Browse Source

Merge branch 'master' into master

master
zhayujie GitHub 9 months ago
parent
commit
eda3ba92fd
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
13 changed files with 310 additions and 46 deletions
  1. +19
    -14
      app.py
  2. +1
    -0
      bot/bot_factory.py
  3. +6
    -2
      bot/linkai/link_ai_bot.py
  4. +155
    -0
      bot/zhipu/chat_glm_bot.py
  5. +48
    -0
      bot/zhipu/chat_glm_session.py
  6. +4
    -2
      channel/chat_channel.py
  7. +38
    -25
      channel/wechat/wechat_channel.py
  8. +1
    -0
      common/const.py
  9. +26
    -1
      common/linkai_client.py
  10. +0
    -1
      config.py
  11. +8
    -0
      plugins/godcmd/godcmd.py
  12. +3
    -0
      plugins/plugin.py
  13. +1
    -1
      requirements-optional.txt

+ 19
- 14
app.py View File

@@ -3,6 +3,7 @@
import os import os
import signal import signal
import sys import sys
import time


from channel import channel_factory from channel import channel_factory
from common import const from common import const
@@ -24,6 +25,21 @@ def sigterm_handler_wrap(_signo):
signal.signal(_signo, func) signal.signal(_signo, func)




def start_channel(channel_name: str):
channel = channel_factory.create_channel(channel_name)
if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service", "wechatcom_app", "wework",
const.FEISHU, const.DINGTALK]:
PluginManager().load_plugins()

if conf().get("use_linkai"):
try:
from common import linkai_client
threading.Thread(target=linkai_client.start, args=(channel,)).start()
except Exception as e:
pass
channel.startup()


def run(): def run():
try: try:
# load config # load config
@@ -41,22 +57,11 @@ def run():


if channel_name == "wxy": if channel_name == "wxy":
os.environ["WECHATY_LOG"] = "warn" os.environ["WECHATY_LOG"] = "warn"
# os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001'

channel = channel_factory.create_channel(channel_name)
if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service", "wechatcom_app", "wework", const.FEISHU,const.DINGTALK]:
PluginManager().load_plugins()

if conf().get("use_linkai"):
try:
from common import linkai_client
threading.Thread(target=linkai_client.start, args=(channel, )).start()
except Exception as e:
pass


# startup channel
channel.startup()
start_channel(channel_name)


while True:
time.sleep(1)
except Exception as e: except Exception as e:
logger.error("App startup failed!") logger.error("App startup failed!")
logger.exception(e) logger.exception(e)


+ 1
- 0
bot/bot_factory.py View File

@@ -56,4 +56,5 @@ def create_bot(bot_type):
from bot.zhipuai.zhipuai_bot import ZHIPUAIBot from bot.zhipuai.zhipuai_bot import ZHIPUAIBot
return ZHIPUAIBot() return ZHIPUAIBot()



raise RuntimeError raise RuntimeError

+ 6
- 2
bot/linkai/link_ai_bot.py View File

@@ -107,7 +107,11 @@ class LinkAIBot(Bot):
body["group_name"] = context.kwargs.get("msg").from_user_nickname body["group_name"] = context.kwargs.get("msg").from_user_nickname
body["sender_name"] = context.kwargs.get("msg").actual_user_nickname body["sender_name"] = context.kwargs.get("msg").actual_user_nickname
else: else:
body["sender_name"] = context.kwargs.get("msg").from_user_nickname
if body.get("channel_type") in ["wechatcom_app"]:
body["sender_name"] = context.kwargs.get("msg").from_user_id
else:
body["sender_name"] = context.kwargs.get("msg").from_user_nickname

except Exception as e: except Exception as e:
pass pass
file_id = context.kwargs.get("file_id") file_id = context.kwargs.get("file_id")
@@ -396,7 +400,7 @@ class LinkAIBot(Bot):
i += 1 i += 1
if url.endswith(".mp4"): if url.endswith(".mp4"):
reply_type = ReplyType.VIDEO_URL reply_type = ReplyType.VIDEO_URL
elif url.endswith(".pdf") or url.endswith(".doc") or url.endswith(".docx"):
elif url.endswith(".pdf") or url.endswith(".doc") or url.endswith(".docx") or url.endswith(".csv"):
reply_type = ReplyType.FILE reply_type = ReplyType.FILE
url = _download_file(url) url = _download_file(url)
if not url: if not url:


+ 155
- 0
bot/zhipu/chat_glm_bot.py View File

@@ -0,0 +1,155 @@
# encoding:utf-8

import time

import openai
import openai.error
import requests

from bot.bot import Bot
from bot.zhipu.chat_glm_session import ChatGLMSession
from bot.openai.open_ai_image import OpenAIImage
from bot.session_manager import SessionManager
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from common.log import logger
# from common.token_bucket import TokenBucket
from config import conf, load_config
from zhipuai import ZhipuAI


