Skip to content

Commit

Permalink
Revert invertibility of RescaleIntensity (#1116)
Browse files Browse the repository at this point in the history
* Add failing test

* Revert invertibility of `RescaleIntensity` (#1120)

* Make RescaleIntensity non-invertible again

* Disable support to invert RescaleIntensity
  • Loading branch information
fepegar authored Nov 4, 2023
1 parent 80d71e0 commit dfd52bb
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 39 deletions.
46 changes: 17 additions & 29 deletions src/torchio/transforms/preprocessing/intensity/rescale.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch

from ....data.subject import Subject
from ....typing import TypeRangeFloat
from ....typing import TypeDoubleFloat
from .normalization_transform import NormalizationTransform
from .normalization_transform import TypeMaskingMethod

Expand Down Expand Up @@ -45,10 +45,10 @@ class RescaleIntensity(NormalizationTransform):

def __init__(
self,
out_min_max: TypeRangeFloat = (0, 1),
percentiles: TypeRangeFloat = (0, 100),
out_min_max: TypeDoubleFloat = (0, 1),
percentiles: TypeDoubleFloat = (0, 100),
masking_method: TypeMaskingMethod = None,
in_min_max: Optional[TypeRangeFloat] = None,
in_min_max: Optional[TypeDoubleFloat] = None,
**kwargs,
):
super().__init__(masking_method=masking_method, **kwargs)
Expand All @@ -65,24 +65,18 @@ def __init__(
max_constraint=100,
)

self.in_min: Optional[float]
self.in_max: Optional[float]
if self.in_min_max is not None:
self.in_min, self.in_max = self._parse_range(
self.in_min_max = self._parse_range(
self.in_min_max,
'in_min_max',
)
else:
self.in_min = None
self.in_max = None

self.args_names = [
'out_min_max',
'percentiles',
'masking_method',
'in_min_max',
]
self.invert_transform = False

def apply_normalization(
self,
Expand All @@ -109,34 +103,28 @@ def rescale(
)
warnings.warn(message, RuntimeWarning, stacklevel=2)
return tensor

values = array[mask]
cutoff = np.percentile(values, self.percentiles)
np.clip(array, *cutoff, out=array) # type: ignore[call-overload]

if self.in_min_max is None:
self.in_min_max = self._parse_range(
(array.min(), array.max()),
'in_min_max',
)
self.in_min, self.in_max = self.in_min_max
assert self.in_min is not None
assert self.in_max is not None
in_range = self.in_max - self.in_min
in_min, in_max = array.min(), array.max()
else:
in_min, in_max = self.in_min_max
in_range = in_max - in_min
if in_range == 0: # should this be compared using a tolerance?
message = (
f'Rescaling image "{image_name}" not possible'
' because all the intensity values are the same'
)
warnings.warn(message, RuntimeWarning, stacklevel=2)
return tensor

out_range = self.out_max - self.out_min
if self.invert_transform:
array -= self.out_min
array /= out_range
array *= in_range
array += self.in_min
else:
array -= self.in_min
array /= in_range
array *= out_range
array += self.out_min

array -= in_min
array /= in_range
array *= out_range
array += self.out_min
return torch.as_tensor(array)
1 change: 1 addition & 0 deletions src/torchio/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
TypeQuartetInt = Tuple[int, int, int, int]
TypeSextetInt = Tuple[int, int, int, int, int, int]

TypeDoubleFloat = Tuple[float, float]
TypeTripletFloat = Tuple[float, float, float]
TypeSextetFloat = Tuple[float, float, float, float, float, float]

Expand Down
24 changes: 14 additions & 10 deletions tests/transforms/preprocessing/test_rescale.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,17 @@ def test_empty_mask(self):
with pytest.warns(RuntimeWarning):
rescale(subject)

def test_invert_rescaling(self):
torch.manual_seed(0)
transform = tio.RescaleIntensity(out_min_max=(0, 1))
data = torch.rand(1, 2, 3, 4).double()
subject = tio.Subject(t1=tio.ScalarImage(tensor=data))
transformed = transform(subject)
assert transformed.t1.data.min() == 0
assert transformed.t1.data.max() == 1
inverted = transformed.apply_inverse_transform()
self.assert_tensor_almost_equal(inverted.t1.data, data)
def test_persistent_in_min_max(self):
# see https://github.com/fepegar/torchio/issues/1115
img1 = torch.tensor([[[[0, 1]]]])
img2 = torch.tensor([[[[0, 10]]]])

rescale = tio.RescaleIntensity(out_min_max=(0, 1))

assert rescale(img1).data.flatten().tolist() == [0, 1]
assert rescale(img2).data.flatten().tolist() == [0, 1]

rescale = tio.RescaleIntensity(out_min_max=(0, 1))

assert rescale(img2).data.flatten().tolist() == [0, 1]
assert rescale(img1).data.flatten().tolist() == [0, 1]

0 comments on commit dfd52bb

Please sign in to comment.