Skip to content

Commit

Permalink
Implement missing restore feature for RandomGhosting
Browse files Browse the repository at this point in the history
  • Loading branch information
fepegar committed Sep 23, 2024
1 parent 4f1e1d6 commit c9024b7
Showing 1 changed file with 73 additions and 45 deletions.
118 changes: 73 additions & 45 deletions src/torchio/transforms/augmentation/intensity/random_ghosting.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import defaultdict
from typing import Dict
from typing import Iterable
from typing import Optional
from typing import Tuple
from typing import Union

Expand Down Expand Up @@ -43,7 +44,10 @@ class RandomGhosting(RandomTransform, IntensityTransform):
:math:`s \sim \mathcal{U}(0, d)`.
restore: Number between ``0`` and ``1`` indicating how much of the
:math:`k`-space center should be restored after removing the planes
that generate the artifact.
that generate the artifact. If ``None``, only the central slice
will be restored. If a tuple :math:`(a, b)` is provided then
:math:`r \sim \mathcal{U}(a, b)`. If only one value :math:`d` is
provided, :math:`r \sim \mathcal{U}(0, d)`.
**kwargs: See :class:`~torchio.transforms.Transform` for additional
keyword arguments.
Expand All @@ -56,7 +60,7 @@ def __init__(
num_ghosts: Union[int, Tuple[int, int]] = (4, 10),
axes: Union[int, Tuple[int, ...]] = (0, 1, 2),
intensity: Union[float, Tuple[float, float]] = (0.5, 1),
restore: float = 0.02,
restore: Optional[float] = None,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -81,7 +85,15 @@ def __init__(
'intensity_range',
min_constraint=0,
)
self.restore = _parse_restore(restore)
if restore is None:
self.restore = None
else:
self.restore = self._parse_range(
restore,
'restore',
min_constraint=0,
max_constraint=1,
)

def apply_transform(self, subject: Subject) -> Subject:
arguments: Dict[str, dict] = defaultdict(dict)
Expand All @@ -95,12 +107,13 @@ def apply_transform(self, subject: Subject) -> Subject:
(int(min_ghosts), int(max_ghosts)),
axes, # type: ignore[arg-type]
self.intensity_range,
self.restore,
)
num_ghosts_param, axis_param, intensity_param = params
num_ghosts_param, axis_param, intensity_param, restore_param = params
arguments['num_ghosts'][name] = num_ghosts_param
arguments['axis'][name] = axis_param
arguments['intensity'][name] = intensity_param
arguments['restore'][name] = self.restore
arguments['restore'][name] = restore_param
transform = Ghosting(**self.add_include_exclude(arguments))
transformed = transform(subject)
assert isinstance(transformed, Subject)
Expand All @@ -111,12 +124,17 @@ def get_params(
num_ghosts_range: Tuple[int, int],
axes: Tuple[int, ...],
intensity_range: Tuple[float, float],
) -> Tuple:
restore_range: Optional[Tuple[float, float]],
) -> Tuple[int, int, float, Optional[float]]:
ng_min, ng_max = num_ghosts_range
num_ghosts = torch.randint(ng_min, ng_max + 1, (1,)).item()
num_ghosts = int(torch.randint(ng_min, ng_max + 1, (1,)).item())
axis = axes[torch.randint(0, len(axes), (1,))]
intensity = self.sample_uniform(*intensity_range)
return num_ghosts, axis, intensity
if restore_range is None:
restore = None
else:
restore = self.sample_uniform(*restore_range)
return num_ghosts, axis, intensity, restore


