@@ -47,4 +47,9 @@ def create_bot(bot_type): | |||||
elif bot_type == const.QWEN: | elif bot_type == const.QWEN: | ||||
from bot.ali.ali_qwen_bot import AliQwenBot | from bot.ali.ali_qwen_bot import AliQwenBot | ||||
return AliQwenBot() | return AliQwenBot() | ||||
elif bot_type == const.GEMINI: | |||||
from bot.gemini.google_gemini_bot import GoogleGeminiBot | |||||
return GoogleGeminiBot() | |||||
raise RuntimeError | raise RuntimeError |
@@ -57,7 +57,7 @@ class ChatGPTSession(Session): | |||||
def num_tokens_from_messages(messages, model): | def num_tokens_from_messages(messages, model): | ||||
"""Returns the number of tokens used by a list of messages.""" | """Returns the number of tokens used by a list of messages.""" | ||||
if model in ["wenxin", "xunfei"]: | |||||
if model in ["wenxin", "xunfei", const.GEMINI]: | |||||
return num_tokens_by_character(messages) | return num_tokens_by_character(messages) | ||||
import tiktoken | import tiktoken | ||||
@@ -0,0 +1,75 @@ | |||||
""" | |||||
Google gemini bot | |||||
@author zhayujie | |||||
@Date 2023/12/15 | |||||
""" | |||||
# encoding:utf-8 | |||||
from bot.bot import Bot | |||||
import google.generativeai as genai | |||||
from bot.session_manager import SessionManager | |||||
from bridge.context import ContextType, Context | |||||
from bridge.reply import Reply, ReplyType | |||||
from common.log import logger | |||||
from config import conf | |||||
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession | |||||
# OpenAI对话模型API (可用) | |||||
class GoogleGeminiBot(Bot): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.api_key = conf().get("gemini_api_key") | |||||
# 复用文心的token计算方式 | |||||
self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "gpt-3.5-turbo") | |||||
def reply(self, query, context: Context = None) -> Reply: | |||||
try: | |||||
if context.type != ContextType.TEXT: | |||||
logger.warn(f"[Gemini] Unsupported message type, type={context.type}") | |||||
return Reply(ReplyType.TEXT, None) | |||||
logger.info(f"[Gemini] query={query}") | |||||
session_id = context["session_id"] | |||||
session = self.sessions.session_query(query, session_id) | |||||
gemini_messages = self._convert_to_gemini_messages(self._filter_messages(session.messages)) | |||||
genai.configure(api_key=self.api_key) | |||||
model = genai.GenerativeModel('gemini-pro') | |||||
response = model.generate_content(gemini_messages) | |||||
reply_text = response.text | |||||
self.sessions.session_reply(reply_text, session_id) | |||||
logger.info(f"[Gemini] reply={reply_text}") | |||||
return Reply(ReplyType.TEXT, reply_text) | |||||
except Exception as e: | |||||
logger.error("[Gemini] fetch reply error, may contain unsafe content") | |||||
logger.error(e) | |||||
def _convert_to_gemini_messages(self, messages: list): | |||||
res = [] | |||||
for msg in messages: | |||||
if msg.get("role") == "user": | |||||
role = "user" | |||||
elif msg.get("role") == "assistant": | |||||
role = "model" | |||||
else: | |||||
continue | |||||
res.append({ | |||||
"role": role, | |||||
"parts": [{"text": msg.get("content")}] | |||||
}) | |||||
return res | |||||
def _filter_messages(self, messages: list): | |||||
res = [] | |||||
turn = "user" | |||||
for i in range(len(messages) - 1, -1, -1): | |||||
message = messages[i] | |||||
if message.get("role") != turn: | |||||
continue | |||||
res.insert(0, message) | |||||
if turn == "user": | |||||
turn = "assistant" | |||||
elif turn == "assistant": | |||||
turn = "user" | |||||
return res |
@@ -29,12 +29,16 @@ class Bridge(object): | |||||
self.btype["chat"] = const.XUNFEI | self.btype["chat"] = const.XUNFEI | ||||
if model_type in [const.QWEN]: | if model_type in [const.QWEN]: | ||||
self.btype["chat"] = const.QWEN | self.btype["chat"] = const.QWEN | ||||
if model_type in [const.GEMINI]: | |||||
self.btype["chat"] = const.GEMINI | |||||
if conf().get("use_linkai") and conf().get("linkai_api_key"): | if conf().get("use_linkai") and conf().get("linkai_api_key"): | ||||
self.btype["chat"] = const.LINKAI | self.btype["chat"] = const.LINKAI | ||||
if not conf().get("voice_to_text") or conf().get("voice_to_text") in ["openai"]: | if not conf().get("voice_to_text") or conf().get("voice_to_text") in ["openai"]: | ||||
self.btype["voice_to_text"] = const.LINKAI | self.btype["voice_to_text"] = const.LINKAI | ||||
if not conf().get("text_to_voice") or conf().get("text_to_voice") in ["openai", const.TTS_1, const.TTS_1_HD]: | if not conf().get("text_to_voice") or conf().get("text_to_voice") in ["openai", const.TTS_1, const.TTS_1_HD]: | ||||
self.btype["text_to_voice"] = const.LINKAI | self.btype["text_to_voice"] = const.LINKAI | ||||
if model_type in ["claude"]: | if model_type in ["claude"]: | ||||
self.btype["chat"] = const.CLAUDEAI | self.btype["chat"] = const.CLAUDEAI | ||||
self.bots = {} | self.bots = {} | ||||
@@ -7,6 +7,7 @@ CHATGPTONAZURE = "chatGPTOnAzure" | |||||
LINKAI = "linkai" | LINKAI = "linkai" | ||||
CLAUDEAI = "claude" | CLAUDEAI = "claude" | ||||
QWEN = "qwen" | QWEN = "qwen" | ||||
GEMINI = "gemini" | |||||
# model | # model | ||||
GPT35 = "gpt-3.5-turbo" | GPT35 = "gpt-3.5-turbo" | ||||
@@ -17,7 +18,7 @@ WHISPER_1 = "whisper-1" | |||||
TTS_1 = "tts-1" | TTS_1 = "tts-1" | ||||
TTS_1_HD = "tts-1-hd" | TTS_1_HD = "tts-1-hd" | ||||
MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "wenxin-4", "xunfei", "claude", "gpt-4-turbo", GPT4_TURBO_PREVIEW, QWEN] | |||||
MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "wenxin-4", "xunfei", "claude", "gpt-4-turbo", GPT4_TURBO_PREVIEW, QWEN, GEMINI] | |||||
# channel | # channel | ||||
FEISHU = "feishu" | FEISHU = "feishu" |
@@ -73,6 +73,8 @@ available_setting = { | |||||
"qwen_agent_key": "", | "qwen_agent_key": "", | ||||
"qwen_app_id": "", | "qwen_app_id": "", | ||||
"qwen_node_id": "", # 流程编排模型用到的id,如果没有用到qwen_node_id,请务必保持为空字符串 | "qwen_node_id": "", # 流程编排模型用到的id,如果没有用到qwen_node_id,请务必保持为空字符串 | ||||
# Google Gemini Api Key | |||||
"gemini_api_key": "", | |||||
# wework的通用配置 | # wework的通用配置 | ||||
"wework_smart": True, # 配置wework是否使用已登录的企业微信,False为多开 | "wework_smart": True, # 配置wework是否使用已登录的企业微信,False为多开 | ||||
# 语音设置 | # 语音设置 | ||||
@@ -313,7 +313,7 @@ class Godcmd(Plugin): | |||||
except Exception as e: | except Exception as e: | ||||
ok, result = False, "你没有设置私有GPT模型" | ok, result = False, "你没有设置私有GPT模型" | ||||
elif cmd == "reset": | elif cmd == "reset": | ||||
if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI, const.BAIDU, const.XUNFEI, const.QWEN]: | |||||
if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI, const.BAIDU, const.XUNFEI, const.QWEN, const.GEMINI]: | |||||
bot.sessions.clear_session(session_id) | bot.sessions.clear_session(session_id) | ||||
if Bridge().chat_bots.get(bottype): | if Bridge().chat_bots.get(bottype): | ||||
Bridge().chat_bots.get(bottype).sessions.clear_session(session_id) | Bridge().chat_bots.get(bottype).sessions.clear_session(session_id) | ||||
@@ -339,7 +339,7 @@ class Godcmd(Plugin): | |||||
ok, result = True, "配置已重载" | ok, result = True, "配置已重载" | ||||
elif cmd == "resetall": | elif cmd == "resetall": | ||||
if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI, | if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI, | ||||
const.BAIDU, const.XUNFEI, const.QWEN]: | |||||
const.BAIDU, const.XUNFEI, const.QWEN, const.GEMINI]: | |||||
channel.cancel_all_session() | channel.cancel_all_session() | ||||
bot.sessions.clear_all_session() | bot.sessions.clear_all_session() | ||||
ok, result = True, "重置所有会话成功" | ok, result = True, "重置所有会话成功" | ||||