-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgenerate_synthetic_dataset.py
124 lines (110 loc) · 7.57 KB
/
generate_synthetic_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from tqdm import tqdm
import os
import numpy as np
import re
from query_llm import QueryLLM
from utils import *
def evaluate_multiple_choice_answers(args, i, curr_new_question, correct_answer=None, correct_answer_baseline=None):
pattern_a = r"\(a\)\s*(.*?)\s*(?=\(b\))"
choice_a = re.findall(pattern_a, curr_new_question)
if choice_a[0][-1] == '.':
choice_a[0] = choice_a[0][:-1]
choice_a[0] = '(a) ' + choice_a[0]
if args['datasets']['linda_problem_variant'] == 'variant_six':
assert correct_answer is not None and correct_answer_baseline is not None
# it has three options
pattern_b = r"\(b\)\s*(.*)\s*(?=\(c\))"
pattern_c = r"\(c\)\s*(.*)"
choice_b = re.findall(pattern_b, curr_new_question)
choice_c = re.findall(pattern_c, curr_new_question)
if choice_b[0][-1] == '.':
choice_b[0] = choice_b[0][:-1]
if choice_c[0][-1] == '.':
choice_c[0] = choice_c[0][:-1]
choice_b[0] = '(b) ' + choice_b[0]
choice_c[0] = '(c) ' + choice_c[0]
if i == 0: # golden
target_answer = choice_a[0] if correct_answer == 0 else choice_b[0] if correct_answer == 1 else choice_c[0]
incorrect_answer = [choice_b[0], choice_c[0]] if correct_answer == 0 else [choice_a[0], choice_c[0]] if correct_answer == 1 else [choice_a[0], choice_b[0]]
else: # baseline
target_answer = choice_a[0] if correct_answer_baseline == 0 else choice_b[0] if correct_answer_baseline == 1 else choice_c[0]
incorrect_answer = [choice_b[0], choice_c[0]] if correct_answer_baseline == 0 else [choice_a[0], choice_c[0]] if correct_answer_baseline == 1 else [choice_a[0], choice_b[0]]
else:
# two options
pattern_b = r"\(b\)\s*(.*)"
choice_b = re.findall(pattern_b, curr_new_question)
if choice_b[0][-1] == '.':
choice_b[0] = choice_b[0][:-1]
choice_b[0] = '(b) ' + choice_b[0]
target_answer = choice_a[0] if len(choice_a[0]) < len(choice_b[0]) else choice_b[0]
incorrect_answer = choice_a[0] if target_answer == choice_b[0] else choice_b[0]
return target_answer, incorrect_answer
def data_generation(device, args):
synthetic_data_filename = args['datasets']['synthetic_data_filename']
fallacy_type = args['datasets']['fallacy_type']
linda_problem_variant = args['datasets']['linda_problem_variant']
LLM = QueryLLM(args)
with torch.no_grad():
########### In-Context Learning ###########
for n in tqdm(range(args['datasets']['num_synthetic_examples'])):
if args['datasets']['fallacy_type'] == 'linda':
if args['datasets']['generate_mode'] == 'baseline':
new_question_baseline = LLM.query_llm(llm_model=args['models']['llm_model'], step='generate_data', verbose=args['inference']['verbose'])
new_questions = [new_question_baseline]
else:
if linda_problem_variant == 'variant_six':
new_question_gold, new_question_baseline, correct_answer, correct_answer_baseline = LLM.query_llm(llm_model=args['models']['llm_model'], step='generate_data', verbose=args['inference']['verbose'])
new_questions = [new_question_gold, new_question_baseline]
#elif args['datasets']['linda_problem_variant'] == 'variant_four':
#new_question_gold, new_question_random_achievement, new_question_random_name = LLM.query_llm(llm_model=args['models']['llm_model'], step='generate_data', verbose=args['inference']['verbose'])
#new_questions = [new_question_gold, new_question_random_achievement, new_question_random_name]
else:
new_question_gold, new_question_random = LLM.query_llm(llm_model=args['models']['llm_model'], step='generate_data', verbose=args['inference']['verbose'])
new_questions = [new_question_gold, new_question_random]
elif args['datasets']['fallacy_type'] == 'sets':
new_question_gold, new_question_control, new_question_framing_gold, new_question_framing_control = LLM.query_llm(llm_model=args['models']['llm_model'], step='generate_data', verbose=args['inference']['verbose'])
new_questions = [new_question_gold, new_question_control, new_question_framing_gold, new_question_framing_control]
elif args['datasets']['fallacy_type'] == 'math':
new_question_random_animal, new_question_random_number, new_question_random, target_random_animal, target_random_number, target_random = LLM.query_llm(llm_model=args['models']['llm_model'], step='generate_data', verbose=args['inference']['verbose'])
new_questions = [new_question_random_animal, new_question_random_number, new_question_random]
else:
assert False, "Invalid fallacy type."
# try:
for i, curr_new_question in enumerate(new_questions):
if args['datasets']['fallacy_type'] == 'linda':
if linda_problem_variant == 'variant_six':
target_answer, incorrect_answer = evaluate_multiple_choice_answers(args, i, curr_new_question, correct_answer=correct_answer, correct_answer_baseline=correct_answer_baseline)
else:
target_answer, incorrect_answer = evaluate_multiple_choice_answers(args, i, curr_new_question)
elif args['datasets']['fallacy_type'] == 'sets':
target_answer, incorrect_answer = '[No]', '[Yes]'
elif args['datasets']['fallacy_type'] == 'math':
if i == 0:
target_answer, incorrect_answer = target_random_animal, "Any value other than " + target_random_animal
elif i == 1:
target_answer, incorrect_answer = target_random_number, "Any animal other than " + target_random_number
else:
target_answer, incorrect_answer = target_random, "Any value other than " + target_random
else:
assert False, "Invalid fallacy type."
########### Record New Data Entry ###########
if args['datasets']['generate_mode'] == 'baseline':
generation_mode = 'baseline'
elif args['datasets']['fallacy_type'] == 'math':
if i == 0:
generation_mode = 'animals_random'
elif i == 1:
generation_mode = 'numbers_random'
else:
generation_mode = 'random'
else:
generation_mode = 'gold' if i % 2 == 0 else 'random'
#if args['datasets']['linda_problem_variant'] == 'variant_four' and i != 0:
# generation_mode += '_achievement' if i == 1 else '_name'
logical_connector = args['datasets']['connector']
framing = 'framing' if args['datasets']['fallacy_type'] == 'sets' and i > 1 else None
response_dict = {'question_idx': n, 'question': curr_new_question, 'target_answer': target_answer, 'incorrect_answer': incorrect_answer, 'generation_mode': generation_mode}
write_response_to_json(n, response_dict, synthetic_data_filename, fallacy_type=fallacy_type, framing=framing,
generation_mode=generation_mode, linda_problem_variant=linda_problem_variant, logical_connector=logical_connector)
# except:
# continue