sheap.Profiles.Utils module

Profile Utilities

This module provides utility functions to support the definition, integration, and composition of spectral line and continuum profiles in sheap.

Functions

  • make_fused_profiles(funcs) :

    Combine multiple profile functions into a single callable that evaluates the sum of all components.

  • with_param_names(param_names) :

    Decorator to attach parameter names and count metadata to profile functions for consistent handling across the codebase.

  • make_integrator(profile_fn, method="broadcast") :

    Factory for JAX-based integrators of profile functions. Supports either broadcasting across batches or nested vmap evaluation.

  • build_grid_penalty(weights_idx, n_Z, n_age) :

    Construct a Laplacian smoothness penalty on 2D grids of host template weights (Z × age), useful for regularization.

  • trapz_jax(y, x) :

    Lightweight trapezoidal integration implemented with JAX.

Notes

  • All utilities are JAX-compatible and designed for use in differentiable fitting pipelines.

  • with_param_names ensures that each profile function exposes .param_names and .n_params attributes for downstream bookkeeping.

  • make_integrator can be used to integrate profiles per spectrum, per component, or across batches without manual looping.

build_grid_penalty(weights_idx, n_Z, n_age)[source]

Construct a Laplacian smoothness penalty over a 2D template grid.

Parameters:
  • weights_idx (list[int]) – Indices of weight parameters in the global parameter vector.

  • n_Z (int) – Number of metallicity bins.

  • n_age (int) – Number of age bins.

Returns:

penalty – Function (params) -> float computing the smoothness penalty.

Return type:

callable

make_fused_profiles(funcs)[source]

Fuse multiple profile functions into a single callable. #TODO when we make a fused profile it lose the info about the n_params and the names (this can be repetead so are not trustworthy) we can add at least the n param to the combine ones. :param funcs: Each function must have a .n_params attribute

and a signature (x, params).

Returns:

fused_profile – A function that evaluates the sum of all profiles given a single concatenated parameter vector.

Return type:

callable

make_integrator(profile_fn, method='broadcast')[source]

Create an integrator for profile functions. This works for 1D wavelength and 3D params n_sample,n_lines,n_params.

Parameters:
  • profile_fn (callable) – Profile function with signature (x, params) -> y.

  • method ({"broadcast", "vmap"}, optional) – Integration strategy: - “broadcast”: expand x for broadcasting across batches - “vmap”: use nested vectorization

Returns:

integrate – Function (x, params) -> integral returning integrated flux.

Return type:

callable

trapz_jax(y, x)[source]

JAX-compatible trapezoidal integration.

Parameters:
  • y (jnp.ndarray) – Function values.

  • x (jnp.ndarray) – Grid over which integration is performed.

Returns:

Approximate integral of y over x.

Return type:

jnp.ndarray

with_param_names(param_names)[source]

Decorator to attach parameter names and count to a profile function.

Parameters:

param_names (list of str) – Names of the parameters for the decorated profile function.

Returns:

decorator – A decorator that attaches .param_names and .n_params attributes to the target function.

Return type:

callable