Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update timm universal (support transformer-style model) #1004

Merged
merged 14 commits into from
Dec 19, 2024

Conversation

brianhou0208
Copy link
Contributor

@brianhou0208 brianhou0208 commented Dec 7, 2024

Hi, @qubvel ,

This PR improves TimmUniversalEncoder to better support transformer-style models, updates the documentation, and ensures compatibility with timm==1.0.12.


Key Updates

  1. Enhanced Model Compatibility
    • Enhanced TimmUniversalEncoder to seamlessly handle both traditional (CNN-based) and transformer-style models.
    • Standardizes feature resolutions to output strides (1/2 to 1/32), representing the encoder’s progressive downsampling.
  2. Testing with timm==1.0.12
    • Exhaustively tested all models with features_only enabled, classifying them as either traditional or transformer-style models.

Details of Changes

TimmUniversalEncoder

  1. Feature Enhancements
    • Improved type annotations, added detailed comments, and updated the documentation.
    • Leveraged feature_info.reduction() to determine whether a model is traditional or transformer-style.
    • Introduced dummy channels and features for transformer-style models, inspired by MixVisionTransformer.
    • Customized out_indices for models like tresnet and dla to ensure accurate feature extraction.

Documentation

  1. Updated the total supported model count:
    • Original documentation mentioned 549 models, but the actual number is 533.
  2. Removed models that are deprecated in the latest version of timm.
  3. Updated model names to reflect the renamed models in timm.
  4. Removed models with output resolutions not conforming to output strides (1/2 to 1/32).
  5. Added a new table listing the currently supported transformer-style models.

Testing

  1. Conducted exhaustive tests on timm==1.0.12 for all models with features_only enabled.
  2. Verified that feature extraction aligns with the expected resolution standards for both traditional and transformer-style models.

@brianhou0208
Copy link
Contributor Author

brianhou0208 commented Dec 8, 2024

Output Stride Not Matching

The following models were removed from the list as their output strides do not match the expected values:

  • inception_resnet_v2, inception_v3, inception_v4, legacy_xception, nasnetalarge, pnasnet5large
encoder \ depth input shape 1 2 3 4 5
inception_resnet_v2 (256, 256) (125, 125) (60, 60) (29, 29) (14, 14) (6, 6)
inception_v3 (256, 256) (125, 125) (60, 60) (29, 29) (14, 14) (6, 6)
inception_v4 (256, 256) (125, 125) (62, 62) (29, 29) (14, 14) (6, 6)
legacy_xception (256, 256) (125, 125) (63, 63) (32, 32) (16, 16) (8, 8)
nasnetalarge (256, 256) (127, 127) (64, 64) (32, 32) (16, 16) (8, 8)
pnasnet5large (256, 256) (127, 127) (64, 64) (32, 32) (16, 16) (8, 8)

test code

import torch
import segmentation_models_pytorch as smp

model_list = [
    "inception_resnet_v2",
    "inception_v3",
    "inception_v4",
    "legacy_xception",
    "nasnetalarge",
    "pnasnet5large",
]

if __name__ == "__main__":
    x = torch.rand(1, 3, 256, 256)
    for name in model_list:
        model = smp.encoders.get_encoder(f"tu-{name}", weights=None).eval()
        f = model(x)
        print(name, [f_.detach().numpy().shape[2:] for f_ in f])

output

inception_resnet_v2 [(256, 256), (125, 125), (60, 60), (29, 29), (14, 14), (6, 6)]
inception_v3 [(256, 256), (125, 125), (60, 60), (29, 29), (14, 14), (6, 6)]
inception_v4 [(256, 256), (125, 125), (62, 62), (29, 29), (14, 14), (6, 6)]
legacy_xception [(256, 256), (125, 125), (63, 63), (32, 32), (16, 16), (8, 8)]
nasnetalarge [(256, 256), (127, 127), (64, 64), (32, 32), (16, 16), (8, 8)]
pnasnet5large [(256, 256), (127, 127), (64, 64), (32, 32), (16, 16), (8, 8)]

Renamed / Deprecated Models

The following models remain functional but are deprecated in timm=1.0.12, as indicated by warnings during testing. These models were removed from the list to avoid confusion:

  • mnasnet_a1, mnasnet_b1, efficientnet_b2a, efficientnet_b3a, seresnext26tn_32x4d

test code

import torch
import segmentation_models_pytorch as smp

model_list = [
    "mnasnet_a1",
    "mnasnet_b1",
    "efficientnet_b2a",
    "efficientnet_b3a",
    "seresnext26tn_32x4d",
]

if __name__ == "__main__":
    x = torch.rand(1, 3, 256, 256)
    for name in model_list:
        model = smp.encoders.get_encoder(f"tu-{name}", weights=None).eval()

output

UserWarning: Mapping deprecated model name mnasnet_a1 to current semnasnet_100.
UserWarning: Mapping deprecated model name mnasnet_b1 to current mnasnet_100.
UserWarning: Mapping deprecated model name efficientnet_b2a to current efficientnet_b2.
UserWarning: Mapping deprecated model name efficientnet_b3a to current efficientnet_b3.
UserWarning: Mapping deprecated model name seresnext26tn_32x4d to current seresnext26t_32x4d.

@brianhou0208
Copy link
Contributor Author

brianhou0208 commented Dec 8, 2024

Add New Traditional-Style Models

  • shape format: (B, C, H, W)
  • output stride: (1/2, 1/4, 1/8, 1/16, 1/32)
  • number of models: 57
new support models
efficientnet_blur_b0
ghostnetv2_100
ghostnetv2_130
ghostnetv2_160
mobilenet_edgetpu_100
mobilenet_edgetpu_v2_l
mobilenet_edgetpu_v2_m
mobilenet_edgetpu_v2_s
mobilenet_edgetpu_v2_xs
mobilenetv1_100
mobilenetv1_100h
mobilenetv1_125
mobilenetv3_large_150d
mobilenetv4_conv_aa_large
mobilenetv4_conv_aa_medium
mobilenetv4_conv_blur_medium
mobilenetv4_conv_large
mobilenetv4_conv_medium
mobilenetv4_conv_small
mobilenetv4_conv_small_035
mobilenetv4_conv_small_050
mobilenetv4_hybrid_large
mobilenetv4_hybrid_large_075
mobilenetv4_hybrid_medium
mobilenetv4_hybrid_medium_075
mobileone_s0
mobileone_s1
mobileone_s2
mobileone_s3
mobileone_s4
repghostnet_050
repghostnet_058
repghostnet_080
repghostnet_100
repghostnet_111
repghostnet_130
repghostnet_150
repghostnet_200
repvgg_a0
repvgg_a1
repvgg_d2se
resnet50_clip
resnet50_clip_gap
resnet50_mlp
resnet50x4_clip
resnet50x4_clip_gap
resnet50x16_clip
resnet50x16_clip_gap
resnet50x64_clip
resnet50x64_clip_gap
resnet101_clip
resnet101_clip_gap
resnetv2_18
resnetv2_18d
resnetv2_34
resnetv2_34d
seresnextaa201d_32x8d
  • shape format: (B, C, H, W)
  • output stride: (1/1, 1/2, 1/4, 1/8, 1/16, 1/32)
  • number of models: 14