# ZhipuAI对话模型API
class ChatGLMBot(Bot):
def __init__(self):
super().__init__()
# set the default api_key
self.api_key = conf().get("zhipu_ai_api_key")
if conf().get("zhipu_ai_api_base"):
openai.api_base = conf().get("zhipu_ai_api_base")
# if conf().get("rate_limit_chatgpt"):
# self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20))

self.sessions = SessionManager(ChatGLMSession, model=conf().get("model") or "chatglm")
self.args = {
"model": "glm-4", # 对话模型的名称
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
# "max_tokens":4096, # 回复最大的字符数
"top_p": conf().get("top_p", 0.7),
# "frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
# "presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
# "request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
# "timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
}
self.client = ZhipuAI(api_key=self.api_key)

def reply(self, query, context=None):
# acquire reply content
if context.type == ContextType.TEXT:
logger.info("[CHATGLM] query={}".format(query))

session_id = context["session_id"]
reply = None
clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
if query in clear_memory_commands:
self.sessions.clear_session(session_id)
reply = Reply(ReplyType.INFO, "记忆已清除")
elif query == "#清除所有":
self.sessions.clear_all_session()
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
elif query == "#更新配置":
load_config()
reply = Reply(ReplyType.INFO, "配置已更新")
if reply:
return reply
session = self.sessions.session_query(query, session_id)
logger.debug("[CHATGLM] session query={}".format(session.messages))

api_key = context.get("openai_api_key") or openai.api_key
model = context.get("gpt_model")
new_args = None
if model:
new_args = self.args.copy()
new_args["model"] = model
# if context.get('stream'):
# # reply in stream
# return self.reply_text_stream(query, new_query, session_id)

reply_content = self.reply_text(session, api_key, args=new_args)
logger.debug(
"[CHATGLM] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
session.messages,
session_id,
reply_content["content"],
reply_content["completion_tokens"],
)
)
if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
reply = Reply(ReplyType.ERROR, reply_content["content"])
elif reply_content["completion_tokens"] > 0:
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
reply = Reply(ReplyType.TEXT, reply_content["content"])
else:
reply = Reply(ReplyType.ERROR, reply_content["content"])
logger.debug("[CHATGLM] reply {} used 0 tokens.".format(reply_content))
return reply
else:
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
return reply

def reply_text(self, session: ChatGLMSession, api_key=None, args=None, retry_count=0) -> dict:
"""
call openai's ChatCompletion to get the answer
:param session: a conversation session
:param session_id: session id
:param retry_count: retry count
:return: {}
"""
try:
# if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token():
# raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
# if api_key == None, the default openai.api_key will be used
if args is None:
args = self.args
# response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **args)
response = self.client.chat.completions.create(messages=session.messages, **args)
# logger.debug("[CHATGLM] response={}".format(response))
# logger.info("[CHATGLM] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
return {
"total_tokens": response.usage.total_tokens,
"completion_tokens": response.usage.completion_tokens,
"content": response.choices[0].message.content,
}
except Exception as e:
need_retry = retry_count < 2
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
if isinstance(e, openai.error.RateLimitError):
logger.warn("[CHATGLM] RateLimitError: {}".format(e))
result["content"] = "提问太快啦,请休息一下再问我吧"
if need_retry:
time.sleep(20)
elif isinstance(e, openai.error.Timeout):
logger.warn("[CHATGLM] Timeout: {}".format(e))
result["content"] = "我没有收到你的消息"
if need_retry:
time.sleep(5)
elif isinstance(e, openai.error.APIError):
logger.warn("[CHATGLM] Bad Gateway: {}".format(e))
result["content"] = "请再问我一次"
if need_retry:
time.sleep(10)
elif isinstance(e, openai.error.APIConnectionError):
logger.warn("[CHATGLM] APIConnectionError: {}".format(e))
result["content"] = "我连接不到你的网络"
if need_retry:
time.sleep(5)
else:
logger.exception("[CHATGLM] Exception: {}".format(e), e)
need_retry = False
self.sessions.clear_session(session.session_id)

if need_retry:
logger.warn("[CHATGLM] 第{}次重试".format(retry_count + 1))
return self.reply_text(session, api_key, args, retry_count + 1)
else:
return result


+ 48
- 0
bot/zhipu/chat_glm_session.py View File

@@ -0,0 +1,48 @@
from bot.session_manager import Session
from common.log import logger

class ChatGLMSession(Session):
def __init__(self, session_id, system_prompt=None, model="glm-4"):
super().__init__(session_id, system_prompt)
self.model = model
self.reset()

