From 3b12ef2e66b8c3cb4ad1d0bddc63cb69a4c7167c Mon Sep 17 00:00:00 2001 From: jimmyzhuu Date: Wed, 6 May 2026 13:24:41 +0800 Subject: [PATCH] feat: add qianfan vision calls --- models/qianfan/qianfan_bot.py | 51 ++++++++++++++++ tests/test_qianfan_provider.py | 107 +++++++++++++++++++++++++++++++++ 2 files changed, 158 insertions(+) diff --git a/models/qianfan/qianfan_bot.py b/models/qianfan/qianfan_bot.py index 1626479e..9e3321fb 100644 --- a/models/qianfan/qianfan_bot.py +++ b/models/qianfan/qianfan_bot.py @@ -15,9 +15,12 @@ from .qianfan_session import QianfanSession DEFAULT_API_BASE = "https://qianfan.baidubce.com/v2" DEFAULT_MODEL = const.ERNIE_5 +DEFAULT_VISION_MODEL = const.ERNIE_45_TURBO_VL_PREVIEW class QianfanBot(Bot, OpenAICompatibleBot): + supports_vision = True + def __init__(self): super().__init__() model = self._resolve_model() @@ -136,6 +139,54 @@ class QianfanBot(Bot, OpenAICompatibleBot): return self.reply_text(session, args, retry_count + 1) return {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} + def call_vision(self, image_url: str, question: str, + model: str = None, max_tokens: int = 1000) -> dict: + vision_model = model or DEFAULT_VISION_MODEL + payload = { + "model": vision_model, + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": question}, + {"type": "image_url", "image_url": {"url": image_url}}, + ], + } + ], + "max_tokens": max_tokens, + } + + try: + response = requests.post( + "{}/chat/completions".format(self.api_base), + headers=self._build_headers(), + json=payload, + timeout=conf().get("request_timeout", 180), + ) + if response.status_code != 200: + err = self._error_result(response, None) + return { + "error": True, + "message": err.get("content", "Qianfan vision request failed"), + } + + data = response.json() + choices = data.get("choices", []) + content = choices[0].get("message", {}).get("content", "") if choices else "" + usage = data.get("usage", {}) or {} + return { + "content": content, + "model": data.get("model", vision_model), + "usage": { + "prompt_tokens": usage.get("prompt_tokens", 0), + "completion_tokens": usage.get("completion_tokens", 0), + "total_tokens": usage.get("total_tokens", 0), + }, + } + except Exception as e: + logger.exception(e) + return {"error": True, "message": str(e)} + def _error_result(self, response, session, args=None, retry_count=0): try: body = response.json() diff --git a/tests/test_qianfan_provider.py b/tests/test_qianfan_provider.py index 51e01ff9..d97211c0 100644 --- a/tests/test_qianfan_provider.py +++ b/tests/test_qianfan_provider.py @@ -223,6 +223,113 @@ class TestQianfanBot(unittest.TestCase): self.assertEqual(result["content"], "请求失败:bad gateway text") post.assert_called_once() + def test_qianfan_bot_supports_vision(self): + fake_conf = self._fake_conf() + with patch("models.qianfan.qianfan_bot.conf", return_value=fake_conf): + with patch("models.qianfan.qianfan_bot.SessionManager"): + from models.qianfan.qianfan_bot import QianfanBot + + bot = QianfanBot() + + self.assertTrue(bot.supports_vision) + + def test_call_vision_posts_openai_compatible_multimodal_payload(self): + fake_conf = self._fake_conf() + fake_response = MagicMock() + fake_response.status_code = 200 + fake_response.json.return_value = { + "id": "chatcmpl-test", + "model": "ernie-4.5-turbo-vl-preview", + "choices": [{"message": {"content": "图中有一个红色方块。"}}], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 8, + "total_tokens": 18, + }, + } + + with patch("models.qianfan.qianfan_bot.conf", return_value=fake_conf): + with patch("models.qianfan.qianfan_bot.SessionManager"): + from models.qianfan.qianfan_bot import QianfanBot + + bot = QianfanBot() + with patch("models.qianfan.qianfan_bot.requests.post", return_value=fake_response) as post: + result = bot.call_vision( + image_url="data:image/png;base64,AAAA", + question="这张图里有什么?", + ) + + self.assertEqual(result["content"], "图中有一个红色方块。") + self.assertEqual(result["model"], "ernie-4.5-turbo-vl-preview") + self.assertEqual(result["usage"]["total_tokens"], 18) + post.assert_called_once() + url = post.call_args.args[0] + kwargs = post.call_args.kwargs + self.assertEqual(url, "https://qianfan.baidubce.com/v2/chat/completions") + self.assertEqual(kwargs["headers"]["Authorization"], "Bearer test-qianfan-key") + self.assertEqual(kwargs["json"]["model"], "ernie-4.5-turbo-vl-preview") + self.assertEqual(kwargs["json"]["max_tokens"], 1000) + self.assertEqual(kwargs["json"]["messages"], [ + { + "role": "user", + "content": [ + {"type": "text", "text": "这张图里有什么?"}, + { + "type": "image_url", + "image_url": {"url": "data:image/png;base64,AAAA"}, + }, + ], + } + ]) + + def test_call_vision_allows_explicit_model_override(self): + fake_conf = self._fake_conf() + fake_response = MagicMock() + fake_response.status_code = 200 + fake_response.json.return_value = { + "model": "ernie-4.5-vl-28b-a3b", + "choices": [{"message": {"content": "有文字。"}}], + "usage": {}, + } + + with patch("models.qianfan.qianfan_bot.conf", return_value=fake_conf): + with patch("models.qianfan.qianfan_bot.SessionManager"): + from models.qianfan.qianfan_bot import QianfanBot + + bot = QianfanBot() + with patch("models.qianfan.qianfan_bot.requests.post", return_value=fake_response) as post: + result = bot.call_vision( + image_url="data:image/jpeg;base64,BBBB", + question="识别文字", + model="ernie-4.5-vl-28b-a3b", + max_tokens=256, + ) + + self.assertEqual(result["model"], "ernie-4.5-vl-28b-a3b") + self.assertEqual(post.call_args.kwargs["json"]["model"], "ernie-4.5-vl-28b-a3b") + self.assertEqual(post.call_args.kwargs["json"]["max_tokens"], 256) + + def test_call_vision_returns_error_dict_for_api_error(self): + fake_conf = self._fake_conf() + fake_response = MagicMock() + fake_response.status_code = 400 + fake_response.json.return_value = {"error": {"message": "bad image"}} + fake_response.text = '{"error":{"message":"bad image"}}' + + with patch("models.qianfan.qianfan_bot.conf", return_value=fake_conf): + with patch("models.qianfan.qianfan_bot.SessionManager"): + from models.qianfan.qianfan_bot import QianfanBot + + bot = QianfanBot() + with patch("models.qianfan.qianfan_bot.requests.post", return_value=fake_response): + result = bot.call_vision( + image_url="data:image/png;base64,AAAA", + question="这张图里有什么?", + ) + + self.assertTrue(result["error"]) + self.assertEqual(result["message"], "请求失败:bad image") + class TestQianfanSurfaces(unittest.TestCase): def _read(self, relative_path):