mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-03 19:17:10 +08:00
Compare commits
21 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
22d67b3a59 | ||
|
|
e102cbb8c4 | ||
|
|
d90eeb7ee4 | ||
|
|
1989d53031 | ||
|
|
04ef0907b4 | ||
|
|
517b43561c | ||
|
|
ccb8c7227f | ||
|
|
9fbfeeb04f | ||
|
|
8b753a5a1f | ||
|
|
d25cab0627 | ||
|
|
84da0a8a35 | ||
|
|
6f665cffba | ||
|
|
aea8ac2e97 | ||
|
|
8418fa7b45 | ||
|
|
9cc4d0ee07 | ||
|
|
da60831c44 | ||
|
|
0773174a20 | ||
|
|
70e007d8ca | ||
|
|
fcc4d02c2f | ||
|
|
f4a5f00593 | ||
|
|
1170ed6566 |
@@ -45,6 +45,7 @@ DEMO视频:https://cdn.link-ai.tech/doc/cow_demo.mp4
|
|||||||
<br>
|
<br>
|
||||||
|
|
||||||
# 🏷 更新日志
|
# 🏷 更新日志
|
||||||
|
>**2024.10.31:** [1.7.3版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.3) 程序稳定性提升、数据库功能、Claude模型优化、linkai插件优化、离线通知
|
||||||
|
|
||||||
>**2024.09.26:** [1.7.2版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.2) 和 [1.7.1版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.1) 文心,讯飞等模型优化、o1 模型、快速安装和管理脚本
|
>**2024.09.26:** [1.7.2版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.2) 和 [1.7.1版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.1) 文心,讯飞等模型优化、o1 模型、快速安装和管理脚本
|
||||||
|
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ class ChatGPTSession(Session):
|
|||||||
def num_tokens_from_messages(messages, model):
|
def num_tokens_from_messages(messages, model):
|
||||||
"""Returns the number of tokens used by a list of messages."""
|
"""Returns the number of tokens used by a list of messages."""
|
||||||
|
|
||||||
if model in ["wenxin", "xunfei", const.GEMINI]:
|
if model in ["wenxin", "xunfei"] or model.startswith(const.GEMINI):
|
||||||
return num_tokens_by_character(messages)
|
return num_tokens_by_character(messages)
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|||||||
@@ -8,12 +8,12 @@ import anthropic
|
|||||||
|
|
||||||
from bot.bot import Bot
|
from bot.bot import Bot
|
||||||
from bot.openai.open_ai_image import OpenAIImage
|
from bot.openai.open_ai_image import OpenAIImage
|
||||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||||
from bot.gemini.google_gemini_bot import GoogleGeminiBot
|
|
||||||
from bot.session_manager import SessionManager
|
from bot.session_manager import SessionManager
|
||||||
from bridge.context import ContextType
|
from bridge.context import ContextType
|
||||||
from bridge.reply import Reply, ReplyType
|
from bridge.reply import Reply, ReplyType
|
||||||
from common.log import logger
|
from common.log import logger
|
||||||
|
from common import const
|
||||||
from config import conf
|
from config import conf
|
||||||
|
|
||||||
user_session = dict()
|
user_session = dict()
|
||||||
@@ -23,17 +23,14 @@ user_session = dict()
|
|||||||
class ClaudeAPIBot(Bot, OpenAIImage):
|
class ClaudeAPIBot(Bot, OpenAIImage):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
proxy = conf().get("proxy", None)
|
||||||
|
base_url = conf().get("open_ai_api_base", None) # 复用"open_ai_api_base"参数作为base_url
|
||||||
self.claudeClient = anthropic.Anthropic(
|
self.claudeClient = anthropic.Anthropic(
|
||||||
api_key=conf().get("claude_api_key")
|
api_key=conf().get("claude_api_key"),
|
||||||
|
proxies=proxy if proxy else None,
|
||||||
|
base_url=base_url if base_url else None
|
||||||
)
|
)
|
||||||
openai.api_key = conf().get("open_ai_api_key")
|
self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "text-davinci-003")
|
||||||
if conf().get("open_ai_api_base"):
|
|
||||||
openai.api_base = conf().get("open_ai_api_base")
|
|
||||||
proxy = conf().get("proxy")
|
|
||||||
if proxy:
|
|
||||||
openai.proxy = proxy
|
|
||||||
|
|
||||||
self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "text-davinci-003")
|
|
||||||
|
|
||||||
def reply(self, query, context=None):
|
def reply(self, query, context=None):
|
||||||
# acquire reply content
|
# acquire reply content
|
||||||
@@ -76,14 +73,14 @@ class ClaudeAPIBot(Bot, OpenAIImage):
|
|||||||
reply = Reply(ReplyType.ERROR, retstring)
|
reply = Reply(ReplyType.ERROR, retstring)
|
||||||
return reply
|
return reply
|
||||||
|
|
||||||
def reply_text(self, session: ChatGPTSession, retry_count=0):
|
def reply_text(self, session: BaiduWenxinSession, retry_count=0):
|
||||||
try:
|
try:
|
||||||
actual_model = self._model_mapping(conf().get("model"))
|
actual_model = self._model_mapping(conf().get("model"))
|
||||||
response = self.claudeClient.messages.create(
|
response = self.claudeClient.messages.create(
|
||||||
model=actual_model,
|
model=actual_model,
|
||||||
max_tokens=1024,
|
max_tokens=4096,
|
||||||
# system=conf().get("system"),
|
system=conf().get("character_desc", ""),
|
||||||
messages=GoogleGeminiBot.filter_messages(session.messages)
|
messages=session.messages
|
||||||
)
|
)
|
||||||
# response = openai.Completion.create(prompt=str(session), **self.args)
|
# response = openai.Completion.create(prompt=str(session), **self.args)
|
||||||
res_content = response.content[0].text.strip().replace("<|endoftext|>", "")
|
res_content = response.content[0].text.strip().replace("<|endoftext|>", "")
|
||||||
@@ -97,7 +94,7 @@ class ClaudeAPIBot(Bot, OpenAIImage):
|
|||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
need_retry = retry_count < 2
|
need_retry = retry_count < 2
|
||||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
result = {"total_tokens": 0, "completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||||
if isinstance(e, openai.error.RateLimitError):
|
if isinstance(e, openai.error.RateLimitError):
|
||||||
logger.warn("[CLAUDE_API] RateLimitError: {}".format(e))
|
logger.warn("[CLAUDE_API] RateLimitError: {}".format(e))
|
||||||
result["content"] = "提问太快啦,请休息一下再问我吧"
|
result["content"] = "提问太快啦,请休息一下再问我吧"
|
||||||
@@ -125,11 +122,11 @@ class ClaudeAPIBot(Bot, OpenAIImage):
|
|||||||
|
|
||||||
def _model_mapping(self, model) -> str:
|
def _model_mapping(self, model) -> str:
|
||||||
if model == "claude-3-opus":
|
if model == "claude-3-opus":
|
||||||
return "claude-3-opus-20240229"
|
return const.CLAUDE_3_OPUS
|
||||||
elif model == "claude-3-sonnet":
|
elif model == "claude-3-sonnet":
|
||||||
return "claude-3-sonnet-20240229"
|
return const.CLAUDE_3_SONNET
|
||||||
elif model == "claude-3-haiku":
|
elif model == "claude-3-haiku":
|
||||||
return "claude-3-haiku-20240307"
|
return const.CLAUDE_3_HAIKU
|
||||||
elif model == "claude-3.5-sonnet":
|
elif model == "claude-3.5-sonnet":
|
||||||
return "claude-3-5-sonnet-20240620"
|
return const.CLAUDE_35_SONNET
|
||||||
return model
|
return model
|
||||||
|
|||||||
@@ -337,24 +337,27 @@ class ChatChannel(Channel):
|
|||||||
while True:
|
while True:
|
||||||
with self.lock:
|
with self.lock:
|
||||||
session_ids = list(self.sessions.keys())
|
session_ids = list(self.sessions.keys())
|
||||||
for session_id in session_ids:
|
for session_id in session_ids:
|
||||||
|
with self.lock:
|
||||||
context_queue, semaphore = self.sessions[session_id]
|
context_queue, semaphore = self.sessions[session_id]
|
||||||
if semaphore.acquire(blocking=False): # 等线程处理完毕才能删除
|
if semaphore.acquire(blocking=False): # 等线程处理完毕才能删除
|
||||||
if not context_queue.empty():
|
if not context_queue.empty():
|
||||||
context = context_queue.get()
|
context = context_queue.get()
|
||||||
logger.debug("[chat_channel] consume context: {}".format(context))
|
logger.debug("[chat_channel] consume context: {}".format(context))
|
||||||
future: Future = handler_pool.submit(self._handle, context)
|
future: Future = handler_pool.submit(self._handle, context)
|
||||||
future.add_done_callback(self._thread_pool_callback(session_id, context=context))
|
future.add_done_callback(self._thread_pool_callback(session_id, context=context))
|
||||||
|
with self.lock:
|
||||||
if session_id not in self.futures:
|
if session_id not in self.futures:
|
||||||
self.futures[session_id] = []
|
self.futures[session_id] = []
|
||||||
self.futures[session_id].append(future)
|
self.futures[session_id].append(future)
|
||||||
elif semaphore._initial_value == semaphore._value + 1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
|
elif semaphore._initial_value == semaphore._value + 1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
|
||||||
|
with self.lock:
|
||||||
self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()]
|
self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()]
|
||||||
assert len(self.futures[session_id]) == 0, "thread pool error"
|
assert len(self.futures[session_id]) == 0, "thread pool error"
|
||||||
del self.sessions[session_id]
|
del self.sessions[session_id]
|
||||||
else:
|
else:
|
||||||
semaphore.release()
|
semaphore.release()
|
||||||
time.sleep(0.1)
|
time.sleep(0.2)
|
||||||
|
|
||||||
# 取消session_id对应的所有任务,只能取消排队的消息和已提交线程池但未执行的任务
|
# 取消session_id对应的所有任务,只能取消排队的消息和已提交线程池但未执行的任务
|
||||||
def cancel_session(self, session_id):
|
def cancel_session(self, session_id):
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from common.expired_dict import ExpiredDict
|
|||||||
from common.log import logger
|
from common.log import logger
|
||||||
from common.singleton import singleton
|
from common.singleton import singleton
|
||||||
from common.time_check import time_checker
|
from common.time_check import time_checker
|
||||||
from common.utils import convert_webp_to_png
|
from common.utils import convert_webp_to_png, remove_markdown_symbol
|
||||||
from config import conf, get_appdata_dir
|
from config import conf, get_appdata_dir
|
||||||
from lib import itchat
|
from lib import itchat
|
||||||
from lib.itchat.content import *
|
from lib.itchat.content import *
|
||||||
@@ -213,9 +213,11 @@ class WechatChannel(ChatChannel):
|
|||||||
def send(self, reply: Reply, context: Context):
|
def send(self, reply: Reply, context: Context):
|
||||||
receiver = context["receiver"]
|
receiver = context["receiver"]
|
||||||
if reply.type == ReplyType.TEXT:
|
if reply.type == ReplyType.TEXT:
|
||||||
|
reply.content = remove_markdown_symbol(reply.content)
|
||||||
itchat.send(reply.content, toUserName=receiver)
|
itchat.send(reply.content, toUserName=receiver)
|
||||||
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
|
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
|
||||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
||||||
|
reply.content = remove_markdown_symbol(reply.content)
|
||||||
itchat.send(reply.content, toUserName=receiver)
|
itchat.send(reply.content, toUserName=receiver)
|
||||||
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
|
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
|
||||||
elif reply.type == ReplyType.VOICE:
|
elif reply.type == ReplyType.VOICE:
|
||||||
|
|||||||
@@ -55,6 +55,16 @@ class WechatMessage(ChatMessage):
|
|||||||
self.ctype = ContextType.EXIT_GROUP
|
self.ctype = ContextType.EXIT_GROUP
|
||||||
self.content = itchat_msg["Content"]
|
self.content = itchat_msg["Content"]
|
||||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
|
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
|
||||||
|
|
||||||
|
elif any(note_patpat in itchat_msg["Content"] for note_patpat in notes_patpat): # 若有任何在notes_patpat列表中的字符串出现在NOTE中:
|
||||||
|
self.ctype = ContextType.PATPAT
|
||||||
|
self.content = itchat_msg["Content"]
|
||||||
|
if "拍了拍我" in itchat_msg["Content"]: # 识别中文
|
||||||
|
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
|
||||||
|
elif "tickled my" in itchat_msg["Content"] or "tickled me" in itchat_msg["Content"]:
|
||||||
|
self.actual_user_nickname = re.findall(r'^(.*?)(?:tickled my|tickled me)', itchat_msg["Content"])[0]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Unsupported note message: " + itchat_msg["Content"])
|
||||||
|
|
||||||
elif "你已添加了" in itchat_msg["Content"]: #通过好友请求
|
elif "你已添加了" in itchat_msg["Content"]: #通过好友请求
|
||||||
self.ctype = ContextType.ACCEPT_FRIEND
|
self.ctype = ContextType.ACCEPT_FRIEND
|
||||||
@@ -62,11 +72,6 @@ class WechatMessage(ChatMessage):
|
|||||||
elif any(note_patpat in itchat_msg["Content"] for note_patpat in notes_patpat): # 若有任何在notes_patpat列表中的字符串出现在NOTE中:
|
elif any(note_patpat in itchat_msg["Content"] for note_patpat in notes_patpat): # 若有任何在notes_patpat列表中的字符串出现在NOTE中:
|
||||||
self.ctype = ContextType.PATPAT
|
self.ctype = ContextType.PATPAT
|
||||||
self.content = itchat_msg["Content"]
|
self.content = itchat_msg["Content"]
|
||||||
if is_group:
|
|
||||||
if "拍了拍我" in itchat_msg["Content"]: # 识别中文
|
|
||||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
|
|
||||||
elif ("tickled my" in itchat_msg["Content"] or "tickled me" in itchat_msg["Content"]):
|
|
||||||
self.actual_user_nickname = re.findall(r'^(.*?)(?:tickled my|tickled me)', itchat_msg["Content"])[0]
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Unsupported note message: " + itchat_msg["Content"])
|
raise NotImplementedError("Unsupported note message: " + itchat_msg["Content"])
|
||||||
elif itchat_msg["Type"] == ATTACHMENT:
|
elif itchat_msg["Type"] == ATTACHMENT:
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from channel.wechatcom.wechatcomapp_client import WechatComAppClient
|
|||||||
from channel.wechatcom.wechatcomapp_message import WechatComAppMessage
|
from channel.wechatcom.wechatcomapp_message import WechatComAppMessage
|
||||||
from common.log import logger
|
from common.log import logger
|
||||||
from common.singleton import singleton
|
from common.singleton import singleton
|
||||||
from common.utils import compress_imgfile, fsize, split_string_by_utf8_length, convert_webp_to_png
|
from common.utils import compress_imgfile, fsize, split_string_by_utf8_length, convert_webp_to_png, remove_markdown_symbol
|
||||||
from config import conf, subscribe_msg
|
from config import conf, subscribe_msg
|
||||||
from voice.audio_convert import any_to_amr, split_audio
|
from voice.audio_convert import any_to_amr, split_audio
|
||||||
|
|
||||||
@@ -52,7 +52,7 @@ class WechatComAppChannel(ChatChannel):
|
|||||||
def send(self, reply: Reply, context: Context):
|
def send(self, reply: Reply, context: Context):
|
||||||
receiver = context["receiver"]
|
receiver = context["receiver"]
|
||||||
if reply.type in [ReplyType.TEXT, ReplyType.ERROR, ReplyType.INFO]:
|
if reply.type in [ReplyType.TEXT, ReplyType.ERROR, ReplyType.INFO]:
|
||||||
reply_text = reply.content
|
reply_text = remove_markdown_symbol(reply.content)
|
||||||
texts = split_string_by_utf8_length(reply_text, MAX_UTF8_LEN)
|
texts = split_string_by_utf8_length(reply_text, MAX_UTF8_LEN)
|
||||||
if len(texts) > 1:
|
if len(texts) > 1:
|
||||||
logger.info("[wechatcom] text too long, split into {} parts".format(len(texts)))
|
logger.info("[wechatcom] text too long, split into {} parts".format(len(texts)))
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from channel.wechatmp.common import *
|
|||||||
from channel.wechatmp.wechatmp_client import WechatMPClient
|
from channel.wechatmp.wechatmp_client import WechatMPClient
|
||||||
from common.log import logger
|
from common.log import logger
|
||||||
from common.singleton import singleton
|
from common.singleton import singleton
|
||||||
from common.utils import split_string_by_utf8_length
|
from common.utils import split_string_by_utf8_length, remove_markdown_symbol
|
||||||
from config import conf
|
from config import conf
|
||||||
from voice.audio_convert import any_to_mp3, split_audio
|
from voice.audio_convert import any_to_mp3, split_audio
|
||||||
|
|
||||||
@@ -81,7 +81,7 @@ class WechatMPChannel(ChatChannel):
|
|||||||
receiver = context["receiver"]
|
receiver = context["receiver"]
|
||||||
if self.passive_reply:
|
if self.passive_reply:
|
||||||
if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR:
|
if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR:
|
||||||
reply_text = reply.content
|
reply_text = remove_markdown_symbol(reply.content)
|
||||||
logger.info("[wechatmp] text cached, receiver {}\n{}".format(receiver, reply_text))
|
logger.info("[wechatmp] text cached, receiver {}\n{}".format(receiver, reply_text))
|
||||||
self.cache_dict[receiver].append(("text", reply_text))
|
self.cache_dict[receiver].append(("text", reply_text))
|
||||||
elif reply.type == ReplyType.VOICE:
|
elif reply.type == ReplyType.VOICE:
|
||||||
|
|||||||
@@ -69,6 +69,17 @@ GLM_4_0520 = "glm-4-0520"
|
|||||||
GLM_4_AIR = "glm-4-air"
|
GLM_4_AIR = "glm-4-air"
|
||||||
GLM_4_AIRX = "glm-4-airx"
|
GLM_4_AIRX = "glm-4-airx"
|
||||||
|
|
||||||
|
|
||||||
|
CLAUDE_3_OPUS = "claude-3-opus-latest"
|
||||||
|
CLAUDE_3_OPUS_0229 = "claude-3-opus-20240229"
|
||||||
|
|
||||||
|
CLAUDE_35_SONNET = "claude-3-5-sonnet-latest" # 带 latest 标签的模型名称,会不断更新指向最新发布的模型
|
||||||
|
CLAUDE_35_SONNET_1022 = "claude-3-5-sonnet-20241022" # 带具体日期的模型名称,会固定为该日期发布的模型
|
||||||
|
CLAUDE_35_SONNET_0620 = "claude-3-5-sonnet-20240620"
|
||||||
|
CLAUDE_3_SONNET = "claude-3-sonnet-20240229"
|
||||||
|
|
||||||
|
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||||
|
|
||||||
MODEL_LIST = [
|
MODEL_LIST = [
|
||||||
GPT35, GPT35_0125, GPT35_1106, "gpt-3.5-turbo-16k",
|
GPT35, GPT35_0125, GPT35_1106, "gpt-3.5-turbo-16k",
|
||||||
O1, O1_MINI, GPT_4o, GPT_4O_0806, GPT_4o_MINI, GPT4_TURBO, GPT4_TURBO_PREVIEW, GPT4_TURBO_01_25, GPT4_TURBO_11_06, GPT4, GPT4_32k, GPT4_06_13, GPT4_32k_06_13,
|
O1, O1_MINI, GPT_4o, GPT_4O_0806, GPT_4o_MINI, GPT4_TURBO, GPT4_TURBO_PREVIEW, GPT4_TURBO_01_25, GPT4_TURBO_11_06, GPT4, GPT4_32k, GPT4_06_13, GPT4_32k_06_13,
|
||||||
@@ -77,7 +88,7 @@ MODEL_LIST = [
|
|||||||
ZHIPU_AI, GLM_4, GLM_4_PLUS, GLM_4_flash, GLM_4_LONG, GLM_4_ALLTOOLS, GLM_4_0520, GLM_4_AIR, GLM_4_AIRX,
|
ZHIPU_AI, GLM_4, GLM_4_PLUS, GLM_4_flash, GLM_4_LONG, GLM_4_ALLTOOLS, GLM_4_0520, GLM_4_AIR, GLM_4_AIRX,
|
||||||
MOONSHOT, MiniMax,
|
MOONSHOT, MiniMax,
|
||||||
GEMINI, GEMINI_PRO, GEMINI_15_flash, GEMINI_15_PRO,
|
GEMINI, GEMINI_PRO, GEMINI_15_flash, GEMINI_15_PRO,
|
||||||
"claude", "claude-3-haiku", "claude-3-sonnet", "claude-3-opus", "claude-3-opus-20240229", "claude-3.5-sonnet",
|
CLAUDE_3_OPUS, CLAUDE_3_OPUS_0229, CLAUDE_35_SONNET, CLAUDE_35_SONNET_1022, CLAUDE_35_SONNET_0620, CLAUDE_3_SONNET, CLAUDE_3_HAIKU, "claude", "claude-3-haiku", "claude-3-sonnet", "claude-3-opus", "claude-3.5-sonnet",
|
||||||
"moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k",
|
"moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k",
|
||||||
QWEN, QWEN_TURBO, QWEN_PLUS, QWEN_MAX,
|
QWEN, QWEN_TURBO, QWEN_PLUS, QWEN_MAX,
|
||||||
LINKAI_35, LINKAI_4_TURBO, LINKAI_4o
|
LINKAI_35, LINKAI_4_TURBO, LINKAI_4o
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from common.log import logger
|
from common.log import logger
|
||||||
@@ -68,3 +69,10 @@ def convert_webp_to_png(webp_image):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to convert WEBP to PNG: {e}")
|
logger.error(f"Failed to convert WEBP to PNG: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def remove_markdown_symbol(text: str):
|
||||||
|
# 移除markdown格式,目前先移除**
|
||||||
|
if not text:
|
||||||
|
return text
|
||||||
|
return re.sub(r'\*\*(.*?)\*\*', r'\1', text)
|
||||||
|
|||||||
@@ -313,7 +313,7 @@ class Godcmd(Plugin):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
ok, result = False, "你没有设置私有GPT模型"
|
ok, result = False, "你没有设置私有GPT模型"
|
||||||
elif cmd == "reset":
|
elif cmd == "reset":
|
||||||
if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI, const.BAIDU, const.XUNFEI, const.QWEN, const.GEMINI, const.ZHIPU_AI]:
|
if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI, const.BAIDU, const.XUNFEI, const.QWEN, const.GEMINI, const.ZHIPU_AI, const.CLAUDEAPI]:
|
||||||
bot.sessions.clear_session(session_id)
|
bot.sessions.clear_session(session_id)
|
||||||
if Bridge().chat_bots.get(bottype):
|
if Bridge().chat_bots.get(bottype):
|
||||||
Bridge().chat_bots.get(bottype).sessions.clear_session(session_id)
|
Bridge().chat_bots.get(bottype).sessions.clear_session(session_id)
|
||||||
|
|||||||
@@ -98,6 +98,8 @@
|
|||||||
|
|
||||||
如果不想创建 `plugins/linkai/config.json` 配置,可以直接通过 `$linkai sum open` 指令开启该功能。
|
如果不想创建 `plugins/linkai/config.json` 配置,可以直接通过 `$linkai sum open` 指令开启该功能。
|
||||||
|
|
||||||
|
也可以通过私聊(全局 `config.json` 中的 `linkai_app_code`)或者群聊绑定(通过`group_app_map`参数配置)的应用来开启该功能:在LinkAI平台 [应用配置](https://link-ai.tech/console/factory) 里添加并开启**内容总结**插件。
|
||||||
|
|
||||||
#### 使用
|
#### 使用
|
||||||
|
|
||||||
功能开启后,向机器人发送 **文件**、 **分享链接卡片**、**图片** 即可生成摘要,进一步可以与文件或链接的内容进行多轮对话。如果需要关闭某种类型的内容总结,设置 `summary`配置中的type字段即可。
|
功能开启后,向机器人发送 **文件**、 **分享链接卡片**、**图片** 即可生成摘要,进一步可以与文件或链接的内容进行多轮对话。如果需要关闭某种类型的内容总结,设置 `summary`配置中的type字段即可。
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from common.expired_dict import ExpiredDict
|
|||||||
from common import const
|
from common import const
|
||||||
import os
|
import os
|
||||||
from .utils import Util
|
from .utils import Util
|
||||||
from config import plugin_config
|
from config import plugin_config, conf
|
||||||
|
|
||||||
|
|
||||||
@plugins.register(
|
@plugins.register(
|
||||||
@@ -28,7 +28,7 @@ class LinkAI(Plugin):
|
|||||||
# 未加载到配置,使用模板中的配置
|
# 未加载到配置,使用模板中的配置
|
||||||
self.config = self._load_config_template()
|
self.config = self._load_config_template()
|
||||||
if self.config:
|
if self.config:
|
||||||
self.mj_bot = MJBot(self.config.get("midjourney"))
|
self.mj_bot = MJBot(self.config.get("midjourney"), self._fetch_group_app_code)
|
||||||
self.sum_config = {}
|
self.sum_config = {}
|
||||||
if self.config:
|
if self.config:
|
||||||
self.sum_config = self.config.get("summary")
|
self.sum_config = self.config.get("summary")
|
||||||
@@ -56,7 +56,8 @@ class LinkAI(Plugin):
|
|||||||
return
|
return
|
||||||
if context.type != ContextType.IMAGE:
|
if context.type != ContextType.IMAGE:
|
||||||
_send_info(e_context, "正在为你加速生成摘要,请稍后")
|
_send_info(e_context, "正在为你加速生成摘要,请稍后")
|
||||||
res = LinkSummary().summary_file(file_path)
|
app_code = self._fetch_app_code(context)
|
||||||
|
res = LinkSummary().summary_file(file_path, app_code)
|
||||||
if not res:
|
if not res:
|
||||||
if context.type != ContextType.IMAGE:
|
if context.type != ContextType.IMAGE:
|
||||||
_set_reply_text("因为神秘力量无法获取内容,请稍后再试吧", e_context, level=ReplyType.TEXT)
|
_set_reply_text("因为神秘力量无法获取内容,请稍后再试吧", e_context, level=ReplyType.TEXT)
|
||||||
@@ -74,7 +75,8 @@ class LinkAI(Plugin):
|
|||||||
if not LinkSummary().check_url(context.content):
|
if not LinkSummary().check_url(context.content):
|
||||||
return
|
return
|
||||||
_send_info(e_context, "正在为你加速生成摘要,请稍后")
|
_send_info(e_context, "正在为你加速生成摘要,请稍后")
|
||||||
res = LinkSummary().summary_url(context.content)
|
app_code = self._fetch_app_code(context)
|
||||||
|
res = LinkSummary().summary_url(context.content, app_code)
|
||||||
if not res:
|
if not res:
|
||||||
_set_reply_text("因为神秘力量无法获取文章内容,请稍后再试吧~", e_context, level=ReplyType.TEXT)
|
_set_reply_text("因为神秘力量无法获取文章内容,请稍后再试吧~", e_context, level=ReplyType.TEXT)
|
||||||
return
|
return
|
||||||
@@ -169,7 +171,7 @@ class LinkAI(Plugin):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if len(cmd) == 3 and cmd[1] == "sum" and (cmd[2] == "open" or cmd[2] == "close"):
|
if len(cmd) == 3 and cmd[1] == "sum" and (cmd[2] == "open" or cmd[2] == "close"):
|
||||||
# 知识库开关指令
|
# 总结对话开关指令
|
||||||
if not Util.is_admin(e_context):
|
if not Util.is_admin(e_context):
|
||||||
_set_reply_text("需要管理员权限执行", e_context, level=ReplyType.ERROR)
|
_set_reply_text("需要管理员权限执行", e_context, level=ReplyType.ERROR)
|
||||||
return
|
return
|
||||||
@@ -192,14 +194,34 @@ class LinkAI(Plugin):
|
|||||||
return
|
return
|
||||||
|
|
||||||
def _is_summary_open(self, context) -> bool:
|
def _is_summary_open(self, context) -> bool:
|
||||||
if not self.sum_config or not self.sum_config.get("enabled"):
|
# 获取远程应用插件状态
|
||||||
return False
|
remote_enabled = False
|
||||||
if context.kwargs.get("isgroup") and not self.sum_config.get("group_enabled"):
|
if context.kwargs.get("isgroup"):
|
||||||
return False
|
# 群聊场景只查询群对应的app_code
|
||||||
support_type = self.sum_config.get("type") or ["FILE", "SHARING"]
|
group_name = context.get("msg").from_user_nickname
|
||||||
if context.type.name not in support_type and context.type.name != "TEXT":
|
app_code = self._fetch_group_app_code(group_name)
|
||||||
return False
|
if app_code:
|
||||||
return True
|
remote_enabled = Util.fetch_app_plugin(app_code, "内容总结")
|
||||||
|
else:
|
||||||
|
# 非群聊场景使用全局app_code
|
||||||
|
app_code = conf().get("linkai_app_code")
|
||||||
|
if app_code:
|
||||||
|
remote_enabled = Util.fetch_app_plugin(app_code, "内容总结")
|
||||||
|
|
||||||
|
# 基础条件:总开关开启且消息类型符合要求
|
||||||
|
base_enabled = (
|
||||||
|
self.sum_config
|
||||||
|
and self.sum_config.get("enabled")
|
||||||
|
and (context.type.name in (
|
||||||
|
self.sum_config.get("type") or ["FILE", "SHARING"]) or context.type.name == "TEXT")
|
||||||
|
)
|
||||||
|
|
||||||
|
# 群聊:需要满足(总开关和群开关)或远程插件开启
|
||||||
|
if context.kwargs.get("isgroup"):
|
||||||
|
return (base_enabled and self.sum_config.get("group_enabled")) or remote_enabled
|
||||||
|
|
||||||
|
# 非群聊:只需要满足总开关或远程插件开启
|
||||||
|
return base_enabled or remote_enabled
|
||||||
|
|
||||||
# LinkAI 对话任务处理
|
# LinkAI 对话任务处理
|
||||||
def _is_chat_task(self, e_context: EventContext):
|
def _is_chat_task(self, e_context: EventContext):
|
||||||
@@ -230,6 +252,19 @@ class LinkAI(Plugin):
|
|||||||
app_code = group_mapping.get(group_name) or group_mapping.get("ALL_GROUP")
|
app_code = group_mapping.get(group_name) or group_mapping.get("ALL_GROUP")
|
||||||
return app_code
|
return app_code
|
||||||
|
|
||||||
|
def _fetch_app_code(self, context) -> str:
|
||||||
|
"""
|
||||||
|
根据主配置或者群聊名称获取对应的应用code,优先获取群聊配置的应用code
|
||||||
|
:param context: 上下文
|
||||||
|
:return: 应用code
|
||||||
|
"""
|
||||||
|
app_code = conf().get("linkai_app_code")
|
||||||
|
if context.kwargs.get("isgroup"):
|
||||||
|
# 群聊场景只查询群对应的app_code
|
||||||
|
group_name = context.get("msg").from_user_nickname
|
||||||
|
app_code = self._fetch_group_app_code(group_name)
|
||||||
|
return app_code
|
||||||
|
|
||||||
def get_help_text(self, verbose=False, **kwargs):
|
def get_help_text(self, verbose=False, **kwargs):
|
||||||
trigger_prefix = _get_trigger_prefix()
|
trigger_prefix = _get_trigger_prefix()
|
||||||
help_text = "用于集成 LinkAI 提供的知识库、Midjourney绘画、文档总结、联网搜索等能力。\n\n"
|
help_text = "用于集成 LinkAI 提供的知识库、Midjourney绘画、文档总结、联网搜索等能力。\n\n"
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from bridge.context import ContextType
|
|||||||
from plugins import EventContext, EventAction
|
from plugins import EventContext, EventAction
|
||||||
from .utils import Util
|
from .utils import Util
|
||||||
|
|
||||||
|
|
||||||
INVALID_REQUEST = 410
|
INVALID_REQUEST = 410
|
||||||
NOT_FOUND_ORIGIN_IMAGE = 461
|
NOT_FOUND_ORIGIN_IMAGE = 461
|
||||||
NOT_FOUND_TASK = 462
|
NOT_FOUND_TASK = 462
|
||||||
@@ -67,10 +68,11 @@ class MJTask:
|
|||||||
|
|
||||||
# midjourney bot
|
# midjourney bot
|
||||||
class MJBot:
|
class MJBot:
|
||||||
def __init__(self, config):
|
def __init__(self, config, fetch_group_app_code):
|
||||||
self.base_url = conf().get("linkai_api_base", "https://api.link-ai.tech") + "/v1/img/midjourney"
|
self.base_url = conf().get("linkai_api_base", "https://api.link-ai.tech") + "/v1/img/midjourney"
|
||||||
self.headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
self.headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.fetch_group_app_code = fetch_group_app_code
|
||||||
self.tasks = {}
|
self.tasks = {}
|
||||||
self.temp_dict = {}
|
self.temp_dict = {}
|
||||||
self.tasks_lock = threading.Lock()
|
self.tasks_lock = threading.Lock()
|
||||||
@@ -98,7 +100,7 @@ class MJBot:
|
|||||||
return TaskType.VARIATION
|
return TaskType.VARIATION
|
||||||
elif cmd_list[0].lower() == f"{trigger_prefix}mjr":
|
elif cmd_list[0].lower() == f"{trigger_prefix}mjr":
|
||||||
return TaskType.RESET
|
return TaskType.RESET
|
||||||
elif context.type == ContextType.IMAGE_CREATE and self.config.get("use_image_create_prefix") and self.config.get("enabled"):
|
elif context.type == ContextType.IMAGE_CREATE and self.config.get("use_image_create_prefix") and self._is_mj_open(context):
|
||||||
return TaskType.GENERATE
|
return TaskType.GENERATE
|
||||||
|
|
||||||
def process_mj_task(self, mj_type: TaskType, e_context: EventContext):
|
def process_mj_task(self, mj_type: TaskType, e_context: EventContext):
|
||||||
@@ -129,8 +131,8 @@ class MJBot:
|
|||||||
self._set_reply_text(f"Midjourney绘画已{tips_text}", e_context, level=ReplyType.INFO)
|
self._set_reply_text(f"Midjourney绘画已{tips_text}", e_context, level=ReplyType.INFO)
|
||||||
return
|
return
|
||||||
|
|
||||||
if not self.config.get("enabled"):
|
if not self._is_mj_open(context):
|
||||||
logger.warn("Midjourney绘画未开启,请查看 plugins/linkai/config.json 中的配置")
|
logger.warn("Midjourney绘画未开启,请查看 plugins/linkai/config.json 中的配置,或者在LinkAI平台 应用中添加/打开”MJ“插件")
|
||||||
self._set_reply_text(f"Midjourney绘画未开启", e_context, level=ReplyType.INFO)
|
self._set_reply_text(f"Midjourney绘画未开启", e_context, level=ReplyType.INFO)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -409,6 +411,25 @@ class MJBot:
|
|||||||
result.append(task)
|
result.append(task)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def _is_mj_open(self, context) -> bool:
|
||||||
|
# 获取远程应用插件状态
|
||||||
|
remote_enabled = False
|
||||||
|
if context.kwargs.get("isgroup"):
|
||||||
|
# 群聊场景只查询群对应的app_code
|
||||||
|
group_name = context.get("msg").from_user_nickname
|
||||||
|
app_code = self.fetch_group_app_code(group_name)
|
||||||
|
if app_code:
|
||||||
|
remote_enabled = Util.fetch_app_plugin(app_code, "Midjourney")
|
||||||
|
else:
|
||||||
|
# 非群聊场景使用全局app_code
|
||||||
|
app_code = conf().get("linkai_app_code")
|
||||||
|
if app_code:
|
||||||
|
remote_enabled = Util.fetch_app_plugin(app_code, "Midjourney")
|
||||||
|
|
||||||
|
# 本地配置
|
||||||
|
base_enabled = self.config.get("enabled")
|
||||||
|
|
||||||
|
return base_enabled or remote_enabled
|
||||||
|
|
||||||
def _send(channel, reply: Reply, context, retry_cnt=0):
|
def _send(channel, reply: Reply, context, retry_cnt=0):
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -9,19 +9,21 @@ class LinkSummary:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def summary_file(self, file_path: str):
|
def summary_file(self, file_path: str, app_code: str):
|
||||||
file_body = {
|
file_body = {
|
||||||
"file": open(file_path, "rb"),
|
"file": open(file_path, "rb"),
|
||||||
"name": file_path.split("/")[-1],
|
"name": file_path.split("/")[-1],
|
||||||
|
"app_code": app_code
|
||||||
}
|
}
|
||||||
url = self.base_url() + "/v1/summary/file"
|
url = self.base_url() + "/v1/summary/file"
|
||||||
res = requests.post(url, headers=self.headers(), files=file_body, timeout=(5, 300))
|
res = requests.post(url, headers=self.headers(), files=file_body, timeout=(5, 300))
|
||||||
return self._parse_summary_res(res)
|
return self._parse_summary_res(res)
|
||||||
|
|
||||||
def summary_url(self, url: str):
|
def summary_url(self, url: str, app_code: str):
|
||||||
url = html.unescape(url)
|
url = html.unescape(url)
|
||||||
body = {
|
body = {
|
||||||
"url": url
|
"url": url,
|
||||||
|
"app_code": app_code
|
||||||
}
|
}
|
||||||
res = requests.post(url=self.base_url() + "/v1/summary/url", headers=self.headers(), json=body, timeout=(5, 180))
|
res = requests.post(url=self.base_url() + "/v1/summary/url", headers=self.headers(), json=body, timeout=(5, 180))
|
||||||
return self._parse_summary_res(res)
|
return self._parse_summary_res(res)
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
|
import requests
|
||||||
|
from common.log import logger
|
||||||
from config import global_config
|
from config import global_config
|
||||||
from bridge.reply import Reply, ReplyType
|
from bridge.reply import Reply, ReplyType
|
||||||
from plugins.event import EventContext, EventAction
|
from plugins.event import EventContext, EventAction
|
||||||
|
from config import conf
|
||||||
|
|
||||||
class Util:
|
class Util:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -26,3 +28,20 @@ class Util:
|
|||||||
reply = Reply(level, content)
|
reply = Reply(level, content)
|
||||||
e_context["reply"] = reply
|
e_context["reply"] = reply
|
||||||
e_context.action = EventAction.BREAK_PASS
|
e_context.action = EventAction.BREAK_PASS
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def fetch_app_plugin(app_code: str, plugin_name: str) -> bool:
|
||||||
|
headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
||||||
|
# do http request
|
||||||
|
base_url = conf().get("linkai_api_base", "https://api.link-ai.tech")
|
||||||
|
params = {"app_code": app_code}
|
||||||
|
res = requests.get(url=base_url + "/v1/app/info", params=params, headers=headers, timeout=(5, 10))
|
||||||
|
if res.status_code == 200:
|
||||||
|
plugins = res.json().get("data").get("plugins")
|
||||||
|
for plugin in plugins:
|
||||||
|
if plugin.get("name") and plugin.get("name") == plugin_name:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
logger.warning(f"[LinkAI] find app info exception, res={res}")
|
||||||
|
return False
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
openai==0.27.8
|
openai==0.27.8
|
||||||
HTMLParser>=0.0.2
|
HTMLParser>=0.0.2
|
||||||
PyQRCode>=1.2.1
|
PyQRCode==1.2.1
|
||||||
qrcode>=7.4.2
|
qrcode==7.4.2
|
||||||
requests>=2.28.2
|
requests>=2.28.2
|
||||||
chardet>=5.1.0
|
chardet>=5.1.0
|
||||||
Pillow
|
Pillow
|
||||||
|
|||||||
Reference in New Issue
Block a user