Explorar el Código

formatting code

master
lanvent hace 1 año
padre
commit
8f72e8c3e6
Se han modificado 92 ficheros con 1843 adiciones y 1181 borrados
  1. +1
    -1
      .github/ISSUE_TEMPLATE.md
  2. +2
    -2
      .github/workflows/deploy-image.yml
  3. +5
    -5
      README.md
  4. +17
    -11
      app.py
  5. +24
    -8
      bot/baidu/baidu_unit_bot.py
  6. +1
    -1
      bot/bot.py
  7. +4
    -0
      bot/bot_factory.py
  8. +72
    -48
      bot/chatgpt/chat_gpt_bot.py
  9. +28
    -12
      bot/chatgpt/chat_gpt_session.py
  10. +62
    -39
      bot/openai/open_ai_bot.py
  11. +20
    -13
      bot/openai/open_ai_image.py
  12. +31
    -17
      bot/openai/open_ai_session.py
  13. +31
    -18
      bot/session_manager.py
  14. +13
    -16
      bridge/bridge.py
  15. +24
    -19
      bridge/context.py
  16. +11
    -8
      bridge/reply.py
  17. +5
    -3
      channel/channel.py
  18. +13
    -7
      channel/channel_factory.py
  19. +213
    -112
      channel/chat_channel.py
  20. +10
    -10
      channel/chat_message.py
  21. +26
    -10
      channel/terminal/terminal_channel.py
  22. +80
    -47
      channel/wechat/wechat_channel.py
  23. +28
    -28
      channel/wechat/wechat_message.py
  24. +54
    -40
      channel/wechat/wechaty_channel.py
  25. +28
    -18
      channel/wechat/wechaty_message.py
  26. +2
    -2
      channel/wechatmp/README.md
  27. +35
    -16
      channel/wechatmp/ServiceAccount.py
  28. +98
    -38
      channel/wechatmp/SubscribeAccount.py
  29. +15
    -10
      channel/wechatmp/common.py
  30. +20
    -18
      channel/wechatmp/receive.py
  31. +12
    -9
      channel/wechatmp/reply.py
  32. +47
    -38
      channel/wechatmp/wechatmp_channel.py
  33. +1
    -1
      common/const.py
  34. +2
    -2
      common/dequeue.py
  35. +1
    -1
      common/expired_dict.py
  36. +16
    -7
      common/log.py
  37. +11
    -5
      common/package_manager.py
  38. +1
    -1
      common/sorted_dict.py
  39. +22
    -10
      common/time_check.py
  40. +5
    -7
      common/tmp_dir.py
  41. +20
    -6
      config-template.json
  42. +21
    -32
      config.py
  43. +1
    -1
      docker/Dockerfile.debian
  44. +1
    -1
      docker/Dockerfile.debian.latest
  45. +1
    -2
      docker/build.alpine.sh
  46. +1
    -1
      docker/build.debian.sh
  47. +1
    -1
      docker/chatgpt-on-wechat-voice-reply/Dockerfile.alpine
  48. +1
    -1
      docker/chatgpt-on-wechat-voice-reply/Dockerfile.debian
  49. +2
    -2
      docker/sample-chatgpt-on-wechat/Makefile
  50. +14
    -14
      plugins/README.md
  51. +2
    -2
      plugins/__init__.py
  52. +1
    -1
      plugins/banwords/__init__.py
  53. +49
    -35
      plugins/banwords/banwords.py
  54. +4
    -4
      plugins/banwords/config.json.template
  55. +1
    -1
      plugins/bdunit/README.md
  56. +1
    -1
      plugins/bdunit/__init__.py
  57. +30
    -42
      plugins/bdunit/bdunit.py
  58. +4
    -4
      plugins/bdunit/config.json.template
  59. +1
    -1
      plugins/dungeon/__init__.py
  60. +47
    -28
      plugins/dungeon/dungeon.py
  61. +6
    -6
      plugins/event.py
  62. +1
    -1
      plugins/finish/__init__.py
  63. +15
    -9
      plugins/finish/finish.py
  64. +1
    -1
      plugins/godcmd/__init__.py
  65. +3
    -3
      plugins/godcmd/config.json.template
  66. +96
    -70
      plugins/godcmd/godcmd.py
  67. +1
    -1
      plugins/hello/__init__.py
  68. +22
    -14
      plugins/hello/hello.py
  69. +1
    -1
      plugins/plugin.py
  70. +104
    -52
      plugins/plugin_manager.py
  71. +1
    -1
      plugins/role/__init__.py
  72. +66
    -35
      plugins/role/role.py
  73. +1
    -1
      plugins/role/roles.json
  74. +14
    -14
      plugins/source.json
  75. +16
    -16
      plugins/tool/README.md
  76. +1
    -1
      plugins/tool/__init__.py
  77. +10
    -5
      plugins/tool/config.json.template
  78. +40
    -21
      plugins/tool/tool.py
  79. +1
    -0
      requirements.txt
  80. +1
    -1
      scripts/start.sh
  81. +1
    -1
      scripts/tout.sh
  82. +25
    -8
      voice/audio_convert.py
  83. +37
    -17
      voice/azure/azure_voice.py
  84. +3
    -3
      voice/azure/config.json.template
  85. +4
    -4
      voice/baidu/README.md
  86. +26
    -23
      voice/baidu/baidu_voice.py
  87. +8
    -8
      voice/baidu/config.json.template
  88. +13
    -7
      voice/google/google_voice.py
  89. +9
    -6
      voice/openai/openai_voice.py
  90. +9
    -7
      voice/pytts/pytts_voice.py
  91. +2
    -1
      voice/voice.py
  92. +11
    -5
      voice/voice_factory.py

+ 1
- 1
.github/ISSUE_TEMPLATE.md Ver fichero

@@ -27,5 +27,5 @@
### 环境

- 操作系统类型 (Mac/Windows/Linux):
- Python版本 ( 执行 `python3 -V` ):
- Python版本 ( 执行 `python3 -V` ):
- pip版本 ( 依赖问题此项必填,执行 `pip3 -V`):

+ 2
- 2
.github/workflows/deploy-image.yml Ver fichero

@@ -49,9 +49,9 @@ jobs:
file: ./docker/Dockerfile.latest
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
- uses: actions/delete-package-versions@v4
with:
with:
package-name: 'chatgpt-on-wechat'
package-type: 'container'
min-versions-to-keep: 10


+ 5
- 5
README.md Ver fichero

@@ -120,7 +120,7 @@ pip3 install azure-cognitiveservices-speech

