Merge pull request #2832 from yangluxin613/feat/cjk-search-fix

fix(memory): CJK keyword search + vector search optimization
This commit is contained in:
zhayujie
2026-05-25 14:45:49 +08:00
committed by GitHub
4 changed files with 493 additions and 139 deletions

View File

@@ -31,9 +31,13 @@ def detect_index_dim(storage) -> Optional[int]:
if not row or not row["embedding"]: if not row or not row["embedding"]:
return None return None
try: try:
emb = json.loads(row["embedding"]) raw = row["embedding"]
if isinstance(raw, (bytes, bytearray)):
# New BLOB format: 4 bytes per float32
return len(raw) // 4
emb = json.loads(raw)
return len(emb) if isinstance(emb, list) else None return len(emb) if isinstance(emb, list) else None
except (json.JSONDecodeError, TypeError): except (json.JSONDecodeError, TypeError, Exception):
return None return None

View File

@@ -13,7 +13,7 @@ from datetime import datetime, timedelta
from agent.memory.config import MemoryConfig, get_default_memory_config from agent.memory.config import MemoryConfig, get_default_memory_config
from agent.memory.storage import MemoryStorage, MemoryChunk, SearchResult from agent.memory.storage import MemoryStorage, MemoryChunk, SearchResult
from agent.memory.chunker import TextChunker from agent.memory.chunker import TextChunker
from agent.memory.embedding import EmbeddingProvider from agent.memory.embedding import EmbeddingProvider, EmbeddingCache
from agent.memory.summarizer import MemoryFlushManager, create_memory_files_if_needed from agent.memory.summarizer import MemoryFlushManager, create_memory_files_if_needed
@@ -61,7 +61,11 @@ class MemoryManager:
logger.info( logger.info(
"[MemoryManager] No embedding provider; memory will use keyword search only" "[MemoryManager] No embedding provider; memory will use keyword search only"
) )
# Cache for query embeddings (avoids redundant API calls within a session)
self._embedding_cache = EmbeddingCache()
# Initialize memory flush manager # Initialize memory flush manager
workspace_dir = self.config.get_workspace() workspace_dir = self.config.get_workspace()
self.flush_manager = MemoryFlushManager( self.flush_manager = MemoryFlushManager(
@@ -128,7 +132,14 @@ class MemoryManager:
vector_results = [] vector_results = []
if self.embedding_provider: if self.embedding_provider:
try: try:
query_embedding = self.embedding_provider.embed_query(query) provider_name = type(self.embedding_provider).__name__
model_name = getattr(self.embedding_provider, 'model', '')
cached = self._embedding_cache.get(query, provider_name, model_name)
if cached is not None:
query_embedding = cached
else:
query_embedding = self.embedding_provider.embed_query(query)
self._embedding_cache.put(query, provider_name, model_name, query_embedding)
vector_results = self.storage.search_vector( vector_results = self.storage.search_vector(
query_embedding=query_embedding, query_embedding=query_embedding,
user_id=user_id, user_id=user_id,

View File

@@ -5,12 +5,42 @@ Provides vector and keyword search capabilities
""" """
from __future__ import annotations from __future__ import annotations
import re
import sqlite3 import sqlite3
import json import json
import hashlib import hashlib
import threading
from typing import List, Dict, Optional, Any from typing import List, Dict, Optional, Any
from pathlib import Path from pathlib import Path
from dataclasses import dataclass from dataclasses import dataclass
try:
import numpy as np
_HAS_NUMPY = True
except ImportError:
_HAS_NUMPY = False
np = None # type: ignore[assignment]
# UPSERT (INSERT … ON CONFLICT DO UPDATE) requires SQLite ≥ 3.24.0 (2018).
# Older systems (e.g. CentOS 7 ships SQLite 3.7) fall back to INSERT OR REPLACE,
# which risks FTS5 rowid drift on chunk updates (see save_chunk docstring).
_HAS_UPSERT = sqlite3.sqlite_version_info >= (3, 24, 0)
# ---------------------------------------------------------------------------
# CJK character ranges, compiled once at module load.
# Covers: CJK Symbols/Punctuation, Japanese kana (hiragana + katakana),
# CJK Unified Ideographs + Extension A, Korean syllables (Hangul),
# CJK Compatibility Ideographs, and CJK Extension BF.
# ---------------------------------------------------------------------------
_CJK_RANGES = (
r'\u3000-\u30ff' # CJK Symbols/Punctuation + Japanese kana
r'\u3400-\u9fff' # CJK Unified Ideographs (incl. Extension A)
r'\uac00-\ud7af' # Korean syllables (Hangul)
r'\uf900-\ufaff' # CJK Compatibility Ideographs
r'\U00020000-\U0002fa1f' # CJK Extension BF
)
_RE_CONTAINS_CJK = re.compile(f'[{_CJK_RANGES}]')
_RE_CJK_WORDS = re.compile(f'[{_CJK_RANGES}]+')
_RE_TRIGRAM_TOKENS = re.compile(f'[{_CJK_RANGES}]+|[A-Za-z0-9_]+')
@dataclass @dataclass
@@ -48,6 +78,10 @@ class MemoryStorage:
self.db_path = db_path self.db_path = db_path
self.conn: Optional[sqlite3.Connection] = None self.conn: Optional[sqlite3.Connection] = None
self.fts5_available = False # Track FTS5 availability self.fts5_available = False # Track FTS5 availability
# RLock protects concurrent writes from the same process.
# SQLite WAL mode handles read/write concurrency at the file level,
# but same-process concurrent writes still need a Python-level lock.
self._lock = threading.RLock()
self._init_db() self._init_db()
def _check_fts5_support(self) -> bool: def _check_fts5_support(self) -> bool:
@@ -69,6 +103,14 @@ class MemoryStorage:
# Check FTS5 support # Check FTS5 support
self.fts5_available = self._check_fts5_support() self.fts5_available = self._check_fts5_support()
if not _HAS_UPSERT:
from common.log import logger
logger.warning(
"[MemoryStorage] SQLite %s < 3.24 — UPSERT unavailable. "
"Falling back to INSERT OR REPLACE; FTS5 rowid may drift on "
"chunk updates (rebuild index periodically to recover).",
sqlite3.sqlite_version,
)
if not self.fts5_available: if not self.fts5_available:
from common.log import logger from common.log import logger
logger.debug("[MemoryStorage] FTS5 not available, using LIKE-based keyword search") logger.debug("[MemoryStorage] FTS5 not available, using LIKE-based keyword search")
@@ -175,6 +217,75 @@ class MemoryStorage:
) )
self._rebuild_fts5_from_chunks() self._rebuild_fts5_from_chunks()
# Internal key-value store for persistent flags (e.g. backfill tracking)
self.conn.execute("""
CREATE TABLE IF NOT EXISTS _meta (
key TEXT PRIMARY KEY,
value TEXT NOT NULL
)
""")
# Create trigram FTS5 table for CJK / mixed-language search
self.trigram_fts5_available = False
if self.fts5_available:
try:
self.conn.execute("""
CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts_trigram USING fts5(
text,
id UNINDEXED,
user_id UNINDEXED,
path UNINDEXED,
source UNINDEXED,
scope UNINDEXED,
content='chunks',
content_rowid='rowid',
tokenize='trigram case_sensitive 0'
)
""")
self.conn.execute("""
CREATE TRIGGER IF NOT EXISTS chunks_trigram_ai
AFTER INSERT ON chunks BEGIN
INSERT INTO chunks_fts_trigram(rowid, text, id, user_id, path, source, scope)
VALUES (new.rowid, new.text, new.id, new.user_id, new.path, new.source, new.scope);
END
""")
self.conn.execute("""
CREATE TRIGGER IF NOT EXISTS chunks_trigram_ad
AFTER DELETE ON chunks BEGIN
DELETE FROM chunks_fts_trigram WHERE rowid = old.rowid;
END
""")
self.conn.execute("""
CREATE TRIGGER IF NOT EXISTS chunks_trigram_au
AFTER UPDATE ON chunks BEGIN
UPDATE chunks_fts_trigram
SET text=new.text, id=new.id, user_id=new.user_id,
path=new.path, source=new.source, scope=new.scope
WHERE rowid = new.rowid;
END
""")
# One-time backfill for existing rows.
# NOTE: COUNT(*) on an FTS5 content table always returns 0, so we
# use a persistent flag in _meta instead of counting trigram rows.
backfill_done = self.conn.execute(
"SELECT 1 FROM _meta WHERE key = 'trigram_backfill_done'"
).fetchone()
chunks_count = self.conn.execute(
"SELECT COUNT(*) as c FROM chunks"
).fetchone()['c']
if chunks_count > 0 and not backfill_done:
self.conn.execute(
"INSERT INTO chunks_fts_trigram(chunks_fts_trigram) VALUES('rebuild')"
)
self.conn.execute(
"INSERT OR REPLACE INTO _meta(key, value) VALUES('trigram_backfill_done', '1')"
)
self.trigram_fts5_available = True
except Exception:
from common.log import logger
logger.warning("[MemoryStorage] trigram FTS5 unavailable, CJK search will use LIKE fallback", exc_info=True)
self.trigram_fts5_available = False
# Create files metadata table # Create files metadata table
self.conn.execute(""" self.conn.execute("""
CREATE TABLE IF NOT EXISTS files ( CREATE TABLE IF NOT EXISTS files (
@@ -186,7 +297,7 @@ class MemoryStorage:
updated_at INTEGER DEFAULT (strftime('%s', 'now')) updated_at INTEGER DEFAULT (strftime('%s', 'now'))
) )
""") """)
self.conn.commit() self.conn.commit()
def _fts5_state_inconsistent(self) -> bool: def _fts5_state_inconsistent(self) -> bool:
@@ -299,43 +410,98 @@ class MemoryStorage:
self.conn.commit() self.conn.commit()
def save_chunk(self, chunk: MemoryChunk): def save_chunk(self, chunk: MemoryChunk):
"""Save a memory chunk""" """Save a memory chunk (insert or update by id).
self.conn.execute("""
INSERT OR REPLACE INTO chunks Uses SQLite UPSERT (INSERT … ON CONFLICT DO UPDATE) instead of
(id, user_id, scope, source, path, start_line, end_line, text, embedding, hash, metadata, updated_at) INSERT OR REPLACE. INSERT OR REPLACE internally does DELETE+INSERT,
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now')) which changes the row's rowid. Because both FTS5 tables use
""", ( content_rowid='rowid', a new rowid would leave the old FTS index
chunk.id, entries pointing at a non-existent rowid and trigger
chunk.user_id, "fts5: missing row N from content table" errors.
chunk.scope, ON CONFLICT DO UPDATE fires the AFTER UPDATE trigger (chunks_au /
chunk.source, chunks_trigram_au) and keeps the original rowid intact.
chunk.path, """
chunk.start_line, if _HAS_UPSERT:
chunk.end_line, _SQL = """
chunk.text, INSERT INTO chunks
json.dumps(chunk.embedding) if chunk.embedding else None, (id, user_id, scope, source, path, start_line, end_line,
text, embedding, hash, metadata, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'))
ON CONFLICT(id) DO UPDATE SET
user_id = excluded.user_id,
scope = excluded.scope,
source = excluded.source,
path = excluded.path,
start_line = excluded.start_line,
end_line = excluded.end_line,
text = excluded.text,
embedding = excluded.embedding,
hash = excluded.hash,
metadata = excluded.metadata,
updated_at = strftime('%s', 'now')
"""
else:
_SQL = """
INSERT OR REPLACE INTO chunks
(id, user_id, scope, source, path, start_line, end_line,
text, embedding, hash, metadata, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'))
"""
params = (
chunk.id, chunk.user_id, chunk.scope, chunk.source, chunk.path,
chunk.start_line, chunk.end_line, chunk.text,
self._encode_embedding(chunk.embedding),
chunk.hash, chunk.hash,
json.dumps(chunk.metadata) if chunk.metadata else None json.dumps(chunk.metadata) if chunk.metadata else None,
)) )
self.conn.commit() with self._lock:
self.conn.execute(_SQL, params)
self.conn.commit()
def save_chunks_batch(self, chunks: List[MemoryChunk]): def save_chunks_batch(self, chunks: List[MemoryChunk]):
"""Save multiple chunks in a batch""" """Save multiple chunks in a batch (insert or update by id).
self.conn.executemany("""
INSERT OR REPLACE INTO chunks See save_chunk for why UPSERT is used instead of INSERT OR REPLACE.
(id, user_id, scope, source, path, start_line, end_line, text, embedding, hash, metadata, updated_at) """
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now')) if _HAS_UPSERT:
""", [ _SQL = """
INSERT INTO chunks
(id, user_id, scope, source, path, start_line, end_line,
text, embedding, hash, metadata, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'))
ON CONFLICT(id) DO UPDATE SET
user_id = excluded.user_id,
scope = excluded.scope,
source = excluded.source,
path = excluded.path,
start_line = excluded.start_line,
end_line = excluded.end_line,
text = excluded.text,
embedding = excluded.embedding,
hash = excluded.hash,
metadata = excluded.metadata,
updated_at = strftime('%s', 'now')
"""
else:
_SQL = """
INSERT OR REPLACE INTO chunks
(id, user_id, scope, source, path, start_line, end_line,
text, embedding, hash, metadata, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'))
"""
params_list = [
( (
c.id, c.user_id, c.scope, c.source, c.path, c.id, c.user_id, c.scope, c.source, c.path,
c.start_line, c.end_line, c.text, c.start_line, c.end_line, c.text,
json.dumps(c.embedding) if c.embedding else None, self._encode_embedding(c.embedding),
c.hash, c.hash,
json.dumps(c.metadata) if c.metadata else None json.dumps(c.metadata) if c.metadata else None,
) )
for c in chunks for c in chunks
]) ]
self.conn.commit() with self._lock:
self.conn.executemany(_SQL, params_list)
self.conn.commit()
def get_chunk(self, chunk_id: str) -> Optional[MemoryChunk]: def get_chunk(self, chunk_id: str) -> Optional[MemoryChunk]:
"""Get a chunk by ID""" """Get a chunk by ID"""
@@ -356,21 +522,21 @@ class MemoryStorage:
limit: int = 10 limit: int = 10
) -> List[SearchResult]: ) -> List[SearchResult]:
""" """
Vector similarity search using in-memory cosine similarity Vector similarity search using numpy-vectorized cosine similarity.
(sqlite-vec can be added later for better performance) All embeddings are loaded then scored in a single BLAS matrix-vector
multiply, which is ~100x faster than the pure-Python per-row loop.
""" """
if scopes is None: if scopes is None:
scopes = ["shared"] scopes = ["shared"]
if user_id: if user_id:
scopes.append("user") scopes.append("user")
# Build query
scope_placeholders = ','.join('?' * len(scopes)) scope_placeholders = ','.join('?' * len(scopes))
params = scopes params = list(scopes)
if user_id: if user_id:
query = f""" query = f"""
SELECT * FROM chunks SELECT * FROM chunks
WHERE scope IN ({scope_placeholders}) WHERE scope IN ({scope_placeholders})
AND (scope = 'shared' OR user_id = ?) AND (scope = 'shared' OR user_id = ?)
AND embedding IS NOT NULL AND embedding IS NOT NULL
@@ -378,51 +544,95 @@ class MemoryStorage:
params.append(user_id) params.append(user_id)
else: else:
query = f""" query = f"""
SELECT * FROM chunks SELECT * FROM chunks
WHERE scope IN ({scope_placeholders}) WHERE scope IN ({scope_placeholders})
AND embedding IS NOT NULL AND embedding IS NOT NULL
""" """
rows = self.conn.execute(query, params).fetchall() rows = self.conn.execute(query, params).fetchall()
if not rows:
return []
# Calculate cosine similarity. We probe the first row's dim to fail # Parse embeddings and build a (N, D) matrix in one pass.
# loudly on a query/index dim mismatch — otherwise every doc would # New rows store BLOB bytes (np.frombuffer); legacy rows fall back to JSON.
# score 0 silently, leaving the user wondering why search broke. # Filter out rows whose embedding dimension differs from the query —
results = [] # mixing dimensions would cause np.array() to produce an object array
query_dim = len(query_embedding) # and matrix @ q_vec to raise ValueError.
if rows: expected_dim = len(query_embedding)
first = json.loads(rows[0]['embedding']) valid_rows = []
if isinstance(first, list) and len(first) != query_dim: vectors = []
raise ValueError(
f"Embedding dim mismatch: query is {query_dim}-dim but "
f"index stores {len(first)}-dim vectors. The configured "
f"embedding model differs from the one that built the "
f"index — run /memory rebuild-index to re-embed."
)
for row in rows: for row in rows:
embedding = json.loads(row['embedding']) vec = self._decode_embedding(row['embedding'])
similarity = self._cosine_similarity(query_embedding, embedding) if not vec:
continue
if len(vec) != expected_dim:
from common.log import logger
logger.warning(
"[MemoryStorage] Skipping chunk %s: embedding dim %d != query dim %d",
row['id'], len(vec), expected_dim
)
continue
valid_rows.append(row)
vectors.append(vec)
if similarity > 0: if not vectors:
results.append((similarity, row)) return []
# Sort by similarity and limit if _HAS_NUMPY:
results.sort(key=lambda x: x[0], reverse=True) matrix = np.array(vectors, dtype=np.float32) # (N, D)
results = results[:limit] q_vec = np.array(query_embedding, dtype=np.float32) # (D,)
return [ # Vectorized cosine similarity: dot(matrix, q) / (||matrix|| * ||q||)
SearchResult( dots = matrix @ q_vec # (N,)
path=row['path'], row_norms = np.linalg.norm(matrix, axis=1) # (N,)
start_line=row['start_line'], q_norm = float(np.linalg.norm(q_vec))
end_line=row['end_line'], denominators = row_norms * q_norm
score=score, np.maximum(denominators, 1e-10, out=denominators) # avoid div-by-zero
snippet=self._truncate_text(row['text'], 500), sims = dots / denominators # (N,)
source=row['source'],
user_id=row['user_id'] # Select TopK using argpartition (O(N) average), then sort only those K
) k = min(limit, len(valid_rows))
for score, row in results top_idx = np.argpartition(sims, -k)[-k:]
] top_idx = top_idx[np.argsort(sims[top_idx])[::-1]]
return [
SearchResult(
path=valid_rows[i]['path'],
start_line=valid_rows[i]['start_line'],
end_line=valid_rows[i]['end_line'],
score=float(sims[i]),
snippet=self._truncate_text(valid_rows[i]['text'], 500),
source=valid_rows[i]['source'],
user_id=valid_rows[i]['user_id']
)
for i in top_idx
if sims[i] > 0
]
else:
# Pure-Python cosine similarity fallback (numpy not installed)
import math
q = query_embedding
q_norm = math.sqrt(sum(x * x for x in q)) or 1e-10
scored = []
for i, vec in enumerate(vectors):
dot = sum(a * b for a, b in zip(vec, q))
v_norm = math.sqrt(sum(x * x for x in vec)) or 1e-10
sim = dot / (v_norm * q_norm)
if sim > 0:
scored.append((sim, valid_rows[i]))
scored.sort(key=lambda x: x[0], reverse=True)
return [
SearchResult(
path=row['path'],
start_line=row['start_line'],
end_line=row['end_line'],
score=sim,
snippet=self._truncate_text(row['text'], 500),
source=row['source'],
user_id=row['user_id']
)
for sim, row in scored[:limit]
]
def search_keyword( def search_keyword(
self, self,
@@ -445,12 +655,37 @@ class MemoryStorage:
if user_id: if user_id:
scopes.append("user") scopes.append("user")
if self.fts5_available: # Step 1: Standard FTS5 (unicode61) — pure ASCII queries only.
# Skipped when query contains any CJK characters: unicode61 tokenises CJK
# as individual characters without forming meaningful tokens, so it would
# match only the ASCII portion of a mixed query (e.g. "Python" from
# "Python教程") and silently discard the CJK part. Those queries go
# directly to Step 2 (trigram), which handles both ASCII and CJK together.
fts1_attempted = False
if (self.fts5_available
and not MemoryStorage._contains_cjk(query)
and MemoryStorage._build_fts_query(query)):
fts1_attempted = True
fts_results = self._search_fts5(query, user_id, scopes, limit) fts_results = self._search_fts5(query, user_id, scopes, limit)
if fts_results: if fts_results:
return fts_results return fts_results
return self._search_like(query, user_id, scopes, limit) # Step 2: Trigram FTS5 — CJK/mixed queries, plus fallback when unicode61
# returned nothing (trigram indexes all scripts with 3-char sliding windows,
# so it can catch terms that unicode61 tokenisation misses).
if self.trigram_fts5_available and (
MemoryStorage._contains_cjk(query) or fts1_attempted
):
trigram_results = self._search_fts5_trigram(query, user_id, scopes, limit)
if trigram_results:
return trigram_results
# Step 3: LIKE fallback — last resort (FTS5 unavailable, or CJK tokens
# shorter than 3 characters that trigram cannot match, e.g. a single-char query).
if not self.fts5_available or MemoryStorage._contains_cjk(query):
return self._search_like(query, user_id, scopes, limit)
return []
def _search_fts5( def _search_fts5(
self, self,
@@ -471,7 +706,7 @@ class MemoryStorage:
sql_query = f""" sql_query = f"""
SELECT chunks.*, bm25(chunks_fts) as rank SELECT chunks.*, bm25(chunks_fts) as rank
FROM chunks_fts FROM chunks_fts
JOIN chunks ON chunks.id = chunks_fts.id JOIN chunks ON chunks.rowid = chunks_fts.rowid
WHERE chunks_fts MATCH ? WHERE chunks_fts MATCH ?
AND chunks.scope IN ({scope_placeholders}) AND chunks.scope IN ({scope_placeholders})
AND (chunks.scope = 'shared' OR chunks.user_id = ?) AND (chunks.scope = 'shared' OR chunks.user_id = ?)
@@ -483,7 +718,7 @@ class MemoryStorage:
sql_query = f""" sql_query = f"""
SELECT chunks.*, bm25(chunks_fts) as rank SELECT chunks.*, bm25(chunks_fts) as rank
FROM chunks_fts FROM chunks_fts
JOIN chunks ON chunks.id = chunks_fts.id JOIN chunks ON chunks.rowid = chunks_fts.rowid
WHERE chunks_fts MATCH ? WHERE chunks_fts MATCH ?
AND chunks.scope IN ({scope_placeholders}) AND chunks.scope IN ({scope_placeholders})
ORDER BY rank ORDER BY rank
@@ -505,13 +740,11 @@ class MemoryStorage:
) )
for row in rows for row in rows
] ]
except Exception as e: except Exception:
from common.log import logger from common.log import logger
logger.error( logger.warning("[MemoryStorage] _search_fts5 failed, returning empty", exc_info=True)
f"[MemoryStorage] FTS5 search failed (caller will fall back to LIKE): {e}"
)
return [] return []
def _search_like( def _search_like(
self, self,
query: str, query: str,
@@ -522,12 +755,11 @@ class MemoryStorage:
"""LIKE-based search. """LIKE-based search.
Used as the keyword-search fallback when FTS5 is unavailable, fails, Used as the keyword-search fallback when FTS5 is unavailable, fails,
or returns empty. Supports both CJK runs and ASCII word tokens so it or returns empty. Supports both CJK runs (1+ chars) and ASCII word
can serve as a true safety net for any query. tokens (3+ chars) so it can serve as a true safety net for any query.
""" """
import re # CJK runs (1+ chars, wide Unicode range) + ASCII words (3+ chars to avoid noise)
# CJK runs (2+ chars) + ASCII word tokens (3+ chars to avoid noise) cjk_words = _RE_CJK_WORDS.findall(query)
cjk_words = re.findall(r'[\u4e00-\u9fff]{2,}', query)
ascii_words = [t for t in re.findall(r'[A-Za-z0-9_]+', query) if len(t) >= 3] ascii_words = [t for t in re.findall(r'[A-Za-z0-9_]+', query) if len(t) >= 3]
words = cjk_words + ascii_words words = cjk_words + ascii_words
if not words: if not words:
@@ -565,44 +797,54 @@ class MemoryStorage:
try: try:
rows = self.conn.execute(sql_query, params).fetchall() rows = self.conn.execute(sql_query, params).fetchall()
return [ results = []
SearchResult( for row in rows:
# Dynamic score: reward chunks that contain more of the query words.
# Use all tokens (CJK + ASCII) so pure-ASCII queries are not skipped.
# matched_count is always ≥1 because the WHERE clause uses OR, but
# guard defensively so unexpected zero-match rows are never surfaced.
text_lower = row['text'].lower()
matched_count = sum(1 for w in words if w.lower() in text_lower)
if matched_count == 0:
continue
score = min(0.85, 0.3 + 0.15 * matched_count)
results.append(SearchResult(
path=row['path'], path=row['path'],
start_line=row['start_line'], start_line=row['start_line'],
end_line=row['end_line'], end_line=row['end_line'],
score=0.5, # Fixed score for LIKE search score=score,
snippet=self._truncate_text(row['text'], 500), snippet=self._truncate_text(row['text'], 500),
source=row['source'], source=row['source'],
user_id=row['user_id'] user_id=row['user_id']
) ))
for row in rows results.sort(key=lambda r: r.score, reverse=True)
] return results
except Exception as e: except Exception:
from common.log import logger from common.log import logger
logger.error(f"[MemoryStorage] LIKE search failed: {e}") logger.warning("[MemoryStorage] _search_like failed, returning empty", exc_info=True)
return [] return []
def delete_by_path(self, path: str): def delete_by_path(self, path: str):
"""Delete all chunks from a file""" """Delete all chunks from a file"""
self.conn.execute(""" with self._lock:
DELETE FROM chunks WHERE path = ? self.conn.execute("DELETE FROM chunks WHERE path = ?", (path,))
""", (path,)) self.conn.commit()
self.conn.commit()
def get_file_hash(self, path: str) -> Optional[str]: def get_file_hash(self, path: str) -> Optional[str]:
"""Get stored file hash""" """Get stored file hash"""
row = self.conn.execute(""" row = self.conn.execute("""
SELECT hash FROM files WHERE path = ? SELECT hash FROM files WHERE path = ?
""", (path,)).fetchone() """, (path,)).fetchone()
return row['hash'] if row else None return row['hash'] if row else None
def update_file_metadata(self, path: str, source: str, file_hash: str, mtime: int, size: int): def update_file_metadata(self, path: str, source: str, file_hash: str, mtime: int, size: int):
"""Update file metadata""" """Update file metadata"""
self.conn.execute(""" with self._lock:
INSERT OR REPLACE INTO files (path, source, hash, mtime, size, updated_at) self.conn.execute("""
VALUES (?, ?, ?, ?, ?, strftime('%s', 'now')) INSERT OR REPLACE INTO files (path, source, hash, mtime, size, updated_at)
""", (path, source, file_hash, mtime, size)) VALUES (?, ?, ?, ?, ?, strftime('%s', 'now'))
self.conn.commit() """, (path, source, file_hash, mtime, size))
self.conn.commit()
def get_stats(self) -> Dict[str, int]: def get_stats(self) -> Dict[str, int]:
"""Get storage statistics""" """Get storage statistics"""
@@ -632,7 +874,8 @@ class MemoryStorage:
self.conn.close() self.conn.close()
self.conn = None # Mark as closed self.conn = None # Mark as closed
except Exception as e: except Exception as e:
print(f"⚠️ Error closing database connection: {e}") from common.log import logger
logger.warning("[MemoryStorage] Error closing database connection: %s", e)
def __del__(self): def __del__(self):
"""Destructor to ensure connection is closed""" """Destructor to ensure connection is closed"""
@@ -642,7 +885,33 @@ class MemoryStorage:
pass # Ignore errors during cleanup pass # Ignore errors during cleanup
# Helper methods # Helper methods
@staticmethod
def _encode_embedding(embedding: Optional[List[float]]) -> Optional[bytes]:
"""Encode embedding as float32 BLOB bytes (~6x smaller and faster than JSON).
Falls back to struct.pack when numpy is unavailable."""
if embedding is None:
return None
if _HAS_NUMPY:
return np.array(embedding, dtype=np.float32).tobytes()
import struct
return struct.pack(f'{len(embedding)}f', *embedding)
@staticmethod
def _decode_embedding(raw) -> Optional[List[float]]:
"""Decode embedding from BLOB bytes or legacy JSON string.
Handles both numpy and numpy-free environments."""
if raw is None:
return None
if isinstance(raw, (bytes, bytearray)):
if _HAS_NUMPY:
return np.frombuffer(raw, dtype=np.float32).tolist()
import struct
n = len(raw) // 4
return list(struct.unpack(f'{n}f', raw))
# Legacy JSON format written by older versions
return json.loads(raw)
def _row_to_chunk(self, row) -> MemoryChunk: def _row_to_chunk(self, row) -> MemoryChunk:
"""Convert database row to MemoryChunk""" """Convert database row to MemoryChunk"""
return MemoryChunk( return MemoryChunk(
@@ -654,32 +923,89 @@ class MemoryStorage:
start_line=row['start_line'], start_line=row['start_line'],
end_line=row['end_line'], end_line=row['end_line'],
text=row['text'], text=row['text'],
embedding=json.loads(row['embedding']) if row['embedding'] else None, embedding=self._decode_embedding(row['embedding']),
hash=row['hash'], hash=row['hash'],
metadata=json.loads(row['metadata']) if row['metadata'] else None metadata=json.loads(row['metadata']) if row['metadata'] else None
) )
@staticmethod @staticmethod
def _cosine_similarity(vec1: List[float], vec2: List[float]) -> float: def _contains_cjk(text: str) -> bool:
"""Calculate cosine similarity between two vectors""" """Check if text contains CJK or related characters (Chinese, Japanese, Korean)."""
if len(vec1) != len(vec2): return bool(_RE_CONTAINS_CJK.search(text))
return 0.0
dot_product = sum(a * b for a, b in zip(vec1, vec2))
norm1 = sum(a * a for a in vec1) ** 0.5
norm2 = sum(b * b for b in vec2) ** 0.5
if norm1 == 0 or norm2 == 0:
return 0.0
return dot_product / (norm1 * norm2)
@staticmethod @staticmethod
def _contains_cjk(text: str) -> bool: def _build_trigram_query(raw_query: str) -> Optional[str]:
"""Check if text contains CJK (Chinese/Japanese/Korean) characters""" """
import re Build FTS5 MATCH query for the trigram tokenizer.
return bool(re.search(r'[\u4e00-\u9fff]', text)) Extracts CJK sequences (including single characters) and ASCII words,
joining them with AND so all terms must appear in the matched chunk.
"""
tokens = _RE_TRIGRAM_TOKENS.findall(raw_query)
tokens = [t for t in tokens if t]
if not tokens:
return None
# Escape embedded double-quotes (FTS5 uses "" inside quoted phrases)
quoted = [f'"{t.replace(chr(34), chr(34)*2)}"' for t in tokens]
return ' AND '.join(quoted)
def _search_fts5_trigram(
self,
query: str,
user_id: Optional[str],
scopes: List[str],
limit: int
) -> List[SearchResult]:
"""Trigram FTS5 search — handles CJK and mixed queries with BM25 ranking."""
trigram_query = self._build_trigram_query(query)
if not trigram_query:
return []
scope_placeholders = ','.join('?' * len(scopes))
params = [trigram_query] + list(scopes)
if user_id:
sql = f"""
SELECT chunks.*, bm25(chunks_fts_trigram) as rank
FROM chunks_fts_trigram
JOIN chunks ON chunks.rowid = chunks_fts_trigram.rowid
WHERE chunks_fts_trigram MATCH ?
AND chunks.scope IN ({scope_placeholders})
AND (chunks.scope = 'shared' OR chunks.user_id = ?)
ORDER BY rank
LIMIT ?
"""
params.extend([user_id, limit])
else:
sql = f"""
SELECT chunks.*, bm25(chunks_fts_trigram) as rank
FROM chunks_fts_trigram
JOIN chunks ON chunks.rowid = chunks_fts_trigram.rowid
WHERE chunks_fts_trigram MATCH ?
AND chunks.scope IN ({scope_placeholders})
ORDER BY rank
LIMIT ?
"""
params.append(limit)
try:
rows = self.conn.execute(sql, params).fetchall()
return [
SearchResult(
path=row['path'],
start_line=row['start_line'],
end_line=row['end_line'],
score=self._bm25_rank_to_score(row['rank']),
snippet=self._truncate_text(row['text'], 500),
source=row['source'],
user_id=row['user_id']
)
for row in rows
]
except Exception:
from common.log import logger
logger.warning("[MemoryStorage] _search_fts5_trigram failed, returning empty", exc_info=True)
return []
@staticmethod @staticmethod
def _build_fts_query(raw_query: str) -> Optional[str]: def _build_fts_query(raw_query: str) -> Optional[str]:
""" """
@@ -688,7 +1014,6 @@ class MemoryStorage:
Works best for English and word-based languages. Works best for English and word-based languages.
For CJK characters, LIKE search will be used as fallback. For CJK characters, LIKE search will be used as fallback.
""" """
import re
# Extract words (primarily English words and numbers) # Extract words (primarily English words and numbers)
tokens = re.findall(r'[A-Za-z0-9_]+', raw_query) tokens = re.findall(r'[A-Za-z0-9_]+', raw_query)
if not tokens: if not tokens:
@@ -701,9 +1026,22 @@ class MemoryStorage:
@staticmethod @staticmethod
def _bm25_rank_to_score(rank: float) -> float: def _bm25_rank_to_score(rank: float) -> float:
"""Convert BM25 rank to 0-1 score""" """Convert SQLite BM25 rank to a [0, 1) relevance score.
normalized = max(0, rank) if rank is not None else 999
return 1 / (1 + normalized) SQLite's bm25() returns a non-positive float (0 or negative).
More negative = more relevant. max(0, rank) would clip every
negative value to 0, making every score 1/(1+0) = 1.0 and
destroying all ranking information.
abs(rank) / (1 + abs(rank)) maps the absolute relevance magnitude
to [0, 1): larger |rank| (stronger match) → score closer to 1.
"""
if rank is None:
return 0.0
# Add a floor of 0.3 so any FTS5 match always exceeds typical
# min_score thresholds (default 0.1). Small-corpus ranks close to
# 0 would otherwise produce score≈0 and be filtered out downstream.
return 0.3 + 0.69 * (abs(rank) / (1.0 + abs(rank)))
@staticmethod @staticmethod
def _truncate_text(text: str, max_chars: int) -> str: def _truncate_text(text: str, max_chars: int) -> str:

View File

@@ -1,3 +1,4 @@
numpy>=1.24
aiohttp>=3.8.6,<3.10 aiohttp>=3.8.6,<3.10
requests>=2.28.2 requests>=2.28.2
chardet>=5.1.0 chardet>=5.1.0