mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-02 00:57:41 +08:00
feat: add request header
This commit is contained in:
@@ -44,6 +44,11 @@ class ChatService:
|
||||
if agent is None:
|
||||
raise RuntimeError("Failed to initialise agent for the session")
|
||||
|
||||
# Pass context metadata to model for downstream API requests
|
||||
if hasattr(agent, 'model'):
|
||||
agent.model.channel_type = channel_type or ""
|
||||
agent.model.session_id = session_id or ""
|
||||
|
||||
# State shared between the event callback and this method
|
||||
state = _StreamState()
|
||||
|
||||
|
||||
@@ -32,18 +32,21 @@ class EmbeddingProvider(ABC):
|
||||
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
"""OpenAI embedding provider using REST API"""
|
||||
|
||||
def __init__(self, model: str = "text-embedding-3-small", api_key: Optional[str] = None, api_base: Optional[str] = None):
|
||||
def __init__(self, model: str = "text-embedding-3-small", api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None, extra_headers: Optional[dict] = None):
|
||||
"""
|
||||
Initialize OpenAI embedding provider
|
||||
|
||||
|
||||
Args:
|
||||
model: Model name (text-embedding-3-small or text-embedding-3-large)
|
||||
api_key: OpenAI API key
|
||||
api_base: Optional API base URL
|
||||
extra_headers: Optional extra headers to include in API requests
|
||||
"""
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base or "https://api.openai.com/v1"
|
||||
self.extra_headers = extra_headers or {}
|
||||
|
||||
# Validate API key
|
||||
if not self.api_key or self.api_key in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
||||
@@ -59,7 +62,8 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
url = f"{self.api_base}/embeddings"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
**self.extra_headers,
|
||||
}
|
||||
data = {
|
||||
"input": input_data,
|
||||
@@ -134,7 +138,8 @@ def create_embedding_provider(
|
||||
provider: str = "openai",
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None
|
||||
api_base: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None
|
||||
) -> EmbeddingProvider:
|
||||
"""
|
||||
Factory function to create embedding provider
|
||||
@@ -147,10 +152,11 @@ def create_embedding_provider(
|
||||
model: Model name (default: text-embedding-3-small)
|
||||
api_key: API key (required)
|
||||
api_base: API base URL
|
||||
|
||||
extra_headers: Optional extra headers to include in API requests
|
||||
|
||||
Returns:
|
||||
EmbeddingProvider instance
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If provider is unsupported or api_key is missing
|
||||
"""
|
||||
@@ -158,4 +164,4 @@ def create_embedding_provider(
|
||||
raise ValueError(f"Unsupported embedding provider: {provider}. Use 'openai' or 'linkai'.")
|
||||
|
||||
model = model or "text-embedding-3-small"
|
||||
return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base)
|
||||
return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base, extra_headers=extra_headers)
|
||||
|
||||
@@ -76,11 +76,15 @@ class MemoryManager:
|
||||
linkai_key = os.environ.get('LINKAI_API_KEY')
|
||||
linkai_base = os.environ.get('LINKAI_API_BASE', 'https://api.link-ai.tech')
|
||||
if linkai_key:
|
||||
from common.utils import get_cloud_headers
|
||||
cloud_headers = get_cloud_headers(linkai_key)
|
||||
cloud_headers.pop("Authorization", None)
|
||||
self.embedding_provider = create_embedding_provider(
|
||||
provider="linkai",
|
||||
model=self.config.embedding_model,
|
||||
api_key=linkai_key,
|
||||
api_base=f"{linkai_base}/v1"
|
||||
api_base=f"{linkai_base}/v1",
|
||||
extra_headers=cloud_headers,
|
||||
)
|
||||
except Exception as e:
|
||||
from common.log import logger
|
||||
|
||||
@@ -82,7 +82,7 @@ class Vision(BaseTool):
|
||||
if not question:
|
||||
return ToolResult.fail("Error: 'question' parameter is required")
|
||||
|
||||
api_key, api_base = self._resolve_provider()
|
||||
api_key, api_base, extra_headers = self._resolve_provider()
|
||||
if not api_key:
|
||||
return ToolResult.fail(
|
||||
"Error: No API key configured for Vision.\n"
|
||||
@@ -98,7 +98,7 @@ class Vision(BaseTool):
|
||||
return ToolResult.fail(f"Error: {e}")
|
||||
|
||||
try:
|
||||
return self._call_api(api_key, api_base, model, question, image_content)
|
||||
return self._call_api(api_key, api_base, model, question, image_content, extra_headers)
|
||||
except requests.Timeout:
|
||||
return ToolResult.fail(f"Error: Vision API request timed out after {DEFAULT_TIMEOUT}s")
|
||||
except requests.ConnectionError:
|
||||
@@ -107,22 +107,26 @@ class Vision(BaseTool):
|
||||
logger.error(f"[Vision] Unexpected error: {e}", exc_info=True)
|
||||
return ToolResult.fail(f"Error: Vision API call failed - {e}")
|
||||
|
||||
def _resolve_provider(self) -> Tuple[Optional[str], str]:
|
||||
"""Resolve API key and base URL. Priority: conf() > env vars."""
|
||||
def _resolve_provider(self) -> Tuple[Optional[str], str, dict]:
|
||||
"""Resolve API key, base URL and extra headers. Priority: conf() > env vars."""
|
||||
api_key = conf().get("open_ai_api_key") or os.environ.get("OPENAI_API_KEY")
|
||||
if api_key:
|
||||
api_base = (conf().get("open_ai_api_base") or os.environ.get("OPENAI_API_BASE", "")).rstrip("/") \
|
||||
or "https://api.openai.com/v1"
|
||||
return api_key, self._ensure_v1(api_base)
|
||||
return api_key, self._ensure_v1(api_base), {}
|
||||
|
||||
api_key = conf().get("linkai_api_key") or os.environ.get("LINKAI_API_KEY")
|
||||
if api_key:
|
||||
api_base = (conf().get("linkai_api_base") or os.environ.get("LINKAI_API_BASE", "")).rstrip("/") \
|
||||
or "https://api.link-ai.tech"
|
||||
logger.debug("[Vision] Using LinkAI API (OPENAI_API_KEY not set)")
|
||||
return api_key, self._ensure_v1(api_base)
|
||||
from common.utils import get_cloud_headers
|
||||
extra = get_cloud_headers(api_key)
|
||||
extra.pop("Authorization", None)
|
||||
extra.pop("Content-Type", None)
|
||||
return api_key, self._ensure_v1(api_base), extra
|
||||
|
||||
return None, ""
|
||||
return None, "", {}
|
||||
|
||||
@staticmethod
|
||||
def _ensure_v1(api_base: str) -> str:
|
||||
@@ -197,7 +201,7 @@ class Vision(BaseTool):
|
||||
return path
|
||||
|
||||
def _call_api(self, api_key: str, api_base: str, model: str,
|
||||
question: str, image_content: dict) -> ToolResult:
|
||||
question: str, image_content: dict, extra_headers: dict = None) -> ToolResult:
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": [
|
||||
@@ -215,6 +219,7 @@ class Vision(BaseTool):
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
**(extra_headers or {}),
|
||||
}
|
||||
|
||||
resp = requests.post(
|
||||
|
||||
@@ -225,10 +225,8 @@ class WebSearch(BaseTool):
|
||||
api_base = conf().get("linkai_api_base", "https://api.link-ai.tech")
|
||||
url = f"{api_base.rstrip('/')}/v1/plugin/execute"
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
}
|
||||
from common.utils import get_cloud_headers
|
||||
headers = get_cloud_headers(api_key)
|
||||
|
||||
payload = {
|
||||
"code": "web-search",
|
||||
|
||||
@@ -152,12 +152,20 @@ class AgentLLMModel(LLMModel):
|
||||
# Only pass max_tokens if it's explicitly set
|
||||
if request.max_tokens is not None:
|
||||
kwargs['max_tokens'] = request.max_tokens
|
||||
|
||||
|
||||
# Extract system prompt if present
|
||||
system_prompt = getattr(request, 'system', None)
|
||||
if system_prompt:
|
||||
kwargs['system'] = system_prompt
|
||||
|
||||
|
||||
# Pass context metadata to bot
|
||||
channel_type = getattr(self, 'channel_type', None)
|
||||
if channel_type:
|
||||
kwargs['channel_type'] = channel_type
|
||||
session_id = getattr(self, 'session_id', None)
|
||||
if session_id:
|
||||
kwargs['session_id'] = session_id
|
||||
|
||||
response = self.bot.call_with_tools(**kwargs)
|
||||
return self._format_response(response)
|
||||
else:
|
||||
@@ -195,10 +203,13 @@ class AgentLLMModel(LLMModel):
|
||||
if system_prompt:
|
||||
kwargs['system'] = system_prompt
|
||||
|
||||
# Pass channel_type for linkai tracking
|
||||
# Pass context metadata to bot
|
||||
channel_type = getattr(self, 'channel_type', None)
|
||||
if channel_type:
|
||||
kwargs['channel_type'] = channel_type
|
||||
session_id = getattr(self, 'session_id', None)
|
||||
if session_id:
|
||||
kwargs['session_id'] = session_id
|
||||
|
||||
stream = self.bot.call_with_tools(**kwargs)
|
||||
|
||||
@@ -375,9 +386,10 @@ class AgentBridge:
|
||||
logger.warning(f"[AgentBridge] Failed to attach context to scheduler: {e}")
|
||||
break
|
||||
|
||||
# Pass channel_type to model so linkai requests carry it
|
||||
# Pass context metadata to model for downstream API requests
|
||||
if context and hasattr(agent, 'model'):
|
||||
agent.model.channel_type = context.get("channel_type", "")
|
||||
agent.model.session_id = session_id or ""
|
||||
|
||||
# Store session_id on agent so executor can clear DB on fatal errors
|
||||
agent._current_session_id = session_id
|
||||
|
||||
@@ -115,3 +115,22 @@ def expand_path(path: str) -> str:
|
||||
expanded = os.path.join(home, path[2:])
|
||||
|
||||
return expanded
|
||||
|
||||
|
||||
def get_cloud_headers(api_key: str) -> dict:
|
||||
"""
|
||||
Build standard headers for LinkAI API requests,
|
||||
including client_id when available.
|
||||
"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
try:
|
||||
from linkai import LinkAIClient
|
||||
client_id = LinkAIClient.fetch_client_id()
|
||||
if client_id:
|
||||
headers["X-Client-Id"] = client_id
|
||||
except Exception:
|
||||
pass
|
||||
return headers
|
||||
|
||||
@@ -534,6 +534,7 @@ def _linkai_call_with_tools(self, messages, tools=None, stream=False, **kwargs):
|
||||
else:
|
||||
channel_type = raw_ct
|
||||
|
||||
session_id = kwargs.get("session_id", "")
|
||||
body = {
|
||||
"messages": messages,
|
||||
"model": kwargs.get("model", conf().get("model") or "gpt-3.5-turbo"),
|
||||
@@ -543,12 +544,22 @@ def _linkai_call_with_tools(self, messages, tools=None, stream=False, **kwargs):
|
||||
"presence_penalty": kwargs.get("presence_penalty", conf().get("presence_penalty", 0.0)),
|
||||
"stream": stream,
|
||||
"channel_type": kwargs.get("channel_type", channel_type),
|
||||
"session_id": session_id,
|
||||
"sender_id": session_id,
|
||||
}
|
||||
|
||||
try:
|
||||
from linkai import LinkAIClient
|
||||
client_id = LinkAIClient.fetch_client_id()
|
||||
if client_id:
|
||||
body["client_id"] = client_id
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if tools:
|
||||
body["tools"] = tools
|
||||
body["tool_choice"] = kwargs.get("tool_choice", "auto")
|
||||
|
||||
|
||||
# Prepare headers
|
||||
headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
||||
base_url = conf().get("linkai_api_base", "https://api.link-ai.tech")
|
||||
|
||||
Reference in New Issue
Block a user