@@ -13,7 +13,8 @@ def sigterm_handler_wrap(_signo): | |||||
def func(_signo, _stack_frame): | def func(_signo, _stack_frame): | ||||
logger.info("signal {} received, exiting...".format(_signo)) | logger.info("signal {} received, exiting...".format(_signo)) | ||||
conf().save_user_datas() | conf().save_user_datas() | ||||
return old_handler(_signo, _stack_frame) | |||||
if callable(old_handler): # check old_handler | |||||
return old_handler(_signo, _stack_frame) | |||||
signal.signal(_signo, func) | signal.signal(_signo, func) | ||||
def run(): | def run(): | ||||
@@ -3,13 +3,12 @@ | |||||
from bot.bot import Bot | from bot.bot import Bot | ||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession | from bot.chatgpt.chat_gpt_session import ChatGPTSession | ||||
from bot.openai.open_ai_image import OpenAIImage | from bot.openai.open_ai_image import OpenAIImage | ||||
from bot.session_manager import Session, SessionManager | |||||
from bot.session_manager import SessionManager | |||||
from bridge.context import ContextType | from bridge.context import ContextType | ||||
from bridge.reply import Reply, ReplyType | from bridge.reply import Reply, ReplyType | ||||
from config import conf, load_config | from config import conf, load_config | ||||
from common.log import logger | from common.log import logger | ||||
from common.token_bucket import TokenBucket | from common.token_bucket import TokenBucket | ||||
from common.expired_dict import ExpiredDict | |||||
import openai | import openai | ||||
import openai.error | import openai.error | ||||
import time | import time | ||||
@@ -91,8 +90,8 @@ class ChatGPTBot(Bot,OpenAIImage): | |||||
"top_p":1, | "top_p":1, | ||||
"frequency_penalty":conf().get('frequency_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 | "frequency_penalty":conf().get('frequency_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 | ||||
"presence_penalty":conf().get('presence_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 | "presence_penalty":conf().get('presence_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容 | ||||
"request_timeout": conf().get('request_timeout', 60), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 | |||||
"timeout": conf().get('request_timeout', 120), #重试超时时间,在这个时间内,将会自动重试 | |||||
"request_timeout": conf().get('request_timeout', None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 | |||||
"timeout": conf().get('request_timeout', None), #重试超时时间,在这个时间内,将会自动重试 | |||||
} | } | ||||
def reply_text(self, session:ChatGPTSession, session_id, api_key, retry_count=0) -> dict: | def reply_text(self, session:ChatGPTSession, session_id, api_key, retry_count=0) -> dict: | ||||
@@ -151,6 +150,7 @@ class AzureChatGPTBot(ChatGPTBot): | |||||
def compose_args(self): | def compose_args(self): | ||||
args = super().compose_args() | args = super().compose_args() | ||||
args["engine"] = args["model"] | |||||
del(args["model"]) | |||||
return args | |||||
args["deployment_id"] = conf().get("azure_deployment_id") | |||||
#args["engine"] = args["model"] | |||||
#del(args["model"]) | |||||
return args |
@@ -55,7 +55,7 @@ def num_tokens_from_messages(messages, model): | |||||
except KeyError: | except KeyError: | ||||
logger.debug("Warning: model not found. Using cl100k_base encoding.") | logger.debug("Warning: model not found. Using cl100k_base encoding.") | ||||
encoding = tiktoken.get_encoding("cl100k_base") | encoding = tiktoken.get_encoding("cl100k_base") | ||||
if model == "gpt-3.5-turbo": | |||||
if model == "gpt-3.5-turbo" or model == "gpt-35-turbo": | |||||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301") | return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301") | ||||
elif model == "gpt-4": | elif model == "gpt-4": | ||||
return num_tokens_from_messages(messages, model="gpt-4-0314") | return num_tokens_from_messages(messages, model="gpt-4-0314") | ||||
@@ -76,4 +76,4 @@ def num_tokens_from_messages(messages, model): | |||||
if key == "name": | if key == "name": | ||||
num_tokens += tokens_per_name | num_tokens += tokens_per_name | ||||
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> | num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> | ||||
return num_tokens | |||||
return num_tokens |
@@ -19,7 +19,7 @@ class Bridge(object): | |||||
model_type = conf().get("model") | model_type = conf().get("model") | ||||
if model_type in ["text-davinci-003"]: | if model_type in ["text-davinci-003"]: | ||||
self.btype['chat'] = const.OPEN_AI | self.btype['chat'] = const.OPEN_AI | ||||
if conf().get("use_azure_chatgpt"): | |||||
if conf().get("use_azure_chatgpt", False): | |||||
self.btype['chat'] = const.CHATGPTONAZURE | self.btype['chat'] = const.CHATGPTONAZURE | ||||
self.bots={} | self.bots={} | ||||
@@ -233,6 +233,9 @@ class ChatChannel(Channel): | |||||
time.sleep(3+3*retry_cnt) | time.sleep(3+3*retry_cnt) | ||||
self._send(reply, context, retry_cnt+1) | self._send(reply, context, retry_cnt+1) | ||||
def _success_callback(self, session_id, **kwargs):# 线程正常结束时的回调函数 | |||||
logger.debug("Worker return success, session_id = {}".format(session_id)) | |||||
def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数 | def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数 | ||||
logger.exception("Worker return exception: {}".format(exception)) | logger.exception("Worker return exception: {}".format(exception)) | ||||
@@ -242,6 +245,8 @@ class ChatChannel(Channel): | |||||
worker_exception = worker.exception() | worker_exception = worker.exception() | ||||
if worker_exception: | if worker_exception: | ||||
self._fail_callback(session_id, exception = worker_exception, **kwargs) | self._fail_callback(session_id, exception = worker_exception, **kwargs) | ||||
else: | |||||
self._success_callback(session_id, **kwargs) | |||||
except CancelledError as e: | except CancelledError as e: | ||||
logger.info("Worker cancelled, session_id = {}".format(session_id)) | logger.info("Worker cancelled, session_id = {}".format(session_id)) | ||||
except Exception as e: | except Exception as e: | ||||
@@ -147,6 +147,8 @@ class WechatChannel(ChatChannel): | |||||
if conf().get('speech_recognition') != True: | if conf().get('speech_recognition') != True: | ||||
return | return | ||||
logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content)) | logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content)) | ||||
elif cmsg.ctype == ContextType.IMAGE: | |||||
logger.debug("[WX]receive image for group msg: {}".format(cmsg.content)) | |||||
else: | else: | ||||
# logger.debug("[WX]receive group msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg)) | # logger.debug("[WX]receive group msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg)) | ||||
pass | pass | ||||
@@ -97,7 +97,6 @@ class WechatMPChannel(ChatChannel): | |||||
if self.passive_reply: | if self.passive_reply: | ||||
receiver = context["receiver"] | receiver = context["receiver"] | ||||
self.cache_dict[receiver] = reply.content | self.cache_dict[receiver] = reply.content | ||||
self.running.remove(receiver) | |||||
logger.info("[send] reply to {} saved to cache: {}".format(receiver, reply)) | logger.info("[send] reply to {} saved to cache: {}".format(receiver, reply)) | ||||
else: | else: | ||||
receiver = context["receiver"] | receiver = context["receiver"] | ||||
@@ -116,8 +115,14 @@ class WechatMPChannel(ChatChannel): | |||||
return | return | ||||
def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数 | |||||
logger.debug("[wechatmp] Success to generate reply, msgId={}".format(context['msg'].msg_id)) | |||||
if self.passive_reply: | |||||
self.running.remove(session_id) | |||||
def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数 | def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数 | ||||
logger.exception("[wechatmp] Fail to generation message to user, msgId={}, exception={}".format(context['msg'].msg_id, exception)) | |||||
logger.exception("[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format(context['msg'].msg_id, exception)) | |||||
if self.passive_reply: | if self.passive_reply: | ||||
assert session_id not in self.cache_dict | assert session_id not in self.cache_dict | ||||
self.running.remove(session_id) | self.running.remove(session_id) | ||||
@@ -12,7 +12,7 @@ class TmpDir(object): | |||||
def __init__(self): | def __init__(self): | ||||
pathExists = os.path.exists(self.tmpFilePath) | pathExists = os.path.exists(self.tmpFilePath) | ||||
if not pathExists and conf().get('speech_recognition') == True: | |||||
if not pathExists: | |||||
os.makedirs(self.tmpFilePath) | os.makedirs(self.tmpFilePath) | ||||
def path(self): | def path(self): | ||||
@@ -2,7 +2,6 @@ | |||||
"open_ai_api_key": "YOUR API KEY", | "open_ai_api_key": "YOUR API KEY", | ||||
"model": "gpt-3.5-turbo", | "model": "gpt-3.5-turbo", | ||||
"proxy": "", | "proxy": "", | ||||
"use_azure_chatgpt": false, | |||||
"single_chat_prefix": ["bot", "@bot"], | "single_chat_prefix": ["bot", "@bot"], | ||||
"single_chat_reply_prefix": "[bot] ", | "single_chat_reply_prefix": "[bot] ", | ||||
"group_chat_prefix": ["@bot"], | "group_chat_prefix": ["@bot"], | ||||
@@ -16,6 +16,7 @@ available_setting = { | |||||
# chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称 | # chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称 | ||||
"model": "gpt-3.5-turbo", | "model": "gpt-3.5-turbo", | ||||
"use_azure_chatgpt": False, # 是否使用azure的chatgpt | "use_azure_chatgpt": False, # 是否使用azure的chatgpt | ||||
"azure_deployment_id": "", #azure 模型部署名称 | |||||
# Bot触发配置 | # Bot触发配置 | ||||
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复 | "single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复 | ||||
@@ -1,7 +1,7 @@ | |||||
**Table of Content** | **Table of Content** | ||||
- [插件化初衷](#插件化初衷) | - [插件化初衷](#插件化初衷) | ||||
- [插件安装方法](#插件化安装方法) | |||||
- [插件安装方法](#插件安装方法) | |||||
- [插件化实现](#插件化实现) | - [插件化实现](#插件化实现) | ||||
- [插件编写示例](#插件编写示例) | - [插件编写示例](#插件编写示例) | ||||
- [插件设计建议](#插件设计建议) | - [插件设计建议](#插件设计建议) | ||||
@@ -52,6 +52,8 @@ | |||||
以下是它们的默认处理逻辑(太长不看,可跳到[插件编写示例](#插件编写示例)): | 以下是它们的默认处理逻辑(太长不看,可跳到[插件编写示例](#插件编写示例)): | ||||
**注意以下包含的代码是`v1.1.0`中的片段,已过时,只可用于理解事件,最新的默认代码逻辑请参考[chat_channel](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/chat_channel.py)** | |||||
#### 1. 收到消息 | #### 1. 收到消息 | ||||
负责接收用户消息,根据用户的配置,判断本条消息是否触发机器人。如果触发,则会判断该消息的类型(声音、文本、画图命令等),将消息包装成如下的`Context`交付给下一个步骤。 | 负责接收用户消息,根据用户的配置,判断本条消息是否触发机器人。如果触发,则会判断该消息的类型(声音、文本、画图命令等),将消息包装成如下的`Context`交付给下一个步骤。 | ||||
@@ -91,9 +93,9 @@ | |||||
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: | if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: | ||||
reply = super().build_reply_content(context.content, context) #文字跟画图交付给chatgpt | reply = super().build_reply_content(context.content, context) #文字跟画图交付给chatgpt | ||||
elif context.type == ContextType.VOICE: # 声音先进行语音转文字后,修改Context类型为文字后,再交付给chatgpt | elif context.type == ContextType.VOICE: # 声音先进行语音转文字后,修改Context类型为文字后,再交付给chatgpt | ||||
msg = context['msg'] | |||||
file_name = TmpDir().path() + context.content | |||||
msg.download(file_name) | |||||
cmsg = context['msg'] | |||||
cmsg.prepare() | |||||
file_name = context.content | |||||
reply = super().build_voice_to_text(file_name) | reply = super().build_voice_to_text(file_name) | ||||
if reply.type != ReplyType.ERROR and reply.type != ReplyType.INFO: | if reply.type != ReplyType.ERROR and reply.type != ReplyType.INFO: | ||||
context.content = reply.content # 语音转文字后,将文字内容作为新的context | context.content = reply.content # 语音转文字后,将文字内容作为新的context | ||||
@@ -21,4 +21,4 @@ web.py | |||||
# chatgpt-tool-hub plugin | # chatgpt-tool-hub plugin | ||||
--extra-index-url https://pypi.python.org/simple | --extra-index-url https://pypi.python.org/simple | ||||
chatgpt_tool_hub>=0.3.5 | |||||
chatgpt_tool_hub>=0.3.7 |
@@ -1,4 +1,4 @@ | |||||
openai>=0.27.2 | |||||
openai==0.27.2 | |||||
HTMLParser>=0.0.2 | HTMLParser>=0.0.2 | ||||
PyQRCode>=1.2.1 | PyQRCode>=1.2.1 | ||||
qrcode>=7.4.2 | qrcode>=7.4.2 | ||||