From d5a80df0bf1f19892e79af4907ee8f556d603403 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 13 Jan 2025 16:42:54 +0000 Subject: [PATCH] Interpolation for unet --- .../decoders/unet/decoder.py | 27 ++++++++++++------- .../decoders/unet/model.py | 4 +++ 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/segmentation_models_pytorch/decoders/unet/decoder.py b/segmentation_models_pytorch/decoders/unet/decoder.py index 4c2a6711..e6bf4d16 100644 --- a/segmentation_models_pytorch/decoders/unet/decoder.py +++ b/segmentation_models_pytorch/decoders/unet/decoder.py @@ -6,7 +6,9 @@ from segmentation_models_pytorch.base import modules as md -class DecoderBlock(nn.Module): +class UnetDecoderBlock(nn.Module): + """A decoder block in the U-Net architecture that performs upsampling and feature fusion.""" + def __init__( self, in_channels: int, @@ -17,7 +19,7 @@ def __init__( interpolation_mode: str = "nearest", ): super().__init__() - self.interpolate_mode = interpolation_mode + self.interpolation_mode = interpolation_mode self.conv1 = md.Conv2dReLU( in_channels + skip_channels, out_channels, @@ -44,11 +46,10 @@ def forward( target_width: int, skip_connection: Optional[torch.Tensor] = None, ) -> torch.Tensor: - """Upsample feature map to the given spatial shape, concatenate with skip connection, - apply attention block (if specified) and then apply two convolutions. - """ feature_map = F.interpolate( - feature_map, size=(target_height, target_width), mode=self.interpolate_mode + feature_map, + size=(target_height, target_width), + mode=self.interpolation_mode, ) if skip_connection is not None: feature_map = torch.cat([feature_map, skip_connection], dim=1) @@ -59,7 +60,7 @@ def forward( return feature_map -class CenterBlock(nn.Sequential): +class UnetCenterBlock(nn.Sequential): """Center block of the Unet decoder. Applied to the last feature map of the encoder.""" def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True): @@ -81,6 +82,12 @@ def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = Tr class UnetDecoder(nn.Module): + """The decoder part of the U-Net architecture. + + Takes encoded features from different stages of the encoder and progressively upsamples them while + combining with skip connections. This helps preserve fine-grained details in the final segmentation. + """ + def __init__( self, encoder_channels: Sequence[int], @@ -89,6 +96,7 @@ def __init__( use_batchnorm: bool = True, attention_type: Optional[str] = None, add_center_block: bool = False, + interpolation_mode: str = "nearest", ): super().__init__() @@ -111,7 +119,7 @@ def __init__( out_channels = decoder_channels if add_center_block: - self.center = CenterBlock( + self.center = UnetCenterBlock( head_channels, head_channels, use_batchnorm=use_batchnorm ) else: @@ -122,12 +130,13 @@ def __init__( for block_in_channels, block_skip_channels, block_out_channels in zip( in_channels, skip_channels, out_channels ): - block = DecoderBlock( + block = UnetDecoderBlock( block_in_channels, block_skip_channels, block_out_channels, use_batchnorm=use_batchnorm, attention_type=attention_type, + interpolation_mode=interpolation_mode, ) self.blocks.append(block) diff --git a/segmentation_models_pytorch/decoders/unet/model.py b/segmentation_models_pytorch/decoders/unet/model.py index 660eb21d..4b30527d 100644 --- a/segmentation_models_pytorch/decoders/unet/model.py +++ b/segmentation_models_pytorch/decoders/unet/model.py @@ -44,6 +44,8 @@ class Unet(SegmentationModel): Available options are **True, False, "inplace"** decoder_attention_type: Attention module used in decoder of the model. Available options are **None** and **scse** (https://arxiv.org/abs/1808.08127). + decoder_interpolation_mode: Interpolation mode used in decoder of the model. Available options are + **"nearest"**, **"bilinear"**, **"bicubic"**, **"area"**, **"nearest-exact"**. Default is **"nearest"**. in_channels: A number of input channels for the model, default is 3 (RGB images) classes: A number of classes for output mask (or you can think as a number of channels of output mask) activation: An activation function to apply after the final convolution layer. @@ -96,6 +98,7 @@ def __init__( decoder_use_batchnorm: bool = True, decoder_channels: Sequence[int] = (256, 128, 64, 32, 16), decoder_attention_type: Optional[str] = None, + decoder_interpolation_mode: str = "nearest", in_channels: int = 3, classes: int = 1, activation: Optional[Union[str, Callable]] = None, @@ -120,6 +123,7 @@ def __init__( use_batchnorm=decoder_use_batchnorm, add_center_block=add_center_block, attention_type=decoder_attention_type, + interpolation_mode=decoder_interpolation_mode, ) self.segmentation_head = SegmentationHead(