Skip to content

Commit

Permalink
Update indexes
Browse files Browse the repository at this point in the history
  • Loading branch information
qubvel committed Jan 16, 2025
1 parent 22e6b2e commit ebe45dc
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions segmentation_models_pytorch/encoders/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
class EfficientNetEncoder(EfficientNet, EncoderMixin):
def __init__(
self,
stage_idxs: List[int],
out_indexes: List[int],
out_channels: List[int],
model_name: str,
depth: int = 5,
Expand All @@ -47,8 +47,7 @@ def __init__(
blocks_args, global_params = get_model_params(model_name, override_params=None)
super().__init__(blocks_args, global_params)

self._stage_idxs = stage_idxs
self._out_indexes = [x - 1 for x in stage_idxs]
self._out_indexes = out_indexes
self._depth = depth
self._in_channels = 3
self._out_channels = out_channels
Expand Down Expand Up @@ -109,7 +108,7 @@ def load_state_dict(self, state_dict, **kwargs):
},
"params": {
"out_channels": [3, 32, 24, 40, 112, 320],
"stage_idxs": [3, 5, 9, 16],
"out_indexes": [2, 4, 8, 15],
"model_name": "efficientnet-b0",
},
},
Expand All @@ -127,7 +126,7 @@ def load_state_dict(self, state_dict, **kwargs):
},
"params": {
"out_channels": [3, 32, 24, 40, 112, 320],
"stage_idxs": [5, 8, 16, 23],
"out_indexes": [4, 7, 15, 22],
"model_name": "efficientnet-b1",
},
},
Expand All @@ -145,7 +144,7 @@ def load_state_dict(self, state_dict, **kwargs):
},
"params": {
"out_channels": [3, 32, 24, 48, 120, 352],
"stage_idxs": [5, 8, 16, 23],
"out_indexes": [4, 7, 15, 22],
"model_name": "efficientnet-b2",
},
},
Expand All @@ -163,7 +162,7 @@ def load_state_dict(self, state_dict, **kwargs):
},
"params": {
"out_channels": [3, 40, 32, 48, 136, 384],
"stage_idxs": [5, 8, 18, 26],
"out_indexes": [4, 7, 17, 25],
"model_name": "efficientnet-b3",
},
},
Expand All @@ -181,7 +180,7 @@ def load_state_dict(self, state_dict, **kwargs):
},
"params": {
"out_channels": [3, 48, 32, 56, 160, 448],
"stage_idxs": [6, 10, 22, 32],
"out_indexes": [5, 9, 21, 31],
"model_name": "efficientnet-b4",
},
},
Expand All @@ -199,7 +198,7 @@ def load_state_dict(self, state_dict, **kwargs):
},
"params": {
"out_channels": [3, 48, 40, 64, 176, 512],
"stage_idxs": [8, 13, 27, 39],
"out_indexes": [7, 12, 26, 38],
"model_name": "efficientnet-b5",
},
},
Expand All @@ -217,7 +216,7 @@ def load_state_dict(self, state_dict, **kwargs):
},
"params": {
"out_channels": [3, 56, 40, 72, 200, 576],
"stage_idxs": [9, 15, 31, 45],
"out_indexes": [8, 14, 30, 44],
"model_name": "efficientnet-b6",
},
},
Expand All @@ -235,7 +234,7 @@ def load_state_dict(self, state_dict, **kwargs):
},
"params": {
"out_channels": [3, 64, 48, 80, 224, 640],
"stage_idxs": [11, 18, 38, 55],
"out_indexes": [10, 17, 37, 54],
"model_name": "efficientnet-b7",
},
},
Expand Down

0 comments on commit ebe45dc

Please sign in to comment.