Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable any resolution for Unet #1029

Merged
merged 7 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion segmentation_models_pytorch/base/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import torch
from typing import TypeVar, Type

from . import initialization as init
from .hub_mixin import SMPHubMixin

T = TypeVar("T", bound="SegmentationModel")


class SegmentationModel(torch.nn.Module, SMPHubMixin):
"""Base class for all segmentation models."""
Expand All @@ -11,6 +14,11 @@ class SegmentationModel(torch.nn.Module, SMPHubMixin):
# set to False
requires_divisible_input_shape = True

# Fix type-hint for models, to avoid HubMixin signature
def __new__(cls: Type[T], *args, **kwargs) -> T:
instance = super().__new__(cls, *args, **kwargs)
return instance

def initialize(self):
init.initialize_decoder(self.decoder)
init.initialize_head(self.segmentation_head)
Expand Down Expand Up @@ -42,7 +50,7 @@ def check_input_shape(self, x):
def forward(self, x):
"""Sequentially pass `x` trough model`s encoder, decoder and heads"""

if not torch.jit.is_tracing() or self.requires_divisible_input_shape:
qubvel marked this conversation as resolved.
Show resolved Hide resolved
if not torch.jit.is_tracing() and self.requires_divisible_input_shape:
self.check_input_shape(x)

features = self.encoder(x)
Expand Down
95 changes: 62 additions & 33 deletions segmentation_models_pytorch/decoders/unet/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,22 @@
import torch.nn as nn
import torch.nn.functional as F

from typing import Optional, Sequence
from segmentation_models_pytorch.base import modules as md


class DecoderBlock(nn.Module):
def __init__(
self,
in_channels,
skip_channels,
out_channels,
use_batchnorm=True,
attention_type=None,
in_channels: int,
skip_channels: int,
out_channels: int,
use_batchnorm: bool = True,
attention_type: Optional[str] = None,
interpolation_mode: str = "nearest",
):
super().__init__()
self.interpolate_mode = interpolation_mode
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit confusing to me to see both "interpolate" and "interpolation", maybe we can make these consistent?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch! will fix it

