diff --git a/segmentation_models_pytorch/decoders/deeplabv3/decoder.py b/segmentation_models_pytorch/decoders/deeplabv3/decoder.py index 15280043..6a801a70 100644 --- a/segmentation_models_pytorch/decoders/deeplabv3/decoder.py +++ b/segmentation_models_pytorch/decoders/deeplabv3/decoder.py @@ -40,7 +40,7 @@ __all__ = ["DeepLabV3Decoder", "DeepLabV3PlusDecoder"] -class DeepLabV3Decoder(nn.Sequential): +class DeepLabV3Decoder(nn.Module): def __init__( self, in_channels: int, @@ -69,23 +69,6 @@ def forward(self, features: List[torch.Tensor]) -> torch.Tensor: x = self.relu(x) return x - def load_state_dict(self, state_dict, *args, **kwargs): - # For backward compatibility, previously this module was Sequential - # and was not scriptable. - keys = list(state_dict.keys()) - for key in keys: - new_key = key - if key.startswith("0."): - new_key = "aspp." + key[2:] - elif key.startswith("1."): - new_key = "conv." + key[2:] - elif key.startswith("2."): - new_key = "bn." + key[2:] - elif key.startswith("3."): - new_key = "relu." + key[2:] - state_dict[new_key] = state_dict.pop(key) - super().load_state_dict(state_dict, *args, **kwargs) - class DeepLabV3PlusDecoder(nn.Module): def __init__( diff --git a/segmentation_models_pytorch/decoders/deeplabv3/model.py b/segmentation_models_pytorch/decoders/deeplabv3/model.py index 654e38d4..c14776f3 100644 --- a/segmentation_models_pytorch/decoders/deeplabv3/model.py +++ b/segmentation_models_pytorch/decoders/deeplabv3/model.py @@ -121,6 +121,21 @@ def __init__( else: self.classification_head = None + def load_state_dict(self, state_dict, *args, **kwargs): + # For backward compatibility, previously Decoder module was Sequential + # and was not scriptable. + keys = list(state_dict.keys()) + for key in keys: + new_key = key + if key.startswith("decoder.0."): + new_key = key.replace("decoder.0.", "decoder.aspp.") + elif key.startswith("decoder.1."): + new_key = key.replace("decoder.1.", "decoder.conv.") + elif key.startswith("decoder.2."): + new_key = key.replace("decoder.2.", "decoder.bn.") + state_dict[new_key] = state_dict.pop(key) + return super().load_state_dict(state_dict, *args, **kwargs) + class DeepLabV3Plus(SegmentationModel): """DeepLabV3+ implementation from "Encoder-Decoder with Atrous Separable