@@ -13,10 +13,13 @@ from common.expired_dict import ExpiredDict | |||
import openai | |||
import openai.error | |||
import time | |||
import redis | |||
# OpenAI对话模型API (可用) | |||
class ChatGPTBot(Bot,OpenAIImage): | |||
def __init__(self): | |||
super().__init__() | |||
# set the default api_key | |||
openai.api_key = conf().get('open_ai_api_key') | |||
if conf().get('open_ai_api_base'): | |||
openai.api_base = conf().get('open_ai_api_base') | |||
@@ -33,6 +36,7 @@ class ChatGPTBot(Bot,OpenAIImage): | |||
if context.type == ContextType.TEXT: | |||
logger.info("[CHATGPT] query={}".format(query)) | |||
session_id = context['session_id'] | |||
reply = None | |||
clear_memory_commands = conf().get('clear_memory_commands', ['#清除记忆']) | |||
@@ -50,11 +54,13 @@ class ChatGPTBot(Bot,OpenAIImage): | |||
session = self.sessions.session_query(query, session_id) | |||
logger.debug("[CHATGPT] session query={}".format(session.messages)) | |||
api_key = context.get('openai_api_key') | |||
# if context.get('stream'): | |||
# # reply in stream | |||
# return self.reply_text_stream(query, new_query, session_id) | |||
reply_content = self.reply_text(session, session_id, 0) | |||
reply_content = self.reply_text(session, session_id, api_key, 0) | |||
logger.debug("[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session.messages, session_id, reply_content["content"], reply_content["completion_tokens"])) | |||
if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0: | |||
reply = Reply(ReplyType.ERROR, reply_content['content']) | |||
@@ -90,7 +96,7 @@ class ChatGPTBot(Bot,OpenAIImage): | |||
"timeout": 120, #重试超时时间,在这个时间内,将会自动重试 | |||
} | |||
def reply_text(self, session:ChatGPTSession, session_id, retry_count=0) -> dict: | |||
def reply_text(self, session:ChatGPTSession, session_id, api_key, retry_count=0) -> dict: | |||
''' | |||
call openai's ChatCompletion to get the answer | |||
:param session: a conversation session | |||
@@ -101,8 +107,9 @@ class ChatGPTBot(Bot,OpenAIImage): | |||
try: | |||
if conf().get('rate_limit_chatgpt') and not self.tb4chatgpt.get_token(): | |||
raise openai.error.RateLimitError("RateLimitError: rate limit exceeded") | |||
# if api_key == None, the default openai.api_key will be used | |||
response = openai.ChatCompletion.create( | |||
messages=session.messages, **self.compose_args() | |||
api_key=api_key, messages=session.messages, **self.compose_args() | |||
) | |||
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"])) | |||
return {"total_tokens": response["usage"]["total_tokens"], | |||
@@ -118,21 +125,21 @@ class ChatGPTBot(Bot,OpenAIImage): | |||
time.sleep(5) | |||
elif isinstance(e, openai.error.Timeout): | |||
logger.warn("[CHATGPT] Timeout: {}".format(e)) | |||
result['content'] = "我没有收到你的消息" | |||
result['content'] = "服务器出现问题" | |||
if need_retry: | |||
time.sleep(5) | |||
elif isinstance(e, openai.error.APIConnectionError): | |||
logger.warn("[CHATGPT] APIConnectionError: {}".format(e)) | |||
need_retry = False | |||
result['content'] = "我连接不到你的网络" | |||
result['content'] = "网络连接出现问题" | |||
else: | |||
logger.warn("[CHATGPT] Exception: {}".format(e)) | |||
need_retry = False | |||
self.sessions.clear_session(session_id) | |||
result['content'] = str(e) | |||
if need_retry: | |||
logger.warn("[CHATGPT] 第{}次重试".format(retry_count+1)) | |||
return self.reply_text(session, session_id, retry_count+1) | |||
return self.reply_text(session, session_id, api_key, retry_count+1) | |||
else: | |||
return result | |||
@@ -13,6 +13,7 @@ from bridge.reply import * | |||
from bridge.context import * | |||
from plugins import * | |||
import traceback | |||
import redis | |||
class WechatMPServer(): | |||
def __init__(self): | |||
@@ -82,7 +83,6 @@ class WechatMPChannel(Channel): | |||
global cache_dict | |||
try: | |||
reply = Reply() | |||
logger.debug('[wechatmp] ready to handle context: {}'.format(context)) | |||
# reply的构建步骤 | |||
@@ -134,6 +134,8 @@ class WechatMPChannel(Channel): | |||
self.send(reply, context['receiver']) | |||
else: | |||
cache_dict[context['receiver']] = (1, "No reply") | |||
logger.info("[threaded] Get reply for {}: {} \nA: {}".format(context['receiver'], context.content, reply.content)) | |||
except Exception as exc: | |||
print(traceback.format_exc()) | |||
cache_dict[context['receiver']] = (1, "ERROR") | |||
@@ -171,6 +173,14 @@ class WechatMPChannel(Channel): | |||
context = Context() | |||
context.kwargs = {'isgroup': False, 'receiver': fromUser, 'session_id': fromUser} | |||
R = redis.Redis(host='localhost', port=6379, db=0) | |||
user_openai_api_key = "openai_api_key_" + fromUser | |||
api_key = R.get(user_openai_api_key) | |||
if api_key != None: | |||
api_key = api_key.decode("utf-8") | |||
context['openai_api_key'] = api_key # None or user openai_api_key | |||
img_match_prefix = check_prefix(message, conf().get('image_create_prefix')) | |||
if img_match_prefix: | |||
message = message.replace(img_match_prefix, '', 1).strip() | |||
@@ -240,7 +250,7 @@ class WechatMPChannel(Channel): | |||
if cnt == 45: | |||
# Have waiting for 3x5 seconds | |||
# return timeout message | |||
reply_text = "【服务器有点忙,回复任意文字再次尝试】" | |||
reply_text = "【正在响应中,回复任意文字尝试获取回复】" | |||
logger.info("[wechatmp] Three queries has finished For {}: {}".format(fromUser, message_id)) | |||
replyPost = reply.TextMsg(fromUser, toUser, reply_text).send() | |||
return replyPost | |||
@@ -29,6 +29,15 @@ COMMANDS = { | |||
"args": ["口令"], | |||
"desc": "管理员认证", | |||
}, | |||
"set_openai_api_key": { | |||
"alias": ["set_openai_api_key"], | |||
"args": ["api_key"], | |||
"desc": "设置你的OpenAI私有api_key", | |||
}, | |||
"reset_openai_api_key": { | |||
"alias": ["reset_openai_api_key"], | |||
"desc": "重置为默认的api_key", | |||
}, | |||
# "id": { | |||
# "alias": ["id", "用户"], | |||
# "desc": "获取用户id", #目前无实际意义 | |||
@@ -99,7 +108,7 @@ def get_help_text(isadmin, isgroup): | |||
alias=["#"+a for a in info['alias']] | |||
help_text += f"{','.join(alias)} " | |||
if 'args' in info: | |||
args=["{"+a+"}" for a in info['args']] | |||
args=["'"+a+"'" for a in info['args']] | |||
help_text += f"{' '.join(args)} " | |||
help_text += f": {info['desc']}\n" | |||
@@ -162,7 +171,7 @@ class Godcmd(Plugin): | |||
bottype = Bridge().get_bot_type("chat") | |||
bot = Bridge().get_bot("chat") | |||
# 将命令和参数分割 | |||
command_parts = content[1:].split(" ") | |||
command_parts = content[1:].strip().split(" ") | |||
cmd = command_parts[0] | |||
args = command_parts[1:] | |||
isadmin=False | |||
@@ -184,6 +193,22 @@ class Godcmd(Plugin): | |||
ok, result = True, PluginManager().instances[name].get_help_text(verbose=True) | |||
else: | |||
ok, result = False, "unknown args" | |||
elif cmd == "set_openai_api_key": | |||
if len(args) == 1: | |||
import redis | |||
R = redis.Redis(host='localhost', port=6379, db=0) | |||
user_openai_api_key = "openai_api_key_" + user | |||
R.set(user_openai_api_key, args[0]) | |||
# R.sadd("openai_api_key", args[0]) | |||
ok, result = True, "你的OpenAI私有api_key已设置为" + args[0] | |||
else: | |||
ok, result = False, "请提供一个api_key" | |||
elif cmd == "reset_openai_api_key": | |||
import redis | |||
R = redis.Redis(host='localhost', port=6379, db=0) | |||
user_openai_api_key = "openai_api_key_" + user | |||
R.delete(user_openai_api_key) | |||
ok, result = True, "OpenAI的api_key已重置" | |||
# elif cmd == "helpp": | |||
# if len(args) != 1: | |||
# ok, result = False, "请提供插件名" | |||