```bash
# config.json文件内容示例
{
{
"open_ai_api_key": "YOUR API KEY", # 填入上面创建的 OpenAI API KEY
"model": "gpt-3.5-turbo", # 模型名称。当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
"proxy": "127.0.0.1:7890", # 代理客户端的ip和端口
@@ -128,7 +128,7 @@ pip3 install azure-cognitiveservices-speech
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
"group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
"image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀
"conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数
"speech_recognition": false, # 是否开启语音识别
@@ -160,7 +160,7 @@ pip3 install azure-cognitiveservices-speech
**4.其他配置**

+ `model`: 模型名称,目前支持 `gpt-3.5-turbo`, `text-davinci-003`, `gpt-4`, `gpt-4-32k` (其中gpt-4 api暂未开放)
+ `temperature`,`frequency_penalty`,`presence_penalty`: Chat API接口参数,详情参考[OpenAI官方文档。](https://platform.openai.com/docs/api-reference/chat)
+ `temperature`,`frequency_penalty`,`presence_penalty`: Chat API接口参数,详情参考[OpenAI官方文档。](https://platform.openai.com/docs/api-reference/chat)
+ `proxy`:由于目前 `openai` 接口国内无法访问,需配置代理客户端的地址,详情参考 [#351](https://github.com/zhayujie/chatgpt-on-wechat/issues/351)
+ 对于图像生成,在满足个人或群组触发条件外,还需要额外的关键词前缀来触发,对应配置 `image_create_prefix `
+ 关于OpenAI对话及图片接口的参数配置(内容自由度、回复字数限制、图片大小等),可以参考 [对话接口](https://beta.openai.com/docs/api-reference/completions) 和 [图像接口](https://beta.openai.com/docs/api-reference/completions) 文档直接在 [代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/bot/openai/open_ai_bot.py) `bot/openai/open_ai_bot.py` 中进行调整。
@@ -181,7 +181,7 @@ pip3 install azure-cognitiveservices-speech
```bash
python3 app.py
```
终端输出二维码后,使用微信进行扫码,当输出 "Start auto replying" 时表示自动回复程序已经成功运行了(注意:用于登录的微信需要在支付处已完成实名认证)。扫码登录后你的账号就成为机器人了,可以在微信手机端通过配置的关键词触发自动回复 (任意好友发送消息给你,或是自己发消息给好友),参考[#142](https://github.com/zhayujie/chatgpt-on-wechat/issues/142)。
终端输出二维码后,使用微信进行扫码,当输出 "Start auto replying" 时表示自动回复程序已经成功运行了(注意:用于登录的微信需要在支付处已完成实名认证)。扫码登录后你的账号就成为机器人了,可以在微信手机端通过配置的关键词触发自动回复 (任意好友发送消息给你,或是自己发消息给好友),参考[#142](https://github.com/zhayujie/chatgpt-on-wechat/issues/142)。


### 2.服务器部署
@@ -189,7 +189,7 @@ python3 app.py
使用nohup命令在后台运行程序:

```bash
touch nohup.out # 首次运行需要新建日志文件
touch nohup.out # 首次运行需要新建日志文件
nohup python3 app.py & tail -f nohup.out # 在后台运行程序并通过日志输出二维码
```
扫码登录后程序即可运行于服务器后台,此时可通过 `ctrl+c` 关闭日志,不会影响后台程序的运行。使用 `ps -ef | grep app.py | grep -v grep` 命令可查看运行于后台的进程,如果想要重新启动程序可以先 `kill` 掉对应的进程。日志关闭后如果想要再次打开只需输入 `tail -f nohup.out`。此外,`scripts` 目录下有一键运行、关闭程序的脚本供使用。


+ 17
- 11
app.py Ver fichero

@@ -1,23 +1,28 @@
# encoding:utf-8

import os
from config import conf, load_config
import signal
import sys

from channel import channel_factory
from common.log import logger
from config import conf, load_config
from plugins import *
import signal
import sys


def sigterm_handler_wrap(_signo):
old_handler = signal.getsignal(_signo)

def func(_signo, _stack_frame):
logger.info("signal {} received, exiting...".format(_signo))
conf().save_user_datas()
if callable(old_handler): # check old_handler
if callable(old_handler): # check old_handler
return old_handler(_signo, _stack_frame)
sys.exit(0)

signal.signal(_signo, func)


def run():
try:
# load config
@@ -28,17 +33,17 @@ def run():
sigterm_handler_wrap(signal.SIGTERM)

# create channel
channel_name=conf().get('channel_type', 'wx')
channel_name = conf().get("channel_type", "wx")

if "--cmd" in sys.argv:
channel_name = 'terminal'
channel_name = "terminal"

if channel_name == 'wxy':
os.environ['WECHATY_LOG']="warn"
if channel_name == "wxy":
os.environ["WECHATY_LOG"] = "warn"
# os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001'

channel = channel_factory.create_channel(channel_name)
if channel_name in ['wx','wxy','terminal','wechatmp','wechatmp_service']:
if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service"]:
PluginManager().load_plugins()

# startup channel
@@ -47,5 +52,6 @@ def run():
logger.error("App startup failed!")
logger.exception(e)

if __name__ == '__main__':
run()

if __name__ == "__main__":
run()

+ 24
- 8
bot/baidu/baidu_unit_bot.py Ver fichero

@@ -1,6 +1,7 @@
# encoding:utf-8

import requests

from bot.bot import Bot
from bridge.reply import Reply, ReplyType

@@ -9,20 +10,35 @@ from bridge.reply import Reply, ReplyType
class BaiduUnitBot(Bot):
def reply(self, query, context=None):
token = self.get_token()
url = 'https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=' + token
post_data = "{\"version\":\"3.0\",\"service_id\":\"S73177\",\"session_id\":\"\",\"log_id\":\"7758521\",\"skill_ids\":[\"1221886\"],\"request\":{\"terminal_id\":\"88888\",\"query\":\"" + query + "\", \"hyper_params\": {\"chat_custom_bot_profile\": 1}}}"
url = (
"https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token="
+ token
)
post_data = (
'{"version":"3.0","service_id":"S73177","session_id":"","log_id":"7758521","skill_ids":["1221886"],"request":{"terminal_id":"88888","query":"'
+ query
+ '", "hyper_params": {"chat_custom_bot_profile": 1}}}'
)
print(post_data)
headers = {'content-type': 'application/x-www-form-urlencoded'}
headers = {"content-type": "application/x-www-form-urlencoded"}
response = requests.post(url, data=post_data.encode(), headers=headers)
if response:
reply = Reply(ReplyType.TEXT, response.json()['result']['context']['SYS_PRESUMED_HIST'][1])
reply = Reply(
ReplyType.TEXT,
response.json()["result"]["context"]["SYS_PRESUMED_HIST"][1],
)
return reply

def get_token(self):
access_key = 'YOUR_ACCESS_KEY'
secret_key = 'YOUR_SECRET_KEY'
host = 'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=' + access_key + '&client_secret=' + secret_key
access_key = "YOUR_ACCESS_KEY"
secret_key = "YOUR_SECRET_KEY"
host = (
"https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id="
+ access_key
+ "&client_secret="
+ secret_key
)
response = requests.get(host)
if response:
print(response.json())
return response.json()['access_token']
return response.json()["access_token"]

+ 1
- 1
bot/bot.py Ver fichero

@@ -8,7 +8,7 @@ from bridge.reply import Reply


class Bot(object):
def reply(self, query, context : Context =None) -> Reply:
def reply(self, query, context: Context = None) -> Reply:
"""
bot auto-reply content
:param req: received message


+ 4
- 0
bot/bot_factory.py Ver fichero

@@ -13,20 +13,24 @@ def create_bot(bot_type):
if bot_type == const.BAIDU:
# Baidu Unit对话接口
from bot.baidu.baidu_unit_bot import BaiduUnitBot

return BaiduUnitBot()

elif bot_type == const.CHATGPT:
# ChatGPT 网页端web接口
from bot.chatgpt.chat_gpt_bot import ChatGPTBot

return ChatGPTBot()

elif bot_type == const.OPEN_AI:
# OpenAI 官方对话模型API
from bot.openai.open_ai_bot import OpenAIBot

return OpenAIBot()

elif bot_type == const.CHATGPTONAZURE:
# Azure chatgpt service https://azure.microsoft.com/en-in/products/cognitive-services/openai-service/
from bot.chatgpt.chat_gpt_bot import AzureChatGPTBot

return AzureChatGPTBot()
raise RuntimeError

+ 72
- 48
bot/chatgpt/chat_gpt_bot.py Ver fichero

@@ -1,42 +1,53 @@
# encoding:utf-8

import time

import openai
import openai.error

from bot.bot import Bot
from bot.chatgpt.chat_gpt_session import ChatGPTSession
from bot.openai.open_ai_image import OpenAIImage
from bot.session_manager import SessionManager
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from config import conf, load_config
from common.log import logger
from common.token_bucket import TokenBucket
import openai
import openai.error
import time
from config import conf, load_config


# OpenAI对话模型API (可用)
class ChatGPTBot(Bot,OpenAIImage):
class ChatGPTBot(Bot, OpenAIImage):
def __init__(self):
super().__init__()
# set the default api_key
openai.api_key = conf().get('open_ai_api_key')
if conf().get('open_ai_api_base'):
openai.api_base = conf().get('open_ai_api_base')
proxy = conf().get('proxy')
openai.api_key = conf().get("open_ai_api_key")
if conf().get("open_ai_api_base"):
openai.api_base = conf().get("open_ai_api_base")
proxy = conf().get("proxy")
if proxy:
openai.proxy = proxy
if conf().get('rate_limit_chatgpt'):
self.tb4chatgpt = TokenBucket(conf().get('rate_limit_chatgpt', 20))
self.sessions = SessionManager(ChatGPTSession, model= conf().get("model") or "gpt-3.5-turbo")
self.args ={
if conf().get("rate_limit_chatgpt"):
self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20))

self.sessions = SessionManager(
ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo"
)
self.args = {
"model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称
"temperature":conf().get('temperature', 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
# "max_tokens":4096, # 回复最大的字符数
"top_p":1,
"frequency_penalty":conf().get('frequency_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"presence_penalty":conf().get('presence_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"request_timeout": conf().get('request_timeout', None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
"timeout": conf().get('request_timeout', None), #重试超时时间,在这个时间内,将会自动重试
"top_p": 1,
"frequency_penalty": conf().get(
"frequency_penalty", 0.0
), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"presence_penalty": conf().get(
"presence_penalty", 0.0
), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"request_timeout": conf().get(
"request_timeout", None
), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
}

def reply(self, query, context=None):
@@ -44,39 +55,50 @@ class ChatGPTBot(Bot,OpenAIImage):
if context.type == ContextType.TEXT:
logger.info("[CHATGPT] query={}".format(query))


session_id = context['session_id']
session_id = context["session_id"]
reply = None
clear_memory_commands = conf().get('clear_memory_commands', ['#清除记忆'])
clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
if query in clear_memory_commands:
self.sessions.clear_session(session_id)
reply = Reply(ReplyType.INFO, '记忆已清除')
elif query == '#清除所有':
reply = Reply(ReplyType.INFO, "记忆已清除")
elif query == "#清除所有":
self.sessions.clear_all_session()
reply = Reply(ReplyType.INFO, '所有人记忆已清除')
elif query == '#更新配置':
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
elif query == "#更新配置":
load_config()
reply = Reply(ReplyType.INFO, '配置已更新')
reply = Reply(ReplyType.INFO, "配置已更新")
if reply:
return reply
session = self.sessions.session_query(query, session_id)
logger.debug("[CHATGPT] session query={}".format(session.messages))

api_key = context.get('openai_api_key')
api_key = context.get("openai_api_key")

# if context.get('stream'):
# # reply in stream
# return self.reply_text_stream(query, new_query, session_id)

reply_content = self.reply_text(session, api_key)
logger.debug("[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session.messages, session_id, reply_content["content"], reply_content["completion_tokens"]))
if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0:
reply = Reply(ReplyType.ERROR, reply_content['content'])
logger.debug(
"[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
session.messages,
session_id,
reply_content["content"],
reply_content["completion_tokens"],
)
)
if (
reply_content["completion_tokens"] == 0
and len(reply_content["content"]) > 0
):
reply = Reply(ReplyType.ERROR, reply_content["content"])
elif reply_content["completion_tokens"] > 0:
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
self.sessions.session_reply(
reply_content["content"], session_id, reply_content["total_tokens"]
)
reply = Reply(ReplyType.TEXT, reply_content["content"])
else:
reply = Reply(ReplyType.ERROR, reply_content['content'])
reply = Reply(ReplyType.ERROR, reply_content["content"])
logger.debug("[CHATGPT] reply {} used 0 tokens.".format(reply_content))
return reply

@@ -89,53 +111,55 @@ class ChatGPTBot(Bot,OpenAIImage):
reply = Reply(ReplyType.ERROR, retstring)
return reply
else:
reply = Reply(ReplyType.ERROR, 'Bot不支持处理{}类型的消息'.format(context.type))
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
return reply

def reply_text(self, session:ChatGPTSession, api_key=None, retry_count=0) -> dict:
'''
def reply_text(self, session: ChatGPTSession, api_key=None, retry_count=0) -> dict:
"""
call openai's ChatCompletion to get the answer
:param session: a conversation session
:param session_id: session id
:param retry_count: retry count
:return: {}
'''
"""
try:
if conf().get('rate_limit_chatgpt') and not self.tb4chatgpt.get_token():
if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token():
raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
# if api_key == None, the default openai.api_key will be used
response = openai.ChatCompletion.create(
api_key=api_key, messages=session.messages, **self.args
)
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
return {"total_tokens": response["usage"]["total_tokens"],
"completion_tokens": response["usage"]["completion_tokens"],
"content": response.choices[0]['message']['content']}
return {
"total_tokens": response["usage"]["total_tokens"],
"completion_tokens": response["usage"]["completion_tokens"],
"content": response.choices[0]["message"]["content"],
}
except Exception as e:
need_retry = retry_count < 2
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
if isinstance(e, openai.error.RateLimitError):
logger.warn("[CHATGPT] RateLimitError: {}".format(e))
result['content'] = "提问太快啦,请休息一下再问我吧"
result["content"] = "提问太快啦,请休息一下再问我吧"
if need_retry:
time.sleep(5)
elif isinstance(e, openai.error.Timeout):
logger.warn("[CHATGPT] Timeout: {}".format(e))
result['content'] = "我没有收到你的消息"
result["content"] = "我没有收到你的消息"
if need_retry:
time.sleep(5)
elif isinstance(e, openai.error.APIConnectionError):
logger.warn("[CHATGPT] APIConnectionError: {}".format(e))
need_retry = False
result['content'] = "我连接不到你的网络"
result["content"] = "我连接不到你的网络"
else:
logger.warn("[CHATGPT] Exception: {}".format(e))
need_retry = False
self.sessions.clear_session(session.session_id)

if need_retry:
logger.warn("[CHATGPT] 第{}次重试".format(retry_count+1))
return self.reply_text(session, api_key, retry_count+1)
logger.warn("[CHATGPT] 第{}次重试".format(retry_count + 1))
return self.reply_text(session, api_key, retry_count + 1)
else:
return result

@@ -145,4 +169,4 @@ class AzureChatGPTBot(ChatGPTBot):
super().__init__()
openai.api_type = "azure"
openai.api_version = "2023-03-15-preview"
self.args["deployment_id"] = conf().get("azure_deployment_id")
self.args["deployment_id"] = conf().get("azure_deployment_id")

+ 28
- 12
bot/chatgpt/chat_gpt_session.py Ver fichero

@@ -1,20 +1,23 @@
from bot.session_manager import Session
from common.log import logger
'''

"""
e.g. [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Who won the world series in 2020?"},
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
{"role": "user", "content": "Where was it played?"}
]
'''
"""


class ChatGPTSession(Session):
def __init__(self, session_id, system_prompt=None, model= "gpt-3.5-turbo"):
def __init__(self, session_id, system_prompt=None, model="gpt-3.5-turbo"):
super().__init__(session_id, system_prompt)
self.model = model
self.reset()
def discard_exceeding(self, max_tokens, cur_tokens= None):
def discard_exceeding(self, max_tokens, cur_tokens=None):
precise = True
try:
cur_tokens = self.calc_tokens()
@@ -22,7 +25,9 @@ class ChatGPTSession(Session):
precise = False
if cur_tokens is None:
raise e
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
logger.debug(
"Exception when counting tokens precisely for query: {}".format(e)
)
while cur_tokens > max_tokens:
if len(self.messages) > 2:
self.messages.pop(1)
@@ -34,25 +39,32 @@ class ChatGPTSession(Session):
cur_tokens = cur_tokens - max_tokens
break
elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
logger.warn(
"user message exceed max_tokens. total_tokens={}".format(cur_tokens)
)
break
else:
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
logger.debug(
"max_tokens={}, total_tokens={}, len(messages)={}".format(
max_tokens, cur_tokens, len(self.messages)
)
)
break
if precise:
cur_tokens = self.calc_tokens()
else:
cur_tokens = cur_tokens - max_tokens
return cur_tokens
def calc_tokens(self):
return num_tokens_from_messages(self.messages, self.model)

# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
def num_tokens_from_messages(messages, model):
"""Returns the number of tokens used by a list of messages."""
import tiktoken

try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
@@ -63,13 +75,17 @@ def num_tokens_from_messages(messages, model):
elif model == "gpt-4":
return num_tokens_from_messages(messages, model="gpt-4-0314")
elif model == "gpt-3.5-turbo-0301":
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_message = (
4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
)
tokens_per_name = -1 # if there's a name, the role is omitted
elif model == "gpt-4-0314":
tokens_per_message = 3
tokens_per_name = 1
else:
logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301.")
logger.warn(
f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301."
)
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
num_tokens = 0
for message in messages:


+ 62
- 39
bot/openai/open_ai_bot.py Ver fichero

@@ -1,41 +1,52 @@
# encoding:utf-8

import time

import openai
import openai.error

from bot.bot import Bot
from bot.openai.open_ai_image import OpenAIImage
from bot.openai.open_ai_session import OpenAISession
from bot.session_manager import SessionManager
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from config import conf
from common.log import logger
import openai
import openai.error
import time
from config import conf

user_session = dict()


# OpenAI对话模型API (可用)
class OpenAIBot(Bot, OpenAIImage):
def __init__(self):
super().__init__()
openai.api_key = conf().get('open_ai_api_key')
if conf().get('open_ai_api_base'):
openai.api_base = conf().get('open_ai_api_base')
proxy = conf().get('proxy')
openai.api_key = conf().get("open_ai_api_key")
if conf().get("open_ai_api_base"):
openai.api_base = conf().get("open_ai_api_base")
proxy = conf().get("proxy")
if proxy:
openai.proxy = proxy

self.sessions = SessionManager(OpenAISession, model= conf().get("model") or "text-davinci-003")
self.sessions = SessionManager(
OpenAISession, model=conf().get("model") or "text-davinci-003"
)
self.args = {
"model": conf().get("model") or "text-davinci-003", # 对话模型的名称
"temperature":conf().get('temperature', 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
"max_tokens":1200, # 回复最大的字符数
"top_p":1,
"frequency_penalty":conf().get('frequency_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"presence_penalty":conf().get('presence_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"request_timeout": conf().get('request_timeout', None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
"timeout": conf().get('request_timeout', None), #重试超时时间,在这个时间内,将会自动重试
"stop":["\n\n\n"]
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
"max_tokens": 1200, # 回复最大的字符数
"top_p": 1,
"frequency_penalty": conf().get(
"frequency_penalty", 0.0
), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"presence_penalty": conf().get(
"presence_penalty", 0.0
), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"request_timeout": conf().get(
"request_timeout", None
), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
"stop": ["\n\n\n"],
}

def reply(self, query, context=None):
@@ -43,24 +54,34 @@ class OpenAIBot(Bot, OpenAIImage):
if context and context.type:
if context.type == ContextType.TEXT:
logger.info("[OPEN_AI] query={}".format(query))
session_id = context['session_id']
session_id = context["session_id"]
reply = None
if query == '#清除记忆':
if query == "#清除记忆":
self.sessions.clear_session(session_id)
reply = Reply(ReplyType.INFO, '记忆已清除')
elif query == '#清除所有':
reply = Reply(ReplyType.INFO, "记忆已清除")
elif query == "#清除所有":
self.sessions.clear_all_session()
reply = Reply(ReplyType.INFO, '所有人记忆已清除')
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
else:
session = self.sessions.session_query(query, session_id)
result = self.reply_text(session)
total_tokens, completion_tokens, reply_content = result['total_tokens'], result['completion_tokens'], result['content']
logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens))
total_tokens, completion_tokens, reply_content = (
result["total_tokens"],
result["completion_tokens"],
result["content"],
)
logger.debug(
"[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
str(session), session_id, reply_content, completion_tokens
)
)

if total_tokens == 0 :
if total_tokens == 0:
reply = Reply(ReplyType.ERROR, reply_content)
else:
self.sessions.session_reply(reply_content, session_id, total_tokens)
self.sessions.session_reply(
reply_content, session_id, total_tokens
)
reply = Reply(ReplyType.TEXT, reply_content)
return reply
elif context.type == ContextType.IMAGE_CREATE:
@@ -72,42 +93,44 @@ class OpenAIBot(Bot, OpenAIImage):
reply = Reply(ReplyType.ERROR, retstring)
return reply

def reply_text(self, session:OpenAISession, retry_count=0):
def reply_text(self, session: OpenAISession, retry_count=0):
try:
response = openai.Completion.create(
prompt=str(session), **self.args
response = openai.Completion.create(prompt=str(session), **self.args)
res_content = (
response.choices[0]["text"].strip().replace("<|endoftext|>", "")
)
res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '')
total_tokens = response["usage"]["total_tokens"]
completion_tokens = response["usage"]["completion_tokens"]
logger.info("[OPEN_AI] reply={}".format(res_content))
return {"total_tokens": total_tokens,
"completion_tokens": completion_tokens,
"content": res_content}
return {
"total_tokens": total_tokens,
"completion_tokens": completion_tokens,
"content": res_content,
}
except Exception as e:
need_retry = retry_count < 2
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
if isinstance(e, openai.error.RateLimitError):
logger.warn("[OPEN_AI] RateLimitError: {}".format(e))
result['content'] = "提问太快啦,请休息一下再问我吧"
result["content"] = "提问太快啦,请休息一下再问我吧"
if need_retry:
time.sleep(5)
elif isinstance(e, openai.error.Timeout):
logger.warn("[OPEN_AI] Timeout: {}".format(e))
result['content'] = "我没有收到你的消息"
result["content"] = "我没有收到你的消息"
if need_retry:
time.sleep(5)
elif isinstance(e, openai.error.APIConnectionError):
logger.warn("[OPEN_AI] APIConnectionError: {}".format(e))
need_retry = False
result['content'] = "我连接不到你的网络"
result["content"] = "我连接不到你的网络"
else:
logger.warn("[OPEN_AI] Exception: {}".format(e))
need_retry = False
self.sessions.clear_session(session.session_id)

if need_retry:
logger.warn("[OPEN_AI] 第{}次重试".format(retry_count+1))
return self.reply_text(session, retry_count+1)
logger.warn("[OPEN_AI] 第{}次重试".format(retry_count + 1))
return self.reply_text(session, retry_count + 1)
else:
return result
return result

+ 20
- 13
bot/openai/open_ai_image.py Ver fichero

@@ -1,38 +1,45 @@
import time

import openai
import openai.error
from common.token_bucket import TokenBucket
from common.log import logger
from common.token_bucket import TokenBucket
from config import conf


# OPENAI提供的画图接口
class OpenAIImage(object):
def __init__(self):
openai.api_key = conf().get('open_ai_api_key')
if conf().get('rate_limit_dalle'):
self.tb4dalle = TokenBucket(conf().get('rate_limit_dalle', 50))
openai.api_key = conf().get("open_ai_api_key")
if conf().get("rate_limit_dalle"):
self.tb4dalle = TokenBucket(conf().get("rate_limit_dalle", 50))
def create_img(self, query, retry_count=0):
try:
if conf().get('rate_limit_dalle') and not self.tb4dalle.get_token():
if conf().get("rate_limit_dalle") and not self.tb4dalle.get_token():
return False, "请求太快了,请休息一下再问我吧"
logger.info("[OPEN_AI] image_query={}".format(query))
response = openai.Image.create(
prompt=query, #图片描述
n=1, #每次生成图片的数量
size="256x256" #图片大小,可选有 256x256, 512x512, 1024x1024
prompt=query, # 图片描述
n=1, # 每次生成图片的数量
size="256x256", # 图片大小,可选有 256x256, 512x512, 1024x1024
)
image_url = response['data'][0]['url']
image_url = response["data"][0]["url"]
logger.info("[OPEN_AI] image_url={}".format(image_url))
return True, image_url
except openai.error.RateLimitError as e:
logger.warn(e)
if retry_count < 1:
time.sleep(5)
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1))
return self.create_img(query, retry_count+1)
logger.warn(
"[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(
retry_count + 1
)
)
return self.create_img(query, retry_count + 1)
else:
return False, "提问太快啦,请休息一下再问我吧"
except Exception as e:
logger.exception(e)
return False, str(e)
return False, str(e)

+ 31
- 17
bot/openai/open_ai_session.py Ver fichero

@@ -1,32 +1,34 @@
from bot.session_manager import Session
from common.log import logger


class OpenAISession(Session):
def __init__(self, session_id, system_prompt=None, model= "text-davinci-003"):
def __init__(self, session_id, system_prompt=None, model="text-davinci-003"):
super().__init__(session_id, system_prompt)
self.model = model
self.reset()

def __str__(self):
# 构造对话模型的输入
'''
"""
e.g. Q: xxx
A: xxx
Q: xxx
'''
"""
prompt = ""
for item in self.messages:
if item['role'] == 'system':
prompt += item['content'] + "<|endoftext|>\n\n\n"
elif item['role'] == 'user':
prompt += "Q: " + item['content'] + "\n"
elif item['role'] == 'assistant':
prompt += "\n\nA: " + item['content'] + "<|endoftext|>\n"
if item["role"] == "system":
prompt += item["content"] + "<|endoftext|>\n\n\n"
elif item["role"] == "user":
prompt += "Q: " + item["content"] + "\n"
elif item["role"] == "assistant":
prompt += "\n\nA: " + item["content"] + "<|endoftext|>\n"

if len(self.messages) > 0 and self.messages[-1]['role'] == 'user':
if len(self.messages) > 0 and self.messages[-1]["role"] == "user":
prompt += "A: "
return prompt

def discard_exceeding(self, max_tokens, cur_tokens= None):
def discard_exceeding(self, max_tokens, cur_tokens=None):
precise = True
try:
cur_tokens = self.calc_tokens()
@@ -34,7 +36,9 @@ class OpenAISession(Session):
precise = False
if cur_tokens is None:
raise e
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
logger.debug(
"Exception when counting tokens precisely for query: {}".format(e)
)
while cur_tokens > max_tokens:
if len(self.messages) > 1:
self.messages.pop(0)
@@ -46,24 +50,34 @@ class OpenAISession(Session):
cur_tokens = len(str(self))
break
elif len(self.messages) == 1 and self.messages[0]["role"] == "user":
logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens))
logger.warn(
"user question exceed max_tokens. total_tokens={}".format(
cur_tokens
)
)
break
else:
logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages)))
logger.debug(
"max_tokens={}, total_tokens={}, len(conversation)={}".format(
max_tokens, cur_tokens, len(self.messages)
)
)
break
if precise:
cur_tokens = self.calc_tokens()
else:
cur_tokens = len(str(self))
return cur_tokens
def calc_tokens(self):
return num_tokens_from_string(str(self), self.model)


# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
def num_tokens_from_string(string: str, model: str) -> int:
"""Returns the number of tokens in a text string."""
import tiktoken

encoding = tiktoken.encoding_for_model(model)
num_tokens = len(encoding.encode(string,disallowed_special=()))
return num_tokens
num_tokens = len(encoding.encode(string, disallowed_special=()))
return num_tokens

+ 31
- 18
bot/session_manager.py Ver fichero

@@ -2,6 +2,7 @@ from common.expired_dict import ExpiredDict
from common.log import logger
from config import conf


class Session(object):
def __init__(self, session_id, system_prompt=None):
self.session_id = session_id
@@ -13,7 +14,7 @@ class Session(object):

# 重置会话
def reset(self):
system_item = {'role': 'system', 'content': self.system_prompt}
system_item = {"role": "system", "content": self.system_prompt}
self.messages = [system_item]

def set_system_prompt(self, system_prompt):
@@ -21,13 +22,13 @@ class Session(object):
self.reset()

def add_query(self, query):
user_item = {'role': 'user', 'content': query}
user_item = {"role": "user", "content": query}
self.messages.append(user_item)

def add_reply(self, reply):
assistant_item = {'role': 'assistant', 'content': reply}
assistant_item = {"role": "assistant", "content": reply}
self.messages.append(assistant_item)
def discard_exceeding(self, max_tokens=None, cur_tokens=None):
raise NotImplementedError

@@ -37,8 +38,8 @@ class Session(object):

class SessionManager(object):
def __init__(self, sessioncls, **session_args):
if conf().get('expires_in_seconds'):
sessions = ExpiredDict(conf().get('expires_in_seconds'))
if conf().get("expires_in_seconds"):
sessions = ExpiredDict(conf().get("expires_in_seconds"))
else:
sessions = dict()
self.sessions = sessions
@@ -46,20 +47,22 @@ class SessionManager(object):
self.session_args = session_args

def build_session(self, session_id, system_prompt=None):
'''
如果session_id不在sessions中,创建一个新的session并添加到sessions中
如果system_prompt不会空,会更新session的system_prompt并重置session
'''
"""
如果session_id不在sessions中,创建一个新的session并添加到sessions中
如果system_prompt不会空,会更新session的system_prompt并重置session
"""
if session_id is None:
return self.sessioncls(session_id, system_prompt, **self.session_args)
if session_id not in self.sessions:
self.sessions[session_id] = self.sessioncls(session_id, system_prompt, **self.session_args)
self.sessions[session_id] = self.sessioncls(
session_id, system_prompt, **self.session_args
)
elif system_prompt is not None: # 如果有新的system_prompt,更新并重置session
self.sessions[session_id].set_system_prompt(system_prompt)
session = self.sessions[session_id]
return session
def session_query(self, query, session_id):
session = self.build_session(session_id)
session.add_query(query)
@@ -68,23 +71,33 @@ class SessionManager(object):
total_tokens = session.discard_exceeding(max_tokens, None)
logger.debug("prompt tokens used={}".format(total_tokens))
except Exception as e:
logger.debug("Exception when counting tokens precisely for prompt: {}".format(str(e)))
logger.debug(
"Exception when counting tokens precisely for prompt: {}".format(str(e))
)
return session

def session_reply(self, reply, session_id, total_tokens = None):
def session_reply(self, reply, session_id, total_tokens=None):
session = self.build_session(session_id)
session.add_reply(reply)
try:
max_tokens = conf().get("conversation_max_tokens", 1000)
tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt))
logger.debug(
"raw total_tokens={}, savesession tokens={}".format(
total_tokens, tokens_cnt
)
)
except Exception as e:
logger.debug("Exception when counting tokens precisely for session: {}".format(str(e)))
logger.debug(
"Exception when counting tokens precisely for session: {}".format(
str(e)
)
)
return session

def clear_session(self, session_id):
if session_id in self.sessions:
del(self.sessions[session_id])
del self.sessions[session_id]

def clear_all_session(self):
self.sessions.clear()

+ 13
- 16
bridge/bridge.py Ver fichero

@@ -1,31 +1,31 @@
from bot import bot_factory
from bridge.context import Context
from bridge.reply import Reply
from common import const
from common.log import logger
from bot import bot_factory
from common.singleton import singleton
from voice import voice_factory
from config import conf
from common import const
from voice import voice_factory


@singleton
class Bridge(object):
def __init__(self):
self.btype={
self.btype = {
"chat": const.CHATGPT,
"voice_to_text": conf().get("voice_to_text", "openai"),
"text_to_voice": conf().get("text_to_voice", "google")
"text_to_voice": conf().get("text_to_voice", "google"),
}
model_type = conf().get("model")
if model_type in ["text-davinci-003"]:
self.btype['chat'] = const.OPEN_AI
self.btype["chat"] = const.OPEN_AI
if conf().get("use_azure_chatgpt", False):
self.btype['chat'] = const.CHATGPTONAZURE
self.bots={}
self.btype["chat"] = const.CHATGPTONAZURE
self.bots = {}

def get_bot(self,typename):
def get_bot(self, typename):
if self.bots.get(typename) is None:
logger.info("create bot {} for {}".format(self.btype[typename],typename))
logger.info("create bot {} for {}".format(self.btype[typename], typename))
if typename == "text_to_voice":
self.bots[typename] = voice_factory.create_voice(self.btype[typename])
elif typename == "voice_to_text":
@@ -33,18 +33,15 @@ class Bridge(object):
elif typename == "chat":
self.bots[typename] = bot_factory.create_bot(self.btype[typename])
return self.bots[typename]
def get_bot_type(self,typename):
return self.btype[typename]

def get_bot_type(self, typename):
return self.btype[typename]

def fetch_reply_content(self, query, context : Context) -> Reply:
def fetch_reply_content(self, query, context: Context) -> Reply:
return self.get_bot("chat").reply(query, context)


def fetch_voice_to_text(self, voiceFile) -> Reply:
return self.get_bot("voice_to_text").voiceToText(voiceFile)

def fetch_text_to_voice(self, text) -> Reply:
return self.get_bot("text_to_voice").textToVoice(text)


+ 24
- 19
bridge/context.py Ver fichero

@@ -2,36 +2,39 @@

from enum import Enum

class ContextType (Enum):
TEXT = 1 # 文本消息
VOICE = 2 # 音频消息
IMAGE = 3 # 图片消息
IMAGE_CREATE = 10 # 创建图片命令

class ContextType(Enum):
TEXT = 1 # 文本消息
VOICE = 2 # 音频消息
IMAGE = 3 # 图片消息
IMAGE_CREATE = 10 # 创建图片命令

def __str__(self):
return self.name


class Context:
def __init__(self, type : ContextType = None , content = None, kwargs = dict()):
def __init__(self, type: ContextType = None, content=None, kwargs=dict()):
self.type = type
self.content = content
self.kwargs = kwargs

def __contains__(self, key):
if key == 'type':
if key == "type":
return self.type is not None
elif key == 'content':
elif key == "content":
return self.content is not None
else:
return key in self.kwargs
def __getitem__(self, key):
if key == 'type':
if key == "type":
return self.type
elif key == 'content':
elif key == "content":
return self.content
else:
return self.kwargs[key]
def get(self, key, default=None):
try:
return self[key]
@@ -39,20 +42,22 @@ class Context:
return default

def __setitem__(self, key, value):
if key == 'type':
if key == "type":
self.type = value
elif key == 'content':
elif key == "content":
self.content = value
else:
self.kwargs[key] = value

def __delitem__(self, key):
if key == 'type':
if key == "type":
self.type = None
elif key == 'content':
elif key == "content":
self.content = None
else:
del self.kwargs[key]
def __str__(self):
return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs)
return "Context(type={}, content={}, kwargs={})".format(
self.type, self.content, self.kwargs
)

+ 11
- 8
bridge/reply.py Ver fichero

@@ -1,22 +1,25 @@

# encoding:utf-8

from enum import Enum


class ReplyType(Enum):
TEXT = 1 # 文本
VOICE = 2 # 音频文件
IMAGE = 3 # 图片文件
IMAGE_URL = 4 # 图片URL
TEXT = 1 # 文本
VOICE = 2 # 音频文件
IMAGE = 3 # 图片文件
IMAGE_URL = 4 # 图片URL
INFO = 9
ERROR = 10

def __str__(self):
return self.name


class Reply:
def __init__(self, type : ReplyType = None , content = None):
def __init__(self, type: ReplyType = None, content=None):
self.type = type
self.content = content

def __str__(self):
return "Reply(type={}, content={})".format(self.type, self.content)
return "Reply(type={}, content={})".format(self.type, self.content)

+ 5
- 3
channel/channel.py Ver fichero

@@ -6,8 +6,10 @@ from bridge.bridge import Bridge
from bridge.context import Context
from bridge.reply import *


class Channel(object):
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE, ReplyType.IMAGE]

def startup(self):
"""
init channel
@@ -27,15 +29,15 @@ class Channel(object):
send message to user
:param msg: message content
:param receiver: receiver channel account
:return:
:return:
"""
raise NotImplementedError

def build_reply_content(self, query, context : Context=None) -> Reply:
def build_reply_content(self, query, context: Context = None) -> Reply:
return Bridge().fetch_reply_content(query, context)

def build_voice_to_text(self, voice_file) -> Reply:
return Bridge().fetch_voice_to_text(voice_file)
def build_text_to_voice(self, text) -> Reply:
return Bridge().fetch_text_to_voice(text)

+ 13
- 7
channel/channel_factory.py Ver fichero

@@ -2,25 +2,31 @@
channel factory
"""


def create_channel(channel_type):
"""
create a channel instance
:param channel_type: channel type code
:return: channel instance
"""
if channel_type == 'wx':
if channel_type == "wx":
from channel.wechat.wechat_channel import WechatChannel

return WechatChannel()
elif channel_type == 'wxy':
elif channel_type == "wxy":
from channel.wechat.wechaty_channel import WechatyChannel

return WechatyChannel()
elif channel_type == 'terminal':
elif channel_type == "terminal":
from channel.terminal.terminal_channel import TerminalChannel

return TerminalChannel()
elif channel_type == 'wechatmp':
elif channel_type == "wechatmp":
from channel.wechatmp.wechatmp_channel import WechatMPChannel
return WechatMPChannel(passive_reply = True)
elif channel_type == 'wechatmp_service':

return WechatMPChannel(passive_reply=True)
elif channel_type == "wechatmp_service":
from channel.wechatmp.wechatmp_channel import WechatMPChannel
return WechatMPChannel(passive_reply = False)

return WechatMPChannel(passive_reply=False)
raise RuntimeError

+ 213
- 112
channel/chat_channel.py Ver fichero

@@ -1,137 +1,172 @@


from asyncio import CancelledError
from concurrent.futures import Future, ThreadPoolExecutor
import os
import re
import threading
import time
from common.dequeue import Dequeue
from channel.channel import Channel
from bridge.reply import *
from asyncio import CancelledError
from concurrent.futures import Future, ThreadPoolExecutor
from bridge.context import *
from config import conf
from bridge.reply import *
from channel.channel import Channel
from common.dequeue import Dequeue
from common.log import logger
from config import conf
from plugins import *

try:
from voice.audio_convert import any_to_wav
except Exception as e:
pass


# 抽象类, 它包含了与消息通道无关的通用处理逻辑
class ChatChannel(Channel):
name = None # 登录的用户名
user_id = None # 登录的用户id
futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消
sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理
lock = threading.Lock() # 用于控制对sessions的访问
name = None # 登录的用户名
user_id = None # 登录的用户id
futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消
sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理
lock = threading.Lock() # 用于控制对sessions的访问
handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池

def __init__(self):
_thread = threading.Thread(target=self.consume)
_thread.setDaemon(True)
_thread.start()

# 根据消息构造context,消息内容相关的触发项写在这里
def _compose_context(self, ctype: ContextType, content, **kwargs):
context = Context(ctype, content)
context.kwargs = kwargs
# context首次传入时,origin_ctype是None,
# context首次传入时,origin_ctype是None,
# 引入的起因是:当输入语音时,会嵌套生成两个context,第一步语音转文本,第二步通过文本生成文字回复。
# origin_ctype用于第二步文本回复时,判断是否需要匹配前缀,如果是私聊的语音,就不需要匹配前缀
if 'origin_ctype' not in context:
context['origin_ctype'] = ctype
if "origin_ctype" not in context:
context["origin_ctype"] = ctype
# context首次传入时,receiver是None,根据类型设置receiver
first_in = 'receiver' not in context
first_in = "receiver" not in context
# 群名匹配过程,设置session_id和receiver
if first_in: # context首次传入时,receiver是None,根据类型设置receiver
if first_in: # context首次传入时,receiver是None,根据类型设置receiver
config = conf()
cmsg = context['msg']
cmsg = context["msg"]
if context.get("isgroup", False):
group_name = cmsg.other_user_nickname
group_id = cmsg.other_user_id

group_name_white_list = config.get('group_name_white_list', [])
group_name_keyword_white_list = config.get('group_name_keyword_white_list', [])
if any([group_name in group_name_white_list, 'ALL_GROUP' in group_name_white_list, check_contain(group_name, group_name_keyword_white_list)]):
group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
group_name_white_list = config.get("group_name_white_list", [])
group_name_keyword_white_list = config.get(
"group_name_keyword_white_list", []
)
if any(
[
group_name in group_name_white_list,
"ALL_GROUP" in group_name_white_list,
check_contain(group_name, group_name_keyword_white_list),
]
):
group_chat_in_one_session = conf().get(
"group_chat_in_one_session", []
)
session_id = cmsg.actual_user_id
if any([group_name in group_chat_in_one_session, 'ALL_GROUP' in group_chat_in_one_session]):
if any(
[
group_name in group_chat_in_one_session,
"ALL_GROUP" in group_chat_in_one_session,
]
):
session_id = group_id
else:
return None
context['session_id'] = session_id
context['receiver'] = group_id
context["session_id"] = session_id
context["receiver"] = group_id
else:
context['session_id'] = cmsg.other_user_id
context['receiver'] = cmsg.other_user_id
e_context = PluginManager().emit_event(EventContext(Event.ON_RECEIVE_MESSAGE, {'channel': self, 'context': context}))
context = e_context['context']
context["session_id"] = cmsg.other_user_id
context["receiver"] = cmsg.other_user_id
e_context = PluginManager().emit_event(
EventContext(
Event.ON_RECEIVE_MESSAGE, {"channel": self, "context": context}
)
)
context = e_context["context"]
if e_context.is_pass() or context is None:
return context
if cmsg.from_user_id == self.user_id and not config.get('trigger_by_self', True):
if cmsg.from_user_id == self.user_id and not config.get(
"trigger_by_self", True
):
logger.debug("[WX]self message skipped")
return None

# 消息内容匹配过程,并处理content
if ctype == ContextType.TEXT:
if first_in and "」\n- - - - - - -" in content: # 初次匹配 过滤引用消息
if first_in and "」\n- - - - - - -" in content: # 初次匹配 过滤引用消息
logger.debug("[WX]reference query skipped")
return None
if context.get("isgroup", False): # 群聊
if context.get("isgroup", False): # 群聊
# 校验关键字
match_prefix = check_prefix(content, conf().get('group_chat_prefix'))
match_contain = check_contain(content, conf().get('group_chat_keyword'))
match_prefix = check_prefix(content, conf().get("group_chat_prefix"))
match_contain = check_contain(content, conf().get("group_chat_keyword"))
flag = False
if match_prefix is not None or match_contain is not None:
flag = True
if match_prefix:
content = content.replace(match_prefix, '', 1).strip()
if context['msg'].is_at:
content = content.replace(match_prefix, "", 1).strip()
if context["msg"].is_at:
logger.info("[WX]receive group at")
if not conf().get("group_at_off", False):
flag = True
pattern = f'@{self.name}(\u2005|\u0020)'
content = re.sub(pattern, r'', content)
pattern = f"@{self.name}(\u2005|\u0020)"
content = re.sub(pattern, r"", content)
if not flag:
if context["origin_ctype"] == ContextType.VOICE:
logger.info("[WX]receive group voice, but checkprefix didn't match")
logger.info(
"[WX]receive group voice, but checkprefix didn't match"
)
return None
else: # 单聊
match_prefix = check_prefix(content, conf().get('single_chat_prefix',['']))
if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容
content = content.replace(match_prefix, '', 1).strip()
elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
else: # 单聊
match_prefix = check_prefix(
content, conf().get("single_chat_prefix", [""])
)
if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容
content = content.replace(match_prefix, "", 1).strip()
elif (
context["origin_ctype"] == ContextType.VOICE
): # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
pass
else:
return None
img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
return None
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
if img_match_prefix:
content = content.replace(img_match_prefix, '', 1)
content = content.replace(img_match_prefix, "", 1)
context.type = ContextType.IMAGE_CREATE
else:
context.type = ContextType.TEXT
context.content = content.strip()
if 'desire_rtype' not in context and conf().get('always_reply_voice') and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
context['desire_rtype'] = ReplyType.VOICE
elif context.type == ContextType.VOICE:
if 'desire_rtype' not in context and conf().get('voice_reply_voice') and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
context['desire_rtype'] = ReplyType.VOICE
if (
"desire_rtype" not in context
and conf().get("always_reply_voice")
and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE
):
context["desire_rtype"] = ReplyType.VOICE
elif context.type == ContextType.VOICE:
if (
"desire_rtype" not in context
and conf().get("voice_reply_voice")
and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE
):
context["desire_rtype"] = ReplyType.VOICE

return context

def _handle(self, context: Context):
if context is None or not context.content:
return
logger.debug('[WX] ready to handle context: {}'.format(context))
logger.debug("[WX] ready to handle context: {}".format(context))
# reply的构建步骤
reply = self._generate_reply(context)

logger.debug('[WX] ready to decorate reply: {}'.format(reply))
logger.debug("[WX] ready to decorate reply: {}".format(reply))
# reply的包装步骤
reply = self._decorate_reply(context, reply)

@@ -139,20 +174,31 @@ class ChatChannel(Channel):
self._send_reply(context, reply)

def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply:
e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {
'channel': self, 'context': context, 'reply': reply}))
reply = e_context['reply']
e_context = PluginManager().emit_event(
EventContext(
Event.ON_HANDLE_CONTEXT,
{"channel": self, "context": context, "reply": reply},
)
)
reply = e_context["reply"]
if not e_context.is_pass():
logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content))
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息
logger.debug(
"[WX] ready to handle context: type={}, content={}".format(
context.type, context.content
)
)
if (
context.type == ContextType.TEXT
or context.type == ContextType.IMAGE_CREATE
): # 文字和图片消息
reply = super().build_reply_content(context.content, context)
elif context.type == ContextType.VOICE: # 语音消息
cmsg = context['msg']
cmsg = context["msg"]
cmsg.prepare()
file_path = context.content
wav_path = os.path.splitext(file_path)[0] + '.wav'
wav_path = os.path.splitext(file_path)[0] + ".wav"
try:
any_to_wav(file_path, wav_path)
any_to_wav(file_path, wav_path)
except Exception as e: # 转换失败,直接使用mp3,对于某些api,mp3也可以识别
logger.warning("[WX]any to wav error, use raw path. " + str(e))
wav_path = file_path
@@ -169,7 +215,8 @@ class ChatChannel(Channel):

if reply.type == ReplyType.TEXT:
new_context = self._compose_context(
ContextType.TEXT, reply.content, **context.kwargs)
ContextType.TEXT, reply.content, **context.kwargs
)
if new_context:
reply = self._generate_reply(new_context)
else:
@@ -177,18 +224,21 @@ class ChatChannel(Channel):
elif context.type == ContextType.IMAGE: # 图片消息,当前无默认逻辑
pass
else:
logger.error('[WX] unknown context type: {}'.format(context.type))
logger.error("[WX] unknown context type: {}".format(context.type))
return
return reply

def _decorate_reply(self, context: Context, reply: Reply) -> Reply:
if reply and reply.type:
e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {
'channel': self, 'context': context, 'reply': reply}))
reply = e_context['reply']
desire_rtype = context.get('desire_rtype')
e_context = PluginManager().emit_event(
EventContext(
Event.ON_DECORATE_REPLY,
{"channel": self, "context": context, "reply": reply},
)
)
reply = e_context["reply"]
desire_rtype = context.get("desire_rtype")
if not e_context.is_pass() and reply and reply.type:
if reply.type in self.NOT_SUPPORT_REPLYTYPE:
logger.error("[WX]reply type not support: " + str(reply.type))
reply.type = ReplyType.ERROR
@@ -196,59 +246,91 @@ class ChatChannel(Channel):

if reply.type == ReplyType.TEXT:
reply_text = reply.content
if desire_rtype == ReplyType.VOICE and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
if (
desire_rtype == ReplyType.VOICE
and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE
):
reply = super().build_text_to_voice(reply.content)
return self._decorate_reply(context, reply)
if context.get("isgroup", False):
reply_text = '@' + context['msg'].actual_user_nickname + ' ' + reply_text.strip()
reply_text = conf().get("group_chat_reply_prefix", "") + reply_text
reply_text = (
"@"
+ context["msg"].actual_user_nickname
+ " "
+ reply_text.strip()
)
reply_text = (
conf().get("group_chat_reply_prefix", "") + reply_text
)
else:
reply_text = conf().get("single_chat_reply_prefix", "") + reply_text
reply_text = (
conf().get("single_chat_reply_prefix", "") + reply_text
)
reply.content = reply_text
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
reply.content = "["+str(reply.type)+"]\n" + reply.content
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE:
reply.content = "[" + str(reply.type) + "]\n" + reply.content
elif (
reply.type == ReplyType.IMAGE_URL
or reply.type == ReplyType.VOICE
or reply.type == ReplyType.IMAGE
):
pass
else:
logger.error('[WX] unknown reply type: {}'.format(reply.type))
logger.error("[WX] unknown reply type: {}".format(reply.type))
return
if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]:
logger.warning('[WX] desire_rtype: {}, but reply type: {}'.format(context.get('desire_rtype'), reply.type))
if (
desire_rtype
and desire_rtype != reply.type
and reply.type not in [ReplyType.ERROR, ReplyType.INFO]
):
logger.warning(
"[WX] desire_rtype: {}, but reply type: {}".format(
context.get("desire_rtype"), reply.type
)
)
return reply

def _send_reply(self, context: Context, reply: Reply):
if reply and reply.type:
e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {
'channel': self, 'context': context, 'reply': reply}))
reply = e_context['reply']
e_context = PluginManager().emit_event(
EventContext(
Event.ON_SEND_REPLY,
{"channel": self, "context": context, "reply": reply},
)
)
reply = e_context["reply"]
if not e_context.is_pass() and reply and reply.type:
logger.debug('[WX] ready to send reply: {}, context: {}'.format(reply, context))
logger.debug(
"[WX] ready to send reply: {}, context: {}".format(reply, context)
)
self._send(reply, context)

def _send(self, reply: Reply, context: Context, retry_cnt = 0):
def _send(self, reply: Reply, context: Context, retry_cnt=0):
try:
self.send(reply, context)
except Exception as e:
logger.error('[WX] sendMsg error: {}'.format(str(e)))
logger.error("[WX] sendMsg error: {}".format(str(e)))
if isinstance(e, NotImplementedError):
return
logger.exception(e)
if retry_cnt < 2:
time.sleep(3+3*retry_cnt)
self._send(reply, context, retry_cnt+1)
time.sleep(3 + 3 * retry_cnt)
self._send(reply, context, retry_cnt + 1)

def _success_callback(self, session_id, **kwargs):# 线程正常结束时的回调函数
def _success_callback(self, session_id, **kwargs): # 线程正常结束时的回调函数
logger.debug("Worker return success, session_id = {}".format(session_id))

def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数
def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数
logger.exception("Worker return exception: {}".format(exception))

def _thread_pool_callback(self, session_id, **kwargs):
def func(worker:Future):
def func(worker: Future):
try:
worker_exception = worker.exception()
if worker_exception:
self._fail_callback(session_id, exception = worker_exception, **kwargs)
self._fail_callback(
session_id, exception=worker_exception, **kwargs
)
else:
self._success_callback(session_id, **kwargs)
except CancelledError as e:
@@ -257,15 +339,19 @@ class ChatChannel(Channel):
logger.exception("Worker raise exception: {}".format(e))
with self.lock:
self.sessions[session_id][1].release()

return func

def produce(self, context: Context):
session_id = context['session_id']
session_id = context["session_id"]
with self.lock:
if session_id not in self.sessions:
self.sessions[session_id] = [Dequeue(), threading.BoundedSemaphore(conf().get("concurrency_in_session", 4))]
if context.type == ContextType.TEXT and context.content.startswith("#"):
self.sessions[session_id][0].putleft(context) # 优先处理管理命令
self.sessions[session_id] = [
Dequeue(),
threading.BoundedSemaphore(conf().get("concurrency_in_session", 4)),
]
if context.type == ContextType.TEXT and context.content.startswith("#"):
self.sessions[session_id][0].putleft(context) # 优先处理管理命令
else:
self.sessions[session_id][0].put(context)

@@ -276,44 +362,58 @@ class ChatChannel(Channel):
session_ids = list(self.sessions.keys())
for session_id in session_ids:
context_queue, semaphore = self.sessions[session_id]
if semaphore.acquire(blocking = False): # 等线程处理完毕才能删除
if semaphore.acquire(blocking=False): # 等线程处理完毕才能删除
if not context_queue.empty():
context = context_queue.get()
logger.debug("[WX] consume context: {}".format(context))
future:Future = self.handler_pool.submit(self._handle, context)
future.add_done_callback(self._thread_pool_callback(session_id, context = context))
future: Future = self.handler_pool.submit(
self._handle, context
)
future.add_done_callback(
self._thread_pool_callback(session_id, context=context)
)
if session_id not in self.futures:
self.futures[session_id] = []
self.futures[session_id].append(future)
elif semaphore._initial_value == semaphore._value+1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()]
assert len(self.futures[session_id]) == 0, "thread pool error"
elif (
semaphore._initial_value == semaphore._value + 1
): # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
self.futures[session_id] = [
t for t in self.futures[session_id] if not t.done()
]
assert (
len(self.futures[session_id]) == 0
), "thread pool error"
del self.sessions[session_id]
else:
semaphore.release()
time.sleep(0.1)

# 取消session_id对应的所有任务,只能取消排队的消息和已提交线程池但未执行的任务
def cancel_session(self, session_id):
def cancel_session(self, session_id):
with self.lock:
if session_id in self.sessions:
for future in self.futures[session_id]:
future.cancel()
cnt = self.sessions[session_id][0].qsize()
if cnt>0:
logger.info("Cancel {} messages in session {}".format(cnt, session_id))
if cnt > 0:
logger.info(
"Cancel {} messages in session {}".format(cnt, session_id)
)
self.sessions[session_id][0] = Dequeue()
def cancel_all_session(self):
with self.lock:
for session_id in self.sessions:
for future in self.futures[session_id]:
future.cancel()
cnt = self.sessions[session_id][0].qsize()
if cnt>0:
logger.info("Cancel {} messages in session {}".format(cnt, session_id))
if cnt > 0:
logger.info(
"Cancel {} messages in session {}".format(cnt, session_id)
)
self.sessions[session_id][0] = Dequeue()

def check_prefix(content, prefix_list):
if not prefix_list:
@@ -323,6 +423,7 @@ def check_prefix(content, prefix_list):
return prefix
return None


def check_contain(content, keyword_list):
if not keyword_list:
return None


+ 10
- 10
channel/chat_message.py Ver fichero

@@ -1,5 +1,4 @@

"""
"""
本类表示聊天消息,用于对itchat和wechaty的消息进行统一的封装。

填好必填项(群聊6个,非群聊8个),即可接入ChatChannel,并支持插件,参考TerminalChannel
@@ -20,7 +19,7 @@ other_user_id: 对方的id,如果你是发送者,那这个就是接收者id
other_user_nickname: 同上

is_group: 是否是群消息 (群聊必填)
is_at: 是否被at
is_at: 是否被at

- (群消息时,一般会存在实际发送者,是群内某个成员的id和昵称,下列项仅在群消息时存在)
actual_user_id: 实际发送者id (群聊必填)
@@ -34,20 +33,22 @@ _prepared: 是否已经调用过准备函数
_rawmsg: 原始消息对象

"""


class ChatMessage(object):
msg_id = None
create_time = None
ctype = None
content = None
from_user_id = None
from_user_nickname = None
to_user_id = None
to_user_nickname = None
other_user_id = None
other_user_nickname = None
is_group = False
is_at = False
actual_user_id = None
@@ -57,8 +58,7 @@ class ChatMessage(object):
_prepared = False
_rawmsg = None


def __init__(self,_rawmsg):
def __init__(self, _rawmsg):
self._rawmsg = _rawmsg

def prepare(self):
@@ -67,7 +67,7 @@ class ChatMessage(object):
self._prepare_fn()

def __str__(self):
return 'ChatMessage: id={}, create_time={}, ctype={}, content={}, from_user_id={}, from_user_nickname={}, to_user_id={}, to_user_nickname={}, other_user_id={}, other_user_nickname={}, is_group={}, is_at={}, actual_user_id={}, actual_user_nickname={}'.format(
return "ChatMessage: id={}, create_time={}, ctype={}, content={}, from_user_id={}, from_user_nickname={}, to_user_id={}, to_user_nickname={}, other_user_id={}, other_user_nickname={}, is_group={}, is_at={}, actual_user_id={}, actual_user_nickname={}".format(
self.msg_id,
self.create_time,
self.ctype,
@@ -82,4 +82,4 @@ class ChatMessage(object):
self.is_at,
self.actual_user_id,
self.actual_user_nickname,
)
)

+ 26
- 10
channel/terminal/terminal_channel.py Ver fichero

@@ -1,14 +1,23 @@
import sys

from bridge.context import *
from bridge.reply import Reply, ReplyType
from channel.chat_channel import ChatChannel, check_prefix
from channel.chat_message import ChatMessage
import sys

from config import conf
from common.log import logger
from config import conf


class TerminalMessage(ChatMessage):
def __init__(self, msg_id, content, ctype = ContextType.TEXT, from_user_id = "User", to_user_id = "Chatgpt", other_user_id = "Chatgpt"):
def __init__(
self,
msg_id,
content,
ctype=ContextType.TEXT,
from_user_id="User",
to_user_id="Chatgpt",
other_user_id="Chatgpt",
):
self.msg_id = msg_id
self.ctype = ctype
self.content = content
@@ -16,6 +25,7 @@ class TerminalMessage(ChatMessage):
self.to_user_id = to_user_id
self.other_user_id = other_user_id


class TerminalChannel(ChatChannel):
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE]

@@ -23,14 +33,18 @@ class TerminalChannel(ChatChannel):
print("\nBot:")
if reply.type == ReplyType.IMAGE:
from PIL import Image

image_storage = reply.content
image_storage.seek(0)
img = Image.open(image_storage)
print("<IMAGE>")
img.show()
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
import io

import requests
from PIL import Image
import requests,io
img_url = reply.content
pic_res = requests.get(img_url, stream=True)
image_storage = io.BytesIO()
@@ -59,11 +73,13 @@ class TerminalChannel(ChatChannel):
print("\nExiting...")
sys.exit()
msg_id += 1
trigger_prefixs = conf().get("single_chat_prefix",[""])
trigger_prefixs = conf().get("single_chat_prefix", [""])
if check_prefix(prompt, trigger_prefixs) is None:
prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀
context = self._compose_context(ContextType.TEXT, prompt, msg = TerminalMessage(msg_id, prompt))
prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀

context = self._compose_context(
ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt)
)
if context:
self.produce(context)
else:


+ 80
- 47
channel/wechat/wechat_channel.py Ver fichero

@@ -4,40 +4,45 @@
wechat channel
"""

import io
import json
import os
import threading
import requests
import io
import time
import json

import requests

from bridge.context import *
from bridge.reply import *
from channel.chat_channel import ChatChannel
from channel.wechat.wechat_message import *
from common.singleton import singleton
from common.expired_dict import ExpiredDict
from common.log import logger
from common.singleton import singleton
from common.time_check import time_checker
from config import conf
from lib import itchat
from lib.itchat.content import *
from bridge.reply import *
from bridge.context import *
from config import conf
from common.time_check import time_checker
from common.expired_dict import ExpiredDict
from plugins import *

@itchat.msg_register([TEXT,VOICE,PICTURE])

@itchat.msg_register([TEXT, VOICE, PICTURE])
def handler_single_msg(msg):
# logger.debug("handler_single_msg: {}".format(msg))
if msg['Type'] == PICTURE and msg['MsgType'] == 47:
if msg["Type"] == PICTURE and msg["MsgType"] == 47:
return None
WechatChannel().handle_single(WeChatMessage(msg))
return None

@itchat.msg_register([TEXT,VOICE,PICTURE], isGroupChat=True)

@itchat.msg_register([TEXT, VOICE, PICTURE], isGroupChat=True)
def handler_group_msg(msg):
if msg['Type'] == PICTURE and msg['MsgType'] == 47:
if msg["Type"] == PICTURE and msg["MsgType"] == 47:
return None
WechatChannel().handle_group(WeChatMessage(msg,True))
WechatChannel().handle_group(WeChatMessage(msg, True))
return None


def _check(func):
def wrapper(self, cmsg: ChatMessage):
msgId = cmsg.msg_id
@@ -45,21 +50,27 @@ def _check(func):
logger.info("Wechat message {} already received, ignore".format(msgId))
return
self.receivedMsgs[msgId] = cmsg
create_time = cmsg.create_time # 消息时间戳
if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
create_time = cmsg.create_time # 消息时间戳
if (
conf().get("hot_reload") == True
and int(create_time) < int(time.time()) - 60
): # 跳过1分钟前的历史消息
logger.debug("[WX]history message {} skipped".format(msgId))
return
return func(self, cmsg)

return wrapper

#可用的二维码生成接口
#https://api.qrserver.com/v1/create-qr-code/?size=400×400&data=https://www.abc.com
#https://api.isoyu.com/qr/?m=1&e=L&p=20&url=https://www.abc.com
def qrCallback(uuid,status,qrcode):

# 可用的二维码生成接口
# https://api.qrserver.com/v1/create-qr-code/?size=400×400&data=https://www.abc.com
# https://api.isoyu.com/qr/?m=1&e=L&p=20&url=https://www.abc.com
def qrCallback(uuid, status, qrcode):
# logger.debug("qrCallback: {} {}".format(uuid,status))
if status == '0':
if status == "0":
try:
from PIL import Image

img = Image.open(io.BytesIO(qrcode))
_thread = threading.Thread(target=img.show, args=("QRCode",))
_thread.setDaemon(True)
@@ -68,35 +79,43 @@ def qrCallback(uuid,status,qrcode):
pass

import qrcode

url = f"https://login.weixin.qq.com/l/{uuid}"
qr_api1="https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url)
qr_api2="https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(url)
qr_api3="https://api.pwmqr.com/qrcode/create/?url={}".format(url)
qr_api4="https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format(url)

qr_api1 = "https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url)
qr_api2 = (
"https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(
url
)
)
qr_api3 = "https://api.pwmqr.com/qrcode/create/?url={}".format(url)
qr_api4 = "https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format(
url
)
print("You can also scan QRCode in any website below:")
print(qr_api3)
print(qr_api4)
print(qr_api2)
print(qr_api1)
qr = qrcode.QRCode(border=1)
qr.add_data(url)
qr.make(fit=True)
qr.print_ascii(invert=True)


@singleton
class WechatChannel(ChatChannel):
NOT_SUPPORT_REPLYTYPE = []

def __init__(self):
super().__init__()
self.receivedMsgs = ExpiredDict(60*60*24)
self.receivedMsgs = ExpiredDict(60 * 60 * 24)

def startup(self):

itchat.instance.receivingRetryCount = 600 # 修改断线超时时间
# login by scan QRCode
hotReload = conf().get('hot_reload', False)
hotReload = conf().get("hot_reload", False)
try:
itchat.auto_login(enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback)
except Exception as e:
@@ -104,12 +123,18 @@ class WechatChannel(ChatChannel):
logger.error("Hot reload failed, try to login without hot reload")
itchat.logout()
os.remove("itchat.pkl")
itchat.auto_login(enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback)
itchat.auto_login(
enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback
)
else:
raise e
self.user_id = itchat.instance.storageClass.userName
self.name = itchat.instance.storageClass.nickName
logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
logger.info(
"Wechat login success, user_id: {}, nickname: {}".format(
self.user_id, self.name
)
)
# start message listener
itchat.run()

@@ -127,24 +152,30 @@ class WechatChannel(ChatChannel):

@time_checker
@_check
def handle_single(self, cmsg : ChatMessage):
def handle_single(self, cmsg: ChatMessage):
if cmsg.ctype == ContextType.VOICE:
if conf().get('speech_recognition') != True:
if conf().get("speech_recognition") != True:
return
logger.debug("[WX]receive voice msg: {}".format(cmsg.content))
elif cmsg.ctype == ContextType.IMAGE:
logger.debug("[WX]receive image msg: {}".format(cmsg.content))
else:
logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
logger.debug(
"[WX]receive text msg: {}, cmsg={}".format(
json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg
)
)
context = self._compose_context(
cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg
)
if context:
self.produce(context)

@time_checker
@_check
def handle_group(self, cmsg : ChatMessage):
def handle_group(self, cmsg: ChatMessage):
if cmsg.ctype == ContextType.VOICE:
if conf().get('speech_recognition') != True:
if conf().get("speech_recognition") != True:
return
logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content))
elif cmsg.ctype == ContextType.IMAGE:
@@ -152,23 +183,25 @@ class WechatChannel(ChatChannel):
else:
# logger.debug("[WX]receive group msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
pass
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
context = self._compose_context(
cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg
)
if context:
self.produce(context)
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
def send(self, reply: Reply, context: Context):
receiver = context["receiver"]
if reply.type == ReplyType.TEXT:
itchat.send(reply.content, toUserName=receiver)
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
itchat.send(reply.content, toUserName=receiver)
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
elif reply.type == ReplyType.VOICE:
itchat.send_file(reply.content, toUserName=receiver)
logger.info('[WX] sendFile={}, receiver={}'.format(reply.content, receiver))
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
logger.info("[WX] sendFile={}, receiver={}".format(reply.content, receiver))
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
img_url = reply.content
pic_res = requests.get(img_url, stream=True)
image_storage = io.BytesIO()
@@ -176,9 +209,9 @@ class WechatChannel(ChatChannel):
image_storage.write(block)
image_storage.seek(0)
itchat.send_image(image_storage, toUserName=receiver)
logger.info('[WX] sendImage url={}, receiver={}'.format(img_url,receiver))
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
image_storage = reply.content
image_storage.seek(0)
itchat.send_image(image_storage, toUserName=receiver)
logger.info('[WX] sendImage, receiver={}'.format(receiver))
logger.info("[WX] sendImage, receiver={}".format(receiver))

+ 28
- 28
channel/wechat/wechat_message.py Ver fichero

@@ -1,54 +1,54 @@


from bridge.context import ContextType
from channel.chat_message import ChatMessage
from common.tmp_dir import TmpDir
from common.log import logger
from lib.itchat.content import *
from common.tmp_dir import TmpDir
from lib import itchat
from lib.itchat.content import *

class WeChatMessage(ChatMessage):

class WeChatMessage(ChatMessage):
def __init__(self, itchat_msg, is_group=False):
super().__init__( itchat_msg)
self.msg_id = itchat_msg['MsgId']
self.create_time = itchat_msg['CreateTime']
super().__init__(itchat_msg)
self.msg_id = itchat_msg["MsgId"]
self.create_time = itchat_msg["CreateTime"]
self.is_group = is_group
if itchat_msg['Type'] == TEXT:
if itchat_msg["Type"] == TEXT:
self.ctype = ContextType.TEXT
self.content = itchat_msg['Text']
elif itchat_msg['Type'] == VOICE:
self.content = itchat_msg["Text"]
elif itchat_msg["Type"] == VOICE:
self.ctype = ContextType.VOICE
self.content = TmpDir().path() + itchat_msg['FileName'] # content直接存临时目录路径
self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径
self._prepare_fn = lambda: itchat_msg.download(self.content)
elif itchat_msg['Type'] == PICTURE and itchat_msg['MsgType'] == 3:
elif itchat_msg["Type"] == PICTURE and itchat_msg["MsgType"] == 3:
self.ctype = ContextType.IMAGE
self.content = TmpDir().path() + itchat_msg['FileName'] # content直接存临时目录路径
self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径
self._prepare_fn = lambda: itchat_msg.download(self.content)
else:
raise NotImplementedError("Unsupported message type: {}".format(itchat_msg['Type']))
self.from_user_id = itchat_msg['FromUserName']
self.to_user_id = itchat_msg['ToUserName']
raise NotImplementedError(
"Unsupported message type: {}".format(itchat_msg["Type"])
)

self.from_user_id = itchat_msg["FromUserName"]
self.to_user_id = itchat_msg["ToUserName"]

user_id = itchat.instance.storageClass.userName
nickname = itchat.instance.storageClass.nickName
# 虽然from_user_id和to_user_id用的少,但是为了保持一致性,还是要填充一下
# 以下很繁琐,一句话总结:能填的都填了。
if self.from_user_id == user_id:
self.from_user_nickname = nickname
if self.to_user_id == user_id:
self.to_user_nickname = nickname
try: # 陌生人时候, 'User'字段可能不存在
self.other_user_id = itchat_msg['User']['UserName']
self.other_user_nickname = itchat_msg['User']['NickName']
try: # 陌生人时候, 'User'字段可能不存在
self.other_user_id = itchat_msg["User"]["UserName"]
self.other_user_nickname = itchat_msg["User"]["NickName"]
if self.other_user_id == self.from_user_id:
self.from_user_nickname = self.other_user_nickname
if self.other_user_id == self.to_user_id:
self.to_user_nickname = self.other_user_nickname
except KeyError as e: # 处理偶尔没有对方信息的情况
except KeyError as e: # 处理偶尔没有对方信息的情况
logger.warn("[WX]get other_user_id failed: " + str(e))
if self.from_user_id == user_id:
self.other_user_id = self.to_user_id
@@ -56,6 +56,6 @@ class WeChatMessage(ChatMessage):
self.other_user_id = self.from_user_id

if self.is_group:
self.is_at = itchat_msg['IsAt']
self.actual_user_id = itchat_msg['ActualUserName']
self.actual_user_nickname = itchat_msg['ActualNickName']
self.is_at = itchat_msg["IsAt"]
self.actual_user_id = itchat_msg["ActualUserName"]
self.actual_user_nickname = itchat_msg["ActualNickName"]

+ 54
- 40
channel/wechat/wechaty_channel.py Ver fichero

@@ -4,104 +4,118 @@
wechaty channel
Python Wechaty - https://github.com/wechaty/python-wechaty
"""
import asyncio
import base64
import os
import time
import asyncio
from bridge.context import Context
from wechaty_puppet import FileBox
from wechaty import Wechaty, Contact

from wechaty import Contact, Wechaty
from wechaty.user import Message
from bridge.reply import *
from wechaty_puppet import FileBox

from bridge.context import *
from bridge.context import Context
from bridge.reply import *
from channel.chat_channel import ChatChannel
from channel.wechat.wechaty_message import WechatyMessage
from common.log import logger
from common.singleton import singleton
from config import conf

try:
from voice.audio_convert import any_to_sil
except Exception as e:
pass


@singleton
class WechatyChannel(ChatChannel):
NOT_SUPPORT_REPLYTYPE = []

def __init__(self):
super().__init__()

def startup(self):
config = conf()
token = config.get('wechaty_puppet_service_token')
os.environ['WECHATY_PUPPET_SERVICE_TOKEN'] = token
token = config.get("wechaty_puppet_service_token")
os.environ["WECHATY_PUPPET_SERVICE_TOKEN"] = token
asyncio.run(self.main())

async def main(self):
loop = asyncio.get_event_loop()
#将asyncio的loop传入处理线程
self.handler_pool._initializer= lambda: asyncio.set_event_loop(loop)
# 将asyncio的loop传入处理线程
self.handler_pool._initializer = lambda: asyncio.set_event_loop(loop)
self.bot = Wechaty()
self.bot.on('login', self.on_login)
self.bot.on('message', self.on_message)
self.bot.on("login", self.on_login)
self.bot.on("message", self.on_message)
await self.bot.start()

async def on_login(self, contact: Contact):
self.user_id = contact.contact_id
self.name = contact.name
logger.info('[WX] login user={}'.format(contact))
logger.info("[WX] login user={}".format(contact))

# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
def send(self, reply: Reply, context: Context):
receiver_id = context['receiver']
receiver_id = context["receiver"]
loop = asyncio.get_event_loop()
if context['isgroup']:
receiver = asyncio.run_coroutine_threadsafe(self.bot.Room.find(receiver_id),loop).result()
if context["isgroup"]:
receiver = asyncio.run_coroutine_threadsafe(
self.bot.Room.find(receiver_id), loop
).result()
else:
receiver = asyncio.run_coroutine_threadsafe(self.bot.Contact.find(receiver_id),loop).result()
receiver = asyncio.run_coroutine_threadsafe(
self.bot.Contact.find(receiver_id), loop
).result()
msg = None
if reply.type == ReplyType.TEXT:
msg = reply.content
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
msg = reply.content
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
elif reply.type == ReplyType.VOICE:
voiceLength = None
file_path = reply.content
sil_file = os.path.splitext(file_path)[0] + '.sil'
sil_file = os.path.splitext(file_path)[0] + ".sil"
voiceLength = int(any_to_sil(file_path, sil_file))
if voiceLength >= 60000:
voiceLength = 60000
logger.info('[WX] voice too long, length={}, set to 60s'.format(voiceLength))
logger.info(
"[WX] voice too long, length={}, set to 60s".format(voiceLength)
)
# 发送语音
t = int(time.time())
msg = FileBox.from_file(sil_file, name=str(t) + '.sil')
msg = FileBox.from_file(sil_file, name=str(t) + ".sil")
if voiceLength is not None:
msg.metadata['voiceLength'] = voiceLength
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
msg.metadata["voiceLength"] = voiceLength
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
try:
os.remove(file_path)
if sil_file != file_path:
os.remove(sil_file)
except Exception as e:
pass
logger.info('[WX] sendVoice={}, receiver={}'.format(reply.content, receiver))
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
logger.info(
"[WX] sendVoice={}, receiver={}".format(reply.content, receiver)
)
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
img_url = reply.content
t = int(time.time())
msg = FileBox.from_url(url=img_url, name=str(t) + '.png')
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
logger.info('[WX] sendImage url={}, receiver={}'.format(img_url,receiver))
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
msg = FileBox.from_url(url=img_url, name=str(t) + ".png")
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
image_storage = reply.content
image_storage.seek(0)
t = int(time.time())
msg = FileBox.from_base64(base64.b64encode(image_storage.read()), str(t) + '.png')
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
logger.info('[WX] sendImage, receiver={}'.format(receiver))
msg = FileBox.from_base64(
base64.b64encode(image_storage.read()), str(t) + ".png"
)
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
logger.info("[WX] sendImage, receiver={}".format(receiver))

async def on_message(self, msg: Message):
"""
@@ -110,16 +124,16 @@ class WechatyChannel(ChatChannel):
try:
cmsg = await WechatyMessage(msg)
except NotImplementedError as e:
logger.debug('[WX] {}'.format(e))
logger.debug("[WX] {}".format(e))
return
except Exception as e:
logger.exception('[WX] {}'.format(e))
logger.exception("[WX] {}".format(e))
return
logger.debug('[WX] message:{}'.format(cmsg))
logger.debug("[WX] message:{}".format(cmsg))
room = msg.room() # 获取消息来自的群聊. 如果消息不是来自群聊, 则返回None
isgroup = room is not None
ctype = cmsg.ctype
context = self._compose_context(ctype, cmsg.content, isgroup=isgroup, msg=cmsg)
if context:
logger.info('[WX] receiveMsg={}, context={}'.format(cmsg, context))
self.produce(context)
logger.info("[WX] receiveMsg={}, context={}".format(cmsg, context))
self.produce(context)

+ 28
- 18
channel/wechat/wechaty_message.py Ver fichero

@@ -1,17 +1,21 @@
import asyncio
import re

from wechaty import MessageType
from wechaty.user import Message

from bridge.context import ContextType
from channel.chat_message import ChatMessage
from common.tmp_dir import TmpDir
from common.log import logger
from wechaty.user import Message
from common.tmp_dir import TmpDir


class aobject(object):
"""Inheriting this class allows you to define an async __init__.

So you can create objects by doing something like `await MyClass(params)`
"""

async def __new__(cls, *a, **kw):
instance = super().__new__(cls)
await instance.__init__(*a, **kw)
@@ -19,17 +23,18 @@ class aobject(object):

async def __init__(self):
pass
class WechatyMessage(ChatMessage, aobject):


class WechatyMessage(ChatMessage, aobject):
async def __init__(self, wechaty_msg: Message):
super().__init__(wechaty_msg)
room = wechaty_msg.room()

self.msg_id = wechaty_msg.message_id
self.create_time = wechaty_msg.payload.timestamp
self.is_group = room is not None
if wechaty_msg.type() == MessageType.MESSAGE_TYPE_TEXT:
self.ctype = ContextType.TEXT
self.content = wechaty_msg.text()
@@ -40,12 +45,17 @@ class WechatyMessage(ChatMessage, aobject):

def func():
loop = asyncio.get_event_loop()
asyncio.run_coroutine_threadsafe(voice_file.to_file(self.content),loop).result()
asyncio.run_coroutine_threadsafe(
voice_file.to_file(self.content), loop
).result()

self._prepare_fn = func
else:
raise NotImplementedError("Unsupported message type: {}".format(wechaty_msg.type()))
raise NotImplementedError(
"Unsupported message type: {}".format(wechaty_msg.type())
)

from_contact = wechaty_msg.talker() # 获取消息的发送者
self.from_user_id = from_contact.contact_id
self.from_user_nickname = from_contact.name
@@ -54,7 +64,7 @@ class WechatyMessage(ChatMessage, aobject):
# wecahty: from是消息实际发送者, to:所在群
# itchat: 如果是你发送群消息,from和to是你自己和所在群,如果是别人发群消息,from和to是所在群和你自己
# 但这个差别不影响逻辑,group中只使用到:1.用from来判断是否是自己发的,2.actual_user_id来判断实际发送用户
if self.is_group:
self.to_user_id = room.room_id
self.to_user_nickname = await room.topic()
@@ -63,22 +73,22 @@ class WechatyMessage(ChatMessage, aobject):
self.to_user_id = to_contact.contact_id
self.to_user_nickname = to_contact.name

if self.is_group or wechaty_msg.is_self(): # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。
if (
self.is_group or wechaty_msg.is_self()
): # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。
self.other_user_id = self.to_user_id
self.other_user_nickname = self.to_user_nickname
else:
self.other_user_id = self.from_user_id
self.other_user_nickname = self.from_user_nickname


if self.is_group: # wechaty群聊中,实际发送用户就是from_user
if self.is_group: # wechaty群聊中,实际发送用户就是from_user
self.is_at = await wechaty_msg.mention_self()
if not self.is_at: # 有时候复制粘贴的消息,不算做@,但是内容里面会有@xxx,这里做一下兼容
if not self.is_at: # 有时候复制粘贴的消息,不算做@,但是内容里面会有@xxx,这里做一下兼容
name = wechaty_msg.wechaty.user_self().name
pattern = f'@{name}(\u2005|\u0020)'
if re.search(pattern,self.content):
logger.debug(f'wechaty message {self.msg_id} include at')
pattern = f"@{name}(\u2005|\u0020)"
if re.search(pattern, self.content):
logger.debug(f"wechaty message {self.msg_id} include at")
self.is_at = True

self.actual_user_id = self.from_user_id


+ 2
- 2
channel/wechatmp/README.md Ver fichero

@@ -21,12 +21,12 @@ pip3 install web.py

相关的服务器验证代码已经写好,你不需要再添加任何代码。你只需要在本项目根目录的`config.json`中添加
```
"channel_type": "wechatmp",
"channel_type": "wechatmp",
"wechatmp_token": "Token", # 微信公众平台的Token
"wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443
"wechatmp_app_id": "", # 微信公众平台的appID,仅服务号需要
"wechatmp_app_secret": "", # 微信公众平台的appsecret,仅服务号需要
```
```
然后运行`python3 app.py`启动web服务器。这里会默认监听8080端口,但是微信公众号的服务器配置只支持80/443端口,有两种方法来解决这个问题。第一个是推荐的方法,使用端口转发命令将80端口转发到8080端口(443同理,注意需要支持SSL,也就是https的访问,在`wechatmp_channel.py`需要修改相应的证书路径):
```
sudo iptables -t nat -A PREROUTING -p tcp --dport 80 -j REDIRECT --to-port 8080


+ 35
- 16
channel/wechatmp/ServiceAccount.py Ver fichero

@@ -1,46 +1,66 @@
import web
import time
import channel.wechatmp.reply as reply

import web

import channel.wechatmp.receive as receive
from config import conf
from common.log import logger
import channel.wechatmp.reply as reply
from bridge.context import *
from channel.wechatmp.common import *
from channel.wechatmp.common import *
from channel.wechatmp.wechatmp_channel import WechatMPChannel
from common.log import logger
from config import conf

# This class is instantiated once per query
class Query():

# This class is instantiated once per query
class Query:
def GET(self):
return verify_server(web.input())

def POST(self):
# Make sure to return the instance that first created, @singleton will do that.
# Make sure to return the instance that first created, @singleton will do that.
channel = WechatMPChannel()
try:
webData = web.data()
# logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8"))
wechatmp_msg = receive.parse_xml(webData)
if wechatmp_msg.msg_type == 'text' or wechatmp_msg.msg_type == 'voice':
if wechatmp_msg.msg_type == "text" or wechatmp_msg.msg_type == "voice":
from_user = wechatmp_msg.from_user_id
message = wechatmp_msg.content.decode("utf-8")
message_id = wechatmp_msg.msg_id

logger.info("[wechatmp] {}:{} Receive post query {} {}: {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), from_user, message_id, message))
context = channel._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg)
logger.info(
"[wechatmp] {}:{} Receive post query {} {}: {}".format(
web.ctx.env.get("REMOTE_ADDR"),
web.ctx.env.get("REMOTE_PORT"),
from_user,
message_id,
message,
)
)
context = channel._compose_context(
ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg
)
if context:
# set private openai_api_key
# if from_user is not changed in itchat, this can be placed at chat_channel
user_data = conf().get_user_data(from_user)
context['openai_api_key'] = user_data.get('openai_api_key') # None or user openai_api_key
context["openai_api_key"] = user_data.get(
"openai_api_key"
) # None or user openai_api_key
channel.produce(context)
# The reply will be sent by channel.send() in another thread
return "success"

elif wechatmp_msg.msg_type == 'event':
logger.info("[wechatmp] Event {} from {}".format(wechatmp_msg.Event, wechatmp_msg.from_user_id))
elif wechatmp_msg.msg_type == "event":
logger.info(
"[wechatmp] Event {} from {}".format(
wechatmp_msg.Event, wechatmp_msg.from_user_id
)
)
content = subscribe_msg()
replyMsg = reply.TextMsg(wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content)
replyMsg = reply.TextMsg(
wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content
)
return replyMsg.send()
else:
logger.info("暂且不处理")
@@ -48,4 +68,3 @@ class Query():
except Exception as exc:
logger.exception(exc)
return exc


+ 98
- 38
channel/wechatmp/SubscribeAccount.py Ver fichero

@@ -1,81 +1,117 @@
import web
import time
import channel.wechatmp.reply as reply

import web

import channel.wechatmp.receive as receive
from config import conf
from common.log import logger
import channel.wechatmp.reply as reply
from bridge.context import *
from channel.wechatmp.common import *
from channel.wechatmp.common import *
from channel.wechatmp.wechatmp_channel import WechatMPChannel
from common.log import logger
from config import conf

# This class is instantiated once per query
class Query():

# This class is instantiated once per query
class Query:
def GET(self):
return verify_server(web.input())

def POST(self):
# Make sure to return the instance that first created, @singleton will do that.
# Make sure to return the instance that first created, @singleton will do that.
channel = WechatMPChannel()
try:
query_time = time.time()
webData = web.data()
logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8"))
wechatmp_msg = receive.parse_xml(webData)
if wechatmp_msg.msg_type == 'text' or wechatmp_msg.msg_type == 'voice':
if wechatmp_msg.msg_type == "text" or wechatmp_msg.msg_type == "voice":
from_user = wechatmp_msg.from_user_id
to_user = wechatmp_msg.to_user_id
message = wechatmp_msg.content.decode("utf-8")
message_id = wechatmp_msg.msg_id

logger.info("[wechatmp] {}:{} Receive post query {} {}: {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), from_user, message_id, message))
logger.info(
"[wechatmp] {}:{} Receive post query {} {}: {}".format(
web.ctx.env.get("REMOTE_ADDR"),
web.ctx.env.get("REMOTE_PORT"),
from_user,
message_id,
message,
)
)
supported = True
if "【收到不支持的消息类型,暂无法显示】" in message:
supported = False # not supported, used to refresh
supported = False # not supported, used to refresh
cache_key = from_user

reply_text = ""
# New request
if cache_key not in channel.cache_dict and cache_key not in channel.running:
if (
cache_key not in channel.cache_dict
and cache_key not in channel.running
):
# The first query begin, reset the cache
context = channel._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg)
logger.debug("[wechatmp] context: {} {}".format(context, wechatmp_msg))
if message_id in channel.received_msgs: # received and finished
context = channel._compose_context(
ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg
)
logger.debug(
"[wechatmp] context: {} {}".format(context, wechatmp_msg)
)
if message_id in channel.received_msgs: # received and finished
# no return because of bandwords or other reasons
return "success"
if supported and context:
# set private openai_api_key
# if from_user is not changed in itchat, this can be placed at chat_channel
user_data = conf().get_user_data(from_user)
context['openai_api_key'] = user_data.get('openai_api_key') # None or user openai_api_key
context["openai_api_key"] = user_data.get(
"openai_api_key"
) # None or user openai_api_key
channel.received_msgs[message_id] = wechatmp_msg
channel.running.add(cache_key)
channel.produce(context)
else:
trigger_prefix = conf().get('single_chat_prefix',[''])[0]
trigger_prefix = conf().get("single_chat_prefix", [""])[0]
if trigger_prefix or not supported:
if trigger_prefix:
content = textwrap.dedent(f"""\
content = textwrap.dedent(
f"""\
请输入'{trigger_prefix}'接你想说的话跟我说话。
例如:
{trigger_prefix}你好,很高兴见到你。""")
{trigger_prefix}你好,很高兴见到你。"""
)
else:
content = textwrap.dedent("""\
content = textwrap.dedent(
"""\
你好,很高兴见到你。
请跟我说话吧。""")
请跟我说话吧。"""
)
else:
logger.error(f"[wechatmp] unknown error")
content = textwrap.dedent("""\
未知错误,请稍后再试""")
replyMsg = reply.TextMsg(wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content)
content = textwrap.dedent(
"""\
未知错误,请稍后再试"""
)
replyMsg = reply.TextMsg(
wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content
)
return replyMsg.send()
channel.query1[cache_key] = False
channel.query2[cache_key] = False
channel.query3[cache_key] = False
# User request again, and the answer is not ready
elif cache_key in channel.running and channel.query1.get(cache_key) == True and channel.query2.get(cache_key) == True and channel.query3.get(cache_key) == True:
channel.query1[cache_key] = False #To improve waiting experience, this can be set to True.
channel.query2[cache_key] = False #To improve waiting experience, this can be set to True.
elif (
cache_key in channel.running
and channel.query1.get(cache_key) == True
and channel.query2.get(cache_key) == True
and channel.query3.get(cache_key) == True
):
channel.query1[
cache_key
] = False # To improve waiting experience, this can be set to True.
channel.query2[
cache_key
] = False # To improve waiting experience, this can be set to True.
channel.query3[cache_key] = False
# User request again, and the answer is ready
elif cache_key in channel.cache_dict:
@@ -84,7 +120,9 @@ class Query():
channel.query2[cache_key] = True
channel.query3[cache_key] = True

assert not (cache_key in channel.cache_dict and cache_key in channel.running)
assert not (
cache_key in channel.cache_dict and cache_key in channel.running
)

if channel.query1.get(cache_key) == False:
# The first query from wechat official server
@@ -128,14 +166,20 @@ class Query():
# Have waiting for 3x5 seconds
# return timeout message
reply_text = "【正在思考中,回复任意文字尝试获取回复】"
logger.info("[wechatmp] Three queries has finished For {}: {}".format(from_user, message_id))
logger.info(
"[wechatmp] Three queries has finished For {}: {}".format(
from_user, message_id
)
)
replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
return replyPost
else:
pass


if cache_key not in channel.cache_dict and cache_key not in channel.running:
if (
cache_key not in channel.cache_dict
and cache_key not in channel.running
):
# no return because of bandwords or other reasons
return "success"

@@ -147,26 +191,42 @@ class Query():

if cache_key in channel.cache_dict:
content = channel.cache_dict[cache_key]
if len(content.encode('utf8'))<=MAX_UTF8_LEN:
if len(content.encode("utf8")) <= MAX_UTF8_LEN:
reply_text = channel.cache_dict[cache_key]
channel.cache_dict.pop(cache_key)
else:
continue_text = "\n【未完待续,回复任意文字以继续】"
splits = split_string_by_utf8_length(content, MAX_UTF8_LEN - len(continue_text.encode('utf-8')), max_split= 1)
splits = split_string_by_utf8_length(
content,
MAX_UTF8_LEN - len(continue_text.encode("utf-8")),
max_split=1,
)
reply_text = splits[0] + continue_text
channel.cache_dict[cache_key] = splits[1]
logger.info("[wechatmp] {}:{} Do send {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), reply_text))
logger.info(
"[wechatmp] {}:{} Do send {}".format(
web.ctx.env.get("REMOTE_ADDR"),
web.ctx.env.get("REMOTE_PORT"),
reply_text,
)
)
replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
return replyPost

elif wechatmp_msg.msg_type == 'event':
logger.info("[wechatmp] Event {} from {}".format(wechatmp_msg.content, wechatmp_msg.from_user_id))
elif wechatmp_msg.msg_type == "event":
logger.info(
"[wechatmp] Event {} from {}".format(
wechatmp_msg.content, wechatmp_msg.from_user_id
)
)
content = subscribe_msg()
replyMsg = reply.TextMsg(wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content)
replyMsg = reply.TextMsg(
wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content
)
return replyMsg.send()
else:
logger.info("暂且不处理")
return "success"
except Exception as exc:
logger.exception(exc)
return exc
return exc

+ 15
- 10
channel/wechatmp/common.py Ver fichero

@@ -1,9 +1,11 @@
from config import conf
import hashlib
import textwrap

from config import conf

MAX_UTF8_LEN = 2048


class WeChatAPIException(Exception):
pass

@@ -16,13 +18,13 @@ def verify_server(data):
timestamp = data.timestamp
nonce = data.nonce
echostr = data.echostr
token = conf().get('wechatmp_token') #请按照公众平台官网\基本配置中信息填写
token = conf().get("wechatmp_token") # 请按照公众平台官网\基本配置中信息填写

data_list = [token, timestamp, nonce]
data_list.sort()
sha1 = hashlib.sha1()
# map(sha1.update, data_list) #python2
sha1.update("".join(data_list).encode('utf-8'))
sha1.update("".join(data_list).encode("utf-8"))
hashcode = sha1.hexdigest()
print("handle/GET func: hashcode, signature: ", hashcode, signature)
if hashcode == signature:
@@ -32,9 +34,11 @@ def verify_server(data):
except Exception as Argument:
return Argument


def subscribe_msg():
trigger_prefix = conf().get('single_chat_prefix',[''])[0]
msg = textwrap.dedent(f"""\
trigger_prefix = conf().get("single_chat_prefix", [""])[0]
msg = textwrap.dedent(
f"""\
感谢您的关注!
这里是ChatGPT,可以自由对话。
资源有限,回复较慢,请勿着急。
@@ -42,22 +46,23 @@ def subscribe_msg():
暂时不支持图片输入。
支持图片输出,画字开头的问题将回复图片链接。
支持角色扮演和文字冒险两种定制模式对话。
输入'{trigger_prefix}#帮助' 查看详细指令。""")
输入'{trigger_prefix}#帮助' 查看详细指令。"""
)
return msg


def split_string_by_utf8_length(string, max_length, max_split=0):
encoded = string.encode('utf-8')
encoded = string.encode("utf-8")
start, end = 0, 0
result = []
while end < len(encoded):
if max_split > 0 and len(result) >= max_split:
result.append(encoded[start:].decode('utf-8'))
result.append(encoded[start:].decode("utf-8"))
break
end = start + max_length
# 如果当前字节不是 UTF-8 编码的开始字节,则向前查找直到找到开始字节为止
while end < len(encoded) and (encoded[end] & 0b11000000) == 0b10000000:
end -= 1
result.append(encoded[start:end].decode('utf-8'))
result.append(encoded[start:end].decode("utf-8"))
start = end
return result
return result

+ 20
- 18
channel/wechatmp/receive.py Ver fichero

@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-#
# filename: receive.py
import xml.etree.ElementTree as ET

from bridge.context import ContextType
from channel.chat_message import ChatMessage
from common.log import logger
@@ -12,34 +13,35 @@ def parse_xml(web_data):
xmlData = ET.fromstring(web_data)
return WeChatMPMessage(xmlData)


class WeChatMPMessage(ChatMessage):
def __init__(self, xmlData):
super().__init__(xmlData)
self.to_user_id = xmlData.find('ToUserName').text
self.from_user_id = xmlData.find('FromUserName').text
self.create_time = xmlData.find('CreateTime').text
self.msg_type = xmlData.find('MsgType').text
self.to_user_id = xmlData.find("ToUserName").text
self.from_user_id = xmlData.find("FromUserName").text
self.create_time = xmlData.find("CreateTime").text
self.msg_type = xmlData.find("MsgType").text
try:
self.msg_id = xmlData.find('MsgId').text
self.msg_id = xmlData.find("MsgId").text
except:
self.msg_id = self.from_user_id+self.create_time
self.msg_id = self.from_user_id + self.create_time
self.is_group = False
# reply to other_user_id
self.other_user_id = self.from_user_id

if self.msg_type == 'text':
if self.msg_type == "text":
self.ctype = ContextType.TEXT
self.content = xmlData.find('Content').text.encode("utf-8")
elif self.msg_type == 'voice':
self.content = xmlData.find("Content").text.encode("utf-8")
elif self.msg_type == "voice":
self.ctype = ContextType.TEXT
self.content = xmlData.find('Recognition').text.encode("utf-8") # 接收语音识别结果
elif self.msg_type == 'image':
self.content = xmlData.find("Recognition").text.encode("utf-8") # 接收语音识别结果
elif self.msg_type == "image":
# not implemented
self.pic_url = xmlData.find('PicUrl').text
self.media_id = xmlData.find('MediaId').text
elif self.msg_type == 'event':
self.content = xmlData.find('Event').text
else: # video, shortvideo, location, link
self.pic_url = xmlData.find("PicUrl").text
self.media_id = xmlData.find("MediaId").text
elif self.msg_type == "event":
self.content = xmlData.find("Event").text
else: # video, shortvideo, location, link
# not implemented
pass
pass

+ 12
- 9
channel/wechatmp/reply.py Ver fichero

@@ -2,6 +2,7 @@
# filename: reply.py
import time


class Msg(object):
def __init__(self):
pass
@@ -9,13 +10,14 @@ class Msg(object):
def send(self):
return "success"


class TextMsg(Msg):
def __init__(self, toUserName, fromUserName, content):
self.__dict = dict()
self.__dict['ToUserName'] = toUserName
self.__dict['FromUserName'] = fromUserName
self.__dict['CreateTime'] = int(time.time())
self.__dict['Content'] = content
self.__dict["ToUserName"] = toUserName
self.__dict["FromUserName"] = fromUserName
self.__dict["CreateTime"] = int(time.time())
self.__dict["Content"] = content

def send(self):
XmlForm = """
@@ -29,13 +31,14 @@ class TextMsg(Msg):
"""
return XmlForm.format(**self.__dict)


class ImageMsg(Msg):
def __init__(self, toUserName, fromUserName, mediaId):
self.__dict = dict()
self.__dict['ToUserName'] = toUserName
self.__dict['FromUserName'] = fromUserName
self.__dict['CreateTime'] = int(time.time())
self.__dict['MediaId'] = mediaId
self.__dict["ToUserName"] = toUserName
self.__dict["FromUserName"] = fromUserName
self.__dict["CreateTime"] = int(time.time())
self.__dict["MediaId"] = mediaId

def send(self):
XmlForm = """
@@ -49,4 +52,4 @@ class ImageMsg(Msg):
</Image>
</xml>
"""
return XmlForm.format(**self.__dict)
return XmlForm.format(**self.__dict)

+ 47
- 38
channel/wechatmp/wechatmp_channel.py Ver fichero

@@ -1,17 +1,19 @@
# -*- coding: utf-8 -*-
import web
import time
import json
import requests
import threading
from common.singleton import singleton
from common.log import logger
from common.expired_dict import ExpiredDict
from config import conf
from bridge.reply import *
import time
import requests
import web
from bridge.context import *
from bridge.reply import *
from channel.chat_channel import ChatChannel
from channel.wechatmp.common import *
from channel.wechatmp.common import *
from common.expired_dict import ExpiredDict
from common.log import logger
from common.singleton import singleton
from config import conf

# If using SSL, uncomment the following lines, and modify the certificate path.
# from cheroot.server import HTTPServer
@@ -20,13 +22,14 @@ from channel.wechatmp.common import *
# certificate='/ssl/cert.pem',
# private_key='/ssl/cert.key')


@singleton
class WechatMPChannel(ChatChannel):
def __init__(self, passive_reply = True):
def __init__(self, passive_reply=True):
super().__init__()
self.passive_reply = passive_reply
self.running = set()
self.received_msgs = ExpiredDict(60*60*24)
self.received_msgs = ExpiredDict(60 * 60 * 24)
if self.passive_reply:
self.NOT_SUPPORT_REPLYTYPE = [ReplyType.IMAGE, ReplyType.VOICE]
self.cache_dict = dict()
@@ -36,8 +39,8 @@ class WechatMPChannel(ChatChannel):
else:
# TODO support image
self.NOT_SUPPORT_REPLYTYPE = [ReplyType.IMAGE, ReplyType.VOICE]
self.app_id = conf().get('wechatmp_app_id')
self.app_secret = conf().get('wechatmp_app_secret')
self.app_id = conf().get("wechatmp_app_id")
self.app_secret = conf().get("wechatmp_app_secret")
self.access_token = None
self.access_token_expires_time = 0
self.access_token_lock = threading.Lock()
@@ -45,13 +48,12 @@ class WechatMPChannel(ChatChannel):

def startup(self):
if self.passive_reply:
urls = ('/wx', 'channel.wechatmp.SubscribeAccount.Query')
urls = ("/wx", "channel.wechatmp.SubscribeAccount.Query")
else:
urls = ('/wx', 'channel.wechatmp.ServiceAccount.Query')
urls = ("/wx", "channel.wechatmp.ServiceAccount.Query")
app = web.application(urls, globals(), autoreload=False)
port = conf().get('wechatmp_port', 8080)
web.httpserver.runsimple(app.wsgifunc(), ('0.0.0.0', port))

port = conf().get("wechatmp_port", 8080)
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))

def wechatmp_request(self, method, url, **kwargs):
r = requests.request(method=method, url=url, **kwargs)
@@ -63,7 +65,6 @@ class WechatMPChannel(ChatChannel):
return ret

def get_access_token(self):

# return the access_token
if self.access_token:
if self.access_token_expires_time - time.time() > 60:
@@ -76,15 +77,15 @@ class WechatMPChannel(ChatChannel):
# This happens every 2 hours, so it doesn't affect the experience very much
time.sleep(1)
self.access_token = None
url="https://api.weixin.qq.com/cgi-bin/token"
params={
url = "https://api.weixin.qq.com/cgi-bin/token"
params = {
"grant_type": "client_credential",
"appid": self.app_id,
"secret": self.app_secret
"secret": self.app_secret,
}
data = self.wechatmp_request(method='get', url=url, params=params)
self.access_token = data['access_token']
self.access_token_expires_time = int(time.time()) + data['expires_in']
data = self.wechatmp_request(method="get", url=url, params=params)
self.access_token = data["access_token"]
self.access_token_expires_time = int(time.time()) + data["expires_in"]
logger.info("[wechatmp] access_token: {}".format(self.access_token))
self.access_token_lock.release()
else:
@@ -101,29 +102,37 @@ class WechatMPChannel(ChatChannel):
else:
receiver = context["receiver"]
reply_text = reply.content
url="https://api.weixin.qq.com/cgi-bin/message/custom/send"
params = {
"access_token": self.get_access_token()
}
url = "https://api.weixin.qq.com/cgi-bin/message/custom/send"
params = {"access_token": self.get_access_token()}
json_data = {
"touser": receiver,
"msgtype": "text",
"text": {"content": reply_text}
"text": {"content": reply_text},
}
self.wechatmp_request(method='post', url=url, params=params, data=json.dumps(json_data, ensure_ascii=False).encode('utf8'))
self.wechatmp_request(
method="post",
url=url,
params=params,
data=json.dumps(json_data, ensure_ascii=False).encode("utf8"),
)
logger.info("[send] Do send to {}: {}".format(receiver, reply_text))
return


def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数
logger.debug("[wechatmp] Success to generate reply, msgId={}".format(context['msg'].msg_id))
def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数
logger.debug(
"[wechatmp] Success to generate reply, msgId={}".format(
context["msg"].msg_id
)
)
if self.passive_reply:
self.running.remove(session_id)


def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数
logger.exception("[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format(context['msg'].msg_id, exception))
def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数
logger.exception(
"[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format(
context["msg"].msg_id, exception
)
)
if self.passive_reply:
assert session_id not in self.cache_dict
self.running.remove(session_id)


+ 1
- 1
common/const.py Ver fichero

@@ -2,4 +2,4 @@
OPEN_AI = "openAI"
CHATGPT = "chatGPT"
BAIDU = "baidu"
CHATGPTONAZURE = "chatGPTOnAzure"
CHATGPTONAZURE = "chatGPTOnAzure"

+ 2
- 2
common/dequeue.py Ver fichero

@@ -1,7 +1,7 @@

from queue import Full, Queue
from time import monotonic as time


# add implementation of putleft to Queue
class Dequeue(Queue):
def putleft(self, item, block=True, timeout=None):
@@ -30,4 +30,4 @@ class Dequeue(Queue):
return self.putleft(item, block=False)

def _putleft(self, item):
self.queue.appendleft(item)
self.queue.appendleft(item)

+ 1
- 1
common/expired_dict.py Ver fichero

@@ -39,4 +39,4 @@ class ExpiredDict(dict):
return [(key, self[key]) for key in self.keys()]

def __iter__(self):
return self.keys().__iter__()
return self.keys().__iter__()

+ 16
- 7
common/log.py Ver fichero

@@ -10,20 +10,29 @@ def _reset_logger(log):
log.handlers.clear()
log.propagate = False
console_handle = logging.StreamHandler(sys.stdout)
console_handle.setFormatter(logging.Formatter('[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'))
file_handle = logging.FileHandler('run.log', encoding='utf-8')
file_handle.setFormatter(logging.Formatter('[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'))
console_handle.setFormatter(
logging.Formatter(
"[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
)
file_handle = logging.FileHandler("run.log", encoding="utf-8")
file_handle.setFormatter(
logging.Formatter(
"[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
)
log.addHandler(file_handle)
log.addHandler(console_handle)


def _get_logger():
log = logging.getLogger('log')
log = logging.getLogger("log")
_reset_logger(log)
log.setLevel(logging.INFO)
return log


# 日志句柄
logger = _get_logger()
logger = _get_logger()

+ 11
- 5
common/package_manager.py Ver fichero

@@ -1,15 +1,20 @@
import time

import pip
from pip._internal import main as pipmain
from common.log import logger,_reset_logger

from common.log import _reset_logger, logger


def install(package):
pipmain(['install', package])
pipmain(["install", package])


def install_requirements(file):
pipmain(['install', '-r', file, "--upgrade"])
pipmain(["install", "-r", file, "--upgrade"])
_reset_logger(logger)


def check_dulwich():
needwait = False
for i in range(2):
@@ -18,13 +23,14 @@ def check_dulwich():
needwait = False
try:
import dulwich

return
except ImportError:
try:
install('dulwich')
install("dulwich")
except:
needwait = True
try:
import dulwich
except ImportError:
raise ImportError("Unable to import dulwich")
raise ImportError("Unable to import dulwich")

+ 1
- 1
common/sorted_dict.py Ver fichero

@@ -62,4 +62,4 @@ class SortedDict(dict):
return iter(self.keys())

def __repr__(self):
return f'{type(self).__name__}({dict(self)}, sort_func={self.sort_func.__name__}, reverse={self.reverse})'
return f"{type(self).__name__}({dict(self)}, sort_func={self.sort_func.__name__}, reverse={self.reverse})"

+ 22
- 10
common/time_check.py Ver fichero

@@ -1,7 +1,11 @@
import time,re,hashlib
import hashlib
import re
import time

import config
from common.log import logger


def time_checker(f):
def _time_checker(self, *args, **kwargs):
_config = config.conf()
@@ -9,17 +13,25 @@ def time_checker(f):
if chat_time_module:
chat_start_time = _config.get("chat_start_time", "00:00")
chat_stopt_time = _config.get("chat_stop_time", "24:00")
time_regex = re.compile(r'^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$') #时间匹配,包含24:00
time_regex = re.compile(
r"^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$"
) # 时间匹配,包含24:00

starttime_format_check = time_regex.match(chat_start_time) # 检查停止时间格式
stoptime_format_check = time_regex.match(chat_stopt_time) # 检查停止时间格式
chat_time_check = chat_start_time < chat_stopt_time # 确定启动时间<停止时间
chat_time_check = chat_start_time < chat_stopt_time # 确定启动时间<停止时间

# 时间格式检查
if not (starttime_format_check and stoptime_format_check and chat_time_check):
logger.warn('时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})'.format(starttime_format_check,stoptime_format_check))
if chat_start_time>"23:59":
logger.error('启动时间可能存在问题,请修改!')
if not (
starttime_format_check and stoptime_format_check and chat_time_check
):
logger.warn(
"时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})".format(
starttime_format_check, stoptime_format_check
)
)
if chat_start_time > "23:59":
logger.error("启动时间可能存在问题,请修改!")

# 服务时间检查
now_time = time.strftime("%H:%M", time.localtime())
@@ -27,12 +39,12 @@ def time_checker(f):
f(self, *args, **kwargs)
return None
else:
if args[0]['Content'] == "#更新配置": # 不在服务时间内也可以更新配置
if args[0]["Content"] == "#更新配置": # 不在服务时间内也可以更新配置
f(self, *args, **kwargs)
else:
logger.info('非服务时间内,不接受访问')
logger.info("非服务时间内,不接受访问")
return None
else:
f(self, *args, **kwargs) # 未开启时间模块则直接回答
return _time_checker

return _time_checker

+ 5
- 7
common/tmp_dir.py Ver fichero

@@ -1,20 +1,18 @@

import os
import pathlib

from config import conf


class TmpDir(object):
"""A temporary directory that is deleted when the object is destroyed.
"""
"""A temporary directory that is deleted when the object is destroyed."""

tmpFilePath = pathlib.Path("./tmp/")

tmpFilePath = pathlib.Path('./tmp/')
def __init__(self):
pathExists = os.path.exists(self.tmpFilePath)
if not pathExists:
os.makedirs(self.tmpFilePath)

def path(self):
return str(self.tmpFilePath) + '/'
return str(self.tmpFilePath) + "/"

+ 20
- 6
config-template.json Ver fichero

@@ -2,16 +2,30 @@
"open_ai_api_key": "YOUR API KEY",
"model": "gpt-3.5-turbo",
"proxy": "",
"single_chat_prefix": ["bot", "@bot"],
"single_chat_prefix": [
"bot",
"@bot"
],
"single_chat_reply_prefix": "[bot] ",
"group_chat_prefix": ["@bot"],
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"],
"group_chat_in_one_session": ["ChatGPT测试群"],
"image_create_prefix": ["画", "看", "找"],
"group_chat_prefix": [
"@bot"
],
"group_name_white_list": [
"ChatGPT测试群",
"ChatGPT测试群2"
],
"group_chat_in_one_session": [
"ChatGPT测试群"
],
"image_create_prefix": [
"画",
"看",
"找"
],
"speech_recognition": false,
"group_speech_recognition": false,
"voice_reply_voice": false,
"conversation_max_tokens": 1000,
"expires_in_seconds": 3600,
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。"
}
}

+ 21
- 32
config.py Ver fichero

@@ -3,9 +3,10 @@
import json
import logging
import os
from common.log import logger
import pickle

from common.log import logger

# 将所有可用的配置项写在字典里, 请使用小写字母
available_setting = {
# openai api配置
@@ -16,8 +17,7 @@ available_setting = {
# chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
"model": "gpt-3.5-turbo",
"use_azure_chatgpt": False, # 是否使用azure的chatgpt
"azure_deployment_id": "", #azure 模型部署名称

"azure_deployment_id": "", # azure 模型部署名称
# Bot触发配置
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
@@ -30,25 +30,21 @@ available_setting = {
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
"trigger_by_self": False, # 是否允许机器人触发
"image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀
"concurrency_in_session": 1, # 同一会话最多有多少条消息在处理中,大于1可能乱序

"concurrency_in_session": 1, # 同一会话最多有多少条消息在处理中,大于1可能乱序
# chatgpt会话参数
"expires_in_seconds": 3600, # 无操作会话的过期时间
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述
"conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数

# chatgpt限流配置
"rate_limit_chatgpt": 20, # chatgpt的调用频率限制
"rate_limit_dalle": 50, # openai dalle的调用频率限制

# chatgpt api参数 参考https://platform.openai.com/docs/api-reference/chat/create
"temperature": 0.9,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"request_timeout": 60, # chatgpt请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
"timeout": 120, # chatgpt重试超时时间,在这个时间内,将会自动重试

"request_timeout": 60, # chatgpt请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
"timeout": 120, # chatgpt重试超时时间,在这个时间内,将会自动重试
# 语音设置
"speech_recognition": False, # 是否开启语音识别
"group_speech_recognition": False, # 是否开启群组语音识别
@@ -56,50 +52,40 @@ available_setting = {
"always_reply_voice": False, # 是否一直使用语音回复
"voice_to_text": "openai", # 语音识别引擎,支持openai,baidu,google,azure
"text_to_voice": "baidu", # 语音合成引擎,支持baidu,google,pytts(offline),azure

# baidu 语音api配置, 使用百度语音识别和语音合成时需要
"baidu_app_id": "",
"baidu_api_key": "",
"baidu_secret_key": "",
# 1536普通话(支持简单的英文识别) 1737英语 1637粤语 1837四川话 1936普通话远场
"baidu_dev_pid": "1536",

# azure 语音api配置, 使用azure语音识别和语音合成时需要
"azure_voice_api_key": "",
"azure_voice_region": "japaneast",

# 服务时间限制,目前支持itchat
"chat_time_module": False, # 是否开启服务时间限制
"chat_start_time": "00:00", # 服务开始时间
"chat_stop_time": "24:00", # 服务结束时间

# itchat的配置
"hot_reload": False, # 是否开启热重载

# wechaty的配置
"wechaty_puppet_service_token": "", # wechaty的token

# wechatmp的配置
"wechatmp_token": "", # 微信公众平台的Token
"wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443
"wechatmp_app_id": "", # 微信公众平台的appID,仅服务号需要
"wechatmp_token": "", # 微信公众平台的Token
"wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443
"wechatmp_app_id": "", # 微信公众平台的appID,仅服务号需要
"wechatmp_app_secret": "", # 微信公众平台的appsecret,仅服务号需要

# chatgpt指令自定义触发词
"clear_memory_commands": ['#清除记忆'], # 重置会话指令,必须以#开头

"clear_memory_commands": ["#清除记忆"], # 重置会话指令,必须以#开头
# channel配置
"channel_type": "wx", # 通道类型,支持:{wx,wxy,terminal,wechatmp,wechatmp_service}

"channel_type": "wx", # 通道类型,支持:{wx,wxy,terminal,wechatmp,wechatmp_service}
"debug": False, # 是否开启debug模式,开启后会打印更多日志

# 插件配置
"plugin_trigger_prefix": "$", # 规范插件提供聊天相关指令的前缀,建议不要和管理员指令前缀"#"冲突
}


class Config(dict):
def __init__(self, d:dict={}):
def __init__(self, d: dict = {}):
super().__init__(d)
# user_datas: 用户数据,key为用户名,value为用户数据,也是dict
self.user_datas = {}
@@ -130,7 +116,7 @@ class Config(dict):

def load_user_datas(self):
try:
with open('user_datas.pkl', 'rb') as f:
with open("user_datas.pkl", "rb") as f:
self.user_datas = pickle.load(f)
logger.info("[Config] User datas loaded.")
except FileNotFoundError as e:
@@ -141,12 +127,13 @@ class Config(dict):

def save_user_datas(self):
try:
with open('user_datas.pkl', 'wb') as f:
with open("user_datas.pkl", "wb") as f:
pickle.dump(self.user_datas, f)
logger.info("[Config] User datas saved.")
except Exception as e:
logger.info("[Config] User datas error: {}".format(e))


config = Config()


@@ -154,7 +141,7 @@ def load_config():
global config
config_path = "./config.json"
if not os.path.exists(config_path):
logger.info('配置文件不存在,将使用config-template.json模板')
logger.info("配置文件不存在,将使用config-template.json模板")
config_path = "./config-template.json"

config_str = read_file(config_path)
@@ -169,7 +156,8 @@ def load_config():
name = name.lower()
if name in available_setting:
logger.info(
"[INIT] override config by environ args: {}={}".format(name, value))
"[INIT] override config by environ args: {}={}".format(name, value)
)
try:
config[name] = eval(value)
except:
@@ -182,18 +170,19 @@ def load_config():

if config.get("debug", False):
logger.setLevel(logging.DEBUG)
logger.debug("[INIT] set log level to DEBUG")
logger.debug("[INIT] set log level to DEBUG")

logger.info("[INIT] load config: {}".format(config))

config.load_user_datas()


def get_root():
return os.path.dirname(os.path.abspath(__file__))


def read_file(path):
with open(path, mode='r', encoding='utf-8') as f:
with open(path, mode="r", encoding="utf-8") as f:
return f.read()




+ 1
- 1
docker/Dockerfile.debian Ver fichero

@@ -33,7 +33,7 @@ ADD ./entrypoint.sh /entrypoint.sh
RUN chmod +x /entrypoint.sh \
&& groupadd -r noroot \
&& useradd -r -g noroot -s /bin/bash -d /home/noroot noroot \
&& chown -R noroot:noroot ${BUILD_PREFIX}
&& chown -R noroot:noroot ${BUILD_PREFIX}

USER noroot



+ 1
- 1
docker/Dockerfile.debian.latest Ver fichero

@@ -18,7 +18,7 @@ RUN apt-get update \
&& pip install --no-cache -r requirements.txt \
&& pip install --no-cache -r requirements-optional.txt \
&& pip install azure-cognitiveservices-speech
WORKDIR ${BUILD_PREFIX}

ADD docker/entrypoint.sh /entrypoint.sh


+ 1
- 2
docker/build.alpine.sh Ver fichero

@@ -11,6 +11,5 @@ docker build -f Dockerfile.alpine \
-t zhayujie/chatgpt-on-wechat .

# tag image
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:alpine
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:alpine
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$CHATGPT_ON_WECHAT_TAG-alpine

+ 1
- 1
docker/build.debian.sh Ver fichero

@@ -11,5 +11,5 @@ docker build -f Dockerfile.debian \
-t zhayujie/chatgpt-on-wechat .

# tag image
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:debian
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:debian
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$CHATGPT_ON_WECHAT_TAG-debian

+ 1
- 1
docker/chatgpt-on-wechat-voice-reply/Dockerfile.alpine Ver fichero

@@ -9,7 +9,7 @@ RUN apk add --no-cache \
ffmpeg \
espeak \
&& pip install --no-cache \
baidu-aip \
baidu-aip \
chardet \
SpeechRecognition



+ 1
- 1
docker/chatgpt-on-wechat-voice-reply/Dockerfile.debian Ver fichero

@@ -10,7 +10,7 @@ RUN apt-get update \
ffmpeg \
espeak \
&& pip install --no-cache \
baidu-aip \
baidu-aip \
chardet \
SpeechRecognition



+ 2
- 2
docker/sample-chatgpt-on-wechat/Makefile Ver fichero

@@ -11,13 +11,13 @@ run_d:
docker rm $(CONTAINER_NAME) || echo
docker run -dt --name $(CONTAINER_NAME) $(PORT_MAP) \
--env-file=$(DOTENV) \
$(MOUNT) $(IMG)
$(MOUNT) $(IMG)

run_i:
docker rm $(CONTAINER_NAME) || echo
docker run -it --name $(CONTAINER_NAME) $(PORT_MAP) \
--env-file=$(DOTENV) \
$(MOUNT) $(IMG)
$(MOUNT) $(IMG)

stop:
docker stop $(CONTAINER_NAME)


+ 14
- 14
plugins/README.md Ver fichero

@@ -24,17 +24,17 @@
在本仓库中预置了一些插件,如果要安装其他仓库的插件,有两种方法。

- 第一种方法是在将下载的插件文件都解压到"plugins"文件夹的一个单独的文件夹,最终插件的代码都位于"plugins/PLUGIN_NAME/*"中。启动程序后,如果插件的目录结构正确,插件会自动被扫描加载。除此以外,注意你还需要安装文件夹中`requirements.txt`中的依赖。
- 第二种方法是`Godcmd`插件,它是预置的管理员插件,能够让程序在运行时就能安装插件,它能够自动安装依赖。
安装插件的命令是"#installp [仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)记录的插件名/仓库地址"。这是管理员命令,认证方法在[这里](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/godcmd)。
- 安装[仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)记录的插件:#installp sdwebui

- 安装指定仓库的插件:#installp https://github.com/lanvent/plugin_sdwebui.git
在安装之后,需要执行"#scanp"命令来扫描加载新安装的插件(或者重新启动程序)。
安装插件后需要注意有些插件有自己的配置模板,一般要去掉".template"新建一个配置文件。

## 插件化实现
@@ -107,14 +107,14 @@
```

