feat(memory): support multi-vendor embedding fallback

Add embedding_provider config knob with native support for
openai / dashscope / doubao / zhipu / linkai, plus an in-chat
/memory status and /memory rebuild-index workflow for switching
vendors safely.
This commit is contained in:
zhayujie
2026-05-20 11:00:53 +08:00
parent a0dfdb79df
commit 3ffb563a44
12 changed files with 1572 additions and 449 deletions

View File

@@ -1,167 +0,0 @@
"""
Embedding providers for memory
Supports OpenAI and local embedding models
"""
import hashlib
from abc import ABC, abstractmethod
from typing import List, Optional
class EmbeddingProvider(ABC):
"""Base class for embedding providers"""
@abstractmethod
def embed(self, text: str) -> List[float]:
"""Generate embedding for text"""
pass
@abstractmethod
def embed_batch(self, texts: List[str]) -> List[List[float]]:
"""Generate embeddings for multiple texts"""
pass
@property
@abstractmethod
def dimensions(self) -> int:
"""Get embedding dimensions"""
pass
class OpenAIEmbeddingProvider(EmbeddingProvider):
"""OpenAI embedding provider using REST API"""
def __init__(self, model: str = "text-embedding-3-small", api_key: Optional[str] = None,
api_base: Optional[str] = None, extra_headers: Optional[dict] = None):
"""
Initialize OpenAI embedding provider
Args:
model: Model name (text-embedding-3-small or text-embedding-3-large)
api_key: OpenAI API key
api_base: Optional API base URL
extra_headers: Optional extra headers to include in API requests
"""
self.model = model
self.api_key = api_key
self.api_base = api_base or "https://api.openai.com/v1"
self.extra_headers = extra_headers or {}
# Validate API key
if not self.api_key or self.api_key in ["", "YOUR API KEY", "YOUR_API_KEY"]:
raise ValueError("OpenAI API key is not configured. Please set 'open_ai_api_key' in config.json")
# Set dimensions based on model
self._dimensions = 1536 if "small" in model else 3072
def _call_api(self, input_data):
"""Call OpenAI embedding API using requests"""
import requests
url = f"{self.api_base}/embeddings"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
**self.extra_headers,
}
data = {
"input": input_data,
"model": self.model
}
try:
response = requests.post(url, headers=headers, json=data, timeout=5)
response.raise_for_status()
return response.json()
except requests.exceptions.ConnectionError as e:
raise ConnectionError(f"Failed to connect to OpenAI API at {url}. Please check your network connection and api_base configuration. Error: {str(e)}")
except requests.exceptions.Timeout as e:
raise TimeoutError(f"OpenAI API request timed out after 10s. Please check your network connection. Error: {str(e)}")
except requests.exceptions.HTTPError as e:
if e.response.status_code == 401:
raise ValueError(f"Invalid OpenAI API key. Please check your 'open_ai_api_key' in config.json")
elif e.response.status_code == 429:
raise ValueError(f"OpenAI API rate limit exceeded. Please try again later.")
else:
raise ValueError(f"OpenAI API request failed: {e.response.status_code} - {e.response.text}")
def embed(self, text: str) -> List[float]:
"""Generate embedding for text"""
result = self._call_api(text)
return result["data"][0]["embedding"]
def embed_batch(self, texts: List[str]) -> List[List[float]]:
"""Generate embeddings for multiple texts"""
if not texts:
return []
result = self._call_api(texts)
return [item["embedding"] for item in result["data"]]
@property
def dimensions(self) -> int:
return self._dimensions
# LocalEmbeddingProvider removed - only use OpenAI embedding or keyword search
class EmbeddingCache:
"""Cache for embeddings to avoid recomputation"""
def __init__(self):
self.cache = {}
def get(self, text: str, provider: str, model: str) -> Optional[List[float]]:
"""Get cached embedding"""
key = self._compute_key(text, provider, model)
return self.cache.get(key)
def put(self, text: str, provider: str, model: str, embedding: List[float]):
"""Cache embedding"""
key = self._compute_key(text, provider, model)
self.cache[key] = embedding
@staticmethod
def _compute_key(text: str, provider: str, model: str) -> str:
"""Compute cache key"""
content = f"{provider}:{model}:{text}"
return hashlib.md5(content.encode('utf-8')).hexdigest()
def clear(self):
"""Clear cache"""
self.cache.clear()
def create_embedding_provider(
provider: str = "openai",
model: Optional[str] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
extra_headers: Optional[dict] = None
) -> EmbeddingProvider:
"""
Factory function to create embedding provider
Supports "openai" and "linkai" providers (both use OpenAI-compatible REST API).
If initialization fails, caller should fall back to keyword-only search.
Args:
provider: Provider name ("openai" or "linkai")
model: Model name (default: text-embedding-3-small)
api_key: API key (required)
api_base: API base URL
extra_headers: Optional extra headers to include in API requests
Returns:
EmbeddingProvider instance
Raises:
ValueError: If provider is unsupported or api_key is missing
"""
if provider not in ("openai", "linkai"):
raise ValueError(f"Unsupported embedding provider: {provider}. Use 'openai' or 'linkai'.")
model = model or "text-embedding-3-small"
return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base, extra_headers=extra_headers)

View File

@@ -0,0 +1,41 @@
"""
Embedding subsystem for memory.
Public API:
create_embedding_provider, EmbeddingProvider, OpenAIEmbeddingProvider,
EMBEDDING_VENDORS, EmbeddingCache
RebuildResult, clear_index, rebuild_in_process
detect_index_dim, cleanup_legacy_state_file
"""
from agent.memory.embedding.provider import (
EMBEDDING_VENDORS,
DoubaoEmbeddingProvider,
EmbeddingCache,
EmbeddingProvider,
OpenAIEmbeddingProvider,
create_embedding_provider,
)
from agent.memory.embedding.rebuild import (
RebuildResult,
clear_index,
rebuild_in_process,
)
from agent.memory.embedding.state import (
cleanup_legacy_state_file,
detect_index_dim,
)
__all__ = [
"EMBEDDING_VENDORS",
"DoubaoEmbeddingProvider",
"EmbeddingCache",
"EmbeddingProvider",
"OpenAIEmbeddingProvider",
"create_embedding_provider",
"RebuildResult",
"clear_index",
"rebuild_in_process",
"cleanup_legacy_state_file",
"detect_index_dim",
]

View File

