-
Notifications
You must be signed in to change notification settings - Fork 197
/
Copy pathpair_preference_mapper.py
131 lines (114 loc) · 5.21 KB
/
pair_preference_mapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import re
from typing import Dict, Optional
from loguru import logger
from pydantic import PositiveInt
from data_juicer.ops.base_op import OPERATORS, Mapper
from data_juicer.utils.model_utils import get_model, prepare_model
OP_NAME = 'pair_preference_mapper'
# TODO: Extend LLM-based OPs into API-based implementation.
@OPERATORS.register_module(OP_NAME)
class PairPreferenceMapper(Mapper):
"""
Mapper to construct paired preference samples.
"""
# avoid leading whitespace
DEFAULT_SYSTEM_PROMPT = (
'你的任务是根据参考信息修改问答对中的回答,在语言风格、事实性、人物身份、立场等任一方面与原回答相反。'
'必须按照以下标记格式输出,不要输出其他多余内容。\n'
'【回答】\n'
'生成的新回答\n'
'【原因】\n'
'生成该回答的原因')
DEFAULT_INPUT_TEMPLATE = ('【参考信息】\n'
'{reference}\n'
'\n'
'以下是原始问答对:\n'
'【问题】\n'
'{query}\n'
'【回答】\n'
'{response}')
DEFAULT_OUTPUT_PATTERN = r'.*?【回答】\s*(.*?)\s*【原因】\s*(.*)'
def __init__(self,
api_model: str = 'gpt-4o',
*,
api_endpoint: Optional[str] = None,
response_path: Optional[str] = None,
system_prompt: Optional[str] = None,
input_template: Optional[str] = None,
output_pattern: Optional[str] = None,
rejected_key: str = 'rejected_response',
reason_key: str = 'reason',
try_num: PositiveInt = 3,
model_params: Dict = {},
sampling_params: Dict = {},
**kwargs):
"""
Initialization method.
:param api_model: API model name.
:param api_endpoint: URL endpoint for the API.
:param response_path: Path to extract content from the API response.
Defaults to 'choices.0.message.content'.
:param system_prompt: System prompt for guiding the generation task.
:param input_template: Template for building the model input. It must
contain placeholders '{query}' and '{reponse}', and can optionally
include '{reference}'.
:param output_pattern: Regular expression for parsing model output.
:param rejected_key: The field name in the sample to store the
generated rejected response. Defaults to 'rejected_response'.
:param reason_key: The field name in the sample to store the reason for
generating the response. Defaults to 'reason'.
:param try_num: The number of retries for the API call in case of
response parsing failure. Defaults to 3.
:param model_params: Parameters for initializing the API model.
:param sampling_params: Extra parameters passed to the API call.
e.g {'temperature': 0.9, 'top_p': 0.95}
:param kwargs: Extra keyword arguments.
"""
super().__init__(**kwargs)
self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT
self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE
self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN
self.rejected_key = rejected_key
self.reason_key = reason_key
self.model_key = prepare_model(model_type='api',
model=api_model,
endpoint=api_endpoint,
response_path=response_path,
**model_params)
self.try_num = try_num
self.sampling_params = sampling_params
def build_input(self, sample):
mapping = {
'query': sample[self.query_key],
'response': sample[self.response_key],
'reference': sample.get(self.text_key, '')
}
return self.input_template.format_map(mapping)
def parse_output(self, raw_output):
logger.debug(raw_output)
match = re.match(self.output_pattern, raw_output, re.DOTALL)
if match:
return match.group(1).strip(), match.group(2).strip()
else:
return ('', '')
def process_single(self, sample, rank=None):
client = get_model(self.model_key, rank=rank)
messages = [{
'role': 'system',
'content': self.system_prompt
}, {
'role': 'user',
'content': self.build_input(sample)
}]
parsed_rejected, parsed_reason = '', ''
for _ in range(self.try_num):
try:
output = client(messages, **self.sampling_params)
parsed_rejected, parsed_reason = self.parse_output(output)
if parsed_rejected and parsed_reason:
break
except Exception as e:
logger.warning(f'Exception: {e}')
sample[self.rejected_key] = parsed_rejected
sample[self.reason_key] = parsed_reason
return sample