mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-02 00:57:41 +08:00
fix: gemini session bug
This commit is contained in:
@@ -62,7 +62,7 @@ def num_tokens_from_messages(messages, model):
|
|||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
if model in ["gpt-3.5-turbo-0301", "gpt-35-turbo", "gpt-3.5-turbo-1106"]:
|
if model in ["gpt-3.5-turbo-0301", "gpt-35-turbo", "gpt-3.5-turbo-1106", "moonshot"]:
|
||||||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo")
|
return num_tokens_from_messages(messages, model="gpt-3.5-turbo")
|
||||||
elif model in ["gpt-4-0314", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613", "gpt-3.5-turbo-0613",
|
elif model in ["gpt-4-0314", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613", "gpt-3.5-turbo-0613",
|
||||||
"gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-35-turbo-16k", "gpt-4-turbo-preview",
|
"gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-35-turbo-16k", "gpt-4-turbo-preview",
|
||||||
|
|||||||
@@ -8,8 +8,8 @@ import anthropic
|
|||||||
|
|
||||||
from bot.bot import Bot
|
from bot.bot import Bot
|
||||||
from bot.openai.open_ai_image import OpenAIImage
|
from bot.openai.open_ai_image import OpenAIImage
|
||||||
from bot.claudeapi.claude_api_session import ClaudeAPISession
|
|
||||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
||||||
|
from bot.gemini.google_gemini_bot import GoogleGeminiBot
|
||||||
from bot.session_manager import SessionManager
|
from bot.session_manager import SessionManager
|
||||||
from bridge.context import ContextType
|
from bridge.context import ContextType
|
||||||
from bridge.reply import Reply, ReplyType
|
from bridge.reply import Reply, ReplyType
|
||||||
@@ -78,15 +78,12 @@ class ClaudeAPIBot(Bot, OpenAIImage):
|
|||||||
|
|
||||||
def reply_text(self, session: ChatGPTSession, retry_count=0):
|
def reply_text(self, session: ChatGPTSession, retry_count=0):
|
||||||
try:
|
try:
|
||||||
if session.messages[0].get("role") == "system":
|
|
||||||
system = session.messages[0].get("content")
|
|
||||||
session.messages.pop(0)
|
|
||||||
actual_model = self._model_mapping(conf().get("model"))
|
actual_model = self._model_mapping(conf().get("model"))
|
||||||
response = self.claudeClient.messages.create(
|
response = self.claudeClient.messages.create(
|
||||||
model=actual_model,
|
model=actual_model,
|
||||||
max_tokens=1024,
|
max_tokens=1024,
|
||||||
# system=conf().get("system"),
|
# system=conf().get("system"),
|
||||||
messages=session.messages
|
messages=GoogleGeminiBot.filter_messages(session.messages)
|
||||||
)
|
)
|
||||||
# response = openai.Completion.create(prompt=str(session), **self.args)
|
# response = openai.Completion.create(prompt=str(session), **self.args)
|
||||||
res_content = response.content[0].text.strip().replace("<|endoftext|>", "")
|
res_content = response.content[0].text.strip().replace("<|endoftext|>", "")
|
||||||
|
|||||||
@@ -1,74 +0,0 @@
|
|||||||
from bot.session_manager import Session
|
|
||||||
from common.log import logger
|
|
||||||
|
|
||||||
|
|
||||||
class ClaudeAPISession(Session):
|
|
||||||
def __init__(self, session_id, system_prompt=None, model="text-davinci-003"):
|
|
||||||
super().__init__(session_id, system_prompt)
|
|
||||||
self.model = model
|
|
||||||
self.reset()
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
# 构造对话模型的输入
|
|
||||||
"""
|
|
||||||
e.g. Q: xxx
|
|
||||||
A: xxx
|
|
||||||
Q: xxx
|
|
||||||
"""
|
|
||||||
prompt = ""
|
|
||||||
for item in self.messages:
|
|
||||||
if item["role"] == "system":
|
|
||||||
prompt += item["content"] + "<|endoftext|>\n\n\n"
|
|
||||||
elif item["role"] == "user":
|
|
||||||
prompt += "Q: " + item["content"] + "\n"
|
|
||||||
elif item["role"] == "assistant":
|
|
||||||
prompt += "\n\nA: " + item["content"] + "<|endoftext|>\n"
|
|
||||||
|
|
||||||
if len(self.messages) > 0 and self.messages[-1]["role"] == "user":
|
|
||||||
prompt += "A: "
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
def discard_exceeding(self, max_tokens, cur_tokens=None):
|
|
||||||
precise = True
|
|
||||||
|
|
||||||
try:
|
|
||||||
cur_tokens = self.calc_tokens()
|
|
||||||
except Exception as e:
|
|
||||||
precise = False
|
|
||||||
if cur_tokens is None:
|
|
||||||
raise e
|
|
||||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
|
||||||
while cur_tokens > max_tokens:
|
|
||||||
if len(self.messages) > 1:
|
|
||||||
self.messages.pop(0)
|
|
||||||
elif len(self.messages) == 1 and self.messages[0]["role"] == "assistant":
|
|
||||||
self.messages.pop(0)
|
|
||||||
if precise:
|
|
||||||
cur_tokens = self.calc_tokens()
|
|
||||||
else:
|
|
||||||
cur_tokens = len(str(self))
|
|
||||||
break
|
|
||||||
elif len(self.messages) == 1 and self.messages[0]["role"] == "user":
|
|
||||||
logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens))
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages)))
|
|
||||||
break
|
|
||||||
if precise:
|
|
||||||
cur_tokens = self.calc_tokens()
|
|
||||||
else:
|
|
||||||
cur_tokens = len(str(self))
|
|
||||||
return cur_tokens
|
|
||||||
def calc_tokens(self):
|
|
||||||
return num_tokens_from_string(str(self), self.model)
|
|
||||||
|
|
||||||
|
|
||||||
# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
|
||||||
def num_tokens_from_string(string: str, model: str) -> int:
|
|
||||||
"""Returns the number of tokens in a text string."""
|
|
||||||
num_tokens = len(string)
|
|
||||||
return num_tokens
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -33,7 +33,7 @@ class GoogleGeminiBot(Bot):
|
|||||||
logger.info(f"[Gemini] query={query}")
|
logger.info(f"[Gemini] query={query}")
|
||||||
session_id = context["session_id"]
|
session_id = context["session_id"]
|
||||||
session = self.sessions.session_query(query, session_id)
|
session = self.sessions.session_query(query, session_id)
|
||||||
gemini_messages = self._convert_to_gemini_messages(self._filter_messages(session.messages))
|
gemini_messages = self._convert_to_gemini_messages(self.filter_messages(session.messages))
|
||||||
genai.configure(api_key=self.api_key)
|
genai.configure(api_key=self.api_key)
|
||||||
model = genai.GenerativeModel('gemini-pro')
|
model = genai.GenerativeModel('gemini-pro')
|
||||||
response = model.generate_content(gemini_messages)
|
response = model.generate_content(gemini_messages)
|
||||||
@@ -61,7 +61,8 @@ class GoogleGeminiBot(Bot):
|
|||||||
})
|
})
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def _filter_messages(self, messages: list):
|
@staticmethod
|
||||||
|
def filter_messages(messages: list):
|
||||||
res = []
|
res = []
|
||||||
turn = "user"
|
turn = "user"
|
||||||
if not messages:
|
if not messages:
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import requests
|
|||||||
import config
|
import config
|
||||||
from bot.bot import Bot
|
from bot.bot import Bot
|
||||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
||||||
|
from bot.gemini.google_gemini_bot import GoogleGeminiBot
|
||||||
from bot.session_manager import SessionManager
|
from bot.session_manager import SessionManager
|
||||||
from bridge.context import Context, ContextType
|
from bridge.context import Context, ContextType
|
||||||
from bridge.reply import Reply, ReplyType
|
from bridge.reply import Reply, ReplyType
|
||||||
|
|||||||
@@ -10,10 +10,11 @@ CLAUDEAPI= "claudeAPI"
|
|||||||
QWEN = "qwen"
|
QWEN = "qwen"
|
||||||
GEMINI = "gemini"
|
GEMINI = "gemini"
|
||||||
ZHIPU_AI = "glm-4"
|
ZHIPU_AI = "glm-4"
|
||||||
|
MOONSHOT = "moonshot"
|
||||||
|
|
||||||
|
|
||||||
# model
|
# model
|
||||||
CLAUDE3="claude-3-opus-20240229"
|
CLAUDE3 = "claude-3-opus-20240229"
|
||||||
GPT35 = "gpt-3.5-turbo"
|
GPT35 = "gpt-3.5-turbo"
|
||||||
GPT4 = "gpt-4"
|
GPT4 = "gpt-4"
|
||||||
GPT4_TURBO_PREVIEW = "gpt-4-0125-preview"
|
GPT4_TURBO_PREVIEW = "gpt-4-0125-preview"
|
||||||
@@ -23,7 +24,7 @@ TTS_1 = "tts-1"
|
|||||||
TTS_1_HD = "tts-1-hd"
|
TTS_1_HD = "tts-1-hd"
|
||||||
|
|
||||||
MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "wenxin-4", "xunfei", "claude","claude-3-opus-20240229", "gpt-4-turbo",
|
MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "wenxin-4", "xunfei", "claude","claude-3-opus-20240229", "gpt-4-turbo",
|
||||||
"gpt-4-turbo-preview", "gpt-4-1106-preview", GPT4_TURBO_PREVIEW, QWEN, GEMINI, ZHIPU_AI]
|
"gpt-4-turbo-preview", "gpt-4-1106-preview", GPT4_TURBO_PREVIEW, QWEN, GEMINI, ZHIPU_AI, MOONSHOT]
|
||||||
|
|
||||||
# channel
|
# channel
|
||||||
FEISHU = "feishu"
|
FEISHU = "feishu"
|
||||||
|
|||||||
@@ -339,7 +339,7 @@ class Godcmd(Plugin):
|
|||||||
ok, result = True, "配置已重载"
|
ok, result = True, "配置已重载"
|
||||||
elif cmd == "resetall":
|
elif cmd == "resetall":
|
||||||
if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI,
|
if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI,
|
||||||
const.BAIDU, const.XUNFEI, const.QWEN, const.GEMINI, const.ZHIPU_AI]:
|
const.BAIDU, const.XUNFEI, const.QWEN, const.GEMINI, const.ZHIPU_AI, const.MOONSHOT]:
|
||||||
channel.cancel_all_session()
|
channel.cancel_all_session()
|
||||||
bot.sessions.clear_all_session()
|
bot.sessions.clear_all_session()
|
||||||
ok, result = True, "重置所有会话成功"
|
ok, result = True, "重置所有会话成功"
|
||||||
|
|||||||
Reference in New Issue
Block a user