mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-02 00:57:41 +08:00
fix(models): persist explicit provider for vision and image capabilities
This commit is contained in:
@@ -78,6 +78,22 @@ _MODEL_PREFIX_TO_PROVIDER = [
|
||||
# Model prefixes that natively belong to OpenAI / LinkAI (raw HTTP providers).
|
||||
_OPENAI_MODEL_PREFIXES = ("gpt-", "o1-", "o3-", "o4-", "chatgpt-")
|
||||
|
||||
# Maps the UI provider id (persisted in tools.vision.provider) to the internal
|
||||
# display name used in VisionProvider.name. Keep in sync with _DISCOVERABLE_MODELS
|
||||
# and the openai/linkai branches in _route_by_model_name.
|
||||
_PROVIDER_ID_TO_DISPLAY = {
|
||||
"openai": "OpenAI",
|
||||
"linkai": "LinkAI",
|
||||
"moonshot": "Moonshot",
|
||||
"doubao": "Doubao",
|
||||
"dashscope": "DashScope",
|
||||
"claudeAPI": "Claude",
|
||||
"gemini": "Gemini",
|
||||
"qianfan": "Qianfan",
|
||||
"zhipu": "ZhipuAI",
|
||||
"minimax": "MiniMax",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class VisionProvider:
|
||||
@@ -211,10 +227,16 @@ class Vision(BaseTool):
|
||||
are de-duplicated to avoid retrying the same endpoint twice.
|
||||
"""
|
||||
user_model = self._resolve_user_vision_model()
|
||||
user_provider = self._resolve_user_vision_provider()
|
||||
providers: List[VisionProvider] = []
|
||||
|
||||
# Step 1: preferred provider derived from tools.vision.model
|
||||
if user_model:
|
||||
# Step 1: preferred provider — explicit `tools.vision.provider`
|
||||
# wins so custom model names can still be routed correctly. Falls
|
||||
# through to model-name prefix inference when provider is unset.
|
||||
preferred = None
|
||||
if user_provider and user_model:
|
||||
preferred = self._route_by_provider_id(user_provider, user_model)
|
||||
if not preferred and user_model:
|
||||
preferred = self._route_by_model_name(user_model)
|
||||
if preferred:
|
||||
providers.extend(preferred)
|
||||
@@ -263,6 +285,24 @@ class Vision(BaseTool):
|
||||
return m.strip()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _resolve_user_vision_provider() -> Optional[str]:
|
||||
"""Read tools.vision.provider — the UI-persisted vendor id.
|
||||
|
||||
Lets users pin a vendor for custom model names that prefix-inference
|
||||
can't recognize. Returns None when unset/blank.
|
||||
"""
|
||||
tools_conf = conf().get("tools") or conf().get("tool") or {}
|
||||
if not isinstance(tools_conf, dict):
|
||||
return None
|
||||
vision_conf = tools_conf.get("vision", {})
|
||||
if not isinstance(vision_conf, dict):
|
||||
return None
|
||||
p = vision_conf.get("provider")
|
||||
if isinstance(p, str) and p.strip():
|
||||
return p.strip()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _infer_provider_from_model(model_name: str) -> Optional[str]:
|
||||
"""
|
||||
@@ -279,6 +319,54 @@ class Vision(BaseTool):
|
||||
return display_name
|
||||
return None
|
||||
|
||||
def _route_by_provider_id(self, provider_id: str, user_model: str) -> Optional[List[VisionProvider]]:
|
||||
"""Route by the UI-persisted provider id.
|
||||
|
||||
Returns:
|
||||
- [provider] : provider id is known and its key is configured.
|
||||
- None : unknown provider id, or the bot can't be created.
|
||||
Caller falls through to model-name-based routing.
|
||||
"""
|
||||
display_name = _PROVIDER_ID_TO_DISPLAY.get(provider_id)
|
||||
if not display_name:
|
||||
return None
|
||||
|
||||
# OpenAI / LinkAI use raw HTTP providers, not the discoverable bot path.
|
||||
if provider_id == "openai":
|
||||
p = self._build_openai_provider(user_model)
|
||||
return [p] if p else None
|
||||
if provider_id == "linkai":
|
||||
p = self._build_linkai_provider(user_model)
|
||||
return [p] if p else None
|
||||
|
||||
# Discoverable bot-backed providers.
|
||||
for config_key, bot_type, _default_model, name in _DISCOVERABLE_MODELS:
|
||||
if name != display_name:
|
||||
continue
|
||||
api_key = conf().get(config_key, "")
|
||||
if not api_key or not api_key.strip():
|
||||
logger.warning(f"[Vision] tools.vision.provider='{provider_id}' "
|
||||
f"but '{config_key}' is not configured. Falling back.")
|
||||
return None
|
||||
try:
|
||||
from models.bot_factory import create_bot
|
||||
bot = create_bot(bot_type)
|
||||
if not hasattr(bot, 'call_vision'):
|
||||
logger.warning(f"[Vision] '{display_name}' bot does not implement call_vision.")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"[Vision] Failed to create '{display_name}' bot: {e}")
|
||||
return None
|
||||
return [VisionProvider(
|
||||
name=display_name,
|
||||
api_key="",
|
||||
api_base="",
|
||||
model_override=user_model,
|
||||
use_bot=True,
|
||||
fallback_bot=bot,
|
||||
)]
|
||||
return None
|
||||
|
||||
def _route_by_model_name(self, user_model: str) -> Optional[List[VisionProvider]]:
|
||||
"""
|
||||
Try to build a provider list using the user-specified model name.
|
||||
|
||||
@@ -106,7 +106,7 @@ const I18N = {
|
||||
config_custom_model_hint: '输入自定义模型名称',
|
||||
config_save: '保存', config_saved: '已保存',
|
||||
config_save_error: '保存失败',
|
||||
config_custom_option: '自定义...',
|
||||
config_custom_option: '自定义',
|
||||
config_custom_tip: '接口需遵循 OpenAI API 协议',
|
||||
config_security: '安全设置', config_password: '访问密码',
|
||||
config_password_hint: '留空则不启用密码保护',
|
||||
@@ -280,7 +280,7 @@ const I18N = {
|
||||
config_custom_model_hint: 'Enter custom model name',
|
||||
config_save: 'Save', config_saved: 'Saved',
|
||||
config_save_error: 'Save failed',
|
||||
config_custom_option: 'Custom...',
|
||||
config_custom_option: 'Custom',
|
||||
config_custom_tip: 'API must follow OpenAI protocol.',
|
||||
config_security: 'Security', config_password: 'Password',
|
||||
config_password_hint: 'Leave empty to disable password protection',
|
||||
@@ -4798,7 +4798,7 @@ function rebuildCapabilityModelDropdown(def, providerId, selectedModel, scope) {
|
||||
modelValues.push(entry.value);
|
||||
return { value: entry.value, label: entry.label || entry.value, hint: entry.hint || '' };
|
||||
});
|
||||
opts.push({ value: '__custom__', label: currentLang === 'zh' ? '自定义...' : 'Custom...' });
|
||||
opts.push({ value: '__custom__', label: currentLang === 'zh' ? '自定义' : 'Custom' });
|
||||
|
||||
let initialValue = selectedModel || '';
|
||||
if (initialValue && !modelValues.includes(initialValue)) {
|
||||
@@ -4881,7 +4881,7 @@ function rebuildCapabilityVoiceDropdown(providerId, selectedVoice, scope, modelI
|
||||
hint: desc === code ? '' : code,
|
||||
};
|
||||
});
|
||||
opts.push({ value: '__custom__', label: currentLang === 'zh' ? '自定义...' : 'Custom...' });
|
||||
opts.push({ value: '__custom__', label: currentLang === 'zh' ? '自定义' : 'Custom' });
|
||||
|
||||
// Off-catalog values route through the custom branch.
|
||||
let initial = selectedVoice || '';
|
||||
|
||||
@@ -2011,12 +2011,17 @@ class ModelsHandler:
|
||||
if not isinstance(vision_conf, dict):
|
||||
vision_conf = {}
|
||||
user_specified = (vision_conf.get("model") or "").strip()
|
||||
explicit_provider = (vision_conf.get("provider") or "").strip()
|
||||
|
||||
# When the user pinned a specific model, infer which vendor card to
|
||||
# highlight by scanning the per-provider model lists. Falls back to
|
||||
# an empty provider so the dropdown stays on "auto" if we can't tell.
|
||||
# Provider resolution priority:
|
||||
# 1. Explicit `tools.vision.provider` (persisted via UI; supports
|
||||
# custom model names that prefix-inference can't recognize).
|
||||
# 2. Scan per-provider model lists by model name.
|
||||
# Empty provider keeps the dropdown on "auto" when we can't tell.
|
||||
inferred_provider = ""
|
||||
if user_specified:
|
||||
if explicit_provider and explicit_provider in cls._VISION_PROVIDER_MODELS:
|
||||
inferred_provider = explicit_provider
|
||||
elif user_specified:
|
||||
for pid, models in cls._VISION_PROVIDER_MODELS.items():
|
||||
if user_specified in models:
|
||||
inferred_provider = pid
|
||||
@@ -2181,11 +2186,17 @@ class ModelsHandler:
|
||||
if not isinstance(img_node, dict):
|
||||
img_node = {}
|
||||
explicit_model = (img_node.get("model") or "").strip()
|
||||
explicit_provider = (img_node.get("provider") or "").strip()
|
||||
|
||||
# Infer the provider card to highlight by scanning per-provider
|
||||
# model lists, including alias values inside {value, hint} entries.
|
||||
# Provider resolution priority:
|
||||
# 1. Explicit `skills.image-generation.provider` (persisted via UI;
|
||||
# supports custom model names that prefix-inference can't catch).
|
||||
# 2. Scan per-provider model catalog by model name.
|
||||
# Empty provider keeps the dropdown on "auto" when we can't tell.
|
||||
inferred_provider = ""
|
||||
if explicit_model:
|
||||
if explicit_provider and explicit_provider in cls._IMAGE_PROVIDER_MODELS:
|
||||
inferred_provider = explicit_provider
|
||||
elif explicit_model:
|
||||
for pid, models in cls._IMAGE_PROVIDER_MODELS.items():
|
||||
for entry in models:
|
||||
val = entry if isinstance(entry, str) else (entry.get("value") or "")
|
||||
@@ -2440,27 +2451,37 @@ class ModelsHandler:
|
||||
return json.dumps({"status": "error", "message": f"capability not editable: {capability}"})
|
||||
|
||||
def _set_image(self, provider_id: str, model: str) -> str:
|
||||
# Source of truth: skills.image-generation.model. provider_id is
|
||||
# informational only; the resolver picks the vendor by model prefix.
|
||||
# Source of truth: skills.image-generation.{provider, model}. The
|
||||
# provider field is persisted so users picking a custom model under
|
||||
# a specific vendor still get routed there — runtime falls back to
|
||||
# model-name prefix inference only when provider is empty.
|
||||
local_config = conf()
|
||||
file_cfg = self._read_file_config()
|
||||
|
||||
self._set_nested_namespace_value(local_config, "skills", "image-generation", "model", model or "")
|
||||
self._set_nested_namespace_value(file_cfg, "skills", "image-generation", "model", model or "")
|
||||
self._set_nested_namespace_value(local_config, "skills", "image-generation", "provider", provider_id or "")
|
||||
self._set_nested_namespace_value(file_cfg, "skills", "image-generation", "provider", provider_id or "")
|
||||
self._drop_legacy_namespace(local_config, "skill", "skills", child="image-generation")
|
||||
self._drop_legacy_namespace(file_cfg, "skill", "skills", child="image-generation")
|
||||
|
||||
self._write_file_config(file_cfg)
|
||||
|
||||
# The skill subprocess reads SKILL_IMAGE_GENERATION_MODEL from env at
|
||||
# startup; mirror the change so live edits apply without restart.
|
||||
env_key = "SKILL_IMAGE_GENERATION_MODEL"
|
||||
# The skill subprocess reads SKILL_IMAGE_GENERATION_{MODEL,PROVIDER}
|
||||
# from env at startup; mirror the change so live edits apply without
|
||||
# restart.
|
||||
model_env = "SKILL_IMAGE_GENERATION_MODEL"
|
||||
provider_env = "SKILL_IMAGE_GENERATION_PROVIDER"
|
||||
if model:
|
||||
os.environ[env_key] = model
|
||||
os.environ[model_env] = model
|
||||
else:
|
||||
os.environ.pop(env_key, None)
|
||||
os.environ.pop(model_env, None)
|
||||
if provider_id:
|
||||
os.environ[provider_env] = provider_id
|
||||
else:
|
||||
os.environ.pop(provider_env, None)
|
||||
|
||||
logger.info(f"[ModelsHandler] image updated: provider_hint={provider_id!r} model={model!r}")
|
||||
logger.info(f"[ModelsHandler] image updated: provider={provider_id!r} model={model!r}")
|
||||
return json.dumps({
|
||||
"status": "success",
|
||||
"provider": provider_id,
|
||||
@@ -2499,18 +2520,22 @@ class ModelsHandler:
|
||||
return json.dumps({"status": "success", "applied": applied})
|
||||
|
||||
def _set_vision(self, provider_id: str, model: str) -> str:
|
||||
# Source of truth: tools.vision.model. provider_id is informational
|
||||
# only; the resolver picks the vendor by model prefix.
|
||||
# Source of truth: tools.vision.{provider, model}. The provider field
|
||||
# is persisted so users picking a custom model under a specific vendor
|
||||
# still get routed there — runtime falls back to model-name prefix
|
||||
# inference only when provider is empty.
|
||||
local_config = conf()
|
||||
file_cfg = self._read_file_config()
|
||||
self._set_nested_namespace_value(file_cfg, "tools", "vision", "model", model)
|
||||
self._set_nested_namespace_value(local_config, "tools", "vision", "model", model)
|
||||
self._set_nested_namespace_value(file_cfg, "tools", "vision", "provider", provider_id or "")
|
||||
self._set_nested_namespace_value(local_config, "tools", "vision", "provider", provider_id or "")
|
||||
self._drop_legacy_namespace(file_cfg, "tool", "tools", child="vision")
|
||||
self._drop_legacy_namespace(local_config, "tool", "tools", child="vision")
|
||||
|
||||
self._write_file_config(file_cfg)
|
||||
logger.info(f"[ModelsHandler] vision model set: {model!r}")
|
||||
return json.dumps({"status": "success", "model": model})
|
||||
logger.info(f"[ModelsHandler] vision updated: provider={provider_id!r} model={model!r}")
|
||||
return json.dumps({"status": "success", "provider": provider_id, "model": model})
|
||||
|
||||
@staticmethod
|
||||
def _set_nested_namespace_value(cfg, top: str, name: str, key: str, value):
|
||||
|
||||
@@ -1011,6 +1011,18 @@ _MODEL_PREFERRED_PROVIDER: list[tuple[tuple[str, ...], str]] = [
|
||||
# Default global priority when the model has no preferred provider.
|
||||
_DEFAULT_PROVIDER_ORDER = ["OpenAI", "Gemini", "Seedream", "Qwen", "MiniMax", "LinkAI"]
|
||||
|
||||
# UI provider id (persisted via the Models page) → internal label used by
|
||||
# the factory dict in `_build_providers`. Allows pinning a vendor for
|
||||
# custom model names that prefix-inference can't recognize.
|
||||
_PROVIDER_ID_TO_LABEL = {
|
||||
"openai": "OpenAI",
|
||||
"gemini": "Gemini",
|
||||
"doubao": "Seedream",
|
||||
"dashscope": "Qwen",
|
||||
"minimax": "MiniMax",
|
||||
"linkai": "LinkAI",
|
||||
}
|
||||
|
||||
|
||||
def _preferred_provider(model: str) -> str | None:
|
||||
m = (model or "").lower()
|
||||
@@ -1020,7 +1032,7 @@ def _preferred_provider(model: str) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def _build_providers(model: str) -> list[tuple[str, ImageProvider]]:
|
||||
def _build_providers(model: str, provider_id: str = "") -> list[tuple[str, ImageProvider]]:
|
||||
"""Build an ordered list of (label, provider) to try.
|
||||
|
||||
Behaviour:
|
||||
@@ -1051,6 +1063,11 @@ def _build_providers(model: str) -> list[tuple[str, ImageProvider]]:
|
||||
"LinkAI": os.environ.get("LINKAI_API_BASE", "https://api.link-ai.tech"),
|
||||
}
|
||||
|
||||
# Provider preference resolution priority:
|
||||
# 1. Explicit `provider_id` (UI-persisted, supports custom model names).
|
||||
# 2. Model-name prefix inference.
|
||||
pref = _PROVIDER_ID_TO_LABEL.get(provider_id) if provider_id else None
|
||||
if not pref:
|
||||
pref = _preferred_provider(model)
|
||||
|
||||
# If a specific model is requested and its native provider has no key,
|
||||
@@ -1114,6 +1131,9 @@ def main():
|
||||
# 3. None → fall back to automatic provider routing (try every
|
||||
# provider with a configured API key in global priority order)
|
||||
model = args.get("model") or os.environ.get("SKILL_IMAGE_GENERATION_MODEL") or ""
|
||||
# Provider hint persisted by the Models UI; lets users pin a vendor for
|
||||
# custom model names that prefix-inference can't recognize.
|
||||
provider_id = args.get("provider") or os.environ.get("SKILL_IMAGE_GENERATION_PROVIDER") or ""
|
||||
quality = args.get("quality")
|
||||
size = args.get("size")
|
||||
aspect_ratio = args.get("aspect_ratio")
|
||||
@@ -1121,7 +1141,7 @@ def main():
|
||||
|
||||
output_dir = os.environ.get("IMAGE_OUTPUT_DIR", os.path.join(os.getcwd(), "images"))
|
||||
|
||||
providers = _build_providers(model)
|
||||
providers = _build_providers(model, provider_id=provider_id)
|
||||
if not providers:
|
||||
target = f"model '{model}'" if model else "image generation"
|
||||
print(json.dumps({
|
||||
|
||||
Reference in New Issue
Block a user