"""
Parameter / Parameters with built-in support for *shared* (global) free parameters.
Key behavior
------------
- If you do NOT use shared params:
raw_init() returns:
- (n_free,) for single spectrum
- (n_spec, n_free) for batched spectra (same as your original code)
- If you DO use shared params (any Parameter with shared=True and not fixed and not tied):
raw_init() returns a *packed 1D vector*:
raw_packed = [ raw_shared (n_shared,) | raw_local_flat (n_spec * n_local,) ]
raw_to_phys(raw_packed) returns:
- (n_spec, n_total) (batched physical vectors, with shared params identical across spectra)
phys_to_raw(phys) returns the same packed 1D vector.
Notes
-----
- In shared-mode, at least ONE local (shared=False) parameter should carry a vector value
of shape (n_spec,) so the class can infer n_spec.
- Shared free parameters should usually be provided as scalars (or 0-d arrays).
"""
import math
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Iterable
import jax
import jax.numpy as jnp
default_inf = float("inf")
[docs]
class Parameter:
"""
Represents a single fit parameter with optional bounds, ties, fixed status,
and optional shared behavior across a batch (shared=True).
"""
def __init__(
self,
name: str,
value: Union[float, jnp.ndarray, List[float], Tuple[float, ...]],
*,
min: float = -default_inf,
max: float = default_inf,
tie: Optional[Tuple[str, str, str, float]] = None,
fixed: bool = False,
shared: bool = False,
):
self.name = name
if isinstance(value, (jnp.ndarray, list, tuple)):
self.value = jnp.array(value)
else:
self.value = float(value)
self.min = float(min)
self.max = float(max)
self.tie = tie
self.fixed = bool(fixed)
self.shared = bool(shared)
if math.isfinite(self.min) and math.isfinite(self.max):
self.transform = "logistic"
elif math.isfinite(self.min):
self.transform = "lower_bound_square"
elif math.isfinite(self.max):
self.transform = "upper_bound_square"
else:
self.transform = "linear"
[docs]
class Parameters:
r"""
Container for managing a list of `Parameter` instances.
Supports:
- fixed params (excluded from optimization)
- tied params (computed from sources)
- bounded transforms (logistic / square / linear)
- batched parameters (values as arrays)
- shared free parameters (shared=True) as global hyper-parameters for batched fits,
with packing/unpacking handled internally.
Notes
-----
- If *any* parameter has an array value with shape (n_spec,), we treat the container as batched.
- In shared-mode (at least one free shared parameter), raw space is a packed 1D vector:
[raw_shared..., raw_local(spec0)... raw_local(spec1)...]
and `raw_to_phys` returns a full physical array of shape (n_spec, n_total).
- IMPORTANT: In shared-mode, `raw_to_phys` broadcasts shared parameters across spectra in the
returned `phys` array, so `phys[:, idx_shared]` is always shape (n_spec,) and can be fed into
`vmap(lambda A,m,s: ...)` directly (shared means identical values across the batch).
"""
def __init__(self):
self._list: List[Parameter] = []
self._jit_raw_to_phys = None
self._jit_phys_to_raw = None
self._jit_raw_to_phys_mixed = None
self._jit_phys_to_raw_mixed = None
# caches set in _finalize
self._raw_list = None
self._tied_list = None
self._fixed_list = None
self._raw_shared_list = None
self._raw_local_list = None
self._n_spectra = None
# -------------------------
# mutation / cache invalidation
# -------------------------
[docs]
def add(
self,
name: str,
value: Union[float, jnp.ndarray, List[float], Tuple[float, ...]],
*,
min: Optional[float] = None,
max: Optional[float] = None,
tie: Optional[Tuple[str, str, str, float]] = None,
fixed: bool = False,
shared: bool = False,
):
lo = -jnp.inf if min is None else float(min)
hi = jnp.inf if max is None else float(max)
self._list.append(
Parameter(
name=name,
value=value,
min=lo,
max=hi,
tie=tie,
fixed=fixed,
shared=shared,
)
)
self._invalidate()
def _invalidate(self):
self._jit_raw_to_phys = None
self._jit_phys_to_raw = None
self._jit_raw_to_phys_mixed = None
self._jit_phys_to_raw_mixed = None
self._raw_list = None
self._tied_list = None
self._fixed_list = None
self._raw_shared_list = None
self._raw_local_list = None
self._n_spectra = None
# -------------------------
# shape inference / finalize
# -------------------------
def _infer_n_spectra(self) -> int:
lens = []
for p in self._list:
v = p.value
if isinstance(v, jnp.ndarray) and v.ndim > 0:
lens.append(int(v.shape[0]))
if not lens:
return 1
if len(set(lens)) != 1:
raise ValueError(f"Inconsistent batch lengths in parameters: {sorted(set(lens))}")
return lens[0]
@property
def names(self) -> List[str]:
return [p.name for p in self._list]
def _finalize(self):
# partition params by role
self._raw_list = [p for p in self._list if p.tie is None and not p.fixed]
self._tied_list = [p for p in self._list if p.tie is not None and not p.fixed]
self._fixed_list = [p for p in self._list if p.fixed]
# split free params into shared vs local
self._raw_shared_list = [p for p in self._raw_list if p.shared]
self._raw_local_list = [p for p in self._raw_list if not p.shared]
# infer batch size from ANY vector-valued parameter
self._n_spectra = self._infer_n_spectra()
# validations that prevent confusing behavior
if self._n_spectra == 1 and len(self._raw_shared_list) > 0:
# shared free params in a single-spectrum run is not harmful,
# but "mixed packing" would be pointless; still allow it.
pass
# JIT compile both "classic" and "mixed" cores
self._jit_raw_to_phys = jax.jit(self._raw_to_phys_core)
self._jit_phys_to_raw = jax.jit(self._phys_to_raw_core)
self._jit_raw_to_phys_mixed = jax.jit(self._raw_to_phys_core_mixed)
self._jit_phys_to_raw_mixed = jax.jit(self._phys_to_raw_core_mixed)
def _ensure_finalized(self):
if self._jit_raw_to_phys is None:
self._finalize()
def _has_shared_free(self) -> bool:
self._ensure_finalized()
return len(self._raw_shared_list) > 0
def _mixed_sizes(self) -> Tuple[int, int, int]:
"""
Returns (n_spec, n_shared_free, n_local_free) for shared-mode.
"""
self._ensure_finalized()
n_spec = int(self._n_spectra) if self._n_spectra is not None else 1
n_shared = len(self._raw_shared_list)
n_local = len(self._raw_local_list)
return n_spec, n_shared, n_local
# -------------------------
# transforms (helpers)
# -------------------------
@staticmethod
def _apply_transform(p: Parameter, rv: jnp.ndarray) -> jnp.ndarray:
if p.transform == "logistic":
return p.min + (p.max - p.min) * jax.nn.sigmoid(rv)
elif p.transform == "lower_bound_square":
return p.min + rv**2
elif p.transform == "upper_bound_square":
return p.max - rv**2
else:
return rv
@staticmethod
def _inv_transform(p: Parameter, vv: jnp.ndarray) -> jnp.ndarray:
if p.transform == "logistic":
frac = (vv - p.min) / (p.max - p.min)
frac = jnp.clip(frac, 1e-6, 1 - 1e-6)
return jnp.log(frac / (1 - frac))
elif p.transform == "lower_bound_square":
return jnp.sqrt(jnp.maximum(vv - p.min, 0))
elif p.transform == "upper_bound_square":
return jnp.sqrt(jnp.maximum(p.max - vv, 0))
else:
return vv
# -------------------------
# public API
# -------------------------
[docs]
def raw_init(self) -> jnp.ndarray:
"""
Initial raw vector from stored physical values.
"""
self._ensure_finalized()
init_phys = self.phys_init()
# Dispatch to the right inverse-map
if not self._has_shared_free():
return self._jit_phys_to_raw(init_phys)
if init_phys.ndim == 1:
init_phys = init_phys[None, :]
return self._jit_phys_to_raw_mixed(init_phys)
[docs]
def phys_init(self) -> jnp.ndarray:
"""
Build the initial physical parameter array from stored `Parameter.value`.
Returns
-------
jnp.ndarray
- If n_spec == 1: shape (n_total,)
- If n_spec > 1: shape (n_spec, n_total)
Notes
-----
- Scalar values are broadcast across spectra.
- Vector values (shape (n_spec,)) are taken per spectrum.
"""
self._ensure_finalized()
n_spec = int(self._n_spectra)
if n_spec == 1:
# single spectrum: all values must be scalars (or length-1 arrays)
return jnp.array([p.value for p in self._list])
init_phys_list = []
for i in range(n_spec):
spec_values = []
for p in self._list:
v = p.value
if isinstance(v, jnp.ndarray) and v.ndim > 0:
spec_values.append(v[i])
else:
spec_values.append(v)
init_phys_list.append(jnp.array(spec_values))
return jnp.stack(init_phys_list) # (n_spec, n_total)
[docs]
def raw_to_phys(self, raw_params: jnp.ndarray) -> jnp.ndarray:
"""
Convert raw parameter vector(s) to physical space.
- No shared free params: classic behavior.
- Shared free params: expects packed 1D raw and returns (n_spec, n_total),
with shared params broadcast across spectra in the returned `phys`.
"""
self._ensure_finalized()
if not self._has_shared_free():
return self._jit_raw_to_phys(raw_params)
return self._jit_raw_to_phys_mixed(raw_params)
[docs]
def phys_to_raw(self, phys_params: jnp.ndarray) -> jnp.ndarray:
"""
Convert physical parameter vector(s) to raw space.
- No shared free params: classic behavior.
- Shared free params: expects (n_spec, n_total) and returns packed 1D raw.
"""
self._ensure_finalized()
if not self._has_shared_free():
return self._jit_phys_to_raw(phys_params)
if phys_params.ndim == 1:
phys_params = phys_params[None, :]
return self._jit_phys_to_raw_mixed(phys_params)
# -------------------------
# classic cores (original behavior)
# -------------------------
def _raw_to_phys_core(self, raw: jnp.ndarray) -> jnp.ndarray:
def convert_one(r_vec, spec_idx):
ctx: Dict[str, jnp.ndarray] = {}
idx = 0
for p in self._raw_list:
ctx[p.name] = self._apply_transform(p, r_vec[idx])
idx += 1
for p in self._fixed_list:
v = p.value
ctx[p.name] = v[spec_idx] if isinstance(v, jnp.ndarray) else v
op_map = {"*": jnp.multiply, "+": jnp.add, "-": jnp.subtract, "/": jnp.divide}
for p in self._tied_list:
tgt, src, op, operand = p.tie
ctx[tgt] = op_map[op](ctx[src], operand)
return jnp.stack([ctx[p.name] for p in self._list])
if raw.ndim == 1:
return convert_one(raw, 0)
else:
N = raw.shape[0]
idxs = jnp.arange(N)
return jax.vmap(convert_one, in_axes=(0, 0))(raw, idxs)
def _phys_to_raw_core(self, phys: jnp.ndarray) -> jnp.ndarray:
def invert_one(v_vec):
ctx = {p.name: v_vec[i] for i, p in enumerate(self._list)}
raws: List[jnp.ndarray] = []
for p in self._raw_list:
raws.append(self._inv_transform(p, ctx[p.name]))
return jnp.stack(raws)
if phys.ndim == 1:
return invert_one(phys)
else:
return jax.vmap(invert_one)(phys)
# -------------------------
# mixed cores (shared + local packed raw, internal)
# -------------------------
def _raw_to_phys_core_mixed(self, raw_packed: jnp.ndarray) -> jnp.ndarray:
"""
Shared-mode raw->phys.
Input
-----
raw_packed : shape (n_shared + n_spec*n_local,)
Output
------
phys : shape (n_spec, n_total)
Guarantee
---------
In the returned `phys`, any shared parameter column is broadcast to all spectra,
so `phys[:, idx_shared]` has shape (n_spec,) and identical values.
"""
n_spec, n_shared, n_local = self._mixed_sizes()
if n_shared == 0:
return self._raw_to_phys_core(raw_packed)
raw_shared = raw_packed[:n_shared] # (n_shared,)
raw_local = raw_packed[n_shared:].reshape(n_spec, n_local) # (n_spec, n_local)
# Compute shared phys scalars once, then broadcast when building each spectrum row
shared_vals: Dict[str, jnp.ndarray] = {}
for j, p in enumerate(self._raw_shared_list):
v = self._apply_transform(p, raw_shared[j]) # scalar
shared_vals[p.name] = v
op_map = {"*": jnp.multiply, "+": jnp.add, "-": jnp.subtract, "/": jnp.divide}
def convert_one(r_loc, spec_idx):
ctx: Dict[str, jnp.ndarray] = {}
# shared free params: scalar in ctx for this spectrum row
for p in self._raw_shared_list:
ctx[p.name] = shared_vals[p.name]
# local free params for this spec
for j, p in enumerate(self._raw_local_list):
ctx[p.name] = self._apply_transform(p, r_loc[j])
# fixed params (may be per-spec arrays)
for p in self._fixed_list:
v = p.value
ctx[p.name] = v[spec_idx] if isinstance(v, jnp.ndarray) else v
# tied params
for p in self._tied_list:
tgt, src, op, operand = p.tie
ctx[tgt] = op_map[op](ctx[src], operand)
return jnp.stack([ctx[p.name] for p in self._list])
idxs = jnp.arange(n_spec)
phys = jax.vmap(convert_one, in_axes=(0, 0))(raw_local, idxs) # (n_spec, n_total)
# --- broadcast shared columns explicitly to guarantee `phys[:, idx_shared]` is (n_spec,)
# and identical across rows (even if later logic changes).
if self._raw_shared_list:
for p in self._raw_shared_list:
col = self.names.index(p.name)
phys = phys.at[:, col].set(jnp.full((n_spec,), shared_vals[p.name]))
return phys
def _phys_to_raw_core_mixed(self, phys: jnp.ndarray) -> jnp.ndarray:
"""
Shared-mode phys->raw (inverse of packed layout).
"""
n_spec, n_shared, n_local = self._mixed_sizes()
if n_shared == 0:
return self._phys_to_raw_core(phys)
def ctx_from_phys(v_vec):
return {p.name: v_vec[i] for i, p in enumerate(self._list)}
# shared raws: read from first spectrum
ctx0 = ctx_from_phys(phys[0])
shared_raws = []
for p in self._raw_shared_list:
shared_raws.append(self._inv_transform(p, ctx0[p.name]))
shared_raws = jnp.stack(shared_raws) if shared_raws else jnp.zeros((0,), dtype=phys.dtype)
# local raws per spectrum
def local_raws_one(v_vec):
ctx = ctx_from_phys(v_vec)
raws = []
for p in self._raw_local_list:
raws.append(self._inv_transform(p, ctx[p.name]))
return jnp.stack(raws) if raws else jnp.zeros((0,), dtype=phys.dtype)
local_raws = jax.vmap(local_raws_one)(phys) # (n_spec, n_local)
return jnp.concatenate([shared_raws, local_raws.ravel()])
# -------------------------
# convenience
# -------------------------
@property
def specs(self) -> List[Tuple[str, Any, float, float, str, bool, bool]]:
"""
(name, value, min, max, transform, fixed, shared)
"""
return [(p.name, p.value, p.min, p.max, p.transform, p.fixed, p.shared) for p in self._list]
@property
def summary(self) -> List[Dict[str, Any]]:
"""
User-facing summary of parameters and definitions.
"""
self._ensure_finalized()
rows: List[Dict[str, Any]] = []
for p in self._list:
if p.fixed:
status = "fixed"
elif p.tie is not None:
status = "tied"
else:
status = "free"
v = p.value
if isinstance(v, jnp.ndarray):
v_out: Any = {
"shape": tuple(v.shape),
"dtype": str(v.dtype),
"preview": v[:5] if (v.ndim == 1 and v.size > 5) else v,
}
else:
v_out = float(v)
rows.append(
{
"name": p.name,
"value": v_out,
"min": float(p.min),
"max": float(p.max),
"transform": p.transform,
"fixed": bool(p.fixed),
"shared": bool(p.shared),
"tie": p.tie,
"status": status,
}
)
return rows
[docs]
def build_Parameters(
tied_map: Dict[int, Tuple[int, str, float]],
params_dict: Dict[str, int],
initial_params: Iterable[float],
constraints: jnp.ndarray,
) -> Parameters:
r"""
Construct a :class:`Parameters` object from initialization arrays, constraints,
and tie definitions.
This helper builds a container of :class:`Parameter` instances ready for fitting,
applying bounds, fixed values, and tied relationships.
Parameters
----------
tied_map : dict[int, tuple[int, str, float]]
Mapping of parameter indices to tie definitions.
Each entry is of the form
``idx_target -> (idx_source, op, operand)``, where:
* ``idx_target`` is the index of the tied parameter,
* ``idx_source`` is the index of the source parameter,
* ``op`` is an arithmetic operator string (``'*'``, ``'/'``, ``'+'``, ``'-'``),
* ``operand`` is a numeric factor or offset.
params_dict : dict[str, int]
Dictionary mapping parameter names to their index positions
in the parameter vector.
initial_params : array-like, shape (n_params,)
Initial physical parameter values.
constraints : array-like, shape (n_params, 2)
Lower and upper bounds per parameter.
Returns
-------
Parameters
A populated container with names, values, bounds, and tie definitions.
Notes
-----
* Tied parameters are added with their ``tie`` attribute and are not optimized
directly; their values are reconstructed from the source parameter during
raw→physical mapping.
* Untied parameters are added with their initial value and bounds taken from
``constraints``.
* Fixed parameters are not assigned here; add them via
:meth:`Parameters.add(..., fixed=True) <Parameters.add>` if needed.
* Typically called by higher-level fitting routines (e.g., :class:`RegionFitting`)
when preparing parameter sets.
"""
params_obj = Parameters()
for name, idx in params_dict.items():
val = jnp.atleast_2d(initial_params)[:,idx]
min_val, max_val = constraints[idx]
if idx in tied_map.keys():
src_idx, op, operand = tied_map[idx]
src_name = list(params_dict.keys())[src_idx]
tie = (name, src_name, op, operand)
params_obj.add(name, val, min=min_val, max=max_val, tie=tie)
else:
params_obj.add(name, val, min=min_val, max=max_val)
return params_obj