diff --git a/modelskill/skill.py b/modelskill/skill.py index 444cb2aaa..6d34a1acd 100644 --- a/modelskill/skill.py +++ b/modelskill/skill.py @@ -25,6 +25,14 @@ def _validate_multi_index(index, min_levels=2, max_levels=2): class SkillArrayPlotter: + """SkillArrayPlotter object for visualization of a single metric (SkillArray) + + plot.line() : line plot + plot.bar() : bar chart + plot.barh() : horizontal bar chart + plot.grid() : colored grid + """ + def __init__(self, skillarray): self.skillarray = skillarray @@ -34,11 +42,11 @@ def _name_to_title_in_kwargs(self, kwargs): kwargs["title"] = self.skillarray.name def _get_plot_df(self, level: int | str = 0) -> pd.DataFrame: - s = self.skillarray.data - if isinstance(s.index, pd.MultiIndex): - df = s.unstack(level=level) + ser = self.skillarray.data + if isinstance(ser.index, pd.MultiIndex): + df = ser.unstack(level=level) else: - df = s.to_frame() + df = ser.to_frame() return df def line( @@ -46,7 +54,7 @@ def line( level: int | str = 0, **kwargs, ): - """plot statistic as a lines using pd.DataFrame.plot.line() + """Plot statistic as a lines using pd.DataFrame.plot.line() Primarily for MultiIndex skill objects, e.g. multiple models and multiple observations @@ -81,7 +89,7 @@ def line( return axes def bar(self, level: int | str = 0, **kwargs): - """plot statistic as bar chart using pd.DataFrame.plot.bar() + """Plot statistic as bar chart using pd.DataFrame.plot.bar() Parameters ---------- @@ -108,7 +116,7 @@ def bar(self, level: int | str = 0, **kwargs): return df.plot.bar(**kwargs) def barh(self, level: int | str = 0, **kwargs): - """plot statistic as horizontal bar chart using pd.DataFrame.plot.barh() + """Plot statistic as horizontal bar chart using pd.DataFrame.plot.barh() Parameters ---------- @@ -143,7 +151,7 @@ def grid( title=None, cmap=None, ): - """plot statistic as a colored grid, optionally with values in the cells. + """Plot statistic as a colored grid, optionally with values in the cells. Primarily for MultiIndex skill objects, e.g. multiple models and multiple observations @@ -269,14 +277,14 @@ def grid(self, field: str, **kwargs): class SkillArray: - """SkillArray object for visualization and analysis obtained by - selecting a single metric from a SkillTable. The object wraps pd.Series + """SkillArray object for visualization obtained by + selecting a single metric from a SkillTable. Examples -------- >>> s = cc.skill() # SkillTable >>> s.rmse # SkillArray - + >>> s.rmse.plot.line() """ def __init__(self, data: pd.Series) -> None: @@ -285,6 +293,7 @@ def __init__(self, data: pd.Series) -> None: self.plot = SkillArrayPlotter(self) def to_dataframe(self) -> pd.DataFrame: + """Output as pd.DataFrame""" return self.data.to_dataframe() def __repr__(self): @@ -295,18 +304,23 @@ def _repr_html_(self): @property def name(self): + """Name of the metric""" return self.data.name class SkillTable: """ SkillTable object for visualization and analysis returned by - the comparer's skill method. The object wraps the pd.Dataframe + the comparer's `skill` method. The object wraps the pd.DataFrame class which can be accessed from the attribute `data`. + The columns are assumed to be metrics and data for a single metric + can be accessed by e.g. `s.rmse` or `s["rmse"]`. The resulting object + can be used for plotting. + Examples -------- - >>> s = comparer.skill() + >>> s = cc.skill() >>> s.mod_names ['SW_1', 'SW_2'] >>> s.style() @@ -434,35 +448,9 @@ def _get_index_level_by_name(self, name): return [] # raise ValueError(f"name {name} not in index {list(self.index.names)}") - def _idx_to_name(self, index, idx) -> str: - """Assumes that index is valid and idx is int""" - names = self._get_index_level_by_name(index) - n = len(names) - if (idx < 0) or (idx >= n): - raise KeyError(f"Id {idx} is out of bounds for index {index} (0, {n})") - return names[idx] - - def _sel_from_index(self, df, key, value): - if (not isinstance(value, str)) and isinstance(value, Iterable): - for i, v in enumerate(value): - dfi = self._sel_from_index(df, key, v) - if i == 0: - dfout = dfi - else: - dfout = pd.concat([dfout, dfi]) - return dfout - - if isinstance(value, int): - value = self._idx_to_name(key, value) - - if isinstance(df.index, pd.MultiIndex): - df = df.xs(value, level=key, drop_level=False) - else: - df = df[df.index == value] # .copy() - return df - def query(self, query: str) -> SkillTable: """Select a subset of the SkillTable by a query string + wrapping pd.DataFrame.query() Parameters @@ -539,6 +527,33 @@ def sel(self, query=None, reduce_index=True, **kwargs): df = self._reduce_index(df) return self.__class__(df) + def _sel_from_index(self, df, key, value): + if (not isinstance(value, str)) and isinstance(value, Iterable): + for i, v in enumerate(value): + dfi = self._sel_from_index(df, key, v) + if i == 0: + dfout = dfi + else: + dfout = pd.concat([dfout, dfi]) + return dfout + + if isinstance(value, int): + value = self._idx_to_name(key, value) + + if isinstance(df.index, pd.MultiIndex): + df = df.xs(value, level=key, drop_level=False) + else: + df = df[df.index == value] # .copy() + return df + + def _idx_to_name(self, index, idx) -> str: + """Assumes that index is valid and idx is int""" + names = self._get_index_level_by_name(index) + n = len(names) + if (idx < 0) or (idx >= n): + raise KeyError(f"Id {idx} is out of bounds for index {index} (0, {n})") + return names[idx] + def _reduce_index(self, df): """Remove unnecessary levels of MultiIndex""" df.index = df.index.remove_unused_levels() @@ -548,20 +563,6 @@ def _reduce_index(self, df): levels_to_reset.append(j) return df.reset_index(level=levels_to_reset) - # TODO: remove plot_* methods in v1.1; warnings are not needed - # as the refering method is also deprecated - def plot_line(self, **kwargs): - return self.plot.line(**kwargs) - - def plot_bar(self, **kwargs): - return self.plot.bar(**kwargs) - - def plot_barh(self, **kwargs): - return self.plot.barh(**kwargs) - - def plot_grid(self, **kwargs): - return self.plot.grid(**kwargs) - def round(self, decimals=3): """Round all values in SkillTable @@ -605,7 +606,7 @@ def style( Examples -------- - >>> s = comparer.skill() + >>> s = cc.skill() >>> s.style() >>> s.style(precision=1, metrics="rmse") >>> s.style(cmap="Blues", show_best=False) @@ -712,3 +713,17 @@ def _style_max(self, s): "text-decoration: underline; font-style: italic; font-weight: bold;" ) return [cell_style if v else "" for v in (s == s.max())] + + # TODO: remove plot_* methods in v1.1; warnings are not needed + # as the refering method is also deprecated + def plot_line(self, **kwargs): + return self.plot.line(**kwargs) + + def plot_bar(self, **kwargs): + return self.plot.bar(**kwargs) + + def plot_barh(self, **kwargs): + return self.plot.barh(**kwargs) + + def plot_grid(self, **kwargs): + return self.plot.grid(**kwargs)