Refactor: inherit ChatChannel

This commit is contained in:
JS00000
2023-04-05 20:55:24 +08:00
parent 44f6892cb7
commit 1a981ea970
3 changed files with 134 additions and 220 deletions

View File

@@ -18,6 +18,6 @@ def create_channel(channel_type):
from channel.terminal.terminal_channel import TerminalChannel from channel.terminal.terminal_channel import TerminalChannel
return TerminalChannel() return TerminalChannel()
elif channel_type == 'wechatmp': elif channel_type == 'wechatmp':
from channel.wechatmp.wechatmp_channel import WechatMPServer from channel.wechatmp.wechatmp_channel import WechatMPChannel
return WechatMPServer() return WechatMPChannel()
raise RuntimeError raise RuntimeError

View File

@@ -1,47 +1,43 @@
# -*- coding: utf-8 -*-# # -*- coding: utf-8 -*-#
# filename: receive.py # filename: receive.py
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from bridge.context import ContextType
from channel.chat_message import ChatMessage
from common.tmp_dir import TmpDir
from common.log import logger
def parse_xml(web_data): def parse_xml(web_data):
if len(web_data) == 0: if len(web_data) == 0:
return None return None
xmlData = ET.fromstring(web_data) xmlData = ET.fromstring(web_data)
msg_type = xmlData.find('MsgType').text return WeChatMPMessage(xmlData)
if msg_type == 'text':
return TextMsg(xmlData)
elif msg_type == 'image':
return ImageMsg(xmlData)
elif msg_type == 'event':
return Event(xmlData)
class WeChatMPMessage(ChatMessage):
class Msg(object):
def __init__(self, xmlData): def __init__(self, xmlData):
self.ToUserName = xmlData.find('ToUserName').text super().__init__(xmlData)
self.FromUserName = xmlData.find('FromUserName').text self.to_user_id = xmlData.find('ToUserName').text
self.CreateTime = xmlData.find('CreateTime').text self.from_user_id = xmlData.find('FromUserName').text
self.MsgType = xmlData.find('MsgType').text self.create_time = xmlData.find('CreateTime').text
self.MsgId = xmlData.find('MsgId').text self.msg_type = xmlData.find('MsgType').text
self.msg_id = xmlData.find('MsgId').text
self.is_group = False
# reply to other_user_id
self.other_user_id = self.from_user_id
if self.msg_type == 'text':
class TextMsg(Msg): self.ctype = ContextType.TEXT
def __init__(self, xmlData): self.content = xmlData.find('Content').text.encode("utf-8")
Msg.__init__(self, xmlData) elif self.msg_type == 'voice':
self.Content = xmlData.find('Content').text.encode("utf-8") self.ctype = ContextType.TEXT
self.content = xmlData.find('Recognition').text.encode("utf-8") # 接收语音识别结果
elif self.msg_type == 'image':
class ImageMsg(Msg): # not implemented
def __init__(self, xmlData): self.pic_url = xmlData.find('PicUrl').text
Msg.__init__(self, xmlData) self.media_id = xmlData.find('MediaId').text
self.PicUrl = xmlData.find('PicUrl').text elif self.msg_type == 'event':
self.MediaId = xmlData.find('MediaId').text self.event = xmlData.find('Event').text
else: # video, shortvideo, location, link
# not implemented
class Event(object): pass
def __init__(self, xmlData):
self.ToUserName = xmlData.find('ToUserName').text
self.FromUserName = xmlData.find('FromUserName').text
self.CreateTime = xmlData.find('CreateTime').text
self.MsgType = xmlData.find('MsgType').text
self.Event = xmlData.find('Event').text

View File

