diff --git a/agent/memory/conversation_store.py b/agent/memory/conversation_store.py index c5d215bf..48148f61 100644 --- a/agent/memory/conversation_store.py +++ b/agent/memory/conversation_store.py @@ -44,6 +44,7 @@ CREATE TABLE IF NOT EXISTS messages ( role TEXT NOT NULL, content TEXT NOT NULL, created_at INTEGER NOT NULL, + extras TEXT NOT NULL DEFAULT '', UNIQUE (session_id, seq) ); @@ -67,6 +68,12 @@ _MIGRATION_ADD_CONTEXT_START_SEQ = """ ALTER TABLE sessions ADD COLUMN context_start_seq INTEGER NOT NULL DEFAULT 0; """ +# Generic JSON sidecar for per-message attachments (TTS audio URL, future use). +# Always optional — readers must tolerate missing column / empty / invalid JSON. +_MIGRATION_ADD_MSG_EXTRAS = """ +ALTER TABLE messages ADD COLUMN extras TEXT NOT NULL DEFAULT ''; +""" + DEFAULT_MAX_AGE_DAYS: int = 30 @@ -169,20 +176,26 @@ def _group_into_display_turns( cur_rest: List[tuple] = [] started = False - for role, raw_content, created_at in rows: + for role, raw_content, created_at, raw_extras in rows: try: content = json.loads(raw_content) except Exception: content = raw_content + try: + extras = json.loads(raw_extras) if raw_extras else {} + if not isinstance(extras, dict): + extras = {} + except Exception: + extras = {} if role == "user" and _is_visible_user_message(content): if started: groups.append((cur_user, cur_rest)) - cur_user = (content, created_at) + cur_user = (content, created_at, extras) cur_rest = [] started = True else: - cur_rest.append((role, content, created_at)) + cur_rest.append((role, content, created_at, extras)) if started: groups.append((cur_user, cur_rest)) @@ -195,7 +208,7 @@ def _group_into_display_turns( for user_row, rest in groups: # User turn if user_row: - content, created_at = user_row + content, created_at, _u_extras = user_row text = _extract_display_text(content) if text: turns.append({"role": "user", "content": text, "created_at": created_at}) @@ -206,8 +219,11 @@ def _group_into_display_turns( tool_results: Dict[str, str] = {} final_text = "" final_ts: Optional[int] = None + merged_extras: Dict[str, Any] = {} - for role, content, created_at in rest: + for role, content, created_at, extras in rest: + if role == "assistant" and isinstance(extras, dict): + merged_extras.update(extras) if role == "user": tool_results.update(_extract_tool_results(content)) elif role == "assistant": @@ -256,6 +272,8 @@ def _group_into_display_turns( "steps": steps, "created_at": final_ts or (user_row[1] if user_row else 0), } + if merged_extras: + turn["extras"] = merged_extras turns.append(turn) return turns @@ -411,13 +429,15 @@ class ConversationStore: content = json.dumps( msg.get("content", ""), ensure_ascii=False ) + extras_obj = msg.get("extras") or {} + extras = json.dumps(extras_obj, ensure_ascii=False) if extras_obj else "" conn.execute( """ INSERT OR IGNORE INTO messages - (session_id, seq, role, content, created_at) - VALUES (?, ?, ?, ?, ?) + (session_id, seq, role, content, created_at, extras) + VALUES (?, ?, ?, ?, ?, ?) """, - (session_id, next_seq, role, content, now), + (session_id, next_seq, role, content, now, extras), ) next_seq += 1 @@ -651,6 +671,55 @@ class ConversationStore: logger.info(f"[ConversationStore] Pruned {deleted} expired sessions") return deleted + def attach_extras_to_last_assistant( + self, + session_id: str, + extras: Dict[str, Any], + ) -> Optional[int]: + """ + Merge ``extras`` into the latest assistant message of a session. + + Used by post-processing (e.g. TTS) that needs to annotate an already + persisted bot reply with attachments such as audio URLs. + + Returns the message seq that was updated, or ``None`` if no assistant + message exists or the update could not be applied. + """ + if not extras: + return None + with self._lock: + conn = self._connect() + try: + row = conn.execute( + """ + SELECT seq, extras FROM messages + WHERE session_id = ? AND role = 'assistant' + ORDER BY seq DESC LIMIT 1 + """, + (session_id,), + ).fetchone() + if not row: + return None + seq, raw = row + try: + cur = json.loads(raw) if raw else {} + if not isinstance(cur, dict): + cur = {} + except Exception: + cur = {} + cur.update(extras) + conn.execute( + "UPDATE messages SET extras = ? WHERE session_id = ? AND seq = ?", + (json.dumps(cur, ensure_ascii=False), session_id, seq), + ) + conn.commit() + return seq + except Exception as e: + logger.warning(f"[ConversationStore] attach_extras failed: {e}") + return None + finally: + conn.close() + def load_history_page( self, session_id: str, @@ -698,15 +767,31 @@ class ConversationStore: ).fetchone() ctx_start = ctx_row[0] if ctx_row else 0 - rows = conn.execute( - """ - SELECT seq, role, content, created_at - FROM messages - WHERE session_id = ? - ORDER BY seq ASC - """, - (session_id,), - ).fetchall() + # extras column is added by migration; tolerate older DBs that + # might miss it by falling back to a NULL literal. + try: + rows = conn.execute( + """ + SELECT seq, role, content, created_at, extras + FROM messages + WHERE session_id = ? + ORDER BY seq ASC + """, + (session_id,), + ).fetchall() + except sqlite3.OperationalError: + rows = [ + (seq, role, content, created_at, "") + for (seq, role, content, created_at) in conn.execute( + """ + SELECT seq, role, content, created_at + FROM messages + WHERE session_id = ? + ORDER BY seq ASC + """, + (session_id,), + ).fetchall() + ] finally: conn.close() @@ -719,13 +804,16 @@ class ConversationStore: include_thinking = False # Strip seq for display grouping, but record max seq per visible user group - plain_rows = [(role, content, created_at) for _seq, role, content, created_at in rows] + plain_rows = [ + (role, content, created_at, extras_raw) + for _seq, role, content, created_at, extras_raw in rows + ] visible = _group_into_display_turns(plain_rows, include_thinking=include_thinking) # Build a mapping: find the seq of each visible user message to annotate context boundary. # Walk through rows to find visible user message seqs in order. visible_user_seqs: List[int] = [] - for seq, role, raw_content, _ts in rows: + for seq, role, raw_content, _ts, _extras in rows: if role != "user": continue try: @@ -911,6 +999,18 @@ class ConversationStore: except Exception as e: logger.warning(f"[ConversationStore] Migration (context_start_seq) failed: {e}") + msg_cols = { + row[1] + for row in conn.execute("PRAGMA table_info(messages)").fetchall() + } + if "extras" not in msg_cols: + try: + conn.execute(_MIGRATION_ADD_MSG_EXTRAS) + conn.commit() + logger.info("[ConversationStore] Migrated: added messages.extras column") + except Exception as e: + logger.warning(f"[ConversationStore] Migration (extras) failed: {e}") + def _connect(self) -> sqlite3.Connection: conn = sqlite3.connect(str(self._db_path), timeout=10) conn.execute("PRAGMA journal_mode=WAL") diff --git a/agent/protocol/agent_stream.py b/agent/protocol/agent_stream.py index 75b4f4ff..ef4f975b 100644 --- a/agent/protocol/agent_stream.py +++ b/agent/protocol/agent_stream.py @@ -603,15 +603,24 @@ class AgentStreamExecutor: except Exception as e: logger.debug(f"[Agent] MCP sync skipped: {e}") - # Prepare tool definitions (OpenAI/Claude format) + # Prepare tool definitions. Prefer get_json_schema() when it yields + # real properties (lets tools augment schema at runtime), otherwise + # fall back to the static `tool.params` (MCP tools rely on this). tools_schema = None if self.tools: tools_schema = [] for tool in self.tools.values(): + input_schema = tool.params + try: + dynamic = (tool.get_json_schema() or {}).get("parameters") or {} + if dynamic.get("properties"): + input_schema = dynamic + except Exception: + pass tools_schema.append({ "name": tool.name, "description": tool.description, - "input_schema": tool.params # Claude uses input_schema + "input_schema": input_schema, }) # Create request diff --git a/agent/tools/vision/vision.py b/agent/tools/vision/vision.py index a1c3265f..d8d7b7a3 100644 --- a/agent/tools/vision/vision.py +++ b/agent/tools/vision/vision.py @@ -3,7 +3,7 @@ Vision tool - Analyze images using Vision API. Supports local files (auto base64-encoded) and HTTP URLs. Provider resolution: - - tool.vision.model (if set) means "prefer this model first; fall back to + - tools.vision.model (if set) means "prefer this model first; fall back to other configured providers if it fails". The model name is mapped to its native provider (e.g. doubao-* → Doubao, kimi-* → Moonshot, gpt-* → OpenAI/LinkAI). That provider is tried first, then the standard auto @@ -30,7 +30,7 @@ from common import const from common.log import logger from config import conf -DEFAULT_MODEL = const.GPT_41_MINI +DEFAULT_MODEL = const.GPT_55 DEFAULT_TIMEOUT = 60 MAX_TOKENS = 1000 COMPRESS_THRESHOLD = 1_048_576 # 1 MB @@ -53,14 +53,14 @@ _DISCOVERABLE_MODELS = [ ("ark_api_key", const.DOUBAO, const.DOUBAO_SEED_2_PRO, "Doubao"), ("dashscope_api_key", const.QWEN_DASHSCOPE, const.QWEN36_PLUS, "DashScope"), ("claude_api_key", const.CLAUDEAPI, const.CLAUDE_4_6_SONNET, "Claude"), - ("gemini_api_key", const.GEMINI, const.GEMINI_31_FLASH_LITE_PRE, "Gemini"), + ("gemini_api_key", const.GEMINI, const.GEMINI_35_FLASH, "Gemini"), ("qianfan_api_key", const.QIANFAN, const.ERNIE_45_TURBO_VL, "Qianfan"), ("zhipu_ai_api_key", const.ZHIPU_AI, const.GLM_4_7, "ZhipuAI"), ("minimax_api_key", const.MiniMax, const.MINIMAX_M2_7, "MiniMax"), ] # Model name prefix → discoverable provider display_name. -# Used to auto-route tool.vision.model to its native provider. +# Used to auto-route tools.vision.model to its native provider. # Matched case-insensitively; longest prefix wins. _MODEL_PREFIX_TO_PROVIDER = [ ("doubao-", "Doubao"), @@ -154,7 +154,7 @@ class Vision(BaseTool): # Default model is only used as a last-resort placeholder for providers # whose VisionProvider.model_override is None (e.g. raw OpenAI provider - # when the user did not configure tool.vision.model). + # when the user did not configure tools.vision.model). return self._call_with_fallback(providers, DEFAULT_MODEL, question, image_content) def _call_with_fallback(self, providers: List[VisionProvider], model: str, @@ -193,12 +193,12 @@ class Vision(BaseTool): """ Build an ordered list of providers to try. - Semantics of `tool.vision.model`: + Semantics of `tools.vision.model`: "Prefer this model first; fall back to other configured providers if it fails." Order: - 1. The provider that natively serves `tool.vision.model` (if any + 1. The provider that natively serves `tools.vision.model` (if any and its API key is configured) — using the user-specified model name verbatim. 2. Auto-discovery chain as fallback: @@ -213,7 +213,7 @@ class Vision(BaseTool): user_model = self._resolve_user_vision_model() providers: List[VisionProvider] = [] - # Step 1: preferred provider derived from tool.vision.model + # Step 1: preferred provider derived from tools.vision.model if user_model: preferred = self._route_by_model_name(user_model) if preferred: @@ -251,11 +251,11 @@ class Vision(BaseTool): @staticmethod def _resolve_user_vision_model() -> Optional[str]: - """Read tool.vision.model from config; return None if unset/blank.""" - tool_conf = conf().get("tool", {}) - if not isinstance(tool_conf, dict): + """Read tools.vision.model (singular ``tool`` kept as runtime fallback).""" + tools_conf = conf().get("tools") or conf().get("tool") or {} + if not isinstance(tools_conf, dict): return None - vision_conf = tool_conf.get("vision", {}) + vision_conf = tools_conf.get("vision", {}) if not isinstance(vision_conf, dict): return None m = vision_conf.get("model") @@ -303,7 +303,7 @@ class Vision(BaseTool): self._append_provider(providers, lambda: self._build_linkai_provider(user_model)) if providers: return providers - logger.warning(f"[Vision] tool.vision.model='{user_model}' looks like an OpenAI " + logger.warning(f"[Vision] tools.vision.model='{user_model}' looks like an OpenAI " f"model but neither OPENAI_API_KEY nor LINKAI_API_KEY is configured.") return None # fall through to auto @@ -317,7 +317,7 @@ class Vision(BaseTool): continue api_key = conf().get(config_key, "") if not api_key or not api_key.strip(): - logger.warning(f"[Vision] tool.vision.model='{user_model}' routes to " + logger.warning(f"[Vision] tools.vision.model='{user_model}' routes to " f"'{display_name}' but '{config_key}' is not configured. " f"Falling back to auto-discovery.") return None # fall through to auto @@ -452,8 +452,8 @@ class Vision(BaseTool): if not self._main_bot_supports_vision(bot): return None - # Use the configured main model name; do NOT inject tool.vision.model - # here, because by the time we reach this branch the tool.vision.model + # Use the configured main model name; do NOT inject tools.vision.model + # here, because by the time we reach this branch the tools.vision.model # routing has already been attempted (and either matched the main bot # or failed to find a provider). main_model_name = conf().get("model") or None diff --git a/agent/tools/web_search/web_search.py b/agent/tools/web_search/web_search.py index 4c6d1e45..ca56567d 100644 --- a/agent/tools/web_search/web_search.py +++ b/agent/tools/web_search/web_search.py @@ -1,13 +1,27 @@ -""" -Web Search tool - Search the web using Bocha or LinkAI search API. -Supports two backends with unified response format: - 1. Bocha Search (primary, requires BOCHA_API_KEY) - 2. LinkAI Search (fallback, requires LINKAI_API_KEY) +"""Web Search tool. Supports four backends with a unified response format: + - bocha (https://open.bochaai.com) + - zhipu (https://docs.bigmodel.cn/cn/guide/tools/web-search) + - qianfan (https://cloud.baidu.com/doc/qianfan/s/2mh4su4uy) + - linkai (https://link-ai.tech, fallback) + +Provider selection + - strategy 'auto' (default): pick the first configured provider in the + canonical order [bocha, zhipu, qianfan, linkai]. When the caller passes + an explicit `provider` it overrides the pick; an invalid/unconfigured + one silently falls back to the auto order. + - strategy 'fixed': use the configured provider; if its credential is + missing at call time, silently fall back to auto order (no card hint). + +Credentials + - bocha : tools.web_search.bocha_api_key -> env BOCHA_API_KEY + - zhipu : conf.zhipu_ai_api_key -> env ZHIPUAI_API_KEY + - qianfan : conf.qianfan_api_key -> env QIANFAN_API_KEY + - linkai : conf.linkai_api_key -> env LINKAI_API_KEY """ -import os import json -from typing import Dict, Any, Optional +import os +from typing import Any, Dict, List, Optional import requests @@ -16,12 +30,63 @@ from common.log import logger from config import conf -# Default timeout for API requests (seconds) DEFAULT_TIMEOUT = 30 +# Canonical fallback order. Empirically ordered by Chinese real-time +# quality + relevance: bocha (best overall), qianfan (best for hot news), +# zhipu (strong on long-form articles), linkai (cloud aggregator, last +# resort). +PROVIDER_ORDER = ("bocha", "qianfan", "zhipu", "linkai") + +PROVIDER_LABELS = { + "bocha": "Bocha", + "zhipu": "Zhipu", + "qianfan": "Baidu Qianfan", + "linkai": "LinkAI", +} + + +def _tools_web_search_conf() -> dict: + """Return the tools.web_search config block (dict-like).""" + tools_cfg = conf().get("tools") or {} + if not isinstance(tools_cfg, dict): + return {} + block = tools_cfg.get("web_search") or {} + return block if isinstance(block, dict) else {} + + +def _get_api_key(provider: str) -> str: + """Resolve API key for a provider, with conf -> env fallback.""" + if provider == "bocha": + key = (_tools_web_search_conf().get("bocha_api_key") or "").strip() + return key or os.environ.get("BOCHA_API_KEY", "").strip() + if provider == "zhipu": + key = (conf().get("zhipu_ai_api_key") or "").strip() + return key or os.environ.get("ZHIPUAI_API_KEY", "").strip() + if provider == "qianfan": + key = (conf().get("qianfan_api_key") or "").strip() + return key or os.environ.get("QIANFAN_API_KEY", "").strip() + if provider == "linkai": + key = (conf().get("linkai_api_key") or "").strip() + return key or os.environ.get("LINKAI_API_KEY", "").strip() + return "" + + +def configured_providers() -> List[str]: + """Return configured providers in canonical order.""" + return [p for p in PROVIDER_ORDER if _get_api_key(p)] + + +def _configured_strategy() -> str: + return (_tools_web_search_conf().get("strategy") or "auto").strip().lower() + + +def _configured_provider() -> str: + return (_tools_web_search_conf().get("provider") or "").strip().lower() + class WebSearch(BaseTool): - """Tool for searching the web using Bocha or LinkAI search API""" + """Tool for searching the web across multiple providers.""" name: str = "web_search" description: str = "Search the web for real-time information. Returns titles, URLs, and snippets." @@ -55,264 +120,368 @@ class WebSearch(BaseTool): def __init__(self, config: dict = None): self.config = config or {} - self._backend = None # Will be resolved on first execute @staticmethod def is_available() -> bool: - """Check if web search is available (at least one API key is configured)""" - return bool(os.environ.get("BOCHA_API_KEY") or os.environ.get("LINKAI_API_KEY")) + """Tool is offered to the agent when at least one provider has a key.""" + return bool(configured_providers()) - def _resolve_backend(self) -> Optional[str]: - """ - Determine which search backend to use. - Priority: Bocha > LinkAI + @classmethod + def get_json_schema(cls) -> dict: + """Augment the static schema with a `provider` field — only when the + user has ≥2 providers configured AND strategy is 'auto'. Otherwise + the backend picks silently and exposing the field would only waste + the agent's tokens.""" + schema = { + "name": cls.name, + "description": cls.description, + "parameters": json.loads(json.dumps(cls.params)), # deep copy + } + if _configured_strategy() != "auto": + return schema + available = configured_providers() + if len(available) < 2: + return schema - :return: 'bocha', 'linkai', or None + schema["parameters"]["properties"]["provider"] = { + "type": "string", + "enum": available, + "description": "Optional. Specifies the search backend. You may switch between providers when the user wants results from a particular source or from multiple sources.", + } + return schema + + # ------------------------------------------------------------------ + # Provider resolution + # ------------------------------------------------------------------ + + def _resolve_provider(self, requested: Optional[str]) -> Optional[str]: + """Pick a provider for this call. + + Priority: caller-supplied (if configured) > fixed strategy (if + configured) > first configured in PROVIDER_ORDER. Silent fallback + when the desired one has no key. """ - if os.environ.get("BOCHA_API_KEY"): - return "bocha" - if os.environ.get("LINKAI_API_KEY"): - return "linkai" - return None + available = configured_providers() + if not available: + return None + + if requested: + req = requested.strip().lower() + if req in available: + return req + logger.warning(f"[WebSearch] requested provider '{requested}' unavailable, falling back") + + if _configured_strategy() == "fixed": + pinned = _configured_provider() + if pinned in available: + return pinned + if pinned: + logger.warning(f"[WebSearch] pinned provider '{pinned}' unavailable, falling back to auto") + + return available[0] + + @staticmethod + def _resolution_reason(requested: Optional[str], chosen: str) -> str: + """Human-readable explanation for why `chosen` won the resolver.""" + if requested and requested.strip().lower() == chosen: + return "caller-requested" + strategy = _configured_strategy() + if strategy == "fixed" and _configured_provider() == chosen: + return "fixed-strategy" + return "auto-fallback" + + # ------------------------------------------------------------------ + # Entry point + # ------------------------------------------------------------------ def execute(self, args: Dict[str, Any]) -> ToolResult: - """ - Execute web search - - :param args: Search parameters (query, count, freshness, summary) - :return: Search results - """ - query = args.get("query", "").strip() + query = (args.get("query") or "").strip() if not query: return ToolResult.fail("Error: 'query' parameter is required") count = args.get("count", 10) freshness = args.get("freshness", "noLimit") summary = args.get("summary", False) - - # Validate count if not isinstance(count, int) or count < 1 or count > 50: count = 10 - # Resolve backend - backend = self._resolve_backend() - if not backend: + requested = args.get("provider") + provider = self._resolve_provider(requested) + if not provider: return ToolResult.fail( - "Error: No search API key configured. " - "Please set BOCHA_API_KEY or LINKAI_API_KEY using env_config tool.\n" - " - Bocha Search: https://open.bocha.cn\n" - " - LinkAI Search: https://link-ai.tech" + "Error: No search provider configured. " + "Configure one of BOCHA_API_KEY / zhipu_ai_api_key / qianfan_api_key / linkai_api_key." ) + # Always log the routing decision so multi-provider deployments can + # tell at a glance which backend served any given query. + available = configured_providers() + reason = self._resolution_reason(requested, provider) + q_preview = query if len(query) <= 60 else (query[:57] + "...") + logger.info( + f"[WebSearch] provider={provider} reason={reason} " + f"available={list(available)} query={q_preview!r} count={count} freshness={freshness}" + ) + try: - if backend == "bocha": + if provider == "bocha": return self._search_bocha(query, count, freshness, summary) - else: + if provider == "zhipu": + return self._search_zhipu(query, count, freshness) + if provider == "qianfan": + return self._search_qianfan(query, count, freshness) + if provider == "linkai": return self._search_linkai(query, count, freshness) + return ToolResult.fail(f"Error: Unknown provider '{provider}'") except requests.Timeout: return ToolResult.fail(f"Error: Search request timed out after {DEFAULT_TIMEOUT}s") except requests.ConnectionError: return ToolResult.fail("Error: Failed to connect to search API") except Exception as e: - logger.error(f"[WebSearch] Unexpected error: {e}", exc_info=True) + logger.error(f"[WebSearch] Unexpected error ({provider}): {e}", exc_info=True) return ToolResult.fail(f"Error: Search failed - {str(e)}") + # ------------------------------------------------------------------ + # Bocha + # ------------------------------------------------------------------ + def _search_bocha(self, query: str, count: int, freshness: str, summary: bool) -> ToolResult: - """ - Search using Bocha API - - :param query: Search query - :param count: Number of results - :param freshness: Time range filter - :param summary: Whether to include summary - :return: Formatted search results - """ - api_key = os.environ.get("BOCHA_API_KEY", "") - url = "https://api.bocha.cn/v1/web-search" - + api_key = _get_api_key("bocha") + url = "https://api.bochaai.com/v1/web-search" headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", - "Accept": "application/json" + "Accept": "application/json", } + payload = {"query": query, "count": count, "freshness": freshness, "summary": summary} - payload = { - "query": query, - "count": count, - "freshness": freshness, - "summary": summary - } + logger.debug(f"[WebSearch] bocha: query='{query}', count={count}") + resp = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT) - logger.debug(f"[WebSearch] Bocha search: query='{query}', count={count}") + if resp.status_code == 401: + return ToolResult.fail("Error: Invalid bocha API key.") + if resp.status_code == 403: + return ToolResult.fail("Error: bocha API — insufficient balance. Top up at https://open.bochaai.com") + if resp.status_code == 429: + return ToolResult.fail("Error: bocha API rate limit reached.") + if resp.status_code != 200: + return ToolResult.fail(f"Error: bocha API returned HTTP {resp.status_code}") - response = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT) - - if response.status_code == 401: - return ToolResult.fail("Error: Invalid BOCHA_API_KEY. Please check your API key.") - if response.status_code == 403: - return ToolResult.fail("Error: Bocha API - insufficient balance. Please top up at https://open.bocha.cn") - if response.status_code == 429: - return ToolResult.fail("Error: Bocha API rate limit reached. Please try again later.") - if response.status_code != 200: - return ToolResult.fail(f"Error: Bocha API returned HTTP {response.status_code}") - - data = response.json() - - # Check API-level error code + data = resp.json() api_code = data.get("code") if api_code is not None and api_code != 200: msg = data.get("msg") or "Unknown error" - return ToolResult.fail(f"Error: Bocha API error (code={api_code}): {msg}") - - # Extract and format results - return self._format_bocha_results(data, query) - - def _format_bocha_results(self, data: dict, query: str) -> ToolResult: - """ - Format Bocha API response into unified result structure - - :param data: Raw API response - :param query: Original query - :return: Formatted ToolResult - """ - search_data = data.get("data", {}) - web_pages = search_data.get("webPages", {}) - pages = web_pages.get("value", []) - - if not pages: - return ToolResult.success({ - "query": query, - "backend": "bocha", - "total": 0, - "results": [], - "message": "No results found" - }) + return ToolResult.fail(f"Error: bocha API error (code={api_code}): {msg}") + pages = (data.get("data") or {}).get("webPages", {}).get("value", []) or [] results = [] - for page in pages: - result = { - "title": page.get("name", ""), - "url": page.get("url", ""), - "snippet": page.get("snippet", ""), - "siteName": page.get("siteName", ""), - "datePublished": page.get("datePublished") or page.get("dateLastCrawled", ""), + for p in pages: + item = { + "title": p.get("name", ""), + "url": p.get("url", ""), + "snippet": p.get("snippet", ""), + "siteName": p.get("siteName", ""), + "datePublished": p.get("datePublished") or p.get("dateLastCrawled", ""), } - # Include summary only if present - if page.get("summary"): - result["summary"] = page["summary"] - results.append(result) - - total = web_pages.get("totalEstimatedMatches", len(results)) - + if p.get("summary"): + item["summary"] = p["summary"] + results.append(item) + total = (data.get("data") or {}).get("webPages", {}).get("totalEstimatedMatches", len(results)) return ToolResult.success({ - "query": query, - "backend": "bocha", - "total": total, - "count": len(results), - "results": results + "query": query, "backend": "bocha", + "total": total, "count": len(results), "results": results, }) - def _search_linkai(self, query: str, count: int, freshness: str) -> ToolResult: - """ - Search using LinkAI plugin API + # ------------------------------------------------------------------ + # Zhipu + # ------------------------------------------------------------------ - :param query: Search query - :param count: Number of results - :param freshness: Time range filter - :return: Formatted search results - """ - api_key = os.environ.get("LINKAI_API_KEY", "") - api_base = conf().get("linkai_api_base", "https://api.link-ai.tech") - url = f"{api_base.rstrip('/')}/v1/plugin/execute" + def _search_zhipu(self, query: str, count: int, freshness: str) -> ToolResult: + api_key = _get_api_key("zhipu") + api_base = (conf().get("zhipu_ai_api_base") or "https://open.bigmodel.cn/api/paas/v4").rstrip("/") + url = f"{api_base}/web_search" + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + } + + # Zhipu Web Search expects `search_query` <= 70 chars; truncate + # gracefully so a long agent-supplied query doesn't get rejected. + trimmed_query = (query or "")[:70] + engine = (_tools_web_search_conf().get("zhipu_search_engine") or "search_pro").strip().lower() + if engine not in ("search_std", "search_pro", "search_pro_sogou", "search_pro_quark"): + engine = "search_pro" + + payload: Dict[str, Any] = { + "search_engine": engine, + "search_query": trimmed_query, + "search_intent": False, + "count": max(1, min(int(count or 10), 50)), + "search_recency_filter": freshness if freshness in ( + "oneDay", "oneWeek", "oneMonth", "oneYear", "noLimit" + ) else "noLimit", + } + content_size = (_tools_web_search_conf().get("zhipu_content_size") or "").strip().lower() + if content_size in ("medium", "high"): + payload["content_size"] = content_size + + logger.debug(f"[WebSearch] zhipu: query='{trimmed_query}', count={payload['count']}, engine={engine}") + resp = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT) + + if resp.status_code == 401: + return ToolResult.fail("Error: Invalid Zhipu API key.") + if resp.status_code != 200: + return ToolResult.fail(f"Error: Zhipu API returned HTTP {resp.status_code}: {resp.text[:200]}") + + data = resp.json() + # Business-level errors (1701/1702/1703 etc.) come back as + # {"error": {"code","message"}} even on HTTP 200. + if isinstance(data, dict) and data.get("error"): + err = data["error"] or {} + return ToolResult.fail(f"Error: Zhipu returned {err.get('code')}: {err.get('message','')}") + + items = data.get("search_result") or (data.get("data") or {}).get("search_result") or [] + results = [] + for it in items: + results.append({ + "title": it.get("title", ""), + "url": it.get("link") or it.get("url", ""), + "snippet": it.get("content") or it.get("snippet", ""), + "siteName": it.get("media") or it.get("siteName", ""), + "datePublished": it.get("publish_date") or it.get("datePublished", ""), + }) + return ToolResult.success({ + "query": query, "backend": "zhipu", + "total": len(results), "count": len(results), "results": results, + }) + + # ------------------------------------------------------------------ + # Qianfan (Baidu) + # ------------------------------------------------------------------ + + def _search_qianfan(self, query: str, count: int, freshness: str) -> ToolResult: + api_key = _get_api_key("qianfan") + api_base = (conf().get("qianfan_api_base") or "https://qianfan.baidubce.com/v2").rstrip("/") + url = f"{api_base}/ai_search/web_search" + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + "X-Appbuilder-From": "cow", + } + + count = max(1, min(int(count or 10), 50)) + payload: Dict[str, Any] = { + "messages": [{"role": "user", "content": query}], + "search_source": "baidu_search_v2", + "resource_type_filter": [{"type": "web", "top_k": count}], + } + + # Baidu AI Search expects freshness as a date-range filter, not a + # named recency token. Translate our shared vocabulary into the + # underlying page_time range expected by the API. + search_filter = self._qianfan_build_freshness_filter(freshness) + if search_filter: + payload["search_filter"] = search_filter + + logger.debug(f"[WebSearch] qianfan: query='{query}', count={count}, freshness={freshness!r}") + resp = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT) + + if resp.status_code == 401: + return ToolResult.fail("Error: Invalid Qianfan API key.") + if resp.status_code != 200: + return ToolResult.fail(f"Error: Qianfan API returned HTTP {resp.status_code}: {resp.text[:200]}") + + data = resp.json() + # Even on HTTP 200 Baidu surfaces business errors as {"code","message"}. + if isinstance(data, dict) and data.get("code"): + return ToolResult.fail(f"Error: Qianfan returned {data.get('code')}: {data.get('message','')}") + + refs = data.get("references") or [] + results = [] + for d in refs: + results.append({ + "title": d.get("title", ""), + "url": d.get("url", ""), + "snippet": (d.get("content") or "")[:200], + "siteName": d.get("web_anchor") or d.get("website") or "", + "datePublished": d.get("date", ""), + }) + return ToolResult.success({ + "query": query, "backend": "qianfan", + "total": len(results), "count": len(results), "results": results, + }) + + @staticmethod + def _qianfan_build_freshness_filter(freshness: str) -> Optional[Dict[str, Any]]: + if not freshness or freshness == "noLimit": + return None + delta_days = {"oneDay": 1, "oneWeek": 7, "oneMonth": 30, "oneYear": 365}.get(freshness) + if not delta_days: + return None + from datetime import datetime, timedelta + now = datetime.now() + end_date = (now + timedelta(days=1)).strftime("%Y-%m-%d") + start_date = (now - timedelta(days=delta_days)).strftime("%Y-%m-%d") + return {"range": {"page_time": {"gte": start_date, "lt": end_date}}} + + # ------------------------------------------------------------------ + # LinkAI (plugin) + # ------------------------------------------------------------------ + + def _search_linkai(self, query: str, count: int, freshness: str) -> ToolResult: + api_key = _get_api_key("linkai") + api_base = (conf().get("linkai_api_base") or "https://api.link-ai.tech").rstrip("/") + url = f"{api_base}/v1/plugin/execute" from common.utils import get_cloud_headers headers = get_cloud_headers(api_key) - payload = { - "code": "web-search", - "args": { - "query": query, - "count": count, - "freshness": freshness - } - } + payload = {"code": "web-search", "args": {"query": query, "count": count, "freshness": freshness}} + logger.debug(f"[WebSearch] linkai: query='{query}', count={count}") + resp = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT) - logger.debug(f"[WebSearch] LinkAI search: query='{query}', count={count}") - - response = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT) - - if response.status_code == 401: - return ToolResult.fail("Error: Invalid LINKAI_API_KEY. Please check your API key.") - if response.status_code != 200: - return ToolResult.fail(f"Error: LinkAI API returned HTTP {response.status_code}") - - data = response.json() + if resp.status_code == 401: + return ToolResult.fail("Error: Invalid LinkAI API key.") + if resp.status_code != 200: + return ToolResult.fail(f"Error: LinkAI API returned HTTP {resp.status_code}") + data = resp.json() if not data.get("success"): msg = data.get("message") or "Unknown error" return ToolResult.fail(f"Error: LinkAI search failed: {msg}") - return self._format_linkai_results(data, query) - - def _format_linkai_results(self, data: dict, query: str) -> ToolResult: - """ - Format LinkAI API response into unified result structure. - LinkAI returns the search data in data.data field, which follows - the same Bing-compatible format as Bocha. - - :param data: Raw API response - :param query: Original query - :return: Formatted ToolResult - """ - raw_data = data.get("data", "") - - # LinkAI may return data as a JSON string - if isinstance(raw_data, str): + raw = data.get("data", "") + if isinstance(raw, str): try: - raw_data = json.loads(raw_data) + raw = json.loads(raw) except (json.JSONDecodeError, TypeError): - # If data is plain text, return it as a single result return ToolResult.success({ - "query": query, - "backend": "linkai", - "total": 1, - "count": 1, - "results": [{"content": raw_data}] + "query": query, "backend": "linkai", + "total": 1, "count": 1, "results": [{"content": raw}], }) - # If the response follows Bing-compatible structure - if isinstance(raw_data, dict): - web_pages = raw_data.get("webPages", {}) - pages = web_pages.get("value", []) - + if isinstance(raw, dict): + pages = (raw.get("webPages") or {}).get("value", []) or [] if pages: results = [] - for page in pages: - result = { - "title": page.get("name", ""), - "url": page.get("url", ""), - "snippet": page.get("snippet", ""), - "siteName": page.get("siteName", ""), - "datePublished": page.get("datePublished") or page.get("dateLastCrawled", ""), + for p in pages: + item = { + "title": p.get("name", ""), + "url": p.get("url", ""), + "snippet": p.get("snippet", ""), + "siteName": p.get("siteName", ""), + "datePublished": p.get("datePublished") or p.get("dateLastCrawled", ""), } - if page.get("summary"): - result["summary"] = page["summary"] - results.append(result) - - total = web_pages.get("totalEstimatedMatches", len(results)) + if p.get("summary"): + item["summary"] = p["summary"] + results.append(item) + total = (raw.get("webPages") or {}).get("totalEstimatedMatches", len(results)) return ToolResult.success({ - "query": query, - "backend": "linkai", - "total": total, - "count": len(results), - "results": results + "query": query, "backend": "linkai", + "total": total, "count": len(results), "results": results, }) - # Fallback: return raw data return ToolResult.success({ - "query": query, - "backend": "linkai", - "total": 1, - "count": 1, - "results": [{"content": str(raw_data)}] + "query": query, "backend": "linkai", + "total": 1, "count": 1, "results": [{"content": str(raw)}], }) diff --git a/bridge/agent_initializer.py b/bridge/agent_initializer.py index d17dcb0c..7d5afb4a 100644 --- a/bridge/agent_initializer.py +++ b/bridge/agent_initializer.py @@ -521,7 +521,7 @@ class AgentInitializer: if tool_name == "web_search": from agent.tools.web_search.web_search import WebSearch if not WebSearch.is_available(): - logger.debug("[AgentInitializer] WebSearch skipped - no BOCHA_API_KEY or LINKAI_API_KEY") + logger.debug("[AgentInitializer] WebSearch skipped - no search provider configured") continue # Special handling for EnvConfig tool diff --git a/bridge/bridge.py b/bridge/bridge.py index 753e394a..c0cb62e4 100644 --- a/bridge/bridge.py +++ b/bridge/bridge.py @@ -14,7 +14,9 @@ class Bridge(object): def __init__(self): self.btype = { "chat": const.OPENAI, - "voice_to_text": conf().get("voice_to_text", "openai"), + # Empty `voice_to_text` (the default in new configs) triggers + # the auto-pick below — see _auto_pick_voice_to_text for order. + "voice_to_text": conf().get("voice_to_text") or self._auto_pick_voice_to_text(), "text_to_voice": conf().get("text_to_voice", "google"), "translate": conf().get("translate", "baidu"), } @@ -84,6 +86,46 @@ class Bridge(object): self.chat_bots = {} self._agent_bridge = None + def refresh_voice(self): + """Re-read voice_to_text / text_to_voice from config and drop the + cached voice bots so the next call picks up the new provider. + Used by the web console after the user edits voice settings. + Does NOT touch the agent_bridge / agent state. + """ + new_v2t = conf().get("voice_to_text") or self._auto_pick_voice_to_text() + new_t2v = conf().get("text_to_voice", "google") + if conf().get("use_linkai") and conf().get("linkai_api_key"): + if not conf().get("voice_to_text") or conf().get("voice_to_text") in ["openai"]: + new_v2t = const.LINKAI + if not conf().get("text_to_voice") or conf().get("text_to_voice") in ["openai", const.TTS_1, const.TTS_1_HD]: + new_t2v = const.LINKAI + self.btype["voice_to_text"] = new_v2t + self.btype["text_to_voice"] = new_t2v + self.bots.pop("voice_to_text", None) + self.bots.pop("text_to_voice", None) + logger.info(f"[Bridge] voice refreshed: voice_to_text={new_v2t}, text_to_voice={new_t2v}") + + @staticmethod + def _auto_pick_voice_to_text() -> str: + """Pick an ASR provider by configured api keys when voice_to_text is + unset. Order matches the web console: openai → dashscope → zhipu → + linkai. Falls back to 'openai' when nothing is configured so the + original "missing key" error is preserved. + """ + def has(k: str) -> bool: + v = (conf().get(k) or "").strip() + return v != "" and v not in ("YOUR API KEY", "YOUR_API_KEY") + + for key, provider in ( + ("open_ai_api_key", "openai"), + ("dashscope_api_key", "dashscope"), + ("zhipu_ai_api_key", "zhipu"), + ("linkai_api_key", "linkai"), + ): + if has(key): + return provider + return "openai" + # 模型对应的接口 def get_bot(self, typename): if self.bots.get(typename) is None: diff --git a/channel/chat_channel.py b/channel/chat_channel.py index 3251c286..c38dd7c8 100644 --- a/channel/chat_channel.py +++ b/channel/chat_channel.py @@ -171,7 +171,13 @@ class ChatChannel(Channel): if "desire_rtype" not in context and conf().get("always_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE: context["desire_rtype"] = ReplyType.VOICE elif context.type == ContextType.VOICE: - if "desire_rtype" not in context and conf().get("voice_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE: + # Voice input replies with voice when either voice_reply_voice + # (mirror voice) or the global always_reply_voice toggle is on. + if ( + "desire_rtype" not in context + and (conf().get("voice_reply_voice") or conf().get("always_reply_voice")) + and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE + ): context["desire_rtype"] = ReplyType.VOICE return context @@ -264,6 +270,8 @@ class ChatChannel(Channel): if reply.type == ReplyType.TEXT: reply_text = reply.content if desire_rtype == ReplyType.VOICE and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE: + # Preserve original text for the "text-then-voice" pattern in _send_reply. + context["voice_reply_text"] = reply.content reply = super().build_text_to_voice(reply.content) return self._decorate_reply(context, reply) if context.get("isgroup", False): @@ -311,6 +319,15 @@ class ChatChannel(Channel): # 短暂延迟后发送图片 time.sleep(0.3) self._send(reply, context) + # Send text bubble before voice, unless channel already streamed + # the text (feishu) or natively renders STT under the voice (wechatcom). + elif reply.type == ReplyType.VOICE and context.get("voice_reply_text") \ + and not context.get("feishu_streamed") \ + and context.get("channel_type") not in ("wechatcom_app",): + text_reply = Reply(ReplyType.TEXT, context.get("voice_reply_text")) + self._send(text_reply, context) + time.sleep(0.3) + self._send(reply, context) else: self._send(reply, context) diff --git a/channel/dingtalk/dingtalk_channel.py b/channel/dingtalk/dingtalk_channel.py index d572e35d..b1ae86c2 100644 --- a/channel/dingtalk/dingtalk_channel.py +++ b/channel/dingtalk/dingtalk_channel.py @@ -86,6 +86,8 @@ def _check(func): @singleton class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler): + NOT_SUPPORT_REPLYTYPE = [] + dingtalk_client_id = conf().get('dingtalk_client_id') dingtalk_client_secret = conf().get('dingtalk_client_secret') @@ -870,6 +872,48 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler): self.reply_text("抱歉,文件上传失败", incoming_message) return + # Native sampleAudio. Upload only accepts ogg/amr, so convert TTS mp3/wav to amr. + elif reply.type == ReplyType.VOICE: + logger.info(f"[DingTalk] Sending voice: {reply.content}") + access_token = self.get_access_token() + if not access_token: + logger.error("[DingTalk] Cannot get access token for voice") + self.reply_text("抱歉,语音发送失败(无法获取token)", incoming_message) + return + + voice_path = reply.content + if voice_path.startswith("file://"): + voice_path = voice_path[7:] + + amr_path = voice_path + duration_ms = 0 + if not voice_path.lower().endswith((".amr", ".ogg")): + try: + from voice.audio_convert import any_to_amr + amr_path = os.path.splitext(voice_path)[0] + ".amr" + duration_ms = int(any_to_amr(voice_path, amr_path) or 0) + except Exception as e: + logger.error(f"[DingTalk] Failed to convert voice to amr: {e}") + self.reply_text("抱歉,语音转码失败", incoming_message) + return + + media_id = self.upload_media(amr_path, media_type="voice") + if not media_id: + logger.error("[DingTalk] Failed to upload voice media") + self.reply_text("抱歉,语音上传失败", incoming_message) + return + + msg_param = { + "mediaId": media_id, + "duration": str(duration_ms or 1000), + } + success = self._send_file_message( + access_token, incoming_message, "sampleAudio", msg_param, isgroup + ) + if not success: + self.reply_text("抱歉,语音发送失败", incoming_message) + return + # 处理文本消息 elif reply.type == ReplyType.TEXT: logger.info(f"[DingTalk] Sending text message, length={len(reply.content)}") diff --git a/channel/feishu/feishu_channel.py b/channel/feishu/feishu_channel.py index f479394a..ca18e64b 100644 --- a/channel/feishu/feishu_channel.py +++ b/channel/feishu/feishu_channel.py @@ -1515,10 +1515,16 @@ class FeiShuChanel(ChatChannel): else: context.type = ContextType.TEXT context.content = content.strip() + # Text input opts into voice replies only when the always-on toggle is set. + if "desire_rtype" not in context and conf().get("always_reply_voice"): + context["desire_rtype"] = ReplyType.VOICE elif context.type == ContextType.VOICE: - # 2.语音请求 - if "desire_rtype" not in context and conf().get("voice_reply_voice"): + # 2.语音请求: voice input replies with voice if either + # voice_reply_voice (mirror reply) or always_reply_voice is on. + if "desire_rtype" not in context and ( + conf().get("voice_reply_voice") or conf().get("always_reply_voice") + ): context["desire_rtype"] = ReplyType.VOICE return context diff --git a/channel/web/chat.html b/channel/web/chat.html index 56ce808f..947e07b7 100644 --- a/channel/web/chat.html +++ b/channel/web/chat.html @@ -137,6 +137,11 @@ 配置 + + + 模型 + @@ -417,15 +422,24 @@ - +
+ + +
+ +
+ Loading... +
+ + + + + @@ -959,7 +1013,7 @@ - `; @@ -1481,11 +1911,12 @@ function startSSE(requestId, loadingEl, timestamp, titleInfo) { scrollChatToBottom(); } else if (item.type === 'done') { + // Don't close the stream yet: the backend keeps it open + // for a short tail to deliver async attachments such as + // TTS audio (`voice_attach`). It will close the stream on + // its own via onerror once the tail expires. done = true; - es.close(); - delete activeStreams[requestId]; - // item.content may be empty when "done" is only a stream-close signal after media. const finalText = item.content || accumulatedText; if (!botEl && finalText) { @@ -1499,6 +1930,7 @@ function startSSE(requestId, loadingEl, timestamp, titleInfo) { if (copyBtn && finalText) copyBtn.style.display = ''; applyHighlighting(botEl); } + renderBotSpeakerButton(botEl, finalText); scrollChatToBottom(); if (titleInfo) { @@ -1508,6 +1940,15 @@ function startSSE(requestId, loadingEl, timestamp, titleInfo) { loadSessionList(); } + } else if (item.type === 'voice_attach') { + // TTS finished — attach a playable audio element to the + // current bot bubble. The stream closes right after. + if (botEl && item.url) { + attachAudioToBotBubble(botEl, item.url, { autoplay: true }); + } + es.close(); + delete activeStreams[requestId]; + } else if (item.type === 'error') { done = true; es.close(); @@ -1521,7 +1962,10 @@ function startSSE(requestId, loadingEl, timestamp, titleInfo) { es.close(); delete activeStreams[requestId]; - if (done) return; + if (done) { + // Normal close after the post-done tail expired; nothing to do. + return; + } if (currentReasoningEl) { finalizeThinking(currentReasoningEl, reasoningStartTime, reasoningText); @@ -1812,21 +2256,174 @@ function createBotMessageEl(content, timestamp, requestId, msg) {
${stepsHtml ? `
${stepsHtml}
` : ''}
${renderMarkdown(displayContent)}
+
${formatTime(timestamp)} +
`; el.querySelector('.answer-content').dataset.rawMd = displayContent; + // Existing TTS attachment (history replay): mount the player up-front. + const existingAudio = msg && msg.extras && msg.extras.audio && msg.extras.audio.url; + if (existingAudio) { + attachAudioToBotBubble(el, existingAudio, { autoplay: false }); + } + renderBotSpeakerButton(el, displayContent); applyHighlighting(el); bindChatKnowledgeLinks(el); return el; } +// Append (or replace) a small audio player inside a bot bubble's +// dedicated `.bot-audio-slot`. Used by both live TTS pushes and history +// replay. Silent failures: never throws. +function attachAudioToBotBubble(botEl, audioUrl, opts) { + try { + if (!botEl || !audioUrl) return; + const slot = botEl.querySelector('.bot-audio-slot'); + if (!slot) return; + slot.innerHTML = ''; + slot.style.marginTop = '6px'; + const pill = renderVoicePill(audioUrl, { autoplay: !!(opts && opts.autoplay) }); + slot.appendChild(pill); + const speakBtn = botEl.querySelector('.speak-msg-btn'); + if (speakBtn) speakBtn.style.display = 'none'; + } catch (_) { /* silent */ } +} + +// Build a compact play/pause + progress + duration pill that wraps a +// hidden