Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable any resolution for Unet #1029

Merged
merged 7 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion segmentation_models_pytorch/base/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import torch
from typing import TypeVar, Type

from . import initialization as init
from .hub_mixin import SMPHubMixin

T = TypeVar("T", bound="SegmentationModel")


class SegmentationModel(torch.nn.Module, SMPHubMixin):
"""Base class for all segmentation models."""
Expand All @@ -11,6 +14,11 @@ class SegmentationModel(torch.nn.Module, SMPHubMixin):
# set to False
requires_divisible_input_shape = True

# Fix type-hint for models, to avoid HubMixin signature
def __new__(cls: Type[T], *args, **kwargs) -> T:
instance = super().__new__(cls, *args, **kwargs)
return instance

def initialize(self):
init.initialize_decoder(self.decoder)
init.initialize_head(self.segmentation_head)
Expand Down Expand Up @@ -42,7 +50,7 @@ def check_input_shape(self, x):
def forward(self, x):
"""Sequentially pass `x` trough model`s encoder, decoder and heads"""

if not torch.jit.is_tracing() or self.requires_divisible_input_shape:
qubvel marked this conversation as resolved.
Show resolved Hide resolved
if not torch.jit.is_tracing() and self.requires_divisible_input_shape:
self.check_input_shape(x)

features = self.encoder(x)
Expand Down
110 changes: 74 additions & 36 deletions segmentation_models_pytorch/decoders/unet/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,24 @@
import torch.nn as nn
import torch.nn.functional as F

from typing import Optional, Sequence
from segmentation_models_pytorch.base import modules as md


class DecoderBlock(nn.Module):
class UnetDecoderBlock(nn.Module):
"""A decoder block in the U-Net architecture that performs upsampling and feature fusion."""

def __init__(
self,
in_channels,
skip_channels,
out_channels,
use_batchnorm=True,
attention_type=None,
in_channels: int,
skip_channels: int,
out_channels: int,
use_batchnorm: bool = True,
attention_type: Optional[str] = None,
interpolation_mode: str = "nearest",
):
super().__init__()
self.interpolation_mode = interpolation_mode
self.conv1 = md.Conv2dReLU(
in_channels + skip_channels,
out_channels,
Expand All @@ -34,19 +39,31 @@
)
self.attention2 = md.Attention(attention_type, in_channels=out_channels)

def forward(self, x, skip=None):
x = F.interpolate(x, scale_factor=2, mode="nearest")
if skip is not None:
x = torch.cat([x, skip], dim=1)
x = self.attention1(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.attention2(x)
return x
def forward(
self,
feature_map: torch.Tensor,
target_height: int,
target_width: int,
skip_connection: Optional[torch.Tensor] = None,
) -> torch.Tensor:
feature_map = F.interpolate(
feature_map,
size=(target_height, target_width),
mode=self.interpolation_mode,
)
if skip_connection is not None:
feature_map = torch.cat([feature_map, skip_connection], dim=1)
feature_map = self.attention1(feature_map)
feature_map = self.conv1(feature_map)
feature_map = self.conv2(feature_map)
feature_map = self.attention2(feature_map)
return feature_map


class UnetCenterBlock(nn.Sequential):
"""Center block of the Unet decoder. Applied to the last feature map of the encoder."""

class CenterBlock(nn.Sequential):
def __init__(self, in_channels, out_channels, use_batchnorm=True):
def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True):
conv1 = md.Conv2dReLU(
in_channels,
out_channels,
Expand All @@ -65,14 +82,21 @@


class UnetDecoder(nn.Module):
"""The decoder part of the U-Net architecture.

Takes encoded features from different stages of the encoder and progressively upsamples them while
combining with skip connections. This helps preserve fine-grained details in the final segmentation.
"""

def __init__(
self,
encoder_channels,
decoder_channels,
n_blocks=5,
use_batchnorm=True,
attention_type=None,
center=False,
encoder_channels: Sequence[int],
decoder_channels: Sequence[int],
n_blocks: int = 5,
use_batchnorm: bool = True,
attention_type: Optional[str] = None,
add_center_block: bool = False,
interpolation_mode: str = "nearest",
):
super().__init__()

Expand All @@ -94,31 +118,45 @@
skip_channels = list(encoder_channels[1:]) + [0]
out_channels = decoder_channels

if center:
self.center = CenterBlock(
if add_center_block:
self.center = UnetCenterBlock(

Check warning on line 122 in segmentation_models_pytorch/decoders/unet/decoder.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/decoders/unet/decoder.py#L122

Added line #L122 was not covered by tests
head_channels, head_channels, use_batchnorm=use_batchnorm
)
else:
self.center = nn.Identity()

# combine decoder keyword arguments
kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type)
blocks = [
DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels)
]
self.blocks = nn.ModuleList(blocks)

def forward(self, *features):
self.blocks = nn.ModuleList()
for block_in_channels, block_skip_channels, block_out_channels in zip(
in_channels, skip_channels, out_channels
):
block = UnetDecoderBlock(
block_in_channels,
block_skip_channels,
block_out_channels,
use_batchnorm=use_batchnorm,
attention_type=attention_type,
interpolation_mode=interpolation_mode,
)
self.blocks.append(block)

