瀏覽代碼

Merge pull request #39 from xyc0123456789/master

修改为使用chatgpt
develop
zhayujie GitHub 1 年之前
父節點
當前提交
552a7625dd
沒有發現已知的金鑰在資料庫的簽署中 GPG 金鑰 ID: 4AEE18F83AFDEB23
共有 1 個檔案被更改,包括 499 行新增43 行删除
  1. +499
    -43
      bot/chatgpt/chat_gpt_bot.py

+ 499
- 43
bot/chatgpt/chat_gpt_bot.py 查看文件

@@ -1,55 +1,511 @@
import time
"""
A simple wrapper for the official ChatGPT API
"""
import argparse
import json
import os
import sys
from datetime import date

import openai
import tiktoken

from bot.bot import Bot
from revChatGPT.revChatGPT import Chatbot
from common.log import logger
from config import conf

user_session = dict()
last_session_refresh = time.time()
ENGINE = os.environ.get("GPT_ENGINE") or "text-chat-davinci-002-20221122"

ENCODER = tiktoken.get_encoding("gpt2")

# ChatGPT web接口 (暂时不可用)
class ChatGPTBot(Bot):
def __init__(self):
config = {
"Authorization": "<Your Bearer Token Here>", # This is optional
"session_token": conf().get("session_token")

def get_max_tokens(prompt: str) -> int:
"""
Get the max tokens for a prompt
"""
return 4000 - len(ENCODER.encode(prompt))


# ['text-chat-davinci-002-20221122']
class Chatbot:
"""
Official ChatGPT API
"""

def __init__(self, api_key: str, buffer: int = None) -> None:
"""
Initialize Chatbot with API key (from https://platform.openai.com/account/api-keys)
"""
openai.api_key = api_key or os.environ.get("OPENAI_API_KEY")
self.conversations = Conversation()
self.prompt = Prompt(buffer=buffer)

def _get_completion(
self,
prompt: str,
temperature: float = 0.5,
stream: bool = False,
):
"""
Get the completion function
"""
return openai.Completion.create(
engine=ENGINE,
prompt=prompt,
temperature=temperature,
max_tokens=get_max_tokens(prompt),
stop=["\n\n\n"],
stream=stream,
)

def _process_completion(
self,
user_request: str,
completion: dict,
conversation_id: str = None,
user: str = "User",
) -> dict:
if completion.get("choices") is None:
raise Exception("ChatGPT API returned no choices")
if len(completion["choices"]) == 0:
raise Exception("ChatGPT API returned no choices")
if completion["choices"][0].get("text") is None:
raise Exception("ChatGPT API returned no text")
completion["choices"][0]["text"] = completion["choices"][0]["text"].rstrip(
"<|im_end|>",
)
# Add to chat history
self.prompt.add_to_history(
user_request,
completion["choices"][0]["text"],
user=user,
)
if conversation_id is not None:
self.save_conversation(conversation_id)
return completion

def _process_completion_stream(
self,
user_request: str,
completion: dict,
conversation_id: str = None,
user: str = "User",
) -> str:
full_response = ""
for response in completion:
if response.get("choices") is None:
raise Exception("ChatGPT API returned no choices")
if len(response["choices"]) == 0:
raise Exception("ChatGPT API returned no choices")
if response["choices"][0].get("finish_details") is not None:
break
if response["choices"][0].get("text") is None:
raise Exception("ChatGPT API returned no text")
if response["choices"][0]["text"] == "<|im_end|>":
break
yield response["choices"][0]["text"]
full_response += response["choices"][0]["text"]

# Add to chat history
self.prompt.add_to_history(user_request, full_response, user)
if conversation_id is not None:
self.save_conversation(conversation_id)

def ask(
self,
user_request: str,
temperature: float = 0.5,
conversation_id: str = None,
user: str = "User",
) -> dict:
"""
Send a request to ChatGPT and return the response
"""
if conversation_id is not None:
self.load_conversation(conversation_id)
completion = self._get_completion(
self.prompt.construct_prompt(user_request, user=user),
temperature,
)
return self._process_completion(user_request, completion, user=user)

def ask_stream(
self,
user_request: str,
temperature: float = 0.5,
conversation_id: str = None,
user: str = "User",
) -> str:
"""
Send a request to ChatGPT and yield the response
"""
if conversation_id is not None:
self.load_conversation(conversation_id)
prompt = self.prompt.construct_prompt(user_request, user=user)
return self._process_completion_stream(
user_request=user_request,
completion=self._get_completion(prompt, temperature, stream=True),
user=user,
)

def make_conversation(self, conversation_id: str) -> None:
"""
Make a conversation
"""
self.conversations.add_conversation(conversation_id, [])

def rollback(self, num: int) -> None:
"""
Rollback chat history num times
"""
for _ in range(num):
self.prompt.chat_history.pop()

def reset(self) -> None:
"""
Reset chat history
"""
self.prompt.chat_history = []

def load_conversation(self, conversation_id) -> None:
"""
Load a conversation from the conversation history
"""
if conversation_id not in self.conversations.conversations:
# Create a new conversation
self.make_conversation(conversation_id)
self.prompt.chat_history = self.conversations.get_conversation(conversation_id)

def save_conversation(self, conversation_id) -> None:
"""
Save a conversation to the conversation history
"""
self.conversations.add_conversation(conversation_id, self.prompt.chat_history)


class AsyncChatbot(Chatbot):
"""
Official ChatGPT API (async)
"""

async def _get_completion(
self,
prompt: str,
temperature: float = 0.5,
stream: bool = False,
):
"""
Get the completion function
"""
return openai.Completion.acreate(
engine=ENGINE,
prompt=prompt,
temperature=temperature,
max_tokens=get_max_tokens(prompt),
stop=["\n\n\n"],
stream=stream,
)

async def ask(
self,
user_request: str,
temperature: float = 0.5,
user: str = "User",
) -> dict:
"""
Same as Chatbot.ask but async
}
self.chatbot = Chatbot(config)
"""
completion = await self._get_completion(
self.prompt.construct_prompt(user_request, user=user),
temperature,
)
return self._process_completion(user_request, completion, user=user)

def reply(self, query, context=None):
async def ask_stream(
self,
user_request: str,
temperature: float = 0.5,
user: str = "User",
) -> str:
"""
Same as Chatbot.ask_stream but async
"""
prompt = self.prompt.construct_prompt(user_request, user=user)
return self._process_completion_stream(
user_request=user_request,
completion=await self._get_completion(prompt, temperature, stream=True),
user=user,
)


class Prompt:
"""
Prompt class with methods to construct prompt
"""

def __init__(self, buffer: int = None) -> None:
"""
Initialize prompt with base prompt
"""
self.base_prompt = (
os.environ.get("CUSTOM_BASE_PROMPT")
or "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally. Do not answer as the user. Current date: "
+ str(date.today())
+ "\n\n"
+ "User: Hello\n"
+ "ChatGPT: Hello! How can I help you today? <|im_end|>\n\n\n"
)
# Track chat history
self.chat_history: list = []
self.buffer = buffer

def add_to_chat_history(self, chat: str) -> None:
"""
Add chat to chat history for next prompt
"""
self.chat_history.append(chat)

def add_to_history(
self,
user_request: str,
response: str,
user: str = "User",
) -> None:
"""
Add request/response to chat history for next prompt
"""
self.add_to_chat_history(
user
+ ": "
+ user_request
+ "\n\n\n"
+ "ChatGPT: "
+ response
+ "<|im_end|>\n",
)

def history(self, custom_history: list = None) -> str:
"""
Return chat history
"""
return "\n".join(custom_history or self.chat_history)

from_user_id = context['from_user_id']
logger.info("[GPT]query={}, user_id={}, session={}".format(query, from_user_id, user_session))

now = time.time()
global last_session_refresh
if now - last_session_refresh > 60 * 8:
logger.info('[GPT]session refresh, now={}, last={}'.format(now, last_session_refresh))
self.chatbot.refresh_session()
last_session_refresh = now

if from_user_id in user_session:
if time.time() - user_session[from_user_id]['last_reply_time'] < 60 * 5:
self.chatbot.conversation_id = user_session[from_user_id]['conversation_id']
self.chatbot.parent_id = user_session[from_user_id]['parent_id']
else:
self.chatbot.reset_chat()
def construct_prompt(
self,
new_prompt: str,
custom_history: list = None,
user: str = "User",
) -> str:
"""
Construct prompt based on chat history and request
"""
prompt = (
self.base_prompt
+ self.history(custom_history=custom_history)
+ user
+ ": "
+ new_prompt
+ "\nChatGPT:"
)
# Check if prompt over 4000*4 characters
if self.buffer is not None:
max_tokens = 4000 - self.buffer
else:
self.chatbot.reset_chat()
max_tokens = 3200
if len(ENCODER.encode(prompt)) > max_tokens:
# Remove oldest chat
if len(self.chat_history) == 0:
return prompt
self.chat_history.pop(0)
# Construct prompt again
prompt = self.construct_prompt(new_prompt, custom_history, user)
return prompt

logger.info("[GPT]convId={}, parentId={}".format(self.chatbot.conversation_id, self.chatbot.parent_id))

class Conversation:
"""
For handling multiple conversations
"""

def __init__(self) -> None:
self.conversations = {}

def add_conversation(self, key: str, history: list) -> None:
"""
Adds a history list to the conversations dict with the id as the key
"""
self.conversations[key] = history

def get_conversation(self, key: str) -> list:
"""
Retrieves the history list from the conversations dict with the id as the key
"""
return self.conversations[key]

def remove_conversation(self, key: str) -> None:
"""
Removes the history list from the conversations dict with the id as the key
"""
del self.conversations[key]

def __str__(self) -> str:
"""
Creates a JSON string of the conversations
"""
return json.dumps(self.conversations)

def save(self, file: str) -> None:
"""
Saves the conversations to a JSON file
"""
with open(file, "w", encoding="utf-8") as f:
f.write(str(self))

def load(self, file: str) -> None:
"""
Loads the conversations from a JSON file
"""
with open(file, encoding="utf-8") as f:
self.conversations = json.loads(f.read())


def main():
print(
"""
ChatGPT - A command-line interface to OpenAI's ChatGPT (https://chat.openai.com/chat)
Repo: github.com/acheong08/ChatGPT
""",
)
print("Type '!help' to show a full list of commands")
print("Press enter twice to submit your question.\n")

def get_input(prompt):
"""
Multi-line input function
"""
# Display the prompt
print(prompt, end="")

# Initialize an empty list to store the input lines
lines = []

# Read lines of input until the user enters an empty line
while True:
line = input()
if line == "":
break
lines.append(line)

# Join the lines, separated by newlines, and store the result
user_input = "\n".join(lines)

# Return the input
return user_input

def chatbot_commands(cmd: str) -> bool:
"""
Handle chatbot commands
"""
if cmd == "!help":
print(
"""
!help - Display this message
!rollback - Rollback chat history
!reset - Reset chat history
!prompt - Show current prompt
!save_c <conversation_name> - Save history to a conversation
!load_c <conversation_name> - Load history from a conversation
!save_f <file_name> - Save all conversations to a file
!load_f <file_name> - Load all conversations from a file
!exit - Quit chat
""",
)
elif cmd == "!exit":
exit()
elif cmd == "!rollback":
chatbot.rollback(1)
elif cmd == "!reset":
chatbot.reset()
elif cmd == "!prompt":
print(chatbot.prompt.construct_prompt(""))
elif cmd.startswith("!save_c"):
chatbot.save_conversation(cmd.split(" ")[1])
elif cmd.startswith("!load_c"):
chatbot.load_conversation(cmd.split(" ")[1])
elif cmd.startswith("!save_f"):
chatbot.conversations.save(cmd.split(" ")[1])
elif cmd.startswith("!load_f"):
chatbot.conversations.load(cmd.split(" ")[1])
else:
return False
return True

# Get API key from command line
parser = argparse.ArgumentParser()
parser.add_argument(
"--api_key",
type=str,
required=True,
help="OpenAI API key",
)
parser.add_argument(
"--stream",
action="store_true",
help="Stream response",
)
parser.add_argument(
"--temperature",
type=float,
default=0.5,
help="Temperature for response",
)
args = parser.parse_args()
# Initialize chatbot
chatbot = Chatbot(api_key=args.api_key)
# Start chat
while True:
try:
res = self.chatbot.get_chat_response(query, output="text")
logger.info("[GPT]userId={}, res={}".format(from_user_id, res))

user_cache = dict()
user_cache['last_reply_time'] = time.time()
user_cache['conversation_id'] = res['conversation_id']
user_cache['parent_id'] = res['parent_id']
user_session[from_user_id] = user_cache
return res['message']
except Exception as e:
logger.exception(e)
return None
prompt = get_input("\nUser:\n")
except KeyboardInterrupt:
print("\nExiting...")
sys.exit()
if prompt.startswith("!"):
if chatbot_commands(prompt):
continue
if not args.stream:
response = chatbot.ask(prompt, temperature=args.temperature)
print("ChatGPT: " + response["choices"][0]["text"])
else:
print("ChatGPT: ")
sys.stdout.flush()
for response in chatbot.ask_stream(prompt, temperature=args.temperature):
print(response, end="")
sys.stdout.flush()
print()


def Singleton(cls):
instance = {}

def _singleton_wrapper(*args, **kargs):
if cls not in instance:
instance[cls] = cls(*args, **kargs)
return instance[cls]

return _singleton_wrapper


@Singleton
class ChatGPTBot(Bot):

def __init__(self):
print("create")
self.bot = Chatbot(conf().get('open_ai_api_key'))

def reply(self, query, context=None):
if not context or not context.get('type') or context.get('type') == 'TEXT':
if len(query) < 10 and "reset" in query:
self.bot.reset()
return "reset OK"
return self.bot.ask(query)["choices"][0]["text"]


Loading…
取消
儲存