sheap.Minimizer.Minimizer module
Minimization Routines
This module contains the main minimization routines in sheap. It defines the Minimizer class, which wraps JAX and Optax optimizers for constrained spectral model fitting.
Contents
Minimizer: high-level interface for fitting spectral models with Adam or LBFGS optimizers.
Loss Function: constructed via loss_builder.build_loss_function, supporting weighted residuals, penalties, and regularization terms.
Vectorization: optimization can be run across batches via jax.vmap.
Constraints & Dependencies: supports tied parameters and physical constraints through Parameters converters.
Notes
Optimization supports two methods: - “adam” (gradient descent with adaptive moments, default) - “lbfgs” (quasi-Newton optimizer via Optax)
Regularization options include: curvature matching, smoothness penalties, and maximum residual weighting.
non_optimize_in_axis controls how constraints and initial conditions are shared across batched spectra:
3 → same initial values and constraints
4 → same constraints, different initial values
5 → both constraints and initial values vary
Example
from sheap.Minimizer.Minimizer import Minimizer
minimizer = Minimizer(model_fn, num_steps=2000, learning_rate=1e-2)
final_params, loss_history = minimizer(
initial_params, flux, wavelength, errors, constraints
)
- class Minimizer(func, non_optimize_in_axis=3, num_steps=1000, learning_rate=None, list_dependencies=[], weighted=True, method='adam', lbfgs_options=None, penalty_function=None, param_converter=None, penalty_weight=0.01, curvature_weight=1000.0, smoothness_weight=100000.0, max_weight=0.1, **kwargs)[source]
Bases:
objectHandles constrained optimization for a given model function using JAX and Optax. #TODO maybe for one object remove the JIT .. attribute:: func
The model function to be optimized.
- type:
Callable
- Parameters:
func (Callable)
non_optimize_in_axis (int)
num_steps (int)
learning_rate (float | None)
list_dependencies (List[str])
weighted (bool)
method (str)
lbfgs_options (Dict | None)
penalty_function (Callable | None)
param_converter (Parameters | None)
penalty_weight (float)
curvature_weight (float)
smoothness_weight (float)
max_weight (float)
- non_optimize_in_axis
Determines vmap axis behavior: - 3: same initial values and constraints across data - 4: same constraints, different initial values - 5: different initial values and constraints
- Type:
int
- num_steps
Number of optimization iterations.
- Type:
int
- learning_rate
Learning rate for the optimizer (ignored for LBFGS).
- Type:
float
- list_dependencies
Parameter dependency specifications for tied parameters.
- Type:
list of str
- method
Optimization method to use (‘adam’ or ‘lbfgs’).
- Type:
str
- lbfgs_options
Options specific to LBFGS optimization (e.g., maxiter, tolerance_grad).
- Type:
dict
- optimizer
Optax optimizer instance.
- Type:
optax.GradientTransformation
- loss_function
JIT-compiled loss function including penalties.
- Type:
Callable
- optimize_model
Function that performs the optimization loop.
- Type:
Callable
- static minimization_function(func, weighted, penalty_function, penalty_weight, param_converter, curvature_weight, learning_rate, smoothness_weight, max_weight, method, lbfgs_options, num_steps)[source]
Builds the loss function and corresponding optimization routine.
- Parameters:
func (Callable) – The model function.
weighted (bool) – Whether to apply inverse variance weighting.
penalty_function (Callable, optional) – Optional penalty function for parameters.
penalty_weight (float) – Scalar penalty strength.
param_converter (Parameters, optional) – Object to convert raw to physical parameters.
curvature_weight (float) – Strength of curvature matching regularization.
smoothness_weight (float) – Strength of smoothness regularization.
max_weight (float) – Penalty on worst residual.
method (str) – Optimizer method (‘adam’ or ‘lbfgs’).
lbfgs_options (dict) – Dictionary of LBFGS-specific options.
learning_rate (float)
- Returns:
The compiled loss function and optimization routine.
- Return type:
Tuple[Callable, Callable]