def forward(self, *features: torch.Tensor) -> torch.Tensor:
# spatial shapes of features: [hw, hw/2, hw/4, hw/8, ...]
spatial_shapes = [feature.shape[2:] for feature in features]
spatial_shapes = spatial_shapes[::-1]

features = features[1:] # remove first skip with same spatial resolution
features = features[::-1] # reverse channels to start from head of encoder

head = features[0]
skips = features[1:]
skip_connections = features[1:]

x = self.center(head)

for i, decoder_block in enumerate(self.blocks):
skip = skips[i] if i < len(skips) else None
x = decoder_block(x, skip)
# upsample to the next spatial shape
height, width = spatial_shapes[i + 1]
skip_connection = skip_connections[i] if i < len(skip_connections) else None
x = decoder_block(x, height, width, skip_connection=skip_connection)

return x
50 changes: 43 additions & 7 deletions segmentation_models_pytorch/decoders/unet/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, Union, Tuple, Callable
from typing import Any, Optional, Union, Callable, Sequence

from segmentation_models_pytorch.base import (
ClassificationHead,
Expand All @@ -12,10 +12,21 @@


class Unet(SegmentationModel):
"""Unet_ is a fully convolution neural network for image semantic segmentation. Consist of *encoder*
and *decoder* parts connected with *skip connections*. Encoder extract features of different spatial
resolution (skip connections) which are used by decoder to define accurate segmentation mask. Use *concatenation*
for fusing decoder blocks with skip connections.
"""
U-Net is a fully convolutional neural network architecture designed for semantic image segmentation.

It consists of two main parts:

1. An encoder (downsampling path) that extracts increasingly abstract features
2. A decoder (upsampling path) that gradually recovers spatial details

The key is the use of skip connections between corresponding encoder and decoder layers.
These connections allow the decoder to access fine-grained details from earlier encoder layers,
which helps produce more precise segmentation masks.

The skip connections work by concatenating feature maps from the encoder directly into the decoder
at corresponding resolutions. This helps preserve important spatial information that would
otherwise be lost during the encoding process.

Args:
encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
Expand All @@ -33,6 +44,8 @@ class Unet(SegmentationModel):
Available options are **True, False, "inplace"**
decoder_attention_type: Attention module used in decoder of the model. Available options are
**None** and **scse** (https://arxiv.org/abs/1808.08127).
decoder_interpolation_mode: Interpolation mode used in decoder of the model. Available options are
**"nearest"**, **"bilinear"**, **"bicubic"**, **"area"**, **"nearest-exact"**. Default is **"nearest"**.
in_channels: A number of input channels for the model, default is 3 (RGB images)
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
activation: An activation function to apply after the final convolution layer.
Expand All @@ -51,20 +64,41 @@ class Unet(SegmentationModel):
Returns:
``torch.nn.Module``: Unet

Example:
.. code-block:: python

import torch
import segmentation_models_pytorch as smp

model = smp.Unet("resnet18", encoder_weights="imagenet", classes=5)
model.eval()

# generate random images
images = torch.rand(2, 3, 256, 256)

with torch.inference_mode():
mask = model(images)

print(mask.shape)
# torch.Size([2, 5, 256, 256])

.. _Unet:
https://arxiv.org/abs/1505.04597

"""

requires_divisible_input_shape = False

@supports_config_loading
def __init__(
self,
encoder_name: str = "resnet34",
encoder_depth: int = 5,
encoder_weights: Optional[str] = "imagenet",
decoder_use_batchnorm: bool = True,
decoder_channels: Tuple[int, ...] = (256, 128, 64, 32, 16),
decoder_channels: Sequence[int] = (256, 128, 64, 32, 16),
decoder_attention_type: Optional[str] = None,
decoder_interpolation_mode: str = "nearest",
in_channels: int = 3,
classes: int = 1,
activation: Optional[Union[str, Callable]] = None,
Expand All @@ -81,13 +115,15 @@ def __init__(
**kwargs,
)

add_center_block = encoder_name.startswith("vgg")
self.decoder = UnetDecoder(
encoder_channels=self.encoder.out_channels,
decoder_channels=decoder_channels,
n_blocks=encoder_depth,
use_batchnorm=decoder_use_batchnorm,
center=True if encoder_name.startswith("vgg") else False,
add_center_block=add_center_block,
attention_type=decoder_attention_type,
interpolation_mode=decoder_interpolation_mode,
)

self.segmentation_head = SegmentationHead(
Expand Down
6 changes: 3 additions & 3 deletions tests/encoders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_in_channels(self):
encoder.eval()

# forward
with torch.no_grad():
with torch.inference_mode():
encoder.forward(sample)

def test_depth(self):
Expand Down Expand Up @@ -110,7 +110,7 @@ def test_depth(self):
encoder.eval()

# forward
with torch.no_grad():
with torch.inference_mode():
features = encoder.forward(sample)

# check number of features
Expand Down Expand Up @@ -187,7 +187,7 @@ def test_dilated(self):
encoder.eval()

# forward
with torch.no_grad():
with torch.inference_mode():
features = encoder.forward(sample)

height_strides, width_strides = self.get_features_output_strides(
Expand Down
Loading
Loading