回复`Reply`的定义如下所示,它允许Bot可以回复多类不同的消息。同时也加入了`INFO`和`ERROR`消息类型区分系统提示和系统错误。
```python
class ReplyType(Enum):
TEXT = 1 # 文本
VOICE = 2 # 音频文件
IMAGE = 3 # 图片文件
IMAGE_URL = 4 # 图片URL
INFO = 9
ERROR = 10
class Reply:
@@ -159,12 +159,12 @@

目前支持三类触发事件:
```
1.收到消息
---> `ON_HANDLE_CONTEXT`
2.产生回复
---> `ON_DECORATE_REPLY`
3.装饰回复
---> `ON_SEND_REPLY`
1.收到消息
---> `ON_HANDLE_CONTEXT`
2.产生回复
---> `ON_DECORATE_REPLY`
3.装饰回复
---> `ON_SEND_REPLY`
4.发送回复
```

@@ -268,6 +268,6 @@ class Hello(Plugin):
- 一个插件目录建议只注册一个插件类。建议使用单独的仓库维护插件,便于更新。

在测试调试好后提交`PR`,把自己的仓库加入到[仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)中。
- 插件的config文件、使用说明`README.md`、`requirement.txt`等放置在插件目录中。
- 默认优先级不要超过管理员插件`Godcmd`的优先级(999),`Godcmd`插件提供了配置管理、插件管理等功能。

+ 2
- 2
plugins/__init__.py Ver fichero

@@ -1,9 +1,9 @@
from .plugin_manager import PluginManager
from .event import *
from .plugin import *
from .plugin_manager import PluginManager

instance = PluginManager()

register = instance.register
register = instance.register
# load_plugins = instance.load_plugins
# emit_event = instance.emit_event

+ 1
- 1
plugins/banwords/__init__.py Ver fichero

@@ -1 +1 @@
from .banwords import *
from .banwords import *

+ 49
- 35
plugins/banwords/banwords.py Ver fichero

@@ -2,56 +2,67 @@

import json
import os

import plugins
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
import plugins
from plugins import *
from common.log import logger
from plugins import *

from .lib.WordsSearch import WordsSearch


@plugins.register(name="Banwords", desire_priority=100, hidden=True, desc="判断消息中是否有敏感词、决定是否回复。", version="1.0", author="lanvent")
@plugins.register(
name="Banwords",
desire_priority=100,
hidden=True,
desc="判断消息中是否有敏感词、决定是否回复。",
version="1.0",
author="lanvent",
)
class Banwords(Plugin):
def __init__(self):
super().__init__()
try:
curdir=os.path.dirname(__file__)
config_path=os.path.join(curdir,"config.json")
conf=None
curdir = os.path.dirname(__file__)
config_path = os.path.join(curdir, "config.json")
conf = None
if not os.path.exists(config_path):
conf={"action":"ignore"}
with open(config_path,"w") as f:
json.dump(conf,f,indent=4)
conf = {"action": "ignore"}
with open(config_path, "w") as f:
json.dump(conf, f, indent=4)
else:
with open(config_path,"r") as f:
conf=json.load(f)
with open(config_path, "r") as f:
conf = json.load(f)
self.searchr = WordsSearch()
self.action = conf["action"]
banwords_path = os.path.join(curdir,"banwords.txt")
with open(banwords_path, 'r', encoding='utf-8') as f:
words=[]
banwords_path = os.path.join(curdir, "banwords.txt")
with open(banwords_path, "r", encoding="utf-8") as f:
words = []
for line in f:
word = line.strip()
if word:
words.append(word)
self.searchr.SetKeywords(words)
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
if conf.get("reply_filter",True):
if conf.get("reply_filter", True):
self.handlers[Event.ON_DECORATE_REPLY] = self.on_decorate_reply
self.reply_action = conf.get("reply_action","ignore")
self.reply_action = conf.get("reply_action", "ignore")
logger.info("[Banwords] inited")
except Exception as e:
logger.warn("[Banwords] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords .")
logger.warn(
"[Banwords] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords ."
)
raise e


def on_handle_context(self, e_context: EventContext):

if e_context['context'].type not in [ContextType.TEXT,ContextType.IMAGE_CREATE]:
if e_context["context"].type not in [
ContextType.TEXT,
ContextType.IMAGE_CREATE,
]:
return
content = e_context['context'].content
content = e_context["context"].content
logger.debug("[Banwords] on_handle_context. content: %s" % content)
if self.action == "ignore":
f = self.searchr.FindFirst(content)
@@ -61,31 +72,34 @@ class Banwords(Plugin):
return
elif self.action == "replace":
if self.searchr.ContainsAny(content):
reply = Reply(ReplyType.INFO, "发言中包含敏感词,请重试: \n"+self.searchr.Replace(content))
e_context['reply'] = reply
reply = Reply(
ReplyType.INFO, "发言中包含敏感词,请重试: \n" + self.searchr.Replace(content)
)
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS
return
def on_decorate_reply(self, e_context: EventContext):

if e_context['reply'].type not in [ReplyType.TEXT]:
def on_decorate_reply(self, e_context: EventContext):
if e_context["reply"].type not in [ReplyType.TEXT]:
return
reply = e_context['reply']
reply = e_context["reply"]
content = reply.content
if self.reply_action == "ignore":
f = self.searchr.FindFirst(content)
if f:
logger.info("[Banwords] %s in reply" % f["Keyword"])
e_context['reply'] = None
e_context["reply"] = None
e_context.action = EventAction.BREAK_PASS
return
elif self.reply_action == "replace":
if self.searchr.ContainsAny(content):
reply = Reply(ReplyType.INFO, "已替换回复中的敏感词: \n"+self.searchr.Replace(content))
e_context['reply'] = reply
reply = Reply(
ReplyType.INFO, "已替换回复中的敏感词: \n" + self.searchr.Replace(content)
)
e_context["reply"] = reply
e_context.action = EventAction.CONTINUE
return
def get_help_text(self, **kwargs):
return Banwords.desc
return Banwords.desc

+ 4
- 4
plugins/banwords/config.json.template Ver fichero

@@ -1,5 +1,5 @@
{
"action": "replace",
"reply_filter": true,
"reply_action": "ignore"
}
"action": "replace",
"reply_filter": true,
"reply_action": "ignore"
}

+ 1
- 1
plugins/bdunit/README.md Ver fichero

@@ -24,7 +24,7 @@ see https://ai.baidu.com/unit/home#/home?track=61fe1b0d3407ce3face1d92cb5c291087
``` json
{
"service_id": "s...", #"机器人ID"
"api_key": "",
"api_key": "",
"secret_key": ""
}
```

+ 1
- 1
plugins/bdunit/__init__.py Ver fichero

@@ -1 +1 @@
from .bdunit import *
from .bdunit import *

+ 30
- 42
plugins/bdunit/bdunit.py Ver fichero

@@ -2,21 +2,29 @@
import json
import os
import uuid
from uuid import getnode as get_mac

import requests

import plugins
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from common.log import logger
import plugins
from plugins import *
from uuid import getnode as get_mac


"""利用百度UNIT实现智能对话
如果命中意图,返回意图对应的回复,否则返回继续交付给下个插件处理
"""


@plugins.register(name="BDunit", desire_priority=0, hidden=True, desc="Baidu unit bot system", version="0.1", author="jackson")
@plugins.register(
name="BDunit",
desire_priority=0,
hidden=True,
desc="Baidu unit bot system",
version="0.1",
author="jackson",
)
class BDunit(Plugin):
def __init__(self):
super().__init__()
@@ -40,11 +48,10 @@ class BDunit(Plugin):
raise e

def on_handle_context(self, e_context: EventContext):

if e_context['context'].type != ContextType.TEXT:
if e_context["context"].type != ContextType.TEXT:
return

content = e_context['context'].content
content = e_context["context"].content
logger.debug("[BDunit] on_handle_context. content: %s" % content)
parsed = self.getUnit2(content)
intent = self.getIntent(parsed)
@@ -53,7 +60,7 @@ class BDunit(Plugin):
reply = Reply()
reply.type = ReplyType.TEXT
reply.content = self.getSay(parsed)
e_context['reply'] = reply
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
else:
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑
@@ -70,17 +77,15 @@ class BDunit(Plugin):
string: access_token
"""
url = "https://aip.baidubce.com/oauth/2.0/token?client_id={}&client_secret={}&grant_type=client_credentials".format(
self.api_key, self.secret_key)
self.api_key, self.secret_key
)
payload = ""
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json'
}
headers = {"Content-Type": "application/json", "Accept": "application/json"}

response = requests.request("POST", url, headers=headers, data=payload)

# print(response.text)
return response.json()['access_token']
return response.json()["access_token"]

def getUnit(self, query):
"""
@@ -90,11 +95,14 @@ class BDunit(Plugin):
"""

url = (
'https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token='
"https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token="
+ self.access_token
)
request = {"query": query, "user_id": str(
get_mac())[:32], "terminal_id": "88888"}
request = {
"query": query,
"user_id": str(get_mac())[:32],
"terminal_id": "88888",
}
body = {
"log_id": str(uuid.uuid1()),
"version": "3.0",
@@ -142,11 +150,7 @@ class BDunit(Plugin):
:param parsed: UNIT 解析结果
:returns: 意图数组
"""
if (
parsed
and "result" in parsed
and "response_list" in parsed["result"]
):
if parsed and "result" in parsed and "response_list" in parsed["result"]:
try:
return parsed["result"]["response_list"][0]["schema"]["intent"]
except Exception as e:
@@ -163,11 +167,7 @@ class BDunit(Plugin):
:param intent: 意图的名称
:returns: True: 包含; False: 不包含
"""
if (
parsed
and "result" in parsed
and "response_list" in parsed["result"]
):
if parsed and "result" in parsed and "response_list" in parsed["result"]:
response_list = parsed["result"]["response_list"]
for response in response_list:
if (
@@ -189,11 +189,7 @@ class BDunit(Plugin):
:returns: 词槽列表。你可以通过 name 属性筛选词槽,
再通过 normalized_word 属性取出相应的值
"""
if (
parsed
and "result" in parsed
and "response_list" in parsed["result"]
):
if parsed and "result" in parsed and "response_list" in parsed["result"]:
response_list = parsed["result"]["response_list"]
if intent == "":
try:
@@ -236,11 +232,7 @@ class BDunit(Plugin):
:param parsed: UNIT 解析结果
:returns: UNIT 的回复文本
"""
if (
parsed
and "result" in parsed
and "response_list" in parsed["result"]
):
if parsed and "result" in parsed and "response_list" in parsed["result"]:
response_list = parsed["result"]["response_list"]
answer = {}
for response in response_list:
@@ -266,11 +258,7 @@ class BDunit(Plugin):
:param intent: 意图的名称
:returns: UNIT 的回复文本
"""
if (
parsed
and "result" in parsed
and "response_list" in parsed["result"]
):
if parsed and "result" in parsed and "response_list" in parsed["result"]:
response_list = parsed["result"]["response_list"]
if intent == "":
try:


+ 4
- 4
plugins/bdunit/config.json.template Ver fichero

@@ -1,5 +1,5 @@
{
"service_id": "s...",
"api_key": "",
"secret_key": ""
}
"service_id": "s...",
"api_key": "",
"secret_key": ""
}

+ 1
- 1
plugins/dungeon/__init__.py Ver fichero

@@ -1 +1 @@
from .dungeon import *
from .dungeon import *

+ 47
- 28
plugins/dungeon/dungeon.py Ver fichero

@@ -1,17 +1,18 @@
# encoding:utf-8

import plugins
from bridge.bridge import Bridge
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from common import const
from common.expired_dict import ExpiredDict
from common.log import logger
from config import conf
import plugins
from plugins import *
from common.log import logger
from common import const


# https://github.com/bupticybee/ChineseAiDungeonChatGPT
class StoryTeller():
class StoryTeller:
def __init__(self, bot, sessionid, story):
self.bot = bot
self.sessionid = sessionid
@@ -27,67 +28,85 @@ class StoryTeller():
if user_action[-1] != "。":
user_action = user_action + "。"
if self.first_interact:
prompt = """现在来充当一个文字冒险游戏,描述时候注意节奏,不要太快,仔细描述各个人物的心情和周边环境。一次只需写四到六句话。
开头是,""" + self.story + " " + user_action
prompt = (
"""现在来充当一个文字冒险游戏,描述时候注意节奏,不要太快,仔细描述各个人物的心情和周边环境。一次只需写四到六句话。
开头是,"""
+ self.story
+ " "
+ user_action
)
self.first_interact = False
else:
prompt = """继续,一次只需要续写四到六句话,总共就只讲5分钟内发生的事情。""" + user_action
return prompt


@plugins.register(name="Dungeon", desire_priority=0, namecn="文字冒险", desc="A plugin to play dungeon game", version="1.0", author="lanvent")
@plugins.register(
name="Dungeon",
desire_priority=0,
namecn="文字冒险",
desc="A plugin to play dungeon game",
version="1.0",
author="lanvent",
)
class Dungeon(Plugin):
def __init__(self):
super().__init__()
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
logger.info("[Dungeon] inited")
# 目前没有设计session过期事件,这里先暂时使用过期字典
if conf().get('expires_in_seconds'):
self.games = ExpiredDict(conf().get('expires_in_seconds'))
if conf().get("expires_in_seconds"):
self.games = ExpiredDict(conf().get("expires_in_seconds"))
else:
self.games = dict()

def on_handle_context(self, e_context: EventContext):

if e_context['context'].type != ContextType.TEXT:
if e_context["context"].type != ContextType.TEXT:
return
bottype = Bridge().get_bot_type("chat")
if bottype not in (const.CHATGPT, const.OPEN_AI):
return
bot = Bridge().get_bot("chat")
content = e_context['context'].content[:]
clist = e_context['context'].content.split(maxsplit=1)
sessionid = e_context['context']['session_id']
content = e_context["context"].content[:]
clist = e_context["context"].content.split(maxsplit=1)
sessionid = e_context["context"]["session_id"]
logger.debug("[Dungeon] on_handle_context. content: %s" % clist)
trigger_prefix = conf().get('plugin_trigger_prefix', "$")
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
if clist[0] == f"{trigger_prefix}停止冒险":
if sessionid in self.games:
self.games[sessionid].reset()
del self.games[sessionid]
reply = Reply(ReplyType.INFO, "冒险结束!")
e_context['reply'] = reply
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS
elif clist[0] == f"{trigger_prefix}开始冒险" or sessionid in self.games:
if sessionid not in self.games or clist[0] == f"{trigger_prefix}开始冒险":
if len(clist)>1 :
if len(clist) > 1:
story = clist[1]
else:
story = "你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。"
story = (
"你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。"
)
self.games[sessionid] = StoryTeller(bot, sessionid, story)
reply = Reply(ReplyType.INFO, "冒险开始,你可以输入任意内容,让故事继续下去。故事背景是:" + story)
e_context['reply'] = reply
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
else:
prompt = self.games[sessionid].action(content)
e_context['context'].type = ContextType.TEXT
e_context['context'].content = prompt
e_context.action = EventAction.BREAK # 事件结束,不跳过处理context的默认逻辑
e_context["context"].type = ContextType.TEXT
e_context["context"].content = prompt
e_context.action = EventAction.BREAK # 事件结束,不跳过处理context的默认逻辑

def get_help_text(self, **kwargs):
help_text = "可以和机器人一起玩文字冒险游戏。\n"
if kwargs.get('verbose') != True:
if kwargs.get("verbose") != True:
return help_text
trigger_prefix = conf().get('plugin_trigger_prefix', "$")
help_text = f"{trigger_prefix}开始冒险 "+"背景故事: 开始一个基于{背景故事}的文字冒险,之后你的所有消息会协助完善这个故事。\n"+f"{trigger_prefix}停止冒险: 结束游戏。\n"
if kwargs.get('verbose') == True:
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
help_text = (
f"{trigger_prefix}开始冒险 "
+ "背景故事: 开始一个基于{背景故事}的文字冒险,之后你的所有消息会协助完善这个故事。\n"
+ f"{trigger_prefix}停止冒险: 结束游戏。\n"
)
if kwargs.get("verbose") == True:
help_text += f"\n命令例子: '{trigger_prefix}开始冒险 你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。'"
return help_text
return help_text

+ 6
- 6
plugins/event.py Ver fichero

@@ -9,17 +9,17 @@ class Event(Enum):
e_context = { "channel": 消息channel, "context" : 本次消息的context}
"""

ON_HANDLE_CONTEXT = 2 # 处理消息前
ON_HANDLE_CONTEXT = 2 # 处理消息前
"""
e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复,初始为空 }
"""

ON_DECORATE_REPLY = 3 # 得到回复后准备装饰
ON_DECORATE_REPLY = 3 # 得到回复后准备装饰
"""
e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 }
"""

ON_SEND_REPLY = 4 # 发送回复前
ON_SEND_REPLY = 4 # 发送回复前
"""
e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 }
"""
@@ -28,9 +28,9 @@ class Event(Enum):


class EventAction(Enum):
CONTINUE = 1 # 事件未结束,继续交给下个插件处理,如果没有下个插件,则交付给默认的事件处理逻辑
BREAK = 2 # 事件结束,不再给下个插件处理,交付给默认的事件处理逻辑
BREAK_PASS = 3 # 事件结束,不再给下个插件处理,不交付给默认的事件处理逻辑
CONTINUE = 1 # 事件未结束,继续交给下个插件处理,如果没有下个插件,则交付给默认的事件处理逻辑
BREAK = 2 # 事件结束,不再给下个插件处理,交付给默认的事件处理逻辑
BREAK_PASS = 3 # 事件结束,不再给下个插件处理,不交付给默认的事件处理逻辑


class EventContext:


+ 1
- 1
plugins/finish/__init__.py Ver fichero

@@ -1 +1 @@
from .finish import *
from .finish import *

+ 15
- 9
plugins/finish/finish.py Ver fichero

@@ -1,14 +1,21 @@
# encoding:utf-8

import plugins
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from common.log import logger
from config import conf
import plugins
from plugins import *
from common.log import logger


@plugins.register(name="Finish", desire_priority=-999, hidden=True, desc="A plugin that check unknown command", version="1.0", author="js00000")
@plugins.register(
name="Finish",
desire_priority=-999,
hidden=True,
desc="A plugin that check unknown command",
version="1.0",
author="js00000",
)
class Finish(Plugin):
def __init__(self):
super().__init__()
@@ -16,19 +23,18 @@ class Finish(Plugin):
logger.info("[Finish] inited")

def on_handle_context(self, e_context: EventContext):

if e_context['context'].type != ContextType.TEXT:
if e_context["context"].type != ContextType.TEXT:
return

content = e_context['context'].content
content = e_context["context"].content
logger.debug("[Finish] on_handle_context. content: %s" % content)
trigger_prefix = conf().get('plugin_trigger_prefix',"$")
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
if content.startswith(trigger_prefix):
reply = Reply()
reply.type = ReplyType.ERROR
reply.content = "未知插件命令\n查看插件命令列表请输入#help 插件名\n"
e_context['reply'] = reply
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑

def get_help_text(self, **kwargs):
return ""

+ 1
- 1
plugins/godcmd/__init__.py Ver fichero

@@ -1 +1 @@
from .godcmd import *
from .godcmd import *

+ 3
- 3
plugins/godcmd/config.json.template Ver fichero

@@ -1,4 +1,4 @@
{
"password": "",
"admin_users": []
}
"password": "",
"admin_users": []
}

+ 96
- 70
plugins/godcmd/godcmd.py Ver fichero

@@ -6,14 +6,16 @@ import random
import string
import traceback
from typing import Tuple

import plugins
from bridge.bridge import Bridge
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from config import conf, load_config
import plugins
from plugins import *
from common import const
from common.log import logger
from config import conf, load_config
from plugins import *

# 定义指令集
COMMANDS = {
"help": {
@@ -41,7 +43,7 @@ COMMANDS = {
},
"id": {
"alias": ["id", "用户"],
"desc": "获取用户id", # wechaty和wechatmp的用户id不会变化,可用于绑定管理员
"desc": "获取用户id", # wechaty和wechatmp的用户id不会变化,可用于绑定管理员
},
"reset": {
"alias": ["reset", "重置会话"],
@@ -114,18 +116,20 @@ ADMIN_COMMANDS = {
"desc": "开启机器调试日志",
},
}


# 定义帮助函数
def get_help_text(isadmin, isgroup):
help_text = "通用指令:\n"
for cmd, info in COMMANDS.items():
if cmd=="auth": #不提示认证指令
if cmd == "auth": # 不提示认证指令
continue
if cmd=="id" and conf().get("channel_type","wx") not in ["wxy","wechatmp"]:
if cmd == "id" and conf().get("channel_type", "wx") not in ["wxy", "wechatmp"]:
continue
alias=["#"+a for a in info['alias'][:1]]
alias = ["#" + a for a in info["alias"][:1]]
help_text += f"{','.join(alias)} "
if 'args' in info:
args=[a for a in info['args']]
if "args" in info:
args = [a for a in info["args"]]
help_text += f"{' '.join(args)}"
help_text += f": {info['desc']}\n"

@@ -135,39 +139,48 @@ def get_help_text(isadmin, isgroup):
for plugin in plugins:
if plugins[plugin].enabled and not plugins[plugin].hidden:
namecn = plugins[plugin].namecn
help_text += "\n%s:"%namecn
help_text += PluginManager().instances[plugin].get_help_text(verbose=False).strip()
help_text += "\n%s:" % namecn
help_text += (
PluginManager().instances[plugin].get_help_text(verbose=False).strip()
)

if ADMIN_COMMANDS and isadmin:
help_text += "\n\n管理员指令:\n"
for cmd, info in ADMIN_COMMANDS.items():
alias=["#"+a for a in info['alias'][:1]]
alias = ["#" + a for a in info["alias"][:1]]
help_text += f"{','.join(alias)} "
if 'args' in info:
args=[a for a in info['args']]
if "args" in info:
args = [a for a in info["args"]]
help_text += f"{' '.join(args)}"
help_text += f": {info['desc']}\n"
return help_text

@plugins.register(name="Godcmd", desire_priority=999, hidden=True, desc="为你的机器人添加指令集,有用户和管理员两种角色,加载顺序请放在首位,初次运行后插件目录会生成配置文件, 填充管理员密码后即可认证", version="1.0", author="lanvent")
class Godcmd(Plugin):

@plugins.register(
name="Godcmd",
desire_priority=999,
hidden=True,
desc="为你的机器人添加指令集,有用户和管理员两种角色,加载顺序请放在首位,初次运行后插件目录会生成配置文件, 填充管理员密码后即可认证",
version="1.0",
author="lanvent",
)
class Godcmd(Plugin):
def __init__(self):
super().__init__()

curdir=os.path.dirname(__file__)
config_path=os.path.join(curdir,"config.json")
gconf=None
curdir = os.path.dirname(__file__)
config_path = os.path.join(curdir, "config.json")
gconf = None
if not os.path.exists(config_path):
gconf={"password":"","admin_users":[]}
with open(config_path,"w") as f:
json.dump(gconf,f,indent=4)
gconf = {"password": "", "admin_users": []}
with open(config_path, "w") as f:
json.dump(gconf, f, indent=4)
else:
with open(config_path,"r") as f:
gconf=json.load(f)
with open(config_path, "r") as f:
gconf = json.load(f)
if gconf["password"] == "":
self.temp_password = "".join(random.sample(string.digits, 4))
logger.info("[Godcmd] 因未设置口令,本次的临时口令为%s。"%self.temp_password)
logger.info("[Godcmd] 因未设置口令,本次的临时口令为%s。" % self.temp_password)
else:
self.temp_password = None
custom_commands = conf().get("clear_memory_commands", [])
@@ -178,41 +191,42 @@ class Godcmd(Plugin):
COMMANDS["reset"]["alias"].append(custom_command)

self.password = gconf["password"]
self.admin_users = gconf["admin_users"] # 预存的管理员账号,这些账号不需要认证。itchat的用户名每次都会变,不可用
self.isrunning = True # 机器人是否运行中
self.admin_users = gconf[
"admin_users"
] # 预存的管理员账号,这些账号不需要认证。itchat的用户名每次都会变,不可用
self.isrunning = True # 机器人是否运行中

self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
logger.info("[Godcmd] inited")


def on_handle_context(self, e_context: EventContext):
context_type = e_context['context'].type
context_type = e_context["context"].type
if context_type != ContextType.TEXT:
if not self.isrunning:
e_context.action = EventAction.BREAK_PASS
return

content = e_context['context'].content
content = e_context["context"].content
logger.debug("[Godcmd] on_handle_context. content: %s" % content)
if content.startswith("#"):
# msg = e_context['context']['msg']
channel = e_context['channel']
user = e_context['context']['receiver']
session_id = e_context['context']['session_id']
isgroup = e_context['context'].get("isgroup", False)
channel = e_context["channel"]
user = e_context["context"]["receiver"]
session_id = e_context["context"]["session_id"]
isgroup = e_context["context"].get("isgroup", False)
bottype = Bridge().get_bot_type("chat")
bot = Bridge().get_bot("chat")
# 将命令和参数分割
command_parts = content[1:].strip().split()
cmd = command_parts[0]
args = command_parts[1:]
isadmin=False
isadmin = False
if user in self.admin_users:
isadmin=True
ok=False
result="string"
if any(cmd in info['alias'] for info in COMMANDS.values()):
cmd = next(c for c, info in COMMANDS.items() if cmd in info['alias'])
isadmin = True
ok = False
result = "string"
if any(cmd in info["alias"] for info in COMMANDS.values()):
cmd = next(c for c, info in COMMANDS.items() if cmd in info["alias"])
if cmd == "auth":
ok, result = self.authenticate(user, args, isadmin, isgroup)
elif cmd == "help" or cmd == "helpp":
@@ -224,10 +238,14 @@ class Godcmd(Plugin):
query_name = args[0].upper()
# search name and namecn
for name, plugincls in plugins.items():
if not plugincls.enabled :
if not plugincls.enabled:
continue
if query_name == name or query_name == plugincls.namecn:
ok, result = True, PluginManager().instances[name].get_help_text(isgroup=isgroup, isadmin=isadmin, verbose=True)
ok, result = True, PluginManager().instances[
name
].get_help_text(
isgroup=isgroup, isadmin=isadmin, verbose=True
)
break
if not ok:
result = "插件不存在或未启用"
@@ -236,14 +254,14 @@ class Godcmd(Plugin):
elif cmd == "set_openai_api_key":
if len(args) == 1:
user_data = conf().get_user_data(user)
user_data['openai_api_key'] = args[0]
user_data["openai_api_key"] = args[0]
ok, result = True, "你的OpenAI私有api_key已设置为" + args[0]
else:
ok, result = False, "请提供一个api_key"
elif cmd == "reset_openai_api_key":
try:
user_data = conf().get_user_data(user)
user_data.pop('openai_api_key')
user_data.pop("openai_api_key")
ok, result = True, "你的OpenAI私有api_key已清除"
except Exception as e:
ok, result = False, "你没有设置私有api_key"
@@ -255,12 +273,16 @@ class Godcmd(Plugin):
else:
ok, result = False, "当前对话机器人不支持重置会话"
logger.debug("[Godcmd] command: %s by %s" % (cmd, user))
elif any(cmd in info['alias'] for info in ADMIN_COMMANDS.values()):
elif any(cmd in info["alias"] for info in ADMIN_COMMANDS.values()):
if isadmin:
if isgroup:
ok, result = False, "群聊不可执行管理员指令"
else:
cmd = next(c for c, info in ADMIN_COMMANDS.items() if cmd in info['alias'])
cmd = next(
c
for c, info in ADMIN_COMMANDS.items()
if cmd in info["alias"]
)
if cmd == "stop":
self.isrunning = False
ok, result = True, "服务已暂停"
@@ -278,13 +300,13 @@ class Godcmd(Plugin):
else:
ok, result = False, "当前对话机器人不支持重置会话"
elif cmd == "debug":
logger.setLevel('DEBUG')
logger.setLevel("DEBUG")
ok, result = True, "DEBUG模式已开启"
elif cmd == "plist":
plugins = PluginManager().list_plugins()
ok = True
result = "插件列表:\n"
for name,plugincls in plugins.items():
for name, plugincls in plugins.items():
result += f"{plugincls.name}_v{plugincls.version} {plugincls.priority} - "
if plugincls.enabled:
result += "已启用\n"
@@ -294,16 +316,20 @@ class Godcmd(Plugin):
new_plugins = PluginManager().scan_plugins()
ok, result = True, "插件扫描完成"
PluginManager().activate_plugins()
if len(new_plugins) >0 :
if len(new_plugins) > 0:
result += "\n发现新插件:\n"
result += "\n".join([f"{p.name}_v{p.version}" for p in new_plugins])
else :
result +=", 未发现新插件"
result += "\n".join(
[f"{p.name}_v{p.version}" for p in new_plugins]
)
else:
result += ", 未发现新插件"
elif cmd == "setpri":
if len(args) != 2:
ok, result = False, "请提供插件名和优先级"
else:
ok = PluginManager().set_plugin_priority(args[0], int(args[1]))
ok = PluginManager().set_plugin_priority(
args[0], int(args[1])
)
if ok:
result = "插件" + args[0] + "优先级已设置为" + args[1]
else:
@@ -350,42 +376,42 @@ class Godcmd(Plugin):
else:
ok, result = False, "需要管理员权限才能执行该指令"
else:
trigger_prefix = conf().get('plugin_trigger_prefix',"$")
if trigger_prefix == "#": # 跟插件聊天指令前缀相同,继续递交
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
if trigger_prefix == "#": # 跟插件聊天指令前缀相同,继续递交
return
ok, result = False, f"未知指令:{cmd}\n查看指令列表请输入#help \n"
reply = Reply()
if ok:
reply.type = ReplyType.INFO
else:
reply.type = ReplyType.ERROR
reply.content = result
e_context['reply'] = reply
e_context["reply"] = reply

e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
elif not self.isrunning:
e_context.action = EventAction.BREAK_PASS

def authenticate(self, userid, args, isadmin, isgroup) -> Tuple[bool,str] :
def authenticate(self, userid, args, isadmin, isgroup) -> Tuple[bool, str]:
if isgroup:
return False,"请勿在群聊中认证"
return False, "请勿在群聊中认证"
if isadmin:
return False,"管理员账号无需认证"
return False, "管理员账号无需认证"
if len(args) != 1:
return False,"请提供口令"
return False, "请提供口令"
password = args[0]
if password == self.password:
self.admin_users.append(userid)
return True,"认证成功"
return True, "认证成功"
elif password == self.temp_password:
self.admin_users.append(userid)
return True,"认证成功,请尽快设置口令"
return True, "认证成功,请尽快设置口令"
else:
return False,"认证失败"
return False, "认证失败"

def get_help_text(self, isadmin = False, isgroup = False, **kwargs):
return get_help_text(isadmin, isgroup)
def get_help_text(self, isadmin=False, isgroup=False, **kwargs):
return get_help_text(isadmin, isgroup)

+ 1
- 1
plugins/hello/__init__.py Ver fichero

@@ -1 +1 @@
from .hello import *
from .hello import *

+ 22
- 14
plugins/hello/hello.py Ver fichero

@@ -1,14 +1,21 @@
# encoding:utf-8

import plugins
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from channel.chat_message import ChatMessage
import plugins
from plugins import *
from common.log import logger
from plugins import *


@plugins.register(name="Hello", desire_priority=-1, hidden=True, desc="A simple plugin that says hello", version="0.1", author="lanvent")
@plugins.register(
name="Hello",
desire_priority=-1,
hidden=True,
desc="A simple plugin that says hello",
version="0.1",
author="lanvent",
)
class Hello(Plugin):
def __init__(self):
super().__init__()
@@ -16,33 +23,34 @@ class Hello(Plugin):
logger.info("[Hello] inited")

def on_handle_context(self, e_context: EventContext):

if e_context['context'].type != ContextType.TEXT:
if e_context["context"].type != ContextType.TEXT:
return
content = e_context['context'].content
content = e_context["context"].content
logger.debug("[Hello] on_handle_context. content: %s" % content)
if content == "Hello":
reply = Reply()
reply.type = ReplyType.TEXT
msg:ChatMessage = e_context['context']['msg']
if e_context['context']['isgroup']:
reply.content = f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}"
msg: ChatMessage = e_context["context"]["msg"]
if e_context["context"]["isgroup"]:
reply.content = (
f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}"
)
else:
reply.content = f"Hello, {msg.from_user_nickname}"
e_context['reply'] = reply
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑

if content == "Hi":
reply = Reply()
reply.type = ReplyType.TEXT
reply.content = "Hi"
e_context['reply'] = reply
e_context["reply"] = reply
e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑,一般会覆写reply

if content == "End":
# 如果是文本消息"End",将请求转换成"IMAGE_CREATE",并将content设置为"The World"
e_context['context'].type = ContextType.IMAGE_CREATE
e_context["context"].type = ContextType.IMAGE_CREATE
content = "The World"
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑



+ 1
- 1
plugins/plugin.py Ver fichero

@@ -3,4 +3,4 @@ class Plugin:
self.handlers = {}

def get_help_text(self, **kwargs):
return "暂无帮助信息"
return "暂无帮助信息"

+ 104
- 52
plugins/plugin_manager.py Ver fichero

@@ -5,17 +5,19 @@ import importlib.util
import json
import os
import sys

from common.log import logger
from common.singleton import singleton
from common.sorted_dict import SortedDict
from .event import *
from common.log import logger
from config import conf

from .event import *


@singleton
class PluginManager:
def __init__(self):
self.plugins = SortedDict(lambda k,v: v.priority,reverse=True)
self.plugins = SortedDict(lambda k, v: v.priority, reverse=True)
self.listening_plugins = {}
self.instances = {}
self.pconf = {}
@@ -26,17 +28,27 @@ class PluginManager:
def wrapper(plugincls):
plugincls.name = name
plugincls.priority = desire_priority
plugincls.desc = kwargs.get('desc')
plugincls.author = kwargs.get('author')
plugincls.desc = kwargs.get("desc")
plugincls.author = kwargs.get("author")
plugincls.path = self.current_plugin_path
plugincls.version = kwargs.get('version') if kwargs.get('version') != None else "1.0"
plugincls.namecn = kwargs.get('namecn') if kwargs.get('namecn') != None else name
plugincls.hidden = kwargs.get('hidden') if kwargs.get('hidden') != None else False
plugincls.version = (
kwargs.get("version") if kwargs.get("version") != None else "1.0"
)
plugincls.namecn = (
kwargs.get("namecn") if kwargs.get("namecn") != None else name
)
plugincls.hidden = (
kwargs.get("hidden") if kwargs.get("hidden") != None else False
)
plugincls.enabled = True
if self.current_plugin_path == None:
raise Exception("Plugin path not set")
self.plugins[name.upper()] = plugincls
logger.info("Plugin %s_v%s registered, path=%s" % (name, plugincls.version, plugincls.path))
logger.info(
"Plugin %s_v%s registered, path=%s"
% (name, plugincls.version, plugincls.path)
)

return wrapper

def save_config(self):
@@ -50,10 +62,12 @@ class PluginManager:
if os.path.exists("./plugins/plugins.json"):
with open("./plugins/plugins.json", "r", encoding="utf-8") as f:
pconf = json.load(f)
pconf['plugins'] = SortedDict(lambda k,v: v["priority"],pconf['plugins'],reverse=True)
pconf["plugins"] = SortedDict(
lambda k, v: v["priority"], pconf["plugins"], reverse=True
)
else:
modified = True
pconf = {"plugins": SortedDict(lambda k,v: v["priority"],reverse=True)}
pconf = {"plugins": SortedDict(lambda k, v: v["priority"], reverse=True)}
self.pconf = pconf
if modified:
self.save_config()
@@ -67,7 +81,7 @@ class PluginManager:
plugin_path = os.path.join(plugins_dir, plugin_name)
if os.path.isdir(plugin_path):
# 判断插件是否包含同名__init__.py文件
main_module_path = os.path.join(plugin_path,"__init__.py")
main_module_path = os.path.join(plugin_path, "__init__.py")
if os.path.isfile(main_module_path):
# 导入插件
import_path = "plugins.{}".format(plugin_name)
@@ -76,16 +90,26 @@ class PluginManager:
if plugin_path in self.loaded:
if self.loaded[plugin_path] == None:
logger.info("reload module %s" % plugin_name)
self.loaded[plugin_path] = importlib.reload(sys.modules[import_path])
dependent_module_names = [name for name in sys.modules.keys() if name.startswith( import_path+ '.')]
self.loaded[plugin_path] = importlib.reload(
sys.modules[import_path]
)
dependent_module_names = [
name
for name in sys.modules.keys()
if name.startswith(import_path + ".")
]
for name in dependent_module_names:
logger.info("reload module %s" % name)
importlib.reload(sys.modules[name])
else:
self.loaded[plugin_path] = importlib.import_module(import_path)
self.loaded[plugin_path] = importlib.import_module(
import_path
)
self.current_plugin_path = None
except Exception as e:
logger.exception("Failed to import plugin %s: %s" % (plugin_name, e))
logger.exception(
"Failed to import plugin %s: %s" % (plugin_name, e)
)
continue
pconf = self.pconf
news = [self.plugins[name] for name in self.plugins]
@@ -95,21 +119,28 @@ class PluginManager:
rawname = plugincls.name
if rawname not in pconf["plugins"]:
modified = True
logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name)
pconf["plugins"][rawname] = {"enabled": plugincls.enabled, "priority": plugincls.priority}
logger.info(
"Plugin %s not found in pconfig, adding to pconfig..." % name
)
pconf["plugins"][rawname] = {
"enabled": plugincls.enabled,
"priority": plugincls.priority,
}
else:
self.plugins[name].enabled = pconf["plugins"][rawname]["enabled"]
self.plugins[name].priority = pconf["plugins"][rawname]["priority"]
self.plugins._update_heap(name) # 更新下plugins中的顺序
self.plugins._update_heap(name) # 更新下plugins中的顺序
if modified:
self.save_config()
return new_plugins

def refresh_order(self):
for event in self.listening_plugins.keys():
self.listening_plugins[event].sort(key=lambda name: self.plugins[name].priority, reverse=True)
self.listening_plugins[event].sort(
key=lambda name: self.plugins[name].priority, reverse=True
)

def activate_plugins(self): # 生成新开启的插件实例
def activate_plugins(self): # 生成新开启的插件实例
failed_plugins = []
for name, plugincls in self.plugins.items():
if plugincls.enabled:
@@ -129,7 +160,7 @@ class PluginManager:
self.refresh_order()
return failed_plugins

def reload_plugin(self, name:str):
def reload_plugin(self, name: str):
name = name.upper()
if name in self.instances:
for event in self.listening_plugins:
@@ -139,13 +170,13 @@ class PluginManager:
self.activate_plugins()
return True
return False
def load_plugins(self):
self.load_config()
self.scan_plugins()
pconf = self.pconf
logger.debug("plugins.json config={}".format(pconf))
for name,plugin in pconf["plugins"].items():
for name, plugin in pconf["plugins"].items():
if name.upper() not in self.plugins:
logger.error("Plugin %s not found, but found in plugins.json" % name)
self.activate_plugins()
@@ -153,13 +184,18 @@ class PluginManager:
def emit_event(self, e_context: EventContext, *args, **kwargs):
if e_context.event in self.listening_plugins:
for name in self.listening_plugins[e_context.event]:
if self.plugins[name].enabled and e_context.action == EventAction.CONTINUE:
logger.debug("Plugin %s triggered by event %s" % (name,e_context.event))
if (
self.plugins[name].enabled
and e_context.action == EventAction.CONTINUE
):
logger.debug(
"Plugin %s triggered by event %s" % (name, e_context.event)
)
instance = self.instances[name]
instance.handlers[e_context.event](e_context, *args, **kwargs)
return e_context

def set_plugin_priority(self, name:str, priority:int):
def set_plugin_priority(self, name: str, priority: int):
name = name.upper()
if name not in self.plugins:
return False
@@ -174,11 +210,11 @@ class PluginManager:
self.refresh_order()
return True

def enable_plugin(self, name:str):
def enable_plugin(self, name: str):
name = name.upper()
if name not in self.plugins:
return False, "插件不存在"
if not self.plugins[name].enabled :
if not self.plugins[name].enabled:
self.plugins[name].enabled = True
rawname = self.plugins[name].name
self.pconf["plugins"][rawname]["enabled"] = True
@@ -188,43 +224,47 @@ class PluginManager:
return False, "插件开启失败"
return True, "插件已开启"
return True, "插件已开启"
def disable_plugin(self, name:str):
def disable_plugin(self, name: str):
name = name.upper()
if name not in self.plugins:
return False
if self.plugins[name].enabled :
if self.plugins[name].enabled:
self.plugins[name].enabled = False
rawname = self.plugins[name].name
self.pconf["plugins"][rawname]["enabled"] = False
self.save_config()
return True
return True
def list_plugins(self):
return self.plugins
def install_plugin(self, repo:str):
def install_plugin(self, repo: str):
try:
import common.package_manager as pkgmgr

pkgmgr.check_dulwich()
except Exception as e:
logger.error("Failed to install plugin, {}".format(e))
return False, "无法导入dulwich,安装插件失败"
import re

from dulwich import porcelain

logger.info("clone git repo: {}".format(repo))
match = re.match(r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo)
if not match:
try:
with open("./plugins/source.json","r", encoding="utf-8") as f:
with open("./plugins/source.json", "r", encoding="utf-8") as f:
source = json.load(f)
if repo in source["repo"]:
repo = source["repo"][repo]["url"]
match = re.match(r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo)
match = re.match(
r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo
)
if not match:
return False, "安装插件失败,source中的仓库地址不合法"
else:
@@ -232,42 +272,53 @@ class PluginManager:
except Exception as e:
logger.error("Failed to install plugin, {}".format(e))
return False, "安装插件失败,请检查仓库地址是否正确"
dirname = os.path.join("./plugins",match.group(4))
dirname = os.path.join("./plugins", match.group(4))
try:
repo = porcelain.clone(repo, dirname, checkout=True)
if os.path.exists(os.path.join(dirname,"requirements.txt")):
if os.path.exists(os.path.join(dirname, "requirements.txt")):
logger.info("detect requirements.txt,installing...")
pkgmgr.install_requirements(os.path.join(dirname,"requirements.txt"))
pkgmgr.install_requirements(os.path.join(dirname, "requirements.txt"))
return True, "安装插件成功,请使用 #scanp 命令扫描插件或重启程序,开启前请检查插件是否需要配置"
except Exception as e:
logger.error("Failed to install plugin, {}".format(e))
return False, "安装插件失败,"+str(e)
def update_plugin(self, name:str):
return False, "安装插件失败," + str(e)
def update_plugin(self, name: str):
try:
import common.package_manager as pkgmgr

pkgmgr.check_dulwich()
except Exception as e:
logger.error("Failed to install plugin, {}".format(e))
return False, "无法导入dulwich,更新插件失败"
from dulwich import porcelain

name = name.upper()
if name not in self.plugins:
return False, "插件不存在"
if name in ["HELLO","GODCMD","ROLE","TOOL","BDUNIT","BANWORDS","FINISH","DUNGEON"]:
if name in [
"HELLO",
"GODCMD",
"ROLE",
"TOOL",
"BDUNIT",
"BANWORDS",
"FINISH",
"DUNGEON",
]:
return False, "预置插件无法更新,请更新主程序仓库"
dirname = self.plugins[name].path
try:
porcelain.pull(dirname, "origin")
if os.path.exists(os.path.join(dirname,"requirements.txt")):
if os.path.exists(os.path.join(dirname, "requirements.txt")):
logger.info("detect requirements.txt,installing...")
pkgmgr.install_requirements(os.path.join(dirname,"requirements.txt"))
pkgmgr.install_requirements(os.path.join(dirname, "requirements.txt"))
return True, "更新插件成功,请重新运行程序"
except Exception as e:
logger.error("Failed to update plugin, {}".format(e))
return False, "更新插件失败,"+str(e)
def uninstall_plugin(self, name:str):
return False, "更新插件失败," + str(e)
def uninstall_plugin(self, name: str):
name = name.upper()
if name not in self.plugins:
return False, "插件不存在"
@@ -276,6 +327,7 @@ class PluginManager:
dirname = self.plugins[name].path
try:
import shutil

shutil.rmtree(dirname)
rawname = self.plugins[name].name
for event in self.listening_plugins:
@@ -288,4 +340,4 @@ class PluginManager:
return True, "卸载插件成功"
except Exception as e:
logger.error("Failed to uninstall plugin, {}".format(e))
return False, "卸载插件失败,请手动删除文件夹完成卸载,"+str(e)
return False, "卸载插件失败,请手动删除文件夹完成卸载," + str(e)

+ 1
- 1
plugins/role/__init__.py Ver fichero

@@ -1 +1 @@
from .role import *
from .role import *

+ 66
- 35
plugins/role/role.py Ver fichero

@@ -2,17 +2,18 @@

import json
import os

import plugins
from bridge.bridge import Bridge
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from common import const
from common.log import logger
from config import conf
import plugins
from plugins import *
from common.log import logger


class RolePlay():
class RolePlay:
def __init__(self, bot, sessionid, desc, wrapper=None):
self.bot = bot
self.sessionid = sessionid
@@ -25,12 +26,20 @@ class RolePlay():

def action(self, user_action):
session = self.bot.sessions.build_session(self.sessionid)
if session.system_prompt != self.desc: # 目前没有触发session过期事件,这里先简单判断,然后重置
if session.system_prompt != self.desc: # 目前没有触发session过期事件,这里先简单判断,然后重置
session.set_system_prompt(self.desc)
prompt = self.wrapper % user_action
return prompt

@plugins.register(name="Role", desire_priority=0, namecn="角色扮演", desc="为你的Bot设置预设角色", version="1.0", author="lanvent")

@plugins.register(
name="Role",
desire_priority=0,
namecn="角色扮演",
desc="为你的Bot设置预设角色",
version="1.0",
author="lanvent",
)
class Role(Plugin):
def __init__(self):
super().__init__()
@@ -39,7 +48,7 @@ class Role(Plugin):
try:
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
self.tags = { tag:(desc,[]) for tag,desc in config["tags"].items()}
self.tags = {tag: (desc, []) for tag, desc in config["tags"].items()}
self.roles = {}
for role in config["roles"]:
self.roles[role["title"].lower()] = role
@@ -60,12 +69,16 @@ class Role(Plugin):
logger.info("[Role] inited")
except Exception as e:
if isinstance(e, FileNotFoundError):
logger.warn(f"[Role] init failed, {config_path} not found, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role .")
logger.warn(
f"[Role] init failed, {config_path} not found, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role ."
)
else:
logger.warn("[Role] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role .")
logger.warn(
"[Role] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/role ."
)
raise e

def get_role(self, name, find_closest=True, min_sim = 0.35):
def get_role(self, name, find_closest=True, min_sim=0.35):
name = name.lower()
found_role = None
if name in self.roles:
@@ -75,6 +88,7 @@ class Role(Plugin):

def str_simularity(a, b):
return difflib.SequenceMatcher(None, a, b).ratio()

max_sim = min_sim
max_role = None
for role in self.roles:
@@ -86,25 +100,24 @@ class Role(Plugin):
return found_role

def on_handle_context(self, e_context: EventContext):

if e_context['context'].type != ContextType.TEXT:
if e_context["context"].type != ContextType.TEXT:
return
bottype = Bridge().get_bot_type("chat")
if bottype not in (const.CHATGPT, const.OPEN_AI):
return
bot = Bridge().get_bot("chat")
content = e_context['context'].content[:]
clist = e_context['context'].content.split(maxsplit=1)
content = e_context["context"].content[:]
clist = e_context["context"].content.split(maxsplit=1)
desckey = None
customize = False
sessionid = e_context['context']['session_id']
trigger_prefix = conf().get('plugin_trigger_prefix', "$")
sessionid = e_context["context"]["session_id"]
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
if clist[0] == f"{trigger_prefix}停止扮演":
if sessionid in self.roleplays:
self.roleplays[sessionid].reset()
del self.roleplays[sessionid]
reply = Reply(ReplyType.INFO, "角色扮演结束!")
e_context['reply'] = reply
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS
return
elif clist[0] == f"{trigger_prefix}角色":
@@ -114,10 +127,10 @@ class Role(Plugin):
elif clist[0] == f"{trigger_prefix}设定扮演":
customize = True
elif clist[0] == f"{trigger_prefix}角色类型":
if len(clist) >1:
if len(clist) > 1:
tag = clist[1].strip()
help_text = "角色列表:\n"
for key,value in self.tags.items():
for key, value in self.tags.items():
if value[0] == tag:
tag = key
break
@@ -130,57 +143,75 @@ class Role(Plugin):
else:
help_text = f"未知角色类型。\n"
help_text += "目前的角色类型有: \n"
help_text += ",".join([self.tags[tag][0] for tag in self.tags])+"\n"
help_text += (
",".join([self.tags[tag][0] for tag in self.tags]) + "\n"
)
else:
help_text = f"请输入角色类型。\n"
help_text += "目前的角色类型有: \n"
help_text += ",".join([self.tags[tag][0] for tag in self.tags])+"\n"
help_text += ",".join([self.tags[tag][0] for tag in self.tags]) + "\n"
reply = Reply(ReplyType.INFO, help_text)
e_context['reply'] = reply
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS
return
elif sessionid not in self.roleplays:
return
logger.debug("[Role] on_handle_context. content: %s" % content)
if desckey is not None:
if len(clist) == 1 or (len(clist) > 1 and clist[1].lower() in ["help", "帮助"]):
if len(clist) == 1 or (
len(clist) > 1 and clist[1].lower() in ["help", "帮助"]
):
reply = Reply(ReplyType.INFO, self.get_help_text(verbose=True))
e_context['reply'] = reply
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS
return
role = self.get_role(clist[1])
if role is None:
reply = Reply(ReplyType.ERROR, "角色不存在")
e_context['reply'] = reply
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS
return
else:
self.roleplays[sessionid] = RolePlay(bot, sessionid, self.roles[role][desckey], self.roles[role].get("wrapper","%s"))
reply = Reply(ReplyType.INFO, f"预设角色为 {role}:\n"+self.roles[role][desckey])
e_context['reply'] = reply
self.roleplays[sessionid] = RolePlay(
bot,
sessionid,
self.roles[role][desckey],
self.roles[role].get("wrapper", "%s"),
)
reply = Reply(
ReplyType.INFO, f"预设角色为 {role}:\n" + self.roles[role][desckey]
)
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS
elif customize == True:
self.roleplays[sessionid] = RolePlay(bot, sessionid, clist[1], "%s")
reply = Reply(ReplyType.INFO, f"角色设定为:\n{clist[1]}")
e_context['reply'] = reply
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS
else:
prompt = self.roleplays[sessionid].action(content)
e_context['context'].type = ContextType.TEXT
e_context['context'].content = prompt
e_context["context"].type = ContextType.TEXT
e_context["context"].content = prompt
e_context.action = EventAction.BREAK

def get_help_text(self, verbose=False, **kwargs):
help_text = "让机器人扮演不同的角色。\n"
if not verbose:
return help_text
trigger_prefix = conf().get('plugin_trigger_prefix', "$")
help_text = f"使用方法:\n{trigger_prefix}角色"+" 预设角色名: 设定角色为{预设角色名}。\n"+f"{trigger_prefix}role"+" 预设角色名: 同上,但使用英文设定。\n"
help_text += f"{trigger_prefix}设定扮演"+" 角色设定: 设定自定义角色人设为{角色设定}。\n"
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
help_text = (
f"使用方法:\n{trigger_prefix}角色"
+ " 预设角色名: 设定角色为{预设角色名}。\n"
+ f"{trigger_prefix}role"
+ " 预设角色名: 同上,但使用英文设定。\n"
)
help_text += f"{trigger_prefix}设定扮演" + " 角色设定: 设定自定义角色人设为{角色设定}。\n"
help_text += f"{trigger_prefix}停止扮演: 清除设定的角色。\n"
help_text += f"{trigger_prefix}角色类型"+" 角色类型: 查看某类{角色类型}的所有预设角色,为所有时输出所有预设角色。\n"
help_text += (
f"{trigger_prefix}角色类型" + " 角色类型: 查看某类{角色类型}的所有预设角色,为所有时输出所有预设角色。\n"
)
help_text += "\n目前的角色类型有: \n"
help_text += ",".join([self.tags[tag][0] for tag in self.tags])+"。\n"
help_text += ",".join([self.tags[tag][0] for tag in self.tags]) + "。\n"
help_text += f"\n命令例子: \n{trigger_prefix}角色 写作助理\n"
help_text += f"{trigger_prefix}角色类型 所有\n"
help_text += f"{trigger_prefix}停止扮演\n"


+ 1
- 1
plugins/role/roles.json Ver fichero

@@ -428,4 +428,4 @@
]
}
]
}
}

+ 14
- 14
plugins/source.json Ver fichero

@@ -1,16 +1,16 @@
{
"repo": {
"sdwebui": {
"url": "https://github.com/lanvent/plugin_sdwebui.git",
"desc": "利用stable-diffusion画图的插件"
},
"replicate": {
"url": "https://github.com/lanvent/plugin_replicate.git",
"desc": "利用replicate api画图的插件"
},
"summary": {
"url": "https://github.com/lanvent/plugin_summary.git",
"desc": "总结聊天记录的插件"
}
"repo": {
"sdwebui": {
"url": "https://github.com/lanvent/plugin_sdwebui.git",
"desc": "利用stable-diffusion画图的插件"
},
"replicate": {
"url": "https://github.com/lanvent/plugin_replicate.git",
"desc": "利用replicate api画图的插件"
},
"summary": {
"url": "https://github.com/lanvent/plugin_summary.git",
"desc": "总结聊天记录的插件"
}
}
}
}

+ 16
- 16
plugins/tool/README.md Ver fichero

@@ -1,14 +1,14 @@
## 插件描述
一个能让chatgpt联网,搜索,数字运算的插件,将赋予强大且丰富的扩展能力
一个能让chatgpt联网,搜索,数字运算的插件,将赋予强大且丰富的扩展能力
使用该插件需在机器人回复你的前提下,在对话内容前加$tool;仅输入$tool将返回tool插件帮助信息,用于测试插件是否加载成功
### 本插件所有工具同步存放至专用仓库:[chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub)
## 使用说明
使用该插件后将默认使用4个工具, 无需额外配置长期生效:
### 1. python
使用该插件后将默认使用4个工具, 无需额外配置长期生效:
### 1. python
###### python解释器,使用它来解释执行python指令,可以配合你想要chatgpt生成的代码输出结果或执行事务
### 2. url-get
###### 往往用来获取某个网站具体内容,结果可能会被反爬策略影响

@@ -23,16 +23,16 @@

> meteo调优记录:https://github.com/zhayujie/chatgpt-on-wechat/issues/776#issuecomment-1500771334

## 使用本插件对话(prompt)技巧
### 1. 有指引的询问
## 使用本插件对话(prompt)技巧
### 1. 有指引的询问
#### 例如:
- 总结这个链接的内容 https://github.com/goldfishh/chatgpt-tool-hub
- 总结这个链接的内容 https://github.com/goldfishh/chatgpt-tool-hub
- 使用Terminal执行curl cip.cc
- 使用python查询今天日期
### 2. 使用搜索引擎工具
- 如果有搜索工具就能让chatgpt获取到你的未传达清楚的上下文信息,比如chatgpt不知道你的地理位置,现在时间等,所以无法查询到天气
## 其他工具

### 5. wikipedia
@@ -55,9 +55,9 @@
### 10. google-search *
###### google搜索引擎,申请流程较bing-search繁琐

###### 注1:带*工具需要获取api-key才能使用,部分工具需要外网支持
###### 注1:带*工具需要获取api-key才能使用,部分工具需要外网支持
#### [申请方法](https://github.com/goldfishh/chatgpt-tool-hub/blob/master/docs/apply_optional_tool.md)
## config.json 配置说明
###### 默认工具无需配置,其它工具需手动配置,一个例子:
```json
@@ -71,15 +71,15 @@
}