@@ -0,0 +1,486 @@
"""
Embedding providers for memory
Supports multiple OpenAI-compatible embedding vendors:
- openai (text-embedding-3-small / large)
- linkai (OpenAI-compatible passthrough)
- dashscope (Aliyun Tongyi text-embedding-v4)
- doubao (ByteDance Doubao Seed1.5 / large-text on Volcengine Ark)
- zhipu (ZhipuAI embedding-3)
Vendor keys here intentionally match the project's bot_type constants in
common.const (OPENAI, LINKAI, QWEN_DASHSCOPE, DOUBAO, ZHIPU_AI).
All providers share a single OpenAI-compatible REST client. Vendor-specific
behaviors (truncation, query instruction prefix) are configured via metadata.
"""
import hashlib
import math
from abc import ABC, abstractmethod
from typing import List, Optional
# HTTP read timeout for a single embeddings request (seconds). A batch of
# 64+ chunks can take 30-50s end-to-end from China-side networks, so 30s is
# routinely too tight; 90s gives meaningful headroom without letting bad
# endpoints hang forever.
EMBEDDING_HTTP_TIMEOUT = 90
class EmbeddingProvider(ABC):
"""Base class for embedding providers"""
@abstractmethod
def embed(self, text: str) -> List[float]:
"""Generate embedding for a single text (treated as a query by default)"""
pass
@abstractmethod
def embed_batch(self, texts: List[str]) -> List[List[float]]:
"""Generate embeddings for multiple texts (treated as documents)"""
pass
def embed_query(self, text: str) -> List[float]:
"""Generate embedding for a query string (may apply vendor instruction prefix)"""
return self.embed(text)
@property
@abstractmethod
def dimensions(self) -> int:
"""Effective embedding dimensions"""
pass
# ---------------------------------------------------------------------------
# Vendor metadata table
# ---------------------------------------------------------------------------
#
# Each entry describes how to reach a vendor's embedding endpoint. Most
# vendors expose an OpenAI-compatible /embeddings API; the few that don't
# (currently: doubao) set `provider_class` to pick a dedicated adapter.
# Fields:
# provider_class : optional adapter key ("doubao"); defaults to OpenAI-compat
# default_base_url : default API base when not overridden by user
# default_model : default embedding model name
# default_dimensions : recommended unified dim when explicit path is enabled
# supports_dim_param : whether the API accepts a `dimensions` request param
# needs_client_truncate : whether to slice + L2-normalize on the client side
# needs_client_normalize : whether to L2-normalize on the client (always safe)
# query_instruction : optional prefix for asymmetric retrieval (Doubao Seed)
# max_batch_size : max texts per /embeddings request; embed_batch
# auto-paginates above this. Conservative defaults.
#
EMBEDDING_VENDORS = {
"openai": {
"default_base_url": "https://api.openai.com/v1",
"default_model": "text-embedding-3-small",
# Match the legacy default so users adding `embedding_provider: openai`
# to an existing index don't need to rebuild. Override via
# embedding_dimensions if you want 1024 / 1536 / 3072.
"default_dimensions": 1536,
"supports_dim_param": True,
"needs_client_truncate": False,
"needs_client_normalize": False,
"query_instruction": "",
# OpenAI permits up to 2048 items per request, but a single call
# carrying hundreds of long chunks routinely exceeds the 30s read
# timeout from China-side networks. 64 keeps each call well under
# both the token-per-request budget and a reasonable wall clock.
"max_batch_size": 64,
},
"linkai": {
"default_base_url": "https://api.link-ai.tech/v1",
"default_model": "text-embedding-3-small",
"default_dimensions": 1536,
"supports_dim_param": True,
"needs_client_truncate": False,
"needs_client_normalize": False,
"query_instruction": "",
"max_batch_size": 64,
},
"dashscope": {
"default_base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"default_model": "text-embedding-v4",
"default_dimensions": 1024,
"supports_dim_param": True,
"needs_client_truncate": False,
"needs_client_normalize": False,
"query_instruction": "",
"max_batch_size": 10, # DashScope hard cap (text-embedding-v4)
},
"doubao": {
# Doubao no longer offers an OpenAI-compatible /v1/embeddings endpoint.
# Current models are unified under /api/v3/embeddings/multimodal
# which uses a structured `input` payload — see DoubaoEmbeddingProvider.
"provider_class": "doubao",
"default_base_url": "https://ark.cn-beijing.volces.com/api/v3",
"default_model": "doubao-embedding-vision-251215",
# Native options: 1024 or 2048. We default to 1024 to align with the
# other Chinese vendors (dashscope/zhipu) and keep storage footprint
# consistent across providers; users can still override via
# `embedding_dimensions: 2048` in config.
"default_dimensions": 1024,
"supports_dim_param": True,
"needs_client_truncate": False,
"needs_client_normalize": False,
"query_instruction": "",
# Multimodal endpoint produces ONE embedding per call (input list is
# a single document's parts, not a batch). embed_batch loops.
"max_batch_size": 1,
},
"zhipu": {
"default_base_url": "https://open.bigmodel.cn/api/paas/v4",
"default_model": "embedding-3",
"default_dimensions": 1024,
"supports_dim_param": True,
"needs_client_truncate": False,
"needs_client_normalize": False,
"query_instruction": "",
"max_batch_size": 64,
},
}
def _l2_normalize(vec: List[float]) -> List[float]:
"""Normalize a vector to unit length (L2 norm). Returns input on zero vector."""
norm = math.sqrt(sum(v * v for v in vec))
if norm == 0:
return vec
return [v / norm for v in vec]
class OpenAIEmbeddingProvider(EmbeddingProvider):
"""
OpenAI-compatible embedding provider.
Used for openai/linkai/dashscope/ark/zhipu by configuring the metadata
fields. The legacy two-arg constructor (model, api_key, api_base) keeps
working, so the original OpenAI/LinkAI fallback code path is unchanged.
"""
def __init__(
self,
model: str = "text-embedding-3-small",
api_key: Optional[str] = None,
api_base: Optional[str] = None,
extra_headers: Optional[dict] = None,
dimensions: Optional[int] = None,
supports_dim_param: bool = True,
needs_client_truncate: bool = False,
needs_client_normalize: bool = False,
query_instruction: str = "",
max_batch_size: int = 256,
):
"""
Args:
model: Model name (e.g. text-embedding-3-small, text-embedding-v4, embedding-3)
api_key: API key (required)
api_base: API base URL (defaults to OpenAI)
extra_headers: Optional extra HTTP headers
dimensions: Target output dimension. Required when supports_dim_param
is False and needs_client_truncate is True (used to slice).
supports_dim_param: Whether the vendor accepts a `dimensions` body param
needs_client_truncate: Slice the returned vector to `dimensions`
needs_client_normalize: L2-normalize on the client after slicing
query_instruction: Optional prefix prepended to query texts only
max_batch_size: Max items per /embeddings request; embed_batch
auto-paginates above this.
"""
self.model = model
self.api_key = api_key
self.api_base = api_base or "https://api.openai.com/v1"
self.extra_headers = extra_headers or {}
self.supports_dim_param = supports_dim_param
self.needs_client_truncate = needs_client_truncate
self.needs_client_normalize = needs_client_normalize
self.query_instruction = query_instruction or ""
self.max_batch_size = max(1, int(max_batch_size or 1))
if not self.api_key or self.api_key in ["", "YOUR API KEY", "YOUR_API_KEY"]:
raise ValueError("Embedding API key is not configured")
if dimensions is not None and dimensions > 0:
self._dimensions = dimensions
else:
# Legacy heuristic for OpenAI text-embedding-3-* family
self._dimensions = 1536 if "small" in model else 3072
def _call_api(self, input_data):
"""Call OpenAI-compatible /embeddings endpoint"""
import requests
url = f"{self.api_base}/embeddings"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
**self.extra_headers,
}
data = {
"input": input_data,
"model": self.model,
}
if self.supports_dim_param and self._dimensions:
data["dimensions"] = self._dimensions
try:
response = requests.post(url, headers=headers, json=data, timeout=EMBEDDING_HTTP_TIMEOUT)
response.raise_for_status()
return response.json()
except requests.exceptions.ConnectionError as e:
raise ConnectionError(
f"Failed to connect to embedding API at {url}. "
f"Please check network and api_base. Error: {str(e)}"
)
except requests.exceptions.Timeout as e:
raise TimeoutError(f"Embedding API request timed out. Error: {str(e)}")
except requests.exceptions.HTTPError as e:
if e.response.status_code == 401:
raise ValueError("Invalid embedding API key")
elif e.response.status_code == 429:
raise ValueError("Embedding API rate limit exceeded")
else:
raise ValueError(
f"Embedding API request failed: "
f"{e.response.status_code} - {e.response.text}"
)
def _post_process(self, raw: List[float]) -> List[float]:
"""Apply optional client-side truncation + normalization"""
vec = raw
if self.needs_client_truncate and self._dimensions and len(vec) > self._dimensions:
vec = vec[: self._dimensions]
if self.needs_client_normalize:
vec = _l2_normalize(vec)
return vec
def embed(self, text: str) -> List[float]:
"""Generate embedding (treated as document by default)"""
result = self._call_api(text)
return self._post_process(result["data"][0]["embedding"])
def embed_query(self, text: str) -> List[float]:
"""Generate embedding for a query (applies vendor instruction prefix if any)"""
if self.query_instruction:
text = f"{self.query_instruction}{text}"
return self.embed(text)
def embed_batch(self, texts: List[str]) -> List[List[float]]:
"""Generate embeddings for multiple documents.
Automatically paginates by self.max_batch_size so callers can pass any
number of texts. Order of returned vectors matches the input order.
"""
if not texts:
return []
out: List[List[float]] = []
step = self.max_batch_size
for i in range(0, len(texts), step):
chunk = texts[i:i + step]
result = self._call_api(chunk)
out.extend(self._post_process(item["embedding"]) for item in result["data"])
return out
@property
def dimensions(self) -> int:
return self._dimensions
class DoubaoEmbeddingProvider(EmbeddingProvider):
"""
Doubao (Volcengine Ark) multimodal embedding provider.
Doubao deprecated their OpenAI-compatible /v1/embeddings endpoint and
unified everything under /api/v3/embeddings/multimodal, which uses a
structured `input: [{type, text|image_url|video_url}, ...]` payload.
Notes:
* The endpoint produces ONE embedding per call (input list is multiple
modality parts of a single document, not a batch). embed_batch
therefore loops per-text — no native batch support.
* Native dimensions: 1024 or 2048 (default 1024 to align with other
Chinese vendors). No client-side truncation needed.
* Auth: Bearer ARK API key.
"""
def __init__(
self,
model: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
extra_headers: Optional[dict] = None,
dimensions: Optional[int] = None,
):
self.model = model
self.api_key = api_key
self.api_base = api_base or "https://ark.cn-beijing.volces.com/api/v3"
self.extra_headers = extra_headers or {}
if not self.api_key or self.api_key in ["", "YOUR API KEY", "YOUR_API_KEY"]:
raise ValueError("Doubao embedding API key (ark_api_key) is not configured")
if dimensions in (1024, 2048):
self._dimensions = dimensions
elif dimensions is None:
self._dimensions = 1024
else:
raise ValueError(
f"Doubao embedding dimensions must be 1024 or 2048, got {dimensions}"
)
def _call_api(self, text: str) -> List[float]:
"""One call → one embedding. multimodal endpoint takes a single
document represented as a list of typed parts; we send a single
text part."""
import requests
url = f"{self.api_base}/embeddings/multimodal"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
**self.extra_headers,
}
payload = {
"model": self.model,
"input": [{"type": "text", "text": text}],
"dimensions": self._dimensions,
"encoding_format": "float",
}
try:
response = requests.post(url, headers=headers, json=payload, timeout=EMBEDDING_HTTP_TIMEOUT)
response.raise_for_status()
body = response.json()
except requests.exceptions.ConnectionError as e:
raise ConnectionError(
f"Failed to connect to Doubao embedding API at {url}. "
f"Please check network and api_base. Error: {str(e)}"
)
except requests.exceptions.Timeout as e:
raise TimeoutError(f"Doubao embedding API request timed out. Error: {str(e)}")
except requests.exceptions.HTTPError as e:
if e.response.status_code == 401:
raise ValueError("Invalid Doubao (ark) embedding API key")
elif e.response.status_code == 429:
raise ValueError("Doubao embedding API rate limit exceeded")
else:
raise ValueError(
f"Doubao embedding API request failed: "
f"{e.response.status_code} - {e.response.text}"
)
# Response shape per docs: {"data": {"embedding": [...]}}
data = body.get("data")
if isinstance(data, dict) and "embedding" in data:
return data["embedding"]
# Some providers wrap as a list of one — be defensive
if isinstance(data, list) and data and "embedding" in data[0]:
return data[0]["embedding"]
raise ValueError(f"Unexpected Doubao embedding response shape: {body}")
def embed(self, text: str) -> List[float]:
return self._call_api(text)
def embed_batch(self, texts: List[str]) -> List[List[float]]:
# Endpoint produces one embedding per call; loop. Order preserved.
return [self._call_api(t) for t in texts]
@property
def dimensions(self) -> int:
return self._dimensions
class EmbeddingCache:
"""In-memory cache for embeddings to avoid recomputation"""
def __init__(self):
self.cache = {}
def get(self, text: str, provider: str, model: str) -> Optional[List[float]]:
key = self._compute_key(text, provider, model)
return self.cache.get(key)
def put(self, text: str, provider: str, model: str, embedding: List[float]):
key = self._compute_key(text, provider, model)
self.cache[key] = embedding
@staticmethod
def _compute_key(text: str, provider: str, model: str) -> str:
content = f"{provider}:{model}:{text}"
return hashlib.md5(content.encode("utf-8")).hexdigest()
def clear(self):
self.cache.clear()
def create_embedding_provider(
provider: str = "openai",
model: Optional[str] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
extra_headers: Optional[dict] = None,
dimensions: Optional[int] = None,
) -> EmbeddingProvider:
"""
Factory function to create an embedding provider.
Backward compatible: when called with provider in {"openai", "linkai"}
and no `dimensions` arg, behaves exactly as before (1536-dim OpenAI).
New providers ("dashscope", "doubao", "zhipu") require explicit configuration
and use the unified 1024-dim defaults from EMBEDDING_VENDORS.
Args:
provider: Vendor key (one of EMBEDDING_VENDORS)
model: Model name (uses vendor default if None)
api_key: API key (required)
api_base: API base URL (uses vendor default if None)
extra_headers: Optional extra HTTP headers
dimensions: Target output dimension (uses vendor default if None)
Returns:
EmbeddingProvider instance
"""
meta = EMBEDDING_VENDORS.get(provider)
if meta is None:
raise ValueError(
f"Unsupported embedding provider: {provider}. "
f"Supported: {sorted(EMBEDDING_VENDORS.keys())}"
)
# Doubao uses a non-OpenAI-compatible multimodal endpoint.
if meta.get("provider_class") == "doubao":
final_dim = dimensions if (dimensions and dimensions > 0) else meta["default_dimensions"]
return DoubaoEmbeddingProvider(
model=model or meta["default_model"],
api_key=api_key,
api_base=api_base or meta["default_base_url"],
extra_headers=extra_headers,
dimensions=final_dim,
)
# Legacy two-arg call for openai/linkai keeps 1536-dim default behavior
# so existing data isn't invalidated.
is_legacy_call = (
provider in ("openai", "linkai")
and dimensions is None
)
if is_legacy_call:
return OpenAIEmbeddingProvider(
model=model or "text-embedding-3-small",
api_key=api_key,
api_base=api_base,
extra_headers=extra_headers,
)
final_dim = dimensions if (dimensions and dimensions > 0) else meta["default_dimensions"]
return OpenAIEmbeddingProvider(
model=model or meta["default_model"],
api_key=api_key,
api_base=api_base or meta["default_base_url"],
extra_headers=extra_headers,
dimensions=final_dim,
supports_dim_param=meta["supports_dim_param"],
needs_client_truncate=meta["needs_client_truncate"],
needs_client_normalize=meta["needs_client_normalize"],
query_instruction=meta["query_instruction"],
max_batch_size=meta.get("max_batch_size", 256),
)

