ソースを参照

修正会话tokens计算

develop
zwssunny 1年前
コミット
5de600c689
1個のファイルの変更22行の追加20行の削除
  1. +22
    -20
      bot/chatgpt/chat_gpt_bot.py

+ 22
- 20
bot/chatgpt/chat_gpt_bot.py ファイルの表示

@@ -40,21 +40,21 @@ class ChatGPTBot(Bot):
# return self.reply_text_stream(query, new_query, from_user_id)

reply_content = self.reply_text(new_query, from_user_id, 0)
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
if reply_content:
Session.save_session(query, reply_content, from_user_id)
return reply_content[1]
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content["content"]))
if reply_content["completion_tokens"] > 0:
Session.save_session(reply_content["content"], from_user_id, reply_content["total_tokens"])
return reply_content["content"]

elif context.get('type', None) == 'IMAGE_CREATE':
return self.create_img(query, 0)

def reply_text(self, query, user_id, retry_count=0):
def reply_text(self, query, user_id, retry_count=0) ->dict:
'''
call openai's ChatCompletion to get the answer
:param query: query content
:param user_id: from user id
:param retry_count: retry count
:return: [0]-tokens used and [1]-answer
:return: {}
'''
try:
response = openai.ChatCompletion.create(
@@ -68,8 +68,9 @@ class ChatGPTBot(Bot):
)
# res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '')
logger.info(response.choices[0]['message']['content'])

return response["usage"]["prompt_tokens"],response.choices[0]['message']['content']
return {"total_tokens": response["usage"]["total_tokens"],
"completion_tokens": response["usage"]["completion_tokens"],
"content": response.choices[0]['message']['content']}
except openai.error.RateLimitError as e:
# rate limit exception
logger.warn(e)
@@ -78,21 +79,21 @@ class ChatGPTBot(Bot):
logger.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1))
return self.reply_text(query, user_id, retry_count+1)
else:
return 0,"提问太快啦,请休息一下再问我吧"
return {"completion_tokens": 0, "content": "提问太快啦,请休息一下再问我吧"}
except openai.error.APIConnectionError as e:
# api connection exception
logger.warn(e)
logger.warn("[OPEN_AI] APIConnection failed")
return 0,"我连接不到你的网络"
return {"completion_tokens": 0, "content":"我连接不到你的网络"}
except openai.error.Timeout as e:
logger.warn(e)
logger.warn("[OPEN_AI] Timeout")
return 0,"我没有收到你的消息"
return {"completion_tokens": 0, "content":"我没有收到你的消息"}
except Exception as e:
# unknown exception
logger.exception(e)
Session.clear_session(user_id)
return 0,"请再问我一次吧"
return {"completion_tokens": 0, "content": "请再问我一次吧"}

def create_img(self, query, retry_count=0):
try:
@@ -143,7 +144,7 @@ class Session(object):
return session

@staticmethod
def save_session(query, answer, user_id):
def save_session(answer, user_id, total_tokens):
max_tokens = conf().get("conversation_max_tokens")
if not max_tokens:
# default 3000
@@ -153,22 +154,23 @@ class Session(object):
session = user_session.get(user_id)
if session:
# append conversation
gpt_item = {'role': 'assistant', 'content': answer[1]}
gpt_item = {'role': 'assistant', 'content': answer}
session.append(gpt_item)

# discard exceed limit conversation
used_tokens=int(answer[0])
# logger.info("prompt tokens used={},max_tokens={}".format(used_tokens,max_tokens))
Session.discard_exceed_conversation(session, max_tokens, total_tokens)

while used_tokens > max_tokens:
@staticmethod
def discard_exceed_conversation(session, max_tokens, total_tokens):
dec_tokens=int(total_tokens)
# logger.info("prompt tokens used={},max_tokens={}".format(used_tokens,max_tokens))
while dec_tokens > max_tokens:
# pop first conversation
if len(session) > 0:
session.pop(0)
else:
break

used_tokens=used_tokens-max_tokens

dec_tokens=dec_tokens-max_tokens

@staticmethod
def clear_session(user_id):


読み込み中…
キャンセル
保存