Nie możesz wybrać więcej, niż 25 tematów Tematy muszą się zaczynać od litery lub cyfry, mogą zawierać myślniki ('-') i mogą mieć do 35 znaków.

259 lines
9.1KB

  1. # encoding:utf-8
  2. import json
  3. import os
  4. import uuid
  5. from uuid import getnode as get_mac
  6. import requests
  7. import plugins
  8. from bridge.context import ContextType
  9. from bridge.reply import Reply, ReplyType
  10. from common.log import logger
  11. from plugins import *
  12. """利用百度UNIT实现智能对话
  13. 如果命中意图,返回意图对应的回复,否则返回继续交付给下个插件处理
  14. """
  15. @plugins.register(
  16. name="BDunit",
  17. desire_priority=0,
  18. hidden=True,
  19. desc="Baidu unit bot system",
  20. version="0.1",
  21. author="jackson",
  22. )
  23. class BDunit(Plugin):
  24. def __init__(self):
  25. super().__init__()
  26. try:
  27. curdir = os.path.dirname(__file__)
  28. config_path = os.path.join(curdir, "config.json")
  29. conf = super().load_config()
  30. if not conf:
  31. if not os.path.exists(config_path):
  32. raise Exception("config.json not found")
  33. else:
  34. with open(config_path, "r") as f:
  35. conf = json.load(f)
  36. self.service_id = conf["service_id"]
  37. self.api_key = conf["api_key"]
  38. self.secret_key = conf["secret_key"]
  39. self.access_token = self.get_token()
  40. self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
  41. logger.info("[BDunit] inited")
  42. except Exception as e:
  43. logger.warn("[BDunit] init failed, ignore ")
  44. raise e
  45. def on_handle_context(self, e_context: EventContext):
  46. if e_context["context"].type != ContextType.TEXT:
  47. return
  48. content = e_context["context"].content
  49. logger.debug("[BDunit] on_handle_context. content: %s" % content)
  50. parsed = self.getUnit2(content)
  51. intent = self.getIntent(parsed)
  52. if intent: # 找到意图
  53. logger.debug("[BDunit] Baidu_AI Intent= %s", intent)
  54. reply = Reply()
  55. reply.type = ReplyType.TEXT
  56. reply.content = self.getSay(parsed)
  57. e_context["reply"] = reply
  58. e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
  59. else:
  60. e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑
  61. def get_help_text(self, **kwargs):
  62. help_text = "本插件会处理询问实时日期时间,天气,数学运算等问题,这些技能由您的百度智能对话UNIT决定\n"
  63. return help_text
  64. def get_token(self):
  65. """获取访问百度UUNIT 的access_token
  66. #param api_key: UNIT apk_key
  67. #param secret_key: UNIT secret_key
  68. Returns:
  69. string: access_token
  70. """
  71. url = "https://aip.baidubce.com/oauth/2.0/token?client_id={}&client_secret={}&grant_type=client_credentials".format(self.api_key, self.secret_key)
  72. payload = ""
  73. headers = {"Content-Type": "application/json", "Accept": "application/json"}
  74. response = requests.request("POST", url, headers=headers, data=payload)
  75. # print(response.text)
  76. return response.json()["access_token"]
  77. def getUnit(self, query):
  78. """
  79. NLU 解析version 3.0
  80. :param query: 用户的指令字符串
  81. :returns: UNIT 解析结果。如果解析失败,返回 None
  82. """
  83. url = "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" + self.access_token
  84. request = {
  85. "query": query,
  86. "user_id": str(get_mac())[:32],
  87. "terminal_id": "88888",
  88. }
  89. body = {
  90. "log_id": str(uuid.uuid1()),
  91. "version": "3.0",
  92. "service_id": self.service_id,
  93. "session_id": str(uuid.uuid1()),
  94. "request": request,
  95. }
  96. try:
  97. headers = {"Content-Type": "application/json"}
  98. response = requests.post(url, json=body, headers=headers)
  99. return json.loads(response.text)
  100. except Exception:
  101. return None
  102. def getUnit2(self, query):
  103. """
  104. NLU 解析 version 2.0
  105. :param query: 用户的指令字符串
  106. :returns: UNIT 解析结果。如果解析失败,返回 None
  107. """
  108. url = "https://aip.baidubce.com/rpc/2.0/unit/service/chat?access_token=" + self.access_token
  109. request = {"query": query, "user_id": str(get_mac())[:32]}
  110. body = {
  111. "log_id": str(uuid.uuid1()),
  112. "version": "2.0",
  113. "service_id": self.service_id,
  114. "session_id": str(uuid.uuid1()),
  115. "request": request,
  116. }
  117. try:
  118. headers = {"Content-Type": "application/json"}
  119. response = requests.post(url, json=body, headers=headers)
  120. return json.loads(response.text)
  121. except Exception:
  122. return None
  123. def getIntent(self, parsed):
  124. """
  125. 提取意图
  126. :param parsed: UNIT 解析结果
  127. :returns: 意图数组
  128. """
  129. if parsed and "result" in parsed and "response_list" in parsed["result"]:
  130. try:
  131. return parsed["result"]["response_list"][0]["schema"]["intent"]
  132. except Exception as e:
  133. logger.warning(e)
  134. return ""
  135. else:
  136. return ""
  137. def hasIntent(self, parsed, intent):
  138. """
  139. 判断是否包含某个意图
  140. :param parsed: UNIT 解析结果
  141. :param intent: 意图的名称
  142. :returns: True: 包含; False: 不包含
  143. """
  144. if parsed and "result" in parsed and "response_list" in parsed["result"]:
  145. response_list = parsed["result"]["response_list"]
  146. for response in response_list:
  147. if "schema" in response and "intent" in response["schema"] and response["schema"]["intent"] == intent:
  148. return True
  149. return False
  150. else:
  151. return False
  152. def getSlots(self, parsed, intent=""):
  153. """
  154. 提取某个意图的所有词槽
  155. :param parsed: UNIT 解析结果
  156. :param intent: 意图的名称
  157. :returns: 词槽列表。你可以通过 name 属性筛选词槽,
  158. 再通过 normalized_word 属性取出相应的值
  159. """
  160. if parsed and "result" in parsed and "response_list" in parsed["result"]:
  161. response_list = parsed["result"]["response_list"]
  162. if intent == "":
  163. try:
  164. return parsed["result"]["response_list"][0]["schema"]["slots"]
  165. except Exception as e:
  166. logger.warning(e)
  167. return []
  168. for response in response_list:
  169. if "schema" in response and "intent" in response["schema"] and "slots" in response["schema"] and response["schema"]["intent"] == intent:
  170. return response["schema"]["slots"]
  171. return []
  172. else:
  173. return []
  174. def getSlotWords(self, parsed, intent, name):
  175. """
  176. 找出命中某个词槽的内容
  177. :param parsed: UNIT 解析结果
  178. :param intent: 意图的名称
  179. :param name: 词槽名
  180. :returns: 命中该词槽的值的列表。
  181. """
  182. slots = self.getSlots(parsed, intent)
  183. words = []
  184. for slot in slots:
  185. if slot["name"] == name:
  186. words.append(slot["normalized_word"])
  187. return words
  188. def getSayByConfidence(self, parsed):
  189. """
  190. 提取 UNIT 置信度最高的回复文本
  191. :param parsed: UNIT 解析结果
  192. :returns: UNIT 的回复文本
  193. """
  194. if parsed and "result" in parsed and "response_list" in parsed["result"]:
  195. response_list = parsed["result"]["response_list"]
  196. answer = {}
  197. for response in response_list:
  198. if (
  199. "schema" in response
  200. and "intent_confidence" in response["schema"]
  201. and (not answer or response["schema"]["intent_confidence"] > answer["schema"]["intent_confidence"])
  202. ):
  203. answer = response
  204. return answer["action_list"][0]["say"]
  205. else:
  206. return ""
  207. def getSay(self, parsed, intent=""):
  208. """
  209. 提取 UNIT 的回复文本
  210. :param parsed: UNIT 解析结果
  211. :param intent: 意图的名称
  212. :returns: UNIT 的回复文本
  213. """
  214. if parsed and "result" in parsed and "response_list" in parsed["result"]:
  215. response_list = parsed["result"]["response_list"]
  216. if intent == "":
  217. try:
  218. return response_list[0]["action_list"][0]["say"]
  219. except Exception as e:
  220. logger.warning(e)
  221. return ""
  222. for response in response_list:
  223. if "schema" in response and "intent" in response["schema"] and response["schema"]["intent"] == intent:
  224. try:
  225. return response["action_list"][0]["say"]
  226. except Exception as e:
  227. logger.warning(e)
  228. return ""
  229. return ""
  230. else:
  231. return ""