From ebe45dc1c315a9153a191369b544dc3ce446b37a Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Thu, 16 Jan 2025 14:17:48 +0000 Subject: [PATCH] Update indexes --- .../encoders/efficientnet.py | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/segmentation_models_pytorch/encoders/efficientnet.py b/segmentation_models_pytorch/encoders/efficientnet.py index 54f47371..7836a5b7 100644 --- a/segmentation_models_pytorch/encoders/efficientnet.py +++ b/segmentation_models_pytorch/encoders/efficientnet.py @@ -33,7 +33,7 @@ class EfficientNetEncoder(EfficientNet, EncoderMixin): def __init__( self, - stage_idxs: List[int], + out_indexes: List[int], out_channels: List[int], model_name: str, depth: int = 5, @@ -47,8 +47,7 @@ def __init__( blocks_args, global_params = get_model_params(model_name, override_params=None) super().__init__(blocks_args, global_params) - self._stage_idxs = stage_idxs - self._out_indexes = [x - 1 for x in stage_idxs] + self._out_indexes = out_indexes self._depth = depth self._in_channels = 3 self._out_channels = out_channels @@ -109,7 +108,7 @@ def load_state_dict(self, state_dict, **kwargs): }, "params": { "out_channels": [3, 32, 24, 40, 112, 320], - "stage_idxs": [3, 5, 9, 16], + "out_indexes": [2, 4, 8, 15], "model_name": "efficientnet-b0", }, }, @@ -127,7 +126,7 @@ def load_state_dict(self, state_dict, **kwargs): }, "params": { "out_channels": [3, 32, 24, 40, 112, 320], - "stage_idxs": [5, 8, 16, 23], + "out_indexes": [4, 7, 15, 22], "model_name": "efficientnet-b1", }, }, @@ -145,7 +144,7 @@ def load_state_dict(self, state_dict, **kwargs): }, "params": { "out_channels": [3, 32, 24, 48, 120, 352], - "stage_idxs": [5, 8, 16, 23], + "out_indexes": [4, 7, 15, 22], "model_name": "efficientnet-b2", }, }, @@ -163,7 +162,7 @@ def load_state_dict(self, state_dict, **kwargs): }, "params": { "out_channels": [3, 40, 32, 48, 136, 384], - "stage_idxs": [5, 8, 18, 26], + "out_indexes": [4, 7, 17, 25], "model_name": "efficientnet-b3", }, }, @@ -181,7 +180,7 @@ def load_state_dict(self, state_dict, **kwargs): }, "params": { "out_channels": [3, 48, 32, 56, 160, 448], - "stage_idxs": [6, 10, 22, 32], + "out_indexes": [5, 9, 21, 31], "model_name": "efficientnet-b4", }, }, @@ -199,7 +198,7 @@ def load_state_dict(self, state_dict, **kwargs): }, "params": { "out_channels": [3, 48, 40, 64, 176, 512], - "stage_idxs": [8, 13, 27, 39], + "out_indexes": [7, 12, 26, 38], "model_name": "efficientnet-b5", }, }, @@ -217,7 +216,7 @@ def load_state_dict(self, state_dict, **kwargs): }, "params": { "out_channels": [3, 56, 40, 72, 200, 576], - "stage_idxs": [9, 15, 31, 45], + "out_indexes": [8, 14, 30, 44], "model_name": "efficientnet-b6", }, }, @@ -235,7 +234,7 @@ def load_state_dict(self, state_dict, **kwargs): }, "params": { "out_channels": [3, 64, 48, 80, 224, 640], - "stage_idxs": [11, 18, 38, 55], + "out_indexes": [10, 17, 37, 54], "model_name": "efficientnet-b7", }, },