r"""
Loss Function Builder
=====================
This module defines the construction of flexible loss functions used in *sheap*
for spectral fitting and optimization.
Contents
--------
- **build_loss_function**: Factory for JAX-compatible scalar loss functions
combining residuals, penalties, and regularization.
Loss Components
---------------
The constructed loss may include the following terms:
1. **Data fidelity (log-cosh residuals)**
.. math::
\mathcal{L}_\text{data} =
\langle \log\cosh(r) \rangle + \alpha \, \max(\log\cosh(r)),
\quad r = \frac{y_\text{pred} - y}{\sigma}
2. **Optional penalty on parameters**
.. math::
\mathcal{L}_\text{penalty} =
\beta \, \text{penalty\_function}(x, \theta)
3. **Curvature matching**
.. math::
\mathcal{L}_\text{curvature} =
\gamma \, \langle (f''_\text{pred} - f''_\text{true})^2 \rangle
4. **Residual smoothness**
.. math::
\mathcal{L}_\text{smoothness} =
\delta \, \langle (\nabla r)^2 \rangle
Notes
-----
- `penalty_function` can enforce additional physics or priors.
- `param_converter` allows transformation from raw to physical parameters.
- All terms are implemented using JAX and are fully differentiable.
Example
-------
.. code-block:: python
from sheap.Minimizer.loss_builder import build_loss_function
loss_fn = build_loss_function(model_fn, weighted=True, curvature_weight=1e4)
loss_val = loss_fn(params, x_grid, flux, flux_err)
"""
__author__ = 'felavila'
__all__ = [
"build_loss_function",
]
from typing import Callable, Dict, List, Optional, Tuple
import jax
import jax.numpy as jnp
import optax
from jax import jit, vmap,lax
# loss.py — optimized
from typing import Optional, Callable
import jax
import jax.numpy as jnp
[docs]
def build_loss_function(
func: Callable,
weighted: bool = True,
penalty_function: Optional[Callable] = None,
penalty_weight: float = 0.01,
param_converter=None,
curvature_weight: float = 1e3,
smoothness_weight: float = 1e5,
max_weight: float = 0.1,
) -> Callable:
"""
Optimizations vs original
-------------------------
1. d2y_true is pre-computed ONCE at build time — it's a constant, never changes.
2. log_cosh mean+max fused into a single jnp.nanmean / jnp.max call on the
same array (computed once, not twice).
3. Four duplicated closure branches collapsed to one.
4. log_cosh uses a numerically stable form that avoids the logaddexp overhead
for large |x|: jnp.abs(x) + jnp.log1p(jnp.exp(-2*jnp.abs(x))) - log2.
5. @jax.jit on the returned loss so the optimizer doesn't retrace.
"""
_log2 = jnp.log(jnp.array(2.0))
def log_cosh(x):
# stable: log(cosh(x)) = |x| + log1p(exp(-2|x|)) - log(2)
ax = jnp.abs(x)
return ax + jnp.log1p(jnp.exp(-2.0 * ax)) - _log2
def wrapped(xs, raw_params):
phys = param_converter.raw_to_phys(raw_params) if param_converter else raw_params
return func(xs, phys)
_cache: dict = {}
def _truth_curvature(y):
"""Returns d²y/dx² — memoised by object identity across a single optimisation run."""
key = id(y) if hasattr(y, '__jax_array__') else None
if key not in _cache:
_cache[key] = jnp.gradient(jnp.gradient(y, axis=-1), axis=-1)
if len(_cache) > 4: # prevent unbounded growth
_cache.pop(next(iter(_cache)))
return _cache[key]
def _loss_body(params, xs, y, yerr):
y_pred = wrapped(xs, params)
r = (y_pred - y) / jnp.clip(yerr, 1e-8) if weighted else (y_pred - y)
lc = log_cosh(r)
data_term = jnp.nanmean(lc) + max_weight * jnp.max(lc)
curv_term = 0.0
if curvature_weight != 0.0:
d2pred = jnp.gradient(jnp.gradient(y_pred, axis=-1), axis=-1)
d2true = _truth_curvature(y) # cached
curv_term = curvature_weight * jnp.nanmean((d2pred - d2true) ** 2)
smooth_term = 0.0
if smoothness_weight != 0.0:
dr = y_pred - y
smooth_term = smoothness_weight * jnp.nanmean(jnp.gradient(dr, axis=-1) ** 2)
penalty_term = 0.0
if penalty_function is not None:
penalty_term = penalty_weight * penalty_function(xs, params) * 1e3
return data_term + curv_term + smooth_term + penalty_term
return jax.jit(_loss_body)
def _solve_weighted_linear_least_squares(
A: jnp.ndarray,
y: jnp.ndarray,
yerr: jnp.ndarray,
lambda_reg: float = 0.0,
reg_matrix: Optional[jnp.ndarray] = None,
) -> jnp.ndarray:
r"""
Solve the weighted linear least-squares problem for amplitudes.
Minimizes
.. math::
\| W (y - A a) \|^2 + \lambda \| R a \|^2,
where :math:`W = \mathrm{diag}(1/\sigma)`.
Parameters
----------
A : jnp.ndarray
Design matrix, shape (n_pix, n_lin).
y : jnp.ndarray
Observed spectrum, shape (n_pix,).
yerr : jnp.ndarray
1-sigma uncertainties, shape (n_pix,).
lambda_reg : float, default=0.0
Regularization strength :math:`\lambda`.
reg_matrix : jnp.ndarray, optional
Regularization operator :math:`R`; if None, the identity is used.
Returns
-------
jnp.ndarray
Optimal amplitudes ``a_star``, shape (n_lin,).
"""
w = 1.0 / jnp.clip(yerr, 1e-10)
Aw = A * w[:, None] # (n_pix, n_lin)
yw = y * w # (n_pix,)
ATA = Aw.T @ Aw # (n_lin, n_lin)
ATy = Aw.T @ yw # (n_lin,)
if lambda_reg > 0.0:
n_lin = ATA.shape[0]
if reg_matrix is None:
R = jnp.eye(n_lin, dtype=A.dtype)
else:
R = reg_matrix
ATA = ATA + lambda_reg * (R.T @ R)
a_star = jnp.linalg.solve(ATA, ATy)
return a_star
def build_varpro_loss_function(
func: Callable,
weighted: bool = True,
penalty_function: Optional[Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]] = None,
penalty_weight: float = 0.01,
param_converter: Optional["Parameters"] = None,
curvature_weight: float = 0.0,
smoothness_weight: float = 0.0,
max_weight: float = 0.0,
lambda_reg: float = 0.0,
reg_matrix: Optional[jnp.ndarray] = None,
) -> Callable:
if param_converter is None:
raise ValueError("build_varpro_loss_function requires a param_converter.")
param_converter._ensure_finalized()
if len(param_converter._tied_list) > 0:
raise ValueError("This test varpro version does not support tied parameters.")
if len(param_converter._raw_shared_list) > 0:
raise ValueError("This test varpro version does not support shared parameters.")
raw_list = param_converter._raw_list
n_obj = 1#int(param_converter.n_obj)
dtype_default = jnp.float32
linear_raw_idx = jnp.array(
[i for i, p in enumerate(raw_list) if p.linear_param],
dtype=jnp.int32,
)
nonlinear_raw_idx = jnp.array(
[i for i, p in enumerate(raw_list) if not p.linear_param],
dtype=jnp.int32,
)
linear_phys_idx = jnp.array(
param_converter.linear_phys_indices(),
dtype=jnp.int32,
)
def log_cosh(x: jnp.ndarray) -> jnp.ndarray:
return jnp.logaddexp(x, -x) - jnp.log(2.0)
def curvature_term(y_pred: jnp.ndarray, y_true: jnp.ndarray) -> jnp.ndarray:
d2p = jnp.gradient(jnp.gradient(y_pred, axis=-1), axis=-1)
d2o = jnp.gradient(jnp.gradient(y_true, axis=-1), axis=-1)
return jnp.nanmean((d2p - d2o) ** 2)
def smoothness_term(y_pred: jnp.ndarray, y_true: jnp.ndarray) -> jnp.ndarray:
dr = y_pred - y_true
dp = jnp.gradient(dr, axis=-1)
return jnp.nanmean(dp ** 2)
def ensure_2d(arr: jnp.ndarray) -> jnp.ndarray:
arr = jnp.asarray(arr, dtype=dtype_default)
if arr.ndim == 1:
return arr[None, :]
return arr
def solve_linear_one(
A: jnp.ndarray,
y_target: jnp.ndarray,
yerr_i: jnp.ndarray,
) -> jnp.ndarray:
if A.shape[1] == 0:
return jnp.zeros((0,), dtype=dtype_default)
if weighted:
return _solve_weighted_linear_least_squares(
A,
y_target,
yerr_i,
lambda_reg=lambda_reg,
reg_matrix=reg_matrix,
)
else:
return _solve_weighted_linear_least_squares(
A,
y_target,
jnp.ones_like(y_target, dtype=dtype_default),
lambda_reg=lambda_reg,
reg_matrix=reg_matrix,
)
def build_design_matrix_and_base(
xs: jnp.ndarray,
raw_full: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
raw_full = ensure_2d(raw_full)
if raw_full.shape[0] != n_obj:
raise ValueError(
f"raw_full first dimension must be n_obj={n_obj}, got {raw_full.shape[0]}."
)
if raw_full.shape[1] != len(raw_list):
raise ValueError(
f"raw_full second dimension must be n_free={len(raw_list)}, got {raw_full.shape[1]}."
)
# Keep only nonlinear raw values; zero the linear ones
raw_used = jnp.zeros_like(raw_full)
raw_used = raw_used.at[:, nonlinear_raw_idx].set(raw_full[:, nonlinear_raw_idx])
phys_full = jnp.asarray(param_converter.raw_to_phys(raw_used), dtype=dtype_default)
if phys_full.ndim != 2:
raise ValueError("raw_to_phys(raw_used) must return shape (n_obj, n_params).")
phys_base = phys_full.at[:, linear_phys_idx].set(0.0)
y_base = ensure_2d(func(xs, phys_base))
n_lin = int(linear_phys_idx.shape[0])
if n_lin == 0:
A_all = jnp.zeros((y_base.shape[0], y_base.shape[1], 0), dtype=dtype_default)
return A_all, y_base, phys_base
def basis_col(j: int) -> jnp.ndarray:
phys_j = phys_base.at[:, linear_phys_idx[j]].set(1.0)
y_j = ensure_2d(func(xs, phys_j))
return y_j - y_base
cols = jax.vmap(basis_col)(jnp.arange(n_lin))
A_all = jnp.moveaxis(cols, 0, -1)
return A_all, y_base, phys_base
def loss(
raw_full: jnp.ndarray,
xs: jnp.ndarray,
y: jnp.ndarray,
yerr: jnp.ndarray,
) -> jnp.ndarray:
y = ensure_2d(y)
yerr = ensure_2d(yerr)
A_all, y_base, phys_base = build_design_matrix_and_base(xs, raw_full)
y_target = y - y_base
a_star = jax.vmap(solve_linear_one)(A_all, y_target, yerr)
y_linear = jnp.einsum("opl,ol->op", A_all, a_star)
y_pred = y_base + y_linear
phys_full = phys_base.at[:, linear_phys_idx].set(a_star)
if weighted:
r = (y_pred - y) / jnp.clip(yerr, 1e-8)
else:
r = y_pred - y
total = jnp.nanmean(log_cosh(r)) + max_weight * jnp.max(log_cosh(r))
if penalty_function is not None:
total = total + penalty_weight * penalty_function(xs, phys_full)
if curvature_weight > 0.0:
total = total + curvature_weight * curvature_term(y_pred, y)
if smoothness_weight > 0.0:
total = total + smoothness_weight * smoothness_term(y_pred, y)
return total
return loss