Skip to content

Commit

Permalink
use matplotlib also for momentum correction
Browse files Browse the repository at this point in the history
  • Loading branch information
rettigl committed Nov 12, 2024
1 parent 855d611 commit 13e4d6a
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 26 deletions.
2 changes: 2 additions & 0 deletions .cspell/custom-dictionary.txt
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ xpos
xratio
xrng
xscale
xticks
xtrans
Xuser
xval
Expand All @@ -414,6 +415,7 @@ ylabel
ypos
yratio
yscale
yticks
ytrans
zain
Zenodo
Expand Down
22 changes: 14 additions & 8 deletions src/sed/calibrator/momentum.py
Original file line number Diff line number Diff line change
Expand Up @@ -1348,7 +1348,9 @@ def view(

if annotated:
tsr, tsc = kwds.pop("textshift", (3, 3))
txtsize = kwds.pop("textsize", 12)
txtsize = kwds.pop("textsize", 10)

title = kwds.pop("title", "")

# Handle unexpected kwds:
handled_kwds = {"figsize"}
Expand All @@ -1358,7 +1360,7 @@ def view(
)

if backend == "matplotlib":
fig_plt, ax = plt.subplots(figsize=figsize)
_, ax = plt.subplots(figsize=figsize)
ax.imshow(image.T, origin=origin, cmap=cmap, **imkwds)

if cross:
Expand All @@ -1368,15 +1370,12 @@ def view(

# Add annotation to the figure
if annotated:
for (
p_keys, # pylint: disable=unused-variable
p_vals,
) in points.items():
for p_keys, p_vals in points.items():
try:
ax.scatter(p_vals[:, 0], p_vals[:, 1], **scatterkwds)
ax.scatter(p_vals[:, 0], p_vals[:, 1], s=15, **scatterkwds)
except IndexError:
try:
ax.scatter(p_vals[0], p_vals[1], **scatterkwds)
ax.scatter(p_vals[0], p_vals[1], s=15, **scatterkwds)
except IndexError:
pass

Expand All @@ -1389,6 +1388,13 @@ def view(
fontsize=txtsize,
)

if crosshair and self.pcent is not None:
for radius in crosshair_radii:
circle = plt.Circle(self.pcent, radius, color="k", fill=False)
ax.add_patch(circle)

ax.set_title(title)

elif backend == "bokeh":
output_notebook(hide_banner=True)
colors = it.cycle(ColorCycle[10])
Expand Down
20 changes: 12 additions & 8 deletions src/sed/core/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,24 +649,28 @@ def generate_splinewarp(
self.mc.spline_warp_estimate(use_center=use_center, **kwds)

if self.mc.slice is not None and self._verbose:
print("Original slice with reference features")
self.mc.view(annotated=True, backend="bokeh", crosshair=True)
self.mc.view(
annotated=True,
backend="matplotlib",
crosshair=True,
title="Original slice with reference features",
)

print("Corrected slice with target features")
self.mc.view(
image=self.mc.slice_corrected,
annotated=True,
points={"feats": self.mc.ptargs},
backend="bokeh",
backend="matplotlib",
crosshair=True,
title="Corrected slice with target features",
)

print("Original slice with target features")
self.mc.view(
image=self.mc.slice,
points={"feats": self.mc.ptargs},
annotated=True,
backend="bokeh",
backend="matplotlib",
title="Original slice with target features",
)

# 3a. Save spline-warp parameters to config file.
Expand Down Expand Up @@ -2384,7 +2388,7 @@ def view_event_histogram(
bins: Sequence[int] = None,
axes: Sequence[str] = None,
ranges: Sequence[tuple[float, float]] = None,
backend: str = "bokeh",
backend: str = "matplotlib",
legend: bool = True,
histkwds: dict = None,
legkwds: dict = None,
Expand All @@ -2403,7 +2407,7 @@ def view_event_histogram(
ranges (Sequence[tuple[float, float]], optional): Value ranges of all
specified axes. Defaults to config["histogram"]["ranges"].
backend (str, optional): Backend of the plotting library
('matplotlib' or 'bokeh'). Defaults to "bokeh".
("matplotlib" or "bokeh"). Defaults to "matplotlib".
legend (bool, optional): Option to include a legend in the histogram plots.
Defaults to True.
histkwds (dict, optional): Keyword arguments for histograms
Expand Down
23 changes: 13 additions & 10 deletions src/sed/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def grid_histogram(
rvs: Sequence,
rvbins: Sequence,
rvranges: Sequence[tuple[float, float]],
backend: str = "bokeh",
backend: str = "matplotlib",
legend: bool = True,
histkwds: dict = None,
legkwds: dict = None,
Expand All @@ -73,30 +73,30 @@ def grid_histogram(
rvs (Sequence): List of names for the random variables (rvs).
rvbins (Sequence): Bin values for all random variables.
rvranges (Sequence[tuple[float, float]]): Value ranges of all random variables.
backend (str, optional): Backend for making the plot ('matplotlib' or 'bokeh').
Defaults to "bokeh".
backend (str, optional): Backend for making the plot ("matplotlib" or "bokeh").
Defaults to "matplotlib".
legend (bool, optional): Option to include a legend in each histogram plot.
Defaults to True.
histkwds (dict, optional): Keyword arguments for histogram plots.
Defaults to None.
legkwds (dict, optional): Keyword arguments for legends. Defaults to None.
**kwds:
- *figsize*: Figure size. Defaults to (14, 8)
- *figsize*: Figure size. Defaults to (6, 4)
"""
if histkwds is None:
histkwds = {}
if legkwds is None:
legkwds = {}

figsz = kwds.pop("figsize", (10, 7))
figsz = kwds.pop("figsize", (6, 4))

if len(kwds) > 0:
raise TypeError(f"grid_histogram() got unexpected keyword arguments {kwds.keys()}.")

if backend == "matplotlib":
nrv = len(rvs)
nrow = int(np.ceil(nrv / ncol))
histtype = kwds.pop("histtype", "step")
histtype = kwds.pop("histtype", "bar")

fig, ax = plt.subplots(nrow, ncol, figsize=figsz)
otherax = ax.copy()
Expand All @@ -114,7 +114,7 @@ def grid_histogram(
**histkwds,
)
if legend:
ax[axind].legend(fontsize=15, **legkwds)
ax[axind].legend(fontsize=10, **legkwds)

otherax[axind] = None

Expand All @@ -128,13 +128,16 @@ def grid_histogram(
**histkwds,
)
if legend:
ax[i].legend(fontsize=15, **legkwds)
ax[i].legend(fontsize=10, **legkwds)

otherax[i] = None

for oax in otherax.flatten():
if oax is not None:
fig.delaxes(oax)
plt.xticks(fontsize=8)
plt.yticks(fontsize=8)
plt.tight_layout()

elif backend == "bokeh":
output_notebook(hide_banner=True)
Expand Down Expand Up @@ -163,7 +166,7 @@ def grid_histogram(
gridplot(
plots, # type: ignore
ncols=ncol,
width=figsz[0] * 30,
height=figsz[1] * 28,
width=figsz[0] * 100 // ncol,
height=figsz[1] * 100 // (len(plots) // ncol),
),
)

0 comments on commit 13e4d6a

Please sign in to comment.