sheap.Minimizer.loss_builder module
Loss Function Builder
This module defines the construction of flexible loss functions used in sheap for spectral fitting and optimization.
Contents
- build_loss_function: Factory for JAX-compatible scalar loss functions
combining residuals, penalties, and regularization.
Loss Components
The constructed loss may include the following terms:
Data fidelity (log-cosh residuals)
Optional penalty on parameters
Curvature matching
Residual smoothness
Notes
penalty_function can enforce additional physics or priors.
param_converter allows transformation from raw to physical parameters.
All terms are implemented using JAX and are fully differentiable.
Example
from sheap.Minimizer.loss_builder import build_loss_function
loss_fn = build_loss_function(model_fn, weighted=True, curvature_weight=1e4)
loss_val = loss_fn(params, x_grid, flux, flux_err)
- build_loss_function(func, weighted=True, penalty_function=None, penalty_weight=0.01, param_converter=None, curvature_weight=1000.0, smoothness_weight=100000.0, max_weight=0.1)[source]
Build a flexible JAX-compatible loss function for regression-style modeling tasks.
This loss function combines several components:
1. Data term using log-cosh residuals
\[\text{data} = \operatorname{mean}(\log\cosh(r)) + \alpha \cdot \max(\log\cosh(r)), \quad \text{where } r = \frac{y_\text{pred} - y}{y_\text{err}}\]2. Optional penalty term on parameters
\[\text{penalty} = \beta \cdot \text{penalty\_function}(x, \theta)\]3. Optional curvature matching (second derivative difference)
\[\text{curvature} = \gamma \cdot \operatorname{mean}[(f''_\text{pred} - f''_\text{true})^2]\]4. Optional smoothness penalty on the residuals
\[\text{smoothness} = \delta \cdot \operatorname{mean}[(\nabla r)^2]\]- Parameters:
func (Callable) – The prediction function, called as
func(xs, phys_params), returningy_pred.weighted (bool, default=True) – Whether to apply inverse error weighting to the residuals.
penalty_function (Callable, optional) – A callable penalty term
penalty(xs, params) → scalar loss, scaled bypenalty_weight.penalty_weight (float, default=0.01) – Coefficient for the penalty function term.
param_converter (Parameters, optional) – Object with a
raw_to_physmethod to convert raw to physical parameters.curvature_weight (float, default=1e3) – Coefficient for the second-derivative matching term.
smoothness_weight (float, default=1e5) – Coefficient for smoothness of the residuals.
max_weight (float, default=0.1) – Weight for the maximum log-cosh residual relative to the mean.
- Returns:
A loss function with signature
(params, xs, y, yerr) → scalar, whereparamsare raw parameters (optionally converted to physical).- Return type:
Callable