```
注:config.json文件非必须,未创建仍可使用本tool;带*工具需在kwargs填入对应api-key键值对
注:config.json文件非必须,未创建仍可使用本tool;带*工具需在kwargs填入对应api-key键值对
- `tools`:本插件初始化时加载的工具, 目前可选集:["wikipedia", "wolfram-alpha", "bing-search", "google-search", "news", "morning-news"] & 默认工具,除wikipedia工具之外均需要申请api-key
- `kwargs`:工具执行时的配置,一般在这里存放**api-key**,或环境配置
- `request_timeout`: 访问openai接口的超时时间,默认与wechat-on-chatgpt配置一致,可单独配置
- `no_default`: 用于配置默认加载4个工具的行为,如果为true则仅使用tools列表工具,不加载默认工具
- `top_k_results`: 控制所有有关搜索的工具返回条目数,数字越高则参考信息越多,但无用信息可能干扰判断,该值一般为2
- `model_name`: 用于控制tool插件底层使用的llm模型,目前暂未测试3.5以外的模型,一般保持默认
## 备注
- 强烈建议申请搜索工具搭配使用,推荐bing-search
- 虽然我会有意加入一些限制,但请不要使用本插件做危害他人的事情,请提前了解清楚某些内容是否会违反相关规定,建议提前做好过滤


+ 1
- 1
plugins/tool/__init__.py Ver fichero

@@ -1 +1 @@
from .tool import *
from .tool import *

+ 10
- 5
plugins/tool/config.json.template Ver fichero

@@ -1,8 +1,13 @@
{
"tools": ["python", "url-get", "terminal", "meteo-weather"],
"tools": [
"python",
"url-get",
"terminal",
"meteo-weather"
],
"kwargs": {
"top_k_results": 2,
"no_default": false,
"model_name": "gpt-3.5-turbo"
"top_k_results": 2,
"no_default": false,
"model_name": "gpt-3.5-turbo"
}
}
}

+ 40
- 21
plugins/tool/tool.py Ver fichero

@@ -4,6 +4,7 @@ import os
from chatgpt_tool_hub.apps import load_app
from chatgpt_tool_hub.apps.app import App
from chatgpt_tool_hub.tools.all_tool_list import get_all_tool_names

import plugins
from bridge.bridge import Bridge
from bridge.context import ContextType
@@ -14,7 +15,13 @@ from config import conf
from plugins import *


@plugins.register(name="tool", desc="Arming your ChatGPT bot with various tools", version="0.3", author="goldfishh", desire_priority=0)
@plugins.register(
name="tool",
desc="Arming your ChatGPT bot with various tools",
version="0.3",
author="goldfishh",
desire_priority=0,
)
class Tool(Plugin):
def __init__(self):
super().__init__()
@@ -28,22 +35,26 @@ class Tool(Plugin):
help_text = "这是一个能让chatgpt联网,搜索,数字运算的插件,将赋予强大且丰富的扩展能力。"
if not verbose:
return help_text
trigger_prefix = conf().get('plugin_trigger_prefix', "$")
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
help_text += "使用说明:\n"
help_text += f"{trigger_prefix}tool "+"命令: 根据给出的{命令}使用一些可用工具尽力为你得到结果。\n"
help_text += f"{trigger_prefix}tool " + "命令: 根据给出的{命令}使用一些可用工具尽力为你得到结果。\n"
help_text += f"{trigger_prefix}tool reset: 重置工具。\n"
return help_text

def on_handle_context(self, e_context: EventContext):
if e_context['context'].type != ContextType.TEXT:
if e_context["context"].type != ContextType.TEXT:
return

# 暂时不支持未来扩展的bot
if Bridge().get_bot_type("chat") not in (const.CHATGPT, const.OPEN_AI, const.CHATGPTONAZURE):
if Bridge().get_bot_type("chat") not in (
const.CHATGPT,
const.OPEN_AI,
const.CHATGPTONAZURE,
):
return

content = e_context['context'].content
content_list = e_context['context'].content.split(maxsplit=1)
content = e_context["context"].content
content_list = e_context["context"].content.split(maxsplit=1)

if not content or len(content_list) < 1:
e_context.action = EventAction.CONTINUE
@@ -52,13 +63,13 @@ class Tool(Plugin):
logger.debug("[tool] on_handle_context. content: %s" % content)
reply = Reply()
reply.type = ReplyType.TEXT
trigger_prefix = conf().get('plugin_trigger_prefix', "$")
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
# todo: 有些工具必须要api-key,需要修改config文件,所以这里没有实现query增删tool的功能
if content.startswith(f"{trigger_prefix}tool"):
if len(content_list) == 1:
logger.debug("[tool]: get help")
reply.content = self.get_help_text()
e_context['reply'] = reply
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS
return
elif len(content_list) > 1:
@@ -66,12 +77,14 @@ class Tool(Plugin):
logger.debug("[tool]: reset config")
self.app = self._reset_app()
reply.content = "重置工具成功"
e_context['reply'] = reply
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS
return
elif content_list[1].startswith("reset"):
logger.debug("[tool]: remind")
e_context['context'].content = "请你随机用一种聊天风格,提醒用户:如果想重置tool插件,reset之后不要加任何字符"
e_context[
"context"
].content = "请你随机用一种聊天风格,提醒用户:如果想重置tool插件,reset之后不要加任何字符"

e_context.action = EventAction.BREAK
return
@@ -80,34 +93,35 @@ class Tool(Plugin):

# Don't modify bot name
all_sessions = Bridge().get_bot("chat").sessions
user_session = all_sessions.session_query(query, e_context['context']['session_id']).messages
user_session = all_sessions.session_query(
query, e_context["context"]["session_id"]
).messages

# chatgpt-tool-hub will reply you with many tools
logger.debug("[tool]: just-go")
try:
_reply = self.app.ask(query, user_session)
e_context.action = EventAction.BREAK_PASS
all_sessions.session_reply(_reply, e_context['context']['session_id'])
all_sessions.session_reply(
_reply, e_context["context"]["session_id"]
)
except Exception as e:
logger.exception(e)
logger.error(str(e))

e_context['context'].content = "请你随机用一种聊天风格,提醒用户:这个问题tool插件暂时无法处理"
e_context["context"].content = "请你随机用一种聊天风格,提醒用户:这个问题tool插件暂时无法处理"
reply.type = ReplyType.ERROR
e_context.action = EventAction.BREAK
return

reply.content = _reply
e_context['reply'] = reply
e_context["reply"] = reply
return

def _read_json(self) -> dict:
curdir = os.path.dirname(__file__)
config_path = os.path.join(curdir, "config.json")
tool_config = {
"tools": [],
"kwargs": {}
}
tool_config = {"tools": [], "kwargs": {}}
if not os.path.exists(config_path):
return tool_config
else:
@@ -123,7 +137,9 @@ class Tool(Plugin):
"proxy": conf().get("proxy", ""),
"request_timeout": conf().get("request_timeout", 60),
# note: 目前tool暂未对其他模型测试,但这里仍对配置来源做了优先级区分,一般插件配置可覆盖全局配置
"model_name": tool_model_name if tool_model_name else conf().get("model", "gpt-3.5-turbo"),
"model_name": tool_model_name
if tool_model_name
else conf().get("model", "gpt-3.5-turbo"),
"no_default": kwargs.get("no_default", False),
"top_k_results": kwargs.get("top_k_results", 2),
# for news tool
@@ -160,4 +176,7 @@ class Tool(Plugin):
# filter not support tool
tool_list = self._filter_tool_list(tool_config.get("tools", []))

return load_app(tools_list=tool_list, **self._build_tool_kwargs(tool_config.get("kwargs", {})))
return load_app(
tools_list=tool_list,
**self._build_tool_kwargs(tool_config.get("kwargs", {})),
)

+ 1
- 0
requirements.txt Ver fichero

@@ -4,3 +4,4 @@ PyQRCode>=1.2.1
qrcode>=7.4.2
requests>=2.28.2
chardet>=5.1.0
pre-commit

+ 1
- 1
scripts/start.sh Ver fichero

@@ -8,7 +8,7 @@ echo $BASE_DIR
# check the nohup.out log output file
if [ ! -f "${BASE_DIR}/nohup.out" ]; then
touch "${BASE_DIR}/nohup.out"
echo "create file ${BASE_DIR}/nohup.out"
echo "create file ${BASE_DIR}/nohup.out"
fi

nohup python3 "${BASE_DIR}/app.py" & tail -f "${BASE_DIR}/nohup.out"


+ 1
- 1
scripts/tout.sh Ver fichero

@@ -7,7 +7,7 @@ echo $BASE_DIR

# check the nohup.out log output file
if [ ! -f "${BASE_DIR}/nohup.out" ]; then
echo "No file ${BASE_DIR}/nohup.out"
echo "No file ${BASE_DIR}/nohup.out"
exit -1;
fi



+ 25
- 8
voice/audio_convert.py Ver fichero

@@ -1,9 +1,12 @@
import shutil
import wave

import pysilk
from pydub import AudioSegment

sil_supports=[8000, 12000, 16000, 24000, 32000, 44100, 48000] # slk转wav时,支持的采样率
sil_supports = [8000, 12000, 16000, 24000, 32000, 44100, 48000] # slk转wav时,支持的采样率


def find_closest_sil_supports(sample_rate):
"""
找到最接近的支持的采样率
@@ -19,6 +22,7 @@ def find_closest_sil_supports(sample_rate):
mindiff = diff
return closest


def get_pcm_from_wav(wav_path):
"""
从 wav 文件中读取 pcm
@@ -29,31 +33,42 @@ def get_pcm_from_wav(wav_path):
wav = wave.open(wav_path, "rb")
return wav.readframes(wav.getnframes())


def any_to_wav(any_path, wav_path):
"""
把任意格式转成wav文件
"""
if any_path.endswith('.wav'):
if any_path.endswith(".wav"):
shutil.copy2(any_path, wav_path)
return
if any_path.endswith('.sil') or any_path.endswith('.silk') or any_path.endswith('.slk'):
if (
any_path.endswith(".sil")
or any_path.endswith(".silk")
or any_path.endswith(".slk")
):
return sil_to_wav(any_path, wav_path)
audio = AudioSegment.from_file(any_path)
audio.export(wav_path, format="wav")


def any_to_sil(any_path, sil_path):
"""
把任意格式转成sil文件
"""
if any_path.endswith('.sil') or any_path.endswith('.silk') or any_path.endswith('.slk'):
if (
any_path.endswith(".sil")
or any_path.endswith(".silk")
or any_path.endswith(".slk")
):
shutil.copy2(any_path, sil_path)
return 10000
if any_path.endswith('.wav'):
if any_path.endswith(".wav"):
return pcm_to_sil(any_path, sil_path)
if any_path.endswith('.mp3'):
if any_path.endswith(".mp3"):
return mp3_to_sil(any_path, sil_path)
raise NotImplementedError("Not support file type: {}".format(any_path))


def mp3_to_wav(mp3_path, wav_path):
"""
把mp3格式转成pcm文件
@@ -61,6 +76,7 @@ def mp3_to_wav(mp3_path, wav_path):
audio = AudioSegment.from_mp3(mp3_path)
audio.export(wav_path, format="wav")


def pcm_to_sil(pcm_path, silk_path):
"""
wav 文件转成 silk
@@ -72,12 +88,12 @@ def pcm_to_sil(pcm_path, silk_path):
pcm_s16 = audio.set_sample_width(2)
pcm_s16 = pcm_s16.set_frame_rate(rate)
wav_data = pcm_s16.raw_data
silk_data = pysilk.encode(
wav_data, data_rate=rate, sample_rate=rate)
silk_data = pysilk.encode(wav_data, data_rate=rate, sample_rate=rate)
with open(silk_path, "wb") as f:
f.write(silk_data)
return audio.duration_seconds * 1000


def mp3_to_sil(mp3_path, silk_path):
"""
mp3 文件转成 silk
@@ -95,6 +111,7 @@ def mp3_to_sil(mp3_path, silk_path):
f.write(silk_data)
return audio.duration_seconds * 1000


def sil_to_wav(silk_path, wav_path, rate: int = 24000):
"""
silk 文件转 wav


+ 37
- 17
voice/azure/azure_voice.py Ver fichero

@@ -1,16 +1,18 @@

"""
azure voice service
"""
import json
import os
import time

import azure.cognitiveservices.speech as speechsdk

from bridge.reply import Reply, ReplyType
from common.log import logger
from common.tmp_dir import TmpDir
from voice.voice import Voice
from config import conf
from voice.voice import Voice

"""
Azure voice
主目录设置文件中需填写azure_voice_api_key和azure_voice_region
@@ -19,50 +21,68 @@ Azure voice

"""

class AzureVoice(Voice):

class AzureVoice(Voice):
def __init__(self):
try:
curdir = os.path.dirname(__file__)
config_path = os.path.join(curdir, "config.json")
config = None
if not os.path.exists(config_path): #如果没有配置文件,创建本地配置文件
config = { "speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural", "speech_recognition_language": "zh-CN"}
if not os.path.exists(config_path): # 如果没有配置文件,创建本地配置文件
config = {
"speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural",
"speech_recognition_language": "zh-CN",
}
with open(config_path, "w") as fw:
json.dump(config, fw, indent=4)
else:
with open(config_path, "r") as fr:
config = json.load(fr)
self.api_key = conf().get('azure_voice_api_key')
self.api_region = conf().get('azure_voice_region')
self.speech_config = speechsdk.SpeechConfig(subscription=self.api_key, region=self.api_region)
self.speech_config.speech_synthesis_voice_name = config["speech_synthesis_voice_name"]
self.speech_config.speech_recognition_language = config["speech_recognition_language"]
self.api_key = conf().get("azure_voice_api_key")
self.api_region = conf().get("azure_voice_region")
self.speech_config = speechsdk.SpeechConfig(
subscription=self.api_key, region=self.api_region
)
self.speech_config.speech_synthesis_voice_name = config[
"speech_synthesis_voice_name"
]
self.speech_config.speech_recognition_language = config[
"speech_recognition_language"
]
except Exception as e:
logger.warn("AzureVoice init failed: %s, ignore " % e)

def voiceToText(self, voice_file):
audio_config = speechsdk.AudioConfig(filename=voice_file)
speech_recognizer = speechsdk.SpeechRecognizer(speech_config=self.speech_config, audio_config=audio_config)
speech_recognizer = speechsdk.SpeechRecognizer(
speech_config=self.speech_config, audio_config=audio_config
)
result = speech_recognizer.recognize_once()
if result.reason == speechsdk.ResultReason.RecognizedSpeech:
logger.info('[Azure] voiceToText voice file name={} text={}'.format(voice_file, result.text))
logger.info(
"[Azure] voiceToText voice file name={} text={}".format(
voice_file, result.text
)
)
reply = Reply(ReplyType.TEXT, result.text)
else:
logger.error('[Azure] voiceToText error, result={}'.format(result))
logger.error("[Azure] voiceToText error, result={}".format(result))
reply = Reply(ReplyType.ERROR, "抱歉,语音识别失败")
return reply

def textToVoice(self, text):
fileName = TmpDir().path() + 'reply-' + str(int(time.time())) + '.wav'
fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".wav"
audio_config = speechsdk.AudioConfig(filename=fileName)
speech_synthesizer = speechsdk.SpeechSynthesizer(speech_config=self.speech_config, audio_config=audio_config)
speech_synthesizer = speechsdk.SpeechSynthesizer(
speech_config=self.speech_config, audio_config=audio_config
)
result = speech_synthesizer.speak_text(text)
if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted:
logger.info(
'[Azure] textToVoice text={} voice file name={}'.format(text, fileName))
"[Azure] textToVoice text={} voice file name={}".format(text, fileName)
)
reply = Reply(ReplyType.VOICE, fileName)
else:
logger.error('[Azure] textToVoice error, result={}'.format(result))
logger.error("[Azure] textToVoice error, result={}".format(result))
reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败")
return reply

+ 3
- 3
voice/azure/config.json.template Ver fichero

@@ -1,4 +1,4 @@
{
"speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural",
"speech_recognition_language": "zh-CN"
}
"speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural",
"speech_recognition_language": "zh-CN"
}

+ 4
- 4
voice/baidu/README.md Ver fichero

@@ -29,7 +29,7 @@ dev_pid 必填 语言选择,填写语言对应的dev_pid值

2、对于def textToVoice(self, text)函数中调用的百度语音合成API,中接口调用synthesis(参数)在本目录下的`config.json`文件中进行配置。
参数 可需 描述
tex 必填 合成的文本,使用UTF-8编码,请注意文本长度必须小于1024字节
tex 必填 合成的文本,使用UTF-8编码,请注意文本长度必须小于1024字节
lan 必填 固定值zh。语言选择,目前只有中英文混合模式,填写固定值zh
spd 选填 语速,取值0-15,默认为5中语速
pit 选填 音调,取值0-15,默认为5中语调
@@ -40,14 +40,14 @@ aue 选填 3为mp3格式(默认); 4为pcm-16k;5为pcm-8k;6为wav

关于per参数的说明,注意您购买的哪个音库,就填写哪个音库的参数,否则会报错。如果您购买的是基础音库,那么per参数只能填写0到4,如果您购买的是精品音库,那么per参数只能填写5003,5118,106,110,111,103,5其他的都会报错。
### 配置文件
将文件夹中`config.json.template`复制为`config.json`。

``` json
{
"lang": "zh",
"lang": "zh",
"ctp": 1,
"spd": 5,
"spd": 5,
"pit": 5,
"vol": 5,
"per": 0


+ 26
- 23
voice/baidu/baidu_voice.py Ver fichero

@@ -1,17 +1,19 @@

"""
baidu voice service
"""
import json
import os
import time

from aip import AipSpeech

from bridge.reply import Reply, ReplyType
from common.log import logger
from common.tmp_dir import TmpDir
from voice.voice import Voice
from voice.audio_convert import get_pcm_from_wav
from config import conf
from voice.audio_convert import get_pcm_from_wav
from voice.voice import Voice

"""
百度的语音识别API.
dev_pid:
@@ -28,40 +30,37 @@ from config import conf


class BaiduVoice(Voice):

def __init__(self):
try:
curdir = os.path.dirname(__file__)
config_path = os.path.join(curdir, "config.json")
bconf = None
if not os.path.exists(config_path): #如果没有配置文件,创建本地配置文件
bconf = { "lang": "zh", "ctp": 1, "spd": 5,
"pit": 5, "vol": 5, "per": 0}
if not os.path.exists(config_path): # 如果没有配置文件,创建本地配置文件
bconf = {"lang": "zh", "ctp": 1, "spd": 5, "pit": 5, "vol": 5, "per": 0}
with open(config_path, "w") as fw:
json.dump(bconf, fw, indent=4)
else:
with open(config_path, "r") as fr:
bconf = json.load(fr)
self.app_id = conf().get('baidu_app_id')
self.api_key = conf().get('baidu_api_key')
self.secret_key = conf().get('baidu_secret_key')
self.dev_id = conf().get('baidu_dev_pid')
self.app_id = conf().get("baidu_app_id")
self.api_key = conf().get("baidu_api_key")
self.secret_key = conf().get("baidu_secret_key")
self.dev_id = conf().get("baidu_dev_pid")
self.lang = bconf["lang"]
self.ctp = bconf["ctp"]
self.spd = bconf["spd"]
self.pit = bconf["pit"]
self.vol = bconf["vol"]
self.per = bconf["per"]
self.client = AipSpeech(self.app_id, self.api_key, self.secret_key)
except Exception as e:
logger.warn("BaiduVoice init failed: %s, ignore " % e)

def voiceToText(self, voice_file):
# 识别本地文件
logger.debug('[Baidu] voice file name={}'.format(voice_file))
logger.debug("[Baidu] voice file name={}".format(voice_file))
pcm = get_pcm_from_wav(voice_file)
res = self.client.asr(pcm, "pcm", 16000, {"dev_pid": self.dev_id})
if res["err_no"] == 0:
@@ -72,21 +71,25 @@ class BaiduVoice(Voice):
logger.info("百度语音识别出错了: {}".format(res["err_msg"]))
if res["err_msg"] == "request pv too much":
logger.info(" 出现这个原因很可能是你的百度语音服务调用量超出限制,或未开通付费")
reply = Reply(ReplyType.ERROR,
"百度语音识别出错了;{0}".format(res["err_msg"]))
reply = Reply(ReplyType.ERROR, "百度语音识别出错了;{0}".format(res["err_msg"]))
return reply

def textToVoice(self, text):
result = self.client.synthesis(text, self.lang, self.ctp, {
'spd': self.spd, 'pit': self.pit, 'vol': self.vol, 'per': self.per})
result = self.client.synthesis(
text,
self.lang,
self.ctp,
{"spd": self.spd, "pit": self.pit, "vol": self.vol, "per": self.per},
)
if not isinstance(result, dict):
fileName = TmpDir().path() + 'reply-' + str(int(time.time())) + '.mp3'
with open(fileName, 'wb') as f:
fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3"
with open(fileName, "wb") as f:
f.write(result)
logger.info(
'[Baidu] textToVoice text={} voice file name={}'.format(text, fileName))
"[Baidu] textToVoice text={} voice file name={}".format(text, fileName)
)
reply = Reply(ReplyType.VOICE, fileName)
else:
logger.error('[Baidu] textToVoice error={}'.format(result))
logger.error("[Baidu] textToVoice error={}".format(result))
reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败")
return reply

+ 8
- 8
voice/baidu/config.json.template Ver fichero

@@ -1,8 +1,8 @@
{
"lang": "zh",
"ctp": 1,
"spd": 5,
"pit": 5,
"vol": 5,
"per": 0
}
{
"lang": "zh",
"ctp": 1,
"spd": 5,
"pit": 5,
"vol": 5,
"per": 0
}

+ 13
- 7
voice/google/google_voice.py Ver fichero

@@ -1,11 +1,12 @@

"""
google voice service
"""

import time

import speech_recognition
from gtts import gTTS

from bridge.reply import Reply, ReplyType
from common.log import logger
from common.tmp_dir import TmpDir
@@ -22,9 +23,12 @@ class GoogleVoice(Voice):
with speech_recognition.AudioFile(voice_file) as source:
audio = self.recognizer.record(source)
try:
text = self.recognizer.recognize_google(audio, language='zh-CN')
text = self.recognizer.recognize_google(audio, language="zh-CN")
logger.info(
'[Google] voiceToText text={} voice file name={}'.format(text, voice_file))
"[Google] voiceToText text={} voice file name={}".format(
text, voice_file
)
)
reply = Reply(ReplyType.TEXT, text)
except speech_recognition.UnknownValueError:
reply = Reply(ReplyType.ERROR, "抱歉,我听不懂")
@@ -32,13 +36,15 @@ class GoogleVoice(Voice):
reply = Reply(ReplyType.ERROR, "抱歉,无法连接到 Google 语音识别服务;{0}".format(e))
finally:
return reply

def textToVoice(self, text):
try:
mp3File = TmpDir().path() + 'reply-' + str(int(time.time())) + '.mp3'
tts = gTTS(text=text, lang='zh')
tts.save(mp3File)
mp3File = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3"
tts = gTTS(text=text, lang="zh")
tts.save(mp3File)
logger.info(
'[Google] textToVoice text={} voice file name={}'.format(text, mp3File))
"[Google] textToVoice text={} voice file name={}".format(text, mp3File)
)
reply = Reply(ReplyType.VOICE, mp3File)
except Exception as e:
reply = Reply(ReplyType.ERROR, str(e))


+ 9
- 6
voice/openai/openai_voice.py Ver fichero

@@ -1,29 +1,32 @@

"""
google voice service
"""
import json

import openai

from bridge.reply import Reply, ReplyType
from config import conf
from common.log import logger
from config import conf
from voice.voice import Voice


class OpenaiVoice(Voice):
def __init__(self):
openai.api_key = conf().get('open_ai_api_key')
openai.api_key = conf().get("open_ai_api_key")

def voiceToText(self, voice_file):
logger.debug(
'[Openai] voice file name={}'.format(voice_file))
logger.debug("[Openai] voice file name={}".format(voice_file))
try:
file = open(voice_file, "rb")
result = openai.Audio.transcribe("whisper-1", file)
text = result["text"]
reply = Reply(ReplyType.TEXT, text)
logger.info(
'[Openai] voiceToText text={} voice file name={}'.format(text, voice_file))
"[Openai] voiceToText text={} voice file name={}".format(
text, voice_file
)
)
except Exception as e:
reply = Reply(ReplyType.ERROR, str(e))
finally:


+ 9
- 7
voice/pytts/pytts_voice.py Ver fichero

@@ -1,10 +1,11 @@

"""
pytts voice service (offline)
"""

import time

import pyttsx3

from bridge.reply import Reply, ReplyType
from common.log import logger
from common.tmp_dir import TmpDir
@@ -16,20 +17,21 @@ class PyttsVoice(Voice):

def __init__(self):
# 语速
self.engine.setProperty('rate', 125)
self.engine.setProperty("rate", 125)
# 音量
self.engine.setProperty('volume', 1.0)
for voice in self.engine.getProperty('voices'):
self.engine.setProperty("volume", 1.0)
for voice in self.engine.getProperty("voices"):
if "Chinese" in voice.name:
self.engine.setProperty('voice', voice.id)
self.engine.setProperty("voice", voice.id)

def textToVoice(self, text):
try:
wavFile = TmpDir().path() + 'reply-' + str(int(time.time())) + '.wav'
wavFile = TmpDir().path() + "reply-" + str(int(time.time())) + ".wav"
self.engine.save_to_file(text, wavFile)
self.engine.runAndWait()
logger.info(
'[Pytts] textToVoice text={} voice file name={}'.format(text, wavFile))
"[Pytts] textToVoice text={} voice file name={}".format(text, wavFile)
)
reply = Reply(ReplyType.VOICE, wavFile)
except Exception as e:
reply = Reply(ReplyType.ERROR, str(e))


+ 2
- 1
voice/voice.py Ver fichero

@@ -2,6 +2,7 @@
Voice service abstract class
"""


class Voice(object):
def voiceToText(self, voice_file):
"""
@@ -13,4 +14,4 @@ class Voice(object):
"""
Send text to voice service and get voice
"""
raise NotImplementedError
raise NotImplementedError

+ 11
- 5
voice/voice_factory.py Ver fichero

@@ -2,25 +2,31 @@
voice factory
"""


def create_voice(voice_type):
"""
create a voice instance
:param voice_type: voice type code
:return: voice instance
"""
if voice_type == 'baidu':
if voice_type == "baidu":
from voice.baidu.baidu_voice import BaiduVoice

return BaiduVoice()
elif voice_type == 'google':
elif voice_type == "google":
from voice.google.google_voice import GoogleVoice

return GoogleVoice()
elif voice_type == 'openai':
elif voice_type == "openai":
from voice.openai.openai_voice import OpenaiVoice

return OpenaiVoice()
elif voice_type == 'pytts':
elif voice_type == "pytts":
from voice.pytts.pytts_voice import PyttsVoice

return PyttsVoice()
elif voice_type == 'azure':
elif voice_type == "azure":
from voice.azure.azure_voice import AzureVoice

return AzureVoice()
raise RuntimeError

Cargando…
Cancelar
Guardar