-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathsoft_patterns_test.py
executable file
·130 lines (106 loc) · 4.42 KB
/
soft_patterns_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#!/usr/bin/env python3
"""
Script to evaluate the accuracy of a model.
"""
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
from collections import OrderedDict
from soft_patterns import MaxPlusSemiring, LogSpaceMaxTimesSemiring, evaluate_accuracy, SoftPatternClassifier, ProbSemiring, \
soft_pattern_arg_parser, general_arg_parser
from baselines.cnn import PooledCnnClassifier, max_pool_seq, cnn_arg_parser
from baselines.dan import DanClassifier
from baselines.lstm import AveragingRnnClassifier
import sys
import torch
import numpy as np
from torch.nn import LSTM
from data import vocab_from_text, read_embeddings, read_docs, read_labels
from rnn import Rnn
SCORE_IDX = 0
START_IDX_IDX = 1
END_IDX_IDX = 2
# TODO: refactor duplicate code with soft_patterns.py
def main(args):
print(args)
n = args.num_train_instances
mlp_hidden_dim = args.mlp_hidden_dim
num_mlp_layers = args.num_mlp_layers
dev_vocab = vocab_from_text(args.vd)
print("Dev vocab size:", len(dev_vocab))
vocab, embeddings, word_dim = \
read_embeddings(args.embedding_file, dev_vocab)
if args.seed != -1:
torch.manual_seed(args.seed)
np.random.seed(args.seed)
if args.dan or args.bilstm:
num_padding_tokens = 1
elif args.cnn:
num_padding_tokens = args.window_size - 1
else:
pattern_specs = OrderedDict(sorted(([int(y) for y in x.split("-")] for x in args.patterns.split("_")),
key=lambda t: t[0]))
num_padding_tokens = max(list(pattern_specs.keys())) - 1
dev_input, _ = read_docs(args.vd, vocab, num_padding_tokens=num_padding_tokens)
dev_labels = read_labels(args.vl)
dev_data = list(zip(dev_input, dev_labels))
if n is not None:
dev_data = dev_data[:n]
num_classes = len(set(dev_labels))
print("num_classes:", num_classes)
if args.dan:
model = DanClassifier(mlp_hidden_dim,
num_mlp_layers,
num_classes,
embeddings,
args.gpu)
elif args.bilstm:
cell_type = LSTM
model = AveragingRnnClassifier(args.hidden_dim,
mlp_hidden_dim,
num_mlp_layers,
num_classes,
embeddings,
cell_type=cell_type,
gpu=args.gpu)
elif args.cnn:
model = PooledCnnClassifier(args.window_size,
args.num_cnn_layers,
args.cnn_hidden_dim,
num_mlp_layers,
mlp_hidden_dim,
num_classes,
embeddings,
pooling=max_pool_seq,
gpu=args.gpu)
else:
semiring = \
MaxPlusSemiring if args.maxplus else (
LogSpaceMaxTimesSemiring if args.maxtimes else ProbSemiring
)
if args.use_rnn:
rnn = Rnn(word_dim,
args.hidden_dim,
cell_type=LSTM,
gpu=args.gpu)
else:
rnn = None
model = SoftPatternClassifier(pattern_specs, mlp_hidden_dim, num_mlp_layers, num_classes, embeddings, vocab,
semiring, args.bias_scale_param, args.gpu, rnn, None, args.no_sl, args.shared_sl,
args.no_eps, args.eps_scale, args.self_loop_scale)
if args.gpu:
state_dict = torch.load(args.input_model)
else:
state_dict = torch.load(args.input_model, map_location=lambda storage, loc: storage)
model.load_state_dict(state_dict)
if args.gpu:
model.to_cuda(model)
test_acc = evaluate_accuracy(model, dev_data, args.batch_size, args.gpu)
print("Test accuracy: {:>8,.3f}%".format(100*test_acc))
return 0
if __name__ == '__main__':
parser = ArgumentParser(description=__doc__,
formatter_class=ArgumentDefaultsHelpFormatter,
parents=[soft_pattern_arg_parser(), cnn_arg_parser(), general_arg_parser()])
parser.add_argument("--dan", help="Dan classifier", action='store_true')
parser.add_argument("--cnn", help="CNN classifier", action='store_true')
parser.add_argument("--bilstm", help="BiLSTM classifier", action='store_true')
sys.exit(main(parser.parse_args()))