-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathNCOD.py
117 lines (85 loc) · 4.15 KB
/
NCOD.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
cross_entropy_val = nn.CrossEntropyLoss
mean = 1e-8
std = 1e-9
encoder_features = 512
total_epochs = 150
class NCODLoss(nn.Module):
def __init__(self, sample_labels, num_examp=50000, num_classes=100, ratio_consistency=0, ratio_balance=0):
super(NCODLoss, self).__init__()
self.num_classes = num_classes
self.USE_CUDA = torch.cuda.is_available()
self.num_examp = num_examp
self.ratio_consistency = ratio_consistency
self.ratio_balance = ratio_balance
self.u = nn.Parameter(torch.empty(num_examp, 1, dtype=torch.float32))
self.init_param(mean=mean, std=std)
self.beginning = True
self.prevSimilarity = torch.rand((num_examp, encoder_features))
self.masterVector = torch.rand((num_classes, encoder_features))
self.sample_labels = sample_labels
self.bins = []
for i in range(0, num_classes):
self.bins.append(np.where(self.sample_labels == i)[0])
def init_param(self, mean=1e-8, std=1e-9):
torch.nn.init.normal_(self.u, mean=mean, std=std)
def forward(self, index, outputs, label, out, flag, epoch):
if len(outputs) > len(index):
output, output2 = torch.chunk(outputs, 2)
out1, out2 = torch.chunk(out, 2)
else:
output = outputs
out1 = out
eps = 1e-4
u = self.u[index]
if flag == 0:
if self.beginning:
percent = math.ceil((50 - (50 / total_epochs) * epoch) + 50)
for i in range(0, len(self.bins)):
class_u = self.u.detach()[self.bins[i]]
bottomK = int((len(class_u) / 100) * percent)
important_indexs = torch.topk(class_u, bottomK, largest=False, dim=0)[1]
self.masterVector[i] = torch.mean(
self.prevSimilarity[self.bins[i]][important_indexs.view(-1)], dim=0
)
masterVector_norm = self.masterVector.norm(p=2, dim=1, keepdim=True)
masterVector_normalized = self.masterVector.div(masterVector_norm)
self.masterVector_transpose = torch.transpose(masterVector_normalized, 0, 1)
self.beginning = True
self.prevSimilarity[index] = out1.detach()
prediction = F.softmax(output, dim=1)
out_norm = out1.detach().norm(p=2, dim=1, keepdim=True)
out_normalized = out1.detach().div(out_norm)
similarity = torch.mm(out_normalized, self.masterVector_transpose)
similarity = similarity * label
sim_mask = (similarity > 0.000).type(torch.float32)
similarity = similarity * sim_mask
u = u * label
prediction = torch.clamp((prediction + u.detach()), min=eps, max=1.0)
loss = torch.mean(-torch.sum((similarity) * torch.log(prediction), dim=1))
label_one_hot = self.soft_to_hard(output.detach())
MSE_loss = F.mse_loss((label_one_hot + u), label, reduction="sum") / len(label)
loss += MSE_loss
if self.ratio_balance > 0:
avg_prediction = torch.mean(prediction, dim=0)
prior_distr = 1.0 / self.num_classes * torch.ones_like(avg_prediction)
avg_prediction = torch.clamp(avg_prediction, min=eps, max=1.0)
balance_kl = torch.mean(-(prior_distr * torch.log(avg_prediction)).sum(dim=0))
loss += self.ratio_balance * balance_kl
if (len(outputs) > len(index)) and (self.ratio_consistency > 0):
consistency_loss = self.consistency_loss(output, output2)
loss += self.ratio_consistency * torch.mean(consistency_loss)
return loss
def consistency_loss(self, output1, output2):
preds1 = F.softmax(output1, dim=1).detach()
preds2 = F.log_softmax(output2, dim=1)
loss_kldiv = F.kl_div(preds2, preds1, reduction="none")
loss_kldiv = torch.sum(loss_kldiv, dim=1)
return loss_kldiv
def soft_to_hard(self, x):
with torch.no_grad():
return (torch.zeros(len(x), self.num_classes)).cuda().scatter_(1, (x.argmax(dim=1)).view(-1, 1), 1)