Source code for sheap.MasterSampler.Samplers.Utils.numpyro_utils

r"""
NumPyro Model Utilities
=======================
.. note::
    This require cleaning
    
Helper functions to construct NumPyro models from sheap parameters,
including dictionary conversion, handling of tied parameters, and
safe initialization near constraints.

Main Features
-------------
- Apply arithmetic ties between parameters.
- Convert flat parameter arrays to dictionaries for NumPyro.
- Initialize values safely away from hard bounds.
- Build a full NumPyro model for flux fitting with uncertainty.


Notes
-----
- Dependencies are expressed as tuples
  ``(tag, target_idx, src_idx, op, val)`` where ``op`` is one of
  ``{'+','-','*','/'}``.
- Parameters tied via dependencies are excluded from sampling and
  reconstructed after free parameters are sampled.
"""

__author__ = 'felavila'

__all__ = [
    "NumPyro_apply_arithmetic_ties",
    "make_numpyro_model",
    "NumPyro_params_to_dict",
]

from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist




[docs] def make_numpyro_model(name_list,wl,flux,sigma,constraints,params_i,theta_to_sheap,fixed_params,dependencies,model_func): r""" Build a NumPyro probabilistic model for spectrum fitting. Parameters ---------- name_list : list of str Names of NumPyro parameters (e.g. ``["theta_0", "theta_1", ...]``). wl : jnp.ndarray Wavelength grid, shape (n_pix,). flux : jnp.ndarray Observed flux, shape (n_pix,). sigma : jnp.ndarray Flux uncertainties, shape (n_pix,). constraints : list of tuple Bounds for each parameter as (low, high). params_i : array-like Initial parameter values. theta_to_sheap : dict Mapping from NumPyro parameter name to SHEAP parameter name. fixed_params : dict Dictionary of fixed parameter values keyed by SHEAP names. dependencies : list of tuple List of ties (see :func:`apply_arithmetic_ties`). model_func : Callable Spectral model function with signature ``model_func(wl, theta)``. Returns ------- numpyro_model : Callable A function defining the NumPyro model. init_value : dict Dictionary of initial parameter values for ``init_to_value``. Notes ----- The model: .. math:: \\theta_i \sim \mathrm{Uniform}(low_i, high_i) f(\\lambda) = \\mathrm{model\_func}(\\lambda, \\theta) y(\\lambda) \sim \mathcal{N}(f(\\lambda), \sigma) - Dependent (tied) parameters are excluded from sampling and reconstructed afterward. - Fixed parameters are held at their provided values. """ init_values = {key: params_i[_] for _,key in enumerate(theta_to_sheap.values())} init_value = NumPyro_params_to_dict(params_i,dependencies,constraints) if not dependencies: dependencies = () def numpyro_model(): params = {} idx_targets = [i[1] for i in dependencies] for i, (name, (low, high)) in enumerate(zip(name_list, constraints)): sheap_name = theta_to_sheap[name] if i in idx_targets: #print(i,idx_targets,dependencies) continue # skip tied targets; they'll be calculated later elif sheap_name in fixed_params.keys(): val = fixed_params[sheap_name] if val is None: val = init_values.get(sheap_name) if val is None: raise ValueError(f"Fixed param '{sheap_name}' is None and not found in init_values.") else: val = numpyro.sample(name, dist.Uniform(low, high)) params[name] = val params = NumPyro_apply_arithmetic_ties(params, dependencies) theta = jnp.array([params[name] for name in name_list]) pred = model_func(wl, theta) numpyro.sample("obs", dist.Normal(pred, sigma), obs=flux) return numpyro_model,init_value
[docs] def NumPyro_apply_arithmetic_ties(params: Dict[str, float], dependencies: List[Tuple]) -> Dict[str, float]: r""" Apply arithmetic ties to update dependent parameters in place. #TODO this is already in other place? Parameters ---------- params : dict Dictionary of parameter values keyed by ``"theta_{i}"``. dependencies : list of tuple List of constraints of the form: ``(tag, target_idx, src_idx, op, val)``. Returns ------- dict Updated parameter dictionary with tied targets assigned. Notes ----- - Supported operations are ``+``, ``-``, ``*``, ``/``. - Example: if ``theta_3 = theta_1 * 2``, then dependency is ``("arith", 3, 1, "*", 2.0)``. """ for tag, target_idx,src_idx, op, val in dependencies: src = params[f"theta_{src_idx}"] if op == '+': result = src + val elif op == '-': result = src - val elif op == '*': result = src * val elif op == '/': result = src / val else: raise ValueError(f"Unsupported operation: {op}") params[f"theta_{target_idx}"] = result return params
[docs] def NumPyro_params_to_dict(params, dependencies, constraints=None, eps=1e-2): r""" Convert flat parameters into a dictionary for NumPyro initialization. Excludes dependent parameters and applies jitter to avoid initializing exactly at bounds. Parameters ---------- params : array-like Flat parameter values (e.g. from a minimizer). dependencies : list of tuple Ties of the form ``(tag, target_idx, src_idx, op, val)``. constraints : dict or list, optional Parameter bounds. Either a dict mapping ``theta_i -> (low, high)`` or a list aligned with params. eps : float, default=1e-2 Minimum offset from bounds to avoid initialization at the edge. Returns ------- dict Dictionary of free parameters, keyed as ``theta_i``. Notes ----- - If a parameter value is invalid (NaN, inf, outside bounds), it is reset to the midpoint of its allowed range. - Exact zeros are nudged to ``eps`` to avoid degenerate starts. """ if not dependencies: dependencies = () target_idx_list = [d[1] for d in dependencies] init_dict = {} for idx, val in enumerate(params): if idx in target_idx_list: continue # skip dependent (tied) parameters name = f"theta_{idx}" val = float(val) # Apply jitter if constraints are available if constraints is not None: if isinstance(constraints, dict): low, high = constraints.get(name, (-jnp.inf, jnp.inf)) else: # assume list aligned with params low, high = constraints[idx] if not jnp.isfinite(val) or val <= low or val >= high: val = 0.5 * (low + high) if jnp.isfinite(low) and jnp.isfinite(high) else val if val - low < eps: val = low + eps elif high - val < eps: val = high - eps # Optional: prevent exact zero if val == 0: val = eps init_dict[name] = val return init_dict
# def params_to_dict_old(params,dependencies): # r""" # Convert parameters to a dict (legacy version). # Parameters # ---------- # params : array-like # Flat list of parameter values. # dependencies : list of tuple # Constraints (same format as in :func:`apply_arithmetic_ties`). # Returns # ------- # dict # Dictionary mapping free parameter indices to values, # keyed as ``theta_i``. # """ # #tag,target_idx,src_idx, op, val = dependencies # target_idx_list = [d[1] for d in dependencies] # init_value = {} # for idx,val in enumerate(params): # if idx in target_idx_list: # continue # else: # init_value[f"theta_{idx}"] = val # return init_value