sheap.Minimizer.loss_builder module

Loss Function Builder

This module defines the construction of flexible loss functions used in sheap for spectral fitting and optimization.

Contents

  • build_loss_function: Factory for JAX-compatible scalar loss functions

    combining residuals, penalties, and regularization.

Loss Components

The constructed loss may include the following terms:

  1. Data fidelity (log-cosh residuals)

\[\mathcal{L}_\text{data} = \langle \log\cosh(r) \rangle + \alpha \, \max(\log\cosh(r)), \quad r = \frac{y_\text{pred} - y}{\sigma}\]
  1. Optional penalty on parameters

\[\mathcal{L}_\text{penalty} = \beta \, \text{penalty\_function}(x, \theta)\]
  1. Curvature matching

\[\mathcal{L}_\text{curvature} = \gamma \, \langle (f''_\text{pred} - f''_\text{true})^2 \rangle\]
  1. Residual smoothness

\[\mathcal{L}_\text{smoothness} = \delta \, \langle (\nabla r)^2 \rangle\]

Notes

  • penalty_function can enforce additional physics or priors.

  • param_converter allows transformation from raw to physical parameters.

  • All terms are implemented using JAX and are fully differentiable.

Example

from sheap.Minimizer.loss_builder import build_loss_function

loss_fn = build_loss_function(model_fn, weighted=True, curvature_weight=1e4)
loss_val = loss_fn(params, x_grid, flux, flux_err)
build_loss_function(func, weighted=True, penalty_function=None, penalty_weight=0.01, param_converter=None, curvature_weight=1000.0, smoothness_weight=100000.0, max_weight=0.1)[source]

Build a flexible JAX-compatible loss function for regression-style modeling tasks.

This loss function combines several components:

1. Data term using log-cosh residuals

\[\text{data} = \operatorname{mean}(\log\cosh(r)) + \alpha \cdot \max(\log\cosh(r)), \quad \text{where } r = \frac{y_\text{pred} - y}{y_\text{err}}\]

2. Optional penalty term on parameters

\[\text{penalty} = \beta \cdot \text{penalty\_function}(x, \theta)\]

3. Optional curvature matching (second derivative difference)

\[\text{curvature} = \gamma \cdot \operatorname{mean}[(f''_\text{pred} - f''_\text{true})^2]\]

4. Optional smoothness penalty on the residuals

\[\text{smoothness} = \delta \cdot \operatorname{mean}[(\nabla r)^2]\]
Parameters:
  • func (Callable) – The prediction function, called as func(xs, phys_params), returning y_pred.

  • weighted (bool, default=True) – Whether to apply inverse error weighting to the residuals.

  • penalty_function (Callable, optional) – A callable penalty term penalty(xs, params) scalar loss, scaled by penalty_weight.

  • penalty_weight (float, default=0.01) – Coefficient for the penalty function term.

  • param_converter (Parameters, optional) – Object with a raw_to_phys method to convert raw to physical parameters.

  • curvature_weight (float, default=1e3) – Coefficient for the second-derivative matching term.

  • smoothness_weight (float, default=1e5) – Coefficient for smoothness of the residuals.

  • max_weight (float, default=0.1) – Weight for the maximum log-cosh residual relative to the mean.

Returns:

A loss function with signature (params, xs, y, yerr) scalar, where params are raw parameters (optionally converted to physical).

Return type:

Callable