Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix torch compile, script, export #1031

Merged
merged 58 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
643c0b6
Move tests
qubvel Jan 12, 2025
7a937ab
Add compile test for encoders (to be optimized)
qubvel Jan 12, 2025
9a7c768
densnet
qubvel Jan 12, 2025
34b8533
dpn
qubvel Jan 12, 2025
a3618fa
efficientnet
qubvel Jan 12, 2025
e3f6c70
inceptionresnetv2
qubvel Jan 12, 2025
20b28be
inceptionv4
qubvel Jan 12, 2025
d996165
mix-transformer
qubvel Jan 12, 2025
9e38154
mobilenet
qubvel Jan 12, 2025
c6e5d53
mobileone
qubvel Jan 12, 2025
5a76722
resnet
qubvel Jan 12, 2025
36d056b
senet
qubvel Jan 12, 2025
aefcfd4
vgg
qubvel Jan 12, 2025
e9628bf
xception
qubvel Jan 12, 2025
70262e5
Deprecate `timm-` encoders, remap to `tu-` most of them
qubvel Jan 12, 2025
0b0b1c4
Add tiny encoders and compile mark
qubvel Jan 12, 2025
4c11682
Add conftest
qubvel Jan 12, 2025
70168b4
Fix features
qubvel Jan 12, 2025
8aed7ef
Merge branch 'main' into torch-compile-export
qubvel Jan 13, 2025
50c40d1
Add triggering compile tests on diff
qubvel Jan 13, 2025
0764d5e
Remove marks
qubvel Jan 13, 2025
7cab4be
Add test_compile stage to CI
qubvel Jan 13, 2025
2622e0e
Update requirements
qubvel Jan 13, 2025
e12ee8d
Update makefile
qubvel Jan 13, 2025
da0cd19
Update get_stages
qubvel Jan 13, 2025
7752969
Fix weight loading for deprecate encoders
qubvel Jan 13, 2025
409b820
Fix weight loading for mobilenetv3
qubvel Jan 13, 2025
ae3cb8a
Format
qubvel Jan 13, 2025
ff278c9
Add compile test for models
qubvel Jan 13, 2025
a806147
Add torch.export test
qubvel Jan 13, 2025
aa5b088
Disable export tests for dpn and inceptionv4
qubvel Jan 13, 2025
df2f484
Disable export for timm-eff-net
qubvel Jan 13, 2025
7157501
Huge fix for torch scripting (except Unet++ and UperNet)
qubvel Jan 14, 2025
257da0b
Fix scripting
qubvel Jan 14, 2025
d4d4cf6
Add test for torch script
qubvel Jan 14, 2025
3cb8198
Add torch_script test to CI
qubvel Jan 14, 2025
4f65d8f
Fix
qubvel Jan 14, 2025
70776ea
Fix timm-effnet encoders
qubvel Jan 14, 2025
31bee79
Make from_pretrained strict by default
qubvel Jan 14, 2025
556b3aa
Fix DeepLabV3 BC
qubvel Jan 14, 2025
f70d861
Fix scripting for encoders
qubvel Jan 14, 2025
ead24b4
Refactor test do not skip
qubvel Jan 14, 2025
d44509a
Fix encoders (mobilenet, inceptionv4)
qubvel Jan 14, 2025
b2c13f1
Update encoders table
qubvel Jan 14, 2025
73809e3
Fix export test
qubvel Jan 14, 2025
bc1319e
Fix docs
qubvel Jan 14, 2025
d25dd47
Update warning
qubvel Jan 14, 2025
4f3b37e
Move pretrained settings
qubvel Jan 14, 2025
06199b0
Add BC for timm- encoders
qubvel Jan 14, 2025
51e0a67
Fixing table
qubvel Jan 14, 2025
524bcae
Update compile test
qubvel Jan 14, 2025
a2b97d8
Change compile backend to eager
qubvel Jan 15, 2025
17a4b70
Update docs
qubvel Jan 15, 2025
20564f2
Fixup
qubvel Jan 15, 2025
5bbb1db
Fix batchnorm typo
qubvel Jan 15, 2025
d121fec
Add depth validation
qubvel Jan 15, 2025
7bb9d37
Update segmentation_models_pytorch/encoders/__init__.py
qubvel Jan 15, 2025
da24de9
Style
qubvel Jan 15, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 48 additions & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
run: uv pip list

