diff --git a/bot/gemini/google_gemini_bot.py b/bot/gemini/google_gemini_bot.py index 4cc0dd3..1a49d60 100644 --- a/bot/gemini/google_gemini_bot.py +++ b/bot/gemini/google_gemini_bot.py @@ -26,21 +26,24 @@ class GoogleGeminiBot(Bot): self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "gpt-3.5-turbo") def reply(self, query, context: Context = None) -> Reply: - if context.type != ContextType.TEXT: - logger.warn(f"[Gemini] Unsupported message type, type={context.type}") - return Reply(ReplyType.TEXT, None) - logger.info(f"[Gemini] query={query}") - session_id = context["session_id"] - session = self.sessions.session_query(query, session_id) - gemini_messages = self._convert_to_gemini_messages(session.messages) - genai.configure(api_key=self.api_key) - model = genai.GenerativeModel('gemini-pro') - response = model.generate_content(gemini_messages) - reply_text = response.text - self.sessions.session_reply(reply_text, session_id) - logger.info(f"[Gemini] reply={reply_text}") - return Reply(ReplyType.TEXT, reply_text) - + try: + if context.type != ContextType.TEXT: + logger.warn(f"[Gemini] Unsupported message type, type={context.type}") + return Reply(ReplyType.TEXT, None) + logger.info(f"[Gemini] query={query}") + session_id = context["session_id"] + session = self.sessions.session_query(query, session_id) + gemini_messages = self._convert_to_gemini_messages(self._filter_messages(session.messages)) + genai.configure(api_key=self.api_key) + model = genai.GenerativeModel('gemini-pro') + response = model.generate_content(gemini_messages) + reply_text = response.text + self.sessions.session_reply(reply_text, session_id) + logger.info(f"[Gemini] reply={reply_text}") + return Reply(ReplyType.TEXT, reply_text) + except Exception as e: + logger.error("[Gemini] fetch reply error, may contain unsafe content") + logger.error(e) def _convert_to_gemini_messages(self, messages: list): res = [] @@ -56,3 +59,17 @@ class GoogleGeminiBot(Bot): "parts": [{"text": msg.get("content")}] }) return res + + def _filter_messages(self, messages: list): + res = [] + turn = "user" + for i in range(len(messages) - 1, -1, -1): + message = messages[i] + if message.get("role") != turn: + continue + res.insert(0, message) + if turn == "user": + turn = "assistant" + elif turn == "assistant": + turn = "user" + return res