Explorar el Código

调整

develop
H Vs hace 3 meses
padre
commit
160a028705
Se han modificado 8 ficheros con 0 adiciones y 859 borrados
  1. +0
    -17
      bot/bot.py
  2. +0
    -72
      bot/bot_factory.py
  3. +0
    -323
      bot/chatgpt/chat_gpt_bot.py
  4. +0
    -104
      bot/chatgpt/chat_gpt_session.py
  5. +0
    -122
      bot/openai/open_ai_bot.py
  6. +0
    -43
      bot/openai/open_ai_image.py
  7. +0
    -73
      bot/openai/open_ai_session.py
  8. +0
    -105
      bot/session_manager.py

+ 0
- 17
bot/bot.py Ver fichero

@@ -1,17 +0,0 @@
"""
Auto-replay chat robot abstract class
"""


from bridge.context import Context
from bridge.reply import Reply


class Bot(object):
def reply(self, query, context: Context = None) -> Reply:
"""
bot auto-reply content
:param req: received message
:return: reply content
"""
raise NotImplementedError

+ 0
- 72
bot/bot_factory.py Ver fichero

@@ -1,72 +0,0 @@
"""
channel factory
"""
from common import const


def create_bot(bot_type):
"""
create a bot_type instance
:param bot_type: bot type code
:return: bot instance
"""
# if bot_type == const.BAIDU:
# # 替换Baidu Unit为Baidu文心千帆对话接口
# # from bot.baidu.baidu_unit_bot import BaiduUnitBot
# # return BaiduUnitBot()
# from bot.baidu.baidu_wenxin import BaiduWenxinBot
# return BaiduWenxinBot()

if bot_type == const.CHATGPT:
# ChatGPT 网页端web接口
from bot.chatgpt.chat_gpt_bot import ChatGPTBot
return ChatGPTBot()

elif bot_type == const.OPEN_AI:
# OpenAI 官方对话模型API
from bot.openai.open_ai_bot import OpenAIBot
return OpenAIBot()

# elif bot_type == const.CHATGPTONAZURE:
# # Azure chatgpt service https://azure.microsoft.com/en-in/products/cognitive-services/openai-service/
# from bot.chatgpt.chat_gpt_bot import AzureChatGPTBot
# return AzureChatGPTBot()

# elif bot_type == const.XUNFEI:
# from bot.xunfei.xunfei_spark_bot import XunFeiBot
# return XunFeiBot()

# elif bot_type == const.LINKAI:
# from bot.linkai.link_ai_bot import LinkAIBot
# return LinkAIBot()

# elif bot_type == const.CLAUDEAI:
# from bot.claude.claude_ai_bot import ClaudeAIBot
# return ClaudeAIBot()
# elif bot_type == const.CLAUDEAPI:
# from bot.claudeapi.claude_api_bot import ClaudeAPIBot
# return ClaudeAPIBot()
# elif bot_type == const.QWEN:
# from bot.ali.ali_qwen_bot import AliQwenBot
# return AliQwenBot()
# elif bot_type == const.QWEN_DASHSCOPE:
# from bot.dashscope.dashscope_bot import DashscopeBot
# return DashscopeBot()
# elif bot_type == const.GEMINI:
# from bot.gemini.google_gemini_bot import GoogleGeminiBot
# return GoogleGeminiBot()

# elif bot_type == const.ZHIPU_AI:
# from bot.zhipuai.zhipuai_bot import ZHIPUAIBot
# return ZHIPUAIBot()

# elif bot_type == const.MOONSHOT:
# from bot.moonshot.moonshot_bot import MoonshotBot
# return MoonshotBot()
# elif bot_type == const.MiniMax:
# from bot.minimax.minimax_bot import MinimaxBot
# return MinimaxBot()


raise RuntimeError

+ 0
- 323
bot/chatgpt/chat_gpt_bot.py Ver fichero

@@ -1,323 +0,0 @@
# encoding:utf-8

import time

import openai
import openai.error
import requests
import json

