Skip to content

Commit

Permalink
Merge pull request #329 from DHI/xy-in-skilltable-alt
Browse files Browse the repository at this point in the history
Include x, y in SkillTable - alternative way
  • Loading branch information
ecomodeller authored Dec 19, 2023
2 parents 7064f02 + 36e15be commit 1868b8c
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 18 deletions.
5 changes: 4 additions & 1 deletion modelskill/comparison/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,10 @@ def skill(
) # len(df.variable.unique()) if (self.n_variables > 1) else 1
by = _parse_groupby(by, n_models, n_obs, n_var)

res = _groupby_df(df.drop(columns=["x", "y"]), by, metrics)
res = _groupby_df(df, by, metrics)
res["x"] = df.groupby(by=by, observed=False).x.first()
res["y"] = df.groupby(by=by, observed=False).y.first()
# TODO: set x,y to NaN if TrackObservation
res = cmp._add_as_col_if_not_in_index(df, skilldf=res)
return SkillTable(res)

Expand Down
7 changes: 5 additions & 2 deletions modelskill/comparison/_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,8 +1017,11 @@ def skill(

by = _parse_groupby(by, cmp.n_models, n_obs=1, n_var=1)

df = cmp.to_dataframe() # TODO: avoid df if possible?
res = _groupby_df(df.drop(columns=["x", "y"]), by, metrics)
df = cmp.to_dataframe()
res = _groupby_df(df, by, metrics)
res["x"] = df.groupby(by=by, observed=False).x.first()
res["y"] = df.groupby(by=by, observed=False).y.first()
# TODO: set x,y to NaN if TrackObservation
res = self._add_as_col_if_not_in_index(df, skilldf=res)
return SkillTable(res)

Expand Down
89 changes: 75 additions & 14 deletions modelskill/skill.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations
import warnings
from typing import Iterable, Collection, overload, Hashable
from typing import Iterable, Collection, overload, Hashable, TYPE_CHECKING
import numpy as np
import pandas as pd

if TYPE_CHECKING:
import geopandas as gpd

from .plotting._misc import _get_fig_ax


Expand Down Expand Up @@ -42,13 +45,28 @@ def _name_to_title_in_kwargs(self, kwargs):
kwargs["title"] = self.skillarray.name

def _get_plot_df(self, level: int | str = 0) -> pd.DataFrame:
ser = self.skillarray.data
ser = self.skillarray._ser
if isinstance(ser.index, pd.MultiIndex):
df = ser.unstack(level=level)
else:
df = ser.to_frame()
return df

# TODO hide this for now until we are certain about the API
# def map(self, **kwargs):
# if "model" in self.skillarray.data.index.names:
# n_models = len(self.skillarray.data.reset_index().model.unique())
# if n_models > 1:
# raise ValueError(
# "map() is only possible for single model skill. Use .sel(model=...) to select a single model."
# )

# gdf = self.skillarray.to_geodataframe()
# column = self.skillarray.name
# kwargs = {"marker_kwds": {"radius": 10}} | kwargs

# return gdf.explore(column=column, **kwargs)

def line(
self,
level: int | str = 0,
Expand Down Expand Up @@ -187,7 +205,7 @@ def grid(
"""

s = self.skillarray
ser = s.data
ser = s._ser

errors = _validate_multi_index(ser.index)
if len(errors) > 0:
Expand Down Expand Up @@ -287,14 +305,17 @@ class SkillArray:
>>> s.rmse.plot.line()
"""

def __init__(self, data: pd.Series) -> None:
assert isinstance(data, pd.Series)
def __init__(self, data: pd.DataFrame) -> None:
self.data = data
self._ser = data.iloc[:, -1] # last column is the metric
self.plot = SkillArrayPlotter(self)

def to_dataframe(self) -> pd.DataFrame:
def to_dataframe(self, drop_xy=True) -> pd.DataFrame:
"""Output as pd.DataFrame"""
return self.data.to_dataframe()
if drop_xy:
return self._ser.to_frame()
else:
return self.data.copy()

def __repr__(self):
return repr(self.to_dataframe())
Expand All @@ -305,7 +326,21 @@ def _repr_html_(self):
@property
def name(self):
"""Name of the metric"""
return self.data.name
return self._ser.name

def to_geodataframe(self, crs="EPSG:4326") -> gpd.GeoDataFrame:
import geopandas as gpd

assert "x" in self.data.columns
assert "y" in self.data.columns

gdf = gpd.GeoDataFrame(
self._ser,
geometry=gpd.points_from_xy(self.data.x, self.data.y),
crs=crs,
)

return gdf


class SkillTable:
Expand Down Expand Up @@ -360,21 +395,41 @@ def __init__(self, data: pd.DataFrame):
self.plot = DeprecatedSkillPlotter(self) # TODO remove in v1.1

# TODO: remove?
# data without xy columns
@property
def _df(self) -> pd.DataFrame:
return self.data
return self.data.drop(columns=["x", "y"], errors="ignore")

@property
def metrics(self) -> Collection[str]:
"""List of metrics (columns) in the SkillTable"""
return list(self.data.columns)
return list(self._df.columns)

# TODO: remove?
def __len__(self) -> int:
return len(self._df)

def to_dataframe(self) -> pd.DataFrame:
return self._df.copy()
def to_dataframe(self, drop_xy=True) -> pd.DataFrame:
if drop_xy:
return self.data.drop(columns=["x", "y"], errors="ignore")
else:
return self.data.copy()

def to_geodataframe(self, crs="EPSG:4326") -> gpd.GeoDataFrame:
import geopandas as gpd

assert "x" in self.data.columns
assert "y" in self.data.columns

df = self.to_dataframe(drop_xy=False)

gdf = gpd.GeoDataFrame(
df,
geometry=gpd.points_from_xy(df.x, df.y),
crs=crs,
)

return gdf

def __repr__(self):
return repr(self._df)
Expand All @@ -395,7 +450,12 @@ def __getitem__(self, key) -> SkillArray | SkillTable:
key = list(self.data.columns)[key]
result = self.data[key]
if isinstance(result, pd.Series):
return SkillArray(result)
# I don't think this should be necessary, but in some cases the input doesn't contain x and y
if "x" in self.data.columns and "y" in self.data.columns:
cols = ["x", "y", key]
return SkillArray(self.data[cols])
else:
return SkillArray(result.to_frame())
elif isinstance(result, pd.DataFrame):
return SkillTable(result)
else:
Expand Down Expand Up @@ -511,7 +571,8 @@ def sel(self, query=None, reduce_index=True, **kwargs):
)
return self[value]

df = self._df
# df = self._df
df = self.to_dataframe(drop_xy=False)

for key, value in kwargs.items():
if key in df.index.names:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_aggregated_skill.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_skill_sel_metrics_str(cc1):

with pytest.warns(FutureWarning, match="deprecated"):
s2 = s.sel(metrics="rmse")
assert s2.data.name == "rmse"
assert s2.name == "rmse"


def test_skill_sel_metrics_list(cc2):
Expand Down

0 comments on commit 1868b8c

Please sign in to comment.