Source code for sheap.Plotting.SheapPlot

"""This module handles ?."""

__author__ = 'felavila'
__all__ = ["SheapPlot",]

from typing import Optional, List, Any
from dataclasses import dataclass
import jax.numpy as jnp
import matplotlib.pyplot as plt
import matplotlib.transforms as mtransforms
from jax import jit
import numpy as np 
 
from sheap.Profiles.Utils import make_fused_profiles
from sheap.Utils.Constants import DEFAULT_C_KMS


[docs] class SheapPlot: def __init__( self, sheap: Optional["Sheapectral"] = None, fit_result: Optional["FitResult"] = None, spectra: Optional[jnp.ndarray] = None, ): """ Initialize SheapPlot using: - a full Sheapectral object (preferred), or - a FitResult + spectra. """ if sheap is not None: self._from_sheap(sheap) elif fit_result is not None and spectra is not None: self._from_fit_result(fit_result, spectra) else: raise ValueError("Provide either `sheap` or (`fit_result` + `spectra`).") def _from_sheap(self, sheap): self.spec = sheap.spectra #self.max_flux = sheap.max_flux self.result = sheap.result # keep reference if needed result = sheap.result # for convenience self.params = result.params self.scale = result.scale self.uncertainty_params = result.uncertainty_params self.profile_params_index_list = result.profile_params_index_list self.profile_functions = result.profile_functions self.profile_names = result.profile_names self.region_list = result.region_list self.xlim = result.outer_limits self.mask = result.mask self.names = sheap.names self.model_keywords = result.model_keywords or {} self.z = sheap.z self.snr = sheap.snr self.chi2_red = result.chi2_red self.model = jit(make_fused_profiles(self.profile_functions)) def _from_fit_result(self, result, spectra): self.spec = spectra self.scale = jnp.nanmax(spectra[:, 1, :], axis=1) self.params = result.params self.uncertainty_params = result.uncertainty_params self.profile_params_index_list = result.profile_params_index_list self.profile_functions = result.profile_functions self.profile_names = result.profile_names self.region_list = result.region_list self.xlim = result.outer_limits self.mask = result.mask self.names = [str(i) for i in range(self.params.shape[0])] self.model_keywords = result.model_keywords or {} self.z = result.z #self.snr = result.snr self.chi2_red = result.chi2_red #self.fe_mode = self.model_keywords.get("fe_mode") self.model = jit(make_fused_profiles(self.profile_functions))
[docs] def plot(self, n, save=None, add_lines_name=False, residual=True,params=None,add_xline=None, flux_unit=r"$\mathrm{erg\,s^{-1}\,cm^{-2}\,\AA^{-1}}$",add_legend=True,add_extra=True, **kwargs): """Plot spectrum, model components, and residuals for a given index `n`.""" # TODO is time to update this. # if not isinstance(n,int): # n = np.random.randint(0,self.spec.shape[0]) default_colors = list(plt.rcParams['axes.prop_cycle'].by_key()['color']) filtered_colors = [ c for c in default_colors if c not in ['black', 'red', 'grey', '#7f7f7f',"blue","green"] ] * 50 ylim = kwargs.get("ylim", [0,self.scale[n]]) xlim = kwargs.get("xlim", self.xlim) x_axis, y_axis, yerr = self.spec[n, :] params = params if params is not None else self.params[n] mask = self.mask[n] fit_y = self.model(x_axis, params) if residual: fig, (ax1, ax2) = plt.subplots( 2, 1, sharex=True, figsize=(30, 8), gridspec_kw={'height_ratios': [2, 1], 'hspace': 0.1}, ) else: fig, ax1 = plt.subplots(1, 1, sharex=True, figsize=(30, 8)) trans = mtransforms.blended_transform_factory(ax1.transData, ax1.transAxes) colors_by_region = {"model":"#d62728","broad":"#0f6fb4","narrow":"#559e46","outflow":"#bcbd22","winds":"#17becf","fe":"#8f220c","host":"#9467bd", "continuum":"#000000","data":"#1B1B1B","balmer":"#2C2424","bal":"#803939"} component_ls = {1: "-",2: "--",3: "-.",4: ":", 5: (0, (5, 5)), 6: (0, (3, 5, 1, 5)), 7: (0, (1, 5))} cont_counter = 1 cont_names = {"balmercontinuum":"Balmer Cont.","balmerhighorder":"Higher-order Balmer"} for i, (profile_name, profile_func, region, idxs) in enumerate(zip(self.profile_names,self.profile_functions,self.region_list,self.profile_params_index_list,)): #print(profile_name, profile_func, region, idxs) values = params[idxs] #print(profile_name) if region.region == "continuum" or region.region=="balmer": #print(region.line_name) component_y = profile_func(x_axis, values) line_name = cont_names.get(region.line_name,"Cont.") #region.line_name #print(region) ax1.plot(x_axis, component_y, zorder=3, label = line_name, color= colors_by_region["continuum"],ls = component_ls[cont_counter]) cont_counter += 1 elif "Fe" in profile_name or "fe" in region.region.lower() or region.region == "fe": component_y = profile_func(x_axis, values) ax1.plot(x_axis, component_y, ls=component_ls[1], zorder=3, color=colors_by_region[region.region.lower()],label="Fe II", linewidth=3) elif "host" in region.region.lower(): f = 1.0/0.6028481012658228 component_y = profile_func(x_axis, values) ax1.plot(x_axis, component_y, ls=component_ls[1], zorder=3, color=colors_by_region["host"],label="Host", linewidth=3) else: component_y = profile_func(x_axis, values) label = region.region.capitalize() #print(region.component) zorder = 0 if region.component>1: label = f"{region.region.capitalize()} {region.component}" if "broad" == region.region: zorder = 10 ax1.plot(x_axis, component_y, ls=component_ls[region.component], zorder=zorder, color=colors_by_region[region.region], label=label, linewidth=3) #ax1.axvline(values[1], ls="--", linewidth=1, color="k") if add_lines_name and isinstance(region.region_lines,list): import numpy as np idx_shift = np.where("vshift_kms" == np.array(profile_func.param_names))[0] #print(idx_shift) centers = np.array(region.center) *(1+values[*idx_shift]/DEFAULT_C_KMS)#This is only true for gaussian for ii,c in enumerate(centers): #ax1.axvline(c) if min(xlim) < c < max(xlim): #print(f"- {region.region_lines[ii]}_{region.region}_{region.component}".replace("_", " "),c) label = f"- {region.region_lines[ii]}_{region.region}_{region.component}".replace("_", " ") ypos = 0.25 if "broad" in label else 0.75 ax1.text( c, ypos, label, transform=trans, rotation=90, fontsize=20, zorder=10, ha = "center") elif add_lines_name and min(xlim) < values[1] < max(xlim): label = f"- {region.line_name}_{region.region}_{region.component}".replace( "_", " " ) ypos = 0.25 if "broad" in label else 0.75 ax1.text( values[1], ypos, label, transform=trans, rotation=90, fontsize=20, zorder=10, ha = "center" ) ax1.plot(x_axis, fit_y, linewidth=3, zorder=2, color=colors_by_region["model"],label="Model")# ax1.errorbar(x_axis, y_axis, yerr=yerr, ecolor='dimgray', color=colors_by_region["data"], zorder=1,label="Obs.") ax1.fill_between(x_axis, *ylim, where=mask, color="grey", alpha=0.3, zorder=10) if add_xline: if isinstance(add_xline,(float,int)): add_xline = [add_xline] for KK in add_xline: ax1.axvline(KK,c='#A020F0',linewidth=3) ax1.set_ylabel(f"Flux [{flux_unit}]", fontsize=25) ax1.set_ylim(ylim) ax1.set_xlim(xlim) if add_extra: x0, y0 = 0.65, 1.21 dx = 0.24 # horizontal separation in axes coords (tune) left_lines = f"ID {self.names[n]} ({n})\n z = {self.z[n]}" right_lines = f"SNR = {self.snr[n]:.2f}\n$\\chi_{{\\rm red}}$ = {self.chi2_red[n]:.2f}" # Left column ax1.text( x0, y0, left_lines, fontsize=25, transform=ax1.transAxes, ha="left", va="top", ) # Right column ax1.text( x0 + dx, y0, right_lines, fontsize=25, transform=ax1.transAxes, ha="left", va="top", ) else: ax1.text( 0.75, 1.0, f"ID {self.names[n]} ({n}) \n z = {self.z[n]}", fontsize=25, transform=ax1.transAxes, ha='left', va='bottom', ) #font_legend = ax1.tick_params(axis='both', labelsize=25) ax1.yaxis.offsetText.set_fontsize(25) if add_legend: handles, labels = ax1.get_legend_handles_labels() # Remove duplicates while keeping order unique = {} for h, l in zip(handles, labels): if l not in unique: unique[l] = h ax1.legend(handles=list(unique.values()),labels=list(unique.keys()),fontsize=25, markerscale=0.8,labelspacing=0.5,frameon=False,ncol=3,columnspacing=1.5,handletextpad=0.4,) if residual: residuals = (fit_y - y_axis) / yerr residuals = residuals.at[mask].set(0.0) ax2.axhline(0, ls="--", linewidth=5, color="black") ax2.scatter(x_axis, residuals, alpha=0.9, zorder=10,c="#4C72B0") ax2.set_ylabel("Norm. Res.", fontsize=25) ax2.set_xlabel("Rest wavelength [Å]", fontsize=25) ax2.tick_params(axis='both', labelsize=25, pad=10) else: ax1.set_xlabel("Rest wavelength [Å]", fontsize=25) if save: plt.savefig(save, dpi=300, bbox_inches='tight') #plt.close() else: plt.show()