mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-02 09:48:22 +08:00
Compare commits
51 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a52f54d988 | ||
|
|
618c94edb8 | ||
|
|
eaf4e9174f | ||
|
|
4af2c7f3d7 | ||
|
|
361f599df0 | ||
|
|
ffe4ea5e4c | ||
|
|
9461e3e01a | ||
|
|
7c85c6f742 | ||
|
|
b5df6faadf | ||
|
|
7cefe2d825 | ||
|
|
350633b69b | ||
|
|
1cd6a71ce0 | ||
|
|
3a08b002a0 | ||
|
|
cca49da730 | ||
|
|
f6d370ad29 | ||
|
|
c9131b333b | ||
|
|
e44161bf42 | ||
|
|
a26189fb25 | ||
|
|
89dd8a1db6 | ||
|
|
650e0b4ad4 | ||
|
|
c60f0517fb | ||
|
|
0f8dc91a8b | ||
|
|
b58feb5d8e | ||
|
|
71c8043699 | ||
|
|
40264bc9cb | ||
|
|
a7772316f9 | ||
|
|
34209021c8 | ||
|
|
1e58c1ad2b | ||
|
|
8cea022ec5 | ||
|
|
f32f8aa08e | ||
|
|
0a7d6e4577 | ||
|
|
df4c1f0401 | ||
|
|
9a86a67984 | ||
|
|
a0cbe9c3e2 | ||
|
|
a83e5a9b65 | ||
|
|
de33911460 | ||
|
|
0be56e5b25 | ||
|
|
abcbb34b1c | ||
|
|
6a13dd04a3 | ||
|
|
f2e29f3f2e | ||
|
|
68361cddd2 | ||
|
|
6404332adc | ||
|
|
e060b6fea2 | ||
|
|
e8aae27ee9 | ||
|
|
7fb4f72b84 | ||
|
|
d4fc322101 | ||
|
|
8fa3da9ca5 | ||
|
|
68ef5aa3ae | ||
|
|
15e6cf850b | ||
|
|
f687b2b6f4 | ||
|
|
8ee7a48151 |
2
.flake8
2
.flake8
@@ -1,5 +1,5 @@
|
||||
[flake8]
|
||||
max-line-length = 88
|
||||
max-line-length = 176
|
||||
select = E303,W293,W291,W292,E305,E231,E302
|
||||
exclude =
|
||||
.tox,
|
||||
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -13,6 +13,7 @@ plugins.json
|
||||
itchat.pkl
|
||||
*.log
|
||||
user_datas.pkl
|
||||
chatgpt_tool_hub/
|
||||
plugins/**/
|
||||
!plugins/bdunit
|
||||
!plugins/dungeon
|
||||
@@ -20,5 +21,7 @@ plugins/**/
|
||||
!plugins/godcmd
|
||||
!plugins/tool
|
||||
!plugins/banwords
|
||||
!plugins/banwords/**/
|
||||
!plugins/hello
|
||||
!plugins/role
|
||||
!plugins/role
|
||||
!plugins/keyword
|
||||
@@ -27,3 +27,4 @@ repos:
|
||||
rev: 6.0.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
exclude: '(\/|^)lib\/'
|
||||
@@ -22,7 +22,7 @@
|
||||
|
||||
# 更新日志
|
||||
|
||||
>**2023.04.05:** 支持微信个人号部署,兼容角色扮演等预设插件,[使用文档](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/wechatmp/README.md)。(contributed by [@JS00000](https://github.com/JS00000) in [#686](https://github.com/zhayujie/chatgpt-on-wechat/pull/686))
|
||||
>**2023.04.05:** 支持微信公众号部署,兼容角色扮演等预设插件,[使用文档](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/wechatmp/README.md)。(contributed by [@JS00000](https://github.com/JS00000) in [#686](https://github.com/zhayujie/chatgpt-on-wechat/pull/686))
|
||||
|
||||
>**2023.04.05:** 增加能让ChatGPT使用工具的`tool`插件,[使用文档](https://github.com/goldfishh/chatgpt-on-wechat/blob/master/plugins/tool/README.md)。工具相关issue可反馈至[chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub)。(contributed by [@goldfishh](https://github.com/goldfishh) in [#663](https://github.com/zhayujie/chatgpt-on-wechat/pull/663))
|
||||
|
||||
|
||||
@@ -10,10 +10,7 @@ from bridge.reply import Reply, ReplyType
|
||||
class BaiduUnitBot(Bot):
|
||||
def reply(self, query, context=None):
|
||||
token = self.get_token()
|
||||
url = (
|
||||
"https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token="
|
||||
+ token
|
||||
)
|
||||
url = "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" + token
|
||||
post_data = (
|
||||
'{"version":"3.0","service_id":"S73177","session_id":"","log_id":"7758521","skill_ids":["1221886"],"request":{"terminal_id":"88888","query":"'
|
||||
+ query
|
||||
@@ -32,12 +29,7 @@ class BaiduUnitBot(Bot):
|
||||
def get_token(self):
|
||||
access_key = "YOUR_ACCESS_KEY"
|
||||
secret_key = "YOUR_SECRET_KEY"
|
||||
host = (
|
||||
"https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id="
|
||||
+ access_key
|
||||
+ "&client_secret="
|
||||
+ secret_key
|
||||
)
|
||||
host = "https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=" + access_key + "&client_secret=" + secret_key
|
||||
response = requests.get(host)
|
||||
if response:
|
||||
print(response.json())
|
||||
|
||||
@@ -30,23 +30,15 @@ class ChatGPTBot(Bot, OpenAIImage):
|
||||
if conf().get("rate_limit_chatgpt"):
|
||||
self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20))
|
||||
|
||||
self.sessions = SessionManager(
|
||||
ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo"
|
||||
)
|
||||
self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
|
||||
self.args = {
|
||||
"model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称
|
||||
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
|
||||
# "max_tokens":4096, # 回复最大的字符数
|
||||
"top_p": 1,
|
||||
"frequency_penalty": conf().get(
|
||||
"frequency_penalty", 0.0
|
||||
), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"presence_penalty": conf().get(
|
||||
"presence_penalty", 0.0
|
||||
), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"request_timeout": conf().get(
|
||||
"request_timeout", None
|
||||
), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
|
||||
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
|
||||
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
|
||||
}
|
||||
|
||||
@@ -87,15 +79,10 @@ class ChatGPTBot(Bot, OpenAIImage):
|
||||
reply_content["completion_tokens"],
|
||||
)
|
||||
)
|
||||
if (
|
||||
reply_content["completion_tokens"] == 0
|
||||
and len(reply_content["content"]) > 0
|
||||
):
|
||||
if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
elif reply_content["completion_tokens"] > 0:
|
||||
self.sessions.session_reply(
|
||||
reply_content["content"], session_id, reply_content["total_tokens"]
|
||||
)
|
||||
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
|
||||
reply = Reply(ReplyType.TEXT, reply_content["content"])
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
@@ -126,9 +113,7 @@ class ChatGPTBot(Bot, OpenAIImage):
|
||||
if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token():
|
||||
raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
|
||||
# if api_key == None, the default openai.api_key will be used
|
||||
response = openai.ChatCompletion.create(
|
||||
api_key=api_key, messages=session.messages, **self.args
|
||||
)
|
||||
response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **self.args)
|
||||
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
|
||||
return {
|
||||
"total_tokens": response["usage"]["total_tokens"],
|
||||
@@ -142,7 +127,7 @@ class ChatGPTBot(Bot, OpenAIImage):
|
||||
logger.warn("[CHATGPT] RateLimitError: {}".format(e))
|
||||
result["content"] = "提问太快啦,请休息一下再问我吧"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
time.sleep(20)
|
||||
elif isinstance(e, openai.error.Timeout):
|
||||
logger.warn("[CHATGPT] Timeout: {}".format(e))
|
||||
result["content"] = "我没有收到你的消息"
|
||||
|
||||
@@ -25,9 +25,7 @@ class ChatGPTSession(Session):
|
||||
precise = False
|
||||
if cur_tokens is None:
|
||||
raise e
|
||||
logger.debug(
|
||||
"Exception when counting tokens precisely for query: {}".format(e)
|
||||
)
|
||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
||||
while cur_tokens > max_tokens:
|
||||
if len(self.messages) > 2:
|
||||
self.messages.pop(1)
|
||||
@@ -39,16 +37,10 @@ class ChatGPTSession(Session):
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
break
|
||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
|
||||
logger.warn(
|
||||
"user message exceed max_tokens. total_tokens={}".format(cur_tokens)
|
||||
)
|
||||
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
|
||||
break
|
||||
else:
|
||||
logger.debug(
|
||||
"max_tokens={}, total_tokens={}, len(messages)={}".format(
|
||||
max_tokens, cur_tokens, len(self.messages)
|
||||
)
|
||||
)
|
||||
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
|
||||
break
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
@@ -75,17 +67,13 @@ def num_tokens_from_messages(messages, model):
|
||||
elif model == "gpt-4":
|
||||
return num_tokens_from_messages(messages, model="gpt-4-0314")
|
||||
elif model == "gpt-3.5-turbo-0301":
|
||||
tokens_per_message = (
|
||||
4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
)
|
||||
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
tokens_per_name = -1 # if there's a name, the role is omitted
|
||||
elif model == "gpt-4-0314":
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
else:
|
||||
logger.warn(
|
||||
f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301."
|
||||
)
|
||||
logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301.")
|
||||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
|
||||
num_tokens = 0
|
||||
for message in messages:
|
||||
|
||||
@@ -28,23 +28,15 @@ class OpenAIBot(Bot, OpenAIImage):
|
||||
if proxy:
|
||||
openai.proxy = proxy
|
||||
|
||||
self.sessions = SessionManager(
|
||||
OpenAISession, model=conf().get("model") or "text-davinci-003"
|
||||
)
|
||||
self.sessions = SessionManager(OpenAISession, model=conf().get("model") or "text-davinci-003")
|
||||
self.args = {
|
||||
"model": conf().get("model") or "text-davinci-003", # 对话模型的名称
|
||||
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
|
||||
"max_tokens": 1200, # 回复最大的字符数
|
||||
"top_p": 1,
|
||||
"frequency_penalty": conf().get(
|
||||
"frequency_penalty", 0.0
|
||||
), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"presence_penalty": conf().get(
|
||||
"presence_penalty", 0.0
|
||||
), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"request_timeout": conf().get(
|
||||
"request_timeout", None
|
||||
), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
|
||||
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
|
||||
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
|
||||
"stop": ["\n\n\n"],
|
||||
}
|
||||
@@ -71,17 +63,13 @@ class OpenAIBot(Bot, OpenAIImage):
|
||||
result["content"],
|
||||
)
|
||||
logger.debug(
|
||||
"[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
|
||||
str(session), session_id, reply_content, completion_tokens
|
||||
)
|
||||
"[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens)
|
||||
)
|
||||
|
||||
if total_tokens == 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content)
|
||||
else:
|
||||
self.sessions.session_reply(
|
||||
reply_content, session_id, total_tokens
|
||||
)
|
||||
self.sessions.session_reply(reply_content, session_id, total_tokens)
|
||||
reply = Reply(ReplyType.TEXT, reply_content)
|
||||
return reply
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
@@ -96,9 +84,7 @@ class OpenAIBot(Bot, OpenAIImage):
|
||||
def reply_text(self, session: OpenAISession, retry_count=0):
|
||||
try:
|
||||
response = openai.Completion.create(prompt=str(session), **self.args)
|
||||
res_content = (
|
||||
response.choices[0]["text"].strip().replace("<|endoftext|>", "")
|
||||
)
|
||||
res_content = response.choices[0]["text"].strip().replace("<|endoftext|>", "")
|
||||
total_tokens = response["usage"]["total_tokens"]
|
||||
completion_tokens = response["usage"]["completion_tokens"]
|
||||
logger.info("[OPEN_AI] reply={}".format(res_content))
|
||||
@@ -114,7 +100,7 @@ class OpenAIBot(Bot, OpenAIImage):
|
||||
logger.warn("[OPEN_AI] RateLimitError: {}".format(e))
|
||||
result["content"] = "提问太快啦,请休息一下再问我吧"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
time.sleep(20)
|
||||
elif isinstance(e, openai.error.Timeout):
|
||||
logger.warn("[OPEN_AI] Timeout: {}".format(e))
|
||||
result["content"] = "我没有收到你的消息"
|
||||
|
||||
@@ -23,9 +23,7 @@ class OpenAIImage(object):
|
||||
response = openai.Image.create(
|
||||
prompt=query, # 图片描述
|
||||
n=1, # 每次生成图片的数量
|
||||
size=conf().get(
|
||||
"image_create_size", "256x256"
|
||||
), # 图片大小,可选有 256x256, 512x512, 1024x1024
|
||||
size=conf().get("image_create_size", "256x256"), # 图片大小,可选有 256x256, 512x512, 1024x1024
|
||||
)
|
||||
image_url = response["data"][0]["url"]
|
||||
logger.info("[OPEN_AI] image_url={}".format(image_url))
|
||||
@@ -34,11 +32,7 @@ class OpenAIImage(object):
|
||||
logger.warn(e)
|
||||
if retry_count < 1:
|
||||
time.sleep(5)
|
||||
logger.warn(
|
||||
"[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(
|
||||
retry_count + 1
|
||||
)
|
||||
)
|
||||
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count + 1))
|
||||
return self.create_img(query, retry_count + 1)
|
||||
else:
|
||||
return False, "提问太快啦,请休息一下再问我吧"
|
||||
|
||||
@@ -36,9 +36,7 @@ class OpenAISession(Session):
|
||||
precise = False
|
||||
if cur_tokens is None:
|
||||
raise e
|
||||
logger.debug(
|
||||
"Exception when counting tokens precisely for query: {}".format(e)
|
||||
)
|
||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
||||
while cur_tokens > max_tokens:
|
||||
if len(self.messages) > 1:
|
||||
self.messages.pop(0)
|
||||
@@ -50,18 +48,10 @@ class OpenAISession(Session):
|
||||
cur_tokens = len(str(self))
|
||||
break
|
||||
elif len(self.messages) == 1 and self.messages[0]["role"] == "user":
|
||||
logger.warn(
|
||||
"user question exceed max_tokens. total_tokens={}".format(
|
||||
cur_tokens
|
||||
)
|
||||
)
|
||||
logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens))
|
||||
break
|
||||
else:
|
||||
logger.debug(
|
||||
"max_tokens={}, total_tokens={}, len(conversation)={}".format(
|
||||
max_tokens, cur_tokens, len(self.messages)
|
||||
)
|
||||
)
|
||||
logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages)))
|
||||
break
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
|
||||
@@ -55,9 +55,7 @@ class SessionManager(object):
|
||||
return self.sessioncls(session_id, system_prompt, **self.session_args)
|
||||
|
||||
if session_id not in self.sessions:
|
||||
self.sessions[session_id] = self.sessioncls(
|
||||
session_id, system_prompt, **self.session_args
|
||||
)
|
||||
self.sessions[session_id] = self.sessioncls(session_id, system_prompt, **self.session_args)
|
||||
elif system_prompt is not None: # 如果有新的system_prompt,更新并重置session
|
||||
self.sessions[session_id].set_system_prompt(system_prompt)
|
||||
session = self.sessions[session_id]
|
||||
@@ -71,9 +69,7 @@ class SessionManager(object):
|
||||
total_tokens = session.discard_exceeding(max_tokens, None)
|
||||
logger.debug("prompt tokens used={}".format(total_tokens))
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"Exception when counting tokens precisely for prompt: {}".format(str(e))
|
||||
)
|
||||
logger.debug("Exception when counting tokens precisely for prompt: {}".format(str(e)))
|
||||
return session
|
||||
|
||||
def session_reply(self, reply, session_id, total_tokens=None):
|
||||
@@ -82,17 +78,9 @@ class SessionManager(object):
|
||||
try:
|
||||
max_tokens = conf().get("conversation_max_tokens", 1000)
|
||||
tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
|
||||
logger.debug(
|
||||
"raw total_tokens={}, savesession tokens={}".format(
|
||||
total_tokens, tokens_cnt
|
||||
)
|
||||
)
|
||||
logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt))
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"Exception when counting tokens precisely for session: {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
logger.debug("Exception when counting tokens precisely for session: {}".format(str(e)))
|
||||
return session
|
||||
|
||||
def clear_session(self, session_id):
|
||||
|
||||
@@ -8,6 +8,8 @@ class ContextType(Enum):
|
||||
VOICE = 2 # 音频消息
|
||||
IMAGE = 3 # 图片消息
|
||||
IMAGE_CREATE = 10 # 创建图片命令
|
||||
JOIN_GROUP = 20 # 加入群聊
|
||||
PATPAT = 21 # 拍了拍
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
@@ -58,6 +60,4 @@ class Context:
|
||||
del self.kwargs[key]
|
||||
|
||||
def __str__(self):
|
||||
return "Context(type={}, content={}, kwargs={})".format(
|
||||
self.type, self.content, self.kwargs
|
||||
)
|
||||
return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs)
|
||||
|
||||
@@ -53,9 +53,7 @@ class ChatChannel(Channel):
|
||||
group_id = cmsg.other_user_id
|
||||
|
||||
group_name_white_list = config.get("group_name_white_list", [])
|
||||
group_name_keyword_white_list = config.get(
|
||||
"group_name_keyword_white_list", []
|
||||
)
|
||||
group_name_keyword_white_list = config.get("group_name_keyword_white_list", [])
|
||||
if any(
|
||||
[
|
||||
group_name in group_name_white_list,
|
||||
@@ -63,9 +61,7 @@ class ChatChannel(Channel):
|
||||
check_contain(group_name, group_name_keyword_white_list),
|
||||
]
|
||||
):
|
||||
group_chat_in_one_session = conf().get(
|
||||
"group_chat_in_one_session", []
|
||||
)
|
||||
group_chat_in_one_session = conf().get("group_chat_in_one_session", [])
|
||||
session_id = cmsg.actual_user_id
|
||||
if any(
|
||||
[
|
||||
@@ -81,17 +77,11 @@ class ChatChannel(Channel):
|
||||
else:
|
||||
context["session_id"] = cmsg.other_user_id
|
||||
context["receiver"] = cmsg.other_user_id
|
||||
e_context = PluginManager().emit_event(
|
||||
EventContext(
|
||||
Event.ON_RECEIVE_MESSAGE, {"channel": self, "context": context}
|
||||
)
|
||||
)
|
||||
e_context = PluginManager().emit_event(EventContext(Event.ON_RECEIVE_MESSAGE, {"channel": self, "context": context}))
|
||||
context = e_context["context"]
|
||||
if e_context.is_pass() or context is None:
|
||||
return context
|
||||
if cmsg.from_user_id == self.user_id and not config.get(
|
||||
"trigger_by_self", True
|
||||
):
|
||||
if cmsg.from_user_id == self.user_id and not config.get("trigger_by_self", True):
|
||||
logger.debug("[WX]self message skipped")
|
||||
return None
|
||||
|
||||
@@ -114,24 +104,18 @@ class ChatChannel(Channel):
|
||||
logger.info("[WX]receive group at")
|
||||
if not conf().get("group_at_off", False):
|
||||
flag = True
|
||||
pattern = f"@{self.name}(\u2005|\u0020)"
|
||||
pattern = f"@{re.escape(self.name)}(\u2005|\u0020)"
|
||||
content = re.sub(pattern, r"", content)
|
||||
|
||||
if not flag:
|
||||
if context["origin_ctype"] == ContextType.VOICE:
|
||||
logger.info(
|
||||
"[WX]receive group voice, but checkprefix didn't match"
|
||||
)
|
||||
logger.info("[WX]receive group voice, but checkprefix didn't match")
|
||||
return None
|
||||
else: # 单聊
|
||||
match_prefix = check_prefix(
|
||||
content, conf().get("single_chat_prefix", [""])
|
||||
)
|
||||
match_prefix = check_prefix(content, conf().get("single_chat_prefix", [""]))
|
||||
if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容
|
||||
content = content.replace(match_prefix, "", 1).strip()
|
||||
elif (
|
||||
context["origin_ctype"] == ContextType.VOICE
|
||||
): # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
|
||||
elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
|
||||
pass
|
||||
else:
|
||||
return None
|
||||
@@ -143,18 +127,10 @@ class ChatChannel(Channel):
|
||||
else:
|
||||
context.type = ContextType.TEXT
|
||||
context.content = content.strip()
|
||||
if (
|
||||
"desire_rtype" not in context
|
||||
and conf().get("always_reply_voice")
|
||||
and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE
|
||||
):
|
||||
if "desire_rtype" not in context and conf().get("always_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
elif context.type == ContextType.VOICE:
|
||||
if (
|
||||
"desire_rtype" not in context
|
||||
and conf().get("voice_reply_voice")
|
||||
and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE
|
||||
):
|
||||
if "desire_rtype" not in context and conf().get("voice_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
|
||||
return context
|
||||
@@ -182,15 +158,8 @@ class ChatChannel(Channel):
|
||||
)
|
||||
reply = e_context["reply"]
|
||||
if not e_context.is_pass():
|
||||
logger.debug(
|
||||
"[WX] ready to handle context: type={}, content={}".format(
|
||||
context.type, context.content
|
||||
)
|
||||
)
|
||||
if (
|
||||
context.type == ContextType.TEXT
|
||||
or context.type == ContextType.IMAGE_CREATE
|
||||
): # 文字和图片消息
|
||||
logger.debug("[WX] ready to handle context: type={}, content={}".format(context.type, context.content))
|
||||
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息
|
||||
reply = super().build_reply_content(context.content, context)
|
||||
elif context.type == ContextType.VOICE: # 语音消息
|
||||
cmsg = context["msg"]
|
||||
@@ -214,9 +183,7 @@ class ChatChannel(Channel):
|
||||
# logger.warning("[WX]delete temp file error: " + str(e))
|
||||
|
||||
if reply.type == ReplyType.TEXT:
|
||||
new_context = self._compose_context(
|
||||
ContextType.TEXT, reply.content, **context.kwargs
|
||||
)
|
||||
new_context = self._compose_context(ContextType.TEXT, reply.content, **context.kwargs)
|
||||
if new_context:
|
||||
reply = self._generate_reply(new_context)
|
||||
else:
|
||||
@@ -246,48 +213,24 @@ class ChatChannel(Channel):
|
||||
|
||||
if reply.type == ReplyType.TEXT:
|
||||
reply_text = reply.content
|
||||
if (
|
||||
desire_rtype == ReplyType.VOICE
|
||||
and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE
|
||||
):
|
||||
if desire_rtype == ReplyType.VOICE and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
||||
reply = super().build_text_to_voice(reply.content)
|
||||
return self._decorate_reply(context, reply)
|
||||
if context.get("isgroup", False):
|
||||
reply_text = (
|
||||
"@"
|
||||
+ context["msg"].actual_user_nickname
|
||||
+ " "
|
||||
+ reply_text.strip()
|
||||
)
|
||||
reply_text = (
|
||||
conf().get("group_chat_reply_prefix", "") + reply_text
|
||||
)
|
||||
reply_text = "@" + context["msg"].actual_user_nickname + " " + reply_text.strip()
|
||||
reply_text = conf().get("group_chat_reply_prefix", "") + reply_text
|
||||
else:
|
||||
reply_text = (
|
||||
conf().get("single_chat_reply_prefix", "") + reply_text
|
||||
)
|
||||
reply_text = conf().get("single_chat_reply_prefix", "") + reply_text
|
||||
reply.content = reply_text
|
||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
||||
reply.content = "[" + str(reply.type) + "]\n" + reply.content
|
||||
elif (
|
||||
reply.type == ReplyType.IMAGE_URL
|
||||
or reply.type == ReplyType.VOICE
|
||||
or reply.type == ReplyType.IMAGE
|
||||
):
|
||||
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE:
|
||||
pass
|
||||
else:
|
||||
logger.error("[WX] unknown reply type: {}".format(reply.type))
|
||||
return
|
||||
if (
|
||||
desire_rtype
|
||||
and desire_rtype != reply.type
|
||||
and reply.type not in [ReplyType.ERROR, ReplyType.INFO]
|
||||
):
|
||||
logger.warning(
|
||||
"[WX] desire_rtype: {}, but reply type: {}".format(
|
||||
context.get("desire_rtype"), reply.type
|
||||
)
|
||||
)
|
||||
if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]:
|
||||
logger.warning("[WX] desire_rtype: {}, but reply type: {}".format(context.get("desire_rtype"), reply.type))
|
||||
return reply
|
||||
|
||||
def _send_reply(self, context: Context, reply: Reply):
|
||||
@@ -300,9 +243,7 @@ class ChatChannel(Channel):
|
||||
)
|
||||
reply = e_context["reply"]
|
||||
if not e_context.is_pass() and reply and reply.type:
|
||||
logger.debug(
|
||||
"[WX] ready to send reply: {}, context: {}".format(reply, context)
|
||||
)
|
||||
logger.debug("[WX] ready to send reply: {}, context: {}".format(reply, context))
|
||||
self._send(reply, context)
|
||||
|
||||
def _send(self, reply: Reply, context: Context, retry_cnt=0):
|
||||
@@ -328,9 +269,7 @@ class ChatChannel(Channel):
|
||||
try:
|
||||
worker_exception = worker.exception()
|
||||
if worker_exception:
|
||||
self._fail_callback(
|
||||
session_id, exception=worker_exception, **kwargs
|
||||
)
|
||||
self._fail_callback(session_id, exception=worker_exception, **kwargs)
|
||||
else:
|
||||
self._success_callback(session_id, **kwargs)
|
||||
except CancelledError as e:
|
||||
@@ -366,24 +305,14 @@ class ChatChannel(Channel):
|
||||
if not context_queue.empty():
|
||||
context = context_queue.get()
|
||||
logger.debug("[WX] consume context: {}".format(context))
|
||||
future: Future = self.handler_pool.submit(
|
||||
self._handle, context
|
||||
)
|
||||
future.add_done_callback(
|
||||
self._thread_pool_callback(session_id, context=context)
|
||||
)
|
||||
future: Future = self.handler_pool.submit(self._handle, context)
|
||||
future.add_done_callback(self._thread_pool_callback(session_id, context=context))
|
||||
if session_id not in self.futures:
|
||||
self.futures[session_id] = []
|
||||
self.futures[session_id].append(future)
|
||||
elif (
|
||||
semaphore._initial_value == semaphore._value + 1
|
||||
): # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
|
||||
self.futures[session_id] = [
|
||||
t for t in self.futures[session_id] if not t.done()
|
||||
]
|
||||
assert (
|
||||
len(self.futures[session_id]) == 0
|
||||
), "thread pool error"
|
||||
elif semaphore._initial_value == semaphore._value + 1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
|
||||
self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()]
|
||||
assert len(self.futures[session_id]) == 0, "thread pool error"
|
||||
del self.sessions[session_id]
|
||||
else:
|
||||
semaphore.release()
|
||||
@@ -397,9 +326,7 @@ class ChatChannel(Channel):
|
||||
future.cancel()
|
||||
cnt = self.sessions[session_id][0].qsize()
|
||||
if cnt > 0:
|
||||
logger.info(
|
||||
"Cancel {} messages in session {}".format(cnt, session_id)
|
||||
)
|
||||
logger.info("Cancel {} messages in session {}".format(cnt, session_id))
|
||||
self.sessions[session_id][0] = Dequeue()
|
||||
|
||||
def cancel_all_session(self):
|
||||
@@ -409,9 +336,7 @@ class ChatChannel(Channel):
|
||||
future.cancel()
|
||||
cnt = self.sessions[session_id][0].qsize()
|
||||
if cnt > 0:
|
||||
logger.info(
|
||||
"Cancel {} messages in session {}".format(cnt, session_id)
|
||||
)
|
||||
logger.info("Cancel {} messages in session {}".format(cnt, session_id))
|
||||
self.sessions[session_id][0] = Dequeue()
|
||||
|
||||
|
||||
|
||||
@@ -77,9 +77,7 @@ class TerminalChannel(ChatChannel):
|
||||
if check_prefix(prompt, trigger_prefixs) is None:
|
||||
prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀
|
||||
|
||||
context = self._compose_context(
|
||||
ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt)
|
||||
)
|
||||
context = self._compose_context(ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt))
|
||||
if context:
|
||||
self.produce(context)
|
||||
else:
|
||||
|
||||
@@ -26,20 +26,25 @@ from lib.itchat.content import *
|
||||
from plugins import *
|
||||
|
||||
|
||||
@itchat.msg_register([TEXT, VOICE, PICTURE])
|
||||
@itchat.msg_register([TEXT, VOICE, PICTURE, NOTE])
|
||||
def handler_single_msg(msg):
|
||||
# logger.debug("handler_single_msg: {}".format(msg))
|
||||
if msg["Type"] == PICTURE and msg["MsgType"] == 47:
|
||||
try:
|
||||
cmsg = WeChatMessage(msg, False)
|
||||
except NotImplementedError as e:
|
||||
logger.debug("[WX]single message {} skipped: {}".format(msg["MsgId"], e))
|
||||
return None
|
||||
WechatChannel().handle_single(WeChatMessage(msg))
|
||||
WechatChannel().handle_single(cmsg)
|
||||
return None
|
||||
|
||||
|
||||
@itchat.msg_register([TEXT, VOICE, PICTURE], isGroupChat=True)
|
||||
@itchat.msg_register([TEXT, VOICE, PICTURE, NOTE], isGroupChat=True)
|
||||
def handler_group_msg(msg):
|
||||
if msg["Type"] == PICTURE and msg["MsgType"] == 47:
|
||||
try:
|
||||
cmsg = WeChatMessage(msg, True)
|
||||
except NotImplementedError as e:
|
||||
logger.debug("[WX]group message {} skipped: {}".format(msg["MsgId"], e))
|
||||
return None
|
||||
WechatChannel().handle_group(WeChatMessage(msg, True))
|
||||
WechatChannel().handle_group(cmsg)
|
||||
return None
|
||||
|
||||
|
||||
@@ -51,10 +56,7 @@ def _check(func):
|
||||
return
|
||||
self.receivedMsgs[msgId] = cmsg
|
||||
create_time = cmsg.create_time # 消息时间戳
|
||||
if (
|
||||
conf().get("hot_reload") == True
|
||||
and int(create_time) < int(time.time()) - 60
|
||||
): # 跳过1分钟前的历史消息
|
||||
if conf().get("hot_reload") == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
|
||||
logger.debug("[WX]history message {} skipped".format(msgId))
|
||||
return
|
||||
return func(self, cmsg)
|
||||
@@ -83,15 +85,9 @@ def qrCallback(uuid, status, qrcode):
|
||||
url = f"https://login.weixin.qq.com/l/{uuid}"
|
||||
|
||||
qr_api1 = "https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url)
|
||||
qr_api2 = (
|
||||
"https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(
|
||||
url
|
||||
)
|
||||
)
|
||||
qr_api2 = "https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(url)
|
||||
qr_api3 = "https://api.pwmqr.com/qrcode/create/?url={}".format(url)
|
||||
qr_api4 = "https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format(
|
||||
url
|
||||
)
|
||||
qr_api4 = "https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format(url)
|
||||
print("You can also scan QRCode in any website below:")
|
||||
print(qr_api3)
|
||||
print(qr_api4)
|
||||
@@ -129,18 +125,12 @@ class WechatChannel(ChatChannel):
|
||||
logger.error("Hot reload failed, try to login without hot reload")
|
||||
itchat.logout()
|
||||
os.remove(status_path)
|
||||
itchat.auto_login(
|
||||
enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback
|
||||
)
|
||||
itchat.auto_login(enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback)
|
||||
else:
|
||||
raise e
|
||||
self.user_id = itchat.instance.storageClass.userName
|
||||
self.name = itchat.instance.storageClass.nickName
|
||||
logger.info(
|
||||
"Wechat login success, user_id: {}, nickname: {}".format(
|
||||
self.user_id, self.name
|
||||
)
|
||||
)
|
||||
logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
|
||||
# start message listener
|
||||
itchat.run()
|
||||
|
||||
@@ -165,15 +155,13 @@ class WechatChannel(ChatChannel):
|
||||
logger.debug("[WX]receive voice msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE:
|
||||
logger.debug("[WX]receive image msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.PATPAT:
|
||||
logger.debug("[WX]receive patpat msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.TEXT:
|
||||
logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
|
||||
else:
|
||||
logger.debug(
|
||||
"[WX]receive text msg: {}, cmsg={}".format(
|
||||
json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg
|
||||
)
|
||||
)
|
||||
context = self._compose_context(
|
||||
cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg
|
||||
)
|
||||
logger.debug("[WX]receive msg: {}, cmsg={}".format(cmsg.content, cmsg))
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
@@ -186,12 +174,14 @@ class WechatChannel(ChatChannel):
|
||||
logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE:
|
||||
logger.debug("[WX]receive image for group msg: {}".format(cmsg.content))
|
||||
else:
|
||||
elif cmsg.ctype in [ContextType.JOIN_GROUP, ContextType.PATPAT]:
|
||||
logger.debug("[WX]receive note msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.TEXT:
|
||||
# logger.debug("[WX]receive group msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
|
||||
pass
|
||||
context = self._compose_context(
|
||||
cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg
|
||||
)
|
||||
else:
|
||||
logger.debug("[WX]receive group msg: {}".format(cmsg.content))
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import re
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.log import logger
|
||||
@@ -24,10 +26,24 @@ class WeChatMessage(ChatMessage):
|
||||
self.ctype = ContextType.IMAGE
|
||||
self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径
|
||||
self._prepare_fn = lambda: itchat_msg.download(self.content)
|
||||
elif itchat_msg["Type"] == NOTE and itchat_msg["MsgType"] == 10000:
|
||||
if is_group and ("加入群聊" in itchat_msg["Content"] or "加入了群聊" in itchat_msg["Content"]):
|
||||
self.ctype = ContextType.JOIN_GROUP
|
||||
self.content = itchat_msg["Content"]
|
||||
# 这里只能得到nickname, actual_user_id还是机器人的id
|
||||
if "加入了群聊" in itchat_msg["Content"]:
|
||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[-1]
|
||||
elif "加入群聊" in itchat_msg["Content"]:
|
||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
|
||||
elif "拍了拍我" in itchat_msg["Content"]:
|
||||
self.ctype = ContextType.PATPAT
|
||||
self.content = itchat_msg["Content"]
|
||||
if is_group:
|
||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
|
||||
else:
|
||||
raise NotImplementedError("Unsupported note message: " + itchat_msg["Content"])
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Unsupported message type: {}".format(itchat_msg["Type"])
|
||||
)
|
||||
raise NotImplementedError("Unsupported message type: Type:{} MsgType:{}".format(itchat_msg["Type"], itchat_msg["MsgType"]))
|
||||
|
||||
self.from_user_id = itchat_msg["FromUserName"]
|
||||
self.to_user_id = itchat_msg["ToUserName"]
|
||||
@@ -58,4 +74,5 @@ class WeChatMessage(ChatMessage):
|
||||
if self.is_group:
|
||||
self.is_at = itchat_msg["IsAt"]
|
||||
self.actual_user_id = itchat_msg["ActualUserName"]
|
||||
self.actual_user_nickname = itchat_msg["ActualNickName"]
|
||||
if self.ctype not in [ContextType.JOIN_GROUP, ContextType.PATPAT]:
|
||||
self.actual_user_nickname = itchat_msg["ActualNickName"]
|
||||
|
||||
@@ -60,13 +60,9 @@ class WechatyChannel(ChatChannel):
|
||||
receiver_id = context["receiver"]
|
||||
loop = asyncio.get_event_loop()
|
||||
if context["isgroup"]:
|
||||
receiver = asyncio.run_coroutine_threadsafe(
|
||||
self.bot.Room.find(receiver_id), loop
|
||||
).result()
|
||||
receiver = asyncio.run_coroutine_threadsafe(self.bot.Room.find(receiver_id), loop).result()
|
||||
else:
|
||||
receiver = asyncio.run_coroutine_threadsafe(
|
||||
self.bot.Contact.find(receiver_id), loop
|
||||
).result()
|
||||
receiver = asyncio.run_coroutine_threadsafe(self.bot.Contact.find(receiver_id), loop).result()
|
||||
msg = None
|
||||
if reply.type == ReplyType.TEXT:
|
||||
msg = reply.content
|
||||
@@ -83,9 +79,7 @@ class WechatyChannel(ChatChannel):
|
||||
voiceLength = int(any_to_sil(file_path, sil_file))
|
||||
if voiceLength >= 60000:
|
||||
voiceLength = 60000
|
||||
logger.info(
|
||||
"[WX] voice too long, length={}, set to 60s".format(voiceLength)
|
||||
)
|
||||
logger.info("[WX] voice too long, length={}, set to 60s".format(voiceLength))
|
||||
# 发送语音
|
||||
t = int(time.time())
|
||||
msg = FileBox.from_file(sil_file, name=str(t) + ".sil")
|
||||
@@ -98,9 +92,7 @@ class WechatyChannel(ChatChannel):
|
||||
os.remove(sil_file)
|
||||
except Exception as e:
|
||||
pass
|
||||
logger.info(
|
||||
"[WX] sendVoice={}, receiver={}".format(reply.content, receiver)
|
||||
)
|
||||
logger.info("[WX] sendVoice={}, receiver={}".format(reply.content, receiver))
|
||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||
img_url = reply.content
|
||||
t = int(time.time())
|
||||
@@ -111,9 +103,7 @@ class WechatyChannel(ChatChannel):
|
||||
image_storage = reply.content
|
||||
image_storage.seek(0)
|
||||
t = int(time.time())
|
||||
msg = FileBox.from_base64(
|
||||
base64.b64encode(image_storage.read()), str(t) + ".png"
|
||||
)
|
||||
msg = FileBox.from_base64(base64.b64encode(image_storage.read()), str(t) + ".png")
|
||||
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
|
||||
logger.info("[WX] sendImage, receiver={}".format(receiver))
|
||||
|
||||
|
||||
@@ -45,16 +45,12 @@ class WechatyMessage(ChatMessage, aobject):
|
||||
|
||||
def func():
|
||||
loop = asyncio.get_event_loop()
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
voice_file.to_file(self.content), loop
|
||||
).result()
|
||||
asyncio.run_coroutine_threadsafe(voice_file.to_file(self.content), loop).result()
|
||||
|
||||
self._prepare_fn = func
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Unsupported message type: {}".format(wechaty_msg.type())
|
||||
)
|
||||
raise NotImplementedError("Unsupported message type: {}".format(wechaty_msg.type()))
|
||||
|
||||
from_contact = wechaty_msg.talker() # 获取消息的发送者
|
||||
self.from_user_id = from_contact.contact_id
|
||||
@@ -73,9 +69,7 @@ class WechatyMessage(ChatMessage, aobject):
|
||||
self.to_user_id = to_contact.contact_id
|
||||
self.to_user_nickname = to_contact.name
|
||||
|
||||
if (
|
||||
self.is_group or wechaty_msg.is_self()
|
||||
): # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。
|
||||
if self.is_group or wechaty_msg.is_self(): # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。
|
||||
self.other_user_id = self.to_user_id
|
||||
self.other_user_nickname = self.to_user_nickname
|
||||
else:
|
||||
@@ -86,7 +80,7 @@ class WechatyMessage(ChatMessage, aobject):
|
||||
self.is_at = await wechaty_msg.mention_self()
|
||||
if not self.is_at: # 有时候复制粘贴的消息,不算做@,但是内容里面会有@xxx,这里做一下兼容
|
||||
name = wechaty_msg.wechaty.user_self().name
|
||||
pattern = f"@{name}(\u2005|\u0020)"
|
||||
pattern = f"@{re.escape(name)}(\u2005|\u0020)"
|
||||
if re.search(pattern, self.content):
|
||||
logger.debug(f"wechaty message {self.msg_id} include at")
|
||||
self.is_at = True
|
||||
|
||||
@@ -1,57 +1,100 @@
|
||||
# 微信公众号channel
|
||||
|
||||
鉴于个人微信号在服务器上通过itchat登录有封号风险,这里新增了微信公众号channel,提供无风险的服务。
|
||||
目前支持订阅号(个人)和服务号(企业)两种类型的公众号,它们的主要区别就是被动回复和主动回复。
|
||||
个人微信订阅号有许多接口限制,目前仅支持最基本的文本对话和语音输入,支持加载插件,支持私有api_key。
|
||||
暂未实现图片输入输出、语音输出等交互形式。
|
||||
目前支持订阅号和服务号两种类型的公众号,它们都支持文本交互,语音和图片输入。其中个人主体的微信订阅号由于无法通过微信认证,存在回复时间限制,每天的图片和声音回复次数也有限制。
|
||||
|
||||
## 使用方法(订阅号,服务号类似)
|
||||
|
||||
在开始部署前,你需要一个拥有公网IP的服务器,以提供微信服务器和我们自己服务器的连接。或者你需要进行内网穿透,否则微信服务器无法将消息发送给我们的服务器。
|
||||
|
||||
此外,需要在我们的服务器上安装python的web框架web.py。
|
||||
此外,需要在我们的服务器上安装python的web框架web.py和wechatpy。
|
||||
以ubuntu为例(在ubuntu 22.04上测试):
|
||||
```
|
||||
pip3 install web.py
|
||||
pip3 install wechatpy
|
||||
```
|
||||
|
||||
然后在[微信公众平台](https://mp.weixin.qq.com)注册一个自己的公众号,类型选择订阅号,主体为个人即可。
|
||||
|
||||
然后根据[接入指南](https://developers.weixin.qq.com/doc/offiaccount/Basic_Information/Access_Overview.html)的说明,在[微信公众平台](https://mp.weixin.qq.com)的“设置与开发”-“基本配置”-“服务器配置”中填写服务器地址`URL`和令牌`Token`。这里的`URL`是`example.com/wx`的形式,不可以使用IP,`Token`是你自己编的一个特定的令牌。消息加解密方式目前选择的是明文模式。
|
||||
然后根据[接入指南](https://developers.weixin.qq.com/doc/offiaccount/Basic_Information/Access_Overview.html)的说明,在[微信公众平台](https://mp.weixin.qq.com)的“设置与开发”-“基本配置”-“服务器配置”中填写服务器地址`URL`和令牌`Token`。这里的`URL`是`example.com/wx`的形式,不可以使用IP,`Token`是你自己编的一个特定的令牌。消息加解密方式如果选择了需要加密的模式,需要在配置中填写`wechatmp_aes_key`。
|
||||
|
||||
相关的服务器验证代码已经写好,你不需要再添加任何代码。你只需要在本项目根目录的`config.json`中添加
|
||||
```
|
||||
"channel_type": "wechatmp",
|
||||
"wechatmp_token": "Token", # 微信公众平台的Token
|
||||
"wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443
|
||||
"wechatmp_app_id": "", # 微信公众平台的appID,仅服务号需要
|
||||
"wechatmp_app_secret": "", # 微信公众平台的appsecret,仅服务号需要
|
||||
"channel_type": "wechatmp", # 如果通过了微信认证,将"wechatmp"替换为"wechatmp_service",可极大的优化使用体验
|
||||
"wechatmp_token": "xxxx", # 微信公众平台的Token
|
||||
"wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443
|
||||
"wechatmp_app_id": "xxxx", # 微信公众平台的appID
|
||||
"wechatmp_app_secret": "xxxx", # 微信公众平台的appsecret
|
||||
"wechatmp_aes_key": "", # 微信公众平台的EncodingAESKey,加密模式需要
|
||||
"single_chat_prefix": [""], # 推荐设置,任意对话都可以触发回复,不添加前缀
|
||||
"single_chat_reply_prefix": "", # 推荐设置,回复不设置前缀
|
||||
"plugin_trigger_prefix": "&", # 推荐设置,在手机微信客户端中,$%^等符号与中文连在一起时会自动显示一段较大的间隔,用户体验不好。请不要使用管理员指令前缀"#",这会造成未知问题。
|
||||
```
|
||||
然后运行`python3 app.py`启动web服务器。这里会默认监听8080端口,但是微信公众号的服务器配置只支持80/443端口,有两种方法来解决这个问题。第一个是推荐的方法,使用端口转发命令将80端口转发到8080端口(443同理,注意需要支持SSL,也就是https的访问,在`wechatmp_channel.py`需要修改相应的证书路径):
|
||||
然后运行`python3 app.py`启动web服务器。这里会默认监听8080端口,但是微信公众号的服务器配置只支持80/443端口,有两种方法来解决这个问题。第一个是推荐的方法,使用端口转发命令将80端口转发到8080端口:
|
||||
```
|
||||
sudo iptables -t nat -A PREROUTING -p tcp --dport 80 -j REDIRECT --to-port 8080
|
||||
sudo iptables-save > /etc/iptables/rules.v4
|
||||
```
|
||||
第二个方法是让python程序直接监听80端口。这样可能会导致权限问题,在linux上需要使用`sudo`。然而这会导致后续缓存文件的权限问题,因此不是推荐的方法。
|
||||
最后在刚才的“服务器配置”中点击`提交`即可验证你的服务器。
|
||||
第二个方法是让python程序直接监听80端口,在配置文件中设置`"wechatmp_port": 80` ,在linux上需要使用`sudo python3 app.py`启动程序。然而这会导致一系列环境和权限问题,因此不是推荐的方法。
|
||||
|
||||
443端口同理,注意需要支持SSL,也就是https的访问,在`wechatmp_channel.py`中需要修改相应的证书路径。
|
||||
|
||||
程序启动并监听端口后,在刚才的“服务器配置”中点击`提交`即可验证你的服务器。
|
||||
随后在[微信公众平台](https://mp.weixin.qq.com)启用服务器,关闭手动填写规则的自动回复,即可实现ChatGPT的自动回复。
|
||||
|
||||
之后需要在公众号开发信息下将本机IP加入到IP白名单。
|
||||
|
||||
不然在启用后,发送语音、图片等消息可能会遇到如下报错:
|
||||
```
|
||||
'errcode': 40164, 'errmsg': 'invalid ip xx.xx.xx.xx not in whitelist rid
|
||||
```
|
||||
|
||||
|
||||
## 个人微信公众号的限制
|
||||
由于人微信公众号不能通过微信认证,所以没有客服接口,因此公众号无法主动发出消息,只能被动回复。而微信官方对被动回复有5秒的时间限制,最多重试2次,因此最多只有15秒的自动回复时间窗口。因此如果问题比较复杂或者我们的服务器比较忙,ChatGPT的回答就没办法及时回复给用户。为了解决这个问题,这里做了回答缓存,它需要你在回复超时后,再次主动发送任意文字(例如1)来尝试拿到回答缓存。为了优化使用体验,目前设置了两分钟(120秒)的timeout,用户在至多两分钟后即可得到查询到回复或者错误原因。
|
||||
|
||||
另外,由于微信官方的限制,自动回复有长度限制。因此这里将ChatGPT的回答拆分,分成每段600字回复(限制大约在700字)。
|
||||
另外,由于微信官方的限制,自动回复有长度限制。因此这里将ChatGPT的回答进行了拆分,以满足限制。
|
||||
|
||||
## 私有api_key
|
||||
公共api有访问频率限制(免费账号每分钟最多20次ChatGPT的API调用),这在服务多人的时候会遇到问题。因此这里多加了一个设置私有api_key的功能。目前通过godcmd插件的命令来设置私有api_key。
|
||||
公共api有访问频率限制(免费账号每分钟最多3次ChatGPT的API调用),这在服务多人的时候会遇到问题。因此这里多加了一个设置私有api_key的功能。目前通过godcmd插件的命令来设置私有api_key。
|
||||
|
||||
## 语音输入
|
||||
利用微信自带的语音识别功能,提供语音输入能力。需要在公众号管理页面的“设置与开发”->“接口权限”页面开启“接收语音识别结果”。
|
||||
|
||||
## 测试范围
|
||||
目前在`RoboStyle`这个公众号上进行了测试(基于[wechatmp分支](https://github.com/JS00000/chatgpt-on-wechat/tree/wechatmp)),感兴趣的可以关注并体验。开启了godcmd, Banwords, role, dungeon, finish这五个插件,其他的插件还没有测试。百度的接口暂未测试。语音对话没有测试。图片直接以链接形式回复(没有临时素材上传接口的权限)。
|
||||
## 语音回复
|
||||
请在配置文件中添加以下词条:
|
||||
```
|
||||
"voice_reply_voice": true,
|
||||
```
|
||||
这样公众号将会用语音回复语音消息,实现语音对话。
|
||||
|
||||
默认的语音合成引擎是`google`,它是免费使用的。
|
||||
|
||||
如果要选择其他的语音合成引擎,请添加以下配置项:
|
||||
```
|
||||
"text_to_voice": "pytts"
|
||||
```
|
||||
|
||||
pytts是本地的语音合成引擎。还支持baidu,azure,这些你需要自行配置相关的依赖和key。
|
||||
|
||||
如果使用pytts,在ubuntu上需要安装如下依赖:
|
||||
```
|
||||
sudo apt update
|
||||
sudo apt install espeak
|
||||
sudo apt install ffmpeg
|
||||
python3 -m pip install pyttsx3
|
||||
```
|
||||
不是很建议开启pytts语音回复,因为它是离线本地计算,算的慢会拖垮服务器,且声音不好听。
|
||||
|
||||
## 图片回复
|
||||
现在认证公众号和非认证公众号都可以实现的图片和语音回复。但是非认证公众号使用了永久素材接口,每天有1000次的调用上限(每个月有10次重置机会,程序中已设定遇到上限会自动重置),且永久素材库存也有上限。因此对于非认证公众号,我们会在回复图片或者语音消息后的10秒内从永久素材库存内删除该素材。
|
||||
|
||||
## 测试
|
||||
目前在`RoboStyle`这个公众号上进行了测试(基于[wechatmp分支](https://github.com/JS00000/chatgpt-on-wechat/tree/wechatmp)),感兴趣的可以关注并体验。开启了godcmd, Banwords, role, dungeon, finish这五个插件,其他的插件还没有详尽测试。百度的接口暂未测试。[wechatmp-stable分支](https://github.com/JS00000/chatgpt-on-wechat/tree/wechatmp-stable)是较稳定的上个版本,但也缺少最新的功能支持。
|
||||
|
||||
## TODO
|
||||
* 服务号交互完善
|
||||
* 服务号使用临时素材接口,提供图片回复能力
|
||||
* 插件测试
|
||||
- [x] 语音输入
|
||||
- [x] 图片输入
|
||||
- [x] 使用临时素材接口提供认证公众号的图片和语音回复
|
||||
- [x] 使用永久素材接口提供未认证公众号的图片和语音回复
|
||||
- [ ] 高并发支持
|
||||
|
||||
@@ -1,70 +0,0 @@
|
||||
import time
|
||||
|
||||
import web
|
||||
|
||||
import channel.wechatmp.receive as receive
|
||||
import channel.wechatmp.reply as reply
|
||||
from bridge.context import *
|
||||
from channel.wechatmp.common import *
|
||||
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
# This class is instantiated once per query
|
||||
class Query:
|
||||
def GET(self):
|
||||
return verify_server(web.input())
|
||||
|
||||
def POST(self):
|
||||
# Make sure to return the instance that first created, @singleton will do that.
|
||||
channel = WechatMPChannel()
|
||||
try:
|
||||
webData = web.data()
|
||||
# logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8"))
|
||||
wechatmp_msg = receive.parse_xml(webData)
|
||||
if wechatmp_msg.msg_type == "text" or wechatmp_msg.msg_type == "voice":
|
||||
from_user = wechatmp_msg.from_user_id
|
||||
message = wechatmp_msg.content.decode("utf-8")
|
||||
message_id = wechatmp_msg.msg_id
|
||||
|
||||
logger.info(
|
||||
"[wechatmp] {}:{} Receive post query {} {}: {}".format(
|
||||
web.ctx.env.get("REMOTE_ADDR"),
|
||||
web.ctx.env.get("REMOTE_PORT"),
|
||||
from_user,
|
||||
message_id,
|
||||
message,
|
||||
)
|
||||
)
|
||||
context = channel._compose_context(
|
||||
ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg
|
||||
)
|
||||
if context:
|
||||
# set private openai_api_key
|
||||
# if from_user is not changed in itchat, this can be placed at chat_channel
|
||||
user_data = conf().get_user_data(from_user)
|
||||
context["openai_api_key"] = user_data.get(
|
||||
"openai_api_key"
|
||||
) # None or user openai_api_key
|
||||
channel.produce(context)
|
||||
# The reply will be sent by channel.send() in another thread
|
||||
return "success"
|
||||
|
||||
elif wechatmp_msg.msg_type == "event":
|
||||
logger.info(
|
||||
"[wechatmp] Event {} from {}".format(
|
||||
wechatmp_msg.Event, wechatmp_msg.from_user_id
|
||||
)
|
||||
)
|
||||
content = subscribe_msg()
|
||||
replyMsg = reply.TextMsg(
|
||||
wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content
|
||||
)
|
||||
return replyMsg.send()
|
||||
else:
|
||||
logger.info("暂且不处理")
|
||||
return "success"
|
||||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
return exc
|
||||
@@ -1,232 +0,0 @@
|
||||
import time
|
||||
|
||||
import web
|
||||
|
||||
import channel.wechatmp.receive as receive
|
||||
import channel.wechatmp.reply as reply
|
||||
from bridge.context import *
|
||||
from channel.wechatmp.common import *
|
||||
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
# This class is instantiated once per query
|
||||
class Query:
|
||||
def GET(self):
|
||||
return verify_server(web.input())
|
||||
|
||||
def POST(self):
|
||||
# Make sure to return the instance that first created, @singleton will do that.
|
||||
channel = WechatMPChannel()
|
||||
try:
|
||||
query_time = time.time()
|
||||
webData = web.data()
|
||||
logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8"))
|
||||
wechatmp_msg = receive.parse_xml(webData)
|
||||
if wechatmp_msg.msg_type == "text" or wechatmp_msg.msg_type == "voice":
|
||||
from_user = wechatmp_msg.from_user_id
|
||||
to_user = wechatmp_msg.to_user_id
|
||||
message = wechatmp_msg.content.decode("utf-8")
|
||||
message_id = wechatmp_msg.msg_id
|
||||
|
||||
logger.info(
|
||||
"[wechatmp] {}:{} Receive post query {} {}: {}".format(
|
||||
web.ctx.env.get("REMOTE_ADDR"),
|
||||
web.ctx.env.get("REMOTE_PORT"),
|
||||
from_user,
|
||||
message_id,
|
||||
message,
|
||||
)
|
||||
)
|
||||
supported = True
|
||||
if "【收到不支持的消息类型,暂无法显示】" in message:
|
||||
supported = False # not supported, used to refresh
|
||||
cache_key = from_user
|
||||
|
||||
reply_text = ""
|
||||
# New request
|
||||
if (
|
||||
cache_key not in channel.cache_dict
|
||||
and cache_key not in channel.running
|
||||
):
|
||||
# The first query begin, reset the cache
|
||||
context = channel._compose_context(
|
||||
ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg
|
||||
)
|
||||
logger.debug(
|
||||
"[wechatmp] context: {} {}".format(context, wechatmp_msg)
|
||||
)
|
||||
if message_id in channel.received_msgs: # received and finished
|
||||
# no return because of bandwords or other reasons
|
||||
return "success"
|
||||
if supported and context:
|
||||
# set private openai_api_key
|
||||
# if from_user is not changed in itchat, this can be placed at chat_channel
|
||||
user_data = conf().get_user_data(from_user)
|
||||
context["openai_api_key"] = user_data.get(
|
||||
"openai_api_key"
|
||||
) # None or user openai_api_key
|
||||
channel.received_msgs[message_id] = wechatmp_msg
|
||||
channel.running.add(cache_key)
|
||||
channel.produce(context)
|
||||
else:
|
||||
trigger_prefix = conf().get("single_chat_prefix", [""])[0]
|
||||
if trigger_prefix or not supported:
|
||||
if trigger_prefix:
|
||||
content = textwrap.dedent(
|
||||
f"""\
|
||||
请输入'{trigger_prefix}'接你想说的话跟我说话。
|
||||
例如:
|
||||
{trigger_prefix}你好,很高兴见到你。"""
|
||||
)
|
||||
else:
|
||||
content = textwrap.dedent(
|
||||
"""\
|
||||
你好,很高兴见到你。
|
||||
请跟我说话吧。"""
|
||||
)
|
||||
else:
|
||||
logger.error(f"[wechatmp] unknown error")
|
||||
content = textwrap.dedent(
|
||||
"""\
|
||||
未知错误,请稍后再试"""
|
||||
)
|
||||
replyMsg = reply.TextMsg(
|
||||
wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content
|
||||
)
|
||||
return replyMsg.send()
|
||||
channel.query1[cache_key] = False
|
||||
channel.query2[cache_key] = False
|
||||
channel.query3[cache_key] = False
|
||||
# User request again, and the answer is not ready
|
||||
elif (
|
||||
cache_key in channel.running
|
||||
and channel.query1.get(cache_key) == True
|
||||
and channel.query2.get(cache_key) == True
|
||||
and channel.query3.get(cache_key) == True
|
||||
):
|
||||
channel.query1[
|
||||
cache_key
|
||||
] = False # To improve waiting experience, this can be set to True.
|
||||
channel.query2[
|
||||
cache_key
|
||||
] = False # To improve waiting experience, this can be set to True.
|
||||
channel.query3[cache_key] = False
|
||||
# User request again, and the answer is ready
|
||||
elif cache_key in channel.cache_dict:
|
||||
# Skip the waiting phase
|
||||
channel.query1[cache_key] = True
|
||||
channel.query2[cache_key] = True
|
||||
channel.query3[cache_key] = True
|
||||
|
||||
assert not (
|
||||
cache_key in channel.cache_dict and cache_key in channel.running
|
||||
)
|
||||
|
||||
if channel.query1.get(cache_key) == False:
|
||||
# The first query from wechat official server
|
||||
logger.debug("[wechatmp] query1 {}".format(cache_key))
|
||||
channel.query1[cache_key] = True
|
||||
cnt = 0
|
||||
while cache_key in channel.running and cnt < 45:
|
||||
cnt = cnt + 1
|
||||
time.sleep(0.1)
|
||||
if cnt == 45:
|
||||
# waiting for timeout (the POST query will be closed by wechat official server)
|
||||
time.sleep(1)
|
||||
# and do nothing
|
||||
return
|
||||
else:
|
||||
pass
|
||||
elif channel.query2.get(cache_key) == False:
|
||||
# The second query from wechat official server
|
||||
logger.debug("[wechatmp] query2 {}".format(cache_key))
|
||||
channel.query2[cache_key] = True
|
||||
cnt = 0
|
||||
while cache_key in channel.running and cnt < 45:
|
||||
cnt = cnt + 1
|
||||
time.sleep(0.1)
|
||||
if cnt == 45:
|
||||
# waiting for timeout (the POST query will be closed by wechat official server)
|
||||
time.sleep(1)
|
||||
# and do nothing
|
||||
return
|
||||
else:
|
||||
pass
|
||||
elif channel.query3.get(cache_key) == False:
|
||||
# The third query from wechat official server
|
||||
logger.debug("[wechatmp] query3 {}".format(cache_key))
|
||||
channel.query3[cache_key] = True
|
||||
cnt = 0
|
||||
while cache_key in channel.running and cnt < 40:
|
||||
cnt = cnt + 1
|
||||
time.sleep(0.1)
|
||||
if cnt == 40:
|
||||
# Have waiting for 3x5 seconds
|
||||
# return timeout message
|
||||
reply_text = "【正在思考中,回复任意文字尝试获取回复】"
|
||||
logger.info(
|
||||
"[wechatmp] Three queries has finished For {}: {}".format(
|
||||
from_user, message_id
|
||||
)
|
||||
)
|
||||
replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
|
||||
return replyPost
|
||||
else:
|
||||
pass
|
||||
|
||||
if (
|
||||
cache_key not in channel.cache_dict
|
||||
and cache_key not in channel.running
|
||||
):
|
||||
# no return because of bandwords or other reasons
|
||||
return "success"
|
||||
|
||||
# if float(time.time()) - float(query_time) > 4.8:
|
||||
# reply_text = "【正在思考中,回复任意文字尝试获取回复】"
|
||||
# logger.info("[wechatmp] Timeout for {} {}, return".format(from_user, message_id))
|
||||
# replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
|
||||
# return replyPost
|
||||
|
||||
if cache_key in channel.cache_dict:
|
||||
content = channel.cache_dict[cache_key]
|
||||
if len(content.encode("utf8")) <= MAX_UTF8_LEN:
|
||||
reply_text = channel.cache_dict[cache_key]
|
||||
channel.cache_dict.pop(cache_key)
|
||||
else:
|
||||
continue_text = "\n【未完待续,回复任意文字以继续】"
|
||||
splits = split_string_by_utf8_length(
|
||||
content,
|
||||
MAX_UTF8_LEN - len(continue_text.encode("utf-8")),
|
||||
max_split=1,
|
||||
)
|
||||
reply_text = splits[0] + continue_text
|
||||
channel.cache_dict[cache_key] = splits[1]
|
||||
logger.info(
|
||||
"[wechatmp] {}:{} Do send {}".format(
|
||||
web.ctx.env.get("REMOTE_ADDR"),
|
||||
web.ctx.env.get("REMOTE_PORT"),
|
||||
reply_text,
|
||||
)
|
||||
)
|
||||
replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
|
||||
return replyPost
|
||||
|
||||
elif wechatmp_msg.msg_type == "event":
|
||||
logger.info(
|
||||
"[wechatmp] Event {} from {}".format(
|
||||
wechatmp_msg.content, wechatmp_msg.from_user_id
|
||||
)
|
||||
)
|
||||
content = subscribe_msg()
|
||||
replyMsg = reply.TextMsg(
|
||||
wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content
|
||||
)
|
||||
return replyMsg.send()
|
||||
else:
|
||||
logger.info("暂且不处理")
|
||||
return "success"
|
||||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
return exc
|
||||
78
channel/wechatmp/active_reply.py
Normal file
78
channel/wechatmp/active_reply.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import time
|
||||
|
||||
import web
|
||||
from wechatpy import parse_message
|
||||
from wechatpy.replies import create_reply
|
||||
|
||||
from bridge.context import *
|
||||
from bridge.reply import *
|
||||
from channel.wechatmp.common import *
|
||||
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
||||
from channel.wechatmp.wechatmp_message import WeChatMPMessage
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
# This class is instantiated once per query
|
||||
class Query:
|
||||
def GET(self):
|
||||
return verify_server(web.input())
|
||||
|
||||
def POST(self):
|
||||
# Make sure to return the instance that first created, @singleton will do that.
|
||||
try:
|
||||
args = web.input()
|
||||
verify_server(args)
|
||||
channel = WechatMPChannel()
|
||||
message = web.data()
|
||||
encrypt_func = lambda x: x
|
||||
if args.get("encrypt_type") == "aes":
|
||||
logger.debug("[wechatmp] Receive encrypted post data:\n" + message.decode("utf-8"))
|
||||
if not channel.crypto:
|
||||
raise Exception("Crypto not initialized, Please set wechatmp_aes_key in config.json")
|
||||
message = channel.crypto.decrypt_message(message, args.msg_signature, args.timestamp, args.nonce)
|
||||
encrypt_func = lambda x: channel.crypto.encrypt_message(x, args.nonce, args.timestamp)
|
||||
else:
|
||||
logger.debug("[wechatmp] Receive post data:\n" + message.decode("utf-8"))
|
||||
msg = parse_message(message)
|
||||
if msg.type in ["text", "voice", "image"]:
|
||||
wechatmp_msg = WeChatMPMessage(msg, client=channel.client)
|
||||
from_user = wechatmp_msg.from_user_id
|
||||
content = wechatmp_msg.content
|
||||
message_id = wechatmp_msg.msg_id
|
||||
|
||||
logger.info(
|
||||
"[wechatmp] {}:{} Receive post query {} {}: {}".format(
|
||||
web.ctx.env.get("REMOTE_ADDR"),
|
||||
web.ctx.env.get("REMOTE_PORT"),
|
||||
from_user,
|
||||
message_id,
|
||||
content,
|
||||
)
|
||||
)
|
||||
if msg.type == "voice" and wechatmp_msg.ctype == ContextType.TEXT and conf().get("voice_reply_voice", False):
|
||||
context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg)
|
||||
else:
|
||||
context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg)
|
||||
if context:
|
||||
# set private openai_api_key
|
||||
# if from_user is not changed in itchat, this can be placed at chat_channel
|
||||
user_data = conf().get_user_data(from_user)
|
||||
context["openai_api_key"] = user_data.get("openai_api_key") # None or user openai_api_key
|
||||
channel.produce(context)
|
||||
# The reply will be sent by channel.send() in another thread
|
||||
return "success"
|
||||
elif msg.type == "event":
|
||||
logger.info("[wechatmp] Event {} from {}".format(msg.event, msg.source))
|
||||
if msg.event in ["subscribe", "subscribe_scan"]:
|
||||
reply_text = subscribe_msg()
|
||||
replyPost = create_reply(reply_text, msg)
|
||||
return encrypt_func(replyPost.render())
|
||||
else:
|
||||
return "success"
|
||||
else:
|
||||
logger.info("暂且不处理")
|
||||
return "success"
|
||||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
return exc
|
||||
@@ -1,6 +1,10 @@
|
||||
import hashlib
|
||||
import textwrap
|
||||
|
||||
import web
|
||||
from wechatpy.crypto import WeChatCrypto
|
||||
from wechatpy.exceptions import InvalidSignatureException
|
||||
from wechatpy.utils import check_signature
|
||||
|
||||
from config import conf
|
||||
|
||||
MAX_UTF8_LEN = 2048
|
||||
@@ -12,27 +16,17 @@ class WeChatAPIException(Exception):
|
||||
|
||||
def verify_server(data):
|
||||
try:
|
||||
if len(data) == 0:
|
||||
return "None"
|
||||
signature = data.signature
|
||||
timestamp = data.timestamp
|
||||
nonce = data.nonce
|
||||
echostr = data.echostr
|
||||
echostr = data.get("echostr", None)
|
||||
token = conf().get("wechatmp_token") # 请按照公众平台官网\基本配置中信息填写
|
||||
|
||||
data_list = [token, timestamp, nonce]
|
||||
data_list.sort()
|
||||
sha1 = hashlib.sha1()
|
||||
# map(sha1.update, data_list) #python2
|
||||
sha1.update("".join(data_list).encode("utf-8"))
|
||||
hashcode = sha1.hexdigest()
|
||||
print("handle/GET func: hashcode, signature: ", hashcode, signature)
|
||||
if hashcode == signature:
|
||||
return echostr
|
||||
else:
|
||||
return ""
|
||||
except Exception as Argument:
|
||||
return Argument
|
||||
check_signature(token, signature, timestamp, nonce)
|
||||
return echostr
|
||||
except InvalidSignatureException:
|
||||
raise web.Forbidden("Invalid signature")
|
||||
except Exception as e:
|
||||
raise web.Forbidden(str(e))
|
||||
|
||||
|
||||
def subscribe_msg():
|
||||
@@ -42,10 +36,10 @@ def subscribe_msg():
|
||||
感谢您的关注!
|
||||
这里是ChatGPT,可以自由对话。
|
||||
资源有限,回复较慢,请勿着急。
|
||||
支持通用表情输入。
|
||||
暂时不支持图片输入。
|
||||
支持图片输出,画字开头的问题将回复图片链接。
|
||||
支持角色扮演和文字冒险两种定制模式对话。
|
||||
支持语音对话。
|
||||
支持图片输入。
|
||||
支持图片输出,画字开头的消息将按要求创作图片。
|
||||
支持tool、角色扮演和文字冒险等丰富的插件。
|
||||
输入'{trigger_prefix}#帮助' 查看详细指令。"""
|
||||
)
|
||||
return msg
|
||||
@@ -59,7 +53,7 @@ def split_string_by_utf8_length(string, max_length, max_split=0):
|
||||
if max_split > 0 and len(result) >= max_split:
|
||||
result.append(encoded[start:].decode("utf-8"))
|
||||
break
|
||||
end = start + max_length
|
||||
end = min(start + max_length, len(encoded))
|
||||
# 如果当前字节不是 UTF-8 编码的开始字节,则向前查找直到找到开始字节为止
|
||||
while end < len(encoded) and (encoded[end] & 0b11000000) == 0b10000000:
|
||||
end -= 1
|
||||
|
||||
212
channel/wechatmp/passive_reply.py
Normal file
212
channel/wechatmp/passive_reply.py
Normal file
@@ -0,0 +1,212 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import web
|
||||
from wechatpy import parse_message
|
||||
from wechatpy.replies import ImageReply, VoiceReply, create_reply
|
||||
|
||||
from bridge.context import *
|
||||
from bridge.reply import *
|
||||
from channel.wechatmp.common import *
|
||||
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
||||
from channel.wechatmp.wechatmp_message import WeChatMPMessage
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
# This class is instantiated once per query
|
||||
class Query:
|
||||
def GET(self):
|
||||
return verify_server(web.input())
|
||||
|
||||
def POST(self):
|
||||
try:
|
||||
args = web.input()
|
||||
verify_server(args)
|
||||
request_time = time.time()
|
||||
channel = WechatMPChannel()
|
||||
message = web.data()
|
||||
encrypt_func = lambda x: x
|
||||
if args.get("encrypt_type") == "aes":
|
||||
logger.debug("[wechatmp] Receive encrypted post data:\n" + message.decode("utf-8"))
|
||||
if not channel.crypto:
|
||||
raise Exception("Crypto not initialized, Please set wechatmp_aes_key in config.json")
|
||||
message = channel.crypto.decrypt_message(message, args.msg_signature, args.timestamp, args.nonce)
|
||||
encrypt_func = lambda x: channel.crypto.encrypt_message(x, args.nonce, args.timestamp)
|
||||
else:
|
||||
logger.debug("[wechatmp] Receive post data:\n" + message.decode("utf-8"))
|
||||
msg = parse_message(message)
|
||||
if msg.type in ["text", "voice", "image"]:
|
||||
wechatmp_msg = WeChatMPMessage(msg, client=channel.client)
|
||||
from_user = wechatmp_msg.from_user_id
|
||||
content = wechatmp_msg.content
|
||||
message_id = wechatmp_msg.msg_id
|
||||
|
||||
supported = True
|
||||
if "【收到不支持的消息类型,暂无法显示】" in content:
|
||||
supported = False # not supported, used to refresh
|
||||
|
||||
# New request
|
||||
if (
|
||||
from_user not in channel.cache_dict
|
||||
and from_user not in channel.running
|
||||
or content.startswith("#")
|
||||
and message_id not in channel.request_cnt # insert the godcmd
|
||||
):
|
||||
# The first query begin
|
||||
if msg.type == "voice" and wechatmp_msg.ctype == ContextType.TEXT and conf().get("voice_reply_voice", False):
|
||||
context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg)
|
||||
else:
|
||||
context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg)
|
||||
logger.debug("[wechatmp] context: {} {} {}".format(context, wechatmp_msg, supported))
|
||||
|
||||
if supported and context:
|
||||
# set private openai_api_key
|
||||
# if from_user is not changed in itchat, this can be placed at chat_channel
|
||||
user_data = conf().get_user_data(from_user)
|
||||
context["openai_api_key"] = user_data.get("openai_api_key")
|
||||
channel.running.add(from_user)
|
||||
channel.produce(context)
|
||||
else:
|
||||
trigger_prefix = conf().get("single_chat_prefix", [""])[0]
|
||||
if trigger_prefix or not supported:
|
||||
if trigger_prefix:
|
||||
reply_text = textwrap.dedent(
|
||||
f"""\
|
||||
请输入'{trigger_prefix}'接你想说的话跟我说话。
|
||||
例如:
|
||||
{trigger_prefix}你好,很高兴见到你。"""
|
||||
)
|
||||
else:
|
||||
reply_text = textwrap.dedent(
|
||||
"""\
|
||||
你好,很高兴见到你。
|
||||
请跟我说话吧。"""
|
||||
)
|
||||
else:
|
||||
logger.error(f"[wechatmp] unknown error")
|
||||
reply_text = textwrap.dedent(
|
||||
"""\
|
||||
未知错误,请稍后再试"""
|
||||
)
|
||||
|
||||
replyPost = create_reply(reply_text, msg)
|
||||
return encrypt_func(replyPost.render())
|
||||
|
||||
# Wechat official server will request 3 times (5 seconds each), with the same message_id.
|
||||
# Because the interval is 5 seconds, here assumed that do not have multithreading problems.
|
||||
request_cnt = channel.request_cnt.get(message_id, 0) + 1
|
||||
channel.request_cnt[message_id] = request_cnt
|
||||
logger.info(
|
||||
"[wechatmp] Request {} from {} {} {}:{}\n{}".format(
|
||||
request_cnt, from_user, message_id, web.ctx.env.get("REMOTE_ADDR"), web.ctx.env.get("REMOTE_PORT"), content
|
||||
)
|
||||
)
|
||||
|
||||
task_running = True
|
||||
waiting_until = request_time + 4
|
||||
while time.time() < waiting_until:
|
||||
if from_user in channel.running:
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
task_running = False
|
||||
break
|
||||
|
||||
reply_text = ""
|
||||
if task_running:
|
||||
if request_cnt < 3:
|
||||
# waiting for timeout (the POST request will be closed by Wechat official server)
|
||||
time.sleep(2)
|
||||
# and do nothing, waiting for the next request
|
||||
return "success"
|
||||
else: # request_cnt == 3:
|
||||
# return timeout message
|
||||
reply_text = "【正在思考中,回复任意文字尝试获取回复】"
|
||||
replyPost = create_reply(reply_text, msg)
|
||||
return encrypt_func(replyPost.render())
|
||||
|
||||
# reply is ready
|
||||
channel.request_cnt.pop(message_id)
|
||||
|
||||
# no return because of bandwords or other reasons
|
||||
if from_user not in channel.cache_dict and from_user not in channel.running:
|
||||
return "success"
|
||||
|
||||
# Only one request can access to the cached data
|
||||
try:
|
||||
(reply_type, reply_content) = channel.cache_dict.pop(from_user)
|
||||
except KeyError:
|
||||
return "success"
|
||||
|
||||
if reply_type == "text":
|
||||
if len(reply_content.encode("utf8")) <= MAX_UTF8_LEN:
|
||||
reply_text = reply_content
|
||||
else:
|
||||
continue_text = "\n【未完待续,回复任意文字以继续】"
|
||||
splits = split_string_by_utf8_length(
|
||||
reply_content,
|
||||
MAX_UTF8_LEN - len(continue_text.encode("utf-8")),
|
||||
max_split=1,
|
||||
)
|
||||
reply_text = splits[0] + continue_text
|
||||
channel.cache_dict[from_user] = ("text", splits[1])
|
||||
|
||||
logger.info(
|
||||
"[wechatmp] Request {} do send to {} {}: {}\n{}".format(
|
||||
request_cnt,
|
||||
from_user,
|
||||
message_id,
|
||||
content,
|
||||
reply_text,
|
||||
)
|
||||
)
|
||||
replyPost = create_reply(reply_text, msg)
|
||||
return encrypt_func(replyPost.render())
|
||||
|
||||
elif reply_type == "voice":
|
||||
media_id = reply_content
|
||||
asyncio.run_coroutine_threadsafe(channel.delete_media(media_id), channel.delete_media_loop)
|
||||
logger.info(
|
||||
"[wechatmp] Request {} do send to {} {}: {} voice media_id {}".format(
|
||||
request_cnt,
|
||||
from_user,
|
||||
message_id,
|
||||
content,
|
||||
media_id,
|
||||
)
|
||||
)
|
||||
replyPost = VoiceReply(message=msg)
|
||||
replyPost.media_id = media_id
|
||||
return encrypt_func(replyPost.render())
|
||||
|
||||
elif reply_type == "image":
|
||||
media_id = reply_content
|
||||
asyncio.run_coroutine_threadsafe(channel.delete_media(media_id), channel.delete_media_loop)
|
||||
logger.info(
|
||||
"[wechatmp] Request {} do send to {} {}: {} image media_id {}".format(
|
||||
request_cnt,
|
||||
from_user,
|
||||
message_id,
|
||||
content,
|
||||
media_id,
|
||||
)
|
||||
)
|
||||
replyPost = ImageReply(message=msg)
|
||||
replyPost.media_id = media_id
|
||||
return encrypt_func(replyPost.render())
|
||||
|
||||
elif msg.type == "event":
|
||||
logger.info("[wechatmp] Event {} from {}".format(msg.event, msg.source))
|
||||
if msg.event in ["subscribe", "subscribe_scan"]:
|
||||
reply_text = subscribe_msg()
|
||||
replyPost = create_reply(reply_text, msg)
|
||||
return encrypt_func(replyPost.render())
|
||||
else:
|
||||
return "success"
|
||||
|
||||
else:
|
||||
logger.info("暂且不处理")
|
||||
return "success"
|
||||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
return exc
|
||||
@@ -1,47 +0,0 @@
|
||||
# -*- coding: utf-8 -*-#
|
||||
# filename: receive.py
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.log import logger
|
||||
|
||||
|
||||
def parse_xml(web_data):
|
||||
if len(web_data) == 0:
|
||||
return None
|
||||
xmlData = ET.fromstring(web_data)
|
||||
return WeChatMPMessage(xmlData)
|
||||
|
||||
|
||||
class WeChatMPMessage(ChatMessage):
|
||||
def __init__(self, xmlData):
|
||||
super().__init__(xmlData)
|
||||
self.to_user_id = xmlData.find("ToUserName").text
|
||||
self.from_user_id = xmlData.find("FromUserName").text
|
||||
self.create_time = xmlData.find("CreateTime").text
|
||||
self.msg_type = xmlData.find("MsgType").text
|
||||
try:
|
||||
self.msg_id = xmlData.find("MsgId").text
|
||||
except:
|
||||
self.msg_id = self.from_user_id + self.create_time
|
||||
self.is_group = False
|
||||
|
||||
# reply to other_user_id
|
||||
self.other_user_id = self.from_user_id
|
||||
|
||||
if self.msg_type == "text":
|
||||
self.ctype = ContextType.TEXT
|
||||
self.content = xmlData.find("Content").text.encode("utf-8")
|
||||
elif self.msg_type == "voice":
|
||||
self.ctype = ContextType.TEXT
|
||||
self.content = xmlData.find("Recognition").text.encode("utf-8") # 接收语音识别结果
|
||||
elif self.msg_type == "image":
|
||||
# not implemented
|
||||
self.pic_url = xmlData.find("PicUrl").text
|
||||
self.media_id = xmlData.find("MediaId").text
|
||||
elif self.msg_type == "event":
|
||||
self.content = xmlData.find("Event").text
|
||||
else: # video, shortvideo, location, link
|
||||
# not implemented
|
||||
pass
|
||||
@@ -1,55 +0,0 @@
|
||||
# -*- coding: utf-8 -*-#
|
||||
# filename: reply.py
|
||||
import time
|
||||
|
||||
|
||||
class Msg(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def send(self):
|
||||
return "success"
|
||||
|
||||
|
||||
class TextMsg(Msg):
|
||||
def __init__(self, toUserName, fromUserName, content):
|
||||
self.__dict = dict()
|
||||
self.__dict["ToUserName"] = toUserName
|
||||
self.__dict["FromUserName"] = fromUserName
|
||||
self.__dict["CreateTime"] = int(time.time())
|
||||
self.__dict["Content"] = content
|
||||
|
||||
def send(self):
|
||||
XmlForm = """
|
||||
<xml>
|
||||
<ToUserName><![CDATA[{ToUserName}]]></ToUserName>
|
||||
<FromUserName><![CDATA[{FromUserName}]]></FromUserName>
|
||||
<CreateTime>{CreateTime}</CreateTime>
|
||||
<MsgType><![CDATA[text]]></MsgType>
|
||||
<Content><![CDATA[{Content}]]></Content>
|
||||
</xml>
|
||||
"""
|
||||
return XmlForm.format(**self.__dict)
|
||||
|
||||
|
||||
class ImageMsg(Msg):
|
||||
def __init__(self, toUserName, fromUserName, mediaId):
|
||||
self.__dict = dict()
|
||||
self.__dict["ToUserName"] = toUserName
|
||||
self.__dict["FromUserName"] = fromUserName
|
||||
self.__dict["CreateTime"] = int(time.time())
|
||||
self.__dict["MediaId"] = mediaId
|
||||
|
||||
def send(self):
|
||||
XmlForm = """
|
||||
<xml>
|
||||
<ToUserName><![CDATA[{ToUserName}]]></ToUserName>
|
||||
<FromUserName><![CDATA[{FromUserName}]]></FromUserName>
|
||||
<CreateTime>{CreateTime}</CreateTime>
|
||||
<MsgType><![CDATA[image]]></MsgType>
|
||||
<Image>
|
||||
<MediaId><![CDATA[{MediaId}]]></MediaId>
|
||||
</Image>
|
||||
</xml>
|
||||
"""
|
||||
return XmlForm.format(**self.__dict)
|
||||
@@ -1,19 +1,25 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import json
|
||||
import asyncio
|
||||
import imghdr
|
||||
import io
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
import requests
|
||||
import web
|
||||
from wechatpy.crypto import WeChatCrypto
|
||||
from wechatpy.exceptions import WeChatClientException
|
||||
|
||||
from bridge.context import *
|
||||
from bridge.reply import *
|
||||
from channel.chat_channel import ChatChannel
|
||||
from channel.wechatmp.common import *
|
||||
from common.expired_dict import ExpiredDict
|
||||
from channel.wechatmp.wechatmp_client import WechatMPClient
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from config import conf
|
||||
from voice.audio_convert import any_to_mp3
|
||||
|
||||
# If using SSL, uncomment the following lines, and modify the certificate path.
|
||||
# from cheroot.server import HTTPServer
|
||||
@@ -28,111 +34,180 @@ class WechatMPChannel(ChatChannel):
|
||||
def __init__(self, passive_reply=True):
|
||||
super().__init__()
|
||||
self.passive_reply = passive_reply
|
||||
self.running = set()
|
||||
self.received_msgs = ExpiredDict(60 * 60 * 24)
|
||||
self.NOT_SUPPORT_REPLYTYPE = []
|
||||
appid = conf().get("wechatmp_app_id")
|
||||
secret = conf().get("wechatmp_app_secret")
|
||||
token = conf().get("wechatmp_token")
|
||||
aes_key = conf().get("wechatmp_aes_key")
|
||||
self.client = WechatMPClient(appid, secret)
|
||||
self.crypto = None
|
||||
if aes_key:
|
||||
self.crypto = WeChatCrypto(token, aes_key, appid)
|
||||
if self.passive_reply:
|
||||
self.NOT_SUPPORT_REPLYTYPE = [ReplyType.IMAGE, ReplyType.VOICE]
|
||||
# Cache the reply to the user's first message
|
||||
self.cache_dict = dict()
|
||||
self.query1 = dict()
|
||||
self.query2 = dict()
|
||||
self.query3 = dict()
|
||||
else:
|
||||
# TODO support image
|
||||
self.NOT_SUPPORT_REPLYTYPE = [ReplyType.IMAGE, ReplyType.VOICE]
|
||||
self.app_id = conf().get("wechatmp_app_id")
|
||||
self.app_secret = conf().get("wechatmp_app_secret")
|
||||
self.access_token = None
|
||||
self.access_token_expires_time = 0
|
||||
self.access_token_lock = threading.Lock()
|
||||
self.get_access_token()
|
||||
# Record whether the current message is being processed
|
||||
self.running = set()
|
||||
# Count the request from wechat official server by message_id
|
||||
self.request_cnt = dict()
|
||||
# The permanent media need to be deleted to avoid media number limit
|
||||
self.delete_media_loop = asyncio.new_event_loop()
|
||||
t = threading.Thread(target=self.start_loop, args=(self.delete_media_loop,))
|
||||
t.setDaemon(True)
|
||||
t.start()
|
||||
|
||||
def startup(self):
|
||||
if self.passive_reply:
|
||||
urls = ("/wx", "channel.wechatmp.SubscribeAccount.Query")
|
||||
urls = ("/wx", "channel.wechatmp.passive_reply.Query")
|
||||
else:
|
||||
urls = ("/wx", "channel.wechatmp.ServiceAccount.Query")
|
||||
urls = ("/wx", "channel.wechatmp.active_reply.Query")
|
||||
app = web.application(urls, globals(), autoreload=False)
|
||||
port = conf().get("wechatmp_port", 8080)
|
||||
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
|
||||
|
||||
def wechatmp_request(self, method, url, **kwargs):
|
||||
r = requests.request(method=method, url=url, **kwargs)
|
||||
r.raise_for_status()
|
||||
r.encoding = "utf-8"
|
||||
ret = r.json()
|
||||
if "errcode" in ret and ret["errcode"] != 0:
|
||||
raise WeChatAPIException("{}".format(ret))
|
||||
return ret
|
||||
def start_loop(self, loop):
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_forever()
|
||||
|
||||
def get_access_token(self):
|
||||
# return the access_token
|
||||
if self.access_token:
|
||||
if self.access_token_expires_time - time.time() > 60:
|
||||
return self.access_token
|
||||
|
||||
# Get new access_token
|
||||
# Do not request access_token in parallel! Only the last obtained is valid.
|
||||
if self.access_token_lock.acquire(blocking=False):
|
||||
# Wait for other threads that have previously obtained access_token to complete the request
|
||||
# This happens every 2 hours, so it doesn't affect the experience very much
|
||||
time.sleep(1)
|
||||
self.access_token = None
|
||||
url = "https://api.weixin.qq.com/cgi-bin/token"
|
||||
params = {
|
||||
"grant_type": "client_credential",
|
||||
"appid": self.app_id,
|
||||
"secret": self.app_secret,
|
||||
}
|
||||
data = self.wechatmp_request(method="get", url=url, params=params)
|
||||
self.access_token = data["access_token"]
|
||||
self.access_token_expires_time = int(time.time()) + data["expires_in"]
|
||||
logger.info("[wechatmp] access_token: {}".format(self.access_token))
|
||||
self.access_token_lock.release()
|
||||
else:
|
||||
# Wait for token update
|
||||
while self.access_token_lock.locked():
|
||||
time.sleep(0.1)
|
||||
return self.access_token
|
||||
async def delete_media(self, media_id):
|
||||
logger.debug("[wechatmp] permanent media {} will be deleted in 10s".format(media_id))
|
||||
await asyncio.sleep(10)
|
||||
self.client.material.delete(media_id)
|
||||
logger.info("[wechatmp] permanent media {} has been deleted".format(media_id))
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
receiver = context["receiver"]
|
||||
if self.passive_reply:
|
||||
receiver = context["receiver"]
|
||||
self.cache_dict[receiver] = reply.content
|
||||
logger.info("[send] reply to {} saved to cache: {}".format(receiver, reply))
|
||||
if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR:
|
||||
reply_text = reply.content
|
||||
logger.info("[wechatmp] text cached, receiver {}\n{}".format(receiver, reply_text))
|
||||
self.cache_dict[receiver] = ("text", reply_text)
|
||||
elif reply.type == ReplyType.VOICE:
|
||||
try:
|
||||
voice_file_path = reply.content
|
||||
with open(voice_file_path, "rb") as f:
|
||||
# support: <2M, <60s, mp3/wma/wav/amr
|
||||
response = self.client.material.add("voice", f)
|
||||
logger.debug("[wechatmp] upload voice response: {}".format(response))
|
||||
# 根据文件大小估计一个微信自动审核的时间,审核结束前返回将会导致语音无法播放,这个估计有待验证
|
||||
f_size = os.fstat(f.fileno()).st_size
|
||||
time.sleep(1.0 + 2 * f_size / 1024 / 1024)
|
||||
# todo check media_id
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatmp] upload voice failed: {}".format(e))
|
||||
return
|
||||
media_id = response["media_id"]
|
||||
logger.info("[wechatmp] voice uploaded, receiver {}, media_id {}".format(receiver, media_id))
|
||||
self.cache_dict[receiver] = ("voice", media_id)
|
||||
|
||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||
img_url = reply.content
|
||||
pic_res = requests.get(img_url, stream=True)
|
||||
image_storage = io.BytesIO()
|
||||
for block in pic_res.iter_content(1024):
|
||||
image_storage.write(block)
|
||||
image_storage.seek(0)
|
||||
image_type = imghdr.what(image_storage)
|
||||
filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
|
||||
content_type = "image/" + image_type
|
||||
try:
|
||||
response = self.client.material.add("image", (filename, image_storage, content_type))
|
||||
logger.debug("[wechatmp] upload image response: {}".format(response))
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatmp] upload image failed: {}".format(e))
|
||||
return
|
||||
media_id = response["media_id"]
|
||||
logger.info("[wechatmp] image uploaded, receiver {}, media_id {}".format(receiver, media_id))
|
||||
self.cache_dict[receiver] = ("image", media_id)
|
||||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
|
||||
image_storage = reply.content
|
||||
image_storage.seek(0)
|
||||
image_type = imghdr.what(image_storage)
|
||||
filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
|
||||
content_type = "image/" + image_type
|
||||
try:
|
||||
response = self.client.material.add("image", (filename, image_storage, content_type))
|
||||
logger.debug("[wechatmp] upload image response: {}".format(response))
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatmp] upload image failed: {}".format(e))
|
||||
return
|
||||
media_id = response["media_id"]
|
||||
logger.info("[wechatmp] image uploaded, receiver {}, media_id {}".format(receiver, media_id))
|
||||
self.cache_dict[receiver] = ("image", media_id)
|
||||
else:
|
||||
receiver = context["receiver"]
|
||||
reply_text = reply.content
|
||||
url = "https://api.weixin.qq.com/cgi-bin/message/custom/send"
|
||||
params = {"access_token": self.get_access_token()}
|
||||
json_data = {
|
||||
"touser": receiver,
|
||||
"msgtype": "text",
|
||||
"text": {"content": reply_text},
|
||||
}
|
||||
self.wechatmp_request(
|
||||
method="post",
|
||||
url=url,
|
||||
params=params,
|
||||
data=json.dumps(json_data, ensure_ascii=False).encode("utf8"),
|
||||
)
|
||||
logger.info("[send] Do send to {}: {}".format(receiver, reply_text))
|
||||
if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR:
|
||||
reply_text = reply.content
|
||||
texts = split_string_by_utf8_length(reply_text, MAX_UTF8_LEN)
|
||||
if len(texts) > 1:
|
||||
logger.info("[wechatmp] text too long, split into {} parts".format(len(texts)))
|
||||
for text in texts:
|
||||
self.client.message.send_text(receiver, text)
|
||||
logger.info("[wechatmp] Do send text to {}: {}".format(receiver, reply_text))
|
||||
elif reply.type == ReplyType.VOICE:
|
||||
try:
|
||||
file_path = reply.content
|
||||
file_name = os.path.basename(file_path)
|
||||
file_type = os.path.splitext(file_name)[1]
|
||||
if file_type == ".mp3":
|
||||
file_type = "audio/mpeg"
|
||||
elif file_type == ".amr":
|
||||
file_type = "audio/amr"
|
||||
else:
|
||||
mp3_file = os.path.splitext(file_path)[0] + ".mp3"
|
||||
any_to_mp3(file_path, mp3_file)
|
||||
file_path = mp3_file
|
||||
file_name = os.path.basename(file_path)
|
||||
file_type = "audio/mpeg"
|
||||
logger.info("[wechatmp] file_name: {}, file_type: {} ".format(file_name, file_type))
|
||||
# support: <2M, <60s, AMR\MP3
|
||||
response = self.client.media.upload("voice", (file_name, open(file_path, "rb"), file_type))
|
||||
logger.debug("[wechatmp] upload voice response: {}".format(response))
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatmp] upload voice failed: {}".format(e))
|
||||
return
|
||||
self.client.message.send_voice(receiver, response["media_id"])
|
||||
logger.info("[wechatmp] Do send voice to {}".format(receiver))
|
||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||
img_url = reply.content
|
||||
pic_res = requests.get(img_url, stream=True)
|
||||
image_storage = io.BytesIO()
|
||||
for block in pic_res.iter_content(1024):
|
||||
image_storage.write(block)
|
||||
image_storage.seek(0)
|
||||
image_type = imghdr.what(image_storage)
|
||||
filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
|
||||
content_type = "image/" + image_type
|
||||
try:
|
||||
response = self.client.media.upload("image", (filename, image_storage, content_type))
|
||||
logger.debug("[wechatmp] upload image response: {}".format(response))
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatmp] upload image failed: {}".format(e))
|
||||
return
|
||||
self.client.message.send_image(receiver, response["media_id"])
|
||||
logger.info("[wechatmp] Do send image to {}".format(receiver))
|
||||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
|
||||
image_storage = reply.content
|
||||
image_storage.seek(0)
|
||||
image_type = imghdr.what(image_storage)
|
||||
filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
|
||||
content_type = "image/" + image_type
|
||||
try:
|
||||
response = self.client.media.upload("image", (filename, image_storage, content_type))
|
||||
logger.debug("[wechatmp] upload image response: {}".format(response))
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatmp] upload image failed: {}".format(e))
|
||||
return
|
||||
self.client.message.send_image(receiver, response["media_id"])
|
||||
logger.info("[wechatmp] Do send image to {}".format(receiver))
|
||||
return
|
||||
|
||||
def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数
|
||||
logger.debug(
|
||||
"[wechatmp] Success to generate reply, msgId={}".format(
|
||||
context["msg"].msg_id
|
||||
)
|
||||
)
|
||||
logger.debug("[wechatmp] Success to generate reply, msgId={}".format(context["msg"].msg_id))
|
||||
if self.passive_reply:
|
||||
self.running.remove(session_id)
|
||||
|
||||
def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数
|
||||
logger.exception(
|
||||
"[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format(
|
||||
context["msg"].msg_id, exception
|
||||
)
|
||||
)
|
||||
logger.exception("[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format(context["msg"].msg_id, exception))
|
||||
if self.passive_reply:
|
||||
assert session_id not in self.cache_dict
|
||||
self.running.remove(session_id)
|
||||
|
||||
40
channel/wechatmp/wechatmp_client.py
Normal file
40
channel/wechatmp/wechatmp_client.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import threading
|
||||
import time
|
||||
|
||||
from wechatpy.client import WeChatClient
|
||||
from wechatpy.exceptions import APILimitedException
|
||||
|
||||
from channel.wechatmp.common import *
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class WechatMPClient(WeChatClient):
|
||||
def __init__(self, appid, secret, access_token=None, session=None, timeout=None, auto_retry=True):
|
||||
super(WechatMPClient, self).__init__(appid, secret, access_token, session, timeout, auto_retry)
|
||||
self.fetch_access_token_lock = threading.Lock()
|
||||
|
||||
def clear_quota(self):
|
||||
return self.post("clear_quota", data={"appid": self.appid})
|
||||
|
||||
def clear_quota_v2(self):
|
||||
return self.post("clear_quota/v2", params={"appid": self.appid, "appsecret": self.secret})
|
||||
|
||||
def fetch_access_token(self): # 重载父类方法,加锁避免多线程重复获取access_token
|
||||
with self.fetch_access_token_lock:
|
||||
access_token = self.session.get(self.access_token_key)
|
||||
if access_token:
|
||||
if not self.expires_at:
|
||||
return access_token
|
||||
timestamp = time.time()
|
||||
if self.expires_at - timestamp > 60:
|
||||
return access_token
|
||||
return super().fetch_access_token()
|
||||
|
||||
def _request(self, method, url_or_endpoint, **kwargs): # 重载父类方法,遇到API限流时,清除quota后重试
|
||||
try:
|
||||
return super()._request(method, url_or_endpoint, **kwargs)
|
||||
except APILimitedException as e:
|
||||
logger.error("[wechatmp] API quata has been used up. {}".format(e))
|
||||
response = self.clear_quota_v2()
|
||||
logger.debug("[wechatmp] API quata has been cleard, {}".format(response))
|
||||
return super()._request(method, url_or_endpoint, **kwargs)
|
||||
56
channel/wechatmp/wechatmp_message.py
Normal file
56
channel/wechatmp/wechatmp_message.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# -*- coding: utf-8 -*-#
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
|
||||
|
||||
class WeChatMPMessage(ChatMessage):
|
||||
def __init__(self, msg, client=None):
|
||||
super().__init__(msg)
|
||||
self.msg_id = msg.id
|
||||
self.create_time = msg.time
|
||||
self.is_group = False
|
||||
|
||||
if msg.type == "text":
|
||||
self.ctype = ContextType.TEXT
|
||||
self.content = msg.content
|
||||
elif msg.type == "voice":
|
||||
if msg.recognition == None:
|
||||
self.ctype = ContextType.VOICE
|
||||
self.content = TmpDir().path() + msg.media_id + "." + msg.format # content直接存临时目录路径
|
||||
|
||||
def download_voice():
|
||||
# 如果响应状态码是200,则将响应内容写入本地文件
|
||||
response = client.media.download(msg.media_id)
|
||||
if response.status_code == 200:
|
||||
with open(self.content, "wb") as f:
|
||||
f.write(response.content)
|
||||
else:
|
||||
logger.info(f"[wechatmp] Failed to download voice file, {response.content}")
|
||||
|
||||
self._prepare_fn = download_voice
|
||||
else:
|
||||
self.ctype = ContextType.TEXT
|
||||
self.content = msg.recognition
|
||||
elif msg.type == "image":
|
||||
self.ctype = ContextType.IMAGE
|
||||
self.content = TmpDir().path() + msg.media_id + ".png" # content直接存临时目录路径
|
||||
|
||||
def download_image():
|
||||
# 如果响应状态码是200,则将响应内容写入本地文件
|
||||
response = client.media.download(msg.media_id)
|
||||
if response.status_code == 200:
|
||||
with open(self.content, "wb") as f:
|
||||
f.write(response.content)
|
||||
else:
|
||||
logger.info(f"[wechatmp] Failed to download image file, {response.content}")
|
||||
|
||||
self._prepare_fn = download_image
|
||||
else:
|
||||
raise NotImplementedError("Unsupported message type: Type:{} ".format(msg.type))
|
||||
|
||||
self.from_user_id = msg.source
|
||||
self.to_user_id = msg.target
|
||||
self.other_user_id = msg.source
|
||||
@@ -13,23 +13,15 @@ def time_checker(f):
|
||||
if chat_time_module:
|
||||
chat_start_time = _config.get("chat_start_time", "00:00")
|
||||
chat_stopt_time = _config.get("chat_stop_time", "24:00")
|
||||
time_regex = re.compile(
|
||||
r"^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$"
|
||||
) # 时间匹配,包含24:00
|
||||
time_regex = re.compile(r"^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$") # 时间匹配,包含24:00
|
||||
|
||||
starttime_format_check = time_regex.match(chat_start_time) # 检查停止时间格式
|
||||
stoptime_format_check = time_regex.match(chat_stopt_time) # 检查停止时间格式
|
||||
chat_time_check = chat_start_time < chat_stopt_time # 确定启动时间<停止时间
|
||||
|
||||
# 时间格式检查
|
||||
if not (
|
||||
starttime_format_check and stoptime_format_check and chat_time_check
|
||||
):
|
||||
logger.warn(
|
||||
"时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})".format(
|
||||
starttime_format_check, stoptime_format_check
|
||||
)
|
||||
)
|
||||
if not (starttime_format_check and stoptime_format_check and chat_time_check):
|
||||
logger.warn("时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})".format(starttime_format_check, stoptime_format_check))
|
||||
if chat_start_time > "23:59":
|
||||
logger.error("启动时间可能存在问题,请修改!")
|
||||
|
||||
|
||||
@@ -73,8 +73,9 @@ available_setting = {
|
||||
# wechatmp的配置
|
||||
"wechatmp_token": "", # 微信公众平台的Token
|
||||
"wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443
|
||||
"wechatmp_app_id": "", # 微信公众平台的appID,仅服务号需要
|
||||
"wechatmp_app_secret": "", # 微信公众平台的appsecret,仅服务号需要
|
||||
"wechatmp_app_id": "", # 微信公众平台的appID
|
||||
"wechatmp_app_secret": "", # 微信公众平台的appsecret
|
||||
"wechatmp_aes_key": "", # 微信公众平台的EncodingAESKey,加密模式需要
|
||||
# chatgpt指令自定义触发词
|
||||
"clear_memory_commands": ["#清除记忆"], # 重置会话指令,必须以#开头
|
||||
# channel配置
|
||||
@@ -157,9 +158,7 @@ def load_config():
|
||||
for name, value in os.environ.items():
|
||||
name = name.lower()
|
||||
if name in available_setting:
|
||||
logger.info(
|
||||
"[INIT] override config by environ args: {}={}".format(name, value)
|
||||
)
|
||||
logger.info("[INIT] override config by environ args: {}={}".format(name, value))
|
||||
try:
|
||||
config[name] = eval(value)
|
||||
except:
|
||||
|
||||
@@ -50,9 +50,7 @@ class Banwords(Plugin):
|
||||
self.reply_action = conf.get("reply_action", "ignore")
|
||||
logger.info("[Banwords] inited")
|
||||
except Exception as e:
|
||||
logger.warn(
|
||||
"[Banwords] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords ."
|
||||
)
|
||||
logger.warn("[Banwords] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords .")
|
||||
raise e
|
||||
|
||||
def on_handle_context(self, e_context: EventContext):
|
||||
@@ -72,9 +70,7 @@ class Banwords(Plugin):
|
||||
return
|
||||
elif self.action == "replace":
|
||||
if self.searchr.ContainsAny(content):
|
||||
reply = Reply(
|
||||
ReplyType.INFO, "发言中包含敏感词,请重试: \n" + self.searchr.Replace(content)
|
||||
)
|
||||
reply = Reply(ReplyType.INFO, "发言中包含敏感词,请重试: \n" + self.searchr.Replace(content))
|
||||
e_context["reply"] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
return
|
||||
@@ -94,12 +90,10 @@ class Banwords(Plugin):
|
||||
return
|
||||
elif self.reply_action == "replace":
|
||||
if self.searchr.ContainsAny(content):
|
||||
reply = Reply(
|
||||
ReplyType.INFO, "已替换回复中的敏感词: \n" + self.searchr.Replace(content)
|
||||
)
|
||||
reply = Reply(ReplyType.INFO, "已替换回复中的敏感词: \n" + self.searchr.Replace(content))
|
||||
e_context["reply"] = reply
|
||||
e_context.action = EventAction.CONTINUE
|
||||
return
|
||||
|
||||
def get_help_text(self, **kwargs):
|
||||
return Banwords.desc
|
||||
return "过滤消息中的敏感词。"
|
||||
|
||||
250
plugins/banwords/lib/WordsSearch.py
Normal file
250
plugins/banwords/lib/WordsSearch.py
Normal file
@@ -0,0 +1,250 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
# ToolGood.Words.WordsSearch.py
|
||||
# 2020, Lin Zhijun, https://github.com/toolgood/ToolGood.Words
|
||||
# Licensed under the Apache License 2.0
|
||||
# 更新日志
|
||||
# 2020.04.06 第一次提交
|
||||
# 2020.05.16 修改,支持大于0xffff的字符
|
||||
|
||||
__all__ = ['WordsSearch']
|
||||
__author__ = 'Lin Zhijun'
|
||||
__date__ = '2020.05.16'
|
||||
|
||||
class TrieNode():
|
||||
def __init__(self):
|
||||
self.Index = 0
|
||||
self.Index = 0
|
||||
self.Layer = 0
|
||||
self.End = False
|
||||
self.Char = ''
|
||||
self.Results = []
|
||||
self.m_values = {}
|
||||
self.Failure = None
|
||||
self.Parent = None
|
||||
|
||||
def Add(self,c):
|
||||
if c in self.m_values :
|
||||
return self.m_values[c]
|
||||
node = TrieNode()
|
||||
node.Parent = self
|
||||
node.Char = c
|
||||
self.m_values[c] = node
|
||||
return node
|
||||
|
||||
def SetResults(self,index):
|
||||
if (self.End == False):
|
||||
self.End = True
|
||||
self.Results.append(index)
|
||||
|
||||
class TrieNode2():
|
||||
def __init__(self):
|
||||
self.End = False
|
||||
self.Results = []
|
||||
self.m_values = {}
|
||||
self.minflag = 0xffff
|
||||
self.maxflag = 0
|
||||
|
||||
def Add(self,c,node3):
|
||||
if (self.minflag > c):
|
||||
self.minflag = c
|
||||
if (self.maxflag < c):
|
||||
self.maxflag = c
|
||||
self.m_values[c] = node3
|
||||
|
||||
def SetResults(self,index):
|
||||
if (self.End == False) :
|
||||
self.End = True
|
||||
if (index in self.Results )==False :
|
||||
self.Results.append(index)
|
||||
|
||||
def HasKey(self,c):
|
||||
return c in self.m_values
|
||||
|
||||
|
||||
def TryGetValue(self,c):
|
||||
if (self.minflag <= c and self.maxflag >= c):
|
||||
if c in self.m_values:
|
||||
return self.m_values[c]
|
||||
return None
|
||||
|
||||
|
||||
class WordsSearch():
|
||||
def __init__(self):
|
||||
self._first = {}
|
||||
self._keywords = []
|
||||
self._indexs=[]
|
||||
|
||||
def SetKeywords(self,keywords):
|
||||
self._keywords = keywords
|
||||
self._indexs=[]
|
||||
for i in range(len(keywords)):
|
||||
self._indexs.append(i)
|
||||
|
||||
root = TrieNode()
|
||||
allNodeLayer={}
|
||||
|
||||
for i in range(len(self._keywords)): # for (i = 0; i < _keywords.length; i++)
|
||||
p = self._keywords[i]
|
||||
nd = root
|
||||
for j in range(len(p)): # for (j = 0; j < p.length; j++)
|
||||
nd = nd.Add(ord(p[j]))
|
||||
if (nd.Layer == 0):
|
||||
nd.Layer = j + 1
|
||||
if nd.Layer in allNodeLayer:
|
||||
allNodeLayer[nd.Layer].append(nd)
|
||||
else:
|
||||
allNodeLayer[nd.Layer]=[]
|
||||
allNodeLayer[nd.Layer].append(nd)
|
||||
nd.SetResults(i)
|
||||
|
||||
|
||||
allNode = []
|
||||
allNode.append(root)
|
||||
for key in allNodeLayer.keys():
|
||||
for nd in allNodeLayer[key]:
|
||||
allNode.append(nd)
|
||||
allNodeLayer=None
|
||||
|
||||
for i in range(len(allNode)): # for (i = 0; i < allNode.length; i++)
|
||||
if i==0 :
|
||||
continue
|
||||
nd=allNode[i]
|
||||
nd.Index = i
|
||||
r = nd.Parent.Failure
|
||||
c = nd.Char
|
||||
while (r != None and (c in r.m_values)==False):
|
||||
r = r.Failure
|
||||
if (r == None):
|
||||
nd.Failure = root
|
||||
else:
|
||||
nd.Failure = r.m_values[c]
|
||||
for key2 in nd.Failure.Results :
|
||||
nd.SetResults(key2)
|
||||
root.Failure = root
|
||||
|
||||
allNode2 = []
|
||||
for i in range(len(allNode)): # for (i = 0; i < allNode.length; i++)
|
||||
allNode2.append( TrieNode2())
|
||||
|
||||
for i in range(len(allNode2)): # for (i = 0; i < allNode2.length; i++)
|
||||
oldNode = allNode[i]
|
||||
newNode = allNode2[i]
|
||||
|
||||
for key in oldNode.m_values :
|
||||
index = oldNode.m_values[key].Index
|
||||
newNode.Add(key, allNode2[index])
|
||||
|
||||
for index in range(len(oldNode.Results)): # for (index = 0; index < oldNode.Results.length; index++)
|
||||
item = oldNode.Results[index]
|
||||
newNode.SetResults(item)
|
||||
|
||||
oldNode=oldNode.Failure
|
||||
while oldNode != root:
|
||||
for key in oldNode.m_values :
|
||||
if (newNode.HasKey(key) == False):
|
||||
index = oldNode.m_values[key].Index
|
||||
newNode.Add(key, allNode2[index])
|
||||
for index in range(len(oldNode.Results)):
|
||||
item = oldNode.Results[index]
|
||||
newNode.SetResults(item)
|
||||
oldNode=oldNode.Failure
|
||||
allNode = None
|
||||
root = None
|
||||
|
||||
# first = []
|
||||
# for index in range(65535):# for (index = 0; index < 0xffff; index++)
|
||||
# first.append(None)
|
||||
|
||||
# for key in allNode2[0].m_values :
|
||||
# first[key] = allNode2[0].m_values[key]
|
||||
|
||||
self._first = allNode2[0]
|
||||
|
||||
|
||||
def FindFirst(self,text):
|
||||
ptr = None
|
||||
for index in range(len(text)): # for (index = 0; index < text.length; index++)
|
||||
t =ord(text[index]) # text.charCodeAt(index)
|
||||
tn = None
|
||||
if (ptr == None):
|
||||
tn = self._first.TryGetValue(t)
|
||||
else:
|
||||
tn = ptr.TryGetValue(t)
|
||||
if (tn==None):
|
||||
tn = self._first.TryGetValue(t)
|
||||
|
||||
|
||||
if (tn != None):
|
||||
if (tn.End):
|
||||
item = tn.Results[0]
|
||||
keyword = self._keywords[item]
|
||||
return { "Keyword": keyword, "Success": True, "End": index, "Start": index + 1 - len(keyword), "Index": self._indexs[item] }
|
||||
ptr = tn
|
||||
return None
|
||||
|
||||
def FindAll(self,text):
|
||||
ptr = None
|
||||
list = []
|
||||
|
||||
for index in range(len(text)): # for (index = 0; index < text.length; index++)
|
||||
t =ord(text[index]) # text.charCodeAt(index)
|
||||
tn = None
|
||||
if (ptr == None):
|
||||
tn = self._first.TryGetValue(t)
|
||||
else:
|
||||
tn = ptr.TryGetValue(t)
|
||||
if (tn==None):
|
||||
tn = self._first.TryGetValue(t)
|
||||
|
||||
|
||||
if (tn != None):
|
||||
if (tn.End):
|
||||
for j in range(len(tn.Results)): # for (j = 0; j < tn.Results.length; j++)
|
||||
item = tn.Results[j]
|
||||
keyword = self._keywords[item]
|
||||
list.append({ "Keyword": keyword, "Success": True, "End": index, "Start": index + 1 - len(keyword), "Index": self._indexs[item] })
|
||||
ptr = tn
|
||||
return list
|
||||
|
||||
|
||||
def ContainsAny(self,text):
|
||||
ptr = None
|
||||
for index in range(len(text)): # for (index = 0; index < text.length; index++)
|
||||
t =ord(text[index]) # text.charCodeAt(index)
|
||||
tn = None
|
||||
if (ptr == None):
|
||||
tn = self._first.TryGetValue(t)
|
||||
else:
|
||||
tn = ptr.TryGetValue(t)
|
||||
if (tn==None):
|
||||
tn = self._first.TryGetValue(t)
|
||||
|
||||
if (tn != None):
|
||||
if (tn.End):
|
||||
return True
|
||||
ptr = tn
|
||||
return False
|
||||
|
||||
def Replace(self,text, replaceChar = '*'):
|
||||
result = list(text)
|
||||
|
||||
ptr = None
|
||||
for i in range(len(text)): # for (i = 0; i < text.length; i++)
|
||||
t =ord(text[i]) # text.charCodeAt(index)
|
||||
tn = None
|
||||
if (ptr == None):
|
||||
tn = self._first.TryGetValue(t)
|
||||
else:
|
||||
tn = ptr.TryGetValue(t)
|
||||
if (tn==None):
|
||||
tn = self._first.TryGetValue(t)
|
||||
|
||||
if (tn != None):
|
||||
if (tn.End):
|
||||
maxLength = len( self._keywords[tn.Results[0]])
|
||||
start = i + 1 - maxLength
|
||||
for j in range(start,i+1): # for (j = start; j <= i; j++)
|
||||
result[j] = replaceChar
|
||||
ptr = tn
|
||||
return ''.join(result)
|
||||
@@ -76,9 +76,7 @@ class BDunit(Plugin):
|
||||
Returns:
|
||||
string: access_token
|
||||
"""
|
||||
url = "https://aip.baidubce.com/oauth/2.0/token?client_id={}&client_secret={}&grant_type=client_credentials".format(
|
||||
self.api_key, self.secret_key
|
||||
)
|
||||
url = "https://aip.baidubce.com/oauth/2.0/token?client_id={}&client_secret={}&grant_type=client_credentials".format(self.api_key, self.secret_key)
|
||||
payload = ""
|
||||
headers = {"Content-Type": "application/json", "Accept": "application/json"}
|
||||
|
||||
@@ -94,10 +92,7 @@ class BDunit(Plugin):
|
||||
:returns: UNIT 解析结果。如果解析失败,返回 None
|
||||
"""
|
||||
|
||||
url = (
|
||||
"https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token="
|
||||
+ self.access_token
|
||||
)
|
||||
url = "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" + self.access_token
|
||||
request = {
|
||||
"query": query,
|
||||
"user_id": str(get_mac())[:32],
|
||||
@@ -124,10 +119,7 @@ class BDunit(Plugin):
|
||||
:param query: 用户的指令字符串
|
||||
:returns: UNIT 解析结果。如果解析失败,返回 None
|
||||
"""
|
||||
url = (
|
||||
"https://aip.baidubce.com/rpc/2.0/unit/service/chat?access_token="
|
||||
+ self.access_token
|
||||
)
|
||||
url = "https://aip.baidubce.com/rpc/2.0/unit/service/chat?access_token=" + self.access_token
|
||||
request = {"query": query, "user_id": str(get_mac())[:32]}
|
||||
body = {
|
||||
"log_id": str(uuid.uuid1()),
|
||||
@@ -170,11 +162,7 @@ class BDunit(Plugin):
|
||||
if parsed and "result" in parsed and "response_list" in parsed["result"]:
|
||||
response_list = parsed["result"]["response_list"]
|
||||
for response in response_list:
|
||||
if (
|
||||
"schema" in response
|
||||
and "intent" in response["schema"]
|
||||
and response["schema"]["intent"] == intent
|
||||
):
|
||||
if "schema" in response and "intent" in response["schema"] and response["schema"]["intent"] == intent:
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
@@ -198,12 +186,7 @@ class BDunit(Plugin):
|
||||
logger.warning(e)
|
||||
return []
|
||||
for response in response_list:
|
||||
if (
|
||||
"schema" in response
|
||||
and "intent" in response["schema"]
|
||||
and "slots" in response["schema"]
|
||||
and response["schema"]["intent"] == intent
|
||||
):
|
||||
if "schema" in response and "intent" in response["schema"] and "slots" in response["schema"] and response["schema"]["intent"] == intent:
|
||||
return response["schema"]["slots"]
|
||||
return []
|
||||
else:
|
||||
@@ -239,11 +222,7 @@ class BDunit(Plugin):
|
||||
if (
|
||||
"schema" in response
|
||||
and "intent_confidence" in response["schema"]
|
||||
and (
|
||||
not answer
|
||||
or response["schema"]["intent_confidence"]
|
||||
> answer["schema"]["intent_confidence"]
|
||||
)
|
||||
and (not answer or response["schema"]["intent_confidence"] > answer["schema"]["intent_confidence"])
|
||||
):
|
||||
answer = response
|
||||
return answer["action_list"][0]["say"]
|
||||
@@ -267,11 +246,7 @@ class BDunit(Plugin):
|
||||
logger.warning(e)
|
||||
return ""
|
||||
for response in response_list:
|
||||
if (
|
||||
"schema" in response
|
||||
and "intent" in response["schema"]
|
||||
and response["schema"]["intent"] == intent
|
||||
):
|
||||
if "schema" in response and "intent" in response["schema"] and response["schema"]["intent"] == intent:
|
||||
try:
|
||||
return response["action_list"][0]["say"]
|
||||
except Exception as e:
|
||||
|
||||
@@ -84,9 +84,7 @@ class Dungeon(Plugin):
|
||||
if len(clist) > 1:
|
||||
story = clist[1]
|
||||
else:
|
||||
story = (
|
||||
"你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。"
|
||||
)
|
||||
story = "你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。"
|
||||
self.games[sessionid] = StoryTeller(bot, sessionid, story)
|
||||
reply = Reply(ReplyType.INFO, "冒险开始,你可以输入任意内容,让故事继续下去。故事背景是:" + story)
|
||||
e_context["reply"] = reply
|
||||
@@ -102,11 +100,7 @@ class Dungeon(Plugin):
|
||||
if kwargs.get("verbose") != True:
|
||||
return help_text
|
||||
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
|
||||
help_text = (
|
||||
f"{trigger_prefix}开始冒险 "
|
||||
+ "背景故事: 开始一个基于{背景故事}的文字冒险,之后你的所有消息会协助完善这个故事。\n"
|
||||
+ f"{trigger_prefix}停止冒险: 结束游戏。\n"
|
||||
)
|
||||
help_text = f"{trigger_prefix}开始冒险 " + "背景故事: 开始一个基于{背景故事}的文字冒险,之后你的所有消息会协助完善这个故事。\n" + f"{trigger_prefix}停止冒险: 结束游戏。\n"
|
||||
if kwargs.get("verbose") == True:
|
||||
help_text += f"\n命令例子: '{trigger_prefix}开始冒险 你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。'"
|
||||
return help_text
|
||||
|
||||
@@ -140,9 +140,7 @@ def get_help_text(isadmin, isgroup):
|
||||
if plugins[plugin].enabled and not plugins[plugin].hidden:
|
||||
namecn = plugins[plugin].namecn
|
||||
help_text += "\n%s:" % namecn
|
||||
help_text += (
|
||||
PluginManager().instances[plugin].get_help_text(verbose=False).strip()
|
||||
)
|
||||
help_text += PluginManager().instances[plugin].get_help_text(verbose=False).strip()
|
||||
|
||||
if ADMIN_COMMANDS and isadmin:
|
||||
help_text += "\n\n管理员指令:\n"
|
||||
@@ -191,9 +189,7 @@ class Godcmd(Plugin):
|
||||
COMMANDS["reset"]["alias"].append(custom_command)
|
||||
|
||||
self.password = gconf["password"]
|
||||
self.admin_users = gconf[
|
||||
"admin_users"
|
||||
] # 预存的管理员账号,这些账号不需要认证。itchat的用户名每次都会变,不可用
|
||||
self.admin_users = gconf["admin_users"] # 预存的管理员账号,这些账号不需要认证。itchat的用户名每次都会变,不可用
|
||||
self.isrunning = True # 机器人是否运行中
|
||||
|
||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
||||
@@ -209,6 +205,13 @@ class Godcmd(Plugin):
|
||||
content = e_context["context"].content
|
||||
logger.debug("[Godcmd] on_handle_context. content: %s" % content)
|
||||
if content.startswith("#"):
|
||||
if len(content) == 1:
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.ERROR
|
||||
reply.content = f"空指令,输入#help查看指令列表\n"
|
||||
e_context["reply"] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
return
|
||||
# msg = e_context['context']['msg']
|
||||
channel = e_context["channel"]
|
||||
user = e_context["context"]["receiver"]
|
||||
@@ -241,11 +244,7 @@ class Godcmd(Plugin):
|
||||
if not plugincls.enabled:
|
||||
continue
|
||||
if query_name == name or query_name == plugincls.namecn:
|
||||
ok, result = True, PluginManager().instances[
|
||||
name
|
||||
].get_help_text(
|
||||
isgroup=isgroup, isadmin=isadmin, verbose=True
|
||||
)
|
||||
ok, result = True, PluginManager().instances[name].get_help_text(isgroup=isgroup, isadmin=isadmin, verbose=True)
|
||||
break
|
||||
if not ok:
|
||||
result = "插件不存在或未启用"
|
||||
@@ -278,11 +277,7 @@ class Godcmd(Plugin):
|
||||
if isgroup:
|
||||
ok, result = False, "群聊不可执行管理员指令"
|
||||
else:
|
||||
cmd = next(
|
||||
c
|
||||
for c, info in ADMIN_COMMANDS.items()
|
||||
if cmd in info["alias"]
|
||||
)
|
||||
cmd = next(c for c, info in ADMIN_COMMANDS.items() if cmd in info["alias"])
|
||||
if cmd == "stop":
|
||||
self.isrunning = False
|
||||
ok, result = True, "服务已暂停"
|
||||
@@ -318,18 +313,14 @@ class Godcmd(Plugin):
|
||||
PluginManager().activate_plugins()
|
||||
if len(new_plugins) > 0:
|
||||
result += "\n发现新插件:\n"
|
||||
result += "\n".join(
|
||||
[f"{p.name}_v{p.version}" for p in new_plugins]
|
||||
)
|
||||
result += "\n".join([f"{p.name}_v{p.version}" for p in new_plugins])
|
||||
else:
|
||||
result += ", 未发现新插件"
|
||||
elif cmd == "setpri":
|
||||
if len(args) != 2:
|
||||
ok, result = False, "请提供插件名和优先级"
|
||||
else:
|
||||
ok = PluginManager().set_plugin_priority(
|
||||
args[0], int(args[1])
|
||||
)
|
||||
ok = PluginManager().set_plugin_priority(args[0], int(args[1]))
|
||||
if ok:
|
||||
result = "插件" + args[0] + "优先级已设置为" + args[1]
|
||||
else:
|
||||
|
||||
@@ -23,7 +23,25 @@ class Hello(Plugin):
|
||||
logger.info("[Hello] inited")
|
||||
|
||||
def on_handle_context(self, e_context: EventContext):
|
||||
if e_context["context"].type != ContextType.TEXT:
|
||||
if e_context["context"].type not in [
|
||||
ContextType.TEXT,
|
||||
ContextType.JOIN_GROUP,
|
||||
ContextType.PATPAT,
|
||||
]:
|
||||
return
|
||||
|
||||
if e_context["context"].type == ContextType.JOIN_GROUP:
|
||||
e_context["context"].type = ContextType.TEXT
|
||||
msg: ChatMessage = e_context["context"]["msg"]
|
||||
e_context["context"].content = f'请你随机使用一种风格说一句问候语来欢迎新用户"{msg.actual_user_nickname}"加入群聊。'
|
||||
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑
|
||||
return
|
||||
|
||||
if e_context["context"].type == ContextType.PATPAT:
|
||||
e_context["context"].type = ContextType.TEXT
|
||||
msg: ChatMessage = e_context["context"]["msg"]
|
||||
e_context["context"].content = f"请你随机使用一种风格介绍你自己,并告诉用户输入#help可以查看帮助信息。"
|
||||
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑
|
||||
return
|
||||
|
||||
content = e_context["context"].content
|
||||
@@ -33,9 +51,7 @@ class Hello(Plugin):
|
||||
reply.type = ReplyType.TEXT
|
||||
msg: ChatMessage = e_context["context"]["msg"]
|
||||
if e_context["context"]["isgroup"]:
|
||||
reply.content = (
|
||||
f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}"
|
||||
)
|
||||
reply.content = f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}"
|
||||
else:
|
||||
reply.content = f"Hello, {msg.from_user_nickname}"
|
||||
e_context["reply"] = reply
|
||||
|
||||
13
plugins/keyword/README.md
Normal file
13
plugins/keyword/README.md
Normal file
@@ -0,0 +1,13 @@
|
||||
# 目的
|
||||
关键字匹配并回复
|
||||
|
||||
# 试用场景
|
||||
目前是在微信公众号下面使用过。
|
||||
|
||||
# 使用步骤
|
||||
1. 复制 `config.json.template` 为 `config.json`
|
||||
2. 在关键字 `keyword` 新增需要关键字匹配的内容
|
||||
3. 重启程序做验证
|
||||
|
||||
# 验证结果
|
||||

|
||||
1
plugins/keyword/__init__.py
Normal file
1
plugins/keyword/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .keyword import *
|
||||
5
plugins/keyword/config.json.template
Normal file
5
plugins/keyword/config.json.template
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"keyword": {
|
||||
"关键字匹配": "测试成功"
|
||||
}
|
||||
}
|
||||
65
plugins/keyword/keyword.py
Normal file
65
plugins/keyword/keyword.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import plugins
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from plugins import *
|
||||
|
||||
|
||||
@plugins.register(
|
||||
name="Keyword",
|
||||
desire_priority=900,
|
||||
hidden=True,
|
||||
desc="关键词匹配过滤",
|
||||
version="0.1",
|
||||
author="fengyege.top",
|
||||
)
|
||||
class Keyword(Plugin):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
try:
|
||||
curdir = os.path.dirname(__file__)
|
||||
config_path = os.path.join(curdir, "config.json")
|
||||
conf = None
|
||||
if not os.path.exists(config_path):
|
||||
logger.debug(f"[keyword]不存在配置文件{config_path}")
|
||||
conf = {"keyword": {}}
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(conf, f, indent=4)
|
||||
else:
|
||||
logger.debug(f"[keyword]加载配置文件{config_path}")
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
conf = json.load(f)
|
||||
# 加载关键词
|
||||
self.keyword = conf["keyword"]
|
||||
|
||||
logger.info("[keyword] {}".format(self.keyword))
|
||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
||||
logger.info("[keyword] inited.")
|
||||
except Exception as e:
|
||||
logger.warn("[keyword] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/keyword .")
|
||||
raise e
|
||||
|
||||
def on_handle_context(self, e_context: EventContext):
|
||||
if e_context["context"].type != ContextType.TEXT:
|
||||
return
|
||||
|
||||
content = e_context["context"].content.strip()
|
||||
logger.debug("[keyword] on_handle_context. content: %s" % content)
|
||||
if content in self.keyword:
|
||||
logger.debug(f"[keyword] 匹配到关键字【{content}】")
|
||||
reply_text = self.keyword[content]
|
||||
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.TEXT
|
||||
reply.content = reply_text
|
||||
e_context["reply"] = reply
|
||||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
|
||||
|
||||
def get_help_text(self, **kwargs):
|
||||
help_text = "关键词过滤"
|
||||
return help_text
|
||||
BIN
plugins/keyword/test-keyword.png
Normal file
BIN
plugins/keyword/test-keyword.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 12 KiB |
@@ -31,23 +31,14 @@ class PluginManager:
|
||||
plugincls.desc = kwargs.get("desc")
|
||||
plugincls.author = kwargs.get("author")
|
||||
plugincls.path = self.current_plugin_path
|
||||
plugincls.version = (
|
||||
kwargs.get("version") if kwargs.get("version") != None else "1.0"
|
||||
)
|
||||
plugincls.namecn = (
|
||||
kwargs.get("namecn") if kwargs.get("namecn") != None else name
|
||||
)
|
||||
plugincls.hidden = (
|
||||
kwargs.get("hidden") if kwargs.get("hidden") != None else False
|
||||
)
|
||||
plugincls.version = kwargs.get("version") if kwargs.get("version") != None else "1.0"
|
||||
plugincls.namecn = kwargs.get("namecn") if kwargs.get("namecn") != None else name
|
||||
plugincls.hidden = kwargs.get("hidden") if kwargs.get("hidden") != None else False
|
||||
plugincls.enabled = True
|
||||
if self.current_plugin_path == None:
|
||||
raise Exception("Plugin path not set")
|
||||
self.plugins[name.upper()] = plugincls
|
||||
logger.info(
|
||||
"Plugin %s_v%s registered, path=%s"
|
||||
% (name, plugincls.version, plugincls.path)
|
||||
)
|
||||
logger.info("Plugin %s_v%s registered, path=%s" % (name, plugincls.version, plugincls.path))
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -62,9 +53,7 @@ class PluginManager:
|
||||
if os.path.exists("./plugins/plugins.json"):
|
||||
with open("./plugins/plugins.json", "r", encoding="utf-8") as f:
|
||||
pconf = json.load(f)
|
||||
pconf["plugins"] = SortedDict(
|
||||
lambda k, v: v["priority"], pconf["plugins"], reverse=True
|
||||
)
|
||||
pconf["plugins"] = SortedDict(lambda k, v: v["priority"], pconf["plugins"], reverse=True)
|
||||
else:
|
||||
modified = True
|
||||
pconf = {"plugins": SortedDict(lambda k, v: v["priority"], reverse=True)}
|
||||
@@ -90,26 +79,16 @@ class PluginManager:
|
||||
if plugin_path in self.loaded:
|
||||
if self.loaded[plugin_path] == None:
|
||||
logger.info("reload module %s" % plugin_name)
|
||||
self.loaded[plugin_path] = importlib.reload(
|
||||
sys.modules[import_path]
|
||||
)
|
||||
dependent_module_names = [
|
||||
name
|
||||
for name in sys.modules.keys()
|
||||
if name.startswith(import_path + ".")
|
||||
]
|
||||
self.loaded[plugin_path] = importlib.reload(sys.modules[import_path])
|
||||
dependent_module_names = [name for name in sys.modules.keys() if name.startswith(import_path + ".")]
|
||||
for name in dependent_module_names:
|
||||
logger.info("reload module %s" % name)
|
||||
importlib.reload(sys.modules[name])
|
||||
else:
|
||||
self.loaded[plugin_path] = importlib.import_module(
|
||||
import_path
|
||||
)
|
||||
self.loaded[plugin_path] = importlib.import_module(import_path)
|
||||
self.current_plugin_path = None
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to import plugin %s: %s" % (plugin_name, e)
|
||||
)
|
||||
logger.exception("Failed to import plugin %s: %s" % (plugin_name, e))
|
||||
continue
|
||||
pconf = self.pconf
|
||||
news = [self.plugins[name] for name in self.plugins]
|
||||
@@ -119,9 +98,7 @@ class PluginManager:
|
||||
rawname = plugincls.name
|
||||
if rawname not in pconf["plugins"]:
|
||||
modified = True
|
||||
logger.info(
|
||||
"Plugin %s not found in pconfig, adding to pconfig..." % name
|
||||
)
|
||||
logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name)
|
||||
pconf["plugins"][rawname] = {
|
||||
"enabled": plugincls.enabled,
|
||||
"priority": plugincls.priority,
|
||||
@@ -136,9 +113,7 @@ class PluginManager:
|
||||
|
||||
def refresh_order(self):
|
||||
for event in self.listening_plugins.keys():
|
||||
self.listening_plugins[event].sort(
|
||||
key=lambda name: self.plugins[name].priority, reverse=True
|
||||
)
|
||||
self.listening_plugins[event].sort(key=lambda name: self.plugins[name].priority, reverse=True)
|
||||
|
||||
def activate_plugins(self): # 生成新开启的插件实例
|
||||
failed_plugins = []
|
||||
@@ -184,13 +159,8 @@ class PluginManager:
|
||||
def emit_event(self, e_context: EventContext, *args, **kwargs):
|
||||
if e_context.event in self.listening_plugins:
|
||||
for name in self.listening_plugins[e_context.event]:
|
||||
if (
|
||||
self.plugins[name].enabled
|
||||
and e_context.action == EventAction.CONTINUE
|
||||
):
|
||||
logger.debug(
|
||||
"Plugin %s triggered by event %s" % (name, e_context.event)
|
||||
)
|
||||
if self.plugins[name].enabled and e_context.action == EventAction.CONTINUE:
|
||||
logger.debug("Plugin %s triggered by event %s" % (name, e_context.event))
|
||||
instance = self.instances[name]
|
||||
instance.handlers[e_context.event](e_context, *args, **kwargs)
|
||||
return e_context
|
||||
@@ -262,9 +232,7 @@ class PluginManager:
|
||||
source = json.load(f)
|
||||
if repo in source["repo"]:
|
||||
repo = source["repo"][repo]["url"]
|
||||
match = re.match(
|
||||
r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo
|
||||
)
|
||||
match = re.match(r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo)
|
||||
if not match:
|
||||
return False, "安装插件失败,source中的仓库地址不合法"
|
||||
else:
|
||||
|
||||
@@ -69,13 +69,9 @@ class Role(Plugin):
|
||||
logger.info("[Role] inited")
|
||||
except Exception as e:
|
||||
if isinstance(e, FileNotFoundError):
|
||||
logger.warn(
|
||||
f"[Role] init failed, {config_path} not found, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role ."
|
||||
)
|
||||
logger.warn(f"[Role] init failed, {config_path} not found, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role .")
|
||||
else:
|
||||
logger.warn(
|
||||
"[Role] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role ."
|
||||
)
|
||||
logger.warn("[Role] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role .")
|
||||
raise e
|
||||
|
||||
def get_role(self, name, find_closest=True, min_sim=0.35):
|
||||
@@ -143,9 +139,7 @@ class Role(Plugin):
|
||||
else:
|
||||
help_text = f"未知角色类型。\n"
|
||||
help_text += "目前的角色类型有: \n"
|
||||
help_text += (
|
||||
",".join([self.tags[tag][0] for tag in self.tags]) + "\n"
|
||||
)
|
||||
help_text += ",".join([self.tags[tag][0] for tag in self.tags]) + "\n"
|
||||
else:
|
||||
help_text = f"请输入角色类型。\n"
|
||||
help_text += "目前的角色类型有: \n"
|
||||
@@ -158,9 +152,7 @@ class Role(Plugin):
|
||||
return
|
||||
logger.debug("[Role] on_handle_context. content: %s" % content)
|
||||
if desckey is not None:
|
||||
if len(clist) == 1 or (
|
||||
len(clist) > 1 and clist[1].lower() in ["help", "帮助"]
|
||||
):
|
||||
if len(clist) == 1 or (len(clist) > 1 and clist[1].lower() in ["help", "帮助"]):
|
||||
reply = Reply(ReplyType.INFO, self.get_help_text(verbose=True))
|
||||
e_context["reply"] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
@@ -178,9 +170,7 @@ class Role(Plugin):
|
||||
self.roles[role][desckey],
|
||||
self.roles[role].get("wrapper", "%s"),
|
||||
)
|
||||
reply = Reply(
|
||||
ReplyType.INFO, f"预设角色为 {role}:\n" + self.roles[role][desckey]
|
||||
)
|
||||
reply = Reply(ReplyType.INFO, f"预设角色为 {role}:\n" + self.roles[role][desckey])
|
||||
e_context["reply"] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
elif customize == True:
|
||||
@@ -199,17 +189,10 @@ class Role(Plugin):
|
||||
if not verbose:
|
||||
return help_text
|
||||
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
|
||||
help_text = (
|
||||
f"使用方法:\n{trigger_prefix}角色"
|
||||
+ " 预设角色名: 设定角色为{预设角色名}。\n"
|
||||
+ f"{trigger_prefix}role"
|
||||
+ " 预设角色名: 同上,但使用英文设定。\n"
|
||||
)
|
||||
help_text = f"使用方法:\n{trigger_prefix}角色" + " 预设角色名: 设定角色为{预设角色名}。\n" + f"{trigger_prefix}role" + " 预设角色名: 同上,但使用英文设定。\n"
|
||||
help_text += f"{trigger_prefix}设定扮演" + " 角色设定: 设定自定义角色人设为{角色设定}。\n"
|
||||
help_text += f"{trigger_prefix}停止扮演: 清除设定的角色。\n"
|
||||
help_text += (
|
||||
f"{trigger_prefix}角色类型" + " 角色类型: 查看某类{角色类型}的所有预设角色,为所有时输出所有预设角色。\n"
|
||||
)
|
||||
help_text += f"{trigger_prefix}角色类型" + " 角色类型: 查看某类{角色类型}的所有预设角色,为所有时输出所有预设角色。\n"
|
||||
help_text += "\n目前的角色类型有: \n"
|
||||
help_text += ",".join([self.tags[tag][0] for tag in self.tags]) + "。\n"
|
||||
help_text += f"\n命令例子: \n{trigger_prefix}角色 写作助理\n"
|
||||
|
||||
@@ -9,9 +9,21 @@
|
||||
### 1. python
|
||||
###### python解释器,使用它来解释执行python指令,可以配合你想要chatgpt生成的代码输出结果或执行事务
|
||||
|
||||
### 2. url-get
|
||||
### 2. 访问网页的工具汇总(默认url-get)
|
||||
|
||||
#### 2.1 url-get
|
||||
###### 往往用来获取某个网站具体内容,结果可能会被反爬策略影响
|
||||
|
||||
#### 2.2 browser
|
||||
###### 浏览器,功能与2.1类似,但能更好模拟,不会被识别为爬虫影响获取网站内容
|
||||
|
||||
> 注1:url-get默认配置、browser需额外配置,browser依赖google-chrome,你需要提前安装好
|
||||
|
||||
> 注2:browser默认使用summary tool 分段总结长文本信息,tokens可能会大量消耗!
|
||||
|
||||
这是debian端安装google-chrome教程,其他系统请执行查找
|
||||
> https://www.linuxjournal.com/content/how-can-you-install-google-browser-debian
|
||||
|
||||
### 3. terminal
|
||||
###### 在你运行的电脑里执行shell命令,可以配合你想要chatgpt生成的代码使用,给予自然语言控制手段
|
||||
|
||||
@@ -38,47 +50,83 @@
|
||||
### 5. wikipedia
|
||||
###### 可以回答你想要知道确切的人事物
|
||||
|
||||
### 6. news *
|
||||
### 6. 新闻类工具
|
||||
|
||||
#### 6.1. news-api *
|
||||
###### 从全球 80,000 多个信息源中获取当前和历史新闻文章
|
||||
|
||||
### 7. morning-news *
|
||||
#### 6.2. morning-news *
|
||||
###### 每日60秒早报,每天凌晨一点更新,本工具使用了[alapi-每日60秒早报](https://alapi.cn/api/view/93)
|
||||
|
||||
> 该tool每天返回内容相同
|
||||
|
||||
### 8. bing-search *
|
||||
#### 6.3. finance-news
|
||||
###### 获取实时的金融财政新闻
|
||||
|
||||
> 该工具需要解决browser tool 的google-chrome依赖安装
|
||||
|
||||
### 7. bing-search *
|
||||
###### bing搜索引擎,从此你不用再烦恼搜索要用哪些关键词
|
||||
|
||||
### 9. wolfram-alpha *
|
||||
### 8. wolfram-alpha *
|
||||
###### 知识搜索引擎、科学问答系统,常用于专业学科计算
|
||||
|
||||
### 10. google-search *
|
||||
### 9. google-search *
|
||||
###### google搜索引擎,申请流程较bing-search繁琐
|
||||
|
||||
###### 注1:带*工具需要获取api-key才能使用,部分工具需要外网支持
|
||||
|
||||
### 10. arxiv(dev 开发中)
|
||||
###### 用于查找论文
|
||||
|
||||
|
||||
### 11. debug(dev 开发中,目前没有接入wechat)
|
||||
###### 当bot遇到无法确定的信息时,将会向你寻求帮助的工具
|
||||
|
||||
|
||||
### 12. summary
|
||||
###### 总结工具,该工具必须输入一个本地文件的绝对路径
|
||||
|
||||
> 该工具目前是和其他工具配合使用,暂未测试单独使用效果
|
||||
|
||||
|
||||
### 13. image2text
|
||||
###### 将图片转换成文字,底层调用imageCaption模型,该工具必须输入一个本地文件的绝对路径
|
||||
|
||||
|
||||
### 14. searxng-search *
|
||||
###### 一个私有化的搜索引擎工具
|
||||
|
||||
> 安装教程:https://docs.searxng.org/admin/installation.html
|
||||
|
||||
---
|
||||
|
||||
###### 注1:带*工具需要获取api-key才能使用(在config.json内的kwargs添加项),部分工具需要外网支持
|
||||
#### [申请方法](https://github.com/goldfishh/chatgpt-tool-hub/blob/master/docs/apply_optional_tool.md)
|
||||
|
||||
## config.json 配置说明
|
||||
###### 默认工具无需配置,其它工具需手动配置,一个例子:
|
||||
```json
|
||||
{
|
||||
"tools": ["wikipedia"], // 填入你想用到的额外工具名
|
||||
"tools": ["wikipedia", "你想要添加的其他工具"], // 填入你想用到的额外工具名
|
||||
"kwargs": {
|
||||
"request_timeout": 60, // openai接口超时时间
|
||||
"debug": true, // 当你遇到问题求助时,需要配置
|
||||
"request_timeout": 120, // openai接口超时时间
|
||||
"no_default": false, // 是否不使用默认的4个工具
|
||||
"OPTIONAL_API_NAME": "OPTIONAL_API_KEY" // 带*工具需要申请api-key,在这里填入,api_name参考前述`申请方法`
|
||||
// 带*工具需要申请api-key,在这里填入,api_name参考前述`申请方法`
|
||||
}
|
||||
}
|
||||
|
||||
```
|
||||
注:config.json文件非必须,未创建仍可使用本tool;带*工具需在kwargs填入对应api-key键值对
|
||||
- `tools`:本插件初始化时加载的工具, 目前可选集:["wikipedia", "wolfram-alpha", "bing-search", "google-search", "news", "morning-news"] & 默认工具,除wikipedia工具之外均需要申请api-key
|
||||
- `tools`:本插件初始化时加载的工具, 目前可选集:["wikipedia", "wolfram-alpha", "bing-search", "google-search", "news"] & 默认工具,除wikipedia工具之外均需要申请api-key
|
||||
- `kwargs`:工具执行时的配置,一般在这里存放**api-key**,或环境配置
|
||||
- `debug`: 输出chatgpt-tool-hub额外信息用于调试
|
||||
- `request_timeout`: 访问openai接口的超时时间,默认与wechat-on-chatgpt配置一致,可单独配置
|
||||
- `no_default`: 用于配置默认加载4个工具的行为,如果为true则仅使用tools列表工具,不加载默认工具
|
||||
- `top_k_results`: 控制所有有关搜索的工具返回条目数,数字越高则参考信息越多,但无用信息可能干扰判断,该值一般为2
|
||||
- `model_name`: 用于控制tool插件底层使用的llm模型,目前暂未测试3.5以外的模型,一般保持默认
|
||||
|
||||
---
|
||||
|
||||
## 备注
|
||||
- 强烈建议申请搜索工具搭配使用,推荐bing-search
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from chatgpt_tool_hub.apps import load_app
|
||||
from chatgpt_tool_hub.apps import AppFactory
|
||||
from chatgpt_tool_hub.apps.app import App
|
||||
from chatgpt_tool_hub.tools.all_tool_list import get_all_tool_names
|
||||
|
||||
@@ -18,7 +18,7 @@ from plugins import *
|
||||
@plugins.register(
|
||||
name="tool",
|
||||
desc="Arming your ChatGPT bot with various tools",
|
||||
version="0.3",
|
||||
version="0.4",
|
||||
author="goldfishh",
|
||||
desire_priority=0,
|
||||
)
|
||||
@@ -82,9 +82,7 @@ class Tool(Plugin):
|
||||
return
|
||||
elif content_list[1].startswith("reset"):
|
||||
logger.debug("[tool]: remind")
|
||||
e_context[
|
||||
"context"
|
||||
].content = "请你随机用一种聊天风格,提醒用户:如果想重置tool插件,reset之后不要加任何字符"
|
||||
e_context["context"].content = "请你随机用一种聊天风格,提醒用户:如果想重置tool插件,reset之后不要加任何字符"
|
||||
|
||||
e_context.action = EventAction.BREAK
|
||||
return
|
||||
@@ -93,18 +91,14 @@ class Tool(Plugin):
|
||||
|
||||
# Don't modify bot name
|
||||
all_sessions = Bridge().get_bot("chat").sessions
|
||||
user_session = all_sessions.session_query(
|
||||
query, e_context["context"]["session_id"]
|
||||
).messages
|
||||
user_session = all_sessions.session_query(query, e_context["context"]["session_id"]).messages
|
||||
|
||||
# chatgpt-tool-hub will reply you with many tools
|
||||
logger.debug("[tool]: just-go")
|
||||
try:
|
||||
_reply = self.app.ask(query, user_session)
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
all_sessions.session_reply(
|
||||
_reply, e_context["context"]["session_id"]
|
||||
)
|
||||
all_sessions.session_reply(_reply, e_context["context"]["session_id"])
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
logger.error(str(e))
|
||||
@@ -131,17 +125,17 @@ class Tool(Plugin):
|
||||
|
||||
def _build_tool_kwargs(self, kwargs: dict):
|
||||
tool_model_name = kwargs.get("model_name")
|
||||
request_timeout = kwargs.get("request_timeout")
|
||||
|
||||
return {
|
||||
"debug": kwargs.get("debug", False),
|
||||
"openai_api_key": conf().get("open_ai_api_key", ""),
|
||||
"proxy": conf().get("proxy", ""),
|
||||
"request_timeout": str(conf().get("request_timeout", 60)),
|
||||
"request_timeout": request_timeout if request_timeout else conf().get("request_timeout", 120),
|
||||
# note: 目前tool暂未对其他模型测试,但这里仍对配置来源做了优先级区分,一般插件配置可覆盖全局配置
|
||||
"model_name": tool_model_name
|
||||
if tool_model_name
|
||||
else conf().get("model", "gpt-3.5-turbo"),
|
||||
"model_name": tool_model_name if tool_model_name else conf().get("model", "gpt-3.5-turbo"),
|
||||
"no_default": kwargs.get("no_default", False),
|
||||
"top_k_results": kwargs.get("top_k_results", 2),
|
||||
"top_k_results": kwargs.get("top_k_results", 3),
|
||||
# for news tool
|
||||
"news_api_key": kwargs.get("news_api_key", ""),
|
||||
# for bing-search tool
|
||||
@@ -157,8 +151,6 @@ class Tool(Plugin):
|
||||
"zaobao_api_key": kwargs.get("zaobao_api_key", ""),
|
||||
# for visual_dl tool
|
||||
"cuda_device": kwargs.get("cuda_device", "cpu"),
|
||||
# for browser tool
|
||||
"phantomjs_exec_path": kwargs.get("phantomjs_exec_path", ""),
|
||||
}
|
||||
|
||||
def _filter_tool_list(self, tool_list: list):
|
||||
@@ -172,11 +164,12 @@ class Tool(Plugin):
|
||||
|
||||
def _reset_app(self) -> App:
|
||||
tool_config = self._read_json()
|
||||
app_kwargs = self._build_tool_kwargs(tool_config.get("kwargs", {}))
|
||||
|
||||
app = AppFactory()
|
||||
app.init_env(**app_kwargs)
|
||||
|
||||
# filter not support tool
|
||||
tool_list = self._filter_tool_list(tool_config.get("tools", []))
|
||||
|
||||
return load_app(
|
||||
tools_list=tool_list,
|
||||
**self._build_tool_kwargs(tool_config.get("kwargs", {})),
|
||||
)
|
||||
return app.create_app(tools_list=tool_list, **app_kwargs)
|
||||
|
||||
8
pyproject.toml
Normal file
8
pyproject.toml
Normal file
@@ -0,0 +1,8 @@
|
||||
[tool.black]
|
||||
line-length = 176
|
||||
target-version = ['py37']
|
||||
include = '\.pyi?$'
|
||||
extend-exclude = '.+/(dist|.venv|venv|build|lib)/.+'
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
@@ -18,7 +18,9 @@ pysilk_mod>=1.6.0 # needed by send voice
|
||||
|
||||
# wechatmp
|
||||
web.py
|
||||
wechatpy
|
||||
|
||||
# chatgpt-tool-hub plugin
|
||||
|
||||
--extra-index-url https://pypi.python.org/simple
|
||||
chatgpt_tool_hub>=0.3.9
|
||||
chatgpt_tool_hub>=0.4.1
|
||||
@@ -34,6 +34,20 @@ def get_pcm_from_wav(wav_path):
|
||||
return wav.readframes(wav.getnframes())
|
||||
|
||||
|
||||
def any_to_mp3(any_path, mp3_path):
|
||||
"""
|
||||
把任意格式转成mp3文件
|
||||
"""
|
||||
if any_path.endswith(".mp3"):
|
||||
shutil.copy2(any_path, mp3_path)
|
||||
return
|
||||
if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"):
|
||||
sil_to_wav(any_path, any_path)
|
||||
any_path = mp3_path
|
||||
audio = AudioSegment.from_file(any_path)
|
||||
audio.export(mp3_path, format="mp3")
|
||||
|
||||
|
||||
def any_to_wav(any_path, wav_path):
|
||||
"""
|
||||
把任意格式转成wav文件
|
||||
@@ -41,11 +55,7 @@ def any_to_wav(any_path, wav_path):
|
||||
if any_path.endswith(".wav"):
|
||||
shutil.copy2(any_path, wav_path)
|
||||
return
|
||||
if (
|
||||
any_path.endswith(".sil")
|
||||
or any_path.endswith(".silk")
|
||||
or any_path.endswith(".slk")
|
||||
):
|
||||
if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"):
|
||||
return sil_to_wav(any_path, wav_path)
|
||||
audio = AudioSegment.from_file(any_path)
|
||||
audio.export(wav_path, format="wav")
|
||||
@@ -55,59 +65,17 @@ def any_to_sil(any_path, sil_path):
|
||||
"""
|
||||
把任意格式转成sil文件
|
||||
"""
|
||||
if (
|
||||
any_path.endswith(".sil")
|
||||
or any_path.endswith(".silk")
|
||||
or any_path.endswith(".slk")
|
||||
):
|
||||
if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"):
|
||||
shutil.copy2(any_path, sil_path)
|
||||
return 10000
|
||||
if any_path.endswith(".wav"):
|
||||
return pcm_to_sil(any_path, sil_path)
|
||||
if any_path.endswith(".mp3"):
|
||||
return mp3_to_sil(any_path, sil_path)
|
||||
raise NotImplementedError("Not support file type: {}".format(any_path))
|
||||
|
||||
|
||||
def mp3_to_wav(mp3_path, wav_path):
|
||||
"""
|
||||
把mp3格式转成pcm文件
|
||||
"""
|
||||
audio = AudioSegment.from_mp3(mp3_path)
|
||||
audio.export(wav_path, format="wav")
|
||||
|
||||
|
||||
def pcm_to_sil(pcm_path, silk_path):
|
||||
"""
|
||||
wav 文件转成 silk
|
||||
return 声音长度,毫秒
|
||||
"""
|
||||
audio = AudioSegment.from_wav(pcm_path)
|
||||
audio = AudioSegment.from_file(any_path)
|
||||
rate = find_closest_sil_supports(audio.frame_rate)
|
||||
# Convert to PCM_s16
|
||||
pcm_s16 = audio.set_sample_width(2)
|
||||
pcm_s16 = pcm_s16.set_frame_rate(rate)
|
||||
wav_data = pcm_s16.raw_data
|
||||
silk_data = pysilk.encode(wav_data, data_rate=rate, sample_rate=rate)
|
||||
with open(silk_path, "wb") as f:
|
||||
f.write(silk_data)
|
||||
return audio.duration_seconds * 1000
|
||||
|
||||
|
||||
def mp3_to_sil(mp3_path, silk_path):
|
||||
"""
|
||||
mp3 文件转成 silk
|
||||
return 声音长度,毫秒
|
||||
"""
|
||||
audio = AudioSegment.from_mp3(mp3_path)
|
||||
rate = find_closest_sil_supports(audio.frame_rate)
|
||||
# Convert to PCM_s16
|
||||
pcm_s16 = audio.set_sample_width(2)
|
||||
pcm_s16 = pcm_s16.set_frame_rate(rate)
|
||||
wav_data = pcm_s16.raw_data
|
||||
silk_data = pysilk.encode(wav_data, data_rate=rate, sample_rate=rate)
|
||||
# Save the silk file
|
||||
with open(silk_path, "wb") as f:
|
||||
with open(sil_path, "wb") as f:
|
||||
f.write(silk_data)
|
||||
return audio.duration_seconds * 1000
|
||||
|
||||
|
||||
@@ -40,49 +40,33 @@ class AzureVoice(Voice):
|
||||
config = json.load(fr)
|
||||
self.api_key = conf().get("azure_voice_api_key")
|
||||
self.api_region = conf().get("azure_voice_region")
|
||||
self.speech_config = speechsdk.SpeechConfig(
|
||||
subscription=self.api_key, region=self.api_region
|
||||
)
|
||||
self.speech_config.speech_synthesis_voice_name = config[
|
||||
"speech_synthesis_voice_name"
|
||||
]
|
||||
self.speech_config.speech_recognition_language = config[
|
||||
"speech_recognition_language"
|
||||
]
|
||||
self.speech_config = speechsdk.SpeechConfig(subscription=self.api_key, region=self.api_region)
|
||||
self.speech_config.speech_synthesis_voice_name = config["speech_synthesis_voice_name"]
|
||||
self.speech_config.speech_recognition_language = config["speech_recognition_language"]
|
||||
except Exception as e:
|
||||
logger.warn("AzureVoice init failed: %s, ignore " % e)
|
||||
|
||||
def voiceToText(self, voice_file):
|
||||
audio_config = speechsdk.AudioConfig(filename=voice_file)
|
||||
speech_recognizer = speechsdk.SpeechRecognizer(
|
||||
speech_config=self.speech_config, audio_config=audio_config
|
||||
)
|
||||
speech_recognizer = speechsdk.SpeechRecognizer(speech_config=self.speech_config, audio_config=audio_config)
|
||||
result = speech_recognizer.recognize_once()
|
||||
if result.reason == speechsdk.ResultReason.RecognizedSpeech:
|
||||
logger.info(
|
||||
"[Azure] voiceToText voice file name={} text={}".format(
|
||||
voice_file, result.text
|
||||
)
|
||||
)
|
||||
logger.info("[Azure] voiceToText voice file name={} text={}".format(voice_file, result.text))
|
||||
reply = Reply(ReplyType.TEXT, result.text)
|
||||
else:
|
||||
logger.error("[Azure] voiceToText error, result={}".format(result))
|
||||
logger.error("[Azure] voiceToText error, result={}, canceldetails={}".format(result, result.cancellation_details))
|
||||
reply = Reply(ReplyType.ERROR, "抱歉,语音识别失败")
|
||||
return reply
|
||||
|
||||
def textToVoice(self, text):
|
||||
fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".wav"
|
||||
audio_config = speechsdk.AudioConfig(filename=fileName)
|
||||
speech_synthesizer = speechsdk.SpeechSynthesizer(
|
||||
speech_config=self.speech_config, audio_config=audio_config
|
||||
)
|
||||
speech_synthesizer = speechsdk.SpeechSynthesizer(speech_config=self.speech_config, audio_config=audio_config)
|
||||
result = speech_synthesizer.speak_text(text)
|
||||
if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted:
|
||||
logger.info(
|
||||
"[Azure] textToVoice text={} voice file name={}".format(text, fileName)
|
||||
)
|
||||
logger.info("[Azure] textToVoice text={} voice file name={}".format(text, fileName))
|
||||
reply = Reply(ReplyType.VOICE, fileName)
|
||||
else:
|
||||
logger.error("[Azure] textToVoice error, result={}".format(result))
|
||||
logger.error("[Azure] textToVoice error, result={}, canceldetails={}".format(result, result.cancellation_details))
|
||||
reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败")
|
||||
return reply
|
||||
|
||||
@@ -85,9 +85,7 @@ class BaiduVoice(Voice):
|
||||
fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3"
|
||||
with open(fileName, "wb") as f:
|
||||
f.write(result)
|
||||
logger.info(
|
||||
"[Baidu] textToVoice text={} voice file name={}".format(text, fileName)
|
||||
)
|
||||
logger.info("[Baidu] textToVoice text={} voice file name={}".format(text, fileName))
|
||||
reply = Reply(ReplyType.VOICE, fileName)
|
||||
else:
|
||||
logger.error("[Baidu] textToVoice error={}".format(result))
|
||||
|
||||
@@ -24,11 +24,7 @@ class GoogleVoice(Voice):
|
||||
audio = self.recognizer.record(source)
|
||||
try:
|
||||
text = self.recognizer.recognize_google(audio, language="zh-CN")
|
||||
logger.info(
|
||||
"[Google] voiceToText text={} voice file name={}".format(
|
||||
text, voice_file
|
||||
)
|
||||
)
|
||||
logger.info("[Google] voiceToText text={} voice file name={}".format(text, voice_file))
|
||||
reply = Reply(ReplyType.TEXT, text)
|
||||
except speech_recognition.UnknownValueError:
|
||||
reply = Reply(ReplyType.ERROR, "抱歉,我听不懂")
|
||||
@@ -42,9 +38,7 @@ class GoogleVoice(Voice):
|
||||
mp3File = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3"
|
||||
tts = gTTS(text=text, lang="zh")
|
||||
tts.save(mp3File)
|
||||
logger.info(
|
||||
"[Google] textToVoice text={} voice file name={}".format(text, mp3File)
|
||||
)
|
||||
logger.info("[Google] textToVoice text={} voice file name={}".format(text, mp3File))
|
||||
reply = Reply(ReplyType.VOICE, mp3File)
|
||||
except Exception as e:
|
||||
reply = Reply(ReplyType.ERROR, str(e))
|
||||
|
||||
@@ -22,11 +22,7 @@ class OpenaiVoice(Voice):
|
||||
result = openai.Audio.transcribe("whisper-1", file)
|
||||
text = result["text"]
|
||||
reply = Reply(ReplyType.TEXT, text)
|
||||
logger.info(
|
||||
"[Openai] voiceToText text={} voice file name={}".format(
|
||||
text, voice_file
|
||||
)
|
||||
)
|
||||
logger.info("[Openai] voiceToText text={} voice file name={}".format(text, voice_file))
|
||||
except Exception as e:
|
||||
reply = Reply(ReplyType.ERROR, str(e))
|
||||
finally:
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
pytts voice service (offline)
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import pyttsx3
|
||||
@@ -20,19 +22,42 @@ class PyttsVoice(Voice):
|
||||
self.engine.setProperty("rate", 125)
|
||||
# 音量
|
||||
self.engine.setProperty("volume", 1.0)
|
||||
for voice in self.engine.getProperty("voices"):
|
||||
if "Chinese" in voice.name:
|
||||
self.engine.setProperty("voice", voice.id)
|
||||
if sys.platform == "win32":
|
||||
for voice in self.engine.getProperty("voices"):
|
||||
if "Chinese" in voice.name:
|
||||
self.engine.setProperty("voice", voice.id)
|
||||
else:
|
||||
self.engine.setProperty("voice", "zh")
|
||||
# If the problem of espeak is fixed, using runAndWait() and remove this startLoop()
|
||||
# TODO: check if this is work on win32
|
||||
self.engine.startLoop(useDriverLoop=False)
|
||||
|
||||
def textToVoice(self, text):
|
||||
try:
|
||||
wavFile = TmpDir().path() + "reply-" + str(int(time.time())) + ".wav"
|
||||
# avoid the same filename
|
||||
wavFileName = "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".wav"
|
||||
wavFile = TmpDir().path() + wavFileName
|
||||
logger.info("[Pytts] textToVoice text={} voice file name={}".format(text, wavFile))
|
||||
|
||||
self.engine.save_to_file(text, wavFile)
|
||||
self.engine.runAndWait()
|
||||
logger.info(
|
||||
"[Pytts] textToVoice text={} voice file name={}".format(text, wavFile)
|
||||
)
|
||||
|
||||
if sys.platform == "win32":
|
||||
self.engine.runAndWait()
|
||||
else:
|
||||
# In ubuntu, runAndWait do not really wait until the file created.
|
||||
# It will return once the task queue is empty, but the task is still running in coroutine.
|
||||
# And if you call runAndWait() and time.sleep() twice, it will stuck, so do not use this.
|
||||
# If you want to fix this, add self._proxy.setBusy(True) in line 127 in espeak.py, at the beginning of the function save_to_file.
|
||||
# self.engine.runAndWait()
|
||||
|
||||
# Before espeak fix this problem, we iterate the generator and control the waiting by ourself.
|
||||
# But this is not the canonical way to use it, for example if the file already exists it also cannot wait.
|
||||
self.engine.iterate()
|
||||
while self.engine.isBusy() or wavFileName not in os.listdir(TmpDir().path()):
|
||||
time.sleep(0.1)
|
||||
|
||||
reply = Reply(ReplyType.VOICE, wavFile)
|
||||
|
||||
except Exception as e:
|
||||
reply = Reply(ReplyType.ERROR, str(e))
|
||||
finally:
|
||||
|
||||
Reference in New Issue
Block a user