-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
integrate SAM (segment anything) encoder with Unet #757
base: main
Are you sure you want to change the base?
Changes from 14 commits
4beb571
3dc3235
85565ce
1143668
48033cb
1f1eaca
f37c9b3
6b36927
64a2516
4d1144e
c1a9319
2ed775d
9c93eb4
9731e8f
500779e
b301d30
12a0db6
a049c88
b8189e0
e6cfdc9
9b29124
e968719
5edc0ee
e5c4bc4
c5bc356
f1ac494
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,4 +37,7 @@ DeepLabV3+ | |
~~~~~~~~~~ | ||
.. autoclass:: segmentation_models_pytorch.DeepLabV3Plus | ||
|
||
SAM | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pls remove this, no decoder anymore |
||
~~~~~~~~~~ | ||
.. autoclass:: segmentation_models_pytorch.SAM | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .model import SAM |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
import logging | ||
from typing import Optional, Union, List, Tuple | ||
|
||
import torch | ||
from segment_anything.modeling import MaskDecoder, TwoWayTransformer, PromptEncoder | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it a pip package? probably need to add to reqs There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just added it to reqs, or should we make it optional? |
||
from segment_anything.modeling.prompt_encoder import PositionEmbeddingRandom | ||
from torch import nn | ||
from torch.nn import functional as F | ||
from torch.utils import model_zoo | ||
|
||
from segmentation_models_pytorch.base import ( | ||
SegmentationModel, | ||
SegmentationHead, | ||
) | ||
from segmentation_models_pytorch.encoders import get_encoder, sam_vit_encoders, get_pretrained_settings | ||
|
||
logger = logging.getLogger("sam") | ||
logger.setLevel(logging.WARNING) | ||
stream = logging.StreamHandler() | ||
logger.addHandler(stream) | ||
logger.propagate = False | ||
|
||
|
||
class SAM(SegmentationModel): | ||
"""SAM_ (Segment Anything Model) is a visual transformer based encoder-decoder segmentation | ||
model that can be used to produce high quality segmentation masks from images and prompts. | ||
Consists of *image encoder*, *prompt encoder* and *mask decoder*. *Segmentation head* is | ||
added after the *mask decoder* to define the final number of classes for the output mask. | ||
|
||
Args: | ||
encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) | ||
to extract features of different spatial resolution | ||
encoder_depth: A number of stages used in encoder in range [6, 24]. Each stage generate features | ||
two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features | ||
with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). | ||
Default is 5 | ||
encoder_weights: One of **None** (random initialization), **"sa-1b"** (pre-training on SA-1B dataset). | ||
decoder_channels: How many output channels image encoder will have. Default is 256. | ||
in_channels: A number of input channels for the model, default is 3 (RGB images) | ||
classes: A number of classes for output mask (or you can think as a number of channels of output mask) | ||
activation: An activation function to apply after the final convolution layer. | ||
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, | ||
**callable** and **None**. | ||
Default is **None** | ||
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build | ||
on top of encoder if **aux_params** is not **None** (default). Supported params: | ||
- classes (int): A number of classes | ||
- pooling (str): One of "max", "avg". Default is "avg" | ||
- dropout (float): Dropout factor in [0, 1) | ||
- activation (str): An activation function to apply "sigmoid"/"softmax" | ||
(could be **None** to return logits) | ||
|
||
Returns: | ||
``torch.nn.Module``: SAM | ||
|
||
.. _SAM: | ||
https://github.com/facebookresearch/segment-anything | ||
|
||
""" | ||
|
||
def __init__( | ||
self, | ||
encoder_name: str = "sam-vit_h", | ||
encoder_depth: int = None, | ||
encoder_weights: Optional[str] = None, | ||
decoder_channels: List[int] = 256, | ||
decoder_multimask_output: bool = True, | ||
in_channels: int = 3, | ||
image_size: int = 1024, | ||
vit_patch_size: int = 16, | ||
classes: int = 1, | ||
weights: Optional[str] = "sa-1b", | ||
activation: Optional[Union[str, callable]] = None, | ||
aux_params: Optional[dict] = None, | ||
): | ||
super().__init__() | ||
|
||
self.register_buffer("pixel_mean", torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1), False) | ||
self.register_buffer("pixel_std", torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1), False) | ||
|
||
self.encoder = get_encoder( | ||
encoder_name, | ||
in_channels=in_channels, | ||
depth=encoder_depth, | ||
weights=encoder_weights, | ||
img_size=image_size, | ||
patch_size=vit_patch_size, | ||
out_chans=decoder_channels, | ||
) | ||
|
||
# this params are used instead of prompt_encoder | ||
image_embedding_size = image_size // vit_patch_size | ||
self.embed_dim = decoder_channels | ||
self.image_embedding_size = (image_embedding_size, image_embedding_size) | ||
self.pe_layer = PositionEmbeddingRandom(decoder_channels // 2) | ||
self.no_mask_embed = nn.Embedding(1, decoder_channels) | ||
|
||
self.decoder = MaskDecoder( | ||
num_multimask_outputs=3, | ||
transformer=TwoWayTransformer( | ||
depth=2, | ||
embedding_dim=decoder_channels, | ||
mlp_dim=2048, | ||
num_heads=8, | ||
), | ||
transformer_dim=decoder_channels, | ||
iou_head_depth=3, | ||
iou_head_hidden_dim=256, | ||
) | ||
self._decoder_multiclass_output = decoder_multimask_output | ||
|
||
if weights is not None: | ||
self._load_pretrained_weights(encoder_name, weights) | ||
|
||
self.segmentation_head = SegmentationHead( | ||
in_channels=3 if decoder_multimask_output else 1, | ||
out_channels=classes, | ||
activation=activation, | ||
kernel_size=3, | ||
) | ||
|
||
if aux_params is not None: | ||
raise NotImplementedError("Auxiliary output is not supported yet") | ||
self.classification_head = None | ||
|
||
self.name = encoder_name | ||
self.initialize() | ||
|
||
def _load_pretrained_weights(self, encoder_name: str, weights: str): | ||
settings = get_pretrained_settings(sam_vit_encoders, encoder_name, weights) | ||
state_dict = model_zoo.load_url(settings["url"]) | ||
state_dict = {k.replace("image_encoder", "encoder"): v for k, v in state_dict.items()} | ||
state_dict = {k.replace("mask_decoder", "decoder"): v for k, v in state_dict.items()} | ||
missing, unused = self.load_state_dict(state_dict, strict=False) | ||
if len(missing) > 0 or len(unused) > 0: | ||
n_loaded = len(state_dict) - len(missing) - len(unused) | ||
logger.warning( | ||
f"Only {n_loaded} out of pretrained {len(state_dict)} SAM modules are loaded. " | ||
f"Missing modules: {missing}. Unused modules: {unused}." | ||
) | ||
|
||
def preprocess(self, x): | ||
"""Normalize pixel values and pad to a square input.""" | ||
# Normalize colors | ||
x = (x - self.pixel_mean) / self.pixel_std | ||
|
||
# Pad | ||
h, w = x.shape[-2:] | ||
padh = self.encoder.img_size - h | ||
padw = self.encoder.img_size - w | ||
x = F.pad(x, (0, padw, 0, padh)) | ||
return x | ||
|
||
def postprocess_masks( | ||
self, | ||
masks: torch.Tensor, | ||
input_size: Tuple[int, ...], | ||
original_size: Tuple[int, ...], | ||
) -> torch.Tensor: | ||
""" | ||
Remove padding and upscale masks to the original image size. | ||
|
||
Arguments: | ||
masks (torch.Tensor): Batched masks from the mask_decoder, | ||
in BxCxHxW format. | ||
input_size (tuple(int, int)): The size of the image input to the | ||
model, in (H, W) format. Used to remove padding. | ||
original_size (tuple(int, int)): The original size of the image | ||
before resizing for input to the model, in (H, W) format. | ||
|
||
Returns: | ||
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W) | ||
is given by original_size. | ||
""" | ||
masks = F.interpolate( | ||
masks, | ||
(self.encoder.img_size, self.encoder.img_size), | ||
mode="bilinear", | ||
align_corners=False, | ||
) | ||
masks = masks[..., : input_size[0], : input_size[1]] | ||
masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) | ||
return masks | ||
|
||
def forward(self, x): | ||
img_size = x.shape[-2:] | ||
x = torch.stack([self.preprocess(img) for img in x]) | ||
features = self.encoder(x) | ||
# sparse_embeddings, dense_embeddings = self.prompt_encoder(points=None, boxes=None, masks=None) | ||
sparse_embeddings, dense_embeddings = self._get_dummy_promp_encoder_output(x.size(0)) | ||
low_res_masks, iou_predictions = self.decoder( | ||
image_embeddings=features, | ||
image_pe=self._get_dense_pe(), | ||
sparse_prompt_embeddings=sparse_embeddings, | ||
dense_prompt_embeddings=dense_embeddings, | ||
multimask_output=self._decoder_multiclass_output, | ||
) | ||
masks = self.postprocess_masks(low_res_masks, input_size=img_size, original_size=img_size) | ||
# use scaling below in order to make it work with torch DDP | ||
masks = masks * iou_predictions.view(-1, masks.size(1), 1, 1) | ||
output = self.segmentation_head(masks) | ||
return output | ||
|
||
def _get_dummy_promp_encoder_output(self, bs): | ||
"""Use this dummy output as we're training without prompts.""" | ||
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self.no_mask_embed.weight.device) | ||
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( | ||
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] | ||
) | ||
return sparse_embeddings, dense_embeddings | ||
|
||
def _get_dense_pe(self): | ||
return self.pe_layer(self.image_embedding_size).unsqueeze(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pls, remove this, no decoder any more