-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsimclr.py
56 lines (48 loc) · 1.87 KB
/
simclr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import torch.nn as nn
from AK_SSL.vision.models.modules.heads import SimCLRProjectionHead
class SimCLR(nn.Module):
"""
SimCLR: A Simple Framework for Contrastive Learning of Visual Representations
Link: https://arxiv.org/abs/2002.05709
Implementation: https://github.com/google-research/simclr
"""
def __init__(
self,
backbone: nn.Module,
feature_size: int,
projection_dim: int = 128,
projection_num_layers: int = 2,
projection_batch_norm: bool = True,
**kwargs
):
"""
Args:
backbone (nn.Module): Backbone to extract features.
feature_size (int): Feature size.
projection_dim (int): Projection head output dimension.
projection_num_layers (int): Number of layers in the projection head.
projection_batch_norm (bool): Whether to use batch norm in the projection head.
"""
super().__init__()
self.feature_size = feature_size
self.projection_dim = projection_dim
self.projection_num_layers = projection_num_layers
self.projection_batch_norm = projection_batch_norm
self.backbone = backbone
self.projection_head = SimCLRProjectionHead(
input_dim=self.feature_size,
hidden_dim=self.feature_size,
output_dim=self.projection_dim,
num_layers=self.projection_num_layers,
batch_norm=self.projection_batch_norm,
)
self.encoder = nn.Sequential(self.backbone, self.projection_head)
def forward(self, x0: torch.Tensor, x1: torch.Tensor = None):
f0 = self.backbone(x0).flatten(start_dim=1)
out0 = self.projection_head(f0)
if x1 is None:
return out0
f1 = self.backbone(x1).flatten(start_dim=1)
out1 = self.projection_head(f1)
return out0, out1