Source code for sheap.Minimizer.loss_builder

r"""
Loss Function Builder
=====================

This module defines the construction of flexible loss functions used in *sheap*
for spectral fitting and optimization.

Contents
--------
- **build_loss_function**: Factory for JAX-compatible scalar loss functions
    combining residuals, penalties, and regularization.

Loss Components
---------------
The constructed loss may include the following terms:

1. **Data fidelity (log-cosh residuals)**

.. math::
    \mathcal{L}_\text{data} =
    \langle \log\cosh(r) \rangle + \alpha \, \max(\log\cosh(r)),
    \quad r = \frac{y_\text{pred} - y}{\sigma}

2. **Optional penalty on parameters**

.. math::
    \mathcal{L}_\text{penalty} =
    \beta \, \text{penalty\_function}(x, \theta)

3. **Curvature matching**

.. math::
    \mathcal{L}_\text{curvature} =
        \gamma \, \langle (f''_\text{pred} - f''_\text{true})^2 \rangle

4. **Residual smoothness**

.. math::
    \mathcal{L}_\text{smoothness} =
        \delta \, \langle (\nabla r)^2 \rangle

Notes
-----
- `penalty_function` can enforce additional physics or priors.
- `param_converter` allows transformation from raw to physical parameters.
- All terms are implemented using JAX and are fully differentiable.

Example
-------
.. code-block:: python

   from sheap.Minimizer.loss_builder import build_loss_function

   loss_fn = build_loss_function(model_fn, weighted=True, curvature_weight=1e4)
   loss_val = loss_fn(params, x_grid, flux, flux_err)
"""

__author__ = 'felavila'

__all__ = [
    "build_loss_function",
]

from typing import Callable, Dict, List, Optional, Tuple

import jax
import jax.numpy as jnp
import optax
from jax import jit, vmap,lax

