Compare commits

..

59 Commits
1.2.3 ... 1.2.5

Author SHA1 Message Date
lanvent
a52f54d988 docs(wechatmp): Update README.md 2023-04-22 12:15:56 +08:00
lanvent
618c94edb8 formatting: run precommit on all files 2023-04-22 12:01:29 +08:00
lanvent
eaf4e9174f style(linting): increase max-line-length to 176
The max-line-length configuration was increased to 176 in both .flake8 and pyproject.toml files to allow for longer lines of code.
2023-04-22 11:59:12 +08:00
lanvent
4af2c7f3d7 fix: escape regex pattern 2023-04-22 11:39:59 +08:00
lanvent
361f599df0 fix: escape regex patterns when matching name 2023-04-22 11:29:41 +08:00
Jianglang
ffe4ea5e4c Update README.md 2023-04-22 11:12:30 +08:00
Jianglang
9461e3e01a Merge pull request #912 from zhayujie/wechatmp
公众号功能优化:支持图片输入、消息加密模式、用户体验优化
2023-04-22 11:08:08 +08:00
lanvent
7c85c6f742 feat(wechatmp): add support for message encryption
- Add support for message encryption in WeChat MP channel.
- Add `wechatmp_aes_key` configuration item to `config.json`.
2023-04-22 02:33:51 +08:00
lanvent
b5df6faadf feat: verify server when receive message in wechatmp 2023-04-22 01:30:21 +08:00
lanvent
7cefe2d825 fix: split long text messages into multiple parts in wechatmp_service 2023-04-21 21:03:38 +08:00
lanvent
350633b69b Merge Purll Request #920 into wechatmp 2023-04-21 20:46:16 +08:00
JS00000
1cd6a71ce0 fix the bug of pytts in linux 2023-04-21 18:31:20 +08:00
JS00000
3a08b002a0 Merge remote-tracking branch 'origin/wechatmp' into wechatmp 2023-04-21 16:20:57 +08:00
lanvent
cca49da730 fix: fix subscribe_msg 2023-04-21 13:49:51 +08:00
lanvent
f6d370ad29 fix: check if event is subscribe 2023-04-21 13:43:01 +08:00
lanvent
c9131b333b feat: add clear_quota_v2 method to clear API quota when it's used up 2023-04-21 13:41:21 +08:00
lanvent
e44161bf42 fix: voice_reply_voice not work 2023-04-21 03:28:31 +08:00
lanvent
a26189fb25 chore: remove passive_reply_message.py 2023-04-21 03:04:50 +08:00
lanvent
89dd8a1db6 refactor(wechatmp): use wechatpy to handle wechatmp messages
feat(wechatmp): add support for image and voice replies
2023-04-21 02:47:33 +08:00
JS00000
650e0b4ad4 wechatmp: adjust log 2023-04-21 02:16:13 +08:00
lanvent
c60f0517fb refactor(audio_convert.py): remove redundant functions 2023-04-20 23:22:08 +08:00
lanvent
0f8dc91a8b fix: add check for empty command and return error message if so 2023-04-20 23:13:07 +08:00
lanvent
b58feb5d8e Merge Pull Request #904 into master 2023-04-20 23:06:17 +08:00
JS00000
71c8043699 update README 2023-04-20 12:35:54 +08:00
JS00000
40264bc9cb fix: delete permanent media 2023-04-20 12:03:48 +08:00
JS00000
a7772316f9 feat: wechatmp channel support voice/image reply 2023-04-20 10:26:58 +08:00
JS00000
34209021c8 fix: pytts second round not work 2023-04-20 09:04:42 +08:00
JS00000
1e58c1ad2b fix: wechatmp channel now do not need client 2023-04-20 04:35:06 +08:00
JS00000
8cea022ec5 Merge branch 'master' into wechatmp 2023-04-20 03:41:37 +08:00
JS00000
f32f8aa08e Update readme, and make the structure more clear 2023-04-20 03:18:21 +08:00
goldfish菌
0a7d6e4577 plugin(tool) ver0.4.1 (#891)
* plugin(tool) fix bugs

* plugin(tool) tool插件更新至0.4.1 版本
2023-04-19 10:05:28 +08:00
JS00000
df4c1f0401 wechatmp: logic simplification 2023-04-19 01:56:25 +08:00
JS00000
9a86a67984 update readme 2023-04-19 01:54:20 +08:00
lanvent
a0cbe9c3e2 feat(azure_voice.py): improve error logging in voiceToText method 2023-04-19 00:55:22 +08:00
lanvent
a83e5a9b65 feat(azure_voice.py): improve error logging in textToVoice method 2023-04-19 00:51:52 +08:00
lanvent
de33911460 feat: add support for PATPAT context 2023-04-18 23:34:08 +08:00
lanvent
0be56e5b25 Merge branch Pull Request #882 into master 2023-04-18 14:26:16 +08:00
lanvent
abcbb34b1c fix(chat_gpt_bot.py, open_ai_bot.py): increase retry time to 20 seconds when encountering RateLimitError 2023-04-18 14:18:22 +08:00
林督翔
6a13dd04a3 feat(插件开发):新增关键字匹配插件 2023-04-18 13:57:20 +08:00
lanvent
f2e29f3f2e fix: banwords help 2023-04-18 11:43:34 +08:00
JS00000
68361cddd2 wechatmp_service: image and voice reply supported 2023-04-18 03:08:18 +08:00
lanvent
6404332adc feat: itchat support joingroup message 2023-04-18 02:21:41 +08:00
JS00000
e060b6fea2 Merge branch 'master' into wechatmp 2023-04-17 20:11:41 +08:00
lanvent
e8aae27ee9 fix: missing lib in banwords 2023-04-17 15:41:29 +08:00
lanvent
2f732e5493 fix: toolhub request_timeout should be str 2023-04-17 12:00:28 +08:00
lanvent
65f20ff2c1 Merge Pull Request #860 into master 2023-04-17 01:24:39 +08:00
lanvent
8f72e8c3e6 formatting code 2023-04-17 01:01:02 +08:00
lanvent
3b8972ce1f add pre-commit hook 2023-04-17 00:57:48 +08:00
李超
fc5d3e4e9c feat: Make the size parameter of the resulting picture configurable 2023-04-16 22:31:53 +08:00
李超
29fbf69945 feat: Add configuration items to support custom data directories and facilitate the storage of itchat.pkl 2023-04-16 22:31:53 +08:00
lanvent
583440b82b banwords: move WordsSearch to lib 2023-04-16 19:04:21 +08:00
lanvent
720de9d73f chore: strip content 2023-04-16 00:47:32 +08:00
JS00000
7fb4f72b84 update wechatmp README 2023-04-12 05:52:26 +08:00
JS00000
d4fc322101 Merge branch 'master' into wechatmp 2023-04-12 05:43:05 +08:00
JS00000
8fa3da9ca5 wechatmp: voice input support 2023-04-12 05:41:48 +08:00
JS00000
68ef5aa3ae ctrl+c exit 2023-04-12 05:35:31 +08:00
JS00000
15e6cf850b Merge branch 'master' into wechatmp 2023-04-10 18:57:01 +08:00
JS00000
f687b2b6f4 remove _success_callback 2023-04-09 18:32:09 +08:00
JS00000
8ee7a48151 fix: wechatmp's deadloop when reply is None 2023-04-09 18:00:34 +08:00
108 changed files with 2247 additions and 1571 deletions

13
.flake8 Normal file
View File

@@ -0,0 +1,13 @@
[flake8]
max-line-length = 176
select = E303,W293,W291,W292,E305,E231,E302
exclude =
.tox,
__pycache__,
*.pyc,
.env
venv/*
.venv/*
reports/*
dist/*
lib/*

View File

@@ -27,5 +27,5 @@
### 环境 ### 环境
- 操作系统类型 (Mac/Windows/Linux) - 操作系统类型 (Mac/Windows/Linux)
- Python版本 ( 执行 `python3 -V` ) - Python版本 ( 执行 `python3 -V` )
- pip版本 ( 依赖问题此项必填,执行 `pip3 -V`) - pip版本 ( 依赖问题此项必填,执行 `pip3 -V`)

View File

@@ -49,9 +49,9 @@ jobs:
file: ./docker/Dockerfile.latest file: ./docker/Dockerfile.latest
tags: ${{ steps.meta.outputs.tags }} tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }} labels: ${{ steps.meta.outputs.labels }}
- uses: actions/delete-package-versions@v4 - uses: actions/delete-package-versions@v4
with: with:
package-name: 'chatgpt-on-wechat' package-name: 'chatgpt-on-wechat'
package-type: 'container' package-type: 'container'
min-versions-to-keep: 10 min-versions-to-keep: 10

5
.gitignore vendored
View File

@@ -13,6 +13,7 @@ plugins.json
itchat.pkl itchat.pkl
*.log *.log
user_datas.pkl user_datas.pkl
chatgpt_tool_hub/
plugins/**/ plugins/**/
!plugins/bdunit !plugins/bdunit
!plugins/dungeon !plugins/dungeon
@@ -20,5 +21,7 @@ plugins/**/
!plugins/godcmd !plugins/godcmd
!plugins/tool !plugins/tool
!plugins/banwords !plugins/banwords
!plugins/banwords/**/
!plugins/hello !plugins/hello
!plugins/role !plugins/role
!plugins/keyword

30
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,30 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: fix-byte-order-marker
- id: check-case-conflict
- id: check-merge-conflict
- id: debug-statements
- id: pretty-format-json
types: [text]
files: \.json(.template)?$
args: [ --autofix , --no-ensure-ascii, --indent=2, --no-sort-keys]
- id: trailing-whitespace
exclude: '(\/|^)lib\/'
args: [ --markdown-linebreak-ext=md ]
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort
exclude: '(\/|^)lib\/'
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
exclude: '(\/|^)lib\/'
- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
hooks:
- id: flake8
exclude: '(\/|^)lib\/'

View File

@@ -22,7 +22,7 @@
# 更新日志 # 更新日志
>**2023.04.05** 支持微信个人号部署,兼容角色扮演等预设插件,[使用文档](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/wechatmp/README.md)。(contributed by [@JS00000](https://github.com/JS00000) in [#686](https://github.com/zhayujie/chatgpt-on-wechat/pull/686)) >**2023.04.05** 支持微信公众号部署,兼容角色扮演等预设插件,[使用文档](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/wechatmp/README.md)。(contributed by [@JS00000](https://github.com/JS00000) in [#686](https://github.com/zhayujie/chatgpt-on-wechat/pull/686))
>**2023.04.05** 增加能让ChatGPT使用工具的`tool`插件,[使用文档](https://github.com/goldfishh/chatgpt-on-wechat/blob/master/plugins/tool/README.md)。工具相关issue可反馈至[chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub)。(contributed by [@goldfishh](https://github.com/goldfishh) in [#663](https://github.com/zhayujie/chatgpt-on-wechat/pull/663)) >**2023.04.05** 增加能让ChatGPT使用工具的`tool`插件,[使用文档](https://github.com/goldfishh/chatgpt-on-wechat/blob/master/plugins/tool/README.md)。工具相关issue可反馈至[chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub)。(contributed by [@goldfishh](https://github.com/goldfishh) in [#663](https://github.com/zhayujie/chatgpt-on-wechat/pull/663))
@@ -120,7 +120,7 @@ pip3 install azure-cognitiveservices-speech
```bash ```bash
# config.json文件内容示例 # config.json文件内容示例
{ {
"open_ai_api_key": "YOUR API KEY", # 填入上面创建的 OpenAI API KEY "open_ai_api_key": "YOUR API KEY", # 填入上面创建的 OpenAI API KEY
"model": "gpt-3.5-turbo", # 模型名称。当use_azure_chatgpt为true时其名称为Azure上model deployment名称 "model": "gpt-3.5-turbo", # 模型名称。当use_azure_chatgpt为true时其名称为Azure上model deployment名称
"proxy": "127.0.0.1:7890", # 代理客户端的ip和端口 "proxy": "127.0.0.1:7890", # 代理客户端的ip和端口
@@ -128,7 +128,7 @@ pip3 install azure-cognitiveservices-speech
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人 "single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
"group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复 "group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表 "group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称 "group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
"image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀 "image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀
"conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数 "conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数
"speech_recognition": false, # 是否开启语音识别 "speech_recognition": false, # 是否开启语音识别
@@ -160,7 +160,7 @@ pip3 install azure-cognitiveservices-speech
**4.其他配置** **4.其他配置**
+ `model`: 模型名称,目前支持 `gpt-3.5-turbo`, `text-davinci-003`, `gpt-4`, `gpt-4-32k` (其中gpt-4 api暂未开放) + `model`: 模型名称,目前支持 `gpt-3.5-turbo`, `text-davinci-003`, `gpt-4`, `gpt-4-32k` (其中gpt-4 api暂未开放)
+ `temperature`,`frequency_penalty`,`presence_penalty`: Chat API接口参数详情参考[OpenAI官方文档。](https://platform.openai.com/docs/api-reference/chat) + `temperature`,`frequency_penalty`,`presence_penalty`: Chat API接口参数详情参考[OpenAI官方文档。](https://platform.openai.com/docs/api-reference/chat)
+ `proxy`:由于目前 `openai` 接口国内无法访问,需配置代理客户端的地址,详情参考 [#351](https://github.com/zhayujie/chatgpt-on-wechat/issues/351) + `proxy`:由于目前 `openai` 接口国内无法访问,需配置代理客户端的地址,详情参考 [#351](https://github.com/zhayujie/chatgpt-on-wechat/issues/351)
+ 对于图像生成,在满足个人或群组触发条件外,还需要额外的关键词前缀来触发,对应配置 `image_create_prefix ` + 对于图像生成,在满足个人或群组触发条件外,还需要额外的关键词前缀来触发,对应配置 `image_create_prefix `
+ 关于OpenAI对话及图片接口的参数配置内容自由度、回复字数限制、图片大小等可以参考 [对话接口](https://beta.openai.com/docs/api-reference/completions) 和 [图像接口](https://beta.openai.com/docs/api-reference/completions) 文档直接在 [代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/bot/openai/open_ai_bot.py) `bot/openai/open_ai_bot.py` 中进行调整。 + 关于OpenAI对话及图片接口的参数配置内容自由度、回复字数限制、图片大小等可以参考 [对话接口](https://beta.openai.com/docs/api-reference/completions) 和 [图像接口](https://beta.openai.com/docs/api-reference/completions) 文档直接在 [代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/bot/openai/open_ai_bot.py) `bot/openai/open_ai_bot.py` 中进行调整。
@@ -181,7 +181,7 @@ pip3 install azure-cognitiveservices-speech
```bash ```bash
python3 app.py python3 app.py
``` ```
终端输出二维码后,使用微信进行扫码,当输出 "Start auto replying" 时表示自动回复程序已经成功运行了(注意:用于登录的微信需要在支付处已完成实名认证)。扫码登录后你的账号就成为机器人了,可以在微信手机端通过配置的关键词触发自动回复 (任意好友发送消息给你,或是自己发消息给好友),参考[#142](https://github.com/zhayujie/chatgpt-on-wechat/issues/142)。 终端输出二维码后,使用微信进行扫码,当输出 "Start auto replying" 时表示自动回复程序已经成功运行了(注意:用于登录的微信需要在支付处已完成实名认证)。扫码登录后你的账号就成为机器人了,可以在微信手机端通过配置的关键词触发自动回复 (任意好友发送消息给你,或是自己发消息给好友),参考[#142](https://github.com/zhayujie/chatgpt-on-wechat/issues/142)。
### 2.服务器部署 ### 2.服务器部署
@@ -189,7 +189,7 @@ python3 app.py
使用nohup命令在后台运行程序 使用nohup命令在后台运行程序
```bash ```bash
touch nohup.out # 首次运行需要新建日志文件 touch nohup.out # 首次运行需要新建日志文件
nohup python3 app.py & tail -f nohup.out # 在后台运行程序并通过日志输出二维码 nohup python3 app.py & tail -f nohup.out # 在后台运行程序并通过日志输出二维码
``` ```
扫码登录后程序即可运行于服务器后台,此时可通过 `ctrl+c` 关闭日志,不会影响后台程序的运行。使用 `ps -ef | grep app.py | grep -v grep` 命令可查看运行于后台的进程,如果想要重新启动程序可以先 `kill` 掉对应的进程。日志关闭后如果想要再次打开只需输入 `tail -f nohup.out`。此外,`scripts` 目录下有一键运行、关闭程序的脚本供使用。 扫码登录后程序即可运行于服务器后台,此时可通过 `ctrl+c` 关闭日志,不会影响后台程序的运行。使用 `ps -ef | grep app.py | grep -v grep` 命令可查看运行于后台的进程,如果想要重新启动程序可以先 `kill` 掉对应的进程。日志关闭后如果想要再次打开只需输入 `tail -f nohup.out`。此外,`scripts` 目录下有一键运行、关闭程序的脚本供使用。

30
app.py
View File

@@ -1,23 +1,28 @@
# encoding:utf-8 # encoding:utf-8
import os import os
from config import conf, load_config
from channel import channel_factory
from common.log import logger
from plugins import *
import signal import signal
import sys import sys
from channel import channel_factory
from common.log import logger
from config import conf, load_config
from plugins import *
def sigterm_handler_wrap(_signo): def sigterm_handler_wrap(_signo):
old_handler = signal.getsignal(_signo) old_handler = signal.getsignal(_signo)
def func(_signo, _stack_frame): def func(_signo, _stack_frame):
logger.info("signal {} received, exiting...".format(_signo)) logger.info("signal {} received, exiting...".format(_signo))
conf().save_user_datas() conf().save_user_datas()
if callable(old_handler): # check old_handler if callable(old_handler): # check old_handler
return old_handler(_signo, _stack_frame) return old_handler(_signo, _stack_frame)
sys.exit(0) sys.exit(0)
signal.signal(_signo, func) signal.signal(_signo, func)
def run(): def run():
try: try:
# load config # load config
@@ -28,17 +33,17 @@ def run():
sigterm_handler_wrap(signal.SIGTERM) sigterm_handler_wrap(signal.SIGTERM)
# create channel # create channel
channel_name=conf().get('channel_type', 'wx') channel_name = conf().get("channel_type", "wx")
if "--cmd" in sys.argv: if "--cmd" in sys.argv:
channel_name = 'terminal' channel_name = "terminal"
if channel_name == 'wxy': if channel_name == "wxy":
os.environ['WECHATY_LOG']="warn" os.environ["WECHATY_LOG"] = "warn"
# os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001' # os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001'
channel = channel_factory.create_channel(channel_name) channel = channel_factory.create_channel(channel_name)
if channel_name in ['wx','wxy','terminal','wechatmp','wechatmp_service']: if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service"]:
PluginManager().load_plugins() PluginManager().load_plugins()
# startup channel # startup channel
@@ -47,5 +52,6 @@ def run():
logger.error("App startup failed!") logger.error("App startup failed!")
logger.exception(e) logger.exception(e)
if __name__ == '__main__':
run() if __name__ == "__main__":
run()

View File

@@ -1,6 +1,7 @@
# encoding:utf-8 # encoding:utf-8
import requests import requests
from bot.bot import Bot from bot.bot import Bot
from bridge.reply import Reply, ReplyType from bridge.reply import Reply, ReplyType
@@ -9,20 +10,27 @@ from bridge.reply import Reply, ReplyType
class BaiduUnitBot(Bot): class BaiduUnitBot(Bot):
def reply(self, query, context=None): def reply(self, query, context=None):
token = self.get_token() token = self.get_token()
url = 'https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=' + token url = "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" + token
post_data = "{\"version\":\"3.0\",\"service_id\":\"S73177\",\"session_id\":\"\",\"log_id\":\"7758521\",\"skill_ids\":[\"1221886\"],\"request\":{\"terminal_id\":\"88888\",\"query\":\"" + query + "\", \"hyper_params\": {\"chat_custom_bot_profile\": 1}}}" post_data = (
'{"version":"3.0","service_id":"S73177","session_id":"","log_id":"7758521","skill_ids":["1221886"],"request":{"terminal_id":"88888","query":"'
+ query
+ '", "hyper_params": {"chat_custom_bot_profile": 1}}}'
)
print(post_data) print(post_data)
headers = {'content-type': 'application/x-www-form-urlencoded'} headers = {"content-type": "application/x-www-form-urlencoded"}
response = requests.post(url, data=post_data.encode(), headers=headers) response = requests.post(url, data=post_data.encode(), headers=headers)
if response: if response:
reply = Reply(ReplyType.TEXT, response.json()['result']['context']['SYS_PRESUMED_HIST'][1]) reply = Reply(
ReplyType.TEXT,
response.json()["result"]["context"]["SYS_PRESUMED_HIST"][1],
)
return reply return reply
def get_token(self): def get_token(self):
access_key = 'YOUR_ACCESS_KEY' access_key = "YOUR_ACCESS_KEY"
secret_key = 'YOUR_SECRET_KEY' secret_key = "YOUR_SECRET_KEY"
host = 'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=' + access_key + '&client_secret=' + secret_key host = "https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=" + access_key + "&client_secret=" + secret_key
response = requests.get(host) response = requests.get(host)
if response: if response:
print(response.json()) print(response.json())
return response.json()['access_token'] return response.json()["access_token"]

View File

@@ -8,7 +8,7 @@ from bridge.reply import Reply
class Bot(object): class Bot(object):
def reply(self, query, context : Context =None) -> Reply: def reply(self, query, context: Context = None) -> Reply:
""" """
bot auto-reply content bot auto-reply content
:param req: received message :param req: received message

View File

@@ -13,20 +13,24 @@ def create_bot(bot_type):
if bot_type == const.BAIDU: if bot_type == const.BAIDU:
# Baidu Unit对话接口 # Baidu Unit对话接口
from bot.baidu.baidu_unit_bot import BaiduUnitBot from bot.baidu.baidu_unit_bot import BaiduUnitBot
return BaiduUnitBot() return BaiduUnitBot()
elif bot_type == const.CHATGPT: elif bot_type == const.CHATGPT:
# ChatGPT 网页端web接口 # ChatGPT 网页端web接口
from bot.chatgpt.chat_gpt_bot import ChatGPTBot from bot.chatgpt.chat_gpt_bot import ChatGPTBot
return ChatGPTBot() return ChatGPTBot()
elif bot_type == const.OPEN_AI: elif bot_type == const.OPEN_AI:
# OpenAI 官方对话模型API # OpenAI 官方对话模型API
from bot.openai.open_ai_bot import OpenAIBot from bot.openai.open_ai_bot import OpenAIBot
return OpenAIBot() return OpenAIBot()
elif bot_type == const.CHATGPTONAZURE: elif bot_type == const.CHATGPTONAZURE:
# Azure chatgpt service https://azure.microsoft.com/en-in/products/cognitive-services/openai-service/ # Azure chatgpt service https://azure.microsoft.com/en-in/products/cognitive-services/openai-service/
from bot.chatgpt.chat_gpt_bot import AzureChatGPTBot from bot.chatgpt.chat_gpt_bot import AzureChatGPTBot
return AzureChatGPTBot() return AzureChatGPTBot()
raise RuntimeError raise RuntimeError

View File

@@ -1,42 +1,45 @@
# encoding:utf-8 # encoding:utf-8
import time
import openai
import openai.error
from bot.bot import Bot from bot.bot import Bot
from bot.chatgpt.chat_gpt_session import ChatGPTSession from bot.chatgpt.chat_gpt_session import ChatGPTSession
from bot.openai.open_ai_image import OpenAIImage from bot.openai.open_ai_image import OpenAIImage
from bot.session_manager import SessionManager from bot.session_manager import SessionManager
from bridge.context import ContextType from bridge.context import ContextType
from bridge.reply import Reply, ReplyType from bridge.reply import Reply, ReplyType
from config import conf, load_config
from common.log import logger from common.log import logger
from common.token_bucket import TokenBucket from common.token_bucket import TokenBucket
import openai from config import conf, load_config
import openai.error
import time
# OpenAI对话模型API (可用) # OpenAI对话模型API (可用)
class ChatGPTBot(Bot,OpenAIImage): class ChatGPTBot(Bot, OpenAIImage):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
# set the default api_key # set the default api_key
openai.api_key = conf().get('open_ai_api_key') openai.api_key = conf().get("open_ai_api_key")
if conf().get('open_ai_api_base'): if conf().get("open_ai_api_base"):
openai.api_base = conf().get('open_ai_api_base') openai.api_base = conf().get("open_ai_api_base")
proxy = conf().get('proxy') proxy = conf().get("proxy")
if proxy: if proxy:
openai.proxy = proxy openai.proxy = proxy
if conf().get('rate_limit_chatgpt'): if conf().get("rate_limit_chatgpt"):
self.tb4chatgpt = TokenBucket(conf().get('rate_limit_chatgpt', 20)) self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20))
self.sessions = SessionManager(ChatGPTSession, model= conf().get("model") or "gpt-3.5-turbo") self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
self.args ={ self.args = {
"model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称 "model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称
"temperature":conf().get('temperature', 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性 "temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
# "max_tokens":4096, # 回复最大的字符数 # "max_tokens":4096, # 回复最大的字符数
"top_p":1, "top_p": 1,
"frequency_penalty":conf().get('frequency_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 "frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"presence_penalty":conf().get('presence_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 "presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"request_timeout": conf().get('request_timeout', None), # 请求超时时间openai接口默认设置为600对于难问题一般需要较长时间 "request_timeout": conf().get("request_timeout", None), # 请求超时时间openai接口默认设置为600对于难问题一般需要较长时间
"timeout": conf().get('request_timeout', None), #重试超时时间,在这个时间内,将会自动重试 "timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
} }
def reply(self, query, context=None): def reply(self, query, context=None):
@@ -44,39 +47,45 @@ class ChatGPTBot(Bot,OpenAIImage):
if context.type == ContextType.TEXT: if context.type == ContextType.TEXT:
logger.info("[CHATGPT] query={}".format(query)) logger.info("[CHATGPT] query={}".format(query))
session_id = context["session_id"]
session_id = context['session_id']
reply = None reply = None
clear_memory_commands = conf().get('clear_memory_commands', ['#清除记忆']) clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
if query in clear_memory_commands: if query in clear_memory_commands:
self.sessions.clear_session(session_id) self.sessions.clear_session(session_id)
reply = Reply(ReplyType.INFO, '记忆已清除') reply = Reply(ReplyType.INFO, "记忆已清除")
elif query == '#清除所有': elif query == "#清除所有":
self.sessions.clear_all_session() self.sessions.clear_all_session()
reply = Reply(ReplyType.INFO, '所有人记忆已清除') reply = Reply(ReplyType.INFO, "所有人记忆已清除")
elif query == '#更新配置': elif query == "#更新配置":
load_config() load_config()
reply = Reply(ReplyType.INFO, '配置已更新') reply = Reply(ReplyType.INFO, "配置已更新")
if reply: if reply:
return reply return reply
session = self.sessions.session_query(query, session_id) session = self.sessions.session_query(query, session_id)
logger.debug("[CHATGPT] session query={}".format(session.messages)) logger.debug("[CHATGPT] session query={}".format(session.messages))
api_key = context.get('openai_api_key') api_key = context.get("openai_api_key")
# if context.get('stream'): # if context.get('stream'):
# # reply in stream # # reply in stream
# return self.reply_text_stream(query, new_query, session_id) # return self.reply_text_stream(query, new_query, session_id)
reply_content = self.reply_text(session, api_key) reply_content = self.reply_text(session, api_key)
logger.debug("[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session.messages, session_id, reply_content["content"], reply_content["completion_tokens"])) logger.debug(
if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0: "[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
reply = Reply(ReplyType.ERROR, reply_content['content']) session.messages,
session_id,
reply_content["content"],
reply_content["completion_tokens"],
)
)
if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
reply = Reply(ReplyType.ERROR, reply_content["content"])
elif reply_content["completion_tokens"] > 0: elif reply_content["completion_tokens"] > 0:
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"]) self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
reply = Reply(ReplyType.TEXT, reply_content["content"]) reply = Reply(ReplyType.TEXT, reply_content["content"])
else: else:
reply = Reply(ReplyType.ERROR, reply_content['content']) reply = Reply(ReplyType.ERROR, reply_content["content"])
logger.debug("[CHATGPT] reply {} used 0 tokens.".format(reply_content)) logger.debug("[CHATGPT] reply {} used 0 tokens.".format(reply_content))
return reply return reply
@@ -89,53 +98,53 @@ class ChatGPTBot(Bot,OpenAIImage):
reply = Reply(ReplyType.ERROR, retstring) reply = Reply(ReplyType.ERROR, retstring)
return reply return reply
else: else:
reply = Reply(ReplyType.ERROR, 'Bot不支持处理{}类型的消息'.format(context.type)) reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
return reply return reply
def reply_text(self, session:ChatGPTSession, api_key=None, retry_count=0) -> dict: def reply_text(self, session: ChatGPTSession, api_key=None, retry_count=0) -> dict:
''' """
call openai's ChatCompletion to get the answer call openai's ChatCompletion to get the answer
:param session: a conversation session :param session: a conversation session
:param session_id: session id :param session_id: session id
:param retry_count: retry count :param retry_count: retry count
:return: {} :return: {}
''' """
try: try:
if conf().get('rate_limit_chatgpt') and not self.tb4chatgpt.get_token(): if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token():
raise openai.error.RateLimitError("RateLimitError: rate limit exceeded") raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
# if api_key == None, the default openai.api_key will be used # if api_key == None, the default openai.api_key will be used
response = openai.ChatCompletion.create( response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **self.args)
api_key=api_key, messages=session.messages, **self.args
)
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"])) # logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
return {"total_tokens": response["usage"]["total_tokens"], return {
"completion_tokens": response["usage"]["completion_tokens"], "total_tokens": response["usage"]["total_tokens"],
"content": response.choices[0]['message']['content']} "completion_tokens": response["usage"]["completion_tokens"],
"content": response.choices[0]["message"]["content"],
}
except Exception as e: except Exception as e:
need_retry = retry_count < 2 need_retry = retry_count < 2
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
if isinstance(e, openai.error.RateLimitError): if isinstance(e, openai.error.RateLimitError):
logger.warn("[CHATGPT] RateLimitError: {}".format(e)) logger.warn("[CHATGPT] RateLimitError: {}".format(e))
result['content'] = "提问太快啦,请休息一下再问我吧" result["content"] = "提问太快啦,请休息一下再问我吧"
if need_retry: if need_retry:
time.sleep(5) time.sleep(20)
elif isinstance(e, openai.error.Timeout): elif isinstance(e, openai.error.Timeout):
logger.warn("[CHATGPT] Timeout: {}".format(e)) logger.warn("[CHATGPT] Timeout: {}".format(e))
result['content'] = "我没有收到你的消息" result["content"] = "我没有收到你的消息"
if need_retry: if need_retry:
time.sleep(5) time.sleep(5)
elif isinstance(e, openai.error.APIConnectionError): elif isinstance(e, openai.error.APIConnectionError):
logger.warn("[CHATGPT] APIConnectionError: {}".format(e)) logger.warn("[CHATGPT] APIConnectionError: {}".format(e))
need_retry = False need_retry = False
result['content'] = "我连接不到你的网络" result["content"] = "我连接不到你的网络"
else: else:
logger.warn("[CHATGPT] Exception: {}".format(e)) logger.warn("[CHATGPT] Exception: {}".format(e))
need_retry = False need_retry = False
self.sessions.clear_session(session.session_id) self.sessions.clear_session(session.session_id)
if need_retry: if need_retry:
logger.warn("[CHATGPT] 第{}次重试".format(retry_count+1)) logger.warn("[CHATGPT] 第{}次重试".format(retry_count + 1))
return self.reply_text(session, api_key, retry_count+1) return self.reply_text(session, api_key, retry_count + 1)
else: else:
return result return result
@@ -145,4 +154,4 @@ class AzureChatGPTBot(ChatGPTBot):
super().__init__() super().__init__()
openai.api_type = "azure" openai.api_type = "azure"
openai.api_version = "2023-03-15-preview" openai.api_version = "2023-03-15-preview"
self.args["deployment_id"] = conf().get("azure_deployment_id") self.args["deployment_id"] = conf().get("azure_deployment_id")

View File

@@ -1,20 +1,23 @@
from bot.session_manager import Session from bot.session_manager import Session
from common.log import logger from common.log import logger
'''
"""
e.g. [ e.g. [
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Who won the world series in 2020?"}, {"role": "user", "content": "Who won the world series in 2020?"},
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."}, {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
{"role": "user", "content": "Where was it played?"} {"role": "user", "content": "Where was it played?"}
] ]
''' """
class ChatGPTSession(Session): class ChatGPTSession(Session):
def __init__(self, session_id, system_prompt=None, model= "gpt-3.5-turbo"): def __init__(self, session_id, system_prompt=None, model="gpt-3.5-turbo"):
super().__init__(session_id, system_prompt) super().__init__(session_id, system_prompt)
self.model = model self.model = model
self.reset() self.reset()
def discard_exceeding(self, max_tokens, cur_tokens= None): def discard_exceeding(self, max_tokens, cur_tokens=None):
precise = True precise = True
try: try:
cur_tokens = self.calc_tokens() cur_tokens = self.calc_tokens()
@@ -44,15 +47,16 @@ class ChatGPTSession(Session):
else: else:
cur_tokens = cur_tokens - max_tokens cur_tokens = cur_tokens - max_tokens
return cur_tokens return cur_tokens
def calc_tokens(self): def calc_tokens(self):
return num_tokens_from_messages(self.messages, self.model) return num_tokens_from_messages(self.messages, self.model)
# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb # refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
def num_tokens_from_messages(messages, model): def num_tokens_from_messages(messages, model):
"""Returns the number of tokens used by a list of messages.""" """Returns the number of tokens used by a list of messages."""
import tiktoken import tiktoken
try: try:
encoding = tiktoken.encoding_for_model(model) encoding = tiktoken.encoding_for_model(model)
except KeyError: except KeyError:

View File

@@ -1,41 +1,44 @@
# encoding:utf-8 # encoding:utf-8
import time
import openai
import openai.error
from bot.bot import Bot from bot.bot import Bot
from bot.openai.open_ai_image import OpenAIImage from bot.openai.open_ai_image import OpenAIImage
from bot.openai.open_ai_session import OpenAISession from bot.openai.open_ai_session import OpenAISession
from bot.session_manager import SessionManager from bot.session_manager import SessionManager
from bridge.context import ContextType from bridge.context import ContextType
from bridge.reply import Reply, ReplyType from bridge.reply import Reply, ReplyType
from config import conf
from common.log import logger from common.log import logger
import openai from config import conf
import openai.error
import time
user_session = dict() user_session = dict()
# OpenAI对话模型API (可用) # OpenAI对话模型API (可用)
class OpenAIBot(Bot, OpenAIImage): class OpenAIBot(Bot, OpenAIImage):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
openai.api_key = conf().get('open_ai_api_key') openai.api_key = conf().get("open_ai_api_key")
if conf().get('open_ai_api_base'): if conf().get("open_ai_api_base"):
openai.api_base = conf().get('open_ai_api_base') openai.api_base = conf().get("open_ai_api_base")
proxy = conf().get('proxy') proxy = conf().get("proxy")
if proxy: if proxy:
openai.proxy = proxy openai.proxy = proxy
self.sessions = SessionManager(OpenAISession, model= conf().get("model") or "text-davinci-003") self.sessions = SessionManager(OpenAISession, model=conf().get("model") or "text-davinci-003")
self.args = { self.args = {
"model": conf().get("model") or "text-davinci-003", # 对话模型的名称 "model": conf().get("model") or "text-davinci-003", # 对话模型的名称
"temperature":conf().get('temperature', 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性 "temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
"max_tokens":1200, # 回复最大的字符数 "max_tokens": 1200, # 回复最大的字符数
"top_p":1, "top_p": 1,
"frequency_penalty":conf().get('frequency_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 "frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"presence_penalty":conf().get('presence_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 "presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"request_timeout": conf().get('request_timeout', None), # 请求超时时间openai接口默认设置为600对于难问题一般需要较长时间 "request_timeout": conf().get("request_timeout", None), # 请求超时时间openai接口默认设置为600对于难问题一般需要较长时间
"timeout": conf().get('request_timeout', None), #重试超时时间,在这个时间内,将会自动重试 "timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
"stop":["\n\n\n"] "stop": ["\n\n\n"],
} }
def reply(self, query, context=None): def reply(self, query, context=None):
@@ -43,21 +46,27 @@ class OpenAIBot(Bot, OpenAIImage):
if context and context.type: if context and context.type:
if context.type == ContextType.TEXT: if context.type == ContextType.TEXT:
logger.info("[OPEN_AI] query={}".format(query)) logger.info("[OPEN_AI] query={}".format(query))
session_id = context['session_id'] session_id = context["session_id"]
reply = None reply = None
if query == '#清除记忆': if query == "#清除记忆":
self.sessions.clear_session(session_id) self.sessions.clear_session(session_id)
reply = Reply(ReplyType.INFO, '记忆已清除') reply = Reply(ReplyType.INFO, "记忆已清除")
elif query == '#清除所有': elif query == "#清除所有":
self.sessions.clear_all_session() self.sessions.clear_all_session()
reply = Reply(ReplyType.INFO, '所有人记忆已清除') reply = Reply(ReplyType.INFO, "所有人记忆已清除")
else: else:
session = self.sessions.session_query(query, session_id) session = self.sessions.session_query(query, session_id)
result = self.reply_text(session) result = self.reply_text(session)
total_tokens, completion_tokens, reply_content = result['total_tokens'], result['completion_tokens'], result['content'] total_tokens, completion_tokens, reply_content = (
logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens)) result["total_tokens"],
result["completion_tokens"],
result["content"],
)
logger.debug(
"[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens)
)
if total_tokens == 0 : if total_tokens == 0:
reply = Reply(ReplyType.ERROR, reply_content) reply = Reply(ReplyType.ERROR, reply_content)
else: else:
self.sessions.session_reply(reply_content, session_id, total_tokens) self.sessions.session_reply(reply_content, session_id, total_tokens)
@@ -72,42 +81,42 @@ class OpenAIBot(Bot, OpenAIImage):
reply = Reply(ReplyType.ERROR, retstring) reply = Reply(ReplyType.ERROR, retstring)
return reply return reply
def reply_text(self, session:OpenAISession, retry_count=0): def reply_text(self, session: OpenAISession, retry_count=0):
try: try:
response = openai.Completion.create( response = openai.Completion.create(prompt=str(session), **self.args)
prompt=str(session), **self.args res_content = response.choices[0]["text"].strip().replace("<|endoftext|>", "")
)
res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '')
total_tokens = response["usage"]["total_tokens"] total_tokens = response["usage"]["total_tokens"]
completion_tokens = response["usage"]["completion_tokens"] completion_tokens = response["usage"]["completion_tokens"]
logger.info("[OPEN_AI] reply={}".format(res_content)) logger.info("[OPEN_AI] reply={}".format(res_content))
return {"total_tokens": total_tokens, return {
"completion_tokens": completion_tokens, "total_tokens": total_tokens,
"content": res_content} "completion_tokens": completion_tokens,
"content": res_content,
}
except Exception as e: except Exception as e:
need_retry = retry_count < 2 need_retry = retry_count < 2
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
if isinstance(e, openai.error.RateLimitError): if isinstance(e, openai.error.RateLimitError):
logger.warn("[OPEN_AI] RateLimitError: {}".format(e)) logger.warn("[OPEN_AI] RateLimitError: {}".format(e))
result['content'] = "提问太快啦,请休息一下再问我吧" result["content"] = "提问太快啦,请休息一下再问我吧"
if need_retry: if need_retry:
time.sleep(5) time.sleep(20)
elif isinstance(e, openai.error.Timeout): elif isinstance(e, openai.error.Timeout):
logger.warn("[OPEN_AI] Timeout: {}".format(e)) logger.warn("[OPEN_AI] Timeout: {}".format(e))
result['content'] = "我没有收到你的消息" result["content"] = "我没有收到你的消息"
if need_retry: if need_retry:
time.sleep(5) time.sleep(5)
elif isinstance(e, openai.error.APIConnectionError): elif isinstance(e, openai.error.APIConnectionError):
logger.warn("[OPEN_AI] APIConnectionError: {}".format(e)) logger.warn("[OPEN_AI] APIConnectionError: {}".format(e))
need_retry = False need_retry = False
result['content'] = "我连接不到你的网络" result["content"] = "我连接不到你的网络"
else: else:
logger.warn("[OPEN_AI] Exception: {}".format(e)) logger.warn("[OPEN_AI] Exception: {}".format(e))
need_retry = False need_retry = False
self.sessions.clear_session(session.session_id) self.sessions.clear_session(session.session_id)
if need_retry: if need_retry:
logger.warn("[OPEN_AI] 第{}次重试".format(retry_count+1)) logger.warn("[OPEN_AI] 第{}次重试".format(retry_count + 1))
return self.reply_text(session, retry_count+1) return self.reply_text(session, retry_count + 1)
else: else:
return result return result

View File

@@ -1,38 +1,41 @@
import time import time
import openai import openai
import openai.error import openai.error
from common.token_bucket import TokenBucket
from common.log import logger from common.log import logger
from common.token_bucket import TokenBucket
from config import conf from config import conf
# OPENAI提供的画图接口 # OPENAI提供的画图接口
class OpenAIImage(object): class OpenAIImage(object):
def __init__(self): def __init__(self):
openai.api_key = conf().get('open_ai_api_key') openai.api_key = conf().get("open_ai_api_key")
if conf().get('rate_limit_dalle'): if conf().get("rate_limit_dalle"):
self.tb4dalle = TokenBucket(conf().get('rate_limit_dalle', 50)) self.tb4dalle = TokenBucket(conf().get("rate_limit_dalle", 50))
def create_img(self, query, retry_count=0): def create_img(self, query, retry_count=0):
try: try:
if conf().get('rate_limit_dalle') and not self.tb4dalle.get_token(): if conf().get("rate_limit_dalle") and not self.tb4dalle.get_token():
return False, "请求太快了,请休息一下再问我吧" return False, "请求太快了,请休息一下再问我吧"
logger.info("[OPEN_AI] image_query={}".format(query)) logger.info("[OPEN_AI] image_query={}".format(query))
response = openai.Image.create( response = openai.Image.create(
prompt=query, #图片描述 prompt=query, # 图片描述
n=1, #每次生成图片的数量 n=1, # 每次生成图片的数量
size="256x256" #图片大小,可选有 256x256, 512x512, 1024x1024 size=conf().get("image_create_size", "256x256"), # 图片大小,可选有 256x256, 512x512, 1024x1024
) )
image_url = response['data'][0]['url'] image_url = response["data"][0]["url"]
logger.info("[OPEN_AI] image_url={}".format(image_url)) logger.info("[OPEN_AI] image_url={}".format(image_url))
return True, image_url return True, image_url
except openai.error.RateLimitError as e: except openai.error.RateLimitError as e:
logger.warn(e) logger.warn(e)
if retry_count < 1: if retry_count < 1:
time.sleep(5) time.sleep(5)
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1)) logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count + 1))
return self.create_img(query, retry_count+1) return self.create_img(query, retry_count + 1)
else: else:
return False, "提问太快啦,请休息一下再问我吧" return False, "提问太快啦,请休息一下再问我吧"
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
return False, str(e) return False, str(e)

View File

@@ -1,32 +1,34 @@
from bot.session_manager import Session from bot.session_manager import Session
from common.log import logger from common.log import logger
class OpenAISession(Session): class OpenAISession(Session):
def __init__(self, session_id, system_prompt=None, model= "text-davinci-003"): def __init__(self, session_id, system_prompt=None, model="text-davinci-003"):
super().__init__(session_id, system_prompt) super().__init__(session_id, system_prompt)
self.model = model self.model = model
self.reset() self.reset()
def __str__(self): def __str__(self):
# 构造对话模型的输入 # 构造对话模型的输入
''' """
e.g. Q: xxx e.g. Q: xxx
A: xxx A: xxx
Q: xxx Q: xxx
''' """
prompt = "" prompt = ""
for item in self.messages: for item in self.messages:
if item['role'] == 'system': if item["role"] == "system":
prompt += item['content'] + "<|endoftext|>\n\n\n" prompt += item["content"] + "<|endoftext|>\n\n\n"
elif item['role'] == 'user': elif item["role"] == "user":
prompt += "Q: " + item['content'] + "\n" prompt += "Q: " + item["content"] + "\n"
elif item['role'] == 'assistant': elif item["role"] == "assistant":
prompt += "\n\nA: " + item['content'] + "<|endoftext|>\n" prompt += "\n\nA: " + item["content"] + "<|endoftext|>\n"
if len(self.messages) > 0 and self.messages[-1]['role'] == 'user': if len(self.messages) > 0 and self.messages[-1]["role"] == "user":
prompt += "A: " prompt += "A: "
return prompt return prompt
def discard_exceeding(self, max_tokens, cur_tokens= None): def discard_exceeding(self, max_tokens, cur_tokens=None):
precise = True precise = True
try: try:
cur_tokens = self.calc_tokens() cur_tokens = self.calc_tokens()
@@ -56,14 +58,16 @@ class OpenAISession(Session):
else: else:
cur_tokens = len(str(self)) cur_tokens = len(str(self))
return cur_tokens return cur_tokens
def calc_tokens(self): def calc_tokens(self):
return num_tokens_from_string(str(self), self.model) return num_tokens_from_string(str(self), self.model)
# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb # refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
def num_tokens_from_string(string: str, model: str) -> int: def num_tokens_from_string(string: str, model: str) -> int:
"""Returns the number of tokens in a text string.""" """Returns the number of tokens in a text string."""
import tiktoken import tiktoken
encoding = tiktoken.encoding_for_model(model) encoding = tiktoken.encoding_for_model(model)
num_tokens = len(encoding.encode(string,disallowed_special=())) num_tokens = len(encoding.encode(string, disallowed_special=()))
return num_tokens return num_tokens

View File

@@ -2,6 +2,7 @@ from common.expired_dict import ExpiredDict
from common.log import logger from common.log import logger
from config import conf from config import conf
class Session(object): class Session(object):
def __init__(self, session_id, system_prompt=None): def __init__(self, session_id, system_prompt=None):
self.session_id = session_id self.session_id = session_id
@@ -13,7 +14,7 @@ class Session(object):
# 重置会话 # 重置会话
def reset(self): def reset(self):
system_item = {'role': 'system', 'content': self.system_prompt} system_item = {"role": "system", "content": self.system_prompt}
self.messages = [system_item] self.messages = [system_item]
def set_system_prompt(self, system_prompt): def set_system_prompt(self, system_prompt):
@@ -21,13 +22,13 @@ class Session(object):
self.reset() self.reset()
def add_query(self, query): def add_query(self, query):
user_item = {'role': 'user', 'content': query} user_item = {"role": "user", "content": query}
self.messages.append(user_item) self.messages.append(user_item)
def add_reply(self, reply): def add_reply(self, reply):
assistant_item = {'role': 'assistant', 'content': reply} assistant_item = {"role": "assistant", "content": reply}
self.messages.append(assistant_item) self.messages.append(assistant_item)
def discard_exceeding(self, max_tokens=None, cur_tokens=None): def discard_exceeding(self, max_tokens=None, cur_tokens=None):
raise NotImplementedError raise NotImplementedError
@@ -37,8 +38,8 @@ class Session(object):
class SessionManager(object): class SessionManager(object):
def __init__(self, sessioncls, **session_args): def __init__(self, sessioncls, **session_args):
if conf().get('expires_in_seconds'): if conf().get("expires_in_seconds"):
sessions = ExpiredDict(conf().get('expires_in_seconds')) sessions = ExpiredDict(conf().get("expires_in_seconds"))
else: else:
sessions = dict() sessions = dict()
self.sessions = sessions self.sessions = sessions
@@ -46,17 +47,20 @@ class SessionManager(object):
self.session_args = session_args self.session_args = session_args
def build_session(self, session_id, system_prompt=None): def build_session(self, session_id, system_prompt=None):
''' """
如果session_id不在sessions中创建一个新的session并添加到sessions中 如果session_id不在sessions中创建一个新的session并添加到sessions中
如果system_prompt不会空会更新session的system_prompt并重置session 如果system_prompt不会空会更新session的system_prompt并重置session
''' """
if session_id is None:
return self.sessioncls(session_id, system_prompt, **self.session_args)
if session_id not in self.sessions: if session_id not in self.sessions:
self.sessions[session_id] = self.sessioncls(session_id, system_prompt, **self.session_args) self.sessions[session_id] = self.sessioncls(session_id, system_prompt, **self.session_args)
elif system_prompt is not None: # 如果有新的system_prompt更新并重置session elif system_prompt is not None: # 如果有新的system_prompt更新并重置session
self.sessions[session_id].set_system_prompt(system_prompt) self.sessions[session_id].set_system_prompt(system_prompt)
session = self.sessions[session_id] session = self.sessions[session_id]
return session return session
def session_query(self, query, session_id): def session_query(self, query, session_id):
session = self.build_session(session_id) session = self.build_session(session_id)
session.add_query(query) session.add_query(query)
@@ -68,7 +72,7 @@ class SessionManager(object):
logger.debug("Exception when counting tokens precisely for prompt: {}".format(str(e))) logger.debug("Exception when counting tokens precisely for prompt: {}".format(str(e)))
return session return session
def session_reply(self, reply, session_id, total_tokens = None): def session_reply(self, reply, session_id, total_tokens=None):
session = self.build_session(session_id) session = self.build_session(session_id)
session.add_reply(reply) session.add_reply(reply)
try: try:
@@ -81,7 +85,7 @@ class SessionManager(object):
def clear_session(self, session_id): def clear_session(self, session_id):
if session_id in self.sessions: if session_id in self.sessions:
del(self.sessions[session_id]) del self.sessions[session_id]
def clear_all_session(self): def clear_all_session(self):
self.sessions.clear() self.sessions.clear()

View File

@@ -1,31 +1,31 @@
from bot import bot_factory
from bridge.context import Context from bridge.context import Context
from bridge.reply import Reply from bridge.reply import Reply
from common.log import logger
from bot import bot_factory
from common.singleton import singleton
from voice import voice_factory
from config import conf
from common import const from common import const
from common.log import logger
from common.singleton import singleton
from config import conf
from voice import voice_factory
@singleton @singleton
class Bridge(object): class Bridge(object):
def __init__(self): def __init__(self):
self.btype={ self.btype = {
"chat": const.CHATGPT, "chat": const.CHATGPT,
"voice_to_text": conf().get("voice_to_text", "openai"), "voice_to_text": conf().get("voice_to_text", "openai"),
"text_to_voice": conf().get("text_to_voice", "google") "text_to_voice": conf().get("text_to_voice", "google"),
} }
model_type = conf().get("model") model_type = conf().get("model")
if model_type in ["text-davinci-003"]: if model_type in ["text-davinci-003"]:
self.btype['chat'] = const.OPEN_AI self.btype["chat"] = const.OPEN_AI
if conf().get("use_azure_chatgpt", False): if conf().get("use_azure_chatgpt", False):
self.btype['chat'] = const.CHATGPTONAZURE self.btype["chat"] = const.CHATGPTONAZURE
self.bots={} self.bots = {}
def get_bot(self,typename): def get_bot(self, typename):
if self.bots.get(typename) is None: if self.bots.get(typename) is None:
logger.info("create bot {} for {}".format(self.btype[typename],typename)) logger.info("create bot {} for {}".format(self.btype[typename], typename))
if typename == "text_to_voice": if typename == "text_to_voice":
self.bots[typename] = voice_factory.create_voice(self.btype[typename]) self.bots[typename] = voice_factory.create_voice(self.btype[typename])
elif typename == "voice_to_text": elif typename == "voice_to_text":
@@ -33,18 +33,15 @@ class Bridge(object):
elif typename == "chat": elif typename == "chat":
self.bots[typename] = bot_factory.create_bot(self.btype[typename]) self.bots[typename] = bot_factory.create_bot(self.btype[typename])
return self.bots[typename] return self.bots[typename]
def get_bot_type(self,typename): def get_bot_type(self, typename):
return self.btype[typename] return self.btype[typename]
def fetch_reply_content(self, query, context: Context) -> Reply:
def fetch_reply_content(self, query, context : Context) -> Reply:
return self.get_bot("chat").reply(query, context) return self.get_bot("chat").reply(query, context)
def fetch_voice_to_text(self, voiceFile) -> Reply: def fetch_voice_to_text(self, voiceFile) -> Reply:
return self.get_bot("voice_to_text").voiceToText(voiceFile) return self.get_bot("voice_to_text").voiceToText(voiceFile)
def fetch_text_to_voice(self, text) -> Reply: def fetch_text_to_voice(self, text) -> Reply:
return self.get_bot("text_to_voice").textToVoice(text) return self.get_bot("text_to_voice").textToVoice(text)

View File

@@ -2,36 +2,41 @@
from enum import Enum from enum import Enum
class ContextType (Enum):
TEXT = 1 # 文本消息 class ContextType(Enum):
VOICE = 2 # 音频消息 TEXT = 1 # 文本消息
IMAGE = 3 # 图片消息 VOICE = 2 # 音频消息
IMAGE_CREATE = 10 # 创建图片命令 IMAGE = 3 # 图片消息
IMAGE_CREATE = 10 # 创建图片命令
JOIN_GROUP = 20 # 加入群聊
PATPAT = 21 # 拍了拍
def __str__(self): def __str__(self):
return self.name return self.name
class Context: class Context:
def __init__(self, type : ContextType = None , content = None, kwargs = dict()): def __init__(self, type: ContextType = None, content=None, kwargs=dict()):
self.type = type self.type = type
self.content = content self.content = content
self.kwargs = kwargs self.kwargs = kwargs
def __contains__(self, key): def __contains__(self, key):
if key == 'type': if key == "type":
return self.type is not None return self.type is not None
elif key == 'content': elif key == "content":
return self.content is not None return self.content is not None
else: else:
return key in self.kwargs return key in self.kwargs
def __getitem__(self, key): def __getitem__(self, key):
if key == 'type': if key == "type":
return self.type return self.type
elif key == 'content': elif key == "content":
return self.content return self.content
else: else:
return self.kwargs[key] return self.kwargs[key]
def get(self, key, default=None): def get(self, key, default=None):
try: try:
return self[key] return self[key]
@@ -39,20 +44,20 @@ class Context:
return default return default
def __setitem__(self, key, value): def __setitem__(self, key, value):
if key == 'type': if key == "type":
self.type = value self.type = value
elif key == 'content': elif key == "content":
self.content = value self.content = value
else: else:
self.kwargs[key] = value self.kwargs[key] = value
def __delitem__(self, key): def __delitem__(self, key):
if key == 'type': if key == "type":
self.type = None self.type = None
elif key == 'content': elif key == "content":
self.content = None self.content = None
else: else:
del self.kwargs[key] del self.kwargs[key]
def __str__(self): def __str__(self):
return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs) return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs)

View File

@@ -1,22 +1,25 @@
# encoding:utf-8 # encoding:utf-8
from enum import Enum from enum import Enum
class ReplyType(Enum): class ReplyType(Enum):
TEXT = 1 # 文本 TEXT = 1 # 文本
VOICE = 2 # 音频文件 VOICE = 2 # 音频文件
IMAGE = 3 # 图片文件 IMAGE = 3 # 图片文件
IMAGE_URL = 4 # 图片URL IMAGE_URL = 4 # 图片URL
INFO = 9 INFO = 9
ERROR = 10 ERROR = 10
def __str__(self): def __str__(self):
return self.name return self.name
class Reply: class Reply:
def __init__(self, type : ReplyType = None , content = None): def __init__(self, type: ReplyType = None, content=None):
self.type = type self.type = type
self.content = content self.content = content
def __str__(self): def __str__(self):
return "Reply(type={}, content={})".format(self.type, self.content) return "Reply(type={}, content={})".format(self.type, self.content)

View File

@@ -6,8 +6,10 @@ from bridge.bridge import Bridge
from bridge.context import Context from bridge.context import Context
from bridge.reply import * from bridge.reply import *
class Channel(object): class Channel(object):
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE, ReplyType.IMAGE] NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE, ReplyType.IMAGE]
def startup(self): def startup(self):
""" """
init channel init channel
@@ -27,15 +29,15 @@ class Channel(object):
send message to user send message to user
:param msg: message content :param msg: message content
:param receiver: receiver channel account :param receiver: receiver channel account
:return: :return:
""" """
raise NotImplementedError raise NotImplementedError
def build_reply_content(self, query, context : Context=None) -> Reply: def build_reply_content(self, query, context: Context = None) -> Reply:
return Bridge().fetch_reply_content(query, context) return Bridge().fetch_reply_content(query, context)
def build_voice_to_text(self, voice_file) -> Reply: def build_voice_to_text(self, voice_file) -> Reply:
return Bridge().fetch_voice_to_text(voice_file) return Bridge().fetch_voice_to_text(voice_file)
def build_text_to_voice(self, text) -> Reply: def build_text_to_voice(self, text) -> Reply:
return Bridge().fetch_text_to_voice(text) return Bridge().fetch_text_to_voice(text)

View File

@@ -2,25 +2,31 @@
channel factory channel factory
""" """
def create_channel(channel_type): def create_channel(channel_type):
""" """
create a channel instance create a channel instance
:param channel_type: channel type code :param channel_type: channel type code
:return: channel instance :return: channel instance
""" """
if channel_type == 'wx': if channel_type == "wx":
from channel.wechat.wechat_channel import WechatChannel from channel.wechat.wechat_channel import WechatChannel
return WechatChannel() return WechatChannel()
elif channel_type == 'wxy': elif channel_type == "wxy":
from channel.wechat.wechaty_channel import WechatyChannel from channel.wechat.wechaty_channel import WechatyChannel
return WechatyChannel() return WechatyChannel()
elif channel_type == 'terminal': elif channel_type == "terminal":
from channel.terminal.terminal_channel import TerminalChannel from channel.terminal.terminal_channel import TerminalChannel
return TerminalChannel() return TerminalChannel()
elif channel_type == 'wechatmp': elif channel_type == "wechatmp":
from channel.wechatmp.wechatmp_channel import WechatMPChannel from channel.wechatmp.wechatmp_channel import WechatMPChannel
return WechatMPChannel(passive_reply = True)
elif channel_type == 'wechatmp_service': return WechatMPChannel(passive_reply=True)
elif channel_type == "wechatmp_service":
from channel.wechatmp.wechatmp_channel import WechatMPChannel from channel.wechatmp.wechatmp_channel import WechatMPChannel
return WechatMPChannel(passive_reply = False)
return WechatMPChannel(passive_reply=False)
raise RuntimeError raise RuntimeError

View File

@@ -1,137 +1,148 @@
from asyncio import CancelledError
from concurrent.futures import Future, ThreadPoolExecutor
import os import os
import re import re
import threading import threading
import time import time
from common.dequeue import Dequeue from asyncio import CancelledError
from channel.channel import Channel from concurrent.futures import Future, ThreadPoolExecutor
from bridge.reply import *
from bridge.context import * from bridge.context import *
from config import conf from bridge.reply import *
from channel.channel import Channel
from common.dequeue import Dequeue
from common.log import logger from common.log import logger
from config import conf
from plugins import * from plugins import *
try: try:
from voice.audio_convert import any_to_wav from voice.audio_convert import any_to_wav
except Exception as e: except Exception as e:
pass pass
# 抽象类, 它包含了与消息通道无关的通用处理逻辑 # 抽象类, 它包含了与消息通道无关的通用处理逻辑
class ChatChannel(Channel): class ChatChannel(Channel):
name = None # 登录的用户名 name = None # 登录的用户名
user_id = None # 登录的用户id user_id = None # 登录的用户id
futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉正在执行的不会被取消 futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉正在执行的不会被取消
sessions = {} # 用于控制并发每个session_id同时只能有一个context在处理 sessions = {} # 用于控制并发每个session_id同时只能有一个context在处理
lock = threading.Lock() # 用于控制对sessions的访问 lock = threading.Lock() # 用于控制对sessions的访问
handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池 handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池
def __init__(self): def __init__(self):
_thread = threading.Thread(target=self.consume) _thread = threading.Thread(target=self.consume)
_thread.setDaemon(True) _thread.setDaemon(True)
_thread.start() _thread.start()
# 根据消息构造context消息内容相关的触发项写在这里 # 根据消息构造context消息内容相关的触发项写在这里
def _compose_context(self, ctype: ContextType, content, **kwargs): def _compose_context(self, ctype: ContextType, content, **kwargs):
context = Context(ctype, content) context = Context(ctype, content)
context.kwargs = kwargs context.kwargs = kwargs
# context首次传入时origin_ctype是None, # context首次传入时origin_ctype是None,
# 引入的起因是当输入语音时会嵌套生成两个context第一步语音转文本第二步通过文本生成文字回复。 # 引入的起因是当输入语音时会嵌套生成两个context第一步语音转文本第二步通过文本生成文字回复。
# origin_ctype用于第二步文本回复时判断是否需要匹配前缀如果是私聊的语音就不需要匹配前缀 # origin_ctype用于第二步文本回复时判断是否需要匹配前缀如果是私聊的语音就不需要匹配前缀
if 'origin_ctype' not in context: if "origin_ctype" not in context:
context['origin_ctype'] = ctype context["origin_ctype"] = ctype
# context首次传入时receiver是None根据类型设置receiver # context首次传入时receiver是None根据类型设置receiver
first_in = 'receiver' not in context first_in = "receiver" not in context
# 群名匹配过程设置session_id和receiver # 群名匹配过程设置session_id和receiver
if first_in: # context首次传入时receiver是None根据类型设置receiver if first_in: # context首次传入时receiver是None根据类型设置receiver
config = conf() config = conf()
cmsg = context['msg'] cmsg = context["msg"]
if context.get("isgroup", False): if context.get("isgroup", False):
group_name = cmsg.other_user_nickname group_name = cmsg.other_user_nickname
group_id = cmsg.other_user_id group_id = cmsg.other_user_id
group_name_white_list = config.get('group_name_white_list', []) group_name_white_list = config.get("group_name_white_list", [])
group_name_keyword_white_list = config.get('group_name_keyword_white_list', []) group_name_keyword_white_list = config.get("group_name_keyword_white_list", [])
if any([group_name in group_name_white_list, 'ALL_GROUP' in group_name_white_list, check_contain(group_name, group_name_keyword_white_list)]): if any(
group_chat_in_one_session = conf().get('group_chat_in_one_session', []) [
group_name in group_name_white_list,
"ALL_GROUP" in group_name_white_list,
check_contain(group_name, group_name_keyword_white_list),
]
):
group_chat_in_one_session = conf().get("group_chat_in_one_session", [])
session_id = cmsg.actual_user_id session_id = cmsg.actual_user_id
if any([group_name in group_chat_in_one_session, 'ALL_GROUP' in group_chat_in_one_session]): if any(
[
group_name in group_chat_in_one_session,
"ALL_GROUP" in group_chat_in_one_session,
]
):
session_id = group_id session_id = group_id
else: else:
return None return None
context['session_id'] = session_id context["session_id"] = session_id
context['receiver'] = group_id context["receiver"] = group_id
else: else:
context['session_id'] = cmsg.other_user_id context["session_id"] = cmsg.other_user_id
context['receiver'] = cmsg.other_user_id context["receiver"] = cmsg.other_user_id
e_context = PluginManager().emit_event(EventContext(Event.ON_RECEIVE_MESSAGE, {'channel': self, 'context': context})) e_context = PluginManager().emit_event(EventContext(Event.ON_RECEIVE_MESSAGE, {"channel": self, "context": context}))
context = e_context['context'] context = e_context["context"]
if e_context.is_pass() or context is None: if e_context.is_pass() or context is None:
return context return context
if cmsg.from_user_id == self.user_id and not config.get('trigger_by_self', True): if cmsg.from_user_id == self.user_id and not config.get("trigger_by_self", True):
logger.debug("[WX]self message skipped") logger.debug("[WX]self message skipped")
return None return None
# 消息内容匹配过程并处理content # 消息内容匹配过程并处理content
if ctype == ContextType.TEXT: if ctype == ContextType.TEXT:
if first_in and "\n- - - - - - -" in content: # 初次匹配 过滤引用消息 if first_in and "\n- - - - - - -" in content: # 初次匹配 过滤引用消息
logger.debug("[WX]reference query skipped") logger.debug("[WX]reference query skipped")
return None return None
if context.get("isgroup", False): # 群聊 if context.get("isgroup", False): # 群聊
# 校验关键字 # 校验关键字
match_prefix = check_prefix(content, conf().get('group_chat_prefix')) match_prefix = check_prefix(content, conf().get("group_chat_prefix"))
match_contain = check_contain(content, conf().get('group_chat_keyword')) match_contain = check_contain(content, conf().get("group_chat_keyword"))
flag = False flag = False
if match_prefix is not None or match_contain is not None: if match_prefix is not None or match_contain is not None:
flag = True flag = True
if match_prefix: if match_prefix:
content = content.replace(match_prefix, '', 1).strip() content = content.replace(match_prefix, "", 1).strip()
if context['msg'].is_at: if context["msg"].is_at:
logger.info("[WX]receive group at") logger.info("[WX]receive group at")
if not conf().get("group_at_off", False): if not conf().get("group_at_off", False):
flag = True flag = True
pattern = f'@{self.name}(\u2005|\u0020)' pattern = f"@{re.escape(self.name)}(\u2005|\u0020)"
content = re.sub(pattern, r'', content) content = re.sub(pattern, r"", content)
if not flag: if not flag:
if context["origin_ctype"] == ContextType.VOICE: if context["origin_ctype"] == ContextType.VOICE:
logger.info("[WX]receive group voice, but checkprefix didn't match") logger.info("[WX]receive group voice, but checkprefix didn't match")
return None return None
else: # 单聊 else: # 单聊
match_prefix = check_prefix(content, conf().get('single_chat_prefix',[''])) match_prefix = check_prefix(content, conf().get("single_chat_prefix", [""]))
if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容 if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容
content = content.replace(match_prefix, '', 1).strip() content = content.replace(match_prefix, "", 1).strip()
elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件 elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
pass pass
else: else:
return None return None
img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
if img_match_prefix: if img_match_prefix:
content = content.replace(img_match_prefix, '', 1).strip() content = content.replace(img_match_prefix, "", 1)
context.type = ContextType.IMAGE_CREATE context.type = ContextType.IMAGE_CREATE
else: else:
context.type = ContextType.TEXT context.type = ContextType.TEXT
context.content = content context.content = content.strip()
if 'desire_rtype' not in context and conf().get('always_reply_voice') and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE: if "desire_rtype" not in context and conf().get("always_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
context['desire_rtype'] = ReplyType.VOICE context["desire_rtype"] = ReplyType.VOICE
elif context.type == ContextType.VOICE: elif context.type == ContextType.VOICE:
if 'desire_rtype' not in context and conf().get('voice_reply_voice') and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE: if "desire_rtype" not in context and conf().get("voice_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
context['desire_rtype'] = ReplyType.VOICE context["desire_rtype"] = ReplyType.VOICE
return context return context
def _handle(self, context: Context): def _handle(self, context: Context):
if context is None or not context.content: if context is None or not context.content:
return return
logger.debug('[WX] ready to handle context: {}'.format(context)) logger.debug("[WX] ready to handle context: {}".format(context))
# reply的构建步骤 # reply的构建步骤
reply = self._generate_reply(context) reply = self._generate_reply(context)
logger.debug('[WX] ready to decorate reply: {}'.format(reply)) logger.debug("[WX] ready to decorate reply: {}".format(reply))
# reply的包装步骤 # reply的包装步骤
reply = self._decorate_reply(context, reply) reply = self._decorate_reply(context, reply)
@@ -139,20 +150,24 @@ class ChatChannel(Channel):
self._send_reply(context, reply) self._send_reply(context, reply)
def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply: def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply:
e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, { e_context = PluginManager().emit_event(
'channel': self, 'context': context, 'reply': reply})) EventContext(
reply = e_context['reply'] Event.ON_HANDLE_CONTEXT,
{"channel": self, "context": context, "reply": reply},
)
)
reply = e_context["reply"]
if not e_context.is_pass(): if not e_context.is_pass():
logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content)) logger.debug("[WX] ready to handle context: type={}, content={}".format(context.type, context.content))
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息 if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息
reply = super().build_reply_content(context.content, context) reply = super().build_reply_content(context.content, context)
elif context.type == ContextType.VOICE: # 语音消息 elif context.type == ContextType.VOICE: # 语音消息
cmsg = context['msg'] cmsg = context["msg"]
cmsg.prepare() cmsg.prepare()
file_path = context.content file_path = context.content
wav_path = os.path.splitext(file_path)[0] + '.wav' wav_path = os.path.splitext(file_path)[0] + ".wav"
try: try:
any_to_wav(file_path, wav_path) any_to_wav(file_path, wav_path)
except Exception as e: # 转换失败直接使用mp3对于某些apimp3也可以识别 except Exception as e: # 转换失败直接使用mp3对于某些apimp3也可以识别
logger.warning("[WX]any to wav error, use raw path. " + str(e)) logger.warning("[WX]any to wav error, use raw path. " + str(e))
wav_path = file_path wav_path = file_path
@@ -168,8 +183,7 @@ class ChatChannel(Channel):
# logger.warning("[WX]delete temp file error: " + str(e)) # logger.warning("[WX]delete temp file error: " + str(e))
if reply.type == ReplyType.TEXT: if reply.type == ReplyType.TEXT:
new_context = self._compose_context( new_context = self._compose_context(ContextType.TEXT, reply.content, **context.kwargs)
ContextType.TEXT, reply.content, **context.kwargs)
if new_context: if new_context:
reply = self._generate_reply(new_context) reply = self._generate_reply(new_context)
else: else:
@@ -177,18 +191,21 @@ class ChatChannel(Channel):
elif context.type == ContextType.IMAGE: # 图片消息,当前无默认逻辑 elif context.type == ContextType.IMAGE: # 图片消息,当前无默认逻辑
pass pass
else: else:
logger.error('[WX] unknown context type: {}'.format(context.type)) logger.error("[WX] unknown context type: {}".format(context.type))
return return
return reply return reply
def _decorate_reply(self, context: Context, reply: Reply) -> Reply: def _decorate_reply(self, context: Context, reply: Reply) -> Reply:
if reply and reply.type: if reply and reply.type:
e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, { e_context = PluginManager().emit_event(
'channel': self, 'context': context, 'reply': reply})) EventContext(
reply = e_context['reply'] Event.ON_DECORATE_REPLY,
desire_rtype = context.get('desire_rtype') {"channel": self, "context": context, "reply": reply},
)
)
reply = e_context["reply"]
desire_rtype = context.get("desire_rtype")
if not e_context.is_pass() and reply and reply.type: if not e_context.is_pass() and reply and reply.type:
if reply.type in self.NOT_SUPPORT_REPLYTYPE: if reply.type in self.NOT_SUPPORT_REPLYTYPE:
logger.error("[WX]reply type not support: " + str(reply.type)) logger.error("[WX]reply type not support: " + str(reply.type))
reply.type = ReplyType.ERROR reply.type = ReplyType.ERROR
@@ -200,55 +217,59 @@ class ChatChannel(Channel):
reply = super().build_text_to_voice(reply.content) reply = super().build_text_to_voice(reply.content)
return self._decorate_reply(context, reply) return self._decorate_reply(context, reply)
if context.get("isgroup", False): if context.get("isgroup", False):
reply_text = '@' + context['msg'].actual_user_nickname + ' ' + reply_text.strip() reply_text = "@" + context["msg"].actual_user_nickname + " " + reply_text.strip()
reply_text = conf().get("group_chat_reply_prefix", "") + reply_text reply_text = conf().get("group_chat_reply_prefix", "") + reply_text
else: else:
reply_text = conf().get("single_chat_reply_prefix", "") + reply_text reply_text = conf().get("single_chat_reply_prefix", "") + reply_text
reply.content = reply_text reply.content = reply_text
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
reply.content = "["+str(reply.type)+"]\n" + reply.content reply.content = "[" + str(reply.type) + "]\n" + reply.content
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE: elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE:
pass pass
else: else:
logger.error('[WX] unknown reply type: {}'.format(reply.type)) logger.error("[WX] unknown reply type: {}".format(reply.type))
return return
if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]: if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]:
logger.warning('[WX] desire_rtype: {}, but reply type: {}'.format(context.get('desire_rtype'), reply.type)) logger.warning("[WX] desire_rtype: {}, but reply type: {}".format(context.get("desire_rtype"), reply.type))
return reply return reply
def _send_reply(self, context: Context, reply: Reply): def _send_reply(self, context: Context, reply: Reply):
if reply and reply.type: if reply and reply.type:
e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, { e_context = PluginManager().emit_event(
'channel': self, 'context': context, 'reply': reply})) EventContext(
reply = e_context['reply'] Event.ON_SEND_REPLY,
{"channel": self, "context": context, "reply": reply},
)
)
reply = e_context["reply"]
if not e_context.is_pass() and reply and reply.type: if not e_context.is_pass() and reply and reply.type:
logger.debug('[WX] ready to send reply: {}, context: {}'.format(reply, context)) logger.debug("[WX] ready to send reply: {}, context: {}".format(reply, context))
self._send(reply, context) self._send(reply, context)
def _send(self, reply: Reply, context: Context, retry_cnt = 0): def _send(self, reply: Reply, context: Context, retry_cnt=0):
try: try:
self.send(reply, context) self.send(reply, context)
except Exception as e: except Exception as e:
logger.error('[WX] sendMsg error: {}'.format(str(e))) logger.error("[WX] sendMsg error: {}".format(str(e)))
if isinstance(e, NotImplementedError): if isinstance(e, NotImplementedError):
return return
logger.exception(e) logger.exception(e)
if retry_cnt < 2: if retry_cnt < 2:
time.sleep(3+3*retry_cnt) time.sleep(3 + 3 * retry_cnt)
self._send(reply, context, retry_cnt+1) self._send(reply, context, retry_cnt + 1)
def _success_callback(self, session_id, **kwargs):# 线程正常结束时的回调函数 def _success_callback(self, session_id, **kwargs): # 线程正常结束时的回调函数
logger.debug("Worker return success, session_id = {}".format(session_id)) logger.debug("Worker return success, session_id = {}".format(session_id))
def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数 def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数
logger.exception("Worker return exception: {}".format(exception)) logger.exception("Worker return exception: {}".format(exception))
def _thread_pool_callback(self, session_id, **kwargs): def _thread_pool_callback(self, session_id, **kwargs):
def func(worker:Future): def func(worker: Future):
try: try:
worker_exception = worker.exception() worker_exception = worker.exception()
if worker_exception: if worker_exception:
self._fail_callback(session_id, exception = worker_exception, **kwargs) self._fail_callback(session_id, exception=worker_exception, **kwargs)
else: else:
self._success_callback(session_id, **kwargs) self._success_callback(session_id, **kwargs)
except CancelledError as e: except CancelledError as e:
@@ -257,15 +278,19 @@ class ChatChannel(Channel):
logger.exception("Worker raise exception: {}".format(e)) logger.exception("Worker raise exception: {}".format(e))
with self.lock: with self.lock:
self.sessions[session_id][1].release() self.sessions[session_id][1].release()
return func return func
def produce(self, context: Context): def produce(self, context: Context):
session_id = context['session_id'] session_id = context["session_id"]
with self.lock: with self.lock:
if session_id not in self.sessions: if session_id not in self.sessions:
self.sessions[session_id] = [Dequeue(), threading.BoundedSemaphore(conf().get("concurrency_in_session", 4))] self.sessions[session_id] = [
if context.type == ContextType.TEXT and context.content.startswith("#"): Dequeue(),
self.sessions[session_id][0].putleft(context) # 优先处理管理命令 threading.BoundedSemaphore(conf().get("concurrency_in_session", 4)),
]
if context.type == ContextType.TEXT and context.content.startswith("#"):
self.sessions[session_id][0].putleft(context) # 优先处理管理命令
else: else:
self.sessions[session_id][0].put(context) self.sessions[session_id][0].put(context)
@@ -276,16 +301,16 @@ class ChatChannel(Channel):
session_ids = list(self.sessions.keys()) session_ids = list(self.sessions.keys())
for session_id in session_ids: for session_id in session_ids:
context_queue, semaphore = self.sessions[session_id] context_queue, semaphore = self.sessions[session_id]
if semaphore.acquire(blocking = False): # 等线程处理完毕才能删除 if semaphore.acquire(blocking=False): # 等线程处理完毕才能删除
if not context_queue.empty(): if not context_queue.empty():
context = context_queue.get() context = context_queue.get()
logger.debug("[WX] consume context: {}".format(context)) logger.debug("[WX] consume context: {}".format(context))
future:Future = self.handler_pool.submit(self._handle, context) future: Future = self.handler_pool.submit(self._handle, context)
future.add_done_callback(self._thread_pool_callback(session_id, context = context)) future.add_done_callback(self._thread_pool_callback(session_id, context=context))
if session_id not in self.futures: if session_id not in self.futures:
self.futures[session_id] = [] self.futures[session_id] = []
self.futures[session_id].append(future) self.futures[session_id].append(future)
elif semaphore._initial_value == semaphore._value+1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕 elif semaphore._initial_value == semaphore._value + 1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()] self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()]
assert len(self.futures[session_id]) == 0, "thread pool error" assert len(self.futures[session_id]) == 0, "thread pool error"
del self.sessions[session_id] del self.sessions[session_id]
@@ -294,26 +319,26 @@ class ChatChannel(Channel):
time.sleep(0.1) time.sleep(0.1)
# 取消session_id对应的所有任务只能取消排队的消息和已提交线程池但未执行的任务 # 取消session_id对应的所有任务只能取消排队的消息和已提交线程池但未执行的任务
def cancel_session(self, session_id): def cancel_session(self, session_id):
with self.lock: with self.lock:
if session_id in self.sessions: if session_id in self.sessions:
for future in self.futures[session_id]: for future in self.futures[session_id]:
future.cancel() future.cancel()
cnt = self.sessions[session_id][0].qsize() cnt = self.sessions[session_id][0].qsize()
if cnt>0: if cnt > 0:
logger.info("Cancel {} messages in session {}".format(cnt, session_id)) logger.info("Cancel {} messages in session {}".format(cnt, session_id))
self.sessions[session_id][0] = Dequeue() self.sessions[session_id][0] = Dequeue()
def cancel_all_session(self): def cancel_all_session(self):
with self.lock: with self.lock:
for session_id in self.sessions: for session_id in self.sessions:
for future in self.futures[session_id]: for future in self.futures[session_id]:
future.cancel() future.cancel()
cnt = self.sessions[session_id][0].qsize() cnt = self.sessions[session_id][0].qsize()
if cnt>0: if cnt > 0:
logger.info("Cancel {} messages in session {}".format(cnt, session_id)) logger.info("Cancel {} messages in session {}".format(cnt, session_id))
self.sessions[session_id][0] = Dequeue() self.sessions[session_id][0] = Dequeue()
def check_prefix(content, prefix_list): def check_prefix(content, prefix_list):
if not prefix_list: if not prefix_list:
@@ -323,6 +348,7 @@ def check_prefix(content, prefix_list):
return prefix return prefix
return None return None
def check_contain(content, keyword_list): def check_contain(content, keyword_list):
if not keyword_list: if not keyword_list:
return None return None

View File

@@ -1,5 +1,4 @@
"""
"""
本类表示聊天消息用于对itchat和wechaty的消息进行统一的封装。 本类表示聊天消息用于对itchat和wechaty的消息进行统一的封装。
填好必填项(群聊6个非群聊8个)即可接入ChatChannel并支持插件参考TerminalChannel 填好必填项(群聊6个非群聊8个)即可接入ChatChannel并支持插件参考TerminalChannel
@@ -20,7 +19,7 @@ other_user_id: 对方的id如果你是发送者那这个就是接收者id
other_user_nickname: 同上 other_user_nickname: 同上
is_group: 是否是群消息 (群聊必填) is_group: 是否是群消息 (群聊必填)
is_at: 是否被at is_at: 是否被at
- (群消息时一般会存在实际发送者是群内某个成员的id和昵称下列项仅在群消息时存在) - (群消息时一般会存在实际发送者是群内某个成员的id和昵称下列项仅在群消息时存在)
actual_user_id: 实际发送者id (群聊必填) actual_user_id: 实际发送者id (群聊必填)
@@ -34,20 +33,22 @@ _prepared: 是否已经调用过准备函数
_rawmsg: 原始消息对象 _rawmsg: 原始消息对象
""" """
class ChatMessage(object): class ChatMessage(object):
msg_id = None msg_id = None
create_time = None create_time = None
ctype = None ctype = None
content = None content = None
from_user_id = None from_user_id = None
from_user_nickname = None from_user_nickname = None
to_user_id = None to_user_id = None
to_user_nickname = None to_user_nickname = None
other_user_id = None other_user_id = None
other_user_nickname = None other_user_nickname = None
is_group = False is_group = False
is_at = False is_at = False
actual_user_id = None actual_user_id = None
@@ -57,8 +58,7 @@ class ChatMessage(object):
_prepared = False _prepared = False
_rawmsg = None _rawmsg = None
def __init__(self, _rawmsg):
def __init__(self,_rawmsg):
self._rawmsg = _rawmsg self._rawmsg = _rawmsg
def prepare(self): def prepare(self):
@@ -67,7 +67,7 @@ class ChatMessage(object):
self._prepare_fn() self._prepare_fn()
def __str__(self): def __str__(self):
return 'ChatMessage: id={}, create_time={}, ctype={}, content={}, from_user_id={}, from_user_nickname={}, to_user_id={}, to_user_nickname={}, other_user_id={}, other_user_nickname={}, is_group={}, is_at={}, actual_user_id={}, actual_user_nickname={}'.format( return "ChatMessage: id={}, create_time={}, ctype={}, content={}, from_user_id={}, from_user_nickname={}, to_user_id={}, to_user_nickname={}, other_user_id={}, other_user_nickname={}, is_group={}, is_at={}, actual_user_id={}, actual_user_nickname={}".format(
self.msg_id, self.msg_id,
self.create_time, self.create_time,
self.ctype, self.ctype,
@@ -82,4 +82,4 @@ class ChatMessage(object):
self.is_at, self.is_at,
self.actual_user_id, self.actual_user_id,
self.actual_user_nickname, self.actual_user_nickname,
) )

View File

@@ -1,14 +1,23 @@
import sys
from bridge.context import * from bridge.context import *
from bridge.reply import Reply, ReplyType from bridge.reply import Reply, ReplyType
from channel.chat_channel import ChatChannel, check_prefix from channel.chat_channel import ChatChannel, check_prefix
from channel.chat_message import ChatMessage from channel.chat_message import ChatMessage
import sys
from config import conf
from common.log import logger from common.log import logger
from config import conf
class TerminalMessage(ChatMessage): class TerminalMessage(ChatMessage):
def __init__(self, msg_id, content, ctype = ContextType.TEXT, from_user_id = "User", to_user_id = "Chatgpt", other_user_id = "Chatgpt"): def __init__(
self,
msg_id,
content,
ctype=ContextType.TEXT,
from_user_id="User",
to_user_id="Chatgpt",
other_user_id="Chatgpt",
):
self.msg_id = msg_id self.msg_id = msg_id
self.ctype = ctype self.ctype = ctype
self.content = content self.content = content
@@ -16,6 +25,7 @@ class TerminalMessage(ChatMessage):
self.to_user_id = to_user_id self.to_user_id = to_user_id
self.other_user_id = other_user_id self.other_user_id = other_user_id
class TerminalChannel(ChatChannel): class TerminalChannel(ChatChannel):
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE] NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE]
@@ -23,14 +33,18 @@ class TerminalChannel(ChatChannel):
print("\nBot:") print("\nBot:")
if reply.type == ReplyType.IMAGE: if reply.type == ReplyType.IMAGE:
from PIL import Image from PIL import Image
image_storage = reply.content image_storage = reply.content
image_storage.seek(0) image_storage.seek(0)
img = Image.open(image_storage) img = Image.open(image_storage)
print("<IMAGE>") print("<IMAGE>")
img.show() img.show()
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
import io
import requests
from PIL import Image from PIL import Image
import requests,io
img_url = reply.content img_url = reply.content
pic_res = requests.get(img_url, stream=True) pic_res = requests.get(img_url, stream=True)
image_storage = io.BytesIO() image_storage = io.BytesIO()
@@ -59,11 +73,11 @@ class TerminalChannel(ChatChannel):
print("\nExiting...") print("\nExiting...")
sys.exit() sys.exit()
msg_id += 1 msg_id += 1
trigger_prefixs = conf().get("single_chat_prefix",[""]) trigger_prefixs = conf().get("single_chat_prefix", [""])
if check_prefix(prompt, trigger_prefixs) is None: if check_prefix(prompt, trigger_prefixs) is None:
prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀 prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀
context = self._compose_context(ContextType.TEXT, prompt, msg = TerminalMessage(msg_id, prompt)) context = self._compose_context(ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt))
if context: if context:
self.produce(context) self.produce(context)
else: else:

View File

@@ -4,40 +4,50 @@
wechat channel wechat channel
""" """
import io
import json
import os import os
import threading import threading
import requests
import io
import time import time
import json
import requests
from bridge.context import *
from bridge.reply import *
from channel.chat_channel import ChatChannel from channel.chat_channel import ChatChannel
from channel.wechat.wechat_message import * from channel.wechat.wechat_message import *
from common.singleton import singleton from common.expired_dict import ExpiredDict
from common.log import logger from common.log import logger
from common.singleton import singleton
from common.time_check import time_checker
from config import conf, get_appdata_dir
from lib import itchat from lib import itchat
from lib.itchat.content import * from lib.itchat.content import *
from bridge.reply import *
from bridge.context import *
from config import conf
from common.time_check import time_checker
from common.expired_dict import ExpiredDict
from plugins import * from plugins import *
@itchat.msg_register([TEXT,VOICE,PICTURE])
@itchat.msg_register([TEXT, VOICE, PICTURE, NOTE])
def handler_single_msg(msg): def handler_single_msg(msg):
# logger.debug("handler_single_msg: {}".format(msg)) try:
if msg['Type'] == PICTURE and msg['MsgType'] == 47: cmsg = WeChatMessage(msg, False)
except NotImplementedError as e:
logger.debug("[WX]single message {} skipped: {}".format(msg["MsgId"], e))
return None return None
WechatChannel().handle_single(WeChatMessage(msg)) WechatChannel().handle_single(cmsg)
return None return None
@itchat.msg_register([TEXT,VOICE,PICTURE], isGroupChat=True)
@itchat.msg_register([TEXT, VOICE, PICTURE, NOTE], isGroupChat=True)
def handler_group_msg(msg): def handler_group_msg(msg):
if msg['Type'] == PICTURE and msg['MsgType'] == 47: try:
cmsg = WeChatMessage(msg, True)
except NotImplementedError as e:
logger.debug("[WX]group message {} skipped: {}".format(msg["MsgId"], e))
return None return None
WechatChannel().handle_group(WeChatMessage(msg,True)) WechatChannel().handle_group(cmsg)
return None return None
def _check(func): def _check(func):
def wrapper(self, cmsg: ChatMessage): def wrapper(self, cmsg: ChatMessage):
msgId = cmsg.msg_id msgId = cmsg.msg_id
@@ -45,21 +55,24 @@ def _check(func):
logger.info("Wechat message {} already received, ignore".format(msgId)) logger.info("Wechat message {} already received, ignore".format(msgId))
return return
self.receivedMsgs[msgId] = cmsg self.receivedMsgs[msgId] = cmsg
create_time = cmsg.create_time # 消息时间戳 create_time = cmsg.create_time # 消息时间戳
if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息 if conf().get("hot_reload") == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
logger.debug("[WX]history message {} skipped".format(msgId)) logger.debug("[WX]history message {} skipped".format(msgId))
return return
return func(self, cmsg) return func(self, cmsg)
return wrapper return wrapper
#可用的二维码生成接口
#https://api.qrserver.com/v1/create-qr-code/?size=400×400&data=https://www.abc.com # 可用的二维码生成接口
#https://api.isoyu.com/qr/?m=1&e=L&p=20&url=https://www.abc.com # https://api.qrserver.com/v1/create-qr-code/?size=400×400&data=https://www.abc.com
def qrCallback(uuid,status,qrcode): # https://api.isoyu.com/qr/?m=1&e=L&p=20&url=https://www.abc.com
def qrCallback(uuid, status, qrcode):
# logger.debug("qrCallback: {} {}".format(uuid,status)) # logger.debug("qrCallback: {} {}".format(uuid,status))
if status == '0': if status == "0":
try: try:
from PIL import Image from PIL import Image
img = Image.open(io.BytesIO(qrcode)) img = Image.open(io.BytesIO(qrcode))
_thread = threading.Thread(target=img.show, args=("QRCode",)) _thread = threading.Thread(target=img.show, args=("QRCode",))
_thread.setDaemon(True) _thread.setDaemon(True)
@@ -68,42 +81,50 @@ def qrCallback(uuid,status,qrcode):
pass pass
import qrcode import qrcode
url = f"https://login.weixin.qq.com/l/{uuid}" url = f"https://login.weixin.qq.com/l/{uuid}"
qr_api1="https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url) qr_api1 = "https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url)
qr_api2="https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(url) qr_api2 = "https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(url)
qr_api3="https://api.pwmqr.com/qrcode/create/?url={}".format(url) qr_api3 = "https://api.pwmqr.com/qrcode/create/?url={}".format(url)
qr_api4="https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format(url) qr_api4 = "https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format(url)
print("You can also scan QRCode in any website below:") print("You can also scan QRCode in any website below:")
print(qr_api3) print(qr_api3)
print(qr_api4) print(qr_api4)
print(qr_api2) print(qr_api2)
print(qr_api1) print(qr_api1)
qr = qrcode.QRCode(border=1) qr = qrcode.QRCode(border=1)
qr.add_data(url) qr.add_data(url)
qr.make(fit=True) qr.make(fit=True)
qr.print_ascii(invert=True) qr.print_ascii(invert=True)
@singleton @singleton
class WechatChannel(ChatChannel): class WechatChannel(ChatChannel):
NOT_SUPPORT_REPLYTYPE = [] NOT_SUPPORT_REPLYTYPE = []
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.receivedMsgs = ExpiredDict(60*60*24) self.receivedMsgs = ExpiredDict(60 * 60 * 24)
def startup(self): def startup(self):
itchat.instance.receivingRetryCount = 600 # 修改断线超时时间 itchat.instance.receivingRetryCount = 600 # 修改断线超时时间
# login by scan QRCode # login by scan QRCode
hotReload = conf().get('hot_reload', False) hotReload = conf().get("hot_reload", False)
status_path = os.path.join(get_appdata_dir(), "itchat.pkl")
try: try:
itchat.auto_login(enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback) itchat.auto_login(
enableCmdQR=2,
hotReload=hotReload,
statusStorageDir=status_path,
qrCallback=qrCallback,
)
except Exception as e: except Exception as e:
if hotReload: if hotReload:
logger.error("Hot reload failed, try to login without hot reload") logger.error("Hot reload failed, try to login without hot reload")
itchat.logout() itchat.logout()
os.remove("itchat.pkl") os.remove(status_path)
itchat.auto_login(enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback) itchat.auto_login(enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback)
else: else:
raise e raise e
@@ -127,48 +148,56 @@ class WechatChannel(ChatChannel):
@time_checker @time_checker
@_check @_check
def handle_single(self, cmsg : ChatMessage): def handle_single(self, cmsg: ChatMessage):
if cmsg.ctype == ContextType.VOICE: if cmsg.ctype == ContextType.VOICE:
if conf().get('speech_recognition') != True: if conf().get("speech_recognition") != True:
return return
logger.debug("[WX]receive voice msg: {}".format(cmsg.content)) logger.debug("[WX]receive voice msg: {}".format(cmsg.content))
elif cmsg.ctype == ContextType.IMAGE: elif cmsg.ctype == ContextType.IMAGE:
logger.debug("[WX]receive image msg: {}".format(cmsg.content)) logger.debug("[WX]receive image msg: {}".format(cmsg.content))
else: elif cmsg.ctype == ContextType.PATPAT:
logger.debug("[WX]receive patpat msg: {}".format(cmsg.content))
elif cmsg.ctype == ContextType.TEXT:
logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg)) logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
else:
logger.debug("[WX]receive msg: {}, cmsg={}".format(cmsg.content, cmsg))
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg) context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
if context: if context:
self.produce(context) self.produce(context)
@time_checker @time_checker
@_check @_check
def handle_group(self, cmsg : ChatMessage): def handle_group(self, cmsg: ChatMessage):
if cmsg.ctype == ContextType.VOICE: if cmsg.ctype == ContextType.VOICE:
if conf().get('speech_recognition') != True: if conf().get("speech_recognition") != True:
return return
logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content)) logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content))
elif cmsg.ctype == ContextType.IMAGE: elif cmsg.ctype == ContextType.IMAGE:
logger.debug("[WX]receive image for group msg: {}".format(cmsg.content)) logger.debug("[WX]receive image for group msg: {}".format(cmsg.content))
else: elif cmsg.ctype in [ContextType.JOIN_GROUP, ContextType.PATPAT]:
logger.debug("[WX]receive note msg: {}".format(cmsg.content))
elif cmsg.ctype == ContextType.TEXT:
# logger.debug("[WX]receive group msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg)) # logger.debug("[WX]receive group msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
pass pass
else:
logger.debug("[WX]receive group msg: {}".format(cmsg.content))
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg) context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
if context: if context:
self.produce(context) self.produce(context)
# 统一的发送函数每个Channel自行实现根据reply的type字段发送不同类型的消息 # 统一的发送函数每个Channel自行实现根据reply的type字段发送不同类型的消息
def send(self, reply: Reply, context: Context): def send(self, reply: Reply, context: Context):
receiver = context["receiver"] receiver = context["receiver"]
if reply.type == ReplyType.TEXT: if reply.type == ReplyType.TEXT:
itchat.send(reply.content, toUserName=receiver) itchat.send(reply.content, toUserName=receiver)
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
itchat.send(reply.content, toUserName=receiver) itchat.send(reply.content, toUserName=receiver)
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
elif reply.type == ReplyType.VOICE: elif reply.type == ReplyType.VOICE:
itchat.send_file(reply.content, toUserName=receiver) itchat.send_file(reply.content, toUserName=receiver)
logger.info('[WX] sendFile={}, receiver={}'.format(reply.content, receiver)) logger.info("[WX] sendFile={}, receiver={}".format(reply.content, receiver))
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
img_url = reply.content img_url = reply.content
pic_res = requests.get(img_url, stream=True) pic_res = requests.get(img_url, stream=True)
image_storage = io.BytesIO() image_storage = io.BytesIO()
@@ -176,9 +205,9 @@ class WechatChannel(ChatChannel):
image_storage.write(block) image_storage.write(block)
image_storage.seek(0) image_storage.seek(0)
itchat.send_image(image_storage, toUserName=receiver) itchat.send_image(image_storage, toUserName=receiver)
logger.info('[WX] sendImage url={}, receiver={}'.format(img_url,receiver)) logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
elif reply.type == ReplyType.IMAGE: # 从文件读取图片 elif reply.type == ReplyType.IMAGE: # 从文件读取图片
image_storage = reply.content image_storage = reply.content
image_storage.seek(0) image_storage.seek(0)
itchat.send_image(image_storage, toUserName=receiver) itchat.send_image(image_storage, toUserName=receiver)
logger.info('[WX] sendImage, receiver={}'.format(receiver)) logger.info("[WX] sendImage, receiver={}".format(receiver))

View File

@@ -1,54 +1,70 @@
import re
from bridge.context import ContextType from bridge.context import ContextType
from channel.chat_message import ChatMessage from channel.chat_message import ChatMessage
from common.tmp_dir import TmpDir
from common.log import logger from common.log import logger
from lib.itchat.content import * from common.tmp_dir import TmpDir
from lib import itchat from lib import itchat
from lib.itchat.content import *
class WeChatMessage(ChatMessage): class WeChatMessage(ChatMessage):
def __init__(self, itchat_msg, is_group=False): def __init__(self, itchat_msg, is_group=False):
super().__init__( itchat_msg) super().__init__(itchat_msg)
self.msg_id = itchat_msg['MsgId'] self.msg_id = itchat_msg["MsgId"]
self.create_time = itchat_msg['CreateTime'] self.create_time = itchat_msg["CreateTime"]
self.is_group = is_group self.is_group = is_group
if itchat_msg['Type'] == TEXT: if itchat_msg["Type"] == TEXT:
self.ctype = ContextType.TEXT self.ctype = ContextType.TEXT
self.content = itchat_msg['Text'] self.content = itchat_msg["Text"]
elif itchat_msg['Type'] == VOICE: elif itchat_msg["Type"] == VOICE:
self.ctype = ContextType.VOICE self.ctype = ContextType.VOICE
self.content = TmpDir().path() + itchat_msg['FileName'] # content直接存临时目录路径 self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径
self._prepare_fn = lambda: itchat_msg.download(self.content) self._prepare_fn = lambda: itchat_msg.download(self.content)
elif itchat_msg['Type'] == PICTURE and itchat_msg['MsgType'] == 3: elif itchat_msg["Type"] == PICTURE and itchat_msg["MsgType"] == 3:
self.ctype = ContextType.IMAGE self.ctype = ContextType.IMAGE
self.content = TmpDir().path() + itchat_msg['FileName'] # content直接存临时目录路径 self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径
self._prepare_fn = lambda: itchat_msg.download(self.content) self._prepare_fn = lambda: itchat_msg.download(self.content)
elif itchat_msg["Type"] == NOTE and itchat_msg["MsgType"] == 10000:
if is_group and ("加入群聊" in itchat_msg["Content"] or "加入了群聊" in itchat_msg["Content"]):
self.ctype = ContextType.JOIN_GROUP
self.content = itchat_msg["Content"]
# 这里只能得到nickname actual_user_id还是机器人的id
if "加入了群聊" in itchat_msg["Content"]:
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[-1]
elif "加入群聊" in itchat_msg["Content"]:
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
elif "拍了拍我" in itchat_msg["Content"]:
self.ctype = ContextType.PATPAT
self.content = itchat_msg["Content"]
if is_group:
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
else:
raise NotImplementedError("Unsupported note message: " + itchat_msg["Content"])
else: else:
raise NotImplementedError("Unsupported message type: {}".format(itchat_msg['Type'])) raise NotImplementedError("Unsupported message type: Type:{} MsgType:{}".format(itchat_msg["Type"], itchat_msg["MsgType"]))
self.from_user_id = itchat_msg['FromUserName'] self.from_user_id = itchat_msg["FromUserName"]
self.to_user_id = itchat_msg['ToUserName'] self.to_user_id = itchat_msg["ToUserName"]
user_id = itchat.instance.storageClass.userName user_id = itchat.instance.storageClass.userName
nickname = itchat.instance.storageClass.nickName nickname = itchat.instance.storageClass.nickName
# 虽然from_user_id和to_user_id用的少但是为了保持一致性还是要填充一下 # 虽然from_user_id和to_user_id用的少但是为了保持一致性还是要填充一下
# 以下很繁琐,一句话总结:能填的都填了。 # 以下很繁琐,一句话总结:能填的都填了。
if self.from_user_id == user_id: if self.from_user_id == user_id:
self.from_user_nickname = nickname self.from_user_nickname = nickname
if self.to_user_id == user_id: if self.to_user_id == user_id:
self.to_user_nickname = nickname self.to_user_nickname = nickname
try: # 陌生人时候, 'User'字段可能不存在 try: # 陌生人时候, 'User'字段可能不存在
self.other_user_id = itchat_msg['User']['UserName'] self.other_user_id = itchat_msg["User"]["UserName"]
self.other_user_nickname = itchat_msg['User']['NickName'] self.other_user_nickname = itchat_msg["User"]["NickName"]
if self.other_user_id == self.from_user_id: if self.other_user_id == self.from_user_id:
self.from_user_nickname = self.other_user_nickname self.from_user_nickname = self.other_user_nickname
if self.other_user_id == self.to_user_id: if self.other_user_id == self.to_user_id:
self.to_user_nickname = self.other_user_nickname self.to_user_nickname = self.other_user_nickname
except KeyError as e: # 处理偶尔没有对方信息的情况 except KeyError as e: # 处理偶尔没有对方信息的情况
logger.warn("[WX]get other_user_id failed: " + str(e)) logger.warn("[WX]get other_user_id failed: " + str(e))
if self.from_user_id == user_id: if self.from_user_id == user_id:
self.other_user_id = self.to_user_id self.other_user_id = self.to_user_id
@@ -56,6 +72,7 @@ class WeChatMessage(ChatMessage):
self.other_user_id = self.from_user_id self.other_user_id = self.from_user_id
if self.is_group: if self.is_group:
self.is_at = itchat_msg['IsAt'] self.is_at = itchat_msg["IsAt"]
self.actual_user_id = itchat_msg['ActualUserName'] self.actual_user_id = itchat_msg["ActualUserName"]
self.actual_user_nickname = itchat_msg['ActualNickName'] if self.ctype not in [ContextType.JOIN_GROUP, ContextType.PATPAT]:
self.actual_user_nickname = itchat_msg["ActualNickName"]

View File

@@ -4,104 +4,108 @@
wechaty channel wechaty channel
Python Wechaty - https://github.com/wechaty/python-wechaty Python Wechaty - https://github.com/wechaty/python-wechaty
""" """
import asyncio
import base64 import base64
import os import os
import time import time
import asyncio
from bridge.context import Context from wechaty import Contact, Wechaty
from wechaty_puppet import FileBox
from wechaty import Wechaty, Contact
from wechaty.user import Message from wechaty.user import Message
from bridge.reply import * from wechaty_puppet import FileBox
from bridge.context import * from bridge.context import *
from bridge.context import Context
from bridge.reply import *
from channel.chat_channel import ChatChannel from channel.chat_channel import ChatChannel
from channel.wechat.wechaty_message import WechatyMessage from channel.wechat.wechaty_message import WechatyMessage
from common.log import logger from common.log import logger
from common.singleton import singleton from common.singleton import singleton
from config import conf from config import conf
try: try:
from voice.audio_convert import any_to_sil from voice.audio_convert import any_to_sil
except Exception as e: except Exception as e:
pass pass
@singleton @singleton
class WechatyChannel(ChatChannel): class WechatyChannel(ChatChannel):
NOT_SUPPORT_REPLYTYPE = [] NOT_SUPPORT_REPLYTYPE = []
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def startup(self): def startup(self):
config = conf() config = conf()
token = config.get('wechaty_puppet_service_token') token = config.get("wechaty_puppet_service_token")
os.environ['WECHATY_PUPPET_SERVICE_TOKEN'] = token os.environ["WECHATY_PUPPET_SERVICE_TOKEN"] = token
asyncio.run(self.main()) asyncio.run(self.main())
async def main(self): async def main(self):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
#将asyncio的loop传入处理线程 # 将asyncio的loop传入处理线程
self.handler_pool._initializer= lambda: asyncio.set_event_loop(loop) self.handler_pool._initializer = lambda: asyncio.set_event_loop(loop)
self.bot = Wechaty() self.bot = Wechaty()
self.bot.on('login', self.on_login) self.bot.on("login", self.on_login)
self.bot.on('message', self.on_message) self.bot.on("message", self.on_message)
await self.bot.start() await self.bot.start()
async def on_login(self, contact: Contact): async def on_login(self, contact: Contact):
self.user_id = contact.contact_id self.user_id = contact.contact_id
self.name = contact.name self.name = contact.name
logger.info('[WX] login user={}'.format(contact)) logger.info("[WX] login user={}".format(contact))
# 统一的发送函数每个Channel自行实现根据reply的type字段发送不同类型的消息 # 统一的发送函数每个Channel自行实现根据reply的type字段发送不同类型的消息
def send(self, reply: Reply, context: Context): def send(self, reply: Reply, context: Context):
receiver_id = context['receiver'] receiver_id = context["receiver"]
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
if context['isgroup']: if context["isgroup"]:
receiver = asyncio.run_coroutine_threadsafe(self.bot.Room.find(receiver_id),loop).result() receiver = asyncio.run_coroutine_threadsafe(self.bot.Room.find(receiver_id), loop).result()
else: else:
receiver = asyncio.run_coroutine_threadsafe(self.bot.Contact.find(receiver_id),loop).result() receiver = asyncio.run_coroutine_threadsafe(self.bot.Contact.find(receiver_id), loop).result()
msg = None msg = None
if reply.type == ReplyType.TEXT: if reply.type == ReplyType.TEXT:
msg = reply.content msg = reply.content
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result() asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
msg = reply.content msg = reply.content
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result() asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
elif reply.type == ReplyType.VOICE: elif reply.type == ReplyType.VOICE:
voiceLength = None voiceLength = None
file_path = reply.content file_path = reply.content
sil_file = os.path.splitext(file_path)[0] + '.sil' sil_file = os.path.splitext(file_path)[0] + ".sil"
voiceLength = int(any_to_sil(file_path, sil_file)) voiceLength = int(any_to_sil(file_path, sil_file))
if voiceLength >= 60000: if voiceLength >= 60000:
voiceLength = 60000 voiceLength = 60000
logger.info('[WX] voice too long, length={}, set to 60s'.format(voiceLength)) logger.info("[WX] voice too long, length={}, set to 60s".format(voiceLength))
# 发送语音 # 发送语音
t = int(time.time()) t = int(time.time())
msg = FileBox.from_file(sil_file, name=str(t) + '.sil') msg = FileBox.from_file(sil_file, name=str(t) + ".sil")
if voiceLength is not None: if voiceLength is not None:
msg.metadata['voiceLength'] = voiceLength msg.metadata["voiceLength"] = voiceLength
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result() asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
try: try:
os.remove(file_path) os.remove(file_path)
if sil_file != file_path: if sil_file != file_path:
os.remove(sil_file) os.remove(sil_file)
except Exception as e: except Exception as e:
pass pass
logger.info('[WX] sendVoice={}, receiver={}'.format(reply.content, receiver)) logger.info("[WX] sendVoice={}, receiver={}".format(reply.content, receiver))
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
img_url = reply.content img_url = reply.content
t = int(time.time()) t = int(time.time())
msg = FileBox.from_url(url=img_url, name=str(t) + '.png') msg = FileBox.from_url(url=img_url, name=str(t) + ".png")
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result() asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
logger.info('[WX] sendImage url={}, receiver={}'.format(img_url,receiver)) logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
elif reply.type == ReplyType.IMAGE: # 从文件读取图片 elif reply.type == ReplyType.IMAGE: # 从文件读取图片
image_storage = reply.content image_storage = reply.content
image_storage.seek(0) image_storage.seek(0)
t = int(time.time()) t = int(time.time())
msg = FileBox.from_base64(base64.b64encode(image_storage.read()), str(t) + '.png') msg = FileBox.from_base64(base64.b64encode(image_storage.read()), str(t) + ".png")
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result() asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
logger.info('[WX] sendImage, receiver={}'.format(receiver)) logger.info("[WX] sendImage, receiver={}".format(receiver))
async def on_message(self, msg: Message): async def on_message(self, msg: Message):
""" """
@@ -110,16 +114,16 @@ class WechatyChannel(ChatChannel):
try: try:
cmsg = await WechatyMessage(msg) cmsg = await WechatyMessage(msg)
except NotImplementedError as e: except NotImplementedError as e:
logger.debug('[WX] {}'.format(e)) logger.debug("[WX] {}".format(e))
return return
except Exception as e: except Exception as e:
logger.exception('[WX] {}'.format(e)) logger.exception("[WX] {}".format(e))
return return
logger.debug('[WX] message:{}'.format(cmsg)) logger.debug("[WX] message:{}".format(cmsg))
room = msg.room() # 获取消息来自的群聊. 如果消息不是来自群聊, 则返回None room = msg.room() # 获取消息来自的群聊. 如果消息不是来自群聊, 则返回None
isgroup = room is not None isgroup = room is not None
ctype = cmsg.ctype ctype = cmsg.ctype
context = self._compose_context(ctype, cmsg.content, isgroup=isgroup, msg=cmsg) context = self._compose_context(ctype, cmsg.content, isgroup=isgroup, msg=cmsg)
if context: if context:
logger.info('[WX] receiveMsg={}, context={}'.format(cmsg, context)) logger.info("[WX] receiveMsg={}, context={}".format(cmsg, context))
self.produce(context) self.produce(context)

View File

@@ -1,17 +1,21 @@
import asyncio import asyncio
import re import re
from wechaty import MessageType from wechaty import MessageType
from wechaty.user import Message
from bridge.context import ContextType from bridge.context import ContextType
from channel.chat_message import ChatMessage from channel.chat_message import ChatMessage
from common.tmp_dir import TmpDir
from common.log import logger from common.log import logger
from wechaty.user import Message from common.tmp_dir import TmpDir
class aobject(object): class aobject(object):
"""Inheriting this class allows you to define an async __init__. """Inheriting this class allows you to define an async __init__.
So you can create objects by doing something like `await MyClass(params)` So you can create objects by doing something like `await MyClass(params)`
""" """
async def __new__(cls, *a, **kw): async def __new__(cls, *a, **kw):
instance = super().__new__(cls) instance = super().__new__(cls)
await instance.__init__(*a, **kw) await instance.__init__(*a, **kw)
@@ -19,17 +23,18 @@ class aobject(object):
async def __init__(self): async def __init__(self):
pass pass
class WechatyMessage(ChatMessage, aobject):
class WechatyMessage(ChatMessage, aobject):
async def __init__(self, wechaty_msg: Message): async def __init__(self, wechaty_msg: Message):
super().__init__(wechaty_msg) super().__init__(wechaty_msg)
room = wechaty_msg.room() room = wechaty_msg.room()
self.msg_id = wechaty_msg.message_id self.msg_id = wechaty_msg.message_id
self.create_time = wechaty_msg.payload.timestamp self.create_time = wechaty_msg.payload.timestamp
self.is_group = room is not None self.is_group = room is not None
if wechaty_msg.type() == MessageType.MESSAGE_TYPE_TEXT: if wechaty_msg.type() == MessageType.MESSAGE_TYPE_TEXT:
self.ctype = ContextType.TEXT self.ctype = ContextType.TEXT
self.content = wechaty_msg.text() self.content = wechaty_msg.text()
@@ -40,12 +45,13 @@ class WechatyMessage(ChatMessage, aobject):
def func(): def func():
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
asyncio.run_coroutine_threadsafe(voice_file.to_file(self.content),loop).result() asyncio.run_coroutine_threadsafe(voice_file.to_file(self.content), loop).result()
self._prepare_fn = func self._prepare_fn = func
else: else:
raise NotImplementedError("Unsupported message type: {}".format(wechaty_msg.type())) raise NotImplementedError("Unsupported message type: {}".format(wechaty_msg.type()))
from_contact = wechaty_msg.talker() # 获取消息的发送者 from_contact = wechaty_msg.talker() # 获取消息的发送者
self.from_user_id = from_contact.contact_id self.from_user_id = from_contact.contact_id
self.from_user_nickname = from_contact.name self.from_user_nickname = from_contact.name
@@ -54,7 +60,7 @@ class WechatyMessage(ChatMessage, aobject):
# wecahty: from是消息实际发送者, to:所在群 # wecahty: from是消息实际发送者, to:所在群
# itchat: 如果是你发送群消息from和to是你自己和所在群如果是别人发群消息from和to是所在群和你自己 # itchat: 如果是你发送群消息from和to是你自己和所在群如果是别人发群消息from和to是所在群和你自己
# 但这个差别不影响逻辑group中只使用到1.用from来判断是否是自己发的2.actual_user_id来判断实际发送用户 # 但这个差别不影响逻辑group中只使用到1.用from来判断是否是自己发的2.actual_user_id来判断实际发送用户
if self.is_group: if self.is_group:
self.to_user_id = room.room_id self.to_user_id = room.room_id
self.to_user_nickname = await room.topic() self.to_user_nickname = await room.topic()
@@ -63,22 +69,20 @@ class WechatyMessage(ChatMessage, aobject):
self.to_user_id = to_contact.contact_id self.to_user_id = to_contact.contact_id
self.to_user_nickname = to_contact.name self.to_user_nickname = to_contact.name
if self.is_group or wechaty_msg.is_self(): # 如果是群消息other_user设置为群如果是私聊消息而且自己发的就设置成对方。 if self.is_group or wechaty_msg.is_self(): # 如果是群消息other_user设置为群如果是私聊消息而且自己发的就设置成对方。
self.other_user_id = self.to_user_id self.other_user_id = self.to_user_id
self.other_user_nickname = self.to_user_nickname self.other_user_nickname = self.to_user_nickname
else: else:
self.other_user_id = self.from_user_id self.other_user_id = self.from_user_id
self.other_user_nickname = self.from_user_nickname self.other_user_nickname = self.from_user_nickname
if self.is_group: # wechaty群聊中实际发送用户就是from_user
if self.is_group: # wechaty群聊中实际发送用户就是from_user
self.is_at = await wechaty_msg.mention_self() self.is_at = await wechaty_msg.mention_self()
if not self.is_at: # 有时候复制粘贴的消息,不算做@,但是内容里面会有@xxx这里做一下兼容 if not self.is_at: # 有时候复制粘贴的消息,不算做@,但是内容里面会有@xxx这里做一下兼容
name = wechaty_msg.wechaty.user_self().name name = wechaty_msg.wechaty.user_self().name
pattern = f'@{name}(\u2005|\u0020)' pattern = f"@{re.escape(name)}(\u2005|\u0020)"
if re.search(pattern,self.content): if re.search(pattern, self.content):
logger.debug(f'wechaty message {self.msg_id} include at') logger.debug(f"wechaty message {self.msg_id} include at")
self.is_at = True self.is_at = True
self.actual_user_id = self.from_user_id self.actual_user_id = self.from_user_id

View File

@@ -1,57 +1,100 @@
# 微信公众号channel # 微信公众号channel
鉴于个人微信号在服务器上通过itchat登录有封号风险这里新增了微信公众号channel提供无风险的服务。 鉴于个人微信号在服务器上通过itchat登录有封号风险这里新增了微信公众号channel提供无风险的服务。
目前支持订阅号(个人)和服务号(企业)两种类型的公众号,它们的主要区别就是被动回复和主动回复 目前支持订阅号和服务号两种类型的公众号,它们都支持文本交互,语音和图片输入。其中个人主体的微信订阅号由于无法通过微信认证,存在回复时间限制,每天的图片和声音回复次数也有限制
个人微信订阅号有许多接口限制目前仅支持最基本的文本对话和语音输入支持加载插件支持私有api_key。
暂未实现图片输入输出、语音输出等交互形式。
## 使用方法(订阅号,服务号类似) ## 使用方法(订阅号,服务号类似)
在开始部署前你需要一个拥有公网IP的服务器以提供微信服务器和我们自己服务器的连接。或者你需要进行内网穿透否则微信服务器无法将消息发送给我们的服务器。 在开始部署前你需要一个拥有公网IP的服务器以提供微信服务器和我们自己服务器的连接。或者你需要进行内网穿透否则微信服务器无法将消息发送给我们的服务器。
此外需要在我们的服务器上安装python的web框架web.py。 此外需要在我们的服务器上安装python的web框架web.py和wechatpy
以ubuntu为例(在ubuntu 22.04上测试): 以ubuntu为例(在ubuntu 22.04上测试):
``` ```
pip3 install web.py pip3 install web.py
pip3 install wechatpy
``` ```
然后在[微信公众平台](https://mp.weixin.qq.com)注册一个自己的公众号,类型选择订阅号,主体为个人即可。 然后在[微信公众平台](https://mp.weixin.qq.com)注册一个自己的公众号,类型选择订阅号,主体为个人即可。
然后根据[接入指南](https://developers.weixin.qq.com/doc/offiaccount/Basic_Information/Access_Overview.html)的说明,在[微信公众平台](https://mp.weixin.qq.com)的“设置与开发”-“基本配置”-“服务器配置”中填写服务器地址`URL`和令牌`Token`。这里的`URL``example.com/wx`的形式不可以使用IP`Token`是你自己编的一个特定的令牌。消息加解密方式目前选择的是明文模式 然后根据[接入指南](https://developers.weixin.qq.com/doc/offiaccount/Basic_Information/Access_Overview.html)的说明,在[微信公众平台](https://mp.weixin.qq.com)的“设置与开发”-“基本配置”-“服务器配置”中填写服务器地址`URL`和令牌`Token`。这里的`URL``example.com/wx`的形式不可以使用IP`Token`是你自己编的一个特定的令牌。消息加解密方式如果选择了需要加密的模式,需要在配置中填写`wechatmp_aes_key`
相关的服务器验证代码已经写好,你不需要再添加任何代码。你只需要在本项目根目录的`config.json`中添加 相关的服务器验证代码已经写好,你不需要再添加任何代码。你只需要在本项目根目录的`config.json`中添加
``` ```
"channel_type": "wechatmp", "channel_type": "wechatmp", # 如果通过了微信认证,将"wechatmp"替换为"wechatmp_service",可极大的优化使用体验
"wechatmp_token": "Token", # 微信公众平台的Token "wechatmp_token": "xxxx", # 微信公众平台的Token
"wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443 "wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443
"wechatmp_app_id": "", # 微信公众平台的appID,仅服务号需要 "wechatmp_app_id": "xxxx", # 微信公众平台的appID
"wechatmp_app_secret": "", # 微信公众平台的appsecret,仅服务号需要 "wechatmp_app_secret": "xxxx", # 微信公众平台的appsecret
``` "wechatmp_aes_key": "", # 微信公众平台的EncodingAESKey加密模式需要
然后运行`python3 app.py`启动web服务器。这里会默认监听8080端口但是微信公众号的服务器配置只支持80/443端口有两种方法来解决这个问题。第一个是推荐的方法使用端口转发命令将80端口转发到8080端口443同理注意需要支持SSL也就是https的访问在`wechatmp_channel.py`需要修改相应的证书路径): "single_chat_prefix": [""], # 推荐设置,任意对话都可以触发回复,不添加前缀
"single_chat_reply_prefix": "", # 推荐设置,回复不设置前缀
"plugin_trigger_prefix": "&", # 推荐设置,在手机微信客户端中,$%^等符号与中文连在一起时会自动显示一段较大的间隔,用户体验不好。请不要使用管理员指令前缀"#",这会造成未知问题。
```
然后运行`python3 app.py`启动web服务器。这里会默认监听8080端口但是微信公众号的服务器配置只支持80/443端口有两种方法来解决这个问题。第一个是推荐的方法使用端口转发命令将80端口转发到8080端口
``` ```
sudo iptables -t nat -A PREROUTING -p tcp --dport 80 -j REDIRECT --to-port 8080 sudo iptables -t nat -A PREROUTING -p tcp --dport 80 -j REDIRECT --to-port 8080
sudo iptables-save > /etc/iptables/rules.v4 sudo iptables-save > /etc/iptables/rules.v4
``` ```
第二个方法是让python程序直接监听80端口。这样可能会导致权限问题在linux上需要使用`sudo`。然而这会导致后续缓存文件的权限问题,因此不是推荐的方法。 第二个方法是让python程序直接监听80端口,在配置文件中设置`"wechatmp_port": 80` 在linux上需要使用`sudo python3 app.py`启动程序。然而这会导致一系列环境和权限问题,因此不是推荐的方法。
最后在刚才的“服务器配置”中点击`提交`即可验证你的服务器。
443端口同理注意需要支持SSL也就是https的访问`wechatmp_channel.py`中需要修改相应的证书路径。
程序启动并监听端口后,在刚才的“服务器配置”中点击`提交`即可验证你的服务器。
随后在[微信公众平台](https://mp.weixin.qq.com)启用服务器关闭手动填写规则的自动回复即可实现ChatGPT的自动回复。 随后在[微信公众平台](https://mp.weixin.qq.com)启用服务器关闭手动填写规则的自动回复即可实现ChatGPT的自动回复。
之后需要在公众号开发信息下将本机IP加入到IP白名单。
不然在启用后,发送语音、图片等消息可能会遇到如下报错:
```
'errcode': 40164, 'errmsg': 'invalid ip xx.xx.xx.xx not in whitelist rid
```
## 个人微信公众号的限制 ## 个人微信公众号的限制
由于人微信公众号不能通过微信认证所以没有客服接口因此公众号无法主动发出消息只能被动回复。而微信官方对被动回复有5秒的时间限制最多重试2次因此最多只有15秒的自动回复时间窗口。因此如果问题比较复杂或者我们的服务器比较忙ChatGPT的回答就没办法及时回复给用户。为了解决这个问题这里做了回答缓存它需要你在回复超时后再次主动发送任意文字例如1来尝试拿到回答缓存。为了优化使用体验目前设置了两分钟120秒的timeout用户在至多两分钟后即可得到查询到回复或者错误原因。 由于人微信公众号不能通过微信认证所以没有客服接口因此公众号无法主动发出消息只能被动回复。而微信官方对被动回复有5秒的时间限制最多重试2次因此最多只有15秒的自动回复时间窗口。因此如果问题比较复杂或者我们的服务器比较忙ChatGPT的回答就没办法及时回复给用户。为了解决这个问题这里做了回答缓存它需要你在回复超时后再次主动发送任意文字例如1来尝试拿到回答缓存。为了优化使用体验目前设置了两分钟120秒的timeout用户在至多两分钟后即可得到查询到回复或者错误原因。
另外由于微信官方的限制自动回复有长度限制。因此这里将ChatGPT的回答拆分分成每段600字回复限制大约在700字 另外由于微信官方的限制自动回复有长度限制。因此这里将ChatGPT的回答进行了拆分,以满足限制
## 私有api_key ## 私有api_key
公共api有访问频率限制免费账号每分钟最多20次ChatGPT的API调用这在服务多人的时候会遇到问题。因此这里多加了一个设置私有api_key的功能。目前通过godcmd插件的命令来设置私有api_key。 公共api有访问频率限制免费账号每分钟最多3次ChatGPT的API调用这在服务多人的时候会遇到问题。因此这里多加了一个设置私有api_key的功能。目前通过godcmd插件的命令来设置私有api_key。
## 语音输入 ## 语音输入
利用微信自带的语音识别功能,提供语音输入能力。需要在公众号管理页面的“设置与开发”->“接口权限”页面开启“接收语音识别结果”。 利用微信自带的语音识别功能,提供语音输入能力。需要在公众号管理页面的“设置与开发”->“接口权限”页面开启“接收语音识别结果”。
## 测试范围 ## 语音回复
目前在`RoboStyle`这个公众号上进行了测试(基于[wechatmp分支](https://github.com/JS00000/chatgpt-on-wechat/tree/wechatmp)感兴趣的可以关注并体验。开启了godcmd, Banwords, role, dungeon, finish这五个插件其他的插件还没有测试。百度的接口暂未测试。语音对话没有测试。图片直接以链接形式回复没有临时素材上传接口的权限 请在配置文件中添加以下词条:
```
"voice_reply_voice": true,
```
这样公众号将会用语音回复语音消息,实现语音对话。
默认的语音合成引擎是`google`,它是免费使用的。
如果要选择其他的语音合成引擎,请添加以下配置项:
```
"text_to_voice": "pytts"
```
pytts是本地的语音合成引擎。还支持baidu,azure这些你需要自行配置相关的依赖和key。
如果使用pytts在ubuntu上需要安装如下依赖
```
sudo apt update
sudo apt install espeak
sudo apt install ffmpeg
python3 -m pip install pyttsx3
```
不是很建议开启pytts语音回复因为它是离线本地计算算的慢会拖垮服务器且声音不好听。
## 图片回复
现在认证公众号和非认证公众号都可以实现的图片和语音回复。但是非认证公众号使用了永久素材接口每天有1000次的调用上限每个月有10次重置机会程序中已设定遇到上限会自动重置且永久素材库存也有上限。因此对于非认证公众号我们会在回复图片或者语音消息后的10秒内从永久素材库存内删除该素材。
## 测试
目前在`RoboStyle`这个公众号上进行了测试(基于[wechatmp分支](https://github.com/JS00000/chatgpt-on-wechat/tree/wechatmp)感兴趣的可以关注并体验。开启了godcmd, Banwords, role, dungeon, finish这五个插件其他的插件还没有详尽测试。百度的接口暂未测试。[wechatmp-stable分支](https://github.com/JS00000/chatgpt-on-wechat/tree/wechatmp-stable)是较稳定的上个版本,但也缺少最新的功能支持。
## TODO ## TODO
* 服务号交互完善 - [x] 语音输入
* 服务号使用临时素材接口,提供图片回复能力 - [x] 图片输入
* 插件测试 - [x] 使用临时素材接口提供认证公众号的图片和语音回复
- [x] 使用永久素材接口提供未认证公众号的图片和语音回复
- [ ] 高并发支持

View File

@@ -1,51 +0,0 @@
import web
import time
import channel.wechatmp.reply as reply
import channel.wechatmp.receive as receive
from config import conf
from common.log import logger
from bridge.context import *
from channel.wechatmp.common import *
from channel.wechatmp.wechatmp_channel import WechatMPChannel
# This class is instantiated once per query
class Query():
def GET(self):
return verify_server(web.input())
def POST(self):
# Make sure to return the instance that first created, @singleton will do that.
channel = WechatMPChannel()
try:
webData = web.data()
# logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8"))
wechatmp_msg = receive.parse_xml(webData)
if wechatmp_msg.msg_type == 'text' or wechatmp_msg.msg_type == 'voice':
from_user = wechatmp_msg.from_user_id
message = wechatmp_msg.content.decode("utf-8")
message_id = wechatmp_msg.msg_id
logger.info("[wechatmp] {}:{} Receive post query {} {}: {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), from_user, message_id, message))
context = channel._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg)
if context:
# set private openai_api_key
# if from_user is not changed in itchat, this can be placed at chat_channel
user_data = conf().get_user_data(from_user)
context['openai_api_key'] = user_data.get('openai_api_key') # None or user openai_api_key
channel.produce(context)
# The reply will be sent by channel.send() in another thread
return "success"
elif wechatmp_msg.msg_type == 'event':
logger.info("[wechatmp] Event {} from {}".format(wechatmp_msg.Event, wechatmp_msg.from_user_id))
content = subscribe_msg()
replyMsg = reply.TextMsg(wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content)
return replyMsg.send()
else:
logger.info("暂且不处理")
return "success"
except Exception as exc:
logger.exception(exc)
return exc

View File

@@ -1,172 +0,0 @@
import web
import time
import channel.wechatmp.reply as reply
import channel.wechatmp.receive as receive
from config import conf
from common.log import logger
from bridge.context import *
from channel.wechatmp.common import *
from channel.wechatmp.wechatmp_channel import WechatMPChannel
# This class is instantiated once per query
class Query():
def GET(self):
return verify_server(web.input())
def POST(self):
# Make sure to return the instance that first created, @singleton will do that.
channel = WechatMPChannel()
try:
query_time = time.time()
webData = web.data()
logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8"))
wechatmp_msg = receive.parse_xml(webData)
if wechatmp_msg.msg_type == 'text' or wechatmp_msg.msg_type == 'voice':
from_user = wechatmp_msg.from_user_id
to_user = wechatmp_msg.to_user_id
message = wechatmp_msg.content.decode("utf-8")
message_id = wechatmp_msg.msg_id
logger.info("[wechatmp] {}:{} Receive post query {} {}: {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), from_user, message_id, message))
supported = True
if "【收到不支持的消息类型,暂无法显示】" in message:
supported = False # not supported, used to refresh
cache_key = from_user
reply_text = ""
# New request
if cache_key not in channel.cache_dict and cache_key not in channel.running:
# The first query begin, reset the cache
context = channel._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg)
logger.debug("[wechatmp] context: {} {}".format(context, wechatmp_msg))
if message_id in channel.received_msgs: # received and finished
# no return because of bandwords or other reasons
return "success"
if supported and context:
# set private openai_api_key
# if from_user is not changed in itchat, this can be placed at chat_channel
user_data = conf().get_user_data(from_user)
context['openai_api_key'] = user_data.get('openai_api_key') # None or user openai_api_key
channel.received_msgs[message_id] = wechatmp_msg
channel.running.add(cache_key)
channel.produce(context)
else:
trigger_prefix = conf().get('single_chat_prefix',[''])[0]
if trigger_prefix or not supported:
if trigger_prefix:
content = textwrap.dedent(f"""\
请输入'{trigger_prefix}'接你想说的话跟我说话。
例如:
{trigger_prefix}你好,很高兴见到你。""")
else:
content = textwrap.dedent("""\
你好,很高兴见到你。
请跟我说话吧。""")
else:
logger.error(f"[wechatmp] unknown error")
content = textwrap.dedent("""\
未知错误,请稍后再试""")
replyMsg = reply.TextMsg(wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content)
return replyMsg.send()
channel.query1[cache_key] = False
channel.query2[cache_key] = False
channel.query3[cache_key] = False
# User request again, and the answer is not ready
elif cache_key in channel.running and channel.query1.get(cache_key) == True and channel.query2.get(cache_key) == True and channel.query3.get(cache_key) == True:
channel.query1[cache_key] = False #To improve waiting experience, this can be set to True.
channel.query2[cache_key] = False #To improve waiting experience, this can be set to True.
channel.query3[cache_key] = False
# User request again, and the answer is ready
elif cache_key in channel.cache_dict:
# Skip the waiting phase
channel.query1[cache_key] = True
channel.query2[cache_key] = True
channel.query3[cache_key] = True
assert not (cache_key in channel.cache_dict and cache_key in channel.running)
if channel.query1.get(cache_key) == False:
# The first query from wechat official server
logger.debug("[wechatmp] query1 {}".format(cache_key))
channel.query1[cache_key] = True
cnt = 0
while cache_key in channel.running and cnt < 45:
cnt = cnt + 1
time.sleep(0.1)
if cnt == 45:
# waiting for timeout (the POST query will be closed by wechat official server)
time.sleep(1)
# and do nothing
return
else:
pass
elif channel.query2.get(cache_key) == False:
# The second query from wechat official server
logger.debug("[wechatmp] query2 {}".format(cache_key))
channel.query2[cache_key] = True
cnt = 0
while cache_key in channel.running and cnt < 45:
cnt = cnt + 1
time.sleep(0.1)
if cnt == 45:
# waiting for timeout (the POST query will be closed by wechat official server)
time.sleep(1)
# and do nothing
return
else:
pass
elif channel.query3.get(cache_key) == False:
# The third query from wechat official server
logger.debug("[wechatmp] query3 {}".format(cache_key))
channel.query3[cache_key] = True
cnt = 0
while cache_key in channel.running and cnt < 40:
cnt = cnt + 1
time.sleep(0.1)
if cnt == 40:
# Have waiting for 3x5 seconds
# return timeout message
reply_text = "【正在思考中,回复任意文字尝试获取回复】"
logger.info("[wechatmp] Three queries has finished For {}: {}".format(from_user, message_id))
replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
return replyPost
else:
pass
if cache_key not in channel.cache_dict and cache_key not in channel.running:
# no return because of bandwords or other reasons
return "success"
# if float(time.time()) - float(query_time) > 4.8:
# reply_text = "【正在思考中,回复任意文字尝试获取回复】"
# logger.info("[wechatmp] Timeout for {} {}, return".format(from_user, message_id))
# replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
# return replyPost
if cache_key in channel.cache_dict:
content = channel.cache_dict[cache_key]
if len(content.encode('utf8'))<=MAX_UTF8_LEN:
reply_text = channel.cache_dict[cache_key]
channel.cache_dict.pop(cache_key)
else:
continue_text = "\n【未完待续,回复任意文字以继续】"
splits = split_string_by_utf8_length(content, MAX_UTF8_LEN - len(continue_text.encode('utf-8')), max_split= 1)
reply_text = splits[0] + continue_text
channel.cache_dict[cache_key] = splits[1]
logger.info("[wechatmp] {}:{} Do send {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), reply_text))
replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
return replyPost
elif wechatmp_msg.msg_type == 'event':
logger.info("[wechatmp] Event {} from {}".format(wechatmp_msg.content, wechatmp_msg.from_user_id))
content = subscribe_msg()
replyMsg = reply.TextMsg(wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content)
return replyMsg.send()
else:
logger.info("暂且不处理")
return "success"
except Exception as exc:
logger.exception(exc)
return exc

View File

@@ -0,0 +1,78 @@
import time
import web
from wechatpy import parse_message
from wechatpy.replies import create_reply
from bridge.context import *
from bridge.reply import *
from channel.wechatmp.common import *
from channel.wechatmp.wechatmp_channel import WechatMPChannel
from channel.wechatmp.wechatmp_message import WeChatMPMessage
from common.log import logger
from config import conf
# This class is instantiated once per query
class Query:
def GET(self):
return verify_server(web.input())
def POST(self):
# Make sure to return the instance that first created, @singleton will do that.
try:
args = web.input()
verify_server(args)
channel = WechatMPChannel()
message = web.data()
encrypt_func = lambda x: x
if args.get("encrypt_type") == "aes":
logger.debug("[wechatmp] Receive encrypted post data:\n" + message.decode("utf-8"))
if not channel.crypto:
raise Exception("Crypto not initialized, Please set wechatmp_aes_key in config.json")
message = channel.crypto.decrypt_message(message, args.msg_signature, args.timestamp, args.nonce)
encrypt_func = lambda x: channel.crypto.encrypt_message(x, args.nonce, args.timestamp)
else:
logger.debug("[wechatmp] Receive post data:\n" + message.decode("utf-8"))
msg = parse_message(message)
if msg.type in ["text", "voice", "image"]:
wechatmp_msg = WeChatMPMessage(msg, client=channel.client)
from_user = wechatmp_msg.from_user_id
content = wechatmp_msg.content
message_id = wechatmp_msg.msg_id
logger.info(
"[wechatmp] {}:{} Receive post query {} {}: {}".format(
web.ctx.env.get("REMOTE_ADDR"),
web.ctx.env.get("REMOTE_PORT"),
from_user,
message_id,
content,
)
)
if msg.type == "voice" and wechatmp_msg.ctype == ContextType.TEXT and conf().get("voice_reply_voice", False):
context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg)
else:
context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg)
if context:
# set private openai_api_key
# if from_user is not changed in itchat, this can be placed at chat_channel
user_data = conf().get_user_data(from_user)
context["openai_api_key"] = user_data.get("openai_api_key") # None or user openai_api_key
channel.produce(context)
# The reply will be sent by channel.send() in another thread
return "success"
elif msg.type == "event":
logger.info("[wechatmp] Event {} from {}".format(msg.event, msg.source))
if msg.event in ["subscribe", "subscribe_scan"]:
reply_text = subscribe_msg()
replyPost = create_reply(reply_text, msg)
return encrypt_func(replyPost.render())
else:
return "success"
else:
logger.info("暂且不处理")
return "success"
except Exception as exc:
logger.exception(exc)
return exc

View File

@@ -1,63 +1,62 @@
from config import conf
import hashlib
import textwrap import textwrap
import web
from wechatpy.crypto import WeChatCrypto
from wechatpy.exceptions import InvalidSignatureException
from wechatpy.utils import check_signature
from config import conf
MAX_UTF8_LEN = 2048 MAX_UTF8_LEN = 2048
class WeChatAPIException(Exception): class WeChatAPIException(Exception):
pass pass
def verify_server(data): def verify_server(data):
try: try:
if len(data) == 0:
return "None"
signature = data.signature signature = data.signature
timestamp = data.timestamp timestamp = data.timestamp
nonce = data.nonce nonce = data.nonce
echostr = data.echostr echostr = data.get("echostr", None)
token = conf().get('wechatmp_token') #请按照公众平台官网\基本配置中信息填写 token = conf().get("wechatmp_token") # 请按照公众平台官网\基本配置中信息填写
check_signature(token, signature, timestamp, nonce)
return echostr
except InvalidSignatureException:
raise web.Forbidden("Invalid signature")
except Exception as e:
raise web.Forbidden(str(e))
data_list = [token, timestamp, nonce]
data_list.sort()
sha1 = hashlib.sha1()
# map(sha1.update, data_list) #python2
sha1.update("".join(data_list).encode('utf-8'))
hashcode = sha1.hexdigest()
print("handle/GET func: hashcode, signature: ", hashcode, signature)
if hashcode == signature:
return echostr
else:
return ""
except Exception as Argument:
return Argument
def subscribe_msg(): def subscribe_msg():
trigger_prefix = conf().get('single_chat_prefix',[''])[0] trigger_prefix = conf().get("single_chat_prefix", [""])[0]
msg = textwrap.dedent(f"""\ msg = textwrap.dedent(
f"""\
感谢您的关注! 感谢您的关注!
这里是ChatGPT可以自由对话。 这里是ChatGPT可以自由对话。
资源有限,回复较慢,请勿着急。 资源有限,回复较慢,请勿着急。
支持通用表情输入 支持语音对话
暂时不支持图片输入。 支持图片输入。
支持图片输出,画字开头的问题将回复图片链接 支持图片输出,画字开头的消息将按要求创作图片
支持角色扮演和文字冒险两种定制模式对话 支持tool、角色扮演和文字冒险等丰富的插件
输入'{trigger_prefix}#帮助' 查看详细指令。""") 输入'{trigger_prefix}#帮助' 查看详细指令。"""
)
return msg return msg
def split_string_by_utf8_length(string, max_length, max_split=0): def split_string_by_utf8_length(string, max_length, max_split=0):
encoded = string.encode('utf-8') encoded = string.encode("utf-8")
start, end = 0, 0 start, end = 0, 0
result = [] result = []
while end < len(encoded): while end < len(encoded):
if max_split > 0 and len(result) >= max_split: if max_split > 0 and len(result) >= max_split:
result.append(encoded[start:].decode('utf-8')) result.append(encoded[start:].decode("utf-8"))
break break
end = start + max_length end = min(start + max_length, len(encoded))
# 如果当前字节不是 UTF-8 编码的开始字节,则向前查找直到找到开始字节为止 # 如果当前字节不是 UTF-8 编码的开始字节,则向前查找直到找到开始字节为止
while end < len(encoded) and (encoded[end] & 0b11000000) == 0b10000000: while end < len(encoded) and (encoded[end] & 0b11000000) == 0b10000000:
end -= 1 end -= 1
result.append(encoded[start:end].decode('utf-8')) result.append(encoded[start:end].decode("utf-8"))
start = end start = end
return result return result

View File

@@ -0,0 +1,212 @@
import asyncio
import time
import web
from wechatpy import parse_message
from wechatpy.replies import ImageReply, VoiceReply, create_reply
from bridge.context import *
from bridge.reply import *
from channel.wechatmp.common import *
from channel.wechatmp.wechatmp_channel import WechatMPChannel
from channel.wechatmp.wechatmp_message import WeChatMPMessage
from common.log import logger
from config import conf
# This class is instantiated once per query
class Query:
def GET(self):
return verify_server(web.input())
def POST(self):
try:
args = web.input()
verify_server(args)
request_time = time.time()
channel = WechatMPChannel()
message = web.data()
encrypt_func = lambda x: x
if args.get("encrypt_type") == "aes":
logger.debug("[wechatmp] Receive encrypted post data:\n" + message.decode("utf-8"))
if not channel.crypto:
raise Exception("Crypto not initialized, Please set wechatmp_aes_key in config.json")
message = channel.crypto.decrypt_message(message, args.msg_signature, args.timestamp, args.nonce)
encrypt_func = lambda x: channel.crypto.encrypt_message(x, args.nonce, args.timestamp)
else:
logger.debug("[wechatmp] Receive post data:\n" + message.decode("utf-8"))
msg = parse_message(message)
if msg.type in ["text", "voice", "image"]:
wechatmp_msg = WeChatMPMessage(msg, client=channel.client)
from_user = wechatmp_msg.from_user_id
content = wechatmp_msg.content
message_id = wechatmp_msg.msg_id
supported = True
if "【收到不支持的消息类型,暂无法显示】" in content:
supported = False # not supported, used to refresh
# New request
if (
from_user not in channel.cache_dict
and from_user not in channel.running
or content.startswith("#")
and message_id not in channel.request_cnt # insert the godcmd
):
# The first query begin
if msg.type == "voice" and wechatmp_msg.ctype == ContextType.TEXT and conf().get("voice_reply_voice", False):
context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg)
else:
context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg)
logger.debug("[wechatmp] context: {} {} {}".format(context, wechatmp_msg, supported))
if supported and context:
# set private openai_api_key
# if from_user is not changed in itchat, this can be placed at chat_channel
user_data = conf().get_user_data(from_user)
context["openai_api_key"] = user_data.get("openai_api_key")
channel.running.add(from_user)
channel.produce(context)
else:
trigger_prefix = conf().get("single_chat_prefix", [""])[0]
if trigger_prefix or not supported:
if trigger_prefix:
reply_text = textwrap.dedent(
f"""\
请输入'{trigger_prefix}'接你想说的话跟我说话。
例如:
{trigger_prefix}你好,很高兴见到你。"""
)
else:
reply_text = textwrap.dedent(
"""\
你好,很高兴见到你。
请跟我说话吧。"""
)
else:
logger.error(f"[wechatmp] unknown error")
reply_text = textwrap.dedent(
"""\
未知错误,请稍后再试"""
)
replyPost = create_reply(reply_text, msg)
return encrypt_func(replyPost.render())
# Wechat official server will request 3 times (5 seconds each), with the same message_id.
# Because the interval is 5 seconds, here assumed that do not have multithreading problems.
request_cnt = channel.request_cnt.get(message_id, 0) + 1
channel.request_cnt[message_id] = request_cnt
logger.info(
"[wechatmp] Request {} from {} {} {}:{}\n{}".format(
request_cnt, from_user, message_id, web.ctx.env.get("REMOTE_ADDR"), web.ctx.env.get("REMOTE_PORT"), content
)
)
task_running = True
waiting_until = request_time + 4
while time.time() < waiting_until:
if from_user in channel.running:
time.sleep(0.1)
else:
task_running = False
break
reply_text = ""
if task_running:
if request_cnt < 3:
# waiting for timeout (the POST request will be closed by Wechat official server)
time.sleep(2)
# and do nothing, waiting for the next request
return "success"
else: # request_cnt == 3:
# return timeout message
reply_text = "【正在思考中,回复任意文字尝试获取回复】"
replyPost = create_reply(reply_text, msg)
return encrypt_func(replyPost.render())
# reply is ready
channel.request_cnt.pop(message_id)
# no return because of bandwords or other reasons
if from_user not in channel.cache_dict and from_user not in channel.running:
return "success"
# Only one request can access to the cached data
try:
(reply_type, reply_content) = channel.cache_dict.pop(from_user)
except KeyError:
return "success"
if reply_type == "text":
if len(reply_content.encode("utf8")) <= MAX_UTF8_LEN:
reply_text = reply_content
else:
continue_text = "\n【未完待续,回复任意文字以继续】"
splits = split_string_by_utf8_length(
reply_content,
MAX_UTF8_LEN - len(continue_text.encode("utf-8")),
max_split=1,
)
reply_text = splits[0] + continue_text
channel.cache_dict[from_user] = ("text", splits[1])
logger.info(
"[wechatmp] Request {} do send to {} {}: {}\n{}".format(
request_cnt,
from_user,
message_id,
content,
reply_text,
)
)
replyPost = create_reply(reply_text, msg)
return encrypt_func(replyPost.render())
elif reply_type == "voice":
media_id = reply_content
asyncio.run_coroutine_threadsafe(channel.delete_media(media_id), channel.delete_media_loop)
logger.info(
"[wechatmp] Request {} do send to {} {}: {} voice media_id {}".format(
request_cnt,
from_user,
message_id,
content,
media_id,
)
)
replyPost = VoiceReply(message=msg)
replyPost.media_id = media_id
return encrypt_func(replyPost.render())
elif reply_type == "image":
media_id = reply_content
asyncio.run_coroutine_threadsafe(channel.delete_media(media_id), channel.delete_media_loop)
logger.info(
"[wechatmp] Request {} do send to {} {}: {} image media_id {}".format(
request_cnt,
from_user,
message_id,
content,
media_id,
)
)
replyPost = ImageReply(message=msg)
replyPost.media_id = media_id
return encrypt_func(replyPost.render())
elif msg.type == "event":
logger.info("[wechatmp] Event {} from {}".format(msg.event, msg.source))
if msg.event in ["subscribe", "subscribe_scan"]:
reply_text = subscribe_msg()
replyPost = create_reply(reply_text, msg)
return encrypt_func(replyPost.render())
else:
return "success"
else:
logger.info("暂且不处理")
return "success"
except Exception as exc:
logger.exception(exc)
return exc

View File

@@ -1,45 +0,0 @@
# -*- coding: utf-8 -*-#
# filename: receive.py
import xml.etree.ElementTree as ET
from bridge.context import ContextType
from channel.chat_message import ChatMessage
from common.log import logger
def parse_xml(web_data):
if len(web_data) == 0:
return None
xmlData = ET.fromstring(web_data)
return WeChatMPMessage(xmlData)
class WeChatMPMessage(ChatMessage):
def __init__(self, xmlData):
super().__init__(xmlData)
self.to_user_id = xmlData.find('ToUserName').text
self.from_user_id = xmlData.find('FromUserName').text
self.create_time = xmlData.find('CreateTime').text
self.msg_type = xmlData.find('MsgType').text
try:
self.msg_id = xmlData.find('MsgId').text
except:
self.msg_id = self.from_user_id+self.create_time
self.is_group = False
# reply to other_user_id
self.other_user_id = self.from_user_id
if self.msg_type == 'text':
self.ctype = ContextType.TEXT
self.content = xmlData.find('Content').text.encode("utf-8")
elif self.msg_type == 'voice':
self.ctype = ContextType.TEXT
self.content = xmlData.find('Recognition').text.encode("utf-8") # 接收语音识别结果
elif self.msg_type == 'image':
# not implemented
self.pic_url = xmlData.find('PicUrl').text
self.media_id = xmlData.find('MediaId').text
elif self.msg_type == 'event':
self.content = xmlData.find('Event').text
else: # video, shortvideo, location, link
# not implemented
pass

View File

@@ -1,52 +0,0 @@
# -*- coding: utf-8 -*-#
# filename: reply.py
import time
class Msg(object):
def __init__(self):
pass
def send(self):
return "success"
class TextMsg(Msg):
def __init__(self, toUserName, fromUserName, content):
self.__dict = dict()
self.__dict['ToUserName'] = toUserName
self.__dict['FromUserName'] = fromUserName
self.__dict['CreateTime'] = int(time.time())
self.__dict['Content'] = content
def send(self):
XmlForm = """
<xml>
<ToUserName><![CDATA[{ToUserName}]]></ToUserName>
<FromUserName><![CDATA[{FromUserName}]]></FromUserName>
<CreateTime>{CreateTime}</CreateTime>
<MsgType><![CDATA[text]]></MsgType>
<Content><![CDATA[{Content}]]></Content>
</xml>
"""
return XmlForm.format(**self.__dict)
class ImageMsg(Msg):
def __init__(self, toUserName, fromUserName, mediaId):
self.__dict = dict()
self.__dict['ToUserName'] = toUserName
self.__dict['FromUserName'] = fromUserName
self.__dict['CreateTime'] = int(time.time())
self.__dict['MediaId'] = mediaId
def send(self):
XmlForm = """
<xml>
<ToUserName><![CDATA[{ToUserName}]]></ToUserName>
<FromUserName><![CDATA[{FromUserName}]]></FromUserName>
<CreateTime>{CreateTime}</CreateTime>
<MsgType><![CDATA[image]]></MsgType>
<Image>
<MediaId><![CDATA[{MediaId}]]></MediaId>
</Image>
</xml>
"""
return XmlForm.format(**self.__dict)

View File

@@ -1,17 +1,25 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import web import asyncio
import time import imghdr
import json import io
import requests import os
import threading import threading
from common.singleton import singleton import time
from common.log import logger
from common.expired_dict import ExpiredDict import requests
from config import conf import web
from bridge.reply import * from wechatpy.crypto import WeChatCrypto
from wechatpy.exceptions import WeChatClientException
from bridge.context import * from bridge.context import *
from bridge.reply import *
from channel.chat_channel import ChatChannel from channel.chat_channel import ChatChannel
from channel.wechatmp.common import * from channel.wechatmp.common import *
from channel.wechatmp.wechatmp_client import WechatMPClient
from common.log import logger
from common.singleton import singleton
from config import conf
from voice.audio_convert import any_to_mp3
# If using SSL, uncomment the following lines, and modify the certificate path. # If using SSL, uncomment the following lines, and modify the certificate path.
# from cheroot.server import HTTPServer # from cheroot.server import HTTPServer
@@ -20,110 +28,186 @@ from channel.wechatmp.common import *
# certificate='/ssl/cert.pem', # certificate='/ssl/cert.pem',
# private_key='/ssl/cert.key') # private_key='/ssl/cert.key')
@singleton @singleton
class WechatMPChannel(ChatChannel): class WechatMPChannel(ChatChannel):
def __init__(self, passive_reply = True): def __init__(self, passive_reply=True):
super().__init__() super().__init__()
self.passive_reply = passive_reply self.passive_reply = passive_reply
self.running = set() self.NOT_SUPPORT_REPLYTYPE = []
self.received_msgs = ExpiredDict(60*60*24) appid = conf().get("wechatmp_app_id")
secret = conf().get("wechatmp_app_secret")
token = conf().get("wechatmp_token")
aes_key = conf().get("wechatmp_aes_key")
self.client = WechatMPClient(appid, secret)
self.crypto = None
if aes_key:
self.crypto = WeChatCrypto(token, aes_key, appid)
if self.passive_reply: if self.passive_reply:
self.NOT_SUPPORT_REPLYTYPE = [ReplyType.IMAGE, ReplyType.VOICE] # Cache the reply to the user's first message
self.cache_dict = dict() self.cache_dict = dict()
self.query1 = dict() # Record whether the current message is being processed
self.query2 = dict() self.running = set()
self.query3 = dict() # Count the request from wechat official server by message_id
else: self.request_cnt = dict()
# TODO support image # The permanent media need to be deleted to avoid media number limit
self.NOT_SUPPORT_REPLYTYPE = [ReplyType.IMAGE, ReplyType.VOICE] self.delete_media_loop = asyncio.new_event_loop()
self.app_id = conf().get('wechatmp_app_id') t = threading.Thread(target=self.start_loop, args=(self.delete_media_loop,))
self.app_secret = conf().get('wechatmp_app_secret') t.setDaemon(True)
self.access_token = None t.start()
self.access_token_expires_time = 0
self.access_token_lock = threading.Lock()
self.get_access_token()
def startup(self): def startup(self):
if self.passive_reply: if self.passive_reply:
urls = ('/wx', 'channel.wechatmp.SubscribeAccount.Query') urls = ("/wx", "channel.wechatmp.passive_reply.Query")
else: else:
urls = ('/wx', 'channel.wechatmp.ServiceAccount.Query') urls = ("/wx", "channel.wechatmp.active_reply.Query")
app = web.application(urls, globals(), autoreload=False) app = web.application(urls, globals(), autoreload=False)
port = conf().get('wechatmp_port', 8080) port = conf().get("wechatmp_port", 8080)
web.httpserver.runsimple(app.wsgifunc(), ('0.0.0.0', port)) web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
def start_loop(self, loop):
asyncio.set_event_loop(loop)
loop.run_forever()
def wechatmp_request(self, method, url, **kwargs): async def delete_media(self, media_id):
r = requests.request(method=method, url=url, **kwargs) logger.debug("[wechatmp] permanent media {} will be deleted in 10s".format(media_id))
r.raise_for_status() await asyncio.sleep(10)
r.encoding = "utf-8" self.client.material.delete(media_id)
ret = r.json() logger.info("[wechatmp] permanent media {} has been deleted".format(media_id))
if "errcode" in ret and ret["errcode"] != 0:
raise WeChatAPIException("{}".format(ret))
return ret
def get_access_token(self):
# return the access_token
if self.access_token:
if self.access_token_expires_time - time.time() > 60:
return self.access_token
# Get new access_token
# Do not request access_token in parallel! Only the last obtained is valid.
if self.access_token_lock.acquire(blocking=False):
# Wait for other threads that have previously obtained access_token to complete the request
# This happens every 2 hours, so it doesn't affect the experience very much
time.sleep(1)
self.access_token = None
url="https://api.weixin.qq.com/cgi-bin/token"
params={
"grant_type": "client_credential",
"appid": self.app_id,
"secret": self.app_secret
}
data = self.wechatmp_request(method='get', url=url, params=params)
self.access_token = data['access_token']
self.access_token_expires_time = int(time.time()) + data['expires_in']
logger.info("[wechatmp] access_token: {}".format(self.access_token))
self.access_token_lock.release()
else:
# Wait for token update
while self.access_token_lock.locked():
time.sleep(0.1)
return self.access_token
def send(self, reply: Reply, context: Context): def send(self, reply: Reply, context: Context):
receiver = context["receiver"]
if self.passive_reply: if self.passive_reply:
receiver = context["receiver"] if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR:
self.cache_dict[receiver] = reply.content reply_text = reply.content
logger.info("[send] reply to {} saved to cache: {}".format(receiver, reply)) logger.info("[wechatmp] text cached, receiver {}\n{}".format(receiver, reply_text))
self.cache_dict[receiver] = ("text", reply_text)
elif reply.type == ReplyType.VOICE:
try:
voice_file_path = reply.content
with open(voice_file_path, "rb") as f:
# support: <2M, <60s, mp3/wma/wav/amr
response = self.client.material.add("voice", f)
logger.debug("[wechatmp] upload voice response: {}".format(response))
# 根据文件大小估计一个微信自动审核的时间,审核结束前返回将会导致语音无法播放,这个估计有待验证
f_size = os.fstat(f.fileno()).st_size
time.sleep(1.0 + 2 * f_size / 1024 / 1024)
# todo check media_id
except WeChatClientException as e:
logger.error("[wechatmp] upload voice failed: {}".format(e))
return
media_id = response["media_id"]
logger.info("[wechatmp] voice uploaded, receiver {}, media_id {}".format(receiver, media_id))
self.cache_dict[receiver] = ("voice", media_id)
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
img_url = reply.content
pic_res = requests.get(img_url, stream=True)
image_storage = io.BytesIO()
for block in pic_res.iter_content(1024):
image_storage.write(block)
image_storage.seek(0)
image_type = imghdr.what(image_storage)
filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
content_type = "image/" + image_type
try:
response = self.client.material.add("image", (filename, image_storage, content_type))
logger.debug("[wechatmp] upload image response: {}".format(response))
except WeChatClientException as e:
logger.error("[wechatmp] upload image failed: {}".format(e))
return
media_id = response["media_id"]
logger.info("[wechatmp] image uploaded, receiver {}, media_id {}".format(receiver, media_id))
self.cache_dict[receiver] = ("image", media_id)
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
image_storage = reply.content
image_storage.seek(0)
image_type = imghdr.what(image_storage)
filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
content_type = "image/" + image_type
try:
response = self.client.material.add("image", (filename, image_storage, content_type))
logger.debug("[wechatmp] upload image response: {}".format(response))
except WeChatClientException as e:
logger.error("[wechatmp] upload image failed: {}".format(e))
return
media_id = response["media_id"]
logger.info("[wechatmp] image uploaded, receiver {}, media_id {}".format(receiver, media_id))
self.cache_dict[receiver] = ("image", media_id)
else: else:
receiver = context["receiver"] if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR:
reply_text = reply.content reply_text = reply.content
url="https://api.weixin.qq.com/cgi-bin/message/custom/send" texts = split_string_by_utf8_length(reply_text, MAX_UTF8_LEN)
params = { if len(texts) > 1:
"access_token": self.get_access_token() logger.info("[wechatmp] text too long, split into {} parts".format(len(texts)))
} for text in texts:
json_data = { self.client.message.send_text(receiver, text)
"touser": receiver, logger.info("[wechatmp] Do send text to {}: {}".format(receiver, reply_text))
"msgtype": "text", elif reply.type == ReplyType.VOICE:
"text": {"content": reply_text} try:
} file_path = reply.content
self.wechatmp_request(method='post', url=url, params=params, data=json.dumps(json_data, ensure_ascii=False).encode('utf8')) file_name = os.path.basename(file_path)
logger.info("[send] Do send to {}: {}".format(receiver, reply_text)) file_type = os.path.splitext(file_name)[1]
if file_type == ".mp3":
file_type = "audio/mpeg"
elif file_type == ".amr":
file_type = "audio/amr"
else:
mp3_file = os.path.splitext(file_path)[0] + ".mp3"
any_to_mp3(file_path, mp3_file)
file_path = mp3_file
file_name = os.path.basename(file_path)
file_type = "audio/mpeg"
logger.info("[wechatmp] file_name: {}, file_type: {} ".format(file_name, file_type))
# support: <2M, <60s, AMR\MP3
response = self.client.media.upload("voice", (file_name, open(file_path, "rb"), file_type))
logger.debug("[wechatmp] upload voice response: {}".format(response))
except WeChatClientException as e:
logger.error("[wechatmp] upload voice failed: {}".format(e))
return
self.client.message.send_voice(receiver, response["media_id"])
logger.info("[wechatmp] Do send voice to {}".format(receiver))
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
img_url = reply.content
pic_res = requests.get(img_url, stream=True)
image_storage = io.BytesIO()
for block in pic_res.iter_content(1024):
image_storage.write(block)
image_storage.seek(0)
image_type = imghdr.what(image_storage)
filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
content_type = "image/" + image_type
try:
response = self.client.media.upload("image", (filename, image_storage, content_type))
logger.debug("[wechatmp] upload image response: {}".format(response))
except WeChatClientException as e:
logger.error("[wechatmp] upload image failed: {}".format(e))
return
self.client.message.send_image(receiver, response["media_id"])
logger.info("[wechatmp] Do send image to {}".format(receiver))
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
image_storage = reply.content
image_storage.seek(0)
image_type = imghdr.what(image_storage)
filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
content_type = "image/" + image_type
try:
response = self.client.media.upload("image", (filename, image_storage, content_type))
logger.debug("[wechatmp] upload image response: {}".format(response))
except WeChatClientException as e:
logger.error("[wechatmp] upload image failed: {}".format(e))
return
self.client.message.send_image(receiver, response["media_id"])
logger.info("[wechatmp] Do send image to {}".format(receiver))
return return
def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数
def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数 logger.debug("[wechatmp] Success to generate reply, msgId={}".format(context["msg"].msg_id))
logger.debug("[wechatmp] Success to generate reply, msgId={}".format(context['msg'].msg_id))
if self.passive_reply: if self.passive_reply:
self.running.remove(session_id) self.running.remove(session_id)
def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数
def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数 logger.exception("[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format(context["msg"].msg_id, exception))
logger.exception("[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format(context['msg'].msg_id, exception))
if self.passive_reply: if self.passive_reply:
assert session_id not in self.cache_dict assert session_id not in self.cache_dict
self.running.remove(session_id) self.running.remove(session_id)

View File

@@ -0,0 +1,40 @@
import threading
import time
from wechatpy.client import WeChatClient
from wechatpy.exceptions import APILimitedException
from channel.wechatmp.common import *
from common.log import logger
class WechatMPClient(WeChatClient):
def __init__(self, appid, secret, access_token=None, session=None, timeout=None, auto_retry=True):
super(WechatMPClient, self).__init__(appid, secret, access_token, session, timeout, auto_retry)
self.fetch_access_token_lock = threading.Lock()
def clear_quota(self):
return self.post("clear_quota", data={"appid": self.appid})
def clear_quota_v2(self):
return self.post("clear_quota/v2", params={"appid": self.appid, "appsecret": self.secret})
def fetch_access_token(self): # 重载父类方法加锁避免多线程重复获取access_token
with self.fetch_access_token_lock:
access_token = self.session.get(self.access_token_key)
if access_token:
if not self.expires_at:
return access_token
timestamp = time.time()
if self.expires_at - timestamp > 60:
return access_token
return super().fetch_access_token()
def _request(self, method, url_or_endpoint, **kwargs): # 重载父类方法遇到API限流时清除quota后重试
try:
return super()._request(method, url_or_endpoint, **kwargs)
except APILimitedException as e:
logger.error("[wechatmp] API quata has been used up. {}".format(e))
response = self.clear_quota_v2()
logger.debug("[wechatmp] API quata has been cleard, {}".format(response))
return super()._request(method, url_or_endpoint, **kwargs)

View File

@@ -0,0 +1,56 @@
# -*- coding: utf-8 -*-#
from bridge.context import ContextType
from channel.chat_message import ChatMessage
from common.log import logger
from common.tmp_dir import TmpDir
class WeChatMPMessage(ChatMessage):
def __init__(self, msg, client=None):
super().__init__(msg)
self.msg_id = msg.id
self.create_time = msg.time
self.is_group = False
if msg.type == "text":
self.ctype = ContextType.TEXT
self.content = msg.content
elif msg.type == "voice":
if msg.recognition == None:
self.ctype = ContextType.VOICE
self.content = TmpDir().path() + msg.media_id + "." + msg.format # content直接存临时目录路径
def download_voice():
# 如果响应状态码是200则将响应内容写入本地文件
response = client.media.download(msg.media_id)
if response.status_code == 200:
with open(self.content, "wb") as f:
f.write(response.content)
else:
logger.info(f"[wechatmp] Failed to download voice file, {response.content}")
self._prepare_fn = download_voice
else:
self.ctype = ContextType.TEXT
self.content = msg.recognition
elif msg.type == "image":
self.ctype = ContextType.IMAGE
self.content = TmpDir().path() + msg.media_id + ".png" # content直接存临时目录路径
def download_image():
# 如果响应状态码是200则将响应内容写入本地文件
response = client.media.download(msg.media_id)
if response.status_code == 200:
with open(self.content, "wb") as f:
f.write(response.content)
else:
logger.info(f"[wechatmp] Failed to download image file, {response.content}")
self._prepare_fn = download_image
else:
raise NotImplementedError("Unsupported message type: Type:{} ".format(msg.type))
self.from_user_id = msg.source
self.to_user_id = msg.target
self.other_user_id = msg.source

View File

@@ -2,4 +2,4 @@
OPEN_AI = "openAI" OPEN_AI = "openAI"
CHATGPT = "chatGPT" CHATGPT = "chatGPT"
BAIDU = "baidu" BAIDU = "baidu"
CHATGPTONAZURE = "chatGPTOnAzure" CHATGPTONAZURE = "chatGPTOnAzure"

View File

@@ -1,7 +1,7 @@
from queue import Full, Queue from queue import Full, Queue
from time import monotonic as time from time import monotonic as time
# add implementation of putleft to Queue # add implementation of putleft to Queue
class Dequeue(Queue): class Dequeue(Queue):
def putleft(self, item, block=True, timeout=None): def putleft(self, item, block=True, timeout=None):
@@ -30,4 +30,4 @@ class Dequeue(Queue):
return self.putleft(item, block=False) return self.putleft(item, block=False)
def _putleft(self, item): def _putleft(self, item):
self.queue.appendleft(item) self.queue.appendleft(item)

View File

@@ -39,4 +39,4 @@ class ExpiredDict(dict):
return [(key, self[key]) for key in self.keys()] return [(key, self[key]) for key in self.keys()]
def __iter__(self): def __iter__(self):
return self.keys().__iter__() return self.keys().__iter__()

View File

@@ -10,20 +10,29 @@ def _reset_logger(log):
log.handlers.clear() log.handlers.clear()
log.propagate = False log.propagate = False
console_handle = logging.StreamHandler(sys.stdout) console_handle = logging.StreamHandler(sys.stdout)
console_handle.setFormatter(logging.Formatter('[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s', console_handle.setFormatter(
datefmt='%Y-%m-%d %H:%M:%S')) logging.Formatter(
file_handle = logging.FileHandler('run.log', encoding='utf-8') "[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s",
file_handle.setFormatter(logging.Formatter('[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s', datefmt="%Y-%m-%d %H:%M:%S",
datefmt='%Y-%m-%d %H:%M:%S')) )
)
file_handle = logging.FileHandler("run.log", encoding="utf-8")
file_handle.setFormatter(
logging.Formatter(
"[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
)
log.addHandler(file_handle) log.addHandler(file_handle)
log.addHandler(console_handle) log.addHandler(console_handle)
def _get_logger(): def _get_logger():
log = logging.getLogger('log') log = logging.getLogger("log")
_reset_logger(log) _reset_logger(log)
log.setLevel(logging.INFO) log.setLevel(logging.INFO)
return log return log
# 日志句柄 # 日志句柄
logger = _get_logger() logger = _get_logger()

View File

@@ -1,15 +1,20 @@
import time import time
import pip import pip
from pip._internal import main as pipmain from pip._internal import main as pipmain
from common.log import logger,_reset_logger
from common.log import _reset_logger, logger
def install(package): def install(package):
pipmain(['install', package]) pipmain(["install", package])
def install_requirements(file): def install_requirements(file):
pipmain(['install', '-r', file, "--upgrade"]) pipmain(["install", "-r", file, "--upgrade"])
_reset_logger(logger) _reset_logger(logger)
def check_dulwich(): def check_dulwich():
needwait = False needwait = False
for i in range(2): for i in range(2):
@@ -18,13 +23,14 @@ def check_dulwich():
needwait = False needwait = False
try: try:
import dulwich import dulwich
return return
except ImportError: except ImportError:
try: try:
install('dulwich') install("dulwich")
except: except:
needwait = True needwait = True
try: try:
import dulwich import dulwich
except ImportError: except ImportError:
raise ImportError("Unable to import dulwich") raise ImportError("Unable to import dulwich")

View File

@@ -62,4 +62,4 @@ class SortedDict(dict):
return iter(self.keys()) return iter(self.keys())
def __repr__(self): def __repr__(self):
return f'{type(self).__name__}({dict(self)}, sort_func={self.sort_func.__name__}, reverse={self.reverse})' return f"{type(self).__name__}({dict(self)}, sort_func={self.sort_func.__name__}, reverse={self.reverse})"

View File

@@ -1,7 +1,11 @@
import time,re,hashlib import hashlib
import re
import time
import config import config
from common.log import logger from common.log import logger
def time_checker(f): def time_checker(f):
def _time_checker(self, *args, **kwargs): def _time_checker(self, *args, **kwargs):
_config = config.conf() _config = config.conf()
@@ -9,17 +13,17 @@ def time_checker(f):
if chat_time_module: if chat_time_module:
chat_start_time = _config.get("chat_start_time", "00:00") chat_start_time = _config.get("chat_start_time", "00:00")
chat_stopt_time = _config.get("chat_stop_time", "24:00") chat_stopt_time = _config.get("chat_stop_time", "24:00")
time_regex = re.compile(r'^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$') #时间匹配包含24:00 time_regex = re.compile(r"^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$") # 时间匹配包含24:00
starttime_format_check = time_regex.match(chat_start_time) # 检查停止时间格式 starttime_format_check = time_regex.match(chat_start_time) # 检查停止时间格式
stoptime_format_check = time_regex.match(chat_stopt_time) # 检查停止时间格式 stoptime_format_check = time_regex.match(chat_stopt_time) # 检查停止时间格式
chat_time_check = chat_start_time < chat_stopt_time # 确定启动时间<停止时间 chat_time_check = chat_start_time < chat_stopt_time # 确定启动时间<停止时间
# 时间格式检查 # 时间格式检查
if not (starttime_format_check and stoptime_format_check and chat_time_check): if not (starttime_format_check and stoptime_format_check and chat_time_check):
logger.warn('时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})'.format(starttime_format_check,stoptime_format_check)) logger.warn("时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})".format(starttime_format_check, stoptime_format_check))
if chat_start_time>"23:59": if chat_start_time > "23:59":
logger.error('启动时间可能存在问题,请修改!') logger.error("启动时间可能存在问题,请修改!")
# 服务时间检查 # 服务时间检查
now_time = time.strftime("%H:%M", time.localtime()) now_time = time.strftime("%H:%M", time.localtime())
@@ -27,12 +31,12 @@ def time_checker(f):
f(self, *args, **kwargs) f(self, *args, **kwargs)
return None return None
else: else:
if args[0]['Content'] == "#更新配置": # 不在服务时间内也可以更新配置 if args[0]["Content"] == "#更新配置": # 不在服务时间内也可以更新配置
f(self, *args, **kwargs) f(self, *args, **kwargs)
else: else:
logger.info('非服务时间内,不接受访问') logger.info("非服务时间内,不接受访问")
return None return None
else: else:
f(self, *args, **kwargs) # 未开启时间模块则直接回答 f(self, *args, **kwargs) # 未开启时间模块则直接回答
return _time_checker
return _time_checker

View File

@@ -1,20 +1,18 @@
import os import os
import pathlib import pathlib
from config import conf from config import conf
class TmpDir(object): class TmpDir(object):
"""A temporary directory that is deleted when the object is destroyed. """A temporary directory that is deleted when the object is destroyed."""
"""
tmpFilePath = pathlib.Path("./tmp/")
tmpFilePath = pathlib.Path('./tmp/')
def __init__(self): def __init__(self):
pathExists = os.path.exists(self.tmpFilePath) pathExists = os.path.exists(self.tmpFilePath)
if not pathExists: if not pathExists:
os.makedirs(self.tmpFilePath) os.makedirs(self.tmpFilePath)
def path(self): def path(self):
return str(self.tmpFilePath) + '/' return str(self.tmpFilePath) + "/"

View File

@@ -2,16 +2,30 @@
"open_ai_api_key": "YOUR API KEY", "open_ai_api_key": "YOUR API KEY",
"model": "gpt-3.5-turbo", "model": "gpt-3.5-turbo",
"proxy": "", "proxy": "",
"single_chat_prefix": ["bot", "@bot"], "single_chat_prefix": [
"bot",
"@bot"
],
"single_chat_reply_prefix": "[bot] ", "single_chat_reply_prefix": "[bot] ",
"group_chat_prefix": ["@bot"], "group_chat_prefix": [
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], "@bot"
"group_chat_in_one_session": ["ChatGPT测试群"], ],
"image_create_prefix": ["画", "看", "找"], "group_name_white_list": [
"ChatGPT测试群",
"ChatGPT测试群2"
],
"group_chat_in_one_session": [
"ChatGPT测试群"
],
"image_create_prefix": [
"画",
"看",
"找"
],
"speech_recognition": false, "speech_recognition": false,
"group_speech_recognition": false, "group_speech_recognition": false,
"voice_reply_voice": false, "voice_reply_voice": false,
"conversation_max_tokens": 1000, "conversation_max_tokens": 1000,
"expires_in_seconds": 3600, "expires_in_seconds": 3600,
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。" "character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。"
} }

View File

@@ -3,9 +3,10 @@
import json import json
import logging import logging
import os import os
from common.log import logger
import pickle import pickle
from common.log import logger
# 将所有可用的配置项写在字典里, 请使用小写字母 # 将所有可用的配置项写在字典里, 请使用小写字母
available_setting = { available_setting = {
# openai api配置 # openai api配置
@@ -16,8 +17,7 @@ available_setting = {
# chatgpt模型 当use_azure_chatgpt为true时其名称为Azure上model deployment名称 # chatgpt模型 当use_azure_chatgpt为true时其名称为Azure上model deployment名称
"model": "gpt-3.5-turbo", "model": "gpt-3.5-turbo",
"use_azure_chatgpt": False, # 是否使用azure的chatgpt "use_azure_chatgpt": False, # 是否使用azure的chatgpt
"azure_deployment_id": "", #azure 模型部署名称 "azure_deployment_id": "", # azure 模型部署名称
# Bot触发配置 # Bot触发配置
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复 "single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人 "single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
@@ -30,25 +30,22 @@ available_setting = {
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称 "group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
"trigger_by_self": False, # 是否允许机器人触发 "trigger_by_self": False, # 是否允许机器人触发
"image_create_prefix": ["", "", ""], # 开启图片回复的前缀 "image_create_prefix": ["", "", ""], # 开启图片回复的前缀
"concurrency_in_session": 1, # 同一会话最多有多少条消息在处理中大于1可能乱序 "concurrency_in_session": 1, # 同一会话最多有多少条消息在处理中大于1可能乱序
"image_create_size": "256x256", # 图片大小,可选有 256x256, 512x512, 1024x1024
# chatgpt会话参数 # chatgpt会话参数
"expires_in_seconds": 3600, # 无操作会话的过期时间 "expires_in_seconds": 3600, # 无操作会话的过期时间
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述 "character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述
"conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数 "conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数
# chatgpt限流配置 # chatgpt限流配置
"rate_limit_chatgpt": 20, # chatgpt的调用频率限制 "rate_limit_chatgpt": 20, # chatgpt的调用频率限制
"rate_limit_dalle": 50, # openai dalle的调用频率限制 "rate_limit_dalle": 50, # openai dalle的调用频率限制
# chatgpt api参数 参考https://platform.openai.com/docs/api-reference/chat/create # chatgpt api参数 参考https://platform.openai.com/docs/api-reference/chat/create
"temperature": 0.9, "temperature": 0.9,
"top_p": 1, "top_p": 1,
"frequency_penalty": 0, "frequency_penalty": 0,
"presence_penalty": 0, "presence_penalty": 0,
"request_timeout": 60, # chatgpt请求超时时间openai接口默认设置为600对于难问题一般需要较长时间 "request_timeout": 60, # chatgpt请求超时时间openai接口默认设置为600对于难问题一般需要较长时间
"timeout": 120, # chatgpt重试超时时间在这个时间内将会自动重试 "timeout": 120, # chatgpt重试超时时间在这个时间内将会自动重试
# 语音设置 # 语音设置
"speech_recognition": False, # 是否开启语音识别 "speech_recognition": False, # 是否开启语音识别
"group_speech_recognition": False, # 是否开启群组语音识别 "group_speech_recognition": False, # 是否开启群组语音识别
@@ -56,50 +53,42 @@ available_setting = {
"always_reply_voice": False, # 是否一直使用语音回复 "always_reply_voice": False, # 是否一直使用语音回复
"voice_to_text": "openai", # 语音识别引擎支持openai,baidu,google,azure "voice_to_text": "openai", # 语音识别引擎支持openai,baidu,google,azure
"text_to_voice": "baidu", # 语音合成引擎支持baidu,google,pytts(offline),azure "text_to_voice": "baidu", # 语音合成引擎支持baidu,google,pytts(offline),azure
# baidu 语音api配置 使用百度语音识别和语音合成时需要 # baidu 语音api配置 使用百度语音识别和语音合成时需要
"baidu_app_id": "", "baidu_app_id": "",
"baidu_api_key": "", "baidu_api_key": "",
"baidu_secret_key": "", "baidu_secret_key": "",
# 1536普通话(支持简单的英文识别) 1737英语 1637粤语 1837四川话 1936普通话远场 # 1536普通话(支持简单的英文识别) 1737英语 1637粤语 1837四川话 1936普通话远场
"baidu_dev_pid": "1536", "baidu_dev_pid": "1536",
# azure 语音api配置 使用azure语音识别和语音合成时需要 # azure 语音api配置 使用azure语音识别和语音合成时需要
"azure_voice_api_key": "", "azure_voice_api_key": "",
"azure_voice_region": "japaneast", "azure_voice_region": "japaneast",
# 服务时间限制目前支持itchat # 服务时间限制目前支持itchat
"chat_time_module": False, # 是否开启服务时间限制 "chat_time_module": False, # 是否开启服务时间限制
"chat_start_time": "00:00", # 服务开始时间 "chat_start_time": "00:00", # 服务开始时间
"chat_stop_time": "24:00", # 服务结束时间 "chat_stop_time": "24:00", # 服务结束时间
# itchat的配置 # itchat的配置
"hot_reload": False, # 是否开启热重载 "hot_reload": False, # 是否开启热重载
# wechaty的配置 # wechaty的配置
"wechaty_puppet_service_token": "", # wechaty的token "wechaty_puppet_service_token": "", # wechaty的token
# wechatmp的配置 # wechatmp的配置
"wechatmp_token": "", # 微信公众平台的Token "wechatmp_token": "", # 微信公众平台的Token
"wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443 "wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443
"wechatmp_app_id": "", # 微信公众平台的appID,仅服务号需要 "wechatmp_app_id": "", # 微信公众平台的appID
"wechatmp_app_secret": "", # 微信公众平台的appsecret,仅服务号需要 "wechatmp_app_secret": "", # 微信公众平台的appsecret
"wechatmp_aes_key": "", # 微信公众平台的EncodingAESKey加密模式需要
# chatgpt指令自定义触发词 # chatgpt指令自定义触发词
"clear_memory_commands": ['#清除记忆'], # 重置会话指令,必须以#开头 "clear_memory_commands": ["#清除记忆"], # 重置会话指令,必须以#开头
# channel配置 # channel配置
"channel_type": "wx", # 通道类型,支持:{wx,wxy,terminal,wechatmp,wechatmp_service} "channel_type": "wx", # 通道类型,支持:{wx,wxy,terminal,wechatmp,wechatmp_service}
"debug": False, # 是否开启debug模式开启后会打印更多日志 "debug": False, # 是否开启debug模式开启后会打印更多日志
"appdata_dir": "", # 数据目录
# 插件配置 # 插件配置
"plugin_trigger_prefix": "$", # 规范插件提供聊天相关指令的前缀,建议不要和管理员指令前缀"#"冲突 "plugin_trigger_prefix": "$", # 规范插件提供聊天相关指令的前缀,建议不要和管理员指令前缀"#"冲突
} }
class Config(dict): class Config(dict):
def __init__(self, d:dict={}): def __init__(self, d: dict = {}):
super().__init__(d) super().__init__(d)
# user_datas: 用户数据key为用户名value为用户数据也是dict # user_datas: 用户数据key为用户名value为用户数据也是dict
self.user_datas = {} self.user_datas = {}
@@ -130,7 +119,7 @@ class Config(dict):
def load_user_datas(self): def load_user_datas(self):
try: try:
with open('user_datas.pkl', 'rb') as f: with open(os.path.join(get_appdata_dir(), "user_datas.pkl"), "rb") as f:
self.user_datas = pickle.load(f) self.user_datas = pickle.load(f)
logger.info("[Config] User datas loaded.") logger.info("[Config] User datas loaded.")
except FileNotFoundError as e: except FileNotFoundError as e:
@@ -141,12 +130,13 @@ class Config(dict):
def save_user_datas(self): def save_user_datas(self):
try: try:
with open('user_datas.pkl', 'wb') as f: with open(os.path.join(get_appdata_dir(), "user_datas.pkl"), "wb") as f:
pickle.dump(self.user_datas, f) pickle.dump(self.user_datas, f)
logger.info("[Config] User datas saved.") logger.info("[Config] User datas saved.")
except Exception as e: except Exception as e:
logger.info("[Config] User datas error: {}".format(e)) logger.info("[Config] User datas error: {}".format(e))
config = Config() config = Config()
@@ -154,7 +144,7 @@ def load_config():
global config global config
config_path = "./config.json" config_path = "./config.json"
if not os.path.exists(config_path): if not os.path.exists(config_path):
logger.info('配置文件不存在将使用config-template.json模板') logger.info("配置文件不存在将使用config-template.json模板")
config_path = "./config-template.json" config_path = "./config-template.json"
config_str = read_file(config_path) config_str = read_file(config_path)
@@ -168,8 +158,7 @@ def load_config():
for name, value in os.environ.items(): for name, value in os.environ.items():
name = name.lower() name = name.lower()
if name in available_setting: if name in available_setting:
logger.info( logger.info("[INIT] override config by environ args: {}={}".format(name, value))
"[INIT] override config by environ args: {}={}".format(name, value))
try: try:
config[name] = eval(value) config[name] = eval(value)
except: except:
@@ -182,20 +171,29 @@ def load_config():
if config.get("debug", False): if config.get("debug", False):
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
logger.debug("[INIT] set log level to DEBUG") logger.debug("[INIT] set log level to DEBUG")
logger.info("[INIT] load config: {}".format(config)) logger.info("[INIT] load config: {}".format(config))
config.load_user_datas() config.load_user_datas()
def get_root(): def get_root():
return os.path.dirname(os.path.abspath(__file__)) return os.path.dirname(os.path.abspath(__file__))
def read_file(path): def read_file(path):
with open(path, mode='r', encoding='utf-8') as f: with open(path, mode="r", encoding="utf-8") as f:
return f.read() return f.read()
def conf(): def conf():
return config return config
def get_appdata_dir():
data_path = os.path.join(get_root(), conf().get("appdata_dir", ""))
if not os.path.exists(data_path):
logger.info("[INIT] data path not exists, create it: {}".format(data_path))
os.makedirs(data_path)
return data_path

View File

@@ -33,7 +33,7 @@ ADD ./entrypoint.sh /entrypoint.sh
RUN chmod +x /entrypoint.sh \ RUN chmod +x /entrypoint.sh \
&& groupadd -r noroot \ && groupadd -r noroot \
&& useradd -r -g noroot -s /bin/bash -d /home/noroot noroot \ && useradd -r -g noroot -s /bin/bash -d /home/noroot noroot \
&& chown -R noroot:noroot ${BUILD_PREFIX} && chown -R noroot:noroot ${BUILD_PREFIX}
USER noroot USER noroot

View File

@@ -18,7 +18,7 @@ RUN apt-get update \
&& pip install --no-cache -r requirements.txt \ && pip install --no-cache -r requirements.txt \
&& pip install --no-cache -r requirements-optional.txt \ && pip install --no-cache -r requirements-optional.txt \
&& pip install azure-cognitiveservices-speech && pip install azure-cognitiveservices-speech
WORKDIR ${BUILD_PREFIX} WORKDIR ${BUILD_PREFIX}
ADD docker/entrypoint.sh /entrypoint.sh ADD docker/entrypoint.sh /entrypoint.sh

View File

@@ -11,6 +11,5 @@ docker build -f Dockerfile.alpine \
-t zhayujie/chatgpt-on-wechat . -t zhayujie/chatgpt-on-wechat .
# tag image # tag image
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:alpine docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:alpine
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$CHATGPT_ON_WECHAT_TAG-alpine docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$CHATGPT_ON_WECHAT_TAG-alpine

View File

@@ -11,5 +11,5 @@ docker build -f Dockerfile.debian \
-t zhayujie/chatgpt-on-wechat . -t zhayujie/chatgpt-on-wechat .
# tag image # tag image
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:debian docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:debian
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$CHATGPT_ON_WECHAT_TAG-debian docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$CHATGPT_ON_WECHAT_TAG-debian

View File

@@ -1,4 +1,8 @@
#!/bin/bash #!/bin/bash
unset KUBECONFIG
cd .. && docker build -f docker/Dockerfile.latest \ cd .. && docker build -f docker/Dockerfile.latest \
-t zhayujie/chatgpt-on-wechat . -t zhayujie/chatgpt-on-wechat .
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$(date +%y%m%d)

View File

@@ -9,7 +9,7 @@ RUN apk add --no-cache \
ffmpeg \ ffmpeg \
espeak \ espeak \
&& pip install --no-cache \ && pip install --no-cache \
baidu-aip \ baidu-aip \
chardet \ chardet \
SpeechRecognition SpeechRecognition

View File

@@ -10,7 +10,7 @@ RUN apt-get update \
ffmpeg \ ffmpeg \
espeak \ espeak \
&& pip install --no-cache \ && pip install --no-cache \
baidu-aip \ baidu-aip \
chardet \ chardet \
SpeechRecognition SpeechRecognition

View File

@@ -11,13 +11,13 @@ run_d:
docker rm $(CONTAINER_NAME) || echo docker rm $(CONTAINER_NAME) || echo
docker run -dt --name $(CONTAINER_NAME) $(PORT_MAP) \ docker run -dt --name $(CONTAINER_NAME) $(PORT_MAP) \
--env-file=$(DOTENV) \ --env-file=$(DOTENV) \
$(MOUNT) $(IMG) $(MOUNT) $(IMG)
run_i: run_i:
docker rm $(CONTAINER_NAME) || echo docker rm $(CONTAINER_NAME) || echo
docker run -it --name $(CONTAINER_NAME) $(PORT_MAP) \ docker run -it --name $(CONTAINER_NAME) $(PORT_MAP) \
--env-file=$(DOTENV) \ --env-file=$(DOTENV) \
$(MOUNT) $(IMG) $(MOUNT) $(IMG)
stop: stop:
docker stop $(CONTAINER_NAME) docker stop $(CONTAINER_NAME)

View File

@@ -24,17 +24,17 @@
在本仓库中预置了一些插件,如果要安装其他仓库的插件,有两种方法。 在本仓库中预置了一些插件,如果要安装其他仓库的插件,有两种方法。
- 第一种方法是在将下载的插件文件都解压到"plugins"文件夹的一个单独的文件夹,最终插件的代码都位于"plugins/PLUGIN_NAME/*"中。启动程序后,如果插件的目录结构正确,插件会自动被扫描加载。除此以外,注意你还需要安装文件夹中`requirements.txt`中的依赖。 - 第一种方法是在将下载的插件文件都解压到"plugins"文件夹的一个单独的文件夹,最终插件的代码都位于"plugins/PLUGIN_NAME/*"中。启动程序后,如果插件的目录结构正确,插件会自动被扫描加载。除此以外,注意你还需要安装文件夹中`requirements.txt`中的依赖。
- 第二种方法是`Godcmd`插件,它是预置的管理员插件,能够让程序在运行时就能安装插件,它能够自动安装依赖。 - 第二种方法是`Godcmd`插件,它是预置的管理员插件,能够让程序在运行时就能安装插件,它能够自动安装依赖。
安装插件的命令是"#installp [仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)记录的插件名/仓库地址"。这是管理员命令,认证方法在[这里](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/godcmd)。 安装插件的命令是"#installp [仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)记录的插件名/仓库地址"。这是管理员命令,认证方法在[这里](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/godcmd)。
- 安装[仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)记录的插件:#installp sdwebui - 安装[仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)记录的插件:#installp sdwebui
- 安装指定仓库的插件:#installp https://github.com/lanvent/plugin_sdwebui.git - 安装指定仓库的插件:#installp https://github.com/lanvent/plugin_sdwebui.git
在安装之后,需要执行"#scanp"命令来扫描加载新安装的插件(或者重新启动程序)。 在安装之后,需要执行"#scanp"命令来扫描加载新安装的插件(或者重新启动程序)。
安装插件后需要注意有些插件有自己的配置模板,一般要去掉".template"新建一个配置文件。 安装插件后需要注意有些插件有自己的配置模板,一般要去掉".template"新建一个配置文件。
## 插件化实现 ## 插件化实现
@@ -107,14 +107,14 @@
``` ```
回复`Reply`的定义如下所示它允许Bot可以回复多类不同的消息。同时也加入了`INFO``ERROR`消息类型区分系统提示和系统错误。 回复`Reply`的定义如下所示它允许Bot可以回复多类不同的消息。同时也加入了`INFO``ERROR`消息类型区分系统提示和系统错误。
```python ```python
class ReplyType(Enum): class ReplyType(Enum):
TEXT = 1 # 文本 TEXT = 1 # 文本
VOICE = 2 # 音频文件 VOICE = 2 # 音频文件
IMAGE = 3 # 图片文件 IMAGE = 3 # 图片文件
IMAGE_URL = 4 # 图片URL IMAGE_URL = 4 # 图片URL
INFO = 9 INFO = 9
ERROR = 10 ERROR = 10
class Reply: class Reply:
@@ -159,12 +159,12 @@
目前支持三类触发事件: 目前支持三类触发事件:
``` ```
1.收到消息 1.收到消息
---> `ON_HANDLE_CONTEXT` ---> `ON_HANDLE_CONTEXT`
2.产生回复 2.产生回复
---> `ON_DECORATE_REPLY` ---> `ON_DECORATE_REPLY`
3.装饰回复 3.装饰回复
---> `ON_SEND_REPLY` ---> `ON_SEND_REPLY`
4.发送回复 4.发送回复
``` ```
@@ -268,6 +268,6 @@ class Hello(Plugin):
- 一个插件目录建议只注册一个插件类。建议使用单独的仓库维护插件,便于更新。 - 一个插件目录建议只注册一个插件类。建议使用单独的仓库维护插件,便于更新。
在测试调试好后提交`PR`,把自己的仓库加入到[仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)中。 在测试调试好后提交`PR`,把自己的仓库加入到[仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)中。
- 插件的config文件、使用说明`README.md`、`requirement.txt`等放置在插件目录中。 - 插件的config文件、使用说明`README.md`、`requirement.txt`等放置在插件目录中。
- 默认优先级不要超过管理员插件`Godcmd`的优先级(999)`Godcmd`插件提供了配置管理、插件管理等功能。 - 默认优先级不要超过管理员插件`Godcmd`的优先级(999)`Godcmd`插件提供了配置管理、插件管理等功能。

View File

@@ -1,9 +1,9 @@
from .plugin_manager import PluginManager
from .event import * from .event import *
from .plugin import * from .plugin import *
from .plugin_manager import PluginManager
instance = PluginManager() instance = PluginManager()
register = instance.register register = instance.register
# load_plugins = instance.load_plugins # load_plugins = instance.load_plugins
# emit_event = instance.emit_event # emit_event = instance.emit_event

View File

@@ -1 +1 @@
from .banwords import * from .banwords import *

View File

@@ -2,56 +2,65 @@
import json import json
import os import os
import plugins
from bridge.context import ContextType from bridge.context import ContextType
from bridge.reply import Reply, ReplyType from bridge.reply import Reply, ReplyType
import plugins
from plugins import *
from common.log import logger from common.log import logger
from .WordsSearch import WordsSearch from plugins import *
from .lib.WordsSearch import WordsSearch
@plugins.register(name="Banwords", desire_priority=100, hidden=True, desc="判断消息中是否有敏感词、决定是否回复。", version="1.0", author="lanvent") @plugins.register(
name="Banwords",
desire_priority=100,
hidden=True,
desc="判断消息中是否有敏感词、决定是否回复。",
version="1.0",
author="lanvent",
)
class Banwords(Plugin): class Banwords(Plugin):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
try: try:
curdir=os.path.dirname(__file__) curdir = os.path.dirname(__file__)
config_path=os.path.join(curdir,"config.json") config_path = os.path.join(curdir, "config.json")
conf=None conf = None
if not os.path.exists(config_path): if not os.path.exists(config_path):
conf={"action":"ignore"} conf = {"action": "ignore"}
with open(config_path,"w") as f: with open(config_path, "w") as f:
json.dump(conf,f,indent=4) json.dump(conf, f, indent=4)
else: else:
with open(config_path,"r") as f: with open(config_path, "r") as f:
conf=json.load(f) conf = json.load(f)
self.searchr = WordsSearch() self.searchr = WordsSearch()
self.action = conf["action"] self.action = conf["action"]
banwords_path = os.path.join(curdir,"banwords.txt") banwords_path = os.path.join(curdir, "banwords.txt")
with open(banwords_path, 'r', encoding='utf-8') as f: with open(banwords_path, "r", encoding="utf-8") as f:
words=[] words = []
for line in f: for line in f:
word = line.strip() word = line.strip()
if word: if word:
words.append(word) words.append(word)
self.searchr.SetKeywords(words) self.searchr.SetKeywords(words)
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
if conf.get("reply_filter",True): if conf.get("reply_filter", True):
self.handlers[Event.ON_DECORATE_REPLY] = self.on_decorate_reply self.handlers[Event.ON_DECORATE_REPLY] = self.on_decorate_reply
self.reply_action = conf.get("reply_action","ignore") self.reply_action = conf.get("reply_action", "ignore")
logger.info("[Banwords] inited") logger.info("[Banwords] inited")
except Exception as e: except Exception as e:
logger.warn("[Banwords] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords .") logger.warn("[Banwords] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords .")
raise e raise e
def on_handle_context(self, e_context: EventContext): def on_handle_context(self, e_context: EventContext):
if e_context["context"].type not in [
if e_context['context'].type not in [ContextType.TEXT,ContextType.IMAGE_CREATE]: ContextType.TEXT,
ContextType.IMAGE_CREATE,
]:
return return
content = e_context['context'].content content = e_context["context"].content
logger.debug("[Banwords] on_handle_context. content: %s" % content) logger.debug("[Banwords] on_handle_context. content: %s" % content)
if self.action == "ignore": if self.action == "ignore":
f = self.searchr.FindFirst(content) f = self.searchr.FindFirst(content)
@@ -61,31 +70,30 @@ class Banwords(Plugin):
return return
elif self.action == "replace": elif self.action == "replace":
if self.searchr.ContainsAny(content): if self.searchr.ContainsAny(content):
reply = Reply(ReplyType.INFO, "发言中包含敏感词,请重试: \n"+self.searchr.Replace(content)) reply = Reply(ReplyType.INFO, "发言中包含敏感词,请重试: \n" + self.searchr.Replace(content))
e_context['reply'] = reply e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS e_context.action = EventAction.BREAK_PASS
return return
def on_decorate_reply(self, e_context: EventContext):
if e_context['reply'].type not in [ReplyType.TEXT]: def on_decorate_reply(self, e_context: EventContext):
if e_context["reply"].type not in [ReplyType.TEXT]:
return return
reply = e_context['reply'] reply = e_context["reply"]
content = reply.content content = reply.content
if self.reply_action == "ignore": if self.reply_action == "ignore":
f = self.searchr.FindFirst(content) f = self.searchr.FindFirst(content)
if f: if f:
logger.info("[Banwords] %s in reply" % f["Keyword"]) logger.info("[Banwords] %s in reply" % f["Keyword"])
e_context['reply'] = None e_context["reply"] = None
e_context.action = EventAction.BREAK_PASS e_context.action = EventAction.BREAK_PASS
return return
elif self.reply_action == "replace": elif self.reply_action == "replace":
if self.searchr.ContainsAny(content): if self.searchr.ContainsAny(content):
reply = Reply(ReplyType.INFO, "已替换回复中的敏感词: \n"+self.searchr.Replace(content)) reply = Reply(ReplyType.INFO, "已替换回复中的敏感词: \n" + self.searchr.Replace(content))
e_context['reply'] = reply e_context["reply"] = reply
e_context.action = EventAction.CONTINUE e_context.action = EventAction.CONTINUE
return return
def get_help_text(self, **kwargs): def get_help_text(self, **kwargs):
return Banwords.desc return "过滤消息中的敏感词。"

View File

@@ -1,5 +1,5 @@
{ {
"action": "replace", "action": "replace",
"reply_filter": true, "reply_filter": true,
"reply_action": "ignore" "reply_action": "ignore"
} }

View File

@@ -24,7 +24,7 @@ see https://ai.baidu.com/unit/home#/home?track=61fe1b0d3407ce3face1d92cb5c291087
``` json ``` json
{ {
"service_id": "s...", #"机器人ID" "service_id": "s...", #"机器人ID"
"api_key": "", "api_key": "",
"secret_key": "" "secret_key": ""
} }
``` ```

View File

@@ -1 +1 @@
from .bdunit import * from .bdunit import *

View File

@@ -2,21 +2,29 @@
import json import json
import os import os
import uuid import uuid
from uuid import getnode as get_mac
import requests import requests
import plugins
from bridge.context import ContextType from bridge.context import ContextType
from bridge.reply import Reply, ReplyType from bridge.reply import Reply, ReplyType
from common.log import logger from common.log import logger
import plugins
from plugins import * from plugins import *
from uuid import getnode as get_mac
"""利用百度UNIT实现智能对话 """利用百度UNIT实现智能对话
如果命中意图,返回意图对应的回复,否则返回继续交付给下个插件处理 如果命中意图,返回意图对应的回复,否则返回继续交付给下个插件处理
""" """
@plugins.register(name="BDunit", desire_priority=0, hidden=True, desc="Baidu unit bot system", version="0.1", author="jackson") @plugins.register(
name="BDunit",
desire_priority=0,
hidden=True,
desc="Baidu unit bot system",
version="0.1",
author="jackson",
)
class BDunit(Plugin): class BDunit(Plugin):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@@ -40,11 +48,10 @@ class BDunit(Plugin):
raise e raise e
def on_handle_context(self, e_context: EventContext): def on_handle_context(self, e_context: EventContext):
if e_context["context"].type != ContextType.TEXT:
if e_context['context'].type != ContextType.TEXT:
return return
content = e_context['context'].content content = e_context["context"].content
logger.debug("[BDunit] on_handle_context. content: %s" % content) logger.debug("[BDunit] on_handle_context. content: %s" % content)
parsed = self.getUnit2(content) parsed = self.getUnit2(content)
intent = self.getIntent(parsed) intent = self.getIntent(parsed)
@@ -53,7 +60,7 @@ class BDunit(Plugin):
reply = Reply() reply = Reply()
reply.type = ReplyType.TEXT reply.type = ReplyType.TEXT
reply.content = self.getSay(parsed) reply.content = self.getSay(parsed)
e_context['reply'] = reply e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS # 事件结束并跳过处理context的默认逻辑 e_context.action = EventAction.BREAK_PASS # 事件结束并跳过处理context的默认逻辑
else: else:
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑
@@ -69,18 +76,14 @@ class BDunit(Plugin):
Returns: Returns:
string: access_token string: access_token
""" """
url = "https://aip.baidubce.com/oauth/2.0/token?client_id={}&client_secret={}&grant_type=client_credentials".format( url = "https://aip.baidubce.com/oauth/2.0/token?client_id={}&client_secret={}&grant_type=client_credentials".format(self.api_key, self.secret_key)
self.api_key, self.secret_key)
payload = "" payload = ""
headers = { headers = {"Content-Type": "application/json", "Accept": "application/json"}
'Content-Type': 'application/json',
'Accept': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload) response = requests.request("POST", url, headers=headers, data=payload)
# print(response.text) # print(response.text)
return response.json()['access_token'] return response.json()["access_token"]
def getUnit(self, query): def getUnit(self, query):
""" """
@@ -89,12 +92,12 @@ class BDunit(Plugin):
:returns: UNIT 解析结果。如果解析失败,返回 None :returns: UNIT 解析结果。如果解析失败,返回 None
""" """
url = ( url = "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" + self.access_token
'https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=' request = {
+ self.access_token "query": query,
) "user_id": str(get_mac())[:32],
request = {"query": query, "user_id": str( "terminal_id": "88888",
get_mac())[:32], "terminal_id": "88888"} }
body = { body = {
"log_id": str(uuid.uuid1()), "log_id": str(uuid.uuid1()),
"version": "3.0", "version": "3.0",
@@ -116,10 +119,7 @@ class BDunit(Plugin):
:param query: 用户的指令字符串 :param query: 用户的指令字符串
:returns: UNIT 解析结果。如果解析失败,返回 None :returns: UNIT 解析结果。如果解析失败,返回 None
""" """
url = ( url = "https://aip.baidubce.com/rpc/2.0/unit/service/chat?access_token=" + self.access_token
"https://aip.baidubce.com/rpc/2.0/unit/service/chat?access_token="
+ self.access_token
)
request = {"query": query, "user_id": str(get_mac())[:32]} request = {"query": query, "user_id": str(get_mac())[:32]}
body = { body = {
"log_id": str(uuid.uuid1()), "log_id": str(uuid.uuid1()),
@@ -142,11 +142,7 @@ class BDunit(Plugin):
:param parsed: UNIT 解析结果 :param parsed: UNIT 解析结果
:returns: 意图数组 :returns: 意图数组
""" """
if ( if parsed and "result" in parsed and "response_list" in parsed["result"]:
parsed
and "result" in parsed
and "response_list" in parsed["result"]
):
try: try:
return parsed["result"]["response_list"][0]["schema"]["intent"] return parsed["result"]["response_list"][0]["schema"]["intent"]
except Exception as e: except Exception as e:
@@ -163,18 +159,10 @@ class BDunit(Plugin):
:param intent: 意图的名称 :param intent: 意图的名称
:returns: True: 包含; False: 不包含 :returns: True: 包含; False: 不包含
""" """
if ( if parsed and "result" in parsed and "response_list" in parsed["result"]:
parsed
and "result" in parsed
and "response_list" in parsed["result"]
):
response_list = parsed["result"]["response_list"] response_list = parsed["result"]["response_list"]
for response in response_list: for response in response_list:
if ( if "schema" in response and "intent" in response["schema"] and response["schema"]["intent"] == intent:
"schema" in response
and "intent" in response["schema"]
and response["schema"]["intent"] == intent
):
return True return True
return False return False
else: else:
@@ -189,11 +177,7 @@ class BDunit(Plugin):
:returns: 词槽列表。你可以通过 name 属性筛选词槽, :returns: 词槽列表。你可以通过 name 属性筛选词槽,
再通过 normalized_word 属性取出相应的值 再通过 normalized_word 属性取出相应的值
""" """
if ( if parsed and "result" in parsed and "response_list" in parsed["result"]:
parsed
and "result" in parsed
and "response_list" in parsed["result"]
):
response_list = parsed["result"]["response_list"] response_list = parsed["result"]["response_list"]
if intent == "": if intent == "":
try: try:
@@ -202,12 +186,7 @@ class BDunit(Plugin):
logger.warning(e) logger.warning(e)
return [] return []
for response in response_list: for response in response_list:
if ( if "schema" in response and "intent" in response["schema"] and "slots" in response["schema"] and response["schema"]["intent"] == intent:
"schema" in response
and "intent" in response["schema"]
and "slots" in response["schema"]
and response["schema"]["intent"] == intent
):
return response["schema"]["slots"] return response["schema"]["slots"]
return [] return []
else: else:
@@ -236,22 +215,14 @@ class BDunit(Plugin):
:param parsed: UNIT 解析结果 :param parsed: UNIT 解析结果
:returns: UNIT 的回复文本 :returns: UNIT 的回复文本
""" """
if ( if parsed and "result" in parsed and "response_list" in parsed["result"]:
parsed
and "result" in parsed
and "response_list" in parsed["result"]
):
response_list = parsed["result"]["response_list"] response_list = parsed["result"]["response_list"]
answer = {} answer = {}
for response in response_list: for response in response_list:
if ( if (
"schema" in response "schema" in response
and "intent_confidence" in response["schema"] and "intent_confidence" in response["schema"]
and ( and (not answer or response["schema"]["intent_confidence"] > answer["schema"]["intent_confidence"])
not answer
or response["schema"]["intent_confidence"]
> answer["schema"]["intent_confidence"]
)
): ):
answer = response answer = response
return answer["action_list"][0]["say"] return answer["action_list"][0]["say"]
@@ -266,11 +237,7 @@ class BDunit(Plugin):
:param intent: 意图的名称 :param intent: 意图的名称
:returns: UNIT 的回复文本 :returns: UNIT 的回复文本
""" """
if ( if parsed and "result" in parsed and "response_list" in parsed["result"]:
parsed
and "result" in parsed
and "response_list" in parsed["result"]
):
response_list = parsed["result"]["response_list"] response_list = parsed["result"]["response_list"]
if intent == "": if intent == "":
try: try:
@@ -279,11 +246,7 @@ class BDunit(Plugin):
logger.warning(e) logger.warning(e)
return "" return ""
for response in response_list: for response in response_list:
if ( if "schema" in response and "intent" in response["schema"] and response["schema"]["intent"] == intent:
"schema" in response
and "intent" in response["schema"]
and response["schema"]["intent"] == intent
):
try: try:
return response["action_list"][0]["say"] return response["action_list"][0]["say"]
except Exception as e: except Exception as e:

View File

@@ -1,5 +1,5 @@
{ {
"service_id": "s...", "service_id": "s...",
"api_key": "", "api_key": "",
"secret_key": "" "secret_key": ""
} }

View File

@@ -1 +1 @@
from .dungeon import * from .dungeon import *

View File

@@ -1,17 +1,18 @@
# encoding:utf-8 # encoding:utf-8
import plugins
from bridge.bridge import Bridge from bridge.bridge import Bridge
from bridge.context import ContextType from bridge.context import ContextType
from bridge.reply import Reply, ReplyType from bridge.reply import Reply, ReplyType
from common.expired_dict import ExpiredDict
from config import conf
import plugins
from plugins import *
from common.log import logger
from common import const from common import const
from common.expired_dict import ExpiredDict
from common.log import logger
from config import conf
from plugins import *
# https://github.com/bupticybee/ChineseAiDungeonChatGPT # https://github.com/bupticybee/ChineseAiDungeonChatGPT
class StoryTeller(): class StoryTeller:
def __init__(self, bot, sessionid, story): def __init__(self, bot, sessionid, story):
self.bot = bot self.bot = bot
self.sessionid = sessionid self.sessionid = sessionid
@@ -27,67 +28,79 @@ class StoryTeller():
if user_action[-1] != "": if user_action[-1] != "":
user_action = user_action + "" user_action = user_action + ""
if self.first_interact: if self.first_interact:
prompt = """现在来充当一个文字冒险游戏,描述时候注意节奏,不要太快,仔细描述各个人物的心情和周边环境。一次只需写四到六句话。 prompt = (
开头是,""" + self.story + " " + user_action """现在来充当一个文字冒险游戏,描述时候注意节奏,不要太快,仔细描述各个人物的心情和周边环境。一次只需写四到六句话。
开头是,"""
+ self.story
+ " "
+ user_action
)
self.first_interact = False self.first_interact = False
else: else:
prompt = """继续一次只需要续写四到六句话总共就只讲5分钟内发生的事情。""" + user_action prompt = """继续一次只需要续写四到六句话总共就只讲5分钟内发生的事情。""" + user_action
return prompt return prompt
@plugins.register(name="Dungeon", desire_priority=0, namecn="文字冒险", desc="A plugin to play dungeon game", version="1.0", author="lanvent") @plugins.register(
name="Dungeon",
desire_priority=0,
namecn="文字冒险",
desc="A plugin to play dungeon game",
version="1.0",
author="lanvent",
)
class Dungeon(Plugin): class Dungeon(Plugin):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
logger.info("[Dungeon] inited") logger.info("[Dungeon] inited")
# 目前没有设计session过期事件这里先暂时使用过期字典 # 目前没有设计session过期事件这里先暂时使用过期字典
if conf().get('expires_in_seconds'): if conf().get("expires_in_seconds"):
self.games = ExpiredDict(conf().get('expires_in_seconds')) self.games = ExpiredDict(conf().get("expires_in_seconds"))
else: else:
self.games = dict() self.games = dict()
def on_handle_context(self, e_context: EventContext): def on_handle_context(self, e_context: EventContext):
if e_context["context"].type != ContextType.TEXT:
if e_context['context'].type != ContextType.TEXT:
return return
bottype = Bridge().get_bot_type("chat") bottype = Bridge().get_bot_type("chat")
if bottype not in (const.CHATGPT, const.OPEN_AI): if bottype not in (const.CHATGPT, const.OPEN_AI):
return return
bot = Bridge().get_bot("chat") bot = Bridge().get_bot("chat")
content = e_context['context'].content[:] content = e_context["context"].content[:]
clist = e_context['context'].content.split(maxsplit=1) clist = e_context["context"].content.split(maxsplit=1)
sessionid = e_context['context']['session_id'] sessionid = e_context["context"]["session_id"]
logger.debug("[Dungeon] on_handle_context. content: %s" % clist) logger.debug("[Dungeon] on_handle_context. content: %s" % clist)
trigger_prefix = conf().get('plugin_trigger_prefix', "$") trigger_prefix = conf().get("plugin_trigger_prefix", "$")
if clist[0] == f"{trigger_prefix}停止冒险": if clist[0] == f"{trigger_prefix}停止冒险":
if sessionid in self.games: if sessionid in self.games:
self.games[sessionid].reset() self.games[sessionid].reset()
del self.games[sessionid] del self.games[sessionid]
reply = Reply(ReplyType.INFO, "冒险结束!") reply = Reply(ReplyType.INFO, "冒险结束!")
e_context['reply'] = reply e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS e_context.action = EventAction.BREAK_PASS
elif clist[0] == f"{trigger_prefix}开始冒险" or sessionid in self.games: elif clist[0] == f"{trigger_prefix}开始冒险" or sessionid in self.games:
if sessionid not in self.games or clist[0] == f"{trigger_prefix}开始冒险": if sessionid not in self.games or clist[0] == f"{trigger_prefix}开始冒险":
if len(clist)>1 : if len(clist) > 1:
story = clist[1] story = clist[1]
else: else:
story = "你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。" story = "你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。"
self.games[sessionid] = StoryTeller(bot, sessionid, story) self.games[sessionid] = StoryTeller(bot, sessionid, story)
reply = Reply(ReplyType.INFO, "冒险开始,你可以输入任意内容,让故事继续下去。故事背景是:" + story) reply = Reply(ReplyType.INFO, "冒险开始,你可以输入任意内容,让故事继续下去。故事背景是:" + story)
e_context['reply'] = reply e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS # 事件结束并跳过处理context的默认逻辑 e_context.action = EventAction.BREAK_PASS # 事件结束并跳过处理context的默认逻辑
else: else:
prompt = self.games[sessionid].action(content) prompt = self.games[sessionid].action(content)
e_context['context'].type = ContextType.TEXT e_context["context"].type = ContextType.TEXT
e_context['context'].content = prompt e_context["context"].content = prompt
e_context.action = EventAction.BREAK # 事件结束不跳过处理context的默认逻辑 e_context.action = EventAction.BREAK # 事件结束不跳过处理context的默认逻辑
def get_help_text(self, **kwargs): def get_help_text(self, **kwargs):
help_text = "可以和机器人一起玩文字冒险游戏。\n" help_text = "可以和机器人一起玩文字冒险游戏。\n"
if kwargs.get('verbose') != True: if kwargs.get("verbose") != True:
return help_text return help_text
trigger_prefix = conf().get('plugin_trigger_prefix', "$") trigger_prefix = conf().get("plugin_trigger_prefix", "$")
help_text = f"{trigger_prefix}开始冒险 "+"背景故事: 开始一个基于{背景故事}的文字冒险,之后你的所有消息会协助完善这个故事。\n"+f"{trigger_prefix}停止冒险: 结束游戏。\n" help_text = f"{trigger_prefix}开始冒险 " + "背景故事: 开始一个基于{背景故事}的文字冒险,之后你的所有消息会协助完善这个故事。\n" + f"{trigger_prefix}停止冒险: 结束游戏。\n"
if kwargs.get('verbose') == True: if kwargs.get("verbose") == True:
help_text += f"\n命令例子: '{trigger_prefix}开始冒险 你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。'" help_text += f"\n命令例子: '{trigger_prefix}开始冒险 你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。'"
return help_text return help_text

View File

@@ -9,17 +9,17 @@ class Event(Enum):
e_context = { "channel": 消息channel, "context" : 本次消息的context} e_context = { "channel": 消息channel, "context" : 本次消息的context}
""" """
ON_HANDLE_CONTEXT = 2 # 处理消息前 ON_HANDLE_CONTEXT = 2 # 处理消息前
""" """
e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复,初始为空 } e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复,初始为空 }
""" """
ON_DECORATE_REPLY = 3 # 得到回复后准备装饰 ON_DECORATE_REPLY = 3 # 得到回复后准备装饰
""" """
e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 } e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 }
""" """
ON_SEND_REPLY = 4 # 发送回复前 ON_SEND_REPLY = 4 # 发送回复前
""" """
e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 } e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 }
""" """
@@ -28,9 +28,9 @@ class Event(Enum):
class EventAction(Enum): class EventAction(Enum):
CONTINUE = 1 # 事件未结束,继续交给下个插件处理,如果没有下个插件,则交付给默认的事件处理逻辑 CONTINUE = 1 # 事件未结束,继续交给下个插件处理,如果没有下个插件,则交付给默认的事件处理逻辑
BREAK = 2 # 事件结束,不再给下个插件处理,交付给默认的事件处理逻辑 BREAK = 2 # 事件结束,不再给下个插件处理,交付给默认的事件处理逻辑
BREAK_PASS = 3 # 事件结束,不再给下个插件处理,不交付给默认的事件处理逻辑 BREAK_PASS = 3 # 事件结束,不再给下个插件处理,不交付给默认的事件处理逻辑
class EventContext: class EventContext:

View File

@@ -1 +1 @@
from .finish import * from .finish import *

View File

@@ -1,14 +1,21 @@
# encoding:utf-8 # encoding:utf-8
import plugins
from bridge.context import ContextType from bridge.context import ContextType
from bridge.reply import Reply, ReplyType from bridge.reply import Reply, ReplyType
from config import conf
import plugins
from plugins import *
from common.log import logger from common.log import logger
from config import conf
from plugins import *
@plugins.register(name="Finish", desire_priority=-999, hidden=True, desc="A plugin that check unknown command", version="1.0", author="js00000") @plugins.register(
name="Finish",
desire_priority=-999,
hidden=True,
desc="A plugin that check unknown command",
version="1.0",
author="js00000",
)
class Finish(Plugin): class Finish(Plugin):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@@ -16,19 +23,18 @@ class Finish(Plugin):
logger.info("[Finish] inited") logger.info("[Finish] inited")
def on_handle_context(self, e_context: EventContext): def on_handle_context(self, e_context: EventContext):
if e_context["context"].type != ContextType.TEXT:
if e_context['context'].type != ContextType.TEXT:
return return
content = e_context['context'].content content = e_context["context"].content
logger.debug("[Finish] on_handle_context. content: %s" % content) logger.debug("[Finish] on_handle_context. content: %s" % content)
trigger_prefix = conf().get('plugin_trigger_prefix',"$") trigger_prefix = conf().get("plugin_trigger_prefix", "$")
if content.startswith(trigger_prefix): if content.startswith(trigger_prefix):
reply = Reply() reply = Reply()
reply.type = ReplyType.ERROR reply.type = ReplyType.ERROR
reply.content = "未知插件命令\n查看插件命令列表请输入#help 插件名\n" reply.content = "未知插件命令\n查看插件命令列表请输入#help 插件名\n"
e_context['reply'] = reply e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS # 事件结束并跳过处理context的默认逻辑 e_context.action = EventAction.BREAK_PASS # 事件结束并跳过处理context的默认逻辑
def get_help_text(self, **kwargs): def get_help_text(self, **kwargs):
return "" return ""

View File

@@ -1 +1 @@
from .godcmd import * from .godcmd import *

View File

@@ -1,4 +1,4 @@
{ {
"password": "", "password": "",
"admin_users": [] "admin_users": []
} }

View File

@@ -6,14 +6,16 @@ import random
import string import string
import traceback import traceback
from typing import Tuple from typing import Tuple
import plugins
from bridge.bridge import Bridge from bridge.bridge import Bridge
from bridge.context import ContextType from bridge.context import ContextType
from bridge.reply import Reply, ReplyType from bridge.reply import Reply, ReplyType
from config import conf, load_config
import plugins
from plugins import *
from common import const from common import const
from common.log import logger from common.log import logger
from config import conf, load_config
from plugins import *
# 定义指令集 # 定义指令集
COMMANDS = { COMMANDS = {
"help": { "help": {
@@ -41,7 +43,7 @@ COMMANDS = {
}, },
"id": { "id": {
"alias": ["id", "用户"], "alias": ["id", "用户"],
"desc": "获取用户id", # wechaty和wechatmp的用户id不会变化可用于绑定管理员 "desc": "获取用户id", # wechaty和wechatmp的用户id不会变化可用于绑定管理员
}, },
"reset": { "reset": {
"alias": ["reset", "重置会话"], "alias": ["reset", "重置会话"],
@@ -114,18 +116,20 @@ ADMIN_COMMANDS = {
"desc": "开启机器调试日志", "desc": "开启机器调试日志",
}, },
} }
# 定义帮助函数 # 定义帮助函数
def get_help_text(isadmin, isgroup): def get_help_text(isadmin, isgroup):
help_text = "通用指令:\n" help_text = "通用指令:\n"
for cmd, info in COMMANDS.items(): for cmd, info in COMMANDS.items():
if cmd=="auth": #不提示认证指令 if cmd == "auth": # 不提示认证指令
continue continue
if cmd=="id" and conf().get("channel_type","wx") not in ["wxy","wechatmp"]: if cmd == "id" and conf().get("channel_type", "wx") not in ["wxy", "wechatmp"]:
continue continue
alias=["#"+a for a in info['alias'][:1]] alias = ["#" + a for a in info["alias"][:1]]
help_text += f"{','.join(alias)} " help_text += f"{','.join(alias)} "
if 'args' in info: if "args" in info:
args=[a for a in info['args']] args = [a for a in info["args"]]
help_text += f"{' '.join(args)}" help_text += f"{' '.join(args)}"
help_text += f": {info['desc']}\n" help_text += f": {info['desc']}\n"
@@ -135,39 +139,46 @@ def get_help_text(isadmin, isgroup):
for plugin in plugins: for plugin in plugins:
if plugins[plugin].enabled and not plugins[plugin].hidden: if plugins[plugin].enabled and not plugins[plugin].hidden:
namecn = plugins[plugin].namecn namecn = plugins[plugin].namecn
help_text += "\n%s:"%namecn help_text += "\n%s:" % namecn
help_text += PluginManager().instances[plugin].get_help_text(verbose=False).strip() help_text += PluginManager().instances[plugin].get_help_text(verbose=False).strip()
if ADMIN_COMMANDS and isadmin: if ADMIN_COMMANDS and isadmin:
help_text += "\n\n管理员指令:\n" help_text += "\n\n管理员指令:\n"
for cmd, info in ADMIN_COMMANDS.items(): for cmd, info in ADMIN_COMMANDS.items():
alias=["#"+a for a in info['alias'][:1]] alias = ["#" + a for a in info["alias"][:1]]
help_text += f"{','.join(alias)} " help_text += f"{','.join(alias)} "
if 'args' in info: if "args" in info:
args=[a for a in info['args']] args = [a for a in info["args"]]
help_text += f"{' '.join(args)}" help_text += f"{' '.join(args)}"
help_text += f": {info['desc']}\n" help_text += f": {info['desc']}\n"
return help_text return help_text
@plugins.register(name="Godcmd", desire_priority=999, hidden=True, desc="为你的机器人添加指令集,有用户和管理员两种角色,加载顺序请放在首位,初次运行后插件目录会生成配置文件, 填充管理员密码后即可认证", version="1.0", author="lanvent")
class Godcmd(Plugin):
@plugins.register(
name="Godcmd",
desire_priority=999,
hidden=True,
desc="为你的机器人添加指令集,有用户和管理员两种角色,加载顺序请放在首位,初次运行后插件目录会生成配置文件, 填充管理员密码后即可认证",
version="1.0",
author="lanvent",
)
class Godcmd(Plugin):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
curdir=os.path.dirname(__file__) curdir = os.path.dirname(__file__)
config_path=os.path.join(curdir,"config.json") config_path = os.path.join(curdir, "config.json")
gconf=None gconf = None
if not os.path.exists(config_path): if not os.path.exists(config_path):
gconf={"password":"","admin_users":[]} gconf = {"password": "", "admin_users": []}
with open(config_path,"w") as f: with open(config_path, "w") as f:
json.dump(gconf,f,indent=4) json.dump(gconf, f, indent=4)
else: else:
with open(config_path,"r") as f: with open(config_path, "r") as f:
gconf=json.load(f) gconf = json.load(f)
if gconf["password"] == "": if gconf["password"] == "":
self.temp_password = "".join(random.sample(string.digits, 4)) self.temp_password = "".join(random.sample(string.digits, 4))
logger.info("[Godcmd] 因未设置口令,本次的临时口令为%s"%self.temp_password) logger.info("[Godcmd] 因未设置口令,本次的临时口令为%s" % self.temp_password)
else: else:
self.temp_password = None self.temp_password = None
custom_commands = conf().get("clear_memory_commands", []) custom_commands = conf().get("clear_memory_commands", [])
@@ -178,41 +189,47 @@ class Godcmd(Plugin):
COMMANDS["reset"]["alias"].append(custom_command) COMMANDS["reset"]["alias"].append(custom_command)
self.password = gconf["password"] self.password = gconf["password"]
self.admin_users = gconf["admin_users"] # 预存的管理员账号这些账号不需要认证。itchat的用户名每次都会变不可用 self.admin_users = gconf["admin_users"] # 预存的管理员账号这些账号不需要认证。itchat的用户名每次都会变不可用
self.isrunning = True # 机器人是否运行中 self.isrunning = True # 机器人是否运行中
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
logger.info("[Godcmd] inited") logger.info("[Godcmd] inited")
def on_handle_context(self, e_context: EventContext): def on_handle_context(self, e_context: EventContext):
context_type = e_context['context'].type context_type = e_context["context"].type
if context_type != ContextType.TEXT: if context_type != ContextType.TEXT:
if not self.isrunning: if not self.isrunning:
e_context.action = EventAction.BREAK_PASS e_context.action = EventAction.BREAK_PASS
return return
content = e_context['context'].content content = e_context["context"].content
logger.debug("[Godcmd] on_handle_context. content: %s" % content) logger.debug("[Godcmd] on_handle_context. content: %s" % content)
if content.startswith("#"): if content.startswith("#"):
if len(content) == 1:
reply = Reply()
reply.type = ReplyType.ERROR
reply.content = f"空指令,输入#help查看指令列表\n"
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS
return
# msg = e_context['context']['msg'] # msg = e_context['context']['msg']
channel = e_context['channel'] channel = e_context["channel"]
user = e_context['context']['receiver'] user = e_context["context"]["receiver"]
session_id = e_context['context']['session_id'] session_id = e_context["context"]["session_id"]
isgroup = e_context['context'].get("isgroup", False) isgroup = e_context["context"].get("isgroup", False)
bottype = Bridge().get_bot_type("chat") bottype = Bridge().get_bot_type("chat")
bot = Bridge().get_bot("chat") bot = Bridge().get_bot("chat")
# 将命令和参数分割 # 将命令和参数分割
command_parts = content[1:].strip().split() command_parts = content[1:].strip().split()
cmd = command_parts[0] cmd = command_parts[0]
args = command_parts[1:] args = command_parts[1:]
isadmin=False isadmin = False
if user in self.admin_users: if user in self.admin_users:
isadmin=True isadmin = True
ok=False ok = False
result="string" result = "string"
if any(cmd in info['alias'] for info in COMMANDS.values()): if any(cmd in info["alias"] for info in COMMANDS.values()):
cmd = next(c for c, info in COMMANDS.items() if cmd in info['alias']) cmd = next(c for c, info in COMMANDS.items() if cmd in info["alias"])
if cmd == "auth": if cmd == "auth":
ok, result = self.authenticate(user, args, isadmin, isgroup) ok, result = self.authenticate(user, args, isadmin, isgroup)
elif cmd == "help" or cmd == "helpp": elif cmd == "help" or cmd == "helpp":
@@ -224,7 +241,7 @@ class Godcmd(Plugin):
query_name = args[0].upper() query_name = args[0].upper()
# search name and namecn # search name and namecn
for name, plugincls in plugins.items(): for name, plugincls in plugins.items():
if not plugincls.enabled : if not plugincls.enabled:
continue continue
if query_name == name or query_name == plugincls.namecn: if query_name == name or query_name == plugincls.namecn:
ok, result = True, PluginManager().instances[name].get_help_text(isgroup=isgroup, isadmin=isadmin, verbose=True) ok, result = True, PluginManager().instances[name].get_help_text(isgroup=isgroup, isadmin=isadmin, verbose=True)
@@ -236,14 +253,14 @@ class Godcmd(Plugin):
elif cmd == "set_openai_api_key": elif cmd == "set_openai_api_key":
if len(args) == 1: if len(args) == 1:
user_data = conf().get_user_data(user) user_data = conf().get_user_data(user)
user_data['openai_api_key'] = args[0] user_data["openai_api_key"] = args[0]
ok, result = True, "你的OpenAI私有api_key已设置为" + args[0] ok, result = True, "你的OpenAI私有api_key已设置为" + args[0]
else: else:
ok, result = False, "请提供一个api_key" ok, result = False, "请提供一个api_key"
elif cmd == "reset_openai_api_key": elif cmd == "reset_openai_api_key":
try: try:
user_data = conf().get_user_data(user) user_data = conf().get_user_data(user)
user_data.pop('openai_api_key') user_data.pop("openai_api_key")
ok, result = True, "你的OpenAI私有api_key已清除" ok, result = True, "你的OpenAI私有api_key已清除"
except Exception as e: except Exception as e:
ok, result = False, "你没有设置私有api_key" ok, result = False, "你没有设置私有api_key"
@@ -255,12 +272,12 @@ class Godcmd(Plugin):
else: else:
ok, result = False, "当前对话机器人不支持重置会话" ok, result = False, "当前对话机器人不支持重置会话"
logger.debug("[Godcmd] command: %s by %s" % (cmd, user)) logger.debug("[Godcmd] command: %s by %s" % (cmd, user))
elif any(cmd in info['alias'] for info in ADMIN_COMMANDS.values()): elif any(cmd in info["alias"] for info in ADMIN_COMMANDS.values()):
if isadmin: if isadmin:
if isgroup: if isgroup:
ok, result = False, "群聊不可执行管理员指令" ok, result = False, "群聊不可执行管理员指令"
else: else:
cmd = next(c for c, info in ADMIN_COMMANDS.items() if cmd in info['alias']) cmd = next(c for c, info in ADMIN_COMMANDS.items() if cmd in info["alias"])
if cmd == "stop": if cmd == "stop":
self.isrunning = False self.isrunning = False
ok, result = True, "服务已暂停" ok, result = True, "服务已暂停"
@@ -278,13 +295,13 @@ class Godcmd(Plugin):
else: else:
ok, result = False, "当前对话机器人不支持重置会话" ok, result = False, "当前对话机器人不支持重置会话"
elif cmd == "debug": elif cmd == "debug":
logger.setLevel('DEBUG') logger.setLevel("DEBUG")
ok, result = True, "DEBUG模式已开启" ok, result = True, "DEBUG模式已开启"
elif cmd == "plist": elif cmd == "plist":
plugins = PluginManager().list_plugins() plugins = PluginManager().list_plugins()
ok = True ok = True
result = "插件列表:\n" result = "插件列表:\n"
for name,plugincls in plugins.items(): for name, plugincls in plugins.items():
result += f"{plugincls.name}_v{plugincls.version} {plugincls.priority} - " result += f"{plugincls.name}_v{plugincls.version} {plugincls.priority} - "
if plugincls.enabled: if plugincls.enabled:
result += "已启用\n" result += "已启用\n"
@@ -294,11 +311,11 @@ class Godcmd(Plugin):
new_plugins = PluginManager().scan_plugins() new_plugins = PluginManager().scan_plugins()
ok, result = True, "插件扫描完成" ok, result = True, "插件扫描完成"
PluginManager().activate_plugins() PluginManager().activate_plugins()
if len(new_plugins) >0 : if len(new_plugins) > 0:
result += "\n发现新插件:\n" result += "\n发现新插件:\n"
result += "\n".join([f"{p.name}_v{p.version}" for p in new_plugins]) result += "\n".join([f"{p.name}_v{p.version}" for p in new_plugins])
else : else:
result +=", 未发现新插件" result += ", 未发现新插件"
elif cmd == "setpri": elif cmd == "setpri":
if len(args) != 2: if len(args) != 2:
ok, result = False, "请提供插件名和优先级" ok, result = False, "请提供插件名和优先级"
@@ -350,42 +367,42 @@ class Godcmd(Plugin):
else: else:
ok, result = False, "需要管理员权限才能执行该指令" ok, result = False, "需要管理员权限才能执行该指令"
else: else:
trigger_prefix = conf().get('plugin_trigger_prefix',"$") trigger_prefix = conf().get("plugin_trigger_prefix", "$")
if trigger_prefix == "#": # 跟插件聊天指令前缀相同,继续递交 if trigger_prefix == "#": # 跟插件聊天指令前缀相同,继续递交
return return
ok, result = False, f"未知指令:{cmd}\n查看指令列表请输入#help \n" ok, result = False, f"未知指令:{cmd}\n查看指令列表请输入#help \n"
reply = Reply() reply = Reply()
if ok: if ok:
reply.type = ReplyType.INFO reply.type = ReplyType.INFO
else: else:
reply.type = ReplyType.ERROR reply.type = ReplyType.ERROR
reply.content = result reply.content = result
e_context['reply'] = reply e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS # 事件结束并跳过处理context的默认逻辑 e_context.action = EventAction.BREAK_PASS # 事件结束并跳过处理context的默认逻辑
elif not self.isrunning: elif not self.isrunning:
e_context.action = EventAction.BREAK_PASS e_context.action = EventAction.BREAK_PASS
def authenticate(self, userid, args, isadmin, isgroup) -> Tuple[bool,str] : def authenticate(self, userid, args, isadmin, isgroup) -> Tuple[bool, str]:
if isgroup: if isgroup:
return False,"请勿在群聊中认证" return False, "请勿在群聊中认证"
if isadmin: if isadmin:
return False,"管理员账号无需认证" return False, "管理员账号无需认证"
if len(args) != 1: if len(args) != 1:
return False,"请提供口令" return False, "请提供口令"
password = args[0] password = args[0]
if password == self.password: if password == self.password:
self.admin_users.append(userid) self.admin_users.append(userid)
return True,"认证成功" return True, "认证成功"
elif password == self.temp_password: elif password == self.temp_password:
self.admin_users.append(userid) self.admin_users.append(userid)
return True,"认证成功,请尽快设置口令" return True, "认证成功,请尽快设置口令"
else: else:
return False,"认证失败" return False, "认证失败"
def get_help_text(self, isadmin = False, isgroup = False, **kwargs): def get_help_text(self, isadmin=False, isgroup=False, **kwargs):
return get_help_text(isadmin, isgroup) return get_help_text(isadmin, isgroup)

View File

@@ -1 +1 @@
from .hello import * from .hello import *

View File

@@ -1,14 +1,21 @@
# encoding:utf-8 # encoding:utf-8
import plugins
from bridge.context import ContextType from bridge.context import ContextType
from bridge.reply import Reply, ReplyType from bridge.reply import Reply, ReplyType
from channel.chat_message import ChatMessage from channel.chat_message import ChatMessage
import plugins
from plugins import *
from common.log import logger from common.log import logger
from plugins import *
@plugins.register(name="Hello", desire_priority=-1, hidden=True, desc="A simple plugin that says hello", version="0.1", author="lanvent") @plugins.register(
name="Hello",
desire_priority=-1,
hidden=True,
desc="A simple plugin that says hello",
version="0.1",
author="lanvent",
)
class Hello(Plugin): class Hello(Plugin):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@@ -16,33 +23,50 @@ class Hello(Plugin):
logger.info("[Hello] inited") logger.info("[Hello] inited")
def on_handle_context(self, e_context: EventContext): def on_handle_context(self, e_context: EventContext):
if e_context["context"].type not in [
if e_context['context'].type != ContextType.TEXT: ContextType.TEXT,
ContextType.JOIN_GROUP,
ContextType.PATPAT,
]:
return return
content = e_context['context'].content if e_context["context"].type == ContextType.JOIN_GROUP:
e_context["context"].type = ContextType.TEXT
msg: ChatMessage = e_context["context"]["msg"]
e_context["context"].content = f'请你随机使用一种风格说一句问候语来欢迎新用户"{msg.actual_user_nickname}"加入群聊。'
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑
return
if e_context["context"].type == ContextType.PATPAT:
e_context["context"].type = ContextType.TEXT
msg: ChatMessage = e_context["context"]["msg"]
e_context["context"].content = f"请你随机使用一种风格介绍你自己,并告诉用户输入#help可以查看帮助信息。"
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑
return
content = e_context["context"].content
logger.debug("[Hello] on_handle_context. content: %s" % content) logger.debug("[Hello] on_handle_context. content: %s" % content)
if content == "Hello": if content == "Hello":
reply = Reply() reply = Reply()
reply.type = ReplyType.TEXT reply.type = ReplyType.TEXT
msg:ChatMessage = e_context['context']['msg'] msg: ChatMessage = e_context["context"]["msg"]
if e_context['context']['isgroup']: if e_context["context"]["isgroup"]:
reply.content = f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}" reply.content = f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}"
else: else:
reply.content = f"Hello, {msg.from_user_nickname}" reply.content = f"Hello, {msg.from_user_nickname}"
e_context['reply'] = reply e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS # 事件结束并跳过处理context的默认逻辑 e_context.action = EventAction.BREAK_PASS # 事件结束并跳过处理context的默认逻辑
if content == "Hi": if content == "Hi":
reply = Reply() reply = Reply()
reply.type = ReplyType.TEXT reply.type = ReplyType.TEXT
reply.content = "Hi" reply.content = "Hi"
e_context['reply'] = reply e_context["reply"] = reply
e_context.action = EventAction.BREAK # 事件结束进入默认处理逻辑一般会覆写reply e_context.action = EventAction.BREAK # 事件结束进入默认处理逻辑一般会覆写reply
if content == "End": if content == "End":
# 如果是文本消息"End",将请求转换成"IMAGE_CREATE"并将content设置为"The World" # 如果是文本消息"End",将请求转换成"IMAGE_CREATE"并将content设置为"The World"
e_context['context'].type = ContextType.IMAGE_CREATE e_context["context"].type = ContextType.IMAGE_CREATE
content = "The World" content = "The World"
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑

13
plugins/keyword/README.md Normal file
View File

@@ -0,0 +1,13 @@
# 目的
关键字匹配并回复
# 试用场景
目前是在微信公众号下面使用过。
# 使用步骤
1. 复制 `config.json.template``config.json`
2. 在关键字 `keyword` 新增需要关键字匹配的内容
3. 重启程序做验证
# 验证结果
![结果](test-keyword.png)

View File

@@ -0,0 +1 @@
from .keyword import *

View File

@@ -0,0 +1,5 @@
{
"keyword": {
"关键字匹配": "测试成功"
}
}

View File

@@ -0,0 +1,65 @@
# encoding:utf-8
import json
import os
import plugins
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from common.log import logger
from plugins import *
@plugins.register(
name="Keyword",
desire_priority=900,
hidden=True,
desc="关键词匹配过滤",
version="0.1",
author="fengyege.top",
)
class Keyword(Plugin):
def __init__(self):
super().__init__()
try:
curdir = os.path.dirname(__file__)
config_path = os.path.join(curdir, "config.json")
conf = None
if not os.path.exists(config_path):
logger.debug(f"[keyword]不存在配置文件{config_path}")
conf = {"keyword": {}}
with open(config_path, "w", encoding="utf-8") as f:
json.dump(conf, f, indent=4)
else:
logger.debug(f"[keyword]加载配置文件{config_path}")
with open(config_path, "r", encoding="utf-8") as f:
conf = json.load(f)
# 加载关键词
self.keyword = conf["keyword"]
logger.info("[keyword] {}".format(self.keyword))
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
logger.info("[keyword] inited.")
except Exception as e:
logger.warn("[keyword] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/keyword .")
raise e
def on_handle_context(self, e_context: EventContext):
if e_context["context"].type != ContextType.TEXT:
return
content = e_context["context"].content.strip()
logger.debug("[keyword] on_handle_context. content: %s" % content)
if content in self.keyword:
logger.debug(f"[keyword] 匹配到关键字【{content}")
reply_text = self.keyword[content]
reply = Reply()
reply.type = ReplyType.TEXT
reply.content = reply_text
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS # 事件结束并跳过处理context的默认逻辑
def get_help_text(self, **kwargs):
help_text = "关键词过滤"
return help_text

Binary file not shown.

After

Width:  |  Height:  |  Size: 12 KiB

View File

@@ -3,4 +3,4 @@ class Plugin:
self.handlers = {} self.handlers = {}
def get_help_text(self, **kwargs): def get_help_text(self, **kwargs):
return "暂无帮助信息" return "暂无帮助信息"

View File

@@ -5,17 +5,19 @@ import importlib.util
import json import json
import os import os
import sys import sys
from common.log import logger
from common.singleton import singleton from common.singleton import singleton
from common.sorted_dict import SortedDict from common.sorted_dict import SortedDict
from .event import *
from common.log import logger
from config import conf from config import conf
from .event import *
@singleton @singleton
class PluginManager: class PluginManager:
def __init__(self): def __init__(self):
self.plugins = SortedDict(lambda k,v: v.priority,reverse=True) self.plugins = SortedDict(lambda k, v: v.priority, reverse=True)
self.listening_plugins = {} self.listening_plugins = {}
self.instances = {} self.instances = {}
self.pconf = {} self.pconf = {}
@@ -26,17 +28,18 @@ class PluginManager:
def wrapper(plugincls): def wrapper(plugincls):
plugincls.name = name plugincls.name = name
plugincls.priority = desire_priority plugincls.priority = desire_priority
plugincls.desc = kwargs.get('desc') plugincls.desc = kwargs.get("desc")
plugincls.author = kwargs.get('author') plugincls.author = kwargs.get("author")
plugincls.path = self.current_plugin_path plugincls.path = self.current_plugin_path
plugincls.version = kwargs.get('version') if kwargs.get('version') != None else "1.0" plugincls.version = kwargs.get("version") if kwargs.get("version") != None else "1.0"
plugincls.namecn = kwargs.get('namecn') if kwargs.get('namecn') != None else name plugincls.namecn = kwargs.get("namecn") if kwargs.get("namecn") != None else name
plugincls.hidden = kwargs.get('hidden') if kwargs.get('hidden') != None else False plugincls.hidden = kwargs.get("hidden") if kwargs.get("hidden") != None else False
plugincls.enabled = True plugincls.enabled = True
if self.current_plugin_path == None: if self.current_plugin_path == None:
raise Exception("Plugin path not set") raise Exception("Plugin path not set")
self.plugins[name.upper()] = plugincls self.plugins[name.upper()] = plugincls
logger.info("Plugin %s_v%s registered, path=%s" % (name, plugincls.version, plugincls.path)) logger.info("Plugin %s_v%s registered, path=%s" % (name, plugincls.version, plugincls.path))
return wrapper return wrapper
def save_config(self): def save_config(self):
@@ -50,10 +53,10 @@ class PluginManager:
if os.path.exists("./plugins/plugins.json"): if os.path.exists("./plugins/plugins.json"):
with open("./plugins/plugins.json", "r", encoding="utf-8") as f: with open("./plugins/plugins.json", "r", encoding="utf-8") as f:
pconf = json.load(f) pconf = json.load(f)
pconf['plugins'] = SortedDict(lambda k,v: v["priority"],pconf['plugins'],reverse=True) pconf["plugins"] = SortedDict(lambda k, v: v["priority"], pconf["plugins"], reverse=True)
else: else:
modified = True modified = True
pconf = {"plugins": SortedDict(lambda k,v: v["priority"],reverse=True)} pconf = {"plugins": SortedDict(lambda k, v: v["priority"], reverse=True)}
self.pconf = pconf self.pconf = pconf
if modified: if modified:
self.save_config() self.save_config()
@@ -67,7 +70,7 @@ class PluginManager:
plugin_path = os.path.join(plugins_dir, plugin_name) plugin_path = os.path.join(plugins_dir, plugin_name)
if os.path.isdir(plugin_path): if os.path.isdir(plugin_path):
# 判断插件是否包含同名__init__.py文件 # 判断插件是否包含同名__init__.py文件
main_module_path = os.path.join(plugin_path,"__init__.py") main_module_path = os.path.join(plugin_path, "__init__.py")
if os.path.isfile(main_module_path): if os.path.isfile(main_module_path):
# 导入插件 # 导入插件
import_path = "plugins.{}".format(plugin_name) import_path = "plugins.{}".format(plugin_name)
@@ -77,7 +80,7 @@ class PluginManager:
if self.loaded[plugin_path] == None: if self.loaded[plugin_path] == None:
logger.info("reload module %s" % plugin_name) logger.info("reload module %s" % plugin_name)
self.loaded[plugin_path] = importlib.reload(sys.modules[import_path]) self.loaded[plugin_path] = importlib.reload(sys.modules[import_path])
dependent_module_names = [name for name in sys.modules.keys() if name.startswith( import_path+ '.')] dependent_module_names = [name for name in sys.modules.keys() if name.startswith(import_path + ".")]
for name in dependent_module_names: for name in dependent_module_names:
logger.info("reload module %s" % name) logger.info("reload module %s" % name)
importlib.reload(sys.modules[name]) importlib.reload(sys.modules[name])
@@ -96,11 +99,14 @@ class PluginManager:
if rawname not in pconf["plugins"]: if rawname not in pconf["plugins"]:
modified = True modified = True
logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name) logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name)
pconf["plugins"][rawname] = {"enabled": plugincls.enabled, "priority": plugincls.priority} pconf["plugins"][rawname] = {
"enabled": plugincls.enabled,
"priority": plugincls.priority,
}
else: else:
self.plugins[name].enabled = pconf["plugins"][rawname]["enabled"] self.plugins[name].enabled = pconf["plugins"][rawname]["enabled"]
self.plugins[name].priority = pconf["plugins"][rawname]["priority"] self.plugins[name].priority = pconf["plugins"][rawname]["priority"]
self.plugins._update_heap(name) # 更新下plugins中的顺序 self.plugins._update_heap(name) # 更新下plugins中的顺序
if modified: if modified:
self.save_config() self.save_config()
return new_plugins return new_plugins
@@ -109,7 +115,7 @@ class PluginManager:
for event in self.listening_plugins.keys(): for event in self.listening_plugins.keys():
self.listening_plugins[event].sort(key=lambda name: self.plugins[name].priority, reverse=True) self.listening_plugins[event].sort(key=lambda name: self.plugins[name].priority, reverse=True)
def activate_plugins(self): # 生成新开启的插件实例 def activate_plugins(self): # 生成新开启的插件实例
failed_plugins = [] failed_plugins = []
for name, plugincls in self.plugins.items(): for name, plugincls in self.plugins.items():
if plugincls.enabled: if plugincls.enabled:
@@ -117,7 +123,7 @@ class PluginManager:
try: try:
instance = plugincls() instance = plugincls()
except Exception as e: except Exception as e:
logger.warn("Failed to init %s, diabled. %s" % (name, e)) logger.error("Failed to init %s, diabled. %s" % (name, e))
self.disable_plugin(name) self.disable_plugin(name)
failed_plugins.append(name) failed_plugins.append(name)
continue continue
@@ -129,7 +135,7 @@ class PluginManager:
self.refresh_order() self.refresh_order()
return failed_plugins return failed_plugins
def reload_plugin(self, name:str): def reload_plugin(self, name: str):
name = name.upper() name = name.upper()
if name in self.instances: if name in self.instances:
for event in self.listening_plugins: for event in self.listening_plugins:
@@ -139,13 +145,13 @@ class PluginManager:
self.activate_plugins() self.activate_plugins()
return True return True
return False return False
def load_plugins(self): def load_plugins(self):
self.load_config() self.load_config()
self.scan_plugins() self.scan_plugins()
pconf = self.pconf pconf = self.pconf
logger.debug("plugins.json config={}".format(pconf)) logger.debug("plugins.json config={}".format(pconf))
for name,plugin in pconf["plugins"].items(): for name, plugin in pconf["plugins"].items():
if name.upper() not in self.plugins: if name.upper() not in self.plugins:
logger.error("Plugin %s not found, but found in plugins.json" % name) logger.error("Plugin %s not found, but found in plugins.json" % name)
self.activate_plugins() self.activate_plugins()
@@ -154,12 +160,12 @@ class PluginManager:
if e_context.event in self.listening_plugins: if e_context.event in self.listening_plugins:
for name in self.listening_plugins[e_context.event]: for name in self.listening_plugins[e_context.event]:
if self.plugins[name].enabled and e_context.action == EventAction.CONTINUE: if self.plugins[name].enabled and e_context.action == EventAction.CONTINUE:
logger.debug("Plugin %s triggered by event %s" % (name,e_context.event)) logger.debug("Plugin %s triggered by event %s" % (name, e_context.event))
instance = self.instances[name] instance = self.instances[name]
instance.handlers[e_context.event](e_context, *args, **kwargs) instance.handlers[e_context.event](e_context, *args, **kwargs)
return e_context return e_context
def set_plugin_priority(self, name:str, priority:int): def set_plugin_priority(self, name: str, priority: int):
name = name.upper() name = name.upper()
if name not in self.plugins: if name not in self.plugins:
return False return False
@@ -174,11 +180,11 @@ class PluginManager:
self.refresh_order() self.refresh_order()
return True return True
def enable_plugin(self, name:str): def enable_plugin(self, name: str):
name = name.upper() name = name.upper()
if name not in self.plugins: if name not in self.plugins:
return False, "插件不存在" return False, "插件不存在"
if not self.plugins[name].enabled : if not self.plugins[name].enabled:
self.plugins[name].enabled = True self.plugins[name].enabled = True
rawname = self.plugins[name].name rawname = self.plugins[name].name
self.pconf["plugins"][rawname]["enabled"] = True self.pconf["plugins"][rawname]["enabled"] = True
@@ -188,39 +194,41 @@ class PluginManager:
return False, "插件开启失败" return False, "插件开启失败"
return True, "插件已开启" return True, "插件已开启"
return True, "插件已开启" return True, "插件已开启"
def disable_plugin(self, name:str): def disable_plugin(self, name: str):
name = name.upper() name = name.upper()
if name not in self.plugins: if name not in self.plugins:
return False return False
if self.plugins[name].enabled : if self.plugins[name].enabled:
self.plugins[name].enabled = False self.plugins[name].enabled = False
rawname = self.plugins[name].name rawname = self.plugins[name].name
self.pconf["plugins"][rawname]["enabled"] = False self.pconf["plugins"][rawname]["enabled"] = False
self.save_config() self.save_config()
return True return True
return True return True
def list_plugins(self): def list_plugins(self):
return self.plugins return self.plugins
def install_plugin(self, repo:str): def install_plugin(self, repo: str):
try: try:
import common.package_manager as pkgmgr import common.package_manager as pkgmgr
pkgmgr.check_dulwich() pkgmgr.check_dulwich()
except Exception as e: except Exception as e:
logger.error("Failed to install plugin, {}".format(e)) logger.error("Failed to install plugin, {}".format(e))
return False, "无法导入dulwich安装插件失败" return False, "无法导入dulwich安装插件失败"
import re import re
from dulwich import porcelain from dulwich import porcelain
logger.info("clone git repo: {}".format(repo)) logger.info("clone git repo: {}".format(repo))
match = re.match(r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo) match = re.match(r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo)
if not match: if not match:
try: try:
with open("./plugins/source.json","r", encoding="utf-8") as f: with open("./plugins/source.json", "r", encoding="utf-8") as f:
source = json.load(f) source = json.load(f)
if repo in source["repo"]: if repo in source["repo"]:
repo = source["repo"][repo]["url"] repo = source["repo"][repo]["url"]
@@ -232,42 +240,53 @@ class PluginManager:
except Exception as e: except Exception as e:
logger.error("Failed to install plugin, {}".format(e)) logger.error("Failed to install plugin, {}".format(e))
return False, "安装插件失败,请检查仓库地址是否正确" return False, "安装插件失败,请检查仓库地址是否正确"
dirname = os.path.join("./plugins",match.group(4)) dirname = os.path.join("./plugins", match.group(4))
try: try:
repo = porcelain.clone(repo, dirname, checkout=True) repo = porcelain.clone(repo, dirname, checkout=True)
if os.path.exists(os.path.join(dirname,"requirements.txt")): if os.path.exists(os.path.join(dirname, "requirements.txt")):
logger.info("detect requirements.txtinstalling...") logger.info("detect requirements.txtinstalling...")
pkgmgr.install_requirements(os.path.join(dirname,"requirements.txt")) pkgmgr.install_requirements(os.path.join(dirname, "requirements.txt"))
return True, "安装插件成功,请使用 #scanp 命令扫描插件或重启程序,开启前请检查插件是否需要配置" return True, "安装插件成功,请使用 #scanp 命令扫描插件或重启程序,开启前请检查插件是否需要配置"
except Exception as e: except Exception as e:
logger.error("Failed to install plugin, {}".format(e)) logger.error("Failed to install plugin, {}".format(e))
return False, "安装插件失败,"+str(e) return False, "安装插件失败," + str(e)
def update_plugin(self, name:str): def update_plugin(self, name: str):
try: try:
import common.package_manager as pkgmgr import common.package_manager as pkgmgr
pkgmgr.check_dulwich() pkgmgr.check_dulwich()
except Exception as e: except Exception as e:
logger.error("Failed to install plugin, {}".format(e)) logger.error("Failed to install plugin, {}".format(e))
return False, "无法导入dulwich更新插件失败" return False, "无法导入dulwich更新插件失败"
from dulwich import porcelain from dulwich import porcelain
name = name.upper() name = name.upper()
if name not in self.plugins: if name not in self.plugins:
return False, "插件不存在" return False, "插件不存在"
if name in ["HELLO","GODCMD","ROLE","TOOL","BDUNIT","BANWORDS","FINISH","DUNGEON"]: if name in [
"HELLO",
"GODCMD",
"ROLE",
"TOOL",
"BDUNIT",
"BANWORDS",
"FINISH",
"DUNGEON",
]:
return False, "预置插件无法更新,请更新主程序仓库" return False, "预置插件无法更新,请更新主程序仓库"
dirname = self.plugins[name].path dirname = self.plugins[name].path
try: try:
porcelain.pull(dirname, "origin") porcelain.pull(dirname, "origin")
if os.path.exists(os.path.join(dirname,"requirements.txt")): if os.path.exists(os.path.join(dirname, "requirements.txt")):
logger.info("detect requirements.txtinstalling...") logger.info("detect requirements.txtinstalling...")
pkgmgr.install_requirements(os.path.join(dirname,"requirements.txt")) pkgmgr.install_requirements(os.path.join(dirname, "requirements.txt"))
return True, "更新插件成功,请重新运行程序" return True, "更新插件成功,请重新运行程序"
except Exception as e: except Exception as e:
logger.error("Failed to update plugin, {}".format(e)) logger.error("Failed to update plugin, {}".format(e))
return False, "更新插件失败,"+str(e) return False, "更新插件失败," + str(e)
def uninstall_plugin(self, name:str): def uninstall_plugin(self, name: str):
name = name.upper() name = name.upper()
if name not in self.plugins: if name not in self.plugins:
return False, "插件不存在" return False, "插件不存在"
@@ -276,6 +295,7 @@ class PluginManager:
dirname = self.plugins[name].path dirname = self.plugins[name].path
try: try:
import shutil import shutil
shutil.rmtree(dirname) shutil.rmtree(dirname)
rawname = self.plugins[name].name rawname = self.plugins[name].name
for event in self.listening_plugins: for event in self.listening_plugins:
@@ -288,4 +308,4 @@ class PluginManager:
return True, "卸载插件成功" return True, "卸载插件成功"
except Exception as e: except Exception as e:
logger.error("Failed to uninstall plugin, {}".format(e)) logger.error("Failed to uninstall plugin, {}".format(e))
return False, "卸载插件失败,请手动删除文件夹完成卸载,"+str(e) return False, "卸载插件失败,请手动删除文件夹完成卸载," + str(e)

View File

@@ -1 +1 @@
from .role import * from .role import *

View File

@@ -2,17 +2,18 @@
import json import json
import os import os
import plugins
from bridge.bridge import Bridge from bridge.bridge import Bridge
from bridge.context import ContextType from bridge.context import ContextType
from bridge.reply import Reply, ReplyType from bridge.reply import Reply, ReplyType
from common import const from common import const
from config import conf
import plugins
from plugins import *
from common.log import logger from common.log import logger
from config import conf
from plugins import *
class RolePlay(): class RolePlay:
def __init__(self, bot, sessionid, desc, wrapper=None): def __init__(self, bot, sessionid, desc, wrapper=None):
self.bot = bot self.bot = bot
self.sessionid = sessionid self.sessionid = sessionid
@@ -25,12 +26,20 @@ class RolePlay():
def action(self, user_action): def action(self, user_action):
session = self.bot.sessions.build_session(self.sessionid) session = self.bot.sessions.build_session(self.sessionid)
if session.system_prompt != self.desc: # 目前没有触发session过期事件这里先简单判断然后重置 if session.system_prompt != self.desc: # 目前没有触发session过期事件这里先简单判断然后重置
session.set_system_prompt(self.desc) session.set_system_prompt(self.desc)
prompt = self.wrapper % user_action prompt = self.wrapper % user_action
return prompt return prompt
@plugins.register(name="Role", desire_priority=0, namecn="角色扮演", desc="为你的Bot设置预设角色", version="1.0", author="lanvent")
@plugins.register(
name="Role",
desire_priority=0,
namecn="角色扮演",
desc="为你的Bot设置预设角色",
version="1.0",
author="lanvent",
)
class Role(Plugin): class Role(Plugin):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@@ -39,7 +48,7 @@ class Role(Plugin):
try: try:
with open(config_path, "r", encoding="utf-8") as f: with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f) config = json.load(f)
self.tags = { tag:(desc,[]) for tag,desc in config["tags"].items()} self.tags = {tag: (desc, []) for tag, desc in config["tags"].items()}
self.roles = {} self.roles = {}
for role in config["roles"]: for role in config["roles"]:
self.roles[role["title"].lower()] = role self.roles[role["title"].lower()] = role
@@ -65,7 +74,7 @@ class Role(Plugin):
logger.warn("[Role] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role .") logger.warn("[Role] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role .")
raise e raise e
def get_role(self, name, find_closest=True, min_sim = 0.35): def get_role(self, name, find_closest=True, min_sim=0.35):
name = name.lower() name = name.lower()
found_role = None found_role = None
if name in self.roles: if name in self.roles:
@@ -75,6 +84,7 @@ class Role(Plugin):
def str_simularity(a, b): def str_simularity(a, b):
return difflib.SequenceMatcher(None, a, b).ratio() return difflib.SequenceMatcher(None, a, b).ratio()
max_sim = min_sim max_sim = min_sim
max_role = None max_role = None
for role in self.roles: for role in self.roles:
@@ -86,25 +96,24 @@ class Role(Plugin):
return found_role return found_role
def on_handle_context(self, e_context: EventContext): def on_handle_context(self, e_context: EventContext):
if e_context["context"].type != ContextType.TEXT:
if e_context['context'].type != ContextType.TEXT:
return return
bottype = Bridge().get_bot_type("chat") bottype = Bridge().get_bot_type("chat")
if bottype not in (const.CHATGPT, const.OPEN_AI): if bottype not in (const.CHATGPT, const.OPEN_AI):
return return
bot = Bridge().get_bot("chat") bot = Bridge().get_bot("chat")
content = e_context['context'].content[:] content = e_context["context"].content[:]
clist = e_context['context'].content.split(maxsplit=1) clist = e_context["context"].content.split(maxsplit=1)
desckey = None desckey = None
customize = False customize = False
sessionid = e_context['context']['session_id'] sessionid = e_context["context"]["session_id"]
trigger_prefix = conf().get('plugin_trigger_prefix', "$") trigger_prefix = conf().get("plugin_trigger_prefix", "$")
if clist[0] == f"{trigger_prefix}停止扮演": if clist[0] == f"{trigger_prefix}停止扮演":
if sessionid in self.roleplays: if sessionid in self.roleplays:
self.roleplays[sessionid].reset() self.roleplays[sessionid].reset()
del self.roleplays[sessionid] del self.roleplays[sessionid]
reply = Reply(ReplyType.INFO, "角色扮演结束!") reply = Reply(ReplyType.INFO, "角色扮演结束!")
e_context['reply'] = reply e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS e_context.action = EventAction.BREAK_PASS
return return
elif clist[0] == f"{trigger_prefix}角色": elif clist[0] == f"{trigger_prefix}角色":
@@ -114,10 +123,10 @@ class Role(Plugin):
elif clist[0] == f"{trigger_prefix}设定扮演": elif clist[0] == f"{trigger_prefix}设定扮演":
customize = True customize = True
elif clist[0] == f"{trigger_prefix}角色类型": elif clist[0] == f"{trigger_prefix}角色类型":
if len(clist) >1: if len(clist) > 1:
tag = clist[1].strip() tag = clist[1].strip()
help_text = "角色列表:\n" help_text = "角色列表:\n"
for key,value in self.tags.items(): for key, value in self.tags.items():
if value[0] == tag: if value[0] == tag:
tag = key tag = key
break break
@@ -130,13 +139,13 @@ class Role(Plugin):
else: else:
help_text = f"未知角色类型。\n" help_text = f"未知角色类型。\n"
help_text += "目前的角色类型有: \n" help_text += "目前的角色类型有: \n"
help_text += "".join([self.tags[tag][0] for tag in self.tags])+"\n" help_text += "".join([self.tags[tag][0] for tag in self.tags]) + "\n"
else: else:
help_text = f"请输入角色类型。\n" help_text = f"请输入角色类型。\n"
help_text += "目前的角色类型有: \n" help_text += "目前的角色类型有: \n"
help_text += "".join([self.tags[tag][0] for tag in self.tags])+"\n" help_text += "".join([self.tags[tag][0] for tag in self.tags]) + "\n"
reply = Reply(ReplyType.INFO, help_text) reply = Reply(ReplyType.INFO, help_text)
e_context['reply'] = reply e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS e_context.action = EventAction.BREAK_PASS
return return
elif sessionid not in self.roleplays: elif sessionid not in self.roleplays:
@@ -145,42 +154,47 @@ class Role(Plugin):
if desckey is not None: if desckey is not None:
if len(clist) == 1 or (len(clist) > 1 and clist[1].lower() in ["help", "帮助"]): if len(clist) == 1 or (len(clist) > 1 and clist[1].lower() in ["help", "帮助"]):
reply = Reply(ReplyType.INFO, self.get_help_text(verbose=True)) reply = Reply(ReplyType.INFO, self.get_help_text(verbose=True))
e_context['reply'] = reply e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS e_context.action = EventAction.BREAK_PASS
return return
role = self.get_role(clist[1]) role = self.get_role(clist[1])
if role is None: if role is None:
reply = Reply(ReplyType.ERROR, "角色不存在") reply = Reply(ReplyType.ERROR, "角色不存在")
e_context['reply'] = reply e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS e_context.action = EventAction.BREAK_PASS
return return
else: else:
self.roleplays[sessionid] = RolePlay(bot, sessionid, self.roles[role][desckey], self.roles[role].get("wrapper","%s")) self.roleplays[sessionid] = RolePlay(
reply = Reply(ReplyType.INFO, f"预设角色为 {role}:\n"+self.roles[role][desckey]) bot,
e_context['reply'] = reply sessionid,
self.roles[role][desckey],
self.roles[role].get("wrapper", "%s"),
)
reply = Reply(ReplyType.INFO, f"预设角色为 {role}:\n" + self.roles[role][desckey])
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS e_context.action = EventAction.BREAK_PASS
elif customize == True: elif customize == True:
self.roleplays[sessionid] = RolePlay(bot, sessionid, clist[1], "%s") self.roleplays[sessionid] = RolePlay(bot, sessionid, clist[1], "%s")
reply = Reply(ReplyType.INFO, f"角色设定为:\n{clist[1]}") reply = Reply(ReplyType.INFO, f"角色设定为:\n{clist[1]}")
e_context['reply'] = reply e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS e_context.action = EventAction.BREAK_PASS
else: else:
prompt = self.roleplays[sessionid].action(content) prompt = self.roleplays[sessionid].action(content)
e_context['context'].type = ContextType.TEXT e_context["context"].type = ContextType.TEXT
e_context['context'].content = prompt e_context["context"].content = prompt
e_context.action = EventAction.BREAK e_context.action = EventAction.BREAK
def get_help_text(self, verbose=False, **kwargs): def get_help_text(self, verbose=False, **kwargs):
help_text = "让机器人扮演不同的角色。\n" help_text = "让机器人扮演不同的角色。\n"
if not verbose: if not verbose:
return help_text return help_text
trigger_prefix = conf().get('plugin_trigger_prefix', "$") trigger_prefix = conf().get("plugin_trigger_prefix", "$")
help_text = f"使用方法:\n{trigger_prefix}角色"+" 预设角色名: 设定角色为{预设角色名}\n"+f"{trigger_prefix}role"+" 预设角色名: 同上,但使用英文设定。\n" help_text = f"使用方法:\n{trigger_prefix}角色" + " 预设角色名: 设定角色为{预设角色名}\n" + f"{trigger_prefix}role" + " 预设角色名: 同上,但使用英文设定。\n"
help_text += f"{trigger_prefix}设定扮演"+" 角色设定: 设定自定义角色人设为{角色设定}\n" help_text += f"{trigger_prefix}设定扮演" + " 角色设定: 设定自定义角色人设为{角色设定}\n"
help_text += f"{trigger_prefix}停止扮演: 清除设定的角色。\n" help_text += f"{trigger_prefix}停止扮演: 清除设定的角色。\n"
help_text += f"{trigger_prefix}角色类型"+" 角色类型: 查看某类{角色类型}的所有预设角色,为所有时输出所有预设角色。\n" help_text += f"{trigger_prefix}角色类型" + " 角色类型: 查看某类{角色类型}的所有预设角色,为所有时输出所有预设角色。\n"
help_text += "\n目前的角色类型有: \n" help_text += "\n目前的角色类型有: \n"
help_text += "".join([self.tags[tag][0] for tag in self.tags])+"\n" help_text += "".join([self.tags[tag][0] for tag in self.tags]) + "\n"
help_text += f"\n命令例子: \n{trigger_prefix}角色 写作助理\n" help_text += f"\n命令例子: \n{trigger_prefix}角色 写作助理\n"
help_text += f"{trigger_prefix}角色类型 所有\n" help_text += f"{trigger_prefix}角色类型 所有\n"
help_text += f"{trigger_prefix}停止扮演\n" help_text += f"{trigger_prefix}停止扮演\n"

View File

@@ -428,4 +428,4 @@
] ]
} }
] ]
} }

View File

@@ -1,16 +1,16 @@
{ {
"repo": { "repo": {
"sdwebui": { "sdwebui": {
"url": "https://github.com/lanvent/plugin_sdwebui.git", "url": "https://github.com/lanvent/plugin_sdwebui.git",
"desc": "利用stable-diffusion画图的插件" "desc": "利用stable-diffusion画图的插件"
}, },
"replicate": { "replicate": {
"url": "https://github.com/lanvent/plugin_replicate.git", "url": "https://github.com/lanvent/plugin_replicate.git",
"desc": "利用replicate api画图的插件" "desc": "利用replicate api画图的插件"
}, },
"summary": { "summary": {
"url": "https://github.com/lanvent/plugin_summary.git", "url": "https://github.com/lanvent/plugin_summary.git",
"desc": "总结聊天记录的插件" "desc": "总结聊天记录的插件"
}
} }
} }
}

View File

@@ -1,17 +1,29 @@
## 插件描述 ## 插件描述
一个能让chatgpt联网搜索数字运算的插件将赋予强大且丰富的扩展能力 一个能让chatgpt联网搜索数字运算的插件将赋予强大且丰富的扩展能力
使用该插件需在机器人回复你的前提下,在对话内容前加$tool仅输入$tool将返回tool插件帮助信息用于测试插件是否加载成功 使用该插件需在机器人回复你的前提下,在对话内容前加$tool仅输入$tool将返回tool插件帮助信息用于测试插件是否加载成功
### 本插件所有工具同步存放至专用仓库:[chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub) ### 本插件所有工具同步存放至专用仓库:[chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub)
## 使用说明 ## 使用说明
使用该插件后将默认使用4个工具, 无需额外配置长期生效: 使用该插件后将默认使用4个工具, 无需额外配置长期生效:
### 1. python ### 1. python
###### python解释器使用它来解释执行python指令可以配合你想要chatgpt生成的代码输出结果或执行事务 ###### python解释器使用它来解释执行python指令可以配合你想要chatgpt生成的代码输出结果或执行事务
### 2. url-get ### 2. 访问网页的工具汇总(默认url-get)
#### 2.1 url-get
###### 往往用来获取某个网站具体内容,结果可能会被反爬策略影响 ###### 往往用来获取某个网站具体内容,结果可能会被反爬策略影响
#### 2.2 browser
###### 浏览器功能与2.1类似,但能更好模拟,不会被识别为爬虫影响获取网站内容
> 注1url-get默认配置、browser需额外配置browser依赖google-chrome你需要提前安装好
> 注2browser默认使用summary tool 分段总结长文本信息tokens可能会大量消耗
这是debian端安装google-chrome教程其他系统请执行查找
> https://www.linuxjournal.com/content/how-can-you-install-google-browser-debian
### 3. terminal ### 3. terminal
###### 在你运行的电脑里执行shell命令可以配合你想要chatgpt生成的代码使用给予自然语言控制手段 ###### 在你运行的电脑里执行shell命令可以配合你想要chatgpt生成的代码使用给予自然语言控制手段
@@ -23,63 +35,99 @@
> meteo调优记录https://github.com/zhayujie/chatgpt-on-wechat/issues/776#issuecomment-1500771334 > meteo调优记录https://github.com/zhayujie/chatgpt-on-wechat/issues/776#issuecomment-1500771334
## 使用本插件对话prompt技巧 ## 使用本插件对话prompt技巧
### 1. 有指引的询问 ### 1. 有指引的询问
#### 例如: #### 例如:
- 总结这个链接的内容 https://github.com/goldfishh/chatgpt-tool-hub - 总结这个链接的内容 https://github.com/goldfishh/chatgpt-tool-hub
- 使用Terminal执行curl cip.cc - 使用Terminal执行curl cip.cc
- 使用python查询今天日期 - 使用python查询今天日期
### 2. 使用搜索引擎工具 ### 2. 使用搜索引擎工具
- 如果有搜索工具就能让chatgpt获取到你的未传达清楚的上下文信息比如chatgpt不知道你的地理位置现在时间等所以无法查询到天气 - 如果有搜索工具就能让chatgpt获取到你的未传达清楚的上下文信息比如chatgpt不知道你的地理位置现在时间等所以无法查询到天气
## 其他工具 ## 其他工具
### 5. wikipedia ### 5. wikipedia
###### 可以回答你想要知道确切的人事物 ###### 可以回答你想要知道确切的人事物
### 6. news * ### 6. 新闻类工具
#### 6.1. news-api *
###### 从全球 80,000 多个信息源中获取当前和历史新闻文章 ###### 从全球 80,000 多个信息源中获取当前和历史新闻文章
### 7. morning-news * #### 6.2. morning-news *
###### 每日60秒早报每天凌晨一点更新本工具使用了[alapi-每日60秒早报](https://alapi.cn/api/view/93) ###### 每日60秒早报每天凌晨一点更新本工具使用了[alapi-每日60秒早报](https://alapi.cn/api/view/93)
> 该tool每天返回内容相同 > 该tool每天返回内容相同
### 8. bing-search * #### 6.3. finance-news
###### 获取实时的金融财政新闻
> 该工具需要解决browser tool 的google-chrome依赖安装
### 7. bing-search *
###### bing搜索引擎从此你不用再烦恼搜索要用哪些关键词 ###### bing搜索引擎从此你不用再烦恼搜索要用哪些关键词
### 9. wolfram-alpha * ### 8. wolfram-alpha *
###### 知识搜索引擎、科学问答系统,常用于专业学科计算 ###### 知识搜索引擎、科学问答系统,常用于专业学科计算
### 10. google-search * ### 9. google-search *
###### google搜索引擎申请流程较bing-search繁琐 ###### google搜索引擎申请流程较bing-search繁琐
###### 注1带*工具需要获取api-key才能使用部分工具需要外网支持
### 10. arxiv(dev 开发中)
###### 用于查找论文
### 11. debug(dev 开发中目前没有接入wechat)
###### 当bot遇到无法确定的信息时将会向你寻求帮助的工具
### 12. summary
###### 总结工具,该工具必须输入一个本地文件的绝对路径
> 该工具目前是和其他工具配合使用,暂未测试单独使用效果
### 13. image2text
###### 将图片转换成文字底层调用imageCaption模型该工具必须输入一个本地文件的绝对路径
### 14. searxng-search *
###### 一个私有化的搜索引擎工具
> 安装教程https://docs.searxng.org/admin/installation.html
---
###### 注1带*工具需要获取api-key才能使用(在config.json内的kwargs添加项),部分工具需要外网支持
#### [申请方法](https://github.com/goldfishh/chatgpt-tool-hub/blob/master/docs/apply_optional_tool.md) #### [申请方法](https://github.com/goldfishh/chatgpt-tool-hub/blob/master/docs/apply_optional_tool.md)
## config.json 配置说明 ## config.json 配置说明
###### 默认工具无需配置,其它工具需手动配置,一个例子: ###### 默认工具无需配置,其它工具需手动配置,一个例子:
```json ```json
{ {
"tools": ["wikipedia"], // 填入你想用到的额外工具名 "tools": ["wikipedia", "你想要添加的其他工具"], // 填入你想用到的额外工具名
"kwargs": { "kwargs": {
"request_timeout": 60, // openai接口超时时间 "debug": true, // 当你遇到问题求助时,需要配置
"request_timeout": 120, // openai接口超时时间
"no_default": false, // 是否不使用默认的4个工具 "no_default": false, // 是否不使用默认的4个工具
"OPTIONAL_API_NAME": "OPTIONAL_API_KEY" // 带*工具需要申请api-key在这里填入api_name参考前述`申请方法` // 带*工具需要申请api-key在这里填入api_name参考前述`申请方法`
} }
} }
``` ```
config.json文件非必须未创建仍可使用本tool带*工具需在kwargs填入对应api-key键值对 config.json文件非必须未创建仍可使用本tool带*工具需在kwargs填入对应api-key键值对
- `tools`:本插件初始化时加载的工具, 目前可选集:["wikipedia", "wolfram-alpha", "bing-search", "google-search", "news", "morning-news"] & 默认工具除wikipedia工具之外均需要申请api-key - `tools`:本插件初始化时加载的工具, 目前可选集:["wikipedia", "wolfram-alpha", "bing-search", "google-search", "news"] & 默认工具除wikipedia工具之外均需要申请api-key
- `kwargs`:工具执行时的配置,一般在这里存放**api-key**,或环境配置 - `kwargs`:工具执行时的配置,一般在这里存放**api-key**,或环境配置
- `debug`: 输出chatgpt-tool-hub额外信息用于调试
- `request_timeout`: 访问openai接口的超时时间默认与wechat-on-chatgpt配置一致可单独配置 - `request_timeout`: 访问openai接口的超时时间默认与wechat-on-chatgpt配置一致可单独配置
- `no_default`: 用于配置默认加载4个工具的行为如果为true则仅使用tools列表工具不加载默认工具 - `no_default`: 用于配置默认加载4个工具的行为如果为true则仅使用tools列表工具不加载默认工具
- `top_k_results`: 控制所有有关搜索的工具返回条目数数字越高则参考信息越多但无用信息可能干扰判断该值一般为2 - `top_k_results`: 控制所有有关搜索的工具返回条目数数字越高则参考信息越多但无用信息可能干扰判断该值一般为2
- `model_name`: 用于控制tool插件底层使用的llm模型目前暂未测试3.5以外的模型,一般保持默认 - `model_name`: 用于控制tool插件底层使用的llm模型目前暂未测试3.5以外的模型,一般保持默认
---
## 备注 ## 备注
- 强烈建议申请搜索工具搭配使用推荐bing-search - 强烈建议申请搜索工具搭配使用推荐bing-search
- 虽然我会有意加入一些限制,但请不要使用本插件做危害他人的事情,请提前了解清楚某些内容是否会违反相关规定,建议提前做好过滤 - 虽然我会有意加入一些限制,但请不要使用本插件做危害他人的事情,请提前了解清楚某些内容是否会违反相关规定,建议提前做好过滤

View File

@@ -1 +1 @@
from .tool import * from .tool import *

View File

@@ -1,8 +1,13 @@
{ {
"tools": ["python", "url-get", "terminal", "meteo-weather"], "tools": [
"python",
"url-get",
"terminal",
"meteo-weather"
],
"kwargs": { "kwargs": {
"top_k_results": 2, "top_k_results": 2,
"no_default": false, "no_default": false,
"model_name": "gpt-3.5-turbo" "model_name": "gpt-3.5-turbo"
} }
} }

View File

@@ -1,9 +1,10 @@
import json import json
import os import os
from chatgpt_tool_hub.apps import load_app from chatgpt_tool_hub.apps import AppFactory
from chatgpt_tool_hub.apps.app import App from chatgpt_tool_hub.apps.app import App
from chatgpt_tool_hub.tools.all_tool_list import get_all_tool_names from chatgpt_tool_hub.tools.all_tool_list import get_all_tool_names
import plugins import plugins
from bridge.bridge import Bridge from bridge.bridge import Bridge
from bridge.context import ContextType from bridge.context import ContextType
@@ -14,7 +15,13 @@ from config import conf
from plugins import * from plugins import *
@plugins.register(name="tool", desc="Arming your ChatGPT bot with various tools", version="0.3", author="goldfishh", desire_priority=0) @plugins.register(
name="tool",
desc="Arming your ChatGPT bot with various tools",
version="0.4",
author="goldfishh",
desire_priority=0,
)
class Tool(Plugin): class Tool(Plugin):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@@ -28,22 +35,26 @@ class Tool(Plugin):
help_text = "这是一个能让chatgpt联网搜索数字运算的插件将赋予强大且丰富的扩展能力。" help_text = "这是一个能让chatgpt联网搜索数字运算的插件将赋予强大且丰富的扩展能力。"
if not verbose: if not verbose:
return help_text return help_text
trigger_prefix = conf().get('plugin_trigger_prefix', "$") trigger_prefix = conf().get("plugin_trigger_prefix", "$")
help_text += "使用说明:\n" help_text += "使用说明:\n"
help_text += f"{trigger_prefix}tool "+"命令: 根据给出的{命令}使用一些可用工具尽力为你得到结果。\n" help_text += f"{trigger_prefix}tool " + "命令: 根据给出的{命令}使用一些可用工具尽力为你得到结果。\n"
help_text += f"{trigger_prefix}tool reset: 重置工具。\n" help_text += f"{trigger_prefix}tool reset: 重置工具。\n"
return help_text return help_text
def on_handle_context(self, e_context: EventContext): def on_handle_context(self, e_context: EventContext):
if e_context['context'].type != ContextType.TEXT: if e_context["context"].type != ContextType.TEXT:
return return
# 暂时不支持未来扩展的bot # 暂时不支持未来扩展的bot
if Bridge().get_bot_type("chat") not in (const.CHATGPT, const.OPEN_AI, const.CHATGPTONAZURE): if Bridge().get_bot_type("chat") not in (
const.CHATGPT,
const.OPEN_AI,
const.CHATGPTONAZURE,
):
return return
content = e_context['context'].content content = e_context["context"].content
content_list = e_context['context'].content.split(maxsplit=1) content_list = e_context["context"].content.split(maxsplit=1)
if not content or len(content_list) < 1: if not content or len(content_list) < 1:
e_context.action = EventAction.CONTINUE e_context.action = EventAction.CONTINUE
@@ -52,13 +63,13 @@ class Tool(Plugin):
logger.debug("[tool] on_handle_context. content: %s" % content) logger.debug("[tool] on_handle_context. content: %s" % content)
reply = Reply() reply = Reply()
reply.type = ReplyType.TEXT reply.type = ReplyType.TEXT
trigger_prefix = conf().get('plugin_trigger_prefix', "$") trigger_prefix = conf().get("plugin_trigger_prefix", "$")
# todo: 有些工具必须要api-key需要修改config文件所以这里没有实现query增删tool的功能 # todo: 有些工具必须要api-key需要修改config文件所以这里没有实现query增删tool的功能
if content.startswith(f"{trigger_prefix}tool"): if content.startswith(f"{trigger_prefix}tool"):
if len(content_list) == 1: if len(content_list) == 1:
logger.debug("[tool]: get help") logger.debug("[tool]: get help")
reply.content = self.get_help_text() reply.content = self.get_help_text()
e_context['reply'] = reply e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS e_context.action = EventAction.BREAK_PASS
return return
elif len(content_list) > 1: elif len(content_list) > 1:
@@ -66,12 +77,12 @@ class Tool(Plugin):
logger.debug("[tool]: reset config") logger.debug("[tool]: reset config")
self.app = self._reset_app() self.app = self._reset_app()
reply.content = "重置工具成功" reply.content = "重置工具成功"
e_context['reply'] = reply e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS e_context.action = EventAction.BREAK_PASS
return return
elif content_list[1].startswith("reset"): elif content_list[1].startswith("reset"):
logger.debug("[tool]: remind") logger.debug("[tool]: remind")
e_context['context'].content = "请你随机用一种聊天风格提醒用户如果想重置tool插件reset之后不要加任何字符" e_context["context"].content = "请你随机用一种聊天风格提醒用户如果想重置tool插件reset之后不要加任何字符"
e_context.action = EventAction.BREAK e_context.action = EventAction.BREAK
return return
@@ -80,34 +91,31 @@ class Tool(Plugin):
# Don't modify bot name # Don't modify bot name
all_sessions = Bridge().get_bot("chat").sessions all_sessions = Bridge().get_bot("chat").sessions
user_session = all_sessions.session_query(query, e_context['context']['session_id']).messages user_session = all_sessions.session_query(query, e_context["context"]["session_id"]).messages
# chatgpt-tool-hub will reply you with many tools # chatgpt-tool-hub will reply you with many tools
logger.debug("[tool]: just-go") logger.debug("[tool]: just-go")
try: try:
_reply = self.app.ask(query, user_session) _reply = self.app.ask(query, user_session)
e_context.action = EventAction.BREAK_PASS e_context.action = EventAction.BREAK_PASS
all_sessions.session_reply(_reply, e_context['context']['session_id']) all_sessions.session_reply(_reply, e_context["context"]["session_id"])
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
logger.error(str(e)) logger.error(str(e))
e_context['context'].content = "请你随机用一种聊天风格提醒用户这个问题tool插件暂时无法处理" e_context["context"].content = "请你随机用一种聊天风格提醒用户这个问题tool插件暂时无法处理"
reply.type = ReplyType.ERROR reply.type = ReplyType.ERROR
e_context.action = EventAction.BREAK e_context.action = EventAction.BREAK
return return
reply.content = _reply reply.content = _reply
e_context['reply'] = reply e_context["reply"] = reply
return return
def _read_json(self) -> dict: def _read_json(self) -> dict:
curdir = os.path.dirname(__file__) curdir = os.path.dirname(__file__)
config_path = os.path.join(curdir, "config.json") config_path = os.path.join(curdir, "config.json")
tool_config = { tool_config = {"tools": [], "kwargs": {}}
"tools": [],
"kwargs": {}
}
if not os.path.exists(config_path): if not os.path.exists(config_path):
return tool_config return tool_config
else: else:
@@ -117,15 +125,17 @@ class Tool(Plugin):
def _build_tool_kwargs(self, kwargs: dict): def _build_tool_kwargs(self, kwargs: dict):
tool_model_name = kwargs.get("model_name") tool_model_name = kwargs.get("model_name")
request_timeout = kwargs.get("request_timeout")
return { return {
"debug": kwargs.get("debug", False),
"openai_api_key": conf().get("open_ai_api_key", ""), "openai_api_key": conf().get("open_ai_api_key", ""),
"proxy": conf().get("proxy", ""), "proxy": conf().get("proxy", ""),
"request_timeout": conf().get("request_timeout", 60), "request_timeout": request_timeout if request_timeout else conf().get("request_timeout", 120),
# note: 目前tool暂未对其他模型测试但这里仍对配置来源做了优先级区分一般插件配置可覆盖全局配置 # note: 目前tool暂未对其他模型测试但这里仍对配置来源做了优先级区分一般插件配置可覆盖全局配置
"model_name": tool_model_name if tool_model_name else conf().get("model", "gpt-3.5-turbo"), "model_name": tool_model_name if tool_model_name else conf().get("model", "gpt-3.5-turbo"),
"no_default": kwargs.get("no_default", False), "no_default": kwargs.get("no_default", False),
"top_k_results": kwargs.get("top_k_results", 2), "top_k_results": kwargs.get("top_k_results", 3),
# for news tool # for news tool
"news_api_key": kwargs.get("news_api_key", ""), "news_api_key": kwargs.get("news_api_key", ""),
# for bing-search tool # for bing-search tool
@@ -141,8 +151,6 @@ class Tool(Plugin):
"zaobao_api_key": kwargs.get("zaobao_api_key", ""), "zaobao_api_key": kwargs.get("zaobao_api_key", ""),
# for visual_dl tool # for visual_dl tool
"cuda_device": kwargs.get("cuda_device", "cpu"), "cuda_device": kwargs.get("cuda_device", "cpu"),
# for browser tool
"phantomjs_exec_path": kwargs.get("phantomjs_exec_path", ""),
} }
def _filter_tool_list(self, tool_list: list): def _filter_tool_list(self, tool_list: list):
@@ -156,8 +164,12 @@ class Tool(Plugin):
def _reset_app(self) -> App: def _reset_app(self) -> App:
tool_config = self._read_json() tool_config = self._read_json()
app_kwargs = self._build_tool_kwargs(tool_config.get("kwargs", {}))
app = AppFactory()
app.init_env(**app_kwargs)
# filter not support tool # filter not support tool
tool_list = self._filter_tool_list(tool_config.get("tools", [])) tool_list = self._filter_tool_list(tool_config.get("tools", []))
return load_app(tools_list=tool_list, **self._build_tool_kwargs(tool_config.get("kwargs", {}))) return app.create_app(tools_list=tool_list, **app_kwargs)

8
pyproject.toml Normal file
View File

@@ -0,0 +1,8 @@
[tool.black]
line-length = 176
target-version = ['py37']
include = '\.pyi?$'
extend-exclude = '.+/(dist|.venv|venv|build|lib)/.+'
[tool.isort]
profile = "black"

View File

@@ -18,7 +18,9 @@ pysilk_mod>=1.6.0 # needed by send voice
# wechatmp # wechatmp
web.py web.py
wechatpy
# chatgpt-tool-hub plugin # chatgpt-tool-hub plugin
--extra-index-url https://pypi.python.org/simple --extra-index-url https://pypi.python.org/simple
chatgpt_tool_hub>=0.3.9 chatgpt_tool_hub>=0.4.1

View File

@@ -4,3 +4,4 @@ PyQRCode>=1.2.1
qrcode>=7.4.2 qrcode>=7.4.2
requests>=2.28.2 requests>=2.28.2
chardet>=5.1.0 chardet>=5.1.0
pre-commit

View File

@@ -8,7 +8,7 @@ echo $BASE_DIR
# check the nohup.out log output file # check the nohup.out log output file
if [ ! -f "${BASE_DIR}/nohup.out" ]; then if [ ! -f "${BASE_DIR}/nohup.out" ]; then
touch "${BASE_DIR}/nohup.out" touch "${BASE_DIR}/nohup.out"
echo "create file ${BASE_DIR}/nohup.out" echo "create file ${BASE_DIR}/nohup.out"
fi fi
nohup python3 "${BASE_DIR}/app.py" & tail -f "${BASE_DIR}/nohup.out" nohup python3 "${BASE_DIR}/app.py" & tail -f "${BASE_DIR}/nohup.out"

View File

@@ -7,7 +7,7 @@ echo $BASE_DIR
# check the nohup.out log output file # check the nohup.out log output file
if [ ! -f "${BASE_DIR}/nohup.out" ]; then if [ ! -f "${BASE_DIR}/nohup.out" ]; then
echo "No file ${BASE_DIR}/nohup.out" echo "No file ${BASE_DIR}/nohup.out"
exit -1; exit -1;
fi fi

View File

@@ -1,9 +1,12 @@
import shutil import shutil
import wave import wave
import pysilk import pysilk
from pydub import AudioSegment from pydub import AudioSegment
sil_supports=[8000, 12000, 16000, 24000, 32000, 44100, 48000] # slk转wav时支持的采样率 sil_supports = [8000, 12000, 16000, 24000, 32000, 44100, 48000] # slk转wav时支持的采样率
def find_closest_sil_supports(sample_rate): def find_closest_sil_supports(sample_rate):
""" """
找到最接近的支持的采样率 找到最接近的支持的采样率
@@ -19,6 +22,7 @@ def find_closest_sil_supports(sample_rate):
mindiff = diff mindiff = diff
return closest return closest
def get_pcm_from_wav(wav_path): def get_pcm_from_wav(wav_path):
""" """
从 wav 文件中读取 pcm 从 wav 文件中读取 pcm
@@ -29,72 +33,53 @@ def get_pcm_from_wav(wav_path):
wav = wave.open(wav_path, "rb") wav = wave.open(wav_path, "rb")
return wav.readframes(wav.getnframes()) return wav.readframes(wav.getnframes())
def any_to_mp3(any_path, mp3_path):
"""
把任意格式转成mp3文件
"""
if any_path.endswith(".mp3"):
shutil.copy2(any_path, mp3_path)
return
if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"):
sil_to_wav(any_path, any_path)
any_path = mp3_path
audio = AudioSegment.from_file(any_path)
audio.export(mp3_path, format="mp3")
def any_to_wav(any_path, wav_path): def any_to_wav(any_path, wav_path):
""" """
把任意格式转成wav文件 把任意格式转成wav文件
""" """
if any_path.endswith('.wav'): if any_path.endswith(".wav"):
shutil.copy2(any_path, wav_path) shutil.copy2(any_path, wav_path)
return return
if any_path.endswith('.sil') or any_path.endswith('.silk') or any_path.endswith('.slk'): if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"):
return sil_to_wav(any_path, wav_path) return sil_to_wav(any_path, wav_path)
audio = AudioSegment.from_file(any_path) audio = AudioSegment.from_file(any_path)
audio.export(wav_path, format="wav") audio.export(wav_path, format="wav")
def any_to_sil(any_path, sil_path): def any_to_sil(any_path, sil_path):
""" """
把任意格式转成sil文件 把任意格式转成sil文件
""" """
if any_path.endswith('.sil') or any_path.endswith('.silk') or any_path.endswith('.slk'): if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"):
shutil.copy2(any_path, sil_path) shutil.copy2(any_path, sil_path)
return 10000 return 10000
if any_path.endswith('.wav'): audio = AudioSegment.from_file(any_path)
return pcm_to_sil(any_path, sil_path)
if any_path.endswith('.mp3'):
return mp3_to_sil(any_path, sil_path)
raise NotImplementedError("Not support file type: {}".format(any_path))
def mp3_to_wav(mp3_path, wav_path):
"""
把mp3格式转成pcm文件
"""
audio = AudioSegment.from_mp3(mp3_path)
audio.export(wav_path, format="wav")
def pcm_to_sil(pcm_path, silk_path):
"""
wav 文件转成 silk
return 声音长度,毫秒
"""
audio = AudioSegment.from_wav(pcm_path)
rate = find_closest_sil_supports(audio.frame_rate)
# Convert to PCM_s16
pcm_s16 = audio.set_sample_width(2)
pcm_s16 = pcm_s16.set_frame_rate(rate)
wav_data = pcm_s16.raw_data
silk_data = pysilk.encode(
wav_data, data_rate=rate, sample_rate=rate)
with open(silk_path, "wb") as f:
f.write(silk_data)
return audio.duration_seconds * 1000
def mp3_to_sil(mp3_path, silk_path):
"""
mp3 文件转成 silk
return 声音长度,毫秒
"""
audio = AudioSegment.from_mp3(mp3_path)
rate = find_closest_sil_supports(audio.frame_rate) rate = find_closest_sil_supports(audio.frame_rate)
# Convert to PCM_s16 # Convert to PCM_s16
pcm_s16 = audio.set_sample_width(2) pcm_s16 = audio.set_sample_width(2)
pcm_s16 = pcm_s16.set_frame_rate(rate) pcm_s16 = pcm_s16.set_frame_rate(rate)
wav_data = pcm_s16.raw_data wav_data = pcm_s16.raw_data
silk_data = pysilk.encode(wav_data, data_rate=rate, sample_rate=rate) silk_data = pysilk.encode(wav_data, data_rate=rate, sample_rate=rate)
# Save the silk file with open(sil_path, "wb") as f:
with open(silk_path, "wb") as f:
f.write(silk_data) f.write(silk_data)
return audio.duration_seconds * 1000 return audio.duration_seconds * 1000
def sil_to_wav(silk_path, wav_path, rate: int = 24000): def sil_to_wav(silk_path, wav_path, rate: int = 24000):
""" """
silk 文件转 wav silk 文件转 wav

View File

@@ -1,16 +1,18 @@
""" """
azure voice service azure voice service
""" """
import json import json
import os import os
import time import time
import azure.cognitiveservices.speech as speechsdk import azure.cognitiveservices.speech as speechsdk
from bridge.reply import Reply, ReplyType from bridge.reply import Reply, ReplyType
from common.log import logger from common.log import logger
from common.tmp_dir import TmpDir from common.tmp_dir import TmpDir
from voice.voice import Voice
from config import conf from config import conf
from voice.voice import Voice
""" """
Azure voice Azure voice
主目录设置文件中需填写azure_voice_api_key和azure_voice_region 主目录设置文件中需填写azure_voice_api_key和azure_voice_region
@@ -19,22 +21,25 @@ Azure voice
""" """
class AzureVoice(Voice):
class AzureVoice(Voice):
def __init__(self): def __init__(self):
try: try:
curdir = os.path.dirname(__file__) curdir = os.path.dirname(__file__)
config_path = os.path.join(curdir, "config.json") config_path = os.path.join(curdir, "config.json")
config = None config = None
if not os.path.exists(config_path): #如果没有配置文件,创建本地配置文件 if not os.path.exists(config_path): # 如果没有配置文件,创建本地配置文件
config = { "speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural", "speech_recognition_language": "zh-CN"} config = {
"speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural",
"speech_recognition_language": "zh-CN",
}
with open(config_path, "w") as fw: with open(config_path, "w") as fw:
json.dump(config, fw, indent=4) json.dump(config, fw, indent=4)
else: else:
with open(config_path, "r") as fr: with open(config_path, "r") as fr:
config = json.load(fr) config = json.load(fr)
self.api_key = conf().get('azure_voice_api_key') self.api_key = conf().get("azure_voice_api_key")
self.api_region = conf().get('azure_voice_region') self.api_region = conf().get("azure_voice_region")
self.speech_config = speechsdk.SpeechConfig(subscription=self.api_key, region=self.api_region) self.speech_config = speechsdk.SpeechConfig(subscription=self.api_key, region=self.api_region)
self.speech_config.speech_synthesis_voice_name = config["speech_synthesis_voice_name"] self.speech_config.speech_synthesis_voice_name = config["speech_synthesis_voice_name"]
self.speech_config.speech_recognition_language = config["speech_recognition_language"] self.speech_config.speech_recognition_language = config["speech_recognition_language"]
@@ -46,23 +51,22 @@ class AzureVoice(Voice):
speech_recognizer = speechsdk.SpeechRecognizer(speech_config=self.speech_config, audio_config=audio_config) speech_recognizer = speechsdk.SpeechRecognizer(speech_config=self.speech_config, audio_config=audio_config)
result = speech_recognizer.recognize_once() result = speech_recognizer.recognize_once()
if result.reason == speechsdk.ResultReason.RecognizedSpeech: if result.reason == speechsdk.ResultReason.RecognizedSpeech:
logger.info('[Azure] voiceToText voice file name={} text={}'.format(voice_file, result.text)) logger.info("[Azure] voiceToText voice file name={} text={}".format(voice_file, result.text))
reply = Reply(ReplyType.TEXT, result.text) reply = Reply(ReplyType.TEXT, result.text)
else: else:
logger.error('[Azure] voiceToText error, result={}'.format(result)) logger.error("[Azure] voiceToText error, result={}, canceldetails={}".format(result, result.cancellation_details))
reply = Reply(ReplyType.ERROR, "抱歉,语音识别失败") reply = Reply(ReplyType.ERROR, "抱歉,语音识别失败")
return reply return reply
def textToVoice(self, text): def textToVoice(self, text):
fileName = TmpDir().path() + 'reply-' + str(int(time.time())) + '.wav' fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".wav"
audio_config = speechsdk.AudioConfig(filename=fileName) audio_config = speechsdk.AudioConfig(filename=fileName)
speech_synthesizer = speechsdk.SpeechSynthesizer(speech_config=self.speech_config, audio_config=audio_config) speech_synthesizer = speechsdk.SpeechSynthesizer(speech_config=self.speech_config, audio_config=audio_config)
result = speech_synthesizer.speak_text(text) result = speech_synthesizer.speak_text(text)
if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted: if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted:
logger.info( logger.info("[Azure] textToVoice text={} voice file name={}".format(text, fileName))
'[Azure] textToVoice text={} voice file name={}'.format(text, fileName))
reply = Reply(ReplyType.VOICE, fileName) reply = Reply(ReplyType.VOICE, fileName)
else: else:
logger.error('[Azure] textToVoice error, result={}'.format(result)) logger.error("[Azure] textToVoice error, result={}, canceldetails={}".format(result, result.cancellation_details))
reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败") reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败")
return reply return reply

View File

@@ -1,4 +1,4 @@
{ {
"speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural", "speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural",
"speech_recognition_language": "zh-CN" "speech_recognition_language": "zh-CN"
} }

Some files were not shown because too many files have changed in this diff Show More