mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-02 00:57:41 +08:00
feat: web ui channel update
This commit is contained in:
@@ -2,7 +2,7 @@ import sys
|
||||
import time
|
||||
import web
|
||||
import json
|
||||
from queue import Queue
|
||||
from queue import Queue, Empty
|
||||
from bridge.context import *
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_channel import ChatChannel, check_prefix
|
||||
@@ -58,91 +58,29 @@ class WebChannel(ChatChannel):
|
||||
logger.warning(f"Web channel doesn't support {reply.type} yet")
|
||||
return
|
||||
|
||||
if reply.type == ReplyType.IMAGE:
|
||||
from PIL import Image
|
||||
|
||||
image_storage = reply.content
|
||||
image_storage.seek(0)
|
||||
img = Image.open(image_storage)
|
||||
print("<IMAGE>")
|
||||
img.show()
|
||||
elif reply.type == ReplyType.IMAGE_URL:
|
||||
import io
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
img_url = reply.content
|
||||
pic_res = requests.get(img_url, stream=True)
|
||||
image_storage = io.BytesIO()
|
||||
for block in pic_res.iter_content(1024):
|
||||
image_storage.write(block)
|
||||
image_storage.seek(0)
|
||||
img = Image.open(image_storage)
|
||||
print(img_url)
|
||||
img.show()
|
||||
else:
|
||||
print(reply.content)
|
||||
|
||||
# 获取用户ID
|
||||
user_id = context.get("receiver", None)
|
||||
if not user_id:
|
||||
logger.error("No receiver found in context, cannot send message")
|
||||
return
|
||||
|
||||
# 确保用户有对应的消息队列
|
||||
if user_id not in self.message_queues:
|
||||
self.message_queues[user_id] = Queue()
|
||||
logger.debug(f"Created message queue for user {user_id}")
|
||||
|
||||
# 将消息放入对应用户的队列
|
||||
message_data = {
|
||||
"type": str(reply.type),
|
||||
"content": reply.content,
|
||||
"timestamp": time.time() # 使用 Unix 时间戳
|
||||
}
|
||||
self.message_queues[user_id].put(message_data)
|
||||
logger.debug(f"Message queued for user {user_id}: {reply.content[:30]}...")
|
||||
# 检查是否有响应队列
|
||||
response_queue = context.get("response_queue", None)
|
||||
if response_queue:
|
||||
# 直接将响应放入队列
|
||||
response_data = {
|
||||
"type": str(reply.type),
|
||||
"content": reply.content,
|
||||
"timestamp": time.time()
|
||||
}
|
||||
response_queue.put(response_data)
|
||||
logger.debug(f"Response sent to queue for user {user_id}")
|
||||
else:
|
||||
logger.warning(f"No response queue found for user {user_id}, response dropped")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in send method: {e}")
|
||||
|
||||
def sse_handler(self, user_id):
|
||||
"""
|
||||
Handle Server-Sent Events (SSE) for real-time communication.
|
||||
"""
|
||||
web.header('Content-Type', 'text/event-stream')
|
||||
web.header('Cache-Control', 'no-cache')
|
||||
web.header('Connection', 'keep-alive')
|
||||
|
||||
logger.debug(f"SSE connection established for user {user_id}")
|
||||
|
||||
# 确保用户有消息队列
|
||||
if user_id not in self.message_queues:
|
||||
self.message_queues[user_id] = Queue()
|
||||
logger.debug(f"Created new message queue for user {user_id}")
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
# 发送心跳
|
||||
yield f": heartbeat\n\n"
|
||||
|
||||
# 非阻塞方式获取消息
|
||||
if user_id in self.message_queues and not self.message_queues[user_id].empty():
|
||||
message = self.message_queues[user_id].get_nowait()
|
||||
logger.debug(f"Sending message to user {user_id}: {message}")
|
||||
data = json.dumps(message)
|
||||
yield f"data: {data}\n\n"
|
||||
logger.debug(f"Message sent to user {user_id}")
|
||||
time.sleep(0.5)
|
||||
except Exception as e:
|
||||
logger.error(f"SSE Error for user {user_id}: {str(e)}")
|
||||
break
|
||||
finally:
|
||||
# 清理资源
|
||||
logger.debug(f"SSE connection closed for user {user_id}")
|
||||
|
||||
def post_message(self):
|
||||
"""
|
||||
Handle incoming messages from users via POST request.
|
||||
@@ -167,7 +105,7 @@ class WebChannel(ChatChannel):
|
||||
msg_id=msg_id,
|
||||
content=prompt,
|
||||
from_user_id=user_id,
|
||||
to_user_id="Chatgpt", # 明确指定接收者
|
||||
to_user_id="Chatgpt",
|
||||
other_user_id=user_id
|
||||
)
|
||||
|
||||
@@ -175,13 +113,24 @@ class WebChannel(ChatChannel):
|
||||
if not context:
|
||||
return json.dumps({"status": "error", "message": "Failed to process message"})
|
||||
|
||||
# 创建一个响应队列
|
||||
response_queue = Queue()
|
||||
|
||||
# 确保上下文包含必要的信息
|
||||
context["isgroup"] = False
|
||||
context["receiver"] = user_id # 添加接收者信息,用于send方法中识别用户
|
||||
context["session_id"] = session_id # 添加会话ID
|
||||
context["receiver"] = user_id
|
||||
context["session_id"] = user_id
|
||||
context["response_queue"] = response_queue
|
||||
|
||||
# 发送消息到处理队列
|
||||
self.produce(context)
|
||||
return json.dumps({"status": "success", "message": "Message received"})
|
||||
|
||||
# 等待响应,最多等待30秒
|
||||
try:
|
||||
response = response_queue.get(timeout=30)
|
||||
return json.dumps({"status": "success", "reply": response["content"]})
|
||||
except Empty:
|
||||
return json.dumps({"status": "error", "message": "Response timeout"})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message: {e}")
|
||||
@@ -203,31 +152,14 @@ class WebChannel(ChatChannel):
|
||||
logger.info(f"Created static directory: {static_dir}")
|
||||
|
||||
urls = (
|
||||
'/sse/(.+)', 'SSEHandler',
|
||||
'/poll/(.+)', 'PollHandler',
|
||||
'/message', 'MessageHandler',
|
||||
'/chat', 'ChatHandler',
|
||||
'/assets/(.*)', 'AssetsHandler', # 匹配 /static/任何路径
|
||||
'/assets/(.*)', 'AssetsHandler', # 匹配 /assets/任何路径
|
||||
)
|
||||
port = conf().get("web_port", 9899)
|
||||
app = web.application(urls, globals(), autoreload=False)
|
||||
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
|
||||
|
||||
def poll_messages(self, user_id):
|
||||
"""Poll for new messages."""
|
||||
messages = []
|
||||
|
||||
if user_id in self.message_queues:
|
||||
while not self.message_queues[user_id].empty():
|
||||
messages.append(self.message_queues[user_id].get_nowait())
|
||||
|
||||
return json.dumps(messages)
|
||||
|
||||
|
||||
class SSEHandler:
|
||||
def GET(self, user_id):
|
||||
return WebChannel().sse_handler(user_id)
|
||||
|
||||
|
||||
class MessageHandler:
|
||||
def POST(self):
|
||||
@@ -242,13 +174,6 @@ class ChatHandler:
|
||||
return f.read()
|
||||
|
||||
|
||||
# 添加轮询处理器
|
||||
class PollHandler:
|
||||
def GET(self, user_id):
|
||||
web.header('Content-Type', 'application/json')
|
||||
return WebChannel().poll_messages(user_id)
|
||||
|
||||
|
||||
class AssetsHandler:
|
||||
def GET(self, file_path): # 修改默认参数
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user