def discard_exceeding(self, max_tokens, cur_tokens=None):
precise = True
try:
cur_tokens = self.calc_tokens()
except Exception as e:
precise = False
if cur_tokens is None:
raise e
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
while cur_tokens > max_tokens:
if len(self.messages) > 2:
self.messages.pop(1)
elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
self.messages.pop(1)
if precise:
cur_tokens = self.calc_tokens()
else:
cur_tokens = cur_tokens - max_tokens
break
elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
break
else:
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
break
if precise:
cur_tokens = self.calc_tokens()
else:
cur_tokens = cur_tokens - max_tokens
return cur_tokens

def calc_tokens(self):
return num_tokens_from_messages(self.messages, self.model)

def num_tokens_from_messages(messages, model):
tokens = 0
for msg in messages:
tokens += len(msg["content"])
return tokens

+ 4
- 2
channel/chat_channel.py View File

@@ -4,6 +4,7 @@ import threading
import time import time
from asyncio import CancelledError from asyncio import CancelledError
from concurrent.futures import Future, ThreadPoolExecutor from concurrent.futures import Future, ThreadPoolExecutor
from concurrent import futures


from bridge.context import * from bridge.context import *
from bridge.reply import * from bridge.reply import *
@@ -17,6 +18,8 @@ try:
except Exception as e: except Exception as e:
pass pass


handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池



# 抽象类, 它包含了与消息通道无关的通用处理逻辑 # 抽象类, 它包含了与消息通道无关的通用处理逻辑
class ChatChannel(Channel): class ChatChannel(Channel):
@@ -25,7 +28,6 @@ class ChatChannel(Channel):
futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消 futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消
sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理 sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理
lock = threading.Lock() # 用于控制对sessions的访问 lock = threading.Lock() # 用于控制对sessions的访问
handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池


def __init__(self): def __init__(self):
_thread = threading.Thread(target=self.consume) _thread = threading.Thread(target=self.consume)
@@ -339,7 +341,7 @@ class ChatChannel(Channel):
if not context_queue.empty(): if not context_queue.empty():
context = context_queue.get() context = context_queue.get()
logger.debug("[WX] consume context: {}".format(context)) logger.debug("[WX] consume context: {}".format(context))
future: Future = self.handler_pool.submit(self._handle, context)
future: Future = handler_pool.submit(self._handle, context)
future.add_done_callback(self._thread_pool_callback(session_id, context=context)) future.add_done_callback(self._thread_pool_callback(session_id, context=context))
if session_id not in self.futures: if session_id not in self.futures:
self.futures[session_id] = [] self.futures[session_id] = []


+ 38
- 25
channel/wechat/wechat_channel.py View File

@@ -15,6 +15,7 @@ import requests
from bridge.context import * from bridge.context import *
from bridge.reply import * from bridge.reply import *
from channel.chat_channel import ChatChannel from channel.chat_channel import ChatChannel
from channel import chat_channel
from channel.wechat.wechat_message import * from channel.wechat.wechat_message import *
from common.expired_dict import ExpiredDict from common.expired_dict import ExpiredDict
from common.log import logger from common.log import logger
@@ -112,30 +113,39 @@ class WechatChannel(ChatChannel):
self.auto_login_times = 0 self.auto_login_times = 0


def startup(self): def startup(self):
itchat.instance.receivingRetryCount = 600 # 修改断线超时时间
# login by scan QRCode
hotReload = conf().get("hot_reload", False)
status_path = os.path.join(get_appdata_dir(), "itchat.pkl")
itchat.auto_login(
enableCmdQR=2,
hotReload=hotReload,
statusStorageDir=status_path,
qrCallback=qrCallback,
exitCallback=self.exitCallback,
loginCallback=self.loginCallback
)
self.user_id = itchat.instance.storageClass.userName
self.name = itchat.instance.storageClass.nickName
logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
# start message listener
itchat.run()
try:
itchat.instance.receivingRetryCount = 600 # 修改断线超时时间
# login by scan QRCode
hotReload = conf().get("hot_reload", False)
status_path = os.path.join(get_appdata_dir(), "itchat.pkl")
itchat.auto_login(
enableCmdQR=2,
hotReload=hotReload,
statusStorageDir=status_path,
qrCallback=qrCallback,
exitCallback=self.exitCallback,
loginCallback=self.loginCallback
)
self.user_id = itchat.instance.storageClass.userName
self.name = itchat.instance.storageClass.nickName
logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
# start message listener
itchat.run()
except Exception as e:
logger.error(e)


def exitCallback(self): def exitCallback(self):
_send_logout()
time.sleep(3)
self.auto_login_times += 1
if self.auto_login_times < 100:
self.startup()
try:
from common.linkai_client import chat_client
if chat_client.client_id and conf().get("use_linkai"):
_send_logout()
time.sleep(2)
self.auto_login_times += 1
if self.auto_login_times < 100:
chat_channel.handler_pool._shutdown = False
self.startup()
except Exception as e:
pass