View File

@@ -0,0 +1,191 @@
"""
Rebuild memory vector index.
Recommended entry point (in-chat, while agent is running):
/memory rebuild-index
Backward-compatible CLI entry (must run from project root):
python -m agent.memory.rebuild_index
What it does:
1. Probes the embedding endpoint with a tiny call to fail fast on
bad provider/model/key — before touching the index.
2. Clears the SQLite chunks/files tables (workspace markdown stays intact).
3. Runs a fresh sync, regenerating embeddings with the currently configured
provider/model/dimensions.
This is the only safe way to switch embedding_provider after the existing
index has been populated by a different-dim model.
"""
from __future__ import annotations
import asyncio
import sys
from dataclasses import dataclass
from typing import Optional
from common.log import logger
from common.utils import expand_path
@dataclass
class RebuildResult:
"""Outcome of a rebuild_in_process() call"""
ok: bool
removed: int = 0
chunks: int = 0
files: int = 0
error: Optional[str] = None
def clear_index(db_path, storage=None) -> int:
"""Wipe chunks/files, reset FTS5, and clean up any legacy state file.
Args:
db_path: Path of the index DB (also used to locate the legacy state
file for migration cleanup, and — when *storage* is None — to
open a fresh connection).
storage: Optional pre-opened MemoryStorage. When provided we reuse it
so the live connection's triggers stay in sync — opening a second
connection would leave the original one's triggers pointing at a
DROP'd chunks_fts table.
We reset (DROP+recreate) chunks_fts because its shadow tables can become
inconsistent across rebuild cycles, causing bm25() / ORDER BY rank to
raise "database disk image is malformed" even when raw MATCH still works.
Returns number of chunks removed.
"""
from agent.memory.embedding.state import cleanup_legacy_state_file
from agent.memory.storage import MemoryStorage
owns_storage = storage is None
if owns_storage:
storage = MemoryStorage(db_path)
try:
before = storage.conn.execute("SELECT COUNT(*) FROM chunks").fetchone()[0]
storage.conn.execute("DELETE FROM chunks")
storage.conn.execute("DELETE FROM files")
storage.conn.commit()
storage.reset_fts5()
finally:
if owns_storage:
storage.close()
cleanup_legacy_state_file(db_path)
return int(before)
def rebuild_in_process(memory_manager) -> RebuildResult:
"""
Rebuild the index using an existing, fully-initialized MemoryManager.
Used by the in-chat /memory rebuild-index command. The caller already has
config loaded, embedding_provider built, and (optionally) the agent
running, so we only need to:
1. Clear chunks/files + state on the manager's storage.
2. Re-sync (force=True).
NOTE: caller must ensure memory_manager.embedding_provider is set, otherwise
sync() will silently skip embedding generation.
"""
if memory_manager is None:
return RebuildResult(ok=False, error="memory_manager is None")
if memory_manager.embedding_provider is None:
return RebuildResult(ok=False, error="embedding_provider is not initialized")
# Probe the embedding endpoint BEFORE clearing the index. A bad
# provider/model/key would otherwise leave the user with an empty index
# that not even keyword search can serve.
try:
memory_manager.embedding_provider.embed_query("ping")
except Exception as e:
logger.error(f"[RebuildIndex] embedding probe failed, aborting rebuild: {e}")
return RebuildResult(ok=False, error=f"embedding endpoint not reachable: {e}")
db_path = memory_manager.config.get_db_path()
try:
removed = clear_index(db_path, storage=memory_manager.storage)
except Exception as e:
logger.exception("[RebuildIndex] clear_index failed")
return RebuildResult(ok=False, error=f"clear failed: {e}")
try:
asyncio.run(memory_manager.sync(force=True))
except RuntimeError:
# Already inside a running event loop (rare in chat handler thread).
loop = asyncio.new_event_loop()
try:
loop.run_until_complete(memory_manager.sync(force=True))
finally:
loop.close()
except Exception as e:
logger.exception("[RebuildIndex] sync failed")
return RebuildResult(ok=False, removed=removed, error=f"re-embed failed: {e}")
stats = memory_manager.storage.get_stats()
chunks = int(stats.get("chunks", 0))
embedded = int(stats.get("embedded", 0))
# sync() degrades to "no embeddings" on batch failure so keyword search
# still works at startup — but in a /rebuild-index request the user
# explicitly asked for vectors. Surface that as a failure.
if chunks > 0 and embedded == 0:
return RebuildResult(
ok=False,
removed=removed,
chunks=chunks,
files=int(stats.get("files", 0)),
error=(
"embedding API failed during sync; index now has chunks but no "
"vectors. Check embedding provider/model/key and retry."
),
)
return RebuildResult(
ok=True,
removed=removed,
chunks=chunks,
files=int(stats.get("files", 0)),
)
def main() -> int:
"""Standalone CLI entry. Must be run from project root (relative config path)."""
from config import conf, load_config
from agent.memory import MemoryConfig, MemoryManager
load_config()
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
memory_config = MemoryConfig(workspace_root=workspace_root)
logger.info(f"[RebuildIndex] Workspace: {workspace_root}")
logger.info(f"[RebuildIndex] Index db: {memory_config.get_db_path()}")
from bridge.agent_initializer import AgentInitializer
initializer = AgentInitializer(bridge=None, agent_bridge=None)
embedding_provider = initializer._init_embedding_provider(memory_config, session_id=None)
if embedding_provider is None:
logger.error(
"[RebuildIndex] No embedding provider could be initialized. "
"Check your config.json. Aborting rebuild."
)
return 1
manager = MemoryManager(memory_config, embedding_provider=embedding_provider)
result = rebuild_in_process(manager)
if not result.ok:
logger.error(f"[RebuildIndex] {result.error}")
return 1
logger.info(
f"[RebuildIndex] Done. removed={result.removed}, "
f"chunks={result.chunks}, files={result.files}"
)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,47 @@
"""
Embedding-related index utilities.
We don't keep a sidecar state file — the SQLite index is the source of truth
and config.json is the source of intent. The two functions below are the
only things needing on-disk awareness:
detect_index_dim : read the dim of stored vectors (display-only)
cleanup_legacy_state_file: remove old embedding_state.json from earlier
versions; safe no-op when absent.
"""
from __future__ import annotations
import json
import os
from pathlib import Path
from typing import Optional, Union
PathLike = Union[str, os.PathLike]
def detect_index_dim(storage) -> Optional[int]:
"""Return the dim of the first stored embedding, or None if the index
has no embeddings. Used by /memory status."""
try:
row = storage.conn.execute(
"SELECT embedding FROM chunks WHERE embedding IS NOT NULL LIMIT 1"
).fetchone()
except Exception:
return None
if not row or not row["embedding"]:
return None
try:
emb = json.loads(row["embedding"])
return len(emb) if isinstance(emb, list) else None
except (json.JSONDecodeError, TypeError):
return None
def cleanup_legacy_state_file(db_path: PathLike) -> None:
"""Remove old embedding_state.json files from earlier versions.
Safe to call repeatedly; no-op if the file is absent."""
legacy = Path(db_path).parent / "embedding_state.json"
try:
legacy.unlink(missing_ok=True)
except Exception:
pass

