-
Notifications
You must be signed in to change notification settings - Fork 197
/
Copy pathsentence_augmentation_mapper.py
123 lines (106 loc) · 4.76 KB
/
sentence_augmentation_mapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from data_juicer.ops.base_op import OPERATORS, Mapper
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.model_utils import get_model, prepare_model
torch = LazyLoader('torch', 'torch')
transformers = LazyLoader('transformers', 'transformers')
DEFAULT_SYSTEM_PROMPT = "A chat between a curious user and an artificial \
intelligence assistant. The assistant gives helpful, detailed, and \
polite answers to the user's questions."
OP_NAME = 'sentence_augmentation_mapper'
@OPERATORS.register_module(OP_NAME)
class SentenceAugmentationMapper(Mapper):
"""Mapper to augment sentences.
The purpose of this operation is to enhance sentences.
If the input text is at the document level, the enhancement
effect may not be optimal. Therefore, please consider the
length of the input text carefully.
Recommended model list: [
lmsys/vicuna-13b-v1.5
Qwen/Qwen2-7B-Instruct
]
"""
_accelerator = 'cuda'
def __init__(self,
hf_model: str = 'Qwen/Qwen2-7B-Instruct',
system_prompt: str = None,
task_sentence: str = None,
max_new_tokens=256,
temperature=0.2,
top_p=None,
num_beams=1,
*args,
**kwargs):
"""
Initialization method.
:param hf_model: Hugginface model id.
:param system_prompt: System prompt.
:param task_sentence: The instruction for the current task.
:param max_new_tokens: the maximum number of new tokens
generated by the model.
:param temperature: used to control the randomness of
generated text. The higher the temperature, the more
random and creative the generated text will be.
:param top_p: randomly select the next word from the group
of words whose cumulative probability reaches p.
:param num_beams: the larger the beam search size, the higher
the quality of the generated text.
:param args: extra args
:param kwargs: extra args
"""
kwargs.setdefault('mem_required', '31GB')
kwargs.setdefault('num_proc', 1)
super().__init__(*args, **kwargs)
if system_prompt is None:
system_prompt = DEFAULT_SYSTEM_PROMPT
self.system_prompt = system_prompt
self.hf_model = hf_model
self.max_new_tokens = max_new_tokens
self.model_key = prepare_model(model_type='huggingface',
pretrained_model_name_or_path=hf_model)
self.temperature = temperature
self.top_p = top_p
self.num_beams = num_beams
self.task_sentence = task_sentence
def process_single(self, sample=None, rank=None):
if self.task_sentence is None:
print('[Warning] task_sentence is None!')
sample[self.text_key] = ''
return sample
model, processor = get_model(model_key=self.model_key,
rank=rank,
use_cuda=self.use_cuda())
if 'vicuna' in self.hf_model:
input_prompt = self.system_prompt + " USER: Here \
is a sentence: \"" + sample[
self.text_key] + "\". " + self.task_sentence + ' ASSISTANT:'
else:
messages = [{
'role': 'system',
'content': self.system_prompt
}, {
'role':
'user',
'content':
"Here is a sentence: \"" + sample[self.text_key] + "\". " +
self.task_sentence
}]
input_prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)
inputs = processor(input_prompt, return_tensors='pt').to(model.device)
response = model.generate(**inputs,
max_new_tokens=self.max_new_tokens,
eos_token_id=processor.eos_token_id,
top_p=self.top_p,
temperature=self.temperature,
num_beams=self.num_beams)
input_token_len = inputs.input_ids.shape[1]
n_diff_input_output = (inputs.input_ids !=
response[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are \
not the same as the input_ids')
output = processor.batch_decode(response[:, input_token_len:],
skip_special_tokens=True)[0]
output = output.strip().strip("\"")
sample[self.text_key] = output
return sample