forked from alec-tschantz/predcoding
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate.py
117 lines (91 loc) · 3.65 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# pylint: disable=not-callable
# pylint: disable=no-member
import numpy as np
import torch
import mnist_utils
import functions as F
from network import PredictiveCodingNetwork
class AttrDict(dict):
__setattr__ = dict.__setitem__
__getattr__ = dict.__getitem__
def main(cf):
print(f"device [{cf.device}]")
print("loading MNIST data...")
train_set = mnist_utils.get_mnist_train_set()
test_set = mnist_utils.get_mnist_test_set()
img_train = mnist_utils.get_imgs(train_set)
img_test = mnist_utils.get_imgs(test_set)
label_train = mnist_utils.get_labels(train_set)
label_test = mnist_utils.get_labels(test_set)
if cf.data_size is not None:
test_size = cf.data_size // 5
img_train = img_train[:, 0 : cf.data_size]
label_train = label_train[:, 0 : cf.data_size]
img_test = img_test[:, 0:test_size]
label_test = label_test[:, 0:test_size]
msg = "img_train {} img_test {} label_train {} label_test {}"
print(msg.format(img_train.shape, img_test.shape, label_train.shape, label_test.shape))
print("performing preprocessing...")
if cf.apply_scaling:
img_train = mnist_utils.scale_imgs(img_train, cf.img_scale)
img_test = mnist_utils.scale_imgs(img_test, cf.img_scale)
label_train = mnist_utils.scale_labels(label_train, cf.label_scale)
label_test = mnist_utils.scale_labels(label_test, cf.label_scale)
if cf.apply_inv and cf.act_fn != F.RELU:
img_train = F.f_inv(img_train, cf.act_fn)
img_test = F.f_inv(img_test, cf.act_fn)
model = PredictiveCodingNetwork(cf)
with torch.no_grad():
for epoch in range(cf.n_epochs):
print(f"\nepoch {epoch}")
img_batches, label_batches = mnist_utils.get_batches(img_train, label_train, cf.batch_size, cf.percent_data_used)
print(f"training on {len(img_batches)} batches of size {cf.batch_size}")
model.train_epoch(label_batches, img_batches, epoch_num=epoch)
img_batches, label_batches = mnist_utils.get_batches(img_test, label_test, cf.batch_size, cf.percent_data_used)
print("generating images...")
pred_imgs = model.generate_data(label_batches[0])
mnist_utils.plot_imgs(pred_imgs, cf.img_path.format(epoch))
np.random.seed(cf.seed)
perm = np.random.permutation(img_train.shape[1])
img_train = img_train[:, perm]
label_train = label_train[:, perm]
# calculate the inception score for p(y|x)
def calculate_inception_score(p_yx, eps=1E-16):
p_y = np.expand_dims(p_yx.mean(axis=0), 0) # calculate p(y)
kl_d = p_yx * (np.log(p_yx + eps) - np.log(p_y + eps)) # kl divergence for each image
sum_kl_d = kl_d.sum(axis=1) # sum over classes
avg_kl_d = np.mean(sum_kl_d) # average over images
is_score = np.exp(avg_kl_d) # undo the logs
return is_score
if __name__ == "__main__":
cf = AttrDict()
cf.img_path = "imgs/{}.png"
cf.img_path_og = "imgs/{}_og.png"
cf.seed = 20
cf.percent_data_used = 0.2
cf.n_epochs = 10
cf.data_size = None
cf.batch_size = 128
cf.apply_inv = True
cf.apply_scaling = True
cf.label_scale = 0.94
cf.img_scale = 1.0
cf.neurons = [10, 500, 500, 784]
cf.n_layers = len(cf.neurons)
cf.act_fn = F.RELU
cf.var_out = 1
cf.vars = torch.ones(cf.n_layers)
cf.itr_max = 50
cf.beta = 0.1
cf.div = 2
cf.condition = 1e-6
cf.d_rate = 0
# optim parameters
cf.l_rate = 1e-3
cf.optim = "ADAM"
cf.eps = 1e-8
cf.decay_r = 0.9
cf.beta_1 = 0.9
cf.beta_2 = 0.999
cf.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
main(cf)