Source code for sheap.MasterSampler.MasterSampler

"""
MasterSampler
===============

Post–fitting interface for extracting physical parameters and running
posterior sampling on spectral models.

This module defines :class:`MasterSampler`, which acts as a high-level
wrapper around results from :class:`Sheapectral <sheap.Sheapectral.Sheapectral>` or :class:`SheapResult <sheap.Core.SheapResult>`.
It provides multiple strategies to handle parameters after fitting:

- **Single best-fit mode**  
  Extract physical quantities (flux, FWHM, EQW, luminosity, etc.)
  directly from the optimized parameters with propagated uncertainties.

- **Monte Carlo mode**  
  Generate pseudo-random realizations of parameters around the covariance
  matrix for uncertainty propagation.

- **Pseudo Monte Carlo mode**  
  Fast approximate sampler for posterior parameter exploration.

- **MCMC mode (NumPyro)**  
  Full Bayesian sampling of the posterior distribution using Hamiltonian
  Monte Carlo / NUTS.

Key Features
------------
- Organizes parameter arrays, constraints, and dependencies after a fit.
- Provides consistent access to spectra, masks, scaling factors, and model functions.
- Computes luminosity distances given redshifts and a cosmology.
- Interfaces with downstream tools (:class:`AfterFitParams`) to compute
  line and continuum physical properties.
- Exposes convenience methods:

  * :meth:`MasterSampler.sample_single`
  * :meth:`MasterSampler.montecarlosampler`
  * :meth:`MasterSampler.sample_pseudomontecarlosampler`
  * :meth:`MasterSampler.sample_mcmc`

Notes
-----
- By default, cosmology is set to ``FlatLambdaCDM(H0=70, Om0=0.3)``.
- Bolometric corrections and single-epoch estimators are loaded from
  :mod:`sheap.Utils.Constants`.
- This module will eventually centralize routines now duplicated in
  samplers and parameter estimation helpers.
"""

__author__ = 'felavila'


__all__ = ["MasterSampler",]

from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from pathlib import Path

import jax.numpy as jnp
import numpy as np
import pandas as pd
from astropy.cosmology import FlatLambdaCDM

from jax import grad, jit,vmap


from sheap.Sheapectral.Sheapectral import Sheapectral
from sheap.Core import SheapResult

from sheap.Profiles.Utils import make_fused_profiles


from sheap.Utils.Constants import DEFAULT_BOL_CORRECTIONS, DEFAULT_SINGLE_EPOCH_ESTIMATORS,DEFAULT_C_KMS,cm_per_mpc


#TODO flat_param_indices_global is super difficult to know what it means.
#TODO here we have to move the entire subrutines for montecarlosampler/mcmcsampler and ParametersSingle to his respective places bc in general they require different subfunctions. 

