Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

295 lignes
9.7KB

  1. # encoding:utf-8
  2. import json
  3. import os
  4. import uuid
  5. import requests
  6. from bridge.context import ContextType
  7. from bridge.reply import Reply, ReplyType
  8. from common.log import logger
  9. import plugins
  10. from plugins import *
  11. from uuid import getnode as get_mac
  12. """利用百度UNIT实现智能对话
  13. 如果命中意图,返回意图对应的回复,否则返回继续交付给下个插件处理
  14. """
  15. @plugins.register(name="BDunit", desc="Baidu unit bot system", version="0.1", author="jackson", desire_priority=0)
  16. class BDunit(Plugin):
  17. def __init__(self):
  18. super().__init__()
  19. try:
  20. curdir = os.path.dirname(__file__)
  21. config_path = os.path.join(curdir, "config.json")
  22. conf = None
  23. if not os.path.exists(config_path):
  24. raise Exception("config.json not found")
  25. else:
  26. with open(config_path, "r") as f:
  27. conf = json.load(f)
  28. self.service_id = conf["service_id"]
  29. self.api_key = conf["api_key"]
  30. self.secret_key = conf["secret_key"]
  31. self.access_token = self.get_token()
  32. self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
  33. logger.info("[BDunit] inited")
  34. except Exception as e:
  35. logger.warn(
  36. "BDunit init failed: %s, ignore " % e)
  37. def on_handle_context(self, e_context: EventContext):
  38. if e_context['context'].type != ContextType.TEXT:
  39. return
  40. content = e_context['context'].content
  41. logger.debug("[BDunit] on_handle_context. content: %s" % content)
  42. parsed = self.getUnit2(content)
  43. intent = self.getIntent(parsed)
  44. if intent: # 找到意图
  45. logger.debug("[BDunit] Baidu_AI Intent= %s", intent)
  46. reply = Reply()
  47. reply.type = ReplyType.TEXT
  48. reply.content = self.getSay(parsed)
  49. e_context['reply'] = reply
  50. e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
  51. else:
  52. e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑
  53. def get_help_text(self, **kwargs):
  54. help_text = "本插件会处理询问实时日期时间,天气,数学运算等问题,这些技能由您的百度智能对话UNIT决定\n"
  55. return help_text
  56. def get_token(self):
  57. """获取访问百度UUNIT 的access_token
  58. #param api_key: UNIT apk_key
  59. #param secret_key: UNIT secret_key
  60. Returns:
  61. string: access_token
  62. """
  63. url = "https://aip.baidubce.com/oauth/2.0/token?client_id={}&client_secret={}&grant_type=client_credentials".format(
  64. self.api_key, self.secret_key)
  65. payload = ""
  66. headers = {
  67. 'Content-Type': 'application/json',
  68. 'Accept': 'application/json'
  69. }
  70. response = requests.request("POST", url, headers=headers, data=payload)
  71. # print(response.text)
  72. return response.json()['access_token']
  73. def getUnit(self, query):
  74. """
  75. NLU 解析version 3.0
  76. :param query: 用户的指令字符串
  77. :returns: UNIT 解析结果。如果解析失败,返回 None
  78. """
  79. url = (
  80. 'https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token='
  81. + self.access_token
  82. )
  83. request = {"query": query, "user_id": str(
  84. get_mac())[:32], "terminal_id": "88888"}
  85. body = {
  86. "log_id": str(uuid.uuid1()),
  87. "version": "3.0",
  88. "service_id": self.service_id,
  89. "session_id": str(uuid.uuid1()),
  90. "request": request,
  91. }
  92. try:
  93. headers = {"Content-Type": "application/json"}
  94. response = requests.post(url, json=body, headers=headers)
  95. return json.loads(response.text)
  96. except Exception:
  97. return None
  98. def getUnit2(self, query):
  99. """
  100. NLU 解析 version 2.0
  101. :param query: 用户的指令字符串
  102. :returns: UNIT 解析结果。如果解析失败,返回 None
  103. """
  104. url = (
  105. "https://aip.baidubce.com/rpc/2.0/unit/service/chat?access_token="
  106. + self.access_token
  107. )
  108. request = {"query": query, "user_id": str(get_mac())[:32]}
  109. body = {
  110. "log_id": str(uuid.uuid1()),
  111. "version": "2.0",
  112. "service_id": self.service_id,
  113. "session_id": str(uuid.uuid1()),
  114. "request": request,
  115. }
  116. try:
  117. headers = {"Content-Type": "application/json"}
  118. response = requests.post(url, json=body, headers=headers)
  119. return json.loads(response.text)
  120. except Exception:
  121. return None
  122. def getIntent(self, parsed):
  123. """
  124. 提取意图
  125. :param parsed: UNIT 解析结果
  126. :returns: 意图数组
  127. """
  128. if (
  129. parsed
  130. and "result" in parsed
  131. and "response_list" in parsed["result"]
  132. ):
  133. try:
  134. return parsed["result"]["response_list"][0]["schema"]["intent"]
  135. except Exception as e:
  136. logger.warning(e)
  137. return ""
  138. else:
  139. return ""
  140. def hasIntent(self, parsed, intent):
  141. """
  142. 判断是否包含某个意图
  143. :param parsed: UNIT 解析结果
  144. :param intent: 意图的名称
  145. :returns: True: 包含; False: 不包含
  146. """
  147. if (
  148. parsed
  149. and "result" in parsed
  150. and "response_list" in parsed["result"]
  151. ):
  152. response_list = parsed["result"]["response_list"]
  153. for response in response_list:
  154. if (
  155. "schema" in response
  156. and "intent" in response["schema"]
  157. and response["schema"]["intent"] == intent
  158. ):
  159. return True
  160. return False
  161. else:
  162. return False
  163. def getSlots(self, parsed, intent=""):
  164. """
  165. 提取某个意图的所有词槽
  166. :param parsed: UNIT 解析结果
  167. :param intent: 意图的名称
  168. :returns: 词槽列表。你可以通过 name 属性筛选词槽,
  169. 再通过 normalized_word 属性取出相应的值
  170. """
  171. if (
  172. parsed
  173. and "result" in parsed
  174. and "response_list" in parsed["result"]
  175. ):
  176. response_list = parsed["result"]["response_list"]
  177. if intent == "":
  178. try:
  179. return parsed["result"]["response_list"][0]["schema"]["slots"]
  180. except Exception as e:
  181. logger.warning(e)
  182. return []
  183. for response in response_list:
  184. if (
  185. "schema" in response
  186. and "intent" in response["schema"]
  187. and "slots" in response["schema"]
  188. and response["schema"]["intent"] == intent
  189. ):
  190. return response["schema"]["slots"]
  191. return []
  192. else:
  193. return []
  194. def getSlotWords(self, parsed, intent, name):
  195. """
  196. 找出命中某个词槽的内容
  197. :param parsed: UNIT 解析结果
  198. :param intent: 意图的名称
  199. :param name: 词槽名
  200. :returns: 命中该词槽的值的列表。
  201. """
  202. slots = self.getSlots(parsed, intent)
  203. words = []
  204. for slot in slots:
  205. if slot["name"] == name:
  206. words.append(slot["normalized_word"])
  207. return words
  208. def getSayByConfidence(self, parsed):
  209. """
  210. 提取 UNIT 置信度最高的回复文本
  211. :param parsed: UNIT 解析结果
  212. :returns: UNIT 的回复文本
  213. """
  214. if (
  215. parsed
  216. and "result" in parsed
  217. and "response_list" in parsed["result"]
  218. ):
  219. response_list = parsed["result"]["response_list"]
  220. answer = {}
  221. for response in response_list:
  222. if (
  223. "schema" in response
  224. and "intent_confidence" in response["schema"]
  225. and (
  226. not answer
  227. or response["schema"]["intent_confidence"]
  228. > answer["schema"]["intent_confidence"]
  229. )
  230. ):
  231. answer = response
  232. return answer["action_list"][0]["say"]
  233. else:
  234. return ""
  235. def getSay(self, parsed, intent=""):
  236. """
  237. 提取 UNIT 的回复文本
  238. :param parsed: UNIT 解析结果
  239. :param intent: 意图的名称
  240. :returns: UNIT 的回复文本
  241. """
  242. if (
  243. parsed
  244. and "result" in parsed
  245. and "response_list" in parsed["result"]
  246. ):
  247. response_list = parsed["result"]["response_list"]
  248. if intent == "":
  249. try:
  250. return response_list[0]["action_list"][0]["say"]
  251. except Exception as e:
  252. logger.warning(e)
  253. return ""
  254. for response in response_list:
  255. if (
  256. "schema" in response
  257. and "intent" in response["schema"]
  258. and response["schema"]["intent"] == intent
  259. ):
  260. try:
  261. return response["action_list"][0]["say"]
  262. except Exception as e:
  263. logger.warning(e)
  264. return ""
  265. return ""
  266. else:
  267. return ""