Skip to content

Commit

Permalink
Fix DeepLabV3 BC
Browse files Browse the repository at this point in the history
  • Loading branch information
qubvel committed Jan 14, 2025
1 parent 31bee79 commit 556b3aa
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 18 deletions.
19 changes: 1 addition & 18 deletions segmentation_models_pytorch/decoders/deeplabv3/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
__all__ = ["DeepLabV3Decoder", "DeepLabV3PlusDecoder"]


class DeepLabV3Decoder(nn.Sequential):
class DeepLabV3Decoder(nn.Module):
def __init__(
self,
in_channels: int,
Expand Down Expand Up @@ -69,23 +69,6 @@ def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
x = self.relu(x)
return x

def load_state_dict(self, state_dict, *args, **kwargs):
# For backward compatibility, previously this module was Sequential
# and was not scriptable.
keys = list(state_dict.keys())
for key in keys:
new_key = key
if key.startswith("0."):
new_key = "aspp." + key[2:]
elif key.startswith("1."):
new_key = "conv." + key[2:]
elif key.startswith("2."):
new_key = "bn." + key[2:]
elif key.startswith("3."):
new_key = "relu." + key[2:]
state_dict[new_key] = state_dict.pop(key)
super().load_state_dict(state_dict, *args, **kwargs)


class DeepLabV3PlusDecoder(nn.Module):
def __init__(
Expand Down
15 changes: 15 additions & 0 deletions segmentation_models_pytorch/decoders/deeplabv3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,21 @@ def __init__(
else:
self.classification_head = None

def load_state_dict(self, state_dict, *args, **kwargs):
# For backward compatibility, previously Decoder module was Sequential
# and was not scriptable.
keys = list(state_dict.keys())
for key in keys:
new_key = key
if key.startswith("decoder.0."):
new_key = key.replace("decoder.0.", "decoder.aspp.")

Check warning on line 131 in segmentation_models_pytorch/decoders/deeplabv3/model.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/decoders/deeplabv3/model.py#L131

Added line #L131 was not covered by tests
elif key.startswith("decoder.1."):
new_key = key.replace("decoder.1.", "decoder.conv.")

Check warning on line 133 in segmentation_models_pytorch/decoders/deeplabv3/model.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/decoders/deeplabv3/model.py#L133

Added line #L133 was not covered by tests
elif key.startswith("decoder.2."):
new_key = key.replace("decoder.2.", "decoder.bn.")

Check warning on line 135 in segmentation_models_pytorch/decoders/deeplabv3/model.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/decoders/deeplabv3/model.py#L135

Added line #L135 was not covered by tests
state_dict[new_key] = state_dict.pop(key)
return super().load_state_dict(state_dict, *args, **kwargs)


class DeepLabV3Plus(SegmentationModel):
"""DeepLabV3+ implementation from "Encoder-Decoder with Atrous Separable
Expand Down

0 comments on commit 556b3aa

Please sign in to comment.