fix: prioritize using a custom master model for vision

This commit is contained in:
zhayujie
2026-04-08 15:16:59 +08:00
parent 9525dc7584
commit ad86deb014

View File

@@ -1,14 +1,15 @@
"""
Vision tool - Analyze images using OpenAI-compatible Vision API.
Supports local files (auto base64-encoded) and HTTP URLs.
Providers: OpenAI (preferred) > LinkAI (fallback).
Providers are tried in priority order with automatic fallback on failure.
"""
import base64
import os
import subprocess
import tempfile
from typing import Any, Dict, Optional, Tuple
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
import requests
@@ -30,6 +31,24 @@ SUPPORTED_EXTENSIONS = {
}
OPENAI_COMPATIBLE_BOT_TYPES = {"openai", "openAI", "chatGPT"}
@dataclass
class VisionProvider:
"""A single Vision API provider configuration."""
name: str
api_key: str
api_base: str
extra_headers: dict = field(default_factory=dict)
model_override: Optional[str] = None
class VisionAPIError(Exception):
"""Raised when a Vision API call fails and should trigger fallback."""
pass
class Vision(BaseTool):
"""Analyze images using OpenAI-compatible Vision API"""
@@ -82,8 +101,8 @@ class Vision(BaseTool):
if not question:
return ToolResult.fail("Error: 'question' parameter is required")
api_key, api_base, extra_headers = self._resolve_provider()
if not api_key:
providers = self._resolve_providers()
if not providers:
return ToolResult.fail(
"Error: No API key configured for Vision.\n"
"Please configure one of the following using env_config tool:\n"
@@ -97,36 +116,91 @@ class Vision(BaseTool):
except Exception as e:
return ToolResult.fail(f"Error: {e}")
try:
return self._call_api(api_key, api_base, model, question, image_content, extra_headers)
except requests.Timeout:
return ToolResult.fail(f"Error: Vision API request timed out after {DEFAULT_TIMEOUT}s")
except requests.ConnectionError:
return ToolResult.fail("Error: Failed to connect to Vision API")
except Exception as e:
logger.error(f"[Vision] Unexpected error: {e}", exc_info=True)
return ToolResult.fail(f"Error: Vision API call failed - {e}")
return self._call_with_fallback(providers, model, question, image_content)
def _resolve_provider(self) -> Tuple[Optional[str], str, dict]:
"""Resolve API key, base URL and extra headers. Priority: conf() > env vars."""
def _call_with_fallback(self, providers: List[VisionProvider], model: str,
question: str, image_content: dict) -> ToolResult:
"""Try each provider in order; fall back to the next one on failure."""
errors: List[str] = []
for i, provider in enumerate(providers):
use_model = provider.model_override or model
try:
logger.debug(f"[Vision] Trying provider '{provider.name}' "
f"with model '{use_model}' ({i + 1}/{len(providers)})")
return self._call_api(provider, use_model, question, image_content)
except VisionAPIError as e:
errors.append(f"[{provider.name}/{use_model}] {e}")
logger.warning(f"[Vision] Provider '{provider.name}' failed: {e}")
except requests.Timeout:
errors.append(f"[{provider.name}/{use_model}] Request timed out after {DEFAULT_TIMEOUT}s")
logger.warning(f"[Vision] Provider '{provider.name}' timed out")
except requests.ConnectionError:
errors.append(f"[{provider.name}/{use_model}] Connection failed")
logger.warning(f"[Vision] Provider '{provider.name}' connection failed")
except Exception as e:
errors.append(f"[{provider.name}/{use_model}] {e}")
logger.error(f"[Vision] Provider '{provider.name}' unexpected error: {e}", exc_info=True)
return ToolResult.fail(
"Error: All Vision API providers failed.\n" + "\n".join(f" - {err}" for err in errors)
)
def _resolve_providers(self) -> List[VisionProvider]:
"""
Build an ordered list of available providers.
Each provider builder returns a VisionProvider or None.
To add a new provider, append a builder method to _PROVIDER_BUILDERS.
"""
providers: List[VisionProvider] = []
for builder in self._PROVIDER_BUILDERS:
provider = builder(self)
if provider:
providers.append(provider)
return providers
def _build_custom_model_provider(self) -> Optional[VisionProvider]:
"""
When bot_type is openai-compatible and a custom model is configured,
try the user's own model first — it may already support multimodal input.
"""
bot_type = conf().get("bot_type", "")
if bot_type not in OPENAI_COMPATIBLE_BOT_TYPES:
return None
custom_model = conf().get("model", "")
if not custom_model or custom_model == DEFAULT_MODEL:
return None
api_key = conf().get("open_ai_api_key") or os.environ.get("OPENAI_API_KEY")
if api_key:
api_base = (conf().get("open_ai_api_base") or os.environ.get("OPENAI_API_BASE", "")).rstrip("/") \
or "https://api.openai.com/v1"
return api_key, self._ensure_v1(api_base), {}
if not api_key:
return None
api_base = (conf().get("open_ai_api_base") or os.environ.get("OPENAI_API_BASE", "")).rstrip("/") \
or "https://api.openai.com/v1"
return VisionProvider(
name="CustomModel", api_key=api_key, api_base=self._ensure_v1(api_base),
model_override=custom_model,
)
def _build_openai_provider(self) -> Optional[VisionProvider]:
api_key = conf().get("open_ai_api_key") or os.environ.get("OPENAI_API_KEY")
if not api_key:
return None
api_base = (conf().get("open_ai_api_base") or os.environ.get("OPENAI_API_BASE", "")).rstrip("/") \
or "https://api.openai.com/v1"
return VisionProvider(name="OpenAI", api_key=api_key, api_base=self._ensure_v1(api_base))
def _build_linkai_provider(self) -> Optional[VisionProvider]:
api_key = conf().get("linkai_api_key") or os.environ.get("LINKAI_API_KEY")
if api_key:
api_base = (conf().get("linkai_api_base") or os.environ.get("LINKAI_API_BASE", "")).rstrip("/") \
or "https://api.link-ai.tech"
logger.debug("[Vision] Using LinkAI API (OPENAI_API_KEY not set)")
from common.utils import get_cloud_headers
extra = get_cloud_headers(api_key)
extra.pop("Authorization", None)
extra.pop("Content-Type", None)
return api_key, self._ensure_v1(api_base), extra
if not api_key:
return None
api_base = (conf().get("linkai_api_base") or os.environ.get("LINKAI_API_BASE", "")).rstrip("/") \
or "https://api.link-ai.tech"
from common.utils import get_cloud_headers
extra = get_cloud_headers(api_key)
extra.pop("Authorization", None)
extra.pop("Content-Type", None)
return VisionProvider(name="LinkAI", api_key=api_key, api_base=self._ensure_v1(api_base),
extra_headers=extra)
return None, "", {}
_PROVIDER_BUILDERS = [_build_custom_model_provider, _build_openai_provider, _build_linkai_provider]
@staticmethod
def _ensure_v1(api_base: str) -> str:
@@ -220,8 +294,13 @@ class Vision(BaseTool):
os.remove(tmp.name)
return path
def _call_api(self, api_key: str, api_base: str, model: str,
question: str, image_content: dict, extra_headers: dict = None) -> ToolResult:
def _call_api(self, provider: VisionProvider, model: str,
question: str, image_content: dict) -> ToolResult:
"""
Call a single provider's Vision API.
Raises VisionAPIError on recoverable failures so the caller can try
the next provider.
"""
payload = {
"model": model,
"messages": [
@@ -233,34 +312,30 @@ class Vision(BaseTool):
],
}
],
"max_tokens": MAX_TOKENS,
"max_completion_tokens": MAX_TOKENS,
}
headers = {
"Authorization": f"Bearer {api_key}",
"Authorization": f"Bearer {provider.api_key}",
"Content-Type": "application/json",
**(extra_headers or {}),
**provider.extra_headers,
}
resp = requests.post(
f"{api_base}/chat/completions",
f"{provider.api_base}/chat/completions",
headers=headers,
json=payload,
timeout=DEFAULT_TIMEOUT,
)
if resp.status_code == 401:
return ToolResult.fail("Error: Invalid API key. Please check your configuration.")
if resp.status_code == 429:
return ToolResult.fail("Error: API rate limit reached. Please try again later.")
if resp.status_code != 200:
return ToolResult.fail(f"Error: Vision API returned HTTP {resp.status_code}: {resp.text[:200]}")
raise VisionAPIError(f"HTTP {resp.status_code}: {resp.text[:200]}")
data = resp.json()
if "error" in data:
msg = data["error"].get("message", "Unknown API error")
return ToolResult.fail(f"Error: Vision API error - {msg}")
raise VisionAPIError(f"API error - {msg}")
content = ""
choices = data.get("choices", [])
@@ -270,6 +345,7 @@ class Vision(BaseTool):
usage = data.get("usage", {})
result = {
"model": model,
"provider": provider.name,
"content": content,
"usage": {
"prompt_tokens": usage.get("prompt_tokens", 0),