Du kannst nicht mehr als 25 Themen auswählen Themen müssen entweder mit einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

278 Zeilen
12KB

  1. # access LinkAI knowledge base platform
  2. # docs: https://link-ai.tech/platform/link-app/wechat
  3. import time
  4. import requests
  5. from bot.bot import Bot
  6. from bot.chatgpt.chat_gpt_session import ChatGPTSession
  7. from bot.session_manager import SessionManager
  8. from bridge.context import Context, ContextType
  9. from bridge.reply import Reply, ReplyType
  10. from common.log import logger
  11. from config import conf, pconf
  12. import threading
  13. class LinkAIBot(Bot):
  14. # authentication failed
  15. AUTH_FAILED_CODE = 401
  16. NO_QUOTA_CODE = 406
  17. def __init__(self):
  18. super().__init__()
  19. self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
  20. self.args = {}
  21. def reply(self, query, context: Context = None) -> Reply:
  22. if context.type == ContextType.TEXT:
  23. return self._chat(query, context)
  24. elif context.type == ContextType.IMAGE_CREATE:
  25. ok, res = self.create_img(query, 0)
  26. if ok:
  27. reply = Reply(ReplyType.IMAGE_URL, res)
  28. else:
  29. reply = Reply(ReplyType.ERROR, res)
  30. return reply
  31. else:
  32. reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
  33. return reply
  34. def _chat(self, query, context, retry_count=0) -> Reply:
  35. """
  36. 发起对话请求
  37. :param query: 请求提示词
  38. :param context: 对话上下文
  39. :param retry_count: 当前递归重试次数
  40. :return: 回复
  41. """
  42. if retry_count > 2:
  43. # exit from retry 2 times
  44. logger.warn("[LINKAI] failed after maximum number of retry times")
  45. return Reply(ReplyType.TEXT, "请再问我一次吧")
  46. try:
  47. # load config
  48. if context.get("generate_breaked_by"):
  49. logger.info(f"[LINKAI] won't set appcode because a plugin ({context['generate_breaked_by']}) affected the context")
  50. app_code = None
  51. else:
  52. app_code = context.kwargs.get("app_code") or conf().get("linkai_app_code")
  53. linkai_api_key = conf().get("linkai_api_key")
  54. session_id = context["session_id"]
  55. session = self.sessions.session_query(query, session_id)
  56. model = conf().get("model")
  57. # remove system message
  58. if session.messages[0].get("role") == "system":
  59. if app_code or model == "wenxin":
  60. session.messages.pop(0)
  61. body = {
  62. "app_code": app_code,
  63. "messages": session.messages,
  64. "model": model, # 对话模型的名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei
  65. "temperature": conf().get("temperature"),
  66. "top_p": conf().get("top_p", 1),
  67. "frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
  68. "presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
  69. }
  70. file_id = context.kwargs.get("file_id")
  71. if file_id:
  72. body["file_id"] = file_id
  73. logger.info(f"[LINKAI] query={query}, app_code={app_code}, mode={body.get('model')}, file_id={file_id}")
  74. headers = {"Authorization": "Bearer " + linkai_api_key}
  75. # do http request
  76. base_url = conf().get("linkai_api_base", "https://api.link-ai.chat")
  77. res = requests.post(url=base_url + "/v1/chat/completions", json=body, headers=headers,
  78. timeout=conf().get("request_timeout", 180))
  79. if res.status_code == 200:
  80. # execute success
  81. response = res.json()
  82. reply_content = response["choices"][0]["message"]["content"]
  83. total_tokens = response["usage"]["total_tokens"]
  84. logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}")
  85. self.sessions.session_reply(reply_content, session_id, total_tokens)
  86. agent_suffix = self._fetch_agent_suffix(response)
  87. if agent_suffix:
  88. reply_content += agent_suffix
  89. if not agent_suffix:
  90. knowledge_suffix = self._fetch_knowledge_search_suffix(response)
  91. if knowledge_suffix:
  92. reply_content += knowledge_suffix
  93. # image process
  94. if response["choices"][0].get("img_urls"):
  95. thread = threading.Thread(target=self._send_image, args=(context.get("channel"), context, response["choices"][0].get("img_urls")))
  96. thread.start()
  97. return Reply(ReplyType.TEXT, reply_content)
  98. else:
  99. response = res.json()
  100. error = response.get("error")
  101. logger.error(f"[LINKAI] chat failed, status_code={res.status_code}, "
  102. f"msg={error.get('message')}, type={error.get('type')}")
  103. if res.status_code >= 500:
  104. # server error, need retry
  105. time.sleep(2)
  106. logger.warn(f"[LINKAI] do retry, times={retry_count}")
  107. return self._chat(query, context, retry_count + 1)
  108. return Reply(ReplyType.TEXT, "提问太快啦,请休息一下再问我吧")
  109. except Exception as e:
  110. logger.exception(e)
  111. # retry
  112. time.sleep(2)
  113. logger.warn(f"[LINKAI] do retry, times={retry_count}")
  114. return self._chat(query, context, retry_count + 1)
  115. def reply_text(self, session: ChatGPTSession, app_code="", retry_count=0) -> dict:
  116. if retry_count >= 2:
  117. # exit from retry 2 times
  118. logger.warn("[LINKAI] failed after maximum number of retry times")
  119. return {
  120. "total_tokens": 0,
  121. "completion_tokens": 0,
  122. "content": "请再问我一次吧"
  123. }
  124. try:
  125. body = {
  126. "app_code": app_code,
  127. "messages": session.messages,
  128. "model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei
  129. "temperature": conf().get("temperature"),
  130. "top_p": conf().get("top_p", 1),
  131. "frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
  132. "presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
  133. }
  134. if self.args.get("max_tokens"):
  135. body["max_tokens"] = self.args.get("max_tokens")
  136. headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
  137. # do http request
  138. base_url = conf().get("linkai_api_base", "https://api.link-ai.chat")
  139. res = requests.post(url=base_url + "/v1/chat/completions", json=body, headers=headers,
  140. timeout=conf().get("request_timeout", 180))
  141. if res.status_code == 200:
  142. # execute success
  143. response = res.json()
  144. reply_content = response["choices"][0]["message"]["content"]
  145. total_tokens = response["usage"]["total_tokens"]
  146. logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}")
  147. return {
  148. "total_tokens": total_tokens,
  149. "completion_tokens": response["usage"]["completion_tokens"],
  150. "content": reply_content,
  151. }
  152. else:
  153. response = res.json()
  154. error = response.get("error")
  155. logger.error(f"[LINKAI] chat failed, status_code={res.status_code}, "
  156. f"msg={error.get('message')}, type={error.get('type')}")
  157. if res.status_code >= 500:
  158. # server error, need retry
  159. time.sleep(2)
  160. logger.warn(f"[LINKAI] do retry, times={retry_count}")
  161. return self.reply_text(session, app_code, retry_count + 1)
  162. return {
  163. "total_tokens": 0,
  164. "completion_tokens": 0,
  165. "content": "提问太快啦,请休息一下再问我吧"
  166. }
  167. except Exception as e:
  168. logger.exception(e)
  169. # retry
  170. time.sleep(2)
  171. logger.warn(f"[LINKAI] do retry, times={retry_count}")
  172. return self.reply_text(session, app_code, retry_count + 1)
  173. def create_img(self, query, retry_count=0, api_key=None):
  174. try:
  175. logger.info("[LinkImage] image_query={}".format(query))
  176. headers = {
  177. "Content-Type": "application/json",
  178. "Authorization": f"Bearer {conf().get('linkai_api_key')}"
  179. }
  180. data = {
  181. "prompt": query,
  182. "n": 1,
  183. "model": conf().get("text_to_image") or "dall-e-2",
  184. "response_format": "url",
  185. "img_proxy": conf().get("image_proxy")
  186. }
  187. url = conf().get("linkai_api_base", "https://api.link-ai.chat") + "/v1/images/generations"
  188. res = requests.post(url, headers=headers, json=data, timeout=(5, 90))
  189. t2 = time.time()
  190. image_url = res.json()["data"][0]["url"]
  191. logger.info("[OPEN_AI] image_url={}".format(image_url))
  192. return True, image_url
  193. except Exception as e:
  194. logger.error(format(e))
  195. return False, "画图出现问题,请休息一下再问我吧"
  196. def _fetch_knowledge_search_suffix(self, response) -> str:
  197. try:
  198. if response.get("knowledge_base"):
  199. search_hit = response.get("knowledge_base").get("search_hit")
  200. first_similarity = response.get("knowledge_base").get("first_similarity")
  201. logger.info(f"[LINKAI] knowledge base, search_hit={search_hit}, first_similarity={first_similarity}")
  202. plugin_config = pconf("linkai")
  203. if plugin_config and plugin_config.get("knowledge_base") and plugin_config.get("knowledge_base").get("search_miss_text_enabled"):
  204. search_miss_similarity = plugin_config.get("knowledge_base").get("search_miss_similarity")
  205. search_miss_text = plugin_config.get("knowledge_base").get("search_miss_suffix")
  206. if not search_hit:
  207. return search_miss_text
  208. if search_miss_similarity and float(search_miss_similarity) > first_similarity:
  209. return search_miss_text
  210. except Exception as e:
  211. logger.exception(e)
  212. def _fetch_agent_suffix(self, response):
  213. try:
  214. plugin_list = []
  215. logger.debug(f"[LinkAgent] res={response}")
  216. if response.get("agent") and response.get("agent").get("chain") and response.get("agent").get("need_show_plugin"):
  217. chain = response.get("agent").get("chain")
  218. suffix = "\n\n- - - - - - - - - - - -"
  219. i = 0
  220. for turn in chain:
  221. plugin_name = turn.get('plugin_name')
  222. suffix += "\n"
  223. need_show_thought = response.get("agent").get("need_show_thought")
  224. if turn.get("thought") and plugin_name and need_show_thought:
  225. suffix += f"{turn.get('thought')}\n"
  226. if plugin_name:
  227. plugin_list.append(turn.get('plugin_name'))
  228. suffix += f"{turn.get('plugin_icon')} {turn.get('plugin_name')}"
  229. if turn.get('plugin_input'):
  230. suffix += f":{turn.get('plugin_input')}"
  231. if i < len(chain) - 1:
  232. suffix += "\n"
  233. i += 1
  234. logger.info(f"[LinkAgent] use plugins: {plugin_list}")
  235. return suffix
  236. except Exception as e:
  237. logger.exception(e)
  238. def _send_image(self, channel, context, image_urls):
  239. if not image_urls:
  240. return
  241. try:
  242. for url in image_urls:
  243. reply = Reply(ReplyType.IMAGE_URL, url)
  244. channel.send(reply, context)
  245. except Exception as e:
  246. logger.error(e)