Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

268 linhas
12KB

  1. # -*- coding: utf-8 -*-
  2. import web
  3. import time
  4. import math
  5. import hashlib
  6. import textwrap
  7. from channel.chat_channel import ChatChannel
  8. import channel.wechatmp.reply as reply
  9. import channel.wechatmp.receive as receive
  10. from common.expired_dict import ExpiredDict
  11. from common.singleton import singleton
  12. from common.log import logger
  13. from config import conf
  14. from bridge.reply import *
  15. from bridge.context import *
  16. from plugins import *
  17. import traceback
  18. # If using SSL, uncomment the following lines, and modify the certificate path.
  19. # from cheroot.server import HTTPServer
  20. # from cheroot.ssl.builtin import BuiltinSSLAdapter
  21. # HTTPServer.ssl_adapter = BuiltinSSLAdapter(
  22. # certificate='/ssl/cert.pem',
  23. # private_key='/ssl/cert.key')
  24. # from concurrent.futures import ThreadPoolExecutor
  25. # thread_pool = ThreadPoolExecutor(max_workers=8)
  26. MAX_UTF8_LEN = 2048
  27. @singleton
  28. class WechatMPChannel(ChatChannel):
  29. NOT_SUPPORT_REPLYTYPE = [ReplyType.IMAGE, ReplyType.VOICE]
  30. def __init__(self):
  31. super().__init__()
  32. self.cache_dict = dict()
  33. self.running = set()
  34. self.query1 = dict()
  35. self.query2 = dict()
  36. self.query3 = dict()
  37. self.received_msgs = ExpiredDict(60*60*24)
  38. def startup(self):
  39. urls = (
  40. '/wx', 'SubsribeAccountQuery',
  41. )
  42. app = web.application(urls, globals(), autoreload=False)
  43. port = conf().get('wechatmp_port', 8080)
  44. web.httpserver.runsimple(app.wsgifunc(), ('0.0.0.0', port))
  45. def send(self, reply: Reply, context: Context):
  46. receiver = context["receiver"]
  47. self.cache_dict[receiver] = reply.content
  48. self.running.remove(receiver)
  49. logger.debug("[send] reply to {} saved to cache: {}".format(receiver, reply))
  50. def _fail_callback(self, session_id, exception, context, **kwargs):
  51. logger.exception("[wechatmp] Fail to generation message to user, msgId={}, exception={}".format(context['msg'].msg_id, exception))
  52. assert session_id not in self.cache_dict
  53. self.running.remove(session_id)
  54. def verify_server():
  55. try:
  56. data = web.input()
  57. if len(data) == 0:
  58. return "None"
  59. signature = data.signature
  60. timestamp = data.timestamp
  61. nonce = data.nonce
  62. echostr = data.echostr
  63. token = conf().get('wechatmp_token') #请按照公众平台官网\基本配置中信息填写
  64. data_list = [token, timestamp, nonce]
  65. data_list.sort()
  66. sha1 = hashlib.sha1()
  67. # map(sha1.update, data_list) #python2
  68. sha1.update("".join(data_list).encode('utf-8'))
  69. hashcode = sha1.hexdigest()
  70. print("handle/GET func: hashcode, signature: ", hashcode, signature)
  71. if hashcode == signature:
  72. return echostr
  73. else:
  74. return ""
  75. except Exception as Argument:
  76. return Argument
  77. # This class is instantiated once per query
  78. class SubsribeAccountQuery():
  79. def GET(self):
  80. return verify_server()
  81. def POST(self):
  82. channel = WechatMPChannel()
  83. try:
  84. query_time = time.time()
  85. webData = web.data()
  86. logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8"))
  87. wechat_msg = receive.parse_xml(webData)
  88. if wechat_msg.msg_type == 'text':
  89. from_user = wechat_msg.from_user_id
  90. to_user = wechat_msg.to_user_id
  91. message = wechat_msg.content.decode("utf-8")
  92. message_id = wechat_msg.msg_id
  93. logger.info("[wechatmp] {}:{} Receive post query {} {}: {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), from_user, message_id, message))
  94. supported = True
  95. if "【收到不支持的消息类型,暂无法显示】" in message:
  96. supported = False # not supported, used to refresh
  97. cache_key = from_user
  98. reply_text = ""
  99. # New request
  100. if cache_key not in channel.cache_dict and cache_key not in channel.running:
  101. # The first query begin, reset the cache
  102. context = channel._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechat_msg)
  103. logger.debug("[wechatmp] context: {} {}".format(context, wechat_msg))
  104. if message_id in channel.received_msgs: # received and finished
  105. return
  106. if supported and context:
  107. # set private openai_api_key
  108. # if from_user is not changed in itchat, this can be placed at chat_channel
  109. user_data = conf().get_user_data(from_user)
  110. context['openai_api_key'] = user_data.get('openai_api_key') # None or user openai_api_key
  111. channel.received_msgs[message_id] = wechat_msg
  112. channel.running.add(cache_key)
  113. channel.produce(context)
  114. else:
  115. trigger_prefix = conf().get('single_chat_prefix',[''])[0]
  116. if trigger_prefix or not supported:
  117. if trigger_prefix:
  118. content = textwrap.dedent(f"""\
  119. 请输入'{trigger_prefix}'接你想说的话跟我说话。
  120. 例如:
  121. {trigger_prefix}你好,很高兴见到你。""")
  122. else:
  123. content = textwrap.dedent("""\
  124. 你好,很高兴见到你。
  125. 请跟我说话吧。""")
  126. else:
  127. logger.error(f"[wechatmp] unknown error")
  128. content = textwrap.dedent("""\
  129. 未知错误,请稍后再试""")
  130. replyMsg = reply.TextMsg(wechat_msg.from_user_id, wechat_msg.to_user_id, content)
  131. return replyMsg.send()
  132. channel.query1[cache_key] = False
  133. channel.query2[cache_key] = False
  134. channel.query3[cache_key] = False
  135. # Request again
  136. elif cache_key in channel.running and channel.query1.get(cache_key) == True and channel.query2.get(cache_key) == True and channel.query3.get(cache_key) == True:
  137. channel.query1[cache_key] = False #To improve waiting experience, this can be set to True.
  138. channel.query2[cache_key] = False #To improve waiting experience, this can be set to True.
  139. channel.query3[cache_key] = False
  140. elif cache_key in channel.cache_dict:
  141. # Skip the waiting phase
  142. channel.query1[cache_key] = True
  143. channel.query2[cache_key] = True
  144. channel.query3[cache_key] = True
  145. assert not (cache_key in channel.cache_dict and cache_key in channel.running)
  146. if channel.query1.get(cache_key) == False:
  147. # The first query from wechat official server
  148. logger.debug("[wechatmp] query1 {}".format(cache_key))
  149. channel.query1[cache_key] = True
  150. cnt = 0
  151. while cache_key not in channel.cache_dict and cnt < 45:
  152. cnt = cnt + 1
  153. time.sleep(0.1)
  154. if cnt == 45:
  155. # waiting for timeout (the POST query will be closed by wechat official server)
  156. time.sleep(1)
  157. # and do nothing
  158. return
  159. else:
  160. pass
  161. elif channel.query2.get(cache_key) == False:
  162. # The second query from wechat official server
  163. logger.debug("[wechatmp] query2 {}".format(cache_key))
  164. channel.query2[cache_key] = True
  165. cnt = 0
  166. while cache_key not in channel.cache_dict and cnt < 45:
  167. cnt = cnt + 1
  168. time.sleep(0.1)
  169. if cnt == 45:
  170. # waiting for timeout (the POST query will be closed by wechat official server)
  171. time.sleep(1)
  172. # and do nothing
  173. return
  174. else:
  175. pass
  176. elif channel.query3.get(cache_key) == False:
  177. # The third query from wechat official server
  178. logger.debug("[wechatmp] query3 {}".format(cache_key))
  179. channel.query3[cache_key] = True
  180. cnt = 0
  181. while cache_key not in channel.cache_dict and cnt < 40:
  182. cnt = cnt + 1
  183. time.sleep(0.1)
  184. if cnt == 40:
  185. # Have waiting for 3x5 seconds
  186. # return timeout message
  187. reply_text = "【正在思考中,回复任意文字尝试获取回复】"
  188. logger.info("[wechatmp] Three queries has finished For {}: {}".format(from_user, message_id))
  189. replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
  190. return replyPost
  191. else:
  192. pass
  193. if float(time.time()) - float(query_time) > 4.8:
  194. reply_text = "【正在思考中,回复任意文字尝试获取回复】"
  195. logger.info("[wechatmp] Timeout for {} {}, return".format(from_user, message_id))
  196. replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
  197. return replyPost
  198. if cache_key in channel.cache_dict:
  199. content = channel.cache_dict[cache_key]
  200. if len(content.encode('utf8'))<=MAX_UTF8_LEN:
  201. reply_text = channel.cache_dict[cache_key]
  202. channel.cache_dict.pop(cache_key)
  203. else:
  204. continue_text = "\n【未完待续,回复任意文字以继续】"
  205. splits = split_string_by_utf8_length(content, MAX_UTF8_LEN - len(continue_text.encode('utf-8')), max_split= 1)
  206. reply_text = splits[0] + continue_text
  207. channel.cache_dict[cache_key] = splits[1]
  208. logger.info("[wechatmp] {}:{} Do send {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), reply_text))
  209. replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
  210. return replyPost
  211. elif wechat_msg.msg_type == 'event':
  212. logger.info("[wechatmp] Event {} from {}".format(wechat_msg.content, wechat_msg.from_user_id))
  213. trigger_prefix = conf().get('single_chat_prefix',[''])[0]
  214. content = textwrap.dedent(f"""\
  215. 感谢您的关注!
  216. 这里是ChatGPT,可以自由对话。
  217. 资源有限,回复较慢,请勿着急。
  218. 支持通用表情输入。
  219. 暂时不支持图片输入。
  220. 支持图片输出,画字开头的问题将回复图片链接。
  221. 支持角色扮演和文字冒险两种定制模式对话。
  222. 输入'{trigger_prefix}#帮助' 查看详细指令。""")
  223. replyMsg = reply.TextMsg(wechat_msg.from_user_id, wechat_msg.to_user_id, content)
  224. return replyMsg.send()
  225. else:
  226. logger.info("暂且不处理")
  227. return "success"
  228. except Exception as exc:
  229. logger.exception(exc)
  230. return exc
  231. def split_string_by_utf8_length(string, max_length, max_split=0):
  232. encoded = string.encode('utf-8')
  233. start, end = 0, 0
  234. result = []
  235. while end < len(encoded):
  236. if max_split > 0 and len(result) >= max_split:
  237. result.append(encoded[start:].decode('utf-8'))
  238. break
  239. end = start + max_length
  240. # 如果当前字节不是 UTF-8 编码的开始字节,则向前查找直到找到开始字节为止
  241. while end < len(encoded) and (encoded[end] & 0b11000000) == 0b10000000:
  242. end -= 1
  243. result.append(encoded[start:end].decode('utf-8'))
  244. start = end
  245. return result