sheap.Minimizer.Minimizer module

Minimization Routines

This module contains the main minimization routines in sheap. It defines the Minimizer class, which wraps JAX and Optax optimizers for constrained spectral model fitting.

Contents

  • Minimizer: high-level interface for fitting spectral models with Adam or LBFGS optimizers.

  • Loss Function: constructed via loss_builder.build_loss_function, supporting weighted residuals, penalties, and regularization terms.

  • Vectorization: optimization can be run across batches via jax.vmap.

  • Constraints & Dependencies: supports tied parameters and physical constraints through Parameters converters.

Notes

  • Optimization supports two methods: - “adam” (gradient descent with adaptive moments, default) - “lbfgs” (quasi-Newton optimizer via Optax)

  • Regularization options include: curvature matching, smoothness penalties, and maximum residual weighting.

  • non_optimize_in_axis controls how constraints and initial conditions are shared across batched spectra:

    • 3 → same initial values and constraints

    • 4 → same constraints, different initial values

    • 5 → both constraints and initial values vary

Example

from sheap.Minimizer.Minimizer import Minimizer

minimizer = Minimizer(model_fn, num_steps=2000, learning_rate=1e-2)
final_params, loss_history = minimizer(
    initial_params, flux, wavelength, errors, constraints
)
class Minimizer(func, non_optimize_in_axis=3, num_steps=1000, learning_rate=None, list_dependencies=[], weighted=True, method='adam', lbfgs_options=None, penalty_function=None, param_converter=None, penalty_weight=0.01, curvature_weight=1000.0, smoothness_weight=100000.0, max_weight=0.1, **kwargs)[source]

Bases: object

Handles constrained optimization for a given model function using JAX and Optax. #TODO maybe for one object remove the JIT .. attribute:: func

The model function to be optimized.

type:

Callable

Parameters:
  • func (Callable)

  • non_optimize_in_axis (int)

  • num_steps (int)

  • learning_rate (float | None)

  • list_dependencies (List[str])

  • weighted (bool)

  • method (str)

  • lbfgs_options (Dict | None)

  • penalty_function (Callable | None)

  • param_converter (Parameters | None)

  • penalty_weight (float)

  • curvature_weight (float)

  • smoothness_weight (float)

  • max_weight (float)

non_optimize_in_axis

Determines vmap axis behavior: - 3: same initial values and constraints across data - 4: same constraints, different initial values - 5: different initial values and constraints

Type:

int

num_steps

Number of optimization iterations.

Type:

int

learning_rate

Learning rate for the optimizer (ignored for LBFGS).

Type:

float

list_dependencies

Parameter dependency specifications for tied parameters.

Type:

list of str

method

Optimization method to use (‘adam’ or ‘lbfgs’).

Type:

str

lbfgs_options

Options specific to LBFGS optimization (e.g., maxiter, tolerance_grad).

Type:

dict

optimizer

Optax optimizer instance.

Type:

optax.GradientTransformation

loss_function

JIT-compiled loss function including penalties.

Type:

Callable

optimize_model

Function that performs the optimization loop.

Type:

Callable

static minimization_function(func, weighted, penalty_function, penalty_weight, param_converter, curvature_weight, learning_rate, smoothness_weight, max_weight, method, lbfgs_options, num_steps)[source]

Builds the loss function and corresponding optimization routine.

Parameters:
  • func (Callable) – The model function.

  • weighted (bool) – Whether to apply inverse variance weighting.

  • penalty_function (Callable, optional) – Optional penalty function for parameters.

  • penalty_weight (float) – Scalar penalty strength.

  • param_converter (Parameters, optional) – Object to convert raw to physical parameters.

  • curvature_weight (float) – Strength of curvature matching regularization.

  • smoothness_weight (float) – Strength of smoothness regularization.

  • max_weight (float) – Penalty on worst residual.

  • method (str) – Optimizer method (‘adam’ or ‘lbfgs’).

  • lbfgs_options (dict) – Dictionary of LBFGS-specific options.

  • learning_rate (float)

Returns:

The compiled loss function and optimization routine.

Return type:

Tuple[Callable, Callable]

static minimization_function2(func, weighted, penalty_function, penalty_weight, param_converter, curvature_weight, learning_rate, smoothness_weight, max_weight, method, lbfgs_options, num_steps)[source]

Builds the loss function and corresponding optimization routine.

Parameters:
  • func (Callable) – The model function.

  • weighted (bool) – Whether to apply inverse variance weighting.

  • penalty_function (Callable, optional) – Optional penalty function for parameters.

  • penalty_weight (float) – Scalar penalty strength.

  • param_converter (Parameters, optional) – Object to convert raw to physical parameters.

  • curvature_weight (float) – Strength of curvature matching regularization.

  • smoothness_weight (float) – Strength of smoothness regularization.

  • max_weight (float) – Penalty on worst residual.

  • method (str) – Optimizer method (‘adam’ or ‘lbfgs’).

  • lbfgs_options (dict) – Dictionary of LBFGS-specific options.

  • learning_rate (float)

Returns:

The compiled loss function and optimization routine.

Return type:

Tuple[Callable, Callable]