From aff25b10c1737d542079adfed7477a96fd990408 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Thu, 19 Dec 2024 23:16:34 +0800 Subject: [PATCH] Update timm universal (support transformer-style model) (#1004) * Update timm_universal.py * Fix ruff style and typing * Update encoders_timm.rst * Fix typo error * Fix typo error & Update doc * Fix typo error * Support channel-last format * Update encoders_timm.rst * Update timm_universal.py * Fix ruff style * Update timm_universal.py 1. rename temporary model 2. create temporary model on meta device to speed up * Add tests/test_models & fix type * Update test_models.py * Update test_models.py --- docs/encoders_timm.rst | 839 +++++++++++++++--- .../encoders/timm_universal.py | 174 +++- tests/test_models.py | 21 +- 3 files changed, 879 insertions(+), 155 deletions(-) diff --git a/docs/encoders_timm.rst b/docs/encoders_timm.rst index 26a18a64..31c8396e 100644 --- a/docs/encoders_timm.rst +++ b/docs/encoders_timm.rst @@ -1,5 +1,5 @@ 🎯 Timm Encoders -~~~~~~~~~~~~~~~~ +================ Pytorch Image Models (a.k.a. timm) has a lot of pretrained models and interface which allows using these models as encoders in smp, however, not all models are supported @@ -9,26 +9,20 @@ however, not all models are supported Below is a table of suitable encoders (for DeepLabV3, DeepLabV3+, and PAN dilation support is needed also) -Total number of encoders: 549 +Total number of encoders: 812 (593+219) .. note:: To use following encoders you have to add prefix ``tu-``, e.g. ``tu-adv_inception_v3`` +Traditional-Style +~~~~~~~~~~~~~~~~~ + +These models typically produce feature maps at the following downsampling scales relative to the input resolution: 1/2, 1/4, 1/8, 1/16, and 1/32 +----------------------------------+------------------+ | Encoder name | Support dilation | +==================================+==================+ -| SelecSls42 | | -+----------------------------------+------------------+ -| SelecSls42b | | -+----------------------------------+------------------+ -| SelecSls60 | | -+----------------------------------+------------------+ -| SelecSls60b | | -+----------------------------------+------------------+ -| SelecSls84 | | -+----------------------------------+------------------+ | bat_resnext26ts | ✅ | +----------------------------------+------------------+ | botnet26t_256 | ✅ | @@ -105,6 +99,8 @@ Total number of encoders: 549 +----------------------------------+------------------+ | cs3sedarknet_xdw | ✅ | +----------------------------------+------------------+ +| cspdarknet53 | ✅ | ++----------------------------------+------------------+ | cspresnet50 | ✅ | +----------------------------------+------------------+ | cspresnet50d | ✅ | @@ -113,6 +109,14 @@ Total number of encoders: 549 +----------------------------------+------------------+ | cspresnext50 | ✅ | +----------------------------------+------------------+ +| darknet17 | ✅ | ++----------------------------------+------------------+ +| darknet21 | ✅ | ++----------------------------------+------------------+ +| darknet53 | ✅ | ++----------------------------------+------------------+ +| darknetaa53 | ✅ | ++----------------------------------+------------------+ | densenet121 | | +----------------------------------+------------------+ | densenet161 | | @@ -125,14 +129,6 @@ Total number of encoders: 549 +----------------------------------+------------------+ | densenetblur121d | | +----------------------------------+------------------+ -| dla102 | | -+----------------------------------+------------------+ -| dla102x | | -+----------------------------------+------------------+ -| dla102x2 | | -+----------------------------------+------------------+ -| dla169 | | -+----------------------------------+------------------+ | dla34 | | +----------------------------------+------------------+ | dla46_c | | @@ -149,6 +145,14 @@ Total number of encoders: 549 +----------------------------------+------------------+ | dla60x_c | | +----------------------------------+------------------+ +| dla102 | | ++----------------------------------+------------------+ +| dla102x | | ++----------------------------------+------------------+ +| dla102x2 | | ++----------------------------------+------------------+ +| dla169 | | ++----------------------------------+------------------+ | dm_nfnet_f0 | ✅ | +----------------------------------+------------------+ | dm_nfnet_f1 | ✅ | @@ -163,10 +167,6 @@ Total number of encoders: 549 +----------------------------------+------------------+ | dm_nfnet_f6 | ✅ | +----------------------------------+------------------+ -| dpn107 | | -+----------------------------------+------------------+ -| dpn131 | | -+----------------------------------+------------------+ | dpn48b | | +----------------------------------+------------------+ | dpn68 | | @@ -177,6 +177,10 @@ Total number of encoders: 549 +----------------------------------+------------------+ | dpn98 | | +----------------------------------+------------------+ +| dpn107 | | ++----------------------------------+------------------+ +| dpn131 | | ++----------------------------------+------------------+ | eca_botnext26ts_256 | ✅ | +----------------------------------+------------------+ | eca_halonext26ts | ✅ | @@ -195,14 +199,6 @@ Total number of encoders: 549 +----------------------------------+------------------+ | eca_vovnet39b | | +----------------------------------+------------------+ -| ecaresnet101d | ✅ | -+----------------------------------+------------------+ -| ecaresnet101d_pruned | ✅ | -+----------------------------------+------------------+ -| ecaresnet200d | ✅ | -+----------------------------------+------------------+ -| ecaresnet269d | ✅ | -+----------------------------------+------------------+ | ecaresnet26t | ✅ | +----------------------------------+------------------+ | ecaresnet50d | ✅ | @@ -211,6 +207,14 @@ Total number of encoders: 549 +----------------------------------+------------------+ | ecaresnet50t | ✅ | +----------------------------------+------------------+ +| ecaresnet101d | ✅ | ++----------------------------------+------------------+ +| ecaresnet101d_pruned | ✅ | ++----------------------------------+------------------+ +| ecaresnet200d | ✅ | ++----------------------------------+------------------+ +| ecaresnet269d | ✅ | ++----------------------------------+------------------+ | ecaresnetlight | ✅ | +----------------------------------+------------------+ | ecaresnext26t_32x4d | ✅ | @@ -219,10 +223,10 @@ Total number of encoders: 549 +----------------------------------+------------------+ | efficientnet_b0 | ✅ | +----------------------------------+------------------+ -| efficientnet_b0_g16_evos | ✅ | -+----------------------------------+------------------+ | efficientnet_b0_g8_gn | ✅ | +----------------------------------+------------------+ +| efficientnet_b0_g16_evos | ✅ | ++----------------------------------+------------------+ | efficientnet_b0_gn | ✅ | +----------------------------------+------------------+ | efficientnet_b1 | ✅ | @@ -233,8 +237,6 @@ Total number of encoders: 549 +----------------------------------+------------------+ | efficientnet_b2_pruned | ✅ | +----------------------------------+------------------+ -| efficientnet_b2a | ✅ | -+----------------------------------+------------------+ | efficientnet_b3 | ✅ | +----------------------------------+------------------+ | efficientnet_b3_g8_gn | ✅ | @@ -243,8 +245,6 @@ Total number of encoders: 549 +----------------------------------+------------------+ | efficientnet_b3_pruned | ✅ | +----------------------------------+------------------+ -| efficientnet_b3a | ✅ | -+----------------------------------+------------------+ | efficientnet_b4 | ✅ | +----------------------------------+------------------+ | efficientnet_b5 | ✅ | @@ -255,6 +255,8 @@ Total number of encoders: 549 +----------------------------------+------------------+ | efficientnet_b8 | ✅ | +----------------------------------+------------------+ +| efficientnet_blur_b0 | ✅ | ++----------------------------------+------------------+ | efficientnet_cc_b0_4e | ✅ | +----------------------------------+------------------+ | efficientnet_cc_b0_8e | ✅ | @@ -341,6 +343,12 @@ Total number of encoders: 549 +----------------------------------+------------------+ | ghostnet_130 | | +----------------------------------+------------------+ +| ghostnetv2_100 | | ++----------------------------------+------------------+ +| ghostnetv2_130 | | ++----------------------------------+------------------+ +| ghostnetv2_160 | | ++----------------------------------+------------------+ | halo2botnet50ts_256 | ✅ | +----------------------------------+------------------+ | halonet26t | ✅ | @@ -385,12 +393,6 @@ Total number of encoders: 549 +----------------------------------+------------------+ | hrnet_w64 | | +----------------------------------+------------------+ -| inception_resnet_v2 | | -+----------------------------------+------------------+ -| inception_v3 | | -+----------------------------------+------------------+ -| inception_v4 | | -+----------------------------------+------------------+ | lambda_resnet26rpt_256 | ✅ | +----------------------------------+------------------+ | lambda_resnet26t | ✅ | @@ -411,23 +413,21 @@ Total number of encoders: 549 +----------------------------------+------------------+ | legacy_senet154 | | +----------------------------------+------------------+ -| legacy_seresnet101 | | -+----------------------------------+------------------+ -| legacy_seresnet152 | | -+----------------------------------+------------------+ | legacy_seresnet18 | | +----------------------------------+------------------+ | legacy_seresnet34 | | +----------------------------------+------------------+ | legacy_seresnet50 | | +----------------------------------+------------------+ -| legacy_seresnext101_32x4d | | +| legacy_seresnet101 | | ++----------------------------------+------------------+ +| legacy_seresnet152 | | +----------------------------------+------------------+ | legacy_seresnext26_32x4d | | +----------------------------------+------------------+ | legacy_seresnext50_32x4d | | +----------------------------------+------------------+ -| legacy_xception | | +| legacy_seresnext101_32x4d | | +----------------------------------+------------------+ | maxvit_base_tf_224 | | +----------------------------------+------------------+ @@ -515,11 +515,23 @@ Total number of encoders: 549 +----------------------------------+------------------+ | mnasnet_140 | ✅ | +----------------------------------+------------------+ -| mnasnet_a1 | ✅ | +| mnasnet_small | ✅ | ++----------------------------------+------------------+ +| mobilenet_edgetpu_100 | ✅ | +----------------------------------+------------------+ -| mnasnet_b1 | ✅ | +| mobilenet_edgetpu_v2_l | ✅ | +----------------------------------+------------------+ -| mnasnet_small | ✅ | +| mobilenet_edgetpu_v2_m | ✅ | ++----------------------------------+------------------+ +| mobilenet_edgetpu_v2_s | ✅ | ++----------------------------------+------------------+ +| mobilenet_edgetpu_v2_xs | ✅ | ++----------------------------------+------------------+ +| mobilenetv1_100 | ✅ | ++----------------------------------+------------------+ +| mobilenetv1_100h | ✅ | ++----------------------------------+------------------+ +| mobilenetv1_125 | ✅ | +----------------------------------+------------------+ | mobilenetv2_035 | ✅ | +----------------------------------+------------------+ @@ -539,6 +551,8 @@ Total number of encoders: 549 +----------------------------------+------------------+ | mobilenetv3_large_100 | ✅ | +----------------------------------+------------------+ +| mobilenetv3_large_150d | ✅ | ++----------------------------------+------------------+ | mobilenetv3_rw | ✅ | +----------------------------------+------------------+ | mobilenetv3_small_050 | ✅ | @@ -547,6 +561,40 @@ Total number of encoders: 549 +----------------------------------+------------------+ | mobilenetv3_small_100 | ✅ | +----------------------------------+------------------+ +| mobilenetv4_conv_aa_large | ✅ | ++----------------------------------+------------------+ +| mobilenetv4_conv_aa_medium | ✅ | ++----------------------------------+------------------+ +| mobilenetv4_conv_blur_medium | ✅ | ++----------------------------------+------------------+ +| mobilenetv4_conv_large | ✅ | ++----------------------------------+------------------+ +| mobilenetv4_conv_medium | ✅ | ++----------------------------------+------------------+ +| mobilenetv4_conv_small | ✅ | ++----------------------------------+------------------+ +| mobilenetv4_conv_small_035 | ✅ | ++----------------------------------+------------------+ +| mobilenetv4_conv_small_050 | ✅ | ++----------------------------------+------------------+ +| mobilenetv4_hybrid_large | ✅ | ++----------------------------------+------------------+ +| mobilenetv4_hybrid_large_075 | ✅ | ++----------------------------------+------------------+ +| mobilenetv4_hybrid_medium | ✅ | ++----------------------------------+------------------+ +| mobilenetv4_hybrid_medium_075 | ✅ | ++----------------------------------+------------------+ +| mobileone_s0 | ✅ | ++----------------------------------+------------------+ +| mobileone_s1 | ✅ | ++----------------------------------+------------------+ +| mobileone_s2 | ✅ | ++----------------------------------+------------------+ +| mobileone_s3 | ✅ | ++----------------------------------+------------------+ +| mobileone_s4 | ✅ | ++----------------------------------+------------------+ | mobilevit_s | ✅ | +----------------------------------+------------------+ | mobilevit_xs | ✅ | @@ -567,14 +615,12 @@ Total number of encoders: 549 +----------------------------------+------------------+ | mobilevitv2_200 | ✅ | +----------------------------------+------------------+ -| nasnetalarge | | -+----------------------------------+------------------+ -| nf_ecaresnet101 | ✅ | -+----------------------------------+------------------+ | nf_ecaresnet26 | ✅ | +----------------------------------+------------------+ | nf_ecaresnet50 | ✅ | +----------------------------------+------------------+ +| nf_ecaresnet101 | ✅ | ++----------------------------------+------------------+ | nf_regnet_b0 | ✅ | +----------------------------------+------------------+ | nf_regnet_b1 | ✅ | @@ -587,18 +633,18 @@ Total number of encoders: 549 +----------------------------------+------------------+ | nf_regnet_b5 | ✅ | +----------------------------------+------------------+ -| nf_resnet101 | ✅ | -+----------------------------------+------------------+ | nf_resnet26 | ✅ | +----------------------------------+------------------+ | nf_resnet50 | ✅ | +----------------------------------+------------------+ -| nf_seresnet101 | ✅ | +| nf_resnet101 | ✅ | +----------------------------------+------------------+ | nf_seresnet26 | ✅ | +----------------------------------+------------------+ | nf_seresnet50 | ✅ | +----------------------------------+------------------+ +| nf_seresnet101 | ✅ | ++----------------------------------+------------------+ | nfnet_f0 | ✅ | +----------------------------------+------------------+ | nfnet_f1 | ✅ | @@ -617,8 +663,6 @@ Total number of encoders: 549 +----------------------------------+------------------+ | nfnet_l0 | ✅ | +----------------------------------+------------------+ -| pnasnet5large | | -+----------------------------------+------------------+ | regnetv_040 | ✅ | +----------------------------------+------------------+ | regnetv_064 | ✅ | @@ -675,16 +719,16 @@ Total number of encoders: 549 +----------------------------------+------------------+ | regnety_120 | ✅ | +----------------------------------+------------------+ -| regnety_1280 | ✅ | -+----------------------------------+------------------+ | regnety_160 | ✅ | +----------------------------------+------------------+ -| regnety_2560 | ✅ | -+----------------------------------+------------------+ | regnety_320 | ✅ | +----------------------------------+------------------+ | regnety_640 | ✅ | +----------------------------------+------------------+ +| regnety_1280 | ✅ | ++----------------------------------+------------------+ +| regnety_2560 | ✅ | ++----------------------------------+------------------+ | regnetz_005 | ✅ | +----------------------------------+------------------+ | regnetz_040 | ✅ | @@ -699,14 +743,34 @@ Total number of encoders: 549 +----------------------------------+------------------+ | regnetz_c16_evos | ✅ | +----------------------------------+------------------+ -| regnetz_d32 | ✅ | -+----------------------------------+------------------+ | regnetz_d8 | ✅ | +----------------------------------+------------------+ | regnetz_d8_evos | ✅ | +----------------------------------+------------------+ +| regnetz_d32 | ✅ | ++----------------------------------+------------------+ | regnetz_e8 | ✅ | +----------------------------------+------------------+ +| repghostnet_050 | | ++----------------------------------+------------------+ +| repghostnet_058 | | ++----------------------------------+------------------+ +| repghostnet_080 | | ++----------------------------------+------------------+ +| repghostnet_100 | | ++----------------------------------+------------------+ +| repghostnet_111 | | ++----------------------------------+------------------+ +| repghostnet_130 | | ++----------------------------------+------------------+ +| repghostnet_150 | | ++----------------------------------+------------------+ +| repghostnet_200 | | ++----------------------------------+------------------+ +| repvgg_a0 | ✅ | ++----------------------------------+------------------+ +| repvgg_a1 | ✅ | ++----------------------------------+------------------+ | repvgg_a2 | ✅ | +----------------------------------+------------------+ | repvgg_b0 | ✅ | @@ -723,9 +787,7 @@ Total number of encoders: 549 +----------------------------------+------------------+ | repvgg_b3g4 | ✅ | +----------------------------------+------------------+ -| res2net101_26w_4s | ✅ | -+----------------------------------+------------------+ -| res2net101d | ✅ | +| repvgg_d2se | ✅ | +----------------------------------+------------------+ | res2net50_14w_8s | ✅ | +----------------------------------+------------------+ @@ -739,15 +801,13 @@ Total number of encoders: 549 +----------------------------------+------------------+ | res2net50d | ✅ | +----------------------------------+------------------+ -| res2next50 | ✅ | -+----------------------------------+------------------+ -| resnest101e | ✅ | +| res2net101_26w_4s | ✅ | +----------------------------------+------------------+ -| resnest14d | ✅ | +| res2net101d | ✅ | +----------------------------------+------------------+ -| resnest200e | ✅ | +| res2next50 | ✅ | +----------------------------------+------------------+ -| resnest269e | ✅ | +| resnest14d | ✅ | +----------------------------------+------------------+ | resnest26d | ✅ | +----------------------------------+------------------+ @@ -757,34 +817,20 @@ Total number of encoders: 549 +----------------------------------+------------------+ | resnest50d_4s2x40d | ✅ | +----------------------------------+------------------+ -| resnet101 | ✅ | -+----------------------------------+------------------+ -| resnet101c | ✅ | +| resnest101e | ✅ | +----------------------------------+------------------+ -| resnet101d | ✅ | +| resnest200e | ✅ | +----------------------------------+------------------+ -| resnet101s | ✅ | +| resnest269e | ✅ | +----------------------------------+------------------+ | resnet10t | ✅ | +----------------------------------+------------------+ | resnet14t | ✅ | +----------------------------------+------------------+ -| resnet152 | ✅ | -+----------------------------------+------------------+ -| resnet152c | ✅ | -+----------------------------------+------------------+ -| resnet152d | ✅ | -+----------------------------------+------------------+ -| resnet152s | ✅ | -+----------------------------------+------------------+ | resnet18 | ✅ | +----------------------------------+------------------+ | resnet18d | ✅ | +----------------------------------+------------------+ -| resnet200 | ✅ | -+----------------------------------+------------------+ -| resnet200d | ✅ | -+----------------------------------+------------------+ | resnet26 | ✅ | +----------------------------------+------------------+ | resnet26d | ✅ | @@ -801,8 +847,14 @@ Total number of encoders: 549 +----------------------------------+------------------+ | resnet50 | ✅ | +----------------------------------+------------------+ +| resnet50_clip | ✅ | ++----------------------------------+------------------+ +| resnet50_clip_gap | ✅ | ++----------------------------------+------------------+ | resnet50_gn | ✅ | +----------------------------------+------------------+ +| resnet50_mlp | ✅ | ++----------------------------------+------------------+ | resnet50c | ✅ | +----------------------------------+------------------+ | resnet50d | ✅ | @@ -811,11 +863,45 @@ Total number of encoders: 549 +----------------------------------+------------------+ | resnet50t | ✅ | +----------------------------------+------------------+ +| resnet50x4_clip | ✅ | ++----------------------------------+------------------+ +| resnet50x4_clip_gap | ✅ | ++----------------------------------+------------------+ +| resnet50x16_clip | ✅ | ++----------------------------------+------------------+ +| resnet50x16_clip_gap | ✅ | ++----------------------------------+------------------+ +| resnet50x64_clip | ✅ | ++----------------------------------+------------------+ +| resnet50x64_clip_gap | ✅ | ++----------------------------------+------------------+ | resnet51q | ✅ | +----------------------------------+------------------+ | resnet61q | ✅ | +----------------------------------+------------------+ -| resnetaa101d | ✅ | +| resnet101 | ✅ | ++----------------------------------+------------------+ +| resnet101_clip | ✅ | ++----------------------------------+------------------+ +| resnet101_clip_gap | ✅ | ++----------------------------------+------------------+ +| resnet101c | ✅ | ++----------------------------------+------------------+ +| resnet101d | ✅ | ++----------------------------------+------------------+ +| resnet101s | ✅ | ++----------------------------------+------------------+ +| resnet152 | ✅ | ++----------------------------------+------------------+ +| resnet152c | ✅ | ++----------------------------------+------------------+ +| resnet152d | ✅ | ++----------------------------------+------------------+ +| resnet152s | ✅ | ++----------------------------------+------------------+ +| resnet200 | ✅ | ++----------------------------------+------------------+ +| resnet200d | ✅ | +----------------------------------+------------------+ | resnetaa34d | ✅ | +----------------------------------+------------------+ @@ -823,7 +909,7 @@ Total number of encoders: 549 +----------------------------------+------------------+ | resnetaa50d | ✅ | +----------------------------------+------------------+ -| resnetblur101d | ✅ | +| resnetaa101d | ✅ | +----------------------------------+------------------+ | resnetblur18 | ✅ | +----------------------------------+------------------+ @@ -831,6 +917,10 @@ Total number of encoders: 549 +----------------------------------+------------------+ | resnetblur50d | ✅ | +----------------------------------+------------------+ +| resnetblur101d | ✅ | ++----------------------------------+------------------+ +| resnetrs50 | ✅ | ++----------------------------------+------------------+ | resnetrs101 | ✅ | +----------------------------------+------------------+ | resnetrs152 | ✅ | @@ -843,23 +933,13 @@ Total number of encoders: 549 +----------------------------------+------------------+ | resnetrs420 | ✅ | +----------------------------------+------------------+ -| resnetrs50 | ✅ | +| resnetv2_18 | ✅ | +----------------------------------+------------------+ -| resnetv2_101 | ✅ | -+----------------------------------+------------------+ -| resnetv2_101d | ✅ | -+----------------------------------+------------------+ -| resnetv2_101x1_bit | ✅ | -+----------------------------------+------------------+ -| resnetv2_101x3_bit | ✅ | -+----------------------------------+------------------+ -| resnetv2_152 | ✅ | +| resnetv2_18d | ✅ | +----------------------------------+------------------+ -| resnetv2_152d | ✅ | +| resnetv2_34 | ✅ | +----------------------------------+------------------+ -| resnetv2_152x2_bit | ✅ | -+----------------------------------+------------------+ -| resnetv2_152x4_bit | ✅ | +| resnetv2_34d | ✅ | +----------------------------------+------------------+ | resnetv2_50 | ✅ | +----------------------------------+------------------+ @@ -877,15 +957,21 @@ Total number of encoders: 549 +----------------------------------+------------------+ | resnetv2_50x3_bit | ✅ | +----------------------------------+------------------+ -| resnext101_32x16d | ✅ | +| resnetv2_101 | ✅ | +----------------------------------+------------------+ -| resnext101_32x32d | ✅ | +| resnetv2_101d | ✅ | +----------------------------------+------------------+ -| resnext101_32x4d | ✅ | +| resnetv2_101x1_bit | ✅ | +----------------------------------+------------------+ -| resnext101_32x8d | ✅ | +| resnetv2_101x3_bit | ✅ | +----------------------------------+------------------+ -| resnext101_64x4d | ✅ | +| resnetv2_152 | ✅ | ++----------------------------------+------------------+ +| resnetv2_152d | ✅ | ++----------------------------------+------------------+ +| resnetv2_152x2_bit | ✅ | ++----------------------------------+------------------+ +| resnetv2_152x4_bit | ✅ | +----------------------------------+------------------+ | resnext26ts | ✅ | +----------------------------------+------------------+ @@ -893,6 +979,16 @@ Total number of encoders: 549 +----------------------------------+------------------+ | resnext50d_32x4d | ✅ | +----------------------------------+------------------+ +| resnext101_32x4d | ✅ | ++----------------------------------+------------------+ +| resnext101_32x8d | ✅ | ++----------------------------------+------------------+ +| resnext101_32x16d | ✅ | ++----------------------------------+------------------+ +| resnext101_32x32d | ✅ | ++----------------------------------+------------------+ +| resnext101_64x4d | ✅ | ++----------------------------------+------------------+ | rexnet_100 | ✅ | +----------------------------------+------------------+ | rexnet_130 | ✅ | @@ -915,8 +1011,20 @@ Total number of encoders: 549 +----------------------------------+------------------+ | sebotnet33ts_256 | ✅ | +----------------------------------+------------------+ +| sedarknet21 | ✅ | ++----------------------------------+------------------+ | sehalonet33ts | ✅ | +----------------------------------+------------------+ +| selecsls42 | | ++----------------------------------+------------------+ +| selecsls42b | | ++----------------------------------+------------------+ +| selecsls60 | | ++----------------------------------+------------------+ +| selecsls60b | | ++----------------------------------+------------------+ +| selecsls84 | | ++----------------------------------+------------------+ | semnasnet_050 | ✅ | +----------------------------------+------------------+ | semnasnet_075 | ✅ | @@ -927,18 +1035,8 @@ Total number of encoders: 549 +----------------------------------+------------------+ | senet154 | ✅ | +----------------------------------+------------------+ -| seresnet101 | ✅ | -+----------------------------------+------------------+ -| seresnet152 | ✅ | -+----------------------------------+------------------+ -| seresnet152d | ✅ | -+----------------------------------+------------------+ | seresnet18 | ✅ | +----------------------------------+------------------+ -| seresnet200d | ✅ | -+----------------------------------+------------------+ -| seresnet269d | ✅ | -+----------------------------------+------------------+ | seresnet33ts | ✅ | +----------------------------------+------------------+ | seresnet34 | ✅ | @@ -947,28 +1045,38 @@ Total number of encoders: 549 +----------------------------------+------------------+ | seresnet50t | ✅ | +----------------------------------+------------------+ -| seresnetaa50d | ✅ | +| seresnet101 | ✅ | +----------------------------------+------------------+ -| seresnext101_32x4d | ✅ | +| seresnet152 | ✅ | +----------------------------------+------------------+ -| seresnext101_32x8d | ✅ | +| seresnet152d | ✅ | +----------------------------------+------------------+ -| seresnext101_64x4d | ✅ | +| seresnet200d | ✅ | +----------------------------------+------------------+ -| seresnext101d_32x8d | ✅ | +| seresnet269d | ✅ | ++----------------------------------+------------------+ +| seresnetaa50d | ✅ | +----------------------------------+------------------+ | seresnext26d_32x4d | ✅ | +----------------------------------+------------------+ | seresnext26t_32x4d | ✅ | +----------------------------------+------------------+ -| seresnext26tn_32x4d | ✅ | -+----------------------------------+------------------+ | seresnext26ts | ✅ | +----------------------------------+------------------+ | seresnext50_32x4d | ✅ | +----------------------------------+------------------+ +| seresnext101_32x4d | ✅ | ++----------------------------------+------------------+ +| seresnext101_32x8d | ✅ | ++----------------------------------+------------------+ +| seresnext101_64x4d | ✅ | ++----------------------------------+------------------+ +| seresnext101d_32x8d | ✅ | ++----------------------------------+------------------+ | seresnextaa101d_32x8d | ✅ | +----------------------------------+------------------+ +| seresnextaa201d_32x8d | ✅ | ++----------------------------------+------------------+ | skresnet18 | ✅ | +----------------------------------+------------------+ | skresnet34 | ✅ | @@ -1067,14 +1175,30 @@ Total number of encoders: 549 +----------------------------------+------------------+ | tinynet_e | ✅ | +----------------------------------+------------------+ +| vgg11 | | ++----------------------------------+------------------+ +| vgg11_bn | | ++----------------------------------+------------------+ +| vgg13 | | ++----------------------------------+------------------+ +| vgg13_bn | | ++----------------------------------+------------------+ +| vgg16 | | ++----------------------------------+------------------+ +| vgg16_bn | | ++----------------------------------+------------------+ +| vgg19 | | ++----------------------------------+------------------+ +| vgg19_bn | | ++----------------------------------+------------------+ | vovnet39a | | +----------------------------------+------------------+ | vovnet57a | | +----------------------------------+------------------+ -| wide_resnet101_2 | ✅ | -+----------------------------------+------------------+ | wide_resnet50_2 | ✅ | +----------------------------------+------------------+ +| wide_resnet101_2 | ✅ | ++----------------------------------+------------------+ | xception41 | ✅ | +----------------------------------+------------------+ | xception41p | ✅ | @@ -1086,3 +1210,450 @@ Total number of encoders: 549 | xception71 | ✅ | +----------------------------------+------------------+ +Transformer-Style +~~~~~~~~~~~~~~~~~ + +Transformer-style models (e.g., Swin Transformer, ConvNeXt) typically produce feature maps starting at a 1/4 scale, followed by 1/8, 1/16, and 1/32 scales + ++------------------------------------+------------------+ +| Encoder name | Support dilation | ++====================================+==================+ +| caformer_b36 | | ++------------------------------------+------------------+ +| caformer_m36 | | ++------------------------------------+------------------+ +| caformer_s18 | | ++------------------------------------+------------------+ +| caformer_s36 | | ++------------------------------------+------------------+ +| convformer_b36 | | ++------------------------------------+------------------+ +| convformer_m36 | | ++------------------------------------+------------------+ +| convformer_s18 | | ++------------------------------------+------------------+ +| convformer_s36 | | ++------------------------------------+------------------+ +| convnext_atto | ✅ | ++------------------------------------+------------------+ +| convnext_atto_ols | ✅ | ++------------------------------------+------------------+ +| convnext_atto_rms | ✅ | ++------------------------------------+------------------+ +| convnext_base | ✅ | ++------------------------------------+------------------+ +| convnext_femto | ✅ | ++------------------------------------+------------------+ +| convnext_femto_ols | ✅ | ++------------------------------------+------------------+ +| convnext_large | ✅ | ++------------------------------------+------------------+ +| convnext_large_mlp | ✅ | ++------------------------------------+------------------+ +| convnext_nano | ✅ | ++------------------------------------+------------------+ +| convnext_nano_ols | ✅ | ++------------------------------------+------------------+ +| convnext_pico | ✅ | ++------------------------------------+------------------+ +| convnext_pico_ols | ✅ | ++------------------------------------+------------------+ +| convnext_small | ✅ | ++------------------------------------+------------------+ +| convnext_tiny | ✅ | ++------------------------------------+------------------+ +| convnext_tiny_hnf | ✅ | ++------------------------------------+------------------+ +| convnext_xlarge | ✅ | ++------------------------------------+------------------+ +| convnext_xxlarge | ✅ | ++------------------------------------+------------------+ +| convnext_zepto_rms | ✅ | ++------------------------------------+------------------+ +| convnext_zepto_rms_ols | ✅ | ++------------------------------------+------------------+ +| convnextv2_atto | ✅ | ++------------------------------------+------------------+ +| convnextv2_base | ✅ | ++------------------------------------+------------------+ +| convnextv2_femto | ✅ | ++------------------------------------+------------------+ +| convnextv2_huge | ✅ | ++------------------------------------+------------------+ +| convnextv2_large | ✅ | ++------------------------------------+------------------+ +| convnextv2_nano | ✅ | ++------------------------------------+------------------+ +| convnextv2_pico | ✅ | ++------------------------------------+------------------+ +| convnextv2_small | ✅ | ++------------------------------------+------------------+ +| convnextv2_tiny | ✅ | ++------------------------------------+------------------+ +| davit_base | | ++------------------------------------+------------------+ +| davit_base_fl | | ++------------------------------------+------------------+ +| davit_giant | | ++------------------------------------+------------------+ +| davit_huge | | ++------------------------------------+------------------+ +| davit_huge_fl | | ++------------------------------------+------------------+ +| davit_large | | ++------------------------------------+------------------+ +| davit_small | | ++------------------------------------+------------------+ +| davit_tiny | | ++------------------------------------+------------------+ +| edgenext_base | | ++------------------------------------+------------------+ +| edgenext_small | | ++------------------------------------+------------------+ +| edgenext_small_rw | | ++------------------------------------+------------------+ +| edgenext_x_small | | ++------------------------------------+------------------+ +| edgenext_xx_small | | ++------------------------------------+------------------+ +| efficientformer_l1 | | ++------------------------------------+------------------+ +| efficientformer_l3 | | ++------------------------------------+------------------+ +| efficientformer_l7 | | ++------------------------------------+------------------+ +| efficientformerv2_l | | ++------------------------------------+------------------+ +| efficientformerv2_s0 | | ++------------------------------------+------------------+ +| efficientformerv2_s1 | | ++------------------------------------+------------------+ +| efficientformerv2_s2 | | ++------------------------------------+------------------+ +| efficientvit_b0 | | ++------------------------------------+------------------+ +| efficientvit_b1 | | ++------------------------------------+------------------+ +| efficientvit_b2 | | ++------------------------------------+------------------+ +| efficientvit_b3 | | ++------------------------------------+------------------+ +| efficientvit_l1 | | ++------------------------------------+------------------+ +| efficientvit_l2 | | ++------------------------------------+------------------+ +| efficientvit_l3 | | ++------------------------------------+------------------+ +| fastvit_ma36 | | ++------------------------------------+------------------+ +| fastvit_mci0 | | ++------------------------------------+------------------+ +| fastvit_mci1 | | ++------------------------------------+------------------+ +| fastvit_mci2 | | ++------------------------------------+------------------+ +| fastvit_s12 | | ++------------------------------------+------------------+ +| fastvit_sa12 | | ++------------------------------------+------------------+ +| fastvit_sa24 | | ++------------------------------------+------------------+ +| fastvit_sa36 | | ++------------------------------------+------------------+ +| fastvit_t8 | | ++------------------------------------+------------------+ +| fastvit_t12 | | ++------------------------------------+------------------+ +| focalnet_base_lrf | | ++------------------------------------+------------------+ +| focalnet_base_srf | | ++------------------------------------+------------------+ +| focalnet_huge_fl3 | | ++------------------------------------+------------------+ +| focalnet_huge_fl4 | | ++------------------------------------+------------------+ +| focalnet_large_fl3 | | ++------------------------------------+------------------+ +| focalnet_large_fl4 | | ++------------------------------------+------------------+ +| focalnet_small_lrf | | ++------------------------------------+------------------+ +| focalnet_small_srf | | ++------------------------------------+------------------+ +| focalnet_tiny_lrf | | ++------------------------------------+------------------+ +| focalnet_tiny_srf | | ++------------------------------------+------------------+ +| focalnet_xlarge_fl3 | | ++------------------------------------+------------------+ +| focalnet_xlarge_fl4 | | ++------------------------------------+------------------+ +| hgnet_base | | ++------------------------------------+------------------+ +| hgnet_small | | ++------------------------------------+------------------+ +| hgnet_tiny | | ++------------------------------------+------------------+ +| hgnetv2_b0 | | ++------------------------------------+------------------+ +| hgnetv2_b1 | | ++------------------------------------+------------------+ +| hgnetv2_b2 | | ++------------------------------------+------------------+ +| hgnetv2_b3 | | ++------------------------------------+------------------+ +| hgnetv2_b4 | | ++------------------------------------+------------------+ +| hgnetv2_b5 | | ++------------------------------------+------------------+ +| hgnetv2_b6 | | ++------------------------------------+------------------+ +| hiera_base_224 | | ++------------------------------------+------------------+ +| hiera_base_abswin_256 | | ++------------------------------------+------------------+ +| hiera_base_plus_224 | | ++------------------------------------+------------------+ +| hiera_huge_224 | | ++------------------------------------+------------------+ +| hiera_large_224 | | ++------------------------------------+------------------+ +| hiera_small_224 | | ++------------------------------------+------------------+ +| hiera_small_abswin_256 | | ++------------------------------------+------------------+ +| hiera_tiny_224 | | ++------------------------------------+------------------+ +| hieradet_small | | ++------------------------------------+------------------+ +| inception_next_base | | ++------------------------------------+------------------+ +| inception_next_small | | ++------------------------------------+------------------+ +| inception_next_tiny | | ++------------------------------------+------------------+ +| mambaout_base | | ++------------------------------------+------------------+ +| mambaout_base_plus_rw | | ++------------------------------------+------------------+ +| mambaout_base_short_rw | | ++------------------------------------+------------------+ +| mambaout_base_tall_rw | | ++------------------------------------+------------------+ +| mambaout_base_wide_rw | | ++------------------------------------+------------------+ +| mambaout_femto | | ++------------------------------------+------------------+ +| mambaout_kobe | | ++------------------------------------+------------------+ +| mambaout_small | | ++------------------------------------+------------------+ +| mambaout_small_rw | | ++------------------------------------+------------------+ +| mambaout_tiny | | ++------------------------------------+------------------+ +| mvitv2_base | | ++------------------------------------+------------------+ +| mvitv2_base_cls | | ++------------------------------------+------------------+ +| mvitv2_huge_cls | | ++------------------------------------+------------------+ +| mvitv2_large | | ++------------------------------------+------------------+ +| mvitv2_large_cls | | ++------------------------------------+------------------+ +| mvitv2_small | | ++------------------------------------+------------------+ +| mvitv2_small_cls | | ++------------------------------------+------------------+ +| mvitv2_tiny | | ++------------------------------------+------------------+ +| nest_base | | ++------------------------------------+------------------+ +| nest_base_jx | | ++------------------------------------+------------------+ +| nest_small | | ++------------------------------------+------------------+ +| nest_small_jx | | ++------------------------------------+------------------+ +| nest_tiny | | ++------------------------------------+------------------+ +| nest_tiny_jx | | ++------------------------------------+------------------+ +| nextvit_base | | ++------------------------------------+------------------+ +| nextvit_large | | ++------------------------------------+------------------+ +| nextvit_small | | ++------------------------------------+------------------+ +| poolformer_m36 | | ++------------------------------------+------------------+ +| poolformer_m48 | | ++------------------------------------+------------------+ +| poolformer_s12 | | ++------------------------------------+------------------+ +| poolformer_s24 | | ++------------------------------------+------------------+ +| poolformer_s36 | | ++------------------------------------+------------------+ +| poolformerv2_m36 | | ++------------------------------------+------------------+ +| poolformerv2_m48 | | ++------------------------------------+------------------+ +| poolformerv2_s12 | | ++------------------------------------+------------------+ +| poolformerv2_s24 | | ++------------------------------------+------------------+ +| poolformerv2_s36 | | ++------------------------------------+------------------+ +| pvt_v2_b0 | | ++------------------------------------+------------------+ +| pvt_v2_b1 | | ++------------------------------------+------------------+ +| pvt_v2_b2 | | ++------------------------------------+------------------+ +| pvt_v2_b2_li | | ++------------------------------------+------------------+ +| pvt_v2_b3 | | ++------------------------------------+------------------+ +| pvt_v2_b4 | | ++------------------------------------+------------------+ +| pvt_v2_b5 | | ++------------------------------------+------------------+ +| rdnet_base | | ++------------------------------------+------------------+ +| rdnet_large | | ++------------------------------------+------------------+ +| rdnet_small | | ++------------------------------------+------------------+ +| rdnet_tiny | | ++------------------------------------+------------------+ +| repvit_m0_9 | | ++------------------------------------+------------------+ +| repvit_m1 | | ++------------------------------------+------------------+ +| repvit_m1_0 | | ++------------------------------------+------------------+ +| repvit_m1_1 | | ++------------------------------------+------------------+ +| repvit_m1_5 | | ++------------------------------------+------------------+ +| repvit_m2 | | ++------------------------------------+------------------+ +| repvit_m2_3 | | ++------------------------------------+------------------+ +| repvit_m3 | | ++------------------------------------+------------------+ +| sam2_hiera_base_plus | | ++------------------------------------+------------------+ +| sam2_hiera_large | | ++------------------------------------+------------------+ +| sam2_hiera_small | | ++------------------------------------+------------------+ +| sam2_hiera_tiny | | ++------------------------------------+------------------+ +| swin_base_patch4_window7_224 | | ++------------------------------------+------------------+ +| swin_base_patch4_window12_384 | | ++------------------------------------+------------------+ +| swin_large_patch4_window7_224 | | ++------------------------------------+------------------+ +| swin_large_patch4_window12_384 | | ++------------------------------------+------------------+ +| swin_s3_base_224 | | ++------------------------------------+------------------+ +| swin_s3_small_224 | | ++------------------------------------+------------------+ +| swin_s3_tiny_224 | | ++------------------------------------+------------------+ +| swin_small_patch4_window7_224 | | ++------------------------------------+------------------+ +| swin_tiny_patch4_window7_224 | | ++------------------------------------+------------------+ +| swinv2_base_window8_256 | | ++------------------------------------+------------------+ +| swinv2_base_window12_192 | | ++------------------------------------+------------------+ +| swinv2_base_window12to16_192to256 | | ++------------------------------------+------------------+ +| swinv2_base_window12to24_192to384 | | ++------------------------------------+------------------+ +| swinv2_base_window16_256 | | ++------------------------------------+------------------+ +| swinv2_cr_base_224 | | ++------------------------------------+------------------+ +| swinv2_cr_base_384 | | ++------------------------------------+------------------+ +| swinv2_cr_base_ns_224 | | ++------------------------------------+------------------+ +| swinv2_cr_giant_224 | | ++------------------------------------+------------------+ +| swinv2_cr_giant_384 | | ++------------------------------------+------------------+ +| swinv2_cr_huge_224 | | ++------------------------------------+------------------+ +| swinv2_cr_huge_384 | | ++------------------------------------+------------------+ +| swinv2_cr_large_224 | | ++------------------------------------+------------------+ +| swinv2_cr_large_384 | | ++------------------------------------+------------------+ +| swinv2_cr_small_224 | | ++------------------------------------+------------------+ +| swinv2_cr_small_384 | | ++------------------------------------+------------------+ +| swinv2_cr_small_ns_224 | | ++------------------------------------+------------------+ +| swinv2_cr_small_ns_256 | | ++------------------------------------+------------------+ +| swinv2_cr_tiny_224 | | ++------------------------------------+------------------+ +| swinv2_cr_tiny_384 | | ++------------------------------------+------------------+ +| swinv2_cr_tiny_ns_224 | | ++------------------------------------+------------------+ +| swinv2_large_window12_192 | | ++------------------------------------+------------------+ +| swinv2_large_window12to16_192to256 | | ++------------------------------------+------------------+ +| swinv2_large_window12to24_192to384 | | ++------------------------------------+------------------+ +| swinv2_small_window8_256 | | ++------------------------------------+------------------+ +| swinv2_small_window16_256 | | ++------------------------------------+------------------+ +| swinv2_tiny_window8_256 | | ++------------------------------------+------------------+ +| swinv2_tiny_window16_256 | | ++------------------------------------+------------------+ +| tiny_vit_5m_224 | | ++------------------------------------+------------------+ +| tiny_vit_11m_224 | | ++------------------------------------+------------------+ +| tiny_vit_21m_224 | | ++------------------------------------+------------------+ +| tiny_vit_21m_384 | | ++------------------------------------+------------------+ +| tiny_vit_21m_512 | | ++------------------------------------+------------------+ +| tresnet_l | | ++------------------------------------+------------------+ +| tresnet_m | | ++------------------------------------+------------------+ +| tresnet_v2_l | | ++------------------------------------+------------------+ +| tresnet_xl | | ++------------------------------------+------------------+ +| twins_pcpvt_base | | ++------------------------------------+------------------+ +| twins_pcpvt_large | | ++------------------------------------+------------------+ +| twins_pcpvt_small | | ++------------------------------------+------------------+ +| twins_svt_base | | ++------------------------------------+------------------+ +| twins_svt_large | | ++------------------------------------+------------------+ +| twins_svt_small | | ++------------------------------------+------------------+ + diff --git a/segmentation_models_pytorch/encoders/timm_universal.py b/segmentation_models_pytorch/encoders/timm_universal.py index eb008221..9bdcb188 100644 --- a/segmentation_models_pytorch/encoders/timm_universal.py +++ b/segmentation_models_pytorch/encoders/timm_universal.py @@ -1,10 +1,49 @@ +""" +TimmUniversalEncoder provides a unified feature extraction interface built on the +`timm` library, supporting both traditional-style (e.g., ResNet) and transformer-style +models (e.g., Swin Transformer, ConvNeXt). + +This encoder produces consistent multi-level feature maps for semantic segmentation tasks. +It allows configuring the number of feature extraction stages (`depth`) and adjusting +`output_stride` when supported. + +Key Features: +- Flexible model selection using `timm.create_model`. +- Unified multi-level output across different model hierarchies. +- Automatic alignment for inconsistent feature scales: + - Transformer-style models (start at 1/4 scale): Insert dummy features for 1/2 scale. + - VGG-style models (include scale-1 features): Align outputs for compatibility. +- Easy access to feature scale information via the `reduction` property. + +Feature Scale Differences: +- Traditional-style models (e.g., ResNet): Scales at 1/2, 1/4, 1/8, 1/16, 1/32. +- Transformer-style models (e.g., Swin Transformer): Start at 1/4 scale, skip 1/2 scale. +- VGG-style models: Include scale-1 features (input resolution). + +Notes: +- `output_stride` is unsupported in some models, especially transformer-based architectures. +- Special handling for models like TResNet and DLA to ensure correct feature indexing. +- VGG-style models use `_is_vgg_style` to align scale-1 features with standard outputs. +""" + from typing import Any import timm +import torch import torch.nn as nn class TimmUniversalEncoder(nn.Module): + """ + A universal encoder leveraging the `timm` library for feature extraction from + various model architectures, including traditional-style and transformer-style models. + + Features: + - Supports configurable depth and output stride. + - Ensures consistent multi-level feature extraction across diverse models. + - Compatible with convolutional and transformer-like backbones. + """ + def __init__( self, name: str, @@ -14,7 +53,20 @@ def __init__( output_stride: int = 32, **kwargs: dict[str, Any], ): + """ + Initialize the encoder. + + Args: + name (str): Model name to load from `timm`. + pretrained (bool): Load pretrained weights (default: True). + in_channels (int): Number of input channels (default: 3 for RGB). + depth (int): Number of feature stages to extract (default: 5). + output_stride (int): Desired output stride (default: 32). + **kwargs: Additional arguments passed to `timm.create_model`. + """ super().__init__() + + # Default model configuration for feature extraction common_kwargs = dict( in_chans=in_channels, features_only=True, @@ -23,34 +75,138 @@ def __init__( out_indices=tuple(range(depth)), ) - # not all models support output stride argument, drop it by default + # Not all models support output stride argument, drop it by default if output_stride == 32: common_kwargs.pop("output_stride") - self.model = timm.create_model( - name, **_merge_kwargs_no_duplicates(common_kwargs, kwargs) - ) + # Load a temporary model to analyze its feature hierarchy + try: + with torch.device("meta"): + tmp_model = timm.create_model(name, features_only=True) + except Exception: + tmp_model = timm.create_model(name, features_only=True) + + # Check if model output is in channel-last format (NHWC) + self._is_channel_last = getattr(tmp_model, "output_fmt", None) == "NHWC" + + # Determine the model's downsampling pattern and set hierarchy flags + encoder_stage = len(tmp_model.feature_info.reduction()) + reduction_scales = list(tmp_model.feature_info.reduction()) + + if reduction_scales == [2 ** (i + 2) for i in range(encoder_stage)]: + # Transformer-style downsampling: scales (4, 8, 16, 32) + self._is_transformer_style = True + self._is_vgg_style = False + elif reduction_scales == [2 ** (i + 1) for i in range(encoder_stage)]: + # Traditional-style downsampling: scales (2, 4, 8, 16, 32) + self._is_transformer_style = False + self._is_vgg_style = False + elif reduction_scales == [2**i for i in range(encoder_stage)]: + # Vgg-style models including scale 1: scales (1, 2, 4, 8, 16, 32) + self._is_transformer_style = False + self._is_vgg_style = True + else: + raise ValueError("Unsupported model downsampling pattern.") + + if self._is_transformer_style: + # Transformer-like models (start at scale 4) + if "tresnet" in name: + # 'tresnet' models start feature extraction at stage 1, + # so out_indices=(1, 2, 3, 4) for depth=5. + common_kwargs["out_indices"] = tuple(range(1, depth)) + else: + # Most transformer-like models use out_indices=(0, 1, 2, 3) for depth=5. + common_kwargs["out_indices"] = tuple(range(depth - 1)) + + self.model = timm.create_model( + name, **_merge_kwargs_no_duplicates(common_kwargs, kwargs) + ) + # Add a dummy output channel (0) to align with traditional encoder structures. + self._out_channels = ( + [in_channels] + [0] + self.model.feature_info.channels() + ) + else: + if "dla" in name: + # For 'dla' models, out_indices starts at 0 and matches the input size. + common_kwargs["out_indices"] = tuple(range(1, depth + 1)) + if self._is_vgg_style: + common_kwargs["out_indices"] = tuple(range(depth + 1)) + + self.model = timm.create_model( + name, **_merge_kwargs_no_duplicates(common_kwargs, kwargs) + ) + + if self._is_vgg_style: + self._out_channels = self.model.feature_info.channels() + else: + self._out_channels = [in_channels] + self.model.feature_info.channels() self._in_channels = in_channels - self._out_channels = [in_channels] + self.model.feature_info.channels() self._depth = depth self._output_stride = output_stride - def forward(self, x): + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: + """ + Forward pass to extract multi-stage features. + + Args: + x (torch.Tensor): Input tensor of shape (B, C, H, W). + + Returns: + list[torch.Tensor]: List of feature maps at different scales. + """ features = self.model(x) - features = [x] + features + + # Convert NHWC to NCHW if needed + if self._is_channel_last: + features = [ + feature.permute(0, 3, 1, 2).contiguous() for feature in features + ] + + # Add dummy feature for scale 1/2 if missing (transformer-style models) + if self._is_transformer_style: + B, _, H, W = x.shape + dummy = torch.empty([B, 0, H // 2, W // 2], dtype=x.dtype, device=x.device) + features = [dummy] + features + + # Add input tensor as scale 1 feature if `self._is_vgg_style` is False + if not self._is_vgg_style: + features = [x] + features + return features @property - def out_channels(self): + def out_channels(self) -> list[int]: + """ + Returns the number of output channels for each feature stage. + + Returns: + list[int]: A list of channel dimensions at each scale. + """ return self._out_channels @property - def output_stride(self): + def output_stride(self) -> int: + """ + Returns the effective output stride based on the model depth. + + Returns: + int: The effective output stride. + """ return min(self._output_stride, 2**self._depth) def _merge_kwargs_no_duplicates(a: dict[str, Any], b: dict[str, Any]) -> dict[str, Any]: + """ + Merge two dictionaries, ensuring no duplicate keys exist. + + Args: + a (dict): Base dictionary. + b (dict): Additional parameters to merge. + + Returns: + dict: A merged dictionary. + """ duplicates = a.keys() & b.keys() if duplicates: raise ValueError(f"'{duplicates}' already specified internally") diff --git a/tests/test_models.py b/tests/test_models.py index 68d12c43..460dcdf2 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -13,7 +13,9 @@ def get_encoders(): ] encoders = smp.encoders.get_encoder_names() encoders = [e for e in encoders if e not in exclude_encoders] - encoders.append("tu-resnet34") # for timm universal encoder + encoders.append("tu-resnet34") # for timm universal traditional-like encoder + encoders.append("tu-convnext_atto") # for timm universal transformer-like encoder + encoders.append("tu-darknet17") # for timm universal vgg-like encoder return encoders @@ -78,16 +80,12 @@ def test_forward(model_class, encoder_name, encoder_depth, **kwargs): or model_class is smp.MAnet ): kwargs["decoder_channels"] = (16, 16, 16, 16, 16)[-encoder_depth:] - if model_class in [smp.UnetPlusPlus, smp.Linknet] and encoder_name.startswith( - "mit_b" - ): - return # skip mit_b* - if ( - model_class is smp.FPN - and encoder_name.startswith("mit_b") - and encoder_depth != 5 - ): - return # skip mit_b* + if model_class in [smp.UnetPlusPlus, smp.Linknet]: + if encoder_name.startswith("mit_b") or encoder_name.startswith("tu-convnext"): + return # skip transformer-like model* + if model_class is smp.FPN and encoder_depth != 5: + if encoder_name.startswith("mit_b") or encoder_name.startswith("tu-convnext"): + return # skip transformer-like model* model = model_class( encoder_name, encoder_depth=encoder_depth, encoder_weights=None, **kwargs ) @@ -178,7 +176,6 @@ def test_dilation(encoder_name): or encoder_name.startswith("vgg") or encoder_name.startswith("densenet") or encoder_name.startswith("timm-res") - or encoder_name.startswith("mit_b") ): return