Source code for sheap.Profiles.Utils

"""
Profile Utilities
=================

This module provides utility functions to support the definition,
integration, and composition of spectral line and continuum profiles
in *sheap*.

Functions
---------
- ``make_fused_profiles(funcs)`` :
    Combine multiple profile functions into a single callable that
    evaluates the sum of all components.

- ``with_param_names(param_names)`` :
    Decorator to attach parameter names and count metadata to profile
    functions for consistent handling across the codebase.

- ``make_integrator(profile_fn, method="broadcast")`` :
    Factory for JAX-based integrators of profile functions. Supports
    either broadcasting across batches or nested `vmap` evaluation.

- ``build_grid_penalty(weights_idx, n_Z, n_age)`` :
    Construct a Laplacian smoothness penalty on 2D grids of host
    template weights (Z × age), useful for regularization.

- ``trapz_jax(y, x)`` :
    Lightweight trapezoidal integration implemented with JAX.

Notes
-----
- All utilities are JAX-compatible and designed for use in differentiable
  fitting pipelines.
- ``with_param_names`` ensures that each profile function exposes
  ``.param_names`` and ``.n_params`` attributes for downstream
  bookkeeping.
- ``make_integrator`` can be used to integrate profiles per spectrum,
  per component, or across batches without manual looping.
"""

__author__ = 'felavila'


__all__ = [
    "build_grid_penalty",
    "make_fused_profiles",
    "make_integrator",
    "trapz_jax",
    "with_param_names",
    "GaussianSum",
]

from typing import Any, Callable, Dict, List, Optional, Tuple, Union,Iterable 

import numpy as np 
import jax.numpy as jnp 
from jax.scipy.integrate import trapezoid
from jax import vmap, jit,nn


[docs] def with_param_names( param_names: Iterable[str], linear_param_names: Optional[Iterable[str]] = None, linear_param_mask: Optional[Iterable[bool]] = None, profile_name = None): """ Decorate a profile function with parameter metadata. Parameters ---------- param_names : iterable of str Ordered parameter names for the function. linear_param_names : iterable of str, optional Names of parameters that enter linearly. linear_param_mask : iterable of bool, optional Boolean mask of the same length as ``param_names`` indicating which parameters are linear. Returns ------- callable Decorated function with attributes: - ``param_names`` - ``n_params`` - ``linear_param_mask`` - ``linear_param_indices`` - ``nonlinear_param_indices`` - ``linear_param_names`` - ``nonlinear_param_names`` Notes ----- Provide either ``linear_param_names`` or ``linear_param_mask``. If neither is provided, all parameters are assumed nonlinear. """ param_names = list(param_names) n_params = len(param_names) if linear_param_names is not None and linear_param_mask is not None: raise ValueError("Use only one of linear_param_names or linear_param_mask.") if linear_param_names is not None: linear_param_names = set(linear_param_names) mask = [name in linear_param_names for name in param_names] elif linear_param_mask is not None: mask = list(linear_param_mask) if len(mask) != n_params: raise ValueError("linear_param_mask must have the same length as param_names.") else: mask = [False] * n_params linear_idx = tuple(i for i, m in enumerate(mask) if m) nonlinear_idx = tuple(i for i, m in enumerate(mask) if not m) def decorator(func): func.param_names = tuple(param_names) func.n_params = n_params func.linear_param_mask = tuple(bool(m) for m in mask) func.linear_param_indices = linear_idx func.nonlinear_param_indices = nonlinear_idx func.linear_param_names = tuple(param_names[i] for i in linear_idx) func.nonlinear_param_names = tuple(param_names[i] for i in nonlinear_idx) func.profile_name = profile_name return func return decorator
# def make_fused_profiles(funcs): # """ # Fuse multiple profile functions into a single callable. # #TODO when we make a fused profile it lose the info about the n_params and the names (this can be repetead so are not trustworthy) we can add at least the n param to the combine ones. # Parameters # ---------- # funcs : list of callables # Each function must have a `.n_params` attribute # and a signature `(x, params)`. # Returns # ------- # fused_profile : callable # A function that evaluates the sum of all profiles given # a single concatenated parameter vector. # """ # n_params = [f.n_params for f in funcs] # param_splits = np.cumsum([0] + n_params) # [0, 3, 6, ...] # def fused_profile(x, all_args): # result = 0.0 # for i, f in enumerate(funcs): # fargs = all_args[param_splits[i]:param_splits[i+1]] # result = result + f(x, fargs) # return result # return fused_profile # def make_g(list): # amplitudes, centers = list.amplitude, list.center # return PROFILE_FUNC_MAP["Gsum_model"](centers, amplitudes) #here add the function to reconstruct sum_gaussian_amplitude_free
[docs] def make_fused_profiles(funcs): """ Fuse multiple profile functions into a single callable while preserving parameter metadata, including which parameters are linear. """ n_params = [f.n_params for f in funcs] param_splits = np.cumsum([0] + n_params) fused_param_names = [] fused_linear_mask = [] for i, f in enumerate(funcs): names = list(getattr(f, "param_names", [f"p{i}_{k}" for k in range(f.n_params)])) linear_mask = list(getattr(f, "linear_param_mask", [False] * f.n_params)) if len(names) != f.n_params: raise ValueError(f"Function {i} has inconsistent param_names length.") if len(linear_mask) != f.n_params: raise ValueError(f"Function {i} has inconsistent linear_param_mask length.") fused_param_names.extend(names) fused_linear_mask.extend(linear_mask) @with_param_names( fused_param_names, linear_param_mask=fused_linear_mask,) def fused_profile(x, all_args): all_args = jnp.asarray(all_args) result = jnp.array(0.0, dtype=x.dtype if hasattr(x, "dtype") else jnp.float32) for i, f in enumerate(funcs): fargs = all_args[param_splits[i]:param_splits[i + 1]] result = result + f(x, fargs) return result fused_profile.param_splits = tuple(param_splits.tolist()) fused_profile.sub_funcs = tuple(funcs) return fused_profile
[docs] def make_integrator(profile_fn, method="broadcast"): """ Create an integrator for profile functions. This works for 1D wavelength and 3D params n_sample,n_lines,n_params. Parameters ---------- profile_fn : callable Profile function with signature `(x, params) -> y`. method : {"broadcast", "vmap"}, optional Integration strategy: - "broadcast": expand x for broadcasting across batches - "vmap": use nested vectorization Returns ------- integrate : callable Function `(x, params) -> integral` returning integrated flux. """ if method == "broadcast": @jit def integrate(x, params): # ensure jnp arrays x = jnp.asarray(x) # (n_pixels,) params = jnp.asarray(params) # (n_spec, n_lines, n_params) # expand x to broadcast against params’ leading dims x_exp = x[:, None, None] # (n_pixels,1,1) y = profile_fn(x_exp, params) # -> (n_pixels, n_spec, n_lines) return trapezoid(y, x, axis=0) # integrate over 0 → (n_spec, n_lines) return integrate elif method == "vmap": # first define a scalar integrator for a single (x,p) pair def single_int(x, p): y = profile_fn(x, p) # p: (n_params,) → y: (n_pixels,) return trapezoid(y, x) # lift over lines, then over spectra int_lines = vmap(single_int, in_axes=(None, 0)) # maps over p-lines int_specs = vmap(int_lines, in_axes=(None, 0)) # maps over spectra integrate = jit(lambda x, params: int_specs(x, params)) return integrate else: raise ValueError(f"unknown method {method!r}")
[docs] def build_grid_penalty( weights_idx, n_Z: int, n_age: int, ) -> Callable[[jnp.ndarray], float]: """ Construct a Laplacian smoothness penalty over a 2D template grid. Parameters ---------- weights_idx : list[int] Indices of weight parameters in the global parameter vector. n_Z : int Number of metallicity bins. n_age : int Number of age bins. Returns ------- penalty : callable Function `(params) -> float` computing the smoothness penalty. """ if len(weights_idx) != n_Z * n_age: raise ValueError(f"Expected {n_Z * n_age} weight indices, got {len(weights_idx)}") def penalty(params: jnp.ndarray) -> float: weights = params[jnp.array(weights_idx)] weights_grid = weights.reshape(n_Z, n_age) d2_age = weights_grid[:, :-2] - 2 * weights_grid[:, 1:-1] + weights_grid[:, 2:] d2_Z = weights_grid[:-2, :] - 2 * weights_grid[1:-1, :] + weights_grid[2:, :] return jnp.sum(d2_age**2) + jnp.sum(d2_Z**2) return penalty
[docs] def trapz_jax(y: jnp.ndarray, x: jnp.ndarray) -> jnp.ndarray: """ JAX-compatible trapezoidal integration. Parameters ---------- y : jnp.ndarray Function values. x : jnp.ndarray Grid over which integration is performed. Returns ------- jnp.ndarray Approximate integral of y over x. """ dx = x[1:] - x[:-1] return jnp.sum((y[1:] + y[:-1]) * dx / 2)
[docs] class GaussianSum: def __init__(self, n, constraints=None, inequalities=None): """ Initialize the GaussianSum with parameter constraints. Parameters: - n (int): Number of Gaussian functions. - constraints (dict): Optional equality constraints on parameters. Example: { 'amp': [('amp0', 'amp1')], # amp0 == amp1 'mu': [('mu2', 'mu3')], 'sigma': [('sigma1', 'sigma2')] } - inequalities (dict): Optional inequality constraints on parameters. Example: { 'sigma': [('sigma1', 'sigma2')] # sigma2 > sigma1 } """ self.n = n self.constraints = constraints or {} self.inequalities = inequalities or {} self.param_mapping = self._build_param_mapping() self.num_free_params = self._count_free_params() self.sum_gaussians_jit = self._build_gaussian_sum() def _build_param_mapping(self): """ Build a mapping from free parameters to all parameters, applying constraints as specified. """ # Initialize mappings: each parameter maps to itself initially mapping = { 'amp': list(range(self.n)), 'mu': list(range(self.n)), 'sigma': list(range(self.n)) } # Apply equality constraints for param_type, pairs in self.constraints.items(): for (p1, p2) in pairs: idx1 = int(p1.replace(param_type, '')) idx2 = int(p2.replace(param_type, '')) mapping[param_type][idx2] = mapping[param_type][idx1] return mapping def _count_free_params(self): """ Count the number of free parameters after applying constraints. """ free_amp = len(set(self.param_mapping['amp'])) free_mu = len(set(self.param_mapping['mu'])) free_sigma = len(set(self.param_mapping['sigma'])) return free_amp + free_mu + free_sigma + self._count_inequality_free_params() def _count_inequality_free_params(self): """ Count additional free parameters required for inequality constraints. For each inequality, an extra free parameter is needed to define the offset. """ count = 0 for param_type, pairs in self.inequalities.items(): count += len(pairs) return count def _apply_constraints(self, params): """ Apply equality constraints to the parameter vector to obtain full parameter sets. Parameters: - params (jnp.ndarray): Free parameters vector. Returns: - amps, mus, sigmas (tuple of jnp.ndarray): Full parameter sets. """ free_amp = self.param_mapping['amp'] free_mu = self.param_mapping['mu'] free_sigma = self.param_mapping['sigma'] num_free_amp = len(set(free_amp)) num_free_mu = len(set(free_mu)) num_free_sigma = len(set(free_sigma)) # Extract free parameters idx = 0 amps_free = params[idx:idx + num_free_amp] idx += num_free_amp mus_free = params[idx:idx + num_free_mu] idx += num_free_mu sigmas_free = params[idx:idx + num_free_sigma] idx += num_free_sigma # Map free parameters to all parameters using the mapping amps = jnp.array([amps_free[i] for i in self.param_mapping['amp']]) mus = jnp.array([mus_free[i] for i in self.param_mapping['mu']]) sigmas = jnp.array([sigmas_free[i] for i in self.param_mapping['sigma']]) return amps, mus, sigmas def _apply_inequality_constraints(self, sigmas, params): """ Apply inequality constraints to sigmas. For example, enforce sigma2 > sigma1 by setting sigma2 = sigma1 + softplus(delta) Parameters: - sigmas (jnp.ndarray): Current sigma parameters. - params (jnp.ndarray): Remaining parameters for inequality transformations. Returns: - jnp.ndarray: Transformed sigma parameters satisfying inequalities. """ if not self.inequalities: return sigmas # Assuming all inequality constraints are on 'sigma' for (s1, s2) in self.inequalities.get('sigma', []): idx1 = int(s1.replace('sigma', '')) idx2 = int(s2.replace('sigma', '')) delta = params[0] params = params[1:] transformed_sigma2 = sigmas[idx1] + nn.softplus(delta) sigmas = sigmas.at[idx2].set(transformed_sigma2) return sigmas def _build_gaussian_sum(self): """ Build the JIT-compiled Gaussian sum function. Returns: - sum_gaussians_jit (function): JIT-compiled function. """ def gaussian(x, amp, mu, sigma): return amp * jnp.exp(-0.5 * ((x - mu) / sigma) ** 2) def sum_gaussians(x, params): # Validate parameter length if params.shape[0] != self.num_free_params: raise ValueError(f"Expected {self.num_free_params} parameters, got {params.shape[0]}.") # Apply equality constraints amps, mus, sigmas = self._apply_constraints(params) # Apply inequality constraints if any if self.inequalities: # Extract deltas for inequalities delta_params = params[-len(self.inequalities.get('sigma', [])):] sigmas = self._apply_inequality_constraints(sigmas, delta_params) # Use a lambda to fix 'x' while vectorizing over amp, mu, sigma gaussians = vmap(lambda amp, mu, sigma: gaussian(x, amp, mu, sigma))(amps, mus, sigmas) return jnp.sum(gaussians, axis=0) self.n_params = self.num_free_params return jit(sum_gaussians) def __call__(self, x, params): """ Compute the sum of Gaussians at points x with given parameters. Parameters: - x (jnp.ndarray): Points at which to evaluate the sum. - params (jnp.ndarray): Free parameters vector. Returns: - jnp.ndarray: Sum of Gaussians evaluated at x. """ return self.sum_gaussians_jit(x, params)