瀏覽代碼

Merge pull request #360 from zwssunny/master

修正会话tokens计算
master
zhayujie GitHub 1 年之前
父節點
當前提交
2886f48788
沒有發現已知的金鑰在資料庫的簽署中 GPG 金鑰 ID: 4AEE18F83AFDEB23
共有 2 個檔案被更改,包括 33 行新增28 行删除
  1. +1
    -0
      README.md
  2. +32
    -28
      bot/chatgpt/chat_gpt_bot.py

+ 1
- 0
README.md 查看文件

@@ -142,6 +142,7 @@ touch nohup.out # 首次运行需要新建日
nohup python3 app.py & tail -f nohup.out # 在后台运行程序并通过日志输出二维码
```
扫码登录后程序即可运行于服务器后台,此时可通过 `ctrl+c` 关闭日志,不会影响后台程序的运行。使用 `ps -ef | grep app.py | grep -v grep` 命令可查看运行于后台的进程,如果想要重新启动程序可以先 `kill` 掉对应的进程。日志关闭后如果想要再次打开只需输入 `tail -f nohup.out`。
scripts/目录有相应的脚本可以调用

> **注意:** 如果 扫码后手机提示登录验证需要等待5s,而终端的二维码再次刷新并提示 `Log in time out, reloading QR code`,此时需参考此 [issue](https://github.com/zhayujie/chatgpt-on-wechat/issues/8) 修改一行代码即可解决。



+ 32
- 28
bot/chatgpt/chat_gpt_bot.py 查看文件

@@ -6,7 +6,6 @@ from common.log import logger
from common.expired_dict import ExpiredDict
import openai
import time
import json

if conf().get('expires_in_seconds'):
user_session = ExpiredDict(conf().get('expires_in_seconds'))
@@ -41,15 +40,22 @@ class ChatGPTBot(Bot):
# return self.reply_text_stream(query, new_query, from_user_id)

reply_content = self.reply_text(new_query, from_user_id, 0)
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
if reply_content:
Session.save_session(query, reply_content, from_user_id)
return reply_content
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content["content"]))
if reply_content["completion_tokens"] > 0:
Session.save_session(reply_content["content"], from_user_id, reply_content["total_tokens"])
return reply_content["content"]

elif context.get('type', None) == 'IMAGE_CREATE':
return self.create_img(query, 0)

def reply_text(self, query, user_id, retry_count=0):
def reply_text(self, query, user_id, retry_count=0) ->dict:
'''
call openai's ChatCompletion to get the answer
:param query: query content
:param user_id: from user id
:param retry_count: retry count
:return: {}
'''
try:
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo", # 对话模型的名称
@@ -62,8 +68,9 @@ class ChatGPTBot(Bot):
)
# res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '')
logger.info(response.choices[0]['message']['content'])
# log.info("[OPEN_AI] reply={}".format(res_content))
return response.choices[0]['message']['content']
return {"total_tokens": response["usage"]["total_tokens"],
"completion_tokens": response["usage"]["completion_tokens"],
"content": response.choices[0]['message']['content']}
except openai.error.RateLimitError as e:
# rate limit exception
logger.warn(e)
@@ -72,21 +79,21 @@ class ChatGPTBot(Bot):
logger.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1))
return self.reply_text(query, user_id, retry_count+1)
else:
return "提问太快啦,请休息一下再问我吧"
return {"completion_tokens": 0, "content": "提问太快啦,请休息一下再问我吧"}
except openai.error.APIConnectionError as e:
# api connection exception
logger.warn(e)
logger.warn("[OPEN_AI] APIConnection failed")
return "我连接不到你的网络"
return {"completion_tokens": 0, "content":"我连接不到你的网络"}
except openai.error.Timeout as e:
logger.warn(e)
logger.warn("[OPEN_AI] Timeout")
return "我没有收到你的消息"
return {"completion_tokens": 0, "content":"我没有收到你的消息"}
except Exception as e:
# unknown exception
logger.exception(e)
Session.clear_session(user_id)
return "请再问我一次吧"
return {"completion_tokens": 0, "content": "请再问我一次吧"}

def create_img(self, query, retry_count=0):
try:
@@ -137,11 +144,12 @@ class Session(object):
return session

@staticmethod
def save_session(query, answer, user_id):
def save_session(answer, user_id, total_tokens):
max_tokens = conf().get("conversation_max_tokens")
if not max_tokens:
# default 3000
max_tokens = 1000
max_tokens=int(max_tokens)

session = user_session.get(user_id)
if session:
@@ -150,23 +158,19 @@ class Session(object):
session.append(gpt_item)

# discard exceed limit conversation
Session.discard_exceed_conversation(user_session[user_id], max_tokens)
Session.discard_exceed_conversation(session, max_tokens, total_tokens)

@staticmethod
def discard_exceed_conversation(session, max_tokens):
count = 0
count_list = list()
for i in range(len(session)-1, -1, -1):
# count tokens of conversation list
history_conv = session[i]
tokens=json.dumps(history_conv).split()
count += len(tokens)
count_list.append(count)

for c in count_list:
if c > max_tokens:
# pop first conversation
session.pop(0)
def discard_exceed_conversation(session, max_tokens, total_tokens):
dec_tokens=int(total_tokens)
# logger.info("prompt tokens used={},max_tokens={}".format(used_tokens,max_tokens))
while dec_tokens > max_tokens:
# pop first conversation
if len(session) > 0:
session.pop(0)
else:
break
dec_tokens=dec_tokens-max_tokens

@staticmethod
def clear_session(user_id):


Loading…
取消
儲存