-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathDemo.py
117 lines (98 loc) · 4.1 KB
/
Demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import fire
import sys
import os
import json
from pathlib import Path
from typing import List
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + "/../..")
import llama.modeling.Loader as Loader
from Tokenizer import Tokenizer
from ModelParams import ModelParams
def main(
ckpt_dir: str,
tokenizer_path: str,
temperature: float = 0.0,
top_p: float = 0.95,
batch: int = 4,
seqlen_scale_up: int = 1,
unaligned_batch: bool = False,
max_gen_len: int = 256,
friendly_gqa: bool = False, # done gqa by repeating key and value by key_value_cache op
fused_qkv: bool = True, # fuse qkv linear
fused_kvcache: bool = True, # fuse key_value_cache and multi_head_attention
fused_ffn_glu: bool = True, # fuse feed forward gate linear unit
auto_causal: bool = True, # causal mask is auto done by attention op, no need to pass additional mask to the model
quantized_cache: bool = True, # 8bit kv cache quantization
cache_layout: int = 0, # change kv cache layout for hardware performance friendly
cache_mode: int = 0, # change kv cache indexing mode for memory management friendly, only affected when dynamic_batching == True
dynamic_batching: bool = True, # use dynamic batching scheduling
context_chunking: bool = True, # enable context chunking for dynamic batching
dump_tensor_path: str = None,
dump_steps: List[int] = []
):
tokenizer = Tokenizer(model_path=tokenizer_path)
with open(Path(ckpt_dir) / "opmx_params.json", "r") as f:
params = json.loads(f.read())
params: ModelParams = ModelParams(**params)
head_dim = params.hidden_dim // params.num_heads
generator = Loader.load(
ckpt_dir, params,
friendly_gqa=friendly_gqa,
fused_qkv=fused_qkv,
fused_kvcache=fused_kvcache,
fused_ffn_glu=fused_ffn_glu,
fused_alibi=False,
auto_causal=auto_causal,
with_rope=True,
with_alibi=False,
quantized_cache=quantized_cache,
cache_layout=cache_layout,
cache_mode=cache_mode,
dynamic_batching=dynamic_batching,
attn_wqkv_bias_term=True,
attn_wo_bias_term=False,
ffn_linear_bias_term=False,
load_to_cpu=False,
rotary_dim=head_dim // 2,
dump_tensor_path=dump_tensor_path,
dump_steps=dump_steps
)
generator.context_chunking = context_chunking if dynamic_batching else False
if unaligned_batch:
test_prompt = [ # For these prompts, the expected answer is the natural continuation of the prompt
"I believe the meaning of life is",
"Simply put, the theory of relativity states that ",
"""A brief message congratulating the team on the launch:
Hi everyone,
I just """,
# Few shot prompt (providing a few examples before asking model to complete more);
"""Translate English to French:
sea otter => loutre de mer
peppermint => menthe poivrée
plush girafe => girafe peluche
cheese =>""",
]
test_prompt = [tokenizer.encode(t, bos=True, eos=False) for t in test_prompt]
prompt_tokens = test_prompt.copy()
for _ in range((batch - 1) // len(test_prompt)):
prompt_tokens.extend(test_prompt)
else:
# test_prompt = "I believe the meaning of life is"
test_prompt = "[Round 1]\n\n问:你好\n\n答:"
test_prompt = tokenizer.encode(test_prompt, bos=True, eos=False)
_scale_up_prompt = []
for _ in range(seqlen_scale_up):
_scale_up_prompt.extend(test_prompt)
test_prompt = _scale_up_prompt
prompt_tokens = [test_prompt for _ in range(batch)]
# print("prompt_tokens: ", prompt_tokens)
print(f"prepared {len(prompt_tokens)} prompts")
results = generator.generate(
prompt_tokens[:batch], tokenizer.get_eos_id(), tokenizer.get_pad_id(),
max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, top_k=0
)
for result in results:
print(tokenizer.decode(result))
print("\n==================================\n")
if __name__ == "__main__":
fire.Fire(main)