from bot.bot import Bot
from bot.chatgpt.chat_gpt_session import ChatGPTSession
from bot.openai.open_ai_image import OpenAIImage
from bot.session_manager import SessionManager
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from common.log import logger
from common.token_bucket import TokenBucket
from config import conf, load_config
from channel.chat_message import ChatMessage

from common import memory


# OpenAI对话模型API (可用)
class ChatGPTBot(Bot, OpenAIImage):
def __init__(self):
super().__init__()
# set the default api_key
openai.api_key = conf().get("open_ai_api_key")
if conf().get("open_ai_api_base"):
openai.api_base = conf().get("open_ai_api_base")
proxy = conf().get("proxy")
if proxy:
openai.proxy = proxy
if conf().get("rate_limit_chatgpt"):
self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20))

self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
self.args = {
"model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
# "max_tokens":4096, # 回复最大的字符数
"top_p": conf().get("top_p", 1),
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
}

def reply(self, query, context=None):
# acquire reply content
if context.type == ContextType.TEXT:
# print(context.__dict__)
msg: ChatMessage = context.kwargs['msg']
# print(msg.from_user_nickname)
logger.info("[CHATGPT] {} query={}".format(msg.from_user_nickname,query))
session_id = context["session_id"]
# print(f'会话id:{session_id}')
reply = None
clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
if query in clear_memory_commands:
self.sessions.clear_session(session_id)
reply = Reply(ReplyType.INFO, "记忆已清除")
elif query == "#清除所有":
self.sessions.clear_all_session()
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
elif query == "#更新配置":
load_config()
reply = Reply(ReplyType.INFO, "配置已更新")
if reply:
return reply
session = self.sessions.session_query(query, session_id)
logger.debug("[CHATGPT] session query={}".format(session.messages))

api_key = context.get("openai_api_key")
model = context.get("gpt_model")
new_args = None
if model:
new_args = self.args.copy()
new_args["model"] = model
# if context.get('stream'):
# # reply in stream
# return self.reply_text_stream(query, new_query, session_id)

reply_content = self.reply_text(session, api_key, args=new_args)
logger.debug(
"[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
session.messages,
session_id,
reply_content["content"],
reply_content["completion_tokens"],
)
)
if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
reply = Reply(ReplyType.ERROR, reply_content["content"])
elif reply_content["completion_tokens"] > 0:
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
reply = Reply(ReplyType.TEXT, reply_content["content"])
else:
reply = Reply(ReplyType.ERROR, reply_content["content"])
logger.debug("[CHATGPT] reply {} used 0 tokens.".format(reply_content))
return reply

elif context.type == ContextType.IMAGE_CREATE:
ok, retstring = self.create_img(query, 0)
reply = None
if ok:
reply = Reply(ReplyType.IMAGE_URL, retstring)
else:
reply = Reply(ReplyType.ERROR, retstring)
return reply
else:
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
return reply

def reply_text(self, session: ChatGPTSession, api_key=None, args=None, retry_count=0) -> dict:
"""
call openai's ChatCompletion to get the answer
:param session: a conversation session
:param session_id: session id
:param retry_count: retry count
:return: {}
"""
try:
if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token():
raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
# if api_key == None, the default openai.api_key will be used
if args is None:
args = self.args

# Define additional parameters
additional_params = {
"chatId": session.session_id,
"detail": True
}

# Combine the additional params with the existing args (if any)
args.update(additional_params)
# msgs=session.messages

# cache_data = memory.USER_INTERACTIVE_CACHE.get(session.session_id)

# # Determine messages to send based on cache data
# messages_to_send = msgs[-1] if cache_data and cache_data.get('interactive') else msgs
# print(msgs[-1])
# print('----------------')
# # Send the response using OpenAI API
# response = openai.ChatCompletion.create(api_key=api_key, messages=messages_to_send, **args)
messages_to_send=session.messages

