Skip to content

Commit

Permalink
🐛 Fix model_to device specification
Browse files Browse the repository at this point in the history
  • Loading branch information
shaneahmed committed Nov 17, 2024
1 parent 7b4f496 commit dba269c
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions tests/models/test_arch_vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from tiatoolbox.models.architecture.vanilla import CNNModel, TimmModel
from tiatoolbox.models.models_abc import model_to
from tiatoolbox.utils.misc import select_device

ON_GPU = False
RNG = np.random.default_rng() # Numpy Random Generator
Expand Down Expand Up @@ -46,7 +45,7 @@ def test_functional() -> None:
for backbone in backbones:
model = CNNModel(backbone, num_classes=1)
model_ = model_to(device=device, model=model)
model.infer_batch(model_, samples, device=select_device(on_gpu=ON_GPU))
model.infer_batch(model_, samples, device=device)
except ValueError as exc:
msg = f"Model {backbone} failed."
raise AssertionError(msg) from exc
Expand All @@ -72,8 +71,8 @@ def test_timm_functional() -> None:
try:
for backbone in backbones:
model = TimmModel(backbone=backbone, num_classes=1, pretrained=False)
model_ = model_to(on_gpu=ON_GPU, model=model)
model.infer_batch(model_, samples, on_gpu=ON_GPU)
model_ = model_to(device=device, model=model)
model.infer_batch(model_, samples, device=device)
except ValueError as exc:
msg = f"Model {backbone} failed."
raise AssertionError(msg) from exc
Expand Down

0 comments on commit dba269c

Please sign in to comment.