new support models
cspdarknet53
darknet17
darknet21
darknet53
darknetaa53
sedarknet21
vgg11
vgg11_bn
vgg13
vgg13_bn
vgg16
vgg16_bn
vgg19
vgg19_bn

Add New Transformer-Style Models

Channel-First Models

  • shape format: (B, C, H, W)
  • output stride: (1/4, 1/8, 1/16, 1/32)
  • number of models : 182 188
new support models
caformer_b36
caformer_m36
caformer_s18
caformer_s36
convformer_b36
convformer_m36
convformer_s18
convformer_s36
convnext_atto
convnext_atto_ols
convnext_atto_rms
convnext_base
convnext_femto
convnext_femto_ols
convnext_large
convnext_large_mlp
convnext_nano
convnext_nano_ols
convnext_pico
convnext_pico_ols
convnext_small
convnext_tiny
convnext_tiny_hnf
convnext_xlarge
convnext_xxlarge
convnext_zepto_rms
convnext_zepto_rms_ols
convnextv2_atto
convnextv2_base
convnextv2_femto
convnextv2_huge
convnextv2_large
convnextv2_nano
convnextv2_pico
convnextv2_small
convnextv2_tiny
davit_base
davit_base_fl
davit_giant
davit_huge
davit_huge_fl
davit_large
davit_small
davit_tiny
edgenext_base
edgenext_small
edgenext_small_rw
edgenext_x_small
edgenext_xx_small
efficientformer_l1
efficientformer_l3
efficientformer_l7
efficientformerv2_l
efficientformerv2_s0
efficientformerv2_s1
efficientformerv2_s2
efficientvit_b0
efficientvit_b1
efficientvit_b2
efficientvit_b3
efficientvit_l1
efficientvit_l2
efficientvit_l3
fastvit_ma36
fastvit_mci0
fastvit_mci1
fastvit_mci2
fastvit_s12
fastvit_sa12
fastvit_sa24
fastvit_sa36
fastvit_t8
fastvit_t12
focalnet_base_lrf
focalnet_base_srf
focalnet_huge_fl3
focalnet_huge_fl4
focalnet_large_fl3
focalnet_large_fl4
focalnet_small_lrf
focalnet_small_srf
focalnet_tiny_lrf
focalnet_tiny_srf
focalnet_xlarge_fl3
focalnet_xlarge_fl4
hgnet_base
hgnet_small
hgnet_tiny
hgnetv2_b0
hgnetv2_b1
hgnetv2_b2
hgnetv2_b3
hgnetv2_b4
hgnetv2_b5
hgnetv2_b6
hiera_base_224
hiera_base_abswin_256
hiera_base_plus_224
hiera_huge_224
hiera_large_224
hiera_small_224
hiera_small_abswin_256
hiera_tiny_224
hieradet_small
inception_next_base
inception_next_small
inception_next_tiny
mvitv2_base
mvitv2_base_cls
mvitv2_huge_cls
mvitv2_large
mvitv2_large_cls
mvitv2_small
mvitv2_small_cls
mvitv2_tiny
nest_base
nest_base_jx
nest_small
nest_small_jx
nest_tiny
nest_tiny_jx
nextvit_base
nextvit_large
nextvit_small
poolformer_m36
poolformer_m48
poolformer_s12
poolformer_s24
poolformer_s36
poolformerv2_m36
poolformerv2_m48
poolformerv2_s12
poolformerv2_s24
poolformerv2_s36
pvt_v2_b0
pvt_v2_b1
pvt_v2_b2
pvt_v2_b2_li
pvt_v2_b3
pvt_v2_b4
pvt_v2_b5
rdnet_base
rdnet_large
rdnet_small
rdnet_tiny
repvit_m0_9
repvit_m1
repvit_m1_0
repvit_m1_1
repvit_m1_5
repvit_m2
repvit_m2_3
repvit_m3
sam2_hiera_base_plus
sam2_hiera_large
sam2_hiera_small
sam2_hiera_tiny
swinv2_cr_base_224
swinv2_cr_base_384
swinv2_cr_base_ns_224
swinv2_cr_giant_224
swinv2_cr_giant_384
swinv2_cr_huge_224
swinv2_cr_huge_384
swinv2_cr_large_224
swinv2_cr_large_384
swinv2_cr_small_224
swinv2_cr_small_384
swinv2_cr_small_ns_224
swinv2_cr_small_ns_256
swinv2_cr_tiny_224
swinv2_cr_tiny_384
swinv2_cr_tiny_ns_224
tiny_vit_5m_224
tiny_vit_11m_224
tiny_vit_21m_224
tiny_vit_21m_384
tiny_vit_21m_512
tresnet_l
tresnet_m
tresnet_v2_l
tresnet_xl
twins_pcpvt_base
twins_pcpvt_large
twins_pcpvt_small
twins_svt_base
twins_svt_large
twins_svt_small

Channel-Last Models

  • shape format: (B, H, W, C)
  • output stride: (1/4, 1/8, 1/16, 1/32)
  • number of models : 31

These models are clearly transformer-style models, but their format is channel-last. Additional processing might be required. I'm not sure how to handle this.

new support models
mambaout_base
mambaout_base_plus_rw
mambaout_base_short_rw
mambaout_base_tall_rw
mambaout_base_wide_rw
mambaout_femto
mambaout_kobe
mambaout_small
mambaout_small_rw
mambaout_tiny
swin_base_patch4_window7_224
swin_base_patch4_window12_384
swin_large_patch4_window7_224
swin_large_patch4_window12_384
swin_s3_base_224
swin_s3_small_224
swin_s3_tiny_224
swin_small_patch4_window7_224
swin_tiny_patch4_window7_224
swinv2_base_window8_256
swinv2_base_window12_192
swinv2_base_window12to16_192to256
swinv2_base_window12to24_192to384
swinv2_base_window16_256
swinv2_large_window12_192
swinv2_large_window12to16_192to256
swinv2_large_window12to24_192to384
swinv2_small_window8_256
swinv2_small_window16_256
swinv2_tiny_window8_256
swinv2_tiny_window16_256

@qubvel
Copy link
Collaborator

qubvel commented Dec 9, 2024

Hi @brianhou0208! Thanks for working with this challenging feature!

My main concerns are:

  1. Reliably determining whether features come in BCHW or BHWC format. Ideally, this should be done during model construction time. I will ask Ross if there is any way to get it at that time.
  2. Timm transformer models can return more than 3/4 features if indices of output features are provided. Your implementation is nice, but I want to make sure we can extend it to more backbones and provide feature indices without breaking backward compatibility in the future.

Let me know what you think?

@qubvel
Copy link
Collaborator

qubvel commented Dec 9, 2024

Regarding NHWC format I got the answer from Ross, the following should work:

getattr(model, "output_fmt", None) == "NHWC"

That attribute is only set if models have NHWC format, that's why we have to use getattr.

Also, there are some models that come with features in NLC format, but I suppose they can be ignored for now

@brianhou0208
Copy link
Contributor Author

