fix(memory): prevent context memory loss by improving trim strategy

This commit is contained in:
zhayujie
2026-03-12 15:25:46 +08:00
parent e791a77f77
commit c11623596d
2 changed files with 114 additions and 19 deletions

View File

@@ -8,7 +8,7 @@ import time
from typing import List, Dict, Any, Optional, Callable, Tuple
from agent.protocol.models import LLMRequest, LLMModel
from agent.protocol.message_utils import sanitize_claude_messages
from agent.protocol.message_utils import sanitize_claude_messages, compress_turn_to_text_only
from agent.tools.base_tool import BaseTool, ToolResult
from common.log import logger
@@ -191,6 +191,11 @@ class AgentStreamExecutor:
]
})
# Trim context ONCE before the agent loop starts, not during tool steps.
# This ensures tool_use/tool_result chains created during the current run
# are never stripped mid-execution (which would cause LLM loops).
self._trim_messages()
self._emit_event("agent_start")
final_response = ""
@@ -481,14 +486,10 @@ class AgentStreamExecutor:
Returns:
(response_text, tool_calls)
"""
# Validate and fix message history first
self._validate_and_fix_messages()
# Trim messages if needed (using agent's context management)
self._trim_messages()
# Re-validate after trimming: trimming may produce new orphaned
# tool_result messages when it removes turns at the boundary.
# Validate and fix message history (e.g. orphaned tool_result blocks).
# Context trimming is done once in run_stream() before the loop starts,
# NOT here — trimming mid-execution would strip the current run's
# tool_use/tool_result chains and cause LLM loops.
self._validate_and_fix_messages()
# Prepare messages
@@ -1165,10 +1166,10 @@ class AgentStreamExecutor:
if not turns:
return
# Step 2: 轮次限制 - 超出时裁到 max_turns/2批量 flush 被裁的轮次
# Step 2: 轮次限制 - 超出时移除前一半,保留后一半
if len(turns) > self.max_context_turns:
keep_count = max(1, self.max_context_turns // 2)
removed_count = len(turns) - keep_count
removed_count = len(turns) // 2
keep_count = len(turns) - removed_count
# Flush discarded turns to daily memory
if self.agent.memory_manager:
@@ -1223,9 +1224,47 @@ class AgentStreamExecutor:
logger.info(f" 重建消息列表: {old_count} -> {len(self.messages)} 条消息")
return
# Token limit exceeded - keep the latest half of turns (same strategy as turn limit)
keep_count = max(1, len(turns) // 2)
removed_count = len(turns) - keep_count
# Token limit exceeded — tiered strategy based on turn count:
#
# Few turns (<5): Compress ALL turns to text-only (strip tool chains,
# keep user query + final reply). Never discard turns
# — losing even one is too painful when context is thin.
#
# Many turns (>=5): Directly discard the first half of turns.
# With enough turns the oldest ones are less
# critical, and keeping the recent half intact
# (with full tool chains) is more useful.
COMPRESS_THRESHOLD = 5
if len(turns) < COMPRESS_THRESHOLD:
# --- Few turns: compress ALL turns to text-only, never discard ---
compressed_turns = []
for t in turns:
compressed = compress_turn_to_text_only(t)
if compressed["messages"]:
compressed_turns.append(compressed)
new_messages = []
for turn in compressed_turns:
new_messages.extend(turn["messages"])
new_tokens = sum(self._estimate_turn_tokens(t) for t in compressed_turns)
old_count = len(self.messages)
self.messages = new_messages
logger.info(
f"📦 上下文tokens超限(轮次<{COMPRESS_THRESHOLD}): "
f"~{current_tokens + system_tokens} > {max_tokens}"
f"压缩全部 {len(turns)} 轮为纯文本 "
f"({old_count} -> {len(self.messages)} 条消息,"
f"~{current_tokens + system_tokens} -> ~{new_tokens + system_tokens} tokens)"
)
return
# --- Many turns (>=5): discard the older half, keep the newer half ---
removed_count = len(turns) // 2
keep_count = len(turns) - removed_count
kept_turns = turns[-keep_count:]
kept_tokens = sum(self._estimate_turn_tokens(t) for t in kept_turns)
@@ -1234,7 +1273,6 @@ class AgentStreamExecutor:
f"裁剪至 {keep_count} 轮(移除 {removed_count} 轮)"
)
# Flush discarded turns to daily memory
if self.agent.memory_manager:
discarded_messages = []
for turn in turns[:removed_count]:
@@ -1245,14 +1283,14 @@ class AgentStreamExecutor:
messages=discarded_messages, user_id=user_id,
reason="trim", max_messages=0
)
new_messages = []
for turn in kept_turns:
new_messages.extend(turn['messages'])
old_count = len(self.messages)
self.messages = new_messages
logger.info(
f" 移除了 {removed_count} 轮对话 "
f"({old_count} -> {len(self.messages)} 条消息,"