Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add_ChatHaruhiTrain_for_roleLLM #54

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ChatHaruhi2.0/ChatHaruhi/ChatGLM2GPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
tokenizer_GLM = None
model_GLM = None


def initialize_GLM2LORA():
global model_GLM, tokenizer_GLM

Expand All @@ -30,9 +31,11 @@ def initialize_GLM2LORA():

return model_GLM, tokenizer_GLM


def GLM_tokenizer(text):
return len(tokenizer_GLM.encode(text))


class ChatGLM2GPT(BaseLLM):
def __init__(self, model = "haruhi-fusion"):
super(ChatGLM2GPT, self).__init__()
Expand Down
44 changes: 22 additions & 22 deletions ChatHaruhi2.0/ChatHaruhi/ChatHaruhi.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,20 @@
from .ChromaDB import ChromaDB
import os

from .utils import luotuo_openai_embedding, tiktokenizer

from .utils import response_postprocess


class ChatHaruhi:

def __init__(self, system_prompt = None, \
role_name = None, role_from_hf = None, \
story_db=None, story_text_folder = None, \
llm = 'openai', \
embedding = 'luotuo_openai', \
max_len_story = None, max_len_history = None,
verbose = False):
def __init__(self, system_prompt=None, role_name=None, role_from_hf=None, story_db=None, story_text_folder=None,
llm='openai', embedding='luotuo_openai', max_len_story=None, max_len_history=None, verbose=False):
super(ChatHaruhi, self).__init__()
self.verbose = verbose

