feat: web channel support multiple message and picture display

This commit is contained in:
Saboteur7
2025-05-23 00:43:54 +08:00
parent 70d7e52df0
commit 5f7ade20dc
5 changed files with 848 additions and 265 deletions

View File

@@ -2,6 +2,7 @@ import sys
import time
import web
import json
import uuid
from queue import Queue, Empty
from bridge.context import *
from bridge.reply import Reply, ReplyType
@@ -12,6 +13,8 @@ from common.singleton import singleton
from config import conf
import os
import mimetypes # 添加这行来处理MIME类型
import threading
import logging
class WebMessage(ChatMessage):
def __init__(
@@ -43,39 +46,54 @@ class WebChannel(ChatChannel):
def __init__(self):
super().__init__()
self.message_queues = {} # 为每个用户存储一个消息队列
self.msg_id_counter = 0 # 添加消息ID计数器
self.session_queues = {} # 存储session_id到队列的映射
self.request_to_session = {} # 存储request_id到session_id的映射
def _generate_msg_id(self):
"""生成唯一的消息ID"""
self.msg_id_counter += 1
return str(int(time.time())) + str(self.msg_id_counter)
def _generate_request_id(self):
"""生成唯一的请求ID"""
return str(uuid.uuid4())
def send(self, reply: Reply, context: Context):
try:
if reply.type in self.NOT_SUPPORT_REPLYTYPE:
logger.warning(f"Web channel doesn't support {reply.type} yet")
return
if reply.type == ReplyType.IMAGE_URL:
time.sleep(0.5)
# 获取请求ID和会话ID
request_id = context.get("request_id", None)
if not request_id:
logger.error("No request_id found in context, cannot send message")
return
# 获取用户ID
user_id = context.get("receiver", None)
if not user_id:
logger.error("No receiver found in context, cannot send message")
# 通过request_id获取session_id
session_id = self.request_to_session.get(request_id)
if not session_id:
logger.error(f"No session_id found for request {request_id}")
return
# 检查是否有响应队列
response_queue = context.get("response_queue", None)
if response_queue:
# 直接将响应放入队列
# 检查是否有会话队列
if session_id in self.session_queues:
# 创建响应数据包含请求ID以区分不同请求的响应
response_data = {
"type": str(reply.type),
"content": reply.content,
"timestamp": time.time()
"timestamp": time.time(),
"request_id": request_id
}
response_queue.put(response_data)
logger.debug(f"Response sent to queue for user {user_id}")
self.session_queues[session_id].put(response_data)
logger.debug(f"Response sent to queue for session {session_id}, request {request_id}")
else:
logger.warning(f"No response queue found for user {user_id}, response dropped")
logger.warning(f"No response queue found for session {session_id}, response dropped")
except Exception as e:
logger.error(f"Error in send method: {e}")
@@ -83,57 +101,83 @@ class WebChannel(ChatChannel):
def post_message(self):
"""
Handle incoming messages from users via POST request.
Returns a request_id for tracking this specific request.
"""
try:
data = web.data() # 获取原始POST数据
json_data = json.loads(data)
user_id = json_data.get('user_id', 'default_user')
prompt = json_data.get('message', '')
session_id = json_data.get('session_id', f'session_{int(time.time())}')
except json.JSONDecodeError:
return json.dumps({"status": "error", "message": "Invalid JSON"})
except Exception as e:
return json.dumps({"status": "error", "message": str(e)})
if not prompt:
return json.dumps({"status": "error", "message": "No message provided"})
prompt = json_data.get('message', '')
try:
msg_id = self._generate_msg_id()
web_message = WebMessage(
msg_id=msg_id,
content=prompt,
from_user_id=user_id,
to_user_id="Chatgpt",
other_user_id=user_id
)
# 生成请求ID
request_id = self._generate_request_id()
context = self._compose_context(ContextType.TEXT, prompt, msg=web_message)
if not context:
return json.dumps({"status": "error", "message": "Failed to process message"})
# 创建一个响应队列
response_queue = Queue()
# 将请求ID与会话ID关联
self.request_to_session[request_id] = session_id
# 确保上下文包含必要的信息
context["isgroup"] = False
context["receiver"] = user_id
context["session_id"] = user_id
context["response_queue"] = response_queue
# 发送消息到处理队列
self.produce(context)
# 确保会话队列存在
if session_id not in self.session_queues:
self.session_queues[session_id] = Queue()
# 等待响应最多等待30秒
try:
response = response_queue.get(timeout=120)
return json.dumps({"status": "success", "reply": response["content"]})
except Empty:
return json.dumps({"status": "error", "message": "Response timeout"})
# 创建消息对象
msg = WebMessage(self._generate_msg_id(), prompt)
msg.from_user_id = session_id # 使用会话ID作为用户ID
# 创建上下文
context = self._compose_context(ContextType.TEXT, prompt, msg=msg)
# 添加必要的字段
context["session_id"] = session_id
context["request_id"] = request_id
context["isgroup"] = False # 添加 isgroup 字段
context["receiver"] = session_id # 添加 receiver 字段
# 异步处理消息 - 只传递上下文
threading.Thread(target=self.produce, args=(context,)).start()
# 返回请求ID
return json.dumps({"status": "success", "request_id": request_id})
except Exception as e:
logger.error(f"Error processing message: {e}")
return json.dumps({"status": "error", "message": "Internal server error"})
return json.dumps({"status": "error", "message": str(e)})
def poll_response(self):
"""
Poll for responses using the session_id.
"""
try:
# 不记录轮询请求的日志
web.ctx.log_request = False
data = web.data()
json_data = json.loads(data)
session_id = json_data.get('session_id')
if not session_id or session_id not in self.session_queues:
return json.dumps({"status": "error", "message": "Invalid session ID"})
# 尝试从队列获取响应,不等待
try:
# 使用peek而不是get这样如果前端没有成功处理下次还能获取到
response = self.session_queues[session_id].get(block=False)
# 返回响应包含请求ID以区分不同请求
return json.dumps({
"status": "success",
"has_content": True,
"content": response["content"],
"request_id": response["request_id"],
"timestamp": response["timestamp"]
})
except Empty:
# 没有新响应
return json.dumps({"status": "success", "has_content": False})
except Exception as e:
logger.error(f"Error polling response: {e}")
return json.dumps({"status": "error", "message": str(e)})
def chat_page(self):
"""Serve the chat HTML page."""
@@ -153,6 +197,7 @@ class WebChannel(ChatChannel):
urls = (
'/', 'RootHandler', # 添加根路径处理器
'/message', 'MessageHandler',
'/poll', 'PollHandler', # 添加轮询处理器
'/chat', 'ChatHandler',
'/assets/(.*)', 'AssetsHandler', # 匹配 /assets/任何路径
)
@@ -163,6 +208,12 @@ class WebChannel(ChatChannel):
import io
from contextlib import redirect_stdout
# 配置web.py的日志级别为ERROR只显示错误
logging.getLogger("web").setLevel(logging.ERROR)
# 禁用web.httpserver的日志
logging.getLogger("web.httpserver").setLevel(logging.ERROR)
# 临时重定向标准输出捕获web.py的启动消息
with redirect_stdout(io.StringIO()):
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
@@ -179,6 +230,11 @@ class MessageHandler:
return WebChannel().post_message()
class PollHandler:
def POST(self):
return WebChannel().poll_response()
class ChatHandler:
def GET(self):
# 正常返回聊天页面