-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Beit segmentation model #1024
Comments
cc @brianhou0208 in case you are interested! it was a pleasure working with you, very clean PRs 🤗 |
Hi @qubvel, I’m not very familiar with multimodal models yet, but after reviewing the BEiT-3 paper and code, here are my thoughts:
If we adopt the approach above in a more flexible manner, it could support a broader range of ViT-based models for segmentation tasks.
Without downsample feature: 264 (PR #1004 )
References: |
@qubvel , here is minimum codeimport torch
import torch.nn as nn
import timm
class FPNWrapper(nn.Module):
def __init__(self, model_name="beitv2_base_patch16_224"):
super().__init__()
self.model = timm.create_model(model_name, features_only=True, out_indices=(3, 5, 7, 11))
embed_dim = 768
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
nn.SyncBatchNorm(embed_dim),
nn.GELU(),
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn2 = nn.Sequential(
nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
)
self.fpn3 = nn.Identity()
self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self, x):
features = self.model(x)
ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
for i in range(len(features)):
features[i] = ops[i](features[i])
return features
if __name__ == "__main__":
x = torch.rand(1, 3, 224, 224)
model = FPNWrapper("beitv2_base_patch16_224").eval()
y = model(x)
print([f.detach().numpy().shape for f in y]) output >>[(1, 768, 56, 56), (1, 768, 28, 28), (1, 768, 14, 14), (1, 768, 7, 7)] |
Previously, I considered the support for encoders without downsampled feature maps, and there might be several options available: One option might be to adapt current models, such as UNet and FPN, to work with these types of feature maps by replacing interpolation operations that use Another option could be to introduce an intermediate The third option might be to support these encoders only with specific decoders, such as DPT or BEiT. This should be explicitly explained in the documentation and validated in the code to provide a clear error message. As for Beit, in this issue I meant to add only a semantic segmentation variant of the model, similar to the 🤗 Transformers library, which will support |
Hi @qubvel, I have reimplemented the BEiT3 code in the |
Hi @brianhou0208! Great news! Not sure I got the question re PR to timm or smp.. but I would appreciate a PR for the segmentation model here 🤗 |
Sorry for the confusion in my previous message. What I meant was that if the code is submitted to timm and maintained by Ross, it could be used in the following way: import timm
model = timm.create_model('beit3_base_patch16_224', features_only=True)
import segmentation_models_pytorch as smp
encoder = smp.encoders.get_encoder('tu-beit3_base_patch16_224') Would this make maintaining smp easier? |
Do you mean backbone? Yes, that would be way better to have it on timm side |
Add Beit to SMP
BEiT-3 is a general-purpose multimodal foundation model developed by Microsoft that excels in various vision and vision-language tasks, including semantic segmentation. It employs a unified architecture with Multiway Transformers, enabling both deep fusion and modality-specific encoding. Pretrained using a masked "language" modeling approach on images ("Imglish"), texts, and image-text pairs, BEiT-3 effectively models images as another language. This design allows it to achieve state-of-the-art performance across a wide range of tasks, such as object detection, image classification, and semantic segmentation.
Papers with Code:
https://paperswithcode.com/paper/image-as-a-foreign-language-beit-pretraining
Paper:
https://arxiv.org/abs/2208.10442
HF reference implementation:
https://huggingface.co/docs/transformers/model_doc/beit
https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/beit/modeling_beit.py
Comments
As an example pls see the latest model additions:
The text was updated successfully, but these errors were encountered: