mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-02 00:57:41 +08:00
Merge branch 'master' of github.com:zhayujie/chatgpt-on-wechat
This commit is contained in:
@@ -28,11 +28,13 @@ from common.log import logger
|
||||
|
||||
_DDL = """
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
session_id TEXT PRIMARY KEY,
|
||||
channel_type TEXT NOT NULL DEFAULT '',
|
||||
created_at INTEGER NOT NULL,
|
||||
last_active INTEGER NOT NULL,
|
||||
msg_count INTEGER NOT NULL DEFAULT 0
|
||||
session_id TEXT PRIMARY KEY,
|
||||
channel_type TEXT NOT NULL DEFAULT '',
|
||||
title TEXT NOT NULL DEFAULT '',
|
||||
context_start_seq INTEGER NOT NULL DEFAULT 0,
|
||||
created_at INTEGER NOT NULL,
|
||||
last_active INTEGER NOT NULL,
|
||||
msg_count INTEGER NOT NULL DEFAULT 0
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
@@ -57,6 +59,14 @@ _MIGRATION_ADD_CHANNEL_TYPE = """
|
||||
ALTER TABLE sessions ADD COLUMN channel_type TEXT NOT NULL DEFAULT '';
|
||||
"""
|
||||
|
||||
_MIGRATION_ADD_TITLE = """
|
||||
ALTER TABLE sessions ADD COLUMN title TEXT NOT NULL DEFAULT '';
|
||||
"""
|
||||
|
||||
_MIGRATION_ADD_CONTEXT_START_SEQ = """
|
||||
ALTER TABLE sessions ADD COLUMN context_start_seq INTEGER NOT NULL DEFAULT 0;
|
||||
"""
|
||||
|
||||
DEFAULT_MAX_AGE_DAYS: int = 30
|
||||
|
||||
|
||||
@@ -287,14 +297,21 @@ class ConversationStore:
|
||||
with self._lock:
|
||||
conn = self._connect()
|
||||
try:
|
||||
# Respect context_start_seq: only load messages at or after the boundary
|
||||
ctx_row = conn.execute(
|
||||
"SELECT context_start_seq FROM sessions WHERE session_id = ?",
|
||||
(session_id,),
|
||||
).fetchone()
|
||||
ctx_start = ctx_row[0] if ctx_row else 0
|
||||
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT seq, role, content
|
||||
FROM messages
|
||||
WHERE session_id = ?
|
||||
WHERE session_id = ? AND seq >= ?
|
||||
ORDER BY seq DESC
|
||||
""",
|
||||
(session_id,),
|
||||
(session_id, ctx_start),
|
||||
).fetchall()
|
||||
finally:
|
||||
conn.close()
|
||||
@@ -302,10 +319,7 @@ class ConversationStore:
|
||||
if not rows:
|
||||
return []
|
||||
|
||||
# Walk newest-to-oldest counting *visible* user turns (actual user text,
|
||||
# not tool_result injections). Record the seq of every visible user
|
||||
# message so we can find a clean cut point later.
|
||||
visible_turn_seqs: List[int] = [] # newest first
|
||||
visible_turn_seqs: List[int] = []
|
||||
for seq, role, raw_content in rows:
|
||||
if role != "user":
|
||||
continue
|
||||
@@ -316,17 +330,11 @@ class ConversationStore:
|
||||
if _is_visible_user_message(content):
|
||||
visible_turn_seqs.append(seq)
|
||||
|
||||
# Determine the seq of the oldest visible user message we want to keep.
|
||||
# If the total turns fit within max_turns, keep everything.
|
||||
if len(visible_turn_seqs) <= max_turns:
|
||||
cutoff_seq = None # keep all
|
||||
cutoff_seq = None
|
||||
else:
|
||||
# The Nth visible user message (0-indexed) is the oldest we keep.
|
||||
cutoff_seq = visible_turn_seqs[max_turns - 1]
|
||||
|
||||
# Build result in chronological order, starting from cutoff.
|
||||
# IMPORTANT: we start exactly at cutoff_seq (the visible user message),
|
||||
# never mid-group, so tool_use / tool_result pairs are always complete.
|
||||
result = []
|
||||
for seq, role, raw_content in reversed(rows):
|
||||
if cutoff_seq is not None and seq < cutoff_seq:
|
||||
@@ -415,6 +423,61 @@ class ConversationStore:
|
||||
""",
|
||||
(session_id, session_id),
|
||||
)
|
||||
|
||||
# Auto-generate title from the first visible user message
|
||||
cur_title = conn.execute(
|
||||
"SELECT title FROM sessions WHERE session_id = ?",
|
||||
(session_id,),
|
||||
).fetchone()
|
||||
if cur_title and not cur_title[0]:
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user":
|
||||
content = msg.get("content", "")
|
||||
text = _extract_display_text(content)
|
||||
if text:
|
||||
title = text[:50].split("\n")[0]
|
||||
conn.execute(
|
||||
"UPDATE sessions SET title = ? WHERE session_id = ?",
|
||||
(title, session_id),
|
||||
)
|
||||
break
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def clear_context(self, session_id: str) -> int:
|
||||
"""
|
||||
Set the context boundary to after the current last message.
|
||||
Messages before this boundary are still stored but excluded from LLM context.
|
||||
|
||||
Returns the new context_start_seq value.
|
||||
"""
|
||||
with self._lock:
|
||||
conn = self._connect()
|
||||
try:
|
||||
with conn:
|
||||
row = conn.execute(
|
||||
"SELECT COALESCE(MAX(seq), -1) FROM messages WHERE session_id = ?",
|
||||
(session_id,),
|
||||
).fetchone()
|
||||
new_start = row[0] + 1
|
||||
conn.execute(
|
||||
"UPDATE sessions SET context_start_seq = ? WHERE session_id = ?",
|
||||
(new_start, session_id),
|
||||
)
|
||||
return new_start
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_context_start_seq(self, session_id: str) -> int:
|
||||
"""Return the context_start_seq for a session (0 if not set)."""
|
||||
with self._lock:
|
||||
conn = self._connect()
|
||||
try:
|
||||
row = conn.execute(
|
||||
"SELECT context_start_seq FROM sessions WHERE session_id = ?",
|
||||
(session_id,),
|
||||
).fetchone()
|
||||
return row[0] if row else 0
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
@@ -436,6 +499,7 @@ class ConversationStore:
|
||||
def cleanup_old_sessions(self, max_age_days: Optional[int] = None) -> int:
|
||||
"""
|
||||
Delete sessions that have not been active within max_age_days.
|
||||
Web channel sessions are excluded — they are meant to be permanent.
|
||||
|
||||
Args:
|
||||
max_age_days: Override the default retention period.
|
||||
@@ -459,7 +523,8 @@ class ConversationStore:
|
||||
try:
|
||||
with conn:
|
||||
stale = conn.execute(
|
||||
"SELECT session_id FROM sessions WHERE last_active < ?",
|
||||
"SELECT session_id FROM sessions "
|
||||
"WHERE last_active < ? AND channel_type != 'web'",
|
||||
(cutoff,),
|
||||
).fetchall()
|
||||
for (sid,) in stale:
|
||||
@@ -518,9 +583,15 @@ class ConversationStore:
|
||||
with self._lock:
|
||||
conn = self._connect()
|
||||
try:
|
||||
ctx_row = conn.execute(
|
||||
"SELECT context_start_seq FROM sessions WHERE session_id = ?",
|
||||
(session_id,),
|
||||
).fetchone()
|
||||
ctx_start = ctx_row[0] if ctx_row else 0
|
||||
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT role, content, created_at
|
||||
SELECT seq, role, content, created_at
|
||||
FROM messages
|
||||
WHERE session_id = ?
|
||||
ORDER BY seq ASC
|
||||
@@ -530,7 +601,30 @@ class ConversationStore:
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
visible = _group_into_display_turns(rows)
|
||||
# Strip seq for display grouping, but record max seq per visible user group
|
||||
plain_rows = [(role, content, created_at) for _seq, role, content, created_at in rows]
|
||||
visible = _group_into_display_turns(plain_rows)
|
||||
|
||||
# Build a mapping: find the seq of each visible user message to annotate context boundary.
|
||||
# Walk through rows to find visible user message seqs in order.
|
||||
visible_user_seqs: List[int] = []
|
||||
for seq, role, raw_content, _ts in rows:
|
||||
if role != "user":
|
||||
continue
|
||||
try:
|
||||
content = json.loads(raw_content)
|
||||
except Exception:
|
||||
content = raw_content
|
||||
if _is_visible_user_message(content):
|
||||
visible_user_seqs.append(seq)
|
||||
|
||||
# Each pair of display turns (user+assistant) corresponds to a visible user seq.
|
||||
# Mark which turns are before the context boundary.
|
||||
user_turn_idx = 0
|
||||
for turn in visible:
|
||||
if turn["role"] == "user" and user_turn_idx < len(visible_user_seqs):
|
||||
turn["_seq"] = visible_user_seqs[user_turn_idx]
|
||||
user_turn_idx += 1
|
||||
|
||||
total = len(visible)
|
||||
offset = (page - 1) * page_size
|
||||
@@ -539,12 +633,98 @@ class ConversationStore:
|
||||
|
||||
return {
|
||||
"messages": page_items,
|
||||
"context_start_seq": ctx_start,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"has_more": offset + page_size < total,
|
||||
}
|
||||
|
||||
def list_sessions(
|
||||
self,
|
||||
channel_type: Optional[str] = None,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
List sessions ordered by last_active DESC, with optional channel_type filter.
|
||||
|
||||
Returns:
|
||||
{
|
||||
"sessions": [{session_id, title, created_at, last_active, msg_count}, ...],
|
||||
"total": int,
|
||||
"page": int,
|
||||
"page_size": int,
|
||||
"has_more": bool,
|
||||
}
|
||||
"""
|
||||
page = max(1, page)
|
||||
with self._lock:
|
||||
conn = self._connect()
|
||||
try:
|
||||
if channel_type:
|
||||
total = conn.execute(
|
||||
"SELECT COUNT(*) FROM sessions WHERE channel_type = ?",
|
||||
(channel_type,),
|
||||
).fetchone()[0]
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT session_id, title, created_at, last_active, msg_count
|
||||
FROM sessions
|
||||
WHERE channel_type = ?
|
||||
ORDER BY last_active DESC
|
||||
LIMIT ? OFFSET ?
|
||||
""",
|
||||
(channel_type, page_size, (page - 1) * page_size),
|
||||
).fetchall()
|
||||
else:
|
||||
total = conn.execute(
|
||||
"SELECT COUNT(*) FROM sessions",
|
||||
).fetchone()[0]
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT session_id, title, created_at, last_active, msg_count
|
||||
FROM sessions
|
||||
ORDER BY last_active DESC
|
||||
LIMIT ? OFFSET ?
|
||||
""",
|
||||
(page_size, (page - 1) * page_size),
|
||||
).fetchall()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
sessions = [
|
||||
{
|
||||
"session_id": r[0],
|
||||
"title": r[1],
|
||||
"created_at": r[2],
|
||||
"last_active": r[3],
|
||||
"msg_count": r[4],
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
return {
|
||||
"sessions": sessions,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"has_more": (page - 1) * page_size + page_size < total,
|
||||
}
|
||||
|
||||
def rename_session(self, session_id: str, title: str) -> bool:
|
||||
"""Update the title of a session. Returns True if the session existed."""
|
||||
with self._lock:
|
||||
conn = self._connect()
|
||||
try:
|
||||
with conn:
|
||||
cur = conn.execute(
|
||||
"UPDATE sessions SET title = ? WHERE session_id = ?",
|
||||
(title, session_id),
|
||||
)
|
||||
return cur.rowcount > 0
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Return basic stats keyed by channel_type, for monitoring."""
|
||||
with self._lock:
|
||||
@@ -599,6 +779,20 @@ class ConversationStore:
|
||||
logger.info("[ConversationStore] Migrated: added channel_type column")
|
||||
except Exception as e:
|
||||
logger.warning(f"[ConversationStore] Migration failed: {e}")
|
||||
if "title" not in cols:
|
||||
try:
|
||||
conn.execute(_MIGRATION_ADD_TITLE)
|
||||
conn.commit()
|
||||
logger.info("[ConversationStore] Migrated: added title column")
|
||||
except Exception as e:
|
||||
logger.warning(f"[ConversationStore] Migration (title) failed: {e}")
|
||||
if "context_start_seq" not in cols:
|
||||
try:
|
||||
conn.execute(_MIGRATION_ADD_CONTEXT_START_SEQ)
|
||||
conn.commit()
|
||||
logger.info("[ConversationStore] Migrated: added context_start_seq column")
|
||||
except Exception as e:
|
||||
logger.warning(f"[ConversationStore] Migration (context_start_seq) failed: {e}")
|
||||
|
||||
def _connect(self) -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(str(self._db_path), timeout=10)
|
||||
|
||||
Reference in New Issue
Block a user