sheap.Minimizer.loss_builder module

Loss Function Builder

This module defines the construction of flexible loss functions used in sheap for spectral fitting and optimization.

Contents

  • build_loss_function: Factory for JAX-compatible scalar loss functions

    combining residuals, penalties, and regularization.

Loss Components

The constructed loss may include the following terms:

  1. Data fidelity (log-cosh residuals)

\[\mathcal{L}_\text{data} = \langle \log\cosh(r) \rangle + \alpha \, \max(\log\cosh(r)), \quad r = \frac{y_\text{pred} - y}{\sigma}\]
  1. Optional penalty on parameters

\[\mathcal{L}_\text{penalty} = \beta \, \text{penalty\_function}(x, \theta)\]
  1. Curvature matching

\[\mathcal{L}_\text{curvature} = \gamma \, \langle (f''_\text{pred} - f''_\text{true})^2 \rangle\]
  1. Residual smoothness

\[\mathcal{L}_\text{smoothness} = \delta \, \langle (\nabla r)^2 \rangle\]

Notes

  • penalty_function can enforce additional physics or priors.

  • param_converter allows transformation from raw to physical parameters.

  • All terms are implemented using JAX and are fully differentiable.

Example

from sheap.Minimizer.loss_builder import build_loss_function

loss_fn = build_loss_function(model_fn, weighted=True, curvature_weight=1e4)
loss_val = loss_fn(params, x_grid, flux, flux_err)
build_loss_function(func, weighted=True, penalty_function=None, penalty_weight=0.01, param_converter=None, curvature_weight=1000.0, smoothness_weight=100000.0, max_weight=0.1)[source]

Optimizations vs original

  1. d2y_true is pre-computed ONCE at build time — it’s a constant, never changes.

  2. log_cosh mean+max fused into a single jnp.nanmean / jnp.max call on the same array (computed once, not twice).

  3. Four duplicated closure branches collapsed to one.

  4. log_cosh uses a numerically stable form that avoids the logaddexp overhead for large |x|: jnp.abs(x) + jnp.log1p(jnp.exp(-2*jnp.abs(x))) - log2.

  5. @jax.jit on the returned loss so the optimizer doesn’t retrace.

Parameters:
  • func (Callable)

  • weighted (bool)

  • penalty_function (Callable | None)

  • penalty_weight (float)

  • curvature_weight (float)

  • smoothness_weight (float)

  • max_weight (float)

Return type:

Callable