@@ -12,13 +12,15 @@ from common import const | |||||
class Bridge(object): | class Bridge(object): | ||||
def __init__(self): | def __init__(self): | ||||
self.btype={ | self.btype={ | ||||
"chat": "chatGPT", | |||||
"chat": const.CHATGPT, | |||||
"voice_to_text": "openai", | "voice_to_text": "openai", | ||||
"text_to_voice": "baidu" | "text_to_voice": "baidu" | ||||
} | } | ||||
model_type = conf().get("model") | |||||
if model_type in ["text-davinci-003"]: | |||||
self.btype['chat'] = const.OPEN_AI | |||||
self.bots={} | self.bots={} | ||||
def get_bot(self,typename): | def get_bot(self,typename): | ||||
if self.bots.get(typename) is None: | if self.bots.get(typename) is None: | ||||
logger.info("create bot {} for {}".format(self.btype[typename],typename)) | logger.info("create bot {} for {}".format(self.btype[typename],typename)) | ||||
@@ -35,13 +37,7 @@ class Bridge(object): | |||||
def fetch_reply_content(self, query, context : Context) -> Reply: | def fetch_reply_content(self, query, context : Context) -> Reply: | ||||
bot_type = const.CHATGPT | |||||
model_type = conf().get("model") | |||||
if model_type in ["gpt-3.5-turbo", "gpt-4", "gpt-4-32k"]: | |||||
bot_type = const.CHATGPT | |||||
elif model_type in ["text-davinci-003"]: | |||||
bot_type = const.OPEN_AI | |||||
return bot_factory.create_bot(bot_type).reply(query, context) | |||||
return self.get_bot("chat").reply(query, context) | |||||
def fetch_voice_to_text(self, voiceFile) -> Reply: | def fetch_voice_to_text(self, voiceFile) -> Reply: | ||||
@@ -56,14 +56,15 @@ class WechatChannel(Channel): | |||||
# start message listener | # start message listener | ||||
itchat.run() | itchat.run() | ||||
# handle_* 系列函数处理收到的消息后构造context,然后调用handle函数处理context | |||||
# context是一个字典,包含了消息的所有信息,包括以下key | |||||
# handle_* 系列函数处理收到的消息后构造Context,然后传入handle函数中处理Context和发送回复 | |||||
# Context包含了消息的所有信息,包括以下属性 | |||||
# type 消息类型, 包括TEXT、VOICE、IMAGE_CREATE | # type 消息类型, 包括TEXT、VOICE、IMAGE_CREATE | ||||
# content 消息内容,如果是TEXT类型,content就是文本内容,如果是VOICE类型,content就是语音文件名,如果是IMAGE_CREATE类型,content就是图片生成命令 | # content 消息内容,如果是TEXT类型,content就是文本内容,如果是VOICE类型,content就是语音文件名,如果是IMAGE_CREATE类型,content就是图片生成命令 | ||||
# session_id: 会话id | |||||
# isgroup: 是否是群聊 | |||||
# msg: 原始消息对象 | |||||
# receiver: 需要回复的对象 | |||||
# kwargs 附加参数字典,包含以下的key: | |||||
# session_id: 会话id | |||||
# isgroup: 是否是群聊 | |||||
# receiver: 需要回复的对象 | |||||
# msg: itchat的原始消息对象 | |||||
def handle_voice(self, msg): | def handle_voice(self, msg): | ||||
if conf().get('speech_recognition') != True: | if conf().get('speech_recognition') != True: | ||||
@@ -194,12 +194,12 @@ class WechatyChannel(Channel): | |||||
try: | try: | ||||
if not query: | if not query: | ||||
return | return | ||||
context = dict() | |||||
context = Context(ContextType.TEXT, query) | |||||
context['session_id'] = reply_user_id | context['session_id'] = reply_user_id | ||||
reply_text = super().build_reply_content(query, context) | |||||
reply_text = super().build_reply_content(query, context).content | |||||
if reply_text: | if reply_text: | ||||
# 转换 mp3 文件为 silk 格式 | # 转换 mp3 文件为 silk 格式 | ||||
mp3_file = super().build_text_to_voice(reply_text) | |||||
mp3_file = super().build_text_to_voice(reply_text).content | |||||
silk_file = mp3_file.replace(".mp3", ".silk") | silk_file = mp3_file.replace(".mp3", ".silk") | ||||
# Load the MP3 file | # Load the MP3 file | ||||
audio = AudioSegment.from_file(mp3_file, format="mp3") | audio = AudioSegment.from_file(mp3_file, format="mp3") | ||||
@@ -16,8 +16,26 @@ class ExpiredDict(dict): | |||||
def __setitem__(self, key, value): | def __setitem__(self, key, value): | ||||
expiry_time = datetime.now() + timedelta(seconds=self.expires_in_seconds) | expiry_time = datetime.now() + timedelta(seconds=self.expires_in_seconds) | ||||
super().__setitem__(key, (value, expiry_time)) | super().__setitem__(key, (value, expiry_time)) | ||||
def get(self, key, default=None): | def get(self, key, default=None): | ||||
try: | try: | ||||
return self[key] | return self[key] | ||||
except KeyError: | except KeyError: | ||||
return default | |||||
return default | |||||
def __contains__(self, key): | |||||
try: | |||||
self[key] | |||||
return True | |||||
except KeyError: | |||||
return False | |||||
def keys(self): | |||||
keys=list(super().keys()) | |||||
return [key for key in keys if key in self] | |||||
def items(self): | |||||
return [(key, self[key]) for key in self.keys()] | |||||
def __iter__(self): | |||||
return self.keys().__iter__() |
@@ -31,7 +31,7 @@ | |||||
context.kwargs = {'isgroup': False, 'msg': msg, 'receiver': other_user_id, 'session_id': other_user_id} | context.kwargs = {'isgroup': False, 'msg': msg, 'receiver': other_user_id, 'session_id': other_user_id} | ||||
``` | ``` | ||||
2. 产生回复 | 2. 产生回复 | ||||
本过程用于处理消息。目前默认处理逻辑是根据`Context`的类型交付给对应的bot: | |||||
本过程用于处理消息。目前默认处理逻辑如下,它根据`Context`的类型交付给对应的bot。如果本过程未产生任何回复,则会跳过之后的处理阶段。 | |||||
```python | ```python | ||||
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: | if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: | ||||
reply = super().build_reply_content(context.content, context) #文字跟画图交付给chatgpt | reply = super().build_reply_content(context.content, context) #文字跟画图交付给chatgpt | ||||
@@ -133,9 +133,9 @@ class Hello(Plugin): | |||||
- `EventAction.BREAK`: 事件结束,不再给下个插件处理,交付给默认的处理逻辑。 | - `EventAction.BREAK`: 事件结束,不再给下个插件处理,交付给默认的处理逻辑。 | ||||
- `EventAction.BREAK_PASS`: 事件结束,不再给下个插件处理,跳过默认的处理逻辑。 | - `EventAction.BREAK_PASS`: 事件结束,不再给下个插件处理,跳过默认的处理逻辑。 | ||||
以`Hello`插件为例,它处理`Context`类型为`TEXT`的消息。 | |||||
以`Hello`插件为例,它处理`Context`类型为`TEXT`的消息: | |||||
- 如果内容是`Hello`,直接将回复设置为`Hello+用户昵称`,并跳过之后的插件和默认逻辑。 | - 如果内容是`Hello`,直接将回复设置为`Hello+用户昵称`,并跳过之后的插件和默认逻辑。 | ||||
- 如果内容是`End`,它会将`Context`的类型更改为`IMAGE_CREATE`,并让事件继续,如果最终交付到默认逻辑,会调用默认的画图Bot。 | |||||
- 如果内容是`End`,它会将`Context`的类型更改为`IMAGE_CREATE`,并让事件继续,如果最终交付到默认逻辑,会调用默认的画图Bot来画画。 | |||||
```python | ```python | ||||
def on_handle_context(self, e_context: EventContext): | def on_handle_context(self, e_context: EventContext): | ||||
if e_context['context'].type != ContextType.TEXT: | if e_context['context'].type != ContextType.TEXT: | ||||
@@ -8,6 +8,7 @@ from config import conf | |||||
import plugins | import plugins | ||||
from plugins import * | from plugins import * | ||||
from common.log import logger | from common.log import logger | ||||
from common import const | |||||
# https://github.com/bupticybee/ChineseAiDungeonChatGPT | # https://github.com/bupticybee/ChineseAiDungeonChatGPT | ||||
class StoryTeller(): | class StoryTeller(): | ||||
@@ -51,7 +52,7 @@ class Dungeon(Plugin): | |||||
if e_context['context'].type != ContextType.TEXT: | if e_context['context'].type != ContextType.TEXT: | ||||
return | return | ||||
bottype = Bridge().get_bot_type("chat") | bottype = Bridge().get_bot_type("chat") | ||||
if bottype != "chatGPT": | |||||
if bottype != const.CHATGPT: | |||||
return | return | ||||
bot = Bridge().get_bot("chat") | bot = Bridge().get_bot("chat") | ||||
content = e_context['context'].content[:] | content = e_context['context'].content[:] | ||||
@@ -10,6 +10,7 @@ from bridge.reply import Reply, ReplyType | |||||
from config import load_config | from config import load_config | ||||
import plugins | import plugins | ||||
from plugins import * | from plugins import * | ||||
from common import const | |||||
from common.log import logger | from common.log import logger | ||||
# 定义指令集 | # 定义指令集 | ||||
@@ -163,7 +164,7 @@ class Godcmd(Plugin): | |||||
elif cmd == "id": | elif cmd == "id": | ||||
ok, result = True, f"用户id=\n{user}" | ok, result = True, f"用户id=\n{user}" | ||||
elif cmd == "reset": | elif cmd == "reset": | ||||
if bottype == "chatGPT": | |||||
if bottype == const.CHATGPT: | |||||
bot.sessions.clear_session(session_id) | bot.sessions.clear_session(session_id) | ||||
ok, result = True, "会话已重置" | ok, result = True, "会话已重置" | ||||
else: | else: | ||||
@@ -185,7 +186,7 @@ class Godcmd(Plugin): | |||||
load_config() | load_config() | ||||
ok, result = True, "配置已重载" | ok, result = True, "配置已重载" | ||||
elif cmd == "resetall": | elif cmd == "resetall": | ||||
if bottype == "chatGPT": | |||||
if bottype == const.CHATGPT: | |||||
bot.sessions.clear_all_session() | bot.sessions.clear_all_session() | ||||
ok, result = True, "重置所有会话成功" | ok, result = True, "重置所有会话成功" | ||||
else: | else: | ||||
@@ -5,6 +5,7 @@ import os | |||||
from bridge.bridge import Bridge | from bridge.bridge import Bridge | ||||
from bridge.context import ContextType | from bridge.context import ContextType | ||||
from bridge.reply import Reply, ReplyType | from bridge.reply import Reply, ReplyType | ||||
from common import const | |||||
import plugins | import plugins | ||||
from plugins import * | from plugins import * | ||||
from common.log import logger | from common.log import logger | ||||
@@ -73,7 +74,7 @@ class Role(Plugin): | |||||
if e_context['context'].type != ContextType.TEXT: | if e_context['context'].type != ContextType.TEXT: | ||||
return | return | ||||
bottype = Bridge().get_bot_type("chat") | bottype = Bridge().get_bot_type("chat") | ||||
if bottype != "chatGPT": | |||||
if bottype != const.CHATGPT: | |||||
return | return | ||||
bot = Bridge().get_bot("chat") | bot = Bridge().get_bot("chat") | ||||
content = e_context['context'].content[:] | content = e_context['context'].content[:] | ||||
@@ -119,7 +120,7 @@ class Role(Plugin): | |||||
e_context.action = EventAction.CONTINUE | e_context.action = EventAction.CONTINUE | ||||
def get_help_text(self): | def get_help_text(self): | ||||
help_text = "输入\"$角色 (角色名)\"或\"$role (角色名)\"为我设定角色吧,#reset 可以清除设定的角色。\n\n目前可用角色列表:\n" | |||||
help_text = "输入\"$角色 (角色名)\"或\"$role (角色名)\"为我设定角色吧,\"$停止扮演 \" 可以清除设定的角色。\n\n目前可用角色列表:\n" | |||||
for role in self.roles: | for role in self.roles: | ||||
help_text += f"[{role}]: {self.roles[role]['remark']}\n" | help_text += f"[{role}]: {self.roles[role]['remark']}\n" | ||||
return help_text | return help_text |