Source code for sheap.Minimizer.Minimizer

"""
Minimization Routines
=====================

This module contains the main minimization routines in *sheap*.
It defines the `Minimizer` class, which wraps JAX and Optax
optimizers for constrained spectral model fitting.

Contents
--------
- **Minimizer**: high-level interface for fitting spectral models with
  Adam or LBFGS optimizers.
- **Loss Function**: constructed via `loss_builder.build_loss_function`,
  supporting weighted residuals, penalties, and regularization terms.
- **Vectorization**: optimization can be run across batches via `jax.vmap`.
- **Constraints & Dependencies**: supports tied parameters and physical
  constraints through `Parameters` converters.

Notes
-----
- Optimization supports two methods:
  - `"adam"` (gradient descent with adaptive moments, default)
  - `"lbfgs"` (quasi-Newton optimizer via Optax)
- Regularization options include:
  curvature matching, smoothness penalties, and maximum residual weighting.
- `non_optimize_in_axis` controls how constraints and initial conditions
  are shared across batched spectra:
  
  * 3 → same initial values and constraints  
  * 4 → same constraints, different initial values  
  * 5 → both constraints and initial values vary

Example
-------
.. code-block:: python

   from sheap.Minimizer.Minimizer import Minimizer

   minimizer = Minimizer(model_fn, num_steps=2000, learning_rate=1e-2)
   final_params, loss_history = minimizer(
       initial_params, flux, wavelength, errors, constraints
   )
"""

__author__ = 'felavila'

__all__ = [
    "Minimizer",
]

from typing import Callable, Dict, List, Optional, Tuple

import jax.numpy as jnp
import optax
from jax import jit, vmap, lax, value_and_grad

#from sheap.Assistants.parser_mapper import parse_dependencies, project_params
from .loss_builder import build_loss_function
from .LB_nonlinear import build_varpro_loss_from_profile_and_params_obj


[docs] class Minimizer: """ Handles constrained optimization for a given model function using JAX and Optax. #TODO maybe for one object remove the JIT Attributes ---------- func : Callable The model function to be optimized. non_optimize_in_axis : int Determines vmap axis behavior: - 3: same initial values and constraints across data - 4: same constraints, different initial values - 5: different initial values and constraints num_steps : int Number of optimization iterations. learning_rate : float Learning rate for the optimizer (ignored for LBFGS). list_dependencies : list of str Parameter dependency specifications for tied parameters. method : str Optimization method to use ('adam' or 'lbfgs'). lbfgs_options : dict Options specific to LBFGS optimization (e.g., maxiter, tolerance_grad). optimizer : optax.GradientTransformation Optax optimizer instance. loss_function : Callable JIT-compiled loss function including penalties. optimize_model : Callable Function that performs the optimization loop. """ def __init__( self, func: Callable, non_optimize_in_axis: int = 3, num_steps: int = 1_000, learning_rate: Optional[float] = None, list_dependencies: List[str] = [], weighted: bool = True, method: str = "adam", lbfgs_options: Optional[Dict] = None, penalty_function: Optional[Callable] = None, param_converter: Optional["Parameters"] = None, penalty_weight: float = 0.01, curvature_weight: float = 1e3, smoothness_weight: float = 1e5, max_weight: float = 0.1, **kwargs, ): self.func = func self.non_optimize_in_axis = non_optimize_in_axis self.num_steps = num_steps self.learning_rate = learning_rate or 1e-2 self.list_dependencies = list_dependencies self.param_converter = param_converter self.method = method.lower() self.lbfgs_options = lbfgs_options or {} #self.optimizer = kwargs.get("optimizer", optax.adam(self.learning_rate)) #print(method,penalty_weight,curvature_weight,smoothness_weight,max_weight) #self.parsed_dependencies_tuple = parse_dependencies(self.list_dependencies) self.loss_function, self.optimize_model = Minimizer.minimization_function(self.func, weighted=weighted, penalty_function=penalty_function, penalty_weight=penalty_weight,param_converter=self.param_converter, curvature_weight=curvature_weight, learning_rate = learning_rate, smoothness_weight=smoothness_weight, max_weight=max_weight, method=self.method, lbfgs_options=self.lbfgs_options, num_steps = num_steps) def __call__(self, initial_params, y, x, yerror, constraints,): """ Execute the optimization process across batches. Parameters ---------- initial_params : jnp.ndarray Initial parameters for optimization. y : jnp.ndarray Observed data values. x : jnp.ndarray Wavelength or independent variable. yerror : jnp.ndarray Uncertainty for each observation. constraints : jnp.ndarray Parameter constraints, shape (N_params, 2). Returns ------- jnp.ndarray Optimized parameters. list Final loss history. """ optimize_in_axis = ( (None, 0, 0, 0, None) if self.non_optimize_in_axis == 3 else (0, 0, 0, 0, None) ) vmap_optimize_model = vmap(self.optimize_model, in_axes=optimize_in_axis, out_axes=0) if self.param_converter: initial_params = self.param_converter.phys_to_raw(initial_params) raw_params,loss = vmap_optimize_model(initial_params,y,x,yerror,constraints,) return self.param_converter.raw_to_phys(raw_params),loss else: #print warning sayng about no param class is defined return vmap_optimize_model(initial_params,y,x,yerror,constraints,)
[docs] @staticmethod def minimization_function( func: Callable, weighted: bool, penalty_function: Optional[Callable], penalty_weight: float, param_converter: Optional["Parameters"], curvature_weight: float, learning_rate : float, smoothness_weight: float, max_weight: float, method: str, lbfgs_options: dict, num_steps ) -> Tuple[Callable, Callable]: """ Builds the loss function and corresponding optimization routine. Parameters ---------- func : Callable The model function. weighted : bool Whether to apply inverse variance weighting. penalty_function : Callable, optional Optional penalty function for parameters. penalty_weight : float Scalar penalty strength. param_converter : Parameters, optional Object to convert raw to physical parameters. curvature_weight : float Strength of curvature matching regularization. smoothness_weight : float Strength of smoothness regularization. max_weight : float Penalty on worst residual. method : str Optimizer method ('adam' or 'lbfgs'). lbfgs_options : dict Dictionary of LBFGS-specific options. Returns ------- Tuple[Callable, Callable] The compiled loss function and optimization routine. """ #build_varpro_loss_function loss_function = jit(build_loss_function(func,weighted,penalty_function,penalty_weight,param_converter,curvature_weight,smoothness_weight,max_weight,)) #loss_function = jit(build_varpro_loss_function(func,weighted,penalty_function,penalty_weight,param_converter,curvature_weight,smoothness_weight,max_weight,)) #loss_function = jit(loss_function) loss_and_grad = jit(value_and_grad(loss_function)) def optimize_model(initial_params, xs, y, y_uncertainties, constraints): #Why this works slow? loss_history = [] if method == "lbfgs": optimizer = optax.lbfgs(**lbfgs_options) state = optimizer.init(initial_params) def lbfgs_step(carry): params, state = carry loss, grads = value_and_grad(loss_function)(params, xs, y, y_uncertainties) updates, state = optimizer.update( grads, state, params, value=loss, grad=grads, value_fn=lambda p: loss_function(p, xs, y, y_uncertainties) ) params = optax.apply_updates(params, updates) return (params, state), loss def cond_fn(carry): (_, _), _, i = carry return i < lbfgs_options.get("maxiter", 200) def body_fn(carry): (params, state), loss_hist, i = carry (params, state), loss = lbfgs_step((params, state)) loss_hist = loss_hist.at[i].set(loss) # Store into preallocated array return (params, state), loss_hist, i + 1 # Preallocate the history buffer maxiter = lbfgs_options.get("maxiter", 200) loss_hist_init = jnp.zeros((maxiter,), dtype=jnp.float64) # Run loop ((final_params, _), loss_history, _i) = lax.while_loop( cond_fn, body_fn, ((initial_params, state), loss_hist_init, 0) ) else: # adam #here should go a way to choose as a dictionary the name of the optimizer. optimizer = optax.adam(learning_rate=learning_rate) opt_state = optimizer.init(initial_params) def step_fn(carry, _): params, opt_state = carry loss, grads = loss_and_grad(params, xs, y, y_uncertainties) #value_and_grad(loss_function) updates, opt_state = optimizer.update(grads, opt_state, params) params = optax.apply_updates(params, updates) return (params, opt_state), loss (final_params, _), loss_history = lax.scan( step_fn, (initial_params, opt_state), None, length=num_steps ) return final_params, loss_history optimize_model = jit(optimize_model) #powerfull when we apply montecarlo-in in 1-2 objects sample not much impact +3 sec return loss_function, optimize_model
[docs] @staticmethod def minimization_function2( func: Callable, weighted: bool, penalty_function: Optional[Callable], penalty_weight: float, param_converter: Optional["Parameters"], curvature_weight: float, learning_rate : float, smoothness_weight: float, max_weight: float, method: str, lbfgs_options: dict, num_steps ) -> Tuple[Callable, Callable]: """ Builds the loss function and corresponding optimization routine. Parameters ---------- func : Callable The model function. weighted : bool Whether to apply inverse variance weighting. penalty_function : Callable, optional Optional penalty function for parameters. penalty_weight : float Scalar penalty strength. param_converter : Parameters, optional Object to convert raw to physical parameters. curvature_weight : float Strength of curvature matching regularization. smoothness_weight : float Strength of smoothness regularization. max_weight : float Penalty on worst residual. method : str Optimizer method ('adam' or 'lbfgs'). lbfgs_options : dict Dictionary of LBFGS-specific options. Returns ------- Tuple[Callable, Callable] The compiled loss function and optimization routine. """ loss_function = jit(build_varpro_loss_function(func,weighted,penalty_function,penalty_weight,param_converter,curvature_weight,smoothness_weight,max_weight,)) #loss_function = jit(loss_function) loss_and_grad = jit(value_and_grad(loss_function)) def optimize_model(initial_params, xs, y, y_uncertainties, constraints): #Why this works slow? loss_history = [] if method == "lbfgs": optimizer = optax.lbfgs(**lbfgs_options) state = optimizer.init(initial_params) def lbfgs_step(carry): params, state = carry loss, grads = value_and_grad(loss_function)(params, xs, y, y_uncertainties) updates, state = optimizer.update( grads, state, params, value=loss, grad=grads, value_fn=lambda p: loss_function(p, xs, y, y_uncertainties) ) params = optax.apply_updates(params, updates) return (params, state), loss def cond_fn(carry): (_, _), _, i = carry return i < lbfgs_options.get("maxiter", 200) def body_fn(carry): (params, state), loss_hist, i = carry (params, state), loss = lbfgs_step((params, state)) loss_hist = loss_hist.at[i].set(loss) # Store into preallocated array return (params, state), loss_hist, i + 1 # Preallocate the history buffer maxiter = lbfgs_options.get("maxiter", 200) loss_hist_init = jnp.zeros((maxiter,), dtype=jnp.float64) # Run loop ((final_params, _), loss_history, _i) = lax.while_loop( cond_fn, body_fn, ((initial_params, state), loss_hist_init, 0) ) else: # adam #here should go a way to choose as a dictionary the name of the optimizer. optimizer = optax.adam(learning_rate=learning_rate) opt_state = optimizer.init(initial_params) def step_fn(carry, _): params, opt_state = carry loss, grads = loss_and_grad(params, xs, y, y_uncertainties) #value_and_grad(loss_function) updates, opt_state = optimizer.update(grads, opt_state, params) params = optax.apply_updates(params, updates) return (params, opt_state), loss (final_params, _), loss_history = lax.scan( step_fn, (initial_params, opt_state), None, length=num_steps ) return final_params, loss_history optimize_model = jit(optimize_model) #powerfull when we apply montecarlo-in in 1-2 objects sample not much impact +3 sec return loss_function, optimize_model
class Minimizer_: """ Handles constrained optimization for a given model function using JAX and Optax. #TODO maybe for one object remove the JIT Attributes ---------- func : Callable The model function to be optimized. non_optimize_in_axis : int Determines vmap axis behavior: - 3: same initial values and constraints across data - 4: same constraints, different initial values - 5: different initial values and constraints num_steps : int Number of optimization iterations. learning_rate : float Learning rate for the optimizer (ignored for LBFGS). list_dependencies : list of str Parameter dependency specifications for tied parameters. method : str Optimization method to use ('adam' or 'lbfgs'). lbfgs_options : dict Options specific to LBFGS optimization (e.g., maxiter, tolerance_grad). optimizer : optax.GradientTransformation Optax optimizer instance. loss_function : Callable JIT-compiled loss function including penalties. optimize_model : Callable Function that performs the optimization loop. """ def __init__( self, func: Callable, non_optimize_in_axis: int = 3, num_steps: int = 1_000, learning_rate: Optional[float] = None, list_dependencies: List[str] = [], weighted: bool = True, method: str = "adam", lbfgs_options: Optional[Dict] = None, penalty_function: Optional[Callable] = None, param_converter: Optional["Parameters"] = None, penalty_weight: float = 0.01, curvature_weight: float = 1e3, smoothness_weight: float = 1e5, max_weight: float = 0.1, **kwargs, ): print("The experimental one") self.func = func self.non_optimize_in_axis = non_optimize_in_axis self.num_steps = num_steps self.learning_rate = learning_rate or 1e-2 self.list_dependencies = list_dependencies self.param_converter = param_converter self.method = method.lower() self.lbfgs_options = lbfgs_options or {} #self.optimizer = kwargs.get("optimizer", optax.adam(self.learning_rate)) #print(method,penalty_weight,curvature_weight,smoothness_weight,max_weight) #self.parsed_dependencies_tuple = parse_dependencies(self.list_dependencies) self.nonlinear_raw_idx = self.func.nonlinear_param_indices self.loss_fn, self.solve_fn = build_varpro_loss_from_profile_and_params_obj(fused_profile=self.func, params_obj=self.param_converter, nonlinear_raw_idx=self.nonlinear_raw_idx, weighted=True,lambda_reg=0.0,reg_matrix=None,) self.loss_function, self.optimize_model = Minimizer.minimization_function(self.func, weighted=weighted, penalty_function=penalty_function, penalty_weight=penalty_weight,param_converter=self.param_converter, curvature_weight=curvature_weight, learning_rate = learning_rate, smoothness_weight=smoothness_weight, max_weight=max_weight, method=self.method, lbfgs_options=self.lbfgs_options, num_steps = num_steps,loss_fn=self.loss_fn) def __call__(self, initial_params, y, x, yerror, constraints,): """ Execute the optimization process across batches. Parameters ---------- initial_params : jnp.ndarray Initial parameters for optimization. y : jnp.ndarray Observed data values. x : jnp.ndarray Wavelength or independent variable. yerror : jnp.ndarray Uncertainty for each observation. constraints : jnp.ndarray Parameter constraints, shape (N_params, 2). Returns ------- jnp.ndarray Optimized parameters. list Final loss history. """ optimize_in_axis = ((None, 0, 0, 0, None) if self.non_optimize_in_axis == 3 else (0, 0, 0, 0, None)) vmap_optimize_model = vmap(self.optimize_model, in_axes=optimize_in_axis, out_axes=0) vmap_solve_fn = vmap(self.solve_fn, in_axes=(0, 0, 0, 0), out_axes=0) print(initial_params.shape) if self.param_converter: raw0_full = jnp.asarray(self.param_converter.phys_to_raw(initial_params)) initial_params = raw0_full.at[:,jnp.array([*self.nonlinear_raw_idx ])].get() #initial_params = self.param_converter.phys_to_raw(initial_params) raw_params,loss = vmap_optimize_model(initial_params,y,x,yerror,constraints,) phys_best, _, _, _, _, _ = vmap_solve_fn(raw_params,y,x,yerror ) return phys_best,loss else: #print warning sayng about no param class is defined return vmap_optimize_model(initial_params,y,x,yerror,constraints,) @staticmethod def minimization_function( func: Callable, weighted: bool, penalty_function: Optional[Callable], penalty_weight: float, param_converter: Optional["Parameters"], curvature_weight: float, learning_rate : float, smoothness_weight: float, max_weight: float, method: str, lbfgs_options: dict, num_steps,loss_fn ) -> Tuple[Callable, Callable]: """ Builds the loss function and corresponding optimization routine. Parameters ---------- func : Callable The model function. weighted : bool Whether to apply inverse variance weighting. penalty_function : Callable, optional Optional penalty function for parameters. penalty_weight : float Scalar penalty strength. param_converter : Parameters, optional Object to convert raw to physical parameters. curvature_weight : float Strength of curvature matching regularization. smoothness_weight : float Strength of smoothness regularization. max_weight : float Penalty on worst residual. method : str Optimizer method ('adam' or 'lbfgs'). lbfgs_options : dict Dictionary of LBFGS-specific options. Returns ------- Tuple[Callable, Callable] The compiled loss function and optimization routine. """ #build_varpro_loss_function loss_function = jit(loss_fn) #loss_function = jit(build_varpro_loss_function(func,weighted,penalty_function,penalty_weight,param_converter,curvature_weight,smoothness_weight,max_weight,)) #loss_function = jit(loss_function) loss_and_grad = jit(value_and_grad(loss_function)) def optimize_model(initial_params, xs, y, y_uncertainties, constraints): #Why this works slow? loss_history = [] #else: # adam #here should go a way to choose as a dictionary the name of the optimizer. optimizer = optax.adam(learning_rate=learning_rate) opt_state = optimizer.init(initial_params) def step_fn(carry, _): params, opt_state = carry loss, grads = loss_and_grad(params, xs, y, y_uncertainties) #value_and_grad(loss_function) updates, opt_state = optimizer.update(grads, opt_state, params) params = optax.apply_updates(params, updates) return (params, opt_state), loss (final_params, _), loss_history = lax.scan( step_fn, (initial_params, opt_state), None, length=num_steps ) return final_params, loss_history optimize_model = jit(optimize_model) #powerfull when we apply montecarlo-in in 1-2 objects sample not much impact +3 sec return loss_function, optimize_model