Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

82 lines
2.8KB

  1. """
  2. Google gemini bot
  3. @author zhayujie
  4. @Date 2023/12/15
  5. """
  6. # encoding:utf-8
  7. from bot.bot import Bot
  8. import google.generativeai as genai
  9. from bot.session_manager import SessionManager
  10. from bridge.context import ContextType, Context
  11. from bridge.reply import Reply, ReplyType
  12. from common.log import logger
  13. from config import conf
  14. from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
  15. # OpenAI对话模型API (可用)
  16. class GoogleGeminiBot(Bot):
  17. def __init__(self):
  18. super().__init__()
  19. self.api_key = conf().get("gemini_api_key")
  20. # 复用文心的token计算方式
  21. self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "gpt-3.5-turbo")
  22. self.model = conf().get("model") or "gemini-pro"
  23. if self.model == "gemini":
  24. self.model = "gemini-pro"
  25. def reply(self, query, context: Context = None) -> Reply:
  26. try:
  27. if context.type != ContextType.TEXT:
  28. logger.warn(f"[Gemini] Unsupported message type, type={context.type}")
  29. return Reply(ReplyType.TEXT, None)
  30. logger.info(f"[Gemini] query={query}")
  31. session_id = context["session_id"]
  32. session = self.sessions.session_query(query, session_id)
  33. gemini_messages = self._convert_to_gemini_messages(self.filter_messages(session.messages))
  34. genai.configure(api_key=self.api_key)
  35. model = genai.GenerativeModel(self.model)
  36. response = model.generate_content(gemini_messages)
  37. reply_text = response.text
  38. self.sessions.session_reply(reply_text, session_id)
  39. logger.info(f"[Gemini] reply={reply_text}")
  40. return Reply(ReplyType.TEXT, reply_text)
  41. except Exception as e:
  42. logger.error("[Gemini] fetch reply error, may contain unsafe content")
  43. logger.error(e)
  44. return Reply(ReplyType.ERROR, "invoke [Gemini] api failed!")
  45. def _convert_to_gemini_messages(self, messages: list):
  46. res = []
  47. for msg in messages:
  48. if msg.get("role") == "user":
  49. role = "user"
  50. elif msg.get("role") == "assistant":
  51. role = "model"
  52. else:
  53. continue
  54. res.append({
  55. "role": role,
  56. "parts": [{"text": msg.get("content")}]
  57. })
  58. return res
  59. @staticmethod
  60. def filter_messages(messages: list):
  61. res = []
  62. turn = "user"
  63. if not messages:
  64. return res
  65. for i in range(len(messages) - 1, -1, -1):
  66. message = messages[i]
  67. if message.get("role") != turn:
  68. continue
  69. res.insert(0, message)
  70. if turn == "user":
  71. turn = "assistant"
  72. elif turn == "assistant":
  73. turn = "user"
  74. return res