Skip to content

Commit

Permalink
Not so None
Browse files Browse the repository at this point in the history
  • Loading branch information
ecomodeller committed Dec 14, 2023
1 parent 6d4f6ae commit 96050c3
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 50 deletions.
57 changes: 29 additions & 28 deletions modelskill/comparison/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def skill(
by: Optional[Union[str, List[str]]] = None,
metrics: Optional[List[str]] = None,
**kwargs,
) -> Optional[SkillTable]:
) -> SkillTable:
"""Aggregated skill assessment of model(s)
Parameters
Expand Down Expand Up @@ -497,8 +497,9 @@ def skill(
area=area,
)
if cmp.n_points == 0:
warnings.warn("No data!")
return None
raise ValueError("Dataset is empty, no data to compare.")

## ---- end of deprecated code ----

df = cmp.to_dataframe()
n_models = cmp.n_models # len(df.model.unique())
Expand Down Expand Up @@ -558,7 +559,7 @@ def gridded_skill(
metrics: Optional[list] = None,
n_min: Optional[int] = None,
**kwargs,
):
) -> SkillGrid:
"""Skill assessment of model(s) on a regular spatial grid.
Parameters
Expand Down Expand Up @@ -630,8 +631,9 @@ def gridded_skill(
)

if cmp.n_points == 0:
warnings.warn("No data!")
return
raise ValueError("Dataset is empty, no data to compare.")

## ---- end of deprecated code ----

df = cmp.to_dataframe()
df = _add_spatial_grid_to_df(df=df, bins=bins, binsize=binsize)
Expand Down Expand Up @@ -722,7 +724,7 @@ def mean_skill(
weights: Optional[Union[str, List[float], Dict[str, float]]] = None,
metrics: Optional[list] = None,
**kwargs,
) -> Optional[SkillTable]: # TODO raise error if no data?
) -> SkillTable:
"""Weighted mean of skills
First, the skill is calculated per observation,
Expand Down Expand Up @@ -774,16 +776,17 @@ def mean_skill(

# filter data
cmp = self.sel(
model=model,
observation=observation,
variable=variable,
start=start,
end=end,
area=area,
model=model, # deprecated
observation=observation, # deprecated
variable=variable, # deprecated
start=start, # deprecated
end=end, # deprecated
area=area, # deprecated
)
if cmp.n_points == 0:
warnings.warn("No data!")
return None
raise ValueError("Dataset is empty, no data to compare.")

## ---- end of deprecated code ----

df = cmp.to_dataframe()
mod_names = cmp.mod_names # df.model.unique()
Expand Down Expand Up @@ -947,7 +950,7 @@ def score(
self,
metric: str | Callable = mtr.rmse,
**kwargs,
) -> Dict[str, float] | None:
) -> Dict[str, float]:
"""Weighted mean score of model(s) over all observations
Wrapping mean_skill() with a single metric.
Expand All @@ -968,7 +971,7 @@ def score(
Returns
-------
float
Dict[str, float]
mean of skills score as a single number (for each model)
See also
Expand Down Expand Up @@ -1010,22 +1013,20 @@ def score(
models = [_get_name(m, self.mod_names) for m in models] # type: ignore

cmp = self.sel(
model=models,
observation=observation,
variable=variable,
start=start,
end=end,
area=area,
model=models, # deprecated
observation=observation, # deprecated
variable=variable, # deprecated
start=start, # deprecated
end=end, # deprecated
area=area, # deprecated
)

if cmp.n_points == 0:
warnings.warn("No data!")
return None
raise ValueError("Dataset is empty, no data to compare.")

skill = cmp.mean_skill(weights=weights, metrics=[metric])
if skill is None:
return None
## ---- end of deprecated code ----

skill = cmp.mean_skill(weights=weights, metrics=[metric])
df = skill.to_dataframe()

metric_name = metric if isinstance(metric, str) else metric.__name__
Expand Down
27 changes: 9 additions & 18 deletions modelskill/comparison/_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,15 @@


class Scoreable(Protocol):
# TODO should this be able to return None?
def score(self, metric: str | Callable, **kwargs) -> Dict[str, float] | None:
def score(self, metric: str | Callable, **kwargs) -> Dict[str, float]:
...

def skill(
self,
by: Optional[Union[str, List[str]]] = None,
metrics: Optional[List[str]] = None,
**kwargs,
) -> Optional[SkillTable]:
) -> SkillTable:
...

def gridded_skill(
Expand All @@ -65,7 +64,7 @@ def gridded_skill(
metrics: Optional[list] = None,
n_min: Optional[int] = None,
**kwargs,
) -> Optional[SkillGrid]:
) -> SkillGrid:
...


Expand Down Expand Up @@ -1063,26 +1062,18 @@ def score(
assert kwargs == {}, f"Unknown keyword arguments: {kwargs}"

s = self.skill(
by=["model", "observation"],
metrics=[metric],
model=model,
start=start,
end=end,
area=area,
model=model, # deprecated
start=start, # deprecated
end=end, # deprecated
area=area, # deprecated
)
# if s is None:
# return
df = s.to_dataframe()

# TODO clean up
metric_name = metric if isinstance(metric, str) else metric.__name__

values = df[metric_name].values
if len(values) == 1:
value = values[0]
return {self.mod_names[0]: value}
else:
# TODO check if this is correct
return {m: v for m, v in zip(self.mod_names, values)}
return df.reset_index().groupby("model")[metric_name].mean().to_dict()

def spatial_skill(
self,
Expand Down
12 changes: 8 additions & 4 deletions tests/test_aggregated_skill.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,13 @@ def cc2(o1, o2, o3):

def test_skill(cc1):
s = cc1.skill()
assert len(s.mod_names) == 0
assert len(s.obs_names) == 1
assert len(s.var_names) == 0

# TODO a minimal skill assesment consists of 1 observation, 1 model and 1 variable
# in this case model and variable is implict since we only have one of each, but why do we have one observation, seems inconsistent

assert len(s.mod_names) == 0 # TODO seems wrong
assert len(s.obs_names) == 1 # makes sense
assert len(s.var_names) == 0 # TODO seems wrong

df = s.to_dataframe()
assert isinstance(df, pd.DataFrame)
Expand Down Expand Up @@ -141,7 +145,7 @@ def test_skill_sel_query(cc2):
s = cc2.skill(metrics=["rmse", "bias"])
with pytest.warns(FutureWarning, match="deprecated"):
s2 = s.sel(query="rmse>0.2")

assert len(s2.mod_names) == 2

# s2 = s.sel("rmse>0.2", model="SW_2", observation=[0, 2])
Expand Down

0 comments on commit 96050c3

Please sign in to comment.