Skip to content

Commit

Permalink
continuous.standard metrics now handle angular data (#104)
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholasloveday authored and tennlee committed Jan 19, 2024
1 parent ebf3aab commit 6ebf427
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 8 deletions.
43 changes: 35 additions & 8 deletions src/scores/continuous/standard_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
import scores.utils
from scores.typing import FlexibleArrayType, FlexibleDimensionTypes

import xarray as xr


def mse(
fcst: FlexibleArrayType,
obs: FlexibleArrayType,
reduce_dims: FlexibleDimensionTypes = None,
preserve_dims: FlexibleDimensionTypes = None,
weights=None,
weights: xr.DataArray = None,
angular: bool = False,
):
"""Calculates the mean squared error from forecast and observed data.
Expand All @@ -39,6 +42,12 @@ def mse(
must match precisely.
weights: Optionally provide an array for weighted averaging (e.g. by area, by latitude,
by population, custom)
angular: specifies whether `fcst` and `obs` are angular
data (e.g. wind direction). If True, a different function is used
to calculate the difference between `fcst` and `obs`, which
accounts for circularity. Angular `fcst` and `obs` data should be in
degrees rather than radians.
Returns:
Union[xr.Dataset, xr.DataArray, pd.Dataframe, pd.Series]: An object containing
Expand All @@ -47,8 +56,10 @@ def mse(
Otherwise: Returns an object representing the mean squared error,
reduced along the relevant dimensions and weighted appropriately.
"""

error = fcst - obs
if angular:
error = scores.functions.angular_difference(fcst, obs)
else:
error = fcst - obs
squared = error * error
squared = scores.functions.apply_weights(squared, weights)

Expand All @@ -70,7 +81,8 @@ def rmse(
obs: FlexibleArrayType,
reduce_dims: FlexibleDimensionTypes = None,
preserve_dims: FlexibleDimensionTypes = None,
weights=None,
weights: xr.DataArray = None,
angular: bool = False,
) -> FlexibleArrayType:
"""Calculate the Root Mean Squared Error from xarray or pandas objects.
Expand Down Expand Up @@ -100,6 +112,11 @@ def rmse(
must match precisely.
weights: Optionally provide an array for weighted averaging (e.g. by area, by latitude,
by population, custom)
angular: specifies whether `fcst` and `obs` are angular
data (e.g. wind direction). If True, a different function is used
to calculate the difference between `fcst` and `obs`, which
accounts for circularity. Angular `fcst` and `obs` data should be in
degrees rather than radians.
Returns:
An object containing
Expand All @@ -109,7 +126,9 @@ def rmse(
reduced along the relevant dimensions and weighted appropriately.
"""
_mse = mse(fcst=fcst, obs=obs, reduce_dims=reduce_dims, preserve_dims=preserve_dims, weights=weights)
_mse = mse(
fcst=fcst, obs=obs, reduce_dims=reduce_dims, preserve_dims=preserve_dims, weights=weights, angular=angular
)

_rmse = pow(_mse, (1 / 2))

Expand All @@ -121,7 +140,8 @@ def mae(
obs: FlexibleArrayType,
reduce_dims: FlexibleDimensionTypes = None,
preserve_dims: FlexibleDimensionTypes = None,
weights=None,
weights: xr.DataArray = None,
angular: bool = False,
) -> FlexibleArrayType:
"""Calculates the mean absolute error from forecast and observed data.
Expand All @@ -146,6 +166,11 @@ def mae(
forecast and observed dimensions must match precisely.
weights: Optionally provide an array for weighted averaging (e.g. by area, by latitude,
by population, custom)
angular: specifies whether `fcst` and `obs` are angular
data (e.g. wind direction). If True, a different function is used
to calculate the difference between `fcst` and `obs`, which
accounts for circularity. Angular `fcst` and `obs` data should be in
degrees rather than radians.
Returns:
By default an xarray DataArray containing
Expand All @@ -155,8 +180,10 @@ def mae(
Alternatively, an xarray structure with dimensions preserved as appropriate
containing the score along reduced dimensions
"""

error = fcst - obs
if angular:
error = scores.functions.angular_difference(fcst, obs)
else:
error = fcst - obs
ae = abs(error)
ae = scores.functions.apply_weights(ae, weights)

Expand Down
51 changes: 51 additions & 0 deletions tests/continuous/test_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@ def test_xarray_dimension_preservations_with_arrays():
).chunk()


# Dask tests
def test_mse_with_dask():
"""
Test that mse works with dask
Expand Down Expand Up @@ -498,3 +499,53 @@ def test_rmse_with_dask():
assert isinstance(result.data, np.ndarray)
expected = xr.DataArray(data=[np.sqrt(5), 2], dims=["dim2"], coords={"dim2": [1, 2]})
xr.testing.assert_equal(result, expected)


# Angular / directional tests
DA1_ANGULAR = xr.DataArray([[10, 10], [90, 90]], coords=[[0, 1], [0, 1]], dims=["i", "j"])
DA2_ANGULAR = xr.DataArray([[350, 180], [270, 280]], coords=[[0, 1], [0, 1]], dims=["i", "j"])


def test_mse_angular():
"""Tests that `mse` returns the expected object with `angular` is True"""

expected = xr.DataArray(
[[20**2, 170**2], [180**2, 170**2]],
coords=[[0, 1], [0, 1]],
dims=["i", "j"],
name="mean_squared_error",
)

result = scores.continuous.mse(DA1_ANGULAR, DA2_ANGULAR, preserve_dims=["i", "j"], angular=True)

xr.testing.assert_equal(result, expected)


def test_mae_angular():
"""Tests that `mae` returns the expected object with `angular` is True"""

expected = xr.DataArray(
[[20, 170], [180, 170]],
coords=[[0, 1], [0, 1]],
dims=["i", "j"],
name="mean_squared_error",
)

result = scores.continuous.mae(DA1_ANGULAR, DA2_ANGULAR, preserve_dims=["i", "j"], angular=True)

xr.testing.assert_equal(result, expected)


def test_rmse_angular():
"""Tests that `rmse` returns the expected object with `angular` is True"""

expected = xr.DataArray(
[((20**2 + 170**2) / 2) ** 0.5, ((180**2 + 170**2) / 2) ** 0.5],
coords={"i": [0, 1]},
dims=["i"],
name="mean_squared_error",
)

result = scores.continuous.rmse(DA1_ANGULAR, DA2_ANGULAR, preserve_dims=["i"], angular=True)

xr.testing.assert_equal(result, expected)

0 comments on commit 6ebf427

Please sign in to comment.