-
Notifications
You must be signed in to change notification settings - Fork 102
/
Copy pathcli_demo.py
89 lines (77 loc) · 2.99 KB
/
cli_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# -*- utf8 -*
import argparse
from conversation import get_conv_template
try:
from vllm import LLM, SamplingParams
is_vllm_avaiable = True
print("use vllm.generate to infer...")
except ImportError:
from transformers import LlamaForCausalLM, LlamaTokenizer
is_vllm_avaiable = False
print("use transformers.generate to infer...")
def infer_vllm(llm, sampling_params, prompt):
assert llm is not None
assert sampling_params is not None
generation = llm.generate(prompt, sampling_params, use_tqdm=False)
outputs = generation[0].outputs[0].text.strip()
return outputs
def infer(model, tokenizer, prompt):
assert model is not None
assert tokenizer is not None
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = model.generate(
**inputs,
max_new_tokens=args.max_new_tokens,
do_sample=True,
temperature=args.temperature,
top_p=args.top_p
)
outputs = tokenizer.decode(outputs.cpu()[0][len(inputs.input_ids[0]):], skip_special_tokens=True).strip()
return outputs
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Test XuanYuan-70B-chat with vLLM")
parser.add_argument("-c", "--checkpoint_path", type=str, help="Checkpoint path")
parser.add_argument("-n", "--max_new_tokens", type=int, default=1000)
parser.add_argument("-t", "--temperature", type=float, default=0.95)
parser.add_argument("-p", "--top_p", type=float, default=0.95)
args = parser.parse_args()
llm = None
sampling_params = None
model = None
tokenizer = None
if is_vllm_avaiable:
print("loading weight with vLLM...")
sampling_params = SamplingParams(
temperature=args.temperature,
top_p=args.top_p,
stop=list(["</s>"]),
max_tokens=args.max_new_tokens
)
llm = LLM(args.checkpoint_path, tensor_parallel_size=8)
else:
print("loading weight with transformers ...")
tokenizer = LlamaTokenizer.from_pretrained(args.checkpoint_path, use_fast=False, legacy=True)
model = LlamaForCausalLM.from_pretrained(args.checkpoint_path, device_map="auto")
conv = get_conv_template("XuanYuan-Chat")
print("########")
print("输入为: EXIT!! 表示退出")
print("输入为: CLEAR!! 表示清空上下文")
print("########")
while True:
content = input("输入: ")
if content.strip() == "EXIT!!":
print("exit....")
break
if content.strip() == "CLEAR!!":
conv = get_conv_template("XuanYuan-Chat")
print("clear...")
continue
conv.append_message(conv.roles[0], content.strip())
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
if is_vllm_avaiable:
outputs = infer_vllm(llm, sampling_params, prompt)
else:
outputs = infer(model, tokenizer, prompt)
print(f"输出: {outputs}")
conv.update_last_message(outputs)