Source code for sheap.Utils.UncertaintyFunction

r"""
Uncertainty Estimation via Residuals
====================================

This module provides utilities to estimate parameter uncertainties
after a fit, using residuals and Jacobian-based covariance
approximations.

Main Features
-------------
- Compute normalized residuals between data and model
  (:func:`residuals`).
- Build residual functions for *free parameters only*, taking into
  account tied and fixed relationships
  (:func:`make_residuals_free_fn`).
- Estimate covariance matrices from the Jacobian of residuals with
  respect to free parameters
  (:func:`error_covariance_matrix`).
- Loop over spectra to propagate uncertainties back into the full
  parameter vector, respecting ties/fixed constraints
  (:func:`Errorfromloop`, :func:`error_for_loop_s`).

Public API
----------
- :func:`residuals`:
    Compute (y - model)/σ residuals for a given parameter vector.
- :func:`make_residuals_free_fn`:
    Construct a callable that maps free parameters → residuals,
    restoring tied/fixed values internally.
- :func:`error_covariance_matrix`:
    Estimate uncertainties for free parameters from the JTJ matrix.
- :func:`Errorfromloop`:
    Iterate over multiple spectra, returning uncertainty arrays
    mapped back into the full parameter space.
- :func:`error_for_loop_s`:
    Simplified variant of :func:`Errorfromloop`.

Notes
-----
- The covariance is estimated with the usual
  :math:`(J^T J)^{-1} \, s^2` approximation, where *J* is the Jacobian
  of residuals and *s^2* is the residual variance.
- Tied and fixed parameters are reconstructed using
  :func:`sheap.Assistants.parser_mapper.apply_tied_and_fixed_params`.
- Regularization is applied to stabilize ill-conditioned inversions.
"""

__author__ = 'felavila'

__all__ = [
    "Errorfromloop",
    "error_covariance_matrix",
    "error_for_loop_s",
    "make_residuals_free_fn",
    "residuals",
]

from typing import Callable, Tuple, Union

import jax
import jax.numpy as jnp
from jax import random,vmap

from sheap.Assistants.parser_mapper import apply_tied_and_fixed_params

#TODO This requires major updates 

[docs] def residuals( func: Callable, params: jnp.ndarray, xs: jnp.ndarray, y: jnp.ndarray, y_uncertainties: jnp.ndarray, ) -> jnp.ndarray: predictions = func(xs, params) return (y - predictions) / y_uncertainties
[docs] def make_residuals_free_fn( model_func: Callable, xs: jnp.ndarray, y: jnp.ndarray, yerr: jnp.ndarray, template_params: jnp.ndarray, dependencies ) -> Callable: def residual_fn(free_params: jnp.ndarray) -> jnp.ndarray: full_params = apply_tied_and_fixed_params(free_params,template_params,dependencies) return residuals(model_func, full_params, xs, y, yerr) return residual_fn
[docs] def error_covariance_matrix( residual_fn: Callable, params_i: jnp.ndarray, xs_i: jnp.ndarray, y_i: jnp.ndarray, yerr_i: jnp.ndarray, free_params: int, return_full: bool = False, regularization: float = 1e-6, overboost_threshold: float = 1e10, ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]: """ Estimate uncertainty for free parameters using JTJ approximation. TODO: CHECK IF THIS CAN BE UPGRADED """ mask = yerr_i < overboost_threshold if jnp.sum(mask) == 0: fallback = jnp.abs(params_i) * 5.0 + 1.0 return (fallback, jnp.diag(fallback**2)) if return_full else fallback #xs_valid, y_valid, yerr_valid = xs_i[mask], y_i[mask], yerr_i[mask] residual = residual_fn(params_i)[mask]#.astype(jnp.float32) if jnp.any(jnp.isnan(residual)) or jnp.any(jnp.isinf(residual)): fallback = jnp.abs(params_i) * 5.0 + 1.0 return (fallback, jnp.diag(fallback**2)) if return_full else fallback jacobian = jax.jacobian(residual_fn)(params_i)#.astype(jnp.float32) JTJ = jacobian.T @ jacobian dof = max(residual.size - free_params, 1) #to avoid fall back in negatives values s_sq = jnp.sum(residual**2) / dof reg = regularization * jnp.eye(JTJ.shape[0]) try: cov = jnp.linalg.inv(JTJ + reg) * s_sq except: cov = jnp.linalg.pinv(JTJ) * s_sq diag_cov = jnp.clip(jnp.diag(cov), a_min=1e-20) std_error = jnp.sqrt(diag_cov) return (std_error, cov) if return_full else std_error
[docs] def Errorfromloop(model, spectra, params, dependencies): spectra = jnp.asarray(spectra, dtype=jnp.float64) params = jnp.asarray(params, dtype=jnp.float64) # unpack: spectra has shape (batch, 3, pixels) after moveaxis wl, flux, yerr = jnp.moveaxis(spectra, 0, 1) # identify which params are free vs tied idx_target = [i[1] for i in dependencies] idx_free_params = list(set(range(params.shape[-1])) - set(idx_target)) # 2) accumulator in float32 std = jnp.zeros_like(params) # 3) loop over each object for n, (p_i, wl_i, fl_i, err_i) in enumerate(zip(params, wl, flux, yerr)): # re-cast each slice for safety #p_i = p_i.astype(jnp.float32) #wl_i = wl_i.astype(jnp.float32) #fl_i = fl_i.astype(jnp.float32) #err_i = err_i.astype(jnp.float32) # pick out the free params (already float32) free_p = p_i[jnp.array(idx_free_params)] # make your residual-fn; assume it handles float32 okay res_fn = make_residuals_free_fn( model_func = model, xs = wl_i, y = fl_i, yerr = err_i, template_params = p_i, dependencies = dependencies ) std_errs, _ = error_covariance_matrix( residual_fn = res_fn, params_i = free_p, xs_i = wl_i, y_i = fl_i, yerr_i = err_i, free_params = free_p.shape[0], return_full = True ) # apply your ties/fixes and store back into the float32 array tied_full = apply_tied_and_fixed_params(std_errs, params[0], dependencies) std = std.at[n].set(tied_full) return std
[docs] def error_for_loop_s(model,spectra,params,dependencies): "save the samples could increase the number of stuff." wl, flux, yerr = jnp.moveaxis(spectra, 0, 1) idx_target = [i[1] for i in dependencies] idx_free_params = list(set(range(len(params[0])))-set(idx_target)) std = jnp.zeros_like(params).astype(jnp.float32) for n, (params_i, wl_i, flux_i, yerr_i) in enumerate(zip(params, wl, flux, yerr)): free_params = params_i[jnp.array(idx_free_params)] res_fn = make_residuals_free_fn(model_func=model, xs=wl_i,y=flux_i, yerr=yerr_i, template_params=params_i, dependencies=dependencies) std_errs, _ = error_covariance_matrix(residual_fn=res_fn, params_i=free_params, xs_i=wl_i, y_i=flux_i, yerr_i=yerr_i, free_params=len(free_params), return_full=True) std = std.at[n].set(apply_tied_and_fixed_params(std_errs,params[0],dependencies)) return std
#@jax.jit