mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-02 00:57:41 +08:00
refactor(openai): drop SDK dependency and switch to native HTTP client
This commit is contained in:
@@ -3,8 +3,15 @@
|
||||
import time
|
||||
import json
|
||||
|
||||
import openai
|
||||
from models.openai.openai_compat import error as openai_error, RateLimitError, Timeout, APIError, APIConnectionError
|
||||
from models.openai.openai_compat import (
|
||||
error as openai_error,
|
||||
RateLimitError,
|
||||
Timeout,
|
||||
APIError,
|
||||
APIConnectionError,
|
||||
wrap_http_error,
|
||||
)
|
||||
from models.openai.openai_http_client import OpenAIHTTPClient, OpenAIHTTPError
|
||||
import requests
|
||||
from common import const
|
||||
from models.bot import Bot
|
||||
@@ -23,18 +30,19 @@ from models.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||
class ChatGPTBot(Bot, OpenAIImage, OpenAICompatibleBot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# set the default api_key / api_base based on bot_type
|
||||
# Resolve api key / base from config (no global SDK state anymore).
|
||||
if conf().get("bot_type") == "custom":
|
||||
openai.api_key = conf().get("custom_api_key", "")
|
||||
if conf().get("custom_api_base"):
|
||||
openai.api_base = conf().get("custom_api_base")
|
||||
self._api_key = conf().get("custom_api_key", "")
|
||||
self._api_base = conf().get("custom_api_base") or None
|
||||
else:
|
||||
openai.api_key = conf().get("open_ai_api_key")
|
||||
if conf().get("open_ai_api_base"):
|
||||
openai.api_base = conf().get("open_ai_api_base")
|
||||
proxy = conf().get("proxy")
|
||||
if proxy:
|
||||
openai.proxy = proxy
|
||||
self._api_key = conf().get("open_ai_api_key")
|
||||
self._api_base = conf().get("open_ai_api_base") or None
|
||||
self._proxy = conf().get("proxy") or None
|
||||
self._http_client = OpenAIHTTPClient(
|
||||
api_key=self._api_key,
|
||||
api_base=self._api_base,
|
||||
proxy=self._proxy,
|
||||
)
|
||||
if conf().get("rate_limit_chatgpt"):
|
||||
self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20))
|
||||
conf_model = conf().get("model") or "gpt-3.5-turbo"
|
||||
@@ -72,6 +80,10 @@ class ChatGPTBot(Bot, OpenAIImage, OpenAICompatibleBot):
|
||||
'default_presence_penalty': conf().get("presence_penalty", 0.0),
|
||||
}
|
||||
|
||||
def _get_http_client(self) -> OpenAIHTTPClient:
|
||||
"""Override the default HTTP client to reuse our pre-configured one."""
|
||||
return self._http_client
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if context.type == ContextType.TEXT:
|
||||
@@ -195,20 +207,16 @@ class ChatGPTBot(Bot, OpenAIImage, OpenAICompatibleBot):
|
||||
|
||||
logger.info(f"[CHATGPT] Calling vision API with model: {model}")
|
||||
|
||||
# Call OpenAI API
|
||||
kwargs = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": 1000
|
||||
}
|
||||
if api_key:
|
||||
kwargs["api_key"] = api_key
|
||||
if api_base:
|
||||
kwargs["api_base"] = api_base
|
||||
# Call OpenAI-compatible API via HTTP
|
||||
response = self._http_client.chat_completions(
|
||||
api_key=api_key or None,
|
||||
api_base=api_base or None,
|
||||
model=model,
|
||||
messages=messages,
|
||||
max_tokens=1000,
|
||||
)
|
||||
|
||||
response = openai.ChatCompletion.create(**kwargs)
|
||||
|
||||
content = response.choices[0]["message"]["content"]
|
||||
content = response["choices"][0]["message"]["content"]
|
||||
logger.info(f"[CHATGPT] Vision API response: {content[:100]}...")
|
||||
|
||||
# Clean up temp file
|
||||
@@ -237,18 +245,37 @@ class ChatGPTBot(Bot, OpenAIImage, OpenAICompatibleBot):
|
||||
try:
|
||||
if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token():
|
||||
raise RateLimitError("RateLimitError: rate limit exceeded")
|
||||
# if api_key == None, the default openai.api_key will be used
|
||||
# If api_key is None, the per-instance default key will be used.
|
||||
if args is None:
|
||||
args = self.args
|
||||
response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **args)
|
||||
# logger.debug("[CHATGPT] response={}".format(response))
|
||||
logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
|
||||
# Translate old SDK kwargs to HTTP client params:
|
||||
# - request_timeout / timeout -> per-call timeout
|
||||
call_args = dict(args)
|
||||
timeout = call_args.pop("request_timeout", None) or call_args.pop("timeout", None)
|
||||
response = self._http_client.chat_completions(
|
||||
api_key=api_key or None,
|
||||
timeout=timeout,
|
||||
messages=session.messages,
|
||||
**call_args,
|
||||
)
|
||||
logger.info("[ChatGPT] reply={}, total_tokens={}".format(
|
||||
response["choices"][0]["message"]["content"],
|
||||
response["usage"]["total_tokens"]
|
||||
))
|
||||
return {
|
||||
"total_tokens": response["usage"]["total_tokens"],
|
||||
"completion_tokens": response["usage"]["completion_tokens"],
|
||||
"content": response.choices[0]["message"]["content"],
|
||||
"content": response["choices"][0]["message"]["content"],
|
||||
}
|
||||
except OpenAIHTTPError as http_err:
|
||||
return self._handle_reply_error(
|
||||
wrap_http_error(http_err), session, api_key, args, retry_count
|
||||
)
|
||||
except Exception as e:
|
||||
return self._handle_reply_error(e, session, api_key, args, retry_count)
|
||||
|
||||
def _handle_reply_error(self, e, session, api_key, args, retry_count):
|
||||
"""Map exception to user-facing reply with retry/backoff (mirrors SDK behavior)."""
|
||||
need_retry = retry_count < 2
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if isinstance(e, RateLimitError):
|
||||
@@ -261,16 +288,16 @@ class ChatGPTBot(Bot, OpenAIImage, OpenAICompatibleBot):
|
||||
result["content"] = "我没有收到你的消息"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
elif isinstance(e, APIError):
|
||||
logger.warn("[CHATGPT] Bad Gateway: {}".format(e))
|
||||
result["content"] = "请再问我一次"
|
||||
if need_retry:
|
||||
time.sleep(10)
|
||||
elif isinstance(e, APIConnectionError):
|
||||
logger.warn("[CHATGPT] APIConnectionError: {}".format(e))
|
||||
result["content"] = "我连接不到你的网络"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
elif isinstance(e, APIError):
|
||||
logger.warn("[CHATGPT] Bad Gateway: {}".format(e))
|
||||
result["content"] = "请再问我一次"
|
||||
if need_retry:
|
||||
time.sleep(10)
|
||||
else:
|
||||
logger.exception("[CHATGPT] Exception: {}".format(e))
|
||||
need_retry = False
|
||||
@@ -279,15 +306,39 @@ class ChatGPTBot(Bot, OpenAIImage, OpenAICompatibleBot):
|
||||
if need_retry:
|
||||
logger.warn("[CHATGPT] 第{}次重试".format(retry_count + 1))
|
||||
return self.reply_text(session, api_key, args, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
|
||||
class AzureChatGPTBot(ChatGPTBot):
|
||||
"""Azure OpenAI variant.
|
||||
|
||||
Azure's HTTP shape differs from public OpenAI:
|
||||
URL : {endpoint}/openai/deployments/{deployment}/chat/completions
|
||||
Auth : api-key header (not Bearer)
|
||||
Query : ?api-version={version}
|
||||
We model that with a dedicated HTTP client and override _get_http_client
|
||||
so the OpenAICompatibleBot streaming/tool path uses it transparently.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
openai.api_type = "azure"
|
||||
openai.api_version = conf().get("azure_api_version", "2023-06-01-preview")
|
||||
self.args["deployment_id"] = conf().get("azure_deployment_id")
|
||||
self._azure_api_version = conf().get("azure_api_version", "2023-06-01-preview")
|
||||
self._azure_deployment_id = conf().get("azure_deployment_id")
|
||||
# Drop legacy SDK kwarg; Azure deployment is encoded in the URL now.
|
||||
self.args.pop("deployment_id", None)
|
||||
|
||||
endpoint = (self._api_base or "").rstrip("/")
|
||||
deployment = self._azure_deployment_id or ""
|
||||
# Build a base that already includes /openai/deployments/{deployment}.
|
||||
# /chat/completions will be appended by the client.
|
||||
azure_base = (
|
||||
f"{endpoint}/openai/deployments/{deployment}" if endpoint and deployment else endpoint
|
||||
)
|
||||
self._http_client = _AzureChatHTTPClient(
|
||||
api_key=self._api_key,
|
||||
api_base=azure_base,
|
||||
api_version=self._azure_api_version,
|
||||
proxy=self._proxy,
|
||||
)
|
||||
|
||||
def create_img(self, query, retry_count=0, api_key=None):
|
||||
text_to_image_model = conf().get("text_to_image")
|
||||
@@ -357,3 +408,35 @@ class AzureChatGPTBot(ChatGPTBot):
|
||||
return False, "图片生成失败"
|
||||
else:
|
||||
return False, "图片生成失败,未配置text_to_image参数"
|
||||
|
||||
|
||||
class _AzureChatHTTPClient(OpenAIHTTPClient):
|
||||
"""Subclass that injects Azure's ``api-version`` query param and ``api-key``
|
||||
header on every chat-completion request, and accepts the deployment-scoped
|
||||
base URL set by :class:`AzureChatGPTBot`.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key, api_base, api_version, proxy=None, timeout=None):
|
||||
super().__init__(
|
||||
api_key=api_key, api_base=api_base, proxy=proxy, timeout=timeout
|
||||
)
|
||||
self._api_version = api_version
|
||||
|
||||
def _build_headers(self, api_key, extra_headers):
|
||||
# Azure uses api-key header, not Bearer token.
|
||||
key = api_key if api_key is not None else self.api_key
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if key:
|
||||
headers["api-key"] = key
|
||||
if self.extra_headers:
|
||||
headers.update(self.extra_headers)
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
return headers
|
||||
|
||||
def chat_completions(self, **kwargs):
|
||||
# Always force api-version query param for Azure.
|
||||
eq = dict(kwargs.get("extra_query") or {})
|
||||
eq.setdefault("api-version", self._api_version)
|
||||
kwargs["extra_query"] = eq
|
||||
return super().chat_completions(**kwargs)
|
||||
|
||||
@@ -2,8 +2,14 @@
|
||||
|
||||
import time
|
||||
|
||||
import openai
|
||||
from models.openai.openai_compat import RateLimitError, Timeout, APIConnectionError
|
||||
from models.openai.openai_compat import (
|
||||
RateLimitError,
|
||||
Timeout,
|
||||
APIConnectionError,
|
||||
APIError,
|
||||
wrap_http_error,
|
||||
)
|
||||
from models.openai.openai_http_client import OpenAIHTTPClient, OpenAIHTTPError
|
||||
|
||||
from models.bot import Bot
|
||||
from models.openai_compatible_bot import OpenAICompatibleBot
|
||||
@@ -22,12 +28,14 @@ user_session = dict()
|
||||
class OpenAIBot(Bot, OpenAIImage, OpenAICompatibleBot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
openai.api_key = conf().get("open_ai_api_key")
|
||||
if conf().get("open_ai_api_base"):
|
||||
openai.api_base = conf().get("open_ai_api_base")
|
||||
proxy = conf().get("proxy")
|
||||
if proxy:
|
||||
openai.proxy = proxy
|
||||
self._api_key = conf().get("open_ai_api_key")
|
||||
self._api_base = conf().get("open_ai_api_base") or None
|
||||
self._proxy = conf().get("proxy") or None
|
||||
self._http_client = OpenAIHTTPClient(
|
||||
api_key=self._api_key,
|
||||
api_base=self._api_base,
|
||||
proxy=self._proxy,
|
||||
)
|
||||
|
||||
self.sessions = SessionManager(OpenAISession, model=conf().get("model") or "text-davinci-003")
|
||||
self.args = {
|
||||
@@ -54,6 +62,10 @@ class OpenAIBot(Bot, OpenAIImage, OpenAICompatibleBot):
|
||||
'default_presence_penalty': conf().get("presence_penalty", 0.0),
|
||||
}
|
||||
|
||||
def _get_http_client(self) -> OpenAIHTTPClient:
|
||||
"""Reuse the per-instance HTTP client for the streaming/tool path."""
|
||||
return self._http_client
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if context and context.type:
|
||||
@@ -96,8 +108,14 @@ class OpenAIBot(Bot, OpenAIImage, OpenAICompatibleBot):
|
||||
|
||||
def reply_text(self, session: OpenAISession, retry_count=0):
|
||||
try:
|
||||
response = openai.Completion.create(prompt=str(session), **self.args)
|
||||
res_content = response.choices[0]["text"].strip().replace("<|endoftext|>", "")
|
||||
call_args = dict(self.args)
|
||||
timeout = call_args.pop("request_timeout", None) or call_args.pop("timeout", None)
|
||||
response = self._http_client.completions(
|
||||
timeout=timeout,
|
||||
prompt=str(session),
|
||||
**call_args,
|
||||
)
|
||||
res_content = response["choices"][0]["text"].strip().replace("<|endoftext|>", "")
|
||||
total_tokens = response["usage"]["total_tokens"]
|
||||
completion_tokens = response["usage"]["completion_tokens"]
|
||||
logger.info("[OPEN_AI] reply={}".format(res_content))
|
||||
@@ -106,7 +124,13 @@ class OpenAIBot(Bot, OpenAIImage, OpenAICompatibleBot):
|
||||
"completion_tokens": completion_tokens,
|
||||
"content": res_content,
|
||||
}
|
||||
except OpenAIHTTPError as http_err:
|
||||
return self._handle_legacy_error(wrap_http_error(http_err), session, retry_count)
|
||||
except Exception as e:
|
||||
return self._handle_legacy_error(e, session, retry_count)
|
||||
|
||||
def _handle_legacy_error(self, e, session, retry_count):
|
||||
"""Map exception -> reply for the legacy /completions endpoint."""
|
||||
need_retry = retry_count < 2
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if isinstance(e, RateLimitError):
|
||||
@@ -131,100 +155,10 @@ class OpenAIBot(Bot, OpenAIImage, OpenAICompatibleBot):
|
||||
if need_retry:
|
||||
logger.warn("[OPEN_AI] 第{}次重试".format(retry_count + 1))
|
||||
return self.reply_text(session, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
|
||||
def call_with_tools(self, messages, tools=None, stream=False, **kwargs):
|
||||
"""
|
||||
Call OpenAI API with tool support for agent integration
|
||||
Note: This bot uses the old Completion API which doesn't support tools.
|
||||
For tool support, use ChatGPTBot instead.
|
||||
|
||||
This method converts to ChatCompletion API when tools are provided.
|
||||
|
||||
Args:
|
||||
messages: List of messages
|
||||
tools: List of tool definitions (OpenAI format)
|
||||
stream: Whether to use streaming
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
Formatted response in OpenAI format or generator for streaming
|
||||
"""
|
||||
try:
|
||||
# The old Completion API doesn't support tools
|
||||
# We need to use ChatCompletion API instead
|
||||
logger.info("[OPEN_AI] Using ChatCompletion API for tool support")
|
||||
|
||||
# Build request parameters for ChatCompletion
|
||||
request_params = {
|
||||
"model": kwargs.get("model", conf().get("model") or "gpt-4.1"),
|
||||
"messages": messages,
|
||||
"temperature": kwargs.get("temperature", conf().get("temperature", 0.9)),
|
||||
"top_p": kwargs.get("top_p", 1),
|
||||
"frequency_penalty": kwargs.get("frequency_penalty", conf().get("frequency_penalty", 0.0)),
|
||||
"presence_penalty": kwargs.get("presence_penalty", conf().get("presence_penalty", 0.0)),
|
||||
"stream": stream
|
||||
}
|
||||
|
||||
# Add max_tokens if specified
|
||||
if kwargs.get("max_tokens"):
|
||||
request_params["max_tokens"] = kwargs["max_tokens"]
|
||||
|
||||
# Add tools if provided
|
||||
if tools:
|
||||
request_params["tools"] = tools
|
||||
request_params["tool_choice"] = kwargs.get("tool_choice", "auto")
|
||||
|
||||
# Make API call using ChatCompletion
|
||||
if stream:
|
||||
return self._handle_stream_response(request_params)
|
||||
else:
|
||||
return self._handle_sync_response(request_params)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[OPEN_AI] call_with_tools error: {e}")
|
||||
if stream:
|
||||
def error_generator():
|
||||
yield {
|
||||
"error": True,
|
||||
"message": str(e),
|
||||
"status_code": 500
|
||||
}
|
||||
return error_generator()
|
||||
else:
|
||||
return {
|
||||
"error": True,
|
||||
"message": str(e),
|
||||
"status_code": 500
|
||||
}
|
||||
|
||||
def _handle_sync_response(self, request_params):
|
||||
"""Handle synchronous OpenAI ChatCompletion API response"""
|
||||
try:
|
||||
response = openai.ChatCompletion.create(**request_params)
|
||||
|
||||
logger.info(f"[OPEN_AI] call_with_tools reply, model={response.get('model')}, "
|
||||
f"total_tokens={response.get('usage', {}).get('total_tokens', 0)}")
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[OPEN_AI] sync response error: {e}")
|
||||
raise
|
||||
|
||||
def _handle_stream_response(self, request_params):
|
||||
"""Handle streaming OpenAI ChatCompletion API response"""
|
||||
try:
|
||||
stream = openai.ChatCompletion.create(**request_params)
|
||||
|
||||
for chunk in stream:
|
||||
yield chunk
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[OPEN_AI] stream response error: {e}")
|
||||
yield {
|
||||
"error": True,
|
||||
"message": str(e),
|
||||
"status_code": 500
|
||||
}
|
||||
# NOTE: Tool-call routing is delegated to OpenAICompatibleBot.call_with_tools,
|
||||
# which calls /chat/completions via our shared HTTP client. The previous
|
||||
# bespoke implementation here bypassed Claude->OpenAI message/tool conversion
|
||||
# and was effectively broken for agent flows; we now inherit the correct
|
||||
# implementation from the base class.
|
||||
|
||||
@@ -1,17 +1,25 @@
|
||||
import time
|
||||
|
||||
import openai
|
||||
from models.openai.openai_compat import RateLimitError
|
||||
|
||||
from common.log import logger
|
||||
from common.token_bucket import TokenBucket
|
||||
from config import conf
|
||||
from models.openai.openai_compat import RateLimitError, wrap_http_error
|
||||
from models.openai.openai_http_client import OpenAIHTTPClient, OpenAIHTTPError
|
||||
|
||||
|
||||
# OPENAI提供的画图接口
|
||||
# OpenAI image generation API wrapper
|
||||
class OpenAIImage(object):
|
||||
def __init__(self):
|
||||
openai.api_key = conf().get("open_ai_api_key")
|
||||
# Lazy default client; subclasses (ChatGPTBot/OpenAIBot) typically
|
||||
# construct their own _http_client and override _get_image_client().
|
||||
self._image_api_key = conf().get("open_ai_api_key")
|
||||
self._image_api_base = conf().get("open_ai_api_base") or None
|
||||
self._image_proxy = conf().get("proxy") or None
|
||||
self._image_client = OpenAIHTTPClient(
|
||||
api_key=self._image_api_key,
|
||||
api_base=self._image_api_base,
|
||||
proxy=self._image_proxy,
|
||||
)
|
||||
if conf().get("rate_limit_dalle"):
|
||||
self.tb4dalle = TokenBucket(conf().get("rate_limit_dalle", 50))
|
||||
|
||||
@@ -20,23 +28,34 @@ class OpenAIImage(object):
|
||||
if conf().get("rate_limit_dalle") and not self.tb4dalle.get_token():
|
||||
return False, "请求太快了,请休息一下再问我吧"
|
||||
logger.info("[OPEN_AI] image_query={}".format(query))
|
||||
response = openai.Image.create(
|
||||
api_key=api_key,
|
||||
prompt=query, # 图片描述
|
||||
n=1, # 每次生成图片的数量
|
||||
response = self._image_client.images_generate(
|
||||
api_key=api_key or None,
|
||||
api_base=api_base or None,
|
||||
prompt=query, # image description
|
||||
n=1,
|
||||
model=conf().get("text_to_image") or "dall-e-2",
|
||||
# size=conf().get("image_create_size", "256x256"), # 图片大小,可选有 256x256, 512x512, 1024x1024
|
||||
# size=conf().get("image_create_size", "256x256"),
|
||||
)
|
||||
image_url = response["data"][0]["url"]
|
||||
logger.info("[OPEN_AI] image_url={}".format(image_url))
|
||||
return True, image_url
|
||||
except OpenAIHTTPError as http_err:
|
||||
mapped = wrap_http_error(http_err)
|
||||
if isinstance(mapped, RateLimitError):
|
||||
logger.warn(mapped)
|
||||
if retry_count < 1:
|
||||
time.sleep(5)
|
||||
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count + 1))
|
||||
return self.create_img(query, retry_count + 1)
|
||||
return False, "画图出现问题,请休息一下再问我吧"
|
||||
logger.exception(mapped)
|
||||
return False, "画图出现问题,请休息一下再问我吧"
|
||||
except RateLimitError as e:
|
||||
logger.warn(e)
|
||||
if retry_count < 1:
|
||||
time.sleep(5)
|
||||
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count + 1))
|
||||
return self.create_img(query, retry_count + 1)
|
||||
else:
|
||||
return False, "画图出现问题,请休息一下再问我吧"
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
@@ -1,102 +1,163 @@
|
||||
"""
|
||||
OpenAI compatibility layer for different versions.
|
||||
OpenAI-compatible exception layer.
|
||||
|
||||
This module provides a compatibility layer between OpenAI library versions:
|
||||
- OpenAI < 1.0 (old API with openai.error module)
|
||||
- OpenAI >= 1.0 (new API with direct exception imports)
|
||||
This module used to bridge between openai SDK 0.x and 1.x exception types.
|
||||
Since we no longer depend on the `openai` SDK at all (we call HTTP directly
|
||||
via :mod:`models.openai.openai_http_client`), this file now provides:
|
||||
|
||||
1. Pure Python exception classes that match the *names* the rest of the
|
||||
codebase already imports (RateLimitError / Timeout / APIError /
|
||||
APIConnectionError / AuthenticationError / InvalidRequestError ...).
|
||||
2. A :func:`map_http_error` helper that converts an
|
||||
:class:`OpenAIHTTPError` (or any HTTP status code + message) into the
|
||||
appropriate exception subclass, so existing ``except RateLimitError``
|
||||
``except Timeout`` etc. blocks keep working unchanged.
|
||||
|
||||
This keeps the behavior of all existing bots (rate-limit backoff, timeout
|
||||
retry, auth-error fast-fail) identical to the openai-SDK-based version, while
|
||||
removing the hard dependency on the `openai` package.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Try new OpenAI >= 1.0 API
|
||||
from openai import (
|
||||
OpenAIError,
|
||||
RateLimitError,
|
||||
APIError,
|
||||
APIConnectionError,
|
||||
AuthenticationError,
|
||||
APITimeoutError,
|
||||
BadRequestError,
|
||||
)
|
||||
from typing import Optional
|
||||
|
||||
# Create a mock error module for backward compatibility
|
||||
class ErrorModule:
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Exception hierarchy (mirrors openai SDK names so call sites don't change)
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
class OpenAIError(Exception):
|
||||
"""Base exception for all OpenAI-compatible API errors."""
|
||||
|
||||
def __init__(self, message: str = "", status_code: Optional[int] = None,
|
||||
body=None):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.status_code = status_code
|
||||
self.body = body
|
||||
|
||||
|
||||
class APIError(OpenAIError):
|
||||
"""Generic API error (5xx and unclassified errors)."""
|
||||
|
||||
|
||||
class APIConnectionError(OpenAIError):
|
||||
"""Network / connection failure (DNS, refused, reset...)."""
|
||||
|
||||
|
||||
class Timeout(OpenAIError):
|
||||
"""Request timeout. Aliased as APITimeoutError for new-SDK style imports."""
|
||||
|
||||
|
||||
class AuthenticationError(OpenAIError):
|
||||
"""401 Unauthorized."""
|
||||
|
||||
|
||||
class PermissionDeniedError(OpenAIError):
|
||||
"""403 Forbidden."""
|
||||
|
||||
|
||||
class NotFoundError(OpenAIError):
|
||||
"""404 Not Found."""
|
||||
|
||||
|
||||
class InvalidRequestError(OpenAIError):
|
||||
"""400 Bad Request. Aliased as BadRequestError."""
|
||||
|
||||
|
||||
class RateLimitError(OpenAIError):
|
||||
"""429 Too Many Requests."""
|
||||
|
||||
|
||||
# Aliases used by some new-SDK-style code paths in the project.
|
||||
APITimeoutError = Timeout
|
||||
BadRequestError = InvalidRequestError
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Backward-compat ``error`` module-style accessor
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Some legacy code in the codebase (and possibly user plugins) does
|
||||
# from models.openai.openai_compat import error
|
||||
# except error.RateLimitError: ...
|
||||
# Keep that path working by exposing an attribute namespace.
|
||||
class _ErrorModule:
|
||||
OpenAIError = OpenAIError
|
||||
RateLimitError = RateLimitError
|
||||
APIError = APIError
|
||||
APIConnectionError = APIConnectionError
|
||||
AuthenticationError = AuthenticationError
|
||||
Timeout = APITimeoutError # Renamed in new version
|
||||
InvalidRequestError = BadRequestError # Renamed in new version
|
||||
|
||||
error = ErrorModule()
|
||||
|
||||
# Also export with new names
|
||||
Timeout = APITimeoutError
|
||||
InvalidRequestError = BadRequestError
|
||||
|
||||
except ImportError:
|
||||
# Fall back to old OpenAI < 1.0 API
|
||||
try:
|
||||
import openai.error as error
|
||||
|
||||
# Export individual exceptions for direct import
|
||||
OpenAIError = error.OpenAIError
|
||||
RateLimitError = error.RateLimitError
|
||||
APIError = error.APIError
|
||||
APIConnectionError = error.APIConnectionError
|
||||
AuthenticationError = error.AuthenticationError
|
||||
InvalidRequestError = error.InvalidRequestError
|
||||
Timeout = error.Timeout
|
||||
BadRequestError = error.InvalidRequestError # Alias
|
||||
APITimeoutError = error.Timeout # Alias
|
||||
except (ImportError, AttributeError):
|
||||
# Neither version works, create dummy classes
|
||||
class OpenAIError(Exception):
|
||||
pass
|
||||
|
||||
class RateLimitError(OpenAIError):
|
||||
pass
|
||||
|
||||
class APIError(OpenAIError):
|
||||
pass
|
||||
|
||||
class APIConnectionError(OpenAIError):
|
||||
pass
|
||||
|
||||
class AuthenticationError(OpenAIError):
|
||||
pass
|
||||
|
||||
class InvalidRequestError(OpenAIError):
|
||||
pass
|
||||
|
||||
class Timeout(OpenAIError):
|
||||
pass
|
||||
|
||||
BadRequestError = InvalidRequestError
|
||||
APITimeoutError = Timeout
|
||||
|
||||
# Create error module
|
||||
class ErrorModule:
|
||||
OpenAIError = OpenAIError
|
||||
RateLimitError = RateLimitError
|
||||
APIError = APIError
|
||||
APIConnectionError = APIConnectionError
|
||||
AuthenticationError = AuthenticationError
|
||||
InvalidRequestError = InvalidRequestError
|
||||
Timeout = Timeout
|
||||
AuthenticationError = AuthenticationError
|
||||
PermissionDeniedError = PermissionDeniedError
|
||||
NotFoundError = NotFoundError
|
||||
InvalidRequestError = InvalidRequestError
|
||||
RateLimitError = RateLimitError
|
||||
|
||||
|
||||
error = _ErrorModule()
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# HTTP -> exception mapping
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
def map_http_error(status_code: Optional[int], message: str = "",
|
||||
body=None) -> OpenAIError:
|
||||
"""Convert an HTTP status (+ optional message/body) to the right subclass.
|
||||
|
||||
Used by HTTP-based bot wrappers so that downstream ``except RateLimitError``
|
||||
blocks behave identically to when the openai SDK was raising them.
|
||||
"""
|
||||
sc = status_code or 0
|
||||
msg = message or ""
|
||||
msg_lower = msg.lower()
|
||||
|
||||
# Connection-level (no status / non-HTTP failure)
|
||||
if sc == 0:
|
||||
if "timeout" in msg_lower or "timed out" in msg_lower:
|
||||
return Timeout(msg, sc, body)
|
||||
return APIConnectionError(msg, sc, body)
|
||||
|
||||
if sc == 408:
|
||||
return Timeout(msg, sc, body)
|
||||
if sc == 401:
|
||||
return AuthenticationError(msg, sc, body)
|
||||
if sc == 403:
|
||||
return PermissionDeniedError(msg, sc, body)
|
||||
if sc == 404:
|
||||
return NotFoundError(msg, sc, body)
|
||||
if sc == 429:
|
||||
return RateLimitError(msg, sc, body)
|
||||
if 400 <= sc < 500:
|
||||
return InvalidRequestError(msg, sc, body)
|
||||
if sc >= 500:
|
||||
return APIError(msg, sc, body)
|
||||
|
||||
return APIError(msg, sc, body)
|
||||
|
||||
|
||||
def wrap_http_error(http_err) -> OpenAIError:
|
||||
"""Adapter for :class:`OpenAIHTTPError` -> compat exception subclass.
|
||||
|
||||
Accepts any object with ``status_code`` / ``message`` / ``body`` attrs.
|
||||
"""
|
||||
sc = getattr(http_err, "status_code", None)
|
||||
msg = getattr(http_err, "message", "") or str(http_err)
|
||||
body = getattr(http_err, "body", None)
|
||||
return map_http_error(sc, msg, body)
|
||||
|
||||
error = ErrorModule()
|
||||
|
||||
# Export all for easy import
|
||||
__all__ = [
|
||||
'error',
|
||||
'OpenAIError',
|
||||
'RateLimitError',
|
||||
'APIError',
|
||||
'APIConnectionError',
|
||||
'AuthenticationError',
|
||||
'InvalidRequestError',
|
||||
'Timeout',
|
||||
'BadRequestError',
|
||||
'APITimeoutError',
|
||||
"error",
|
||||
"OpenAIError",
|
||||
"APIError",
|
||||
"APIConnectionError",
|
||||
"Timeout",
|
||||
"APITimeoutError",
|
||||
"AuthenticationError",
|
||||
"PermissionDeniedError",
|
||||
"NotFoundError",
|
||||
"InvalidRequestError",
|
||||
"BadRequestError",
|
||||
"RateLimitError",
|
||||
"map_http_error",
|
||||
"wrap_http_error",
|
||||
]
|
||||
|
||||
456
models/openai/openai_http_client.py
Normal file
456
models/openai/openai_http_client.py
Normal file
@@ -0,0 +1,456 @@
|
||||
# encoding:utf-8
|
||||
|
||||
"""
|
||||
Lightweight HTTP client for OpenAI-compatible APIs.
|
||||
|
||||
This client is a drop-in replacement for the parts of the `openai` SDK that this
|
||||
project actually uses (chat completions, completions, image generation), so we
|
||||
can drop the hard dependency on `openai==0.27.x`.
|
||||
|
||||
Design goals:
|
||||
- Pure `requests` based (no httpx / pydantic / openai SDK dependency).
|
||||
- Returns plain `dict` responses with the same shape OpenAI's HTTP API returns,
|
||||
so existing code that does `response["choices"][0]["message"]["content"]` /
|
||||
`response["usage"]["total_tokens"]` keeps working.
|
||||
- Streaming yields plain `dict` chunks (parsed SSE `data:` JSON), matching the
|
||||
shape that `agent/protocol/agent_stream.py` consumes:
|
||||
chunk["choices"][0]["delta"]["content" | "tool_calls" | "reasoning_content"]
|
||||
chunk["choices"][0]["finish_reason"]
|
||||
Plus dict-style error chunks: {"error": True, "message": ..., "status_code": ...}
|
||||
- Compatible with arbitrary OpenAI-compatible endpoints (LinkAI, Azure-style
|
||||
proxies, DeepSeek, Moonshot, etc.) by allowing per-call api_key / api_base
|
||||
override and trusting whatever path/payload shape the caller passes.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, Generator, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from common.log import logger
|
||||
|
||||
|
||||
DEFAULT_API_BASE = "https://api.openai.com/v1"
|
||||
DEFAULT_TIMEOUT = 600 # seconds; matches old openai SDK default
|
||||
|
||||
|
||||
class OpenAIHTTPError(Exception):
|
||||
"""Raised for non-2xx responses. Carries status code + parsed body."""
|
||||
|
||||
def __init__(self, status_code: int, body: Any, message: str = ""):
|
||||
self.status_code = status_code
|
||||
self.body = body
|
||||
# Try to extract human-readable message from OpenAI-style error envelope
|
||||
if not message and isinstance(body, dict):
|
||||
err = body.get("error") or {}
|
||||
if isinstance(err, dict):
|
||||
message = err.get("message") or ""
|
||||
elif isinstance(err, str):
|
||||
message = err
|
||||
if not message:
|
||||
message = str(body)[:500]
|
||||
self.message = message
|
||||
super().__init__(f"HTTP {status_code}: {message}")
|
||||
|
||||
|
||||
class OpenAIHTTPClient:
|
||||
"""Minimal HTTP client for OpenAI-compatible endpoints.
|
||||
|
||||
Per-instance defaults (api_key / api_base / proxy / timeout) can be
|
||||
overridden on every call. Callers can also pass ``extra_headers`` for
|
||||
Azure-style ``api-key`` headers or custom routing headers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
proxy: Optional[str] = None,
|
||||
timeout: Optional[float] = None,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.api_base = (api_base or DEFAULT_API_BASE).rstrip("/")
|
||||
self.timeout = timeout if timeout is not None else DEFAULT_TIMEOUT
|
||||
self.proxies = (
|
||||
{"http": proxy, "https": proxy} if proxy else None
|
||||
)
|
||||
self.extra_headers = dict(extra_headers) if extra_headers else {}
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Public API surface (mirrors what the old openai SDK provided)
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def chat_completions(
|
||||
self,
|
||||
*,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
timeout: Optional[float] = None,
|
||||
proxy: Optional[str] = None,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_query: Optional[Dict[str, str]] = None,
|
||||
path: str = "/chat/completions",
|
||||
stream: bool = False,
|
||||
**payload,
|
||||
):
|
||||
"""POST /chat/completions.
|
||||
|
||||
When ``stream=True`` returns a generator yielding parsed SSE chunks
|
||||
(plain ``dict``). On error during streaming, yields a single dict with
|
||||
``{"error": True, ...}`` and stops, matching the contract expected by
|
||||
``agent/protocol/agent_stream.py``.
|
||||
"""
|
||||
payload["stream"] = stream
|
||||
return self._request(
|
||||
path=path,
|
||||
payload=payload,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
proxy=proxy,
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def completions(
|
||||
self,
|
||||
*,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
timeout: Optional[float] = None,
|
||||
**payload,
|
||||
) -> Dict[str, Any]:
|
||||
"""POST /completions (legacy text completion). Non-streaming only."""
|
||||
payload.pop("stream", None)
|
||||
return self._request(
|
||||
path="/completions",
|
||||
payload=payload,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
def images_generate(
|
||||
self,
|
||||
*,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
timeout: Optional[float] = None,
|
||||
**payload,
|
||||
) -> Dict[str, Any]:
|
||||
"""POST /images/generations."""
|
||||
return self._request(
|
||||
path="/images/generations",
|
||||
payload=payload,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def _build_headers(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
extra_headers: Optional[Dict[str, str]],
|
||||
) -> Dict[str, str]:
|
||||
key = api_key if api_key is not None else self.api_key
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if key:
|
||||
headers["Authorization"] = f"Bearer {key}"
|
||||
if self.extra_headers:
|
||||
headers.update(self.extra_headers)
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
return headers
|
||||
|
||||
def _request(
|
||||
self,
|
||||
*,
|
||||
path: str,
|
||||
payload: Dict[str, Any],
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Optional[float],
|
||||
stream: bool,
|
||||
proxy: Optional[str] = None,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_query: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
base = (api_base or self.api_base).rstrip("/") if api_base else self.api_base
|
||||
url = f"{base}{path}" if path.startswith("/") else f"{base}/{path}"
|
||||
headers = self._build_headers(api_key, extra_headers)
|
||||
req_timeout = timeout if timeout is not None else self.timeout
|
||||
proxies = (
|
||||
{"http": proxy, "https": proxy} if proxy else self.proxies
|
||||
)
|
||||
|
||||
# Drop None-valued keys; some providers reject explicit nulls.
|
||||
clean_payload = {k: v for k, v in payload.items() if v is not None}
|
||||
|
||||
if stream:
|
||||
# Return a generator. Errors during stream are yielded as a single
|
||||
# error chunk so callers (agent_stream) can map them to their
|
||||
# existing error-handling path without try/except around the loop.
|
||||
return self._stream_chat(
|
||||
url=url,
|
||||
headers=headers,
|
||||
payload=clean_payload,
|
||||
proxies=proxies,
|
||||
timeout=req_timeout,
|
||||
params=extra_query,
|
||||
)
|
||||
|
||||
try:
|
||||
resp = requests.post(
|
||||
url,
|
||||
headers=headers,
|
||||
json=clean_payload,
|
||||
timeout=req_timeout,
|
||||
proxies=proxies,
|
||||
params=extra_query,
|
||||
)
|
||||
except requests.exceptions.Timeout as e:
|
||||
raise OpenAIHTTPError(408, {}, f"Request timed out: {e}")
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
raise OpenAIHTTPError(0, {}, f"Connection error: {e}")
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise OpenAIHTTPError(0, {}, f"Request failed: {e}")
|
||||
|
||||
return self._parse_response(resp)
|
||||
|
||||
@staticmethod
|
||||
def _parse_response(resp: requests.Response) -> Dict[str, Any]:
|
||||
# Try JSON, fall back to text
|
||||
try:
|
||||
data = resp.json()
|
||||
except ValueError:
|
||||
data = {"raw": resp.text}
|
||||
|
||||
if resp.status_code >= 400:
|
||||
raise OpenAIHTTPError(resp.status_code, data)
|
||||
|
||||
return data
|
||||
|
||||
def _stream_chat(
|
||||
self,
|
||||
*,
|
||||
url: str,
|
||||
headers: Dict[str, str],
|
||||
payload: Dict[str, Any],
|
||||
proxies: Optional[Dict[str, str]],
|
||||
timeout: float,
|
||||
params: Optional[Dict[str, str]] = None,
|
||||
) -> Generator[Dict[str, Any], None, None]:
|
||||
"""Stream SSE response and yield parsed JSON chunks.
|
||||
|
||||
Yields:
|
||||
- Normal chunks: dict with ``choices[0].delta`` etc.
|
||||
- Error chunks: ``{"error": True, "message": str, "status_code": int}``
|
||||
followed by termination of the generator.
|
||||
"""
|
||||
try:
|
||||
resp = requests.post(
|
||||
url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=timeout,
|
||||
proxies=proxies,
|
||||
stream=True,
|
||||
params=params,
|
||||
)
|
||||
except requests.exceptions.Timeout as e:
|
||||
yield self._make_error_chunk(408, f"Request timed out: {e}")
|
||||
return
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
yield self._make_error_chunk(0, f"Connection error: {e}")
|
||||
return
|
||||
except requests.exceptions.RequestException as e:
|
||||
yield self._make_error_chunk(0, f"Request failed: {e}")
|
||||
return
|
||||
|
||||
if resp.status_code >= 400:
|
||||
# Read full body once for error reporting
|
||||
try:
|
||||
body = resp.json()
|
||||
except ValueError:
|
||||
body = {"raw": resp.text[:1000]}
|
||||
err_msg = ""
|
||||
err_code = ""
|
||||
err_type = ""
|
||||
if isinstance(body, dict):
|
||||
err = body.get("error") or {}
|
||||
if isinstance(err, dict):
|
||||
err_msg = err.get("message") or ""
|
||||
err_code = err.get("code") or ""
|
||||
err_type = err.get("type") or ""
|
||||
elif isinstance(err, str):
|
||||
err_msg = err
|
||||
if not err_msg:
|
||||
err_msg = str(body)[:500]
|
||||
yield {
|
||||
"error": {
|
||||
"message": err_msg,
|
||||
"code": err_code,
|
||||
"type": err_type,
|
||||
},
|
||||
# Top-level fields kept for backward compatibility with the
|
||||
# error-shape that `_handle_stream_response` previously emitted.
|
||||
"message": err_msg,
|
||||
"status_code": resp.status_code,
|
||||
}
|
||||
return
|
||||
|
||||
# IMPORTANT: do NOT use `iter_lines(decode_unicode=True)`.
|
||||
#
|
||||
# `requests` decodes per-network-chunk using the response's declared
|
||||
# encoding (often Latin-1 / ISO-8859-1 for SSE), which mangles UTF-8
|
||||
# codepoints that straddle a chunk boundary. Some upstreams (Azure
|
||||
# OpenAI proxies, Cloudflare-fronted gateways, ...) split TCP chunks
|
||||
# aggressively in the middle of multibyte characters, producing
|
||||
# garbled text and "skip malformed SSE chunk" errors.
|
||||
#
|
||||
# The fix is to read raw bytes, accumulate them until we have a
|
||||
# complete SSE event (terminated by a blank line per the SSE spec:
|
||||
# https://html.spec.whatwg.org/multipage/server-sent-events.html),
|
||||
# and only THEN decode as UTF-8. This mirrors what the official
|
||||
# openai SDK 1.x does in `openai/_streaming.py::SSEDecoder` (which
|
||||
# itself is copied from httpx-sse).
|
||||
try:
|
||||
for sse_event in self._iter_sse_events(resp):
|
||||
# `sse_event` is the joined `data:` payload as a str.
|
||||
if sse_event == "[DONE]":
|
||||
return
|
||||
if not sse_event:
|
||||
continue
|
||||
try:
|
||||
chunk = json.loads(sse_event)
|
||||
except ValueError:
|
||||
logger.debug(
|
||||
f"[OpenAIHTTP] skip malformed SSE chunk: {sse_event[:200]}"
|
||||
)
|
||||
continue
|
||||
yield chunk
|
||||
except requests.exceptions.ChunkedEncodingError as e:
|
||||
yield self._make_error_chunk(0, f"Stream interrupted: {e}")
|
||||
except requests.exceptions.RequestException as e:
|
||||
yield self._make_error_chunk(0, f"Stream error: {e}")
|
||||
finally:
|
||||
try:
|
||||
resp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _iter_sse_events(resp: requests.Response) -> Generator[str, None, None]:
|
||||
"""Decode an SSE byte stream into joined `data:` payloads.
|
||||
|
||||
Implements the subset of the SSE spec that OpenAI / OpenAI-compatible
|
||||
endpoints actually use:
|
||||
- Events are separated by blank lines (\\r\\r, \\n\\n, or \\r\\n\\r\\n).
|
||||
- Within an event, multiple ``data:`` lines are concatenated with
|
||||
"\\n" (per spec).
|
||||
- ``event:``, ``id:``, ``retry:`` and comment lines (``:``) are
|
||||
tolerated but not yielded — for chat-completion we only care
|
||||
about the JSON payload in ``data:``.
|
||||
- Bytes are buffered until a complete event boundary is seen so
|
||||
UTF-8 codepoints split across TCP chunks decode correctly.
|
||||
|
||||
Yields each event's joined ``data`` string. The terminal sentinel
|
||||
``[DONE]`` is yielded as a literal string so the caller can break.
|
||||
"""
|
||||
buf = b""
|
||||
for raw in resp.iter_content(chunk_size=None, decode_unicode=False):
|
||||
if not raw:
|
||||
continue
|
||||
buf += raw
|
||||
# Find complete events (terminated by a blank line).
|
||||
while True:
|
||||
# Look for the earliest event terminator. SSE allows three
|
||||
# forms; check all and pick the earliest match.
|
||||
idx_nn = buf.find(b"\n\n")
|
||||
idx_rr = buf.find(b"\r\r")
|
||||
idx_rnrn = buf.find(b"\r\n\r\n")
|
||||
candidates = [i for i in (idx_nn, idx_rr, idx_rnrn) if i != -1]
|
||||
if not candidates:
|
||||
break
|
||||
# We need to know the length of the matched terminator to
|
||||
# advance past it correctly.
|
||||
end_pos = min(candidates)
|
||||
if end_pos == idx_rnrn:
|
||||
term_len = 4
|
||||
else:
|
||||
term_len = 2
|
||||
event_bytes = buf[:end_pos]
|
||||
buf = buf[end_pos + term_len:]
|
||||
|
||||
# Decode the full event as UTF-8. ``errors="replace"`` is a
|
||||
# belt-and-suspenders safety net for truly malformed upstream
|
||||
# bytes; it should never trigger for well-formed providers.
|
||||
try:
|
||||
event_text = event_bytes.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
event_text = event_bytes.decode("utf-8", errors="replace")
|
||||
|
||||
data_lines = []
|
||||
for line in event_text.splitlines():
|
||||
if not line or line.startswith(":"):
|
||||
continue
|
||||
field, _, value = line.partition(":")
|
||||
# Per SSE spec, a single optional space after the colon
|
||||
# is part of the framing, not the value.
|
||||
if value.startswith(" "):
|
||||
value = value[1:]
|
||||
if field == "data":
|
||||
data_lines.append(value)
|
||||
# Other fields (event/id/retry) are intentionally ignored
|
||||
# — chat-completion endpoints don't use them in a way we
|
||||
# need for parsing.
|
||||
if data_lines:
|
||||
yield "\n".join(data_lines)
|
||||
|
||||
# Flush any trailing bytes the server forgot to terminate. This is
|
||||
# rare but spec-allowed (some providers omit the final \n\n).
|
||||
if buf.strip():
|
||||
try:
|
||||
event_text = buf.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
event_text = buf.decode("utf-8", errors="replace")
|
||||
data_lines = []
|
||||
for line in event_text.splitlines():
|
||||
if not line or line.startswith(":"):
|
||||
continue
|
||||
field, _, value = line.partition(":")
|
||||
if value.startswith(" "):
|
||||
value = value[1:]
|
||||
if field == "data":
|
||||
data_lines.append(value)
|
||||
if data_lines:
|
||||
yield "\n".join(data_lines)
|
||||
|
||||
@staticmethod
|
||||
def _make_error_chunk(status_code: int, message: str) -> Dict[str, Any]:
|
||||
return {
|
||||
"error": {"message": message, "code": "", "type": ""},
|
||||
"message": message,
|
||||
"status_code": status_code,
|
||||
}
|
||||
|
||||
|
||||
# A tiny helper for callers that just need a one-shot client without storing
|
||||
# state. Keeps call sites cleaner than instantiating the class every time.
|
||||
def get_default_client(
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
proxy: Optional[str] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> OpenAIHTTPClient:
|
||||
return OpenAIHTTPClient(
|
||||
api_key=api_key, api_base=api_base, proxy=proxy, timeout=timeout
|
||||
)
|
||||
@@ -8,11 +8,11 @@ This includes: OpenAI, LinkAI, Azure OpenAI, and many third-party providers.
|
||||
"""
|
||||
|
||||
import json
|
||||
import openai
|
||||
import requests
|
||||
from typing import Optional
|
||||
from common.log import logger
|
||||
from agent.protocol.message_utils import drop_orphaned_tool_results_openai
|
||||
from models.openai.openai_http_client import OpenAIHTTPClient, OpenAIHTTPError
|
||||
|
||||
|
||||
class OpenAICompatibleBot:
|
||||
@@ -135,49 +135,87 @@ class OpenAICompatibleBot:
|
||||
"status_code": 500
|
||||
}
|
||||
|
||||
def _get_http_client(self) -> OpenAIHTTPClient:
|
||||
"""Build an HTTP client honoring the global proxy config.
|
||||
|
||||
Subclasses can override this for custom auth headers (e.g. Azure's
|
||||
``api-key`` header) by returning a pre-configured client.
|
||||
"""
|
||||
from config import conf
|
||||
proxy = conf().get("proxy") or None
|
||||
return OpenAIHTTPClient(proxy=proxy)
|
||||
|
||||
def _handle_sync_response(self, request_params, api_key, api_base):
|
||||
"""Handle synchronous OpenAI API response"""
|
||||
"""Handle synchronous chat-completion via HTTP."""
|
||||
params = dict(request_params)
|
||||
params.pop("stream", None)
|
||||
# Translate legacy SDK timeout kwarg to our HTTP client kwarg.
|
||||
timeout = params.pop("request_timeout", None) or params.pop("timeout", None)
|
||||
try:
|
||||
# Build kwargs with explicit API configuration
|
||||
kwargs = dict(request_params)
|
||||
if api_key:
|
||||
kwargs["api_key"] = api_key
|
||||
if api_base:
|
||||
kwargs["api_base"] = api_base
|
||||
|
||||
response = openai.ChatCompletion.create(**kwargs)
|
||||
return response
|
||||
|
||||
client = self._get_http_client()
|
||||
return client.chat_completions(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
stream=False,
|
||||
**params,
|
||||
)
|
||||
except OpenAIHTTPError as e:
|
||||
logger.error(
|
||||
f"[{self.__class__.__name__}] sync response error: "
|
||||
f"HTTP {e.status_code}: {e.message}"
|
||||
)
|
||||
return {
|
||||
"error": True,
|
||||
"message": e.message,
|
||||
"status_code": e.status_code or 500,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.__class__.__name__}] sync response error: {e}")
|
||||
return {
|
||||
"error": True,
|
||||
"message": str(e),
|
||||
"status_code": 500
|
||||
"status_code": 500,
|
||||
}
|
||||
|
||||
def _handle_stream_response(self, request_params, api_key, api_base):
|
||||
"""Handle streaming OpenAI API response"""
|
||||
"""Handle streaming chat-completion via HTTP (SSE).
|
||||
|
||||
Yields dict chunks in OpenAI's standard streaming shape:
|
||||
{"choices": [{"delta": {...}, "finish_reason": ...}], ...}
|
||||
On error, yields a single ``{"error": ..., "status_code": ...}`` chunk
|
||||
— the same contract :mod:`agent.protocol.agent_stream` already handles.
|
||||
"""
|
||||
params = dict(request_params)
|
||||
params.pop("stream", None)
|
||||
timeout = params.pop("request_timeout", None) or params.pop("timeout", None)
|
||||
try:
|
||||
# Build kwargs with explicit API configuration
|
||||
kwargs = dict(request_params)
|
||||
if api_key:
|
||||
kwargs["api_key"] = api_key
|
||||
if api_base:
|
||||
kwargs["api_base"] = api_base
|
||||
|
||||
stream = openai.ChatCompletion.create(**kwargs)
|
||||
|
||||
# Stream chunks to caller
|
||||
client = self._get_http_client()
|
||||
stream = client.chat_completions(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
stream=True,
|
||||
**params,
|
||||
)
|
||||
for chunk in stream:
|
||||
yield chunk
|
||||
|
||||
except OpenAIHTTPError as e:
|
||||
logger.error(
|
||||
f"[{self.__class__.__name__}] stream response error: "
|
||||
f"HTTP {e.status_code}: {e.message}"
|
||||
)
|
||||
yield {
|
||||
"error": True,
|
||||
"message": e.message,
|
||||
"status_code": e.status_code or 500,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.__class__.__name__}] stream response error: {e}")
|
||||
yield {
|
||||
"error": True,
|
||||
"message": str(e),
|
||||
"status_code": 500
|
||||
"status_code": 500,
|
||||
}
|
||||
|
||||
def _convert_tools_to_openai_format(self, tools):
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
openai==0.27.8
|
||||
aiohttp>=3.8.6,<3.10
|
||||
requests>=2.28.2
|
||||
chardet>=5.1.0
|
||||
|
||||
@@ -3,8 +3,6 @@ google voice service
|
||||
"""
|
||||
import json
|
||||
|
||||
import openai
|
||||
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
@@ -15,7 +13,9 @@ import datetime, random
|
||||
|
||||
class OpenaiVoice(Voice):
|
||||
def __init__(self):
|
||||
openai.api_key = conf().get("open_ai_api_key")
|
||||
# No-op: this implementation calls OpenAI HTTP endpoints directly via
|
||||
# `requests`, so it does not need a global SDK to be configured.
|
||||
pass
|
||||
|
||||
def voiceToText(self, voice_file):
|
||||
logger.debug("[Openai] voice file name={}".format(voice_file))
|
||||
|
||||
Reference in New Issue
Block a user