brianhou0208 commented Dec 12, 2024

Hi @qubvel ,

  1. Support for channel-last models
    I have updated the code to support channels in the format (B, H, W, C).
  2. Support for more timm backbones
    I think this is a bit challenging. The speed at which models are updated is very fast, and not all models follow a fixed architectural design (this refers to models like ResNet or CoNeXt). However, I believe Ross will ensure that any new models added are still aligned with the current timm API structure. Thus, based on the current PR, it can support most models.

Without using feature_info.channels(), we can also use feature_info.reduction() in the timm documentation to achieve similar functionality.

timm api test & result

test code

import torch
import timm
import segmentation_models_pytorch as smp

model_list = [
    ["dla34", 224],
    ["cspdarknet53", 224],
    ["efficientnet_x_b3", 224],
    ["efficientvit_m0", 224],
    ["inception_resnet_v2", 299],
    ["inception_v3", 299],
    ["inception_v4", 299],
    ["mambaout_tiny", 224],
    ["tresnet_m", 224],
    ["vit_tiny_patch16_224", 224],
]

if __name__ == "__main__":
    for model_name, img_size in model_list:
        x = torch.rand(1, 3, img_size, img_size)
        model = timm.create_model(f"{model_name}", features_only=True).eval()
        y = model(x)
        
        print(f"timm-{model_name}-(C, H, W) = {(3, img_size, img_size)}")
        print(f"    Feature shape: {[f.detach().numpy().shape[1:] for f in y]}")
        print(f"    Feature channels: {model.feature_info.channels()}")
        print(f"    Feature reduction: {model.feature_info.reduction()}")

output

timm-dla34-(C, H, W) = (3, 224, 224)
    Feature shape: [(32, 112, 112), (64, 56, 56), (128, 28, 28), (256, 14, 14), (512, 7, 7)]
    Feature channels: [32, 64, 128, 256, 512]
    Feature reduction: [2, 4, 8, 16, 32]
timm-cspdarknet53-(C, H, W) = (3, 224, 224)
    Feature shape: [(32, 224, 224), (64, 112, 112), (128, 56, 56), (256, 28, 28), (512, 14, 14), (1024, 7, 7)]
    Feature channels: [32, 64, 128, 256, 512, 1024]
    Feature reduction: [1, 2, 4, 8, 16, 32]
timm-efficientnet_x_b3-(C, H, W) = (3, 224, 224)
    Feature shape: [(96, 56, 56), (32, 56, 56), (48, 28, 28), (136, 14, 14), (384, 7, 7)]
    Feature channels: [96, 32, 48, 136, 384]
    Feature reduction: [2, 2, 4, 8, 16]
timm-efficientvit_m0-(C, H, W) = (3, 224, 224)
    Feature shape: [(64, 14, 14), (128, 7, 7), (192, 4, 4)]
    Feature channels: [64, 128, 192]
    Feature reduction: [16, 32, 64]
timm-inception_resnet_v2-(C, H, W) = (3, 299, 299)
    Feature shape: [(64, 147, 147), (192, 71, 71), (320, 35, 35), (1088, 17, 17), (1536, 8, 8)]
    Feature channels: [64, 192, 320, 1088, 1536]
    Feature reduction: [2, 4, 8, 16, 32]
timm-inception_v3-(C, H, W) = (3, 299, 299)
    Feature shape: [(64, 147, 147), (192, 71, 71), (288, 35, 35), (768, 17, 17), (2048, 8, 8)]
    Feature channels: [64, 192, 288, 768, 2048]
    Feature reduction: [2, 4, 8, 16, 32]
timm-inception_v4-(C, H, W) = (3, 299, 299)
    Feature shape: [(64, 147, 147), (160, 73, 73), (384, 35, 35), (1024, 17, 17), (1536, 8, 8)]
    Feature channels: [64, 160, 384, 1024, 1536]
    Feature reduction: [2, 4, 8, 16, 32]
timm-tresnet_m-(C, H, W) = (3, 224, 224)
    Feature shape: [(64, 56, 56), (128, 28, 28), (1024, 14, 14), (2048, 7, 7)]
    Feature channels: [64, 128, 1024, 2048]
    Feature reduction: [4, 8, 16, 32]
timm-vit_tiny_patch16_224-(C, H, W) = (3, 224, 224)
    Feature shape: [(192, 14, 14), (192, 14, 14), (192, 14, 14)]
    Feature channels: [192, 192, 192]
    Feature reduction: [16, 16, 16]

However, you might still encounter cases like:

  • inception_resnet_v2, inception_v3 where output stride is not matching.
  • dla34, tresnet_m which require explicit setting of out_indices to correctly output features.
out_indices test & result

test

import torch
import timm
import segmentation_models_pytorch as smp

model_list = [
    ["resnet18", 224, (0, 1, 2, 3, 4)],
    ["dla34", 224, (1, 2, 3, 4, 5)],
    ["mambaout_tiny", 224, (0, 1, 2, 3)],
    ["tresnet_m", 224, (1, 2, 3, 4)],
]

if __name__ == "__main__":
    for model_name, img_size, out_indices in model_list:
        x = torch.rand(1, 3, img_size, img_size)
        model = timm.create_model(f"{model_name}", features_only=True, out_indices=out_indices).eval()
        y = model(x)
        
        print(f"timm-{model_name}")
        print(f"    Feature shape: {[f.detach().numpy().shape[1:] for f in y]}")

        model = smp.encoders.get_encoder(f"tu-{model_name}").eval()
        y = model(x)[1:]
        print(f"smp-{model_name}")
        print(f"    Feature shape: {[f.detach().numpy().shape[1:] for f in y]}")
        print()

output

timm-resnet18
    Feature shape: [(64, 112, 112), (64, 56, 56), (128, 28, 28), (256, 14, 14), (512, 7, 7)]
smp-resnet18
    Feature shape: [(64, 112, 112), (64, 56, 56), (128, 28, 28), (256, 14, 14), (512, 7, 7)]

timm-dla34
    Feature shape: [(32, 112, 112), (64, 56, 56), (128, 28, 28), (256, 14, 14), (512, 7, 7)]
smp-dla34
    Feature shape: [(32, 112, 112), (64, 56, 56), (128, 28, 28), (256, 14, 14), (512, 7, 7)]

timm-mambaout_tiny
    Feature shape: [(56, 56, 96), (28, 28, 192), (14, 14, 384), (7, 7, 576)]
smp-mambaout_tiny
    Feature shape: [(0, 112, 112), (96, 56, 56), (192, 28, 28), (384, 14, 14), (576, 7, 7)]

timm-tresnet_m
    Feature shape: [(64, 56, 56), (128, 28, 28), (1024, 14, 14), (2048, 7, 7)]
smp-tresnet_m
    Feature shape: [(0, 112, 112), (64, 56, 56), (128, 28, 28), (1024, 14, 14), (2048, 7, 7)]

@brianhou0208
Copy link
Contributor Author

brianhou0208 commented Dec 12, 2024

I spent some time reviewing all the models in timm==1.0.12.
In fact, this PR already supports the majority of models, covering 94.638% ((593 + 219) / (1122 - 264)).

