From 238dd11631dc466aaaf27728ae0a5f6da456eaa2 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 14:13:18 +0000 Subject: [PATCH] [run-slow] Fixing minimum --- tests/models/base.py | 3 ++- tests/models/test_segformer.py | 3 ++- tests/utils.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/models/base.py b/tests/models/base.py index 15ca6751..739f40eb 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -137,6 +137,7 @@ def test_classification_head(self): self.assertEqual(cls_probs.shape[1], 10) + @requires_torch_greater_or_equal("2.0.0") def test_save_load_with_hub_mixin(self): # instantiate model model = smp.create_model( @@ -172,7 +173,7 @@ def test_save_load_with_hub_mixin(self): self.assertIn("my_awesome_metric", readme) @slow_test - @requires_torch_greater_or_equal("2.0.1") + @requires_torch_greater_or_equal("2.0.0") def test_preserve_forward_output(self): from huggingface_hub import hf_hub_download diff --git a/tests/models/test_segformer.py b/tests/models/test_segformer.py index 3d073763..f59a0fcc 100644 --- a/tests/models/test_segformer.py +++ b/tests/models/test_segformer.py @@ -3,7 +3,7 @@ import segmentation_models_pytorch as smp from tests.models import base -from tests.utils import slow_test, default_device +from tests.utils import slow_test, default_device, requires_torch_greater_or_equal @pytest.mark.segformer @@ -11,6 +11,7 @@ class TestSegformerModel(base.BaseModelTester): test_model_type = "segformer" @slow_test + @requires_torch_greater_or_equal("2.0.0") def test_load_pretrained(self): hub_checkpoint = "smp-hub/segformer-b0-512x512-ade-160k" diff --git a/tests/utils.py b/tests/utils.py index e8bce88e..e87874f3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -42,6 +42,6 @@ def requires_torch_greater_or_equal(version: str): torch_version = Version(torch.__version__) provided_version = Version(version) return unittest.skipUnless( - torch_version >= provided_version, + torch_version < provided_version, f"torch version {torch_version} is less than {provided_version}", )