@@ -27,5 +27,5 @@ | |||
### 环境 | |||
- 操作系统类型 (Mac/Windows/Linux): | |||
- Python版本 ( 执行 `python3 -V` ): | |||
- Python版本 ( 执行 `python3 -V` ): | |||
- pip版本 ( 依赖问题此项必填,执行 `pip3 -V`): |
@@ -49,9 +49,9 @@ jobs: | |||
file: ./docker/Dockerfile.latest | |||
tags: ${{ steps.meta.outputs.tags }} | |||
labels: ${{ steps.meta.outputs.labels }} | |||
- uses: actions/delete-package-versions@v4 | |||
with: | |||
with: | |||
package-name: 'chatgpt-on-wechat' | |||
package-type: 'container' | |||
min-versions-to-keep: 10 | |||
@@ -120,7 +120,7 @@ pip3 install azure-cognitiveservices-speech | |||
```bash | |||
# config.json文件内容示例 | |||
{ | |||
{ | |||
"open_ai_api_key": "YOUR API KEY", # 填入上面创建的 OpenAI API KEY | |||
"model": "gpt-3.5-turbo", # 模型名称。当use_azure_chatgpt为true时,其名称为Azure上model deployment名称 | |||
"proxy": "127.0.0.1:7890", # 代理客户端的ip和端口 | |||
@@ -128,7 +128,7 @@ pip3 install azure-cognitiveservices-speech | |||
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人 | |||
"group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复 | |||
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表 | |||
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称 | |||
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称 | |||
"image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀 | |||
"conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数 | |||
"speech_recognition": false, # 是否开启语音识别 | |||
@@ -160,7 +160,7 @@ pip3 install azure-cognitiveservices-speech | |||
**4.其他配置** | |||
+ `model`: 模型名称,目前支持 `gpt-3.5-turbo`, `text-davinci-003`, `gpt-4`, `gpt-4-32k` (其中gpt-4 api暂未开放) | |||
+ `temperature`,`frequency_penalty`,`presence_penalty`: Chat API接口参数,详情参考[OpenAI官方文档。](https://platform.openai.com/docs/api-reference/chat) | |||
+ `temperature`,`frequency_penalty`,`presence_penalty`: Chat API接口参数,详情参考[OpenAI官方文档。](https://platform.openai.com/docs/api-reference/chat) | |||
+ `proxy`:由于目前 `openai` 接口国内无法访问,需配置代理客户端的地址,详情参考 [#351](https://github.com/zhayujie/chatgpt-on-wechat/issues/351) | |||
+ 对于图像生成,在满足个人或群组触发条件外,还需要额外的关键词前缀来触发,对应配置 `image_create_prefix ` | |||
+ 关于OpenAI对话及图片接口的参数配置(内容自由度、回复字数限制、图片大小等),可以参考 [对话接口](https://beta.openai.com/docs/api-reference/completions) 和 [图像接口](https://beta.openai.com/docs/api-reference/completions) 文档直接在 [代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/bot/openai/open_ai_bot.py) `bot/openai/open_ai_bot.py` 中进行调整。 | |||
@@ -181,7 +181,7 @@ pip3 install azure-cognitiveservices-speech | |||
```bash | |||
python3 app.py | |||
``` | |||
终端输出二维码后,使用微信进行扫码,当输出 "Start auto replying" 时表示自动回复程序已经成功运行了(注意:用于登录的微信需要在支付处已完成实名认证)。扫码登录后你的账号就成为机器人了,可以在微信手机端通过配置的关键词触发自动回复 (任意好友发送消息给你,或是自己发消息给好友),参考[#142](https://github.com/zhayujie/chatgpt-on-wechat/issues/142)。 | |||
终端输出二维码后,使用微信进行扫码,当输出 "Start auto replying" 时表示自动回复程序已经成功运行了(注意:用于登录的微信需要在支付处已完成实名认证)。扫码登录后你的账号就成为机器人了,可以在微信手机端通过配置的关键词触发自动回复 (任意好友发送消息给你,或是自己发消息给好友),参考[#142](https://github.com/zhayujie/chatgpt-on-wechat/issues/142)。 | |||
### 2.服务器部署 | |||
@@ -189,7 +189,7 @@ python3 app.py | |||
使用nohup命令在后台运行程序: | |||
```bash | |||
touch nohup.out # 首次运行需要新建日志文件 | |||
touch nohup.out # 首次运行需要新建日志文件 | |||
nohup python3 app.py & tail -f nohup.out # 在后台运行程序并通过日志输出二维码 | |||
``` | |||
扫码登录后程序即可运行于服务器后台,此时可通过 `ctrl+c` 关闭日志,不会影响后台程序的运行。使用 `ps -ef | grep app.py | grep -v grep` 命令可查看运行于后台的进程,如果想要重新启动程序可以先 `kill` 掉对应的进程。日志关闭后如果想要再次打开只需输入 `tail -f nohup.out`。此外,`scripts` 目录下有一键运行、关闭程序的脚本供使用。 | |||
@@ -1,23 +1,28 @@ | |||
# encoding:utf-8 | |||
import os | |||
from config import conf, load_config | |||
import signal | |||
import sys | |||
from channel import channel_factory | |||
from common.log import logger | |||
from config import conf, load_config | |||
from plugins import * | |||
import signal | |||
import sys | |||
def sigterm_handler_wrap(_signo): | |||
old_handler = signal.getsignal(_signo) | |||
def func(_signo, _stack_frame): | |||
logger.info("signal {} received, exiting...".format(_signo)) | |||
conf().save_user_datas() | |||
if callable(old_handler): # check old_handler | |||
if callable(old_handler): # check old_handler | |||
return old_handler(_signo, _stack_frame) | |||
sys.exit(0) | |||
signal.signal(_signo, func) | |||
def run(): | |||
try: | |||
# load config | |||
@@ -28,17 +33,17 @@ def run(): | |||
sigterm_handler_wrap(signal.SIGTERM) | |||
# create channel | |||
channel_name=conf().get('channel_type', 'wx') | |||
channel_name = conf().get("channel_type", "wx") | |||
if "--cmd" in sys.argv: | |||
channel_name = 'terminal' | |||
channel_name = "terminal" | |||
if channel_name == 'wxy': | |||
os.environ['WECHATY_LOG']="warn" | |||
if channel_name == "wxy": | |||
os.environ["WECHATY_LOG"] = "warn" | |||
# os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001' | |||
channel = channel_factory.create_channel(channel_name) | |||
if channel_name in ['wx','wxy','terminal','wechatmp','wechatmp_service']: | |||
if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service"]: | |||
PluginManager().load_plugins() | |||
# startup channel | |||
@@ -47,5 +52,6 @@ def run(): | |||
logger.error("App startup failed!") | |||
logger.exception(e) | |||
if __name__ == '__main__': | |||
run() | |||
if __name__ == "__main__": | |||
run() |
@@ -1,6 +1,7 @@ | |||
# encoding:utf-8 | |||
import requests | |||
from bot.bot import Bot | |||
from bridge.reply import Reply, ReplyType | |||
@@ -9,20 +10,35 @@ from bridge.reply import Reply, ReplyType | |||
class BaiduUnitBot(Bot): | |||
def reply(self, query, context=None): | |||
token = self.get_token() | |||
url = 'https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=' + token | |||
post_data = "{\"version\":\"3.0\",\"service_id\":\"S73177\",\"session_id\":\"\",\"log_id\":\"7758521\",\"skill_ids\":[\"1221886\"],\"request\":{\"terminal_id\":\"88888\",\"query\":\"" + query + "\", \"hyper_params\": {\"chat_custom_bot_profile\": 1}}}" | |||
url = ( | |||
"https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" | |||
+ token | |||
) | |||
post_data = ( | |||
'{"version":"3.0","service_id":"S73177","session_id":"","log_id":"7758521","skill_ids":["1221886"],"request":{"terminal_id":"88888","query":"' | |||
+ query | |||
+ '", "hyper_params": {"chat_custom_bot_profile": 1}}}' | |||
) | |||
print(post_data) | |||
headers = {'content-type': 'application/x-www-form-urlencoded'} | |||
headers = {"content-type": "application/x-www-form-urlencoded"} | |||
response = requests.post(url, data=post_data.encode(), headers=headers) | |||
if response: | |||
reply = Reply(ReplyType.TEXT, response.json()['result']['context']['SYS_PRESUMED_HIST'][1]) | |||
reply = Reply( | |||
ReplyType.TEXT, | |||
response.json()["result"]["context"]["SYS_PRESUMED_HIST"][1], | |||
) | |||
return reply | |||
def get_token(self): | |||
access_key = 'YOUR_ACCESS_KEY' | |||
secret_key = 'YOUR_SECRET_KEY' | |||
host = 'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=' + access_key + '&client_secret=' + secret_key | |||
access_key = "YOUR_ACCESS_KEY" | |||
secret_key = "YOUR_SECRET_KEY" | |||
host = ( | |||
"https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=" | |||
+ access_key | |||
+ "&client_secret=" | |||
+ secret_key | |||
) | |||
response = requests.get(host) | |||
if response: | |||
print(response.json()) | |||
return response.json()['access_token'] | |||
return response.json()["access_token"] |
@@ -8,7 +8,7 @@ from bridge.reply import Reply | |||
class Bot(object): | |||
def reply(self, query, context : Context =None) -> Reply: | |||
def reply(self, query, context: Context = None) -> Reply: | |||
""" | |||
bot auto-reply content | |||
:param req: received message | |||
@@ -13,20 +13,24 @@ def create_bot(bot_type): | |||
if bot_type == const.BAIDU: | |||
# Baidu Unit对话接口 | |||
from bot.baidu.baidu_unit_bot import BaiduUnitBot | |||
return BaiduUnitBot() | |||
elif bot_type == const.CHATGPT: | |||
# ChatGPT 网页端web接口 | |||
from bot.chatgpt.chat_gpt_bot import ChatGPTBot | |||
return ChatGPTBot() | |||
elif bot_type == const.OPEN_AI: | |||
# OpenAI 官方对话模型API | |||
from bot.openai.open_ai_bot import OpenAIBot | |||
return OpenAIBot() | |||
elif bot_type == const.CHATGPTONAZURE: | |||
# Azure chatgpt service https://azure.microsoft.com/en-in/products/cognitive-services/openai-service/ | |||
from bot.chatgpt.chat_gpt_bot import AzureChatGPTBot | |||
return AzureChatGPTBot() | |||
raise RuntimeError |
@@ -1,42 +1,53 @@ | |||
# encoding:utf-8 | |||
import time | |||
import openai | |||
import openai.error | |||
from bot.bot import Bot | |||
from bot.chatgpt.chat_gpt_session import ChatGPTSession | |||
from bot.openai.open_ai_image import OpenAIImage | |||
from bot.session_manager import SessionManager | |||
from bridge.context import ContextType | |||
from bridge.reply import Reply, ReplyType | |||
from config import conf, load_config | |||
from common.log import logger | |||
from common.token_bucket import TokenBucket | |||
import openai | |||
import openai.error | |||
import time | |||
from config import conf, load_config | |||
# OpenAI对话模型API (可用) | |||
class ChatGPTBot(Bot,OpenAIImage): | |||
class ChatGPTBot(Bot, OpenAIImage): | |||
def __init__(self): | |||
super().__init__() | |||
# set the default api_key | |||
openai.api_key = conf().get('open_ai_api_key') | |||
if conf().get('open_ai_api_base'): | |||
openai.api_base = conf().get('open_ai_api_base') | |||
proxy = conf().get('proxy') | |||
openai.api_key = conf().get("open_ai_api_key") | |||
if conf().get("open_ai_api_base"): | |||
openai.api_base = conf().get("open_ai_api_base") | |||
proxy = conf().get("proxy") | |||
if proxy: | |||
openai.proxy = proxy | |||
if conf().get('rate_limit_chatgpt'): | |||
self.tb4chatgpt = TokenBucket(conf().get('rate_limit_chatgpt', 20)) | |||
self.sessions = SessionManager(ChatGPTSession, model= conf().get("model") or "gpt-3.5-turbo") | |||
self.args ={ | |||
if conf().get("rate_limit_chatgpt"): | |||
self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20)) | |||
self.sessions = SessionManager( | |||
ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo" | |||
) | |||
self.args = { | |||
"model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称 | |||
"temperature":conf().get('temperature', 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性 | |||
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性 | |||
# "max_tokens":4096, # 回复最大的字符数 | |||
"top_p":1, | |||
"frequency_penalty":conf().get('frequency_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 | |||
"presence_penalty":conf().get('presence_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 | |||
"request_timeout": conf().get('request_timeout', None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 | |||
"timeout": conf().get('request_timeout', None), #重试超时时间,在这个时间内,将会自动重试 | |||
"top_p": 1, | |||
"frequency_penalty": conf().get( | |||
"frequency_penalty", 0.0 | |||
), # [-2,2]之间,该值越大则更倾向于产生不同的内容 | |||
"presence_penalty": conf().get( | |||
"presence_penalty", 0.0 | |||
), # [-2,2]之间,该值越大则更倾向于产生不同的内容 | |||
"request_timeout": conf().get( | |||
"request_timeout", None | |||
), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 | |||
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试 | |||
} | |||
def reply(self, query, context=None): | |||
@@ -44,39 +55,50 @@ class ChatGPTBot(Bot,OpenAIImage): | |||
if context.type == ContextType.TEXT: | |||
logger.info("[CHATGPT] query={}".format(query)) | |||
session_id = context['session_id'] | |||
session_id = context["session_id"] | |||
reply = None | |||
clear_memory_commands = conf().get('clear_memory_commands', ['#清除记忆']) | |||
clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"]) | |||
if query in clear_memory_commands: | |||
self.sessions.clear_session(session_id) | |||
reply = Reply(ReplyType.INFO, '记忆已清除') | |||
elif query == '#清除所有': | |||
reply = Reply(ReplyType.INFO, "记忆已清除") | |||
elif query == "#清除所有": | |||
self.sessions.clear_all_session() | |||
reply = Reply(ReplyType.INFO, '所有人记忆已清除') | |||
elif query == '#更新配置': | |||
reply = Reply(ReplyType.INFO, "所有人记忆已清除") | |||
elif query == "#更新配置": | |||
load_config() | |||
reply = Reply(ReplyType.INFO, '配置已更新') | |||
reply = Reply(ReplyType.INFO, "配置已更新") | |||
if reply: | |||
return reply | |||
session = self.sessions.session_query(query, session_id) | |||
logger.debug("[CHATGPT] session query={}".format(session.messages)) | |||
api_key = context.get('openai_api_key') | |||
api_key = context.get("openai_api_key") | |||
# if context.get('stream'): | |||
# # reply in stream | |||
# return self.reply_text_stream(query, new_query, session_id) | |||
reply_content = self.reply_text(session, api_key) | |||
logger.debug("[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session.messages, session_id, reply_content["content"], reply_content["completion_tokens"])) | |||
if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0: | |||
reply = Reply(ReplyType.ERROR, reply_content['content']) | |||
logger.debug( | |||
"[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format( | |||
session.messages, | |||
session_id, | |||
reply_content["content"], | |||
reply_content["completion_tokens"], | |||
) | |||
) | |||
if ( | |||
reply_content["completion_tokens"] == 0 | |||
and len(reply_content["content"]) > 0 | |||
): | |||
reply = Reply(ReplyType.ERROR, reply_content["content"]) | |||
elif reply_content["completion_tokens"] > 0: | |||
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"]) | |||
self.sessions.session_reply( | |||
reply_content["content"], session_id, reply_content["total_tokens"] | |||
) | |||
reply = Reply(ReplyType.TEXT, reply_content["content"]) | |||
else: | |||
reply = Reply(ReplyType.ERROR, reply_content['content']) | |||
reply = Reply(ReplyType.ERROR, reply_content["content"]) | |||
logger.debug("[CHATGPT] reply {} used 0 tokens.".format(reply_content)) | |||
return reply | |||
@@ -89,53 +111,55 @@ class ChatGPTBot(Bot,OpenAIImage): | |||
reply = Reply(ReplyType.ERROR, retstring) | |||
return reply | |||
else: | |||
reply = Reply(ReplyType.ERROR, 'Bot不支持处理{}类型的消息'.format(context.type)) | |||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type)) | |||
return reply | |||
def reply_text(self, session:ChatGPTSession, api_key=None, retry_count=0) -> dict: | |||
''' | |||
def reply_text(self, session: ChatGPTSession, api_key=None, retry_count=0) -> dict: | |||
""" | |||
call openai's ChatCompletion to get the answer | |||
:param session: a conversation session | |||
:param session_id: session id | |||
:param retry_count: retry count | |||
:return: {} | |||
''' | |||
""" | |||
try: | |||
if conf().get('rate_limit_chatgpt') and not self.tb4chatgpt.get_token(): | |||
if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token(): | |||
raise openai.error.RateLimitError("RateLimitError: rate limit exceeded") | |||
# if api_key == None, the default openai.api_key will be used | |||
response = openai.ChatCompletion.create( | |||
api_key=api_key, messages=session.messages, **self.args | |||
) | |||
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"])) | |||
return {"total_tokens": response["usage"]["total_tokens"], | |||
"completion_tokens": response["usage"]["completion_tokens"], | |||
"content": response.choices[0]['message']['content']} | |||
return { | |||
"total_tokens": response["usage"]["total_tokens"], | |||
"completion_tokens": response["usage"]["completion_tokens"], | |||
"content": response.choices[0]["message"]["content"], | |||
} | |||
except Exception as e: | |||
need_retry = retry_count < 2 | |||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} | |||
if isinstance(e, openai.error.RateLimitError): | |||
logger.warn("[CHATGPT] RateLimitError: {}".format(e)) | |||
result['content'] = "提问太快啦,请休息一下再问我吧" | |||
result["content"] = "提问太快啦,请休息一下再问我吧" | |||
if need_retry: | |||
time.sleep(5) | |||
elif isinstance(e, openai.error.Timeout): | |||
logger.warn("[CHATGPT] Timeout: {}".format(e)) | |||
result['content'] = "我没有收到你的消息" | |||
result["content"] = "我没有收到你的消息" | |||
if need_retry: | |||
time.sleep(5) | |||
elif isinstance(e, openai.error.APIConnectionError): | |||
logger.warn("[CHATGPT] APIConnectionError: {}".format(e)) | |||
need_retry = False | |||
result['content'] = "我连接不到你的网络" | |||
result["content"] = "我连接不到你的网络" | |||
else: | |||
logger.warn("[CHATGPT] Exception: {}".format(e)) | |||
need_retry = False | |||
self.sessions.clear_session(session.session_id) | |||
if need_retry: | |||
logger.warn("[CHATGPT] 第{}次重试".format(retry_count+1)) | |||
return self.reply_text(session, api_key, retry_count+1) | |||
logger.warn("[CHATGPT] 第{}次重试".format(retry_count + 1)) | |||
return self.reply_text(session, api_key, retry_count + 1) | |||
else: | |||
return result | |||
@@ -145,4 +169,4 @@ class AzureChatGPTBot(ChatGPTBot): | |||
super().__init__() | |||
openai.api_type = "azure" | |||
openai.api_version = "2023-03-15-preview" | |||
self.args["deployment_id"] = conf().get("azure_deployment_id") | |||
self.args["deployment_id"] = conf().get("azure_deployment_id") |
@@ -1,20 +1,23 @@ | |||
from bot.session_manager import Session | |||
from common.log import logger | |||
''' | |||
""" | |||
e.g. [ | |||
{"role": "system", "content": "You are a helpful assistant."}, | |||
{"role": "user", "content": "Who won the world series in 2020?"}, | |||
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."}, | |||
{"role": "user", "content": "Where was it played?"} | |||
] | |||
''' | |||
""" | |||
class ChatGPTSession(Session): | |||
def __init__(self, session_id, system_prompt=None, model= "gpt-3.5-turbo"): | |||
def __init__(self, session_id, system_prompt=None, model="gpt-3.5-turbo"): | |||
super().__init__(session_id, system_prompt) | |||
self.model = model | |||
self.reset() | |||
def discard_exceeding(self, max_tokens, cur_tokens= None): | |||
def discard_exceeding(self, max_tokens, cur_tokens=None): | |||
precise = True | |||
try: | |||
cur_tokens = self.calc_tokens() | |||
@@ -22,7 +25,9 @@ class ChatGPTSession(Session): | |||
precise = False | |||
if cur_tokens is None: | |||
raise e | |||
logger.debug("Exception when counting tokens precisely for query: {}".format(e)) | |||
logger.debug( | |||
"Exception when counting tokens precisely for query: {}".format(e) | |||
) | |||
while cur_tokens > max_tokens: | |||
if len(self.messages) > 2: | |||
self.messages.pop(1) | |||
@@ -34,25 +39,32 @@ class ChatGPTSession(Session): | |||
cur_tokens = cur_tokens - max_tokens | |||
break | |||
elif len(self.messages) == 2 and self.messages[1]["role"] == "user": | |||
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens)) | |||
logger.warn( | |||
"user message exceed max_tokens. total_tokens={}".format(cur_tokens) | |||
) | |||
break | |||
else: | |||
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages))) | |||
logger.debug( | |||
"max_tokens={}, total_tokens={}, len(messages)={}".format( | |||
max_tokens, cur_tokens, len(self.messages) | |||
) | |||
) | |||
break | |||
if precise: | |||
cur_tokens = self.calc_tokens() | |||
else: | |||
cur_tokens = cur_tokens - max_tokens | |||
return cur_tokens | |||
def calc_tokens(self): | |||
return num_tokens_from_messages(self.messages, self.model) | |||
# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb | |||
def num_tokens_from_messages(messages, model): | |||
"""Returns the number of tokens used by a list of messages.""" | |||
import tiktoken | |||
try: | |||
encoding = tiktoken.encoding_for_model(model) | |||
except KeyError: | |||
@@ -63,13 +75,17 @@ def num_tokens_from_messages(messages, model): | |||
elif model == "gpt-4": | |||
return num_tokens_from_messages(messages, model="gpt-4-0314") | |||
elif model == "gpt-3.5-turbo-0301": | |||
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n | |||
tokens_per_message = ( | |||
4 # every message follows <|start|>{role/name}\n{content}<|end|>\n | |||
) | |||
tokens_per_name = -1 # if there's a name, the role is omitted | |||
elif model == "gpt-4-0314": | |||
tokens_per_message = 3 | |||
tokens_per_name = 1 | |||
else: | |||
logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301.") | |||
logger.warn( | |||
f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301." | |||
) | |||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301") | |||
num_tokens = 0 | |||
for message in messages: | |||
@@ -1,41 +1,52 @@ | |||
# encoding:utf-8 | |||
import time | |||
import openai | |||
import openai.error | |||
from bot.bot import Bot | |||
from bot.openai.open_ai_image import OpenAIImage | |||
from bot.openai.open_ai_session import OpenAISession | |||
from bot.session_manager import SessionManager | |||
from bridge.context import ContextType | |||
from bridge.reply import Reply, ReplyType | |||
from config import conf | |||
from common.log import logger | |||
import openai | |||
import openai.error | |||
import time | |||
from config import conf | |||
user_session = dict() | |||
# OpenAI对话模型API (可用) | |||
class OpenAIBot(Bot, OpenAIImage): | |||
def __init__(self): | |||
super().__init__() | |||
openai.api_key = conf().get('open_ai_api_key') | |||
if conf().get('open_ai_api_base'): | |||
openai.api_base = conf().get('open_ai_api_base') | |||
proxy = conf().get('proxy') | |||
openai.api_key = conf().get("open_ai_api_key") | |||
if conf().get("open_ai_api_base"): | |||
openai.api_base = conf().get("open_ai_api_base") | |||
proxy = conf().get("proxy") | |||
if proxy: | |||
openai.proxy = proxy | |||
self.sessions = SessionManager(OpenAISession, model= conf().get("model") or "text-davinci-003") | |||
self.sessions = SessionManager( | |||
OpenAISession, model=conf().get("model") or "text-davinci-003" | |||
) | |||
self.args = { | |||
"model": conf().get("model") or "text-davinci-003", # 对话模型的名称 | |||
"temperature":conf().get('temperature', 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性 | |||
"max_tokens":1200, # 回复最大的字符数 | |||
"top_p":1, | |||
"frequency_penalty":conf().get('frequency_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 | |||
"presence_penalty":conf().get('presence_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 | |||
"request_timeout": conf().get('request_timeout', None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 | |||
"timeout": conf().get('request_timeout', None), #重试超时时间,在这个时间内,将会自动重试 | |||
"stop":["\n\n\n"] | |||
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性 | |||
"max_tokens": 1200, # 回复最大的字符数 | |||
"top_p": 1, | |||
"frequency_penalty": conf().get( | |||
"frequency_penalty", 0.0 | |||
), # [-2,2]之间,该值越大则更倾向于产生不同的内容 | |||
"presence_penalty": conf().get( | |||
"presence_penalty", 0.0 | |||
), # [-2,2]之间,该值越大则更倾向于产生不同的内容 | |||
"request_timeout": conf().get( | |||
"request_timeout", None | |||
), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 | |||
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试 | |||
"stop": ["\n\n\n"], | |||
} | |||
def reply(self, query, context=None): | |||
@@ -43,24 +54,34 @@ class OpenAIBot(Bot, OpenAIImage): | |||
if context and context.type: | |||
if context.type == ContextType.TEXT: | |||
logger.info("[OPEN_AI] query={}".format(query)) | |||
session_id = context['session_id'] | |||
session_id = context["session_id"] | |||
reply = None | |||
if query == '#清除记忆': | |||
if query == "#清除记忆": | |||
self.sessions.clear_session(session_id) | |||
reply = Reply(ReplyType.INFO, '记忆已清除') | |||
elif query == '#清除所有': | |||
reply = Reply(ReplyType.INFO, "记忆已清除") | |||
elif query == "#清除所有": | |||
self.sessions.clear_all_session() | |||
reply = Reply(ReplyType.INFO, '所有人记忆已清除') | |||
reply = Reply(ReplyType.INFO, "所有人记忆已清除") | |||
else: | |||
session = self.sessions.session_query(query, session_id) | |||
result = self.reply_text(session) | |||
total_tokens, completion_tokens, reply_content = result['total_tokens'], result['completion_tokens'], result['content'] | |||
logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens)) | |||
total_tokens, completion_tokens, reply_content = ( | |||
result["total_tokens"], | |||
result["completion_tokens"], | |||
result["content"], | |||
) | |||
logger.debug( | |||
"[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format( | |||
str(session), session_id, reply_content, completion_tokens | |||
) | |||
) | |||
if total_tokens == 0 : | |||
if total_tokens == 0: | |||
reply = Reply(ReplyType.ERROR, reply_content) | |||
else: | |||
self.sessions.session_reply(reply_content, session_id, total_tokens) | |||
self.sessions.session_reply( | |||
reply_content, session_id, total_tokens | |||
) | |||
reply = Reply(ReplyType.TEXT, reply_content) | |||
return reply | |||
elif context.type == ContextType.IMAGE_CREATE: | |||
@@ -72,42 +93,44 @@ class OpenAIBot(Bot, OpenAIImage): | |||
reply = Reply(ReplyType.ERROR, retstring) | |||
return reply | |||
def reply_text(self, session:OpenAISession, retry_count=0): | |||
def reply_text(self, session: OpenAISession, retry_count=0): | |||
try: | |||
response = openai.Completion.create( | |||
prompt=str(session), **self.args | |||
response = openai.Completion.create(prompt=str(session), **self.args) | |||
res_content = ( | |||
response.choices[0]["text"].strip().replace("<|endoftext|>", "") | |||
) | |||
res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '') | |||
total_tokens = response["usage"]["total_tokens"] | |||
completion_tokens = response["usage"]["completion_tokens"] | |||
logger.info("[OPEN_AI] reply={}".format(res_content)) | |||
return {"total_tokens": total_tokens, | |||
"completion_tokens": completion_tokens, | |||
"content": res_content} | |||
return { | |||
"total_tokens": total_tokens, | |||
"completion_tokens": completion_tokens, | |||
"content": res_content, | |||
} | |||
except Exception as e: | |||
need_retry = retry_count < 2 | |||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} | |||
if isinstance(e, openai.error.RateLimitError): | |||
logger.warn("[OPEN_AI] RateLimitError: {}".format(e)) | |||
result['content'] = "提问太快啦,请休息一下再问我吧" | |||
result["content"] = "提问太快啦,请休息一下再问我吧" | |||
if need_retry: | |||
time.sleep(5) | |||
elif isinstance(e, openai.error.Timeout): | |||
logger.warn("[OPEN_AI] Timeout: {}".format(e)) | |||
result['content'] = "我没有收到你的消息" | |||
result["content"] = "我没有收到你的消息" | |||
if need_retry: | |||
time.sleep(5) | |||
elif isinstance(e, openai.error.APIConnectionError): | |||
logger.warn("[OPEN_AI] APIConnectionError: {}".format(e)) | |||
need_retry = False | |||
result['content'] = "我连接不到你的网络" | |||
result["content"] = "我连接不到你的网络" | |||
else: | |||
logger.warn("[OPEN_AI] Exception: {}".format(e)) | |||
need_retry = False | |||
self.sessions.clear_session(session.session_id) | |||
if need_retry: | |||
logger.warn("[OPEN_AI] 第{}次重试".format(retry_count+1)) | |||
return self.reply_text(session, retry_count+1) | |||
logger.warn("[OPEN_AI] 第{}次重试".format(retry_count + 1)) | |||
return self.reply_text(session, retry_count + 1) | |||
else: | |||
return result | |||
return result |
@@ -1,38 +1,45 @@ | |||
import time | |||
import openai | |||
import openai.error | |||
from common.token_bucket import TokenBucket | |||
from common.log import logger | |||
from common.token_bucket import TokenBucket | |||
from config import conf | |||
# OPENAI提供的画图接口 | |||
class OpenAIImage(object): | |||
def __init__(self): | |||
openai.api_key = conf().get('open_ai_api_key') | |||
if conf().get('rate_limit_dalle'): | |||
self.tb4dalle = TokenBucket(conf().get('rate_limit_dalle', 50)) | |||
openai.api_key = conf().get("open_ai_api_key") | |||
if conf().get("rate_limit_dalle"): | |||
self.tb4dalle = TokenBucket(conf().get("rate_limit_dalle", 50)) | |||
def create_img(self, query, retry_count=0): | |||
try: | |||
if conf().get('rate_limit_dalle') and not self.tb4dalle.get_token(): | |||
if conf().get("rate_limit_dalle") and not self.tb4dalle.get_token(): | |||
return False, "请求太快了,请休息一下再问我吧" | |||
logger.info("[OPEN_AI] image_query={}".format(query)) | |||
response = openai.Image.create( | |||
prompt=query, #图片描述 | |||
n=1, #每次生成图片的数量 | |||
size="256x256" #图片大小,可选有 256x256, 512x512, 1024x1024 | |||
prompt=query, # 图片描述 | |||
n=1, # 每次生成图片的数量 | |||
size="256x256", # 图片大小,可选有 256x256, 512x512, 1024x1024 | |||
) | |||
image_url = response['data'][0]['url'] | |||
image_url = response["data"][0]["url"] | |||
logger.info("[OPEN_AI] image_url={}".format(image_url)) | |||
return True, image_url | |||
except openai.error.RateLimitError as e: | |||
logger.warn(e) | |||
if retry_count < 1: | |||
time.sleep(5) | |||
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1)) | |||
return self.create_img(query, retry_count+1) | |||
logger.warn( | |||
"[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format( | |||
retry_count + 1 | |||
) | |||
) | |||
return self.create_img(query, retry_count + 1) | |||
else: | |||
return False, "提问太快啦,请休息一下再问我吧" | |||
except Exception as e: | |||
logger.exception(e) | |||
return False, str(e) | |||
return False, str(e) |
@@ -1,32 +1,34 @@ | |||
from bot.session_manager import Session | |||
from common.log import logger | |||
class OpenAISession(Session): | |||
def __init__(self, session_id, system_prompt=None, model= "text-davinci-003"): | |||
def __init__(self, session_id, system_prompt=None, model="text-davinci-003"): | |||
super().__init__(session_id, system_prompt) | |||
self.model = model | |||
self.reset() | |||
def __str__(self): | |||
# 构造对话模型的输入 | |||
''' | |||
""" | |||
e.g. Q: xxx | |||
A: xxx | |||
Q: xxx | |||
''' | |||
""" | |||
prompt = "" | |||
for item in self.messages: | |||
if item['role'] == 'system': | |||
prompt += item['content'] + "<|endoftext|>\n\n\n" | |||
elif item['role'] == 'user': | |||
prompt += "Q: " + item['content'] + "\n" | |||
elif item['role'] == 'assistant': | |||
prompt += "\n\nA: " + item['content'] + "<|endoftext|>\n" | |||
if item["role"] == "system": | |||
prompt += item["content"] + "<|endoftext|>\n\n\n" | |||
elif item["role"] == "user": | |||
prompt += "Q: " + item["content"] + "\n" | |||
elif item["role"] == "assistant": | |||
prompt += "\n\nA: " + item["content"] + "<|endoftext|>\n" | |||
if len(self.messages) > 0 and self.messages[-1]['role'] == 'user': | |||
if len(self.messages) > 0 and self.messages[-1]["role"] == "user": | |||
prompt += "A: " | |||
return prompt | |||
def discard_exceeding(self, max_tokens, cur_tokens= None): | |||
def discard_exceeding(self, max_tokens, cur_tokens=None): | |||
precise = True | |||
try: | |||
cur_tokens = self.calc_tokens() | |||
@@ -34,7 +36,9 @@ class OpenAISession(Session): | |||
precise = False | |||
if cur_tokens is None: | |||
raise e | |||
logger.debug("Exception when counting tokens precisely for query: {}".format(e)) | |||
logger.debug( | |||
"Exception when counting tokens precisely for query: {}".format(e) | |||
) | |||
while cur_tokens > max_tokens: | |||
if len(self.messages) > 1: | |||
self.messages.pop(0) | |||
@@ -46,24 +50,34 @@ class OpenAISession(Session): | |||
cur_tokens = len(str(self)) | |||
break | |||
elif len(self.messages) == 1 and self.messages[0]["role"] == "user": | |||
logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens)) | |||
logger.warn( | |||
"user question exceed max_tokens. total_tokens={}".format( | |||
cur_tokens | |||
) | |||
) | |||
break | |||
else: | |||
logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages))) | |||
logger.debug( | |||
"max_tokens={}, total_tokens={}, len(conversation)={}".format( | |||
max_tokens, cur_tokens, len(self.messages) | |||
) | |||
) | |||
break | |||
if precise: | |||
cur_tokens = self.calc_tokens() | |||
else: | |||
cur_tokens = len(str(self)) | |||
return cur_tokens | |||
def calc_tokens(self): | |||
return num_tokens_from_string(str(self), self.model) | |||
# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb | |||
def num_tokens_from_string(string: str, model: str) -> int: | |||
"""Returns the number of tokens in a text string.""" | |||
import tiktoken | |||
encoding = tiktoken.encoding_for_model(model) | |||
num_tokens = len(encoding.encode(string,disallowed_special=())) | |||
return num_tokens | |||
num_tokens = len(encoding.encode(string, disallowed_special=())) | |||
return num_tokens |
@@ -2,6 +2,7 @@ from common.expired_dict import ExpiredDict | |||
from common.log import logger | |||
from config import conf | |||
class Session(object): | |||
def __init__(self, session_id, system_prompt=None): | |||
self.session_id = session_id | |||
@@ -13,7 +14,7 @@ class Session(object): | |||
# 重置会话 | |||
def reset(self): | |||
system_item = {'role': 'system', 'content': self.system_prompt} | |||
system_item = {"role": "system", "content": self.system_prompt} | |||
self.messages = [system_item] | |||
def set_system_prompt(self, system_prompt): | |||
@@ -21,13 +22,13 @@ class Session(object): | |||
self.reset() | |||
def add_query(self, query): | |||
user_item = {'role': 'user', 'content': query} | |||
user_item = {"role": "user", "content": query} | |||
self.messages.append(user_item) | |||
def add_reply(self, reply): | |||
assistant_item = {'role': 'assistant', 'content': reply} | |||
assistant_item = {"role": "assistant", "content": reply} | |||
self.messages.append(assistant_item) | |||
def discard_exceeding(self, max_tokens=None, cur_tokens=None): | |||
raise NotImplementedError | |||
@@ -37,8 +38,8 @@ class Session(object): | |||
class SessionManager(object): | |||
def __init__(self, sessioncls, **session_args): | |||
if conf().get('expires_in_seconds'): | |||
sessions = ExpiredDict(conf().get('expires_in_seconds')) | |||
if conf().get("expires_in_seconds"): | |||
sessions = ExpiredDict(conf().get("expires_in_seconds")) | |||
else: | |||
sessions = dict() | |||
self.sessions = sessions | |||
@@ -46,20 +47,22 @@ class SessionManager(object): | |||
self.session_args = session_args | |||
def build_session(self, session_id, system_prompt=None): | |||
''' | |||
如果session_id不在sessions中,创建一个新的session并添加到sessions中 | |||
如果system_prompt不会空,会更新session的system_prompt并重置session | |||
''' | |||
""" | |||
如果session_id不在sessions中,创建一个新的session并添加到sessions中 | |||
如果system_prompt不会空,会更新session的system_prompt并重置session | |||
""" | |||
if session_id is None: | |||
return self.sessioncls(session_id, system_prompt, **self.session_args) | |||
if session_id not in self.sessions: | |||
self.sessions[session_id] = self.sessioncls(session_id, system_prompt, **self.session_args) | |||
self.sessions[session_id] = self.sessioncls( | |||
session_id, system_prompt, **self.session_args | |||
) | |||
elif system_prompt is not None: # 如果有新的system_prompt,更新并重置session | |||
self.sessions[session_id].set_system_prompt(system_prompt) | |||
session = self.sessions[session_id] | |||
return session | |||
def session_query(self, query, session_id): | |||
session = self.build_session(session_id) | |||
session.add_query(query) | |||
@@ -68,23 +71,33 @@ class SessionManager(object): | |||
total_tokens = session.discard_exceeding(max_tokens, None) | |||
logger.debug("prompt tokens used={}".format(total_tokens)) | |||
except Exception as e: | |||
logger.debug("Exception when counting tokens precisely for prompt: {}".format(str(e))) | |||
logger.debug( | |||
"Exception when counting tokens precisely for prompt: {}".format(str(e)) | |||
) | |||
return session | |||
def session_reply(self, reply, session_id, total_tokens = None): | |||
def session_reply(self, reply, session_id, total_tokens=None): | |||
session = self.build_session(session_id) | |||
session.add_reply(reply) | |||
try: | |||
max_tokens = conf().get("conversation_max_tokens", 1000) | |||
tokens_cnt = session.discard_exceeding(max_tokens, total_tokens) | |||
logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt)) | |||
logger.debug( | |||
"raw total_tokens={}, savesession tokens={}".format( | |||
total_tokens, tokens_cnt | |||
) | |||
) | |||
except Exception as e: | |||
logger.debug("Exception when counting tokens precisely for session: {}".format(str(e))) | |||
logger.debug( | |||
"Exception when counting tokens precisely for session: {}".format( | |||
str(e) | |||
) | |||
) | |||
return session | |||
def clear_session(self, session_id): | |||
if session_id in self.sessions: | |||
del(self.sessions[session_id]) | |||
del self.sessions[session_id] | |||
def clear_all_session(self): | |||
self.sessions.clear() |
@@ -1,31 +1,31 @@ | |||
from bot import bot_factory | |||
from bridge.context import Context | |||
from bridge.reply import Reply | |||
from common import const | |||
from common.log import logger | |||
from bot import bot_factory | |||
from common.singleton import singleton | |||
from voice import voice_factory | |||
from config import conf | |||
from common import const | |||
from voice import voice_factory | |||
@singleton | |||
class Bridge(object): | |||
def __init__(self): | |||
self.btype={ | |||
self.btype = { | |||
"chat": const.CHATGPT, | |||
"voice_to_text": conf().get("voice_to_text", "openai"), | |||
"text_to_voice": conf().get("text_to_voice", "google") | |||
"text_to_voice": conf().get("text_to_voice", "google"), | |||
} | |||
model_type = conf().get("model") | |||
if model_type in ["text-davinci-003"]: | |||
self.btype['chat'] = const.OPEN_AI | |||
self.btype["chat"] = const.OPEN_AI | |||
if conf().get("use_azure_chatgpt", False): | |||
self.btype['chat'] = const.CHATGPTONAZURE | |||
self.bots={} | |||
self.btype["chat"] = const.CHATGPTONAZURE | |||
self.bots = {} | |||
def get_bot(self,typename): | |||
def get_bot(self, typename): | |||
if self.bots.get(typename) is None: | |||
logger.info("create bot {} for {}".format(self.btype[typename],typename)) | |||
logger.info("create bot {} for {}".format(self.btype[typename], typename)) | |||
if typename == "text_to_voice": | |||
self.bots[typename] = voice_factory.create_voice(self.btype[typename]) | |||
elif typename == "voice_to_text": | |||
@@ -33,18 +33,15 @@ class Bridge(object): | |||
elif typename == "chat": | |||
self.bots[typename] = bot_factory.create_bot(self.btype[typename]) | |||
return self.bots[typename] | |||
def get_bot_type(self,typename): | |||
return self.btype[typename] | |||
def get_bot_type(self, typename): | |||
return self.btype[typename] | |||
def fetch_reply_content(self, query, context : Context) -> Reply: | |||
def fetch_reply_content(self, query, context: Context) -> Reply: | |||
return self.get_bot("chat").reply(query, context) | |||
def fetch_voice_to_text(self, voiceFile) -> Reply: | |||
return self.get_bot("voice_to_text").voiceToText(voiceFile) | |||
def fetch_text_to_voice(self, text) -> Reply: | |||
return self.get_bot("text_to_voice").textToVoice(text) | |||
@@ -2,36 +2,39 @@ | |||
from enum import Enum | |||
class ContextType (Enum): | |||
TEXT = 1 # 文本消息 | |||
VOICE = 2 # 音频消息 | |||
IMAGE = 3 # 图片消息 | |||
IMAGE_CREATE = 10 # 创建图片命令 | |||
class ContextType(Enum): | |||
TEXT = 1 # 文本消息 | |||
VOICE = 2 # 音频消息 | |||
IMAGE = 3 # 图片消息 | |||
IMAGE_CREATE = 10 # 创建图片命令 | |||
def __str__(self): | |||
return self.name | |||
class Context: | |||
def __init__(self, type : ContextType = None , content = None, kwargs = dict()): | |||
def __init__(self, type: ContextType = None, content=None, kwargs=dict()): | |||
self.type = type | |||
self.content = content | |||
self.kwargs = kwargs | |||
def __contains__(self, key): | |||
if key == 'type': | |||
if key == "type": | |||
return self.type is not None | |||
elif key == 'content': | |||
elif key == "content": | |||
return self.content is not None | |||
else: | |||
return key in self.kwargs | |||
def __getitem__(self, key): | |||
if key == 'type': | |||
if key == "type": | |||
return self.type | |||
elif key == 'content': | |||
elif key == "content": | |||
return self.content | |||
else: | |||
return self.kwargs[key] | |||
def get(self, key, default=None): | |||
try: | |||
return self[key] | |||
@@ -39,20 +42,22 @@ class Context: | |||
return default | |||
def __setitem__(self, key, value): | |||
if key == 'type': | |||
if key == "type": | |||
self.type = value | |||
elif key == 'content': | |||
elif key == "content": | |||
self.content = value | |||
else: | |||
self.kwargs[key] = value | |||
def __delitem__(self, key): | |||
if key == 'type': | |||
if key == "type": | |||
self.type = None | |||
elif key == 'content': | |||
elif key == "content": | |||
self.content = None | |||
else: | |||
del self.kwargs[key] | |||
def __str__(self): | |||
return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs) | |||
return "Context(type={}, content={}, kwargs={})".format( | |||
self.type, self.content, self.kwargs | |||
) |
@@ -1,22 +1,25 @@ | |||
# encoding:utf-8 | |||
from enum import Enum | |||
class ReplyType(Enum): | |||
TEXT = 1 # 文本 | |||
VOICE = 2 # 音频文件 | |||
IMAGE = 3 # 图片文件 | |||
IMAGE_URL = 4 # 图片URL | |||
TEXT = 1 # 文本 | |||
VOICE = 2 # 音频文件 | |||
IMAGE = 3 # 图片文件 | |||
IMAGE_URL = 4 # 图片URL | |||
INFO = 9 | |||
ERROR = 10 | |||
def __str__(self): | |||
return self.name | |||
class Reply: | |||
def __init__(self, type : ReplyType = None , content = None): | |||
def __init__(self, type: ReplyType = None, content=None): | |||
self.type = type | |||
self.content = content | |||
def __str__(self): | |||
return "Reply(type={}, content={})".format(self.type, self.content) | |||
return "Reply(type={}, content={})".format(self.type, self.content) |
@@ -6,8 +6,10 @@ from bridge.bridge import Bridge | |||
from bridge.context import Context | |||
from bridge.reply import * | |||
class Channel(object): | |||
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE, ReplyType.IMAGE] | |||
def startup(self): | |||
""" | |||
init channel | |||
@@ -27,15 +29,15 @@ class Channel(object): | |||
send message to user | |||
:param msg: message content | |||
:param receiver: receiver channel account | |||
:return: | |||
:return: | |||
""" | |||
raise NotImplementedError | |||
def build_reply_content(self, query, context : Context=None) -> Reply: | |||
def build_reply_content(self, query, context: Context = None) -> Reply: | |||
return Bridge().fetch_reply_content(query, context) | |||
def build_voice_to_text(self, voice_file) -> Reply: | |||
return Bridge().fetch_voice_to_text(voice_file) | |||
def build_text_to_voice(self, text) -> Reply: | |||
return Bridge().fetch_text_to_voice(text) |
@@ -2,25 +2,31 @@ | |||
channel factory | |||
""" | |||
def create_channel(channel_type): | |||
""" | |||
create a channel instance | |||
:param channel_type: channel type code | |||
:return: channel instance | |||
""" | |||
if channel_type == 'wx': | |||
if channel_type == "wx": | |||
from channel.wechat.wechat_channel import WechatChannel | |||
return WechatChannel() | |||
elif channel_type == 'wxy': | |||
elif channel_type == "wxy": | |||
from channel.wechat.wechaty_channel import WechatyChannel | |||
return WechatyChannel() | |||
elif channel_type == 'terminal': | |||
elif channel_type == "terminal": | |||
from channel.terminal.terminal_channel import TerminalChannel | |||
return TerminalChannel() | |||
elif channel_type == 'wechatmp': | |||
elif channel_type == "wechatmp": | |||
from channel.wechatmp.wechatmp_channel import WechatMPChannel | |||
return WechatMPChannel(passive_reply = True) | |||
elif channel_type == 'wechatmp_service': | |||
return WechatMPChannel(passive_reply=True) | |||
elif channel_type == "wechatmp_service": | |||
from channel.wechatmp.wechatmp_channel import WechatMPChannel | |||
return WechatMPChannel(passive_reply = False) | |||
return WechatMPChannel(passive_reply=False) | |||
raise RuntimeError |
@@ -1,137 +1,172 @@ | |||
from asyncio import CancelledError | |||
from concurrent.futures import Future, ThreadPoolExecutor | |||
import os | |||
import re | |||
import threading | |||
import time | |||
from common.dequeue import Dequeue | |||
from channel.channel import Channel | |||
from bridge.reply import * | |||
from asyncio import CancelledError | |||
from concurrent.futures import Future, ThreadPoolExecutor | |||
from bridge.context import * | |||
from config import conf | |||
from bridge.reply import * | |||
from channel.channel import Channel | |||
from common.dequeue import Dequeue | |||
from common.log import logger | |||
from config import conf | |||
from plugins import * | |||
try: | |||
from voice.audio_convert import any_to_wav | |||
except Exception as e: | |||
pass | |||
# 抽象类, 它包含了与消息通道无关的通用处理逻辑 | |||
class ChatChannel(Channel): | |||
name = None # 登录的用户名 | |||
user_id = None # 登录的用户id | |||
futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消 | |||
sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理 | |||
lock = threading.Lock() # 用于控制对sessions的访问 | |||
name = None # 登录的用户名 | |||
user_id = None # 登录的用户id | |||
futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消 | |||
sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理 | |||
lock = threading.Lock() # 用于控制对sessions的访问 | |||
handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池 | |||
def __init__(self): | |||
_thread = threading.Thread(target=self.consume) | |||
_thread.setDaemon(True) | |||
_thread.start() | |||
# 根据消息构造context,消息内容相关的触发项写在这里 | |||
def _compose_context(self, ctype: ContextType, content, **kwargs): | |||
context = Context(ctype, content) | |||
context.kwargs = kwargs | |||
# context首次传入时,origin_ctype是None, | |||
# context首次传入时,origin_ctype是None, | |||
# 引入的起因是:当输入语音时,会嵌套生成两个context,第一步语音转文本,第二步通过文本生成文字回复。 | |||
# origin_ctype用于第二步文本回复时,判断是否需要匹配前缀,如果是私聊的语音,就不需要匹配前缀 | |||
if 'origin_ctype' not in context: | |||
context['origin_ctype'] = ctype | |||
if "origin_ctype" not in context: | |||
context["origin_ctype"] = ctype | |||
# context首次传入时,receiver是None,根据类型设置receiver | |||
first_in = 'receiver' not in context | |||
first_in = "receiver" not in context | |||
# 群名匹配过程,设置session_id和receiver | |||
if first_in: # context首次传入时,receiver是None,根据类型设置receiver | |||
if first_in: # context首次传入时,receiver是None,根据类型设置receiver | |||
config = conf() | |||
cmsg = context['msg'] | |||
cmsg = context["msg"] | |||
if context.get("isgroup", False): | |||
group_name = cmsg.other_user_nickname | |||
group_id = cmsg.other_user_id | |||
group_name_white_list = config.get('group_name_white_list', []) | |||
group_name_keyword_white_list = config.get('group_name_keyword_white_list', []) | |||
if any([group_name in group_name_white_list, 'ALL_GROUP' in group_name_white_list, check_contain(group_name, group_name_keyword_white_list)]): | |||
group_chat_in_one_session = conf().get('group_chat_in_one_session', []) | |||
group_name_white_list = config.get("group_name_white_list", []) | |||
group_name_keyword_white_list = config.get( | |||
"group_name_keyword_white_list", [] | |||
) | |||
if any( | |||
[ | |||
group_name in group_name_white_list, | |||
"ALL_GROUP" in group_name_white_list, | |||
check_contain(group_name, group_name_keyword_white_list), | |||
] | |||
): | |||
group_chat_in_one_session = conf().get( | |||
"group_chat_in_one_session", [] | |||
) | |||
session_id = cmsg.actual_user_id | |||
if any([group_name in group_chat_in_one_session, 'ALL_GROUP' in group_chat_in_one_session]): | |||
if any( | |||
[ | |||
group_name in group_chat_in_one_session, | |||
"ALL_GROUP" in group_chat_in_one_session, | |||
] | |||
): | |||
session_id = group_id | |||
else: | |||
return None | |||
context['session_id'] = session_id | |||
context['receiver'] = group_id | |||
context["session_id"] = session_id | |||
context["receiver"] = group_id | |||
else: | |||
context['session_id'] = cmsg.other_user_id | |||
context['receiver'] = cmsg.other_user_id | |||
e_context = PluginManager().emit_event(EventContext(Event.ON_RECEIVE_MESSAGE, {'channel': self, 'context': context})) | |||
context = e_context['context'] | |||
context["session_id"] = cmsg.other_user_id | |||
context["receiver"] = cmsg.other_user_id | |||
e_context = PluginManager().emit_event( | |||
EventContext( | |||
Event.ON_RECEIVE_MESSAGE, {"channel": self, "context": context} | |||
) | |||
) | |||
context = e_context["context"] | |||
if e_context.is_pass() or context is None: | |||
return context | |||
if cmsg.from_user_id == self.user_id and not config.get('trigger_by_self', True): | |||
if cmsg.from_user_id == self.user_id and not config.get( | |||
"trigger_by_self", True | |||
): | |||
logger.debug("[WX]self message skipped") | |||
return None | |||
# 消息内容匹配过程,并处理content | |||
if ctype == ContextType.TEXT: | |||
if first_in and "」\n- - - - - - -" in content: # 初次匹配 过滤引用消息 | |||
if first_in and "」\n- - - - - - -" in content: # 初次匹配 过滤引用消息 | |||
logger.debug("[WX]reference query skipped") | |||
return None | |||
if context.get("isgroup", False): # 群聊 | |||
if context.get("isgroup", False): # 群聊 | |||
# 校验关键字 | |||
match_prefix = check_prefix(content, conf().get('group_chat_prefix')) | |||
match_contain = check_contain(content, conf().get('group_chat_keyword')) | |||
match_prefix = check_prefix(content, conf().get("group_chat_prefix")) | |||
match_contain = check_contain(content, conf().get("group_chat_keyword")) | |||
flag = False | |||
if match_prefix is not None or match_contain is not None: | |||
flag = True | |||
if match_prefix: | |||
content = content.replace(match_prefix, '', 1).strip() | |||
if context['msg'].is_at: | |||
content = content.replace(match_prefix, "", 1).strip() | |||
if context["msg"].is_at: | |||
logger.info("[WX]receive group at") | |||
if not conf().get("group_at_off", False): | |||
flag = True | |||
pattern = f'@{self.name}(\u2005|\u0020)' | |||
content = re.sub(pattern, r'', content) | |||
pattern = f"@{self.name}(\u2005|\u0020)" | |||
content = re.sub(pattern, r"", content) | |||
if not flag: | |||
if context["origin_ctype"] == ContextType.VOICE: | |||
logger.info("[WX]receive group voice, but checkprefix didn't match") | |||
logger.info( | |||
"[WX]receive group voice, but checkprefix didn't match" | |||
) | |||
return None | |||
else: # 单聊 | |||
match_prefix = check_prefix(content, conf().get('single_chat_prefix',[''])) | |||
if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容 | |||
content = content.replace(match_prefix, '', 1).strip() | |||
elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件 | |||
else: # 单聊 | |||
match_prefix = check_prefix( | |||
content, conf().get("single_chat_prefix", [""]) | |||
) | |||
if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容 | |||
content = content.replace(match_prefix, "", 1).strip() | |||
elif ( | |||
context["origin_ctype"] == ContextType.VOICE | |||
): # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件 | |||
pass | |||
else: | |||
return None | |||
img_match_prefix = check_prefix(content, conf().get('image_create_prefix')) | |||
return None | |||
img_match_prefix = check_prefix(content, conf().get("image_create_prefix")) | |||
if img_match_prefix: | |||
content = content.replace(img_match_prefix, '', 1) | |||
content = content.replace(img_match_prefix, "", 1) | |||
context.type = ContextType.IMAGE_CREATE | |||
else: | |||
context.type = ContextType.TEXT | |||
context.content = content.strip() | |||
if 'desire_rtype' not in context and conf().get('always_reply_voice') and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE: | |||
context['desire_rtype'] = ReplyType.VOICE | |||
elif context.type == ContextType.VOICE: | |||
if 'desire_rtype' not in context and conf().get('voice_reply_voice') and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE: | |||
context['desire_rtype'] = ReplyType.VOICE | |||
if ( | |||
"desire_rtype" not in context | |||
and conf().get("always_reply_voice") | |||
and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE | |||
): | |||
context["desire_rtype"] = ReplyType.VOICE | |||
elif context.type == ContextType.VOICE: | |||
if ( | |||
"desire_rtype" not in context | |||
and conf().get("voice_reply_voice") | |||
and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE | |||
): | |||
context["desire_rtype"] = ReplyType.VOICE | |||
return context | |||
def _handle(self, context: Context): | |||
if context is None or not context.content: | |||
return | |||
logger.debug('[WX] ready to handle context: {}'.format(context)) | |||
logger.debug("[WX] ready to handle context: {}".format(context)) | |||
# reply的构建步骤 | |||
reply = self._generate_reply(context) | |||
logger.debug('[WX] ready to decorate reply: {}'.format(reply)) | |||
logger.debug("[WX] ready to decorate reply: {}".format(reply)) | |||
# reply的包装步骤 | |||
reply = self._decorate_reply(context, reply) | |||
@@ -139,20 +174,31 @@ class ChatChannel(Channel): | |||
self._send_reply(context, reply) | |||
def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply: | |||
e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, { | |||
'channel': self, 'context': context, 'reply': reply})) | |||
reply = e_context['reply'] | |||
e_context = PluginManager().emit_event( | |||
EventContext( | |||
Event.ON_HANDLE_CONTEXT, | |||
{"channel": self, "context": context, "reply": reply}, | |||
) | |||
) | |||
reply = e_context["reply"] | |||
if not e_context.is_pass(): | |||
logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content)) | |||
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息 | |||
logger.debug( | |||
"[WX] ready to handle context: type={}, content={}".format( | |||
context.type, context.content | |||
) | |||
) | |||
if ( | |||
context.type == ContextType.TEXT | |||
or context.type == ContextType.IMAGE_CREATE | |||
): # 文字和图片消息 | |||
reply = super().build_reply_content(context.content, context) | |||
elif context.type == ContextType.VOICE: # 语音消息 | |||
cmsg = context['msg'] | |||
cmsg = context["msg"] | |||
cmsg.prepare() | |||
file_path = context.content | |||
wav_path = os.path.splitext(file_path)[0] + '.wav' | |||
wav_path = os.path.splitext(file_path)[0] + ".wav" | |||
try: | |||
any_to_wav(file_path, wav_path) | |||
any_to_wav(file_path, wav_path) | |||
except Exception as e: # 转换失败,直接使用mp3,对于某些api,mp3也可以识别 | |||
logger.warning("[WX]any to wav error, use raw path. " + str(e)) | |||
wav_path = file_path | |||
@@ -169,7 +215,8 @@ class ChatChannel(Channel): | |||
if reply.type == ReplyType.TEXT: | |||
new_context = self._compose_context( | |||
ContextType.TEXT, reply.content, **context.kwargs) | |||
ContextType.TEXT, reply.content, **context.kwargs | |||
) | |||
if new_context: | |||
reply = self._generate_reply(new_context) | |||
else: | |||
@@ -177,18 +224,21 @@ class ChatChannel(Channel): | |||
elif context.type == ContextType.IMAGE: # 图片消息,当前无默认逻辑 | |||
pass | |||
else: | |||
logger.error('[WX] unknown context type: {}'.format(context.type)) | |||
logger.error("[WX] unknown context type: {}".format(context.type)) | |||
return | |||
return reply | |||
def _decorate_reply(self, context: Context, reply: Reply) -> Reply: | |||
if reply and reply.type: | |||
e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, { | |||
'channel': self, 'context': context, 'reply': reply})) | |||
reply = e_context['reply'] | |||
desire_rtype = context.get('desire_rtype') | |||
e_context = PluginManager().emit_event( | |||
EventContext( | |||
Event.ON_DECORATE_REPLY, | |||
{"channel": self, "context": context, "reply": reply}, | |||
) | |||
) | |||
reply = e_context["reply"] | |||
desire_rtype = context.get("desire_rtype") | |||
if not e_context.is_pass() and reply and reply.type: | |||
if reply.type in self.NOT_SUPPORT_REPLYTYPE: | |||
logger.error("[WX]reply type not support: " + str(reply.type)) | |||
reply.type = ReplyType.ERROR | |||
@@ -196,59 +246,91 @@ class ChatChannel(Channel): | |||
if reply.type == ReplyType.TEXT: | |||
reply_text = reply.content | |||
if desire_rtype == ReplyType.VOICE and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE: | |||
if ( | |||
desire_rtype == ReplyType.VOICE | |||
and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE | |||
): | |||
reply = super().build_text_to_voice(reply.content) | |||
return self._decorate_reply(context, reply) | |||
if context.get("isgroup", False): | |||
reply_text = '@' + context['msg'].actual_user_nickname + ' ' + reply_text.strip() | |||
reply_text = conf().get("group_chat_reply_prefix", "") + reply_text | |||
reply_text = ( | |||
"@" | |||
+ context["msg"].actual_user_nickname | |||
+ " " | |||
+ reply_text.strip() | |||
) | |||
reply_text = ( | |||
conf().get("group_chat_reply_prefix", "") + reply_text | |||
) | |||
else: | |||
reply_text = conf().get("single_chat_reply_prefix", "") + reply_text | |||
reply_text = ( | |||
conf().get("single_chat_reply_prefix", "") + reply_text | |||
) | |||
reply.content = reply_text | |||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: | |||
reply.content = "["+str(reply.type)+"]\n" + reply.content | |||
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE: | |||
reply.content = "[" + str(reply.type) + "]\n" + reply.content | |||
elif ( | |||
reply.type == ReplyType.IMAGE_URL | |||
or reply.type == ReplyType.VOICE | |||
or reply.type == ReplyType.IMAGE | |||
): | |||
pass | |||
else: | |||
logger.error('[WX] unknown reply type: {}'.format(reply.type)) | |||
logger.error("[WX] unknown reply type: {}".format(reply.type)) | |||
return | |||
if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]: | |||
logger.warning('[WX] desire_rtype: {}, but reply type: {}'.format(context.get('desire_rtype'), reply.type)) | |||
if ( | |||
desire_rtype | |||
and desire_rtype != reply.type | |||
and reply.type not in [ReplyType.ERROR, ReplyType.INFO] | |||
): | |||
logger.warning( | |||
"[WX] desire_rtype: {}, but reply type: {}".format( | |||
context.get("desire_rtype"), reply.type | |||
) | |||
) | |||
return reply | |||
def _send_reply(self, context: Context, reply: Reply): | |||
if reply and reply.type: | |||
e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, { | |||
'channel': self, 'context': context, 'reply': reply})) | |||
reply = e_context['reply'] | |||
e_context = PluginManager().emit_event( | |||
EventContext( | |||
Event.ON_SEND_REPLY, | |||
{"channel": self, "context": context, "reply": reply}, | |||
) | |||
) | |||
reply = e_context["reply"] | |||
if not e_context.is_pass() and reply and reply.type: | |||
logger.debug('[WX] ready to send reply: {}, context: {}'.format(reply, context)) | |||
logger.debug( | |||
"[WX] ready to send reply: {}, context: {}".format(reply, context) | |||
) | |||
self._send(reply, context) | |||
def _send(self, reply: Reply, context: Context, retry_cnt = 0): | |||
def _send(self, reply: Reply, context: Context, retry_cnt=0): | |||
try: | |||
self.send(reply, context) | |||
except Exception as e: | |||
logger.error('[WX] sendMsg error: {}'.format(str(e))) | |||
logger.error("[WX] sendMsg error: {}".format(str(e))) | |||
if isinstance(e, NotImplementedError): | |||
return | |||
logger.exception(e) | |||
if retry_cnt < 2: | |||
time.sleep(3+3*retry_cnt) | |||
self._send(reply, context, retry_cnt+1) | |||
time.sleep(3 + 3 * retry_cnt) | |||
self._send(reply, context, retry_cnt + 1) | |||
def _success_callback(self, session_id, **kwargs):# 线程正常结束时的回调函数 | |||
def _success_callback(self, session_id, **kwargs): # 线程正常结束时的回调函数 | |||
logger.debug("Worker return success, session_id = {}".format(session_id)) | |||
def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数 | |||
def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数 | |||
logger.exception("Worker return exception: {}".format(exception)) | |||
def _thread_pool_callback(self, session_id, **kwargs): | |||
def func(worker:Future): | |||
def func(worker: Future): | |||
try: | |||
worker_exception = worker.exception() | |||
if worker_exception: | |||
self._fail_callback(session_id, exception = worker_exception, **kwargs) | |||
self._fail_callback( | |||
session_id, exception=worker_exception, **kwargs | |||
) | |||
else: | |||
self._success_callback(session_id, **kwargs) | |||
except CancelledError as e: | |||
@@ -257,15 +339,19 @@ class ChatChannel(Channel): | |||
logger.exception("Worker raise exception: {}".format(e)) | |||
with self.lock: | |||
self.sessions[session_id][1].release() | |||
return func | |||
def produce(self, context: Context): | |||
session_id = context['session_id'] | |||
session_id = context["session_id"] | |||
with self.lock: | |||
if session_id not in self.sessions: | |||
self.sessions[session_id] = [Dequeue(), threading.BoundedSemaphore(conf().get("concurrency_in_session", 4))] | |||
if context.type == ContextType.TEXT and context.content.startswith("#"): | |||
self.sessions[session_id][0].putleft(context) # 优先处理管理命令 | |||
self.sessions[session_id] = [ | |||
Dequeue(), | |||
threading.BoundedSemaphore(conf().get("concurrency_in_session", 4)), | |||
] | |||
if context.type == ContextType.TEXT and context.content.startswith("#"): | |||
self.sessions[session_id][0].putleft(context) # 优先处理管理命令 | |||
else: | |||
self.sessions[session_id][0].put(context) | |||
@@ -276,44 +362,58 @@ class ChatChannel(Channel): | |||
session_ids = list(self.sessions.keys()) | |||
for session_id in session_ids: | |||
context_queue, semaphore = self.sessions[session_id] | |||
if semaphore.acquire(blocking = False): # 等线程处理完毕才能删除 | |||
if semaphore.acquire(blocking=False): # 等线程处理完毕才能删除 | |||
if not context_queue.empty(): | |||
context = context_queue.get() | |||
logger.debug("[WX] consume context: {}".format(context)) | |||
future:Future = self.handler_pool.submit(self._handle, context) | |||
future.add_done_callback(self._thread_pool_callback(session_id, context = context)) | |||
future: Future = self.handler_pool.submit( | |||
self._handle, context | |||
) | |||
future.add_done_callback( | |||
self._thread_pool_callback(session_id, context=context) | |||
) | |||
if session_id not in self.futures: | |||
self.futures[session_id] = [] | |||
self.futures[session_id].append(future) | |||
elif semaphore._initial_value == semaphore._value+1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕 | |||
self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()] | |||
assert len(self.futures[session_id]) == 0, "thread pool error" | |||
elif ( | |||
semaphore._initial_value == semaphore._value + 1 | |||
): # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕 | |||
self.futures[session_id] = [ | |||
t for t in self.futures[session_id] if not t.done() | |||
] | |||
assert ( | |||
len(self.futures[session_id]) == 0 | |||
), "thread pool error" | |||
del self.sessions[session_id] | |||
else: | |||
semaphore.release() | |||
time.sleep(0.1) | |||
# 取消session_id对应的所有任务,只能取消排队的消息和已提交线程池但未执行的任务 | |||
def cancel_session(self, session_id): | |||
def cancel_session(self, session_id): | |||
with self.lock: | |||
if session_id in self.sessions: | |||
for future in self.futures[session_id]: | |||
future.cancel() | |||
cnt = self.sessions[session_id][0].qsize() | |||
if cnt>0: | |||
logger.info("Cancel {} messages in session {}".format(cnt, session_id)) | |||
if cnt > 0: | |||
logger.info( | |||
"Cancel {} messages in session {}".format(cnt, session_id) | |||
) | |||
self.sessions[session_id][0] = Dequeue() | |||
def cancel_all_session(self): | |||
with self.lock: | |||
for session_id in self.sessions: | |||
for future in self.futures[session_id]: | |||
future.cancel() | |||
cnt = self.sessions[session_id][0].qsize() | |||
if cnt>0: | |||
logger.info("Cancel {} messages in session {}".format(cnt, session_id)) | |||
if cnt > 0: | |||
logger.info( | |||
"Cancel {} messages in session {}".format(cnt, session_id) | |||
) | |||
self.sessions[session_id][0] = Dequeue() | |||
def check_prefix(content, prefix_list): | |||
if not prefix_list: | |||
@@ -323,6 +423,7 @@ def check_prefix(content, prefix_list): | |||
return prefix | |||
return None | |||
def check_contain(content, keyword_list): | |||
if not keyword_list: | |||
return None | |||
@@ -1,5 +1,4 @@ | |||
""" | |||
""" | |||
本类表示聊天消息,用于对itchat和wechaty的消息进行统一的封装。 | |||
填好必填项(群聊6个,非群聊8个),即可接入ChatChannel,并支持插件,参考TerminalChannel | |||
@@ -20,7 +19,7 @@ other_user_id: 对方的id,如果你是发送者,那这个就是接收者id | |||
other_user_nickname: 同上 | |||
is_group: 是否是群消息 (群聊必填) | |||
is_at: 是否被at | |||
is_at: 是否被at | |||
- (群消息时,一般会存在实际发送者,是群内某个成员的id和昵称,下列项仅在群消息时存在) | |||
actual_user_id: 实际发送者id (群聊必填) | |||
@@ -34,20 +33,22 @@ _prepared: 是否已经调用过准备函数 | |||
_rawmsg: 原始消息对象 | |||
""" | |||
class ChatMessage(object): | |||
msg_id = None | |||
create_time = None | |||
ctype = None | |||
content = None | |||
from_user_id = None | |||
from_user_nickname = None | |||
to_user_id = None | |||
to_user_nickname = None | |||
other_user_id = None | |||
other_user_nickname = None | |||
is_group = False | |||
is_at = False | |||
actual_user_id = None | |||
@@ -57,8 +58,7 @@ class ChatMessage(object): | |||
_prepared = False | |||
_rawmsg = None | |||
def __init__(self,_rawmsg): | |||
def __init__(self, _rawmsg): | |||
self._rawmsg = _rawmsg | |||
def prepare(self): | |||
@@ -67,7 +67,7 @@ class ChatMessage(object): | |||
self._prepare_fn() | |||
def __str__(self): | |||
return 'ChatMessage: id={}, create_time={}, ctype={}, content={}, from_user_id={}, from_user_nickname={}, to_user_id={}, to_user_nickname={}, other_user_id={}, other_user_nickname={}, is_group={}, is_at={}, actual_user_id={}, actual_user_nickname={}'.format( | |||
return "ChatMessage: id={}, create_time={}, ctype={}, content={}, from_user_id={}, from_user_nickname={}, to_user_id={}, to_user_nickname={}, other_user_id={}, other_user_nickname={}, is_group={}, is_at={}, actual_user_id={}, actual_user_nickname={}".format( | |||
self.msg_id, | |||
self.create_time, | |||
self.ctype, | |||
@@ -82,4 +82,4 @@ class ChatMessage(object): | |||
self.is_at, | |||
self.actual_user_id, | |||
self.actual_user_nickname, | |||
) | |||
) |
@@ -1,14 +1,23 @@ | |||
import sys | |||
from bridge.context import * | |||
from bridge.reply import Reply, ReplyType | |||
from channel.chat_channel import ChatChannel, check_prefix | |||
from channel.chat_message import ChatMessage | |||
import sys | |||
from config import conf | |||
from common.log import logger | |||
from config import conf | |||
class TerminalMessage(ChatMessage): | |||
def __init__(self, msg_id, content, ctype = ContextType.TEXT, from_user_id = "User", to_user_id = "Chatgpt", other_user_id = "Chatgpt"): | |||
def __init__( | |||
self, | |||
msg_id, | |||
content, | |||
ctype=ContextType.TEXT, | |||
from_user_id="User", | |||
to_user_id="Chatgpt", | |||
other_user_id="Chatgpt", | |||
): | |||
self.msg_id = msg_id | |||
self.ctype = ctype | |||
self.content = content | |||
@@ -16,6 +25,7 @@ class TerminalMessage(ChatMessage): | |||
self.to_user_id = to_user_id | |||
self.other_user_id = other_user_id | |||
class TerminalChannel(ChatChannel): | |||
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE] | |||
@@ -23,14 +33,18 @@ class TerminalChannel(ChatChannel): | |||
print("\nBot:") | |||
if reply.type == ReplyType.IMAGE: | |||
from PIL import Image | |||
image_storage = reply.content | |||
image_storage.seek(0) | |||
img = Image.open(image_storage) | |||
print("<IMAGE>") | |||
img.show() | |||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 | |||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 | |||
import io | |||
import requests | |||
from PIL import Image | |||
import requests,io | |||
img_url = reply.content | |||
pic_res = requests.get(img_url, stream=True) | |||
image_storage = io.BytesIO() | |||
@@ -59,11 +73,13 @@ class TerminalChannel(ChatChannel): | |||
print("\nExiting...") | |||
sys.exit() | |||
msg_id += 1 | |||
trigger_prefixs = conf().get("single_chat_prefix",[""]) | |||
trigger_prefixs = conf().get("single_chat_prefix", [""]) | |||
if check_prefix(prompt, trigger_prefixs) is None: | |||
prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀 | |||
context = self._compose_context(ContextType.TEXT, prompt, msg = TerminalMessage(msg_id, prompt)) | |||
prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀 | |||
context = self._compose_context( | |||
ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt) | |||
) | |||
if context: | |||
self.produce(context) | |||
else: | |||
@@ -4,40 +4,45 @@ | |||
wechat channel | |||
""" | |||
import io | |||
import json | |||
import os | |||
import threading | |||
import requests | |||
import io | |||
import time | |||
import json | |||
import requests | |||
from bridge.context import * | |||
from bridge.reply import * | |||
from channel.chat_channel import ChatChannel | |||
from channel.wechat.wechat_message import * | |||
from common.singleton import singleton | |||
from common.expired_dict import ExpiredDict | |||
from common.log import logger | |||
from common.singleton import singleton | |||
from common.time_check import time_checker | |||
from config import conf | |||
from lib import itchat | |||
from lib.itchat.content import * | |||
from bridge.reply import * | |||
from bridge.context import * | |||
from config import conf | |||
from common.time_check import time_checker | |||
from common.expired_dict import ExpiredDict | |||
from plugins import * | |||
@itchat.msg_register([TEXT,VOICE,PICTURE]) | |||
@itchat.msg_register([TEXT, VOICE, PICTURE]) | |||
def handler_single_msg(msg): | |||
# logger.debug("handler_single_msg: {}".format(msg)) | |||
if msg['Type'] == PICTURE and msg['MsgType'] == 47: | |||
if msg["Type"] == PICTURE and msg["MsgType"] == 47: | |||
return None | |||
WechatChannel().handle_single(WeChatMessage(msg)) | |||
return None | |||
@itchat.msg_register([TEXT,VOICE,PICTURE], isGroupChat=True) | |||
@itchat.msg_register([TEXT, VOICE, PICTURE], isGroupChat=True) | |||
def handler_group_msg(msg): | |||
if msg['Type'] == PICTURE and msg['MsgType'] == 47: | |||
if msg["Type"] == PICTURE and msg["MsgType"] == 47: | |||
return None | |||
WechatChannel().handle_group(WeChatMessage(msg,True)) | |||
WechatChannel().handle_group(WeChatMessage(msg, True)) | |||
return None | |||
def _check(func): | |||
def wrapper(self, cmsg: ChatMessage): | |||
msgId = cmsg.msg_id | |||
@@ -45,21 +50,27 @@ def _check(func): | |||
logger.info("Wechat message {} already received, ignore".format(msgId)) | |||
return | |||
self.receivedMsgs[msgId] = cmsg | |||
create_time = cmsg.create_time # 消息时间戳 | |||
if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息 | |||
create_time = cmsg.create_time # 消息时间戳 | |||
if ( | |||
conf().get("hot_reload") == True | |||
and int(create_time) < int(time.time()) - 60 | |||
): # 跳过1分钟前的历史消息 | |||
logger.debug("[WX]history message {} skipped".format(msgId)) | |||
return | |||
return func(self, cmsg) | |||
return wrapper | |||
#可用的二维码生成接口 | |||
#https://api.qrserver.com/v1/create-qr-code/?size=400×400&data=https://www.abc.com | |||
#https://api.isoyu.com/qr/?m=1&e=L&p=20&url=https://www.abc.com | |||
def qrCallback(uuid,status,qrcode): | |||
# 可用的二维码生成接口 | |||
# https://api.qrserver.com/v1/create-qr-code/?size=400×400&data=https://www.abc.com | |||
# https://api.isoyu.com/qr/?m=1&e=L&p=20&url=https://www.abc.com | |||
def qrCallback(uuid, status, qrcode): | |||
# logger.debug("qrCallback: {} {}".format(uuid,status)) | |||
if status == '0': | |||
if status == "0": | |||
try: | |||
from PIL import Image | |||
img = Image.open(io.BytesIO(qrcode)) | |||
_thread = threading.Thread(target=img.show, args=("QRCode",)) | |||
_thread.setDaemon(True) | |||
@@ -68,35 +79,43 @@ def qrCallback(uuid,status,qrcode): | |||
pass | |||
import qrcode | |||
url = f"https://login.weixin.qq.com/l/{uuid}" | |||
qr_api1="https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url) | |||
qr_api2="https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(url) | |||
qr_api3="https://api.pwmqr.com/qrcode/create/?url={}".format(url) | |||
qr_api4="https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format(url) | |||
qr_api1 = "https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url) | |||
qr_api2 = ( | |||
"https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format( | |||
url | |||
) | |||
) | |||
qr_api3 = "https://api.pwmqr.com/qrcode/create/?url={}".format(url) | |||
qr_api4 = "https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format( | |||
url | |||
) | |||
print("You can also scan QRCode in any website below:") | |||
print(qr_api3) | |||
print(qr_api4) | |||
print(qr_api2) | |||
print(qr_api1) | |||
qr = qrcode.QRCode(border=1) | |||
qr.add_data(url) | |||
qr.make(fit=True) | |||
qr.print_ascii(invert=True) | |||
@singleton | |||
class WechatChannel(ChatChannel): | |||
NOT_SUPPORT_REPLYTYPE = [] | |||
def __init__(self): | |||
super().__init__() | |||
self.receivedMsgs = ExpiredDict(60*60*24) | |||
self.receivedMsgs = ExpiredDict(60 * 60 * 24) | |||
def startup(self): | |||
itchat.instance.receivingRetryCount = 600 # 修改断线超时时间 | |||
# login by scan QRCode | |||
hotReload = conf().get('hot_reload', False) | |||
hotReload = conf().get("hot_reload", False) | |||
try: | |||
itchat.auto_login(enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback) | |||
except Exception as e: | |||
@@ -104,12 +123,18 @@ class WechatChannel(ChatChannel): | |||
logger.error("Hot reload failed, try to login without hot reload") | |||
itchat.logout() | |||
os.remove("itchat.pkl") | |||
itchat.auto_login(enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback) | |||
itchat.auto_login( | |||
enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback | |||
) | |||
else: | |||
raise e | |||
self.user_id = itchat.instance.storageClass.userName | |||
self.name = itchat.instance.storageClass.nickName | |||
logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name)) | |||
logger.info( | |||
"Wechat login success, user_id: {}, nickname: {}".format( | |||
self.user_id, self.name | |||
) | |||
) | |||
# start message listener | |||
itchat.run() | |||
@@ -127,24 +152,30 @@ class WechatChannel(ChatChannel): | |||
@time_checker | |||
@_check | |||
def handle_single(self, cmsg : ChatMessage): | |||
def handle_single(self, cmsg: ChatMessage): | |||
if cmsg.ctype == ContextType.VOICE: | |||
if conf().get('speech_recognition') != True: | |||
if conf().get("speech_recognition") != True: | |||
return | |||
logger.debug("[WX]receive voice msg: {}".format(cmsg.content)) | |||
elif cmsg.ctype == ContextType.IMAGE: | |||
logger.debug("[WX]receive image msg: {}".format(cmsg.content)) | |||
else: | |||
logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg)) | |||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg) | |||
logger.debug( | |||
"[WX]receive text msg: {}, cmsg={}".format( | |||
json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg | |||
) | |||
) | |||
context = self._compose_context( | |||
cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg | |||
) | |||
if context: | |||
self.produce(context) | |||
@time_checker | |||
@_check | |||
def handle_group(self, cmsg : ChatMessage): | |||
def handle_group(self, cmsg: ChatMessage): | |||
if cmsg.ctype == ContextType.VOICE: | |||
if conf().get('speech_recognition') != True: | |||
if conf().get("speech_recognition") != True: | |||
return | |||
logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content)) | |||
elif cmsg.ctype == ContextType.IMAGE: | |||
@@ -152,23 +183,25 @@ class WechatChannel(ChatChannel): | |||
else: | |||
# logger.debug("[WX]receive group msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg)) | |||
pass | |||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg) | |||
context = self._compose_context( | |||
cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg | |||
) | |||
if context: | |||
self.produce(context) | |||
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息 | |||
def send(self, reply: Reply, context: Context): | |||
receiver = context["receiver"] | |||
if reply.type == ReplyType.TEXT: | |||
itchat.send(reply.content, toUserName=receiver) | |||
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) | |||
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver)) | |||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: | |||
itchat.send(reply.content, toUserName=receiver) | |||
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) | |||
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver)) | |||
elif reply.type == ReplyType.VOICE: | |||
itchat.send_file(reply.content, toUserName=receiver) | |||
logger.info('[WX] sendFile={}, receiver={}'.format(reply.content, receiver)) | |||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 | |||
logger.info("[WX] sendFile={}, receiver={}".format(reply.content, receiver)) | |||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 | |||
img_url = reply.content | |||
pic_res = requests.get(img_url, stream=True) | |||
image_storage = io.BytesIO() | |||
@@ -176,9 +209,9 @@ class WechatChannel(ChatChannel): | |||
image_storage.write(block) | |||
image_storage.seek(0) | |||
itchat.send_image(image_storage, toUserName=receiver) | |||
logger.info('[WX] sendImage url={}, receiver={}'.format(img_url,receiver)) | |||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片 | |||
logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver)) | |||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片 | |||
image_storage = reply.content | |||
image_storage.seek(0) | |||
itchat.send_image(image_storage, toUserName=receiver) | |||
logger.info('[WX] sendImage, receiver={}'.format(receiver)) | |||
logger.info("[WX] sendImage, receiver={}".format(receiver)) |
@@ -1,54 +1,54 @@ | |||
from bridge.context import ContextType | |||
from channel.chat_message import ChatMessage | |||
from common.tmp_dir import TmpDir | |||
from common.log import logger | |||
from lib.itchat.content import * | |||
from common.tmp_dir import TmpDir | |||
from lib import itchat | |||
from lib.itchat.content import * | |||
class WeChatMessage(ChatMessage): | |||
class WeChatMessage(ChatMessage): | |||
def __init__(self, itchat_msg, is_group=False): | |||
super().__init__( itchat_msg) | |||
self.msg_id = itchat_msg['MsgId'] | |||
self.create_time = itchat_msg['CreateTime'] | |||
super().__init__(itchat_msg) | |||
self.msg_id = itchat_msg["MsgId"] | |||
self.create_time = itchat_msg["CreateTime"] | |||
self.is_group = is_group | |||
if itchat_msg['Type'] == TEXT: | |||
if itchat_msg["Type"] == TEXT: | |||
self.ctype = ContextType.TEXT | |||
self.content = itchat_msg['Text'] | |||
elif itchat_msg['Type'] == VOICE: | |||
self.content = itchat_msg["Text"] | |||
elif itchat_msg["Type"] == VOICE: | |||
self.ctype = ContextType.VOICE | |||
self.content = TmpDir().path() + itchat_msg['FileName'] # content直接存临时目录路径 | |||
self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径 | |||
self._prepare_fn = lambda: itchat_msg.download(self.content) | |||
elif itchat_msg['Type'] == PICTURE and itchat_msg['MsgType'] == 3: | |||
elif itchat_msg["Type"] == PICTURE and itchat_msg["MsgType"] == 3: | |||
self.ctype = ContextType.IMAGE | |||
self.content = TmpDir().path() + itchat_msg['FileName'] # content直接存临时目录路径 | |||
self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径 | |||
self._prepare_fn = lambda: itchat_msg.download(self.content) | |||
else: | |||
raise NotImplementedError("Unsupported message type: {}".format(itchat_msg['Type'])) | |||
self.from_user_id = itchat_msg['FromUserName'] | |||
self.to_user_id = itchat_msg['ToUserName'] | |||
raise NotImplementedError( | |||
"Unsupported message type: {}".format(itchat_msg["Type"]) | |||
) | |||
self.from_user_id = itchat_msg["FromUserName"] | |||
self.to_user_id = itchat_msg["ToUserName"] | |||
user_id = itchat.instance.storageClass.userName | |||
nickname = itchat.instance.storageClass.nickName | |||
# 虽然from_user_id和to_user_id用的少,但是为了保持一致性,还是要填充一下 | |||
# 以下很繁琐,一句话总结:能填的都填了。 | |||
if self.from_user_id == user_id: | |||
self.from_user_nickname = nickname | |||
if self.to_user_id == user_id: | |||
self.to_user_nickname = nickname | |||
try: # 陌生人时候, 'User'字段可能不存在 | |||
self.other_user_id = itchat_msg['User']['UserName'] | |||
self.other_user_nickname = itchat_msg['User']['NickName'] | |||
try: # 陌生人时候, 'User'字段可能不存在 | |||
self.other_user_id = itchat_msg["User"]["UserName"] | |||
self.other_user_nickname = itchat_msg["User"]["NickName"] | |||
if self.other_user_id == self.from_user_id: | |||
self.from_user_nickname = self.other_user_nickname | |||
if self.other_user_id == self.to_user_id: | |||
self.to_user_nickname = self.other_user_nickname | |||
except KeyError as e: # 处理偶尔没有对方信息的情况 | |||
except KeyError as e: # 处理偶尔没有对方信息的情况 | |||
logger.warn("[WX]get other_user_id failed: " + str(e)) | |||
if self.from_user_id == user_id: | |||
self.other_user_id = self.to_user_id | |||
@@ -56,6 +56,6 @@ class WeChatMessage(ChatMessage): | |||
self.other_user_id = self.from_user_id | |||
if self.is_group: | |||
self.is_at = itchat_msg['IsAt'] | |||
self.actual_user_id = itchat_msg['ActualUserName'] | |||
self.actual_user_nickname = itchat_msg['ActualNickName'] | |||
self.is_at = itchat_msg["IsAt"] | |||
self.actual_user_id = itchat_msg["ActualUserName"] | |||
self.actual_user_nickname = itchat_msg["ActualNickName"] |
@@ -4,104 +4,118 @@ | |||
wechaty channel | |||
Python Wechaty - https://github.com/wechaty/python-wechaty | |||
""" | |||
import asyncio | |||
import base64 | |||
import os | |||
import time | |||
import asyncio | |||
from bridge.context import Context | |||
from wechaty_puppet import FileBox | |||
from wechaty import Wechaty, Contact | |||
from wechaty import Contact, Wechaty | |||
from wechaty.user import Message | |||
from bridge.reply import * | |||
from wechaty_puppet import FileBox | |||
from bridge.context import * | |||
from bridge.context import Context | |||
from bridge.reply import * | |||
from channel.chat_channel import ChatChannel | |||
from channel.wechat.wechaty_message import WechatyMessage | |||
from common.log import logger | |||
from common.singleton import singleton | |||
from config import conf | |||
try: | |||
from voice.audio_convert import any_to_sil | |||
except Exception as e: | |||
pass | |||
@singleton | |||
class WechatyChannel(ChatChannel): | |||
NOT_SUPPORT_REPLYTYPE = [] | |||
def __init__(self): | |||
super().__init__() | |||
def startup(self): | |||
config = conf() | |||
token = config.get('wechaty_puppet_service_token') | |||
os.environ['WECHATY_PUPPET_SERVICE_TOKEN'] = token | |||
token = config.get("wechaty_puppet_service_token") | |||
os.environ["WECHATY_PUPPET_SERVICE_TOKEN"] = token | |||
asyncio.run(self.main()) | |||
async def main(self): | |||
loop = asyncio.get_event_loop() | |||
#将asyncio的loop传入处理线程 | |||
self.handler_pool._initializer= lambda: asyncio.set_event_loop(loop) | |||
# 将asyncio的loop传入处理线程 | |||
self.handler_pool._initializer = lambda: asyncio.set_event_loop(loop) | |||
self.bot = Wechaty() | |||
self.bot.on('login', self.on_login) | |||
self.bot.on('message', self.on_message) | |||
self.bot.on("login", self.on_login) | |||
self.bot.on("message", self.on_message) | |||
await self.bot.start() | |||
async def on_login(self, contact: Contact): | |||
self.user_id = contact.contact_id | |||
self.name = contact.name | |||
logger.info('[WX] login user={}'.format(contact)) | |||
logger.info("[WX] login user={}".format(contact)) | |||
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息 | |||
def send(self, reply: Reply, context: Context): | |||
receiver_id = context['receiver'] | |||
receiver_id = context["receiver"] | |||
loop = asyncio.get_event_loop() | |||
if context['isgroup']: | |||
receiver = asyncio.run_coroutine_threadsafe(self.bot.Room.find(receiver_id),loop).result() | |||
if context["isgroup"]: | |||
receiver = asyncio.run_coroutine_threadsafe( | |||
self.bot.Room.find(receiver_id), loop | |||
).result() | |||
else: | |||
receiver = asyncio.run_coroutine_threadsafe(self.bot.Contact.find(receiver_id),loop).result() | |||
receiver = asyncio.run_coroutine_threadsafe( | |||
self.bot.Contact.find(receiver_id), loop | |||
).result() | |||
msg = None | |||
if reply.type == ReplyType.TEXT: | |||
msg = reply.content | |||
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result() | |||
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) | |||
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result() | |||
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver)) | |||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: | |||
msg = reply.content | |||
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result() | |||
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver)) | |||
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result() | |||
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver)) | |||
elif reply.type == ReplyType.VOICE: | |||
voiceLength = None | |||
file_path = reply.content | |||
sil_file = os.path.splitext(file_path)[0] + '.sil' | |||
sil_file = os.path.splitext(file_path)[0] + ".sil" | |||
voiceLength = int(any_to_sil(file_path, sil_file)) | |||
if voiceLength >= 60000: | |||
voiceLength = 60000 | |||
logger.info('[WX] voice too long, length={}, set to 60s'.format(voiceLength)) | |||
logger.info( | |||
"[WX] voice too long, length={}, set to 60s".format(voiceLength) | |||
) | |||
# 发送语音 | |||
t = int(time.time()) | |||
msg = FileBox.from_file(sil_file, name=str(t) + '.sil') | |||
msg = FileBox.from_file(sil_file, name=str(t) + ".sil") | |||
if voiceLength is not None: | |||
msg.metadata['voiceLength'] = voiceLength | |||
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result() | |||
msg.metadata["voiceLength"] = voiceLength | |||
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result() | |||
try: | |||
os.remove(file_path) | |||
if sil_file != file_path: | |||
os.remove(sil_file) | |||
except Exception as e: | |||
pass | |||
logger.info('[WX] sendVoice={}, receiver={}'.format(reply.content, receiver)) | |||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 | |||
logger.info( | |||
"[WX] sendVoice={}, receiver={}".format(reply.content, receiver) | |||
) | |||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 | |||
img_url = reply.content | |||
t = int(time.time()) | |||
msg = FileBox.from_url(url=img_url, name=str(t) + '.png') | |||
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result() | |||
logger.info('[WX] sendImage url={}, receiver={}'.format(img_url,receiver)) | |||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片 | |||
msg = FileBox.from_url(url=img_url, name=str(t) + ".png") | |||
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result() | |||
logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver)) | |||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片 | |||
image_storage = reply.content | |||
image_storage.seek(0) | |||
t = int(time.time()) | |||
msg = FileBox.from_base64(base64.b64encode(image_storage.read()), str(t) + '.png') | |||
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result() | |||
logger.info('[WX] sendImage, receiver={}'.format(receiver)) | |||
msg = FileBox.from_base64( | |||
base64.b64encode(image_storage.read()), str(t) + ".png" | |||
) | |||
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result() | |||
logger.info("[WX] sendImage, receiver={}".format(receiver)) | |||
async def on_message(self, msg: Message): | |||
""" | |||
@@ -110,16 +124,16 @@ class WechatyChannel(ChatChannel): | |||
try: | |||
cmsg = await WechatyMessage(msg) | |||
except NotImplementedError as e: | |||
logger.debug('[WX] {}'.format(e)) | |||
logger.debug("[WX] {}".format(e)) | |||
return | |||
except Exception as e: | |||
logger.exception('[WX] {}'.format(e)) | |||
logger.exception("[WX] {}".format(e)) | |||
return | |||
logger.debug('[WX] message:{}'.format(cmsg)) | |||
logger.debug("[WX] message:{}".format(cmsg)) | |||
room = msg.room() # 获取消息来自的群聊. 如果消息不是来自群聊, 则返回None | |||
isgroup = room is not None | |||
ctype = cmsg.ctype | |||
context = self._compose_context(ctype, cmsg.content, isgroup=isgroup, msg=cmsg) | |||
if context: | |||
logger.info('[WX] receiveMsg={}, context={}'.format(cmsg, context)) | |||
self.produce(context) | |||
logger.info("[WX] receiveMsg={}, context={}".format(cmsg, context)) | |||
self.produce(context) |
@@ -1,17 +1,21 @@ | |||
import asyncio | |||
import re | |||
from wechaty import MessageType | |||
from wechaty.user import Message | |||
from bridge.context import ContextType | |||
from channel.chat_message import ChatMessage | |||
from common.tmp_dir import TmpDir | |||
from common.log import logger | |||
from wechaty.user import Message | |||
from common.tmp_dir import TmpDir | |||
class aobject(object): | |||
"""Inheriting this class allows you to define an async __init__. | |||
So you can create objects by doing something like `await MyClass(params)` | |||
""" | |||
async def __new__(cls, *a, **kw): | |||
instance = super().__new__(cls) | |||
await instance.__init__(*a, **kw) | |||
@@ -19,17 +23,18 @@ class aobject(object): | |||
async def __init__(self): | |||
pass | |||
class WechatyMessage(ChatMessage, aobject): | |||
class WechatyMessage(ChatMessage, aobject): | |||
async def __init__(self, wechaty_msg: Message): | |||
super().__init__(wechaty_msg) | |||
room = wechaty_msg.room() | |||
self.msg_id = wechaty_msg.message_id | |||
self.create_time = wechaty_msg.payload.timestamp | |||
self.is_group = room is not None | |||
if wechaty_msg.type() == MessageType.MESSAGE_TYPE_TEXT: | |||
self.ctype = ContextType.TEXT | |||
self.content = wechaty_msg.text() | |||
@@ -40,12 +45,17 @@ class WechatyMessage(ChatMessage, aobject): | |||
def func(): | |||
loop = asyncio.get_event_loop() | |||
asyncio.run_coroutine_threadsafe(voice_file.to_file(self.content),loop).result() | |||
asyncio.run_coroutine_threadsafe( | |||
voice_file.to_file(self.content), loop | |||
).result() | |||
self._prepare_fn = func | |||
else: | |||
raise NotImplementedError("Unsupported message type: {}".format(wechaty_msg.type())) | |||
raise NotImplementedError( | |||
"Unsupported message type: {}".format(wechaty_msg.type()) | |||
) | |||
from_contact = wechaty_msg.talker() # 获取消息的发送者 | |||
self.from_user_id = from_contact.contact_id | |||
self.from_user_nickname = from_contact.name | |||
@@ -54,7 +64,7 @@ class WechatyMessage(ChatMessage, aobject): | |||
# wecahty: from是消息实际发送者, to:所在群 | |||
# itchat: 如果是你发送群消息,from和to是你自己和所在群,如果是别人发群消息,from和to是所在群和你自己 | |||
# 但这个差别不影响逻辑,group中只使用到:1.用from来判断是否是自己发的,2.actual_user_id来判断实际发送用户 | |||
if self.is_group: | |||
self.to_user_id = room.room_id | |||
self.to_user_nickname = await room.topic() | |||
@@ -63,22 +73,22 @@ class WechatyMessage(ChatMessage, aobject): | |||
self.to_user_id = to_contact.contact_id | |||
self.to_user_nickname = to_contact.name | |||
if self.is_group or wechaty_msg.is_self(): # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。 | |||
if ( | |||
self.is_group or wechaty_msg.is_self() | |||
): # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。 | |||
self.other_user_id = self.to_user_id | |||
self.other_user_nickname = self.to_user_nickname | |||
else: | |||
self.other_user_id = self.from_user_id | |||
self.other_user_nickname = self.from_user_nickname | |||
if self.is_group: # wechaty群聊中,实际发送用户就是from_user | |||
if self.is_group: # wechaty群聊中,实际发送用户就是from_user | |||
self.is_at = await wechaty_msg.mention_self() | |||
if not self.is_at: # 有时候复制粘贴的消息,不算做@,但是内容里面会有@xxx,这里做一下兼容 | |||
if not self.is_at: # 有时候复制粘贴的消息,不算做@,但是内容里面会有@xxx,这里做一下兼容 | |||
name = wechaty_msg.wechaty.user_self().name | |||
pattern = f'@{name}(\u2005|\u0020)' | |||
if re.search(pattern,self.content): | |||
logger.debug(f'wechaty message {self.msg_id} include at') | |||
pattern = f"@{name}(\u2005|\u0020)" | |||
if re.search(pattern, self.content): | |||
logger.debug(f"wechaty message {self.msg_id} include at") | |||
self.is_at = True | |||
self.actual_user_id = self.from_user_id | |||
@@ -21,12 +21,12 @@ pip3 install web.py | |||
相关的服务器验证代码已经写好,你不需要再添加任何代码。你只需要在本项目根目录的`config.json`中添加 | |||
``` | |||
"channel_type": "wechatmp", | |||
"channel_type": "wechatmp", | |||
"wechatmp_token": "Token", # 微信公众平台的Token | |||
"wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443 | |||
"wechatmp_app_id": "", # 微信公众平台的appID,仅服务号需要 | |||
"wechatmp_app_secret": "", # 微信公众平台的appsecret,仅服务号需要 | |||
``` | |||
``` | |||
然后运行`python3 app.py`启动web服务器。这里会默认监听8080端口,但是微信公众号的服务器配置只支持80/443端口,有两种方法来解决这个问题。第一个是推荐的方法,使用端口转发命令将80端口转发到8080端口(443同理,注意需要支持SSL,也就是https的访问,在`wechatmp_channel.py`需要修改相应的证书路径): | |||
``` | |||
sudo iptables -t nat -A PREROUTING -p tcp --dport 80 -j REDIRECT --to-port 8080 | |||
@@ -1,46 +1,66 @@ | |||
import web | |||
import time | |||
import channel.wechatmp.reply as reply | |||
import web | |||
import channel.wechatmp.receive as receive | |||
from config import conf | |||
from common.log import logger | |||
import channel.wechatmp.reply as reply | |||
from bridge.context import * | |||
from channel.wechatmp.common import * | |||
from channel.wechatmp.common import * | |||
from channel.wechatmp.wechatmp_channel import WechatMPChannel | |||
from common.log import logger | |||
from config import conf | |||
# This class is instantiated once per query | |||
class Query(): | |||
# This class is instantiated once per query | |||
class Query: | |||
def GET(self): | |||
return verify_server(web.input()) | |||
def POST(self): | |||
# Make sure to return the instance that first created, @singleton will do that. | |||
# Make sure to return the instance that first created, @singleton will do that. | |||
channel = WechatMPChannel() | |||
try: | |||
webData = web.data() | |||
# logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8")) | |||
wechatmp_msg = receive.parse_xml(webData) | |||
if wechatmp_msg.msg_type == 'text' or wechatmp_msg.msg_type == 'voice': | |||
if wechatmp_msg.msg_type == "text" or wechatmp_msg.msg_type == "voice": | |||
from_user = wechatmp_msg.from_user_id | |||
message = wechatmp_msg.content.decode("utf-8") | |||
message_id = wechatmp_msg.msg_id | |||
logger.info("[wechatmp] {}:{} Receive post query {} {}: {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), from_user, message_id, message)) | |||
context = channel._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg) | |||
logger.info( | |||
"[wechatmp] {}:{} Receive post query {} {}: {}".format( | |||
web.ctx.env.get("REMOTE_ADDR"), | |||
web.ctx.env.get("REMOTE_PORT"), | |||
from_user, | |||
message_id, | |||
message, | |||
) | |||
) | |||
context = channel._compose_context( | |||
ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg | |||
) | |||
if context: | |||
# set private openai_api_key | |||
# if from_user is not changed in itchat, this can be placed at chat_channel | |||
user_data = conf().get_user_data(from_user) | |||
context['openai_api_key'] = user_data.get('openai_api_key') # None or user openai_api_key | |||
context["openai_api_key"] = user_data.get( | |||
"openai_api_key" | |||
) # None or user openai_api_key | |||
channel.produce(context) | |||
# The reply will be sent by channel.send() in another thread | |||
return "success" | |||
elif wechatmp_msg.msg_type == 'event': | |||
logger.info("[wechatmp] Event {} from {}".format(wechatmp_msg.Event, wechatmp_msg.from_user_id)) | |||
elif wechatmp_msg.msg_type == "event": | |||
logger.info( | |||
"[wechatmp] Event {} from {}".format( | |||
wechatmp_msg.Event, wechatmp_msg.from_user_id | |||
) | |||
) | |||
content = subscribe_msg() | |||
replyMsg = reply.TextMsg(wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content) | |||
replyMsg = reply.TextMsg( | |||
wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content | |||
) | |||
return replyMsg.send() | |||
else: | |||
logger.info("暂且不处理") | |||
@@ -48,4 +68,3 @@ class Query(): | |||
except Exception as exc: | |||
logger.exception(exc) | |||
return exc | |||
@@ -1,81 +1,117 @@ | |||
import web | |||
import time | |||
import channel.wechatmp.reply as reply | |||
import web | |||
import channel.wechatmp.receive as receive | |||
from config import conf | |||
from common.log import logger | |||
import channel.wechatmp.reply as reply | |||
from bridge.context import * | |||
from channel.wechatmp.common import * | |||
from channel.wechatmp.common import * | |||
from channel.wechatmp.wechatmp_channel import WechatMPChannel | |||
from common.log import logger | |||
from config import conf | |||
# This class is instantiated once per query | |||
class Query(): | |||
# This class is instantiated once per query | |||
class Query: | |||
def GET(self): | |||
return verify_server(web.input()) | |||
def POST(self): | |||
# Make sure to return the instance that first created, @singleton will do that. | |||
# Make sure to return the instance that first created, @singleton will do that. | |||
channel = WechatMPChannel() | |||
try: | |||
query_time = time.time() | |||
webData = web.data() | |||
logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8")) | |||
wechatmp_msg = receive.parse_xml(webData) | |||
if wechatmp_msg.msg_type == 'text' or wechatmp_msg.msg_type == 'voice': | |||
if wechatmp_msg.msg_type == "text" or wechatmp_msg.msg_type == "voice": | |||
from_user = wechatmp_msg.from_user_id | |||
to_user = wechatmp_msg.to_user_id | |||
message = wechatmp_msg.content.decode("utf-8") | |||
message_id = wechatmp_msg.msg_id | |||
logger.info("[wechatmp] {}:{} Receive post query {} {}: {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), from_user, message_id, message)) | |||
logger.info( | |||
"[wechatmp] {}:{} Receive post query {} {}: {}".format( | |||
web.ctx.env.get("REMOTE_ADDR"), | |||
web.ctx.env.get("REMOTE_PORT"), | |||
from_user, | |||
message_id, | |||
message, | |||
) | |||
) | |||
supported = True | |||
if "【收到不支持的消息类型,暂无法显示】" in message: | |||
supported = False # not supported, used to refresh | |||
supported = False # not supported, used to refresh | |||
cache_key = from_user | |||
reply_text = "" | |||
# New request | |||
if cache_key not in channel.cache_dict and cache_key not in channel.running: | |||
if ( | |||
cache_key not in channel.cache_dict | |||
and cache_key not in channel.running | |||
): | |||
# The first query begin, reset the cache | |||
context = channel._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg) | |||
logger.debug("[wechatmp] context: {} {}".format(context, wechatmp_msg)) | |||
if message_id in channel.received_msgs: # received and finished | |||
context = channel._compose_context( | |||
ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg | |||
) | |||
logger.debug( | |||
"[wechatmp] context: {} {}".format(context, wechatmp_msg) | |||
) | |||
if message_id in channel.received_msgs: # received and finished | |||
# no return because of bandwords or other reasons | |||
return "success" | |||
if supported and context: | |||
# set private openai_api_key | |||
# if from_user is not changed in itchat, this can be placed at chat_channel | |||
user_data = conf().get_user_data(from_user) | |||
context['openai_api_key'] = user_data.get('openai_api_key') # None or user openai_api_key | |||
context["openai_api_key"] = user_data.get( | |||
"openai_api_key" | |||
) # None or user openai_api_key | |||
channel.received_msgs[message_id] = wechatmp_msg | |||
channel.running.add(cache_key) | |||
channel.produce(context) | |||
else: | |||
trigger_prefix = conf().get('single_chat_prefix',[''])[0] | |||
trigger_prefix = conf().get("single_chat_prefix", [""])[0] | |||
if trigger_prefix or not supported: | |||
if trigger_prefix: | |||
content = textwrap.dedent(f"""\ | |||
content = textwrap.dedent( | |||
f"""\ | |||
请输入'{trigger_prefix}'接你想说的话跟我说话。 | |||
例如: | |||
{trigger_prefix}你好,很高兴见到你。""") | |||
{trigger_prefix}你好,很高兴见到你。""" | |||
) | |||
else: | |||
content = textwrap.dedent("""\ | |||
content = textwrap.dedent( | |||
"""\ | |||
你好,很高兴见到你。 | |||
请跟我说话吧。""") | |||
请跟我说话吧。""" | |||
) | |||
else: | |||
logger.error(f"[wechatmp] unknown error") | |||
content = textwrap.dedent("""\ | |||
未知错误,请稍后再试""") | |||
replyMsg = reply.TextMsg(wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content) | |||
content = textwrap.dedent( | |||
"""\ | |||
未知错误,请稍后再试""" | |||
) | |||
replyMsg = reply.TextMsg( | |||
wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content | |||
) | |||
return replyMsg.send() | |||
channel.query1[cache_key] = False | |||
channel.query2[cache_key] = False | |||
channel.query3[cache_key] = False | |||
# User request again, and the answer is not ready | |||
elif cache_key in channel.running and channel.query1.get(cache_key) == True and channel.query2.get(cache_key) == True and channel.query3.get(cache_key) == True: | |||
channel.query1[cache_key] = False #To improve waiting experience, this can be set to True. | |||
channel.query2[cache_key] = False #To improve waiting experience, this can be set to True. | |||
elif ( | |||
cache_key in channel.running | |||
and channel.query1.get(cache_key) == True | |||
and channel.query2.get(cache_key) == True | |||
and channel.query3.get(cache_key) == True | |||
): | |||
channel.query1[ | |||
cache_key | |||
] = False # To improve waiting experience, this can be set to True. | |||
channel.query2[ | |||
cache_key | |||
] = False # To improve waiting experience, this can be set to True. | |||
channel.query3[cache_key] = False | |||
# User request again, and the answer is ready | |||
elif cache_key in channel.cache_dict: | |||
@@ -84,7 +120,9 @@ class Query(): | |||
channel.query2[cache_key] = True | |||
channel.query3[cache_key] = True | |||
assert not (cache_key in channel.cache_dict and cache_key in channel.running) | |||
assert not ( | |||
cache_key in channel.cache_dict and cache_key in channel.running | |||
) | |||
if channel.query1.get(cache_key) == False: | |||
# The first query from wechat official server | |||
@@ -128,14 +166,20 @@ class Query(): | |||
# Have waiting for 3x5 seconds | |||
# return timeout message | |||
reply_text = "【正在思考中,回复任意文字尝试获取回复】" | |||
logger.info("[wechatmp] Three queries has finished For {}: {}".format(from_user, message_id)) | |||
logger.info( | |||
"[wechatmp] Three queries has finished For {}: {}".format( | |||
from_user, message_id | |||
) | |||
) | |||
replyPost = reply.TextMsg(from_user, to_user, reply_text).send() | |||
return replyPost | |||
else: | |||
pass | |||
if cache_key not in channel.cache_dict and cache_key not in channel.running: | |||
if ( | |||
cache_key not in channel.cache_dict | |||
and cache_key not in channel.running | |||
): | |||
# no return because of bandwords or other reasons | |||
return "success" | |||
@@ -147,26 +191,42 @@ class Query(): | |||
if cache_key in channel.cache_dict: | |||
content = channel.cache_dict[cache_key] | |||
if len(content.encode('utf8'))<=MAX_UTF8_LEN: | |||
if len(content.encode("utf8")) <= MAX_UTF8_LEN: | |||
reply_text = channel.cache_dict[cache_key] | |||
channel.cache_dict.pop(cache_key) | |||
else: | |||
continue_text = "\n【未完待续,回复任意文字以继续】" | |||
splits = split_string_by_utf8_length(content, MAX_UTF8_LEN - len(continue_text.encode('utf-8')), max_split= 1) | |||
splits = split_string_by_utf8_length( | |||
content, | |||
MAX_UTF8_LEN - len(continue_text.encode("utf-8")), | |||
max_split=1, | |||
) | |||
reply_text = splits[0] + continue_text | |||
channel.cache_dict[cache_key] = splits[1] | |||
logger.info("[wechatmp] {}:{} Do send {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), reply_text)) | |||
logger.info( | |||
"[wechatmp] {}:{} Do send {}".format( | |||
web.ctx.env.get("REMOTE_ADDR"), | |||
web.ctx.env.get("REMOTE_PORT"), | |||
reply_text, | |||
) | |||
) | |||
replyPost = reply.TextMsg(from_user, to_user, reply_text).send() | |||
return replyPost | |||
elif wechatmp_msg.msg_type == 'event': | |||
logger.info("[wechatmp] Event {} from {}".format(wechatmp_msg.content, wechatmp_msg.from_user_id)) | |||
elif wechatmp_msg.msg_type == "event": | |||
logger.info( | |||
"[wechatmp] Event {} from {}".format( | |||
wechatmp_msg.content, wechatmp_msg.from_user_id | |||
) | |||
) | |||
content = subscribe_msg() | |||
replyMsg = reply.TextMsg(wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content) | |||
replyMsg = reply.TextMsg( | |||
wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content | |||
) | |||
return replyMsg.send() | |||
else: | |||
logger.info("暂且不处理") | |||
return "success" | |||
except Exception as exc: | |||
logger.exception(exc) | |||
return exc | |||
return exc |
@@ -1,9 +1,11 @@ | |||
from config import conf | |||
import hashlib | |||
import textwrap | |||
from config import conf | |||
MAX_UTF8_LEN = 2048 | |||
class WeChatAPIException(Exception): | |||
pass | |||
@@ -16,13 +18,13 @@ def verify_server(data): | |||
timestamp = data.timestamp | |||
nonce = data.nonce | |||
echostr = data.echostr | |||
token = conf().get('wechatmp_token') #请按照公众平台官网\基本配置中信息填写 | |||
token = conf().get("wechatmp_token") # 请按照公众平台官网\基本配置中信息填写 | |||
data_list = [token, timestamp, nonce] | |||
data_list.sort() | |||
sha1 = hashlib.sha1() | |||
# map(sha1.update, data_list) #python2 | |||
sha1.update("".join(data_list).encode('utf-8')) | |||
sha1.update("".join(data_list).encode("utf-8")) | |||
hashcode = sha1.hexdigest() | |||
print("handle/GET func: hashcode, signature: ", hashcode, signature) | |||
if hashcode == signature: | |||
@@ -32,9 +34,11 @@ def verify_server(data): | |||
except Exception as Argument: | |||
return Argument | |||
def subscribe_msg(): | |||
trigger_prefix = conf().get('single_chat_prefix',[''])[0] | |||
msg = textwrap.dedent(f"""\ | |||
trigger_prefix = conf().get("single_chat_prefix", [""])[0] | |||
msg = textwrap.dedent( | |||
f"""\ | |||
感谢您的关注! | |||
这里是ChatGPT,可以自由对话。 | |||
资源有限,回复较慢,请勿着急。 | |||
@@ -42,22 +46,23 @@ def subscribe_msg(): | |||
暂时不支持图片输入。 | |||
支持图片输出,画字开头的问题将回复图片链接。 | |||
支持角色扮演和文字冒险两种定制模式对话。 | |||
输入'{trigger_prefix}#帮助' 查看详细指令。""") | |||
输入'{trigger_prefix}#帮助' 查看详细指令。""" | |||
) | |||
return msg | |||
def split_string_by_utf8_length(string, max_length, max_split=0): | |||
encoded = string.encode('utf-8') | |||
encoded = string.encode("utf-8") | |||
start, end = 0, 0 | |||
result = [] | |||
while end < len(encoded): | |||
if max_split > 0 and len(result) >= max_split: | |||
result.append(encoded[start:].decode('utf-8')) | |||
result.append(encoded[start:].decode("utf-8")) | |||
break | |||
end = start + max_length | |||
# 如果当前字节不是 UTF-8 编码的开始字节,则向前查找直到找到开始字节为止 | |||
while end < len(encoded) and (encoded[end] & 0b11000000) == 0b10000000: | |||
end -= 1 | |||
result.append(encoded[start:end].decode('utf-8')) | |||
result.append(encoded[start:end].decode("utf-8")) | |||
start = end | |||
return result | |||
return result |
@@ -1,6 +1,7 @@ | |||
# -*- coding: utf-8 -*-# | |||
# filename: receive.py | |||
import xml.etree.ElementTree as ET | |||
from bridge.context import ContextType | |||
from channel.chat_message import ChatMessage | |||
from common.log import logger | |||
@@ -12,34 +13,35 @@ def parse_xml(web_data): | |||
xmlData = ET.fromstring(web_data) | |||
return WeChatMPMessage(xmlData) | |||
class WeChatMPMessage(ChatMessage): | |||
def __init__(self, xmlData): | |||
super().__init__(xmlData) | |||
self.to_user_id = xmlData.find('ToUserName').text | |||
self.from_user_id = xmlData.find('FromUserName').text | |||
self.create_time = xmlData.find('CreateTime').text | |||
self.msg_type = xmlData.find('MsgType').text | |||
self.to_user_id = xmlData.find("ToUserName").text | |||
self.from_user_id = xmlData.find("FromUserName").text | |||
self.create_time = xmlData.find("CreateTime").text | |||
self.msg_type = xmlData.find("MsgType").text | |||
try: | |||
self.msg_id = xmlData.find('MsgId').text | |||
self.msg_id = xmlData.find("MsgId").text | |||
except: | |||
self.msg_id = self.from_user_id+self.create_time | |||
self.msg_id = self.from_user_id + self.create_time | |||
self.is_group = False | |||
# reply to other_user_id | |||
self.other_user_id = self.from_user_id | |||
if self.msg_type == 'text': | |||
if self.msg_type == "text": | |||
self.ctype = ContextType.TEXT | |||
self.content = xmlData.find('Content').text.encode("utf-8") | |||
elif self.msg_type == 'voice': | |||
self.content = xmlData.find("Content").text.encode("utf-8") | |||
elif self.msg_type == "voice": | |||
self.ctype = ContextType.TEXT | |||
self.content = xmlData.find('Recognition').text.encode("utf-8") # 接收语音识别结果 | |||
elif self.msg_type == 'image': | |||
self.content = xmlData.find("Recognition").text.encode("utf-8") # 接收语音识别结果 | |||
elif self.msg_type == "image": | |||
# not implemented | |||
self.pic_url = xmlData.find('PicUrl').text | |||
self.media_id = xmlData.find('MediaId').text | |||
elif self.msg_type == 'event': | |||
self.content = xmlData.find('Event').text | |||
else: # video, shortvideo, location, link | |||
self.pic_url = xmlData.find("PicUrl").text | |||
self.media_id = xmlData.find("MediaId").text | |||
elif self.msg_type == "event": | |||
self.content = xmlData.find("Event").text | |||
else: # video, shortvideo, location, link | |||
# not implemented | |||
pass | |||
pass |
@@ -2,6 +2,7 @@ | |||
# filename: reply.py | |||
import time | |||
class Msg(object): | |||
def __init__(self): | |||
pass | |||
@@ -9,13 +10,14 @@ class Msg(object): | |||
def send(self): | |||
return "success" | |||
class TextMsg(Msg): | |||
def __init__(self, toUserName, fromUserName, content): | |||
self.__dict = dict() | |||
self.__dict['ToUserName'] = toUserName | |||
self.__dict['FromUserName'] = fromUserName | |||
self.__dict['CreateTime'] = int(time.time()) | |||
self.__dict['Content'] = content | |||
self.__dict["ToUserName"] = toUserName | |||
self.__dict["FromUserName"] = fromUserName | |||
self.__dict["CreateTime"] = int(time.time()) | |||
self.__dict["Content"] = content | |||
def send(self): | |||
XmlForm = """ | |||
@@ -29,13 +31,14 @@ class TextMsg(Msg): | |||
""" | |||
return XmlForm.format(**self.__dict) | |||
class ImageMsg(Msg): | |||
def __init__(self, toUserName, fromUserName, mediaId): | |||
self.__dict = dict() | |||
self.__dict['ToUserName'] = toUserName | |||
self.__dict['FromUserName'] = fromUserName | |||
self.__dict['CreateTime'] = int(time.time()) | |||
self.__dict['MediaId'] = mediaId | |||
self.__dict["ToUserName"] = toUserName | |||
self.__dict["FromUserName"] = fromUserName | |||
self.__dict["CreateTime"] = int(time.time()) | |||
self.__dict["MediaId"] = mediaId | |||
def send(self): | |||
XmlForm = """ | |||
@@ -49,4 +52,4 @@ class ImageMsg(Msg): | |||
</Image> | |||
</xml> | |||
""" | |||
return XmlForm.format(**self.__dict) | |||
return XmlForm.format(**self.__dict) |
@@ -1,17 +1,19 @@ | |||
# -*- coding: utf-8 -*- | |||
import web | |||
import time | |||
import json | |||
import requests | |||
import threading | |||
from common.singleton import singleton | |||
from common.log import logger | |||
from common.expired_dict import ExpiredDict | |||
from config import conf | |||
from bridge.reply import * | |||
import time | |||
import requests | |||
import web | |||
from bridge.context import * | |||
from bridge.reply import * | |||
from channel.chat_channel import ChatChannel | |||
from channel.wechatmp.common import * | |||
from channel.wechatmp.common import * | |||
from common.expired_dict import ExpiredDict | |||
from common.log import logger | |||
from common.singleton import singleton | |||
from config import conf | |||
# If using SSL, uncomment the following lines, and modify the certificate path. | |||
# from cheroot.server import HTTPServer | |||
@@ -20,13 +22,14 @@ from channel.wechatmp.common import * | |||
# certificate='/ssl/cert.pem', | |||
# private_key='/ssl/cert.key') | |||
@singleton | |||
class WechatMPChannel(ChatChannel): | |||
def __init__(self, passive_reply = True): | |||
def __init__(self, passive_reply=True): | |||
super().__init__() | |||
self.passive_reply = passive_reply | |||
self.running = set() | |||
self.received_msgs = ExpiredDict(60*60*24) | |||
self.received_msgs = ExpiredDict(60 * 60 * 24) | |||
if self.passive_reply: | |||
self.NOT_SUPPORT_REPLYTYPE = [ReplyType.IMAGE, ReplyType.VOICE] | |||
self.cache_dict = dict() | |||
@@ -36,8 +39,8 @@ class WechatMPChannel(ChatChannel): | |||
else: | |||
# TODO support image | |||
self.NOT_SUPPORT_REPLYTYPE = [ReplyType.IMAGE, ReplyType.VOICE] | |||
self.app_id = conf().get('wechatmp_app_id') | |||
self.app_secret = conf().get('wechatmp_app_secret') | |||
self.app_id = conf().get("wechatmp_app_id") | |||
self.app_secret = conf().get("wechatmp_app_secret") | |||
self.access_token = None | |||
self.access_token_expires_time = 0 | |||
self.access_token_lock = threading.Lock() | |||
@@ -45,13 +48,12 @@ class WechatMPChannel(ChatChannel): | |||
def startup(self): | |||
if self.passive_reply: | |||
urls = ('/wx', 'channel.wechatmp.SubscribeAccount.Query') | |||
urls = ("/wx", "channel.wechatmp.SubscribeAccount.Query") | |||
else: | |||
urls = ('/wx', 'channel.wechatmp.ServiceAccount.Query') | |||
urls = ("/wx", "channel.wechatmp.ServiceAccount.Query") | |||
app = web.application(urls, globals(), autoreload=False) | |||
port = conf().get('wechatmp_port', 8080) | |||
web.httpserver.runsimple(app.wsgifunc(), ('0.0.0.0', port)) | |||
port = conf().get("wechatmp_port", 8080) | |||
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port)) | |||
def wechatmp_request(self, method, url, **kwargs): | |||
r = requests.request(method=method, url=url, **kwargs) | |||
@@ -63,7 +65,6 @@ class WechatMPChannel(ChatChannel): | |||
return ret | |||
def get_access_token(self): | |||
# return the access_token | |||
if self.access_token: | |||
if self.access_token_expires_time - time.time() > 60: | |||
@@ -76,15 +77,15 @@ class WechatMPChannel(ChatChannel): | |||
# This happens every 2 hours, so it doesn't affect the experience very much | |||
time.sleep(1) | |||
self.access_token = None | |||
url="https://api.weixin.qq.com/cgi-bin/token" | |||
params={ | |||
url = "https://api.weixin.qq.com/cgi-bin/token" | |||
params = { | |||
"grant_type": "client_credential", | |||
"appid": self.app_id, | |||
"secret": self.app_secret | |||
"secret": self.app_secret, | |||
} | |||
data = self.wechatmp_request(method='get', url=url, params=params) | |||
self.access_token = data['access_token'] | |||
self.access_token_expires_time = int(time.time()) + data['expires_in'] | |||
data = self.wechatmp_request(method="get", url=url, params=params) | |||
self.access_token = data["access_token"] | |||
self.access_token_expires_time = int(time.time()) + data["expires_in"] | |||
logger.info("[wechatmp] access_token: {}".format(self.access_token)) | |||
self.access_token_lock.release() | |||
else: | |||
@@ -101,29 +102,37 @@ class WechatMPChannel(ChatChannel): | |||
else: | |||
receiver = context["receiver"] | |||
reply_text = reply.content | |||
url="https://api.weixin.qq.com/cgi-bin/message/custom/send" | |||
params = { | |||
"access_token": self.get_access_token() | |||
} | |||
url = "https://api.weixin.qq.com/cgi-bin/message/custom/send" | |||
params = {"access_token": self.get_access_token()} | |||
json_data = { | |||
"touser": receiver, | |||
"msgtype": "text", | |||
"text": {"content": reply_text} | |||
"text": {"content": reply_text}, | |||
} | |||
self.wechatmp_request(method='post', url=url, params=params, data=json.dumps(json_data, ensure_ascii=False).encode('utf8')) | |||
self.wechatmp_request( | |||
method="post", | |||
url=url, | |||
params=params, | |||
data=json.dumps(json_data, ensure_ascii=False).encode("utf8"), | |||
) | |||
logger.info("[send] Do send to {}: {}".format(receiver, reply_text)) | |||
return | |||
def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数 | |||
logger.debug("[wechatmp] Success to generate reply, msgId={}".format(context['msg'].msg_id)) | |||
def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数 | |||
logger.debug( | |||
"[wechatmp] Success to generate reply, msgId={}".format( | |||
context["msg"].msg_id | |||
) | |||
) | |||
if self.passive_reply: | |||
self.running.remove(session_id) | |||
def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数 | |||
logger.exception("[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format(context['msg'].msg_id, exception)) | |||
def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数 | |||
logger.exception( | |||
"[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format( | |||
context["msg"].msg_id, exception | |||
) | |||
) | |||
if self.passive_reply: | |||
assert session_id not in self.cache_dict | |||
self.running.remove(session_id) | |||
@@ -2,4 +2,4 @@ | |||
OPEN_AI = "openAI" | |||
CHATGPT = "chatGPT" | |||
BAIDU = "baidu" | |||
CHATGPTONAZURE = "chatGPTOnAzure" | |||
CHATGPTONAZURE = "chatGPTOnAzure" |
@@ -1,7 +1,7 @@ | |||
from queue import Full, Queue | |||
from time import monotonic as time | |||
# add implementation of putleft to Queue | |||
class Dequeue(Queue): | |||
def putleft(self, item, block=True, timeout=None): | |||
@@ -30,4 +30,4 @@ class Dequeue(Queue): | |||
return self.putleft(item, block=False) | |||
def _putleft(self, item): | |||
self.queue.appendleft(item) | |||
self.queue.appendleft(item) |
@@ -39,4 +39,4 @@ class ExpiredDict(dict): | |||
return [(key, self[key]) for key in self.keys()] | |||
def __iter__(self): | |||
return self.keys().__iter__() | |||
return self.keys().__iter__() |
@@ -10,20 +10,29 @@ def _reset_logger(log): | |||
log.handlers.clear() | |||
log.propagate = False | |||
console_handle = logging.StreamHandler(sys.stdout) | |||
console_handle.setFormatter(logging.Formatter('[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s', | |||
datefmt='%Y-%m-%d %H:%M:%S')) | |||
file_handle = logging.FileHandler('run.log', encoding='utf-8') | |||
file_handle.setFormatter(logging.Formatter('[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s', | |||
datefmt='%Y-%m-%d %H:%M:%S')) | |||
console_handle.setFormatter( | |||
logging.Formatter( | |||
"[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s", | |||
datefmt="%Y-%m-%d %H:%M:%S", | |||
) | |||
) | |||
file_handle = logging.FileHandler("run.log", encoding="utf-8") | |||
file_handle.setFormatter( | |||
logging.Formatter( | |||
"[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s", | |||
datefmt="%Y-%m-%d %H:%M:%S", | |||
) | |||
) | |||
log.addHandler(file_handle) | |||
log.addHandler(console_handle) | |||
def _get_logger(): | |||
log = logging.getLogger('log') | |||
log = logging.getLogger("log") | |||
_reset_logger(log) | |||
log.setLevel(logging.INFO) | |||
return log | |||
# 日志句柄 | |||
logger = _get_logger() | |||
logger = _get_logger() |
@@ -1,15 +1,20 @@ | |||
import time | |||
import pip | |||
from pip._internal import main as pipmain | |||
from common.log import logger,_reset_logger | |||
from common.log import _reset_logger, logger | |||
def install(package): | |||
pipmain(['install', package]) | |||
pipmain(["install", package]) | |||
def install_requirements(file): | |||
pipmain(['install', '-r', file, "--upgrade"]) | |||
pipmain(["install", "-r", file, "--upgrade"]) | |||
_reset_logger(logger) | |||
def check_dulwich(): | |||
needwait = False | |||
for i in range(2): | |||
@@ -18,13 +23,14 @@ def check_dulwich(): | |||
needwait = False | |||
try: | |||
import dulwich | |||
return | |||
except ImportError: | |||
try: | |||
install('dulwich') | |||
install("dulwich") | |||
except: | |||
needwait = True | |||
try: | |||
import dulwich | |||
except ImportError: | |||
raise ImportError("Unable to import dulwich") | |||
raise ImportError("Unable to import dulwich") |
@@ -62,4 +62,4 @@ class SortedDict(dict): | |||
return iter(self.keys()) | |||
def __repr__(self): | |||
return f'{type(self).__name__}({dict(self)}, sort_func={self.sort_func.__name__}, reverse={self.reverse})' | |||
return f"{type(self).__name__}({dict(self)}, sort_func={self.sort_func.__name__}, reverse={self.reverse})" |
@@ -1,7 +1,11 @@ | |||
import time,re,hashlib | |||
import hashlib | |||
import re | |||
import time | |||
import config | |||
from common.log import logger | |||
def time_checker(f): | |||
def _time_checker(self, *args, **kwargs): | |||
_config = config.conf() | |||
@@ -9,17 +13,25 @@ def time_checker(f): | |||
if chat_time_module: | |||
chat_start_time = _config.get("chat_start_time", "00:00") | |||
chat_stopt_time = _config.get("chat_stop_time", "24:00") | |||
time_regex = re.compile(r'^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$') #时间匹配,包含24:00 | |||
time_regex = re.compile( | |||
r"^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$" | |||
) # 时间匹配,包含24:00 | |||
starttime_format_check = time_regex.match(chat_start_time) # 检查停止时间格式 | |||
stoptime_format_check = time_regex.match(chat_stopt_time) # 检查停止时间格式 | |||
chat_time_check = chat_start_time < chat_stopt_time # 确定启动时间<停止时间 | |||
chat_time_check = chat_start_time < chat_stopt_time # 确定启动时间<停止时间 | |||
# 时间格式检查 | |||
if not (starttime_format_check and stoptime_format_check and chat_time_check): | |||
logger.warn('时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})'.format(starttime_format_check,stoptime_format_check)) | |||
if chat_start_time>"23:59": | |||
logger.error('启动时间可能存在问题,请修改!') | |||
if not ( | |||
starttime_format_check and stoptime_format_check and chat_time_check | |||
): | |||
logger.warn( | |||
"时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})".format( | |||
starttime_format_check, stoptime_format_check | |||
) | |||
) | |||
if chat_start_time > "23:59": | |||
logger.error("启动时间可能存在问题,请修改!") | |||
# 服务时间检查 | |||
now_time = time.strftime("%H:%M", time.localtime()) | |||
@@ -27,12 +39,12 @@ def time_checker(f): | |||
f(self, *args, **kwargs) | |||
return None | |||
else: | |||
if args[0]['Content'] == "#更新配置": # 不在服务时间内也可以更新配置 | |||
if args[0]["Content"] == "#更新配置": # 不在服务时间内也可以更新配置 | |||
f(self, *args, **kwargs) | |||
else: | |||
logger.info('非服务时间内,不接受访问') | |||
logger.info("非服务时间内,不接受访问") | |||
return None | |||
else: | |||
f(self, *args, **kwargs) # 未开启时间模块则直接回答 | |||
return _time_checker | |||
return _time_checker |
@@ -1,20 +1,18 @@ | |||
import os | |||
import pathlib | |||
from config import conf | |||
class TmpDir(object): | |||
"""A temporary directory that is deleted when the object is destroyed. | |||
""" | |||
"""A temporary directory that is deleted when the object is destroyed.""" | |||
tmpFilePath = pathlib.Path("./tmp/") | |||
tmpFilePath = pathlib.Path('./tmp/') | |||
def __init__(self): | |||
pathExists = os.path.exists(self.tmpFilePath) | |||
if not pathExists: | |||
os.makedirs(self.tmpFilePath) | |||
def path(self): | |||
return str(self.tmpFilePath) + '/' | |||
return str(self.tmpFilePath) + "/" |
@@ -2,16 +2,30 @@ | |||
"open_ai_api_key": "YOUR API KEY", | |||
"model": "gpt-3.5-turbo", | |||
"proxy": "", | |||
"single_chat_prefix": ["bot", "@bot"], | |||
"single_chat_prefix": [ | |||
"bot", | |||
"@bot" | |||
], | |||
"single_chat_reply_prefix": "[bot] ", | |||
"group_chat_prefix": ["@bot"], | |||
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], | |||
"group_chat_in_one_session": ["ChatGPT测试群"], | |||
"image_create_prefix": ["画", "看", "找"], | |||
"group_chat_prefix": [ | |||
"@bot" | |||
], | |||
"group_name_white_list": [ | |||
"ChatGPT测试群", | |||
"ChatGPT测试群2" | |||
], | |||
"group_chat_in_one_session": [ | |||
"ChatGPT测试群" | |||
], | |||
"image_create_prefix": [ | |||
"画", | |||
"看", | |||
"找" | |||
], | |||
"speech_recognition": false, | |||
"group_speech_recognition": false, | |||
"voice_reply_voice": false, | |||
"conversation_max_tokens": 1000, | |||
"expires_in_seconds": 3600, | |||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。" | |||
} | |||
} |
@@ -3,9 +3,10 @@ | |||
import json | |||
import logging | |||
import os | |||
from common.log import logger | |||
import pickle | |||
from common.log import logger | |||
# 将所有可用的配置项写在字典里, 请使用小写字母 | |||
available_setting = { | |||
# openai api配置 | |||
@@ -16,8 +17,7 @@ available_setting = { | |||
# chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称 | |||
"model": "gpt-3.5-turbo", | |||
"use_azure_chatgpt": False, # 是否使用azure的chatgpt | |||
"azure_deployment_id": "", #azure 模型部署名称 | |||
"azure_deployment_id": "", # azure 模型部署名称 | |||
# Bot触发配置 | |||
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复 | |||
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人 | |||
@@ -30,25 +30,21 @@ available_setting = { | |||
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称 | |||
"trigger_by_self": False, # 是否允许机器人触发 | |||
"image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀 | |||
"concurrency_in_session": 1, # 同一会话最多有多少条消息在处理中,大于1可能乱序 | |||
"concurrency_in_session": 1, # 同一会话最多有多少条消息在处理中,大于1可能乱序 | |||
# chatgpt会话参数 | |||
"expires_in_seconds": 3600, # 无操作会话的过期时间 | |||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述 | |||
"conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数 | |||
# chatgpt限流配置 | |||
"rate_limit_chatgpt": 20, # chatgpt的调用频率限制 | |||
"rate_limit_dalle": 50, # openai dalle的调用频率限制 | |||
# chatgpt api参数 参考https://platform.openai.com/docs/api-reference/chat/create | |||
"temperature": 0.9, | |||
"top_p": 1, | |||
"frequency_penalty": 0, | |||
"presence_penalty": 0, | |||
"request_timeout": 60, # chatgpt请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 | |||
"timeout": 120, # chatgpt重试超时时间,在这个时间内,将会自动重试 | |||
"request_timeout": 60, # chatgpt请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 | |||
"timeout": 120, # chatgpt重试超时时间,在这个时间内,将会自动重试 | |||
# 语音设置 | |||
"speech_recognition": False, # 是否开启语音识别 | |||
"group_speech_recognition": False, # 是否开启群组语音识别 | |||
@@ -56,50 +52,40 @@ available_setting = { | |||
"always_reply_voice": False, # 是否一直使用语音回复 | |||
"voice_to_text": "openai", # 语音识别引擎,支持openai,baidu,google,azure | |||
"text_to_voice": "baidu", # 语音合成引擎,支持baidu,google,pytts(offline),azure | |||
# baidu 语音api配置, 使用百度语音识别和语音合成时需要 | |||
"baidu_app_id": "", | |||
"baidu_api_key": "", | |||
"baidu_secret_key": "", | |||
# 1536普通话(支持简单的英文识别) 1737英语 1637粤语 1837四川话 1936普通话远场 | |||
"baidu_dev_pid": "1536", | |||
# azure 语音api配置, 使用azure语音识别和语音合成时需要 | |||
"azure_voice_api_key": "", | |||
"azure_voice_region": "japaneast", | |||
# 服务时间限制,目前支持itchat | |||
"chat_time_module": False, # 是否开启服务时间限制 | |||
"chat_start_time": "00:00", # 服务开始时间 | |||
"chat_stop_time": "24:00", # 服务结束时间 | |||
# itchat的配置 | |||
"hot_reload": False, # 是否开启热重载 | |||
# wechaty的配置 | |||
"wechaty_puppet_service_token": "", # wechaty的token | |||
# wechatmp的配置 | |||
"wechatmp_token": "", # 微信公众平台的Token | |||
"wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443 | |||
"wechatmp_app_id": "", # 微信公众平台的appID,仅服务号需要 | |||
"wechatmp_token": "", # 微信公众平台的Token | |||
"wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443 | |||
"wechatmp_app_id": "", # 微信公众平台的appID,仅服务号需要 | |||
"wechatmp_app_secret": "", # 微信公众平台的appsecret,仅服务号需要 | |||
# chatgpt指令自定义触发词 | |||
"clear_memory_commands": ['#清除记忆'], # 重置会话指令,必须以#开头 | |||
"clear_memory_commands": ["#清除记忆"], # 重置会话指令,必须以#开头 | |||
# channel配置 | |||
"channel_type": "wx", # 通道类型,支持:{wx,wxy,terminal,wechatmp,wechatmp_service} | |||
"channel_type": "wx", # 通道类型,支持:{wx,wxy,terminal,wechatmp,wechatmp_service} | |||
"debug": False, # 是否开启debug模式,开启后会打印更多日志 | |||
# 插件配置 | |||
"plugin_trigger_prefix": "$", # 规范插件提供聊天相关指令的前缀,建议不要和管理员指令前缀"#"冲突 | |||
} | |||
class Config(dict): | |||
def __init__(self, d:dict={}): | |||
def __init__(self, d: dict = {}): | |||
super().__init__(d) | |||
# user_datas: 用户数据,key为用户名,value为用户数据,也是dict | |||
self.user_datas = {} | |||
@@ -130,7 +116,7 @@ class Config(dict): | |||
def load_user_datas(self): | |||
try: | |||
with open('user_datas.pkl', 'rb') as f: | |||
with open("user_datas.pkl", "rb") as f: | |||
self.user_datas = pickle.load(f) | |||
logger.info("[Config] User datas loaded.") | |||
except FileNotFoundError as e: | |||
@@ -141,12 +127,13 @@ class Config(dict): | |||
def save_user_datas(self): | |||
try: | |||
with open('user_datas.pkl', 'wb') as f: | |||
with open("user_datas.pkl", "wb") as f: | |||
pickle.dump(self.user_datas, f) | |||
logger.info("[Config] User datas saved.") | |||
except Exception as e: | |||
logger.info("[Config] User datas error: {}".format(e)) | |||
config = Config() | |||
@@ -154,7 +141,7 @@ def load_config(): | |||
global config | |||
config_path = "./config.json" | |||
if not os.path.exists(config_path): | |||
logger.info('配置文件不存在,将使用config-template.json模板') | |||
logger.info("配置文件不存在,将使用config-template.json模板") | |||
config_path = "./config-template.json" | |||
config_str = read_file(config_path) | |||
@@ -169,7 +156,8 @@ def load_config(): | |||
name = name.lower() | |||
if name in available_setting: | |||
logger.info( | |||
"[INIT] override config by environ args: {}={}".format(name, value)) | |||
"[INIT] override config by environ args: {}={}".format(name, value) | |||
) | |||
try: | |||
config[name] = eval(value) | |||
except: | |||
@@ -182,18 +170,19 @@ def load_config(): | |||
if config.get("debug", False): | |||
logger.setLevel(logging.DEBUG) | |||
logger.debug("[INIT] set log level to DEBUG") | |||
logger.debug("[INIT] set log level to DEBUG") | |||
logger.info("[INIT] load config: {}".format(config)) | |||
config.load_user_datas() | |||
def get_root(): | |||
return os.path.dirname(os.path.abspath(__file__)) | |||
def read_file(path): | |||
with open(path, mode='r', encoding='utf-8') as f: | |||
with open(path, mode="r", encoding="utf-8") as f: | |||
return f.read() | |||
@@ -33,7 +33,7 @@ ADD ./entrypoint.sh /entrypoint.sh | |||
RUN chmod +x /entrypoint.sh \ | |||
&& groupadd -r noroot \ | |||
&& useradd -r -g noroot -s /bin/bash -d /home/noroot noroot \ | |||
&& chown -R noroot:noroot ${BUILD_PREFIX} | |||
&& chown -R noroot:noroot ${BUILD_PREFIX} | |||
USER noroot | |||
@@ -18,7 +18,7 @@ RUN apt-get update \ | |||
&& pip install --no-cache -r requirements.txt \ | |||
&& pip install --no-cache -r requirements-optional.txt \ | |||
&& pip install azure-cognitiveservices-speech | |||
WORKDIR ${BUILD_PREFIX} | |||
ADD docker/entrypoint.sh /entrypoint.sh | |||
@@ -11,6 +11,5 @@ docker build -f Dockerfile.alpine \ | |||
-t zhayujie/chatgpt-on-wechat . | |||
# tag image | |||
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:alpine | |||
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:alpine | |||
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$CHATGPT_ON_WECHAT_TAG-alpine | |||
@@ -11,5 +11,5 @@ docker build -f Dockerfile.debian \ | |||
-t zhayujie/chatgpt-on-wechat . | |||
# tag image | |||
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:debian | |||
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:debian | |||
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$CHATGPT_ON_WECHAT_TAG-debian |
@@ -9,7 +9,7 @@ RUN apk add --no-cache \ | |||
ffmpeg \ | |||
espeak \ | |||
&& pip install --no-cache \ | |||
baidu-aip \ | |||
baidu-aip \ | |||
chardet \ | |||
SpeechRecognition | |||
@@ -10,7 +10,7 @@ RUN apt-get update \ | |||
ffmpeg \ | |||
espeak \ | |||
&& pip install --no-cache \ | |||
baidu-aip \ | |||
baidu-aip \ | |||
chardet \ | |||
SpeechRecognition | |||
@@ -11,13 +11,13 @@ run_d: | |||
docker rm $(CONTAINER_NAME) || echo | |||
docker run -dt --name $(CONTAINER_NAME) $(PORT_MAP) \ | |||
--env-file=$(DOTENV) \ | |||
$(MOUNT) $(IMG) | |||
$(MOUNT) $(IMG) | |||
run_i: | |||
docker rm $(CONTAINER_NAME) || echo | |||
docker run -it --name $(CONTAINER_NAME) $(PORT_MAP) \ | |||
--env-file=$(DOTENV) \ | |||
$(MOUNT) $(IMG) | |||
$(MOUNT) $(IMG) | |||
stop: | |||
docker stop $(CONTAINER_NAME) | |||
@@ -24,17 +24,17 @@ | |||
在本仓库中预置了一些插件,如果要安装其他仓库的插件,有两种方法。 | |||
- 第一种方法是在将下载的插件文件都解压到"plugins"文件夹的一个单独的文件夹,最终插件的代码都位于"plugins/PLUGIN_NAME/*"中。启动程序后,如果插件的目录结构正确,插件会自动被扫描加载。除此以外,注意你还需要安装文件夹中`requirements.txt`中的依赖。 | |||
- 第二种方法是`Godcmd`插件,它是预置的管理员插件,能够让程序在运行时就能安装插件,它能够自动安装依赖。 | |||
安装插件的命令是"#installp [仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)记录的插件名/仓库地址"。这是管理员命令,认证方法在[这里](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/godcmd)。 | |||
- 安装[仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)记录的插件:#installp sdwebui | |||
- 安装指定仓库的插件:#installp https://github.com/lanvent/plugin_sdwebui.git | |||
在安装之后,需要执行"#scanp"命令来扫描加载新安装的插件(或者重新启动程序)。 | |||
安装插件后需要注意有些插件有自己的配置模板,一般要去掉".template"新建一个配置文件。 | |||
## 插件化实现 | |||
@@ -107,14 +107,14 @@ | |||
``` | |||
回复`Reply`的定义如下所示,它允许Bot可以回复多类不同的消息。同时也加入了`INFO`和`ERROR`消息类型区分系统提示和系统错误。 | |||
```python | |||
class ReplyType(Enum): | |||
TEXT = 1 # 文本 | |||
VOICE = 2 # 音频文件 | |||
IMAGE = 3 # 图片文件 | |||
IMAGE_URL = 4 # 图片URL | |||
INFO = 9 | |||
ERROR = 10 | |||
class Reply: | |||
@@ -159,12 +159,12 @@ | |||
目前支持三类触发事件: | |||
``` | |||
1.收到消息 | |||
---> `ON_HANDLE_CONTEXT` | |||
2.产生回复 | |||
---> `ON_DECORATE_REPLY` | |||
3.装饰回复 | |||
---> `ON_SEND_REPLY` | |||
1.收到消息 | |||
---> `ON_HANDLE_CONTEXT` | |||
2.产生回复 | |||
---> `ON_DECORATE_REPLY` | |||
3.装饰回复 | |||
---> `ON_SEND_REPLY` | |||
4.发送回复 | |||
``` | |||
@@ -268,6 +268,6 @@ class Hello(Plugin): | |||
- 一个插件目录建议只注册一个插件类。建议使用单独的仓库维护插件,便于更新。 | |||
在测试调试好后提交`PR`,把自己的仓库加入到[仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)中。 | |||
- 插件的config文件、使用说明`README.md`、`requirement.txt`等放置在插件目录中。 | |||
- 默认优先级不要超过管理员插件`Godcmd`的优先级(999),`Godcmd`插件提供了配置管理、插件管理等功能。 |
@@ -1,9 +1,9 @@ | |||
from .plugin_manager import PluginManager | |||
from .event import * | |||
from .plugin import * | |||
from .plugin_manager import PluginManager | |||
instance = PluginManager() | |||
register = instance.register | |||
register = instance.register | |||
# load_plugins = instance.load_plugins | |||
# emit_event = instance.emit_event |
@@ -1 +1 @@ | |||
from .banwords import * | |||
from .banwords import * |
@@ -2,56 +2,67 @@ | |||
import json | |||
import os | |||
import plugins | |||
from bridge.context import ContextType | |||
from bridge.reply import Reply, ReplyType | |||
import plugins | |||
from plugins import * | |||
from common.log import logger | |||
from plugins import * | |||
from .lib.WordsSearch import WordsSearch | |||
@plugins.register(name="Banwords", desire_priority=100, hidden=True, desc="判断消息中是否有敏感词、决定是否回复。", version="1.0", author="lanvent") | |||
@plugins.register( | |||
name="Banwords", | |||
desire_priority=100, | |||
hidden=True, | |||
desc="判断消息中是否有敏感词、决定是否回复。", | |||
version="1.0", | |||
author="lanvent", | |||
) | |||
class Banwords(Plugin): | |||
def __init__(self): | |||
super().__init__() | |||
try: | |||
curdir=os.path.dirname(__file__) | |||
config_path=os.path.join(curdir,"config.json") | |||
conf=None | |||
curdir = os.path.dirname(__file__) | |||
config_path = os.path.join(curdir, "config.json") | |||
conf = None | |||
if not os.path.exists(config_path): | |||
conf={"action":"ignore"} | |||
with open(config_path,"w") as f: | |||
json.dump(conf,f,indent=4) | |||
conf = {"action": "ignore"} | |||
with open(config_path, "w") as f: | |||
json.dump(conf, f, indent=4) | |||
else: | |||
with open(config_path,"r") as f: | |||
conf=json.load(f) | |||
with open(config_path, "r") as f: | |||
conf = json.load(f) | |||
self.searchr = WordsSearch() | |||
self.action = conf["action"] | |||
banwords_path = os.path.join(curdir,"banwords.txt") | |||
with open(banwords_path, 'r', encoding='utf-8') as f: | |||
words=[] | |||
banwords_path = os.path.join(curdir, "banwords.txt") | |||
with open(banwords_path, "r", encoding="utf-8") as f: | |||
words = [] | |||
for line in f: | |||
word = line.strip() | |||
if word: | |||
words.append(word) | |||
self.searchr.SetKeywords(words) | |||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context | |||
if conf.get("reply_filter",True): | |||
if conf.get("reply_filter", True): | |||
self.handlers[Event.ON_DECORATE_REPLY] = self.on_decorate_reply | |||
self.reply_action = conf.get("reply_action","ignore") | |||
self.reply_action = conf.get("reply_action", "ignore") | |||
logger.info("[Banwords] inited") | |||
except Exception as e: | |||
logger.warn("[Banwords] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords .") | |||
logger.warn( | |||
"[Banwords] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords ." | |||
) | |||
raise e | |||
def on_handle_context(self, e_context: EventContext): | |||
if e_context['context'].type not in [ContextType.TEXT,ContextType.IMAGE_CREATE]: | |||
if e_context["context"].type not in [ | |||
ContextType.TEXT, | |||
ContextType.IMAGE_CREATE, | |||
]: | |||
return | |||
content = e_context['context'].content | |||
content = e_context["context"].content | |||
logger.debug("[Banwords] on_handle_context. content: %s" % content) | |||
if self.action == "ignore": | |||
f = self.searchr.FindFirst(content) | |||
@@ -61,31 +72,34 @@ class Banwords(Plugin): | |||
return | |||
elif self.action == "replace": | |||
if self.searchr.ContainsAny(content): | |||
reply = Reply(ReplyType.INFO, "发言中包含敏感词,请重试: \n"+self.searchr.Replace(content)) | |||
e_context['reply'] = reply | |||
reply = Reply( | |||
ReplyType.INFO, "发言中包含敏感词,请重试: \n" + self.searchr.Replace(content) | |||
) | |||
e_context["reply"] = reply | |||
e_context.action = EventAction.BREAK_PASS | |||
return | |||
def on_decorate_reply(self, e_context: EventContext): | |||
if e_context['reply'].type not in [ReplyType.TEXT]: | |||
def on_decorate_reply(self, e_context: EventContext): | |||
if e_context["reply"].type not in [ReplyType.TEXT]: | |||
return | |||
reply = e_context['reply'] | |||
reply = e_context["reply"] | |||
content = reply.content | |||
if self.reply_action == "ignore": | |||
f = self.searchr.FindFirst(content) | |||
if f: | |||
logger.info("[Banwords] %s in reply" % f["Keyword"]) | |||
e_context['reply'] = None | |||
e_context["reply"] = None | |||
e_context.action = EventAction.BREAK_PASS | |||
return | |||
elif self.reply_action == "replace": | |||
if self.searchr.ContainsAny(content): | |||
reply = Reply(ReplyType.INFO, "已替换回复中的敏感词: \n"+self.searchr.Replace(content)) | |||
e_context['reply'] = reply | |||
reply = Reply( | |||
ReplyType.INFO, "已替换回复中的敏感词: \n" + self.searchr.Replace(content) | |||
) | |||
e_context["reply"] = reply | |||
e_context.action = EventAction.CONTINUE | |||
return | |||
def get_help_text(self, **kwargs): | |||
return Banwords.desc | |||
return Banwords.desc |
@@ -1,5 +1,5 @@ | |||
{ | |||
"action": "replace", | |||
"reply_filter": true, | |||
"reply_action": "ignore" | |||
} | |||
"action": "replace", | |||
"reply_filter": true, | |||
"reply_action": "ignore" | |||
} |
@@ -24,7 +24,7 @@ see https://ai.baidu.com/unit/home#/home?track=61fe1b0d3407ce3face1d92cb5c291087 | |||
``` json | |||
{ | |||
"service_id": "s...", #"机器人ID" | |||
"api_key": "", | |||
"api_key": "", | |||
"secret_key": "" | |||
} | |||
``` |
@@ -1 +1 @@ | |||
from .bdunit import * | |||
from .bdunit import * |
@@ -2,21 +2,29 @@ | |||
import json | |||
import os | |||
import uuid | |||
from uuid import getnode as get_mac | |||
import requests | |||
import plugins | |||
from bridge.context import ContextType | |||
from bridge.reply import Reply, ReplyType | |||
from common.log import logger | |||
import plugins | |||
from plugins import * | |||
from uuid import getnode as get_mac | |||
"""利用百度UNIT实现智能对话 | |||
如果命中意图,返回意图对应的回复,否则返回继续交付给下个插件处理 | |||
""" | |||
@plugins.register(name="BDunit", desire_priority=0, hidden=True, desc="Baidu unit bot system", version="0.1", author="jackson") | |||
@plugins.register( | |||
name="BDunit", | |||
desire_priority=0, | |||
hidden=True, | |||
desc="Baidu unit bot system", | |||
version="0.1", | |||
author="jackson", | |||
) | |||
class BDunit(Plugin): | |||
def __init__(self): | |||
super().__init__() | |||
@@ -40,11 +48,10 @@ class BDunit(Plugin): | |||
raise e | |||
def on_handle_context(self, e_context: EventContext): | |||
if e_context['context'].type != ContextType.TEXT: | |||
if e_context["context"].type != ContextType.TEXT: | |||
return | |||
content = e_context['context'].content | |||
content = e_context["context"].content | |||
logger.debug("[BDunit] on_handle_context. content: %s" % content) | |||
parsed = self.getUnit2(content) | |||
intent = self.getIntent(parsed) | |||
@@ -53,7 +60,7 @@ class BDunit(Plugin): | |||
reply = Reply() | |||
reply.type = ReplyType.TEXT | |||
reply.content = self.getSay(parsed) | |||
e_context['reply'] = reply | |||
e_context["reply"] = reply | |||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 | |||
else: | |||
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 | |||
@@ -70,17 +77,15 @@ class BDunit(Plugin): | |||
string: access_token | |||
""" | |||
url = "https://aip.baidubce.com/oauth/2.0/token?client_id={}&client_secret={}&grant_type=client_credentials".format( | |||
self.api_key, self.secret_key) | |||
self.api_key, self.secret_key | |||
) | |||
payload = "" | |||
headers = { | |||
'Content-Type': 'application/json', | |||
'Accept': 'application/json' | |||
} | |||
headers = {"Content-Type": "application/json", "Accept": "application/json"} | |||
response = requests.request("POST", url, headers=headers, data=payload) | |||
# print(response.text) | |||
return response.json()['access_token'] | |||
return response.json()["access_token"] | |||
def getUnit(self, query): | |||
""" | |||
@@ -90,11 +95,14 @@ class BDunit(Plugin): | |||
""" | |||
url = ( | |||
'https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=' | |||
"https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" | |||
+ self.access_token | |||
) | |||
request = {"query": query, "user_id": str( | |||
get_mac())[:32], "terminal_id": "88888"} | |||
request = { | |||
"query": query, | |||
"user_id": str(get_mac())[:32], | |||
"terminal_id": "88888", | |||
} | |||
body = { | |||
"log_id": str(uuid.uuid1()), | |||
"version": "3.0", | |||
@@ -142,11 +150,7 @@ class BDunit(Plugin): | |||
:param parsed: UNIT 解析结果 | |||
:returns: 意图数组 | |||
""" | |||
if ( | |||
parsed | |||
and "result" in parsed | |||
and "response_list" in parsed["result"] | |||
): | |||
if parsed and "result" in parsed and "response_list" in parsed["result"]: | |||
try: | |||
return parsed["result"]["response_list"][0]["schema"]["intent"] | |||
except Exception as e: | |||
@@ -163,11 +167,7 @@ class BDunit(Plugin): | |||
:param intent: 意图的名称 | |||
:returns: True: 包含; False: 不包含 | |||
""" | |||
if ( | |||
parsed | |||
and "result" in parsed | |||
and "response_list" in parsed["result"] | |||
): | |||
if parsed and "result" in parsed and "response_list" in parsed["result"]: | |||
response_list = parsed["result"]["response_list"] | |||
for response in response_list: | |||
if ( | |||
@@ -189,11 +189,7 @@ class BDunit(Plugin): | |||
:returns: 词槽列表。你可以通过 name 属性筛选词槽, | |||
再通过 normalized_word 属性取出相应的值 | |||
""" | |||
if ( | |||
parsed | |||
and "result" in parsed | |||
and "response_list" in parsed["result"] | |||
): | |||
if parsed and "result" in parsed and "response_list" in parsed["result"]: | |||
response_list = parsed["result"]["response_list"] | |||
if intent == "": | |||
try: | |||
@@ -236,11 +232,7 @@ class BDunit(Plugin): | |||
:param parsed: UNIT 解析结果 | |||
:returns: UNIT 的回复文本 | |||
""" | |||
if ( | |||
parsed | |||
and "result" in parsed | |||
and "response_list" in parsed["result"] | |||
): | |||
if parsed and "result" in parsed and "response_list" in parsed["result"]: | |||
response_list = parsed["result"]["response_list"] | |||
answer = {} | |||
for response in response_list: | |||
@@ -266,11 +258,7 @@ class BDunit(Plugin): | |||
:param intent: 意图的名称 | |||
:returns: UNIT 的回复文本 | |||
""" | |||
if ( | |||
parsed | |||
and "result" in parsed | |||
and "response_list" in parsed["result"] | |||
): | |||
if parsed and "result" in parsed and "response_list" in parsed["result"]: | |||
response_list = parsed["result"]["response_list"] | |||
if intent == "": | |||
try: | |||
@@ -1,5 +1,5 @@ | |||
{ | |||
"service_id": "s...", | |||
"api_key": "", | |||
"secret_key": "" | |||
} | |||
"service_id": "s...", | |||
"api_key": "", | |||
"secret_key": "" | |||
} |
@@ -1 +1 @@ | |||
from .dungeon import * | |||
from .dungeon import * |
@@ -1,17 +1,18 @@ | |||
# encoding:utf-8 | |||
import plugins | |||
from bridge.bridge import Bridge | |||
from bridge.context import ContextType | |||
from bridge.reply import Reply, ReplyType | |||
from common import const | |||
from common.expired_dict import ExpiredDict | |||
from common.log import logger | |||
from config import conf | |||
import plugins | |||
from plugins import * | |||
from common.log import logger | |||
from common import const | |||
# https://github.com/bupticybee/ChineseAiDungeonChatGPT | |||
class StoryTeller(): | |||
class StoryTeller: | |||
def __init__(self, bot, sessionid, story): | |||
self.bot = bot | |||
self.sessionid = sessionid | |||
@@ -27,67 +28,85 @@ class StoryTeller(): | |||
if user_action[-1] != "。": | |||
user_action = user_action + "。" | |||
if self.first_interact: | |||
prompt = """现在来充当一个文字冒险游戏,描述时候注意节奏,不要太快,仔细描述各个人物的心情和周边环境。一次只需写四到六句话。 | |||
开头是,""" + self.story + " " + user_action | |||
prompt = ( | |||
"""现在来充当一个文字冒险游戏,描述时候注意节奏,不要太快,仔细描述各个人物的心情和周边环境。一次只需写四到六句话。 | |||
开头是,""" | |||
+ self.story | |||
+ " " | |||
+ user_action | |||
) | |||
self.first_interact = False | |||
else: | |||
prompt = """继续,一次只需要续写四到六句话,总共就只讲5分钟内发生的事情。""" + user_action | |||
return prompt | |||
@plugins.register(name="Dungeon", desire_priority=0, namecn="文字冒险", desc="A plugin to play dungeon game", version="1.0", author="lanvent") | |||
@plugins.register( | |||
name="Dungeon", | |||
desire_priority=0, | |||
namecn="文字冒险", | |||
desc="A plugin to play dungeon game", | |||
version="1.0", | |||
author="lanvent", | |||
) | |||
class Dungeon(Plugin): | |||
def __init__(self): | |||
super().__init__() | |||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context | |||
logger.info("[Dungeon] inited") | |||
# 目前没有设计session过期事件,这里先暂时使用过期字典 | |||
if conf().get('expires_in_seconds'): | |||
self.games = ExpiredDict(conf().get('expires_in_seconds')) | |||
if conf().get("expires_in_seconds"): | |||
self.games = ExpiredDict(conf().get("expires_in_seconds")) | |||
else: | |||
self.games = dict() | |||
def on_handle_context(self, e_context: EventContext): | |||
if e_context['context'].type != ContextType.TEXT: | |||
if e_context["context"].type != ContextType.TEXT: | |||
return | |||
bottype = Bridge().get_bot_type("chat") | |||
if bottype not in (const.CHATGPT, const.OPEN_AI): | |||
return | |||
bot = Bridge().get_bot("chat") | |||
content = e_context['context'].content[:] | |||
clist = e_context['context'].content.split(maxsplit=1) | |||
sessionid = e_context['context']['session_id'] | |||
content = e_context["context"].content[:] | |||
clist = e_context["context"].content.split(maxsplit=1) | |||
sessionid = e_context["context"]["session_id"] | |||
logger.debug("[Dungeon] on_handle_context. content: %s" % clist) | |||
trigger_prefix = conf().get('plugin_trigger_prefix', "$") | |||
trigger_prefix = conf().get("plugin_trigger_prefix", "$") | |||
if clist[0] == f"{trigger_prefix}停止冒险": | |||
if sessionid in self.games: | |||
self.games[sessionid].reset() | |||
del self.games[sessionid] | |||
reply = Reply(ReplyType.INFO, "冒险结束!") | |||
e_context['reply'] = reply | |||
e_context["reply"] = reply | |||
e_context.action = EventAction.BREAK_PASS | |||
elif clist[0] == f"{trigger_prefix}开始冒险" or sessionid in self.games: | |||
if sessionid not in self.games or clist[0] == f"{trigger_prefix}开始冒险": | |||
if len(clist)>1 : | |||
if len(clist) > 1: | |||
story = clist[1] | |||
else: | |||
story = "你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。" | |||
story = ( | |||
"你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。" | |||
) | |||
self.games[sessionid] = StoryTeller(bot, sessionid, story) | |||
reply = Reply(ReplyType.INFO, "冒险开始,你可以输入任意内容,让故事继续下去。故事背景是:" + story) | |||
e_context['reply'] = reply | |||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 | |||
e_context["reply"] = reply | |||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 | |||
else: | |||
prompt = self.games[sessionid].action(content) | |||
e_context['context'].type = ContextType.TEXT | |||
e_context['context'].content = prompt | |||
e_context.action = EventAction.BREAK # 事件结束,不跳过处理context的默认逻辑 | |||
e_context["context"].type = ContextType.TEXT | |||
e_context["context"].content = prompt | |||
e_context.action = EventAction.BREAK # 事件结束,不跳过处理context的默认逻辑 | |||
def get_help_text(self, **kwargs): | |||
help_text = "可以和机器人一起玩文字冒险游戏。\n" | |||
if kwargs.get('verbose') != True: | |||
if kwargs.get("verbose") != True: | |||
return help_text | |||
trigger_prefix = conf().get('plugin_trigger_prefix', "$") | |||
help_text = f"{trigger_prefix}开始冒险 "+"背景故事: 开始一个基于{背景故事}的文字冒险,之后你的所有消息会协助完善这个故事。\n"+f"{trigger_prefix}停止冒险: 结束游戏。\n" | |||
if kwargs.get('verbose') == True: | |||
trigger_prefix = conf().get("plugin_trigger_prefix", "$") | |||
help_text = ( | |||
f"{trigger_prefix}开始冒险 " | |||
+ "背景故事: 开始一个基于{背景故事}的文字冒险,之后你的所有消息会协助完善这个故事。\n" | |||
+ f"{trigger_prefix}停止冒险: 结束游戏。\n" | |||
) | |||
if kwargs.get("verbose") == True: | |||
help_text += f"\n命令例子: '{trigger_prefix}开始冒险 你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。'" | |||
return help_text | |||
return help_text |
@@ -9,17 +9,17 @@ class Event(Enum): | |||
e_context = { "channel": 消息channel, "context" : 本次消息的context} | |||
""" | |||
ON_HANDLE_CONTEXT = 2 # 处理消息前 | |||
ON_HANDLE_CONTEXT = 2 # 处理消息前 | |||
""" | |||
e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复,初始为空 } | |||
""" | |||
ON_DECORATE_REPLY = 3 # 得到回复后准备装饰 | |||
ON_DECORATE_REPLY = 3 # 得到回复后准备装饰 | |||
""" | |||
e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 } | |||
""" | |||
ON_SEND_REPLY = 4 # 发送回复前 | |||
ON_SEND_REPLY = 4 # 发送回复前 | |||
""" | |||
e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 } | |||
""" | |||
@@ -28,9 +28,9 @@ class Event(Enum): | |||
class EventAction(Enum): | |||
CONTINUE = 1 # 事件未结束,继续交给下个插件处理,如果没有下个插件,则交付给默认的事件处理逻辑 | |||
BREAK = 2 # 事件结束,不再给下个插件处理,交付给默认的事件处理逻辑 | |||
BREAK_PASS = 3 # 事件结束,不再给下个插件处理,不交付给默认的事件处理逻辑 | |||
CONTINUE = 1 # 事件未结束,继续交给下个插件处理,如果没有下个插件,则交付给默认的事件处理逻辑 | |||
BREAK = 2 # 事件结束,不再给下个插件处理,交付给默认的事件处理逻辑 | |||
BREAK_PASS = 3 # 事件结束,不再给下个插件处理,不交付给默认的事件处理逻辑 | |||
class EventContext: | |||
@@ -1 +1 @@ | |||
from .finish import * | |||
from .finish import * |
@@ -1,14 +1,21 @@ | |||
# encoding:utf-8 | |||
import plugins | |||
from bridge.context import ContextType | |||
from bridge.reply import Reply, ReplyType | |||
from common.log import logger | |||
from config import conf | |||
import plugins | |||
from plugins import * | |||
from common.log import logger | |||
@plugins.register(name="Finish", desire_priority=-999, hidden=True, desc="A plugin that check unknown command", version="1.0", author="js00000") | |||
@plugins.register( | |||
name="Finish", | |||
desire_priority=-999, | |||
hidden=True, | |||
desc="A plugin that check unknown command", | |||
version="1.0", | |||
author="js00000", | |||
) | |||
class Finish(Plugin): | |||
def __init__(self): | |||
super().__init__() | |||
@@ -16,19 +23,18 @@ class Finish(Plugin): | |||
logger.info("[Finish] inited") | |||
def on_handle_context(self, e_context: EventContext): | |||
if e_context['context'].type != ContextType.TEXT: | |||
if e_context["context"].type != ContextType.TEXT: | |||
return | |||
content = e_context['context'].content | |||
content = e_context["context"].content | |||
logger.debug("[Finish] on_handle_context. content: %s" % content) | |||
trigger_prefix = conf().get('plugin_trigger_prefix',"$") | |||
trigger_prefix = conf().get("plugin_trigger_prefix", "$") | |||
if content.startswith(trigger_prefix): | |||
reply = Reply() | |||
reply.type = ReplyType.ERROR | |||
reply.content = "未知插件命令\n查看插件命令列表请输入#help 插件名\n" | |||
e_context['reply'] = reply | |||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 | |||
e_context["reply"] = reply | |||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 | |||
def get_help_text(self, **kwargs): | |||
return "" |
@@ -1 +1 @@ | |||
from .godcmd import * | |||
from .godcmd import * |
@@ -1,4 +1,4 @@ | |||
{ | |||
"password": "", | |||
"admin_users": [] | |||
} | |||
"password": "", | |||
"admin_users": [] | |||
} |
@@ -6,14 +6,16 @@ import random | |||
import string | |||
import traceback | |||
from typing import Tuple | |||
import plugins | |||
from bridge.bridge import Bridge | |||
from bridge.context import ContextType | |||
from bridge.reply import Reply, ReplyType | |||
from config import conf, load_config | |||
import plugins | |||
from plugins import * | |||
from common import const | |||
from common.log import logger | |||
from config import conf, load_config | |||
from plugins import * | |||
# 定义指令集 | |||
COMMANDS = { | |||
"help": { | |||
@@ -41,7 +43,7 @@ COMMANDS = { | |||
}, | |||
"id": { | |||
"alias": ["id", "用户"], | |||
"desc": "获取用户id", # wechaty和wechatmp的用户id不会变化,可用于绑定管理员 | |||
"desc": "获取用户id", # wechaty和wechatmp的用户id不会变化,可用于绑定管理员 | |||
}, | |||
"reset": { | |||
"alias": ["reset", "重置会话"], | |||
@@ -114,18 +116,20 @@ ADMIN_COMMANDS = { | |||
"desc": "开启机器调试日志", | |||
}, | |||
} | |||
# 定义帮助函数 | |||
def get_help_text(isadmin, isgroup): | |||
help_text = "通用指令:\n" | |||
for cmd, info in COMMANDS.items(): | |||
if cmd=="auth": #不提示认证指令 | |||
if cmd == "auth": # 不提示认证指令 | |||
continue | |||
if cmd=="id" and conf().get("channel_type","wx") not in ["wxy","wechatmp"]: | |||
if cmd == "id" and conf().get("channel_type", "wx") not in ["wxy", "wechatmp"]: | |||
continue | |||
alias=["#"+a for a in info['alias'][:1]] | |||
alias = ["#" + a for a in info["alias"][:1]] | |||
help_text += f"{','.join(alias)} " | |||
if 'args' in info: | |||
args=[a for a in info['args']] | |||
if "args" in info: | |||
args = [a for a in info["args"]] | |||
help_text += f"{' '.join(args)}" | |||
help_text += f": {info['desc']}\n" | |||
@@ -135,39 +139,48 @@ def get_help_text(isadmin, isgroup): | |||
for plugin in plugins: | |||
if plugins[plugin].enabled and not plugins[plugin].hidden: | |||
namecn = plugins[plugin].namecn | |||
help_text += "\n%s:"%namecn | |||
help_text += PluginManager().instances[plugin].get_help_text(verbose=False).strip() | |||
help_text += "\n%s:" % namecn | |||
help_text += ( | |||
PluginManager().instances[plugin].get_help_text(verbose=False).strip() | |||
) | |||
if ADMIN_COMMANDS and isadmin: | |||
help_text += "\n\n管理员指令:\n" | |||
for cmd, info in ADMIN_COMMANDS.items(): | |||
alias=["#"+a for a in info['alias'][:1]] | |||
alias = ["#" + a for a in info["alias"][:1]] | |||
help_text += f"{','.join(alias)} " | |||
if 'args' in info: | |||
args=[a for a in info['args']] | |||
if "args" in info: | |||
args = [a for a in info["args"]] | |||
help_text += f"{' '.join(args)}" | |||
help_text += f": {info['desc']}\n" | |||
return help_text | |||
@plugins.register(name="Godcmd", desire_priority=999, hidden=True, desc="为你的机器人添加指令集,有用户和管理员两种角色,加载顺序请放在首位,初次运行后插件目录会生成配置文件, 填充管理员密码后即可认证", version="1.0", author="lanvent") | |||
class Godcmd(Plugin): | |||
@plugins.register( | |||
name="Godcmd", | |||
desire_priority=999, | |||
hidden=True, | |||
desc="为你的机器人添加指令集,有用户和管理员两种角色,加载顺序请放在首位,初次运行后插件目录会生成配置文件, 填充管理员密码后即可认证", | |||
version="1.0", | |||
author="lanvent", | |||
) | |||
class Godcmd(Plugin): | |||
def __init__(self): | |||
super().__init__() | |||
curdir=os.path.dirname(__file__) | |||
config_path=os.path.join(curdir,"config.json") | |||
gconf=None | |||
curdir = os.path.dirname(__file__) | |||
config_path = os.path.join(curdir, "config.json") | |||
gconf = None | |||
if not os.path.exists(config_path): | |||
gconf={"password":"","admin_users":[]} | |||
with open(config_path,"w") as f: | |||
json.dump(gconf,f,indent=4) | |||
gconf = {"password": "", "admin_users": []} | |||
with open(config_path, "w") as f: | |||
json.dump(gconf, f, indent=4) | |||
else: | |||
with open(config_path,"r") as f: | |||
gconf=json.load(f) | |||
with open(config_path, "r") as f: | |||
gconf = json.load(f) | |||
if gconf["password"] == "": | |||
self.temp_password = "".join(random.sample(string.digits, 4)) | |||
logger.info("[Godcmd] 因未设置口令,本次的临时口令为%s。"%self.temp_password) | |||
logger.info("[Godcmd] 因未设置口令,本次的临时口令为%s。" % self.temp_password) | |||
else: | |||
self.temp_password = None | |||
custom_commands = conf().get("clear_memory_commands", []) | |||
@@ -178,41 +191,42 @@ class Godcmd(Plugin): | |||
COMMANDS["reset"]["alias"].append(custom_command) | |||
self.password = gconf["password"] | |||
self.admin_users = gconf["admin_users"] # 预存的管理员账号,这些账号不需要认证。itchat的用户名每次都会变,不可用 | |||
self.isrunning = True # 机器人是否运行中 | |||
self.admin_users = gconf[ | |||
"admin_users" | |||
] # 预存的管理员账号,这些账号不需要认证。itchat的用户名每次都会变,不可用 | |||
self.isrunning = True # 机器人是否运行中 | |||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context | |||
logger.info("[Godcmd] inited") | |||
def on_handle_context(self, e_context: EventContext): | |||
context_type = e_context['context'].type | |||
context_type = e_context["context"].type | |||
if context_type != ContextType.TEXT: | |||
if not self.isrunning: | |||
e_context.action = EventAction.BREAK_PASS | |||
return | |||
content = e_context['context'].content | |||
content = e_context["context"].content | |||
logger.debug("[Godcmd] on_handle_context. content: %s" % content) | |||
if content.startswith("#"): | |||
# msg = e_context['context']['msg'] | |||
channel = e_context['channel'] | |||
user = e_context['context']['receiver'] | |||
session_id = e_context['context']['session_id'] | |||
isgroup = e_context['context'].get("isgroup", False) | |||
channel = e_context["channel"] | |||
user = e_context["context"]["receiver"] | |||
session_id = e_context["context"]["session_id"] | |||
isgroup = e_context["context"].get("isgroup", False) | |||
bottype = Bridge().get_bot_type("chat") | |||
bot = Bridge().get_bot("chat") | |||
# 将命令和参数分割 | |||
command_parts = content[1:].strip().split() | |||
cmd = command_parts[0] | |||
args = command_parts[1:] | |||
isadmin=False | |||
isadmin = False | |||
if user in self.admin_users: | |||
isadmin=True | |||
ok=False | |||
result="string" | |||
if any(cmd in info['alias'] for info in COMMANDS.values()): | |||
cmd = next(c for c, info in COMMANDS.items() if cmd in info['alias']) | |||
isadmin = True | |||
ok = False | |||
result = "string" | |||
if any(cmd in info["alias"] for info in COMMANDS.values()): | |||
cmd = next(c for c, info in COMMANDS.items() if cmd in info["alias"]) | |||
if cmd == "auth": | |||
ok, result = self.authenticate(user, args, isadmin, isgroup) | |||
elif cmd == "help" or cmd == "helpp": | |||
@@ -224,10 +238,14 @@ class Godcmd(Plugin): | |||
query_name = args[0].upper() | |||
# search name and namecn | |||
for name, plugincls in plugins.items(): | |||
if not plugincls.enabled : | |||
if not plugincls.enabled: | |||
continue | |||
if query_name == name or query_name == plugincls.namecn: | |||
ok, result = True, PluginManager().instances[name].get_help_text(isgroup=isgroup, isadmin=isadmin, verbose=True) | |||
ok, result = True, PluginManager().instances[ | |||
name | |||
].get_help_text( | |||
isgroup=isgroup, isadmin=isadmin, verbose=True | |||
) | |||
break | |||
if not ok: | |||
result = "插件不存在或未启用" | |||
@@ -236,14 +254,14 @@ class Godcmd(Plugin): | |||
elif cmd == "set_openai_api_key": | |||
if len(args) == 1: | |||
user_data = conf().get_user_data(user) | |||
user_data['openai_api_key'] = args[0] | |||
user_data["openai_api_key"] = args[0] | |||
ok, result = True, "你的OpenAI私有api_key已设置为" + args[0] | |||
else: | |||
ok, result = False, "请提供一个api_key" | |||
elif cmd == "reset_openai_api_key": | |||
try: | |||
user_data = conf().get_user_data(user) | |||
user_data.pop('openai_api_key') | |||
user_data.pop("openai_api_key") | |||
ok, result = True, "你的OpenAI私有api_key已清除" | |||
except Exception as e: | |||
ok, result = False, "你没有设置私有api_key" | |||
@@ -255,12 +273,16 @@ class Godcmd(Plugin): | |||
else: | |||
ok, result = False, "当前对话机器人不支持重置会话" | |||
logger.debug("[Godcmd] command: %s by %s" % (cmd, user)) | |||
elif any(cmd in info['alias'] for info in ADMIN_COMMANDS.values()): | |||
elif any(cmd in info["alias"] for info in ADMIN_COMMANDS.values()): | |||
if isadmin: | |||
if isgroup: | |||
ok, result = False, "群聊不可执行管理员指令" | |||
else: | |||
cmd = next(c for c, info in ADMIN_COMMANDS.items() if cmd in info['alias']) | |||
cmd = next( | |||
c | |||
for c, info in ADMIN_COMMANDS.items() | |||
if cmd in info["alias"] | |||
) | |||
if cmd == "stop": | |||
self.isrunning = False | |||
ok, result = True, "服务已暂停" | |||
@@ -278,13 +300,13 @@ class Godcmd(Plugin): | |||
else: | |||
ok, result = False, "当前对话机器人不支持重置会话" | |||
elif cmd == "debug": | |||
logger.setLevel('DEBUG') | |||
logger.setLevel("DEBUG") | |||
ok, result = True, "DEBUG模式已开启" | |||
elif cmd == "plist": | |||
plugins = PluginManager().list_plugins() | |||
ok = True | |||
result = "插件列表:\n" | |||
for name,plugincls in plugins.items(): | |||
for name, plugincls in plugins.items(): | |||
result += f"{plugincls.name}_v{plugincls.version} {plugincls.priority} - " | |||
if plugincls.enabled: | |||
result += "已启用\n" | |||
@@ -294,16 +316,20 @@ class Godcmd(Plugin): | |||
new_plugins = PluginManager().scan_plugins() | |||
ok, result = True, "插件扫描完成" | |||
PluginManager().activate_plugins() | |||
if len(new_plugins) >0 : | |||
if len(new_plugins) > 0: | |||
result += "\n发现新插件:\n" | |||
result += "\n".join([f"{p.name}_v{p.version}" for p in new_plugins]) | |||
else : | |||
result +=", 未发现新插件" | |||
result += "\n".join( | |||
[f"{p.name}_v{p.version}" for p in new_plugins] | |||
) | |||
else: | |||
result += ", 未发现新插件" | |||
elif cmd == "setpri": | |||
if len(args) != 2: | |||
ok, result = False, "请提供插件名和优先级" | |||
else: | |||
ok = PluginManager().set_plugin_priority(args[0], int(args[1])) | |||
ok = PluginManager().set_plugin_priority( | |||
args[0], int(args[1]) | |||
) | |||
if ok: | |||
result = "插件" + args[0] + "优先级已设置为" + args[1] | |||
else: | |||
@@ -350,42 +376,42 @@ class Godcmd(Plugin): | |||
else: | |||
ok, result = False, "需要管理员权限才能执行该指令" | |||
else: | |||
trigger_prefix = conf().get('plugin_trigger_prefix',"$") | |||
if trigger_prefix == "#": # 跟插件聊天指令前缀相同,继续递交 | |||
trigger_prefix = conf().get("plugin_trigger_prefix", "$") | |||
if trigger_prefix == "#": # 跟插件聊天指令前缀相同,继续递交 | |||
return | |||
ok, result = False, f"未知指令:{cmd}\n查看指令列表请输入#help \n" | |||
reply = Reply() | |||
if ok: | |||
reply.type = ReplyType.INFO | |||
else: | |||
reply.type = ReplyType.ERROR | |||
reply.content = result | |||
e_context['reply'] = reply | |||
e_context["reply"] = reply | |||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 | |||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 | |||
elif not self.isrunning: | |||
e_context.action = EventAction.BREAK_PASS | |||
def authenticate(self, userid, args, isadmin, isgroup) -> Tuple[bool,str] : | |||
def authenticate(self, userid, args, isadmin, isgroup) -> Tuple[bool, str]: | |||
if isgroup: | |||
return False,"请勿在群聊中认证" | |||
return False, "请勿在群聊中认证" | |||
if isadmin: | |||
return False,"管理员账号无需认证" | |||
return False, "管理员账号无需认证" | |||
if len(args) != 1: | |||
return False,"请提供口令" | |||
return False, "请提供口令" | |||
password = args[0] | |||
if password == self.password: | |||
self.admin_users.append(userid) | |||
return True,"认证成功" | |||
return True, "认证成功" | |||
elif password == self.temp_password: | |||
self.admin_users.append(userid) | |||
return True,"认证成功,请尽快设置口令" | |||
return True, "认证成功,请尽快设置口令" | |||
else: | |||
return False,"认证失败" | |||
return False, "认证失败" | |||
def get_help_text(self, isadmin = False, isgroup = False, **kwargs): | |||
return get_help_text(isadmin, isgroup) | |||
def get_help_text(self, isadmin=False, isgroup=False, **kwargs): | |||
return get_help_text(isadmin, isgroup) |
@@ -1 +1 @@ | |||
from .hello import * | |||
from .hello import * |
@@ -1,14 +1,21 @@ | |||
# encoding:utf-8 | |||
import plugins | |||
from bridge.context import ContextType | |||
from bridge.reply import Reply, ReplyType | |||
from channel.chat_message import ChatMessage | |||
import plugins | |||
from plugins import * | |||
from common.log import logger | |||
from plugins import * | |||
@plugins.register(name="Hello", desire_priority=-1, hidden=True, desc="A simple plugin that says hello", version="0.1", author="lanvent") | |||
@plugins.register( | |||
name="Hello", | |||
desire_priority=-1, | |||
hidden=True, | |||
desc="A simple plugin that says hello", | |||
version="0.1", | |||
author="lanvent", | |||
) | |||
class Hello(Plugin): | |||
def __init__(self): | |||
super().__init__() | |||
@@ -16,33 +23,34 @@ class Hello(Plugin): | |||
logger.info("[Hello] inited") | |||
def on_handle_context(self, e_context: EventContext): | |||
if e_context['context'].type != ContextType.TEXT: | |||
if e_context["context"].type != ContextType.TEXT: | |||
return | |||
content = e_context['context'].content | |||
content = e_context["context"].content | |||
logger.debug("[Hello] on_handle_context. content: %s" % content) | |||
if content == "Hello": | |||
reply = Reply() | |||
reply.type = ReplyType.TEXT | |||
msg:ChatMessage = e_context['context']['msg'] | |||
if e_context['context']['isgroup']: | |||
reply.content = f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}" | |||
msg: ChatMessage = e_context["context"]["msg"] | |||
if e_context["context"]["isgroup"]: | |||
reply.content = ( | |||
f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}" | |||
) | |||
else: | |||
reply.content = f"Hello, {msg.from_user_nickname}" | |||
e_context['reply'] = reply | |||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 | |||
e_context["reply"] = reply | |||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 | |||
if content == "Hi": | |||
reply = Reply() | |||
reply.type = ReplyType.TEXT | |||
reply.content = "Hi" | |||
e_context['reply'] = reply | |||
e_context["reply"] = reply | |||
e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑,一般会覆写reply | |||
if content == "End": | |||
# 如果是文本消息"End",将请求转换成"IMAGE_CREATE",并将content设置为"The World" | |||
e_context['context'].type = ContextType.IMAGE_CREATE | |||
e_context["context"].type = ContextType.IMAGE_CREATE | |||
content = "The World" | |||
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 | |||
@@ -3,4 +3,4 @@ class Plugin: | |||
self.handlers = {} | |||
def get_help_text(self, **kwargs): | |||
return "暂无帮助信息" | |||
return "暂无帮助信息" |
@@ -5,17 +5,19 @@ import importlib.util | |||
import json | |||
import os | |||
import sys | |||
from common.log import logger | |||
from common.singleton import singleton | |||
from common.sorted_dict import SortedDict | |||
from .event import * | |||
from common.log import logger | |||
from config import conf | |||
from .event import * | |||
@singleton | |||
class PluginManager: | |||
def __init__(self): | |||
self.plugins = SortedDict(lambda k,v: v.priority,reverse=True) | |||
self.plugins = SortedDict(lambda k, v: v.priority, reverse=True) | |||
self.listening_plugins = {} | |||
self.instances = {} | |||
self.pconf = {} | |||
@@ -26,17 +28,27 @@ class PluginManager: | |||
def wrapper(plugincls): | |||
plugincls.name = name | |||
plugincls.priority = desire_priority | |||
plugincls.desc = kwargs.get('desc') | |||
plugincls.author = kwargs.get('author') | |||
plugincls.desc = kwargs.get("desc") | |||
plugincls.author = kwargs.get("author") | |||
plugincls.path = self.current_plugin_path | |||
plugincls.version = kwargs.get('version') if kwargs.get('version') != None else "1.0" | |||
plugincls.namecn = kwargs.get('namecn') if kwargs.get('namecn') != None else name | |||
plugincls.hidden = kwargs.get('hidden') if kwargs.get('hidden') != None else False | |||
plugincls.version = ( | |||
kwargs.get("version") if kwargs.get("version") != None else "1.0" | |||
) | |||
plugincls.namecn = ( | |||
kwargs.get("namecn") if kwargs.get("namecn") != None else name | |||
) | |||
plugincls.hidden = ( | |||
kwargs.get("hidden") if kwargs.get("hidden") != None else False | |||
) | |||
plugincls.enabled = True | |||
if self.current_plugin_path == None: | |||
raise Exception("Plugin path not set") | |||
self.plugins[name.upper()] = plugincls | |||
logger.info("Plugin %s_v%s registered, path=%s" % (name, plugincls.version, plugincls.path)) | |||
logger.info( | |||
"Plugin %s_v%s registered, path=%s" | |||
% (name, plugincls.version, plugincls.path) | |||
) | |||
return wrapper | |||
def save_config(self): | |||
@@ -50,10 +62,12 @@ class PluginManager: | |||
if os.path.exists("./plugins/plugins.json"): | |||
with open("./plugins/plugins.json", "r", encoding="utf-8") as f: | |||
pconf = json.load(f) | |||
pconf['plugins'] = SortedDict(lambda k,v: v["priority"],pconf['plugins'],reverse=True) | |||
pconf["plugins"] = SortedDict( | |||
lambda k, v: v["priority"], pconf["plugins"], reverse=True | |||
) | |||
else: | |||
modified = True | |||
pconf = {"plugins": SortedDict(lambda k,v: v["priority"],reverse=True)} | |||
pconf = {"plugins": SortedDict(lambda k, v: v["priority"], reverse=True)} | |||
self.pconf = pconf | |||
if modified: | |||
self.save_config() | |||
@@ -67,7 +81,7 @@ class PluginManager: | |||
plugin_path = os.path.join(plugins_dir, plugin_name) | |||
if os.path.isdir(plugin_path): | |||
# 判断插件是否包含同名__init__.py文件 | |||
main_module_path = os.path.join(plugin_path,"__init__.py") | |||
main_module_path = os.path.join(plugin_path, "__init__.py") | |||
if os.path.isfile(main_module_path): | |||
# 导入插件 | |||
import_path = "plugins.{}".format(plugin_name) | |||
@@ -76,16 +90,26 @@ class PluginManager: | |||
if plugin_path in self.loaded: | |||
if self.loaded[plugin_path] == None: | |||
logger.info("reload module %s" % plugin_name) | |||
self.loaded[plugin_path] = importlib.reload(sys.modules[import_path]) | |||
dependent_module_names = [name for name in sys.modules.keys() if name.startswith( import_path+ '.')] | |||
self.loaded[plugin_path] = importlib.reload( | |||
sys.modules[import_path] | |||
) | |||
dependent_module_names = [ | |||
name | |||
for name in sys.modules.keys() | |||
if name.startswith(import_path + ".") | |||
] | |||
for name in dependent_module_names: | |||
logger.info("reload module %s" % name) | |||
importlib.reload(sys.modules[name]) | |||
else: | |||
self.loaded[plugin_path] = importlib.import_module(import_path) | |||
self.loaded[plugin_path] = importlib.import_module( | |||
import_path | |||
) | |||
self.current_plugin_path = None | |||
except Exception as e: | |||
logger.exception("Failed to import plugin %s: %s" % (plugin_name, e)) | |||
logger.exception( | |||
"Failed to import plugin %s: %s" % (plugin_name, e) | |||
) | |||
continue | |||
pconf = self.pconf | |||
news = [self.plugins[name] for name in self.plugins] | |||
@@ -95,21 +119,28 @@ class PluginManager: | |||
rawname = plugincls.name | |||
if rawname not in pconf["plugins"]: | |||
modified = True | |||
logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name) | |||
pconf["plugins"][rawname] = {"enabled": plugincls.enabled, "priority": plugincls.priority} | |||
logger.info( | |||
"Plugin %s not found in pconfig, adding to pconfig..." % name | |||
) | |||
pconf["plugins"][rawname] = { | |||
"enabled": plugincls.enabled, | |||
"priority": plugincls.priority, | |||
} | |||
else: | |||
self.plugins[name].enabled = pconf["plugins"][rawname]["enabled"] | |||
self.plugins[name].priority = pconf["plugins"][rawname]["priority"] | |||
self.plugins._update_heap(name) # 更新下plugins中的顺序 | |||
self.plugins._update_heap(name) # 更新下plugins中的顺序 | |||
if modified: | |||
self.save_config() | |||
return new_plugins | |||
def refresh_order(self): | |||
for event in self.listening_plugins.keys(): | |||
self.listening_plugins[event].sort(key=lambda name: self.plugins[name].priority, reverse=True) | |||
self.listening_plugins[event].sort( | |||
key=lambda name: self.plugins[name].priority, reverse=True | |||
) | |||
def activate_plugins(self): # 生成新开启的插件实例 | |||
def activate_plugins(self): # 生成新开启的插件实例 | |||
failed_plugins = [] | |||
for name, plugincls in self.plugins.items(): | |||
if plugincls.enabled: | |||
@@ -129,7 +160,7 @@ class PluginManager: | |||
self.refresh_order() | |||
return failed_plugins | |||
def reload_plugin(self, name:str): | |||
def reload_plugin(self, name: str): | |||
name = name.upper() | |||
if name in self.instances: | |||
for event in self.listening_plugins: | |||
@@ -139,13 +170,13 @@ class PluginManager: | |||
self.activate_plugins() | |||
return True | |||
return False | |||
def load_plugins(self): | |||
self.load_config() | |||
self.scan_plugins() | |||
pconf = self.pconf | |||
logger.debug("plugins.json config={}".format(pconf)) | |||
for name,plugin in pconf["plugins"].items(): | |||
for name, plugin in pconf["plugins"].items(): | |||
if name.upper() not in self.plugins: | |||
logger.error("Plugin %s not found, but found in plugins.json" % name) | |||
self.activate_plugins() | |||
@@ -153,13 +184,18 @@ class PluginManager: | |||
def emit_event(self, e_context: EventContext, *args, **kwargs): | |||
if e_context.event in self.listening_plugins: | |||
for name in self.listening_plugins[e_context.event]: | |||
if self.plugins[name].enabled and e_context.action == EventAction.CONTINUE: | |||
logger.debug("Plugin %s triggered by event %s" % (name,e_context.event)) | |||
if ( | |||
self.plugins[name].enabled | |||
and e_context.action == EventAction.CONTINUE | |||
): | |||
logger.debug( | |||
"Plugin %s triggered by event %s" % (name, e_context.event) | |||
) | |||
instance = self.instances[name] | |||
instance.handlers[e_context.event](e_context, *args, **kwargs) | |||
return e_context | |||
def set_plugin_priority(self, name:str, priority:int): | |||
def set_plugin_priority(self, name: str, priority: int): | |||
name = name.upper() | |||
if name not in self.plugins: | |||
return False | |||
@@ -174,11 +210,11 @@ class PluginManager: | |||
self.refresh_order() | |||
return True | |||
def enable_plugin(self, name:str): | |||
def enable_plugin(self, name: str): | |||
name = name.upper() | |||
if name not in self.plugins: | |||
return False, "插件不存在" | |||
if not self.plugins[name].enabled : | |||
if not self.plugins[name].enabled: | |||
self.plugins[name].enabled = True | |||
rawname = self.plugins[name].name | |||
self.pconf["plugins"][rawname]["enabled"] = True | |||
@@ -188,43 +224,47 @@ class PluginManager: | |||
return False, "插件开启失败" | |||
return True, "插件已开启" | |||
return True, "插件已开启" | |||
def disable_plugin(self, name:str): | |||
def disable_plugin(self, name: str): | |||
name = name.upper() | |||
if name not in self.plugins: | |||
return False | |||
if self.plugins[name].enabled : | |||
if self.plugins[name].enabled: | |||
self.plugins[name].enabled = False | |||
rawname = self.plugins[name].name | |||
self.pconf["plugins"][rawname]["enabled"] = False | |||
self.save_config() | |||
return True | |||
return True | |||
def list_plugins(self): | |||
return self.plugins | |||
def install_plugin(self, repo:str): | |||
def install_plugin(self, repo: str): | |||
try: | |||
import common.package_manager as pkgmgr | |||
pkgmgr.check_dulwich() | |||
except Exception as e: | |||
logger.error("Failed to install plugin, {}".format(e)) | |||
return False, "无法导入dulwich,安装插件失败" | |||
import re | |||
from dulwich import porcelain | |||
logger.info("clone git repo: {}".format(repo)) | |||
match = re.match(r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo) | |||
if not match: | |||
try: | |||
with open("./plugins/source.json","r", encoding="utf-8") as f: | |||
with open("./plugins/source.json", "r", encoding="utf-8") as f: | |||
source = json.load(f) | |||
if repo in source["repo"]: | |||
repo = source["repo"][repo]["url"] | |||
match = re.match(r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo) | |||
match = re.match( | |||
r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo | |||
) | |||
if not match: | |||
return False, "安装插件失败,source中的仓库地址不合法" | |||
else: | |||
@@ -232,42 +272,53 @@ class PluginManager: | |||
except Exception as e: | |||
logger.error("Failed to install plugin, {}".format(e)) | |||
return False, "安装插件失败,请检查仓库地址是否正确" | |||
dirname = os.path.join("./plugins",match.group(4)) | |||
dirname = os.path.join("./plugins", match.group(4)) | |||
try: | |||
repo = porcelain.clone(repo, dirname, checkout=True) | |||
if os.path.exists(os.path.join(dirname,"requirements.txt")): | |||
if os.path.exists(os.path.join(dirname, "requirements.txt")): | |||
logger.info("detect requirements.txt,installing...") | |||
pkgmgr.install_requirements(os.path.join(dirname,"requirements.txt")) | |||
pkgmgr.install_requirements(os.path.join(dirname, "requirements.txt")) | |||
return True, "安装插件成功,请使用 #scanp 命令扫描插件或重启程序,开启前请检查插件是否需要配置" | |||
except Exception as e: | |||
logger.error("Failed to install plugin, {}".format(e)) | |||
return False, "安装插件失败,"+str(e) | |||
def update_plugin(self, name:str): | |||
return False, "安装插件失败," + str(e) | |||
def update_plugin(self, name: str): | |||
try: | |||
import common.package_manager as pkgmgr | |||
pkgmgr.check_dulwich() | |||
except Exception as e: | |||
logger.error("Failed to install plugin, {}".format(e)) | |||
return False, "无法导入dulwich,更新插件失败" | |||
from dulwich import porcelain | |||
name = name.upper() | |||
if name not in self.plugins: | |||
return False, "插件不存在" | |||
if name in ["HELLO","GODCMD","ROLE","TOOL","BDUNIT","BANWORDS","FINISH","DUNGEON"]: | |||
if name in [ | |||
"HELLO", | |||
"GODCMD", | |||
"ROLE", | |||
"TOOL", | |||
"BDUNIT", | |||
"BANWORDS", | |||
"FINISH", | |||
"DUNGEON", | |||
]: | |||
return False, "预置插件无法更新,请更新主程序仓库" | |||
dirname = self.plugins[name].path | |||
try: | |||
porcelain.pull(dirname, "origin") | |||
if os.path.exists(os.path.join(dirname,"requirements.txt")): | |||
if os.path.exists(os.path.join(dirname, "requirements.txt")): | |||
logger.info("detect requirements.txt,installing...") | |||
pkgmgr.install_requirements(os.path.join(dirname,"requirements.txt")) | |||
pkgmgr.install_requirements(os.path.join(dirname, "requirements.txt")) | |||
return True, "更新插件成功,请重新运行程序" | |||
except Exception as e: | |||
logger.error("Failed to update plugin, {}".format(e)) | |||
return False, "更新插件失败,"+str(e) | |||
def uninstall_plugin(self, name:str): | |||
return False, "更新插件失败," + str(e) | |||
def uninstall_plugin(self, name: str): | |||
name = name.upper() | |||
if name not in self.plugins: | |||
return False, "插件不存在" | |||
@@ -276,6 +327,7 @@ class PluginManager: | |||
dirname = self.plugins[name].path | |||
try: | |||
import shutil | |||
shutil.rmtree(dirname) | |||
rawname = self.plugins[name].name | |||
for event in self.listening_plugins: | |||
@@ -288,4 +340,4 @@ class PluginManager: | |||
return True, "卸载插件成功" | |||
except Exception as e: | |||
logger.error("Failed to uninstall plugin, {}".format(e)) | |||
return False, "卸载插件失败,请手动删除文件夹完成卸载,"+str(e) | |||
return False, "卸载插件失败,请手动删除文件夹完成卸载," + str(e) |
@@ -1 +1 @@ | |||
from .role import * | |||
from .role import * |
@@ -2,17 +2,18 @@ | |||
import json | |||
import os | |||
import plugins | |||
from bridge.bridge import Bridge | |||
from bridge.context import ContextType | |||
from bridge.reply import Reply, ReplyType | |||
from common import const | |||
from common.log import logger | |||
from config import conf | |||
import plugins | |||
from plugins import * | |||
from common.log import logger | |||
class RolePlay(): | |||
class RolePlay: | |||
def __init__(self, bot, sessionid, desc, wrapper=None): | |||
self.bot = bot | |||
self.sessionid = sessionid | |||
@@ -25,12 +26,20 @@ class RolePlay(): | |||
def action(self, user_action): | |||
session = self.bot.sessions.build_session(self.sessionid) | |||
if session.system_prompt != self.desc: # 目前没有触发session过期事件,这里先简单判断,然后重置 | |||
if session.system_prompt != self.desc: # 目前没有触发session过期事件,这里先简单判断,然后重置 | |||
session.set_system_prompt(self.desc) | |||
prompt = self.wrapper % user_action | |||
return prompt | |||
@plugins.register(name="Role", desire_priority=0, namecn="角色扮演", desc="为你的Bot设置预设角色", version="1.0", author="lanvent") | |||
@plugins.register( | |||
name="Role", | |||
desire_priority=0, | |||
namecn="角色扮演", | |||
desc="为你的Bot设置预设角色", | |||
version="1.0", | |||
author="lanvent", | |||
) | |||
class Role(Plugin): | |||
def __init__(self): | |||
super().__init__() | |||
@@ -39,7 +48,7 @@ class Role(Plugin): | |||
try: | |||
with open(config_path, "r", encoding="utf-8") as f: | |||
config = json.load(f) | |||
self.tags = { tag:(desc,[]) for tag,desc in config["tags"].items()} | |||
self.tags = {tag: (desc, []) for tag, desc in config["tags"].items()} | |||
self.roles = {} | |||
for role in config["roles"]: | |||
self.roles[role["title"].lower()] = role | |||
@@ -60,12 +69,16 @@ class Role(Plugin): | |||
logger.info("[Role] inited") | |||
except Exception as e: | |||
if isinstance(e, FileNotFoundError): | |||
logger.warn(f"[Role] init failed, {config_path} not found, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role .") | |||
logger.warn( | |||
f"[Role] init failed, {config_path} not found, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role ." | |||
) | |||
else: | |||
logger.warn("[Role] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role .") | |||
logger.warn( | |||
"[Role] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role ." | |||
) | |||
raise e | |||
def get_role(self, name, find_closest=True, min_sim = 0.35): | |||
def get_role(self, name, find_closest=True, min_sim=0.35): | |||
name = name.lower() | |||
found_role = None | |||
if name in self.roles: | |||
@@ -75,6 +88,7 @@ class Role(Plugin): | |||
def str_simularity(a, b): | |||
return difflib.SequenceMatcher(None, a, b).ratio() | |||
max_sim = min_sim | |||
max_role = None | |||
for role in self.roles: | |||
@@ -86,25 +100,24 @@ class Role(Plugin): | |||
return found_role | |||
def on_handle_context(self, e_context: EventContext): | |||
if e_context['context'].type != ContextType.TEXT: | |||
if e_context["context"].type != ContextType.TEXT: | |||
return | |||
bottype = Bridge().get_bot_type("chat") | |||
if bottype not in (const.CHATGPT, const.OPEN_AI): | |||
return | |||
bot = Bridge().get_bot("chat") | |||
content = e_context['context'].content[:] | |||
clist = e_context['context'].content.split(maxsplit=1) | |||
content = e_context["context"].content[:] | |||
clist = e_context["context"].content.split(maxsplit=1) | |||
desckey = None | |||
customize = False | |||
sessionid = e_context['context']['session_id'] | |||
trigger_prefix = conf().get('plugin_trigger_prefix', "$") | |||
sessionid = e_context["context"]["session_id"] | |||
trigger_prefix = conf().get("plugin_trigger_prefix", "$") | |||
if clist[0] == f"{trigger_prefix}停止扮演": | |||
if sessionid in self.roleplays: | |||
self.roleplays[sessionid].reset() | |||
del self.roleplays[sessionid] | |||
reply = Reply(ReplyType.INFO, "角色扮演结束!") | |||
e_context['reply'] = reply | |||
e_context["reply"] = reply | |||
e_context.action = EventAction.BREAK_PASS | |||
return | |||
elif clist[0] == f"{trigger_prefix}角色": | |||
@@ -114,10 +127,10 @@ class Role(Plugin): | |||
elif clist[0] == f"{trigger_prefix}设定扮演": | |||
customize = True | |||
elif clist[0] == f"{trigger_prefix}角色类型": | |||
if len(clist) >1: | |||
if len(clist) > 1: | |||
tag = clist[1].strip() | |||
help_text = "角色列表:\n" | |||
for key,value in self.tags.items(): | |||
for key, value in self.tags.items(): | |||
if value[0] == tag: | |||
tag = key | |||
break | |||
@@ -130,57 +143,75 @@ class Role(Plugin): | |||
else: | |||
help_text = f"未知角色类型。\n" | |||
help_text += "目前的角色类型有: \n" | |||
help_text += ",".join([self.tags[tag][0] for tag in self.tags])+"\n" | |||
help_text += ( | |||
",".join([self.tags[tag][0] for tag in self.tags]) + "\n" | |||
) | |||
else: | |||
help_text = f"请输入角色类型。\n" | |||
help_text += "目前的角色类型有: \n" | |||
help_text += ",".join([self.tags[tag][0] for tag in self.tags])+"\n" | |||
help_text += ",".join([self.tags[tag][0] for tag in self.tags]) + "\n" | |||
reply = Reply(ReplyType.INFO, help_text) | |||
e_context['reply'] = reply | |||
e_context["reply"] = reply | |||
e_context.action = EventAction.BREAK_PASS | |||
return | |||
elif sessionid not in self.roleplays: | |||
return | |||
logger.debug("[Role] on_handle_context. content: %s" % content) | |||
if desckey is not None: | |||
if len(clist) == 1 or (len(clist) > 1 and clist[1].lower() in ["help", "帮助"]): | |||
if len(clist) == 1 or ( | |||
len(clist) > 1 and clist[1].lower() in ["help", "帮助"] | |||
): | |||
reply = Reply(ReplyType.INFO, self.get_help_text(verbose=True)) | |||
e_context['reply'] = reply | |||
e_context["reply"] = reply | |||
e_context.action = EventAction.BREAK_PASS | |||
return | |||
role = self.get_role(clist[1]) | |||
if role is None: | |||
reply = Reply(ReplyType.ERROR, "角色不存在") | |||
e_context['reply'] = reply | |||
e_context["reply"] = reply | |||
e_context.action = EventAction.BREAK_PASS | |||
return | |||
else: | |||
self.roleplays[sessionid] = RolePlay(bot, sessionid, self.roles[role][desckey], self.roles[role].get("wrapper","%s")) | |||
reply = Reply(ReplyType.INFO, f"预设角色为 {role}:\n"+self.roles[role][desckey]) | |||
e_context['reply'] = reply | |||
self.roleplays[sessionid] = RolePlay( | |||
bot, | |||
sessionid, | |||
self.roles[role][desckey], | |||
self.roles[role].get("wrapper", "%s"), | |||
) | |||
reply = Reply( | |||
ReplyType.INFO, f"预设角色为 {role}:\n" + self.roles[role][desckey] | |||
) | |||
e_context["reply"] = reply | |||
e_context.action = EventAction.BREAK_PASS | |||
elif customize == True: | |||
self.roleplays[sessionid] = RolePlay(bot, sessionid, clist[1], "%s") | |||
reply = Reply(ReplyType.INFO, f"角色设定为:\n{clist[1]}") | |||
e_context['reply'] = reply | |||
e_context["reply"] = reply | |||
e_context.action = EventAction.BREAK_PASS | |||
else: | |||
prompt = self.roleplays[sessionid].action(content) | |||
e_context['context'].type = ContextType.TEXT | |||
e_context['context'].content = prompt | |||
e_context["context"].type = ContextType.TEXT | |||
e_context["context"].content = prompt | |||
e_context.action = EventAction.BREAK | |||
def get_help_text(self, verbose=False, **kwargs): | |||
help_text = "让机器人扮演不同的角色。\n" | |||
if not verbose: | |||
return help_text | |||
trigger_prefix = conf().get('plugin_trigger_prefix', "$") | |||
help_text = f"使用方法:\n{trigger_prefix}角色"+" 预设角色名: 设定角色为{预设角色名}。\n"+f"{trigger_prefix}role"+" 预设角色名: 同上,但使用英文设定。\n" | |||
help_text += f"{trigger_prefix}设定扮演"+" 角色设定: 设定自定义角色人设为{角色设定}。\n" | |||
trigger_prefix = conf().get("plugin_trigger_prefix", "$") | |||
help_text = ( | |||
f"使用方法:\n{trigger_prefix}角色" | |||
+ " 预设角色名: 设定角色为{预设角色名}。\n" | |||
+ f"{trigger_prefix}role" | |||
+ " 预设角色名: 同上,但使用英文设定。\n" | |||
) | |||
help_text += f"{trigger_prefix}设定扮演" + " 角色设定: 设定自定义角色人设为{角色设定}。\n" | |||
help_text += f"{trigger_prefix}停止扮演: 清除设定的角色。\n" | |||
help_text += f"{trigger_prefix}角色类型"+" 角色类型: 查看某类{角色类型}的所有预设角色,为所有时输出所有预设角色。\n" | |||
help_text += ( | |||
f"{trigger_prefix}角色类型" + " 角色类型: 查看某类{角色类型}的所有预设角色,为所有时输出所有预设角色。\n" | |||
) | |||
help_text += "\n目前的角色类型有: \n" | |||
help_text += ",".join([self.tags[tag][0] for tag in self.tags])+"。\n" | |||
help_text += ",".join([self.tags[tag][0] for tag in self.tags]) + "。\n" | |||
help_text += f"\n命令例子: \n{trigger_prefix}角色 写作助理\n" | |||
help_text += f"{trigger_prefix}角色类型 所有\n" | |||
help_text += f"{trigger_prefix}停止扮演\n" | |||
@@ -428,4 +428,4 @@ | |||
] | |||
} | |||
] | |||
} | |||
} |
@@ -1,16 +1,16 @@ | |||
{ | |||
"repo": { | |||
"sdwebui": { | |||
"url": "https://github.com/lanvent/plugin_sdwebui.git", | |||
"desc": "利用stable-diffusion画图的插件" | |||
}, | |||
"replicate": { | |||
"url": "https://github.com/lanvent/plugin_replicate.git", | |||
"desc": "利用replicate api画图的插件" | |||
}, | |||
"summary": { | |||
"url": "https://github.com/lanvent/plugin_summary.git", | |||
"desc": "总结聊天记录的插件" | |||
} | |||
"repo": { | |||
"sdwebui": { | |||
"url": "https://github.com/lanvent/plugin_sdwebui.git", | |||
"desc": "利用stable-diffusion画图的插件" | |||
}, | |||
"replicate": { | |||
"url": "https://github.com/lanvent/plugin_replicate.git", | |||
"desc": "利用replicate api画图的插件" | |||
}, | |||
"summary": { | |||
"url": "https://github.com/lanvent/plugin_summary.git", | |||
"desc": "总结聊天记录的插件" | |||
} | |||
} | |||
} | |||
} |
@@ -1,14 +1,14 @@ | |||
## 插件描述 | |||
一个能让chatgpt联网,搜索,数字运算的插件,将赋予强大且丰富的扩展能力 | |||
一个能让chatgpt联网,搜索,数字运算的插件,将赋予强大且丰富的扩展能力 | |||
使用该插件需在机器人回复你的前提下,在对话内容前加$tool;仅输入$tool将返回tool插件帮助信息,用于测试插件是否加载成功 | |||
### 本插件所有工具同步存放至专用仓库:[chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub) | |||
## 使用说明 | |||
使用该插件后将默认使用4个工具, 无需额外配置长期生效: | |||
### 1. python | |||
使用该插件后将默认使用4个工具, 无需额外配置长期生效: | |||
### 1. python | |||
###### python解释器,使用它来解释执行python指令,可以配合你想要chatgpt生成的代码输出结果或执行事务 | |||
### 2. url-get | |||
###### 往往用来获取某个网站具体内容,结果可能会被反爬策略影响 | |||
@@ -23,16 +23,16 @@ | |||
> meteo调优记录:https://github.com/zhayujie/chatgpt-on-wechat/issues/776#issuecomment-1500771334 | |||
## 使用本插件对话(prompt)技巧 | |||
### 1. 有指引的询问 | |||
## 使用本插件对话(prompt)技巧 | |||
### 1. 有指引的询问 | |||
#### 例如: | |||
- 总结这个链接的内容 https://github.com/goldfishh/chatgpt-tool-hub | |||
- 总结这个链接的内容 https://github.com/goldfishh/chatgpt-tool-hub | |||
- 使用Terminal执行curl cip.cc | |||
- 使用python查询今天日期 | |||
### 2. 使用搜索引擎工具 | |||
- 如果有搜索工具就能让chatgpt获取到你的未传达清楚的上下文信息,比如chatgpt不知道你的地理位置,现在时间等,所以无法查询到天气 | |||
## 其他工具 | |||
### 5. wikipedia | |||
@@ -55,9 +55,9 @@ | |||
### 10. google-search * | |||
###### google搜索引擎,申请流程较bing-search繁琐 | |||
###### 注1:带*工具需要获取api-key才能使用,部分工具需要外网支持 | |||
###### 注1:带*工具需要获取api-key才能使用,部分工具需要外网支持 | |||
#### [申请方法](https://github.com/goldfishh/chatgpt-tool-hub/blob/master/docs/apply_optional_tool.md) | |||
## config.json 配置说明 | |||
###### 默认工具无需配置,其它工具需手动配置,一个例子: | |||
```json | |||
@@ -71,15 +71,15 @@ | |||
} | |||
``` | |||
注:config.json文件非必须,未创建仍可使用本tool;带*工具需在kwargs填入对应api-key键值对 | |||
注:config.json文件非必须,未创建仍可使用本tool;带*工具需在kwargs填入对应api-key键值对 | |||
- `tools`:本插件初始化时加载的工具, 目前可选集:["wikipedia", "wolfram-alpha", "bing-search", "google-search", "news", "morning-news"] & 默认工具,除wikipedia工具之外均需要申请api-key | |||
- `kwargs`:工具执行时的配置,一般在这里存放**api-key**,或环境配置 | |||
- `request_timeout`: 访问openai接口的超时时间,默认与wechat-on-chatgpt配置一致,可单独配置 | |||
- `no_default`: 用于配置默认加载4个工具的行为,如果为true则仅使用tools列表工具,不加载默认工具 | |||
- `top_k_results`: 控制所有有关搜索的工具返回条目数,数字越高则参考信息越多,但无用信息可能干扰判断,该值一般为2 | |||
- `model_name`: 用于控制tool插件底层使用的llm模型,目前暂未测试3.5以外的模型,一般保持默认 | |||
## 备注 | |||
- 强烈建议申请搜索工具搭配使用,推荐bing-search | |||
- 虽然我会有意加入一些限制,但请不要使用本插件做危害他人的事情,请提前了解清楚某些内容是否会违反相关规定,建议提前做好过滤 | |||
@@ -1 +1 @@ | |||
from .tool import * | |||
from .tool import * |
@@ -1,8 +1,13 @@ | |||
{ | |||
"tools": ["python", "url-get", "terminal", "meteo-weather"], | |||
"tools": [ | |||
"python", | |||
"url-get", | |||
"terminal", | |||
"meteo-weather" | |||
], | |||
"kwargs": { | |||
"top_k_results": 2, | |||
"no_default": false, | |||
"model_name": "gpt-3.5-turbo" | |||
"top_k_results": 2, | |||
"no_default": false, | |||
"model_name": "gpt-3.5-turbo" | |||
} | |||
} | |||
} |
@@ -4,6 +4,7 @@ import os | |||
from chatgpt_tool_hub.apps import load_app | |||
from chatgpt_tool_hub.apps.app import App | |||
from chatgpt_tool_hub.tools.all_tool_list import get_all_tool_names | |||
import plugins | |||
from bridge.bridge import Bridge | |||
from bridge.context import ContextType | |||
@@ -14,7 +15,13 @@ from config import conf | |||
from plugins import * | |||
@plugins.register(name="tool", desc="Arming your ChatGPT bot with various tools", version="0.3", author="goldfishh", desire_priority=0) | |||
@plugins.register( | |||
name="tool", | |||
desc="Arming your ChatGPT bot with various tools", | |||
version="0.3", | |||
author="goldfishh", | |||
desire_priority=0, | |||
) | |||
class Tool(Plugin): | |||
def __init__(self): | |||
super().__init__() | |||
@@ -28,22 +35,26 @@ class Tool(Plugin): | |||
help_text = "这是一个能让chatgpt联网,搜索,数字运算的插件,将赋予强大且丰富的扩展能力。" | |||
if not verbose: | |||
return help_text | |||
trigger_prefix = conf().get('plugin_trigger_prefix', "$") | |||
trigger_prefix = conf().get("plugin_trigger_prefix", "$") | |||
help_text += "使用说明:\n" | |||
help_text += f"{trigger_prefix}tool "+"命令: 根据给出的{命令}使用一些可用工具尽力为你得到结果。\n" | |||
help_text += f"{trigger_prefix}tool " + "命令: 根据给出的{命令}使用一些可用工具尽力为你得到结果。\n" | |||
help_text += f"{trigger_prefix}tool reset: 重置工具。\n" | |||
return help_text | |||
def on_handle_context(self, e_context: EventContext): | |||
if e_context['context'].type != ContextType.TEXT: | |||
if e_context["context"].type != ContextType.TEXT: | |||
return | |||
# 暂时不支持未来扩展的bot | |||
if Bridge().get_bot_type("chat") not in (const.CHATGPT, const.OPEN_AI, const.CHATGPTONAZURE): | |||
if Bridge().get_bot_type("chat") not in ( | |||
const.CHATGPT, | |||
const.OPEN_AI, | |||
const.CHATGPTONAZURE, | |||
): | |||
return | |||
content = e_context['context'].content | |||
content_list = e_context['context'].content.split(maxsplit=1) | |||
content = e_context["context"].content | |||
content_list = e_context["context"].content.split(maxsplit=1) | |||
if not content or len(content_list) < 1: | |||
e_context.action = EventAction.CONTINUE | |||
@@ -52,13 +63,13 @@ class Tool(Plugin): | |||
logger.debug("[tool] on_handle_context. content: %s" % content) | |||
reply = Reply() | |||
reply.type = ReplyType.TEXT | |||
trigger_prefix = conf().get('plugin_trigger_prefix', "$") | |||
trigger_prefix = conf().get("plugin_trigger_prefix", "$") | |||
# todo: 有些工具必须要api-key,需要修改config文件,所以这里没有实现query增删tool的功能 | |||
if content.startswith(f"{trigger_prefix}tool"): | |||
if len(content_list) == 1: | |||
logger.debug("[tool]: get help") | |||
reply.content = self.get_help_text() | |||
e_context['reply'] = reply | |||
e_context["reply"] = reply | |||
e_context.action = EventAction.BREAK_PASS | |||
return | |||
elif len(content_list) > 1: | |||
@@ -66,12 +77,14 @@ class Tool(Plugin): | |||
logger.debug("[tool]: reset config") | |||
self.app = self._reset_app() | |||
reply.content = "重置工具成功" | |||
e_context['reply'] = reply | |||
e_context["reply"] = reply | |||
e_context.action = EventAction.BREAK_PASS | |||
return | |||
elif content_list[1].startswith("reset"): | |||
logger.debug("[tool]: remind") | |||
e_context['context'].content = "请你随机用一种聊天风格,提醒用户:如果想重置tool插件,reset之后不要加任何字符" | |||
e_context[ | |||
"context" | |||
].content = "请你随机用一种聊天风格,提醒用户:如果想重置tool插件,reset之后不要加任何字符" | |||
e_context.action = EventAction.BREAK | |||
return | |||
@@ -80,34 +93,35 @@ class Tool(Plugin): | |||
# Don't modify bot name | |||
all_sessions = Bridge().get_bot("chat").sessions | |||
user_session = all_sessions.session_query(query, e_context['context']['session_id']).messages | |||
user_session = all_sessions.session_query( | |||
query, e_context["context"]["session_id"] | |||
).messages | |||
# chatgpt-tool-hub will reply you with many tools | |||
logger.debug("[tool]: just-go") | |||
try: | |||
_reply = self.app.ask(query, user_session) | |||
e_context.action = EventAction.BREAK_PASS | |||
all_sessions.session_reply(_reply, e_context['context']['session_id']) | |||
all_sessions.session_reply( | |||
_reply, e_context["context"]["session_id"] | |||
) | |||
except Exception as e: | |||
logger.exception(e) | |||
logger.error(str(e)) | |||
e_context['context'].content = "请你随机用一种聊天风格,提醒用户:这个问题tool插件暂时无法处理" | |||
e_context["context"].content = "请你随机用一种聊天风格,提醒用户:这个问题tool插件暂时无法处理" | |||
reply.type = ReplyType.ERROR | |||
e_context.action = EventAction.BREAK | |||
return | |||
reply.content = _reply | |||
e_context['reply'] = reply | |||
e_context["reply"] = reply | |||
return | |||
def _read_json(self) -> dict: | |||
curdir = os.path.dirname(__file__) | |||
config_path = os.path.join(curdir, "config.json") | |||
tool_config = { | |||
"tools": [], | |||
"kwargs": {} | |||
} | |||
tool_config = {"tools": [], "kwargs": {}} | |||
if not os.path.exists(config_path): | |||
return tool_config | |||
else: | |||
@@ -123,7 +137,9 @@ class Tool(Plugin): | |||
"proxy": conf().get("proxy", ""), | |||
"request_timeout": conf().get("request_timeout", 60), | |||
# note: 目前tool暂未对其他模型测试,但这里仍对配置来源做了优先级区分,一般插件配置可覆盖全局配置 | |||
"model_name": tool_model_name if tool_model_name else conf().get("model", "gpt-3.5-turbo"), | |||
"model_name": tool_model_name | |||
if tool_model_name | |||
else conf().get("model", "gpt-3.5-turbo"), | |||
"no_default": kwargs.get("no_default", False), | |||
"top_k_results": kwargs.get("top_k_results", 2), | |||
# for news tool | |||
@@ -160,4 +176,7 @@ class Tool(Plugin): | |||
# filter not support tool | |||
tool_list = self._filter_tool_list(tool_config.get("tools", [])) | |||
return load_app(tools_list=tool_list, **self._build_tool_kwargs(tool_config.get("kwargs", {}))) | |||
return load_app( | |||
tools_list=tool_list, | |||
**self._build_tool_kwargs(tool_config.get("kwargs", {})), | |||
) |
@@ -4,3 +4,4 @@ PyQRCode>=1.2.1 | |||
qrcode>=7.4.2 | |||
requests>=2.28.2 | |||
chardet>=5.1.0 | |||
pre-commit |
@@ -8,7 +8,7 @@ echo $BASE_DIR | |||
# check the nohup.out log output file | |||
if [ ! -f "${BASE_DIR}/nohup.out" ]; then | |||
touch "${BASE_DIR}/nohup.out" | |||
echo "create file ${BASE_DIR}/nohup.out" | |||
echo "create file ${BASE_DIR}/nohup.out" | |||
fi | |||
nohup python3 "${BASE_DIR}/app.py" & tail -f "${BASE_DIR}/nohup.out" | |||
@@ -7,7 +7,7 @@ echo $BASE_DIR | |||
# check the nohup.out log output file | |||
if [ ! -f "${BASE_DIR}/nohup.out" ]; then | |||
echo "No file ${BASE_DIR}/nohup.out" | |||
echo "No file ${BASE_DIR}/nohup.out" | |||
exit -1; | |||
fi | |||
@@ -1,9 +1,12 @@ | |||
import shutil | |||
import wave | |||
import pysilk | |||
from pydub import AudioSegment | |||
sil_supports=[8000, 12000, 16000, 24000, 32000, 44100, 48000] # slk转wav时,支持的采样率 | |||
sil_supports = [8000, 12000, 16000, 24000, 32000, 44100, 48000] # slk转wav时,支持的采样率 | |||
def find_closest_sil_supports(sample_rate): | |||
""" | |||
找到最接近的支持的采样率 | |||
@@ -19,6 +22,7 @@ def find_closest_sil_supports(sample_rate): | |||
mindiff = diff | |||
return closest | |||
def get_pcm_from_wav(wav_path): | |||
""" | |||
从 wav 文件中读取 pcm | |||
@@ -29,31 +33,42 @@ def get_pcm_from_wav(wav_path): | |||
wav = wave.open(wav_path, "rb") | |||
return wav.readframes(wav.getnframes()) | |||
def any_to_wav(any_path, wav_path): | |||
""" | |||
把任意格式转成wav文件 | |||
""" | |||
if any_path.endswith('.wav'): | |||
if any_path.endswith(".wav"): | |||
shutil.copy2(any_path, wav_path) | |||
return | |||
if any_path.endswith('.sil') or any_path.endswith('.silk') or any_path.endswith('.slk'): | |||
if ( | |||
any_path.endswith(".sil") | |||
or any_path.endswith(".silk") | |||
or any_path.endswith(".slk") | |||
): | |||
return sil_to_wav(any_path, wav_path) | |||
audio = AudioSegment.from_file(any_path) | |||
audio.export(wav_path, format="wav") | |||
def any_to_sil(any_path, sil_path): | |||
""" | |||
把任意格式转成sil文件 | |||
""" | |||
if any_path.endswith('.sil') or any_path.endswith('.silk') or any_path.endswith('.slk'): | |||
if ( | |||
any_path.endswith(".sil") | |||
or any_path.endswith(".silk") | |||
or any_path.endswith(".slk") | |||
): | |||
shutil.copy2(any_path, sil_path) | |||
return 10000 | |||
if any_path.endswith('.wav'): | |||
if any_path.endswith(".wav"): | |||
return pcm_to_sil(any_path, sil_path) | |||
if any_path.endswith('.mp3'): | |||
if any_path.endswith(".mp3"): | |||
return mp3_to_sil(any_path, sil_path) | |||
raise NotImplementedError("Not support file type: {}".format(any_path)) | |||
def mp3_to_wav(mp3_path, wav_path): | |||
""" | |||
把mp3格式转成pcm文件 | |||
@@ -61,6 +76,7 @@ def mp3_to_wav(mp3_path, wav_path): | |||
audio = AudioSegment.from_mp3(mp3_path) | |||
audio.export(wav_path, format="wav") | |||
def pcm_to_sil(pcm_path, silk_path): | |||
""" | |||
wav 文件转成 silk | |||
@@ -72,12 +88,12 @@ def pcm_to_sil(pcm_path, silk_path): | |||
pcm_s16 = audio.set_sample_width(2) | |||
pcm_s16 = pcm_s16.set_frame_rate(rate) | |||
wav_data = pcm_s16.raw_data | |||
silk_data = pysilk.encode( | |||
wav_data, data_rate=rate, sample_rate=rate) | |||
silk_data = pysilk.encode(wav_data, data_rate=rate, sample_rate=rate) | |||
with open(silk_path, "wb") as f: | |||
f.write(silk_data) | |||
return audio.duration_seconds * 1000 | |||
def mp3_to_sil(mp3_path, silk_path): | |||
""" | |||
mp3 文件转成 silk | |||
@@ -95,6 +111,7 @@ def mp3_to_sil(mp3_path, silk_path): | |||
f.write(silk_data) | |||
return audio.duration_seconds * 1000 | |||
def sil_to_wav(silk_path, wav_path, rate: int = 24000): | |||
""" | |||
silk 文件转 wav | |||
@@ -1,16 +1,18 @@ | |||
""" | |||
azure voice service | |||
""" | |||
import json | |||
import os | |||
import time | |||
import azure.cognitiveservices.speech as speechsdk | |||
from bridge.reply import Reply, ReplyType | |||
from common.log import logger | |||
from common.tmp_dir import TmpDir | |||
from voice.voice import Voice | |||
from config import conf | |||
from voice.voice import Voice | |||
""" | |||
Azure voice | |||
主目录设置文件中需填写azure_voice_api_key和azure_voice_region | |||
@@ -19,50 +21,68 @@ Azure voice | |||
""" | |||
class AzureVoice(Voice): | |||
class AzureVoice(Voice): | |||
def __init__(self): | |||
try: | |||
curdir = os.path.dirname(__file__) | |||
config_path = os.path.join(curdir, "config.json") | |||
config = None | |||
if not os.path.exists(config_path): #如果没有配置文件,创建本地配置文件 | |||
config = { "speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural", "speech_recognition_language": "zh-CN"} | |||
if not os.path.exists(config_path): # 如果没有配置文件,创建本地配置文件 | |||
config = { | |||
"speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural", | |||
"speech_recognition_language": "zh-CN", | |||
} | |||
with open(config_path, "w") as fw: | |||
json.dump(config, fw, indent=4) | |||
else: | |||
with open(config_path, "r") as fr: | |||
config = json.load(fr) | |||
self.api_key = conf().get('azure_voice_api_key') | |||
self.api_region = conf().get('azure_voice_region') | |||
self.speech_config = speechsdk.SpeechConfig(subscription=self.api_key, region=self.api_region) | |||
self.speech_config.speech_synthesis_voice_name = config["speech_synthesis_voice_name"] | |||
self.speech_config.speech_recognition_language = config["speech_recognition_language"] | |||
self.api_key = conf().get("azure_voice_api_key") | |||
self.api_region = conf().get("azure_voice_region") | |||
self.speech_config = speechsdk.SpeechConfig( | |||
subscription=self.api_key, region=self.api_region | |||
) | |||
self.speech_config.speech_synthesis_voice_name = config[ | |||
"speech_synthesis_voice_name" | |||
] | |||
self.speech_config.speech_recognition_language = config[ | |||
"speech_recognition_language" | |||
] | |||
except Exception as e: | |||
logger.warn("AzureVoice init failed: %s, ignore " % e) | |||
def voiceToText(self, voice_file): | |||
audio_config = speechsdk.AudioConfig(filename=voice_file) | |||
speech_recognizer = speechsdk.SpeechRecognizer(speech_config=self.speech_config, audio_config=audio_config) | |||
speech_recognizer = speechsdk.SpeechRecognizer( | |||
speech_config=self.speech_config, audio_config=audio_config | |||
) | |||
result = speech_recognizer.recognize_once() | |||
if result.reason == speechsdk.ResultReason.RecognizedSpeech: | |||
logger.info('[Azure] voiceToText voice file name={} text={}'.format(voice_file, result.text)) | |||
logger.info( | |||
"[Azure] voiceToText voice file name={} text={}".format( | |||
voice_file, result.text | |||
) | |||
) | |||
reply = Reply(ReplyType.TEXT, result.text) | |||
else: | |||
logger.error('[Azure] voiceToText error, result={}'.format(result)) | |||
logger.error("[Azure] voiceToText error, result={}".format(result)) | |||
reply = Reply(ReplyType.ERROR, "抱歉,语音识别失败") | |||
return reply | |||
def textToVoice(self, text): | |||
fileName = TmpDir().path() + 'reply-' + str(int(time.time())) + '.wav' | |||
fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".wav" | |||
audio_config = speechsdk.AudioConfig(filename=fileName) | |||
speech_synthesizer = speechsdk.SpeechSynthesizer(speech_config=self.speech_config, audio_config=audio_config) | |||
speech_synthesizer = speechsdk.SpeechSynthesizer( | |||
speech_config=self.speech_config, audio_config=audio_config | |||
) | |||
result = speech_synthesizer.speak_text(text) | |||
if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted: | |||
logger.info( | |||
'[Azure] textToVoice text={} voice file name={}'.format(text, fileName)) | |||
"[Azure] textToVoice text={} voice file name={}".format(text, fileName) | |||
) | |||
reply = Reply(ReplyType.VOICE, fileName) | |||
else: | |||
logger.error('[Azure] textToVoice error, result={}'.format(result)) | |||
logger.error("[Azure] textToVoice error, result={}".format(result)) | |||
reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败") | |||
return reply |
@@ -1,4 +1,4 @@ | |||
{ | |||
"speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural", | |||
"speech_recognition_language": "zh-CN" | |||
} | |||
"speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural", | |||
"speech_recognition_language": "zh-CN" | |||
} |
@@ -29,7 +29,7 @@ dev_pid 必填 语言选择,填写语言对应的dev_pid值 | |||
2、对于def textToVoice(self, text)函数中调用的百度语音合成API,中接口调用synthesis(参数)在本目录下的`config.json`文件中进行配置。 | |||
参数 可需 描述 | |||
tex 必填 合成的文本,使用UTF-8编码,请注意文本长度必须小于1024字节 | |||
tex 必填 合成的文本,使用UTF-8编码,请注意文本长度必须小于1024字节 | |||
lan 必填 固定值zh。语言选择,目前只有中英文混合模式,填写固定值zh | |||
spd 选填 语速,取值0-15,默认为5中语速 | |||
pit 选填 音调,取值0-15,默认为5中语调 | |||
@@ -40,14 +40,14 @@ aue 选填 3为mp3格式(默认); 4为pcm-16k;5为pcm-8k;6为wav | |||
关于per参数的说明,注意您购买的哪个音库,就填写哪个音库的参数,否则会报错。如果您购买的是基础音库,那么per参数只能填写0到4,如果您购买的是精品音库,那么per参数只能填写5003,5118,106,110,111,103,5其他的都会报错。 | |||
### 配置文件 | |||
将文件夹中`config.json.template`复制为`config.json`。 | |||
``` json | |||
{ | |||
"lang": "zh", | |||
"lang": "zh", | |||
"ctp": 1, | |||
"spd": 5, | |||
"spd": 5, | |||
"pit": 5, | |||
"vol": 5, | |||
"per": 0 | |||
@@ -1,17 +1,19 @@ | |||
""" | |||
baidu voice service | |||
""" | |||
import json | |||
import os | |||
import time | |||
from aip import AipSpeech | |||
from bridge.reply import Reply, ReplyType | |||
from common.log import logger | |||
from common.tmp_dir import TmpDir | |||
from voice.voice import Voice | |||
from voice.audio_convert import get_pcm_from_wav | |||
from config import conf | |||
from voice.audio_convert import get_pcm_from_wav | |||
from voice.voice import Voice | |||
""" | |||
百度的语音识别API. | |||
dev_pid: | |||
@@ -28,40 +30,37 @@ from config import conf | |||
class BaiduVoice(Voice): | |||
def __init__(self): | |||
try: | |||
curdir = os.path.dirname(__file__) | |||
config_path = os.path.join(curdir, "config.json") | |||
bconf = None | |||
if not os.path.exists(config_path): #如果没有配置文件,创建本地配置文件 | |||
bconf = { "lang": "zh", "ctp": 1, "spd": 5, | |||
"pit": 5, "vol": 5, "per": 0} | |||
if not os.path.exists(config_path): # 如果没有配置文件,创建本地配置文件 | |||
bconf = {"lang": "zh", "ctp": 1, "spd": 5, "pit": 5, "vol": 5, "per": 0} | |||
with open(config_path, "w") as fw: | |||
json.dump(bconf, fw, indent=4) | |||
else: | |||
with open(config_path, "r") as fr: | |||
bconf = json.load(fr) | |||
self.app_id = conf().get('baidu_app_id') | |||
self.api_key = conf().get('baidu_api_key') | |||
self.secret_key = conf().get('baidu_secret_key') | |||
self.dev_id = conf().get('baidu_dev_pid') | |||
self.app_id = conf().get("baidu_app_id") | |||
self.api_key = conf().get("baidu_api_key") | |||
self.secret_key = conf().get("baidu_secret_key") | |||
self.dev_id = conf().get("baidu_dev_pid") | |||
self.lang = bconf["lang"] | |||
self.ctp = bconf["ctp"] | |||
self.spd = bconf["spd"] | |||
self.pit = bconf["pit"] | |||
self.vol = bconf["vol"] | |||
self.per = bconf["per"] | |||
self.client = AipSpeech(self.app_id, self.api_key, self.secret_key) | |||
except Exception as e: | |||
logger.warn("BaiduVoice init failed: %s, ignore " % e) | |||
def voiceToText(self, voice_file): | |||
# 识别本地文件 | |||
logger.debug('[Baidu] voice file name={}'.format(voice_file)) | |||
logger.debug("[Baidu] voice file name={}".format(voice_file)) | |||
pcm = get_pcm_from_wav(voice_file) | |||
res = self.client.asr(pcm, "pcm", 16000, {"dev_pid": self.dev_id}) | |||
if res["err_no"] == 0: | |||
@@ -72,21 +71,25 @@ class BaiduVoice(Voice): | |||
logger.info("百度语音识别出错了: {}".format(res["err_msg"])) | |||
if res["err_msg"] == "request pv too much": | |||
logger.info(" 出现这个原因很可能是你的百度语音服务调用量超出限制,或未开通付费") | |||
reply = Reply(ReplyType.ERROR, | |||
"百度语音识别出错了;{0}".format(res["err_msg"])) | |||
reply = Reply(ReplyType.ERROR, "百度语音识别出错了;{0}".format(res["err_msg"])) | |||
return reply | |||
def textToVoice(self, text): | |||
result = self.client.synthesis(text, self.lang, self.ctp, { | |||
'spd': self.spd, 'pit': self.pit, 'vol': self.vol, 'per': self.per}) | |||
result = self.client.synthesis( | |||
text, | |||
self.lang, | |||
self.ctp, | |||
{"spd": self.spd, "pit": self.pit, "vol": self.vol, "per": self.per}, | |||
) | |||
if not isinstance(result, dict): | |||
fileName = TmpDir().path() + 'reply-' + str(int(time.time())) + '.mp3' | |||
with open(fileName, 'wb') as f: | |||
fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3" | |||
with open(fileName, "wb") as f: | |||
f.write(result) | |||
logger.info( | |||
'[Baidu] textToVoice text={} voice file name={}'.format(text, fileName)) | |||
"[Baidu] textToVoice text={} voice file name={}".format(text, fileName) | |||
) | |||
reply = Reply(ReplyType.VOICE, fileName) | |||
else: | |||
logger.error('[Baidu] textToVoice error={}'.format(result)) | |||
logger.error("[Baidu] textToVoice error={}".format(result)) | |||
reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败") | |||
return reply |
@@ -1,8 +1,8 @@ | |||
{ | |||
"lang": "zh", | |||
"ctp": 1, | |||
"spd": 5, | |||
"pit": 5, | |||
"vol": 5, | |||
"per": 0 | |||
} | |||
{ | |||
"lang": "zh", | |||
"ctp": 1, | |||
"spd": 5, | |||
"pit": 5, | |||
"vol": 5, | |||
"per": 0 | |||
} |
@@ -1,11 +1,12 @@ | |||
""" | |||
google voice service | |||
""" | |||
import time | |||
import speech_recognition | |||
from gtts import gTTS | |||
from bridge.reply import Reply, ReplyType | |||
from common.log import logger | |||
from common.tmp_dir import TmpDir | |||
@@ -22,9 +23,12 @@ class GoogleVoice(Voice): | |||
with speech_recognition.AudioFile(voice_file) as source: | |||
audio = self.recognizer.record(source) | |||
try: | |||
text = self.recognizer.recognize_google(audio, language='zh-CN') | |||
text = self.recognizer.recognize_google(audio, language="zh-CN") | |||
logger.info( | |||
'[Google] voiceToText text={} voice file name={}'.format(text, voice_file)) | |||
"[Google] voiceToText text={} voice file name={}".format( | |||
text, voice_file | |||
) | |||
) | |||
reply = Reply(ReplyType.TEXT, text) | |||
except speech_recognition.UnknownValueError: | |||
reply = Reply(ReplyType.ERROR, "抱歉,我听不懂") | |||
@@ -32,13 +36,15 @@ class GoogleVoice(Voice): | |||
reply = Reply(ReplyType.ERROR, "抱歉,无法连接到 Google 语音识别服务;{0}".format(e)) | |||
finally: | |||
return reply | |||
def textToVoice(self, text): | |||
try: | |||
mp3File = TmpDir().path() + 'reply-' + str(int(time.time())) + '.mp3' | |||
tts = gTTS(text=text, lang='zh') | |||
tts.save(mp3File) | |||
mp3File = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3" | |||
tts = gTTS(text=text, lang="zh") | |||
tts.save(mp3File) | |||
logger.info( | |||
'[Google] textToVoice text={} voice file name={}'.format(text, mp3File)) | |||
"[Google] textToVoice text={} voice file name={}".format(text, mp3File) | |||
) | |||
reply = Reply(ReplyType.VOICE, mp3File) | |||
except Exception as e: | |||
reply = Reply(ReplyType.ERROR, str(e)) | |||
@@ -1,29 +1,32 @@ | |||
""" | |||
google voice service | |||
""" | |||
import json | |||
import openai | |||
from bridge.reply import Reply, ReplyType | |||
from config import conf | |||
from common.log import logger | |||
from config import conf | |||
from voice.voice import Voice | |||
class OpenaiVoice(Voice): | |||
def __init__(self): | |||
openai.api_key = conf().get('open_ai_api_key') | |||
openai.api_key = conf().get("open_ai_api_key") | |||
def voiceToText(self, voice_file): | |||
logger.debug( | |||
'[Openai] voice file name={}'.format(voice_file)) | |||
logger.debug("[Openai] voice file name={}".format(voice_file)) | |||
try: | |||
file = open(voice_file, "rb") | |||
result = openai.Audio.transcribe("whisper-1", file) | |||
text = result["text"] | |||
reply = Reply(ReplyType.TEXT, text) | |||
logger.info( | |||
'[Openai] voiceToText text={} voice file name={}'.format(text, voice_file)) | |||
"[Openai] voiceToText text={} voice file name={}".format( | |||
text, voice_file | |||
) | |||
) | |||
except Exception as e: | |||
reply = Reply(ReplyType.ERROR, str(e)) | |||
finally: | |||
@@ -1,10 +1,11 @@ | |||
""" | |||
pytts voice service (offline) | |||
""" | |||
import time | |||
import pyttsx3 | |||
from bridge.reply import Reply, ReplyType | |||
from common.log import logger | |||
from common.tmp_dir import TmpDir | |||
@@ -16,20 +17,21 @@ class PyttsVoice(Voice): | |||
def __init__(self): | |||
# 语速 | |||
self.engine.setProperty('rate', 125) | |||
self.engine.setProperty("rate", 125) | |||
# 音量 | |||
self.engine.setProperty('volume', 1.0) | |||
for voice in self.engine.getProperty('voices'): | |||
self.engine.setProperty("volume", 1.0) | |||
for voice in self.engine.getProperty("voices"): | |||
if "Chinese" in voice.name: | |||
self.engine.setProperty('voice', voice.id) | |||
self.engine.setProperty("voice", voice.id) | |||
def textToVoice(self, text): | |||
try: | |||
wavFile = TmpDir().path() + 'reply-' + str(int(time.time())) + '.wav' | |||
wavFile = TmpDir().path() + "reply-" + str(int(time.time())) + ".wav" | |||
self.engine.save_to_file(text, wavFile) | |||
self.engine.runAndWait() | |||
logger.info( | |||
'[Pytts] textToVoice text={} voice file name={}'.format(text, wavFile)) | |||
"[Pytts] textToVoice text={} voice file name={}".format(text, wavFile) | |||
) | |||
reply = Reply(ReplyType.VOICE, wavFile) | |||
except Exception as e: | |||
reply = Reply(ReplyType.ERROR, str(e)) | |||
@@ -2,6 +2,7 @@ | |||
Voice service abstract class | |||
""" | |||
class Voice(object): | |||
def voiceToText(self, voice_file): | |||
""" | |||
@@ -13,4 +14,4 @@ class Voice(object): | |||
""" | |||
Send text to voice service and get voice | |||
""" | |||
raise NotImplementedError | |||
raise NotImplementedError |
@@ -2,25 +2,31 @@ | |||
voice factory | |||
""" | |||
def create_voice(voice_type): | |||
""" | |||
create a voice instance | |||
:param voice_type: voice type code | |||
:return: voice instance | |||
""" | |||
if voice_type == 'baidu': | |||
if voice_type == "baidu": | |||
from voice.baidu.baidu_voice import BaiduVoice | |||
return BaiduVoice() | |||
elif voice_type == 'google': | |||
elif voice_type == "google": | |||
from voice.google.google_voice import GoogleVoice | |||
return GoogleVoice() | |||
elif voice_type == 'openai': | |||
elif voice_type == "openai": | |||
from voice.openai.openai_voice import OpenaiVoice | |||
return OpenaiVoice() | |||
elif voice_type == 'pytts': | |||
elif voice_type == "pytts": | |||
from voice.pytts.pytts_voice import PyttsVoice | |||
return PyttsVoice() | |||
elif voice_type == 'azure': | |||
elif voice_type == "azure": | |||
from voice.azure.azure_voice import AzureVoice | |||
return AzureVoice() | |||
raise RuntimeError |