mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-02 00:57:41 +08:00
feat(memory): support multi-vendor embedding fallback
Add embedding_provider config knob with native support for openai / dashscope / doubao / zhipu / linkai, plus an in-chat /memory status and /memory rebuild-index workflow for switching vendors safely.
This commit is contained in:
@@ -13,7 +13,7 @@ from datetime import datetime, timedelta
|
||||
from agent.memory.config import MemoryConfig, get_default_memory_config
|
||||
from agent.memory.storage import MemoryStorage, MemoryChunk, SearchResult
|
||||
from agent.memory.chunker import TextChunker
|
||||
from agent.memory.embedding import create_embedding_provider, EmbeddingProvider
|
||||
from agent.memory.embedding import EmbeddingProvider
|
||||
from agent.memory.summarizer import MemoryFlushManager, create_memory_files_if_needed
|
||||
|
||||
|
||||
@@ -50,49 +50,17 @@ class MemoryManager:
|
||||
overlap_tokens=self.config.chunk_overlap_tokens
|
||||
)
|
||||
|
||||
# Initialize embedding provider (optional, prefer OpenAI, fallback to LinkAI)
|
||||
self.embedding_provider = None
|
||||
if embedding_provider:
|
||||
self.embedding_provider = embedding_provider
|
||||
else:
|
||||
# Try OpenAI first
|
||||
try:
|
||||
api_key = os.environ.get('OPENAI_API_KEY')
|
||||
api_base = os.environ.get('OPENAI_API_BASE')
|
||||
if api_key:
|
||||
self.embedding_provider = create_embedding_provider(
|
||||
provider="openai",
|
||||
model=self.config.embedding_model,
|
||||
api_key=api_key,
|
||||
api_base=api_base
|
||||
)
|
||||
except Exception as e:
|
||||
from common.log import logger
|
||||
logger.warning(f"[MemoryManager] OpenAI embedding failed: {e}")
|
||||
|
||||
# Fallback to LinkAI
|
||||
if self.embedding_provider is None:
|
||||
try:
|
||||
linkai_key = os.environ.get('LINKAI_API_KEY')
|
||||
linkai_base = os.environ.get('LINKAI_API_BASE', 'https://api.link-ai.tech')
|
||||
if linkai_key:
|
||||
from common.utils import get_cloud_headers
|
||||
cloud_headers = get_cloud_headers(linkai_key)
|
||||
cloud_headers.pop("Authorization", None)
|
||||
self.embedding_provider = create_embedding_provider(
|
||||
provider="linkai",
|
||||
model=self.config.embedding_model,
|
||||
api_key=linkai_key,
|
||||
api_base=f"{linkai_base}/v1",
|
||||
extra_headers=cloud_headers,
|
||||
)
|
||||
except Exception as e:
|
||||
from common.log import logger
|
||||
logger.warning(f"[MemoryManager] LinkAI embedding failed: {e}")
|
||||
|
||||
if self.embedding_provider is None:
|
||||
from common.log import logger
|
||||
logger.info(f"[MemoryManager] Memory will work with keyword search only (no vector search)")
|
||||
# Embedding provider is owned by the caller (agent_initializer is the
|
||||
# canonical entry point and handles legacy/explicit + state validation).
|
||||
# When None is passed, memory degrades to keyword-only search instead
|
||||
# of silently re-initializing a vendor here, which would bypass the
|
||||
# caller's state checks and risk corrupting the index.
|
||||
self.embedding_provider = embedding_provider
|
||||
if self.embedding_provider is None:
|
||||
from common.log import logger
|
||||
logger.info(
|
||||
"[MemoryManager] No embedding provider; memory will use keyword search only"
|
||||
)
|
||||
|
||||
# Initialize memory flush manager
|
||||
workspace_dir = self.config.get_workspace()
|
||||
@@ -153,12 +121,14 @@ class MemoryManager:
|
||||
if self.config.sync_on_search and self._dirty:
|
||||
await self.sync()
|
||||
|
||||
# Perform vector search (if embedding provider available)
|
||||
from common.log import logger
|
||||
|
||||
# Perform vector search (if embedding provider available).
|
||||
# Failures degrade silently to keyword-only — no exception is raised.
|
||||
vector_results = []
|
||||
if self.embedding_provider:
|
||||
try:
|
||||
from common.log import logger
|
||||
query_embedding = self.embedding_provider.embed(query)
|
||||
query_embedding = self.embedding_provider.embed_query(query)
|
||||
vector_results = self.storage.search_vector(
|
||||
query_embedding=query_embedding,
|
||||
user_id=user_id,
|
||||
@@ -167,19 +137,19 @@ class MemoryManager:
|
||||
)
|
||||
logger.info(f"[MemoryManager] Vector search found {len(vector_results)} results for query: {query}")
|
||||
except Exception as e:
|
||||
from common.log import logger
|
||||
logger.warning(f"[MemoryManager] Vector search failed: {e}")
|
||||
|
||||
# Perform keyword search
|
||||
logger.error(
|
||||
f"[MemoryManager] Vector search failed, falling back to keyword-only: {e}"
|
||||
)
|
||||
|
||||
# Perform keyword search (also runs as fallback when vector failed)
|
||||
keyword_results = self.storage.search_keyword(
|
||||
query=query,
|
||||
user_id=user_id,
|
||||
scopes=scopes,
|
||||
limit=max_results * 2
|
||||
)
|
||||
from common.log import logger
|
||||
logger.info(f"[MemoryManager] Keyword search found {len(keyword_results)} results for query: {query}")
|
||||
|
||||
|
||||
# Merge results
|
||||
merged = self._merge_results(
|
||||
vector_results,
|
||||
@@ -187,7 +157,7 @@ class MemoryManager:
|
||||
self.config.vector_weight,
|
||||
self.config.keyword_weight
|
||||
)
|
||||
|
||||
|
||||
# Filter by min score and limit
|
||||
filtered = [r for r in merged if r.score >= min_score]
|
||||
return filtered[:max_results]
|
||||
@@ -269,132 +239,157 @@ class MemoryManager:
|
||||
|
||||
async def sync(self, force: bool = False):
|
||||
"""
|
||||
Synchronize memory from files
|
||||
|
||||
Synchronize memory from files.
|
||||
|
||||
Two-pass design to amortize embedding HTTP cost:
|
||||
1. Walk all files, chunk those whose hash changed, collect pending
|
||||
chunks across files. No embedding calls yet.
|
||||
2. Run a single embed_batch over the union of pending chunks (the
|
||||
provider auto-paginates by vendor cap), then persist per-file.
|
||||
|
||||
For workspaces with many small files (101 files / ~1 chunk each), this
|
||||
cuts ~100 HTTP calls down to ~ceil(total_chunks / vendor_cap).
|
||||
|
||||
Args:
|
||||
force: Force full reindex
|
||||
"""
|
||||
memory_dir = self.config.get_memory_dir()
|
||||
workspace_dir = self.config.get_workspace()
|
||||
|
||||
# Scan MEMORY.md (workspace root)
|
||||
|
||||
files_to_scan: List[tuple] = [] # (file_path, source, scope, user_id)
|
||||
|
||||
memory_file = Path(workspace_dir) / "MEMORY.md"
|
||||
if memory_file.exists():
|
||||
await self._sync_file(memory_file, "memory", "shared", None)
|
||||
|
||||
# Scan memory directory (including daily summaries)
|
||||
files_to_scan.append((memory_file, "memory", "shared", None))
|
||||
|
||||
if memory_dir.exists():
|
||||
for file_path in memory_dir.rglob("*.md"):
|
||||
# Skip hidden directories (e.g. .dreams/)
|
||||
if any(part.startswith('.') for part in file_path.relative_to(workspace_dir).parts):
|
||||
continue
|
||||
|
||||
# Determine scope and user_id from path
|
||||
rel_path = file_path.relative_to(workspace_dir)
|
||||
parts = rel_path.parts
|
||||
|
||||
# Check if it's in daily summary directory
|
||||
if "daily" in parts:
|
||||
# Daily summary files
|
||||
if "users" in parts or len(parts) > 3:
|
||||
# User-scoped daily summary: memory/daily/{user_id}/2024-01-29.md
|
||||
user_idx = parts.index("daily") + 1
|
||||
user_id = parts[user_idx] if user_idx < len(parts) else None
|
||||
rel_parts = file_path.relative_to(workspace_dir).parts
|
||||
if "daily" in rel_parts:
|
||||
if "users" in rel_parts or len(rel_parts) > 3:
|
||||
user_idx = rel_parts.index("daily") + 1
|
||||
user_id = rel_parts[user_idx] if user_idx < len(rel_parts) else None
|
||||
scope = "user"
|
||||
else:
|
||||
# Shared daily summary: memory/daily/2024-01-29.md
|
||||
user_id = None
|
||||
scope = "shared"
|
||||
elif "users" in parts:
|
||||
# User-scoped memory
|
||||
user_idx = parts.index("users") + 1
|
||||
user_id = parts[user_idx] if user_idx < len(parts) else None
|
||||
elif "users" in rel_parts:
|
||||
user_idx = rel_parts.index("users") + 1
|
||||
user_id = rel_parts[user_idx] if user_idx < len(rel_parts) else None
|
||||
scope = "user"
|
||||
else:
|
||||
# Shared memory
|
||||
user_id = None
|
||||
scope = "shared"
|
||||
|
||||
await self._sync_file(file_path, "memory", scope, user_id)
|
||||
files_to_scan.append((file_path, "memory", scope, user_id))
|
||||
|
||||
# Scan knowledge directory (structured knowledge wiki)
|
||||
from config import conf
|
||||
if conf().get("knowledge", True):
|
||||
knowledge_dir = Path(workspace_dir) / "knowledge"
|
||||
if knowledge_dir.exists():
|
||||
for file_path in knowledge_dir.rglob("*.md"):
|
||||
await self._sync_file(file_path, "knowledge", "shared", None)
|
||||
|
||||
self._dirty = False
|
||||
|
||||
async def _sync_file(
|
||||
self,
|
||||
file_path: Path,
|
||||
source: str,
|
||||
scope: str,
|
||||
user_id: Optional[str]
|
||||
):
|
||||
"""Sync a single file"""
|
||||
# Compute file hash
|
||||
content = file_path.read_text(encoding='utf-8')
|
||||
file_hash = MemoryStorage.compute_hash(content)
|
||||
|
||||
# Get relative path
|
||||
workspace_dir = self.config.get_workspace()
|
||||
rel_path = str(file_path.relative_to(workspace_dir))
|
||||
|
||||
# Check if file changed
|
||||
stored_hash = self.storage.get_file_hash(rel_path)
|
||||
if stored_hash == file_hash:
|
||||
return # No changes
|
||||
|
||||
# Delete old chunks
|
||||
self.storage.delete_by_path(rel_path)
|
||||
|
||||
# Chunk and embed
|
||||
chunks = self.chunker.chunk_text(content)
|
||||
if not chunks:
|
||||
files_to_scan.append((file_path, "knowledge", "shared", None))
|
||||
|
||||
# Pass 1: inline chunking + change detection. Inlined (instead of
|
||||
# calling self._prepare_file_for_sync) so this method does not depend
|
||||
# on any sibling helpers — keeps it robust against partial reloads
|
||||
# where the class object is older than the method's source.
|
||||
pending: List[Dict[str, Any]] = []
|
||||
workspace_dir_path = self.config.get_workspace()
|
||||
for file_path, source, scope, user_id in files_to_scan:
|
||||
try:
|
||||
content = file_path.read_text(encoding='utf-8')
|
||||
except Exception:
|
||||
continue
|
||||
file_hash = MemoryStorage.compute_hash(content)
|
||||
rel_path = str(file_path.relative_to(workspace_dir_path))
|
||||
if self.storage.get_file_hash(rel_path) == file_hash:
|
||||
continue
|
||||
chunks = self.chunker.chunk_text(content)
|
||||
if not chunks:
|
||||
continue
|
||||
pending.append({
|
||||
"file_path": file_path,
|
||||
"rel_path": rel_path,
|
||||
"source": source,
|
||||
"scope": scope,
|
||||
"user_id": user_id,
|
||||
"file_hash": file_hash,
|
||||
"chunks": chunks,
|
||||
"texts": [c.text for c in chunks],
|
||||
})
|
||||
|
||||
if not pending:
|
||||
self._dirty = False
|
||||
return
|
||||
|
||||
texts = [chunk.text for chunk in chunks]
|
||||
if self.embedding_provider:
|
||||
embeddings = self.embedding_provider.embed_batch(texts)
|
||||
|
||||
# Pass 2: single batched embed across all pending chunks.
|
||||
# CRITICAL: never touch the index until we hold valid embeddings.
|
||||
# If embed_batch fails, leave the existing index intact (chunks +
|
||||
# file_hash) so the next sync will retry the same files. Writing
|
||||
# NULL embeddings + updating file_hash here would mark the file as
|
||||
# "successfully synced" and silently strand it without vectors.
|
||||
all_texts: List[str] = []
|
||||
for entry in pending:
|
||||
all_texts.extend(entry["texts"])
|
||||
|
||||
if not self.embedding_provider:
|
||||
# No provider configured at all (legacy keyword-only). Persist
|
||||
# chunks without embeddings — this is the user's intent.
|
||||
all_embeddings: List[Optional[List[float]]] = [None] * len(all_texts)
|
||||
else:
|
||||
embeddings = [None] * len(texts)
|
||||
|
||||
# Create memory chunks
|
||||
memory_chunks = []
|
||||
for chunk, embedding in zip(chunks, embeddings):
|
||||
chunk_id = self._generate_chunk_id(rel_path, chunk.start_line, chunk.end_line)
|
||||
chunk_hash = MemoryStorage.compute_hash(chunk.text)
|
||||
|
||||
memory_chunks.append(MemoryChunk(
|
||||
id=chunk_id,
|
||||
user_id=user_id,
|
||||
scope=scope,
|
||||
source=source,
|
||||
try:
|
||||
all_embeddings = self.embedding_provider.embed_batch(all_texts)
|
||||
except Exception as e:
|
||||
from common.log import logger
|
||||
logger.error(
|
||||
f"[MemoryManager] Batch embedding failed for {len(all_texts)} "
|
||||
f"chunks across {len(pending)} files: {e}. "
|
||||
f"Index left untouched; will retry on next sync."
|
||||
)
|
||||
# Bail before touching storage. self._dirty stays True so
|
||||
# callers know there is pending work.
|
||||
return
|
||||
|
||||
# Pass 3: inline persist — same self-contained reasoning as Pass 1.
|
||||
cursor = 0
|
||||
for entry in pending:
|
||||
n = len(entry["texts"])
|
||||
entry_embeddings = all_embeddings[cursor:cursor + n]
|
||||
cursor += n
|
||||
|
||||
rel_path = entry["rel_path"]
|
||||
self.storage.delete_by_path(rel_path)
|
||||
memory_chunks = []
|
||||
for chunk, embedding in zip(entry["chunks"], entry_embeddings):
|
||||
chunk_id = self._generate_chunk_id(rel_path, chunk.start_line, chunk.end_line)
|
||||
chunk_hash = MemoryStorage.compute_hash(chunk.text)
|
||||
memory_chunks.append(MemoryChunk(
|
||||
id=chunk_id,
|
||||
user_id=entry["user_id"],
|
||||
scope=entry["scope"],
|
||||
source=entry["source"],
|
||||
path=rel_path,
|
||||
start_line=chunk.start_line,
|
||||
end_line=chunk.end_line,
|
||||
text=chunk.text,
|
||||
embedding=embedding,
|
||||
hash=chunk_hash,
|
||||
metadata=None,
|
||||
))
|
||||
self.storage.save_chunks_batch(memory_chunks)
|
||||
stat = entry["file_path"].stat()
|
||||
self.storage.update_file_metadata(
|
||||
path=rel_path,
|
||||
start_line=chunk.start_line,
|
||||
end_line=chunk.end_line,
|
||||
text=chunk.text,
|
||||
embedding=embedding,
|
||||
hash=chunk_hash,
|
||||
metadata=None
|
||||
))
|
||||
|
||||
# Save
|
||||
self.storage.save_chunks_batch(memory_chunks)
|
||||
|
||||
# Update file metadata
|
||||
stat = file_path.stat()
|
||||
self.storage.update_file_metadata(
|
||||
path=rel_path,
|
||||
source=source,
|
||||
file_hash=file_hash,
|
||||
mtime=int(stat.st_mtime),
|
||||
size=stat.st_size
|
||||
)
|
||||
|
||||
source=entry["source"],
|
||||
file_hash=entry["file_hash"],
|
||||
mtime=int(stat.st_mtime),
|
||||
size=stat.st_size,
|
||||
)
|
||||
|
||||
self._dirty = False
|
||||
|
||||
def flush_memory(
|
||||
self,
|
||||
messages: list,
|
||||
|
||||
Reference in New Issue
Block a user