- name: Test with PyTest
run: uv run pytest -v -rsx -n 2 --cov=segmentation_models_pytorch --cov-report=xml --cov-config=pyproject.toml -k "not logits_match"
run: uv run pytest -v -rsx -n 2 --cov=segmentation_models_pytorch --cov-report=xml --cov-config=pyproject.toml --non-marked-only

- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v5
Expand All @@ -73,7 +73,52 @@ jobs:
- name: Show installed packages
run: uv pip list
- name: Test with PyTest
run: RUN_SLOW=1 uv run pytest -v -rsx -n 2 -k "logits_match"
run: RUN_SLOW=1 uv run pytest -v -rsx -n 2 -m "logits_match"

test_torch_compile:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: astral-sh/setup-uv@v5
with:
python-version: "3.10"
- name: Install dependencies
run: uv pip install -r requirements/required.txt -r requirements/test.txt
- name: Show installed packages
run: uv pip list
- name: Test with PyTest
run: uv run pytest -v -rsx -n 2 -m "compile"

test_torch_export:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: astral-sh/setup-uv@v5
with:
python-version: "3.10"
- name: Install dependencies
run: uv pip install -r requirements/required.txt -r requirements/test.txt
- name: Show installed packages
run: uv pip list
- name: Test with PyTest
run: uv run pytest -v -rsx -n 2 -m "torch_export"

test_torch_script:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: astral-sh/setup-uv@v5
with:
python-version: "3.10"
- name: Install dependencies
run: uv pip install -r requirements/required.txt -r requirements/test.txt
- name: Show installed packages
run: uv pip list
- name: Test with PyTest
run: uv run pytest -v -rsx -n 2 -m "torch_script"

minimum:
runs-on: ubuntu-latest
Expand All @@ -88,4 +133,4 @@ jobs:
- name: Show installed packages
run: uv pip list
- name: Test with pytest
run: uv run pytest -v -rsx -n 2 -k "not logits_match"
run: uv run pytest -v -rsx -n 2 --non-marked-only
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ install_dev: .venv
.venv/bin/pip install -e ".[test]"

test: .venv
.venv/bin/pytest -v -rsx -n 2 tests/ -k "not logits_match"
.venv/bin/pytest -v -rsx -n 2 tests/ --non-marked-only

test_all: .venv
RUN_SLOW=1 .venv/bin/pytest -v -rsx -n 2 tests/
Expand Down
16 changes: 5 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,13 @@ docs = [
'sphinx-book-theme',
]
test = [
'gitpython',
'packaging',
'pytest',
'pytest-cov',
'pytest-xdist',
'ruff>=0.9',
'setuptools',
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
]

[project.urls]
Expand All @@ -61,18 +63,10 @@ include = ['segmentation_models_pytorch*']

[tool.pytest.ini_options]
markers = [
"deeplabv3",
"deeplabv3plus",
"fpn",
"linknet",
"manet",
"pan",
"psp",
"segformer",
"unet",
"unetplusplus",
"upernet",
"logits_match",
"compile",
"torch_export",
"torch_script",
]

