mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-02 00:57:41 +08:00
feat: support gemini model
This commit is contained in:
@@ -47,4 +47,9 @@ def create_bot(bot_type):
|
|||||||
elif bot_type == const.QWEN:
|
elif bot_type == const.QWEN:
|
||||||
from bot.tongyi.tongyi_qwen_bot import TongyiQwenBot
|
from bot.tongyi.tongyi_qwen_bot import TongyiQwenBot
|
||||||
return TongyiQwenBot()
|
return TongyiQwenBot()
|
||||||
|
|
||||||
|
elif bot_type == const.GEMINI:
|
||||||
|
from bot.gemini.google_gemini_bot import GoogleGeminiBot
|
||||||
|
return GoogleGeminiBot()
|
||||||
|
|
||||||
raise RuntimeError
|
raise RuntimeError
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ class ChatGPTSession(Session):
|
|||||||
def num_tokens_from_messages(messages, model):
|
def num_tokens_from_messages(messages, model):
|
||||||
"""Returns the number of tokens used by a list of messages."""
|
"""Returns the number of tokens used by a list of messages."""
|
||||||
|
|
||||||
if model in ["wenxin", "xunfei"]:
|
if model in ["wenxin", "xunfei", const.GEMINI]:
|
||||||
return num_tokens_by_character(messages)
|
return num_tokens_by_character(messages)
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|||||||
58
bot/gemini/google_gemini_bot.py
Normal file
58
bot/gemini/google_gemini_bot.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
"""
|
||||||
|
Google gemini bot
|
||||||
|
|
||||||
|
@author zhayujie
|
||||||
|
@Date 2023/12/15
|
||||||
|
"""
|
||||||
|
# encoding:utf-8
|
||||||
|
|
||||||
|
from bot.bot import Bot
|
||||||
|
import google.generativeai as genai
|
||||||
|
from bot.session_manager import SessionManager
|
||||||
|
from bridge.context import ContextType, Context
|
||||||
|
from bridge.reply import Reply, ReplyType
|
||||||
|
from common.log import logger
|
||||||
|
from config import conf
|
||||||
|
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||||
|
|
||||||
|
|
||||||
|
# OpenAI对话模型API (可用)
|
||||||
|
class GoogleGeminiBot(Bot):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.api_key = conf().get("gemini_api_key")
|
||||||
|
# 复用文心的token计算方式
|
||||||
|
self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "gpt-3.5-turbo")
|
||||||
|
|
||||||
|
def reply(self, query, context: Context = None) -> Reply:
|
||||||
|
if context.type != ContextType.TEXT:
|
||||||
|
logger.warn(f"[Gemini] Unsupported message type, type={context.type}")
|
||||||
|
return Reply(ReplyType.TEXT, None)
|
||||||
|
logger.info(f"[Gemini] query={query}")
|
||||||
|
session_id = context["session_id"]
|
||||||
|
session = self.sessions.session_query(query, session_id)
|
||||||
|
gemini_messages = self._convert_to_gemini_messages(session.messages)
|
||||||
|
genai.configure(api_key=self.api_key)
|
||||||
|
model = genai.GenerativeModel('gemini-pro')
|
||||||
|
response = model.generate_content(gemini_messages)
|
||||||
|
reply_text = response.text
|
||||||
|
self.sessions.session_reply(reply_text, session_id)
|
||||||
|
logger.info(f"[Gemini] reply={reply_text}")
|
||||||
|
return Reply(ReplyType.TEXT, reply_text)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_to_gemini_messages(self, messages: list):
|
||||||
|
res = []
|
||||||
|
for msg in messages:
|
||||||
|
if msg.get("role") == "user":
|
||||||
|
role = "user"
|
||||||
|
elif msg.get("role") == "assistant":
|
||||||
|
role = "model"
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
res.append({
|
||||||
|
"role": role,
|
||||||
|
"parts": [{"text": msg.get("content")}]
|
||||||
|
})
|
||||||
|
return res
|
||||||
@@ -29,12 +29,16 @@ class Bridge(object):
|
|||||||
self.btype["chat"] = const.XUNFEI
|
self.btype["chat"] = const.XUNFEI
|
||||||
if model_type in [const.QWEN]:
|
if model_type in [const.QWEN]:
|
||||||
self.btype["chat"] = const.QWEN
|
self.btype["chat"] = const.QWEN
|
||||||
|
if model_type in [const.GEMINI]:
|
||||||
|
self.btype["chat"] = const.GEMINI
|
||||||
|
|
||||||
if conf().get("use_linkai") and conf().get("linkai_api_key"):
|
if conf().get("use_linkai") and conf().get("linkai_api_key"):
|
||||||
self.btype["chat"] = const.LINKAI
|
self.btype["chat"] = const.LINKAI
|
||||||
if not conf().get("voice_to_text") or conf().get("voice_to_text") in ["openai"]:
|
if not conf().get("voice_to_text") or conf().get("voice_to_text") in ["openai"]:
|
||||||
self.btype["voice_to_text"] = const.LINKAI
|
self.btype["voice_to_text"] = const.LINKAI
|
||||||
if not conf().get("text_to_voice") or conf().get("text_to_voice") in ["openai", const.TTS_1, const.TTS_1_HD]:
|
if not conf().get("text_to_voice") or conf().get("text_to_voice") in ["openai", const.TTS_1, const.TTS_1_HD]:
|
||||||
self.btype["text_to_voice"] = const.LINKAI
|
self.btype["text_to_voice"] = const.LINKAI
|
||||||
|
|
||||||
if model_type in ["claude"]:
|
if model_type in ["claude"]:
|
||||||
self.btype["chat"] = const.CLAUDEAI
|
self.btype["chat"] = const.CLAUDEAI
|
||||||
self.bots = {}
|
self.bots = {}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ CHATGPTONAZURE = "chatGPTOnAzure"
|
|||||||
LINKAI = "linkai"
|
LINKAI = "linkai"
|
||||||
CLAUDEAI = "claude"
|
CLAUDEAI = "claude"
|
||||||
QWEN = "qwen"
|
QWEN = "qwen"
|
||||||
|
GEMINI = "gemini"
|
||||||
|
|
||||||
# model
|
# model
|
||||||
GPT35 = "gpt-3.5-turbo"
|
GPT35 = "gpt-3.5-turbo"
|
||||||
@@ -17,7 +18,7 @@ WHISPER_1 = "whisper-1"
|
|||||||
TTS_1 = "tts-1"
|
TTS_1 = "tts-1"
|
||||||
TTS_1_HD = "tts-1-hd"
|
TTS_1_HD = "tts-1-hd"
|
||||||
|
|
||||||
MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "wenxin-4", "xunfei", "claude", "gpt-4-turbo", GPT4_TURBO_PREVIEW, QWEN]
|
MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "wenxin-4", "xunfei", "claude", "gpt-4-turbo", GPT4_TURBO_PREVIEW, QWEN, GEMINI]
|
||||||
|
|
||||||
# channel
|
# channel
|
||||||
FEISHU = "feishu"
|
FEISHU = "feishu"
|
||||||
|
|||||||
@@ -73,6 +73,8 @@ available_setting = {
|
|||||||
"qwen_agent_key": "",
|
"qwen_agent_key": "",
|
||||||
"qwen_app_id": "",
|
"qwen_app_id": "",
|
||||||
"qwen_node_id": "", # 流程编排模型用到的id,如果没有用到qwen_node_id,请务必保持为空字符串
|
"qwen_node_id": "", # 流程编排模型用到的id,如果没有用到qwen_node_id,请务必保持为空字符串
|
||||||
|
# Google Gemini Api Key
|
||||||
|
"gemini_api_key": "",
|
||||||
# wework的通用配置
|
# wework的通用配置
|
||||||
"wework_smart": True, # 配置wework是否使用已登录的企业微信,False为多开
|
"wework_smart": True, # 配置wework是否使用已登录的企业微信,False为多开
|
||||||
# 语音设置
|
# 语音设置
|
||||||
|
|||||||
@@ -313,7 +313,7 @@ class Godcmd(Plugin):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
ok, result = False, "你没有设置私有GPT模型"
|
ok, result = False, "你没有设置私有GPT模型"
|
||||||
elif cmd == "reset":
|
elif cmd == "reset":
|
||||||
if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI, const.BAIDU, const.XUNFEI]:
|
if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI, const.BAIDU, const.XUNFEI, const.GEMINI]:
|
||||||
bot.sessions.clear_session(session_id)
|
bot.sessions.clear_session(session_id)
|
||||||
if Bridge().chat_bots.get(bottype):
|
if Bridge().chat_bots.get(bottype):
|
||||||
Bridge().chat_bots.get(bottype).sessions.clear_session(session_id)
|
Bridge().chat_bots.get(bottype).sessions.clear_session(session_id)
|
||||||
|
|||||||
Reference in New Issue
Block a user