Source code for sheap.Sheapectral.Sheapectral

"""
Main sheap Interface
====================

Provides the `Sheapectral` class, the high-level entry point for
loading, correcting, fitting, and analyzing AGN spectra with *sheap*.

Contents
--------
- **Spectral I/O**: load spectra from arrays or files.
- **Corrections**: apply Galactic extinction and redshift corrections.
- **Modeling**: build complex spectral regions via `SheapModelBuilder`.
- **Fitting**: run JAX-based optimization with `SheapModelFitting`.
- **Posterior Sampling**: estimate parameters using single, pseudo-MC, MC, or MCMC.
- **Persistence**: save/load full state with pickle.
- **Visualization**: quicklook plotting and model visualization with `SheapPlot`.

Notes
-----
- Input spectra are expected in shape `(n_objects, 3[,4], n_pixels)`,
  with channels = (wavelength, flux, error[, wdisp]).
- Velocity resolution (FWHM) is computed from dispersion when available.
- Main workflow:

  .. code-block:: python

	 sheap = Sheapectral("spectrum.fits", z=0.5, coords=(l, b))
	 sheap.makemodel(6500, 6600, n_narrow=1, n_broad=2)
	 sheap.fitmodel()
	 sheap.posteriors(sampling_method="montecarlo")

- Results are stored in `self.result` (`SheapResult`).
"""

from __future__ import annotations

__author__ = 'felavila'


__all__ = ["Sheapectral","logger",]

import logging
import pickle
import sys
import warnings
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union
import time

import jax.numpy as jnp
import numpy as np

from sheap.Core import SpectralLine,SheapResult,ArrayLike

from sheap.Sheapectral.Utils.SpectralSetup import pad_error_channel,ensure_sfd_data,profile_functions_from_region_list
from sheap.SheapModelFitting.SheapModelFitting import SheapModelFitting
from sheap.SheapModelBuilder.SheapModelBuilder import SheapModelBuilder
from sheap.SheaProducts.SheaProducts import SheaProducts
from sheap.Plotting.SheapPlot import SheapPlot
from sheap.Utils.Constants  import DEFAULT_C_KMS



logger = logging.getLogger(__name__)


