-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmulti_resolution_segmentation.py
194 lines (171 loc) · 8.55 KB
/
multi_resolution_segmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import torch
from torch import nn, optim
from torch.utils.data.dataloader import DataLoader
from tqdm import trange
from torchvision.datasets import ImageFolder
from torch.cuda.amp import GradScaler
from efficientnet_pytorch.model import EfficientNet
from segmentation import num_channels, cam_resolution
from utils import eval_segmentation, to_device, train_or_eval, transform, train_transform
import argparse
class ClassActivationMap2d(nn.Module):
'''Generates a class activation map from the input tensor (format: N * C * H * W).'''
def __init__(self, in_channels: int, out_channels: int, input_res, output_res=None, upsampling_mode='nearest'):
super(ClassActivationMap2d, self).__init__()
# Linear (1x1 convolution) layer to transform input feature maps into CAMs
self.linear = nn.Conv2d(in_channels, out_channels, kernel_size=1)
# Average pooling if not generating CAM
self.avgpool = nn.AvgPool2d(input_res)
# Upsample to a certain size (optional)
self.upsample = nn.Upsample(size=output_res, mode=upsampling_mode) if output_res else None
def forward(self, inputs, return_cam=False):
if return_cam:
# If generating CAM, apply the linear transformation then upsample
out = self.linear(inputs)
return self.upsample(out) if self.upsample else out
else:
# If not generating CAM, average pool the input to size 1x1, then apply the linear transformation
avg = self.avgpool(inputs)
return self.linear(avg)
class MultiResolutionSegmentation(nn.Module):
'''Generates a class activation map by combining multiple CAMs at different resolutions.'''
# TODO Implement
# Useful classes:
# - torch.nn.Upsample
# - torch.nn.ModuleList/ModuleDict
# - ParameterDict?
def __init__(self, from_pretrained=True, **kwargs):
super(MultiResolutionSegmentation, self).__init__()
# # Extract arguments and save the params
self.constructor_params = kwargs
backbone = kwargs.get('backbone', 'efficientnet-b0')
endpoints = kwargs.get('endpoints', [1, 2, 3, 4])
pos_class_weight = kwargs.get('pos_class_weight', 2.0)
upsampling_mode = kwargs.get('upsampling_mode', 'nearest')
# Use EfficientNet classifier as a backbone
if from_pretrained:
self.backbone = EfficientNet.from_pretrained(backbone, num_classes=2)
else:
self.backbone = EfficientNet.from_name(backbone, num_classes=2)
# Create a segmentation branch and CAM layer for each endpoint
self.seg_branches = nn.ModuleDict()
self.activation_map = nn.ModuleDict()
for endpoint in endpoints:
# Create seg branch for endpoint
in_channels = num_channels(endpoint=endpoint)
seg_branch = nn.Sequential(
nn.Conv2d(in_channels, 16, kernel_size=3, padding=1),
nn.ReLU()
)
self.seg_branches[f'reduction_{endpoint}'] = seg_branch
# Create CAM layer for endpoint
in_res = cam_resolution(endpoint=endpoint)
self.activation_map[f'reduction_{endpoint}'] = ClassActivationMap2d(16, 2, in_res, 112, upsampling_mode)
# Softmax over channels if outputting CAM
self.softmax = nn.Softmax(dim=1)
# Loss function
self.set_alpha(pos_class_weight)
def set_alpha(self, pos_class_weight):
class_weights = torch.FloatTensor([1.0, pos_class_weight])
self.loss_criterion = to_device(nn.CrossEntropyLoss(weight=class_weights))
def forward(self, inputs, return_cam=False):
# Extract hidden states from EfficientNet endpoints
hidden_states = self.backbone.extract_endpoints(inputs)
# Sub-class activation maps
sub_cams = []
# Put each one through its respective segmentation branch if it exists,
# then upsample to 112x112
for endpoint_name in hidden_states.keys():
# Get seg branch
if endpoint_name in self.seg_branches:
seg_branch = self.seg_branches[endpoint_name]
hidden_state = hidden_states[endpoint_name]
sub_cam = seg_branch(hidden_state)
sub_cam = self.activation_map[endpoint_name](sub_cam, return_cam=return_cam)
sub_cams.append(sub_cam)
# Sum all CAMs together
stacked_cams = torch.stack(sub_cams, dim=0)
cam = torch.sum(stacked_cams, dim=0, keepdim=False)
# Return either scores or activation map
# Each Conv2D layer returns an output of dimension (N, C, H, W)
if return_cam:
# Compute softmax channel-wise and output only the positive class
# probability for each pixel
return self.softmax(cam)[:, 1]
else:
# Remove the (H, W) dimensions
cam = torch.squeeze(cam, -1)
cam = torch.squeeze(cam, -1)
return cam
def freeze_backbone(self):
'''Freezes the backbone.'''
self.backbone.requires_grad_(False)
def to_save_file(self, save_file):
'''Saves this model to a save file in .pt format.
Both the constructor parameters and the state_dict are saved.'''
save_json = {
'constructor_params': self.constructor_params,
'state_dict': self.state_dict()
}
torch.save(save_json, save_file)
@classmethod
def from_save_file(cls, save_file):
'''Creates a new model from an existing save file.
The save file contains both the state_dict and constructor parameters needed to
initialize the model correctly.'''
save_json = torch.load(save_file)
# Extract constructor params and state_dict
params = save_json['constructor_params']
state_dict = save_json['state_dict']
# Create model according to params and load state_dict
model = cls(from_pretrained=False, **params)
model.load_state_dict(state_dict)
return to_device(model)
def train_multi_segmentation(model: MultiResolutionSegmentation,
train_loader: DataLoader,
val_loader: DataLoader,
optimizer: optim.Optimizer,
val_loader_type: str = 'class',
scaler: GradScaler = None,
num_epochs: int = 3):
# Don't train the backbone
model.freeze_backbone()
history = []
# Train the segmentation part of the neural network
for epoch in trange(num_epochs, desc=f'Training'):
result = train_or_eval(model, train_loader, optimizer, scaler)
history.append({
'epoch': epoch,
**result
})
# Evaluate on validation set once at the end
if val_loader_type == 'class':
return train_or_eval(model, val_loader, scaler=scaler), history
elif val_loader_type == 'seg':
return eval_segmentation(model, val_loader), history
def parse_args():
parser = argparse.ArgumentParser(description='Train and store the model')
parser.add_argument('-o', '--out', metavar='model.pt', default='model.pt')
parser.add_argument('-w', '--pos-class-weight', type=float, default=8.0)
parser.add_argument('-e', '--num-epochs', type=int, default=3)
parser.add_argument('-b', '--batch-size', type=int, default=48)
parser.add_argument('-m', '--mixed-precision', action='store_true')
parser.add_argument('--train-dir', default='./SPI_train/')
parser.add_argument('--val-dir', default='./SPI_val/')
return parser.parse_args()
def main():
args = parse_args()
train_set = ImageFolder(root=args.train_dir, transform=train_transform)
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4)
val_set = ImageFolder(root=args.val_dir, transform=transform)
val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=True, num_workers=4)
model = to_device(MultiResolutionSegmentation(pos_class_weight=args.pos_class_weight))
# Use RMSProp parameters from the DeepSolar paper (alpha = second moment discount rate)
# except for learning rate decay and epsilon
optimizer = optim.RMSprop(model.parameters(), alpha=0.9, momentum=0.9, eps=0.001, lr=1e-3)
# optimizer = optim.Adam(model.parameters()) # betas =(0.9, 0.9)
scaler = GradScaler() if args.mixed_precision else None
train_multi_segmentation(model, train_loader, val_loader, optimizer, scaler, num_epochs=args.num_epochs)
model.to_save_file(args.out)
if __name__ == '__main__':
main()