View File

@@ -13,7 +13,7 @@ from datetime import datetime, timedelta
from agent.memory.config import MemoryConfig, get_default_memory_config from agent.memory.config import MemoryConfig, get_default_memory_config
from agent.memory.storage import MemoryStorage, MemoryChunk, SearchResult from agent.memory.storage import MemoryStorage, MemoryChunk, SearchResult
from agent.memory.chunker import TextChunker from agent.memory.chunker import TextChunker
from agent.memory.embedding import create_embedding_provider, EmbeddingProvider from agent.memory.embedding import EmbeddingProvider
from agent.memory.summarizer import MemoryFlushManager, create_memory_files_if_needed from agent.memory.summarizer import MemoryFlushManager, create_memory_files_if_needed
@@ -50,49 +50,17 @@ class MemoryManager:
overlap_tokens=self.config.chunk_overlap_tokens overlap_tokens=self.config.chunk_overlap_tokens
) )
# Initialize embedding provider (optional, prefer OpenAI, fallback to LinkAI) # Embedding provider is owned by the caller (agent_initializer is the
self.embedding_provider = None # canonical entry point and handles legacy/explicit + state validation).
if embedding_provider: # When None is passed, memory degrades to keyword-only search instead
# of silently re-initializing a vendor here, which would bypass the
# caller's state checks and risk corrupting the index.
self.embedding_provider = embedding_provider self.embedding_provider = embedding_provider
else:
# Try OpenAI first
try:
api_key = os.environ.get('OPENAI_API_KEY')
api_base = os.environ.get('OPENAI_API_BASE')
if api_key:
self.embedding_provider = create_embedding_provider(
provider="openai",
model=self.config.embedding_model,
api_key=api_key,
api_base=api_base
)
except Exception as e:
from common.log import logger
logger.warning(f"[MemoryManager] OpenAI embedding failed: {e}")
# Fallback to LinkAI
if self.embedding_provider is None:
try:
linkai_key = os.environ.get('LINKAI_API_KEY')
linkai_base = os.environ.get('LINKAI_API_BASE', 'https://api.link-ai.tech')
if linkai_key:
from common.utils import get_cloud_headers
cloud_headers = get_cloud_headers(linkai_key)
cloud_headers.pop("Authorization", None)
self.embedding_provider = create_embedding_provider(
provider="linkai",
model=self.config.embedding_model,
api_key=linkai_key,
api_base=f"{linkai_base}/v1",
extra_headers=cloud_headers,
)
except Exception as e:
from common.log import logger
logger.warning(f"[MemoryManager] LinkAI embedding failed: {e}")
if self.embedding_provider is None: if self.embedding_provider is None:
from common.log import logger from common.log import logger
logger.info(f"[MemoryManager] Memory will work with keyword search only (no vector search)") logger.info(
"[MemoryManager] No embedding provider; memory will use keyword search only"
)
# Initialize memory flush manager # Initialize memory flush manager
workspace_dir = self.config.get_workspace() workspace_dir = self.config.get_workspace()
@@ -153,12 +121,14 @@ class MemoryManager:
if self.config.sync_on_search and self._dirty: if self.config.sync_on_search and self._dirty:
await self.sync() await self.sync()
# Perform vector search (if embedding provider available) from common.log import logger
# Perform vector search (if embedding provider available).
# Failures degrade silently to keyword-only — no exception is raised.
vector_results = [] vector_results = []
if self.embedding_provider: if self.embedding_provider:
try: try:
from common.log import logger query_embedding = self.embedding_provider.embed_query(query)
query_embedding = self.embedding_provider.embed(query)
vector_results = self.storage.search_vector( vector_results = self.storage.search_vector(
query_embedding=query_embedding, query_embedding=query_embedding,
user_id=user_id, user_id=user_id,
@@ -167,17 +137,17 @@ class MemoryManager:
) )
logger.info(f"[MemoryManager] Vector search found {len(vector_results)} results for query: {query}") logger.info(f"[MemoryManager] Vector search found {len(vector_results)} results for query: {query}")
except Exception as e: except Exception as e:
from common.log import logger logger.error(
logger.warning(f"[MemoryManager] Vector search failed: {e}") f"[MemoryManager] Vector search failed, falling back to keyword-only: {e}"
)
# Perform keyword search # Perform keyword search (also runs as fallback when vector failed)
keyword_results = self.storage.search_keyword( keyword_results = self.storage.search_keyword(
query=query, query=query,
user_id=user_id, user_id=user_id,
scopes=scopes, scopes=scopes,
limit=max_results * 2 limit=max_results * 2
) )
from common.log import logger
logger.info(f"[MemoryManager] Keyword search found {len(keyword_results)} results for query: {query}") logger.info(f"[MemoryManager] Keyword search found {len(keyword_results)} results for query: {query}")
# Merge results # Merge results
@@ -269,7 +239,16 @@ class MemoryManager:
async def sync(self, force: bool = False): async def sync(self, force: bool = False):
""" """
Synchronize memory from files Synchronize memory from files.
Two-pass design to amortize embedding HTTP cost:
1. Walk all files, chunk those whose hash changed, collect pending
chunks across files. No embedding calls yet.
2. Run a single embed_batch over the union of pending chunks (the
provider auto-paginates by vendor cap), then persist per-file.
For workspaces with many small files (101 files / ~1 chunk each), this
cuts ~100 HTTP calls down to ~ceil(total_chunks / vendor_cap).
Args: Args:
force: Force full reindex force: Force full reindex
@@ -277,124 +256,140 @@ class MemoryManager:
memory_dir = self.config.get_memory_dir() memory_dir = self.config.get_memory_dir()
workspace_dir = self.config.get_workspace() workspace_dir = self.config.get_workspace()
# Scan MEMORY.md (workspace root) files_to_scan: List[tuple] = [] # (file_path, source, scope, user_id)
memory_file = Path(workspace_dir) / "MEMORY.md" memory_file = Path(workspace_dir) / "MEMORY.md"
if memory_file.exists(): if memory_file.exists():
await self._sync_file(memory_file, "memory", "shared", None) files_to_scan.append((memory_file, "memory", "shared", None))
# Scan memory directory (including daily summaries)
if memory_dir.exists(): if memory_dir.exists():
for file_path in memory_dir.rglob("*.md"): for file_path in memory_dir.rglob("*.md"):
# Skip hidden directories (e.g. .dreams/)
if any(part.startswith('.') for part in file_path.relative_to(workspace_dir).parts): if any(part.startswith('.') for part in file_path.relative_to(workspace_dir).parts):
continue continue
rel_parts = file_path.relative_to(workspace_dir).parts
# Determine scope and user_id from path if "daily" in rel_parts:
rel_path = file_path.relative_to(workspace_dir) if "users" in rel_parts or len(rel_parts) > 3:
parts = rel_path.parts user_idx = rel_parts.index("daily") + 1
user_id = rel_parts[user_idx] if user_idx < len(rel_parts) else None
# Check if it's in daily summary directory
if "daily" in parts:
# Daily summary files
if "users" in parts or len(parts) > 3:
# User-scoped daily summary: memory/daily/{user_id}/2024-01-29.md
user_idx = parts.index("daily") + 1
user_id = parts[user_idx] if user_idx < len(parts) else None
scope = "user" scope = "user"
else: else:
# Shared daily summary: memory/daily/2024-01-29.md
user_id = None user_id = None
scope = "shared" scope = "shared"
elif "users" in parts: elif "users" in rel_parts:
# User-scoped memory user_idx = rel_parts.index("users") + 1
user_idx = parts.index("users") + 1 user_id = rel_parts[user_idx] if user_idx < len(rel_parts) else None
user_id = parts[user_idx] if user_idx < len(parts) else None
scope = "user" scope = "user"
else: else:
# Shared memory
user_id = None user_id = None
scope = "shared" scope = "shared"
files_to_scan.append((file_path, "memory", scope, user_id))
await self._sync_file(file_path, "memory", scope, user_id)
# Scan knowledge directory (structured knowledge wiki)
from config import conf from config import conf
if conf().get("knowledge", True): if conf().get("knowledge", True):
knowledge_dir = Path(workspace_dir) / "knowledge" knowledge_dir = Path(workspace_dir) / "knowledge"
if knowledge_dir.exists(): if knowledge_dir.exists():
for file_path in knowledge_dir.rglob("*.md"): for file_path in knowledge_dir.rglob("*.md"):
await self._sync_file(file_path, "knowledge", "shared", None) files_to_scan.append((file_path, "knowledge", "shared", None))
self._dirty = False # Pass 1: inline chunking + change detection. Inlined (instead of
# calling self._prepare_file_for_sync) so this method does not depend
async def _sync_file( # on any sibling helpers — keeps it robust against partial reloads
self, # where the class object is older than the method's source.
file_path: Path, pending: List[Dict[str, Any]] = []
source: str, workspace_dir_path = self.config.get_workspace()
scope: str, for file_path, source, scope, user_id in files_to_scan:
user_id: Optional[str] try:
):
"""Sync a single file"""
# Compute file hash
content = file_path.read_text(encoding='utf-8') content = file_path.read_text(encoding='utf-8')
except Exception:
continue
file_hash = MemoryStorage.compute_hash(content) file_hash = MemoryStorage.compute_hash(content)
rel_path = str(file_path.relative_to(workspace_dir_path))
# Get relative path if self.storage.get_file_hash(rel_path) == file_hash:
workspace_dir = self.config.get_workspace() continue
rel_path = str(file_path.relative_to(workspace_dir))
# Check if file changed
stored_hash = self.storage.get_file_hash(rel_path)
if stored_hash == file_hash:
return # No changes
# Delete old chunks
self.storage.delete_by_path(rel_path)
# Chunk and embed
chunks = self.chunker.chunk_text(content) chunks = self.chunker.chunk_text(content)
if not chunks: if not chunks:
continue
pending.append({
"file_path": file_path,
"rel_path": rel_path,
"source": source,
"scope": scope,
"user_id": user_id,
"file_hash": file_hash,
"chunks": chunks,
"texts": [c.text for c in chunks],
})
if not pending:
self._dirty = False
return return
texts = [chunk.text for chunk in chunks] # Pass 2: single batched embed across all pending chunks.
if self.embedding_provider: # CRITICAL: never touch the index until we hold valid embeddings.
embeddings = self.embedding_provider.embed_batch(texts) # If embed_batch fails, leave the existing index intact (chunks +
else: # file_hash) so the next sync will retry the same files. Writing
embeddings = [None] * len(texts) # NULL embeddings + updating file_hash here would mark the file as
# "successfully synced" and silently strand it without vectors.
all_texts: List[str] = []
for entry in pending:
all_texts.extend(entry["texts"])
# Create memory chunks if not self.embedding_provider:
# No provider configured at all (legacy keyword-only). Persist
# chunks without embeddings — this is the user's intent.
all_embeddings: List[Optional[List[float]]] = [None] * len(all_texts)
else:
try:
all_embeddings = self.embedding_provider.embed_batch(all_texts)
except Exception as e:
from common.log import logger
logger.error(
f"[MemoryManager] Batch embedding failed for {len(all_texts)} "
f"chunks across {len(pending)} files: {e}. "
f"Index left untouched; will retry on next sync."
)
# Bail before touching storage. self._dirty stays True so
# callers know there is pending work.
return
# Pass 3: inline persist — same self-contained reasoning as Pass 1.
cursor = 0
for entry in pending:
n = len(entry["texts"])
entry_embeddings = all_embeddings[cursor:cursor + n]
cursor += n
rel_path = entry["rel_path"]
self.storage.delete_by_path(rel_path)
memory_chunks = [] memory_chunks = []
for chunk, embedding in zip(chunks, embeddings): for chunk, embedding in zip(entry["chunks"], entry_embeddings):
chunk_id = self._generate_chunk_id(rel_path, chunk.start_line, chunk.end_line) chunk_id = self._generate_chunk_id(rel_path, chunk.start_line, chunk.end_line)
chunk_hash = MemoryStorage.compute_hash(chunk.text) chunk_hash = MemoryStorage.compute_hash(chunk.text)
memory_chunks.append(MemoryChunk( memory_chunks.append(MemoryChunk(
id=chunk_id, id=chunk_id,
user_id=user_id, user_id=entry["user_id"],
scope=scope, scope=entry["scope"],
source=source, source=entry["source"],
path=rel_path, path=rel_path,
start_line=chunk.start_line, start_line=chunk.start_line,
end_line=chunk.end_line, end_line=chunk.end_line,
text=chunk.text, text=chunk.text,
embedding=embedding, embedding=embedding,
hash=chunk_hash, hash=chunk_hash,
metadata=None metadata=None,
)) ))
# Save
self.storage.save_chunks_batch(memory_chunks) self.storage.save_chunks_batch(memory_chunks)
stat = entry["file_path"].stat()
# Update file metadata
stat = file_path.stat()
self.storage.update_file_metadata( self.storage.update_file_metadata(
path=rel_path, path=rel_path,
source=source, source=entry["source"],
file_hash=file_hash, file_hash=entry["file_hash"],
mtime=int(stat.st_mtime), mtime=int(stat.st_mtime),
size=stat.st_size size=stat.st_size,
) )
self._dirty = False
def flush_memory( def flush_memory(
self, self,
messages: list, messages: list,

View File

@@ -0,0 +1,14 @@
"""
Backward-compatible shim for the legacy entry point:
python -m agent.memory.rebuild_index
The implementation now lives in agent.memory.embedding.rebuild.
Prefer using `/memory rebuild-index` in chat going forward.
"""
from agent.memory.embedding.rebuild import main
if __name__ == "__main__":
import sys
sys.exit(main())

View File

@@ -144,44 +144,36 @@ class MemoryStorage:
ON chunks(path, hash) ON chunks(path, hash)
""") """)
# Create FTS5 virtual table for keyword search (only if supported) # Create FTS5 virtual table + triggers (only if supported).
# Self-heal: if the previous process crashed mid-rebuild and left
# triggers pointing at a missing chunks_fts (or vice versa), wipe
# both sides and recreate cleanly. Otherwise next chunks INSERT
# will fail with "no such table: chunks_fts".
if self.fts5_available: if self.fts5_available:
# Use default unicode61 tokenizer (stable and compatible) if self._fts5_state_inconsistent():
# For CJK support, we'll use LIKE queries as fallback from common.log import logger
self.conn.execute(""" logger.warning(
CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5( "[MemoryStorage] FTS5 state inconsistent (triggers/table mismatch). "
text, "Resetting chunks_fts to recover."
id UNINDEXED,
user_id UNINDEXED,
path UNINDEXED,
source UNINDEXED,
scope UNINDEXED,
content='chunks',
content_rowid='rowid'
) )
""") self.conn.execute("DROP TRIGGER IF EXISTS chunks_ai")
self.conn.execute("DROP TRIGGER IF EXISTS chunks_ad")
self.conn.execute("DROP TRIGGER IF EXISTS chunks_au")
self.conn.execute("DROP TABLE IF EXISTS chunks_fts")
self.conn.commit()
self._create_fts5_objects()
# Create triggers to keep FTS in sync # Probe FTS5 shadow tables. The schema may be intact but the
self.conn.execute(""" # internal _data/_idx/_docsize blob can still be corrupt — that
CREATE TRIGGER IF NOT EXISTS chunks_ai AFTER INSERT ON chunks BEGIN # surfaces as "database disk image is malformed" on bm25 / MATCH.
INSERT INTO chunks_fts(rowid, text, id, user_id, path, source, scope) # We rebuild from the chunks table when that happens; data isn't
VALUES (new.rowid, new.text, new.id, new.user_id, new.path, new.source, new.scope); # lost because chunks (the content table) is the source of truth.
END if self._fts5_shadow_corrupt():
""") from common.log import logger
logger.warning(
self.conn.execute(""" "[MemoryStorage] FTS5 shadow tables corrupt; rebuilding from chunks."
CREATE TRIGGER IF NOT EXISTS chunks_ad AFTER DELETE ON chunks BEGIN )
DELETE FROM chunks_fts WHERE rowid = old.rowid; self._rebuild_fts5_from_chunks()
END
""")
self.conn.execute("""
CREATE TRIGGER IF NOT EXISTS chunks_au AFTER UPDATE ON chunks BEGIN
UPDATE chunks_fts SET text = new.text, id = new.id,
user_id = new.user_id, path = new.path, source = new.source, scope = new.scope
WHERE rowid = new.rowid;
END
""")
# Create files metadata table # Create files metadata table
self.conn.execute(""" self.conn.execute("""
@@ -197,6 +189,115 @@ class MemoryStorage:
self.conn.commit() self.conn.commit()
def _fts5_state_inconsistent(self) -> bool:
"""Detect a half-broken FTS5 setup (e.g. trigger exists but table doesn't)."""
try:
row = self.conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='chunks_fts'"
).fetchone()
table_exists = row is not None
row = self.conn.execute(
"SELECT COUNT(*) FROM sqlite_master WHERE type='trigger' "
"AND name IN ('chunks_ai','chunks_ad','chunks_au')"
).fetchone()
trigger_count = int(row[0]) if row else 0
except Exception:
return False
# Healthy = both present (3 triggers + table) or both absent.
return table_exists != (trigger_count > 0)
def _create_fts5_objects(self):
"""Create chunks_fts virtual table and the 3 sync triggers.
Idempotent: uses IF NOT EXISTS. Caller must hold self.conn.
"""
self.conn.execute("""
CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5(
text,
id UNINDEXED,
user_id UNINDEXED,
path UNINDEXED,
source UNINDEXED,
scope UNINDEXED,
content='chunks',
content_rowid='rowid'
)
""")
self.conn.execute("""
CREATE TRIGGER IF NOT EXISTS chunks_ai AFTER INSERT ON chunks BEGIN
INSERT INTO chunks_fts(rowid, text, id, user_id, path, source, scope)
VALUES (new.rowid, new.text, new.id, new.user_id, new.path, new.source, new.scope);
END
""")
self.conn.execute("""
CREATE TRIGGER IF NOT EXISTS chunks_ad AFTER DELETE ON chunks BEGIN
DELETE FROM chunks_fts WHERE rowid = old.rowid;
END
""")
self.conn.execute("""
CREATE TRIGGER IF NOT EXISTS chunks_au AFTER UPDATE ON chunks BEGIN
UPDATE chunks_fts SET text = new.text, id = new.id,
user_id = new.user_id, path = new.path,
source = new.source, scope = new.scope
WHERE rowid = new.rowid;
END
""")
def reset_fts5(self):
"""Drop and recreate chunks_fts + triggers in one transaction.
Used by rebuild_index to recover from FTS5 shadow-table corruption
(bm25/ORDER BY rank may raise "database disk image is malformed"
even when raw MATCH still works).
Triggers must be dropped first; otherwise the next chunks INSERT/DELETE
on the existing connection will hit "no such table: chunks_fts".
"""
if not self.fts5_available:
return
self.conn.execute("DROP TRIGGER IF EXISTS chunks_ai")
self.conn.execute("DROP TRIGGER IF EXISTS chunks_ad")
self.conn.execute("DROP TRIGGER IF EXISTS chunks_au")
self.conn.execute("DROP TABLE IF EXISTS chunks_fts")
self._create_fts5_objects()
self.conn.commit()
def _fts5_shadow_corrupt(self) -> bool:
"""Probe whether bm25 over chunks_fts errors out at startup.
Schema (table + triggers) can be intact while the underlying
FTS5 shadow blobs are malformed — typically because the previous
process crashed mid-write or wrote with a different SQLite build.
A cheap MATCH probe surfaces it immediately."""
try:
self.conn.execute(
"SELECT bm25(chunks_fts) FROM chunks_fts WHERE chunks_fts MATCH 'a' LIMIT 1"
).fetchone()
return False
except sqlite3.DatabaseError as e:
msg = str(e).lower()
return "malformed" in msg or "corrupt" in msg
except Exception:
# Any other error (e.g. table missing) is handled by the
# state-inconsistent path; treat as healthy here.
return False
def _rebuild_fts5_from_chunks(self):
"""Drop FTS5, recreate it, then INSERT every row from chunks.
Safe data-wise: chunks (the content table) is the source of truth.
Done in one transaction so a crash leaves either fully old or fully
new state, not a partial rebuild.
"""
# Reset schema first; this clears any malformed shadow blobs.
self.reset_fts5()
# Re-feed content. Triggers handle future writes automatically.
self.conn.execute("""
INSERT INTO chunks_fts(rowid, text, id, user_id, path, source, scope)
SELECT rowid, text, id, user_id, path, source, scope FROM chunks
""")
self.conn.commit()
def save_chunk(self, chunk: MemoryChunk): def save_chunk(self, chunk: MemoryChunk):
"""Save a memory chunk""" """Save a memory chunk"""
self.conn.execute(""" self.conn.execute("""
@@ -284,8 +385,21 @@ class MemoryStorage:
rows = self.conn.execute(query, params).fetchall() rows = self.conn.execute(query, params).fetchall()
# Calculate cosine similarity # Calculate cosine similarity. We probe the first row's dim to fail
# loudly on a query/index dim mismatch — otherwise every doc would
# score 0 silently, leaving the user wondering why search broke.
results = [] results = []
query_dim = len(query_embedding)
if rows:
first = json.loads(rows[0]['embedding'])
if isinstance(first, list) and len(first) != query_dim:
raise ValueError(
f"Embedding dim mismatch: query is {query_dim}-dim but "
f"index stores {len(first)}-dim vectors. The configured "
f"embedding model differs from the one that built the "
f"index — run /memory rebuild-index to re-embed."
)
for row in rows: for row in rows:
embedding = json.loads(row['embedding']) embedding = json.loads(row['embedding'])
similarity = self._cosine_similarity(query_embedding, embedding) similarity = self._cosine_similarity(query_embedding, embedding)
@@ -321,26 +435,23 @@ class MemoryStorage:
Keyword search using FTS5 + LIKE fallback Keyword search using FTS5 + LIKE fallback
Strategy: Strategy:
1. If FTS5 available: Try FTS5 search first (good for English and word-based languages) 1. If FTS5 available and healthy: try FTS5 first
2. If no FTS5 or no results and query contains CJK: Use LIKE search 2. Always fall back to LIKE for CJK queries
3. If FTS5 fails OR returns empty for non-CJK, also try LIKE so a
broken FTS5 shadow table doesn't silently kill keyword search.
""" """
if scopes is None: if scopes is None:
scopes = ["shared"] scopes = ["shared"]
if user_id: if user_id:
scopes.append("user") scopes.append("user")
# Try FTS5 search first (if available)
if self.fts5_available: if self.fts5_available:
fts_results = self._search_fts5(query, user_id, scopes, limit) fts_results = self._search_fts5(query, user_id, scopes, limit)
if fts_results: if fts_results:
return fts_results return fts_results
# Fallback to LIKE search (always for CJK, or if FTS5 not available)
if not self.fts5_available or MemoryStorage._contains_cjk(query):
return self._search_like(query, user_id, scopes, limit) return self._search_like(query, user_id, scopes, limit)
return []
def _search_fts5( def _search_fts5(
self, self,
query: str, query: str,
@@ -394,7 +505,11 @@ class MemoryStorage:
) )
for row in rows for row in rows
] ]
except Exception: except Exception as e:
from common.log import logger
logger.error(
f"[MemoryStorage] FTS5 search failed (caller will fall back to LIKE): {e}"
)
return [] return []
def _search_like( def _search_like(
@@ -404,21 +519,28 @@ class MemoryStorage:
scopes: List[str], scopes: List[str],
limit: int limit: int
) -> List[SearchResult]: ) -> List[SearchResult]:
"""LIKE-based search for CJK characters""" """LIKE-based search.
Used as the keyword-search fallback when FTS5 is unavailable, fails,
or returns empty. Supports both CJK runs and ASCII word tokens so it
can serve as a true safety net for any query.
"""
import re import re
# Extract CJK words (2+ characters) # CJK runs (2+ chars) + ASCII word tokens (3+ chars to avoid noise)
cjk_words = re.findall(r'[\u4e00-\u9fff]{2,}', query) cjk_words = re.findall(r'[\u4e00-\u9fff]{2,}', query)
if not cjk_words: ascii_words = [t for t in re.findall(r'[A-Za-z0-9_]+', query) if len(t) >= 3]
words = cjk_words + ascii_words
if not words:
return [] return []
scope_placeholders = ','.join('?' * len(scopes)) scope_placeholders = ','.join('?' * len(scopes))
# Build LIKE conditions for each word # Build LIKE conditions for each word (case-insensitive for ASCII)
like_conditions = [] like_conditions = []
params = [] params = []
for word in cjk_words: for word in words:
like_conditions.append("text LIKE ?") like_conditions.append("LOWER(text) LIKE ?")
params.append(f'%{word}%') params.append(f'%{word.lower()}%')
where_clause = ' OR '.join(like_conditions) where_clause = ' OR '.join(like_conditions)
params.extend(scopes) params.extend(scopes)
@@ -455,7 +577,9 @@ class MemoryStorage:
) )
for row in rows for row in rows
] ]
except Exception: except Exception as e:
from common.log import logger
logger.error(f"[MemoryStorage] LIKE search failed: {e}")
return [] return []
def delete_by_path(self, path: str): def delete_by_path(self, path: str):
@@ -490,9 +614,14 @@ class MemoryStorage:
SELECT COUNT(*) as cnt FROM files SELECT COUNT(*) as cnt FROM files
""").fetchone()['cnt'] """).fetchone()['cnt']
embedded_count = self.conn.execute("""
SELECT COUNT(*) as cnt FROM chunks WHERE embedding IS NOT NULL
""").fetchone()['cnt']
return { return {
'chunks': chunks_count, 'chunks': chunks_count,
'files': files_count 'files': files_count,
'embedded': embedded_count,
} }
def close(self): def close(self):

