Skip to content

Commit

Permalink
Interpolation for unet
Browse files Browse the repository at this point in the history
  • Loading branch information
qubvel committed Jan 13, 2025
1 parent eb81c1f commit d5a80df
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 9 deletions.
27 changes: 18 additions & 9 deletions segmentation_models_pytorch/decoders/unet/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from segmentation_models_pytorch.base import modules as md


class DecoderBlock(nn.Module):
class UnetDecoderBlock(nn.Module):
"""A decoder block in the U-Net architecture that performs upsampling and feature fusion."""

def __init__(
self,
in_channels: int,
Expand All @@ -17,7 +19,7 @@ def __init__(
interpolation_mode: str = "nearest",
):
super().__init__()
self.interpolate_mode = interpolation_mode
self.interpolation_mode = interpolation_mode
self.conv1 = md.Conv2dReLU(
in_channels + skip_channels,
out_channels,
Expand All @@ -44,11 +46,10 @@ def forward(
target_width: int,
skip_connection: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Upsample feature map to the given spatial shape, concatenate with skip connection,
apply attention block (if specified) and then apply two convolutions.
"""
feature_map = F.interpolate(
feature_map, size=(target_height, target_width), mode=self.interpolate_mode
feature_map,
size=(target_height, target_width),
mode=self.interpolation_mode,
)
if skip_connection is not None:
feature_map = torch.cat([feature_map, skip_connection], dim=1)
Expand All @@ -59,7 +60,7 @@ def forward(
return feature_map


class CenterBlock(nn.Sequential):
class UnetCenterBlock(nn.Sequential):
"""Center block of the Unet decoder. Applied to the last feature map of the encoder."""

def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True):
Expand All @@ -81,6 +82,12 @@ def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = Tr


class UnetDecoder(nn.Module):
"""The decoder part of the U-Net architecture.
Takes encoded features from different stages of the encoder and progressively upsamples them while
combining with skip connections. This helps preserve fine-grained details in the final segmentation.
"""

def __init__(
self,
encoder_channels: Sequence[int],
Expand All @@ -89,6 +96,7 @@ def __init__(
use_batchnorm: bool = True,
attention_type: Optional[str] = None,
add_center_block: bool = False,
interpolation_mode: str = "nearest",
):
super().__init__()

Expand All @@ -111,7 +119,7 @@ def __init__(
out_channels = decoder_channels

if add_center_block:
self.center = CenterBlock(
self.center = UnetCenterBlock(

Check warning on line 122 in segmentation_models_pytorch/decoders/unet/decoder.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/decoders/unet/decoder.py#L122

Added line #L122 was not covered by tests
head_channels, head_channels, use_batchnorm=use_batchnorm
)
else:
Expand All @@ -122,12 +130,13 @@ def __init__(
for block_in_channels, block_skip_channels, block_out_channels in zip(
in_channels, skip_channels, out_channels
):
block = DecoderBlock(
block = UnetDecoderBlock(
block_in_channels,
block_skip_channels,
block_out_channels,
use_batchnorm=use_batchnorm,
attention_type=attention_type,
interpolation_mode=interpolation_mode,
)
self.blocks.append(block)

Expand Down
4 changes: 4 additions & 0 deletions segmentation_models_pytorch/decoders/unet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class Unet(SegmentationModel):
Available options are **True, False, "inplace"**
decoder_attention_type: Attention module used in decoder of the model. Available options are
**None** and **scse** (https://arxiv.org/abs/1808.08127).
decoder_interpolation_mode: Interpolation mode used in decoder of the model. Available options are
**"nearest"**, **"bilinear"**, **"bicubic"**, **"area"**, **"nearest-exact"**. Default is **"nearest"**.
in_channels: A number of input channels for the model, default is 3 (RGB images)
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
activation: An activation function to apply after the final convolution layer.
Expand Down Expand Up @@ -96,6 +98,7 @@ def __init__(
decoder_use_batchnorm: bool = True,
decoder_channels: Sequence[int] = (256, 128, 64, 32, 16),
decoder_attention_type: Optional[str] = None,
decoder_interpolation_mode: str = "nearest",
in_channels: int = 3,
classes: int = 1,
activation: Optional[Union[str, Callable]] = None,
Expand All @@ -120,6 +123,7 @@ def __init__(
use_batchnorm=decoder_use_batchnorm,
add_center_block=add_center_block,
attention_type=decoder_attention_type,
interpolation_mode=decoder_interpolation_mode,
)

self.segmentation_head = SegmentationHead(
Expand Down

0 comments on commit d5a80df

Please sign in to comment.