[docs] class MasterSampler: """ Computes best-fit physical parameters and uncertainties for spectral regions. Provides Monte Carlo and MCMC posterior sampling. Parameters ---------- sheap : Sheapectral, optional Configured Sheapectral instance with fit results. sheapresult : SheapResult, optional SheapResult object containing parameters, uncertainties, and metadata. spectra : jnp.ndarray, optional Normalized spectra array, shape (n_objects, 3, n_pixels). z : jnp.ndarray, optional Redshift array for each object. cosmo : astropy.cosmology instance, optional Cosmology for distance calculations; defaults to FlatLambdaCDM(H0=70, Om0=0.3). BOL_CORRECTIONS : dict, optional Bolometric correction factors; defaults from module constant. SINGLE_EPOCH_ESTIMATORS : dict, optional Single-epoch estimators; defaults from module constant. : float, optional Speed of light constant; defaults from module constant. Attributes ---------- spec : jnp.ndarray Spectra array used for estimation. z : jnp.ndarray Redshifts for each spectrum. params : jnp.ndarray Best-fit parameter values. uncertainty_params : jnp.ndarray Uncertainty estimates for parameters. cosmo : ? Cosmology object for computing distances. d : ? Luminosity distances corresponding to `z` (in cm). BOL_CORRECTIONS : dict Bolometric correction lookup. SINGLE_EPOCH_ESTIMATORS : dict Single-epoch parameter estimators. C_KMS : float Speed of light. Methods ------- sample_montecarlo(num_samples=2000, key_seed=0, summarize=True, extra_products=True) Run Monte Carlo sampling of physical parameters. sample_mcmc(n_random=0, num_warmup=500, num_samples=1000, summarize=True, extra_products=True) Run MCMC sampling via NumPyro. sample_single(extra_products=True) Compute parameter estimates without posterior sampling. _from_sheap(sheap) Initialize internal state from a Sheapectral object. _from_sheapresult(result, spectra, z) Initialize internal state from SheapResult and spectra. """ def __init__( self, sheap: Optional["Sheapectral"] = None, sheapresult: Optional["SheapResult"] = None, spectra: Optional[jnp.ndarray] = None, z: Optional[jnp.ndarray] = None, cosmo=None, BOL_CORRECTIONS = None, SINGLE_EPOCH_ESTIMATORS = None, C_KMS=DEFAULT_C_KMS,): """ Initialize ParameterEstimation context. Parameters ---------- sheap : Sheapectral, optional Use attributes from this Spectral fitting interface. sheapresult : SheapResult, optional Use direct fit results if `sheap` not provided. spectra : jnp.ndarray, optional Spectra corresponding to `sheapresult`. z : jnp.ndarray, optional Redshifts for each spectrum. cosmo : ? Cosmology for computing luminosity distances. BOL_CORRECTIONS : dict, optional Bolometric corrections mapping. SINGLE_EPOCH_ESTIMATORS : dict, optional Single-epoch estimators mapping. C_KMS : float, optional Speed of light constant. """ if sheap is not None: self._from_sheap(sheap) elif sheapresult is not None and spectra is not None: self._from_sheapresult(sheapresult, spectra, z) else: raise ValueError("Provide either `sheap` or (`sheapresult` + `spectra`).") if not BOL_CORRECTIONS: self.BOL_CORRECTIONS = DEFAULT_BOL_CORRECTIONS if not SINGLE_EPOCH_ESTIMATORS: self.SINGLE_EPOCH_ESTIMATORS = DEFAULT_SINGLE_EPOCH_ESTIMATORS self.C_KMS = C_KMS if self.z is None: print("None informed redshift, assuming zero.") self.z = np.zeros(self.spectra.shape[0]) if cosmo is None: self.cosmo = FlatLambdaCDM(H0=70, Om0=0.3) else: self.cosmo = cosmo self.d = self.cosmo.luminosity_distance(self.z).value * cm_per_mpc
[docs] def sample_pseudomontecarlosampler(self, num_samples: int = 2000, key_seed: int = 0,summarize=True): """ Run pseudomontecarlosamplerparameter sampling. Parameters ---------- num_samples : int, optional Number of samples to draw. key_seed : int, optional Seed for random number generator. summarize : bool, optional If True, summarize posterior distributions. extra_products : bool, optional Whether to return additional derived products. Returns ------- full_samples, summary_dict Array of samples and dictionary of summarized statistics. """ from sheap.MasterSampler.Samplers.PseudoMonteCarloSampler import PseudoMonteCarloSampler self.method = "pseudomontecarlos" sampler = PseudoMonteCarloSampler(self) if summarize: print("The samples will be summarize is you want to keep the samples summarize=False") return sampler.sample_params(num_samples=num_samples, key_seed=key_seed,summarize=summarize)
[docs] def montecarlosampler(self, num_samples: int = 2000, key_seed: int = 0,summarize=True,return_only_draws=False,frac_box_sigma=0.02,k_sigma=0.3): """ Run montecarlosampler sampling. Parameters ---------- num_samples : int, optional Number of samples to draw. key_seed : int, optional Seed for random number generator. summarize : bool, optional If True, summarize posterior distributions. extra_products : bool, optional Whether to return additional derived products. Returns ------- full_samples, summary_dict Array of samples and dictionary of summarized statistics. """ from sheap.MasterSampler.Samplers.MonteCarloSampler import MonteCarloSampler self.method = "montecarlo" sampler = MonteCarloSampler(self) if summarize: print("The samples will be summarize is you want to keep the samples summarize=False") return sampler.sample_params(num_samples=num_samples, key_seed=key_seed,summarize=summarize,return_only_draws=return_only_draws,frac_box_sigma=frac_box_sigma, k_sigma= k_sigma)
[docs] def sample_mcmc(self,n_random = 0,num_warmup=500,num_samples=1000,summarize=True): """ Run MCMC sampling using NumPyro. Parameters ---------- n_random : int, optional Number of random initial chains. num_warmup : int, optional Number of warmup steps. num_samples : int, optional Number of MCMC samples. summarize : bool, optional If True, summarize the chains. extra_products : bool, optional Include extra derived products. Returns ------- full_chain, summary_dict Array of MCMC samples and dictionary of statistics. """ from sheap.MasterSampler.Samplers.McMcSampler import McMcSampler self.method = "mcmc" sampler = McMcSampler(self) return sampler.sample_params(n_random=n_random,num_warmup=num_warmup,num_samples=num_samples,summarize=summarize)
[docs] def sample_single(self,summarize=True): """ Compute parameter estimates and uncertainties without sampling. Parameters ---------- extra_products : bool, optional Include additional derived products. Returns ------- summary_dict Dictionary of parameter estimates and uncertainties. """ from sheap.SheaProducts.SheaProducts import SheaProducts self.method = "single" SheaProducts = SheaProducts(self) return SheaProducts.extract_params(summarize=summarize)
[docs] def _from_sheap(self, sheap): """ Initialize internal state from a Sheapectral instance. Parameters ---------- sheap : Sheapectral Source of fit results and spectra. """ self.spectra = sheap.spectra self.z = sheap.z self.result = sheap.result result = sheap.result # for convenience self.constraints = result.constraints self.params = result.params self.scale = result.scale self.uncertainty_params = result.uncertainty_params self.profile_params_index_list = result.profile_params_index_list self.profile_functions = result.profile_functions self.profile_names = result.profile_names self.region_list = result.region_list self.xlim = result.outer_limits self.mask = result.mask self.names = sheap.names self.model_keywords = result.model_keywords or {} #self.fe_mode = self.model_keywords.get("fe_mode") self.model = jit(make_fused_profiles(self.profile_functions)) # self.params_dict = result.params_dict self.dependencies = result.dependencies self.sheapmodel = result.sheapmodel self.fitkwargs = result.fitkwargs self.initial_params = result.initial_params
[docs] def _from_sheapresult(self, result, spectra, z): """ Initialize internal state from SheapResult and spectra. Parameters ---------- result : SheapResult SheapResult containing parameters and metadata. spectra : jnp.ndarray Spectra array corresponding to `result`. z : jnp.ndarray Redshifts for each spectrum. """ self.spectra = spectra self.z = z self.params = result.params self.uncertainty_params = result.uncertainty_params self.profile_params_index_list = result.profile_params_index_list self.profile_functions = result.profile_functions self.profile_names = result.profile_names self.region_list = result.region_list self.xlim = result.outer_limits self.mask = result.mask self.names = [str(i) for i in range(self.params.shape[0])] self.model_keywords = result.model_keywords or {} #self.fe_mode = self.model_keywords.get("fe_mode") self.model = jit(make_fused_profiles(self.profile_functions)) #mmm self.params_dict = result.params_dict self.constraints = result.constraints self.fitkwargs = result.fitkwargs self.initial_params = result.initial_params