From 76783f0ad325ae1eeae3d7fbeccc9bdfb9a81d0b Mon Sep 17 00:00:00 2001 From: JS00000 Date: Wed, 29 Mar 2023 23:08:30 +0800 Subject: [PATCH] private openai_api_key --- bot/chatgpt/chat_gpt_bot.py | 21 +++++++++++++------- channel/wechatmp/wechatmp_channel.py | 14 ++++++++++++-- plugins/godcmd/godcmd.py | 29 ++++++++++++++++++++++++++-- 3 files changed, 53 insertions(+), 11 deletions(-) diff --git a/bot/chatgpt/chat_gpt_bot.py b/bot/chatgpt/chat_gpt_bot.py index 49a99c1..e4507b3 100644 --- a/bot/chatgpt/chat_gpt_bot.py +++ b/bot/chatgpt/chat_gpt_bot.py @@ -13,10 +13,13 @@ from common.expired_dict import ExpiredDict import openai import openai.error import time +import redis + # OpenAI对话模型API (可用) class ChatGPTBot(Bot,OpenAIImage): def __init__(self): super().__init__() + # set the default api_key openai.api_key = conf().get('open_ai_api_key') if conf().get('open_ai_api_base'): openai.api_base = conf().get('open_ai_api_base') @@ -33,6 +36,7 @@ class ChatGPTBot(Bot,OpenAIImage): if context.type == ContextType.TEXT: logger.info("[CHATGPT] query={}".format(query)) + session_id = context['session_id'] reply = None clear_memory_commands = conf().get('clear_memory_commands', ['#清除记忆']) @@ -50,11 +54,13 @@ class ChatGPTBot(Bot,OpenAIImage): session = self.sessions.session_query(query, session_id) logger.debug("[CHATGPT] session query={}".format(session.messages)) + api_key = context.get('openai_api_key') + # if context.get('stream'): # # reply in stream # return self.reply_text_stream(query, new_query, session_id) - reply_content = self.reply_text(session, session_id, 0) + reply_content = self.reply_text(session, session_id, api_key, 0) logger.debug("[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session.messages, session_id, reply_content["content"], reply_content["completion_tokens"])) if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0: reply = Reply(ReplyType.ERROR, reply_content['content']) @@ -90,7 +96,7 @@ class ChatGPTBot(Bot,OpenAIImage): "timeout": 120, #重试超时时间,在这个时间内,将会自动重试 } - def reply_text(self, session:ChatGPTSession, session_id, retry_count=0) -> dict: + def reply_text(self, session:ChatGPTSession, session_id, api_key, retry_count=0) -> dict: ''' call openai's ChatCompletion to get the answer :param session: a conversation session @@ -101,8 +107,9 @@ class ChatGPTBot(Bot,OpenAIImage): try: if conf().get('rate_limit_chatgpt') and not self.tb4chatgpt.get_token(): raise openai.error.RateLimitError("RateLimitError: rate limit exceeded") + # if api_key == None, the default openai.api_key will be used response = openai.ChatCompletion.create( - messages=session.messages, **self.compose_args() + api_key=api_key, messages=session.messages, **self.compose_args() ) # logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"])) return {"total_tokens": response["usage"]["total_tokens"], @@ -118,21 +125,21 @@ class ChatGPTBot(Bot,OpenAIImage): time.sleep(5) elif isinstance(e, openai.error.Timeout): logger.warn("[CHATGPT] Timeout: {}".format(e)) - result['content'] = "我没有收到你的消息" + result['content'] = "服务器出现问题" if need_retry: time.sleep(5) elif isinstance(e, openai.error.APIConnectionError): logger.warn("[CHATGPT] APIConnectionError: {}".format(e)) need_retry = False - result['content'] = "我连接不到你的网络" + result['content'] = "网络连接出现问题" else: logger.warn("[CHATGPT] Exception: {}".format(e)) need_retry = False self.sessions.clear_session(session_id) - + result['content'] = str(e) if need_retry: logger.warn("[CHATGPT] 第{}次重试".format(retry_count+1)) - return self.reply_text(session, session_id, retry_count+1) + return self.reply_text(session, session_id, api_key, retry_count+1) else: return result diff --git a/channel/wechatmp/wechatmp_channel.py b/channel/wechatmp/wechatmp_channel.py index 0599bb1..85393e6 100644 --- a/channel/wechatmp/wechatmp_channel.py +++ b/channel/wechatmp/wechatmp_channel.py @@ -13,6 +13,7 @@ from bridge.reply import * from bridge.context import * from plugins import * import traceback +import redis class WechatMPServer(): def __init__(self): @@ -82,7 +83,6 @@ class WechatMPChannel(Channel): global cache_dict try: reply = Reply() - logger.debug('[wechatmp] ready to handle context: {}'.format(context)) # reply的构建步骤 @@ -134,6 +134,8 @@ class WechatMPChannel(Channel): self.send(reply, context['receiver']) else: cache_dict[context['receiver']] = (1, "No reply") + + logger.info("[threaded] Get reply for {}: {} \nA: {}".format(context['receiver'], context.content, reply.content)) except Exception as exc: print(traceback.format_exc()) cache_dict[context['receiver']] = (1, "ERROR") @@ -171,6 +173,14 @@ class WechatMPChannel(Channel): context = Context() context.kwargs = {'isgroup': False, 'receiver': fromUser, 'session_id': fromUser} + + R = redis.Redis(host='localhost', port=6379, db=0) + user_openai_api_key = "openai_api_key_" + fromUser + api_key = R.get(user_openai_api_key) + if api_key != None: + api_key = api_key.decode("utf-8") + context['openai_api_key'] = api_key # None or user openai_api_key + img_match_prefix = check_prefix(message, conf().get('image_create_prefix')) if img_match_prefix: message = message.replace(img_match_prefix, '', 1).strip() @@ -240,7 +250,7 @@ class WechatMPChannel(Channel): if cnt == 45: # Have waiting for 3x5 seconds # return timeout message - reply_text = "【服务器有点忙,回复任意文字再次尝试】" + reply_text = "【正在响应中,回复任意文字尝试获取回复】" logger.info("[wechatmp] Three queries has finished For {}: {}".format(fromUser, message_id)) replyPost = reply.TextMsg(fromUser, toUser, reply_text).send() return replyPost diff --git a/plugins/godcmd/godcmd.py b/plugins/godcmd/godcmd.py index bc4f09c..33e72a3 100644 --- a/plugins/godcmd/godcmd.py +++ b/plugins/godcmd/godcmd.py @@ -29,6 +29,15 @@ COMMANDS = { "args": ["口令"], "desc": "管理员认证", }, + "set_openai_api_key": { + "alias": ["set_openai_api_key"], + "args": ["api_key"], + "desc": "设置你的OpenAI私有api_key", + }, + "reset_openai_api_key": { + "alias": ["reset_openai_api_key"], + "desc": "重置为默认的api_key", + }, # "id": { # "alias": ["id", "用户"], # "desc": "获取用户id", #目前无实际意义 @@ -99,7 +108,7 @@ def get_help_text(isadmin, isgroup): alias=["#"+a for a in info['alias']] help_text += f"{','.join(alias)} " if 'args' in info: - args=["{"+a+"}" for a in info['args']] + args=["'"+a+"'" for a in info['args']] help_text += f"{' '.join(args)} " help_text += f": {info['desc']}\n" @@ -162,7 +171,7 @@ class Godcmd(Plugin): bottype = Bridge().get_bot_type("chat") bot = Bridge().get_bot("chat") # 将命令和参数分割 - command_parts = content[1:].split(" ") + command_parts = content[1:].strip().split(" ") cmd = command_parts[0] args = command_parts[1:] isadmin=False @@ -184,6 +193,22 @@ class Godcmd(Plugin): ok, result = True, PluginManager().instances[name].get_help_text(verbose=True) else: ok, result = False, "unknown args" + elif cmd == "set_openai_api_key": + if len(args) == 1: + import redis + R = redis.Redis(host='localhost', port=6379, db=0) + user_openai_api_key = "openai_api_key_" + user + R.set(user_openai_api_key, args[0]) + # R.sadd("openai_api_key", args[0]) + ok, result = True, "你的OpenAI私有api_key已设置为" + args[0] + else: + ok, result = False, "请提供一个api_key" + elif cmd == "reset_openai_api_key": + import redis + R = redis.Redis(host='localhost', port=6379, db=0) + user_openai_api_key = "openai_api_key_" + user + R.delete(user_openai_api_key) + ok, result = True, "OpenAI的api_key已重置" # elif cmd == "helpp": # if len(args) != 1: # ok, result = False, "请提供插件名"