Skip to content

Commit

Permalink
can now choose if thresholds for firm are left or right alaligned
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholasloveday committed Dec 6, 2023
1 parent ee69685 commit 6dd7f45
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 30 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/run-pre-commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,6 @@ jobs:
pip install .[all]
- name: pre-commit checks
run: |
pre-commit run --all-files
pre-commit run black --all-files
pre-commit run isort --all-files
pre-commit run pylint --all-files
46 changes: 29 additions & 17 deletions src/scores/categorical/multicategorical_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import xarray as xr

from scores.functions import apply_weights
from scores.typing import FlexibleDimensionTypes
from scores.utils import check_dims, gather_dimensions


Expand All @@ -18,9 +19,10 @@ def firm( # pylint: disable=too-many-arguments
categorical_thresholds: Sequence[float],
threshold_weights: Sequence[Union[float, xr.DataArray]],
discount_distance: Optional[float] = 0,
reduce_dims: Optional[Sequence[str]] = None,
preserve_dims: Optional[Sequence[str]] = None,
reduce_dims: FlexibleDimensionTypes = None,
preserve_dims: FlexibleDimensionTypes = None,
weights: Optional[xr.DataArray] = None,
threshold_assignment: Optional[str] = "lower",
) -> xr.Dataset:
"""
Calculates the FIxed Risk Multicategorical (FIRM) score including the
Expand Down Expand Up @@ -64,6 +66,9 @@ def firm( # pylint: disable=too-many-arguments
supplied. The default behaviour if neither are supplied is to reduce all dims.
weights: Optionally provide an array for weighted averaging (e.g. by area, by latitude,
by population, custom)
threshold_assignment: Specifies whether the intervals defining the categories are
left or right closed. That is whether the decision threshold is included in
the upper (left closed) or lower (right closed) category. Defaults to "lower".
Returns:
An xarray Dataset with data vars:
Expand Down Expand Up @@ -98,15 +103,13 @@ def firm( # pylint: disable=too-many-arguments
Journal of the Royal Meteorological Society, 148(744), pp.1389-1406.
"""
_check_firm_inputs(
obs,
risk_parameter,
categorical_thresholds,
threshold_weights,
discount_distance,
obs, risk_parameter, categorical_thresholds, threshold_weights, discount_distance, threshold_assignment
)
total_score = []
for categorical_threshold, weight in zip(categorical_thresholds, threshold_weights):
score = weight * _single_category_score(fcst, obs, risk_parameter, categorical_threshold, discount_distance)
score = weight * _single_category_score(
fcst, obs, risk_parameter, categorical_threshold, discount_distance, threshold_assignment
)
total_score.append(score)
summed_score = sum(total_score)
reduce_dims = gather_dimensions(fcst.dims, obs.dims, reduce_dims, preserve_dims) # type: ignore[assignment]
Expand All @@ -117,11 +120,7 @@ def firm( # pylint: disable=too-many-arguments


def _check_firm_inputs(
obs,
risk_parameter,
categorical_thresholds,
threshold_weights,
discount_distance,
obs, risk_parameter, categorical_thresholds, threshold_weights, discount_distance, threshold_assignment
):
"""
Checks that the FIRM inputs are suitable
Expand Down Expand Up @@ -150,13 +149,17 @@ def _check_firm_inputs(
if discount_distance < 0:
raise ValueError("`discount_distance` must be >= 0")

if threshold_assignment not in ["upper", "lower"]:
raise ValueError(""" `threshold_assignment` must be either \"upper\" or \"lower\" """)


