mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-02 00:57:41 +08:00
feat: add llm retry
This commit is contained in:
@@ -354,7 +354,7 @@ def _build_workspace_section(workspace_dir: str, language: str) -> List[str]:
|
|||||||
"",
|
"",
|
||||||
"**路径使用规则** (非常重要):",
|
"**路径使用规则** (非常重要):",
|
||||||
"",
|
"",
|
||||||
"- **工作空间内的文件**: 使用相对路径(如 `SOUL.md`、`memory/daily.md`)",
|
"- **工作空间内的文件**: 可以使用相对路径(如 `SOUL.md`、`MEMORY.md`)",
|
||||||
"- **工作空间外的文件**: 必须使用绝对路径(如 `~/project/code.py`、`/etc/config`)",
|
"- **工作空间外的文件**: 必须使用绝对路径(如 `~/project/code.py`、`/etc/config`)",
|
||||||
"- **不确定时**: 先用 `bash pwd` 确认当前目录,或用 `ls .` 查看当前位置",
|
"- **不确定时**: 先用 `bash pwd` 确认当前目录,或用 `ls .` 查看当前位置",
|
||||||
"",
|
"",
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ class AgentStreamExecutor:
|
|||||||
try:
|
try:
|
||||||
while turn < self.max_turns:
|
while turn < self.max_turns:
|
||||||
turn += 1
|
turn += 1
|
||||||
logger.info(f"\n{'='*50} 第 {turn} 轮 {'='*50}")
|
logger.info(f"\n🔄 第 {turn} 轮")
|
||||||
self._emit_event("turn_start", {"turn": turn})
|
self._emit_event("turn_start", {"turn": turn})
|
||||||
|
|
||||||
# Check if memory flush is needed (before calling LLM)
|
# Check if memory flush is needed (before calling LLM)
|
||||||
@@ -156,9 +156,15 @@ class AgentStreamExecutor:
|
|||||||
})
|
})
|
||||||
break
|
break
|
||||||
|
|
||||||
# Log tool calls in compact format
|
# Log tool calls with arguments
|
||||||
tool_names = [tc['name'] for tc in tool_calls]
|
tool_calls_str = []
|
||||||
logger.info(f"🔧 调用工具: {', '.join(tool_names)}")
|
for tc in tool_calls:
|
||||||
|
args_str = ', '.join([f"{k}={v}" for k, v in tc['arguments'].items()])
|
||||||
|
if args_str:
|
||||||
|
tool_calls_str.append(f"{tc['name']}({args_str})")
|
||||||
|
else:
|
||||||
|
tool_calls_str.append(tc['name'])
|
||||||
|
logger.info(f"🔧 {', '.join(tool_calls_str)}")
|
||||||
|
|
||||||
# Execute tools
|
# Execute tools
|
||||||
tool_results = []
|
tool_results = []
|
||||||
@@ -179,13 +185,33 @@ class AgentStreamExecutor:
|
|||||||
logger.info(f" {status_emoji} {tool_call['name']} ({result.get('execution_time', 0):.2f}s): {result_str[:200]}{'...' if len(result_str) > 200 else ''}")
|
logger.info(f" {status_emoji} {tool_call['name']} ({result.get('execution_time', 0):.2f}s): {result_str[:200]}{'...' if len(result_str) > 200 else ''}")
|
||||||
|
|
||||||
# Build tool result block (Claude format)
|
# Build tool result block (Claude format)
|
||||||
# Content should be a string representation of the result
|
# Format content in a way that's easy for LLM to understand
|
||||||
result_content = json.dumps(result, ensure_ascii=False) if not isinstance(result, str) else result
|
is_error = result.get("status") == "error"
|
||||||
tool_result_blocks.append({
|
|
||||||
|
if is_error:
|
||||||
|
# For errors, provide clear error message
|
||||||
|
result_content = f"Error: {result.get('result', 'Unknown error')}"
|
||||||
|
elif isinstance(result.get('result'), dict):
|
||||||
|
# For dict results, use JSON format
|
||||||
|
result_content = json.dumps(result.get('result'), ensure_ascii=False)
|
||||||
|
elif isinstance(result.get('result'), str):
|
||||||
|
# For string results, use directly
|
||||||
|
result_content = result.get('result')
|
||||||
|
else:
|
||||||
|
# Fallback to full JSON
|
||||||
|
result_content = json.dumps(result, ensure_ascii=False)
|
||||||
|
|
||||||
|
tool_result_block = {
|
||||||
"type": "tool_result",
|
"type": "tool_result",
|
||||||
"tool_use_id": tool_call["id"],
|
"tool_use_id": tool_call["id"],
|
||||||
"content": result_content
|
"content": result_content
|
||||||
})
|
}
|
||||||
|
|
||||||
|
# Add is_error field for Claude API (helps model understand failures)
|
||||||
|
if is_error:
|
||||||
|
tool_result_block["is_error"] = True
|
||||||
|
|
||||||
|
tool_result_blocks.append(tool_result_block)
|
||||||
|
|
||||||
# Add tool results to message history as user message (Claude format)
|
# Add tool results to message history as user message (Claude format)
|
||||||
self.messages.append({
|
self.messages.append({
|
||||||
@@ -201,6 +227,11 @@ class AgentStreamExecutor:
|
|||||||
|
|
||||||
if turn >= self.max_turns:
|
if turn >= self.max_turns:
|
||||||
logger.warning(f"⚠️ 已达到最大轮数限制: {self.max_turns}")
|
logger.warning(f"⚠️ 已达到最大轮数限制: {self.max_turns}")
|
||||||
|
if not final_response:
|
||||||
|
final_response = (
|
||||||
|
"抱歉,我在处理你的请求时遇到了一些困难,尝试了多次仍未能完成。"
|
||||||
|
"请尝试简化你的问题,或换一种方式描述。"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"❌ Agent执行错误: {e}")
|
logger.error(f"❌ Agent执行错误: {e}")
|
||||||
@@ -208,14 +239,19 @@ class AgentStreamExecutor:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
logger.info(f"{'='*50} 完成({turn}轮) {'='*50}\n")
|
logger.info(f"🏁 完成({turn}轮)\n")
|
||||||
self._emit_event("agent_end", {"final_response": final_response})
|
self._emit_event("agent_end", {"final_response": final_response})
|
||||||
|
|
||||||
return final_response
|
return final_response
|
||||||
|
|
||||||
def _call_llm_stream(self) -> tuple[str, List[Dict]]:
|
def _call_llm_stream(self, retry_on_empty=True, retry_count=0, max_retries=3) -> tuple[str, List[Dict]]:
|
||||||
"""
|
"""
|
||||||
Call LLM with streaming
|
Call LLM with streaming and automatic retry on errors
|
||||||
|
|
||||||
|
Args:
|
||||||
|
retry_on_empty: Whether to retry once if empty response is received
|
||||||
|
retry_count: Current retry attempt (internal use)
|
||||||
|
max_retries: Maximum number of retries for API errors
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(response_text, tool_calls)
|
(response_text, tool_calls)
|
||||||
@@ -309,7 +345,28 @@ class AgentStreamExecutor:
|
|||||||
tool_calls_buffer[index]["arguments"] += func["arguments"]
|
tool_calls_buffer[index]["arguments"] += func["arguments"]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"LLM call error: {e}")
|
error_str = str(e).lower()
|
||||||
|
# Check if error is retryable (timeout, connection, rate limit, etc.)
|
||||||
|
is_retryable = any(keyword in error_str for keyword in [
|
||||||
|
'timeout', 'timed out', 'connection', 'network',
|
||||||
|
'rate limit', 'overloaded', 'unavailable', '429', '500', '502', '503', '504'
|
||||||
|
])
|
||||||
|
|
||||||
|
if is_retryable and retry_count < max_retries:
|
||||||
|
wait_time = (retry_count + 1) * 2 # Exponential backoff: 2s, 4s, 6s
|
||||||
|
logger.warning(f"⚠️ LLM API error (attempt {retry_count + 1}/{max_retries}): {e}")
|
||||||
|
logger.info(f"🔄 Retrying in {wait_time}s...")
|
||||||
|
time.sleep(wait_time)
|
||||||
|
return self._call_llm_stream(
|
||||||
|
retry_on_empty=retry_on_empty,
|
||||||
|
retry_count=retry_count + 1,
|
||||||
|
max_retries=max_retries
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if retry_count >= max_retries:
|
||||||
|
logger.error(f"❌ LLM API error after {max_retries} retries: {e}")
|
||||||
|
else:
|
||||||
|
logger.error(f"❌ LLM call error (non-retryable): {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# Parse tool calls
|
# Parse tool calls
|
||||||
@@ -328,6 +385,21 @@ class AgentStreamExecutor:
|
|||||||
"arguments": arguments
|
"arguments": arguments
|
||||||
})
|
})
|
||||||
|
|
||||||
|
# Check for empty response and retry once if enabled
|
||||||
|
if retry_on_empty and not full_content and not tool_calls:
|
||||||
|
logger.warning(f"⚠️ LLM returned empty response, retrying once...")
|
||||||
|
self._emit_event("message_end", {
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [],
|
||||||
|
"empty_retry": True
|
||||||
|
})
|
||||||
|
# Retry without retry flag to avoid infinite loop
|
||||||
|
return self._call_llm_stream(
|
||||||
|
retry_on_empty=False,
|
||||||
|
retry_count=retry_count,
|
||||||
|
max_retries=max_retries
|
||||||
|
)
|
||||||
|
|
||||||
# Add assistant message to history (Claude format uses content blocks)
|
# Add assistant message to history (Claude format uses content blocks)
|
||||||
assistant_msg = {"role": "assistant", "content": []}
|
assistant_msg = {"role": "assistant", "content": []}
|
||||||
|
|
||||||
|
|||||||
@@ -255,7 +255,7 @@ class GoogleGeminiBot(Bot):
|
|||||||
gemini_tools = self._convert_tools_to_gemini_rest_format(tools)
|
gemini_tools = self._convert_tools_to_gemini_rest_format(tools)
|
||||||
if gemini_tools:
|
if gemini_tools:
|
||||||
payload["tools"] = gemini_tools
|
payload["tools"] = gemini_tools
|
||||||
logger.info(f"[Gemini] Added {len(tools)} tools to request")
|
logger.debug(f"[Gemini] Added {len(tools)} tools to request")
|
||||||
|
|
||||||
# Make REST API call
|
# Make REST API call
|
||||||
base_url = f"{self.api_base}/v1beta"
|
base_url = f"{self.api_base}/v1beta"
|
||||||
@@ -445,6 +445,9 @@ class GoogleGeminiBot(Bot):
|
|||||||
all_tool_calls = []
|
all_tool_calls = []
|
||||||
has_sent_tool_calls = False
|
has_sent_tool_calls = False
|
||||||
has_content = False # Track if any content was sent
|
has_content = False # Track if any content was sent
|
||||||
|
chunk_count = 0
|
||||||
|
last_finish_reason = None
|
||||||
|
last_safety_ratings = None
|
||||||
|
|
||||||
for line in response.iter_lines():
|
for line in response.iter_lines():
|
||||||
if not line:
|
if not line:
|
||||||
@@ -461,6 +464,7 @@ class GoogleGeminiBot(Bot):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
chunk_data = json.loads(line)
|
chunk_data = json.loads(line)
|
||||||
|
chunk_count += 1
|
||||||
logger.debug(f"[Gemini] Stream chunk: {json.dumps(chunk_data, ensure_ascii=False)[:200]}")
|
logger.debug(f"[Gemini] Stream chunk: {json.dumps(chunk_data, ensure_ascii=False)[:200]}")
|
||||||
|
|
||||||
candidates = chunk_data.get("candidates", [])
|
candidates = chunk_data.get("candidates", [])
|
||||||
@@ -469,6 +473,13 @@ class GoogleGeminiBot(Bot):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
candidate = candidates[0]
|
candidate = candidates[0]
|
||||||
|
|
||||||
|
# 记录 finish_reason 和 safety_ratings
|
||||||
|
if "finishReason" in candidate:
|
||||||
|
last_finish_reason = candidate["finishReason"]
|
||||||
|
if "safetyRatings" in candidate:
|
||||||
|
last_safety_ratings = candidate["safetyRatings"]
|
||||||
|
|
||||||
content = candidate.get("content", {})
|
content = candidate.get("content", {})
|
||||||
parts = content.get("parts", [])
|
parts = content.get("parts", [])
|
||||||
|
|
||||||
@@ -512,7 +523,7 @@ class GoogleGeminiBot(Bot):
|
|||||||
|
|
||||||
# Send tool calls if any were collected
|
# Send tool calls if any were collected
|
||||||
if all_tool_calls and not has_sent_tool_calls:
|
if all_tool_calls and not has_sent_tool_calls:
|
||||||
logger.info(f"[Gemini] Stream detected {len(all_tool_calls)} tool calls")
|
logger.debug(f"[Gemini] Stream detected {len(all_tool_calls)} tool calls")
|
||||||
yield {
|
yield {
|
||||||
"id": f"chatcmpl-{time.time()}",
|
"id": f"chatcmpl-{time.time()}",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
@@ -526,17 +537,17 @@ class GoogleGeminiBot(Bot):
|
|||||||
}
|
}
|
||||||
has_sent_tool_calls = True
|
has_sent_tool_calls = True
|
||||||
|
|
||||||
# Log summary
|
# Log summary (only if there's something interesting)
|
||||||
logger.info(f"[Gemini] Stream complete: has_content={has_content}, tool_calls={len(all_tool_calls)}")
|
if not has_content and not all_tool_calls:
|
||||||
|
logger.debug(f"[Gemini] Stream complete: has_content={has_content}, tool_calls={len(all_tool_calls)}")
|
||||||
|
elif all_tool_calls:
|
||||||
|
logger.debug(f"[Gemini] Stream complete: {len(all_tool_calls)} tool calls")
|
||||||
|
else:
|
||||||
|
logger.debug(f"[Gemini] Stream complete: text response")
|
||||||
|
|
||||||
# 如果返回空响应,记录详细警告
|
# 如果返回空响应,记录详细警告
|
||||||
if not has_content and not all_tool_calls:
|
if not has_content and not all_tool_calls:
|
||||||
logger.warning(f"[Gemini] ⚠️ Empty response detected!")
|
logger.warning(f"[Gemini] ⚠️ Empty response detected!")
|
||||||
logger.warning(f"[Gemini] Possible reasons:")
|
|
||||||
logger.warning(f" 1. Model couldn't generate response based on context")
|
|
||||||
logger.warning(f" 2. Content blocked by safety filters")
|
|
||||||
logger.warning(f" 3. All previous tool calls failed")
|
|
||||||
logger.warning(f" 4. API error not properly caught")
|
|
||||||
|
|
||||||
# Final chunk
|
# Final chunk
|
||||||
yield {
|
yield {
|
||||||
|
|||||||
Reference in New Issue
Block a user