Browse Source

Refactor: inherit ChatChannel

master
JS00000 1 year ago
parent
commit
1a981ea970
3 changed files with 135 additions and 221 deletions
  1. +2
    -2
      channel/channel_factory.py
  2. +32
    -36
      channel/wechatmp/receive.py
  3. +101
    -183
      channel/wechatmp/wechatmp_channel.py

+ 2
- 2
channel/channel_factory.py View File

@@ -18,6 +18,6 @@ def create_channel(channel_type):
from channel.terminal.terminal_channel import TerminalChannel from channel.terminal.terminal_channel import TerminalChannel
return TerminalChannel() return TerminalChannel()
elif channel_type == 'wechatmp': elif channel_type == 'wechatmp':
from channel.wechatmp.wechatmp_channel import WechatMPServer
return WechatMPServer()
from channel.wechatmp.wechatmp_channel import WechatMPChannel
return WechatMPChannel()
raise RuntimeError raise RuntimeError

+ 32
- 36
channel/wechatmp/receive.py View File

@@ -1,47 +1,43 @@
# -*- coding: utf-8 -*-# # -*- coding: utf-8 -*-#
# filename: receive.py # filename: receive.py
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from bridge.context import ContextType
from channel.chat_message import ChatMessage
from common.tmp_dir import TmpDir
from common.log import logger




def parse_xml(web_data): def parse_xml(web_data):
if len(web_data) == 0: if len(web_data) == 0:
return None return None
xmlData = ET.fromstring(web_data) xmlData = ET.fromstring(web_data)
msg_type = xmlData.find('MsgType').text
if msg_type == 'text':
return TextMsg(xmlData)
elif msg_type == 'image':
return ImageMsg(xmlData)
elif msg_type == 'event':
return Event(xmlData)
return WeChatMPMessage(xmlData)



class Msg(object):
def __init__(self, xmlData):
self.ToUserName = xmlData.find('ToUserName').text
self.FromUserName = xmlData.find('FromUserName').text
self.CreateTime = xmlData.find('CreateTime').text
self.MsgType = xmlData.find('MsgType').text
self.MsgId = xmlData.find('MsgId').text


class TextMsg(Msg):
def __init__(self, xmlData):
Msg.__init__(self, xmlData)
self.Content = xmlData.find('Content').text.encode("utf-8")


class ImageMsg(Msg):
def __init__(self, xmlData):
Msg.__init__(self, xmlData)
self.PicUrl = xmlData.find('PicUrl').text
self.MediaId = xmlData.find('MediaId').text


class Event(object):
class WeChatMPMessage(ChatMessage):
def __init__(self, xmlData): def __init__(self, xmlData):
self.ToUserName = xmlData.find('ToUserName').text
self.FromUserName = xmlData.find('FromUserName').text
self.CreateTime = xmlData.find('CreateTime').text
self.MsgType = xmlData.find('MsgType').text
self.Event = xmlData.find('Event').text
super().__init__(xmlData)
self.to_user_id = xmlData.find('ToUserName').text
self.from_user_id = xmlData.find('FromUserName').text
self.create_time = xmlData.find('CreateTime').text
self.msg_type = xmlData.find('MsgType').text
self.msg_id = xmlData.find('MsgId').text
self.is_group = False
# reply to other_user_id
self.other_user_id = self.from_user_id

if self.msg_type == 'text':
self.ctype = ContextType.TEXT
self.content = xmlData.find('Content').text.encode("utf-8")
elif self.msg_type == 'voice':
self.ctype = ContextType.TEXT
self.content = xmlData.find('Recognition').text.encode("utf-8") # 接收语音识别结果
elif self.msg_type == 'image':
# not implemented
self.pic_url = xmlData.find('PicUrl').text
self.media_id = xmlData.find('MediaId').text
elif self.msg_type == 'event':
self.event = xmlData.find('Event').text
else: # video, shortvideo, location, link
# not implemented
pass

+ 101
- 183
channel/wechatmp/wechatmp_channel.py View File

