Browse Source

private openai_api_key

master
JS00000 1 year ago
parent
commit
76783f0ad3
3 changed files with 53 additions and 11 deletions
  1. +14
    -7
      bot/chatgpt/chat_gpt_bot.py
  2. +12
    -2
      channel/wechatmp/wechatmp_channel.py
  3. +27
    -2
      plugins/godcmd/godcmd.py

+ 14
- 7
bot/chatgpt/chat_gpt_bot.py View File

@@ -13,10 +13,13 @@ from common.expired_dict import ExpiredDict
import openai import openai
import openai.error import openai.error
import time import time
import redis

# OpenAI对话模型API (可用) # OpenAI对话模型API (可用)
class ChatGPTBot(Bot,OpenAIImage): class ChatGPTBot(Bot,OpenAIImage):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
# set the default api_key
openai.api_key = conf().get('open_ai_api_key') openai.api_key = conf().get('open_ai_api_key')
if conf().get('open_ai_api_base'): if conf().get('open_ai_api_base'):
openai.api_base = conf().get('open_ai_api_base') openai.api_base = conf().get('open_ai_api_base')
@@ -33,6 +36,7 @@ class ChatGPTBot(Bot,OpenAIImage):
if context.type == ContextType.TEXT: if context.type == ContextType.TEXT:
logger.info("[CHATGPT] query={}".format(query)) logger.info("[CHATGPT] query={}".format(query))



session_id = context['session_id'] session_id = context['session_id']
reply = None reply = None
clear_memory_commands = conf().get('clear_memory_commands', ['#清除记忆']) clear_memory_commands = conf().get('clear_memory_commands', ['#清除记忆'])
@@ -50,11 +54,13 @@ class ChatGPTBot(Bot,OpenAIImage):
session = self.sessions.session_query(query, session_id) session = self.sessions.session_query(query, session_id)
logger.debug("[CHATGPT] session query={}".format(session.messages)) logger.debug("[CHATGPT] session query={}".format(session.messages))


api_key = context.get('openai_api_key')

# if context.get('stream'): # if context.get('stream'):
# # reply in stream # # reply in stream
# return self.reply_text_stream(query, new_query, session_id) # return self.reply_text_stream(query, new_query, session_id)


reply_content = self.reply_text(session, session_id, 0)
reply_content = self.reply_text(session, session_id, api_key, 0)
logger.debug("[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session.messages, session_id, reply_content["content"], reply_content["completion_tokens"])) logger.debug("[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session.messages, session_id, reply_content["content"], reply_content["completion_tokens"]))
if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0: if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0:
reply = Reply(ReplyType.ERROR, reply_content['content']) reply = Reply(ReplyType.ERROR, reply_content['content'])
@@ -90,7 +96,7 @@ class ChatGPTBot(Bot,OpenAIImage):
"timeout": 120, #重试超时时间,在这个时间内,将会自动重试 "timeout": 120, #重试超时时间,在这个时间内,将会自动重试
} }


