Selaa lähdekoodia

private openai_api_key

master
JS00000 1 vuosi sitten
vanhempi
commit
76783f0ad3
3 muutettua tiedostoa jossa 53 lisäystä ja 11 poistoa
  1. +14
    -7
      bot/chatgpt/chat_gpt_bot.py
  2. +12
    -2
      channel/wechatmp/wechatmp_channel.py
  3. +27
    -2
      plugins/godcmd/godcmd.py

+ 14
- 7
bot/chatgpt/chat_gpt_bot.py Näytä tiedosto

@@ -13,10 +13,13 @@ from common.expired_dict import ExpiredDict
import openai
import openai.error
import time
import redis

# OpenAI对话模型API (可用)
class ChatGPTBot(Bot,OpenAIImage):
def __init__(self):
super().__init__()
# set the default api_key
openai.api_key = conf().get('open_ai_api_key')
if conf().get('open_ai_api_base'):
openai.api_base = conf().get('open_ai_api_base')
@@ -33,6 +36,7 @@ class ChatGPTBot(Bot,OpenAIImage):
if context.type == ContextType.TEXT:
logger.info("[CHATGPT] query={}".format(query))


session_id = context['session_id']
reply = None
clear_memory_commands = conf().get('clear_memory_commands', ['#清除记忆'])
@@ -50,11 +54,13 @@ class ChatGPTBot(Bot,OpenAIImage):
session = self.sessions.session_query(query, session_id)
logger.debug("[CHATGPT] session query={}".format(session.messages))

api_key = context.get('openai_api_key')

# if context.get('stream'):
# # reply in stream
# return self.reply_text_stream(query, new_query, session_id)

reply_content = self.reply_text(session, session_id, 0)
reply_content = self.reply_text(session, session_id, api_key, 0)
logger.debug("[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session.messages, session_id, reply_content["content"], reply_content["completion_tokens"]))
if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0:
reply = Reply(ReplyType.ERROR, reply_content['content'])
@@ -90,7 +96,7 @@ class ChatGPTBot(Bot,OpenAIImage):
"timeout": 120, #重试超时时间,在这个时间内,将会自动重试
}