cache_data = memory.USER_INTERACTIVE_CACHE.get(session.session_id)
if cache_data and cache_data.get('interactive'):
messages_to_send=[session.messages[-1]]
print(messages_to_send)
response = openai.ChatCompletion.create(api_key=api_key, messages=messages_to_send, **args)
# print("{}".format(session.__dict__))
logger.info("[CHATGPT] 请求={}".format(messages_to_send))
# print(f'会话id:{session.session_id}')
# logger.info("[CHATGPT] 响应={}".format(response))
logger.info("[CHATGPT] 响应={}".format(json.dumps(response, separators=(',', ':'),ensure_ascii=False)))
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
content=response.choices[0]["message"]["content"]
description = ''
userSelectOptions = []
if isinstance(content, list) and any(item.get("type") == "interactive" for item in content):
# print(content)
for item in content:
if item["type"] == "interactive" and item["interactive"]["type"] == "userSelect":
params = item["interactive"]["params"]
description = params.get("description")
userSelectOptions = params.get("userSelectOptions", [])
values_string = "\n".join(option["value"] for option in userSelectOptions)
if description is not None:
memory.USER_INTERACTIVE_CACHE[session.session_id] = {
"interactive":True
}
return {
"total_tokens": response["usage"]["total_tokens"],
"completion_tokens": response["usage"]["completion_tokens"],
"content": description + '------------------------------\n'+values_string,
}
elif isinstance(content, list) and any(item.get("type") == "text" for item in content):
memory.USER_INTERACTIVE_CACHE[session.session_id] = {
"interactive":False
}
text=''
for item in content:
if item["type"] == "text":
text=item["text"]["content"]
if text=='':
args.pop('chatId', None) # The second argument (None) is the default return value if the key doesn't exist
args.pop('detail', None)
response = openai.ChatCompletion.create(api_key=api_key, messages=messages_to_send, **args)
text=response.choices[0]["message"]["content"]
return {
"total_tokens": response["usage"]["total_tokens"],
"completion_tokens": response["usage"]["completion_tokens"],
"content": text,
}

else:
memory.USER_INTERACTIVE_CACHE[session.session_id] = {
"interactive":False
}
return {
"total_tokens": response["usage"]["total_tokens"],
"completion_tokens": response["usage"]["completion_tokens"],
"content": content.lstrip("\n"),
}
except Exception as e:
need_retry = retry_count < 2
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
if isinstance(e, openai.error.RateLimitError):
logger.warn("[CHATGPT] RateLimitError: {}".format(e))
result["content"] = "提问太快啦,请休息一下再问我吧"
if need_retry:
time.sleep(20)
elif isinstance(e, openai.error.Timeout):
logger.warn("[CHATGPT] Timeout: {}".format(e))
result["content"] = "我没有收到你的消息"
if need_retry:
time.sleep(5)
elif isinstance(e, openai.error.APIError):
logger.warn("[CHATGPT] Bad Gateway: {}".format(e))
result["content"] = "请再问我一次"
if need_retry:
time.sleep(10)
elif isinstance(e, openai.error.APIConnectionError):
logger.warn("[CHATGPT] APIConnectionError: {}".format(e))
result["content"] = "我连接不到你的网络"
if need_retry:
time.sleep(5)
else:
logger.exception("[CHATGPT] Exception: {}".format(e))
need_retry = False
self.sessions.clear_session(session.session_id)

if need_retry:
logger.warn("[CHATGPT] 第{}次重试".format(retry_count + 1))
return self.reply_text(session, api_key, args, retry_count + 1)
else:
return result


class AzureChatGPTBot(ChatGPTBot):
def __init__(self):
super().__init__()
openai.api_type = "azure"
openai.api_version = conf().get("azure_api_version", "2023-06-01-preview")
self.args["deployment_id"] = conf().get("azure_deployment_id")

