From 8752f0cc606638f04361cb049c9add9f2000186e Mon Sep 17 00:00:00 2001 From: zhayujie Date: Mon, 27 Apr 2026 20:21:54 +0800 Subject: [PATCH] refactor(openai): drop SDK dependency and switch to native HTTP client --- models/chatgpt/chat_gpt_bot.py | 211 +++++++++---- models/openai/open_ai_bot.py | 190 ++++-------- models/openai/open_ai_image.py | 43 ++- models/openai/openai_compat.py | 251 +++++++++------ models/openai/openai_http_client.py | 456 ++++++++++++++++++++++++++++ models/openai_compatible_bot.py | 92 ++++-- requirements.txt | 1 - voice/openai/openai_voice.py | 6 +- 8 files changed, 920 insertions(+), 330 deletions(-) create mode 100644 models/openai/openai_http_client.py diff --git a/models/chatgpt/chat_gpt_bot.py b/models/chatgpt/chat_gpt_bot.py index 1c01c902..7e7e9aa9 100644 --- a/models/chatgpt/chat_gpt_bot.py +++ b/models/chatgpt/chat_gpt_bot.py @@ -3,8 +3,15 @@ import time import json -import openai -from models.openai.openai_compat import error as openai_error, RateLimitError, Timeout, APIError, APIConnectionError +from models.openai.openai_compat import ( + error as openai_error, + RateLimitError, + Timeout, + APIError, + APIConnectionError, + wrap_http_error, +) +from models.openai.openai_http_client import OpenAIHTTPClient, OpenAIHTTPError import requests from common import const from models.bot import Bot @@ -23,18 +30,19 @@ from models.baidu.baidu_wenxin_session import BaiduWenxinSession class ChatGPTBot(Bot, OpenAIImage, OpenAICompatibleBot): def __init__(self): super().__init__() - # set the default api_key / api_base based on bot_type + # Resolve api key / base from config (no global SDK state anymore). if conf().get("bot_type") == "custom": - openai.api_key = conf().get("custom_api_key", "") - if conf().get("custom_api_base"): - openai.api_base = conf().get("custom_api_base") + self._api_key = conf().get("custom_api_key", "") + self._api_base = conf().get("custom_api_base") or None else: - openai.api_key = conf().get("open_ai_api_key") - if conf().get("open_ai_api_base"): - openai.api_base = conf().get("open_ai_api_base") - proxy = conf().get("proxy") - if proxy: - openai.proxy = proxy + self._api_key = conf().get("open_ai_api_key") + self._api_base = conf().get("open_ai_api_base") or None + self._proxy = conf().get("proxy") or None + self._http_client = OpenAIHTTPClient( + api_key=self._api_key, + api_base=self._api_base, + proxy=self._proxy, + ) if conf().get("rate_limit_chatgpt"): self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20)) conf_model = conf().get("model") or "gpt-3.5-turbo" @@ -71,6 +79,10 @@ class ChatGPTBot(Bot, OpenAIImage, OpenAICompatibleBot): 'default_frequency_penalty': conf().get("frequency_penalty", 0.0), 'default_presence_penalty': conf().get("presence_penalty", 0.0), } + + def _get_http_client(self) -> OpenAIHTTPClient: + """Override the default HTTP client to reuse our pre-configured one.""" + return self._http_client def reply(self, query, context=None): # acquire reply content @@ -195,20 +207,16 @@ class ChatGPTBot(Bot, OpenAIImage, OpenAICompatibleBot): logger.info(f"[CHATGPT] Calling vision API with model: {model}") - # Call OpenAI API - kwargs = { - "model": model, - "messages": messages, - "max_tokens": 1000 - } - if api_key: - kwargs["api_key"] = api_key - if api_base: - kwargs["api_base"] = api_base - - response = openai.ChatCompletion.create(**kwargs) - - content = response.choices[0]["message"]["content"] + # Call OpenAI-compatible API via HTTP + response = self._http_client.chat_completions( + api_key=api_key or None, + api_base=api_base or None, + model=model, + messages=messages, + max_tokens=1000, + ) + + content = response["choices"][0]["message"]["content"] logger.info(f"[CHATGPT] Vision API response: {content[:100]}...") # Clean up temp file @@ -237,57 +245,100 @@ class ChatGPTBot(Bot, OpenAIImage, OpenAICompatibleBot): try: if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token(): raise RateLimitError("RateLimitError: rate limit exceeded") - # if api_key == None, the default openai.api_key will be used + # If api_key is None, the per-instance default key will be used. if args is None: args = self.args - response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **args) - # logger.debug("[CHATGPT] response={}".format(response)) - logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"])) + # Translate old SDK kwargs to HTTP client params: + # - request_timeout / timeout -> per-call timeout + call_args = dict(args) + timeout = call_args.pop("request_timeout", None) or call_args.pop("timeout", None) + response = self._http_client.chat_completions( + api_key=api_key or None, + timeout=timeout, + messages=session.messages, + **call_args, + ) + logger.info("[ChatGPT] reply={}, total_tokens={}".format( + response["choices"][0]["message"]["content"], + response["usage"]["total_tokens"] + )) return { "total_tokens": response["usage"]["total_tokens"], "completion_tokens": response["usage"]["completion_tokens"], - "content": response.choices[0]["message"]["content"], + "content": response["choices"][0]["message"]["content"], } + except OpenAIHTTPError as http_err: + return self._handle_reply_error( + wrap_http_error(http_err), session, api_key, args, retry_count + ) except Exception as e: - need_retry = retry_count < 2 - result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} - if isinstance(e, RateLimitError): - logger.warn("[CHATGPT] RateLimitError: {}".format(e)) - result["content"] = "提问太快啦,请休息一下再问我吧" - if need_retry: - time.sleep(20) - elif isinstance(e, Timeout): - logger.warn("[CHATGPT] Timeout: {}".format(e)) - result["content"] = "我没有收到你的消息" - if need_retry: - time.sleep(5) - elif isinstance(e, APIError): - logger.warn("[CHATGPT] Bad Gateway: {}".format(e)) - result["content"] = "请再问我一次" - if need_retry: - time.sleep(10) - elif isinstance(e, APIConnectionError): - logger.warn("[CHATGPT] APIConnectionError: {}".format(e)) - result["content"] = "我连接不到你的网络" - if need_retry: - time.sleep(5) - else: - logger.exception("[CHATGPT] Exception: {}".format(e)) - need_retry = False - self.sessions.clear_session(session.session_id) + return self._handle_reply_error(e, session, api_key, args, retry_count) + def _handle_reply_error(self, e, session, api_key, args, retry_count): + """Map exception to user-facing reply with retry/backoff (mirrors SDK behavior).""" + need_retry = retry_count < 2 + result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} + if isinstance(e, RateLimitError): + logger.warn("[CHATGPT] RateLimitError: {}".format(e)) + result["content"] = "提问太快啦,请休息一下再问我吧" if need_retry: - logger.warn("[CHATGPT] 第{}次重试".format(retry_count + 1)) - return self.reply_text(session, api_key, args, retry_count + 1) - else: - return result + time.sleep(20) + elif isinstance(e, Timeout): + logger.warn("[CHATGPT] Timeout: {}".format(e)) + result["content"] = "我没有收到你的消息" + if need_retry: + time.sleep(5) + elif isinstance(e, APIConnectionError): + logger.warn("[CHATGPT] APIConnectionError: {}".format(e)) + result["content"] = "我连接不到你的网络" + if need_retry: + time.sleep(5) + elif isinstance(e, APIError): + logger.warn("[CHATGPT] Bad Gateway: {}".format(e)) + result["content"] = "请再问我一次" + if need_retry: + time.sleep(10) + else: + logger.exception("[CHATGPT] Exception: {}".format(e)) + need_retry = False + self.sessions.clear_session(session.session_id) + + if need_retry: + logger.warn("[CHATGPT] 第{}次重试".format(retry_count + 1)) + return self.reply_text(session, api_key, args, retry_count + 1) + return result class AzureChatGPTBot(ChatGPTBot): + """Azure OpenAI variant. + + Azure's HTTP shape differs from public OpenAI: + URL : {endpoint}/openai/deployments/{deployment}/chat/completions + Auth : api-key header (not Bearer) + Query : ?api-version={version} + We model that with a dedicated HTTP client and override _get_http_client + so the OpenAICompatibleBot streaming/tool path uses it transparently. + """ + def __init__(self): super().__init__() - openai.api_type = "azure" - openai.api_version = conf().get("azure_api_version", "2023-06-01-preview") - self.args["deployment_id"] = conf().get("azure_deployment_id") + self._azure_api_version = conf().get("azure_api_version", "2023-06-01-preview") + self._azure_deployment_id = conf().get("azure_deployment_id") + # Drop legacy SDK kwarg; Azure deployment is encoded in the URL now. + self.args.pop("deployment_id", None) + + endpoint = (self._api_base or "").rstrip("/") + deployment = self._azure_deployment_id or "" + # Build a base that already includes /openai/deployments/{deployment}. + # /chat/completions will be appended by the client. + azure_base = ( + f"{endpoint}/openai/deployments/{deployment}" if endpoint and deployment else endpoint + ) + self._http_client = _AzureChatHTTPClient( + api_key=self._api_key, + api_base=azure_base, + api_version=self._azure_api_version, + proxy=self._proxy, + ) def create_img(self, query, retry_count=0, api_key=None): text_to_image_model = conf().get("text_to_image") @@ -357,3 +408,35 @@ class AzureChatGPTBot(ChatGPTBot): return False, "图片生成失败" else: return False, "图片生成失败,未配置text_to_image参数" + + +class _AzureChatHTTPClient(OpenAIHTTPClient): + """Subclass that injects Azure's ``api-version`` query param and ``api-key`` + header on every chat-completion request, and accepts the deployment-scoped + base URL set by :class:`AzureChatGPTBot`. + """ + + def __init__(self, api_key, api_base, api_version, proxy=None, timeout=None): + super().__init__( + api_key=api_key, api_base=api_base, proxy=proxy, timeout=timeout + ) + self._api_version = api_version + + def _build_headers(self, api_key, extra_headers): + # Azure uses api-key header, not Bearer token. + key = api_key if api_key is not None else self.api_key + headers = {"Content-Type": "application/json"} + if key: + headers["api-key"] = key + if self.extra_headers: + headers.update(self.extra_headers) + if extra_headers: + headers.update(extra_headers) + return headers + + def chat_completions(self, **kwargs): + # Always force api-version query param for Azure. + eq = dict(kwargs.get("extra_query") or {}) + eq.setdefault("api-version", self._api_version) + kwargs["extra_query"] = eq + return super().chat_completions(**kwargs) diff --git a/models/openai/open_ai_bot.py b/models/openai/open_ai_bot.py index d603f6aa..ef9df3f5 100644 --- a/models/openai/open_ai_bot.py +++ b/models/openai/open_ai_bot.py @@ -2,8 +2,14 @@ import time -import openai -from models.openai.openai_compat import RateLimitError, Timeout, APIConnectionError +from models.openai.openai_compat import ( + RateLimitError, + Timeout, + APIConnectionError, + APIError, + wrap_http_error, +) +from models.openai.openai_http_client import OpenAIHTTPClient, OpenAIHTTPError from models.bot import Bot from models.openai_compatible_bot import OpenAICompatibleBot @@ -22,12 +28,14 @@ user_session = dict() class OpenAIBot(Bot, OpenAIImage, OpenAICompatibleBot): def __init__(self): super().__init__() - openai.api_key = conf().get("open_ai_api_key") - if conf().get("open_ai_api_base"): - openai.api_base = conf().get("open_ai_api_base") - proxy = conf().get("proxy") - if proxy: - openai.proxy = proxy + self._api_key = conf().get("open_ai_api_key") + self._api_base = conf().get("open_ai_api_base") or None + self._proxy = conf().get("proxy") or None + self._http_client = OpenAIHTTPClient( + api_key=self._api_key, + api_base=self._api_base, + proxy=self._proxy, + ) self.sessions = SessionManager(OpenAISession, model=conf().get("model") or "text-davinci-003") self.args = { @@ -54,6 +62,10 @@ class OpenAIBot(Bot, OpenAIImage, OpenAICompatibleBot): 'default_presence_penalty': conf().get("presence_penalty", 0.0), } + def _get_http_client(self) -> OpenAIHTTPClient: + """Reuse the per-instance HTTP client for the streaming/tool path.""" + return self._http_client + def reply(self, query, context=None): # acquire reply content if context and context.type: @@ -96,8 +108,14 @@ class OpenAIBot(Bot, OpenAIImage, OpenAICompatibleBot): def reply_text(self, session: OpenAISession, retry_count=0): try: - response = openai.Completion.create(prompt=str(session), **self.args) - res_content = response.choices[0]["text"].strip().replace("<|endoftext|>", "") + call_args = dict(self.args) + timeout = call_args.pop("request_timeout", None) or call_args.pop("timeout", None) + response = self._http_client.completions( + timeout=timeout, + prompt=str(session), + **call_args, + ) + res_content = response["choices"][0]["text"].strip().replace("<|endoftext|>", "") total_tokens = response["usage"]["total_tokens"] completion_tokens = response["usage"]["completion_tokens"] logger.info("[OPEN_AI] reply={}".format(res_content)) @@ -106,125 +124,41 @@ class OpenAIBot(Bot, OpenAIImage, OpenAICompatibleBot): "completion_tokens": completion_tokens, "content": res_content, } + except OpenAIHTTPError as http_err: + return self._handle_legacy_error(wrap_http_error(http_err), session, retry_count) except Exception as e: - need_retry = retry_count < 2 - result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} - if isinstance(e, RateLimitError): - logger.warn("[OPEN_AI] RateLimitError: {}".format(e)) - result["content"] = "提问太快啦,请休息一下再问我吧" - if need_retry: - time.sleep(20) - elif isinstance(e, Timeout): - logger.warn("[OPEN_AI] Timeout: {}".format(e)) - result["content"] = "我没有收到你的消息" - if need_retry: - time.sleep(5) - elif isinstance(e, APIConnectionError): - logger.warn("[OPEN_AI] APIConnectionError: {}".format(e)) - need_retry = False - result["content"] = "我连接不到你的网络" - else: - logger.warn("[OPEN_AI] Exception: {}".format(e)) - need_retry = False - self.sessions.clear_session(session.session_id) + return self._handle_legacy_error(e, session, retry_count) + def _handle_legacy_error(self, e, session, retry_count): + """Map exception -> reply for the legacy /completions endpoint.""" + need_retry = retry_count < 2 + result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} + if isinstance(e, RateLimitError): + logger.warn("[OPEN_AI] RateLimitError: {}".format(e)) + result["content"] = "提问太快啦,请休息一下再问我吧" if need_retry: - logger.warn("[OPEN_AI] 第{}次重试".format(retry_count + 1)) - return self.reply_text(session, retry_count + 1) - else: - return result + time.sleep(20) + elif isinstance(e, Timeout): + logger.warn("[OPEN_AI] Timeout: {}".format(e)) + result["content"] = "我没有收到你的消息" + if need_retry: + time.sleep(5) + elif isinstance(e, APIConnectionError): + logger.warn("[OPEN_AI] APIConnectionError: {}".format(e)) + need_retry = False + result["content"] = "我连接不到你的网络" + else: + logger.warn("[OPEN_AI] Exception: {}".format(e)) + need_retry = False + self.sessions.clear_session(session.session_id) - def call_with_tools(self, messages, tools=None, stream=False, **kwargs): - """ - Call OpenAI API with tool support for agent integration - Note: This bot uses the old Completion API which doesn't support tools. - For tool support, use ChatGPTBot instead. - - This method converts to ChatCompletion API when tools are provided. - - Args: - messages: List of messages - tools: List of tool definitions (OpenAI format) - stream: Whether to use streaming - **kwargs: Additional parameters - - Returns: - Formatted response in OpenAI format or generator for streaming - """ - try: - # The old Completion API doesn't support tools - # We need to use ChatCompletion API instead - logger.info("[OPEN_AI] Using ChatCompletion API for tool support") - - # Build request parameters for ChatCompletion - request_params = { - "model": kwargs.get("model", conf().get("model") or "gpt-4.1"), - "messages": messages, - "temperature": kwargs.get("temperature", conf().get("temperature", 0.9)), - "top_p": kwargs.get("top_p", 1), - "frequency_penalty": kwargs.get("frequency_penalty", conf().get("frequency_penalty", 0.0)), - "presence_penalty": kwargs.get("presence_penalty", conf().get("presence_penalty", 0.0)), - "stream": stream - } - - # Add max_tokens if specified - if kwargs.get("max_tokens"): - request_params["max_tokens"] = kwargs["max_tokens"] - - # Add tools if provided - if tools: - request_params["tools"] = tools - request_params["tool_choice"] = kwargs.get("tool_choice", "auto") - - # Make API call using ChatCompletion - if stream: - return self._handle_stream_response(request_params) - else: - return self._handle_sync_response(request_params) - - except Exception as e: - logger.error(f"[OPEN_AI] call_with_tools error: {e}") - if stream: - def error_generator(): - yield { - "error": True, - "message": str(e), - "status_code": 500 - } - return error_generator() - else: - return { - "error": True, - "message": str(e), - "status_code": 500 - } - - def _handle_sync_response(self, request_params): - """Handle synchronous OpenAI ChatCompletion API response""" - try: - response = openai.ChatCompletion.create(**request_params) - - logger.info(f"[OPEN_AI] call_with_tools reply, model={response.get('model')}, " - f"total_tokens={response.get('usage', {}).get('total_tokens', 0)}") - - return response - - except Exception as e: - logger.error(f"[OPEN_AI] sync response error: {e}") - raise - - def _handle_stream_response(self, request_params): - """Handle streaming OpenAI ChatCompletion API response""" - try: - stream = openai.ChatCompletion.create(**request_params) - - for chunk in stream: - yield chunk - - except Exception as e: - logger.error(f"[OPEN_AI] stream response error: {e}") - yield { - "error": True, - "message": str(e), - "status_code": 500 - } + if need_retry: + logger.warn("[OPEN_AI] 第{}次重试".format(retry_count + 1)) + return self.reply_text(session, retry_count + 1) + return result + + # NOTE: Tool-call routing is delegated to OpenAICompatibleBot.call_with_tools, + # which calls /chat/completions via our shared HTTP client. The previous + # bespoke implementation here bypassed Claude->OpenAI message/tool conversion + # and was effectively broken for agent flows; we now inherit the correct + # implementation from the base class. diff --git a/models/openai/open_ai_image.py b/models/openai/open_ai_image.py index fb113a02..9683baa7 100644 --- a/models/openai/open_ai_image.py +++ b/models/openai/open_ai_image.py @@ -1,17 +1,25 @@ import time -import openai -from models.openai.openai_compat import RateLimitError - from common.log import logger from common.token_bucket import TokenBucket from config import conf +from models.openai.openai_compat import RateLimitError, wrap_http_error +from models.openai.openai_http_client import OpenAIHTTPClient, OpenAIHTTPError -# OPENAI提供的画图接口 +# OpenAI image generation API wrapper class OpenAIImage(object): def __init__(self): - openai.api_key = conf().get("open_ai_api_key") + # Lazy default client; subclasses (ChatGPTBot/OpenAIBot) typically + # construct their own _http_client and override _get_image_client(). + self._image_api_key = conf().get("open_ai_api_key") + self._image_api_base = conf().get("open_ai_api_base") or None + self._image_proxy = conf().get("proxy") or None + self._image_client = OpenAIHTTPClient( + api_key=self._image_api_key, + api_base=self._image_api_base, + proxy=self._image_proxy, + ) if conf().get("rate_limit_dalle"): self.tb4dalle = TokenBucket(conf().get("rate_limit_dalle", 50)) @@ -20,24 +28,35 @@ class OpenAIImage(object): if conf().get("rate_limit_dalle") and not self.tb4dalle.get_token(): return False, "请求太快了,请休息一下再问我吧" logger.info("[OPEN_AI] image_query={}".format(query)) - response = openai.Image.create( - api_key=api_key, - prompt=query, # 图片描述 - n=1, # 每次生成图片的数量 + response = self._image_client.images_generate( + api_key=api_key or None, + api_base=api_base or None, + prompt=query, # image description + n=1, model=conf().get("text_to_image") or "dall-e-2", - # size=conf().get("image_create_size", "256x256"), # 图片大小,可选有 256x256, 512x512, 1024x1024 + # size=conf().get("image_create_size", "256x256"), ) image_url = response["data"][0]["url"] logger.info("[OPEN_AI] image_url={}".format(image_url)) return True, image_url + except OpenAIHTTPError as http_err: + mapped = wrap_http_error(http_err) + if isinstance(mapped, RateLimitError): + logger.warn(mapped) + if retry_count < 1: + time.sleep(5) + logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count + 1)) + return self.create_img(query, retry_count + 1) + return False, "画图出现问题,请休息一下再问我吧" + logger.exception(mapped) + return False, "画图出现问题,请休息一下再问我吧" except RateLimitError as e: logger.warn(e) if retry_count < 1: time.sleep(5) logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count + 1)) return self.create_img(query, retry_count + 1) - else: - return False, "画图出现问题,请休息一下再问我吧" + return False, "画图出现问题,请休息一下再问我吧" except Exception as e: logger.exception(e) return False, "画图出现问题,请休息一下再问我吧" diff --git a/models/openai/openai_compat.py b/models/openai/openai_compat.py index 7668e6ac..099acb59 100644 --- a/models/openai/openai_compat.py +++ b/models/openai/openai_compat.py @@ -1,102 +1,163 @@ """ -OpenAI compatibility layer for different versions. +OpenAI-compatible exception layer. -This module provides a compatibility layer between OpenAI library versions: -- OpenAI < 1.0 (old API with openai.error module) -- OpenAI >= 1.0 (new API with direct exception imports) +This module used to bridge between openai SDK 0.x and 1.x exception types. +Since we no longer depend on the `openai` SDK at all (we call HTTP directly +via :mod:`models.openai.openai_http_client`), this file now provides: + +1. Pure Python exception classes that match the *names* the rest of the + codebase already imports (RateLimitError / Timeout / APIError / + APIConnectionError / AuthenticationError / InvalidRequestError ...). +2. A :func:`map_http_error` helper that converts an + :class:`OpenAIHTTPError` (or any HTTP status code + message) into the + appropriate exception subclass, so existing ``except RateLimitError`` + ``except Timeout`` etc. blocks keep working unchanged. + +This keeps the behavior of all existing bots (rate-limit backoff, timeout +retry, auth-error fast-fail) identical to the openai-SDK-based version, while +removing the hard dependency on the `openai` package. """ -try: - # Try new OpenAI >= 1.0 API - from openai import ( - OpenAIError, - RateLimitError, - APIError, - APIConnectionError, - AuthenticationError, - APITimeoutError, - BadRequestError, - ) - - # Create a mock error module for backward compatibility - class ErrorModule: - OpenAIError = OpenAIError - RateLimitError = RateLimitError - APIError = APIError - APIConnectionError = APIConnectionError - AuthenticationError = AuthenticationError - Timeout = APITimeoutError # Renamed in new version - InvalidRequestError = BadRequestError # Renamed in new version - - error = ErrorModule() - - # Also export with new names - Timeout = APITimeoutError - InvalidRequestError = BadRequestError - -except ImportError: - # Fall back to old OpenAI < 1.0 API - try: - import openai.error as error - - # Export individual exceptions for direct import - OpenAIError = error.OpenAIError - RateLimitError = error.RateLimitError - APIError = error.APIError - APIConnectionError = error.APIConnectionError - AuthenticationError = error.AuthenticationError - InvalidRequestError = error.InvalidRequestError - Timeout = error.Timeout - BadRequestError = error.InvalidRequestError # Alias - APITimeoutError = error.Timeout # Alias - except (ImportError, AttributeError): - # Neither version works, create dummy classes - class OpenAIError(Exception): - pass - - class RateLimitError(OpenAIError): - pass - - class APIError(OpenAIError): - pass - - class APIConnectionError(OpenAIError): - pass - - class AuthenticationError(OpenAIError): - pass - - class InvalidRequestError(OpenAIError): - pass - - class Timeout(OpenAIError): - pass - - BadRequestError = InvalidRequestError - APITimeoutError = Timeout - - # Create error module - class ErrorModule: - OpenAIError = OpenAIError - RateLimitError = RateLimitError - APIError = APIError - APIConnectionError = APIConnectionError - AuthenticationError = AuthenticationError - InvalidRequestError = InvalidRequestError - Timeout = Timeout - - error = ErrorModule() +from typing import Optional + + +# --------------------------------------------------------------------------- # +# Exception hierarchy (mirrors openai SDK names so call sites don't change) +# --------------------------------------------------------------------------- # + +class OpenAIError(Exception): + """Base exception for all OpenAI-compatible API errors.""" + + def __init__(self, message: str = "", status_code: Optional[int] = None, + body=None): + super().__init__(message) + self.message = message + self.status_code = status_code + self.body = body + + +class APIError(OpenAIError): + """Generic API error (5xx and unclassified errors).""" + + +class APIConnectionError(OpenAIError): + """Network / connection failure (DNS, refused, reset...).""" + + +class Timeout(OpenAIError): + """Request timeout. Aliased as APITimeoutError for new-SDK style imports.""" + + +class AuthenticationError(OpenAIError): + """401 Unauthorized.""" + + +class PermissionDeniedError(OpenAIError): + """403 Forbidden.""" + + +class NotFoundError(OpenAIError): + """404 Not Found.""" + + +class InvalidRequestError(OpenAIError): + """400 Bad Request. Aliased as BadRequestError.""" + + +class RateLimitError(OpenAIError): + """429 Too Many Requests.""" + + +# Aliases used by some new-SDK-style code paths in the project. +APITimeoutError = Timeout +BadRequestError = InvalidRequestError + + +# --------------------------------------------------------------------------- # +# Backward-compat ``error`` module-style accessor +# --------------------------------------------------------------------------- # +# Some legacy code in the codebase (and possibly user plugins) does +# from models.openai.openai_compat import error +# except error.RateLimitError: ... +# Keep that path working by exposing an attribute namespace. +class _ErrorModule: + OpenAIError = OpenAIError + APIError = APIError + APIConnectionError = APIConnectionError + Timeout = Timeout + AuthenticationError = AuthenticationError + PermissionDeniedError = PermissionDeniedError + NotFoundError = NotFoundError + InvalidRequestError = InvalidRequestError + RateLimitError = RateLimitError + + +error = _ErrorModule() + + +# --------------------------------------------------------------------------- # +# HTTP -> exception mapping +# --------------------------------------------------------------------------- # + +def map_http_error(status_code: Optional[int], message: str = "", + body=None) -> OpenAIError: + """Convert an HTTP status (+ optional message/body) to the right subclass. + + Used by HTTP-based bot wrappers so that downstream ``except RateLimitError`` + blocks behave identically to when the openai SDK was raising them. + """ + sc = status_code or 0 + msg = message or "" + msg_lower = msg.lower() + + # Connection-level (no status / non-HTTP failure) + if sc == 0: + if "timeout" in msg_lower or "timed out" in msg_lower: + return Timeout(msg, sc, body) + return APIConnectionError(msg, sc, body) + + if sc == 408: + return Timeout(msg, sc, body) + if sc == 401: + return AuthenticationError(msg, sc, body) + if sc == 403: + return PermissionDeniedError(msg, sc, body) + if sc == 404: + return NotFoundError(msg, sc, body) + if sc == 429: + return RateLimitError(msg, sc, body) + if 400 <= sc < 500: + return InvalidRequestError(msg, sc, body) + if sc >= 500: + return APIError(msg, sc, body) + + return APIError(msg, sc, body) + + +def wrap_http_error(http_err) -> OpenAIError: + """Adapter for :class:`OpenAIHTTPError` -> compat exception subclass. + + Accepts any object with ``status_code`` / ``message`` / ``body`` attrs. + """ + sc = getattr(http_err, "status_code", None) + msg = getattr(http_err, "message", "") or str(http_err) + body = getattr(http_err, "body", None) + return map_http_error(sc, msg, body) + -# Export all for easy import __all__ = [ - 'error', - 'OpenAIError', - 'RateLimitError', - 'APIError', - 'APIConnectionError', - 'AuthenticationError', - 'InvalidRequestError', - 'Timeout', - 'BadRequestError', - 'APITimeoutError', + "error", + "OpenAIError", + "APIError", + "APIConnectionError", + "Timeout", + "APITimeoutError", + "AuthenticationError", + "PermissionDeniedError", + "NotFoundError", + "InvalidRequestError", + "BadRequestError", + "RateLimitError", + "map_http_error", + "wrap_http_error", ] diff --git a/models/openai/openai_http_client.py b/models/openai/openai_http_client.py new file mode 100644 index 00000000..ca9e91a6 --- /dev/null +++ b/models/openai/openai_http_client.py @@ -0,0 +1,456 @@ +# encoding:utf-8 + +""" +Lightweight HTTP client for OpenAI-compatible APIs. + +This client is a drop-in replacement for the parts of the `openai` SDK that this +project actually uses (chat completions, completions, image generation), so we +can drop the hard dependency on `openai==0.27.x`. + +Design goals: +- Pure `requests` based (no httpx / pydantic / openai SDK dependency). +- Returns plain `dict` responses with the same shape OpenAI's HTTP API returns, + so existing code that does `response["choices"][0]["message"]["content"]` / + `response["usage"]["total_tokens"]` keeps working. +- Streaming yields plain `dict` chunks (parsed SSE `data:` JSON), matching the + shape that `agent/protocol/agent_stream.py` consumes: + chunk["choices"][0]["delta"]["content" | "tool_calls" | "reasoning_content"] + chunk["choices"][0]["finish_reason"] + Plus dict-style error chunks: {"error": True, "message": ..., "status_code": ...} +- Compatible with arbitrary OpenAI-compatible endpoints (LinkAI, Azure-style + proxies, DeepSeek, Moonshot, etc.) by allowing per-call api_key / api_base + override and trusting whatever path/payload shape the caller passes. +""" + +import json +from typing import Any, Dict, Generator, Optional + +import requests + +from common.log import logger + + +DEFAULT_API_BASE = "https://api.openai.com/v1" +DEFAULT_TIMEOUT = 600 # seconds; matches old openai SDK default + + +class OpenAIHTTPError(Exception): + """Raised for non-2xx responses. Carries status code + parsed body.""" + + def __init__(self, status_code: int, body: Any, message: str = ""): + self.status_code = status_code + self.body = body + # Try to extract human-readable message from OpenAI-style error envelope + if not message and isinstance(body, dict): + err = body.get("error") or {} + if isinstance(err, dict): + message = err.get("message") or "" + elif isinstance(err, str): + message = err + if not message: + message = str(body)[:500] + self.message = message + super().__init__(f"HTTP {status_code}: {message}") + + +class OpenAIHTTPClient: + """Minimal HTTP client for OpenAI-compatible endpoints. + + Per-instance defaults (api_key / api_base / proxy / timeout) can be + overridden on every call. Callers can also pass ``extra_headers`` for + Azure-style ``api-key`` headers or custom routing headers. + """ + + def __init__( + self, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + proxy: Optional[str] = None, + timeout: Optional[float] = None, + extra_headers: Optional[Dict[str, str]] = None, + ): + self.api_key = api_key + self.api_base = (api_base or DEFAULT_API_BASE).rstrip("/") + self.timeout = timeout if timeout is not None else DEFAULT_TIMEOUT + self.proxies = ( + {"http": proxy, "https": proxy} if proxy else None + ) + self.extra_headers = dict(extra_headers) if extra_headers else {} + + # ------------------------------------------------------------------ # + # Public API surface (mirrors what the old openai SDK provided) + # ------------------------------------------------------------------ # + + def chat_completions( + self, + *, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + timeout: Optional[float] = None, + proxy: Optional[str] = None, + extra_headers: Optional[Dict[str, str]] = None, + extra_query: Optional[Dict[str, str]] = None, + path: str = "/chat/completions", + stream: bool = False, + **payload, + ): + """POST /chat/completions. + + When ``stream=True`` returns a generator yielding parsed SSE chunks + (plain ``dict``). On error during streaming, yields a single dict with + ``{"error": True, ...}`` and stops, matching the contract expected by + ``agent/protocol/agent_stream.py``. + """ + payload["stream"] = stream + return self._request( + path=path, + payload=payload, + api_key=api_key, + api_base=api_base, + timeout=timeout, + proxy=proxy, + extra_headers=extra_headers, + extra_query=extra_query, + stream=stream, + ) + + def completions( + self, + *, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + timeout: Optional[float] = None, + **payload, + ) -> Dict[str, Any]: + """POST /completions (legacy text completion). Non-streaming only.""" + payload.pop("stream", None) + return self._request( + path="/completions", + payload=payload, + api_key=api_key, + api_base=api_base, + timeout=timeout, + stream=False, + ) + + def images_generate( + self, + *, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + timeout: Optional[float] = None, + **payload, + ) -> Dict[str, Any]: + """POST /images/generations.""" + return self._request( + path="/images/generations", + payload=payload, + api_key=api_key, + api_base=api_base, + timeout=timeout, + stream=False, + ) + + # ------------------------------------------------------------------ # + # Internal helpers + # ------------------------------------------------------------------ # + + def _build_headers( + self, + api_key: Optional[str], + extra_headers: Optional[Dict[str, str]], + ) -> Dict[str, str]: + key = api_key if api_key is not None else self.api_key + headers = {"Content-Type": "application/json"} + if key: + headers["Authorization"] = f"Bearer {key}" + if self.extra_headers: + headers.update(self.extra_headers) + if extra_headers: + headers.update(extra_headers) + return headers + + def _request( + self, + *, + path: str, + payload: Dict[str, Any], + api_key: Optional[str], + api_base: Optional[str], + timeout: Optional[float], + stream: bool, + proxy: Optional[str] = None, + extra_headers: Optional[Dict[str, str]] = None, + extra_query: Optional[Dict[str, str]] = None, + ): + base = (api_base or self.api_base).rstrip("/") if api_base else self.api_base + url = f"{base}{path}" if path.startswith("/") else f"{base}/{path}" + headers = self._build_headers(api_key, extra_headers) + req_timeout = timeout if timeout is not None else self.timeout + proxies = ( + {"http": proxy, "https": proxy} if proxy else self.proxies + ) + + # Drop None-valued keys; some providers reject explicit nulls. + clean_payload = {k: v for k, v in payload.items() if v is not None} + + if stream: + # Return a generator. Errors during stream are yielded as a single + # error chunk so callers (agent_stream) can map them to their + # existing error-handling path without try/except around the loop. + return self._stream_chat( + url=url, + headers=headers, + payload=clean_payload, + proxies=proxies, + timeout=req_timeout, + params=extra_query, + ) + + try: + resp = requests.post( + url, + headers=headers, + json=clean_payload, + timeout=req_timeout, + proxies=proxies, + params=extra_query, + ) + except requests.exceptions.Timeout as e: + raise OpenAIHTTPError(408, {}, f"Request timed out: {e}") + except requests.exceptions.ConnectionError as e: + raise OpenAIHTTPError(0, {}, f"Connection error: {e}") + except requests.exceptions.RequestException as e: + raise OpenAIHTTPError(0, {}, f"Request failed: {e}") + + return self._parse_response(resp) + + @staticmethod + def _parse_response(resp: requests.Response) -> Dict[str, Any]: + # Try JSON, fall back to text + try: + data = resp.json() + except ValueError: + data = {"raw": resp.text} + + if resp.status_code >= 400: + raise OpenAIHTTPError(resp.status_code, data) + + return data + + def _stream_chat( + self, + *, + url: str, + headers: Dict[str, str], + payload: Dict[str, Any], + proxies: Optional[Dict[str, str]], + timeout: float, + params: Optional[Dict[str, str]] = None, + ) -> Generator[Dict[str, Any], None, None]: + """Stream SSE response and yield parsed JSON chunks. + + Yields: + - Normal chunks: dict with ``choices[0].delta`` etc. + - Error chunks: ``{"error": True, "message": str, "status_code": int}`` + followed by termination of the generator. + """ + try: + resp = requests.post( + url, + headers=headers, + json=payload, + timeout=timeout, + proxies=proxies, + stream=True, + params=params, + ) + except requests.exceptions.Timeout as e: + yield self._make_error_chunk(408, f"Request timed out: {e}") + return + except requests.exceptions.ConnectionError as e: + yield self._make_error_chunk(0, f"Connection error: {e}") + return + except requests.exceptions.RequestException as e: + yield self._make_error_chunk(0, f"Request failed: {e}") + return + + if resp.status_code >= 400: + # Read full body once for error reporting + try: + body = resp.json() + except ValueError: + body = {"raw": resp.text[:1000]} + err_msg = "" + err_code = "" + err_type = "" + if isinstance(body, dict): + err = body.get("error") or {} + if isinstance(err, dict): + err_msg = err.get("message") or "" + err_code = err.get("code") or "" + err_type = err.get("type") or "" + elif isinstance(err, str): + err_msg = err + if not err_msg: + err_msg = str(body)[:500] + yield { + "error": { + "message": err_msg, + "code": err_code, + "type": err_type, + }, + # Top-level fields kept for backward compatibility with the + # error-shape that `_handle_stream_response` previously emitted. + "message": err_msg, + "status_code": resp.status_code, + } + return + + # IMPORTANT: do NOT use `iter_lines(decode_unicode=True)`. + # + # `requests` decodes per-network-chunk using the response's declared + # encoding (often Latin-1 / ISO-8859-1 for SSE), which mangles UTF-8 + # codepoints that straddle a chunk boundary. Some upstreams (Azure + # OpenAI proxies, Cloudflare-fronted gateways, ...) split TCP chunks + # aggressively in the middle of multibyte characters, producing + # garbled text and "skip malformed SSE chunk" errors. + # + # The fix is to read raw bytes, accumulate them until we have a + # complete SSE event (terminated by a blank line per the SSE spec: + # https://html.spec.whatwg.org/multipage/server-sent-events.html), + # and only THEN decode as UTF-8. This mirrors what the official + # openai SDK 1.x does in `openai/_streaming.py::SSEDecoder` (which + # itself is copied from httpx-sse). + try: + for sse_event in self._iter_sse_events(resp): + # `sse_event` is the joined `data:` payload as a str. + if sse_event == "[DONE]": + return + if not sse_event: + continue + try: + chunk = json.loads(sse_event) + except ValueError: + logger.debug( + f"[OpenAIHTTP] skip malformed SSE chunk: {sse_event[:200]}" + ) + continue + yield chunk + except requests.exceptions.ChunkedEncodingError as e: + yield self._make_error_chunk(0, f"Stream interrupted: {e}") + except requests.exceptions.RequestException as e: + yield self._make_error_chunk(0, f"Stream error: {e}") + finally: + try: + resp.close() + except Exception: + pass + + @staticmethod + def _iter_sse_events(resp: requests.Response) -> Generator[str, None, None]: + """Decode an SSE byte stream into joined `data:` payloads. + + Implements the subset of the SSE spec that OpenAI / OpenAI-compatible + endpoints actually use: + - Events are separated by blank lines (\\r\\r, \\n\\n, or \\r\\n\\r\\n). + - Within an event, multiple ``data:`` lines are concatenated with + "\\n" (per spec). + - ``event:``, ``id:``, ``retry:`` and comment lines (``:``) are + tolerated but not yielded — for chat-completion we only care + about the JSON payload in ``data:``. + - Bytes are buffered until a complete event boundary is seen so + UTF-8 codepoints split across TCP chunks decode correctly. + + Yields each event's joined ``data`` string. The terminal sentinel + ``[DONE]`` is yielded as a literal string so the caller can break. + """ + buf = b"" + for raw in resp.iter_content(chunk_size=None, decode_unicode=False): + if not raw: + continue + buf += raw + # Find complete events (terminated by a blank line). + while True: + # Look for the earliest event terminator. SSE allows three + # forms; check all and pick the earliest match. + idx_nn = buf.find(b"\n\n") + idx_rr = buf.find(b"\r\r") + idx_rnrn = buf.find(b"\r\n\r\n") + candidates = [i for i in (idx_nn, idx_rr, idx_rnrn) if i != -1] + if not candidates: + break + # We need to know the length of the matched terminator to + # advance past it correctly. + end_pos = min(candidates) + if end_pos == idx_rnrn: + term_len = 4 + else: + term_len = 2 + event_bytes = buf[:end_pos] + buf = buf[end_pos + term_len:] + + # Decode the full event as UTF-8. ``errors="replace"`` is a + # belt-and-suspenders safety net for truly malformed upstream + # bytes; it should never trigger for well-formed providers. + try: + event_text = event_bytes.decode("utf-8") + except UnicodeDecodeError: + event_text = event_bytes.decode("utf-8", errors="replace") + + data_lines = [] + for line in event_text.splitlines(): + if not line or line.startswith(":"): + continue + field, _, value = line.partition(":") + # Per SSE spec, a single optional space after the colon + # is part of the framing, not the value. + if value.startswith(" "): + value = value[1:] + if field == "data": + data_lines.append(value) + # Other fields (event/id/retry) are intentionally ignored + # — chat-completion endpoints don't use them in a way we + # need for parsing. + if data_lines: + yield "\n".join(data_lines) + + # Flush any trailing bytes the server forgot to terminate. This is + # rare but spec-allowed (some providers omit the final \n\n). + if buf.strip(): + try: + event_text = buf.decode("utf-8") + except UnicodeDecodeError: + event_text = buf.decode("utf-8", errors="replace") + data_lines = [] + for line in event_text.splitlines(): + if not line or line.startswith(":"): + continue + field, _, value = line.partition(":") + if value.startswith(" "): + value = value[1:] + if field == "data": + data_lines.append(value) + if data_lines: + yield "\n".join(data_lines) + + @staticmethod + def _make_error_chunk(status_code: int, message: str) -> Dict[str, Any]: + return { + "error": {"message": message, "code": "", "type": ""}, + "message": message, + "status_code": status_code, + } + + +# A tiny helper for callers that just need a one-shot client without storing +# state. Keeps call sites cleaner than instantiating the class every time. +def get_default_client( + api_key: Optional[str] = None, + api_base: Optional[str] = None, + proxy: Optional[str] = None, + timeout: Optional[float] = None, +) -> OpenAIHTTPClient: + return OpenAIHTTPClient( + api_key=api_key, api_base=api_base, proxy=proxy, timeout=timeout + ) diff --git a/models/openai_compatible_bot.py b/models/openai_compatible_bot.py index 6d4d314e..1fcad7d7 100644 --- a/models/openai_compatible_bot.py +++ b/models/openai_compatible_bot.py @@ -8,11 +8,11 @@ This includes: OpenAI, LinkAI, Azure OpenAI, and many third-party providers. """ import json -import openai import requests from typing import Optional from common.log import logger from agent.protocol.message_utils import drop_orphaned_tool_results_openai +from models.openai.openai_http_client import OpenAIHTTPClient, OpenAIHTTPError class OpenAICompatibleBot: @@ -135,49 +135,87 @@ class OpenAICompatibleBot: "status_code": 500 } + def _get_http_client(self) -> OpenAIHTTPClient: + """Build an HTTP client honoring the global proxy config. + + Subclasses can override this for custom auth headers (e.g. Azure's + ``api-key`` header) by returning a pre-configured client. + """ + from config import conf + proxy = conf().get("proxy") or None + return OpenAIHTTPClient(proxy=proxy) + def _handle_sync_response(self, request_params, api_key, api_base): - """Handle synchronous OpenAI API response""" + """Handle synchronous chat-completion via HTTP.""" + params = dict(request_params) + params.pop("stream", None) + # Translate legacy SDK timeout kwarg to our HTTP client kwarg. + timeout = params.pop("request_timeout", None) or params.pop("timeout", None) try: - # Build kwargs with explicit API configuration - kwargs = dict(request_params) - if api_key: - kwargs["api_key"] = api_key - if api_base: - kwargs["api_base"] = api_base - - response = openai.ChatCompletion.create(**kwargs) - return response - + client = self._get_http_client() + return client.chat_completions( + api_key=api_key, + api_base=api_base, + timeout=timeout, + stream=False, + **params, + ) + except OpenAIHTTPError as e: + logger.error( + f"[{self.__class__.__name__}] sync response error: " + f"HTTP {e.status_code}: {e.message}" + ) + return { + "error": True, + "message": e.message, + "status_code": e.status_code or 500, + } except Exception as e: logger.error(f"[{self.__class__.__name__}] sync response error: {e}") return { "error": True, "message": str(e), - "status_code": 500 + "status_code": 500, } - + def _handle_stream_response(self, request_params, api_key, api_base): - """Handle streaming OpenAI API response""" + """Handle streaming chat-completion via HTTP (SSE). + + Yields dict chunks in OpenAI's standard streaming shape: + {"choices": [{"delta": {...}, "finish_reason": ...}], ...} + On error, yields a single ``{"error": ..., "status_code": ...}`` chunk + — the same contract :mod:`agent.protocol.agent_stream` already handles. + """ + params = dict(request_params) + params.pop("stream", None) + timeout = params.pop("request_timeout", None) or params.pop("timeout", None) try: - # Build kwargs with explicit API configuration - kwargs = dict(request_params) - if api_key: - kwargs["api_key"] = api_key - if api_base: - kwargs["api_base"] = api_base - - stream = openai.ChatCompletion.create(**kwargs) - - # Stream chunks to caller + client = self._get_http_client() + stream = client.chat_completions( + api_key=api_key, + api_base=api_base, + timeout=timeout, + stream=True, + **params, + ) for chunk in stream: yield chunk - + except OpenAIHTTPError as e: + logger.error( + f"[{self.__class__.__name__}] stream response error: " + f"HTTP {e.status_code}: {e.message}" + ) + yield { + "error": True, + "message": e.message, + "status_code": e.status_code or 500, + } except Exception as e: logger.error(f"[{self.__class__.__name__}] stream response error: {e}") yield { "error": True, "message": str(e), - "status_code": 500 + "status_code": 500, } def _convert_tools_to_openai_format(self, tools): diff --git a/requirements.txt b/requirements.txt index be4be71b..5a236db8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ -openai==0.27.8 aiohttp>=3.8.6,<3.10 requests>=2.28.2 chardet>=5.1.0 diff --git a/voice/openai/openai_voice.py b/voice/openai/openai_voice.py index 506d8b5c..d48e4b4f 100644 --- a/voice/openai/openai_voice.py +++ b/voice/openai/openai_voice.py @@ -3,8 +3,6 @@ google voice service """ import json -import openai - from bridge.reply import Reply, ReplyType from common.log import logger from config import conf @@ -15,7 +13,9 @@ import datetime, random class OpenaiVoice(Voice): def __init__(self): - openai.api_key = conf().get("open_ai_api_key") + # No-op: this implementation calls OpenAI HTTP endpoints directly via + # `requests`, so it does not need a global SDK to be configured. + pass def voiceToText(self, voice_file): logger.debug("[Openai] voice file name={}".format(voice_file))