diff --git a/bot/gemini/google_gemini_bot.py b/bot/gemini/google_gemini_bot.py index 6132b78..6305481 100644 --- a/bot/gemini/google_gemini_bot.py +++ b/bot/gemini/google_gemini_bot.py @@ -24,7 +24,7 @@ class GoogleGeminiBot(Bot): self.api_key = conf().get("gemini_api_key") # 复用文心的token计算方式 self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "gpt-3.5-turbo") - + self.model = conf().get("model") or "gemini-pro" def reply(self, query, context: Context = None) -> Reply: try: if context.type != ContextType.TEXT: @@ -35,7 +35,7 @@ class GoogleGeminiBot(Bot): session = self.sessions.session_query(query, session_id) gemini_messages = self._convert_to_gemini_messages(self.filter_messages(session.messages)) genai.configure(api_key=self.api_key) - model = genai.GenerativeModel('gemini-pro') + model = genai.GenerativeModel(self.model) response = model.generate_content(gemini_messages) reply_text = response.text self.sessions.session_reply(reply_text, session_id) diff --git a/bridge/bridge.py b/bridge/bridge.py index 2432926..b7b3ebf 100644 --- a/bridge/bridge.py +++ b/bridge/bridge.py @@ -36,7 +36,7 @@ class Bridge(object): self.btype["chat"] = const.QWEN if model_type in [const.QWEN_TURBO, const.QWEN_PLUS, const.QWEN_MAX]: self.btype["chat"] = const.QWEN_DASHSCOPE - if model_type in [const.GEMINI]: + if model_type and model_type.startswith("gemini"): self.btype["chat"] = const.GEMINI if model_type in [const.ZHIPU_AI]: self.btype["chat"] = const.ZHIPU_AI