You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

google_gemini_bot.py 2.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. """
  2. Google gemini bot
  3. @author zhayujie
  4. @Date 2023/12/15
  5. """
  6. # encoding:utf-8
  7. from bot.bot import Bot
  8. import google.generativeai as genai
  9. from bot.session_manager import SessionManager
  10. from bridge.context import ContextType, Context
  11. from bridge.reply import Reply, ReplyType
  12. from common.log import logger
  13. from config import conf
  14. from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
  15. # OpenAI对话模型API (可用)
  16. class GoogleGeminiBot(Bot):
  17. def __init__(self):
  18. super().__init__()
  19. self.api_key = conf().get("gemini_api_key")
  20. # 复用文心的token计算方式
  21. self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "gpt-3.5-turbo")
  22. def reply(self, query, context: Context = None) -> Reply:
  23. try:
  24. if context.type != ContextType.TEXT:
  25. logger.warn(f"[Gemini] Unsupported message type, type={context.type}")
  26. return Reply(ReplyType.TEXT, None)
  27. logger.info(f"[Gemini] query={query}")
  28. session_id = context["session_id"]
  29. session = self.sessions.session_query(query, session_id)
  30. gemini_messages = self._convert_to_gemini_messages(self._filter_messages(session.messages))
  31. genai.configure(api_key=self.api_key)
  32. model = genai.GenerativeModel('gemini-pro')
  33. response = model.generate_content(gemini_messages)
  34. reply_text = response.text
  35. self.sessions.session_reply(reply_text, session_id)
  36. logger.info(f"[Gemini] reply={reply_text}")
  37. return Reply(ReplyType.TEXT, reply_text)
  38. except Exception as e:
  39. logger.error("[Gemini] fetch reply error, may contain unsafe content")
  40. logger.error(e)
  41. def _convert_to_gemini_messages(self, messages: list):
  42. res = []
  43. for msg in messages:
  44. if msg.get("role") == "user":
  45. role = "user"
  46. elif msg.get("role") == "assistant":
  47. role = "model"
  48. else:
  49. continue
  50. res.append({
  51. "role": role,
  52. "parts": [{"text": msg.get("content")}]
  53. })
  54. return res
  55. def _filter_messages(self, messages: list):
  56. res = []
  57. turn = "user"
  58. for i in range(len(messages) - 1, -1, -1):
  59. message = messages[i]
  60. if message.get("role") != turn:
  61. continue
  62. res.insert(0, message)
  63. if turn == "user":
  64. turn = "assistant"
  65. elif turn == "assistant":
  66. turn = "user"
  67. return res