Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Fix/Inference mode for Imputers #75

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/anemoi/models/preprocessing/imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,11 @@ def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
if not in_place:
x = x.clone()

# Initialize nan mask once
# Reset NaN locations outside of training for validation and inference.
if not self.training:
self.nan_locations = None

# Initialise mask if not cached.
if self.nan_locations is None:

# Get NaN locations
Expand Down
113 changes: 59 additions & 54 deletions tests/preprocessing/test_preprocessor_imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,17 @@ def default_constant_data():
return base, expected


fixture_combinations = (
("default_constant_imputer", "default_constant_data"),
("non_default_constant_imputer", "non_default_constant_data"),
("default_input_imputer", "default_input_data"),
("non_default_input_imputer", "non_default_input_data"),
)


@pytest.mark.parametrize(
("imputer_fixture", "data_fixture"),
[
("default_constant_imputer", "default_constant_data"),
("non_default_constant_imputer", "non_default_constant_data"),
("default_input_imputer", "default_input_data"),
("non_default_input_imputer", "non_default_input_data"),
],
fixture_combinations,
)
def test_imputer_not_inplace(imputer_fixture, data_fixture, request) -> None:
"""Check that the imputer does not modify the input tensor when in_place=False."""
Expand All @@ -150,12 +153,7 @@ def test_imputer_not_inplace(imputer_fixture, data_fixture, request) -> None:

@pytest.mark.parametrize(
("imputer_fixture", "data_fixture"),
[
("default_constant_imputer", "default_constant_data"),
("non_default_constant_imputer", "non_default_constant_data"),
("default_input_imputer", "default_input_data"),
("non_default_input_imputer", "non_default_input_data"),
],
fixture_combinations,
)
def test_imputer_inplace(imputer_fixture, data_fixture, request) -> None:
"""Check that the imputer modifies the input tensor when in_place=True."""
Expand All @@ -169,12 +167,7 @@ def test_imputer_inplace(imputer_fixture, data_fixture, request) -> None:

@pytest.mark.parametrize(
("imputer_fixture", "data_fixture"),
[
("default_constant_imputer", "default_constant_data"),
("non_default_constant_imputer", "non_default_constant_data"),
("default_input_imputer", "default_input_data"),
("non_default_input_imputer", "non_default_input_data"),
],
fixture_combinations,
)
def test_transform_with_nan(imputer_fixture, data_fixture, request):
"""Check that the imputer correctly transforms a tensor with NaNs."""
Expand All @@ -186,12 +179,7 @@ def test_transform_with_nan(imputer_fixture, data_fixture, request):

@pytest.mark.parametrize(
("imputer_fixture", "data_fixture"),
[
("default_constant_imputer", "default_constant_data"),
("non_default_constant_imputer", "non_default_constant_data"),
("default_input_imputer", "default_input_data"),
("non_default_input_imputer", "non_default_input_data"),
],
fixture_combinations,
)
def test_transform_with_nan_small(imputer_fixture, data_fixture, request):
"""Check that the imputer correctly transforms a tensor with NaNs."""
Expand All @@ -211,12 +199,7 @@ def test_transform_with_nan_small(imputer_fixture, data_fixture, request):

@pytest.mark.parametrize(
("imputer_fixture", "data_fixture"),
[
("default_constant_imputer", "default_constant_data"),
("non_default_constant_imputer", "non_default_constant_data"),
("default_input_imputer", "default_input_data"),
("non_default_input_imputer", "non_default_input_data"),
],
fixture_combinations,
)
def test_transform_with_nan_inference(imputer_fixture, data_fixture, request):
"""Check that the imputer correctly transforms a tensor with NaNs in inference."""
Expand Down Expand Up @@ -244,12 +227,7 @@ def test_transform_with_nan_inference(imputer_fixture, data_fixture, request):

@pytest.mark.parametrize(
("imputer_fixture", "data_fixture"),
[
("default_constant_imputer", "default_constant_data"),
("non_default_constant_imputer", "non_default_constant_data"),
("default_input_imputer", "default_input_data"),
("non_default_input_imputer", "non_default_input_data"),
],
fixture_combinations,
)
def test_transform_noop(imputer_fixture, data_fixture, request):
"""Check that the imputer does not modify a tensor without NaNs."""
Expand All @@ -262,12 +240,7 @@ def test_transform_noop(imputer_fixture, data_fixture, request):

