feat: add request header

This commit is contained in:
zhayujie
2026-03-19 17:06:05 +08:00
parent 1b5be1b981
commit b4e711f411
8 changed files with 85 additions and 25 deletions

View File

@@ -44,6 +44,11 @@ class ChatService:
if agent is None:
raise RuntimeError("Failed to initialise agent for the session")
# Pass context metadata to model for downstream API requests
if hasattr(agent, 'model'):
agent.model.channel_type = channel_type or ""
agent.model.session_id = session_id or ""
# State shared between the event callback and this method
state = _StreamState()

View File

@@ -32,18 +32,21 @@ class EmbeddingProvider(ABC):
class OpenAIEmbeddingProvider(EmbeddingProvider):
"""OpenAI embedding provider using REST API"""
def __init__(self, model: str = "text-embedding-3-small", api_key: Optional[str] = None, api_base: Optional[str] = None):
def __init__(self, model: str = "text-embedding-3-small", api_key: Optional[str] = None,
api_base: Optional[str] = None, extra_headers: Optional[dict] = None):
"""
Initialize OpenAI embedding provider
Args:
model: Model name (text-embedding-3-small or text-embedding-3-large)
api_key: OpenAI API key
api_base: Optional API base URL
extra_headers: Optional extra headers to include in API requests
"""
self.model = model
self.api_key = api_key
self.api_base = api_base or "https://api.openai.com/v1"
self.extra_headers = extra_headers or {}
# Validate API key
if not self.api_key or self.api_key in ["", "YOUR API KEY", "YOUR_API_KEY"]:
@@ -59,7 +62,8 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
url = f"{self.api_base}/embeddings"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
"Authorization": f"Bearer {self.api_key}",
**self.extra_headers,
}
data = {
"input": input_data,
@@ -134,7 +138,8 @@ def create_embedding_provider(
provider: str = "openai",
model: Optional[str] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None
api_base: Optional[str] = None,
extra_headers: Optional[dict] = None
) -> EmbeddingProvider:
"""
Factory function to create embedding provider
@@ -147,10 +152,11 @@ def create_embedding_provider(
model: Model name (default: text-embedding-3-small)
api_key: API key (required)
api_base: API base URL
extra_headers: Optional extra headers to include in API requests
Returns:
EmbeddingProvider instance
Raises:
ValueError: If provider is unsupported or api_key is missing
"""
@@ -158,4 +164,4 @@ def create_embedding_provider(
raise ValueError(f"Unsupported embedding provider: {provider}. Use 'openai' or 'linkai'.")
model = model or "text-embedding-3-small"
return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base)
return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base, extra_headers=extra_headers)

View File

@@ -76,11 +76,15 @@ class MemoryManager:
linkai_key = os.environ.get('LINKAI_API_KEY')
linkai_base = os.environ.get('LINKAI_API_BASE', 'https://api.link-ai.tech')
if linkai_key:
from common.utils import get_cloud_headers
cloud_headers = get_cloud_headers(linkai_key)
cloud_headers.pop("Authorization", None)
self.embedding_provider = create_embedding_provider(
provider="linkai",
model=self.config.embedding_model,
api_key=linkai_key,
api_base=f"{linkai_base}/v1"
api_base=f"{linkai_base}/v1",
extra_headers=cloud_headers,
)
except Exception as e:
from common.log import logger

View File