def _single_category_score(
fcst: xr.DataArray,
obs: xr.DataArray,
risk_parameter: float,
categorical_threshold: float,
discount_distance: Optional[float] = None,
threshold_assignment: Optional[str] = "lower",
) -> xr.Dataset:
"""
Calculates the score for a single category for the `firm` metric at each
Expand All @@ -175,6 +178,9 @@ def _single_category_score(
discounted whenever the observation is within distance
`discount_distance` of the forecast category. A value of 0
will not a apply any discounting.
threshold_assignment: Specifies whether the intervals defining the categories are
left or right closed. That is whether the decision threshold is included in
the upper (left closed) or lower (right closed) category. Defaults to "lower".
Returns:
An xarray Dataset with data vars:
Expand All @@ -187,10 +193,16 @@ def _single_category_score(
# pylint: disable=unbalanced-tuple-unpacking
fcst, obs = xr.align(fcst, obs)

# False Alarms
condition1 = (obs <= categorical_threshold) & (categorical_threshold < fcst)
# Misses
condition2 = (fcst <= categorical_threshold) & (categorical_threshold < obs)
if threshold_assignment == "lower":
# False Alarms
condition1 = (obs <= categorical_threshold) & (categorical_threshold < fcst)
# Misses
condition2 = (fcst <= categorical_threshold) & (categorical_threshold < obs)
else:
# False Alarms
condition1 = (obs < categorical_threshold) & (categorical_threshold <= fcst)
# Misses
condition2 = (fcst < categorical_threshold) & (categorical_threshold <= obs)

# Bring back NaNs
condition1 = condition1.where(~np.isnan(fcst))
Expand Down
75 changes: 75 additions & 0 deletions tests/categorical/multicategorical_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@
"k": [10, 11, 12],
},
)
DA_FCST_SC2 = xr.DataArray(
data=[3, 3, 1, 2, 2, 1],
dims=["i"],
coords={
"i": [1, 2, 3, 4, 5, 6],
},
)

DA_OBS_SC = xr.DataArray(
data=[[10, np.nan], [0, 1]],
Expand All @@ -23,6 +30,14 @@
},
)

DA_OBS_SC2 = xr.DataArray(
data=[1, 2, 3, 4, 1, 2],
dims=["i"],
coords={
"i": [1, 2, 3, 4, 5, 6],
},
)

EXP_SC_TOTAL_CASE0 = xr.DataArray(
data=[[[np.nan, 0.3], [0.7, np.nan]], [[0.0, 0], [0, np.nan]]],
dims=["i", "j", "k"],
Expand Down Expand Up @@ -147,6 +162,66 @@
)


EXP_SC_TOTAL_CASE4 = xr.DataArray(
data=[0.3, 0.3, 0.7, 0.7, 0, 0],
dims=["i"],
coords={
"i": [1, 2, 3, 4, 5, 6],
},
)
EXP_SC_UNDER_CASE4 = xr.DataArray(
data=[0, 0, 0.7, 0.7, 0, 0],
dims=["i"],
coords={
"i": [1, 2, 3, 4, 5, 6],
},
)
EXP_SC_OVER_CASE4 = xr.DataArray(
data=[0.3, 0.3, 0, 0, 0, 0],
dims=["i"],
coords={
"i": [1, 2, 3, 4, 5, 6],
},
)

EXP_SC_CASE4 = xr.Dataset(
{
"firm_score": EXP_SC_TOTAL_CASE4,
"underforecast_penalty": EXP_SC_UNDER_CASE4,
"overforecast_penalty": EXP_SC_OVER_CASE4,
}
)

EXP_SC_TOTAL_CASE5 = xr.DataArray(
data=[0.3, 0, 0.7, 0.0, 0.3, 0.7],
dims=["i"],
coords={
"i": [1, 2, 3, 4, 5, 6],
},
)
EXP_SC_UNDER_CASE5 = xr.DataArray(
data=[0, 0, 0.7, 0, 0, 0.7],
dims=["i"],
coords={
"i": [1, 2, 3, 4, 5, 6],
},
)
EXP_SC_OVER_CASE5 = xr.DataArray(
data=[0.3, 0.0, 0, 0, 0.3, 0],
dims=["i"],
coords={
"i": [1, 2, 3, 4, 5, 6],
},
)

EXP_SC_CASE5 = xr.Dataset(
{
"firm_score": EXP_SC_TOTAL_CASE5,
"underforecast_penalty": EXP_SC_UNDER_CASE5,
"overforecast_penalty": EXP_SC_OVER_CASE5,
}
)

DA_FCST_FIRM = xr.DataArray(
data=[
[[np.nan, 7, 4], [-100, 0, 1], [0, -100, 1]],
Expand Down
50 changes: 38 additions & 12 deletions tests/categorical/test_multicategorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,32 @@


@pytest.mark.parametrize(
("fcst", "obs", "categorical_threshold", "discount_distance", "expected"),
("fcst", "obs", "categorical_threshold", "discount_distance", "threshold_assignment", "expected"),
[
# Threshold 5, discount = 0, preserve all dims
(mtd.DA_FCST_SC, mtd.DA_OBS_SC, 5, 0, mtd.EXP_SC_CASE0),
(mtd.DA_FCST_SC, mtd.DA_OBS_SC, 5, 0, "lower", mtd.EXP_SC_CASE0),
# Threshold -200, discount = 0, preserve 1 dim
(mtd.DA_FCST_SC, mtd.DA_OBS_SC, -200, 0, mtd.EXP_SC_CASE1),
(mtd.DA_FCST_SC, mtd.DA_OBS_SC, -200, 0, "lower", mtd.EXP_SC_CASE1),
# Threshold 200, discount = 0, preserve 1 dim
(mtd.DA_FCST_SC, mtd.DA_OBS_SC, 200, 0, mtd.EXP_SC_CASE1),
(mtd.DA_FCST_SC, mtd.DA_OBS_SC, 200, 0, "lower", mtd.EXP_SC_CASE1),
# Threshold 5, discount = 7, preserve all dims.
# discount_distance is maximum for both false alarms and misses
(mtd.DA_FCST_SC, mtd.DA_OBS_SC, 5, 7, mtd.EXP_SC_CASE2),
(mtd.DA_FCST_SC, mtd.DA_OBS_SC, 5, 7, "lower", mtd.EXP_SC_CASE2),
# Threshold 5, discount = 0.5, preserve all dims.
# discount_distance is minimum for both false alarms and misses
(mtd.DA_FCST_SC, mtd.DA_OBS_SC, 5, 0.5, mtd.EXP_SC_CASE3),
(mtd.DA_FCST_SC, mtd.DA_OBS_SC, 5, 0.5, "lower", mtd.EXP_SC_CASE3),
# Test lower/right assignment
(mtd.DA_FCST_SC2, mtd.DA_OBS_SC2, 2, None, "lower", mtd.EXP_SC_CASE4),
# Test upper/left assignment
(mtd.DA_FCST_SC2, mtd.DA_OBS_SC2, 2, None, "upper", mtd.EXP_SC_CASE5),
],
)
def test__single_category_score(fcst, obs, categorical_threshold, discount_distance, expected):
def test__single_category_score(fcst, obs, categorical_threshold, discount_distance, threshold_assignment, expected):
"""Tests _single_category_score"""
risk_parameter = 0.7

calculated = _single_category_score(
fcst,
obs,
risk_parameter,
categorical_threshold,
discount_distance,
fcst, obs, risk_parameter, categorical_threshold, discount_distance, threshold_assignment
)
xr.testing.assert_allclose(calculated, expected)

Expand Down Expand Up @@ -286,6 +286,7 @@ def test_firm_dask():
"weights",
"preserve_dims",
"discount_distance",
"threshold_assignment",
"error_type",
"error_msg_snippet",
),
Expand All @@ -299,6 +300,7 @@ def test_firm_dask():
[],
["i", "j", "k"],
0,
"upper",
ValueError,
"`categorical_thresholds` must have at least",
),
Expand All @@ -311,6 +313,7 @@ def test_firm_dask():
[1, 2],
["i", "j", "k"],
0,
"upper",
ValueError,
"`categorical_thresholds` and `weights`",
),
Expand All @@ -323,6 +326,7 @@ def test_firm_dask():
[1],
["i", "j", "k"],
0,
"upper",
ValueError,
"0 < `risk_parameter` < 1 must",
),
Expand All @@ -335,6 +339,7 @@ def test_firm_dask():
[1],
["i", "j", "k"],
0,
"upper",
ValueError,
"0 < `risk_parameter` < 1 must",
),
Expand All @@ -347,6 +352,7 @@ def test_firm_dask():
[1, -1],
["i", "j", "k"],
0,
"upper",
ValueError,
"`weights` must be > 0",
),
Expand All @@ -359,6 +365,7 @@ def test_firm_dask():
mtd.LIST_WEIGHTS_FIRM3,
["i", "j", "k"],
0,
"upper",
ValueError,
"value was found in index 0 of `weights",
),
Expand All @@ -371,6 +378,7 @@ def test_firm_dask():
[1, 0],
["i", "j", "k"],
0,
"upper",
ValueError,
"`weights` must be > 0",
),
Expand All @@ -383,6 +391,7 @@ def test_firm_dask():
mtd.LIST_WEIGHTS_FIRM4,
["i", "j", "k"],
0,
"upper",
ValueError,
"No values <= 0 are allowed in `weights`",
),
Expand All @@ -395,6 +404,7 @@ def test_firm_dask():
mtd.LIST_WEIGHTS_FIRM5,
["i", "j", "k"],
0,
"upper",
DimensionError,
"of data object are not subset to",
),
Expand All @@ -407,9 +417,23 @@ def test_firm_dask():
[1],
["i", "j", "k"],
-1,
"upper",
ValueError,
"`discount_distance` must be >= 0",
),
# wrong threshold assignment
(
mtd.DA_FCST_FIRM,
mtd.DA_OBS_FIRM,
0.5,
[5],
[1],
["i", "j", "k"],
0.0,
"up",
ValueError,
""" `threshold_assignment` must be either \"upper\" or \"lower\" """,
),
],
)
def test_firm_raises(
Expand All @@ -420,6 +444,7 @@ def test_firm_raises(
weights,
preserve_dims,
discount_distance,
threshold_assignment,
error_type,
error_msg_snippet,
):
Expand All @@ -436,4 +461,5 @@ def test_firm_raises(
discount_distance,
None,
preserve_dims,
threshold_assignment=threshold_assignment,
)

0 comments on commit 6dd7f45

Please sign in to comment.