|
|
@@ -26,21 +26,24 @@ class GoogleGeminiBot(Bot): |
|
|
|
self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "gpt-3.5-turbo") |
|
|
|
|
|
|
|
def reply(self, query, context: Context = None) -> Reply: |
|
|
|
if context.type != ContextType.TEXT: |
|
|
|
logger.warn(f"[Gemini] Unsupported message type, type={context.type}") |
|
|
|
return Reply(ReplyType.TEXT, None) |
|
|
|
logger.info(f"[Gemini] query={query}") |
|
|
|
session_id = context["session_id"] |
|
|
|
session = self.sessions.session_query(query, session_id) |
|
|
|
gemini_messages = self._convert_to_gemini_messages(session.messages) |
|
|
|
genai.configure(api_key=self.api_key) |
|
|
|
model = genai.GenerativeModel('gemini-pro') |
|
|
|
response = model.generate_content(gemini_messages) |
|
|
|
reply_text = response.text |
|
|
|
self.sessions.session_reply(reply_text, session_id) |
|
|
|
logger.info(f"[Gemini] reply={reply_text}") |
|
|
|
return Reply(ReplyType.TEXT, reply_text) |
|
|
|
|
|
|
|
try: |
|
|
|
if context.type != ContextType.TEXT: |
|
|
|
logger.warn(f"[Gemini] Unsupported message type, type={context.type}") |
|
|
|
return Reply(ReplyType.TEXT, None) |
|
|
|
logger.info(f"[Gemini] query={query}") |
|
|
|
session_id = context["session_id"] |
|
|
|
session = self.sessions.session_query(query, session_id) |
|
|
|
gemini_messages = self._convert_to_gemini_messages(self._filter_messages(session.messages)) |
|
|
|
genai.configure(api_key=self.api_key) |
|
|
|
model = genai.GenerativeModel('gemini-pro') |
|
|
|
response = model.generate_content(gemini_messages) |
|
|
|
reply_text = response.text |
|
|
|
self.sessions.session_reply(reply_text, session_id) |
|
|
|
logger.info(f"[Gemini] reply={reply_text}") |
|
|
|
return Reply(ReplyType.TEXT, reply_text) |
|
|
|
except Exception as e: |
|
|
|
logger.error("[Gemini] fetch reply error, may contain unsafe content") |
|
|
|
logger.error(e) |
|
|
|
|
|
|
|
def _convert_to_gemini_messages(self, messages: list): |
|
|
|
res = [] |
|
|
@@ -56,3 +59,17 @@ class GoogleGeminiBot(Bot): |
|
|
|
"parts": [{"text": msg.get("content")}] |
|
|
|
}) |
|
|
|
return res |
|
|
|
|
|
|
|
def _filter_messages(self, messages: list): |
|
|
|
res = [] |
|
|
|
turn = "user" |
|
|
|
for i in range(len(messages) - 1, -1, -1): |
|
|
|
message = messages[i] |
|
|
|
if message.get("role") != turn: |
|
|
|
continue |
|
|
|
res.insert(0, message) |
|
|
|
if turn == "user": |
|
|
|
turn = "assistant" |
|
|
|
elif turn == "assistant": |
|
|
|
turn = "user" |
|
|
|
return res |