Source code for sheap.SheaProducts.SheaProducts

"""
SheaProducts Handling
============================

Routines to post-process fitted or sampled parameter sets and compute
derived physical quantities.

This module provides the :class:`SheaProducts` class, which acts as a
bridge between raw fitting/sampling outputs (parameter vectors) and
scientifically useful quantities such as line fluxes, widths, equivalent
widths, luminosities, and single-epoch black hole mass estimators.

Main Features
-------------
- Unified interface to handle both:
  * **single best-fit parameters** (deterministic optimization), and
  * **sampled parameters** (Monte Carlo / MCMC posterior draws).
- Automatic grouping of parameters by spectral region and profile.
- Computation of:
  * line flux, FWHM, velocity width (km/s),
  * line centers, amplitudes, shape parameters,
  * equivalent width (EQW),
  * monochromatic and bolometric luminosities,
  * combined quantities (e.g. Hα+Hβ, Mg II+Fe, CIV blends).
- uncertainties propagation via :mod:`uncertainties`.

Public API
----------
- :class:`SheaProducts`:
    High-level handler that connects a :class:`ComplexSampler` result
    to physical parameter extraction.

Typical Workflow
----------------
1. Fit or sample spectra with :class:`RegionFitting` or a sampler.
2. Wrap the result in a :class:`ComplexSampler` instance.
3. Construct :class:`SheaProducts(samplerclass)` from it.
4. Call :meth:`SheaProducts.extract_params` to obtain dictionaries
    of physical line quantities, optionally summarized across samples.

Notes
-----
- The attribute ``method`` determines whether results are handled as
    ``"single"`` (best fit) or ``"sampled"`` (posterior draws).
- Many helpers internally rely on
    :func:`make_batch_fwhm_split[_with_error]`,
    :func:`make_integrator`, and profile-specific shape functions.
"""

__author__ = 'felavila'

__all__ = ["SheaProducts",]
from typing import Any, Callable, Dict, List, Optional, Tuple, Union,Iterable
from collections import defaultdict

import numpy as np 
import jax.numpy as jnp 
from jax import vmap,jit

from sheap.Profiles.Utils import make_fused_profiles
from sheap.SheaProducts.Utils.MasterCombineProfile import MasterCombineProfile
from sheap.SheaProducts.Utils.ExtractBasicParamsSampled import extract_basic_params_sampled
from sheap.SheaProducts.Utils.ExtractBasicParams import extract_basic_params_single
from sheap.Utils.Constants import DEFAULT_BOL_CORRECTIONS, DEFAULT_SINGLE_EPOCH_ESTIMATORS,DEFAULT_C_KMS

from sheap.SheaProducts.Utils.ExtractExtraParams import extra_params_functions


#TODO method still usefful ?
#TODO Fe ratio and Star cont ratio. with model-spectra-reconstruction