Timm Support Backbone

  • timm version: 1.0.12
  • Total models in timm: 1170
  • Feature extraction (supported / unsupported): 1136 / 34
  • timm test models: 14
  • Number of models: 1122 (1170 - 34 - 14)
Unsupported feature extraction: 34
coat_lite_medium
coat_lite_medium_384
coat_lite_mini
coat_lite_small
coat_lite_tiny
coat_mini
coat_small
coat_tiny
convit_base
convit_small
convit_tiny
convmixer_768_32
convmixer_1024_20_ks9_p14
convmixer_1536_20
crossvit_9_240
crossvit_9_dagger_240
crossvit_15_240
crossvit_15_dagger_240
crossvit_15_dagger_408
crossvit_18_240
crossvit_18_dagger_240
crossvit_18_dagger_408
crossvit_base_240
crossvit_small_240
crossvit_tiny_240
gcvit_base
gcvit_small
gcvit_tiny
gcvit_xtiny
gcvit_xxtiny
tnt_b_patch16_224
tnt_s_patch16_224
visformer_small
visformer_tiny
Tests models: 14
test_byobnet
test_convnext
test_convnext2
test_convnext3
test_efficientnet
test_efficientnet_evos
test_efficientnet_gn
test_efficientnet_ln
test_mambaout
test_nfnet
test_resnet
test_vit
test_vit2
test_vit3

SMP Support Backbone

  • SMP support for traditional encoders: 593
  • SMP support for transformer-style encoders: 219 (182+31)

SMP Unsupported Backbone

  • Number of unsupported models: 310 (1122 - 593 - 219)
Unsupported models: 310
beit_base_patch16_224
beit_base_patch16_384
beit_large_patch16_224
beit_large_patch16_384
beit_large_patch16_512
beitv2_base_patch16_224
beitv2_large_patch16_224
cait_m36_384
cait_m48_448
cait_s24_224
cait_s24_384
cait_s36_384
cait_xs24_384
cait_xxs24_224
cait_xxs24_384
cait_xxs36_224
cait_xxs36_384
deit3_base_patch16_224
deit3_base_patch16_384
deit3_huge_patch14_224
deit3_large_patch16_224
deit3_large_patch16_384
deit3_medium_patch16_224
deit3_small_patch16_224
deit3_small_patch16_384
deit_base_distilled_patch16_224
deit_base_distilled_patch16_384
deit_base_patch16_224
deit_base_patch16_384
deit_small_distilled_patch16_224
deit_small_patch16_224
deit_tiny_distilled_patch16_224
deit_tiny_patch16_224
efficientnet_h_b5
efficientnet_x_b3
efficientnet_x_b5
efficientvit_m0
efficientvit_m1
efficientvit_m2
efficientvit_m3
efficientvit_m4
efficientvit_m5
eva02_base_patch14_224
eva02_base_patch14_448
eva02_base_patch16_clip_224
eva02_enormous_patch14_clip_224
eva02_large_patch14_224
eva02_large_patch14_448
eva02_large_patch14_clip_224
eva02_large_patch14_clip_336
eva02_small_patch14_224
eva02_small_patch14_336
eva02_tiny_patch14_224
eva02_tiny_patch14_336
eva_giant_patch14_224
eva_giant_patch14_336
eva_giant_patch14_560
eva_giant_patch14_clip_224
eva_large_patch14_196
eva_large_patch14_336
flexivit_base
flexivit_large
flexivit_small
gmixer_12_224
gmixer_24_224
gmlp_b16_224
gmlp_s16_224
gmlp_ti16_224
inception_resnet_v2
inception_v3
inception_v4
legacy_xception
levit_128
levit_128s
levit_192
levit_256
levit_256d
levit_384
levit_384_s8
levit_512
levit_512_s8
levit_512d
levit_conv_128
levit_conv_128s
levit_conv_192
levit_conv_256
levit_conv_256d
levit_conv_384
levit_conv_384_s8
levit_conv_512
levit_conv_512_s8
levit_conv_512d
mixer_b16_224
mixer_b32_224
mixer_l16_224
mixer_l32_224
mixer_s16_224
mixer_s32_224
nasnetalarge
pit_b_224
pit_b_distilled_224
pit_s_224
pit_s_distilled_224
pit_ti_224
pit_ti_distilled_224
pit_xs_224
pit_xs_distilled_224
pnasnet5large
resmlp_12_224
resmlp_24_224
resmlp_36_224
resmlp_big_24_224
samvit_base_patch16
samvit_base_patch16_224
samvit_huge_patch16
samvit_large_patch16
sequencer2d_l
sequencer2d_m
sequencer2d_s
vit_base_mci_224
vit_base_patch8_224
vit_base_patch14_dinov2
vit_base_patch14_reg4_dinov2
vit_base_patch16_18x2_224
vit_base_patch16_224
vit_base_patch16_224_miil
vit_base_patch16_384
vit_base_patch16_clip_224
vit_base_patch16_clip_384
vit_base_patch16_clip_quickgelu_224
vit_base_patch16_gap_224
vit_base_patch16_plus_240
vit_base_patch16_plus_clip_240
vit_base_patch16_reg4_gap_256
vit_base_patch16_rope_reg1_gap_256
vit_base_patch16_rpn_224
vit_base_patch16_siglip_224
vit_base_patch16_siglip_256
vit_base_patch16_siglip_384
vit_base_patch16_siglip_512
vit_base_patch16_siglip_gap_224
vit_base_patch16_siglip_gap_256
vit_base_patch16_siglip_gap_384
vit_base_patch16_siglip_gap_512
vit_base_patch16_xp_224
vit_base_patch32_224
vit_base_patch32_384
vit_base_patch32_clip_224
vit_base_patch32_clip_256
vit_base_patch32_clip_384
vit_base_patch32_clip_448
vit_base_patch32_clip_quickgelu_224
vit_base_patch32_plus_256
vit_base_r26_s32_224
vit_base_r50_s16_224
vit_base_r50_s16_384
vit_base_resnet26d_224
vit_base_resnet50d_224
vit_betwixt_patch16_gap_256
vit_betwixt_patch16_reg1_gap_256
vit_betwixt_patch16_reg4_gap_256
vit_betwixt_patch16_reg4_gap_384
vit_betwixt_patch16_rope_reg4_gap_256
vit_betwixt_patch32_clip_224
vit_giant_patch14_224
vit_giant_patch14_clip_224
vit_giant_patch14_dinov2
vit_giant_patch14_reg4_dinov2
vit_giant_patch16_gap_224
vit_gigantic_patch14_224
vit_gigantic_patch14_clip_224
vit_gigantic_patch14_clip_quickgelu_224
vit_huge_patch14_224
vit_huge_patch14_clip_224
vit_huge_patch14_clip_336
vit_huge_patch14_clip_378
vit_huge_patch14_clip_quickgelu_224
vit_huge_patch14_clip_quickgelu_378
vit_huge_patch14_gap_224
vit_huge_patch14_xp_224
vit_huge_patch16_gap_448
vit_intern300m_patch14_448
vit_large_patch14_224
vit_large_patch14_clip_224
vit_large_patch14_clip_336
vit_large_patch14_clip_quickgelu_224
vit_large_patch14_clip_quickgelu_336
vit_large_patch14_dinov2
vit_large_patch14_reg4_dinov2
vit_large_patch14_xp_224
vit_large_patch16_224
vit_large_patch16_384
vit_large_patch16_siglip_256
vit_large_patch16_siglip_384
vit_large_patch16_siglip_gap_256
vit_large_patch16_siglip_gap_384
vit_large_patch32_224
vit_large_patch32_384
vit_large_r50_s32_224
vit_large_r50_s32_384
vit_little_patch16_reg1_gap_256
vit_little_patch16_reg4_gap_256
vit_medium_patch16_clip_224
vit_medium_patch16_gap_240
vit_medium_patch16_gap_256
vit_medium_patch16_gap_384
vit_medium_patch16_reg1_gap_256
vit_medium_patch16_reg4_gap_256
vit_medium_patch16_rope_reg1_gap_256
vit_medium_patch32_clip_224
vit_mediumd_patch16_reg4_gap_256
vit_mediumd_patch16_reg4_gap_384
vit_mediumd_patch16_rope_reg1_gap_256
vit_pwee_patch16_reg1_gap_256
vit_relpos_base_patch16_224
vit_relpos_base_patch16_cls_224
vit_relpos_base_patch16_clsgap_224
vit_relpos_base_patch16_plus_240
vit_relpos_base_patch16_rpn_224
vit_relpos_base_patch32_plus_rpn_256
vit_relpos_medium_patch16_224
vit_relpos_medium_patch16_cls_224
vit_relpos_medium_patch16_rpn_224
vit_relpos_small_patch16_224
vit_relpos_small_patch16_rpn_224
vit_small_patch8_224
vit_small_patch14_dinov2
vit_small_patch14_reg4_dinov2
vit_small_patch16_18x2_224
vit_small_patch16_36x1_224
vit_small_patch16_224
vit_small_patch16_384
vit_small_patch32_224
vit_small_patch32_384
vit_small_r26_s32_224
vit_small_r26_s32_384
vit_small_resnet26d_224
vit_small_resnet50d_s16_224
vit_so150m_patch16_reg4_gap_256
vit_so150m_patch16_reg4_map_256
vit_so400m_patch14_siglip_224
vit_so400m_patch14_siglip_378
vit_so400m_patch14_siglip_384
vit_so400m_patch14_siglip_gap_224
vit_so400m_patch14_siglip_gap_378
vit_so400m_patch14_siglip_gap_384
vit_so400m_patch14_siglip_gap_448
vit_so400m_patch14_siglip_gap_896
vit_so400m_patch16_siglip_256
vit_so400m_patch16_siglip_gap_256
vit_srelpos_medium_patch16_224
vit_srelpos_small_patch16_224
vit_tiny_patch16_224
vit_tiny_patch16_384
vit_tiny_r_s16_p8_224
vit_tiny_r_s16_p8_384
vit_wee_patch16_reg1_gap_256
vit_xsmall_patch16_clip_224
vitamin_base_224
vitamin_large2_224
vitamin_large2_256
vitamin_large2_336
vitamin_large2_384
vitamin_large_224
vitamin_large_256
vitamin_large_336
vitamin_large_384
vitamin_small_224
vitamin_xlarge_256
vitamin_xlarge_336
vitamin_xlarge_384
volo_d1_224
volo_d1_384
volo_d2_224
volo_d2_384
volo_d3_224
volo_d3_448
volo_d4_224
volo_d4_448
volo_d5_224
volo_d5_448
volo_d5_512
xcit_large_24_p8_224
xcit_large_24_p8_384
xcit_large_24_p16_224
xcit_large_24_p16_384
xcit_medium_24_p8_224
xcit_medium_24_p8_384
xcit_medium_24_p16_224
xcit_medium_24_p16_384
xcit_nano_12_p8_224
xcit_nano_12_p8_384
xcit_nano_12_p16_224
xcit_nano_12_p16_384
xcit_small_12_p8_224
xcit_small_12_p8_384
xcit_small_12_p16_224
xcit_small_12_p16_384
xcit_small_24_p8_224
xcit_small_24_p8_384
xcit_small_24_p16_224
xcit_small_24_p16_384
xcit_tiny_12_p8_224
xcit_tiny_12_p8_384
xcit_tiny_12_p16_224
xcit_tiny_12_p16_384
xcit_tiny_24_p8_224
xcit_tiny_24_p8_384
xcit_tiny_24_p16_224
xcit_tiny_24_p16_384

