|
@@ -1,6 +1,8 @@ |
|
|
# encoding:utf-8 |
|
|
# encoding:utf-8 |
|
|
|
|
|
|
|
|
from bot.bot import Bot |
|
|
from bot.bot import Bot |
|
|
|
|
|
from bridge.context import ContextType |
|
|
|
|
|
from bridge.reply import Reply, ReplyType |
|
|
from config import conf |
|
|
from config import conf |
|
|
from common.log import logger |
|
|
from common.log import logger |
|
|
import openai |
|
|
import openai |
|
@@ -13,30 +15,31 @@ class OpenAIBot(Bot): |
|
|
def __init__(self): |
|
|
def __init__(self): |
|
|
openai.api_key = conf().get('open_ai_api_key') |
|
|
openai.api_key = conf().get('open_ai_api_key') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reply(self, query, context=None): |
|
|
def reply(self, query, context=None): |
|
|
# acquire reply content |
|
|
# acquire reply content |
|
|
if not context or not context.get('type') or context.get('type') == 'TEXT': |
|
|
|
|
|
logger.info("[OPEN_AI] query={}".format(query)) |
|
|
|
|
|
from_user_id = context.get('from_user_id') or context.get('session_id') |
|
|
|
|
|
if query == '#清除记忆': |
|
|
|
|
|
Session.clear_session(from_user_id) |
|
|
|
|
|
return '记忆已清除' |
|
|
|
|
|
elif query == '#清除所有': |
|
|
|
|
|
Session.clear_all_session() |
|
|
|
|
|
return '所有人记忆已清除' |
|
|
|
|
|
|
|
|
|
|
|
new_query = Session.build_session_query(query, from_user_id) |
|
|
|
|
|
logger.debug("[OPEN_AI] session query={}".format(new_query)) |
|
|
|
|
|
|
|
|
|
|
|
reply_content = self.reply_text(new_query, from_user_id, 0) |
|
|
|
|
|
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content)) |
|
|
|
|
|
if reply_content and query: |
|
|
|
|
|
Session.save_session(query, reply_content, from_user_id) |
|
|
|
|
|
return reply_content |
|
|
|
|
|
|
|
|
|
|
|
elif context.get('type', None) == 'IMAGE_CREATE': |
|
|
|
|
|
return self.create_img(query, 0) |
|
|
|
|
|
|
|
|
if context and context.type: |
|
|
|
|
|
if context.type == ContextType.TEXT: |
|
|
|
|
|
logger.info("[OPEN_AI] query={}".format(query)) |
|
|
|
|
|
from_user_id = context['session_id'] |
|
|
|
|
|
reply = None |
|
|
|
|
|
if query == '#清除记忆': |
|
|
|
|
|
Session.clear_session(from_user_id) |
|
|
|
|
|
reply = Reply(ReplyType.INFO, '记忆已清除') |
|
|
|
|
|
elif query == '#清除所有': |
|
|
|
|
|
Session.clear_all_session() |
|
|
|
|
|
reply = Reply(ReplyType.INFO, '所有人记忆已清除') |
|
|
|
|
|
else: |
|
|
|
|
|
new_query = Session.build_session_query(query, from_user_id) |
|
|
|
|
|
logger.debug("[OPEN_AI] session query={}".format(new_query)) |
|
|
|
|
|
|
|
|
|
|
|
reply_content = self.reply_text(new_query, from_user_id, 0) |
|
|
|
|
|
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content)) |
|
|
|
|
|
if reply_content and query: |
|
|
|
|
|
Session.save_session(query, reply_content, from_user_id) |
|
|
|
|
|
reply = Reply(ReplyType.TEXT, reply_content) |
|
|
|
|
|
return reply |
|
|
|
|
|
elif context.type == ContextType.IMAGE_CREATE: |
|
|
|
|
|
return self.create_img(query, 0) |
|
|
|
|
|
|
|
|
def reply_text(self, query, user_id, retry_count=0): |
|
|
def reply_text(self, query, user_id, retry_count=0): |
|
|
try: |
|
|
try: |
|
|