[docs] class SheaProducts: _BASE_REQUIRED = ("model", "dependencies", "spectra", "mask", "sheapmodel", "method", "d") def __init__(self,*,samplerclass: Optional[object] = None,model=None,dependencies=None,spectra=None,mask=None,sheapmodel=None,method=None,d=None, BOL_CORRECTIONS=None,SINGLE_EPOCH_ESTIMATORS=None,C_KMS=None, **extra,): self.BOL_CORRECTIONS = DEFAULT_BOL_CORRECTIONS if BOL_CORRECTIONS is None else BOL_CORRECTIONS self.SINGLE_EPOCH_ESTIMATORS = (DEFAULT_SINGLE_EPOCH_ESTIMATORS if SINGLE_EPOCH_ESTIMATORS is None else SINGLE_EPOCH_ESTIMATORS) self.C_KMS = DEFAULT_C_KMS if C_KMS is None else C_KMS if samplerclass is not None: self._from_any(samplerclass) manual = dict(model=model,dependencies=dependencies,spectra=spectra,mask=mask,sheapmodel=sheapmodel,method=method,d=d,) manual.update(extra) for name, value in manual.items(): if value is None: continue if getattr(self, name, None) is None: setattr(self, name, value) self._require(self._BASE_REQUIRED) self.wavelength_grid = jnp.linspace(0, 20_000, 20_000) self.LINES_TO_COMBINE = ["Halpha", "Hbeta","MgII","CIV"] self.limit_velocity = 150. #self.sheapmodel = sheapmodel self.by_region = self.sheapmodel.group_by("region") #-> self.full_cont_profile = jit(vmap(make_fused_profiles(np.concatenate([self.by_region[key].profile_functions for key in self.by_region.keys() if key in ["fe", "continuum", "host","balmer"]])), in_axes=(0, 0))) self.full_cont_profile_NONE = jit(vmap(make_fused_profiles(np.concatenate([self.by_region[key].profile_functions for key in self.by_region.keys() if key in ["fe", "continuum", "host","balmer"]])), in_axes=(None,0))) self.full_cont_idx = np.concatenate([self.by_region[key].flat_param_indices_global for key in self.by_region.keys() if key in ["fe", "continuum", "host","balmer"]]) self.cont_profile = jit(vmap(self.by_region["continuum"].combined_profile, in_axes=(None, 0))) ##summarize_spectral_lines(sheapspectral.result.sheapmodel.lines) this have to be the way to handle this. n_broad = len(getattr(self.by_region.get("broad"), "lines", [])) n_narrow = len(getattr(self.by_region.get("narrow"), "lines", [])) self.MC = MasterCombineProfile(LINES_TO_COMBINE= self.LINES_TO_COMBINE,limit_velocity=self.limit_velocity ,C_KMS=self.C_KMS,full_cont_profile=self.full_cont_profile,ucont_params = None, full_cont_profile_NONE = self.full_cont_profile_NONE,n_broad=n_broad,n_narrow=n_narrow) # self.MSR = MoldelSpectraReconstruction(self, jit_compile=True)# <-- this should go with the sampler to be able of run it in one run
[docs] def calculate_sheap_products_sampled(self,idx,samples,combine=True,extra_products=True,**kwargs): #full_samples -> samples #d->luminosity_distance? wi,mi,samples,luminosity_distance = self.spectra[idx,0,:],self.mask[idx,:],samples,self.d[idx] #basic_params should be basic extraction. # fe_out = self.MSR.fe_integrated_flux(all_samples=samples,include_bestfit=False) # stars = self.MSR.stars_Cont_5100(all_samples=samples) products = extract_basic_params_sampled(sheapmodel=self.sheapmodel,wavelength=wi,mask=mi,samples=samples,continuum_idx_all=self.full_cont_idx,cont_profile_all = self.full_cont_profile, cont_profile=self.cont_profile,luminosity_distance=luminosity_distance,BOL_CORRECTIONS =self.BOL_CORRECTIONS,C_KMS= self.C_KMS,wavelength_grid=self.wavelength_grid) #products["stars"] = stars #products["feII"] = fe_out if combine: full_cont_params = samples[:,self.full_cont_idx] all_params = self.MC.combine_both(products["basic_params"],luminosity_distance,full_cont_params) #<- products= {**all_params,**products} if extra_products: products = self._get_extraparams(products) #print(products.keys()) return products
[docs] def calculate_sheap_products(self,combine=True,extra_products=True,**kwargs): params = jnp.asarray(self.sheapmodel.params, dtype=jnp.float32) uncertainty_params = jnp.asarray(self.sheapmodel.uncertainty_params, dtype=jnp.float32) spectra = jnp.asarray(self.spectra, dtype=jnp.float32) #mask = jnp.asarray(self.mask, dtype=jnp.float32) full_cont_profile = make_fused_profiles(np.concatenate([self.by_region[key].profile_functions for key in self.by_region.keys() if key in ["fe", "continuum", "host","balmer"]])) # <- ? products = extract_basic_params_single(spectra,self.mask,params,uncertainty_params,continuum_idx_all=self.full_cont_idx, luminosity_distance=self.d,sheapmodel=self.sheapmodel,cont_profile_all= full_cont_profile, BOL_CORRECTIONS =self.BOL_CORRECTIONS,C_KMS= self.C_KMS,wavelength_grid=self.wavelength_grid) if combine: full_cont_params = params[:,self.full_cont_idx] # <-jeje self.MC.ucont_params = uncertainty_params[:,self.full_cont_idx] # <-jeje self.MC.full_cont_profile = full_cont_profile all_params = self.MC.combine_both(products["basic_params"],self.d,full_cont_params) #<- #print(all_params) products= {**all_params,**products} if extra_products: products = self._get_extraparams(products) #print(products) return products
def _get_extraparams(self,products): L_w,L_bol = products["L_w"],products["L_bol"] new_products = {} for key, local_result in products.items(): if "basic_params" in key: key = key.replace("basic","extra") new_products[key] = extra_params_functions(local_result,L_w,L_bol,self.SINGLE_EPOCH_ESTIMATORS,self.C_KMS) else: pass products.update(new_products) return products def _from_any(self, src: object) -> None: for name in self._BASE_REQUIRED: setattr(self, name, getattr(src, name, None)) if hasattr(src, "BOL_CORRECTIONS"): self.BOL_CORRECTIONS = src.BOL_CORRECTIONS if hasattr(src, "SINGLE_EPOCH_ESTIMATORS"): self.SINGLE_EPOCH_ESTIMATORS = src.SINGLE_EPOCH_ESTIMATORS if hasattr(src, "C_KMS"): self.C_KMS = src.C_KMS def _require(self, names: Iterable[str]) -> None: missing = [n for n in names if getattr(self, n, None) is None] if missing: raise ValueError(f"SheaProducts is missing required fields: {missing}")