Check for unsupported models in SMP

import torch
import timm

if __name__ == "__main__":
    for model_name in model_list:
        model = timm.create_model(f"{model_name}", features_only=True).eval()
        is_channel_last = getattr(model, "output_fmt", None) == "NHWC"
        print(f"{model_name} {is_channel_last}")
        print(model.feature_info.reduction())
With downsample feature: 46
efficientnet_h_b5 False
[2, 2, 4, 8, 16]
efficientnet_x_b3 False
[2, 2, 4, 8, 16]
efficientnet_x_b5 False
[2, 2, 4, 8, 16]
efficientvit_m0 False
[16, 32, 64]
efficientvit_m1 False
[16, 32, 64]
efficientvit_m2 False
[16, 32, 64]
efficientvit_m3 False
[16, 32, 64]
efficientvit_m4 False
[16, 32, 64]
efficientvit_m5 False
[16, 32, 64]
inception_resnet_v2 False
[2, 4, 8, 16, 32]
inception_v3 False
[2, 4, 8, 16, 32]
inception_v4 False
[2, 4, 8, 16, 32]
legacy_xception False
[2, 4, 8, 16, 32]
levit_128 False
[16, 32, 64]
levit_128s False
[16, 32, 64]
levit_192 False
[16, 32, 64]
levit_256 False
[16, 32, 64]
levit_256d False
[16, 32, 64]
levit_384 False
[16, 32, 64]
levit_384_s8 False
[8, 16, 32]
levit_512 False
[16, 32, 64]
levit_512_s8 False
[8, 16, 32]
levit_512d False
[16, 32, 64]
levit_conv_128 False
[16, 32, 64]
levit_conv_128s False
[16, 32, 64]
levit_conv_192 False
[16, 32, 64]
levit_conv_256 False
[16, 32, 64]
levit_conv_256d False
[16, 32, 64]
levit_conv_384 False
[16, 32, 64]
levit_conv_384_s8 False
[8, 16, 32]
levit_conv_512 False
[16, 32, 64]
levit_conv_512_s8 False
[8, 16, 32]
levit_conv_512d False
[16, 32, 64]
nasnetalarge False
[2, 4, 8, 16, 32]
pit_b_224 False
[6, 12, 24]
pit_b_distilled_224 False
[6, 12, 24]
pit_s_224 False
[7, 14, 28]
pit_s_distilled_224 False
[7, 14, 28]
pit_ti_224 False
[7, 14, 28]
pit_ti_distilled_224 False
[7, 14, 28]
pit_xs_224 False
[7, 14, 28]
pit_xs_distilled_224 False
[7, 14, 28]
pnasnet5large False
[2, 4, 8, 16, 32]
sequencer2d_l True
[7, 14, 14]
sequencer2d_m True
[7, 14, 14]
sequencer2d_s True
[7, 14, 14]
Without downsample feature: 264
beit_base_patch16_224 False
[16, 16, 16]
beit_base_patch16_384 False
[16, 16, 16]
beit_large_patch16_224 False
[16, 16, 16]
beit_large_patch16_384 False
[16, 16, 16]
beit_large_patch16_512 False
[16, 16, 16]
beitv2_base_patch16_224 False
[16, 16, 16]
beitv2_large_patch16_224 False
[16, 16, 16]
cait_m36_384 False
[16, 16, 16]
cait_m48_448 False
[16, 16, 16]
cait_s24_224 False
[16, 16, 16]
cait_s24_384 False
[16, 16, 16]
cait_s36_384 False
[16, 16, 16]
cait_xs24_384 False
[16, 16, 16]
cait_xxs24_224 False
[16, 16, 16]
cait_xxs24_384 False
[16, 16, 16]
cait_xxs36_224 False
[16, 16, 16]
cait_xxs36_384 False
[16, 16, 16]
deit3_base_patch16_224 False
[16, 16, 16]
deit3_base_patch16_384 False
[16, 16, 16]
deit3_huge_patch14_224 False
[14, 14, 14]
deit3_large_patch16_224 False
[16, 16, 16]
deit3_large_patch16_384 False
[16, 16, 16]
deit3_medium_patch16_224 False
[16, 16, 16]
deit3_small_patch16_224 False
[16, 16, 16]
deit3_small_patch16_384 False
[16, 16, 16]
deit_base_distilled_patch16_224 False
[16, 16, 16]
deit_base_distilled_patch16_384 False
[16, 16, 16]
deit_base_patch16_224 False
[16, 16, 16]
deit_base_patch16_384 False
[16, 16, 16]
deit_small_distilled_patch16_224 False
[16, 16, 16]
deit_small_patch16_224 False
[16, 16, 16]
deit_tiny_distilled_patch16_224 False
[16, 16, 16]
deit_tiny_patch16_224 False
[16, 16, 16]
eva02_base_patch14_224 False
[14, 14, 14]
eva02_base_patch14_448 False
[14, 14, 14]
eva02_base_patch16_clip_224 False
[16, 16, 16]
eva02_enormous_patch14_clip_224 False
[14, 14, 14]
eva02_large_patch14_224 False
[14, 14, 14]
eva02_large_patch14_448 False
[14, 14, 14]
eva02_large_patch14_clip_224 False
[14, 14, 14]
eva02_large_patch14_clip_336 False
[14, 14, 14]
eva02_small_patch14_224 False
[14, 14, 14]
eva02_small_patch14_336 False
[14, 14, 14]
eva02_tiny_patch14_224 False
[14, 14, 14]
eva02_tiny_patch14_336 False
[14, 14, 14]
eva_giant_patch14_224 False
[14, 14, 14]
eva_giant_patch14_336 False
[14, 14, 14]
eva_giant_patch14_560 False
[14, 14, 14]
eva_giant_patch14_clip_224 False
[14, 14, 14]
eva_large_patch14_196 False
[14, 14, 14]
eva_large_patch14_336 False
[14, 14, 14]
flexivit_base False
[16, 16, 16]
flexivit_large False
[16, 16, 16]
flexivit_small False
[16, 16, 16]
gmixer_12_224 False
[16, 16, 16]
gmixer_24_224 False
[16, 16, 16]
gmlp_b16_224 False
[16, 16, 16]
gmlp_s16_224 False
[16, 16, 16]
gmlp_ti16_224 False
[16, 16, 16]
mixer_b16_224 False
[16, 16, 16]
mixer_b32_224 False
[32, 32, 32]
mixer_l16_224 False
[16, 16, 16]
mixer_l32_224 False
[32, 32, 32]
mixer_s16_224 False
[16, 16, 16]
mixer_s32_224 False
[32, 32, 32]
resmlp_12_224 False
[16, 16, 16]
resmlp_24_224 False
[16, 16, 16]
resmlp_36_224 False
[16, 16, 16]
resmlp_big_24_224 False
[8, 8, 8]
samvit_base_patch16 False
[16, 16, 16]
samvit_base_patch16_224 False
[16, 16, 16]
samvit_huge_patch16 False
[16, 16, 16]
samvit_large_patch16 False
[16, 16, 16]
vit_base_mci_224 False
[16, 16, 16]
vit_base_patch8_224 False
[8, 8, 8]
vit_base_patch14_dinov2 False
[14, 14, 14]
vit_base_patch14_reg4_dinov2 False
[14, 14, 14]
vit_base_patch16_18x2_224 False
[16, 16, 16]
vit_base_patch16_224 False
[16, 16, 16]
vit_base_patch16_224_miil False
[16, 16, 16]
vit_base_patch16_384 False
[16, 16, 16]
vit_base_patch16_clip_224 False
[16, 16, 16]
vit_base_patch16_clip_384 False
[16, 16, 16]
vit_base_patch16_clip_quickgelu_224 False
[16, 16, 16]
vit_base_patch16_gap_224 False
[16, 16, 16]
vit_base_patch16_plus_240 False
[16, 16, 16]
vit_base_patch16_plus_clip_240 False
[16, 16, 16]
vit_base_patch16_reg4_gap_256 False
[16, 16, 16]
vit_base_patch16_rope_reg1_gap_256 False
[16, 16, 16]
vit_base_patch16_rpn_224 False
[16, 16, 16]
vit_base_patch16_siglip_224 False
[16, 16, 16]
vit_base_patch16_siglip_256 False
[16, 16, 16]
vit_base_patch16_siglip_384 False
[16, 16, 16]
vit_base_patch16_siglip_512 False
[16, 16, 16]
vit_base_patch16_siglip_gap_224 False
[16, 16, 16]
vit_base_patch16_siglip_gap_256 False
[16, 16, 16]
vit_base_patch16_siglip_gap_384 False
[16, 16, 16]
vit_base_patch16_siglip_gap_512 False
[16, 16, 16]
vit_base_patch16_xp_224 False
[16, 16, 16]
vit_base_patch32_224 False
[32, 32, 32]
vit_base_patch32_384 False
[32, 32, 32]
vit_base_patch32_clip_224 False
[32, 32, 32]
vit_base_patch32_clip_256 False
[32, 32, 32]
vit_base_patch32_clip_384 False
[32, 32, 32]
vit_base_patch32_clip_448 False
[32, 32, 32]
vit_base_patch32_clip_quickgelu_224 False
[32, 32, 32]
vit_base_patch32_plus_256 False
[32, 32, 32]
vit_base_r26_s32_224 False
[32, 32, 32]
vit_base_r50_s16_224 False
[16, 16, 16]
vit_base_r50_s16_384 False
[16, 16, 16]
vit_base_resnet26d_224 False
[32, 32, 32]
vit_base_resnet50d_224 False
[32, 32, 32]
vit_betwixt_patch16_gap_256 False
[16, 16, 16]
vit_betwixt_patch16_reg1_gap_256 False
[16, 16, 16]
vit_betwixt_patch16_reg4_gap_256 False
[16, 16, 16]
vit_betwixt_patch16_reg4_gap_384 False
[16, 16, 16]
vit_betwixt_patch16_rope_reg4_gap_256 False
[16, 16, 16]
vit_betwixt_patch32_clip_224 False
[32, 32, 32]
vit_giant_patch14_224 False
[14, 14, 14]
vit_giant_patch14_clip_224 False
[14, 14, 14]
vit_giant_patch14_dinov2 False
[14, 14, 14]
vit_giant_patch14_reg4_dinov2 False
[14, 14, 14]
vit_giant_patch16_gap_224 False
[16, 16, 16]
vit_gigantic_patch14_224 False
[14, 14, 14]
vit_gigantic_patch14_clip_224 False
[14, 14, 14]
vit_gigantic_patch14_clip_quickgelu_224 False
[14, 14, 14]
vit_huge_patch14_224 False
[14, 14, 14]
vit_huge_patch14_clip_224 False
[14, 14, 14]
vit_huge_patch14_clip_336 False
[14, 14, 14]
vit_huge_patch14_clip_378 False
[14, 14, 14]
vit_huge_patch14_clip_quickgelu_224 False
[14, 14, 14]
vit_huge_patch14_clip_quickgelu_378 False
[14, 14, 14]
vit_huge_patch14_gap_224 False
[14, 14, 14]
vit_huge_patch14_xp_224 False
[14, 14, 14]
vit_huge_patch16_gap_448 False
[16, 16, 16]
vit_intern300m_patch14_448 False
[14, 14, 14]
vit_large_patch14_224 False
[14, 14, 14]
vit_large_patch14_clip_224 False
[14, 14, 14]
vit_large_patch14_clip_336 False
[14, 14, 14]
vit_large_patch14_clip_quickgelu_224 False
[14, 14, 14]
vit_large_patch14_clip_quickgelu_336 False
[14, 14, 14]
vit_large_patch14_dinov2 False
[14, 14, 14]
vit_large_patch14_reg4_dinov2 False
[14, 14, 14]
vit_large_patch14_xp_224 False
[14, 14, 14]
vit_large_patch16_224 False
[16, 16, 16]
vit_large_patch16_384 False
[16, 16, 16]
vit_large_patch16_siglip_256 False
[16, 16, 16]
vit_large_patch16_siglip_384 False
[16, 16, 16]
vit_large_patch16_siglip_gap_256 False
[16, 16, 16]
vit_large_patch16_siglip_gap_384 False
[16, 16, 16]
vit_large_patch32_224 False
[32, 32, 32]
vit_large_patch32_384 False
[32, 32, 32]
vit_large_r50_s32_224 False
[32, 32, 32]
vit_large_r50_s32_384 False
[32, 32, 32]
vit_little_patch16_reg1_gap_256 False
[16, 16, 16]
vit_little_patch16_reg4_gap_256 False
[16, 16, 16]
vit_medium_patch16_clip_224 False
[16, 16, 16]
vit_medium_patch16_gap_240 False
[16, 16, 16]
vit_medium_patch16_gap_256 False
[16, 16, 16]
vit_medium_patch16_gap_384 False
[16, 16, 16]
vit_medium_patch16_reg1_gap_256 False
[16, 16, 16]
vit_medium_patch16_reg4_gap_256 False
[16, 16, 16]
vit_medium_patch16_rope_reg1_gap_256 False
[16, 16, 16]
vit_medium_patch32_clip_224 False
[32, 32, 32]
vit_mediumd_patch16_reg4_gap_256 False
[16, 16, 16]
vit_mediumd_patch16_reg4_gap_384 False
[16, 16, 16]
vit_mediumd_patch16_rope_reg1_gap_256 False
[16, 16, 16]
vit_pwee_patch16_reg1_gap_256 False
[16, 16, 16]
vit_relpos_base_patch16_224 False
[16, 16, 16]
vit_relpos_base_patch16_cls_224 False
[16, 16, 16]
vit_relpos_base_patch16_clsgap_224 False
[16, 16, 16]
vit_relpos_base_patch16_plus_240 False
[16, 16, 16]
vit_relpos_base_patch16_rpn_224 False
[16, 16, 16]
vit_relpos_base_patch32_plus_rpn_256 False
[32, 32, 32]
vit_relpos_medium_patch16_224 False
[16, 16, 16]
vit_relpos_medium_patch16_cls_224 False
[16, 16, 16]
vit_relpos_medium_patch16_rpn_224 False
[16, 16, 16]
vit_relpos_small_patch16_224 False
[16, 16, 16]
vit_relpos_small_patch16_rpn_224 False
[16, 16, 16]
vit_small_patch8_224 False
[8, 8, 8]
vit_small_patch14_dinov2 False
[14, 14, 14]
vit_small_patch14_reg4_dinov2 False
[14, 14, 14]
vit_small_patch16_18x2_224 False
[16, 16, 16]
vit_small_patch16_36x1_224 False
[16, 16, 16]
vit_small_patch16_224 False
[16, 16, 16]
vit_small_patch16_384 False
[16, 16, 16]
vit_small_patch32_224 False
[32, 32, 32]
vit_small_patch32_384 False
[32, 32, 32]
vit_small_r26_s32_224 False
[32, 32, 32]
vit_small_r26_s32_384 False
[32, 32, 32]
vit_small_resnet26d_224 False
[32, 32, 32]
vit_small_resnet50d_s16_224 False
[16, 16, 16]
vit_so150m_patch16_reg4_gap_256 False
[16, 16, 16]
vit_so150m_patch16_reg4_map_256 False
[16, 16, 16]
vit_so400m_patch14_siglip_224 False
[14, 14, 14]
vit_so400m_patch14_siglip_378 False
[14, 14, 14]
vit_so400m_patch14_siglip_384 False
[14, 14, 14]
vit_so400m_patch14_siglip_gap_224 False
[14, 14, 14]
vit_so400m_patch14_siglip_gap_378 False
[14, 14, 14]
vit_so400m_patch14_siglip_gap_384 False
[14, 14, 14]
vit_so400m_patch14_siglip_gap_448 False
[14, 14, 14]
vit_so400m_patch14_siglip_gap_896 False
[14, 14, 14]
vit_so400m_patch16_siglip_256 False
[16, 16, 16]
vit_so400m_patch16_siglip_gap_256 False
[16, 16, 16]
vit_srelpos_medium_patch16_224 False
[16, 16, 16]
vit_srelpos_small_patch16_224 False
[16, 16, 16]
vit_tiny_patch16_224 False
[16, 16, 16]
vit_tiny_patch16_384 False
[16, 16, 16]
vit_tiny_r_s16_p8_224 False
[32, 32, 32]
vit_tiny_r_s16_p8_384 False
[32, 32, 32]
vit_wee_patch16_reg1_gap_256 False
[16, 16, 16]
vit_xsmall_patch16_clip_224 False
[16, 16, 16]
vitamin_base_224 False
[16, 16, 16]
vitamin_large2_224 False
[16, 16, 16]
vitamin_large2_256 False
[16, 16, 16]
vitamin_large2_336 False
[16, 16, 16]
vitamin_large2_384 False
[16, 16, 16]
vitamin_large_224 False
[16, 16, 16]
vitamin_large_256 False
[16, 16, 16]
vitamin_large_336 False
[16, 16, 16]
vitamin_large_384 False
[16, 16, 16]
vitamin_small_224 False
[16, 16, 16]
vitamin_xlarge_256 False
[16, 16, 16]
vitamin_xlarge_336 False
[16, 16, 16]
vitamin_xlarge_384 False
[16, 16, 16]
volo_d1_224 False
[16, 16, 16]
volo_d1_384 False
[16, 16, 16]
volo_d2_224 False
[16, 16, 16]
volo_d2_384 False
[16, 16, 16]
volo_d3_224 False
[16, 16, 16]
volo_d3_448 False
[16, 16, 16]
volo_d4_224 False
[16, 16, 16]
volo_d4_448 False
[16, 16, 16]
volo_d5_224 False
[16, 16, 16]
volo_d5_448 False
[16, 16, 16]
volo_d5_512 False
[16, 16, 16]
xcit_large_24_p8_224 False
[8, 8, 8]
xcit_large_24_p8_384 False
[8, 8, 8]
xcit_large_24_p16_224 False
[16, 16, 16]
xcit_large_24_p16_384 False
[16, 16, 16]
xcit_medium_24_p8_224 False
[8, 8, 8]
xcit_medium_24_p8_384 False
[8, 8, 8]
xcit_medium_24_p16_224 False
[16, 16, 16]
xcit_medium_24_p16_384 False
[16, 16, 16]
xcit_nano_12_p8_224 False
[8, 8, 8]
xcit_nano_12_p8_384 False
[8, 8, 8]
xcit_nano_12_p16_224 False
[16, 16, 16]
xcit_nano_12_p16_384 False
[16, 16, 16]
xcit_small_12_p8_224 False
[8, 8, 8]
xcit_small_12_p8_384 False
[8, 8, 8]
xcit_small_12_p16_224 False
[16, 16, 16]
xcit_small_12_p16_384 False
[16, 16, 16]
xcit_small_24_p8_224 False
[8, 8, 8]
xcit_small_24_p8_384 False
[8, 8, 8]
xcit_small_24_p16_224 False
[16, 16, 16]
xcit_small_24_p16_384 False
[16, 16, 16]
xcit_tiny_12_p8_224 False
[8, 8, 8]
xcit_tiny_12_p8_384 False
[8, 8, 8]
xcit_tiny_12_p16_224 False
[16, 16, 16]
xcit_tiny_12_p16_384 False
[16, 16, 16]
xcit_tiny_24_p8_224 False
[8, 8, 8]
xcit_tiny_24_p8_384 False
[8, 8, 8]
xcit_tiny_24_p16_224 False
[16, 16, 16]
xcit_tiny_24_p16_384 False
[16, 16, 16]

