mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-02 00:57:41 +08:00
494 lines
18 KiB
Python
494 lines
18 KiB
Python
# encoding:utf-8
|
|
|
|
"""
|
|
Lightweight HTTP client for OpenAI-compatible APIs.
|
|
|
|
This client is a drop-in replacement for the parts of the `openai` SDK that this
|
|
project actually uses (chat completions, completions, image generation), so we
|
|
can drop the hard dependency on `openai==0.27.x`.
|
|
|
|
Design goals:
|
|
- Pure `requests` based (no httpx / pydantic / openai SDK dependency).
|
|
- Returns plain `dict` responses with the same shape OpenAI's HTTP API returns,
|
|
so existing code that does `response["choices"][0]["message"]["content"]` /
|
|
`response["usage"]["total_tokens"]` keeps working.
|
|
- Streaming yields plain `dict` chunks (parsed SSE `data:` JSON), matching the
|
|
shape that `agent/protocol/agent_stream.py` consumes:
|
|
chunk["choices"][0]["delta"]["content" | "tool_calls" | "reasoning_content"]
|
|
chunk["choices"][0]["finish_reason"]
|
|
Plus dict-style error chunks: {"error": True, "message": ..., "status_code": ...}
|
|
- Compatible with arbitrary OpenAI-compatible endpoints (LinkAI, Azure-style
|
|
proxies, DeepSeek, Moonshot, etc.) by allowing per-call api_key / api_base
|
|
override and trusting whatever path/payload shape the caller passes.
|
|
"""
|
|
|
|
import json
|
|
from typing import Any, Dict, Generator, Optional
|
|
from urllib.parse import urlparse
|
|
|
|
import requests
|
|
|
|
from common.log import logger
|
|
|
|
|
|
DEFAULT_API_BASE = "https://api.openai.com/v1"
|
|
DEFAULT_TIMEOUT = 600 # seconds; matches old openai SDK default
|
|
|
|
|
|
_APP_TITLE = "CowAgent"
|
|
_APP_REFERER = "https://github.com/zhayujie/CowAgent"
|
|
|
|
# Per-gateway app attribution headers, only sent when the request host
|
|
# matches a documented gateway. Sending these to user-configured custom
|
|
# proxies would leak app identity, so we dispatch by host suffix.
|
|
_ATTRIBUTION_HEADERS_BY_HOST: Dict[str, Dict[str, str]] = {
|
|
"openrouter.ai": {
|
|
"HTTP-Referer": _APP_REFERER,
|
|
"X-Title": _APP_TITLE,
|
|
},
|
|
"ai-gateway.vercel.sh": {
|
|
"HTTP-Referer": _APP_REFERER,
|
|
"X-Title": _APP_TITLE,
|
|
},
|
|
}
|
|
|
|
|
|
def _resolve_attribution_headers(url: str) -> Dict[str, str]:
|
|
try:
|
|
host = (urlparse(url).hostname or "").lower()
|
|
except Exception:
|
|
return {}
|
|
if not host:
|
|
return {}
|
|
for suffix, headers in _ATTRIBUTION_HEADERS_BY_HOST.items():
|
|
if host == suffix or host.endswith("." + suffix):
|
|
return dict(headers)
|
|
return {}
|
|
|
|
|
|
class OpenAIHTTPError(Exception):
|
|
"""Raised for non-2xx responses. Carries status code + parsed body."""
|
|
|
|
def __init__(self, status_code: int, body: Any, message: str = ""):
|
|
self.status_code = status_code
|
|
self.body = body
|
|
# Try to extract human-readable message from OpenAI-style error envelope
|
|
if not message and isinstance(body, dict):
|
|
err = body.get("error") or {}
|
|
if isinstance(err, dict):
|
|
message = err.get("message") or ""
|
|
elif isinstance(err, str):
|
|
message = err
|
|
if not message:
|
|
message = str(body)[:500]
|
|
self.message = message
|
|
super().__init__(f"HTTP {status_code}: {message}")
|
|
|
|
|
|
class OpenAIHTTPClient:
|
|
"""Minimal HTTP client for OpenAI-compatible endpoints.
|
|
|
|
Per-instance defaults (api_key / api_base / proxy / timeout) can be
|
|
overridden on every call. Callers can also pass ``extra_headers`` for
|
|
Azure-style ``api-key`` headers or custom routing headers.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
proxy: Optional[str] = None,
|
|
timeout: Optional[float] = None,
|
|
extra_headers: Optional[Dict[str, str]] = None,
|
|
):
|
|
self.api_key = api_key
|
|
self.api_base = (api_base or DEFAULT_API_BASE).rstrip("/")
|
|
self.timeout = timeout if timeout is not None else DEFAULT_TIMEOUT
|
|
self.proxies = (
|
|
{"http": proxy, "https": proxy} if proxy else None
|
|
)
|
|
self.extra_headers = dict(extra_headers) if extra_headers else {}
|
|
|
|
# ------------------------------------------------------------------ #
|
|
# Public API surface (mirrors what the old openai SDK provided)
|
|
# ------------------------------------------------------------------ #
|
|
|
|
def chat_completions(
|
|
self,
|
|
*,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
timeout: Optional[float] = None,
|
|
proxy: Optional[str] = None,
|
|
extra_headers: Optional[Dict[str, str]] = None,
|
|
extra_query: Optional[Dict[str, str]] = None,
|
|
path: str = "/chat/completions",
|
|
stream: bool = False,
|
|
**payload,
|
|
):
|
|
"""POST /chat/completions.
|
|
|
|
When ``stream=True`` returns a generator yielding parsed SSE chunks
|
|
(plain ``dict``). On error during streaming, yields a single dict with
|
|
``{"error": True, ...}`` and stops, matching the contract expected by
|
|
``agent/protocol/agent_stream.py``.
|
|
"""
|
|
payload["stream"] = stream
|
|
return self._request(
|
|
path=path,
|
|
payload=payload,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
proxy=proxy,
|
|
extra_headers=extra_headers,
|
|
extra_query=extra_query,
|
|
stream=stream,
|
|
)
|
|
|
|
def completions(
|
|
self,
|
|
*,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
timeout: Optional[float] = None,
|
|
**payload,
|
|
) -> Dict[str, Any]:
|
|
"""POST /completions (legacy text completion). Non-streaming only."""
|
|
payload.pop("stream", None)
|
|
return self._request(
|
|
path="/completions",
|
|
payload=payload,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
stream=False,
|
|
)
|
|
|
|
def images_generate(
|
|
self,
|
|
*,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
timeout: Optional[float] = None,
|
|
**payload,
|
|
) -> Dict[str, Any]:
|
|
"""POST /images/generations."""
|
|
return self._request(
|
|
path="/images/generations",
|
|
payload=payload,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
stream=False,
|
|
)
|
|
|
|
# ------------------------------------------------------------------ #
|
|
# Internal helpers
|
|
# ------------------------------------------------------------------ #
|
|
|
|
def _build_headers(
|
|
self,
|
|
api_key: Optional[str],
|
|
extra_headers: Optional[Dict[str, str]],
|
|
url: Optional[str] = None,
|
|
) -> Dict[str, str]:
|
|
key = api_key if api_key is not None else self.api_key
|
|
headers = {"Content-Type": "application/json"}
|
|
if key:
|
|
headers["Authorization"] = f"Bearer {key}"
|
|
if url:
|
|
attribution = _resolve_attribution_headers(url)
|
|
if attribution:
|
|
headers.update(attribution)
|
|
if self.extra_headers:
|
|
headers.update(self.extra_headers)
|
|
if extra_headers:
|
|
headers.update(extra_headers)
|
|
return headers
|
|
|
|
def _request(
|
|
self,
|
|
*,
|
|
path: str,
|
|
payload: Dict[str, Any],
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
timeout: Optional[float],
|
|
stream: bool,
|
|
proxy: Optional[str] = None,
|
|
extra_headers: Optional[Dict[str, str]] = None,
|
|
extra_query: Optional[Dict[str, str]] = None,
|
|
):
|
|
base = (api_base or self.api_base).rstrip("/") if api_base else self.api_base
|
|
url = f"{base}{path}" if path.startswith("/") else f"{base}/{path}"
|
|
headers = self._build_headers(api_key, extra_headers, url=url)
|
|
req_timeout = timeout if timeout is not None else self.timeout
|
|
proxies = (
|
|
{"http": proxy, "https": proxy} if proxy else self.proxies
|
|
)
|
|
|
|
# Drop None-valued keys; some providers reject explicit nulls.
|
|
clean_payload = {k: v for k, v in payload.items() if v is not None}
|
|
|
|
if stream:
|
|
# Return a generator. Errors during stream are yielded as a single
|
|
# error chunk so callers (agent_stream) can map them to their
|
|
# existing error-handling path without try/except around the loop.
|
|
return self._stream_chat(
|
|
url=url,
|
|
headers=headers,
|
|
payload=clean_payload,
|
|
proxies=proxies,
|
|
timeout=req_timeout,
|
|
params=extra_query,
|
|
)
|
|
|
|
try:
|
|
resp = requests.post(
|
|
url,
|
|
headers=headers,
|
|
json=clean_payload,
|
|
timeout=req_timeout,
|
|
proxies=proxies,
|
|
params=extra_query,
|
|
)
|
|
except requests.exceptions.Timeout as e:
|
|
raise OpenAIHTTPError(408, {}, f"Request timed out: {e}")
|
|
except requests.exceptions.ConnectionError as e:
|
|
raise OpenAIHTTPError(0, {}, f"Connection error: {e}")
|
|
except requests.exceptions.RequestException as e:
|
|
raise OpenAIHTTPError(0, {}, f"Request failed: {e}")
|
|
|
|
return self._parse_response(resp)
|
|
|
|
@staticmethod
|
|
def _parse_response(resp: requests.Response) -> Dict[str, Any]:
|
|
# Try JSON, fall back to text
|
|
try:
|
|
data = resp.json()
|
|
except ValueError:
|
|
data = {"raw": resp.text}
|
|
|
|
if resp.status_code >= 400:
|
|
raise OpenAIHTTPError(resp.status_code, data)
|
|
|
|
return data
|
|
|
|
def _stream_chat(
|
|
self,
|
|
*,
|
|
url: str,
|
|
headers: Dict[str, str],
|
|
payload: Dict[str, Any],
|
|
proxies: Optional[Dict[str, str]],
|
|
timeout: float,
|
|
params: Optional[Dict[str, str]] = None,
|
|
) -> Generator[Dict[str, Any], None, None]:
|
|
"""Stream SSE response and yield parsed JSON chunks.
|
|
|
|
Yields:
|
|
- Normal chunks: dict with ``choices[0].delta`` etc.
|
|
- Error chunks: ``{"error": True, "message": str, "status_code": int}``
|
|
followed by termination of the generator.
|
|
"""
|
|
try:
|
|
resp = requests.post(
|
|
url,
|
|
headers=headers,
|
|
json=payload,
|
|
timeout=timeout,
|
|
proxies=proxies,
|
|
stream=True,
|
|
params=params,
|
|
)
|
|
except requests.exceptions.Timeout as e:
|
|
yield self._make_error_chunk(408, f"Request timed out: {e}")
|
|
return
|
|
except requests.exceptions.ConnectionError as e:
|
|
yield self._make_error_chunk(0, f"Connection error: {e}")
|
|
return
|
|
except requests.exceptions.RequestException as e:
|
|
yield self._make_error_chunk(0, f"Request failed: {e}")
|
|
return
|
|
|
|
if resp.status_code >= 400:
|
|
# Read full body once for error reporting
|
|
try:
|
|
body = resp.json()
|
|
except ValueError:
|
|
body = {"raw": resp.text[:1000]}
|
|
err_msg = ""
|
|
err_code = ""
|
|
err_type = ""
|
|
if isinstance(body, dict):
|
|
err = body.get("error") or {}
|
|
if isinstance(err, dict):
|
|
err_msg = err.get("message") or ""
|
|
err_code = err.get("code") or ""
|
|
err_type = err.get("type") or ""
|
|
elif isinstance(err, str):
|
|
err_msg = err
|
|
if not err_msg:
|
|
err_msg = str(body)[:500]
|
|
yield {
|
|
"error": {
|
|
"message": err_msg,
|
|
"code": err_code,
|
|
"type": err_type,
|
|
},
|
|
# Top-level fields kept for backward compatibility with the
|
|
# error-shape that `_handle_stream_response` previously emitted.
|
|
"message": err_msg,
|
|
"status_code": resp.status_code,
|
|
}
|
|
return
|
|
|
|
# IMPORTANT: do NOT use `iter_lines(decode_unicode=True)`.
|
|
#
|
|
# `requests` decodes per-network-chunk using the response's declared
|
|
# encoding (often Latin-1 / ISO-8859-1 for SSE), which mangles UTF-8
|
|
# codepoints that straddle a chunk boundary. Some upstreams (Azure
|
|
# OpenAI proxies, Cloudflare-fronted gateways, ...) split TCP chunks
|
|
# aggressively in the middle of multibyte characters, producing
|
|
# garbled text and "skip malformed SSE chunk" errors.
|
|
#
|
|
# The fix is to read raw bytes, accumulate them until we have a
|
|
# complete SSE event (terminated by a blank line per the SSE spec:
|
|
# https://html.spec.whatwg.org/multipage/server-sent-events.html),
|
|
# and only THEN decode as UTF-8. This mirrors what the official
|
|
# openai SDK 1.x does in `openai/_streaming.py::SSEDecoder` (which
|
|
# itself is copied from httpx-sse).
|
|
try:
|
|
for sse_event in self._iter_sse_events(resp):
|
|
# `sse_event` is the joined `data:` payload as a str.
|
|
if sse_event == "[DONE]":
|
|
return
|
|
if not sse_event:
|
|
continue
|
|
try:
|
|
chunk = json.loads(sse_event)
|
|
except ValueError:
|
|
logger.debug(
|
|
f"[OpenAIHTTP] skip malformed SSE chunk: {sse_event[:200]}"
|
|
)
|
|
continue
|
|
yield chunk
|
|
except requests.exceptions.ChunkedEncodingError as e:
|
|
yield self._make_error_chunk(0, f"Stream interrupted: {e}")
|
|
except requests.exceptions.RequestException as e:
|
|
yield self._make_error_chunk(0, f"Stream error: {e}")
|
|
finally:
|
|
try:
|
|
resp.close()
|
|
except Exception:
|
|
pass
|
|
|
|
@staticmethod
|
|
def _iter_sse_events(resp: requests.Response) -> Generator[str, None, None]:
|
|
"""Decode an SSE byte stream into joined `data:` payloads.
|
|
|
|
Implements the subset of the SSE spec that OpenAI / OpenAI-compatible
|
|
endpoints actually use:
|
|
- Events are separated by blank lines (\\r\\r, \\n\\n, or \\r\\n\\r\\n).
|
|
- Within an event, multiple ``data:`` lines are concatenated with
|
|
"\\n" (per spec).
|
|
- ``event:``, ``id:``, ``retry:`` and comment lines (``:``) are
|
|
tolerated but not yielded — for chat-completion we only care
|
|
about the JSON payload in ``data:``.
|
|
- Bytes are buffered until a complete event boundary is seen so
|
|
UTF-8 codepoints split across TCP chunks decode correctly.
|
|
|
|
Yields each event's joined ``data`` string. The terminal sentinel
|
|
``[DONE]`` is yielded as a literal string so the caller can break.
|
|
"""
|
|
buf = b""
|
|
for raw in resp.iter_content(chunk_size=None, decode_unicode=False):
|
|
if not raw:
|
|
continue
|
|
buf += raw
|
|
# Find complete events (terminated by a blank line).
|
|
while True:
|
|
# Look for the earliest event terminator. SSE allows three
|
|
# forms; check all and pick the earliest match.
|
|
idx_nn = buf.find(b"\n\n")
|
|
idx_rr = buf.find(b"\r\r")
|
|
idx_rnrn = buf.find(b"\r\n\r\n")
|
|
candidates = [i for i in (idx_nn, idx_rr, idx_rnrn) if i != -1]
|
|
if not candidates:
|
|
break
|
|
# We need to know the length of the matched terminator to
|
|
# advance past it correctly.
|
|
end_pos = min(candidates)
|
|
if end_pos == idx_rnrn:
|
|
term_len = 4
|
|
else:
|
|
term_len = 2
|
|
event_bytes = buf[:end_pos]
|
|
buf = buf[end_pos + term_len:]
|
|
|
|
# Decode the full event as UTF-8. ``errors="replace"`` is a
|
|
# belt-and-suspenders safety net for truly malformed upstream
|
|
# bytes; it should never trigger for well-formed providers.
|
|
try:
|
|
event_text = event_bytes.decode("utf-8")
|
|
except UnicodeDecodeError:
|
|
event_text = event_bytes.decode("utf-8", errors="replace")
|
|
|
|
data_lines = []
|
|
for line in event_text.splitlines():
|
|
if not line or line.startswith(":"):
|
|
continue
|
|
field, _, value = line.partition(":")
|
|
# Per SSE spec, a single optional space after the colon
|
|
# is part of the framing, not the value.
|
|
if value.startswith(" "):
|
|
value = value[1:]
|
|
if field == "data":
|
|
data_lines.append(value)
|
|
# Other fields (event/id/retry) are intentionally ignored
|
|
# — chat-completion endpoints don't use them in a way we
|
|
# need for parsing.
|
|
if data_lines:
|
|
yield "\n".join(data_lines)
|
|
|
|
# Flush any trailing bytes the server forgot to terminate. This is
|
|
# rare but spec-allowed (some providers omit the final \n\n).
|
|
if buf.strip():
|
|
try:
|
|
event_text = buf.decode("utf-8")
|
|
except UnicodeDecodeError:
|
|
event_text = buf.decode("utf-8", errors="replace")
|
|
data_lines = []
|
|
for line in event_text.splitlines():
|
|
if not line or line.startswith(":"):
|
|
continue
|
|
field, _, value = line.partition(":")
|
|
if value.startswith(" "):
|
|
value = value[1:]
|
|
if field == "data":
|
|
data_lines.append(value)
|
|
if data_lines:
|
|
yield "\n".join(data_lines)
|
|
|
|
@staticmethod
|
|
def _make_error_chunk(status_code: int, message: str) -> Dict[str, Any]:
|
|
return {
|
|
"error": {"message": message, "code": "", "type": ""},
|
|
"message": message,
|
|
"status_code": status_code,
|
|
}
|
|
|
|
|
|
# A tiny helper for callers that just need a one-shot client without storing
|
|
# state. Keeps call sites cleaner than instantiating the class every time.
|
|
def get_default_client(
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
proxy: Optional[str] = None,
|
|
timeout: Optional[float] = None,
|
|
) -> OpenAIHTTPClient:
|
|
return OpenAIHTTPClient(
|
|
api_key=api_key, api_base=api_base, proxy=proxy, timeout=timeout
|
|
)
|