class Ghosting(IntensityTransform, FourierTransform):
Expand All @@ -139,7 +157,8 @@ class Ghosting(IntensityTransform, FourierTransform):
If ``0``, the ghosts will not be visible.
restore: Number between ``0`` and ``1`` indicating how much of the
:math:`k`-space center should be restored after removing the planes
that generate the artifact.
that generate the artifact. If ``None``, only the central slice
will be restored.
**kwargs: See :class:`~torchio.transforms.Transform` for additional
keyword arguments.
Expand All @@ -152,7 +171,7 @@ def __init__(
num_ghosts: Union[int, Dict[str, int]],
axis: Union[int, Dict[str, int]],
intensity: Union[float, Dict[str, float]],
restore: Union[float, Dict[str, float]],
restore: Union[Optional[float], Dict[str, Optional[float]]],
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -163,10 +182,6 @@ def __init__(
self.args_names = ['num_ghosts', 'axis', 'intensity', 'restore']

def apply_transform(self, subject: Subject) -> Subject:
axis = self.axis
num_ghosts = self.num_ghosts
intensity = self.intensity
restore = self.restore
for name, image in self.get_images_dict(subject).items():
if self.arguments_are_dict():
assert isinstance(self.axis, dict)
Expand All @@ -177,12 +192,18 @@ def apply_transform(self, subject: Subject) -> Subject:
num_ghosts = self.num_ghosts[name]
intensity = self.intensity[name]
restore = self.restore[name]
else:
axis = self.axis
num_ghosts = self.num_ghosts
intensity = self.intensity
restore = self.restore
transformed_tensors = []
for tensor in image.data:
assert isinstance(num_ghosts, int)
assert isinstance(axis, int)
assert isinstance(intensity, float)
assert isinstance(restore, float)
assert isinstance(intensity, (int, float))
if restore is not None:
assert isinstance(restore, float)
transformed_tensor = self.add_artifact(
tensor,
num_ghosts,
Expand All @@ -200,49 +221,56 @@ def add_artifact(
num_ghosts: int,
axis: int,
intensity: float,
restore_center: float,
restore_center: Optional[float],
):
if not num_ghosts or not intensity:
return tensor

spectrum = self.fourier_transform(tensor)

shape = np.array(tensor.shape)
ri, rj, rk = np.round(restore_center * shape).astype(np.uint16)
mi, mj, mk = np.array(tensor.shape) // 2

# Variable "planes" is the part of the spectrum that will be modified
if axis == 0:
planes = spectrum[::num_ghosts, :, :]
restore = spectrum[mi, :, :].clone()
elif axis == 1:
planes = spectrum[:, ::num_ghosts, :]
restore = spectrum[:, mj, :].clone()
elif axis == 2:
planes = spectrum[:, :, ::num_ghosts]
restore = spectrum[:, :, mk].clone()
# Variable "restore" is the part of the spectrum that will be restored
planes = self._get_planes_to_modify(spectrum, axis, num_ghosts)
tensor_restore, slices = self._get_slices_to_restore(
spectrum, axis, restore_center
)
tensor_restore = tensor_restore.clone()

# Multiply by 0 if intensity is 1
planes *= 1 - intensity

# Restore the center of k-space to avoid extreme artifacts
if axis == 0:
spectrum[mi, :, :] = restore
elif axis == 1:
spectrum[:, mj, :] = restore
elif axis == 2:
spectrum[:, :, mk] = restore
spectrum[slices] = tensor_restore

tensor_ghosts = self.inv_fourier_transform(spectrum)
return tensor_ghosts.real.float()

@staticmethod
def _get_planes_to_modify(
spectrum: torch.Tensor,
axis: int,
num_ghosts: int,
) -> torch.Tensor:
slices = [slice(None)] * spectrum.ndim
slices[axis] = slice(None, None, num_ghosts)
slices = tuple(slices)
return spectrum[slices]

def _parse_restore(restore):
try:
restore = float(restore)
except ValueError as e:
raise TypeError(f'Restore must be a float, not "{restore}"') from e
if not 0 <= restore <= 1:
message = f'Restore must be a number between 0 and 1, not {restore}'
raise ValueError(message)
return restore
@staticmethod
def _get_slices_to_restore(
spectrum: torch.Tensor,
axis: int,
restore_center: Optional[float],
) -> Tuple[torch.Tensor, Tuple[slice, ...]]:
dim_shape = spectrum.shape[axis]
mid_idx = dim_shape // 2
slices = [slice(None)] * spectrum.ndim
if restore_center is None:
slice_ = slice(mid_idx, mid_idx + 1)
else:
size_restore = int(np.round(restore_center * dim_shape))
slice_ = slice(mid_idx - size_restore // 2, mid_idx + size_restore // 2)
slices[axis] = slice_
slices = tuple(slices)
restore_tensor = spectrum[slices]
return restore_tensor, slices

0 comments on commit c9024b7

Please sign in to comment.