mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-02 00:57:41 +08:00
feat(qianfan): scope vision support to multimodal models
This commit is contained in:
@@ -18,7 +18,7 @@ class TestQianfanConstantsAndRouting(unittest.TestCase):
|
||||
|
||||
self.assertEqual(const.ERNIE_45_TURBO_128K, "ernie-4.5-turbo-128k")
|
||||
self.assertEqual(const.ERNIE_45_TURBO_32K, "ernie-4.5-turbo-32k")
|
||||
self.assertEqual(const.ERNIE_X1_TURBO_32K, "ernie-x1-turbo-32k")
|
||||
self.assertEqual(const.ERNIE_X1_1, "ernie-x1.1")
|
||||
self.assertEqual(
|
||||
const.ERNIE_45_TURBO_VL,
|
||||
"ernie-4.5-turbo-vl",
|
||||
@@ -30,7 +30,7 @@ class TestQianfanConstantsAndRouting(unittest.TestCase):
|
||||
self.assertIn(const.QIANFAN, const.MODEL_LIST)
|
||||
self.assertIn(const.ERNIE_45_TURBO_128K, const.MODEL_LIST)
|
||||
self.assertIn(const.ERNIE_45_TURBO_32K, const.MODEL_LIST)
|
||||
self.assertIn(const.ERNIE_X1_TURBO_32K, const.MODEL_LIST)
|
||||
self.assertIn(const.ERNIE_X1_1, const.MODEL_LIST)
|
||||
self.assertIn(const.ERNIE_45_TURBO_VL, const.MODEL_LIST)
|
||||
self.assertIn(const.ERNIE_45_TURBO_VL_32K, const.MODEL_LIST)
|
||||
|
||||
@@ -223,15 +223,31 @@ class TestQianfanBot(unittest.TestCase):
|
||||
self.assertEqual(result["content"], "请求失败:bad gateway text")
|
||||
post.assert_called_once()
|
||||
|
||||
def test_qianfan_bot_supports_vision(self):
|
||||
fake_conf = self._fake_conf()
|
||||
with patch("models.qianfan.qianfan_bot.conf", return_value=fake_conf):
|
||||
with patch("models.qianfan.qianfan_bot.SessionManager"):
|
||||
from models.qianfan.qianfan_bot import QianfanBot
|
||||
def test_qianfan_bot_supports_vision_for_multimodal_models(self):
|
||||
for model in ("ernie-5.0", "ernie-x1.1", "ernie-4.5-turbo-vl", "ernie-4.5-turbo-vl-32k"):
|
||||
fake_conf = self._fake_conf({"model": model})
|
||||
with patch("models.qianfan.qianfan_bot.conf", return_value=fake_conf):
|
||||
with patch("models.qianfan.qianfan_bot.SessionManager"):
|
||||
from models.qianfan.qianfan_bot import QianfanBot
|
||||
|
||||
bot = QianfanBot()
|
||||
bot = QianfanBot()
|
||||
self.assertTrue(
|
||||
bot.supports_vision,
|
||||
msg=f"{model} should be marked as multimodal",
|
||||
)
|
||||
|
||||
self.assertTrue(bot.supports_vision)
|
||||
def test_qianfan_bot_does_not_advertise_vision_for_text_only_models(self):
|
||||
for model in ("ernie-4.5-turbo-128k", "ernie-4.5-turbo-32k"):
|
||||
fake_conf = self._fake_conf({"model": model})
|
||||
with patch("models.qianfan.qianfan_bot.conf", return_value=fake_conf):
|
||||
with patch("models.qianfan.qianfan_bot.SessionManager"):
|
||||
from models.qianfan.qianfan_bot import QianfanBot
|
||||
|
||||
bot = QianfanBot()
|
||||
self.assertFalse(
|
||||
bot.supports_vision,
|
||||
msg=f"{model} should not be marked as multimodal",
|
||||
)
|
||||
|
||||
def test_call_vision_posts_openai_compatible_multimodal_payload(self):
|
||||
fake_conf = self._fake_conf()
|
||||
@@ -435,6 +451,105 @@ class TestQianfanVisionTool(unittest.TestCase):
|
||||
self.assertEqual(providers[0].name, "MainModel")
|
||||
self.assertEqual(providers[0].model_override, "ernie-4.5-turbo-vl-32k")
|
||||
|
||||
def test_vision_main_model_uses_ernie_5_directly(self):
|
||||
"""ERNIE 5.0 is omni-modal → main-model path forwards image to it."""
|
||||
fake_conf = self._fake_conf({"model": "ernie-5.0"})
|
||||
from common import const
|
||||
|
||||
fake_model = MagicMock()
|
||||
fake_model._resolve_bot_type.return_value = const.QIANFAN
|
||||
fake_model.bot = MagicMock()
|
||||
fake_model.bot.supports_vision = True
|
||||
fake_model.bot.call_vision = MagicMock()
|
||||
|
||||
with patch("agent.tools.vision.vision.conf", return_value=fake_conf):
|
||||
from agent.tools.vision.vision import Vision
|
||||
|
||||
tool = Vision()
|
||||
tool.model = fake_model
|
||||
providers = tool._resolve_providers()
|
||||
|
||||
self.assertEqual(providers[0].name, "MainModel")
|
||||
self.assertEqual(providers[0].model_override, "ernie-5.0")
|
||||
|
||||
def test_vision_falls_back_to_qianfan_vl_when_main_model_is_text_only_ernie(self):
|
||||
"""Text-only ERNIE (e.g. ernie-4.5-turbo-128k) must NOT receive image
|
||||
payloads — Vision should skip MainModel and pick up the Qianfan
|
||||
provider from _DISCOVERABLE_MODELS instead."""
|
||||
fake_conf = self._fake_conf({
|
||||
"model": "ernie-4.5-turbo-128k",
|
||||
"qianfan_api_key": "test-qianfan-key",
|
||||
})
|
||||
from common import const
|
||||
|
||||
# Main bot reports supports_vision=False because the configured
|
||||
# model is text-only.
|
||||
fake_main_bot = MagicMock()
|
||||
fake_main_bot.supports_vision = False
|
||||
fake_main_bot.call_vision = MagicMock()
|
||||
|
||||
fake_model = MagicMock()
|
||||
fake_model._resolve_bot_type.return_value = const.QIANFAN
|
||||
fake_model.bot = fake_main_bot
|
||||
|
||||
# The discoverable Qianfan provider creates a new bot via factory.
|
||||
fake_factory_bot = MagicMock()
|
||||
fake_factory_bot.call_vision = MagicMock()
|
||||
|
||||
with patch("agent.tools.vision.vision.conf", return_value=fake_conf):
|
||||
with patch("models.bot_factory.create_bot", return_value=fake_factory_bot):
|
||||
from agent.tools.vision.vision import Vision
|
||||
|
||||
tool = Vision()
|
||||
tool.model = fake_model
|
||||
providers = tool._resolve_providers()
|
||||
|
||||
# MainModel must be absent; Qianfan fallback provider must be the
|
||||
# first choice and pinned to the dedicated vision model.
|
||||
names = [p.name for p in providers]
|
||||
self.assertNotIn("MainModel", names)
|
||||
self.assertEqual(names[0], "Qianfan")
|
||||
self.assertEqual(providers[0].model_override, const.ERNIE_45_TURBO_VL)
|
||||
|
||||
def test_vision_prefers_same_vendor_fallback_over_other_configured_keys(self):
|
||||
"""When the main bot is text-only ERNIE and several vision-capable
|
||||
keys are configured, the same-vendor (Qianfan) fallback wins over
|
||||
unrelated providers regardless of declaration order."""
|
||||
fake_conf = self._fake_conf({
|
||||
"model": "ernie-4.5-turbo-128k",
|
||||
"qianfan_api_key": "test-qianfan-key",
|
||||
"ark_api_key": "test-ark-key",
|
||||
"claude_api_key": "test-claude-key",
|
||||
"minimax_api_key": "test-minimax-key",
|
||||
})
|
||||
from common import const
|
||||
|
||||
fake_main_bot = MagicMock()
|
||||
fake_main_bot.supports_vision = False
|
||||
fake_main_bot.call_vision = MagicMock()
|
||||
|
||||
fake_model = MagicMock()
|
||||
fake_model._resolve_bot_type.return_value = const.QIANFAN
|
||||
fake_model.bot = fake_main_bot
|
||||
|
||||
fake_factory_bot = MagicMock()
|
||||
fake_factory_bot.call_vision = MagicMock()
|
||||
|
||||
with patch("agent.tools.vision.vision.conf", return_value=fake_conf):
|
||||
with patch("models.bot_factory.create_bot", return_value=fake_factory_bot):
|
||||
from agent.tools.vision.vision import Vision
|
||||
|
||||
tool = Vision()
|
||||
tool.model = fake_model
|
||||
providers = tool._resolve_providers()
|
||||
|
||||
names = [p.name for p in providers]
|
||||
self.assertEqual(names[0], "Qianfan")
|
||||
self.assertEqual(providers[0].model_override, const.ERNIE_45_TURBO_VL)
|
||||
# Other configured providers should still appear in the chain.
|
||||
for expected in ("Doubao", "Claude", "MiniMax"):
|
||||
self.assertIn(expected, names)
|
||||
|
||||
|
||||
class TestQianfanDocs(unittest.TestCase):
|
||||
def _read(self, relative_path):
|
||||
|
||||
Reference in New Issue
Block a user