mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-02 00:57:41 +08:00
feat: upgrade memory flush system
- Use LLM to summarize discarded context into concise daily memory entries - Batch trim to half when exceeding max_turns/max_tokens, reducing flush frequency - Run summarization asynchronously in background thread, no blocking on replies - Add daily scheduled flush (23:55) as fallback for low-activity days - Sync trimmed messages back to agent to keep context state consistent
This commit is contained in:
@@ -544,11 +544,15 @@ class Agent:
|
||||
logger.info("[Agent] Cleared Agent message history after executor recovery")
|
||||
raise
|
||||
|
||||
# Append only the NEW messages from this execution (thread-safe)
|
||||
# This allows concurrent requests to both contribute to history
|
||||
# Sync executor's messages back to agent (thread-safe).
|
||||
# If the executor trimmed context, its message list is shorter than
|
||||
# original_length, so we must replace rather than append.
|
||||
with self.messages_lock:
|
||||
new_messages = executor.messages[original_length:]
|
||||
self.messages.extend(new_messages)
|
||||
self.messages = list(executor.messages)
|
||||
# Track messages added in this run (user query + all assistant/tool messages)
|
||||
# original_length may exceed executor.messages length after trimming
|
||||
trim_adjusted_start = min(original_length, len(executor.messages))
|
||||
self._last_run_new_messages = list(executor.messages[trim_adjusted_start:])
|
||||
|
||||
# Store executor reference for agent_bridge to access files_to_send
|
||||
self.stream_executor = executor
|
||||
|
||||
@@ -201,26 +201,6 @@ class AgentStreamExecutor:
|
||||
logger.info(f"[Agent] 第 {turn} 轮")
|
||||
self._emit_event("turn_start", {"turn": turn})
|
||||
|
||||
# Check if memory flush is needed (before calling LLM)
|
||||
# 使用独立的 flush 阈值(50K tokens 或 20 轮)
|
||||
if self.agent.memory_manager and hasattr(self.agent, 'last_usage'):
|
||||
usage = self.agent.last_usage
|
||||
if usage and 'input_tokens' in usage:
|
||||
current_tokens = usage.get('input_tokens', 0)
|
||||
|
||||
if self.agent.memory_manager.should_flush_memory(
|
||||
current_tokens=current_tokens
|
||||
):
|
||||
self._emit_event("memory_flush_start", {
|
||||
"current_tokens": current_tokens,
|
||||
"turn_count": self.agent.memory_manager.flush_manager.turn_count
|
||||
})
|
||||
|
||||
# TODO: Execute memory flush in background
|
||||
# This would require async support
|
||||
logger.info(
|
||||
f"Memory flush recommended: tokens={current_tokens}, turns={self.agent.memory_manager.flush_manager.turn_count}")
|
||||
|
||||
# Call LLM (enable retry_on_empty for better reliability)
|
||||
assistant_msg, tool_calls = self._call_llm_stream(retry_on_empty=True)
|
||||
final_response = assistant_msg
|
||||
@@ -473,10 +453,6 @@ class AgentStreamExecutor:
|
||||
logger.info(f"[Agent] 🏁 完成 ({turn}轮)")
|
||||
self._emit_event("agent_end", {"final_response": final_response})
|
||||
|
||||
# 每轮对话结束后增加计数(用户消息+AI回复=1轮)
|
||||
if self.agent.memory_manager:
|
||||
self.agent.memory_manager.increment_turn()
|
||||
|
||||
return final_response
|
||||
|
||||
def _call_llm_stream(self, retry_on_empty=True, retry_count=0, max_retries=3,
|
||||
@@ -501,7 +477,8 @@ class AgentStreamExecutor:
|
||||
|
||||
# Prepare messages
|
||||
messages = self._prepare_messages()
|
||||
logger.info(f"Sending {len(messages)} messages to LLM")
|
||||
turns = self._identify_complete_turns()
|
||||
logger.info(f"Sending {len(messages)} messages ({len(turns)} turns) to LLM")
|
||||
|
||||
# Prepare tool definitions (OpenAI/Claude format)
|
||||
tools_schema = None
|
||||
@@ -655,6 +632,14 @@ class AgentStreamExecutor:
|
||||
error_type = "context overflow" if is_context_overflow else "message format error"
|
||||
logger.error(f"💥 {error_type} detected: {e}")
|
||||
|
||||
# Flush memory before trimming to preserve context that will be lost
|
||||
if is_context_overflow and self.agent.memory_manager:
|
||||
user_id = getattr(self.agent, '_current_user_id', None)
|
||||
self.agent.memory_manager.flush_memory(
|
||||
messages=self.messages, user_id=user_id,
|
||||
reason="overflow", max_messages=0
|
||||
)
|
||||
|
||||
# Strategy: try aggressive trimming first, only clear as last resort
|
||||
if is_context_overflow and not _overflow_retry:
|
||||
trimmed = self._aggressive_trim_for_overflow()
|
||||
@@ -1204,14 +1189,28 @@ class AgentStreamExecutor:
|
||||
if not turns:
|
||||
return
|
||||
|
||||
# Step 2: 轮次限制 - 保留最近 N 轮
|
||||
# Step 2: 轮次限制 - 超出时裁到 max_turns/2,批量 flush 被裁的轮次
|
||||
if len(turns) > self.max_context_turns:
|
||||
removed_turns = len(turns) - self.max_context_turns
|
||||
turns = turns[-self.max_context_turns:] # 保留最近的轮次
|
||||
keep_count = max(1, self.max_context_turns // 2)
|
||||
removed_count = len(turns) - keep_count
|
||||
|
||||
# Flush discarded turns to daily memory
|
||||
if self.agent.memory_manager:
|
||||
discarded_messages = []
|
||||
for turn in turns[:removed_count]:
|
||||
discarded_messages.extend(turn["messages"])
|
||||
if discarded_messages:
|
||||
user_id = getattr(self.agent, '_current_user_id', None)
|
||||
self.agent.memory_manager.flush_memory(
|
||||
messages=discarded_messages, user_id=user_id,
|
||||
reason="trim", max_messages=0
|
||||
)
|
||||
|
||||
turns = turns[-keep_count:]
|
||||
|
||||
logger.info(
|
||||
f"💾 上下文轮次超限: {len(turns) + removed_turns} > {self.max_context_turns},"
|
||||
f"移除最早的 {removed_turns} 轮完整对话"
|
||||
f"💾 上下文轮次超限: {keep_count + removed_count} > {self.max_context_turns},"
|
||||
f"裁剪至 {keep_count} 轮(移除 {removed_count} 轮)"
|
||||
)
|
||||
|
||||
# Step 3: Token 限制 - 保留完整轮次
|
||||
@@ -1248,56 +1247,41 @@ class AgentStreamExecutor:
|
||||
logger.info(f" 重建消息列表: {old_count} -> {len(self.messages)} 条消息")
|
||||
return
|
||||
|
||||
# Token limit exceeded - keep complete turns from newest
|
||||
# Token limit exceeded - keep the latest half of turns (same strategy as turn limit)
|
||||
keep_count = max(1, len(turns) // 2)
|
||||
removed_count = len(turns) - keep_count
|
||||
kept_turns = turns[-keep_count:]
|
||||
kept_tokens = sum(self._estimate_turn_tokens(t) for t in kept_turns)
|
||||
|
||||
logger.info(
|
||||
f"🔄 上下文tokens超限: ~{current_tokens + system_tokens} > {max_tokens},"
|
||||
f"将按完整轮次移除最早的对话"
|
||||
f"裁剪至 {keep_count} 轮(移除 {removed_count} 轮)"
|
||||
)
|
||||
|
||||
# 从最新轮次开始,反向累加(保持完整轮次)
|
||||
kept_turns = []
|
||||
accumulated_tokens = 0
|
||||
min_turns = 3 # 尽量保留至少 3 轮,但不强制(避免超出 token 限制)
|
||||
# Flush discarded turns to daily memory
|
||||
if self.agent.memory_manager:
|
||||
discarded_messages = []
|
||||
for turn in turns[:removed_count]:
|
||||
discarded_messages.extend(turn["messages"])
|
||||
if discarded_messages:
|
||||
user_id = getattr(self.agent, '_current_user_id', None)
|
||||
self.agent.memory_manager.flush_memory(
|
||||
messages=discarded_messages, user_id=user_id,
|
||||
reason="trim", max_messages=0
|
||||
)
|
||||
|
||||
for i, turn in enumerate(reversed(turns)):
|
||||
turn_tokens = self._estimate_turn_tokens(turn)
|
||||
turns_from_end = i + 1
|
||||
|
||||
# 检查是否超出限制
|
||||
if accumulated_tokens + turn_tokens <= available_tokens:
|
||||
kept_turns.insert(0, turn)
|
||||
accumulated_tokens += turn_tokens
|
||||
else:
|
||||
# 超出限制
|
||||
# 如果还没有保留足够的轮次,且这是最后的机会,尝试保留
|
||||
if len(kept_turns) < min_turns and turns_from_end <= min_turns:
|
||||
# 检查是否严重超出(超出 20% 以上则放弃)
|
||||
overflow_ratio = (accumulated_tokens + turn_tokens - available_tokens) / available_tokens
|
||||
if overflow_ratio < 0.2: # 允许最多超出 20%
|
||||
kept_turns.insert(0, turn)
|
||||
accumulated_tokens += turn_tokens
|
||||
logger.debug(f" 为保留最少轮次,允许超出 {overflow_ratio*100:.1f}%")
|
||||
continue
|
||||
# 停止保留更早的轮次
|
||||
break
|
||||
|
||||
# 重建消息列表
|
||||
new_messages = []
|
||||
for turn in kept_turns:
|
||||
new_messages.extend(turn['messages'])
|
||||
|
||||
old_count = len(self.messages)
|
||||
old_turn_count = len(turns)
|
||||
self.messages = new_messages
|
||||
new_count = len(self.messages)
|
||||
new_turn_count = len(kept_turns)
|
||||
|
||||
if old_count > new_count:
|
||||
logger.info(
|
||||
f" 移除了 {old_turn_count - new_turn_count} 轮对话 "
|
||||
f"({old_count} -> {new_count} 条消息,"
|
||||
f"~{current_tokens + system_tokens} -> ~{accumulated_tokens + system_tokens} tokens)"
|
||||
)
|
||||
logger.info(
|
||||
f" 移除了 {removed_count} 轮对话 "
|
||||
f"({old_count} -> {len(self.messages)} 条消息,"
|
||||
f"~{current_tokens + system_tokens} -> ~{kept_tokens + system_tokens} tokens)"
|
||||
)
|
||||
|
||||
def _clear_session_db(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user