Source code for sheap.MasterSampler.Samplers.MonteCarloSampler

"""
Monte Carlo Sampler
===================

This module implements the :class:`MonteCarloSampler`, a simple
posterior approximation for spectral fits based on randomized parameter
initialization and local re-optimization.

Main Features
-------------
- Generates random draws of parameter vectors within their constraints.
- Converts parameters to raw space and re-optimizes them with
  :class:`Minimizer`.
- Handles tied/fixed parameters through :func:`build_Parameters` and
  dependency flattening utilities.
- Reconstructs physical parameters from optimized raw vectors.
- Computes physical quantities (fluxes, FWHM, luminosities, etc.)
  for each draw using :class:`SheaProducts`.

Public API
----------
- :class:`MonteCarloSampler`
    * :meth:`MonteCarloSampler.sample_params` —
      run the Monte Carlo sampler and return posterior dictionaries.
    * :meth:`MonteCarloSampler.make_minimizer` —
      construct a :class:`Minimizer` configured with penalties/weights.
    * :meth:`MonteCarloSampler._build_tied` —
      convert tied-parameter specifications into dependency strings.

Notes
-----
- This method approximates the posterior distribution by repeatedly
  optimizing from random starts (sometimes called a “poor man’s MCMC”).
- Actual uncertainty propagation is performed by analyzing the
  distribution of optimized solutions.
- Dependencies are flattened so that all tied parameters ultimately
  reference free parameters only.
"""

__author__ = 'felavila'

__all__ = ["MonteCarloSampler",]

from typing import Tuple, Dict, List

import jax.numpy as jnp
from jax import jit , random
import jax.numpy as jnp

import numpy as np 
import time


from sheap.Assistants.parser_mapper import descale_amp,scale_amp,make_get_param_coord_value,build_tied,parse_dependencies,flatten_tied_map
from sheap.SheaProducts.SheaProducts import SheaProducts
from sheap.Assistants.Parameters import build_Parameters
from sheap.Minimizer.Minimizer import Minimizer
from sheap.MasterSampler.Samplers.Utils.montecarlo_utils import phys_trust_region_inits,resample_spec_all 