def loginCallback(self): def loginCallback(self):
logger.debug("Login success") logger.debug("Login success")
@@ -251,20 +261,23 @@ class WechatChannel(ChatChannel):
def _send_login_success(): def _send_login_success():
try: try:
from common.linkai_client import chat_client from common.linkai_client import chat_client
chat_client.send_login_success()
if chat_client.client_id:
chat_client.send_login_success()
except Exception as e: except Exception as e:
pass pass


def _send_logout(): def _send_logout():
try: try:
from common.linkai_client import chat_client from common.linkai_client import chat_client
chat_client.send_logout()
if chat_client.client_id:
chat_client.send_logout()
except Exception as e: except Exception as e:
pass pass


def _send_qr_code(qrcode_list: list): def _send_qr_code(qrcode_list: list):
try: try:
from common.linkai_client import chat_client from common.linkai_client import chat_client
chat_client.send_qrcode(qrcode_list)
if chat_client.client_id:
chat_client.send_qrcode(qrcode_list)
except Exception as e: except Exception as e:
pass pass

+ 1
- 0
common/const.py View File

@@ -10,6 +10,7 @@ QWEN = "qwen"
GEMINI = "gemini" GEMINI = "gemini"
ZHIPU_AI = "glm-4" ZHIPU_AI = "glm-4"



# model # model
GPT35 = "gpt-3.5-turbo" GPT35 = "gpt-3.5-turbo"
GPT4 = "gpt-4" GPT4 = "gpt-4"


+ 26
- 1
common/linkai_client.py View File

@@ -2,7 +2,9 @@ from bridge.context import Context, ContextType
from bridge.reply import Reply, ReplyType from bridge.reply import Reply, ReplyType
from common.log import logger from common.log import logger
from linkai import LinkAIClient, PushMsg from linkai import LinkAIClient, PushMsg
from config import conf
from config import conf, pconf, plugin_config
from plugins import PluginManager



chat_client: LinkAIClient chat_client: LinkAIClient


@@ -22,6 +24,29 @@ class ChatClient(LinkAIClient):
context["isgroup"] = push_msg.is_group context["isgroup"] = push_msg.is_group
self.channel.send(Reply(ReplyType.TEXT, content=msg_content), context) self.channel.send(Reply(ReplyType.TEXT, content=msg_content), context)


def on_config(self, config: dict):
if not self.client_id:
return
logger.info(f"从控制台加载配置: {config}")
local_config = conf()
for key in local_config.keys():
if config.get(key) is not None:
local_config[key] = config.get(key)
if config.get("reply_voice_mode"):
if config.get("reply_voice_mode") == "voice_reply_voice":
local_config["voice_reply_voice"] = True
elif config.get("reply_voice_mode") == "always_reply_voice":
local_config["always_reply_voice"] = True
# if config.get("admin_password") and plugin_config["Godcmd"]:
# plugin_config["Godcmd"]["password"] = config.get("admin_password")
# PluginManager().instances["Godcmd"].reload()
# if config.get("group_app_map") and pconf("linkai"):
# local_group_map = {}
# for mapping in config.get("group_app_map"):
# local_group_map[mapping.get("group_name")] = mapping.get("app_code")
# pconf("linkai")["group_app_map"] = local_group_map
# PluginManager().instances["linkai"].reload()



def start(channel): def start(channel):
global chat_client global chat_client


+ 0
- 1
config.py View File

@@ -159,7 +159,6 @@ available_setting = {
# 智谱AI 平台配置 # 智谱AI 平台配置
"zhipu_ai_api_key": "", "zhipu_ai_api_key": "",
"zhipu_ai_api_base": "https://open.bigmodel.cn/api/paas/v4", "zhipu_ai_api_base": "https://open.bigmodel.cn/api/paas/v4",

} }






+ 8
- 0
plugins/godcmd/godcmd.py View File

@@ -475,3 +475,11 @@ class Godcmd(Plugin):
if model == "gpt-4-turbo": if model == "gpt-4-turbo":
return const.GPT4_TURBO_PREVIEW return const.GPT4_TURBO_PREVIEW
return model return model

def reload(self):
gconf = plugin_config[self.name]
if gconf:
if gconf.get("password"):
self.password = gconf["password"]
if gconf.get("admin_users"):
self.admin_users = gconf["admin_users"]

+ 3
- 0
plugins/plugin.py View File

@@ -46,3 +46,6 @@ class Plugin:


def get_help_text(self, **kwargs): def get_help_text(self, **kwargs):
return "暂无帮助信息" return "暂无帮助信息"

def reload(self):
pass

+ 1
- 1
requirements-optional.txt View File

@@ -39,4 +39,4 @@ linkai
dingtalk_stream dingtalk_stream


# zhipuai # zhipuai
zhipuai>=2.0.1
zhipuai>=2.0.1

Loading…
Cancel
Save