Skip to content

Commit

Permalink
Add torch.export test
Browse files Browse the repository at this point in the history
  • Loading branch information
qubvel committed Jan 13, 2025
1 parent ff278c9 commit a806147
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 1 deletion.
15 changes: 15 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,21 @@ jobs:
- name: Test with PyTest
run: uv run pytest -v -rsx -n 2 -m "compile"

test_torch_export:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: astral-sh/setup-uv@v5
with:
python-version: "3.10"
- name: Install dependencies
run: uv pip install -r requirements/required.txt -r requirements/test.txt
- name: Show installed packages
run: uv pip list
- name: Test with PyTest
run: uv run pytest -v -rsx -n 2 -m "torch_export"

minimum:
runs-on: ubuntu-latest
steps:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ include = ['segmentation_models_pytorch*']
markers = [
"logits_match",
"compile",
"torch_export",
]

[tool.coverage.run]
Expand Down
7 changes: 6 additions & 1 deletion segmentation_models_pytorch/base/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from . import initialization as init
from .hub_mixin import SMPHubMixin
from .utils import is_torch_compiling

T = TypeVar("T", bound="SegmentationModel")

Expand Down Expand Up @@ -50,7 +51,11 @@ def check_input_shape(self, x):
def forward(self, x):
"""Sequentially pass `x` trough model`s encoder, decoder and heads"""

if not torch.jit.is_tracing() and self.requires_divisible_input_shape:
if (
not torch.jit.is_tracing()
and not is_torch_compiling()
and self.requires_divisible_input_shape
):
self.check_input_shape(x)

features = self.encoder(x)
Expand Down
13 changes: 13 additions & 0 deletions segmentation_models_pytorch/base/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import torch


def is_torch_compiling():
try:
return torch.compiler.is_compiling()
except Exception:
try:
import torch._dynamo as dynamo # noqa: F401

Check warning on line 9 in segmentation_models_pytorch/base/utils.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/base/utils.py#L7-L9

Added lines #L7 - L9 were not covered by tests

return dynamo.is_compiling()
except Exception:
return False

Check warning on line 13 in segmentation_models_pytorch/base/utils.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/base/utils.py#L11-L13

Added lines #L11 - L13 were not covered by tests
28 changes: 28 additions & 0 deletions tests/encoders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,31 @@ def test_compile(self):

with torch.inference_mode():
compiled_encoder(sample)

@pytest.mark.torch_export
def test_torch_export(self):
if not check_run_test_on_diff_or_main(self.files_for_diff):
self.skipTest("No diff and not on `main`.")

sample = self._get_sample(
batch_size=self.default_batch_size,
num_channels=self.default_num_channels,
height=self.default_height,
width=self.default_width,
).to(default_device)

encoder = self.get_tiny_encoder()
encoder = encoder.eval().to(default_device)

exported_encoder = torch.export.export(
encoder,
args=(sample,),
strict=True,
)

with torch.inference_mode():
eager_output = encoder(sample)
exported_output = exported_encoder.module().forward(sample)

for eager_feature, exported_feature in zip(eager_output, exported_output):
torch.testing.assert_close(eager_feature, exported_feature)
28 changes: 28 additions & 0 deletions tests/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,31 @@ def test_compile(self):

with torch.inference_mode():
compiled_model(sample)

@pytest.mark.torch_export
def test_torch_export(self):
if not check_run_test_on_diff_or_main(self.files_for_diff):
self.skipTest("No diff and not on `main`.")

sample = self._get_sample(
batch_size=self.default_batch_size,
num_channels=self.default_num_channels,
height=self.default_height,
width=self.default_width,
).to(default_device)

model = self.get_default_model()
model.eval()

exported_model = torch.export.export(
model,
args=(sample,),
strict=True,
)

with torch.inference_mode():
eager_output = model(sample)
exported_output = exported_model.module().forward(sample)

self.assertEqual(eager_output.shape, exported_output.shape)
torch.testing.assert_close(eager_output, exported_output)

0 comments on commit a806147

Please sign in to comment.