mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-03 10:47:08 +08:00
fix: persist ASR model in models API
This commit is contained in:
@@ -2613,7 +2613,7 @@ class ModelsHandler:
|
||||
if capability == "vision":
|
||||
return self._set_vision(provider_id, model)
|
||||
if capability == "asr":
|
||||
return self._set_simple("voice_to_text", provider_id)
|
||||
return self._set_asr(provider_id, model)
|
||||
if capability == "tts":
|
||||
return self._set_tts(provider_id, model, (data.get("voice") or "").strip())
|
||||
if capability == "embedding":
|
||||
@@ -2773,6 +2773,24 @@ class ModelsHandler:
|
||||
self._refresh_voice_routing()
|
||||
return json.dumps({"status": "success", key: value})
|
||||
|
||||
def _set_asr(self, provider_id: str, model: str) -> str:
|
||||
local_config = conf()
|
||||
file_cfg = self._read_file_config()
|
||||
local_config["voice_to_text"] = provider_id
|
||||
file_cfg["voice_to_text"] = provider_id
|
||||
local_config["voice_to_text_model"] = model
|
||||
file_cfg["voice_to_text_model"] = model
|
||||
self._write_file_config(file_cfg)
|
||||
logger.info(
|
||||
f"[ModelsHandler] asr updated: provider={provider_id!r} "
|
||||
f"model={model!r}"
|
||||
)
|
||||
self._refresh_voice_routing()
|
||||
return json.dumps({
|
||||
"status": "success",
|
||||
"provider": provider_id, "model": model,
|
||||
})
|
||||
|
||||
def _set_tts(self, provider_id: str, model: str, voice: str = "") -> str:
|
||||
local_config = conf()
|
||||
file_cfg = self._read_file_config()
|
||||
|
||||
59
tests/test_models_handler.py
Normal file
59
tests/test_models_handler.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# encoding:utf-8
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
if "web" not in sys.modules:
|
||||
web_stub = types.ModuleType("web")
|
||||
web_stub.HTTPError = type("HTTPError", (Exception,), {})
|
||||
web_stub.cookies = lambda: {}
|
||||
web_stub.header = lambda *args, **kwargs: None
|
||||
web_stub.data = lambda: b"{}"
|
||||
web_stub.input = lambda **kwargs: types.SimpleNamespace(**kwargs)
|
||||
web_stub.setcookie = lambda *args, **kwargs: None
|
||||
web_stub.seeother = lambda *args, **kwargs: Exception("seeother")
|
||||
web_stub.notfound = lambda *args, **kwargs: Exception("notfound")
|
||||
web_stub.badrequest = lambda *args, **kwargs: Exception("badrequest")
|
||||
web_stub.application = lambda *args, **kwargs: types.SimpleNamespace(wsgifunc=lambda: None)
|
||||
web_stub.httpserver = types.SimpleNamespace(
|
||||
LogMiddleware=type("LogMiddleware", (), {"log": lambda *args, **kwargs: None}),
|
||||
StaticMiddleware=lambda app: app,
|
||||
WSGIServer=lambda *args, **kwargs: types.SimpleNamespace(serve_forever=lambda: None),
|
||||
)
|
||||
sys.modules["web"] = web_stub
|
||||
|
||||
|
||||
class TestModelsHandler(unittest.TestCase):
|
||||
def test_set_asr_capability_persists_provider_and_model(self):
|
||||
from channel.web.web_channel import ModelsHandler
|
||||
|
||||
local_config = {}
|
||||
file_config = {}
|
||||
handler = ModelsHandler()
|
||||
|
||||
with patch("channel.web.web_channel.conf", return_value=local_config):
|
||||
with patch.object(ModelsHandler, "_read_file_config", return_value=file_config):
|
||||
with patch.object(ModelsHandler, "_write_file_config") as write_file:
|
||||
with patch.object(ModelsHandler, "_refresh_voice_routing") as refresh_voice:
|
||||
result = json.loads(handler._handle_set_capability({
|
||||
"capability": "asr",
|
||||
"provider_id": "dashscope",
|
||||
"model": "qwen3-asr-flash",
|
||||
}))
|
||||
|
||||
self.assertEqual(result["status"], "success")
|
||||
self.assertEqual(local_config["voice_to_text"], "dashscope")
|
||||
self.assertEqual(local_config["voice_to_text_model"], "qwen3-asr-flash")
|
||||
self.assertEqual(file_config["voice_to_text"], "dashscope")
|
||||
self.assertEqual(file_config["voice_to_text_model"], "qwen3-asr-flash")
|
||||
write_file.assert_called_once_with(file_config)
|
||||
refresh_voice.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user