def create_img(self, query, retry_count=0, api_key=None):
text_to_image_model = conf().get("text_to_image")
if text_to_image_model == "dall-e-2":
api_version = "2023-06-01-preview"
endpoint = conf().get("azure_openai_dalle_api_base","open_ai_api_base")
# 检查endpoint是否以/结尾
if not endpoint.endswith("/"):
endpoint = endpoint + "/"
url = "{}openai/images/generations:submit?api-version={}".format(endpoint, api_version)
api_key = conf().get("azure_openai_dalle_api_key","open_ai_api_key")
headers = {"api-key": api_key, "Content-Type": "application/json"}
try:
body = {"prompt": query, "size": conf().get("image_create_size", "256x256"),"n": 1}
submission = requests.post(url, headers=headers, json=body)
operation_location = submission.headers['operation-location']
status = ""
while (status != "succeeded"):
if retry_count > 3:
return False, "图片生成失败"
response = requests.get(operation_location, headers=headers)
status = response.json()['status']
retry_count += 1
image_url = response.json()['result']['data'][0]['url']
return True, image_url
except Exception as e:
logger.error("create image error: {}".format(e))
return False, "图片生成失败"
elif text_to_image_model == "dall-e-3":
api_version = conf().get("azure_api_version", "2024-02-15-preview")
endpoint = conf().get("azure_openai_dalle_api_base","open_ai_api_base")
# 检查endpoint是否以/结尾
if not endpoint.endswith("/"):
endpoint = endpoint + "/"
url = "{}openai/deployments/{}/images/generations?api-version={}".format(endpoint, conf().get("azure_openai_dalle_deployment_id","text_to_image"),api_version)
api_key = conf().get("azure_openai_dalle_api_key","open_ai_api_key")
headers = {"api-key": api_key, "Content-Type": "application/json"}
try:
body = {"prompt": query, "size": conf().get("image_create_size", "1024x1024"), "quality": conf().get("dalle3_image_quality", "standard")}
response = requests.post(url, headers=headers, json=body)
response.raise_for_status() # 检查请求是否成功
data = response.json()

# 检查响应中是否包含图像 URL
if 'data' in data and len(data['data']) > 0 and 'url' in data['data'][0]:
image_url = data['data'][0]['url']
return True, image_url
else:
error_message = "响应中没有图像 URL"
logger.error(error_message)
return False, "图片生成失败"

except requests.exceptions.RequestException as e:
# 捕获所有请求相关的异常
try:
error_detail = response.json().get('error', {}).get('message', str(e))
except ValueError:
error_detail = str(e)
error_message = f"{error_detail}"
logger.error(error_message)
return False, error_message

except Exception as e:
# 捕获所有其他异常
error_message = f"生成图像时发生错误: {e}"
logger.error(error_message)
return False, "图片生成失败"
else:
return False, "图片生成失败,未配置text_to_image参数"

+ 0
- 104
bot/chatgpt/chat_gpt_session.py Ver fichero

@@ -1,104 +0,0 @@
from bot.session_manager import Session
from common.log import logger
from common import const

"""
e.g. [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Who won the world series in 2020?"},
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
{"role": "user", "content": "Where was it played?"}
]
"""


class ChatGPTSession(Session):
def __init__(self, session_id, system_prompt=None, model="gpt-3.5-turbo"):
super().__init__(session_id, system_prompt)
self.model = model
self.reset()

def discard_exceeding(self, max_tokens, cur_tokens=None):
precise = True
try:
cur_tokens = self.calc_tokens()
except Exception as e:
precise = False
if cur_tokens is None:
raise e
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
while cur_tokens > max_tokens:
if len(self.messages) > 2:
self.messages.pop(1)
elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
self.messages.pop(1)
if precise:
cur_tokens = self.calc_tokens()
else:
cur_tokens = cur_tokens - max_tokens
break
elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
break
else:
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
break
if precise:
cur_tokens = self.calc_tokens()
else:
cur_tokens = cur_tokens - max_tokens
return cur_tokens

def calc_tokens(self):
return num_tokens_from_messages(self.messages, self.model)


# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
def num_tokens_from_messages(messages, model):
"""Returns the number of tokens used by a list of messages."""

if model in ["wenxin", "xunfei", const.GEMINI]:
return num_tokens_by_character(messages)

import tiktoken