def reply_text(self, session:ChatGPTSession, session_id, retry_count=0) -> dict:
def reply_text(self, session:ChatGPTSession, session_id, api_key, retry_count=0) -> dict:
''' '''
call openai's ChatCompletion to get the answer call openai's ChatCompletion to get the answer
:param session: a conversation session :param session: a conversation session
@@ -101,8 +107,9 @@ class ChatGPTBot(Bot,OpenAIImage):
try: try:
if conf().get('rate_limit_chatgpt') and not self.tb4chatgpt.get_token(): if conf().get('rate_limit_chatgpt') and not self.tb4chatgpt.get_token():
raise openai.error.RateLimitError("RateLimitError: rate limit exceeded") raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
# if api_key == None, the default openai.api_key will be used
response = openai.ChatCompletion.create( response = openai.ChatCompletion.create(
messages=session.messages, **self.compose_args()
api_key=api_key, messages=session.messages, **self.compose_args()
) )
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"])) # logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
return {"total_tokens": response["usage"]["total_tokens"], return {"total_tokens": response["usage"]["total_tokens"],
@@ -118,21 +125,21 @@ class ChatGPTBot(Bot,OpenAIImage):
time.sleep(5) time.sleep(5)
elif isinstance(e, openai.error.Timeout): elif isinstance(e, openai.error.Timeout):
logger.warn("[CHATGPT] Timeout: {}".format(e)) logger.warn("[CHATGPT] Timeout: {}".format(e))
result['content'] = "我没有收到你的消息"
result['content'] = "服务器出现问题"
if need_retry: if need_retry:
time.sleep(5) time.sleep(5)
elif isinstance(e, openai.error.APIConnectionError): elif isinstance(e, openai.error.APIConnectionError):
logger.warn("[CHATGPT] APIConnectionError: {}".format(e)) logger.warn("[CHATGPT] APIConnectionError: {}".format(e))
need_retry = False need_retry = False
result['content'] = "我连接不到你的网络"
result['content'] = "网络连接出现问题"
else: else:
logger.warn("[CHATGPT] Exception: {}".format(e)) logger.warn("[CHATGPT] Exception: {}".format(e))
need_retry = False need_retry = False
self.sessions.clear_session(session_id) self.sessions.clear_session(session_id)
result['content'] = str(e)
if need_retry: if need_retry:
logger.warn("[CHATGPT] 第{}次重试".format(retry_count+1)) logger.warn("[CHATGPT] 第{}次重试".format(retry_count+1))
return self.reply_text(session, session_id, retry_count+1)
return self.reply_text(session, session_id, api_key, retry_count+1)
else: else:
return result return result




+ 12
- 2
channel/wechatmp/wechatmp_channel.py View File

@@ -13,6 +13,7 @@ from bridge.reply import *
from bridge.context import * from bridge.context import *
from plugins import * from plugins import *
import traceback import traceback
import redis


class WechatMPServer(): class WechatMPServer():
def __init__(self): def __init__(self):
@@ -82,7 +83,6 @@ class WechatMPChannel(Channel):
global cache_dict global cache_dict
try: try:
reply = Reply() reply = Reply()

logger.debug('[wechatmp] ready to handle context: {}'.format(context)) logger.debug('[wechatmp] ready to handle context: {}'.format(context))


# reply的构建步骤 # reply的构建步骤
@@ -134,6 +134,8 @@ class WechatMPChannel(Channel):
self.send(reply, context['receiver']) self.send(reply, context['receiver'])
else: else:
cache_dict[context['receiver']] = (1, "No reply") cache_dict[context['receiver']] = (1, "No reply")

logger.info("[threaded] Get reply for {}: {} \nA: {}".format(context['receiver'], context.content, reply.content))
except Exception as exc: except Exception as exc:
print(traceback.format_exc()) print(traceback.format_exc())
cache_dict[context['receiver']] = (1, "ERROR") cache_dict[context['receiver']] = (1, "ERROR")
@@ -171,6 +173,14 @@ class WechatMPChannel(Channel):


context = Context() context = Context()
context.kwargs = {'isgroup': False, 'receiver': fromUser, 'session_id': fromUser} context.kwargs = {'isgroup': False, 'receiver': fromUser, 'session_id': fromUser}

R = redis.Redis(host='localhost', port=6379, db=0)
user_openai_api_key = "openai_api_key_" + fromUser
api_key = R.get(user_openai_api_key)
if api_key != None:
api_key = api_key.decode("utf-8")
context['openai_api_key'] = api_key # None or user openai_api_key

img_match_prefix = check_prefix(message, conf().get('image_create_prefix')) img_match_prefix = check_prefix(message, conf().get('image_create_prefix'))
if img_match_prefix: if img_match_prefix:
message = message.replace(img_match_prefix, '', 1).strip() message = message.replace(img_match_prefix, '', 1).strip()
@@ -240,7 +250,7 @@ class WechatMPChannel(Channel):
if cnt == 45: if cnt == 45:
# Have waiting for 3x5 seconds # Have waiting for 3x5 seconds
# return timeout message # return timeout message
reply_text = "【服务器有点忙,回复任意文字再次尝试】"
reply_text = "【正在响应中,回复任意文字尝试获取回复】"
logger.info("[wechatmp] Three queries has finished For {}: {}".format(fromUser, message_id)) logger.info("[wechatmp] Three queries has finished For {}: {}".format(fromUser, message_id))
replyPost = reply.TextMsg(fromUser, toUser, reply_text).send() replyPost = reply.TextMsg(fromUser, toUser, reply_text).send()
return replyPost return replyPost


+ 27
- 2
plugins/godcmd/godcmd.py View File

@@ -29,6 +29,15 @@ COMMANDS = {
"args": ["口令"], "args": ["口令"],
"desc": "管理员认证", "desc": "管理员认证",
}, },
"set_openai_api_key": {
"alias": ["set_openai_api_key"],
"args": ["api_key"],
"desc": "设置你的OpenAI私有api_key",
},
"reset_openai_api_key": {
"alias": ["reset_openai_api_key"],
"desc": "重置为默认的api_key",
},
# "id": { # "id": {
# "alias": ["id", "用户"], # "alias": ["id", "用户"],
# "desc": "获取用户id", #目前无实际意义 # "desc": "获取用户id", #目前无实际意义
@@ -99,7 +108,7 @@ def get_help_text(isadmin, isgroup):
alias=["#"+a for a in info['alias']] alias=["#"+a for a in info['alias']]
help_text += f"{','.join(alias)} " help_text += f"{','.join(alias)} "
if 'args' in info: if 'args' in info:
args=["{"+a+"}" for a in info['args']]
args=["'"+a+"'" for a in info['args']]
help_text += f"{' '.join(args)} " help_text += f"{' '.join(args)} "
help_text += f": {info['desc']}\n" help_text += f": {info['desc']}\n"


@@ -162,7 +171,7 @@ class Godcmd(Plugin):
bottype = Bridge().get_bot_type("chat") bottype = Bridge().get_bot_type("chat")
bot = Bridge().get_bot("chat") bot = Bridge().get_bot("chat")
# 将命令和参数分割 # 将命令和参数分割
command_parts = content[1:].split(" ")
command_parts = content[1:].strip().split(" ")
cmd = command_parts[0] cmd = command_parts[0]
args = command_parts[1:] args = command_parts[1:]
isadmin=False isadmin=False
@@ -184,6 +193,22 @@ class Godcmd(Plugin):
ok, result = True, PluginManager().instances[name].get_help_text(verbose=True) ok, result = True, PluginManager().instances[name].get_help_text(verbose=True)
else: else:
ok, result = False, "unknown args" ok, result = False, "unknown args"
elif cmd == "set_openai_api_key":
if len(args) == 1:
import redis
R = redis.Redis(host='localhost', port=6379, db=0)
user_openai_api_key = "openai_api_key_" + user
R.set(user_openai_api_key, args[0])
# R.sadd("openai_api_key", args[0])
ok, result = True, "你的OpenAI私有api_key已设置为" + args[0]
else:
ok, result = False, "请提供一个api_key"
elif cmd == "reset_openai_api_key":
import redis
R = redis.Redis(host='localhost', port=6379, db=0)
user_openai_api_key = "openai_api_key_" + user
R.delete(user_openai_api_key)
ok, result = True, "OpenAI的api_key已重置"
# elif cmd == "helpp": # elif cmd == "helpp":
# if len(args) != 1: # if len(args) != 1:
# ok, result = False, "请提供插件名" # ok, result = False, "请提供插件名"


Loading…
Cancel
Save