diff --git a/README.md b/README.md index d3201dc..00ef3d3 100644 --- a/README.md +++ b/README.md @@ -96,12 +96,8 @@ pip3 install --upgrade openai # config.json文件内容示例 { "open_ai_api_key": "YOUR API KEY", # 填入上面创建的 OpenAI API KEY - "open_ai_api_base": "https://api.openai.com/v1", # 自定义 OpenAI API 地址 + "model": "gpt-3.5-turbo", # 模型名称 "proxy": "127.0.0.1:7890", # 代理客户端的ip和端口 - "baidu_app_id": "", # 百度AI的App Id - "baidu_api_key": "", # 百度AI的API KEY - "baidu_secret_key": "", # 百度AI的Secret KEY - "wechaty_puppet_service_token":"", # wechaty服务token "single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复 "single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人 "group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复 @@ -109,7 +105,6 @@ pip3 install --upgrade openai "image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀 "conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数 "speech_recognition": false, # 是否开启语音识别 - "voice_reply_voice": false, # 是否开启语音回复 "character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述 } ``` @@ -133,6 +128,7 @@ pip3 install --upgrade openai **4.其他配置** ++ `model`: 模型名称,目前支持 `gpt-3.5-turbo`, `text-davinci-003`, `gpt-4`, `gpt-4-32k` (其中gpt-4 api暂未开放) + `proxy`:由于目前 `openai` 接口国内无法访问,需配置代理客户端的地址,详情参考 [#351](https://github.com/zhayujie/chatgpt-on-wechat/issues/351) + 对于图像生成,在满足个人或群组触发条件外,还需要额外的关键词前缀来触发,对应配置 `image_create_prefix ` + 关于OpenAI对话及图片接口的参数配置(内容自由度、回复字数限制、图片大小等),可以参考 [对话接口](https://beta.openai.com/docs/api-reference/completions) 和 [图像接口](https://beta.openai.com/docs/api-reference/completions) 文档直接在 [代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/bot/openai/open_ai_bot.py) `bot/openai/open_ai_bot.py` 中进行调整。 diff --git a/bot/bot_factory.py b/bot/bot_factory.py index dd590c7..a920524 100644 --- a/bot/bot_factory.py +++ b/bot/bot_factory.py @@ -1,6 +1,7 @@ """ channel factory """ +from common import const def create_bot(bot_type): @@ -9,17 +10,17 @@ def create_bot(bot_type): :param channel_type: channel type code :return: channel instance """ - if bot_type == 'baidu': + if bot_type == const.BAIDU: # Baidu Unit对话接口 from bot.baidu.baidu_unit_bot import BaiduUnitBot return BaiduUnitBot() - elif bot_type == 'chatGPT': + elif bot_type == const.CHATGPT: # ChatGPT 网页端web接口 from bot.chatgpt.chat_gpt_bot import ChatGPTBot return ChatGPTBot() - elif bot_type == 'openAI': + elif bot_type == const.OPEN_AI: # OpenAI 官方对话模型API from bot.openai.open_ai_bot import OpenAIBot return OpenAIBot() diff --git a/bot/chatgpt/chat_gpt_bot.py b/bot/chatgpt/chat_gpt_bot.py index 685514f..be2fa67 100644 --- a/bot/chatgpt/chat_gpt_bot.py +++ b/bot/chatgpt/chat_gpt_bot.py @@ -63,7 +63,7 @@ class ChatGPTBot(Bot): ''' try: response = openai.ChatCompletion.create( - model="gpt-3.5-turbo", # 对话模型的名称 + model= conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称 messages=session, temperature=0.9, # 值在[0,1]之间,越大表示回复越具有不确定性 #max_tokens=4096, # 回复最大的字符数 diff --git a/bot/openai/open_ai_bot.py b/bot/openai/open_ai_bot.py index 7822bae..5ed1aa2 100644 --- a/bot/openai/open_ai_bot.py +++ b/bot/openai/open_ai_bot.py @@ -45,7 +45,7 @@ class OpenAIBot(Bot): def reply_text(self, query, user_id, retry_count=0): try: response = openai.Completion.create( - model="text-davinci-003", # 对话模型的名称 + model= conf().get("model") or "text-davinci-003", # 对话模型的名称 prompt=query, temperature=0.9, # 值在[0,1]之间,越大表示回复越具有不确定性 max_tokens=1200, # 回复最大的字符数 diff --git a/bridge/bridge.py b/bridge/bridge.py index e739a7f..e44c3e4 100644 --- a/bridge/bridge.py +++ b/bridge/bridge.py @@ -1,5 +1,7 @@ from bot import bot_factory from voice import voice_factory +from config import conf +from common import const class Bridge(object): @@ -7,7 +9,13 @@ class Bridge(object): pass def fetch_reply_content(self, query, context): - return bot_factory.create_bot("chatGPT").reply(query, context) + bot_type = const.CHATGPT + model_type = conf().get("model") + if model_type in ["gpt-3.5-turbo", "gpt-4", "gpt-4-32k"]: + bot_type = const.CHATGPT + elif model_type in ["text-davinci-003"]: + bot_type = const.OPEN_AI + return bot_factory.create_bot(bot_type).reply(query, context) def fetch_voice_to_text(self, voiceFile): return voice_factory.create_voice("openai").voiceToText(voiceFile) diff --git a/channel/channel_factory.py b/channel/channel_factory.py index bfeaacf..546f18a 100644 --- a/channel/channel_factory.py +++ b/channel/channel_factory.py @@ -14,4 +14,7 @@ def create_channel(channel_type): elif channel_type == 'wxy': from channel.wechat.wechaty_channel import WechatyChannel return WechatyChannel() + elif channel_type == 'terminal': + from channel.terminal.terminal_channel import TerminalChannel + return TerminalChannel() raise RuntimeError diff --git a/channel/terminal/terminal_channel.py b/channel/terminal/terminal_channel.py new file mode 100644 index 0000000..1c9a61d --- /dev/null +++ b/channel/terminal/terminal_channel.py @@ -0,0 +1,29 @@ +from channel.channel import Channel +import sys + +class TerminalChannel(Channel): + def startup(self): + context = {"from_user_id": "User"} + print("\nPlease input your question") + while True: + try: + prompt = self.get_input("User:\n") + except KeyboardInterrupt: + print("\nExiting...") + sys.exit() + + print("Bot:") + sys.stdout.flush() + for res in super().build_reply_content(prompt, context): + print(res, end="") + sys.stdout.flush() + print("\n") + + + def get_input(self, prompt): + """ + Multi-line input function + """ + print(prompt, end="") + line = input() + return line diff --git a/common/const.py b/common/const.py new file mode 100644 index 0000000..37f2dbd --- /dev/null +++ b/common/const.py @@ -0,0 +1,4 @@ +# bot_type +OPEN_AI = "openAI" +CHATGPT = "chatGPT" +BAIDU = "baidu" \ No newline at end of file diff --git a/config-template.json b/config-template.json index 88e0e34..154cb9e 100644 --- a/config-template.json +++ b/config-template.json @@ -1,5 +1,6 @@ { "open_ai_api_key": "YOUR API KEY", + "model": "gpt-3.5-turbo", "proxy": "", "single_chat_prefix": ["bot", "@bot"], "single_chat_reply_prefix": "[bot] ",