diff --git a/tests/transforms/augmentation/test_random_affine.py b/tests/transforms/augmentation/test_random_affine.py index bb86c159c..a54992f7f 100644 --- a/tests/transforms/augmentation/test_random_affine.py +++ b/tests/transforms/augmentation/test_random_affine.py @@ -167,3 +167,11 @@ def test_no_inverse(self): ) transformed = apply_affine(tensor) self.assertTensorAlmostEqual(transformed, expected) + + def test_different_spaces(self): + t1 = self.sample_subject.t1 + label = tio.Resample(2)(self.sample_subject.label) + new_subject = tio.Subject(t1=t1, label=label) + with self.assertRaises(RuntimeError): + tio.RandomAffine()(new_subject) + tio.RandomAffine(check_shape=False)(new_subject) diff --git a/torchio/transforms/augmentation/spatial/random_affine.py b/torchio/transforms/augmentation/spatial/random_affine.py index 74a3bf4ce..d6c1d1f0e 100644 --- a/torchio/transforms/augmentation/spatial/random_affine.py +++ b/torchio/transforms/augmentation/spatial/random_affine.py @@ -75,6 +75,9 @@ class RandomAffine(RandomTransform, SpatialTransform): `Otsu threshold `_. If it is a number, that value will be used. image_interpolation: See :ref:`Interpolation`. + check_shape: If ``True`` an error will be raised if the images are in + different physical spaces. If ``False``, :attr:`center` should + probably not be ``'image'`` but ``'center'``. **kwargs: See :class:`~torchio.transforms.Transform` for additional keyword arguments. @@ -112,6 +115,7 @@ def __init__( center: str = 'image', default_pad_value: Union[str, float] = 'minimum', image_interpolation: str = 'linear', + check_shape: bool = True, **kwargs ): super().__init__(**kwargs) @@ -129,6 +133,7 @@ def __init__( self.center = center self.default_pad_value = _parse_default_value(default_pad_value) self.image_interpolation = self.parse_interpolation(image_interpolation) + self.check_shape = check_shape def get_params( self, @@ -145,7 +150,6 @@ def get_params( return scaling_params, rotation_params, translation_params def apply_transform(self, subject: Subject) -> Subject: - subject.check_consistent_spatial_shape() scaling_params, rotation_params, translation_params = self.get_params( self.scales, self.degrees, @@ -159,6 +163,7 @@ def apply_transform(self, subject: Subject) -> Subject: center=self.center, default_pad_value=self.default_pad_value, image_interpolation=self.image_interpolation, + check_shape=self.check_shape, ) transform = Affine(**self.add_include_exclude(arguments)) transformed = transform(subject) @@ -187,6 +192,9 @@ class Affine(SpatialTransform): `Otsu threshold `_. If it is a number, that value will be used. image_interpolation: See :ref:`Interpolation`. + check_shape: If ``True`` an error will be raised if the images are in + different physical spaces. If ``False``, :attr:`center` should + probably not be ``'image'`` but ``'center'``. **kwargs: See :class:`~torchio.transforms.Transform` for additional keyword arguments. """ @@ -198,6 +206,7 @@ def __init__( center: str = 'image', default_pad_value: Union[str, float] = 'minimum', image_interpolation: str = 'linear', + check_shape: bool = True, **kwargs ): super().__init__(**kwargs) @@ -231,6 +240,7 @@ def __init__( self.default_pad_value = _parse_default_value(default_pad_value) self.image_interpolation = self.parse_interpolation(image_interpolation) self.invert_transform = False + self.check_shape = check_shape self.args_names = ( 'scales', 'degrees', @@ -238,6 +248,7 @@ def __init__( 'center', 'default_pad_value', 'image_interpolation', + 'check_shape', ) @staticmethod @@ -322,7 +333,8 @@ def get_affine_transform(self, image): return transform def apply_transform(self, subject: Subject) -> Subject: - subject.check_consistent_spatial_shape() + if self.check_shape: + subject.check_consistent_spatial_shape() for image in self.get_images(subject): transform = self.get_affine_transform(image) transformed_tensors = []