Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

google_gemini_bot.py 2.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. """
  2. Google gemini bot
  3. @author zhayujie
  4. @Date 2023/12/15
  5. """
  6. # encoding:utf-8
  7. from bot.bot import Bot
  8. import google.generativeai as genai
  9. from bot.session_manager import SessionManager
  10. from bridge.context import ContextType, Context
  11. from bridge.reply import Reply, ReplyType
  12. from common.log import logger
  13. from config import conf
  14. from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
  15. # OpenAI对话模型API (可用)
  16. class GoogleGeminiBot(Bot):
  17. def __init__(self):
  18. super().__init__()
  19. self.api_key = conf().get("gemini_api_key")
  20. # 复用文心的token计算方式
  21. self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "gpt-3.5-turbo")
  22. def reply(self, query, context: Context = None) -> Reply:
  23. try:
  24. if context.type != ContextType.TEXT:
  25. logger.warn(f"[Gemini] Unsupported message type, type={context.type}")
  26. return Reply(ReplyType.TEXT, None)
  27. logger.info(f"[Gemini] query={query}")
  28. session_id = context["session_id"]
  29. session = self.sessions.session_query(query, session_id)
  30. gemini_messages = self._convert_to_gemini_messages(self._filter_messages(session.messages))
  31. genai.configure(api_key=self.api_key)
  32. model = genai.GenerativeModel('gemini-pro')
  33. response = model.generate_content(gemini_messages)
  34. reply_text = response.text
  35. self.sessions.session_reply(reply_text, session_id)
  36. logger.info(f"[Gemini] reply={reply_text}")
  37. return Reply(ReplyType.TEXT, reply_text)
  38. except Exception as e:
  39. logger.error("[Gemini] fetch reply error, may contain unsafe content")
  40. logger.error(e)
  41. return Reply(ReplyType.ERROR, "invoke [Gemini] api failed!")
  42. def _convert_to_gemini_messages(self, messages: list):
  43. res = []
  44. for msg in messages:
  45. if msg.get("role") == "user":
  46. role = "user"
  47. elif msg.get("role") == "assistant":
  48. role = "model"
  49. else:
  50. continue
  51. res.append({
  52. "role": role,
  53. "parts": [{"text": msg.get("content")}]
  54. })
  55. return res
  56. def _filter_messages(self, messages: list):
  57. res = []
  58. turn = "user"
  59. if not messages:
  60. return res
  61. for i in range(len(messages) - 1, -1, -1):
  62. message = messages[i]
  63. if message.get("role") != turn:
  64. continue
  65. res.insert(0, message)
  66. if turn == "user":
  67. turn = "assistant"
  68. elif turn == "assistant":
  69. turn = "user"
  70. return res