self.conv1 = md.Conv2dReLU(
in_channels + skip_channels,
out_channels,
Expand All @@ -34,19 +37,32 @@ def __init__(
)
self.attention2 = md.Attention(attention_type, in_channels=out_channels)

def forward(self, x, skip=None):
x = F.interpolate(x, scale_factor=2, mode="nearest")
if skip is not None:
x = torch.cat([x, skip], dim=1)
x = self.attention1(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.attention2(x)
return x
def forward(
self,
feature_map: torch.Tensor,
target_height: int,
target_width: int,
skip_connection: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Upsample feature map to the given spatial shape, concatenate with skip connection,
apply attention block (if specified) and then apply two convolutions.
"""
feature_map = F.interpolate(
feature_map, size=(target_height, target_width), mode=self.interpolate_mode
)
if skip_connection is not None:
feature_map = torch.cat([feature_map, skip_connection], dim=1)
feature_map = self.attention1(feature_map)
feature_map = self.conv1(feature_map)
feature_map = self.conv2(feature_map)
feature_map = self.attention2(feature_map)
return feature_map


class CenterBlock(nn.Sequential):
def __init__(self, in_channels, out_channels, use_batchnorm=True):
"""Center block of the Unet decoder. Applied to the last feature map of the encoder."""

def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True):
conv1 = md.Conv2dReLU(
in_channels,
out_channels,
Expand All @@ -67,12 +83,12 @@ def __init__(self, in_channels, out_channels, use_batchnorm=True):
class UnetDecoder(nn.Module):
def __init__(
self,
encoder_channels,
decoder_channels,
n_blocks=5,
use_batchnorm=True,
attention_type=None,
center=False,
encoder_channels: Sequence[int],
decoder_channels: Sequence[int],
n_blocks: int = 5,
use_batchnorm: bool = True,
attention_type: Optional[str] = None,
add_center_block: bool = False,
):
super().__init__()

Expand All @@ -94,31 +110,44 @@ def __init__(
skip_channels = list(encoder_channels[1:]) + [0]
out_channels = decoder_channels

if center:
if add_center_block:
self.center = CenterBlock(
head_channels, head_channels, use_batchnorm=use_batchnorm
)
else:
self.center = nn.Identity()

# combine decoder keyword arguments
kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type)
blocks = [
DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels)
]
self.blocks = nn.ModuleList(blocks)

def forward(self, *features):
self.blocks = nn.ModuleList()
for block_in_channels, block_skip_channels, block_out_channels in zip(
in_channels, skip_channels, out_channels
):
block = DecoderBlock(
block_in_channels,
block_skip_channels,
block_out_channels,
use_batchnorm=use_batchnorm,
attention_type=attention_type,
)
self.blocks.append(block)

def forward(self, *features: torch.Tensor) -> torch.Tensor:
# spatial shapes of features: [hw, hw/2, hw/4, hw/8, ...]
spatial_shapes = [feature.shape[2:] for feature in features]
spatial_shapes = spatial_shapes[::-1]

features = features[1:] # remove first skip with same spatial resolution
features = features[::-1] # reverse channels to start from head of encoder

head = features[0]
skips = features[1:]
skip_connections = features[1:]

x = self.center(head)

for i, decoder_block in enumerate(self.blocks):
skip = skips[i] if i < len(skips) else None
x = decoder_block(x, skip)
# upsample to the next spatial shape
height, width = spatial_shapes[i + 1]
skip_connection = skip_connections[i] if i < len(skip_connections) else None
x = decoder_block(x, height, width, skip_connection=skip_connection)

return x
46 changes: 39 additions & 7 deletions segmentation_models_pytorch/decoders/unet/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, Union, Tuple, Callable
from typing import Any, Optional, Union, Callable, Sequence

from segmentation_models_pytorch.base import (
ClassificationHead,
Expand All @@ -12,10 +12,21 @@


class Unet(SegmentationModel):
"""Unet_ is a fully convolution neural network for image semantic segmentation. Consist of *encoder*
and *decoder* parts connected with *skip connections*. Encoder extract features of different spatial
resolution (skip connections) which are used by decoder to define accurate segmentation mask. Use *concatenation*
for fusing decoder blocks with skip connections.
"""
U-Net is a fully convolutional neural network architecture designed for semantic image segmentation.

It consists of two main parts:

1. An encoder (downsampling path) that extracts increasingly abstract features
2. A decoder (upsampling path) that gradually recovers spatial details

The key is the use of skip connections between corresponding encoder and decoder layers.
These connections allow the decoder to access fine-grained details from earlier encoder layers,
which helps produce more precise segmentation masks.

The skip connections work by concatenating feature maps from the encoder directly into the decoder
at corresponding resolutions. This helps preserve important spatial information that would
otherwise be lost during the encoding process.

Args:
encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
Expand Down Expand Up @@ -51,19 +62,39 @@ class Unet(SegmentationModel):
Returns:
``torch.nn.Module``: Unet

Example:
.. code-block:: python

import torch
import segmentation_models_pytorch as smp

model = smp.Unet("resnet18", encoder_weights="imagenet", classes=5)
model.eval()

# generate random images
images = torch.rand(2, 3, 256, 256)

with torch.inference_mode():
mask = model(images)

print(mask.shape)
# torch.Size([2, 5, 256, 256])

.. _Unet:
https://arxiv.org/abs/1505.04597

"""

requires_divisible_input_shape = False

@supports_config_loading
def __init__(
self,
encoder_name: str = "resnet34",
encoder_depth: int = 5,
encoder_weights: Optional[str] = "imagenet",
decoder_use_batchnorm: bool = True,
decoder_channels: Tuple[int, ...] = (256, 128, 64, 32, 16),
decoder_channels: Sequence[int] = (256, 128, 64, 32, 16),
decoder_attention_type: Optional[str] = None,
in_channels: int = 3,
classes: int = 1,
Expand All @@ -81,12 +112,13 @@ def __init__(
**kwargs,
)

add_center_block = encoder_name.startswith("vgg")
self.decoder = UnetDecoder(
encoder_channels=self.encoder.out_channels,
decoder_channels=decoder_channels,
n_blocks=encoder_depth,
use_batchnorm=decoder_use_batchnorm,
center=True if encoder_name.startswith("vgg") else False,
add_center_block=add_center_block,
attention_type=decoder_attention_type,
)

Expand Down
6 changes: 3 additions & 3 deletions tests/encoders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_in_channels(self):
encoder.eval()

# forward
with torch.no_grad():
with torch.inference_mode():
encoder.forward(sample)

def test_depth(self):
Expand Down Expand Up @@ -110,7 +110,7 @@ def test_depth(self):
encoder.eval()

# forward
with torch.no_grad():
with torch.inference_mode():
features = encoder.forward(sample)

# check number of features
Expand Down Expand Up @@ -187,7 +187,7 @@ def test_dilated(self):
encoder.eval()

# forward
with torch.no_grad():
with torch.inference_mode():
features = encoder.forward(sample)

height_strides, width_strides = self.get_features_output_strides(
Expand Down
Loading
Loading