Skip to content

Commit

Permalink
Merge pull request #330 from DHI/scoreable
Browse files Browse the repository at this point in the history
Ensure Comparer / ComparerCollection consistency
  • Loading branch information
jsmariegaard authored Dec 15, 2023
2 parents 832f3a6 + dd85573 commit cf765a0
Show file tree
Hide file tree
Showing 7 changed files with 332 additions and 277 deletions.
1 change: 1 addition & 0 deletions modelskill/comparison/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
from ._comparison import Comparer
from ._collection import ComparerCollection


__all__ = ["Comparer", "ComparerCollection"]
85 changes: 41 additions & 44 deletions modelskill/comparison/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path
import tempfile
from typing import (
Callable,
Dict,
List,
Union,
Expand All @@ -27,7 +28,7 @@
from ..settings import options, reset_option

from ..utils import _get_idx, _get_name
from ._comparison import Comparer
from ._comparison import Comparer, Scoreable
from ._utils import (
_parse_metric,
_add_spatial_grid_to_df,
Expand Down Expand Up @@ -79,7 +80,7 @@ def _all_df_template(n_variables: int = 1):
return res


class ComparerCollection(Mapping):
class ComparerCollection(Mapping, Scoreable):
"""
Collection of comparers, constructed by calling the `modelskill.match` method.
Expand Down Expand Up @@ -435,7 +436,7 @@ def skill(
by: Optional[Union[str, List[str]]] = None,
metrics: Optional[List[str]] = None,
**kwargs,
) -> Optional[SkillTable]:
) -> SkillTable:
"""Aggregated skill assessment of model(s)
Parameters
Expand Down Expand Up @@ -496,8 +497,9 @@ def skill(
area=area,
)
if cmp.n_points == 0:
warnings.warn("No data!")
return None
raise ValueError("Dataset is empty, no data to compare.")

## ---- end of deprecated code ----

df = cmp.to_dataframe()
n_models = cmp.n_models # len(df.model.unique())
Expand Down Expand Up @@ -557,7 +559,7 @@ def gridded_skill(
metrics: Optional[list] = None,
n_min: Optional[int] = None,
**kwargs,
):
) -> SkillGrid:
"""Skill assessment of model(s) on a regular spatial grid.
Parameters
Expand Down Expand Up @@ -629,8 +631,9 @@ def gridded_skill(
)

if cmp.n_points == 0:
warnings.warn("No data!")
return
raise ValueError("Dataset is empty, no data to compare.")

## ---- end of deprecated code ----

df = cmp.to_dataframe()
df = _add_spatial_grid_to_df(df=df, bins=bins, binsize=binsize)
Expand Down Expand Up @@ -721,7 +724,7 @@ def mean_skill(
weights: Optional[Union[str, List[float], Dict[str, float]]] = None,
metrics: Optional[list] = None,
**kwargs,
) -> Optional[SkillTable]: # TODO raise error if no data?
) -> SkillTable:
"""Weighted mean of skills
First, the skill is calculated per observation,
Expand Down Expand Up @@ -773,16 +776,17 @@ def mean_skill(

# filter data
cmp = self.sel(
model=model,
observation=observation,
variable=variable,
start=start,
end=end,
area=area,
model=model, # deprecated
observation=observation, # deprecated
variable=variable, # deprecated
start=start, # deprecated
end=end, # deprecated
area=area, # deprecated
)
if cmp.n_points == 0:
warnings.warn("No data!")
return None
raise ValueError("Dataset is empty, no data to compare.")

## ---- end of deprecated code ----

df = cmp.to_dataframe()
mod_names = cmp.mod_names # df.model.unique()
Expand Down Expand Up @@ -813,6 +817,9 @@ def weighted_mean(x):
agg[metric.__name__] = weighted_mean # type: ignore
res = skilldf.groupby(by).agg(agg)

# TODO is this correct?
res.index.name = "model"

# output
res = cmp._add_as_col_if_not_in_index(df, res, fields=["model", "variable"])
return SkillTable(res.astype({"n": int}))
Expand Down Expand Up @@ -944,11 +951,9 @@ def _parse_weights(self, weights, observations):

def score(
self,
*,
weights: Optional[Union[str, List[float], Dict[str, float]]] = None,
metric=mtr.rmse,
metric: str | Callable = mtr.rmse,
**kwargs,
) -> Optional[float]: # TODO raise error if no data?
) -> Dict[str, float]:
"""Weighted mean score of model(s) over all observations
Wrapping mean_skill() with a single metric.
Expand All @@ -969,7 +974,7 @@ def score(
Returns
-------
float
Dict[str, float]
mean of skills score as a single number (for each model)
See also
Expand All @@ -993,6 +998,8 @@ def score(
>>> cc.score(weights='points', metric="mape")
8.414442957854142
"""

weights = kwargs.pop("weights", None)
metric = _parse_metric(metric, self.metrics)
if not (callable(metric) or isinstance(metric, str)):
raise ValueError("metric must be a string or a function")
Expand All @@ -1007,37 +1014,27 @@ def score(
# TODO: these two lines looks familiar, extract to function
models = [model] if np.isscalar(model) else model # type: ignore
models = [_get_name(m, self.mod_names) for m in models] # type: ignore
n_models = len(models)

cmp = self.sel(
model=models,
observation=observation,
variable=variable,
start=start,
end=end,
area=area,
model=models, # deprecated
observation=observation, # deprecated
variable=variable, # deprecated
start=start, # deprecated
end=end, # deprecated
area=area, # deprecated
)

if cmp.n_points == 0:
warnings.warn("No data!")
return None
raise ValueError("Dataset is empty, no data to compare.")

skill = cmp.mean_skill(weights=weights, metrics=[metric])
if skill is None:
return None
## ---- end of deprecated code ----

skill = cmp.mean_skill(weights=weights, metrics=[metric])
df = skill.to_dataframe()

if n_models == 1:
score = df[metric.__name__].values.mean()
else:
score = {}
for model in models:
mtr_val = df.loc[model][metric.__name__]
if not np.isscalar(mtr_val):
# e.g. mean over different variables!
mtr_val = mtr_val.values.mean()
score[model] = mtr_val
metric_name = metric if isinstance(metric, str) else metric.__name__

score = df[metric_name].to_dict()

return score

Expand Down
54 changes: 40 additions & 14 deletions modelskill/comparison/_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
from dataclasses import dataclass
from pathlib import Path
from typing import (
Callable,
Dict,
List,
Mapping,
Optional,
Union,
Iterable,
Protocol,
Sequence,
TYPE_CHECKING,
)
Expand Down Expand Up @@ -42,6 +44,30 @@
from ._collection import ComparerCollection


class Scoreable(Protocol):
def score(self, metric: str | Callable, **kwargs) -> Dict[str, float]:
...

def skill(
self,
by: Optional[Union[str, List[str]]] = None,
metrics: Optional[List[str]] = None,
**kwargs,
) -> SkillTable:
...

def gridded_skill(
self,
bins=5,
binsize: Optional[float] = None,
by: Optional[Union[str, List[str]]] = None,
metrics: Optional[list] = None,
n_min: Optional[int] = None,
**kwargs,
) -> SkillGrid:
...


def _parse_dataset(data) -> xr.Dataset:
if not isinstance(data, xr.Dataset):
raise ValueError("matched_data must be an xarray.Dataset")
Expand Down Expand Up @@ -372,7 +398,7 @@ def _matched_data_to_xarray(
return ds


class Comparer:
class Comparer(Scoreable):
"""
Comparer class for comparing model and observation data.
Expand Down Expand Up @@ -472,7 +498,8 @@ def __repr__(self):
f"Observation: {self.name}, n_points={self.n_points}",
]
for model in self.mod_names:
out.append(f" Model: {model}, rmse={self.sel(model=model).score():.3f}")
out.append(f" Model: {model}, rmse={self.score()[model]:.3f}")

for var in self.aux_names:
out.append(f" Auxiliary: {var}")
return str.join("\n", out)
Expand Down Expand Up @@ -1010,9 +1037,9 @@ def _add_as_col_if_not_in_index(self, df, skilldf):

def score(
self,
metric=mtr.rmse,
metric: str | Callable = mtr.rmse,
**kwargs,
) -> float:
) -> Dict[str, float]:
"""Model skill score
Parameters
Expand Down Expand Up @@ -1049,19 +1076,18 @@ def score(
assert kwargs == {}, f"Unknown keyword arguments: {kwargs}"

s = self.skill(
by=["model", "observation"],
metrics=[metric],
model=model,
start=start,
end=end,
area=area,
model=model, # deprecated
start=start, # deprecated
end=end, # deprecated
area=area, # deprecated
)
# if s is None:
# return
df = s.to_dataframe()
values = df[metric.__name__].values
if len(values) == 1:
values = values[0]
return values

metric_name = metric if isinstance(metric, str) else metric.__name__

return df.reset_index().groupby("model")[metric_name].mean().to_dict()

def spatial_skill(
self,
Expand Down
Loading

0 comments on commit cf765a0

Please sign in to comment.