#WE CAN mode fast, we will stay in this 32 to go faster. 
#TODO -> helper to given "samples_phys" params can calculate again the SheaProducts
#TODO -> update this part to be able to show the actual name of the parameter lets say line with 0,1,2,3,4 is easier for code reason but for visualitation could be messy 
#TODO -> check the buge associated to have nan in the flux chanel is best change for 0 and just bust the error
[docs] class Sheapectral: """ Main interface class for loading, correcting, fitting, and analyzing AGN spectra. This class handles: - Spectral I/O and validation - Extinction and redshift correction - Spectral region definition and model building - JAX-based optimization and uncertainty estimation - Posterior sampling using Monte Carlo or MCMC - Saving/loading results from pickle - Quick visualization and result summaries Parameters ---------- spectra : str or jnp.ndarray Input spectra. If a string, it should be a path to a file readable by `np.loadtxt`. Expected shape after parsing: (n_objects, 3[, or 4], n_pixels). z : float or jnp.ndarray, optional Redshift(s) for each spectrum. Scalar or 1D array of shape (n_objects,). coords : jnp.ndarray, optional Galactic coordinates (l, b) for extinction correction, shape (n_objects, 2). ebv : jnp.ndarray, optional E(B-V) values. If not provided, estimated from `coords` using the SFD map. names : list of str, optional Object names. Defaults to stringified index if not given. extinction_correction : {'pending', 'done'}, optional Whether to apply extinction correction during initialization. redshift_correction : {'pending', 'done'}, optional Whether to apply redshift correction during initialization. **kwargs Additional arguments passed to underlying utilities. Attributes ---------- spectra : jnp.ndarray 3D array with shape (n_objects, 3, n_pixels) [wavelength, flux, error]. wdisp : jnp.ndarray or None Wavelength dispersion (if available) per pixel. fwhm_lambda : jnp.ndarray Instrumental resolution in Angstroms, if `wdisp` provided. fwhm_kms : jnp.ndarray Instrumental resolution in km/s. result : SheapResult #TODO maybe change this name Output of the fitting routine, including parameters and metadata. modelbuild : SheapModelBuilder Configuration used to build the model region. plotter : SheapPlot Plotting backend object. Methods ------- makemodel(xmin, xmax, n_narrow=1, n_broad=1, **kwargs) Create a model from line and continuum definitions. fitmodel(...) Perform spectral model fitting using the configured model. estimate_posteriors(...) Estimate posterior distributions using MC or MCMC or just give and estimation of the params. save_to_pickle(filepath) Save object state to a pickle file. from_pickle(filepath) Load a `Sheapectral` instance from a saved pickle. result_panda(n) Return a Pandas DataFrame of the fit parameters for object `n`. quicklook(idx, ax=None, xlim=None, ylim=None) Plot flux + error for spectrum `idx`. modelplot Return or initialize plotting interface (SheapPlot). """ def __init__( self, spectra: Union[str, jnp.ndarray], z: Optional[Union[float, jnp.ndarray]] = None, coords: Optional[jnp.ndarray] = None, ebv: Optional[jnp.ndarray] = None, names: Optional[list[str]] = None, extinction_correction: str = "pending", # this only can be pending or done redshift_correction: str = "pending", # this only can be pending or done **kwargs,): """ Initialize Sheapectral object, load and optionally correct spectra. Parameters ---------- spectra : str or jnp.ndarray Path to data file or array of raw spectra. z : float or jnp.ndarray (N,) shape,optional Redshift(s) to apply; repeated if scalar. coords : ? Coordinates for extinction map lookup. ebv : ? E(B-V) values, overrides coords-based estimation. names : list of str, optional Names for each spectrum. extinction_correction : {'pending', 'done'}, optional Control flag for extinction step. redshift_correction : {'pending', 'done'}, optional Control flag for redshift step. **kwargs : ? Additional parameters passed internally. """ self.log = logging.getLogger(self.__class__.__name__) self.extinction_correction = extinction_correction self.redshift_correction = redshift_correction self.wdisp = None spec_arr = self._load_spectra(spectra) if spec_arr.shape[1] == 4: self.wdisp = spec_arr[:,3,:] spec_arr = spec_arr[:,[0,1,2],:] spec_arr = pad_error_channel(spec_arr) self.spectra = spec_arr#.astype(jnp.float32)# if self.wdisp is not None: #Velocity scale in km/s per pixel (eq.8 of Cappellari 2017) #This aprouch only is usseful for log sample spectra # Resolution fwhm_lambda of every pixel, in Angstroms self.velscale = np.log(np.atleast_2d(self.spectra[:,0,-1]/self.spectra[:,0,0]).T)/(self.spectra.shape[2]- 1 ) * DEFAULT_C_KMS self.dlam = np.gradient(self.spectra[:,0,:],axis=1) self.fwhm_lambda = 2.355 * self.wdisp * self.dlam #A self.fwhm_kms = self.fwhm_lambda / self.spectra[:,0,:] * DEFAULT_C_KMS # in cases without wdisp # self.coords = coords # may be None – handle carefully downstream if self.coords is None: print("no inform coords") self.extinction_correction = "done"#< self.ebv = ebv self.z = self._prepare_z(z, self.spectra.shape[0]) self.names = (np.atleast_1d(names) if names is not None else np.arange(self.spectra.shape[0]).astype(str)) if self.names.shape[0] !=self.spectra.shape[0]: print(f"The number of names ({len(self.names.shape[0])}) is different from the number of spectra ({self.spectra.shape[0]}) the code will use the inner names") self.names = np.arange(self.spectra.shape[0]).astype(str) #print(self.names.shape,self.spectra.shape) if self.extinction_correction == "pending" and (self.coords is not None or self.ebv is not None): print("extinction correction will be do it, change 'extinction_correction' to done if you want to avoid this step") self._apply_extinction() self.extinction_correction = "done" if self.redshift_correction == "pending" and self.z is not None: print("redshift correction will be do it, change 'redshift_correction' to done if you want to avoid this step") self._apply_redshift() self.redshift_correction = "done" self.sheap_set_up() self.default_limits = (float(np.min(self.spectra[:,0,:])),float(np.max(self.spectra[:,0,:]))) self.snr = np.nanmean(self.spectra[:, 1, :] / self.spectra[:, 2, :], axis=1) #signal to noise def _load_spectra(self, spectra: Union[str, ArrayLike]) -> jnp.ndarray: """ Load spectra from file or array. Parameters ---------- spectra : str, Path, np.ndarray, list, or jnp.ndarray Input data source. Returns ------- jnp.ndarray Array of shape (n_objects, channels, n_pixels). Raises ------ TypeError If input type is unsupported. """ if isinstance(spectra, (str, Path)): arr = np.loadtxt(spectra) return jnp.array(arr).T # ensure (c, λ) then transpose later elif isinstance(spectra, np.ndarray): return jnp.array(spectra) elif isinstance(spectra,list): return jnp.array(spectra) elif isinstance(spectra, jnp.ndarray): return spectra raise TypeError("spectra must be a path or ndarray") def _prepare_z(self, z: Optional[Union[float, ArrayLike]], nobj: int) -> Optional[jnp.ndarray]: """ Normalize redshift input to array form. Parameters ---------- z : float, array-like, or None Input redshift(s). nobj : int Number of spectra objects. Returns ------- jnp.ndarray or None Array of length nobj or None if z was None. """ if z is None: return jnp.repeat(0, nobj) if isinstance(z, (int, float)): return jnp.repeat(z, nobj) return jnp.array(z) def _apply_extinction(self) -> None: """ Apply Galactic extinction correction to the flux and error channels. Uses Cardelli et al. (1989) law; if coords provided, uses SFD map. Parameters ---------- None Returns ------- None """ from sfdmap2 import sfdmap from sheap.Sheapectral.Utils.BasicCorrections import unred ebv = self.ebv if self.coords is not None: self.coords = jnp.array(self.coords) l, b = self.coords.T # type: ignore[union-attr] sfd_path = Path(__file__).resolve().parent.parent / "SuportData" / "sfddata/" ensure_sfd_data(sfd_path) ebv_func = sfdmap.SFDMap(sfd_path).ebv ebv = ebv_func(l, b) corrected = unred(*np.swapaxes(self.spectra[:, [0, 1], :], 0, 1), ebv) # propagate to error channel proportionally as pyqso ratio = corrected / self.spectra[:, 1, :] self.spectra = self.spectra.at[:, 1, :].set(corrected) self.spectra = self.spectra.at[:, 2, :].multiply(ratio) def _apply_redshift(self) -> None: """ Apply redshift correction (deredshift) to wavelength axis. Parameters ---------- None Returns ------- None """ from sheap.Sheapectral.Utils.BasicCorrections import deredshift self.spectra = deredshift(self.spectra, self.z)
[docs] def sheap_set_up(self): """ Ensure spectra have leading object axis, record shape and NaN mask. Parameters ---------- None Returns ------- None """ if len(self.spectra.shape) <= 2: self.spectra = self.spectra[jnp.newaxis, :] self.spectra_shape = self.spectra.shape # ? self.spectra_nans = jnp.isnan(self.spectra) c1 = self.spectra[:, 1, :] c1 = jnp.where(jnp.isnan(c1), 0.0, c1) c2 = self.spectra[:, 2, :] c2 = jnp.where(jnp.isnan(c2) | jnp.isnan(c1), 3.4028235e+38, c2)#max float32 # write back functionally self.spectra = self.spectra.at[:, 1, :].set(c1) self.spectra = self.spectra.at[:, 2, :].set(c2)
#self.spectra = spec
[docs] def makemodel(self,limits: tuple = None ,n_narrow: int = 1,n_broad: int = 1,group_method=True, add_balmer_continuum = True ,add_balmerhighorder_continuum = True ,**kwargs): """ Initialize a SheapModelBuilder for later fitting. Parameters ---------- limits : tuple xmin,xmax n_narrow : int, optional Number of narrow components per line. n_broad : int, optional Number of broad components per line. **kwargs : ? Additional SheapModelBuilder options. Returns ------- None """ if not limits: print(f"We will use the defualt limits {self.default_limits}") xmin,xmax = self.default_limits else: xmin,xmax = min(limits),max(limits) if xmin < 3600 and add_balmer_continuum: #3000–3646 add_balmer_continuum = add_balmer_continuum if (3700 > xmin and 3910 < xmax) and add_balmerhighorder_continuum: #3646-3910 add_balmerhighorder_continuum = add_balmerhighorder_continuum self.modelbuild = SheapModelBuilder(xmin=xmin,xmax=xmax,n_narrow=n_narrow,n_broad=n_broad,group_method=group_method, add_balmerhighorder_continuum=add_balmerhighorder_continuum, add_balmer_continuum= add_balmer_continuum, **kwargs)
[docs] def fitmodel(self,run_fit=True, list_num_steps=[1_000],list_learning_rate = [1e-2] ,covariance_error = False,profile: str ='gaussian' ,add_penalty_function=False,method="adam",penalty_weight: float = 0.00 ,curvature_weight: float = 0.0,smoothness_weight: float = 0.0,max_weight: float = 0.0,limits_overrides={}): """ Execute fitting of the prepared region on the spectra. Parameters ---------- list_num_steps : list of int, optional Maximum optimization steps per routine stage. run_uncertainty_params : bool, optional Whether to compute parameter uncertainties. profile : str, optional Line profile type for fitting. list_learning_rate : list of float, optional Learning rates for each stage. run_fit : bool, optional If False, construct the SheapModelFitting object without fitting. add_penalty_function : bool, optional If True, include host-model penalty. Raises ------ RuntimeError If makemodel() was not called first. Returns ------- None """ if not hasattr(self, "modelbuild"): raise RuntimeError("makemodel() must be called before fitmodel()") self.fitting_class = SheapModelFitting.from_builder(self.modelbuild,limits_overrides=limits_overrides,profile=profile) #until here only uses the things that it knows from modelbuild spectra = self.spectra.astype(jnp.float32) if run_fit: Warning(f"You selected run_fit = {run_fit}, if you want run the fit change to True") self.fitting_class(spectra,list_num_steps = list_num_steps,list_learning_rate =list_learning_rate, covariance_error= covariance_error,add_penalty_function=add_penalty_function,method=method, penalty_weight= penalty_weight, curvature_weight= curvature_weight, smoothness_weight= smoothness_weight,max_weight= max_weight) self.spectral_model = self.fitting_class.model #the actual model is self.params_obj = self.fitting_class.params_obj fit_output = self.fitting_class.sheapresult fit_output.source = "computed" self.result = SheapResult( params=fit_output.params, uncertainty_params=fit_output.uncertainty_params, mask=fit_output.mask, profile_functions=fit_output.profile_functions, profile_names=fit_output.profile_names, loss=fit_output.loss, profile_params_index_list=fit_output.profile_params_index_list, initial_params=fit_output.initial_params.astype(jnp.float32), scale=fit_output.scale, params_dict=fit_output.params_dict, region_list=fit_output.region_list, outer_limits=fit_output.outer_limits, inner_limits=fit_output.inner_limits, model_keywords= fit_output.model_keywords, fitting_routine = fit_output.fitting_routine, constraints = fit_output.constraints.astype(jnp.float32), source=fit_output.source, dependencies=fit_output.dependencies, residuals = fit_output.residuals, free_params = fit_output.free_params, chi2_red = fit_output.chi2_red, fitkwargs = fit_output.fitkwargs, elapsed_time= fit_output.elapsed_time) #del self.fitting_class #free memory self.plotter = SheapPlot(sheap=self)
[docs] def estimate_posteriors(self,sampling_method="single", num_samples: int = 50, key_seed: int = 0,summarize=False ,overwrite=False, num_warmup=500,n_random=1_000,frac_box_sigma=0.02,k_sigma=0.3 ,only_sheaproducts=False,**kwargs): """ Estimate or sample posterior distributions of fit parameters. Parameters ---------- sampling_method : {'single', 'pseudomontecarlo', 'mcmc',"montecarlo"} Sampling algorithm to use. num_samples : int, optional Number of samples to draw. key_seed : int, optional Random seed for reproducibility. summarize : bool, optional If True, compute summary statistics. overwrite : bool, optional If True, rerun even if posterior exists. num_warmup : int, optional Warm-up steps for MCMC. n_random : int, optional Number of initial random positions for MCMC. extra_products : bool, optional Return additional diagnostics. Returns ------- ParameterEstimation or dict Posterior object or results dictionary. Raises ------ RuntimeError If fit has not been run (`self.result` missing). """ avalaible_methods = ["single","none","montecarlo"] #TODO implement the other methods. from sheap.MasterSampler.MasterSampler import MasterSampler import time if not hasattr(self, "result"): raise RuntimeError("self.result should exist to run this.") if sampling_method not in avalaible_methods: raise ValueError( f"Unknown sampling method '{sampling_method}'. " f"Available methods: single, {avalaible_methods}") if self.result.posterior is None: self.result.posterior = {} if sampling_method in self.result.posterior and not overwrite: raise RuntimeError(f"Posterior for method '{sampling_method}' already exists. " "Use overwrite=True to recompute it.") PM = MasterSampler(sheap = self) if sampling_method == "none": print("Nothing will run if you dont choose between sampling_method [single, pseudomontecarlo, montecarlo, mcmc") return PM if sampling_method == "single": self.result.posterior[sampling_method] = {} print("You chose single: parameter estimation using " "only fitting uncertainties.") SP = SheaProducts(samplerclass=PM,method="direct") dic_posterior_params = SP.calculate_sheap_products(summarize=summarize) self.result.posterior[sampling_method] = {"posterior_result": dic_posterior_params,"summarize": summarize,} elif sampling_method == "montecarlo": time_init = time.time() dic_posterior_params = PM.montecarlosampler(num_samples=num_samples,key_seed=key_seed,summarize=summarize,frac_box_sigma=frac_box_sigma,k_sigma=k_sigma) self.result.posterior[sampling_method] = {"posterior_result": dic_posterior_params,"num_samples": num_samples,"key_seed": key_seed,"summarize": summarize,"time_elapsed": time.time() - time_init} else: raise ValueError( f"Unknown sampling method '{sampling_method}'. " "Available methods: single, pseudomontecarlo, montecarlo, mcmc." )
def _recalculate_products(self,sampling_method=None): from sheap.MasterSampler.MasterSampler import MasterSampler import time from tqdm import tqdm if hasattr(self.result, "posterior"): keys = list(self.result.posterior.keys()) if len(keys)==1: print(keys[0], "will be recalculated") sampling_method = keys[0] elif len(keys) >= 2: self.result.posterior[sampling_method] PM = MasterSampler(sheap = self) PM.method = sampling_method if sampling_method == "single": SP = SheaProducts(samplerclass=PM,method="direct") dic_posterior_params = SP.calculate_sheap_products(summarize=False) self.result.posterior[sampling_method] = {"posterior_result": dic_posterior_params} previous_producs = self.result.posterior[sampling_method]["posterior_result"] names = list(previous_producs.keys()) iterator = tqdm(names, total=len(names), desc="Re-Getting posterior-params") for n, name_i in enumerate(iterator):#SP.calculate_sheap_products_sampled(n,samples,extra_products=True) t0 = time.perf_counter() self.result.posterior["montecarlo"]["posterior_result"][name_i].update(SP.calculate_sheap_products_sampled(n,previous_producs[name_i]["samples_phys"])) t1 = time.perf_counter() iterator.set_postfix({"it_s": f"{(t1 - t0):.4f}"})
[docs] @classmethod def from_pickle(cls, filepath: Union[str, Path]) -> Sheapectral: """ Load a saved Sheapectral instance from a pickle file. Parameters ---------- filepath : str or Path Path to the pickle file created by save_to_pickle(). Returns ------- Sheapectral Restored object with loaded spectra and results. """ from sheap.Profiles.Utils import make_fused_profiles filepath = Path(filepath) with open(filepath, "rb") as f: data = pickle.load(f) obj = cls( spectra=data["spectra"], z=data["z"], names=data["names"], coords=data["coords"], extinction_correction=data["extinction_correction"], redshift_correction=data["redshift_correction"],) # small helper to read older versions -> removing it soon region_list = data.get("region_list", []) if "complex_region" in data.keys(): region_list = data.get("complex_region", []) obj.region_list = [SpectralLine(**i) for i in region_list] profile_names = data.get("profile_names", []) obj.result = SheapResult( params=jnp.array(data.get("params")), uncertainty_params=jnp.array(data.get("uncertainty_params", jnp.zeros_like(data.get("params")))), initial_params=jnp.array(data.get("initial_params")), mask=jnp.array(data.get("mask")), profile_functions= profile_functions_from_region_list(obj.region_list) , profile_names=profile_names, loss=None, # Not saved currently, could be added if needed profile_params_index_list=data.get("profile_params_index_list"), scale=data.get("scale"), # Not saved currently, could be added if needed params_dict=data.get("params_dict"), region_list=obj.region_list, outer_limits=data.get("outer_limits"), inner_limits=data.get("inner_limits"), model_keywords=data.get("model_keywords"), dependencies = data.get("dependencies"), source=data.get("source", "pickle"), constraints = data.get('constraints'), fitting_routine = data.get("fitting_routine"), free_params= data.get("free_params"), residuals= data.get("residuals"), posterior = data.get("posterior"), chi2_red = data.get("chi2_red"), fitkwargs = data.get("fitkwargs"), elapsed_time= data.get("elapsed_time") ) obj.plotter = SheapPlot(sheap=obj) obj.spectral_model = make_fused_profiles(obj.result.profile_functions) return obj
def _save(self): """ Internal: assemble a dict of object state for pickling. Returns ------- dict Keys/values for spectra, results, and metadata. """ _region_list = [i.to_dict() for i in self.result.region_list] dic_ = { "names": self.names, "spectra": np.array(self.spectra), "coords": np.array(self.coords),#mmmm "z": np.array(self.z), "extinction_correction": self.extinction_correction, "redshift_correction": self.redshift_correction, "params": np.array(self.result.params), "uncertainty_params": np.array(self.result.uncertainty_params), "initial_params": np.array(self.result.initial_params), # explicitly saved "params_dict": self.result.params_dict, "mask": np.array(self.result.mask), "region_list": _region_list, "profile_params_index_list": self.result.profile_params_index_list, "profile_names": self.result.profile_names, "fitting_routine": self.result.fitting_routine, "outer_limits": self.result.outer_limits, "inner_limits": self.result.inner_limits, "model_keywords": self.result.model_keywords, "source": self.result.source, "scale":np.array(self.result.scale), 'constraints':np.array(self.result.constraints), 'dependencies': self.result.dependencies, 'residuals' : np.array(self.result.residuals), 'free_params' : self.result.free_params, 'chi2_red' : np.array(self.result.chi2_red), "posterior" : self.result.posterior, "fitkwargs":self.result.fitkwargs, "elapsed_time":self.result.elapsed_time } estimated_size = sys.getsizeof(pickle.dumps(dic_)) print(f"Estimated pickle size: {estimated_size / 1024:.2f} KB") return dic_
[docs] def save_to_pickle(self, filepath: Union[str, Path]): """ Save the current object state to a pickle file (.pkl). Parameters ---------- filepath : str or Path Destination path for the pickle. Returns ------- None """ filepath = Path(filepath) with open(filepath, "wb") as f: pickle.dump(self._save(), f)
@property def modelplot(self): """ Get or initialize the SheapPlot plotting interface. TODO modelplot or plotter? Returns ------- SheapPlot Plotting backend for spectra and fit results. Raises ------ RuntimeError If no fit result exists. """ if not hasattr(self, "plotter"): if hasattr(self, "result"): self.plotter = SheapPlot(sheap=self) else: raise RuntimeError("No fit result found. Run `fitmodel()` first.") return self.plotter
[docs] def result_panda(self, n: int, param_filter: str | None = None, regex: bool = False, case: bool = True) -> pd.DataFrame: """ #TODO update this part to be able to show the actual name of the parameter lets say line with 0,1,2,3,4 is easier for code reason but for visualitation could be messy #TODO say if the paramters are at scale or not. Return a pandas DataFrame of fit parameters for a given spectrum. Parameters ---------- n : int Index of the spectrum object. param_filter : str, optional If provided, return only rows whose parameter name matches this pattern. Uses pandas .str.contains, so it can be a substring or a regex (see `regex` and `case`). regex : bool, default False If True, `param_filter` is interpreted as a regular expression. case : bool, default True If False, ignore case when matching `param_filter`. Returns ------- pandas.DataFrame Index: parameter name. Columns: ['value', 'error', 'max_constraint', 'init_value', 'min_constraint']. """ import pandas as pd data = [] scale = self.result.scale[n] for param_index,(param_name, i) in enumerate(self.result.params_dict.items()): param = float(self.result.params[n][i]) init_value = float(self.result.initial_params[i]) uncertainty = float(self.result.uncertainty_params[n][i]) if "amplitude" in param_name: param /= scale uncertainty /= scale elif "logamp" in param_name: param -= np.log10(scale) # uncertainty is usually left as-is in log-space constraints = self.result.constraints[i] data.append({"param_index":param_index,"param_name": param_name, "value": param, "error": uncertainty, "max_constraint": constraints[1], "init_value": init_value,"min_constraint": constraints[0],}) df = pd.DataFrame(data) # Optional filtering by parameter name (now using the column) if param_filter is not None: mask = df["param_name"].str.contains(param_filter, case=case, regex=regex, na=False) df = df[mask] return df
[docs] def quicklook(self, idx: int, ax=None, xlim=None, ylim=None,add_text=True,figsize= (15, 5)): """ Produce a quick errorbar plot of flux vs. wavelength. Parameters ---------- idx : int Spectrum index to plot. ax : matplotlib.axes.Axes, optional Axes to plot into (creates new if None). xlim : tuple, optional X-axis limits as (xmin, xmax). ylim : tuple, optional Y-axis limits as (ymin, ymax). Returns ------- matplotlib.axes.Axes The axes containing the plot. """ import matplotlib.pyplot as plt from matplotlib.ticker import FixedLocator lam, flux, err = self.spectra[idx] if ax is None: fig, ax = plt.subplots(figsize=figsize) ax.errorbar(lam, flux, yerr=err, ecolor='dimgray', color="black", zorder=1) # Default xlim and ylim if not provided if xlim is None: xlim = (jnp.nanmin(lam), jnp.nanmax(lam)) if ylim is None: ylim = (0, jnp.nanmax(flux) * 1.02) ax.set_xlim(*xlim) ax.set_ylim(*ylim) ax.set_xlabel("Rest wavelength [Å]", fontsize=25) flux_unit=r"$\mathrm{erg\,s^{-1}\,cm^{-2}\,\AA^{-1}}$" ax.set_ylabel(f"Flux [{flux_unit}]", fontsize=25) ax.tick_params(axis='both', labelsize=25) ax.yaxis.offsetText.set_fontsize(25) # Plot ID label outside main plot area, above-left if add_text: ax.text( 0.7, 1.0, f"ID {self.names[idx]} ({idx}) \nz = {self.z[idx]} ", fontsize=20, transform=ax.transAxes, ha='left', va='bottom', ) ax.yaxis.set_major_locator(FixedLocator(ax.get_yticks())) return ax
#extra plots @property def plot_redshift_signal2noise_distribution(self): import matplotlib.pyplot as plt fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True) axes[0].hist(self.z, bins=20, color="#214994", edgecolor='black', alpha=0.8) axes[0].set_xlabel("Redshift (z)", fontsize=16) axes[0].set_ylabel("Number of objects", fontsize=16) #axes[0].set_title("Redshift Distribution", fontsize=18) axes[0].tick_params(axis='both', which='major', labelsize=14) snr = np.nanmean(self.spectra[:, 1, :] / self.spectra[:, 2, :], axis=1) axes[1].hist(snr, bins=20, color="#f5b041", edgecolor='black', alpha=0.8) axes[1].set_xlabel("Mean Signal-to-Noise Ratio", fontsize=16) #axes[1].set_title("S/N Distribution", fontsize=18) axes[1].tick_params(axis='both', which='major', labelsize=14) plt.tight_layout() return axes @property def plot_chi2(self): import matplotlib.pyplot as plt if not hasattr(self, "result"): raise RuntimeError("self.result should exist to run this.") chi2_model = self.result.chi2_red # or whatever your model object is chi2_model = np.asarray(chi2_model) nan_mask = ~np.isfinite(chi2_model) nan_idx = np.where(nan_mask)[0] chi2_model = chi2_model[~nan_mask] print(f"NaN / non-finite entries at indices: {nan_idx.tolist()}") print(f"Number of NaNs / non-finite values: {nan_idx.size}") #mask_range_model = (chi2_model > 0.) & (chi2_model < 5.) #frac_model_0_5 = np.nanmean(mask_range_model) * 100.0 # --- Define bins from model only --- chi2_min = chi2_model.min() chi2_max = chi2_model.max() bins = np.linspace(chi2_min, chi2_max, 40) # --- Plot --- fig, ax = plt.subplots(figsize=(10, 6)) ax.hist( chi2_model, bins=bins, alpha=0.7, color="#d62728", # label=fr"Model chi2 (median = {np.nanmedian(chi2_model):.2f})" ) ax.set_xlabel(r"Reduced $\chi^2$", fontsize=18) ax.set_ylabel("Number of spectra", fontsize=18) ax.tick_params(axis="both", labelsize=14) ax.axvline(np.nanmedian(chi2_model), linestyle="--",label=fr"Reduced $\chi^2$ median = {np.nanmedian(chi2_model):.2f}",c="b") ax.axvline(1.0, linestyle="--",label=fr"Reduced $\chi^2$ = 1",c="k") ax.set_xlim(0.1,ax.get_xlim()[-1]) #textstr = ( # fr"$0.<\chi^2_{{\rm red}}<5$ fraction: " # fr"Model: {frac_model_0_5:.1f}%" #) # ax.text( # 0.3, 0.5, # textstr, # transform=ax.transAxes, # fontsize=20, # va="top", # bbox=dict(facecolor="white", alpha=0.7, edgecolor="none") # ) ax.legend(fontsize=14, frameon=False) fig.tight_layout() plt.show() #plt.tight_layout() #plt.show()
[docs] def plot_param_distribution(self,param_name,no_log=False): #TODO what about this plot but for posterior-> params_keys = self.result.params_dict.keys() if param_name not in params_keys: raise KeyError(f"param_name '{param_name}' is not available. " f"Available param names: {list(params_keys)}") import pandas as pd import matplotlib.pyplot as plt #data = [] scale = self.result.scale param_index = self.result.params_dict[param_name] param_values = self.result.params[:,param_index] init_value = float(self.result.initial_params[param_index]) if "amplitude" in param_name: param_values /= scale elif "logamp" in param_name: param_values -= np.log10(scale) constraints = self.result.constraints[param_index] if no_log: constraints = 10**constraints #constraints[1] = 10**constraints[1] init_value = 10**init_value param_values = 10**param_values param_name = param_name.replace("log","") bins = np.linspace(np.min(param_values),np.max(param_values), 40) fig, ax = plt.subplots(figsize=(10, 6)) ax.hist( param_values, bins=bins, alpha=0.7, color="#2781d65c",) ax.set_xlabel(param_name, fontsize=18) ax.set_ylabel("Number of spectra", fontsize=18) ax.set_title(f"param number {param_index}") ax.tick_params(axis="both", labelsize=14) ax.axvline(init_value, linestyle="--",label="init value",c="r") ax.axvline(constraints[0], linestyle="--",label="min value",c="k") ax.axvline(constraints[1], linestyle="--",label="max value",c="k") ax.legend(fontsize=14, frameon=False) fig.tight_layout() plt.show() return param_values
@property def modelplot(self): """ Get or initialize the SheapPlot plotting interface. TODO modelplot or plotter? Returns ------- SheapPlot Plotting backend for spectra and fit results. Raises ------ RuntimeError If no fit result exists. """ if not hasattr(self, "plotter"): if hasattr(self, "result"): self.plotter = SheapPlot(sheap=self) else: raise RuntimeError("No fit result found. Run `fitmodel()` first.") return self.plotter # def result_panda(self, n:number {param_index}") # ax.tick_params(axis="both", labelsize=14) # ax.axvline(init_value, linestyle="--",label="init value",c="r") # ax.axvline(constraints[0], linestyle="--",label="min value",c="k") # ax.axvline(constraints[1], linestyle="--",label="max value",c="k") # ax.legend(fontsize=14, frameon=False) # fig.tight_layout() # plt.show() # return param_values def _calcualte_d(self,cosmo=None,H0=70,Om0=0.3): from astropy.cosmology import FlatLambdaCDM from sheap.Utils.Constants import cm_per_mpc if cosmo is None: self.cosmo = FlatLambdaCDM(H0=H0, Om0=Om0) else: self.cosmo = cosmo #depending on the version this could change after 7.0.0 this change #self.d = self.cosmo.luminosity_distance(self.z).value * cm_per_mpc d = self.cosmo.luminosity_distance(self.z).value * cm_per_mpc return d