@@ -4,9 +4,10 @@ import time
import math import math
import hashlib import hashlib
import textwrap import textwrap
from channel.channel import Channel
from channel.chat_channel import ChatChannel
import channel.wechatmp.reply as reply import channel.wechatmp.reply as reply
import channel.wechatmp.receive as receive import channel.wechatmp.receive as receive
from common.singleton import singleton
from common.log import logger from common.log import logger
from config import conf from config import conf
from bridge.reply import * from bridge.reply import *
@@ -21,202 +22,125 @@ import traceback
# certificate='/ssl/cert.pem', # certificate='/ssl/cert.pem',
# private_key='/ssl/cert.key') # private_key='/ssl/cert.key')


class WechatMPServer():

# from concurrent.futures import ThreadPoolExecutor
# thread_pool = ThreadPoolExecutor(max_workers=8)

@singleton
class WechatMPChannel(ChatChannel):
def __init__(self): def __init__(self):
pass
super().__init__()
self.cache_dict = dict()
self.query1 = dict()
self.query2 = dict()
self.query3 = dict()



def startup(self):
def startup(self):
urls = ( urls = (
'/wx', 'WechatMPChannel',
'/wx', 'SubsribeAccountQuery',
) )
app = web.application(urls, globals()) app = web.application(urls, globals())
app.run() app.run()


cache_dict = dict()
query1 = dict()
query2 = dict()
query3 = dict()

from concurrent.futures import ThreadPoolExecutor
thread_pool = ThreadPoolExecutor(max_workers=8)

class WechatMPChannel(Channel):


def GET(self):
try:
data = web.input()
if len(data) == 0:
return "hello, this is handle view"
signature = data.signature
timestamp = data.timestamp
nonce = data.nonce
echostr = data.echostr
token = conf().get('wechatmp_token') #请按照公众平台官网\基本配置中信息填写

data_list = [token, timestamp, nonce]
data_list.sort()
sha1 = hashlib.sha1()
# map(sha1.update, data_list) #python2
sha1.update("".join(data_list).encode('utf-8'))
hashcode = sha1.hexdigest()
print("handle/GET func: hashcode, signature: ", hashcode, signature)
if hashcode == signature:
return echostr
else:
return ""
except Exception as Argument:
return Argument


def _do_build_reply(self, cache_key, fromUser, message):
context = dict()
context['session_id'] = fromUser
reply_text = super().build_reply_content(message, context)
# The query is done, record the cache
logger.info("[threaded] Get reply for {}: {} \nA: {}".format(fromUser, message, reply_text))
global cache_dict
reply_cnt = math.ceil(len(reply_text) / 600)
cache_dict[cache_key] = (reply_cnt, reply_text)


def send(self, reply : Reply, cache_key):
global cache_dict
def send(self, reply: Reply, context: Context):
reply_cnt = math.ceil(len(reply.content) / 600) reply_cnt = math.ceil(len(reply.content) / 600)
cache_dict[cache_key] = (reply_cnt, reply.content)


def handle(self, context):
global cache_dict
try:
reply = Reply()
logger.debug('[wechatmp] ready to handle context: {}'.format(context))

# reply的构建步骤
e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {'channel' : self, 'context': context, 'reply': reply}))
reply = e_context['reply']
if not e_context.is_pass():
logger.debug('[wechatmp] ready to handle context: type={}, content={}'.format(context.type, context.content))
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE:
reply = super().build_reply_content(context.content, context)
# elif context.type == ContextType.VOICE:
# msg = context['msg']
# file_name = TmpDir().path() + context.content
# msg.download(file_name)
# reply = super().build_voice_to_text(file_name)
# if reply.type != ReplyType.ERROR and reply.type != ReplyType.INFO:
# context.content = reply.content # 语音转文字后,将文字内容作为新的context
# context.type = ContextType.TEXT
# reply = super().build_reply_content(context.content, context)
# if reply.type == ReplyType.TEXT:
# if conf().get('voice_reply_voice'):
# reply = super().build_text_to_voice(reply.content)
else:
logger.error('[wechatmp] unknown context type: {}'.format(context.type))
return

