sheap.SheaProducts.Utils.Helpers module
Helpers
This module provides helper routines to evaluate and integrate
spectral line profiles while propagating parameter and/or grid
uncertainties. These are used primarily in the ComplexParams
pipeline when computing derived quantities (flux, FWHM, luminosity,
etc.) from fitted or sampled parameter sets.
Main Features
Numerical integration of profile functions with uncertainty propagation via JAX autodiff.
Batched integration and evaluation for multiple lines/objects.
Support for error propagation from both parameter uncertainties and wavelength (x) uncertainties.
JAX-compatible (vectorized with
vmap, differentiable).
Public API
- trapz_jax(y, x)[source]
Trapezoidal integration along a 1D grid using JAX.
- Parameters:
y (jnp.ndarray) – Function values on the grid
x.x (jnp.ndarray) – Monotonic grid points.
- Returns:
Scalar integral \(\int y(x) \, dx\) approximated with the trapezoid rule.
- Return type:
jnp.ndarray
- integrate_function_error(function, x, p, sigma_p=None)[source]
Integrate a profile and propagate parameter uncertainties.
\[F = \int f(\lambda; p) \, d\lambda\]- Parameters:
function (Callable) – Profile function
f(x, p).x (jnp.ndarray) – Grid over which to integrate.
p (jnp.ndarray) – Parameters.
sigma_p (jnp.ndarray, optional) – 1σ parameter uncertainties. If None, treated as zero.
- Returns:
y_int (jnp.ndarray) – Integral value.
sigma_f (jnp.ndarray) – Propagated uncertainty.
- Return type:
Tuple[jax.numpy.ndarray, jax.numpy.ndarray]
- integrate_function_error_single(function, x, p, sigma_p)[source]
Integrate a single profile and propagate parameter errors.
\[F = \int f(\lambda; p) \, d\lambda\]with uncertainty propagation via linearization.
- Parameters:
function (Callable) – Profile function with signature
function(x, p).x (jnp.ndarray) – 1D integration grid.
p (jnp.ndarray) – Parameter vector.
sigma_p (jnp.ndarray) – 1σ uncertainty per parameter.
- Returns:
y_int (jnp.ndarray) – Integrated value.
sigma_f (jnp.ndarray) – Propagated 1σ uncertainty.
- Return type:
Tuple[jax.numpy.ndarray, jax.numpy.ndarray]
- integrate_batch_with_error(function, x, p, sigma_p)[source]
Batched integration with parameter uncertainty propagation.
- Parameters:
function (Callable) – Profile function.
x (jnp.ndarray) – 1D integration grid.
p (jnp.ndarray) – Parameters, shape (N, L, P).
sigma_p (jnp.ndarray) – Uncertainties, shape (N, L, P).
- Returns:
y_batch (jnp.ndarray) – Integrated values, shape (N, L).
sigma_batch (jnp.ndarray) – Propagated uncertainties, shape (N, L).
- Return type:
Tuple[jax.numpy.ndarray, jax.numpy.ndarray]
- evaluate_with_error(function, x, p, sigma_x=None, sigma_p=None)[source]
Evaluate a profile and propagate 1σ errors in both x and p.
\[\sigma_y^2 = \left( \frac{\partial f}{\partial x} \sigma_x \right)^2 + \sum_i \left( \frac{\partial f}{\partial p_i} \sigma_{p_i} \right)^2\]- Parameters:
function (Callable) – Profile function
f(x, p).x (jnp.ndarray) – Grid, shape (N, L).
p (jnp.ndarray) – Parameters, shape (N, P).
sigma_x (jnp.ndarray, optional) – Uncertainty on x, shape (N, L).
sigma_p (jnp.ndarray, optional) – Uncertainty on p, shape (N, P).
- Returns:
y (jnp.ndarray) – Function values.
yerr (jnp.ndarray) – Propagated uncertainties.
- Return type:
Tuple[jax.numpy.ndarray, jax.numpy.ndarray]
- batched_evaluate(function, x, p, sigma_p)[source]
Batched evaluation with parameter uncertainties only.
- Parameters:
function (Callable) – Profile function.
x (jnp.ndarray) – Independent variable(s).
p (jnp.ndarray) – Parameters, shape (N, L, P).
sigma_p (jnp.ndarray) – Parameter uncertainties, shape (N, L, P).
- Returns:
f_batch (jnp.ndarray) – Function values, shape (N, L).
err_batch (jnp.ndarray) – Propagated errors, shape (N, L).
- Return type:
Tuple[jax.numpy.ndarray, jax.numpy.ndarray]