sheap.Minimizer.loss_builder module
Loss Function Builder
This module defines the construction of flexible loss functions used in sheap for spectral fitting and optimization.
Contents
- build_loss_function: Factory for JAX-compatible scalar loss functions
combining residuals, penalties, and regularization.
Loss Components
The constructed loss may include the following terms:
Data fidelity (log-cosh residuals)
Optional penalty on parameters
Curvature matching
Residual smoothness
Notes
penalty_function can enforce additional physics or priors.
param_converter allows transformation from raw to physical parameters.
All terms are implemented using JAX and are fully differentiable.
Example
from sheap.Minimizer.loss_builder import build_loss_function
loss_fn = build_loss_function(model_fn, weighted=True, curvature_weight=1e4)
loss_val = loss_fn(params, x_grid, flux, flux_err)
- build_loss_function(func, weighted=True, penalty_function=None, penalty_weight=0.01, param_converter=None, curvature_weight=1000.0, smoothness_weight=100000.0, max_weight=0.1)[source]
Optimizations vs original
d2y_true is pre-computed ONCE at build time — it’s a constant, never changes.
log_cosh mean+max fused into a single jnp.nanmean / jnp.max call on the same array (computed once, not twice).
Four duplicated closure branches collapsed to one.
log_cosh uses a numerically stable form that avoids the logaddexp overhead for large |x|: jnp.abs(x) + jnp.log1p(jnp.exp(-2*jnp.abs(x))) - log2.
@jax.jit on the returned loss so the optimizer doesn’t retrace.
- Parameters:
func (Callable)
weighted (bool)
penalty_function (Callable | None)
penalty_weight (float)
curvature_weight (float)
smoothness_weight (float)
max_weight (float)
- Return type:
Callable