feat: web ui channel update

This commit is contained in:
Saboteur7
2025-05-18 16:56:50 +08:00
parent 8c8e996c87
commit 03fc8c1202
2 changed files with 435 additions and 222 deletions

View File

@@ -2,7 +2,7 @@ import sys
import time
import web
import json
from queue import Queue
from queue import Queue, Empty
from bridge.context import *
from bridge.reply import Reply, ReplyType
from channel.chat_channel import ChatChannel, check_prefix
@@ -58,91 +58,29 @@ class WebChannel(ChatChannel):
logger.warning(f"Web channel doesn't support {reply.type} yet")
return
if reply.type == ReplyType.IMAGE:
from PIL import Image
image_storage = reply.content
image_storage.seek(0)
img = Image.open(image_storage)
print("<IMAGE>")
img.show()
elif reply.type == ReplyType.IMAGE_URL:
import io
import requests
from PIL import Image
img_url = reply.content
pic_res = requests.get(img_url, stream=True)
image_storage = io.BytesIO()
for block in pic_res.iter_content(1024):
image_storage.write(block)
image_storage.seek(0)
img = Image.open(image_storage)
print(img_url)
img.show()
else:
print(reply.content)
# 获取用户ID
user_id = context.get("receiver", None)
if not user_id:
logger.error("No receiver found in context, cannot send message")
return
# 确保用户有对应的消息队列
if user_id not in self.message_queues:
self.message_queues[user_id] = Queue()
logger.debug(f"Created message queue for user {user_id}")
# 将消息放入对应用户的队列
message_data = {
"type": str(reply.type),
"content": reply.content,
"timestamp": time.time() # 使用 Unix 时间戳
}
self.message_queues[user_id].put(message_data)
logger.debug(f"Message queued for user {user_id}: {reply.content[:30]}...")
# 检查是否有响应队列
response_queue = context.get("response_queue", None)
if response_queue:
# 直接将响应放入队列
response_data = {
"type": str(reply.type),
"content": reply.content,
"timestamp": time.time()
}
response_queue.put(response_data)
logger.debug(f"Response sent to queue for user {user_id}")
else:
logger.warning(f"No response queue found for user {user_id}, response dropped")
except Exception as e:
logger.error(f"Error in send method: {e}")
def sse_handler(self, user_id):
"""
Handle Server-Sent Events (SSE) for real-time communication.
"""
web.header('Content-Type', 'text/event-stream')
web.header('Cache-Control', 'no-cache')
web.header('Connection', 'keep-alive')
logger.debug(f"SSE connection established for user {user_id}")
# 确保用户有消息队列
if user_id not in self.message_queues:
self.message_queues[user_id] = Queue()
logger.debug(f"Created new message queue for user {user_id}")
try:
while True:
try:
# 发送心跳
yield f": heartbeat\n\n"
# 非阻塞方式获取消息
if user_id in self.message_queues and not self.message_queues[user_id].empty():
message = self.message_queues[user_id].get_nowait()
logger.debug(f"Sending message to user {user_id}: {message}")
data = json.dumps(message)
yield f"data: {data}\n\n"
logger.debug(f"Message sent to user {user_id}")
time.sleep(0.5)
except Exception as e:
logger.error(f"SSE Error for user {user_id}: {str(e)}")
break
finally:
# 清理资源
logger.debug(f"SSE connection closed for user {user_id}")
def post_message(self):
"""
Handle incoming messages from users via POST request.
@@ -167,7 +105,7 @@ class WebChannel(ChatChannel):
msg_id=msg_id,
content=prompt,
from_user_id=user_id,
to_user_id="Chatgpt", # 明确指定接收者
to_user_id="Chatgpt",
other_user_id=user_id
)
@@ -175,13 +113,24 @@ class WebChannel(ChatChannel):
if not context:
return json.dumps({"status": "error", "message": "Failed to process message"})
# 创建一个响应队列
response_queue = Queue()
# 确保上下文包含必要的信息
context["isgroup"] = False
context["receiver"] = user_id # 添加接收者信息用于send方法中识别用户
context["session_id"] = session_id # 添加会话ID
context["receiver"] = user_id
context["session_id"] = user_id
context["response_queue"] = response_queue
# 发送消息到处理队列
self.produce(context)
return json.dumps({"status": "success", "message": "Message received"})
# 等待响应最多等待30秒
try:
response = response_queue.get(timeout=30)
return json.dumps({"status": "success", "reply": response["content"]})
except Empty:
return json.dumps({"status": "error", "message": "Response timeout"})
except Exception as e:
logger.error(f"Error processing message: {e}")
@@ -203,31 +152,14 @@ class WebChannel(ChatChannel):
logger.info(f"Created static directory: {static_dir}")
urls = (
'/sse/(.+)', 'SSEHandler',
'/poll/(.+)', 'PollHandler',
'/message', 'MessageHandler',
'/chat', 'ChatHandler',
'/assets/(.*)', 'AssetsHandler', # 匹配 /static/任何路径
'/assets/(.*)', 'AssetsHandler', # 匹配 /assets/任何路径
)
port = conf().get("web_port", 9899)
app = web.application(urls, globals(), autoreload=False)
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
def poll_messages(self, user_id):
"""Poll for new messages."""
messages = []
if user_id in self.message_queues:
while not self.message_queues[user_id].empty():
messages.append(self.message_queues[user_id].get_nowait())
return json.dumps(messages)
class SSEHandler:
def GET(self, user_id):
return WebChannel().sse_handler(user_id)
class MessageHandler:
def POST(self):
@@ -242,13 +174,6 @@ class ChatHandler:
return f.read()
# 添加轮询处理器
class PollHandler:
def GET(self, user_id):
web.header('Content-Type', 'application/json')
return WebChannel().poll_messages(user_id)
class AssetsHandler:
def GET(self, file_path): # 修改默认参数
try: