You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

59 lines
1.9KB

  1. """
  2. Google gemini bot
  3. @author zhayujie
  4. @Date 2023/12/15
  5. """
  6. # encoding:utf-8
  7. from bot.bot import Bot
  8. import google.generativeai as genai
  9. from bot.session_manager import SessionManager
  10. from bridge.context import ContextType, Context
  11. from bridge.reply import Reply, ReplyType
  12. from common.log import logger
  13. from config import conf
  14. from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
  15. # OpenAI对话模型API (可用)
  16. class GoogleGeminiBot(Bot):
  17. def __init__(self):
  18. super().__init__()
  19. self.api_key = conf().get("gemini_api_key")
  20. # 复用文心的token计算方式
  21. self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "gpt-3.5-turbo")
  22. def reply(self, query, context: Context = None) -> Reply:
  23. if context.type != ContextType.TEXT:
  24. logger.warn(f"[Gemini] Unsupported message type, type={context.type}")
  25. return Reply(ReplyType.TEXT, None)
  26. logger.info(f"[Gemini] query={query}")
  27. session_id = context["session_id"]
  28. session = self.sessions.session_query(query, session_id)
  29. gemini_messages = self._convert_to_gemini_messages(session.messages)
  30. genai.configure(api_key=self.api_key)
  31. model = genai.GenerativeModel('gemini-pro')
  32. response = model.generate_content(gemini_messages)
  33. reply_text = response.text
  34. self.sessions.session_reply(reply_text, session_id)
  35. logger.info(f"[Gemini] reply={reply_text}")
  36. return Reply(ReplyType.TEXT, reply_text)
  37. def _convert_to_gemini_messages(self, messages: list):
  38. res = []
  39. for msg in messages:
  40. if msg.get("role") == "user":
  41. role = "user"
  42. elif msg.get("role") == "assistant":
  43. role = "model"
  44. else:
  45. continue
  46. res.append({
  47. "role": role,
  48. "parts": [{"text": msg.get("content")}]
  49. })
  50. return res