Skip to content

Commit

Permalink
Fix torch compile, script, export (#1031)
Browse files Browse the repository at this point in the history
* Move tests

* Add compile test for encoders (to be optimized)

* densnet

* dpn

* efficientnet

* inceptionresnetv2

* inceptionv4

* mix-transformer

* mobilenet

* mobileone

* resnet

* senet

* vgg

* xception

* Deprecate `timm-` encoders, remap to `tu-` most of them

* Add tiny encoders and compile mark

* Add conftest

* Fix features

* Add triggering compile tests on diff

* Remove marks

* Add test_compile stage to CI

* Update requirements

* Update makefile

* Update get_stages

* Fix weight loading for deprecate encoders

* Fix weight loading for mobilenetv3

* Format

* Add compile test for models

* Add torch.export test

* Disable export tests for dpn and inceptionv4

* Disable export for timm-eff-net

* Huge fix for torch scripting (except Unet++ and UperNet)

* Fix scripting

* Add test for torch script

* Add torch_script test to CI

* Fix

* Fix timm-effnet encoders

* Make from_pretrained strict by default

* Fix DeepLabV3 BC

* Fix scripting for encoders

* Refactor test do not skip

* Fix encoders (mobilenet, inceptionv4)

* Update encoders table

* Fix export test

* Fix docs

* Update warning

* Move pretrained settings

* Add BC for timm- encoders

* Fixing table

* Update compile test

* Change compile backend to eager

* Update docs

* Fixup

* Fix batchnorm typo

* Add depth validation

* Update segmentation_models_pytorch/encoders/__init__.py

Co-authored-by: Adam J. Stewart <[email protected]>

* Style

---------

Co-authored-by: Adam J. Stewart <[email protected]>
  • Loading branch information
qubvel and adamjstewart authored Jan 15, 2025
1 parent 93b19d3 commit 456871a
Show file tree
Hide file tree
Showing 63 changed files with 2,418 additions and 2,289 deletions.
51 changes: 48 additions & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
run: uv pip list

- name: Test with PyTest
run: uv run pytest -v -rsx -n 2 --cov=segmentation_models_pytorch --cov-report=xml --cov-config=pyproject.toml -k "not logits_match"
run: uv run pytest -v -rsx -n 2 --cov=segmentation_models_pytorch --cov-report=xml --cov-config=pyproject.toml --non-marked-only

- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v5
Expand All @@ -73,7 +73,52 @@ jobs:
- name: Show installed packages
run: uv pip list
- name: Test with PyTest
run: RUN_SLOW=1 uv run pytest -v -rsx -n 2 -k "logits_match"
run: RUN_SLOW=1 uv run pytest -v -rsx -n 2 -m "logits_match"

test_torch_compile:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: astral-sh/setup-uv@v5
with:
python-version: "3.10"
- name: Install dependencies
run: uv pip install -r requirements/required.txt -r requirements/test.txt
- name: Show installed packages
run: uv pip list
- name: Test with PyTest
run: uv run pytest -v -rsx -n 2 -m "compile"

test_torch_export:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: astral-sh/setup-uv@v5
with:
python-version: "3.10"
- name: Install dependencies
run: uv pip install -r requirements/required.txt -r requirements/test.txt
- name: Show installed packages
run: uv pip list
- name: Test with PyTest
run: uv run pytest -v -rsx -n 2 -m "torch_export"

test_torch_script:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: astral-sh/setup-uv@v5
with:
python-version: "3.10"
- name: Install dependencies
run: uv pip install -r requirements/required.txt -r requirements/test.txt
- name: Show installed packages
run: uv pip list
- name: Test with PyTest
run: uv run pytest -v -rsx -n 2 -m "torch_script"

minimum:
runs-on: ubuntu-latest
Expand All @@ -88,4 +133,4 @@ jobs:
- name: Show installed packages
run: uv pip list
- name: Test with pytest
run: uv run pytest -v -rsx -n 2 -k "not logits_match"
run: uv run pytest -v -rsx -n 2 --non-marked-only
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ install_dev: .venv
.venv/bin/pip install -e ".[test]"

test: .venv
.venv/bin/pytest -v -rsx -n 2 tests/ -k "not logits_match"
.venv/bin/pytest -v -rsx -n 2 tests/ --non-marked-only

test_all: .venv
RUN_SLOW=1 .venv/bin/pytest -v -rsx -n 2 tests/
Expand Down
498 changes: 138 additions & 360 deletions docs/encoders.rst

Large diffs are not rendered by default.

30 changes: 21 additions & 9 deletions misc/generate_table.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
import os
import segmentation_models_pytorch as smp

from tqdm import tqdm

encoders = smp.encoders.encoders


WIDTH = 32
COLUMNS = ["Encoder", "Weights", "Params, M"]
COLUMNS = ["Encoder", "Pretrained weights", "Params, M", "Script", "Compile", "Export"]
FILE = "encoders_table.md"

if os.path.exists(FILE):
os.remove(FILE)


def wrap_row(r):
Expand All @@ -16,18 +23,23 @@ def wrap_row(r):
["-" * WIDTH] + [":" + "-" * (WIDTH - 2) + ":"] * (len(COLUMNS) - 1)
)

