-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update timm universal (support transformer-style model) #1004
Conversation
Output Stride Not MatchingThe following models were removed from the list as their output strides do not match the expected values:
test codeimport torch
import segmentation_models_pytorch as smp
model_list = [
"inception_resnet_v2",
"inception_v3",
"inception_v4",
"legacy_xception",
"nasnetalarge",
"pnasnet5large",
]
if __name__ == "__main__":
x = torch.rand(1, 3, 256, 256)
for name in model_list:
model = smp.encoders.get_encoder(f"tu-{name}", weights=None).eval()
f = model(x)
print(name, [f_.detach().numpy().shape[2:] for f_ in f]) outputinception_resnet_v2 [(256, 256), (125, 125), (60, 60), (29, 29), (14, 14), (6, 6)]
inception_v3 [(256, 256), (125, 125), (60, 60), (29, 29), (14, 14), (6, 6)]
inception_v4 [(256, 256), (125, 125), (62, 62), (29, 29), (14, 14), (6, 6)]
legacy_xception [(256, 256), (125, 125), (63, 63), (32, 32), (16, 16), (8, 8)]
nasnetalarge [(256, 256), (127, 127), (64, 64), (32, 32), (16, 16), (8, 8)]
pnasnet5large [(256, 256), (127, 127), (64, 64), (32, 32), (16, 16), (8, 8)] Renamed / Deprecated ModelsThe following models remain functional but are deprecated in
test codeimport torch
import segmentation_models_pytorch as smp
model_list = [
"mnasnet_a1",
"mnasnet_b1",
"efficientnet_b2a",
"efficientnet_b3a",
"seresnext26tn_32x4d",
]
if __name__ == "__main__":
x = torch.rand(1, 3, 256, 256)
for name in model_list:
model = smp.encoders.get_encoder(f"tu-{name}", weights=None).eval() outputUserWarning: Mapping deprecated model name mnasnet_a1 to current semnasnet_100.
UserWarning: Mapping deprecated model name mnasnet_b1 to current mnasnet_100.
UserWarning: Mapping deprecated model name efficientnet_b2a to current efficientnet_b2.
UserWarning: Mapping deprecated model name efficientnet_b3a to current efficientnet_b3.
UserWarning: Mapping deprecated model name seresnext26tn_32x4d to current seresnext26t_32x4d. |
Add New Traditional-Style Models
new support models
new support modelscspdarknet53
darknet17
darknet21
darknet53
darknetaa53
sedarknet21
vgg11
vgg11_bn
vgg13
vgg13_bn
vgg16
vgg16_bn
vgg19
vgg19_bn Add New Transformer-Style ModelsChannel-First Models
new support models
Channel-Last Models
These models are clearly transformer-style models, but their format is channel-last. new support models
|
Hi @brianhou0208! Thanks for working with this challenging feature! My main concerns are:
Let me know what you think? |
Regarding NHWC format I got the answer from Ross, the following should work: getattr(model, "output_fmt", None) == "NHWC" That attribute is only set if models have NHWC format, that's why we have to use Also, there are some models that come with features in |
Hi @qubvel ,
Without using timm api test & resulttest codeimport torch
import timm
import segmentation_models_pytorch as smp
model_list = [
["dla34", 224],
["cspdarknet53", 224],
["efficientnet_x_b3", 224],
["efficientvit_m0", 224],
["inception_resnet_v2", 299],
["inception_v3", 299],
["inception_v4", 299],
["mambaout_tiny", 224],
["tresnet_m", 224],
["vit_tiny_patch16_224", 224],
]
if __name__ == "__main__":
for model_name, img_size in model_list:
x = torch.rand(1, 3, img_size, img_size)
model = timm.create_model(f"{model_name}", features_only=True).eval()
y = model(x)
print(f"timm-{model_name}-(C, H, W) = {(3, img_size, img_size)}")
print(f" Feature shape: {[f.detach().numpy().shape[1:] for f in y]}")
print(f" Feature channels: {model.feature_info.channels()}")
print(f" Feature reduction: {model.feature_info.reduction()}") outputtimm-dla34-(C, H, W) = (3, 224, 224)
Feature shape: [(32, 112, 112), (64, 56, 56), (128, 28, 28), (256, 14, 14), (512, 7, 7)]
Feature channels: [32, 64, 128, 256, 512]
Feature reduction: [2, 4, 8, 16, 32]
timm-cspdarknet53-(C, H, W) = (3, 224, 224)
Feature shape: [(32, 224, 224), (64, 112, 112), (128, 56, 56), (256, 28, 28), (512, 14, 14), (1024, 7, 7)]
Feature channels: [32, 64, 128, 256, 512, 1024]
Feature reduction: [1, 2, 4, 8, 16, 32]
timm-efficientnet_x_b3-(C, H, W) = (3, 224, 224)
Feature shape: [(96, 56, 56), (32, 56, 56), (48, 28, 28), (136, 14, 14), (384, 7, 7)]
Feature channels: [96, 32, 48, 136, 384]
Feature reduction: [2, 2, 4, 8, 16]
timm-efficientvit_m0-(C, H, W) = (3, 224, 224)
Feature shape: [(64, 14, 14), (128, 7, 7), (192, 4, 4)]
Feature channels: [64, 128, 192]
Feature reduction: [16, 32, 64]
timm-inception_resnet_v2-(C, H, W) = (3, 299, 299)
Feature shape: [(64, 147, 147), (192, 71, 71), (320, 35, 35), (1088, 17, 17), (1536, 8, 8)]
Feature channels: [64, 192, 320, 1088, 1536]
Feature reduction: [2, 4, 8, 16, 32]
timm-inception_v3-(C, H, W) = (3, 299, 299)
Feature shape: [(64, 147, 147), (192, 71, 71), (288, 35, 35), (768, 17, 17), (2048, 8, 8)]
Feature channels: [64, 192, 288, 768, 2048]
Feature reduction: [2, 4, 8, 16, 32]
timm-inception_v4-(C, H, W) = (3, 299, 299)
Feature shape: [(64, 147, 147), (160, 73, 73), (384, 35, 35), (1024, 17, 17), (1536, 8, 8)]
Feature channels: [64, 160, 384, 1024, 1536]
Feature reduction: [2, 4, 8, 16, 32]
timm-tresnet_m-(C, H, W) = (3, 224, 224)
Feature shape: [(64, 56, 56), (128, 28, 28), (1024, 14, 14), (2048, 7, 7)]
Feature channels: [64, 128, 1024, 2048]
Feature reduction: [4, 8, 16, 32]
timm-vit_tiny_patch16_224-(C, H, W) = (3, 224, 224)
Feature shape: [(192, 14, 14), (192, 14, 14), (192, 14, 14)]
Feature channels: [192, 192, 192]
Feature reduction: [16, 16, 16] However, you might still encounter cases like:
out_indices test & resulttestimport torch
import timm
import segmentation_models_pytorch as smp
model_list = [
["resnet18", 224, (0, 1, 2, 3, 4)],
["dla34", 224, (1, 2, 3, 4, 5)],
["mambaout_tiny", 224, (0, 1, 2, 3)],
["tresnet_m", 224, (1, 2, 3, 4)],
]
if __name__ == "__main__":
for model_name, img_size, out_indices in model_list:
x = torch.rand(1, 3, img_size, img_size)
model = timm.create_model(f"{model_name}", features_only=True, out_indices=out_indices).eval()
y = model(x)
print(f"timm-{model_name}")
print(f" Feature shape: {[f.detach().numpy().shape[1:] for f in y]}")
model = smp.encoders.get_encoder(f"tu-{model_name}").eval()
y = model(x)[1:]
print(f"smp-{model_name}")
print(f" Feature shape: {[f.detach().numpy().shape[1:] for f in y]}")
print() outputtimm-resnet18
Feature shape: [(64, 112, 112), (64, 56, 56), (128, 28, 28), (256, 14, 14), (512, 7, 7)]
smp-resnet18
Feature shape: [(64, 112, 112), (64, 56, 56), (128, 28, 28), (256, 14, 14), (512, 7, 7)]
timm-dla34
Feature shape: [(32, 112, 112), (64, 56, 56), (128, 28, 28), (256, 14, 14), (512, 7, 7)]
smp-dla34
Feature shape: [(32, 112, 112), (64, 56, 56), (128, 28, 28), (256, 14, 14), (512, 7, 7)]
timm-mambaout_tiny
Feature shape: [(56, 56, 96), (28, 28, 192), (14, 14, 384), (7, 7, 576)]
smp-mambaout_tiny
Feature shape: [(0, 112, 112), (96, 56, 56), (192, 28, 28), (384, 14, 14), (576, 7, 7)]
timm-tresnet_m
Feature shape: [(64, 56, 56), (128, 28, 28), (1024, 14, 14), (2048, 7, 7)]
smp-tresnet_m
Feature shape: [(0, 112, 112), (64, 56, 56), (128, 28, 28), (1024, 14, 14), (2048, 7, 7)] |
I spent some time reviewing all the models in Timm Support Backbone
Unsupported feature extraction: 34coat_lite_medium
coat_lite_medium_384
coat_lite_mini
coat_lite_small
coat_lite_tiny
coat_mini
coat_small
coat_tiny
convit_base
convit_small
convit_tiny
convmixer_768_32
convmixer_1024_20_ks9_p14
convmixer_1536_20
crossvit_9_240
crossvit_9_dagger_240
crossvit_15_240
crossvit_15_dagger_240
crossvit_15_dagger_408
crossvit_18_240
crossvit_18_dagger_240
crossvit_18_dagger_408
crossvit_base_240
crossvit_small_240
crossvit_tiny_240
gcvit_base
gcvit_small
gcvit_tiny
gcvit_xtiny
gcvit_xxtiny
tnt_b_patch16_224
tnt_s_patch16_224
visformer_small
visformer_tiny Tests models: 14test_byobnet
test_convnext
test_convnext2
test_convnext3
test_efficientnet
test_efficientnet_evos
test_efficientnet_gn
test_efficientnet_ln
test_mambaout
test_nfnet
test_resnet
test_vit
test_vit2
test_vit3 SMP Support Backbone
SMP Unsupported Backbone
Unsupported models: 310beit_base_patch16_224
beit_base_patch16_384
beit_large_patch16_224
beit_large_patch16_384
beit_large_patch16_512
beitv2_base_patch16_224
beitv2_large_patch16_224
cait_m36_384
cait_m48_448
cait_s24_224
cait_s24_384
cait_s36_384
cait_xs24_384
cait_xxs24_224
cait_xxs24_384
cait_xxs36_224
cait_xxs36_384
deit3_base_patch16_224
deit3_base_patch16_384
deit3_huge_patch14_224
deit3_large_patch16_224
deit3_large_patch16_384
deit3_medium_patch16_224
deit3_small_patch16_224
deit3_small_patch16_384
deit_base_distilled_patch16_224
deit_base_distilled_patch16_384
deit_base_patch16_224
deit_base_patch16_384
deit_small_distilled_patch16_224
deit_small_patch16_224
deit_tiny_distilled_patch16_224
deit_tiny_patch16_224
efficientnet_h_b5
efficientnet_x_b3
efficientnet_x_b5
efficientvit_m0
efficientvit_m1
efficientvit_m2
efficientvit_m3
efficientvit_m4
efficientvit_m5
eva02_base_patch14_224
eva02_base_patch14_448
eva02_base_patch16_clip_224
eva02_enormous_patch14_clip_224
eva02_large_patch14_224
eva02_large_patch14_448
eva02_large_patch14_clip_224
eva02_large_patch14_clip_336
eva02_small_patch14_224
eva02_small_patch14_336
eva02_tiny_patch14_224
eva02_tiny_patch14_336
eva_giant_patch14_224
eva_giant_patch14_336
eva_giant_patch14_560
eva_giant_patch14_clip_224
eva_large_patch14_196
eva_large_patch14_336
flexivit_base
flexivit_large
flexivit_small
gmixer_12_224
gmixer_24_224
gmlp_b16_224
gmlp_s16_224
gmlp_ti16_224
inception_resnet_v2
inception_v3
inception_v4
legacy_xception
levit_128
levit_128s
levit_192
levit_256
levit_256d
levit_384
levit_384_s8
levit_512
levit_512_s8
levit_512d
levit_conv_128
levit_conv_128s
levit_conv_192
levit_conv_256
levit_conv_256d
levit_conv_384
levit_conv_384_s8
levit_conv_512
levit_conv_512_s8
levit_conv_512d
mixer_b16_224
mixer_b32_224
mixer_l16_224
mixer_l32_224
mixer_s16_224
mixer_s32_224
nasnetalarge
pit_b_224
pit_b_distilled_224
pit_s_224
pit_s_distilled_224
pit_ti_224
pit_ti_distilled_224
pit_xs_224
pit_xs_distilled_224
pnasnet5large
resmlp_12_224
resmlp_24_224
resmlp_36_224
resmlp_big_24_224
samvit_base_patch16
samvit_base_patch16_224
samvit_huge_patch16
samvit_large_patch16
sequencer2d_l
sequencer2d_m
sequencer2d_s
vit_base_mci_224
vit_base_patch8_224
vit_base_patch14_dinov2
vit_base_patch14_reg4_dinov2
vit_base_patch16_18x2_224
vit_base_patch16_224
vit_base_patch16_224_miil
vit_base_patch16_384
vit_base_patch16_clip_224
vit_base_patch16_clip_384
vit_base_patch16_clip_quickgelu_224
vit_base_patch16_gap_224
vit_base_patch16_plus_240
vit_base_patch16_plus_clip_240
vit_base_patch16_reg4_gap_256
vit_base_patch16_rope_reg1_gap_256
vit_base_patch16_rpn_224
vit_base_patch16_siglip_224
vit_base_patch16_siglip_256
vit_base_patch16_siglip_384
vit_base_patch16_siglip_512
vit_base_patch16_siglip_gap_224
vit_base_patch16_siglip_gap_256
vit_base_patch16_siglip_gap_384
vit_base_patch16_siglip_gap_512
vit_base_patch16_xp_224
vit_base_patch32_224
vit_base_patch32_384
vit_base_patch32_clip_224
vit_base_patch32_clip_256
vit_base_patch32_clip_384
vit_base_patch32_clip_448
vit_base_patch32_clip_quickgelu_224
vit_base_patch32_plus_256
vit_base_r26_s32_224
vit_base_r50_s16_224
vit_base_r50_s16_384
vit_base_resnet26d_224
vit_base_resnet50d_224
vit_betwixt_patch16_gap_256
vit_betwixt_patch16_reg1_gap_256
vit_betwixt_patch16_reg4_gap_256
vit_betwixt_patch16_reg4_gap_384
vit_betwixt_patch16_rope_reg4_gap_256
vit_betwixt_patch32_clip_224
vit_giant_patch14_224
vit_giant_patch14_clip_224
vit_giant_patch14_dinov2
vit_giant_patch14_reg4_dinov2
vit_giant_patch16_gap_224
vit_gigantic_patch14_224
vit_gigantic_patch14_clip_224
vit_gigantic_patch14_clip_quickgelu_224
vit_huge_patch14_224
vit_huge_patch14_clip_224
vit_huge_patch14_clip_336
vit_huge_patch14_clip_378
vit_huge_patch14_clip_quickgelu_224
vit_huge_patch14_clip_quickgelu_378
vit_huge_patch14_gap_224
vit_huge_patch14_xp_224
vit_huge_patch16_gap_448
vit_intern300m_patch14_448
vit_large_patch14_224
vit_large_patch14_clip_224
vit_large_patch14_clip_336
vit_large_patch14_clip_quickgelu_224
vit_large_patch14_clip_quickgelu_336
vit_large_patch14_dinov2
vit_large_patch14_reg4_dinov2
vit_large_patch14_xp_224
vit_large_patch16_224
vit_large_patch16_384
vit_large_patch16_siglip_256
vit_large_patch16_siglip_384
vit_large_patch16_siglip_gap_256
vit_large_patch16_siglip_gap_384
vit_large_patch32_224
vit_large_patch32_384
vit_large_r50_s32_224
vit_large_r50_s32_384
vit_little_patch16_reg1_gap_256
vit_little_patch16_reg4_gap_256
vit_medium_patch16_clip_224
vit_medium_patch16_gap_240
vit_medium_patch16_gap_256
vit_medium_patch16_gap_384
vit_medium_patch16_reg1_gap_256
vit_medium_patch16_reg4_gap_256
vit_medium_patch16_rope_reg1_gap_256
vit_medium_patch32_clip_224
vit_mediumd_patch16_reg4_gap_256
vit_mediumd_patch16_reg4_gap_384
vit_mediumd_patch16_rope_reg1_gap_256
vit_pwee_patch16_reg1_gap_256
vit_relpos_base_patch16_224
vit_relpos_base_patch16_cls_224
vit_relpos_base_patch16_clsgap_224
vit_relpos_base_patch16_plus_240
vit_relpos_base_patch16_rpn_224
vit_relpos_base_patch32_plus_rpn_256
vit_relpos_medium_patch16_224
vit_relpos_medium_patch16_cls_224
vit_relpos_medium_patch16_rpn_224
vit_relpos_small_patch16_224
vit_relpos_small_patch16_rpn_224
vit_small_patch8_224
vit_small_patch14_dinov2
vit_small_patch14_reg4_dinov2
vit_small_patch16_18x2_224
vit_small_patch16_36x1_224
vit_small_patch16_224
vit_small_patch16_384
vit_small_patch32_224
vit_small_patch32_384
vit_small_r26_s32_224
vit_small_r26_s32_384
vit_small_resnet26d_224
vit_small_resnet50d_s16_224
vit_so150m_patch16_reg4_gap_256
vit_so150m_patch16_reg4_map_256
vit_so400m_patch14_siglip_224
vit_so400m_patch14_siglip_378
vit_so400m_patch14_siglip_384
vit_so400m_patch14_siglip_gap_224
vit_so400m_patch14_siglip_gap_378
vit_so400m_patch14_siglip_gap_384
vit_so400m_patch14_siglip_gap_448
vit_so400m_patch14_siglip_gap_896
vit_so400m_patch16_siglip_256
vit_so400m_patch16_siglip_gap_256
vit_srelpos_medium_patch16_224
vit_srelpos_small_patch16_224
vit_tiny_patch16_224
vit_tiny_patch16_384
vit_tiny_r_s16_p8_224
vit_tiny_r_s16_p8_384
vit_wee_patch16_reg1_gap_256
vit_xsmall_patch16_clip_224
vitamin_base_224
vitamin_large2_224
vitamin_large2_256
vitamin_large2_336
vitamin_large2_384
vitamin_large_224
vitamin_large_256
vitamin_large_336
vitamin_large_384
vitamin_small_224
vitamin_xlarge_256
vitamin_xlarge_336
vitamin_xlarge_384
volo_d1_224
volo_d1_384
volo_d2_224
volo_d2_384
volo_d3_224
volo_d3_448
volo_d4_224
volo_d4_448
volo_d5_224
volo_d5_448
volo_d5_512
xcit_large_24_p8_224
xcit_large_24_p8_384
xcit_large_24_p16_224
xcit_large_24_p16_384
xcit_medium_24_p8_224
xcit_medium_24_p8_384
xcit_medium_24_p16_224
xcit_medium_24_p16_384
xcit_nano_12_p8_224
xcit_nano_12_p8_384
xcit_nano_12_p16_224
xcit_nano_12_p16_384
xcit_small_12_p8_224
xcit_small_12_p8_384
xcit_small_12_p16_224
xcit_small_12_p16_384
xcit_small_24_p8_224
xcit_small_24_p8_384
xcit_small_24_p16_224
xcit_small_24_p16_384
xcit_tiny_12_p8_224
xcit_tiny_12_p8_384
xcit_tiny_12_p16_224
xcit_tiny_12_p16_384
xcit_tiny_24_p8_224
xcit_tiny_24_p8_384
xcit_tiny_24_p16_224
xcit_tiny_24_p16_384 Check for unsupported models in SMPimport torch
import timm
if __name__ == "__main__":
for model_name in model_list:
model = timm.create_model(f"{model_name}", features_only=True).eval()
is_channel_last = getattr(model, "output_fmt", None) == "NHWC"
print(f"{model_name} {is_channel_last}")
print(model.feature_info.reduction()) With downsample feature: 46efficientnet_h_b5 False
[2, 2, 4, 8, 16]
efficientnet_x_b3 False
[2, 2, 4, 8, 16]
efficientnet_x_b5 False
[2, 2, 4, 8, 16]
efficientvit_m0 False
[16, 32, 64]
efficientvit_m1 False
[16, 32, 64]
efficientvit_m2 False
[16, 32, 64]
efficientvit_m3 False
[16, 32, 64]
efficientvit_m4 False
[16, 32, 64]
efficientvit_m5 False
[16, 32, 64]
inception_resnet_v2 False
[2, 4, 8, 16, 32]
inception_v3 False
[2, 4, 8, 16, 32]
inception_v4 False
[2, 4, 8, 16, 32]
legacy_xception False
[2, 4, 8, 16, 32]
levit_128 False
[16, 32, 64]
levit_128s False
[16, 32, 64]
levit_192 False
[16, 32, 64]
levit_256 False
[16, 32, 64]
levit_256d False
[16, 32, 64]
levit_384 False
[16, 32, 64]
levit_384_s8 False
[8, 16, 32]
levit_512 False
[16, 32, 64]
levit_512_s8 False
[8, 16, 32]
levit_512d False
[16, 32, 64]
levit_conv_128 False
[16, 32, 64]
levit_conv_128s False
[16, 32, 64]
levit_conv_192 False
[16, 32, 64]
levit_conv_256 False
[16, 32, 64]
levit_conv_256d False
[16, 32, 64]
levit_conv_384 False
[16, 32, 64]
levit_conv_384_s8 False
[8, 16, 32]
levit_conv_512 False
[16, 32, 64]
levit_conv_512_s8 False
[8, 16, 32]
levit_conv_512d False
[16, 32, 64]
nasnetalarge False
[2, 4, 8, 16, 32]
pit_b_224 False
[6, 12, 24]
pit_b_distilled_224 False
[6, 12, 24]
pit_s_224 False
[7, 14, 28]
pit_s_distilled_224 False
[7, 14, 28]
pit_ti_224 False
[7, 14, 28]
pit_ti_distilled_224 False
[7, 14, 28]
pit_xs_224 False
[7, 14, 28]
pit_xs_distilled_224 False
[7, 14, 28]
pnasnet5large False
[2, 4, 8, 16, 32]
sequencer2d_l True
[7, 14, 14]
sequencer2d_m True
[7, 14, 14]
sequencer2d_s True
[7, 14, 14] Without downsample feature: 264beit_base_patch16_224 False
[16, 16, 16]
beit_base_patch16_384 False
[16, 16, 16]
beit_large_patch16_224 False
[16, 16, 16]
beit_large_patch16_384 False
[16, 16, 16]
beit_large_patch16_512 False
[16, 16, 16]
beitv2_base_patch16_224 False
[16, 16, 16]
beitv2_large_patch16_224 False
[16, 16, 16]
cait_m36_384 False
[16, 16, 16]
cait_m48_448 False
[16, 16, 16]
cait_s24_224 False
[16, 16, 16]
cait_s24_384 False
[16, 16, 16]
cait_s36_384 False
[16, 16, 16]
cait_xs24_384 False
[16, 16, 16]
cait_xxs24_224 False
[16, 16, 16]
cait_xxs24_384 False
[16, 16, 16]
cait_xxs36_224 False
[16, 16, 16]
cait_xxs36_384 False
[16, 16, 16]
deit3_base_patch16_224 False
[16, 16, 16]
deit3_base_patch16_384 False
[16, 16, 16]
deit3_huge_patch14_224 False
[14, 14, 14]
deit3_large_patch16_224 False
[16, 16, 16]
deit3_large_patch16_384 False
[16, 16, 16]
deit3_medium_patch16_224 False
[16, 16, 16]
deit3_small_patch16_224 False
[16, 16, 16]
deit3_small_patch16_384 False
[16, 16, 16]
deit_base_distilled_patch16_224 False
[16, 16, 16]
deit_base_distilled_patch16_384 False
[16, 16, 16]
deit_base_patch16_224 False
[16, 16, 16]
deit_base_patch16_384 False
[16, 16, 16]
deit_small_distilled_patch16_224 False
[16, 16, 16]
deit_small_patch16_224 False
[16, 16, 16]
deit_tiny_distilled_patch16_224 False
[16, 16, 16]
deit_tiny_patch16_224 False
[16, 16, 16]
eva02_base_patch14_224 False
[14, 14, 14]
eva02_base_patch14_448 False
[14, 14, 14]
eva02_base_patch16_clip_224 False
[16, 16, 16]
eva02_enormous_patch14_clip_224 False
[14, 14, 14]
eva02_large_patch14_224 False
[14, 14, 14]
eva02_large_patch14_448 False
[14, 14, 14]
eva02_large_patch14_clip_224 False
[14, 14, 14]
eva02_large_patch14_clip_336 False
[14, 14, 14]
eva02_small_patch14_224 False
[14, 14, 14]
eva02_small_patch14_336 False
[14, 14, 14]
eva02_tiny_patch14_224 False
[14, 14, 14]
eva02_tiny_patch14_336 False
[14, 14, 14]
eva_giant_patch14_224 False
[14, 14, 14]
eva_giant_patch14_336 False
[14, 14, 14]
eva_giant_patch14_560 False
[14, 14, 14]
eva_giant_patch14_clip_224 False
[14, 14, 14]
eva_large_patch14_196 False
[14, 14, 14]
eva_large_patch14_336 False
[14, 14, 14]
flexivit_base False
[16, 16, 16]
flexivit_large False
[16, 16, 16]
flexivit_small False
[16, 16, 16]
gmixer_12_224 False
[16, 16, 16]
gmixer_24_224 False
[16, 16, 16]
gmlp_b16_224 False
[16, 16, 16]
gmlp_s16_224 False
[16, 16, 16]
gmlp_ti16_224 False
[16, 16, 16]
mixer_b16_224 False
[16, 16, 16]
mixer_b32_224 False
[32, 32, 32]
mixer_l16_224 False
[16, 16, 16]
mixer_l32_224 False
[32, 32, 32]
mixer_s16_224 False
[16, 16, 16]
mixer_s32_224 False
[32, 32, 32]
resmlp_12_224 False
[16, 16, 16]
resmlp_24_224 False
[16, 16, 16]
resmlp_36_224 False
[16, 16, 16]
resmlp_big_24_224 False
[8, 8, 8]
samvit_base_patch16 False
[16, 16, 16]
samvit_base_patch16_224 False
[16, 16, 16]
samvit_huge_patch16 False
[16, 16, 16]
samvit_large_patch16 False
[16, 16, 16]
vit_base_mci_224 False
[16, 16, 16]
vit_base_patch8_224 False
[8, 8, 8]
vit_base_patch14_dinov2 False
[14, 14, 14]
vit_base_patch14_reg4_dinov2 False
[14, 14, 14]
vit_base_patch16_18x2_224 False
[16, 16, 16]
vit_base_patch16_224 False
[16, 16, 16]
vit_base_patch16_224_miil False
[16, 16, 16]
vit_base_patch16_384 False
[16, 16, 16]
vit_base_patch16_clip_224 False
[16, 16, 16]
vit_base_patch16_clip_384 False
[16, 16, 16]
vit_base_patch16_clip_quickgelu_224 False
[16, 16, 16]
vit_base_patch16_gap_224 False
[16, 16, 16]
vit_base_patch16_plus_240 False
[16, 16, 16]
vit_base_patch16_plus_clip_240 False
[16, 16, 16]
vit_base_patch16_reg4_gap_256 False
[16, 16, 16]
vit_base_patch16_rope_reg1_gap_256 False
[16, 16, 16]
vit_base_patch16_rpn_224 False
[16, 16, 16]
vit_base_patch16_siglip_224 False
[16, 16, 16]
vit_base_patch16_siglip_256 False
[16, 16, 16]
vit_base_patch16_siglip_384 False
[16, 16, 16]
vit_base_patch16_siglip_512 False
[16, 16, 16]
vit_base_patch16_siglip_gap_224 False
[16, 16, 16]
vit_base_patch16_siglip_gap_256 False
[16, 16, 16]
vit_base_patch16_siglip_gap_384 False
[16, 16, 16]
vit_base_patch16_siglip_gap_512 False
[16, 16, 16]
vit_base_patch16_xp_224 False
[16, 16, 16]
vit_base_patch32_224 False
[32, 32, 32]
vit_base_patch32_384 False
[32, 32, 32]
vit_base_patch32_clip_224 False
[32, 32, 32]
vit_base_patch32_clip_256 False
[32, 32, 32]
vit_base_patch32_clip_384 False
[32, 32, 32]
vit_base_patch32_clip_448 False
[32, 32, 32]
vit_base_patch32_clip_quickgelu_224 False
[32, 32, 32]
vit_base_patch32_plus_256 False
[32, 32, 32]
vit_base_r26_s32_224 False
[32, 32, 32]
vit_base_r50_s16_224 False
[16, 16, 16]
vit_base_r50_s16_384 False
[16, 16, 16]
vit_base_resnet26d_224 False
[32, 32, 32]
vit_base_resnet50d_224 False
[32, 32, 32]
vit_betwixt_patch16_gap_256 False
[16, 16, 16]
vit_betwixt_patch16_reg1_gap_256 False
[16, 16, 16]
vit_betwixt_patch16_reg4_gap_256 False
[16, 16, 16]
vit_betwixt_patch16_reg4_gap_384 False
[16, 16, 16]
vit_betwixt_patch16_rope_reg4_gap_256 False
[16, 16, 16]
vit_betwixt_patch32_clip_224 False
[32, 32, 32]
vit_giant_patch14_224 False
[14, 14, 14]
vit_giant_patch14_clip_224 False
[14, 14, 14]
vit_giant_patch14_dinov2 False
[14, 14, 14]
vit_giant_patch14_reg4_dinov2 False
[14, 14, 14]
vit_giant_patch16_gap_224 False
[16, 16, 16]
vit_gigantic_patch14_224 False
[14, 14, 14]
vit_gigantic_patch14_clip_224 False
[14, 14, 14]
vit_gigantic_patch14_clip_quickgelu_224 False
[14, 14, 14]
vit_huge_patch14_224 False
[14, 14, 14]
vit_huge_patch14_clip_224 False
[14, 14, 14]
vit_huge_patch14_clip_336 False
[14, 14, 14]
vit_huge_patch14_clip_378 False
[14, 14, 14]
vit_huge_patch14_clip_quickgelu_224 False
[14, 14, 14]
vit_huge_patch14_clip_quickgelu_378 False
[14, 14, 14]
vit_huge_patch14_gap_224 False
[14, 14, 14]
vit_huge_patch14_xp_224 False
[14, 14, 14]
vit_huge_patch16_gap_448 False
[16, 16, 16]
vit_intern300m_patch14_448 False
[14, 14, 14]
vit_large_patch14_224 False
[14, 14, 14]
vit_large_patch14_clip_224 False
[14, 14, 14]
vit_large_patch14_clip_336 False
[14, 14, 14]
vit_large_patch14_clip_quickgelu_224 False
[14, 14, 14]
vit_large_patch14_clip_quickgelu_336 False
[14, 14, 14]
vit_large_patch14_dinov2 False
[14, 14, 14]
vit_large_patch14_reg4_dinov2 False
[14, 14, 14]
vit_large_patch14_xp_224 False
[14, 14, 14]
vit_large_patch16_224 False
[16, 16, 16]
vit_large_patch16_384 False
[16, 16, 16]
vit_large_patch16_siglip_256 False
[16, 16, 16]
vit_large_patch16_siglip_384 False
[16, 16, 16]
vit_large_patch16_siglip_gap_256 False
[16, 16, 16]
vit_large_patch16_siglip_gap_384 False
[16, 16, 16]
vit_large_patch32_224 False
[32, 32, 32]
vit_large_patch32_384 False
[32, 32, 32]
vit_large_r50_s32_224 False
[32, 32, 32]
vit_large_r50_s32_384 False
[32, 32, 32]
vit_little_patch16_reg1_gap_256 False
[16, 16, 16]
vit_little_patch16_reg4_gap_256 False
[16, 16, 16]
vit_medium_patch16_clip_224 False
[16, 16, 16]
vit_medium_patch16_gap_240 False
[16, 16, 16]
vit_medium_patch16_gap_256 False
[16, 16, 16]
vit_medium_patch16_gap_384 False
[16, 16, 16]
vit_medium_patch16_reg1_gap_256 False
[16, 16, 16]
vit_medium_patch16_reg4_gap_256 False
[16, 16, 16]
vit_medium_patch16_rope_reg1_gap_256 False
[16, 16, 16]
vit_medium_patch32_clip_224 False
[32, 32, 32]
vit_mediumd_patch16_reg4_gap_256 False
[16, 16, 16]
vit_mediumd_patch16_reg4_gap_384 False
[16, 16, 16]
vit_mediumd_patch16_rope_reg1_gap_256 False
[16, 16, 16]
vit_pwee_patch16_reg1_gap_256 False
[16, 16, 16]
vit_relpos_base_patch16_224 False
[16, 16, 16]
vit_relpos_base_patch16_cls_224 False
[16, 16, 16]
vit_relpos_base_patch16_clsgap_224 False
[16, 16, 16]
vit_relpos_base_patch16_plus_240 False
[16, 16, 16]
vit_relpos_base_patch16_rpn_224 False
[16, 16, 16]
vit_relpos_base_patch32_plus_rpn_256 False
[32, 32, 32]
vit_relpos_medium_patch16_224 False
[16, 16, 16]
vit_relpos_medium_patch16_cls_224 False
[16, 16, 16]
vit_relpos_medium_patch16_rpn_224 False
[16, 16, 16]
vit_relpos_small_patch16_224 False
[16, 16, 16]
vit_relpos_small_patch16_rpn_224 False
[16, 16, 16]
vit_small_patch8_224 False
[8, 8, 8]
vit_small_patch14_dinov2 False
[14, 14, 14]
vit_small_patch14_reg4_dinov2 False
[14, 14, 14]
vit_small_patch16_18x2_224 False
[16, 16, 16]
vit_small_patch16_36x1_224 False
[16, 16, 16]
vit_small_patch16_224 False
[16, 16, 16]
vit_small_patch16_384 False
[16, 16, 16]
vit_small_patch32_224 False
[32, 32, 32]
vit_small_patch32_384 False
[32, 32, 32]
vit_small_r26_s32_224 False
[32, 32, 32]
vit_small_r26_s32_384 False
[32, 32, 32]
vit_small_resnet26d_224 False
[32, 32, 32]
vit_small_resnet50d_s16_224 False
[16, 16, 16]
vit_so150m_patch16_reg4_gap_256 False
[16, 16, 16]
vit_so150m_patch16_reg4_map_256 False
[16, 16, 16]
vit_so400m_patch14_siglip_224 False
[14, 14, 14]
vit_so400m_patch14_siglip_378 False
[14, 14, 14]
vit_so400m_patch14_siglip_384 False
[14, 14, 14]
vit_so400m_patch14_siglip_gap_224 False
[14, 14, 14]
vit_so400m_patch14_siglip_gap_378 False
[14, 14, 14]
vit_so400m_patch14_siglip_gap_384 False
[14, 14, 14]
vit_so400m_patch14_siglip_gap_448 False
[14, 14, 14]
vit_so400m_patch14_siglip_gap_896 False
[14, 14, 14]
vit_so400m_patch16_siglip_256 False
[16, 16, 16]
vit_so400m_patch16_siglip_gap_256 False
[16, 16, 16]
vit_srelpos_medium_patch16_224 False
[16, 16, 16]
vit_srelpos_small_patch16_224 False
[16, 16, 16]
vit_tiny_patch16_224 False
[16, 16, 16]
vit_tiny_patch16_384 False
[16, 16, 16]
vit_tiny_r_s16_p8_224 False
[32, 32, 32]
vit_tiny_r_s16_p8_384 False
[32, 32, 32]
vit_wee_patch16_reg1_gap_256 False
[16, 16, 16]
vit_xsmall_patch16_clip_224 False
[16, 16, 16]
vitamin_base_224 False
[16, 16, 16]
vitamin_large2_224 False
[16, 16, 16]
vitamin_large2_256 False
[16, 16, 16]
vitamin_large2_336 False
[16, 16, 16]
vitamin_large2_384 False
[16, 16, 16]
vitamin_large_224 False
[16, 16, 16]
vitamin_large_256 False
[16, 16, 16]
vitamin_large_336 False
[16, 16, 16]
vitamin_large_384 False
[16, 16, 16]
vitamin_small_224 False
[16, 16, 16]
vitamin_xlarge_256 False
[16, 16, 16]
vitamin_xlarge_336 False
[16, 16, 16]
vitamin_xlarge_384 False
[16, 16, 16]
volo_d1_224 False
[16, 16, 16]
volo_d1_384 False
[16, 16, 16]
volo_d2_224 False
[16, 16, 16]
volo_d2_384 False
[16, 16, 16]
volo_d3_224 False
[16, 16, 16]
volo_d3_448 False
[16, 16, 16]
volo_d4_224 False
[16, 16, 16]
volo_d4_448 False
[16, 16, 16]
volo_d5_224 False
[16, 16, 16]
volo_d5_448 False
[16, 16, 16]
volo_d5_512 False
[16, 16, 16]
xcit_large_24_p8_224 False
[8, 8, 8]
xcit_large_24_p8_384 False
[8, 8, 8]
xcit_large_24_p16_224 False
[16, 16, 16]
xcit_large_24_p16_384 False
[16, 16, 16]
xcit_medium_24_p8_224 False
[8, 8, 8]
xcit_medium_24_p8_384 False
[8, 8, 8]
xcit_medium_24_p16_224 False
[16, 16, 16]
xcit_medium_24_p16_384 False
[16, 16, 16]
xcit_nano_12_p8_224 False
[8, 8, 8]
xcit_nano_12_p8_384 False
[8, 8, 8]
xcit_nano_12_p16_224 False
[16, 16, 16]
xcit_nano_12_p16_384 False
[16, 16, 16]
xcit_small_12_p8_224 False
[8, 8, 8]
xcit_small_12_p8_384 False
[8, 8, 8]
xcit_small_12_p16_224 False
[16, 16, 16]
xcit_small_12_p16_384 False
[16, 16, 16]
xcit_small_24_p8_224 False
[8, 8, 8]
xcit_small_24_p8_384 False
[8, 8, 8]
xcit_small_24_p16_224 False
[16, 16, 16]
xcit_small_24_p16_384 False
[16, 16, 16]
xcit_tiny_12_p8_224 False
[8, 8, 8]
xcit_tiny_12_p8_384 False
[8, 8, 8]
xcit_tiny_12_p16_224 False
[16, 16, 16]
xcit_tiny_12_p16_384 False
[16, 16, 16]
xcit_tiny_24_p8_224 False
[8, 8, 8]
xcit_tiny_24_p8_384 False
[8, 8, 8]
xcit_tiny_24_p16_224 False
[16, 16, 16]
xcit_tiny_24_p16_384 False
[16, 16, 16] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @brianhou0208! Thanks for continuing to work on this 🚀 It already looks really great. I just have one question:
name, **_merge_kwargs_no_duplicates(common_kwargs, kwargs) | ||
) | ||
# Load a temporary model to analyze its feature hierarchy | ||
self.model = timm.create_model(name, features_only=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to load a temporary model? I would try to avoid it if possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe that a temporary model is necessary because we need to determine feature_info.reduction()
to classify the model as traditional, transformer, or VGG style. This affects the range of out_indices
to be used:
common_kwargs["out_indices"] = tuple(range(depth))
- If
depth == 5
,out_indices
is- traditional-style
(0, 1, 2, 3, 4)
- transformer-style
(0, 1, 2, 3)
- vgg-style
(0, 1, 2, 3, 4, 5)
- traditional-style
- If
depth == 3
,out_indices
is- traditional-style
(0, 1, 2)
- transformer-style
(0, 1)
- vgg-style
(0, 1, 2, 3)
- traditional-style
Is there any other way to determine feature_info.reduction()
in advance?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we slice features in forward
instead of providing "out_indicies"? Otherwise, I would recommend using pretrained=False
for the tmp model and maybe initialize it on the meta device to avoid double memory consumption.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In timm.create_model()
, default is pretrained=False
I think initialize tmp model to torch.device("meta")
is good
self.model = timm.create_model(name, pretrained=False, features_only=True).to("meta")
what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Explicit pretrained=False
would be nice, for meta it should be something like this:
with torch.device("meta"):
tmp_model = timm.create_model(name, pretrained=False, features_only=True)
+ without self.
+ let's name it with tmp_
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets leave it as is for now, it can be optimized later if needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we don't use additional variable names, it shouldn't take up extra memory?
renamed temp_model
to self.model
Although the variable names will be a little confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As is I mean:
# Load a temporary model to analyze its feature hierarchy
try:
with torch.device("meta"):
tmp_model = timm.create_model(name, features_only=True)
except Exception:
tmp_model = timm.create_model(name, features_only=True)
sorry for the confusuion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we don't use additional variable names, it shouldn't take up extra memory?
don't think so, we still allocate twice.
- we have tmp model initialized and linked to
self.model
- we initialize required model
- we unlink tmp model from
self.model
var name and link required one
two models exist at a time
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you are right, thanks for your explanation
1. rename temporary model 2. create temporary model on meta device to speed up
Hi @qubvel , Thank you for your comment; it has made this PR more complete. However, I think they do not affect this PR. It's ready to be merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, can you please add one transfomers-like and one vgg-like encoders to tests? And we are good to merge
segmentation_models.pytorch/tests/test_models.py
Lines 7 to 17 in 34ee31d
def get_encoders(): | |
exclude_encoders = [ | |
"senet154", | |
"resnext101_32x16d", | |
"resnext101_32x32d", | |
"resnext101_32x48d", | |
] | |
encoders = smp.encoders.get_encoder_names() | |
encoders = [e for e in encoders if e not in exclude_encoders] | |
encoders.append("tu-resnet34") # for timm universal encoder | |
return encoders |
@qubvel It's ready to merge, please check Since version |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for delivering this super important feature!
Yeah, I will do a release 👍 |
Hi, @qubvel ,
This PR improves
TimmUniversalEncoder
to better support transformer-style models, updates the documentation, and ensures compatibility withtimm==1.0.12
.Key Updates
TimmUniversalEncoder
to seamlessly handle both traditional (CNN-based) and transformer-style models.timm==1.0.12
Details of Changes
TimmUniversalEncoder
feature_info.reduction()
to determine whether a model is traditional or transformer-style.out_indices
for models liketresnet
anddla
to ensure accurate feature extraction.Documentation
Testing
timm==1.0.12
for all models withfeatures_only
enabled.