Skip to content

Commit

Permalink
Fix KeyError with LabelMap-only subjects (#1218)
Browse files Browse the repository at this point in the history
* Fix KeyError with LabelMap-only subjects

* Fix mypy error

* Use double precision for params
  • Loading branch information
fepegar authored Oct 5, 2024
1 parent dfba02b commit ca17ba0
Show file tree
Hide file tree
Showing 12 changed files with 70 additions and 32 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ dev = [
"pytest",
"pytest-cov",
"pytest-sugar",
"tox",
"tox-uv",
"types-Deprecated",
]
doc = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,17 @@ def __init__(
self.order = _parse_order(order)

def apply_transform(self, subject: Subject) -> Subject:
images_dict = self.get_images_dict(subject)
if not images_dict:
return subject

arguments: Dict[str, dict] = defaultdict(dict)
for image_name in self.get_images_dict(subject):
coefficients = self.get_params(
self.order,
self.coefficients_range,
)
for image_name in images_dict:
coefficients = self.get_params(self.order, self.coefficients_range)
arguments['coefficients'][image_name] = coefficients
arguments['order'][image_name] = self.order
transform = BiasField(**self.add_include_exclude(arguments))
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed

def get_params(
Expand Down
12 changes: 8 additions & 4 deletions src/torchio/transforms/augmentation/intensity/random_blur.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,22 @@ def __init__(self, std: Union[float, Tuple[float, float]] = (0, 2), **kwargs):
self.std_ranges = self.parse_params(std, None, 'std', min_constraint=0)

def apply_transform(self, subject: Subject) -> Subject:
images_dict = self.get_images_dict(subject)
if not images_dict:
return subject

arguments: Dict[str, dict] = defaultdict(dict)
for name in self.get_images_dict(subject):
std = self.get_params(self.std_ranges)
for name in images_dict:
std = self.get_params(self.std_ranges) # type: ignore[arg-type]
arguments['std'][name] = std
transform = Blur(**self.add_include_exclude(arguments))
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed

def get_params(self, std_ranges: TypeSextetFloat) -> TypeTripletFloat:
std = self.sample_uniform_sextet(std_ranges)
return std
sx, sy, sz = self.sample_uniform_sextet(std_ranges)
return sx, sy, sz


class Blur(IntensityTransform):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,12 @@ def __init__(self, log_gamma: TypeRangeFloat = (-0.3, 0.3), **kwargs):
self.log_gamma_range = self._parse_range(log_gamma, 'log_gamma')

def apply_transform(self, subject: Subject) -> Subject:
images_dict = self.get_images_dict(subject)
if not images_dict:
return subject

arguments: Dict[str, dict] = defaultdict(dict)
for name, image in self.get_images_dict(subject).items():
for name, image in images_dict.items():
gammas = [self.get_params(self.log_gamma_range) for _ in image.data]
arguments['gamma'][name] = gammas
transform = Gamma(**self.add_include_exclude(arguments))
Expand Down
11 changes: 8 additions & 3 deletions src/torchio/transforms/augmentation/intensity/random_ghosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,15 @@ def __init__(
self.restore = _parse_restore(restore)

def apply_transform(self, subject: Subject) -> Subject:
arguments: Dict[str, dict] = defaultdict(dict)
if any(isinstance(n, str) for n in self.axes):
images_dict = self.get_images_dict(subject)
if not images_dict:
return subject

if any(isinstance(axis, str) for axis in self.axes):
subject.check_consistent_orientation()
for name, image in self.get_images_dict(subject).items():

arguments: Dict[str, dict] = defaultdict(dict)
for name, image in images_dict.items():
is_2d = image.is_2d()
axes = [a for a in self.axes if a != 2] if is_2d else self.axes
min_ghosts, max_ghosts = self.num_ghosts_range
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,12 @@ def __init__(
)

def apply_transform(self, subject: Subject) -> Subject:
images_dict = self.get_images_dict(subject)
if not images_dict:
return subject

arguments: Dict[str, dict] = defaultdict(dict)
for name, image in self.get_images_dict(subject).items():
for name, image in images_dict.items():
params = self.get_params(
self.degrees_range,
self.translation_range,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,12 @@ def __init__(
self.std_range = self._parse_range(std, 'std', min_constraint=0)

def apply_transform(self, subject: Subject) -> Subject:
images_dict = self.get_images_dict(subject)
if not images_dict:
return subject

arguments: Dict[str, dict] = defaultdict(dict)
for image_name in self.get_images_dict(subject):
for image_name in images_dict:
mean, std, seed = self.get_params(self.mean_range, self.std_range)
arguments['mean'][image_name] = mean
arguments['std'][image_name] = std
Expand Down
8 changes: 6 additions & 2 deletions src/torchio/transforms/augmentation/intensity/random_spike.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,12 @@ def __init__(
)

def apply_transform(self, subject: Subject) -> Subject:
images_dict = self.get_images_dict(subject)
if not images_dict:
return subject

arguments: Dict[str, dict] = defaultdict(dict)
for image_name in self.get_images_dict(subject):
for image_name in images_dict:
spikes_positions_param, intensity_param = self.get_params(
self.num_spikes_range,
self.intensity_range,
Expand Down Expand Up @@ -90,7 +94,7 @@ class Spike(IntensityTransform, FourierTransform):
r"""Add MRI spike artifacts.
Also known as `Herringbone artifact
<https://radiopaedia.org/articles/herringbone-artifact?lang=gb>`_,
<https://radiopaedia.org/articles/herringbone-artifact>`_,
crisscross artifact or corduroy artifact, it creates stripes in different
directions in image space due to spikes in k-space.
Expand Down
6 changes: 5 additions & 1 deletion src/torchio/transforms/augmentation/intensity/random_swap.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,12 @@ def get_params(
return locations # type: ignore[return-value]

def apply_transform(self, subject: Subject) -> Subject:
images_dict = self.get_images_dict(subject)
if not images_dict:
return subject

arguments: Dict[str, dict] = defaultdict(dict)
for name, image in self.get_images_dict(subject).items():
for name, image in images_dict.items():
locations = self.get_params(
image.data,
self.patch_size,
Expand Down
7 changes: 5 additions & 2 deletions src/torchio/transforms/augmentation/random_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch

from ...typing import TypeRangeFloat
from ...typing import TypeSextetFloat
from ...typing import TypeTripletFloat
from ..transform import Transform


Expand Down Expand Up @@ -49,8 +51,9 @@ def _get_random_seed() -> int:
"""
return int(torch.randint(0, 2**31, (1,)).item())

def sample_uniform_sextet(self, params):
def sample_uniform_sextet(self, params: TypeSextetFloat) -> TypeTripletFloat:
results = []
for a, b in zip(params[::2], params[1::2]):
results.append(self.sample_uniform(a, b))
return torch.Tensor(results)
sx, sy, sz = results
return sx, sy, sz
21 changes: 15 additions & 6 deletions src/torchio/transforms/augmentation/spatial/random_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,20 @@ def get_params(
translation: TypeSextetFloat,
isotropic: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
scaling_params = self.sample_uniform_sextet(scales)
scaling_params = torch.as_tensor(
self.sample_uniform_sextet(scales),
dtype=torch.float64,
)
if isotropic:
scaling_params.fill_(scaling_params[0])
rotation_params = self.sample_uniform_sextet(degrees)
translation_params = self.sample_uniform_sextet(translation)
rotation_params = torch.as_tensor(
self.sample_uniform_sextet(degrees),
dtype=torch.float64,
)
translation_params = torch.as_tensor(
self.sample_uniform_sextet(translation),
dtype=torch.float64,
)
return scaling_params, rotation_params, translation_params

def apply_transform(self, subject: Subject) -> Subject:
Expand All @@ -166,9 +175,9 @@ def apply_transform(self, subject: Subject) -> Subject:
self.isotropic,
)
arguments = {
'scales': scaling_params.tolist(),
'degrees': rotation_params.tolist(),
'translation': translation_params.tolist(),
'scales': scaling_params,
'degrees': rotation_params,
'translation': translation_params,
'center': self.center,
'default_pad_value': self.default_pad_value,
'image_interpolation': self.image_interpolation,
Expand Down
5 changes: 1 addition & 4 deletions src/torchio/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,7 @@ def __init__(
# used to invert invertible transforms
self.args_names: List[str] = []

def __call__(
self,
data: InputType,
) -> InputType:
def __call__(self, data: InputType) -> InputType:
"""Transform data and return a result of the same type.
Args:
Expand Down

0 comments on commit ca17ba0

Please sign in to comment.