Merge branch 'master' of github.com:zhayujie/chatgpt-on-wechat

This commit is contained in:
zhayujie
2026-04-13 20:13:30 +08:00
8 changed files with 1257 additions and 105 deletions

View File

@@ -28,11 +28,13 @@ from common.log import logger
_DDL = """
CREATE TABLE IF NOT EXISTS sessions (
session_id TEXT PRIMARY KEY,
channel_type TEXT NOT NULL DEFAULT '',
created_at INTEGER NOT NULL,
last_active INTEGER NOT NULL,
msg_count INTEGER NOT NULL DEFAULT 0
session_id TEXT PRIMARY KEY,
channel_type TEXT NOT NULL DEFAULT '',
title TEXT NOT NULL DEFAULT '',
context_start_seq INTEGER NOT NULL DEFAULT 0,
created_at INTEGER NOT NULL,
last_active INTEGER NOT NULL,
msg_count INTEGER NOT NULL DEFAULT 0
);
CREATE TABLE IF NOT EXISTS messages (
@@ -57,6 +59,14 @@ _MIGRATION_ADD_CHANNEL_TYPE = """
ALTER TABLE sessions ADD COLUMN channel_type TEXT NOT NULL DEFAULT '';
"""
_MIGRATION_ADD_TITLE = """
ALTER TABLE sessions ADD COLUMN title TEXT NOT NULL DEFAULT '';
"""
_MIGRATION_ADD_CONTEXT_START_SEQ = """
ALTER TABLE sessions ADD COLUMN context_start_seq INTEGER NOT NULL DEFAULT 0;
"""
DEFAULT_MAX_AGE_DAYS: int = 30
@@ -287,14 +297,21 @@ class ConversationStore:
with self._lock:
conn = self._connect()
try:
# Respect context_start_seq: only load messages at or after the boundary
ctx_row = conn.execute(
"SELECT context_start_seq FROM sessions WHERE session_id = ?",
(session_id,),
).fetchone()
ctx_start = ctx_row[0] if ctx_row else 0
rows = conn.execute(
"""
SELECT seq, role, content
FROM messages
WHERE session_id = ?
WHERE session_id = ? AND seq >= ?
ORDER BY seq DESC
""",
(session_id,),
(session_id, ctx_start),
).fetchall()
finally:
conn.close()
@@ -302,10 +319,7 @@ class ConversationStore:
if not rows:
return []
# Walk newest-to-oldest counting *visible* user turns (actual user text,
# not tool_result injections). Record the seq of every visible user
# message so we can find a clean cut point later.
visible_turn_seqs: List[int] = [] # newest first
visible_turn_seqs: List[int] = []
for seq, role, raw_content in rows:
if role != "user":
continue
@@ -316,17 +330,11 @@ class ConversationStore:
if _is_visible_user_message(content):
visible_turn_seqs.append(seq)
# Determine the seq of the oldest visible user message we want to keep.
# If the total turns fit within max_turns, keep everything.
if len(visible_turn_seqs) <= max_turns:
cutoff_seq = None # keep all
cutoff_seq = None
else:
# The Nth visible user message (0-indexed) is the oldest we keep.
cutoff_seq = visible_turn_seqs[max_turns - 1]
# Build result in chronological order, starting from cutoff.
# IMPORTANT: we start exactly at cutoff_seq (the visible user message),
# never mid-group, so tool_use / tool_result pairs are always complete.
result = []
for seq, role, raw_content in reversed(rows):
if cutoff_seq is not None and seq < cutoff_seq:
@@ -415,6 +423,61 @@ class ConversationStore:
""",
(session_id, session_id),
)
# Auto-generate title from the first visible user message
cur_title = conn.execute(
"SELECT title FROM sessions WHERE session_id = ?",
(session_id,),
).fetchone()
if cur_title and not cur_title[0]:
for msg in messages:
if msg.get("role") == "user":
content = msg.get("content", "")
text = _extract_display_text(content)
if text:
title = text[:50].split("\n")[0]
conn.execute(
"UPDATE sessions SET title = ? WHERE session_id = ?",
(title, session_id),
)
break
finally:
conn.close()
def clear_context(self, session_id: str) -> int:
"""
Set the context boundary to after the current last message.
Messages before this boundary are still stored but excluded from LLM context.
Returns the new context_start_seq value.
"""
with self._lock:
conn = self._connect()
try:
with conn:
row = conn.execute(
"SELECT COALESCE(MAX(seq), -1) FROM messages WHERE session_id = ?",
(session_id,),
).fetchone()
new_start = row[0] + 1
conn.execute(
"UPDATE sessions SET context_start_seq = ? WHERE session_id = ?",
(new_start, session_id),
)
return new_start
finally:
conn.close()
def get_context_start_seq(self, session_id: str) -> int:
"""Return the context_start_seq for a session (0 if not set)."""
with self._lock:
conn = self._connect()
try:
row = conn.execute(
"SELECT context_start_seq FROM sessions WHERE session_id = ?",
(session_id,),
).fetchone()
return row[0] if row else 0
finally:
conn.close()
@@ -436,6 +499,7 @@ class ConversationStore:
def cleanup_old_sessions(self, max_age_days: Optional[int] = None) -> int:
"""
Delete sessions that have not been active within max_age_days.
Web channel sessions are excluded — they are meant to be permanent.
Args:
max_age_days: Override the default retention period.
@@ -459,7 +523,8 @@ class ConversationStore:
try:
with conn:
stale = conn.execute(
"SELECT session_id FROM sessions WHERE last_active < ?",
"SELECT session_id FROM sessions "
"WHERE last_active < ? AND channel_type != 'web'",
(cutoff,),
).fetchall()
for (sid,) in stale:
@@ -518,9 +583,15 @@ class ConversationStore:
with self._lock:
conn = self._connect()
try:
ctx_row = conn.execute(
"SELECT context_start_seq FROM sessions WHERE session_id = ?",
(session_id,),
).fetchone()
ctx_start = ctx_row[0] if ctx_row else 0
rows = conn.execute(
"""
SELECT role, content, created_at
SELECT seq, role, content, created_at
FROM messages
WHERE session_id = ?
ORDER BY seq ASC
@@ -530,7 +601,30 @@ class ConversationStore:
finally:
conn.close()
visible = _group_into_display_turns(rows)
# Strip seq for display grouping, but record max seq per visible user group
plain_rows = [(role, content, created_at) for _seq, role, content, created_at in rows]
visible = _group_into_display_turns(plain_rows)
# Build a mapping: find the seq of each visible user message to annotate context boundary.
# Walk through rows to find visible user message seqs in order.
visible_user_seqs: List[int] = []
for seq, role, raw_content, _ts in rows:
if role != "user":
continue
try:
content = json.loads(raw_content)
except Exception:
content = raw_content
if _is_visible_user_message(content):
visible_user_seqs.append(seq)
# Each pair of display turns (user+assistant) corresponds to a visible user seq.
# Mark which turns are before the context boundary.
user_turn_idx = 0
for turn in visible:
if turn["role"] == "user" and user_turn_idx < len(visible_user_seqs):
turn["_seq"] = visible_user_seqs[user_turn_idx]
user_turn_idx += 1
total = len(visible)
offset = (page - 1) * page_size
@@ -539,12 +633,98 @@ class ConversationStore:
return {
"messages": page_items,
"context_start_seq": ctx_start,
"total": total,
"page": page,
"page_size": page_size,
"has_more": offset + page_size < total,
}
def list_sessions(
self,
channel_type: Optional[str] = None,
page: int = 1,
page_size: int = 50,
) -> Dict[str, Any]:
"""
List sessions ordered by last_active DESC, with optional channel_type filter.
Returns:
{
"sessions": [{session_id, title, created_at, last_active, msg_count}, ...],
"total": int,
"page": int,
"page_size": int,
"has_more": bool,
}
"""
page = max(1, page)
with self._lock:
conn = self._connect()
try:
if channel_type:
total = conn.execute(
"SELECT COUNT(*) FROM sessions WHERE channel_type = ?",
(channel_type,),
).fetchone()[0]
rows = conn.execute(
"""
SELECT session_id, title, created_at, last_active, msg_count
FROM sessions
WHERE channel_type = ?
ORDER BY last_active DESC
LIMIT ? OFFSET ?
""",
(channel_type, page_size, (page - 1) * page_size),
).fetchall()
else:
total = conn.execute(
"SELECT COUNT(*) FROM sessions",
).fetchone()[0]
rows = conn.execute(
"""
SELECT session_id, title, created_at, last_active, msg_count
FROM sessions
ORDER BY last_active DESC
LIMIT ? OFFSET ?
""",
(page_size, (page - 1) * page_size),
).fetchall()
finally:
conn.close()
sessions = [
{
"session_id": r[0],
"title": r[1],
"created_at": r[2],
"last_active": r[3],
"msg_count": r[4],
}
for r in rows
]
return {
"sessions": sessions,
"total": total,
"page": page,
"page_size": page_size,
"has_more": (page - 1) * page_size + page_size < total,
}
def rename_session(self, session_id: str, title: str) -> bool:
"""Update the title of a session. Returns True if the session existed."""
with self._lock:
conn = self._connect()
try:
with conn:
cur = conn.execute(
"UPDATE sessions SET title = ? WHERE session_id = ?",
(title, session_id),
)
return cur.rowcount > 0
finally:
conn.close()
def get_stats(self) -> Dict[str, Any]:
"""Return basic stats keyed by channel_type, for monitoring."""
with self._lock:
@@ -599,6 +779,20 @@ class ConversationStore:
logger.info("[ConversationStore] Migrated: added channel_type column")
except Exception as e:
logger.warning(f"[ConversationStore] Migration failed: {e}")
if "title" not in cols:
try:
conn.execute(_MIGRATION_ADD_TITLE)
conn.commit()
logger.info("[ConversationStore] Migrated: added title column")
except Exception as e:
logger.warning(f"[ConversationStore] Migration (title) failed: {e}")
if "context_start_seq" not in cols:
try:
conn.execute(_MIGRATION_ADD_CONTEXT_START_SEQ)
conn.commit()
logger.info("[ConversationStore] Migrated: added context_start_seq column")
except Exception as e:
logger.warning(f"[ConversationStore] Migration (context_start_seq) failed: {e}")
def _connect(self) -> sqlite3.Connection:
conn = sqlite3.connect(str(self._db_path), timeout=10)