if model in ["gpt-3.5-turbo-0301", "gpt-35-turbo", "gpt-3.5-turbo-1106", "moonshot", const.LINKAI_35]:
return num_tokens_from_messages(messages, model="gpt-3.5-turbo")
elif model in ["gpt-4-0314", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613", "gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-35-turbo-16k", "gpt-4-turbo-preview",
"gpt-4-1106-preview",const.GPT4_TURBO_PREVIEW, const.GPT4_VISION_PREVIEW, const.GPT4_TURBO_01_25,
const.GPT_4o, const.GPT_4O_0806, const.GPT_4o_MINI, const.LINKAI_4o, const.LINKAI_4_TURBO]:
return num_tokens_from_messages(messages, model="gpt-4")
elif model.startswith("claude-3"):
return num_tokens_from_messages(messages, model="gpt-3.5-turbo")
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
logger.debug("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
if model == "gpt-3.5-turbo":
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
elif model == "gpt-4":
tokens_per_message = 3
tokens_per_name = 1
else:
logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo.")
return num_tokens_from_messages(messages, model="gpt-3.5-turbo")
num_tokens = 0
for message in messages:
num_tokens += tokens_per_message
for key, value in message.items():
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens


def num_tokens_by_character(messages):
"""Returns the number of tokens used by a list of messages."""
tokens = 0
for msg in messages:
tokens += len(msg["content"])
return tokens

+ 0
- 122
bot/openai/open_ai_bot.py Ver fichero

@@ -1,122 +0,0 @@
# encoding:utf-8

import time

import openai
import openai.error

from bot.bot import Bot
from bot.openai.open_ai_image import OpenAIImage
from bot.openai.open_ai_session import OpenAISession
from bot.session_manager import SessionManager
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from common.log import logger
from config import conf

user_session = dict()


# OpenAI对话模型API (可用)
class OpenAIBot(Bot, OpenAIImage):
def __init__(self):
super().__init__()
openai.api_key = conf().get("open_ai_api_key")
if conf().get("open_ai_api_base"):
openai.api_base = conf().get("open_ai_api_base")
proxy = conf().get("proxy")
if proxy:
openai.proxy = proxy

self.sessions = SessionManager(OpenAISession, model=conf().get("model") or "text-davinci-003")
self.args = {
"model": conf().get("model") or "text-davinci-003", # 对话模型的名称
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
"max_tokens": 1200, # 回复最大的字符数
"top_p": 1,
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
"stop": ["\n\n\n"],
}

def reply(self, query, context=None):
# acquire reply content
if context and context.type:
if context.type == ContextType.TEXT:
logger.info("[OPEN_AI] query={}".format(query))
session_id = context["session_id"]
reply = None
if query == "#清除记忆":
self.sessions.clear_session(session_id)
reply = Reply(ReplyType.INFO, "记忆已清除")
elif query == "#清除所有":
self.sessions.clear_all_session()
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
else:
session = self.sessions.session_query(query, session_id)
result = self.reply_text(session)
total_tokens, completion_tokens, reply_content = (
result["total_tokens"],
result["completion_tokens"],
result["content"],
)
logger.debug(
"[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens)
)

if total_tokens == 0:
reply = Reply(ReplyType.ERROR, reply_content)
else:
self.sessions.session_reply(reply_content, session_id, total_tokens)
reply = Reply(ReplyType.TEXT, reply_content)
return reply
elif context.type == ContextType.IMAGE_CREATE:
ok, retstring = self.create_img(query, 0)
reply = None
if ok:
reply = Reply(ReplyType.IMAGE_URL, retstring)
else:
reply = Reply(ReplyType.ERROR, retstring)
return reply

def reply_text(self, session: OpenAISession, retry_count=0):
try:
response = openai.Completion.create(prompt=str(session), **self.args)
res_content = response.choices[0]["text"].strip().replace("<|endoftext|>", "")
total_tokens = response["usage"]["total_tokens"]
completion_tokens = response["usage"]["completion_tokens"]
logger.info("[OPEN_AI] reply={}".format(res_content))
return {
"total_tokens": total_tokens,
"completion_tokens": completion_tokens,
"content": res_content,
}
except Exception as e:
need_retry = retry_count < 2
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
if isinstance(e, openai.error.RateLimitError):
logger.warn("[OPEN_AI] RateLimitError: {}".format(e))
result["content"] = "提问太快啦,请休息一下再问我吧"
if need_retry:
time.sleep(20)
elif isinstance(e, openai.error.Timeout):
logger.warn("[OPEN_AI] Timeout: {}".format(e))
result["content"] = "我没有收到你的消息"
if need_retry:
time.sleep(5)
elif isinstance(e, openai.error.APIConnectionError):
logger.warn("[OPEN_AI] APIConnectionError: {}".format(e))
need_retry = False
result["content"] = "我连接不到你的网络"
else:
logger.warn("[OPEN_AI] Exception: {}".format(e))
need_retry = False
self.sessions.clear_session(session.session_id)

if need_retry:
logger.warn("[OPEN_AI] 第{}次重试".format(retry_count + 1))
return self.reply_text(session, retry_count + 1)
else:
return result

+ 0
- 43
bot/openai/open_ai_image.py Ver fichero

@@ -1,43 +0,0 @@
import time

import openai
import openai.error

from common.log import logger
from common.token_bucket import TokenBucket
from config import conf


# OPENAI提供的画图接口
class OpenAIImage(object):
def __init__(self):
openai.api_key = conf().get("open_ai_api_key")
if conf().get("rate_limit_dalle"):
self.tb4dalle = TokenBucket(conf().get("rate_limit_dalle", 50))

def create_img(self, query, retry_count=0, api_key=None, api_base=None):
try:
if conf().get("rate_limit_dalle") and not self.tb4dalle.get_token():
return False, "请求太快了,请休息一下再问我吧"
logger.info("[OPEN_AI] image_query={}".format(query))
response = openai.Image.create(
api_key=api_key,
prompt=query, # 图片描述
n=1, # 每次生成图片的数量
model=conf().get("text_to_image") or "dall-e-2",
# size=conf().get("image_create_size", "256x256"), # 图片大小,可选有 256x256, 512x512, 1024x1024
)
image_url = response["data"][0]["url"]
logger.info("[OPEN_AI] image_url={}".format(image_url))
return True, image_url
except openai.error.RateLimitError as e:
logger.warn(e)
if retry_count < 1:
time.sleep(5)
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count + 1))
return self.create_img(query, retry_count + 1)
else:
return False, "画图出现问题,请休息一下再问我吧"
except Exception as e:
logger.exception(e)
return False, "画图出现问题,请休息一下再问我吧"

