mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-02 00:57:41 +08:00
Add embedding_provider config knob with native support for openai / dashscope / doubao / zhipu / linkai, plus an in-chat /memory status and /memory rebuild-index workflow for switching vendors safely.
487 lines
19 KiB
Python
487 lines
19 KiB
Python
"""
|
|
Embedding providers for memory
|
|
|
|
Supports multiple OpenAI-compatible embedding vendors:
|
|
- openai (text-embedding-3-small / large)
|
|
- linkai (OpenAI-compatible passthrough)
|
|
- dashscope (Aliyun Tongyi text-embedding-v4)
|
|
- doubao (ByteDance Doubao Seed1.5 / large-text on Volcengine Ark)
|
|
- zhipu (ZhipuAI embedding-3)
|
|
|
|
Vendor keys here intentionally match the project's bot_type constants in
|
|
common.const (OPENAI, LINKAI, QWEN_DASHSCOPE, DOUBAO, ZHIPU_AI).
|
|
|
|
All providers share a single OpenAI-compatible REST client. Vendor-specific
|
|
behaviors (truncation, query instruction prefix) are configured via metadata.
|
|
"""
|
|
|
|
import hashlib
|
|
import math
|
|
from abc import ABC, abstractmethod
|
|
from typing import List, Optional
|
|
|
|
# HTTP read timeout for a single embeddings request (seconds). A batch of
|
|
# 64+ chunks can take 30-50s end-to-end from China-side networks, so 30s is
|
|
# routinely too tight; 90s gives meaningful headroom without letting bad
|
|
# endpoints hang forever.
|
|
EMBEDDING_HTTP_TIMEOUT = 90
|
|
|
|
|
|
class EmbeddingProvider(ABC):
|
|
"""Base class for embedding providers"""
|
|
|
|
@abstractmethod
|
|
def embed(self, text: str) -> List[float]:
|
|
"""Generate embedding for a single text (treated as a query by default)"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
|
"""Generate embeddings for multiple texts (treated as documents)"""
|
|
pass
|
|
|
|
def embed_query(self, text: str) -> List[float]:
|
|
"""Generate embedding for a query string (may apply vendor instruction prefix)"""
|
|
return self.embed(text)
|
|
|
|
@property
|
|
@abstractmethod
|
|
def dimensions(self) -> int:
|
|
"""Effective embedding dimensions"""
|
|
pass
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Vendor metadata table
|
|
# ---------------------------------------------------------------------------
|
|
#
|
|
# Each entry describes how to reach a vendor's embedding endpoint. Most
|
|
# vendors expose an OpenAI-compatible /embeddings API; the few that don't
|
|
# (currently: doubao) set `provider_class` to pick a dedicated adapter.
|
|
# Fields:
|
|
# provider_class : optional adapter key ("doubao"); defaults to OpenAI-compat
|
|
# default_base_url : default API base when not overridden by user
|
|
# default_model : default embedding model name
|
|
# default_dimensions : recommended unified dim when explicit path is enabled
|
|
# supports_dim_param : whether the API accepts a `dimensions` request param
|
|
# needs_client_truncate : whether to slice + L2-normalize on the client side
|
|
# needs_client_normalize : whether to L2-normalize on the client (always safe)
|
|
# query_instruction : optional prefix for asymmetric retrieval (Doubao Seed)
|
|
# max_batch_size : max texts per /embeddings request; embed_batch
|
|
# auto-paginates above this. Conservative defaults.
|
|
#
|
|
EMBEDDING_VENDORS = {
|
|
"openai": {
|
|
"default_base_url": "https://api.openai.com/v1",
|
|
"default_model": "text-embedding-3-small",
|
|
# Match the legacy default so users adding `embedding_provider: openai`
|
|
# to an existing index don't need to rebuild. Override via
|
|
# embedding_dimensions if you want 1024 / 1536 / 3072.
|
|
"default_dimensions": 1536,
|
|
"supports_dim_param": True,
|
|
"needs_client_truncate": False,
|
|
"needs_client_normalize": False,
|
|
"query_instruction": "",
|
|
# OpenAI permits up to 2048 items per request, but a single call
|
|
# carrying hundreds of long chunks routinely exceeds the 30s read
|
|
# timeout from China-side networks. 64 keeps each call well under
|
|
# both the token-per-request budget and a reasonable wall clock.
|
|
"max_batch_size": 64,
|
|
},
|
|
"linkai": {
|
|
"default_base_url": "https://api.link-ai.tech/v1",
|
|
"default_model": "text-embedding-3-small",
|
|
"default_dimensions": 1536,
|
|
"supports_dim_param": True,
|
|
"needs_client_truncate": False,
|
|
"needs_client_normalize": False,
|
|
"query_instruction": "",
|
|
"max_batch_size": 64,
|
|
},
|
|
"dashscope": {
|
|
"default_base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
|
"default_model": "text-embedding-v4",
|
|
"default_dimensions": 1024,
|
|
"supports_dim_param": True,
|
|
"needs_client_truncate": False,
|
|
"needs_client_normalize": False,
|
|
"query_instruction": "",
|
|
"max_batch_size": 10, # DashScope hard cap (text-embedding-v4)
|
|
},
|
|
"doubao": {
|
|
# Doubao no longer offers an OpenAI-compatible /v1/embeddings endpoint.
|
|
# Current models are unified under /api/v3/embeddings/multimodal
|
|
# which uses a structured `input` payload — see DoubaoEmbeddingProvider.
|
|
"provider_class": "doubao",
|
|
"default_base_url": "https://ark.cn-beijing.volces.com/api/v3",
|
|
"default_model": "doubao-embedding-vision-251215",
|
|
# Native options: 1024 or 2048. We default to 1024 to align with the
|
|
# other Chinese vendors (dashscope/zhipu) and keep storage footprint
|
|
# consistent across providers; users can still override via
|
|
# `embedding_dimensions: 2048` in config.
|
|
"default_dimensions": 1024,
|
|
"supports_dim_param": True,
|
|
"needs_client_truncate": False,
|
|
"needs_client_normalize": False,
|
|
"query_instruction": "",
|
|
# Multimodal endpoint produces ONE embedding per call (input list is
|
|
# a single document's parts, not a batch). embed_batch loops.
|
|
"max_batch_size": 1,
|
|
},
|
|
"zhipu": {
|
|
"default_base_url": "https://open.bigmodel.cn/api/paas/v4",
|
|
"default_model": "embedding-3",
|
|
"default_dimensions": 1024,
|
|
"supports_dim_param": True,
|
|
"needs_client_truncate": False,
|
|
"needs_client_normalize": False,
|
|
"query_instruction": "",
|
|
"max_batch_size": 64,
|
|
},
|
|
}
|
|
|
|
|
|
def _l2_normalize(vec: List[float]) -> List[float]:
|
|
"""Normalize a vector to unit length (L2 norm). Returns input on zero vector."""
|
|
norm = math.sqrt(sum(v * v for v in vec))
|
|
if norm == 0:
|
|
return vec
|
|
return [v / norm for v in vec]
|
|
|
|
|
|
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
|
"""
|
|
OpenAI-compatible embedding provider.
|
|
|
|
Used for openai/linkai/dashscope/ark/zhipu by configuring the metadata
|
|
fields. The legacy two-arg constructor (model, api_key, api_base) keeps
|
|
working, so the original OpenAI/LinkAI fallback code path is unchanged.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: str = "text-embedding-3-small",
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
extra_headers: Optional[dict] = None,
|
|
dimensions: Optional[int] = None,
|
|
supports_dim_param: bool = True,
|
|
needs_client_truncate: bool = False,
|
|
needs_client_normalize: bool = False,
|
|
query_instruction: str = "",
|
|
max_batch_size: int = 256,
|
|
):
|
|
"""
|
|
Args:
|
|
model: Model name (e.g. text-embedding-3-small, text-embedding-v4, embedding-3)
|
|
api_key: API key (required)
|
|
api_base: API base URL (defaults to OpenAI)
|
|
extra_headers: Optional extra HTTP headers
|
|
dimensions: Target output dimension. Required when supports_dim_param
|
|
is False and needs_client_truncate is True (used to slice).
|
|
supports_dim_param: Whether the vendor accepts a `dimensions` body param
|
|
needs_client_truncate: Slice the returned vector to `dimensions`
|
|
needs_client_normalize: L2-normalize on the client after slicing
|
|
query_instruction: Optional prefix prepended to query texts only
|
|
max_batch_size: Max items per /embeddings request; embed_batch
|
|
auto-paginates above this.
|
|
"""
|
|
self.model = model
|
|
self.api_key = api_key
|
|
self.api_base = api_base or "https://api.openai.com/v1"
|
|
self.extra_headers = extra_headers or {}
|
|
self.supports_dim_param = supports_dim_param
|
|
self.needs_client_truncate = needs_client_truncate
|
|
self.needs_client_normalize = needs_client_normalize
|
|
self.query_instruction = query_instruction or ""
|
|
self.max_batch_size = max(1, int(max_batch_size or 1))
|
|
|
|
if not self.api_key or self.api_key in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
|
raise ValueError("Embedding API key is not configured")
|
|
|
|
if dimensions is not None and dimensions > 0:
|
|
self._dimensions = dimensions
|
|
else:
|
|
# Legacy heuristic for OpenAI text-embedding-3-* family
|
|
self._dimensions = 1536 if "small" in model else 3072
|
|
|
|
def _call_api(self, input_data):
|
|
"""Call OpenAI-compatible /embeddings endpoint"""
|
|
import requests
|
|
|
|
url = f"{self.api_base}/embeddings"
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
**self.extra_headers,
|
|
}
|
|
data = {
|
|
"input": input_data,
|
|
"model": self.model,
|
|
}
|
|
if self.supports_dim_param and self._dimensions:
|
|
data["dimensions"] = self._dimensions
|
|
|
|
try:
|
|
response = requests.post(url, headers=headers, json=data, timeout=EMBEDDING_HTTP_TIMEOUT)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
except requests.exceptions.ConnectionError as e:
|
|
raise ConnectionError(
|
|
f"Failed to connect to embedding API at {url}. "
|
|
f"Please check network and api_base. Error: {str(e)}"
|
|
)
|
|
except requests.exceptions.Timeout as e:
|
|
raise TimeoutError(f"Embedding API request timed out. Error: {str(e)}")
|
|
except requests.exceptions.HTTPError as e:
|
|
if e.response.status_code == 401:
|
|
raise ValueError("Invalid embedding API key")
|
|
elif e.response.status_code == 429:
|
|
raise ValueError("Embedding API rate limit exceeded")
|
|
else:
|
|
raise ValueError(
|
|
f"Embedding API request failed: "
|
|
f"{e.response.status_code} - {e.response.text}"
|
|
)
|
|
|
|
def _post_process(self, raw: List[float]) -> List[float]:
|
|
"""Apply optional client-side truncation + normalization"""
|
|
vec = raw
|
|
if self.needs_client_truncate and self._dimensions and len(vec) > self._dimensions:
|
|
vec = vec[: self._dimensions]
|
|
if self.needs_client_normalize:
|
|
vec = _l2_normalize(vec)
|
|
return vec
|
|
|
|
def embed(self, text: str) -> List[float]:
|
|
"""Generate embedding (treated as document by default)"""
|
|
result = self._call_api(text)
|
|
return self._post_process(result["data"][0]["embedding"])
|
|
|
|
def embed_query(self, text: str) -> List[float]:
|
|
"""Generate embedding for a query (applies vendor instruction prefix if any)"""
|
|
if self.query_instruction:
|
|
text = f"{self.query_instruction}{text}"
|
|
return self.embed(text)
|
|
|
|
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
|
"""Generate embeddings for multiple documents.
|
|
|
|
Automatically paginates by self.max_batch_size so callers can pass any
|
|
number of texts. Order of returned vectors matches the input order.
|
|
"""
|
|
if not texts:
|
|
return []
|
|
out: List[List[float]] = []
|
|
step = self.max_batch_size
|
|
for i in range(0, len(texts), step):
|
|
chunk = texts[i:i + step]
|
|
result = self._call_api(chunk)
|
|
out.extend(self._post_process(item["embedding"]) for item in result["data"])
|
|
return out
|
|
|
|
@property
|
|
def dimensions(self) -> int:
|
|
return self._dimensions
|
|
|
|
|
|
class DoubaoEmbeddingProvider(EmbeddingProvider):
|
|
"""
|
|
Doubao (Volcengine Ark) multimodal embedding provider.
|
|
|
|
Doubao deprecated their OpenAI-compatible /v1/embeddings endpoint and
|
|
unified everything under /api/v3/embeddings/multimodal, which uses a
|
|
structured `input: [{type, text|image_url|video_url}, ...]` payload.
|
|
|
|
Notes:
|
|
* The endpoint produces ONE embedding per call (input list is multiple
|
|
modality parts of a single document, not a batch). embed_batch
|
|
therefore loops per-text — no native batch support.
|
|
* Native dimensions: 1024 or 2048 (default 1024 to align with other
|
|
Chinese vendors). No client-side truncation needed.
|
|
* Auth: Bearer ARK API key.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: str,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
extra_headers: Optional[dict] = None,
|
|
dimensions: Optional[int] = None,
|
|
):
|
|
self.model = model
|
|
self.api_key = api_key
|
|
self.api_base = api_base or "https://ark.cn-beijing.volces.com/api/v3"
|
|
self.extra_headers = extra_headers or {}
|
|
if not self.api_key or self.api_key in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
|
raise ValueError("Doubao embedding API key (ark_api_key) is not configured")
|
|
|
|
if dimensions in (1024, 2048):
|
|
self._dimensions = dimensions
|
|
elif dimensions is None:
|
|
self._dimensions = 1024
|
|
else:
|
|
raise ValueError(
|
|
f"Doubao embedding dimensions must be 1024 or 2048, got {dimensions}"
|
|
)
|
|
|
|
def _call_api(self, text: str) -> List[float]:
|
|
"""One call → one embedding. multimodal endpoint takes a single
|
|
document represented as a list of typed parts; we send a single
|
|
text part."""
|
|
import requests
|
|
|
|
url = f"{self.api_base}/embeddings/multimodal"
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
**self.extra_headers,
|
|
}
|
|
payload = {
|
|
"model": self.model,
|
|
"input": [{"type": "text", "text": text}],
|
|
"dimensions": self._dimensions,
|
|
"encoding_format": "float",
|
|
}
|
|
|
|
try:
|
|
response = requests.post(url, headers=headers, json=payload, timeout=EMBEDDING_HTTP_TIMEOUT)
|
|
response.raise_for_status()
|
|
body = response.json()
|
|
except requests.exceptions.ConnectionError as e:
|
|
raise ConnectionError(
|
|
f"Failed to connect to Doubao embedding API at {url}. "
|
|
f"Please check network and api_base. Error: {str(e)}"
|
|
)
|
|
except requests.exceptions.Timeout as e:
|
|
raise TimeoutError(f"Doubao embedding API request timed out. Error: {str(e)}")
|
|
except requests.exceptions.HTTPError as e:
|
|
if e.response.status_code == 401:
|
|
raise ValueError("Invalid Doubao (ark) embedding API key")
|
|
elif e.response.status_code == 429:
|
|
raise ValueError("Doubao embedding API rate limit exceeded")
|
|
else:
|
|
raise ValueError(
|
|
f"Doubao embedding API request failed: "
|
|
f"{e.response.status_code} - {e.response.text}"
|
|
)
|
|
|
|
# Response shape per docs: {"data": {"embedding": [...]}}
|
|
data = body.get("data")
|
|
if isinstance(data, dict) and "embedding" in data:
|
|
return data["embedding"]
|
|
# Some providers wrap as a list of one — be defensive
|
|
if isinstance(data, list) and data and "embedding" in data[0]:
|
|
return data[0]["embedding"]
|
|
raise ValueError(f"Unexpected Doubao embedding response shape: {body}")
|
|
|
|
def embed(self, text: str) -> List[float]:
|
|
return self._call_api(text)
|
|
|
|
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
|
# Endpoint produces one embedding per call; loop. Order preserved.
|
|
return [self._call_api(t) for t in texts]
|
|
|
|
@property
|
|
def dimensions(self) -> int:
|
|
return self._dimensions
|
|
|
|
|
|
class EmbeddingCache:
|
|
"""In-memory cache for embeddings to avoid recomputation"""
|
|
|
|
def __init__(self):
|
|
self.cache = {}
|
|
|
|
def get(self, text: str, provider: str, model: str) -> Optional[List[float]]:
|
|
key = self._compute_key(text, provider, model)
|
|
return self.cache.get(key)
|
|
|
|
def put(self, text: str, provider: str, model: str, embedding: List[float]):
|
|
key = self._compute_key(text, provider, model)
|
|
self.cache[key] = embedding
|
|
|
|
@staticmethod
|
|
def _compute_key(text: str, provider: str, model: str) -> str:
|
|
content = f"{provider}:{model}:{text}"
|
|
return hashlib.md5(content.encode("utf-8")).hexdigest()
|
|
|
|
def clear(self):
|
|
self.cache.clear()
|
|
|
|
|
|
def create_embedding_provider(
|
|
provider: str = "openai",
|
|
model: Optional[str] = None,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
extra_headers: Optional[dict] = None,
|
|
dimensions: Optional[int] = None,
|
|
) -> EmbeddingProvider:
|
|
"""
|
|
Factory function to create an embedding provider.
|
|
|
|
Backward compatible: when called with provider in {"openai", "linkai"}
|
|
and no `dimensions` arg, behaves exactly as before (1536-dim OpenAI).
|
|
|
|
New providers ("dashscope", "doubao", "zhipu") require explicit configuration
|
|
and use the unified 1024-dim defaults from EMBEDDING_VENDORS.
|
|
|
|
Args:
|
|
provider: Vendor key (one of EMBEDDING_VENDORS)
|
|
model: Model name (uses vendor default if None)
|
|
api_key: API key (required)
|
|
api_base: API base URL (uses vendor default if None)
|
|
extra_headers: Optional extra HTTP headers
|
|
dimensions: Target output dimension (uses vendor default if None)
|
|
|
|
Returns:
|
|
EmbeddingProvider instance
|
|
"""
|
|
meta = EMBEDDING_VENDORS.get(provider)
|
|
if meta is None:
|
|
raise ValueError(
|
|
f"Unsupported embedding provider: {provider}. "
|
|
f"Supported: {sorted(EMBEDDING_VENDORS.keys())}"
|
|
)
|
|
|
|
# Doubao uses a non-OpenAI-compatible multimodal endpoint.
|
|
if meta.get("provider_class") == "doubao":
|
|
final_dim = dimensions if (dimensions and dimensions > 0) else meta["default_dimensions"]
|
|
return DoubaoEmbeddingProvider(
|
|
model=model or meta["default_model"],
|
|
api_key=api_key,
|
|
api_base=api_base or meta["default_base_url"],
|
|
extra_headers=extra_headers,
|
|
dimensions=final_dim,
|
|
)
|
|
|
|
# Legacy two-arg call for openai/linkai keeps 1536-dim default behavior
|
|
# so existing data isn't invalidated.
|
|
is_legacy_call = (
|
|
provider in ("openai", "linkai")
|
|
and dimensions is None
|
|
)
|
|
if is_legacy_call:
|
|
return OpenAIEmbeddingProvider(
|
|
model=model or "text-embedding-3-small",
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
extra_headers=extra_headers,
|
|
)
|
|
|
|
final_dim = dimensions if (dimensions and dimensions > 0) else meta["default_dimensions"]
|
|
return OpenAIEmbeddingProvider(
|
|
model=model or meta["default_model"],
|
|
api_key=api_key,
|
|
api_base=api_base or meta["default_base_url"],
|
|
extra_headers=extra_headers,
|
|
dimensions=final_dim,
|
|
supports_dim_param=meta["supports_dim_param"],
|
|
needs_client_truncate=meta["needs_client_truncate"],
|
|
needs_client_normalize=meta["needs_client_normalize"],
|
|
query_instruction=meta["query_instruction"],
|
|
max_batch_size=meta.get("max_batch_size", 256),
|
|
)
|