@@ -4,9 +4,10 @@ import time
import math import math
import hashlib import hashlib
import textwrap import textwrap
from channel.channel import Channel from channel.chat_channel import ChatChannel
import channel.wechatmp.reply as reply import channel.wechatmp.reply as reply
import channel.wechatmp.receive as receive import channel.wechatmp.receive as receive
from common.singleton import singleton
from common.log import logger from common.log import logger
from config import conf from config import conf
from bridge.reply import * from bridge.reply import *
@@ -21,202 +22,125 @@ import traceback
# certificate='/ssl/cert.pem', # certificate='/ssl/cert.pem',
# private_key='/ssl/cert.key') # private_key='/ssl/cert.key')
class WechatMPServer():
def __init__(self):
pass
def startup(self): # from concurrent.futures import ThreadPoolExecutor
# thread_pool = ThreadPoolExecutor(max_workers=8)
@singleton
class WechatMPChannel(ChatChannel):
def __init__(self):
super().__init__()
self.cache_dict = dict()
self.query1 = dict()
self.query2 = dict()
self.query3 = dict()
def startup(self):
urls = ( urls = (
'/wx', 'WechatMPChannel', '/wx', 'SubsribeAccountQuery',
) )
app = web.application(urls, globals()) app = web.application(urls, globals())
app.run() app.run()
cache_dict = dict()
query1 = dict()
query2 = dict()
query3 = dict()
from concurrent.futures import ThreadPoolExecutor def send(self, reply: Reply, context: Context):
thread_pool = ThreadPoolExecutor(max_workers=8) reply_cnt = math.ceil(len(reply.content) / 600)
receiver = context["receiver"]
self.cache_dict[receiver] = (reply_cnt, reply.content)
logger.debug("[send] reply to {} saved to cache: {}".format(receiver, reply))
class WechatMPChannel(Channel):
def verify_server():
try:
data = web.input()
if len(data) == 0:
return "None"
signature = data.signature
timestamp = data.timestamp
nonce = data.nonce
echostr = data.echostr
token = conf().get('wechatmp_token') #请按照公众平台官网\基本配置中信息填写
data_list = [token, timestamp, nonce]
data_list.sort()
sha1 = hashlib.sha1()
# map(sha1.update, data_list) #python2
sha1.update("".join(data_list).encode('utf-8'))
hashcode = sha1.hexdigest()
print("handle/GET func: hashcode, signature: ", hashcode, signature)
if hashcode == signature:
return echostr
else:
return ""
except Exception as Argument:
return Argument
# This class is instantiated once per query
class SubsribeAccountQuery():
def GET(self): def GET(self):
try: return verify_server()
data = web.input()
if len(data) == 0:
return "hello, this is handle view"
signature = data.signature
timestamp = data.timestamp
nonce = data.nonce
echostr = data.echostr
token = conf().get('wechatmp_token') #请按照公众平台官网\基本配置中信息填写
data_list = [token, timestamp, nonce]
data_list.sort()
sha1 = hashlib.sha1()
# map(sha1.update, data_list) #python2
sha1.update("".join(data_list).encode('utf-8'))
hashcode = sha1.hexdigest()
print("handle/GET func: hashcode, signature: ", hashcode, signature)
if hashcode == signature:
return echostr
else:
return ""
except Exception as Argument:
return Argument
def _do_build_reply(self, cache_key, fromUser, message):
context = dict()
context['session_id'] = fromUser
reply_text = super().build_reply_content(message, context)
# The query is done, record the cache
logger.info("[threaded] Get reply for {}: {} \nA: {}".format(fromUser, message, reply_text))
global cache_dict
reply_cnt = math.ceil(len(reply_text) / 600)
cache_dict[cache_key] = (reply_cnt, reply_text)
def send(self, reply : Reply, cache_key):
global cache_dict
reply_cnt = math.ceil(len(reply.content) / 600)
cache_dict[cache_key] = (reply_cnt, reply.content)
def handle(self, context):
global cache_dict
try:
reply = Reply()
logger.debug('[wechatmp] ready to handle context: {}'.format(context))
# reply的构建步骤
e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {'channel' : self, 'context': context, 'reply': reply}))
reply = e_context['reply']
if not e_context.is_pass():
logger.debug('[wechatmp] ready to handle context: type={}, content={}'.format(context.type, context.content))
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE:
reply = super().build_reply_content(context.content, context)
# elif context.type == ContextType.VOICE:
# msg = context['msg']
# file_name = TmpDir().path() + context.content
# msg.download(file_name)
# reply = super().build_voice_to_text(file_name)
# if reply.type != ReplyType.ERROR and reply.type != ReplyType.INFO:
# context.content = reply.content # 语音转文字后将文字内容作为新的context
# context.type = ContextType.TEXT
# reply = super().build_reply_content(context.content, context)
# if reply.type == ReplyType.TEXT:
# if conf().get('voice_reply_voice'):
# reply = super().build_text_to_voice(reply.content)
else:
logger.error('[wechatmp] unknown context type: {}'.format(context.type))
return
logger.debug('[wechatmp] ready to decorate reply: {}'.format(reply))
# reply的包装步骤
if reply and reply.type:
e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {'channel' : self, 'context': context, 'reply': reply}))
reply=e_context['reply']
if not e_context.is_pass() and reply and reply.type:
if reply.type == ReplyType.TEXT:
pass
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
reply.content = str(reply.type)+":\n" + reply.content
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE:
pass
else:
logger.error('[wechatmp] unknown reply type: {}'.format(reply.type))
return
# reply的发送步骤
if reply and reply.type:
e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {'channel' : self, 'context': context, 'reply': reply}))
reply=e_context['reply']
if not e_context.is_pass() and reply and reply.type:
logger.debug('[wechatmp] ready to send reply: {} to {}'.format(reply, context['receiver']))
self.send(reply, context['receiver'])
else:
cache_dict[context['receiver']] = (1, "No reply")
logger.info("[threaded] Get reply for {}: {} \nA: {}".format(context['receiver'], context.content, reply.content))
except Exception as exc:
print(traceback.format_exc())
cache_dict[context['receiver']] = (1, "ERROR")
def POST(self): def POST(self):
channel_instance = WechatMPChannel()
try: try:
queryTime = time.time() query_time = time.time()
webData = web.data() webData = web.data()
# logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8")) # logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8"))
recMsg = receive.parse_xml(webData) wechat_msg = receive.parse_xml(webData)
if isinstance(recMsg, receive.Msg) and recMsg.MsgType == 'text': if wechat_msg.msg_type == 'text':
fromUser = recMsg.FromUserName from_user = wechat_msg.from_user_id
toUser = recMsg.ToUserName to_user = wechat_msg.to_user_id
createTime = recMsg.CreateTime message = wechat_msg.content.decode("utf-8")
message = recMsg.Content.decode("utf-8") message_id = wechat_msg.msg_id
message_id = recMsg.MsgId
logger.info("[wechatmp] {}:{} Receive post query {} {}: {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), fromUser, message_id, message)) logger.info("[wechatmp] {}:{} Receive post query {} {}: {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), from_user, message_id, message))
global cache_dict cache_key = from_user
global query1 cache = channel_instance.cache_dict.get(cache_key)
global query2
global query3
cache_key = fromUser
cache = cache_dict.get(cache_key)
reply_text = "" reply_text = ""
# New request # New request
if cache == None: if cache == None:
# The first query begin, reset the cache # The first query begin, reset the cache
cache_dict[cache_key] = (0, "") channel_instance.cache_dict[cache_key] = (0, "")
# thread_pool.submit(self._do_build_reply, cache_key, fromUser, message)
context = Context() context = channel_instance._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechat_msg)
context.kwargs = {'isgroup': False, 'receiver': fromUser, 'session_id': fromUser} if context:
# set private openai_api_key
# if from_user is not changed in itchat, this can be placed at chat_channel
user_data = conf().get_user_data(from_user)
context['openai_api_key'] = user_data.get('openai_api_key') # None or user openai_api_key
channel_instance.produce(context)
user_data = conf().get_user_data(fromUser)
context['openai_api_key'] = user_data.get('openai_api_key') # None or user openai_api_key
img_match_prefix = check_prefix(message, conf().get('image_create_prefix')) channel_instance.query1[cache_key] = False
if img_match_prefix: channel_instance.query2[cache_key] = False
message = message.replace(img_match_prefix, '', 1).strip() channel_instance.query3[cache_key] = False
context.type = ContextType.IMAGE_CREATE
else:
context.type = ContextType.TEXT
context.content = message
thread_pool.submit(self.handle, context)
query1[cache_key] = False
query2[cache_key] = False
query3[cache_key] = False
# Request again # Request again
elif cache[0] == 0 and query1.get(cache_key) == True and query2.get(cache_key) == True and query3.get(cache_key) == True: elif cache[0] == 0 and channel_instance.query1.get(cache_key) == True and channel_instance.query2.get(cache_key) == True and channel_instance.query3.get(cache_key) == True:
query1[cache_key] = False #To improve waiting experience, this can be set to True. channel_instance.query1[cache_key] = False #To improve waiting experience, this can be set to True.
query2[cache_key] = False #To improve waiting experience, this can be set to True. channel_instance.query2[cache_key] = False #To improve waiting experience, this can be set to True.
query3[cache_key] = False channel_instance.query3[cache_key] = False
elif cache[0] >= 1: elif cache[0] >= 1:
# Skip the waiting phase # Skip the waiting phase
query1[cache_key] = True channel_instance.query1[cache_key] = True
query2[cache_key] = True channel_instance.query2[cache_key] = True
query3[cache_key] = True channel_instance.query3[cache_key] = True
cache = cache_dict.get(cache_key) cache = channel_instance.cache_dict.get(cache_key)
if query1.get(cache_key) == False: if channel_instance.query1.get(cache_key) == False:
# The first query from wechat official server # The first query from wechat official server
logger.debug("[wechatmp] query1 {}".format(cache_key)) logger.debug("[wechatmp] query1 {}".format(cache_key))
query1[cache_key] = True channel_instance.query1[cache_key] = True
cnt = 0 cnt = 0
while cache[0] == 0 and cnt < 45: while cache[0] == 0 and cnt < 45:
cnt = cnt + 1 cnt = cnt + 1
time.sleep(0.1) time.sleep(0.1)
cache = cache_dict.get(cache_key) cache = channel_instance.cache_dict.get(cache_key)
if cnt == 45: if cnt == 45:
# waiting for timeout (the POST query will be closed by wechat official server) # waiting for timeout (the POST query will be closed by wechat official server)
time.sleep(5) time.sleep(5)
@@ -224,15 +148,15 @@ class WechatMPChannel(Channel):
return return
else: else:
pass pass
elif query2.get(cache_key) == False: elif channel_instance.query2.get(cache_key) == False:
# The second query from wechat official server # The second query from wechat official server
logger.debug("[wechatmp] query2 {}".format(cache_key)) logger.debug("[wechatmp] query2 {}".format(cache_key))
query2[cache_key] = True channel_instance.query2[cache_key] = True
cnt = 0 cnt = 0
while cache[0] == 0 and cnt < 45: while cache[0] == 0 and cnt < 45:
cnt = cnt + 1 cnt = cnt + 1
time.sleep(0.1) time.sleep(0.1)
cache = cache_dict.get(cache_key) cache = channel_instance.cache_dict.get(cache_key)
if cnt == 45: if cnt == 45:
# waiting for timeout (the POST query will be closed by wechat official server) # waiting for timeout (the POST query will be closed by wechat official server)
time.sleep(5) time.sleep(5)
@@ -240,42 +164,42 @@ class WechatMPChannel(Channel):
return return
else: else:
pass pass
elif query3.get(cache_key) == False: elif channel_instance.query3.get(cache_key) == False:
# The third query from wechat official server # The third query from wechat official server
logger.debug("[wechatmp] query3 {}".format(cache_key)) logger.debug("[wechatmp] query3 {}".format(cache_key))
query3[cache_key] = True channel_instance.query3[cache_key] = True
cnt = 0 cnt = 0
while cache[0] == 0 and cnt < 45: while cache[0] == 0 and cnt < 45:
cnt = cnt + 1 cnt = cnt + 1
time.sleep(0.1) time.sleep(0.1)
cache = cache_dict.get(cache_key) cache = channel_instance.cache_dict.get(cache_key)
if cnt == 45: if cnt == 45:
# Have waiting for 3x5 seconds # Have waiting for 3x5 seconds
# return timeout message # return timeout message
reply_text = "【正在响应中,回复任意文字尝试获取回复】" reply_text = "【正在响应中,回复任意文字尝试获取回复】"
logger.info("[wechatmp] Three queries has finished For {}: {}".format(fromUser, message_id)) logger.info("[wechatmp] Three queries has finished For {}: {}".format(from_user, message_id))
replyPost = reply.TextMsg(fromUser, toUser, reply_text).send() replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
return replyPost return replyPost
else: else:
pass pass
if float(time.time()) - float(queryTime) > 4.8: if float(time.time()) - float(query_time) > 4.8:
logger.info("[wechatmp] Timeout for {} {}".format(fromUser, message_id)) logger.info("[wechatmp] Timeout for {} {}".format(from_user, message_id))
return return
if cache[0] > 1: if cache[0] > 1:
reply_text = cache[1][:600] + "\n【未完待续,回复任意文字以继续】" #wechatmp auto_reply length limit reply_text = cache[1][:600] + "\n【未完待续,回复任意文字以继续】" #wechatmp auto_reply length limit
cache_dict[cache_key] = (cache[0] - 1, cache[1][600:]) channel_instance.cache_dict[cache_key] = (cache[0] - 1, cache[1][600:])
elif cache[0] == 1: elif cache[0] == 1:
reply_text = cache[1] reply_text = cache[1]
cache_dict.pop(cache_key) channel_instance.cache_dict.pop(cache_key)
logger.info("[wechatmp] {}:{} Do send {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), reply_text)) logger.info("[wechatmp] {}:{} Do send {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), reply_text))
replyPost = reply.TextMsg(fromUser, toUser, reply_text).send() replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
return replyPost return replyPost
elif isinstance(recMsg, receive.Event) and recMsg.MsgType == 'event': elif wechat_msg.msg_type == 'event':
logger.info("[wechatmp] Event {} from {}".format(recMsg.Event, recMsg.FromUserName)) logger.info("[wechatmp] Event {} from {}".format(wechat_msg.Event, wechat_msg.from_user_id))
content = textwrap.dedent("""\ content = textwrap.dedent("""\
感谢您的关注! 感谢您的关注!
这里是ChatGPT可以自由对话。 这里是ChatGPT可以自由对话。
@@ -285,7 +209,7 @@ class WechatMPChannel(Channel):
支持图片输出,画字开头的问题将回复图片链接。 支持图片输出,画字开头的问题将回复图片链接。
支持角色扮演和文字冒险两种定制模式对话。 支持角色扮演和文字冒险两种定制模式对话。
输入'#帮助' 查看详细指令。""") 输入'#帮助' 查看详细指令。""")
replyMsg = reply.TextMsg(recMsg.FromUserName, recMsg.ToUserName, content) replyMsg = reply.TextMsg(wechat_msg.from_user_id, wechat_msg.to_user_id, content)
return replyMsg.send() return replyMsg.send()
else: else:
logger.info("暂且不处理") logger.info("暂且不处理")
@@ -294,9 +218,3 @@ class WechatMPChannel(Channel):
logger.exception(exc) logger.exception(exc)
return exc return exc
def check_prefix(content, prefix_list):
for prefix in prefix_list:
if content.startswith(prefix):
return prefix
return None