-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathutils.py
177 lines (141 loc) · 4.16 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.distributions as dist
from copy import deepcopy
import math
import matplotlib.pyplot as plt
global DEVICE
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
### General Utils ###
def boolcheck(x):
return str(x).lower() in ["true", "1", "yes"]
def set_tensor(xs):
return xs.float().to(DEVICE)
def edge_zero_pad(img,d):
N,C, h,w = img.shape
x = torch.zeros((N,C,h+(d*2),w+(d*2))).to(DEVICE)
x[:,:,d:h+d,d:w+d] = img
return x
def accuracy(out, L):
B,l = out.shape
total = 0
for i in range(B):
if torch.argmax(out[i,:]) == torch.argmax(L[i,:]):
total +=1
return total/ B
def sequence_accuracy(model, target_batch):
accuracy = 0
L = len(target_batch)
_,B = target_batch[0].shape
s = ""
for i in range(len(target_batch)): # this loop is over the seq_len
s += str(torch.argmax(model.mu_y[i][:,0]).item()) + " " + str(torch.argmax(target_batch[i][:,0]).item()) + " "
for b in range(B):
#print("target idx: ", torch.argmax(target_batch[i][:,b]).item())
#print("pred idx: ", torch.argmax(model.mu_y[i][:,b]).item())
if torch.argmax(target_batch[i][:,b]) ==torch.argmax(model.mu_y[i][:,b]):
accuracy+=1
print("accs: ", s)
return accuracy / (L * B)
def custom_onehot(idx, shape):
ret = set_tensor(torch.zeros(shape))
ret[idx] =1
return ret
def onehot(arr, vocab_size):
L, B = arr.shape
ret = np.zeros([L,vocab_size,B])
for l in range(L):
for b in range(B):
ret[l,int(arr[l,b]),b] = 1
return ret
def inverse_list_onehot(arr):
L = len(arr)
V,B = arr[0].shape
ret = np.zeros([L,B])
for l in range(L):
for b in range(B):
for v in range(V):
if arr[l][v,b] == 1:
ret[l,b] = v
return ret
def decode_ypreds(ypreds):
L = len(ypreds)
V,B = ypreds[0].shape
ret = np.zeros([L,B])
for l in range(L):
for b in range(B):
v = torch.argmax(ypreds[l][:,b])
ret[l,b] =v
return ret
def inverse_onehot(arr):
if type(arr) == list:
return inverse_list_onehot(arr)
else:
L,V,B = arr.shape
ret = np.zeros([L,B])
for l in range(L):
for b in range(B):
for v in range(V):
if arr[l,v,b] == 1:
ret[l,b] = v
return ret
### Activation functions ###
def tanh(xs):
return torch.tanh(xs)
def linear(x):
return x
def tanh_deriv(xs):
return 1.0 - torch.tanh(xs) ** 2.0
def linear_deriv(x):
return set_tensor(torch.ones((1,)))
def relu(xs):
return torch.clamp(xs,min=0)
def relu_deriv(xs):
rel = relu(xs)
rel[rel>0] = 1
return rel
def softmax(xs):
return F.softmax(xs)
def sigmoid(xs):
return F.sigmoid(xs)
def sigmoid_deriv(xs):
return F.sigmoid(xs) * (torch.ones_like(xs) - F.sigmoid(xs))
### loss functions
def mse_loss(out, label):
return torch.sum((out-label)**2)
def mse_deriv(out,label):
return 2 * (out - label)
ce_loss = nn.CrossEntropyLoss()
def cross_entropy_loss(out,label):
return ce_loss(out,label)
def my_cross_entropy(out,label):
return -torch.sum(label * torch.log(out + 1e-6))
def cross_entropy_deriv(out,label):
return out - label
def parse_loss_function(loss_arg):
if loss_arg == "mse":
return mse_loss, mse_deriv
elif loss_arg == "crossentropy":
return my_cross_entropy, cross_entropy_deriv
else:
raise ValueError("loss argument not expected. Can be one of 'mse' and 'crossentropy'. You inputted " + str(loss_arg))
### Initialization Functions ###
def gaussian_init(W,mean=0.0, std=0.05):
return W.normal_(mean=0.0,std=0.05)
def zeros_init(W):
return torch.zeros_like(W)
def kaiming_init(W, a=math.sqrt(5),*kwargs):
return init.kaiming_uniform_(W, a)
def glorot_init(W):
return init.xavier_normal_(W)
def kaiming_bias_init(b,*kwargs):
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
return init.uniform_(b, -bound, bound)
#the initialization pytorch uses for lstm
def std_uniform_init(W,hidden_size):
stdv = 1.0 / math.sqrt(hidden_size)
return init.uniform_(W, -stdv, stdv)