Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add qq and residual_hist plots on comparer collection #456

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ venv/
ENV/
env.bak/
venv.bak/
.envrc

# Spyder and vscode project settings
.spyderproject
Expand Down
183 changes: 183 additions & 0 deletions modelskill/comparison/_collection_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
if TYPE_CHECKING:
from ._collection import ComparerCollection

import numpy as np
import pandas as pd

from .. import metrics as mtr
from ..utils import _get_idx
from ..settings import options
from ..plotting import taylor_diagram, scatter, TaylorPoint
from ..plotting._misc import _xtick_directional, _ytick_directional, _get_fig_ax
from ._comparer_plotter import quantiles_xy


def _default_univarate_title(kind: str, cc: ComparerCollection) -> str:
Expand Down Expand Up @@ -595,3 +598,183 @@ def box(self, *, ax=None, figsize=None, title=None, **kwargs) -> Axes:
_ytick_directional(ax)

return ax

def qq(
self,
quantiles: int | Sequence[float] | None = None,
*,
title=None,
ax=None,
figsize=None,
**kwargs,
):
"""Make quantile-quantile (q-q) plot of model data and observations.

Primarily used to compare multiple models.

Parameters
----------
quantiles: (int, sequence), optional
number of quantiles for QQ-plot, by default None and will depend on the scatter data length (10, 100 or 1000)
if int, this is the number of points
if sequence (list of floats), represents the desired quantiles (from 0 to 1)
title : str, optional
plot title, default: "Q-Q plot for [observation name]"
ax : matplotlib.axes.Axes, optional
axes to plot on, by default None
figsize : tuple, optional
figure size, by default None
**kwargs
other keyword arguments to plt.plot()

Returns
-------
matplotlib axes

Examples
--------
>>> cc.plot.qq()

"""
cc = self.cc

_, ax = _get_fig_ax(ax, figsize)

df = cc._to_long_dataframe()

xmin, xmax, ymin, ymax = np.inf, -np.inf, np.inf, -np.inf

for model in cc.mod_names:
df_model = df[df.model == model]

x = df_model.obs_val.values
y = df_model.mod_val.values
xq, yq = quantiles_xy(x, y, quantiles)

xmin = min([x.min(), xmin])
xmax = max([x.max(), xmax])
ymin = min([y.min(), ymin])
ymax = max([y.max(), ymax])

ax.plot(xq, yq, ".-", label=model, zorder=4, **kwargs)

# 1:1 line
ax.plot(
[xmin, xmax],
[ymin, ymax],
label=options.plot.scatter.oneone_line.label,
c=options.plot.scatter.oneone_line.color,
zorder=3,
)

ax.axis("square")
ax.set_xlim([xmin, xmax])
ax.set_ylim([ymin, ymax])
ax.minorticks_on()
ax.grid(which="both", axis="both", linewidth="0.2", color="k", alpha=0.6)

ax.legend()
ax.set_xlabel("Observation, " + cc._unit_text)
ax.set_ylabel("Model, " + cc._unit_text)
title = (
_default_univarate_title("Q-Q plot for ", self.cc)
if title is None
else title
)
ax.set_title(title)

if self.is_directional:
_xtick_directional(ax)
_ytick_directional(ax)

return ax

def residual_hist(
self, bins=100, title=None, color=None, figsize=None, ax=None, **kwargs
) -> Axes | list[Axes]:
"""plot histogram of residual values

Parameters
----------
bins : int, optional
specification of bins, by default 100
title : str, optional
plot title, default: Residuals, [name]
color : str, optional
residual color, by default "#8B8D8E"
figsize : tuple, optional
figure size, by default None
ax : Axes | list[Axes], optional
axes to plot on, by default None
**kwargs
other keyword arguments to plt.hist()

Returns
-------
Axes | list[Axes]
"""
cc = self.cc

if cc.n_models == 1:
return self._residual_hist_one_model(
bins=bins,
title=title,
color=color,
figsize=figsize,
ax=ax,
mod_name=cc.mod_names[0],
**kwargs,
)

if ax is not None and len(ax) != len(cc.mod_names):
raise ValueError("Number of axes must match number of models")

axs = ax if ax is not None else [None] * len(cc.mod_names)

for i, mod_name in enumerate(cc.mod_names):
cc_model = cc.sel(model=mod_name)
ax_mod = cc_model.plot.residual_hist(
bins=bins,
title=title,
color=color,
figsize=figsize,
ax=axs[i],
**kwargs,
)
axs[i] = ax_mod

return axs

def _residual_hist_one_model(
self,
bins=100,
title=None,
color=None,
figsize=None,
ax=None,
mod_name=None,
**kwargs,
) -> Axes:
"""Residual histogram for one model only"""
_, ax = _get_fig_ax(ax, figsize)

df = self.cc.sel(model=mod_name)._to_long_dataframe()
residuals = df.mod_val.values - df.obs_val.values

default_color = "#8B8D8E"
color = default_color if color is None else color
title = (
_default_univarate_title(f"Residuals, Model {mod_name}", self.cc)
if title is None
else title
)
ax.hist(residuals, bins=bins, color=color, **kwargs)
ax.set_title(title)
ax.set_xlabel(f"Residuals of {self.cc._unit_text}")

if self.is_directional:
ticks = np.linspace(-180, 180, 9)
ax.set_xticks(ticks)
ax.set_xlim(-180, 180)

return ax
16 changes: 15 additions & 1 deletion tests/test_comparercollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,13 +491,18 @@ def test_plots_directional(cc):
assert ax.get_xlim() == (0.0, 360.0)


PLOT_FUNCS_RETURNING_MANY_AX = ["scatter", "hist", "residual_hist"]


@pytest.fixture(
params=[
"scatter",
"kde",
"hist",
"taylor",
"box",
"qq",
"residual_hist",
]
)
def cc_plot_function(cc, request):
Expand All @@ -509,7 +514,7 @@ def cc_plot_function(cc, request):

func = getattr(cc.plot, request.param)
# special cases require selecting a model
if request.param in ["scatter", "hist"]:
if request.param in PLOT_FUNCS_RETURNING_MANY_AX:

def func(**kwargs):
wrapped_func = getattr(cc.sel(model=[0]).plot, request.param)
Expand All @@ -518,6 +523,15 @@ def func(**kwargs):
return func


@pytest.mark.parametrize("kind", PLOT_FUNCS_RETURNING_MANY_AX)
def test_plots_returning_multiple_axes(pc, kind):
n_models = 2
func = getattr(pc.plot, kind)
ax = func()
assert len(ax) == n_models
assert all(isinstance(a, plt.Axes) for a in ax)


def test_plot_returns_an_object(cc_plot_function):
obj = cc_plot_function()
assert obj is not None
Expand Down
Loading