mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-02 00:57:41 +08:00
feat(memory): support multi-vendor embedding fallback
Add embedding_provider config knob with native support for openai / dashscope / doubao / zhipu / linkai, plus an in-chat /memory status and /memory rebuild-index workflow for switching vendors safely.
This commit is contained in:
@@ -1,167 +0,0 @@
|
|||||||
"""
|
|
||||||
Embedding providers for memory
|
|
||||||
|
|
||||||
Supports OpenAI and local embedding models
|
|
||||||
"""
|
|
||||||
|
|
||||||
import hashlib
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingProvider(ABC):
|
|
||||||
"""Base class for embedding providers"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def embed(self, text: str) -> List[float]:
|
|
||||||
"""Generate embedding for text"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
|
||||||
"""Generate embeddings for multiple texts"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def dimensions(self) -> int:
|
|
||||||
"""Get embedding dimensions"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
|
||||||
"""OpenAI embedding provider using REST API"""
|
|
||||||
|
|
||||||
def __init__(self, model: str = "text-embedding-3-small", api_key: Optional[str] = None,
|
|
||||||
api_base: Optional[str] = None, extra_headers: Optional[dict] = None):
|
|
||||||
"""
|
|
||||||
Initialize OpenAI embedding provider
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: Model name (text-embedding-3-small or text-embedding-3-large)
|
|
||||||
api_key: OpenAI API key
|
|
||||||
api_base: Optional API base URL
|
|
||||||
extra_headers: Optional extra headers to include in API requests
|
|
||||||
"""
|
|
||||||
self.model = model
|
|
||||||
self.api_key = api_key
|
|
||||||
self.api_base = api_base or "https://api.openai.com/v1"
|
|
||||||
self.extra_headers = extra_headers or {}
|
|
||||||
|
|
||||||
# Validate API key
|
|
||||||
if not self.api_key or self.api_key in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
|
||||||
raise ValueError("OpenAI API key is not configured. Please set 'open_ai_api_key' in config.json")
|
|
||||||
|
|
||||||
# Set dimensions based on model
|
|
||||||
self._dimensions = 1536 if "small" in model else 3072
|
|
||||||
|
|
||||||
def _call_api(self, input_data):
|
|
||||||
"""Call OpenAI embedding API using requests"""
|
|
||||||
import requests
|
|
||||||
|
|
||||||
url = f"{self.api_base}/embeddings"
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
|
||||||
**self.extra_headers,
|
|
||||||
}
|
|
||||||
data = {
|
|
||||||
"input": input_data,
|
|
||||||
"model": self.model
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = requests.post(url, headers=headers, json=data, timeout=5)
|
|
||||||
response.raise_for_status()
|
|
||||||
return response.json()
|
|
||||||
except requests.exceptions.ConnectionError as e:
|
|
||||||
raise ConnectionError(f"Failed to connect to OpenAI API at {url}. Please check your network connection and api_base configuration. Error: {str(e)}")
|
|
||||||
except requests.exceptions.Timeout as e:
|
|
||||||
raise TimeoutError(f"OpenAI API request timed out after 10s. Please check your network connection. Error: {str(e)}")
|
|
||||||
except requests.exceptions.HTTPError as e:
|
|
||||||
if e.response.status_code == 401:
|
|
||||||
raise ValueError(f"Invalid OpenAI API key. Please check your 'open_ai_api_key' in config.json")
|
|
||||||
elif e.response.status_code == 429:
|
|
||||||
raise ValueError(f"OpenAI API rate limit exceeded. Please try again later.")
|
|
||||||
else:
|
|
||||||
raise ValueError(f"OpenAI API request failed: {e.response.status_code} - {e.response.text}")
|
|
||||||
|
|
||||||
def embed(self, text: str) -> List[float]:
|
|
||||||
"""Generate embedding for text"""
|
|
||||||
result = self._call_api(text)
|
|
||||||
return result["data"][0]["embedding"]
|
|
||||||
|
|
||||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
|
||||||
"""Generate embeddings for multiple texts"""
|
|
||||||
if not texts:
|
|
||||||
return []
|
|
||||||
|
|
||||||
result = self._call_api(texts)
|
|
||||||
return [item["embedding"] for item in result["data"]]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dimensions(self) -> int:
|
|
||||||
return self._dimensions
|
|
||||||
|
|
||||||
|
|
||||||
# LocalEmbeddingProvider removed - only use OpenAI embedding or keyword search
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingCache:
|
|
||||||
"""Cache for embeddings to avoid recomputation"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.cache = {}
|
|
||||||
|
|
||||||
def get(self, text: str, provider: str, model: str) -> Optional[List[float]]:
|
|
||||||
"""Get cached embedding"""
|
|
||||||
key = self._compute_key(text, provider, model)
|
|
||||||
return self.cache.get(key)
|
|
||||||
|
|
||||||
def put(self, text: str, provider: str, model: str, embedding: List[float]):
|
|
||||||
"""Cache embedding"""
|
|
||||||
key = self._compute_key(text, provider, model)
|
|
||||||
self.cache[key] = embedding
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _compute_key(text: str, provider: str, model: str) -> str:
|
|
||||||
"""Compute cache key"""
|
|
||||||
content = f"{provider}:{model}:{text}"
|
|
||||||
return hashlib.md5(content.encode('utf-8')).hexdigest()
|
|
||||||
|
|
||||||
def clear(self):
|
|
||||||
"""Clear cache"""
|
|
||||||
self.cache.clear()
|
|
||||||
|
|
||||||
|
|
||||||
def create_embedding_provider(
|
|
||||||
provider: str = "openai",
|
|
||||||
model: Optional[str] = None,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
api_base: Optional[str] = None,
|
|
||||||
extra_headers: Optional[dict] = None
|
|
||||||
) -> EmbeddingProvider:
|
|
||||||
"""
|
|
||||||
Factory function to create embedding provider
|
|
||||||
|
|
||||||
Supports "openai" and "linkai" providers (both use OpenAI-compatible REST API).
|
|
||||||
If initialization fails, caller should fall back to keyword-only search.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
provider: Provider name ("openai" or "linkai")
|
|
||||||
model: Model name (default: text-embedding-3-small)
|
|
||||||
api_key: API key (required)
|
|
||||||
api_base: API base URL
|
|
||||||
extra_headers: Optional extra headers to include in API requests
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
EmbeddingProvider instance
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If provider is unsupported or api_key is missing
|
|
||||||
"""
|
|
||||||
if provider not in ("openai", "linkai"):
|
|
||||||
raise ValueError(f"Unsupported embedding provider: {provider}. Use 'openai' or 'linkai'.")
|
|
||||||
|
|
||||||
model = model or "text-embedding-3-small"
|
|
||||||
return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base, extra_headers=extra_headers)
|
|
||||||
41
agent/memory/embedding/__init__.py
Normal file
41
agent/memory/embedding/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
"""
|
||||||
|
Embedding subsystem for memory.
|
||||||
|
|
||||||
|
Public API:
|
||||||
|
create_embedding_provider, EmbeddingProvider, OpenAIEmbeddingProvider,
|
||||||
|
EMBEDDING_VENDORS, EmbeddingCache
|
||||||
|
RebuildResult, clear_index, rebuild_in_process
|
||||||
|
detect_index_dim, cleanup_legacy_state_file
|
||||||
|
"""
|
||||||
|
|
||||||
|
from agent.memory.embedding.provider import (
|
||||||
|
EMBEDDING_VENDORS,
|
||||||
|
DoubaoEmbeddingProvider,
|
||||||
|
EmbeddingCache,
|
||||||
|
EmbeddingProvider,
|
||||||
|
OpenAIEmbeddingProvider,
|
||||||
|
create_embedding_provider,
|
||||||
|
)
|
||||||
|
from agent.memory.embedding.rebuild import (
|
||||||
|
RebuildResult,
|
||||||
|
clear_index,
|
||||||
|
rebuild_in_process,
|
||||||
|
)
|
||||||
|
from agent.memory.embedding.state import (
|
||||||
|
cleanup_legacy_state_file,
|
||||||
|
detect_index_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"EMBEDDING_VENDORS",
|
||||||
|
"DoubaoEmbeddingProvider",
|
||||||
|
"EmbeddingCache",
|
||||||
|
"EmbeddingProvider",
|
||||||
|
"OpenAIEmbeddingProvider",
|
||||||
|
"create_embedding_provider",
|
||||||
|
"RebuildResult",
|
||||||
|
"clear_index",
|
||||||
|
"rebuild_in_process",
|
||||||
|
"cleanup_legacy_state_file",
|
||||||
|
"detect_index_dim",
|
||||||
|
]
|
||||||
486
agent/memory/embedding/provider.py
Normal file
486
agent/memory/embedding/provider.py
Normal file
@@ -0,0 +1,486 @@
|
|||||||
|
"""
|
||||||
|
Embedding providers for memory
|
||||||
|
|
||||||
|
Supports multiple OpenAI-compatible embedding vendors:
|
||||||
|
- openai (text-embedding-3-small / large)
|
||||||
|
- linkai (OpenAI-compatible passthrough)
|
||||||
|
- dashscope (Aliyun Tongyi text-embedding-v4)
|
||||||
|
- doubao (ByteDance Doubao Seed1.5 / large-text on Volcengine Ark)
|
||||||
|
- zhipu (ZhipuAI embedding-3)
|
||||||
|
|
||||||
|
Vendor keys here intentionally match the project's bot_type constants in
|
||||||
|
common.const (OPENAI, LINKAI, QWEN_DASHSCOPE, DOUBAO, ZHIPU_AI).
|
||||||
|
|
||||||
|
All providers share a single OpenAI-compatible REST client. Vendor-specific
|
||||||
|
behaviors (truncation, query instruction prefix) are configured via metadata.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import math
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
# HTTP read timeout for a single embeddings request (seconds). A batch of
|
||||||
|
# 64+ chunks can take 30-50s end-to-end from China-side networks, so 30s is
|
||||||
|
# routinely too tight; 90s gives meaningful headroom without letting bad
|
||||||
|
# endpoints hang forever.
|
||||||
|
EMBEDDING_HTTP_TIMEOUT = 90
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingProvider(ABC):
|
||||||
|
"""Base class for embedding providers"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def embed(self, text: str) -> List[float]:
|
||||||
|
"""Generate embedding for a single text (treated as a query by default)"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Generate embeddings for multiple texts (treated as documents)"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
"""Generate embedding for a query string (may apply vendor instruction prefix)"""
|
||||||
|
return self.embed(text)
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def dimensions(self) -> int:
|
||||||
|
"""Effective embedding dimensions"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Vendor metadata table
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
#
|
||||||
|
# Each entry describes how to reach a vendor's embedding endpoint. Most
|
||||||
|
# vendors expose an OpenAI-compatible /embeddings API; the few that don't
|
||||||
|
# (currently: doubao) set `provider_class` to pick a dedicated adapter.
|
||||||
|
# Fields:
|
||||||
|
# provider_class : optional adapter key ("doubao"); defaults to OpenAI-compat
|
||||||
|
# default_base_url : default API base when not overridden by user
|
||||||
|
# default_model : default embedding model name
|
||||||
|
# default_dimensions : recommended unified dim when explicit path is enabled
|
||||||
|
# supports_dim_param : whether the API accepts a `dimensions` request param
|
||||||
|
# needs_client_truncate : whether to slice + L2-normalize on the client side
|
||||||
|
# needs_client_normalize : whether to L2-normalize on the client (always safe)
|
||||||
|
# query_instruction : optional prefix for asymmetric retrieval (Doubao Seed)
|
||||||
|
# max_batch_size : max texts per /embeddings request; embed_batch
|
||||||
|
# auto-paginates above this. Conservative defaults.
|
||||||
|
#
|
||||||
|
EMBEDDING_VENDORS = {
|
||||||
|
"openai": {
|
||||||
|
"default_base_url": "https://api.openai.com/v1",
|
||||||
|
"default_model": "text-embedding-3-small",
|
||||||
|
# Match the legacy default so users adding `embedding_provider: openai`
|
||||||
|
# to an existing index don't need to rebuild. Override via
|
||||||
|
# embedding_dimensions if you want 1024 / 1536 / 3072.
|
||||||
|
"default_dimensions": 1536,
|
||||||
|
"supports_dim_param": True,
|
||||||
|
"needs_client_truncate": False,
|
||||||
|
"needs_client_normalize": False,
|
||||||
|
"query_instruction": "",
|
||||||
|
# OpenAI permits up to 2048 items per request, but a single call
|
||||||
|
# carrying hundreds of long chunks routinely exceeds the 30s read
|
||||||
|
# timeout from China-side networks. 64 keeps each call well under
|
||||||
|
# both the token-per-request budget and a reasonable wall clock.
|
||||||
|
"max_batch_size": 64,
|
||||||
|
},
|
||||||
|
"linkai": {
|
||||||
|
"default_base_url": "https://api.link-ai.tech/v1",
|
||||||
|
"default_model": "text-embedding-3-small",
|
||||||
|
"default_dimensions": 1536,
|
||||||
|
"supports_dim_param": True,
|
||||||
|
"needs_client_truncate": False,
|
||||||
|
"needs_client_normalize": False,
|
||||||
|
"query_instruction": "",
|
||||||
|
"max_batch_size": 64,
|
||||||
|
},
|
||||||
|
"dashscope": {
|
||||||
|
"default_base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||||
|
"default_model": "text-embedding-v4",
|
||||||
|
"default_dimensions": 1024,
|
||||||
|
"supports_dim_param": True,
|
||||||
|
"needs_client_truncate": False,
|
||||||
|
"needs_client_normalize": False,
|
||||||
|
"query_instruction": "",
|
||||||
|
"max_batch_size": 10, # DashScope hard cap (text-embedding-v4)
|
||||||
|
},
|
||||||
|
"doubao": {
|
||||||
|
# Doubao no longer offers an OpenAI-compatible /v1/embeddings endpoint.
|
||||||
|
# Current models are unified under /api/v3/embeddings/multimodal
|
||||||
|
# which uses a structured `input` payload — see DoubaoEmbeddingProvider.
|
||||||
|
"provider_class": "doubao",
|
||||||
|
"default_base_url": "https://ark.cn-beijing.volces.com/api/v3",
|
||||||
|
"default_model": "doubao-embedding-vision-251215",
|
||||||
|
# Native options: 1024 or 2048. We default to 1024 to align with the
|
||||||
|
# other Chinese vendors (dashscope/zhipu) and keep storage footprint
|
||||||
|
# consistent across providers; users can still override via
|
||||||
|
# `embedding_dimensions: 2048` in config.
|
||||||
|
"default_dimensions": 1024,
|
||||||
|
"supports_dim_param": True,
|
||||||
|
"needs_client_truncate": False,
|
||||||
|
"needs_client_normalize": False,
|
||||||
|
"query_instruction": "",
|
||||||
|
# Multimodal endpoint produces ONE embedding per call (input list is
|
||||||
|
# a single document's parts, not a batch). embed_batch loops.
|
||||||
|
"max_batch_size": 1,
|
||||||
|
},
|
||||||
|
"zhipu": {
|
||||||
|
"default_base_url": "https://open.bigmodel.cn/api/paas/v4",
|
||||||
|
"default_model": "embedding-3",
|
||||||
|
"default_dimensions": 1024,
|
||||||
|
"supports_dim_param": True,
|
||||||
|
"needs_client_truncate": False,
|
||||||
|
"needs_client_normalize": False,
|
||||||
|
"query_instruction": "",
|
||||||
|
"max_batch_size": 64,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _l2_normalize(vec: List[float]) -> List[float]:
|
||||||
|
"""Normalize a vector to unit length (L2 norm). Returns input on zero vector."""
|
||||||
|
norm = math.sqrt(sum(v * v for v in vec))
|
||||||
|
if norm == 0:
|
||||||
|
return vec
|
||||||
|
return [v / norm for v in vec]
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||||
|
"""
|
||||||
|
OpenAI-compatible embedding provider.
|
||||||
|
|
||||||
|
Used for openai/linkai/dashscope/ark/zhipu by configuring the metadata
|
||||||
|
fields. The legacy two-arg constructor (model, api_key, api_base) keeps
|
||||||
|
working, so the original OpenAI/LinkAI fallback code path is unchanged.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str = "text-embedding-3-small",
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
extra_headers: Optional[dict] = None,
|
||||||
|
dimensions: Optional[int] = None,
|
||||||
|
supports_dim_param: bool = True,
|
||||||
|
needs_client_truncate: bool = False,
|
||||||
|
needs_client_normalize: bool = False,
|
||||||
|
query_instruction: str = "",
|
||||||
|
max_batch_size: int = 256,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
model: Model name (e.g. text-embedding-3-small, text-embedding-v4, embedding-3)
|
||||||
|
api_key: API key (required)
|
||||||
|
api_base: API base URL (defaults to OpenAI)
|
||||||
|
extra_headers: Optional extra HTTP headers
|
||||||
|
dimensions: Target output dimension. Required when supports_dim_param
|
||||||
|
is False and needs_client_truncate is True (used to slice).
|
||||||
|
supports_dim_param: Whether the vendor accepts a `dimensions` body param
|
||||||
|
needs_client_truncate: Slice the returned vector to `dimensions`
|
||||||
|
needs_client_normalize: L2-normalize on the client after slicing
|
||||||
|
query_instruction: Optional prefix prepended to query texts only
|
||||||
|
max_batch_size: Max items per /embeddings request; embed_batch
|
||||||
|
auto-paginates above this.
|
||||||
|
"""
|
||||||
|
self.model = model
|
||||||
|
self.api_key = api_key
|
||||||
|
self.api_base = api_base or "https://api.openai.com/v1"
|
||||||
|
self.extra_headers = extra_headers or {}
|
||||||
|
self.supports_dim_param = supports_dim_param
|
||||||
|
self.needs_client_truncate = needs_client_truncate
|
||||||
|
self.needs_client_normalize = needs_client_normalize
|
||||||
|
self.query_instruction = query_instruction or ""
|
||||||
|
self.max_batch_size = max(1, int(max_batch_size or 1))
|
||||||
|
|
||||||
|
if not self.api_key or self.api_key in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
||||||
|
raise ValueError("Embedding API key is not configured")
|
||||||
|
|
||||||
|
if dimensions is not None and dimensions > 0:
|
||||||
|
self._dimensions = dimensions
|
||||||
|
else:
|
||||||
|
# Legacy heuristic for OpenAI text-embedding-3-* family
|
||||||
|
self._dimensions = 1536 if "small" in model else 3072
|
||||||
|
|
||||||
|
def _call_api(self, input_data):
|
||||||
|
"""Call OpenAI-compatible /embeddings endpoint"""
|
||||||
|
import requests
|
||||||
|
|
||||||
|
url = f"{self.api_base}/embeddings"
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
**self.extra_headers,
|
||||||
|
}
|
||||||
|
data = {
|
||||||
|
"input": input_data,
|
||||||
|
"model": self.model,
|
||||||
|
}
|
||||||
|
if self.supports_dim_param and self._dimensions:
|
||||||
|
data["dimensions"] = self._dimensions
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(url, headers=headers, json=data, timeout=EMBEDDING_HTTP_TIMEOUT)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
except requests.exceptions.ConnectionError as e:
|
||||||
|
raise ConnectionError(
|
||||||
|
f"Failed to connect to embedding API at {url}. "
|
||||||
|
f"Please check network and api_base. Error: {str(e)}"
|
||||||
|
)
|
||||||
|
except requests.exceptions.Timeout as e:
|
||||||
|
raise TimeoutError(f"Embedding API request timed out. Error: {str(e)}")
|
||||||
|
except requests.exceptions.HTTPError as e:
|
||||||
|
if e.response.status_code == 401:
|
||||||
|
raise ValueError("Invalid embedding API key")
|
||||||
|
elif e.response.status_code == 429:
|
||||||
|
raise ValueError("Embedding API rate limit exceeded")
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Embedding API request failed: "
|
||||||
|
f"{e.response.status_code} - {e.response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _post_process(self, raw: List[float]) -> List[float]:
|
||||||
|
"""Apply optional client-side truncation + normalization"""
|
||||||
|
vec = raw
|
||||||
|
if self.needs_client_truncate and self._dimensions and len(vec) > self._dimensions:
|
||||||
|
vec = vec[: self._dimensions]
|
||||||
|
if self.needs_client_normalize:
|
||||||
|
vec = _l2_normalize(vec)
|
||||||
|
return vec
|
||||||
|
|
||||||
|
def embed(self, text: str) -> List[float]:
|
||||||
|
"""Generate embedding (treated as document by default)"""
|
||||||
|
result = self._call_api(text)
|
||||||
|
return self._post_process(result["data"][0]["embedding"])
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
"""Generate embedding for a query (applies vendor instruction prefix if any)"""
|
||||||
|
if self.query_instruction:
|
||||||
|
text = f"{self.query_instruction}{text}"
|
||||||
|
return self.embed(text)
|
||||||
|
|
||||||
|
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Generate embeddings for multiple documents.
|
||||||
|
|
||||||
|
Automatically paginates by self.max_batch_size so callers can pass any
|
||||||
|
number of texts. Order of returned vectors matches the input order.
|
||||||
|
"""
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
out: List[List[float]] = []
|
||||||
|
step = self.max_batch_size
|
||||||
|
for i in range(0, len(texts), step):
|
||||||
|
chunk = texts[i:i + step]
|
||||||
|
result = self._call_api(chunk)
|
||||||
|
out.extend(self._post_process(item["embedding"]) for item in result["data"])
|
||||||
|
return out
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dimensions(self) -> int:
|
||||||
|
return self._dimensions
|
||||||
|
|
||||||
|
|
||||||
|
class DoubaoEmbeddingProvider(EmbeddingProvider):
|
||||||
|
"""
|
||||||
|
Doubao (Volcengine Ark) multimodal embedding provider.
|
||||||
|
|
||||||
|
Doubao deprecated their OpenAI-compatible /v1/embeddings endpoint and
|
||||||
|
unified everything under /api/v3/embeddings/multimodal, which uses a
|
||||||
|
structured `input: [{type, text|image_url|video_url}, ...]` payload.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
* The endpoint produces ONE embedding per call (input list is multiple
|
||||||
|
modality parts of a single document, not a batch). embed_batch
|
||||||
|
therefore loops per-text — no native batch support.
|
||||||
|
* Native dimensions: 1024 or 2048 (default 1024 to align with other
|
||||||
|
Chinese vendors). No client-side truncation needed.
|
||||||
|
* Auth: Bearer ARK API key.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
extra_headers: Optional[dict] = None,
|
||||||
|
dimensions: Optional[int] = None,
|
||||||
|
):
|
||||||
|
self.model = model
|
||||||
|
self.api_key = api_key
|
||||||
|
self.api_base = api_base or "https://ark.cn-beijing.volces.com/api/v3"
|
||||||
|
self.extra_headers = extra_headers or {}
|
||||||
|
if not self.api_key or self.api_key in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
||||||
|
raise ValueError("Doubao embedding API key (ark_api_key) is not configured")
|
||||||
|
|
||||||
|
if dimensions in (1024, 2048):
|
||||||
|
self._dimensions = dimensions
|
||||||
|
elif dimensions is None:
|
||||||
|
self._dimensions = 1024
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Doubao embedding dimensions must be 1024 or 2048, got {dimensions}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _call_api(self, text: str) -> List[float]:
|
||||||
|
"""One call → one embedding. multimodal endpoint takes a single
|
||||||
|
document represented as a list of typed parts; we send a single
|
||||||
|
text part."""
|
||||||
|
import requests
|
||||||
|
|
||||||
|
url = f"{self.api_base}/embeddings/multimodal"
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
**self.extra_headers,
|
||||||
|
}
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"input": [{"type": "text", "text": text}],
|
||||||
|
"dimensions": self._dimensions,
|
||||||
|
"encoding_format": "float",
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(url, headers=headers, json=payload, timeout=EMBEDDING_HTTP_TIMEOUT)
|
||||||
|
response.raise_for_status()
|
||||||
|
body = response.json()
|
||||||
|
except requests.exceptions.ConnectionError as e:
|
||||||
|
raise ConnectionError(
|
||||||
|
f"Failed to connect to Doubao embedding API at {url}. "
|
||||||
|
f"Please check network and api_base. Error: {str(e)}"
|
||||||
|
)
|
||||||
|
except requests.exceptions.Timeout as e:
|
||||||
|
raise TimeoutError(f"Doubao embedding API request timed out. Error: {str(e)}")
|
||||||
|
except requests.exceptions.HTTPError as e:
|
||||||
|
if e.response.status_code == 401:
|
||||||
|
raise ValueError("Invalid Doubao (ark) embedding API key")
|
||||||
|
elif e.response.status_code == 429:
|
||||||
|
raise ValueError("Doubao embedding API rate limit exceeded")
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Doubao embedding API request failed: "
|
||||||
|
f"{e.response.status_code} - {e.response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Response shape per docs: {"data": {"embedding": [...]}}
|
||||||
|
data = body.get("data")
|
||||||
|
if isinstance(data, dict) and "embedding" in data:
|
||||||
|
return data["embedding"]
|
||||||
|
# Some providers wrap as a list of one — be defensive
|
||||||
|
if isinstance(data, list) and data and "embedding" in data[0]:
|
||||||
|
return data[0]["embedding"]
|
||||||
|
raise ValueError(f"Unexpected Doubao embedding response shape: {body}")
|
||||||
|
|
||||||
|
def embed(self, text: str) -> List[float]:
|
||||||
|
return self._call_api(text)
|
||||||
|
|
||||||
|
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
# Endpoint produces one embedding per call; loop. Order preserved.
|
||||||
|
return [self._call_api(t) for t in texts]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dimensions(self) -> int:
|
||||||
|
return self._dimensions
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingCache:
|
||||||
|
"""In-memory cache for embeddings to avoid recomputation"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.cache = {}
|
||||||
|
|
||||||
|
def get(self, text: str, provider: str, model: str) -> Optional[List[float]]:
|
||||||
|
key = self._compute_key(text, provider, model)
|
||||||
|
return self.cache.get(key)
|
||||||
|
|
||||||
|
def put(self, text: str, provider: str, model: str, embedding: List[float]):
|
||||||
|
key = self._compute_key(text, provider, model)
|
||||||
|
self.cache[key] = embedding
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _compute_key(text: str, provider: str, model: str) -> str:
|
||||||
|
content = f"{provider}:{model}:{text}"
|
||||||
|
return hashlib.md5(content.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self.cache.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def create_embedding_provider(
|
||||||
|
provider: str = "openai",
|
||||||
|
model: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
extra_headers: Optional[dict] = None,
|
||||||
|
dimensions: Optional[int] = None,
|
||||||
|
) -> EmbeddingProvider:
|
||||||
|
"""
|
||||||
|
Factory function to create an embedding provider.
|
||||||
|
|
||||||
|
Backward compatible: when called with provider in {"openai", "linkai"}
|
||||||
|
and no `dimensions` arg, behaves exactly as before (1536-dim OpenAI).
|
||||||
|
|
||||||
|
New providers ("dashscope", "doubao", "zhipu") require explicit configuration
|
||||||
|
and use the unified 1024-dim defaults from EMBEDDING_VENDORS.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: Vendor key (one of EMBEDDING_VENDORS)
|
||||||
|
model: Model name (uses vendor default if None)
|
||||||
|
api_key: API key (required)
|
||||||
|
api_base: API base URL (uses vendor default if None)
|
||||||
|
extra_headers: Optional extra HTTP headers
|
||||||
|
dimensions: Target output dimension (uses vendor default if None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EmbeddingProvider instance
|
||||||
|
"""
|
||||||
|
meta = EMBEDDING_VENDORS.get(provider)
|
||||||
|
if meta is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported embedding provider: {provider}. "
|
||||||
|
f"Supported: {sorted(EMBEDDING_VENDORS.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Doubao uses a non-OpenAI-compatible multimodal endpoint.
|
||||||
|
if meta.get("provider_class") == "doubao":
|
||||||
|
final_dim = dimensions if (dimensions and dimensions > 0) else meta["default_dimensions"]
|
||||||
|
return DoubaoEmbeddingProvider(
|
||||||
|
model=model or meta["default_model"],
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base or meta["default_base_url"],
|
||||||
|
extra_headers=extra_headers,
|
||||||
|
dimensions=final_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Legacy two-arg call for openai/linkai keeps 1536-dim default behavior
|
||||||
|
# so existing data isn't invalidated.
|
||||||
|
is_legacy_call = (
|
||||||
|
provider in ("openai", "linkai")
|
||||||
|
and dimensions is None
|
||||||
|
)
|
||||||
|
if is_legacy_call:
|
||||||
|
return OpenAIEmbeddingProvider(
|
||||||
|
model=model or "text-embedding-3-small",
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
extra_headers=extra_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
final_dim = dimensions if (dimensions and dimensions > 0) else meta["default_dimensions"]
|
||||||
|
return OpenAIEmbeddingProvider(
|
||||||
|
model=model or meta["default_model"],
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base or meta["default_base_url"],
|
||||||
|
extra_headers=extra_headers,
|
||||||
|
dimensions=final_dim,
|
||||||
|
supports_dim_param=meta["supports_dim_param"],
|
||||||
|
needs_client_truncate=meta["needs_client_truncate"],
|
||||||
|
needs_client_normalize=meta["needs_client_normalize"],
|
||||||
|
query_instruction=meta["query_instruction"],
|
||||||
|
max_batch_size=meta.get("max_batch_size", 256),
|
||||||
|
)
|
||||||
191
agent/memory/embedding/rebuild.py
Normal file
191
agent/memory/embedding/rebuild.py
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
"""
|
||||||
|
Rebuild memory vector index.
|
||||||
|
|
||||||
|
Recommended entry point (in-chat, while agent is running):
|
||||||
|
/memory rebuild-index
|
||||||
|
|
||||||
|
Backward-compatible CLI entry (must run from project root):
|
||||||
|
python -m agent.memory.rebuild_index
|
||||||
|
|
||||||
|
What it does:
|
||||||
|
1. Probes the embedding endpoint with a tiny call to fail fast on
|
||||||
|
bad provider/model/key — before touching the index.
|
||||||
|
2. Clears the SQLite chunks/files tables (workspace markdown stays intact).
|
||||||
|
3. Runs a fresh sync, regenerating embeddings with the currently configured
|
||||||
|
provider/model/dimensions.
|
||||||
|
|
||||||
|
This is the only safe way to switch embedding_provider after the existing
|
||||||
|
index has been populated by a different-dim model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from common.log import logger
|
||||||
|
from common.utils import expand_path
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RebuildResult:
|
||||||
|
"""Outcome of a rebuild_in_process() call"""
|
||||||
|
ok: bool
|
||||||
|
removed: int = 0
|
||||||
|
chunks: int = 0
|
||||||
|
files: int = 0
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
def clear_index(db_path, storage=None) -> int:
|
||||||
|
"""Wipe chunks/files, reset FTS5, and clean up any legacy state file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: Path of the index DB (also used to locate the legacy state
|
||||||
|
file for migration cleanup, and — when *storage* is None — to
|
||||||
|
open a fresh connection).
|
||||||
|
storage: Optional pre-opened MemoryStorage. When provided we reuse it
|
||||||
|
so the live connection's triggers stay in sync — opening a second
|
||||||
|
connection would leave the original one's triggers pointing at a
|
||||||
|
DROP'd chunks_fts table.
|
||||||
|
|
||||||
|
We reset (DROP+recreate) chunks_fts because its shadow tables can become
|
||||||
|
inconsistent across rebuild cycles, causing bm25() / ORDER BY rank to
|
||||||
|
raise "database disk image is malformed" even when raw MATCH still works.
|
||||||
|
|
||||||
|
Returns number of chunks removed.
|
||||||
|
"""
|
||||||
|
from agent.memory.embedding.state import cleanup_legacy_state_file
|
||||||
|
from agent.memory.storage import MemoryStorage
|
||||||
|
|
||||||
|
owns_storage = storage is None
|
||||||
|
if owns_storage:
|
||||||
|
storage = MemoryStorage(db_path)
|
||||||
|
try:
|
||||||
|
before = storage.conn.execute("SELECT COUNT(*) FROM chunks").fetchone()[0]
|
||||||
|
storage.conn.execute("DELETE FROM chunks")
|
||||||
|
storage.conn.execute("DELETE FROM files")
|
||||||
|
storage.conn.commit()
|
||||||
|
storage.reset_fts5()
|
||||||
|
finally:
|
||||||
|
if owns_storage:
|
||||||
|
storage.close()
|
||||||
|
|
||||||
|
cleanup_legacy_state_file(db_path)
|
||||||
|
return int(before)
|
||||||
|
|
||||||
|
|
||||||
|
def rebuild_in_process(memory_manager) -> RebuildResult:
|
||||||
|
"""
|
||||||
|
Rebuild the index using an existing, fully-initialized MemoryManager.
|
||||||
|
|
||||||
|
Used by the in-chat /memory rebuild-index command. The caller already has
|
||||||
|
config loaded, embedding_provider built, and (optionally) the agent
|
||||||
|
running, so we only need to:
|
||||||
|
1. Clear chunks/files + state on the manager's storage.
|
||||||
|
2. Re-sync (force=True).
|
||||||
|
|
||||||
|
NOTE: caller must ensure memory_manager.embedding_provider is set, otherwise
|
||||||
|
sync() will silently skip embedding generation.
|
||||||
|
"""
|
||||||
|
if memory_manager is None:
|
||||||
|
return RebuildResult(ok=False, error="memory_manager is None")
|
||||||
|
if memory_manager.embedding_provider is None:
|
||||||
|
return RebuildResult(ok=False, error="embedding_provider is not initialized")
|
||||||
|
|
||||||
|
# Probe the embedding endpoint BEFORE clearing the index. A bad
|
||||||
|
# provider/model/key would otherwise leave the user with an empty index
|
||||||
|
# that not even keyword search can serve.
|
||||||
|
try:
|
||||||
|
memory_manager.embedding_provider.embed_query("ping")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[RebuildIndex] embedding probe failed, aborting rebuild: {e}")
|
||||||
|
return RebuildResult(ok=False, error=f"embedding endpoint not reachable: {e}")
|
||||||
|
|
||||||
|
db_path = memory_manager.config.get_db_path()
|
||||||
|
try:
|
||||||
|
removed = clear_index(db_path, storage=memory_manager.storage)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("[RebuildIndex] clear_index failed")
|
||||||
|
return RebuildResult(ok=False, error=f"clear failed: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
asyncio.run(memory_manager.sync(force=True))
|
||||||
|
except RuntimeError:
|
||||||
|
# Already inside a running event loop (rare in chat handler thread).
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
try:
|
||||||
|
loop.run_until_complete(memory_manager.sync(force=True))
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("[RebuildIndex] sync failed")
|
||||||
|
return RebuildResult(ok=False, removed=removed, error=f"re-embed failed: {e}")
|
||||||
|
|
||||||
|
stats = memory_manager.storage.get_stats()
|
||||||
|
chunks = int(stats.get("chunks", 0))
|
||||||
|
embedded = int(stats.get("embedded", 0))
|
||||||
|
|
||||||
|
# sync() degrades to "no embeddings" on batch failure so keyword search
|
||||||
|
# still works at startup — but in a /rebuild-index request the user
|
||||||
|
# explicitly asked for vectors. Surface that as a failure.
|
||||||
|
if chunks > 0 and embedded == 0:
|
||||||
|
return RebuildResult(
|
||||||
|
ok=False,
|
||||||
|
removed=removed,
|
||||||
|
chunks=chunks,
|
||||||
|
files=int(stats.get("files", 0)),
|
||||||
|
error=(
|
||||||
|
"embedding API failed during sync; index now has chunks but no "
|
||||||
|
"vectors. Check embedding provider/model/key and retry."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return RebuildResult(
|
||||||
|
ok=True,
|
||||||
|
removed=removed,
|
||||||
|
chunks=chunks,
|
||||||
|
files=int(stats.get("files", 0)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
"""Standalone CLI entry. Must be run from project root (relative config path)."""
|
||||||
|
from config import conf, load_config
|
||||||
|
from agent.memory import MemoryConfig, MemoryManager
|
||||||
|
|
||||||
|
load_config()
|
||||||
|
|
||||||
|
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||||
|
memory_config = MemoryConfig(workspace_root=workspace_root)
|
||||||
|
|
||||||
|
logger.info(f"[RebuildIndex] Workspace: {workspace_root}")
|
||||||
|
logger.info(f"[RebuildIndex] Index db: {memory_config.get_db_path()}")
|
||||||
|
|
||||||
|
from bridge.agent_initializer import AgentInitializer
|
||||||
|
|
||||||
|
initializer = AgentInitializer(bridge=None, agent_bridge=None)
|
||||||
|
embedding_provider = initializer._init_embedding_provider(memory_config, session_id=None)
|
||||||
|
if embedding_provider is None:
|
||||||
|
logger.error(
|
||||||
|
"[RebuildIndex] No embedding provider could be initialized. "
|
||||||
|
"Check your config.json. Aborting rebuild."
|
||||||
|
)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
manager = MemoryManager(memory_config, embedding_provider=embedding_provider)
|
||||||
|
result = rebuild_in_process(manager)
|
||||||
|
if not result.ok:
|
||||||
|
logger.error(f"[RebuildIndex] {result.error}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[RebuildIndex] Done. removed={result.removed}, "
|
||||||
|
f"chunks={result.chunks}, files={result.files}"
|
||||||
|
)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
47
agent/memory/embedding/state.py
Normal file
47
agent/memory/embedding/state.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
"""
|
||||||
|
Embedding-related index utilities.
|
||||||
|
|
||||||
|
We don't keep a sidecar state file — the SQLite index is the source of truth
|
||||||
|
and config.json is the source of intent. The two functions below are the
|
||||||
|
only things needing on-disk awareness:
|
||||||
|
|
||||||
|
detect_index_dim : read the dim of stored vectors (display-only)
|
||||||
|
cleanup_legacy_state_file: remove old embedding_state.json from earlier
|
||||||
|
versions; safe no-op when absent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
PathLike = Union[str, os.PathLike]
|
||||||
|
|
||||||
|
|
||||||
|
def detect_index_dim(storage) -> Optional[int]:
|
||||||
|
"""Return the dim of the first stored embedding, or None if the index
|
||||||
|
has no embeddings. Used by /memory status."""
|
||||||
|
try:
|
||||||
|
row = storage.conn.execute(
|
||||||
|
"SELECT embedding FROM chunks WHERE embedding IS NOT NULL LIMIT 1"
|
||||||
|
).fetchone()
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
if not row or not row["embedding"]:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
emb = json.loads(row["embedding"])
|
||||||
|
return len(emb) if isinstance(emb, list) else None
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_legacy_state_file(db_path: PathLike) -> None:
|
||||||
|
"""Remove old embedding_state.json files from earlier versions.
|
||||||
|
Safe to call repeatedly; no-op if the file is absent."""
|
||||||
|
legacy = Path(db_path).parent / "embedding_state.json"
|
||||||
|
try:
|
||||||
|
legacy.unlink(missing_ok=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
@@ -13,7 +13,7 @@ from datetime import datetime, timedelta
|
|||||||
from agent.memory.config import MemoryConfig, get_default_memory_config
|
from agent.memory.config import MemoryConfig, get_default_memory_config
|
||||||
from agent.memory.storage import MemoryStorage, MemoryChunk, SearchResult
|
from agent.memory.storage import MemoryStorage, MemoryChunk, SearchResult
|
||||||
from agent.memory.chunker import TextChunker
|
from agent.memory.chunker import TextChunker
|
||||||
from agent.memory.embedding import create_embedding_provider, EmbeddingProvider
|
from agent.memory.embedding import EmbeddingProvider
|
||||||
from agent.memory.summarizer import MemoryFlushManager, create_memory_files_if_needed
|
from agent.memory.summarizer import MemoryFlushManager, create_memory_files_if_needed
|
||||||
|
|
||||||
|
|
||||||
@@ -50,49 +50,17 @@ class MemoryManager:
|
|||||||
overlap_tokens=self.config.chunk_overlap_tokens
|
overlap_tokens=self.config.chunk_overlap_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize embedding provider (optional, prefer OpenAI, fallback to LinkAI)
|
# Embedding provider is owned by the caller (agent_initializer is the
|
||||||
self.embedding_provider = None
|
# canonical entry point and handles legacy/explicit + state validation).
|
||||||
if embedding_provider:
|
# When None is passed, memory degrades to keyword-only search instead
|
||||||
self.embedding_provider = embedding_provider
|
# of silently re-initializing a vendor here, which would bypass the
|
||||||
else:
|
# caller's state checks and risk corrupting the index.
|
||||||
# Try OpenAI first
|
self.embedding_provider = embedding_provider
|
||||||
try:
|
if self.embedding_provider is None:
|
||||||
api_key = os.environ.get('OPENAI_API_KEY')
|
from common.log import logger
|
||||||
api_base = os.environ.get('OPENAI_API_BASE')
|
logger.info(
|
||||||
if api_key:
|
"[MemoryManager] No embedding provider; memory will use keyword search only"
|
||||||
self.embedding_provider = create_embedding_provider(
|
)
|
||||||
provider="openai",
|
|
||||||
model=self.config.embedding_model,
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
from common.log import logger
|
|
||||||
logger.warning(f"[MemoryManager] OpenAI embedding failed: {e}")
|
|
||||||
|
|
||||||
# Fallback to LinkAI
|
|
||||||
if self.embedding_provider is None:
|
|
||||||
try:
|
|
||||||
linkai_key = os.environ.get('LINKAI_API_KEY')
|
|
||||||
linkai_base = os.environ.get('LINKAI_API_BASE', 'https://api.link-ai.tech')
|
|
||||||
if linkai_key:
|
|
||||||
from common.utils import get_cloud_headers
|
|
||||||
cloud_headers = get_cloud_headers(linkai_key)
|
|
||||||
cloud_headers.pop("Authorization", None)
|
|
||||||
self.embedding_provider = create_embedding_provider(
|
|
||||||
provider="linkai",
|
|
||||||
model=self.config.embedding_model,
|
|
||||||
api_key=linkai_key,
|
|
||||||
api_base=f"{linkai_base}/v1",
|
|
||||||
extra_headers=cloud_headers,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
from common.log import logger
|
|
||||||
logger.warning(f"[MemoryManager] LinkAI embedding failed: {e}")
|
|
||||||
|
|
||||||
if self.embedding_provider is None:
|
|
||||||
from common.log import logger
|
|
||||||
logger.info(f"[MemoryManager] Memory will work with keyword search only (no vector search)")
|
|
||||||
|
|
||||||
# Initialize memory flush manager
|
# Initialize memory flush manager
|
||||||
workspace_dir = self.config.get_workspace()
|
workspace_dir = self.config.get_workspace()
|
||||||
@@ -153,12 +121,14 @@ class MemoryManager:
|
|||||||
if self.config.sync_on_search and self._dirty:
|
if self.config.sync_on_search and self._dirty:
|
||||||
await self.sync()
|
await self.sync()
|
||||||
|
|
||||||
# Perform vector search (if embedding provider available)
|
from common.log import logger
|
||||||
|
|
||||||
|
# Perform vector search (if embedding provider available).
|
||||||
|
# Failures degrade silently to keyword-only — no exception is raised.
|
||||||
vector_results = []
|
vector_results = []
|
||||||
if self.embedding_provider:
|
if self.embedding_provider:
|
||||||
try:
|
try:
|
||||||
from common.log import logger
|
query_embedding = self.embedding_provider.embed_query(query)
|
||||||
query_embedding = self.embedding_provider.embed(query)
|
|
||||||
vector_results = self.storage.search_vector(
|
vector_results = self.storage.search_vector(
|
||||||
query_embedding=query_embedding,
|
query_embedding=query_embedding,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@@ -167,19 +137,19 @@ class MemoryManager:
|
|||||||
)
|
)
|
||||||
logger.info(f"[MemoryManager] Vector search found {len(vector_results)} results for query: {query}")
|
logger.info(f"[MemoryManager] Vector search found {len(vector_results)} results for query: {query}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from common.log import logger
|
logger.error(
|
||||||
logger.warning(f"[MemoryManager] Vector search failed: {e}")
|
f"[MemoryManager] Vector search failed, falling back to keyword-only: {e}"
|
||||||
|
)
|
||||||
# Perform keyword search
|
|
||||||
|
# Perform keyword search (also runs as fallback when vector failed)
|
||||||
keyword_results = self.storage.search_keyword(
|
keyword_results = self.storage.search_keyword(
|
||||||
query=query,
|
query=query,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
scopes=scopes,
|
scopes=scopes,
|
||||||
limit=max_results * 2
|
limit=max_results * 2
|
||||||
)
|
)
|
||||||
from common.log import logger
|
|
||||||
logger.info(f"[MemoryManager] Keyword search found {len(keyword_results)} results for query: {query}")
|
logger.info(f"[MemoryManager] Keyword search found {len(keyword_results)} results for query: {query}")
|
||||||
|
|
||||||
# Merge results
|
# Merge results
|
||||||
merged = self._merge_results(
|
merged = self._merge_results(
|
||||||
vector_results,
|
vector_results,
|
||||||
@@ -187,7 +157,7 @@ class MemoryManager:
|
|||||||
self.config.vector_weight,
|
self.config.vector_weight,
|
||||||
self.config.keyword_weight
|
self.config.keyword_weight
|
||||||
)
|
)
|
||||||
|
|
||||||
# Filter by min score and limit
|
# Filter by min score and limit
|
||||||
filtered = [r for r in merged if r.score >= min_score]
|
filtered = [r for r in merged if r.score >= min_score]
|
||||||
return filtered[:max_results]
|
return filtered[:max_results]
|
||||||
@@ -269,132 +239,157 @@ class MemoryManager:
|
|||||||
|
|
||||||
async def sync(self, force: bool = False):
|
async def sync(self, force: bool = False):
|
||||||
"""
|
"""
|
||||||
Synchronize memory from files
|
Synchronize memory from files.
|
||||||
|
|
||||||
|
Two-pass design to amortize embedding HTTP cost:
|
||||||
|
1. Walk all files, chunk those whose hash changed, collect pending
|
||||||
|
chunks across files. No embedding calls yet.
|
||||||
|
2. Run a single embed_batch over the union of pending chunks (the
|
||||||
|
provider auto-paginates by vendor cap), then persist per-file.
|
||||||
|
|
||||||
|
For workspaces with many small files (101 files / ~1 chunk each), this
|
||||||
|
cuts ~100 HTTP calls down to ~ceil(total_chunks / vendor_cap).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
force: Force full reindex
|
force: Force full reindex
|
||||||
"""
|
"""
|
||||||
memory_dir = self.config.get_memory_dir()
|
memory_dir = self.config.get_memory_dir()
|
||||||
workspace_dir = self.config.get_workspace()
|
workspace_dir = self.config.get_workspace()
|
||||||
|
|
||||||
# Scan MEMORY.md (workspace root)
|
files_to_scan: List[tuple] = [] # (file_path, source, scope, user_id)
|
||||||
|
|
||||||
memory_file = Path(workspace_dir) / "MEMORY.md"
|
memory_file = Path(workspace_dir) / "MEMORY.md"
|
||||||
if memory_file.exists():
|
if memory_file.exists():
|
||||||
await self._sync_file(memory_file, "memory", "shared", None)
|
files_to_scan.append((memory_file, "memory", "shared", None))
|
||||||
|
|
||||||
# Scan memory directory (including daily summaries)
|
|
||||||
if memory_dir.exists():
|
if memory_dir.exists():
|
||||||
for file_path in memory_dir.rglob("*.md"):
|
for file_path in memory_dir.rglob("*.md"):
|
||||||
# Skip hidden directories (e.g. .dreams/)
|
|
||||||
if any(part.startswith('.') for part in file_path.relative_to(workspace_dir).parts):
|
if any(part.startswith('.') for part in file_path.relative_to(workspace_dir).parts):
|
||||||
continue
|
continue
|
||||||
|
rel_parts = file_path.relative_to(workspace_dir).parts
|
||||||
# Determine scope and user_id from path
|
if "daily" in rel_parts:
|
||||||
rel_path = file_path.relative_to(workspace_dir)
|
if "users" in rel_parts or len(rel_parts) > 3:
|
||||||
parts = rel_path.parts
|
user_idx = rel_parts.index("daily") + 1
|
||||||
|
user_id = rel_parts[user_idx] if user_idx < len(rel_parts) else None
|
||||||
# Check if it's in daily summary directory
|
|
||||||
if "daily" in parts:
|
|
||||||
# Daily summary files
|
|
||||||
if "users" in parts or len(parts) > 3:
|
|
||||||
# User-scoped daily summary: memory/daily/{user_id}/2024-01-29.md
|
|
||||||
user_idx = parts.index("daily") + 1
|
|
||||||
user_id = parts[user_idx] if user_idx < len(parts) else None
|
|
||||||
scope = "user"
|
scope = "user"
|
||||||
else:
|
else:
|
||||||
# Shared daily summary: memory/daily/2024-01-29.md
|
|
||||||
user_id = None
|
user_id = None
|
||||||
scope = "shared"
|
scope = "shared"
|
||||||
elif "users" in parts:
|
elif "users" in rel_parts:
|
||||||
# User-scoped memory
|
user_idx = rel_parts.index("users") + 1
|
||||||
user_idx = parts.index("users") + 1
|
user_id = rel_parts[user_idx] if user_idx < len(rel_parts) else None
|
||||||
user_id = parts[user_idx] if user_idx < len(parts) else None
|
|
||||||
scope = "user"
|
scope = "user"
|
||||||
else:
|
else:
|
||||||
# Shared memory
|
|
||||||
user_id = None
|
user_id = None
|
||||||
scope = "shared"
|
scope = "shared"
|
||||||
|
files_to_scan.append((file_path, "memory", scope, user_id))
|
||||||
await self._sync_file(file_path, "memory", scope, user_id)
|
|
||||||
|
|
||||||
# Scan knowledge directory (structured knowledge wiki)
|
|
||||||
from config import conf
|
from config import conf
|
||||||
if conf().get("knowledge", True):
|
if conf().get("knowledge", True):
|
||||||
knowledge_dir = Path(workspace_dir) / "knowledge"
|
knowledge_dir = Path(workspace_dir) / "knowledge"
|
||||||
if knowledge_dir.exists():
|
if knowledge_dir.exists():
|
||||||
for file_path in knowledge_dir.rglob("*.md"):
|
for file_path in knowledge_dir.rglob("*.md"):
|
||||||
await self._sync_file(file_path, "knowledge", "shared", None)
|
files_to_scan.append((file_path, "knowledge", "shared", None))
|
||||||
|
|
||||||
self._dirty = False
|
# Pass 1: inline chunking + change detection. Inlined (instead of
|
||||||
|
# calling self._prepare_file_for_sync) so this method does not depend
|
||||||
async def _sync_file(
|
# on any sibling helpers — keeps it robust against partial reloads
|
||||||
self,
|
# where the class object is older than the method's source.
|
||||||
file_path: Path,
|
pending: List[Dict[str, Any]] = []
|
||||||
source: str,
|
workspace_dir_path = self.config.get_workspace()
|
||||||
scope: str,
|
for file_path, source, scope, user_id in files_to_scan:
|
||||||
user_id: Optional[str]
|
try:
|
||||||
):
|
content = file_path.read_text(encoding='utf-8')
|
||||||
"""Sync a single file"""
|
except Exception:
|
||||||
# Compute file hash
|
continue
|
||||||
content = file_path.read_text(encoding='utf-8')
|
file_hash = MemoryStorage.compute_hash(content)
|
||||||
file_hash = MemoryStorage.compute_hash(content)
|
rel_path = str(file_path.relative_to(workspace_dir_path))
|
||||||
|
if self.storage.get_file_hash(rel_path) == file_hash:
|
||||||
# Get relative path
|
continue
|
||||||
workspace_dir = self.config.get_workspace()
|
chunks = self.chunker.chunk_text(content)
|
||||||
rel_path = str(file_path.relative_to(workspace_dir))
|
if not chunks:
|
||||||
|
continue
|
||||||
# Check if file changed
|
pending.append({
|
||||||
stored_hash = self.storage.get_file_hash(rel_path)
|
"file_path": file_path,
|
||||||
if stored_hash == file_hash:
|
"rel_path": rel_path,
|
||||||
return # No changes
|
"source": source,
|
||||||
|
"scope": scope,
|
||||||
# Delete old chunks
|
"user_id": user_id,
|
||||||
self.storage.delete_by_path(rel_path)
|
"file_hash": file_hash,
|
||||||
|
"chunks": chunks,
|
||||||
# Chunk and embed
|
"texts": [c.text for c in chunks],
|
||||||
chunks = self.chunker.chunk_text(content)
|
})
|
||||||
if not chunks:
|
|
||||||
|
if not pending:
|
||||||
|
self._dirty = False
|
||||||
return
|
return
|
||||||
|
|
||||||
texts = [chunk.text for chunk in chunks]
|
# Pass 2: single batched embed across all pending chunks.
|
||||||
if self.embedding_provider:
|
# CRITICAL: never touch the index until we hold valid embeddings.
|
||||||
embeddings = self.embedding_provider.embed_batch(texts)
|
# If embed_batch fails, leave the existing index intact (chunks +
|
||||||
|
# file_hash) so the next sync will retry the same files. Writing
|
||||||
|
# NULL embeddings + updating file_hash here would mark the file as
|
||||||
|
# "successfully synced" and silently strand it without vectors.
|
||||||
|
all_texts: List[str] = []
|
||||||
|
for entry in pending:
|
||||||
|
all_texts.extend(entry["texts"])
|
||||||
|
|
||||||
|
if not self.embedding_provider:
|
||||||
|
# No provider configured at all (legacy keyword-only). Persist
|
||||||
|
# chunks without embeddings — this is the user's intent.
|
||||||
|
all_embeddings: List[Optional[List[float]]] = [None] * len(all_texts)
|
||||||
else:
|
else:
|
||||||
embeddings = [None] * len(texts)
|
try:
|
||||||
|
all_embeddings = self.embedding_provider.embed_batch(all_texts)
|
||||||
# Create memory chunks
|
except Exception as e:
|
||||||
memory_chunks = []
|
from common.log import logger
|
||||||
for chunk, embedding in zip(chunks, embeddings):
|
logger.error(
|
||||||
chunk_id = self._generate_chunk_id(rel_path, chunk.start_line, chunk.end_line)
|
f"[MemoryManager] Batch embedding failed for {len(all_texts)} "
|
||||||
chunk_hash = MemoryStorage.compute_hash(chunk.text)
|
f"chunks across {len(pending)} files: {e}. "
|
||||||
|
f"Index left untouched; will retry on next sync."
|
||||||
memory_chunks.append(MemoryChunk(
|
)
|
||||||
id=chunk_id,
|
# Bail before touching storage. self._dirty stays True so
|
||||||
user_id=user_id,
|
# callers know there is pending work.
|
||||||
scope=scope,
|
return
|
||||||
source=source,
|
|
||||||
|
# Pass 3: inline persist — same self-contained reasoning as Pass 1.
|
||||||
|
cursor = 0
|
||||||
|
for entry in pending:
|
||||||
|
n = len(entry["texts"])
|
||||||
|
entry_embeddings = all_embeddings[cursor:cursor + n]
|
||||||
|
cursor += n
|
||||||
|
|
||||||
|
rel_path = entry["rel_path"]
|
||||||
|
self.storage.delete_by_path(rel_path)
|
||||||
|
memory_chunks = []
|
||||||
|
for chunk, embedding in zip(entry["chunks"], entry_embeddings):
|
||||||
|
chunk_id = self._generate_chunk_id(rel_path, chunk.start_line, chunk.end_line)
|
||||||
|
chunk_hash = MemoryStorage.compute_hash(chunk.text)
|
||||||
|
memory_chunks.append(MemoryChunk(
|
||||||
|
id=chunk_id,
|
||||||
|
user_id=entry["user_id"],
|
||||||
|
scope=entry["scope"],
|
||||||
|
source=entry["source"],
|
||||||
|
path=rel_path,
|
||||||
|
start_line=chunk.start_line,
|
||||||
|
end_line=chunk.end_line,
|
||||||
|
text=chunk.text,
|
||||||
|
embedding=embedding,
|
||||||
|
hash=chunk_hash,
|
||||||
|
metadata=None,
|
||||||
|
))
|
||||||
|
self.storage.save_chunks_batch(memory_chunks)
|
||||||
|
stat = entry["file_path"].stat()
|
||||||
|
self.storage.update_file_metadata(
|
||||||
path=rel_path,
|
path=rel_path,
|
||||||
start_line=chunk.start_line,
|
source=entry["source"],
|
||||||
end_line=chunk.end_line,
|
file_hash=entry["file_hash"],
|
||||||
text=chunk.text,
|
mtime=int(stat.st_mtime),
|
||||||
embedding=embedding,
|
size=stat.st_size,
|
||||||
hash=chunk_hash,
|
)
|
||||||
metadata=None
|
|
||||||
))
|
self._dirty = False
|
||||||
|
|
||||||
# Save
|
|
||||||
self.storage.save_chunks_batch(memory_chunks)
|
|
||||||
|
|
||||||
# Update file metadata
|
|
||||||
stat = file_path.stat()
|
|
||||||
self.storage.update_file_metadata(
|
|
||||||
path=rel_path,
|
|
||||||
source=source,
|
|
||||||
file_hash=file_hash,
|
|
||||||
mtime=int(stat.st_mtime),
|
|
||||||
size=stat.st_size
|
|
||||||
)
|
|
||||||
|
|
||||||
def flush_memory(
|
def flush_memory(
|
||||||
self,
|
self,
|
||||||
messages: list,
|
messages: list,
|
||||||
|
|||||||
14
agent/memory/rebuild_index.py
Normal file
14
agent/memory/rebuild_index.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
"""
|
||||||
|
Backward-compatible shim for the legacy entry point:
|
||||||
|
python -m agent.memory.rebuild_index
|
||||||
|
|
||||||
|
The implementation now lives in agent.memory.embedding.rebuild.
|
||||||
|
Prefer using `/memory rebuild-index` in chat going forward.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from agent.memory.embedding.rebuild import main
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.exit(main())
|
||||||
@@ -144,45 +144,37 @@ class MemoryStorage:
|
|||||||
ON chunks(path, hash)
|
ON chunks(path, hash)
|
||||||
""")
|
""")
|
||||||
|
|
||||||
# Create FTS5 virtual table for keyword search (only if supported)
|
# Create FTS5 virtual table + triggers (only if supported).
|
||||||
|
# Self-heal: if the previous process crashed mid-rebuild and left
|
||||||
|
# triggers pointing at a missing chunks_fts (or vice versa), wipe
|
||||||
|
# both sides and recreate cleanly. Otherwise next chunks INSERT
|
||||||
|
# will fail with "no such table: chunks_fts".
|
||||||
if self.fts5_available:
|
if self.fts5_available:
|
||||||
# Use default unicode61 tokenizer (stable and compatible)
|
if self._fts5_state_inconsistent():
|
||||||
# For CJK support, we'll use LIKE queries as fallback
|
from common.log import logger
|
||||||
self.conn.execute("""
|
logger.warning(
|
||||||
CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5(
|
"[MemoryStorage] FTS5 state inconsistent (triggers/table mismatch). "
|
||||||
text,
|
"Resetting chunks_fts to recover."
|
||||||
id UNINDEXED,
|
|
||||||
user_id UNINDEXED,
|
|
||||||
path UNINDEXED,
|
|
||||||
source UNINDEXED,
|
|
||||||
scope UNINDEXED,
|
|
||||||
content='chunks',
|
|
||||||
content_rowid='rowid'
|
|
||||||
)
|
)
|
||||||
""")
|
self.conn.execute("DROP TRIGGER IF EXISTS chunks_ai")
|
||||||
|
self.conn.execute("DROP TRIGGER IF EXISTS chunks_ad")
|
||||||
# Create triggers to keep FTS in sync
|
self.conn.execute("DROP TRIGGER IF EXISTS chunks_au")
|
||||||
self.conn.execute("""
|
self.conn.execute("DROP TABLE IF EXISTS chunks_fts")
|
||||||
CREATE TRIGGER IF NOT EXISTS chunks_ai AFTER INSERT ON chunks BEGIN
|
self.conn.commit()
|
||||||
INSERT INTO chunks_fts(rowid, text, id, user_id, path, source, scope)
|
self._create_fts5_objects()
|
||||||
VALUES (new.rowid, new.text, new.id, new.user_id, new.path, new.source, new.scope);
|
|
||||||
END
|
# Probe FTS5 shadow tables. The schema may be intact but the
|
||||||
""")
|
# internal _data/_idx/_docsize blob can still be corrupt — that
|
||||||
|
# surfaces as "database disk image is malformed" on bm25 / MATCH.
|
||||||
self.conn.execute("""
|
# We rebuild from the chunks table when that happens; data isn't
|
||||||
CREATE TRIGGER IF NOT EXISTS chunks_ad AFTER DELETE ON chunks BEGIN
|
# lost because chunks (the content table) is the source of truth.
|
||||||
DELETE FROM chunks_fts WHERE rowid = old.rowid;
|
if self._fts5_shadow_corrupt():
|
||||||
END
|
from common.log import logger
|
||||||
""")
|
logger.warning(
|
||||||
|
"[MemoryStorage] FTS5 shadow tables corrupt; rebuilding from chunks."
|
||||||
self.conn.execute("""
|
)
|
||||||
CREATE TRIGGER IF NOT EXISTS chunks_au AFTER UPDATE ON chunks BEGIN
|
self._rebuild_fts5_from_chunks()
|
||||||
UPDATE chunks_fts SET text = new.text, id = new.id,
|
|
||||||
user_id = new.user_id, path = new.path, source = new.source, scope = new.scope
|
|
||||||
WHERE rowid = new.rowid;
|
|
||||||
END
|
|
||||||
""")
|
|
||||||
|
|
||||||
# Create files metadata table
|
# Create files metadata table
|
||||||
self.conn.execute("""
|
self.conn.execute("""
|
||||||
CREATE TABLE IF NOT EXISTS files (
|
CREATE TABLE IF NOT EXISTS files (
|
||||||
@@ -196,7 +188,116 @@ class MemoryStorage:
|
|||||||
""")
|
""")
|
||||||
|
|
||||||
self.conn.commit()
|
self.conn.commit()
|
||||||
|
|
||||||
|
def _fts5_state_inconsistent(self) -> bool:
|
||||||
|
"""Detect a half-broken FTS5 setup (e.g. trigger exists but table doesn't)."""
|
||||||
|
try:
|
||||||
|
row = self.conn.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' AND name='chunks_fts'"
|
||||||
|
).fetchone()
|
||||||
|
table_exists = row is not None
|
||||||
|
row = self.conn.execute(
|
||||||
|
"SELECT COUNT(*) FROM sqlite_master WHERE type='trigger' "
|
||||||
|
"AND name IN ('chunks_ai','chunks_ad','chunks_au')"
|
||||||
|
).fetchone()
|
||||||
|
trigger_count = int(row[0]) if row else 0
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
# Healthy = both present (3 triggers + table) or both absent.
|
||||||
|
return table_exists != (trigger_count > 0)
|
||||||
|
|
||||||
|
def _create_fts5_objects(self):
|
||||||
|
"""Create chunks_fts virtual table and the 3 sync triggers.
|
||||||
|
|
||||||
|
Idempotent: uses IF NOT EXISTS. Caller must hold self.conn.
|
||||||
|
"""
|
||||||
|
self.conn.execute("""
|
||||||
|
CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5(
|
||||||
|
text,
|
||||||
|
id UNINDEXED,
|
||||||
|
user_id UNINDEXED,
|
||||||
|
path UNINDEXED,
|
||||||
|
source UNINDEXED,
|
||||||
|
scope UNINDEXED,
|
||||||
|
content='chunks',
|
||||||
|
content_rowid='rowid'
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
self.conn.execute("""
|
||||||
|
CREATE TRIGGER IF NOT EXISTS chunks_ai AFTER INSERT ON chunks BEGIN
|
||||||
|
INSERT INTO chunks_fts(rowid, text, id, user_id, path, source, scope)
|
||||||
|
VALUES (new.rowid, new.text, new.id, new.user_id, new.path, new.source, new.scope);
|
||||||
|
END
|
||||||
|
""")
|
||||||
|
self.conn.execute("""
|
||||||
|
CREATE TRIGGER IF NOT EXISTS chunks_ad AFTER DELETE ON chunks BEGIN
|
||||||
|
DELETE FROM chunks_fts WHERE rowid = old.rowid;
|
||||||
|
END
|
||||||
|
""")
|
||||||
|
self.conn.execute("""
|
||||||
|
CREATE TRIGGER IF NOT EXISTS chunks_au AFTER UPDATE ON chunks BEGIN
|
||||||
|
UPDATE chunks_fts SET text = new.text, id = new.id,
|
||||||
|
user_id = new.user_id, path = new.path,
|
||||||
|
source = new.source, scope = new.scope
|
||||||
|
WHERE rowid = new.rowid;
|
||||||
|
END
|
||||||
|
""")
|
||||||
|
|
||||||
|
def reset_fts5(self):
|
||||||
|
"""Drop and recreate chunks_fts + triggers in one transaction.
|
||||||
|
|
||||||
|
Used by rebuild_index to recover from FTS5 shadow-table corruption
|
||||||
|
(bm25/ORDER BY rank may raise "database disk image is malformed"
|
||||||
|
even when raw MATCH still works).
|
||||||
|
|
||||||
|
Triggers must be dropped first; otherwise the next chunks INSERT/DELETE
|
||||||
|
on the existing connection will hit "no such table: chunks_fts".
|
||||||
|
"""
|
||||||
|
if not self.fts5_available:
|
||||||
|
return
|
||||||
|
self.conn.execute("DROP TRIGGER IF EXISTS chunks_ai")
|
||||||
|
self.conn.execute("DROP TRIGGER IF EXISTS chunks_ad")
|
||||||
|
self.conn.execute("DROP TRIGGER IF EXISTS chunks_au")
|
||||||
|
self.conn.execute("DROP TABLE IF EXISTS chunks_fts")
|
||||||
|
self._create_fts5_objects()
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
def _fts5_shadow_corrupt(self) -> bool:
|
||||||
|
"""Probe whether bm25 over chunks_fts errors out at startup.
|
||||||
|
|
||||||
|
Schema (table + triggers) can be intact while the underlying
|
||||||
|
FTS5 shadow blobs are malformed — typically because the previous
|
||||||
|
process crashed mid-write or wrote with a different SQLite build.
|
||||||
|
A cheap MATCH probe surfaces it immediately."""
|
||||||
|
try:
|
||||||
|
self.conn.execute(
|
||||||
|
"SELECT bm25(chunks_fts) FROM chunks_fts WHERE chunks_fts MATCH 'a' LIMIT 1"
|
||||||
|
).fetchone()
|
||||||
|
return False
|
||||||
|
except sqlite3.DatabaseError as e:
|
||||||
|
msg = str(e).lower()
|
||||||
|
return "malformed" in msg or "corrupt" in msg
|
||||||
|
except Exception:
|
||||||
|
# Any other error (e.g. table missing) is handled by the
|
||||||
|
# state-inconsistent path; treat as healthy here.
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _rebuild_fts5_from_chunks(self):
|
||||||
|
"""Drop FTS5, recreate it, then INSERT every row from chunks.
|
||||||
|
|
||||||
|
Safe data-wise: chunks (the content table) is the source of truth.
|
||||||
|
Done in one transaction so a crash leaves either fully old or fully
|
||||||
|
new state, not a partial rebuild.
|
||||||
|
"""
|
||||||
|
# Reset schema first; this clears any malformed shadow blobs.
|
||||||
|
self.reset_fts5()
|
||||||
|
# Re-feed content. Triggers handle future writes automatically.
|
||||||
|
self.conn.execute("""
|
||||||
|
INSERT INTO chunks_fts(rowid, text, id, user_id, path, source, scope)
|
||||||
|
SELECT rowid, text, id, user_id, path, source, scope FROM chunks
|
||||||
|
""")
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
def save_chunk(self, chunk: MemoryChunk):
|
def save_chunk(self, chunk: MemoryChunk):
|
||||||
"""Save a memory chunk"""
|
"""Save a memory chunk"""
|
||||||
self.conn.execute("""
|
self.conn.execute("""
|
||||||
@@ -283,13 +384,26 @@ class MemoryStorage:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
rows = self.conn.execute(query, params).fetchall()
|
rows = self.conn.execute(query, params).fetchall()
|
||||||
|
|
||||||
# Calculate cosine similarity
|
# Calculate cosine similarity. We probe the first row's dim to fail
|
||||||
|
# loudly on a query/index dim mismatch — otherwise every doc would
|
||||||
|
# score 0 silently, leaving the user wondering why search broke.
|
||||||
results = []
|
results = []
|
||||||
|
query_dim = len(query_embedding)
|
||||||
|
if rows:
|
||||||
|
first = json.loads(rows[0]['embedding'])
|
||||||
|
if isinstance(first, list) and len(first) != query_dim:
|
||||||
|
raise ValueError(
|
||||||
|
f"Embedding dim mismatch: query is {query_dim}-dim but "
|
||||||
|
f"index stores {len(first)}-dim vectors. The configured "
|
||||||
|
f"embedding model differs from the one that built the "
|
||||||
|
f"index — run /memory rebuild-index to re-embed."
|
||||||
|
)
|
||||||
|
|
||||||
for row in rows:
|
for row in rows:
|
||||||
embedding = json.loads(row['embedding'])
|
embedding = json.loads(row['embedding'])
|
||||||
similarity = self._cosine_similarity(query_embedding, embedding)
|
similarity = self._cosine_similarity(query_embedding, embedding)
|
||||||
|
|
||||||
if similarity > 0:
|
if similarity > 0:
|
||||||
results.append((similarity, row))
|
results.append((similarity, row))
|
||||||
|
|
||||||
@@ -319,27 +433,24 @@ class MemoryStorage:
|
|||||||
) -> List[SearchResult]:
|
) -> List[SearchResult]:
|
||||||
"""
|
"""
|
||||||
Keyword search using FTS5 + LIKE fallback
|
Keyword search using FTS5 + LIKE fallback
|
||||||
|
|
||||||
Strategy:
|
Strategy:
|
||||||
1. If FTS5 available: Try FTS5 search first (good for English and word-based languages)
|
1. If FTS5 available and healthy: try FTS5 first
|
||||||
2. If no FTS5 or no results and query contains CJK: Use LIKE search
|
2. Always fall back to LIKE for CJK queries
|
||||||
|
3. If FTS5 fails OR returns empty for non-CJK, also try LIKE so a
|
||||||
|
broken FTS5 shadow table doesn't silently kill keyword search.
|
||||||
"""
|
"""
|
||||||
if scopes is None:
|
if scopes is None:
|
||||||
scopes = ["shared"]
|
scopes = ["shared"]
|
||||||
if user_id:
|
if user_id:
|
||||||
scopes.append("user")
|
scopes.append("user")
|
||||||
|
|
||||||
# Try FTS5 search first (if available)
|
|
||||||
if self.fts5_available:
|
if self.fts5_available:
|
||||||
fts_results = self._search_fts5(query, user_id, scopes, limit)
|
fts_results = self._search_fts5(query, user_id, scopes, limit)
|
||||||
if fts_results:
|
if fts_results:
|
||||||
return fts_results
|
return fts_results
|
||||||
|
|
||||||
# Fallback to LIKE search (always for CJK, or if FTS5 not available)
|
return self._search_like(query, user_id, scopes, limit)
|
||||||
if not self.fts5_available or MemoryStorage._contains_cjk(query):
|
|
||||||
return self._search_like(query, user_id, scopes, limit)
|
|
||||||
|
|
||||||
return []
|
|
||||||
|
|
||||||
def _search_fts5(
|
def _search_fts5(
|
||||||
self,
|
self,
|
||||||
@@ -394,7 +505,11 @@ class MemoryStorage:
|
|||||||
)
|
)
|
||||||
for row in rows
|
for row in rows
|
||||||
]
|
]
|
||||||
except Exception:
|
except Exception as e:
|
||||||
|
from common.log import logger
|
||||||
|
logger.error(
|
||||||
|
f"[MemoryStorage] FTS5 search failed (caller will fall back to LIKE): {e}"
|
||||||
|
)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def _search_like(
|
def _search_like(
|
||||||
@@ -404,21 +519,28 @@ class MemoryStorage:
|
|||||||
scopes: List[str],
|
scopes: List[str],
|
||||||
limit: int
|
limit: int
|
||||||
) -> List[SearchResult]:
|
) -> List[SearchResult]:
|
||||||
"""LIKE-based search for CJK characters"""
|
"""LIKE-based search.
|
||||||
|
|
||||||
|
Used as the keyword-search fallback when FTS5 is unavailable, fails,
|
||||||
|
or returns empty. Supports both CJK runs and ASCII word tokens so it
|
||||||
|
can serve as a true safety net for any query.
|
||||||
|
"""
|
||||||
import re
|
import re
|
||||||
# Extract CJK words (2+ characters)
|
# CJK runs (2+ chars) + ASCII word tokens (3+ chars to avoid noise)
|
||||||
cjk_words = re.findall(r'[\u4e00-\u9fff]{2,}', query)
|
cjk_words = re.findall(r'[\u4e00-\u9fff]{2,}', query)
|
||||||
if not cjk_words:
|
ascii_words = [t for t in re.findall(r'[A-Za-z0-9_]+', query) if len(t) >= 3]
|
||||||
|
words = cjk_words + ascii_words
|
||||||
|
if not words:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
scope_placeholders = ','.join('?' * len(scopes))
|
scope_placeholders = ','.join('?' * len(scopes))
|
||||||
|
|
||||||
# Build LIKE conditions for each word
|
# Build LIKE conditions for each word (case-insensitive for ASCII)
|
||||||
like_conditions = []
|
like_conditions = []
|
||||||
params = []
|
params = []
|
||||||
for word in cjk_words:
|
for word in words:
|
||||||
like_conditions.append("text LIKE ?")
|
like_conditions.append("LOWER(text) LIKE ?")
|
||||||
params.append(f'%{word}%')
|
params.append(f'%{word.lower()}%')
|
||||||
|
|
||||||
where_clause = ' OR '.join(like_conditions)
|
where_clause = ' OR '.join(like_conditions)
|
||||||
params.extend(scopes)
|
params.extend(scopes)
|
||||||
@@ -455,7 +577,9 @@ class MemoryStorage:
|
|||||||
)
|
)
|
||||||
for row in rows
|
for row in rows
|
||||||
]
|
]
|
||||||
except Exception:
|
except Exception as e:
|
||||||
|
from common.log import logger
|
||||||
|
logger.error(f"[MemoryStorage] LIKE search failed: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def delete_by_path(self, path: str):
|
def delete_by_path(self, path: str):
|
||||||
@@ -485,14 +609,19 @@ class MemoryStorage:
|
|||||||
chunks_count = self.conn.execute("""
|
chunks_count = self.conn.execute("""
|
||||||
SELECT COUNT(*) as cnt FROM chunks
|
SELECT COUNT(*) as cnt FROM chunks
|
||||||
""").fetchone()['cnt']
|
""").fetchone()['cnt']
|
||||||
|
|
||||||
files_count = self.conn.execute("""
|
files_count = self.conn.execute("""
|
||||||
SELECT COUNT(*) as cnt FROM files
|
SELECT COUNT(*) as cnt FROM files
|
||||||
""").fetchone()['cnt']
|
""").fetchone()['cnt']
|
||||||
|
|
||||||
|
embedded_count = self.conn.execute("""
|
||||||
|
SELECT COUNT(*) as cnt FROM chunks WHERE embedding IS NOT NULL
|
||||||
|
""").fetchone()['cnt']
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'chunks': chunks_count,
|
'chunks': chunks_count,
|
||||||
'files': files_count
|
'files': files_count,
|
||||||
|
'embedded': embedded_count,
|
||||||
}
|
}
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
|||||||
@@ -17,6 +17,10 @@ from common.utils import expand_path
|
|||||||
# Module-level lock to serialize scheduler init across concurrent sessions
|
# Module-level lock to serialize scheduler init across concurrent sessions
|
||||||
_scheduler_init_lock = threading.Lock()
|
_scheduler_init_lock = threading.Lock()
|
||||||
|
|
||||||
|
# Track whether the embedding model log has been printed in this process,
|
||||||
|
# so we avoid spamming it once per session.
|
||||||
|
_embedding_logged: bool = False
|
||||||
|
|
||||||
|
|
||||||
class AgentInitializer:
|
class AgentInitializer:
|
||||||
"""
|
"""
|
||||||
@@ -272,52 +276,19 @@ class AgentInitializer:
|
|||||||
memory_tools = []
|
memory_tools = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from agent.memory import MemoryManager, MemoryConfig, create_embedding_provider
|
from agent.memory import MemoryManager, MemoryConfig
|
||||||
from agent.tools import MemorySearchTool, MemoryGetTool
|
from agent.tools import MemorySearchTool, MemoryGetTool
|
||||||
from config import conf
|
from config import conf
|
||||||
|
|
||||||
# Initialize embedding provider (prefer OpenAI, fallback to LinkAI)
|
|
||||||
embedding_provider = None
|
|
||||||
|
|
||||||
openai_api_key = conf().get("open_ai_api_key", "")
|
|
||||||
openai_api_base = conf().get("open_ai_api_base", "")
|
|
||||||
if openai_api_key and openai_api_key not in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
|
||||||
try:
|
|
||||||
embedding_provider = create_embedding_provider(
|
|
||||||
provider="openai",
|
|
||||||
model="text-embedding-3-small",
|
|
||||||
api_key=openai_api_key,
|
|
||||||
api_base=openai_api_base or "https://api.openai.com/v1"
|
|
||||||
)
|
|
||||||
if session_id is None:
|
|
||||||
logger.info("[AgentInitializer] OpenAI embedding initialized")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[AgentInitializer] OpenAI embedding failed: {e}")
|
|
||||||
|
|
||||||
if embedding_provider is None:
|
|
||||||
linkai_api_key = conf().get("linkai_api_key", "") or os.environ.get("LINKAI_API_KEY", "")
|
|
||||||
linkai_api_base = conf().get("linkai_api_base", "https://api.link-ai.tech")
|
|
||||||
if linkai_api_key and linkai_api_key not in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
|
||||||
try:
|
|
||||||
embedding_provider = create_embedding_provider(
|
|
||||||
provider="linkai",
|
|
||||||
model="text-embedding-3-small",
|
|
||||||
api_key=linkai_api_key,
|
|
||||||
api_base=f"{linkai_api_base}/v1"
|
|
||||||
)
|
|
||||||
if session_id is None:
|
|
||||||
logger.info("[AgentInitializer] LinkAI embedding initialized (fallback)")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[AgentInitializer] LinkAI embedding failed: {e}")
|
|
||||||
|
|
||||||
# Create memory manager
|
|
||||||
memory_config = MemoryConfig(workspace_root=workspace_root)
|
memory_config = MemoryConfig(workspace_root=workspace_root)
|
||||||
|
|
||||||
|
embedding_provider = self._init_embedding_provider(
|
||||||
|
memory_config, session_id=session_id
|
||||||
|
)
|
||||||
|
|
||||||
memory_manager = MemoryManager(memory_config, embedding_provider=embedding_provider)
|
memory_manager = MemoryManager(memory_config, embedding_provider=embedding_provider)
|
||||||
|
|
||||||
# Sync memory
|
|
||||||
self._sync_memory(memory_manager, session_id)
|
self._sync_memory(memory_manager, session_id)
|
||||||
|
|
||||||
# Create memory tools
|
|
||||||
memory_tools = [
|
memory_tools = [
|
||||||
MemorySearchTool(memory_manager),
|
MemorySearchTool(memory_manager),
|
||||||
MemoryGetTool(memory_manager)
|
MemoryGetTool(memory_manager)
|
||||||
@@ -330,6 +301,190 @@ class AgentInitializer:
|
|||||||
logger.warning(f"[AgentInitializer] Memory system not available: {e}")
|
logger.warning(f"[AgentInitializer] Memory system not available: {e}")
|
||||||
|
|
||||||
return memory_manager, memory_tools
|
return memory_manager, memory_tools
|
||||||
|
|
||||||
|
def _init_embedding_provider(self, memory_config, session_id: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Initialize the embedding provider for memory.
|
||||||
|
|
||||||
|
Two paths:
|
||||||
|
A. Default (no `embedding_provider` in config.json):
|
||||||
|
Auto-init OpenAI -> LinkAI fallback. Existing 1536-dim indices
|
||||||
|
keep working.
|
||||||
|
B. Explicit (`embedding_provider` is set):
|
||||||
|
Initialize the requested vendor with unified dim (default 1024).
|
||||||
|
If the index was built with a different dim, vector search will
|
||||||
|
quietly return no results (cosine returns 0) and keyword search
|
||||||
|
takes over until the user runs /memory rebuild-index.
|
||||||
|
"""
|
||||||
|
from agent.memory import create_embedding_provider
|
||||||
|
from config import conf
|
||||||
|
|
||||||
|
explicit_provider = (conf().get("embedding_provider") or "").strip().lower()
|
||||||
|
|
||||||
|
if not explicit_provider:
|
||||||
|
return self._init_embedding_provider_legacy(session_id=session_id)
|
||||||
|
|
||||||
|
return self._init_embedding_provider_explicit(
|
||||||
|
memory_config, explicit_provider, session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_embedding_provider_legacy(self, session_id: Optional[str] = None):
|
||||||
|
"""Legacy auto-init path: OpenAI -> LinkAI. Preserved verbatim for compat."""
|
||||||
|
from agent.memory import create_embedding_provider
|
||||||
|
from config import conf
|
||||||
|
|
||||||
|
embedding_provider = None
|
||||||
|
embedding_model = None
|
||||||
|
|
||||||
|
openai_api_key = conf().get("open_ai_api_key", "")
|
||||||
|
openai_api_base = conf().get("open_ai_api_base", "")
|
||||||
|
if openai_api_key and openai_api_key not in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
||||||
|
try:
|
||||||
|
model = "text-embedding-3-small"
|
||||||
|
embedding_provider = create_embedding_provider(
|
||||||
|
provider="openai",
|
||||||
|
model=model,
|
||||||
|
api_key=openai_api_key,
|
||||||
|
api_base=openai_api_base or "https://api.openai.com/v1"
|
||||||
|
)
|
||||||
|
embedding_model = f"openai/{model}"
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[AgentInitializer] OpenAI embedding failed: {e}")
|
||||||
|
|
||||||
|
if embedding_provider is None:
|
||||||
|
linkai_api_key = conf().get("linkai_api_key", "") or os.environ.get("LINKAI_API_KEY", "")
|
||||||
|
linkai_api_base = conf().get("linkai_api_base", "https://api.link-ai.tech")
|
||||||
|
if linkai_api_key and linkai_api_key not in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
||||||
|
try:
|
||||||
|
model = "text-embedding-3-small"
|
||||||
|
embedding_provider = create_embedding_provider(
|
||||||
|
provider="linkai",
|
||||||
|
model=model,
|
||||||
|
api_key=linkai_api_key,
|
||||||
|
api_base=f"{linkai_api_base}/v1"
|
||||||
|
)
|
||||||
|
embedding_model = f"linkai/{model}"
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[AgentInitializer] LinkAI embedding failed: {e}")
|
||||||
|
|
||||||
|
if embedding_provider is not None and embedding_model:
|
||||||
|
global _embedding_logged
|
||||||
|
if not _embedding_logged:
|
||||||
|
logger.info(
|
||||||
|
f"[AgentInitializer] Embedding model in use: {embedding_model} "
|
||||||
|
f"(dim={embedding_provider.dimensions})"
|
||||||
|
)
|
||||||
|
_embedding_logged = True
|
||||||
|
|
||||||
|
return embedding_provider
|
||||||
|
|
||||||
|
def _init_embedding_provider_explicit(
|
||||||
|
self,
|
||||||
|
memory_config,
|
||||||
|
provider_key: str,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""Explicit-provider path: build the configured vendor.
|
||||||
|
|
||||||
|
If the index was built with a different dim, vector search will
|
||||||
|
silently return no results (cosine returns 0 for mismatched dims)
|
||||||
|
and keyword search takes over. Users switch vendors by running
|
||||||
|
/memory rebuild-index — see docs.
|
||||||
|
"""
|
||||||
|
from agent.memory import create_embedding_provider
|
||||||
|
from agent.memory.embedding import EMBEDDING_VENDORS
|
||||||
|
from config import conf
|
||||||
|
|
||||||
|
meta = EMBEDDING_VENDORS.get(provider_key)
|
||||||
|
if meta is None:
|
||||||
|
logger.error(
|
||||||
|
f"[AgentInitializer] Unknown embedding_provider '{provider_key}'. "
|
||||||
|
f"Supported: {sorted(EMBEDDING_VENDORS.keys())}. "
|
||||||
|
f"Memory will run in keyword-only mode."
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
api_key = self._resolve_embedding_api_key(provider_key)
|
||||||
|
api_base = self._resolve_embedding_api_base(provider_key, meta["default_base_url"])
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
logger.error(
|
||||||
|
f"[AgentInitializer] embedding_provider='{provider_key}' is set but its "
|
||||||
|
f"API key is missing. Memory will run in keyword-only mode."
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
model = (conf().get("embedding_model") or "").strip() or meta["default_model"]
|
||||||
|
try:
|
||||||
|
cfg_dim = int(conf().get("embedding_dimensions") or 0)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
cfg_dim = 0
|
||||||
|
dim = cfg_dim if cfg_dim > 0 else meta["default_dimensions"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
provider = create_embedding_provider(
|
||||||
|
provider=provider_key,
|
||||||
|
model=model,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
dimensions=dim,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[AgentInitializer] Failed to init embedding provider "
|
||||||
|
f"'{provider_key}/{model}': {e}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
global _embedding_logged
|
||||||
|
if not _embedding_logged:
|
||||||
|
logger.info(
|
||||||
|
f"[AgentInitializer] Embedding model in use: "
|
||||||
|
f"{provider_key}/{model} (dim={provider.dimensions})"
|
||||||
|
)
|
||||||
|
_embedding_logged = True
|
||||||
|
return provider
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_embedding_api_key(provider_key: str) -> str:
|
||||||
|
"""Pick the API key for an explicit embedding provider from config."""
|
||||||
|
from config import conf
|
||||||
|
|
||||||
|
key_map = {
|
||||||
|
"openai": "open_ai_api_key",
|
||||||
|
"linkai": "linkai_api_key",
|
||||||
|
"dashscope": "dashscope_api_key",
|
||||||
|
"doubao": "ark_api_key",
|
||||||
|
"zhipu": "zhipu_ai_api_key",
|
||||||
|
}
|
||||||
|
field = key_map.get(provider_key)
|
||||||
|
if not field:
|
||||||
|
return ""
|
||||||
|
value = conf().get(field, "") or ""
|
||||||
|
if value in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
||||||
|
return ""
|
||||||
|
return value
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_embedding_api_base(provider_key: str, default_base: str) -> str:
|
||||||
|
"""Pick the API base for an explicit embedding provider from config."""
|
||||||
|
from config import conf
|
||||||
|
|
||||||
|
base_map = {
|
||||||
|
"openai": "open_ai_api_base",
|
||||||
|
"linkai": "linkai_api_base",
|
||||||
|
"doubao": "ark_base_url",
|
||||||
|
"zhipu": "zhipu_ai_api_base",
|
||||||
|
}
|
||||||
|
field = base_map.get(provider_key)
|
||||||
|
if not field:
|
||||||
|
return default_base
|
||||||
|
value = (conf().get(field) or "").strip()
|
||||||
|
if not value:
|
||||||
|
return default_base
|
||||||
|
if provider_key == "linkai" and not value.rstrip("/").endswith("/v1"):
|
||||||
|
return f"{value.rstrip('/')}/v1"
|
||||||
|
return value
|
||||||
|
|
||||||
def _sync_memory(self, memory_manager, session_id: Optional[str] = None):
|
def _sync_memory(self, memory_manager, session_id: Optional[str] = None):
|
||||||
"""Sync memory database"""
|
"""Sync memory database"""
|
||||||
|
|||||||
@@ -26,7 +26,8 @@ Commands:
|
|||||||
knowledge Manage knowledge base.
|
knowledge Manage knowledge base.
|
||||||
install-browser Install browser tool (Playwright + Chromium).
|
install-browser Install browser tool (Playwright + Chromium).
|
||||||
|
|
||||||
Tip: You can also send /help, /skill list, etc. in agent chat."""
|
Tip: Memory index management lives in chat — send /memory status or
|
||||||
|
/memory rebuild-index to the running agent."""
|
||||||
|
|
||||||
|
|
||||||
class CowCLI(click.Group):
|
class CowCLI(click.Group):
|
||||||
|
|||||||
@@ -100,6 +100,10 @@ available_setting = {
|
|||||||
"dashscope_api_key": "",
|
"dashscope_api_key": "",
|
||||||
# Google Gemini Api Key
|
# Google Gemini Api Key
|
||||||
"gemini_api_key": "",
|
"gemini_api_key": "",
|
||||||
|
# Embedding 模型设置
|
||||||
|
"embedding_provider": "", # 显式指定厂商:openai / linkai / dashscope / doubao / zhipu (与 bot_type 命名一致)
|
||||||
|
"embedding_model": "", # 留空使用厂商默认 model
|
||||||
|
"embedding_dimensions": 0, # 留空/0 使用厂商默认维度(推荐统一 1024)
|
||||||
# 语音设置
|
# 语音设置
|
||||||
"speech_recognition": True, # 是否开启语音识别
|
"speech_recognition": True, # 是否开启语音识别
|
||||||
"group_speech_recognition": False, # 是否开启群组语音识别
|
"group_speech_recognition": False, # 是否开启群组语音识别
|
||||||
|
|||||||
@@ -62,10 +62,25 @@ class CowCliPlugin(Plugin):
|
|||||||
|
|
||||||
content = e_context["context"].content.strip()
|
content = e_context["context"].content.strip()
|
||||||
parsed = self._parse_command(content)
|
parsed = self._parse_command(content)
|
||||||
if not parsed:
|
if parsed is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
cmd, args = parsed
|
cmd, args = parsed
|
||||||
|
|
||||||
|
if cmd not in KNOWN_COMMANDS:
|
||||||
|
# Slash-prefixed near-miss: looks like a typo of a real command.
|
||||||
|
# Intercept with a hint so we don't burn an LLM round on "/momory".
|
||||||
|
suggestion = self._suggest_command(cmd)
|
||||||
|
if suggestion is None:
|
||||||
|
return
|
||||||
|
hint = f"未知命令: /{cmd}"
|
||||||
|
if suggestion:
|
||||||
|
hint += f"\n你是不是想输入 /{suggestion} ?"
|
||||||
|
hint += "\n发送 /help 查看全部命令。"
|
||||||
|
e_context["reply"] = Reply(ReplyType.TEXT, hint)
|
||||||
|
e_context.action = EventAction.BREAK_PASS
|
||||||
|
return
|
||||||
|
|
||||||
logger.info(f"[CowCli] intercepted command: {cmd} {args}")
|
logger.info(f"[CowCli] intercepted command: {cmd} {args}")
|
||||||
|
|
||||||
result = self._dispatch(cmd, args, e_context)
|
result = self._dispatch(cmd, args, e_context)
|
||||||
@@ -82,28 +97,80 @@ class CowCliPlugin(Plugin):
|
|||||||
cow <command> [args...] e.g. "cow skill list"
|
cow <command> [args...] e.g. "cow skill list"
|
||||||
/<command> [args...] e.g. "/skill list"
|
/<command> [args...] e.g. "/skill list"
|
||||||
|
|
||||||
Returns (command, args_string) or None if not a cow command.
|
Returns:
|
||||||
"""
|
- (command, args_string): when the message looks like a command.
|
||||||
parts = None
|
'command' may NOT be in KNOWN_COMMANDS; caller should validate.
|
||||||
|
- None: when the message is not command-like at all.
|
||||||
|
|
||||||
|
We deliberately return parsed-but-unknown for the slash form so the
|
||||||
|
caller can offer a typo hint instead of silently passing the message
|
||||||
|
through to the agent.
|
||||||
|
"""
|
||||||
if content.startswith("/"):
|
if content.startswith("/"):
|
||||||
rest = content[1:].strip()
|
rest = content[1:].strip()
|
||||||
if rest:
|
if not rest:
|
||||||
parts = rest.split(None, 1)
|
return None
|
||||||
elif content.startswith("cow "):
|
parts = rest.split(None, 1)
|
||||||
|
cmd = parts[0].lower()
|
||||||
|
args = parts[1] if len(parts) > 1 else ""
|
||||||
|
return cmd, args
|
||||||
|
|
||||||
|
if content.startswith("cow "):
|
||||||
rest = content[4:].strip()
|
rest = content[4:].strip()
|
||||||
if rest:
|
if not rest:
|
||||||
parts = rest.split(None, 1)
|
return None
|
||||||
|
parts = rest.split(None, 1)
|
||||||
|
cmd = parts[0].lower()
|
||||||
|
if cmd not in KNOWN_COMMANDS:
|
||||||
|
# 'cow xxx' that isn't a command — don't intercept (could be
|
||||||
|
# natural language like "cow xxx 怎么样").
|
||||||
|
return None
|
||||||
|
args = parts[1] if len(parts) > 1 else ""
|
||||||
|
return cmd, args
|
||||||
|
|
||||||
if not parts:
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _suggest_command(cmd: str) -> str:
|
||||||
|
"""
|
||||||
|
Return the closest known command if cmd is a likely typo, else "".
|
||||||
|
Returns None to indicate "do not intercept" (when input is too far off).
|
||||||
|
|
||||||
|
Heuristic: edit distance <= 1 (single insert/delete/substitute) when
|
||||||
|
|cmd| >= 3, and the candidate shares the same first letter.
|
||||||
|
"""
|
||||||
|
if not cmd:
|
||||||
|
return ""
|
||||||
|
if len(cmd) < 3:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
cmd = parts[0].lower()
|
def edit_distance_le1(a: str, b: str) -> bool:
|
||||||
if cmd not in KNOWN_COMMANDS:
|
if a == b:
|
||||||
return None
|
return True
|
||||||
|
la, lb = len(a), len(b)
|
||||||
|
if abs(la - lb) > 1:
|
||||||
|
return False
|
||||||
|
if la == lb:
|
||||||
|
diffs = sum(1 for x, y in zip(a, b) if x != y)
|
||||||
|
return diffs <= 1
|
||||||
|
short, long_ = (a, b) if la < lb else (b, a)
|
||||||
|
i = j = 0
|
||||||
|
skipped = False
|
||||||
|
while i < len(short) and j < len(long_):
|
||||||
|
if short[i] != long_[j]:
|
||||||
|
if skipped:
|
||||||
|
return False
|
||||||
|
skipped = True
|
||||||
|
j += 1
|
||||||
|
else:
|
||||||
|
i += 1
|
||||||
|
j += 1
|
||||||
|
return True
|
||||||
|
|
||||||
args = parts[1] if len(parts) > 1 else ""
|
for known in KNOWN_COMMANDS:
|
||||||
return cmd, args
|
if known[0] == cmd[0] and edit_distance_le1(cmd, known):
|
||||||
|
return known
|
||||||
|
return None
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Command dispatch
|
# Command dispatch
|
||||||
@@ -113,12 +180,23 @@ class CowCliPlugin(Plugin):
|
|||||||
"""Execute a cow/slash command string without a channel context.
|
"""Execute a cow/slash command string without a channel context.
|
||||||
|
|
||||||
Used by cloud on_chat to intercept commands before the agent runs.
|
Used by cloud on_chat to intercept commands before the agent runs.
|
||||||
Returns None when *query* is not a recognised command.
|
Returns None when *query* is not command-like at all (e.g. natural
|
||||||
|
language). For slash-prefixed typos returns a hint string so the
|
||||||
|
caller still short-circuits the agent round.
|
||||||
"""
|
"""
|
||||||
parsed = self._parse_command(query.strip())
|
parsed = self._parse_command(query.strip())
|
||||||
if not parsed:
|
if parsed is None:
|
||||||
return None
|
return None
|
||||||
cmd, args = parsed
|
cmd, args = parsed
|
||||||
|
if cmd not in KNOWN_COMMANDS:
|
||||||
|
suggestion = self._suggest_command(cmd)
|
||||||
|
if suggestion is None:
|
||||||
|
return None
|
||||||
|
hint = f"未知命令: /{cmd}"
|
||||||
|
if suggestion:
|
||||||
|
hint += f"\n你是不是想输入 /{suggestion} ?"
|
||||||
|
hint += "\n发送 /help 查看全部命令。"
|
||||||
|
return hint
|
||||||
return self._dispatch(cmd, args, e_context=None, session_id=session_id)
|
return self._dispatch(cmd, args, e_context=None, session_id=session_id)
|
||||||
|
|
||||||
def _dispatch(self, cmd: str, args: str, e_context: EventContext, session_id: str = "") -> str:
|
def _dispatch(self, cmd: str, args: str, e_context: EventContext, session_id: str = "") -> str:
|
||||||
@@ -158,7 +236,9 @@ class CowCliPlugin(Plugin):
|
|||||||
" /config 查看当前配置",
|
" /config 查看当前配置",
|
||||||
" /config <key> 查看某项配置",
|
" /config <key> 查看某项配置",
|
||||||
" /config <key> <val> 修改配置",
|
" /config <key> <val> 修改配置",
|
||||||
" /memory dream [N] 手动触发记忆蒸馏 (整理近N天, 默认3, 最多30)",
|
" /memory status 查看记忆索引状态",
|
||||||
|
" /memory rebuild-index 清空并重建向量索引 (切换 embedding 模型后必须执行)",
|
||||||
|
" /memory dream [N] 手动触发记忆蒸馏 (整理近N天, 默认3, 最多30)",
|
||||||
" /knowledge 查看知识库统计",
|
" /knowledge 查看知识库统计",
|
||||||
" /knowledge list 查看知识库文件树",
|
" /knowledge list 查看知识库文件树",
|
||||||
" /knowledge on|off 开启/关闭知识库",
|
" /knowledge on|off 开启/关闭知识库",
|
||||||
@@ -907,12 +987,25 @@ class CowCliPlugin(Plugin):
|
|||||||
if len(parts) > 1 and parts[1].isdigit():
|
if len(parts) > 1 and parts[1].isdigit():
|
||||||
days = max(1, min(int(parts[1]), 30))
|
days = max(1, min(int(parts[1]), 30))
|
||||||
return self._memory_dream(days, e_context, session_id)
|
return self._memory_dream(days, e_context, session_id)
|
||||||
|
elif sub in ("rebuild-index", "rebuild_index", "rebuild"):
|
||||||
|
return self._memory_rebuild_index(e_context, session_id)
|
||||||
|
elif sub in ("status", "info", ""):
|
||||||
|
if sub == "":
|
||||||
|
return self._memory_help()
|
||||||
|
return self._memory_status()
|
||||||
else:
|
else:
|
||||||
return (
|
return self._memory_help()
|
||||||
"用法: /memory <子命令>\n\n"
|
|
||||||
"子命令:\n"
|
@staticmethod
|
||||||
" dream [N] 手动触发记忆蒸馏 (整理近N天, 默认3, 最多30)"
|
def _memory_help() -> str:
|
||||||
)
|
return (
|
||||||
|
"🧠 记忆管理\n\n"
|
||||||
|
"用法: /memory <子命令>\n\n"
|
||||||
|
"子命令:\n"
|
||||||
|
" status 查看索引状态 (provider / model / dim / chunks)\n"
|
||||||
|
" rebuild-index 清空并重建向量索引 (切换 embedding 模型后必须执行)\n"
|
||||||
|
" dream [N] 手动触发记忆蒸馏 (整理近N天, 默认3, 最多30)"
|
||||||
|
)
|
||||||
|
|
||||||
def _memory_dream(self, days: int, e_context, session_id: str) -> str:
|
def _memory_dream(self, days: int, e_context, session_id: str) -> str:
|
||||||
session_id = self._get_session_id(e_context, fallback=session_id)
|
session_id = self._get_session_id(e_context, fallback=session_id)
|
||||||
@@ -963,6 +1056,140 @@ class CowCliPlugin(Plugin):
|
|||||||
logger.warning(f"[CowCli] /memory dream sync failed: {e}")
|
logger.warning(f"[CowCli] /memory dream sync failed: {e}")
|
||||||
return f"❌ 记忆蒸馏失败: {e}"
|
return f"❌ 记忆蒸馏失败: {e}"
|
||||||
|
|
||||||
|
def _memory_status(self) -> str:
|
||||||
|
"""Show current memory index status."""
|
||||||
|
from agent.memory.embedding import detect_index_dim
|
||||||
|
from config import conf
|
||||||
|
|
||||||
|
agent = self._get_agent("")
|
||||||
|
memory_manager = agent.memory_manager if agent else None
|
||||||
|
|
||||||
|
lines = ["🧠 记忆索引状态", ""]
|
||||||
|
if not memory_manager:
|
||||||
|
lines.append(" ⚠️ Agent 尚未初始化,先发一条普通消息再试")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
stats = memory_manager.storage.get_stats()
|
||||||
|
db_path = memory_manager.config.get_db_path()
|
||||||
|
embedded = stats.get('embedded', 0)
|
||||||
|
chunks = stats.get('chunks', 0)
|
||||||
|
lines.append(f" 索引DB : {db_path}")
|
||||||
|
lines.append(f" Files : {stats.get('files', 0)}")
|
||||||
|
lines.append(f" Chunks : {chunks} (embedded: {embedded})")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
# Active provider (from running config + provider instance).
|
||||||
|
provider_obj = memory_manager.embedding_provider
|
||||||
|
cfg_provider = (conf().get("embedding_provider") or "").strip().lower() or "(legacy)"
|
||||||
|
if provider_obj is not None:
|
||||||
|
cfg_model = getattr(provider_obj, "model", "?")
|
||||||
|
cfg_dim = getattr(provider_obj, "_dimensions", None) or "?"
|
||||||
|
lines.append(f" Provider : {cfg_provider}")
|
||||||
|
lines.append(f" Model : {cfg_model}")
|
||||||
|
lines.append(f" Dim : {cfg_dim}")
|
||||||
|
else:
|
||||||
|
lines.append(" Provider : (未初始化, keyword-only)")
|
||||||
|
|
||||||
|
# Health hints — only shown when the user has explicitly opted into
|
||||||
|
# vector search via `embedding_provider`. Legacy users (no explicit
|
||||||
|
# provider) are running in a "best-effort vectors" mode by design;
|
||||||
|
# nagging them about missing/mismatched vectors would be noise.
|
||||||
|
warnings = []
|
||||||
|
explicitly_opted_in = (conf().get("embedding_provider") or "").strip() != ""
|
||||||
|
if explicitly_opted_in and provider_obj is not None:
|
||||||
|
if chunks > 0 and embedded < chunks:
|
||||||
|
missing = chunks - embedded
|
||||||
|
warnings.append(
|
||||||
|
f" ⚠️ {missing}/{chunks} 个 chunk 没有向量;"
|
||||||
|
f"运行 /memory rebuild-index 后所有记忆才会被向量化检索"
|
||||||
|
)
|
||||||
|
|
||||||
|
index_dim = detect_index_dim(memory_manager.storage)
|
||||||
|
cfg_dim = getattr(provider_obj, "_dimensions", None)
|
||||||
|
if index_dim is not None and cfg_dim and index_dim != cfg_dim:
|
||||||
|
warnings.append(
|
||||||
|
f" ⚠️ 索引中存量向量为 {index_dim} 维,与当前配置 {cfg_dim} 维不一致;"
|
||||||
|
f"运行 /memory rebuild-index 重建后向量检索才会生效"
|
||||||
|
)
|
||||||
|
|
||||||
|
if warnings:
|
||||||
|
lines.append("")
|
||||||
|
lines.extend(warnings)
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
def _memory_rebuild_index(self, e_context, session_id: str) -> str:
|
||||||
|
"""Rebuild the vector index using the current agent's memory_manager."""
|
||||||
|
session_id = self._get_session_id(e_context, fallback=session_id)
|
||||||
|
agent = self._get_agent(session_id)
|
||||||
|
if not agent or not agent.memory_manager:
|
||||||
|
return (
|
||||||
|
"⚠️ Agent 尚未初始化,无法重建索引。\n"
|
||||||
|
"请先发送一条普通消息触发 Agent 启动后再试。"
|
||||||
|
)
|
||||||
|
|
||||||
|
memory_manager = agent.memory_manager
|
||||||
|
if memory_manager.embedding_provider is None:
|
||||||
|
return (
|
||||||
|
"⚠️ 当前没有可用的 embedding provider。\n"
|
||||||
|
"请检查 config.json 中的 embedding 相关配置 (provider / api key)。"
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_obj = memory_manager.embedding_provider
|
||||||
|
model_label = getattr(provider_obj, "model", "?")
|
||||||
|
dim_label = getattr(provider_obj, "dimensions", "?")
|
||||||
|
|
||||||
|
# SaaS (e_context is None): run synchronously, return final result
|
||||||
|
if e_context is None:
|
||||||
|
return self._memory_rebuild_sync(memory_manager, model_label, dim_label)
|
||||||
|
|
||||||
|
# Local channels: run in background, push progress + final result
|
||||||
|
from agent.memory.embedding import rebuild_in_process
|
||||||
|
|
||||||
|
def _run():
|
||||||
|
try:
|
||||||
|
result = rebuild_in_process(memory_manager)
|
||||||
|
if result.ok:
|
||||||
|
self._notify(
|
||||||
|
e_context,
|
||||||
|
(
|
||||||
|
f"✅ 索引重建完成\n"
|
||||||
|
f" cleared : {result.removed}\n"
|
||||||
|
f" chunks : {result.chunks}\n"
|
||||||
|
f" files : {result.files}"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._notify(e_context, f"❌ 索引重建失败: {result.error}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("[CowCli] /memory rebuild-index failed")
|
||||||
|
self._notify(e_context, f"❌ 索引重建失败: {e}")
|
||||||
|
|
||||||
|
threading.Thread(target=_run, daemon=True).start()
|
||||||
|
return (
|
||||||
|
f"🔧 索引重建已启动 (model={model_label}, dim={dim_label})\n\n"
|
||||||
|
f"将清空现有 chunks 并重新 embed 所有记忆文件,完成后会通知你。"
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _memory_rebuild_sync(memory_manager, model_label, dim_label) -> str:
|
||||||
|
from agent.memory.embedding import rebuild_in_process
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = rebuild_in_process(memory_manager)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("[CowCli] /memory rebuild-index sync failed")
|
||||||
|
return f"❌ 索引重建失败: {e}"
|
||||||
|
|
||||||
|
if not result.ok:
|
||||||
|
return f"❌ 索引重建失败: {result.error}"
|
||||||
|
return (
|
||||||
|
f"✅ 索引重建完成 (model={model_label}, dim={dim_label})\n"
|
||||||
|
f" cleared : {result.removed}\n"
|
||||||
|
f" chunks : {result.chunks}\n"
|
||||||
|
f" files : {result.files}"
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _notify(e_context, text: str):
|
def _notify(e_context, text: str):
|
||||||
"""Push a notification message back to the chat channel."""
|
"""Push a notification message back to the chat channel."""
|
||||||
|
|||||||
Reference in New Issue
Block a user