Skip to content

Commit

Permalink
FEAT: more extensive gather_dimensions2 with unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rob-taggart committed Dec 11, 2023
1 parent 99cc2fa commit 017aa10
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 24 deletions.
77 changes: 54 additions & 23 deletions src/scores/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,14 @@ def gather_dimensions( # pylint: disable=too-many-branches
return reduce_dims


def standard_dims_check(fcst, obs, weights=None, reduce_dims=None, preserve_dims=None, special_fcst_dims=None):
def gather_dimensions2(fcst, obs, weights=None, reduce_dims=None, preserve_dims=None, special_fcst_dims=None):
"""
Performs a standard dimensions check for a function that calculates (mean) scores.
Returns a list of the dimensions to reduce.
special_fcst_dims are dims in fcst that will be collapsed while calculating individual scores
(e.g. the threshold dimension of a CDF, or the ensemble member dimesnion, when calculating CRPS)
and is never present for the step of calculating mean scores.
Checks that:
- reduce_dims and preserve_dims are both not specified
Expand All @@ -128,44 +130,73 @@ def standard_dims_check(fcst, obs, weights=None, reduce_dims=None, preserve_dims
- specified dims with a string value is handled correctly
Returns:
list of dimensions over which to take the mean
OR IS THIS REALLY NEEDED? Once the checks are done,
list of dimensions over which to take the mean once the checks are passed.
Raises:
ValueError:
- when `preserve_dims and `reduce_dims` are both specified
- when `special_fcst_dims` is not a subset of `fcst.dims`
- when `obs.dims`, `weights.dims`, `reduce_dims` or `preserve_dims`
contains elements from `special_fcst_dims`
- when `preserve_dims and `reduce_dims` contain elements not among dimensions
of the data (`fcst`, `obs` or `weights`)
"""
# all_scoring_dims is the set of dims remaining after individual scores are computed.
all_dims = set(fcst_dims).union(set(obs_dims))
all_data_dims = set(fcst.dims).union(set(obs.dims))
if weights is not None:
all_dims = all_dims.union(set(weights.dims))
all_data_dims = all_data_dims.union(set(weights.dims))

# all_scoring_dims is the set of dims remaining after individual scores are computed.
all_scoring_dims = all_dims.copy()
all_scoring_dims = all_data_dims.copy()

# Handle error conditions related to specified dimensions
if preserve_dims is not None and reduce_dims is not None:
raise ValueError(ERROR_OVERSPECIFIED_PRESERVE_REDUCE)

specified_dims = preserve_dims or reduce_dims

# check that special_fcst_dims are in in fcst.dims only
if specified_dims == "all":
if "all" in all_data_dims:
warnings.warn(WARN_ALL_DATA_CONFLICT_MSG)
elif specified_dims is not None:
if isinstance(specified_dims, str):
specified_dims = [specified_dims]

# check that special_fcst_dims are in fcst.dims only
if special_fcst_dims is not None:
if len(set(obs.dims).union(set(special_fcst_dims))) > 0:
if isinstance(special_fcst_dims, str):
special_fcst_dims = [special_fcst_dims]
if not set(special_fcst_dims).issubset(set(fcst.dims)):
raise ValueError("`special_fcst_dims` must be a subset of `fcst` dimensions")
if len(set(obs.dims).intersection(set(special_fcst_dims))) > 0:
raise ValueError("`obs.dims` must not contain any `special_fcst_dims`")
if weights is not None:
if len(set(weights.dims).union(set(special_fcst_dims))) > 0:
if len(set(weights.dims).intersection(set(special_fcst_dims))) > 0:
raise ValueError("`weights.dims` must not contain any `special_fcst_dims`")
if specified_dims is not None and specified_dims != "all":
if len(set(specified_dims).intersection(set(special_fcst_dims))) > 0:
raise ValueError("`reduce_dims` and `preserve_dims` must not contain any `special_fcst_dims`")
# remove special_fcst_dims from all_scoring_dims
all_scoring_dims = all_scoring_dims.difference(set(special_fcst_dims))

# Handle error conditions related to specified dimensions
if preserve_dims is not None and reduce_dims is not None:
raise ValueError(ERROR_OVERSPECIFIED_PRESERVE_REDUCE)

specified = preserve_dims or reduce_dims
if specified == "all":
if "all" in all_dims:
warnings.warn(WARN_ALL_DATA_CONFLICT_MSG)
elif specified is not None:
if isinstance(specified, str):
specified = [specified]

if not set(specified).issubset(all_scoring_dims):
if specified_dims is not None and specified_dims != "all":
if not set(specified_dims).issubset(all_scoring_dims):
if preserve_dims is not None:
raise ValueError(ERROR_SPECIFIED_NONPRESENT_PRESERVE_DIMENSION)
raise ValueError(ERROR_SPECIFIED_NONPRESENT_REDUCE_DIMENSION)

# all errors have been captured, so now return list of dims to reduce
if specified_dims is None:
return all_scoring_dims
elif reduce_dims is not None:
if reduce_dims == "all":
return all_scoring_dims
else:
return set(specified_dims)
elif preserve_dims == "all":
return set([])
else:
return all_scoring_dims.difference(set(specified_dims))


def dims_complement(data, dims=None) -> list[str]:
"""Returns the complement of data.dims and dims
Expand Down
156 changes: 155 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import xarray as xr

from scores import utils
from scores.utils import DimensionError
from scores.utils import DimensionError, gather_dimensions2
from scores.utils import gather_dimensions as gd
from tests import utils_test_data

Expand Down Expand Up @@ -504,6 +504,160 @@ def test_gather_dimensions_exceptions():
assert gd(fcst_dims_conflict, obs_dims, reduce_dims="all") == fcst_dims_conflict


@pytest.mark.parametrize(
("fcst", "obs", "weights", "reduce_dims", "preserve_dims", "special_fcst_dims", "error_msg_snippet"),
[
(
utils_test_data.DA_RGB,
utils_test_data.DA_R,
None,
["red"],
["blue"],
None,
utils.ERROR_OVERSPECIFIED_PRESERVE_REDUCE,
),
# checks for special_fcst_dims
(
utils_test_data.DA_RGB,
utils_test_data.DA_R,
None,
None,
None,
["black"],
"`special_fcst_dims` must be a subset of `fcst` dimensions",
),
(
utils_test_data.DA_RGB,
utils_test_data.DA_R,
None,
None,
None,
["red"],
"`obs.dims` must not contain any `special_fcst_dims`",
),
(
utils_test_data.DA_RGB,
utils_test_data.DA_R,
None,
None,
None,
"red",
"`obs.dims` must not contain any `special_fcst_dims`",
),
(
utils_test_data.DA_RGB,
utils_test_data.DA_R,
utils_test_data.DA_G,
None,
None,
"green",
"`weights.dims` must not contain any `special_fcst_dims`",
),
(
utils_test_data.DA_RGB,
utils_test_data.DA_R,
utils_test_data.DA_G,
"blue",
None,
"blue",
"`reduce_dims` and `preserve_dims` must not contain any `special_fcst_dims`",
),
(
utils_test_data.DA_RGB,
utils_test_data.DA_R,
utils_test_data.DA_G,
None,
["blue"],
"blue",
"`reduce_dims` and `preserve_dims` must not contain any `special_fcst_dims`",
),
(
utils_test_data.DA_RGB,
utils_test_data.DA_R,
utils_test_data.DA_G,
None,
["blue", "yellow"],
None,
utils.ERROR_SPECIFIED_NONPRESENT_PRESERVE_DIMENSION,
),
(
utils_test_data.DA_RGB,
utils_test_data.DA_R,
utils_test_data.DA_G,
"yellow",
None,
"blue",
utils.ERROR_SPECIFIED_NONPRESENT_REDUCE_DIMENSION,
),
],
)
def test_gather_dimensions2_exceptions(
fcst, obs, weights, reduce_dims, preserve_dims, special_fcst_dims, error_msg_snippet
):
"""
Confirm `gather_dimensions2` raises exceptions as expected.
"""
with pytest.raises(ValueError) as excinfo:
gather_dimensions2(
fcst,
obs,
weights=weights,
reduce_dims=reduce_dims,
preserve_dims=preserve_dims,
special_fcst_dims=special_fcst_dims,
)
assert error_msg_snippet in str(excinfo.value)


def test_gather_dimensions2_warnings():
"""Tests that gather_dimensions2 warns as expected with correct output."""
# Preserve "all" as a string but named dimension present in data
with pytest.warns(UserWarning):
result = gather_dimensions2(
utils_test_data.DA_R.rename({"red": "all"}), utils_test_data.DA_R, preserve_dims="all"
)
assert result == set([])

with pytest.warns(UserWarning):
result = gather_dimensions2(
utils_test_data.DA_R.rename({"red": "all"}), utils_test_data.DA_R, reduce_dims="all"
)
assert result == {"red", "all"}


@pytest.mark.parametrize(
("fcst", "obs", "weights", "reduce_dims", "preserve_dims", "special_fcst_dims", "expected"),
[
(utils_test_data.DA_B, utils_test_data.DA_R, None, None, None, None, {"blue", "red"}),
(utils_test_data.DA_B, utils_test_data.DA_R, utils_test_data.DA_G, None, None, None, {"blue", "red", "green"}),
(utils_test_data.DA_B, utils_test_data.DA_R, utils_test_data.DA_G, None, None, "blue", {"red", "green"}),
(utils_test_data.DA_B, utils_test_data.DA_R, utils_test_data.DA_G, None, None, ["blue"], {"red", "green"}),
(utils_test_data.DA_B, utils_test_data.DA_R, utils_test_data.DA_G, None, None, ["blue"], {"red", "green"}),
(utils_test_data.DA_B, utils_test_data.DA_R, utils_test_data.DA_G, "all", None, None, {"blue", "red", "green"}),
(utils_test_data.DA_B, utils_test_data.DA_R, utils_test_data.DA_G, "blue", None, None, {"blue"}),
(utils_test_data.DA_B, utils_test_data.DA_R, utils_test_data.DA_G, ["blue"], None, None, {"blue"}),
(utils_test_data.DA_RGB, utils_test_data.DA_R, utils_test_data.DA_G, ["green"], None, "blue", {"green"}),
(utils_test_data.DA_RGB, utils_test_data.DA_B, None, None, "all", "red", set([])),
(utils_test_data.DA_RGB, utils_test_data.DA_R, utils_test_data.DA_G, None, "green", None, {"red", "blue"}),
(utils_test_data.DA_RGB, utils_test_data.DA_R, None, None, ["green"], None, {"red", "blue"}),
(utils_test_data.DA_RGB, utils_test_data.DA_B, None, None, ["green"], "red", {"blue"}),
],
)
def test_gather_dimensions2_examples(fcst, obs, weights, reduce_dims, preserve_dims, special_fcst_dims, expected):
"""
Test that `gather_dimensions2` gives outputs as expected.
"""
result = gather_dimensions2(
fcst,
obs,
weights=weights,
reduce_dims=reduce_dims,
preserve_dims=preserve_dims,
special_fcst_dims=special_fcst_dims,
)
assert result == expected


def test_tmp_coord_name():
"""
Tests that `tmp_coord_name` returns as expected.
Expand Down

0 comments on commit 017aa10

Please sign in to comment.