Copy link
Collaborator

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @brianhou0208! Thanks for continuing to work on this 🚀 It already looks really great. I just have one question:

name, **_merge_kwargs_no_duplicates(common_kwargs, kwargs)
)
# Load a temporary model to analyze its feature hierarchy
self.model = timm.create_model(name, features_only=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to load a temporary model? I would try to avoid it if possible.

Copy link
Contributor Author

@brianhou0208 brianhou0208 Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that a temporary model is necessary because we need to determine feature_info.reduction() to classify the model as traditional, transformer, or VGG style. This affects the range of out_indices to be used:

common_kwargs["out_indices"] = tuple(range(depth))
  • If depth == 5, out_indices is
    • traditional-style (0, 1, 2, 3, 4)
    • transformer-style (0, 1, 2, 3)
    • vgg-style (0, 1, 2, 3, 4, 5)
  • If depth == 3, out_indices is
    • traditional-style (0, 1, 2)
    • transformer-style (0, 1)
    • vgg-style (0, 1, 2, 3)

Is there any other way to determine feature_info.reduction() in advance?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we slice features in forward instead of providing "out_indicies"? Otherwise, I would recommend using pretrained=False for the tmp model and maybe initialize it on the meta device to avoid double memory consumption.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In timm.create_model(), default is pretrained=False
I think initialize tmp model to torch.device("meta") is good

self.model = timm.create_model(name, pretrained=False, features_only=True).to("meta")

what do you think?

Copy link
Collaborator

@qubvel qubvel Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explicit pretrained=False would be nice, for meta it should be something like this:

with torch.device("meta"):
    tmp_model = timm.create_model(name, pretrained=False, features_only=True)

+ without self.
+ let's name it with tmp_

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets leave it as is for now, it can be optimized later if needed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't use additional variable names, it shouldn't take up extra memory?
renamed temp_model to self.model

Although the variable names will be a little confusing.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As is I mean:

# Load a temporary model to analyze its feature hierarchy
try:
    with torch.device("meta"):
        tmp_model = timm.create_model(name, features_only=True)
except Exception:
    tmp_model = timm.create_model(name, features_only=True)

sorry for the confusuion

Copy link
Collaborator

@qubvel qubvel Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't use additional variable names, it shouldn't take up extra memory?

don't think so, we still allocate twice.

  1. we have tmp model initialized and linked to self.model
  2. we initialize required model
  3. we unlink tmp model from self.model var name and link required one

two models exist at a time

Copy link
Contributor Author

@brianhou0208 brianhou0208 Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you are right, thanks for your explanation

1. rename temporary model
2. create temporary model on meta device to speed up
@brianhou0208 brianhou0208 marked this pull request as ready for review December 18, 2024 17:37
@brianhou0208
Copy link
Contributor Author

brianhou0208 commented Dec 18, 2024

Hi @qubvel ,

Thank you for your comment; it has made this PR more complete.
Some models in timm still have a few bugs, which I have reported to Ross.

However, I think they do not affect this PR. It's ready to be merged.

Copy link
Collaborator

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, can you please add one transfomers-like and one vgg-like encoders to tests? And we are good to merge

def get_encoders():
exclude_encoders = [
"senet154",
"resnext101_32x16d",
"resnext101_32x32d",
"resnext101_32x48d",
]
encoders = smp.encoders.get_encoder_names()
encoders = [e for e in encoders if e not in exclude_encoders]
encoders.append("tu-resnet34") # for timm universal encoder
return encoders

segmentation_models_pytorch/encoders/timm_universal.py Outdated Show resolved Hide resolved
@brianhou0208
Copy link
Contributor Author

brianhou0208 commented Dec 19, 2024

@qubvel It's ready to merge, please check

Since version 0.3.4, we have added new decoders and supported more timm encoders.
Shall we bump the version?

Copy link
Collaborator

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for delivering this super important feature!

@qubvel qubvel merged commit aff25b1 into qubvel-org:main Dec 19, 2024
12 checks passed
@qubvel
Copy link
Collaborator

qubvel commented Dec 19, 2024

Shall we bump the version?

Yeah, I will do a release 👍

@brianhou0208 brianhou0208 deleted the update_timm_universal branch December 25, 2024 14:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants