Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Beit segmentation model #1024

Open
qubvel opened this issue Jan 7, 2025 · 8 comments
Open

Add Beit segmentation model #1024

qubvel opened this issue Jan 7, 2025 · 8 comments
Labels
good first issue Good for newcomers help wanted Extra attention is needed new-model

Comments

@qubvel
Copy link
Collaborator

qubvel commented Jan 7, 2025

Add Beit to SMP

BEiT-3 is a general-purpose multimodal foundation model developed by Microsoft that excels in various vision and vision-language tasks, including semantic segmentation. It employs a unified architecture with Multiway Transformers, enabling both deep fusion and modality-specific encoding. Pretrained using a masked "language" modeling approach on images ("Imglish"), texts, and image-text pairs, BEiT-3 effectively models images as another language. This design allows it to achieve state-of-the-art performance across a wide range of tasks, such as object detection, image classification, and semantic segmentation.

  • Achieves top 1 results on ADE20K-val

Papers with Code:
https://paperswithcode.com/paper/image-as-a-foreign-language-beit-pretraining

Paper:
https://arxiv.org/abs/2208.10442

HF reference implementation:
https://huggingface.co/docs/transformers/model_doc/beit
https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/beit/modeling_beit.py

Comments

As an example pls see the latest model additions:

@qubvel qubvel added help wanted Extra attention is needed good first issue Good for newcomers new-model labels Jan 7, 2025
@qubvel
Copy link
Collaborator Author

qubvel commented Jan 7, 2025

cc @brianhou0208 in case you are interested! it was a pleasure working with you, very clean PRs 🤗

@brianhou0208
Copy link
Contributor

brianhou0208 commented Jan 8, 2025

Hi @qubvel, I’m not very familiar with multimodal models yet, but after reviewing the BEiT-3 paper and code, here are my thoughts:

  1. Although BEiT-3 supports multiple input modalities and tasks, there is currently no available implementation in either timm or HF transformers. We have to rely on the official source code, which unfortunately does not provide an out-of-the-box example for semantic segmentation.
  2. On page 8 of the paper, the ADE20K segmentation task setup requires not only BEiT-3 but also ViT-Adapter and Mask2Former:
We directly follow the task transfer settings of ViT-Adapter [CDW+22]. We use a dense prediction task
adapter and employ Mask2Former [CMS+21] as the segmentation framework. 
  1. Based on my review of the official repos for BEiT and BEiTv2 , semantic segmentation can be achieved with these three main components:
  • Encoder: BEiT/BEiTv2 (it's already in timm)
  • Skip-Connection: FPN block
  • Decoder: UPerNet Head

If we adopt the approach above in a more flexible manner, it could support a broader range of ViT-based models for segmentation tasks.
For instance, we could implement an FPNWrapper that outputs multi-scale features (from 1/4 to 1/32) and supports models that do not use downsampling.

Without downsample feature: 264 (PR #1004 )
beit_base_patch16_224 False
[16, 16, 16]
beit_base_patch16_384 False
[16, 16, 16]
beit_large_patch16_224 False
[16, 16, 16]
beit_large_patch16_384 False
[16, 16, 16]
beit_large_patch16_512 False
[16, 16, 16]
beitv2_base_patch16_224 False
[16, 16, 16]
beitv2_large_patch16_224 False
[16, 16, 16]
cait_m36_384 False
[16, 16, 16]
cait_m48_448 False
[16, 16, 16]
cait_s24_224 False
[16, 16, 16]
cait_s24_384 False
[16, 16, 16]
cait_s36_384 False
[16, 16, 16]
cait_xs24_384 False
[16, 16, 16]
cait_xxs24_224 False
[16, 16, 16]
cait_xxs24_384 False
[16, 16, 16]
cait_xxs36_224 False
[16, 16, 16]
cait_xxs36_384 False
[16, 16, 16]
deit3_base_patch16_224 False
[16, 16, 16]
deit3_base_patch16_384 False
[16, 16, 16]
deit3_huge_patch14_224 False
[14, 14, 14]
deit3_large_patch16_224 False
[16, 16, 16]
deit3_large_patch16_384 False
[16, 16, 16]
deit3_medium_patch16_224 False
[16, 16, 16]
deit3_small_patch16_224 False
[16, 16, 16]
deit3_small_patch16_384 False
[16, 16, 16]
deit_base_distilled_patch16_224 False
[16, 16, 16]
deit_base_distilled_patch16_384 False
[16, 16, 16]
deit_base_patch16_224 False
[16, 16, 16]
deit_base_patch16_384 False
[16, 16, 16]
deit_small_distilled_patch16_224 False
[16, 16, 16]
deit_small_patch16_224 False
[16, 16, 16]
deit_tiny_distilled_patch16_224 False
[16, 16, 16]
deit_tiny_patch16_224 False
[16, 16, 16]
eva02_base_patch14_224 False
[14, 14, 14]
eva02_base_patch14_448 False
[14, 14, 14]
eva02_base_patch16_clip_224 False
[16, 16, 16]
eva02_enormous_patch14_clip_224 False
[14, 14, 14]
eva02_large_patch14_224 False
[14, 14, 14]
eva02_large_patch14_448 False
[14, 14, 14]
eva02_large_patch14_clip_224 False
[14, 14, 14]
eva02_large_patch14_clip_336 False
[14, 14, 14]
eva02_small_patch14_224 False
[14, 14, 14]
eva02_small_patch14_336 False
[14, 14, 14]
eva02_tiny_patch14_224 False
[14, 14, 14]
eva02_tiny_patch14_336 False
[14, 14, 14]
eva_giant_patch14_224 False
[14, 14, 14]
eva_giant_patch14_336 False
[14, 14, 14]
eva_giant_patch14_560 False
[14, 14, 14]
eva_giant_patch14_clip_224 False
[14, 14, 14]
eva_large_patch14_196 False
[14, 14, 14]
eva_large_patch14_336 False
[14, 14, 14]
flexivit_base False
[16, 16, 16]
flexivit_large False
[16, 16, 16]
flexivit_small False
[16, 16, 16]
gmixer_12_224 False
[16, 16, 16]
gmixer_24_224 False
[16, 16, 16]
gmlp_b16_224 False
[16, 16, 16]
gmlp_s16_224 False
[16, 16, 16]
gmlp_ti16_224 False
[16, 16, 16]
mixer_b16_224 False
[16, 16, 16]
mixer_b32_224 False
[32, 32, 32]
mixer_l16_224 False
[16, 16, 16]
mixer_l32_224 False
[32, 32, 32]
mixer_s16_224 False
[16, 16, 16]
mixer_s32_224 False
[32, 32, 32]
resmlp_12_224 False
[16, 16, 16]
resmlp_24_224 False
[16, 16, 16]
resmlp_36_224 False
[16, 16, 16]
resmlp_big_24_224 False
[8, 8, 8]
samvit_base_patch16 False
[16, 16, 16]
samvit_base_patch16_224 False
[16, 16, 16]
samvit_huge_patch16 False
[16, 16, 16]
samvit_large_patch16 False
[16, 16, 16]
vit_base_mci_224 False
[16, 16, 16]
vit_base_patch8_224 False
[8, 8, 8]
vit_base_patch14_dinov2 False
[14, 14, 14]
vit_base_patch14_reg4_dinov2 False
[14, 14, 14]
vit_base_patch16_18x2_224 False
[16, 16, 16]
vit_base_patch16_224 False
[16, 16, 16]
vit_base_patch16_224_miil False
[16, 16, 16]
vit_base_patch16_384 False
[16, 16, 16]
vit_base_patch16_clip_224 False
[16, 16, 16]
vit_base_patch16_clip_384 False
[16, 16, 16]
vit_base_patch16_clip_quickgelu_224 False
[16, 16, 16]
vit_base_patch16_gap_224 False
[16, 16, 16]
vit_base_patch16_plus_240 False
[16, 16, 16]
vit_base_patch16_plus_clip_240 False
[16, 16, 16]
vit_base_patch16_reg4_gap_256 False
[16, 16, 16]
vit_base_patch16_rope_reg1_gap_256 False
[16, 16, 16]
vit_base_patch16_rpn_224 False
[16, 16, 16]
vit_base_patch16_siglip_224 False
[16, 16, 16]
vit_base_patch16_siglip_256 False
[16, 16, 16]
vit_base_patch16_siglip_384 False
[16, 16, 16]
vit_base_patch16_siglip_512 False
[16, 16, 16]
vit_base_patch16_siglip_gap_224 False
[16, 16, 16]
vit_base_patch16_siglip_gap_256 False
[16, 16, 16]
vit_base_patch16_siglip_gap_384 False
[16, 16, 16]
vit_base_patch16_siglip_gap_512 False
[16, 16, 16]
vit_base_patch16_xp_224 False
[16, 16, 16]
vit_base_patch32_224 False
[32, 32, 32]
vit_base_patch32_384 False
[32, 32, 32]
vit_base_patch32_clip_224 False
[32, 32, 32]
vit_base_patch32_clip_256 False
[32, 32, 32]
vit_base_patch32_clip_384 False
[32, 32, 32]
vit_base_patch32_clip_448 False
[32, 32, 32]
vit_base_patch32_clip_quickgelu_224 False
[32, 32, 32]
vit_base_patch32_plus_256 False
[32, 32, 32]
vit_base_r26_s32_224 False
[32, 32, 32]
vit_base_r50_s16_224 False
[16, 16, 16]
vit_base_r50_s16_384 False
[16, 16, 16]
vit_base_resnet26d_224 False
[32, 32, 32]
vit_base_resnet50d_224 False
[32, 32, 32]
vit_betwixt_patch16_gap_256 False
[16, 16, 16]
vit_betwixt_patch16_reg1_gap_256 False
[16, 16, 16]
vit_betwixt_patch16_reg4_gap_256 False
[16, 16, 16]
vit_betwixt_patch16_reg4_gap_384 False
[16, 16, 16]
vit_betwixt_patch16_rope_reg4_gap_256 False
[16, 16, 16]
vit_betwixt_patch32_clip_224 False
[32, 32, 32]
vit_giant_patch14_224 False
[14, 14, 14]
vit_giant_patch14_clip_224 False
[14, 14, 14]
vit_giant_patch14_dinov2 False
[14, 14, 14]
vit_giant_patch14_reg4_dinov2 False
[14, 14, 14]
vit_giant_patch16_gap_224 False
[16, 16, 16]
vit_gigantic_patch14_224 False
[14, 14, 14]
vit_gigantic_patch14_clip_224 False
[14, 14, 14]
vit_gigantic_patch14_clip_quickgelu_224 False
[14, 14, 14]
vit_huge_patch14_224 False
[14, 14, 14]
vit_huge_patch14_clip_224 False
[14, 14, 14]
vit_huge_patch14_clip_336 False
[14, 14, 14]
vit_huge_patch14_clip_378 False
[14, 14, 14]
vit_huge_patch14_clip_quickgelu_224 False
[14, 14, 14]
vit_huge_patch14_clip_quickgelu_378 False
[14, 14, 14]
vit_huge_patch14_gap_224 False
[14, 14, 14]
vit_huge_patch14_xp_224 False
[14, 14, 14]
vit_huge_patch16_gap_448 False
[16, 16, 16]
vit_intern300m_patch14_448 False
[14, 14, 14]
vit_large_patch14_224 False
[14, 14, 14]
vit_large_patch14_clip_224 False
[14, 14, 14]
vit_large_patch14_clip_336 False
[14, 14, 14]
vit_large_patch14_clip_quickgelu_224 False
[14, 14, 14]
vit_large_patch14_clip_quickgelu_336 False
[14, 14, 14]
vit_large_patch14_dinov2 False
[14, 14, 14]
vit_large_patch14_reg4_dinov2 False
[14, 14, 14]
vit_large_patch14_xp_224 False
[14, 14, 14]
vit_large_patch16_224 False
[16, 16, 16]
vit_large_patch16_384 False
[16, 16, 16]
vit_large_patch16_siglip_256 False
[16, 16, 16]
vit_large_patch16_siglip_384 False
[16, 16, 16]
vit_large_patch16_siglip_gap_256 False
[16, 16, 16]
vit_large_patch16_siglip_gap_384 False
[16, 16, 16]
vit_large_patch32_224 False
[32, 32, 32]
vit_large_patch32_384 False
[32, 32, 32]
vit_large_r50_s32_224 False
[32, 32, 32]
vit_large_r50_s32_384 False
[32, 32, 32]
vit_little_patch16_reg1_gap_256 False
[16, 16, 16]
vit_little_patch16_reg4_gap_256 False
[16, 16, 16]
vit_medium_patch16_clip_224 False
[16, 16, 16]
vit_medium_patch16_gap_240 False
[16, 16, 16]
vit_medium_patch16_gap_256 False
[16, 16, 16]
vit_medium_patch16_gap_384 False
[16, 16, 16]
vit_medium_patch16_reg1_gap_256 False
[16, 16, 16]
vit_medium_patch16_reg4_gap_256 False
[16, 16, 16]
vit_medium_patch16_rope_reg1_gap_256 False
[16, 16, 16]
vit_medium_patch32_clip_224 False
[32, 32, 32]
vit_mediumd_patch16_reg4_gap_256 False
[16, 16, 16]
vit_mediumd_patch16_reg4_gap_384 False
[16, 16, 16]
vit_mediumd_patch16_rope_reg1_gap_256 False
[16, 16, 16]
vit_pwee_patch16_reg1_gap_256 False
[16, 16, 16]
vit_relpos_base_patch16_224 False
[16, 16, 16]
vit_relpos_base_patch16_cls_224 False
[16, 16, 16]
vit_relpos_base_patch16_clsgap_224 False
[16, 16, 16]
vit_relpos_base_patch16_plus_240 False
[16, 16, 16]
vit_relpos_base_patch16_rpn_224 False
[16, 16, 16]
vit_relpos_base_patch32_plus_rpn_256 False
[32, 32, 32]
vit_relpos_medium_patch16_224 False
[16, 16, 16]
vit_relpos_medium_patch16_cls_224 False
[16, 16, 16]
vit_relpos_medium_patch16_rpn_224 False
[16, 16, 16]
vit_relpos_small_patch16_224 False
[16, 16, 16]
vit_relpos_small_patch16_rpn_224 False
[16, 16, 16]
vit_small_patch8_224 False
[8, 8, 8]
vit_small_patch14_dinov2 False
[14, 14, 14]
vit_small_patch14_reg4_dinov2 False
[14, 14, 14]
vit_small_patch16_18x2_224 False
[16, 16, 16]
vit_small_patch16_36x1_224 False
[16, 16, 16]
vit_small_patch16_224 False
[16, 16, 16]
vit_small_patch16_384 False
[16, 16, 16]
vit_small_patch32_224 False
[32, 32, 32]
vit_small_patch32_384 False
[32, 32, 32]
vit_small_r26_s32_224 False
[32, 32, 32]
vit_small_r26_s32_384 False
[32, 32, 32]
vit_small_resnet26d_224 False
[32, 32, 32]
vit_small_resnet50d_s16_224 False
[16, 16, 16]
vit_so150m_patch16_reg4_gap_256 False
[16, 16, 16]
vit_so150m_patch16_reg4_map_256 False
[16, 16, 16]
vit_so400m_patch14_siglip_224 False
[14, 14, 14]
vit_so400m_patch14_siglip_378 False
[14, 14, 14]
vit_so400m_patch14_siglip_384 False
[14, 14, 14]
vit_so400m_patch14_siglip_gap_224 False
[14, 14, 14]
vit_so400m_patch14_siglip_gap_378 False
[14, 14, 14]
vit_so400m_patch14_siglip_gap_384 False
[14, 14, 14]
vit_so400m_patch14_siglip_gap_448 False
[14, 14, 14]
vit_so400m_patch14_siglip_gap_896 False
[14, 14, 14]
vit_so400m_patch16_siglip_256 False
[16, 16, 16]
vit_so400m_patch16_siglip_gap_256 False
[16, 16, 16]
vit_srelpos_medium_patch16_224 False
[16, 16, 16]
vit_srelpos_small_patch16_224 False
[16, 16, 16]
vit_tiny_patch16_224 False
[16, 16, 16]
vit_tiny_patch16_384 False
[16, 16, 16]
vit_tiny_r_s16_p8_224 False
[32, 32, 32]
vit_tiny_r_s16_p8_384 False
[32, 32, 32]
vit_wee_patch16_reg1_gap_256 False
[16, 16, 16]
vit_xsmall_patch16_clip_224 False
[16, 16, 16]
vitamin_base_224 False
[16, 16, 16]
vitamin_large2_224 False
[16, 16, 16]
vitamin_large2_256 False
[16, 16, 16]
vitamin_large2_336 False
[16, 16, 16]
vitamin_large2_384 False
[16, 16, 16]
vitamin_large_224 False
[16, 16, 16]
vitamin_large_256 False
[16, 16, 16]
vitamin_large_336 False
[16, 16, 16]
vitamin_large_384 False
[16, 16, 16]
vitamin_small_224 False
[16, 16, 16]
vitamin_xlarge_256 False
[16, 16, 16]
vitamin_xlarge_336 False
[16, 16, 16]
vitamin_xlarge_384 False
[16, 16, 16]
volo_d1_224 False
[16, 16, 16]
volo_d1_384 False
[16, 16, 16]
volo_d2_224 False
[16, 16, 16]
volo_d2_384 False
[16, 16, 16]
volo_d3_224 False
[16, 16, 16]
volo_d3_448 False
[16, 16, 16]
volo_d4_224 False
[16, 16, 16]
volo_d4_448 False
[16, 16, 16]
volo_d5_224 False
[16, 16, 16]
volo_d5_448 False
[16, 16, 16]
volo_d5_512 False
[16, 16, 16]
xcit_large_24_p8_224 False
[8, 8, 8]
xcit_large_24_p8_384 False
[8, 8, 8]
xcit_large_24_p16_224 False
[16, 16, 16]
xcit_large_24_p16_384 False
[16, 16, 16]
xcit_medium_24_p8_224 False
[8, 8, 8]
xcit_medium_24_p8_384 False
[8, 8, 8]
xcit_medium_24_p16_224 False
[16, 16, 16]
xcit_medium_24_p16_384 False
[16, 16, 16]
xcit_nano_12_p8_224 False
[8, 8, 8]
xcit_nano_12_p8_384 False
[8, 8, 8]
xcit_nano_12_p16_224 False
[16, 16, 16]
xcit_nano_12_p16_384 False
[16, 16, 16]
xcit_small_12_p8_224 False
[8, 8, 8]
xcit_small_12_p8_384 False
[8, 8, 8]
xcit_small_12_p16_224 False
[16, 16, 16]
xcit_small_12_p16_384 False
[16, 16, 16]
xcit_small_24_p8_224 False
[8, 8, 8]
xcit_small_24_p8_384 False
[8, 8, 8]
xcit_small_24_p16_224 False
[16, 16, 16]
xcit_small_24_p16_384 False
[16, 16, 16]
xcit_tiny_12_p8_224 False
[8, 8, 8]
xcit_tiny_12_p8_384 False
[8, 8, 8]
xcit_tiny_12_p16_224 False
[16, 16, 16]
xcit_tiny_12_p16_384 False
[16, 16, 16]
xcit_tiny_24_p8_224 False
[8, 8, 8]
xcit_tiny_24_p8_384 False
[8, 8, 8]
xcit_tiny_24_p16_224 False
[16, 16, 16]
xcit_tiny_24_p16_384 False
[16, 16, 16]

References:

@brianhou0208
Copy link
Contributor

brianhou0208 commented Jan 8, 2025

@qubvel , here is minimum FPNWrapper implementation

code

import torch
import torch.nn as nn
import timm

class FPNWrapper(nn.Module):
    def __init__(self, model_name="beitv2_base_patch16_224"):
        super().__init__()

        self.model = timm.create_model(model_name, features_only=True, out_indices=(3, 5, 7, 11))
        
        embed_dim = 768
        self.fpn1 = nn.Sequential(
            nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
            nn.SyncBatchNorm(embed_dim),
            nn.GELU(),
            nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
        )

        self.fpn2 = nn.Sequential(
            nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
        )

        self.fpn3 = nn.Identity()

        self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
    
    def forward(self, x):
        features = self.model(x)
        ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
        for i in range(len(features)):
            features[i] = ops[i](features[i])

        return features


if __name__ == "__main__":
    x = torch.rand(1, 3, 224, 224)
    model = FPNWrapper("beitv2_base_patch16_224").eval()
    y = model(x)
    print([f.detach().numpy().shape for f in y])

output

 >>[(1, 768, 56, 56), (1, 768, 28, 28), (1, 768, 14, 14), (1, 768, 7, 7)]

@qubvel
Copy link
Collaborator Author

qubvel commented Jan 9, 2025

Previously, I considered the support for encoders without downsampled feature maps, and there might be several options available:

One option might be to adapt current models, such as UNet and FPN, to work with these types of feature maps by replacing interpolation operations that use scale with operations that use size.

Another option could be to introduce an intermediate neck, which will ensure that features are downsampled correctly (similar to your suggestion).

The third option might be to support these encoders only with specific decoders, such as DPT or BEiT. This should be explicitly explained in the documentation and validated in the code to provide a clear error message.

As for Beit, in this issue I meant to add only a semantic segmentation variant of the model, similar to the 🤗 Transformers library, which will support timm encoders with features of the same size (like vit) and can be extended to support other encoders (e.g. with optional FPN module or configurable stride).

@brianhou0208
Copy link
Contributor

Hi @qubvel,

I have reimplemented the BEiT3 code in the timm style, but it will take some time to verify.
Would you suggest creating a new pull request to submit to timm or adding it here?

@qubvel
Copy link
Collaborator Author

qubvel commented Jan 15, 2025

Hi @brianhou0208! Great news!

Not sure I got the question re PR to timm or smp.. but I would appreciate a PR for the segmentation model here 🤗

@brianhou0208
Copy link
Contributor

Sorry for the confusion in my previous message. What I meant was that if the code is submitted to timm and maintained by Ross, it could be used in the following way:

import timm  
model = timm.create_model('beit3_base_patch16_224', features_only=True)  

import segmentation_models_pytorch as smp  
encoder = smp.encoders.get_encoder('tu-beit3_base_patch16_224')  

Would this make maintaining smp easier?

@qubvel
Copy link
Collaborator Author

qubvel commented Jan 15, 2025

Do you mean backbone? Yes, that would be way better to have it on timm side

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers help wanted Extra attention is needed new-model
Projects
None yet
Development

No branches or pull requests

2 participants