[docs] class MonteCarloSampler: """ Montecarlo sampler still under developmen. """ def __init__(self, estimator: "MasterSampler"): self.estimator = estimator # ParameterEstimation instance self.SheaProducts = SheaProducts(samplerclass=estimator) self.names = estimator.names self.model = jit(estimator.model) #####norm_spectra#### self.scale = estimator.scale self.spectra = estimator.spectra self.mask = estimator.mask self.norm_spectra = self._normalize_spectra() ######## self.params = estimator.params # are this in the normal scale ######## self.dependencies = estimator.dependencies self.params_dict = estimator.params_dict self.sheapmodel = estimator.sheapmodel self.fitkwargs = estimator.fitkwargs self.initial_params = estimator.initial_params self.get_param_coord_value = make_get_param_coord_value(self.params_dict, self.initial_params) # important self.tied_params = self.fitkwargs[-1]["tied"] #the tied params of the last iteration. self.constraints = jnp.asarray(estimator.constraints, dtype=jnp.float32) #this will be moved self.params_class = self._build_params_class() self.best_params = descale_amp(self.params_dict,self.params,self.scale).astype(jnp.float32) #thescaled
[docs] def sample_params(self, num_samples: int = 100, key_seed: int = 0, summarize=True,**kwargs) -> jnp.ndarray: from tqdm import tqdm print(f"Running Monte Carlo with JAX.,sample over the spectra using init params") norm_spectra = self.norm_spectra model = self.model _minimizer = self.make_minimizer(model=model, **self.fitkwargs[-1]) iterator = tqdm(range(num_samples), total=num_samples, desc="Sampling obj") params = jnp.tile(self.initial_params, (norm_spectra.shape[1], 1)) key = random.PRNGKey(key_seed) monte_params = [] for n in iterator: key, ki = random.split(key) norm_spectra_local = resample_spec_all(ki,norm_spectra) t0 = time.perf_counter() params_m, _ = _minimizer(params , *norm_spectra_local, self.constraints)# t1 = time.perf_counter() monte_params.append(params_m) iterator.set_postfix({"it_s": f"{(t1 - t0):.4f}"}) _monte_params = np.moveaxis(np.stack(monte_params),0,1) dic_posterior_params = {} iterator = tqdm(self.names, total=len(self.names), desc="Getting posterior-params") for n, name_i in enumerate(iterator): full_samples = scale_amp(self.params_dict,_monte_params[n],self.scale[n]) dic_posterior_params[name_i] = {"samples_phys":full_samples} dic_posterior_params[name_i] = self.SheaProducts.calculate_sheap_products_sampled(n,full_samples) dic_posterior_params[name_i].update({"samples_phys":full_samples}) return dic_posterior_params
[docs] def sample_params_st(self, num_samples: int = 100, key_seed: int = 0, summarize=True,**kwargs) -> jnp.ndarray: "stable start from the best result." from tqdm import tqdm print(f"Running Monte Carlo with JAX.,sample over the spectra") norm_spectra = self.norm_spectra model = self.model _minimizer = self.make_minimizer(model=model, **self.fitkwargs[-1]) iterator = tqdm(range(num_samples), total=num_samples, desc="Sampling obj") key = random.PRNGKey(key_seed) monte_params = [] for n in iterator: key, ki = random.split(key) norm_spectra_local = resample_spec_all(ki,norm_spectra) t0 = time.perf_counter() params_m, _ = _minimizer(self.best_params, *norm_spectra_local, self.constraints) t1 = time.perf_counter() monte_params.append(params_m) iterator.set_postfix({"it_s": f"{(t1 - t0):.4f}"}) _monte_params = np.moveaxis(np.stack(monte_params),0,1) dic_posterior_params = {} iterator = tqdm(self.names, total=len(self.names), desc="Getting posterior-params") for n, name_i in enumerate(iterator): full_samples = scale_amp(self.params_dict,_monte_params[n],self.scale[n]) dic_posterior_params[name_i] = {"samples_phys":full_samples} dic_posterior_params[name_i] = self.SheaProducts.calculate_sheap_products_sampled(n,full_samples) dic_posterior_params[name_i].update({"samples_phys":full_samples}) return dic_posterior_params
[docs] def sample_params_experimental(self, num_samples: int = 100, key_seed: int = 0, summarize=True ,return_only_draws=False,frac_box_sigma=0.5, k_sigma= 0.5 ) -> jnp.ndarray: #it looks like this only work for frac_box_sigma=0.02,k_sigma=0.3 limits from tqdm import tqdm print(f"Running Monte Carlo with JAX.,frac_box_sigma={frac_box_sigma},k_sigma={k_sigma}") model = self.model norm_spectra = self.norm_spectra best_params = self.best_params _, draws_phys = phys_trust_region_inits(key_seed, params_class=self.params_class, best_params=best_params, phys_bounds=self.constraints, num_samples=num_samples, frac_box_sigma= frac_box_sigma,k_sigma=k_sigma ) draws_phys = draws_phys.astype(jnp.float32) # ensure consistent dtype if return_only_draws: iterator = tqdm(self.names, total=len(self.names), desc="Getting draws") _draws_phys = np.moveaxis(draws_phys,0,1) dic_posterior_params = {} for n, name_i in enumerate(iterator): draws_phys_n = scale_amp(self.params_dict,np.array(_draws_phys[n]),self.scale[n]) dic_posterior_params[name_i]=({"draws_phys":draws_phys_n}) return dic_posterior_params _minimizer = self.make_minimizer(model=model, **self.fitkwargs[-1]) iterator = tqdm(range(num_samples), total=num_samples, desc="Sampling obj") monte_params = [] for n in iterator: draws_phys_local = draws_phys[n] # already float32 t0 = time.perf_counter() params_m, _ = _minimizer(draws_phys_local, *norm_spectra, self.constraints) t1 = time.perf_counter() monte_params.append(params_m) iterator.set_postfix({"it_s": f"{(t1 - t0):.4f}"}) _monte_params = np.moveaxis(np.stack(monte_params),0,1) _draws_phys = np.moveaxis(draws_phys,0,1) dic_posterior_params = {} iterator = tqdm(self.names, total=len(self.names), desc="Getting posterior-params") for n, name_i in enumerate(iterator): full_samples = scale_amp(self.params_dict,_monte_params[n],self.scale[n]) draws_phys_n = scale_amp(self.params_dict,np.array(_draws_phys[n]),self.scale[n]) dic_posterior_params[name_i] = self.SheaProducts.calculate_sheap_products_sampled(n,full_samples) dic_posterior_params[name_i].update({"samples_phys":full_samples,"draws_phys":draws_phys_n}) return dic_posterior_params
[docs] def make_minimizer(self,model,non_optimize_in_axis,num_steps,learning_rate, method,penalty_weight,curvature_weight,smoothness_weight,max_weight,penalty_function=None,weighted=True,**kwargs): minimizer = Minimizer(model,non_optimize_in_axis=non_optimize_in_axis,num_steps=num_steps,weighted=weighted, learning_rate=learning_rate,param_converter= self.params_class,penalty_function = penalty_function,method=method, penalty_weight= penalty_weight,curvature_weight= curvature_weight,smoothness_weight= smoothness_weight,max_weight= max_weight) return minimizer
def _normalize_spectra(self): "from the clasical shape to the one that is use during the fitting" scale = jnp.atleast_1d(self.scale.astype(jnp.float32)) spectra = self.spectra.astype(jnp.float32) norm_spectra = spectra.at[:, [1, 2], :].divide(jnp.moveaxis(jnp.tile(scale, (2, 1)), 0, 1)[:, :, None]) norm_spectra = norm_spectra.at[:, 2, :].set(jnp.where(self.mask, 1e31, norm_spectra[:, 2, :])) return norm_spectra.astype(jnp.float32).transpose(1, 0, 2) def _build_params_class(self): dependencies = build_tied(self.tied_params,self.get_param_coord_value) list_dependencies = parse_dependencies(dependencies) tied_map = {T[1]: T[2:] for T in list_dependencies} tied_map = flatten_tied_map(tied_map) params_class = build_Parameters(tied_map,self.params_dict,self.initial_params,self.constraints) return params_class