diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 83b02392..3ebe6143 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -90,6 +90,21 @@ jobs: - name: Test with PyTest run: uv run pytest -v -rsx -n 2 -m "compile" + test_torch_export: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: astral-sh/setup-uv@v5 + with: + python-version: "3.10" + - name: Install dependencies + run: uv pip install -r requirements/required.txt -r requirements/test.txt + - name: Show installed packages + run: uv pip list + - name: Test with PyTest + run: uv run pytest -v -rsx -n 2 -m "torch_export" + minimum: runs-on: ubuntu-latest steps: diff --git a/pyproject.toml b/pyproject.toml index 4e34be59..8d9b2078 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,7 @@ include = ['segmentation_models_pytorch*'] markers = [ "logits_match", "compile", + "torch_export", ] [tool.coverage.run] diff --git a/segmentation_models_pytorch/base/model.py b/segmentation_models_pytorch/base/model.py index a25ed30a..e04c2d6e 100644 --- a/segmentation_models_pytorch/base/model.py +++ b/segmentation_models_pytorch/base/model.py @@ -3,6 +3,7 @@ from . import initialization as init from .hub_mixin import SMPHubMixin +from .utils import is_torch_compiling T = TypeVar("T", bound="SegmentationModel") @@ -50,7 +51,11 @@ def check_input_shape(self, x): def forward(self, x): """Sequentially pass `x` trough model`s encoder, decoder and heads""" - if not torch.jit.is_tracing() and self.requires_divisible_input_shape: + if ( + not torch.jit.is_tracing() + and not is_torch_compiling() + and self.requires_divisible_input_shape + ): self.check_input_shape(x) features = self.encoder(x) diff --git a/segmentation_models_pytorch/base/utils.py b/segmentation_models_pytorch/base/utils.py new file mode 100644 index 00000000..3fcba739 --- /dev/null +++ b/segmentation_models_pytorch/base/utils.py @@ -0,0 +1,13 @@ +import torch + + +def is_torch_compiling(): + try: + return torch.compiler.is_compiling() + except Exception: + try: + import torch._dynamo as dynamo # noqa: F401 + + return dynamo.is_compiling() + except Exception: + return False diff --git a/tests/encoders/base.py b/tests/encoders/base.py index 87ad2cfb..c1858cdb 100644 --- a/tests/encoders/base.py +++ b/tests/encoders/base.py @@ -231,3 +231,31 @@ def test_compile(self): with torch.inference_mode(): compiled_encoder(sample) + + @pytest.mark.torch_export + def test_torch_export(self): + if not check_run_test_on_diff_or_main(self.files_for_diff): + self.skipTest("No diff and not on `main`.") + + sample = self._get_sample( + batch_size=self.default_batch_size, + num_channels=self.default_num_channels, + height=self.default_height, + width=self.default_width, + ).to(default_device) + + encoder = self.get_tiny_encoder() + encoder = encoder.eval().to(default_device) + + exported_encoder = torch.export.export( + encoder, + args=(sample,), + strict=True, + ) + + with torch.inference_mode(): + eager_output = encoder(sample) + exported_output = exported_encoder.module().forward(sample) + + for eager_feature, exported_feature in zip(eager_output, exported_output): + torch.testing.assert_close(eager_feature, exported_feature) diff --git a/tests/models/base.py b/tests/models/base.py index ba246436..d6e19fd0 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -254,3 +254,31 @@ def test_compile(self): with torch.inference_mode(): compiled_model(sample) + + @pytest.mark.torch_export + def test_torch_export(self): + if not check_run_test_on_diff_or_main(self.files_for_diff): + self.skipTest("No diff and not on `main`.") + + sample = self._get_sample( + batch_size=self.default_batch_size, + num_channels=self.default_num_channels, + height=self.default_height, + width=self.default_width, + ).to(default_device) + + model = self.get_default_model() + model.eval() + + exported_model = torch.export.export( + model, + args=(sample,), + strict=True, + ) + + with torch.inference_mode(): + eager_output = model(sample) + exported_output = exported_model.module().forward(sample) + + self.assertEqual(eager_output.shape, exported_output.shape) + torch.testing.assert_close(eager_output, exported_output)