mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-02 00:57:41 +08:00
fix: prioritize using a custom master model for vision
This commit is contained in:
@@ -1,14 +1,15 @@
|
||||
"""
|
||||
Vision tool - Analyze images using OpenAI-compatible Vision API.
|
||||
Supports local files (auto base64-encoded) and HTTP URLs.
|
||||
Providers: OpenAI (preferred) > LinkAI (fallback).
|
||||
Providers are tried in priority order with automatic fallback on failure.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
@@ -30,6 +31,24 @@ SUPPORTED_EXTENSIONS = {
|
||||
}
|
||||
|
||||
|
||||
OPENAI_COMPATIBLE_BOT_TYPES = {"openai", "openAI", "chatGPT"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class VisionProvider:
|
||||
"""A single Vision API provider configuration."""
|
||||
name: str
|
||||
api_key: str
|
||||
api_base: str
|
||||
extra_headers: dict = field(default_factory=dict)
|
||||
model_override: Optional[str] = None
|
||||
|
||||
|
||||
class VisionAPIError(Exception):
|
||||
"""Raised when a Vision API call fails and should trigger fallback."""
|
||||
pass
|
||||
|
||||
|
||||
class Vision(BaseTool):
|
||||
"""Analyze images using OpenAI-compatible Vision API"""
|
||||
|
||||
@@ -82,8 +101,8 @@ class Vision(BaseTool):
|
||||
if not question:
|
||||
return ToolResult.fail("Error: 'question' parameter is required")
|
||||
|
||||
api_key, api_base, extra_headers = self._resolve_provider()
|
||||
if not api_key:
|
||||
providers = self._resolve_providers()
|
||||
if not providers:
|
||||
return ToolResult.fail(
|
||||
"Error: No API key configured for Vision.\n"
|
||||
"Please configure one of the following using env_config tool:\n"
|
||||
@@ -97,36 +116,91 @@ class Vision(BaseTool):
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error: {e}")
|
||||
|
||||
try:
|
||||
return self._call_api(api_key, api_base, model, question, image_content, extra_headers)
|
||||
except requests.Timeout:
|
||||
return ToolResult.fail(f"Error: Vision API request timed out after {DEFAULT_TIMEOUT}s")
|
||||
except requests.ConnectionError:
|
||||
return ToolResult.fail("Error: Failed to connect to Vision API")
|
||||
except Exception as e:
|
||||
logger.error(f"[Vision] Unexpected error: {e}", exc_info=True)
|
||||
return ToolResult.fail(f"Error: Vision API call failed - {e}")
|
||||
return self._call_with_fallback(providers, model, question, image_content)
|
||||
|
||||
def _resolve_provider(self) -> Tuple[Optional[str], str, dict]:
|
||||
"""Resolve API key, base URL and extra headers. Priority: conf() > env vars."""
|
||||
def _call_with_fallback(self, providers: List[VisionProvider], model: str,
|
||||
question: str, image_content: dict) -> ToolResult:
|
||||
"""Try each provider in order; fall back to the next one on failure."""
|
||||
errors: List[str] = []
|
||||
for i, provider in enumerate(providers):
|
||||
use_model = provider.model_override or model
|
||||
try:
|
||||
logger.debug(f"[Vision] Trying provider '{provider.name}' "
|
||||
f"with model '{use_model}' ({i + 1}/{len(providers)})")
|
||||
return self._call_api(provider, use_model, question, image_content)
|
||||
except VisionAPIError as e:
|
||||
errors.append(f"[{provider.name}/{use_model}] {e}")
|
||||
logger.warning(f"[Vision] Provider '{provider.name}' failed: {e}")
|
||||
except requests.Timeout:
|
||||
errors.append(f"[{provider.name}/{use_model}] Request timed out after {DEFAULT_TIMEOUT}s")
|
||||
logger.warning(f"[Vision] Provider '{provider.name}' timed out")
|
||||
except requests.ConnectionError:
|
||||
errors.append(f"[{provider.name}/{use_model}] Connection failed")
|
||||
logger.warning(f"[Vision] Provider '{provider.name}' connection failed")
|
||||
except Exception as e:
|
||||
errors.append(f"[{provider.name}/{use_model}] {e}")
|
||||
logger.error(f"[Vision] Provider '{provider.name}' unexpected error: {e}", exc_info=True)
|
||||
|
||||
return ToolResult.fail(
|
||||
"Error: All Vision API providers failed.\n" + "\n".join(f" - {err}" for err in errors)
|
||||
)
|
||||
|
||||
def _resolve_providers(self) -> List[VisionProvider]:
|
||||
"""
|
||||
Build an ordered list of available providers.
|
||||
Each provider builder returns a VisionProvider or None.
|
||||
To add a new provider, append a builder method to _PROVIDER_BUILDERS.
|
||||
"""
|
||||
providers: List[VisionProvider] = []
|
||||
for builder in self._PROVIDER_BUILDERS:
|
||||
provider = builder(self)
|
||||
if provider:
|
||||
providers.append(provider)
|
||||
return providers
|
||||
|
||||
def _build_custom_model_provider(self) -> Optional[VisionProvider]:
|
||||
"""
|
||||
When bot_type is openai-compatible and a custom model is configured,
|
||||
try the user's own model first — it may already support multimodal input.
|
||||
"""
|
||||
bot_type = conf().get("bot_type", "")
|
||||
if bot_type not in OPENAI_COMPATIBLE_BOT_TYPES:
|
||||
return None
|
||||
custom_model = conf().get("model", "")
|
||||
if not custom_model or custom_model == DEFAULT_MODEL:
|
||||
return None
|
||||
api_key = conf().get("open_ai_api_key") or os.environ.get("OPENAI_API_KEY")
|
||||
if api_key:
|
||||
api_base = (conf().get("open_ai_api_base") or os.environ.get("OPENAI_API_BASE", "")).rstrip("/") \
|
||||
or "https://api.openai.com/v1"
|
||||
return api_key, self._ensure_v1(api_base), {}
|
||||
if not api_key:
|
||||
return None
|
||||
api_base = (conf().get("open_ai_api_base") or os.environ.get("OPENAI_API_BASE", "")).rstrip("/") \
|
||||
or "https://api.openai.com/v1"
|
||||
return VisionProvider(
|
||||
name="CustomModel", api_key=api_key, api_base=self._ensure_v1(api_base),
|
||||
model_override=custom_model,
|
||||
)
|
||||
|
||||
def _build_openai_provider(self) -> Optional[VisionProvider]:
|
||||
api_key = conf().get("open_ai_api_key") or os.environ.get("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
return None
|
||||
api_base = (conf().get("open_ai_api_base") or os.environ.get("OPENAI_API_BASE", "")).rstrip("/") \
|
||||
or "https://api.openai.com/v1"
|
||||
return VisionProvider(name="OpenAI", api_key=api_key, api_base=self._ensure_v1(api_base))
|
||||
|
||||
def _build_linkai_provider(self) -> Optional[VisionProvider]:
|
||||
api_key = conf().get("linkai_api_key") or os.environ.get("LINKAI_API_KEY")
|
||||
if api_key:
|
||||
api_base = (conf().get("linkai_api_base") or os.environ.get("LINKAI_API_BASE", "")).rstrip("/") \
|
||||
or "https://api.link-ai.tech"
|
||||
logger.debug("[Vision] Using LinkAI API (OPENAI_API_KEY not set)")
|
||||
from common.utils import get_cloud_headers
|
||||
extra = get_cloud_headers(api_key)
|
||||
extra.pop("Authorization", None)
|
||||
extra.pop("Content-Type", None)
|
||||
return api_key, self._ensure_v1(api_base), extra
|
||||
if not api_key:
|
||||
return None
|
||||
api_base = (conf().get("linkai_api_base") or os.environ.get("LINKAI_API_BASE", "")).rstrip("/") \
|
||||
or "https://api.link-ai.tech"
|
||||
from common.utils import get_cloud_headers
|
||||
extra = get_cloud_headers(api_key)
|
||||
extra.pop("Authorization", None)
|
||||
extra.pop("Content-Type", None)
|
||||
return VisionProvider(name="LinkAI", api_key=api_key, api_base=self._ensure_v1(api_base),
|
||||
extra_headers=extra)
|
||||
|
||||
return None, "", {}
|
||||
_PROVIDER_BUILDERS = [_build_custom_model_provider, _build_openai_provider, _build_linkai_provider]
|
||||
|
||||
@staticmethod
|
||||
def _ensure_v1(api_base: str) -> str:
|
||||
@@ -220,8 +294,13 @@ class Vision(BaseTool):
|
||||
os.remove(tmp.name)
|
||||
return path
|
||||
|
||||
def _call_api(self, api_key: str, api_base: str, model: str,
|
||||
question: str, image_content: dict, extra_headers: dict = None) -> ToolResult:
|
||||
def _call_api(self, provider: VisionProvider, model: str,
|
||||
question: str, image_content: dict) -> ToolResult:
|
||||
"""
|
||||
Call a single provider's Vision API.
|
||||
Raises VisionAPIError on recoverable failures so the caller can try
|
||||
the next provider.
|
||||
"""
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": [
|
||||
@@ -233,34 +312,30 @@ class Vision(BaseTool):
|
||||
],
|
||||
}
|
||||
],
|
||||
"max_tokens": MAX_TOKENS,
|
||||
"max_completion_tokens": MAX_TOKENS,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Authorization": f"Bearer {provider.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
**(extra_headers or {}),
|
||||
**provider.extra_headers,
|
||||
}
|
||||
|
||||
resp = requests.post(
|
||||
f"{api_base}/chat/completions",
|
||||
f"{provider.api_base}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return ToolResult.fail("Error: Invalid API key. Please check your configuration.")
|
||||
if resp.status_code == 429:
|
||||
return ToolResult.fail("Error: API rate limit reached. Please try again later.")
|
||||
if resp.status_code != 200:
|
||||
return ToolResult.fail(f"Error: Vision API returned HTTP {resp.status_code}: {resp.text[:200]}")
|
||||
raise VisionAPIError(f"HTTP {resp.status_code}: {resp.text[:200]}")
|
||||
|
||||
data = resp.json()
|
||||
|
||||
if "error" in data:
|
||||
msg = data["error"].get("message", "Unknown API error")
|
||||
return ToolResult.fail(f"Error: Vision API error - {msg}")
|
||||
raise VisionAPIError(f"API error - {msg}")
|
||||
|
||||
content = ""
|
||||
choices = data.get("choices", [])
|
||||
@@ -270,6 +345,7 @@ class Vision(BaseTool):
|
||||
usage = data.get("usage", {})
|
||||
result = {
|
||||
"model": model,
|
||||
"provider": provider.name,
|
||||
"content": content,
|
||||
"usage": {
|
||||
"prompt_tokens": usage.get("prompt_tokens", 0),
|
||||
|
||||
Reference in New Issue
Block a user