[tool.coverage.run]
Expand Down
2 changes: 2 additions & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
gitpython==3.1.44
packaging==24.2
pytest==8.3.4
pytest-xdist==3.6.1
pytest-cov==6.0.0
ruff==0.9.1
setuptools==75.8.0
2 changes: 2 additions & 0 deletions segmentation_models_pytorch/base/hub_mixin.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
import json
from pathlib import Path
from typing import Optional, Union
Expand Down Expand Up @@ -114,6 +115,7 @@ def save_pretrained(
return result

@property
@torch.jit.unused
def config(self) -> dict:
return self._hub_mixin_config

Expand Down
13 changes: 9 additions & 4 deletions segmentation_models_pytorch/base/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@

from . import initialization as init
from .hub_mixin import SMPHubMixin
from .utils import is_torch_compiling

T = TypeVar("T", bound="SegmentationModel")


class SegmentationModel(torch.nn.Module, SMPHubMixin):
"""Base class for all segmentation models."""

# if model supports shape not divisible by 2 ^ n
# set to False
# if model supports shape not divisible by 2 ^ n set to False
requires_divisible_input_shape = True

# Fix type-hint for models, to avoid HubMixin signature
Expand All @@ -29,6 +29,9 @@ def check_input_shape(self, x):
"""Check if the input shape is divisible by the output stride.
If not, raise a RuntimeError.
"""
if not self.requires_divisible_input_shape:
return

h, w = x.shape[-2:]
output_stride = self.encoder.output_stride
if h % output_stride != 0 or w % output_stride != 0:
Expand All @@ -50,11 +53,13 @@ def check_input_shape(self, x):
def forward(self, x):
"""Sequentially pass `x` trough model`s encoder, decoder and heads"""

if not torch.jit.is_tracing() and self.requires_divisible_input_shape:
if not (
torch.jit.is_scripting() or torch.jit.is_tracing() or is_torch_compiling()
):
self.check_input_shape(x)

features = self.encoder(x)
decoder_output = self.decoder(*features)
decoder_output = self.decoder(features)

masks = self.segmentation_head(decoder_output)

Expand Down
14 changes: 14 additions & 0 deletions segmentation_models_pytorch/base/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch


@torch.jit.unused
def is_torch_compiling():
try:
return torch.compiler.is_compiling()
except Exception:
try:
import torch._dynamo as dynamo # noqa: F401

Check warning on line 10 in segmentation_models_pytorch/base/utils.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/base/utils.py#L8-L10

Added lines #L8 - L10 were not covered by tests

return dynamo.is_compiling()
except Exception:
return False

Check warning on line 14 in segmentation_models_pytorch/base/utils.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/base/utils.py#L12-L14

Added lines #L12 - L14 were not covered by tests
57 changes: 39 additions & 18 deletions segmentation_models_pytorch/decoders/deeplabv3/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"""

from collections.abc import Iterable, Sequence
from typing import Literal
from typing import Literal, List

import torch
from torch import nn
Expand All @@ -49,21 +49,42 @@
aspp_separable: bool,
aspp_dropout: float,
):
super().__init__(
ASPP(
in_channels,
out_channels,
atrous_rates,
separable=aspp_separable,
dropout=aspp_dropout,
),
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
super().__init__()
self.aspp = ASPP(
in_channels,
out_channels,
atrous_rates,
separable=aspp_separable,
dropout=aspp_dropout,
)

def forward(self, *features):
return super().forward(features[-1])
self.conv = nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()

def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
x = features[-1]
x = self.aspp(x)
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x

def load_state_dict(self, state_dict, *args, **kwargs):
# For backward compatibility, previously this module was Sequential
# and was not scriptable.
keys = list(state_dict.keys())
for key in keys:
new_key = key
if key.startswith("0."):
new_key = "aspp." + key[2:]
elif key.startswith("1."):
new_key = "conv." + key[2:]
elif key.startswith("2."):
new_key = "bn." + key[2:]
elif key.startswith("3."):
new_key = "relu." + key[2:]
state_dict[new_key] = state_dict.pop(key)
super().load_state_dict(state_dict, *args, **kwargs)

Check warning on line 87 in segmentation_models_pytorch/decoders/deeplabv3/decoder.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/decoders/deeplabv3/decoder.py#L75-L87

Added lines #L75 - L87 were not covered by tests


class DeepLabV3PlusDecoder(nn.Module):
Expand Down Expand Up @@ -124,7 +145,7 @@
nn.ReLU(),
)

def forward(self, *features):
def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
aspp_features = self.aspp(features[-1])
aspp_features = self.up(aspp_features)
high_res_features = self.block1(features[2])
Expand Down Expand Up @@ -174,7 +195,7 @@
nn.ReLU(),
)

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
size = x.shape[-2:]
for mod in self:
x = mod(x)
Expand Down Expand Up @@ -216,7 +237,7 @@
nn.Dropout(dropout),
)

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
res = []
for conv in self.convs:
res.append(conv(x))
Expand Down
Loading
Loading