From 643c0b686ee3dabf048eaf751c5fdcf6b0f43ee5 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sun, 12 Jan 2025 02:59:53 +0000 Subject: [PATCH 01/57] Move tests --- .../test_pretrainedmodels_encoders.py | 29 ----------------- tests/encoders/test_torchvision_encoders.py | 31 ++++++++++++++++++- 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/tests/encoders/test_pretrainedmodels_encoders.py b/tests/encoders/test_pretrainedmodels_encoders.py index bbde576c..a19e335d 100644 --- a/tests/encoders/test_pretrainedmodels_encoders.py +++ b/tests/encoders/test_pretrainedmodels_encoders.py @@ -2,15 +2,6 @@ from tests.utils import RUN_ALL_ENCODERS -class TestDenseNetEncoder(base.BaseEncoderTester): - supports_dilated = False - encoder_names = ( - ["densenet121"] - if not RUN_ALL_ENCODERS - else ["densenet121", "densenet169", "densenet161"] - ) - - class TestDPNEncoder(base.BaseEncoderTester): encoder_names = ( ["dpn68"] @@ -31,26 +22,6 @@ class TestInceptionV4Encoder(base.BaseEncoderTester): encoder_names = ["inceptionv4"] if not RUN_ALL_ENCODERS else ["inceptionv4"] -class TestResNetEncoder(base.BaseEncoderTester): - encoder_names = ( - ["resnet18"] - if not RUN_ALL_ENCODERS - else [ - "resnet18", - "resnet34", - "resnet50", - "resnet101", - "resnet152", - "resnext50_32x4d", - "resnext101_32x4d", - "resnext101_32x8d", - "resnext101_32x16d", - "resnext101_32x32d", - "resnext101_32x48d", - ] - ) - - class TestSeNetEncoder(base.BaseEncoderTester): encoder_names = ( ["se_resnet50"] diff --git a/tests/encoders/test_torchvision_encoders.py b/tests/encoders/test_torchvision_encoders.py index 99b8b9d5..2ebaa86a 100644 --- a/tests/encoders/test_torchvision_encoders.py +++ b/tests/encoders/test_torchvision_encoders.py @@ -2,7 +2,36 @@ from tests.utils import RUN_ALL_ENCODERS -class TestMobileoneEncoder(base.BaseEncoderTester): +class TestResNetEncoder(base.BaseEncoderTester): + encoder_names = ( + ["resnet18"] + if not RUN_ALL_ENCODERS + else [ + "resnet18", + "resnet34", + "resnet50", + "resnet101", + "resnet152", + "resnext50_32x4d", + "resnext101_32x4d", + "resnext101_32x8d", + "resnext101_32x16d", + "resnext101_32x32d", + "resnext101_32x48d", + ] + ) + + +class TestDenseNetEncoder(base.BaseEncoderTester): + supports_dilated = False + encoder_names = ( + ["densenet121"] + if not RUN_ALL_ENCODERS + else ["densenet121", "densenet169", "densenet161"] + ) + + +class TestMobileNetEncoder(base.BaseEncoderTester): encoder_names = ["mobilenet_v2"] if not RUN_ALL_ENCODERS else ["mobilenet_v2"] From 7a937abb6e4f2dd01aa30d1fae9773a926c255e3 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sun, 12 Jan 2025 03:00:29 +0000 Subject: [PATCH 02/57] Add compile test for encoders (to be optimized) --- tests/encoders/base.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/encoders/base.py b/tests/encoders/base.py index 0f762cf4..78c9e170 100644 --- a/tests/encoders/base.py +++ b/tests/encoders/base.py @@ -206,3 +206,23 @@ def test_dilated(self): expected_width_strides, f"Encoder `{encoder_name}` should have width output strides {expected_width_strides}, but has {width_strides}", ) + + @torch.inference_mode() + def test_compile(self): + sample = self._get_sample( + batch_size=self.default_batch_size, + num_channels=self.default_num_channels, + height=self.default_height, + width=self.default_width, + ).to(default_device) + + for encoder_name in self.encoder_names: + with self.subTest(encoder_name=encoder_name): + encoder = smp.encoders.get_encoder( + encoder_name, + in_channels=self.default_num_channels, + encoder_weights=None, + ).to(default_device) + encoder.eval() + compiled_encoder = torch.compile(encoder, fullgraph=True, dynamic=True) + compiled_encoder(sample) From 9a7c768d92b4611820063245b8f753f8af8595da Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sun, 12 Jan 2025 03:00:54 +0000 Subject: [PATCH 03/57] densnet --- .../encoders/densenet.py | 125 ++++++++++++------ 1 file changed, 83 insertions(+), 42 deletions(-) diff --git a/segmentation_models_pytorch/encoders/densenet.py b/segmentation_models_pytorch/encoders/densenet.py index c4bd0ce2..6ec1773f 100644 --- a/segmentation_models_pytorch/encoders/densenet.py +++ b/segmentation_models_pytorch/encoders/densenet.py @@ -24,27 +24,14 @@ """ import re +import torch import torch.nn as nn -from pretrainedmodels.models.torchvision_models import pretrained_settings from torchvision.models.densenet import DenseNet from ._base import EncoderMixin -class TransitionWithSkip(nn.Module): - def __init__(self, module): - super().__init__() - self.module = module - - def forward(self, x): - for module in self.module: - x = module(x) - if isinstance(module, nn.ReLU): - skip = x - return x, skip - - class DenseNetEncoder(DenseNet, EncoderMixin): def __init__(self, out_channels, depth=5, **kwargs): super().__init__(**kwargs) @@ -59,37 +46,44 @@ def make_dilated(self, *args, **kwargs): "due to pooling operation for downsampling!" ) - def get_stages(self): - return [ - nn.Identity(), - nn.Sequential( - self.features.conv0, self.features.norm0, self.features.relu0 - ), - nn.Sequential( - self.features.pool0, - self.features.denseblock1, - TransitionWithSkip(self.features.transition1), - ), - nn.Sequential( - self.features.denseblock2, TransitionWithSkip(self.features.transition2) - ), - nn.Sequential( - self.features.denseblock3, TransitionWithSkip(self.features.transition3) - ), - nn.Sequential(self.features.denseblock4, self.features.norm5), - ] + def apply_transition( + self, transition: torch.nn.Sequential, x: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + for module in transition: + x = module(x) + if isinstance(module, nn.ReLU): + intermediate = x + return x, intermediate def forward(self, x): - stages = self.get_stages() - features = [] - for i in range(self._depth + 1): - x = stages[i](x) - if isinstance(x, (list, tuple)): - x, skip = x - features.append(skip) - else: - features.append(x) + + if self._depth >= 1: + x = self.features.conv0(x) + x = self.features.norm0(x) + x = self.features.relu0(x) + features.append(x) + + if self._depth >= 2: + x = self.features.pool0(x) + x = self.features.denseblock1(x) + x, intermediate = self.apply_transition(self.features.transition1, x) + features.append(intermediate) + + if self._depth >= 3: + x = self.features.denseblock2(x) + x, intermediate = self.apply_transition(self.features.transition2, x) + features.append(intermediate) + + if self._depth >= 4: + x = self.features.denseblock3(x) + x, intermediate = self.apply_transition(self.features.transition3, x) + features.append(intermediate) + + if self._depth >= 5: + x = self.features.denseblock4(x) + x = self.features.norm5(x) + features.append(x) return features @@ -111,6 +105,53 @@ def load_state_dict(self, state_dict): super().load_state_dict(state_dict) +pretrained_settings = { + "densenet121": { + "imagenet": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/densenet121-fbdb23505.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "densenet169": { + "imagenet": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/densenet169-f470b90a4.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "densenet201": { + "imagenet": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/densenet201-5750cbb1e.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "densenet161": { + "imagenet": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/densenet161-347e6b360.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, +} + densenet_encoders = { "densenet121": { "encoder": DenseNetEncoder, From 34b853309007e5f2dff8cd84e909a770de87fc46 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sun, 12 Jan 2025 03:01:01 +0000 Subject: [PATCH 04/57] dpn --- segmentation_models_pytorch/encoders/dpn.py | 52 +++++++++++---------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/segmentation_models_pytorch/encoders/dpn.py b/segmentation_models_pytorch/encoders/dpn.py index 220c66de..0ee8f04b 100644 --- a/segmentation_models_pytorch/encoders/dpn.py +++ b/segmentation_models_pytorch/encoders/dpn.py @@ -24,8 +24,8 @@ """ import torch -import torch.nn as nn import torch.nn.functional as F +from typing import List from pretrainedmodels.models.dpn import DPN from pretrainedmodels.models.dpn import pretrained_settings @@ -43,30 +43,34 @@ def __init__(self, stage_idxs, out_channels, depth=5, **kwargs): del self.last_linear - def get_stages(self): - return [ - nn.Identity(), - nn.Sequential( - self.features[0].conv, self.features[0].bn, self.features[0].act - ), - nn.Sequential( - self.features[0].pool, self.features[1 : self._stage_idxs[0]] - ), - self.features[self._stage_idxs[0] : self._stage_idxs[1]], - self.features[self._stage_idxs[1] : self._stage_idxs[2]], - self.features[self._stage_idxs[2] : self._stage_idxs[3]], - ] - - def forward(self, x): - stages = self.get_stages() - + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: features = [] - for i in range(self._depth + 1): - x = stages[i](x) - if isinstance(x, (list, tuple)): - features.append(F.relu(torch.cat(x, dim=1), inplace=True)) - else: - features.append(x) + + if self._depth >= 1: + x = self.features[0].conv(x) + x = self.features[0].bn(x) + x = self.features[0].act(x) + features.append(x) + + if self._depth >= 2: + x = self.features[0].pool(x) + x = self.features[1 : self._stage_idxs[0]](x) + skip = F.relu(torch.cat(x, dim=1), inplace=True) + features.append(skip) + + if self._depth >= 3: + x = self.features[self._stage_idxs[0] : self._stage_idxs[1]](x) + skip = F.relu(torch.cat(x, dim=1), inplace=True) + features.append(skip) + + if self._depth >= 4: + x = self.features[self._stage_idxs[1] : self._stage_idxs[2]](x) + skip = F.relu(torch.cat(x, dim=1), inplace=True) + features.append(skip) + + if self._depth >= 5: + x = self.features[self._stage_idxs[2] : self._stage_idxs[3]](x) + features.append(x) return features From a3618fac238caa6b1bf4f736f95005983769013f Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sun, 12 Jan 2025 03:01:08 +0000 Subject: [PATCH 05/57] efficientnet --- .../encoders/efficientnet.py | 57 +++++++++++-------- 1 file changed, 32 insertions(+), 25 deletions(-) diff --git a/segmentation_models_pytorch/encoders/efficientnet.py b/segmentation_models_pytorch/encoders/efficientnet.py index 4a7af6b4..2765af20 100644 --- a/segmentation_models_pytorch/encoders/efficientnet.py +++ b/segmentation_models_pytorch/encoders/efficientnet.py @@ -23,7 +23,9 @@ depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). """ -import torch.nn as nn +import torch +from typing import List + from efficientnet_pytorch import EfficientNet from efficientnet_pytorch.utils import url_map, url_map_advprop, get_model_params @@ -42,35 +44,40 @@ def __init__(self, stage_idxs, out_channels, model_name, depth=5): del self._fc - def get_stages(self): - return [ - nn.Identity(), - nn.Sequential(self._conv_stem, self._bn0, self._swish), - self._blocks[: self._stage_idxs[0]], - self._blocks[self._stage_idxs[0] : self._stage_idxs[1]], - self._blocks[self._stage_idxs[1] : self._stage_idxs[2]], - self._blocks[self._stage_idxs[2] :], - ] + def apply_blocks( + self, x: torch.Tensor, start_idx: int, end_idx: int + ) -> torch.Tensor: + drop_connect_rate = self._global_params.drop_connect_rate - def forward(self, x): - stages = self.get_stages() + for block_number in range(start_idx, end_idx): + drop_connect_prob = drop_connect_rate * block_number / len(self._blocks) + x = self._blocks[block_number](x, drop_connect_prob) - block_number = 0.0 - drop_connect_rate = self._global_params.drop_connect_rate + return x + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: features = [] - for i in range(self._depth + 1): - # Identity and Sequential stages - if i < 2: - x = stages[i](x) - - # Block stages need drop_connect rate - else: - for module in stages[i]: - drop_connect = drop_connect_rate * block_number / len(self._blocks) - block_number += 1.0 - x = module(x, drop_connect) + if self._depth >= 1: + x = self._conv_stem(x) + x = self._bn0(x) + x = self._swish(x) + features.append(x) + + if self._depth >= 2: + x = self.apply_blocks(x, 0, self._stage_idxs[0]) + features.append(x) + + if self._depth >= 3: + x = self.apply_blocks(x, self._stage_idxs[0], self._stage_idxs[1]) + features.append(x) + + if self._depth >= 4: + x = self.apply_blocks(x, self._stage_idxs[1], self._stage_idxs[2]) + features.append(x) + + if self._depth >= 5: + x = self.apply_blocks(x, self._stage_idxs[2], len(self._blocks)) features.append(x) return features From e3f6c7086e4d93ede9095aeb5dd2f0325f9caa41 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sun, 12 Jan 2025 03:01:18 +0000 Subject: [PATCH 06/57] inceptionresnetv2 --- .../encoders/inceptionresnetv2.py | 65 ++++++++++++++----- 1 file changed, 49 insertions(+), 16 deletions(-) diff --git a/segmentation_models_pytorch/encoders/inceptionresnetv2.py b/segmentation_models_pytorch/encoders/inceptionresnetv2.py index 5d90c7f4..2c6b8de3 100644 --- a/segmentation_models_pytorch/encoders/inceptionresnetv2.py +++ b/segmentation_models_pytorch/encoders/inceptionresnetv2.py @@ -25,7 +25,6 @@ import torch.nn as nn from pretrainedmodels.models.inceptionresnetv2 import InceptionResNetV2 -from pretrainedmodels.models.inceptionresnetv2 import pretrained_settings from ._base import EncoderMixin @@ -56,22 +55,37 @@ def make_dilated(self, *args, **kwargs): "due to pooling operation for downsampling!" ) - def get_stages(self): - return [ - nn.Identity(), - nn.Sequential(self.conv2d_1a, self.conv2d_2a, self.conv2d_2b), - nn.Sequential(self.maxpool_3a, self.conv2d_3b, self.conv2d_4a), - nn.Sequential(self.maxpool_5a, self.mixed_5b, self.repeat), - nn.Sequential(self.mixed_6a, self.repeat_1), - nn.Sequential(self.mixed_7a, self.repeat_2, self.block8, self.conv2d_7b), - ] - def forward(self, x): - stages = self.get_stages() - features = [] - for i in range(self._depth + 1): - x = stages[i](x) + + if self._depth >= 1: + x = self.conv2d_1a(x) + x = self.conv2d_2a(x) + x = self.conv2d_2b(x) + features.append(x) + + if self._depth >= 2: + x = self.maxpool_3a(x) + x = self.conv2d_3b(x) + x = self.conv2d_4a(x) + features.append(x) + + if self._depth >= 3: + x = self.maxpool_5a(x) + x = self.mixed_5b(x) + x = self.repeat(x) + features.append(x) + + if self._depth >= 4: + x = self.mixed_6a(x) + x = self.repeat_1(x) + features.append(x) + + if self._depth >= 5: + x = self.mixed_7a(x) + x = self.repeat_2(x) + x = self.block8(x) + x = self.conv2d_7b(x) features.append(x) return features @@ -85,7 +99,26 @@ def load_state_dict(self, state_dict, **kwargs): inceptionresnetv2_encoders = { "inceptionresnetv2": { "encoder": InceptionResNetV2Encoder, - "pretrained_settings": pretrained_settings["inceptionresnetv2"], + "pretrained_settings": { + "imagenet": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth", + "input_space": "RGB", + "input_size": [3, 299, 299], + "input_range": [0, 1], + "mean": [0.5, 0.5, 0.5], + "std": [0.5, 0.5, 0.5], + "num_classes": 1000, + }, + "imagenet+background": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth", + "input_space": "RGB", + "input_size": [3, 299, 299], + "input_range": [0, 1], + "mean": [0.5, 0.5, 0.5], + "std": [0.5, 0.5, 0.5], + "num_classes": 1001, + }, + }, "params": {"out_channels": (3, 64, 192, 320, 1088, 1536), "num_classes": 1000}, } } From 20b28beb3ec9f9668a6535a6c0563b55777aa829 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sun, 12 Jan 2025 03:01:30 +0000 Subject: [PATCH 07/57] inceptionv4 --- .../encoders/inceptionv4.py | 45 ++++++++++++++++--- 1 file changed, 39 insertions(+), 6 deletions(-) diff --git a/segmentation_models_pytorch/encoders/inceptionv4.py b/segmentation_models_pytorch/encoders/inceptionv4.py index 96540f9a..4731053a 100644 --- a/segmentation_models_pytorch/encoders/inceptionv4.py +++ b/segmentation_models_pytorch/encoders/inceptionv4.py @@ -25,7 +25,6 @@ import torch.nn as nn from pretrainedmodels.models.inceptionv4 import InceptionV4 -from pretrainedmodels.models.inceptionv4 import pretrained_settings from ._base import EncoderMixin @@ -66,11 +65,26 @@ def get_stages(self): ] def forward(self, x): - stages = self.get_stages() - features = [] - for i in range(self._depth + 1): - x = stages[i](x) + + if self._depth >= 1: + x = self.features[: self._stage_idxs[0]](x) + features.append(x) + + if self._depth >= 2: + x = self.features[self._stage_idxs[0] : self._stage_idxs[1]](x) + features.append(x) + + if self._depth >= 3: + x = self.features[self._stage_idxs[1] : self._stage_idxs[2]](x) + features.append(x) + + if self._depth >= 4: + x = self.features[self._stage_idxs[2] : self._stage_idxs[3]](x) + features.append(x) + + if self._depth >= 5: + x = self.features[self._stage_idxs[3] :](x) features.append(x) return features @@ -84,7 +98,26 @@ def load_state_dict(self, state_dict, **kwargs): inceptionv4_encoders = { "inceptionv4": { "encoder": InceptionV4Encoder, - "pretrained_settings": pretrained_settings["inceptionv4"], + "pretrained_settings": { + "imagenet": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/inceptionv4-8e4777a0.pth", + "input_space": "RGB", + "input_size": [3, 299, 299], + "input_range": [0, 1], + "mean": [0.5, 0.5, 0.5], + "std": [0.5, 0.5, 0.5], + "num_classes": 1000, + }, + "imagenet+background": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/inceptionv4-8e4777a0.pth", + "input_space": "RGB", + "input_size": [3, 299, 299], + "input_range": [0, 1], + "mean": [0.5, 0.5, 0.5], + "std": [0.5, 0.5, 0.5], + "num_classes": 1001, + }, + }, "params": { "stage_idxs": (3, 5, 9, 15), "out_channels": (3, 64, 192, 384, 1024, 1536), From d996165e1e2a3b1701230a78cfe9dd6b69200bfa Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sun, 12 Jan 2025 03:01:50 +0000 Subject: [PATCH 08/57] mix-transformer --- .../encoders/mix_transformer.py | 61 +++++++++++-------- 1 file changed, 36 insertions(+), 25 deletions(-) diff --git a/segmentation_models_pytorch/encoders/mix_transformer.py b/segmentation_models_pytorch/encoders/mix_transformer.py index cf4c3f33..1595ae0d 100644 --- a/segmentation_models_pytorch/encoders/mix_transformer.py +++ b/segmentation_models_pytorch/encoders/mix_transformer.py @@ -526,30 +526,43 @@ def __init__(self, out_channels, depth=5, **kwargs): self._depth = depth self._in_channels = 3 - def get_stages(self): - return [ - nn.Identity(), - nn.Identity(), - nn.Sequential(self.patch_embed1, self.block1, self.norm1), - nn.Sequential(self.patch_embed2, self.block2, self.norm2), - nn.Sequential(self.patch_embed3, self.block3, self.norm3), - nn.Sequential(self.patch_embed4, self.block4, self.norm4), - ] - def forward(self, x): - stages = self.get_stages() - # create dummy output for the first block - B, _, H, W = x.shape - dummy = torch.empty([B, 0, H // 2, W // 2], dtype=x.dtype, device=x.device) - - features = [] - for i in range(self._depth + 1): - if i == 1: - features.append(dummy) - else: - x = stages[i](x).contiguous() - features.append(x) + batch_size, _, height, width = x.shape + dummy = torch.empty( + [batch_size, 0, height // 2, width // 2], dtype=x.dtype, device=x.device + ) + + features = [dummy] + + if self._depth >= 2: + x = self.patch_embed1(x) + x = self.block1(x) + x = self.norm1(x) + x = x.contiguous() + features.append(x) + + if self._depth >= 3: + x = self.patch_embed2(x) + x = self.block2(x) + x = self.norm2(x) + x = x.contiguous() + features.append(x) + + if self._depth >= 4: + x = self.patch_embed3(x) + x = self.block3(x) + x = self.norm3(x) + x = x.contiguous() + features.append(x) + + if self._depth >= 5: + x = self.patch_embed4(x) + x = self.block4(x) + x = self.norm4(x) + x = x.contiguous() + features.append(x) + return features def load_state_dict(self, state_dict): @@ -560,9 +573,7 @@ def load_state_dict(self, state_dict): def get_pretrained_cfg(name): return { - "url": "https://github.com/qubvel/segmentation_models.pytorch/releases/download/v0.0.2/{}.pth".format( - name - ), + "url": f"https://github.com/qubvel/segmentation_models.pytorch/releases/download/v0.0.2/{name}.pth", "input_space": "RGB", "input_size": [3, 224, 224], "input_range": [0, 1], From 9e381541423a57b3a323578d8a6ed1a38be39ade Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sun, 12 Jan 2025 03:01:57 +0000 Subject: [PATCH 09/57] mobilenet --- .../encoders/mobilenet.py | 36 ++++++++++--------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/segmentation_models_pytorch/encoders/mobilenet.py b/segmentation_models_pytorch/encoders/mobilenet.py index dd30f142..f49175d1 100644 --- a/segmentation_models_pytorch/encoders/mobilenet.py +++ b/segmentation_models_pytorch/encoders/mobilenet.py @@ -24,7 +24,6 @@ """ import torchvision -import torch.nn as nn from ._base import EncoderMixin @@ -37,22 +36,27 @@ def __init__(self, out_channels, depth=5, **kwargs): self._in_channels = 3 del self.classifier - def get_stages(self): - return [ - nn.Identity(), - self.features[:2], - self.features[2:4], - self.features[4:7], - self.features[7:14], - self.features[14:], - ] - def forward(self, x): - stages = self.get_stages() - features = [] - for i in range(self._depth + 1): - x = stages[i](x) + + if self._depth >= 1: + x = self.features[:2](x) + features.append(x) + + if self._depth >= 2: + x = self.features[2:4](x) + features.append(x) + + if self._depth >= 3: + x = self.features[4:7](x) + features.append(x) + + if self._depth >= 4: + x = self.features[7:14](x) + features.append(x) + + if self._depth >= 5: + x = self.features[14:](x) features.append(x) return features @@ -68,9 +72,9 @@ def load_state_dict(self, state_dict, **kwargs): "encoder": MobileNetV2Encoder, "pretrained_settings": { "imagenet": { + "url": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", "mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225], - "url": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", "input_space": "RGB", "input_range": [0, 1], } From c6e5d53e731bdef4a1111a0966b7877e9547b030 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sun, 12 Jan 2025 03:02:05 +0000 Subject: [PATCH 10/57] mobileone --- .../encoders/mobileone.py | 35 +++++++++++-------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/segmentation_models_pytorch/encoders/mobileone.py b/segmentation_models_pytorch/encoders/mobileone.py index 1c031d28..c6f7c391 100644 --- a/segmentation_models_pytorch/encoders/mobileone.py +++ b/segmentation_models_pytorch/encoders/mobileone.py @@ -355,16 +355,6 @@ def __init__( num_se_blocks=num_blocks_per_stage[3] if use_se else 0, ) - def get_stages(self): - return [ - nn.Identity(), - self.stage0, - self.stage1, - self.stage2, - self.stage3, - self.stage4, - ] - def _make_stage( self, planes: int, num_blocks: int, num_se_blocks: int ) -> nn.Sequential: @@ -417,13 +407,30 @@ def _make_stage( self.cur_layer_idx += 1 return nn.Sequential(*blocks) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: """Apply forward pass.""" - stages = self.get_stages() features = [] - for i in range(self._depth + 1): - x = stages[i](x) + + if self._depth >= 1: + x = self.stage0(x) + features.append(x) + + if self._depth >= 2: + x = self.stage1(x) + features.append(x) + + if self._depth >= 3: + x = self.stage2(x) + features.append(x) + + if self._depth >= 4: + x = self.stage3(x) features.append(x) + + if self._depth >= 5: + x = self.stage4(x) + features.append(x) + return features def load_state_dict(self, state_dict, **kwargs): From 5a76722da7caf6b286df01679fd879c61be7073f Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sun, 12 Jan 2025 03:02:10 +0000 Subject: [PATCH 11/57] resnet --- .../encoders/resnet.py | 266 ++++++++++++++---- 1 file changed, 217 insertions(+), 49 deletions(-) diff --git a/segmentation_models_pytorch/encoders/resnet.py b/segmentation_models_pytorch/encoders/resnet.py index 2040a42c..fac29c96 100644 --- a/segmentation_models_pytorch/encoders/resnet.py +++ b/segmentation_models_pytorch/encoders/resnet.py @@ -23,19 +23,18 @@ depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). """ -from copy import deepcopy - -import torch.nn as nn +import torch from torchvision.models.resnet import ResNet from torchvision.models.resnet import BasicBlock from torchvision.models.resnet import Bottleneck -from pretrainedmodels.models.torchvision_models import pretrained_settings from ._base import EncoderMixin class ResNetEncoder(ResNet, EncoderMixin): + """ResNet encoder implementation.""" + def __init__(self, out_channels, depth=5, **kwargs): super().__init__(**kwargs) self._depth = depth @@ -45,22 +44,30 @@ def __init__(self, out_channels, depth=5, **kwargs): del self.fc del self.avgpool - def get_stages(self): - return [ - nn.Identity(), - nn.Sequential(self.conv1, self.bn1, self.relu), - nn.Sequential(self.maxpool, self.layer1), - self.layer2, - self.layer3, - self.layer4, - ] + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: + features = [] - def forward(self, x): - stages = self.get_stages() + if self._depth >= 1: + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + features.append(x) - features = [] - for i in range(self._depth + 1): - x = stages[i](x) + if self._depth >= 2: + x = self.maxpool(x) + x = self.layer1(x) + features.append(x) + + if self._depth >= 3: + x = self.layer2(x) + features.append(x) + + if self._depth >= 4: + x = self.layer3(x) + features.append(x) + + if self._depth >= 5: + x = self.layer4(x) features.append(x) return features @@ -71,58 +78,219 @@ def load_state_dict(self, state_dict, **kwargs): super().load_state_dict(state_dict, **kwargs) -new_settings = { +pretrained_settings = { "resnet18": { - "ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet18-d92f0530.pth", # noqa - "swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet18-118f1556.pth", # noqa + "imagenet": { + "url": "https://download.pytorch.org/models/resnet18-5c106cde.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + }, + "ssl": { + "url": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet18-d92f0530.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + }, + "swsl": { + "url": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet18-118f1556.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + }, + }, + "resnet34": { + "imagenet": { + "url": "https://download.pytorch.org/models/resnet34-333f7ec4.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } }, "resnet50": { - "ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet50-08389792.pth", # noqa - "swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet50-16a12f1b.pth", # noqa + "imagenet": { + "url": "https://download.pytorch.org/models/resnet50-19c8e357.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + }, + "ssl": { + "url": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet50-08389792.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + }, + "swsl": { + "url": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet50-16a12f1b.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + }, + }, + "resnet101": { + "imagenet": { + "url": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "resnet152": { + "imagenet": { + "url": "https://download.pytorch.org/models/resnet152-b121ed2d.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } }, "resnext50_32x4d": { - "imagenet": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", - "ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext50_32x4-ddb3e555.pth", # noqa - "swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext50_32x4-72679e44.pth", # noqa + "imagenet": { + "url": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + }, + "ssl": { + "url": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext50_32x4-ddb3e555.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + }, + "swsl": { + "url": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext50_32x4-72679e44.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + }, }, "resnext101_32x4d": { - "ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x4-dc43570a.pth", # noqa - "swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x4-3f87e46b.pth", # noqa + "ssl": { + "url": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x4-dc43570a.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + }, + "swsl": { + "url": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x4-3f87e46b.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + }, }, "resnext101_32x8d": { - "imagenet": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", - "instagram": "https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth", - "ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x8-2cfe2f8b.pth", # noqa - "swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x8-b4712904.pth", # noqa + "imagenet": { + "url": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + }, + "instagram": { + "url": "https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + }, + "ssl": { + "url": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x8-2cfe2f8b.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + }, + "swsl": { + "url": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x8-b4712904.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + }, }, "resnext101_32x16d": { - "instagram": "https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth", - "ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x16-15fffa57.pth", # noqa - "swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x16-f3559a9c.pth", # noqa + "instagram": { + "url": "https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + }, + "ssl": { + "url": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x16-15fffa57.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + }, + "swsl": { + "url": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x16-f3559a9c.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + }, }, "resnext101_32x32d": { - "instagram": "https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth" + "instagram": { + "url": "https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } }, "resnext101_32x48d": { - "instagram": "https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth" - }, -} - -pretrained_settings = deepcopy(pretrained_settings) -for model_name, sources in new_settings.items(): - if model_name not in pretrained_settings: - pretrained_settings[model_name] = {} - - for source_name, source_url in sources.items(): - pretrained_settings[model_name][source_name] = { - "url": source_url, + "instagram": { + "url": "https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth", "input_size": [3, 224, 224], "input_range": [0, 1], "mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225], "num_classes": 1000, } - + }, +} resnet_encoders = { "resnet18": { From 36d056be0d8e37f16a374efc4a5e0827eea7af84 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sun, 12 Jan 2025 03:02:15 +0000 Subject: [PATCH 12/57] senet --- segmentation_models_pytorch/encoders/senet.py | 95 ++++++++++++++++++- 1 file changed, 90 insertions(+), 5 deletions(-) diff --git a/segmentation_models_pytorch/encoders/senet.py b/segmentation_models_pytorch/encoders/senet.py index 8e0f6fd8..506a9717 100644 --- a/segmentation_models_pytorch/encoders/senet.py +++ b/segmentation_models_pytorch/encoders/senet.py @@ -30,7 +30,6 @@ SEBottleneck, SEResNetBottleneck, SEResNeXtBottleneck, - pretrained_settings, ) from ._base import EncoderMixin @@ -57,11 +56,27 @@ def get_stages(self): ] def forward(self, x): - stages = self.get_stages() - features = [] - for i in range(self._depth + 1): - x = stages[i](x) + + if self._depth >= 1: + x = self.layer0[:-1](x) + features.append(x) + + if self._depth >= 2: + x = self.layer0[-1](x) + x = self.layer1(x) + features.append(x) + + if self._depth >= 3: + x = self.layer2(x) + features.append(x) + + if self._depth >= 4: + x = self.layer3(x) + features.append(x) + + if self._depth >= 5: + x = self.layer4(x) features.append(x) return features @@ -72,6 +87,76 @@ def load_state_dict(self, state_dict, **kwargs): super().load_state_dict(state_dict, **kwargs) +pretrained_settings = { + "senet154": { + "imagenet": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "se_resnet50": { + "imagenet": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "se_resnet101": { + "imagenet": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "se_resnet152": { + "imagenet": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "se_resnext50_32x4d": { + "imagenet": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "se_resnext101_32x4d": { + "imagenet": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, +} + + senet_encoders = { "senet154": { "encoder": SENetEncoder, From aefcfd4d6b7dbc03085d32ce9bb22263b17b1173 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sun, 12 Jan 2025 03:02:20 +0000 Subject: [PATCH 13/57] vgg --- segmentation_models_pytorch/encoders/vgg.py | 115 +++++++++++++++++++- 1 file changed, 109 insertions(+), 6 deletions(-) diff --git a/segmentation_models_pytorch/encoders/vgg.py b/segmentation_models_pytorch/encoders/vgg.py index cbc602c8..7ff7843a 100644 --- a/segmentation_models_pytorch/encoders/vgg.py +++ b/segmentation_models_pytorch/encoders/vgg.py @@ -26,7 +26,6 @@ import torch.nn as nn from torchvision.models.vgg import VGG from torchvision.models.vgg import make_layers -from pretrainedmodels.models.torchvision_models import pretrained_settings from ._base import EncoderMixin @@ -66,12 +65,25 @@ def get_stages(self): return stages def forward(self, x): - stages = self.get_stages() - features = [] - for i in range(self._depth + 1): - x = stages[i](x) - features.append(x) + depth = 0 + + for i, module in enumerate(self.features): + if isinstance(module, nn.MaxPool2d): + features.append(x) + depth += 1 + + # last layer is always maxpool, we just apply it and break + if i == len(self.features) - 1: + x = module(x) + features.append(x) + break + + # if depth is reached, break + if depth > self._depth: + break + + x = module(x) return features @@ -83,6 +95,97 @@ def load_state_dict(self, state_dict, **kwargs): super().load_state_dict(state_dict, **kwargs) +pretrained_settings = { + "vgg11": { + "imagenet": { + "url": "https://download.pytorch.org/models/vgg11-bbd30ac9.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "vgg11_bn": { + "imagenet": { + "url": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "vgg13": { + "imagenet": { + "url": "https://download.pytorch.org/models/vgg13-c768596a.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "vgg13_bn": { + "imagenet": { + "url": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "vgg16": { + "imagenet": { + "url": "https://download.pytorch.org/models/vgg16-397923af.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "vgg16_bn": { + "imagenet": { + "url": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "vgg19": { + "imagenet": { + "url": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "vgg19_bn": { + "imagenet": { + "url": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, +} + vgg_encoders = { "vgg11": { "encoder": VGGEncoder, From e9628bf8a254a4c263b48dcc3d24a966c8a034c5 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sun, 12 Jan 2025 03:02:26 +0000 Subject: [PATCH 14/57] xception --- .../encoders/xception.py | 83 ++++++++++++------- 1 file changed, 52 insertions(+), 31 deletions(-) diff --git a/segmentation_models_pytorch/encoders/xception.py b/segmentation_models_pytorch/encoders/xception.py index c8c476ce..d0ee22f9 100644 --- a/segmentation_models_pytorch/encoders/xception.py +++ b/segmentation_models_pytorch/encoders/xception.py @@ -1,6 +1,3 @@ -import torch.nn as nn - -from pretrainedmodels.models.xception import pretrained_settings from pretrainedmodels.models.xception import Xception from ._base import EncoderMixin @@ -26,36 +23,45 @@ def make_dilated(self, *args, **kwargs): "due to pooling operation for downsampling!" ) - def get_stages(self): - return [ - nn.Identity(), - nn.Sequential( - self.conv1, self.bn1, self.relu, self.conv2, self.bn2, self.relu - ), - self.block1, - self.block2, - nn.Sequential( - self.block3, - self.block4, - self.block5, - self.block6, - self.block7, - self.block8, - self.block9, - self.block10, - self.block11, - ), - nn.Sequential( - self.block12, self.conv3, self.bn3, self.relu, self.conv4, self.bn4 - ), - ] - def forward(self, x): - stages = self.get_stages() - features = [] - for i in range(self._depth + 1): - x = stages[i](x) + + if self._depth >= 1: + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + features.append(x) + + if self._depth >= 2: + x = self.block1(x) + features.append(x) + + if self._depth >= 3: + x = self.block2(x) + features.append(x) + + if self._depth >= 4: + x = self.block3(x) + x = self.block4(x) + x = self.block5(x) + x = self.block6(x) + x = self.block7(x) + x = self.block8(x) + x = self.block9(x) + x = self.block10(x) + x = self.block11(x) + features.append(x) + + if self._depth >= 5: + x = self.block12(x) + x = self.conv3(x) + x = self.bn3(x) + x = self.relu(x) + x = self.conv4(x) + x = self.bn4(x) features.append(x) return features @@ -68,6 +74,21 @@ def load_state_dict(self, state_dict): super().load_state_dict(state_dict) +pretrained_settings = { + "xception": { + "imagenet": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/xception-43020ad28.pth", + "input_space": "RGB", + "input_size": [3, 299, 299], + "input_range": [0, 1], + "mean": [0.5, 0.5, 0.5], + "std": [0.5, 0.5, 0.5], + "num_classes": 1000, + "scale": 0.8975, # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 + } + } +} + xception_encoders = { "xception": { "encoder": XceptionEncoder, From 70262e51ac8a6d94876c5ac01720ed41264f826a Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sun, 12 Jan 2025 12:37:08 +0000 Subject: [PATCH 15/57] Deprecate `timm-` encoders, remap to `tu-` most of them --- segmentation_models_pytorch/base/model.py | 16 + .../encoders/__init__.py | 62 +++- .../encoders/timm_efficientnet.py | 34 +- .../encoders/timm_gernet.py | 124 ------- .../encoders/timm_mobilenetv3.py | 151 -------- .../encoders/timm_regnet.py | 350 ------------------ .../encoders/timm_res2net.py | 163 -------- .../encoders/timm_resnest.py | 208 ----------- .../encoders/timm_sknet.py | 37 +- .../encoders/timm_universal.py | 1 + tests/encoders/test_smp_encoders.py | 3 + 11 files changed, 105 insertions(+), 1044 deletions(-) delete mode 100644 segmentation_models_pytorch/encoders/timm_gernet.py delete mode 100644 segmentation_models_pytorch/encoders/timm_mobilenetv3.py delete mode 100644 segmentation_models_pytorch/encoders/timm_regnet.py delete mode 100644 segmentation_models_pytorch/encoders/timm_res2net.py delete mode 100644 segmentation_models_pytorch/encoders/timm_resnest.py diff --git a/segmentation_models_pytorch/base/model.py b/segmentation_models_pytorch/base/model.py index 6d7bf643..cfd5a0a7 100644 --- a/segmentation_models_pytorch/base/model.py +++ b/segmentation_models_pytorch/base/model.py @@ -2,6 +2,7 @@ from . import initialization as init from .hub_mixin import SMPHubMixin +from ..encoders.timm_universal import TimmUniversalEncoder class SegmentationModel(torch.nn.Module, SMPHubMixin): @@ -73,3 +74,18 @@ def predict(self, x): x = self.forward(x) return x + + def load_state_dict(self, state_dict, **kwargs): + # for compatibility of weights for + # timm- ported encoders with TimmUniversalEncoder + if isinstance(self.encoder, TimmUniversalEncoder): + keys = list(state_dict.keys()) + for key in keys: + new_key = key + if key.startswith("encoder.") and not key.startswith("encoder.model."): + new_key = key.replace("encoder.", "encoder.model.") + if "gernet" in self.encoder.name: + new_key = new_key.replace(".stages.", ".stages_") + state_dict[new_key] = state_dict.pop(key) + + return super().load_state_dict(state_dict, **kwargs) diff --git a/segmentation_models_pytorch/encoders/__init__.py b/segmentation_models_pytorch/encoders/__init__.py index c4a4c037..2dbbf020 100644 --- a/segmentation_models_pytorch/encoders/__init__.py +++ b/segmentation_models_pytorch/encoders/__init__.py @@ -1,4 +1,6 @@ import timm +import copy +import warnings import functools import torch.utils.model_zoo as model_zoo @@ -13,12 +15,7 @@ from .mobilenet import mobilenet_encoders from .xception import xception_encoders from .timm_efficientnet import timm_efficientnet_encoders -from .timm_resnest import timm_resnest_encoders -from .timm_res2net import timm_res2net_encoders -from .timm_regnet import timm_regnet_encoders from .timm_sknet import timm_sknet_encoders -from .timm_mobilenetv3 import timm_mobilenetv3_encoders -from .timm_gernet import timm_gernet_encoders from .mix_transformer import mix_transformer_encoders from .mobileone import mobileone_encoders @@ -26,6 +23,14 @@ from ._preprocessing import preprocess_input +__all__ = [ + "encoders", + "get_encoder", + "get_encoder_names", + "get_preprocessing_params", + "get_preprocessing_fn", +] + encoders = {} encoders.update(resnet_encoders) encoders.update(dpn_encoders) @@ -38,17 +43,38 @@ encoders.update(mobilenet_encoders) encoders.update(xception_encoders) encoders.update(timm_efficientnet_encoders) -encoders.update(timm_resnest_encoders) -encoders.update(timm_res2net_encoders) -encoders.update(timm_regnet_encoders) encoders.update(timm_sknet_encoders) -encoders.update(timm_mobilenetv3_encoders) -encoders.update(timm_gernet_encoders) encoders.update(mix_transformer_encoders) encoders.update(mobileone_encoders) +def is_equivalent_to_timm_universal(name): + patterns = [ + "timm-regnet", + "timm-res2", + "timm-resnest", + "timm-mobilenetv3", + "timm-gernet", + ] + for pattern in patterns: + if name.startswith(pattern): + return True + return False + + def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs): + if name.startswith("timm-"): + warnings.warn( + "`timm-` encoders are deprecated and will be removed in the future. " + "Please use `tu-` encoders instead." + ) + + # convert timm- models to tu- models + if is_equivalent_to_timm_universal(name): + name = name.replace("timm-", "tu-") + if "minimal" in name: + name = name.replace("tu-", "tu-tf_") + if name.startswith("tu-"): name = name[3:] encoder = TimmUniversalEncoder( @@ -61,18 +87,16 @@ def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, ** ) return encoder - try: - Encoder = encoders[name]["encoder"] - except KeyError: + if name not in encoders: raise KeyError( - "Wrong encoder name `{}`, supported encoders: {}".format( - name, list(encoders.keys()) - ) + f"Wrong encoder name `{name}`, supported encoders: {list(encoders.keys())}" ) - params = encoders[name]["params"] - params.update(depth=depth) - encoder = Encoder(**params) + params = copy.deepcopy(encoders[name]["params"]) + params["depth"] = depth + + EncoderClass = encoders[name]["encoder"] + encoder = EncoderClass(**params) if weights is not None: try: diff --git a/segmentation_models_pytorch/encoders/timm_efficientnet.py b/segmentation_models_pytorch/encoders/timm_efficientnet.py index fc248575..cf2811eb 100644 --- a/segmentation_models_pytorch/encoders/timm_efficientnet.py +++ b/segmentation_models_pytorch/encoders/timm_efficientnet.py @@ -105,22 +105,28 @@ def __init__(self, stage_idxs, out_channels, depth=5, **kwargs): del self.classifier - def get_stages(self): - return [ - nn.Identity(), - nn.Sequential(self.conv_stem, self.bn1), - self.blocks[: self._stage_idxs[0]], - self.blocks[self._stage_idxs[0] : self._stage_idxs[1]], - self.blocks[self._stage_idxs[1] : self._stage_idxs[2]], - self.blocks[self._stage_idxs[2] :], - ] - def forward(self, x): - stages = self.get_stages() - features = [] - for i in range(self._depth + 1): - x = stages[i](x) + + if self._depth >= 1: + x = self.conv_stem(x) + x = self.bn1(x) + features.append(x) + + if self._depth >= 2: + x = self.blocks[: self._stage_idxs[0]](x) + features.append(x) + + if self._depth >= 3: + x = self.blocks[self._stage_idxs[0] : self._stage_idxs[1]](x) + features.append(x) + + if self._depth >= 4: + x = self.blocks[self._stage_idxs[1] : self._stage_idxs[2]](x) + features.append(x) + + if self._depth >= 5: + x = self.blocks[self._stage_idxs[2] :](x) features.append(x) return features diff --git a/segmentation_models_pytorch/encoders/timm_gernet.py b/segmentation_models_pytorch/encoders/timm_gernet.py deleted file mode 100644 index e0c3354d..00000000 --- a/segmentation_models_pytorch/encoders/timm_gernet.py +++ /dev/null @@ -1,124 +0,0 @@ -from timm.models import ByoModelCfg, ByoBlockCfg, ByobNet - -from ._base import EncoderMixin -import torch.nn as nn - - -class GERNetEncoder(ByobNet, EncoderMixin): - def __init__(self, out_channels, depth=5, **kwargs): - super().__init__(**kwargs) - self._depth = depth - self._out_channels = out_channels - self._in_channels = 3 - - del self.head - - def get_stages(self): - return [ - nn.Identity(), - self.stem, - self.stages[0], - self.stages[1], - self.stages[2], - nn.Sequential(self.stages[3], self.stages[4], self.final_conv), - ] - - def forward(self, x): - stages = self.get_stages() - - features = [] - for i in range(self._depth + 1): - x = stages[i](x) - features.append(x) - - return features - - def load_state_dict(self, state_dict, **kwargs): - state_dict.pop("head.fc.weight", None) - state_dict.pop("head.fc.bias", None) - super().load_state_dict(state_dict, **kwargs) - - -regnet_weights = { - "timm-gernet_s": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_s-756b4751.pth" # noqa - }, - "timm-gernet_m": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_m-0873c53a.pth" # noqa - }, - "timm-gernet_l": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_l-f31e2e8d.pth" # noqa - }, -} - -pretrained_settings = {} -for model_name, sources in regnet_weights.items(): - pretrained_settings[model_name] = {} - for source_name, source_url in sources.items(): - pretrained_settings[model_name][source_name] = { - "url": source_url, - "input_range": [0, 1], - "mean": [0.485, 0.456, 0.406], - "std": [0.229, 0.224, 0.225], - "num_classes": 1000, - } - -timm_gernet_encoders = { - "timm-gernet_s": { - "encoder": GERNetEncoder, - "pretrained_settings": pretrained_settings["timm-gernet_s"], - "params": { - "out_channels": (3, 13, 48, 48, 384, 1920), - "cfg": ByoModelCfg( - blocks=( - ByoBlockCfg(type="basic", d=1, c=48, s=2, gs=0, br=1.0), - ByoBlockCfg(type="basic", d=3, c=48, s=2, gs=0, br=1.0), - ByoBlockCfg(type="bottle", d=7, c=384, s=2, gs=0, br=1 / 4), - ByoBlockCfg(type="bottle", d=2, c=560, s=2, gs=1, br=3.0), - ByoBlockCfg(type="bottle", d=1, c=256, s=1, gs=1, br=3.0), - ), - stem_chs=13, - stem_pool=None, - num_features=1920, - ), - }, - }, - "timm-gernet_m": { - "encoder": GERNetEncoder, - "pretrained_settings": pretrained_settings["timm-gernet_m"], - "params": { - "out_channels": (3, 32, 128, 192, 640, 2560), - "cfg": ByoModelCfg( - blocks=( - ByoBlockCfg(type="basic", d=1, c=128, s=2, gs=0, br=1.0), - ByoBlockCfg(type="basic", d=2, c=192, s=2, gs=0, br=1.0), - ByoBlockCfg(type="bottle", d=6, c=640, s=2, gs=0, br=1 / 4), - ByoBlockCfg(type="bottle", d=4, c=640, s=2, gs=1, br=3.0), - ByoBlockCfg(type="bottle", d=1, c=640, s=1, gs=1, br=3.0), - ), - stem_chs=32, - stem_pool=None, - num_features=2560, - ), - }, - }, - "timm-gernet_l": { - "encoder": GERNetEncoder, - "pretrained_settings": pretrained_settings["timm-gernet_l"], - "params": { - "out_channels": (3, 32, 128, 192, 640, 2560), - "cfg": ByoModelCfg( - blocks=( - ByoBlockCfg(type="basic", d=1, c=128, s=2, gs=0, br=1.0), - ByoBlockCfg(type="basic", d=2, c=192, s=2, gs=0, br=1.0), - ByoBlockCfg(type="bottle", d=6, c=640, s=2, gs=0, br=1 / 4), - ByoBlockCfg(type="bottle", d=5, c=640, s=2, gs=1, br=3.0), - ByoBlockCfg(type="bottle", d=4, c=640, s=1, gs=1, br=3.0), - ), - stem_chs=32, - stem_pool=None, - num_features=2560, - ), - }, - }, -} diff --git a/segmentation_models_pytorch/encoders/timm_mobilenetv3.py b/segmentation_models_pytorch/encoders/timm_mobilenetv3.py deleted file mode 100644 index ff733ab9..00000000 --- a/segmentation_models_pytorch/encoders/timm_mobilenetv3.py +++ /dev/null @@ -1,151 +0,0 @@ -import timm -import numpy as np -import torch.nn as nn - -from ._base import EncoderMixin - - -def _make_divisible(x, divisible_by=8): - return int(np.ceil(x * 1.0 / divisible_by) * divisible_by) - - -class MobileNetV3Encoder(nn.Module, EncoderMixin): - def __init__(self, model_name, width_mult, depth=5, **kwargs): - super().__init__() - if "large" not in model_name and "small" not in model_name: - raise ValueError("MobileNetV3 wrong model name {}".format(model_name)) - - self._mode = "small" if "small" in model_name else "large" - self._depth = depth - self._out_channels = self._get_channels(self._mode, width_mult) - self._in_channels = 3 - - # minimal models replace hardswish with relu - self.model = timm.create_model( - model_name=model_name, - scriptable=True, # torch.jit scriptable - exportable=True, # onnx export - features_only=True, - ) - - def _get_channels(self, mode, width_mult): - if mode == "small": - channels = [16, 16, 24, 48, 576] - else: - channels = [16, 24, 40, 112, 960] - channels = [3] + [_make_divisible(x * width_mult) for x in channels] - return tuple(channels) - - def get_stages(self): - if self._mode == "small": - return [ - nn.Identity(), - nn.Sequential(self.model.conv_stem, self.model.bn1, self.model.act1), - self.model.blocks[0], - self.model.blocks[1], - self.model.blocks[2:4], - self.model.blocks[4:], - ] - elif self._mode == "large": - return [ - nn.Identity(), - nn.Sequential( - self.model.conv_stem, - self.model.bn1, - self.model.act1, - self.model.blocks[0], - ), - self.model.blocks[1], - self.model.blocks[2], - self.model.blocks[3:5], - self.model.blocks[5:], - ] - else: - ValueError( - "MobileNetV3 mode should be small or large, got {}".format(self._mode) - ) - - def forward(self, x): - stages = self.get_stages() - - features = [] - for i in range(self._depth + 1): - x = stages[i](x) - features.append(x) - - return features - - def load_state_dict(self, state_dict, **kwargs): - state_dict.pop("conv_head.weight", None) - state_dict.pop("conv_head.bias", None) - state_dict.pop("classifier.weight", None) - state_dict.pop("classifier.bias", None) - self.model.load_state_dict(state_dict, **kwargs) - - -mobilenetv3_weights = { - "tf_mobilenetv3_large_075": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth" # noqa - }, - "tf_mobilenetv3_large_100": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth" # noqa - }, - "tf_mobilenetv3_large_minimal_100": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth" # noqa - }, - "tf_mobilenetv3_small_075": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth" # noqa - }, - "tf_mobilenetv3_small_100": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth" # noqa - }, - "tf_mobilenetv3_small_minimal_100": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth" # noqa - }, -} - -pretrained_settings = {} -for model_name, sources in mobilenetv3_weights.items(): - pretrained_settings[model_name] = {} - for source_name, source_url in sources.items(): - pretrained_settings[model_name][source_name] = { - "url": source_url, - "input_range": [0, 1], - "mean": [0.485, 0.456, 0.406], - "std": [0.229, 0.224, 0.225], - "input_space": "RGB", - } - - -timm_mobilenetv3_encoders = { - "timm-mobilenetv3_large_075": { - "encoder": MobileNetV3Encoder, - "pretrained_settings": pretrained_settings["tf_mobilenetv3_large_075"], - "params": {"model_name": "tf_mobilenetv3_large_075", "width_mult": 0.75}, - }, - "timm-mobilenetv3_large_100": { - "encoder": MobileNetV3Encoder, - "pretrained_settings": pretrained_settings["tf_mobilenetv3_large_100"], - "params": {"model_name": "tf_mobilenetv3_large_100", "width_mult": 1.0}, - }, - "timm-mobilenetv3_large_minimal_100": { - "encoder": MobileNetV3Encoder, - "pretrained_settings": pretrained_settings["tf_mobilenetv3_large_minimal_100"], - "params": {"model_name": "tf_mobilenetv3_large_minimal_100", "width_mult": 1.0}, - }, - "timm-mobilenetv3_small_075": { - "encoder": MobileNetV3Encoder, - "pretrained_settings": pretrained_settings["tf_mobilenetv3_small_075"], - "params": {"model_name": "tf_mobilenetv3_small_075", "width_mult": 0.75}, - }, - "timm-mobilenetv3_small_100": { - "encoder": MobileNetV3Encoder, - "pretrained_settings": pretrained_settings["tf_mobilenetv3_small_100"], - "params": {"model_name": "tf_mobilenetv3_small_100", "width_mult": 1.0}, - }, - "timm-mobilenetv3_small_minimal_100": { - "encoder": MobileNetV3Encoder, - "pretrained_settings": pretrained_settings["tf_mobilenetv3_small_minimal_100"], - "params": {"model_name": "tf_mobilenetv3_small_minimal_100", "width_mult": 1.0}, - }, -} diff --git a/segmentation_models_pytorch/encoders/timm_regnet.py b/segmentation_models_pytorch/encoders/timm_regnet.py deleted file mode 100644 index cc60b8ba..00000000 --- a/segmentation_models_pytorch/encoders/timm_regnet.py +++ /dev/null @@ -1,350 +0,0 @@ -from ._base import EncoderMixin -from timm.models.regnet import RegNet, RegNetCfg -import torch.nn as nn - - -class RegNetEncoder(RegNet, EncoderMixin): - def __init__(self, out_channels, depth=5, **kwargs): - kwargs["cfg"] = RegNetCfg(**kwargs["cfg"]) - super().__init__(**kwargs) - self._depth = depth - self._out_channels = out_channels - self._in_channels = 3 - - del self.head - - def get_stages(self): - return [nn.Identity(), self.stem, self.s1, self.s2, self.s3, self.s4] - - def forward(self, x): - stages = self.get_stages() - - features = [] - for i in range(self._depth + 1): - x = stages[i](x) - features.append(x) - - return features - - def load_state_dict(self, state_dict, **kwargs): - state_dict.pop("head.fc.weight", None) - state_dict.pop("head.fc.bias", None) - super().load_state_dict(state_dict, **kwargs) - - -regnet_weights = { - "timm-regnetx_002": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_002-e7e85e5c.pth" # noqa - }, - "timm-regnetx_004": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_004-7d0e9424.pth" # noqa - }, - "timm-regnetx_006": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_006-85ec1baa.pth" # noqa - }, - "timm-regnetx_008": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_008-d8b470eb.pth" # noqa - }, - "timm-regnetx_016": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_016-65ca972a.pth" # noqa - }, - "timm-regnetx_032": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_032-ed0c7f7e.pth" # noqa - }, - "timm-regnetx_040": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_040-73c2a654.pth" # noqa - }, - "timm-regnetx_064": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_064-29278baa.pth" # noqa - }, - "timm-regnetx_080": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_080-7c7fcab1.pth" # noqa - }, - "timm-regnetx_120": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_120-65d5521e.pth" # noqa - }, - "timm-regnetx_160": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_160-c98c4112.pth" # noqa - }, - "timm-regnetx_320": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_320-8ea38b93.pth" # noqa - }, - "timm-regnety_002": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_002-e68ca334.pth" # noqa - }, - "timm-regnety_004": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_004-0db870e6.pth" # noqa - }, - "timm-regnety_006": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_006-c67e57ec.pth" # noqa - }, - "timm-regnety_008": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_008-dc900dbe.pth" # noqa - }, - "timm-regnety_016": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_016-54367f74.pth" # noqa - }, - "timm-regnety_032": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/regnety_032_ra-7f2439f9.pth" # noqa - }, - "timm-regnety_040": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_040-f0d569f9.pth" # noqa - }, - "timm-regnety_064": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_064-0a48325c.pth" # noqa - }, - "timm-regnety_080": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_080-e7f3eb93.pth" # noqa - }, - "timm-regnety_120": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_120-721ba79a.pth" # noqa - }, - "timm-regnety_160": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_160-d64013cd.pth" # noqa - }, - "timm-regnety_320": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth" # noqa - }, -} - -pretrained_settings = {} -for model_name, sources in regnet_weights.items(): - pretrained_settings[model_name] = {} - for source_name, source_url in sources.items(): - pretrained_settings[model_name][source_name] = { - "url": source_url, - "input_size": [3, 224, 224], - "input_range": [0, 1], - "mean": [0.485, 0.456, 0.406], - "std": [0.229, 0.224, 0.225], - "num_classes": 1000, - } - -# at this point I am too lazy to copy configs, so I just used the same configs from timm's repo - - -def _mcfg(**kwargs): - cfg = dict(se_ratio=0.0, bottle_ratio=1.0, stem_width=32) - cfg.update(**kwargs) - return cfg - - -timm_regnet_encoders = { - "timm-regnetx_002": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnetx_002"], - "params": { - "out_channels": (3, 32, 24, 56, 152, 368), - "cfg": _mcfg(w0=24, wa=36.44, wm=2.49, group_size=8, depth=13), - }, - }, - "timm-regnetx_004": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnetx_004"], - "params": { - "out_channels": (3, 32, 32, 64, 160, 384), - "cfg": _mcfg(w0=24, wa=24.48, wm=2.54, group_size=16, depth=22), - }, - }, - "timm-regnetx_006": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnetx_006"], - "params": { - "out_channels": (3, 32, 48, 96, 240, 528), - "cfg": _mcfg(w0=48, wa=36.97, wm=2.24, group_size=24, depth=16), - }, - }, - "timm-regnetx_008": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnetx_008"], - "params": { - "out_channels": (3, 32, 64, 128, 288, 672), - "cfg": _mcfg(w0=56, wa=35.73, wm=2.28, group_size=16, depth=16), - }, - }, - "timm-regnetx_016": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnetx_016"], - "params": { - "out_channels": (3, 32, 72, 168, 408, 912), - "cfg": _mcfg(w0=80, wa=34.01, wm=2.25, group_size=24, depth=18), - }, - }, - "timm-regnetx_032": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnetx_032"], - "params": { - "out_channels": (3, 32, 96, 192, 432, 1008), - "cfg": _mcfg(w0=88, wa=26.31, wm=2.25, group_size=48, depth=25), - }, - }, - "timm-regnetx_040": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnetx_040"], - "params": { - "out_channels": (3, 32, 80, 240, 560, 1360), - "cfg": _mcfg(w0=96, wa=38.65, wm=2.43, group_size=40, depth=23), - }, - }, - "timm-regnetx_064": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnetx_064"], - "params": { - "out_channels": (3, 32, 168, 392, 784, 1624), - "cfg": _mcfg(w0=184, wa=60.83, wm=2.07, group_size=56, depth=17), - }, - }, - "timm-regnetx_080": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnetx_080"], - "params": { - "out_channels": (3, 32, 80, 240, 720, 1920), - "cfg": _mcfg(w0=80, wa=49.56, wm=2.88, group_size=120, depth=23), - }, - }, - "timm-regnetx_120": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnetx_120"], - "params": { - "out_channels": (3, 32, 224, 448, 896, 2240), - "cfg": _mcfg(w0=168, wa=73.36, wm=2.37, group_size=112, depth=19), - }, - }, - "timm-regnetx_160": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnetx_160"], - "params": { - "out_channels": (3, 32, 256, 512, 896, 2048), - "cfg": _mcfg(w0=216, wa=55.59, wm=2.1, group_size=128, depth=22), - }, - }, - "timm-regnetx_320": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnetx_320"], - "params": { - "out_channels": (3, 32, 336, 672, 1344, 2520), - "cfg": _mcfg(w0=320, wa=69.86, wm=2.0, group_size=168, depth=23), - }, - }, - # regnety - "timm-regnety_002": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnety_002"], - "params": { - "out_channels": (3, 32, 24, 56, 152, 368), - "cfg": _mcfg( - w0=24, wa=36.44, wm=2.49, group_size=8, depth=13, se_ratio=0.25 - ), - }, - }, - "timm-regnety_004": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnety_004"], - "params": { - "out_channels": (3, 32, 48, 104, 208, 440), - "cfg": _mcfg( - w0=48, wa=27.89, wm=2.09, group_size=8, depth=16, se_ratio=0.25 - ), - }, - }, - "timm-regnety_006": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnety_006"], - "params": { - "out_channels": (3, 32, 48, 112, 256, 608), - "cfg": _mcfg( - w0=48, wa=32.54, wm=2.32, group_size=16, depth=15, se_ratio=0.25 - ), - }, - }, - "timm-regnety_008": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnety_008"], - "params": { - "out_channels": (3, 32, 64, 128, 320, 768), - "cfg": _mcfg( - w0=56, wa=38.84, wm=2.4, group_size=16, depth=14, se_ratio=0.25 - ), - }, - }, - "timm-regnety_016": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnety_016"], - "params": { - "out_channels": (3, 32, 48, 120, 336, 888), - "cfg": _mcfg( - w0=48, wa=20.71, wm=2.65, group_size=24, depth=27, se_ratio=0.25 - ), - }, - }, - "timm-regnety_032": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnety_032"], - "params": { - "out_channels": (3, 32, 72, 216, 576, 1512), - "cfg": _mcfg( - w0=80, wa=42.63, wm=2.66, group_size=24, depth=21, se_ratio=0.25 - ), - }, - }, - "timm-regnety_040": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnety_040"], - "params": { - "out_channels": (3, 32, 128, 192, 512, 1088), - "cfg": _mcfg( - w0=96, wa=31.41, wm=2.24, group_size=64, depth=22, se_ratio=0.25 - ), - }, - }, - "timm-regnety_064": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnety_064"], - "params": { - "out_channels": (3, 32, 144, 288, 576, 1296), - "cfg": _mcfg( - w0=112, wa=33.22, wm=2.27, group_size=72, depth=25, se_ratio=0.25 - ), - }, - }, - "timm-regnety_080": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnety_080"], - "params": { - "out_channels": (3, 32, 168, 448, 896, 2016), - "cfg": _mcfg( - w0=192, wa=76.82, wm=2.19, group_size=56, depth=17, se_ratio=0.25 - ), - }, - }, - "timm-regnety_120": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnety_120"], - "params": { - "out_channels": (3, 32, 224, 448, 896, 2240), - "cfg": _mcfg( - w0=168, wa=73.36, wm=2.37, group_size=112, depth=19, se_ratio=0.25 - ), - }, - }, - "timm-regnety_160": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnety_160"], - "params": { - "out_channels": (3, 32, 224, 448, 1232, 3024), - "cfg": _mcfg( - w0=200, wa=106.23, wm=2.48, group_size=112, depth=18, se_ratio=0.25 - ), - }, - }, - "timm-regnety_320": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnety_320"], - "params": { - "out_channels": (3, 32, 232, 696, 1392, 3712), - "cfg": _mcfg( - w0=232, wa=115.89, wm=2.53, group_size=232, depth=20, se_ratio=0.25 - ), - }, - }, -} diff --git a/segmentation_models_pytorch/encoders/timm_res2net.py b/segmentation_models_pytorch/encoders/timm_res2net.py deleted file mode 100644 index e97043e3..00000000 --- a/segmentation_models_pytorch/encoders/timm_res2net.py +++ /dev/null @@ -1,163 +0,0 @@ -from ._base import EncoderMixin -from timm.models.resnet import ResNet -from timm.models.res2net import Bottle2neck -import torch.nn as nn - - -class Res2NetEncoder(ResNet, EncoderMixin): - def __init__(self, out_channels, depth=5, **kwargs): - super().__init__(**kwargs) - self._depth = depth - self._out_channels = out_channels - self._in_channels = 3 - - del self.fc - del self.global_pool - - def get_stages(self): - return [ - nn.Identity(), - nn.Sequential(self.conv1, self.bn1, self.act1), - nn.Sequential(self.maxpool, self.layer1), - self.layer2, - self.layer3, - self.layer4, - ] - - def make_dilated(self, *args, **kwargs): - raise ValueError("Res2Net encoders do not support dilated mode") - - def forward(self, x): - stages = self.get_stages() - - features = [] - for i in range(self._depth + 1): - x = stages[i](x) - features.append(x) - - return features - - def load_state_dict(self, state_dict, **kwargs): - state_dict.pop("fc.bias", None) - state_dict.pop("fc.weight", None) - super().load_state_dict(state_dict, **kwargs) - - -res2net_weights = { - "timm-res2net50_26w_4s": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_4s-06e79181.pth" # noqa - }, - "timm-res2net50_48w_2s": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_48w_2s-afed724a.pth" # noqa - }, - "timm-res2net50_14w_8s": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_14w_8s-6527dddc.pth" # noqa - }, - "timm-res2net50_26w_6s": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_6s-19041792.pth" # noqa - }, - "timm-res2net50_26w_8s": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_8s-2c7c9f12.pth" # noqa - }, - "timm-res2net101_26w_4s": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net101_26w_4s-02a759a1.pth" # noqa - }, - "timm-res2next50": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2next50_4s-6ef7e7bf.pth" # noqa - }, -} - -pretrained_settings = {} -for model_name, sources in res2net_weights.items(): - pretrained_settings[model_name] = {} - for source_name, source_url in sources.items(): - pretrained_settings[model_name][source_name] = { - "url": source_url, - "input_size": [3, 224, 224], - "input_range": [0, 1], - "mean": [0.485, 0.456, 0.406], - "std": [0.229, 0.224, 0.225], - "num_classes": 1000, - } - - -timm_res2net_encoders = { - "timm-res2net50_26w_4s": { - "encoder": Res2NetEncoder, - "pretrained_settings": pretrained_settings["timm-res2net50_26w_4s"], - "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), - "block": Bottle2neck, - "layers": [3, 4, 6, 3], - "base_width": 26, - "block_args": {"scale": 4}, - }, - }, - "timm-res2net101_26w_4s": { - "encoder": Res2NetEncoder, - "pretrained_settings": pretrained_settings["timm-res2net101_26w_4s"], - "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), - "block": Bottle2neck, - "layers": [3, 4, 23, 3], - "base_width": 26, - "block_args": {"scale": 4}, - }, - }, - "timm-res2net50_26w_6s": { - "encoder": Res2NetEncoder, - "pretrained_settings": pretrained_settings["timm-res2net50_26w_6s"], - "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), - "block": Bottle2neck, - "layers": [3, 4, 6, 3], - "base_width": 26, - "block_args": {"scale": 6}, - }, - }, - "timm-res2net50_26w_8s": { - "encoder": Res2NetEncoder, - "pretrained_settings": pretrained_settings["timm-res2net50_26w_8s"], - "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), - "block": Bottle2neck, - "layers": [3, 4, 6, 3], - "base_width": 26, - "block_args": {"scale": 8}, - }, - }, - "timm-res2net50_48w_2s": { - "encoder": Res2NetEncoder, - "pretrained_settings": pretrained_settings["timm-res2net50_48w_2s"], - "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), - "block": Bottle2neck, - "layers": [3, 4, 6, 3], - "base_width": 48, - "block_args": {"scale": 2}, - }, - }, - "timm-res2net50_14w_8s": { - "encoder": Res2NetEncoder, - "pretrained_settings": pretrained_settings["timm-res2net50_14w_8s"], - "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), - "block": Bottle2neck, - "layers": [3, 4, 6, 3], - "base_width": 14, - "block_args": {"scale": 8}, - }, - }, - "timm-res2next50": { - "encoder": Res2NetEncoder, - "pretrained_settings": pretrained_settings["timm-res2next50"], - "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), - "block": Bottle2neck, - "layers": [3, 4, 6, 3], - "base_width": 4, - "cardinality": 8, - "block_args": {"scale": 4}, - }, - }, -} diff --git a/segmentation_models_pytorch/encoders/timm_resnest.py b/segmentation_models_pytorch/encoders/timm_resnest.py deleted file mode 100644 index 1599b6c8..00000000 --- a/segmentation_models_pytorch/encoders/timm_resnest.py +++ /dev/null @@ -1,208 +0,0 @@ -from ._base import EncoderMixin -from timm.models.resnet import ResNet -from timm.models.resnest import ResNestBottleneck -import torch.nn as nn - - -class ResNestEncoder(ResNet, EncoderMixin): - def __init__(self, out_channels, depth=5, **kwargs): - super().__init__(**kwargs) - self._depth = depth - self._out_channels = out_channels - self._in_channels = 3 - - del self.fc - del self.global_pool - - def get_stages(self): - return [ - nn.Identity(), - nn.Sequential(self.conv1, self.bn1, self.act1), - nn.Sequential(self.maxpool, self.layer1), - self.layer2, - self.layer3, - self.layer4, - ] - - def make_dilated(self, *args, **kwargs): - raise ValueError("ResNest encoders do not support dilated mode") - - def forward(self, x): - stages = self.get_stages() - - features = [] - for i in range(self._depth + 1): - x = stages[i](x) - features.append(x) - - return features - - def load_state_dict(self, state_dict, **kwargs): - state_dict.pop("fc.bias", None) - state_dict.pop("fc.weight", None) - super().load_state_dict(state_dict, **kwargs) - - -resnest_weights = { - "timm-resnest14d": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest14-9c8fe254.pth" # noqa - }, - "timm-resnest26d": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest26-50eb607c.pth" # noqa - }, - "timm-resnest50d": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50-528c19ca.pth" # noqa - }, - "timm-resnest101e": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest101-22405ba7.pth" # noqa - }, - "timm-resnest200e": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest200-75117900.pth" # noqa - }, - "timm-resnest269e": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest269-0cc87c48.pth" # noqa - }, - "timm-resnest50d_4s2x40d": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_4s2x40d-41d14ed0.pth" # noqa - }, - "timm-resnest50d_1s4x24d": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_1s4x24d-d4a4f76f.pth" # noqa - }, -} - -pretrained_settings = {} -for model_name, sources in resnest_weights.items(): - pretrained_settings[model_name] = {} - for source_name, source_url in sources.items(): - pretrained_settings[model_name][source_name] = { - "url": source_url, - "input_size": [3, 224, 224], - "input_range": [0, 1], - "mean": [0.485, 0.456, 0.406], - "std": [0.229, 0.224, 0.225], - "num_classes": 1000, - } - - -timm_resnest_encoders = { - "timm-resnest14d": { - "encoder": ResNestEncoder, - "pretrained_settings": pretrained_settings["timm-resnest14d"], - "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), - "block": ResNestBottleneck, - "layers": [1, 1, 1, 1], - "stem_type": "deep", - "stem_width": 32, - "avg_down": True, - "base_width": 64, - "cardinality": 1, - "block_args": {"radix": 2, "avd": True, "avd_first": False}, - }, - }, - "timm-resnest26d": { - "encoder": ResNestEncoder, - "pretrained_settings": pretrained_settings["timm-resnest26d"], - "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), - "block": ResNestBottleneck, - "layers": [2, 2, 2, 2], - "stem_type": "deep", - "stem_width": 32, - "avg_down": True, - "base_width": 64, - "cardinality": 1, - "block_args": {"radix": 2, "avd": True, "avd_first": False}, - }, - }, - "timm-resnest50d": { - "encoder": ResNestEncoder, - "pretrained_settings": pretrained_settings["timm-resnest50d"], - "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), - "block": ResNestBottleneck, - "layers": [3, 4, 6, 3], - "stem_type": "deep", - "stem_width": 32, - "avg_down": True, - "base_width": 64, - "cardinality": 1, - "block_args": {"radix": 2, "avd": True, "avd_first": False}, - }, - }, - "timm-resnest101e": { - "encoder": ResNestEncoder, - "pretrained_settings": pretrained_settings["timm-resnest101e"], - "params": { - "out_channels": (3, 128, 256, 512, 1024, 2048), - "block": ResNestBottleneck, - "layers": [3, 4, 23, 3], - "stem_type": "deep", - "stem_width": 64, - "avg_down": True, - "base_width": 64, - "cardinality": 1, - "block_args": {"radix": 2, "avd": True, "avd_first": False}, - }, - }, - "timm-resnest200e": { - "encoder": ResNestEncoder, - "pretrained_settings": pretrained_settings["timm-resnest200e"], - "params": { - "out_channels": (3, 128, 256, 512, 1024, 2048), - "block": ResNestBottleneck, - "layers": [3, 24, 36, 3], - "stem_type": "deep", - "stem_width": 64, - "avg_down": True, - "base_width": 64, - "cardinality": 1, - "block_args": {"radix": 2, "avd": True, "avd_first": False}, - }, - }, - "timm-resnest269e": { - "encoder": ResNestEncoder, - "pretrained_settings": pretrained_settings["timm-resnest269e"], - "params": { - "out_channels": (3, 128, 256, 512, 1024, 2048), - "block": ResNestBottleneck, - "layers": [3, 30, 48, 8], - "stem_type": "deep", - "stem_width": 64, - "avg_down": True, - "base_width": 64, - "cardinality": 1, - "block_args": {"radix": 2, "avd": True, "avd_first": False}, - }, - }, - "timm-resnest50d_4s2x40d": { - "encoder": ResNestEncoder, - "pretrained_settings": pretrained_settings["timm-resnest50d_4s2x40d"], - "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), - "block": ResNestBottleneck, - "layers": [3, 4, 6, 3], - "stem_type": "deep", - "stem_width": 32, - "avg_down": True, - "base_width": 40, - "cardinality": 2, - "block_args": {"radix": 4, "avd": True, "avd_first": True}, - }, - }, - "timm-resnest50d_1s4x24d": { - "encoder": ResNestEncoder, - "pretrained_settings": pretrained_settings["timm-resnest50d_1s4x24d"], - "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), - "block": ResNestBottleneck, - "layers": [3, 4, 6, 3], - "stem_type": "deep", - "stem_width": 32, - "avg_down": True, - "base_width": 24, - "cardinality": 4, - "block_args": {"radix": 1, "avd": True, "avd_first": True}, - }, - }, -} diff --git a/segmentation_models_pytorch/encoders/timm_sknet.py b/segmentation_models_pytorch/encoders/timm_sknet.py index 14d6d2b0..50c760a1 100644 --- a/segmentation_models_pytorch/encoders/timm_sknet.py +++ b/segmentation_models_pytorch/encoders/timm_sknet.py @@ -1,7 +1,6 @@ from ._base import EncoderMixin from timm.models.resnet import ResNet from timm.models.sknet import SelectiveKernelBottleneck, SelectiveKernelBasic -import torch.nn as nn class SkNetEncoder(ResNet, EncoderMixin): @@ -14,22 +13,30 @@ def __init__(self, out_channels, depth=5, **kwargs): del self.fc del self.global_pool - def get_stages(self): - return [ - nn.Identity(), - nn.Sequential(self.conv1, self.bn1, self.act1), - nn.Sequential(self.maxpool, self.layer1), - self.layer2, - self.layer3, - self.layer4, - ] - def forward(self, x): - stages = self.get_stages() - features = [] - for i in range(self._depth + 1): - x = stages[i](x) + + if self._depth >= 1: + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + features.append(x) + + if self._depth >= 2: + x = self.maxpool(x) + x = self.layer1(x) + features.append(x) + + if self._depth >= 3: + x = self.layer2(x) + features.append(x) + + if self._depth >= 4: + x = self.layer3(x) + features.append(x) + + if self._depth >= 5: + x = self.layer4(x) features.append(x) return features diff --git a/segmentation_models_pytorch/encoders/timm_universal.py b/segmentation_models_pytorch/encoders/timm_universal.py index 9bdcb188..029aef8b 100644 --- a/segmentation_models_pytorch/encoders/timm_universal.py +++ b/segmentation_models_pytorch/encoders/timm_universal.py @@ -65,6 +65,7 @@ def __init__( **kwargs: Additional arguments passed to `timm.create_model`. """ super().__init__() + self.name = name # Default model configuration for feature extraction common_kwargs = dict( diff --git a/tests/encoders/test_smp_encoders.py b/tests/encoders/test_smp_encoders.py index 863537bf..cab7b53d 100644 --- a/tests/encoders/test_smp_encoders.py +++ b/tests/encoders/test_smp_encoders.py @@ -39,3 +39,6 @@ class TestEfficientNetEncoder(base.BaseEncoderTester): # "efficientnet-b7", # extra large model ] ) + + def test_compile(self): + self.skipTest("compile fullgraph is not supported for efficientnet encoders") From 0b0b1c4d0ca090fcab19b82804dbd5a356464998 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sun, 12 Jan 2025 14:03:50 +0000 Subject: [PATCH 16/57] Add tiny encoders and compile mark --- pyproject.toml | 1 + tests/encoders/base.py | 21 ++++++------ .../test_pretrainedmodels_encoders.py | 33 +++++++++++++++++++ tests/encoders/test_smp_encoders.py | 20 +++++++++++ tests/encoders/test_timm_ported_encoders.py | 15 +++++++++ tests/encoders/test_timm_universal.py | 2 +- tests/encoders/test_torchvision_encoders.py | 27 +++++++++++++++ 7 files changed, 107 insertions(+), 12 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3e44d723..77d1b15a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,7 @@ markers = [ "unetplusplus", "upernet", "logits_match", + "compile", ] [tool.coverage.run] diff --git a/tests/encoders/base.py b/tests/encoders/base.py index 78c9e170..b2618e1f 100644 --- a/tests/encoders/base.py +++ b/tests/encoders/base.py @@ -1,3 +1,4 @@ +import pytest import unittest import torch import segmentation_models_pytorch as smp @@ -25,6 +26,9 @@ class BaseEncoderTester(unittest.TestCase): depth_to_test = [3, 4, 5] strides_to_test = [8, 16] # 32 is a default one + def get_tiny_encoder(self): + return smp.encoders.get_encoder(self.encoder_names[0], encoder_weights=None) + @lru_cache def _get_sample(self, batch_size=1, num_channels=3, height=32, width=32): return torch.rand(batch_size, num_channels, height, width) @@ -207,7 +211,7 @@ def test_dilated(self): f"Encoder `{encoder_name}` should have width output strides {expected_width_strides}, but has {width_strides}", ) - @torch.inference_mode() + @pytest.mark.compile def test_compile(self): sample = self._get_sample( batch_size=self.default_batch_size, @@ -216,13 +220,8 @@ def test_compile(self): width=self.default_width, ).to(default_device) - for encoder_name in self.encoder_names: - with self.subTest(encoder_name=encoder_name): - encoder = smp.encoders.get_encoder( - encoder_name, - in_channels=self.default_num_channels, - encoder_weights=None, - ).to(default_device) - encoder.eval() - compiled_encoder = torch.compile(encoder, fullgraph=True, dynamic=True) - compiled_encoder(sample) + encoder = self.get_tiny_encoder().eval().to(default_device) + compiled_encoder = torch.compile(encoder, fullgraph=True, dynamic=True) + + with torch.inference_mode(): + compiled_encoder(sample) diff --git a/tests/encoders/test_pretrainedmodels_encoders.py b/tests/encoders/test_pretrainedmodels_encoders.py index a19e335d..5756b349 100644 --- a/tests/encoders/test_pretrainedmodels_encoders.py +++ b/tests/encoders/test_pretrainedmodels_encoders.py @@ -1,3 +1,5 @@ +import segmentation_models_pytorch as smp + from tests.encoders import base from tests.utils import RUN_ALL_ENCODERS @@ -9,6 +11,21 @@ class TestDPNEncoder(base.BaseEncoderTester): else ["dpn68", "dpn68b", "dpn92", "dpn98", "dpn107", "dpn131"] ) + def get_tiny_encoder(self): + params = { + "stage_idxs": (2, 3, 4, 5), + "out_channels": None, + "groups": 2, + "inc_sec": (2, 2, 2, 2), + "k_r": 2, + "k_sec": (1, 1, 1, 1), + "num_classes": 1000, + "num_init_features": 2, + "small": True, + "test_time_pool": True, + } + return smp.encoders.dpn.DPNEncoder(**params) + class TestInceptionResNetV2Encoder(base.BaseEncoderTester): supports_dilated = False @@ -36,6 +53,22 @@ class TestSeNetEncoder(base.BaseEncoderTester): ] ) + def get_tiny_encoder(self): + params = { + "out_channels": None, + "block": smp.encoders.senet.SEResNetBottleneck, + "layers": [1, 1, 1, 1], + "downsample_kernel_size": 1, + "downsample_padding": 0, + "dropout_p": None, + "groups": 1, + "inplanes": 2, + "input_3x3": False, + "num_classes": 1000, + "reduction": 2, + } + return smp.encoders.senet.SENetEncoder(**params) + class TestXceptionEncoder(base.BaseEncoderTester): supports_dilated = False diff --git a/tests/encoders/test_smp_encoders.py b/tests/encoders/test_smp_encoders.py index cab7b53d..67a27126 100644 --- a/tests/encoders/test_smp_encoders.py +++ b/tests/encoders/test_smp_encoders.py @@ -1,3 +1,6 @@ +import segmentation_models_pytorch as smp +from functools import partial + from tests.encoders import base from tests.utils import RUN_ALL_ENCODERS @@ -23,6 +26,23 @@ class TestMixTransformerEncoder(base.BaseEncoderTester): else ["mit_b0", "mit_b1", "mit_b2", "mit_b3", "mit_b4", "mit_b5"] ) + def get_tiny_encoder(self): + params = { + "out_channels": (3, 0, 4, 4, 4, 4), + "patch_size": 4, + "embed_dims": [4, 4, 4, 4], + "num_heads": [1, 1, 1, 1], + "mlp_ratios": [1, 1, 1, 1], + "qkv_bias": True, + "norm_layer": partial(smp.encoders.mix_transformer.LayerNorm, eps=1e-6), + "depths": [1, 1, 1, 1], + "sr_ratios": [8, 4, 2, 1], + "drop_rate": 0.0, + "drop_path_rate": 0.1, + } + + return smp.encoders.mix_transformer.MixVisionTransformerEncoder(**params) + class TestEfficientNetEncoder(base.BaseEncoderTester): encoder_names = ( diff --git a/tests/encoders/test_timm_ported_encoders.py b/tests/encoders/test_timm_ported_encoders.py index b467c968..81e59f63 100644 --- a/tests/encoders/test_timm_ported_encoders.py +++ b/tests/encoders/test_timm_ported_encoders.py @@ -33,6 +33,9 @@ class TestTimmGERNetEncoder(base.BaseEncoderTester): else ["timm-gernet_s", "timm-gernet_m", "timm-gernet_l"] ) + def test_compile(self): + self.skipTest("Test to be removed") + class TestTimmMobileNetV3Encoder(base.BaseEncoderTester): encoder_names = ( @@ -48,6 +51,9 @@ class TestTimmMobileNetV3Encoder(base.BaseEncoderTester): ] ) + def test_compile(self): + self.skipTest("Test to be removed") + class TestTimmRegNetEncoder(base.BaseEncoderTester): encoder_names = ( @@ -81,6 +87,9 @@ class TestTimmRegNetEncoder(base.BaseEncoderTester): ] ) + def test_compile(self): + self.skipTest("Test to be removed") + class TestTimmRes2NetEncoder(base.BaseEncoderTester): supports_dilated = False @@ -98,6 +107,9 @@ class TestTimmRes2NetEncoder(base.BaseEncoderTester): ] ) + def test_compile(self): + self.skipTest("Test to be removed") + class TestTimmResnestEncoder(base.BaseEncoderTester): default_batch_size = 2 @@ -117,6 +129,9 @@ class TestTimmResnestEncoder(base.BaseEncoderTester): ] ) + def test_compile(self): + self.skipTest("Test to be removed") + class TestTimmSkNetEncoder(base.BaseEncoderTester): default_batch_size = 2 diff --git a/tests/encoders/test_timm_universal.py b/tests/encoders/test_timm_universal.py index 753ee4de..e6522618 100644 --- a/tests/encoders/test_timm_universal.py +++ b/tests/encoders/test_timm_universal.py @@ -9,7 +9,7 @@ ] if has_timm_test_models: - timm_encoders.append("tu-test_resnet.r160_in1k") + timm_encoders.insert(0, "tu-test_resnet.r160_in1k") class TestTimmUniversalEncoder(base.BaseEncoderTester): diff --git a/tests/encoders/test_torchvision_encoders.py b/tests/encoders/test_torchvision_encoders.py index 2ebaa86a..e2aa47b8 100644 --- a/tests/encoders/test_torchvision_encoders.py +++ b/tests/encoders/test_torchvision_encoders.py @@ -1,3 +1,5 @@ +import segmentation_models_pytorch as smp + from tests.encoders import base from tests.utils import RUN_ALL_ENCODERS @@ -21,6 +23,14 @@ class TestResNetEncoder(base.BaseEncoderTester): ] ) + def get_tiny_encoder(self): + params = { + "out_channels": None, + "block": smp.encoders.resnet.BasicBlock, + "layers": [1, 1, 1, 1], + } + return smp.encoders.resnet.ResNetEncoder(**params) + class TestDenseNetEncoder(base.BaseEncoderTester): supports_dilated = False @@ -30,6 +40,15 @@ class TestDenseNetEncoder(base.BaseEncoderTester): else ["densenet121", "densenet169", "densenet161"] ) + def get_tiny_encoder(self): + params = { + "out_channels": None, + "num_init_features": 2, + "growth_rate": 1, + "block_config": (1, 1, 1, 1), + } + return smp.encoders.densenet.DenseNetEncoder(**params) + class TestMobileNetEncoder(base.BaseEncoderTester): encoder_names = ["mobilenet_v2"] if not RUN_ALL_ENCODERS else ["mobilenet_v2"] @@ -51,3 +70,11 @@ class TestVggEncoder(base.BaseEncoderTester): "vgg19_bn", ] ) + + def get_tiny_encoder(self): + params = { + "out_channels": (4, 4, 4, 4, 4, 4), + "config": [4, "M", 4, "M", 4, "M", 4, "M", 4, "M"], + "batch_norm": False, + } + return smp.encoders.vgg.VGGEncoder(**params) From 4c1168222f109a80f1e41ee73dd36e0050c54ae8 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sun, 12 Jan 2025 14:03:59 +0000 Subject: [PATCH 17/57] Add conftest --- tests/conftest.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 tests/conftest.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..0562659c --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,15 @@ +def pytest_addoption(parser): + parser.addoption( + "--non-marked-only", action="store_true", help="Run only non-marked tests" + ) + +def pytest_collection_modifyitems(config, items): + if config.getoption("--non-marked-only"): + non_marked_items = [] + for item in items: + # Check if the test has no marks + if not item.own_markers: + non_marked_items.append(item) + + # Update the test collection to only include non-marked tests + items[:] = non_marked_items From 70168b463ac44e4f9aa24c83024e02fc73a5250a Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sun, 12 Jan 2025 14:04:45 +0000 Subject: [PATCH 18/57] Fix features --- segmentation_models_pytorch/encoders/densenet.py | 2 +- segmentation_models_pytorch/encoders/dpn.py | 2 +- segmentation_models_pytorch/encoders/efficientnet.py | 2 +- .../encoders/inceptionresnetv2.py | 2 +- segmentation_models_pytorch/encoders/inceptionv4.py | 2 +- .../encoders/mix_transformer.py | 2 +- segmentation_models_pytorch/encoders/mobilenet.py | 2 +- segmentation_models_pytorch/encoders/mobileone.py | 2 +- segmentation_models_pytorch/encoders/resnet.py | 2 +- segmentation_models_pytorch/encoders/senet.py | 2 +- .../encoders/timm_efficientnet.py | 2 +- segmentation_models_pytorch/encoders/timm_sknet.py | 2 +- segmentation_models_pytorch/encoders/vgg.py | 11 ----------- segmentation_models_pytorch/encoders/xception.py | 2 +- 14 files changed, 13 insertions(+), 24 deletions(-) diff --git a/segmentation_models_pytorch/encoders/densenet.py b/segmentation_models_pytorch/encoders/densenet.py index 6ec1773f..4dd23f2f 100644 --- a/segmentation_models_pytorch/encoders/densenet.py +++ b/segmentation_models_pytorch/encoders/densenet.py @@ -56,7 +56,7 @@ def apply_transition( return x, intermediate def forward(self, x): - features = [] + features = [x] if self._depth >= 1: x = self.features.conv0(x) diff --git a/segmentation_models_pytorch/encoders/dpn.py b/segmentation_models_pytorch/encoders/dpn.py index 0ee8f04b..49ca6745 100644 --- a/segmentation_models_pytorch/encoders/dpn.py +++ b/segmentation_models_pytorch/encoders/dpn.py @@ -44,7 +44,7 @@ def __init__(self, stage_idxs, out_channels, depth=5, **kwargs): del self.last_linear def forward(self, x: torch.Tensor) -> List[torch.Tensor]: - features = [] + features = [x] if self._depth >= 1: x = self.features[0].conv(x) diff --git a/segmentation_models_pytorch/encoders/efficientnet.py b/segmentation_models_pytorch/encoders/efficientnet.py index 2765af20..e1d051b0 100644 --- a/segmentation_models_pytorch/encoders/efficientnet.py +++ b/segmentation_models_pytorch/encoders/efficientnet.py @@ -56,7 +56,7 @@ def apply_blocks( return x def forward(self, x: torch.Tensor) -> List[torch.Tensor]: - features = [] + features = [x] if self._depth >= 1: x = self._conv_stem(x) diff --git a/segmentation_models_pytorch/encoders/inceptionresnetv2.py b/segmentation_models_pytorch/encoders/inceptionresnetv2.py index 2c6b8de3..4ef53404 100644 --- a/segmentation_models_pytorch/encoders/inceptionresnetv2.py +++ b/segmentation_models_pytorch/encoders/inceptionresnetv2.py @@ -56,7 +56,7 @@ def make_dilated(self, *args, **kwargs): ) def forward(self, x): - features = [] + features = [x] if self._depth >= 1: x = self.conv2d_1a(x) diff --git a/segmentation_models_pytorch/encoders/inceptionv4.py b/segmentation_models_pytorch/encoders/inceptionv4.py index 4731053a..6fce4306 100644 --- a/segmentation_models_pytorch/encoders/inceptionv4.py +++ b/segmentation_models_pytorch/encoders/inceptionv4.py @@ -65,7 +65,7 @@ def get_stages(self): ] def forward(self, x): - features = [] + features = [x] if self._depth >= 1: x = self.features[: self._stage_idxs[0]](x) diff --git a/segmentation_models_pytorch/encoders/mix_transformer.py b/segmentation_models_pytorch/encoders/mix_transformer.py index 1595ae0d..737a0b3f 100644 --- a/segmentation_models_pytorch/encoders/mix_transformer.py +++ b/segmentation_models_pytorch/encoders/mix_transformer.py @@ -533,7 +533,7 @@ def forward(self, x): [batch_size, 0, height // 2, width // 2], dtype=x.dtype, device=x.device ) - features = [dummy] + features = [x, dummy] if self._depth >= 2: x = self.patch_embed1(x) diff --git a/segmentation_models_pytorch/encoders/mobilenet.py b/segmentation_models_pytorch/encoders/mobilenet.py index f49175d1..80fad5ab 100644 --- a/segmentation_models_pytorch/encoders/mobilenet.py +++ b/segmentation_models_pytorch/encoders/mobilenet.py @@ -37,7 +37,7 @@ def __init__(self, out_channels, depth=5, **kwargs): del self.classifier def forward(self, x): - features = [] + features = [x] if self._depth >= 1: x = self.features[:2](x) diff --git a/segmentation_models_pytorch/encoders/mobileone.py b/segmentation_models_pytorch/encoders/mobileone.py index c6f7c391..881fa11b 100644 --- a/segmentation_models_pytorch/encoders/mobileone.py +++ b/segmentation_models_pytorch/encoders/mobileone.py @@ -409,7 +409,7 @@ def _make_stage( def forward(self, x: torch.Tensor) -> List[torch.Tensor]: """Apply forward pass.""" - features = [] + features = [x] if self._depth >= 1: x = self.stage0(x) diff --git a/segmentation_models_pytorch/encoders/resnet.py b/segmentation_models_pytorch/encoders/resnet.py index fac29c96..0d8f64a2 100644 --- a/segmentation_models_pytorch/encoders/resnet.py +++ b/segmentation_models_pytorch/encoders/resnet.py @@ -45,7 +45,7 @@ def __init__(self, out_channels, depth=5, **kwargs): del self.avgpool def forward(self, x: torch.Tensor) -> list[torch.Tensor]: - features = [] + features = [x] if self._depth >= 1: x = self.conv1(x) diff --git a/segmentation_models_pytorch/encoders/senet.py b/segmentation_models_pytorch/encoders/senet.py index 506a9717..4a4d9ef0 100644 --- a/segmentation_models_pytorch/encoders/senet.py +++ b/segmentation_models_pytorch/encoders/senet.py @@ -56,7 +56,7 @@ def get_stages(self): ] def forward(self, x): - features = [] + features = [x] if self._depth >= 1: x = self.layer0[:-1](x) diff --git a/segmentation_models_pytorch/encoders/timm_efficientnet.py b/segmentation_models_pytorch/encoders/timm_efficientnet.py index cf2811eb..33d4a3cb 100644 --- a/segmentation_models_pytorch/encoders/timm_efficientnet.py +++ b/segmentation_models_pytorch/encoders/timm_efficientnet.py @@ -106,7 +106,7 @@ def __init__(self, stage_idxs, out_channels, depth=5, **kwargs): del self.classifier def forward(self, x): - features = [] + features = [x] if self._depth >= 1: x = self.conv_stem(x) diff --git a/segmentation_models_pytorch/encoders/timm_sknet.py b/segmentation_models_pytorch/encoders/timm_sknet.py index 50c760a1..fb5a6cf3 100644 --- a/segmentation_models_pytorch/encoders/timm_sknet.py +++ b/segmentation_models_pytorch/encoders/timm_sknet.py @@ -14,7 +14,7 @@ def __init__(self, out_channels, depth=5, **kwargs): del self.global_pool def forward(self, x): - features = [] + features = [x] if self._depth >= 1: x = self.conv1(x) diff --git a/segmentation_models_pytorch/encoders/vgg.py b/segmentation_models_pytorch/encoders/vgg.py index 7ff7843a..fb4502a5 100644 --- a/segmentation_models_pytorch/encoders/vgg.py +++ b/segmentation_models_pytorch/encoders/vgg.py @@ -53,17 +53,6 @@ def make_dilated(self, *args, **kwargs): " operations for downsampling!" ) - def get_stages(self): - stages = [] - stage_modules = [] - for module in self.features: - if isinstance(module, nn.MaxPool2d): - stages.append(nn.Sequential(*stage_modules)) - stage_modules = [] - stage_modules.append(module) - stages.append(nn.Sequential(*stage_modules)) - return stages - def forward(self, x): features = [] depth = 0 diff --git a/segmentation_models_pytorch/encoders/xception.py b/segmentation_models_pytorch/encoders/xception.py index d0ee22f9..7ba1cdd6 100644 --- a/segmentation_models_pytorch/encoders/xception.py +++ b/segmentation_models_pytorch/encoders/xception.py @@ -24,7 +24,7 @@ def make_dilated(self, *args, **kwargs): ) def forward(self, x): - features = [] + features = [x] if self._depth >= 1: x = self.conv1(x) From 50c40d1ebd7a4fcaf27af2dd05fefd6f64c6001d Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 13 Jan 2025 18:07:59 +0000 Subject: [PATCH 19/57] Add triggering compile tests on diff --- requirements/test.txt | 3 +- tests/conftest.py | 3 +- tests/encoders/base.py | 8 +++- .../test_pretrainedmodels_encoders.py | 5 ++ tests/encoders/test_smp_encoders.py | 3 ++ tests/encoders/test_timm_ported_encoders.py | 2 + tests/encoders/test_timm_universal.py | 1 + tests/encoders/test_torchvision_encoders.py | 4 ++ tests/utils.py | 47 ++++++++++++------- 9 files changed, 56 insertions(+), 20 deletions(-) diff --git a/requirements/test.txt b/requirements/test.txt index 755b851d..31a6be53 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,5 +1,6 @@ +gitpython==3.1.44 packaging==24.2 pytest==8.3.4 pytest-xdist==3.6.1 pytest-cov==6.0.0 -ruff==0.9.1 +ruff==0.9.1 \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 0562659c..688fd00b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ def pytest_addoption(parser): "--non-marked-only", action="store_true", help="Run only non-marked tests" ) + def pytest_collection_modifyitems(config, items): if config.getoption("--non-marked-only"): non_marked_items = [] @@ -10,6 +11,6 @@ def pytest_collection_modifyitems(config, items): # Check if the test has no marks if not item.own_markers: non_marked_items.append(item) - + # Update the test collection to only include non-marked tests items[:] = non_marked_items diff --git a/tests/encoders/base.py b/tests/encoders/base.py index 5e7c9d20..87ad2cfb 100644 --- a/tests/encoders/base.py +++ b/tests/encoders/base.py @@ -4,12 +4,15 @@ import segmentation_models_pytorch as smp from functools import lru_cache -from tests.utils import default_device +from tests.utils import default_device, check_run_test_on_diff_or_main class BaseEncoderTester(unittest.TestCase): encoder_names = [] + # some tests might be slow, running them only on diff + files_for_diff = [] + # standard encoder configuration num_output_features = 6 output_strides = [1, 2, 4, 8, 16, 32] @@ -213,6 +216,9 @@ def test_dilated(self): @pytest.mark.compile def test_compile(self): + if not check_run_test_on_diff_or_main(self.files_for_diff): + self.skipTest("No diff and not on `main`.") + sample = self._get_sample( batch_size=self.default_batch_size, num_channels=self.default_num_channels, diff --git a/tests/encoders/test_pretrainedmodels_encoders.py b/tests/encoders/test_pretrainedmodels_encoders.py index 5756b349..868f686d 100644 --- a/tests/encoders/test_pretrainedmodels_encoders.py +++ b/tests/encoders/test_pretrainedmodels_encoders.py @@ -10,6 +10,7 @@ class TestDPNEncoder(base.BaseEncoderTester): if not RUN_ALL_ENCODERS else ["dpn68", "dpn68b", "dpn92", "dpn98", "dpn107", "dpn131"] ) + files_for_diff = ["encoders/dpn.py"] def get_tiny_encoder(self): params = { @@ -32,11 +33,13 @@ class TestInceptionResNetV2Encoder(base.BaseEncoderTester): encoder_names = ( ["inceptionresnetv2"] if not RUN_ALL_ENCODERS else ["inceptionresnetv2"] ) + files_for_diff = ["encoders/inceptionresnetv2.py"] class TestInceptionV4Encoder(base.BaseEncoderTester): supports_dilated = False encoder_names = ["inceptionv4"] if not RUN_ALL_ENCODERS else ["inceptionv4"] + files_for_diff = ["encoders/inceptionv4.py"] class TestSeNetEncoder(base.BaseEncoderTester): @@ -52,6 +55,7 @@ class TestSeNetEncoder(base.BaseEncoderTester): # "senet154", # extra large model ] ) + files_for_diff = ["encoders/senet.py"] def get_tiny_encoder(self): params = { @@ -73,3 +77,4 @@ def get_tiny_encoder(self): class TestXceptionEncoder(base.BaseEncoderTester): supports_dilated = False encoder_names = ["xception"] if not RUN_ALL_ENCODERS else ["xception"] + files_for_diff = ["encoders/xception.py"] diff --git a/tests/encoders/test_smp_encoders.py b/tests/encoders/test_smp_encoders.py index 67a27126..876d9266 100644 --- a/tests/encoders/test_smp_encoders.py +++ b/tests/encoders/test_smp_encoders.py @@ -17,6 +17,7 @@ class TestMobileoneEncoder(base.BaseEncoderTester): "mobileone_s4", ] ) + files_for_diff = ["encoders/mobileone.py"] class TestMixTransformerEncoder(base.BaseEncoderTester): @@ -25,6 +26,7 @@ class TestMixTransformerEncoder(base.BaseEncoderTester): if not RUN_ALL_ENCODERS else ["mit_b0", "mit_b1", "mit_b2", "mit_b3", "mit_b4", "mit_b5"] ) + files_for_diff = ["encoders/mix_transformer.py"] def get_tiny_encoder(self): params = { @@ -59,6 +61,7 @@ class TestEfficientNetEncoder(base.BaseEncoderTester): # "efficientnet-b7", # extra large model ] ) + files_for_diff = ["encoders/efficientnet.py"] def test_compile(self): self.skipTest("compile fullgraph is not supported for efficientnet encoders") diff --git a/tests/encoders/test_timm_ported_encoders.py b/tests/encoders/test_timm_ported_encoders.py index 81e59f63..ff1f48ee 100644 --- a/tests/encoders/test_timm_ported_encoders.py +++ b/tests/encoders/test_timm_ported_encoders.py @@ -24,6 +24,7 @@ class TestTimmEfficientNetEncoder(base.BaseEncoderTester): "timm-tf_efficientnet_lite4", ] ) + files_for_diff = ["encoders/timm_efficientnet.py"] class TestTimmGERNetEncoder(base.BaseEncoderTester): @@ -144,3 +145,4 @@ class TestTimmSkNetEncoder(base.BaseEncoderTester): "timm-skresnext50_32x4d", ] ) + files_for_diff = ["encoders/timm_sknet.py"] diff --git a/tests/encoders/test_timm_universal.py b/tests/encoders/test_timm_universal.py index e6522618..99f8990f 100644 --- a/tests/encoders/test_timm_universal.py +++ b/tests/encoders/test_timm_universal.py @@ -14,3 +14,4 @@ class TestTimmUniversalEncoder(base.BaseEncoderTester): encoder_names = timm_encoders + files_for_diff = ["encoders/timm_universal.py"] diff --git a/tests/encoders/test_torchvision_encoders.py b/tests/encoders/test_torchvision_encoders.py index e2aa47b8..7c979689 100644 --- a/tests/encoders/test_torchvision_encoders.py +++ b/tests/encoders/test_torchvision_encoders.py @@ -22,6 +22,7 @@ class TestResNetEncoder(base.BaseEncoderTester): "resnext101_32x48d", ] ) + files_for_diff = ["encoders/resnet.py"] def get_tiny_encoder(self): params = { @@ -39,6 +40,7 @@ class TestDenseNetEncoder(base.BaseEncoderTester): if not RUN_ALL_ENCODERS else ["densenet121", "densenet169", "densenet161"] ) + files_for_diff = ["encoders/densenet.py"] def get_tiny_encoder(self): params = { @@ -52,6 +54,7 @@ def get_tiny_encoder(self): class TestMobileNetEncoder(base.BaseEncoderTester): encoder_names = ["mobilenet_v2"] if not RUN_ALL_ENCODERS else ["mobilenet_v2"] + files_for_diff = ["encoders/mobilenet.py"] class TestVggEncoder(base.BaseEncoderTester): @@ -70,6 +73,7 @@ class TestVggEncoder(base.BaseEncoderTester): "vgg19_bn", ] ) + files_for_diff = ["encoders/vgg.py"] def get_tiny_encoder(self): params = { diff --git a/tests/utils.py b/tests/utils.py index e8bce88e..6e201f1d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,31 +1,21 @@ import os +import re import timm import torch import unittest +from git import Repo +from typing import List from packaging.version import Version has_timm_test_models = Version(timm.__version__) >= Version("1.0.12") default_device = "cuda" if torch.cuda.is_available() else "cpu" - -def get_commit_message(): - commit_msg = os.getenv("COMMIT_MESSAGE", "") - return commit_msg.lower() - - -# Check both environment variables and commit message -commit_message = get_commit_message() -RUN_ALL_ENCODERS = ( - os.getenv("RUN_ALL_ENCODERS", "false").lower() in ["true", "1", "y", "yes"] - or "run-all-encoders" in commit_message -) - -RUN_SLOW = ( - os.getenv("RUN_SLOW", "false").lower() in ["true", "1", "y", "yes"] - or "run-slow" in commit_message -) +YES_LIST = ["true", "1", "y", "yes"] +RUN_ALL_ENCODERS = os.getenv("RUN_ALL_ENCODERS", "false").lower() in YES_LIST +RUN_SLOW = os.getenv("RUN_SLOW", "false").lower() in YES_LIST +RUN_ALL = os.getenv("RUN_ALL", "false").lower() in YES_LIST def slow_test(test_case): @@ -45,3 +35,26 @@ def requires_torch_greater_or_equal(version: str): torch_version >= provided_version, f"torch version {torch_version} is less than {provided_version}", ) + + +def check_run_test_on_diff_or_main(filepath_patterns: List[str]): + if RUN_ALL: + return True + + try: + repo = Repo(".") + current_branch = repo.active_branch.name + diff_files = repo.git.diff("main", name_only=True).splitlines() + + except Exception: + return True + + if current_branch == "main": + return True + + for pattern in filepath_patterns: + for file_path in diff_files: + if re.search(pattern, file_path): + return True + + return False From 0764d5eb82afd417885ef4cf5b25c737012367ae Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 13 Jan 2025 18:16:16 +0000 Subject: [PATCH 20/57] Remove marks --- pyproject.toml | 11 ----------- tests/models/test_deeplab.py | 3 --- tests/models/test_fpn.py | 2 -- tests/models/test_linknet.py | 2 -- tests/models/test_manet.py | 2 -- tests/models/test_pan.py | 2 -- tests/models/test_psp.py | 2 -- tests/models/test_segformer.py | 1 - tests/models/test_unet.py | 2 -- tests/models/test_unetplusplus.py | 2 -- tests/models/test_upernet.py | 2 -- 11 files changed, 31 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 77d1b15a..aab3379e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,17 +61,6 @@ include = ['segmentation_models_pytorch*'] [tool.pytest.ini_options] markers = [ - "deeplabv3", - "deeplabv3plus", - "fpn", - "linknet", - "manet", - "pan", - "psp", - "segformer", - "unet", - "unetplusplus", - "upernet", "logits_match", "compile", ] diff --git a/tests/models/test_deeplab.py b/tests/models/test_deeplab.py index d3d350e9..0b0e63d9 100644 --- a/tests/models/test_deeplab.py +++ b/tests/models/test_deeplab.py @@ -1,15 +1,12 @@ -import pytest from tests.models import base -@pytest.mark.deeplabv3 class TestDeeplabV3Model(base.BaseModelTester): test_model_type = "deeplabv3" default_batch_size = 2 -@pytest.mark.deeplabv3plus class TestDeeplabV3PlusModel(base.BaseModelTester): test_model_type = "deeplabv3plus" diff --git a/tests/models/test_fpn.py b/tests/models/test_fpn.py index 15ae1f6a..cd0585fd 100644 --- a/tests/models/test_fpn.py +++ b/tests/models/test_fpn.py @@ -1,7 +1,5 @@ -import pytest from tests.models import base -@pytest.mark.fpn class TestFpnModel(base.BaseModelTester): test_model_type = "fpn" diff --git a/tests/models/test_linknet.py b/tests/models/test_linknet.py index 1ab5eb4e..fc76e0c0 100644 --- a/tests/models/test_linknet.py +++ b/tests/models/test_linknet.py @@ -1,7 +1,5 @@ -import pytest from tests.models import base -@pytest.mark.linknet class TestLinknetModel(base.BaseModelTester): test_model_type = "linknet" diff --git a/tests/models/test_manet.py b/tests/models/test_manet.py index 33a8ae3b..d8950df8 100644 --- a/tests/models/test_manet.py +++ b/tests/models/test_manet.py @@ -1,7 +1,5 @@ -import pytest from tests.models import base -@pytest.mark.manet class TestManetModel(base.BaseModelTester): test_model_type = "manet" diff --git a/tests/models/test_pan.py b/tests/models/test_pan.py index d66fefe0..634f1351 100644 --- a/tests/models/test_pan.py +++ b/tests/models/test_pan.py @@ -1,8 +1,6 @@ -import pytest from tests.models import base -@pytest.mark.pan class TestPanModel(base.BaseModelTester): test_model_type = "pan" diff --git a/tests/models/test_psp.py b/tests/models/test_psp.py index 2603cdda..e08e6cf9 100644 --- a/tests/models/test_psp.py +++ b/tests/models/test_psp.py @@ -1,8 +1,6 @@ -import pytest from tests.models import base -@pytest.mark.psp class TestPspModel(base.BaseModelTester): test_model_type = "pspnet" diff --git a/tests/models/test_segformer.py b/tests/models/test_segformer.py index 5195f050..4ad010c9 100644 --- a/tests/models/test_segformer.py +++ b/tests/models/test_segformer.py @@ -6,7 +6,6 @@ from tests.utils import slow_test, default_device, requires_torch_greater_or_equal -@pytest.mark.segformer class TestSegformerModel(base.BaseModelTester): test_model_type = "segformer" diff --git a/tests/models/test_unet.py b/tests/models/test_unet.py index 54c69bf0..c51c32bf 100644 --- a/tests/models/test_unet.py +++ b/tests/models/test_unet.py @@ -1,7 +1,5 @@ -import pytest from tests.models import base -@pytest.mark.unet class TestUnetModel(base.BaseModelTester): test_model_type = "unet" diff --git a/tests/models/test_unetplusplus.py b/tests/models/test_unetplusplus.py index 9e67f2ed..2b22e0c2 100644 --- a/tests/models/test_unetplusplus.py +++ b/tests/models/test_unetplusplus.py @@ -1,7 +1,5 @@ -import pytest from tests.models import base -@pytest.mark.unetplusplus class TestUnetPlusPlusModel(base.BaseModelTester): test_model_type = "unetplusplus" diff --git a/tests/models/test_upernet.py b/tests/models/test_upernet.py index 71d703f9..56fa8dff 100644 --- a/tests/models/test_upernet.py +++ b/tests/models/test_upernet.py @@ -1,8 +1,6 @@ -import pytest from tests.models import base -@pytest.mark.upernet class TestUnetModel(base.BaseModelTester): test_model_type = "upernet" default_batch_size = 2 From 7cab4be6caa32b9474a93cc16a9a22e2681b0a78 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 13 Jan 2025 18:19:41 +0000 Subject: [PATCH 21/57] Add test_compile stage to CI --- .github/workflows/tests.yml | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index bf04ab0a..83b02392 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -51,7 +51,7 @@ jobs: run: uv pip list - name: Test with PyTest - run: uv run pytest -v -rsx -n 2 --cov=segmentation_models_pytorch --cov-report=xml --cov-config=pyproject.toml -k "not logits_match" + run: uv run pytest -v -rsx -n 2 --cov=segmentation_models_pytorch --cov-report=xml --cov-config=pyproject.toml --non-marked-only - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v5 @@ -73,7 +73,22 @@ jobs: - name: Show installed packages run: uv pip list - name: Test with PyTest - run: RUN_SLOW=1 uv run pytest -v -rsx -n 2 -k "logits_match" + run: RUN_SLOW=1 uv run pytest -v -rsx -n 2 -m "logits_match" + + test_torch_compile: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: astral-sh/setup-uv@v5 + with: + python-version: "3.10" + - name: Install dependencies + run: uv pip install -r requirements/required.txt -r requirements/test.txt + - name: Show installed packages + run: uv pip list + - name: Test with PyTest + run: uv run pytest -v -rsx -n 2 -m "compile" minimum: runs-on: ubuntu-latest @@ -88,4 +103,4 @@ jobs: - name: Show installed packages run: uv pip list - name: Test with pytest - run: uv run pytest -v -rsx -n 2 -k "not logits_match" + run: uv run pytest -v -rsx -n 2 --non-marked-only From 2622e0efd6337eb85f4bf641245a030cb86a7e91 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 13 Jan 2025 18:33:13 +0000 Subject: [PATCH 22/57] Update requirements --- pyproject.toml | 2 ++ requirements/test.txt | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index aab3379e..4e34be59 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,11 +39,13 @@ docs = [ 'sphinx-book-theme', ] test = [ + 'gitpython', 'packaging', 'pytest', 'pytest-cov', 'pytest-xdist', 'ruff>=0.9', + 'setuptools', ] [project.urls] diff --git a/requirements/test.txt b/requirements/test.txt index 31a6be53..e22c65e8 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -3,4 +3,5 @@ packaging==24.2 pytest==8.3.4 pytest-xdist==3.6.1 pytest-cov==6.0.0 -ruff==0.9.1 \ No newline at end of file +ruff==0.9.1 +setuptools==75.8.0 \ No newline at end of file From e12ee8db30f4bc54ba2861113052a07549c03c15 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 13 Jan 2025 19:23:44 +0000 Subject: [PATCH 23/57] Update makefile --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index a58d230f..478abd6c 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,7 @@ install_dev: .venv .venv/bin/pip install -e ".[test]" test: .venv - .venv/bin/pytest -v -rsx -n 2 tests/ -k "not logits_match" + .venv/bin/pytest -v -rsx -n 2 tests/ --non-marked-only test_all: .venv RUN_SLOW=1 .venv/bin/pytest -v -rsx -n 2 tests/ From da0cd19a4c0af47b5e370520c0285dd1cb2178bf Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 13 Jan 2025 19:24:12 +0000 Subject: [PATCH 24/57] Update get_stages --- segmentation_models_pytorch/encoders/_base.py | 36 ++++++++---------- segmentation_models_pytorch/encoders/dpn.py | 6 +++ .../encoders/efficientnet.py | 6 +++ .../encoders/mix_transformer.py | 6 +++ .../encoders/mobilenet.py | 6 +++ .../encoders/mobileone.py | 6 +++ .../encoders/resnet.py | 6 +++ segmentation_models_pytorch/encoders/senet.py | 14 ++----- .../encoders/timm_efficientnet.py | 6 +++ .../encoders/timm_sknet.py | 6 +++ segmentation_models_pytorch/encoders/vgg.py | 37 +++++++++---------- tests/encoders/test_timm_ported_encoders.py | 2 - 12 files changed, 85 insertions(+), 52 deletions(-) diff --git a/segmentation_models_pytorch/encoders/_base.py b/segmentation_models_pytorch/encoders/_base.py index 3b877075..20b6aa4c 100644 --- a/segmentation_models_pytorch/encoders/_base.py +++ b/segmentation_models_pytorch/encoders/_base.py @@ -1,3 +1,6 @@ +import torch +from typing import Sequence + from . import _utils as utils @@ -31,28 +34,21 @@ def set_in_channels(self, in_channels, pretrained=True): model=self, new_in_channels=in_channels, pretrained=pretrained ) - def get_stages(self): - """Override it in your implementation""" + def get_stages(self) -> dict[int, Sequence[torch.nn.Module]]: + """Override it in your implementation, should return a dictionary with keys as + the output stride and values as the list of modules + """ raise NotImplementedError def make_dilated(self, output_stride): - if output_stride == 16: - stage_list = [5] - dilation_list = [2] - - elif output_stride == 8: - stage_list = [4, 5] - dilation_list = [2, 4] - - else: - raise ValueError( - "Output stride should be 16 or 8, got {}.".format(output_stride) - ) - - self._output_stride = output_stride + if output_stride not in [8, 16]: + raise ValueError(f"Output stride should be 16 or 8, got {output_stride}.") stages = self.get_stages() - for stage_indx, dilation_rate in zip(stage_list, dilation_list): - utils.replace_strides_with_dilation( - module=stages[stage_indx], dilation_rate=dilation_rate - ) + for stage_stride, stage_modules in stages.items(): + if stage_stride <= output_stride: + continue + + dilation_rate = stage_stride // output_stride + for module in stage_modules: + utils.replace_strides_with_dilation(module, dilation_rate) diff --git a/segmentation_models_pytorch/encoders/dpn.py b/segmentation_models_pytorch/encoders/dpn.py index 49ca6745..51c153a7 100644 --- a/segmentation_models_pytorch/encoders/dpn.py +++ b/segmentation_models_pytorch/encoders/dpn.py @@ -43,6 +43,12 @@ def __init__(self, stage_idxs, out_channels, depth=5, **kwargs): del self.last_linear + def get_stages(self): + return { + 16: self.features[self._stage_idxs[1] : self._stage_idxs[2]], + 32: self.features[self._stage_idxs[2] : self._stage_idxs[3]], + } + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: features = [x] diff --git a/segmentation_models_pytorch/encoders/efficientnet.py b/segmentation_models_pytorch/encoders/efficientnet.py index e1d051b0..6c4a4b5f 100644 --- a/segmentation_models_pytorch/encoders/efficientnet.py +++ b/segmentation_models_pytorch/encoders/efficientnet.py @@ -44,6 +44,12 @@ def __init__(self, stage_idxs, out_channels, model_name, depth=5): del self._fc + def get_stages(self): + return { + 16: self._blocks[self._stage_idxs[1] : self._stage_idxs[2]], + 32: self._blocks[self._stage_idxs[2] :], + } + def apply_blocks( self, x: torch.Tensor, start_idx: int, end_idx: int ) -> torch.Tensor: diff --git a/segmentation_models_pytorch/encoders/mix_transformer.py b/segmentation_models_pytorch/encoders/mix_transformer.py index 737a0b3f..54bd747e 100644 --- a/segmentation_models_pytorch/encoders/mix_transformer.py +++ b/segmentation_models_pytorch/encoders/mix_transformer.py @@ -526,6 +526,12 @@ def __init__(self, out_channels, depth=5, **kwargs): self._depth = depth self._in_channels = 3 + def get_stages(self): + return { + 16: [self.patch_embed3, self.block3, self.norm3], + 32: [self.patch_embed4, self.block4, self.norm4], + } + def forward(self, x): # create dummy output for the first block batch_size, _, height, width = x.shape diff --git a/segmentation_models_pytorch/encoders/mobilenet.py b/segmentation_models_pytorch/encoders/mobilenet.py index 80fad5ab..52c87160 100644 --- a/segmentation_models_pytorch/encoders/mobilenet.py +++ b/segmentation_models_pytorch/encoders/mobilenet.py @@ -36,6 +36,12 @@ def __init__(self, out_channels, depth=5, **kwargs): self._in_channels = 3 del self.classifier + def get_stages(self): + return { + 16: self.features[7:14], + 32: self.features[14:], + } + def forward(self, x): features = [x] diff --git a/segmentation_models_pytorch/encoders/mobileone.py b/segmentation_models_pytorch/encoders/mobileone.py index 881fa11b..3605dcef 100644 --- a/segmentation_models_pytorch/encoders/mobileone.py +++ b/segmentation_models_pytorch/encoders/mobileone.py @@ -355,6 +355,12 @@ def __init__( num_se_blocks=num_blocks_per_stage[3] if use_se else 0, ) + def get_stages(self): + return { + 16: self.stage3, + 32: self.stage4, + } + def _make_stage( self, planes: int, num_blocks: int, num_se_blocks: int ) -> nn.Sequential: diff --git a/segmentation_models_pytorch/encoders/resnet.py b/segmentation_models_pytorch/encoders/resnet.py index 0d8f64a2..bfa37abd 100644 --- a/segmentation_models_pytorch/encoders/resnet.py +++ b/segmentation_models_pytorch/encoders/resnet.py @@ -44,6 +44,12 @@ def __init__(self, out_channels, depth=5, **kwargs): del self.fc del self.avgpool + def get_stages(self): + return { + 16: self.layer3, + 32: self.layer4, + } + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: features = [x] diff --git a/segmentation_models_pytorch/encoders/senet.py b/segmentation_models_pytorch/encoders/senet.py index 4a4d9ef0..123bf30d 100644 --- a/segmentation_models_pytorch/encoders/senet.py +++ b/segmentation_models_pytorch/encoders/senet.py @@ -23,8 +23,6 @@ depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). """ -import torch.nn as nn - from pretrainedmodels.models.senet import ( SENet, SEBottleneck, @@ -46,14 +44,10 @@ def __init__(self, out_channels, depth=5, **kwargs): del self.avg_pool def get_stages(self): - return [ - nn.Identity(), - self.layer0[:-1], - nn.Sequential(self.layer0[-1], self.layer1), - self.layer2, - self.layer3, - self.layer4, - ] + return { + 16: self.layer3, + 32: self.layer4, + } def forward(self, x): features = [x] diff --git a/segmentation_models_pytorch/encoders/timm_efficientnet.py b/segmentation_models_pytorch/encoders/timm_efficientnet.py index 33d4a3cb..3ab0c069 100644 --- a/segmentation_models_pytorch/encoders/timm_efficientnet.py +++ b/segmentation_models_pytorch/encoders/timm_efficientnet.py @@ -105,6 +105,12 @@ def __init__(self, stage_idxs, out_channels, depth=5, **kwargs): del self.classifier + def get_stages(self): + return { + 16: self.blocks[self._stage_idxs[1] : self._stage_idxs[2]], + 32: self.blocks[self._stage_idxs[2] :], + } + def forward(self, x): features = [x] diff --git a/segmentation_models_pytorch/encoders/timm_sknet.py b/segmentation_models_pytorch/encoders/timm_sknet.py index fb5a6cf3..9b340d6e 100644 --- a/segmentation_models_pytorch/encoders/timm_sknet.py +++ b/segmentation_models_pytorch/encoders/timm_sknet.py @@ -13,6 +13,12 @@ def __init__(self, out_channels, depth=5, **kwargs): del self.fc del self.global_pool + def get_stages(self): + return { + 16: self.layer3, + 32: self.layer4, + } + def forward(self, x): features = [x] diff --git a/segmentation_models_pytorch/encoders/vgg.py b/segmentation_models_pytorch/encoders/vgg.py index fb4502a5..34eb1091 100644 --- a/segmentation_models_pytorch/encoders/vgg.py +++ b/segmentation_models_pytorch/encoders/vgg.py @@ -31,10 +31,10 @@ # fmt: off cfg = { - 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], - 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], - 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], - 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], + "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], + "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], + "D": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"], + "E": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"], } # fmt: on @@ -45,6 +45,7 @@ def __init__(self, out_channels, config, batch_norm=False, depth=5, **kwargs): self._out_channels = out_channels self._depth = depth self._in_channels = 3 + del self.classifier def make_dilated(self, *args, **kwargs): @@ -54,25 +55,21 @@ def make_dilated(self, *args, **kwargs): ) def forward(self, x): - features = [] - depth = 0 - - for i, module in enumerate(self.features): + # collect stages + stages = [] + stage_modules = [] + for module in self.features: if isinstance(module, nn.MaxPool2d): - features.append(x) - depth += 1 + stages.append(stage_modules) + stage_modules = [] + stage_modules.append(module) + stages.append(stage_modules) - # last layer is always maxpool, we just apply it and break - if i == len(self.features) - 1: + features = [] + for i in range(self._depth + 1): + for module in stages[i]: x = module(x) - features.append(x) - break - - # if depth is reached, break - if depth > self._depth: - break - - x = module(x) + features.append(x) return features diff --git a/tests/encoders/test_timm_ported_encoders.py b/tests/encoders/test_timm_ported_encoders.py index ff1f48ee..3793606e 100644 --- a/tests/encoders/test_timm_ported_encoders.py +++ b/tests/encoders/test_timm_ported_encoders.py @@ -93,7 +93,6 @@ def test_compile(self): class TestTimmRes2NetEncoder(base.BaseEncoderTester): - supports_dilated = False encoder_names = ( ["timm-res2net50_26w_4s"] if not RUN_ALL_ENCODERS @@ -114,7 +113,6 @@ def test_compile(self): class TestTimmResnestEncoder(base.BaseEncoderTester): default_batch_size = 2 - supports_dilated = False encoder_names = ( ["timm-resnest14d"] if not RUN_ALL_ENCODERS From 7752969736b994722df03a32d88630cb2147fdd5 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 13 Jan 2025 20:43:02 +0000 Subject: [PATCH 25/57] Fix weight loading for deprecate encoders --- segmentation_models_pytorch/base/model.py | 16 -------------- .../encoders/timm_universal.py | 22 +++++++++++++++++++ 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/segmentation_models_pytorch/base/model.py b/segmentation_models_pytorch/base/model.py index 5f559db4..a25ed30a 100644 --- a/segmentation_models_pytorch/base/model.py +++ b/segmentation_models_pytorch/base/model.py @@ -3,7 +3,6 @@ from . import initialization as init from .hub_mixin import SMPHubMixin -from ..encoders.timm_universal import TimmUniversalEncoder T = TypeVar("T", bound="SegmentationModel") @@ -82,18 +81,3 @@ def predict(self, x): x = self.forward(x) return x - - def load_state_dict(self, state_dict, **kwargs): - # for compatibility of weights for - # timm- ported encoders with TimmUniversalEncoder - if isinstance(self.encoder, TimmUniversalEncoder): - keys = list(state_dict.keys()) - for key in keys: - new_key = key - if key.startswith("encoder.") and not key.startswith("encoder.model."): - new_key = key.replace("encoder.", "encoder.model.") - if "gernet" in self.encoder.name: - new_key = new_key.replace(".stages.", ".stages_") - state_dict[new_key] = state_dict.pop(key) - - return super().load_state_dict(state_dict, **kwargs) diff --git a/segmentation_models_pytorch/encoders/timm_universal.py b/segmentation_models_pytorch/encoders/timm_universal.py index 029aef8b..15c64345 100644 --- a/segmentation_models_pytorch/encoders/timm_universal.py +++ b/segmentation_models_pytorch/encoders/timm_universal.py @@ -196,6 +196,28 @@ def output_stride(self) -> int: """ return min(self._output_stride, 2**self._depth) + def load_state_dict(self, state_dict, **kwargs): + # for compatibility of weights for + # timm- ported encoders with TimmUniversalEncoder + + patterns = ["regnet", "res2", "resnest", "mobilenetv3", "gernet"] + + is_deprecated_encoder = any( + self.name.startswith(pattern) for pattern in patterns + ) + + if is_deprecated_encoder: + keys = list(state_dict.keys()) + for key in keys: + new_key = key + if not key.startswith("model."): + new_key = "model." + key + if "gernet" in self.name: + new_key = new_key.replace(".stages.", ".stages_") + state_dict[new_key] = state_dict.pop(key) + + return super().load_state_dict(state_dict, **kwargs) + def _merge_kwargs_no_duplicates(a: dict[str, Any], b: dict[str, Any]) -> dict[str, Any]: """ From 409b82084de64e9142d6ffbc0ce6590036016b37 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 13 Jan 2025 20:53:50 +0000 Subject: [PATCH 26/57] Fix weight loading for mobilenetv3 --- segmentation_models_pytorch/encoders/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/segmentation_models_pytorch/encoders/__init__.py b/segmentation_models_pytorch/encoders/__init__.py index 2dbbf020..bc4fafce 100644 --- a/segmentation_models_pytorch/encoders/__init__.py +++ b/segmentation_models_pytorch/encoders/__init__.py @@ -72,7 +72,7 @@ def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, ** # convert timm- models to tu- models if is_equivalent_to_timm_universal(name): name = name.replace("timm-", "tu-") - if "minimal" in name: + if "mobilenetv3" in name: name = name.replace("tu-", "tu-tf_") if name.startswith("tu-"): From ae3cb8a2af5f2359d1fa73998921ca6448de1e14 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 13 Jan 2025 20:54:00 +0000 Subject: [PATCH 27/57] Format --- segmentation_models_pytorch/encoders/timm_universal.py | 1 - 1 file changed, 1 deletion(-) diff --git a/segmentation_models_pytorch/encoders/timm_universal.py b/segmentation_models_pytorch/encoders/timm_universal.py index 15c64345..5a48273a 100644 --- a/segmentation_models_pytorch/encoders/timm_universal.py +++ b/segmentation_models_pytorch/encoders/timm_universal.py @@ -199,7 +199,6 @@ def output_stride(self) -> int: def load_state_dict(self, state_dict, **kwargs): # for compatibility of weights for # timm- ported encoders with TimmUniversalEncoder - patterns = ["regnet", "res2", "resnest", "mobilenetv3", "gernet"] is_deprecated_encoder = any( From ff278c9e4dc5d20169f1834d57c3d26f3c948bb2 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 13 Jan 2025 21:19:22 +0000 Subject: [PATCH 28/57] Add compile test for models --- tests/models/base.py | 20 ++++++++++++++++++++ tests/models/test_deeplab.py | 2 ++ tests/models/test_fpn.py | 1 + tests/models/test_linknet.py | 1 + tests/models/test_manet.py | 1 + tests/models/test_pan.py | 1 + tests/models/test_psp.py | 1 + tests/models/test_segformer.py | 1 + tests/models/test_unet.py | 1 + tests/models/test_unetplusplus.py | 1 + tests/models/test_upernet.py | 2 ++ 11 files changed, 32 insertions(+) diff --git a/tests/models/base.py b/tests/models/base.py index 93aac7b0..ba246436 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -14,6 +14,7 @@ default_device, slow_test, requires_torch_greater_or_equal, + check_run_test_on_diff_or_main, ) @@ -21,6 +22,7 @@ class BaseModelTester(unittest.TestCase): test_encoder_name = ( "tu-test_resnet.r160_in1k" if has_timm_test_models else "resnet18" ) + files_for_diff = [r".*"] # should be overriden test_model_type = None @@ -234,3 +236,21 @@ def test_preserve_forward_output(self): is_close = torch.allclose(output, output_tensor, atol=5e-2) max_diff = torch.max(torch.abs(output - output_tensor)) self.assertTrue(is_close, f"Max diff: {max_diff}") + + @pytest.mark.compile + def test_compile(self): + if not check_run_test_on_diff_or_main(self.files_for_diff): + self.skipTest("No diff and not on `main`.") + + sample = self._get_sample( + batch_size=self.default_batch_size, + num_channels=self.default_num_channels, + height=self.default_height, + width=self.default_width, + ).to(default_device) + + model = self.get_default_model() + compiled_model = torch.compile(model, fullgraph=True, dynamic=True) + + with torch.inference_mode(): + compiled_model(sample) diff --git a/tests/models/test_deeplab.py b/tests/models/test_deeplab.py index 0b0e63d9..de112633 100644 --- a/tests/models/test_deeplab.py +++ b/tests/models/test_deeplab.py @@ -3,11 +3,13 @@ class TestDeeplabV3Model(base.BaseModelTester): test_model_type = "deeplabv3" + files_for_diff = [r"decoders/deeplabv3/", r"base/"] default_batch_size = 2 class TestDeeplabV3PlusModel(base.BaseModelTester): test_model_type = "deeplabv3plus" + files_for_diff = [r"decoders/deeplabv3plus/", r"base/"] default_batch_size = 2 diff --git a/tests/models/test_fpn.py b/tests/models/test_fpn.py index cd0585fd..28e7426b 100644 --- a/tests/models/test_fpn.py +++ b/tests/models/test_fpn.py @@ -3,3 +3,4 @@ class TestFpnModel(base.BaseModelTester): test_model_type = "fpn" + files_for_diff = [r"decoders/fpn/", r"base/"] diff --git a/tests/models/test_linknet.py b/tests/models/test_linknet.py index fc76e0c0..6f9490d9 100644 --- a/tests/models/test_linknet.py +++ b/tests/models/test_linknet.py @@ -3,3 +3,4 @@ class TestLinknetModel(base.BaseModelTester): test_model_type = "linknet" + files_for_diff = [r"decoders/linknet/", r"base/"] diff --git a/tests/models/test_manet.py b/tests/models/test_manet.py index d8950df8..459fe794 100644 --- a/tests/models/test_manet.py +++ b/tests/models/test_manet.py @@ -3,3 +3,4 @@ class TestManetModel(base.BaseModelTester): test_model_type = "manet" + files_for_diff = [r"decoders/manet/", r"base/"] diff --git a/tests/models/test_pan.py b/tests/models/test_pan.py index 634f1351..f2779eaf 100644 --- a/tests/models/test_pan.py +++ b/tests/models/test_pan.py @@ -3,6 +3,7 @@ class TestPanModel(base.BaseModelTester): test_model_type = "pan" + files_for_diff = [r"decoders/pan/", r"base/"] default_batch_size = 2 default_height = 128 diff --git a/tests/models/test_psp.py b/tests/models/test_psp.py index e08e6cf9..c29b5e99 100644 --- a/tests/models/test_psp.py +++ b/tests/models/test_psp.py @@ -3,5 +3,6 @@ class TestPspModel(base.BaseModelTester): test_model_type = "pspnet" + files_for_diff = [r"decoders/pspnet/", r"base/"] default_batch_size = 2 diff --git a/tests/models/test_segformer.py b/tests/models/test_segformer.py index 4ad010c9..b0f288ef 100644 --- a/tests/models/test_segformer.py +++ b/tests/models/test_segformer.py @@ -8,6 +8,7 @@ class TestSegformerModel(base.BaseModelTester): test_model_type = "segformer" + files_for_diff = [r"decoders/segformer/", r"base/"] @slow_test @requires_torch_greater_or_equal("2.0.1") diff --git a/tests/models/test_unet.py b/tests/models/test_unet.py index c51c32bf..4c2d7e4d 100644 --- a/tests/models/test_unet.py +++ b/tests/models/test_unet.py @@ -3,3 +3,4 @@ class TestUnetModel(base.BaseModelTester): test_model_type = "unet" + files_for_diff = [r"decoders/unet/", r"base/"] diff --git a/tests/models/test_unetplusplus.py b/tests/models/test_unetplusplus.py index 2b22e0c2..e2901483 100644 --- a/tests/models/test_unetplusplus.py +++ b/tests/models/test_unetplusplus.py @@ -3,3 +3,4 @@ class TestUnetPlusPlusModel(base.BaseModelTester): test_model_type = "unetplusplus" + files_for_diff = [r"decoders/unetplusplus/", r"base/"] diff --git a/tests/models/test_upernet.py b/tests/models/test_upernet.py index 56fa8dff..1c23406b 100644 --- a/tests/models/test_upernet.py +++ b/tests/models/test_upernet.py @@ -3,4 +3,6 @@ class TestUnetModel(base.BaseModelTester): test_model_type = "upernet" + files_for_diff = [r"decoders/upernet/", r"base/"] + default_batch_size = 2 From a806147676f9b4a523d2a75ed72a9065b6265542 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 13 Jan 2025 22:13:38 +0000 Subject: [PATCH 29/57] Add torch.export test --- .github/workflows/tests.yml | 15 ++++++++++++ pyproject.toml | 1 + segmentation_models_pytorch/base/model.py | 7 +++++- segmentation_models_pytorch/base/utils.py | 13 +++++++++++ tests/encoders/base.py | 28 +++++++++++++++++++++++ tests/models/base.py | 28 +++++++++++++++++++++++ 6 files changed, 91 insertions(+), 1 deletion(-) create mode 100644 segmentation_models_pytorch/base/utils.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 83b02392..3ebe6143 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -90,6 +90,21 @@ jobs: - name: Test with PyTest run: uv run pytest -v -rsx -n 2 -m "compile" + test_torch_export: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: astral-sh/setup-uv@v5 + with: + python-version: "3.10" + - name: Install dependencies + run: uv pip install -r requirements/required.txt -r requirements/test.txt + - name: Show installed packages + run: uv pip list + - name: Test with PyTest + run: uv run pytest -v -rsx -n 2 -m "torch_export" + minimum: runs-on: ubuntu-latest steps: diff --git a/pyproject.toml b/pyproject.toml index 4e34be59..8d9b2078 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,7 @@ include = ['segmentation_models_pytorch*'] markers = [ "logits_match", "compile", + "torch_export", ] [tool.coverage.run] diff --git a/segmentation_models_pytorch/base/model.py b/segmentation_models_pytorch/base/model.py index a25ed30a..e04c2d6e 100644 --- a/segmentation_models_pytorch/base/model.py +++ b/segmentation_models_pytorch/base/model.py @@ -3,6 +3,7 @@ from . import initialization as init from .hub_mixin import SMPHubMixin +from .utils import is_torch_compiling T = TypeVar("T", bound="SegmentationModel") @@ -50,7 +51,11 @@ def check_input_shape(self, x): def forward(self, x): """Sequentially pass `x` trough model`s encoder, decoder and heads""" - if not torch.jit.is_tracing() and self.requires_divisible_input_shape: + if ( + not torch.jit.is_tracing() + and not is_torch_compiling() + and self.requires_divisible_input_shape + ): self.check_input_shape(x) features = self.encoder(x) diff --git a/segmentation_models_pytorch/base/utils.py b/segmentation_models_pytorch/base/utils.py new file mode 100644 index 00000000..3fcba739 --- /dev/null +++ b/segmentation_models_pytorch/base/utils.py @@ -0,0 +1,13 @@ +import torch + + +def is_torch_compiling(): + try: + return torch.compiler.is_compiling() + except Exception: + try: + import torch._dynamo as dynamo # noqa: F401 + + return dynamo.is_compiling() + except Exception: + return False diff --git a/tests/encoders/base.py b/tests/encoders/base.py index 87ad2cfb..c1858cdb 100644 --- a/tests/encoders/base.py +++ b/tests/encoders/base.py @@ -231,3 +231,31 @@ def test_compile(self): with torch.inference_mode(): compiled_encoder(sample) + + @pytest.mark.torch_export + def test_torch_export(self): + if not check_run_test_on_diff_or_main(self.files_for_diff): + self.skipTest("No diff and not on `main`.") + + sample = self._get_sample( + batch_size=self.default_batch_size, + num_channels=self.default_num_channels, + height=self.default_height, + width=self.default_width, + ).to(default_device) + + encoder = self.get_tiny_encoder() + encoder = encoder.eval().to(default_device) + + exported_encoder = torch.export.export( + encoder, + args=(sample,), + strict=True, + ) + + with torch.inference_mode(): + eager_output = encoder(sample) + exported_output = exported_encoder.module().forward(sample) + + for eager_feature, exported_feature in zip(eager_output, exported_output): + torch.testing.assert_close(eager_feature, exported_feature) diff --git a/tests/models/base.py b/tests/models/base.py index ba246436..d6e19fd0 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -254,3 +254,31 @@ def test_compile(self): with torch.inference_mode(): compiled_model(sample) + + @pytest.mark.torch_export + def test_torch_export(self): + if not check_run_test_on_diff_or_main(self.files_for_diff): + self.skipTest("No diff and not on `main`.") + + sample = self._get_sample( + batch_size=self.default_batch_size, + num_channels=self.default_num_channels, + height=self.default_height, + width=self.default_width, + ).to(default_device) + + model = self.get_default_model() + model.eval() + + exported_model = torch.export.export( + model, + args=(sample,), + strict=True, + ) + + with torch.inference_mode(): + eager_output = model(sample) + exported_output = exported_model.module().forward(sample) + + self.assertEqual(eager_output.shape, exported_output.shape) + torch.testing.assert_close(eager_output, exported_output) From aa5b0887514bf033dd8ccd159a654f062ab0e5be Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 13 Jan 2025 22:41:08 +0000 Subject: [PATCH 30/57] Disable export tests for dpn and inceptionv4 --- tests/encoders/base.py | 67 ++++++++----------- .../test_pretrainedmodels_encoders.py | 12 +++- tests/encoders/test_smp_encoders.py | 4 +- 3 files changed, 41 insertions(+), 42 deletions(-) diff --git a/tests/encoders/base.py b/tests/encoders/base.py index c1858cdb..28b12ab8 100644 --- a/tests/encoders/base.py +++ b/tests/encoders/base.py @@ -4,7 +4,11 @@ import segmentation_models_pytorch as smp from functools import lru_cache -from tests.utils import default_device, check_run_test_on_diff_or_main +from tests.utils import ( + default_device, + check_run_test_on_diff_or_main, + requires_torch_greater_or_equal, +) class BaseEncoderTester(unittest.TestCase): @@ -29,11 +33,19 @@ class BaseEncoderTester(unittest.TestCase): depth_to_test = [3, 4, 5] strides_to_test = [8, 16] # 32 is a default one + # enable/disable tests + do_test_torch_compile = True + do_test_torch_export = True + def get_tiny_encoder(self): return smp.encoders.get_encoder(self.encoder_names[0], encoder_weights=None) @lru_cache - def _get_sample(self, batch_size=1, num_channels=3, height=32, width=32): + def _get_sample(self, batch_size=None, num_channels=None, height=None, width=None): + batch_size = batch_size or self.default_batch_size + num_channels = num_channels or self.default_num_channels + height = height or self.default_height + width = width or self.default_width return torch.rand(batch_size, num_channels, height, width) def get_features_output_strides(self, sample, features): @@ -43,12 +55,7 @@ def get_features_output_strides(self, sample, features): return height_strides, width_strides def test_forward_backward(self): - sample = self._get_sample( - batch_size=self.default_batch_size, - num_channels=self.default_num_channels, - height=self.default_height, - width=self.default_width, - ).to(default_device) + sample = self._get_sample().to(default_device) for encoder_name in self.encoder_names: with self.subTest(encoder_name=encoder_name): # init encoder @@ -75,12 +82,7 @@ def test_in_channels(self): ] for encoder_name, in_channels in cases: - sample = self._get_sample( - batch_size=self.default_batch_size, - num_channels=in_channels, - height=self.default_height, - width=self.default_width, - ).to(default_device) + sample = self._get_sample(num_channels=in_channels).to(default_device) with self.subTest(encoder_name=encoder_name, in_channels=in_channels): encoder = smp.encoders.get_encoder( @@ -93,12 +95,7 @@ def test_in_channels(self): encoder.forward(sample) def test_depth(self): - sample = self._get_sample( - batch_size=self.default_batch_size, - num_channels=self.default_num_channels, - height=self.default_height, - width=self.default_width, - ).to(default_device) + sample = self._get_sample().to(default_device) cases = [ (encoder_name, depth) @@ -157,12 +154,7 @@ def test_depth(self): ) def test_dilated(self): - sample = self._get_sample( - batch_size=self.default_batch_size, - num_channels=self.default_num_channels, - height=self.default_height, - width=self.default_width, - ).to(default_device) + sample = self._get_sample().to(default_device) cases = [ (encoder_name, stride) @@ -216,15 +208,15 @@ def test_dilated(self): @pytest.mark.compile def test_compile(self): + if not self.do_test_torch_compile: + self.skipTest( + f"torch_compile test is disabled for {self.encoder_names[0]}." + ) + if not check_run_test_on_diff_or_main(self.files_for_diff): self.skipTest("No diff and not on `main`.") - sample = self._get_sample( - batch_size=self.default_batch_size, - num_channels=self.default_num_channels, - height=self.default_height, - width=self.default_width, - ).to(default_device) + sample = self._get_sample().to(default_device) encoder = self.get_tiny_encoder().eval().to(default_device) compiled_encoder = torch.compile(encoder, fullgraph=True, dynamic=True) @@ -233,16 +225,15 @@ def test_compile(self): compiled_encoder(sample) @pytest.mark.torch_export + @requires_torch_greater_or_equal("2.4.0") def test_torch_export(self): + if not self.do_test_torch_export: + self.skipTest(f"torch_export test is disabled for {self.encoder_names[0]}.") + if not check_run_test_on_diff_or_main(self.files_for_diff): self.skipTest("No diff and not on `main`.") - sample = self._get_sample( - batch_size=self.default_batch_size, - num_channels=self.default_num_channels, - height=self.default_height, - width=self.default_width, - ).to(default_device) + sample = self._get_sample().to(default_device) encoder = self.get_tiny_encoder() encoder = encoder.eval().to(default_device) diff --git a/tests/encoders/test_pretrainedmodels_encoders.py b/tests/encoders/test_pretrainedmodels_encoders.py index 868f686d..e77c3652 100644 --- a/tests/encoders/test_pretrainedmodels_encoders.py +++ b/tests/encoders/test_pretrainedmodels_encoders.py @@ -12,6 +12,10 @@ class TestDPNEncoder(base.BaseEncoderTester): ) files_for_diff = ["encoders/dpn.py"] + # works with torch 2.4.0, but not with torch 2.5.1 + # dynamo error, probably on Sequential + OrderedDict + do_test_torch_export = False + def get_tiny_encoder(self): params = { "stage_idxs": (2, 3, 4, 5), @@ -29,17 +33,21 @@ def get_tiny_encoder(self): class TestInceptionResNetV2Encoder(base.BaseEncoderTester): - supports_dilated = False encoder_names = ( ["inceptionresnetv2"] if not RUN_ALL_ENCODERS else ["inceptionresnetv2"] ) files_for_diff = ["encoders/inceptionresnetv2.py"] + supports_dilated = False class TestInceptionV4Encoder(base.BaseEncoderTester): - supports_dilated = False encoder_names = ["inceptionv4"] if not RUN_ALL_ENCODERS else ["inceptionv4"] files_for_diff = ["encoders/inceptionv4.py"] + supports_dilated = False + + # works with torch 2.4.0, but not with torch 2.5.1 + # dynamo error, probably on Sequential + OrderedDict + do_test_torch_export = False class TestSeNetEncoder(base.BaseEncoderTester): diff --git a/tests/encoders/test_smp_encoders.py b/tests/encoders/test_smp_encoders.py index 876d9266..f65a61b8 100644 --- a/tests/encoders/test_smp_encoders.py +++ b/tests/encoders/test_smp_encoders.py @@ -63,5 +63,5 @@ class TestEfficientNetEncoder(base.BaseEncoderTester): ) files_for_diff = ["encoders/efficientnet.py"] - def test_compile(self): - self.skipTest("compile fullgraph is not supported for efficientnet encoders") + # torch_compile is not supported for efficientnet encoders + do_test_torch_compile = False From df2f48417b264b819ac690f6cb4c6087720eef19 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 13 Jan 2025 22:46:19 +0000 Subject: [PATCH 31/57] Disable export for timm-eff-net --- tests/encoders/test_timm_ported_encoders.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/encoders/test_timm_ported_encoders.py b/tests/encoders/test_timm_ported_encoders.py index 3793606e..49578f73 100644 --- a/tests/encoders/test_timm_ported_encoders.py +++ b/tests/encoders/test_timm_ported_encoders.py @@ -26,6 +26,9 @@ class TestTimmEfficientNetEncoder(base.BaseEncoderTester): ) files_for_diff = ["encoders/timm_efficientnet.py"] + # works with torch 2.4.0, but not with torch 2.5.1 + do_test_torch_export = False + class TestTimmGERNetEncoder(base.BaseEncoderTester): encoder_names = ( From 7157501e1edf64a28d6fe155ed519a7440fd3279 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Tue, 14 Jan 2025 12:38:53 +0000 Subject: [PATCH 32/57] Huge fix for torch scripting (except Unet++ and UperNet) --- segmentation_models_pytorch/base/hub_mixin.py | 2 + segmentation_models_pytorch/base/model.py | 14 +- .../decoders/deeplabv3/decoder.py | 57 ++++-- .../decoders/fpn/decoder.py | 48 ++--- .../decoders/linknet/decoder.py | 18 +- .../decoders/manet/decoder.py | 74 +++++--- .../decoders/pan/decoder.py | 73 +++++--- .../decoders/pspnet/decoder.py | 32 +++- .../decoders/segformer/decoder.py | 19 +- .../decoders/unet/decoder.py | 4 +- .../decoders/unetplusplus/decoder.py | 39 ++-- .../encoders/__init__.py | 1 + segmentation_models_pytorch/encoders/_base.py | 11 +- .../encoders/densenet.py | 13 +- segmentation_models_pytorch/encoders/dpn.py | 44 +++-- .../encoders/efficientnet.py | 52 +++--- .../encoders/inceptionresnetv2.py | 17 +- .../encoders/inceptionv4.py | 32 ++-- .../encoders/mix_transformer.py | 169 +++++++++--------- .../encoders/mobilenet.py | 21 ++- .../encoders/mobileone.py | 26 +-- .../encoders/resnet.py | 38 ++-- segmentation_models_pytorch/encoders/senet.py | 34 ++-- .../encoders/timm_efficientnet.py | 86 +++++---- .../encoders/timm_sknet.py | 31 ++-- segmentation_models_pytorch/encoders/vgg.py | 35 ++-- .../encoders/xception.py | 7 +- .../test_pretrainedmodels_encoders.py | 2 +- tests/encoders/test_smp_encoders.py | 2 +- tests/encoders/test_torchvision_encoders.py | 2 +- 30 files changed, 602 insertions(+), 401 deletions(-) diff --git a/segmentation_models_pytorch/base/hub_mixin.py b/segmentation_models_pytorch/base/hub_mixin.py index 360aa521..1c9e8052 100644 --- a/segmentation_models_pytorch/base/hub_mixin.py +++ b/segmentation_models_pytorch/base/hub_mixin.py @@ -1,3 +1,4 @@ +import torch import json from pathlib import Path from typing import Optional, Union @@ -114,6 +115,7 @@ def save_pretrained( return result @property + @torch.jit.unused def config(self) -> dict: return self._hub_mixin_config diff --git a/segmentation_models_pytorch/base/model.py b/segmentation_models_pytorch/base/model.py index e04c2d6e..b5f8abc5 100644 --- a/segmentation_models_pytorch/base/model.py +++ b/segmentation_models_pytorch/base/model.py @@ -11,8 +11,7 @@ class SegmentationModel(torch.nn.Module, SMPHubMixin): """Base class for all segmentation models.""" - # if model supports shape not divisible by 2 ^ n - # set to False + # if model supports shape not divisible by 2 ^ n set to False requires_divisible_input_shape = True # Fix type-hint for models, to avoid HubMixin signature @@ -30,6 +29,9 @@ def check_input_shape(self, x): """Check if the input shape is divisible by the output stride. If not, raise a RuntimeError. """ + if not self.requires_divisible_input_shape: + return + h, w = x.shape[-2:] output_stride = self.encoder.output_stride if h % output_stride != 0 or w % output_stride != 0: @@ -51,15 +53,13 @@ def check_input_shape(self, x): def forward(self, x): """Sequentially pass `x` trough model`s encoder, decoder and heads""" - if ( - not torch.jit.is_tracing() - and not is_torch_compiling() - and self.requires_divisible_input_shape + if not ( + torch.jit.is_scripting() or torch.jit.is_tracing() or is_torch_compiling() ): self.check_input_shape(x) features = self.encoder(x) - decoder_output = self.decoder(*features) + decoder_output = self.decoder(features) masks = self.segmentation_head(decoder_output) diff --git a/segmentation_models_pytorch/decoders/deeplabv3/decoder.py b/segmentation_models_pytorch/decoders/deeplabv3/decoder.py index 3fd73786..15280043 100644 --- a/segmentation_models_pytorch/decoders/deeplabv3/decoder.py +++ b/segmentation_models_pytorch/decoders/deeplabv3/decoder.py @@ -31,7 +31,7 @@ """ from collections.abc import Iterable, Sequence -from typing import Literal +from typing import Literal, List import torch from torch import nn @@ -49,21 +49,42 @@ def __init__( aspp_separable: bool, aspp_dropout: float, ): - super().__init__( - ASPP( - in_channels, - out_channels, - atrous_rates, - separable=aspp_separable, - dropout=aspp_dropout, - ), - nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False), - nn.BatchNorm2d(out_channels), - nn.ReLU(), + super().__init__() + self.aspp = ASPP( + in_channels, + out_channels, + atrous_rates, + separable=aspp_separable, + dropout=aspp_dropout, ) - - def forward(self, *features): - return super().forward(features[-1]) + self.conv = nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU() + + def forward(self, features: List[torch.Tensor]) -> torch.Tensor: + x = features[-1] + x = self.aspp(x) + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + def load_state_dict(self, state_dict, *args, **kwargs): + # For backward compatibility, previously this module was Sequential + # and was not scriptable. + keys = list(state_dict.keys()) + for key in keys: + new_key = key + if key.startswith("0."): + new_key = "aspp." + key[2:] + elif key.startswith("1."): + new_key = "conv." + key[2:] + elif key.startswith("2."): + new_key = "bn." + key[2:] + elif key.startswith("3."): + new_key = "relu." + key[2:] + state_dict[new_key] = state_dict.pop(key) + super().load_state_dict(state_dict, *args, **kwargs) class DeepLabV3PlusDecoder(nn.Module): @@ -124,7 +145,7 @@ def __init__( nn.ReLU(), ) - def forward(self, *features): + def forward(self, features: List[torch.Tensor]) -> torch.Tensor: aspp_features = self.aspp(features[-1]) aspp_features = self.up(aspp_features) high_res_features = self.block1(features[2]) @@ -174,7 +195,7 @@ def __init__(self, in_channels: int, out_channels: int): nn.ReLU(), ) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: size = x.shape[-2:] for mod in self: x = mod(x) @@ -216,7 +237,7 @@ def __init__( nn.Dropout(dropout), ) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: res = [] for conv in self.convs: res.append(conv(x)) diff --git a/segmentation_models_pytorch/decoders/fpn/decoder.py b/segmentation_models_pytorch/decoders/fpn/decoder.py index 766190f4..23178623 100644 --- a/segmentation_models_pytorch/decoders/fpn/decoder.py +++ b/segmentation_models_pytorch/decoders/fpn/decoder.py @@ -2,9 +2,11 @@ import torch.nn as nn import torch.nn.functional as F +from typing import List, Literal + class Conv3x3GNReLU(nn.Module): - def __init__(self, in_channels, out_channels, upsample=False): + def __init__(self, in_channels: int, out_channels: int, upsample: bool = False): super().__init__() self.upsample = upsample self.block = nn.Sequential( @@ -15,27 +17,27 @@ def __init__(self, in_channels, out_channels, upsample=False): nn.ReLU(inplace=True), ) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.block(x) if self.upsample: - x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) + x = F.interpolate(x, scale_factor=2.0, mode="bilinear", align_corners=True) return x class FPNBlock(nn.Module): - def __init__(self, pyramid_channels, skip_channels): + def __init__(self, pyramid_channels: int, skip_channels: int): super().__init__() self.skip_conv = nn.Conv2d(skip_channels, pyramid_channels, kernel_size=1) - def forward(self, x, skip=None): - x = F.interpolate(x, scale_factor=2, mode="nearest") + def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor: + x = F.interpolate(x, scale_factor=2.0, mode="nearest") skip = self.skip_conv(skip) x = x + skip return x class SegmentationBlock(nn.Module): - def __init__(self, in_channels, out_channels, n_upsamples=0): + def __init__(self, in_channels: int, out_channels: int, n_upsamples: int = 0): super().__init__() blocks = [Conv3x3GNReLU(in_channels, out_channels, upsample=bool(n_upsamples))] @@ -51,7 +53,7 @@ def forward(self, x): class MergeBlock(nn.Module): - def __init__(self, policy): + def __init__(self, policy: Literal["add", "cat"]): super().__init__() if policy not in ["add", "cat"]: raise ValueError( @@ -59,28 +61,29 @@ def __init__(self, policy): ) self.policy = policy - def forward(self, x): + def forward(self, x: List[torch.Tensor]) -> torch.Tensor: if self.policy == "add": - return sum(x) + output = torch.stack(x).sum(dim=0) elif self.policy == "cat": - return torch.cat(x, dim=1) + output = torch.cat(x, dim=1) else: raise ValueError( "`merge_policy` must be one of: ['add', 'cat'], got {}".format( self.policy ) ) + return output class FPNDecoder(nn.Module): def __init__( self, - encoder_channels, - encoder_depth=5, - pyramid_channels=256, - segmentation_channels=128, - dropout=0.2, - merge_policy="add", + encoder_channels: List[int], + encoder_depth: int = 5, + pyramid_channels: int = 256, + segmentation_channels: int = 128, + dropout: float = 0.2, + merge_policy: Literal["add", "cat"] = "add", ): super().__init__() @@ -116,7 +119,7 @@ def __init__( self.merge = MergeBlock(merge_policy) self.dropout = nn.Dropout2d(p=dropout, inplace=True) - def forward(self, *features): + def forward(self, features: List[torch.Tensor]) -> torch.Tensor: c2, c3, c4, c5 = features[-4:] p5 = self.p5(c5) @@ -124,9 +127,12 @@ def forward(self, *features): p3 = self.p3(p4, c3) p2 = self.p2(p3, c2) - feature_pyramid = [ - seg_block(p) for seg_block, p in zip(self.seg_blocks, [p5, p4, p3, p2]) - ] + s5 = self.seg_blocks[0](p5) + s4 = self.seg_blocks[1](p4) + s3 = self.seg_blocks[2](p3) + s2 = self.seg_blocks[3](p2) + + feature_pyramid = [s5, s4, s3, s2] x = self.merge(feature_pyramid) x = self.dropout(x) diff --git a/segmentation_models_pytorch/decoders/linknet/decoder.py b/segmentation_models_pytorch/decoders/linknet/decoder.py index e16a32c8..8dfd8434 100644 --- a/segmentation_models_pytorch/decoders/linknet/decoder.py +++ b/segmentation_models_pytorch/decoders/linknet/decoder.py @@ -1,10 +1,12 @@ +import torch import torch.nn as nn +from typing import List, Optional from segmentation_models_pytorch.base import modules class TransposeX2(nn.Sequential): - def __init__(self, in_channels, out_channels, use_batchnorm=True): + def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True): super().__init__() layers = [ nn.ConvTranspose2d( @@ -20,7 +22,7 @@ def __init__(self, in_channels, out_channels, use_batchnorm=True): class DecoderBlock(nn.Module): - def __init__(self, in_channels, out_channels, use_batchnorm=True): + def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True): super().__init__() self.block = nn.Sequential( @@ -41,7 +43,9 @@ def __init__(self, in_channels, out_channels, use_batchnorm=True): ), ) - def forward(self, x, skip=None): + def forward( + self, x: torch.Tensor, skip: Optional[torch.Tensor] = None + ) -> torch.Tensor: x = self.block(x) if skip is not None: x = x + skip @@ -50,7 +54,11 @@ def forward(self, x, skip=None): class LinknetDecoder(nn.Module): def __init__( - self, encoder_channels, prefinal_channels=32, n_blocks=5, use_batchnorm=True + self, + encoder_channels: List[int], + prefinal_channels: int = 32, + n_blocks: int = 5, + use_batchnorm: bool = True, ): super().__init__() @@ -68,7 +76,7 @@ def __init__( ] ) - def forward(self, *features): + def forward(self, features: List[torch.Tensor]) -> torch.Tensor: features = features[1:] # remove first skip features = features[::-1] # reverse channels to start from head of encoder diff --git a/segmentation_models_pytorch/decoders/manet/decoder.py b/segmentation_models_pytorch/decoders/manet/decoder.py index 0f6af18d..61b1fe57 100644 --- a/segmentation_models_pytorch/decoders/manet/decoder.py +++ b/segmentation_models_pytorch/decoders/manet/decoder.py @@ -2,12 +2,15 @@ import torch.nn as nn import torch.nn.functional as F +from typing import List, Optional + from segmentation_models_pytorch.base import modules as md -class PAB(nn.Module): - def __init__(self, in_channels, out_channels, pab_channels=64): - super(PAB, self).__init__() +class PABBlock(nn.Module): + def __init__(self, in_channels: int, pab_channels: int = 64): + super().__init__() + # Series of 1x1 conv to generate attention feature maps self.pab_channels = pab_channels self.in_channels = in_channels @@ -17,10 +20,9 @@ def __init__(self, in_channels, out_channels, pab_channels=64): self.map_softmax = nn.Softmax(dim=1) self.out_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) - def forward(self, x): - bsize = x.size()[0] - h = x.size()[2] - w = x.size()[3] + def forward(self, x: torch.Tensor) -> torch.Tensor: + batch_size, _, height, width = x.shape + x_top = self.top_conv(x) x_center = self.center_conv(x) x_bottom = self.bottom_conv(x) @@ -30,20 +32,28 @@ def forward(self, x): x_bottom = x_bottom.flatten(2).transpose(1, 2) sp_map = torch.matmul(x_center, x_top) - sp_map = self.map_softmax(sp_map.view(bsize, -1)).view(bsize, h * w, h * w) + sp_map = self.map_softmax(sp_map.view(batch_size, -1)) + sp_map = sp_map.view(batch_size, height * width, height * width) + sp_map = torch.matmul(sp_map, x_bottom) - sp_map = sp_map.reshape(bsize, self.in_channels, h, w) + sp_map = sp_map.reshape(batch_size, self.in_channels, height, width) + x = x + sp_map x = self.out_conv(x) return x -class MFAB(nn.Module): +class MFABBlock(nn.Module): def __init__( - self, in_channels, skip_channels, out_channels, use_batchnorm=True, reduction=16 + self, + in_channels: int, + skip_channels: int, + out_channels: int, + use_batchnorm: bool = True, + reduction: int = 16, ): - # MFAB is just a modified version of SE-blocks, one for skip, one for input - super(MFAB, self).__init__() + # MFABBlock is just a modified version of SE-blocks, one for skip, one for input + super().__init__() self.hl_conv = nn.Sequential( md.Conv2dReLU( in_channels, @@ -87,9 +97,11 @@ def __init__( use_batchnorm=use_batchnorm, ) - def forward(self, x, skip=None): + def forward( + self, x: torch.Tensor, skip: Optional[torch.Tensor] = None + ) -> torch.Tensor: x = self.hl_conv(x) - x = F.interpolate(x, scale_factor=2, mode="nearest") + x = F.interpolate(x, scale_factor=2.0, mode="nearest") attention_hl = self.SE_hl(x) if skip is not None: attention_ll = self.SE_ll(skip) @@ -102,7 +114,13 @@ def forward(self, x, skip=None): class DecoderBlock(nn.Module): - def __init__(self, in_channels, skip_channels, out_channels, use_batchnorm=True): + def __init__( + self, + in_channels: int, + skip_channels: int, + out_channels: int, + use_batchnorm: bool = True, + ): super().__init__() self.conv1 = md.Conv2dReLU( in_channels + skip_channels, @@ -119,8 +137,10 @@ def __init__(self, in_channels, skip_channels, out_channels, use_batchnorm=True) use_batchnorm=use_batchnorm, ) - def forward(self, x, skip=None): - x = F.interpolate(x, scale_factor=2, mode="nearest") + def forward( + self, x: torch.Tensor, skip: Optional[torch.Tensor] = None + ) -> torch.Tensor: + x = F.interpolate(x, scale_factor=2.0, mode="nearest") if skip is not None: x = torch.cat([x, skip], dim=1) x = self.conv1(x) @@ -131,12 +151,12 @@ def forward(self, x, skip=None): class MAnetDecoder(nn.Module): def __init__( self, - encoder_channels, - decoder_channels, - n_blocks=5, - reduction=16, - use_batchnorm=True, - pab_channels=64, + encoder_channels: List[int], + decoder_channels: List[int], + n_blocks: int = 5, + reduction: int = 16, + use_batchnorm: bool = True, + pab_channels: int = 64, ): super().__init__() @@ -159,12 +179,12 @@ def __init__( skip_channels = list(encoder_channels[1:]) + [0] out_channels = decoder_channels - self.center = PAB(head_channels, head_channels, pab_channels=pab_channels) + self.center = PABBlock(head_channels, pab_channels=pab_channels) # combine decoder keyword arguments kwargs = dict(use_batchnorm=use_batchnorm) # no attention type here blocks = [ - MFAB(in_ch, skip_ch, out_ch, reduction=reduction, **kwargs) + MFABBlock(in_ch, skip_ch, out_ch, reduction=reduction, **kwargs) if skip_ch > 0 else DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels) @@ -172,7 +192,7 @@ def __init__( # for the last we dont have skip connection -> use simple decoder block self.blocks = nn.ModuleList(blocks) - def forward(self, *features): + def forward(self, features: List[torch.Tensor]) -> torch.Tensor: features = features[1:] # remove first skip with same spatial resolution features = features[::-1] # reverse channels to start from head of encoder diff --git a/segmentation_models_pytorch/decoders/pan/decoder.py b/segmentation_models_pytorch/decoders/pan/decoder.py index fa0bb261..ed8d0ee9 100644 --- a/segmentation_models_pytorch/decoders/pan/decoder.py +++ b/segmentation_models_pytorch/decoders/pan/decoder.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Literal +from typing import Literal, List import torch import torch.nn as nn @@ -31,18 +31,22 @@ def __init__( bias=bias, groups=groups, ) + self.activation = nn.ReLU(inplace=True) + self.bn = nn.BatchNorm2d(out_channels) + self.add_relu = add_relu self.interpolate = interpolate - self.bn = nn.BatchNorm2d(out_channels) - self.activation = nn.ReLU(inplace=True) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv(x) x = self.bn(x) + if self.add_relu: x = self.activation(x) + if self.interpolate: - x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) + x = F.interpolate(x, scale_factor=2.0, mode="bilinear", align_corners=True) + return x @@ -50,7 +54,7 @@ class FPABlock(nn.Module): def __init__( self, in_channels: int, out_channels: int, upscale_mode: str = "bilinear" ): - super(FPABlock, self).__init__() + super().__init__() self.upscale_mode = upscale_mode if self.upscale_mode == "bilinear": @@ -70,7 +74,7 @@ def __init__( ), ) - # midddle branch + # middle branch self.mid = nn.Sequential( ConvBnRelu( in_channels=in_channels, @@ -112,30 +116,50 @@ def __init__( in_channels=1, out_channels=1, kernel_size=7, stride=1, padding=3 ) - def forward(self, x): - h, w = x.size(2), x.size(3) - b1 = self.branch1(x) - upscale_parameters = dict( - mode=self.upscale_mode, align_corners=self.align_corners + def forward(self, x: torch.Tensor) -> torch.Tensor: + _, _, height, width = x.shape + + branch1_output = self.branch1(x) + branch1_output = F.interpolate( + branch1_output, + size=(height, width), + mode=self.upscale_mode, + align_corners=self.align_corners, ) - b1 = F.interpolate(b1, size=(h, w), **upscale_parameters) - mid = self.mid(x) + middle_output = self.mid(x) + x1 = self.down1(x) x2 = self.down2(x1) x3 = self.down3(x2) - x3 = F.interpolate(x3, size=(h // 4, w // 4), **upscale_parameters) + x3 = F.interpolate( + x3, + size=(height // 4, width // 4), + mode=self.upscale_mode, + align_corners=self.align_corners, + ) x2 = self.conv2(x2) x = x2 + x3 - x = F.interpolate(x, size=(h // 2, w // 2), **upscale_parameters) + x = F.interpolate( + x, + size=(height // 2, width // 2), + mode=self.upscale_mode, + align_corners=self.align_corners, + ) x1 = self.conv1(x1) x = x + x1 - x = F.interpolate(x, size=(h, w), **upscale_parameters) + x = F.interpolate( + x, + size=(height, width), + mode=self.upscale_mode, + align_corners=self.align_corners, + ) + + x = torch.mul(x, middle_output) + x = x + branch1_output - x = torch.mul(x, mid) - x = x + b1 return x @@ -162,15 +186,18 @@ def __init__( in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1 ) - def forward(self, x, y): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """ Args: x: low level feature y: high level feature """ - h, w = x.size(2), x.size(3) + height, width = x.shape[2:] y_up = F.interpolate( - y, size=(h, w), mode=self.upscale_mode, align_corners=self.align_corners + y, + size=(height, width), + mode=self.upscale_mode, + align_corners=self.align_corners, ) x = self.conv2(x) y = self.conv1(y) @@ -220,7 +247,7 @@ def __init__( upscale_mode=upscale_mode, ) - def forward(self, *features): + def forward(self, features: List[torch.Tensor]) -> torch.Tensor: features = features[2:] # remove first and second skip out = self.fpa(features[-1]) # 1/16 or 1/32 diff --git a/segmentation_models_pytorch/decoders/pspnet/decoder.py b/segmentation_models_pytorch/decoders/pspnet/decoder.py index 40d2e945..99ec5f72 100644 --- a/segmentation_models_pytorch/decoders/pspnet/decoder.py +++ b/segmentation_models_pytorch/decoders/pspnet/decoder.py @@ -2,14 +2,23 @@ import torch.nn as nn import torch.nn.functional as F +from typing import List, Tuple from segmentation_models_pytorch.base import modules class PSPBlock(nn.Module): - def __init__(self, in_channels, out_channels, pool_size, use_bathcnorm=True): + def __init__( + self, + in_channels: int, + out_channels: int, + pool_size: int, + use_bathcnorm: bool = True, + ): super().__init__() + if pool_size == 1: use_bathcnorm = False # PyTorch does not support BatchNorm for 1x1 shape + self.pool = nn.Sequential( nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size)), modules.Conv2dReLU( @@ -17,15 +26,20 @@ def __init__(self, in_channels, out_channels, pool_size, use_bathcnorm=True): ), ) - def forward(self, x): - h, w = x.size(2), x.size(3) + def forward(self, x: torch.Tensor) -> torch.Tensor: + height, width = x.shape[2:] x = self.pool(x) - x = F.interpolate(x, size=(h, w), mode="bilinear", align_corners=True) + x = F.interpolate(x, size=(height, width), mode="bilinear", align_corners=True) return x class PSPModule(nn.Module): - def __init__(self, in_channels, sizes=(1, 2, 3, 6), use_bathcnorm=True): + def __init__( + self, + in_channels: int, + sizes: Tuple[int, ...] = (1, 2, 3, 6), + use_bathcnorm: bool = True, + ): super().__init__() self.blocks = nn.ModuleList( @@ -48,7 +62,11 @@ def forward(self, x): class PSPDecoder(nn.Module): def __init__( - self, encoder_channels, use_batchnorm=True, out_channels=512, dropout=0.2 + self, + encoder_channels: List[int], + use_batchnorm: bool = True, + out_channels: int = 512, + dropout: float = 0.2, ): super().__init__() @@ -67,7 +85,7 @@ def __init__( self.dropout = nn.Dropout2d(p=dropout) - def forward(self, *features): + def forward(self, features: List[torch.Tensor]) -> torch.Tensor: x = features[-1] x = self.psp(x) x = self.conv(x) diff --git a/segmentation_models_pytorch/decoders/segformer/decoder.py b/segmentation_models_pytorch/decoders/segformer/decoder.py index daa78b37..cd160a4c 100644 --- a/segmentation_models_pytorch/decoders/segformer/decoder.py +++ b/segmentation_models_pytorch/decoders/segformer/decoder.py @@ -2,11 +2,12 @@ import torch.nn as nn import torch.nn.functional as F +from typing import List from segmentation_models_pytorch.base import modules as md class MLP(nn.Module): - def __init__(self, skip_channels, segmentation_channels): + def __init__(self, skip_channels: int, segmentation_channels: int): super().__init__() self.linear = nn.Linear(skip_channels, segmentation_channels) @@ -22,9 +23,9 @@ def forward(self, x: torch.Tensor): class SegformerDecoder(nn.Module): def __init__( self, - encoder_channels, - encoder_depth=5, - segmentation_channels=256, + encoder_channels: List[int], + encoder_depth: int = 5, + segmentation_channels: int = 256, ): super().__init__() @@ -36,9 +37,9 @@ def __init__( ) if encoder_channels[1] == 0: - encoder_channels = tuple( + encoder_channels = [ channel for index, channel in enumerate(encoder_channels) if index != 1 - ) + ] encoder_channels = encoder_channels[::-1] self.mlp_stage = nn.ModuleList( @@ -52,7 +53,7 @@ def __init__( use_batchnorm=True, ) - def forward(self, *features): + def forward(self, features: List[torch.Tensor]) -> torch.Tensor: # Resize all features to the size of the largest feature target_size = [dim // 4 for dim in features[0].shape[2:]] @@ -60,8 +61,8 @@ def forward(self, *features): features = features[::-1] # reverse channels to start from head of encoder resized_features = [] - for feature, stage in zip(features, self.mlp_stage): - feature = stage(feature) + for i, mlp_layer in enumerate(self.mlp_stage): + feature = mlp_layer(features[i]) resized_feature = F.interpolate( feature, size=target_size, mode="bilinear", align_corners=False ) diff --git a/segmentation_models_pytorch/decoders/unet/decoder.py b/segmentation_models_pytorch/decoders/unet/decoder.py index e6bf4d16..0e4f35fd 100644 --- a/segmentation_models_pytorch/decoders/unet/decoder.py +++ b/segmentation_models_pytorch/decoders/unet/decoder.py @@ -2,7 +2,7 @@ import torch.nn as nn import torch.nn.functional as F -from typing import Optional, Sequence +from typing import Optional, Sequence, List from segmentation_models_pytorch.base import modules as md @@ -140,7 +140,7 @@ def __init__( ) self.blocks.append(block) - def forward(self, *features: torch.Tensor) -> torch.Tensor: + def forward(self, features: List[torch.Tensor]) -> torch.Tensor: # spatial shapes of features: [hw, hw/2, hw/4, hw/8, ...] spatial_shapes = [feature.shape[2:] for feature in features] spatial_shapes = spatial_shapes[::-1] diff --git a/segmentation_models_pytorch/decoders/unetplusplus/decoder.py b/segmentation_models_pytorch/decoders/unetplusplus/decoder.py index feafb5d4..3282849f 100644 --- a/segmentation_models_pytorch/decoders/unetplusplus/decoder.py +++ b/segmentation_models_pytorch/decoders/unetplusplus/decoder.py @@ -2,17 +2,19 @@ import torch.nn as nn import torch.nn.functional as F +from typing import Optional, List + from segmentation_models_pytorch.base import modules as md class DecoderBlock(nn.Module): def __init__( self, - in_channels, - skip_channels, - out_channels, - use_batchnorm=True, - attention_type=None, + in_channels: int, + skip_channels: int, + out_channels: int, + use_batchnorm: bool = True, + attention_type: Optional[str] = None, ): super().__init__() self.conv1 = md.Conv2dReLU( @@ -34,8 +36,10 @@ def __init__( ) self.attention2 = md.Attention(attention_type, in_channels=out_channels) - def forward(self, x, skip=None): - x = F.interpolate(x, scale_factor=2, mode="nearest") + def forward( + self, x: torch.Tensor, skip: Optional[torch.Tensor] = None + ) -> torch.Tensor: + x = F.interpolate(x, scale_factor=2.0, mode="nearest") if skip is not None: x = torch.cat([x, skip], dim=1) x = self.attention1(x) @@ -46,7 +50,7 @@ def forward(self, x, skip=None): class CenterBlock(nn.Sequential): - def __init__(self, in_channels, out_channels, use_batchnorm=True): + def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True): conv1 = md.Conv2dReLU( in_channels, out_channels, @@ -67,20 +71,18 @@ def __init__(self, in_channels, out_channels, use_batchnorm=True): class UnetPlusPlusDecoder(nn.Module): def __init__( self, - encoder_channels, - decoder_channels, - n_blocks=5, - use_batchnorm=True, - attention_type=None, - center=False, + encoder_channels: List[int], + decoder_channels: List[int], + n_blocks: int = 5, + use_batchnorm: bool = True, + attention_type: Optional[str] = None, + center: bool = False, ): super().__init__() if n_blocks != len(decoder_channels): raise ValueError( - "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( - n_blocks, len(decoder_channels) - ) + f"Model depth is {n_blocks}, but you provide `decoder_channels` for {len(decoder_channels)} blocks." ) # remove first skip with same spatial resolution @@ -125,9 +127,10 @@ def __init__( self.blocks = nn.ModuleDict(blocks) self.depth = len(self.in_channels) - 1 - def forward(self, *features): + def forward(self, features: List[torch.Tensor]) -> torch.Tensor: features = features[1:] # remove first skip with same spatial resolution features = features[::-1] # reverse channels to start from head of encoder + # start building dense connections dense_x = {} for layer_idx in range(len(self.in_channels) - 1): diff --git a/segmentation_models_pytorch/encoders/__init__.py b/segmentation_models_pytorch/encoders/__init__.py index bc4fafce..12b106df 100644 --- a/segmentation_models_pytorch/encoders/__init__.py +++ b/segmentation_models_pytorch/encoders/__init__.py @@ -94,6 +94,7 @@ def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, ** params = copy.deepcopy(encoders[name]["params"]) params["depth"] = depth + params["output_stride"] = output_stride EncoderClass = encoders[name]["encoder"] encoder = EncoderClass(**params) diff --git a/segmentation_models_pytorch/encoders/_base.py b/segmentation_models_pytorch/encoders/_base.py index 20b6aa4c..335efed0 100644 --- a/segmentation_models_pytorch/encoders/_base.py +++ b/segmentation_models_pytorch/encoders/_base.py @@ -1,5 +1,5 @@ import torch -from typing import Sequence +from typing import Sequence, Dict from . import _utils as utils @@ -10,7 +10,10 @@ class EncoderMixin: - patching first convolution for arbitrary input channels """ - _output_stride = 32 + def __init__(self): + self._depth = 5 + self._in_channels = 3 + self._output_stride = 32 @property def out_channels(self): @@ -28,13 +31,13 @@ def set_in_channels(self, in_channels, pretrained=True): self._in_channels = in_channels if self._out_channels[0] == 3: - self._out_channels = tuple([in_channels] + list(self._out_channels)[1:]) + self._out_channels = [in_channels] + self._out_channels[1:] utils.patch_first_conv( model=self, new_in_channels=in_channels, pretrained=pretrained ) - def get_stages(self) -> dict[int, Sequence[torch.nn.Module]]: + def get_stages(self) -> Dict[int, Sequence[torch.nn.Module]]: """Override it in your implementation, should return a dictionary with keys as the output stride and values as the list of modules """ diff --git a/segmentation_models_pytorch/encoders/densenet.py b/segmentation_models_pytorch/encoders/densenet.py index 4dd23f2f..c496d048 100644 --- a/segmentation_models_pytorch/encoders/densenet.py +++ b/segmentation_models_pytorch/encoders/densenet.py @@ -33,11 +33,12 @@ class DenseNetEncoder(DenseNet, EncoderMixin): - def __init__(self, out_channels, depth=5, **kwargs): + def __init__(self, out_channels, depth=5, output_stride=32, **kwargs): super().__init__(**kwargs) - self._out_channels = out_channels self._depth = depth self._in_channels = 3 + self._out_channels = out_channels + self._output_stride = output_stride del self.classifier def make_dilated(self, *args, **kwargs): @@ -157,7 +158,7 @@ def load_state_dict(self, state_dict): "encoder": DenseNetEncoder, "pretrained_settings": pretrained_settings["densenet121"], "params": { - "out_channels": (3, 64, 256, 512, 1024, 1024), + "out_channels": [3, 64, 256, 512, 1024, 1024], "num_init_features": 64, "growth_rate": 32, "block_config": (6, 12, 24, 16), @@ -167,7 +168,7 @@ def load_state_dict(self, state_dict): "encoder": DenseNetEncoder, "pretrained_settings": pretrained_settings["densenet169"], "params": { - "out_channels": (3, 64, 256, 512, 1280, 1664), + "out_channels": [3, 64, 256, 512, 1280, 1664], "num_init_features": 64, "growth_rate": 32, "block_config": (6, 12, 32, 32), @@ -177,7 +178,7 @@ def load_state_dict(self, state_dict): "encoder": DenseNetEncoder, "pretrained_settings": pretrained_settings["densenet201"], "params": { - "out_channels": (3, 64, 256, 512, 1792, 1920), + "out_channels": [3, 64, 256, 512, 1792, 1920], "num_init_features": 64, "growth_rate": 32, "block_config": (6, 12, 48, 32), @@ -187,7 +188,7 @@ def load_state_dict(self, state_dict): "encoder": DenseNetEncoder, "pretrained_settings": pretrained_settings["densenet161"], "params": { - "out_channels": (3, 96, 384, 768, 2112, 2208), + "out_channels": [3, 96, 384, 768, 2112, 2208], "num_init_features": 96, "growth_rate": 48, "block_config": (6, 12, 36, 24), diff --git a/segmentation_models_pytorch/encoders/dpn.py b/segmentation_models_pytorch/encoders/dpn.py index 51c153a7..e25acd7c 100644 --- a/segmentation_models_pytorch/encoders/dpn.py +++ b/segmentation_models_pytorch/encoders/dpn.py @@ -25,7 +25,7 @@ import torch import torch.nn.functional as F -from typing import List +from typing import List, Dict, Sequence from pretrainedmodels.models.dpn import DPN from pretrainedmodels.models.dpn import pretrained_settings @@ -34,19 +34,27 @@ class DPNEncoder(DPN, EncoderMixin): - def __init__(self, stage_idxs, out_channels, depth=5, **kwargs): + def __init__( + self, + stage_idxs: List[int], + out_channels: List[int], + depth: int = 5, + output_stride: int = 32, + **kwargs, + ): super().__init__(**kwargs) self._stage_idxs = stage_idxs self._depth = depth - self._out_channels = out_channels self._in_channels = 3 + self._out_channels = out_channels + self._output_stride = output_stride del self.last_linear - def get_stages(self): + def get_stages(self) -> Dict[int, Sequence[torch.nn.Module]]: return { - 16: self.features[self._stage_idxs[1] : self._stage_idxs[2]], - 32: self.features[self._stage_idxs[2] : self._stage_idxs[3]], + 16: [self.features[self._stage_idxs[1] : self._stage_idxs[2]]], + 32: [self.features[self._stage_idxs[2] : self._stage_idxs[3]]], } def forward(self, x: torch.Tensor) -> List[torch.Tensor]: @@ -91,8 +99,8 @@ def load_state_dict(self, state_dict, **kwargs): "encoder": DPNEncoder, "pretrained_settings": pretrained_settings["dpn68"], "params": { - "stage_idxs": (4, 8, 20, 24), - "out_channels": (3, 10, 144, 320, 704, 832), + "stage_idxs": [4, 8, 20, 24], + "out_channels": [3, 10, 144, 320, 704, 832], "groups": 32, "inc_sec": (16, 32, 32, 64), "k_r": 128, @@ -107,8 +115,8 @@ def load_state_dict(self, state_dict, **kwargs): "encoder": DPNEncoder, "pretrained_settings": pretrained_settings["dpn68b"], "params": { - "stage_idxs": (4, 8, 20, 24), - "out_channels": (3, 10, 144, 320, 704, 832), + "stage_idxs": [4, 8, 20, 24], + "out_channels": [3, 10, 144, 320, 704, 832], "b": True, "groups": 32, "inc_sec": (16, 32, 32, 64), @@ -124,8 +132,8 @@ def load_state_dict(self, state_dict, **kwargs): "encoder": DPNEncoder, "pretrained_settings": pretrained_settings["dpn92"], "params": { - "stage_idxs": (4, 8, 28, 32), - "out_channels": (3, 64, 336, 704, 1552, 2688), + "stage_idxs": [4, 8, 28, 32], + "out_channels": [3, 64, 336, 704, 1552, 2688], "groups": 32, "inc_sec": (16, 32, 24, 128), "k_r": 96, @@ -139,8 +147,8 @@ def load_state_dict(self, state_dict, **kwargs): "encoder": DPNEncoder, "pretrained_settings": pretrained_settings["dpn98"], "params": { - "stage_idxs": (4, 10, 30, 34), - "out_channels": (3, 96, 336, 768, 1728, 2688), + "stage_idxs": [4, 10, 30, 34], + "out_channels": [3, 96, 336, 768, 1728, 2688], "groups": 40, "inc_sec": (16, 32, 32, 128), "k_r": 160, @@ -154,8 +162,8 @@ def load_state_dict(self, state_dict, **kwargs): "encoder": DPNEncoder, "pretrained_settings": pretrained_settings["dpn107"], "params": { - "stage_idxs": (5, 13, 33, 37), - "out_channels": (3, 128, 376, 1152, 2432, 2688), + "stage_idxs": [5, 13, 33, 37], + "out_channels": [3, 128, 376, 1152, 2432, 2688], "groups": 50, "inc_sec": (20, 64, 64, 128), "k_r": 200, @@ -169,8 +177,8 @@ def load_state_dict(self, state_dict, **kwargs): "encoder": DPNEncoder, "pretrained_settings": pretrained_settings["dpn131"], "params": { - "stage_idxs": (5, 13, 41, 45), - "out_channels": (3, 128, 352, 832, 1984, 2688), + "stage_idxs": [5, 13, 41, 45], + "out_channels": [3, 128, 352, 832, 1984, 2688], "groups": 40, "inc_sec": (16, 32, 32, 128), "k_r": 160, diff --git a/segmentation_models_pytorch/encoders/efficientnet.py b/segmentation_models_pytorch/encoders/efficientnet.py index 6c4a4b5f..bb56abf2 100644 --- a/segmentation_models_pytorch/encoders/efficientnet.py +++ b/segmentation_models_pytorch/encoders/efficientnet.py @@ -24,7 +24,7 @@ """ import torch -from typing import List +from typing import List, Dict, Sequence from efficientnet_pytorch import EfficientNet from efficientnet_pytorch.utils import url_map, url_map_advprop, get_model_params @@ -33,21 +33,29 @@ class EfficientNetEncoder(EfficientNet, EncoderMixin): - def __init__(self, stage_idxs, out_channels, model_name, depth=5): + def __init__( + self, + stage_idxs: List[int], + out_channels: List[int], + model_name: str, + depth: int = 5, + output_stride: int = 32, + ): blocks_args, global_params = get_model_params(model_name, override_params=None) super().__init__(blocks_args, global_params) self._stage_idxs = stage_idxs - self._out_channels = out_channels self._depth = depth self._in_channels = 3 + self._out_channels = out_channels + self._output_stride = output_stride del self._fc - def get_stages(self): + def get_stages(self) -> Dict[int, Sequence[torch.nn.Module]]: return { - 16: self._blocks[self._stage_idxs[1] : self._stage_idxs[2]], - 32: self._blocks[self._stage_idxs[2] :], + 16: [self._blocks[self._stage_idxs[1] : self._stage_idxs[2]]], + 32: [self._blocks[self._stage_idxs[2] :]], } def apply_blocks( @@ -119,8 +127,8 @@ def _get_pretrained_settings(encoder): "encoder": EfficientNetEncoder, "pretrained_settings": _get_pretrained_settings("efficientnet-b0"), "params": { - "out_channels": (3, 32, 24, 40, 112, 320), - "stage_idxs": (3, 5, 9, 16), + "out_channels": [3, 32, 24, 40, 112, 320], + "stage_idxs": [3, 5, 9, 16], "model_name": "efficientnet-b0", }, }, @@ -128,8 +136,8 @@ def _get_pretrained_settings(encoder): "encoder": EfficientNetEncoder, "pretrained_settings": _get_pretrained_settings("efficientnet-b1"), "params": { - "out_channels": (3, 32, 24, 40, 112, 320), - "stage_idxs": (5, 8, 16, 23), + "out_channels": [3, 32, 24, 40, 112, 320], + "stage_idxs": [5, 8, 16, 23], "model_name": "efficientnet-b1", }, }, @@ -137,8 +145,8 @@ def _get_pretrained_settings(encoder): "encoder": EfficientNetEncoder, "pretrained_settings": _get_pretrained_settings("efficientnet-b2"), "params": { - "out_channels": (3, 32, 24, 48, 120, 352), - "stage_idxs": (5, 8, 16, 23), + "out_channels": [3, 32, 24, 48, 120, 352], + "stage_idxs": [5, 8, 16, 23], "model_name": "efficientnet-b2", }, }, @@ -146,8 +154,8 @@ def _get_pretrained_settings(encoder): "encoder": EfficientNetEncoder, "pretrained_settings": _get_pretrained_settings("efficientnet-b3"), "params": { - "out_channels": (3, 40, 32, 48, 136, 384), - "stage_idxs": (5, 8, 18, 26), + "out_channels": [3, 40, 32, 48, 136, 384], + "stage_idxs": [5, 8, 18, 26], "model_name": "efficientnet-b3", }, }, @@ -155,8 +163,8 @@ def _get_pretrained_settings(encoder): "encoder": EfficientNetEncoder, "pretrained_settings": _get_pretrained_settings("efficientnet-b4"), "params": { - "out_channels": (3, 48, 32, 56, 160, 448), - "stage_idxs": (6, 10, 22, 32), + "out_channels": [3, 48, 32, 56, 160, 448], + "stage_idxs": [6, 10, 22, 32], "model_name": "efficientnet-b4", }, }, @@ -164,8 +172,8 @@ def _get_pretrained_settings(encoder): "encoder": EfficientNetEncoder, "pretrained_settings": _get_pretrained_settings("efficientnet-b5"), "params": { - "out_channels": (3, 48, 40, 64, 176, 512), - "stage_idxs": (8, 13, 27, 39), + "out_channels": [3, 48, 40, 64, 176, 512], + "stage_idxs": [8, 13, 27, 39], "model_name": "efficientnet-b5", }, }, @@ -173,8 +181,8 @@ def _get_pretrained_settings(encoder): "encoder": EfficientNetEncoder, "pretrained_settings": _get_pretrained_settings("efficientnet-b6"), "params": { - "out_channels": (3, 56, 40, 72, 200, 576), - "stage_idxs": (9, 15, 31, 45), + "out_channels": [3, 56, 40, 72, 200, 576], + "stage_idxs": [9, 15, 31, 45], "model_name": "efficientnet-b6", }, }, @@ -182,8 +190,8 @@ def _get_pretrained_settings(encoder): "encoder": EfficientNetEncoder, "pretrained_settings": _get_pretrained_settings("efficientnet-b7"), "params": { - "out_channels": (3, 64, 48, 80, 224, 640), - "stage_idxs": (11, 18, 38, 55), + "out_channels": [3, 64, 48, 80, 224, 640], + "stage_idxs": [11, 18, 38, 55], "model_name": "efficientnet-b7", }, }, diff --git a/segmentation_models_pytorch/encoders/inceptionresnetv2.py b/segmentation_models_pytorch/encoders/inceptionresnetv2.py index 4ef53404..ced9c7f5 100644 --- a/segmentation_models_pytorch/encoders/inceptionresnetv2.py +++ b/segmentation_models_pytorch/encoders/inceptionresnetv2.py @@ -23,19 +23,28 @@ depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). """ +import torch import torch.nn as nn +from typing import List from pretrainedmodels.models.inceptionresnetv2 import InceptionResNetV2 from ._base import EncoderMixin class InceptionResNetV2Encoder(InceptionResNetV2, EncoderMixin): - def __init__(self, out_channels, depth=5, **kwargs): + def __init__( + self, + out_channels: List[int], + depth: int = 5, + output_stride: int = 32, + **kwargs, + ): super().__init__(**kwargs) - self._out_channels = out_channels self._depth = depth self._in_channels = 3 + self._out_channels = out_channels + self._output_stride = output_stride # correct paddings for m in self.modules(): @@ -55,7 +64,7 @@ def make_dilated(self, *args, **kwargs): "due to pooling operation for downsampling!" ) - def forward(self, x): + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: features = [x] if self._depth >= 1: @@ -119,6 +128,6 @@ def load_state_dict(self, state_dict, **kwargs): "num_classes": 1001, }, }, - "params": {"out_channels": (3, 64, 192, 320, 1088, 1536), "num_classes": 1000}, + "params": {"out_channels": [3, 64, 192, 320, 1088, 1536], "num_classes": 1000}, } } diff --git a/segmentation_models_pytorch/encoders/inceptionv4.py b/segmentation_models_pytorch/encoders/inceptionv4.py index 6fce4306..ccac34f9 100644 --- a/segmentation_models_pytorch/encoders/inceptionv4.py +++ b/segmentation_models_pytorch/encoders/inceptionv4.py @@ -23,19 +23,31 @@ depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). """ +import torch import torch.nn as nn + +from typing import List from pretrainedmodels.models.inceptionv4 import InceptionV4 from ._base import EncoderMixin class InceptionV4Encoder(InceptionV4, EncoderMixin): - def __init__(self, stage_idxs, out_channels, depth=5, **kwargs): + def __init__( + self, + stage_idxs: List[int], + out_channels: List[int], + depth: int = 5, + output_stride: int = 32, + **kwargs, + ): super().__init__(**kwargs) + self._stage_idxs = stage_idxs - self._out_channels = out_channels self._depth = depth self._in_channels = 3 + self._out_channels = out_channels + self._output_stride = output_stride # correct paddings for m in self.modules(): @@ -54,17 +66,7 @@ def make_dilated(self, *args, **kwargs): "due to pooling operation for downsampling!" ) - def get_stages(self): - return [ - nn.Identity(), - self.features[: self._stage_idxs[0]], - self.features[self._stage_idxs[0] : self._stage_idxs[1]], - self.features[self._stage_idxs[1] : self._stage_idxs[2]], - self.features[self._stage_idxs[2] : self._stage_idxs[3]], - self.features[self._stage_idxs[3] :], - ] - - def forward(self, x): + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: features = [x] if self._depth >= 1: @@ -119,8 +121,8 @@ def load_state_dict(self, state_dict, **kwargs): }, }, "params": { - "stage_idxs": (3, 5, 9, 15), - "out_channels": (3, 64, 192, 384, 1024, 1536), + "stage_idxs": [3, 5, 9, 15], + "out_channels": [3, 64, 192, 384, 1024, 1536], "num_classes": 1001, }, } diff --git a/segmentation_models_pytorch/encoders/mix_transformer.py b/segmentation_models_pytorch/encoders/mix_transformer.py index 54bd747e..2a1e068d 100644 --- a/segmentation_models_pytorch/encoders/mix_transformer.py +++ b/segmentation_models_pytorch/encoders/mix_transformer.py @@ -516,23 +516,28 @@ def forward(self, x, H, W): # End of NVIDIA code # --------------------------------------------------------------- +from typing import Dict, Sequence, List # noqa E402 from ._base import EncoderMixin # noqa E402 class MixVisionTransformerEncoder(MixVisionTransformer, EncoderMixin): - def __init__(self, out_channels, depth=5, **kwargs): + def __init__( + self, out_channels: List[int], depth: int = 5, output_stride: int = 32, **kwargs + ): super().__init__(**kwargs) - self._out_channels = out_channels + self._depth = depth self._in_channels = 3 + self._out_channels = out_channels + self._output_stride = output_stride - def get_stages(self): + def get_stages(self) -> Dict[int, Sequence[torch.nn.Module]]: return { 16: [self.patch_embed3, self.block3, self.norm3], 32: [self.patch_embed4, self.block4, self.norm4], } - def forward(self, x): + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: # create dummy output for the first block batch_size, _, height, width = x.shape dummy = torch.empty( @@ -592,103 +597,103 @@ def get_pretrained_cfg(name): "mit_b0": { "encoder": MixVisionTransformerEncoder, "pretrained_settings": {"imagenet": get_pretrained_cfg("mit_b0")}, - "params": dict( - out_channels=(3, 0, 32, 64, 160, 256), - patch_size=4, - embed_dims=[32, 64, 160, 256], - num_heads=[1, 2, 5, 8], - mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, - norm_layer=partial(LayerNorm, eps=1e-6), - depths=[2, 2, 2, 2], - sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, - drop_path_rate=0.1, - ), + "params": { + "out_channels": [3, 0, 32, 64, 160, 256], + "patch_size": 4, + "embed_dims": [32, 64, 160, 256], + "num_heads": [1, 2, 5, 8], + "mlp_ratios": [4, 4, 4, 4], + "qkv_bias": True, + "norm_layer": partial(LayerNorm, eps=1e-6), + "depths": [2, 2, 2, 2], + "sr_ratios": [8, 4, 2, 1], + "drop_rate": 0.0, + "drop_path_rate": 0.1, + }, }, "mit_b1": { "encoder": MixVisionTransformerEncoder, "pretrained_settings": {"imagenet": get_pretrained_cfg("mit_b1")}, - "params": dict( - out_channels=(3, 0, 64, 128, 320, 512), - patch_size=4, - embed_dims=[64, 128, 320, 512], - num_heads=[1, 2, 5, 8], - mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, - norm_layer=partial(LayerNorm, eps=1e-6), - depths=[2, 2, 2, 2], - sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, - drop_path_rate=0.1, - ), + "params": { + "out_channels": [3, 0, 64, 128, 320, 512], + "patch_size": 4, + "embed_dims": [64, 128, 320, 512], + "num_heads": [1, 2, 5, 8], + "mlp_ratios": [4, 4, 4, 4], + "qkv_bias": True, + "norm_layer": partial(LayerNorm, eps=1e-6), + "depths": [2, 2, 2, 2], + "sr_ratios": [8, 4, 2, 1], + "drop_rate": 0.0, + "drop_path_rate": 0.1, + }, }, "mit_b2": { "encoder": MixVisionTransformerEncoder, "pretrained_settings": {"imagenet": get_pretrained_cfg("mit_b2")}, - "params": dict( - out_channels=(3, 0, 64, 128, 320, 512), - patch_size=4, - embed_dims=[64, 128, 320, 512], - num_heads=[1, 2, 5, 8], - mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, - norm_layer=partial(LayerNorm, eps=1e-6), - depths=[3, 4, 6, 3], - sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, - drop_path_rate=0.1, - ), + "params": { + "out_channels": [3, 0, 64, 128, 320, 512], + "patch_size": 4, + "embed_dims": [64, 128, 320, 512], + "num_heads": [1, 2, 5, 8], + "mlp_ratios": [4, 4, 4, 4], + "qkv_bias": True, + "norm_layer": partial(LayerNorm, eps=1e-6), + "depths": [3, 4, 6, 3], + "sr_ratios": [8, 4, 2, 1], + "drop_rate": 0.0, + "drop_path_rate": 0.1, + }, }, "mit_b3": { "encoder": MixVisionTransformerEncoder, "pretrained_settings": {"imagenet": get_pretrained_cfg("mit_b3")}, - "params": dict( - out_channels=(3, 0, 64, 128, 320, 512), - patch_size=4, - embed_dims=[64, 128, 320, 512], - num_heads=[1, 2, 5, 8], - mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, - norm_layer=partial(LayerNorm, eps=1e-6), - depths=[3, 4, 18, 3], - sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, - drop_path_rate=0.1, - ), + "params": { + "out_channels": [3, 0, 64, 128, 320, 512], + "patch_size": 4, + "embed_dims": [64, 128, 320, 512], + "num_heads": [1, 2, 5, 8], + "mlp_ratios": [4, 4, 4, 4], + "qkv_bias": True, + "norm_layer": partial(LayerNorm, eps=1e-6), + "depths": [3, 4, 18, 3], + "sr_ratios": [8, 4, 2, 1], + "drop_rate": 0.0, + "drop_path_rate": 0.1, + }, }, "mit_b4": { "encoder": MixVisionTransformerEncoder, "pretrained_settings": {"imagenet": get_pretrained_cfg("mit_b4")}, - "params": dict( - out_channels=(3, 0, 64, 128, 320, 512), - patch_size=4, - embed_dims=[64, 128, 320, 512], - num_heads=[1, 2, 5, 8], - mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, - norm_layer=partial(LayerNorm, eps=1e-6), - depths=[3, 8, 27, 3], - sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, - drop_path_rate=0.1, - ), + "params": { + "out_channels": [3, 0, 64, 128, 320, 512], + "patch_size": 4, + "embed_dims": [64, 128, 320, 512], + "num_heads": [1, 2, 5, 8], + "mlp_ratios": [4, 4, 4, 4], + "qkv_bias": True, + "norm_layer": partial(LayerNorm, eps=1e-6), + "depths": [3, 8, 27, 3], + "sr_ratios": [8, 4, 2, 1], + "drop_rate": 0.0, + "drop_path_rate": 0.1, + }, }, "mit_b5": { "encoder": MixVisionTransformerEncoder, "pretrained_settings": {"imagenet": get_pretrained_cfg("mit_b5")}, - "params": dict( - out_channels=(3, 0, 64, 128, 320, 512), - patch_size=4, - embed_dims=[64, 128, 320, 512], - num_heads=[1, 2, 5, 8], - mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, - norm_layer=partial(LayerNorm, eps=1e-6), - depths=[3, 6, 40, 3], - sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, - drop_path_rate=0.1, - ), + "params": { + "out_channels": [3, 0, 64, 128, 320, 512], + "patch_size": 4, + "embed_dims": [64, 128, 320, 512], + "num_heads": [1, 2, 5, 8], + "mlp_ratios": [4, 4, 4, 4], + "qkv_bias": True, + "norm_layer": partial(LayerNorm, eps=1e-6), + "depths": [3, 6, 40, 3], + "sr_ratios": [8, 4, 2, 1], + "drop_rate": 0.0, + "drop_path_rate": 0.1, + }, }, } diff --git a/segmentation_models_pytorch/encoders/mobilenet.py b/segmentation_models_pytorch/encoders/mobilenet.py index 52c87160..482043e3 100644 --- a/segmentation_models_pytorch/encoders/mobilenet.py +++ b/segmentation_models_pytorch/encoders/mobilenet.py @@ -23,26 +23,33 @@ depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). """ +import torch import torchvision +from typing import Dict, Sequence, List from ._base import EncoderMixin class MobileNetV2Encoder(torchvision.models.MobileNetV2, EncoderMixin): - def __init__(self, out_channels, depth=5, **kwargs): + def __init__( + self, out_channels: List[int], depth: int = 5, output_stride: int = 32, **kwargs + ): super().__init__(**kwargs) + self._depth = depth - self._out_channels = out_channels self._in_channels = 3 + self._out_channels = out_channels + self._output_stride = output_stride + del self.classifier - def get_stages(self): + def get_stages(self) -> Dict[int, Sequence[torch.nn.Module]]: return { - 16: self.features[7:14], - 32: self.features[14:], + 16: [self.features[7:14]], + 32: [self.features[14:]], } - def forward(self, x): + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: features = [x] if self._depth >= 1: @@ -85,6 +92,6 @@ def load_state_dict(self, state_dict, **kwargs): "input_range": [0, 1], } }, - "params": {"out_channels": (3, 16, 24, 32, 96, 1280)}, + "params": {"out_channels": [3, 16, 24, 32, 96, 1280]}, } } diff --git a/segmentation_models_pytorch/encoders/mobileone.py b/segmentation_models_pytorch/encoders/mobileone.py index 3605dcef..20732ce3 100644 --- a/segmentation_models_pytorch/encoders/mobileone.py +++ b/segmentation_models_pytorch/encoders/mobileone.py @@ -3,7 +3,7 @@ # Copyright (C) 2022 Apple Inc. All Rights Reserved. # import copy -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Dict, Sequence import torch import torch.nn as nn @@ -298,13 +298,14 @@ class MobileOne(nn.Module, EncoderMixin): def __init__( self, - out_channels, + out_channels: List[int], num_blocks_per_stage: List[int] = [2, 8, 10, 1], width_multipliers: Optional[List[float]] = None, inference_mode: bool = False, use_se: bool = False, depth=5, in_channels=3, + output_stride=32, num_conv_branches: int = 1, ) -> None: """Construct MobileOne model. @@ -320,13 +321,14 @@ def __init__( assert len(width_multipliers) == 4 self.inference_mode = inference_mode - self._out_channels = out_channels self.in_planes = min(64, int(64 * width_multipliers[0])) self.use_se = use_se self.num_conv_branches = num_conv_branches + self._depth = depth self._in_channels = in_channels - self.set_in_channels(self._in_channels) + self._out_channels = out_channels + self._output_stride = output_stride # Build stages self.stage0 = MobileOneBlock( @@ -355,10 +357,10 @@ def __init__( num_se_blocks=num_blocks_per_stage[3] if use_se else 0, ) - def get_stages(self): + def get_stages(self) -> Dict[int, Sequence[torch.nn.Module]]: return { - 16: self.stage3, - 32: self.stage4, + 16: [self.stage3], + 32: [self.stage4], } def _make_stage( @@ -492,7 +494,7 @@ def reparameterize_model(model: torch.nn.Module) -> nn.Module: } }, "params": { - "out_channels": (3, 48, 48, 128, 256, 1024), + "out_channels": [3, 48, 48, 128, 256, 1024], "width_multipliers": (0.75, 1.0, 1.0, 2.0), "num_conv_branches": 4, "inference_mode": False, @@ -510,7 +512,7 @@ def reparameterize_model(model: torch.nn.Module) -> nn.Module: } }, "params": { - "out_channels": (3, 64, 96, 192, 512, 1280), + "out_channels": [3, 64, 96, 192, 512, 1280], "width_multipliers": (1.5, 1.5, 2.0, 2.5), "inference_mode": False, }, @@ -527,7 +529,7 @@ def reparameterize_model(model: torch.nn.Module) -> nn.Module: } }, "params": { - "out_channels": (3, 64, 96, 256, 640, 2048), + "out_channels": [3, 64, 96, 256, 640, 2048], "width_multipliers": (1.5, 2.0, 2.5, 4.0), "inference_mode": False, }, @@ -544,7 +546,7 @@ def reparameterize_model(model: torch.nn.Module) -> nn.Module: } }, "params": { - "out_channels": (3, 64, 128, 320, 768, 2048), + "out_channels": [3, 64, 128, 320, 768, 2048], "width_multipliers": (2.0, 2.5, 3.0, 4.0), "inference_mode": False, }, @@ -561,7 +563,7 @@ def reparameterize_model(model: torch.nn.Module) -> nn.Module: } }, "params": { - "out_channels": (3, 64, 192, 448, 896, 2048), + "out_channels": [3, 64, 192, 448, 896, 2048], "width_multipliers": (3.0, 3.5, 3.5, 4.0), "use_se": True, "inference_mode": False, diff --git a/segmentation_models_pytorch/encoders/resnet.py b/segmentation_models_pytorch/encoders/resnet.py index bfa37abd..fc1665dd 100644 --- a/segmentation_models_pytorch/encoders/resnet.py +++ b/segmentation_models_pytorch/encoders/resnet.py @@ -24,7 +24,7 @@ """ import torch - +from typing import Dict, Sequence, List from torchvision.models.resnet import ResNet from torchvision.models.resnet import BasicBlock from torchvision.models.resnet import Bottleneck @@ -35,19 +35,23 @@ class ResNetEncoder(ResNet, EncoderMixin): """ResNet encoder implementation.""" - def __init__(self, out_channels, depth=5, **kwargs): + def __init__( + self, out_channels: List[int], depth: int = 5, output_stride: int = 32, **kwargs + ): super().__init__(**kwargs) + self._depth = depth - self._out_channels = out_channels self._in_channels = 3 + self._out_channels = out_channels + self._output_stride = output_stride del self.fc del self.avgpool - def get_stages(self): + def get_stages(self) -> Dict[int, Sequence[torch.nn.Module]]: return { - 16: self.layer3, - 32: self.layer4, + 16: [self.layer3], + 32: [self.layer4], } def forward(self, x: torch.Tensor) -> list[torch.Tensor]: @@ -303,7 +307,7 @@ def load_state_dict(self, state_dict, **kwargs): "encoder": ResNetEncoder, "pretrained_settings": pretrained_settings["resnet18"], "params": { - "out_channels": (3, 64, 64, 128, 256, 512), + "out_channels": [3, 64, 64, 128, 256, 512], "block": BasicBlock, "layers": [2, 2, 2, 2], }, @@ -312,7 +316,7 @@ def load_state_dict(self, state_dict, **kwargs): "encoder": ResNetEncoder, "pretrained_settings": pretrained_settings["resnet34"], "params": { - "out_channels": (3, 64, 64, 128, 256, 512), + "out_channels": [3, 64, 64, 128, 256, 512], "block": BasicBlock, "layers": [3, 4, 6, 3], }, @@ -321,7 +325,7 @@ def load_state_dict(self, state_dict, **kwargs): "encoder": ResNetEncoder, "pretrained_settings": pretrained_settings["resnet50"], "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), + "out_channels": [3, 64, 256, 512, 1024, 2048], "block": Bottleneck, "layers": [3, 4, 6, 3], }, @@ -330,7 +334,7 @@ def load_state_dict(self, state_dict, **kwargs): "encoder": ResNetEncoder, "pretrained_settings": pretrained_settings["resnet101"], "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), + "out_channels": [3, 64, 256, 512, 1024, 2048], "block": Bottleneck, "layers": [3, 4, 23, 3], }, @@ -339,7 +343,7 @@ def load_state_dict(self, state_dict, **kwargs): "encoder": ResNetEncoder, "pretrained_settings": pretrained_settings["resnet152"], "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), + "out_channels": [3, 64, 256, 512, 1024, 2048], "block": Bottleneck, "layers": [3, 8, 36, 3], }, @@ -348,7 +352,7 @@ def load_state_dict(self, state_dict, **kwargs): "encoder": ResNetEncoder, "pretrained_settings": pretrained_settings["resnext50_32x4d"], "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), + "out_channels": [3, 64, 256, 512, 1024, 2048], "block": Bottleneck, "layers": [3, 4, 6, 3], "groups": 32, @@ -359,7 +363,7 @@ def load_state_dict(self, state_dict, **kwargs): "encoder": ResNetEncoder, "pretrained_settings": pretrained_settings["resnext101_32x4d"], "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), + "out_channels": [3, 64, 256, 512, 1024, 2048], "block": Bottleneck, "layers": [3, 4, 23, 3], "groups": 32, @@ -370,7 +374,7 @@ def load_state_dict(self, state_dict, **kwargs): "encoder": ResNetEncoder, "pretrained_settings": pretrained_settings["resnext101_32x8d"], "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), + "out_channels": [3, 64, 256, 512, 1024, 2048], "block": Bottleneck, "layers": [3, 4, 23, 3], "groups": 32, @@ -381,7 +385,7 @@ def load_state_dict(self, state_dict, **kwargs): "encoder": ResNetEncoder, "pretrained_settings": pretrained_settings["resnext101_32x16d"], "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), + "out_channels": [3, 64, 256, 512, 1024, 2048], "block": Bottleneck, "layers": [3, 4, 23, 3], "groups": 32, @@ -392,7 +396,7 @@ def load_state_dict(self, state_dict, **kwargs): "encoder": ResNetEncoder, "pretrained_settings": pretrained_settings["resnext101_32x32d"], "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), + "out_channels": [3, 64, 256, 512, 1024, 2048], "block": Bottleneck, "layers": [3, 4, 23, 3], "groups": 32, @@ -403,7 +407,7 @@ def load_state_dict(self, state_dict, **kwargs): "encoder": ResNetEncoder, "pretrained_settings": pretrained_settings["resnext101_32x48d"], "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), + "out_channels": [3, 64, 256, 512, 1024, 2048], "block": Bottleneck, "layers": [3, 4, 23, 3], "groups": 32, diff --git a/segmentation_models_pytorch/encoders/senet.py b/segmentation_models_pytorch/encoders/senet.py index 123bf30d..7c3a90fd 100644 --- a/segmentation_models_pytorch/encoders/senet.py +++ b/segmentation_models_pytorch/encoders/senet.py @@ -23,6 +23,9 @@ depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). """ +import torch +from typing import List, Dict, Sequence + from pretrainedmodels.models.senet import ( SENet, SEBottleneck, @@ -33,23 +36,30 @@ class SENetEncoder(SENet, EncoderMixin): - def __init__(self, out_channels, depth=5, **kwargs): + def __init__( + self, + out_channels: List[int], + depth: int = 5, + output_stride: int = 32, + **kwargs, + ): super().__init__(**kwargs) - self._out_channels = out_channels self._depth = depth self._in_channels = 3 + self._out_channels = out_channels + self._output_stride = output_stride del self.last_linear del self.avg_pool - def get_stages(self): + def get_stages(self) -> Dict[int, Sequence[torch.nn.Module]]: return { - 16: self.layer3, - 32: self.layer4, + 16: [self.layer3], + 32: [self.layer4], } - def forward(self, x): + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: features = [x] if self._depth >= 1: @@ -156,7 +166,7 @@ def load_state_dict(self, state_dict, **kwargs): "encoder": SENetEncoder, "pretrained_settings": pretrained_settings["senet154"], "params": { - "out_channels": (3, 128, 256, 512, 1024, 2048), + "out_channels": [3, 128, 256, 512, 1024, 2048], "block": SEBottleneck, "dropout_p": 0.2, "groups": 64, @@ -169,7 +179,7 @@ def load_state_dict(self, state_dict, **kwargs): "encoder": SENetEncoder, "pretrained_settings": pretrained_settings["se_resnet50"], "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), + "out_channels": [3, 64, 256, 512, 1024, 2048], "block": SEResNetBottleneck, "layers": [3, 4, 6, 3], "downsample_kernel_size": 1, @@ -186,7 +196,7 @@ def load_state_dict(self, state_dict, **kwargs): "encoder": SENetEncoder, "pretrained_settings": pretrained_settings["se_resnet101"], "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), + "out_channels": [3, 64, 256, 512, 1024, 2048], "block": SEResNetBottleneck, "layers": [3, 4, 23, 3], "downsample_kernel_size": 1, @@ -203,7 +213,7 @@ def load_state_dict(self, state_dict, **kwargs): "encoder": SENetEncoder, "pretrained_settings": pretrained_settings["se_resnet152"], "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), + "out_channels": [3, 64, 256, 512, 1024, 2048], "block": SEResNetBottleneck, "layers": [3, 8, 36, 3], "downsample_kernel_size": 1, @@ -220,7 +230,7 @@ def load_state_dict(self, state_dict, **kwargs): "encoder": SENetEncoder, "pretrained_settings": pretrained_settings["se_resnext50_32x4d"], "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), + "out_channels": [3, 64, 256, 512, 1024, 2048], "block": SEResNeXtBottleneck, "layers": [3, 4, 6, 3], "downsample_kernel_size": 1, @@ -237,7 +247,7 @@ def load_state_dict(self, state_dict, **kwargs): "encoder": SENetEncoder, "pretrained_settings": pretrained_settings["se_resnext101_32x4d"], "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), + "out_channels": [3, 64, 256, 512, 1024, 2048], "block": SEResNeXtBottleneck, "layers": [3, 4, 23, 3], "downsample_kernel_size": 1, diff --git a/segmentation_models_pytorch/encoders/timm_efficientnet.py b/segmentation_models_pytorch/encoders/timm_efficientnet.py index 3ab0c069..dcc98e28 100644 --- a/segmentation_models_pytorch/encoders/timm_efficientnet.py +++ b/segmentation_models_pytorch/encoders/timm_efficientnet.py @@ -1,7 +1,9 @@ -from functools import partial - +import torch import torch.nn as nn +from typing import List, Dict, Sequence +from functools import partial + from timm.models.efficientnet import EfficientNet from timm.models.efficientnet import decode_arch_def, round_channels, default_cfgs from timm.layers.activations import Swish @@ -95,23 +97,31 @@ def gen_efficientnet_lite_kwargs( class EfficientNetBaseEncoder(EfficientNet, EncoderMixin): - def __init__(self, stage_idxs, out_channels, depth=5, **kwargs): + def __init__( + self, + stage_idxs: List[int], + out_channels: List[int], + depth: int = 5, + output_stride: int = 32, + **kwargs, + ): super().__init__(**kwargs) self._stage_idxs = stage_idxs - self._out_channels = out_channels self._depth = depth self._in_channels = 3 + self._out_channels = out_channels + self._output_stride = output_stride del self.classifier - def get_stages(self): + def get_stages(self) -> Dict[int, Sequence[torch.nn.Module]]: return { - 16: self.blocks[self._stage_idxs[1] : self._stage_idxs[2]], - 32: self.blocks[self._stage_idxs[2] :], + 16: [self.blocks[self._stage_idxs[1] : self._stage_idxs[2]]], + 32: [self.blocks[self._stage_idxs[2] :]], } - def forward(self, x): + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: features = [x] if self._depth >= 1: @@ -200,8 +210,8 @@ def prepare_settings(settings): ), }, "params": { - "out_channels": (3, 32, 24, 40, 112, 320), - "stage_idxs": (2, 3, 5), + "out_channels": [3, 32, 24, 40, 112, 320], + "stage_idxs": [2, 3, 5], "channel_multiplier": 1.0, "depth_multiplier": 1.0, "drop_rate": 0.2, @@ -221,8 +231,8 @@ def prepare_settings(settings): ), }, "params": { - "out_channels": (3, 32, 24, 40, 112, 320), - "stage_idxs": (2, 3, 5), + "out_channels": [3, 32, 24, 40, 112, 320], + "stage_idxs": [2, 3, 5], "channel_multiplier": 1.0, "depth_multiplier": 1.1, "drop_rate": 0.2, @@ -242,8 +252,8 @@ def prepare_settings(settings): ), }, "params": { - "out_channels": (3, 32, 24, 48, 120, 352), - "stage_idxs": (2, 3, 5), + "out_channels": [3, 32, 24, 48, 120, 352], + "stage_idxs": [2, 3, 5], "channel_multiplier": 1.1, "depth_multiplier": 1.2, "drop_rate": 0.3, @@ -263,8 +273,8 @@ def prepare_settings(settings): ), }, "params": { - "out_channels": (3, 40, 32, 48, 136, 384), - "stage_idxs": (2, 3, 5), + "out_channels": [3, 40, 32, 48, 136, 384], + "stage_idxs": [2, 3, 5], "channel_multiplier": 1.2, "depth_multiplier": 1.4, "drop_rate": 0.3, @@ -284,8 +294,8 @@ def prepare_settings(settings): ), }, "params": { - "out_channels": (3, 48, 32, 56, 160, 448), - "stage_idxs": (2, 3, 5), + "out_channels": [3, 48, 32, 56, 160, 448], + "stage_idxs": [2, 3, 5], "channel_multiplier": 1.4, "depth_multiplier": 1.8, "drop_rate": 0.4, @@ -305,8 +315,8 @@ def prepare_settings(settings): ), }, "params": { - "out_channels": (3, 48, 40, 64, 176, 512), - "stage_idxs": (2, 3, 5), + "out_channels": [3, 48, 40, 64, 176, 512], + "stage_idxs": [2, 3, 5], "channel_multiplier": 1.6, "depth_multiplier": 2.2, "drop_rate": 0.4, @@ -326,8 +336,8 @@ def prepare_settings(settings): ), }, "params": { - "out_channels": (3, 56, 40, 72, 200, 576), - "stage_idxs": (2, 3, 5), + "out_channels": [3, 56, 40, 72, 200, 576], + "stage_idxs": [2, 3, 5], "channel_multiplier": 1.8, "depth_multiplier": 2.6, "drop_rate": 0.5, @@ -347,8 +357,8 @@ def prepare_settings(settings): ), }, "params": { - "out_channels": (3, 64, 48, 80, 224, 640), - "stage_idxs": (2, 3, 5), + "out_channels": [3, 64, 48, 80, 224, 640], + "stage_idxs": [2, 3, 5], "channel_multiplier": 2.0, "depth_multiplier": 3.1, "drop_rate": 0.5, @@ -365,8 +375,8 @@ def prepare_settings(settings): ), }, "params": { - "out_channels": (3, 72, 56, 88, 248, 704), - "stage_idxs": (2, 3, 5), + "out_channels": [3, 72, 56, 88, 248, 704], + "stage_idxs": [2, 3, 5], "channel_multiplier": 2.2, "depth_multiplier": 3.6, "drop_rate": 0.5, @@ -383,8 +393,8 @@ def prepare_settings(settings): ), }, "params": { - "out_channels": (3, 136, 104, 176, 480, 1376), - "stage_idxs": (2, 3, 5), + "out_channels": [3, 136, 104, 176, 480, 1376], + "stage_idxs": [2, 3, 5], "channel_multiplier": 4.3, "depth_multiplier": 5.3, "drop_rate": 0.5, @@ -398,8 +408,8 @@ def prepare_settings(settings): ) }, "params": { - "out_channels": (3, 32, 24, 40, 112, 320), - "stage_idxs": (2, 3, 5), + "out_channels": [3, 32, 24, 40, 112, 320], + "stage_idxs": [2, 3, 5], "channel_multiplier": 1.0, "depth_multiplier": 1.0, "drop_rate": 0.2, @@ -413,8 +423,8 @@ def prepare_settings(settings): ) }, "params": { - "out_channels": (3, 32, 24, 40, 112, 320), - "stage_idxs": (2, 3, 5), + "out_channels": [3, 32, 24, 40, 112, 320], + "stage_idxs": [2, 3, 5], "channel_multiplier": 1.0, "depth_multiplier": 1.1, "drop_rate": 0.2, @@ -428,8 +438,8 @@ def prepare_settings(settings): ) }, "params": { - "out_channels": (3, 32, 24, 48, 120, 352), - "stage_idxs": (2, 3, 5), + "out_channels": [3, 32, 24, 48, 120, 352], + "stage_idxs": [2, 3, 5], "channel_multiplier": 1.1, "depth_multiplier": 1.2, "drop_rate": 0.3, @@ -443,8 +453,8 @@ def prepare_settings(settings): ) }, "params": { - "out_channels": (3, 32, 32, 48, 136, 384), - "stage_idxs": (2, 3, 5), + "out_channels": [3, 32, 32, 48, 136, 384], + "stage_idxs": [2, 3, 5], "channel_multiplier": 1.2, "depth_multiplier": 1.4, "drop_rate": 0.3, @@ -458,8 +468,8 @@ def prepare_settings(settings): ) }, "params": { - "out_channels": (3, 32, 32, 56, 160, 448), - "stage_idxs": (2, 3, 5), + "out_channels": [3, 32, 32, 56, 160, 448], + "stage_idxs": [2, 3, 5], "channel_multiplier": 1.4, "depth_multiplier": 1.8, "drop_rate": 0.4, diff --git a/segmentation_models_pytorch/encoders/timm_sknet.py b/segmentation_models_pytorch/encoders/timm_sknet.py index 9b340d6e..a28e6330 100644 --- a/segmentation_models_pytorch/encoders/timm_sknet.py +++ b/segmentation_models_pytorch/encoders/timm_sknet.py @@ -1,25 +1,36 @@ -from ._base import EncoderMixin +import torch +from typing import Dict, List, Sequence from timm.models.resnet import ResNet from timm.models.sknet import SelectiveKernelBottleneck, SelectiveKernelBasic +from ._base import EncoderMixin + class SkNetEncoder(ResNet, EncoderMixin): - def __init__(self, out_channels, depth=5, **kwargs): + def __init__( + self, + out_channels: List[int], + depth: int = 5, + output_stride: int = 32, + **kwargs, + ): super().__init__(**kwargs) + self._depth = depth - self._out_channels = out_channels self._in_channels = 3 + self._out_channels = out_channels + self._output_stride = output_stride del self.fc del self.global_pool - def get_stages(self): + def get_stages(self) -> Dict[int, Sequence[torch.nn.Module]]: return { - 16: self.layer3, - 32: self.layer4, + 16: [self.layer3], + 32: [self.layer4], } - def forward(self, x): + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: features = [x] if self._depth >= 1: @@ -83,7 +94,7 @@ def load_state_dict(self, state_dict, **kwargs): "encoder": SkNetEncoder, "pretrained_settings": pretrained_settings["timm-skresnet18"], "params": { - "out_channels": (3, 64, 64, 128, 256, 512), + "out_channels": [3, 64, 64, 128, 256, 512], "block": SelectiveKernelBasic, "layers": [2, 2, 2, 2], "zero_init_last": False, @@ -94,7 +105,7 @@ def load_state_dict(self, state_dict, **kwargs): "encoder": SkNetEncoder, "pretrained_settings": pretrained_settings["timm-skresnet34"], "params": { - "out_channels": (3, 64, 64, 128, 256, 512), + "out_channels": [3, 64, 64, 128, 256, 512], "block": SelectiveKernelBasic, "layers": [3, 4, 6, 3], "zero_init_last": False, @@ -105,7 +116,7 @@ def load_state_dict(self, state_dict, **kwargs): "encoder": SkNetEncoder, "pretrained_settings": pretrained_settings["timm-skresnext50_32x4d"], "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), + "out_channels": [3, 64, 256, 512, 1024, 2048], "block": SelectiveKernelBottleneck, "layers": [3, 4, 6, 3], "zero_init_last": False, diff --git a/segmentation_models_pytorch/encoders/vgg.py b/segmentation_models_pytorch/encoders/vgg.py index 34eb1091..537b03dd 100644 --- a/segmentation_models_pytorch/encoders/vgg.py +++ b/segmentation_models_pytorch/encoders/vgg.py @@ -23,10 +23,13 @@ depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). """ +import torch import torch.nn as nn from torchvision.models.vgg import VGG from torchvision.models.vgg import make_layers +from typing import List, Union + from ._base import EncoderMixin # fmt: off @@ -40,11 +43,21 @@ class VGGEncoder(VGG, EncoderMixin): - def __init__(self, out_channels, config, batch_norm=False, depth=5, **kwargs): + def __init__( + self, + out_channels: List[int], + config: List[Union[int, str]], + batch_norm: bool = False, + depth: int = 5, + output_stride: int = 32, + **kwargs, + ): super().__init__(make_layers(config, batch_norm=batch_norm), **kwargs) - self._out_channels = out_channels + self._depth = depth self._in_channels = 3 + self._out_channels = out_channels + self._output_stride = output_stride del self.classifier @@ -54,7 +67,7 @@ def make_dilated(self, *args, **kwargs): " operations for downsampling!" ) - def forward(self, x): + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: # collect stages stages = [] stage_modules = [] @@ -177,7 +190,7 @@ def load_state_dict(self, state_dict, **kwargs): "encoder": VGGEncoder, "pretrained_settings": pretrained_settings["vgg11"], "params": { - "out_channels": (64, 128, 256, 512, 512, 512), + "out_channels": [64, 128, 256, 512, 512, 512], "config": cfg["A"], "batch_norm": False, }, @@ -186,7 +199,7 @@ def load_state_dict(self, state_dict, **kwargs): "encoder": VGGEncoder, "pretrained_settings": pretrained_settings["vgg11_bn"], "params": { - "out_channels": (64, 128, 256, 512, 512, 512), + "out_channels": [64, 128, 256, 512, 512, 512], "config": cfg["A"], "batch_norm": True, }, @@ -195,7 +208,7 @@ def load_state_dict(self, state_dict, **kwargs): "encoder": VGGEncoder, "pretrained_settings": pretrained_settings["vgg13"], "params": { - "out_channels": (64, 128, 256, 512, 512, 512), + "out_channels": [64, 128, 256, 512, 512, 512], "config": cfg["B"], "batch_norm": False, }, @@ -204,7 +217,7 @@ def load_state_dict(self, state_dict, **kwargs): "encoder": VGGEncoder, "pretrained_settings": pretrained_settings["vgg13_bn"], "params": { - "out_channels": (64, 128, 256, 512, 512, 512), + "out_channels": [64, 128, 256, 512, 512, 512], "config": cfg["B"], "batch_norm": True, }, @@ -213,7 +226,7 @@ def load_state_dict(self, state_dict, **kwargs): "encoder": VGGEncoder, "pretrained_settings": pretrained_settings["vgg16"], "params": { - "out_channels": (64, 128, 256, 512, 512, 512), + "out_channels": [64, 128, 256, 512, 512, 512], "config": cfg["D"], "batch_norm": False, }, @@ -222,7 +235,7 @@ def load_state_dict(self, state_dict, **kwargs): "encoder": VGGEncoder, "pretrained_settings": pretrained_settings["vgg16_bn"], "params": { - "out_channels": (64, 128, 256, 512, 512, 512), + "out_channels": [64, 128, 256, 512, 512, 512], "config": cfg["D"], "batch_norm": True, }, @@ -231,7 +244,7 @@ def load_state_dict(self, state_dict, **kwargs): "encoder": VGGEncoder, "pretrained_settings": pretrained_settings["vgg19"], "params": { - "out_channels": (64, 128, 256, 512, 512, 512), + "out_channels": [64, 128, 256, 512, 512, 512], "config": cfg["E"], "batch_norm": False, }, @@ -240,7 +253,7 @@ def load_state_dict(self, state_dict, **kwargs): "encoder": VGGEncoder, "pretrained_settings": pretrained_settings["vgg19_bn"], "params": { - "out_channels": (64, 128, 256, 512, 512, 512), + "out_channels": [64, 128, 256, 512, 512, 512], "config": cfg["E"], "batch_norm": True, }, diff --git a/segmentation_models_pytorch/encoders/xception.py b/segmentation_models_pytorch/encoders/xception.py index 7ba1cdd6..ab78d6ac 100644 --- a/segmentation_models_pytorch/encoders/xception.py +++ b/segmentation_models_pytorch/encoders/xception.py @@ -4,12 +4,13 @@ class XceptionEncoder(Xception, EncoderMixin): - def __init__(self, out_channels, *args, depth=5, **kwargs): + def __init__(self, out_channels, *args, depth=5, output_stride=32, **kwargs): super().__init__(*args, **kwargs) - self._out_channels = out_channels self._depth = depth self._in_channels = 3 + self._out_channels = out_channels + self._output_stride = output_stride # modify padding to maintain output shape self.conv1.padding = (1, 1) @@ -93,6 +94,6 @@ def load_state_dict(self, state_dict): "xception": { "encoder": XceptionEncoder, "pretrained_settings": pretrained_settings["xception"], - "params": {"out_channels": (3, 64, 128, 256, 728, 2048)}, + "params": {"out_channels": [3, 64, 128, 256, 728, 2048]}, } } diff --git a/tests/encoders/test_pretrainedmodels_encoders.py b/tests/encoders/test_pretrainedmodels_encoders.py index e77c3652..0b631601 100644 --- a/tests/encoders/test_pretrainedmodels_encoders.py +++ b/tests/encoders/test_pretrainedmodels_encoders.py @@ -18,7 +18,7 @@ class TestDPNEncoder(base.BaseEncoderTester): def get_tiny_encoder(self): params = { - "stage_idxs": (2, 3, 4, 5), + "stage_idxs": [2, 3, 4, 5], "out_channels": None, "groups": 2, "inc_sec": (2, 2, 2, 2), diff --git a/tests/encoders/test_smp_encoders.py b/tests/encoders/test_smp_encoders.py index f65a61b8..2bb061f8 100644 --- a/tests/encoders/test_smp_encoders.py +++ b/tests/encoders/test_smp_encoders.py @@ -30,7 +30,7 @@ class TestMixTransformerEncoder(base.BaseEncoderTester): def get_tiny_encoder(self): params = { - "out_channels": (3, 0, 4, 4, 4, 4), + "out_channels": [3, 0, 4, 4, 4, 4], "patch_size": 4, "embed_dims": [4, 4, 4, 4], "num_heads": [1, 1, 1, 1], diff --git a/tests/encoders/test_torchvision_encoders.py b/tests/encoders/test_torchvision_encoders.py index 7c979689..95500c9e 100644 --- a/tests/encoders/test_torchvision_encoders.py +++ b/tests/encoders/test_torchvision_encoders.py @@ -77,7 +77,7 @@ class TestVggEncoder(base.BaseEncoderTester): def get_tiny_encoder(self): params = { - "out_channels": (4, 4, 4, 4, 4, 4), + "out_channels": [4, 4, 4, 4, 4, 4], "config": [4, "M", 4, "M", 4, "M", 4, "M", 4, "M"], "batch_norm": False, } From 257da0be20416ebc37f5c67f26c35819719c5443 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Tue, 14 Jan 2025 13:03:26 +0000 Subject: [PATCH 33/57] Fix scripting --- segmentation_models_pytorch/base/utils.py | 1 + segmentation_models_pytorch/encoders/timm_universal.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/segmentation_models_pytorch/base/utils.py b/segmentation_models_pytorch/base/utils.py index 3fcba739..a0d41943 100644 --- a/segmentation_models_pytorch/base/utils.py +++ b/segmentation_models_pytorch/base/utils.py @@ -1,6 +1,7 @@ import torch +@torch.jit.unused def is_torch_compiling(): try: return torch.compiler.is_compiling() diff --git a/segmentation_models_pytorch/encoders/timm_universal.py b/segmentation_models_pytorch/encoders/timm_universal.py index 5a48273a..beea8794 100644 --- a/segmentation_models_pytorch/encoders/timm_universal.py +++ b/segmentation_models_pytorch/encoders/timm_universal.py @@ -194,7 +194,7 @@ def output_stride(self) -> int: Returns: int: The effective output stride. """ - return min(self._output_stride, 2**self._depth) + return int(min(self._output_stride, 2**self._depth)) def load_state_dict(self, state_dict, **kwargs): # for compatibility of weights for From d4d4cf6dc1e66000444d88fcf6d2e63eb4cb9ea9 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Tue, 14 Jan 2025 13:03:44 +0000 Subject: [PATCH 34/57] Add test for torch script --- tests/models/base.py | 69 +++++++++++++++++++------------------------- 1 file changed, 29 insertions(+), 40 deletions(-) diff --git a/tests/models/base.py b/tests/models/base.py index d6e19fd0..58048ff4 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -56,7 +56,11 @@ def decoder_channels(self): return None @lru_cache - def _get_sample(self, batch_size=1, num_channels=3, height=32, width=32): + def _get_sample(self, batch_size=None, num_channels=None, height=None, width=None): + batch_size = batch_size or self.default_batch_size + num_channels = num_channels or self.default_num_channels + height = height or self.default_height + width = width or self.default_width return torch.rand(batch_size, num_channels, height, width) @lru_cache @@ -66,12 +70,7 @@ def get_default_model(self): return model def test_forward_backward(self): - sample = self._get_sample( - batch_size=self.default_batch_size, - num_channels=self.default_num_channels, - height=self.default_height, - width=self.default_width, - ).to(default_device) + sample = self._get_sample().to(default_device) model = self.get_default_model() @@ -111,12 +110,7 @@ def test_in_channels_and_depth_and_out_classes( .eval() ) - sample = self._get_sample( - batch_size=self.default_batch_size, - num_channels=in_channels, - height=self.default_height, - width=self.default_width, - ).to(default_device) + sample = self._get_sample(num_channels=in_channels).to(default_device) # check in channels correctly set with torch.inference_mode(): @@ -145,12 +139,7 @@ def test_classification_head(self): self.assertIsInstance(model.classification_head[3], torch.nn.Linear) self.assertIsInstance(model.classification_head[4].activation, torch.nn.Sigmoid) - sample = self._get_sample( - batch_size=self.default_batch_size, - num_channels=self.default_num_channels, - height=self.default_height, - width=self.default_width, - ).to(default_device) + sample = self._get_sample().to(default_device) with torch.inference_mode(): _, cls_probs = model(sample) @@ -163,8 +152,6 @@ def test_any_resolution(self): self.skipTest("Model requires divisible input shape") sample = self._get_sample( - batch_size=self.default_batch_size, - num_channels=self.default_num_channels, height=self.default_height + 3, width=self.default_width + 7, ).to(default_device) @@ -193,12 +180,7 @@ def test_save_load_with_hub_mixin(self): readme = f.read() # check inference is correct - sample = self._get_sample( - batch_size=self.default_batch_size, - num_channels=self.default_num_channels, - height=self.default_height, - width=self.default_width, - ).to(default_device) + sample = self._get_sample().to(default_device) with torch.inference_mode(): output = model(sample) @@ -242,12 +224,7 @@ def test_compile(self): if not check_run_test_on_diff_or_main(self.files_for_diff): self.skipTest("No diff and not on `main`.") - sample = self._get_sample( - batch_size=self.default_batch_size, - num_channels=self.default_num_channels, - height=self.default_height, - width=self.default_width, - ).to(default_device) + sample = self._get_sample().to(default_device) model = self.get_default_model() compiled_model = torch.compile(model, fullgraph=True, dynamic=True) @@ -260,13 +237,7 @@ def test_torch_export(self): if not check_run_test_on_diff_or_main(self.files_for_diff): self.skipTest("No diff and not on `main`.") - sample = self._get_sample( - batch_size=self.default_batch_size, - num_channels=self.default_num_channels, - height=self.default_height, - width=self.default_width, - ).to(default_device) - + sample = self._get_sample().to(default_device) model = self.get_default_model() model.eval() @@ -282,3 +253,21 @@ def test_torch_export(self): self.assertEqual(eager_output.shape, exported_output.shape) torch.testing.assert_close(eager_output, exported_output) + + @pytest.mark.torch_script + def test_torch_script(self): + if not check_run_test_on_diff_or_main(self.files_for_diff): + self.skipTest("No diff and not on `main`.") + + sample = self._get_sample().to(default_device) + model = self.get_default_model() + model.eval() + + scripted_model = torch.jit.script(model) + + with torch.inference_mode(): + scripted_output = scripted_model(sample) + eager_output = model(sample) + + self.assertEqual(scripted_output.shape, eager_output.shape) + torch.testing.assert_close(scripted_output, eager_output) From 3cb81983dbc7b1a2fdf8a1498867dc6aee6ea7ab Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Tue, 14 Jan 2025 13:05:05 +0000 Subject: [PATCH 35/57] Add torch_script test to CI --- .github/workflows/tests.yml | 15 +++++++++++++++ pyproject.toml | 1 + 2 files changed, 16 insertions(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 3ebe6143..fbce9a45 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -105,6 +105,21 @@ jobs: - name: Test with PyTest run: uv run pytest -v -rsx -n 2 -m "torch_export" + test_torch_script: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: astral-sh/setup-uv@v5 + with: + python-version: "3.10" + - name: Install dependencies + run: uv pip install -r requirements/required.txt -r requirements/test.txt + - name: Show installed packages + run: uv pip list + - name: Test with PyTest + run: uv run pytest -v -rsx -n 2 -m "torch_script" + minimum: runs-on: ubuntu-latest steps: diff --git a/pyproject.toml b/pyproject.toml index 8d9b2078..c7cf1958 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,7 @@ markers = [ "logits_match", "compile", "torch_export", + "torch_script", ] [tool.coverage.run] From 4f65d8f1e4a81ce7543edc280b202ceb16f9980b Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Tue, 14 Jan 2025 13:08:01 +0000 Subject: [PATCH 36/57] Fix --- segmentation_models_pytorch/decoders/upernet/decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/segmentation_models_pytorch/decoders/upernet/decoder.py b/segmentation_models_pytorch/decoders/upernet/decoder.py index 092de36a..ebcb3d10 100644 --- a/segmentation_models_pytorch/decoders/upernet/decoder.py +++ b/segmentation_models_pytorch/decoders/upernet/decoder.py @@ -110,7 +110,7 @@ def __init__( use_batchnorm=True, ) - def forward(self, *features): + def forward(self, features): output_size = features[0].shape[2:] target_size = [size // 4 for size in output_size] From 70776eaa46775a5bf5fbac03744101eb2237b2c9 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Tue, 14 Jan 2025 13:14:10 +0000 Subject: [PATCH 37/57] Fix timm-effnet encoders --- .../encoders/timm_efficientnet.py | 42 ++++++++++++------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/segmentation_models_pytorch/encoders/timm_efficientnet.py b/segmentation_models_pytorch/encoders/timm_efficientnet.py index dcc98e28..9372ba31 100644 --- a/segmentation_models_pytorch/encoders/timm_efficientnet.py +++ b/segmentation_models_pytorch/encoders/timm_efficientnet.py @@ -156,33 +156,47 @@ def load_state_dict(self, state_dict, **kwargs): class EfficientNetEncoder(EfficientNetBaseEncoder): def __init__( self, - stage_idxs, - out_channels, - depth=5, - channel_multiplier=1.0, - depth_multiplier=1.0, - drop_rate=0.2, + stage_idxs: List[int], + out_channels: List[int], + depth: int = 5, + channel_multiplier: float = 1.0, + depth_multiplier: float = 1.0, + drop_rate: float = 0.2, + output_stride: int = 32, ): kwargs = get_efficientnet_kwargs( channel_multiplier, depth_multiplier, drop_rate ) - super().__init__(stage_idxs, out_channels, depth, **kwargs) + super().__init__( + stage_idxs=stage_idxs, + depth=depth, + out_channels=out_channels, + output_stride=output_stride, + **kwargs, + ) class EfficientNetLiteEncoder(EfficientNetBaseEncoder): def __init__( self, - stage_idxs, - out_channels, - depth=5, - channel_multiplier=1.0, - depth_multiplier=1.0, - drop_rate=0.2, + stage_idxs: List[int], + out_channels: List[int], + depth: int = 5, + channel_multiplier: float = 1.0, + depth_multiplier: float = 1.0, + drop_rate: float = 0.2, + output_stride: int = 32, ): kwargs = gen_efficientnet_lite_kwargs( channel_multiplier, depth_multiplier, drop_rate ) - super().__init__(stage_idxs, out_channels, depth, **kwargs) + super().__init__( + stage_idxs=stage_idxs, + depth=depth, + out_channels=out_channels, + output_stride=output_stride, + **kwargs, + ) def prepare_settings(settings): From 31bee792db1a06c46747b8e270a2473fa1fc50d2 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Tue, 14 Jan 2025 14:22:07 +0000 Subject: [PATCH 38/57] Make from_pretrained strict by default --- segmentation_models_pytorch/base/hub_mixin.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/segmentation_models_pytorch/base/hub_mixin.py b/segmentation_models_pytorch/base/hub_mixin.py index 1c9e8052..a18380d1 100644 --- a/segmentation_models_pytorch/base/hub_mixin.py +++ b/segmentation_models_pytorch/base/hub_mixin.py @@ -121,7 +121,9 @@ def config(self) -> dict: @wraps(PyTorchModelHubMixin.from_pretrained) -def from_pretrained(pretrained_model_name_or_path: str, *args, **kwargs): +def from_pretrained( + pretrained_model_name_or_path: str, *args, strict: bool = True, **kwargs +): config_path = Path(pretrained_model_name_or_path) / "config.json" if not config_path.exists(): config_path = hf_hub_download( @@ -137,7 +139,9 @@ def from_pretrained(pretrained_model_name_or_path: str, *args, **kwargs): import segmentation_models_pytorch as smp model_class = getattr(smp, model_class_name) - return model_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + return model_class.from_pretrained( + pretrained_model_name_or_path, *args, **kwargs, strict=strict + ) def supports_config_loading(func): From 556b3aa2a5fdd0b4e7aaaeff0f33aef6e411e349 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Tue, 14 Jan 2025 14:22:30 +0000 Subject: [PATCH 39/57] Fix DeepLabV3 BC --- .../decoders/deeplabv3/decoder.py | 19 +------------------ .../decoders/deeplabv3/model.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/segmentation_models_pytorch/decoders/deeplabv3/decoder.py b/segmentation_models_pytorch/decoders/deeplabv3/decoder.py index 15280043..6a801a70 100644 --- a/segmentation_models_pytorch/decoders/deeplabv3/decoder.py +++ b/segmentation_models_pytorch/decoders/deeplabv3/decoder.py @@ -40,7 +40,7 @@ __all__ = ["DeepLabV3Decoder", "DeepLabV3PlusDecoder"] -class DeepLabV3Decoder(nn.Sequential): +class DeepLabV3Decoder(nn.Module): def __init__( self, in_channels: int, @@ -69,23 +69,6 @@ def forward(self, features: List[torch.Tensor]) -> torch.Tensor: x = self.relu(x) return x - def load_state_dict(self, state_dict, *args, **kwargs): - # For backward compatibility, previously this module was Sequential - # and was not scriptable. - keys = list(state_dict.keys()) - for key in keys: - new_key = key - if key.startswith("0."): - new_key = "aspp." + key[2:] - elif key.startswith("1."): - new_key = "conv." + key[2:] - elif key.startswith("2."): - new_key = "bn." + key[2:] - elif key.startswith("3."): - new_key = "relu." + key[2:] - state_dict[new_key] = state_dict.pop(key) - super().load_state_dict(state_dict, *args, **kwargs) - class DeepLabV3PlusDecoder(nn.Module): def __init__( diff --git a/segmentation_models_pytorch/decoders/deeplabv3/model.py b/segmentation_models_pytorch/decoders/deeplabv3/model.py index 654e38d4..c14776f3 100644 --- a/segmentation_models_pytorch/decoders/deeplabv3/model.py +++ b/segmentation_models_pytorch/decoders/deeplabv3/model.py @@ -121,6 +121,21 @@ def __init__( else: self.classification_head = None + def load_state_dict(self, state_dict, *args, **kwargs): + # For backward compatibility, previously Decoder module was Sequential + # and was not scriptable. + keys = list(state_dict.keys()) + for key in keys: + new_key = key + if key.startswith("decoder.0."): + new_key = key.replace("decoder.0.", "decoder.aspp.") + elif key.startswith("decoder.1."): + new_key = key.replace("decoder.1.", "decoder.conv.") + elif key.startswith("decoder.2."): + new_key = key.replace("decoder.2.", "decoder.bn.") + state_dict[new_key] = state_dict.pop(key) + return super().load_state_dict(state_dict, *args, **kwargs) + class DeepLabV3Plus(SegmentationModel): """DeepLabV3+ implementation from "Encoder-Decoder with Atrous Separable From f70d861712dedd46e43cce5aaab44e2d1a08e9e4 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Tue, 14 Jan 2025 18:16:48 +0000 Subject: [PATCH 40/57] Fix scripting for encoders --- segmentation_models_pytorch/encoders/_base.py | 4 ++ .../encoders/densenet.py | 32 +++++---- segmentation_models_pytorch/encoders/dpn.py | 2 + .../encoders/efficientnet.py | 2 + .../encoders/inceptionresnetv2.py | 3 + .../encoders/inceptionv4.py | 34 ++++------ .../encoders/mix_transformer.py | 65 ++++++++++--------- .../encoders/mobilenet.py | 28 ++++---- .../encoders/mobileone.py | 6 +- segmentation_models_pytorch/encoders/senet.py | 8 ++- .../encoders/timm_efficientnet.py | 11 ++-- .../encoders/timm_universal.py | 4 ++ segmentation_models_pytorch/encoders/vgg.py | 36 ++++++---- tests/encoders/base.py | 21 ++++++ .../test_pretrainedmodels_encoders.py | 6 +- tests/encoders/test_torchvision_encoders.py | 4 +- 16 files changed, 156 insertions(+), 110 deletions(-) diff --git a/segmentation_models_pytorch/encoders/_base.py b/segmentation_models_pytorch/encoders/_base.py index 335efed0..98c431fb 100644 --- a/segmentation_models_pytorch/encoders/_base.py +++ b/segmentation_models_pytorch/encoders/_base.py @@ -10,6 +10,10 @@ class EncoderMixin: - patching first convolution for arbitrary input channels """ + _is_torch_scriptable = True + _is_torch_exportable = True + _is_torch_compilable = True + def __init__(self): self._depth = 5 self._in_channels = 3 diff --git a/segmentation_models_pytorch/encoders/densenet.py b/segmentation_models_pytorch/encoders/densenet.py index c496d048..aa61db35 100644 --- a/segmentation_models_pytorch/encoders/densenet.py +++ b/segmentation_models_pytorch/encoders/densenet.py @@ -24,8 +24,6 @@ """ import re -import torch -import torch.nn as nn from torchvision.models.densenet import DenseNet @@ -47,15 +45,6 @@ def make_dilated(self, *args, **kwargs): "due to pooling operation for downsampling!" ) - def apply_transition( - self, transition: torch.nn.Sequential, x: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - for module in transition: - x = module(x) - if isinstance(module, nn.ReLU): - intermediate = x - return x, intermediate - def forward(self, x): features = [x] @@ -68,20 +57,29 @@ def forward(self, x): if self._depth >= 2: x = self.features.pool0(x) x = self.features.denseblock1(x) - x, intermediate = self.apply_transition(self.features.transition1, x) - features.append(intermediate) + x = self.features.transition1.norm(x) + x = self.features.transition1.relu(x) + features.append(x) if self._depth >= 3: + x = self.features.transition1.conv(x) + x = self.features.transition1.pool(x) x = self.features.denseblock2(x) - x, intermediate = self.apply_transition(self.features.transition2, x) - features.append(intermediate) + x = self.features.transition2.norm(x) + x = self.features.transition2.relu(x) + features.append(x) if self._depth >= 4: + x = self.features.transition2.conv(x) + x = self.features.transition2.pool(x) x = self.features.denseblock3(x) - x, intermediate = self.apply_transition(self.features.transition3, x) - features.append(intermediate) + x = self.features.transition3.norm(x) + x = self.features.transition3.relu(x) + features.append(x) if self._depth >= 5: + x = self.features.transition3.conv(x) + x = self.features.transition3.pool(x) x = self.features.denseblock4(x) x = self.features.norm5(x) features.append(x) diff --git a/segmentation_models_pytorch/encoders/dpn.py b/segmentation_models_pytorch/encoders/dpn.py index e25acd7c..bf768c8b 100644 --- a/segmentation_models_pytorch/encoders/dpn.py +++ b/segmentation_models_pytorch/encoders/dpn.py @@ -34,6 +34,8 @@ class DPNEncoder(DPN, EncoderMixin): + _is_torch_scriptable = False + def __init__( self, stage_idxs: List[int], diff --git a/segmentation_models_pytorch/encoders/efficientnet.py b/segmentation_models_pytorch/encoders/efficientnet.py index bb56abf2..c0483b39 100644 --- a/segmentation_models_pytorch/encoders/efficientnet.py +++ b/segmentation_models_pytorch/encoders/efficientnet.py @@ -33,6 +33,8 @@ class EfficientNetEncoder(EfficientNet, EncoderMixin): + _is_torch_scriptable = False + def __init__( self, stage_idxs: List[int], diff --git a/segmentation_models_pytorch/encoders/inceptionresnetv2.py b/segmentation_models_pytorch/encoders/inceptionresnetv2.py index ced9c7f5..df3da839 100644 --- a/segmentation_models_pytorch/encoders/inceptionresnetv2.py +++ b/segmentation_models_pytorch/encoders/inceptionresnetv2.py @@ -54,6 +54,9 @@ def __init__( if isinstance(m, nn.MaxPool2d): m.padding = (1, 1) + # for torchscript, block8 does not have relu defined + self.block8.relu = nn.Identity() + # remove linear layers del self.avgpool_1a del self.last_linear diff --git a/segmentation_models_pytorch/encoders/inceptionv4.py b/segmentation_models_pytorch/encoders/inceptionv4.py index ccac34f9..fa8da811 100644 --- a/segmentation_models_pytorch/encoders/inceptionv4.py +++ b/segmentation_models_pytorch/encoders/inceptionv4.py @@ -35,7 +35,7 @@ class InceptionV4Encoder(InceptionV4, EncoderMixin): def __init__( self, - stage_idxs: List[int], + out_indexes: List[int], out_channels: List[int], depth: int = 5, output_stride: int = 32, @@ -43,11 +43,11 @@ def __init__( ): super().__init__(**kwargs) - self._stage_idxs = stage_idxs self._depth = depth self._in_channels = 3 self._out_channels = out_channels self._output_stride = output_stride + self._out_indexes = out_indexes # correct paddings for m in self.modules(): @@ -67,28 +67,22 @@ def make_dilated(self, *args, **kwargs): ) def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + depth = 0 features = [x] - if self._depth >= 1: - x = self.features[: self._stage_idxs[0]](x) - features.append(x) + for i, module in enumerate(self.features): + x = module(x) - if self._depth >= 2: - x = self.features[self._stage_idxs[0] : self._stage_idxs[1]](x) - features.append(x) + if i in self._out_indexes: + features.append(x) + depth += 1 - if self._depth >= 3: - x = self.features[self._stage_idxs[1] : self._stage_idxs[2]](x) - features.append(x) - - if self._depth >= 4: - x = self.features[self._stage_idxs[2] : self._stage_idxs[3]](x) - features.append(x) - - if self._depth >= 5: - x = self.features[self._stage_idxs[3] :](x) - features.append(x) + # torchscript does not support break in cycle, so we just + # go over all modules and then slice number of features + if not torch.jit.is_scripting() and depth > self._depth: + break + features = features[: self._depth + 1] return features def load_state_dict(self, state_dict, **kwargs): @@ -121,7 +115,7 @@ def load_state_dict(self, state_dict, **kwargs): }, }, "params": { - "stage_idxs": [3, 5, 9, 15], + "out_indexes": [2, 4, 8, 14], "out_channels": [3, 64, 192, 384, 1024, 1536], "num_classes": 1001, }, diff --git a/segmentation_models_pytorch/encoders/mix_transformer.py b/segmentation_models_pytorch/encoders/mix_transformer.py index 2a1e068d..479c3f09 100644 --- a/segmentation_models_pytorch/encoders/mix_transformer.py +++ b/segmentation_models_pytorch/encoders/mix_transformer.py @@ -11,20 +11,22 @@ import math import torch import torch.nn as nn +import torch.nn.functional as F from functools import partial +from typing import Dict, Sequence, List from timm.layers import DropPath, to_2tuple, trunc_normal_ class LayerNorm(nn.LayerNorm): - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: if x.ndim == 4: - B, C, H, W = x.shape - x = x.view(B, C, -1).transpose(1, 2) - x = super().forward(x) - x = x.transpose(1, 2).view(B, C, H, W) + batch_size, channels, height, width = x.shape + x = x.view(batch_size, channels, -1).transpose(1, 2) + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + x = x.transpose(1, 2).view(batch_size, channels, height, width) else: - x = super().forward(x) + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) return x @@ -60,9 +62,9 @@ def _init_weights(self, m): if m.bias is not None: m.bias.data.zero_() - def forward(self, x, H, W): + def forward(self, x: torch.Tensor, height: int, width: int) -> torch.Tensor: x = self.fc1(x) - x = self.dwconv(x, H, W) + x = self.dwconv(x, height, width) x = self.act(x) x = self.drop(x) x = self.fc2(x) @@ -101,6 +103,10 @@ def __init__( if sr_ratio > 1: self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) self.norm = LayerNorm(dim) + else: + # for torchscript compatibility + self.sr = nn.Identity() + self.norm = nn.Identity() self.apply(self._init_weights) @@ -119,27 +125,27 @@ def _init_weights(self, m): if m.bias is not None: m.bias.data.zero_() - def forward(self, x, H, W): - B, N, C = x.shape + def forward(self, x: torch.Tensor, height: int, width: int) -> torch.Tensor: + batch_size, N, C = x.shape q = ( self.q(x) - .reshape(B, N, self.num_heads, C // self.num_heads) + .reshape(batch_size, N, self.num_heads, C // self.num_heads) .permute(0, 2, 1, 3) ) if self.sr_ratio > 1: - x_ = x.permute(0, 2, 1).reshape(B, C, H, W) - x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) + x_ = x.permute(0, 2, 1).reshape(batch_size, C, height, width) + x_ = self.sr(x_).reshape(batch_size, C, -1).permute(0, 2, 1) x_ = self.norm(x_) kv = ( self.kv(x_) - .reshape(B, -1, 2, self.num_heads, C // self.num_heads) + .reshape(batch_size, -1, 2, self.num_heads, C // self.num_heads) .permute(2, 0, 3, 1, 4) ) else: kv = ( self.kv(x) - .reshape(B, -1, 2, self.num_heads, C // self.num_heads) + .reshape(batch_size, -1, 2, self.num_heads, C // self.num_heads) .permute(2, 0, 3, 1, 4) ) k, v = kv[0], kv[1] @@ -148,7 +154,7 @@ def forward(self, x, H, W): attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = (attn @ v).transpose(1, 2).reshape(batch_size, N, C) x = self.proj(x) x = self.proj_drop(x) @@ -209,12 +215,12 @@ def _init_weights(self, m): if m.bias is not None: m.bias.data.zero_() - def forward(self, x): - B, _, H, W = x.shape + def forward(self, x: torch.Tensor) -> torch.Tensor: + batch_size, _, height, width = x.shape x = x.flatten(2).transpose(1, 2) - x = x + self.drop_path(self.attn(self.norm1(x), H, W)) - x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) - x = x.transpose(1, 2).view(B, -1, H, W) + x = x + self.drop_path(self.attn(self.norm1(x), height, width)) + x = x + self.drop_path(self.mlp(self.norm2(x), height, width)) + x = x.transpose(1, 2).view(batch_size, -1, height, width) return x @@ -256,7 +262,7 @@ def _init_weights(self, m): if m.bias is not None: m.bias.data.zero_() - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) x = self.norm(x) return x @@ -462,7 +468,7 @@ def reset_classifier(self, num_classes, global_pool=""): nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() ) - def forward_features(self, x): + def forward_features(self, x: torch.Tensor) -> List[torch.Tensor]: outs = [] # stage 1 @@ -491,11 +497,11 @@ def forward_features(self, x): return outs - def forward(self, x): - x = self.forward_features(x) + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + features = self.forward_features(x) # x = self.head(x) - return x + return features class DWConv(nn.Module): @@ -503,9 +509,9 @@ def __init__(self, dim=768): super(DWConv, self).__init__() self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) - def forward(self, x, H, W): - B, _, C = x.shape - x = x.transpose(1, 2).view(B, C, H, W) + def forward(self, x: torch.Tensor, height: int, width: int) -> torch.Tensor: + batch_size, _, channels = x.shape + x = x.transpose(1, 2).view(batch_size, channels, height, width) x = self.dwconv(x) x = x.flatten(2).transpose(1, 2) @@ -516,7 +522,6 @@ def forward(self, x, H, W): # End of NVIDIA code # --------------------------------------------------------------- -from typing import Dict, Sequence, List # noqa E402 from ._base import EncoderMixin # noqa E402 diff --git a/segmentation_models_pytorch/encoders/mobilenet.py b/segmentation_models_pytorch/encoders/mobilenet.py index 482043e3..2dfa4a63 100644 --- a/segmentation_models_pytorch/encoders/mobilenet.py +++ b/segmentation_models_pytorch/encoders/mobilenet.py @@ -40,6 +40,7 @@ def __init__( self._in_channels = 3 self._out_channels = out_channels self._output_stride = output_stride + self._out_indexes = [2, 4, 7, 14] del self.classifier @@ -52,25 +53,20 @@ def get_stages(self) -> Dict[int, Sequence[torch.nn.Module]]: def forward(self, x: torch.Tensor) -> List[torch.Tensor]: features = [x] - if self._depth >= 1: - x = self.features[:2](x) - features.append(x) + depth = 0 + for i, module in enumerate(self.features): + x = module(x) - if self._depth >= 2: - x = self.features[2:4](x) - features.append(x) + if i in self._out_indexes: + features.append(x) + depth += 1 - if self._depth >= 3: - x = self.features[4:7](x) - features.append(x) + # torchscript does not support break in cycle, so we just + # go over all modules and then slice number of features + if not torch.jit.is_scripting() and depth > self._depth: + break - if self._depth >= 4: - x = self.features[7:14](x) - features.append(x) - - if self._depth >= 5: - x = self.features[14:](x) - features.append(x) + features = features[: self._depth + 1] return features diff --git a/segmentation_models_pytorch/encoders/mobileone.py b/segmentation_models_pytorch/encoders/mobileone.py index 20732ce3..131675cb 100644 --- a/segmentation_models_pytorch/encoders/mobileone.py +++ b/segmentation_models_pytorch/encoders/mobileone.py @@ -120,6 +120,8 @@ def __init__( bias=True, ) else: + self.reparam_conv = nn.Identity() + # Re-parameterizable skip connection self.rbr_skip = ( nn.BatchNorm2d(num_features=in_channels) @@ -157,8 +159,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Other branches out = scale_out + identity_out - for ix in range(self.num_conv_branches): - out += self.rbr_conv[ix](x) + for module in self.rbr_conv: + out += module(x) return self.activation(self.se(out)) diff --git a/segmentation_models_pytorch/encoders/senet.py b/segmentation_models_pytorch/encoders/senet.py index 7c3a90fd..a3b44877 100644 --- a/segmentation_models_pytorch/encoders/senet.py +++ b/segmentation_models_pytorch/encoders/senet.py @@ -50,6 +50,10 @@ def __init__( self._out_channels = out_channels self._output_stride = output_stride + # for compatibility with torchscript + self.layer0_pool = self.layer0.pool + self.layer0.pool = torch.nn.Identity() + del self.last_linear del self.avg_pool @@ -63,11 +67,11 @@ def forward(self, x: torch.Tensor) -> List[torch.Tensor]: features = [x] if self._depth >= 1: - x = self.layer0[:-1](x) + x = self.layer0(x) features.append(x) if self._depth >= 2: - x = self.layer0[-1](x) + x = self.layer0_pool(x) x = self.layer1(x) features.append(x) diff --git a/segmentation_models_pytorch/encoders/timm_efficientnet.py b/segmentation_models_pytorch/encoders/timm_efficientnet.py index 9372ba31..7cd52923 100644 --- a/segmentation_models_pytorch/encoders/timm_efficientnet.py +++ b/segmentation_models_pytorch/encoders/timm_efficientnet.py @@ -130,19 +130,22 @@ def forward(self, x: torch.Tensor) -> List[torch.Tensor]: features.append(x) if self._depth >= 2: - x = self.blocks[: self._stage_idxs[0]](x) + x = self.blocks[0](x) + x = self.blocks[1](x) features.append(x) if self._depth >= 3: - x = self.blocks[self._stage_idxs[0] : self._stage_idxs[1]](x) + x = self.blocks[2](x) features.append(x) if self._depth >= 4: - x = self.blocks[self._stage_idxs[1] : self._stage_idxs[2]](x) + x = self.blocks[3](x) + x = self.blocks[4](x) features.append(x) if self._depth >= 5: - x = self.blocks[self._stage_idxs[2] :](x) + x = self.blocks[5](x) + x = self.blocks[6](x) features.append(x) return features diff --git a/segmentation_models_pytorch/encoders/timm_universal.py b/segmentation_models_pytorch/encoders/timm_universal.py index beea8794..759ede51 100644 --- a/segmentation_models_pytorch/encoders/timm_universal.py +++ b/segmentation_models_pytorch/encoders/timm_universal.py @@ -44,6 +44,10 @@ class TimmUniversalEncoder(nn.Module): - Compatible with convolutional and transformer-like backbones. """ + _is_torch_scriptable = True + _is_torch_exportable = True + _is_torch_compilable = True + def __init__( self, name: str, diff --git a/segmentation_models_pytorch/encoders/vgg.py b/segmentation_models_pytorch/encoders/vgg.py index 537b03dd..82c9c431 100644 --- a/segmentation_models_pytorch/encoders/vgg.py +++ b/segmentation_models_pytorch/encoders/vgg.py @@ -25,6 +25,7 @@ import torch import torch.nn as nn + from torchvision.models.vgg import VGG from torchvision.models.vgg import make_layers @@ -58,6 +59,12 @@ def __init__( self._in_channels = 3 self._out_channels = out_channels self._output_stride = output_stride + self._out_indexes = [ + i - 1 + for i, module in enumerate(self.features) + if isinstance(module, nn.MaxPool2d) + ] + self._out_indexes.append(len(self.features) - 1) del self.classifier @@ -68,21 +75,22 @@ def make_dilated(self, *args, **kwargs): ) def forward(self, x: torch.Tensor) -> List[torch.Tensor]: - # collect stages - stages = [] - stage_modules = [] - for module in self.features: - if isinstance(module, nn.MaxPool2d): - stages.append(stage_modules) - stage_modules = [] - stage_modules.append(module) - stages.append(stage_modules) - features = [] - for i in range(self._depth + 1): - for module in stages[i]: - x = module(x) - features.append(x) + depth = 0 + + for i, module in enumerate(self.features): + x = module(x) + + if i in self._out_indexes: + features.append(x) + depth += 1 + + # torchscript does not support break in cycle, so we just + # go over all modules and then slice number of features + if not torch.jit.is_scripting() and depth > self._depth: + break + + features = features[: self._depth + 1] return features diff --git a/tests/encoders/base.py b/tests/encoders/base.py index 28b12ab8..fe12d58f 100644 --- a/tests/encoders/base.py +++ b/tests/encoders/base.py @@ -250,3 +250,24 @@ def test_torch_export(self): for eager_feature, exported_feature in zip(eager_output, exported_output): torch.testing.assert_close(eager_feature, exported_feature) + + @pytest.mark.torch_script + def test_torch_script(self): + sample = self._get_sample().to(default_device) + + encoder = self.get_tiny_encoder() + encoder = encoder.eval().to(default_device) + + if not encoder._is_torch_scriptable: + with self.assertRaises(RuntimeError, msg="not torch scriptable"): + scripted_encoder = torch.jit.script(encoder) + return + + scripted_encoder = torch.jit.script(encoder) + + with torch.inference_mode(): + eager_output = encoder(sample) + scripted_output = scripted_encoder(sample) + + for eager_feature, scripted_feature in zip(eager_output, scripted_output): + torch.testing.assert_close(eager_feature, scripted_feature) diff --git a/tests/encoders/test_pretrainedmodels_encoders.py b/tests/encoders/test_pretrainedmodels_encoders.py index 0b631601..486f7fe6 100644 --- a/tests/encoders/test_pretrainedmodels_encoders.py +++ b/tests/encoders/test_pretrainedmodels_encoders.py @@ -18,8 +18,8 @@ class TestDPNEncoder(base.BaseEncoderTester): def get_tiny_encoder(self): params = { - "stage_idxs": [2, 3, 4, 5], - "out_channels": None, + "stage_idxs": [2, 3, 4, 6], + "out_channels": [3, 2, 70, 134, 262, 518], "groups": 2, "inc_sec": (2, 2, 2, 2), "k_r": 2, @@ -67,7 +67,7 @@ class TestSeNetEncoder(base.BaseEncoderTester): def get_tiny_encoder(self): params = { - "out_channels": None, + "out_channels": [3, 2, 256, 512, 1024, 2048], "block": smp.encoders.senet.SEResNetBottleneck, "layers": [1, 1, 1, 1], "downsample_kernel_size": 1, diff --git a/tests/encoders/test_torchvision_encoders.py b/tests/encoders/test_torchvision_encoders.py index 95500c9e..c0d7c64f 100644 --- a/tests/encoders/test_torchvision_encoders.py +++ b/tests/encoders/test_torchvision_encoders.py @@ -26,7 +26,7 @@ class TestResNetEncoder(base.BaseEncoderTester): def get_tiny_encoder(self): params = { - "out_channels": None, + "out_channels": [3, 64, 64, 128, 256, 512], "block": smp.encoders.resnet.BasicBlock, "layers": [1, 1, 1, 1], } @@ -44,7 +44,7 @@ class TestDenseNetEncoder(base.BaseEncoderTester): def get_tiny_encoder(self): params = { - "out_channels": None, + "out_channels": [3, 2, 3, 2, 2, 2], "num_init_features": 2, "growth_rate": 1, "block_config": (1, 1, 1, 1), From ead24b41f9508214c2cd2ab3dad2eb80ad071a5c Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Tue, 14 Jan 2025 18:40:16 +0000 Subject: [PATCH 41/57] Refactor test do not skip --- segmentation_models_pytorch/base/model.py | 4 ++++ .../decoders/unetplusplus/model.py | 2 ++ .../decoders/upernet/model.py | 2 ++ .../encoders/efficientnet.py | 3 +++ tests/encoders/base.py | 21 +++++++------------ .../test_pretrainedmodels_encoders.py | 8 ------- tests/encoders/test_smp_encoders.py | 3 --- tests/encoders/test_timm_ported_encoders.py | 3 --- tests/models/base.py | 5 +++++ 9 files changed, 23 insertions(+), 28 deletions(-) diff --git a/segmentation_models_pytorch/base/model.py b/segmentation_models_pytorch/base/model.py index b5f8abc5..25272f5f 100644 --- a/segmentation_models_pytorch/base/model.py +++ b/segmentation_models_pytorch/base/model.py @@ -11,6 +11,10 @@ class SegmentationModel(torch.nn.Module, SMPHubMixin): """Base class for all segmentation models.""" + _is_torch_scriptable = True + _is_torch_exportable = True + _is_torch_compilable = True + # if model supports shape not divisible by 2 ^ n set to False requires_divisible_input_shape = True diff --git a/segmentation_models_pytorch/decoders/unetplusplus/model.py b/segmentation_models_pytorch/decoders/unetplusplus/model.py index 9d4a1e35..5c3d3a91 100644 --- a/segmentation_models_pytorch/decoders/unetplusplus/model.py +++ b/segmentation_models_pytorch/decoders/unetplusplus/model.py @@ -56,6 +56,8 @@ class UnetPlusPlus(SegmentationModel): """ + _is_torch_scriptable = False + @supports_config_loading def __init__( self, diff --git a/segmentation_models_pytorch/decoders/upernet/model.py b/segmentation_models_pytorch/decoders/upernet/model.py index 076ed2de..7ffeee5b 100644 --- a/segmentation_models_pytorch/decoders/upernet/model.py +++ b/segmentation_models_pytorch/decoders/upernet/model.py @@ -48,6 +48,8 @@ class UPerNet(SegmentationModel): """ + _is_torch_scriptable = False + @supports_config_loading def __init__( self, diff --git a/segmentation_models_pytorch/encoders/efficientnet.py b/segmentation_models_pytorch/encoders/efficientnet.py index c0483b39..5c826a58 100644 --- a/segmentation_models_pytorch/encoders/efficientnet.py +++ b/segmentation_models_pytorch/encoders/efficientnet.py @@ -35,6 +35,9 @@ class EfficientNetEncoder(EfficientNet, EncoderMixin): _is_torch_scriptable = False + # works with torch 2.4.0, but not with torch 2.5.1 + _is_torch_compilable = False + def __init__( self, stage_idxs: List[int], diff --git a/tests/encoders/base.py b/tests/encoders/base.py index fe12d58f..1f1d26c9 100644 --- a/tests/encoders/base.py +++ b/tests/encoders/base.py @@ -33,10 +33,6 @@ class BaseEncoderTester(unittest.TestCase): depth_to_test = [3, 4, 5] strides_to_test = [8, 16] # 32 is a default one - # enable/disable tests - do_test_torch_compile = True - do_test_torch_export = True - def get_tiny_encoder(self): return smp.encoders.get_encoder(self.encoder_names[0], encoder_weights=None) @@ -208,28 +204,25 @@ def test_dilated(self): @pytest.mark.compile def test_compile(self): - if not self.do_test_torch_compile: - self.skipTest( - f"torch_compile test is disabled for {self.encoder_names[0]}." - ) - if not check_run_test_on_diff_or_main(self.files_for_diff): self.skipTest("No diff and not on `main`.") sample = self._get_sample().to(default_device) - encoder = self.get_tiny_encoder().eval().to(default_device) + encoder = self.get_tiny_encoder() + encoder = encoder.eval().to(default_device) + compiled_encoder = torch.compile(encoder, fullgraph=True, dynamic=True) - with torch.inference_mode(): + if encoder._is_torch_compilable: compiled_encoder(sample) + else: + with self.assertRaises(Exception): + compiled_encoder(sample) @pytest.mark.torch_export @requires_torch_greater_or_equal("2.4.0") def test_torch_export(self): - if not self.do_test_torch_export: - self.skipTest(f"torch_export test is disabled for {self.encoder_names[0]}.") - if not check_run_test_on_diff_or_main(self.files_for_diff): self.skipTest("No diff and not on `main`.") diff --git a/tests/encoders/test_pretrainedmodels_encoders.py b/tests/encoders/test_pretrainedmodels_encoders.py index 486f7fe6..2dcc7a52 100644 --- a/tests/encoders/test_pretrainedmodels_encoders.py +++ b/tests/encoders/test_pretrainedmodels_encoders.py @@ -12,10 +12,6 @@ class TestDPNEncoder(base.BaseEncoderTester): ) files_for_diff = ["encoders/dpn.py"] - # works with torch 2.4.0, but not with torch 2.5.1 - # dynamo error, probably on Sequential + OrderedDict - do_test_torch_export = False - def get_tiny_encoder(self): params = { "stage_idxs": [2, 3, 4, 6], @@ -45,10 +41,6 @@ class TestInceptionV4Encoder(base.BaseEncoderTester): files_for_diff = ["encoders/inceptionv4.py"] supports_dilated = False - # works with torch 2.4.0, but not with torch 2.5.1 - # dynamo error, probably on Sequential + OrderedDict - do_test_torch_export = False - class TestSeNetEncoder(base.BaseEncoderTester): encoder_names = ( diff --git a/tests/encoders/test_smp_encoders.py b/tests/encoders/test_smp_encoders.py index 2bb061f8..29e2f416 100644 --- a/tests/encoders/test_smp_encoders.py +++ b/tests/encoders/test_smp_encoders.py @@ -62,6 +62,3 @@ class TestEfficientNetEncoder(base.BaseEncoderTester): ] ) files_for_diff = ["encoders/efficientnet.py"] - - # torch_compile is not supported for efficientnet encoders - do_test_torch_compile = False diff --git a/tests/encoders/test_timm_ported_encoders.py b/tests/encoders/test_timm_ported_encoders.py index 49578f73..3793606e 100644 --- a/tests/encoders/test_timm_ported_encoders.py +++ b/tests/encoders/test_timm_ported_encoders.py @@ -26,9 +26,6 @@ class TestTimmEfficientNetEncoder(base.BaseEncoderTester): ) files_for_diff = ["encoders/timm_efficientnet.py"] - # works with torch 2.4.0, but not with torch 2.5.1 - do_test_torch_export = False - class TestTimmGERNetEncoder(base.BaseEncoderTester): encoder_names = ( diff --git a/tests/models/base.py b/tests/models/base.py index 58048ff4..433022e6 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -263,6 +263,11 @@ def test_torch_script(self): model = self.get_default_model() model.eval() + if not model._is_torch_scriptable: + with self.assertRaises(RuntimeError): + scripted_model = torch.jit.script(model) + return + scripted_model = torch.jit.script(model) with torch.inference_mode(): From d44509a6f8475f4c6a5c4cf5a6b4ee8b1b91606e Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Tue, 14 Jan 2025 18:53:36 +0000 Subject: [PATCH 42/57] Fix encoders (mobilenet, inceptionv4) --- segmentation_models_pytorch/encoders/inceptionv4.py | 4 +--- segmentation_models_pytorch/encoders/mobilenet.py | 2 +- tests/models/base.py | 7 +++++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/segmentation_models_pytorch/encoders/inceptionv4.py b/segmentation_models_pytorch/encoders/inceptionv4.py index fa8da811..cfa0b7c0 100644 --- a/segmentation_models_pytorch/encoders/inceptionv4.py +++ b/segmentation_models_pytorch/encoders/inceptionv4.py @@ -35,7 +35,6 @@ class InceptionV4Encoder(InceptionV4, EncoderMixin): def __init__( self, - out_indexes: List[int], out_channels: List[int], depth: int = 5, output_stride: int = 32, @@ -47,7 +46,7 @@ def __init__( self._in_channels = 3 self._out_channels = out_channels self._output_stride = output_stride - self._out_indexes = out_indexes + self._out_indexes = [2, 4, 8, 14, len(self.features) - 1] # correct paddings for m in self.modules(): @@ -115,7 +114,6 @@ def load_state_dict(self, state_dict, **kwargs): }, }, "params": { - "out_indexes": [2, 4, 8, 14], "out_channels": [3, 64, 192, 384, 1024, 1536], "num_classes": 1001, }, diff --git a/segmentation_models_pytorch/encoders/mobilenet.py b/segmentation_models_pytorch/encoders/mobilenet.py index 2dfa4a63..a803c475 100644 --- a/segmentation_models_pytorch/encoders/mobilenet.py +++ b/segmentation_models_pytorch/encoders/mobilenet.py @@ -40,7 +40,7 @@ def __init__( self._in_channels = 3 self._out_channels = out_channels self._output_stride = output_stride - self._out_indexes = [2, 4, 7, 14] + self._out_indexes = [1, 3, 6, 13, len(self.features) - 1] del self.classifier diff --git a/tests/models/base.py b/tests/models/base.py index 433022e6..b0d3670f 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -148,14 +148,17 @@ def test_classification_head(self): def test_any_resolution(self): model = self.get_default_model() - if model.requires_divisible_input_shape: - self.skipTest("Model requires divisible input shape") sample = self._get_sample( height=self.default_height + 3, width=self.default_width + 7, ).to(default_device) + if model.requires_divisible_input_shape: + with self.assertRaises(RuntimeError, msg="Wrong input shape"): + output = model(sample) + return + with torch.inference_mode(): output = model(sample) From b2c13f1212cfd026b8f258afca59cad33904417d Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Tue, 14 Jan 2025 19:30:40 +0000 Subject: [PATCH 43/57] Update encoders table --- docs/encoders.rst | 496 ++++++++++++----------------------------- misc/generate_table.py | 30 ++- 2 files changed, 158 insertions(+), 368 deletions(-) diff --git a/docs/encoders.rst b/docs/encoders.rst index 652745b7..35fb3845 100644 --- a/docs/encoders.rst +++ b/docs/encoders.rst @@ -1,363 +1,141 @@ 🔍 Available Encoders ===================== -ResNet -~~~~~~ +**Segmentation Models PyTorch** provides support for a wide range of encoders. +This flexibility allows you to use these encoders with any model in the library by +specifying the encoder name in the ``encoder_name`` parameter during model initialization. + +Here’s a quick example of using a ResNet34 encoder with the ``Unet`` model: + +.. code-block:: python + + from segmentation_models_pytorch import Unet + + # Initialize Unet with ResNet34 encoder pre-trained on ImageNet + model = Unet(encoder_name="resnet34", encoder_weights="imagenet") + + +The following encoder families are supported by the library, enabling you to choose the one that best fits your use case: + +- **Mix Vision Transformer (mit)** +- **MobileOne** +- **MobileNet** +- **EfficientNet** +- **ResNet** +- **ResNeXt** +- **SENet** +- **DPN** +- **VGG** +- **DenseNet** +- **Xception** +- **Inception** + +Choosing the Right Encoder +-------------------------- + +1. **Small Models for Edge Devices** + Consider encoders like **MobileNet** or **MobileOne**, which have a smaller parameter count and are optimized for lightweight deployment. + +2. **High Performance** + If you require state-of-the-art accuracy **Mix Vision Transformer (mit)**, **EfficientNet** families offer excellent balance between performance and computational efficiency. + +For each encoder, the table below provides detailed information: + +1. **Pretrained Weights** + Specifies the available pretrained weights (e.g., ``imagenet``, ``imagenet21k``). + +2. **Parameters (Params, M)** + The total number of parameters in the encoder, measured in millions. This metric helps you assess the model's size and computational requirements. + +3. **Scriptable** + Indicates whether the encoder can be scripted with ``torch.jit.script``. + +4. **Compilable** + Indicates whether the encoder is compatible with ``torch.jit.compile`` for enhanced performance. + +5. **Exportable** + Indicates whether the encoder can be exported using ``torch.export.export``, making it suitable for deployment in different environments (e.g., ONNX). + + +.. list-table:: ++----------------------------+--------------------------------------+-----------+--------+---------+--------+ +| Encoder | Pretrained weights | Params, M | Script | Compile | Export | ++============================+======================================+===========+========+=========+========+ +| resnet18 | imagenet
ssl
swsl | 11M | ✅ | ✅ | ✅ | +| resnet34 | imagenet | 21M | ✅ | ✅ | ✅ | +| resnet50 | imagenet
ssl
swsl | 23M | ✅ | ✅ | ✅ | +| resnet101 | imagenet | 42M | ✅ | ✅ | ✅ | +| resnet152 | imagenet | 58M | ✅ | ✅ | ✅ | +| resnext50_32x4d | imagenet
ssl
swsl | 22M | ✅ | ✅ | ✅ | +| resnext101_32x4d | ssl
swsl | 42M | ✅ | ✅ | ✅ | +| resnext101_32x8d | imagenet
instagram
ssl
swsl | 86M | ✅ | ✅ | ✅ | +| resnext101_32x16d | instagram
ssl
swsl | 191M | ✅ | ✅ | ✅ | +| resnext101_32x32d | instagram | 466M | ✅ | ✅ | ✅ | +| resnext101_32x48d | instagram | 826M | ✅ | ✅ | ✅ | +| dpn68 | imagenet | 11M | ❌ | ✅ | ✅ | +| dpn68b | imagenet+5k | 11M | ❌ | ✅ | ✅ | +| dpn92 | imagenet+5k | 34M | ❌ | ✅ | ✅ | +| dpn98 | imagenet | 58M | ❌ | ✅ | ✅ | +| dpn107 | imagenet+5k | 84M | ❌ | ✅ | ✅ | +| dpn131 | imagenet | 76M | ❌ | ✅ | ✅ | +| vgg11 | imagenet | 9M | ✅ | ✅ | ✅ | +| vgg11_bn | imagenet | 9M | ✅ | ✅ | ✅ | +| vgg13 | imagenet | 9M | ✅ | ✅ | ✅ | +| vgg13_bn | imagenet | 9M | ✅ | ✅ | ✅ | +| vgg16 | imagenet | 14M | ✅ | ✅ | ✅ | +| vgg16_bn | imagenet | 14M | ✅ | ✅ | ✅ | +| vgg19 | imagenet | 20M | ✅ | ✅ | ✅ | +| vgg19_bn | imagenet | 20M | ✅ | ✅ | ✅ | +| senet154 | imagenet | 113M | ✅ | ✅ | ✅ | +| se_resnet50 | imagenet | 26M | ✅ | ✅ | ✅ | +| se_resnet101 | imagenet | 47M | ✅ | ✅ | ✅ | +| se_resnet152 | imagenet | 64M | ✅ | ✅ | ✅ | +| se_resnext50_32x4d | imagenet | 25M | ✅ | ✅ | ✅ | +| se_resnext101_32x4d | imagenet | 46M | ✅ | ✅ | ✅ | +| densenet121 | imagenet | 6M | ✅ | ✅ | ✅ | +| densenet169 | imagenet | 12M | ✅ | ✅ | ✅ | +| densenet201 | imagenet | 18M | ✅ | ✅ | ✅ | +| densenet161 | imagenet | 26M | ✅ | ✅ | ✅ | +| inceptionresnetv2 | imagenet
imagenet+background | 54M | ✅ | ✅ | ✅ | +| inceptionv4 | imagenet
imagenet+background | 41M | ✅ | ✅ | ✅ | +| efficientnet-b0 | imagenet
advprop | 4M | ❌ | ❌ | ✅ | +| efficientnet-b1 | imagenet
advprop | 6M | ❌ | ❌ | ✅ | +| efficientnet-b2 | imagenet
advprop | 7M | ❌ | ❌ | ✅ | +| efficientnet-b3 | imagenet
advprop | 10M | ❌ | ❌ | ✅ | +| efficientnet-b4 | imagenet
advprop | 17M | ❌ | ❌ | ✅ | +| efficientnet-b5 | imagenet
advprop | 28M | ❌ | ❌ | ✅ | +| efficientnet-b6 | imagenet
advprop | 40M | ❌ | ❌ | ✅ | +| efficientnet-b7 | imagenet
advprop | 63M | ❌ | ❌ | ✅ | +| mobilenet_v2 | imagenet | 2M | ✅ | ✅ | ✅ | +| xception | imagenet | 20M | ✅ | ✅ | ✅ | +| timm-efficientnet-b0 | imagenet
advprop
noisy-student | 4M | ✅ | ✅ | ✅ | +| timm-efficientnet-b1 | imagenet
advprop
noisy-student | 6M | ✅ | ✅ | ✅ | +| timm-efficientnet-b2 | imagenet
advprop
noisy-student | 7M | ✅ | ✅ | ✅ | +| timm-efficientnet-b3 | imagenet
advprop
noisy-student | 10M | ✅ | ✅ | ✅ | +| timm-efficientnet-b4 | imagenet
advprop
noisy-student | 17M | ✅ | ✅ | ✅ | +| timm-efficientnet-b5 | imagenet
advprop
noisy-student | 28M | ✅ | ✅ | ✅ | +| timm-efficientnet-b6 | imagenet
advprop
noisy-student | 40M | ✅ | ✅ | ✅ | +| timm-efficientnet-b7 | imagenet
advprop
noisy-student | 63M | ✅ | ✅ | ✅ | +| timm-efficientnet-b8 | imagenet
advprop | 84M | ✅ | ✅ | ✅ | +| timm-efficientnet-l2 | noisy-student
noisy-student-475 | 474M | ✅ | ✅ | ✅ | +| timm-tf_efficientnet_lite0 | imagenet | 3M | ✅ | ✅ | ✅ | +| timm-tf_efficientnet_lite1 | imagenet | 4M | ✅ | ✅ | ✅ | +| timm-tf_efficientnet_lite2 | imagenet | 4M | ✅ | ✅ | ✅ | +| timm-tf_efficientnet_lite3 | imagenet | 6M | ✅ | ✅ | ✅ | +| timm-tf_efficientnet_lite4 | imagenet | 11M | ✅ | ✅ | ✅ | +| timm-skresnet18 | imagenet | 11M | ✅ | ✅ | ✅ | +| timm-skresnet34 | imagenet | 21M | ✅ | ✅ | ✅ | +| timm-skresnext50_32x4d | imagenet | 23M | ✅ | ✅ | ✅ | +| mit_b0 | imagenet | 3M | ✅ | ✅ | ✅ | +| mit_b1 | imagenet | 13M | ✅ | ✅ | ✅ | +| mit_b2 | imagenet | 24M | ✅ | ✅ | ✅ | +| mit_b3 | imagenet | 44M | ✅ | ✅ | ✅ | +| mit_b4 | imagenet | 60M | ✅ | ✅ | ✅ | +| mit_b5 | imagenet | 81M | ✅ | ✅ | ✅ | +| mobileone_s0 | imagenet | 4M | ✅ | ✅ | ✅ | +| mobileone_s1 | imagenet | 3M | ✅ | ✅ | ✅ | +| mobileone_s2 | imagenet | 5M | ✅ | ✅ | ✅ | +| mobileone_s3 | imagenet | 8M | ✅ | ✅ | ✅ | +| mobileone_s4 | imagenet | 12M | ✅ | ✅ | ✅ | ++----------------------------+--------------------------------------+-----------+--------+---------+--------+ -+-------------+-------------------------+-------------+ -| Encoder | Weights | Params, M | -+=============+=========================+=============+ -| resnet18 | imagenet / ssl / swsl | 11M | -+-------------+-------------------------+-------------+ -| resnet34 | imagenet | 21M | -+-------------+-------------------------+-------------+ -| resnet50 | imagenet / ssl / swsl | 23M | -+-------------+-------------------------+-------------+ -| resnet101 | imagenet | 42M | -+-------------+-------------------------+-------------+ -| resnet152 | imagenet | 58M | -+-------------+-------------------------+-------------+ - -ResNeXt -~~~~~~~ - -+----------------------+-------------------------------------+-------------+ -| Encoder | Weights | Params, M | -+======================+=====================================+=============+ -| resnext50\_32x4d | imagenet / ssl / swsl | 22M | -+----------------------+-------------------------------------+-------------+ -| resnext101\_32x4d | ssl / swsl | 42M | -+----------------------+-------------------------------------+-------------+ -| resnext101\_32x8d | imagenet / instagram / ssl / swsl | 86M | -+----------------------+-------------------------------------+-------------+ -| resnext101\_32x16d | instagram / ssl / swsl | 191M | -+----------------------+-------------------------------------+-------------+ -| resnext101\_32x32d | instagram | 466M | -+----------------------+-------------------------------------+-------------+ -| resnext101\_32x48d | instagram | 826M | -+----------------------+-------------------------------------+-------------+ - -ResNeSt -~~~~~~~ - -+----------------------------+------------+-------------+ -| Encoder | Weights | Params, M | -+============================+============+=============+ -| timm-resnest14d | imagenet | 8M | -+----------------------------+------------+-------------+ -| timm-resnest26d | imagenet | 15M | -+----------------------------+------------+-------------+ -| timm-resnest50d | imagenet | 25M | -+----------------------------+------------+-------------+ -| timm-resnest101e | imagenet | 46M | -+----------------------------+------------+-------------+ -| timm-resnest200e | imagenet | 68M | -+----------------------------+------------+-------------+ -| timm-resnest269e | imagenet | 108M | -+----------------------------+------------+-------------+ -| timm-resnest50d\_4s2x40d | imagenet | 28M | -+----------------------------+------------+-------------+ -| timm-resnest50d\_1s4x24d | imagenet | 23M | -+----------------------------+------------+-------------+ - -Res2Ne(X)t -~~~~~~~~~~ - -+----------------------------+------------+-------------+ -| Encoder | Weights | Params, M | -+============================+============+=============+ -| timm-res2net50\_26w\_4s | imagenet | 23M | -+----------------------------+------------+-------------+ -| timm-res2net101\_26w\_4s | imagenet | 43M | -+----------------------------+------------+-------------+ -| timm-res2net50\_26w\_6s | imagenet | 35M | -+----------------------------+------------+-------------+ -| timm-res2net50\_26w\_8s | imagenet | 46M | -+----------------------------+------------+-------------+ -| timm-res2net50\_48w\_2s | imagenet | 23M | -+----------------------------+------------+-------------+ -| timm-res2net50\_14w\_8s | imagenet | 23M | -+----------------------------+------------+-------------+ -| timm-res2next50 | imagenet | 22M | -+----------------------------+------------+-------------+ - -RegNet(x/y) -~~~~~~~~~~ - -+---------------------+------------+-------------+ -| Encoder | Weights | Params, M | -+=====================+============+=============+ -| timm-regnetx\_002 | imagenet | 2M | -+---------------------+------------+-------------+ -| timm-regnetx\_004 | imagenet | 4M | -+---------------------+------------+-------------+ -| timm-regnetx\_006 | imagenet | 5M | -+---------------------+------------+-------------+ -| timm-regnetx\_008 | imagenet | 6M | -+---------------------+------------+-------------+ -| timm-regnetx\_016 | imagenet | 8M | -+---------------------+------------+-------------+ -| timm-regnetx\_032 | imagenet | 14M | -+---------------------+------------+-------------+ -| timm-regnetx\_040 | imagenet | 20M | -+---------------------+------------+-------------+ -| timm-regnetx\_064 | imagenet | 24M | -+---------------------+------------+-------------+ -| timm-regnetx\_080 | imagenet | 37M | -+---------------------+------------+-------------+ -| timm-regnetx\_120 | imagenet | 43M | -+---------------------+------------+-------------+ -| timm-regnetx\_160 | imagenet | 52M | -+---------------------+------------+-------------+ -| timm-regnetx\_320 | imagenet | 105M | -+---------------------+------------+-------------+ -| timm-regnety\_002 | imagenet | 2M | -+---------------------+------------+-------------+ -| timm-regnety\_004 | imagenet | 3M | -+---------------------+------------+-------------+ -| timm-regnety\_006 | imagenet | 5M | -+---------------------+------------+-------------+ -| timm-regnety\_008 | imagenet | 5M | -+---------------------+------------+-------------+ -| timm-regnety\_016 | imagenet | 10M | -+---------------------+------------+-------------+ -| timm-regnety\_032 | imagenet | 17M | -+---------------------+------------+-------------+ -| timm-regnety\_040 | imagenet | 19M | -+---------------------+------------+-------------+ -| timm-regnety\_064 | imagenet | 29M | -+---------------------+------------+-------------+ -| timm-regnety\_080 | imagenet | 37M | -+---------------------+------------+-------------+ -| timm-regnety\_120 | imagenet | 49M | -+---------------------+------------+-------------+ -| timm-regnety\_160 | imagenet | 80M | -+---------------------+------------+-------------+ -| timm-regnety\_320 | imagenet | 141M | -+---------------------+------------+-------------+ - -GERNet -~~~~~~ - -+-------------------------+------------+-------------+ -| Encoder | Weights | Params, M | -+=========================+============+=============+ -| timm-gernet\_s | imagenet | 6M | -+-------------------------+------------+-------------+ -| timm-gernet\_m | imagenet | 18M | -+-------------------------+------------+-------------+ -| timm-gernet\_l | imagenet | 28M | -+-------------------------+------------+-------------+ - -SE-Net -~~~~~~ - -+-------------------------+------------+-------------+ -| Encoder | Weights | Params, M | -+=========================+============+=============+ -| senet154 | imagenet | 113M | -+-------------------------+------------+-------------+ -| se\_resnet50 | imagenet | 26M | -+-------------------------+------------+-------------+ -| se\_resnet101 | imagenet | 47M | -+-------------------------+------------+-------------+ -| se\_resnet152 | imagenet | 64M | -+-------------------------+------------+-------------+ -| se\_resnext50\_32x4d | imagenet | 25M | -+-------------------------+------------+-------------+ -| se\_resnext101\_32x4d | imagenet | 46M | -+-------------------------+------------+-------------+ - -SK-ResNe(X)t -~~~~~~~~~~~~ - -+---------------------------+------------+-------------+ -| Encoder | Weights | Params, M | -+===========================+============+=============+ -| timm-skresnet18 | imagenet | 11M | -+---------------------------+------------+-------------+ -| timm-skresnet34 | imagenet | 21M | -+---------------------------+------------+-------------+ -| timm-skresnext50\_32x4d | imagenet | 25M | -+---------------------------+------------+-------------+ - -DenseNet -~~~~~~~~ - -+---------------+------------+-------------+ -| Encoder | Weights | Params, M | -+===============+============+=============+ -| densenet121 | imagenet | 6M | -+---------------+------------+-------------+ -| densenet169 | imagenet | 12M | -+---------------+------------+-------------+ -| densenet201 | imagenet | 18M | -+---------------+------------+-------------+ -| densenet161 | imagenet | 26M | -+---------------+------------+-------------+ - -Inception -~~~~~~~~~ - -+---------------------+----------------------------------+-------------+ -| Encoder | Weights | Params, M | -+=====================+==================================+=============+ -| inceptionresnetv2 | imagenet / imagenet+background | 54M | -+---------------------+----------------------------------+-------------+ -| inceptionv4 | imagenet / imagenet+background | 41M | -+---------------------+----------------------------------+-------------+ -| xception | imagenet | 22M | -+---------------------+----------------------------------+-------------+ - -EfficientNet -~~~~~~~~~~~~ - -+------------------------+--------------------------------------+-------------+ -| Encoder | Weights | Params, M | -+========================+======================================+=============+ -| efficientnet-b0 | imagenet | 4M | -+------------------------+--------------------------------------+-------------+ -| efficientnet-b1 | imagenet | 6M | -+------------------------+--------------------------------------+-------------+ -| efficientnet-b2 | imagenet | 7M | -+------------------------+--------------------------------------+-------------+ -| efficientnet-b3 | imagenet | 10M | -+------------------------+--------------------------------------+-------------+ -| efficientnet-b4 | imagenet | 17M | -+------------------------+--------------------------------------+-------------+ -| efficientnet-b5 | imagenet | 28M | -+------------------------+--------------------------------------+-------------+ -| efficientnet-b6 | imagenet | 40M | -+------------------------+--------------------------------------+-------------+ -| efficientnet-b7 | imagenet | 63M | -+------------------------+--------------------------------------+-------------+ -| timm-efficientnet-b0 | imagenet / advprop / noisy-student | 4M | -+------------------------+--------------------------------------+-------------+ -| timm-efficientnet-b1 | imagenet / advprop / noisy-student | 6M | -+------------------------+--------------------------------------+-------------+ -| timm-efficientnet-b2 | imagenet / advprop / noisy-student | 7M | -+------------------------+--------------------------------------+-------------+ -| timm-efficientnet-b3 | imagenet / advprop / noisy-student | 10M | -+------------------------+--------------------------------------+-------------+ -| timm-efficientnet-b4 | imagenet / advprop / noisy-student | 17M | -+------------------------+--------------------------------------+-------------+ -| timm-efficientnet-b5 | imagenet / advprop / noisy-student | 28M | -+------------------------+--------------------------------------+-------------+ -| timm-efficientnet-b6 | imagenet / advprop / noisy-student | 40M | -+------------------------+--------------------------------------+-------------+ -| timm-efficientnet-b7 | imagenet / advprop / noisy-student | 63M | -+------------------------+--------------------------------------+-------------+ -| timm-efficientnet-b8 | imagenet / advprop | 84M | -+------------------------+--------------------------------------+-------------+ -| timm-efficientnet-l2 | noisy-student / noisy-student-475 | 474M | -+------------------------+--------------------------------------+-------------+ -| timm-efficientnet-lite0| imagenet | 4M | -+------------------------+--------------------------------------+-------------+ -| timm-efficientnet-lite1| imagenet | 4M | -+------------------------+--------------------------------------+-------------+ -| timm-efficientnet-lite2| imagenet | 6M | -+------------------------+--------------------------------------+-------------+ -| timm-efficientnet-lite3| imagenet | 8M | -+------------------------+--------------------------------------+-------------+ -| timm-efficientnet-lite4| imagenet | 13M | -+------------------------+--------------------------------------+-------------+ - -MobileNet -~~~~~~~~~ - -+---------------------------------------+------------+-------------+ -| Encoder | Weights | Params, M | -+=======================================+============+=============+ -| mobilenet\_v2 | imagenet | 2M | -+---------------------------------------+------------+-------------+ -| timm-mobilenetv3\_large\_075 | imagenet | 1.78M | -+---------------------------------------+------------+-------------+ -| timm-mobilenetv3\_large\_100 | imagenet | 2.97M | -+---------------------------------------+------------+-------------+ -| timm-mobilenetv3\_large\_minimal\_100 | imagenet | 1.41M | -+---------------------------------------+------------+-------------+ -| timm-mobilenetv3\_small\_075 | imagenet | 0.57M | -+---------------------------------------+------------+-------------+ -| timm-mobilenetv3\_small\_100 | imagenet | 0.93M | -+---------------------------------------+------------+-------------+ -| timm-mobilenetv3\_small\_minimal\_100 | imagenet | 0.43M | -+---------------------------------------+------------+-------------+ - -DPN -~~~ - -+-----------+---------------+-------------+ -| Encoder | Weights | Params, M | -+===========+===============+=============+ -| dpn68 | imagenet | 11M | -+-----------+---------------+-------------+ -| dpn68b | imagenet+5k | 11M | -+-----------+---------------+-------------+ -| dpn92 | imagenet+5k | 34M | -+-----------+---------------+-------------+ -| dpn98 | imagenet | 58M | -+-----------+---------------+-------------+ -| dpn107 | imagenet+5k | 84M | -+-----------+---------------+-------------+ -| dpn131 | imagenet | 76M | -+-----------+---------------+-------------+ - -VGG -~~~ - -+-------------+------------+-------------+ -| Encoder | Weights | Params, M | -+=============+============+=============+ -| vgg11 | imagenet | 9M | -+-------------+------------+-------------+ -| vgg11\_bn | imagenet | 9M | -+-------------+------------+-------------+ -| vgg13 | imagenet | 9M | -+-------------+------------+-------------+ -| vgg13\_bn | imagenet | 9M | -+-------------+------------+-------------+ -| vgg16 | imagenet | 14M | -+-------------+------------+-------------+ -| vgg16\_bn | imagenet | 14M | -+-------------+------------+-------------+ -| vgg19 | imagenet | 20M | -+-------------+------------+-------------+ -| vgg19\_bn | imagenet | 20M | -+-------------+------------+-------------+ - - -Mix Visual Transformer -~~~~~~~~~~~~~~~~~~~~~ - -+-----------+----------+------------+ -| Encoder | Weights | Params, M | -+===========+==========+============+ -| mit\_b0 | imagenet | 3M | -+-----------+----------+------------+ -| mit\_b1 | imagenet | 13M | -+-----------+----------+------------+ -| mit\_b2 | imagenet | 24M | -+-----------+----------+------------+ -| mit\_b3 | imagenet | 44M | -+-----------+----------+------------+ -| mit\_b4 | imagenet | 60M | -+-----------+----------+------------+ -| mit\_b5 | imagenet | 81M | -+-----------+----------+------------+ - -MobileOne -~~~~~~~~~~~~~~~~~~~~~ - -+-----------------+----------+------------+ -| Encoder | Weights | Params, M | -+=================+==========+============+ -| mobileone\_s0 | imagenet | 4.6M | -+-----------------+----------+------------+ -| mobileone\_s1 | imagenet | 4.0M | -+-----------------+----------+------------+ -| mobileone\_s2 | imagenet | 6.5M | -+-----------------+----------+------------+ -| mobileone\_s3 | imagenet | 8.8M | -+-----------------+----------+------------+ -| mobileone\_s4 | imagenet | 13.6M | -+-----------------+----------+------------+ diff --git a/misc/generate_table.py b/misc/generate_table.py index f14b1a3c..8635b102 100644 --- a/misc/generate_table.py +++ b/misc/generate_table.py @@ -1,11 +1,17 @@ +import os import segmentation_models_pytorch as smp +from tqdm import tqdm + encoders = smp.encoders.encoders WIDTH = 32 -COLUMNS = ["Encoder", "Weights", "Params, M"] +COLUMNS = ["Encoder", "Pretrained weights", "Params, M", "Script", "Compile", "Export"] +FILE = "encoders_table.md" +if os.path.exists(FILE): + os.remove(FILE) def wrap_row(r): return "|{}|".format(r) @@ -16,18 +22,24 @@ def wrap_row(r): ["-" * WIDTH] + [":" + "-" * (WIDTH - 2) + ":"] * (len(COLUMNS) - 1) ) -print(wrap_row(header)) -print(wrap_row(separator)) +print(wrap_row(header), file=open(FILE, "a")) +print(wrap_row(separator), file=open(FILE, "a")) + +for encoder_name, encoder in tqdm(encoders.items()): -for encoder_name, encoder in encoders.items(): weights = "
".join(encoder["pretrained_settings"].keys()) - encoder_name = encoder_name.ljust(WIDTH, " ") - weights = weights.ljust(WIDTH, " ") model = encoder["encoder"](**encoder["params"], depth=5) + + script = "✅" if model._is_torch_scriptable else "❌" + compile = "✅" if model._is_torch_compilable else "❌" + export = "✅" if model._is_torch_exportable else "❌" + params = sum(p.numel() for p in model.parameters()) params = str(params // 1000000) + "M" - params = params.ljust(WIDTH, " ") - row = "|".join([encoder_name, weights, params]) - print(wrap_row(row)) + row = [encoder_name, weights, params, script, compile, export] + row = [str(r).ljust(WIDTH, " ") for r in row] + row = "|".join(row) + + print(wrap_row(row), file=open(FILE, "a")) From 73809e36cf974684020c7efe8b032bd75329f17f Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Tue, 14 Jan 2025 19:49:02 +0000 Subject: [PATCH 44/57] Fix export test --- segmentation_models_pytorch/encoders/dpn.py | 1 + tests/encoders/base.py | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/segmentation_models_pytorch/encoders/dpn.py b/segmentation_models_pytorch/encoders/dpn.py index bf768c8b..b173a514 100644 --- a/segmentation_models_pytorch/encoders/dpn.py +++ b/segmentation_models_pytorch/encoders/dpn.py @@ -35,6 +35,7 @@ class DPNEncoder(DPN, EncoderMixin): _is_torch_scriptable = False + _is_torch_exportable = False def __init__( self, diff --git a/tests/encoders/base.py b/tests/encoders/base.py index 1f1d26c9..8dc1f21a 100644 --- a/tests/encoders/base.py +++ b/tests/encoders/base.py @@ -231,6 +231,15 @@ def test_torch_export(self): encoder = self.get_tiny_encoder() encoder = encoder.eval().to(default_device) + if not encoder._is_torch_exportable: + with self.assertRaises(Exception): + exported_encoder = torch.export.export( + encoder, + args=(sample,), + strict=True, + ) + return + exported_encoder = torch.export.export( encoder, args=(sample,), From bc1319e1dba99cfd4ed00bb704f37aa9f85a0e1d Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Tue, 14 Jan 2025 19:49:17 +0000 Subject: [PATCH 45/57] Fix docs --- docs/encoders.rst | 25 ++++++++++++------------- misc/generate_table.py | 2 +- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/docs/encoders.rst b/docs/encoders.rst index 35fb3845..a1aa9807 100644 --- a/docs/encoders.rst +++ b/docs/encoders.rst @@ -44,22 +44,21 @@ For each encoder, the table below provides detailed information: 1. **Pretrained Weights** Specifies the available pretrained weights (e.g., ``imagenet``, ``imagenet21k``). -2. **Parameters (Params, M)** +2. **Params, M**: The total number of parameters in the encoder, measured in millions. This metric helps you assess the model's size and computational requirements. -3. **Scriptable** +3. **Script**: Indicates whether the encoder can be scripted with ``torch.jit.script``. -4. **Compilable** - Indicates whether the encoder is compatible with ``torch.jit.compile`` for enhanced performance. +4. **Compile**: + Indicates whether the encoder is compatible with ``torch.compile`` for enhanced performance. -5. **Exportable** +5. **Export**: Indicates whether the encoder can be exported using ``torch.export.export``, making it suitable for deployment in different environments (e.g., ONNX). -.. list-table:: +----------------------------+--------------------------------------+-----------+--------+---------+--------+ -| Encoder | Pretrained weights | Params, M | Script | Compile | Export | +| Encoder | Pretrained Weights | Params, M | Script | Compile | Export | +============================+======================================+===========+========+=========+========+ | resnet18 | imagenet
ssl
swsl | 11M | ✅ | ✅ | ✅ | | resnet34 | imagenet | 21M | ✅ | ✅ | ✅ | @@ -72,12 +71,12 @@ For each encoder, the table below provides detailed information: | resnext101_32x16d | instagram
ssl
swsl | 191M | ✅ | ✅ | ✅ | | resnext101_32x32d | instagram | 466M | ✅ | ✅ | ✅ | | resnext101_32x48d | instagram | 826M | ✅ | ✅ | ✅ | -| dpn68 | imagenet | 11M | ❌ | ✅ | ✅ | -| dpn68b | imagenet+5k | 11M | ❌ | ✅ | ✅ | -| dpn92 | imagenet+5k | 34M | ❌ | ✅ | ✅ | -| dpn98 | imagenet | 58M | ❌ | ✅ | ✅ | -| dpn107 | imagenet+5k | 84M | ❌ | ✅ | ✅ | -| dpn131 | imagenet | 76M | ❌ | ✅ | ✅ | +| dpn68 | imagenet | 11M | ❌ | ✅ | ❌ | +| dpn68b | imagenet+5k | 11M | ❌ | ✅ | ❌ | +| dpn92 | imagenet+5k | 34M | ❌ | ✅ | ❌ | +| dpn98 | imagenet | 58M | ❌ | ✅ | ❌ | +| dpn107 | imagenet+5k | 84M | ❌ | ✅ | ❌ | +| dpn131 | imagenet | 76M | ❌ | ✅ | ❌ | | vgg11 | imagenet | 9M | ✅ | ✅ | ✅ | | vgg11_bn | imagenet | 9M | ✅ | ✅ | ✅ | | vgg13 | imagenet | 9M | ✅ | ✅ | ✅ | diff --git a/misc/generate_table.py b/misc/generate_table.py index 8635b102..4e0efed5 100644 --- a/misc/generate_table.py +++ b/misc/generate_table.py @@ -13,6 +13,7 @@ if os.path.exists(FILE): os.remove(FILE) + def wrap_row(r): return "|{}|".format(r) @@ -26,7 +27,6 @@ def wrap_row(r): print(wrap_row(separator), file=open(FILE, "a")) for encoder_name, encoder in tqdm(encoders.items()): - weights = "
".join(encoder["pretrained_settings"].keys()) model = encoder["encoder"](**encoder["params"], depth=5) From d25dd4742fbcd33fe505b157be7b929930ba85d3 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Tue, 14 Jan 2025 19:53:45 +0000 Subject: [PATCH 46/57] Update warning --- segmentation_models_pytorch/encoders/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/segmentation_models_pytorch/encoders/__init__.py b/segmentation_models_pytorch/encoders/__init__.py index 12b106df..9f1f3be5 100644 --- a/segmentation_models_pytorch/encoders/__init__.py +++ b/segmentation_models_pytorch/encoders/__init__.py @@ -66,7 +66,7 @@ def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, ** if name.startswith("timm-"): warnings.warn( "`timm-` encoders are deprecated and will be removed in the future. " - "Please use `tu-` encoders instead." + "Please use `tu-` equivalent encoders instead (see 'Timm encoders' section in the documentation)." ) # convert timm- models to tu- models From 4f3b37ee8c5b1a7858b41c31f45f156ab25001fa Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Tue, 14 Jan 2025 19:57:19 +0000 Subject: [PATCH 47/57] Move pretrained settings --- segmentation_models_pytorch/encoders/dpn.py | 70 ++++++++++++++++++++- 1 file changed, 69 insertions(+), 1 deletion(-) diff --git a/segmentation_models_pytorch/encoders/dpn.py b/segmentation_models_pytorch/encoders/dpn.py index b173a514..e5082cb4 100644 --- a/segmentation_models_pytorch/encoders/dpn.py +++ b/segmentation_models_pytorch/encoders/dpn.py @@ -28,7 +28,6 @@ from typing import List, Dict, Sequence from pretrainedmodels.models.dpn import DPN -from pretrainedmodels.models.dpn import pretrained_settings from ._base import EncoderMixin @@ -97,6 +96,75 @@ def load_state_dict(self, state_dict, **kwargs): super().load_state_dict(state_dict, **kwargs) +pretrained_settings = { + "dpn68": { + "imagenet": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/dpn68-4af7d88d2.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [124 / 255, 117 / 255, 104 / 255], + "std": [1 / (0.0167 * 255)] * 3, + "num_classes": 1000, + } + }, + "dpn68b": { + "imagenet+5k": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/dpn68b_extra-363ab9c19.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [124 / 255, 117 / 255, 104 / 255], + "std": [1 / (0.0167 * 255)] * 3, + "num_classes": 1000, + } + }, + "dpn92": { + "imagenet+5k": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/dpn92_extra-fda993c95.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [124 / 255, 117 / 255, 104 / 255], + "std": [1 / (0.0167 * 255)] * 3, + "num_classes": 1000, + } + }, + "dpn98": { + "imagenet": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/dpn98-722954780.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [124 / 255, 117 / 255, 104 / 255], + "std": [1 / (0.0167 * 255)] * 3, + "num_classes": 1000, + } + }, + "dpn131": { + "imagenet": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/dpn131-7af84be88.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [124 / 255, 117 / 255, 104 / 255], + "std": [1 / (0.0167 * 255)] * 3, + "num_classes": 1000, + } + }, + "dpn107": { + "imagenet+5k": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/dpn107_extra-b7f9f4cc9.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [124 / 255, 117 / 255, 104 / 255], + "std": [1 / (0.0167 * 255)] * 3, + "num_classes": 1000, + } + }, +} + dpn_encoders = { "dpn68": { "encoder": DPNEncoder, From 06199b0b1262b4ec198e3d4cfe3fc6ad4e9ec5a4 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Tue, 14 Jan 2025 20:21:03 +0000 Subject: [PATCH 48/57] Add BC for timm- encoders --- segmentation_models_pytorch/base/model.py | 26 +++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/segmentation_models_pytorch/base/model.py b/segmentation_models_pytorch/base/model.py index 25272f5f..29820840 100644 --- a/segmentation_models_pytorch/base/model.py +++ b/segmentation_models_pytorch/base/model.py @@ -90,3 +90,29 @@ def predict(self, x): x = self.forward(x) return x + + def load_state_dict(self, state_dict, **kwargs): + # for compatibility of weights for + # timm- ported encoders with TimmUniversalEncoder + from segmentation_models_pytorch.encoders import TimmUniversalEncoder + + if not isinstance(self.encoder, TimmUniversalEncoder): + return super().load_state_dict(state_dict, **kwargs) + + patterns = ["regnet", "res2", "resnest", "mobilenetv3", "gernet"] + + is_deprecated_encoder = any( + self.encoder.name.startswith(pattern) for pattern in patterns + ) + + if is_deprecated_encoder: + keys = list(state_dict.keys()) + for key in keys: + new_key = key + if key.startswith("encoder.") and not key.startswith("encoder.model."): + new_key = "encoder.model." + key.removeprefix("encoder.") + if "gernet" in self.encoder.name: + new_key = new_key.replace(".stages.", ".stages_") + state_dict[new_key] = state_dict.pop(key) + + return super().load_state_dict(state_dict, **kwargs) From 51e0a67a76c112cfd58e57ef3fce35223610c849 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Tue, 14 Jan 2025 20:32:32 +0000 Subject: [PATCH 49/57] Fixing table --- docs/encoders.rst | 161 +++++++++++++++++++++++----------------------- 1 file changed, 80 insertions(+), 81 deletions(-) diff --git a/docs/encoders.rst b/docs/encoders.rst index a1aa9807..8d1bb9e4 100644 --- a/docs/encoders.rst +++ b/docs/encoders.rst @@ -57,84 +57,83 @@ For each encoder, the table below provides detailed information: Indicates whether the encoder can be exported using ``torch.export.export``, making it suitable for deployment in different environments (e.g., ONNX). -+----------------------------+--------------------------------------+-----------+--------+---------+--------+ -| Encoder | Pretrained Weights | Params, M | Script | Compile | Export | -+============================+======================================+===========+========+=========+========+ -| resnet18 | imagenet
ssl
swsl | 11M | ✅ | ✅ | ✅ | -| resnet34 | imagenet | 21M | ✅ | ✅ | ✅ | -| resnet50 | imagenet
ssl
swsl | 23M | ✅ | ✅ | ✅ | -| resnet101 | imagenet | 42M | ✅ | ✅ | ✅ | -| resnet152 | imagenet | 58M | ✅ | ✅ | ✅ | -| resnext50_32x4d | imagenet
ssl
swsl | 22M | ✅ | ✅ | ✅ | -| resnext101_32x4d | ssl
swsl | 42M | ✅ | ✅ | ✅ | -| resnext101_32x8d | imagenet
instagram
ssl
swsl | 86M | ✅ | ✅ | ✅ | -| resnext101_32x16d | instagram
ssl
swsl | 191M | ✅ | ✅ | ✅ | -| resnext101_32x32d | instagram | 466M | ✅ | ✅ | ✅ | -| resnext101_32x48d | instagram | 826M | ✅ | ✅ | ✅ | -| dpn68 | imagenet | 11M | ❌ | ✅ | ❌ | -| dpn68b | imagenet+5k | 11M | ❌ | ✅ | ❌ | -| dpn92 | imagenet+5k | 34M | ❌ | ✅ | ❌ | -| dpn98 | imagenet | 58M | ❌ | ✅ | ❌ | -| dpn107 | imagenet+5k | 84M | ❌ | ✅ | ❌ | -| dpn131 | imagenet | 76M | ❌ | ✅ | ❌ | -| vgg11 | imagenet | 9M | ✅ | ✅ | ✅ | -| vgg11_bn | imagenet | 9M | ✅ | ✅ | ✅ | -| vgg13 | imagenet | 9M | ✅ | ✅ | ✅ | -| vgg13_bn | imagenet | 9M | ✅ | ✅ | ✅ | -| vgg16 | imagenet | 14M | ✅ | ✅ | ✅ | -| vgg16_bn | imagenet | 14M | ✅ | ✅ | ✅ | -| vgg19 | imagenet | 20M | ✅ | ✅ | ✅ | -| vgg19_bn | imagenet | 20M | ✅ | ✅ | ✅ | -| senet154 | imagenet | 113M | ✅ | ✅ | ✅ | -| se_resnet50 | imagenet | 26M | ✅ | ✅ | ✅ | -| se_resnet101 | imagenet | 47M | ✅ | ✅ | ✅ | -| se_resnet152 | imagenet | 64M | ✅ | ✅ | ✅ | -| se_resnext50_32x4d | imagenet | 25M | ✅ | ✅ | ✅ | -| se_resnext101_32x4d | imagenet | 46M | ✅ | ✅ | ✅ | -| densenet121 | imagenet | 6M | ✅ | ✅ | ✅ | -| densenet169 | imagenet | 12M | ✅ | ✅ | ✅ | -| densenet201 | imagenet | 18M | ✅ | ✅ | ✅ | -| densenet161 | imagenet | 26M | ✅ | ✅ | ✅ | -| inceptionresnetv2 | imagenet
imagenet+background | 54M | ✅ | ✅ | ✅ | -| inceptionv4 | imagenet
imagenet+background | 41M | ✅ | ✅ | ✅ | -| efficientnet-b0 | imagenet
advprop | 4M | ❌ | ❌ | ✅ | -| efficientnet-b1 | imagenet
advprop | 6M | ❌ | ❌ | ✅ | -| efficientnet-b2 | imagenet
advprop | 7M | ❌ | ❌ | ✅ | -| efficientnet-b3 | imagenet
advprop | 10M | ❌ | ❌ | ✅ | -| efficientnet-b4 | imagenet
advprop | 17M | ❌ | ❌ | ✅ | -| efficientnet-b5 | imagenet
advprop | 28M | ❌ | ❌ | ✅ | -| efficientnet-b6 | imagenet
advprop | 40M | ❌ | ❌ | ✅ | -| efficientnet-b7 | imagenet
advprop | 63M | ❌ | ❌ | ✅ | -| mobilenet_v2 | imagenet | 2M | ✅ | ✅ | ✅ | -| xception | imagenet | 20M | ✅ | ✅ | ✅ | -| timm-efficientnet-b0 | imagenet
advprop
noisy-student | 4M | ✅ | ✅ | ✅ | -| timm-efficientnet-b1 | imagenet
advprop
noisy-student | 6M | ✅ | ✅ | ✅ | -| timm-efficientnet-b2 | imagenet
advprop
noisy-student | 7M | ✅ | ✅ | ✅ | -| timm-efficientnet-b3 | imagenet
advprop
noisy-student | 10M | ✅ | ✅ | ✅ | -| timm-efficientnet-b4 | imagenet
advprop
noisy-student | 17M | ✅ | ✅ | ✅ | -| timm-efficientnet-b5 | imagenet
advprop
noisy-student | 28M | ✅ | ✅ | ✅ | -| timm-efficientnet-b6 | imagenet
advprop
noisy-student | 40M | ✅ | ✅ | ✅ | -| timm-efficientnet-b7 | imagenet
advprop
noisy-student | 63M | ✅ | ✅ | ✅ | -| timm-efficientnet-b8 | imagenet
advprop | 84M | ✅ | ✅ | ✅ | -| timm-efficientnet-l2 | noisy-student
noisy-student-475 | 474M | ✅ | ✅ | ✅ | -| timm-tf_efficientnet_lite0 | imagenet | 3M | ✅ | ✅ | ✅ | -| timm-tf_efficientnet_lite1 | imagenet | 4M | ✅ | ✅ | ✅ | -| timm-tf_efficientnet_lite2 | imagenet | 4M | ✅ | ✅ | ✅ | -| timm-tf_efficientnet_lite3 | imagenet | 6M | ✅ | ✅ | ✅ | -| timm-tf_efficientnet_lite4 | imagenet | 11M | ✅ | ✅ | ✅ | -| timm-skresnet18 | imagenet | 11M | ✅ | ✅ | ✅ | -| timm-skresnet34 | imagenet | 21M | ✅ | ✅ | ✅ | -| timm-skresnext50_32x4d | imagenet | 23M | ✅ | ✅ | ✅ | -| mit_b0 | imagenet | 3M | ✅ | ✅ | ✅ | -| mit_b1 | imagenet | 13M | ✅ | ✅ | ✅ | -| mit_b2 | imagenet | 24M | ✅ | ✅ | ✅ | -| mit_b3 | imagenet | 44M | ✅ | ✅ | ✅ | -| mit_b4 | imagenet | 60M | ✅ | ✅ | ✅ | -| mit_b5 | imagenet | 81M | ✅ | ✅ | ✅ | -| mobileone_s0 | imagenet | 4M | ✅ | ✅ | ✅ | -| mobileone_s1 | imagenet | 3M | ✅ | ✅ | ✅ | -| mobileone_s2 | imagenet | 5M | ✅ | ✅ | ✅ | -| mobileone_s3 | imagenet | 8M | ✅ | ✅ | ✅ | -| mobileone_s4 | imagenet | 12M | ✅ | ✅ | ✅ | -+----------------------------+--------------------------------------+-----------+--------+---------+--------+ - +============================ ==================================== =========== ======== ========= ======== +Encoder Pretrained weights Params, M Script Compile Export +============================ ==================================== =========== ======== ========= ======== +resnet18 imagenet / ssl / swsl 11M ✅ ✅ ✅ +resnet34 imagenet 21M ✅ ✅ ✅ +resnet50 imagenet / ssl / swsl 23M ✅ ✅ ✅ +resnet101 imagenet 42M ✅ ✅ ✅ +resnet152 imagenet 58M ✅ ✅ ✅ +resnext50_32x4d imagenet / ssl / swsl 22M ✅ ✅ ✅ +resnext101_32x4d ssl / swsl 42M ✅ ✅ ✅ +resnext101_32x8d imagenet / instagram / ssl / swsl 86M ✅ ✅ ✅ +resnext101_32x16d instagram / ssl / swsl 191M ✅ ✅ ✅ +resnext101_32x32d instagram 466M ✅ ✅ ✅ +resnext101_32x48d instagram 826M ✅ ✅ ✅ +dpn68 imagenet 11M ❌ ✅ ✅ +dpn68b imagenet+5k 11M ❌ ✅ ✅ +dpn92 imagenet+5k 34M ❌ ✅ ✅ +dpn98 imagenet 58M ❌ ✅ ✅ +dpn107 imagenet+5k 84M ❌ ✅ ✅ +dpn131 imagenet 76M ❌ ✅ ✅ +vgg11 imagenet 9M ✅ ✅ ✅ +vgg11_bn imagenet 9M ✅ ✅ ✅ +vgg13 imagenet 9M ✅ ✅ ✅ +vgg13_bn imagenet 9M ✅ ✅ ✅ +vgg16 imagenet 14M ✅ ✅ ✅ +vgg16_bn imagenet 14M ✅ ✅ ✅ +vgg19 imagenet 20M ✅ ✅ ✅ +vgg19_bn imagenet 20M ✅ ✅ ✅ +senet154 imagenet 113M ✅ ✅ ✅ +se_resnet50 imagenet 26M ✅ ✅ ✅ +se_resnet101 imagenet 47M ✅ ✅ ✅ +se_resnet152 imagenet 64M ✅ ✅ ✅ +se_resnext50_32x4d imagenet 25M ✅ ✅ ✅ +se_resnext101_32x4d imagenet 46M ✅ ✅ ✅ +densenet121 imagenet 6M ✅ ✅ ✅ +densenet169 imagenet 12M ✅ ✅ ✅ +densenet201 imagenet 18M ✅ ✅ ✅ +densenet161 imagenet 26M ✅ ✅ ✅ +inceptionresnetv2 imagenet / imagenet+background 54M ✅ ✅ ✅ +inceptionv4 imagenet / imagenet+background 41M ✅ ✅ ✅ +efficientnet-b0 imagenet / advprop 4M ❌ ❌ ✅ +efficientnet-b1 imagenet / advprop 6M ❌ ❌ ✅ +efficientnet-b2 imagenet / advprop 7M ❌ ❌ ✅ +efficientnet-b3 imagenet / advprop 10M ❌ ❌ ✅ +efficientnet-b4 imagenet / advprop 17M ❌ ❌ ✅ +efficientnet-b5 imagenet / advprop 28M ❌ ❌ ✅ +efficientnet-b6 imagenet / advprop 40M ❌ ❌ ✅ +efficientnet-b7 imagenet / advprop 63M ❌ ❌ ✅ +mobilenet_v2 imagenet 2M ✅ ✅ ✅ +xception imagenet 20M ✅ ✅ ✅ +timm-efficientnet-b0 imagenet / advprop / noisy-student 4M ✅ ✅ ✅ +timm-efficientnet-b1 imagenet / advprop / noisy-student 6M ✅ ✅ ✅ +timm-efficientnet-b2 imagenet / advprop / noisy-student 7M ✅ ✅ ✅ +timm-efficientnet-b3 imagenet / advprop / noisy-student 10M ✅ ✅ ✅ +timm-efficientnet-b4 imagenet / advprop / noisy-student 17M ✅ ✅ ✅ +timm-efficientnet-b5 imagenet / advprop / noisy-student 28M ✅ ✅ ✅ +timm-efficientnet-b6 imagenet / advprop / noisy-student 40M ✅ ✅ ✅ +timm-efficientnet-b7 imagenet / advprop / noisy-student 63M ✅ ✅ ✅ +timm-efficientnet-b8 imagenet / advprop 84M ✅ ✅ ✅ +timm-efficientnet-l2 noisy-student / noisy-student-475 474M ✅ ✅ ✅ +timm-tf_efficientnet_lite0 imagenet 3M ✅ ✅ ✅ +timm-tf_efficientnet_lite1 imagenet 4M ✅ ✅ ✅ +timm-tf_efficientnet_lite2 imagenet 4M ✅ ✅ ✅ +timm-tf_efficientnet_lite3 imagenet 6M ✅ ✅ ✅ +timm-tf_efficientnet_lite4 imagenet 11M ✅ ✅ ✅ +timm-skresnet18 imagenet 11M ✅ ✅ ✅ +timm-skresnet34 imagenet 21M ✅ ✅ ✅ +timm-skresnext50_32x4d imagenet 23M ✅ ✅ ✅ +mit_b0 imagenet 3M ✅ ✅ ✅ +mit_b1 imagenet 13M ✅ ✅ ✅ +mit_b2 imagenet 24M ✅ ✅ ✅ +mit_b3 imagenet 44M ✅ ✅ ✅ +mit_b4 imagenet 60M ✅ ✅ ✅ +mit_b5 imagenet 81M ✅ ✅ ✅ +mobileone_s0 imagenet 4M ✅ ✅ ✅ +mobileone_s1 imagenet 3M ✅ ✅ ✅ +mobileone_s2 imagenet 5M ✅ ✅ ✅ +mobileone_s3 imagenet 8M ✅ ✅ ✅ +mobileone_s4 imagenet 12M ✅ ✅ ✅ +============================ ==================================== =========== ======== ========= ======== From 524bcae73e049b80a72bf93ed39b809bf995822e Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Tue, 14 Jan 2025 21:35:47 +0000 Subject: [PATCH 50/57] Update compile test --- docs/encoders.rst | 16 ++++++++-------- .../encoders/efficientnet.py | 3 --- tests/encoders/base.py | 1 + tests/models/base.py | 4 +++- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/docs/encoders.rst b/docs/encoders.rst index 8d1bb9e4..c0016bc5 100644 --- a/docs/encoders.rst +++ b/docs/encoders.rst @@ -97,14 +97,14 @@ densenet201 imagenet 18M densenet161 imagenet 26M ✅ ✅ ✅ inceptionresnetv2 imagenet / imagenet+background 54M ✅ ✅ ✅ inceptionv4 imagenet / imagenet+background 41M ✅ ✅ ✅ -efficientnet-b0 imagenet / advprop 4M ❌ ❌ ✅ -efficientnet-b1 imagenet / advprop 6M ❌ ❌ ✅ -efficientnet-b2 imagenet / advprop 7M ❌ ❌ ✅ -efficientnet-b3 imagenet / advprop 10M ❌ ❌ ✅ -efficientnet-b4 imagenet / advprop 17M ❌ ❌ ✅ -efficientnet-b5 imagenet / advprop 28M ❌ ❌ ✅ -efficientnet-b6 imagenet / advprop 40M ❌ ❌ ✅ -efficientnet-b7 imagenet / advprop 63M ❌ ❌ ✅ +efficientnet-b0 imagenet / advprop 4M ❌ ✅ ✅ +efficientnet-b1 imagenet / advprop 6M ❌ ✅ ✅ +efficientnet-b2 imagenet / advprop 7M ❌ ✅ ✅ +efficientnet-b3 imagenet / advprop 10M ❌ ✅ ✅ +efficientnet-b4 imagenet / advprop 17M ❌ ✅ ✅ +efficientnet-b5 imagenet / advprop 28M ❌ ✅ ✅ +efficientnet-b6 imagenet / advprop 40M ❌ ✅ ✅ +efficientnet-b7 imagenet / advprop 63M ❌ ✅ ✅ mobilenet_v2 imagenet 2M ✅ ✅ ✅ xception imagenet 20M ✅ ✅ ✅ timm-efficientnet-b0 imagenet / advprop / noisy-student 4M ✅ ✅ ✅ diff --git a/segmentation_models_pytorch/encoders/efficientnet.py b/segmentation_models_pytorch/encoders/efficientnet.py index 5c826a58..c0483b39 100644 --- a/segmentation_models_pytorch/encoders/efficientnet.py +++ b/segmentation_models_pytorch/encoders/efficientnet.py @@ -35,9 +35,6 @@ class EfficientNetEncoder(EfficientNet, EncoderMixin): _is_torch_scriptable = False - # works with torch 2.4.0, but not with torch 2.5.1 - _is_torch_compilable = False - def __init__( self, stage_idxs: List[int], diff --git a/tests/encoders/base.py b/tests/encoders/base.py index 8dc1f21a..68f7c39b 100644 --- a/tests/encoders/base.py +++ b/tests/encoders/base.py @@ -212,6 +212,7 @@ def test_compile(self): encoder = self.get_tiny_encoder() encoder = encoder.eval().to(default_device) + torch.compiler.reset() compiled_encoder = torch.compile(encoder, fullgraph=True, dynamic=True) if encoder._is_torch_compilable: diff --git a/tests/models/base.py b/tests/models/base.py index b0d3670f..8434ced7 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -228,8 +228,10 @@ def test_compile(self): self.skipTest("No diff and not on `main`.") sample = self._get_sample().to(default_device) - model = self.get_default_model() + model = model.eval().to(default_device) + + torch.compiler.reset() compiled_model = torch.compile(model, fullgraph=True, dynamic=True) with torch.inference_mode(): From a2b97d879f1e34b15c09996fa236262592c2c858 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Wed, 15 Jan 2025 10:42:40 +0000 Subject: [PATCH 51/57] Change compile backend to eager --- tests/encoders/base.py | 2 +- tests/models/base.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/encoders/base.py b/tests/encoders/base.py index 68f7c39b..3b309b3a 100644 --- a/tests/encoders/base.py +++ b/tests/encoders/base.py @@ -213,7 +213,7 @@ def test_compile(self): encoder = encoder.eval().to(default_device) torch.compiler.reset() - compiled_encoder = torch.compile(encoder, fullgraph=True, dynamic=True) + compiled_encoder = torch.compile(encoder, fullgraph=True, dynamic=True, backend="eager") if encoder._is_torch_compilable: compiled_encoder(sample) diff --git a/tests/models/base.py b/tests/models/base.py index 8434ced7..fca0b36c 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -232,7 +232,7 @@ def test_compile(self): model = model.eval().to(default_device) torch.compiler.reset() - compiled_model = torch.compile(model, fullgraph=True, dynamic=True) + compiled_model = torch.compile(model, fullgraph=True, dynamic=True, backend="eager") with torch.inference_mode(): compiled_model(sample) From 17a4b70e26691688593f2416a9da93fe028e3c17 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Wed, 15 Jan 2025 10:43:06 +0000 Subject: [PATCH 52/57] Update docs --- docs/encoders.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/encoders.rst b/docs/encoders.rst index c0016bc5..0da221f9 100644 --- a/docs/encoders.rst +++ b/docs/encoders.rst @@ -51,7 +51,9 @@ For each encoder, the table below provides detailed information: Indicates whether the encoder can be scripted with ``torch.jit.script``. 4. **Compile**: - Indicates whether the encoder is compatible with ``torch.compile`` for enhanced performance. + Indicates whether the encoder is compatible with ``torch.compile(model, fullgraph=True, dynamic=True, backend="eager")``. + You may still get some issues with another backends, such as ``inductor``, depending on the torch/cuda/... dependencies version, + but most of the time it will work. 5. **Export**: Indicates whether the encoder can be exported using ``torch.export.export``, making it suitable for deployment in different environments (e.g., ONNX). From 20564f231983ab8e7ea82671b132028aced72b73 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Wed, 15 Jan 2025 10:48:44 +0000 Subject: [PATCH 53/57] Fixup --- tests/encoders/base.py | 4 +++- tests/models/base.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/encoders/base.py b/tests/encoders/base.py index 3b309b3a..b1113437 100644 --- a/tests/encoders/base.py +++ b/tests/encoders/base.py @@ -213,7 +213,9 @@ def test_compile(self): encoder = encoder.eval().to(default_device) torch.compiler.reset() - compiled_encoder = torch.compile(encoder, fullgraph=True, dynamic=True, backend="eager") + compiled_encoder = torch.compile( + encoder, fullgraph=True, dynamic=True, backend="eager" + ) if encoder._is_torch_compilable: compiled_encoder(sample) diff --git a/tests/models/base.py b/tests/models/base.py index fca0b36c..f7492986 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -232,7 +232,9 @@ def test_compile(self): model = model.eval().to(default_device) torch.compiler.reset() - compiled_model = torch.compile(model, fullgraph=True, dynamic=True, backend="eager") + compiled_model = torch.compile( + model, fullgraph=True, dynamic=True, backend="eager" + ) with torch.inference_mode(): compiled_model(sample) From 5bbb1db7b9487b729d2dab5799ec82cf094958c7 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Wed, 15 Jan 2025 12:00:30 +0000 Subject: [PATCH 54/57] Fix batchnorm typo --- .../decoders/pspnet/decoder.py | 12 ++++++------ .../decoders/upernet/decoder.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/segmentation_models_pytorch/decoders/pspnet/decoder.py b/segmentation_models_pytorch/decoders/pspnet/decoder.py index 99ec5f72..42ac42d0 100644 --- a/segmentation_models_pytorch/decoders/pspnet/decoder.py +++ b/segmentation_models_pytorch/decoders/pspnet/decoder.py @@ -12,17 +12,17 @@ def __init__( in_channels: int, out_channels: int, pool_size: int, - use_bathcnorm: bool = True, + use_batchnorm: bool = True, ): super().__init__() if pool_size == 1: - use_bathcnorm = False # PyTorch does not support BatchNorm for 1x1 shape + use_batchnorm = False # PyTorch does not support BatchNorm for 1x1 shape self.pool = nn.Sequential( nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size)), modules.Conv2dReLU( - in_channels, out_channels, (1, 1), use_batchnorm=use_bathcnorm + in_channels, out_channels, (1, 1), use_batchnorm=use_batchnorm ), ) @@ -38,7 +38,7 @@ def __init__( self, in_channels: int, sizes: Tuple[int, ...] = (1, 2, 3, 6), - use_bathcnorm: bool = True, + use_batchnorm: bool = True, ): super().__init__() @@ -48,7 +48,7 @@ def __init__( in_channels, in_channels // len(sizes), size, - use_bathcnorm=use_bathcnorm, + use_batchnorm=use_batchnorm, ) for size in sizes ] @@ -73,7 +73,7 @@ def __init__( self.psp = PSPModule( in_channels=encoder_channels[-1], sizes=(1, 2, 3, 6), - use_bathcnorm=use_batchnorm, + use_batchnorm=use_batchnorm, ) self.conv = modules.Conv2dReLU( diff --git a/segmentation_models_pytorch/decoders/upernet/decoder.py b/segmentation_models_pytorch/decoders/upernet/decoder.py index ebcb3d10..99c74fb1 100644 --- a/segmentation_models_pytorch/decoders/upernet/decoder.py +++ b/segmentation_models_pytorch/decoders/upernet/decoder.py @@ -48,14 +48,14 @@ def forward(self, x): class FPNBlock(nn.Module): - def __init__(self, skip_channels, pyramid_channels, use_bathcnorm=True): + def __init__(self, skip_channels, pyramid_channels, use_batchnorm=True): super().__init__() self.skip_conv = ( md.Conv2dReLU( skip_channels, pyramid_channels, kernel_size=1, - use_batchnorm=use_bathcnorm, + use_batchnorm=use_batchnorm, ) if skip_channels != 0 else nn.Identity() From d121fecab27dfa8a54510201e1db8cdfad3d38c1 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Wed, 15 Jan 2025 12:33:47 +0000 Subject: [PATCH 55/57] Add depth validation --- segmentation_models_pytorch/encoders/densenet.py | 6 ++++++ segmentation_models_pytorch/encoders/dpn.py | 5 +++++ .../encoders/efficientnet.py | 5 +++++ .../encoders/inceptionresnetv2.py | 5 +++++ .../encoders/inceptionv4.py | 4 ++++ .../encoders/mix_transformer.py | 4 ++++ segmentation_models_pytorch/encoders/mobilenet.py | 4 ++++ segmentation_models_pytorch/encoders/mobileone.py | 5 +++++ segmentation_models_pytorch/encoders/resnet.py | 4 ++++ segmentation_models_pytorch/encoders/senet.py | 4 ++++ .../encoders/timm_efficientnet.py | 4 ++++ segmentation_models_pytorch/encoders/timm_sknet.py | 4 ++++ .../encoders/timm_universal.py | 7 +++++++ segmentation_models_pytorch/encoders/vgg.py | 4 ++++ segmentation_models_pytorch/encoders/xception.py | 14 +++++++++++++- tests/encoders/base.py | 6 ++++++ 16 files changed, 84 insertions(+), 1 deletion(-) diff --git a/segmentation_models_pytorch/encoders/densenet.py b/segmentation_models_pytorch/encoders/densenet.py index aa61db35..3ce9b3d0 100644 --- a/segmentation_models_pytorch/encoders/densenet.py +++ b/segmentation_models_pytorch/encoders/densenet.py @@ -32,7 +32,13 @@ class DenseNetEncoder(DenseNet, EncoderMixin): def __init__(self, out_channels, depth=5, output_stride=32, **kwargs): + if depth > 5 or depth < 1: + raise ValueError( + f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" + ) + super().__init__(**kwargs) + self._depth = depth self._in_channels = 3 self._out_channels = out_channels diff --git a/segmentation_models_pytorch/encoders/dpn.py b/segmentation_models_pytorch/encoders/dpn.py index e5082cb4..1034540d 100644 --- a/segmentation_models_pytorch/encoders/dpn.py +++ b/segmentation_models_pytorch/encoders/dpn.py @@ -44,6 +44,11 @@ def __init__( output_stride: int = 32, **kwargs, ): + if depth > 5 or depth < 1: + raise ValueError( + f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" + ) + super().__init__(**kwargs) self._stage_idxs = stage_idxs self._depth = depth diff --git a/segmentation_models_pytorch/encoders/efficientnet.py b/segmentation_models_pytorch/encoders/efficientnet.py index c0483b39..96edb4fe 100644 --- a/segmentation_models_pytorch/encoders/efficientnet.py +++ b/segmentation_models_pytorch/encoders/efficientnet.py @@ -43,6 +43,11 @@ def __init__( depth: int = 5, output_stride: int = 32, ): + if depth > 5 or depth < 1: + raise ValueError( + f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" + ) + blocks_args, global_params = get_model_params(model_name, override_params=None) super().__init__(blocks_args, global_params) diff --git a/segmentation_models_pytorch/encoders/inceptionresnetv2.py b/segmentation_models_pytorch/encoders/inceptionresnetv2.py index df3da839..3ac662e2 100644 --- a/segmentation_models_pytorch/encoders/inceptionresnetv2.py +++ b/segmentation_models_pytorch/encoders/inceptionresnetv2.py @@ -39,6 +39,11 @@ def __init__( output_stride: int = 32, **kwargs, ): + if depth > 5 or depth < 1: + raise ValueError( + f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" + ) + super().__init__(**kwargs) self._depth = depth diff --git a/segmentation_models_pytorch/encoders/inceptionv4.py b/segmentation_models_pytorch/encoders/inceptionv4.py index cfa0b7c0..c5b79b02 100644 --- a/segmentation_models_pytorch/encoders/inceptionv4.py +++ b/segmentation_models_pytorch/encoders/inceptionv4.py @@ -40,6 +40,10 @@ def __init__( output_stride: int = 32, **kwargs, ): + if depth > 5 or depth < 1: + raise ValueError( + f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" + ) super().__init__(**kwargs) self._depth = depth diff --git a/segmentation_models_pytorch/encoders/mix_transformer.py b/segmentation_models_pytorch/encoders/mix_transformer.py index 479c3f09..7430dd4d 100644 --- a/segmentation_models_pytorch/encoders/mix_transformer.py +++ b/segmentation_models_pytorch/encoders/mix_transformer.py @@ -529,6 +529,10 @@ class MixVisionTransformerEncoder(MixVisionTransformer, EncoderMixin): def __init__( self, out_channels: List[int], depth: int = 5, output_stride: int = 32, **kwargs ): + if depth > 5 or depth < 1: + raise ValueError( + f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" + ) super().__init__(**kwargs) self._depth = depth diff --git a/segmentation_models_pytorch/encoders/mobilenet.py b/segmentation_models_pytorch/encoders/mobilenet.py index a803c475..af7fc122 100644 --- a/segmentation_models_pytorch/encoders/mobilenet.py +++ b/segmentation_models_pytorch/encoders/mobilenet.py @@ -34,6 +34,10 @@ class MobileNetV2Encoder(torchvision.models.MobileNetV2, EncoderMixin): def __init__( self, out_channels: List[int], depth: int = 5, output_stride: int = 32, **kwargs ): + if depth > 5 or depth < 1: + raise ValueError( + f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" + ) super().__init__(**kwargs) self._depth = depth diff --git a/segmentation_models_pytorch/encoders/mobileone.py b/segmentation_models_pytorch/encoders/mobileone.py index 131675cb..3430b978 100644 --- a/segmentation_models_pytorch/encoders/mobileone.py +++ b/segmentation_models_pytorch/encoders/mobileone.py @@ -319,6 +319,11 @@ def __init__( :param use_se: Whether to use SE-ReLU activations. :param num_conv_branches: Number of linear conv branches. """ + if depth > 5 or depth < 1: + raise ValueError( + f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" + ) + super().__init__() assert len(width_multipliers) == 4 diff --git a/segmentation_models_pytorch/encoders/resnet.py b/segmentation_models_pytorch/encoders/resnet.py index fc1665dd..383af002 100644 --- a/segmentation_models_pytorch/encoders/resnet.py +++ b/segmentation_models_pytorch/encoders/resnet.py @@ -38,6 +38,10 @@ class ResNetEncoder(ResNet, EncoderMixin): def __init__( self, out_channels: List[int], depth: int = 5, output_stride: int = 32, **kwargs ): + if depth > 5 or depth < 1: + raise ValueError( + f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" + ) super().__init__(**kwargs) self._depth = depth diff --git a/segmentation_models_pytorch/encoders/senet.py b/segmentation_models_pytorch/encoders/senet.py index a3b44877..ff900742 100644 --- a/segmentation_models_pytorch/encoders/senet.py +++ b/segmentation_models_pytorch/encoders/senet.py @@ -43,6 +43,10 @@ def __init__( output_stride: int = 32, **kwargs, ): + if depth > 5 or depth < 1: + raise ValueError( + f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" + ) super().__init__(**kwargs) self._depth = depth diff --git a/segmentation_models_pytorch/encoders/timm_efficientnet.py b/segmentation_models_pytorch/encoders/timm_efficientnet.py index 7cd52923..0dbb90b0 100644 --- a/segmentation_models_pytorch/encoders/timm_efficientnet.py +++ b/segmentation_models_pytorch/encoders/timm_efficientnet.py @@ -105,6 +105,10 @@ def __init__( output_stride: int = 32, **kwargs, ): + if depth > 5 or depth < 1: + raise ValueError( + f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" + ) super().__init__(**kwargs) self._stage_idxs = stage_idxs diff --git a/segmentation_models_pytorch/encoders/timm_sknet.py b/segmentation_models_pytorch/encoders/timm_sknet.py index a28e6330..12fdd822 100644 --- a/segmentation_models_pytorch/encoders/timm_sknet.py +++ b/segmentation_models_pytorch/encoders/timm_sknet.py @@ -14,6 +14,10 @@ def __init__( output_stride: int = 32, **kwargs, ): + if depth > 5 or depth < 1: + raise ValueError( + f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" + ) super().__init__(**kwargs) self._depth = depth diff --git a/segmentation_models_pytorch/encoders/timm_universal.py b/segmentation_models_pytorch/encoders/timm_universal.py index 759ede51..299178f2 100644 --- a/segmentation_models_pytorch/encoders/timm_universal.py +++ b/segmentation_models_pytorch/encoders/timm_universal.py @@ -68,6 +68,13 @@ def __init__( output_stride (int): Desired output stride (default: 32). **kwargs: Additional arguments passed to `timm.create_model`. """ + # At the moment we do not support models with more than 5 stages, + # but can be reconfigured in the future. + if depth > 5 or depth < 1: + raise ValueError( + f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" + ) + super().__init__() self.name = name diff --git a/segmentation_models_pytorch/encoders/vgg.py b/segmentation_models_pytorch/encoders/vgg.py index 82c9c431..5b89a50a 100644 --- a/segmentation_models_pytorch/encoders/vgg.py +++ b/segmentation_models_pytorch/encoders/vgg.py @@ -53,6 +53,10 @@ def __init__( output_stride: int = 32, **kwargs, ): + if depth > 5 or depth < 1: + raise ValueError( + f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" + ) super().__init__(make_layers(config, batch_norm=batch_norm), **kwargs) self._depth = depth diff --git a/segmentation_models_pytorch/encoders/xception.py b/segmentation_models_pytorch/encoders/xception.py index ab78d6ac..f81dc959 100644 --- a/segmentation_models_pytorch/encoders/xception.py +++ b/segmentation_models_pytorch/encoders/xception.py @@ -1,10 +1,22 @@ +from typing import List from pretrainedmodels.models.xception import Xception from ._base import EncoderMixin class XceptionEncoder(Xception, EncoderMixin): - def __init__(self, out_channels, *args, depth=5, output_stride=32, **kwargs): + def __init__( + self, + out_channels: List[int], + *args, + depth: int = 5, + output_stride: int = 32, + **kwargs, + ): + if depth > 5 or depth < 1: + raise ValueError( + f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" + ) super().__init__(*args, **kwargs) self._depth = depth diff --git a/tests/encoders/base.py b/tests/encoders/base.py index b1113437..b18be2a9 100644 --- a/tests/encoders/base.py +++ b/tests/encoders/base.py @@ -149,6 +149,12 @@ def test_depth(self): f"Encoder `{encoder_name}` should have {depth + 1} out_channels, but has {len(encoder.out_channels)}", ) + def test_invalid_depth(self): + with self.assertRaises(ValueError): + smp.encoders.get_encoder(self.encoder_names[0], depth=6) + with self.assertRaises(ValueError): + smp.encoders.get_encoder(self.encoder_names[0], depth=0) + def test_dilated(self): sample = self._get_sample().to(default_device) From 7bb9d378cd33a261e6bff16440848d49d97e2188 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Wed, 15 Jan 2025 12:53:33 +0000 Subject: [PATCH 56/57] Update segmentation_models_pytorch/encoders/__init__.py Co-authored-by: Adam J. Stewart --- segmentation_models_pytorch/encoders/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/segmentation_models_pytorch/encoders/__init__.py b/segmentation_models_pytorch/encoders/__init__.py index 9f1f3be5..c087b17d 100644 --- a/segmentation_models_pytorch/encoders/__init__.py +++ b/segmentation_models_pytorch/encoders/__init__.py @@ -66,7 +66,8 @@ def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, ** if name.startswith("timm-"): warnings.warn( "`timm-` encoders are deprecated and will be removed in the future. " - "Please use `tu-` equivalent encoders instead (see 'Timm encoders' section in the documentation)." + "Please use `tu-` equivalent encoders instead (see 'Timm encoders' section in the documentation).", + DeprecationWarning ) # convert timm- models to tu- models From da24de9b6b9c8307364ae0cd1068c481dcccb94f Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Wed, 15 Jan 2025 12:57:17 +0000 Subject: [PATCH 57/57] Style --- segmentation_models_pytorch/encoders/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/segmentation_models_pytorch/encoders/__init__.py b/segmentation_models_pytorch/encoders/__init__.py index c087b17d..3d71f49a 100644 --- a/segmentation_models_pytorch/encoders/__init__.py +++ b/segmentation_models_pytorch/encoders/__init__.py @@ -67,7 +67,7 @@ def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, ** warnings.warn( "`timm-` encoders are deprecated and will be removed in the future. " "Please use `tu-` equivalent encoders instead (see 'Timm encoders' section in the documentation).", - DeprecationWarning + DeprecationWarning, ) # convert timm- models to tu- models