mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-02 00:57:41 +08:00
fix(memory): CJK keyword search + vector search optimization
- Add trigram FTS5 table for CJK/mixed-language search with BM25 ranking - Fix three-step search routing: unicode61 (ASCII) → trigram (CJK/mixed) → LIKE fallback - Fix _bm25_rank_to_score: abs(rank)/(1+abs(rank)) instead of max(0,rank) - Fix INSERT OR REPLACE → UPSERT to preserve FTS5 content table rowid stability - Fix FTS5 JOIN to use rowid instead of id column - Fix _search_like: single-char CJK match, dynamic scoring, merged CJK+ASCII path - Add numpy vectorized cosine similarity + BLOB embedding storage (6x smaller) - Add _decode_embedding backward compat for legacy JSON embeddings - Add threading.RLock for concurrent write safety - Add _meta table to avoid trigram backfill re-running on every startup - Activate EmbeddingCache in MemoryManager for session-level query deduplication - Add numpy>=1.24 to requirements.txt - Merge upstream master (embedding package refactor, FTS5 self-healing methods) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -13,7 +13,7 @@ from datetime import datetime, timedelta
|
||||
from agent.memory.config import MemoryConfig, get_default_memory_config
|
||||
from agent.memory.storage import MemoryStorage, MemoryChunk, SearchResult
|
||||
from agent.memory.chunker import TextChunker
|
||||
from agent.memory.embedding import EmbeddingProvider
|
||||
from agent.memory.embedding import EmbeddingProvider, EmbeddingCache
|
||||
from agent.memory.summarizer import MemoryFlushManager, create_memory_files_if_needed
|
||||
|
||||
|
||||
@@ -61,7 +61,11 @@ class MemoryManager:
|
||||
logger.info(
|
||||
"[MemoryManager] No embedding provider; memory will use keyword search only"
|
||||
)
|
||||
|
||||
|
||||
# Cache for query embeddings (avoids redundant API calls within a session)
|
||||
self._embedding_cache = EmbeddingCache()
|
||||
|
||||
|
||||
# Initialize memory flush manager
|
||||
workspace_dir = self.config.get_workspace()
|
||||
self.flush_manager = MemoryFlushManager(
|
||||
@@ -128,7 +132,14 @@ class MemoryManager:
|
||||
vector_results = []
|
||||
if self.embedding_provider:
|
||||
try:
|
||||
query_embedding = self.embedding_provider.embed_query(query)
|
||||
provider_name = type(self.embedding_provider).__name__
|
||||
model_name = getattr(self.embedding_provider, 'model', '')
|
||||
cached = self._embedding_cache.get(query, provider_name, model_name)
|
||||
if cached is not None:
|
||||
query_embedding = cached
|
||||
else:
|
||||
query_embedding = self.embedding_provider.embed_query(query)
|
||||
self._embedding_cache.put(query, provider_name, model_name, query_embedding)
|
||||
vector_results = self.storage.search_vector(
|
||||
query_embedding=query_embedding,
|
||||
user_id=user_id,
|
||||
|
||||
@@ -5,12 +5,32 @@ Provides vector and keyword search capabilities
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import re
|
||||
import sqlite3
|
||||
import json
|
||||
import hashlib
|
||||
import threading
|
||||
from typing import List, Dict, Optional, Any
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CJK character ranges, compiled once at module load.
|
||||
# Covers: CJK Symbols/Punctuation, Japanese kana (hiragana + katakana),
|
||||
# CJK Unified Ideographs + Extension A, Korean syllables (Hangul),
|
||||
# CJK Compatibility Ideographs, and CJK Extension B–F.
|
||||
# ---------------------------------------------------------------------------
|
||||
_CJK_RANGES = (
|
||||
r'\u3000-\u30ff' # CJK Symbols/Punctuation + Japanese kana
|
||||
r'\u3400-\u9fff' # CJK Unified Ideographs (incl. Extension A)
|
||||
r'\uac00-\ud7af' # Korean syllables (Hangul)
|
||||
r'\uf900-\ufaff' # CJK Compatibility Ideographs
|
||||
r'\U00020000-\U0002fa1f' # CJK Extension B–F
|
||||
)
|
||||
_RE_CONTAINS_CJK = re.compile(f'[{_CJK_RANGES}]')
|
||||
_RE_CJK_WORDS = re.compile(f'[{_CJK_RANGES}]+')
|
||||
_RE_TRIGRAM_TOKENS = re.compile(f'[{_CJK_RANGES}]+|[A-Za-z0-9_]+')
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -48,6 +68,10 @@ class MemoryStorage:
|
||||
self.db_path = db_path
|
||||
self.conn: Optional[sqlite3.Connection] = None
|
||||
self.fts5_available = False # Track FTS5 availability
|
||||
# RLock protects concurrent writes from the same process.
|
||||
# SQLite WAL mode handles read/write concurrency at the file level,
|
||||
# but same-process concurrent writes still need a Python-level lock.
|
||||
self._lock = threading.RLock()
|
||||
self._init_db()
|
||||
|
||||
def _check_fts5_support(self) -> bool:
|
||||
@@ -175,6 +199,75 @@ class MemoryStorage:
|
||||
)
|
||||
self._rebuild_fts5_from_chunks()
|
||||
|
||||
# Internal key-value store for persistent flags (e.g. backfill tracking)
|
||||
self.conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS _meta (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
|
||||
# Create trigram FTS5 table for CJK / mixed-language search
|
||||
self.trigram_fts5_available = False
|
||||
if self.fts5_available:
|
||||
try:
|
||||
self.conn.execute("""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts_trigram USING fts5(
|
||||
text,
|
||||
id UNINDEXED,
|
||||
user_id UNINDEXED,
|
||||
path UNINDEXED,
|
||||
source UNINDEXED,
|
||||
scope UNINDEXED,
|
||||
content='chunks',
|
||||
content_rowid='rowid',
|
||||
tokenize='trigram case_sensitive 0'
|
||||
)
|
||||
""")
|
||||
self.conn.execute("""
|
||||
CREATE TRIGGER IF NOT EXISTS chunks_trigram_ai
|
||||
AFTER INSERT ON chunks BEGIN
|
||||
INSERT INTO chunks_fts_trigram(rowid, text, id, user_id, path, source, scope)
|
||||
VALUES (new.rowid, new.text, new.id, new.user_id, new.path, new.source, new.scope);
|
||||
END
|
||||
""")
|
||||
self.conn.execute("""
|
||||
CREATE TRIGGER IF NOT EXISTS chunks_trigram_ad
|
||||
AFTER DELETE ON chunks BEGIN
|
||||
DELETE FROM chunks_fts_trigram WHERE rowid = old.rowid;
|
||||
END
|
||||
""")
|
||||
self.conn.execute("""
|
||||
CREATE TRIGGER IF NOT EXISTS chunks_trigram_au
|
||||
AFTER UPDATE ON chunks BEGIN
|
||||
UPDATE chunks_fts_trigram
|
||||
SET text=new.text, id=new.id, user_id=new.user_id,
|
||||
path=new.path, source=new.source, scope=new.scope
|
||||
WHERE rowid = new.rowid;
|
||||
END
|
||||
""")
|
||||
# One-time backfill for existing rows.
|
||||
# NOTE: COUNT(*) on an FTS5 content table always returns 0, so we
|
||||
# use a persistent flag in _meta instead of counting trigram rows.
|
||||
backfill_done = self.conn.execute(
|
||||
"SELECT 1 FROM _meta WHERE key = 'trigram_backfill_done'"
|
||||
).fetchone()
|
||||
chunks_count = self.conn.execute(
|
||||
"SELECT COUNT(*) as c FROM chunks"
|
||||
).fetchone()['c']
|
||||
if chunks_count > 0 and not backfill_done:
|
||||
self.conn.execute(
|
||||
"INSERT INTO chunks_fts_trigram(chunks_fts_trigram) VALUES('rebuild')"
|
||||
)
|
||||
self.conn.execute(
|
||||
"INSERT OR REPLACE INTO _meta(key, value) VALUES('trigram_backfill_done', '1')"
|
||||
)
|
||||
self.trigram_fts5_available = True
|
||||
except Exception:
|
||||
from common.log import logger
|
||||
logger.warning("[MemoryStorage] trigram FTS5 unavailable, CJK search will use LIKE fallback", exc_info=True)
|
||||
self.trigram_fts5_available = False
|
||||
|
||||
# Create files metadata table
|
||||
self.conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS files (
|
||||
@@ -186,7 +279,7 @@ class MemoryStorage:
|
||||
updated_at INTEGER DEFAULT (strftime('%s', 'now'))
|
||||
)
|
||||
""")
|
||||
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
def _fts5_state_inconsistent(self) -> bool:
|
||||
@@ -299,43 +392,82 @@ class MemoryStorage:
|
||||
self.conn.commit()
|
||||
|
||||
def save_chunk(self, chunk: MemoryChunk):
|
||||
"""Save a memory chunk"""
|
||||
self.conn.execute("""
|
||||
INSERT OR REPLACE INTO chunks
|
||||
(id, user_id, scope, source, path, start_line, end_line, text, embedding, hash, metadata, updated_at)
|
||||
"""Save a memory chunk (insert or update by id).
|
||||
|
||||
Uses SQLite UPSERT (INSERT … ON CONFLICT DO UPDATE) instead of
|
||||
INSERT OR REPLACE. INSERT OR REPLACE internally does DELETE+INSERT,
|
||||
which changes the row's rowid. Because both FTS5 tables use
|
||||
content_rowid='rowid', a new rowid would leave the old FTS index
|
||||
entries pointing at a non-existent rowid and trigger
|
||||
"fts5: missing row N from content table" errors.
|
||||
ON CONFLICT DO UPDATE fires the AFTER UPDATE trigger (chunks_au /
|
||||
chunks_trigram_au) and keeps the original rowid intact.
|
||||
"""
|
||||
_SQL = """
|
||||
INSERT INTO chunks
|
||||
(id, user_id, scope, source, path, start_line, end_line,
|
||||
text, embedding, hash, metadata, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'))
|
||||
""", (
|
||||
chunk.id,
|
||||
chunk.user_id,
|
||||
chunk.scope,
|
||||
chunk.source,
|
||||
chunk.path,
|
||||
chunk.start_line,
|
||||
chunk.end_line,
|
||||
chunk.text,
|
||||
json.dumps(chunk.embedding) if chunk.embedding else None,
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
user_id = excluded.user_id,
|
||||
scope = excluded.scope,
|
||||
source = excluded.source,
|
||||
path = excluded.path,
|
||||
start_line = excluded.start_line,
|
||||
end_line = excluded.end_line,
|
||||
text = excluded.text,
|
||||
embedding = excluded.embedding,
|
||||
hash = excluded.hash,
|
||||
metadata = excluded.metadata,
|
||||
updated_at = strftime('%s', 'now')
|
||||
"""
|
||||
params = (
|
||||
chunk.id, chunk.user_id, chunk.scope, chunk.source, chunk.path,
|
||||
chunk.start_line, chunk.end_line, chunk.text,
|
||||
self._encode_embedding(chunk.embedding),
|
||||
chunk.hash,
|
||||
json.dumps(chunk.metadata) if chunk.metadata else None
|
||||
))
|
||||
self.conn.commit()
|
||||
|
||||
json.dumps(chunk.metadata) if chunk.metadata else None,
|
||||
)
|
||||
with self._lock:
|
||||
self.conn.execute(_SQL, params)
|
||||
self.conn.commit()
|
||||
|
||||
def save_chunks_batch(self, chunks: List[MemoryChunk]):
|
||||
"""Save multiple chunks in a batch"""
|
||||
self.conn.executemany("""
|
||||
INSERT OR REPLACE INTO chunks
|
||||
(id, user_id, scope, source, path, start_line, end_line, text, embedding, hash, metadata, updated_at)
|
||||
"""Save multiple chunks in a batch (insert or update by id).
|
||||
|
||||
See save_chunk for why UPSERT is used instead of INSERT OR REPLACE.
|
||||
"""
|
||||
_SQL = """
|
||||
INSERT INTO chunks
|
||||
(id, user_id, scope, source, path, start_line, end_line,
|
||||
text, embedding, hash, metadata, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'))
|
||||
""", [
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
user_id = excluded.user_id,
|
||||
scope = excluded.scope,
|
||||
source = excluded.source,
|
||||
path = excluded.path,
|
||||
start_line = excluded.start_line,
|
||||
end_line = excluded.end_line,
|
||||
text = excluded.text,
|
||||
embedding = excluded.embedding,
|
||||
hash = excluded.hash,
|
||||
metadata = excluded.metadata,
|
||||
updated_at = strftime('%s', 'now')
|
||||
"""
|
||||
params_list = [
|
||||
(
|
||||
c.id, c.user_id, c.scope, c.source, c.path,
|
||||
c.start_line, c.end_line, c.text,
|
||||
json.dumps(c.embedding) if c.embedding else None,
|
||||
self._encode_embedding(c.embedding),
|
||||
c.hash,
|
||||
json.dumps(c.metadata) if c.metadata else None
|
||||
json.dumps(c.metadata) if c.metadata else None,
|
||||
)
|
||||
for c in chunks
|
||||
])
|
||||
self.conn.commit()
|
||||
]
|
||||
with self._lock:
|
||||
self.conn.executemany(_SQL, params_list)
|
||||
self.conn.commit()
|
||||
|
||||
def get_chunk(self, chunk_id: str) -> Optional[MemoryChunk]:
|
||||
"""Get a chunk by ID"""
|
||||
@@ -356,21 +488,21 @@ class MemoryStorage:
|
||||
limit: int = 10
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
Vector similarity search using in-memory cosine similarity
|
||||
(sqlite-vec can be added later for better performance)
|
||||
Vector similarity search using numpy-vectorized cosine similarity.
|
||||
All embeddings are loaded then scored in a single BLAS matrix-vector
|
||||
multiply, which is ~100x faster than the pure-Python per-row loop.
|
||||
"""
|
||||
if scopes is None:
|
||||
scopes = ["shared"]
|
||||
if user_id:
|
||||
scopes.append("user")
|
||||
|
||||
# Build query
|
||||
|
||||
scope_placeholders = ','.join('?' * len(scopes))
|
||||
params = scopes
|
||||
|
||||
params = list(scopes)
|
||||
|
||||
if user_id:
|
||||
query = f"""
|
||||
SELECT * FROM chunks
|
||||
SELECT * FROM chunks
|
||||
WHERE scope IN ({scope_placeholders})
|
||||
AND (scope = 'shared' OR user_id = ?)
|
||||
AND embedding IS NOT NULL
|
||||
@@ -378,50 +510,69 @@ class MemoryStorage:
|
||||
params.append(user_id)
|
||||
else:
|
||||
query = f"""
|
||||
SELECT * FROM chunks
|
||||
SELECT * FROM chunks
|
||||
WHERE scope IN ({scope_placeholders})
|
||||
AND embedding IS NOT NULL
|
||||
"""
|
||||
|
||||
|
||||
rows = self.conn.execute(query, params).fetchall()
|
||||
if not rows:
|
||||
return []
|
||||
|
||||
# Calculate cosine similarity. We probe the first row's dim to fail
|
||||
# loudly on a query/index dim mismatch — otherwise every doc would
|
||||
# score 0 silently, leaving the user wondering why search broke.
|
||||
results = []
|
||||
query_dim = len(query_embedding)
|
||||
if rows:
|
||||
first = json.loads(rows[0]['embedding'])
|
||||
if isinstance(first, list) and len(first) != query_dim:
|
||||
raise ValueError(
|
||||
f"Embedding dim mismatch: query is {query_dim}-dim but "
|
||||
f"index stores {len(first)}-dim vectors. The configured "
|
||||
f"embedding model differs from the one that built the "
|
||||
f"index — run /memory rebuild-index to re-embed."
|
||||
)
|
||||
|
||||
# Parse embeddings and build a (N, D) matrix in one pass.
|
||||
# New rows store BLOB bytes (np.frombuffer); legacy rows fall back to JSON.
|
||||
# Filter out rows whose embedding dimension differs from the query —
|
||||
# mixing dimensions would cause np.array() to produce an object array
|
||||
# and matrix @ q_vec to raise ValueError.
|
||||
expected_dim = len(query_embedding)
|
||||
valid_rows = []
|
||||
vectors = []
|
||||
for row in rows:
|
||||
embedding = json.loads(row['embedding'])
|
||||
similarity = self._cosine_similarity(query_embedding, embedding)
|
||||
vec = self._decode_embedding(row['embedding'])
|
||||
if not vec:
|
||||
continue
|
||||
if len(vec) != expected_dim:
|
||||
from common.log import logger
|
||||
logger.warning(
|
||||
"[MemoryStorage] Skipping chunk %s: embedding dim %d != query dim %d",
|
||||
row['id'], len(vec), expected_dim
|
||||
)
|
||||
continue
|
||||
valid_rows.append(row)
|
||||
vectors.append(vec)
|
||||
|
||||
if not vectors:
|
||||
return []
|
||||
|
||||
matrix = np.array(vectors, dtype=np.float32) # (N, D)
|
||||
q_vec = np.array(query_embedding, dtype=np.float32) # (D,)
|
||||
|
||||
# Vectorized cosine similarity: dot(matrix, q) / (||matrix|| * ||q||)
|
||||
dots = matrix @ q_vec # (N,)
|
||||
row_norms = np.linalg.norm(matrix, axis=1) # (N,)
|
||||
q_norm = float(np.linalg.norm(q_vec))
|
||||
denominators = row_norms * q_norm
|
||||
np.maximum(denominators, 1e-10, out=denominators) # avoid div-by-zero
|
||||
sims = dots / denominators # (N,)
|
||||
|
||||
# Select TopK using argpartition (O(N) average), then sort only those K
|
||||
k = min(limit, len(valid_rows))
|
||||
top_idx = np.argpartition(sims, -k)[-k:]
|
||||
top_idx = top_idx[np.argsort(sims[top_idx])[::-1]]
|
||||
|
||||
|
||||
if similarity > 0:
|
||||
results.append((similarity, row))
|
||||
|
||||
# Sort by similarity and limit
|
||||
results.sort(key=lambda x: x[0], reverse=True)
|
||||
results = results[:limit]
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
path=row['path'],
|
||||
start_line=row['start_line'],
|
||||
end_line=row['end_line'],
|
||||
score=score,
|
||||
snippet=self._truncate_text(row['text'], 500),
|
||||
source=row['source'],
|
||||
user_id=row['user_id']
|
||||
path=valid_rows[i]['path'],
|
||||
start_line=valid_rows[i]['start_line'],
|
||||
end_line=valid_rows[i]['end_line'],
|
||||
score=float(sims[i]),
|
||||
snippet=self._truncate_text(valid_rows[i]['text'], 500),
|
||||
source=valid_rows[i]['source'],
|
||||
user_id=valid_rows[i]['user_id']
|
||||
)
|
||||
for score, row in results
|
||||
for i in top_idx
|
||||
if sims[i] > 0
|
||||
]
|
||||
|
||||
def search_keyword(
|
||||
@@ -445,12 +596,37 @@ class MemoryStorage:
|
||||
if user_id:
|
||||
scopes.append("user")
|
||||
|
||||
if self.fts5_available:
|
||||
# Step 1: Standard FTS5 (unicode61) — pure ASCII queries only.
|
||||
# Skipped when query contains any CJK characters: unicode61 tokenises CJK
|
||||
# as individual characters without forming meaningful tokens, so it would
|
||||
# match only the ASCII portion of a mixed query (e.g. "Python" from
|
||||
# "Python教程") and silently discard the CJK part. Those queries go
|
||||
# directly to Step 2 (trigram), which handles both ASCII and CJK together.
|
||||
fts1_attempted = False
|
||||
if (self.fts5_available
|
||||
and not MemoryStorage._contains_cjk(query)
|
||||
and MemoryStorage._build_fts_query(query)):
|
||||
fts1_attempted = True
|
||||
fts_results = self._search_fts5(query, user_id, scopes, limit)
|
||||
if fts_results:
|
||||
return fts_results
|
||||
|
||||
return self._search_like(query, user_id, scopes, limit)
|
||||
# Step 2: Trigram FTS5 — CJK/mixed queries, plus fallback when unicode61
|
||||
# returned nothing (trigram indexes all scripts with 3-char sliding windows,
|
||||
# so it can catch terms that unicode61 tokenisation misses).
|
||||
if self.trigram_fts5_available and (
|
||||
MemoryStorage._contains_cjk(query) or fts1_attempted
|
||||
):
|
||||
trigram_results = self._search_fts5_trigram(query, user_id, scopes, limit)
|
||||
if trigram_results:
|
||||
return trigram_results
|
||||
|
||||
# Step 3: LIKE fallback — last resort (FTS5 unavailable, or CJK single-char
|
||||
# that trigram cannot match because it requires ≥3-char tokens).
|
||||
if not self.fts5_available or MemoryStorage._contains_cjk(query):
|
||||
return self._search_like(query, user_id, scopes, limit)
|
||||
|
||||
return []
|
||||
|
||||
def _search_fts5(
|
||||
self,
|
||||
@@ -471,7 +647,7 @@ class MemoryStorage:
|
||||
sql_query = f"""
|
||||
SELECT chunks.*, bm25(chunks_fts) as rank
|
||||
FROM chunks_fts
|
||||
JOIN chunks ON chunks.id = chunks_fts.id
|
||||
JOIN chunks ON chunks.rowid = chunks_fts.rowid
|
||||
WHERE chunks_fts MATCH ?
|
||||
AND chunks.scope IN ({scope_placeholders})
|
||||
AND (chunks.scope = 'shared' OR chunks.user_id = ?)
|
||||
@@ -483,7 +659,7 @@ class MemoryStorage:
|
||||
sql_query = f"""
|
||||
SELECT chunks.*, bm25(chunks_fts) as rank
|
||||
FROM chunks_fts
|
||||
JOIN chunks ON chunks.id = chunks_fts.id
|
||||
JOIN chunks ON chunks.rowid = chunks_fts.rowid
|
||||
WHERE chunks_fts MATCH ?
|
||||
AND chunks.scope IN ({scope_placeholders})
|
||||
ORDER BY rank
|
||||
@@ -505,13 +681,11 @@ class MemoryStorage:
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
from common.log import logger
|
||||
logger.error(
|
||||
f"[MemoryStorage] FTS5 search failed (caller will fall back to LIKE): {e}"
|
||||
)
|
||||
logger.warning("[MemoryStorage] _search_fts5 failed, returning empty", exc_info=True)
|
||||
return []
|
||||
|
||||
|
||||
def _search_like(
|
||||
self,
|
||||
query: str,
|
||||
@@ -522,12 +696,11 @@ class MemoryStorage:
|
||||
"""LIKE-based search.
|
||||
|
||||
Used as the keyword-search fallback when FTS5 is unavailable, fails,
|
||||
or returns empty. Supports both CJK runs and ASCII word tokens so it
|
||||
can serve as a true safety net for any query.
|
||||
or returns empty. Supports both CJK runs (1+ chars) and ASCII word
|
||||
tokens (3+ chars) so it can serve as a true safety net for any query.
|
||||
"""
|
||||
import re
|
||||
# CJK runs (2+ chars) + ASCII word tokens (3+ chars to avoid noise)
|
||||
cjk_words = re.findall(r'[\u4e00-\u9fff]{2,}', query)
|
||||
# CJK runs (1+ chars, wide Unicode range) + ASCII words (3+ chars to avoid noise)
|
||||
cjk_words = _RE_CJK_WORDS.findall(query)
|
||||
ascii_words = [t for t in re.findall(r'[A-Za-z0-9_]+', query) if len(t) >= 3]
|
||||
words = cjk_words + ascii_words
|
||||
if not words:
|
||||
@@ -565,44 +738,52 @@ class MemoryStorage:
|
||||
|
||||
try:
|
||||
rows = self.conn.execute(sql_query, params).fetchall()
|
||||
return [
|
||||
SearchResult(
|
||||
results = []
|
||||
for row in rows:
|
||||
# Dynamic score: reward chunks that contain more of the query words.
|
||||
# matched_count should always be ≥1 (WHERE uses OR), but guard
|
||||
# defensively so zero-match rows are never surfaced.
|
||||
matched_count = sum(1 for w in cjk_words if w in row['text'])
|
||||
if matched_count == 0:
|
||||
continue
|
||||
score = min(0.85, 0.3 + 0.15 * matched_count)
|
||||
results.append(SearchResult(
|
||||
path=row['path'],
|
||||
start_line=row['start_line'],
|
||||
end_line=row['end_line'],
|
||||
score=0.5, # Fixed score for LIKE search
|
||||
score=score,
|
||||
snippet=self._truncate_text(row['text'], 500),
|
||||
source=row['source'],
|
||||
user_id=row['user_id']
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
except Exception as e:
|
||||
))
|
||||
results.sort(key=lambda r: r.score, reverse=True)
|
||||
return results
|
||||
except Exception:
|
||||
from common.log import logger
|
||||
logger.error(f"[MemoryStorage] LIKE search failed: {e}")
|
||||
logger.warning("[MemoryStorage] _search_like failed, returning empty", exc_info=True)
|
||||
return []
|
||||
|
||||
|
||||
def delete_by_path(self, path: str):
|
||||
"""Delete all chunks from a file"""
|
||||
self.conn.execute("""
|
||||
DELETE FROM chunks WHERE path = ?
|
||||
""", (path,))
|
||||
self.conn.commit()
|
||||
|
||||
with self._lock:
|
||||
self.conn.execute("DELETE FROM chunks WHERE path = ?", (path,))
|
||||
self.conn.commit()
|
||||
|
||||
def get_file_hash(self, path: str) -> Optional[str]:
|
||||
"""Get stored file hash"""
|
||||
row = self.conn.execute("""
|
||||
SELECT hash FROM files WHERE path = ?
|
||||
""", (path,)).fetchone()
|
||||
return row['hash'] if row else None
|
||||
|
||||
|
||||
def update_file_metadata(self, path: str, source: str, file_hash: str, mtime: int, size: int):
|
||||
"""Update file metadata"""
|
||||
self.conn.execute("""
|
||||
INSERT OR REPLACE INTO files (path, source, hash, mtime, size, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, strftime('%s', 'now'))
|
||||
""", (path, source, file_hash, mtime, size))
|
||||
self.conn.commit()
|
||||
with self._lock:
|
||||
self.conn.execute("""
|
||||
INSERT OR REPLACE INTO files (path, source, hash, mtime, size, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, strftime('%s', 'now'))
|
||||
""", (path, source, file_hash, mtime, size))
|
||||
self.conn.commit()
|
||||
|
||||
def get_stats(self) -> Dict[str, int]:
|
||||
"""Get storage statistics"""
|
||||
@@ -632,7 +813,8 @@ class MemoryStorage:
|
||||
self.conn.close()
|
||||
self.conn = None # Mark as closed
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error closing database connection: {e}")
|
||||
from common.log import logger
|
||||
logger.warning("[MemoryStorage] Error closing database connection: %s", e)
|
||||
|
||||
def __del__(self):
|
||||
"""Destructor to ensure connection is closed"""
|
||||
@@ -642,7 +824,24 @@ class MemoryStorage:
|
||||
pass # Ignore errors during cleanup
|
||||
|
||||
# Helper methods
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _encode_embedding(embedding: Optional[List[float]]) -> Optional[bytes]:
|
||||
"""Encode embedding as float32 BLOB bytes (~6x smaller and faster than JSON)."""
|
||||
if embedding is None:
|
||||
return None
|
||||
return np.array(embedding, dtype=np.float32).tobytes()
|
||||
|
||||
@staticmethod
|
||||
def _decode_embedding(raw) -> Optional[List[float]]:
|
||||
"""Decode embedding from BLOB bytes or legacy JSON string."""
|
||||
if raw is None:
|
||||
return None
|
||||
if isinstance(raw, (bytes, bytearray)):
|
||||
return np.frombuffer(raw, dtype=np.float32).tolist()
|
||||
# Legacy JSON format written by older versions
|
||||
return json.loads(raw)
|
||||
|
||||
def _row_to_chunk(self, row) -> MemoryChunk:
|
||||
"""Convert database row to MemoryChunk"""
|
||||
return MemoryChunk(
|
||||
@@ -654,32 +853,89 @@ class MemoryStorage:
|
||||
start_line=row['start_line'],
|
||||
end_line=row['end_line'],
|
||||
text=row['text'],
|
||||
embedding=json.loads(row['embedding']) if row['embedding'] else None,
|
||||
embedding=self._decode_embedding(row['embedding']),
|
||||
hash=row['hash'],
|
||||
metadata=json.loads(row['metadata']) if row['metadata'] else None
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _cosine_similarity(vec1: List[float], vec2: List[float]) -> float:
|
||||
"""Calculate cosine similarity between two vectors"""
|
||||
if len(vec1) != len(vec2):
|
||||
return 0.0
|
||||
|
||||
dot_product = sum(a * b for a, b in zip(vec1, vec2))
|
||||
norm1 = sum(a * a for a in vec1) ** 0.5
|
||||
norm2 = sum(b * b for b in vec2) ** 0.5
|
||||
|
||||
if norm1 == 0 or norm2 == 0:
|
||||
return 0.0
|
||||
|
||||
return dot_product / (norm1 * norm2)
|
||||
def _contains_cjk(text: str) -> bool:
|
||||
"""Check if text contains CJK or related characters (Chinese, Japanese, Korean)."""
|
||||
return bool(_RE_CONTAINS_CJK.search(text))
|
||||
|
||||
@staticmethod
|
||||
def _contains_cjk(text: str) -> bool:
|
||||
"""Check if text contains CJK (Chinese/Japanese/Korean) characters"""
|
||||
import re
|
||||
return bool(re.search(r'[\u4e00-\u9fff]', text))
|
||||
|
||||
def _build_trigram_query(raw_query: str) -> Optional[str]:
|
||||
"""
|
||||
Build FTS5 MATCH query for the trigram tokenizer.
|
||||
Extracts CJK sequences (including single characters) and ASCII words,
|
||||
joining them with AND so all terms must appear in the matched chunk.
|
||||
"""
|
||||
tokens = _RE_TRIGRAM_TOKENS.findall(raw_query)
|
||||
tokens = [t for t in tokens if t]
|
||||
if not tokens:
|
||||
return None
|
||||
# Escape embedded double-quotes (FTS5 uses "" inside quoted phrases)
|
||||
quoted = [f'"{t.replace(chr(34), chr(34)*2)}"' for t in tokens]
|
||||
return ' AND '.join(quoted)
|
||||
|
||||
def _search_fts5_trigram(
|
||||
self,
|
||||
query: str,
|
||||
user_id: Optional[str],
|
||||
scopes: List[str],
|
||||
limit: int
|
||||
) -> List[SearchResult]:
|
||||
"""Trigram FTS5 search — handles CJK and mixed queries with BM25 ranking."""
|
||||
trigram_query = self._build_trigram_query(query)
|
||||
if not trigram_query:
|
||||
return []
|
||||
|
||||
scope_placeholders = ','.join('?' * len(scopes))
|
||||
params = [trigram_query] + list(scopes)
|
||||
|
||||
if user_id:
|
||||
sql = f"""
|
||||
SELECT chunks.*, bm25(chunks_fts_trigram) as rank
|
||||
FROM chunks_fts_trigram
|
||||
JOIN chunks ON chunks.rowid = chunks_fts_trigram.rowid
|
||||
WHERE chunks_fts_trigram MATCH ?
|
||||
AND chunks.scope IN ({scope_placeholders})
|
||||
AND (chunks.scope = 'shared' OR chunks.user_id = ?)
|
||||
ORDER BY rank
|
||||
LIMIT ?
|
||||
"""
|
||||
params.extend([user_id, limit])
|
||||
else:
|
||||
sql = f"""
|
||||
SELECT chunks.*, bm25(chunks_fts_trigram) as rank
|
||||
FROM chunks_fts_trigram
|
||||
JOIN chunks ON chunks.rowid = chunks_fts_trigram.rowid
|
||||
WHERE chunks_fts_trigram MATCH ?
|
||||
AND chunks.scope IN ({scope_placeholders})
|
||||
ORDER BY rank
|
||||
LIMIT ?
|
||||
"""
|
||||
params.append(limit)
|
||||
|
||||
try:
|
||||
rows = self.conn.execute(sql, params).fetchall()
|
||||
return [
|
||||
SearchResult(
|
||||
path=row['path'],
|
||||
start_line=row['start_line'],
|
||||
end_line=row['end_line'],
|
||||
score=self._bm25_rank_to_score(row['rank']),
|
||||
snippet=self._truncate_text(row['text'], 500),
|
||||
source=row['source'],
|
||||
user_id=row['user_id']
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
except Exception:
|
||||
from common.log import logger
|
||||
logger.warning("[MemoryStorage] _search_fts5_trigram failed, returning empty", exc_info=True)
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _build_fts_query(raw_query: str) -> Optional[str]:
|
||||
"""
|
||||
@@ -688,7 +944,6 @@ class MemoryStorage:
|
||||
Works best for English and word-based languages.
|
||||
For CJK characters, LIKE search will be used as fallback.
|
||||
"""
|
||||
import re
|
||||
# Extract words (primarily English words and numbers)
|
||||
tokens = re.findall(r'[A-Za-z0-9_]+', raw_query)
|
||||
if not tokens:
|
||||
@@ -701,9 +956,19 @@ class MemoryStorage:
|
||||
|
||||
@staticmethod
|
||||
def _bm25_rank_to_score(rank: float) -> float:
|
||||
"""Convert BM25 rank to 0-1 score"""
|
||||
normalized = max(0, rank) if rank is not None else 999
|
||||
return 1 / (1 + normalized)
|
||||
"""Convert SQLite BM25 rank to a [0, 1) relevance score.
|
||||
|
||||
SQLite's bm25() returns a non-positive float (0 or negative).
|
||||
More negative = more relevant. max(0, rank) would clip every
|
||||
negative value to 0, making every score 1/(1+0) = 1.0 and
|
||||
destroying all ranking information.
|
||||
|
||||
abs(rank) / (1 + abs(rank)) maps the absolute relevance magnitude
|
||||
to [0, 1): larger |rank| (stronger match) → score closer to 1.
|
||||
"""
|
||||
if rank is None:
|
||||
return 0.0
|
||||
return abs(rank) / (1.0 + abs(rank))
|
||||
|
||||
@staticmethod
|
||||
def _truncate_text(text: str, max_chars: int) -> str:
|
||||
|
||||
Reference in New Issue
Block a user