Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite of Model.plot() #281

Merged
merged 5 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 0 additions & 191 deletions pahfit/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
import matplotlib as mpl

from pahfit.instrument import within_segment, fwhm
from pahfit.errors import PAHFITModelError
Expand Down Expand Up @@ -265,196 +264,6 @@ def model_from_param_info(param_info):

return model

@staticmethod
def plot(axs, x, y, yerr, model, model_samples=1000, scalefac_resid=2):
"""
Plot model using axis object.

Parameters
----------
axs : matplotlib.axis objects
where to put the plot
x : floats
wavelength points
y : floats
observed spectrum
yerr: floats
observed spectrum uncertainties
model : PAHFITBase model (astropy modeling CompoundModel)
model giving all the components and parameters
model_samples : int
Total number of wavelength points to allocate to the model display
scalefac_resid : float
Factor multiplying the standard deviation of the residuals to adjust plot limits
"""
# remove units if they are present
if hasattr(x, "value"):
x = x.value
if hasattr(y, "value"):
y = y.value
if hasattr(yerr, "value"):
yerr = yerr.value

# Fine x samples for model fit
x_mod = np.logspace(np.log10(min(x)), np.log10(max(x)), model_samples)

# spectrum and best fit model
ax = axs[0]
ax.set_yscale("linear")
ax.set_xscale("log")
ax.minorticks_on()
ax.tick_params(axis="both",
which="major",
top="on",
right="on",
direction="in",
length=10)
ax.tick_params(axis="both",
which="minor",
top="on",
right="on",
direction="in",
length=5)

ax_att = ax.twinx() # axis for plotting the extinction curve
ax_att.tick_params(which="minor", direction="in", length=5)
ax_att.tick_params(which="major", direction="in", length=10)
ax_att.minorticks_on()

# get the extinction model (probably a better way to do this)
ext_model = None
for cmodel in model:
if isinstance(cmodel, S07_attenuation):
ext_model = cmodel(x_mod)

# get additional extinction components that can be
# characterized by functional forms (Drude profile in this case)
for cmodel in model:
if isinstance(cmodel, att_Drude1D):
if ext_model is not None:
ext_model *= cmodel(x_mod)
else:
ext_model = cmodel(x_mod)
ax_att.plot(x_mod, ext_model, "k--", alpha=0.5)
ax_att.set_ylabel("Attenuation")
ax_att.set_ylim(0, 1.1)

# Define legend lines
Leg_lines = [
mpl.lines.Line2D([0], [0], color="k", linestyle="--", lw=2),
mpl.lines.Line2D([0], [0], color="#FE6100", lw=2),
mpl.lines.Line2D([0], [0], color="#648FFF", lw=2, alpha=0.5),
mpl.lines.Line2D([0], [0], color="#DC267F", lw=2, alpha=0.5),
mpl.lines.Line2D([0], [0], color="#785EF0", lw=2, alpha=1),
mpl.lines.Line2D([0], [0], color="#FFB000", lw=2, alpha=0.5),
]

# create the continum compound model (base for plotting lines)
cont_components = []

for cmodel in model:
if isinstance(cmodel, BlackBody1D):
cont_components.append(cmodel)
# plot as we go
ax.plot(x_mod,
cmodel(x_mod) * ext_model / x_mod,
"#FFB000",
alpha=0.5)
cont_model = cont_components[0]
for cmodel in cont_components[1:]:
cont_model += cmodel
cont_y = cont_model(x_mod)

# now plot the dust bands and lines
for cmodel in model:
if isinstance(cmodel, Gaussian1D):
ax.plot(
x_mod,
(cont_y + cmodel(x_mod)) * ext_model / x_mod,
"#DC267F",
alpha=0.5,
)
if isinstance(cmodel, Drude1D):
ax.plot(
x_mod,
(cont_y + cmodel(x_mod)) * ext_model / x_mod,
"#648FFF",
alpha=0.5,
)

ax.plot(x_mod, cont_y * ext_model / x_mod, "#785EF0", alpha=1)

ax.plot(x_mod, model(x_mod) / x_mod, "#FE6100", alpha=1)
ax.errorbar(
x,
y / x,
yerr=yerr / x,
fmt="o",
markeredgecolor="k",
markerfacecolor="none",
ecolor="k",
elinewidth=0.2,
capsize=0.5,
markersize=6,
)

ax.set_ylim(0)
ax.set_ylabel(r"$\nu F_{\nu}$")

ax.legend(
Leg_lines,
[
"S07_attenuation",
"Spectrum Fit",
"Dust Features",
r"Atomic and $H_2$ Lines",
"Total Continuum Emissions",
"Continuum Components",
],
prop={"size": 10},
loc="best",
facecolor="white",
framealpha=1,
ncol=3,
)

# residuals, lower sub-figure
res = (y - model(x)) / x
std = np.std(res)
ax = axs[1]

ax.set_yscale("linear")
ax.set_xscale("log")
ax.tick_params(axis="both",
which="major",
top="on",
right="on",
direction="in",
length=10)
ax.tick_params(axis="both",
which="minor",
top="on",
right="on",
direction="in",
length=5)
ax.minorticks_on()

# Custom X axis ticks
ax.xaxis.set_ticks(
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 14, 16, 20, 25, 30, 40])

ax.axhline(0, linestyle="--", color="gray", zorder=0)
ax.plot(x, res, "ko-", fillstyle="none", zorder=1)
ax.set_ylim(-scalefac_resid * std, scalefac_resid * std)
ax.set_xlim(np.floor(np.amin(x)), np.ceil(np.amax(x)))
ax.set_xlabel(r"$\lambda$ [$\mu m$]")
ax.set_ylabel("Residuals [%]")

# scalar x-axis marks
ax.xaxis.set_minor_formatter(mpl.ticker.ScalarFormatter())
ax.xaxis.set_major_formatter(mpl.ticker.ScalarFormatter())

@staticmethod
def update_dictionary(feature_dict, instrumentname, update_fwhms=False, redshift=0):
"""
Update parameter dictionary based on the instrument name.
Expand Down
Loading
Loading