# constants
self.story_prefix_prompt = "Classic scenes for the role are as follows:\n"
self.story_prefix_prompt = "\nClassic scenes for the role are as follows:"
self.k_search = 19
self.narrator = ['旁白', '', 'scene','Scene','narrator' , 'Narrator']
self.narrator = ['旁白', '', 'scene', 'Scene', 'narrator', 'Narrator']
self.dialogue_divide_token = '\n###\n'
self.dialogue_bra_token = '「'
self.dialogue_ket_token = '」'
Expand All @@ -45,6 +39,8 @@ def __init__(self, system_prompt = None, \
self.llm, self.tokenizer = self.get_models('BaiChuan2GPT')
elif llm == "ernie":
self.llm, self.tokenizer = self.get_models('ernie')
elif llm == "Llama2GPT":
self.llm, self.tokenizer = self.get_models('Llama2GPT')
else:
print(f'warning! undefined llm {llm}, use openai instead.')
self.llm, self.tokenizer = self.get_models('openai')
Expand Down Expand Up @@ -128,14 +124,20 @@ def __init__(self, system_prompt = None, \
elif story_db:
self.db = ChromaDB()
self.db.load(story_db)

elif story_text_folder:
# print("Building story database from texts...")
self.db = self.build_story_db(story_text_folder)
db_name = "db_" + story_text_folder.split("/")[-1].replace(" ", "_")
if not os.path.exists(db_name):
self.db = self.build_story_db(story_text_folder)
else:
self.db = ChromaDB()
self.db.load(db_name)

else:
self.db = None
print('warning! database not yet figured out, both story_db and story_text_folder are not inputted.')
# raise ValueError("Either story_db or story_text_folder must be provided")


self.max_len_story, self.max_len_history = self.get_tokenlen_setting('openai')

Expand All @@ -149,8 +151,6 @@ def __init__(self, system_prompt = None, \

self.dialogue_history = []



def check_system_prompt(self, system_prompt):
# if system_prompt end with .txt, read the file with utf-8
# else, return the string directly
Expand All @@ -159,7 +159,6 @@ def check_system_prompt(self, system_prompt):
return f.read()
else:
return system_prompt


def get_models(self, model_name):

Expand Down Expand Up @@ -187,6 +186,9 @@ def get_models(self, model_name):
elif model_name == "BaiChuan2GPT":
from .BaiChuan2GPT import BaiChuan2GPT, BaiChuan_tokenizer
return (BaiChuan2GPT(), BaiChuan_tokenizer)
elif model_name == "Llama2GPT":
from .Llama2GPT import Llama2GPT, Llama_tokenizer
return (Llama2GPT(), Llama_tokenizer)
else:
print(f'warning! undefined model {model_name}, use openai instead.')
from .LangChainGPT import LangChainGPT
Expand Down Expand Up @@ -232,7 +234,8 @@ def build_story_db(self, text_folder):
for mystr in strs:
vecs.append(self.embedding(mystr))

db.init_from_docs(vecs, strs)
db_name = "db_" + text_folder.split("/")[-1].replace(" ", "_")
db.init_from_docs(vecs, strs, db_name)

return db

Expand All @@ -243,11 +246,10 @@ def chat(self, text, role):
# add system prompt
self.llm.initialize_message()
self.llm.system_message(self.system_prompt)


# add story
query = self.get_query_string(text, role)
self.add_story( query )
self.add_story(query)

# add history
self.add_history()
Expand All @@ -263,8 +265,6 @@ def chat(self, text, role):
# record dialogue history
self.dialogue_history.append((query, response))



return response

def get_query_string(self, text, role):
Expand Down Expand Up @@ -321,4 +321,4 @@ def add_history(self):
if query is not None:
self.llm.user_message(query)
if response is not None:
self.llm.ai_message(response)
self.llm.ai_message(response)
80 changes: 80 additions & 0 deletions ChatHaruhi2.0/ChatHaruhi/ChatHaruhiTrain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from .ChatHaruhi import ChatHaruhi
from .ChromaDB import ChromaDB
from transformers import AutoTokenizer
import os


class ChatHaruhiTrain(ChatHaruhi):
def add_story_with_expire(self, query):
if self.db is None:
print("No vec DB!")
return

query_vec = self.embedding(query)
stories = self.db.search(query_vec, self.k_search)

story_string = self.story_prefix_prompt + self.dialogue_divide_token

sum_story_token = self.tokenizer(story_string)

for story in stories:
if query.strip() in story.strip():
continue

story_token = self.tokenizer(story.strip()) + self.tokenizer(self.dialogue_divide_token)
if sum_story_token + story_token > self.max_len_story:
break
else:
sum_story_token += story_token
story_string += story.strip() + self.dialogue_divide_token

self.llm.user_message(story_string)

def generate_prompt(self, query, history, target):
# 这里修改下其他超参,不规范删了
self.k_search = 5
self.max_len_story = 1500
self.max_len_history = 1200
self.story_prefix_prompt = "\nClassic scenes for the role are as follows:"

self.llm.initialize_message()

self.llm.system_message(self.system_prompt)

self.add_story_with_expire(query)

self.add_history(history)

self.llm.user_message(query)

self.llm.user_message(target)

return self.llm.messages

def add_history(self, history_list):

if len(history_list) == 0:
return

sum_history_token = 0
flag = 0
for history in history_list:
current_count = 0
if history is not None:
current_count += self.tokenizer(history)

sum_history_token += current_count
if sum_history_token > self.max_len_history:
break
else:
flag += 1

if flag == 0:
print('warning! no history added. the last dialogue is too long.')

# 是否添加历史前缀,
history_message = ""
for history in history_list[-flag:]:
history_message += history
self.llm.user_message(history_message)

26 changes: 12 additions & 14 deletions ChatHaruhi2.0/ChatHaruhi/ChromaDB.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,26 @@
import string
import os


class ChromaDB(BaseDB):

def __init__(self):
self.client = None
self.collection = None
self.path = None

def init_db(self):
def init_db(self, folder_name=None):

if self.client is not None:
print('ChromaDB has already been initialized')
return

folder_name = ''

while os.path.exists(folder_name) or folder_name == '':
# try to create a folder named temp_<random string> which is not yet existed
folder_name = "tempdb_" + ''.join(random.sample(string.ascii_letters + string.digits, 8))
if folder_name is None:
while os.path.exists(folder_name) or folder_name is None:
# try to create a folder named temp_<random string> which is not yet existed
folder_name = "tempdb_" + ''.join(random.sample(string.ascii_letters + string.digits, 8))

self.path = folder_name
self.client = chromadb.PersistentClient(path = folder_name)
self.client = chromadb.PersistentClient(path=folder_name)

self.collection = self.client.get_or_create_collection("search")

Expand All @@ -38,24 +37,23 @@ def save(self, file_path):
# remove previous path if it start with tempdb
if previous_path.startswith("tempdb"):
os.system("rm -rf " + previous_path)


def load(self, file_path):
self.path = file_path
self.client = chromadb.PersistentClient(path = file_path)
self.client = chromadb.PersistentClient(path=file_path)
self.collection = self.client.get_collection("search")

def search(self, vector, n_results):
results = self.collection.query(query_embeddings=[vector], n_results=n_results)
return results['documents'][0]

def init_from_docs(self, vectors, documents):
def init_from_docs(self, vectors, documents, folder_name=None):
if self.client is None:
self.init_db()
self.init_db(folder_name)

ids = []
for i, doc in enumerate(documents):
first_four_chat = doc[:min(4, len(doc))]
ids.append( str(i) + "_" + doc)
self.collection.add(embeddings=vectors, documents=documents, ids = ids)
ids.append(str(i) + "_" + first_four_chat)
self.collection.add(embeddings=vectors, documents=documents, ids=ids)

95 changes: 95 additions & 0 deletions ChatHaruhi2.0/ChatHaruhi/Llama2GPT.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import torch
from .BaseLLM import BaseLLM
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig
from peft import PeftModel

tokenizer_Llama = AutoTokenizer.from_pretrained("../../../llm/Llama-2-7b-hf/",
use_fast=True, trust_remote_code=True)
model_Llama = None


def initialize_Llama2LORA():
global model_Llama, tokenizer_Llama

if model_Llama is None:
model_Llama = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b",
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
model_Llama = PeftModel.from_pretrained(
model_Llama,
"silk-road/RoleLLM_Llama2_7B"
)
model_Llama.generation_config = GenerationConfig.from_pretrained(
"meta-llama/Llama-2-7b"
)

if tokenizer_Llama is None:
tokenizer_Llama = AutoTokenizer.from_pretrained(
"meta-llama/Llama-2-7b",
use_fast=True,
trust_remote_code=True
)

return model_Llama, tokenizer_Llama


def Llama_tokenizer(text):
return len(tokenizer_Llama.encode(text))


class Llama2GPT(BaseLLM):
def __init__(self, model="roleLLM-llama"):
super(Llama2GPT, self).__init__()
if model == "llama2-7b":
self.tokenizer = AutoTokenizer.from_pretrained(
"meta-llama/Llama-2-7b",
use_fast=True,
trust_remote_code=True
),
self.model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b",
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
self.model.generation_config = GenerationConfig.from_pretrained(
"meta-llama/Llama-2-7b"
)
elif model == "roleLLM-llama":
# self.model, self.tokenizer = initialize_Llama2LORA()
# 改回去
self.model = None
self.tokenizer = AutoTokenizer.from_pretrained(
"../../../llm/Llama-2-7b-hf/",
use_fast=True, trust_remote_code=True)
else:
raise Exception("Unknown Llama Model! Currently supported: [Llama2-7B, roleLLM-llama]")
self.messages = ""

def initialize_message(self):
self.messages = ""

# 待修改
def ai_message(self, payload):
self.messages = self.messages + "\n" + payload

# 待修改
def system_message(self, payload):
self.messages = self.messages + "\n" + payload

# 待修改
def user_message(self, payload):
self.messages = self.messages + "\n" + payload

def get_response(self):
with torch.no_grad():
response = self.model.chat(self.tokenizer, self.messages)
return response

def print_prompt(self):
print(type(self.messages))
print(self.messages)
1 change: 1 addition & 0 deletions ChatHaruhi2.0/ChatHaruhi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@
# }

from .ChatHaruhi import ChatHaruhi
from .ChatHaruhiTrain import ChatHaruhiTrain
Loading