Skip to content

Commit

Permalink
Add save-load test, add aux head test
Browse files Browse the repository at this point in the history
  • Loading branch information
qubvel committed Dec 23, 2024
1 parent 2b113f0 commit b7ce422
Showing 1 changed file with 57 additions and 1 deletion.
58 changes: 57 additions & 1 deletion tests/models/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest
import inspect
import tempfile
import unittest
from functools import lru_cache

import torch
Expand Down Expand Up @@ -89,3 +90,58 @@ def test_base_params_are_set(self, in_channels=1, depth=3, classes=7):
output = model(sample)

self.assertEqual(output.shape[1], classes)

def test_aux_params(self):
model = smp.create_model(
arch=self.model_type,
aux_params={
"pooling": "avg",
"classes": 10,
"dropout": 0.5,
"activation": "sigmoid",
},
)

self.assertIsNotNone(model.classification_head)
self.assertIsInstance(model.classification_head[0], torch.nn.AdaptiveAvgPool2d)
self.assertIsInstance(model.classification_head[1], torch.nn.Flatten)
self.assertIsInstance(model.classification_head[2], torch.nn.Dropout)
self.assertEqual(model.classification_head[2].p, 0.5)
self.assertIsInstance(model.classification_head[3], torch.nn.Linear)
self.assertIsInstance(model.classification_head[4].activation, torch.nn.Sigmoid)

sample = self._get_sample(
batch_size=self.default_batch_size,
num_channels=self.default_num_channels,
height=self.default_height,
width=self.default_width,
)

with torch.no_grad():
_, cls_probs = model(sample)

self.assertEqual(cls_probs.shape[1], 10)

def test_save_load(self):
# instantiate model
model = smp.create_model(arch=self.model_type)

# save model
with tempfile.TemporaryDirectory() as tmpdir:
model.save_pretrained(tmpdir)
restored_model = model.from_pretrained(tmpdir)

# check inference is correct
sample = self._get_sample(
batch_size=self.default_batch_size,
num_channels=self.default_num_channels,
height=self.default_height,
width=self.default_width,
)

with torch.no_grad():
output = model(sample)
restored_output = restored_model(sample)

self.assertEqual(output.shape, restored_output.shape)
self.assertEqual(output.shape[1], 1)

0 comments on commit b7ce422

Please sign in to comment.