From 2d3048405fc0dc06e64b910d58d34fb67f6f0d5f Mon Sep 17 00:00:00 2001 From: "F. Dangel" Date: Mon, 3 Aug 2020 18:05:38 +0200 Subject: [PATCH 1/2] Add failing test: evaluating loss on cifar100_wrn404 with regularization --- tests/pytorch/testproblems/test_cifar100_wrn404.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/pytorch/testproblems/test_cifar100_wrn404.py b/tests/pytorch/testproblems/test_cifar100_wrn404.py index 1dc2ac32..9d8bb73c 100644 --- a/tests/pytorch/testproblems/test_cifar100_wrn404.py +++ b/tests/pytorch/testproblems/test_cifar100_wrn404.py @@ -24,12 +24,12 @@ def setUp(self): """Sets up CIFAR-100 dataset for the tests.""" self.batch_size = 100 self.cifar100_wrn404 = testproblems.cifar100_wrn404(self.batch_size) - - def test_num_param(self): - """Tests the number of parameters.""" torch.manual_seed(42) self.cifar100_wrn404.set_up() + self.cifar100_wrn404.train_init_op() + def test_num_param(self): + """Tests the number of parameters.""" num_param = [] for parameter in self.cifar100_wrn404.net.parameters(): num_param.append(parameter.numel()) @@ -155,6 +155,11 @@ def test_num_param(self): self.assertEqual(num_param, expected_num_param) + def test_forward_pass_with_regularization(self): + loss, _ = self.cifar100_wrn404.get_batch_loss_and_accuracy( + add_regularization_if_available=True + ) + if __name__ == "__main__": unittest.main() From 0be39fef55800b73fc1d9e0d99e20b40cd9ecbc2 Mon Sep 17 00:00:00 2001 From: "F. Dangel" Date: Mon, 3 Aug 2020 18:07:09 +0200 Subject: [PATCH 2/2] Fix missing return of regularization_groups --- deepobs/pytorch/testproblems/cifar100_wrn404.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepobs/pytorch/testproblems/cifar100_wrn404.py b/deepobs/pytorch/testproblems/cifar100_wrn404.py index fa8ce3c2..84252a50 100644 --- a/deepobs/pytorch/testproblems/cifar100_wrn404.py +++ b/deepobs/pytorch/testproblems/cifar100_wrn404.py @@ -66,4 +66,4 @@ def get_regularization_groups(self): group_dict[l2].append(parameters) else: group_dict[no].append(parameters) - + return group_dict