選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

80 行
2.7KB

  1. """
  2. Google gemini bot
  3. @author zhayujie
  4. @Date 2023/12/15
  5. """
  6. # encoding:utf-8
  7. from bot.bot import Bot
  8. import google.generativeai as genai
  9. from bot.session_manager import SessionManager
  10. from bridge.context import ContextType, Context
  11. from bridge.reply import Reply, ReplyType
  12. from common.log import logger
  13. from config import conf
  14. from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
  15. # OpenAI对话模型API (可用)
  16. class GoogleGeminiBot(Bot):
  17. def __init__(self):
  18. super().__init__()
  19. self.api_key = conf().get("gemini_api_key")
  20. # 复用文心的token计算方式
  21. self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "gpt-3.5-turbo")
  22. def reply(self, query, context: Context = None) -> Reply:
  23. try:
  24. if context.type != ContextType.TEXT:
  25. logger.warn(f"[Gemini] Unsupported message type, type={context.type}")
  26. return Reply(ReplyType.TEXT, None)
  27. logger.info(f"[Gemini] query={query}")
  28. session_id = context["session_id"]
  29. session = self.sessions.session_query(query, session_id)
  30. gemini_messages = self._convert_to_gemini_messages(self.filter_messages(session.messages))
  31. genai.configure(api_key=self.api_key)
  32. model = genai.GenerativeModel('gemini-pro')
  33. response = model.generate_content(gemini_messages)
  34. reply_text = response.text
  35. self.sessions.session_reply(reply_text, session_id)
  36. logger.info(f"[Gemini] reply={reply_text}")
  37. return Reply(ReplyType.TEXT, reply_text)
  38. except Exception as e:
  39. logger.error("[Gemini] fetch reply error, may contain unsafe content")
  40. logger.error(e)
  41. return Reply(ReplyType.ERROR, "invoke [Gemini] api failed!")
  42. def _convert_to_gemini_messages(self, messages: list):
  43. res = []
  44. for msg in messages:
  45. if msg.get("role") == "user":
  46. role = "user"
  47. elif msg.get("role") == "assistant":
  48. role = "model"
  49. else:
  50. continue
  51. res.append({
  52. "role": role,
  53. "parts": [{"text": msg.get("content")}]
  54. })
  55. return res
  56. @staticmethod
  57. def filter_messages(messages: list):
  58. res = []
  59. turn = "user"
  60. if not messages:
  61. return res
  62. for i in range(len(messages) - 1, -1, -1):
  63. message = messages[i]
  64. if message.get("role") != turn:
  65. continue
  66. res.insert(0, message)
  67. if turn == "user":
  68. turn = "assistant"
  69. elif turn == "assistant":
  70. turn = "user"
  71. return res