logger.debug('[wechatmp] ready to decorate reply: {}'.format(reply))

# reply的包装步骤
if reply and reply.type:
e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {'channel' : self, 'context': context, 'reply': reply}))
reply=e_context['reply']
if not e_context.is_pass() and reply and reply.type:
if reply.type == ReplyType.TEXT:
pass
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
reply.content = str(reply.type)+":\n" + reply.content
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE:
pass
else:
logger.error('[wechatmp] unknown reply type: {}'.format(reply.type))
return

# reply的发送步骤
if reply and reply.type:
e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {'channel' : self, 'context': context, 'reply': reply}))
reply=e_context['reply']
if not e_context.is_pass() and reply and reply.type:
logger.debug('[wechatmp] ready to send reply: {} to {}'.format(reply, context['receiver']))
self.send(reply, context['receiver'])
else:
cache_dict[context['receiver']] = (1, "No reply")

logger.info("[threaded] Get reply for {}: {} \nA: {}".format(context['receiver'], context.content, reply.content))
except Exception as exc:
print(traceback.format_exc())
cache_dict[context['receiver']] = (1, "ERROR")

receiver = context["receiver"]
self.cache_dict[receiver] = (reply_cnt, reply.content)
logger.debug("[send] reply to {} saved to cache: {}".format(receiver, reply))


def verify_server():
try:
data = web.input()
if len(data) == 0:
return "None"
signature = data.signature
timestamp = data.timestamp
nonce = data.nonce
echostr = data.echostr
token = conf().get('wechatmp_token') #请按照公众平台官网\基本配置中信息填写

data_list = [token, timestamp, nonce]
data_list.sort()
sha1 = hashlib.sha1()
# map(sha1.update, data_list) #python2
sha1.update("".join(data_list).encode('utf-8'))
hashcode = sha1.hexdigest()
print("handle/GET func: hashcode, signature: ", hashcode, signature)
if hashcode == signature:
return echostr
else:
return ""
except Exception as Argument:
return Argument


# This class is instantiated once per query
class SubsribeAccountQuery():


def GET(self):
return verify_server()


def POST(self): def POST(self):
channel_instance = WechatMPChannel()
try: try:
queryTime = time.time()
query_time = time.time()
webData = web.data() webData = web.data()
# logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8")) # logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8"))
recMsg = receive.parse_xml(webData)
if isinstance(recMsg, receive.Msg) and recMsg.MsgType == 'text':
fromUser = recMsg.FromUserName
toUser = recMsg.ToUserName
createTime = recMsg.CreateTime
message = recMsg.Content.decode("utf-8")
message_id = recMsg.MsgId
wechat_msg = receive.parse_xml(webData)
if wechat_msg.msg_type == 'text':
from_user = wechat_msg.from_user_id
to_user = wechat_msg.to_user_id
message = wechat_msg.content.decode("utf-8")
message_id = wechat_msg.msg_id


