diff --git a/docs/encoders.rst b/docs/encoders.rst index 0da221f9..2de35dec 100644 --- a/docs/encoders.rst +++ b/docs/encoders.rst @@ -99,14 +99,14 @@ densenet201 imagenet 18M densenet161 imagenet 26M ✅ ✅ ✅ inceptionresnetv2 imagenet / imagenet+background 54M ✅ ✅ ✅ inceptionv4 imagenet / imagenet+background 41M ✅ ✅ ✅ -efficientnet-b0 imagenet / advprop 4M ❌ ✅ ✅ -efficientnet-b1 imagenet / advprop 6M ❌ ✅ ✅ -efficientnet-b2 imagenet / advprop 7M ❌ ✅ ✅ -efficientnet-b3 imagenet / advprop 10M ❌ ✅ ✅ -efficientnet-b4 imagenet / advprop 17M ❌ ✅ ✅ -efficientnet-b5 imagenet / advprop 28M ❌ ✅ ✅ -efficientnet-b6 imagenet / advprop 40M ❌ ✅ ✅ -efficientnet-b7 imagenet / advprop 63M ❌ ✅ ✅ +efficientnet-b0 imagenet / advprop 4M ✅ ✅ ✅ +efficientnet-b1 imagenet / advprop 6M ✅ ✅ ✅ +efficientnet-b2 imagenet / advprop 7M ✅ ✅ ✅ +efficientnet-b3 imagenet / advprop 10M ✅ ✅ ✅ +efficientnet-b4 imagenet / advprop 17M ✅ ✅ ✅ +efficientnet-b5 imagenet / advprop 28M ✅ ✅ ✅ +efficientnet-b6 imagenet / advprop 40M ✅ ✅ ✅ +efficientnet-b7 imagenet / advprop 63M ✅ ✅ ✅ mobilenet_v2 imagenet 2M ✅ ✅ ✅ xception imagenet 20M ✅ ✅ ✅ timm-efficientnet-b0 imagenet / advprop / noisy-student 4M ✅ ✅ ✅ diff --git a/segmentation_models_pytorch/encoders/_efficientnet.py b/segmentation_models_pytorch/encoders/_efficientnet.py index dcc4e268..95f91631 100644 --- a/segmentation_models_pytorch/encoders/_efficientnet.py +++ b/segmentation_models_pytorch/encoders/_efficientnet.py @@ -13,6 +13,44 @@ import math import collections from functools import partial +from typing import List, Optional + +# Parameters for the entire model (stem, all blocks, and head) +GlobalParams = collections.namedtuple( + "GlobalParams", + [ + "width_coefficient", + "depth_coefficient", + "image_size", + "dropout_rate", + "num_classes", + "batch_norm_momentum", + "batch_norm_epsilon", + "drop_connect_rate", + "depth_divisor", + "min_depth", + "include_top", + ], +) + +# Parameters for an individual model block +BlockArgs = collections.namedtuple( + "BlockArgs", + [ + "num_repeat", + "kernel_size", + "stride", + "expand_ratio", + "input_filters", + "output_filters", + "se_ratio", + "id_skip", + ], +) + +# Set GlobalParams and BlockArgs's defaults +GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields) +BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) class MBConvBlock(nn.Module): @@ -29,77 +67,94 @@ class MBConvBlock(nn.Module): [3] https://arxiv.org/abs/1905.02244 (MobileNet v3) """ - def __init__(self, block_args, global_params, image_size=None): + def __init__( + self, block_args: BlockArgs, global_params: GlobalParams, image_size=None + ): super().__init__() - self._block_args = block_args - self._bn_mom = ( - 1 - global_params.batch_norm_momentum - ) # pytorch's difference from tensorflow - self._bn_eps = global_params.batch_norm_epsilon - self.has_se = (self._block_args.se_ratio is not None) and ( - 0 < self._block_args.se_ratio <= 1 - ) - self.id_skip = ( + + self._has_expansion = block_args.expand_ratio != 1 + self._has_se = block_args.se_ratio is not None and 0 < block_args.se_ratio <= 1 + self._has_drop_connect = ( block_args.id_skip - ) # whether to use skip connection and drop connect + and block_args.stride == 1 + and block_args.input_filters == block_args.output_filters + ) + + # Pytorch's difference from tensorflow + bn_momentum = 1 - global_params.batch_norm_momentum + bn_eps = global_params.batch_norm_epsilon # Expansion phase (Inverted Bottleneck) - inp = self._block_args.input_filters # number of input channels - oup = ( - self._block_args.input_filters * self._block_args.expand_ratio - ) # number of output channels - if self._block_args.expand_ratio != 1: + input_channels = block_args.input_filters + expanded_channels = input_channels * block_args.expand_ratio + + if self._has_expansion: Conv2d = get_same_padding_conv2d(image_size=image_size) self._expand_conv = Conv2d( - in_channels=inp, out_channels=oup, kernel_size=1, bias=False + input_channels, expanded_channels, kernel_size=1, bias=False ) self._bn0 = nn.BatchNorm2d( - num_features=oup, momentum=self._bn_mom, eps=self._bn_eps + expanded_channels, + momentum=bn_momentum, + eps=bn_eps, ) - # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size + else: + # for torchscript compatibility + self._expand_conv = nn.Identity() + self._bn0 = nn.Identity() # Depthwise convolution phase - k = self._block_args.kernel_size - s = self._block_args.stride + kernel_size = block_args.kernel_size + stride = block_args.stride Conv2d = get_same_padding_conv2d(image_size=image_size) self._depthwise_conv = Conv2d( - in_channels=oup, - out_channels=oup, - groups=oup, # groups makes it depthwise - kernel_size=k, - stride=s, + in_channels=expanded_channels, + out_channels=expanded_channels, + groups=expanded_channels, # groups makes it depthwise + kernel_size=kernel_size, + stride=stride, bias=False, ) self._bn1 = nn.BatchNorm2d( - num_features=oup, momentum=self._bn_mom, eps=self._bn_eps + expanded_channels, + momentum=bn_momentum, + eps=bn_eps, ) - image_size = calculate_output_image_size(image_size, s) + image_size = calculate_output_image_size(image_size, stride) # Squeeze and Excitation layer, if desired - if self.has_se: + if self._has_se: + squeezed_channels = int(input_channels * block_args.se_ratio) + squeezed_channels = max(1, squeezed_channels) Conv2d = get_same_padding_conv2d(image_size=(1, 1)) - num_squeezed_channels = max( - 1, int(self._block_args.input_filters * self._block_args.se_ratio) - ) self._se_reduce = Conv2d( - in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1 + in_channels=expanded_channels, + out_channels=squeezed_channels, + kernel_size=1, ) self._se_expand = Conv2d( - in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1 + in_channels=squeezed_channels, + out_channels=expanded_channels, + kernel_size=1, ) # Pointwise convolution phase - final_oup = self._block_args.output_filters + output_channels = block_args.output_filters Conv2d = get_same_padding_conv2d(image_size=image_size) self._project_conv = Conv2d( - in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False + in_channels=expanded_channels, + out_channels=output_channels, + kernel_size=1, + bias=False, ) self._bn2 = nn.BatchNorm2d( - num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps + num_features=output_channels, + momentum=bn_momentum, + eps=bn_eps, ) self._swish = nn.SiLU() - def forward(self, inputs, drop_connect_rate=None): + def forward(self, inputs: torch.Tensor, drop_connect_rate: Optional[float] = None): """MBConvBlock's forward function. Args: @@ -112,7 +167,7 @@ def forward(self, inputs, drop_connect_rate=None): # Expansion and Depthwise Convolution x = inputs - if self._block_args.expand_ratio != 1: + if self._has_expansion: x = self._expand_conv(inputs) x = self._bn0(x) x = self._swish(x) @@ -122,7 +177,7 @@ def forward(self, inputs, drop_connect_rate=None): x = self._swish(x) # Squeeze and Excitation - if self.has_se: + if self._has_se: x_squeezed = F.adaptive_avg_pool2d(x, 1) x_squeezed = self._se_reduce(x_squeezed) x_squeezed = self._swish(x_squeezed) @@ -134,17 +189,9 @@ def forward(self, inputs, drop_connect_rate=None): x = self._bn2(x) # Skip connection and drop connect - input_filters, output_filters = ( - self._block_args.input_filters, - self._block_args.output_filters, - ) - if ( - self.id_skip - and self._block_args.stride == 1 - and input_filters == output_filters - ): + if self._has_drop_connect: # The combination of skip connection and drop connect brings about stochastic depth. - if drop_connect_rate: + if drop_connect_rate is not None and drop_connect_rate > 0: x = drop_connect(x, p=drop_connect_rate, training=self.training) x = x + inputs # skip connection return x @@ -169,10 +216,14 @@ class EfficientNet(nn.Module): >>> outputs = model(inputs) """ - def __init__(self, blocks_args=None, global_params=None): + def __init__(self, blocks_args: List[BlockArgs], global_params: GlobalParams): super().__init__() - assert isinstance(blocks_args, list), "blocks_args should be a list" - assert len(blocks_args) > 0, "block args must be greater than 0" + + if not isinstance(blocks_args, list): + raise ValueError("blocks_args should be a list") + if len(blocks_args) == 0: + raise ValueError("block args must be greater than 0") + self._global_params = global_params self._blocks_args = blocks_args @@ -186,20 +237,16 @@ def __init__(self, blocks_args=None, global_params=None): # Stem in_channels = 3 # rgb - out_channels = round_filters( - 32, self._global_params - ) # number of output channels + out_channels = round_filters(32, self._global_params) self._conv_stem = Conv2d( in_channels, out_channels, kernel_size=3, stride=2, bias=False ) - self._bn0 = nn.BatchNorm2d( - num_features=out_channels, momentum=bn_mom, eps=bn_eps - ) + self._bn0 = nn.BatchNorm2d(out_channels, momentum=bn_mom, eps=bn_eps) image_size = calculate_output_image_size(image_size, 2) # Build blocks self._blocks = nn.ModuleList([]) - for block_args in self._blocks_args: + for block_args in blocks_args: # Update block input and output filters based on depth multiplier. block_args = block_args._replace( input_filters=round_filters( @@ -243,57 +290,8 @@ def __init__(self, blocks_args=None, global_params=None): self._swish = nn.SiLU() - def extract_endpoints(self, inputs): - """Use convolution layer to extract features - from reduction levels i in [1, 2, 3, 4, 5]. - - Args: - inputs (tensor): Input tensor. - - Returns: - Dictionary of last intermediate features - with reduction levels i in [1, 2, 3, 4, 5]. - Example: - >>> import torch - >>> from efficientnet.model import EfficientNet - >>> inputs = torch.rand(1, 3, 224, 224) - >>> model = EfficientNet.from_pretrained('efficientnet-b0') - >>> endpoints = model.extract_endpoints(inputs) - >>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112]) - >>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56]) - >>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28]) - >>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14]) - >>> print(endpoints['reduction_5'].shape) # torch.Size([1, 320, 7, 7]) - >>> print(endpoints['reduction_6'].shape) # torch.Size([1, 1280, 7, 7]) - """ - endpoints = dict() - - # Stem - x = self._swish(self._bn0(self._conv_stem(inputs))) - prev_x = x - - # Blocks - for idx, block in enumerate(self._blocks): - drop_connect_rate = self._global_params.drop_connect_rate - if drop_connect_rate: - drop_connect_rate *= float(idx) / len( - self._blocks - ) # scale drop connect_rate - x = block(x, drop_connect_rate=drop_connect_rate) - if prev_x.size(2) > x.size(2): - endpoints["reduction_{}".format(len(endpoints) + 1)] = prev_x - elif idx == len(self._blocks) - 1: - endpoints["reduction_{}".format(len(endpoints) + 1)] = x - prev_x = x - - # Head - x = self._swish(self._bn1(self._conv_head(x))) - endpoints["reduction_{}".format(len(endpoints) + 1)] = x - - return endpoints - def extract_features(self, inputs): - """use convolution layer to extract feature . + """Use convolution layer to extract feature. Args: inputs (tensor): Input tensor. @@ -309,9 +307,8 @@ def extract_features(self, inputs): for idx, block in enumerate(self._blocks): drop_connect_rate = self._global_params.drop_connect_rate if drop_connect_rate: - drop_connect_rate *= float(idx) / len( - self._blocks - ) # scale drop connect_rate + # scale drop connect_rate + drop_connect_rate *= float(idx) / len(self._blocks) x = block(x, drop_connect_rate=drop_connect_rate) # Head @@ -321,7 +318,7 @@ def extract_features(self, inputs): def forward(self, inputs): """EfficientNet's forward function. - Calls extract_features to extract features, applies final linear layer, and returns logits. + Calls extract_features to extract features, applies final linear layer, and returns logits. Args: inputs (tensor): Input tensor. @@ -331,6 +328,7 @@ def forward(self, inputs): """ # Convolution layers x = self.extract_features(inputs) + # Pooling and final linear layer x = self._avg_pooling(x) if self._global_params.include_top: @@ -358,43 +356,6 @@ def forward(self, inputs): # It's an additional function, not used in EfficientNet, # but can be used in other model (such as EfficientDet). -# Parameters for the entire model (stem, all blocks, and head) -GlobalParams = collections.namedtuple( - "GlobalParams", - [ - "width_coefficient", - "depth_coefficient", - "image_size", - "dropout_rate", - "num_classes", - "batch_norm_momentum", - "batch_norm_epsilon", - "drop_connect_rate", - "depth_divisor", - "min_depth", - "include_top", - ], -) - -# Parameters for an individual model block -BlockArgs = collections.namedtuple( - "BlockArgs", - [ - "num_repeat", - "kernel_size", - "stride", - "expand_ratio", - "input_filters", - "output_filters", - "se_ratio", - "id_skip", - ], -) - -# Set GlobalParams and BlockArgs's defaults -GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields) -BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) - def round_filters(filters, global_params): """Calculate and round number of filters based on width multiplier. @@ -442,7 +403,7 @@ def round_repeats(repeats, global_params): return int(math.ceil(multiplier * repeats)) -def drop_connect(inputs, p, training): +def drop_connect(inputs: torch.Tensor, p: float, training: bool) -> torch.Tensor: """Drop connect. Args: diff --git a/segmentation_models_pytorch/encoders/efficientnet.py b/segmentation_models_pytorch/encoders/efficientnet.py index 3ea9f1d7..70046e44 100644 --- a/segmentation_models_pytorch/encoders/efficientnet.py +++ b/segmentation_models_pytorch/encoders/efficientnet.py @@ -31,17 +31,15 @@ class EfficientNetEncoder(EfficientNet, EncoderMixin): - _is_torch_scriptable = False - def __init__( self, - stage_idxs: List[int], + out_indexes: List[int], out_channels: List[int], model_name: str, depth: int = 5, output_stride: int = 32, ): - if depth > 5 or depth < 1: + if depth > 5 or depth < 2: raise ValueError( f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" ) @@ -49,31 +47,21 @@ def __init__( blocks_args, global_params = get_model_params(model_name, override_params=None) super().__init__(blocks_args, global_params) - self._stage_idxs = stage_idxs + self._out_indexes = out_indexes self._depth = depth self._in_channels = 3 self._out_channels = out_channels self._output_stride = output_stride + self._drop_connect_rate = self._global_params.drop_connect_rate del self._fc def get_stages(self) -> Dict[int, Sequence[torch.nn.Module]]: return { - 16: [self._blocks[self._stage_idxs[1] : self._stage_idxs[2]]], - 32: [self._blocks[self._stage_idxs[2] :]], + 16: [self._blocks[self._out_indexes[1] + 1 : self._out_indexes[2] + 1]], + 32: [self._blocks[self._out_indexes[2] + 1 :]], } - def apply_blocks( - self, x: torch.Tensor, start_idx: int, end_idx: int - ) -> torch.Tensor: - drop_connect_rate = self._global_params.drop_connect_rate - - for block_number in range(start_idx, end_idx): - drop_connect_prob = drop_connect_rate * block_number / len(self._blocks) - x = self._blocks[block_number](x, drop_connect_prob) - - return x - def forward(self, x: torch.Tensor) -> List[torch.Tensor]: features = [x] @@ -83,21 +71,19 @@ def forward(self, x: torch.Tensor) -> List[torch.Tensor]: x = self._swish(x) features.append(x) - if self._depth >= 2: - x = self.apply_blocks(x, 0, self._stage_idxs[0]) - features.append(x) + depth = 1 + for i, block in enumerate(self._blocks): + drop_connect_prob = self._drop_connect_rate * i / len(self._blocks) + x = block(x, drop_connect_prob) - if self._depth >= 3: - x = self.apply_blocks(x, self._stage_idxs[0], self._stage_idxs[1]) - features.append(x) + if i in self._out_indexes: + features.append(x) + depth += 1 - if self._depth >= 4: - x = self.apply_blocks(x, self._stage_idxs[1], self._stage_idxs[2]) - features.append(x) + if not torch.jit.is_scripting() and depth > self._depth: + break - if self._depth >= 5: - x = self.apply_blocks(x, self._stage_idxs[2], len(self._blocks)) - features.append(x) + features = features[: self._depth + 1] return features @@ -122,7 +108,7 @@ def load_state_dict(self, state_dict, **kwargs): }, "params": { "out_channels": [3, 32, 24, 40, 112, 320], - "stage_idxs": [3, 5, 9, 16], + "out_indexes": [2, 4, 8, 15], "model_name": "efficientnet-b0", }, }, @@ -140,7 +126,7 @@ def load_state_dict(self, state_dict, **kwargs): }, "params": { "out_channels": [3, 32, 24, 40, 112, 320], - "stage_idxs": [5, 8, 16, 23], + "out_indexes": [4, 7, 15, 22], "model_name": "efficientnet-b1", }, }, @@ -158,7 +144,7 @@ def load_state_dict(self, state_dict, **kwargs): }, "params": { "out_channels": [3, 32, 24, 48, 120, 352], - "stage_idxs": [5, 8, 16, 23], + "out_indexes": [4, 7, 15, 22], "model_name": "efficientnet-b2", }, }, @@ -176,7 +162,7 @@ def load_state_dict(self, state_dict, **kwargs): }, "params": { "out_channels": [3, 40, 32, 48, 136, 384], - "stage_idxs": [5, 8, 18, 26], + "out_indexes": [4, 7, 17, 25], "model_name": "efficientnet-b3", }, }, @@ -194,7 +180,7 @@ def load_state_dict(self, state_dict, **kwargs): }, "params": { "out_channels": [3, 48, 32, 56, 160, 448], - "stage_idxs": [6, 10, 22, 32], + "out_indexes": [5, 9, 21, 31], "model_name": "efficientnet-b4", }, }, @@ -212,7 +198,7 @@ def load_state_dict(self, state_dict, **kwargs): }, "params": { "out_channels": [3, 48, 40, 64, 176, 512], - "stage_idxs": [8, 13, 27, 39], + "out_indexes": [7, 12, 26, 38], "model_name": "efficientnet-b5", }, }, @@ -230,7 +216,7 @@ def load_state_dict(self, state_dict, **kwargs): }, "params": { "out_channels": [3, 56, 40, 72, 200, 576], - "stage_idxs": [9, 15, 31, 45], + "out_indexes": [8, 14, 30, 44], "model_name": "efficientnet-b6", }, }, @@ -248,7 +234,7 @@ def load_state_dict(self, state_dict, **kwargs): }, "params": { "out_channels": [3, 64, 48, 80, 224, 640], - "stage_idxs": [11, 18, 38, 55], + "out_indexes": [10, 17, 37, 54], "model_name": "efficientnet-b7", }, },