print(wrap_row(header))
print(wrap_row(separator))
print(wrap_row(header), file=open(FILE, "a"))
print(wrap_row(separator), file=open(FILE, "a"))

for encoder_name, encoder in encoders.items():
for encoder_name, encoder in tqdm(encoders.items()):
weights = "<br>".join(encoder["pretrained_settings"].keys())
encoder_name = encoder_name.ljust(WIDTH, " ")
weights = weights.ljust(WIDTH, " ")

model = encoder["encoder"](**encoder["params"], depth=5)

script = "✅" if model._is_torch_scriptable else "❌"
compile = "✅" if model._is_torch_compilable else "❌"
export = "✅" if model._is_torch_exportable else "❌"

params = sum(p.numel() for p in model.parameters())
params = str(params // 1000000) + "M"
params = params.ljust(WIDTH, " ")

row = "|".join([encoder_name, weights, params])
print(wrap_row(row))
row = [encoder_name, weights, params, script, compile, export]
row = [str(r).ljust(WIDTH, " ") for r in row]
row = "|".join(row)

print(wrap_row(row), file=open(FILE, "a"))
16 changes: 5 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,13 @@ docs = [
'sphinx-book-theme',
]
test = [
'gitpython',
'packaging',
'pytest',
'pytest-cov',
'pytest-xdist',
'ruff>=0.9',
'setuptools',
]

[project.urls]
Expand All @@ -61,18 +63,10 @@ include = ['segmentation_models_pytorch*']

[tool.pytest.ini_options]
markers = [
"deeplabv3",
"deeplabv3plus",
"fpn",
"linknet",
"manet",
"pan",
"psp",
"segformer",
"unet",
"unetplusplus",
"upernet",
"logits_match",
"compile",
"torch_export",
"torch_script",
]

[tool.coverage.run]
Expand Down
2 changes: 2 additions & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
gitpython==3.1.44
packaging==24.2
pytest==8.3.4
pytest-xdist==3.6.1
pytest-cov==6.0.0
ruff==0.9.1
setuptools==75.8.0
10 changes: 8 additions & 2 deletions segmentation_models_pytorch/base/hub_mixin.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
import json
from pathlib import Path
from typing import Optional, Union
Expand Down Expand Up @@ -114,12 +115,15 @@ def save_pretrained(
return result

@property
@torch.jit.unused
def config(self) -> dict:
return self._hub_mixin_config


@wraps(PyTorchModelHubMixin.from_pretrained)
def from_pretrained(pretrained_model_name_or_path: str, *args, **kwargs):
def from_pretrained(
pretrained_model_name_or_path: str, *args, strict: bool = True, **kwargs
):
config_path = Path(pretrained_model_name_or_path) / "config.json"
if not config_path.exists():
config_path = hf_hub_download(
Expand All @@ -135,7 +139,9 @@ def from_pretrained(pretrained_model_name_or_path: str, *args, **kwargs):
import segmentation_models_pytorch as smp

model_class = getattr(smp, model_class_name)
return model_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
return model_class.from_pretrained(
pretrained_model_name_or_path, *args, **kwargs, strict=strict
)


def supports_config_loading(func):
Expand Down
43 changes: 39 additions & 4 deletions segmentation_models_pytorch/base/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@

from . import initialization as init
from .hub_mixin import SMPHubMixin
from .utils import is_torch_compiling

T = TypeVar("T", bound="SegmentationModel")


class SegmentationModel(torch.nn.Module, SMPHubMixin):
"""Base class for all segmentation models."""

# if model supports shape not divisible by 2 ^ n
# set to False
_is_torch_scriptable = True
_is_torch_exportable = True
_is_torch_compilable = True

# if model supports shape not divisible by 2 ^ n set to False
requires_divisible_input_shape = True

# Fix type-hint for models, to avoid HubMixin signature
Expand All @@ -29,6 +33,9 @@ def check_input_shape(self, x):
"""Check if the input shape is divisible by the output stride.
If not, raise a RuntimeError.
"""
if not self.requires_divisible_input_shape:
return

h, w = x.shape[-2:]
output_stride = self.encoder.output_stride
if h % output_stride != 0 or w % output_stride != 0:
Expand All @@ -50,11 +57,13 @@ def check_input_shape(self, x):
def forward(self, x):
"""Sequentially pass `x` trough model`s encoder, decoder and heads"""

if not torch.jit.is_tracing() and self.requires_divisible_input_shape:
if not (
torch.jit.is_scripting() or torch.jit.is_tracing() or is_torch_compiling()
):
self.check_input_shape(x)

features = self.encoder(x)
decoder_output = self.decoder(*features)
decoder_output = self.decoder(features)

masks = self.segmentation_head(decoder_output)

Expand All @@ -81,3 +90,29 @@ def predict(self, x):
x = self.forward(x)

return x

def load_state_dict(self, state_dict, **kwargs):
# for compatibility of weights for
# timm- ported encoders with TimmUniversalEncoder
from segmentation_models_pytorch.encoders import TimmUniversalEncoder

if not isinstance(self.encoder, TimmUniversalEncoder):
return super().load_state_dict(state_dict, **kwargs)

patterns = ["regnet", "res2", "resnest", "mobilenetv3", "gernet"]

is_deprecated_encoder = any(
self.encoder.name.startswith(pattern) for pattern in patterns
)

if is_deprecated_encoder:
keys = list(state_dict.keys())
for key in keys:
new_key = key
if key.startswith("encoder.") and not key.startswith("encoder.model."):
new_key = "encoder.model." + key.removeprefix("encoder.")
if "gernet" in self.encoder.name:
new_key = new_key.replace(".stages.", ".stages_")
state_dict[new_key] = state_dict.pop(key)

return super().load_state_dict(state_dict, **kwargs)
14 changes: 14 additions & 0 deletions segmentation_models_pytorch/base/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch


@torch.jit.unused
def is_torch_compiling():
try:
return torch.compiler.is_compiling()
except Exception:
try:
import torch._dynamo as dynamo # noqa: F401

return dynamo.is_compiling()
except Exception:
return False
40 changes: 22 additions & 18 deletions segmentation_models_pytorch/decoders/deeplabv3/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"""

from collections.abc import Iterable, Sequence
from typing import Literal
from typing import Literal, List

import torch
from torch import nn
Expand All @@ -40,7 +40,7 @@
__all__ = ["DeepLabV3Decoder", "DeepLabV3PlusDecoder"]


class DeepLabV3Decoder(nn.Sequential):
class DeepLabV3Decoder(nn.Module):
def __init__(
self,
in_channels: int,
Expand All @@ -49,21 +49,25 @@ def __init__(
aspp_separable: bool,
aspp_dropout: float,
):
super().__init__(
ASPP(
in_channels,
out_channels,
atrous_rates,
separable=aspp_separable,
dropout=aspp_dropout,
),
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
super().__init__()
self.aspp = ASPP(
in_channels,
out_channels,
atrous_rates,
separable=aspp_separable,
dropout=aspp_dropout,
)
self.conv = nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()

def forward(self, *features):
return super().forward(features[-1])
def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
x = features[-1]
x = self.aspp(x)
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x


class DeepLabV3PlusDecoder(nn.Module):
Expand Down Expand Up @@ -124,7 +128,7 @@ def __init__(
nn.ReLU(),
)

def forward(self, *features):
def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
aspp_features = self.aspp(features[-1])
aspp_features = self.up(aspp_features)
high_res_features = self.block1(features[2])
Expand Down Expand Up @@ -174,7 +178,7 @@ def __init__(self, in_channels: int, out_channels: int):
nn.ReLU(),
)

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
size = x.shape[-2:]
for mod in self:
x = mod(x)
Expand Down Expand Up @@ -216,7 +220,7 @@ def __init__(
nn.Dropout(dropout),
)

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
res = []
for conv in self.convs:
res.append(conv(x))
Expand Down
Loading

0 comments on commit 456871a

Please sign in to comment.