@@ -27,5 +27,5 @@ | |||||
### 环境 | ### 环境 | ||||
- 操作系统类型 (Mac/Windows/Linux): | - 操作系统类型 (Mac/Windows/Linux): | ||||
- Python版本 ( 执行 `python3 -V` ): | |||||
- Python版本 ( 执行 `python3 -V` ): | |||||
- pip版本 ( 依赖问题此项必填,执行 `pip3 -V`): | - pip版本 ( 依赖问题此项必填,执行 `pip3 -V`): |
@@ -49,9 +49,9 @@ jobs: | |||||
file: ./docker/Dockerfile.latest | file: ./docker/Dockerfile.latest | ||||
tags: ${{ steps.meta.outputs.tags }} | tags: ${{ steps.meta.outputs.tags }} | ||||
labels: ${{ steps.meta.outputs.labels }} | labels: ${{ steps.meta.outputs.labels }} | ||||
- uses: actions/delete-package-versions@v4 | - uses: actions/delete-package-versions@v4 | ||||
with: | |||||
with: | |||||
package-name: 'chatgpt-on-wechat' | package-name: 'chatgpt-on-wechat' | ||||
package-type: 'container' | package-type: 'container' | ||||
min-versions-to-keep: 10 | min-versions-to-keep: 10 | ||||
@@ -120,7 +120,7 @@ pip3 install azure-cognitiveservices-speech | |||||
```bash | ```bash | ||||
# config.json文件内容示例 | # config.json文件内容示例 | ||||
{ | |||||
{ | |||||
"open_ai_api_key": "YOUR API KEY", # 填入上面创建的 OpenAI API KEY | "open_ai_api_key": "YOUR API KEY", # 填入上面创建的 OpenAI API KEY | ||||
"model": "gpt-3.5-turbo", # 模型名称。当use_azure_chatgpt为true时,其名称为Azure上model deployment名称 | "model": "gpt-3.5-turbo", # 模型名称。当use_azure_chatgpt为true时,其名称为Azure上model deployment名称 | ||||
"proxy": "127.0.0.1:7890", # 代理客户端的ip和端口 | "proxy": "127.0.0.1:7890", # 代理客户端的ip和端口 | ||||
@@ -128,7 +128,7 @@ pip3 install azure-cognitiveservices-speech | |||||
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人 | "single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人 | ||||
"group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复 | "group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复 | ||||
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表 | "group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表 | ||||
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称 | |||||
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称 | |||||
"image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀 | "image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀 | ||||
"conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数 | "conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数 | ||||
"speech_recognition": false, # 是否开启语音识别 | "speech_recognition": false, # 是否开启语音识别 | ||||
@@ -160,7 +160,7 @@ pip3 install azure-cognitiveservices-speech | |||||
**4.其他配置** | **4.其他配置** | ||||
+ `model`: 模型名称,目前支持 `gpt-3.5-turbo`, `text-davinci-003`, `gpt-4`, `gpt-4-32k` (其中gpt-4 api暂未开放) | + `model`: 模型名称,目前支持 `gpt-3.5-turbo`, `text-davinci-003`, `gpt-4`, `gpt-4-32k` (其中gpt-4 api暂未开放) | ||||
+ `temperature`,`frequency_penalty`,`presence_penalty`: Chat API接口参数,详情参考[OpenAI官方文档。](https://platform.openai.com/docs/api-reference/chat) | |||||
+ `temperature`,`frequency_penalty`,`presence_penalty`: Chat API接口参数,详情参考[OpenAI官方文档。](https://platform.openai.com/docs/api-reference/chat) | |||||
+ `proxy`:由于目前 `openai` 接口国内无法访问,需配置代理客户端的地址,详情参考 [#351](https://github.com/zhayujie/chatgpt-on-wechat/issues/351) | + `proxy`:由于目前 `openai` 接口国内无法访问,需配置代理客户端的地址,详情参考 [#351](https://github.com/zhayujie/chatgpt-on-wechat/issues/351) | ||||
+ 对于图像生成,在满足个人或群组触发条件外,还需要额外的关键词前缀来触发,对应配置 `image_create_prefix ` | + 对于图像生成,在满足个人或群组触发条件外,还需要额外的关键词前缀来触发,对应配置 `image_create_prefix ` | ||||
+ 关于OpenAI对话及图片接口的参数配置(内容自由度、回复字数限制、图片大小等),可以参考 [对话接口](https://beta.openai.com/docs/api-reference/completions) 和 [图像接口](https://beta.openai.com/docs/api-reference/completions) 文档直接在 [代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/bot/openai/open_ai_bot.py) `bot/openai/open_ai_bot.py` 中进行调整。 | + 关于OpenAI对话及图片接口的参数配置(内容自由度、回复字数限制、图片大小等),可以参考 [对话接口](https://beta.openai.com/docs/api-reference/completions) 和 [图像接口](https://beta.openai.com/docs/api-reference/completions) 文档直接在 [代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/bot/openai/open_ai_bot.py) `bot/openai/open_ai_bot.py` 中进行调整。 | ||||
@@ -181,7 +181,7 @@ pip3 install azure-cognitiveservices-speech | |||||
```bash | ```bash | ||||
python3 app.py | python3 app.py | ||||
``` | ``` | ||||
终端输出二维码后,使用微信进行扫码,当输出 "Start auto replying" 时表示自动回复程序已经成功运行了(注意:用于登录的微信需要在支付处已完成实名认证)。扫码登录后你的账号就成为机器人了,可以在微信手机端通过配置的关键词触发自动回复 (任意好友发送消息给你,或是自己发消息给好友),参考[#142](https://github.com/zhayujie/chatgpt-on-wechat/issues/142)。 | |||||
终端输出二维码后,使用微信进行扫码,当输出 "Start auto replying" 时表示自动回复程序已经成功运行了(注意:用于登录的微信需要在支付处已完成实名认证)。扫码登录后你的账号就成为机器人了,可以在微信手机端通过配置的关键词触发自动回复 (任意好友发送消息给你,或是自己发消息给好友),参考[#142](https://github.com/zhayujie/chatgpt-on-wechat/issues/142)。 | |||||
### 2.服务器部署 | ### 2.服务器部署 | ||||
@@ -189,7 +189,7 @@ python3 app.py | |||||
使用nohup命令在后台运行程序: | 使用nohup命令在后台运行程序: | ||||
```bash | ```bash | ||||
touch nohup.out # 首次运行需要新建日志文件 | |||||
touch nohup.out # 首次运行需要新建日志文件 | |||||
nohup python3 app.py & tail -f nohup.out # 在后台运行程序并通过日志输出二维码 | nohup python3 app.py & tail -f nohup.out # 在后台运行程序并通过日志输出二维码 | ||||
``` | ``` | ||||
扫码登录后程序即可运行于服务器后台,此时可通过 `ctrl+c` 关闭日志,不会影响后台程序的运行。使用 `ps -ef | grep app.py | grep -v grep` 命令可查看运行于后台的进程,如果想要重新启动程序可以先 `kill` 掉对应的进程。日志关闭后如果想要再次打开只需输入 `tail -f nohup.out`。此外,`scripts` 目录下有一键运行、关闭程序的脚本供使用。 | 扫码登录后程序即可运行于服务器后台,此时可通过 `ctrl+c` 关闭日志,不会影响后台程序的运行。使用 `ps -ef | grep app.py | grep -v grep` 命令可查看运行于后台的进程,如果想要重新启动程序可以先 `kill` 掉对应的进程。日志关闭后如果想要再次打开只需输入 `tail -f nohup.out`。此外,`scripts` 目录下有一键运行、关闭程序的脚本供使用。 | ||||
@@ -1,23 +1,28 @@ | |||||
# encoding:utf-8 | # encoding:utf-8 | ||||
import os | import os | ||||
from config import conf, load_config | |||||
import signal | |||||
import sys | |||||
from channel import channel_factory | from channel import channel_factory | ||||
from common.log import logger | from common.log import logger | ||||
from config import conf, load_config | |||||
from plugins import * | from plugins import * | ||||
import signal | |||||
import sys | |||||
def sigterm_handler_wrap(_signo): | def sigterm_handler_wrap(_signo): | ||||
old_handler = signal.getsignal(_signo) | old_handler = signal.getsignal(_signo) | ||||
def func(_signo, _stack_frame): | def func(_signo, _stack_frame): | ||||
logger.info("signal {} received, exiting...".format(_signo)) | logger.info("signal {} received, exiting...".format(_signo)) | ||||
conf().save_user_datas() | conf().save_user_datas() | ||||
if callable(old_handler): # check old_handler | |||||
if callable(old_handler): # check old_handler | |||||
return old_handler(_signo, _stack_frame) | return old_handler(_signo, _stack_frame) | ||||
sys.exit(0) | sys.exit(0) | ||||
signal.signal(_signo, func) | signal.signal(_signo, func) | ||||
def run(): | def run(): | ||||
try: | try: | ||||
# load config | # load config | ||||
@@ -28,17 +33,17 @@ def run(): | |||||
sigterm_handler_wrap(signal.SIGTERM) | sigterm_handler_wrap(signal.SIGTERM) | ||||
# create channel | # create channel | ||||
channel_name=conf().get('channel_type', 'wx') | |||||
channel_name = conf().get("channel_type", "wx") | |||||
if "--cmd" in sys.argv: | if "--cmd" in sys.argv: | ||||
channel_name = 'terminal' | |||||
channel_name = "terminal" | |||||
if channel_name == 'wxy': | |||||
os.environ['WECHATY_LOG']="warn" | |||||
if channel_name == "wxy": | |||||
os.environ["WECHATY_LOG"] = "warn" | |||||
# os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001' | # os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001' | ||||
channel = channel_factory.create_channel(channel_name) | channel = channel_factory.create_channel(channel_name) | ||||
if channel_name in ['wx','wxy','terminal','wechatmp','wechatmp_service']: | |||||
if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service"]: | |||||
PluginManager().load_plugins() | PluginManager().load_plugins() | ||||
# startup channel | # startup channel | ||||
@@ -47,5 +52,6 @@ def run(): | |||||
logger.error("App startup failed!") | logger.error("App startup failed!") | ||||
logger.exception(e) | logger.exception(e) | ||||
if __name__ == '__main__': | |||||
run() | |||||
if __name__ == "__main__": | |||||
run() |
@@ -1,6 +1,7 @@ | |||||
# encoding:utf-8 | # encoding:utf-8 | ||||
import requests | import requests | ||||
from bot.bot import Bot | from bot.bot import Bot | ||||
from bridge.reply import Reply, ReplyType | from bridge.reply import Reply, ReplyType | ||||
@@ -9,20 +10,35 @@ from bridge.reply import Reply, ReplyType | |||||
class BaiduUnitBot(Bot): | class BaiduUnitBot(Bot): | ||||
def reply(self, query, context=None): | def reply(self, query, context=None): | ||||
token = self.get_token() | token = self.get_token() | ||||
url = 'https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=' + token | |||||
post_data = "{\"version\":\"3.0\",\"service_id\":\"S73177\",\"session_id\":\"\",\"log_id\":\"7758521\",\"skill_ids\":[\"1221886\"],\"request\":{\"terminal_id\":\"88888\",\"query\":\"" + query + "\", \"hyper_params\": {\"chat_custom_bot_profile\": 1}}}" | |||||
url = ( | |||||
"https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" | |||||
+ token | |||||
) | |||||
post_data = ( | |||||
'{"version":"3.0","service_id":"S73177","session_id":"","log_id":"7758521","skill_ids":["1221886"],"request":{"terminal_id":"88888","query":"' | |||||
+ query | |||||
+ '", "hyper_params": {"chat_custom_bot_profile": 1}}}' | |||||
) | |||||
print(post_data) | print(post_data) | ||||
headers = {'content-type': 'application/x-www-form-urlencoded'} | |||||
headers = {"content-type": "application/x-www-form-urlencoded"} | |||||
response = requests.post(url, data=post_data.encode(), headers=headers) | response = requests.post(url, data=post_data.encode(), headers=headers) | ||||
if response: | if response: | ||||
reply = Reply(ReplyType.TEXT, response.json()['result']['context']['SYS_PRESUMED_HIST'][1]) | |||||
reply = Reply( | |||||
ReplyType.TEXT, | |||||
response.json()["result"]["context"]["SYS_PRESUMED_HIST"][1], | |||||
) | |||||
return reply | return reply | ||||
def get_token(self): | def get_token(self): | ||||
access_key = 'YOUR_ACCESS_KEY' | |||||
secret_key = 'YOUR_SECRET_KEY' | |||||
host = 'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=' + access_key + '&client_secret=' + secret_key | |||||
access_key = "YOUR_ACCESS_KEY" | |||||
secret_key = "YOUR_SECRET_KEY" | |||||
host = ( | |||||
"https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=" | |||||
+ access_key | |||||
+ "&client_secret=" | |||||
+ secret_key | |||||
) | |||||
response = requests.get(host) | response = requests.get(host) | ||||
if response: | if response: | ||||
print(response.json()) | print(response.json()) | ||||
return response.json()['access_token'] | |||||
return response.json()["access_token"] |
@@ -8,7 +8,7 @@ from bridge.reply import Reply | |||||
class Bot(object): | class Bot(object): | ||||
def reply(self, query, context : Context =None) -> Reply: | |||||
def reply(self, query, context: Context = None) -> Reply: | |||||
""" | """ | ||||
bot auto-reply content | bot auto-reply content | ||||
:param req: received message | :param req: received message | ||||
@@ -13,20 +13,24 @@ def create_bot(bot_type): | |||||
if bot_type == const.BAIDU: | if bot_type == const.BAIDU: | ||||
# Baidu Unit对话接口 | # Baidu Unit对话接口 | ||||
from bot.baidu.baidu_unit_bot import BaiduUnitBot | from bot.baidu.baidu_unit_bot import BaiduUnitBot | ||||
return BaiduUnitBot() | return BaiduUnitBot() | ||||
elif bot_type == const.CHATGPT: | elif bot_type == const.CHATGPT: | ||||
# ChatGPT 网页端web接口 | # ChatGPT 网页端web接口 | ||||
from bot.chatgpt.chat_gpt_bot import ChatGPTBot | from bot.chatgpt.chat_gpt_bot import ChatGPTBot | ||||
return ChatGPTBot() | return ChatGPTBot() | ||||
elif bot_type == const.OPEN_AI: | elif bot_type == const.OPEN_AI: | ||||
# OpenAI 官方对话模型API | # OpenAI 官方对话模型API | ||||
from bot.openai.open_ai_bot import OpenAIBot | from bot.openai.open_ai_bot import OpenAIBot | ||||
return OpenAIBot() | return OpenAIBot() | ||||
elif bot_type == const.CHATGPTONAZURE: | elif bot_type == const.CHATGPTONAZURE: | ||||
# Azure chatgpt service https://azure.microsoft.com/en-in/products/cognitive-services/openai-service/ | # Azure chatgpt service https://azure.microsoft.com/en-in/products/cognitive-services/openai-service/ | ||||
from bot.chatgpt.chat_gpt_bot import AzureChatGPTBot | from bot.chatgpt.chat_gpt_bot import AzureChatGPTBot | ||||
return AzureChatGPTBot() | return AzureChatGPTBot() | ||||
raise RuntimeError | raise RuntimeError |
@@ -1,42 +1,53 @@ | |||||
# encoding:utf-8 | # encoding:utf-8 | ||||
import time | |||||
import openai | |||||
import openai.error | |||||
from bot.bot import Bot | from bot.bot import Bot | ||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession | from bot.chatgpt.chat_gpt_session import ChatGPTSession | ||||
from bot.openai.open_ai_image import OpenAIImage | from bot.openai.open_ai_image import OpenAIImage | ||||
from bot.session_manager import SessionManager | from bot.session_manager import SessionManager | ||||
from bridge.context import ContextType | from bridge.context import ContextType | ||||
from bridge.reply import Reply, ReplyType | from bridge.reply import Reply, ReplyType | ||||
from config import conf, load_config | |||||
from common.log import logger | from common.log import logger | ||||
from common.token_bucket import TokenBucket | from common.token_bucket import TokenBucket | ||||
import openai | |||||
import openai.error | |||||
import time | |||||
from config import conf, load_config | |||||
# OpenAI对话模型API (可用) | # OpenAI对话模型API (可用) | ||||
class ChatGPTBot(Bot,OpenAIImage): | |||||
class ChatGPTBot(Bot, OpenAIImage): | |||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
# set the default api_key | # set the default api_key | ||||
openai.api_key = conf().get('open_ai_api_key') | |||||
if conf().get('open_ai_api_base'): | |||||
openai.api_base = conf().get('open_ai_api_base') | |||||
proxy = conf().get('proxy') | |||||
openai.api_key = conf().get("open_ai_api_key") | |||||
if conf().get("open_ai_api_base"): | |||||
openai.api_base = conf().get("open_ai_api_base") | |||||
proxy = conf().get("proxy") | |||||
if proxy: | if proxy: | ||||
openai.proxy = proxy | openai.proxy = proxy | ||||
if conf().get('rate_limit_chatgpt'): | |||||
self.tb4chatgpt = TokenBucket(conf().get('rate_limit_chatgpt', 20)) | |||||
self.sessions = SessionManager(ChatGPTSession, model= conf().get("model") or "gpt-3.5-turbo") | |||||
self.args ={ | |||||
if conf().get("rate_limit_chatgpt"): | |||||
self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20)) | |||||
self.sessions = SessionManager( | |||||
ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo" | |||||
) | |||||
self.args = { | |||||
"model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称 | "model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称 | ||||
"temperature":conf().get('temperature', 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性 | |||||
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性 | |||||
# "max_tokens":4096, # 回复最大的字符数 | # "max_tokens":4096, # 回复最大的字符数 | ||||
"top_p":1, | |||||
"frequency_penalty":conf().get('frequency_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 | |||||
"presence_penalty":conf().get('presence_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 | |||||
"request_timeout": conf().get('request_timeout', None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 | |||||
"timeout": conf().get('request_timeout', None), #重试超时时间,在这个时间内,将会自动重试 | |||||
"top_p": 1, | |||||
"frequency_penalty": conf().get( | |||||
"frequency_penalty", 0.0 | |||||
), # [-2,2]之间,该值越大则更倾向于产生不同的内容 | |||||
"presence_penalty": conf().get( | |||||
"presence_penalty", 0.0 | |||||
), # [-2,2]之间,该值越大则更倾向于产生不同的内容 | |||||
"request_timeout": conf().get( | |||||
"request_timeout", None | |||||
), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 | |||||
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试 | |||||
} | } | ||||
def reply(self, query, context=None): | def reply(self, query, context=None): | ||||
@@ -44,39 +55,50 @@ class ChatGPTBot(Bot,OpenAIImage): | |||||
if context.type == ContextType.TEXT: | if context.type == ContextType.TEXT: | ||||
logger.info("[CHATGPT] query={}".format(query)) | logger.info("[CHATGPT] query={}".format(query)) | ||||
session_id = context['session_id'] | |||||
session_id = context["session_id"] | |||||
reply = None | reply = None | ||||
clear_memory_commands = conf().get('clear_memory_commands', ['#清除记忆']) | |||||
clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"]) | |||||
if query in clear_memory_commands: | if query in clear_memory_commands: | ||||
self.sessions.clear_session(session_id) | self.sessions.clear_session(session_id) | ||||
reply = Reply(ReplyType.INFO, '记忆已清除') | |||||
elif query == '#清除所有': | |||||
reply = Reply(ReplyType.INFO, "记忆已清除") | |||||
elif query == "#清除所有": | |||||
self.sessions.clear_all_session() | self.sessions.clear_all_session() | ||||
reply = Reply(ReplyType.INFO, '所有人记忆已清除') | |||||
elif query == '#更新配置': | |||||
reply = Reply(ReplyType.INFO, "所有人记忆已清除") | |||||
elif query == "#更新配置": | |||||
load_config() | load_config() | ||||
reply = Reply(ReplyType.INFO, '配置已更新') | |||||
reply = Reply(ReplyType.INFO, "配置已更新") | |||||
if reply: | if reply: | ||||
return reply | return reply | ||||
session = self.sessions.session_query(query, session_id) | session = self.sessions.session_query(query, session_id) | ||||
logger.debug("[CHATGPT] session query={}".format(session.messages)) | logger.debug("[CHATGPT] session query={}".format(session.messages)) | ||||
api_key = context.get('openai_api_key') | |||||
api_key = context.get("openai_api_key") | |||||
# if context.get('stream'): | # if context.get('stream'): | ||||
# # reply in stream | # # reply in stream | ||||
# return self.reply_text_stream(query, new_query, session_id) | # return self.reply_text_stream(query, new_query, session_id) | ||||
reply_content = self.reply_text(session, api_key) | reply_content = self.reply_text(session, api_key) | ||||
logger.debug("[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session.messages, session_id, reply_content["content"], reply_content["completion_tokens"])) | |||||
if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0: | |||||
reply = Reply(ReplyType.ERROR, reply_content['content']) | |||||
logger.debug( | |||||
"[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format( | |||||
session.messages, | |||||
session_id, | |||||
reply_content["content"], | |||||
reply_content["completion_tokens"], | |||||
) | |||||
) | |||||
if ( | |||||
reply_content["completion_tokens"] == 0 | |||||
and len(reply_content["content"]) > 0 | |||||
): | |||||
reply = Reply(ReplyType.ERROR, reply_content["content"]) | |||||
elif reply_content["completion_tokens"] > 0: | elif reply_content["completion_tokens"] > 0: | ||||
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"]) | |||||
self.sessions.session_reply( | |||||
reply_content["content"], session_id, reply_content["total_tokens"] | |||||
) | |||||
reply = Reply(ReplyType.TEXT, reply_content["content"]) | reply = Reply(ReplyType.TEXT, reply_content["content"]) | ||||
else: | else: | ||||
reply = Reply(ReplyType.ERROR, reply_content['content']) | |||||
reply = Reply(ReplyType.ERROR, reply_content["content"]) | |||||
logger.debug("[CHATGPT] reply {} used 0 tokens.".format(reply_content)) | logger.debug("[CHATGPT] reply {} used 0 tokens.".format(reply_content)) | ||||
return reply | return reply | ||||
@@ -89,53 +111,55 @@ class ChatGPTBot(Bot,OpenAIImage): | |||||
reply = Reply(ReplyType.ERROR, retstring) | reply = Reply(ReplyType.ERROR, retstring) | ||||
return reply | return reply | ||||
else: | else: | ||||
reply = Reply(ReplyType.ERROR, 'Bot不支持处理{}类型的消息'.format(context.type)) | |||||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type)) | |||||
return reply | return reply | ||||
def reply_text(self, session:ChatGPTSession, api_key=None, retry_count=0) -> dict: | |||||
''' | |||||
def reply_text(self, session: ChatGPTSession, api_key=None, retry_count=0) -> dict: | |||||
""" | |||||
call openai's ChatCompletion to get the answer | call openai's ChatCompletion to get the answer | ||||
:param session: a conversation session | :param session: a conversation session | ||||
:param session_id: session id | :param session_id: session id | ||||
:param retry_count: retry count | :param retry_count: retry count | ||||
:return: {} | :return: {} | ||||
''' | |||||
""" | |||||
try: | try: | ||||
if conf().get('rate_limit_chatgpt') and not self.tb4chatgpt.get_token(): | |||||
if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token(): | |||||
raise openai.error.RateLimitError("RateLimitError: rate limit exceeded") | raise openai.error.RateLimitError("RateLimitError: rate limit exceeded") | ||||
# if api_key == None, the default openai.api_key will be used | # if api_key == None, the default openai.api_key will be used | ||||
response = openai.ChatCompletion.create( | response = openai.ChatCompletion.create( | ||||
api_key=api_key, messages=session.messages, **self.args | api_key=api_key, messages=session.messages, **self.args | ||||
) | ) | ||||
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"])) | # logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"])) | ||||
return {"total_tokens": response["usage"]["total_tokens"], | |||||
"completion_tokens": response["usage"]["completion_tokens"], | |||||
"content": response.choices[0]['message']['content']} | |||||
return { | |||||
"total_tokens": response["usage"]["total_tokens"], | |||||
"completion_tokens": response["usage"]["completion_tokens"], | |||||
"content": response.choices[0]["message"]["content"], | |||||
} | |||||
except Exception as e: | except Exception as e: | ||||
need_retry = retry_count < 2 | need_retry = retry_count < 2 | ||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} | result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} | ||||
if isinstance(e, openai.error.RateLimitError): | if isinstance(e, openai.error.RateLimitError): | ||||
logger.warn("[CHATGPT] RateLimitError: {}".format(e)) | logger.warn("[CHATGPT] RateLimitError: {}".format(e)) | ||||
result['content'] = "提问太快啦,请休息一下再问我吧" | |||||
result["content"] = "提问太快啦,请休息一下再问我吧" | |||||
if need_retry: | if need_retry: | ||||
time.sleep(5) | time.sleep(5) | ||||
elif isinstance(e, openai.error.Timeout): | elif isinstance(e, openai.error.Timeout): | ||||
logger.warn("[CHATGPT] Timeout: {}".format(e)) | logger.warn("[CHATGPT] Timeout: {}".format(e)) | ||||
result['content'] = "我没有收到你的消息" | |||||
result["content"] = "我没有收到你的消息" | |||||
if need_retry: | if need_retry: | ||||
time.sleep(5) | time.sleep(5) | ||||
elif isinstance(e, openai.error.APIConnectionError): | elif isinstance(e, openai.error.APIConnectionError): | ||||
logger.warn("[CHATGPT] APIConnectionError: {}".format(e)) | logger.warn("[CHATGPT] APIConnectionError: {}".format(e)) | ||||
need_retry = False | need_retry = False | ||||
result['content'] = "我连接不到你的网络" | |||||
result["content"] = "我连接不到你的网络" | |||||
else: | else: | ||||
logger.warn("[CHATGPT] Exception: {}".format(e)) | logger.warn("[CHATGPT] Exception: {}".format(e)) | ||||
need_retry = False | need_retry = False | ||||
self.sessions.clear_session(session.session_id) | self.sessions.clear_session(session.session_id) | ||||
if need_retry: | if need_retry: | ||||
logger.warn("[CHATGPT] 第{}次重试".format(retry_count+1)) | |||||
return self.reply_text(session, api_key, retry_count+1) | |||||
logger.warn("[CHATGPT] 第{}次重试".format(retry_count + 1)) | |||||
return self.reply_text(session, api_key, retry_count + 1) | |||||
else: | else: | ||||
return result | return result | ||||
@@ -145,4 +169,4 @@ class AzureChatGPTBot(ChatGPTBot): | |||||
super().__init__() | super().__init__() | ||||
openai.api_type = "azure" | openai.api_type = "azure" | ||||
openai.api_version = "2023-03-15-preview" | openai.api_version = "2023-03-15-preview" | ||||
self.args["deployment_id"] = conf().get("azure_deployment_id") | |||||
self.args["deployment_id"] = conf().get("azure_deployment_id") |
@@ -1,20 +1,23 @@ | |||||
from bot.session_manager import Session | from bot.session_manager import Session | ||||
from common.log import logger | from common.log import logger | ||||
''' | |||||
""" | |||||
e.g. [ | e.g. [ | ||||
{"role": "system", "content": "You are a helpful assistant."}, | {"role": "system", "content": "You are a helpful assistant."}, | ||||
{"role": "user", "content": "Who won the world series in 2020?"}, | {"role": "user", "content": "Who won the world series in 2020?"}, | ||||
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."}, | {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."}, | ||||
{"role": "user", "content": "Where was it played?"} | {"role": "user", "content": "Where was it played?"} | ||||
] | ] | ||||
''' | |||||
""" | |||||
class ChatGPTSession(Session): | class ChatGPTSession(Session): | ||||
def __init__(self, session_id, system_prompt=None, model= "gpt-3.5-turbo"): | |||||
def __init__(self, session_id, system_prompt=None, model="gpt-3.5-turbo"): | |||||
super().__init__(session_id, system_prompt) | super().__init__(session_id, system_prompt) | ||||
self.model = model | self.model = model | ||||
self.reset() | self.reset() | ||||
def discard_exceeding(self, max_tokens, cur_tokens= None): | |||||
def discard_exceeding(self, max_tokens, cur_tokens=None): | |||||
precise = True | precise = True | ||||
try: | try: | ||||
cur_tokens = self.calc_tokens() | cur_tokens = self.calc_tokens() | ||||
@@ -22,7 +25,9 @@ class ChatGPTSession(Session): | |||||
precise = False | precise = False | ||||
if cur_tokens is None: | if cur_tokens is None: | ||||
raise e | raise e | ||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e)) | |||||
logger.debug( | |||||
"Exception when counting tokens precisely for query: {}".format(e) | |||||
) | |||||
while cur_tokens > max_tokens: | while cur_tokens > max_tokens: | ||||
if len(self.messages) > 2: | if len(self.messages) > 2: | ||||
self.messages.pop(1) | self.messages.pop(1) | ||||
@@ -34,25 +39,32 @@ class ChatGPTSession(Session): | |||||
cur_tokens = cur_tokens - max_tokens | cur_tokens = cur_tokens - max_tokens | ||||
break | break | ||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "user": | elif len(self.messages) == 2 and self.messages[1]["role"] == "user": | ||||
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens)) | |||||
logger.warn( | |||||
"user message exceed max_tokens. total_tokens={}".format(cur_tokens) | |||||
) | |||||
break | break | ||||
else: | else: | ||||
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages))) | |||||
logger.debug( | |||||
"max_tokens={}, total_tokens={}, len(messages)={}".format( | |||||
max_tokens, cur_tokens, len(self.messages) | |||||
) | |||||
) | |||||
break | break | ||||
if precise: | if precise: | ||||
cur_tokens = self.calc_tokens() | cur_tokens = self.calc_tokens() | ||||
else: | else: | ||||
cur_tokens = cur_tokens - max_tokens | cur_tokens = cur_tokens - max_tokens | ||||
return cur_tokens | return cur_tokens | ||||
def calc_tokens(self): | def calc_tokens(self): | ||||
return num_tokens_from_messages(self.messages, self.model) | return num_tokens_from_messages(self.messages, self.model) | ||||
# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb | # refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb | ||||
def num_tokens_from_messages(messages, model): | def num_tokens_from_messages(messages, model): | ||||
"""Returns the number of tokens used by a list of messages.""" | """Returns the number of tokens used by a list of messages.""" | ||||
import tiktoken | import tiktoken | ||||
try: | try: | ||||
encoding = tiktoken.encoding_for_model(model) | encoding = tiktoken.encoding_for_model(model) | ||||
except KeyError: | except KeyError: | ||||
@@ -63,13 +75,17 @@ def num_tokens_from_messages(messages, model): | |||||
elif model == "gpt-4": | elif model == "gpt-4": | ||||
return num_tokens_from_messages(messages, model="gpt-4-0314") | return num_tokens_from_messages(messages, model="gpt-4-0314") | ||||
elif model == "gpt-3.5-turbo-0301": | elif model == "gpt-3.5-turbo-0301": | ||||
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n | |||||
tokens_per_message = ( | |||||
4 # every message follows <|start|>{role/name}\n{content}<|end|>\n | |||||
) | |||||
tokens_per_name = -1 # if there's a name, the role is omitted | tokens_per_name = -1 # if there's a name, the role is omitted | ||||
elif model == "gpt-4-0314": | elif model == "gpt-4-0314": | ||||
tokens_per_message = 3 | tokens_per_message = 3 | ||||
tokens_per_name = 1 | tokens_per_name = 1 | ||||
else: | else: | ||||
logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301.") | |||||
logger.warn( | |||||
f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301." | |||||
) | |||||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301") | return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301") | ||||
num_tokens = 0 | num_tokens = 0 | ||||
for message in messages: | for message in messages: | ||||
@@ -1,41 +1,52 @@ | |||||
# encoding:utf-8 | # encoding:utf-8 | ||||
import time | |||||
import openai | |||||
import openai.error | |||||
from bot.bot import Bot | from bot.bot import Bot | ||||
from bot.openai.open_ai_image import OpenAIImage | from bot.openai.open_ai_image import OpenAIImage | ||||
from bot.openai.open_ai_session import OpenAISession | from bot.openai.open_ai_session import OpenAISession | ||||
from bot.session_manager import SessionManager | from bot.session_manager import SessionManager | ||||
from bridge.context import ContextType | from bridge.context import ContextType | ||||
from bridge.reply import Reply, ReplyType | from bridge.reply import Reply, ReplyType | ||||
from config import conf | |||||
from common.log import logger | from common.log import logger | ||||
import openai | |||||
import openai.error | |||||
import time | |||||
from config import conf | |||||
user_session = dict() | user_session = dict() | ||||
# OpenAI对话模型API (可用) | # OpenAI对话模型API (可用) | ||||
class OpenAIBot(Bot, OpenAIImage): | class OpenAIBot(Bot, OpenAIImage): | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
openai.api_key = conf().get('open_ai_api_key') | |||||
if conf().get('open_ai_api_base'): | |||||
openai.api_base = conf().get('open_ai_api_base') | |||||
proxy = conf().get('proxy') | |||||
openai.api_key = conf().get("open_ai_api_key") | |||||
if conf().get("open_ai_api_base"): | |||||
openai.api_base = conf().get("open_ai_api_base") | |||||
proxy = conf().get("proxy") | |||||
if proxy: | if proxy: | ||||
openai.proxy = proxy | openai.proxy = proxy | ||||
self.sessions = SessionManager(OpenAISession, model= conf().get("model") or "text-davinci-003") | |||||
self.sessions = SessionManager( | |||||
OpenAISession, model=conf().get("model") or "text-davinci-003" | |||||
) | |||||
self.args = { | self.args = { | ||||
"model": conf().get("model") or "text-davinci-003", # 对话模型的名称 | "model": conf().get("model") or "text-davinci-003", # 对话模型的名称 | ||||
"temperature":conf().get('temperature', 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性 | |||||
"max_tokens":1200, # 回复最大的字符数 | |||||
"top_p":1, | |||||
"frequency_penalty":conf().get('frequency_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 | |||||
"presence_penalty":conf().get('presence_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 | |||||
"request_timeout": conf().get('request_timeout', None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 | |||||
"timeout": conf().get('request_timeout', None), #重试超时时间,在这个时间内,将会自动重试 | |||||
"stop":["\n\n\n"] | |||||
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性 | |||||
"max_tokens": 1200, # 回复最大的字符数 | |||||
"top_p": 1, | |||||
"frequency_penalty": conf().get( | |||||
"frequency_penalty", 0.0 | |||||
), # [-2,2]之间,该值越大则更倾向于产生不同的内容 | |||||
"presence_penalty": conf().get( | |||||
"presence_penalty", 0.0 | |||||
), # [-2,2]之间,该值越大则更倾向于产生不同的内容 | |||||
"request_timeout": conf().get( | |||||
"request_timeout", None | |||||
), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 | |||||
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试 | |||||
"stop": ["\n\n\n"], | |||||
} | } | ||||
def reply(self, query, context=None): | def reply(self, query, context=None): | ||||
@@ -43,24 +54,34 @@ class OpenAIBot(Bot, OpenAIImage): | |||||
if context and context.type: | if context and context.type: | ||||
if context.type == ContextType.TEXT: | if context.type == ContextType.TEXT: | ||||
logger.info("[OPEN_AI] query={}".format(query)) | logger.info("[OPEN_AI] query={}".format(query)) | ||||
session_id = context['session_id'] | |||||
session_id = context["session_id"] | |||||
reply = None | reply = None | ||||
if query == '#清除记忆': | |||||
if query == "#清除记忆": | |||||
self.sessions.clear_session(session_id) | self.sessions.clear_session(session_id) | ||||
reply = Reply(ReplyType.INFO, '记忆已清除') | |||||
elif query == '#清除所有': | |||||
reply = Reply(ReplyType.INFO, "记忆已清除") | |||||
elif query == "#清除所有": | |||||
self.sessions.clear_all_session() | self.sessions.clear_all_session() | ||||
reply = Reply(ReplyType.INFO, '所有人记忆已清除') | |||||
reply = Reply(ReplyType.INFO, "所有人记忆已清除") | |||||
else: | else: | ||||
session = self.sessions.session_query(query, session_id) | session = self.sessions.session_query(query, session_id) | ||||
result = self.reply_text(session) | result = self.reply_text(session) | ||||
total_tokens, completion_tokens, reply_content = result['total_tokens'], result['completion_tokens'], result['content'] | |||||
logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens)) | |||||
total_tokens, completion_tokens, reply_content = ( | |||||
result["total_tokens"], | |||||
result["completion_tokens"], | |||||
result["content"], | |||||
) | |||||
logger.debug( | |||||
"[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format( | |||||
str(session), session_id, reply_content, completion_tokens | |||||
) | |||||
) | |||||
if total_tokens == 0 : | |||||
if total_tokens == 0: | |||||
reply = Reply(ReplyType.ERROR, reply_content) | reply = Reply(ReplyType.ERROR, reply_content) | ||||
else: | else: | ||||
self.sessions.session_reply(reply_content, session_id, total_tokens) | |||||
self.sessions.session_reply( | |||||
reply_content, session_id, total_tokens | |||||
) | |||||
reply = Reply(ReplyType.TEXT, reply_content) | reply = Reply(ReplyType.TEXT, reply_content) | ||||
return reply | return reply | ||||
elif context.type == ContextType.IMAGE_CREATE: | elif context.type == ContextType.IMAGE_CREATE: | ||||
@@ -72,42 +93,44 @@ class OpenAIBot(Bot, OpenAIImage): | |||||
reply = Reply(ReplyType.ERROR, retstring) | reply = Reply(ReplyType.ERROR, retstring) | ||||
return reply | return reply | ||||
def reply_text(self, session:OpenAISession, retry_count=0): | |||||
def reply_text(self, session: OpenAISession, retry_count=0): | |||||
try: | try: | ||||
response = openai.Completion.create( | |||||
prompt=str(session), **self.args | |||||
response = openai.Completion.create(prompt=str(session), **self.args) | |||||
res_content = ( | |||||
response.choices[0]["text"].strip().replace("<|endoftext|>", "") | |||||
) | ) | ||||
res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '') | |||||
total_tokens = response["usage"]["total_tokens"] | total_tokens = response["usage"]["total_tokens"] | ||||
completion_tokens = response["usage"]["completion_tokens"] | completion_tokens = response["usage"]["completion_tokens"] | ||||
logger.info("[OPEN_AI] reply={}".format(res_content)) | logger.info("[OPEN_AI] reply={}".format(res_content)) | ||||
return {"total_tokens": total_tokens, | |||||
"completion_tokens": completion_tokens, | |||||
"content": res_content} | |||||
return { | |||||
"total_tokens": total_tokens, | |||||
"completion_tokens": completion_tokens, | |||||
"content": res_content, | |||||
} | |||||
except Exception as e: | except Exception as e: | ||||
need_retry = retry_count < 2 | need_retry = retry_count < 2 | ||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} | result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} | ||||
if isinstance(e, openai.error.RateLimitError): | if isinstance(e, openai.error.RateLimitError): | ||||
logger.warn("[OPEN_AI] RateLimitError: {}".format(e)) | logger.warn("[OPEN_AI] RateLimitError: {}".format(e)) | ||||
result['content'] = "提问太快啦,请休息一下再问我吧" | |||||
result["content"] = "提问太快啦,请休息一下再问我吧" | |||||
if need_retry: | if need_retry: | ||||
time.sleep(5) | time.sleep(5) | ||||
elif isinstance(e, openai.error.Timeout): | elif isinstance(e, openai.error.Timeout): | ||||
logger.warn("[OPEN_AI] Timeout: {}".format(e)) | logger.warn("[OPEN_AI] Timeout: {}".format(e)) | ||||
result['content'] = "我没有收到你的消息" | |||||
result["content"] = "我没有收到你的消息" | |||||
if need_retry: | if need_retry: | ||||
time.sleep(5) | time.sleep(5) | ||||
elif isinstance(e, openai.error.APIConnectionError): | elif isinstance(e, openai.error.APIConnectionError): | ||||
logger.warn("[OPEN_AI] APIConnectionError: {}".format(e)) | logger.warn("[OPEN_AI] APIConnectionError: {}".format(e)) | ||||
need_retry = False | need_retry = False | ||||
result['content'] = "我连接不到你的网络" | |||||
result["content"] = "我连接不到你的网络" | |||||
else: | else: | ||||
logger.warn("[OPEN_AI] Exception: {}".format(e)) | logger.warn("[OPEN_AI] Exception: {}".format(e)) | ||||
need_retry = False | need_retry = False | ||||
self.sessions.clear_session(session.session_id) | self.sessions.clear_session(session.session_id) | ||||
if need_retry: | if need_retry: | ||||
logger.warn("[OPEN_AI] 第{}次重试".format(retry_count+1)) | |||||
return self.reply_text(session, retry_count+1) | |||||
logger.warn("[OPEN_AI] 第{}次重试".format(retry_count + 1)) | |||||
return self.reply_text(session, retry_count + 1) | |||||
else: | else: | ||||
return result | |||||
return result |
@@ -1,38 +1,45 @@ | |||||
import time | import time | ||||
import openai | import openai | ||||
import openai.error | import openai.error | ||||
from common.token_bucket import TokenBucket | |||||
from common.log import logger | from common.log import logger | ||||
from common.token_bucket import TokenBucket | |||||
from config import conf | from config import conf | ||||
# OPENAI提供的画图接口 | # OPENAI提供的画图接口 | ||||
class OpenAIImage(object): | class OpenAIImage(object): | ||||
def __init__(self): | def __init__(self): | ||||
openai.api_key = conf().get('open_ai_api_key') | |||||
if conf().get('rate_limit_dalle'): | |||||
self.tb4dalle = TokenBucket(conf().get('rate_limit_dalle', 50)) | |||||
openai.api_key = conf().get("open_ai_api_key") | |||||
if conf().get("rate_limit_dalle"): | |||||
self.tb4dalle = TokenBucket(conf().get("rate_limit_dalle", 50)) | |||||
def create_img(self, query, retry_count=0): | def create_img(self, query, retry_count=0): | ||||
try: | try: | ||||
if conf().get('rate_limit_dalle') and not self.tb4dalle.get_token(): | |||||
if conf().get("rate_limit_dalle") and not self.tb4dalle.get_token(): | |||||
return False, "请求太快了,请休息一下再问我吧" | return False, "请求太快了,请休息一下再问我吧" | ||||
logger.info("[OPEN_AI] image_query={}".format(query)) | logger.info("[OPEN_AI] image_query={}".format(query)) | ||||
response = openai.Image.create( | response = openai.Image.create( | ||||
prompt=query, #图片描述 | |||||
n=1, #每次生成图片的数量 | |||||
size="256x256" #图片大小,可选有 256x256, 512x512, 1024x1024 | |||||
prompt=query, # 图片描述 | |||||
n=1, # 每次生成图片的数量 | |||||
size="256x256", # 图片大小,可选有 256x256, 512x512, 1024x1024 | |||||
) | ) | ||||
image_url = response['data'][0]['url'] | |||||
image_url = response["data"][0]["url"] | |||||
logger.info("[OPEN_AI] image_url={}".format(image_url)) | logger.info("[OPEN_AI] image_url={}".format(image_url)) | ||||
return True, image_url | return True, image_url | ||||
except openai.error.RateLimitError as e: | except openai.error.RateLimitError as e: | ||||
logger.warn(e) | logger.warn(e) | ||||
if retry_count < 1: | if retry_count < 1: | ||||
time.sleep(5) | time.sleep(5) | ||||
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1)) | |||||
return self.create_img(query, retry_count+1) | |||||
logger.warn( | |||||
"[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format( | |||||
retry_count + 1 | |||||
) | |||||
) | |||||
return self.create_img(query, retry_count + 1) | |||||
else: | else: | ||||
return False, "提问太快啦,请休息一下再问我吧" | return False, "提问太快啦,请休息一下再问我吧" | ||||
except Exception as e: | except Exception as e: | ||||
logger.exception(e) | logger.exception(e) | ||||
return False, str(e) | |||||
return False, str(e) |
@@ -1,32 +1,34 @@ | |||||
from bot.session_manager import Session | from bot.session_manager import Session | ||||
from common.log import logger | from common.log import logger | ||||
class OpenAISession(Session): | class OpenAISession(Session): | ||||
def __init__(self, session_id, system_prompt=None, model= "text-davinci-003"): | |||||
def __init__(self, session_id, system_prompt=None, model="text-davinci-003"): | |||||
super().__init__(session_id, system_prompt) | super().__init__(session_id, system_prompt) | ||||
self.model = model | self.model = model | ||||
self.reset() | self.reset() | ||||
def __str__(self): | def __str__(self): | ||||
# 构造对话模型的输入 | # 构造对话模型的输入 | ||||
''' | |||||
""" | |||||
e.g. Q: xxx | e.g. Q: xxx | ||||
A: xxx | A: xxx | ||||
Q: xxx | Q: xxx | ||||
''' | |||||
""" | |||||
prompt = "" | prompt = "" | ||||
for item in self.messages: | for item in self.messages: | ||||
if item['role'] == 'system': | |||||
prompt += item['content'] + "<|endoftext|>\n\n\n" | |||||
elif item['role'] == 'user': | |||||
prompt += "Q: " + item['content'] + "\n" | |||||
elif item['role'] == 'assistant': | |||||
prompt += "\n\nA: " + item['content'] + "<|endoftext|>\n" | |||||
if item["role"] == "system": | |||||
prompt += item["content"] + "<|endoftext|>\n\n\n" | |||||
elif item["role"] == "user": | |||||
prompt += "Q: " + item["content"] + "\n" | |||||
elif item["role"] == "assistant": | |||||
prompt += "\n\nA: " + item["content"] + "<|endoftext|>\n" | |||||
if len(self.messages) > 0 and self.messages[-1]['role'] == 'user': | |||||
if len(self.messages) > 0 and self.messages[-1]["role"] == "user": | |||||
prompt += "A: " | prompt += "A: " | ||||
return prompt | return prompt | ||||
def discard_exceeding(self, max_tokens, cur_tokens= None): | |||||
def discard_exceeding(self, max_tokens, cur_tokens=None): | |||||
precise = True | precise = True | ||||
try: | try: | ||||
cur_tokens = self.calc_tokens() | cur_tokens = self.calc_tokens() | ||||
@@ -34,7 +36,9 @@ class OpenAISession(Session): | |||||
precise = False | precise = False | ||||
if cur_tokens is None: | if cur_tokens is None: | ||||
raise e | raise e | ||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e)) | |||||
logger.debug( | |||||
"Exception when counting tokens precisely for query: {}".format(e) | |||||
) | |||||
while cur_tokens > max_tokens: | while cur_tokens > max_tokens: | ||||
if len(self.messages) > 1: | if len(self.messages) > 1: | ||||
self.messages.pop(0) | self.messages.pop(0) | ||||
@@ -46,24 +50,34 @@ class OpenAISession(Session): | |||||
cur_tokens = len(str(self)) | cur_tokens = len(str(self)) | ||||
break | break | ||||
elif len(self.messages) == 1 and self.messages[0]["role"] == "user": | elif len(self.messages) == 1 and self.messages[0]["role"] == "user": | ||||
logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens)) | |||||
logger.warn( | |||||
"user question exceed max_tokens. total_tokens={}".format( | |||||
cur_tokens | |||||
) | |||||
) | |||||
break | break | ||||
else: | else: | ||||
logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages))) | |||||
logger.debug( | |||||
"max_tokens={}, total_tokens={}, len(conversation)={}".format( | |||||
max_tokens, cur_tokens, len(self.messages) | |||||
) | |||||
) | |||||
break | break | ||||
if precise: | if precise: | ||||
cur_tokens = self.calc_tokens() | cur_tokens = self.calc_tokens() | ||||
else: | else: | ||||
cur_tokens = len(str(self)) | cur_tokens = len(str(self)) | ||||
return cur_tokens | return cur_tokens | ||||
def calc_tokens(self): | def calc_tokens(self): | ||||
return num_tokens_from_string(str(self), self.model) | return num_tokens_from_string(str(self), self.model) | ||||
# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb | # refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb | ||||
def num_tokens_from_string(string: str, model: str) -> int: | def num_tokens_from_string(string: str, model: str) -> int: | ||||
"""Returns the number of tokens in a text string.""" | """Returns the number of tokens in a text string.""" | ||||
import tiktoken | import tiktoken | ||||
encoding = tiktoken.encoding_for_model(model) | encoding = tiktoken.encoding_for_model(model) | ||||
num_tokens = len(encoding.encode(string,disallowed_special=())) | |||||
return num_tokens | |||||
num_tokens = len(encoding.encode(string, disallowed_special=())) | |||||
return num_tokens |
@@ -2,6 +2,7 @@ from common.expired_dict import ExpiredDict | |||||
from common.log import logger | from common.log import logger | ||||
from config import conf | from config import conf | ||||
class Session(object): | class Session(object): | ||||
def __init__(self, session_id, system_prompt=None): | def __init__(self, session_id, system_prompt=None): | ||||
self.session_id = session_id | self.session_id = session_id | ||||
@@ -13,7 +14,7 @@ class Session(object): | |||||
# 重置会话 | # 重置会话 | ||||
def reset(self): | def reset(self): | ||||
system_item = {'role': 'system', 'content': self.system_prompt} | |||||
system_item = {"role": "system", "content": self.system_prompt} | |||||
self.messages = [system_item] | self.messages = [system_item] | ||||
def set_system_prompt(self, system_prompt): | def set_system_prompt(self, system_prompt): | ||||
@@ -21,13 +22,13 @@ class Session(object): | |||||
self.reset() | self.reset() | ||||
def add_query(self, query): | def add_query(self, query): | ||||
user_item = {'role': 'user', 'content': query} | |||||
user_item = {"role": "user", "content": query} | |||||
self.messages.append(user_item) | self.messages.append(user_item) | ||||
def add_reply(self, reply): | def add_reply(self, reply): | ||||
assistant_item = {'role': 'assistant', 'content': reply} | |||||
assistant_item = {"role": "assistant", "content": reply} | |||||
self.messages.append(assistant_item) | self.messages.append(assistant_item) | ||||
def discard_exceeding(self, max_tokens=None, cur_tokens=None): | def discard_exceeding(self, max_tokens=None, cur_tokens=None): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
@@ -37,8 +38,8 @@ class Session(object): | |||||
class SessionManager(object): | class SessionManager(object): | ||||
def __init__(self, sessioncls, **session_args): | def __init__(self, sessioncls, **session_args): | ||||
if conf().get('expires_in_seconds'): | |||||
sessions = ExpiredDict(conf().get('expires_in_seconds')) | |||||
if conf().get("expires_in_seconds"): | |||||
sessions = ExpiredDict(conf().get("expires_in_seconds")) | |||||
else: | else: | ||||
sessions = dict() | sessions = dict() | ||||
self.sessions = sessions | self.sessions = sessions | ||||
@@ -46,20 +47,22 @@ class SessionManager(object): | |||||
self.session_args = session_args | self.session_args = session_args | ||||
def build_session(self, session_id, system_prompt=None): | def build_session(self, session_id, system_prompt=None): | ||||
''' | |||||
如果session_id不在sessions中,创建一个新的session并添加到sessions中 | |||||
如果system_prompt不会空,会更新session的system_prompt并重置session | |||||
''' | |||||
""" | |||||
如果session_id不在sessions中,创建一个新的session并添加到sessions中 | |||||
如果system_prompt不会空,会更新session的system_prompt并重置session | |||||
""" | |||||
if session_id is None: | if session_id is None: | ||||
return self.sessioncls(session_id, system_prompt, **self.session_args) | return self.sessioncls(session_id, system_prompt, **self.session_args) | ||||
if session_id not in self.sessions: | if session_id not in self.sessions: | ||||
self.sessions[session_id] = self.sessioncls(session_id, system_prompt, **self.session_args) | |||||
self.sessions[session_id] = self.sessioncls( | |||||
session_id, system_prompt, **self.session_args | |||||
) | |||||
elif system_prompt is not None: # 如果有新的system_prompt,更新并重置session | elif system_prompt is not None: # 如果有新的system_prompt,更新并重置session | ||||
self.sessions[session_id].set_system_prompt(system_prompt) | self.sessions[session_id].set_system_prompt(system_prompt) | ||||
session = self.sessions[session_id] | session = self.sessions[session_id] | ||||
return session | return session | ||||
def session_query(self, query, session_id): | def session_query(self, query, session_id): | ||||
session = self.build_session(session_id) | session = self.build_session(session_id) | ||||
session.add_query(query) | session.add_query(query) | ||||
@@ -68,23 +71,33 @@ class SessionManager(object): | |||||
total_tokens = session.discard_exceeding(max_tokens, None) | total_tokens = session.discard_exceeding(max_tokens, None) | ||||
logger.debug("prompt tokens used={}".format(total_tokens)) | logger.debug("prompt tokens used={}".format(total_tokens)) | ||||
except Exception as e: | except Exception as e: | ||||
logger.debug("Exception when counting tokens precisely for prompt: {}".format(str(e))) | |||||
logger.debug( | |||||
"Exception when counting tokens precisely for prompt: {}".format(str(e)) | |||||
) | |||||
return session | return session | ||||
def session_reply(self, reply, session_id, total_tokens = None): | |||||
def session_reply(self, reply, session_id, total_tokens=None): | |||||
session = self.build_session(session_id) | session = self.build_session(session_id) | ||||
session.add_reply(reply) | session.add_reply(reply) | ||||
try: | try: | ||||
max_tokens = conf().get("conversation_max_tokens", 1000) | max_tokens = conf().get("conversation_max_tokens", 1000) | ||||
tokens_cnt = session.discard_exceeding(max_tokens, total_tokens) | tokens_cnt = session.discard_exceeding(max_tokens, total_tokens) | ||||
logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt)) | |||||
logger.debug( | |||||
"raw total_tokens={}, savesession tokens={}".format( | |||||
total_tokens, tokens_cnt | |||||
) | |||||
) | |||||
except Exception as e: | except Exception as e: | ||||
logger.debug("Exception when counting tokens precisely for session: {}".format(str(e))) | |||||
logger.debug( | |||||
"Exception when counting tokens precisely for session: {}".format( | |||||
str(e) | |||||
) | |||||
) | |||||
return session | return session | ||||
def clear_session(self, session_id): | def clear_session(self, session_id): | ||||
if session_id in self.sessions: | if session_id in self.sessions: | ||||
del(self.sessions[session_id]) | |||||
del self.sessions[session_id] | |||||
def clear_all_session(self): | def clear_all_session(self): | ||||
self.sessions.clear() | self.sessions.clear() |
@@ -1,31 +1,31 @@ | |||||
from bot import bot_factory | |||||
from bridge.context import Context | from bridge.context import Context | ||||
from bridge.reply import Reply | from bridge.reply import Reply | ||||
from common import const | |||||
from common.log import logger | from common.log import logger | ||||
from bot import bot_factory | |||||
from common.singleton import singleton | from common.singleton import singleton | ||||
from voice import voice_factory | |||||
from config import conf | from config import conf | ||||
from common import const | |||||
from voice import voice_factory | |||||
@singleton | @singleton | ||||
class Bridge(object): | class Bridge(object): | ||||
def __init__(self): | def __init__(self): | ||||
self.btype={ | |||||
self.btype = { | |||||
"chat": const.CHATGPT, | "chat": const.CHATGPT, | ||||
"voice_to_text": conf().get("voice_to_text", "openai"), | "voice_to_text": conf().get("voice_to_text", "openai"), | ||||
"text_to_voice": conf().get("text_to_voice", "google") | |||||
"text_to_voice": conf().get("text_to_voice", "google"), | |||||
} | } | ||||
model_type = conf().get("model") | model_type = conf().get("model") | ||||
if model_type in ["text-davinci-003"]: | if model_type in ["text-davinci-003"]: | ||||
self.btype['chat'] = const.OPEN_AI | |||||
self.btype["chat"] = const.OPEN_AI | |||||
if conf().get("use_azure_chatgpt", False): | if conf().get("use_azure_chatgpt", False): | ||||
self.btype['chat'] = const.CHATGPTONAZURE | |||||
self.bots={} | |||||
self.btype["chat"] = const.CHATGPTONAZURE | |||||
self.bots = {} | |||||
def get_bot(self,typename): | |||||
def get_bot(self, typename): | |||||
if self.bots.get(typename) is None: | if self.bots.get(typename) is None: | ||||
logger.info("create bot {} for {}".format(self.btype[typename],typename)) | |||||
logger.info("create bot {} for {}".format(self.btype[typename], typename)) | |||||
if typename == "text_to_voice": | if typename == "text_to_voice": | ||||
self.bots[typename] = voice_factory.create_voice(self.btype[typename]) | self.bots[typename] = voice_factory.create_voice(self.btype[typename]) | ||||
elif typename == "voice_to_text": | elif typename == "voice_to_text": | ||||
@@ -33,18 +33,15 @@ class Bridge(object): | |||||
elif typename == "chat": | elif typename == "chat": | ||||
self.bots[typename] = bot_factory.create_bot(self.btype[typename]) | self.bots[typename] = bot_factory.create_bot(self.btype[typename]) | ||||
return self.bots[typename] | return self.bots[typename] | ||||
def get_bot_type(self,typename): | |||||
return self.btype[typename] | |||||
def get_bot_type(self, typename): | |||||
return self.btype[typename] | |||||
def fetch_reply_content(self, query, context : Context) -> Reply: | |||||
def fetch_reply_content(self, query, context: Context) -> Reply: | |||||
return self.get_bot("chat").reply(query, context) | return self.get_bot("chat").reply(query, context) | ||||
def fetch_voice_to_text(self, voiceFile) -> Reply: | def fetch_voice_to_text(self, voiceFile) -> Reply: | ||||
return self.get_bot("voice_to_text").voiceToText(voiceFile) | return self.get_bot("voice_to_text").voiceToText(voiceFile) | ||||
def fetch_text_to_voice(self, text) -> Reply: | def fetch_text_to_voice(self, text) -> Reply: | ||||
return self.get_bot("text_to_voice").textToVoice(text) | return self.get_bot("text_to_voice").textToVoice(text) | ||||
@@ -2,36 +2,39 @@ | |||||
from enum import Enum | from enum import Enum | ||||
class ContextType (Enum): | |||||
TEXT = 1 # 文本消息 | |||||
VOICE = 2 # 音频消息 | |||||
IMAGE = 3 # 图片消息 | |||||
IMAGE_CREATE = 10 # 创建图片命令 | |||||
class ContextType(Enum): | |||||
TEXT = 1 # 文本消息 | |||||
VOICE = 2 # 音频消息 | |||||
IMAGE = 3 # 图片消息 | |||||
IMAGE_CREATE = 10 # 创建图片命令 | |||||
def __str__(self): | def __str__(self): | ||||
return self.name | return self.name | ||||
class Context: | class Context: | ||||
def __init__(self, type : ContextType = None , content = None, kwargs = dict()): | |||||
def __init__(self, type: ContextType = None, content=None, kwargs=dict()): | |||||
self.type = type | self.type = type | ||||
self.content = content | self.content = content | ||||
self.kwargs = kwargs | self.kwargs = kwargs | ||||
def __contains__(self, key): | def __contains__(self, key): | ||||
if key == 'type': | |||||
if key == "type": | |||||
return self.type is not None | return self.type is not None | ||||
elif key == 'content': | |||||
elif key == "content": | |||||
return self.content is not None | return self.content is not None | ||||
else: | else: | ||||
return key in self.kwargs | return key in self.kwargs | ||||
def __getitem__(self, key): | def __getitem__(self, key): | ||||
if key == 'type': | |||||
if key == "type": | |||||
return self.type | return self.type | ||||
elif key == 'content': | |||||
elif key == "content": | |||||
return self.content | return self.content | ||||
else: | else: | ||||
return self.kwargs[key] | return self.kwargs[key] | ||||
def get(self, key, default=None): | def get(self, key, default=None): | ||||
try: | try: | ||||
return self[key] | return self[key] | ||||
@@ -39,20 +42,22 @@ class Context: | |||||
return default | return default | ||||
def __setitem__(self, key, value): | def __setitem__(self, key, value): | ||||
if key == 'type': | |||||
if key == "type": | |||||
self.type = value | self.type = value | ||||
elif key == 'content': | |||||
elif key == "content": | |||||
self.content = value | self.content = value | ||||
else: | else: | ||||
self.kwargs[key] = value | self.kwargs[key] = value | ||||
def __delitem__(self, key): | def __delitem__(self, key): | ||||
if key == 'type': | |||||
if key == "type": | |||||
self.type = None | self.type = None | ||||
elif key == 'content': | |||||
elif key == "content": | |||||
self.content = None | self.content = None | ||||
else: | else: | ||||
del self.kwargs[key] | del self.kwargs[key] | ||||
def __str__(self): | def __str__(self): | ||||
return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs) | |||||
return "Context(type={}, content={}, kwargs={})".format( | |||||
self.type, self.content, self.kwargs | |||||
) |
@@ -1,22 +1,25 @@ | |||||
# encoding:utf-8 | # encoding:utf-8 | ||||
from enum import Enum | from enum import Enum | ||||
class ReplyType(Enum): | class ReplyType(Enum): | ||||
TEXT = 1 # 文本 | |||||
VOICE = 2 # 音频文件 | |||||
IMAGE = 3 # 图片文件 | |||||
IMAGE_URL = 4 # 图片URL | |||||
TEXT = 1 # 文本 | |||||
VOICE = 2 # 音频文件 | |||||
IMAGE = 3 # 图片文件 | |||||
IMAGE_URL = 4 # 图片URL | |||||
INFO = 9 | INFO = 9 | ||||
ERROR = 10 | ERROR = 10 | ||||
def __str__(self): | def __str__(self): | ||||
return self.name | return self.name | ||||
class Reply: | class Reply: | ||||
def __init__(self, type : ReplyType = None , content = None): | |||||
def __init__(self, type: ReplyType = None, content=None): | |||||
self.type = type | self.type = type | ||||
self.content = content | self.content = content | ||||
def __str__(self): | def __str__(self): | ||||
return "Reply(type={}, content={})".format(self.type, self.content) | |||||
return "Reply(type={}, content={})".format(self.type, self.content) |
@@ -6,8 +6,10 @@ from bridge.bridge import Bridge | |||||
from bridge.context import Context | from bridge.context import Context | ||||
from bridge.reply import * | from bridge.reply import * | ||||
class Channel(object): | class Channel(object): | ||||
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE, ReplyType.IMAGE] | NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE, ReplyType.IMAGE] | ||||
def startup(self): | def startup(self): | ||||
""" | """ | ||||
init channel | init channel | ||||
@@ -27,15 +29,15 @@ class Channel(object): | |||||
send message to user | send message to user | ||||
:param msg: message content | :param msg: message content | ||||
:param receiver: receiver channel account | :param receiver: receiver channel account | ||||
:return: | |||||
:return: | |||||
""" | """ | ||||
raise NotImplementedError | raise NotImplementedError | ||||
def build_reply_content(self, query, context : Context=None) -> Reply: | |||||
def build_reply_content(self, query, context: Context = None) -> Reply: | |||||
return Bridge().fetch_reply_content(query, context) | return Bridge().fetch_reply_content(query, context) | ||||
def build_voice_to_text(self, voice_file) -> Reply: | def build_voice_to_text(self, voice_file) -> Reply: | ||||
return Bridge().fetch_voice_to_text(voice_file) | return Bridge().fetch_voice_to_text(voice_file) | ||||
def build_text_to_voice(self, text) -> Reply: | def build_text_to_voice(self, text) -> Reply: | ||||
return Bridge().fetch_text_to_voice(text) | return Bridge().fetch_text_to_voice(text) |
@@ -2,25 +2,31 @@ | |||||
channel factory | channel factory | ||||
""" | """ | ||||
def create_channel(channel_type): | def create_channel(channel_type): | ||||
""" | """ | ||||
create a channel instance | create a channel instance | ||||
:param channel_type: channel type code | :param channel_type: channel type code | ||||
:return: channel instance | :return: channel instance | ||||
""" | """ | ||||
if channel_type == 'wx': | |||||
if channel_type == "wx": | |||||
from channel.wechat.wechat_channel import WechatChannel | from channel.wechat.wechat_channel import WechatChannel | ||||
return WechatChannel() | return WechatChannel() | ||||
elif channel_type == 'wxy': | |||||
elif channel_type == "wxy": | |||||
from channel.wechat.wechaty_channel import WechatyChannel | from channel.wechat.wechaty_channel import WechatyChannel | ||||
return WechatyChannel() | return WechatyChannel() | ||||
elif channel_type == 'terminal': | |||||
elif channel_type == "terminal": | |||||
from channel.terminal.terminal_channel import TerminalChannel | from channel.terminal.terminal_channel import TerminalChannel | ||||
return TerminalChannel() | return TerminalChannel() | ||||
elif channel_type == 'wechatmp': | |||||
elif channel_type == "wechatmp": | |||||
from channel.wechatmp.wechatmp_channel import WechatMPChannel | from channel.wechatmp.wechatmp_channel import WechatMPChannel | ||||
return WechatMPChannel(passive_reply = True) | |||||
elif channel_type == 'wechatmp_service': | |||||
return WechatMPChannel(passive_reply=True) | |||||
elif channel_type == "wechatmp_service": | |||||
from channel.wechatmp.wechatmp_channel import WechatMPChannel | from channel.wechatmp.wechatmp_channel import WechatMPChannel | ||||
return WechatMPChannel(passive_reply = False) | |||||
return WechatMPChannel(passive_reply=False) | |||||
raise RuntimeError | raise RuntimeError |
@@ -1,137 +1,172 @@ | |||||
from asyncio import CancelledError | |||||
from concurrent.futures import Future, ThreadPoolExecutor | |||||
import os | import os | ||||
import re | import re | ||||
import threading | import threading | ||||
import time | import time | ||||
from common.dequeue import Dequeue | |||||
from channel.channel import Channel | |||||
from bridge.reply import * | |||||
from asyncio import CancelledError | |||||
from concurrent.futures import Future, ThreadPoolExecutor | |||||
from bridge.context import * | from bridge.context import * | ||||
from config import conf | |||||
from bridge.reply import * | |||||
from channel.channel import Channel | |||||
from common.dequeue import Dequeue | |||||
from common.log import logger | from common.log import logger | ||||
from config import conf | |||||
from plugins import * | from plugins import * | ||||
try: | try: | ||||
from voice.audio_convert import any_to_wav | from voice.audio_convert import any_to_wav | ||||
except Exception as e: | except Exception as e: | ||||
pass | pass | ||||
# 抽象类, 它包含了与消息通道无关的通用处理逻辑 | # 抽象类, 它包含了与消息通道无关的通用处理逻辑 | ||||
class ChatChannel(Channel): | class ChatChannel(Channel): | ||||
name = None # 登录的用户名 | |||||
user_id = None # 登录的用户id | |||||
futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消 | |||||
sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理 | |||||
lock = threading.Lock() # 用于控制对sessions的访问 | |||||
name = None # 登录的用户名 | |||||
user_id = None # 登录的用户id | |||||
futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消 | |||||
sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理 | |||||
lock = threading.Lock() # 用于控制对sessions的访问 | |||||
handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池 | handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池 | ||||
def __init__(self): | def __init__(self): | ||||
_thread = threading.Thread(target=self.consume) | _thread = threading.Thread(target=self.consume) | ||||
_thread.setDaemon(True) | _thread.setDaemon(True) | ||||
_thread.start() | _thread.start() | ||||
# 根据消息构造context,消息内容相关的触发项写在这里 | # 根据消息构造context,消息内容相关的触发项写在这里 | ||||
def _compose_context(self, ctype: ContextType, content, **kwargs): | def _compose_context(self, ctype: ContextType, content, **kwargs): | ||||
context = Context(ctype, content) | context = Context(ctype, content) | ||||
context.kwargs = kwargs | context.kwargs = kwargs | ||||
# context首次传入时,origin_ctype是None, | |||||
# context首次传入时,origin_ctype是None, | |||||
# 引入的起因是:当输入语音时,会嵌套生成两个context,第一步语音转文本,第二步通过文本生成文字回复。 | # 引入的起因是:当输入语音时,会嵌套生成两个context,第一步语音转文本,第二步通过文本生成文字回复。 | ||||
# origin_ctype用于第二步文本回复时,判断是否需要匹配前缀,如果是私聊的语音,就不需要匹配前缀 | # origin_ctype用于第二步文本回复时,判断是否需要匹配前缀,如果是私聊的语音,就不需要匹配前缀 | ||||
if 'origin_ctype' not in context: | |||||
context['origin_ctype'] = ctype | |||||
if "origin_ctype" not in context: | |||||
context["origin_ctype"] = ctype | |||||
# context首次传入时,receiver是None,根据类型设置receiver | # context首次传入时,receiver是None,根据类型设置receiver | ||||
first_in = 'receiver' not in context | |||||
first_in = "receiver" not in context | |||||
# 群名匹配过程,设置session_id和receiver | # 群名匹配过程,设置session_id和receiver | ||||
if first_in: # context首次传入时,receiver是None,根据类型设置receiver | |||||
if first_in: # context首次传入时,receiver是None,根据类型设置receiver | |||||
config = conf() | config = conf() | ||||
cmsg = context['msg'] | |||||
cmsg = context["msg"] | |||||
if context.get("isgroup", False): | if context.get("isgroup", False): | ||||
group_name = cmsg.other_user_nickname | group_name = cmsg.other_user_nickname | ||||
group_id = cmsg.other_user_id | group_id = cmsg.other_user_id | ||||
group_name_white_list = config.get('group_name_white_list', []) | |||||
group_name_keyword_white_list = config.get('group_name_keyword_white_list', []) | |||||
if any([group_name in group_name_white_list, 'ALL_GROUP' in group_name_white_list, check_contain(group_name, group_name_keyword_white_list)]): | |||||
group_chat_in_one_session = conf().get('group_chat_in_one_session', []) | |||||
group_name_white_list = config.get("group_name_white_list", []) | |||||
group_name_keyword_white_list = config.get( | |||||
"group_name_keyword_white_list", [] | |||||
) | |||||
if any( | |||||
[ | |||||
group_name in group_name_white_list, | |||||
"ALL_GROUP" in group_name_white_list, | |||||
check_contain(group_name, group_name_keyword_white_list), | |||||
] | |||||
): | |||||
group_chat_in_one_session = conf().get( | |||||
"group_chat_in_one_session", [] | |||||
) | |||||
session_id = cmsg.actual_user_id | session_id = cmsg.actual_user_id | ||||
if any([group_name in group_chat_in_one_session, 'ALL_GROUP' in group_chat_in_one_session]): | |||||
if any( | |||||
[ | |||||
group_name in group_chat_in_one_session, | |||||
"ALL_GROUP" in group_chat_in_one_session, | |||||
] | |||||
): | |||||
session_id = group_id | session_id = group_id | ||||
else: | else: | ||||
return None | return None | ||||
context['session_id'] = session_id | |||||
context['receiver'] = group_id | |||||
context["session_id"] = session_id | |||||
context["receiver"] = group_id | |||||
else: | else: | ||||
context['session_id'] = cmsg.other_user_id | |||||
context['receiver'] = cmsg.other_user_id | |||||
e_context = PluginManager().emit_event(EventContext(Event.ON_RECEIVE_MESSAGE, {'channel': self, 'context': context})) | |||||
context = e_context['context'] | |||||
context["session_id"] = cmsg.other_user_id | |||||
context["receiver"] = cmsg.other_user_id | |||||
e_context = PluginManager().emit_event( | |||||
EventContext( | |||||
Event.ON_RECEIVE_MESSAGE, {"channel": self, "context": context} | |||||
) | |||||
) | |||||
context = e_context["context"] | |||||
if e_context.is_pass() or context is None: | if e_context.is_pass() or context is None: | ||||
return context | return context | ||||
if cmsg.from_user_id == self.user_id and not config.get('trigger_by_self', True): | |||||
if cmsg.from_user_id == self.user_id and not config.get( | |||||
"trigger_by_self", True | |||||
): | |||||
logger.debug("[WX]self message skipped") | logger.debug("[WX]self message skipped") | ||||
return None | return None | ||||
# 消息内容匹配过程,并处理content | # 消息内容匹配过程,并处理content | ||||
if ctype == ContextType.TEXT: | if ctype == ContextType.TEXT: | ||||
if first_in and "」\n- - - - - - -" in content: # 初次匹配 过滤引用消息 | |||||
if first_in and "」\n- - - - - - -" in content: # 初次匹配 过滤引用消息 | |||||
logger.debug("[WX]reference query skipped") | logger.debug("[WX]reference query skipped") | ||||
return None | return None | ||||
if context.get("isgroup", False): # 群聊 | |||||
if context.get("isgroup", False): # 群聊 | |||||
# 校验关键字 | # 校验关键字 | ||||
match_prefix = check_prefix(content, conf().get('group_chat_prefix')) | |||||
match_contain = check_contain(content, conf().get('group_chat_keyword')) | |||||
match_prefix = check_prefix(content, conf().get("group_chat_prefix")) | |||||
match_contain = check_contain(content, conf().get("group_chat_keyword")) | |||||
flag = False | flag = False | ||||
if match_prefix is not None or match_contain is not None: | if match_prefix is not None or match_contain is not None: | ||||
flag = True | flag = True | ||||
if match_prefix: | if match_prefix: | ||||
content = content.replace(match_prefix, '', 1).strip() | |||||
if context['msg'].is_at: | |||||
content = content.replace(match_prefix, "", 1).strip() | |||||
if context["msg"].is_at: | |||||
logger.info("[WX]receive group at") | logger.info("[WX]receive group at") | ||||
if not conf().get("group_at_off", False): | if not conf().get("group_at_off", False): | ||||
flag = True | flag = True | ||||
pattern = f'@{self.name}(\u2005|\u0020)' | |||||
content = re.sub(pattern, r'', content) | |||||
pattern = f"@{self.name}(\u2005|\u0020)" | |||||
content = re.sub(pattern, r"", content) | |||||
if not flag: | if not flag: | ||||
if context["origin_ctype"] == ContextType.VOICE: | if context["origin_ctype"] == ContextType.VOICE: | ||||
logger.info("[WX]receive group voice, but checkprefix didn't match") | |||||
logger.info( | |||||
"[WX]receive group voice, but checkprefix didn't match" | |||||
) | |||||
return None | return None | ||||
else: # 单聊 | |||||
match_prefix = check_prefix(content, conf().get('single_chat_prefix',[''])) | |||||
if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容 | |||||
content = content.replace(match_prefix, '', 1).strip() | |||||
elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件 | |||||
else: # 单聊 | |||||
match_prefix = check_prefix( | |||||
content, conf().get("single_chat_prefix", [""]) | |||||
) | |||||
if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容 | |||||
content = content.replace(match_prefix, "", 1).strip() | |||||
elif ( | |||||
context["origin_ctype"] == ContextType.VOICE | |||||
): # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件 | |||||
pass | pass | ||||
else: | else: | ||||
return None | |||||
img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) | |||||
return None | |||||
img_match_prefix = check_prefix(content, conf().get("image_create_prefix")) | |||||
if img_match_prefix: | if img_match_prefix: | ||||
content = content.replace(img_match_prefix, '', 1) | |||||
content = content.replace(img_match_prefix, "", 1) | |||||
context.type = ContextType.IMAGE_CREATE | context.type = ContextType.IMAGE_CREATE | ||||
else: | else: | ||||
context.type = ContextType.TEXT | context.type = ContextType.TEXT | ||||
context.content = content.strip() | context.content = content.strip() | ||||
if 'desire_rtype' not in context and conf().get('always_reply_voice') and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE: | |||||
context['desire_rtype'] = ReplyType.VOICE | |||||
elif context.type == ContextType.VOICE: | |||||
if 'desire_rtype' not in context and conf().get('voice_reply_voice') and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE: | |||||
context['desire_rtype'] = ReplyType.VOICE | |||||
if ( | |||||
"desire_rtype" not in context | |||||
and conf().get("always_reply_voice") | |||||
and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE | |||||
): | |||||
context["desire_rtype"] = ReplyType.VOICE | |||||
elif context.type == ContextType.VOICE: | |||||
if ( | |||||
"desire_rtype" not in context | |||||
and conf().get("voice_reply_voice") | |||||
and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE | |||||
): | |||||
context["desire_rtype"] = ReplyType.VOICE | |||||
return context | return context | ||||
def _handle(self, context: Context): | def _handle(self, context: Context): | ||||
if context is None or not context.content: | if context is None or not context.content: | ||||
return | return | ||||
logger.debug('[WX] ready to handle context: {}'.format(context)) | |||||
logger.debug("[WX] ready to handle context: {}".format(context)) | |||||
# reply的构建步骤 | # reply的构建步骤 | ||||
reply = self._generate_reply(context) | reply = self._generate_reply(context) | ||||
logger.debug('[WX] ready to decorate reply: {}'.format(reply)) | |||||
logger.debug("[WX] ready to decorate reply: {}".format(reply)) | |||||
# reply的包装步骤 | # reply的包装步骤 | ||||
reply = self._decorate_reply(context, reply) | reply = self._decorate_reply(context, reply) | ||||
@@ -139,20 +174,31 @@ class ChatChannel(Channel): | |||||
self._send_reply(context, reply) | self._send_reply(context, reply) | ||||
def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply: | def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply: | ||||
e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, { | |||||
'channel': self, 'context': context, 'reply': reply})) | |||||
reply = e_context['reply'] | |||||
e_context = PluginManager().emit_event( | |||||
EventContext( | |||||
Event.ON_HANDLE_CONTEXT, | |||||
{"channel": self, "context": context, "reply": reply}, | |||||
) | |||||
) | |||||
reply = e_context["reply"] | |||||
if not e_context.is_pass(): | if not e_context.is_pass(): | ||||
logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content)) | |||||
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息 | |||||
logger.debug( | |||||
"[WX] ready to handle context: type={}, content={}".format( | |||||
context.type, context.content | |||||
) | |||||
) | |||||
if ( | |||||
context.type == ContextType.TEXT | |||||
or context.type == ContextType.IMAGE_CREATE | |||||
): # 文字和图片消息 | |||||
reply = super().build_reply_content(context.content, context) | reply = super().build_reply_content(context.content, context) | ||||
elif context.type == ContextType.VOICE: # 语音消息 | elif context.type == ContextType.VOICE: # 语音消息 | ||||
cmsg = context['msg'] | |||||
cmsg = context["msg"] | |||||
cmsg.prepare() | cmsg.prepare() | ||||
file_path = context.content | file_path = context.content | ||||
wav_path = os.path.splitext(file_path)[0] + '.wav' | |||||
wav_path = os.path.splitext(file_path)[0] + ".wav" | |||||
try: | try: | ||||
any_to_wav(file_path, wav_path) | |||||
any_to_wav(file_path, wav_path) | |||||
except Exception as e: # 转换失败,直接使用mp3,对于某些api,mp3也可以识别 | except Exception as e: # 转换失败,直接使用mp3,对于某些api,mp3也可以识别 | ||||
logger.warning("[WX]any to wav error, use raw path. " + str(e)) | logger.warning("[WX]any to wav error, use raw path. " + str(e)) | ||||
wav_path = file_path | wav_path = file_path | ||||
@@ -169,7 +215,8 @@ class ChatChannel(Channel): | |||||
if reply.type == ReplyType.TEXT: | if reply.type == ReplyType.TEXT: | ||||
new_context = self._compose_context( | new_context = self._compose_context( | ||||
ContextType.TEXT, reply.content, **context.kwargs) | |||||
ContextType.TEXT, reply.content, **context.kwargs | |||||
) | |||||
if new_context: | if new_context: | ||||
reply = self._generate_reply(new_context) | reply = self._generate_reply(new_context) | ||||
else: | else: | ||||
@@ -177,18 +224,21 @@ class ChatChannel(Channel): | |||||
elif context.type == ContextType.IMAGE: # 图片消息,当前无默认逻辑 | elif context.type == ContextType.IMAGE: # 图片消息,当前无默认逻辑 | ||||
pass | pass | ||||
else: | else: | ||||
logger.error('[WX] unknown context type: {}'.format(context.type)) | |||||
logger.error("[WX] unknown context type: {}".format(context.type)) | |||||
return | return | ||||
return reply | return reply | ||||
def _decorate_reply(self, context: Context, reply: Reply) -> Reply: | def _decorate_reply(self, context: Context, reply: Reply) -> Reply: | ||||
if reply and reply.type: | if reply and reply.type: | ||||
e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, { | |||||
'channel': self, 'context': context, 'reply': reply})) | |||||
reply = e_context['reply'] | |||||
desire_rtype = context.get('desire_rtype') | |||||
e_context = PluginManager().emit_event( | |||||
EventContext( | |||||
Event.ON_DECORATE_REPLY, | |||||
{"channel": self, "context": context, "reply": reply}, | |||||
) | |||||
) | |||||
reply = e_context["reply"] | |||||
desire_rtype = context.get("desire_rtype") | |||||
if not e_context.is_pass() and reply and reply.type: | if not e_context.is_pass() and reply and reply.type: | ||||
if reply.type in self.NOT_SUPPORT_REPLYTYPE: | if reply.type in self.NOT_SUPPORT_REPLYTYPE: | ||||
logger.error("[WX]reply type not support: " + str(reply.type)) | logger.error("[WX]reply type not support: " + str(reply.type)) | ||||
reply.type = ReplyType.ERROR | reply.type = ReplyType.ERROR | ||||
@@ -196,59 +246,91 @@ class ChatChannel(Channel): | |||||
if reply.type == ReplyType.TEXT: | if reply.type == ReplyType.TEXT: | ||||
reply_text = reply.content | reply_text = reply.content | ||||
if desire_rtype == ReplyType.VOICE and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE: | |||||
if ( | |||||
desire_rtype == ReplyType.VOICE | |||||
and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE | |||||
): | |||||
reply = super().build_text_to_voice(reply.content) | reply = super().build_text_to_voice(reply.content) | ||||
return self._decorate_reply(context, reply) | return self._decorate_reply(context, reply) | ||||
if context.get("isgroup", False): | if context.get("isgroup", False): | ||||
reply_text = '@' + context['msg'].actual_user_nickname + ' ' + reply_text.strip() | |||||
reply_text = conf().get("group_chat_reply_prefix", "") + reply_text | |||||
reply_text = ( | |||||
"@" | |||||
+ context["msg"].actual_user_nickname | |||||
+ " " | |||||
+ reply_text.strip() | |||||
) | |||||
reply_text = ( | |||||
conf().get("group_chat_reply_prefix", "") + reply_text | |||||
) | |||||
else: | else: | ||||
reply_text = conf().get("single_chat_reply_prefix", "") + reply_text | |||||
reply_text = ( | |||||
conf().get("single_chat_reply_prefix", "") + reply_text | |||||
) | |||||
reply.content = reply_text | reply.content = reply_text | ||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: | elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: | ||||
reply.content = "["+str(reply.type)+"]\n" + reply.content | |||||
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE: | |||||
reply.content = "[" + str(reply.type) + "]\n" + reply.content | |||||
elif ( | |||||
reply.type == ReplyType.IMAGE_URL | |||||
or reply.type == ReplyType.VOICE | |||||
or reply.type == ReplyType.IMAGE | |||||
): | |||||
pass | pass | ||||
else: | else: | ||||
logger.error('[WX] unknown reply type: {}'.format(reply.type)) | |||||
logger.error("[WX] unknown reply type: {}".format(reply.type)) | |||||
return | return | ||||
if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]: | |||||
logger.warning('[WX] desire_rtype: {}, but reply type: {}'.format(context.get('desire_rtype'), reply.type)) | |||||
if ( | |||||
desire_rtype | |||||
and desire_rtype != reply.type | |||||
and reply.type not in [ReplyType.ERROR, ReplyType.INFO] | |||||
): | |||||
logger.warning( | |||||
"[WX] desire_rtype: {}, but reply type: {}".format( | |||||
context.get("desire_rtype"), reply.type | |||||
) | |||||
) | |||||
return reply | return reply | ||||
def _send_reply(self, context: Context, reply: Reply): | def _send_reply(self, context: Context, reply: Reply): | ||||
if reply and reply.type: | if reply and reply.type: | ||||
e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, { | |||||
'channel': self, 'context': context, 'reply': reply})) | |||||
reply = e_context['reply'] | |||||
e_context = PluginManager().emit_event( | |||||
EventContext( | |||||
Event.ON_SEND_REPLY, | |||||
{"channel": self, "context": context, "reply": reply}, | |||||
) | |||||
) | |||||
reply = e_context["reply"] | |||||
if not e_context.is_pass() and reply and reply.type: | if not e_context.is_pass() and reply and reply.type: | ||||
logger.debug('[WX] ready to send reply: {}, context: {}'.format(reply, context)) | |||||
logger.debug( | |||||
"[WX] ready to send reply: {}, context: {}".format(reply, context) | |||||
) | |||||
self._send(reply, context) | self._send(reply, context) | ||||
def _send(self, reply: Reply, context: Context, retry_cnt = 0): | |||||
def _send(self, reply: Reply, context: Context, retry_cnt=0): | |||||
try: | try: | ||||
self.send(reply, context) | self.send(reply, context) | ||||
except Exception as e: | except Exception as e: | ||||
logger.error('[WX] sendMsg error: {}'.format(str(e))) | |||||
logger.error("[WX] sendMsg error: {}".format(str(e))) | |||||
if isinstance(e, NotImplementedError): | if isinstance(e, NotImplementedError): | ||||
return | return | ||||
logger.exception(e) | logger.exception(e) | ||||
if retry_cnt < 2: | if retry_cnt < 2: | ||||
time.sleep(3+3*retry_cnt) | |||||
self._send(reply, context, retry_cnt+1) | |||||
time.sleep(3 + 3 * retry_cnt) | |||||
self._send(reply, context, retry_cnt + 1) | |||||
def _success_callback(self, session_id, **kwargs):# 线程正常结束时的回调函数 | |||||
def _success_callback(self, session_id, **kwargs): # 线程正常结束时的回调函数 | |||||
logger.debug("Worker return success, session_id = {}".format(session_id)) | logger.debug("Worker return success, session_id = {}".format(session_id)) | ||||
def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数 | |||||
def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数 | |||||
logger.exception("Worker return exception: {}".format(exception)) | logger.exception("Worker return exception: {}".format(exception)) | ||||
def _thread_pool_callback(self, session_id, **kwargs): | def _thread_pool_callback(self, session_id, **kwargs): | ||||
def func(worker:Future): | |||||
def func(worker: Future): | |||||
try: | try: | ||||
worker_exception = worker.exception() | worker_exception = worker.exception() | ||||
if worker_exception: | if worker_exception: | ||||
self._fail_callback(session_id, exception = worker_exception, **kwargs) | |||||
self._fail_callback( | |||||
session_id, exception=worker_exception, **kwargs | |||||
) | |||||
else: | else: | ||||
self._success_callback(session_id, **kwargs) | self._success_callback(session_id, **kwargs) | ||||
except CancelledError as e: | except CancelledError as e: | ||||
@@ -257,15 +339,19 @@ class ChatChannel(Channel): | |||||
logger.exception("Worker raise exception: {}".format(e)) | logger.exception("Worker raise exception: {}".format(e)) | ||||
with self.lock: | with self.lock: | ||||
self.sessions[session_id][1].release() | self.sessions[session_id][1].release() | ||||
return func | return func | ||||
def produce(self, context: Context): | def produce(self, context: Context): | ||||
session_id = context['session_id'] | |||||
session_id = context["session_id"] | |||||
with self.lock: | with self.lock: | ||||
if session_id not in self.sessions: | if session_id not in self.sessions: | ||||
self.sessions[session_id] = [Dequeue(), threading.BoundedSemaphore(conf().get("concurrency_in_session", 4))] | |||||
if context.type == ContextType.TEXT and context.content.startswith("#"): | |||||
self.sessions[session_id][0].putleft(context) # 优先处理管理命令 | |||||
self.sessions[session_id] = [ | |||||
Dequeue(), | |||||
threading.BoundedSemaphore(conf().get("concurrency_in_session", 4)), | |||||
] | |||||
if context.type == ContextType.TEXT and context.content.startswith("#"): | |||||
self.sessions[session_id][0].putleft(context) # 优先处理管理命令 | |||||
else: | else: | ||||
self.sessions[session_id][0].put(context) | self.sessions[session_id][0].put(context) | ||||
@@ -276,44 +362,58 @@ class ChatChannel(Channel): | |||||
session_ids = list(self.sessions.keys()) | session_ids = list(self.sessions.keys()) | ||||
for session_id in session_ids: | for session_id in session_ids: | ||||
context_queue, semaphore = self.sessions[session_id] | context_queue, semaphore = self.sessions[session_id] | ||||
if semaphore.acquire(blocking = False): # 等线程处理完毕才能删除 | |||||
if semaphore.acquire(blocking=False): # 等线程处理完毕才能删除 | |||||
if not context_queue.empty(): | if not context_queue.empty(): | ||||
context = context_queue.get() | context = context_queue.get() | ||||
logger.debug("[WX] consume context: {}".format(context)) | logger.debug("[WX] consume context: {}".format(context)) | ||||
future:Future = self.handler_pool.submit(self._handle, context) | |||||
future.add_done_callback(self._thread_pool_callback(session_id, context = context)) | |||||
future: Future = self.handler_pool.submit( | |||||
self._handle, context | |||||
) | |||||
future.add_done_callback( | |||||
self._thread_pool_callback(session_id, context=context) | |||||
) | |||||
if session_id not in self.futures: | if session_id not in self.futures: | ||||
self.futures[session_id] = [] | self.futures[session_id] = [] | ||||
self.futures[session_id].append(future) | self.futures[session_id].append(future) | ||||
elif semaphore._initial_value == semaphore._value+1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕 | |||||
self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()] | |||||
assert len(self.futures[session_id]) == 0, "thread pool error" | |||||
elif ( | |||||
semaphore._initial_value == semaphore._value + 1 | |||||
): # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕 | |||||
self.futures[session_id] = [ | |||||
t for t in self.futures[session_id] if not t.done() | |||||
] | |||||
assert ( | |||||
len(self.futures[session_id]) == 0 | |||||
), "thread pool error" | |||||
del self.sessions[session_id] | del self.sessions[session_id] | ||||
else: | else: | ||||
semaphore.release() | semaphore.release() | ||||
time.sleep(0.1) | time.sleep(0.1) | ||||
# 取消session_id对应的所有任务,只能取消排队的消息和已提交线程池但未执行的任务 | # 取消session_id对应的所有任务,只能取消排队的消息和已提交线程池但未执行的任务 | ||||
def cancel_session(self, session_id): | |||||
def cancel_session(self, session_id): | |||||
with self.lock: | with self.lock: | ||||
if session_id in self.sessions: | if session_id in self.sessions: | ||||
for future in self.futures[session_id]: | for future in self.futures[session_id]: | ||||
future.cancel() | future.cancel() | ||||
cnt = self.sessions[session_id][0].qsize() | cnt = self.sessions[session_id][0].qsize() | ||||
if cnt>0: | |||||
logger.info("Cancel {} messages in session {}".format(cnt, session_id)) | |||||
if cnt > 0: | |||||
logger.info( | |||||
"Cancel {} messages in session {}".format(cnt, session_id) | |||||
) | |||||
self.sessions[session_id][0] = Dequeue() | self.sessions[session_id][0] = Dequeue() | ||||
def cancel_all_session(self): | def cancel_all_session(self): | ||||
with self.lock: | with self.lock: | ||||
for session_id in self.sessions: | for session_id in self.sessions: | ||||
for future in self.futures[session_id]: | for future in self.futures[session_id]: | ||||
future.cancel() | future.cancel() | ||||
cnt = self.sessions[session_id][0].qsize() | cnt = self.sessions[session_id][0].qsize() | ||||
if cnt>0: | |||||
logger.info("Cancel {} messages in session {}".format(cnt, session_id)) | |||||
if cnt > 0: | |||||
logger.info( | |||||
"Cancel {} messages in session {}".format(cnt, session_id) | |||||
) | |||||
self.sessions[session_id][0] = Dequeue() | self.sessions[session_id][0] = Dequeue() | ||||
def check_prefix(content, prefix_list): | def check_prefix(content, prefix_list): | ||||
if not prefix_list: | if not prefix_list: | ||||
@@ -323,6 +423,7 @@ def check_prefix(content, prefix_list): | |||||
return prefix | return prefix | ||||
return None | return None | ||||
def check_contain(content, keyword_list): | def check_contain(content, keyword_list): | ||||
if not keyword_list: | if not keyword_list: | ||||
return None | return None | ||||
@@ -1,5 +1,4 @@ | |||||
""" | |||||
""" | |||||
本类表示聊天消息,用于对itchat和wechaty的消息进行统一的封装。 | 本类表示聊天消息,用于对itchat和wechaty的消息进行统一的封装。 | ||||
填好必填项(群聊6个,非群聊8个),即可接入ChatChannel,并支持插件,参考TerminalChannel | 填好必填项(群聊6个,非群聊8个),即可接入ChatChannel,并支持插件,参考TerminalChannel | ||||
@@ -20,7 +19,7 @@ other_user_id: 对方的id,如果你是发送者,那这个就是接收者id | |||||
other_user_nickname: 同上 | other_user_nickname: 同上 | ||||
is_group: 是否是群消息 (群聊必填) | is_group: 是否是群消息 (群聊必填) | ||||
is_at: 是否被at | |||||
is_at: 是否被at | |||||
- (群消息时,一般会存在实际发送者,是群内某个成员的id和昵称,下列项仅在群消息时存在) | - (群消息时,一般会存在实际发送者,是群内某个成员的id和昵称,下列项仅在群消息时存在) | ||||
actual_user_id: 实际发送者id (群聊必填) | actual_user_id: 实际发送者id (群聊必填) | ||||
@@ -34,20 +33,22 @@ _prepared: 是否已经调用过准备函数 | |||||
_rawmsg: 原始消息对象 | _rawmsg: 原始消息对象 | ||||
""" | """ | ||||
class ChatMessage(object): | class ChatMessage(object): | ||||
msg_id = None | msg_id = None | ||||
create_time = None | create_time = None | ||||
ctype = None | ctype = None | ||||
content = None | content = None | ||||
from_user_id = None | from_user_id = None | ||||
from_user_nickname = None | from_user_nickname = None | ||||
to_user_id = None | to_user_id = None | ||||
to_user_nickname = None | to_user_nickname = None | ||||
other_user_id = None | other_user_id = None | ||||
other_user_nickname = None | other_user_nickname = None | ||||
is_group = False | is_group = False | ||||
is_at = False | is_at = False | ||||
actual_user_id = None | actual_user_id = None | ||||
@@ -57,8 +58,7 @@ class ChatMessage(object): | |||||
_prepared = False | _prepared = False | ||||
_rawmsg = None | _rawmsg = None | ||||
def __init__(self,_rawmsg): | |||||
def __init__(self, _rawmsg): | |||||
self._rawmsg = _rawmsg | self._rawmsg = _rawmsg | ||||
def prepare(self): | def prepare(self): | ||||
@@ -67,7 +67,7 @@ class ChatMessage(object): | |||||
self._prepare_fn() | self._prepare_fn() | ||||
def __str__(self): | def __str__(self): | ||||
return 'ChatMessage: id={}, create_time={}, ctype={}, content={}, from_user_id={}, from_user_nickname={}, to_user_id={}, to_user_nickname={}, other_user_id={}, other_user_nickname={}, is_group={}, is_at={}, actual_user_id={}, actual_user_nickname={}'.format( | |||||
return "ChatMessage: id={}, create_time={}, ctype={}, content={}, from_user_id={}, from_user_nickname={}, to_user_id={}, to_user_nickname={}, other_user_id={}, other_user_nickname={}, is_group={}, is_at={}, actual_user_id={}, actual_user_nickname={}".format( | |||||
self.msg_id, | self.msg_id, | ||||
self.create_time, | self.create_time, | ||||
self.ctype, | self.ctype, | ||||
@@ -82,4 +82,4 @@ class ChatMessage(object): | |||||
self.is_at, | self.is_at, | ||||
self.actual_user_id, | self.actual_user_id, | ||||
self.actual_user_nickname, | self.actual_user_nickname, | ||||
) | |||||
) |
@@ -1,14 +1,23 @@ | |||||
import sys | |||||
from bridge.context import * | from bridge.context import * | ||||
from bridge.reply import Reply, ReplyType | from bridge.reply import Reply, ReplyType | ||||
from channel.chat_channel import ChatChannel, check_prefix | from channel.chat_channel import ChatChannel, check_prefix | ||||
from channel.chat_message import ChatMessage | from channel.chat_message import ChatMessage | ||||
import sys | |||||
from config import conf | |||||
from common.log import logger | from common.log import logger | ||||
from config import conf | |||||
class TerminalMessage(ChatMessage): | class TerminalMessage(ChatMessage): | ||||
def __init__(self, msg_id, content, ctype = ContextType.TEXT, from_user_id = "User", to_user_id = "Chatgpt", other_user_id = "Chatgpt"): | |||||
def __init__( | |||||
self, | |||||
msg_id, | |||||
content, | |||||
ctype=ContextType.TEXT, | |||||
from_user_id="User", | |||||
to_user_id="Chatgpt", | |||||
other_user_id="Chatgpt", | |||||
): | |||||
self.msg_id = msg_id | self.msg_id = msg_id | ||||
self.ctype = ctype | self.ctype = ctype | ||||
self.content = content | self.content = content | ||||
@@ -16,6 +25,7 @@ class TerminalMessage(ChatMessage): | |||||
self.to_user_id = to_user_id | self.to_user_id = to_user_id | ||||
self.other_user_id = other_user_id | self.other_user_id = other_user_id | ||||
class TerminalChannel(ChatChannel): | class TerminalChannel(ChatChannel): | ||||
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE] | NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE] | ||||
@@ -23,14 +33,18 @@ class TerminalChannel(ChatChannel): | |||||
print("\nBot:") | print("\nBot:") | ||||
if reply.type == ReplyType.IMAGE: | if reply.type == ReplyType.IMAGE: | ||||
from PIL import Image | from PIL import Image | ||||
image_storage = reply.content | image_storage = reply.content | ||||
image_storage.seek(0) | image_storage.seek(0) | ||||
img = Image.open(image_storage) | img = Image.open(image_storage) | ||||
print("<IMAGE>") | print("<IMAGE>") | ||||
img.show() | img.show() | ||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 | |||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 | |||||
import io | |||||
import requests | |||||
from PIL import Image | from PIL import Image | ||||
import requests,io | |||||
img_url = reply.content | img_url = reply.content | ||||
pic_res = requests.get(img_url, stream=True) | pic_res = requests.get(img_url, stream=True) | ||||
image_storage = io.BytesIO() | image_storage = io.BytesIO() | ||||
@@ -59,11 +73,13 @@ class TerminalChannel(ChatChannel): | |||||
print("\nExiting...") | print("\nExiting...") | ||||
sys.exit() | sys.exit() | ||||
msg_id += 1 | msg_id += 1 | ||||
trigger_prefixs = conf().get("single_chat_prefix",[""]) | |||||
trigger_prefixs = conf().get("single_chat_prefix", [""]) | |||||
if check_prefix(prompt, trigger_prefixs) is None: | if check_prefix(prompt, trigger_prefixs) is None: | ||||
prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀 | |||||
context = self._compose_context(ContextType.TEXT, prompt, msg = TerminalMessage(msg_id, prompt)) | |||||
prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀 | |||||
context = self._compose_context( | |||||
ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt) | |||||
) | |||||
if context: | if context: | ||||
self.produce(context) | self.produce(context) | ||||
else: | else: | ||||
@@ -4,40 +4,45 @@ | |||||
wechat channel | wechat channel | ||||
""" | """ | ||||
import io | |||||
import json | |||||
import os | import os | ||||
import threading | import threading | ||||
import requests | |||||
import io | |||||
import time | import time | ||||
import json | |||||
import requests | |||||
from bridge.context import * | |||||
from bridge.reply import * | |||||
from channel.chat_channel import ChatChannel | from channel.chat_channel import ChatChannel | ||||
from channel.wechat.wechat_message import * | from channel.wechat.wechat_message import * | ||||
from common.singleton import singleton | |||||
from common.expired_dict import ExpiredDict | |||||
from common.log import logger | from common.log import logger | ||||
from common.singleton import singleton | |||||
from common.time_check import time_checker | |||||
from config import conf | |||||
from lib import itchat | from lib import itchat | ||||
from lib.itchat.content import * | from lib.itchat.content import * | ||||
from bridge.reply import * | |||||
from bridge.context import * | |||||
from config import conf | |||||
from common.time_check import time_checker | |||||
from common.expired_dict import ExpiredDict | |||||
from plugins import * | from plugins import * | ||||
@itchat.msg_register([TEXT,VOICE,PICTURE]) | |||||
@itchat.msg_register([TEXT, VOICE, PICTURE]) | |||||
def handler_single_msg(msg): | def handler_single_msg(msg): | ||||
# logger.debug("handler_single_msg: {}".format(msg)) | # logger.debug("handler_single_msg: {}".format(msg)) | ||||
if msg['Type'] == PICTURE and msg['MsgType'] == 47: | |||||
if msg["Type"] == PICTURE and msg["MsgType"] == 47: | |||||
return None | return None | ||||
WechatChannel().handle_single(WeChatMessage(msg)) | WechatChannel().handle_single(WeChatMessage(msg)) | ||||
return None | return None | ||||
@itchat.msg_register([TEXT,VOICE,PICTURE], isGroupChat=True) | |||||
@itchat.msg_register([TEXT, VOICE, PICTURE], isGroupChat=True) | |||||
def handler_group_msg(msg): | def handler_group_msg(msg): | ||||
if msg['Type'] == PICTURE and msg['MsgType'] == 47: | |||||
if msg["Type"] == PICTURE and msg["MsgType"] == 47: | |||||
return None | return None | ||||
WechatChannel().handle_group(WeChatMessage(msg,True)) | |||||
WechatChannel().handle_group(WeChatMessage(msg, True)) | |||||
return None | return None | ||||
def _check(func): | def _check(func): | ||||
def wrapper(self, cmsg: ChatMessage): | def wrapper(self, cmsg: ChatMessage): | ||||
msgId = cmsg.msg_id | msgId = cmsg.msg_id | ||||
@@ -45,21 +50,27 @@ def _check(func): | |||||
logger.info("Wechat message {} already received, ignore".format(msgId)) | logger.info("Wechat message {} already received, ignore".format(msgId)) | ||||
return | return | ||||
self.receivedMsgs[msgId] = cmsg | self.receivedMsgs[msgId] = cmsg | ||||
create_time = cmsg.create_time # 消息时间戳 | |||||
if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息 | |||||
create_time = cmsg.create_time # 消息时间戳 | |||||
if ( | |||||
conf().get("hot_reload") == True | |||||
and int(create_time) < int(time.time()) - 60 | |||||
): # 跳过1分钟前的历史消息 | |||||
logger.debug("[WX]history message {} skipped".format(msgId)) | logger.debug("[WX]history message {} skipped".format(msgId)) | ||||
return | return | ||||
return func(self, cmsg) | return func(self, cmsg) | ||||
return wrapper | return wrapper | ||||
#可用的二维码生成接口 | |||||
#https://api.qrserver.com/v1/create-qr-code/?size=400×400&data=https://www.abc.com | |||||
#https://api.isoyu.com/qr/?m=1&e=L&p=20&url=https://www.abc.com | |||||
def qrCallback(uuid,status,qrcode): | |||||
# 可用的二维码生成接口 | |||||
# https://api.qrserver.com/v1/create-qr-code/?size=400×400&data=https://www.abc.com | |||||
# https://api.isoyu.com/qr/?m=1&e=L&p=20&url=https://www.abc.com | |||||
def qrCallback(uuid, status, qrcode): | |||||
# logger.debug("qrCallback: {} {}".format(uuid,status)) | # logger.debug("qrCallback: {} {}".format(uuid,status)) | ||||
if status == '0': | |||||
if status == "0": | |||||
try: | try: | ||||
from PIL import Image | from PIL import Image | ||||
img = Image.open(io.BytesIO(qrcode)) | img = Image.open(io.BytesIO(qrcode)) | ||||
_thread = threading.Thread(target=img.show, args=("QRCode",)) | _thread = threading.Thread(target=img.show, args=("QRCode",)) | ||||
_thread.setDaemon(True) | _thread.setDaemon(True) | ||||
@@ -68,35 +79,43 @@ def qrCallback(uuid,status,qrcode): | |||||
pass | pass | ||||
import qrcode | import qrcode | ||||
url = f"https://login.weixin.qq.com/l/{uuid}" | url = f"https://login.weixin.qq.com/l/{uuid}" | ||||
qr_api1="https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url) | |||||
qr_api2="https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(url) | |||||
qr_api3="https://api.pwmqr.com/qrcode/create/?url={}".format(url) | |||||
qr_api4="https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format(url) | |||||
qr_api1 = "https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url) | |||||
qr_api2 = ( | |||||
"https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format( | |||||
url | |||||
) | |||||
) | |||||
qr_api3 = "https://api.pwmqr.com/qrcode/create/?url={}".format(url) | |||||
qr_api4 = "https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format( | |||||
url | |||||
) | |||||
print("You can also scan QRCode in any website below:") | print("You can also scan QRCode in any website below:") | ||||
print(qr_api3) | print(qr_api3) | ||||
print(qr_api4) | print(qr_api4) | ||||
print(qr_api2) | print(qr_api2) | ||||
print(qr_api1) | print(qr_api1) | ||||
qr = qrcode.QRCode(border=1) | qr = qrcode.QRCode(border=1) | ||||
qr.add_data(url) | qr.add_data(url) | ||||
qr.make(fit=True) | qr.make(fit=True) | ||||
qr.print_ascii(invert=True) | qr.print_ascii(invert=True) | ||||
@singleton | @singleton | ||||
class WechatChannel(ChatChannel): | class WechatChannel(ChatChannel): | ||||
NOT_SUPPORT_REPLYTYPE = [] | NOT_SUPPORT_REPLYTYPE = [] | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
self.receivedMsgs = ExpiredDict(60*60*24) | |||||
self.receivedMsgs = ExpiredDict(60 * 60 * 24) | |||||
def startup(self): | def startup(self): | ||||
itchat.instance.receivingRetryCount = 600 # 修改断线超时时间 | itchat.instance.receivingRetryCount = 600 # 修改断线超时时间 | ||||
# login by scan QRCode | # login by scan QRCode | ||||
hotReload = conf().get('hot_reload', False) | |||||
hotReload = conf().get("hot_reload", False) | |||||
try: | try: | ||||
itchat.auto_login(enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback) | itchat.auto_login(enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback) | ||||
except Exception as e: | except Exception as e: | ||||
@@ -104,12 +123,18 @@ class WechatChannel(ChatChannel): | |||||
logger.error("Hot reload failed, try to login without hot reload") | logger.error("Hot reload failed, try to login without hot reload") | ||||
itchat.logout() | itchat.logout() | ||||
os.remove("itchat.pkl") | os.remove("itchat.pkl") | ||||
itchat.auto_login(enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback) | |||||
itchat.auto_login( | |||||
enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback | |||||
) | |||||
else: | else: | ||||
raise e | raise e | ||||
self.user_id = itchat.instance.storageClass.userName | self.user_id = itchat.instance.storageClass.userName | ||||
self.name = itchat.instance.storageClass.nickName | self.name = itchat.instance.storageClass.nickName | ||||
logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name)) | |||||
logger.info( | |||||
"Wechat login success, user_id: {}, nickname: {}".format( | |||||
self.user_id, self.name | |||||
) | |||||
) | |||||
# start message listener | # start message listener | ||||
itchat.run() | itchat.run() | ||||
@@ -127,24 +152,30 @@ class WechatChannel(ChatChannel): | |||||
@time_checker | @time_checker | ||||
@_check | @_check | ||||
def handle_single(self, cmsg : ChatMessage): | |||||
def handle_single(self, cmsg: ChatMessage): | |||||
if cmsg.ctype == ContextType.VOICE: | if cmsg.ctype == ContextType.VOICE: | ||||
if conf().get('speech_recognition') != True: | |||||
if conf().get("speech_recognition") != True: | |||||
return | return | ||||
logger.debug("[WX]receive voice msg: {}".format(cmsg.content)) | logger.debug("[WX]receive voice msg: {}".format(cmsg.content)) | ||||
elif cmsg.ctype == ContextType.IMAGE: | elif cmsg.ctype == ContextType.IMAGE: | ||||
logger.debug("[WX]receive image msg: {}".format(cmsg.content)) | logger.debug("[WX]receive image msg: {}".format(cmsg.content)) | ||||
else: | else: | ||||
logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg)) | |||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg) | |||||
logger.debug( | |||||
"[WX]receive text msg: {}, cmsg={}".format( | |||||
json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg | |||||
) | |||||
) | |||||
context = self._compose_context( | |||||
cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg | |||||
) | |||||
if context: | if context: | ||||
self.produce(context) | self.produce(context) | ||||
@time_checker | @time_checker | ||||
@_check | @_check | ||||
def handle_group(self, cmsg : ChatMessage): | |||||
def handle_group(self, cmsg: ChatMessage): | |||||
if cmsg.ctype == ContextType.VOICE: | if cmsg.ctype == ContextType.VOICE: | ||||
if conf().get('speech_recognition') != True: | |||||
if conf().get("speech_recognition") != True: | |||||
return | return | ||||
logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content)) | logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content)) | ||||
elif cmsg.ctype == ContextType.IMAGE: | elif cmsg.ctype == ContextType.IMAGE: | ||||
@@ -152,23 +183,25 @@ class WechatChannel(ChatChannel): | |||||
else: | else: | ||||
# logger.debug("[WX]receive group msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg)) | # logger.debug("[WX]receive group msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg)) | ||||
pass | pass | ||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg) | |||||
context = self._compose_context( | |||||
cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg | |||||
) | |||||
if context: | if context: | ||||
self.produce(context) | self.produce(context) | ||||
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息 | # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息 | ||||
def send(self, reply: Reply, context: Context): | def send(self, reply: Reply, context: Context): | ||||
receiver = context["receiver"] | receiver = context["receiver"] | ||||
if reply.type == ReplyType.TEXT: | if reply.type == ReplyType.TEXT: | ||||
itchat.send(reply.content, toUserName=receiver) | itchat.send(reply.content, toUserName=receiver) | ||||
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) | |||||
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver)) | |||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: | elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: | ||||
itchat.send(reply.content, toUserName=receiver) | itchat.send(reply.content, toUserName=receiver) | ||||
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) | |||||
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver)) | |||||
elif reply.type == ReplyType.VOICE: | elif reply.type == ReplyType.VOICE: | ||||
itchat.send_file(reply.content, toUserName=receiver) | itchat.send_file(reply.content, toUserName=receiver) | ||||
logger.info('[WX] sendFile={}, receiver={}'.format(reply.content, receiver)) | |||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 | |||||
logger.info("[WX] sendFile={}, receiver={}".format(reply.content, receiver)) | |||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 | |||||
img_url = reply.content | img_url = reply.content | ||||
pic_res = requests.get(img_url, stream=True) | pic_res = requests.get(img_url, stream=True) | ||||
image_storage = io.BytesIO() | image_storage = io.BytesIO() | ||||
@@ -176,9 +209,9 @@ class WechatChannel(ChatChannel): | |||||
image_storage.write(block) | image_storage.write(block) | ||||
image_storage.seek(0) | image_storage.seek(0) | ||||
itchat.send_image(image_storage, toUserName=receiver) | itchat.send_image(image_storage, toUserName=receiver) | ||||
logger.info('[WX] sendImage url={}, receiver={}'.format(img_url,receiver)) | |||||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片 | |||||
logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver)) | |||||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片 | |||||
image_storage = reply.content | image_storage = reply.content | ||||
image_storage.seek(0) | image_storage.seek(0) | ||||
itchat.send_image(image_storage, toUserName=receiver) | itchat.send_image(image_storage, toUserName=receiver) | ||||
logger.info('[WX] sendImage, receiver={}'.format(receiver)) | |||||
logger.info("[WX] sendImage, receiver={}".format(receiver)) |
@@ -1,54 +1,54 @@ | |||||
from bridge.context import ContextType | from bridge.context import ContextType | ||||
from channel.chat_message import ChatMessage | from channel.chat_message import ChatMessage | ||||
from common.tmp_dir import TmpDir | |||||
from common.log import logger | from common.log import logger | ||||
from lib.itchat.content import * | |||||
from common.tmp_dir import TmpDir | |||||
from lib import itchat | from lib import itchat | ||||
from lib.itchat.content import * | |||||
class WeChatMessage(ChatMessage): | |||||
class WeChatMessage(ChatMessage): | |||||
def __init__(self, itchat_msg, is_group=False): | def __init__(self, itchat_msg, is_group=False): | ||||
super().__init__( itchat_msg) | |||||
self.msg_id = itchat_msg['MsgId'] | |||||
self.create_time = itchat_msg['CreateTime'] | |||||
super().__init__(itchat_msg) | |||||
self.msg_id = itchat_msg["MsgId"] | |||||
self.create_time = itchat_msg["CreateTime"] | |||||
self.is_group = is_group | self.is_group = is_group | ||||
if itchat_msg['Type'] == TEXT: | |||||
if itchat_msg["Type"] == TEXT: | |||||
self.ctype = ContextType.TEXT | self.ctype = ContextType.TEXT | ||||
self.content = itchat_msg['Text'] | |||||
elif itchat_msg['Type'] == VOICE: | |||||
self.content = itchat_msg["Text"] | |||||
elif itchat_msg["Type"] == VOICE: | |||||
self.ctype = ContextType.VOICE | self.ctype = ContextType.VOICE | ||||
self.content = TmpDir().path() + itchat_msg['FileName'] # content直接存临时目录路径 | |||||
self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径 | |||||
self._prepare_fn = lambda: itchat_msg.download(self.content) | self._prepare_fn = lambda: itchat_msg.download(self.content) | ||||
elif itchat_msg['Type'] == PICTURE and itchat_msg['MsgType'] == 3: | |||||
elif itchat_msg["Type"] == PICTURE and itchat_msg["MsgType"] == 3: | |||||
self.ctype = ContextType.IMAGE | self.ctype = ContextType.IMAGE | ||||
self.content = TmpDir().path() + itchat_msg['FileName'] # content直接存临时目录路径 | |||||
self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径 | |||||
self._prepare_fn = lambda: itchat_msg.download(self.content) | self._prepare_fn = lambda: itchat_msg.download(self.content) | ||||
else: | else: | ||||
raise NotImplementedError("Unsupported message type: {}".format(itchat_msg['Type'])) | |||||
self.from_user_id = itchat_msg['FromUserName'] | |||||
self.to_user_id = itchat_msg['ToUserName'] | |||||
raise NotImplementedError( | |||||
"Unsupported message type: {}".format(itchat_msg["Type"]) | |||||
) | |||||
self.from_user_id = itchat_msg["FromUserName"] | |||||
self.to_user_id = itchat_msg["ToUserName"] | |||||
user_id = itchat.instance.storageClass.userName | user_id = itchat.instance.storageClass.userName | ||||
nickname = itchat.instance.storageClass.nickName | nickname = itchat.instance.storageClass.nickName | ||||
# 虽然from_user_id和to_user_id用的少,但是为了保持一致性,还是要填充一下 | # 虽然from_user_id和to_user_id用的少,但是为了保持一致性,还是要填充一下 | ||||
# 以下很繁琐,一句话总结:能填的都填了。 | # 以下很繁琐,一句话总结:能填的都填了。 | ||||
if self.from_user_id == user_id: | if self.from_user_id == user_id: | ||||
self.from_user_nickname = nickname | self.from_user_nickname = nickname | ||||
if self.to_user_id == user_id: | if self.to_user_id == user_id: | ||||
self.to_user_nickname = nickname | self.to_user_nickname = nickname | ||||
try: # 陌生人时候, 'User'字段可能不存在 | |||||
self.other_user_id = itchat_msg['User']['UserName'] | |||||
self.other_user_nickname = itchat_msg['User']['NickName'] | |||||
try: # 陌生人时候, 'User'字段可能不存在 | |||||
self.other_user_id = itchat_msg["User"]["UserName"] | |||||
self.other_user_nickname = itchat_msg["User"]["NickName"] | |||||
if self.other_user_id == self.from_user_id: | if self.other_user_id == self.from_user_id: | ||||
self.from_user_nickname = self.other_user_nickname | self.from_user_nickname = self.other_user_nickname | ||||
if self.other_user_id == self.to_user_id: | if self.other_user_id == self.to_user_id: | ||||
self.to_user_nickname = self.other_user_nickname | self.to_user_nickname = self.other_user_nickname | ||||
except KeyError as e: # 处理偶尔没有对方信息的情况 | |||||
except KeyError as e: # 处理偶尔没有对方信息的情况 | |||||
logger.warn("[WX]get other_user_id failed: " + str(e)) | logger.warn("[WX]get other_user_id failed: " + str(e)) | ||||
if self.from_user_id == user_id: | if self.from_user_id == user_id: | ||||
self.other_user_id = self.to_user_id | self.other_user_id = self.to_user_id | ||||
@@ -56,6 +56,6 @@ class WeChatMessage(ChatMessage): | |||||
self.other_user_id = self.from_user_id | self.other_user_id = self.from_user_id | ||||
if self.is_group: | if self.is_group: | ||||
self.is_at = itchat_msg['IsAt'] | |||||
self.actual_user_id = itchat_msg['ActualUserName'] | |||||
self.actual_user_nickname = itchat_msg['ActualNickName'] | |||||
self.is_at = itchat_msg["IsAt"] | |||||
self.actual_user_id = itchat_msg["ActualUserName"] | |||||
self.actual_user_nickname = itchat_msg["ActualNickName"] |
@@ -4,104 +4,118 @@ | |||||
wechaty channel | wechaty channel | ||||
Python Wechaty - https://github.com/wechaty/python-wechaty | Python Wechaty - https://github.com/wechaty/python-wechaty | ||||
""" | """ | ||||
import asyncio | |||||
import base64 | import base64 | ||||
import os | import os | ||||
import time | import time | ||||
import asyncio | |||||
from bridge.context import Context | |||||
from wechaty_puppet import FileBox | |||||
from wechaty import Wechaty, Contact | |||||
from wechaty import Contact, Wechaty | |||||
from wechaty.user import Message | from wechaty.user import Message | ||||
from bridge.reply import * | |||||
from wechaty_puppet import FileBox | |||||
from bridge.context import * | from bridge.context import * | ||||
from bridge.context import Context | |||||
from bridge.reply import * | |||||
from channel.chat_channel import ChatChannel | from channel.chat_channel import ChatChannel | ||||
from channel.wechat.wechaty_message import WechatyMessage | from channel.wechat.wechaty_message import WechatyMessage | ||||
from common.log import logger | from common.log import logger | ||||
from common.singleton import singleton | from common.singleton import singleton | ||||
from config import conf | from config import conf | ||||
try: | try: | ||||
from voice.audio_convert import any_to_sil | from voice.audio_convert import any_to_sil | ||||
except Exception as e: | except Exception as e: | ||||
pass | pass | ||||
@singleton | @singleton | ||||
class WechatyChannel(ChatChannel): | class WechatyChannel(ChatChannel): | ||||
NOT_SUPPORT_REPLYTYPE = [] | NOT_SUPPORT_REPLYTYPE = [] | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
def startup(self): | def startup(self): | ||||
config = conf() | config = conf() | ||||
token = config.get('wechaty_puppet_service_token') | |||||
os.environ['WECHATY_PUPPET_SERVICE_TOKEN'] = token | |||||
token = config.get("wechaty_puppet_service_token") | |||||
os.environ["WECHATY_PUPPET_SERVICE_TOKEN"] = token | |||||
asyncio.run(self.main()) | asyncio.run(self.main()) | ||||
async def main(self): | async def main(self): | ||||
loop = asyncio.get_event_loop() | loop = asyncio.get_event_loop() | ||||
#将asyncio的loop传入处理线程 | |||||
self.handler_pool._initializer= lambda: asyncio.set_event_loop(loop) | |||||
# 将asyncio的loop传入处理线程 | |||||
self.handler_pool._initializer = lambda: asyncio.set_event_loop(loop) | |||||
self.bot = Wechaty() | self.bot = Wechaty() | ||||
self.bot.on('login', self.on_login) | |||||
self.bot.on('message', self.on_message) | |||||
self.bot.on("login", self.on_login) | |||||
self.bot.on("message", self.on_message) | |||||
await self.bot.start() | await self.bot.start() | ||||
async def on_login(self, contact: Contact): | async def on_login(self, contact: Contact): | ||||
self.user_id = contact.contact_id | self.user_id = contact.contact_id | ||||
self.name = contact.name | self.name = contact.name | ||||
logger.info('[WX] login user={}'.format(contact)) | |||||
logger.info("[WX] login user={}".format(contact)) | |||||
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息 | # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息 | ||||
def send(self, reply: Reply, context: Context): | def send(self, reply: Reply, context: Context): | ||||
receiver_id = context['receiver'] | |||||
receiver_id = context["receiver"] | |||||
loop = asyncio.get_event_loop() | loop = asyncio.get_event_loop() | ||||
if context['isgroup']: | |||||
receiver = asyncio.run_coroutine_threadsafe(self.bot.Room.find(receiver_id),loop).result() | |||||
if context["isgroup"]: | |||||
receiver = asyncio.run_coroutine_threadsafe( | |||||
self.bot.Room.find(receiver_id), loop | |||||
).result() | |||||
else: | else: | ||||
receiver = asyncio.run_coroutine_threadsafe(self.bot.Contact.find(receiver_id),loop).result() | |||||
receiver = asyncio.run_coroutine_threadsafe( | |||||
self.bot.Contact.find(receiver_id), loop | |||||
).result() | |||||
msg = None | msg = None | ||||
if reply.type == ReplyType.TEXT: | if reply.type == ReplyType.TEXT: | ||||
msg = reply.content | msg = reply.content | ||||
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result() | |||||
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) | |||||
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result() | |||||
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver)) | |||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: | elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: | ||||
msg = reply.content | msg = reply.content | ||||
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result() | |||||
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) | |||||
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result() | |||||
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver)) | |||||
elif reply.type == ReplyType.VOICE: | elif reply.type == ReplyType.VOICE: | ||||
voiceLength = None | voiceLength = None | ||||
file_path = reply.content | file_path = reply.content | ||||
sil_file = os.path.splitext(file_path)[0] + '.sil' | |||||
sil_file = os.path.splitext(file_path)[0] + ".sil" | |||||
voiceLength = int(any_to_sil(file_path, sil_file)) | voiceLength = int(any_to_sil(file_path, sil_file)) | ||||
if voiceLength >= 60000: | if voiceLength >= 60000: | ||||
voiceLength = 60000 | voiceLength = 60000 | ||||
logger.info('[WX] voice too long, length={}, set to 60s'.format(voiceLength)) | |||||
logger.info( | |||||
"[WX] voice too long, length={}, set to 60s".format(voiceLength) | |||||
) | |||||
# 发送语音 | # 发送语音 | ||||
t = int(time.time()) | t = int(time.time()) | ||||
msg = FileBox.from_file(sil_file, name=str(t) + '.sil') | |||||
msg = FileBox.from_file(sil_file, name=str(t) + ".sil") | |||||
if voiceLength is not None: | if voiceLength is not None: | ||||
msg.metadata['voiceLength'] = voiceLength | |||||
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result() | |||||
msg.metadata["voiceLength"] = voiceLength | |||||
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result() | |||||
try: | try: | ||||
os.remove(file_path) | os.remove(file_path) | ||||
if sil_file != file_path: | if sil_file != file_path: | ||||
os.remove(sil_file) | os.remove(sil_file) | ||||
except Exception as e: | except Exception as e: | ||||
pass | pass | ||||
logger.info('[WX] sendVoice={}, receiver={}'.format(reply.content, receiver)) | |||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 | |||||
logger.info( | |||||
"[WX] sendVoice={}, receiver={}".format(reply.content, receiver) | |||||
) | |||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 | |||||
img_url = reply.content | img_url = reply.content | ||||
t = int(time.time()) | t = int(time.time()) | ||||
msg = FileBox.from_url(url=img_url, name=str(t) + '.png') | |||||
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result() | |||||
logger.info('[WX] sendImage url={}, receiver={}'.format(img_url,receiver)) | |||||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片 | |||||
msg = FileBox.from_url(url=img_url, name=str(t) + ".png") | |||||
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result() | |||||
logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver)) | |||||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片 | |||||
image_storage = reply.content | image_storage = reply.content | ||||
image_storage.seek(0) | image_storage.seek(0) | ||||
t = int(time.time()) | t = int(time.time()) | ||||
msg = FileBox.from_base64(base64.b64encode(image_storage.read()), str(t) + '.png') | |||||
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result() | |||||
logger.info('[WX] sendImage, receiver={}'.format(receiver)) | |||||
msg = FileBox.from_base64( | |||||
base64.b64encode(image_storage.read()), str(t) + ".png" | |||||
) | |||||
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result() | |||||
logger.info("[WX] sendImage, receiver={}".format(receiver)) | |||||
async def on_message(self, msg: Message): | async def on_message(self, msg: Message): | ||||
""" | """ | ||||
@@ -110,16 +124,16 @@ class WechatyChannel(ChatChannel): | |||||
try: | try: | ||||
cmsg = await WechatyMessage(msg) | cmsg = await WechatyMessage(msg) | ||||
except NotImplementedError as e: | except NotImplementedError as e: | ||||
logger.debug('[WX] {}'.format(e)) | |||||
logger.debug("[WX] {}".format(e)) | |||||
return | return | ||||
except Exception as e: | except Exception as e: | ||||
logger.exception('[WX] {}'.format(e)) | |||||
logger.exception("[WX] {}".format(e)) | |||||
return | return | ||||
logger.debug('[WX] message:{}'.format(cmsg)) | |||||
logger.debug("[WX] message:{}".format(cmsg)) | |||||
room = msg.room() # 获取消息来自的群聊. 如果消息不是来自群聊, 则返回None | room = msg.room() # 获取消息来自的群聊. 如果消息不是来自群聊, 则返回None | ||||
isgroup = room is not None | isgroup = room is not None | ||||
ctype = cmsg.ctype | ctype = cmsg.ctype | ||||
context = self._compose_context(ctype, cmsg.content, isgroup=isgroup, msg=cmsg) | context = self._compose_context(ctype, cmsg.content, isgroup=isgroup, msg=cmsg) | ||||
if context: | if context: | ||||
logger.info('[WX] receiveMsg={}, context={}'.format(cmsg, context)) | |||||
self.produce(context) | |||||
logger.info("[WX] receiveMsg={}, context={}".format(cmsg, context)) | |||||
self.produce(context) |
@@ -1,17 +1,21 @@ | |||||
import asyncio | import asyncio | ||||
import re | import re | ||||
from wechaty import MessageType | from wechaty import MessageType | ||||
from wechaty.user import Message | |||||
from bridge.context import ContextType | from bridge.context import ContextType | ||||
from channel.chat_message import ChatMessage | from channel.chat_message import ChatMessage | ||||
from common.tmp_dir import TmpDir | |||||
from common.log import logger | from common.log import logger | ||||
from wechaty.user import Message | |||||
from common.tmp_dir import TmpDir | |||||
class aobject(object): | class aobject(object): | ||||
"""Inheriting this class allows you to define an async __init__. | """Inheriting this class allows you to define an async __init__. | ||||
So you can create objects by doing something like `await MyClass(params)` | So you can create objects by doing something like `await MyClass(params)` | ||||
""" | """ | ||||
async def __new__(cls, *a, **kw): | async def __new__(cls, *a, **kw): | ||||
instance = super().__new__(cls) | instance = super().__new__(cls) | ||||
await instance.__init__(*a, **kw) | await instance.__init__(*a, **kw) | ||||
@@ -19,17 +23,18 @@ class aobject(object): | |||||
async def __init__(self): | async def __init__(self): | ||||
pass | pass | ||||
class WechatyMessage(ChatMessage, aobject): | |||||
class WechatyMessage(ChatMessage, aobject): | |||||
async def __init__(self, wechaty_msg: Message): | async def __init__(self, wechaty_msg: Message): | ||||
super().__init__(wechaty_msg) | super().__init__(wechaty_msg) | ||||
room = wechaty_msg.room() | room = wechaty_msg.room() | ||||
self.msg_id = wechaty_msg.message_id | self.msg_id = wechaty_msg.message_id | ||||
self.create_time = wechaty_msg.payload.timestamp | self.create_time = wechaty_msg.payload.timestamp | ||||
self.is_group = room is not None | self.is_group = room is not None | ||||
if wechaty_msg.type() == MessageType.MESSAGE_TYPE_TEXT: | if wechaty_msg.type() == MessageType.MESSAGE_TYPE_TEXT: | ||||
self.ctype = ContextType.TEXT | self.ctype = ContextType.TEXT | ||||
self.content = wechaty_msg.text() | self.content = wechaty_msg.text() | ||||
@@ -40,12 +45,17 @@ class WechatyMessage(ChatMessage, aobject): | |||||
def func(): | def func(): | ||||
loop = asyncio.get_event_loop() | loop = asyncio.get_event_loop() | ||||
asyncio.run_coroutine_threadsafe(voice_file.to_file(self.content),loop).result() | |||||
asyncio.run_coroutine_threadsafe( | |||||
voice_file.to_file(self.content), loop | |||||
).result() | |||||
self._prepare_fn = func | self._prepare_fn = func | ||||
else: | else: | ||||
raise NotImplementedError("Unsupported message type: {}".format(wechaty_msg.type())) | |||||
raise NotImplementedError( | |||||
"Unsupported message type: {}".format(wechaty_msg.type()) | |||||
) | |||||
from_contact = wechaty_msg.talker() # 获取消息的发送者 | from_contact = wechaty_msg.talker() # 获取消息的发送者 | ||||
self.from_user_id = from_contact.contact_id | self.from_user_id = from_contact.contact_id | ||||
self.from_user_nickname = from_contact.name | self.from_user_nickname = from_contact.name | ||||
@@ -54,7 +64,7 @@ class WechatyMessage(ChatMessage, aobject): | |||||
# wecahty: from是消息实际发送者, to:所在群 | # wecahty: from是消息实际发送者, to:所在群 | ||||
# itchat: 如果是你发送群消息,from和to是你自己和所在群,如果是别人发群消息,from和to是所在群和你自己 | # itchat: 如果是你发送群消息,from和to是你自己和所在群,如果是别人发群消息,from和to是所在群和你自己 | ||||
# 但这个差别不影响逻辑,group中只使用到:1.用from来判断是否是自己发的,2.actual_user_id来判断实际发送用户 | # 但这个差别不影响逻辑,group中只使用到:1.用from来判断是否是自己发的,2.actual_user_id来判断实际发送用户 | ||||
if self.is_group: | if self.is_group: | ||||
self.to_user_id = room.room_id | self.to_user_id = room.room_id | ||||
self.to_user_nickname = await room.topic() | self.to_user_nickname = await room.topic() | ||||
@@ -63,22 +73,22 @@ class WechatyMessage(ChatMessage, aobject): | |||||
self.to_user_id = to_contact.contact_id | self.to_user_id = to_contact.contact_id | ||||
self.to_user_nickname = to_contact.name | self.to_user_nickname = to_contact.name | ||||
if self.is_group or wechaty_msg.is_self(): # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。 | |||||
if ( | |||||
self.is_group or wechaty_msg.is_self() | |||||
): # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。 | |||||
self.other_user_id = self.to_user_id | self.other_user_id = self.to_user_id | ||||
self.other_user_nickname = self.to_user_nickname | self.other_user_nickname = self.to_user_nickname | ||||
else: | else: | ||||
self.other_user_id = self.from_user_id | self.other_user_id = self.from_user_id | ||||
self.other_user_nickname = self.from_user_nickname | self.other_user_nickname = self.from_user_nickname | ||||
if self.is_group: # wechaty群聊中,实际发送用户就是from_user | |||||
if self.is_group: # wechaty群聊中,实际发送用户就是from_user | |||||
self.is_at = await wechaty_msg.mention_self() | self.is_at = await wechaty_msg.mention_self() | ||||
if not self.is_at: # 有时候复制粘贴的消息,不算做@,但是内容里面会有@xxx,这里做一下兼容 | |||||
if not self.is_at: # 有时候复制粘贴的消息,不算做@,但是内容里面会有@xxx,这里做一下兼容 | |||||
name = wechaty_msg.wechaty.user_self().name | name = wechaty_msg.wechaty.user_self().name | ||||
pattern = f'@{name}(\u2005|\u0020)' | |||||
if re.search(pattern,self.content): | |||||
logger.debug(f'wechaty message {self.msg_id} include at') | |||||
pattern = f"@{name}(\u2005|\u0020)" | |||||
if re.search(pattern, self.content): | |||||
logger.debug(f"wechaty message {self.msg_id} include at") | |||||
self.is_at = True | self.is_at = True | ||||
self.actual_user_id = self.from_user_id | self.actual_user_id = self.from_user_id | ||||
@@ -21,12 +21,12 @@ pip3 install web.py | |||||
相关的服务器验证代码已经写好,你不需要再添加任何代码。你只需要在本项目根目录的`config.json`中添加 | 相关的服务器验证代码已经写好,你不需要再添加任何代码。你只需要在本项目根目录的`config.json`中添加 | ||||
``` | ``` | ||||
"channel_type": "wechatmp", | |||||
"channel_type": "wechatmp", | |||||
"wechatmp_token": "Token", # 微信公众平台的Token | "wechatmp_token": "Token", # 微信公众平台的Token | ||||
"wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443 | "wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443 | ||||
"wechatmp_app_id": "", # 微信公众平台的appID,仅服务号需要 | "wechatmp_app_id": "", # 微信公众平台的appID,仅服务号需要 | ||||
"wechatmp_app_secret": "", # 微信公众平台的appsecret,仅服务号需要 | "wechatmp_app_secret": "", # 微信公众平台的appsecret,仅服务号需要 | ||||
``` | |||||
``` | |||||
然后运行`python3 app.py`启动web服务器。这里会默认监听8080端口,但是微信公众号的服务器配置只支持80/443端口,有两种方法来解决这个问题。第一个是推荐的方法,使用端口转发命令将80端口转发到8080端口(443同理,注意需要支持SSL,也就是https的访问,在`wechatmp_channel.py`需要修改相应的证书路径): | 然后运行`python3 app.py`启动web服务器。这里会默认监听8080端口,但是微信公众号的服务器配置只支持80/443端口,有两种方法来解决这个问题。第一个是推荐的方法,使用端口转发命令将80端口转发到8080端口(443同理,注意需要支持SSL,也就是https的访问,在`wechatmp_channel.py`需要修改相应的证书路径): | ||||
``` | ``` | ||||
sudo iptables -t nat -A PREROUTING -p tcp --dport 80 -j REDIRECT --to-port 8080 | sudo iptables -t nat -A PREROUTING -p tcp --dport 80 -j REDIRECT --to-port 8080 | ||||
@@ -1,46 +1,66 @@ | |||||
import web | |||||
import time | import time | ||||
import channel.wechatmp.reply as reply | |||||
import web | |||||
import channel.wechatmp.receive as receive | import channel.wechatmp.receive as receive | ||||
from config import conf | |||||
from common.log import logger | |||||
import channel.wechatmp.reply as reply | |||||
from bridge.context import * | from bridge.context import * | ||||
from channel.wechatmp.common import * | |||||
from channel.wechatmp.common import * | |||||
from channel.wechatmp.wechatmp_channel import WechatMPChannel | from channel.wechatmp.wechatmp_channel import WechatMPChannel | ||||
from common.log import logger | |||||
from config import conf | |||||
# This class is instantiated once per query | |||||
class Query(): | |||||
# This class is instantiated once per query | |||||
class Query: | |||||
def GET(self): | def GET(self): | ||||
return verify_server(web.input()) | return verify_server(web.input()) | ||||
def POST(self): | def POST(self): | ||||
# Make sure to return the instance that first created, @singleton will do that. | |||||
# Make sure to return the instance that first created, @singleton will do that. | |||||
channel = WechatMPChannel() | channel = WechatMPChannel() | ||||
try: | try: | ||||
webData = web.data() | webData = web.data() | ||||
# logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8")) | # logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8")) | ||||
wechatmp_msg = receive.parse_xml(webData) | wechatmp_msg = receive.parse_xml(webData) | ||||
if wechatmp_msg.msg_type == 'text' or wechatmp_msg.msg_type == 'voice': | |||||
if wechatmp_msg.msg_type == "text" or wechatmp_msg.msg_type == "voice": | |||||
from_user = wechatmp_msg.from_user_id | from_user = wechatmp_msg.from_user_id | ||||
message = wechatmp_msg.content.decode("utf-8") | message = wechatmp_msg.content.decode("utf-8") | ||||
message_id = wechatmp_msg.msg_id | message_id = wechatmp_msg.msg_id | ||||
logger.info("[wechatmp] {}:{} Receive post query {} {}: {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), from_user, message_id, message)) | |||||
context = channel._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg) | |||||
logger.info( | |||||
"[wechatmp] {}:{} Receive post query {} {}: {}".format( | |||||
web.ctx.env.get("REMOTE_ADDR"), | |||||
web.ctx.env.get("REMOTE_PORT"), | |||||
from_user, | |||||
message_id, | |||||
message, | |||||
) | |||||
) | |||||
context = channel._compose_context( | |||||
ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg | |||||
) | |||||
if context: | if context: | ||||
# set private openai_api_key | # set private openai_api_key | ||||
# if from_user is not changed in itchat, this can be placed at chat_channel | # if from_user is not changed in itchat, this can be placed at chat_channel | ||||
user_data = conf().get_user_data(from_user) | user_data = conf().get_user_data(from_user) | ||||
context['openai_api_key'] = user_data.get('openai_api_key') # None or user openai_api_key | |||||
context["openai_api_key"] = user_data.get( | |||||
"openai_api_key" | |||||
) # None or user openai_api_key | |||||
channel.produce(context) | channel.produce(context) | ||||
# The reply will be sent by channel.send() in another thread | # The reply will be sent by channel.send() in another thread | ||||
return "success" | return "success" | ||||
elif wechatmp_msg.msg_type == 'event': | |||||
logger.info("[wechatmp] Event {} from {}".format(wechatmp_msg.Event, wechatmp_msg.from_user_id)) | |||||
elif wechatmp_msg.msg_type == "event": | |||||
logger.info( | |||||
"[wechatmp] Event {} from {}".format( | |||||
wechatmp_msg.Event, wechatmp_msg.from_user_id | |||||
) | |||||
) | |||||
content = subscribe_msg() | content = subscribe_msg() | ||||
replyMsg = reply.TextMsg(wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content) | |||||
replyMsg = reply.TextMsg( | |||||
wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content | |||||
) | |||||
return replyMsg.send() | return replyMsg.send() | ||||
else: | else: | ||||
logger.info("暂且不处理") | logger.info("暂且不处理") | ||||
@@ -48,4 +68,3 @@ class Query(): | |||||
except Exception as exc: | except Exception as exc: | ||||
logger.exception(exc) | logger.exception(exc) | ||||
return exc | return exc | ||||
@@ -1,81 +1,117 @@ | |||||
import web | |||||
import time | import time | ||||
import channel.wechatmp.reply as reply | |||||
import web | |||||
import channel.wechatmp.receive as receive | import channel.wechatmp.receive as receive | ||||
from config import conf | |||||
from common.log import logger | |||||
import channel.wechatmp.reply as reply | |||||
from bridge.context import * | from bridge.context import * | ||||
from channel.wechatmp.common import * | |||||
from channel.wechatmp.common import * | |||||
from channel.wechatmp.wechatmp_channel import WechatMPChannel | from channel.wechatmp.wechatmp_channel import WechatMPChannel | ||||
from common.log import logger | |||||
from config import conf | |||||
# This class is instantiated once per query | |||||
class Query(): | |||||
# This class is instantiated once per query | |||||
class Query: | |||||
def GET(self): | def GET(self): | ||||
return verify_server(web.input()) | return verify_server(web.input()) | ||||
def POST(self): | def POST(self): | ||||
# Make sure to return the instance that first created, @singleton will do that. | |||||
# Make sure to return the instance that first created, @singleton will do that. | |||||
channel = WechatMPChannel() | channel = WechatMPChannel() | ||||
try: | try: | ||||
query_time = time.time() | query_time = time.time() | ||||
webData = web.data() | webData = web.data() | ||||
logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8")) | logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8")) | ||||
wechatmp_msg = receive.parse_xml(webData) | wechatmp_msg = receive.parse_xml(webData) | ||||
if wechatmp_msg.msg_type == 'text' or wechatmp_msg.msg_type == 'voice': | |||||
if wechatmp_msg.msg_type == "text" or wechatmp_msg.msg_type == "voice": | |||||
from_user = wechatmp_msg.from_user_id | from_user = wechatmp_msg.from_user_id | ||||
to_user = wechatmp_msg.to_user_id | to_user = wechatmp_msg.to_user_id | ||||
message = wechatmp_msg.content.decode("utf-8") | message = wechatmp_msg.content.decode("utf-8") | ||||
message_id = wechatmp_msg.msg_id | message_id = wechatmp_msg.msg_id | ||||
logger.info("[wechatmp] {}:{} Receive post query {} {}: {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), from_user, message_id, message)) | |||||
logger.info( | |||||
"[wechatmp] {}:{} Receive post query {} {}: {}".format( | |||||
web.ctx.env.get("REMOTE_ADDR"), | |||||
web.ctx.env.get("REMOTE_PORT"), | |||||
from_user, | |||||
message_id, | |||||
message, | |||||
) | |||||
) | |||||
supported = True | supported = True | ||||
if "【收到不支持的消息类型,暂无法显示】" in message: | if "【收到不支持的消息类型,暂无法显示】" in message: | ||||
supported = False # not supported, used to refresh | |||||
supported = False # not supported, used to refresh | |||||
cache_key = from_user | cache_key = from_user | ||||
reply_text = "" | reply_text = "" | ||||
# New request | # New request | ||||
if cache_key not in channel.cache_dict and cache_key not in channel.running: | |||||
if ( | |||||
cache_key not in channel.cache_dict | |||||
and cache_key not in channel.running | |||||
): | |||||
# The first query begin, reset the cache | # The first query begin, reset the cache | ||||
context = channel._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg) | |||||
logger.debug("[wechatmp] context: {} {}".format(context, wechatmp_msg)) | |||||
if message_id in channel.received_msgs: # received and finished | |||||
context = channel._compose_context( | |||||
ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg | |||||
) | |||||
logger.debug( | |||||
"[wechatmp] context: {} {}".format(context, wechatmp_msg) | |||||
) | |||||
if message_id in channel.received_msgs: # received and finished | |||||
# no return because of bandwords or other reasons | # no return because of bandwords or other reasons | ||||
return "success" | return "success" | ||||
if supported and context: | if supported and context: | ||||
# set private openai_api_key | # set private openai_api_key | ||||
# if from_user is not changed in itchat, this can be placed at chat_channel | # if from_user is not changed in itchat, this can be placed at chat_channel | ||||
user_data = conf().get_user_data(from_user) | user_data = conf().get_user_data(from_user) | ||||
context['openai_api_key'] = user_data.get('openai_api_key') # None or user openai_api_key | |||||
context["openai_api_key"] = user_data.get( | |||||
"openai_api_key" | |||||
) # None or user openai_api_key | |||||
channel.received_msgs[message_id] = wechatmp_msg | channel.received_msgs[message_id] = wechatmp_msg | ||||
channel.running.add(cache_key) | channel.running.add(cache_key) | ||||
channel.produce(context) | channel.produce(context) | ||||
else: | else: | ||||
trigger_prefix = conf().get('single_chat_prefix',[''])[0] | |||||
trigger_prefix = conf().get("single_chat_prefix", [""])[0] | |||||
if trigger_prefix or not supported: | if trigger_prefix or not supported: | ||||
if trigger_prefix: | if trigger_prefix: | ||||
content = textwrap.dedent(f"""\ | |||||
content = textwrap.dedent( | |||||
f"""\ | |||||
请输入'{trigger_prefix}'接你想说的话跟我说话。 | 请输入'{trigger_prefix}'接你想说的话跟我说话。 | ||||
例如: | 例如: | ||||
{trigger_prefix}你好,很高兴见到你。""") | |||||
{trigger_prefix}你好,很高兴见到你。""" | |||||
) | |||||
else: | else: | ||||
content = textwrap.dedent("""\ | |||||
content = textwrap.dedent( | |||||
"""\ | |||||
你好,很高兴见到你。 | 你好,很高兴见到你。 | ||||
请跟我说话吧。""") | |||||
请跟我说话吧。""" | |||||
) | |||||
else: | else: | ||||
logger.error(f"[wechatmp] unknown error") | logger.error(f"[wechatmp] unknown error") | ||||
content = textwrap.dedent("""\ | |||||
未知错误,请稍后再试""") | |||||
replyMsg = reply.TextMsg(wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content) | |||||
content = textwrap.dedent( | |||||
"""\ | |||||
未知错误,请稍后再试""" | |||||
) | |||||
replyMsg = reply.TextMsg( | |||||
wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content | |||||
) | |||||
return replyMsg.send() | return replyMsg.send() | ||||
channel.query1[cache_key] = False | channel.query1[cache_key] = False | ||||
channel.query2[cache_key] = False | channel.query2[cache_key] = False | ||||
channel.query3[cache_key] = False | channel.query3[cache_key] = False | ||||
# User request again, and the answer is not ready | # User request again, and the answer is not ready | ||||
elif cache_key in channel.running and channel.query1.get(cache_key) == True and channel.query2.get(cache_key) == True and channel.query3.get(cache_key) == True: | |||||
channel.query1[cache_key] = False #To improve waiting experience, this can be set to True. | |||||
channel.query2[cache_key] = False #To improve waiting experience, this can be set to True. | |||||
elif ( | |||||
cache_key in channel.running | |||||
and channel.query1.get(cache_key) == True | |||||
and channel.query2.get(cache_key) == True | |||||
and channel.query3.get(cache_key) == True | |||||
): | |||||
channel.query1[ | |||||
cache_key | |||||
] = False # To improve waiting experience, this can be set to True. | |||||
channel.query2[ | |||||
cache_key | |||||
] = False # To improve waiting experience, this can be set to True. | |||||
channel.query3[cache_key] = False | channel.query3[cache_key] = False | ||||
# User request again, and the answer is ready | # User request again, and the answer is ready | ||||
elif cache_key in channel.cache_dict: | elif cache_key in channel.cache_dict: | ||||
@@ -84,7 +120,9 @@ class Query(): | |||||
channel.query2[cache_key] = True | channel.query2[cache_key] = True | ||||
channel.query3[cache_key] = True | channel.query3[cache_key] = True | ||||
assert not (cache_key in channel.cache_dict and cache_key in channel.running) | |||||
assert not ( | |||||
cache_key in channel.cache_dict and cache_key in channel.running | |||||
) | |||||
if channel.query1.get(cache_key) == False: | if channel.query1.get(cache_key) == False: | ||||
# The first query from wechat official server | # The first query from wechat official server | ||||
@@ -128,14 +166,20 @@ class Query(): | |||||
# Have waiting for 3x5 seconds | # Have waiting for 3x5 seconds | ||||
# return timeout message | # return timeout message | ||||
reply_text = "【正在思考中,回复任意文字尝试获取回复】" | reply_text = "【正在思考中,回复任意文字尝试获取回复】" | ||||
logger.info("[wechatmp] Three queries has finished For {}: {}".format(from_user, message_id)) | |||||
logger.info( | |||||
"[wechatmp] Three queries has finished For {}: {}".format( | |||||
from_user, message_id | |||||
) | |||||
) | |||||
replyPost = reply.TextMsg(from_user, to_user, reply_text).send() | replyPost = reply.TextMsg(from_user, to_user, reply_text).send() | ||||
return replyPost | return replyPost | ||||
else: | else: | ||||
pass | pass | ||||
if cache_key not in channel.cache_dict and cache_key not in channel.running: | |||||
if ( | |||||
cache_key not in channel.cache_dict | |||||
and cache_key not in channel.running | |||||
): | |||||
# no return because of bandwords or other reasons | # no return because of bandwords or other reasons | ||||
return "success" | return "success" | ||||
@@ -147,26 +191,42 @@ class Query(): | |||||
if cache_key in channel.cache_dict: | if cache_key in channel.cache_dict: | ||||
content = channel.cache_dict[cache_key] | content = channel.cache_dict[cache_key] | ||||
if len(content.encode('utf8'))<=MAX_UTF8_LEN: | |||||
if len(content.encode("utf8")) <= MAX_UTF8_LEN: | |||||
reply_text = channel.cache_dict[cache_key] | reply_text = channel.cache_dict[cache_key] | ||||
channel.cache_dict.pop(cache_key) | channel.cache_dict.pop(cache_key) | ||||
else: | else: | ||||
continue_text = "\n【未完待续,回复任意文字以继续】" | continue_text = "\n【未完待续,回复任意文字以继续】" | ||||
splits = split_string_by_utf8_length(content, MAX_UTF8_LEN - len(continue_text.encode('utf-8')), max_split= 1) | |||||
splits = split_string_by_utf8_length( | |||||
content, | |||||
MAX_UTF8_LEN - len(continue_text.encode("utf-8")), | |||||
max_split=1, | |||||
) | |||||
reply_text = splits[0] + continue_text | reply_text = splits[0] + continue_text | ||||
channel.cache_dict[cache_key] = splits[1] | channel.cache_dict[cache_key] = splits[1] | ||||
logger.info("[wechatmp] {}:{} Do send {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), reply_text)) | |||||
logger.info( | |||||
"[wechatmp] {}:{} Do send {}".format( | |||||
web.ctx.env.get("REMOTE_ADDR"), | |||||
web.ctx.env.get("REMOTE_PORT"), | |||||
reply_text, | |||||
) | |||||
) | |||||
replyPost = reply.TextMsg(from_user, to_user, reply_text).send() | replyPost = reply.TextMsg(from_user, to_user, reply_text).send() | ||||
return replyPost | return replyPost | ||||
elif wechatmp_msg.msg_type == 'event': | |||||
logger.info("[wechatmp] Event {} from {}".format(wechatmp_msg.content, wechatmp_msg.from_user_id)) | |||||
elif wechatmp_msg.msg_type == "event": | |||||
logger.info( | |||||
"[wechatmp] Event {} from {}".format( | |||||
wechatmp_msg.content, wechatmp_msg.from_user_id | |||||
) | |||||
) | |||||
content = subscribe_msg() | content = subscribe_msg() | ||||
replyMsg = reply.TextMsg(wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content) | |||||
replyMsg = reply.TextMsg( | |||||
wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content | |||||
) | |||||
return replyMsg.send() | return replyMsg.send() | ||||
else: | else: | ||||
logger.info("暂且不处理") | logger.info("暂且不处理") | ||||
return "success" | return "success" | ||||
except Exception as exc: | except Exception as exc: | ||||
logger.exception(exc) | logger.exception(exc) | ||||
return exc | |||||
return exc |
@@ -1,9 +1,11 @@ | |||||
from config import conf | |||||
import hashlib | import hashlib | ||||
import textwrap | import textwrap | ||||
from config import conf | |||||
MAX_UTF8_LEN = 2048 | MAX_UTF8_LEN = 2048 | ||||
class WeChatAPIException(Exception): | class WeChatAPIException(Exception): | ||||
pass | pass | ||||
@@ -16,13 +18,13 @@ def verify_server(data): | |||||
timestamp = data.timestamp | timestamp = data.timestamp | ||||
nonce = data.nonce | nonce = data.nonce | ||||
echostr = data.echostr | echostr = data.echostr | ||||
token = conf().get('wechatmp_token') #请按照公众平台官网\基本配置中信息填写 | |||||
token = conf().get("wechatmp_token") # 请按照公众平台官网\基本配置中信息填写 | |||||
data_list = [token, timestamp, nonce] | data_list = [token, timestamp, nonce] | ||||
data_list.sort() | data_list.sort() | ||||
sha1 = hashlib.sha1() | sha1 = hashlib.sha1() | ||||
# map(sha1.update, data_list) #python2 | # map(sha1.update, data_list) #python2 | ||||
sha1.update("".join(data_list).encode('utf-8')) | |||||
sha1.update("".join(data_list).encode("utf-8")) | |||||
hashcode = sha1.hexdigest() | hashcode = sha1.hexdigest() | ||||
print("handle/GET func: hashcode, signature: ", hashcode, signature) | print("handle/GET func: hashcode, signature: ", hashcode, signature) | ||||
if hashcode == signature: | if hashcode == signature: | ||||
@@ -32,9 +34,11 @@ def verify_server(data): | |||||
except Exception as Argument: | except Exception as Argument: | ||||
return Argument | return Argument | ||||
def subscribe_msg(): | def subscribe_msg(): | ||||
trigger_prefix = conf().get('single_chat_prefix',[''])[0] | |||||
msg = textwrap.dedent(f"""\ | |||||
trigger_prefix = conf().get("single_chat_prefix", [""])[0] | |||||
msg = textwrap.dedent( | |||||
f"""\ | |||||
感谢您的关注! | 感谢您的关注! | ||||
这里是ChatGPT,可以自由对话。 | 这里是ChatGPT,可以自由对话。 | ||||
资源有限,回复较慢,请勿着急。 | 资源有限,回复较慢,请勿着急。 | ||||
@@ -42,22 +46,23 @@ def subscribe_msg(): | |||||
暂时不支持图片输入。 | 暂时不支持图片输入。 | ||||
支持图片输出,画字开头的问题将回复图片链接。 | 支持图片输出,画字开头的问题将回复图片链接。 | ||||
支持角色扮演和文字冒险两种定制模式对话。 | 支持角色扮演和文字冒险两种定制模式对话。 | ||||
输入'{trigger_prefix}#帮助' 查看详细指令。""") | |||||
输入'{trigger_prefix}#帮助' 查看详细指令。""" | |||||
) | |||||
return msg | return msg | ||||
def split_string_by_utf8_length(string, max_length, max_split=0): | def split_string_by_utf8_length(string, max_length, max_split=0): | ||||
encoded = string.encode('utf-8') | |||||
encoded = string.encode("utf-8") | |||||
start, end = 0, 0 | start, end = 0, 0 | ||||
result = [] | result = [] | ||||
while end < len(encoded): | while end < len(encoded): | ||||
if max_split > 0 and len(result) >= max_split: | if max_split > 0 and len(result) >= max_split: | ||||
result.append(encoded[start:].decode('utf-8')) | |||||
result.append(encoded[start:].decode("utf-8")) | |||||
break | break | ||||
end = start + max_length | end = start + max_length | ||||
# 如果当前字节不是 UTF-8 编码的开始字节,则向前查找直到找到开始字节为止 | # 如果当前字节不是 UTF-8 编码的开始字节,则向前查找直到找到开始字节为止 | ||||
while end < len(encoded) and (encoded[end] & 0b11000000) == 0b10000000: | while end < len(encoded) and (encoded[end] & 0b11000000) == 0b10000000: | ||||
end -= 1 | end -= 1 | ||||
result.append(encoded[start:end].decode('utf-8')) | |||||
result.append(encoded[start:end].decode("utf-8")) | |||||
start = end | start = end | ||||
return result | |||||
return result |
@@ -1,6 +1,7 @@ | |||||
# -*- coding: utf-8 -*-# | # -*- coding: utf-8 -*-# | ||||
# filename: receive.py | # filename: receive.py | ||||
import xml.etree.ElementTree as ET | import xml.etree.ElementTree as ET | ||||
from bridge.context import ContextType | from bridge.context import ContextType | ||||
from channel.chat_message import ChatMessage | from channel.chat_message import ChatMessage | ||||
from common.log import logger | from common.log import logger | ||||
@@ -12,34 +13,35 @@ def parse_xml(web_data): | |||||
xmlData = ET.fromstring(web_data) | xmlData = ET.fromstring(web_data) | ||||
return WeChatMPMessage(xmlData) | return WeChatMPMessage(xmlData) | ||||
class WeChatMPMessage(ChatMessage): | class WeChatMPMessage(ChatMessage): | ||||
def __init__(self, xmlData): | def __init__(self, xmlData): | ||||
super().__init__(xmlData) | super().__init__(xmlData) | ||||
self.to_user_id = xmlData.find('ToUserName').text | |||||
self.from_user_id = xmlData.find('FromUserName').text | |||||
self.create_time = xmlData.find('CreateTime').text | |||||
self.msg_type = xmlData.find('MsgType').text | |||||
self.to_user_id = xmlData.find("ToUserName").text | |||||
self.from_user_id = xmlData.find("FromUserName").text | |||||
self.create_time = xmlData.find("CreateTime").text | |||||
self.msg_type = xmlData.find("MsgType").text | |||||
try: | try: | ||||
self.msg_id = xmlData.find('MsgId').text | |||||
self.msg_id = xmlData.find("MsgId").text | |||||
except: | except: | ||||
self.msg_id = self.from_user_id+self.create_time | |||||
self.msg_id = self.from_user_id + self.create_time | |||||
self.is_group = False | self.is_group = False | ||||
# reply to other_user_id | # reply to other_user_id | ||||
self.other_user_id = self.from_user_id | self.other_user_id = self.from_user_id | ||||
if self.msg_type == 'text': | |||||
if self.msg_type == "text": | |||||
self.ctype = ContextType.TEXT | self.ctype = ContextType.TEXT | ||||
self.content = xmlData.find('Content').text.encode("utf-8") | |||||
elif self.msg_type == 'voice': | |||||
self.content = xmlData.find("Content").text.encode("utf-8") | |||||
elif self.msg_type == "voice": | |||||
self.ctype = ContextType.TEXT | self.ctype = ContextType.TEXT | ||||
self.content = xmlData.find('Recognition').text.encode("utf-8") # 接收语音识别结果 | |||||
elif self.msg_type == 'image': | |||||
self.content = xmlData.find("Recognition").text.encode("utf-8") # 接收语音识别结果 | |||||
elif self.msg_type == "image": | |||||
# not implemented | # not implemented | ||||
self.pic_url = xmlData.find('PicUrl').text | |||||
self.media_id = xmlData.find('MediaId').text | |||||
elif self.msg_type == 'event': | |||||
self.content = xmlData.find('Event').text | |||||
else: # video, shortvideo, location, link | |||||
self.pic_url = xmlData.find("PicUrl").text | |||||
self.media_id = xmlData.find("MediaId").text | |||||
elif self.msg_type == "event": | |||||
self.content = xmlData.find("Event").text | |||||
else: # video, shortvideo, location, link | |||||
# not implemented | # not implemented | ||||
pass | |||||
pass |
@@ -2,6 +2,7 @@ | |||||
# filename: reply.py | # filename: reply.py | ||||
import time | import time | ||||
class Msg(object): | class Msg(object): | ||||
def __init__(self): | def __init__(self): | ||||
pass | pass | ||||
@@ -9,13 +10,14 @@ class Msg(object): | |||||
def send(self): | def send(self): | ||||
return "success" | return "success" | ||||
class TextMsg(Msg): | class TextMsg(Msg): | ||||
def __init__(self, toUserName, fromUserName, content): | def __init__(self, toUserName, fromUserName, content): | ||||
self.__dict = dict() | self.__dict = dict() | ||||
self.__dict['ToUserName'] = toUserName | |||||
self.__dict['FromUserName'] = fromUserName | |||||
self.__dict['CreateTime'] = int(time.time()) | |||||
self.__dict['Content'] = content | |||||
self.__dict["ToUserName"] = toUserName | |||||
self.__dict["FromUserName"] = fromUserName | |||||
self.__dict["CreateTime"] = int(time.time()) | |||||
self.__dict["Content"] = content | |||||
def send(self): | def send(self): | ||||
XmlForm = """ | XmlForm = """ | ||||
@@ -29,13 +31,14 @@ class TextMsg(Msg): | |||||
""" | """ | ||||
return XmlForm.format(**self.__dict) | return XmlForm.format(**self.__dict) | ||||
class ImageMsg(Msg): | class ImageMsg(Msg): | ||||
def __init__(self, toUserName, fromUserName, mediaId): | def __init__(self, toUserName, fromUserName, mediaId): | ||||
self.__dict = dict() | self.__dict = dict() | ||||
self.__dict['ToUserName'] = toUserName | |||||
self.__dict['FromUserName'] = fromUserName | |||||
self.__dict['CreateTime'] = int(time.time()) | |||||
self.__dict['MediaId'] = mediaId | |||||
self.__dict["ToUserName"] = toUserName | |||||
self.__dict["FromUserName"] = fromUserName | |||||
self.__dict["CreateTime"] = int(time.time()) | |||||
self.__dict["MediaId"] = mediaId | |||||
def send(self): | def send(self): | ||||
XmlForm = """ | XmlForm = """ | ||||
@@ -49,4 +52,4 @@ class ImageMsg(Msg): | |||||
</Image> | </Image> | ||||
</xml> | </xml> | ||||
""" | """ | ||||
return XmlForm.format(**self.__dict) | |||||
return XmlForm.format(**self.__dict) |
@@ -1,17 +1,19 @@ | |||||
# -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||
import web | |||||
import time | |||||
import json | import json | ||||
import requests | |||||
import threading | import threading | ||||
from common.singleton import singleton | |||||
from common.log import logger | |||||
from common.expired_dict import ExpiredDict | |||||
from config import conf | |||||
from bridge.reply import * | |||||
import time | |||||
import requests | |||||
import web | |||||
from bridge.context import * | from bridge.context import * | ||||
from bridge.reply import * | |||||
from channel.chat_channel import ChatChannel | from channel.chat_channel import ChatChannel | ||||
from channel.wechatmp.common import * | |||||
from channel.wechatmp.common import * | |||||
from common.expired_dict import ExpiredDict | |||||
from common.log import logger | |||||
from common.singleton import singleton | |||||
from config import conf | |||||
# If using SSL, uncomment the following lines, and modify the certificate path. | # If using SSL, uncomment the following lines, and modify the certificate path. | ||||
# from cheroot.server import HTTPServer | # from cheroot.server import HTTPServer | ||||
@@ -20,13 +22,14 @@ from channel.wechatmp.common import * | |||||
# certificate='/ssl/cert.pem', | # certificate='/ssl/cert.pem', | ||||
# private_key='/ssl/cert.key') | # private_key='/ssl/cert.key') | ||||
@singleton | @singleton | ||||
class WechatMPChannel(ChatChannel): | class WechatMPChannel(ChatChannel): | ||||
def __init__(self, passive_reply = True): | |||||
def __init__(self, passive_reply=True): | |||||
super().__init__() | super().__init__() | ||||
self.passive_reply = passive_reply | self.passive_reply = passive_reply | ||||
self.running = set() | self.running = set() | ||||
self.received_msgs = ExpiredDict(60*60*24) | |||||
self.received_msgs = ExpiredDict(60 * 60 * 24) | |||||
if self.passive_reply: | if self.passive_reply: | ||||
self.NOT_SUPPORT_REPLYTYPE = [ReplyType.IMAGE, ReplyType.VOICE] | self.NOT_SUPPORT_REPLYTYPE = [ReplyType.IMAGE, ReplyType.VOICE] | ||||
self.cache_dict = dict() | self.cache_dict = dict() | ||||
@@ -36,8 +39,8 @@ class WechatMPChannel(ChatChannel): | |||||
else: | else: | ||||
# TODO support image | # TODO support image | ||||
self.NOT_SUPPORT_REPLYTYPE = [ReplyType.IMAGE, ReplyType.VOICE] | self.NOT_SUPPORT_REPLYTYPE = [ReplyType.IMAGE, ReplyType.VOICE] | ||||
self.app_id = conf().get('wechatmp_app_id') | |||||
self.app_secret = conf().get('wechatmp_app_secret') | |||||
self.app_id = conf().get("wechatmp_app_id") | |||||
self.app_secret = conf().get("wechatmp_app_secret") | |||||
self.access_token = None | self.access_token = None | ||||
self.access_token_expires_time = 0 | self.access_token_expires_time = 0 | ||||
self.access_token_lock = threading.Lock() | self.access_token_lock = threading.Lock() | ||||
@@ -45,13 +48,12 @@ class WechatMPChannel(ChatChannel): | |||||
def startup(self): | def startup(self): | ||||
if self.passive_reply: | if self.passive_reply: | ||||
urls = ('/wx', 'channel.wechatmp.SubscribeAccount.Query') | |||||
urls = ("/wx", "channel.wechatmp.SubscribeAccount.Query") | |||||
else: | else: | ||||
urls = ('/wx', 'channel.wechatmp.ServiceAccount.Query') | |||||
urls = ("/wx", "channel.wechatmp.ServiceAccount.Query") | |||||
app = web.application(urls, globals(), autoreload=False) | app = web.application(urls, globals(), autoreload=False) | ||||
port = conf().get('wechatmp_port', 8080) | |||||
web.httpserver.runsimple(app.wsgifunc(), ('0.0.0.0', port)) | |||||
port = conf().get("wechatmp_port", 8080) | |||||
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port)) | |||||
def wechatmp_request(self, method, url, **kwargs): | def wechatmp_request(self, method, url, **kwargs): | ||||
r = requests.request(method=method, url=url, **kwargs) | r = requests.request(method=method, url=url, **kwargs) | ||||
@@ -63,7 +65,6 @@ class WechatMPChannel(ChatChannel): | |||||
return ret | return ret | ||||
def get_access_token(self): | def get_access_token(self): | ||||
# return the access_token | # return the access_token | ||||
if self.access_token: | if self.access_token: | ||||
if self.access_token_expires_time - time.time() > 60: | if self.access_token_expires_time - time.time() > 60: | ||||
@@ -76,15 +77,15 @@ class WechatMPChannel(ChatChannel): | |||||
# This happens every 2 hours, so it doesn't affect the experience very much | # This happens every 2 hours, so it doesn't affect the experience very much | ||||
time.sleep(1) | time.sleep(1) | ||||
self.access_token = None | self.access_token = None | ||||
url="https://api.weixin.qq.com/cgi-bin/token" | |||||
params={ | |||||
url = "https://api.weixin.qq.com/cgi-bin/token" | |||||
params = { | |||||
"grant_type": "client_credential", | "grant_type": "client_credential", | ||||
"appid": self.app_id, | "appid": self.app_id, | ||||
"secret": self.app_secret | |||||
"secret": self.app_secret, | |||||
} | } | ||||
data = self.wechatmp_request(method='get', url=url, params=params) | |||||
self.access_token = data['access_token'] | |||||
self.access_token_expires_time = int(time.time()) + data['expires_in'] | |||||
data = self.wechatmp_request(method="get", url=url, params=params) | |||||
self.access_token = data["access_token"] | |||||
self.access_token_expires_time = int(time.time()) + data["expires_in"] | |||||
logger.info("[wechatmp] access_token: {}".format(self.access_token)) | logger.info("[wechatmp] access_token: {}".format(self.access_token)) | ||||
self.access_token_lock.release() | self.access_token_lock.release() | ||||
else: | else: | ||||
@@ -101,29 +102,37 @@ class WechatMPChannel(ChatChannel): | |||||
else: | else: | ||||
receiver = context["receiver"] | receiver = context["receiver"] | ||||
reply_text = reply.content | reply_text = reply.content | ||||
url="https://api.weixin.qq.com/cgi-bin/message/custom/send" | |||||
params = { | |||||
"access_token": self.get_access_token() | |||||
} | |||||
url = "https://api.weixin.qq.com/cgi-bin/message/custom/send" | |||||
params = {"access_token": self.get_access_token()} | |||||
json_data = { | json_data = { | ||||
"touser": receiver, | "touser": receiver, | ||||
"msgtype": "text", | "msgtype": "text", | ||||
"text": {"content": reply_text} | |||||
"text": {"content": reply_text}, | |||||
} | } | ||||
self.wechatmp_request(method='post', url=url, params=params, data=json.dumps(json_data, ensure_ascii=False).encode('utf8')) | |||||
self.wechatmp_request( | |||||
method="post", | |||||
url=url, | |||||
params=params, | |||||
data=json.dumps(json_data, ensure_ascii=False).encode("utf8"), | |||||
) | |||||
logger.info("[send] Do send to {}: {}".format(receiver, reply_text)) | logger.info("[send] Do send to {}: {}".format(receiver, reply_text)) | ||||
return | return | ||||
def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数 | |||||
logger.debug("[wechatmp] Success to generate reply, msgId={}".format(context['msg'].msg_id)) | |||||
def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数 | |||||
logger.debug( | |||||
"[wechatmp] Success to generate reply, msgId={}".format( | |||||
context["msg"].msg_id | |||||
) | |||||
) | |||||
if self.passive_reply: | if self.passive_reply: | ||||
self.running.remove(session_id) | self.running.remove(session_id) | ||||
def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数 | |||||
logger.exception("[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format(context['msg'].msg_id, exception)) | |||||
def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数 | |||||
logger.exception( | |||||
"[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format( | |||||
context["msg"].msg_id, exception | |||||
) | |||||
) | |||||
if self.passive_reply: | if self.passive_reply: | ||||
assert session_id not in self.cache_dict | assert session_id not in self.cache_dict | ||||
self.running.remove(session_id) | self.running.remove(session_id) | ||||
@@ -2,4 +2,4 @@ | |||||
OPEN_AI = "openAI" | OPEN_AI = "openAI" | ||||
CHATGPT = "chatGPT" | CHATGPT = "chatGPT" | ||||
BAIDU = "baidu" | BAIDU = "baidu" | ||||
CHATGPTONAZURE = "chatGPTOnAzure" | |||||
CHATGPTONAZURE = "chatGPTOnAzure" |
@@ -1,7 +1,7 @@ | |||||
from queue import Full, Queue | from queue import Full, Queue | ||||
from time import monotonic as time | from time import monotonic as time | ||||
# add implementation of putleft to Queue | # add implementation of putleft to Queue | ||||
class Dequeue(Queue): | class Dequeue(Queue): | ||||
def putleft(self, item, block=True, timeout=None): | def putleft(self, item, block=True, timeout=None): | ||||
@@ -30,4 +30,4 @@ class Dequeue(Queue): | |||||
return self.putleft(item, block=False) | return self.putleft(item, block=False) | ||||
def _putleft(self, item): | def _putleft(self, item): | ||||
self.queue.appendleft(item) | |||||
self.queue.appendleft(item) |
@@ -39,4 +39,4 @@ class ExpiredDict(dict): | |||||
return [(key, self[key]) for key in self.keys()] | return [(key, self[key]) for key in self.keys()] | ||||
def __iter__(self): | def __iter__(self): | ||||
return self.keys().__iter__() | |||||
return self.keys().__iter__() |
@@ -10,20 +10,29 @@ def _reset_logger(log): | |||||
log.handlers.clear() | log.handlers.clear() | ||||
log.propagate = False | log.propagate = False | ||||
console_handle = logging.StreamHandler(sys.stdout) | console_handle = logging.StreamHandler(sys.stdout) | ||||
console_handle.setFormatter(logging.Formatter('[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s', | |||||
datefmt='%Y-%m-%d %H:%M:%S')) | |||||
file_handle = logging.FileHandler('run.log', encoding='utf-8') | |||||
file_handle.setFormatter(logging.Formatter('[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s', | |||||
datefmt='%Y-%m-%d %H:%M:%S')) | |||||
console_handle.setFormatter( | |||||
logging.Formatter( | |||||
"[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s", | |||||
datefmt="%Y-%m-%d %H:%M:%S", | |||||
) | |||||
) | |||||
file_handle = logging.FileHandler("run.log", encoding="utf-8") | |||||
file_handle.setFormatter( | |||||
logging.Formatter( | |||||
"[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s", | |||||
datefmt="%Y-%m-%d %H:%M:%S", | |||||
) | |||||
) | |||||
log.addHandler(file_handle) | log.addHandler(file_handle) | ||||
log.addHandler(console_handle) | log.addHandler(console_handle) | ||||
def _get_logger(): | def _get_logger(): | ||||
log = logging.getLogger('log') | |||||
log = logging.getLogger("log") | |||||
_reset_logger(log) | _reset_logger(log) | ||||
log.setLevel(logging.INFO) | log.setLevel(logging.INFO) | ||||
return log | return log | ||||
# 日志句柄 | # 日志句柄 | ||||
logger = _get_logger() | |||||
logger = _get_logger() |
@@ -1,15 +1,20 @@ | |||||
import time | import time | ||||
import pip | import pip | ||||
from pip._internal import main as pipmain | from pip._internal import main as pipmain | ||||
from common.log import logger,_reset_logger | |||||
from common.log import _reset_logger, logger | |||||
def install(package): | def install(package): | ||||
pipmain(['install', package]) | |||||
pipmain(["install", package]) | |||||
def install_requirements(file): | def install_requirements(file): | ||||
pipmain(['install', '-r', file, "--upgrade"]) | |||||
pipmain(["install", "-r", file, "--upgrade"]) | |||||
_reset_logger(logger) | _reset_logger(logger) | ||||
def check_dulwich(): | def check_dulwich(): | ||||
needwait = False | needwait = False | ||||
for i in range(2): | for i in range(2): | ||||
@@ -18,13 +23,14 @@ def check_dulwich(): | |||||
needwait = False | needwait = False | ||||
try: | try: | ||||
import dulwich | import dulwich | ||||
return | return | ||||
except ImportError: | except ImportError: | ||||
try: | try: | ||||
install('dulwich') | |||||
install("dulwich") | |||||
except: | except: | ||||
needwait = True | needwait = True | ||||
try: | try: | ||||
import dulwich | import dulwich | ||||
except ImportError: | except ImportError: | ||||
raise ImportError("Unable to import dulwich") | |||||
raise ImportError("Unable to import dulwich") |
@@ -62,4 +62,4 @@ class SortedDict(dict): | |||||
return iter(self.keys()) | return iter(self.keys()) | ||||
def __repr__(self): | def __repr__(self): | ||||
return f'{type(self).__name__}({dict(self)}, sort_func={self.sort_func.__name__}, reverse={self.reverse})' | |||||
return f"{type(self).__name__}({dict(self)}, sort_func={self.sort_func.__name__}, reverse={self.reverse})" |
@@ -1,7 +1,11 @@ | |||||
import time,re,hashlib | |||||
import hashlib | |||||
import re | |||||
import time | |||||
import config | import config | ||||
from common.log import logger | from common.log import logger | ||||
def time_checker(f): | def time_checker(f): | ||||
def _time_checker(self, *args, **kwargs): | def _time_checker(self, *args, **kwargs): | ||||
_config = config.conf() | _config = config.conf() | ||||
@@ -9,17 +13,25 @@ def time_checker(f): | |||||
if chat_time_module: | if chat_time_module: | ||||
chat_start_time = _config.get("chat_start_time", "00:00") | chat_start_time = _config.get("chat_start_time", "00:00") | ||||
chat_stopt_time = _config.get("chat_stop_time", "24:00") | chat_stopt_time = _config.get("chat_stop_time", "24:00") | ||||
time_regex = re.compile(r'^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$') #时间匹配,包含24:00 | |||||
time_regex = re.compile( | |||||
r"^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$" | |||||
) # 时间匹配,包含24:00 | |||||
starttime_format_check = time_regex.match(chat_start_time) # 检查停止时间格式 | starttime_format_check = time_regex.match(chat_start_time) # 检查停止时间格式 | ||||
stoptime_format_check = time_regex.match(chat_stopt_time) # 检查停止时间格式 | stoptime_format_check = time_regex.match(chat_stopt_time) # 检查停止时间格式 | ||||
chat_time_check = chat_start_time < chat_stopt_time # 确定启动时间<停止时间 | |||||
chat_time_check = chat_start_time < chat_stopt_time # 确定启动时间<停止时间 | |||||
# 时间格式检查 | # 时间格式检查 | ||||
if not (starttime_format_check and stoptime_format_check and chat_time_check): | |||||
logger.warn('时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})'.format(starttime_format_check,stoptime_format_check)) | |||||
if chat_start_time>"23:59": | |||||
logger.error('启动时间可能存在问题,请修改!') | |||||
if not ( | |||||
starttime_format_check and stoptime_format_check and chat_time_check | |||||
): | |||||
logger.warn( | |||||
"时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})".format( | |||||
starttime_format_check, stoptime_format_check | |||||
) | |||||
) | |||||
if chat_start_time > "23:59": | |||||
logger.error("启动时间可能存在问题,请修改!") | |||||
# 服务时间检查 | # 服务时间检查 | ||||
now_time = time.strftime("%H:%M", time.localtime()) | now_time = time.strftime("%H:%M", time.localtime()) | ||||
@@ -27,12 +39,12 @@ def time_checker(f): | |||||
f(self, *args, **kwargs) | f(self, *args, **kwargs) | ||||
return None | return None | ||||
else: | else: | ||||
if args[0]['Content'] == "#更新配置": # 不在服务时间内也可以更新配置 | |||||
if args[0]["Content"] == "#更新配置": # 不在服务时间内也可以更新配置 | |||||
f(self, *args, **kwargs) | f(self, *args, **kwargs) | ||||
else: | else: | ||||
logger.info('非服务时间内,不接受访问') | |||||
logger.info("非服务时间内,不接受访问") | |||||
return None | return None | ||||
else: | else: | ||||
f(self, *args, **kwargs) # 未开启时间模块则直接回答 | f(self, *args, **kwargs) # 未开启时间模块则直接回答 | ||||
return _time_checker | |||||
return _time_checker |
@@ -1,20 +1,18 @@ | |||||
import os | import os | ||||
import pathlib | import pathlib | ||||
from config import conf | from config import conf | ||||
class TmpDir(object): | class TmpDir(object): | ||||
"""A temporary directory that is deleted when the object is destroyed. | |||||
""" | |||||
"""A temporary directory that is deleted when the object is destroyed.""" | |||||
tmpFilePath = pathlib.Path("./tmp/") | |||||
tmpFilePath = pathlib.Path('./tmp/') | |||||
def __init__(self): | def __init__(self): | ||||
pathExists = os.path.exists(self.tmpFilePath) | pathExists = os.path.exists(self.tmpFilePath) | ||||
if not pathExists: | if not pathExists: | ||||
os.makedirs(self.tmpFilePath) | os.makedirs(self.tmpFilePath) | ||||
def path(self): | def path(self): | ||||
return str(self.tmpFilePath) + '/' | |||||
return str(self.tmpFilePath) + "/" |
@@ -2,16 +2,30 @@ | |||||
"open_ai_api_key": "YOUR API KEY", | "open_ai_api_key": "YOUR API KEY", | ||||
"model": "gpt-3.5-turbo", | "model": "gpt-3.5-turbo", | ||||
"proxy": "", | "proxy": "", | ||||
"single_chat_prefix": ["bot", "@bot"], | |||||
"single_chat_prefix": [ | |||||
"bot", | |||||
"@bot" | |||||
], | |||||
"single_chat_reply_prefix": "[bot] ", | "single_chat_reply_prefix": "[bot] ", | ||||
"group_chat_prefix": ["@bot"], | |||||
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], | |||||
"group_chat_in_one_session": ["ChatGPT测试群"], | |||||
"image_create_prefix": ["画", "看", "找"], | |||||
"group_chat_prefix": [ | |||||
"@bot" | |||||
], | |||||
"group_name_white_list": [ | |||||
"ChatGPT测试群", | |||||
"ChatGPT测试群2" | |||||
], | |||||
"group_chat_in_one_session": [ | |||||
"ChatGPT测试群" | |||||
], | |||||
"image_create_prefix": [ | |||||
"画", | |||||
"看", | |||||
"找" | |||||
], | |||||
"speech_recognition": false, | "speech_recognition": false, | ||||
"group_speech_recognition": false, | "group_speech_recognition": false, | ||||
"voice_reply_voice": false, | "voice_reply_voice": false, | ||||
"conversation_max_tokens": 1000, | "conversation_max_tokens": 1000, | ||||
"expires_in_seconds": 3600, | "expires_in_seconds": 3600, | ||||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。" | "character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。" | ||||
} | |||||
} |
@@ -3,9 +3,10 @@ | |||||
import json | import json | ||||
import logging | import logging | ||||
import os | import os | ||||
from common.log import logger | |||||
import pickle | import pickle | ||||
from common.log import logger | |||||
# 将所有可用的配置项写在字典里, 请使用小写字母 | # 将所有可用的配置项写在字典里, 请使用小写字母 | ||||
available_setting = { | available_setting = { | ||||
# openai api配置 | # openai api配置 | ||||
@@ -16,8 +17,7 @@ available_setting = { | |||||
# chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称 | # chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称 | ||||
"model": "gpt-3.5-turbo", | "model": "gpt-3.5-turbo", | ||||
"use_azure_chatgpt": False, # 是否使用azure的chatgpt | "use_azure_chatgpt": False, # 是否使用azure的chatgpt | ||||
"azure_deployment_id": "", #azure 模型部署名称 | |||||
"azure_deployment_id": "", # azure 模型部署名称 | |||||
# Bot触发配置 | # Bot触发配置 | ||||
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复 | "single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复 | ||||
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人 | "single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人 | ||||
@@ -30,25 +30,21 @@ available_setting = { | |||||
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称 | "group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称 | ||||
"trigger_by_self": False, # 是否允许机器人触发 | "trigger_by_self": False, # 是否允许机器人触发 | ||||
"image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀 | "image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀 | ||||
"concurrency_in_session": 1, # 同一会话最多有多少条消息在处理中,大于1可能乱序 | |||||
"concurrency_in_session": 1, # 同一会话最多有多少条消息在处理中,大于1可能乱序 | |||||
# chatgpt会话参数 | # chatgpt会话参数 | ||||
"expires_in_seconds": 3600, # 无操作会话的过期时间 | "expires_in_seconds": 3600, # 无操作会话的过期时间 | ||||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述 | "character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述 | ||||
"conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数 | "conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数 | ||||
# chatgpt限流配置 | # chatgpt限流配置 | ||||
"rate_limit_chatgpt": 20, # chatgpt的调用频率限制 | "rate_limit_chatgpt": 20, # chatgpt的调用频率限制 | ||||
"rate_limit_dalle": 50, # openai dalle的调用频率限制 | "rate_limit_dalle": 50, # openai dalle的调用频率限制 | ||||
# chatgpt api参数 参考https://platform.openai.com/docs/api-reference/chat/create | # chatgpt api参数 参考https://platform.openai.com/docs/api-reference/chat/create | ||||
"temperature": 0.9, | "temperature": 0.9, | ||||
"top_p": 1, | "top_p": 1, | ||||
"frequency_penalty": 0, | "frequency_penalty": 0, | ||||
"presence_penalty": 0, | "presence_penalty": 0, | ||||
"request_timeout": 60, # chatgpt请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 | |||||
"timeout": 120, # chatgpt重试超时时间,在这个时间内,将会自动重试 | |||||
"request_timeout": 60, # chatgpt请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 | |||||
"timeout": 120, # chatgpt重试超时时间,在这个时间内,将会自动重试 | |||||
# 语音设置 | # 语音设置 | ||||
"speech_recognition": False, # 是否开启语音识别 | "speech_recognition": False, # 是否开启语音识别 | ||||
"group_speech_recognition": False, # 是否开启群组语音识别 | "group_speech_recognition": False, # 是否开启群组语音识别 | ||||
@@ -56,50 +52,40 @@ available_setting = { | |||||
"always_reply_voice": False, # 是否一直使用语音回复 | "always_reply_voice": False, # 是否一直使用语音回复 | ||||
"voice_to_text": "openai", # 语音识别引擎,支持openai,baidu,google,azure | "voice_to_text": "openai", # 语音识别引擎,支持openai,baidu,google,azure | ||||
"text_to_voice": "baidu", # 语音合成引擎,支持baidu,google,pytts(offline),azure | "text_to_voice": "baidu", # 语音合成引擎,支持baidu,google,pytts(offline),azure | ||||
# baidu 语音api配置, 使用百度语音识别和语音合成时需要 | # baidu 语音api配置, 使用百度语音识别和语音合成时需要 | ||||
"baidu_app_id": "", | "baidu_app_id": "", | ||||
"baidu_api_key": "", | "baidu_api_key": "", | ||||
"baidu_secret_key": "", | "baidu_secret_key": "", | ||||
# 1536普通话(支持简单的英文识别) 1737英语 1637粤语 1837四川话 1936普通话远场 | # 1536普通话(支持简单的英文识别) 1737英语 1637粤语 1837四川话 1936普通话远场 | ||||
"baidu_dev_pid": "1536", | "baidu_dev_pid": "1536", | ||||
# azure 语音api配置, 使用azure语音识别和语音合成时需要 | # azure 语音api配置, 使用azure语音识别和语音合成时需要 | ||||
"azure_voice_api_key": "", | "azure_voice_api_key": "", | ||||
"azure_voice_region": "japaneast", | "azure_voice_region": "japaneast", | ||||
# 服务时间限制,目前支持itchat | # 服务时间限制,目前支持itchat | ||||
"chat_time_module": False, # 是否开启服务时间限制 | "chat_time_module": False, # 是否开启服务时间限制 | ||||
"chat_start_time": "00:00", # 服务开始时间 | "chat_start_time": "00:00", # 服务开始时间 | ||||
"chat_stop_time": "24:00", # 服务结束时间 | "chat_stop_time": "24:00", # 服务结束时间 | ||||
# itchat的配置 | # itchat的配置 | ||||
"hot_reload": False, # 是否开启热重载 | "hot_reload": False, # 是否开启热重载 | ||||
# wechaty的配置 | # wechaty的配置 | ||||
"wechaty_puppet_service_token": "", # wechaty的token | "wechaty_puppet_service_token": "", # wechaty的token | ||||
# wechatmp的配置 | # wechatmp的配置 | ||||
"wechatmp_token": "", # 微信公众平台的Token | |||||
"wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443 | |||||
"wechatmp_app_id": "", # 微信公众平台的appID,仅服务号需要 | |||||
"wechatmp_token": "", # 微信公众平台的Token | |||||
"wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443 | |||||
"wechatmp_app_id": "", # 微信公众平台的appID,仅服务号需要 | |||||
"wechatmp_app_secret": "", # 微信公众平台的appsecret,仅服务号需要 | "wechatmp_app_secret": "", # 微信公众平台的appsecret,仅服务号需要 | ||||
# chatgpt指令自定义触发词 | # chatgpt指令自定义触发词 | ||||
"clear_memory_commands": ['#清除记忆'], # 重置会话指令,必须以#开头 | |||||
"clear_memory_commands": ["#清除记忆"], # 重置会话指令,必须以#开头 | |||||
# channel配置 | # channel配置 | ||||
"channel_type": "wx", # 通道类型,支持:{wx,wxy,terminal,wechatmp,wechatmp_service} | |||||
"channel_type": "wx", # 通道类型,支持:{wx,wxy,terminal,wechatmp,wechatmp_service} | |||||
"debug": False, # 是否开启debug模式,开启后会打印更多日志 | "debug": False, # 是否开启debug模式,开启后会打印更多日志 | ||||
# 插件配置 | # 插件配置 | ||||
"plugin_trigger_prefix": "$", # 规范插件提供聊天相关指令的前缀,建议不要和管理员指令前缀"#"冲突 | "plugin_trigger_prefix": "$", # 规范插件提供聊天相关指令的前缀,建议不要和管理员指令前缀"#"冲突 | ||||
} | } | ||||
class Config(dict): | class Config(dict): | ||||
def __init__(self, d:dict={}): | |||||
def __init__(self, d: dict = {}): | |||||
super().__init__(d) | super().__init__(d) | ||||
# user_datas: 用户数据,key为用户名,value为用户数据,也是dict | # user_datas: 用户数据,key为用户名,value为用户数据,也是dict | ||||
self.user_datas = {} | self.user_datas = {} | ||||
@@ -130,7 +116,7 @@ class Config(dict): | |||||
def load_user_datas(self): | def load_user_datas(self): | ||||
try: | try: | ||||
with open('user_datas.pkl', 'rb') as f: | |||||
with open("user_datas.pkl", "rb") as f: | |||||
self.user_datas = pickle.load(f) | self.user_datas = pickle.load(f) | ||||
logger.info("[Config] User datas loaded.") | logger.info("[Config] User datas loaded.") | ||||
except FileNotFoundError as e: | except FileNotFoundError as e: | ||||
@@ -141,12 +127,13 @@ class Config(dict): | |||||
def save_user_datas(self): | def save_user_datas(self): | ||||
try: | try: | ||||
with open('user_datas.pkl', 'wb') as f: | |||||
with open("user_datas.pkl", "wb") as f: | |||||
pickle.dump(self.user_datas, f) | pickle.dump(self.user_datas, f) | ||||
logger.info("[Config] User datas saved.") | logger.info("[Config] User datas saved.") | ||||
except Exception as e: | except Exception as e: | ||||
logger.info("[Config] User datas error: {}".format(e)) | logger.info("[Config] User datas error: {}".format(e)) | ||||
config = Config() | config = Config() | ||||
@@ -154,7 +141,7 @@ def load_config(): | |||||
global config | global config | ||||
config_path = "./config.json" | config_path = "./config.json" | ||||
if not os.path.exists(config_path): | if not os.path.exists(config_path): | ||||
logger.info('配置文件不存在,将使用config-template.json模板') | |||||
logger.info("配置文件不存在,将使用config-template.json模板") | |||||
config_path = "./config-template.json" | config_path = "./config-template.json" | ||||
config_str = read_file(config_path) | config_str = read_file(config_path) | ||||
@@ -169,7 +156,8 @@ def load_config(): | |||||
name = name.lower() | name = name.lower() | ||||
if name in available_setting: | if name in available_setting: | ||||
logger.info( | logger.info( | ||||
"[INIT] override config by environ args: {}={}".format(name, value)) | |||||
"[INIT] override config by environ args: {}={}".format(name, value) | |||||
) | |||||
try: | try: | ||||
config[name] = eval(value) | config[name] = eval(value) | ||||
except: | except: | ||||
@@ -182,18 +170,19 @@ def load_config(): | |||||
if config.get("debug", False): | if config.get("debug", False): | ||||
logger.setLevel(logging.DEBUG) | logger.setLevel(logging.DEBUG) | ||||
logger.debug("[INIT] set log level to DEBUG") | |||||
logger.debug("[INIT] set log level to DEBUG") | |||||
logger.info("[INIT] load config: {}".format(config)) | logger.info("[INIT] load config: {}".format(config)) | ||||
config.load_user_datas() | config.load_user_datas() | ||||
def get_root(): | def get_root(): | ||||
return os.path.dirname(os.path.abspath(__file__)) | return os.path.dirname(os.path.abspath(__file__)) | ||||
def read_file(path): | def read_file(path): | ||||
with open(path, mode='r', encoding='utf-8') as f: | |||||
with open(path, mode="r", encoding="utf-8") as f: | |||||
return f.read() | return f.read() | ||||
@@ -33,7 +33,7 @@ ADD ./entrypoint.sh /entrypoint.sh | |||||
RUN chmod +x /entrypoint.sh \ | RUN chmod +x /entrypoint.sh \ | ||||
&& groupadd -r noroot \ | && groupadd -r noroot \ | ||||
&& useradd -r -g noroot -s /bin/bash -d /home/noroot noroot \ | && useradd -r -g noroot -s /bin/bash -d /home/noroot noroot \ | ||||
&& chown -R noroot:noroot ${BUILD_PREFIX} | |||||
&& chown -R noroot:noroot ${BUILD_PREFIX} | |||||
USER noroot | USER noroot | ||||
@@ -18,7 +18,7 @@ RUN apt-get update \ | |||||
&& pip install --no-cache -r requirements.txt \ | && pip install --no-cache -r requirements.txt \ | ||||
&& pip install --no-cache -r requirements-optional.txt \ | && pip install --no-cache -r requirements-optional.txt \ | ||||
&& pip install azure-cognitiveservices-speech | && pip install azure-cognitiveservices-speech | ||||
WORKDIR ${BUILD_PREFIX} | WORKDIR ${BUILD_PREFIX} | ||||
ADD docker/entrypoint.sh /entrypoint.sh | ADD docker/entrypoint.sh /entrypoint.sh | ||||
@@ -11,6 +11,5 @@ docker build -f Dockerfile.alpine \ | |||||
-t zhayujie/chatgpt-on-wechat . | -t zhayujie/chatgpt-on-wechat . | ||||
# tag image | # tag image | ||||
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:alpine | |||||
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:alpine | |||||
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$CHATGPT_ON_WECHAT_TAG-alpine | docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$CHATGPT_ON_WECHAT_TAG-alpine | ||||
@@ -11,5 +11,5 @@ docker build -f Dockerfile.debian \ | |||||
-t zhayujie/chatgpt-on-wechat . | -t zhayujie/chatgpt-on-wechat . | ||||
# tag image | # tag image | ||||
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:debian | |||||
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:debian | |||||
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$CHATGPT_ON_WECHAT_TAG-debian | docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$CHATGPT_ON_WECHAT_TAG-debian |
@@ -9,7 +9,7 @@ RUN apk add --no-cache \ | |||||
ffmpeg \ | ffmpeg \ | ||||
espeak \ | espeak \ | ||||
&& pip install --no-cache \ | && pip install --no-cache \ | ||||
baidu-aip \ | |||||
baidu-aip \ | |||||
chardet \ | chardet \ | ||||
SpeechRecognition | SpeechRecognition | ||||
@@ -10,7 +10,7 @@ RUN apt-get update \ | |||||
ffmpeg \ | ffmpeg \ | ||||
espeak \ | espeak \ | ||||
&& pip install --no-cache \ | && pip install --no-cache \ | ||||
baidu-aip \ | |||||
baidu-aip \ | |||||
chardet \ | chardet \ | ||||
SpeechRecognition | SpeechRecognition | ||||
@@ -11,13 +11,13 @@ run_d: | |||||
docker rm $(CONTAINER_NAME) || echo | docker rm $(CONTAINER_NAME) || echo | ||||
docker run -dt --name $(CONTAINER_NAME) $(PORT_MAP) \ | docker run -dt --name $(CONTAINER_NAME) $(PORT_MAP) \ | ||||
--env-file=$(DOTENV) \ | --env-file=$(DOTENV) \ | ||||
$(MOUNT) $(IMG) | |||||
$(MOUNT) $(IMG) | |||||
run_i: | run_i: | ||||
docker rm $(CONTAINER_NAME) || echo | docker rm $(CONTAINER_NAME) || echo | ||||
docker run -it --name $(CONTAINER_NAME) $(PORT_MAP) \ | docker run -it --name $(CONTAINER_NAME) $(PORT_MAP) \ | ||||
--env-file=$(DOTENV) \ | --env-file=$(DOTENV) \ | ||||
$(MOUNT) $(IMG) | |||||
$(MOUNT) $(IMG) | |||||
stop: | stop: | ||||
docker stop $(CONTAINER_NAME) | docker stop $(CONTAINER_NAME) | ||||
@@ -24,17 +24,17 @@ | |||||
在本仓库中预置了一些插件,如果要安装其他仓库的插件,有两种方法。 | 在本仓库中预置了一些插件,如果要安装其他仓库的插件,有两种方法。 | ||||
- 第一种方法是在将下载的插件文件都解压到"plugins"文件夹的一个单独的文件夹,最终插件的代码都位于"plugins/PLUGIN_NAME/*"中。启动程序后,如果插件的目录结构正确,插件会自动被扫描加载。除此以外,注意你还需要安装文件夹中`requirements.txt`中的依赖。 | - 第一种方法是在将下载的插件文件都解压到"plugins"文件夹的一个单独的文件夹,最终插件的代码都位于"plugins/PLUGIN_NAME/*"中。启动程序后,如果插件的目录结构正确,插件会自动被扫描加载。除此以外,注意你还需要安装文件夹中`requirements.txt`中的依赖。 | ||||
- 第二种方法是`Godcmd`插件,它是预置的管理员插件,能够让程序在运行时就能安装插件,它能够自动安装依赖。 | - 第二种方法是`Godcmd`插件,它是预置的管理员插件,能够让程序在运行时就能安装插件,它能够自动安装依赖。 | ||||
安装插件的命令是"#installp [仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)记录的插件名/仓库地址"。这是管理员命令,认证方法在[这里](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/godcmd)。 | 安装插件的命令是"#installp [仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)记录的插件名/仓库地址"。这是管理员命令,认证方法在[这里](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/godcmd)。 | ||||
- 安装[仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)记录的插件:#installp sdwebui | - 安装[仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)记录的插件:#installp sdwebui | ||||
- 安装指定仓库的插件:#installp https://github.com/lanvent/plugin_sdwebui.git | - 安装指定仓库的插件:#installp https://github.com/lanvent/plugin_sdwebui.git | ||||
在安装之后,需要执行"#scanp"命令来扫描加载新安装的插件(或者重新启动程序)。 | 在安装之后,需要执行"#scanp"命令来扫描加载新安装的插件(或者重新启动程序)。 | ||||
安装插件后需要注意有些插件有自己的配置模板,一般要去掉".template"新建一个配置文件。 | 安装插件后需要注意有些插件有自己的配置模板,一般要去掉".template"新建一个配置文件。 | ||||
## 插件化实现 | ## 插件化实现 | ||||
@@ -107,14 +107,14 @@ | |||||
``` | ``` | ||||
回复`Reply`的定义如下所示,它允许Bot可以回复多类不同的消息。同时也加入了`INFO`和`ERROR`消息类型区分系统提示和系统错误。 | 回复`Reply`的定义如下所示,它允许Bot可以回复多类不同的消息。同时也加入了`INFO`和`ERROR`消息类型区分系统提示和系统错误。 | ||||
```python | ```python | ||||
class ReplyType(Enum): | class ReplyType(Enum): | ||||
TEXT = 1 # 文本 | TEXT = 1 # 文本 | ||||
VOICE = 2 # 音频文件 | VOICE = 2 # 音频文件 | ||||
IMAGE = 3 # 图片文件 | IMAGE = 3 # 图片文件 | ||||
IMAGE_URL = 4 # 图片URL | IMAGE_URL = 4 # 图片URL | ||||
INFO = 9 | INFO = 9 | ||||
ERROR = 10 | ERROR = 10 | ||||
class Reply: | class Reply: | ||||
@@ -159,12 +159,12 @@ | |||||
目前支持三类触发事件: | 目前支持三类触发事件: | ||||
``` | ``` | ||||
1.收到消息 | |||||
---> `ON_HANDLE_CONTEXT` | |||||
2.产生回复 | |||||
---> `ON_DECORATE_REPLY` | |||||
3.装饰回复 | |||||
---> `ON_SEND_REPLY` | |||||
1.收到消息 | |||||
---> `ON_HANDLE_CONTEXT` | |||||
2.产生回复 | |||||
---> `ON_DECORATE_REPLY` | |||||
3.装饰回复 | |||||
---> `ON_SEND_REPLY` | |||||
4.发送回复 | 4.发送回复 | ||||
``` | ``` | ||||
@@ -268,6 +268,6 @@ class Hello(Plugin): | |||||
- 一个插件目录建议只注册一个插件类。建议使用单独的仓库维护插件,便于更新。 | - 一个插件目录建议只注册一个插件类。建议使用单独的仓库维护插件,便于更新。 | ||||
在测试调试好后提交`PR`,把自己的仓库加入到[仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)中。 | 在测试调试好后提交`PR`,把自己的仓库加入到[仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)中。 | ||||
- 插件的config文件、使用说明`README.md`、`requirement.txt`等放置在插件目录中。 | - 插件的config文件、使用说明`README.md`、`requirement.txt`等放置在插件目录中。 | ||||
- 默认优先级不要超过管理员插件`Godcmd`的优先级(999),`Godcmd`插件提供了配置管理、插件管理等功能。 | - 默认优先级不要超过管理员插件`Godcmd`的优先级(999),`Godcmd`插件提供了配置管理、插件管理等功能。 |
@@ -1,9 +1,9 @@ | |||||
from .plugin_manager import PluginManager | |||||
from .event import * | from .event import * | ||||
from .plugin import * | from .plugin import * | ||||
from .plugin_manager import PluginManager | |||||
instance = PluginManager() | instance = PluginManager() | ||||
register = instance.register | |||||
register = instance.register | |||||
# load_plugins = instance.load_plugins | # load_plugins = instance.load_plugins | ||||
# emit_event = instance.emit_event | # emit_event = instance.emit_event |
@@ -1 +1 @@ | |||||
from .banwords import * | |||||
from .banwords import * |
@@ -2,56 +2,67 @@ | |||||
import json | import json | ||||
import os | import os | ||||
import plugins | |||||
from bridge.context import ContextType | from bridge.context import ContextType | ||||
from bridge.reply import Reply, ReplyType | from bridge.reply import Reply, ReplyType | ||||
import plugins | |||||
from plugins import * | |||||
from common.log import logger | from common.log import logger | ||||
from plugins import * | |||||
from .lib.WordsSearch import WordsSearch | from .lib.WordsSearch import WordsSearch | ||||
@plugins.register(name="Banwords", desire_priority=100, hidden=True, desc="判断消息中是否有敏感词、决定是否回复。", version="1.0", author="lanvent") | |||||
@plugins.register( | |||||
name="Banwords", | |||||
desire_priority=100, | |||||
hidden=True, | |||||
desc="判断消息中是否有敏感词、决定是否回复。", | |||||
version="1.0", | |||||
author="lanvent", | |||||
) | |||||
class Banwords(Plugin): | class Banwords(Plugin): | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
try: | try: | ||||
curdir=os.path.dirname(__file__) | |||||
config_path=os.path.join(curdir,"config.json") | |||||
conf=None | |||||
curdir = os.path.dirname(__file__) | |||||
config_path = os.path.join(curdir, "config.json") | |||||
conf = None | |||||
if not os.path.exists(config_path): | if not os.path.exists(config_path): | ||||
conf={"action":"ignore"} | |||||
with open(config_path,"w") as f: | |||||
json.dump(conf,f,indent=4) | |||||
conf = {"action": "ignore"} | |||||
with open(config_path, "w") as f: | |||||
json.dump(conf, f, indent=4) | |||||
else: | else: | ||||
with open(config_path,"r") as f: | |||||
conf=json.load(f) | |||||
with open(config_path, "r") as f: | |||||
conf = json.load(f) | |||||
self.searchr = WordsSearch() | self.searchr = WordsSearch() | ||||
self.action = conf["action"] | self.action = conf["action"] | ||||
banwords_path = os.path.join(curdir,"banwords.txt") | |||||
with open(banwords_path, 'r', encoding='utf-8') as f: | |||||
words=[] | |||||
banwords_path = os.path.join(curdir, "banwords.txt") | |||||
with open(banwords_path, "r", encoding="utf-8") as f: | |||||
words = [] | |||||
for line in f: | for line in f: | ||||
word = line.strip() | word = line.strip() | ||||
if word: | if word: | ||||
words.append(word) | words.append(word) | ||||
self.searchr.SetKeywords(words) | self.searchr.SetKeywords(words) | ||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context | self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context | ||||
if conf.get("reply_filter",True): | |||||
if conf.get("reply_filter", True): | |||||
self.handlers[Event.ON_DECORATE_REPLY] = self.on_decorate_reply | self.handlers[Event.ON_DECORATE_REPLY] = self.on_decorate_reply | ||||
self.reply_action = conf.get("reply_action","ignore") | |||||
self.reply_action = conf.get("reply_action", "ignore") | |||||
logger.info("[Banwords] inited") | logger.info("[Banwords] inited") | ||||
except Exception as e: | except Exception as e: | ||||
logger.warn("[Banwords] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords .") | |||||
logger.warn( | |||||
"[Banwords] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords ." | |||||
) | |||||
raise e | raise e | ||||
def on_handle_context(self, e_context: EventContext): | def on_handle_context(self, e_context: EventContext): | ||||
if e_context['context'].type not in [ContextType.TEXT,ContextType.IMAGE_CREATE]: | |||||
if e_context["context"].type not in [ | |||||
ContextType.TEXT, | |||||
ContextType.IMAGE_CREATE, | |||||
]: | |||||
return | return | ||||
content = e_context['context'].content | |||||
content = e_context["context"].content | |||||
logger.debug("[Banwords] on_handle_context. content: %s" % content) | logger.debug("[Banwords] on_handle_context. content: %s" % content) | ||||
if self.action == "ignore": | if self.action == "ignore": | ||||
f = self.searchr.FindFirst(content) | f = self.searchr.FindFirst(content) | ||||
@@ -61,31 +72,34 @@ class Banwords(Plugin): | |||||
return | return | ||||
elif self.action == "replace": | elif self.action == "replace": | ||||
if self.searchr.ContainsAny(content): | if self.searchr.ContainsAny(content): | ||||
reply = Reply(ReplyType.INFO, "发言中包含敏感词,请重试: \n"+self.searchr.Replace(content)) | |||||
e_context['reply'] = reply | |||||
reply = Reply( | |||||
ReplyType.INFO, "发言中包含敏感词,请重试: \n" + self.searchr.Replace(content) | |||||
) | |||||
e_context["reply"] = reply | |||||
e_context.action = EventAction.BREAK_PASS | e_context.action = EventAction.BREAK_PASS | ||||
return | return | ||||
def on_decorate_reply(self, e_context: EventContext): | |||||
if e_context['reply'].type not in [ReplyType.TEXT]: | |||||
def on_decorate_reply(self, e_context: EventContext): | |||||
if e_context["reply"].type not in [ReplyType.TEXT]: | |||||
return | return | ||||
reply = e_context['reply'] | |||||
reply = e_context["reply"] | |||||
content = reply.content | content = reply.content | ||||
if self.reply_action == "ignore": | if self.reply_action == "ignore": | ||||
f = self.searchr.FindFirst(content) | f = self.searchr.FindFirst(content) | ||||
if f: | if f: | ||||
logger.info("[Banwords] %s in reply" % f["Keyword"]) | logger.info("[Banwords] %s in reply" % f["Keyword"]) | ||||
e_context['reply'] = None | |||||
e_context["reply"] = None | |||||
e_context.action = EventAction.BREAK_PASS | e_context.action = EventAction.BREAK_PASS | ||||
return | return | ||||
elif self.reply_action == "replace": | elif self.reply_action == "replace": | ||||
if self.searchr.ContainsAny(content): | if self.searchr.ContainsAny(content): | ||||
reply = Reply(ReplyType.INFO, "已替换回复中的敏感词: \n"+self.searchr.Replace(content)) | |||||
e_context['reply'] = reply | |||||
reply = Reply( | |||||
ReplyType.INFO, "已替换回复中的敏感词: \n" + self.searchr.Replace(content) | |||||
) | |||||
e_context["reply"] = reply | |||||
e_context.action = EventAction.CONTINUE | e_context.action = EventAction.CONTINUE | ||||
return | return | ||||
def get_help_text(self, **kwargs): | def get_help_text(self, **kwargs): | ||||
return Banwords.desc | |||||
return Banwords.desc |
@@ -1,5 +1,5 @@ | |||||
{ | { | ||||
"action": "replace", | |||||
"reply_filter": true, | |||||
"reply_action": "ignore" | |||||
} | |||||
"action": "replace", | |||||
"reply_filter": true, | |||||
"reply_action": "ignore" | |||||
} |
@@ -24,7 +24,7 @@ see https://ai.baidu.com/unit/home#/home?track=61fe1b0d3407ce3face1d92cb5c291087 | |||||
``` json | ``` json | ||||
{ | { | ||||
"service_id": "s...", #"机器人ID" | "service_id": "s...", #"机器人ID" | ||||
"api_key": "", | |||||
"api_key": "", | |||||
"secret_key": "" | "secret_key": "" | ||||
} | } | ||||
``` | ``` |
@@ -1 +1 @@ | |||||
from .bdunit import * | |||||
from .bdunit import * |
@@ -2,21 +2,29 @@ | |||||
import json | import json | ||||
import os | import os | ||||
import uuid | import uuid | ||||
from uuid import getnode as get_mac | |||||
import requests | import requests | ||||
import plugins | |||||
from bridge.context import ContextType | from bridge.context import ContextType | ||||
from bridge.reply import Reply, ReplyType | from bridge.reply import Reply, ReplyType | ||||
from common.log import logger | from common.log import logger | ||||
import plugins | |||||
from plugins import * | from plugins import * | ||||
from uuid import getnode as get_mac | |||||
"""利用百度UNIT实现智能对话 | """利用百度UNIT实现智能对话 | ||||
如果命中意图,返回意图对应的回复,否则返回继续交付给下个插件处理 | 如果命中意图,返回意图对应的回复,否则返回继续交付给下个插件处理 | ||||
""" | """ | ||||
@plugins.register(name="BDunit", desire_priority=0, hidden=True, desc="Baidu unit bot system", version="0.1", author="jackson") | |||||
@plugins.register( | |||||
name="BDunit", | |||||
desire_priority=0, | |||||
hidden=True, | |||||
desc="Baidu unit bot system", | |||||
version="0.1", | |||||
author="jackson", | |||||
) | |||||
class BDunit(Plugin): | class BDunit(Plugin): | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
@@ -40,11 +48,10 @@ class BDunit(Plugin): | |||||
raise e | raise e | ||||
def on_handle_context(self, e_context: EventContext): | def on_handle_context(self, e_context: EventContext): | ||||
if e_context['context'].type != ContextType.TEXT: | |||||
if e_context["context"].type != ContextType.TEXT: | |||||
return | return | ||||
content = e_context['context'].content | |||||
content = e_context["context"].content | |||||
logger.debug("[BDunit] on_handle_context. content: %s" % content) | logger.debug("[BDunit] on_handle_context. content: %s" % content) | ||||
parsed = self.getUnit2(content) | parsed = self.getUnit2(content) | ||||
intent = self.getIntent(parsed) | intent = self.getIntent(parsed) | ||||
@@ -53,7 +60,7 @@ class BDunit(Plugin): | |||||
reply = Reply() | reply = Reply() | ||||
reply.type = ReplyType.TEXT | reply.type = ReplyType.TEXT | ||||
reply.content = self.getSay(parsed) | reply.content = self.getSay(parsed) | ||||
e_context['reply'] = reply | |||||
e_context["reply"] = reply | |||||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 | e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 | ||||
else: | else: | ||||
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 | e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 | ||||
@@ -70,17 +77,15 @@ class BDunit(Plugin): | |||||
string: access_token | string: access_token | ||||
""" | """ | ||||
url = "https://aip.baidubce.com/oauth/2.0/token?client_id={}&client_secret={}&grant_type=client_credentials".format( | url = "https://aip.baidubce.com/oauth/2.0/token?client_id={}&client_secret={}&grant_type=client_credentials".format( | ||||
self.api_key, self.secret_key) | |||||
self.api_key, self.secret_key | |||||
) | |||||
payload = "" | payload = "" | ||||
headers = { | |||||
'Content-Type': 'application/json', | |||||
'Accept': 'application/json' | |||||
} | |||||
headers = {"Content-Type": "application/json", "Accept": "application/json"} | |||||
response = requests.request("POST", url, headers=headers, data=payload) | response = requests.request("POST", url, headers=headers, data=payload) | ||||
# print(response.text) | # print(response.text) | ||||
return response.json()['access_token'] | |||||
return response.json()["access_token"] | |||||
def getUnit(self, query): | def getUnit(self, query): | ||||
""" | """ | ||||
@@ -90,11 +95,14 @@ class BDunit(Plugin): | |||||
""" | """ | ||||
url = ( | url = ( | ||||
'https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=' | |||||
"https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" | |||||
+ self.access_token | + self.access_token | ||||
) | ) | ||||
request = {"query": query, "user_id": str( | |||||
get_mac())[:32], "terminal_id": "88888"} | |||||
request = { | |||||
"query": query, | |||||
"user_id": str(get_mac())[:32], | |||||
"terminal_id": "88888", | |||||
} | |||||
body = { | body = { | ||||
"log_id": str(uuid.uuid1()), | "log_id": str(uuid.uuid1()), | ||||
"version": "3.0", | "version": "3.0", | ||||
@@ -142,11 +150,7 @@ class BDunit(Plugin): | |||||
:param parsed: UNIT 解析结果 | :param parsed: UNIT 解析结果 | ||||
:returns: 意图数组 | :returns: 意图数组 | ||||
""" | """ | ||||
if ( | |||||
parsed | |||||
and "result" in parsed | |||||
and "response_list" in parsed["result"] | |||||
): | |||||
if parsed and "result" in parsed and "response_list" in parsed["result"]: | |||||
try: | try: | ||||
return parsed["result"]["response_list"][0]["schema"]["intent"] | return parsed["result"]["response_list"][0]["schema"]["intent"] | ||||
except Exception as e: | except Exception as e: | ||||
@@ -163,11 +167,7 @@ class BDunit(Plugin): | |||||
:param intent: 意图的名称 | :param intent: 意图的名称 | ||||
:returns: True: 包含; False: 不包含 | :returns: True: 包含; False: 不包含 | ||||
""" | """ | ||||
if ( | |||||
parsed | |||||
and "result" in parsed | |||||
and "response_list" in parsed["result"] | |||||
): | |||||
if parsed and "result" in parsed and "response_list" in parsed["result"]: | |||||
response_list = parsed["result"]["response_list"] | response_list = parsed["result"]["response_list"] | ||||
for response in response_list: | for response in response_list: | ||||
if ( | if ( | ||||
@@ -189,11 +189,7 @@ class BDunit(Plugin): | |||||
:returns: 词槽列表。你可以通过 name 属性筛选词槽, | :returns: 词槽列表。你可以通过 name 属性筛选词槽, | ||||
再通过 normalized_word 属性取出相应的值 | 再通过 normalized_word 属性取出相应的值 | ||||
""" | """ | ||||
if ( | |||||
parsed | |||||
and "result" in parsed | |||||
and "response_list" in parsed["result"] | |||||
): | |||||
if parsed and "result" in parsed and "response_list" in parsed["result"]: | |||||
response_list = parsed["result"]["response_list"] | response_list = parsed["result"]["response_list"] | ||||
if intent == "": | if intent == "": | ||||
try: | try: | ||||
@@ -236,11 +232,7 @@ class BDunit(Plugin): | |||||
:param parsed: UNIT 解析结果 | :param parsed: UNIT 解析结果 | ||||
:returns: UNIT 的回复文本 | :returns: UNIT 的回复文本 | ||||
""" | """ | ||||
if ( | |||||
parsed | |||||
and "result" in parsed | |||||
and "response_list" in parsed["result"] | |||||
): | |||||
if parsed and "result" in parsed and "response_list" in parsed["result"]: | |||||
response_list = parsed["result"]["response_list"] | response_list = parsed["result"]["response_list"] | ||||
answer = {} | answer = {} | ||||
for response in response_list: | for response in response_list: | ||||
@@ -266,11 +258,7 @@ class BDunit(Plugin): | |||||
:param intent: 意图的名称 | :param intent: 意图的名称 | ||||
:returns: UNIT 的回复文本 | :returns: UNIT 的回复文本 | ||||
""" | """ | ||||
if ( | |||||
parsed | |||||
and "result" in parsed | |||||
and "response_list" in parsed["result"] | |||||
): | |||||
if parsed and "result" in parsed and "response_list" in parsed["result"]: | |||||
response_list = parsed["result"]["response_list"] | response_list = parsed["result"]["response_list"] | ||||
if intent == "": | if intent == "": | ||||
try: | try: | ||||
@@ -1,5 +1,5 @@ | |||||
{ | { | ||||
"service_id": "s...", | |||||
"api_key": "", | |||||
"secret_key": "" | |||||
} | |||||
"service_id": "s...", | |||||
"api_key": "", | |||||
"secret_key": "" | |||||
} |
@@ -1 +1 @@ | |||||
from .dungeon import * | |||||
from .dungeon import * |
@@ -1,17 +1,18 @@ | |||||
# encoding:utf-8 | # encoding:utf-8 | ||||
import plugins | |||||
from bridge.bridge import Bridge | from bridge.bridge import Bridge | ||||
from bridge.context import ContextType | from bridge.context import ContextType | ||||
from bridge.reply import Reply, ReplyType | from bridge.reply import Reply, ReplyType | ||||
from common import const | |||||
from common.expired_dict import ExpiredDict | from common.expired_dict import ExpiredDict | ||||
from common.log import logger | |||||
from config import conf | from config import conf | ||||
import plugins | |||||
from plugins import * | from plugins import * | ||||
from common.log import logger | |||||
from common import const | |||||
# https://github.com/bupticybee/ChineseAiDungeonChatGPT | # https://github.com/bupticybee/ChineseAiDungeonChatGPT | ||||
class StoryTeller(): | |||||
class StoryTeller: | |||||
def __init__(self, bot, sessionid, story): | def __init__(self, bot, sessionid, story): | ||||
self.bot = bot | self.bot = bot | ||||
self.sessionid = sessionid | self.sessionid = sessionid | ||||
@@ -27,67 +28,85 @@ class StoryTeller(): | |||||
if user_action[-1] != "。": | if user_action[-1] != "。": | ||||
user_action = user_action + "。" | user_action = user_action + "。" | ||||
if self.first_interact: | if self.first_interact: | ||||
prompt = """现在来充当一个文字冒险游戏,描述时候注意节奏,不要太快,仔细描述各个人物的心情和周边环境。一次只需写四到六句话。 | |||||
开头是,""" + self.story + " " + user_action | |||||
prompt = ( | |||||
"""现在来充当一个文字冒险游戏,描述时候注意节奏,不要太快,仔细描述各个人物的心情和周边环境。一次只需写四到六句话。 | |||||
开头是,""" | |||||
+ self.story | |||||
+ " " | |||||
+ user_action | |||||
) | |||||
self.first_interact = False | self.first_interact = False | ||||
else: | else: | ||||
prompt = """继续,一次只需要续写四到六句话,总共就只讲5分钟内发生的事情。""" + user_action | prompt = """继续,一次只需要续写四到六句话,总共就只讲5分钟内发生的事情。""" + user_action | ||||
return prompt | return prompt | ||||
@plugins.register(name="Dungeon", desire_priority=0, namecn="文字冒险", desc="A plugin to play dungeon game", version="1.0", author="lanvent") | |||||
@plugins.register( | |||||
name="Dungeon", | |||||
desire_priority=0, | |||||
namecn="文字冒险", | |||||
desc="A plugin to play dungeon game", | |||||
version="1.0", | |||||
author="lanvent", | |||||
) | |||||
class Dungeon(Plugin): | class Dungeon(Plugin): | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context | self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context | ||||
logger.info("[Dungeon] inited") | logger.info("[Dungeon] inited") | ||||
# 目前没有设计session过期事件,这里先暂时使用过期字典 | # 目前没有设计session过期事件,这里先暂时使用过期字典 | ||||
if conf().get('expires_in_seconds'): | |||||
self.games = ExpiredDict(conf().get('expires_in_seconds')) | |||||
if conf().get("expires_in_seconds"): | |||||
self.games = ExpiredDict(conf().get("expires_in_seconds")) | |||||
else: | else: | ||||
self.games = dict() | self.games = dict() | ||||
def on_handle_context(self, e_context: EventContext): | def on_handle_context(self, e_context: EventContext): | ||||
if e_context['context'].type != ContextType.TEXT: | |||||
if e_context["context"].type != ContextType.TEXT: | |||||
return | return | ||||
bottype = Bridge().get_bot_type("chat") | bottype = Bridge().get_bot_type("chat") | ||||
if bottype not in (const.CHATGPT, const.OPEN_AI): | if bottype not in (const.CHATGPT, const.OPEN_AI): | ||||
return | return | ||||
bot = Bridge().get_bot("chat") | bot = Bridge().get_bot("chat") | ||||
content = e_context['context'].content[:] | |||||
clist = e_context['context'].content.split(maxsplit=1) | |||||
sessionid = e_context['context']['session_id'] | |||||
content = e_context["context"].content[:] | |||||
clist = e_context["context"].content.split(maxsplit=1) | |||||
sessionid = e_context["context"]["session_id"] | |||||
logger.debug("[Dungeon] on_handle_context. content: %s" % clist) | logger.debug("[Dungeon] on_handle_context. content: %s" % clist) | ||||
trigger_prefix = conf().get('plugin_trigger_prefix', "$") | |||||
trigger_prefix = conf().get("plugin_trigger_prefix", "$") | |||||
if clist[0] == f"{trigger_prefix}停止冒险": | if clist[0] == f"{trigger_prefix}停止冒险": | ||||
if sessionid in self.games: | if sessionid in self.games: | ||||
self.games[sessionid].reset() | self.games[sessionid].reset() | ||||
del self.games[sessionid] | del self.games[sessionid] | ||||
reply = Reply(ReplyType.INFO, "冒险结束!") | reply = Reply(ReplyType.INFO, "冒险结束!") | ||||
e_context['reply'] = reply | |||||
e_context["reply"] = reply | |||||
e_context.action = EventAction.BREAK_PASS | e_context.action = EventAction.BREAK_PASS | ||||
elif clist[0] == f"{trigger_prefix}开始冒险" or sessionid in self.games: | elif clist[0] == f"{trigger_prefix}开始冒险" or sessionid in self.games: | ||||
if sessionid not in self.games or clist[0] == f"{trigger_prefix}开始冒险": | if sessionid not in self.games or clist[0] == f"{trigger_prefix}开始冒险": | ||||
if len(clist)>1 : | |||||
if len(clist) > 1: | |||||
story = clist[1] | story = clist[1] | ||||
else: | else: | ||||
story = "你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。" | |||||
story = ( | |||||
"你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。" | |||||
) | |||||
self.games[sessionid] = StoryTeller(bot, sessionid, story) | self.games[sessionid] = StoryTeller(bot, sessionid, story) | ||||
reply = Reply(ReplyType.INFO, "冒险开始,你可以输入任意内容,让故事继续下去。故事背景是:" + story) | reply = Reply(ReplyType.INFO, "冒险开始,你可以输入任意内容,让故事继续下去。故事背景是:" + story) | ||||
e_context['reply'] = reply | |||||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 | |||||
e_context["reply"] = reply | |||||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 | |||||
else: | else: | ||||
prompt = self.games[sessionid].action(content) | prompt = self.games[sessionid].action(content) | ||||
e_context['context'].type = ContextType.TEXT | |||||
e_context['context'].content = prompt | |||||
e_context.action = EventAction.BREAK # 事件结束,不跳过处理context的默认逻辑 | |||||
e_context["context"].type = ContextType.TEXT | |||||
e_context["context"].content = prompt | |||||
e_context.action = EventAction.BREAK # 事件结束,不跳过处理context的默认逻辑 | |||||
def get_help_text(self, **kwargs): | def get_help_text(self, **kwargs): | ||||
help_text = "可以和机器人一起玩文字冒险游戏。\n" | help_text = "可以和机器人一起玩文字冒险游戏。\n" | ||||
if kwargs.get('verbose') != True: | |||||
if kwargs.get("verbose") != True: | |||||
return help_text | return help_text | ||||
trigger_prefix = conf().get('plugin_trigger_prefix', "$") | |||||
help_text = f"{trigger_prefix}开始冒险 "+"背景故事: 开始一个基于{背景故事}的文字冒险,之后你的所有消息会协助完善这个故事。\n"+f"{trigger_prefix}停止冒险: 结束游戏。\n" | |||||
if kwargs.get('verbose') == True: | |||||
trigger_prefix = conf().get("plugin_trigger_prefix", "$") | |||||
help_text = ( | |||||
f"{trigger_prefix}开始冒险 " | |||||
+ "背景故事: 开始一个基于{背景故事}的文字冒险,之后你的所有消息会协助完善这个故事。\n" | |||||
+ f"{trigger_prefix}停止冒险: 结束游戏。\n" | |||||
) | |||||
if kwargs.get("verbose") == True: | |||||
help_text += f"\n命令例子: '{trigger_prefix}开始冒险 你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。'" | help_text += f"\n命令例子: '{trigger_prefix}开始冒险 你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。'" | ||||
return help_text | |||||
return help_text |
@@ -9,17 +9,17 @@ class Event(Enum): | |||||
e_context = { "channel": 消息channel, "context" : 本次消息的context} | e_context = { "channel": 消息channel, "context" : 本次消息的context} | ||||
""" | """ | ||||
ON_HANDLE_CONTEXT = 2 # 处理消息前 | |||||
ON_HANDLE_CONTEXT = 2 # 处理消息前 | |||||
""" | """ | ||||
e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复,初始为空 } | e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复,初始为空 } | ||||
""" | """ | ||||
ON_DECORATE_REPLY = 3 # 得到回复后准备装饰 | |||||
ON_DECORATE_REPLY = 3 # 得到回复后准备装饰 | |||||
""" | """ | ||||
e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 } | e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 } | ||||
""" | """ | ||||
ON_SEND_REPLY = 4 # 发送回复前 | |||||
ON_SEND_REPLY = 4 # 发送回复前 | |||||
""" | """ | ||||
e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 } | e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 } | ||||
""" | """ | ||||
@@ -28,9 +28,9 @@ class Event(Enum): | |||||
class EventAction(Enum): | class EventAction(Enum): | ||||
CONTINUE = 1 # 事件未结束,继续交给下个插件处理,如果没有下个插件,则交付给默认的事件处理逻辑 | |||||
BREAK = 2 # 事件结束,不再给下个插件处理,交付给默认的事件处理逻辑 | |||||
BREAK_PASS = 3 # 事件结束,不再给下个插件处理,不交付给默认的事件处理逻辑 | |||||
CONTINUE = 1 # 事件未结束,继续交给下个插件处理,如果没有下个插件,则交付给默认的事件处理逻辑 | |||||
BREAK = 2 # 事件结束,不再给下个插件处理,交付给默认的事件处理逻辑 | |||||
BREAK_PASS = 3 # 事件结束,不再给下个插件处理,不交付给默认的事件处理逻辑 | |||||
class EventContext: | class EventContext: | ||||
@@ -1 +1 @@ | |||||
from .finish import * | |||||
from .finish import * |
@@ -1,14 +1,21 @@ | |||||
# encoding:utf-8 | # encoding:utf-8 | ||||
import plugins | |||||
from bridge.context import ContextType | from bridge.context import ContextType | ||||
from bridge.reply import Reply, ReplyType | from bridge.reply import Reply, ReplyType | ||||
from common.log import logger | |||||
from config import conf | from config import conf | ||||
import plugins | |||||
from plugins import * | from plugins import * | ||||
from common.log import logger | |||||
@plugins.register(name="Finish", desire_priority=-999, hidden=True, desc="A plugin that check unknown command", version="1.0", author="js00000") | |||||
@plugins.register( | |||||
name="Finish", | |||||
desire_priority=-999, | |||||
hidden=True, | |||||
desc="A plugin that check unknown command", | |||||
version="1.0", | |||||
author="js00000", | |||||
) | |||||
class Finish(Plugin): | class Finish(Plugin): | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
@@ -16,19 +23,18 @@ class Finish(Plugin): | |||||
logger.info("[Finish] inited") | logger.info("[Finish] inited") | ||||
def on_handle_context(self, e_context: EventContext): | def on_handle_context(self, e_context: EventContext): | ||||
if e_context['context'].type != ContextType.TEXT: | |||||
if e_context["context"].type != ContextType.TEXT: | |||||
return | return | ||||
content = e_context['context'].content | |||||
content = e_context["context"].content | |||||
logger.debug("[Finish] on_handle_context. content: %s" % content) | logger.debug("[Finish] on_handle_context. content: %s" % content) | ||||
trigger_prefix = conf().get('plugin_trigger_prefix',"$") | |||||
trigger_prefix = conf().get("plugin_trigger_prefix", "$") | |||||
if content.startswith(trigger_prefix): | if content.startswith(trigger_prefix): | ||||
reply = Reply() | reply = Reply() | ||||
reply.type = ReplyType.ERROR | reply.type = ReplyType.ERROR | ||||
reply.content = "未知插件命令\n查看插件命令列表请输入#help 插件名\n" | reply.content = "未知插件命令\n查看插件命令列表请输入#help 插件名\n" | ||||
e_context['reply'] = reply | |||||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 | |||||
e_context["reply"] = reply | |||||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 | |||||
def get_help_text(self, **kwargs): | def get_help_text(self, **kwargs): | ||||
return "" | return "" |
@@ -1 +1 @@ | |||||
from .godcmd import * | |||||
from .godcmd import * |
@@ -1,4 +1,4 @@ | |||||
{ | { | ||||
"password": "", | |||||
"admin_users": [] | |||||
} | |||||
"password": "", | |||||
"admin_users": [] | |||||
} |
@@ -6,14 +6,16 @@ import random | |||||
import string | import string | ||||
import traceback | import traceback | ||||
from typing import Tuple | from typing import Tuple | ||||
import plugins | |||||
from bridge.bridge import Bridge | from bridge.bridge import Bridge | ||||
from bridge.context import ContextType | from bridge.context import ContextType | ||||
from bridge.reply import Reply, ReplyType | from bridge.reply import Reply, ReplyType | ||||
from config import conf, load_config | |||||
import plugins | |||||
from plugins import * | |||||
from common import const | from common import const | ||||
from common.log import logger | from common.log import logger | ||||
from config import conf, load_config | |||||
from plugins import * | |||||
# 定义指令集 | # 定义指令集 | ||||
COMMANDS = { | COMMANDS = { | ||||
"help": { | "help": { | ||||
@@ -41,7 +43,7 @@ COMMANDS = { | |||||
}, | }, | ||||
"id": { | "id": { | ||||
"alias": ["id", "用户"], | "alias": ["id", "用户"], | ||||
"desc": "获取用户id", # wechaty和wechatmp的用户id不会变化,可用于绑定管理员 | |||||
"desc": "获取用户id", # wechaty和wechatmp的用户id不会变化,可用于绑定管理员 | |||||
}, | }, | ||||
"reset": { | "reset": { | ||||
"alias": ["reset", "重置会话"], | "alias": ["reset", "重置会话"], | ||||
@@ -114,18 +116,20 @@ ADMIN_COMMANDS = { | |||||
"desc": "开启机器调试日志", | "desc": "开启机器调试日志", | ||||
}, | }, | ||||
} | } | ||||
# 定义帮助函数 | # 定义帮助函数 | ||||
def get_help_text(isadmin, isgroup): | def get_help_text(isadmin, isgroup): | ||||
help_text = "通用指令:\n" | help_text = "通用指令:\n" | ||||
for cmd, info in COMMANDS.items(): | for cmd, info in COMMANDS.items(): | ||||
if cmd=="auth": #不提示认证指令 | |||||
if cmd == "auth": # 不提示认证指令 | |||||
continue | continue | ||||
if cmd=="id" and conf().get("channel_type","wx") not in ["wxy","wechatmp"]: | |||||
if cmd == "id" and conf().get("channel_type", "wx") not in ["wxy", "wechatmp"]: | |||||
continue | continue | ||||
alias=["#"+a for a in info['alias'][:1]] | |||||
alias = ["#" + a for a in info["alias"][:1]] | |||||
help_text += f"{','.join(alias)} " | help_text += f"{','.join(alias)} " | ||||
if 'args' in info: | |||||
args=[a for a in info['args']] | |||||
if "args" in info: | |||||
args = [a for a in info["args"]] | |||||
help_text += f"{' '.join(args)}" | help_text += f"{' '.join(args)}" | ||||
help_text += f": {info['desc']}\n" | help_text += f": {info['desc']}\n" | ||||
@@ -135,39 +139,48 @@ def get_help_text(isadmin, isgroup): | |||||
for plugin in plugins: | for plugin in plugins: | ||||
if plugins[plugin].enabled and not plugins[plugin].hidden: | if plugins[plugin].enabled and not plugins[plugin].hidden: | ||||
namecn = plugins[plugin].namecn | namecn = plugins[plugin].namecn | ||||
help_text += "\n%s:"%namecn | |||||
help_text += PluginManager().instances[plugin].get_help_text(verbose=False).strip() | |||||
help_text += "\n%s:" % namecn | |||||
help_text += ( | |||||
PluginManager().instances[plugin].get_help_text(verbose=False).strip() | |||||
) | |||||
if ADMIN_COMMANDS and isadmin: | if ADMIN_COMMANDS and isadmin: | ||||
help_text += "\n\n管理员指令:\n" | help_text += "\n\n管理员指令:\n" | ||||
for cmd, info in ADMIN_COMMANDS.items(): | for cmd, info in ADMIN_COMMANDS.items(): | ||||
alias=["#"+a for a in info['alias'][:1]] | |||||
alias = ["#" + a for a in info["alias"][:1]] | |||||
help_text += f"{','.join(alias)} " | help_text += f"{','.join(alias)} " | ||||
if 'args' in info: | |||||
args=[a for a in info['args']] | |||||
if "args" in info: | |||||
args = [a for a in info["args"]] | |||||
help_text += f"{' '.join(args)}" | help_text += f"{' '.join(args)}" | ||||
help_text += f": {info['desc']}\n" | help_text += f": {info['desc']}\n" | ||||
return help_text | return help_text | ||||
@plugins.register(name="Godcmd", desire_priority=999, hidden=True, desc="为你的机器人添加指令集,有用户和管理员两种角色,加载顺序请放在首位,初次运行后插件目录会生成配置文件, 填充管理员密码后即可认证", version="1.0", author="lanvent") | |||||
class Godcmd(Plugin): | |||||
@plugins.register( | |||||
name="Godcmd", | |||||
desire_priority=999, | |||||
hidden=True, | |||||
desc="为你的机器人添加指令集,有用户和管理员两种角色,加载顺序请放在首位,初次运行后插件目录会生成配置文件, 填充管理员密码后即可认证", | |||||
version="1.0", | |||||
author="lanvent", | |||||
) | |||||
class Godcmd(Plugin): | |||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
curdir=os.path.dirname(__file__) | |||||
config_path=os.path.join(curdir,"config.json") | |||||
gconf=None | |||||
curdir = os.path.dirname(__file__) | |||||
config_path = os.path.join(curdir, "config.json") | |||||
gconf = None | |||||
if not os.path.exists(config_path): | if not os.path.exists(config_path): | ||||
gconf={"password":"","admin_users":[]} | |||||
with open(config_path,"w") as f: | |||||
json.dump(gconf,f,indent=4) | |||||
gconf = {"password": "", "admin_users": []} | |||||
with open(config_path, "w") as f: | |||||
json.dump(gconf, f, indent=4) | |||||
else: | else: | ||||
with open(config_path,"r") as f: | |||||
gconf=json.load(f) | |||||
with open(config_path, "r") as f: | |||||
gconf = json.load(f) | |||||
if gconf["password"] == "": | if gconf["password"] == "": | ||||
self.temp_password = "".join(random.sample(string.digits, 4)) | self.temp_password = "".join(random.sample(string.digits, 4)) | ||||
logger.info("[Godcmd] 因未设置口令,本次的临时口令为%s。"%self.temp_password) | |||||
logger.info("[Godcmd] 因未设置口令,本次的临时口令为%s。" % self.temp_password) | |||||
else: | else: | ||||
self.temp_password = None | self.temp_password = None | ||||
custom_commands = conf().get("clear_memory_commands", []) | custom_commands = conf().get("clear_memory_commands", []) | ||||
@@ -178,41 +191,42 @@ class Godcmd(Plugin): | |||||
COMMANDS["reset"]["alias"].append(custom_command) | COMMANDS["reset"]["alias"].append(custom_command) | ||||
self.password = gconf["password"] | self.password = gconf["password"] | ||||
self.admin_users = gconf["admin_users"] # 预存的管理员账号,这些账号不需要认证。itchat的用户名每次都会变,不可用 | |||||
self.isrunning = True # 机器人是否运行中 | |||||
self.admin_users = gconf[ | |||||
"admin_users" | |||||
] # 预存的管理员账号,这些账号不需要认证。itchat的用户名每次都会变,不可用 | |||||
self.isrunning = True # 机器人是否运行中 | |||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context | self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context | ||||
logger.info("[Godcmd] inited") | logger.info("[Godcmd] inited") | ||||
def on_handle_context(self, e_context: EventContext): | def on_handle_context(self, e_context: EventContext): | ||||
context_type = e_context['context'].type | |||||
context_type = e_context["context"].type | |||||
if context_type != ContextType.TEXT: | if context_type != ContextType.TEXT: | ||||
if not self.isrunning: | if not self.isrunning: | ||||
e_context.action = EventAction.BREAK_PASS | e_context.action = EventAction.BREAK_PASS | ||||
return | return | ||||
content = e_context['context'].content | |||||
content = e_context["context"].content | |||||
logger.debug("[Godcmd] on_handle_context. content: %s" % content) | logger.debug("[Godcmd] on_handle_context. content: %s" % content) | ||||
if content.startswith("#"): | if content.startswith("#"): | ||||
# msg = e_context['context']['msg'] | # msg = e_context['context']['msg'] | ||||
channel = e_context['channel'] | |||||
user = e_context['context']['receiver'] | |||||
session_id = e_context['context']['session_id'] | |||||
isgroup = e_context['context'].get("isgroup", False) | |||||
channel = e_context["channel"] | |||||
user = e_context["context"]["receiver"] | |||||
session_id = e_context["context"]["session_id"] | |||||
isgroup = e_context["context"].get("isgroup", False) | |||||
bottype = Bridge().get_bot_type("chat") | bottype = Bridge().get_bot_type("chat") | ||||
bot = Bridge().get_bot("chat") | bot = Bridge().get_bot("chat") | ||||
# 将命令和参数分割 | # 将命令和参数分割 | ||||
command_parts = content[1:].strip().split() | command_parts = content[1:].strip().split() | ||||
cmd = command_parts[0] | cmd = command_parts[0] | ||||
args = command_parts[1:] | args = command_parts[1:] | ||||
isadmin=False | |||||
isadmin = False | |||||
if user in self.admin_users: | if user in self.admin_users: | ||||
isadmin=True | |||||
ok=False | |||||
result="string" | |||||
if any(cmd in info['alias'] for info in COMMANDS.values()): | |||||
cmd = next(c for c, info in COMMANDS.items() if cmd in info['alias']) | |||||
isadmin = True | |||||
ok = False | |||||
result = "string" | |||||
if any(cmd in info["alias"] for info in COMMANDS.values()): | |||||
cmd = next(c for c, info in COMMANDS.items() if cmd in info["alias"]) | |||||
if cmd == "auth": | if cmd == "auth": | ||||
ok, result = self.authenticate(user, args, isadmin, isgroup) | ok, result = self.authenticate(user, args, isadmin, isgroup) | ||||
elif cmd == "help" or cmd == "helpp": | elif cmd == "help" or cmd == "helpp": | ||||
@@ -224,10 +238,14 @@ class Godcmd(Plugin): | |||||
query_name = args[0].upper() | query_name = args[0].upper() | ||||
# search name and namecn | # search name and namecn | ||||
for name, plugincls in plugins.items(): | for name, plugincls in plugins.items(): | ||||
if not plugincls.enabled : | |||||
if not plugincls.enabled: | |||||
continue | continue | ||||
if query_name == name or query_name == plugincls.namecn: | if query_name == name or query_name == plugincls.namecn: | ||||
ok, result = True, PluginManager().instances[name].get_help_text(isgroup=isgroup, isadmin=isadmin, verbose=True) | |||||
ok, result = True, PluginManager().instances[ | |||||
name | |||||
].get_help_text( | |||||
isgroup=isgroup, isadmin=isadmin, verbose=True | |||||
) | |||||
break | break | ||||
if not ok: | if not ok: | ||||
result = "插件不存在或未启用" | result = "插件不存在或未启用" | ||||
@@ -236,14 +254,14 @@ class Godcmd(Plugin): | |||||
elif cmd == "set_openai_api_key": | elif cmd == "set_openai_api_key": | ||||
if len(args) == 1: | if len(args) == 1: | ||||
user_data = conf().get_user_data(user) | user_data = conf().get_user_data(user) | ||||
user_data['openai_api_key'] = args[0] | |||||
user_data["openai_api_key"] = args[0] | |||||
ok, result = True, "你的OpenAI私有api_key已设置为" + args[0] | ok, result = True, "你的OpenAI私有api_key已设置为" + args[0] | ||||
else: | else: | ||||
ok, result = False, "请提供一个api_key" | ok, result = False, "请提供一个api_key" | ||||
elif cmd == "reset_openai_api_key": | elif cmd == "reset_openai_api_key": | ||||
try: | try: | ||||
user_data = conf().get_user_data(user) | user_data = conf().get_user_data(user) | ||||
user_data.pop('openai_api_key') | |||||
user_data.pop("openai_api_key") | |||||
ok, result = True, "你的OpenAI私有api_key已清除" | ok, result = True, "你的OpenAI私有api_key已清除" | ||||
except Exception as e: | except Exception as e: | ||||
ok, result = False, "你没有设置私有api_key" | ok, result = False, "你没有设置私有api_key" | ||||
@@ -255,12 +273,16 @@ class Godcmd(Plugin): | |||||
else: | else: | ||||
ok, result = False, "当前对话机器人不支持重置会话" | ok, result = False, "当前对话机器人不支持重置会话" | ||||
logger.debug("[Godcmd] command: %s by %s" % (cmd, user)) | logger.debug("[Godcmd] command: %s by %s" % (cmd, user)) | ||||
elif any(cmd in info['alias'] for info in ADMIN_COMMANDS.values()): | |||||
elif any(cmd in info["alias"] for info in ADMIN_COMMANDS.values()): | |||||
if isadmin: | if isadmin: | ||||
if isgroup: | if isgroup: | ||||
ok, result = False, "群聊不可执行管理员指令" | ok, result = False, "群聊不可执行管理员指令" | ||||
else: | else: | ||||
cmd = next(c for c, info in ADMIN_COMMANDS.items() if cmd in info['alias']) | |||||
cmd = next( | |||||
c | |||||
for c, info in ADMIN_COMMANDS.items() | |||||
if cmd in info["alias"] | |||||
) | |||||
if cmd == "stop": | if cmd == "stop": | ||||
self.isrunning = False | self.isrunning = False | ||||
ok, result = True, "服务已暂停" | ok, result = True, "服务已暂停" | ||||
@@ -278,13 +300,13 @@ class Godcmd(Plugin): | |||||
else: | else: | ||||
ok, result = False, "当前对话机器人不支持重置会话" | ok, result = False, "当前对话机器人不支持重置会话" | ||||
elif cmd == "debug": | elif cmd == "debug": | ||||
logger.setLevel('DEBUG') | |||||
logger.setLevel("DEBUG") | |||||
ok, result = True, "DEBUG模式已开启" | ok, result = True, "DEBUG模式已开启" | ||||
elif cmd == "plist": | elif cmd == "plist": | ||||
plugins = PluginManager().list_plugins() | plugins = PluginManager().list_plugins() | ||||
ok = True | ok = True | ||||
result = "插件列表:\n" | result = "插件列表:\n" | ||||
for name,plugincls in plugins.items(): | |||||
for name, plugincls in plugins.items(): | |||||
result += f"{plugincls.name}_v{plugincls.version} {plugincls.priority} - " | result += f"{plugincls.name}_v{plugincls.version} {plugincls.priority} - " | ||||
if plugincls.enabled: | if plugincls.enabled: | ||||
result += "已启用\n" | result += "已启用\n" | ||||
@@ -294,16 +316,20 @@ class Godcmd(Plugin): | |||||
new_plugins = PluginManager().scan_plugins() | new_plugins = PluginManager().scan_plugins() | ||||
ok, result = True, "插件扫描完成" | ok, result = True, "插件扫描完成" | ||||
PluginManager().activate_plugins() | PluginManager().activate_plugins() | ||||
if len(new_plugins) >0 : | |||||
if len(new_plugins) > 0: | |||||
result += "\n发现新插件:\n" | result += "\n发现新插件:\n" | ||||
result += "\n".join([f"{p.name}_v{p.version}" for p in new_plugins]) | |||||
else : | |||||
result +=", 未发现新插件" | |||||
result += "\n".join( | |||||
[f"{p.name}_v{p.version}" for p in new_plugins] | |||||
) | |||||
else: | |||||
result += ", 未发现新插件" | |||||
elif cmd == "setpri": | elif cmd == "setpri": | ||||
if len(args) != 2: | if len(args) != 2: | ||||
ok, result = False, "请提供插件名和优先级" | ok, result = False, "请提供插件名和优先级" | ||||
else: | else: | ||||
ok = PluginManager().set_plugin_priority(args[0], int(args[1])) | |||||
ok = PluginManager().set_plugin_priority( | |||||
args[0], int(args[1]) | |||||
) | |||||
if ok: | if ok: | ||||
result = "插件" + args[0] + "优先级已设置为" + args[1] | result = "插件" + args[0] + "优先级已设置为" + args[1] | ||||
else: | else: | ||||
@@ -350,42 +376,42 @@ class Godcmd(Plugin): | |||||
else: | else: | ||||
ok, result = False, "需要管理员权限才能执行该指令" | ok, result = False, "需要管理员权限才能执行该指令" | ||||
else: | else: | ||||
trigger_prefix = conf().get('plugin_trigger_prefix',"$") | |||||
if trigger_prefix == "#": # 跟插件聊天指令前缀相同,继续递交 | |||||
trigger_prefix = conf().get("plugin_trigger_prefix", "$") | |||||
if trigger_prefix == "#": # 跟插件聊天指令前缀相同,继续递交 | |||||
return | return | ||||
ok, result = False, f"未知指令:{cmd}\n查看指令列表请输入#help \n" | ok, result = False, f"未知指令:{cmd}\n查看指令列表请输入#help \n" | ||||
reply = Reply() | reply = Reply() | ||||
if ok: | if ok: | ||||
reply.type = ReplyType.INFO | reply.type = ReplyType.INFO | ||||
else: | else: | ||||
reply.type = ReplyType.ERROR | reply.type = ReplyType.ERROR | ||||
reply.content = result | reply.content = result | ||||
e_context['reply'] = reply | |||||
e_context["reply"] = reply | |||||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 | |||||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 | |||||
elif not self.isrunning: | elif not self.isrunning: | ||||
e_context.action = EventAction.BREAK_PASS | e_context.action = EventAction.BREAK_PASS | ||||
def authenticate(self, userid, args, isadmin, isgroup) -> Tuple[bool,str] : | |||||
def authenticate(self, userid, args, isadmin, isgroup) -> Tuple[bool, str]: | |||||
if isgroup: | if isgroup: | ||||
return False,"请勿在群聊中认证" | |||||
return False, "请勿在群聊中认证" | |||||
if isadmin: | if isadmin: | ||||
return False,"管理员账号无需认证" | |||||
return False, "管理员账号无需认证" | |||||
if len(args) != 1: | if len(args) != 1: | ||||
return False,"请提供口令" | |||||
return False, "请提供口令" | |||||
password = args[0] | password = args[0] | ||||
if password == self.password: | if password == self.password: | ||||
self.admin_users.append(userid) | self.admin_users.append(userid) | ||||
return True,"认证成功" | |||||
return True, "认证成功" | |||||
elif password == self.temp_password: | elif password == self.temp_password: | ||||
self.admin_users.append(userid) | self.admin_users.append(userid) | ||||
return True,"认证成功,请尽快设置口令" | |||||
return True, "认证成功,请尽快设置口令" | |||||
else: | else: | ||||
return False,"认证失败" | |||||
return False, "认证失败" | |||||
def get_help_text(self, isadmin = False, isgroup = False, **kwargs): | |||||
return get_help_text(isadmin, isgroup) | |||||
def get_help_text(self, isadmin=False, isgroup=False, **kwargs): | |||||
return get_help_text(isadmin, isgroup) |
@@ -1 +1 @@ | |||||
from .hello import * | |||||
from .hello import * |
@@ -1,14 +1,21 @@ | |||||
# encoding:utf-8 | # encoding:utf-8 | ||||
import plugins | |||||
from bridge.context import ContextType | from bridge.context import ContextType | ||||
from bridge.reply import Reply, ReplyType | from bridge.reply import Reply, ReplyType | ||||
from channel.chat_message import ChatMessage | from channel.chat_message import ChatMessage | ||||
import plugins | |||||
from plugins import * | |||||
from common.log import logger | from common.log import logger | ||||
from plugins import * | |||||
@plugins.register(name="Hello", desire_priority=-1, hidden=True, desc="A simple plugin that says hello", version="0.1", author="lanvent") | |||||
@plugins.register( | |||||
name="Hello", | |||||
desire_priority=-1, | |||||
hidden=True, | |||||
desc="A simple plugin that says hello", | |||||
version="0.1", | |||||
author="lanvent", | |||||
) | |||||
class Hello(Plugin): | class Hello(Plugin): | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
@@ -16,33 +23,34 @@ class Hello(Plugin): | |||||
logger.info("[Hello] inited") | logger.info("[Hello] inited") | ||||
def on_handle_context(self, e_context: EventContext): | def on_handle_context(self, e_context: EventContext): | ||||
if e_context['context'].type != ContextType.TEXT: | |||||
if e_context["context"].type != ContextType.TEXT: | |||||
return | return | ||||
content = e_context['context'].content | |||||
content = e_context["context"].content | |||||
logger.debug("[Hello] on_handle_context. content: %s" % content) | logger.debug("[Hello] on_handle_context. content: %s" % content) | ||||
if content == "Hello": | if content == "Hello": | ||||
reply = Reply() | reply = Reply() | ||||
reply.type = ReplyType.TEXT | reply.type = ReplyType.TEXT | ||||
msg:ChatMessage = e_context['context']['msg'] | |||||
if e_context['context']['isgroup']: | |||||
reply.content = f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}" | |||||
msg: ChatMessage = e_context["context"]["msg"] | |||||
if e_context["context"]["isgroup"]: | |||||
reply.content = ( | |||||
f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}" | |||||
) | |||||
else: | else: | ||||
reply.content = f"Hello, {msg.from_user_nickname}" | reply.content = f"Hello, {msg.from_user_nickname}" | ||||
e_context['reply'] = reply | |||||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 | |||||
e_context["reply"] = reply | |||||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 | |||||
if content == "Hi": | if content == "Hi": | ||||
reply = Reply() | reply = Reply() | ||||
reply.type = ReplyType.TEXT | reply.type = ReplyType.TEXT | ||||
reply.content = "Hi" | reply.content = "Hi" | ||||
e_context['reply'] = reply | |||||
e_context["reply"] = reply | |||||
e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑,一般会覆写reply | e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑,一般会覆写reply | ||||
if content == "End": | if content == "End": | ||||
# 如果是文本消息"End",将请求转换成"IMAGE_CREATE",并将content设置为"The World" | # 如果是文本消息"End",将请求转换成"IMAGE_CREATE",并将content设置为"The World" | ||||
e_context['context'].type = ContextType.IMAGE_CREATE | |||||
e_context["context"].type = ContextType.IMAGE_CREATE | |||||
content = "The World" | content = "The World" | ||||
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 | e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 | ||||
@@ -3,4 +3,4 @@ class Plugin: | |||||
self.handlers = {} | self.handlers = {} | ||||
def get_help_text(self, **kwargs): | def get_help_text(self, **kwargs): | ||||
return "暂无帮助信息" | |||||
return "暂无帮助信息" |
@@ -5,17 +5,19 @@ import importlib.util | |||||
import json | import json | ||||
import os | import os | ||||
import sys | import sys | ||||
from common.log import logger | |||||
from common.singleton import singleton | from common.singleton import singleton | ||||
from common.sorted_dict import SortedDict | from common.sorted_dict import SortedDict | ||||
from .event import * | |||||
from common.log import logger | |||||
from config import conf | from config import conf | ||||
from .event import * | |||||
@singleton | @singleton | ||||
class PluginManager: | class PluginManager: | ||||
def __init__(self): | def __init__(self): | ||||
self.plugins = SortedDict(lambda k,v: v.priority,reverse=True) | |||||
self.plugins = SortedDict(lambda k, v: v.priority, reverse=True) | |||||
self.listening_plugins = {} | self.listening_plugins = {} | ||||
self.instances = {} | self.instances = {} | ||||
self.pconf = {} | self.pconf = {} | ||||
@@ -26,17 +28,27 @@ class PluginManager: | |||||
def wrapper(plugincls): | def wrapper(plugincls): | ||||
plugincls.name = name | plugincls.name = name | ||||
plugincls.priority = desire_priority | plugincls.priority = desire_priority | ||||
plugincls.desc = kwargs.get('desc') | |||||
plugincls.author = kwargs.get('author') | |||||
plugincls.desc = kwargs.get("desc") | |||||
plugincls.author = kwargs.get("author") | |||||
plugincls.path = self.current_plugin_path | plugincls.path = self.current_plugin_path | ||||
plugincls.version = kwargs.get('version') if kwargs.get('version') != None else "1.0" | |||||
plugincls.namecn = kwargs.get('namecn') if kwargs.get('namecn') != None else name | |||||
plugincls.hidden = kwargs.get('hidden') if kwargs.get('hidden') != None else False | |||||
plugincls.version = ( | |||||
kwargs.get("version") if kwargs.get("version") != None else "1.0" | |||||
) | |||||
plugincls.namecn = ( | |||||
kwargs.get("namecn") if kwargs.get("namecn") != None else name | |||||
) | |||||
plugincls.hidden = ( | |||||
kwargs.get("hidden") if kwargs.get("hidden") != None else False | |||||
) | |||||
plugincls.enabled = True | plugincls.enabled = True | ||||
if self.current_plugin_path == None: | if self.current_plugin_path == None: | ||||
raise Exception("Plugin path not set") | raise Exception("Plugin path not set") | ||||
self.plugins[name.upper()] = plugincls | self.plugins[name.upper()] = plugincls | ||||
logger.info("Plugin %s_v%s registered, path=%s" % (name, plugincls.version, plugincls.path)) | |||||
logger.info( | |||||
"Plugin %s_v%s registered, path=%s" | |||||
% (name, plugincls.version, plugincls.path) | |||||
) | |||||
return wrapper | return wrapper | ||||
def save_config(self): | def save_config(self): | ||||
@@ -50,10 +62,12 @@ class PluginManager: | |||||
if os.path.exists("./plugins/plugins.json"): | if os.path.exists("./plugins/plugins.json"): | ||||
with open("./plugins/plugins.json", "r", encoding="utf-8") as f: | with open("./plugins/plugins.json", "r", encoding="utf-8") as f: | ||||
pconf = json.load(f) | pconf = json.load(f) | ||||
pconf['plugins'] = SortedDict(lambda k,v: v["priority"],pconf['plugins'],reverse=True) | |||||
pconf["plugins"] = SortedDict( | |||||
lambda k, v: v["priority"], pconf["plugins"], reverse=True | |||||
) | |||||
else: | else: | ||||
modified = True | modified = True | ||||
pconf = {"plugins": SortedDict(lambda k,v: v["priority"],reverse=True)} | |||||
pconf = {"plugins": SortedDict(lambda k, v: v["priority"], reverse=True)} | |||||
self.pconf = pconf | self.pconf = pconf | ||||
if modified: | if modified: | ||||
self.save_config() | self.save_config() | ||||
@@ -67,7 +81,7 @@ class PluginManager: | |||||
plugin_path = os.path.join(plugins_dir, plugin_name) | plugin_path = os.path.join(plugins_dir, plugin_name) | ||||
if os.path.isdir(plugin_path): | if os.path.isdir(plugin_path): | ||||
# 判断插件是否包含同名__init__.py文件 | # 判断插件是否包含同名__init__.py文件 | ||||
main_module_path = os.path.join(plugin_path,"__init__.py") | |||||
main_module_path = os.path.join(plugin_path, "__init__.py") | |||||
if os.path.isfile(main_module_path): | if os.path.isfile(main_module_path): | ||||
# 导入插件 | # 导入插件 | ||||
import_path = "plugins.{}".format(plugin_name) | import_path = "plugins.{}".format(plugin_name) | ||||
@@ -76,16 +90,26 @@ class PluginManager: | |||||
if plugin_path in self.loaded: | if plugin_path in self.loaded: | ||||
if self.loaded[plugin_path] == None: | if self.loaded[plugin_path] == None: | ||||
logger.info("reload module %s" % plugin_name) | logger.info("reload module %s" % plugin_name) | ||||
self.loaded[plugin_path] = importlib.reload(sys.modules[import_path]) | |||||
dependent_module_names = [name for name in sys.modules.keys() if name.startswith( import_path+ '.')] | |||||
self.loaded[plugin_path] = importlib.reload( | |||||
sys.modules[import_path] | |||||
) | |||||
dependent_module_names = [ | |||||
name | |||||
for name in sys.modules.keys() | |||||
if name.startswith(import_path + ".") | |||||
] | |||||
for name in dependent_module_names: | for name in dependent_module_names: | ||||
logger.info("reload module %s" % name) | logger.info("reload module %s" % name) | ||||
importlib.reload(sys.modules[name]) | importlib.reload(sys.modules[name]) | ||||
else: | else: | ||||
self.loaded[plugin_path] = importlib.import_module(import_path) | |||||
self.loaded[plugin_path] = importlib.import_module( | |||||
import_path | |||||
) | |||||
self.current_plugin_path = None | self.current_plugin_path = None | ||||
except Exception as e: | except Exception as e: | ||||
logger.exception("Failed to import plugin %s: %s" % (plugin_name, e)) | |||||
logger.exception( | |||||
"Failed to import plugin %s: %s" % (plugin_name, e) | |||||
) | |||||
continue | continue | ||||
pconf = self.pconf | pconf = self.pconf | ||||
news = [self.plugins[name] for name in self.plugins] | news = [self.plugins[name] for name in self.plugins] | ||||
@@ -95,21 +119,28 @@ class PluginManager: | |||||
rawname = plugincls.name | rawname = plugincls.name | ||||
if rawname not in pconf["plugins"]: | if rawname not in pconf["plugins"]: | ||||
modified = True | modified = True | ||||
logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name) | |||||
pconf["plugins"][rawname] = {"enabled": plugincls.enabled, "priority": plugincls.priority} | |||||
logger.info( | |||||
"Plugin %s not found in pconfig, adding to pconfig..." % name | |||||
) | |||||
pconf["plugins"][rawname] = { | |||||
"enabled": plugincls.enabled, | |||||
"priority": plugincls.priority, | |||||
} | |||||
else: | else: | ||||
self.plugins[name].enabled = pconf["plugins"][rawname]["enabled"] | self.plugins[name].enabled = pconf["plugins"][rawname]["enabled"] | ||||
self.plugins[name].priority = pconf["plugins"][rawname]["priority"] | self.plugins[name].priority = pconf["plugins"][rawname]["priority"] | ||||
self.plugins._update_heap(name) # 更新下plugins中的顺序 | |||||
self.plugins._update_heap(name) # 更新下plugins中的顺序 | |||||
if modified: | if modified: | ||||
self.save_config() | self.save_config() | ||||
return new_plugins | return new_plugins | ||||
def refresh_order(self): | def refresh_order(self): | ||||
for event in self.listening_plugins.keys(): | for event in self.listening_plugins.keys(): | ||||
self.listening_plugins[event].sort(key=lambda name: self.plugins[name].priority, reverse=True) | |||||
self.listening_plugins[event].sort( | |||||
key=lambda name: self.plugins[name].priority, reverse=True | |||||
) | |||||
def activate_plugins(self): # 生成新开启的插件实例 | |||||
def activate_plugins(self): # 生成新开启的插件实例 | |||||
failed_plugins = [] | failed_plugins = [] | ||||
for name, plugincls in self.plugins.items(): | for name, plugincls in self.plugins.items(): | ||||
if plugincls.enabled: | if plugincls.enabled: | ||||
@@ -129,7 +160,7 @@ class PluginManager: | |||||
self.refresh_order() | self.refresh_order() | ||||
return failed_plugins | return failed_plugins | ||||
def reload_plugin(self, name:str): | |||||
def reload_plugin(self, name: str): | |||||
name = name.upper() | name = name.upper() | ||||
if name in self.instances: | if name in self.instances: | ||||
for event in self.listening_plugins: | for event in self.listening_plugins: | ||||
@@ -139,13 +170,13 @@ class PluginManager: | |||||
self.activate_plugins() | self.activate_plugins() | ||||
return True | return True | ||||
return False | return False | ||||
def load_plugins(self): | def load_plugins(self): | ||||
self.load_config() | self.load_config() | ||||
self.scan_plugins() | self.scan_plugins() | ||||
pconf = self.pconf | pconf = self.pconf | ||||
logger.debug("plugins.json config={}".format(pconf)) | logger.debug("plugins.json config={}".format(pconf)) | ||||
for name,plugin in pconf["plugins"].items(): | |||||
for name, plugin in pconf["plugins"].items(): | |||||
if name.upper() not in self.plugins: | if name.upper() not in self.plugins: | ||||
logger.error("Plugin %s not found, but found in plugins.json" % name) | logger.error("Plugin %s not found, but found in plugins.json" % name) | ||||
self.activate_plugins() | self.activate_plugins() | ||||
@@ -153,13 +184,18 @@ class PluginManager: | |||||
def emit_event(self, e_context: EventContext, *args, **kwargs): | def emit_event(self, e_context: EventContext, *args, **kwargs): | ||||
if e_context.event in self.listening_plugins: | if e_context.event in self.listening_plugins: | ||||
for name in self.listening_plugins[e_context.event]: | for name in self.listening_plugins[e_context.event]: | ||||
if self.plugins[name].enabled and e_context.action == EventAction.CONTINUE: | |||||
logger.debug("Plugin %s triggered by event %s" % (name,e_context.event)) | |||||
if ( | |||||
self.plugins[name].enabled | |||||
and e_context.action == EventAction.CONTINUE | |||||
): | |||||
logger.debug( | |||||
"Plugin %s triggered by event %s" % (name, e_context.event) | |||||
) | |||||
instance = self.instances[name] | instance = self.instances[name] | ||||
instance.handlers[e_context.event](e_context, *args, **kwargs) | instance.handlers[e_context.event](e_context, *args, **kwargs) | ||||
return e_context | return e_context | ||||
def set_plugin_priority(self, name:str, priority:int): | |||||
def set_plugin_priority(self, name: str, priority: int): | |||||
name = name.upper() | name = name.upper() | ||||
if name not in self.plugins: | if name not in self.plugins: | ||||
return False | return False | ||||
@@ -174,11 +210,11 @@ class PluginManager: | |||||
self.refresh_order() | self.refresh_order() | ||||
return True | return True | ||||
def enable_plugin(self, name:str): | |||||
def enable_plugin(self, name: str): | |||||
name = name.upper() | name = name.upper() | ||||
if name not in self.plugins: | if name not in self.plugins: | ||||
return False, "插件不存在" | return False, "插件不存在" | ||||
if not self.plugins[name].enabled : | |||||
if not self.plugins[name].enabled: | |||||
self.plugins[name].enabled = True | self.plugins[name].enabled = True | ||||
rawname = self.plugins[name].name | rawname = self.plugins[name].name | ||||
self.pconf["plugins"][rawname]["enabled"] = True | self.pconf["plugins"][rawname]["enabled"] = True | ||||
@@ -188,43 +224,47 @@ class PluginManager: | |||||
return False, "插件开启失败" | return False, "插件开启失败" | ||||
return True, "插件已开启" | return True, "插件已开启" | ||||
return True, "插件已开启" | return True, "插件已开启" | ||||
def disable_plugin(self, name:str): | |||||
def disable_plugin(self, name: str): | |||||
name = name.upper() | name = name.upper() | ||||
if name not in self.plugins: | if name not in self.plugins: | ||||
return False | return False | ||||
if self.plugins[name].enabled : | |||||
if self.plugins[name].enabled: | |||||
self.plugins[name].enabled = False | self.plugins[name].enabled = False | ||||
rawname = self.plugins[name].name | rawname = self.plugins[name].name | ||||
self.pconf["plugins"][rawname]["enabled"] = False | self.pconf["plugins"][rawname]["enabled"] = False | ||||
self.save_config() | self.save_config() | ||||
return True | return True | ||||
return True | return True | ||||
def list_plugins(self): | def list_plugins(self): | ||||
return self.plugins | return self.plugins | ||||
def install_plugin(self, repo:str): | |||||
def install_plugin(self, repo: str): | |||||
try: | try: | ||||
import common.package_manager as pkgmgr | import common.package_manager as pkgmgr | ||||
pkgmgr.check_dulwich() | pkgmgr.check_dulwich() | ||||
except Exception as e: | except Exception as e: | ||||
logger.error("Failed to install plugin, {}".format(e)) | logger.error("Failed to install plugin, {}".format(e)) | ||||
return False, "无法导入dulwich,安装插件失败" | return False, "无法导入dulwich,安装插件失败" | ||||
import re | import re | ||||
from dulwich import porcelain | from dulwich import porcelain | ||||
logger.info("clone git repo: {}".format(repo)) | logger.info("clone git repo: {}".format(repo)) | ||||
match = re.match(r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo) | match = re.match(r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo) | ||||
if not match: | if not match: | ||||
try: | try: | ||||
with open("./plugins/source.json","r", encoding="utf-8") as f: | |||||
with open("./plugins/source.json", "r", encoding="utf-8") as f: | |||||
source = json.load(f) | source = json.load(f) | ||||
if repo in source["repo"]: | if repo in source["repo"]: | ||||
repo = source["repo"][repo]["url"] | repo = source["repo"][repo]["url"] | ||||
match = re.match(r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo) | |||||
match = re.match( | |||||
r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo | |||||
) | |||||
if not match: | if not match: | ||||
return False, "安装插件失败,source中的仓库地址不合法" | return False, "安装插件失败,source中的仓库地址不合法" | ||||
else: | else: | ||||
@@ -232,42 +272,53 @@ class PluginManager: | |||||
except Exception as e: | except Exception as e: | ||||
logger.error("Failed to install plugin, {}".format(e)) | logger.error("Failed to install plugin, {}".format(e)) | ||||
return False, "安装插件失败,请检查仓库地址是否正确" | return False, "安装插件失败,请检查仓库地址是否正确" | ||||
dirname = os.path.join("./plugins",match.group(4)) | |||||
dirname = os.path.join("./plugins", match.group(4)) | |||||
try: | try: | ||||
repo = porcelain.clone(repo, dirname, checkout=True) | repo = porcelain.clone(repo, dirname, checkout=True) | ||||
if os.path.exists(os.path.join(dirname,"requirements.txt")): | |||||
if os.path.exists(os.path.join(dirname, "requirements.txt")): | |||||
logger.info("detect requirements.txt,installing...") | logger.info("detect requirements.txt,installing...") | ||||
pkgmgr.install_requirements(os.path.join(dirname,"requirements.txt")) | |||||
pkgmgr.install_requirements(os.path.join(dirname, "requirements.txt")) | |||||
return True, "安装插件成功,请使用 #scanp 命令扫描插件或重启程序,开启前请检查插件是否需要配置" | return True, "安装插件成功,请使用 #scanp 命令扫描插件或重启程序,开启前请检查插件是否需要配置" | ||||
except Exception as e: | except Exception as e: | ||||
logger.error("Failed to install plugin, {}".format(e)) | logger.error("Failed to install plugin, {}".format(e)) | ||||
return False, "安装插件失败,"+str(e) | |||||
def update_plugin(self, name:str): | |||||
return False, "安装插件失败," + str(e) | |||||
def update_plugin(self, name: str): | |||||
try: | try: | ||||
import common.package_manager as pkgmgr | import common.package_manager as pkgmgr | ||||
pkgmgr.check_dulwich() | pkgmgr.check_dulwich() | ||||
except Exception as e: | except Exception as e: | ||||
logger.error("Failed to install plugin, {}".format(e)) | logger.error("Failed to install plugin, {}".format(e)) | ||||
return False, "无法导入dulwich,更新插件失败" | return False, "无法导入dulwich,更新插件失败" | ||||
from dulwich import porcelain | from dulwich import porcelain | ||||
name = name.upper() | name = name.upper() | ||||
if name not in self.plugins: | if name not in self.plugins: | ||||
return False, "插件不存在" | return False, "插件不存在" | ||||
if name in ["HELLO","GODCMD","ROLE","TOOL","BDUNIT","BANWORDS","FINISH","DUNGEON"]: | |||||
if name in [ | |||||
"HELLO", | |||||
"GODCMD", | |||||
"ROLE", | |||||
"TOOL", | |||||
"BDUNIT", | |||||
"BANWORDS", | |||||
"FINISH", | |||||
"DUNGEON", | |||||
]: | |||||
return False, "预置插件无法更新,请更新主程序仓库" | return False, "预置插件无法更新,请更新主程序仓库" | ||||
dirname = self.plugins[name].path | dirname = self.plugins[name].path | ||||
try: | try: | ||||
porcelain.pull(dirname, "origin") | porcelain.pull(dirname, "origin") | ||||
if os.path.exists(os.path.join(dirname,"requirements.txt")): | |||||
if os.path.exists(os.path.join(dirname, "requirements.txt")): | |||||
logger.info("detect requirements.txt,installing...") | logger.info("detect requirements.txt,installing...") | ||||
pkgmgr.install_requirements(os.path.join(dirname,"requirements.txt")) | |||||
pkgmgr.install_requirements(os.path.join(dirname, "requirements.txt")) | |||||
return True, "更新插件成功,请重新运行程序" | return True, "更新插件成功,请重新运行程序" | ||||
except Exception as e: | except Exception as e: | ||||
logger.error("Failed to update plugin, {}".format(e)) | logger.error("Failed to update plugin, {}".format(e)) | ||||
return False, "更新插件失败,"+str(e) | |||||
def uninstall_plugin(self, name:str): | |||||
return False, "更新插件失败," + str(e) | |||||
def uninstall_plugin(self, name: str): | |||||
name = name.upper() | name = name.upper() | ||||
if name not in self.plugins: | if name not in self.plugins: | ||||
return False, "插件不存在" | return False, "插件不存在" | ||||
@@ -276,6 +327,7 @@ class PluginManager: | |||||
dirname = self.plugins[name].path | dirname = self.plugins[name].path | ||||
try: | try: | ||||
import shutil | import shutil | ||||
shutil.rmtree(dirname) | shutil.rmtree(dirname) | ||||
rawname = self.plugins[name].name | rawname = self.plugins[name].name | ||||
for event in self.listening_plugins: | for event in self.listening_plugins: | ||||
@@ -288,4 +340,4 @@ class PluginManager: | |||||
return True, "卸载插件成功" | return True, "卸载插件成功" | ||||
except Exception as e: | except Exception as e: | ||||
logger.error("Failed to uninstall plugin, {}".format(e)) | logger.error("Failed to uninstall plugin, {}".format(e)) | ||||
return False, "卸载插件失败,请手动删除文件夹完成卸载,"+str(e) | |||||
return False, "卸载插件失败,请手动删除文件夹完成卸载," + str(e) |
@@ -1 +1 @@ | |||||
from .role import * | |||||
from .role import * |
@@ -2,17 +2,18 @@ | |||||
import json | import json | ||||
import os | import os | ||||
import plugins | |||||
from bridge.bridge import Bridge | from bridge.bridge import Bridge | ||||
from bridge.context import ContextType | from bridge.context import ContextType | ||||
from bridge.reply import Reply, ReplyType | from bridge.reply import Reply, ReplyType | ||||
from common import const | from common import const | ||||
from common.log import logger | |||||
from config import conf | from config import conf | ||||
import plugins | |||||
from plugins import * | from plugins import * | ||||
from common.log import logger | |||||
class RolePlay(): | |||||
class RolePlay: | |||||
def __init__(self, bot, sessionid, desc, wrapper=None): | def __init__(self, bot, sessionid, desc, wrapper=None): | ||||
self.bot = bot | self.bot = bot | ||||
self.sessionid = sessionid | self.sessionid = sessionid | ||||
@@ -25,12 +26,20 @@ class RolePlay(): | |||||
def action(self, user_action): | def action(self, user_action): | ||||
session = self.bot.sessions.build_session(self.sessionid) | session = self.bot.sessions.build_session(self.sessionid) | ||||
if session.system_prompt != self.desc: # 目前没有触发session过期事件,这里先简单判断,然后重置 | |||||
if session.system_prompt != self.desc: # 目前没有触发session过期事件,这里先简单判断,然后重置 | |||||
session.set_system_prompt(self.desc) | session.set_system_prompt(self.desc) | ||||
prompt = self.wrapper % user_action | prompt = self.wrapper % user_action | ||||
return prompt | return prompt | ||||
@plugins.register(name="Role", desire_priority=0, namecn="角色扮演", desc="为你的Bot设置预设角色", version="1.0", author="lanvent") | |||||
@plugins.register( | |||||
name="Role", | |||||
desire_priority=0, | |||||
namecn="角色扮演", | |||||
desc="为你的Bot设置预设角色", | |||||
version="1.0", | |||||
author="lanvent", | |||||
) | |||||
class Role(Plugin): | class Role(Plugin): | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
@@ -39,7 +48,7 @@ class Role(Plugin): | |||||
try: | try: | ||||
with open(config_path, "r", encoding="utf-8") as f: | with open(config_path, "r", encoding="utf-8") as f: | ||||
config = json.load(f) | config = json.load(f) | ||||
self.tags = { tag:(desc,[]) for tag,desc in config["tags"].items()} | |||||
self.tags = {tag: (desc, []) for tag, desc in config["tags"].items()} | |||||
self.roles = {} | self.roles = {} | ||||
for role in config["roles"]: | for role in config["roles"]: | ||||
self.roles[role["title"].lower()] = role | self.roles[role["title"].lower()] = role | ||||
@@ -60,12 +69,16 @@ class Role(Plugin): | |||||
logger.info("[Role] inited") | logger.info("[Role] inited") | ||||
except Exception as e: | except Exception as e: | ||||
if isinstance(e, FileNotFoundError): | if isinstance(e, FileNotFoundError): | ||||
logger.warn(f"[Role] init failed, {config_path} not found, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role .") | |||||
logger.warn( | |||||
f"[Role] init failed, {config_path} not found, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role ." | |||||
) | |||||
else: | else: | ||||
logger.warn("[Role] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role .") | |||||
logger.warn( | |||||
"[Role] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role ." | |||||
) | |||||
raise e | raise e | ||||
def get_role(self, name, find_closest=True, min_sim = 0.35): | |||||
def get_role(self, name, find_closest=True, min_sim=0.35): | |||||
name = name.lower() | name = name.lower() | ||||
found_role = None | found_role = None | ||||
if name in self.roles: | if name in self.roles: | ||||
@@ -75,6 +88,7 @@ class Role(Plugin): | |||||
def str_simularity(a, b): | def str_simularity(a, b): | ||||
return difflib.SequenceMatcher(None, a, b).ratio() | return difflib.SequenceMatcher(None, a, b).ratio() | ||||
max_sim = min_sim | max_sim = min_sim | ||||
max_role = None | max_role = None | ||||
for role in self.roles: | for role in self.roles: | ||||
@@ -86,25 +100,24 @@ class Role(Plugin): | |||||
return found_role | return found_role | ||||
def on_handle_context(self, e_context: EventContext): | def on_handle_context(self, e_context: EventContext): | ||||
if e_context['context'].type != ContextType.TEXT: | |||||
if e_context["context"].type != ContextType.TEXT: | |||||
return | return | ||||
bottype = Bridge().get_bot_type("chat") | bottype = Bridge().get_bot_type("chat") | ||||
if bottype not in (const.CHATGPT, const.OPEN_AI): | if bottype not in (const.CHATGPT, const.OPEN_AI): | ||||
return | return | ||||
bot = Bridge().get_bot("chat") | bot = Bridge().get_bot("chat") | ||||
content = e_context['context'].content[:] | |||||
clist = e_context['context'].content.split(maxsplit=1) | |||||
content = e_context["context"].content[:] | |||||
clist = e_context["context"].content.split(maxsplit=1) | |||||
desckey = None | desckey = None | ||||
customize = False | customize = False | ||||
sessionid = e_context['context']['session_id'] | |||||
trigger_prefix = conf().get('plugin_trigger_prefix', "$") | |||||
sessionid = e_context["context"]["session_id"] | |||||
trigger_prefix = conf().get("plugin_trigger_prefix", "$") | |||||
if clist[0] == f"{trigger_prefix}停止扮演": | if clist[0] == f"{trigger_prefix}停止扮演": | ||||
if sessionid in self.roleplays: | if sessionid in self.roleplays: | ||||
self.roleplays[sessionid].reset() | self.roleplays[sessionid].reset() | ||||
del self.roleplays[sessionid] | del self.roleplays[sessionid] | ||||
reply = Reply(ReplyType.INFO, "角色扮演结束!") | reply = Reply(ReplyType.INFO, "角色扮演结束!") | ||||
e_context['reply'] = reply | |||||
e_context["reply"] = reply | |||||
e_context.action = EventAction.BREAK_PASS | e_context.action = EventAction.BREAK_PASS | ||||
return | return | ||||
elif clist[0] == f"{trigger_prefix}角色": | elif clist[0] == f"{trigger_prefix}角色": | ||||
@@ -114,10 +127,10 @@ class Role(Plugin): | |||||
elif clist[0] == f"{trigger_prefix}设定扮演": | elif clist[0] == f"{trigger_prefix}设定扮演": | ||||
customize = True | customize = True | ||||
elif clist[0] == f"{trigger_prefix}角色类型": | elif clist[0] == f"{trigger_prefix}角色类型": | ||||
if len(clist) >1: | |||||
if len(clist) > 1: | |||||
tag = clist[1].strip() | tag = clist[1].strip() | ||||
help_text = "角色列表:\n" | help_text = "角色列表:\n" | ||||
for key,value in self.tags.items(): | |||||
for key, value in self.tags.items(): | |||||
if value[0] == tag: | if value[0] == tag: | ||||
tag = key | tag = key | ||||
break | break | ||||
@@ -130,57 +143,75 @@ class Role(Plugin): | |||||
else: | else: | ||||
help_text = f"未知角色类型。\n" | help_text = f"未知角色类型。\n" | ||||
help_text += "目前的角色类型有: \n" | help_text += "目前的角色类型有: \n" | ||||
help_text += ",".join([self.tags[tag][0] for tag in self.tags])+"\n" | |||||
help_text += ( | |||||
",".join([self.tags[tag][0] for tag in self.tags]) + "\n" | |||||
) | |||||
else: | else: | ||||
help_text = f"请输入角色类型。\n" | help_text = f"请输入角色类型。\n" | ||||
help_text += "目前的角色类型有: \n" | help_text += "目前的角色类型有: \n" | ||||
help_text += ",".join([self.tags[tag][0] for tag in self.tags])+"\n" | |||||
help_text += ",".join([self.tags[tag][0] for tag in self.tags]) + "\n" | |||||
reply = Reply(ReplyType.INFO, help_text) | reply = Reply(ReplyType.INFO, help_text) | ||||
e_context['reply'] = reply | |||||
e_context["reply"] = reply | |||||
e_context.action = EventAction.BREAK_PASS | e_context.action = EventAction.BREAK_PASS | ||||
return | return | ||||
elif sessionid not in self.roleplays: | elif sessionid not in self.roleplays: | ||||
return | return | ||||
logger.debug("[Role] on_handle_context. content: %s" % content) | logger.debug("[Role] on_handle_context. content: %s" % content) | ||||
if desckey is not None: | if desckey is not None: | ||||
if len(clist) == 1 or (len(clist) > 1 and clist[1].lower() in ["help", "帮助"]): | |||||
if len(clist) == 1 or ( | |||||
len(clist) > 1 and clist[1].lower() in ["help", "帮助"] | |||||
): | |||||
reply = Reply(ReplyType.INFO, self.get_help_text(verbose=True)) | reply = Reply(ReplyType.INFO, self.get_help_text(verbose=True)) | ||||
e_context['reply'] = reply | |||||
e_context["reply"] = reply | |||||
e_context.action = EventAction.BREAK_PASS | e_context.action = EventAction.BREAK_PASS | ||||
return | return | ||||
role = self.get_role(clist[1]) | role = self.get_role(clist[1]) | ||||
if role is None: | if role is None: | ||||
reply = Reply(ReplyType.ERROR, "角色不存在") | reply = Reply(ReplyType.ERROR, "角色不存在") | ||||
e_context['reply'] = reply | |||||
e_context["reply"] = reply | |||||
e_context.action = EventAction.BREAK_PASS | e_context.action = EventAction.BREAK_PASS | ||||
return | return | ||||
else: | else: | ||||
self.roleplays[sessionid] = RolePlay(bot, sessionid, self.roles[role][desckey], self.roles[role].get("wrapper","%s")) | |||||
reply = Reply(ReplyType.INFO, f"预设角色为 {role}:\n"+self.roles[role][desckey]) | |||||
e_context['reply'] = reply | |||||
self.roleplays[sessionid] = RolePlay( | |||||
bot, | |||||
sessionid, | |||||
self.roles[role][desckey], | |||||
self.roles[role].get("wrapper", "%s"), | |||||
) | |||||
reply = Reply( | |||||
ReplyType.INFO, f"预设角色为 {role}:\n" + self.roles[role][desckey] | |||||
) | |||||
e_context["reply"] = reply | |||||
e_context.action = EventAction.BREAK_PASS | e_context.action = EventAction.BREAK_PASS | ||||
elif customize == True: | elif customize == True: | ||||
self.roleplays[sessionid] = RolePlay(bot, sessionid, clist[1], "%s") | self.roleplays[sessionid] = RolePlay(bot, sessionid, clist[1], "%s") | ||||
reply = Reply(ReplyType.INFO, f"角色设定为:\n{clist[1]}") | reply = Reply(ReplyType.INFO, f"角色设定为:\n{clist[1]}") | ||||
e_context['reply'] = reply | |||||
e_context["reply"] = reply | |||||
e_context.action = EventAction.BREAK_PASS | e_context.action = EventAction.BREAK_PASS | ||||
else: | else: | ||||
prompt = self.roleplays[sessionid].action(content) | prompt = self.roleplays[sessionid].action(content) | ||||
e_context['context'].type = ContextType.TEXT | |||||
e_context['context'].content = prompt | |||||
e_context["context"].type = ContextType.TEXT | |||||
e_context["context"].content = prompt | |||||
e_context.action = EventAction.BREAK | e_context.action = EventAction.BREAK | ||||
def get_help_text(self, verbose=False, **kwargs): | def get_help_text(self, verbose=False, **kwargs): | ||||
help_text = "让机器人扮演不同的角色。\n" | help_text = "让机器人扮演不同的角色。\n" | ||||
if not verbose: | if not verbose: | ||||
return help_text | return help_text | ||||
trigger_prefix = conf().get('plugin_trigger_prefix', "$") | |||||
help_text = f"使用方法:\n{trigger_prefix}角色"+" 预设角色名: 设定角色为{预设角色名}。\n"+f"{trigger_prefix}role"+" 预设角色名: 同上,但使用英文设定。\n" | |||||
help_text += f"{trigger_prefix}设定扮演"+" 角色设定: 设定自定义角色人设为{角色设定}。\n" | |||||
trigger_prefix = conf().get("plugin_trigger_prefix", "$") | |||||
help_text = ( | |||||
f"使用方法:\n{trigger_prefix}角色" | |||||
+ " 预设角色名: 设定角色为{预设角色名}。\n" | |||||
+ f"{trigger_prefix}role" | |||||
+ " 预设角色名: 同上,但使用英文设定。\n" | |||||
) | |||||
help_text += f"{trigger_prefix}设定扮演" + " 角色设定: 设定自定义角色人设为{角色设定}。\n" | |||||
help_text += f"{trigger_prefix}停止扮演: 清除设定的角色。\n" | help_text += f"{trigger_prefix}停止扮演: 清除设定的角色。\n" | ||||
help_text += f"{trigger_prefix}角色类型"+" 角色类型: 查看某类{角色类型}的所有预设角色,为所有时输出所有预设角色。\n" | |||||
help_text += ( | |||||
f"{trigger_prefix}角色类型" + " 角色类型: 查看某类{角色类型}的所有预设角色,为所有时输出所有预设角色。\n" | |||||
) | |||||
help_text += "\n目前的角色类型有: \n" | help_text += "\n目前的角色类型有: \n" | ||||
help_text += ",".join([self.tags[tag][0] for tag in self.tags])+"。\n" | |||||
help_text += ",".join([self.tags[tag][0] for tag in self.tags]) + "。\n" | |||||
help_text += f"\n命令例子: \n{trigger_prefix}角色 写作助理\n" | help_text += f"\n命令例子: \n{trigger_prefix}角色 写作助理\n" | ||||
help_text += f"{trigger_prefix}角色类型 所有\n" | help_text += f"{trigger_prefix}角色类型 所有\n" | ||||
help_text += f"{trigger_prefix}停止扮演\n" | help_text += f"{trigger_prefix}停止扮演\n" | ||||
@@ -428,4 +428,4 @@ | |||||
] | ] | ||||
} | } | ||||
] | ] | ||||
} | |||||
} |
@@ -1,16 +1,16 @@ | |||||
{ | { | ||||
"repo": { | |||||
"sdwebui": { | |||||
"url": "https://github.com/lanvent/plugin_sdwebui.git", | |||||
"desc": "利用stable-diffusion画图的插件" | |||||
}, | |||||
"replicate": { | |||||
"url": "https://github.com/lanvent/plugin_replicate.git", | |||||
"desc": "利用replicate api画图的插件" | |||||
}, | |||||
"summary": { | |||||
"url": "https://github.com/lanvent/plugin_summary.git", | |||||
"desc": "总结聊天记录的插件" | |||||
} | |||||
"repo": { | |||||
"sdwebui": { | |||||
"url": "https://github.com/lanvent/plugin_sdwebui.git", | |||||
"desc": "利用stable-diffusion画图的插件" | |||||
}, | |||||
"replicate": { | |||||
"url": "https://github.com/lanvent/plugin_replicate.git", | |||||
"desc": "利用replicate api画图的插件" | |||||
}, | |||||
"summary": { | |||||
"url": "https://github.com/lanvent/plugin_summary.git", | |||||
"desc": "总结聊天记录的插件" | |||||
} | } | ||||
} | |||||
} | |||||
} |
@@ -1,14 +1,14 @@ | |||||
## 插件描述 | ## 插件描述 | ||||
一个能让chatgpt联网,搜索,数字运算的插件,将赋予强大且丰富的扩展能力 | |||||
一个能让chatgpt联网,搜索,数字运算的插件,将赋予强大且丰富的扩展能力 | |||||
使用该插件需在机器人回复你的前提下,在对话内容前加$tool;仅输入$tool将返回tool插件帮助信息,用于测试插件是否加载成功 | 使用该插件需在机器人回复你的前提下,在对话内容前加$tool;仅输入$tool将返回tool插件帮助信息,用于测试插件是否加载成功 | ||||
### 本插件所有工具同步存放至专用仓库:[chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub) | ### 本插件所有工具同步存放至专用仓库:[chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub) | ||||
## 使用说明 | ## 使用说明 | ||||
使用该插件后将默认使用4个工具, 无需额外配置长期生效: | |||||
### 1. python | |||||
使用该插件后将默认使用4个工具, 无需额外配置长期生效: | |||||
### 1. python | |||||
###### python解释器,使用它来解释执行python指令,可以配合你想要chatgpt生成的代码输出结果或执行事务 | ###### python解释器,使用它来解释执行python指令,可以配合你想要chatgpt生成的代码输出结果或执行事务 | ||||
### 2. url-get | ### 2. url-get | ||||
###### 往往用来获取某个网站具体内容,结果可能会被反爬策略影响 | ###### 往往用来获取某个网站具体内容,结果可能会被反爬策略影响 | ||||
@@ -23,16 +23,16 @@ | |||||
> meteo调优记录:https://github.com/zhayujie/chatgpt-on-wechat/issues/776#issuecomment-1500771334 | > meteo调优记录:https://github.com/zhayujie/chatgpt-on-wechat/issues/776#issuecomment-1500771334 | ||||
## 使用本插件对话(prompt)技巧 | |||||
### 1. 有指引的询问 | |||||
## 使用本插件对话(prompt)技巧 | |||||
### 1. 有指引的询问 | |||||
#### 例如: | #### 例如: | ||||
- 总结这个链接的内容 https://github.com/goldfishh/chatgpt-tool-hub | |||||
- 总结这个链接的内容 https://github.com/goldfishh/chatgpt-tool-hub | |||||
- 使用Terminal执行curl cip.cc | - 使用Terminal执行curl cip.cc | ||||
- 使用python查询今天日期 | - 使用python查询今天日期 | ||||
### 2. 使用搜索引擎工具 | ### 2. 使用搜索引擎工具 | ||||
- 如果有搜索工具就能让chatgpt获取到你的未传达清楚的上下文信息,比如chatgpt不知道你的地理位置,现在时间等,所以无法查询到天气 | - 如果有搜索工具就能让chatgpt获取到你的未传达清楚的上下文信息,比如chatgpt不知道你的地理位置,现在时间等,所以无法查询到天气 | ||||
## 其他工具 | ## 其他工具 | ||||
### 5. wikipedia | ### 5. wikipedia | ||||
@@ -55,9 +55,9 @@ | |||||
### 10. google-search * | ### 10. google-search * | ||||
###### google搜索引擎,申请流程较bing-search繁琐 | ###### google搜索引擎,申请流程较bing-search繁琐 | ||||
###### 注1:带*工具需要获取api-key才能使用,部分工具需要外网支持 | |||||
###### 注1:带*工具需要获取api-key才能使用,部分工具需要外网支持 | |||||
#### [申请方法](https://github.com/goldfishh/chatgpt-tool-hub/blob/master/docs/apply_optional_tool.md) | #### [申请方法](https://github.com/goldfishh/chatgpt-tool-hub/blob/master/docs/apply_optional_tool.md) | ||||
## config.json 配置说明 | ## config.json 配置说明 | ||||
###### 默认工具无需配置,其它工具需手动配置,一个例子: | ###### 默认工具无需配置,其它工具需手动配置,一个例子: | ||||
```json | ```json | ||||
@@ -71,15 +71,15 @@ | |||||
} | } | ||||
``` | ``` | ||||
注:config.json文件非必须,未创建仍可使用本tool;带*工具需在kwargs填入对应api-key键值对 | |||||
注:config.json文件非必须,未创建仍可使用本tool;带*工具需在kwargs填入对应api-key键值对 | |||||
- `tools`:本插件初始化时加载的工具, 目前可选集:["wikipedia", "wolfram-alpha", "bing-search", "google-search", "news", "morning-news"] & 默认工具,除wikipedia工具之外均需要申请api-key | - `tools`:本插件初始化时加载的工具, 目前可选集:["wikipedia", "wolfram-alpha", "bing-search", "google-search", "news", "morning-news"] & 默认工具,除wikipedia工具之外均需要申请api-key | ||||
- `kwargs`:工具执行时的配置,一般在这里存放**api-key**,或环境配置 | - `kwargs`:工具执行时的配置,一般在这里存放**api-key**,或环境配置 | ||||
- `request_timeout`: 访问openai接口的超时时间,默认与wechat-on-chatgpt配置一致,可单独配置 | - `request_timeout`: 访问openai接口的超时时间,默认与wechat-on-chatgpt配置一致,可单独配置 | ||||
- `no_default`: 用于配置默认加载4个工具的行为,如果为true则仅使用tools列表工具,不加载默认工具 | - `no_default`: 用于配置默认加载4个工具的行为,如果为true则仅使用tools列表工具,不加载默认工具 | ||||
- `top_k_results`: 控制所有有关搜索的工具返回条目数,数字越高则参考信息越多,但无用信息可能干扰判断,该值一般为2 | - `top_k_results`: 控制所有有关搜索的工具返回条目数,数字越高则参考信息越多,但无用信息可能干扰判断,该值一般为2 | ||||
- `model_name`: 用于控制tool插件底层使用的llm模型,目前暂未测试3.5以外的模型,一般保持默认 | - `model_name`: 用于控制tool插件底层使用的llm模型,目前暂未测试3.5以外的模型,一般保持默认 | ||||
## 备注 | ## 备注 | ||||
- 强烈建议申请搜索工具搭配使用,推荐bing-search | - 强烈建议申请搜索工具搭配使用,推荐bing-search | ||||
- 虽然我会有意加入一些限制,但请不要使用本插件做危害他人的事情,请提前了解清楚某些内容是否会违反相关规定,建议提前做好过滤 | - 虽然我会有意加入一些限制,但请不要使用本插件做危害他人的事情,请提前了解清楚某些内容是否会违反相关规定,建议提前做好过滤 | ||||
@@ -1 +1 @@ | |||||
from .tool import * | |||||
from .tool import * |
@@ -1,8 +1,13 @@ | |||||
{ | { | ||||
"tools": ["python", "url-get", "terminal", "meteo-weather"], | |||||
"tools": [ | |||||
"python", | |||||
"url-get", | |||||
"terminal", | |||||
"meteo-weather" | |||||
], | |||||
"kwargs": { | "kwargs": { | ||||
"top_k_results": 2, | |||||
"no_default": false, | |||||
"model_name": "gpt-3.5-turbo" | |||||
"top_k_results": 2, | |||||
"no_default": false, | |||||
"model_name": "gpt-3.5-turbo" | |||||
} | } | ||||
} | |||||
} |
@@ -4,6 +4,7 @@ import os | |||||
from chatgpt_tool_hub.apps import load_app | from chatgpt_tool_hub.apps import load_app | ||||
from chatgpt_tool_hub.apps.app import App | from chatgpt_tool_hub.apps.app import App | ||||
from chatgpt_tool_hub.tools.all_tool_list import get_all_tool_names | from chatgpt_tool_hub.tools.all_tool_list import get_all_tool_names | ||||
import plugins | import plugins | ||||
from bridge.bridge import Bridge | from bridge.bridge import Bridge | ||||
from bridge.context import ContextType | from bridge.context import ContextType | ||||
@@ -14,7 +15,13 @@ from config import conf | |||||
from plugins import * | from plugins import * | ||||
@plugins.register(name="tool", desc="Arming your ChatGPT bot with various tools", version="0.3", author="goldfishh", desire_priority=0) | |||||
@plugins.register( | |||||
name="tool", | |||||
desc="Arming your ChatGPT bot with various tools", | |||||
version="0.3", | |||||
author="goldfishh", | |||||
desire_priority=0, | |||||
) | |||||
class Tool(Plugin): | class Tool(Plugin): | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
@@ -28,22 +35,26 @@ class Tool(Plugin): | |||||
help_text = "这是一个能让chatgpt联网,搜索,数字运算的插件,将赋予强大且丰富的扩展能力。" | help_text = "这是一个能让chatgpt联网,搜索,数字运算的插件,将赋予强大且丰富的扩展能力。" | ||||
if not verbose: | if not verbose: | ||||
return help_text | return help_text | ||||
trigger_prefix = conf().get('plugin_trigger_prefix', "$") | |||||
trigger_prefix = conf().get("plugin_trigger_prefix", "$") | |||||
help_text += "使用说明:\n" | help_text += "使用说明:\n" | ||||
help_text += f"{trigger_prefix}tool "+"命令: 根据给出的{命令}使用一些可用工具尽力为你得到结果。\n" | |||||
help_text += f"{trigger_prefix}tool " + "命令: 根据给出的{命令}使用一些可用工具尽力为你得到结果。\n" | |||||
help_text += f"{trigger_prefix}tool reset: 重置工具。\n" | help_text += f"{trigger_prefix}tool reset: 重置工具。\n" | ||||
return help_text | return help_text | ||||
def on_handle_context(self, e_context: EventContext): | def on_handle_context(self, e_context: EventContext): | ||||
if e_context['context'].type != ContextType.TEXT: | |||||
if e_context["context"].type != ContextType.TEXT: | |||||
return | return | ||||
# 暂时不支持未来扩展的bot | # 暂时不支持未来扩展的bot | ||||
if Bridge().get_bot_type("chat") not in (const.CHATGPT, const.OPEN_AI, const.CHATGPTONAZURE): | |||||
if Bridge().get_bot_type("chat") not in ( | |||||
const.CHATGPT, | |||||
const.OPEN_AI, | |||||
const.CHATGPTONAZURE, | |||||
): | |||||
return | return | ||||
content = e_context['context'].content | |||||
content_list = e_context['context'].content.split(maxsplit=1) | |||||
content = e_context["context"].content | |||||
content_list = e_context["context"].content.split(maxsplit=1) | |||||
if not content or len(content_list) < 1: | if not content or len(content_list) < 1: | ||||
e_context.action = EventAction.CONTINUE | e_context.action = EventAction.CONTINUE | ||||
@@ -52,13 +63,13 @@ class Tool(Plugin): | |||||
logger.debug("[tool] on_handle_context. content: %s" % content) | logger.debug("[tool] on_handle_context. content: %s" % content) | ||||
reply = Reply() | reply = Reply() | ||||
reply.type = ReplyType.TEXT | reply.type = ReplyType.TEXT | ||||
trigger_prefix = conf().get('plugin_trigger_prefix', "$") | |||||
trigger_prefix = conf().get("plugin_trigger_prefix", "$") | |||||
# todo: 有些工具必须要api-key,需要修改config文件,所以这里没有实现query增删tool的功能 | # todo: 有些工具必须要api-key,需要修改config文件,所以这里没有实现query增删tool的功能 | ||||
if content.startswith(f"{trigger_prefix}tool"): | if content.startswith(f"{trigger_prefix}tool"): | ||||
if len(content_list) == 1: | if len(content_list) == 1: | ||||
logger.debug("[tool]: get help") | logger.debug("[tool]: get help") | ||||
reply.content = self.get_help_text() | reply.content = self.get_help_text() | ||||
e_context['reply'] = reply | |||||
e_context["reply"] = reply | |||||
e_context.action = EventAction.BREAK_PASS | e_context.action = EventAction.BREAK_PASS | ||||
return | return | ||||
elif len(content_list) > 1: | elif len(content_list) > 1: | ||||
@@ -66,12 +77,14 @@ class Tool(Plugin): | |||||
logger.debug("[tool]: reset config") | logger.debug("[tool]: reset config") | ||||
self.app = self._reset_app() | self.app = self._reset_app() | ||||
reply.content = "重置工具成功" | reply.content = "重置工具成功" | ||||
e_context['reply'] = reply | |||||
e_context["reply"] = reply | |||||
e_context.action = EventAction.BREAK_PASS | e_context.action = EventAction.BREAK_PASS | ||||
return | return | ||||
elif content_list[1].startswith("reset"): | elif content_list[1].startswith("reset"): | ||||
logger.debug("[tool]: remind") | logger.debug("[tool]: remind") | ||||
e_context['context'].content = "请你随机用一种聊天风格,提醒用户:如果想重置tool插件,reset之后不要加任何字符" | |||||
e_context[ | |||||
"context" | |||||
].content = "请你随机用一种聊天风格,提醒用户:如果想重置tool插件,reset之后不要加任何字符" | |||||
e_context.action = EventAction.BREAK | e_context.action = EventAction.BREAK | ||||
return | return | ||||
@@ -80,34 +93,35 @@ class Tool(Plugin): | |||||
# Don't modify bot name | # Don't modify bot name | ||||
all_sessions = Bridge().get_bot("chat").sessions | all_sessions = Bridge().get_bot("chat").sessions | ||||
user_session = all_sessions.session_query(query, e_context['context']['session_id']).messages | |||||
user_session = all_sessions.session_query( | |||||
query, e_context["context"]["session_id"] | |||||
).messages | |||||
# chatgpt-tool-hub will reply you with many tools | # chatgpt-tool-hub will reply you with many tools | ||||
logger.debug("[tool]: just-go") | logger.debug("[tool]: just-go") | ||||
try: | try: | ||||
_reply = self.app.ask(query, user_session) | _reply = self.app.ask(query, user_session) | ||||
e_context.action = EventAction.BREAK_PASS | e_context.action = EventAction.BREAK_PASS | ||||
all_sessions.session_reply(_reply, e_context['context']['session_id']) | |||||
all_sessions.session_reply( | |||||
_reply, e_context["context"]["session_id"] | |||||
) | |||||
except Exception as e: | except Exception as e: | ||||
logger.exception(e) | logger.exception(e) | ||||
logger.error(str(e)) | logger.error(str(e)) | ||||
e_context['context'].content = "请你随机用一种聊天风格,提醒用户:这个问题tool插件暂时无法处理" | |||||
e_context["context"].content = "请你随机用一种聊天风格,提醒用户:这个问题tool插件暂时无法处理" | |||||
reply.type = ReplyType.ERROR | reply.type = ReplyType.ERROR | ||||
e_context.action = EventAction.BREAK | e_context.action = EventAction.BREAK | ||||
return | return | ||||
reply.content = _reply | reply.content = _reply | ||||
e_context['reply'] = reply | |||||
e_context["reply"] = reply | |||||
return | return | ||||
def _read_json(self) -> dict: | def _read_json(self) -> dict: | ||||
curdir = os.path.dirname(__file__) | curdir = os.path.dirname(__file__) | ||||
config_path = os.path.join(curdir, "config.json") | config_path = os.path.join(curdir, "config.json") | ||||
tool_config = { | |||||
"tools": [], | |||||
"kwargs": {} | |||||
} | |||||
tool_config = {"tools": [], "kwargs": {}} | |||||
if not os.path.exists(config_path): | if not os.path.exists(config_path): | ||||
return tool_config | return tool_config | ||||
else: | else: | ||||
@@ -123,7 +137,9 @@ class Tool(Plugin): | |||||
"proxy": conf().get("proxy", ""), | "proxy": conf().get("proxy", ""), | ||||
"request_timeout": conf().get("request_timeout", 60), | "request_timeout": conf().get("request_timeout", 60), | ||||
# note: 目前tool暂未对其他模型测试,但这里仍对配置来源做了优先级区分,一般插件配置可覆盖全局配置 | # note: 目前tool暂未对其他模型测试,但这里仍对配置来源做了优先级区分,一般插件配置可覆盖全局配置 | ||||
"model_name": tool_model_name if tool_model_name else conf().get("model", "gpt-3.5-turbo"), | |||||
"model_name": tool_model_name | |||||
if tool_model_name | |||||
else conf().get("model", "gpt-3.5-turbo"), | |||||
"no_default": kwargs.get("no_default", False), | "no_default": kwargs.get("no_default", False), | ||||
"top_k_results": kwargs.get("top_k_results", 2), | "top_k_results": kwargs.get("top_k_results", 2), | ||||
# for news tool | # for news tool | ||||
@@ -160,4 +176,7 @@ class Tool(Plugin): | |||||
# filter not support tool | # filter not support tool | ||||
tool_list = self._filter_tool_list(tool_config.get("tools", [])) | tool_list = self._filter_tool_list(tool_config.get("tools", [])) | ||||
return load_app(tools_list=tool_list, **self._build_tool_kwargs(tool_config.get("kwargs", {}))) | |||||
return load_app( | |||||
tools_list=tool_list, | |||||
**self._build_tool_kwargs(tool_config.get("kwargs", {})), | |||||
) |
@@ -4,3 +4,4 @@ PyQRCode>=1.2.1 | |||||
qrcode>=7.4.2 | qrcode>=7.4.2 | ||||
requests>=2.28.2 | requests>=2.28.2 | ||||
chardet>=5.1.0 | chardet>=5.1.0 | ||||
pre-commit |
@@ -8,7 +8,7 @@ echo $BASE_DIR | |||||
# check the nohup.out log output file | # check the nohup.out log output file | ||||
if [ ! -f "${BASE_DIR}/nohup.out" ]; then | if [ ! -f "${BASE_DIR}/nohup.out" ]; then | ||||
touch "${BASE_DIR}/nohup.out" | touch "${BASE_DIR}/nohup.out" | ||||
echo "create file ${BASE_DIR}/nohup.out" | |||||
echo "create file ${BASE_DIR}/nohup.out" | |||||
fi | fi | ||||
nohup python3 "${BASE_DIR}/app.py" & tail -f "${BASE_DIR}/nohup.out" | nohup python3 "${BASE_DIR}/app.py" & tail -f "${BASE_DIR}/nohup.out" | ||||
@@ -7,7 +7,7 @@ echo $BASE_DIR | |||||
# check the nohup.out log output file | # check the nohup.out log output file | ||||
if [ ! -f "${BASE_DIR}/nohup.out" ]; then | if [ ! -f "${BASE_DIR}/nohup.out" ]; then | ||||
echo "No file ${BASE_DIR}/nohup.out" | |||||
echo "No file ${BASE_DIR}/nohup.out" | |||||
exit -1; | exit -1; | ||||
fi | fi | ||||
@@ -1,9 +1,12 @@ | |||||
import shutil | import shutil | ||||
import wave | import wave | ||||
import pysilk | import pysilk | ||||
from pydub import AudioSegment | from pydub import AudioSegment | ||||
sil_supports=[8000, 12000, 16000, 24000, 32000, 44100, 48000] # slk转wav时,支持的采样率 | |||||
sil_supports = [8000, 12000, 16000, 24000, 32000, 44100, 48000] # slk转wav时,支持的采样率 | |||||
def find_closest_sil_supports(sample_rate): | def find_closest_sil_supports(sample_rate): | ||||
""" | """ | ||||
找到最接近的支持的采样率 | 找到最接近的支持的采样率 | ||||
@@ -19,6 +22,7 @@ def find_closest_sil_supports(sample_rate): | |||||
mindiff = diff | mindiff = diff | ||||
return closest | return closest | ||||
def get_pcm_from_wav(wav_path): | def get_pcm_from_wav(wav_path): | ||||
""" | """ | ||||
从 wav 文件中读取 pcm | 从 wav 文件中读取 pcm | ||||
@@ -29,31 +33,42 @@ def get_pcm_from_wav(wav_path): | |||||
wav = wave.open(wav_path, "rb") | wav = wave.open(wav_path, "rb") | ||||
return wav.readframes(wav.getnframes()) | return wav.readframes(wav.getnframes()) | ||||
def any_to_wav(any_path, wav_path): | def any_to_wav(any_path, wav_path): | ||||
""" | """ | ||||
把任意格式转成wav文件 | 把任意格式转成wav文件 | ||||
""" | """ | ||||
if any_path.endswith('.wav'): | |||||
if any_path.endswith(".wav"): | |||||
shutil.copy2(any_path, wav_path) | shutil.copy2(any_path, wav_path) | ||||
return | return | ||||
if any_path.endswith('.sil') or any_path.endswith('.silk') or any_path.endswith('.slk'): | |||||
if ( | |||||
any_path.endswith(".sil") | |||||
or any_path.endswith(".silk") | |||||
or any_path.endswith(".slk") | |||||
): | |||||
return sil_to_wav(any_path, wav_path) | return sil_to_wav(any_path, wav_path) | ||||
audio = AudioSegment.from_file(any_path) | audio = AudioSegment.from_file(any_path) | ||||
audio.export(wav_path, format="wav") | audio.export(wav_path, format="wav") | ||||
def any_to_sil(any_path, sil_path): | def any_to_sil(any_path, sil_path): | ||||
""" | """ | ||||
把任意格式转成sil文件 | 把任意格式转成sil文件 | ||||
""" | """ | ||||
if any_path.endswith('.sil') or any_path.endswith('.silk') or any_path.endswith('.slk'): | |||||
if ( | |||||
any_path.endswith(".sil") | |||||
or any_path.endswith(".silk") | |||||
or any_path.endswith(".slk") | |||||
): | |||||
shutil.copy2(any_path, sil_path) | shutil.copy2(any_path, sil_path) | ||||
return 10000 | return 10000 | ||||
if any_path.endswith('.wav'): | |||||
if any_path.endswith(".wav"): | |||||
return pcm_to_sil(any_path, sil_path) | return pcm_to_sil(any_path, sil_path) | ||||
if any_path.endswith('.mp3'): | |||||
if any_path.endswith(".mp3"): | |||||
return mp3_to_sil(any_path, sil_path) | return mp3_to_sil(any_path, sil_path) | ||||
raise NotImplementedError("Not support file type: {}".format(any_path)) | raise NotImplementedError("Not support file type: {}".format(any_path)) | ||||
def mp3_to_wav(mp3_path, wav_path): | def mp3_to_wav(mp3_path, wav_path): | ||||
""" | """ | ||||
把mp3格式转成pcm文件 | 把mp3格式转成pcm文件 | ||||
@@ -61,6 +76,7 @@ def mp3_to_wav(mp3_path, wav_path): | |||||
audio = AudioSegment.from_mp3(mp3_path) | audio = AudioSegment.from_mp3(mp3_path) | ||||
audio.export(wav_path, format="wav") | audio.export(wav_path, format="wav") | ||||
def pcm_to_sil(pcm_path, silk_path): | def pcm_to_sil(pcm_path, silk_path): | ||||
""" | """ | ||||
wav 文件转成 silk | wav 文件转成 silk | ||||
@@ -72,12 +88,12 @@ def pcm_to_sil(pcm_path, silk_path): | |||||
pcm_s16 = audio.set_sample_width(2) | pcm_s16 = audio.set_sample_width(2) | ||||
pcm_s16 = pcm_s16.set_frame_rate(rate) | pcm_s16 = pcm_s16.set_frame_rate(rate) | ||||
wav_data = pcm_s16.raw_data | wav_data = pcm_s16.raw_data | ||||
silk_data = pysilk.encode( | |||||
wav_data, data_rate=rate, sample_rate=rate) | |||||
silk_data = pysilk.encode(wav_data, data_rate=rate, sample_rate=rate) | |||||
with open(silk_path, "wb") as f: | with open(silk_path, "wb") as f: | ||||
f.write(silk_data) | f.write(silk_data) | ||||
return audio.duration_seconds * 1000 | return audio.duration_seconds * 1000 | ||||
def mp3_to_sil(mp3_path, silk_path): | def mp3_to_sil(mp3_path, silk_path): | ||||
""" | """ | ||||
mp3 文件转成 silk | mp3 文件转成 silk | ||||
@@ -95,6 +111,7 @@ def mp3_to_sil(mp3_path, silk_path): | |||||
f.write(silk_data) | f.write(silk_data) | ||||
return audio.duration_seconds * 1000 | return audio.duration_seconds * 1000 | ||||
def sil_to_wav(silk_path, wav_path, rate: int = 24000): | def sil_to_wav(silk_path, wav_path, rate: int = 24000): | ||||
""" | """ | ||||
silk 文件转 wav | silk 文件转 wav | ||||
@@ -1,16 +1,18 @@ | |||||
""" | """ | ||||
azure voice service | azure voice service | ||||
""" | """ | ||||
import json | import json | ||||
import os | import os | ||||
import time | import time | ||||
import azure.cognitiveservices.speech as speechsdk | import azure.cognitiveservices.speech as speechsdk | ||||
from bridge.reply import Reply, ReplyType | from bridge.reply import Reply, ReplyType | ||||
from common.log import logger | from common.log import logger | ||||
from common.tmp_dir import TmpDir | from common.tmp_dir import TmpDir | ||||
from voice.voice import Voice | |||||
from config import conf | from config import conf | ||||
from voice.voice import Voice | |||||
""" | """ | ||||
Azure voice | Azure voice | ||||
主目录设置文件中需填写azure_voice_api_key和azure_voice_region | 主目录设置文件中需填写azure_voice_api_key和azure_voice_region | ||||
@@ -19,50 +21,68 @@ Azure voice | |||||
""" | """ | ||||
class AzureVoice(Voice): | |||||
class AzureVoice(Voice): | |||||
def __init__(self): | def __init__(self): | ||||
try: | try: | ||||
curdir = os.path.dirname(__file__) | curdir = os.path.dirname(__file__) | ||||
config_path = os.path.join(curdir, "config.json") | config_path = os.path.join(curdir, "config.json") | ||||
config = None | config = None | ||||
if not os.path.exists(config_path): #如果没有配置文件,创建本地配置文件 | |||||
config = { "speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural", "speech_recognition_language": "zh-CN"} | |||||
if not os.path.exists(config_path): # 如果没有配置文件,创建本地配置文件 | |||||
config = { | |||||
"speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural", | |||||
"speech_recognition_language": "zh-CN", | |||||
} | |||||
with open(config_path, "w") as fw: | with open(config_path, "w") as fw: | ||||
json.dump(config, fw, indent=4) | json.dump(config, fw, indent=4) | ||||
else: | else: | ||||
with open(config_path, "r") as fr: | with open(config_path, "r") as fr: | ||||
config = json.load(fr) | config = json.load(fr) | ||||
self.api_key = conf().get('azure_voice_api_key') | |||||
self.api_region = conf().get('azure_voice_region') | |||||
self.speech_config = speechsdk.SpeechConfig(subscription=self.api_key, region=self.api_region) | |||||
self.speech_config.speech_synthesis_voice_name = config["speech_synthesis_voice_name"] | |||||
self.speech_config.speech_recognition_language = config["speech_recognition_language"] | |||||
self.api_key = conf().get("azure_voice_api_key") | |||||
self.api_region = conf().get("azure_voice_region") | |||||
self.speech_config = speechsdk.SpeechConfig( | |||||
subscription=self.api_key, region=self.api_region | |||||
) | |||||
self.speech_config.speech_synthesis_voice_name = config[ | |||||
"speech_synthesis_voice_name" | |||||
] | |||||
self.speech_config.speech_recognition_language = config[ | |||||
"speech_recognition_language" | |||||
] | |||||
except Exception as e: | except Exception as e: | ||||
logger.warn("AzureVoice init failed: %s, ignore " % e) | logger.warn("AzureVoice init failed: %s, ignore " % e) | ||||
def voiceToText(self, voice_file): | def voiceToText(self, voice_file): | ||||
audio_config = speechsdk.AudioConfig(filename=voice_file) | audio_config = speechsdk.AudioConfig(filename=voice_file) | ||||
speech_recognizer = speechsdk.SpeechRecognizer(speech_config=self.speech_config, audio_config=audio_config) | |||||
speech_recognizer = speechsdk.SpeechRecognizer( | |||||
speech_config=self.speech_config, audio_config=audio_config | |||||
) | |||||
result = speech_recognizer.recognize_once() | result = speech_recognizer.recognize_once() | ||||
if result.reason == speechsdk.ResultReason.RecognizedSpeech: | if result.reason == speechsdk.ResultReason.RecognizedSpeech: | ||||
logger.info('[Azure] voiceToText voice file name={} text={}'.format(voice_file, result.text)) | |||||
logger.info( | |||||
"[Azure] voiceToText voice file name={} text={}".format( | |||||
voice_file, result.text | |||||
) | |||||
) | |||||
reply = Reply(ReplyType.TEXT, result.text) | reply = Reply(ReplyType.TEXT, result.text) | ||||
else: | else: | ||||
logger.error('[Azure] voiceToText error, result={}'.format(result)) | |||||
logger.error("[Azure] voiceToText error, result={}".format(result)) | |||||
reply = Reply(ReplyType.ERROR, "抱歉,语音识别失败") | reply = Reply(ReplyType.ERROR, "抱歉,语音识别失败") | ||||
return reply | return reply | ||||
def textToVoice(self, text): | def textToVoice(self, text): | ||||
fileName = TmpDir().path() + 'reply-' + str(int(time.time())) + '.wav' | |||||
fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".wav" | |||||
audio_config = speechsdk.AudioConfig(filename=fileName) | audio_config = speechsdk.AudioConfig(filename=fileName) | ||||
speech_synthesizer = speechsdk.SpeechSynthesizer(speech_config=self.speech_config, audio_config=audio_config) | |||||
speech_synthesizer = speechsdk.SpeechSynthesizer( | |||||
speech_config=self.speech_config, audio_config=audio_config | |||||
) | |||||
result = speech_synthesizer.speak_text(text) | result = speech_synthesizer.speak_text(text) | ||||
if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted: | if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted: | ||||
logger.info( | logger.info( | ||||
'[Azure] textToVoice text={} voice file name={}'.format(text, fileName)) | |||||
"[Azure] textToVoice text={} voice file name={}".format(text, fileName) | |||||
) | |||||
reply = Reply(ReplyType.VOICE, fileName) | reply = Reply(ReplyType.VOICE, fileName) | ||||
else: | else: | ||||
logger.error('[Azure] textToVoice error, result={}'.format(result)) | |||||
logger.error("[Azure] textToVoice error, result={}".format(result)) | |||||
reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败") | reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败") | ||||
return reply | return reply |
@@ -1,4 +1,4 @@ | |||||
{ | { | ||||
"speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural", | |||||
"speech_recognition_language": "zh-CN" | |||||
} | |||||
"speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural", | |||||
"speech_recognition_language": "zh-CN" | |||||
} |
@@ -29,7 +29,7 @@ dev_pid 必填 语言选择,填写语言对应的dev_pid值 | |||||
2、对于def textToVoice(self, text)函数中调用的百度语音合成API,中接口调用synthesis(参数)在本目录下的`config.json`文件中进行配置。 | 2、对于def textToVoice(self, text)函数中调用的百度语音合成API,中接口调用synthesis(参数)在本目录下的`config.json`文件中进行配置。 | ||||
参数 可需 描述 | 参数 可需 描述 | ||||
tex 必填 合成的文本,使用UTF-8编码,请注意文本长度必须小于1024字节 | |||||
tex 必填 合成的文本,使用UTF-8编码,请注意文本长度必须小于1024字节 | |||||
lan 必填 固定值zh。语言选择,目前只有中英文混合模式,填写固定值zh | lan 必填 固定值zh。语言选择,目前只有中英文混合模式,填写固定值zh | ||||
spd 选填 语速,取值0-15,默认为5中语速 | spd 选填 语速,取值0-15,默认为5中语速 | ||||
pit 选填 音调,取值0-15,默认为5中语调 | pit 选填 音调,取值0-15,默认为5中语调 | ||||
@@ -40,14 +40,14 @@ aue 选填 3为mp3格式(默认); 4为pcm-16k;5为pcm-8k;6为wav | |||||
关于per参数的说明,注意您购买的哪个音库,就填写哪个音库的参数,否则会报错。如果您购买的是基础音库,那么per参数只能填写0到4,如果您购买的是精品音库,那么per参数只能填写5003,5118,106,110,111,103,5其他的都会报错。 | 关于per参数的说明,注意您购买的哪个音库,就填写哪个音库的参数,否则会报错。如果您购买的是基础音库,那么per参数只能填写0到4,如果您购买的是精品音库,那么per参数只能填写5003,5118,106,110,111,103,5其他的都会报错。 | ||||
### 配置文件 | ### 配置文件 | ||||
将文件夹中`config.json.template`复制为`config.json`。 | 将文件夹中`config.json.template`复制为`config.json`。 | ||||
``` json | ``` json | ||||
{ | { | ||||
"lang": "zh", | |||||
"lang": "zh", | |||||
"ctp": 1, | "ctp": 1, | ||||
"spd": 5, | |||||
"spd": 5, | |||||
"pit": 5, | "pit": 5, | ||||
"vol": 5, | "vol": 5, | ||||
"per": 0 | "per": 0 | ||||
@@ -1,17 +1,19 @@ | |||||
""" | """ | ||||
baidu voice service | baidu voice service | ||||
""" | """ | ||||
import json | import json | ||||
import os | import os | ||||
import time | import time | ||||
from aip import AipSpeech | from aip import AipSpeech | ||||
from bridge.reply import Reply, ReplyType | from bridge.reply import Reply, ReplyType | ||||
from common.log import logger | from common.log import logger | ||||
from common.tmp_dir import TmpDir | from common.tmp_dir import TmpDir | ||||
from voice.voice import Voice | |||||
from voice.audio_convert import get_pcm_from_wav | |||||
from config import conf | from config import conf | ||||
from voice.audio_convert import get_pcm_from_wav | |||||
from voice.voice import Voice | |||||
""" | """ | ||||
百度的语音识别API. | 百度的语音识别API. | ||||
dev_pid: | dev_pid: | ||||
@@ -28,40 +30,37 @@ from config import conf | |||||
class BaiduVoice(Voice): | class BaiduVoice(Voice): | ||||
def __init__(self): | def __init__(self): | ||||
try: | try: | ||||
curdir = os.path.dirname(__file__) | curdir = os.path.dirname(__file__) | ||||
config_path = os.path.join(curdir, "config.json") | config_path = os.path.join(curdir, "config.json") | ||||
bconf = None | bconf = None | ||||
if not os.path.exists(config_path): #如果没有配置文件,创建本地配置文件 | |||||
bconf = { "lang": "zh", "ctp": 1, "spd": 5, | |||||
"pit": 5, "vol": 5, "per": 0} | |||||
if not os.path.exists(config_path): # 如果没有配置文件,创建本地配置文件 | |||||
bconf = {"lang": "zh", "ctp": 1, "spd": 5, "pit": 5, "vol": 5, "per": 0} | |||||
with open(config_path, "w") as fw: | with open(config_path, "w") as fw: | ||||
json.dump(bconf, fw, indent=4) | json.dump(bconf, fw, indent=4) | ||||
else: | else: | ||||
with open(config_path, "r") as fr: | with open(config_path, "r") as fr: | ||||
bconf = json.load(fr) | bconf = json.load(fr) | ||||
self.app_id = conf().get('baidu_app_id') | |||||
self.api_key = conf().get('baidu_api_key') | |||||
self.secret_key = conf().get('baidu_secret_key') | |||||
self.dev_id = conf().get('baidu_dev_pid') | |||||
self.app_id = conf().get("baidu_app_id") | |||||
self.api_key = conf().get("baidu_api_key") | |||||
self.secret_key = conf().get("baidu_secret_key") | |||||
self.dev_id = conf().get("baidu_dev_pid") | |||||
self.lang = bconf["lang"] | self.lang = bconf["lang"] | ||||
self.ctp = bconf["ctp"] | self.ctp = bconf["ctp"] | ||||
self.spd = bconf["spd"] | self.spd = bconf["spd"] | ||||
self.pit = bconf["pit"] | self.pit = bconf["pit"] | ||||
self.vol = bconf["vol"] | self.vol = bconf["vol"] | ||||
self.per = bconf["per"] | self.per = bconf["per"] | ||||
self.client = AipSpeech(self.app_id, self.api_key, self.secret_key) | self.client = AipSpeech(self.app_id, self.api_key, self.secret_key) | ||||
except Exception as e: | except Exception as e: | ||||
logger.warn("BaiduVoice init failed: %s, ignore " % e) | logger.warn("BaiduVoice init failed: %s, ignore " % e) | ||||
def voiceToText(self, voice_file): | def voiceToText(self, voice_file): | ||||
# 识别本地文件 | # 识别本地文件 | ||||
logger.debug('[Baidu] voice file name={}'.format(voice_file)) | |||||
logger.debug("[Baidu] voice file name={}".format(voice_file)) | |||||
pcm = get_pcm_from_wav(voice_file) | pcm = get_pcm_from_wav(voice_file) | ||||
res = self.client.asr(pcm, "pcm", 16000, {"dev_pid": self.dev_id}) | res = self.client.asr(pcm, "pcm", 16000, {"dev_pid": self.dev_id}) | ||||
if res["err_no"] == 0: | if res["err_no"] == 0: | ||||
@@ -72,21 +71,25 @@ class BaiduVoice(Voice): | |||||
logger.info("百度语音识别出错了: {}".format(res["err_msg"])) | logger.info("百度语音识别出错了: {}".format(res["err_msg"])) | ||||
if res["err_msg"] == "request pv too much": | if res["err_msg"] == "request pv too much": | ||||
logger.info(" 出现这个原因很可能是你的百度语音服务调用量超出限制,或未开通付费") | logger.info(" 出现这个原因很可能是你的百度语音服务调用量超出限制,或未开通付费") | ||||
reply = Reply(ReplyType.ERROR, | |||||
"百度语音识别出错了;{0}".format(res["err_msg"])) | |||||
reply = Reply(ReplyType.ERROR, "百度语音识别出错了;{0}".format(res["err_msg"])) | |||||
return reply | return reply | ||||
def textToVoice(self, text): | def textToVoice(self, text): | ||||
result = self.client.synthesis(text, self.lang, self.ctp, { | |||||
'spd': self.spd, 'pit': self.pit, 'vol': self.vol, 'per': self.per}) | |||||
result = self.client.synthesis( | |||||
text, | |||||
self.lang, | |||||
self.ctp, | |||||
{"spd": self.spd, "pit": self.pit, "vol": self.vol, "per": self.per}, | |||||
) | |||||
if not isinstance(result, dict): | if not isinstance(result, dict): | ||||
fileName = TmpDir().path() + 'reply-' + str(int(time.time())) + '.mp3' | |||||
with open(fileName, 'wb') as f: | |||||
fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3" | |||||
with open(fileName, "wb") as f: | |||||
f.write(result) | f.write(result) | ||||
logger.info( | logger.info( | ||||
'[Baidu] textToVoice text={} voice file name={}'.format(text, fileName)) | |||||
"[Baidu] textToVoice text={} voice file name={}".format(text, fileName) | |||||
) | |||||
reply = Reply(ReplyType.VOICE, fileName) | reply = Reply(ReplyType.VOICE, fileName) | ||||
else: | else: | ||||
logger.error('[Baidu] textToVoice error={}'.format(result)) | |||||
logger.error("[Baidu] textToVoice error={}".format(result)) | |||||
reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败") | reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败") | ||||
return reply | return reply |
@@ -1,8 +1,8 @@ | |||||
{ | |||||
"lang": "zh", | |||||
"ctp": 1, | |||||
"spd": 5, | |||||
"pit": 5, | |||||
"vol": 5, | |||||
"per": 0 | |||||
} | |||||
{ | |||||
"lang": "zh", | |||||
"ctp": 1, | |||||
"spd": 5, | |||||
"pit": 5, | |||||
"vol": 5, | |||||
"per": 0 | |||||
} |
@@ -1,11 +1,12 @@ | |||||
""" | """ | ||||
google voice service | google voice service | ||||
""" | """ | ||||
import time | import time | ||||
import speech_recognition | import speech_recognition | ||||
from gtts import gTTS | from gtts import gTTS | ||||
from bridge.reply import Reply, ReplyType | from bridge.reply import Reply, ReplyType | ||||
from common.log import logger | from common.log import logger | ||||
from common.tmp_dir import TmpDir | from common.tmp_dir import TmpDir | ||||
@@ -22,9 +23,12 @@ class GoogleVoice(Voice): | |||||
with speech_recognition.AudioFile(voice_file) as source: | with speech_recognition.AudioFile(voice_file) as source: | ||||
audio = self.recognizer.record(source) | audio = self.recognizer.record(source) | ||||
try: | try: | ||||
text = self.recognizer.recognize_google(audio, language='zh-CN') | |||||
text = self.recognizer.recognize_google(audio, language="zh-CN") | |||||
logger.info( | logger.info( | ||||
'[Google] voiceToText text={} voice file name={}'.format(text, voice_file)) | |||||
"[Google] voiceToText text={} voice file name={}".format( | |||||
text, voice_file | |||||
) | |||||
) | |||||
reply = Reply(ReplyType.TEXT, text) | reply = Reply(ReplyType.TEXT, text) | ||||
except speech_recognition.UnknownValueError: | except speech_recognition.UnknownValueError: | ||||
reply = Reply(ReplyType.ERROR, "抱歉,我听不懂") | reply = Reply(ReplyType.ERROR, "抱歉,我听不懂") | ||||
@@ -32,13 +36,15 @@ class GoogleVoice(Voice): | |||||
reply = Reply(ReplyType.ERROR, "抱歉,无法连接到 Google 语音识别服务;{0}".format(e)) | reply = Reply(ReplyType.ERROR, "抱歉,无法连接到 Google 语音识别服务;{0}".format(e)) | ||||
finally: | finally: | ||||
return reply | return reply | ||||
def textToVoice(self, text): | def textToVoice(self, text): | ||||
try: | try: | ||||
mp3File = TmpDir().path() + 'reply-' + str(int(time.time())) + '.mp3' | |||||
tts = gTTS(text=text, lang='zh') | |||||
tts.save(mp3File) | |||||
mp3File = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3" | |||||
tts = gTTS(text=text, lang="zh") | |||||
tts.save(mp3File) | |||||
logger.info( | logger.info( | ||||
'[Google] textToVoice text={} voice file name={}'.format(text, mp3File)) | |||||
"[Google] textToVoice text={} voice file name={}".format(text, mp3File) | |||||
) | |||||
reply = Reply(ReplyType.VOICE, mp3File) | reply = Reply(ReplyType.VOICE, mp3File) | ||||
except Exception as e: | except Exception as e: | ||||
reply = Reply(ReplyType.ERROR, str(e)) | reply = Reply(ReplyType.ERROR, str(e)) | ||||
@@ -1,29 +1,32 @@ | |||||
""" | """ | ||||
google voice service | google voice service | ||||
""" | """ | ||||
import json | import json | ||||
import openai | import openai | ||||
from bridge.reply import Reply, ReplyType | from bridge.reply import Reply, ReplyType | ||||
from config import conf | |||||
from common.log import logger | from common.log import logger | ||||
from config import conf | |||||
from voice.voice import Voice | from voice.voice import Voice | ||||
class OpenaiVoice(Voice): | class OpenaiVoice(Voice): | ||||
def __init__(self): | def __init__(self): | ||||
openai.api_key = conf().get('open_ai_api_key') | |||||
openai.api_key = conf().get("open_ai_api_key") | |||||
def voiceToText(self, voice_file): | def voiceToText(self, voice_file): | ||||
logger.debug( | |||||
'[Openai] voice file name={}'.format(voice_file)) | |||||
logger.debug("[Openai] voice file name={}".format(voice_file)) | |||||
try: | try: | ||||
file = open(voice_file, "rb") | file = open(voice_file, "rb") | ||||
result = openai.Audio.transcribe("whisper-1", file) | result = openai.Audio.transcribe("whisper-1", file) | ||||
text = result["text"] | text = result["text"] | ||||
reply = Reply(ReplyType.TEXT, text) | reply = Reply(ReplyType.TEXT, text) | ||||
logger.info( | logger.info( | ||||
'[Openai] voiceToText text={} voice file name={}'.format(text, voice_file)) | |||||
"[Openai] voiceToText text={} voice file name={}".format( | |||||
text, voice_file | |||||
) | |||||
) | |||||
except Exception as e: | except Exception as e: | ||||
reply = Reply(ReplyType.ERROR, str(e)) | reply = Reply(ReplyType.ERROR, str(e)) | ||||
finally: | finally: | ||||
@@ -1,10 +1,11 @@ | |||||
""" | """ | ||||
pytts voice service (offline) | pytts voice service (offline) | ||||
""" | """ | ||||
import time | import time | ||||
import pyttsx3 | import pyttsx3 | ||||
from bridge.reply import Reply, ReplyType | from bridge.reply import Reply, ReplyType | ||||
from common.log import logger | from common.log import logger | ||||
from common.tmp_dir import TmpDir | from common.tmp_dir import TmpDir | ||||
@@ -16,20 +17,21 @@ class PyttsVoice(Voice): | |||||
def __init__(self): | def __init__(self): | ||||
# 语速 | # 语速 | ||||
self.engine.setProperty('rate', 125) | |||||
self.engine.setProperty("rate", 125) | |||||
# 音量 | # 音量 | ||||
self.engine.setProperty('volume', 1.0) | |||||
for voice in self.engine.getProperty('voices'): | |||||
self.engine.setProperty("volume", 1.0) | |||||
for voice in self.engine.getProperty("voices"): | |||||
if "Chinese" in voice.name: | if "Chinese" in voice.name: | ||||
self.engine.setProperty('voice', voice.id) | |||||
self.engine.setProperty("voice", voice.id) | |||||
def textToVoice(self, text): | def textToVoice(self, text): | ||||
try: | try: | ||||
wavFile = TmpDir().path() + 'reply-' + str(int(time.time())) + '.wav' | |||||
wavFile = TmpDir().path() + "reply-" + str(int(time.time())) + ".wav" | |||||
self.engine.save_to_file(text, wavFile) | self.engine.save_to_file(text, wavFile) | ||||
self.engine.runAndWait() | self.engine.runAndWait() | ||||
logger.info( | logger.info( | ||||
'[Pytts] textToVoice text={} voice file name={}'.format(text, wavFile)) | |||||
"[Pytts] textToVoice text={} voice file name={}".format(text, wavFile) | |||||
) | |||||
reply = Reply(ReplyType.VOICE, wavFile) | reply = Reply(ReplyType.VOICE, wavFile) | ||||
except Exception as e: | except Exception as e: | ||||
reply = Reply(ReplyType.ERROR, str(e)) | reply = Reply(ReplyType.ERROR, str(e)) | ||||
@@ -2,6 +2,7 @@ | |||||
Voice service abstract class | Voice service abstract class | ||||
""" | """ | ||||
class Voice(object): | class Voice(object): | ||||
def voiceToText(self, voice_file): | def voiceToText(self, voice_file): | ||||
""" | """ | ||||
@@ -13,4 +14,4 @@ class Voice(object): | |||||
""" | """ | ||||
Send text to voice service and get voice | Send text to voice service and get voice | ||||
""" | """ | ||||
raise NotImplementedError | |||||
raise NotImplementedError |
@@ -2,25 +2,31 @@ | |||||
voice factory | voice factory | ||||
""" | """ | ||||
def create_voice(voice_type): | def create_voice(voice_type): | ||||
""" | """ | ||||
create a voice instance | create a voice instance | ||||
:param voice_type: voice type code | :param voice_type: voice type code | ||||
:return: voice instance | :return: voice instance | ||||
""" | """ | ||||
if voice_type == 'baidu': | |||||
if voice_type == "baidu": | |||||
from voice.baidu.baidu_voice import BaiduVoice | from voice.baidu.baidu_voice import BaiduVoice | ||||
return BaiduVoice() | return BaiduVoice() | ||||
elif voice_type == 'google': | |||||
elif voice_type == "google": | |||||
from voice.google.google_voice import GoogleVoice | from voice.google.google_voice import GoogleVoice | ||||
return GoogleVoice() | return GoogleVoice() | ||||
elif voice_type == 'openai': | |||||
elif voice_type == "openai": | |||||
from voice.openai.openai_voice import OpenaiVoice | from voice.openai.openai_voice import OpenaiVoice | ||||
return OpenaiVoice() | return OpenaiVoice() | ||||
elif voice_type == 'pytts': | |||||
elif voice_type == "pytts": | |||||
from voice.pytts.pytts_voice import PyttsVoice | from voice.pytts.pytts_voice import PyttsVoice | ||||
return PyttsVoice() | return PyttsVoice() | ||||
elif voice_type == 'azure': | |||||
elif voice_type == "azure": | |||||
from voice.azure.azure_voice import AzureVoice | from voice.azure.azure_voice import AzureVoice | ||||
return AzureVoice() | return AzureVoice() | ||||
raise RuntimeError | raise RuntimeError |