Source code for sheap.SheaProducts.Utils.Helpers

"""
Helpers
=========================

This module provides helper routines to evaluate and integrate
spectral line profiles while propagating parameter and/or grid
uncertainties. These are used primarily in the ``ComplexParams``
pipeline when computing derived quantities (flux, FWHM, luminosity,
etc.) from fitted or sampled parameter sets.

Main Features
-------------
- Numerical integration of profile functions with uncertainty
  propagation via JAX autodiff.
- Batched integration and evaluation for multiple lines/objects.
- Support for error propagation from both parameter uncertainties
  and wavelength (x) uncertainties.
- JAX-compatible (vectorized with ``vmap``, differentiable).

Public API
----------
- :func:`trapz_jax`
- :func:`integrate_function_error`
- :func:`integrate_function_error_single`
- :func:`integrate_batch_with_error`
- :func:`evaluate_with_error`
- :func:`batched_evaluate`
"""

__author__ = 'felavila'

__all__ = [
    "trapz_jax",
    "integrate_function_error",
    "integrate_function_error_single",
    "integrate_batch_with_error",
    "evaluate_with_error",
    "batched_evaluate",
]

from typing import Callable, Tuple
import jax.numpy as jnp
from jax import vmap, grad, jit


[docs] def trapz_jax(y: jnp.ndarray, x: jnp.ndarray) -> jnp.ndarray: """ Trapezoidal integration along a 1D grid using JAX. Parameters ---------- y : jnp.ndarray Function values on the grid ``x``. x : jnp.ndarray Monotonic grid points. Returns ------- jnp.ndarray Scalar integral :math:`\\int y(x) \\, dx` approximated with the trapezoid rule. """ dx = x[1:] - x[:-1] return jnp.sum((y[1:] + y[:-1]) * dx / 2)
[docs] def integrate_function_error_single(function: Callable, x: jnp.ndarray, p: jnp.ndarray, sigma_p: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Integrate a single profile and propagate parameter errors. .. math:: F = \\int f(\\lambda; p) \\, d\\lambda with uncertainty propagation via linearization. Parameters ---------- function : Callable Profile function with signature ``function(x, p)``. x : jnp.ndarray 1D integration grid. p : jnp.ndarray Parameter vector. sigma_p : jnp.ndarray 1σ uncertainty per parameter. Returns ------- y_int : jnp.ndarray Integrated value. sigma_f : jnp.ndarray Propagated 1σ uncertainty. """ def int_function(pp): return trapz_jax(function(x, pp), x) y_int = int_function(p) grad_f = grad(int_function)(p) sigma_f = jnp.sqrt(jnp.sum((grad_f * sigma_p) ** 2)) return y_int, sigma_f
[docs] def integrate_batch_with_error(function: Callable, x: jnp.ndarray, p: jnp.ndarray, sigma_p: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Batched integration with parameter uncertainty propagation. Parameters ---------- function : Callable Profile function. x : jnp.ndarray 1D integration grid. p : jnp.ndarray Parameters, shape (N, L, P). sigma_p : jnp.ndarray Uncertainties, shape (N, L, P). Returns ------- y_batch : jnp.ndarray Integrated values, shape (N, L). sigma_batch : jnp.ndarray Propagated uncertainties, shape (N, L). """ n, lines, params = p.shape p_flat = p.reshape((n * lines, params)) sigma_flat = sigma_p.reshape((n * lines, params)) batched_integrator = vmap( lambda pp, sp: integrate_function_error_single(function, x, pp, sp), in_axes=(0, 0), out_axes=(0, 0) ) y_flat, sigma_flat_out = batched_integrator(p_flat, sigma_flat) return y_flat.reshape((n, lines)), sigma_flat_out.reshape((n, lines))
[docs] def integrate_function_error(function: Callable, x: jnp.ndarray, p: jnp.ndarray, sigma_p: jnp.ndarray = None) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Integrate a profile and propagate parameter uncertainties. .. math:: F = \\int f(\\lambda; p) \\, d\\lambda Parameters ---------- function : Callable Profile function ``f(x, p)``. x : jnp.ndarray Grid over which to integrate. p : jnp.ndarray Parameters. sigma_p : jnp.ndarray, optional 1σ parameter uncertainties. If None, treated as zero. Returns ------- y_int : jnp.ndarray Integral value. sigma_f : jnp.ndarray Propagated uncertainty. """ p = jnp.atleast_1d(p) sigma_p = jnp.zeros_like(p) if sigma_p is None else jnp.atleast_1d(sigma_p) def int_function(pp): return trapz_jax(function(x, pp), x) y_int = int_function(p) grad_f = grad(int_function)(p) sigma_f = jnp.sqrt(jnp.sum((grad_f * sigma_p) ** 2)) return y_int, sigma_f
[docs] def evaluate_with_error(function: Callable, x: jnp.ndarray, p: jnp.ndarray, sigma_x: jnp.ndarray = None, sigma_p: jnp.ndarray = None) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Evaluate a profile and propagate 1σ errors in both x and p. .. math:: \\sigma_y^2 = \\left( \\frac{\\partial f}{\\partial x} \\sigma_x \\right)^2 + \\sum_i \\left( \\frac{\\partial f}{\\partial p_i} \\sigma_{p_i} \\right)^2 Parameters ---------- function : Callable Profile function ``f(x, p)``. x : jnp.ndarray Grid, shape (N, L). p : jnp.ndarray Parameters, shape (N, P). sigma_x : jnp.ndarray, optional Uncertainty on x, shape (N, L). sigma_p : jnp.ndarray, optional Uncertainty on p, shape (N, P). Returns ------- y : jnp.ndarray Function values. yerr : jnp.ndarray Propagated uncertainties. """ if sigma_x is None: sigma_x = jnp.zeros_like(x) if sigma_p is None: sigma_p = jnp.zeros_like(p) n, lines = x.shape _, P = p.shape p_exp = jnp.broadcast_to(p[:, None, :], (n, lines, P)) sp_exp = jnp.broadcast_to(sigma_p[:, None, :], (n, lines, P)) flat_size = n * lines x_flat = x.reshape((flat_size,)) sx_flat = sigma_x.reshape((flat_size,)) p_flat = p_exp.reshape((flat_size, P)) sp_flat = sp_exp.reshape((flat_size, P)) def single_eval(xv, pv, sxv, spv): y = function(xv, pv) dyx = grad(function, argnums=0)(xv, pv) dyp = grad(function, argnums=1)(xv, pv) var = (dyx * sxv)**2 + jnp.sum((dyp * spv)**2) return y, jnp.sqrt(var) y_flat, err_flat = vmap(single_eval, in_axes=(0, 0, 0, 0), out_axes=(0, 0))( x_flat, p_flat, sx_flat, sp_flat ) return y_flat.reshape((n, lines)), err_flat.reshape((n, lines))
[docs] def batched_evaluate(function: Callable, x: jnp.ndarray, p: jnp.ndarray, sigma_p: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Batched evaluation with parameter uncertainties only. Parameters ---------- function : Callable Profile function. x : jnp.ndarray Independent variable(s). p : jnp.ndarray Parameters, shape (N, L, P). sigma_p : jnp.ndarray Parameter uncertainties, shape (N, L, P). Returns ------- f_batch : jnp.ndarray Function values, shape (N, L). err_batch : jnp.ndarray Propagated errors, shape (N, L). """ n, lines, P = p.shape p_flat = p.reshape((n * lines, P)) sigma_flat = sigma_p.reshape((n * lines, P)) def single_eval(pp, sp): f_val, f_err = evaluate_with_error(function, x, pp[None], sigma_p=sp[None]) return f_val.squeeze(), f_err.squeeze() f_flat, err_flat = vmap(single_eval, in_axes=(0, 0), out_axes=(0, 0))(p_flat, sigma_flat) return f_flat.reshape((n, lines)), err_flat.reshape((n, lines))