Skip to content

Commit

Permalink
ROC tutorial notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholasloveday committed Oct 31, 2023
1 parent c4acfe6 commit 9e50fcb
Show file tree
Hide file tree
Showing 7 changed files with 1,458 additions and 32 deletions.
26 changes: 12 additions & 14 deletions src/scores/categorical/binary_impl.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
"""
This module contains methods for binary categories
"""
from typing import Optional, Sequence
from typing import Optional

import numpy as np
import xarray as xr

from scores.functions import apply_weights
from scores.processing import check_binary
from scores.typing import FlexibleDimensionTypes, XarrayLike
from scores.utils import gather_dimensions


def probability_of_detection(
fcst: xr.DataArray,
obs: xr.DataArray,
reduce_dims: Optional[Sequence[str]] = None,
preserve_dims: Optional[Sequence[str]] = None,
fcst: XarrayLike,
obs: XarrayLike,
reduce_dims: FlexibleDimensionTypes = None,
preserve_dims: FlexibleDimensionTypes = None,
weights: Optional[xr.DataArray] = None,
check_args: Optional[bool] = True,
) -> xr.DataArray:
) -> XarrayLike:
"""
Calculates the Probability of Detection (POD), also known as the Hit Rate.
This is the proportion of observed events (obs = 1) that were correctly
Expand Down Expand Up @@ -75,19 +76,17 @@ def probability_of_detection(
hits = hits.sum(dim=dims_to_sum)

pod = hits / (hits + misses)

pod.name = "ctable_probability_of_detection"
return pod


def probability_of_false_detection(
fcst: xr.DataArray,
obs: xr.DataArray,
reduce_dims: Optional[Sequence[str]] = None,
preserve_dims: Optional[Sequence[str]] = None,
fcst: XarrayLike,
obs: XarrayLike,
reduce_dims: FlexibleDimensionTypes = None,
preserve_dims: FlexibleDimensionTypes = None,
weights: Optional[xr.DataArray] = None,
check_args: Optional[bool] = True,
) -> xr.DataArray:
) -> XarrayLike:
"""
Calculates the Probability of False Detection (POFD), also known as False
Alarm Rate (not to be confused with the False Alarm Ratio). The POFD is
Expand Down Expand Up @@ -144,5 +143,4 @@ def probability_of_false_detection(
correct_negatives = correct_negatives.sum(dim=dims_to_sum)

pofd = false_alarms / (false_alarms + correct_negatives)
pofd.name = "ctable_probability_of_false_detection"
return pofd
3 changes: 2 additions & 1 deletion src/scores/probability/roc_impl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""
Implementation of Reciever Operating Characteristic (ROC) calculations
"""
from typing import Iterable, Optional, Sequence
from collections.abc import Iterable, Sequence
from typing import Optional

import numpy as np
import xarray as xr
Expand Down
43 changes: 28 additions & 15 deletions src/scores/processing.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Tools for processing data for verification"""
import operator
from typing import Union
from typing import Optional, Union

import numpy as np
import pandas as pd
import xarray as xr

from scores.typing import FlexibleDimensionTypes, XarrayLike

INEQUALITY_MODES = {
">=": (operator.ge, -1),
">": (operator.gt, 1),
Expand All @@ -17,7 +19,7 @@
EQUALITY_MODES = {"==": (operator.le), "!=": (operator.gt)}


def check_binary(data: Union[xr.DataArray, xr.Dataset], name: str):
def check_binary(data: XarrayLike, name: str):
"""
Checks that data does not have any non-NaN values out of the set {0, 1}
Expand All @@ -27,25 +29,30 @@ def check_binary(data: Union[xr.DataArray, xr.Dataset], name: str):
ValueError: if there are values in `fcst` and `obs` that are not in the
set {0, 1, np.nan} and `check_args` is true.
"""
unique_values = pd.unique(data.values.flatten())
if isinstance(data, xr.DataArray):
unique_values = pd.unique(data.values.flatten())
else:
unique_values = pd.unique(data.to_array().values.flatten())
unique_values = unique_values[~np.isnan(unique_values)]
binary_set = {0, 1}

if not set(unique_values).issubset(binary_set):
raise ValueError(f"`{name}` contains values that are not in the set {{0, 1, np.nan}}")


def comparative_discretise(data, comparison, mode, abs_tolerance=None):
def comparative_discretise(
data: XarrayLike, comparison: Union[xr.DataArray, float, int], mode: str, abs_tolerance: Optional[float] = None
) -> XarrayLike:
"""
Converts the values of `data` to 0 or 1 based on how they relate to the specified
values in `comparison` via the `mode` operator.
Args:
data (xarray.DataArray or xarray.Dataset): The data to convert to
data: The data to convert to
discrete values.
comparison (xarray.DataArray, float or int): The values to which
comparison: The values to which
to compare `data`.
mode (str): Specifies the required relation of `data` to `thresholds`
mode: Specifies the required relation of `data` to `thresholds`
for a value to fall in the 'event' category (i.e. assigned to 1).
Allowed modes are:
- '>=' values in `data` greater than or equal to the
Expand All @@ -60,7 +67,7 @@ def comparative_discretise(data, comparison, mode, abs_tolerance=None):
are assigned as 1
- '!=' values in `data` not equal to the corresponding threshold
are assigned as 1.
abs_tolerance (Optional[float]): If supplied, values in data that are
abs_tolerance: If supplied, values in data that are
within abs_tolerance of a threshold are considered to be equal to
that threshold. This is generally used to correct for floating
point rounding, e.g. we may want to consider 1.0000000000000002 as
Expand Down Expand Up @@ -108,17 +115,23 @@ def comparative_discretise(data, comparison, mode, abs_tolerance=None):
return discrete_data


def binary_discretise(data, thresholds, mode, abs_tolerance=None, autosqueeze=False):
def binary_discretise(
data: XarrayLike,
thresholds: FlexibleDimensionTypes,
mode: str,
abs_tolerance: Optional[float] = None,
autosqueeze: Optional[bool] = False,
):
"""
Converts the values of `data` to 0 or 1 for each threshold in `thresholds`
according to the operation defined by `mode`.
Args:
data (xarray.DataArray or xarray.Dataset): The data to convert to
data: The data to convert to
discrete values.
thresholds (float or iterable): Threshold(s) at which to convert the
thresholds: Threshold(s) at which to convert the
values of `data` to 0 or 1.
mode (str): Specifies the required relation of `data` to `thresholds`
mode: Specifies the required relation of `data` to `thresholds`
for a value to fall in the 'event' category (i.e. assigned to 1).
Allowed modes are:
Expand All @@ -135,13 +148,13 @@ def binary_discretise(data, thresholds, mode, abs_tolerance=None, autosqueeze=Fa
- '!=' values in `data` not equal to the corresponding threshold
are assigned as 1.
abs_tolerance (Optional[float]): If supplied, values in data that are
abs_tolerance: If supplied, values in data that are
within abs_tolerance of a threshold are considered to be equal to
that threshold. This is generally used to correct for floating
point rounding, e.g. we may want to consider 1.0000000000000002 as
equal to 1
autosqueeze (Optional[bool]): If True and only one threshold is
autosqueeze: If True and only one threshold is
supplied, then the dimension 'threshold' is squeezed out of the
output. If `thresholds` is float-like, then this is forced to
True, otherwise defaults to False.
Expand Down Expand Up @@ -184,7 +197,7 @@ def binary_discretise(data, thresholds, mode, abs_tolerance=None, autosqueeze=Fa
return discrete_data


def broadcast_and_match_nan(*args: Union[xr.DataArray, xr.Dataset]):
def broadcast_and_match_nan(*args: XarrayLike) -> XarrayLike:
"""
Input xarray data objects are 'matched' - they are broadcast against each
other (forced to have the same dimensions), and the position of nans are
Expand Down
16 changes: 16 additions & 0 deletions tests/categorical/test_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@
(fcst_mix, obs1, "a", True, None, expected_poda), # Fcst mix, obs ones, only reduce one dim
(fcst_bad, obs0, None, False, None, expected_pod0), # Don't check for bad data
(fcst_mix, obs1, None, True, weight_array, expected_pod_weighted), # Fcst mixed, obs ones, with weights
(
xr.Dataset({"array1": fcst0, "array2": fcst1}),
xr.Dataset({"array1": obs0, "array2": obs1}),
None,
True,
None,
xr.Dataset({"array1": expected_pod0, "array2": expected_pod1}),
), # Test with DataSet for inputs
],
)
def test_probability_of_detection(fcst, obs, reduce_dims, check_args, weights, expected):
Expand Down Expand Up @@ -90,6 +98,14 @@ def test_probability_of_detection_raises(fcst, obs, error_msg):
(fcst_mix, obs0, "a", True, None, expected_poda), # Fcst mix, obs ones, only reduce one dim
(fcst_bad, obs0, None, False, None, expected_pofd0), # Don't check for bad data
(fcst_mix, obs0, None, True, weight_array, expected_pofd_weighted), # Fcst mixed, obs ones, with weights
(
xr.Dataset({"array1": fcst0, "array2": fcst1}),
xr.Dataset({"array1": obs0, "array2": obs1}),
None,
True,
None,
xr.Dataset({"array1": expected_pofd0, "array2": expected_pofd1}),
), # Test with DataSet for inputs
],
)
def test_probability_of_false_detection(fcst, obs, reduce_dims, check_args, weights, expected):
Expand Down
2 changes: 1 addition & 1 deletion tests/probabilty/test_roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import dask
import numpy as np
import pytest
import roc_test_data as rtd
import xarray as xr

from scores.probability import roc_curve_data
from tests.probabilty import roc_test_data as rtd


@pytest.mark.parametrize(
Expand Down
3 changes: 2 additions & 1 deletion tests/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import pytest
import xarray as xr

from tests import test_processing_data as xtd
from scores.processing import (
binary_discretise,
broadcast_and_match_nan,
check_binary,
comparative_discretise,
)
from tests import test_processing_data as xtd


@pytest.mark.parametrize(
("args", "expected"),
Expand Down
Loading

0 comments on commit 9e50fcb

Please sign in to comment.