Files
chatgpt-on-wechat/models/openai/openai_http_client.py

494 lines
18 KiB
Python

# encoding:utf-8
"""
Lightweight HTTP client for OpenAI-compatible APIs.
This client is a drop-in replacement for the parts of the `openai` SDK that this
project actually uses (chat completions, completions, image generation), so we
can drop the hard dependency on `openai==0.27.x`.
Design goals:
- Pure `requests` based (no httpx / pydantic / openai SDK dependency).
- Returns plain `dict` responses with the same shape OpenAI's HTTP API returns,
so existing code that does `response["choices"][0]["message"]["content"]` /
`response["usage"]["total_tokens"]` keeps working.
- Streaming yields plain `dict` chunks (parsed SSE `data:` JSON), matching the
shape that `agent/protocol/agent_stream.py` consumes:
chunk["choices"][0]["delta"]["content" | "tool_calls" | "reasoning_content"]
chunk["choices"][0]["finish_reason"]
Plus dict-style error chunks: {"error": True, "message": ..., "status_code": ...}
- Compatible with arbitrary OpenAI-compatible endpoints (LinkAI, Azure-style
proxies, DeepSeek, Moonshot, etc.) by allowing per-call api_key / api_base
override and trusting whatever path/payload shape the caller passes.
"""
import json
from typing import Any, Dict, Generator, Optional
from urllib.parse import urlparse
import requests
from common.log import logger
DEFAULT_API_BASE = "https://api.openai.com/v1"
DEFAULT_TIMEOUT = 600 # seconds; matches old openai SDK default
_APP_TITLE = "CowAgent"
_APP_REFERER = "https://github.com/zhayujie/CowAgent"
# Per-gateway app attribution headers, only sent when the request host
# matches a documented gateway. Sending these to user-configured custom
# proxies would leak app identity, so we dispatch by host suffix.
_ATTRIBUTION_HEADERS_BY_HOST: Dict[str, Dict[str, str]] = {
"openrouter.ai": {
"HTTP-Referer": _APP_REFERER,
"X-Title": _APP_TITLE,
},
"ai-gateway.vercel.sh": {
"HTTP-Referer": _APP_REFERER,
"X-Title": _APP_TITLE,
},
}
def _resolve_attribution_headers(url: str) -> Dict[str, str]:
try:
host = (urlparse(url).hostname or "").lower()
except Exception:
return {}
if not host:
return {}
for suffix, headers in _ATTRIBUTION_HEADERS_BY_HOST.items():
if host == suffix or host.endswith("." + suffix):
return dict(headers)
return {}
class OpenAIHTTPError(Exception):
"""Raised for non-2xx responses. Carries status code + parsed body."""
def __init__(self, status_code: int, body: Any, message: str = ""):
self.status_code = status_code
self.body = body
# Try to extract human-readable message from OpenAI-style error envelope
if not message and isinstance(body, dict):
err = body.get("error") or {}
if isinstance(err, dict):
message = err.get("message") or ""
elif isinstance(err, str):
message = err
if not message:
message = str(body)[:500]
self.message = message
super().__init__(f"HTTP {status_code}: {message}")
class OpenAIHTTPClient:
"""Minimal HTTP client for OpenAI-compatible endpoints.
Per-instance defaults (api_key / api_base / proxy / timeout) can be
overridden on every call. Callers can also pass ``extra_headers`` for
Azure-style ``api-key`` headers or custom routing headers.
"""
def __init__(
self,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
proxy: Optional[str] = None,
timeout: Optional[float] = None,
extra_headers: Optional[Dict[str, str]] = None,
):
self.api_key = api_key
self.api_base = (api_base or DEFAULT_API_BASE).rstrip("/")
self.timeout = timeout if timeout is not None else DEFAULT_TIMEOUT
self.proxies = (
{"http": proxy, "https": proxy} if proxy else None
)
self.extra_headers = dict(extra_headers) if extra_headers else {}
# ------------------------------------------------------------------ #
# Public API surface (mirrors what the old openai SDK provided)
# ------------------------------------------------------------------ #
def chat_completions(
self,
*,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
timeout: Optional[float] = None,
proxy: Optional[str] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_query: Optional[Dict[str, str]] = None,
path: str = "/chat/completions",
stream: bool = False,
**payload,
):
"""POST /chat/completions.
When ``stream=True`` returns a generator yielding parsed SSE chunks
(plain ``dict``). On error during streaming, yields a single dict with
``{"error": True, ...}`` and stops, matching the contract expected by
``agent/protocol/agent_stream.py``.
"""
payload["stream"] = stream
return self._request(
path=path,
payload=payload,
api_key=api_key,
api_base=api_base,
timeout=timeout,
proxy=proxy,
extra_headers=extra_headers,
extra_query=extra_query,
stream=stream,
)
def completions(
self,
*,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
timeout: Optional[float] = None,
**payload,
) -> Dict[str, Any]:
"""POST /completions (legacy text completion). Non-streaming only."""
payload.pop("stream", None)
return self._request(
path="/completions",
payload=payload,
api_key=api_key,
api_base=api_base,
timeout=timeout,
stream=False,
)
def images_generate(
self,
*,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
timeout: Optional[float] = None,
**payload,
) -> Dict[str, Any]:
"""POST /images/generations."""
return self._request(
path="/images/generations",
payload=payload,
api_key=api_key,
api_base=api_base,
timeout=timeout,
stream=False,
)
# ------------------------------------------------------------------ #
# Internal helpers
# ------------------------------------------------------------------ #
def _build_headers(
self,
api_key: Optional[str],
extra_headers: Optional[Dict[str, str]],
url: Optional[str] = None,
) -> Dict[str, str]:
key = api_key if api_key is not None else self.api_key
headers = {"Content-Type": "application/json"}
if key:
headers["Authorization"] = f"Bearer {key}"
if url:
attribution = _resolve_attribution_headers(url)
if attribution:
headers.update(attribution)
if self.extra_headers:
headers.update(self.extra_headers)
if extra_headers:
headers.update(extra_headers)
return headers
def _request(
self,
*,
path: str,
payload: Dict[str, Any],
api_key: Optional[str],
api_base: Optional[str],
timeout: Optional[float],
stream: bool,
proxy: Optional[str] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_query: Optional[Dict[str, str]] = None,
):
base = (api_base or self.api_base).rstrip("/") if api_base else self.api_base
url = f"{base}{path}" if path.startswith("/") else f"{base}/{path}"
headers = self._build_headers(api_key, extra_headers, url=url)
req_timeout = timeout if timeout is not None else self.timeout
proxies = (
{"http": proxy, "https": proxy} if proxy else self.proxies
)
# Drop None-valued keys; some providers reject explicit nulls.
clean_payload = {k: v for k, v in payload.items() if v is not None}
if stream:
# Return a generator. Errors during stream are yielded as a single
# error chunk so callers (agent_stream) can map them to their
# existing error-handling path without try/except around the loop.
return self._stream_chat(
url=url,
headers=headers,
payload=clean_payload,
proxies=proxies,
timeout=req_timeout,
params=extra_query,
)
try:
resp = requests.post(
url,
headers=headers,
json=clean_payload,
timeout=req_timeout,
proxies=proxies,
params=extra_query,
)
except requests.exceptions.Timeout as e:
raise OpenAIHTTPError(408, {}, f"Request timed out: {e}")
except requests.exceptions.ConnectionError as e:
raise OpenAIHTTPError(0, {}, f"Connection error: {e}")
except requests.exceptions.RequestException as e:
raise OpenAIHTTPError(0, {}, f"Request failed: {e}")
return self._parse_response(resp)
@staticmethod
def _parse_response(resp: requests.Response) -> Dict[str, Any]:
# Try JSON, fall back to text
try:
data = resp.json()
except ValueError:
data = {"raw": resp.text}
if resp.status_code >= 400:
raise OpenAIHTTPError(resp.status_code, data)
return data
def _stream_chat(
self,
*,
url: str,
headers: Dict[str, str],
payload: Dict[str, Any],
proxies: Optional[Dict[str, str]],
timeout: float,
params: Optional[Dict[str, str]] = None,
) -> Generator[Dict[str, Any], None, None]:
"""Stream SSE response and yield parsed JSON chunks.
Yields:
- Normal chunks: dict with ``choices[0].delta`` etc.
- Error chunks: ``{"error": True, "message": str, "status_code": int}``
followed by termination of the generator.
"""
try:
resp = requests.post(
url,
headers=headers,
json=payload,
timeout=timeout,
proxies=proxies,
stream=True,
params=params,
)
except requests.exceptions.Timeout as e:
yield self._make_error_chunk(408, f"Request timed out: {e}")
return
except requests.exceptions.ConnectionError as e:
yield self._make_error_chunk(0, f"Connection error: {e}")
return
except requests.exceptions.RequestException as e:
yield self._make_error_chunk(0, f"Request failed: {e}")
return
if resp.status_code >= 400:
# Read full body once for error reporting
try:
body = resp.json()
except ValueError:
body = {"raw": resp.text[:1000]}
err_msg = ""
err_code = ""
err_type = ""
if isinstance(body, dict):
err = body.get("error") or {}
if isinstance(err, dict):
err_msg = err.get("message") or ""
err_code = err.get("code") or ""
err_type = err.get("type") or ""
elif isinstance(err, str):
err_msg = err
if not err_msg:
err_msg = str(body)[:500]
yield {
"error": {
"message": err_msg,
"code": err_code,
"type": err_type,
},
# Top-level fields kept for backward compatibility with the
# error-shape that `_handle_stream_response` previously emitted.
"message": err_msg,
"status_code": resp.status_code,
}
return
# IMPORTANT: do NOT use `iter_lines(decode_unicode=True)`.
#
# `requests` decodes per-network-chunk using the response's declared
# encoding (often Latin-1 / ISO-8859-1 for SSE), which mangles UTF-8
# codepoints that straddle a chunk boundary. Some upstreams (Azure
# OpenAI proxies, Cloudflare-fronted gateways, ...) split TCP chunks
# aggressively in the middle of multibyte characters, producing
# garbled text and "skip malformed SSE chunk" errors.
#
# The fix is to read raw bytes, accumulate them until we have a
# complete SSE event (terminated by a blank line per the SSE spec:
# https://html.spec.whatwg.org/multipage/server-sent-events.html),
# and only THEN decode as UTF-8. This mirrors what the official
# openai SDK 1.x does in `openai/_streaming.py::SSEDecoder` (which
# itself is copied from httpx-sse).
try:
for sse_event in self._iter_sse_events(resp):
# `sse_event` is the joined `data:` payload as a str.
if sse_event == "[DONE]":
return
if not sse_event:
continue
try:
chunk = json.loads(sse_event)
except ValueError:
logger.debug(
f"[OpenAIHTTP] skip malformed SSE chunk: {sse_event[:200]}"
)
continue
yield chunk
except requests.exceptions.ChunkedEncodingError as e:
yield self._make_error_chunk(0, f"Stream interrupted: {e}")
except requests.exceptions.RequestException as e:
yield self._make_error_chunk(0, f"Stream error: {e}")
finally:
try:
resp.close()
except Exception:
pass
@staticmethod
def _iter_sse_events(resp: requests.Response) -> Generator[str, None, None]:
"""Decode an SSE byte stream into joined `data:` payloads.
Implements the subset of the SSE spec that OpenAI / OpenAI-compatible
endpoints actually use:
- Events are separated by blank lines (\\r\\r, \\n\\n, or \\r\\n\\r\\n).
- Within an event, multiple ``data:`` lines are concatenated with
"\\n" (per spec).
- ``event:``, ``id:``, ``retry:`` and comment lines (``:``) are
tolerated but not yielded — for chat-completion we only care
about the JSON payload in ``data:``.
- Bytes are buffered until a complete event boundary is seen so
UTF-8 codepoints split across TCP chunks decode correctly.
Yields each event's joined ``data`` string. The terminal sentinel
``[DONE]`` is yielded as a literal string so the caller can break.
"""
buf = b""
for raw in resp.iter_content(chunk_size=None, decode_unicode=False):
if not raw:
continue
buf += raw
# Find complete events (terminated by a blank line).
while True:
# Look for the earliest event terminator. SSE allows three
# forms; check all and pick the earliest match.
idx_nn = buf.find(b"\n\n")
idx_rr = buf.find(b"\r\r")
idx_rnrn = buf.find(b"\r\n\r\n")
candidates = [i for i in (idx_nn, idx_rr, idx_rnrn) if i != -1]
if not candidates:
break
# We need to know the length of the matched terminator to
# advance past it correctly.
end_pos = min(candidates)
if end_pos == idx_rnrn:
term_len = 4
else:
term_len = 2
event_bytes = buf[:end_pos]
buf = buf[end_pos + term_len:]
# Decode the full event as UTF-8. ``errors="replace"`` is a
# belt-and-suspenders safety net for truly malformed upstream
# bytes; it should never trigger for well-formed providers.
try:
event_text = event_bytes.decode("utf-8")
except UnicodeDecodeError:
event_text = event_bytes.decode("utf-8", errors="replace")
data_lines = []
for line in event_text.splitlines():
if not line or line.startswith(":"):
continue
field, _, value = line.partition(":")
# Per SSE spec, a single optional space after the colon
# is part of the framing, not the value.
if value.startswith(" "):
value = value[1:]
if field == "data":
data_lines.append(value)
# Other fields (event/id/retry) are intentionally ignored
# — chat-completion endpoints don't use them in a way we
# need for parsing.
if data_lines:
yield "\n".join(data_lines)
# Flush any trailing bytes the server forgot to terminate. This is
# rare but spec-allowed (some providers omit the final \n\n).
if buf.strip():
try:
event_text = buf.decode("utf-8")
except UnicodeDecodeError:
event_text = buf.decode("utf-8", errors="replace")
data_lines = []
for line in event_text.splitlines():
if not line or line.startswith(":"):
continue
field, _, value = line.partition(":")
if value.startswith(" "):
value = value[1:]
if field == "data":
data_lines.append(value)
if data_lines:
yield "\n".join(data_lines)
@staticmethod
def _make_error_chunk(status_code: int, message: str) -> Dict[str, Any]:
return {
"error": {"message": message, "code": "", "type": ""},
"message": message,
"status_code": status_code,
}
# A tiny helper for callers that just need a one-shot client without storing
# state. Keeps call sites cleaner than instantiating the class every time.
def get_default_client(
api_key: Optional[str] = None,
api_base: Optional[str] = None,
proxy: Optional[str] = None,
timeout: Optional[float] = None,
) -> OpenAIHTTPClient:
return OpenAIHTTPClient(
api_key=api_key, api_base=api_base, proxy=proxy, timeout=timeout
)