Skip to content

Commit

Permalink
Enable any resolution for Unet (#1029)
Browse files Browse the repository at this point in the history
* Fix type hint for models

* Use inference mode in tests

* Add test for any resolution (not divisible by 32)

* Use inference mode in tests

* Enable any res for Unet and better docs

* Fix check_input_shape condition

* Interpolation for unet
  • Loading branch information
qubvel authored Jan 13, 2025
1 parent eaf8be6 commit 93b19d3
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 75 deletions.
10 changes: 9 additions & 1 deletion segmentation_models_pytorch/base/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import torch
from typing import TypeVar, Type

from . import initialization as init
from .hub_mixin import SMPHubMixin

T = TypeVar("T", bound="SegmentationModel")


class SegmentationModel(torch.nn.Module, SMPHubMixin):
"""Base class for all segmentation models."""
Expand All @@ -11,6 +14,11 @@ class SegmentationModel(torch.nn.Module, SMPHubMixin):
# set to False
requires_divisible_input_shape = True

# Fix type-hint for models, to avoid HubMixin signature
def __new__(cls: Type[T], *args, **kwargs) -> T:
instance = super().__new__(cls, *args, **kwargs)
return instance

def initialize(self):
init.initialize_decoder(self.decoder)
init.initialize_head(self.segmentation_head)
Expand Down Expand Up @@ -42,7 +50,7 @@ def check_input_shape(self, x):
def forward(self, x):
"""Sequentially pass `x` trough model`s encoder, decoder and heads"""

if not torch.jit.is_tracing() or self.requires_divisible_input_shape:
if not torch.jit.is_tracing() and self.requires_divisible_input_shape:
self.check_input_shape(x)

features = self.encoder(x)
Expand Down
110 changes: 74 additions & 36 deletions segmentation_models_pytorch/decoders/unet/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,24 @@
import torch.nn as nn
import torch.nn.functional as F

from typing import Optional, Sequence
from segmentation_models_pytorch.base import modules as md


class DecoderBlock(nn.Module):
class UnetDecoderBlock(nn.Module):
"""A decoder block in the U-Net architecture that performs upsampling and feature fusion."""

def __init__(
self,
in_channels,
skip_channels,
out_channels,
use_batchnorm=True,
attention_type=None,
in_channels: int,
skip_channels: int,
out_channels: int,
use_batchnorm: bool = True,
attention_type: Optional[str] = None,
interpolation_mode: str = "nearest",
):
super().__init__()
self.interpolation_mode = interpolation_mode
self.conv1 = md.Conv2dReLU(
in_channels + skip_channels,
out_channels,
Expand All @@ -34,19 +39,31 @@ def __init__(
)
self.attention2 = md.Attention(attention_type, in_channels=out_channels)

def forward(self, x, skip=None):
x = F.interpolate(x, scale_factor=2, mode="nearest")
if skip is not None:
x = torch.cat([x, skip], dim=1)
x = self.attention1(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.attention2(x)
return x
def forward(
self,
feature_map: torch.Tensor,
target_height: int,
target_width: int,
skip_connection: Optional[torch.Tensor] = None,
) -> torch.Tensor:
feature_map = F.interpolate(
feature_map,
size=(target_height, target_width),
mode=self.interpolation_mode,
)
if skip_connection is not None:
feature_map = torch.cat([feature_map, skip_connection], dim=1)
feature_map = self.attention1(feature_map)
feature_map = self.conv1(feature_map)
feature_map = self.conv2(feature_map)
feature_map = self.attention2(feature_map)
return feature_map


class UnetCenterBlock(nn.Sequential):
"""Center block of the Unet decoder. Applied to the last feature map of the encoder."""

class CenterBlock(nn.Sequential):
def __init__(self, in_channels, out_channels, use_batchnorm=True):
def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True):
conv1 = md.Conv2dReLU(
in_channels,
out_channels,
Expand All @@ -65,14 +82,21 @@ def __init__(self, in_channels, out_channels, use_batchnorm=True):


class UnetDecoder(nn.Module):
"""The decoder part of the U-Net architecture.
Takes encoded features from different stages of the encoder and progressively upsamples them while
combining with skip connections. This helps preserve fine-grained details in the final segmentation.
"""

def __init__(
self,
encoder_channels,
decoder_channels,
n_blocks=5,
use_batchnorm=True,
attention_type=None,
center=False,
encoder_channels: Sequence[int],
decoder_channels: Sequence[int],
n_blocks: int = 5,
use_batchnorm: bool = True,
attention_type: Optional[str] = None,
add_center_block: bool = False,
interpolation_mode: str = "nearest",
):
super().__init__()

Expand All @@ -94,31 +118,45 @@ def __init__(
skip_channels = list(encoder_channels[1:]) + [0]
out_channels = decoder_channels

if center:
self.center = CenterBlock(
if add_center_block:
self.center = UnetCenterBlock(
head_channels, head_channels, use_batchnorm=use_batchnorm
)
else:
self.center = nn.Identity()

# combine decoder keyword arguments
kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type)
blocks = [
DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels)
]
self.blocks = nn.ModuleList(blocks)

def forward(self, *features):
self.blocks = nn.ModuleList()
for block_in_channels, block_skip_channels, block_out_channels in zip(
in_channels, skip_channels, out_channels
):
block = UnetDecoderBlock(
block_in_channels,
block_skip_channels,
block_out_channels,
use_batchnorm=use_batchnorm,
attention_type=attention_type,
interpolation_mode=interpolation_mode,
)
self.blocks.append(block)

def forward(self, *features: torch.Tensor) -> torch.Tensor:
# spatial shapes of features: [hw, hw/2, hw/4, hw/8, ...]
spatial_shapes = [feature.shape[2:] for feature in features]
spatial_shapes = spatial_shapes[::-1]

features = features[1:] # remove first skip with same spatial resolution
features = features[::-1] # reverse channels to start from head of encoder

head = features[0]
skips = features[1:]
skip_connections = features[1:]

x = self.center(head)

for i, decoder_block in enumerate(self.blocks):
skip = skips[i] if i < len(skips) else None
x = decoder_block(x, skip)
# upsample to the next spatial shape
height, width = spatial_shapes[i + 1]
skip_connection = skip_connections[i] if i < len(skip_connections) else None
x = decoder_block(x, height, width, skip_connection=skip_connection)

return x
50 changes: 43 additions & 7 deletions segmentation_models_pytorch/decoders/unet/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, Union, Tuple, Callable
from typing import Any, Optional, Union, Callable, Sequence

from segmentation_models_pytorch.base import (
ClassificationHead,
Expand All @@ -12,10 +12,21 @@


class Unet(SegmentationModel):
"""Unet_ is a fully convolution neural network for image semantic segmentation. Consist of *encoder*
and *decoder* parts connected with *skip connections*. Encoder extract features of different spatial
resolution (skip connections) which are used by decoder to define accurate segmentation mask. Use *concatenation*
for fusing decoder blocks with skip connections.
"""
U-Net is a fully convolutional neural network architecture designed for semantic image segmentation.
It consists of two main parts:
1. An encoder (downsampling path) that extracts increasingly abstract features
2. A decoder (upsampling path) that gradually recovers spatial details
The key is the use of skip connections between corresponding encoder and decoder layers.
These connections allow the decoder to access fine-grained details from earlier encoder layers,
which helps produce more precise segmentation masks.
The skip connections work by concatenating feature maps from the encoder directly into the decoder
at corresponding resolutions. This helps preserve important spatial information that would
otherwise be lost during the encoding process.
Args:
encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
Expand All @@ -33,6 +44,8 @@ class Unet(SegmentationModel):
Available options are **True, False, "inplace"**
decoder_attention_type: Attention module used in decoder of the model. Available options are
**None** and **scse** (https://arxiv.org/abs/1808.08127).
decoder_interpolation_mode: Interpolation mode used in decoder of the model. Available options are
**"nearest"**, **"bilinear"**, **"bicubic"**, **"area"**, **"nearest-exact"**. Default is **"nearest"**.
in_channels: A number of input channels for the model, default is 3 (RGB images)
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
activation: An activation function to apply after the final convolution layer.
Expand All @@ -51,20 +64,41 @@ class Unet(SegmentationModel):
Returns:
``torch.nn.Module``: Unet
Example:
.. code-block:: python
import torch
import segmentation_models_pytorch as smp
model = smp.Unet("resnet18", encoder_weights="imagenet", classes=5)
model.eval()
# generate random images
images = torch.rand(2, 3, 256, 256)
with torch.inference_mode():
mask = model(images)
print(mask.shape)
# torch.Size([2, 5, 256, 256])
.. _Unet:
https://arxiv.org/abs/1505.04597
"""

requires_divisible_input_shape = False

@supports_config_loading
def __init__(
self,
encoder_name: str = "resnet34",
encoder_depth: int = 5,
encoder_weights: Optional[str] = "imagenet",
decoder_use_batchnorm: bool = True,
decoder_channels: Tuple[int, ...] = (256, 128, 64, 32, 16),
decoder_channels: Sequence[int] = (256, 128, 64, 32, 16),
decoder_attention_type: Optional[str] = None,
decoder_interpolation_mode: str = "nearest",
in_channels: int = 3,
classes: int = 1,
activation: Optional[Union[str, Callable]] = None,
Expand All @@ -81,13 +115,15 @@ def __init__(
**kwargs,
)

add_center_block = encoder_name.startswith("vgg")
self.decoder = UnetDecoder(
encoder_channels=self.encoder.out_channels,
decoder_channels=decoder_channels,
n_blocks=encoder_depth,
use_batchnorm=decoder_use_batchnorm,
center=True if encoder_name.startswith("vgg") else False,
add_center_block=add_center_block,
attention_type=decoder_attention_type,
interpolation_mode=decoder_interpolation_mode,
)

self.segmentation_head = SegmentationHead(
Expand Down
6 changes: 3 additions & 3 deletions tests/encoders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_in_channels(self):
encoder.eval()

# forward
with torch.no_grad():
with torch.inference_mode():
encoder.forward(sample)

def test_depth(self):
Expand Down Expand Up @@ -110,7 +110,7 @@ def test_depth(self):
encoder.eval()

# forward
with torch.no_grad():
with torch.inference_mode():
features = encoder.forward(sample)

# check number of features
Expand Down Expand Up @@ -187,7 +187,7 @@ def test_dilated(self):
encoder.eval()

# forward
with torch.no_grad():
with torch.inference_mode():
features = encoder.forward(sample)

height_strides, width_strides = self.get_features_output_strides(
Expand Down
Loading

0 comments on commit 93b19d3

Please sign in to comment.