No puede seleccionar más de 25 temas Los temas deben comenzar con una letra o número, pueden incluir guiones ('-') y pueden tener hasta 35 caracteres de largo.

270 líneas
9.4KB

  1. # encoding:utf-8
  2. import requests, json
  3. from bot.bot import Bot
  4. from bot.session_manager import SessionManager
  5. from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
  6. from bridge.context import ContextType, Context
  7. from bridge.reply import Reply, ReplyType
  8. from common.log import logger
  9. from config import conf
  10. from common import const
  11. import time
  12. import _thread as thread
  13. import datetime
  14. from datetime import datetime
  15. from wsgiref.handlers import format_date_time
  16. from urllib.parse import urlencode
  17. import base64
  18. import ssl
  19. import hashlib
  20. import hmac
  21. import json
  22. from time import mktime
  23. from urllib.parse import urlparse
  24. import websocket
  25. import queue
  26. import threading
  27. import random
  28. # 消息队列 map
  29. queue_map = dict()
  30. # 响应队列 map
  31. reply_map = dict()
  32. class XunFeiBot(Bot):
  33. def __init__(self):
  34. super().__init__()
  35. self.app_id = conf().get("xunfei_app_id")
  36. self.api_key = conf().get("xunfei_api_key")
  37. self.api_secret = conf().get("xunfei_api_secret")
  38. # 默认使用v2.0版本: "generalv2"
  39. # Spark Lite请求地址(spark_url): wss://spark-api.xf-yun.com/v1.1/chat, 对应的domain参数为: "general"
  40. # Spark V2.0请求地址(spark_url): wss://spark-api.xf-yun.com/v2.1/chat, 对应的domain参数为: "generalv2"
  41. # Spark Pro 请求地址(spark_url): wss://spark-api.xf-yun.com/v3.1/chat, 对应的domain参数为: "generalv3"
  42. # Spark Pro-128K请求地址(spark_url): wss://spark-api.xf-yun.com/chat/pro-128k, 对应的domain参数为: "pro-128k"
  43. # Spark Max 请求地址(spark_url): wss://spark-api.xf-yun.com/v3.5/chat, 对应的domain参数为: "generalv3.5"
  44. # Spark4.0 Ultra 请求地址(spark_url): wss://spark-api.xf-yun.com/v4.0/chat, 对应的domain参数为: "4.0Ultra"
  45. # 后续模型更新,对应的参数可以参考官网文档获取:https://www.xfyun.cn/doc/spark/Web.html
  46. self.domain = conf().get("xunfei_domain", "generalv3.5")
  47. self.spark_url = conf().get("xunfei_spark_url", "wss://spark-api.xf-yun.com/v3.5/chat")
  48. self.host = urlparse(self.spark_url).netloc
  49. self.path = urlparse(self.spark_url).path
  50. # 和wenxin使用相同的session机制
  51. self.sessions = SessionManager(BaiduWenxinSession, model=const.XUNFEI)
  52. def reply(self, query, context: Context = None) -> Reply:
  53. if context.type == ContextType.TEXT:
  54. logger.info("[XunFei] query={}".format(query))
  55. session_id = context["session_id"]
  56. request_id = self.gen_request_id(session_id)
  57. reply_map[request_id] = ""
  58. session = self.sessions.session_query(query, session_id)
  59. threading.Thread(target=self.create_web_socket,
  60. args=(session.messages, request_id)).start()
  61. depth = 0
  62. time.sleep(0.1)
  63. t1 = time.time()
  64. usage = {}
  65. while depth <= 300:
  66. try:
  67. data_queue = queue_map.get(request_id)
  68. if not data_queue:
  69. depth += 1
  70. time.sleep(0.1)
  71. continue
  72. data_item = data_queue.get(block=True, timeout=0.1)
  73. if data_item.is_end:
  74. # 请求结束
  75. del queue_map[request_id]
  76. if data_item.reply:
  77. reply_map[request_id] += data_item.reply
  78. usage = data_item.usage
  79. break
  80. reply_map[request_id] += data_item.reply
  81. depth += 1
  82. except Exception as e:
  83. depth += 1
  84. continue
  85. t2 = time.time()
  86. logger.info(
  87. f"[XunFei-API] response={reply_map[request_id]}, time={t2 - t1}s, usage={usage}"
  88. )
  89. self.sessions.session_reply(reply_map[request_id], session_id,
  90. usage.get("total_tokens"))
  91. reply = Reply(ReplyType.TEXT, reply_map[request_id])
  92. del reply_map[request_id]
  93. return reply
  94. else:
  95. reply = Reply(ReplyType.ERROR,
  96. "Bot不支持处理{}类型的消息".format(context.type))
  97. return reply
  98. def create_web_socket(self, prompt, session_id, temperature=0.5):
  99. logger.info(f"[XunFei] start connect, prompt={prompt}")
  100. websocket.enableTrace(False)
  101. wsUrl = self.create_url()
  102. ws = websocket.WebSocketApp(wsUrl,
  103. on_message=on_message,
  104. on_error=on_error,
  105. on_close=on_close,
  106. on_open=on_open)
  107. data_queue = queue.Queue(1000)
  108. queue_map[session_id] = data_queue
  109. ws.appid = self.app_id
  110. ws.question = prompt
  111. ws.domain = self.domain
  112. ws.session_id = session_id
  113. ws.temperature = temperature
  114. ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
  115. def gen_request_id(self, session_id: str):
  116. return session_id + "_" + str(int(time.time())) + "" + str(
  117. random.randint(0, 100))
  118. # 生成url
  119. def create_url(self):
  120. # 生成RFC1123格式的时间戳
  121. now = datetime.now()
  122. date = format_date_time(mktime(now.timetuple()))
  123. # 拼接字符串
  124. signature_origin = "host: " + self.host + "\n"
  125. signature_origin += "date: " + date + "\n"
  126. signature_origin += "GET " + self.path + " HTTP/1.1"
  127. # 进行hmac-sha256进行加密
  128. signature_sha = hmac.new(self.api_secret.encode('utf-8'),
  129. signature_origin.encode('utf-8'),
  130. digestmod=hashlib.sha256).digest()
  131. signature_sha_base64 = base64.b64encode(signature_sha).decode(
  132. encoding='utf-8')
  133. authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", ' \
  134. f'signature="{signature_sha_base64}"'
  135. authorization = base64.b64encode(
  136. authorization_origin.encode('utf-8')).decode(encoding='utf-8')
  137. # 将请求的鉴权参数组合为字典
  138. v = {"authorization": authorization, "date": date, "host": self.host}
  139. # 拼接鉴权参数,生成url
  140. url = self.spark_url + '?' + urlencode(v)
  141. # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
  142. return url
  143. def gen_params(self, appid, domain, question):
  144. """
  145. 通过appid和用户的提问来生成请参数
  146. """
  147. data = {
  148. "header": {
  149. "app_id": appid,
  150. "uid": "1234"
  151. },
  152. "parameter": {
  153. "chat": {
  154. "domain": domain,
  155. "random_threshold": 0.5,
  156. "max_tokens": 2048,
  157. "auditing": "default"
  158. }
  159. },
  160. "payload": {
  161. "message": {
  162. "text": question
  163. }
  164. }
  165. }
  166. return data
  167. class ReplyItem:
  168. def __init__(self, reply, usage=None, is_end=False):
  169. self.is_end = is_end
  170. self.reply = reply
  171. self.usage = usage
  172. # 收到websocket错误的处理
  173. def on_error(ws, error):
  174. logger.error(f"[XunFei] error: {str(error)}")
  175. # 收到websocket关闭的处理
  176. def on_close(ws, one, two):
  177. data_queue = queue_map.get(ws.session_id)
  178. data_queue.put("END")
  179. # 收到websocket连接建立的处理
  180. def on_open(ws):
  181. logger.info(f"[XunFei] Start websocket, session_id={ws.session_id}")
  182. thread.start_new_thread(run, (ws, ))
  183. def run(ws, *args):
  184. data = json.dumps(
  185. gen_params(appid=ws.appid,
  186. domain=ws.domain,
  187. question=ws.question,
  188. temperature=ws.temperature))
  189. ws.send(data)
  190. # Websocket 操作
  191. # 收到websocket消息的处理
  192. def on_message(ws, message):
  193. data = json.loads(message)
  194. code = data['header']['code']
  195. if code != 0:
  196. logger.error(f'请求错误: {code}, {data}')
  197. ws.close()
  198. else:
  199. choices = data["payload"]["choices"]
  200. status = choices["status"]
  201. content = choices["text"][0]["content"]
  202. data_queue = queue_map.get(ws.session_id)
  203. if not data_queue:
  204. logger.error(
  205. f"[XunFei] can't find data queue, session_id={ws.session_id}")
  206. return
  207. reply_item = ReplyItem(content)
  208. if status == 2:
  209. usage = data["payload"].get("usage")
  210. reply_item = ReplyItem(content, usage)
  211. reply_item.is_end = True
  212. ws.close()
  213. data_queue.put(reply_item)
  214. def gen_params(appid, domain, question, temperature=0.5):
  215. """
  216. 通过appid和用户的提问来生成请参数
  217. """
  218. data = {
  219. "header": {
  220. "app_id": appid,
  221. "uid": "1234"
  222. },
  223. "parameter": {
  224. "chat": {
  225. "domain": domain,
  226. "temperature": temperature,
  227. "random_threshold": 0.5,
  228. "max_tokens": 2048,
  229. "auditing": "default"
  230. }
  231. },
  232. "payload": {
  233. "message": {
  234. "text": question
  235. }
  236. }
  237. }
  238. return data