diff --git a/agent/memory/embedding/state.py b/agent/memory/embedding/state.py index 3fb60b23..5efffef2 100644 --- a/agent/memory/embedding/state.py +++ b/agent/memory/embedding/state.py @@ -31,9 +31,13 @@ def detect_index_dim(storage) -> Optional[int]: if not row or not row["embedding"]: return None try: - emb = json.loads(row["embedding"]) + raw = row["embedding"] + if isinstance(raw, (bytes, bytearray)): + # New BLOB format: 4 bytes per float32 + return len(raw) // 4 + emb = json.loads(raw) return len(emb) if isinstance(emb, list) else None - except (json.JSONDecodeError, TypeError): + except (json.JSONDecodeError, TypeError, Exception): return None diff --git a/agent/memory/storage.py b/agent/memory/storage.py index 1a004904..683b083f 100644 --- a/agent/memory/storage.py +++ b/agent/memory/storage.py @@ -13,7 +13,17 @@ import threading from typing import List, Dict, Optional, Any from pathlib import Path from dataclasses import dataclass -import numpy as np +try: + import numpy as np + _HAS_NUMPY = True +except ImportError: + _HAS_NUMPY = False + np = None # type: ignore[assignment] + +# UPSERT (INSERT … ON CONFLICT DO UPDATE) requires SQLite ≥ 3.24.0 (2018). +# Older systems (e.g. CentOS 7 ships SQLite 3.7) fall back to INSERT OR REPLACE, +# which risks FTS5 rowid drift on chunk updates (see save_chunk docstring). +_HAS_UPSERT = sqlite3.sqlite_version_info >= (3, 24, 0) # --------------------------------------------------------------------------- # CJK character ranges, compiled once at module load. @@ -93,6 +103,14 @@ class MemoryStorage: # Check FTS5 support self.fts5_available = self._check_fts5_support() + if not _HAS_UPSERT: + from common.log import logger + logger.warning( + "[MemoryStorage] SQLite %s < 3.24 — UPSERT unavailable. " + "Falling back to INSERT OR REPLACE; FTS5 rowid may drift on " + "chunk updates (rebuild index periodically to recover).", + sqlite3.sqlite_version, + ) if not self.fts5_available: from common.log import logger logger.debug("[MemoryStorage] FTS5 not available, using LIKE-based keyword search") @@ -403,24 +421,32 @@ class MemoryStorage: ON CONFLICT DO UPDATE fires the AFTER UPDATE trigger (chunks_au / chunks_trigram_au) and keeps the original rowid intact. """ - _SQL = """ - INSERT INTO chunks - (id, user_id, scope, source, path, start_line, end_line, - text, embedding, hash, metadata, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now')) - ON CONFLICT(id) DO UPDATE SET - user_id = excluded.user_id, - scope = excluded.scope, - source = excluded.source, - path = excluded.path, - start_line = excluded.start_line, - end_line = excluded.end_line, - text = excluded.text, - embedding = excluded.embedding, - hash = excluded.hash, - metadata = excluded.metadata, - updated_at = strftime('%s', 'now') - """ + if _HAS_UPSERT: + _SQL = """ + INSERT INTO chunks + (id, user_id, scope, source, path, start_line, end_line, + text, embedding, hash, metadata, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now')) + ON CONFLICT(id) DO UPDATE SET + user_id = excluded.user_id, + scope = excluded.scope, + source = excluded.source, + path = excluded.path, + start_line = excluded.start_line, + end_line = excluded.end_line, + text = excluded.text, + embedding = excluded.embedding, + hash = excluded.hash, + metadata = excluded.metadata, + updated_at = strftime('%s', 'now') + """ + else: + _SQL = """ + INSERT OR REPLACE INTO chunks + (id, user_id, scope, source, path, start_line, end_line, + text, embedding, hash, metadata, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now')) + """ params = ( chunk.id, chunk.user_id, chunk.scope, chunk.source, chunk.path, chunk.start_line, chunk.end_line, chunk.text, @@ -437,24 +463,32 @@ class MemoryStorage: See save_chunk for why UPSERT is used instead of INSERT OR REPLACE. """ - _SQL = """ - INSERT INTO chunks - (id, user_id, scope, source, path, start_line, end_line, - text, embedding, hash, metadata, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now')) - ON CONFLICT(id) DO UPDATE SET - user_id = excluded.user_id, - scope = excluded.scope, - source = excluded.source, - path = excluded.path, - start_line = excluded.start_line, - end_line = excluded.end_line, - text = excluded.text, - embedding = excluded.embedding, - hash = excluded.hash, - metadata = excluded.metadata, - updated_at = strftime('%s', 'now') - """ + if _HAS_UPSERT: + _SQL = """ + INSERT INTO chunks + (id, user_id, scope, source, path, start_line, end_line, + text, embedding, hash, metadata, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now')) + ON CONFLICT(id) DO UPDATE SET + user_id = excluded.user_id, + scope = excluded.scope, + source = excluded.source, + path = excluded.path, + start_line = excluded.start_line, + end_line = excluded.end_line, + text = excluded.text, + embedding = excluded.embedding, + hash = excluded.hash, + metadata = excluded.metadata, + updated_at = strftime('%s', 'now') + """ + else: + _SQL = """ + INSERT OR REPLACE INTO chunks + (id, user_id, scope, source, path, start_line, end_line, + text, embedding, hash, metadata, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now')) + """ params_list = [ ( c.id, c.user_id, c.scope, c.source, c.path, @@ -544,36 +578,61 @@ class MemoryStorage: if not vectors: return [] - matrix = np.array(vectors, dtype=np.float32) # (N, D) - q_vec = np.array(query_embedding, dtype=np.float32) # (D,) + if _HAS_NUMPY: + matrix = np.array(vectors, dtype=np.float32) # (N, D) + q_vec = np.array(query_embedding, dtype=np.float32) # (D,) - # Vectorized cosine similarity: dot(matrix, q) / (||matrix|| * ||q||) - dots = matrix @ q_vec # (N,) - row_norms = np.linalg.norm(matrix, axis=1) # (N,) - q_norm = float(np.linalg.norm(q_vec)) - denominators = row_norms * q_norm - np.maximum(denominators, 1e-10, out=denominators) # avoid div-by-zero - sims = dots / denominators # (N,) + # Vectorized cosine similarity: dot(matrix, q) / (||matrix|| * ||q||) + dots = matrix @ q_vec # (N,) + row_norms = np.linalg.norm(matrix, axis=1) # (N,) + q_norm = float(np.linalg.norm(q_vec)) + denominators = row_norms * q_norm + np.maximum(denominators, 1e-10, out=denominators) # avoid div-by-zero + sims = dots / denominators # (N,) - # Select TopK using argpartition (O(N) average), then sort only those K - k = min(limit, len(valid_rows)) - top_idx = np.argpartition(sims, -k)[-k:] - top_idx = top_idx[np.argsort(sims[top_idx])[::-1]] + # Select TopK using argpartition (O(N) average), then sort only those K + k = min(limit, len(valid_rows)) + top_idx = np.argpartition(sims, -k)[-k:] + top_idx = top_idx[np.argsort(sims[top_idx])[::-1]] - - return [ - SearchResult( - path=valid_rows[i]['path'], - start_line=valid_rows[i]['start_line'], - end_line=valid_rows[i]['end_line'], - score=float(sims[i]), - snippet=self._truncate_text(valid_rows[i]['text'], 500), - source=valid_rows[i]['source'], - user_id=valid_rows[i]['user_id'] - ) - for i in top_idx - if sims[i] > 0 - ] + return [ + SearchResult( + path=valid_rows[i]['path'], + start_line=valid_rows[i]['start_line'], + end_line=valid_rows[i]['end_line'], + score=float(sims[i]), + snippet=self._truncate_text(valid_rows[i]['text'], 500), + source=valid_rows[i]['source'], + user_id=valid_rows[i]['user_id'] + ) + for i in top_idx + if sims[i] > 0 + ] + else: + # Pure-Python cosine similarity fallback (numpy not installed) + import math + q = query_embedding + q_norm = math.sqrt(sum(x * x for x in q)) or 1e-10 + scored = [] + for i, vec in enumerate(vectors): + dot = sum(a * b for a, b in zip(vec, q)) + v_norm = math.sqrt(sum(x * x for x in vec)) or 1e-10 + sim = dot / (v_norm * q_norm) + if sim > 0: + scored.append((sim, valid_rows[i])) + scored.sort(key=lambda x: x[0], reverse=True) + return [ + SearchResult( + path=row['path'], + start_line=row['start_line'], + end_line=row['end_line'], + score=sim, + snippet=self._truncate_text(row['text'], 500), + source=row['source'], + user_id=row['user_id'] + ) + for sim, row in scored[:limit] + ] def search_keyword( self, @@ -621,8 +680,8 @@ class MemoryStorage: if trigram_results: return trigram_results - # Step 3: LIKE fallback — last resort (FTS5 unavailable, or CJK single-char - # that trigram cannot match because it requires ≥3-char tokens). + # Step 3: LIKE fallback — last resort (FTS5 unavailable, or CJK tokens + # shorter than 3 characters that trigram cannot match, e.g. a single-char query). if not self.fts5_available or MemoryStorage._contains_cjk(query): return self._search_like(query, user_id, scopes, limit) @@ -829,18 +888,27 @@ class MemoryStorage: @staticmethod def _encode_embedding(embedding: Optional[List[float]]) -> Optional[bytes]: - """Encode embedding as float32 BLOB bytes (~6x smaller and faster than JSON).""" + """Encode embedding as float32 BLOB bytes (~6x smaller and faster than JSON). + Falls back to struct.pack when numpy is unavailable.""" if embedding is None: return None - return np.array(embedding, dtype=np.float32).tobytes() + if _HAS_NUMPY: + return np.array(embedding, dtype=np.float32).tobytes() + import struct + return struct.pack(f'{len(embedding)}f', *embedding) @staticmethod def _decode_embedding(raw) -> Optional[List[float]]: - """Decode embedding from BLOB bytes or legacy JSON string.""" + """Decode embedding from BLOB bytes or legacy JSON string. + Handles both numpy and numpy-free environments.""" if raw is None: return None if isinstance(raw, (bytes, bytearray)): - return np.frombuffer(raw, dtype=np.float32).tolist() + if _HAS_NUMPY: + return np.frombuffer(raw, dtype=np.float32).tolist() + import struct + n = len(raw) // 4 + return list(struct.unpack(f'{n}f', raw)) # Legacy JSON format written by older versions return json.loads(raw) @@ -970,7 +1038,10 @@ class MemoryStorage: """ if rank is None: return 0.0 - return abs(rank) / (1.0 + abs(rank)) + # Add a floor of 0.3 so any FTS5 match always exceeds typical + # min_score thresholds (default 0.1). Small-corpus ranks close to + # 0 would otherwise produce score≈0 and be filtered out downstream. + return 0.3 + 0.69 * (abs(rank) / (1.0 + abs(rank))) @staticmethod def _truncate_text(text: str, max_chars: int) -> str: