Source code for sheap.Minimizer.LB_nonlinear

from typing import Optional, Tuple, Callable

import jax
import jax.numpy as jnp


dtype = jnp.float32

# varpro.py  — optimized
from functools import partial
from typing import Optional, Tuple, Callable

import jax
import jax.numpy as jnp

dtype = jnp.float32


def _solve_wlls(
    A: jnp.ndarray,
    y: jnp.ndarray,
    yerr: jnp.ndarray,
    lambda_reg: float = 0.0,
    reg_matrix: Optional[jnp.ndarray] = None,
) -> jnp.ndarray:
    """Weighted linear least-squares. Single implementation — dead duplicate removed."""
    A    = jnp.asarray(A,    dtype=dtype)
    y    = jnp.asarray(y,    dtype=dtype).reshape(-1)
    yerr = jnp.asarray(yerr, dtype=dtype).reshape(-1)
    n_lin = A.shape[1]

    if A.shape[0] == 0:
        return jnp.full((n_lin,), jnp.nan, dtype=dtype)

    w  = 1.0 / jnp.clip(yerr, 1e-10)
    Aw = A * w[:, None]
    yw = y * w

    if lambda_reg > 0.0:
        R = jnp.eye(n_lin, dtype=dtype) if reg_matrix is None \
            else jnp.asarray(reg_matrix, dtype=dtype)
        # append regularization rows in one shot — avoids two separate vstack calls
        sqrt_lam = jnp.sqrt(jnp.asarray(lambda_reg, dtype=dtype))
        Aw = jnp.vstack([Aw, sqrt_lam * R])
        yw = jnp.concatenate([yw, jnp.zeros(R.shape[0], dtype=dtype)])

    x, *_ = jnp.linalg.lstsq(Aw, yw, rcond=None)
    return jnp.asarray(x, dtype=dtype)


[docs] def make_build_linear_system_from_phys_profile(fused_profile: Callable): """Same semantics; column builder is now vmapped over a static arange so JAX can compile it once and cache the result.""" linear_phys_idx = jnp.array(fused_profile.linear_param_indices, dtype=jnp.int32) nonlinear_phys_idx = jnp.array(fused_profile.nonlinear_param_indices, dtype=jnp.int32) n_params = int(fused_profile.n_params) n_linear = int(linear_phys_idx.shape[0]) col_idx = jnp.arange(n_linear, dtype=jnp.int32) # static — computed once def build_linear_system( x: jnp.ndarray, params_phys: jnp.ndarray, ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: x = jnp.asarray(x, dtype=dtype) params_phys = jnp.asarray(params_phys, dtype=dtype).reshape(-1) params_base = params_phys.at[linear_phys_idx].set(0.0) y_base = jnp.asarray(fused_profile(x, params_base), dtype=dtype).reshape(-1) if n_linear == 0: A = jnp.zeros((y_base.shape[0], 0), dtype=dtype) return A, y_base, linear_phys_idx, nonlinear_phys_idx def one_column(j): params_j = params_base.at[linear_phys_idx[j]].set(1.0) return jnp.asarray(fused_profile(x, params_j), dtype=dtype).reshape(-1) - y_base # col_idx is a captured static array — vmap traces this once per unique n_linear A = jax.vmap(one_column)(col_idx).T return A, y_base, linear_phys_idx, nonlinear_phys_idx return build_linear_system
[docs] def build_varpro_loss_from_profile_and_params_obj( fused_profile: Callable, params_obj, nonlinear_raw_idx: jnp.ndarray, linear_phys_idx: Optional[jnp.ndarray] = None, weighted: bool = True, lambda_reg: float = 0.0, reg_matrix: Optional[jnp.ndarray] = None, ): nonlinear_raw_idx = jnp.asarray(nonlinear_raw_idx, dtype=jnp.int32) if linear_phys_idx is None: linear_phys_idx = jnp.asarray(fused_profile.linear_param_indices, dtype=jnp.int32) else: linear_phys_idx = jnp.asarray(linear_phys_idx, dtype=jnp.int32) n_params = int(fused_profile.n_params) build_linear_system = make_build_linear_system_from_phys_profile(fused_profile) # unit-error array reused across calls — avoids re-allocation in the unweighted branch _ones_cache: dict = {} def _get_ones(n: int): if n not in _ones_cache: _ones_cache[n] = jnp.ones((n,), dtype=dtype) return _ones_cache[n] def unpack_raw_nonlinear(raw_nonlinear): raw_nonlinear = jnp.asarray(raw_nonlinear, dtype=dtype).reshape(-1) raw_full = jnp.zeros((n_params,), dtype=dtype) return raw_full.at[nonlinear_raw_idx].set(raw_nonlinear) def build_from_raw_nonlinear(raw_nonlinear, x): raw_full = unpack_raw_nonlinear(raw_nonlinear) phys_full = jnp.asarray(params_obj.raw_to_phys(raw_full), dtype=dtype).reshape(-1) A, y_base, _, _ = build_linear_system(x, phys_full) return raw_full, phys_full, A, y_base def solve_full_parameters(raw_nonlinear, x, y, yerr): x = jnp.asarray(x, dtype=dtype) y = jnp.asarray(y, dtype=dtype).reshape(-1) yerr = jnp.asarray(yerr, dtype=dtype).reshape(-1) raw_full, phys_full_nonlin, A, y_base = build_from_raw_nonlinear(raw_nonlinear, x) y_target = y - y_base _yerr = yerr if weighted else _get_ones(y.shape[0]) a_star = _solve_wlls(A, y_target, _yerr, lambda_reg=lambda_reg, reg_matrix=reg_matrix) y_model = y_base + A @ a_star phys_full_best = phys_full_nonlin.at[linear_phys_idx].set(a_star) if hasattr(params_obj, "phys_to_raw"): raw_full_best = jnp.asarray(params_obj.phys_to_raw(phys_full_best), dtype=dtype).reshape(-1) else: raw_full_best = raw_full return phys_full_best, raw_full_best, a_star, y_model, y_base, A # ── loss is JIT-compiled once; shape changes retrace but that's unavoidable ── @jax.jit def loss(raw_nonlinear, x, y, yerr): x = jnp.asarray(x, dtype=dtype) y = jnp.asarray(y, dtype=dtype).reshape(-1) yerr = jnp.asarray(yerr, dtype=dtype).reshape(-1) _, _, _, y_model, _, _ = solve_full_parameters(raw_nonlinear, x, y, yerr) if weighted: resid = (y_model - y) / jnp.clip(yerr, 1e-8) else: resid = y_model - y return jnp.mean(resid ** 2) return loss, solve_full_parameters