mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-02 00:57:41 +08:00
fix(memory): address PR review — numpy/UPSERT soft deps + BM25 floor + BLOB dim
- numpy soft dependency: try/except import + _HAS_NUMPY flag; _encode_embedding and _decode_embedding fall back to struct.pack/unpack; search_vector falls back to pure-Python cosine loop — startup never fails without numpy reinstalled - SQLite UPSERT guard: _HAS_UPSERT = sqlite_version_info >= (3,24,0); save_chunk and save_chunks_batch fall back to INSERT OR REPLACE on SQLite < 3.24 with a one-time startup warning about potential FTS rowid drift - _bm25_rank_to_score floor: 0.3 + 0.69*(|rank|/(1+|rank|)) → always in [0.3, 0.99), prevents small-corpus matches scoring 0.0 and being filtered by min_score - detect_index_dim BLOB-aware: check isinstance(raw, bytes) first and return len(raw)//4 before json.loads, so /memory status works after embedding format switch - Comment: "CJK single-char" → "CJK tokens shorter than 3 characters" Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -31,9 +31,13 @@ def detect_index_dim(storage) -> Optional[int]:
|
||||
if not row or not row["embedding"]:
|
||||
return None
|
||||
try:
|
||||
emb = json.loads(row["embedding"])
|
||||
raw = row["embedding"]
|
||||
if isinstance(raw, (bytes, bytearray)):
|
||||
# New BLOB format: 4 bytes per float32
|
||||
return len(raw) // 4
|
||||
emb = json.loads(raw)
|
||||
return len(emb) if isinstance(emb, list) else None
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
except (json.JSONDecodeError, TypeError, Exception):
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -13,7 +13,17 @@ import threading
|
||||
from typing import List, Dict, Optional, Any
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
try:
|
||||
import numpy as np
|
||||
_HAS_NUMPY = True
|
||||
except ImportError:
|
||||
_HAS_NUMPY = False
|
||||
np = None # type: ignore[assignment]
|
||||
|
||||
# UPSERT (INSERT … ON CONFLICT DO UPDATE) requires SQLite ≥ 3.24.0 (2018).
|
||||
# Older systems (e.g. CentOS 7 ships SQLite 3.7) fall back to INSERT OR REPLACE,
|
||||
# which risks FTS5 rowid drift on chunk updates (see save_chunk docstring).
|
||||
_HAS_UPSERT = sqlite3.sqlite_version_info >= (3, 24, 0)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CJK character ranges, compiled once at module load.
|
||||
@@ -93,6 +103,14 @@ class MemoryStorage:
|
||||
|
||||
# Check FTS5 support
|
||||
self.fts5_available = self._check_fts5_support()
|
||||
if not _HAS_UPSERT:
|
||||
from common.log import logger
|
||||
logger.warning(
|
||||
"[MemoryStorage] SQLite %s < 3.24 — UPSERT unavailable. "
|
||||
"Falling back to INSERT OR REPLACE; FTS5 rowid may drift on "
|
||||
"chunk updates (rebuild index periodically to recover).",
|
||||
sqlite3.sqlite_version,
|
||||
)
|
||||
if not self.fts5_available:
|
||||
from common.log import logger
|
||||
logger.debug("[MemoryStorage] FTS5 not available, using LIKE-based keyword search")
|
||||
@@ -403,24 +421,32 @@ class MemoryStorage:
|
||||
ON CONFLICT DO UPDATE fires the AFTER UPDATE trigger (chunks_au /
|
||||
chunks_trigram_au) and keeps the original rowid intact.
|
||||
"""
|
||||
_SQL = """
|
||||
INSERT INTO chunks
|
||||
(id, user_id, scope, source, path, start_line, end_line,
|
||||
text, embedding, hash, metadata, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'))
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
user_id = excluded.user_id,
|
||||
scope = excluded.scope,
|
||||
source = excluded.source,
|
||||
path = excluded.path,
|
||||
start_line = excluded.start_line,
|
||||
end_line = excluded.end_line,
|
||||
text = excluded.text,
|
||||
embedding = excluded.embedding,
|
||||
hash = excluded.hash,
|
||||
metadata = excluded.metadata,
|
||||
updated_at = strftime('%s', 'now')
|
||||
"""
|
||||
if _HAS_UPSERT:
|
||||
_SQL = """
|
||||
INSERT INTO chunks
|
||||
(id, user_id, scope, source, path, start_line, end_line,
|
||||
text, embedding, hash, metadata, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'))
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
user_id = excluded.user_id,
|
||||
scope = excluded.scope,
|
||||
source = excluded.source,
|
||||
path = excluded.path,
|
||||
start_line = excluded.start_line,
|
||||
end_line = excluded.end_line,
|
||||
text = excluded.text,
|
||||
embedding = excluded.embedding,
|
||||
hash = excluded.hash,
|
||||
metadata = excluded.metadata,
|
||||
updated_at = strftime('%s', 'now')
|
||||
"""
|
||||
else:
|
||||
_SQL = """
|
||||
INSERT OR REPLACE INTO chunks
|
||||
(id, user_id, scope, source, path, start_line, end_line,
|
||||
text, embedding, hash, metadata, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'))
|
||||
"""
|
||||
params = (
|
||||
chunk.id, chunk.user_id, chunk.scope, chunk.source, chunk.path,
|
||||
chunk.start_line, chunk.end_line, chunk.text,
|
||||
@@ -437,24 +463,32 @@ class MemoryStorage:
|
||||
|
||||
See save_chunk for why UPSERT is used instead of INSERT OR REPLACE.
|
||||
"""
|
||||
_SQL = """
|
||||
INSERT INTO chunks
|
||||
(id, user_id, scope, source, path, start_line, end_line,
|
||||
text, embedding, hash, metadata, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'))
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
user_id = excluded.user_id,
|
||||
scope = excluded.scope,
|
||||
source = excluded.source,
|
||||
path = excluded.path,
|
||||
start_line = excluded.start_line,
|
||||
end_line = excluded.end_line,
|
||||
text = excluded.text,
|
||||
embedding = excluded.embedding,
|
||||
hash = excluded.hash,
|
||||
metadata = excluded.metadata,
|
||||
updated_at = strftime('%s', 'now')
|
||||
"""
|
||||
if _HAS_UPSERT:
|
||||
_SQL = """
|
||||
INSERT INTO chunks
|
||||
(id, user_id, scope, source, path, start_line, end_line,
|
||||
text, embedding, hash, metadata, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'))
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
user_id = excluded.user_id,
|
||||
scope = excluded.scope,
|
||||
source = excluded.source,
|
||||
path = excluded.path,
|
||||
start_line = excluded.start_line,
|
||||
end_line = excluded.end_line,
|
||||
text = excluded.text,
|
||||
embedding = excluded.embedding,
|
||||
hash = excluded.hash,
|
||||
metadata = excluded.metadata,
|
||||
updated_at = strftime('%s', 'now')
|
||||
"""
|
||||
else:
|
||||
_SQL = """
|
||||
INSERT OR REPLACE INTO chunks
|
||||
(id, user_id, scope, source, path, start_line, end_line,
|
||||
text, embedding, hash, metadata, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'))
|
||||
"""
|
||||
params_list = [
|
||||
(
|
||||
c.id, c.user_id, c.scope, c.source, c.path,
|
||||
@@ -544,36 +578,61 @@ class MemoryStorage:
|
||||
if not vectors:
|
||||
return []
|
||||
|
||||
matrix = np.array(vectors, dtype=np.float32) # (N, D)
|
||||
q_vec = np.array(query_embedding, dtype=np.float32) # (D,)
|
||||
if _HAS_NUMPY:
|
||||
matrix = np.array(vectors, dtype=np.float32) # (N, D)
|
||||
q_vec = np.array(query_embedding, dtype=np.float32) # (D,)
|
||||
|
||||
# Vectorized cosine similarity: dot(matrix, q) / (||matrix|| * ||q||)
|
||||
dots = matrix @ q_vec # (N,)
|
||||
row_norms = np.linalg.norm(matrix, axis=1) # (N,)
|
||||
q_norm = float(np.linalg.norm(q_vec))
|
||||
denominators = row_norms * q_norm
|
||||
np.maximum(denominators, 1e-10, out=denominators) # avoid div-by-zero
|
||||
sims = dots / denominators # (N,)
|
||||
# Vectorized cosine similarity: dot(matrix, q) / (||matrix|| * ||q||)
|
||||
dots = matrix @ q_vec # (N,)
|
||||
row_norms = np.linalg.norm(matrix, axis=1) # (N,)
|
||||
q_norm = float(np.linalg.norm(q_vec))
|
||||
denominators = row_norms * q_norm
|
||||
np.maximum(denominators, 1e-10, out=denominators) # avoid div-by-zero
|
||||
sims = dots / denominators # (N,)
|
||||
|
||||
# Select TopK using argpartition (O(N) average), then sort only those K
|
||||
k = min(limit, len(valid_rows))
|
||||
top_idx = np.argpartition(sims, -k)[-k:]
|
||||
top_idx = top_idx[np.argsort(sims[top_idx])[::-1]]
|
||||
# Select TopK using argpartition (O(N) average), then sort only those K
|
||||
k = min(limit, len(valid_rows))
|
||||
top_idx = np.argpartition(sims, -k)[-k:]
|
||||
top_idx = top_idx[np.argsort(sims[top_idx])[::-1]]
|
||||
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
path=valid_rows[i]['path'],
|
||||
start_line=valid_rows[i]['start_line'],
|
||||
end_line=valid_rows[i]['end_line'],
|
||||
score=float(sims[i]),
|
||||
snippet=self._truncate_text(valid_rows[i]['text'], 500),
|
||||
source=valid_rows[i]['source'],
|
||||
user_id=valid_rows[i]['user_id']
|
||||
)
|
||||
for i in top_idx
|
||||
if sims[i] > 0
|
||||
]
|
||||
return [
|
||||
SearchResult(
|
||||
path=valid_rows[i]['path'],
|
||||
start_line=valid_rows[i]['start_line'],
|
||||
end_line=valid_rows[i]['end_line'],
|
||||
score=float(sims[i]),
|
||||
snippet=self._truncate_text(valid_rows[i]['text'], 500),
|
||||
source=valid_rows[i]['source'],
|
||||
user_id=valid_rows[i]['user_id']
|
||||
)
|
||||
for i in top_idx
|
||||
if sims[i] > 0
|
||||
]
|
||||
else:
|
||||
# Pure-Python cosine similarity fallback (numpy not installed)
|
||||
import math
|
||||
q = query_embedding
|
||||
q_norm = math.sqrt(sum(x * x for x in q)) or 1e-10
|
||||
scored = []
|
||||
for i, vec in enumerate(vectors):
|
||||
dot = sum(a * b for a, b in zip(vec, q))
|
||||
v_norm = math.sqrt(sum(x * x for x in vec)) or 1e-10
|
||||
sim = dot / (v_norm * q_norm)
|
||||
if sim > 0:
|
||||
scored.append((sim, valid_rows[i]))
|
||||
scored.sort(key=lambda x: x[0], reverse=True)
|
||||
return [
|
||||
SearchResult(
|
||||
path=row['path'],
|
||||
start_line=row['start_line'],
|
||||
end_line=row['end_line'],
|
||||
score=sim,
|
||||
snippet=self._truncate_text(row['text'], 500),
|
||||
source=row['source'],
|
||||
user_id=row['user_id']
|
||||
)
|
||||
for sim, row in scored[:limit]
|
||||
]
|
||||
|
||||
def search_keyword(
|
||||
self,
|
||||
@@ -621,8 +680,8 @@ class MemoryStorage:
|
||||
if trigram_results:
|
||||
return trigram_results
|
||||
|
||||
# Step 3: LIKE fallback — last resort (FTS5 unavailable, or CJK single-char
|
||||
# that trigram cannot match because it requires ≥3-char tokens).
|
||||
# Step 3: LIKE fallback — last resort (FTS5 unavailable, or CJK tokens
|
||||
# shorter than 3 characters that trigram cannot match, e.g. a single-char query).
|
||||
if not self.fts5_available or MemoryStorage._contains_cjk(query):
|
||||
return self._search_like(query, user_id, scopes, limit)
|
||||
|
||||
@@ -829,18 +888,27 @@ class MemoryStorage:
|
||||
|
||||
@staticmethod
|
||||
def _encode_embedding(embedding: Optional[List[float]]) -> Optional[bytes]:
|
||||
"""Encode embedding as float32 BLOB bytes (~6x smaller and faster than JSON)."""
|
||||
"""Encode embedding as float32 BLOB bytes (~6x smaller and faster than JSON).
|
||||
Falls back to struct.pack when numpy is unavailable."""
|
||||
if embedding is None:
|
||||
return None
|
||||
return np.array(embedding, dtype=np.float32).tobytes()
|
||||
if _HAS_NUMPY:
|
||||
return np.array(embedding, dtype=np.float32).tobytes()
|
||||
import struct
|
||||
return struct.pack(f'{len(embedding)}f', *embedding)
|
||||
|
||||
@staticmethod
|
||||
def _decode_embedding(raw) -> Optional[List[float]]:
|
||||
"""Decode embedding from BLOB bytes or legacy JSON string."""
|
||||
"""Decode embedding from BLOB bytes or legacy JSON string.
|
||||
Handles both numpy and numpy-free environments."""
|
||||
if raw is None:
|
||||
return None
|
||||
if isinstance(raw, (bytes, bytearray)):
|
||||
return np.frombuffer(raw, dtype=np.float32).tolist()
|
||||
if _HAS_NUMPY:
|
||||
return np.frombuffer(raw, dtype=np.float32).tolist()
|
||||
import struct
|
||||
n = len(raw) // 4
|
||||
return list(struct.unpack(f'{n}f', raw))
|
||||
# Legacy JSON format written by older versions
|
||||
return json.loads(raw)
|
||||
|
||||
@@ -970,7 +1038,10 @@ class MemoryStorage:
|
||||
"""
|
||||
if rank is None:
|
||||
return 0.0
|
||||
return abs(rank) / (1.0 + abs(rank))
|
||||
# Add a floor of 0.3 so any FTS5 match always exceeds typical
|
||||
# min_score thresholds (default 0.1). Small-corpus ranks close to
|
||||
# 0 would otherwise produce score≈0 and be filtered out downstream.
|
||||
return 0.3 + 0.69 * (abs(rank) / (1.0 + abs(rank)))
|
||||
|
||||
@staticmethod
|
||||
def _truncate_text(text: str, max_chars: int) -> str:
|
||||
|
||||
Reference in New Issue
Block a user