You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

google_gemini_bot.py 2.8KB

8 kuukautta sitten
11 kuukautta sitten
8 kuukautta sitten
11 kuukautta sitten
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. """
  2. Google gemini bot
  3. @author zhayujie
  4. @Date 2023/12/15
  5. """
  6. # encoding:utf-8
  7. from bot.bot import Bot
  8. import google.generativeai as genai
  9. from bot.session_manager import SessionManager
  10. from bridge.context import ContextType, Context
  11. from bridge.reply import Reply, ReplyType
  12. from common.log import logger
  13. from config import conf
  14. from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
  15. # OpenAI对话模型API (可用)
  16. class GoogleGeminiBot(Bot):
  17. def __init__(self):
  18. super().__init__()
  19. self.api_key = conf().get("gemini_api_key")
  20. # 复用文心的token计算方式
  21. self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "gpt-3.5-turbo")
  22. self.model = conf().get("model") or "gemini-pro"
  23. def reply(self, query, context: Context = None) -> Reply:
  24. try:
  25. if context.type != ContextType.TEXT:
  26. logger.warn(f"[Gemini] Unsupported message type, type={context.type}")
  27. return Reply(ReplyType.TEXT, None)
  28. logger.info(f"[Gemini] query={query}")
  29. session_id = context["session_id"]
  30. session = self.sessions.session_query(query, session_id)
  31. gemini_messages = self._convert_to_gemini_messages(self.filter_messages(session.messages))
  32. genai.configure(api_key=self.api_key)
  33. model = genai.GenerativeModel(self.model)
  34. response = model.generate_content(gemini_messages)
  35. reply_text = response.text
  36. self.sessions.session_reply(reply_text, session_id)
  37. logger.info(f"[Gemini] reply={reply_text}")
  38. return Reply(ReplyType.TEXT, reply_text)
  39. except Exception as e:
  40. logger.error("[Gemini] fetch reply error, may contain unsafe content")
  41. logger.error(e)
  42. return Reply(ReplyType.ERROR, "invoke [Gemini] api failed!")
  43. def _convert_to_gemini_messages(self, messages: list):
  44. res = []
  45. for msg in messages:
  46. if msg.get("role") == "user":
  47. role = "user"
  48. elif msg.get("role") == "assistant":
  49. role = "model"
  50. else:
  51. continue
  52. res.append({
  53. "role": role,
  54. "parts": [{"text": msg.get("content")}]
  55. })
  56. return res
  57. @staticmethod
  58. def filter_messages(messages: list):
  59. res = []
  60. turn = "user"
  61. if not messages:
  62. return res
  63. for i in range(len(messages) - 1, -1, -1):
  64. message = messages[i]
  65. if message.get("role") != turn:
  66. continue
  67. res.insert(0, message)
  68. if turn == "user":
  69. turn = "assistant"
  70. elif turn == "assistant":
  71. turn = "user"
  72. return res