@pytest.mark.parametrize(
("imputer_fixture", "data_fixture"),
[
("default_constant_imputer", "default_constant_data"),
("non_default_constant_imputer", "non_default_constant_data"),
("default_input_imputer", "default_input_data"),
("non_default_input_imputer", "non_default_input_data"),
],
fixture_combinations,
)
def test_inverse_transform(imputer_fixture, data_fixture, request):
"""Check that the imputer correctly inverts the transformation."""
Expand All @@ -281,12 +254,7 @@ def test_inverse_transform(imputer_fixture, data_fixture, request):

@pytest.mark.parametrize(
("imputer_fixture", "data_fixture"),
[
("default_constant_imputer", "default_constant_data"),
("non_default_constant_imputer", "non_default_constant_data"),
("default_input_imputer", "default_input_data"),
("non_default_input_imputer", "non_default_input_data"),
],
fixture_combinations,
)
def test_mask_saving(imputer_fixture, data_fixture, request):
"""Check that the imputer saves the NaN mask correctly."""
Expand All @@ -299,12 +267,7 @@ def test_mask_saving(imputer_fixture, data_fixture, request):

@pytest.mark.parametrize(
("imputer_fixture", "data_fixture"),
[
("default_constant_imputer", "default_constant_data"),
("non_default_constant_imputer", "non_default_constant_data"),
("default_input_imputer", "default_input_data"),
("non_default_input_imputer", "non_default_input_data"),
],
fixture_combinations,
)
def test_loss_nan_mask(imputer_fixture, data_fixture, request):
"""Check that the imputer correctly transforms a tensor with NaNs."""
Expand Down Expand Up @@ -336,3 +299,45 @@ def test_reuse_imputer(imputer_fixture, data_fixture, request):
assert torch.allclose(
transformed2, expected, equal_nan=True
), "Imputer does not reuse mask correctly on subsequent runs."


@pytest.mark.parametrize(
("imputer_fixture", "data_fixture"),
fixture_combinations,
)
def test_inference_imputer(imputer_fixture, data_fixture, request):
"""Check that the imputer resets its mask during inference."""
x, expected = request.getfixturevalue(data_fixture)
imputer = request.getfixturevalue(imputer_fixture)

# Check training flag
assert imputer.training, "Imputer is not set to training mode."

expected_mask = torch.isnan(x)
transformed = imputer.transform(x, in_place=False)
assert torch.allclose(transformed, expected, equal_nan=True), "Transform does not handle NaNs correctly."
restored = imputer.inverse_transform(transformed, in_place=False)
assert torch.allclose(restored, x, equal_nan=True), "Inverse transform does not restore NaNs correctly."
assert torch.equal(imputer.nan_locations, expected_mask), "Mask not saved correctly after first run."

imputer.eval()
with torch.no_grad():
x2 = x.roll(-1, dims=0)
expected2 = expected.roll(-1, dims=0)
expected_mask2 = torch.isnan(x2)

assert torch.equal(imputer.nan_locations, expected_mask), "Mask not saved correctly after first run."

# Check training flag
assert not imputer.training, "Imputer is not set to evaluation mode."

assert not torch.allclose(x, x2, equal_nan=True), "Failed to modify the input data."
assert not torch.allclose(expected, expected2, equal_nan=True), "Failed to modify the expected data."
assert not torch.allclose(expected_mask, expected_mask2, equal_nan=True), "Failed to modify the nan mask."

transformed = imputer.transform(x2, in_place=False)
assert torch.allclose(transformed, expected2, equal_nan=True), "Transform does not handle NaNs correctly."
restored = imputer.inverse_transform(transformed, in_place=False)
assert torch.allclose(restored, x2, equal_nan=True), "Inverse transform does not restore NaNs correctly."

assert torch.equal(imputer.nan_locations, expected_mask2), "Mask not saved correctly after evaluation run."
Loading