Source code for sheap.SheaProducts.Utils.Sample_handlers

"""
Flattening and Summarizing Utilities
====================================

This module provides helpers to:
- Concatenate lists of dictionaries with array-like leaves.
- Flatten nested “masses” / parameter dictionaries to pandas DataFrames.
- Pivot nested results per object into per-object dictionaries.
- Summarize posterior samples via 16/50/84 percentiles.

Notes
-----
- All functions are NumPy/JAX friendly; arrays may be ``numpy.ndarray`` or ``jax.numpy.ndarray``.
- Percentile summaries follow a common convention:
    median = 50th percentile, err_minus = 50th - 16th, err_plus = 84th - 50th.
"""

__author__ = 'felavila'

__all__ = [
    "flatten_mass_dict",
    "flatten_mass_samples_to_df",
    "flatten_param_dict",
    "flatten_scalar_dict",
    "pivot_and_split",
    "summarize_nested_samples",
    "summarize_samples",
    "concat_dicts",
    "concat_dicts_combine"
]

from typing import Dict, Any
import pandas as pd
import warnings
from collections import defaultdict

import numpy as np 
import jax.numpy as jnp
from uncertainties import unumpy

#TODO clean up 

[docs] def concat_dicts_combine(list_of_dicts): out = {} lines =[] for line,values in list_of_dicts.items(): for k,v in values.items(): if k not in out.keys(): out[k] = [v] else: out[k].append(v) lines.append(line) # flatten or stack if numeric/array-like for k, v in out.items(): if "component" not in k: out[k] = np.stack(v, axis=1).squeeze() else: out[k] = v out["lines"] = lines out["combined"] = True return {"broad": dict(out)}
[docs] def concat_dicts(list_of_dicts): """ Concatenate lists/arrays across a list of homogeneous dictionaries. Parameters ---------- list_of_dicts : list of dict Each dict must share the same keys. Values are arrays (or array-like) that can be concatenated along their first axis. Returns ------- dict Dictionary with the same keys; each value is the concatenation of the corresponding values across the input list, then transposed (so that samples shape usually becomes ``(N, ...)``). Notes ----- This is used to merge per-profile/line dictionaries into a single per-region dict. It expects all leaves to be concatenable; non-numeric leaves should be filtered out before calling. """ out = defaultdict(list) for d in list_of_dicts: for k, v in d.items(): out[k].append(v) # flatten or stack if numeric/array-like for k, v in out.items(): out[k] = np.concatenate([x for x in v]).T return dict(out)
[docs] def flatten_mass_samples_to_df(dict_samples: Dict[str, Dict[str, Any]]) -> pd.DataFrame: """ Flatten nested mass sample summaries into a tidy DataFrame. Parameters ---------- dict_samples : dict Mapping ``object_name -> {"masses": {line: {quantity: stats_dict}}}``. Each ``stats_dict`` must have keys ``median``, ``err_minus``, ``err_plus`` (scalars or 0-d arrays). Returns ------- pandas.DataFrame Columns: ``object``, ``line``, ``quantity``, ``median``, ``err_minus``, ``err_plus``. """ records = [] for object_key, item in dict_samples.items(): if not isinstance(item, Dict) or "masses" not in item: continue for line_name, stats in item["masses"].items(): for stat_name, values in stats.items(): records.append({ "object": object_key, "line": line_name, "quantity": stat_name, "median": values["median"].item(), "err_minus": values["err_minus"].item(), "err_plus": values["err_plus"].item() }) return pd.DataFrame(records)
[docs] def flatten_param_dict(dict_basic_params): """ Convert a nested parameter dictionary into a tidy table. Parameters ---------- dict_basic_params : dict Structure like: ``{kind: {"lines": [...], "component": [...], <param>: {"median": [...], "err_minus": [...], "err_plus": [...]}, ...}}`` Returns ------- pandas.DataFrame One row per (line, component, kind, parameter), with median and error bars. """ rows = [] for kind, values in dict_basic_params.items(): lines = values["lines"] components = values["component"] for param_name, param_values in values.items(): if param_name in ["lines", "component"]: continue medians = param_values["median"] err_minus = param_values.get("err_minus", [None]*len(medians)) err_plus = param_values.get("err_plus", [None]*len(medians)) for _, (line, comp, med, err_m, err_p) in enumerate(zip(lines, components, medians, err_minus, err_plus)): rows.append({ "line_name": line, "component": comp, "kind": kind, "parameter": param_name, "median": med, "err_minus": err_m, "err_plus": err_p }) return pd.DataFrame(rows)
[docs] def flatten_scalar_dict(name, scalar_dict): """ Flatten a scalar-valued dictionary (e.g., L_bol/L_w summaries) into a DataFrame. Parameters ---------- name : str Label for the quantity (e.g., ``"L_bol"`` or ``"L_w"``). scalar_dict : dict Mapping ``key -> {"median": scalar, "err_minus": scalar, "err_plus": scalar}``. Returns ------- pandas.DataFrame Columns: ``quantity``, ``wavelength_or_line``, ``median``, ``err_minus``, ``err_plus``. """ rows = [] for key, stats in scalar_dict.items(): rows.append({ "quantity": name, "wavelength_or_line": key, "median": stats["median"].item(), "err_minus": stats["err_minus"].item(), "err_plus": stats["err_plus"].item() }) return pd.DataFrame(rows)
[docs] def flatten_mass_dict(masses): """ Flatten a masses dictionary into a DataFrame. Parameters ---------- masses : dict Mapping ``line -> {stat_name: {"median": scalar, "err_minus": scalar, "err_plus": scalar}}``. Returns ------- pandas.DataFrame Columns: ``line_name``, ``quantity``, ``median``, ``err_minus``, ``err_plus``. """ rows = [] for line, metrics in masses.items(): #print(line) for stat_name, stats in metrics.items(): rows.append({ "line_name": line, "quantity": stat_name, "median": stats["median"].item(), "err_minus": stats["err_minus"].item(), "err_plus": stats["err_plus"].item() }) return pd.DataFrame(rows)
[docs] def pivot_and_split(obj_names, result): """ Two-pass approach: 1) Normalize the tree once: replace uarray leaves with {'value': vals, 'error': errs} and plain arrays with {'median': arr, 'error': 0} when leading dim == N. 2) Create per-object slices without calling unumpy again. """ N = len(obj_names) _memo = {} def _normalize(node): oid = id(node) if oid in _memo: return _memo[oid] # dict -> recurse if isinstance(node, dict): out = {k: _normalize(v) for k, v in node.items()} # simple scalars/strings/None -> as-is elif isinstance(node, (str, int, float, type(None))): out = node # arrays of uncertainties (dtype=object) -> split once elif isinstance(node, np.ndarray) and node.dtype == object and node.size: # Try extracting; if it's not uarray-like, fall back as-is. try: vals = unumpy.nominal_values(node) errs = unumpy.std_devs(node) out = {'median': vals, 'error': errs} except Exception: out = node # not an uncertainties array after all # numeric arrays whose first axis is N -> batch leaf elif isinstance(node, np.ndarray) and node.ndim >= 1 and node.shape[0] == N: out = {'median': node, 'error': 0} # keep a scalar 0 to avoid huge zero arrays # lists/tuples -> recurse elementwise (kept shape) elif isinstance(node, (list, tuple)): seq = [_normalize(x) for x in node] out = tuple(seq) if isinstance(node, tuple) else seq else: out = node _memo[oid] = out return out normalized = _normalize(result) def _slice_idx(node, idx): # If this is a normalized leaf with 'value'/'error', slice only if the first dim is N if isinstance(node, dict) and 'median' in node and 'error' in node: v, e = node['median'], node['error'] if isinstance(v, np.ndarray) and v.ndim >= 1 and v.shape[0] == N: # e can be 0 (scalar) or an array aligned with v ei = (e[idx].squeeze() if isinstance(e, np.ndarray) and e.ndim >= 1 and e.shape[0] == N else e) return {'median': v[idx].squeeze(), 'error': ei} # not a batch leaf → recurse normally below if isinstance(node, dict): return {k: _slice_idx(v, idx) for k, v in node.items()} if isinstance(node, (list, tuple)): seq = [_slice_idx(x, idx) for x in node] return tuple(seq) if isinstance(node, tuple) else seq return node return {name: _slice_idx(normalized, i) for i, name in enumerate(obj_names)}
[docs] def summarize_samples(samples) -> Dict[str, np.ndarray]: """ Summarize a sample vector by 16th/50th/84th percentiles. Parameters ---------- samples : array-like 1D vector of draws, or 2D array where each column is a separate variable. JAX arrays are accepted and converted to NumPy for percentile computation. Returns ------- dict ``{"median": ..., "err_minus": ..., "err_plus": ...}``. Notes ----- - If more than 20% of entries are NaN, a warning is emitted and percentiles are computed with ``np.nanpercentile``. - For 1D input, returns scalars; for 2D (n, m), returns length-m arrays. """ if isinstance(samples, jnp.ndarray): samples = np.asarray(samples) samples = np.atleast_2d(samples).T #print(type(samples)) if np.isnan(samples).sum() / samples.size > 0.2: warnings.warn("High fraction of NaNs; uncertainty estimates may be biased.") if samples.shape[1]<=1: q = np.nanpercentile(samples, [16, 50, 84], axis=0) else: q = np.nanpercentile(samples, [16, 50, 84], axis=1) #else: return { "median": q[1], "err_minus": q[1] - q[0], "err_plus": q[2] - q[1] }
[docs] def summarize_nested_samples(d: dict,run_summarize:bool = True) -> dict: """ Recursively apply :func:`summarize_samples` to array-like leaves. Parameters ---------- d : dict Nested dictionary whose leaves may be arrays to summarize. run_summarize : bool, default True If False, returns ``d`` unchanged. Returns ------- dict Same structure as input with arrays replaced by percentile summaries. Notes ----- - Keys named ``"component"`` are passed through untouched (often categorical). - JAX arrays are accepted; they are converted to NumPy within :func:`summarize_samples`. """ if not run_summarize: return d summarized = {} for k, v in d.items(): if isinstance(v, dict): summarized[k] = summarize_nested_samples(v) elif isinstance(v, (np.ndarray, jnp.ndarray)) and np.ndim(v) >= 1 and k!='component': summarized[k] = summarize_samples(v) else: summarized[k] = v return summarized