logger.info("[wechatmp] {}:{} Receive post query {} {}: {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), fromUser, message_id, message))
logger.info("[wechatmp] {}:{} Receive post query {} {}: {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), from_user, message_id, message))


global cache_dict
global query1
global query2
global query3
cache_key = fromUser
cache = cache_dict.get(cache_key)
cache_key = from_user
cache = channel_instance.cache_dict.get(cache_key)


reply_text = "" reply_text = ""
# New request # New request
if cache == None: if cache == None:
# The first query begin, reset the cache # The first query begin, reset the cache
cache_dict[cache_key] = (0, "")
# thread_pool.submit(self._do_build_reply, cache_key, fromUser, message)

context = Context()
context.kwargs = {'isgroup': False, 'receiver': fromUser, 'session_id': fromUser}
channel_instance.cache_dict[cache_key] = (0, "")


user_data = conf().get_user_data(fromUser)
context['openai_api_key'] = user_data.get('openai_api_key') # None or user openai_api_key
context = channel_instance._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechat_msg)
if context:
# set private openai_api_key
# if from_user is not changed in itchat, this can be placed at chat_channel
user_data = conf().get_user_data(from_user)
context['openai_api_key'] = user_data.get('openai_api_key') # None or user openai_api_key
channel_instance.produce(context)


img_match_prefix = check_prefix(message, conf().get('image_create_prefix'))
if img_match_prefix:
message = message.replace(img_match_prefix, '', 1).strip()
context.type = ContextType.IMAGE_CREATE
else:
context.type = ContextType.TEXT
context.content = message
thread_pool.submit(self.handle, context)


query1[cache_key] = False
query2[cache_key] = False
query3[cache_key] = False
channel_instance.query1[cache_key] = False
channel_instance.query2[cache_key] = False
channel_instance.query3[cache_key] = False
# Request again # Request again
elif cache[0] == 0 and query1.get(cache_key) == True and query2.get(cache_key) == True and query3.get(cache_key) == True:
query1[cache_key] = False #To improve waiting experience, this can be set to True.
query2[cache_key] = False #To improve waiting experience, this can be set to True.
query3[cache_key] = False
elif cache[0] == 0 and channel_instance.query1.get(cache_key) == True and channel_instance.query2.get(cache_key) == True and channel_instance.query3.get(cache_key) == True:
channel_instance.query1[cache_key] = False #To improve waiting experience, this can be set to True.
channel_instance.query2[cache_key] = False #To improve waiting experience, this can be set to True.
channel_instance.query3[cache_key] = False
elif cache[0] >= 1: elif cache[0] >= 1:
# Skip the waiting phase # Skip the waiting phase
query1[cache_key] = True
query2[cache_key] = True
query3[cache_key] = True
channel_instance.query1[cache_key] = True
channel_instance.query2[cache_key] = True
channel_instance.query3[cache_key] = True




cache = cache_dict.get(cache_key)
if query1.get(cache_key) == False:
cache = channel_instance.cache_dict.get(cache_key)
if channel_instance.query1.get(cache_key) == False:
# The first query from wechat official server # The first query from wechat official server
logger.debug("[wechatmp] query1 {}".format(cache_key)) logger.debug("[wechatmp] query1 {}".format(cache_key))
query1[cache_key] = True
channel_instance.query1[cache_key] = True
cnt = 0 cnt = 0
while cache[0] == 0 and cnt < 45: while cache[0] == 0 and cnt < 45:
cnt = cnt + 1 cnt = cnt + 1
time.sleep(0.1) time.sleep(0.1)
cache = cache_dict.get(cache_key)
cache = channel_instance.cache_dict.get(cache_key)
if cnt == 45: if cnt == 45:
# waiting for timeout (the POST query will be closed by wechat official server) # waiting for timeout (the POST query will be closed by wechat official server)
time.sleep(5) time.sleep(5)
@@ -224,15 +148,15 @@ class WechatMPChannel(Channel):
return return
else: else:
pass pass
elif query2.get(cache_key) == False:
elif channel_instance.query2.get(cache_key) == False:
# The second query from wechat official server # The second query from wechat official server
logger.debug("[wechatmp] query2 {}".format(cache_key)) logger.debug("[wechatmp] query2 {}".format(cache_key))
query2[cache_key] = True
channel_instance.query2[cache_key] = True
cnt = 0 cnt = 0
while cache[0] == 0 and cnt < 45: while cache[0] == 0 and cnt < 45:
cnt = cnt + 1 cnt = cnt + 1
time.sleep(0.1) time.sleep(0.1)
cache = cache_dict.get(cache_key)
cache = channel_instance.cache_dict.get(cache_key)
if cnt == 45: if cnt == 45:
# waiting for timeout (the POST query will be closed by wechat official server) # waiting for timeout (the POST query will be closed by wechat official server)
time.sleep(5) time.sleep(5)
@@ -240,42 +164,42 @@ class WechatMPChannel(Channel):
return return
else: else:
pass pass
elif query3.get(cache_key) == False:
elif channel_instance.query3.get(cache_key) == False:
# The third query from wechat official server # The third query from wechat official server
logger.debug("[wechatmp] query3 {}".format(cache_key)) logger.debug("[wechatmp] query3 {}".format(cache_key))
query3[cache_key] = True
channel_instance.query3[cache_key] = True
cnt = 0 cnt = 0
while cache[0] == 0 and cnt < 45: while cache[0] == 0 and cnt < 45:
cnt = cnt + 1 cnt = cnt + 1
time.sleep(0.1) time.sleep(0.1)
cache = cache_dict.get(cache_key)
cache = channel_instance.cache_dict.get(cache_key)
if cnt == 45: if cnt == 45:
# Have waiting for 3x5 seconds # Have waiting for 3x5 seconds
# return timeout message # return timeout message
reply_text = "【正在响应中,回复任意文字尝试获取回复】" reply_text = "【正在响应中,回复任意文字尝试获取回复】"
logger.info("[wechatmp] Three queries has finished For {}: {}".format(fromUser, message_id))
replyPost = reply.TextMsg(fromUser, toUser, reply_text).send()
logger.info("[wechatmp] Three queries has finished For {}: {}".format(from_user, message_id))
replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
return replyPost return replyPost
else: else:
pass pass


if float(time.time()) - float(queryTime) > 4.8:
logger.info("[wechatmp] Timeout for {} {}".format(fromUser, message_id))
if float(time.time()) - float(query_time) > 4.8:
logger.info("[wechatmp] Timeout for {} {}".format(from_user, message_id))
return return




if cache[0] > 1: if cache[0] > 1:
reply_text = cache[1][:600] + "\n【未完待续,回复任意文字以继续】" #wechatmp auto_reply length limit reply_text = cache[1][:600] + "\n【未完待续,回复任意文字以继续】" #wechatmp auto_reply length limit
cache_dict[cache_key] = (cache[0] - 1, cache[1][600:])
channel_instance.cache_dict[cache_key] = (cache[0] - 1, cache[1][600:])
elif cache[0] == 1: elif cache[0] == 1:
reply_text = cache[1] reply_text = cache[1]
cache_dict.pop(cache_key)
channel_instance.cache_dict.pop(cache_key)
logger.info("[wechatmp] {}:{} Do send {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), reply_text)) logger.info("[wechatmp] {}:{} Do send {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), reply_text))
replyPost = reply.TextMsg(fromUser, toUser, reply_text).send()
replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
return replyPost return replyPost


