Skip to content

Commit

Permalink
better docstrings and re-ordering
Browse files Browse the repository at this point in the history
  • Loading branch information
jsmariegaard committed Dec 13, 2023
1 parent 5071718 commit 294684a
Showing 1 changed file with 70 additions and 55 deletions.
125 changes: 70 additions & 55 deletions modelskill/skill.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ def _validate_multi_index(index, min_levels=2, max_levels=2):


class SkillArrayPlotter:
"""SkillArrayPlotter object for visualization of a single metric (SkillArray)
plot.line() : line plot
plot.bar() : bar chart
plot.barh() : horizontal bar chart
plot.grid() : colored grid
"""

def __init__(self, skillarray):
self.skillarray = skillarray

Expand All @@ -34,19 +42,19 @@ def _name_to_title_in_kwargs(self, kwargs):
kwargs["title"] = self.skillarray.name

def _get_plot_df(self, level: int | str = 0) -> pd.DataFrame:
s = self.skillarray.data
if isinstance(s.index, pd.MultiIndex):
df = s.unstack(level=level)
ser = self.skillarray.data
if isinstance(ser.index, pd.MultiIndex):
df = ser.unstack(level=level)
else:
df = s.to_frame()
df = ser.to_frame()
return df

def line(
self,
level: int | str = 0,
**kwargs,
):
"""plot statistic as a lines using pd.DataFrame.plot.line()
"""Plot statistic as a lines using pd.DataFrame.plot.line()
Primarily for MultiIndex skill objects, e.g. multiple models and multiple observations
Expand Down Expand Up @@ -81,7 +89,7 @@ def line(
return axes

def bar(self, level: int | str = 0, **kwargs):
"""plot statistic as bar chart using pd.DataFrame.plot.bar()
"""Plot statistic as bar chart using pd.DataFrame.plot.bar()
Parameters
----------
Expand All @@ -108,7 +116,7 @@ def bar(self, level: int | str = 0, **kwargs):
return df.plot.bar(**kwargs)

def barh(self, level: int | str = 0, **kwargs):
"""plot statistic as horizontal bar chart using pd.DataFrame.plot.barh()
"""Plot statistic as horizontal bar chart using pd.DataFrame.plot.barh()
Parameters
----------
Expand Down Expand Up @@ -143,7 +151,7 @@ def grid(
title=None,
cmap=None,
):
"""plot statistic as a colored grid, optionally with values in the cells.
"""Plot statistic as a colored grid, optionally with values in the cells.
Primarily for MultiIndex skill objects, e.g. multiple models and multiple observations
Expand Down Expand Up @@ -269,14 +277,14 @@ def grid(self, field: str, **kwargs):


class SkillArray:
"""SkillArray object for visualization and analysis obtained by
selecting a single metric from a SkillTable. The object wraps pd.Series
"""SkillArray object for visualization obtained by
selecting a single metric from a SkillTable.
Examples
--------
>>> s = cc.skill() # SkillTable
>>> s.rmse # SkillArray
>>> s.rmse.plot.line()
"""

def __init__(self, data: pd.Series) -> None:
Expand All @@ -285,6 +293,7 @@ def __init__(self, data: pd.Series) -> None:
self.plot = SkillArrayPlotter(self)

def to_dataframe(self) -> pd.DataFrame:
"""Output as pd.DataFrame"""
return self.data.to_dataframe()

def __repr__(self):
Expand All @@ -295,18 +304,23 @@ def _repr_html_(self):

@property
def name(self):
"""Name of the metric"""
return self.data.name


class SkillTable:
"""
SkillTable object for visualization and analysis returned by
the comparer's skill method. The object wraps the pd.Dataframe
the comparer's `skill` method. The object wraps the pd.DataFrame
class which can be accessed from the attribute `data`.
The columns are assumed to be metrics and data for a single metric
can be accessed by e.g. `s.rmse` or `s["rmse"]`. The resulting object
can be used for plotting.
Examples
--------
>>> s = comparer.skill()
>>> s = cc.skill()
>>> s.mod_names
['SW_1', 'SW_2']
>>> s.style()
Expand Down Expand Up @@ -434,35 +448,9 @@ def _get_index_level_by_name(self, name):
return []
# raise ValueError(f"name {name} not in index {list(self.index.names)}")

def _idx_to_name(self, index, idx) -> str:
"""Assumes that index is valid and idx is int"""
names = self._get_index_level_by_name(index)
n = len(names)
if (idx < 0) or (idx >= n):
raise KeyError(f"Id {idx} is out of bounds for index {index} (0, {n})")
return names[idx]

def _sel_from_index(self, df, key, value):
if (not isinstance(value, str)) and isinstance(value, Iterable):
for i, v in enumerate(value):
dfi = self._sel_from_index(df, key, v)
if i == 0:
dfout = dfi
else:
dfout = pd.concat([dfout, dfi])
return dfout

if isinstance(value, int):
value = self._idx_to_name(key, value)

if isinstance(df.index, pd.MultiIndex):
df = df.xs(value, level=key, drop_level=False)
else:
df = df[df.index == value] # .copy()
return df

def query(self, query: str) -> SkillTable:
"""Select a subset of the SkillTable by a query string
wrapping pd.DataFrame.query()
Parameters
Expand Down Expand Up @@ -539,6 +527,33 @@ def sel(self, query=None, reduce_index=True, **kwargs):
df = self._reduce_index(df)
return self.__class__(df)

def _sel_from_index(self, df, key, value):
if (not isinstance(value, str)) and isinstance(value, Iterable):
for i, v in enumerate(value):
dfi = self._sel_from_index(df, key, v)
if i == 0:
dfout = dfi
else:
dfout = pd.concat([dfout, dfi])
return dfout

if isinstance(value, int):
value = self._idx_to_name(key, value)

if isinstance(df.index, pd.MultiIndex):
df = df.xs(value, level=key, drop_level=False)
else:
df = df[df.index == value] # .copy()
return df

def _idx_to_name(self, index, idx) -> str:
"""Assumes that index is valid and idx is int"""
names = self._get_index_level_by_name(index)
n = len(names)
if (idx < 0) or (idx >= n):
raise KeyError(f"Id {idx} is out of bounds for index {index} (0, {n})")
return names[idx]

def _reduce_index(self, df):
"""Remove unnecessary levels of MultiIndex"""
df.index = df.index.remove_unused_levels()
Expand All @@ -548,20 +563,6 @@ def _reduce_index(self, df):
levels_to_reset.append(j)
return df.reset_index(level=levels_to_reset)

# TODO: remove plot_* methods in v1.1; warnings are not needed
# as the refering method is also deprecated
def plot_line(self, **kwargs):
return self.plot.line(**kwargs)

def plot_bar(self, **kwargs):
return self.plot.bar(**kwargs)

def plot_barh(self, **kwargs):
return self.plot.barh(**kwargs)

def plot_grid(self, **kwargs):
return self.plot.grid(**kwargs)

def round(self, decimals=3):
"""Round all values in SkillTable
Expand Down Expand Up @@ -605,7 +606,7 @@ def style(
Examples
--------
>>> s = comparer.skill()
>>> s = cc.skill()
>>> s.style()
>>> s.style(precision=1, metrics="rmse")
>>> s.style(cmap="Blues", show_best=False)
Expand Down Expand Up @@ -712,3 +713,17 @@ def _style_max(self, s):
"text-decoration: underline; font-style: italic; font-weight: bold;"
)
return [cell_style if v else "" for v in (s == s.max())]

# TODO: remove plot_* methods in v1.1; warnings are not needed
# as the refering method is also deprecated
def plot_line(self, **kwargs):
return self.plot.line(**kwargs)

def plot_bar(self, **kwargs):
return self.plot.bar(**kwargs)

def plot_barh(self, **kwargs):
return self.plot.barh(**kwargs)

def plot_grid(self, **kwargs):
return self.plot.grid(**kwargs)

0 comments on commit 294684a

Please sign in to comment.