[docs] def build_loss_function( func: Callable, weighted: bool = True, penalty_function: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None, penalty_weight: float = 0.01, param_converter: Optional["Parameters"] = None, curvature_weight: float = 1e3, # γ: second-derivative match 1e5 smoothness_weight: float = 1e5, # δ: first-derivative smoothness 0.0 max_weight: float = 0.1, # α: weight on worst‐pixel term ) -> Callable: r""" Build a flexible JAX-compatible loss function for regression-style modeling tasks. This loss function combines several components: **1. Data term using log-cosh residuals** .. math:: \text{data} = \operatorname{mean}(\log\cosh(r)) + \alpha \cdot \max(\log\cosh(r)), \quad \text{where } r = \frac{y_\text{pred} - y}{y_\text{err}} **2. Optional penalty term on parameters** .. math:: \text{penalty} = \beta \cdot \text{penalty\_function}(x, \theta) **3. Optional curvature matching (second derivative difference)** .. math:: \text{curvature} = \gamma \cdot \operatorname{mean}[(f''_\text{pred} - f''_\text{true})^2] **4. Optional smoothness penalty on the residuals** .. math:: \text{smoothness} = \delta \cdot \operatorname{mean}[(\nabla r)^2] Parameters ---------- func : Callable The prediction function, called as ``func(xs, phys_params)``, returning ``y_pred``. weighted : bool, default=True Whether to apply inverse error weighting to the residuals. penalty_function : Callable, optional A callable penalty term ``penalty(xs, params) → scalar loss``, scaled by ``penalty_weight``. penalty_weight : float, default=0.01 Coefficient for the penalty function term. param_converter : Parameters, optional Object with a ``raw_to_phys`` method to convert raw to physical parameters. curvature_weight : float, default=1e3 Coefficient for the second-derivative matching term. smoothness_weight : float, default=1e5 Coefficient for smoothness of the residuals. max_weight : float, default=0.1 Weight for the maximum log-cosh residual relative to the mean. Returns ------- Callable A loss function with signature ``(params, xs, y, yerr) → scalar``, where ``params`` are raw parameters (optionally converted to physical). """ #print("smoothness_weight =",smoothness_weight,"penalty_weight =",penalty_weight,"max_weight=",max_weight,"curvature_weight=",curvature_weight) def log_cosh(x): # numerically stable log(cosh(x)) return jnp.logaddexp(x, -x) - jnp.log(2.0) def wrapped(xs, raw_params): phys = param_converter.raw_to_phys(raw_params) if param_converter else raw_params return func(xs, phys) def curvature_term(y_pred, y): d2p = jnp.gradient(jnp.gradient(y_pred, axis=-1), axis=-1) d2o = jnp.gradient(jnp.gradient(y, axis=-1), axis=-1) return jnp.nanmean((d2p - d2o)**2) def smoothness_term(y_pred, y): dr = y_pred - y dp = jnp.gradient(dr, axis=-1) return jnp.nanmean(dp**2) if weighted and penalty_function: def loss(params, xs, y, yerr): y_pred = wrapped(xs, params) r = (y_pred - y) / jnp.clip(yerr, 1e-8) # data term = mean + max Lmean = jnp.nanmean(log_cosh(r)) Lmax = jnp.max (log_cosh(r)) data_term = Lmean + max_weight * Lmax # penalty on params reg_term = penalty_weight * penalty_function(xs, params) * 1e3 # curvature & smoothness curv_term = curvature_weight * curvature_term(y_pred, y) smooth_term = smoothness_weight * smoothness_term(y_pred, y) return data_term + reg_term + curv_term + smooth_term return loss elif weighted: def loss(params, xs, y, yerr): y_pred = wrapped(xs, params) r = (y_pred - y) / jnp.clip(yerr, 1e-8) Lmean = jnp.nanmean(log_cosh(r)) Lmax = jnp.max (log_cosh(r)) data_term = Lmean + max_weight * Lmax curv_term = curvature_weight * curvature_term(y_pred, y) smooth_term = smoothness_weight * smoothness_term(y_pred, y) return data_term + curv_term + smooth_term return loss elif penalty_function: def loss(params, xs, y, yerr): y_pred = wrapped(xs, params) r = (y_pred - y) Lmean = jnp.nanmean(log_cosh(r)) Lmax = jnp.max (log_cosh(r)) data_term = Lmean + max_weight * Lmax reg_term = penalty_weight * penalty_function(xs, params) * 1e3 curv_term = curvature_weight * curvature_term(y_pred, y) smooth_term = smoothness_weight * smoothness_term(y_pred, y) return data_term + reg_term + curv_term + smooth_term return loss else: def loss(params, xs, y, yerr): y_pred = wrapped(xs, params) r = (y_pred - y) Lmean = jnp.nanmean(log_cosh(r)) Lmax = jnp.max (log_cosh(r)) data_term = Lmean + max_weight * Lmax curv_term = curvature_weight * curvature_term(y_pred, y) smooth_term = smoothness_weight * smoothness_term(y_pred, y) return data_term + curv_term + smooth_term return loss
def _solve_weighted_linear_least_squares( A: jnp.ndarray, y: jnp.ndarray, yerr: jnp.ndarray, lambda_reg: float = 0.0, reg_matrix: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: r""" Solve the weighted linear least-squares problem for amplitudes. Minimizes .. math:: \| W (y - A a) \|^2 + \lambda \| R a \|^2, where :math:`W = \mathrm{diag}(1/\sigma)`. Parameters ---------- A : jnp.ndarray Design matrix, shape (n_pix, n_lin). y : jnp.ndarray Observed spectrum, shape (n_pix,). yerr : jnp.ndarray 1-sigma uncertainties, shape (n_pix,). lambda_reg : float, default=0.0 Regularization strength :math:`\lambda`. reg_matrix : jnp.ndarray, optional Regularization operator :math:`R`; if None, the identity is used. Returns ------- jnp.ndarray Optimal amplitudes ``a_star``, shape (n_lin,). """ w = 1.0 / jnp.clip(yerr, 1e-10) Aw = A * w[:, None] # (n_pix, n_lin) yw = y * w # (n_pix,) ATA = Aw.T @ Aw # (n_lin, n_lin) ATy = Aw.T @ yw # (n_lin,) if lambda_reg > 0.0: n_lin = ATA.shape[0] if reg_matrix is None: R = jnp.eye(n_lin, dtype=A.dtype) else: R = reg_matrix ATA = ATA + lambda_reg * (R.T @ R) a_star = jnp.linalg.solve(ATA, ATy) return a_star def build_varpro_loss_function( func: Callable, weighted: bool = True, penalty_function: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None, penalty_weight: float = 0.01, param_converter: Optional["Parameters"] = None, curvature_weight: float = 0.0, smoothness_weight: float = 0.0, max_weight: float = 0.0, lambda_reg: float = 0.0, reg_matrix: Optional[jnp.ndarray] = None, ) -> Callable: r""" Build a loss function using variable projection for linear amplitudes. This variant assumes that a subset of parameters in ``param_converter`` are linear-in-the-model (e.g., amplitudes, weights). Those are *not* optimized directly; instead, for each evaluation of the non-linear parameters, the optimal amplitudes are solved by weighted linear least squares. The optimizer operates only on the non-linear subset of the raw parameter vector, but the loss function signature remains compatible with :func:`build_loss_function`:: loss(params, xs, y, yerr) where here ``params`` are the *non-linear* raw parameters. Parameters ---------- func : Callable Model function called as ``func(xs, phys_params) -> y_pred``. weighted : bool, default=True Whether to use inverse-variance weighting via ``yerr``. penalty_function : Callable, optional Penalty term evaluated as ``penalty(xs, phys_full)`` and scaled by ``penalty_weight``. penalty_weight : float, default=0.01 Global weight for the penalty term. param_converter : Parameters, optional Parameter container with ``raw_to_phys`` and linear/non-linear indices. curvature_weight : float, default=0.0 Coefficient for curvature-matching term. smoothness_weight : float, default=0.0 Coefficient for smoothness of residuals. max_weight : float, default=0.0 Coefficient for the max-logcosh term relative to the mean. lambda_reg : float, default=0.0 Regularization strength for the linear amplitudes. reg_matrix : jnp.ndarray, optional Regularization operator for amplitudes (defaults to identity if None and ``lambda_reg > 0``). Returns ------- Callable A loss function with signature ``loss(raw_nl, xs, y, yerr) -> scalar``, where ``raw_nl`` are the non-linear raw parameters. """ if param_converter is None: raise ValueError("build_varpro_loss_function requires a param_converter.") linear_phys_idx = param_converter.linear_phys_indices nonlinear_raw_idx = param_converter.nonlinear_raw_indices n_free = len(param_converter._raw_list) def log_cosh(x): return jnp.logaddexp(x, -x) - jnp.log(2.0) def curvature_term(y_pred, y): d2p = jnp.gradient(jnp.gradient(y_pred, axis=-1), axis=-1) d2o = jnp.gradient(jnp.gradient(y, axis=-1), axis=-1) return jnp.nanmean((d2p - d2o) ** 2) def smoothness_term(y_pred, y): dr = y_pred - y dp = jnp.gradient(dr, axis=-1) return jnp.nanmean(dp ** 2) def build_design_matrix_and_base( xs: jnp.ndarray, raw_nl: jnp.ndarray, ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """ Given non-linear raw params, build: - A: design matrix (n_pix, n_lin), - y_base: baseline spectrum with all linear amplitudes = 0, - phys_full: full physical parameter vector corresponding to raw_nl (with linear entries set to 0 in phys space). """ # 1) Build a full raw vector with only non-linear entries filled. full_raw = jnp.zeros((n_free,), dtype=raw_nl.dtype) full_raw = full_raw.at[nonlinear_raw_idx].set(raw_nl) # 2) Convert to physical parameters phys_full = param_converter.raw_to_phys(full_raw) # (n_total,) # 3) Zero out physical linear amplitudes to define the baseline phys_base = phys_full.at[linear_phys_idx].set(0.0) # 4) Baseline spectrum with all linear amps = 0 y_base = func(xs, phys_base) # (n_pix,) n_lin = linear_phys_idx.shape[0] j_idx = jnp.arange(n_lin, dtype=int) # 5) For each linear parameter j, set phys_base[lin_phys_idx[j]] = 1 # and compute its basis contribution. def basis_col(j): phys_j = phys_base.at[linear_phys_idx[j]].set(1.0) y_j = func(xs, phys_j) return y_j - y_base # contribution per unit amplitude cols = jax.vmap(basis_col)(j_idx) # (n_lin, n_pix) A = cols.T # (n_pix, n_lin) return A, y_base, phys_base def loss(raw_nl: jnp.ndarray, xs: jnp.ndarray, y: jnp.ndarray, yerr: jnp.ndarray) -> jnp.ndarray: # 1) Build design matrix and baseline A, y_base, phys_base = build_design_matrix_and_base(xs, raw_nl) # 2) Solve for optimal amplitudes if weighted: a_star = _solve_weighted_linear_least_squares( A, y, yerr, lambda_reg=lambda_reg, reg_matrix=reg_matrix ) y_pred = y_base + A @ a_star r = (y_pred - y) / jnp.clip(yerr, 1e-8) else: a_star = _solve_weighted_linear_least_squares( A, y, jnp.ones_like(yerr), lambda_reg=lambda_reg, reg_matrix=reg_matrix ) y_pred = y_base + A @ a_star r = (y_pred - y) # 3) Data term = mean + max log-cosh Lmean = jnp.nanmean(log_cosh(r)) Lmax = jnp.max(log_cosh(r)) data_term = Lmean + max_weight * Lmax # 4) Optional penalty on the physical parameters (using phys_base, i.e. # with linear amps set to 0; you can adjust if you prefer otherwise) penalty_term = 0.0 if penalty_function is not None and penalty_weight != 0.0: penalty_term = penalty_weight * penalty_function(xs, phys_base) # 5) Optional curvature & smoothness terms curv_term = curvature_weight * curvature_term(y_pred, y) smooth_term = smoothness_weight * smoothness_term(y_pred, y) return data_term + penalty_term + curv_term + smooth_term return loss