diff --git a/src/scores/probability/crps_impl.py b/src/scores/probability/crps_impl.py index 3854a581c..ba3eab084 100644 --- a/src/scores/probability/crps_impl.py +++ b/src/scores/probability/crps_impl.py @@ -753,14 +753,22 @@ def crps_for_ensemble( where X, X' are independent samples of the predictive distribution F and Y is the observation (possibly unknown). Samples from F and Y are drawn from the fcst_sample_dim and obs_sample_dim respectively. Other dimensions are broadcast using xr broadcast rules. + + Args: + fcst: Forecast data. + obs: Observation data + weights: Weights for calculating a weighted mean of scores + reduce_dims: Dimensions to reduce. Can be "all" to reduce all dimensions. + preserve_dims: Dimensions to preserve. Can be "all" to preserve all dimensions. + special_fcst_dims: Dimension(s) in `fcst` that are reduced to calculate individual scores. + Must not appear as a dimension in `obs`, `weights`, `reduce_dims` or `preserve_dims`. + e.g. the ensemble member dimension if calculating CRPS for ensembles, or the + threshold dimension of calculating CRPS for CDFs. """ - # all_dims = - # if weight not None: - # all_dims = set(weights) - # if method not in ["ecdf", "fair"]: - # raise ValueError("`method` must be one of 'ecdf' or 'fair'") - # if ensemble_member_dim in obs.dims or (weight is not None and weights.dims): - # raise ValueError("`ensemble_member_dim` cannot be a dimension of `obs` or `weights`") + if method not in ["ecdf", "fair"]: + raise ValueError("`method` must be one of 'ecdf' or 'fair'") + + dims_for_mean = scores.utils.gather_dimensions2(fcst, obs, weights, reduce_dims, preserve_dims, ensemble_member_dim) ensemble_member_dim1 = scores.utils.tmp_coord_name(fcst) @@ -779,9 +787,6 @@ def crps_for_ensemble( result = fcst_obs_term - fcst_spread_term # apply weights and take means across specified dims - fcst_dims = [x for x in fcst.dims if x != ensemble_member_dim] - reduce_dims = scores.utils.gather_dimensions(fcst_dims, obs.dims, reduce_dims, preserve_dims) # type: ignore[assignment] - result = scores.functions.apply_weights(result, weights) - result = result.mean(dim=reduce_dims) + result = scores.functions.apply_weights(result, weights).mean(dim=dims_for_mean) return result diff --git a/src/scores/utils.py b/src/scores/utils.py index 259689d86..c1d2e6b03 100644 --- a/src/scores/utils.py +++ b/src/scores/utils.py @@ -112,34 +112,39 @@ def gather_dimensions( # pylint: disable=too-many-branches return reduce_dims -def gather_dimensions2(fcst, obs, weights=None, reduce_dims=None, preserve_dims=None, special_fcst_dims=None): +def gather_dimensions2( + fcst: XarrayLike, + obs: XarrayLike, + weights: XarrayLike = None, + reduce_dims: FlexibleDimensionTypes = None, + preserve_dims: FlexibleDimensionTypes = None, + special_fcst_dims: FlexibleDimensionTypes = None, +) -> set[Hashable]: """ - Performs a standard dimensions check for a function that calculates (mean) scores. - Returns a list of the dimensions to reduce. + Performs a standard dimensions check for inputs of a function that calculates (mean) scores. + Returns a set of the dimensions to reduce. - special_fcst_dims are dims in fcst that will be collapsed while calculating individual scores - (e.g. the threshold dimension of a CDF, or the ensemble member dimesnion, when calculating CRPS) - and is never present for the step of calculating mean scores. - - Checks that: - - reduce_dims and preserve_dims are both not specified - - specified dims (reduce_dims or preserve_dims) are a subset of fcst.dims, - obs.dims and (if not None) weight.dims. - - special_fcst_dims are not in obs.dims or weights.dims - - specified dims with a value of "all" is handled correctly - - specified dims with a string value is handled correctly + Args: + fcst: Forecast data + obs: Observation data + weights: Weights for calculating a weighted mean of scores + reduce_dims: Dimensions to reduce. Can be "all" to reduce all dimensions. + preserve_dims: Dimensions to preserve. Can be "all" to preserve all dimensions. + special_fcst_dims: Dimension(s) in `fcst` that are reduced to calculate individual scores. + Must not appear as a dimension in `obs`, `weights`, `reduce_dims` or `preserve_dims`. + e.g. the ensemble member dimension if calculating CRPS for ensembles, or the + threshold dimension of calculating CRPS for CDFs. Returns: - list of dimensions over which to take the mean once the checks are passed. + Set of dimensions over which to take the mean once the checks are passed. Raises: - ValueError: - - when `preserve_dims and `reduce_dims` are both specified - - when `special_fcst_dims` is not a subset of `fcst.dims` - - when `obs.dims`, `weights.dims`, `reduce_dims` or `preserve_dims` - contains elements from `special_fcst_dims` - - when `preserve_dims and `reduce_dims` contain elements not among dimensions - of the data (`fcst`, `obs` or `weights`) + ValueError: when `preserve_dims and `reduce_dims` are both specified. + ValueError: when `special_fcst_dims` is not a subset of `fcst.dims`. + ValueError: when `obs.dims`, `weights.dims`, `reduce_dims` or `preserve_dims` + contains elements from `special_fcst_dims`. + ValueError: when `preserve_dims and `reduce_dims` contain elements not among dimensions + of the data (`fcst`, `obs` or `weights`). """ all_data_dims = set(fcst.dims).union(set(obs.dims)) if weights is not None: diff --git a/tests/probabilty/test_crps.py b/tests/probabilty/test_crps.py index 504f85ae4..1098222a7 100644 --- a/tests/probabilty/test_crps.py +++ b/tests/probabilty/test_crps.py @@ -637,8 +637,15 @@ def test_crps_for_ensemble(): result_ecdf = crps_for_ensemble(fcst, obs, "ens_member", method="ecdf", preserve_dims="all") result_fair = crps_for_ensemble(fcst, obs, "ens_member", method="fair", preserve_dims="all") - result_weighted_mean = crps_for_ensemble(fcst, obs, "ens_member", method="ecdf", preserve_dims=None, weights=weight) + result_weighted_mean = crps_for_ensemble(fcst, obs, "ens_member", method="ecdf", weights=weight) assert_dataarray_equal(result_ecdf, expected_ecdf, decimals=7) assert_dataarray_equal(result_fair, expected_fair, decimals=7) assert_dataarray_equal(result_weighted_mean, expected_weighted_mean, decimals=7) + + +def test_crps_for_ensemble_raises(): + """Tests `crps_for_ensemble` raises exception as expected.""" + with pytest.raises(ValueError) as excinfo: + crps_for_ensemble(xr.DataArray(data=[1]), xr.DataArray(data=[1]), "ens_member", "unfair") + assert "`method` must be one of 'ecdf' or 'fair'" in str(excinfo.value)