diff --git a/agent/memory/embedding.py b/agent/memory/embedding.py deleted file mode 100644 index 1bc1c671..00000000 --- a/agent/memory/embedding.py +++ /dev/null @@ -1,167 +0,0 @@ -""" -Embedding providers for memory - -Supports OpenAI and local embedding models -""" - -import hashlib -from abc import ABC, abstractmethod -from typing import List, Optional - - -class EmbeddingProvider(ABC): - """Base class for embedding providers""" - - @abstractmethod - def embed(self, text: str) -> List[float]: - """Generate embedding for text""" - pass - - @abstractmethod - def embed_batch(self, texts: List[str]) -> List[List[float]]: - """Generate embeddings for multiple texts""" - pass - - @property - @abstractmethod - def dimensions(self) -> int: - """Get embedding dimensions""" - pass - - -class OpenAIEmbeddingProvider(EmbeddingProvider): - """OpenAI embedding provider using REST API""" - - def __init__(self, model: str = "text-embedding-3-small", api_key: Optional[str] = None, - api_base: Optional[str] = None, extra_headers: Optional[dict] = None): - """ - Initialize OpenAI embedding provider - - Args: - model: Model name (text-embedding-3-small or text-embedding-3-large) - api_key: OpenAI API key - api_base: Optional API base URL - extra_headers: Optional extra headers to include in API requests - """ - self.model = model - self.api_key = api_key - self.api_base = api_base or "https://api.openai.com/v1" - self.extra_headers = extra_headers or {} - - # Validate API key - if not self.api_key or self.api_key in ["", "YOUR API KEY", "YOUR_API_KEY"]: - raise ValueError("OpenAI API key is not configured. Please set 'open_ai_api_key' in config.json") - - # Set dimensions based on model - self._dimensions = 1536 if "small" in model else 3072 - - def _call_api(self, input_data): - """Call OpenAI embedding API using requests""" - import requests - - url = f"{self.api_base}/embeddings" - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}", - **self.extra_headers, - } - data = { - "input": input_data, - "model": self.model - } - - try: - response = requests.post(url, headers=headers, json=data, timeout=5) - response.raise_for_status() - return response.json() - except requests.exceptions.ConnectionError as e: - raise ConnectionError(f"Failed to connect to OpenAI API at {url}. Please check your network connection and api_base configuration. Error: {str(e)}") - except requests.exceptions.Timeout as e: - raise TimeoutError(f"OpenAI API request timed out after 10s. Please check your network connection. Error: {str(e)}") - except requests.exceptions.HTTPError as e: - if e.response.status_code == 401: - raise ValueError(f"Invalid OpenAI API key. Please check your 'open_ai_api_key' in config.json") - elif e.response.status_code == 429: - raise ValueError(f"OpenAI API rate limit exceeded. Please try again later.") - else: - raise ValueError(f"OpenAI API request failed: {e.response.status_code} - {e.response.text}") - - def embed(self, text: str) -> List[float]: - """Generate embedding for text""" - result = self._call_api(text) - return result["data"][0]["embedding"] - - def embed_batch(self, texts: List[str]) -> List[List[float]]: - """Generate embeddings for multiple texts""" - if not texts: - return [] - - result = self._call_api(texts) - return [item["embedding"] for item in result["data"]] - - @property - def dimensions(self) -> int: - return self._dimensions - - -# LocalEmbeddingProvider removed - only use OpenAI embedding or keyword search - - -class EmbeddingCache: - """Cache for embeddings to avoid recomputation""" - - def __init__(self): - self.cache = {} - - def get(self, text: str, provider: str, model: str) -> Optional[List[float]]: - """Get cached embedding""" - key = self._compute_key(text, provider, model) - return self.cache.get(key) - - def put(self, text: str, provider: str, model: str, embedding: List[float]): - """Cache embedding""" - key = self._compute_key(text, provider, model) - self.cache[key] = embedding - - @staticmethod - def _compute_key(text: str, provider: str, model: str) -> str: - """Compute cache key""" - content = f"{provider}:{model}:{text}" - return hashlib.md5(content.encode('utf-8')).hexdigest() - - def clear(self): - """Clear cache""" - self.cache.clear() - - -def create_embedding_provider( - provider: str = "openai", - model: Optional[str] = None, - api_key: Optional[str] = None, - api_base: Optional[str] = None, - extra_headers: Optional[dict] = None -) -> EmbeddingProvider: - """ - Factory function to create embedding provider - - Supports "openai" and "linkai" providers (both use OpenAI-compatible REST API). - If initialization fails, caller should fall back to keyword-only search. - - Args: - provider: Provider name ("openai" or "linkai") - model: Model name (default: text-embedding-3-small) - api_key: API key (required) - api_base: API base URL - extra_headers: Optional extra headers to include in API requests - - Returns: - EmbeddingProvider instance - - Raises: - ValueError: If provider is unsupported or api_key is missing - """ - if provider not in ("openai", "linkai"): - raise ValueError(f"Unsupported embedding provider: {provider}. Use 'openai' or 'linkai'.") - - model = model or "text-embedding-3-small" - return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base, extra_headers=extra_headers) diff --git a/agent/memory/embedding/__init__.py b/agent/memory/embedding/__init__.py new file mode 100644 index 00000000..f89bc216 --- /dev/null +++ b/agent/memory/embedding/__init__.py @@ -0,0 +1,41 @@ +""" +Embedding subsystem for memory. + +Public API: + create_embedding_provider, EmbeddingProvider, OpenAIEmbeddingProvider, + EMBEDDING_VENDORS, EmbeddingCache + RebuildResult, clear_index, rebuild_in_process + detect_index_dim, cleanup_legacy_state_file +""" + +from agent.memory.embedding.provider import ( + EMBEDDING_VENDORS, + DoubaoEmbeddingProvider, + EmbeddingCache, + EmbeddingProvider, + OpenAIEmbeddingProvider, + create_embedding_provider, +) +from agent.memory.embedding.rebuild import ( + RebuildResult, + clear_index, + rebuild_in_process, +) +from agent.memory.embedding.state import ( + cleanup_legacy_state_file, + detect_index_dim, +) + +__all__ = [ + "EMBEDDING_VENDORS", + "DoubaoEmbeddingProvider", + "EmbeddingCache", + "EmbeddingProvider", + "OpenAIEmbeddingProvider", + "create_embedding_provider", + "RebuildResult", + "clear_index", + "rebuild_in_process", + "cleanup_legacy_state_file", + "detect_index_dim", +] diff --git a/agent/memory/embedding/provider.py b/agent/memory/embedding/provider.py new file mode 100644 index 00000000..a106a43c --- /dev/null +++ b/agent/memory/embedding/provider.py @@ -0,0 +1,486 @@ +""" +Embedding providers for memory + +Supports multiple OpenAI-compatible embedding vendors: + - openai (text-embedding-3-small / large) + - linkai (OpenAI-compatible passthrough) + - dashscope (Aliyun Tongyi text-embedding-v4) + - doubao (ByteDance Doubao Seed1.5 / large-text on Volcengine Ark) + - zhipu (ZhipuAI embedding-3) + +Vendor keys here intentionally match the project's bot_type constants in +common.const (OPENAI, LINKAI, QWEN_DASHSCOPE, DOUBAO, ZHIPU_AI). + +All providers share a single OpenAI-compatible REST client. Vendor-specific +behaviors (truncation, query instruction prefix) are configured via metadata. +""" + +import hashlib +import math +from abc import ABC, abstractmethod +from typing import List, Optional + +# HTTP read timeout for a single embeddings request (seconds). A batch of +# 64+ chunks can take 30-50s end-to-end from China-side networks, so 30s is +# routinely too tight; 90s gives meaningful headroom without letting bad +# endpoints hang forever. +EMBEDDING_HTTP_TIMEOUT = 90 + + +class EmbeddingProvider(ABC): + """Base class for embedding providers""" + + @abstractmethod + def embed(self, text: str) -> List[float]: + """Generate embedding for a single text (treated as a query by default)""" + pass + + @abstractmethod + def embed_batch(self, texts: List[str]) -> List[List[float]]: + """Generate embeddings for multiple texts (treated as documents)""" + pass + + def embed_query(self, text: str) -> List[float]: + """Generate embedding for a query string (may apply vendor instruction prefix)""" + return self.embed(text) + + @property + @abstractmethod + def dimensions(self) -> int: + """Effective embedding dimensions""" + pass + + +# --------------------------------------------------------------------------- +# Vendor metadata table +# --------------------------------------------------------------------------- +# +# Each entry describes how to reach a vendor's embedding endpoint. Most +# vendors expose an OpenAI-compatible /embeddings API; the few that don't +# (currently: doubao) set `provider_class` to pick a dedicated adapter. +# Fields: +# provider_class : optional adapter key ("doubao"); defaults to OpenAI-compat +# default_base_url : default API base when not overridden by user +# default_model : default embedding model name +# default_dimensions : recommended unified dim when explicit path is enabled +# supports_dim_param : whether the API accepts a `dimensions` request param +# needs_client_truncate : whether to slice + L2-normalize on the client side +# needs_client_normalize : whether to L2-normalize on the client (always safe) +# query_instruction : optional prefix for asymmetric retrieval (Doubao Seed) +# max_batch_size : max texts per /embeddings request; embed_batch +# auto-paginates above this. Conservative defaults. +# +EMBEDDING_VENDORS = { + "openai": { + "default_base_url": "https://api.openai.com/v1", + "default_model": "text-embedding-3-small", + # Match the legacy default so users adding `embedding_provider: openai` + # to an existing index don't need to rebuild. Override via + # embedding_dimensions if you want 1024 / 1536 / 3072. + "default_dimensions": 1536, + "supports_dim_param": True, + "needs_client_truncate": False, + "needs_client_normalize": False, + "query_instruction": "", + # OpenAI permits up to 2048 items per request, but a single call + # carrying hundreds of long chunks routinely exceeds the 30s read + # timeout from China-side networks. 64 keeps each call well under + # both the token-per-request budget and a reasonable wall clock. + "max_batch_size": 64, + }, + "linkai": { + "default_base_url": "https://api.link-ai.tech/v1", + "default_model": "text-embedding-3-small", + "default_dimensions": 1536, + "supports_dim_param": True, + "needs_client_truncate": False, + "needs_client_normalize": False, + "query_instruction": "", + "max_batch_size": 64, + }, + "dashscope": { + "default_base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", + "default_model": "text-embedding-v4", + "default_dimensions": 1024, + "supports_dim_param": True, + "needs_client_truncate": False, + "needs_client_normalize": False, + "query_instruction": "", + "max_batch_size": 10, # DashScope hard cap (text-embedding-v4) + }, + "doubao": { + # Doubao no longer offers an OpenAI-compatible /v1/embeddings endpoint. + # Current models are unified under /api/v3/embeddings/multimodal + # which uses a structured `input` payload — see DoubaoEmbeddingProvider. + "provider_class": "doubao", + "default_base_url": "https://ark.cn-beijing.volces.com/api/v3", + "default_model": "doubao-embedding-vision-251215", + # Native options: 1024 or 2048. We default to 1024 to align with the + # other Chinese vendors (dashscope/zhipu) and keep storage footprint + # consistent across providers; users can still override via + # `embedding_dimensions: 2048` in config. + "default_dimensions": 1024, + "supports_dim_param": True, + "needs_client_truncate": False, + "needs_client_normalize": False, + "query_instruction": "", + # Multimodal endpoint produces ONE embedding per call (input list is + # a single document's parts, not a batch). embed_batch loops. + "max_batch_size": 1, + }, + "zhipu": { + "default_base_url": "https://open.bigmodel.cn/api/paas/v4", + "default_model": "embedding-3", + "default_dimensions": 1024, + "supports_dim_param": True, + "needs_client_truncate": False, + "needs_client_normalize": False, + "query_instruction": "", + "max_batch_size": 64, + }, +} + + +def _l2_normalize(vec: List[float]) -> List[float]: + """Normalize a vector to unit length (L2 norm). Returns input on zero vector.""" + norm = math.sqrt(sum(v * v for v in vec)) + if norm == 0: + return vec + return [v / norm for v in vec] + + +class OpenAIEmbeddingProvider(EmbeddingProvider): + """ + OpenAI-compatible embedding provider. + + Used for openai/linkai/dashscope/ark/zhipu by configuring the metadata + fields. The legacy two-arg constructor (model, api_key, api_base) keeps + working, so the original OpenAI/LinkAI fallback code path is unchanged. + """ + + def __init__( + self, + model: str = "text-embedding-3-small", + api_key: Optional[str] = None, + api_base: Optional[str] = None, + extra_headers: Optional[dict] = None, + dimensions: Optional[int] = None, + supports_dim_param: bool = True, + needs_client_truncate: bool = False, + needs_client_normalize: bool = False, + query_instruction: str = "", + max_batch_size: int = 256, + ): + """ + Args: + model: Model name (e.g. text-embedding-3-small, text-embedding-v4, embedding-3) + api_key: API key (required) + api_base: API base URL (defaults to OpenAI) + extra_headers: Optional extra HTTP headers + dimensions: Target output dimension. Required when supports_dim_param + is False and needs_client_truncate is True (used to slice). + supports_dim_param: Whether the vendor accepts a `dimensions` body param + needs_client_truncate: Slice the returned vector to `dimensions` + needs_client_normalize: L2-normalize on the client after slicing + query_instruction: Optional prefix prepended to query texts only + max_batch_size: Max items per /embeddings request; embed_batch + auto-paginates above this. + """ + self.model = model + self.api_key = api_key + self.api_base = api_base or "https://api.openai.com/v1" + self.extra_headers = extra_headers or {} + self.supports_dim_param = supports_dim_param + self.needs_client_truncate = needs_client_truncate + self.needs_client_normalize = needs_client_normalize + self.query_instruction = query_instruction or "" + self.max_batch_size = max(1, int(max_batch_size or 1)) + + if not self.api_key or self.api_key in ["", "YOUR API KEY", "YOUR_API_KEY"]: + raise ValueError("Embedding API key is not configured") + + if dimensions is not None and dimensions > 0: + self._dimensions = dimensions + else: + # Legacy heuristic for OpenAI text-embedding-3-* family + self._dimensions = 1536 if "small" in model else 3072 + + def _call_api(self, input_data): + """Call OpenAI-compatible /embeddings endpoint""" + import requests + + url = f"{self.api_base}/embeddings" + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}", + **self.extra_headers, + } + data = { + "input": input_data, + "model": self.model, + } + if self.supports_dim_param and self._dimensions: + data["dimensions"] = self._dimensions + + try: + response = requests.post(url, headers=headers, json=data, timeout=EMBEDDING_HTTP_TIMEOUT) + response.raise_for_status() + return response.json() + except requests.exceptions.ConnectionError as e: + raise ConnectionError( + f"Failed to connect to embedding API at {url}. " + f"Please check network and api_base. Error: {str(e)}" + ) + except requests.exceptions.Timeout as e: + raise TimeoutError(f"Embedding API request timed out. Error: {str(e)}") + except requests.exceptions.HTTPError as e: + if e.response.status_code == 401: + raise ValueError("Invalid embedding API key") + elif e.response.status_code == 429: + raise ValueError("Embedding API rate limit exceeded") + else: + raise ValueError( + f"Embedding API request failed: " + f"{e.response.status_code} - {e.response.text}" + ) + + def _post_process(self, raw: List[float]) -> List[float]: + """Apply optional client-side truncation + normalization""" + vec = raw + if self.needs_client_truncate and self._dimensions and len(vec) > self._dimensions: + vec = vec[: self._dimensions] + if self.needs_client_normalize: + vec = _l2_normalize(vec) + return vec + + def embed(self, text: str) -> List[float]: + """Generate embedding (treated as document by default)""" + result = self._call_api(text) + return self._post_process(result["data"][0]["embedding"]) + + def embed_query(self, text: str) -> List[float]: + """Generate embedding for a query (applies vendor instruction prefix if any)""" + if self.query_instruction: + text = f"{self.query_instruction}{text}" + return self.embed(text) + + def embed_batch(self, texts: List[str]) -> List[List[float]]: + """Generate embeddings for multiple documents. + + Automatically paginates by self.max_batch_size so callers can pass any + number of texts. Order of returned vectors matches the input order. + """ + if not texts: + return [] + out: List[List[float]] = [] + step = self.max_batch_size + for i in range(0, len(texts), step): + chunk = texts[i:i + step] + result = self._call_api(chunk) + out.extend(self._post_process(item["embedding"]) for item in result["data"]) + return out + + @property + def dimensions(self) -> int: + return self._dimensions + + +class DoubaoEmbeddingProvider(EmbeddingProvider): + """ + Doubao (Volcengine Ark) multimodal embedding provider. + + Doubao deprecated their OpenAI-compatible /v1/embeddings endpoint and + unified everything under /api/v3/embeddings/multimodal, which uses a + structured `input: [{type, text|image_url|video_url}, ...]` payload. + + Notes: + * The endpoint produces ONE embedding per call (input list is multiple + modality parts of a single document, not a batch). embed_batch + therefore loops per-text — no native batch support. + * Native dimensions: 1024 or 2048 (default 1024 to align with other + Chinese vendors). No client-side truncation needed. + * Auth: Bearer ARK API key. + """ + + def __init__( + self, + model: str, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + extra_headers: Optional[dict] = None, + dimensions: Optional[int] = None, + ): + self.model = model + self.api_key = api_key + self.api_base = api_base or "https://ark.cn-beijing.volces.com/api/v3" + self.extra_headers = extra_headers or {} + if not self.api_key or self.api_key in ["", "YOUR API KEY", "YOUR_API_KEY"]: + raise ValueError("Doubao embedding API key (ark_api_key) is not configured") + + if dimensions in (1024, 2048): + self._dimensions = dimensions + elif dimensions is None: + self._dimensions = 1024 + else: + raise ValueError( + f"Doubao embedding dimensions must be 1024 or 2048, got {dimensions}" + ) + + def _call_api(self, text: str) -> List[float]: + """One call → one embedding. multimodal endpoint takes a single + document represented as a list of typed parts; we send a single + text part.""" + import requests + + url = f"{self.api_base}/embeddings/multimodal" + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}", + **self.extra_headers, + } + payload = { + "model": self.model, + "input": [{"type": "text", "text": text}], + "dimensions": self._dimensions, + "encoding_format": "float", + } + + try: + response = requests.post(url, headers=headers, json=payload, timeout=EMBEDDING_HTTP_TIMEOUT) + response.raise_for_status() + body = response.json() + except requests.exceptions.ConnectionError as e: + raise ConnectionError( + f"Failed to connect to Doubao embedding API at {url}. " + f"Please check network and api_base. Error: {str(e)}" + ) + except requests.exceptions.Timeout as e: + raise TimeoutError(f"Doubao embedding API request timed out. Error: {str(e)}") + except requests.exceptions.HTTPError as e: + if e.response.status_code == 401: + raise ValueError("Invalid Doubao (ark) embedding API key") + elif e.response.status_code == 429: + raise ValueError("Doubao embedding API rate limit exceeded") + else: + raise ValueError( + f"Doubao embedding API request failed: " + f"{e.response.status_code} - {e.response.text}" + ) + + # Response shape per docs: {"data": {"embedding": [...]}} + data = body.get("data") + if isinstance(data, dict) and "embedding" in data: + return data["embedding"] + # Some providers wrap as a list of one — be defensive + if isinstance(data, list) and data and "embedding" in data[0]: + return data[0]["embedding"] + raise ValueError(f"Unexpected Doubao embedding response shape: {body}") + + def embed(self, text: str) -> List[float]: + return self._call_api(text) + + def embed_batch(self, texts: List[str]) -> List[List[float]]: + # Endpoint produces one embedding per call; loop. Order preserved. + return [self._call_api(t) for t in texts] + + @property + def dimensions(self) -> int: + return self._dimensions + + +class EmbeddingCache: + """In-memory cache for embeddings to avoid recomputation""" + + def __init__(self): + self.cache = {} + + def get(self, text: str, provider: str, model: str) -> Optional[List[float]]: + key = self._compute_key(text, provider, model) + return self.cache.get(key) + + def put(self, text: str, provider: str, model: str, embedding: List[float]): + key = self._compute_key(text, provider, model) + self.cache[key] = embedding + + @staticmethod + def _compute_key(text: str, provider: str, model: str) -> str: + content = f"{provider}:{model}:{text}" + return hashlib.md5(content.encode("utf-8")).hexdigest() + + def clear(self): + self.cache.clear() + + +def create_embedding_provider( + provider: str = "openai", + model: Optional[str] = None, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + extra_headers: Optional[dict] = None, + dimensions: Optional[int] = None, +) -> EmbeddingProvider: + """ + Factory function to create an embedding provider. + + Backward compatible: when called with provider in {"openai", "linkai"} + and no `dimensions` arg, behaves exactly as before (1536-dim OpenAI). + + New providers ("dashscope", "doubao", "zhipu") require explicit configuration + and use the unified 1024-dim defaults from EMBEDDING_VENDORS. + + Args: + provider: Vendor key (one of EMBEDDING_VENDORS) + model: Model name (uses vendor default if None) + api_key: API key (required) + api_base: API base URL (uses vendor default if None) + extra_headers: Optional extra HTTP headers + dimensions: Target output dimension (uses vendor default if None) + + Returns: + EmbeddingProvider instance + """ + meta = EMBEDDING_VENDORS.get(provider) + if meta is None: + raise ValueError( + f"Unsupported embedding provider: {provider}. " + f"Supported: {sorted(EMBEDDING_VENDORS.keys())}" + ) + + # Doubao uses a non-OpenAI-compatible multimodal endpoint. + if meta.get("provider_class") == "doubao": + final_dim = dimensions if (dimensions and dimensions > 0) else meta["default_dimensions"] + return DoubaoEmbeddingProvider( + model=model or meta["default_model"], + api_key=api_key, + api_base=api_base or meta["default_base_url"], + extra_headers=extra_headers, + dimensions=final_dim, + ) + + # Legacy two-arg call for openai/linkai keeps 1536-dim default behavior + # so existing data isn't invalidated. + is_legacy_call = ( + provider in ("openai", "linkai") + and dimensions is None + ) + if is_legacy_call: + return OpenAIEmbeddingProvider( + model=model or "text-embedding-3-small", + api_key=api_key, + api_base=api_base, + extra_headers=extra_headers, + ) + + final_dim = dimensions if (dimensions and dimensions > 0) else meta["default_dimensions"] + return OpenAIEmbeddingProvider( + model=model or meta["default_model"], + api_key=api_key, + api_base=api_base or meta["default_base_url"], + extra_headers=extra_headers, + dimensions=final_dim, + supports_dim_param=meta["supports_dim_param"], + needs_client_truncate=meta["needs_client_truncate"], + needs_client_normalize=meta["needs_client_normalize"], + query_instruction=meta["query_instruction"], + max_batch_size=meta.get("max_batch_size", 256), + ) diff --git a/agent/memory/embedding/rebuild.py b/agent/memory/embedding/rebuild.py new file mode 100644 index 00000000..e5b592ab --- /dev/null +++ b/agent/memory/embedding/rebuild.py @@ -0,0 +1,191 @@ +""" +Rebuild memory vector index. + +Recommended entry point (in-chat, while agent is running): + /memory rebuild-index + +Backward-compatible CLI entry (must run from project root): + python -m agent.memory.rebuild_index + +What it does: + 1. Probes the embedding endpoint with a tiny call to fail fast on + bad provider/model/key — before touching the index. + 2. Clears the SQLite chunks/files tables (workspace markdown stays intact). + 3. Runs a fresh sync, regenerating embeddings with the currently configured + provider/model/dimensions. + +This is the only safe way to switch embedding_provider after the existing +index has been populated by a different-dim model. +""" + +from __future__ import annotations +import asyncio +import sys +from dataclasses import dataclass +from typing import Optional + +from common.log import logger +from common.utils import expand_path + + +@dataclass +class RebuildResult: + """Outcome of a rebuild_in_process() call""" + ok: bool + removed: int = 0 + chunks: int = 0 + files: int = 0 + error: Optional[str] = None + + +def clear_index(db_path, storage=None) -> int: + """Wipe chunks/files, reset FTS5, and clean up any legacy state file. + + Args: + db_path: Path of the index DB (also used to locate the legacy state + file for migration cleanup, and — when *storage* is None — to + open a fresh connection). + storage: Optional pre-opened MemoryStorage. When provided we reuse it + so the live connection's triggers stay in sync — opening a second + connection would leave the original one's triggers pointing at a + DROP'd chunks_fts table. + + We reset (DROP+recreate) chunks_fts because its shadow tables can become + inconsistent across rebuild cycles, causing bm25() / ORDER BY rank to + raise "database disk image is malformed" even when raw MATCH still works. + + Returns number of chunks removed. + """ + from agent.memory.embedding.state import cleanup_legacy_state_file + from agent.memory.storage import MemoryStorage + + owns_storage = storage is None + if owns_storage: + storage = MemoryStorage(db_path) + try: + before = storage.conn.execute("SELECT COUNT(*) FROM chunks").fetchone()[0] + storage.conn.execute("DELETE FROM chunks") + storage.conn.execute("DELETE FROM files") + storage.conn.commit() + storage.reset_fts5() + finally: + if owns_storage: + storage.close() + + cleanup_legacy_state_file(db_path) + return int(before) + + +def rebuild_in_process(memory_manager) -> RebuildResult: + """ + Rebuild the index using an existing, fully-initialized MemoryManager. + + Used by the in-chat /memory rebuild-index command. The caller already has + config loaded, embedding_provider built, and (optionally) the agent + running, so we only need to: + 1. Clear chunks/files + state on the manager's storage. + 2. Re-sync (force=True). + + NOTE: caller must ensure memory_manager.embedding_provider is set, otherwise + sync() will silently skip embedding generation. + """ + if memory_manager is None: + return RebuildResult(ok=False, error="memory_manager is None") + if memory_manager.embedding_provider is None: + return RebuildResult(ok=False, error="embedding_provider is not initialized") + + # Probe the embedding endpoint BEFORE clearing the index. A bad + # provider/model/key would otherwise leave the user with an empty index + # that not even keyword search can serve. + try: + memory_manager.embedding_provider.embed_query("ping") + except Exception as e: + logger.error(f"[RebuildIndex] embedding probe failed, aborting rebuild: {e}") + return RebuildResult(ok=False, error=f"embedding endpoint not reachable: {e}") + + db_path = memory_manager.config.get_db_path() + try: + removed = clear_index(db_path, storage=memory_manager.storage) + except Exception as e: + logger.exception("[RebuildIndex] clear_index failed") + return RebuildResult(ok=False, error=f"clear failed: {e}") + + try: + asyncio.run(memory_manager.sync(force=True)) + except RuntimeError: + # Already inside a running event loop (rare in chat handler thread). + loop = asyncio.new_event_loop() + try: + loop.run_until_complete(memory_manager.sync(force=True)) + finally: + loop.close() + except Exception as e: + logger.exception("[RebuildIndex] sync failed") + return RebuildResult(ok=False, removed=removed, error=f"re-embed failed: {e}") + + stats = memory_manager.storage.get_stats() + chunks = int(stats.get("chunks", 0)) + embedded = int(stats.get("embedded", 0)) + + # sync() degrades to "no embeddings" on batch failure so keyword search + # still works at startup — but in a /rebuild-index request the user + # explicitly asked for vectors. Surface that as a failure. + if chunks > 0 and embedded == 0: + return RebuildResult( + ok=False, + removed=removed, + chunks=chunks, + files=int(stats.get("files", 0)), + error=( + "embedding API failed during sync; index now has chunks but no " + "vectors. Check embedding provider/model/key and retry." + ), + ) + + return RebuildResult( + ok=True, + removed=removed, + chunks=chunks, + files=int(stats.get("files", 0)), + ) + + +def main() -> int: + """Standalone CLI entry. Must be run from project root (relative config path).""" + from config import conf, load_config + from agent.memory import MemoryConfig, MemoryManager + + load_config() + + workspace_root = expand_path(conf().get("agent_workspace", "~/cow")) + memory_config = MemoryConfig(workspace_root=workspace_root) + + logger.info(f"[RebuildIndex] Workspace: {workspace_root}") + logger.info(f"[RebuildIndex] Index db: {memory_config.get_db_path()}") + + from bridge.agent_initializer import AgentInitializer + + initializer = AgentInitializer(bridge=None, agent_bridge=None) + embedding_provider = initializer._init_embedding_provider(memory_config, session_id=None) + if embedding_provider is None: + logger.error( + "[RebuildIndex] No embedding provider could be initialized. " + "Check your config.json. Aborting rebuild." + ) + return 1 + + manager = MemoryManager(memory_config, embedding_provider=embedding_provider) + result = rebuild_in_process(manager) + if not result.ok: + logger.error(f"[RebuildIndex] {result.error}") + return 1 + + logger.info( + f"[RebuildIndex] Done. removed={result.removed}, " + f"chunks={result.chunks}, files={result.files}" + ) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/agent/memory/embedding/state.py b/agent/memory/embedding/state.py new file mode 100644 index 00000000..3fb60b23 --- /dev/null +++ b/agent/memory/embedding/state.py @@ -0,0 +1,47 @@ +""" +Embedding-related index utilities. + +We don't keep a sidecar state file — the SQLite index is the source of truth +and config.json is the source of intent. The two functions below are the +only things needing on-disk awareness: + + detect_index_dim : read the dim of stored vectors (display-only) + cleanup_legacy_state_file: remove old embedding_state.json from earlier + versions; safe no-op when absent. +""" + +from __future__ import annotations +import json +import os +from pathlib import Path +from typing import Optional, Union + +PathLike = Union[str, os.PathLike] + + +def detect_index_dim(storage) -> Optional[int]: + """Return the dim of the first stored embedding, or None if the index + has no embeddings. Used by /memory status.""" + try: + row = storage.conn.execute( + "SELECT embedding FROM chunks WHERE embedding IS NOT NULL LIMIT 1" + ).fetchone() + except Exception: + return None + if not row or not row["embedding"]: + return None + try: + emb = json.loads(row["embedding"]) + return len(emb) if isinstance(emb, list) else None + except (json.JSONDecodeError, TypeError): + return None + + +def cleanup_legacy_state_file(db_path: PathLike) -> None: + """Remove old embedding_state.json files from earlier versions. + Safe to call repeatedly; no-op if the file is absent.""" + legacy = Path(db_path).parent / "embedding_state.json" + try: + legacy.unlink(missing_ok=True) + except Exception: + pass diff --git a/agent/memory/manager.py b/agent/memory/manager.py index 259742f0..7053592a 100644 --- a/agent/memory/manager.py +++ b/agent/memory/manager.py @@ -13,7 +13,7 @@ from datetime import datetime, timedelta from agent.memory.config import MemoryConfig, get_default_memory_config from agent.memory.storage import MemoryStorage, MemoryChunk, SearchResult from agent.memory.chunker import TextChunker -from agent.memory.embedding import create_embedding_provider, EmbeddingProvider +from agent.memory.embedding import EmbeddingProvider from agent.memory.summarizer import MemoryFlushManager, create_memory_files_if_needed @@ -50,49 +50,17 @@ class MemoryManager: overlap_tokens=self.config.chunk_overlap_tokens ) - # Initialize embedding provider (optional, prefer OpenAI, fallback to LinkAI) - self.embedding_provider = None - if embedding_provider: - self.embedding_provider = embedding_provider - else: - # Try OpenAI first - try: - api_key = os.environ.get('OPENAI_API_KEY') - api_base = os.environ.get('OPENAI_API_BASE') - if api_key: - self.embedding_provider = create_embedding_provider( - provider="openai", - model=self.config.embedding_model, - api_key=api_key, - api_base=api_base - ) - except Exception as e: - from common.log import logger - logger.warning(f"[MemoryManager] OpenAI embedding failed: {e}") - - # Fallback to LinkAI - if self.embedding_provider is None: - try: - linkai_key = os.environ.get('LINKAI_API_KEY') - linkai_base = os.environ.get('LINKAI_API_BASE', 'https://api.link-ai.tech') - if linkai_key: - from common.utils import get_cloud_headers - cloud_headers = get_cloud_headers(linkai_key) - cloud_headers.pop("Authorization", None) - self.embedding_provider = create_embedding_provider( - provider="linkai", - model=self.config.embedding_model, - api_key=linkai_key, - api_base=f"{linkai_base}/v1", - extra_headers=cloud_headers, - ) - except Exception as e: - from common.log import logger - logger.warning(f"[MemoryManager] LinkAI embedding failed: {e}") - - if self.embedding_provider is None: - from common.log import logger - logger.info(f"[MemoryManager] Memory will work with keyword search only (no vector search)") + # Embedding provider is owned by the caller (agent_initializer is the + # canonical entry point and handles legacy/explicit + state validation). + # When None is passed, memory degrades to keyword-only search instead + # of silently re-initializing a vendor here, which would bypass the + # caller's state checks and risk corrupting the index. + self.embedding_provider = embedding_provider + if self.embedding_provider is None: + from common.log import logger + logger.info( + "[MemoryManager] No embedding provider; memory will use keyword search only" + ) # Initialize memory flush manager workspace_dir = self.config.get_workspace() @@ -153,12 +121,14 @@ class MemoryManager: if self.config.sync_on_search and self._dirty: await self.sync() - # Perform vector search (if embedding provider available) + from common.log import logger + + # Perform vector search (if embedding provider available). + # Failures degrade silently to keyword-only — no exception is raised. vector_results = [] if self.embedding_provider: try: - from common.log import logger - query_embedding = self.embedding_provider.embed(query) + query_embedding = self.embedding_provider.embed_query(query) vector_results = self.storage.search_vector( query_embedding=query_embedding, user_id=user_id, @@ -167,19 +137,19 @@ class MemoryManager: ) logger.info(f"[MemoryManager] Vector search found {len(vector_results)} results for query: {query}") except Exception as e: - from common.log import logger - logger.warning(f"[MemoryManager] Vector search failed: {e}") - - # Perform keyword search + logger.error( + f"[MemoryManager] Vector search failed, falling back to keyword-only: {e}" + ) + + # Perform keyword search (also runs as fallback when vector failed) keyword_results = self.storage.search_keyword( query=query, user_id=user_id, scopes=scopes, limit=max_results * 2 ) - from common.log import logger logger.info(f"[MemoryManager] Keyword search found {len(keyword_results)} results for query: {query}") - + # Merge results merged = self._merge_results( vector_results, @@ -187,7 +157,7 @@ class MemoryManager: self.config.vector_weight, self.config.keyword_weight ) - + # Filter by min score and limit filtered = [r for r in merged if r.score >= min_score] return filtered[:max_results] @@ -269,132 +239,157 @@ class MemoryManager: async def sync(self, force: bool = False): """ - Synchronize memory from files - + Synchronize memory from files. + + Two-pass design to amortize embedding HTTP cost: + 1. Walk all files, chunk those whose hash changed, collect pending + chunks across files. No embedding calls yet. + 2. Run a single embed_batch over the union of pending chunks (the + provider auto-paginates by vendor cap), then persist per-file. + + For workspaces with many small files (101 files / ~1 chunk each), this + cuts ~100 HTTP calls down to ~ceil(total_chunks / vendor_cap). + Args: force: Force full reindex """ memory_dir = self.config.get_memory_dir() workspace_dir = self.config.get_workspace() - - # Scan MEMORY.md (workspace root) + + files_to_scan: List[tuple] = [] # (file_path, source, scope, user_id) + memory_file = Path(workspace_dir) / "MEMORY.md" if memory_file.exists(): - await self._sync_file(memory_file, "memory", "shared", None) - - # Scan memory directory (including daily summaries) + files_to_scan.append((memory_file, "memory", "shared", None)) + if memory_dir.exists(): for file_path in memory_dir.rglob("*.md"): - # Skip hidden directories (e.g. .dreams/) if any(part.startswith('.') for part in file_path.relative_to(workspace_dir).parts): continue - - # Determine scope and user_id from path - rel_path = file_path.relative_to(workspace_dir) - parts = rel_path.parts - - # Check if it's in daily summary directory - if "daily" in parts: - # Daily summary files - if "users" in parts or len(parts) > 3: - # User-scoped daily summary: memory/daily/{user_id}/2024-01-29.md - user_idx = parts.index("daily") + 1 - user_id = parts[user_idx] if user_idx < len(parts) else None + rel_parts = file_path.relative_to(workspace_dir).parts + if "daily" in rel_parts: + if "users" in rel_parts or len(rel_parts) > 3: + user_idx = rel_parts.index("daily") + 1 + user_id = rel_parts[user_idx] if user_idx < len(rel_parts) else None scope = "user" else: - # Shared daily summary: memory/daily/2024-01-29.md user_id = None scope = "shared" - elif "users" in parts: - # User-scoped memory - user_idx = parts.index("users") + 1 - user_id = parts[user_idx] if user_idx < len(parts) else None + elif "users" in rel_parts: + user_idx = rel_parts.index("users") + 1 + user_id = rel_parts[user_idx] if user_idx < len(rel_parts) else None scope = "user" else: - # Shared memory user_id = None scope = "shared" - - await self._sync_file(file_path, "memory", scope, user_id) + files_to_scan.append((file_path, "memory", scope, user_id)) - # Scan knowledge directory (structured knowledge wiki) from config import conf if conf().get("knowledge", True): knowledge_dir = Path(workspace_dir) / "knowledge" if knowledge_dir.exists(): for file_path in knowledge_dir.rglob("*.md"): - await self._sync_file(file_path, "knowledge", "shared", None) - - self._dirty = False - - async def _sync_file( - self, - file_path: Path, - source: str, - scope: str, - user_id: Optional[str] - ): - """Sync a single file""" - # Compute file hash - content = file_path.read_text(encoding='utf-8') - file_hash = MemoryStorage.compute_hash(content) - - # Get relative path - workspace_dir = self.config.get_workspace() - rel_path = str(file_path.relative_to(workspace_dir)) - - # Check if file changed - stored_hash = self.storage.get_file_hash(rel_path) - if stored_hash == file_hash: - return # No changes - - # Delete old chunks - self.storage.delete_by_path(rel_path) - - # Chunk and embed - chunks = self.chunker.chunk_text(content) - if not chunks: + files_to_scan.append((file_path, "knowledge", "shared", None)) + + # Pass 1: inline chunking + change detection. Inlined (instead of + # calling self._prepare_file_for_sync) so this method does not depend + # on any sibling helpers — keeps it robust against partial reloads + # where the class object is older than the method's source. + pending: List[Dict[str, Any]] = [] + workspace_dir_path = self.config.get_workspace() + for file_path, source, scope, user_id in files_to_scan: + try: + content = file_path.read_text(encoding='utf-8') + except Exception: + continue + file_hash = MemoryStorage.compute_hash(content) + rel_path = str(file_path.relative_to(workspace_dir_path)) + if self.storage.get_file_hash(rel_path) == file_hash: + continue + chunks = self.chunker.chunk_text(content) + if not chunks: + continue + pending.append({ + "file_path": file_path, + "rel_path": rel_path, + "source": source, + "scope": scope, + "user_id": user_id, + "file_hash": file_hash, + "chunks": chunks, + "texts": [c.text for c in chunks], + }) + + if not pending: + self._dirty = False return - - texts = [chunk.text for chunk in chunks] - if self.embedding_provider: - embeddings = self.embedding_provider.embed_batch(texts) + + # Pass 2: single batched embed across all pending chunks. + # CRITICAL: never touch the index until we hold valid embeddings. + # If embed_batch fails, leave the existing index intact (chunks + + # file_hash) so the next sync will retry the same files. Writing + # NULL embeddings + updating file_hash here would mark the file as + # "successfully synced" and silently strand it without vectors. + all_texts: List[str] = [] + for entry in pending: + all_texts.extend(entry["texts"]) + + if not self.embedding_provider: + # No provider configured at all (legacy keyword-only). Persist + # chunks without embeddings — this is the user's intent. + all_embeddings: List[Optional[List[float]]] = [None] * len(all_texts) else: - embeddings = [None] * len(texts) - - # Create memory chunks - memory_chunks = [] - for chunk, embedding in zip(chunks, embeddings): - chunk_id = self._generate_chunk_id(rel_path, chunk.start_line, chunk.end_line) - chunk_hash = MemoryStorage.compute_hash(chunk.text) - - memory_chunks.append(MemoryChunk( - id=chunk_id, - user_id=user_id, - scope=scope, - source=source, + try: + all_embeddings = self.embedding_provider.embed_batch(all_texts) + except Exception as e: + from common.log import logger + logger.error( + f"[MemoryManager] Batch embedding failed for {len(all_texts)} " + f"chunks across {len(pending)} files: {e}. " + f"Index left untouched; will retry on next sync." + ) + # Bail before touching storage. self._dirty stays True so + # callers know there is pending work. + return + + # Pass 3: inline persist — same self-contained reasoning as Pass 1. + cursor = 0 + for entry in pending: + n = len(entry["texts"]) + entry_embeddings = all_embeddings[cursor:cursor + n] + cursor += n + + rel_path = entry["rel_path"] + self.storage.delete_by_path(rel_path) + memory_chunks = [] + for chunk, embedding in zip(entry["chunks"], entry_embeddings): + chunk_id = self._generate_chunk_id(rel_path, chunk.start_line, chunk.end_line) + chunk_hash = MemoryStorage.compute_hash(chunk.text) + memory_chunks.append(MemoryChunk( + id=chunk_id, + user_id=entry["user_id"], + scope=entry["scope"], + source=entry["source"], + path=rel_path, + start_line=chunk.start_line, + end_line=chunk.end_line, + text=chunk.text, + embedding=embedding, + hash=chunk_hash, + metadata=None, + )) + self.storage.save_chunks_batch(memory_chunks) + stat = entry["file_path"].stat() + self.storage.update_file_metadata( path=rel_path, - start_line=chunk.start_line, - end_line=chunk.end_line, - text=chunk.text, - embedding=embedding, - hash=chunk_hash, - metadata=None - )) - - # Save - self.storage.save_chunks_batch(memory_chunks) - - # Update file metadata - stat = file_path.stat() - self.storage.update_file_metadata( - path=rel_path, - source=source, - file_hash=file_hash, - mtime=int(stat.st_mtime), - size=stat.st_size - ) - + source=entry["source"], + file_hash=entry["file_hash"], + mtime=int(stat.st_mtime), + size=stat.st_size, + ) + + self._dirty = False + def flush_memory( self, messages: list, diff --git a/agent/memory/rebuild_index.py b/agent/memory/rebuild_index.py new file mode 100644 index 00000000..a975503d --- /dev/null +++ b/agent/memory/rebuild_index.py @@ -0,0 +1,14 @@ +""" +Backward-compatible shim for the legacy entry point: + python -m agent.memory.rebuild_index + +The implementation now lives in agent.memory.embedding.rebuild. +Prefer using `/memory rebuild-index` in chat going forward. +""" + +from agent.memory.embedding.rebuild import main + +if __name__ == "__main__": + import sys + + sys.exit(main()) diff --git a/agent/memory/storage.py b/agent/memory/storage.py index 8ff0504a..0a4e6edb 100644 --- a/agent/memory/storage.py +++ b/agent/memory/storage.py @@ -144,45 +144,37 @@ class MemoryStorage: ON chunks(path, hash) """) - # Create FTS5 virtual table for keyword search (only if supported) + # Create FTS5 virtual table + triggers (only if supported). + # Self-heal: if the previous process crashed mid-rebuild and left + # triggers pointing at a missing chunks_fts (or vice versa), wipe + # both sides and recreate cleanly. Otherwise next chunks INSERT + # will fail with "no such table: chunks_fts". if self.fts5_available: - # Use default unicode61 tokenizer (stable and compatible) - # For CJK support, we'll use LIKE queries as fallback - self.conn.execute(""" - CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5( - text, - id UNINDEXED, - user_id UNINDEXED, - path UNINDEXED, - source UNINDEXED, - scope UNINDEXED, - content='chunks', - content_rowid='rowid' + if self._fts5_state_inconsistent(): + from common.log import logger + logger.warning( + "[MemoryStorage] FTS5 state inconsistent (triggers/table mismatch). " + "Resetting chunks_fts to recover." ) - """) - - # Create triggers to keep FTS in sync - self.conn.execute(""" - CREATE TRIGGER IF NOT EXISTS chunks_ai AFTER INSERT ON chunks BEGIN - INSERT INTO chunks_fts(rowid, text, id, user_id, path, source, scope) - VALUES (new.rowid, new.text, new.id, new.user_id, new.path, new.source, new.scope); - END - """) - - self.conn.execute(""" - CREATE TRIGGER IF NOT EXISTS chunks_ad AFTER DELETE ON chunks BEGIN - DELETE FROM chunks_fts WHERE rowid = old.rowid; - END - """) - - self.conn.execute(""" - CREATE TRIGGER IF NOT EXISTS chunks_au AFTER UPDATE ON chunks BEGIN - UPDATE chunks_fts SET text = new.text, id = new.id, - user_id = new.user_id, path = new.path, source = new.source, scope = new.scope - WHERE rowid = new.rowid; - END - """) - + self.conn.execute("DROP TRIGGER IF EXISTS chunks_ai") + self.conn.execute("DROP TRIGGER IF EXISTS chunks_ad") + self.conn.execute("DROP TRIGGER IF EXISTS chunks_au") + self.conn.execute("DROP TABLE IF EXISTS chunks_fts") + self.conn.commit() + self._create_fts5_objects() + + # Probe FTS5 shadow tables. The schema may be intact but the + # internal _data/_idx/_docsize blob can still be corrupt — that + # surfaces as "database disk image is malformed" on bm25 / MATCH. + # We rebuild from the chunks table when that happens; data isn't + # lost because chunks (the content table) is the source of truth. + if self._fts5_shadow_corrupt(): + from common.log import logger + logger.warning( + "[MemoryStorage] FTS5 shadow tables corrupt; rebuilding from chunks." + ) + self._rebuild_fts5_from_chunks() + # Create files metadata table self.conn.execute(""" CREATE TABLE IF NOT EXISTS files ( @@ -196,7 +188,116 @@ class MemoryStorage: """) self.conn.commit() - + + def _fts5_state_inconsistent(self) -> bool: + """Detect a half-broken FTS5 setup (e.g. trigger exists but table doesn't).""" + try: + row = self.conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='chunks_fts'" + ).fetchone() + table_exists = row is not None + row = self.conn.execute( + "SELECT COUNT(*) FROM sqlite_master WHERE type='trigger' " + "AND name IN ('chunks_ai','chunks_ad','chunks_au')" + ).fetchone() + trigger_count = int(row[0]) if row else 0 + except Exception: + return False + # Healthy = both present (3 triggers + table) or both absent. + return table_exists != (trigger_count > 0) + + def _create_fts5_objects(self): + """Create chunks_fts virtual table and the 3 sync triggers. + + Idempotent: uses IF NOT EXISTS. Caller must hold self.conn. + """ + self.conn.execute(""" + CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5( + text, + id UNINDEXED, + user_id UNINDEXED, + path UNINDEXED, + source UNINDEXED, + scope UNINDEXED, + content='chunks', + content_rowid='rowid' + ) + """) + self.conn.execute(""" + CREATE TRIGGER IF NOT EXISTS chunks_ai AFTER INSERT ON chunks BEGIN + INSERT INTO chunks_fts(rowid, text, id, user_id, path, source, scope) + VALUES (new.rowid, new.text, new.id, new.user_id, new.path, new.source, new.scope); + END + """) + self.conn.execute(""" + CREATE TRIGGER IF NOT EXISTS chunks_ad AFTER DELETE ON chunks BEGIN + DELETE FROM chunks_fts WHERE rowid = old.rowid; + END + """) + self.conn.execute(""" + CREATE TRIGGER IF NOT EXISTS chunks_au AFTER UPDATE ON chunks BEGIN + UPDATE chunks_fts SET text = new.text, id = new.id, + user_id = new.user_id, path = new.path, + source = new.source, scope = new.scope + WHERE rowid = new.rowid; + END + """) + + def reset_fts5(self): + """Drop and recreate chunks_fts + triggers in one transaction. + + Used by rebuild_index to recover from FTS5 shadow-table corruption + (bm25/ORDER BY rank may raise "database disk image is malformed" + even when raw MATCH still works). + + Triggers must be dropped first; otherwise the next chunks INSERT/DELETE + on the existing connection will hit "no such table: chunks_fts". + """ + if not self.fts5_available: + return + self.conn.execute("DROP TRIGGER IF EXISTS chunks_ai") + self.conn.execute("DROP TRIGGER IF EXISTS chunks_ad") + self.conn.execute("DROP TRIGGER IF EXISTS chunks_au") + self.conn.execute("DROP TABLE IF EXISTS chunks_fts") + self._create_fts5_objects() + self.conn.commit() + + def _fts5_shadow_corrupt(self) -> bool: + """Probe whether bm25 over chunks_fts errors out at startup. + + Schema (table + triggers) can be intact while the underlying + FTS5 shadow blobs are malformed — typically because the previous + process crashed mid-write or wrote with a different SQLite build. + A cheap MATCH probe surfaces it immediately.""" + try: + self.conn.execute( + "SELECT bm25(chunks_fts) FROM chunks_fts WHERE chunks_fts MATCH 'a' LIMIT 1" + ).fetchone() + return False + except sqlite3.DatabaseError as e: + msg = str(e).lower() + return "malformed" in msg or "corrupt" in msg + except Exception: + # Any other error (e.g. table missing) is handled by the + # state-inconsistent path; treat as healthy here. + return False + + def _rebuild_fts5_from_chunks(self): + """Drop FTS5, recreate it, then INSERT every row from chunks. + + Safe data-wise: chunks (the content table) is the source of truth. + Done in one transaction so a crash leaves either fully old or fully + new state, not a partial rebuild. + """ + # Reset schema first; this clears any malformed shadow blobs. + self.reset_fts5() + # Re-feed content. Triggers handle future writes automatically. + self.conn.execute(""" + INSERT INTO chunks_fts(rowid, text, id, user_id, path, source, scope) + SELECT rowid, text, id, user_id, path, source, scope FROM chunks + """) + self.conn.commit() + def save_chunk(self, chunk: MemoryChunk): """Save a memory chunk""" self.conn.execute(""" @@ -283,13 +384,26 @@ class MemoryStorage: """ rows = self.conn.execute(query, params).fetchall() - - # Calculate cosine similarity + + # Calculate cosine similarity. We probe the first row's dim to fail + # loudly on a query/index dim mismatch — otherwise every doc would + # score 0 silently, leaving the user wondering why search broke. results = [] + query_dim = len(query_embedding) + if rows: + first = json.loads(rows[0]['embedding']) + if isinstance(first, list) and len(first) != query_dim: + raise ValueError( + f"Embedding dim mismatch: query is {query_dim}-dim but " + f"index stores {len(first)}-dim vectors. The configured " + f"embedding model differs from the one that built the " + f"index — run /memory rebuild-index to re-embed." + ) + for row in rows: embedding = json.loads(row['embedding']) similarity = self._cosine_similarity(query_embedding, embedding) - + if similarity > 0: results.append((similarity, row)) @@ -319,27 +433,24 @@ class MemoryStorage: ) -> List[SearchResult]: """ Keyword search using FTS5 + LIKE fallback - + Strategy: - 1. If FTS5 available: Try FTS5 search first (good for English and word-based languages) - 2. If no FTS5 or no results and query contains CJK: Use LIKE search + 1. If FTS5 available and healthy: try FTS5 first + 2. Always fall back to LIKE for CJK queries + 3. If FTS5 fails OR returns empty for non-CJK, also try LIKE so a + broken FTS5 shadow table doesn't silently kill keyword search. """ if scopes is None: scopes = ["shared"] if user_id: scopes.append("user") - - # Try FTS5 search first (if available) + if self.fts5_available: fts_results = self._search_fts5(query, user_id, scopes, limit) if fts_results: return fts_results - - # Fallback to LIKE search (always for CJK, or if FTS5 not available) - if not self.fts5_available or MemoryStorage._contains_cjk(query): - return self._search_like(query, user_id, scopes, limit) - - return [] + + return self._search_like(query, user_id, scopes, limit) def _search_fts5( self, @@ -394,7 +505,11 @@ class MemoryStorage: ) for row in rows ] - except Exception: + except Exception as e: + from common.log import logger + logger.error( + f"[MemoryStorage] FTS5 search failed (caller will fall back to LIKE): {e}" + ) return [] def _search_like( @@ -404,21 +519,28 @@ class MemoryStorage: scopes: List[str], limit: int ) -> List[SearchResult]: - """LIKE-based search for CJK characters""" + """LIKE-based search. + + Used as the keyword-search fallback when FTS5 is unavailable, fails, + or returns empty. Supports both CJK runs and ASCII word tokens so it + can serve as a true safety net for any query. + """ import re - # Extract CJK words (2+ characters) + # CJK runs (2+ chars) + ASCII word tokens (3+ chars to avoid noise) cjk_words = re.findall(r'[\u4e00-\u9fff]{2,}', query) - if not cjk_words: + ascii_words = [t for t in re.findall(r'[A-Za-z0-9_]+', query) if len(t) >= 3] + words = cjk_words + ascii_words + if not words: return [] - + scope_placeholders = ','.join('?' * len(scopes)) - - # Build LIKE conditions for each word + + # Build LIKE conditions for each word (case-insensitive for ASCII) like_conditions = [] params = [] - for word in cjk_words: - like_conditions.append("text LIKE ?") - params.append(f'%{word}%') + for word in words: + like_conditions.append("LOWER(text) LIKE ?") + params.append(f'%{word.lower()}%') where_clause = ' OR '.join(like_conditions) params.extend(scopes) @@ -455,7 +577,9 @@ class MemoryStorage: ) for row in rows ] - except Exception: + except Exception as e: + from common.log import logger + logger.error(f"[MemoryStorage] LIKE search failed: {e}") return [] def delete_by_path(self, path: str): @@ -485,14 +609,19 @@ class MemoryStorage: chunks_count = self.conn.execute(""" SELECT COUNT(*) as cnt FROM chunks """).fetchone()['cnt'] - + files_count = self.conn.execute(""" SELECT COUNT(*) as cnt FROM files """).fetchone()['cnt'] - + + embedded_count = self.conn.execute(""" + SELECT COUNT(*) as cnt FROM chunks WHERE embedding IS NOT NULL + """).fetchone()['cnt'] + return { 'chunks': chunks_count, - 'files': files_count + 'files': files_count, + 'embedded': embedded_count, } def close(self): diff --git a/bridge/agent_initializer.py b/bridge/agent_initializer.py index 5e0f3d37..d17dcb0c 100644 --- a/bridge/agent_initializer.py +++ b/bridge/agent_initializer.py @@ -17,6 +17,10 @@ from common.utils import expand_path # Module-level lock to serialize scheduler init across concurrent sessions _scheduler_init_lock = threading.Lock() +# Track whether the embedding model log has been printed in this process, +# so we avoid spamming it once per session. +_embedding_logged: bool = False + class AgentInitializer: """ @@ -272,52 +276,19 @@ class AgentInitializer: memory_tools = [] try: - from agent.memory import MemoryManager, MemoryConfig, create_embedding_provider + from agent.memory import MemoryManager, MemoryConfig from agent.tools import MemorySearchTool, MemoryGetTool from config import conf - - # Initialize embedding provider (prefer OpenAI, fallback to LinkAI) - embedding_provider = None - openai_api_key = conf().get("open_ai_api_key", "") - openai_api_base = conf().get("open_ai_api_base", "") - if openai_api_key and openai_api_key not in ["", "YOUR API KEY", "YOUR_API_KEY"]: - try: - embedding_provider = create_embedding_provider( - provider="openai", - model="text-embedding-3-small", - api_key=openai_api_key, - api_base=openai_api_base or "https://api.openai.com/v1" - ) - if session_id is None: - logger.info("[AgentInitializer] OpenAI embedding initialized") - except Exception as e: - logger.warning(f"[AgentInitializer] OpenAI embedding failed: {e}") - - if embedding_provider is None: - linkai_api_key = conf().get("linkai_api_key", "") or os.environ.get("LINKAI_API_KEY", "") - linkai_api_base = conf().get("linkai_api_base", "https://api.link-ai.tech") - if linkai_api_key and linkai_api_key not in ["", "YOUR API KEY", "YOUR_API_KEY"]: - try: - embedding_provider = create_embedding_provider( - provider="linkai", - model="text-embedding-3-small", - api_key=linkai_api_key, - api_base=f"{linkai_api_base}/v1" - ) - if session_id is None: - logger.info("[AgentInitializer] LinkAI embedding initialized (fallback)") - except Exception as e: - logger.warning(f"[AgentInitializer] LinkAI embedding failed: {e}") - - # Create memory manager memory_config = MemoryConfig(workspace_root=workspace_root) + + embedding_provider = self._init_embedding_provider( + memory_config, session_id=session_id + ) + memory_manager = MemoryManager(memory_config, embedding_provider=embedding_provider) - - # Sync memory self._sync_memory(memory_manager, session_id) - - # Create memory tools + memory_tools = [ MemorySearchTool(memory_manager), MemoryGetTool(memory_manager) @@ -330,6 +301,190 @@ class AgentInitializer: logger.warning(f"[AgentInitializer] Memory system not available: {e}") return memory_manager, memory_tools + + def _init_embedding_provider(self, memory_config, session_id: Optional[str] = None): + """ + Initialize the embedding provider for memory. + + Two paths: + A. Default (no `embedding_provider` in config.json): + Auto-init OpenAI -> LinkAI fallback. Existing 1536-dim indices + keep working. + B. Explicit (`embedding_provider` is set): + Initialize the requested vendor with unified dim (default 1024). + If the index was built with a different dim, vector search will + quietly return no results (cosine returns 0) and keyword search + takes over until the user runs /memory rebuild-index. + """ + from agent.memory import create_embedding_provider + from config import conf + + explicit_provider = (conf().get("embedding_provider") or "").strip().lower() + + if not explicit_provider: + return self._init_embedding_provider_legacy(session_id=session_id) + + return self._init_embedding_provider_explicit( + memory_config, explicit_provider, session_id=session_id, + ) + + def _init_embedding_provider_legacy(self, session_id: Optional[str] = None): + """Legacy auto-init path: OpenAI -> LinkAI. Preserved verbatim for compat.""" + from agent.memory import create_embedding_provider + from config import conf + + embedding_provider = None + embedding_model = None + + openai_api_key = conf().get("open_ai_api_key", "") + openai_api_base = conf().get("open_ai_api_base", "") + if openai_api_key and openai_api_key not in ["", "YOUR API KEY", "YOUR_API_KEY"]: + try: + model = "text-embedding-3-small" + embedding_provider = create_embedding_provider( + provider="openai", + model=model, + api_key=openai_api_key, + api_base=openai_api_base or "https://api.openai.com/v1" + ) + embedding_model = f"openai/{model}" + except Exception as e: + logger.warning(f"[AgentInitializer] OpenAI embedding failed: {e}") + + if embedding_provider is None: + linkai_api_key = conf().get("linkai_api_key", "") or os.environ.get("LINKAI_API_KEY", "") + linkai_api_base = conf().get("linkai_api_base", "https://api.link-ai.tech") + if linkai_api_key and linkai_api_key not in ["", "YOUR API KEY", "YOUR_API_KEY"]: + try: + model = "text-embedding-3-small" + embedding_provider = create_embedding_provider( + provider="linkai", + model=model, + api_key=linkai_api_key, + api_base=f"{linkai_api_base}/v1" + ) + embedding_model = f"linkai/{model}" + except Exception as e: + logger.warning(f"[AgentInitializer] LinkAI embedding failed: {e}") + + if embedding_provider is not None and embedding_model: + global _embedding_logged + if not _embedding_logged: + logger.info( + f"[AgentInitializer] Embedding model in use: {embedding_model} " + f"(dim={embedding_provider.dimensions})" + ) + _embedding_logged = True + + return embedding_provider + + def _init_embedding_provider_explicit( + self, + memory_config, + provider_key: str, + session_id: Optional[str] = None, + ): + """Explicit-provider path: build the configured vendor. + + If the index was built with a different dim, vector search will + silently return no results (cosine returns 0 for mismatched dims) + and keyword search takes over. Users switch vendors by running + /memory rebuild-index — see docs. + """ + from agent.memory import create_embedding_provider + from agent.memory.embedding import EMBEDDING_VENDORS + from config import conf + + meta = EMBEDDING_VENDORS.get(provider_key) + if meta is None: + logger.error( + f"[AgentInitializer] Unknown embedding_provider '{provider_key}'. " + f"Supported: {sorted(EMBEDDING_VENDORS.keys())}. " + f"Memory will run in keyword-only mode." + ) + return None + + api_key = self._resolve_embedding_api_key(provider_key) + api_base = self._resolve_embedding_api_base(provider_key, meta["default_base_url"]) + + if not api_key: + logger.error( + f"[AgentInitializer] embedding_provider='{provider_key}' is set but its " + f"API key is missing. Memory will run in keyword-only mode." + ) + return None + + model = (conf().get("embedding_model") or "").strip() or meta["default_model"] + try: + cfg_dim = int(conf().get("embedding_dimensions") or 0) + except (TypeError, ValueError): + cfg_dim = 0 + dim = cfg_dim if cfg_dim > 0 else meta["default_dimensions"] + + try: + provider = create_embedding_provider( + provider=provider_key, + model=model, + api_key=api_key, + api_base=api_base, + dimensions=dim, + ) + except Exception as e: + logger.error( + f"[AgentInitializer] Failed to init embedding provider " + f"'{provider_key}/{model}': {e}" + ) + return None + + global _embedding_logged + if not _embedding_logged: + logger.info( + f"[AgentInitializer] Embedding model in use: " + f"{provider_key}/{model} (dim={provider.dimensions})" + ) + _embedding_logged = True + return provider + + @staticmethod + def _resolve_embedding_api_key(provider_key: str) -> str: + """Pick the API key for an explicit embedding provider from config.""" + from config import conf + + key_map = { + "openai": "open_ai_api_key", + "linkai": "linkai_api_key", + "dashscope": "dashscope_api_key", + "doubao": "ark_api_key", + "zhipu": "zhipu_ai_api_key", + } + field = key_map.get(provider_key) + if not field: + return "" + value = conf().get(field, "") or "" + if value in ["", "YOUR API KEY", "YOUR_API_KEY"]: + return "" + return value + + @staticmethod + def _resolve_embedding_api_base(provider_key: str, default_base: str) -> str: + """Pick the API base for an explicit embedding provider from config.""" + from config import conf + + base_map = { + "openai": "open_ai_api_base", + "linkai": "linkai_api_base", + "doubao": "ark_base_url", + "zhipu": "zhipu_ai_api_base", + } + field = base_map.get(provider_key) + if not field: + return default_base + value = (conf().get(field) or "").strip() + if not value: + return default_base + if provider_key == "linkai" and not value.rstrip("/").endswith("/v1"): + return f"{value.rstrip('/')}/v1" + return value def _sync_memory(self, memory_manager, session_id: Optional[str] = None): """Sync memory database""" diff --git a/cli/cli.py b/cli/cli.py index 9bc6f48f..99b837ab 100644 --- a/cli/cli.py +++ b/cli/cli.py @@ -26,7 +26,8 @@ Commands: knowledge Manage knowledge base. install-browser Install browser tool (Playwright + Chromium). -Tip: You can also send /help, /skill list, etc. in agent chat.""" +Tip: Memory index management lives in chat — send /memory status or +/memory rebuild-index to the running agent.""" class CowCLI(click.Group): diff --git a/config.py b/config.py index f281cecb..1d08bb64 100644 --- a/config.py +++ b/config.py @@ -100,6 +100,10 @@ available_setting = { "dashscope_api_key": "", # Google Gemini Api Key "gemini_api_key": "", + # Embedding 模型设置 + "embedding_provider": "", # 显式指定厂商:openai / linkai / dashscope / doubao / zhipu (与 bot_type 命名一致) + "embedding_model": "", # 留空使用厂商默认 model + "embedding_dimensions": 0, # 留空/0 使用厂商默认维度(推荐统一 1024) # 语音设置 "speech_recognition": True, # 是否开启语音识别 "group_speech_recognition": False, # 是否开启群组语音识别 diff --git a/plugins/cow_cli/cow_cli.py b/plugins/cow_cli/cow_cli.py index 42d27330..3ecbfaec 100644 --- a/plugins/cow_cli/cow_cli.py +++ b/plugins/cow_cli/cow_cli.py @@ -62,10 +62,25 @@ class CowCliPlugin(Plugin): content = e_context["context"].content.strip() parsed = self._parse_command(content) - if not parsed: + if parsed is None: return cmd, args = parsed + + if cmd not in KNOWN_COMMANDS: + # Slash-prefixed near-miss: looks like a typo of a real command. + # Intercept with a hint so we don't burn an LLM round on "/momory". + suggestion = self._suggest_command(cmd) + if suggestion is None: + return + hint = f"未知命令: /{cmd}" + if suggestion: + hint += f"\n你是不是想输入 /{suggestion} ?" + hint += "\n发送 /help 查看全部命令。" + e_context["reply"] = Reply(ReplyType.TEXT, hint) + e_context.action = EventAction.BREAK_PASS + return + logger.info(f"[CowCli] intercepted command: {cmd} {args}") result = self._dispatch(cmd, args, e_context) @@ -82,28 +97,80 @@ class CowCliPlugin(Plugin): cow [args...] e.g. "cow skill list" / [args...] e.g. "/skill list" - Returns (command, args_string) or None if not a cow command. - """ - parts = None + Returns: + - (command, args_string): when the message looks like a command. + 'command' may NOT be in KNOWN_COMMANDS; caller should validate. + - None: when the message is not command-like at all. + We deliberately return parsed-but-unknown for the slash form so the + caller can offer a typo hint instead of silently passing the message + through to the agent. + """ if content.startswith("/"): rest = content[1:].strip() - if rest: - parts = rest.split(None, 1) - elif content.startswith("cow "): + if not rest: + return None + parts = rest.split(None, 1) + cmd = parts[0].lower() + args = parts[1] if len(parts) > 1 else "" + return cmd, args + + if content.startswith("cow "): rest = content[4:].strip() - if rest: - parts = rest.split(None, 1) + if not rest: + return None + parts = rest.split(None, 1) + cmd = parts[0].lower() + if cmd not in KNOWN_COMMANDS: + # 'cow xxx' that isn't a command — don't intercept (could be + # natural language like "cow xxx 怎么样"). + return None + args = parts[1] if len(parts) > 1 else "" + return cmd, args - if not parts: + return None + + @staticmethod + def _suggest_command(cmd: str) -> str: + """ + Return the closest known command if cmd is a likely typo, else "". + Returns None to indicate "do not intercept" (when input is too far off). + + Heuristic: edit distance <= 1 (single insert/delete/substitute) when + |cmd| >= 3, and the candidate shares the same first letter. + """ + if not cmd: + return "" + if len(cmd) < 3: return None - cmd = parts[0].lower() - if cmd not in KNOWN_COMMANDS: - return None + def edit_distance_le1(a: str, b: str) -> bool: + if a == b: + return True + la, lb = len(a), len(b) + if abs(la - lb) > 1: + return False + if la == lb: + diffs = sum(1 for x, y in zip(a, b) if x != y) + return diffs <= 1 + short, long_ = (a, b) if la < lb else (b, a) + i = j = 0 + skipped = False + while i < len(short) and j < len(long_): + if short[i] != long_[j]: + if skipped: + return False + skipped = True + j += 1 + else: + i += 1 + j += 1 + return True - args = parts[1] if len(parts) > 1 else "" - return cmd, args + for known in KNOWN_COMMANDS: + if known[0] == cmd[0] and edit_distance_le1(cmd, known): + return known + return None # ------------------------------------------------------------------ # Command dispatch @@ -113,12 +180,23 @@ class CowCliPlugin(Plugin): """Execute a cow/slash command string without a channel context. Used by cloud on_chat to intercept commands before the agent runs. - Returns None when *query* is not a recognised command. + Returns None when *query* is not command-like at all (e.g. natural + language). For slash-prefixed typos returns a hint string so the + caller still short-circuits the agent round. """ parsed = self._parse_command(query.strip()) - if not parsed: + if parsed is None: return None cmd, args = parsed + if cmd not in KNOWN_COMMANDS: + suggestion = self._suggest_command(cmd) + if suggestion is None: + return None + hint = f"未知命令: /{cmd}" + if suggestion: + hint += f"\n你是不是想输入 /{suggestion} ?" + hint += "\n发送 /help 查看全部命令。" + return hint return self._dispatch(cmd, args, e_context=None, session_id=session_id) def _dispatch(self, cmd: str, args: str, e_context: EventContext, session_id: str = "") -> str: @@ -158,7 +236,9 @@ class CowCliPlugin(Plugin): " /config 查看当前配置", " /config 查看某项配置", " /config 修改配置", - " /memory dream [N] 手动触发记忆蒸馏 (整理近N天, 默认3, 最多30)", + " /memory status 查看记忆索引状态", + " /memory rebuild-index 清空并重建向量索引 (切换 embedding 模型后必须执行)", + " /memory dream [N] 手动触发记忆蒸馏 (整理近N天, 默认3, 最多30)", " /knowledge 查看知识库统计", " /knowledge list 查看知识库文件树", " /knowledge on|off 开启/关闭知识库", @@ -907,12 +987,25 @@ class CowCliPlugin(Plugin): if len(parts) > 1 and parts[1].isdigit(): days = max(1, min(int(parts[1]), 30)) return self._memory_dream(days, e_context, session_id) + elif sub in ("rebuild-index", "rebuild_index", "rebuild"): + return self._memory_rebuild_index(e_context, session_id) + elif sub in ("status", "info", ""): + if sub == "": + return self._memory_help() + return self._memory_status() else: - return ( - "用法: /memory <子命令>\n\n" - "子命令:\n" - " dream [N] 手动触发记忆蒸馏 (整理近N天, 默认3, 最多30)" - ) + return self._memory_help() + + @staticmethod + def _memory_help() -> str: + return ( + "🧠 记忆管理\n\n" + "用法: /memory <子命令>\n\n" + "子命令:\n" + " status 查看索引状态 (provider / model / dim / chunks)\n" + " rebuild-index 清空并重建向量索引 (切换 embedding 模型后必须执行)\n" + " dream [N] 手动触发记忆蒸馏 (整理近N天, 默认3, 最多30)" + ) def _memory_dream(self, days: int, e_context, session_id: str) -> str: session_id = self._get_session_id(e_context, fallback=session_id) @@ -963,6 +1056,140 @@ class CowCliPlugin(Plugin): logger.warning(f"[CowCli] /memory dream sync failed: {e}") return f"❌ 记忆蒸馏失败: {e}" + def _memory_status(self) -> str: + """Show current memory index status.""" + from agent.memory.embedding import detect_index_dim + from config import conf + + agent = self._get_agent("") + memory_manager = agent.memory_manager if agent else None + + lines = ["🧠 记忆索引状态", ""] + if not memory_manager: + lines.append(" ⚠️ Agent 尚未初始化,先发一条普通消息再试") + return "\n".join(lines) + + stats = memory_manager.storage.get_stats() + db_path = memory_manager.config.get_db_path() + embedded = stats.get('embedded', 0) + chunks = stats.get('chunks', 0) + lines.append(f" 索引DB : {db_path}") + lines.append(f" Files : {stats.get('files', 0)}") + lines.append(f" Chunks : {chunks} (embedded: {embedded})") + lines.append("") + + # Active provider (from running config + provider instance). + provider_obj = memory_manager.embedding_provider + cfg_provider = (conf().get("embedding_provider") or "").strip().lower() or "(legacy)" + if provider_obj is not None: + cfg_model = getattr(provider_obj, "model", "?") + cfg_dim = getattr(provider_obj, "_dimensions", None) or "?" + lines.append(f" Provider : {cfg_provider}") + lines.append(f" Model : {cfg_model}") + lines.append(f" Dim : {cfg_dim}") + else: + lines.append(" Provider : (未初始化, keyword-only)") + + # Health hints — only shown when the user has explicitly opted into + # vector search via `embedding_provider`. Legacy users (no explicit + # provider) are running in a "best-effort vectors" mode by design; + # nagging them about missing/mismatched vectors would be noise. + warnings = [] + explicitly_opted_in = (conf().get("embedding_provider") or "").strip() != "" + if explicitly_opted_in and provider_obj is not None: + if chunks > 0 and embedded < chunks: + missing = chunks - embedded + warnings.append( + f" ⚠️ {missing}/{chunks} 个 chunk 没有向量;" + f"运行 /memory rebuild-index 后所有记忆才会被向量化检索" + ) + + index_dim = detect_index_dim(memory_manager.storage) + cfg_dim = getattr(provider_obj, "_dimensions", None) + if index_dim is not None and cfg_dim and index_dim != cfg_dim: + warnings.append( + f" ⚠️ 索引中存量向量为 {index_dim} 维,与当前配置 {cfg_dim} 维不一致;" + f"运行 /memory rebuild-index 重建后向量检索才会生效" + ) + + if warnings: + lines.append("") + lines.extend(warnings) + + return "\n".join(lines) + + def _memory_rebuild_index(self, e_context, session_id: str) -> str: + """Rebuild the vector index using the current agent's memory_manager.""" + session_id = self._get_session_id(e_context, fallback=session_id) + agent = self._get_agent(session_id) + if not agent or not agent.memory_manager: + return ( + "⚠️ Agent 尚未初始化,无法重建索引。\n" + "请先发送一条普通消息触发 Agent 启动后再试。" + ) + + memory_manager = agent.memory_manager + if memory_manager.embedding_provider is None: + return ( + "⚠️ 当前没有可用的 embedding provider。\n" + "请检查 config.json 中的 embedding 相关配置 (provider / api key)。" + ) + + provider_obj = memory_manager.embedding_provider + model_label = getattr(provider_obj, "model", "?") + dim_label = getattr(provider_obj, "dimensions", "?") + + # SaaS (e_context is None): run synchronously, return final result + if e_context is None: + return self._memory_rebuild_sync(memory_manager, model_label, dim_label) + + # Local channels: run in background, push progress + final result + from agent.memory.embedding import rebuild_in_process + + def _run(): + try: + result = rebuild_in_process(memory_manager) + if result.ok: + self._notify( + e_context, + ( + f"✅ 索引重建完成\n" + f" cleared : {result.removed}\n" + f" chunks : {result.chunks}\n" + f" files : {result.files}" + ), + ) + else: + self._notify(e_context, f"❌ 索引重建失败: {result.error}") + except Exception as e: + logger.exception("[CowCli] /memory rebuild-index failed") + self._notify(e_context, f"❌ 索引重建失败: {e}") + + threading.Thread(target=_run, daemon=True).start() + return ( + f"🔧 索引重建已启动 (model={model_label}, dim={dim_label})\n\n" + f"将清空现有 chunks 并重新 embed 所有记忆文件,完成后会通知你。" + ) + + @staticmethod + def _memory_rebuild_sync(memory_manager, model_label, dim_label) -> str: + from agent.memory.embedding import rebuild_in_process + + try: + result = rebuild_in_process(memory_manager) + except Exception as e: + logger.exception("[CowCli] /memory rebuild-index sync failed") + return f"❌ 索引重建失败: {e}" + + if not result.ok: + return f"❌ 索引重建失败: {result.error}" + return ( + f"✅ 索引重建完成 (model={model_label}, dim={dim_label})\n" + f" cleared : {result.removed}\n" + f" chunks : {result.chunks}\n" + f" files : {result.files}" + ) + @staticmethod def _notify(e_context, text: str): """Push a notification message back to the chat channel."""