diff --git a/agent/tools/vision/vision.py b/agent/tools/vision/vision.py index ce2477a7..0c8e48ba 100644 --- a/agent/tools/vision/vision.py +++ b/agent/tools/vision/vision.py @@ -53,6 +53,7 @@ _DISCOVERABLE_MODELS = [ ("ark_api_key", const.DOUBAO, const.DOUBAO_SEED_2_PRO, "Doubao"), ("dashscope_api_key", const.QWEN_DASHSCOPE, const.QWEN36_PLUS, "DashScope"), ("claude_api_key", const.CLAUDEAPI, const.CLAUDE_4_6_SONNET, "Claude"), + ("qianfan_api_key", const.QIANFAN, const.ERNIE_45_TURBO_VL_PREVIEW, "Qianfan"), ("gemini_api_key", const.GEMINI, const.GEMINI_31_FLASH_LITE_PRE, "Gemini"), ("zhipu_ai_api_key", const.ZHIPU_AI, const.GLM_4_7, "ZhipuAI"), ("minimax_api_key", const.MiniMax, const.MINIMAX_M2_7, "MiniMax"), @@ -67,6 +68,7 @@ _MODEL_PREFIX_TO_PROVIDER = [ ("moonshot-", "Moonshot"), ("qwen", "DashScope"), # qwen-*, qwen3-*, qwen3.6-*, etc. ("claude-", "Claude"), + ("ernie-", "Qianfan"), ("gemini-", "Gemini"), ("glm-", "ZhipuAI"), ("minimax-", "MiniMax"), @@ -140,7 +142,7 @@ class Vision(BaseTool): "Error: No model available for Vision.\n" "The main model does not support vision and no other API keys are configured.\n" "Options:\n" - " 1. Switch to a multimodal model (e.g. qwen3.6-plus, claude-sonnet-4-6, gemini-2.0-flash)\n" + " 1. Switch to a multimodal model (e.g. ernie-4.5-turbo-vl-preview, qwen3.6-plus, claude-sonnet-4-6, gemini-2.0-flash)\n" " 2. Configure OPENAI_API_KEY: env_config(action=\"set\", key=\"OPENAI_API_KEY\", value=\"your-key\")\n" " 3. Configure LINKAI_API_KEY: env_config(action=\"set\", key=\"LINKAI_API_KEY\", value=\"your-key\")" ) diff --git a/tests/test_qianfan_provider.py b/tests/test_qianfan_provider.py index d97211c0..2e51224a 100644 --- a/tests/test_qianfan_provider.py +++ b/tests/test_qianfan_provider.py @@ -360,6 +360,82 @@ class TestQianfanSurfaces(unittest.TestCase): self.assertIn("const.QIANFAN", godcmd_source) +class TestQianfanVisionTool(unittest.TestCase): + def _fake_conf(self, values=None): + data = { + "model": "deepseek-v4-flash", + "qianfan_api_key": "", + "qianfan_api_base": "https://qianfan.baidubce.com/v2", + "open_ai_api_key": "", + "linkai_api_key": "", + "use_linkai": False, + "tool": {}, + } + if values: + data.update(values) + fake_conf = MagicMock() + fake_conf.get.side_effect = lambda key, default=None: data.get(key, default) + return fake_conf + + def test_vision_auto_discovers_qianfan_when_key_configured(self): + fake_conf = self._fake_conf({"qianfan_api_key": "test-qianfan-key"}) + fake_bot = MagicMock() + fake_bot.call_vision = MagicMock() + + with patch("agent.tools.vision.vision.conf", return_value=fake_conf): + with patch("models.bot_factory.create_bot", return_value=fake_bot) as create_bot: + from agent.tools.vision.vision import Vision + from common import const + + tool = Vision() + tool.model = None + providers = tool._resolve_providers() + + self.assertEqual(providers[0].name, "Qianfan") + self.assertEqual(providers[0].model_override, const.ERNIE_45_TURBO_VL_PREVIEW) + self.assertTrue(providers[0].use_bot) + create_bot.assert_called_with(const.QIANFAN) + + def test_vision_routes_ernie_model_override_to_qianfan(self): + fake_conf = self._fake_conf({ + "qianfan_api_key": "test-qianfan-key", + "tool": {"vision": {"model": "ernie-4.5-vl-28b-a3b"}}, + }) + fake_bot = MagicMock() + fake_bot.call_vision = MagicMock() + + with patch("agent.tools.vision.vision.conf", return_value=fake_conf): + with patch("models.bot_factory.create_bot", return_value=fake_bot): + from agent.tools.vision.vision import Vision + + tool = Vision() + tool.model = None + providers = tool._resolve_providers() + + self.assertEqual(providers[0].name, "Qianfan") + self.assertEqual(providers[0].model_override, "ernie-4.5-vl-28b-a3b") + + def test_vision_main_model_uses_qianfan_when_configured_model_is_ernie(self): + fake_conf = self._fake_conf({"model": "ernie-4.5-vl-28b-a3b"}) + from common import const + + fake_model = MagicMock() + fake_model._resolve_bot_type.return_value = const.QIANFAN + fake_model.bot = MagicMock() + fake_model.bot.supports_vision = True + fake_model.bot.call_vision = MagicMock() + + with patch("agent.tools.vision.vision.conf", return_value=fake_conf): + from agent.tools.vision.vision import Vision + + tool = Vision() + tool.model = fake_model + providers = tool._resolve_providers() + + self.assertEqual(providers[0].name, "MainModel") + self.assertEqual(providers[0].model_override, "ernie-4.5-vl-28b-a3b") + + class TestQianfanDocs(unittest.TestCase): def _read(self, relative_path): root = os.path.join(os.path.dirname(__file__), "..")