From dba269cc503035e6bbf1e935046419a4a2a44a07 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Sun, 17 Nov 2024 14:51:51 +0000 Subject: [PATCH] :bug: Fix `model_to` device specification --- tests/models/test_arch_vanilla.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/models/test_arch_vanilla.py b/tests/models/test_arch_vanilla.py index 894bd2ef3..a87424dfd 100644 --- a/tests/models/test_arch_vanilla.py +++ b/tests/models/test_arch_vanilla.py @@ -6,7 +6,6 @@ from tiatoolbox.models.architecture.vanilla import CNNModel, TimmModel from tiatoolbox.models.models_abc import model_to -from tiatoolbox.utils.misc import select_device ON_GPU = False RNG = np.random.default_rng() # Numpy Random Generator @@ -46,7 +45,7 @@ def test_functional() -> None: for backbone in backbones: model = CNNModel(backbone, num_classes=1) model_ = model_to(device=device, model=model) - model.infer_batch(model_, samples, device=select_device(on_gpu=ON_GPU)) + model.infer_batch(model_, samples, device=device) except ValueError as exc: msg = f"Model {backbone} failed." raise AssertionError(msg) from exc @@ -72,8 +71,8 @@ def test_timm_functional() -> None: try: for backbone in backbones: model = TimmModel(backbone=backbone, num_classes=1, pretrained=False) - model_ = model_to(on_gpu=ON_GPU, model=model) - model.infer_batch(model_, samples, on_gpu=ON_GPU) + model_ = model_to(device=device, model=model) + model.infer_batch(model_, samples, device=device) except ValueError as exc: msg = f"Model {backbone} failed." raise AssertionError(msg) from exc