From ca17ba035ecea6f5ea73fb3e308be6cfd52634ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fernando=20P=C3=A9rez-Garc=C3=ADa?= Date: Sun, 6 Oct 2024 00:21:02 +0200 Subject: [PATCH] Fix KeyError with LabelMap-only subjects (#1218) * Fix KeyError with LabelMap-only subjects * Fix mypy error * Use double precision for params --- pyproject.toml | 2 +- .../intensity/random_bias_field.py | 12 +++++------ .../augmentation/intensity/random_blur.py | 12 +++++++---- .../augmentation/intensity/random_gamma.py | 6 +++++- .../augmentation/intensity/random_ghosting.py | 11 +++++++--- .../augmentation/intensity/random_motion.py | 6 +++++- .../augmentation/intensity/random_noise.py | 6 +++++- .../augmentation/intensity/random_spike.py | 8 +++++-- .../augmentation/intensity/random_swap.py | 6 +++++- .../augmentation/random_transform.py | 7 +++++-- .../augmentation/spatial/random_affine.py | 21 +++++++++++++------ src/torchio/transforms/transform.py | 5 +---- 12 files changed, 70 insertions(+), 32 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9bb3074c..0855bf21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,7 @@ dev = [ "pytest", "pytest-cov", "pytest-sugar", - "tox", + "tox-uv", "types-Deprecated", ] doc = [ diff --git a/src/torchio/transforms/augmentation/intensity/random_bias_field.py b/src/torchio/transforms/augmentation/intensity/random_bias_field.py index f1e797c7..c623d0c1 100644 --- a/src/torchio/transforms/augmentation/intensity/random_bias_field.py +++ b/src/torchio/transforms/augmentation/intensity/random_bias_field.py @@ -51,17 +51,17 @@ def __init__( self.order = _parse_order(order) def apply_transform(self, subject: Subject) -> Subject: + images_dict = self.get_images_dict(subject) + if not images_dict: + return subject + arguments: Dict[str, dict] = defaultdict(dict) - for image_name in self.get_images_dict(subject): - coefficients = self.get_params( - self.order, - self.coefficients_range, - ) + for image_name in images_dict: + coefficients = self.get_params(self.order, self.coefficients_range) arguments['coefficients'][image_name] = coefficients arguments['order'][image_name] = self.order transform = BiasField(**self.add_include_exclude(arguments)) transformed = transform(subject) - assert isinstance(transformed, Subject) return transformed def get_params( diff --git a/src/torchio/transforms/augmentation/intensity/random_blur.py b/src/torchio/transforms/augmentation/intensity/random_blur.py index 704def8f..e1406dea 100644 --- a/src/torchio/transforms/augmentation/intensity/random_blur.py +++ b/src/torchio/transforms/augmentation/intensity/random_blur.py @@ -39,9 +39,13 @@ def __init__(self, std: Union[float, Tuple[float, float]] = (0, 2), **kwargs): self.std_ranges = self.parse_params(std, None, 'std', min_constraint=0) def apply_transform(self, subject: Subject) -> Subject: + images_dict = self.get_images_dict(subject) + if not images_dict: + return subject + arguments: Dict[str, dict] = defaultdict(dict) - for name in self.get_images_dict(subject): - std = self.get_params(self.std_ranges) + for name in images_dict: + std = self.get_params(self.std_ranges) # type: ignore[arg-type] arguments['std'][name] = std transform = Blur(**self.add_include_exclude(arguments)) transformed = transform(subject) @@ -49,8 +53,8 @@ def apply_transform(self, subject: Subject) -> Subject: return transformed def get_params(self, std_ranges: TypeSextetFloat) -> TypeTripletFloat: - std = self.sample_uniform_sextet(std_ranges) - return std + sx, sy, sz = self.sample_uniform_sextet(std_ranges) + return sx, sy, sz class Blur(IntensityTransform): diff --git a/src/torchio/transforms/augmentation/intensity/random_gamma.py b/src/torchio/transforms/augmentation/intensity/random_gamma.py index bae5a177..5a2d2721 100644 --- a/src/torchio/transforms/augmentation/intensity/random_gamma.py +++ b/src/torchio/transforms/augmentation/intensity/random_gamma.py @@ -69,8 +69,12 @@ def __init__(self, log_gamma: TypeRangeFloat = (-0.3, 0.3), **kwargs): self.log_gamma_range = self._parse_range(log_gamma, 'log_gamma') def apply_transform(self, subject: Subject) -> Subject: + images_dict = self.get_images_dict(subject) + if not images_dict: + return subject + arguments: Dict[str, dict] = defaultdict(dict) - for name, image in self.get_images_dict(subject).items(): + for name, image in images_dict.items(): gammas = [self.get_params(self.log_gamma_range) for _ in image.data] arguments['gamma'][name] = gammas transform = Gamma(**self.add_include_exclude(arguments)) diff --git a/src/torchio/transforms/augmentation/intensity/random_ghosting.py b/src/torchio/transforms/augmentation/intensity/random_ghosting.py index 2316f0b1..8a827076 100644 --- a/src/torchio/transforms/augmentation/intensity/random_ghosting.py +++ b/src/torchio/transforms/augmentation/intensity/random_ghosting.py @@ -84,10 +84,15 @@ def __init__( self.restore = _parse_restore(restore) def apply_transform(self, subject: Subject) -> Subject: - arguments: Dict[str, dict] = defaultdict(dict) - if any(isinstance(n, str) for n in self.axes): + images_dict = self.get_images_dict(subject) + if not images_dict: + return subject + + if any(isinstance(axis, str) for axis in self.axes): subject.check_consistent_orientation() - for name, image in self.get_images_dict(subject).items(): + + arguments: Dict[str, dict] = defaultdict(dict) + for name, image in images_dict.items(): is_2d = image.is_2d() axes = [a for a in self.axes if a != 2] if is_2d else self.axes min_ghosts, max_ghosts = self.num_ghosts_range diff --git a/src/torchio/transforms/augmentation/intensity/random_motion.py b/src/torchio/transforms/augmentation/intensity/random_motion.py index 3ff81f0d..1e179f31 100644 --- a/src/torchio/transforms/augmentation/intensity/random_motion.py +++ b/src/torchio/transforms/augmentation/intensity/random_motion.py @@ -73,8 +73,12 @@ def __init__( ) def apply_transform(self, subject: Subject) -> Subject: + images_dict = self.get_images_dict(subject) + if not images_dict: + return subject + arguments: Dict[str, dict] = defaultdict(dict) - for name, image in self.get_images_dict(subject).items(): + for name, image in images_dict.items(): params = self.get_params( self.degrees_range, self.translation_range, diff --git a/src/torchio/transforms/augmentation/intensity/random_noise.py b/src/torchio/transforms/augmentation/intensity/random_noise.py index 0914810f..2560b4dd 100644 --- a/src/torchio/transforms/augmentation/intensity/random_noise.py +++ b/src/torchio/transforms/augmentation/intensity/random_noise.py @@ -44,8 +44,12 @@ def __init__( self.std_range = self._parse_range(std, 'std', min_constraint=0) def apply_transform(self, subject: Subject) -> Subject: + images_dict = self.get_images_dict(subject) + if not images_dict: + return subject + arguments: Dict[str, dict] = defaultdict(dict) - for image_name in self.get_images_dict(subject): + for image_name in images_dict: mean, std, seed = self.get_params(self.mean_range, self.std_range) arguments['mean'][image_name] = mean arguments['std'][image_name] = std diff --git a/src/torchio/transforms/augmentation/intensity/random_spike.py b/src/torchio/transforms/augmentation/intensity/random_spike.py index 183f5b6e..5557bcaa 100644 --- a/src/torchio/transforms/augmentation/intensity/random_spike.py +++ b/src/torchio/transforms/augmentation/intensity/random_spike.py @@ -61,8 +61,12 @@ def __init__( ) def apply_transform(self, subject: Subject) -> Subject: + images_dict = self.get_images_dict(subject) + if not images_dict: + return subject + arguments: Dict[str, dict] = defaultdict(dict) - for image_name in self.get_images_dict(subject): + for image_name in images_dict: spikes_positions_param, intensity_param = self.get_params( self.num_spikes_range, self.intensity_range, @@ -90,7 +94,7 @@ class Spike(IntensityTransform, FourierTransform): r"""Add MRI spike artifacts. Also known as `Herringbone artifact - `_, + `_, crisscross artifact or corduroy artifact, it creates stripes in different directions in image space due to spikes in k-space. diff --git a/src/torchio/transforms/augmentation/intensity/random_swap.py b/src/torchio/transforms/augmentation/intensity/random_swap.py index 65d7d74d..3eac412c 100644 --- a/src/torchio/transforms/augmentation/intensity/random_swap.py +++ b/src/torchio/transforms/augmentation/intensity/random_swap.py @@ -89,8 +89,12 @@ def get_params( return locations # type: ignore[return-value] def apply_transform(self, subject: Subject) -> Subject: + images_dict = self.get_images_dict(subject) + if not images_dict: + return subject + arguments: Dict[str, dict] = defaultdict(dict) - for name, image in self.get_images_dict(subject).items(): + for name, image in images_dict.items(): locations = self.get_params( image.data, self.patch_size, diff --git a/src/torchio/transforms/augmentation/random_transform.py b/src/torchio/transforms/augmentation/random_transform.py index 0da6e211..83c27937 100644 --- a/src/torchio/transforms/augmentation/random_transform.py +++ b/src/torchio/transforms/augmentation/random_transform.py @@ -5,6 +5,8 @@ import torch from ...typing import TypeRangeFloat +from ...typing import TypeSextetFloat +from ...typing import TypeTripletFloat from ..transform import Transform @@ -49,8 +51,9 @@ def _get_random_seed() -> int: """ return int(torch.randint(0, 2**31, (1,)).item()) - def sample_uniform_sextet(self, params): + def sample_uniform_sextet(self, params: TypeSextetFloat) -> TypeTripletFloat: results = [] for a, b in zip(params[::2], params[1::2]): results.append(self.sample_uniform(a, b)) - return torch.Tensor(results) + sx, sy, sz = results + return sx, sy, sz diff --git a/src/torchio/transforms/augmentation/spatial/random_affine.py b/src/torchio/transforms/augmentation/spatial/random_affine.py index 2cdcb38d..973b2ef0 100644 --- a/src/torchio/transforms/augmentation/spatial/random_affine.py +++ b/src/torchio/transforms/augmentation/spatial/random_affine.py @@ -151,11 +151,20 @@ def get_params( translation: TypeSextetFloat, isotropic: bool, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - scaling_params = self.sample_uniform_sextet(scales) + scaling_params = torch.as_tensor( + self.sample_uniform_sextet(scales), + dtype=torch.float64, + ) if isotropic: scaling_params.fill_(scaling_params[0]) - rotation_params = self.sample_uniform_sextet(degrees) - translation_params = self.sample_uniform_sextet(translation) + rotation_params = torch.as_tensor( + self.sample_uniform_sextet(degrees), + dtype=torch.float64, + ) + translation_params = torch.as_tensor( + self.sample_uniform_sextet(translation), + dtype=torch.float64, + ) return scaling_params, rotation_params, translation_params def apply_transform(self, subject: Subject) -> Subject: @@ -166,9 +175,9 @@ def apply_transform(self, subject: Subject) -> Subject: self.isotropic, ) arguments = { - 'scales': scaling_params.tolist(), - 'degrees': rotation_params.tolist(), - 'translation': translation_params.tolist(), + 'scales': scaling_params, + 'degrees': rotation_params, + 'translation': translation_params, 'center': self.center, 'default_pad_value': self.default_pad_value, 'image_interpolation': self.image_interpolation, diff --git a/src/torchio/transforms/transform.py b/src/torchio/transforms/transform.py index 20daccbe..7d0ef894 100644 --- a/src/torchio/transforms/transform.py +++ b/src/torchio/transforms/transform.py @@ -128,10 +128,7 @@ def __init__( # used to invert invertible transforms self.args_names: List[str] = [] - def __call__( - self, - data: InputType, - ) -> InputType: + def __call__(self, data: InputType) -> InputType: """Transform data and return a result of the same type. Args: