From 23a237074ed1bcd1097fb3fcb370687896060b00 Mon Sep 17 00:00:00 2001 From: zhayujie Date: Fri, 15 Dec 2023 10:19:48 +0800 Subject: [PATCH] feat: support gemini model --- bot/bot_factory.py | 5 +++ bot/chatgpt/chat_gpt_session.py | 2 +- bot/gemini/google_gemini_bot.py | 58 +++++++++++++++++++++++++++++++++ bridge/bridge.py | 4 +++ common/const.py | 3 +- config.py | 2 ++ plugins/godcmd/godcmd.py | 2 +- 7 files changed, 73 insertions(+), 3 deletions(-) create mode 100644 bot/gemini/google_gemini_bot.py diff --git a/bot/bot_factory.py b/bot/bot_factory.py index a0edde1..a103209 100644 --- a/bot/bot_factory.py +++ b/bot/bot_factory.py @@ -47,4 +47,9 @@ def create_bot(bot_type): elif bot_type == const.QWEN: from bot.tongyi.tongyi_qwen_bot import TongyiQwenBot return TongyiQwenBot() + + elif bot_type == const.GEMINI: + from bot.gemini.google_gemini_bot import GoogleGeminiBot + return GoogleGeminiBot() + raise RuntimeError diff --git a/bot/chatgpt/chat_gpt_session.py b/bot/chatgpt/chat_gpt_session.py index e7dabec..74914f2 100644 --- a/bot/chatgpt/chat_gpt_session.py +++ b/bot/chatgpt/chat_gpt_session.py @@ -57,7 +57,7 @@ class ChatGPTSession(Session): def num_tokens_from_messages(messages, model): """Returns the number of tokens used by a list of messages.""" - if model in ["wenxin", "xunfei"]: + if model in ["wenxin", "xunfei", const.GEMINI]: return num_tokens_by_character(messages) import tiktoken diff --git a/bot/gemini/google_gemini_bot.py b/bot/gemini/google_gemini_bot.py new file mode 100644 index 0000000..4cc0dd3 --- /dev/null +++ b/bot/gemini/google_gemini_bot.py @@ -0,0 +1,58 @@ +""" +Google gemini bot + +@author zhayujie +@Date 2023/12/15 +""" +# encoding:utf-8 + +from bot.bot import Bot +import google.generativeai as genai +from bot.session_manager import SessionManager +from bridge.context import ContextType, Context +from bridge.reply import Reply, ReplyType +from common.log import logger +from config import conf +from bot.baidu.baidu_wenxin_session import BaiduWenxinSession + + +# OpenAI对话模型API (可用) +class GoogleGeminiBot(Bot): + + def __init__(self): + super().__init__() + self.api_key = conf().get("gemini_api_key") + # 复用文心的token计算方式 + self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "gpt-3.5-turbo") + + def reply(self, query, context: Context = None) -> Reply: + if context.type != ContextType.TEXT: + logger.warn(f"[Gemini] Unsupported message type, type={context.type}") + return Reply(ReplyType.TEXT, None) + logger.info(f"[Gemini] query={query}") + session_id = context["session_id"] + session = self.sessions.session_query(query, session_id) + gemini_messages = self._convert_to_gemini_messages(session.messages) + genai.configure(api_key=self.api_key) + model = genai.GenerativeModel('gemini-pro') + response = model.generate_content(gemini_messages) + reply_text = response.text + self.sessions.session_reply(reply_text, session_id) + logger.info(f"[Gemini] reply={reply_text}") + return Reply(ReplyType.TEXT, reply_text) + + + def _convert_to_gemini_messages(self, messages: list): + res = [] + for msg in messages: + if msg.get("role") == "user": + role = "user" + elif msg.get("role") == "assistant": + role = "model" + else: + continue + res.append({ + "role": role, + "parts": [{"text": msg.get("content")}] + }) + return res diff --git a/bridge/bridge.py b/bridge/bridge.py index 2b637c3..53ee878 100644 --- a/bridge/bridge.py +++ b/bridge/bridge.py @@ -29,12 +29,16 @@ class Bridge(object): self.btype["chat"] = const.XUNFEI if model_type in [const.QWEN]: self.btype["chat"] = const.QWEN + if model_type in [const.GEMINI]: + self.btype["chat"] = const.GEMINI + if conf().get("use_linkai") and conf().get("linkai_api_key"): self.btype["chat"] = const.LINKAI if not conf().get("voice_to_text") or conf().get("voice_to_text") in ["openai"]: self.btype["voice_to_text"] = const.LINKAI if not conf().get("text_to_voice") or conf().get("text_to_voice") in ["openai", const.TTS_1, const.TTS_1_HD]: self.btype["text_to_voice"] = const.LINKAI + if model_type in ["claude"]: self.btype["chat"] = const.CLAUDEAI self.bots = {} diff --git a/common/const.py b/common/const.py index fc74e64..b2d0df6 100644 --- a/common/const.py +++ b/common/const.py @@ -7,6 +7,7 @@ CHATGPTONAZURE = "chatGPTOnAzure" LINKAI = "linkai" CLAUDEAI = "claude" QWEN = "qwen" +GEMINI = "gemini" # model GPT35 = "gpt-3.5-turbo" @@ -17,7 +18,7 @@ WHISPER_1 = "whisper-1" TTS_1 = "tts-1" TTS_1_HD = "tts-1-hd" -MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "wenxin-4", "xunfei", "claude", "gpt-4-turbo", GPT4_TURBO_PREVIEW, QWEN] +MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "wenxin-4", "xunfei", "claude", "gpt-4-turbo", GPT4_TURBO_PREVIEW, QWEN, GEMINI] # channel FEISHU = "feishu" diff --git a/config.py b/config.py index 8300699..bc4d9f7 100644 --- a/config.py +++ b/config.py @@ -73,6 +73,8 @@ available_setting = { "qwen_agent_key": "", "qwen_app_id": "", "qwen_node_id": "", # 流程编排模型用到的id,如果没有用到qwen_node_id,请务必保持为空字符串 + # Google Gemini Api Key + "gemini_api_key": "", # wework的通用配置 "wework_smart": True, # 配置wework是否使用已登录的企业微信,False为多开 # 语音设置 diff --git a/plugins/godcmd/godcmd.py b/plugins/godcmd/godcmd.py index 03a96bd..dd301e6 100644 --- a/plugins/godcmd/godcmd.py +++ b/plugins/godcmd/godcmd.py @@ -313,7 +313,7 @@ class Godcmd(Plugin): except Exception as e: ok, result = False, "你没有设置私有GPT模型" elif cmd == "reset": - if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI, const.BAIDU, const.XUNFEI]: + if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI, const.BAIDU, const.XUNFEI, const.GEMINI]: bot.sessions.clear_session(session_id) if Bridge().chat_bots.get(bottype): Bridge().chat_bots.get(bottype).sessions.clear_session(session_id)