elif isinstance(recMsg, receive.Event) and recMsg.MsgType == 'event':
logger.info("[wechatmp] Event {} from {}".format(recMsg.Event, recMsg.FromUserName))
elif wechat_msg.msg_type == 'event':
logger.info("[wechatmp] Event {} from {}".format(wechat_msg.Event, wechat_msg.from_user_id))
content = textwrap.dedent("""\ content = textwrap.dedent("""\
感谢您的关注! 感谢您的关注!
这里是ChatGPT,可以自由对话。 这里是ChatGPT,可以自由对话。
@@ -285,7 +209,7 @@ class WechatMPChannel(Channel):
支持图片输出,画字开头的问题将回复图片链接。 支持图片输出,画字开头的问题将回复图片链接。
支持角色扮演和文字冒险两种定制模式对话。 支持角色扮演和文字冒险两种定制模式对话。
输入'#帮助' 查看详细指令。""") 输入'#帮助' 查看详细指令。""")
replyMsg = reply.TextMsg(recMsg.FromUserName, recMsg.ToUserName, content)
replyMsg = reply.TextMsg(wechat_msg.from_user_id, wechat_msg.to_user_id, content)
return replyMsg.send() return replyMsg.send()
else: else:
logger.info("暂且不处理") logger.info("暂且不处理")
@@ -294,9 +218,3 @@ class WechatMPChannel(Channel):
logger.exception(exc) logger.exception(exc)
return exc return exc



def check_prefix(content, prefix_list):
for prefix in prefix_list:
if content.startswith(prefix):
return prefix
return None

Loading…
Cancel
Save