mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-02 00:57:41 +08:00
feat: add qianfan chat bot
This commit is contained in:
@@ -99,5 +99,99 @@ class TestQianfanConstantsAndRouting(unittest.TestCase):
|
||||
)
|
||||
|
||||
|
||||
class TestQianfanBot(unittest.TestCase):
|
||||
def _fake_conf(self, values=None):
|
||||
data = {
|
||||
"model": "ernie-4.5-turbo-128k",
|
||||
"qianfan_api_key": "test-qianfan-key",
|
||||
"qianfan_api_base": "https://qianfan.baidubce.com/v2",
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"frequency_penalty": 0.0,
|
||||
"presence_penalty": 0.0,
|
||||
"request_timeout": 180,
|
||||
"clear_memory_commands": ["#清除记忆"],
|
||||
"conversation_max_tokens": 1000,
|
||||
"expires_in_seconds": 3600,
|
||||
}
|
||||
if values:
|
||||
data.update(values)
|
||||
fake_conf = MagicMock()
|
||||
fake_conf.get.side_effect = lambda key, default=None: data.get(key, default)
|
||||
return fake_conf
|
||||
|
||||
def test_bot_factory_returns_qianfan_bot(self):
|
||||
from common import const
|
||||
from models.bot_factory import create_bot
|
||||
|
||||
fake_conf = self._fake_conf()
|
||||
with patch("models.qianfan.qianfan_bot.conf", return_value=fake_conf):
|
||||
with patch("models.qianfan.qianfan_bot.SessionManager"):
|
||||
bot = create_bot(const.QIANFAN)
|
||||
|
||||
from models.qianfan.qianfan_bot import QianfanBot
|
||||
self.assertIsInstance(bot, QianfanBot)
|
||||
|
||||
def test_default_model_uses_ernie_when_model_is_provider_alias(self):
|
||||
fake_conf = self._fake_conf({"model": "qianfan"})
|
||||
with patch("models.qianfan.qianfan_bot.conf", return_value=fake_conf):
|
||||
with patch("models.qianfan.qianfan_bot.SessionManager"):
|
||||
from models.qianfan.qianfan_bot import QianfanBot
|
||||
|
||||
bot = QianfanBot()
|
||||
|
||||
self.assertEqual(bot.args["model"], "ernie-4.5-turbo-128k")
|
||||
|
||||
def test_reply_text_posts_openai_compatible_payload(self):
|
||||
fake_conf = self._fake_conf()
|
||||
fake_response = MagicMock()
|
||||
fake_response.status_code = 200
|
||||
fake_response.json.return_value = {
|
||||
"choices": [{"message": {"content": "你好,我是文心。"}}],
|
||||
"usage": {"total_tokens": 12, "completion_tokens": 6},
|
||||
}
|
||||
session = MagicMock()
|
||||
session.messages = [{"role": "user", "content": "你好"}]
|
||||
|
||||
with patch("models.qianfan.qianfan_bot.conf", return_value=fake_conf):
|
||||
with patch("models.qianfan.qianfan_bot.SessionManager"):
|
||||
from models.qianfan.qianfan_bot import QianfanBot
|
||||
|
||||
bot = QianfanBot()
|
||||
with patch("models.qianfan.qianfan_bot.requests.post", return_value=fake_response) as post:
|
||||
result = bot.reply_text(session)
|
||||
|
||||
self.assertEqual(result["content"], "你好,我是文心。")
|
||||
self.assertEqual(result["total_tokens"], 12)
|
||||
self.assertEqual(result["completion_tokens"], 6)
|
||||
post.assert_called_once()
|
||||
url = post.call_args.args[0]
|
||||
kwargs = post.call_args.kwargs
|
||||
self.assertEqual(url, "https://qianfan.baidubce.com/v2/chat/completions")
|
||||
self.assertEqual(kwargs["headers"]["Authorization"], "Bearer test-qianfan-key")
|
||||
self.assertEqual(kwargs["json"]["model"], "ernie-4.5-turbo-128k")
|
||||
self.assertEqual(kwargs["json"]["messages"], [{"role": "user", "content": "你好"}])
|
||||
|
||||
def test_reply_text_returns_auth_error_for_401(self):
|
||||
fake_conf = self._fake_conf()
|
||||
fake_response = MagicMock()
|
||||
fake_response.status_code = 401
|
||||
fake_response.json.return_value = {"error": {"message": "invalid api key"}}
|
||||
fake_response.text = '{"error":{"message":"invalid api key"}}'
|
||||
session = MagicMock()
|
||||
session.messages = [{"role": "user", "content": "你好"}]
|
||||
|
||||
with patch("models.qianfan.qianfan_bot.conf", return_value=fake_conf):
|
||||
with patch("models.qianfan.qianfan_bot.SessionManager"):
|
||||
from models.qianfan.qianfan_bot import QianfanBot
|
||||
|
||||
bot = QianfanBot()
|
||||
with patch("models.qianfan.qianfan_bot.requests.post", return_value=fake_response):
|
||||
result = bot.reply_text(session)
|
||||
|
||||
self.assertEqual(result["completion_tokens"], 0)
|
||||
self.assertEqual(result["content"], "授权失败,请检查 Qianfan API Key 是否正确")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user