Source code for sheap.MasterSampler.Samplers.PseudoMonteCarloSampler

"""
Pseudo Monte Carlo Sampler
==========================

.. note::
   the docs require update now.

?
"""
from __future__ import annotations
__author__ = 'felavila'

__all__ = ["PseudoMonteCarloSampler"]

from typing import Callable, Optional, Tuple, Any, Dict, Dict, List
#import warning 

import jax.numpy as jnp
from jax import vmap, random,jacobian
import jax

import numpy as np 

from typing import Any, Callable, Dict, Optional, Tuple


from sheap.Assistants.parser_mapper import descale_amp,scale_amp,apply_tied_and_fixed_params,make_get_param_coord_value,build_tied,parse_dependencies,flatten_tied_map
from sheap.Assistants.Parameters import build_Parameters
from sheap.SheaProducts.SheaProducts import SheaProducts


[docs] class PseudoMonteCarloSampler: """ Laplace (pseudo Monte Carlo) sampler in PHYSICAL space, robust and batched. ? """ def __init__(self, estimator: Any, dtype=jnp.float32): self.estimator = estimator self.dtype = dtype self.SheaProducts = SheaProducts(estimator) self.obj_params = getattr(estimator, "obj_params", None) or getattr(estimator, "params", None) self.names = getattr(estimator, "names", None) self.constraints = getattr(estimator, "constraints", None) self.params_dict = getattr(estimator, "params_dict", None) self.scale = getattr(estimator, "scale", None) #self.constraints = if self.obj_params is None or isinstance(self.obj_params, jnp.ndarray): self.initial_params = getattr(estimator, "initial_params", None) #self.constraints self.get_param_coord_value = make_get_param_coord_value(self.params_dict, self.initial_params) self.fitkwargs = getattr(estimator, "fitkwargs", None) self.obj_params = self._make_params_obj() if not (hasattr(self.obj_params, "raw_to_phys") and hasattr(self.obj_params, "phys_to_raw")): raise TypeError("obj_params must provide raw_to_phys(x) and phys_to_raw(x).") params_phys = getattr(estimator.result, "params", None) #params_dict, params, scale #(params_dict, params, scale) params_phys = descale_amp(self.params_dict,params_phys,self.scale) if params_phys is None: raise RuntimeError("estimator.result.params not found (expected physical mode).") mode = jnp.asarray(params_phys, dtype=self.dtype) self.mode_phys, self.is_batched = self._ensure_batched(mode) # (N, D), bool
[docs] def sample_params( self, num_samples: int, key_seed: int = 0, cov_phys: Optional[jnp.ndarray] = None, residuals_fn_phys: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, eps: float = 1e-8, summarize: bool = True, # kept for compatibility; not used here ) -> Dict[str, Any]: """ Returns ------- dict with keys: - "samples_raw": (S, D) or (N, S, D) - "samples_phys": (S, D) or (N, S, D) (always included) - "products": postprocess_fn(samples_phys) if provided """ key = random.PRNGKey(key_seed) lo_b, hi_b = self._get_bounds_phys_broadcast() # (N, D) covP = self._get_cov_phys_broadcast(cov_phys, residuals_fn_phys) # (N, D, D) phys_samples = self._laplace_phys_samples_safe_batched( mean_phys=self.mode_phys, cov_phys=covP, lo_phys=lo_b, hi_phys=hi_b, key=key, num_samples=num_samples, eps=eps,) #bottle neck. dic_posterior_params = {} for n,name_i in enumerate(self.names): if n % 100 == 0: print(f"{n} of {len(self.names)}") full_samples = scale_amp(self.params_dict,phys_samples[n],np.array(self.scale[n])) dic_posterior_params[name_i] = self.SheaProducts.extract_params(full_samples,n,summarize=summarize) dic_posterior_params[name_i].update({"samples_phys":full_samples}) return dic_posterior_params
def _get_cov_phys_broadcast( self, cov_phys: Optional[jnp.ndarray], residuals_fn_phys: Optional[Callable[[jnp.ndarray], jnp.ndarray]], ) -> jnp.ndarray: """ Returns (N, D, D). Broadcasts shared (D, D) to all N. """ if cov_phys is not None: cov = jnp.asarray(cov_phys, dtype=self.dtype) else: maybe_cov = getattr(self.estimator.result, "cov_phys", None) cov = jnp.asarray(maybe_cov, dtype=self.dtype) if maybe_cov is not None else None if cov is not None: if cov.ndim == 2: cov = jnp.broadcast_to(cov, (self.mode_phys.shape[0],) + cov.shape) # (N, D, D) elif cov.ndim == 3: if cov.shape[0] != self.mode_phys.shape[0]: raise ValueError(f"cov_phys batch dim {cov.shape[0]} != N {self.mode_phys.shape[0]}") else: raise ValueError("cov_phys must be (D,D) or (N,D,D)") return self._symmetrize_jitter_batched(cov) # Fallback: compute via Gauss–Newton per object if residuals_fn_phys is None: residuals_fn_phys = getattr(self.estimator, "residuals_fn_phys", None) if residuals_fn_phys is None: # Tiny-diagonal per object N, D = self.mode_phys.shape diag_vecs = jnp.ones((N, D), dtype=self.dtype) * 1e-6 cov = vmap(jnp.diag, in_axes=0, out_axes=0)(diag_vecs) # (N, D, D) return self._symmetrize_jitter_batched(cov) def _gn(x_phys: jnp.ndarray) -> jnp.ndarray: J = jacobian(residuals_fn_phys)(x_phys) # (L, D) H = J.T @ J H = 0.5 * (H + H.T) + jnp.eye(H.shape[0], dtype=H.dtype) * 1e-8 return jnp.linalg.pinv(H, rcond=1e-12) cov = vmap(_gn, in_axes=0)(self.mode_phys) # (N, D, D) return self._symmetrize_jitter_batched(cov) def _symmetrize_jitter_batched(self, cov: jnp.ndarray, jitter: float = 0.0) -> jnp.ndarray: cov = 0.5 * (cov + jnp.swapaxes(cov, -1, -2)) if jitter > 0: I = jnp.eye(cov.shape[-1], dtype=cov.dtype) cov = cov + I[None, ...] * jitter return cov def _raw_to_phys(self, x: jnp.ndarray) -> jnp.ndarray: return self.obj_params.raw_to_phys(x) def _phys_to_raw(self, x: jnp.ndarray) -> jnp.ndarray: return self.obj_params.phys_to_raw(x) def _get_bounds_phys(self) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Build (lo, hi) in physical space from Parameters.specs (preferred), else from estimator.result.constraints, else unbounded. """ if hasattr(self.obj_params, "specs") and self.obj_params.specs is not None: specs = self.obj_params.specs los, his = [], [] for s in specs: if isinstance(s, dict): lo = s.get("min", None) hi = s.get("max", None) else: lo = getattr(s, "min", None) hi = getattr(s, "max", None) los.append(-jnp.inf if lo is None else lo) his.append( jnp.inf if hi is None else hi) return jnp.asarray(los, self.dtype), jnp.asarray(his, self.dtype) res_constraints = getattr(self.estimator.result, "constraints", None) names = getattr(self.obj_params, "names", None) if (res_constraints is not None) and (names is not None): los, his = [], [] for nm in names: Limit = res_constraints.get(nm, {}) if isinstance(res_constraints, dict) else {} lo = Limit.get("min", None) hi = Limit.get("max", None) los.append(-jnp.inf if lo is None else lo) his.append( jnp.inf if hi is None else hi) return jnp.asarray(los, self.dtype), jnp.asarray(his, self.dtype) D = self.mode_phys.shape[-1] return (jnp.full((D,), -jnp.inf, dtype=self.dtype), jnp.full((D,), jnp.inf, dtype=self.dtype)) def _get_bounds_phys_broadcast(self) -> Tuple[jnp.ndarray, jnp.ndarray]: lo, hi = self._get_bounds_phys() # (D,), (D,) N, D = self.mode_phys.shape return jnp.broadcast_to(lo, (N, D)), jnp.broadcast_to(hi, (N, D)) @staticmethod def _make_psd(cov: jnp.ndarray, min_eig: float = 1e-12) -> jnp.ndarray: cov = 0.5 * (cov + jnp.swapaxes(cov, -1, -2)) w, V = jnp.linalg.eigh(cov) # w: (N,D), V: (N,D,D) cov_psd = (V * w[..., None, :]) @ jnp.swapaxes(V, -1, -2) return cov_psd @classmethod def _safe_cholesky_batched(cls, cov: jnp.ndarray, jitter0: float = 1e-10, max_tries: int = 6) -> jnp.ndarray: def chol_one(C: jnp.ndarray) -> jnp.ndarray: C = 0.5 * (C + C.T) # try direct try: return jnp.linalg.cholesky(C) except Exception: pass # jitter backoff j = jitter0 for _ in range(max_tries): try: return jnp.linalg.cholesky(C + j * jnp.eye(C.shape[0], dtype=C.dtype)) except Exception: j *= 10.0 C_psd = PseudoMonteCarloSampler._make_psd(C[None, ...])[0] return jnp.linalg.cholesky(C_psd + 1e-8 * jnp.eye(C.shape[0], dtype=C.dtype)) return vmap(chol_one, in_axes=0)(cov) # (N, D, D) @classmethod def _laplace_phys_samples_safe_batched( cls, mean_phys: jnp.ndarray, # (N, D) cov_phys: jnp.ndarray, # (N, D, D) lo_phys: jnp.ndarray, # (N, D) hi_phys: jnp.ndarray, # (N, D) key: jax.Array, num_samples: int, eps: float = 1e-8, ) -> jnp.ndarray: N, D = mean_phys.shape L = cls._safe_cholesky_batched(cov_phys) # (N, D, D) # Random normals per object → (N, S, D) subkeys = random.split(key, N) # (N, 2) def draw_one(k: jax.Array, mu: jnp.ndarray, Lm: jnp.ndarray) -> jnp.ndarray: z = random.normal(k, (num_samples, D), dtype=mu.dtype) return mu[None, :] + z @ jnp.swapaxes(Lm, -1, -2) phys = vmap(draw_one, in_axes=(0, 0, 0))(subkeys, mean_phys, L) # (N, S, D) # Enforce bounds per object reflect = vmap(PseudoMonteCarloSampler._reflect_into_bounds, in_axes=(0, 0, 0)) push = vmap(PseudoMonteCarloSampler._push_interior, in_axes=(0, 0, 0, None)) sanitize= vmap(PseudoMonteCarloSampler._sanitize_nonfinite, in_axes=(0, 0)) phys = reflect(phys, lo_phys[:, None, :], hi_phys[:, None, :]) # (N, S, D) phys = push(phys, lo_phys[:, None, :], hi_phys[:, None, :], eps) # NOTE: eps is NOT mapped phys = sanitize(phys, mean_phys[:, None, :]) return phys @staticmethod def _reflect_into_bounds(x: jnp.ndarray, lo: jnp.ndarray, hi: jnp.ndarray) -> jnp.ndarray: finite_lo = jnp.isfinite(lo) finite_hi = jnp.isfinite(hi) width = hi - lo finite_interval = jnp.logical_and(finite_lo, finite_hi) & (width > 0) t = (x - lo) / jnp.where(finite_interval, width, 1.0) t_wrapped = t - jnp.floor(t / 2.0) * 2.0 x_reflected = jnp.where( t_wrapped <= 1.0, lo + t_wrapped * width, hi - (t_wrapped - 1.0) * width, ) x = jnp.where(finite_interval, x_reflected, x) x = jnp.where(jnp.logical_and(finite_lo, ~finite_hi), jnp.maximum(x, lo), x) x = jnp.where(jnp.logical_and(~finite_lo, finite_hi), jnp.minimum(x, hi), x) return x @staticmethod def _push_interior(x: jnp.ndarray, lo: jnp.ndarray, hi: jnp.ndarray, eps: float = 1e-8) -> jnp.ndarray: finite_lo = jnp.isfinite(lo) finite_hi = jnp.isfinite(hi) both = jnp.logical_and(finite_lo, finite_hi) lo_i = jnp.where(both, lo + eps, lo) hi_i = jnp.where(both, hi - eps, hi) return jnp.minimum(jnp.maximum(x, lo_i), hi_i) @staticmethod def _sanitize_nonfinite(x: jnp.ndarray, fallback: jnp.ndarray) -> jnp.ndarray: bad = ~jnp.isfinite(x) return jnp.where(bad, fallback, x) @staticmethod def _ensure_batched(mode_phys: jnp.ndarray) -> Tuple[jnp.ndarray, bool]: if mode_phys.ndim == 1: return mode_phys[None, :], False if mode_phys.ndim == 2: return mode_phys, True raise ValueError("mode_phys must be (D,) or (N, D)") def _make_params_obj(self): list_dependencies = parse_dependencies( build_tied(self.fitkwargs[-1]["tied"], self.get_param_coord_value) ) tied_map = {T[1]: T[2:] for T in list_dependencies} tied_map = flatten_tied_map(tied_map) return build_Parameters(tied_map, self.params_dict, self.initial_params, self.constraints)