Source code for sheap.Plotting.SheapPlot

"""This module handles ?."""

__author__ = 'felavila'
__all__ = ["SheapPlot",]

from typing import Optional, List, Any
from dataclasses import dataclass
import jax.numpy as jnp
import matplotlib.pyplot as plt
import matplotlib.transforms as mtransforms
from jax import jit
 
from sheap.Profiles.Utils import make_fused_profiles
from sheap.Utils.Constants import DEFAULT_C_KMS


[docs] class SheapPlot: def __init__( self, sheap: Optional["Sheapectral"] = None, fit_result: Optional["FitResult"] = None, spectra: Optional[jnp.ndarray] = None, ): """ Initialize SheapPlot using: - a full Sheapectral object (preferred), or - a FitResult + spectra. """ if sheap is not None: self._from_sheap(sheap) elif fit_result is not None and spectra is not None: self._from_fit_result(fit_result, spectra) else: raise ValueError("Provide either `sheap` or (`fit_result` + `spectra`).") def _from_sheap(self, sheap): self.spec = sheap.spectra #self.max_flux = sheap.max_flux self.result = sheap.result # keep reference if needed result = sheap.result # for convenience self.params = result.params self.scale = result.scale self.uncertainty_params = result.uncertainty_params self.profile_params_index_list = result.profile_params_index_list self.profile_functions = result.profile_functions self.profile_names = result.profile_names self.region_list = result.region_list self.xlim = result.outer_limits self.mask = result.mask self.names = sheap.names self.model_keywords = result.model_keywords or {} self.z = sheap.z self.snr = sheap.snr self.chi2_red = result.chi2_red self.model = jit(make_fused_profiles(self.profile_functions)) def _from_fit_result(self, result, spectra): self.spec = spectra self.scale = jnp.nanmax(spectra[:, 1, :], axis=1) self.params = result.params self.uncertainty_params = result.uncertainty_params self.profile_params_index_list = result.profile_params_index_list self.profile_functions = result.profile_functions self.profile_names = result.profile_names self.region_list = result.region_list self.xlim = result.outer_limits self.mask = result.mask self.names = [str(i) for i in range(self.params.shape[0])] self.model_keywords = result.model_keywords or {} self.z = result.z #self.snr = result.snr self.chi2_red = result.chi2_red #self.fe_mode = self.model_keywords.get("fe_mode") self.model = jit(make_fused_profiles(self.profile_functions))
[docs] def plot(self, n, save=None, add_lines_name=False, residual=True,params=None,add_xline=None, flux_unit=r"$\mathrm{erg\,s^{-1}\,cm^{-2}\,\AA^{-1}}$",add_legend=True,add_extra=True, **kwargs): """Plot spectrum, model components, and residuals for a given index `n`.""" # TODO is time to update this. default_colors = list(plt.rcParams['axes.prop_cycle'].by_key()['color']) filtered_colors = [ c for c in default_colors if c not in ['black', 'red', 'grey', '#7f7f7f',"blue","green"] ] * 50 ylim = kwargs.get("ylim", [0,self.scale[n]]) xlim = kwargs.get("xlim", self.xlim) x_axis, y_axis, yerr = self.spec[n, :] params = params if params is not None else self.params[n] mask = self.mask[n] fit_y = self.model(x_axis, params) if residual: fig, (ax1, ax2) = plt.subplots( 2, 1, sharex=True, figsize=(30, 8), gridspec_kw={'height_ratios': [2, 1], 'hspace': 0.1}, ) else: fig, ax1 = plt.subplots(1, 1, sharex=True, figsize=(30, 8)) trans = mtransforms.blended_transform_factory(ax1.transData, ax1.transAxes) colors_by_region = {"model":"#d62728","broad":"#0f6fb4","narrow":"#559e46","outflow":"#bcbd22","winds":"#17becf","fe":"#8f220c","host":"#9467bd", "continuum":"#000000","data":"#1B1B1B","balmer":"#2C2424","bal":"#803939"} component_ls = {1: "-",2: "--",3: "-.",4: ":", 5: (0, (5, 5)), 6: (0, (3, 5, 1, 5)), 7: (0, (1, 5))} cont_counter = 1 cont_names = {"balmercontinuum":"Balmer Cont.","balmerhighorder":"Higher-order Balmer"} for i, (profile_name, profile_func, region, idxs) in enumerate(zip(self.profile_names,self.profile_functions,self.region_list,self.profile_params_index_list,)): #print(profile_name, profile_func, region, idxs) values = params[idxs] #print(profile_name) if region.region == "continuum" or region.region=="balmer": #print(region.line_name) component_y = profile_func(x_axis, values) line_name = cont_names.get(region.line_name,"Cont.") #region.line_name #print(region) ax1.plot(x_axis, component_y, zorder=3, label = line_name, color= colors_by_region["continuum"],ls = component_ls[cont_counter]) cont_counter += 1 elif "Fe" in profile_name or "fe" in region.region.lower() or region.region == "fe": component_y = profile_func(x_axis, values) ax1.plot(x_axis, component_y, ls=component_ls[1], zorder=3, color=colors_by_region[region.region.lower()],label="Fe II", linewidth=3) elif "host" in region.region.lower(): f = 1.0/0.6028481012658228 component_y = profile_func(x_axis, values) ax1.plot(x_axis, component_y, ls=component_ls[1], zorder=3, color=colors_by_region["host"],label="Host", linewidth=3) else: component_y = profile_func(x_axis, values) label = region.region.capitalize() #print(region.component) zorder = 0 if region.component>1: label = f"{region.region.capitalize()} {region.component}" if "broad" == region.region: zorder = 10 ax1.plot(x_axis, component_y, ls=component_ls[region.component], zorder=zorder, color=colors_by_region[region.region], label=label, linewidth=3) #ax1.axvline(values[1], ls="--", linewidth=1, color="k") if add_lines_name and isinstance(region.region_lines,list): import numpy as np idx_shift = np.where("vshift_kms" == np.array(profile_func.param_names))[0] #print(idx_shift) centers = np.array(region.center) *(1+values[*idx_shift]/DEFAULT_C_KMS)#This is only true for gaussian for ii,c in enumerate(centers): #ax1.axvline(c) if min(xlim) < c < max(xlim): #print(f"- {region.region_lines[ii]}_{region.region}_{region.component}".replace("_", " "),c) label = f"- {region.region_lines[ii]}_{region.region}_{region.component}".replace("_", " ") ypos = 0.25 if "broad" in label else 0.75 ax1.text( c, ypos, label, transform=trans, rotation=90, fontsize=20, zorder=10, ha = "center") elif add_lines_name and min(xlim) < values[1] < max(xlim): label = f"- {region.line_name}_{region.region}_{region.component}".replace( "_", " " ) ypos = 0.25 if "broad" in label else 0.75 ax1.text( values[1], ypos, label, transform=trans, rotation=90, fontsize=20, zorder=10, ha = "center" ) ax1.plot(x_axis, fit_y, linewidth=3, zorder=2, color=colors_by_region["model"],label="Model")# ax1.errorbar(x_axis, y_axis, yerr=yerr, ecolor='dimgray', color=colors_by_region["data"], zorder=1,label="Obs.") ax1.fill_between(x_axis, *ylim, where=mask, color="grey", alpha=0.3, zorder=10) if add_xline: if isinstance(add_xline,(float,int)): add_xline = [add_xline] for KK in add_xline: ax1.axvline(KK,c='#A020F0',linewidth=3) ax1.set_ylabel(f"Flux [{flux_unit}]", fontsize=25) ax1.set_ylim(ylim) ax1.set_xlim(xlim) if add_extra: x0, y0 = 0.65, 1.21 dx = 0.24 # horizontal separation in axes coords (tune) left_lines = f"ID {self.names[n]} ({n})\n z = {self.z[n]}" right_lines = f"SNR = {self.snr[n]:.2f}\n$\\chi_{{\\rm red}}$ = {self.chi2_red[n]:.2f}" # Left column ax1.text( x0, y0, left_lines, fontsize=25, transform=ax1.transAxes, ha="left", va="top", ) # Right column ax1.text( x0 + dx, y0, right_lines, fontsize=25, transform=ax1.transAxes, ha="left", va="top", ) else: ax1.text( 0.75, 1.0, f"ID {self.names[n]} ({n}) \n z = {self.z[n]}", fontsize=25, transform=ax1.transAxes, ha='left', va='bottom', ) #font_legend = ax1.tick_params(axis='both', labelsize=25) ax1.yaxis.offsetText.set_fontsize(25) if add_legend: handles, labels = ax1.get_legend_handles_labels() # Remove duplicates while keeping order unique = {} for h, l in zip(handles, labels): if l not in unique: unique[l] = h ax1.legend(handles=list(unique.values()),labels=list(unique.keys()),fontsize=25, markerscale=0.8,labelspacing=0.5,frameon=False,ncol=3,columnspacing=1.5,handletextpad=0.4,) if residual: residuals = (fit_y - y_axis) / yerr residuals = residuals.at[mask].set(0.0) ax2.axhline(0, ls="--", linewidth=5, color="black") ax2.scatter(x_axis, residuals, alpha=0.9, zorder=10,c="#4C72B0") ax2.set_ylabel("Norm. Res.", fontsize=25) ax2.set_xlabel("Rest wavelength [Å]", fontsize=25) ax2.tick_params(axis='both', labelsize=25, pad=10) else: ax1.set_xlabel("Rest wavelength [Å]", fontsize=25) if save: plt.savefig(save, dpi=300, bbox_inches='tight') #plt.close() else: plt.show()