diff --git a/agent/chat/service.py b/agent/chat/service.py index b6a9ae67..d3712dbf 100644 --- a/agent/chat/service.py +++ b/agent/chat/service.py @@ -44,6 +44,11 @@ class ChatService: if agent is None: raise RuntimeError("Failed to initialise agent for the session") + # Pass context metadata to model for downstream API requests + if hasattr(agent, 'model'): + agent.model.channel_type = channel_type or "" + agent.model.session_id = session_id or "" + # State shared between the event callback and this method state = _StreamState() diff --git a/agent/memory/embedding.py b/agent/memory/embedding.py index ea8c7538..1bc1c671 100644 --- a/agent/memory/embedding.py +++ b/agent/memory/embedding.py @@ -32,18 +32,21 @@ class EmbeddingProvider(ABC): class OpenAIEmbeddingProvider(EmbeddingProvider): """OpenAI embedding provider using REST API""" - def __init__(self, model: str = "text-embedding-3-small", api_key: Optional[str] = None, api_base: Optional[str] = None): + def __init__(self, model: str = "text-embedding-3-small", api_key: Optional[str] = None, + api_base: Optional[str] = None, extra_headers: Optional[dict] = None): """ Initialize OpenAI embedding provider - + Args: model: Model name (text-embedding-3-small or text-embedding-3-large) api_key: OpenAI API key api_base: Optional API base URL + extra_headers: Optional extra headers to include in API requests """ self.model = model self.api_key = api_key self.api_base = api_base or "https://api.openai.com/v1" + self.extra_headers = extra_headers or {} # Validate API key if not self.api_key or self.api_key in ["", "YOUR API KEY", "YOUR_API_KEY"]: @@ -59,7 +62,8 @@ class OpenAIEmbeddingProvider(EmbeddingProvider): url = f"{self.api_base}/embeddings" headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}" + "Authorization": f"Bearer {self.api_key}", + **self.extra_headers, } data = { "input": input_data, @@ -134,7 +138,8 @@ def create_embedding_provider( provider: str = "openai", model: Optional[str] = None, api_key: Optional[str] = None, - api_base: Optional[str] = None + api_base: Optional[str] = None, + extra_headers: Optional[dict] = None ) -> EmbeddingProvider: """ Factory function to create embedding provider @@ -147,10 +152,11 @@ def create_embedding_provider( model: Model name (default: text-embedding-3-small) api_key: API key (required) api_base: API base URL - + extra_headers: Optional extra headers to include in API requests + Returns: EmbeddingProvider instance - + Raises: ValueError: If provider is unsupported or api_key is missing """ @@ -158,4 +164,4 @@ def create_embedding_provider( raise ValueError(f"Unsupported embedding provider: {provider}. Use 'openai' or 'linkai'.") model = model or "text-embedding-3-small" - return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base) + return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base, extra_headers=extra_headers) diff --git a/agent/memory/manager.py b/agent/memory/manager.py index bb007fc6..197c9ffd 100644 --- a/agent/memory/manager.py +++ b/agent/memory/manager.py @@ -76,11 +76,15 @@ class MemoryManager: linkai_key = os.environ.get('LINKAI_API_KEY') linkai_base = os.environ.get('LINKAI_API_BASE', 'https://api.link-ai.tech') if linkai_key: + from common.utils import get_cloud_headers + cloud_headers = get_cloud_headers(linkai_key) + cloud_headers.pop("Authorization", None) self.embedding_provider = create_embedding_provider( provider="linkai", model=self.config.embedding_model, api_key=linkai_key, - api_base=f"{linkai_base}/v1" + api_base=f"{linkai_base}/v1", + extra_headers=cloud_headers, ) except Exception as e: from common.log import logger diff --git a/agent/tools/vision/vision.py b/agent/tools/vision/vision.py index 53074141..8b06c437 100644 --- a/agent/tools/vision/vision.py +++ b/agent/tools/vision/vision.py @@ -82,7 +82,7 @@ class Vision(BaseTool): if not question: return ToolResult.fail("Error: 'question' parameter is required") - api_key, api_base = self._resolve_provider() + api_key, api_base, extra_headers = self._resolve_provider() if not api_key: return ToolResult.fail( "Error: No API key configured for Vision.\n" @@ -98,7 +98,7 @@ class Vision(BaseTool): return ToolResult.fail(f"Error: {e}") try: - return self._call_api(api_key, api_base, model, question, image_content) + return self._call_api(api_key, api_base, model, question, image_content, extra_headers) except requests.Timeout: return ToolResult.fail(f"Error: Vision API request timed out after {DEFAULT_TIMEOUT}s") except requests.ConnectionError: @@ -107,22 +107,26 @@ class Vision(BaseTool): logger.error(f"[Vision] Unexpected error: {e}", exc_info=True) return ToolResult.fail(f"Error: Vision API call failed - {e}") - def _resolve_provider(self) -> Tuple[Optional[str], str]: - """Resolve API key and base URL. Priority: conf() > env vars.""" + def _resolve_provider(self) -> Tuple[Optional[str], str, dict]: + """Resolve API key, base URL and extra headers. Priority: conf() > env vars.""" api_key = conf().get("open_ai_api_key") or os.environ.get("OPENAI_API_KEY") if api_key: api_base = (conf().get("open_ai_api_base") or os.environ.get("OPENAI_API_BASE", "")).rstrip("/") \ or "https://api.openai.com/v1" - return api_key, self._ensure_v1(api_base) + return api_key, self._ensure_v1(api_base), {} api_key = conf().get("linkai_api_key") or os.environ.get("LINKAI_API_KEY") if api_key: api_base = (conf().get("linkai_api_base") or os.environ.get("LINKAI_API_BASE", "")).rstrip("/") \ or "https://api.link-ai.tech" logger.debug("[Vision] Using LinkAI API (OPENAI_API_KEY not set)") - return api_key, self._ensure_v1(api_base) + from common.utils import get_cloud_headers + extra = get_cloud_headers(api_key) + extra.pop("Authorization", None) + extra.pop("Content-Type", None) + return api_key, self._ensure_v1(api_base), extra - return None, "" + return None, "", {} @staticmethod def _ensure_v1(api_base: str) -> str: @@ -197,7 +201,7 @@ class Vision(BaseTool): return path def _call_api(self, api_key: str, api_base: str, model: str, - question: str, image_content: dict) -> ToolResult: + question: str, image_content: dict, extra_headers: dict = None) -> ToolResult: payload = { "model": model, "messages": [ @@ -215,6 +219,7 @@ class Vision(BaseTool): headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", + **(extra_headers or {}), } resp = requests.post( diff --git a/agent/tools/web_search/web_search.py b/agent/tools/web_search/web_search.py index 47d4d90b..4c6d1e45 100644 --- a/agent/tools/web_search/web_search.py +++ b/agent/tools/web_search/web_search.py @@ -225,10 +225,8 @@ class WebSearch(BaseTool): api_base = conf().get("linkai_api_base", "https://api.link-ai.tech") url = f"{api_base.rstrip('/')}/v1/plugin/execute" - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {api_key}" - } + from common.utils import get_cloud_headers + headers = get_cloud_headers(api_key) payload = { "code": "web-search", diff --git a/bridge/agent_bridge.py b/bridge/agent_bridge.py index 7e5e605a..44d59b7f 100644 --- a/bridge/agent_bridge.py +++ b/bridge/agent_bridge.py @@ -152,12 +152,20 @@ class AgentLLMModel(LLMModel): # Only pass max_tokens if it's explicitly set if request.max_tokens is not None: kwargs['max_tokens'] = request.max_tokens - + # Extract system prompt if present system_prompt = getattr(request, 'system', None) if system_prompt: kwargs['system'] = system_prompt - + + # Pass context metadata to bot + channel_type = getattr(self, 'channel_type', None) + if channel_type: + kwargs['channel_type'] = channel_type + session_id = getattr(self, 'session_id', None) + if session_id: + kwargs['session_id'] = session_id + response = self.bot.call_with_tools(**kwargs) return self._format_response(response) else: @@ -195,10 +203,13 @@ class AgentLLMModel(LLMModel): if system_prompt: kwargs['system'] = system_prompt - # Pass channel_type for linkai tracking + # Pass context metadata to bot channel_type = getattr(self, 'channel_type', None) if channel_type: kwargs['channel_type'] = channel_type + session_id = getattr(self, 'session_id', None) + if session_id: + kwargs['session_id'] = session_id stream = self.bot.call_with_tools(**kwargs) @@ -375,9 +386,10 @@ class AgentBridge: logger.warning(f"[AgentBridge] Failed to attach context to scheduler: {e}") break - # Pass channel_type to model so linkai requests carry it + # Pass context metadata to model for downstream API requests if context and hasattr(agent, 'model'): agent.model.channel_type = context.get("channel_type", "") + agent.model.session_id = session_id or "" # Store session_id on agent so executor can clear DB on fatal errors agent._current_session_id = session_id diff --git a/common/utils.py b/common/utils.py index c7bcb7a3..812b20ab 100644 --- a/common/utils.py +++ b/common/utils.py @@ -115,3 +115,22 @@ def expand_path(path: str) -> str: expanded = os.path.join(home, path[2:]) return expanded + + +def get_cloud_headers(api_key: str) -> dict: + """ + Build standard headers for LinkAI API requests, + including client_id when available. + """ + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}", + } + try: + from linkai import LinkAIClient + client_id = LinkAIClient.fetch_client_id() + if client_id: + headers["X-Client-Id"] = client_id + except Exception: + pass + return headers diff --git a/models/linkai/link_ai_bot.py b/models/linkai/link_ai_bot.py index 73891878..dd91f3db 100644 --- a/models/linkai/link_ai_bot.py +++ b/models/linkai/link_ai_bot.py @@ -534,6 +534,7 @@ def _linkai_call_with_tools(self, messages, tools=None, stream=False, **kwargs): else: channel_type = raw_ct + session_id = kwargs.get("session_id", "") body = { "messages": messages, "model": kwargs.get("model", conf().get("model") or "gpt-3.5-turbo"), @@ -543,12 +544,22 @@ def _linkai_call_with_tools(self, messages, tools=None, stream=False, **kwargs): "presence_penalty": kwargs.get("presence_penalty", conf().get("presence_penalty", 0.0)), "stream": stream, "channel_type": kwargs.get("channel_type", channel_type), + "session_id": session_id, + "sender_id": session_id, } + try: + from linkai import LinkAIClient + client_id = LinkAIClient.fetch_client_id() + if client_id: + body["client_id"] = client_id + except Exception: + pass + if tools: body["tools"] = tools body["tool_choice"] = kwargs.get("tool_choice", "auto") - + # Prepare headers headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")} base_url = conf().get("linkai_api_base", "https://api.link-ai.tech")