diff --git a/segmentation_models_pytorch/encoders/efficientnet.py b/segmentation_models_pytorch/encoders/efficientnet.py index 3ea9f1d7..54f47371 100644 --- a/segmentation_models_pytorch/encoders/efficientnet.py +++ b/segmentation_models_pytorch/encoders/efficientnet.py @@ -31,8 +31,6 @@ class EfficientNetEncoder(EfficientNet, EncoderMixin): - _is_torch_scriptable = False - def __init__( self, stage_idxs: List[int], @@ -41,7 +39,7 @@ def __init__( depth: int = 5, output_stride: int = 32, ): - if depth > 5 or depth < 1: + if depth > 5 or depth < 2: raise ValueError( f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" ) @@ -50,11 +48,13 @@ def __init__( super().__init__(blocks_args, global_params) self._stage_idxs = stage_idxs + self._out_indexes = [x - 1 for x in stage_idxs] self._depth = depth self._in_channels = 3 self._out_channels = out_channels self._output_stride = output_stride + self._drop_connect_rate = self._global_params.drop_connect_rate del self._fc def get_stages(self) -> Dict[int, Sequence[torch.nn.Module]]: @@ -63,17 +63,6 @@ def get_stages(self) -> Dict[int, Sequence[torch.nn.Module]]: 32: [self._blocks[self._stage_idxs[2] :]], } - def apply_blocks( - self, x: torch.Tensor, start_idx: int, end_idx: int - ) -> torch.Tensor: - drop_connect_rate = self._global_params.drop_connect_rate - - for block_number in range(start_idx, end_idx): - drop_connect_prob = drop_connect_rate * block_number / len(self._blocks) - x = self._blocks[block_number](x, drop_connect_prob) - - return x - def forward(self, x: torch.Tensor) -> List[torch.Tensor]: features = [x] @@ -83,21 +72,19 @@ def forward(self, x: torch.Tensor) -> List[torch.Tensor]: x = self._swish(x) features.append(x) - if self._depth >= 2: - x = self.apply_blocks(x, 0, self._stage_idxs[0]) - features.append(x) + depth = 1 + for i, block in enumerate(self._blocks): + drop_connect_prob = self._drop_connect_rate * i / len(self._blocks) + x = block(x, drop_connect_prob) - if self._depth >= 3: - x = self.apply_blocks(x, self._stage_idxs[0], self._stage_idxs[1]) - features.append(x) + if i in self._out_indexes: + features.append(x) + depth += 1 - if self._depth >= 4: - x = self.apply_blocks(x, self._stage_idxs[1], self._stage_idxs[2]) - features.append(x) + if not torch.jit.is_scripting() and depth > self._depth: + break - if self._depth >= 5: - x = self.apply_blocks(x, self._stage_idxs[2], len(self._blocks)) - features.append(x) + features = features[: self._depth + 1] return features