def reply_text(self, session:ChatGPTSession, session_id, retry_count=0) -> dict:
def reply_text(self, session:ChatGPTSession, session_id, api_key, retry_count=0) -> dict:
'''
call openai's ChatCompletion to get the answer
:param session: a conversation session
@@ -101,8 +107,9 @@ class ChatGPTBot(Bot,OpenAIImage):
try:
if conf().get('rate_limit_chatgpt') and not self.tb4chatgpt.get_token():
raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
# if api_key == None, the default openai.api_key will be used
response = openai.ChatCompletion.create(
messages=session.messages, **self.compose_args()
api_key=api_key, messages=session.messages, **self.compose_args()
)
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
return {"total_tokens": response["usage"]["total_tokens"],
@@ -118,21 +125,21 @@ class ChatGPTBot(Bot,OpenAIImage):
time.sleep(5)
elif isinstance(e, openai.error.Timeout):
logger.warn("[CHATGPT] Timeout: {}".format(e))
result['content'] = "我没有收到你的消息"
result['content'] = "服务器出现问题"
if need_retry:
time.sleep(5)
elif isinstance(e, openai.error.APIConnectionError):
logger.warn("[CHATGPT] APIConnectionError: {}".format(e))
need_retry = False
result['content'] = "我连接不到你的网络"
result['content'] = "网络连接出现问题"
else:
logger.warn("[CHATGPT] Exception: {}".format(e))
need_retry = False
self.sessions.clear_session(session_id)
result['content'] = str(e)
if need_retry:
logger.warn("[CHATGPT] 第{}次重试".format(retry_count+1))
return self.reply_text(session, session_id, retry_count+1)
return self.reply_text(session, session_id, api_key, retry_count+1)
else:
return result



+ 12
- 2
channel/wechatmp/wechatmp_channel.py Näytä tiedosto

@@ -13,6 +13,7 @@ from bridge.reply import *
from bridge.context import *
from plugins import *
import traceback
import redis

class WechatMPServer():
def __init__(self):
@@ -82,7 +83,6 @@ class WechatMPChannel(Channel):
global cache_dict
try:
reply = Reply()

logger.debug('[wechatmp] ready to handle context: {}'.format(context))

# reply的构建步骤
@@ -134,6 +134,8 @@ class WechatMPChannel(Channel):
self.send(reply, context['receiver'])
else:
cache_dict[context['receiver']] = (1, "No reply")

logger.info("[threaded] Get reply for {}: {} \nA: {}".format(context['receiver'], context.content, reply.content))
except Exception as exc:
print(traceback.format_exc())
cache_dict[context['receiver']] = (1, "ERROR")
@@ -171,6 +173,14 @@ class WechatMPChannel(Channel):

context = Context()
context.kwargs = {'isgroup': False, 'receiver': fromUser, 'session_id': fromUser}

R = redis.Redis(host='localhost', port=6379, db=0)
user_openai_api_key = "openai_api_key_" + fromUser
api_key = R.get(user_openai_api_key)
if api_key != None:
api_key = api_key.decode("utf-8")
context['openai_api_key'] = api_key # None or user openai_api_key

img_match_prefix = check_prefix(message, conf().get('image_create_prefix'))
if img_match_prefix:
message = message.replace(img_match_prefix, '', 1).strip()
@@ -240,7 +250,7 @@ class WechatMPChannel(Channel):
if cnt == 45:
# Have waiting for 3x5 seconds
# return timeout message
reply_text = "【服务器有点忙,回复任意文字再次尝试】"
reply_text = "【正在响应中,回复任意文字尝试获取回复】"
logger.info("[wechatmp] Three queries has finished For {}: {}".format(fromUser, message_id))
replyPost = reply.TextMsg(fromUser, toUser, reply_text).send()
return replyPost


+ 27
- 2
plugins/godcmd/godcmd.py Näytä tiedosto

@@ -29,6 +29,15 @@ COMMANDS = {
"args": ["口令"],
"desc": "管理员认证",
},
"set_openai_api_key": {
"alias": ["set_openai_api_key"],
"args": ["api_key"],
"desc": "设置你的OpenAI私有api_key",
},
"reset_openai_api_key": {
"alias": ["reset_openai_api_key"],
"desc": "重置为默认的api_key",
},
# "id": {
# "alias": ["id", "用户"],
# "desc": "获取用户id", #目前无实际意义
@@ -99,7 +108,7 @@ def get_help_text(isadmin, isgroup):
alias=["#"+a for a in info['alias']]
help_text += f"{','.join(alias)} "
if 'args' in info:
args=["{"+a+"}" for a in info['args']]
args=["'"+a+"'" for a in info['args']]
help_text += f"{' '.join(args)} "
help_text += f": {info['desc']}\n"

@@ -162,7 +171,7 @@ class Godcmd(Plugin):
bottype = Bridge().get_bot_type("chat")
bot = Bridge().get_bot("chat")
# 将命令和参数分割
command_parts = content[1:].split(" ")
command_parts = content[1:].strip().split(" ")
cmd = command_parts[0]
args = command_parts[1:]
isadmin=False
@@ -184,6 +193,22 @@ class Godcmd(Plugin):
ok, result = True, PluginManager().instances[name].get_help_text(verbose=True)
else:
ok, result = False, "unknown args"
elif cmd == "set_openai_api_key":
if len(args) == 1:
import redis
R = redis.Redis(host='localhost', port=6379, db=0)
user_openai_api_key = "openai_api_key_" + user
R.set(user_openai_api_key, args[0])
# R.sadd("openai_api_key", args[0])
ok, result = True, "你的OpenAI私有api_key已设置为" + args[0]
else:
ok, result = False, "请提供一个api_key"
elif cmd == "reset_openai_api_key":
import redis
R = redis.Redis(host='localhost', port=6379, db=0)
user_openai_api_key = "openai_api_key_" + user
R.delete(user_openai_api_key)
ok, result = True, "OpenAI的api_key已重置"
# elif cmd == "helpp":
# if len(args) != 1:
# ok, result = False, "请提供插件名"


Loading…
Peruuta
Tallenna