Skip to content

Commit

Permalink
Update encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
qubvel committed Jan 16, 2025
1 parent ce1ae43 commit 4aa7cd1
Showing 1 changed file with 13 additions and 26 deletions.
39 changes: 13 additions & 26 deletions segmentation_models_pytorch/encoders/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@


class EfficientNetEncoder(EfficientNet, EncoderMixin):
_is_torch_scriptable = False

def __init__(
self,
stage_idxs: List[int],
Expand All @@ -41,7 +39,7 @@ def __init__(
depth: int = 5,
output_stride: int = 32,
):
if depth > 5 or depth < 1:
if depth > 5 or depth < 2:
raise ValueError(
f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}"
)
Expand All @@ -50,11 +48,13 @@ def __init__(
super().__init__(blocks_args, global_params)

self._stage_idxs = stage_idxs
self._out_indexes = [x - 1 for x in stage_idxs]
self._depth = depth
self._in_channels = 3
self._out_channels = out_channels
self._output_stride = output_stride

self._drop_connect_rate = self._global_params.drop_connect_rate
del self._fc

def get_stages(self) -> Dict[int, Sequence[torch.nn.Module]]:
Expand All @@ -63,17 +63,6 @@ def get_stages(self) -> Dict[int, Sequence[torch.nn.Module]]:
32: [self._blocks[self._stage_idxs[2] :]],
}

def apply_blocks(
self, x: torch.Tensor, start_idx: int, end_idx: int
) -> torch.Tensor:
drop_connect_rate = self._global_params.drop_connect_rate

for block_number in range(start_idx, end_idx):
drop_connect_prob = drop_connect_rate * block_number / len(self._blocks)
x = self._blocks[block_number](x, drop_connect_prob)

return x

def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
features = [x]

Expand All @@ -83,21 +72,19 @@ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
x = self._swish(x)
features.append(x)

if self._depth >= 2:
x = self.apply_blocks(x, 0, self._stage_idxs[0])
features.append(x)
depth = 1
for i, block in enumerate(self._blocks):
drop_connect_prob = self._drop_connect_rate * i / len(self._blocks)
x = block(x, drop_connect_prob)

if self._depth >= 3:
x = self.apply_blocks(x, self._stage_idxs[0], self._stage_idxs[1])
features.append(x)
if i in self._out_indexes:
features.append(x)
depth += 1

if self._depth >= 4:
x = self.apply_blocks(x, self._stage_idxs[1], self._stage_idxs[2])
features.append(x)
if not torch.jit.is_scripting() and depth > self._depth:
break

if self._depth >= 5:
x = self.apply_blocks(x, self._stage_idxs[2], len(self._blocks))
features.append(x)
features = features[: self._depth + 1]

return features

Expand Down

0 comments on commit 4aa7cd1

Please sign in to comment.