diff --git a/agent/tools/vision/vision.py b/agent/tools/vision/vision.py index 72cd7cbb..3f8ad308 100644 --- a/agent/tools/vision/vision.py +++ b/agent/tools/vision/vision.py @@ -1,14 +1,15 @@ """ Vision tool - Analyze images using OpenAI-compatible Vision API. Supports local files (auto base64-encoded) and HTTP URLs. -Providers: OpenAI (preferred) > LinkAI (fallback). +Providers are tried in priority order with automatic fallback on failure. """ import base64 import os import subprocess import tempfile -from typing import Any, Dict, Optional, Tuple +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional import requests @@ -30,6 +31,24 @@ SUPPORTED_EXTENSIONS = { } +OPENAI_COMPATIBLE_BOT_TYPES = {"openai", "openAI", "chatGPT"} + + +@dataclass +class VisionProvider: + """A single Vision API provider configuration.""" + name: str + api_key: str + api_base: str + extra_headers: dict = field(default_factory=dict) + model_override: Optional[str] = None + + +class VisionAPIError(Exception): + """Raised when a Vision API call fails and should trigger fallback.""" + pass + + class Vision(BaseTool): """Analyze images using OpenAI-compatible Vision API""" @@ -82,8 +101,8 @@ class Vision(BaseTool): if not question: return ToolResult.fail("Error: 'question' parameter is required") - api_key, api_base, extra_headers = self._resolve_provider() - if not api_key: + providers = self._resolve_providers() + if not providers: return ToolResult.fail( "Error: No API key configured for Vision.\n" "Please configure one of the following using env_config tool:\n" @@ -97,36 +116,91 @@ class Vision(BaseTool): except Exception as e: return ToolResult.fail(f"Error: {e}") - try: - return self._call_api(api_key, api_base, model, question, image_content, extra_headers) - except requests.Timeout: - return ToolResult.fail(f"Error: Vision API request timed out after {DEFAULT_TIMEOUT}s") - except requests.ConnectionError: - return ToolResult.fail("Error: Failed to connect to Vision API") - except Exception as e: - logger.error(f"[Vision] Unexpected error: {e}", exc_info=True) - return ToolResult.fail(f"Error: Vision API call failed - {e}") + return self._call_with_fallback(providers, model, question, image_content) - def _resolve_provider(self) -> Tuple[Optional[str], str, dict]: - """Resolve API key, base URL and extra headers. Priority: conf() > env vars.""" + def _call_with_fallback(self, providers: List[VisionProvider], model: str, + question: str, image_content: dict) -> ToolResult: + """Try each provider in order; fall back to the next one on failure.""" + errors: List[str] = [] + for i, provider in enumerate(providers): + use_model = provider.model_override or model + try: + logger.debug(f"[Vision] Trying provider '{provider.name}' " + f"with model '{use_model}' ({i + 1}/{len(providers)})") + return self._call_api(provider, use_model, question, image_content) + except VisionAPIError as e: + errors.append(f"[{provider.name}/{use_model}] {e}") + logger.warning(f"[Vision] Provider '{provider.name}' failed: {e}") + except requests.Timeout: + errors.append(f"[{provider.name}/{use_model}] Request timed out after {DEFAULT_TIMEOUT}s") + logger.warning(f"[Vision] Provider '{provider.name}' timed out") + except requests.ConnectionError: + errors.append(f"[{provider.name}/{use_model}] Connection failed") + logger.warning(f"[Vision] Provider '{provider.name}' connection failed") + except Exception as e: + errors.append(f"[{provider.name}/{use_model}] {e}") + logger.error(f"[Vision] Provider '{provider.name}' unexpected error: {e}", exc_info=True) + + return ToolResult.fail( + "Error: All Vision API providers failed.\n" + "\n".join(f" - {err}" for err in errors) + ) + + def _resolve_providers(self) -> List[VisionProvider]: + """ + Build an ordered list of available providers. + Each provider builder returns a VisionProvider or None. + To add a new provider, append a builder method to _PROVIDER_BUILDERS. + """ + providers: List[VisionProvider] = [] + for builder in self._PROVIDER_BUILDERS: + provider = builder(self) + if provider: + providers.append(provider) + return providers + + def _build_custom_model_provider(self) -> Optional[VisionProvider]: + """ + When bot_type is openai-compatible and a custom model is configured, + try the user's own model first — it may already support multimodal input. + """ + bot_type = conf().get("bot_type", "") + if bot_type not in OPENAI_COMPATIBLE_BOT_TYPES: + return None + custom_model = conf().get("model", "") + if not custom_model or custom_model == DEFAULT_MODEL: + return None api_key = conf().get("open_ai_api_key") or os.environ.get("OPENAI_API_KEY") - if api_key: - api_base = (conf().get("open_ai_api_base") or os.environ.get("OPENAI_API_BASE", "")).rstrip("/") \ - or "https://api.openai.com/v1" - return api_key, self._ensure_v1(api_base), {} + if not api_key: + return None + api_base = (conf().get("open_ai_api_base") or os.environ.get("OPENAI_API_BASE", "")).rstrip("/") \ + or "https://api.openai.com/v1" + return VisionProvider( + name="CustomModel", api_key=api_key, api_base=self._ensure_v1(api_base), + model_override=custom_model, + ) + def _build_openai_provider(self) -> Optional[VisionProvider]: + api_key = conf().get("open_ai_api_key") or os.environ.get("OPENAI_API_KEY") + if not api_key: + return None + api_base = (conf().get("open_ai_api_base") or os.environ.get("OPENAI_API_BASE", "")).rstrip("/") \ + or "https://api.openai.com/v1" + return VisionProvider(name="OpenAI", api_key=api_key, api_base=self._ensure_v1(api_base)) + + def _build_linkai_provider(self) -> Optional[VisionProvider]: api_key = conf().get("linkai_api_key") or os.environ.get("LINKAI_API_KEY") - if api_key: - api_base = (conf().get("linkai_api_base") or os.environ.get("LINKAI_API_BASE", "")).rstrip("/") \ - or "https://api.link-ai.tech" - logger.debug("[Vision] Using LinkAI API (OPENAI_API_KEY not set)") - from common.utils import get_cloud_headers - extra = get_cloud_headers(api_key) - extra.pop("Authorization", None) - extra.pop("Content-Type", None) - return api_key, self._ensure_v1(api_base), extra + if not api_key: + return None + api_base = (conf().get("linkai_api_base") or os.environ.get("LINKAI_API_BASE", "")).rstrip("/") \ + or "https://api.link-ai.tech" + from common.utils import get_cloud_headers + extra = get_cloud_headers(api_key) + extra.pop("Authorization", None) + extra.pop("Content-Type", None) + return VisionProvider(name="LinkAI", api_key=api_key, api_base=self._ensure_v1(api_base), + extra_headers=extra) - return None, "", {} + _PROVIDER_BUILDERS = [_build_custom_model_provider, _build_openai_provider, _build_linkai_provider] @staticmethod def _ensure_v1(api_base: str) -> str: @@ -220,8 +294,13 @@ class Vision(BaseTool): os.remove(tmp.name) return path - def _call_api(self, api_key: str, api_base: str, model: str, - question: str, image_content: dict, extra_headers: dict = None) -> ToolResult: + def _call_api(self, provider: VisionProvider, model: str, + question: str, image_content: dict) -> ToolResult: + """ + Call a single provider's Vision API. + Raises VisionAPIError on recoverable failures so the caller can try + the next provider. + """ payload = { "model": model, "messages": [ @@ -233,34 +312,30 @@ class Vision(BaseTool): ], } ], - "max_tokens": MAX_TOKENS, + "max_completion_tokens": MAX_TOKENS, } headers = { - "Authorization": f"Bearer {api_key}", + "Authorization": f"Bearer {provider.api_key}", "Content-Type": "application/json", - **(extra_headers or {}), + **provider.extra_headers, } resp = requests.post( - f"{api_base}/chat/completions", + f"{provider.api_base}/chat/completions", headers=headers, json=payload, timeout=DEFAULT_TIMEOUT, ) - if resp.status_code == 401: - return ToolResult.fail("Error: Invalid API key. Please check your configuration.") - if resp.status_code == 429: - return ToolResult.fail("Error: API rate limit reached. Please try again later.") if resp.status_code != 200: - return ToolResult.fail(f"Error: Vision API returned HTTP {resp.status_code}: {resp.text[:200]}") + raise VisionAPIError(f"HTTP {resp.status_code}: {resp.text[:200]}") data = resp.json() if "error" in data: msg = data["error"].get("message", "Unknown API error") - return ToolResult.fail(f"Error: Vision API error - {msg}") + raise VisionAPIError(f"API error - {msg}") content = "" choices = data.get("choices", []) @@ -270,6 +345,7 @@ class Vision(BaseTool): usage = data.get("usage", {}) result = { "model": model, + "provider": provider.name, "content": content, "usage": { "prompt_tokens": usage.get("prompt_tokens", 0),