diff --git a/channel/web/web_channel.py b/channel/web/web_channel.py index bfc038cb..6000d8a1 100644 --- a/channel/web/web_channel.py +++ b/channel/web/web_channel.py @@ -2613,7 +2613,7 @@ class ModelsHandler: if capability == "vision": return self._set_vision(provider_id, model) if capability == "asr": - return self._set_simple("voice_to_text", provider_id) + return self._set_asr(provider_id, model) if capability == "tts": return self._set_tts(provider_id, model, (data.get("voice") or "").strip()) if capability == "embedding": @@ -2773,6 +2773,24 @@ class ModelsHandler: self._refresh_voice_routing() return json.dumps({"status": "success", key: value}) + def _set_asr(self, provider_id: str, model: str) -> str: + local_config = conf() + file_cfg = self._read_file_config() + local_config["voice_to_text"] = provider_id + file_cfg["voice_to_text"] = provider_id + local_config["voice_to_text_model"] = model + file_cfg["voice_to_text_model"] = model + self._write_file_config(file_cfg) + logger.info( + f"[ModelsHandler] asr updated: provider={provider_id!r} " + f"model={model!r}" + ) + self._refresh_voice_routing() + return json.dumps({ + "status": "success", + "provider": provider_id, "model": model, + }) + def _set_tts(self, provider_id: str, model: str, voice: str = "") -> str: local_config = conf() file_cfg = self._read_file_config() diff --git a/tests/test_models_handler.py b/tests/test_models_handler.py new file mode 100644 index 00000000..cd36c1c9 --- /dev/null +++ b/tests/test_models_handler.py @@ -0,0 +1,59 @@ +# encoding:utf-8 +import json +import os +import sys +import types +import unittest +from unittest.mock import patch + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +if "web" not in sys.modules: + web_stub = types.ModuleType("web") + web_stub.HTTPError = type("HTTPError", (Exception,), {}) + web_stub.cookies = lambda: {} + web_stub.header = lambda *args, **kwargs: None + web_stub.data = lambda: b"{}" + web_stub.input = lambda **kwargs: types.SimpleNamespace(**kwargs) + web_stub.setcookie = lambda *args, **kwargs: None + web_stub.seeother = lambda *args, **kwargs: Exception("seeother") + web_stub.notfound = lambda *args, **kwargs: Exception("notfound") + web_stub.badrequest = lambda *args, **kwargs: Exception("badrequest") + web_stub.application = lambda *args, **kwargs: types.SimpleNamespace(wsgifunc=lambda: None) + web_stub.httpserver = types.SimpleNamespace( + LogMiddleware=type("LogMiddleware", (), {"log": lambda *args, **kwargs: None}), + StaticMiddleware=lambda app: app, + WSGIServer=lambda *args, **kwargs: types.SimpleNamespace(serve_forever=lambda: None), + ) + sys.modules["web"] = web_stub + + +class TestModelsHandler(unittest.TestCase): + def test_set_asr_capability_persists_provider_and_model(self): + from channel.web.web_channel import ModelsHandler + + local_config = {} + file_config = {} + handler = ModelsHandler() + + with patch("channel.web.web_channel.conf", return_value=local_config): + with patch.object(ModelsHandler, "_read_file_config", return_value=file_config): + with patch.object(ModelsHandler, "_write_file_config") as write_file: + with patch.object(ModelsHandler, "_refresh_voice_routing") as refresh_voice: + result = json.loads(handler._handle_set_capability({ + "capability": "asr", + "provider_id": "dashscope", + "model": "qwen3-asr-flash", + })) + + self.assertEqual(result["status"], "success") + self.assertEqual(local_config["voice_to_text"], "dashscope") + self.assertEqual(local_config["voice_to_text_model"], "qwen3-asr-flash") + self.assertEqual(file_config["voice_to_text"], "dashscope") + self.assertEqual(file_config["voice_to_text_model"], "qwen3-asr-flash") + write_file.assert_called_once_with(file_config) + refresh_voice.assert_called_once() + + +if __name__ == "__main__": + unittest.main()