+ 0
- 73
bot/openai/open_ai_session.py Ver fichero

@@ -1,73 +0,0 @@
from bot.session_manager import Session
from common.log import logger


class OpenAISession(Session):
def __init__(self, session_id, system_prompt=None, model="text-davinci-003"):
super().__init__(session_id, system_prompt)
self.model = model
self.reset()

def __str__(self):
# 构造对话模型的输入
"""
e.g. Q: xxx
A: xxx
Q: xxx
"""
prompt = ""
for item in self.messages:
if item["role"] == "system":
prompt += item["content"] + "<|endoftext|>\n\n\n"
elif item["role"] == "user":
prompt += "Q: " + item["content"] + "\n"
elif item["role"] == "assistant":
prompt += "\n\nA: " + item["content"] + "<|endoftext|>\n"

if len(self.messages) > 0 and self.messages[-1]["role"] == "user":
prompt += "A: "
return prompt

def discard_exceeding(self, max_tokens, cur_tokens=None):
precise = True
try:
cur_tokens = self.calc_tokens()
except Exception as e:
precise = False
if cur_tokens is None:
raise e
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
while cur_tokens > max_tokens:
if len(self.messages) > 1:
self.messages.pop(0)
elif len(self.messages) == 1 and self.messages[0]["role"] == "assistant":
self.messages.pop(0)
if precise:
cur_tokens = self.calc_tokens()
else:
cur_tokens = len(str(self))
break
elif len(self.messages) == 1 and self.messages[0]["role"] == "user":
logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens))
break
else:
logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages)))
break
if precise:
cur_tokens = self.calc_tokens()
else:
cur_tokens = len(str(self))
return cur_tokens

