mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-02 00:57:41 +08:00
修正会话tokens计算
This commit is contained in:
@@ -40,21 +40,21 @@ class ChatGPTBot(Bot):
|
||||
# return self.reply_text_stream(query, new_query, from_user_id)
|
||||
|
||||
reply_content = self.reply_text(new_query, from_user_id, 0)
|
||||
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
|
||||
if reply_content:
|
||||
Session.save_session(query, reply_content, from_user_id)
|
||||
return reply_content[1]
|
||||
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content["content"]))
|
||||
if reply_content["completion_tokens"] > 0:
|
||||
Session.save_session(reply_content["content"], from_user_id, reply_content["total_tokens"])
|
||||
return reply_content["content"]
|
||||
|
||||
elif context.get('type', None) == 'IMAGE_CREATE':
|
||||
return self.create_img(query, 0)
|
||||
|
||||
def reply_text(self, query, user_id, retry_count=0):
|
||||
def reply_text(self, query, user_id, retry_count=0) ->dict:
|
||||
'''
|
||||
call openai's ChatCompletion to get the answer
|
||||
:param query: query content
|
||||
:param user_id: from user id
|
||||
:param retry_count: retry count
|
||||
:return: [0]-tokens used and [1]-answer
|
||||
:return: {}
|
||||
'''
|
||||
try:
|
||||
response = openai.ChatCompletion.create(
|
||||
@@ -68,8 +68,9 @@ class ChatGPTBot(Bot):
|
||||
)
|
||||
# res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '')
|
||||
logger.info(response.choices[0]['message']['content'])
|
||||
|
||||
return response["usage"]["prompt_tokens"],response.choices[0]['message']['content']
|
||||
return {"total_tokens": response["usage"]["total_tokens"],
|
||||
"completion_tokens": response["usage"]["completion_tokens"],
|
||||
"content": response.choices[0]['message']['content']}
|
||||
except openai.error.RateLimitError as e:
|
||||
# rate limit exception
|
||||
logger.warn(e)
|
||||
@@ -78,21 +79,21 @@ class ChatGPTBot(Bot):
|
||||
logger.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1))
|
||||
return self.reply_text(query, user_id, retry_count+1)
|
||||
else:
|
||||
return 0,"提问太快啦,请休息一下再问我吧"
|
||||
return {"completion_tokens": 0, "content": "提问太快啦,请休息一下再问我吧"}
|
||||
except openai.error.APIConnectionError as e:
|
||||
# api connection exception
|
||||
logger.warn(e)
|
||||
logger.warn("[OPEN_AI] APIConnection failed")
|
||||
return 0,"我连接不到你的网络"
|
||||
return {"completion_tokens": 0, "content":"我连接不到你的网络"}
|
||||
except openai.error.Timeout as e:
|
||||
logger.warn(e)
|
||||
logger.warn("[OPEN_AI] Timeout")
|
||||
return 0,"我没有收到你的消息"
|
||||
return {"completion_tokens": 0, "content":"我没有收到你的消息"}
|
||||
except Exception as e:
|
||||
# unknown exception
|
||||
logger.exception(e)
|
||||
Session.clear_session(user_id)
|
||||
return 0,"请再问我一次吧"
|
||||
return {"completion_tokens": 0, "content": "请再问我一次吧"}
|
||||
|
||||
def create_img(self, query, retry_count=0):
|
||||
try:
|
||||
@@ -143,7 +144,7 @@ class Session(object):
|
||||
return session
|
||||
|
||||
@staticmethod
|
||||
def save_session(query, answer, user_id):
|
||||
def save_session(answer, user_id, total_tokens):
|
||||
max_tokens = conf().get("conversation_max_tokens")
|
||||
if not max_tokens:
|
||||
# default 3000
|
||||
@@ -153,22 +154,23 @@ class Session(object):
|
||||
session = user_session.get(user_id)
|
||||
if session:
|
||||
# append conversation
|
||||
gpt_item = {'role': 'assistant', 'content': answer[1]}
|
||||
gpt_item = {'role': 'assistant', 'content': answer}
|
||||
session.append(gpt_item)
|
||||
|
||||
# discard exceed limit conversation
|
||||
used_tokens=int(answer[0])
|
||||
# logger.info("prompt tokens used={},max_tokens={}".format(used_tokens,max_tokens))
|
||||
Session.discard_exceed_conversation(session, max_tokens, total_tokens)
|
||||
|
||||
while used_tokens > max_tokens:
|
||||
@staticmethod
|
||||
def discard_exceed_conversation(session, max_tokens, total_tokens):
|
||||
dec_tokens=int(total_tokens)
|
||||
# logger.info("prompt tokens used={},max_tokens={}".format(used_tokens,max_tokens))
|
||||
while dec_tokens > max_tokens:
|
||||
# pop first conversation
|
||||
if len(session) > 0:
|
||||
session.pop(0)
|
||||
else:
|
||||
break
|
||||
|
||||
used_tokens=used_tokens-max_tokens
|
||||
|
||||
dec_tokens=dec_tokens-max_tokens
|
||||
|
||||
@staticmethod
|
||||
def clear_session(user_id):
|
||||
|
||||
Reference in New Issue
Block a user