@@ -82,7 +82,7 @@ class Vision(BaseTool):
if not question:
return ToolResult.fail("Error: 'question' parameter is required")
api_key, api_base = self._resolve_provider()
api_key, api_base, extra_headers = self._resolve_provider()
if not api_key:
return ToolResult.fail(
"Error: No API key configured for Vision.\n"
@@ -98,7 +98,7 @@ class Vision(BaseTool):
return ToolResult.fail(f"Error: {e}")
try:
return self._call_api(api_key, api_base, model, question, image_content)
return self._call_api(api_key, api_base, model, question, image_content, extra_headers)
except requests.Timeout:
return ToolResult.fail(f"Error: Vision API request timed out after {DEFAULT_TIMEOUT}s")
except requests.ConnectionError:
@@ -107,22 +107,26 @@ class Vision(BaseTool):
logger.error(f"[Vision] Unexpected error: {e}", exc_info=True)
return ToolResult.fail(f"Error: Vision API call failed - {e}")
def _resolve_provider(self) -> Tuple[Optional[str], str]:
"""Resolve API key and base URL. Priority: conf() > env vars."""
def _resolve_provider(self) -> Tuple[Optional[str], str, dict]:
"""Resolve API key, base URL and extra headers. Priority: conf() > env vars."""
api_key = conf().get("open_ai_api_key") or os.environ.get("OPENAI_API_KEY")
if api_key:
api_base = (conf().get("open_ai_api_base") or os.environ.get("OPENAI_API_BASE", "")).rstrip("/") \
or "https://api.openai.com/v1"
return api_key, self._ensure_v1(api_base)
return api_key, self._ensure_v1(api_base), {}
api_key = conf().get("linkai_api_key") or os.environ.get("LINKAI_API_KEY")
if api_key:
api_base = (conf().get("linkai_api_base") or os.environ.get("LINKAI_API_BASE", "")).rstrip("/") \
or "https://api.link-ai.tech"
logger.debug("[Vision] Using LinkAI API (OPENAI_API_KEY not set)")
return api_key, self._ensure_v1(api_base)
from common.utils import get_cloud_headers
extra = get_cloud_headers(api_key)
extra.pop("Authorization", None)
extra.pop("Content-Type", None)
return api_key, self._ensure_v1(api_base), extra
return None, ""
return None, "", {}
@staticmethod
def _ensure_v1(api_base: str) -> str:
@@ -197,7 +201,7 @@ class Vision(BaseTool):
return path
def _call_api(self, api_key: str, api_base: str, model: str,
question: str, image_content: dict) -> ToolResult:
question: str, image_content: dict, extra_headers: dict = None) -> ToolResult:
payload = {
"model": model,
"messages": [
@@ -215,6 +219,7 @@ class Vision(BaseTool):
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
**(extra_headers or {}),
}
resp = requests.post(

View File

@@ -225,10 +225,8 @@ class WebSearch(BaseTool):
api_base = conf().get("linkai_api_base", "https://api.link-ai.tech")
url = f"{api_base.rstrip('/')}/v1/plugin/execute"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
from common.utils import get_cloud_headers
headers = get_cloud_headers(api_key)
payload = {
"code": "web-search",

View File

@@ -152,12 +152,20 @@ class AgentLLMModel(LLMModel):
# Only pass max_tokens if it's explicitly set
if request.max_tokens is not None:
kwargs['max_tokens'] = request.max_tokens
# Extract system prompt if present
system_prompt = getattr(request, 'system', None)
if system_prompt:
kwargs['system'] = system_prompt
# Pass context metadata to bot
channel_type = getattr(self, 'channel_type', None)
if channel_type:
kwargs['channel_type'] = channel_type
session_id = getattr(self, 'session_id', None)
if session_id:
kwargs['session_id'] = session_id
response = self.bot.call_with_tools(**kwargs)
return self._format_response(response)
else:
@@ -195,10 +203,13 @@ class AgentLLMModel(LLMModel):
if system_prompt:
kwargs['system'] = system_prompt
# Pass channel_type for linkai tracking
# Pass context metadata to bot
channel_type = getattr(self, 'channel_type', None)
if channel_type:
kwargs['channel_type'] = channel_type
session_id = getattr(self, 'session_id', None)
if session_id:
kwargs['session_id'] = session_id
stream = self.bot.call_with_tools(**kwargs)
@@ -375,9 +386,10 @@ class AgentBridge:
logger.warning(f"[AgentBridge] Failed to attach context to scheduler: {e}")
break
# Pass channel_type to model so linkai requests carry it
# Pass context metadata to model for downstream API requests
if context and hasattr(agent, 'model'):
agent.model.channel_type = context.get("channel_type", "")
agent.model.session_id = session_id or ""
# Store session_id on agent so executor can clear DB on fatal errors
agent._current_session_id = session_id

View File

@@ -115,3 +115,22 @@ def expand_path(path: str) -> str:
expanded = os.path.join(home, path[2:])
return expanded
def get_cloud_headers(api_key: str) -> dict:
"""
Build standard headers for LinkAI API requests,
including client_id when available.
"""
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
}
try:
from linkai import LinkAIClient
client_id = LinkAIClient.fetch_client_id()
if client_id:
headers["X-Client-Id"] = client_id
except Exception:
pass
return headers

View File

@@ -534,6 +534,7 @@ def _linkai_call_with_tools(self, messages, tools=None, stream=False, **kwargs):
else:
channel_type = raw_ct
session_id = kwargs.get("session_id", "")
body = {
"messages": messages,
"model": kwargs.get("model", conf().get("model") or "gpt-3.5-turbo"),
@@ -543,12 +544,22 @@ def _linkai_call_with_tools(self, messages, tools=None, stream=False, **kwargs):
"presence_penalty": kwargs.get("presence_penalty", conf().get("presence_penalty", 0.0)),
"stream": stream,
"channel_type": kwargs.get("channel_type", channel_type),
"session_id": session_id,
"sender_id": session_id,
}
try:
from linkai import LinkAIClient
client_id = LinkAIClient.fetch_client_id()
if client_id:
body["client_id"] = client_id
except Exception:
pass
if tools:
body["tools"] = tools
body["tool_choice"] = kwargs.get("tool_choice", "auto")
# Prepare headers
headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
base_url = conf().get("linkai_api_base", "https://api.link-ai.tech")