Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

integrate SAM (segment anything) encoder with Unet #757

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ repos:
hooks:
- id: black
args: [ --config=pyproject.toml ]
- repo: https://gitlab.com/pycqa/flake8
rev: 4.0.1
- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
hooks:
- id: flake8
args: [ --config=.flake8 ]
Expand Down
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ Congratulations! You are done! Now you can train your model with your favorite f
- PAN [[paper](https://arxiv.org/abs/1805.10180)] [[docs](https://smp.readthedocs.io/en/latest/models.html#pan)]
- DeepLabV3 [[paper](https://arxiv.org/abs/1706.05587)] [[docs](https://smp.readthedocs.io/en/latest/models.html#deeplabv3)]
- DeepLabV3+ [[paper](https://arxiv.org/abs/1802.02611)] [[docs](https://smp.readthedocs.io/en/latest/models.html#id9)]
- SAM [[paper](https://ai.facebook.com/research/publications/segment-anything/)] [[docs](https://github.com/facebookresearch/segment-anything)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pls, remove this, no decoder any more


#### Encoders <a name="encoders"></a>

Expand Down Expand Up @@ -394,6 +395,19 @@ Note: In the official github repo the s0 variant has additional num_conv_branche
</div>
</details>

<details>
<summary style="margin-left: 25px;">SAM</summary>
<div style="margin-left: 25px;">

| Encoder | Weights | Params, M |
|-----------|:--------:|:---------:|
| sam-vit_b | sa-1b | 91M |
| sam-vit_l | sa-1b | 308M |
| sam-vit_h | sa-1b | 636M |

</div>
</details>


\* `ssl`, `swsl` - semi-supervised and weakly-supervised learning on ImageNet ([repo](https://github.com/facebookresearch/semi-supervised-ImageNet1K-models)).

Expand Down
13 changes: 13 additions & 0 deletions docs/encoders.rst
Original file line number Diff line number Diff line change
Expand Up @@ -361,3 +361,16 @@ MobileOne
+-----------------+----------+------------+
| mobileone\_s4 | imagenet | 13.6M |
+-----------------+----------+------------+

SAM
~~~~~~~~~~~~~~~~~~~~~

+-----------------+----------+------------+
| Encoder | Weights | Params, M |
+=================+==========+============+
| sam-vit_b | sa-1b | 91M |
+-----------------+----------+------------+
| sam-vit_l | sa-1b | 308M |
+-----------------+----------+------------+
| sam-vit_h | sa-1b | 636M |
+-----------------+----------+------------+
3 changes: 3 additions & 0 deletions docs/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,7 @@ DeepLabV3+
~~~~~~~~~~
.. autoclass:: segmentation_models_pytorch.DeepLabV3Plus

SAM
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pls remove this, no decoder anymore

~~~~~~~~~~
.. autoclass:: segmentation_models_pytorch.SAM

2 changes: 2 additions & 0 deletions segmentation_models_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .decoders.pspnet import PSPNet
from .decoders.deeplabv3 import DeepLabV3, DeepLabV3Plus
from .decoders.pan import PAN
from .decoders.sam import SAM

from .__version__ import __version__

Expand Down Expand Up @@ -42,6 +43,7 @@ def create_model(
DeepLabV3,
DeepLabV3Plus,
PAN,
SAM,
]
archs_dict = {a.__name__.lower(): a for a in archs}
try:
Expand Down
1 change: 1 addition & 0 deletions segmentation_models_pytorch/decoders/sam/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .model import SAM
213 changes: 213 additions & 0 deletions segmentation_models_pytorch/decoders/sam/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
import logging
from typing import Optional, Union, List, Tuple

import torch
from segment_anything.modeling import MaskDecoder, TwoWayTransformer, PromptEncoder
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it a pip package? probably need to add to reqs

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just added it to reqs, or should we make it optional?

from segment_anything.modeling.prompt_encoder import PositionEmbeddingRandom
from torch import nn
from torch.nn import functional as F
from torch.utils import model_zoo

from segmentation_models_pytorch.base import (
SegmentationModel,
SegmentationHead,
)
from segmentation_models_pytorch.encoders import get_encoder, sam_vit_encoders, get_pretrained_settings

logger = logging.getLogger("sam")
logger.setLevel(logging.WARNING)
stream = logging.StreamHandler()
logger.addHandler(stream)
logger.propagate = False


class SAM(SegmentationModel):
"""SAM_ (Segment Anything Model) is a visual transformer based encoder-decoder segmentation
model that can be used to produce high quality segmentation masks from images and prompts.
Consists of *image encoder*, *prompt encoder* and *mask decoder*. *Segmentation head* is
added after the *mask decoder* to define the final number of classes for the output mask.

Args:
encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
to extract features of different spatial resolution
encoder_depth: A number of stages used in encoder in range [6, 24]. Each stage generate features
two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
Default is 5
encoder_weights: One of **None** (random initialization), **"sa-1b"** (pre-training on SA-1B dataset).
decoder_channels: How many output channels image encoder will have. Default is 256.
in_channels: A number of input channels for the model, default is 3 (RGB images)
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
activation: An activation function to apply after the final convolution layer.
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
**callable** and **None**.
Default is **None**
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
on top of encoder if **aux_params** is not **None** (default). Supported params:
- classes (int): A number of classes
- pooling (str): One of "max", "avg". Default is "avg"
- dropout (float): Dropout factor in [0, 1)
- activation (str): An activation function to apply "sigmoid"/"softmax"
(could be **None** to return logits)

Returns:
``torch.nn.Module``: SAM

.. _SAM:
https://github.com/facebookresearch/segment-anything

"""

def __init__(
self,
encoder_name: str = "sam-vit_h",
encoder_depth: int = None,
encoder_weights: Optional[str] = None,
decoder_channels: List[int] = 256,
decoder_multimask_output: bool = True,
in_channels: int = 3,
image_size: int = 1024,
vit_patch_size: int = 16,
classes: int = 1,
weights: Optional[str] = "sa-1b",
activation: Optional[Union[str, callable]] = None,
aux_params: Optional[dict] = None,
):
super().__init__()

self.register_buffer("pixel_mean", torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1), False)
self.register_buffer("pixel_std", torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1), False)

self.encoder = get_encoder(
encoder_name,
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
img_size=image_size,
patch_size=vit_patch_size,
out_chans=decoder_channels,
)

# this params are used instead of prompt_encoder
image_embedding_size = image_size // vit_patch_size
self.embed_dim = decoder_channels
self.image_embedding_size = (image_embedding_size, image_embedding_size)
self.pe_layer = PositionEmbeddingRandom(decoder_channels // 2)
self.no_mask_embed = nn.Embedding(1, decoder_channels)

self.decoder = MaskDecoder(
num_multimask_outputs=3,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=decoder_channels,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=decoder_channels,
iou_head_depth=3,
iou_head_hidden_dim=256,
)
self._decoder_multiclass_output = decoder_multimask_output

if weights is not None:
self._load_pretrained_weights(encoder_name, weights)

self.segmentation_head = SegmentationHead(
in_channels=3 if decoder_multimask_output else 1,
out_channels=classes,
activation=activation,
kernel_size=3,
)

if aux_params is not None:
raise NotImplementedError("Auxiliary output is not supported yet")
self.classification_head = None

self.name = encoder_name
self.initialize()

def _load_pretrained_weights(self, encoder_name: str, weights: str):
settings = get_pretrained_settings(sam_vit_encoders, encoder_name, weights)
state_dict = model_zoo.load_url(settings["url"])
state_dict = {k.replace("image_encoder", "encoder"): v for k, v in state_dict.items()}
state_dict = {k.replace("mask_decoder", "decoder"): v for k, v in state_dict.items()}
missing, unused = self.load_state_dict(state_dict, strict=False)
if len(missing) > 0 or len(unused) > 0:
n_loaded = len(state_dict) - len(missing) - len(unused)
logger.warning(
f"Only {n_loaded} out of pretrained {len(state_dict)} SAM modules are loaded. "
f"Missing modules: {missing}. Unused modules: {unused}."
)

def preprocess(self, x):
"""Normalize pixel values and pad to a square input."""
# Normalize colors
x = (x - self.pixel_mean) / self.pixel_std

# Pad
h, w = x.shape[-2:]
padh = self.encoder.img_size - h
padw = self.encoder.img_size - w
x = F.pad(x, (0, padw, 0, padh))
return x

def postprocess_masks(
self,
masks: torch.Tensor,
input_size: Tuple[int, ...],
original_size: Tuple[int, ...],
) -> torch.Tensor:
"""
Remove padding and upscale masks to the original image size.

Arguments:
masks (torch.Tensor): Batched masks from the mask_decoder,
in BxCxHxW format.
input_size (tuple(int, int)): The size of the image input to the
model, in (H, W) format. Used to remove padding.
original_size (tuple(int, int)): The original size of the image
before resizing for input to the model, in (H, W) format.

Returns:
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
is given by original_size.
"""
masks = F.interpolate(
masks,
(self.encoder.img_size, self.encoder.img_size),
mode="bilinear",
align_corners=False,
)
masks = masks[..., : input_size[0], : input_size[1]]
masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
return masks

def forward(self, x):
img_size = x.shape[-2:]
x = torch.stack([self.preprocess(img) for img in x])
features = self.encoder(x)
# sparse_embeddings, dense_embeddings = self.prompt_encoder(points=None, boxes=None, masks=None)
sparse_embeddings, dense_embeddings = self._get_dummy_promp_encoder_output(x.size(0))
low_res_masks, iou_predictions = self.decoder(
image_embeddings=features,
image_pe=self._get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=self._decoder_multiclass_output,
)
masks = self.postprocess_masks(low_res_masks, input_size=img_size, original_size=img_size)
# use scaling below in order to make it work with torch DDP
masks = masks * iou_predictions.view(-1, masks.size(1), 1, 1)
output = self.segmentation_head(masks)
return output

def _get_dummy_promp_encoder_output(self, bs):
"""Use this dummy output as we're training without prompts."""
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self.no_mask_embed.weight.device)
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
)
return sparse_embeddings, dense_embeddings

def _get_dense_pe(self):
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
48 changes: 37 additions & 11 deletions segmentation_models_pytorch/encoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .resnet import resnet_encoders
from .dpn import dpn_encoders
from .sam import sam_vit_encoders, SamVitEncoder
from .vgg import vgg_encoders
from .senet import senet_encoders
from .densenet import densenet_encoders
Expand Down Expand Up @@ -46,6 +47,34 @@
encoders.update(timm_gernet_encoders)
encoders.update(mix_transformer_encoders)
encoders.update(mobileone_encoders)
encoders.update(sam_vit_encoders)


def get_pretrained_settings(encoders: dict, encoder_name: str, weights: str) -> dict:
"""Get pretrained settings for encoder from encoders collection.

Args:
encoders: collection of encoders
encoder_name: name of encoder in collection
weights: one of ``None`` (random initialization), ``imagenet`` or other pretrained settings

Returns:
pretrained settings for encoder

Raises:
KeyError: in case of wrong encoder name or pretrained settings name
"""
try:
settings = encoders[encoder_name]["pretrained_settings"][weights]
except KeyError:
raise KeyError(
"Wrong pretrained weights `{}` for encoder `{}`. Available options are: {}".format(
weights,
encoder_name,
list(encoders[encoder_name]["pretrained_settings"].keys()),
)
)
return settings


def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs):
Expand All @@ -68,20 +97,17 @@ def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **
raise KeyError("Wrong encoder name `{}`, supported encoders: {}".format(name, list(encoders.keys())))

params = encoders[name]["params"]
params.update(depth=depth)
if name.startswith("sam-"):
params.update(**kwargs)
params.update(dict(name=name[4:]))
if depth is not None:
params.update(depth=depth)
else:
params.update(depth=depth)
Rusteam marked this conversation as resolved.
Show resolved Hide resolved
encoder = Encoder(**params)

if weights is not None:
try:
settings = encoders[name]["pretrained_settings"][weights]
except KeyError:
raise KeyError(
"Wrong pretrained weights `{}` for encoder `{}`. Available options are: {}".format(
weights,
name,
list(encoders[name]["pretrained_settings"].keys()),
)
)
settings = get_pretrained_settings(encoders, name, weights)
encoder.load_state_dict(model_zoo.load_url(settings["url"]))

encoder.set_in_channels(in_channels, pretrained=weights is not None)
Expand Down
5 changes: 0 additions & 5 deletions segmentation_models_pytorch/encoders/_base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
import torch
import torch.nn as nn
from typing import List
from collections import OrderedDict

from . import _utils as utils


Expand Down
Loading