sheap.SheapModelFitting.SheapModelFitting module

Complex Fitting

This module defines SheapModelFitting, the main driver for fitting multi-component spectral regions with JAX-based minimization.

Main Features

  • Builds parameter initialization, constraints, and profile functions.

  • Performs iterative optimization using custom JAX minimizers.

  • Supports tied parameters, penalties, and continuum fitting.

  • Computes uncertainties via covariance matrices or samplers.

  • Packages results into SheapResult.

Notes

  • All fitting is GPU-accelerated via JAX.

  • Residual loss is log-cosh with optional smoothness/curvature penalties.

  • Continuum slopes/intercepts can be initialized via weighted least squares.

class SheapModelFitting(region_dict, *, profile='gaussian', limits_overrides={})[source]

Bases: object

Fits a spectral region containing multiple emission lines.

This class wraps the workflow of:
  • Building parameterized line + continuum models

  • JAX‑based fitting (with optional penalty functions)

  • Uncertainty estimation via covariance or sampling

  • Post‑processing (renormalization, χ² calculation)

  • Packaging results into a SheapResult object

Parameters:
  • region_dict (dict) – Dictionary of attributes produced by a ComplexBuilder, including: - sheapmodel (with .lines list) - fitting_routine (dict of fit steps and ties) - any other metadata needed for fitting

  • profile (str, optional) – Default line profile to use for unlabeled components (e.g. ‘gaussian’, ‘lorentzian’, ‘SPAF’), by default “gaussian”

  • limits_overrides (dict[str, FittingLimits], optional) – Overrides for the default parameter‑limit lookup, by species or region.

profile

Profile name used for unconstrained components.

Type:

str

limits_map

Per‑species limits (lower/upper) for fit parameters.

Type:

dict[str, FittingLimits]

params_dict

Mapping from parameter names to their index in the packed parameter vector.

Type:

dict[str, int]

initial_params

Initial guesses for all fit parameters.

Type:

jnp.ndarray, shape (n_params,)

constraints

Lower and upper bounds for each parameter.

Type:

jnp.ndarray, shape (n_params, 2)

profile_functions

JAX‑compiled model functions for each line/continuum component.

Type:

list[callable]

model

Fused, jit‑compiled model combining all profile_functions.

Type:

callable

host_info

Extra metadata (e.g. stellar‑population grid) for penalty construction.

Type:

dict

sheapesult

Final results (parameters, uncertainties, residuals, χ², etc.), set after fitting.

Type:

SheapResult

__call__(spectra, force_cut=False, run_uncertainty_params=True,

inner_limits=None, outer_limits=None, learning_rate=None, add_penalty_function=False)

Execute the full fit on one or more spectra. Raises if limits or routine are mis‑specified.

_fit(iteration_number, norm_spec, model, initial_params,

tied, learning_rate=1e-1, weighted=True, num_steps=1000, non_optimize_in_axis=3, penalty_function=None)

Perform the JAX‑based minimization using Minimizer. Returns optimized parameters and final loss history.

_prep_data(spectra, inner_limits, outer_limits, force_cut)[source]

Preprocess spectra: mask, cut region, normalize flux by max per pixel.

_postprocess(norm_spec, params, uncertainty_params, scale)[source]

Scale fitted parameters back to original flux units and compute residuals, χ², and package intermediate arrays.

_build_fit_components(profile="gaussian", \*\*kwargs)[source]

Build parameter initialization lists, constraints, and profile functions from sheapmodel.lines.

_build_tied(tied_params)[source]

Convert user‑specified tie lists into dependency strings for the minimizer.

_stack_constraints(low, high)[source]

Stack lower & upper bound lists into an (n_params, 2) JAX array.

_add_linear(idx)[source]

Add a linear continuum component if none was found in the region.

to_result()[source]

Assemble a SheapResult object from the final attributes.

Return type:

SheapResult

from_builder(builder, \*, profile='gaussian', limits_overrides=None, \*\*builder_kwargs)[source]

Alternate constructor: build region_dict via ComplexBuilder and return a new instance.

Parameters:
  • builder (ComplexBuilder)

  • profile (str)

Return type:

SheapModelFitting

init_linear(norm_spec, params)[source]

Compute and insert weighted least‑squares continuum slopes/intercepts.

Notes

  • Uses JAX and a custom Minimizer for gradient‑based optimization.

  • Tied parameters are handled via the _build_tied helper.

  • Continuum can be injected via a linear fit or special continuum profiles.

Examples

>>> builder = ComplexBuilder(xmin=6500, xmax=6600, lines=['Halpha', 'NII'])
>>> rf = SheapModelFitting.from_builder(builder, profile='SPAF')
>>> rf(spectra_array, inner_limits=(6520, 6580), outer_limits=(6500, 6600))
>>> df = rf.pandas_params()
>>> print(df.head())
to_result()[source]

Assemble and store the SheapResult object.

Return type:

None

classmethod from_builder(builder, *, profile='gaussian', limits_overrides=None, **builder_kwargs)[source]

Construct SheapModelFitting directly from a ComplexBuilder.

Parameters:
  • builder (ComplexBuilder)

  • profile (str, optional) – Default profile name for unconstrained lines.

  • limits_overrides (dict, optional) – Per‑species parameter limits overrides.

  • **builder_kwargs – Passed to builder._make_fitting_routine(…)

Return type:

SheapModelFitting

init_linear(norm_spec, params)[source]

Fit and insert a linear continuum via weighted least squares.

Parameters:
  • norm_spec (jnp.ndarray) – Array of normalized spectra, shape (n, 3, m).

  • params (jnp.ndarray) – Uninitialized parameter array to be updated.

Returns:

Updated parameters with continuum slopes/intercepts.

Return type:

jnp.ndarray