Skip to content

Commit

Permalink
Fix ruff style
Browse files Browse the repository at this point in the history
  • Loading branch information
brianhou0208 committed Dec 18, 2024
1 parent 8b0fece commit 330e6e5
Showing 1 changed file with 23 additions and 23 deletions.
46 changes: 23 additions & 23 deletions segmentation_models_pytorch/encoders/timm_universal.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
"""
TimmUniversalEncoder provides a unified feature extraction interface built on the
`timm` library, supporting both traditional-style (e.g., ResNet) and transformer-style
TimmUniversalEncoder provides a unified feature extraction interface built on the
`timm` library, supporting both traditional-style (e.g., ResNet) and transformer-style
models (e.g., Swin Transformer, ConvNeXt).
This encoder produces consistent multi-level feature maps for semantic segmentation tasks.
It allows configuring the number of feature extraction stages (`depth`) and adjusting
This encoder produces consistent multi-level feature maps for semantic segmentation tasks.
It allows configuring the number of feature extraction stages (`depth`) and adjusting
`output_stride` when supported.
Key Features:
- Flexible model selection using `timm.create_model`.
- Unified multi-level output across different model hierarchies.
- Unified multi-level output across different model hierarchies.
- Automatic alignment for inconsistent feature scales:
- Transformer-style models (start at 1/4 scale): Insert dummy features for 1/2 scale.
- VGG-style models (include scale-1 features): Align outputs for compatibility.
- Transformer-style models (start at 1/4 scale): Insert dummy features for 1/2 scale.
- VGG-style models (include scale-1 features): Align outputs for compatibility.
- Easy access to feature scale information via the `reduction` property.
Feature Scale Differences:
- Traditional-style models (e.g., ResNet): Scales at 1/2, 1/4, 1/8, 1/16, 1/32.
- Transformer-style models (e.g., Swin Transformer): Start at 1/4 scale, skip 1/2 scale.
- Traditional-style models (e.g., ResNet): Scales at 1/2, 1/4, 1/8, 1/16, 1/32.
- Transformer-style models (e.g., Swin Transformer): Start at 1/4 scale, skip 1/2 scale.
- VGG-style models: Include scale-1 features (input resolution).
Notes:
- `output_stride` is unsupported in some models, especially transformer-based architectures.
- Special handling for models like TResNet and DLA to ensure correct feature indexing.
- VGG-style models use `_is_skip_first` to align scale-1 features with standard outputs.
- `output_stride` is unsupported in some models, especially transformer-based architectures.
- Special handling for models like TResNet and DLA to ensure correct feature indexing.
- VGG-style models use `_is_vgg_style` to align scale-1 features with standard outputs.
"""

from typing import Any
Expand All @@ -35,7 +35,7 @@

class TimmUniversalEncoder(nn.Module):
"""
A universal encoder leveraging the `timm` library for feature extraction from
A universal encoder leveraging the `timm` library for feature extraction from
various model architectures, including traditional-style and transformer-style models.
Features:
Expand Down Expand Up @@ -92,15 +92,15 @@ def __init__(
if reduction_scales == [2 ** (i + 2) for i in range(encoder_stage)]:
# Transformer-style downsampling: scales (4, 8, 16, 32)
self._is_transformer_style = True
self._is_skip_first = False
self._is_vgg_style = False
elif reduction_scales == [2 ** (i + 1) for i in range(encoder_stage)]:
# Traditional-style downsampling: scales (2, 4, 8, 16, 32)
self._is_transformer_style = False
self._is_skip_first = False
elif reduction_scales == [2 ** i for i in range(encoder_stage)]:
# Models including scale 1: scales (1, 2, 4, 8, 16, 32)
self._is_vgg_style = False
elif reduction_scales == [2**i for i in range(encoder_stage)]:
# Vgg-style models including scale 1: scales (1, 2, 4, 8, 16, 32)
self._is_transformer_style = False
self._is_skip_first = True
self._is_vgg_style = True
else:
raise ValueError("Unsupported model downsampling pattern.")

Expand All @@ -125,14 +125,14 @@ def __init__(
if "dla" in name:
# For 'dla' models, out_indices starts at 0 and matches the input size.
common_kwargs["out_indices"] = tuple(range(1, depth + 1))
if self._is_skip_first:
if self._is_vgg_style:
common_kwargs["out_indices"] = tuple(range(depth + 1))

self.model = timm.create_model(
name, **_merge_kwargs_no_duplicates(common_kwargs, kwargs)
)

if self._is_skip_first:
if self._is_vgg_style:
self._out_channels = self.model.feature_info.channels()
else:
self._out_channels = [in_channels] + self.model.feature_info.channels()
Expand Down Expand Up @@ -164,9 +164,9 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
B, _, H, W = x.shape
dummy = torch.empty([B, 0, H // 2, W // 2], dtype=x.dtype, device=x.device)
features = [dummy] + features
# Add input tensor as scale 1 feature if `self._is_skip_first` is False
if not self._is_skip_first:

# Add input tensor as scale 1 feature if `self._is_vgg_style` is False
if not self._is_vgg_style:
features = [x] + features

return features
Expand Down

0 comments on commit 330e6e5

Please sign in to comment.