def calc_tokens(self):
return num_tokens_from_string(str(self), self.model)


# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
def num_tokens_from_string(string: str, model: str) -> int:
"""Returns the number of tokens in a text string."""
import tiktoken

encoding = tiktoken.encoding_for_model(model)
num_tokens = len(encoding.encode(string, disallowed_special=()))
return num_tokens

+ 0
- 105
bot/session_manager.py Ver fichero

@@ -1,105 +0,0 @@
from common.expired_dict import ExpiredDict
from common.log import logger
from config import conf
import json


class Session(object):
def __init__(self, session_id, system_prompt=None):
self.session_id = session_id
self.messages = []
if system_prompt is None:
self.system_prompt = conf().get("character_desc", "")
else:
self.system_prompt = system_prompt

# 重置会话
def reset(self):
system_item = {"role": "system", "content": self.system_prompt}
self.messages = [system_item]

def set_system_prompt(self, system_prompt):
self.system_prompt = system_prompt
self.reset()

# def add_query(self, query):
# user_item = {"role": "user", "content": query}
# self.messages.append(user_item)

def add_query(self, query):
try:
# 判断是否为 JSON 字符串,如果是则转换为 Python 字典
json_data = json.loads(query)
if isinstance(json_data, dict) or isinstance(json_data, list): # 检查是否为字典格式
user_item = {"role": "user", "content": json_data}
else:
user_item = {"role": "user", "content": query}
except json.JSONDecodeError:
# 如果不是 JSON 字符串,直接保存为字符串
user_item = {"role": "user", "content": query}
self.messages.append(user_item)

def add_reply(self, reply):
assistant_item = {"role": "assistant", "content": reply}
self.messages.append(assistant_item)

def discard_exceeding(self, max_tokens=None, cur_tokens=None):
raise NotImplementedError

def calc_tokens(self):
raise NotImplementedError


class SessionManager(object):
def __init__(self, sessioncls, **session_args):
if conf().get("expires_in_seconds"):
sessions = ExpiredDict(conf().get("expires_in_seconds"))
else:
sessions = dict()
self.sessions = sessions
self.sessioncls = sessioncls
self.session_args = session_args

def build_session(self, session_id, system_prompt=None):
"""
如果session_id不在sessions中,创建一个新的session并添加到sessions中
如果system_prompt不会空,会更新session的system_prompt并重置session
"""
if session_id is None:
return self.sessioncls(session_id, system_prompt, **self.session_args)

if session_id not in self.sessions:
self.sessions[session_id] = self.sessioncls(session_id, system_prompt, **self.session_args)
elif system_prompt is not None: # 如果有新的system_prompt,更新并重置session
self.sessions[session_id].set_system_prompt(system_prompt)
session = self.sessions[session_id]
return session

def session_query(self, query, session_id):
session = self.build_session(session_id)
session.add_query(query)
try:
max_tokens = conf().get("conversation_max_tokens", 1000)
total_tokens = session.discard_exceeding(max_tokens, None)
logger.debug("prompt tokens used={}".format(total_tokens))
except Exception as e:
logger.warning("Exception when counting tokens precisely for prompt: {}".format(str(e)))
return session

def session_reply(self, reply, session_id, total_tokens=None):
session = self.build_session(session_id)
session.add_reply(reply)
try:
max_tokens = conf().get("conversation_max_tokens", 1000)
tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt))
except Exception as e:
logger.warning("Exception when counting tokens precisely for session: {}".format(str(e)))
return session

def clear_session(self, session_id):
if session_id in self.sessions:
del self.sessions[session_id]

def clear_all_session(self):
self.sessions.clear()

Cargando…
Cancelar
Guardar