View File

@@ -17,6 +17,10 @@ from common.utils import expand_path
# Module-level lock to serialize scheduler init across concurrent sessions # Module-level lock to serialize scheduler init across concurrent sessions
_scheduler_init_lock = threading.Lock() _scheduler_init_lock = threading.Lock()
# Track whether the embedding model log has been printed in this process,
# so we avoid spamming it once per session.
_embedding_logged: bool = False
class AgentInitializer: class AgentInitializer:
""" """
@@ -272,52 +276,19 @@ class AgentInitializer:
memory_tools = [] memory_tools = []
try: try:
from agent.memory import MemoryManager, MemoryConfig, create_embedding_provider from agent.memory import MemoryManager, MemoryConfig
from agent.tools import MemorySearchTool, MemoryGetTool from agent.tools import MemorySearchTool, MemoryGetTool
from config import conf from config import conf
# Initialize embedding provider (prefer OpenAI, fallback to LinkAI)
embedding_provider = None
openai_api_key = conf().get("open_ai_api_key", "")
openai_api_base = conf().get("open_ai_api_base", "")
if openai_api_key and openai_api_key not in ["", "YOUR API KEY", "YOUR_API_KEY"]:
try:
embedding_provider = create_embedding_provider(
provider="openai",
model="text-embedding-3-small",
api_key=openai_api_key,
api_base=openai_api_base or "https://api.openai.com/v1"
)
if session_id is None:
logger.info("[AgentInitializer] OpenAI embedding initialized")
except Exception as e:
logger.warning(f"[AgentInitializer] OpenAI embedding failed: {e}")
if embedding_provider is None:
linkai_api_key = conf().get("linkai_api_key", "") or os.environ.get("LINKAI_API_KEY", "")
linkai_api_base = conf().get("linkai_api_base", "https://api.link-ai.tech")
if linkai_api_key and linkai_api_key not in ["", "YOUR API KEY", "YOUR_API_KEY"]:
try:
embedding_provider = create_embedding_provider(
provider="linkai",
model="text-embedding-3-small",
api_key=linkai_api_key,
api_base=f"{linkai_api_base}/v1"
)
if session_id is None:
logger.info("[AgentInitializer] LinkAI embedding initialized (fallback)")
except Exception as e:
logger.warning(f"[AgentInitializer] LinkAI embedding failed: {e}")
# Create memory manager
memory_config = MemoryConfig(workspace_root=workspace_root) memory_config = MemoryConfig(workspace_root=workspace_root)
memory_manager = MemoryManager(memory_config, embedding_provider=embedding_provider)
# Sync memory embedding_provider = self._init_embedding_provider(
memory_config, session_id=session_id
)
memory_manager = MemoryManager(memory_config, embedding_provider=embedding_provider)
self._sync_memory(memory_manager, session_id) self._sync_memory(memory_manager, session_id)
# Create memory tools
memory_tools = [ memory_tools = [
MemorySearchTool(memory_manager), MemorySearchTool(memory_manager),
MemoryGetTool(memory_manager) MemoryGetTool(memory_manager)
@@ -331,6 +302,190 @@ class AgentInitializer:
return memory_manager, memory_tools return memory_manager, memory_tools
def _init_embedding_provider(self, memory_config, session_id: Optional[str] = None):
"""
Initialize the embedding provider for memory.
Two paths:
A. Default (no `embedding_provider` in config.json):
Auto-init OpenAI -> LinkAI fallback. Existing 1536-dim indices
keep working.
B. Explicit (`embedding_provider` is set):
Initialize the requested vendor with unified dim (default 1024).
If the index was built with a different dim, vector search will
quietly return no results (cosine returns 0) and keyword search
takes over until the user runs /memory rebuild-index.
"""
from agent.memory import create_embedding_provider
from config import conf
explicit_provider = (conf().get("embedding_provider") or "").strip().lower()
if not explicit_provider:
return self._init_embedding_provider_legacy(session_id=session_id)
return self._init_embedding_provider_explicit(
memory_config, explicit_provider, session_id=session_id,
)
def _init_embedding_provider_legacy(self, session_id: Optional[str] = None):
"""Legacy auto-init path: OpenAI -> LinkAI. Preserved verbatim for compat."""
from agent.memory import create_embedding_provider
from config import conf
embedding_provider = None
embedding_model = None
openai_api_key = conf().get("open_ai_api_key", "")
openai_api_base = conf().get("open_ai_api_base", "")
if openai_api_key and openai_api_key not in ["", "YOUR API KEY", "YOUR_API_KEY"]:
try:
model = "text-embedding-3-small"
embedding_provider = create_embedding_provider(
provider="openai",
model=model,
api_key=openai_api_key,
api_base=openai_api_base or "https://api.openai.com/v1"
)
embedding_model = f"openai/{model}"
except Exception as e:
logger.warning(f"[AgentInitializer] OpenAI embedding failed: {e}")
if embedding_provider is None:
linkai_api_key = conf().get("linkai_api_key", "") or os.environ.get("LINKAI_API_KEY", "")
linkai_api_base = conf().get("linkai_api_base", "https://api.link-ai.tech")
if linkai_api_key and linkai_api_key not in ["", "YOUR API KEY", "YOUR_API_KEY"]:
try:
model = "text-embedding-3-small"
embedding_provider = create_embedding_provider(
provider="linkai",
model=model,
api_key=linkai_api_key,
api_base=f"{linkai_api_base}/v1"
)
embedding_model = f"linkai/{model}"
except Exception as e:
logger.warning(f"[AgentInitializer] LinkAI embedding failed: {e}")
if embedding_provider is not None and embedding_model:
global _embedding_logged
if not _embedding_logged:
logger.info(
f"[AgentInitializer] Embedding model in use: {embedding_model} "
f"(dim={embedding_provider.dimensions})"
)
_embedding_logged = True
return embedding_provider
def _init_embedding_provider_explicit(
self,
memory_config,
provider_key: str,
session_id: Optional[str] = None,
):
"""Explicit-provider path: build the configured vendor.
If the index was built with a different dim, vector search will
silently return no results (cosine returns 0 for mismatched dims)
and keyword search takes over. Users switch vendors by running
/memory rebuild-index — see docs.
"""
from agent.memory import create_embedding_provider
from agent.memory.embedding import EMBEDDING_VENDORS
from config import conf
meta = EMBEDDING_VENDORS.get(provider_key)
if meta is None:
logger.error(
f"[AgentInitializer] Unknown embedding_provider '{provider_key}'. "
f"Supported: {sorted(EMBEDDING_VENDORS.keys())}. "
f"Memory will run in keyword-only mode."
)
return None
api_key = self._resolve_embedding_api_key(provider_key)
api_base = self._resolve_embedding_api_base(provider_key, meta["default_base_url"])
if not api_key:
logger.error(
f"[AgentInitializer] embedding_provider='{provider_key}' is set but its "
f"API key is missing. Memory will run in keyword-only mode."
)
return None
model = (conf().get("embedding_model") or "").strip() or meta["default_model"]
try:
cfg_dim = int(conf().get("embedding_dimensions") or 0)
except (TypeError, ValueError):
cfg_dim = 0
dim = cfg_dim if cfg_dim > 0 else meta["default_dimensions"]
try:
provider = create_embedding_provider(
provider=provider_key,
model=model,
api_key=api_key,
api_base=api_base,
dimensions=dim,
)
except Exception as e:
logger.error(
f"[AgentInitializer] Failed to init embedding provider "
f"'{provider_key}/{model}': {e}"
)
return None
global _embedding_logged
if not _embedding_logged:
logger.info(
f"[AgentInitializer] Embedding model in use: "
f"{provider_key}/{model} (dim={provider.dimensions})"
)
_embedding_logged = True
return provider
@staticmethod
def _resolve_embedding_api_key(provider_key: str) -> str:
"""Pick the API key for an explicit embedding provider from config."""
from config import conf
key_map = {
"openai": "open_ai_api_key",
"linkai": "linkai_api_key",
"dashscope": "dashscope_api_key",
"doubao": "ark_api_key",
"zhipu": "zhipu_ai_api_key",
}
field = key_map.get(provider_key)
if not field:
return ""
value = conf().get(field, "") or ""
if value in ["", "YOUR API KEY", "YOUR_API_KEY"]:
return ""
return value
@staticmethod
def _resolve_embedding_api_base(provider_key: str, default_base: str) -> str:
"""Pick the API base for an explicit embedding provider from config."""
from config import conf
base_map = {
"openai": "open_ai_api_base",
"linkai": "linkai_api_base",
"doubao": "ark_base_url",
"zhipu": "zhipu_ai_api_base",
}
field = base_map.get(provider_key)
if not field:
return default_base
value = (conf().get(field) or "").strip()
if not value:
return default_base
if provider_key == "linkai" and not value.rstrip("/").endswith("/v1"):
return f"{value.rstrip('/')}/v1"
return value
def _sync_memory(self, memory_manager, session_id: Optional[str] = None): def _sync_memory(self, memory_manager, session_id: Optional[str] = None):
"""Sync memory database""" """Sync memory database"""
try: try:

View File

@@ -26,7 +26,8 @@ Commands:
knowledge Manage knowledge base. knowledge Manage knowledge base.
install-browser Install browser tool (Playwright + Chromium). install-browser Install browser tool (Playwright + Chromium).
Tip: You can also send /help, /skill list, etc. in agent chat.""" Tip: Memory index management lives in chat — send /memory status or
/memory rebuild-index to the running agent."""
class CowCLI(click.Group): class CowCLI(click.Group):

View File

@@ -100,6 +100,10 @@ available_setting = {
"dashscope_api_key": "", "dashscope_api_key": "",
# Google Gemini Api Key # Google Gemini Api Key
"gemini_api_key": "", "gemini_api_key": "",
# Embedding 模型设置
"embedding_provider": "", # 显式指定厂商openai / linkai / dashscope / doubao / zhipu (与 bot_type 命名一致)
"embedding_model": "", # 留空使用厂商默认 model
"embedding_dimensions": 0, # 留空/0 使用厂商默认维度(推荐统一 1024
# 语音设置 # 语音设置
"speech_recognition": True, # 是否开启语音识别 "speech_recognition": True, # 是否开启语音识别
"group_speech_recognition": False, # 是否开启群组语音识别 "group_speech_recognition": False, # 是否开启群组语音识别

View File

@@ -62,10 +62,25 @@ class CowCliPlugin(Plugin):
content = e_context["context"].content.strip() content = e_context["context"].content.strip()
parsed = self._parse_command(content) parsed = self._parse_command(content)
if not parsed: if parsed is None:
return return
cmd, args = parsed cmd, args = parsed
if cmd not in KNOWN_COMMANDS:
# Slash-prefixed near-miss: looks like a typo of a real command.
# Intercept with a hint so we don't burn an LLM round on "/momory".
suggestion = self._suggest_command(cmd)
if suggestion is None:
return
hint = f"未知命令: /{cmd}"
if suggestion:
hint += f"\n你是不是想输入 /{suggestion} ?"
hint += "\n发送 /help 查看全部命令。"
e_context["reply"] = Reply(ReplyType.TEXT, hint)
e_context.action = EventAction.BREAK_PASS
return
logger.info(f"[CowCli] intercepted command: {cmd} {args}") logger.info(f"[CowCli] intercepted command: {cmd} {args}")
result = self._dispatch(cmd, args, e_context) result = self._dispatch(cmd, args, e_context)
@@ -82,29 +97,81 @@ class CowCliPlugin(Plugin):
cow <command> [args...] e.g. "cow skill list" cow <command> [args...] e.g. "cow skill list"
/<command> [args...] e.g. "/skill list" /<command> [args...] e.g. "/skill list"
Returns (command, args_string) or None if not a cow command. Returns:
""" - (command, args_string): when the message looks like a command.
parts = None 'command' may NOT be in KNOWN_COMMANDS; caller should validate.
- None: when the message is not command-like at all.
We deliberately return parsed-but-unknown for the slash form so the
caller can offer a typo hint instead of silently passing the message
through to the agent.
"""
if content.startswith("/"): if content.startswith("/"):
rest = content[1:].strip() rest = content[1:].strip()
if rest: if not rest:
parts = rest.split(None, 1)
elif content.startswith("cow "):
rest = content[4:].strip()
if rest:
parts = rest.split(None, 1)
if not parts:
return None return None
parts = rest.split(None, 1)
cmd = parts[0].lower() cmd = parts[0].lower()
if cmd not in KNOWN_COMMANDS:
return None
args = parts[1] if len(parts) > 1 else "" args = parts[1] if len(parts) > 1 else ""
return cmd, args return cmd, args
if content.startswith("cow "):
rest = content[4:].strip()
if not rest:
return None
parts = rest.split(None, 1)
cmd = parts[0].lower()
if cmd not in KNOWN_COMMANDS:
# 'cow xxx' that isn't a command — don't intercept (could be
# natural language like "cow xxx 怎么样").
return None
args = parts[1] if len(parts) > 1 else ""
return cmd, args
return None
@staticmethod
def _suggest_command(cmd: str) -> str:
"""
Return the closest known command if cmd is a likely typo, else "".
Returns None to indicate "do not intercept" (when input is too far off).
Heuristic: edit distance <= 1 (single insert/delete/substitute) when
|cmd| >= 3, and the candidate shares the same first letter.
"""
if not cmd:
return ""
if len(cmd) < 3:
return None
def edit_distance_le1(a: str, b: str) -> bool:
if a == b:
return True
la, lb = len(a), len(b)
if abs(la - lb) > 1:
return False
if la == lb:
diffs = sum(1 for x, y in zip(a, b) if x != y)
return diffs <= 1
short, long_ = (a, b) if la < lb else (b, a)
i = j = 0
skipped = False
while i < len(short) and j < len(long_):
if short[i] != long_[j]:
if skipped:
return False
skipped = True
j += 1
else:
i += 1
j += 1
return True
for known in KNOWN_COMMANDS:
if known[0] == cmd[0] and edit_distance_le1(cmd, known):
return known
return None
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Command dispatch # Command dispatch
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@@ -113,12 +180,23 @@ class CowCliPlugin(Plugin):
"""Execute a cow/slash command string without a channel context. """Execute a cow/slash command string without a channel context.
Used by cloud on_chat to intercept commands before the agent runs. Used by cloud on_chat to intercept commands before the agent runs.
Returns None when *query* is not a recognised command. Returns None when *query* is not command-like at all (e.g. natural
language). For slash-prefixed typos returns a hint string so the
caller still short-circuits the agent round.
""" """
parsed = self._parse_command(query.strip()) parsed = self._parse_command(query.strip())
if not parsed: if parsed is None:
return None return None
cmd, args = parsed cmd, args = parsed
if cmd not in KNOWN_COMMANDS:
suggestion = self._suggest_command(cmd)
if suggestion is None:
return None
hint = f"未知命令: /{cmd}"
if suggestion:
hint += f"\n你是不是想输入 /{suggestion} ?"
hint += "\n发送 /help 查看全部命令。"
return hint
return self._dispatch(cmd, args, e_context=None, session_id=session_id) return self._dispatch(cmd, args, e_context=None, session_id=session_id)
def _dispatch(self, cmd: str, args: str, e_context: EventContext, session_id: str = "") -> str: def _dispatch(self, cmd: str, args: str, e_context: EventContext, session_id: str = "") -> str:
@@ -158,6 +236,8 @@ class CowCliPlugin(Plugin):
" /config 查看当前配置", " /config 查看当前配置",
" /config <key> 查看某项配置", " /config <key> 查看某项配置",
" /config <key> <val> 修改配置", " /config <key> <val> 修改配置",
" /memory status 查看记忆索引状态",
" /memory rebuild-index 清空并重建向量索引 (切换 embedding 模型后必须执行)",
" /memory dream [N] 手动触发记忆蒸馏 (整理近N天, 默认3, 最多30)", " /memory dream [N] 手动触发记忆蒸馏 (整理近N天, 默认3, 最多30)",
" /knowledge 查看知识库统计", " /knowledge 查看知识库统计",
" /knowledge list 查看知识库文件树", " /knowledge list 查看知识库文件树",
@@ -907,10 +987,23 @@ class CowCliPlugin(Plugin):
if len(parts) > 1 and parts[1].isdigit(): if len(parts) > 1 and parts[1].isdigit():
days = max(1, min(int(parts[1]), 30)) days = max(1, min(int(parts[1]), 30))
return self._memory_dream(days, e_context, session_id) return self._memory_dream(days, e_context, session_id)
elif sub in ("rebuild-index", "rebuild_index", "rebuild"):
return self._memory_rebuild_index(e_context, session_id)
elif sub in ("status", "info", ""):
if sub == "":
return self._memory_help()
return self._memory_status()
else: else:
return self._memory_help()
@staticmethod
def _memory_help() -> str:
return ( return (
"🧠 记忆管理\n\n"
"用法: /memory <子命令>\n\n" "用法: /memory <子命令>\n\n"
"子命令:\n" "子命令:\n"
" status 查看索引状态 (provider / model / dim / chunks)\n"
" rebuild-index 清空并重建向量索引 (切换 embedding 模型后必须执行)\n"
" dream [N] 手动触发记忆蒸馏 (整理近N天, 默认3, 最多30)" " dream [N] 手动触发记忆蒸馏 (整理近N天, 默认3, 最多30)"
) )
@@ -963,6 +1056,140 @@ class CowCliPlugin(Plugin):
logger.warning(f"[CowCli] /memory dream sync failed: {e}") logger.warning(f"[CowCli] /memory dream sync failed: {e}")
return f"❌ 记忆蒸馏失败: {e}" return f"❌ 记忆蒸馏失败: {e}"
def _memory_status(self) -> str:
"""Show current memory index status."""
from agent.memory.embedding import detect_index_dim
from config import conf
agent = self._get_agent("")
memory_manager = agent.memory_manager if agent else None
lines = ["🧠 记忆索引状态", ""]
if not memory_manager:
lines.append(" ⚠️ Agent 尚未初始化,先发一条普通消息再试")
return "\n".join(lines)
stats = memory_manager.storage.get_stats()
db_path = memory_manager.config.get_db_path()
embedded = stats.get('embedded', 0)
chunks = stats.get('chunks', 0)
lines.append(f" 索引DB : {db_path}")
lines.append(f" Files : {stats.get('files', 0)}")
lines.append(f" Chunks : {chunks} (embedded: {embedded})")
lines.append("")
# Active provider (from running config + provider instance).
provider_obj = memory_manager.embedding_provider
cfg_provider = (conf().get("embedding_provider") or "").strip().lower() or "(legacy)"
if provider_obj is not None:
cfg_model = getattr(provider_obj, "model", "?")
cfg_dim = getattr(provider_obj, "_dimensions", None) or "?"
lines.append(f" Provider : {cfg_provider}")
lines.append(f" Model : {cfg_model}")
lines.append(f" Dim : {cfg_dim}")
else:
lines.append(" Provider : (未初始化, keyword-only)")
# Health hints — only shown when the user has explicitly opted into
# vector search via `embedding_provider`. Legacy users (no explicit
# provider) are running in a "best-effort vectors" mode by design;
# nagging them about missing/mismatched vectors would be noise.
warnings = []
explicitly_opted_in = (conf().get("embedding_provider") or "").strip() != ""
if explicitly_opted_in and provider_obj is not None:
if chunks > 0 and embedded < chunks:
missing = chunks - embedded
warnings.append(
f" ⚠️ {missing}/{chunks} 个 chunk 没有向量;"
f"运行 /memory rebuild-index 后所有记忆才会被向量化检索"
)
index_dim = detect_index_dim(memory_manager.storage)
cfg_dim = getattr(provider_obj, "_dimensions", None)
if index_dim is not None and cfg_dim and index_dim != cfg_dim:
warnings.append(
f" ⚠️ 索引中存量向量为 {index_dim} 维,与当前配置 {cfg_dim} 维不一致;"
f"运行 /memory rebuild-index 重建后向量检索才会生效"
)
if warnings:
lines.append("")
lines.extend(warnings)
return "\n".join(lines)
def _memory_rebuild_index(self, e_context, session_id: str) -> str:
"""Rebuild the vector index using the current agent's memory_manager."""
session_id = self._get_session_id(e_context, fallback=session_id)
agent = self._get_agent(session_id)
if not agent or not agent.memory_manager:
return (
"⚠️ Agent 尚未初始化,无法重建索引。\n"
"请先发送一条普通消息触发 Agent 启动后再试。"
)
memory_manager = agent.memory_manager
if memory_manager.embedding_provider is None:
return (
"⚠️ 当前没有可用的 embedding provider。\n"
"请检查 config.json 中的 embedding 相关配置 (provider / api key)。"
)
provider_obj = memory_manager.embedding_provider
model_label = getattr(provider_obj, "model", "?")
dim_label = getattr(provider_obj, "dimensions", "?")
# SaaS (e_context is None): run synchronously, return final result
if e_context is None:
return self._memory_rebuild_sync(memory_manager, model_label, dim_label)
# Local channels: run in background, push progress + final result
from agent.memory.embedding import rebuild_in_process
def _run():
try:
result = rebuild_in_process(memory_manager)
if result.ok:
self._notify(
e_context,
(
f"✅ 索引重建完成\n"
f" cleared : {result.removed}\n"
f" chunks : {result.chunks}\n"
f" files : {result.files}"
),
)
else:
self._notify(e_context, f"❌ 索引重建失败: {result.error}")
except Exception as e:
logger.exception("[CowCli] /memory rebuild-index failed")
self._notify(e_context, f"❌ 索引重建失败: {e}")
threading.Thread(target=_run, daemon=True).start()
return (
f"🔧 索引重建已启动 (model={model_label}, dim={dim_label})\n\n"
f"将清空现有 chunks 并重新 embed 所有记忆文件,完成后会通知你。"
)
@staticmethod
def _memory_rebuild_sync(memory_manager, model_label, dim_label) -> str:
from agent.memory.embedding import rebuild_in_process
try:
result = rebuild_in_process(memory_manager)
except Exception as e:
logger.exception("[CowCli] /memory rebuild-index sync failed")
return f"❌ 索引重建失败: {e}"
if not result.ok:
return f"❌ 索引重建失败: {result.error}"
return (
f"✅ 索引重建完成 (model={model_label}, dim={dim_label})\n"
f" cleared : {result.removed}\n"
f" chunks : {result.chunks}\n"
f" files : {result.files}"
)
@staticmethod @staticmethod
def _notify(e_context, text: str): def _notify(e_context, text: str):
"""Push a notification message back to the chat channel.""" """Push a notification message back to the chat channel."""