"""
Core Data Structures
====================
This module defines the core data classes used across **sheap** to describe
spectral lines, grouped regions, fitting outputs, per‑profile constraints,
and per‑kind fitting limits.
Exposed classes
---------------
- :class:`SpectralLine` — a single (or composite) emission/absorption component.
- :class:`SheapModel` — a container of lines with profile functions, parameter
maps, and convenient subsetting/grouping utilities.
- :class:`SheapResult` — a structured record of a completed fit (parameters,
uncertainties, residuals, χ², etc.).
- :class:`ProfileConstraintSet` — per‑profile initial values and bounds.
- :class:`FittingLimits` — canonical velocity/shift/amplitude limits by kind.
Main Features
-------------
- Dataclass APIs with typed fields and `.to_dict()` helpers.
- Region‑level table view (`as_df()`), filtering, grouping, and safe subsetting
that preserves global↔local parameter index mappings.
- Lazy assembly of fused profile functions for fast evaluation with JAX.
- Seamless attachment of fitted parameter matrices and uncertainties.
Notes
-----
- Arrays may be NumPy or JAX arrays; fused profile evaluation is JAX‑friendly.
- Global parameter indices are preserved when subsetting regions so results can
be traced back to the original packed parameter vector.
"""
from __future__ import annotations
__author__ = 'felavila'
__all__ = [
"SheapModel",
"SheapResult",
"FittingLimits",
"ProfileConstraintSet",
"SpectralLine",
]
from dataclasses import asdict, dataclass, field
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import jax.numpy as jnp
import numpy as np
import pandas as pd
from sheap.Profiles.Utils import make_fused_profiles
[docs]
@dataclass
class SpectralLine:
"""
Represents a single spectral emission or absorption line component.
Parameters
----------
line_name : str or list of str
Identifier(s) for the spectral line (e.g., 'Halpha'), or for a composite region
the region name plus component number.
center : float or list of float, optional
Central wavelength(s) of the line in Angstroms.
region : str, optional
Spatial region of the line, one of 'narrow', 'broad', 'outflow', or 'fe'.
component : int, optional
Integer identifier for this component within its region.
subregion : str, optional
Element + spatial subregion tag, useful for complex templates (e.g. FeII sub‐regions).
amplitude : float or list of float, optional
Initial or fixed amplitude(s) for the line(s).
element : str, optional
Chemical identifier of the line (e.g., 'H', 'FeII').
profile : str, optional
Name of the profile function to use ('gaussian', 'lorentzian', etc.).
region_lines : list of str, optional
Explicit list of line identifiers included in a composite region.
amplitude_relations : list of list, optional
Parameter‐tying definitions (e.g. fixed ratios) among amplitudes.
subprofile : str, optional
Sub‐profile name for compound models (e.g. a secondary kernel).
rarity : str or list of str, optional
Qualitative frequency label for the line (e.g. 'common', 'rare').
template_info : dict, optional
Additional template metadata (e.g. for 'hostmiles' or 'fetemplate' profiles).
Attributes
----------
(all parameters become attributes of this dataclass)
Methods
-------
to_dict()
Convert the SpectralLine instance into a plain dictionary via `asdict`.
Examples
--------
>>> line = SpectralLine(
... line_name='Halpha',
... center=6563.0,
... region='narrow',
... component=0,
... profile='gaussian',
... amplitude=1.0
... )
>>> d = line.to_dict()
>>> print(d['center'])
6563.0
"""
line_name: Union[str, List[str]]
center: Optional[Union[float, List[float]]] = None
region: Optional[str] = None
component: Optional[int] = None
subregion: Optional[str] = None
amplitude: Optional[Union[float, List[float]]] = None
element: Optional[str] = None
profile: Optional[str] = None
region_lines: Optional[List[str]] = None
amplitude_relations: Optional[List[List]] = None
subprofile: Optional[str] = None
rarity: Optional[Union[str, List[str]]] = None
template_info: Optional[Dict] = None
[docs]
def to_dict(self) -> dict:
"""
Convert the SpectralLine to a dictionary.
Returns
-------
dict
A dict representation of all fields of the dataclass.
"""
return asdict(self)
#this still require a few changes
[docs]
@dataclass
class SheapModel:
"""
Holds SpectralLines + (optionally) their profile functions & parameters.
You can slice/filter/group arbitrarily, and still recover both the
original (“global”) and per‐subset (“local”) parameter mappings.
"""
lines: List[SpectralLine]
profile_functions: List[Callable] = field(default_factory=list)
profile_names: List[str] = field(default_factory=list)
params_dict: Dict[str, int] = field(default_factory=dict)
profile_params_index_list: List[List[int]] = field(default_factory=list)
params: Optional[np.ndarray] = None
uncertainty_params: Optional[np.ndarray] = None
original_idx: List[int] = field(init=False)
_df: pd.DataFrame = field(init=False, repr=False)
_combined_func: Optional[Callable] = field(init=False, repr=False)
_master_param_names: List[str] = field(init=False, default_factory=list)
global_profile_params_index_list: List[List[int]] = field(init=False, default_factory=list)
def __post_init__(self):
self.original_idx = list(range(len(self.lines)))
if self.params_dict:
self._master_param_names = list(self.params_dict.keys())
# and record the original full index‐lists
self.global_profile_params_index_list = [
lst.copy() for lst in self.profile_params_index_list
]
# 3) build metadata DF: local index = .index, orig_idx column
fallback = [ln.profile for ln in self.lines]
prof_names = self.profile_names or fallback
rows = []
for i, ln in enumerate(self.lines):
rows.append({
"orig_idx": self.original_idx[i],
"line_name": ln.line_name,
"region": ln.region,
"subregion": ln.subregion,
"element": ln.element,
"component": ln.component,
"profile_name": prof_names[i],
})
self._df = pd.DataFrame(rows)
# 4) pre‐combine if profiles exist
self._combined_func = (
make_fused_profiles(self.profile_functions)
if self.profile_functions else None
)
[docs]
def attach_profiles(
self,
profile_functions: List[Callable],
profile_names: List[str],
params: np.ndarray,
uncertainty_params: np.ndarray,
profile_params_index_list: List[List[int]],
params_dict: Dict[str, int],
) -> None:
"""
Supply the full fit‐machinery. Must provide exactly one profile
& name per line, and a params_dict mapping each param_name->col.
"""
N = len(self.lines)
if not (len(profile_functions) == len(profile_names) == N):
raise ValueError("Need exactly one profile per line")
self.profile_functions = profile_functions
self.profile_names = profile_names
self.params = params
self.uncertainty_params = uncertainty_params
self.profile_params_index_list = [lst.copy() for lst in profile_params_index_list]
self.params_dict = params_dict
# record master list of all param‐names in the order of params_dict.keys()
self._master_param_names = list(params_dict.keys())
# record the original global index-lists once and for all
self.global_profile_params_index_list = [lst.copy() for lst in profile_params_index_list]
# rebuild combined profile
self._combined_func = make_fused_profiles(self.profile_functions)
# update DF’s profile_name column
self._df["profile_name"] = self.profile_names
@property
def combined_profile(self) -> Callable[[np.ndarray, np.ndarray], np.ndarray]:
if self._combined_func is None:
raise RuntimeError("No profiles attached")
return self._combined_func
@property
def flat_param_indices_global(self) -> np.ndarray:
"""
All the *global* parameter columns (original indices), in order.
"""
if not self.global_profile_params_index_list:
return np.array([], dtype=int)
return np.concatenate(self.global_profile_params_index_list).astype(int)
@property
def flat_param_indices_local(self) -> np.ndarray:
"""
All the *local* parameter columns (subset indices), in order.
"""
if not self.profile_params_index_list:
return np.array([], dtype=int)
return np.concatenate(self.profile_params_index_list).astype(int)
[docs]
def as_df(self) -> pd.DataFrame:
"""Local‐index DataFrame with columns including orig_idx, kind, component, etc."""
return self._df.copy()
[docs]
def filter(self, **conds) -> "SheapModel":
mask = np.ones(len(self.lines), dtype=bool)
for k, v in conds.items():
if k not in self._df.columns:
raise KeyError(f"No metadata column {k!r}")
col = self._df[k].values
mask &= np.isin(col, v) if isinstance(v, (list,tuple,np.ndarray)) else (col == v)
return self._subset(mask)
def _subset(self, mask: np.ndarray) -> "SheapModel":
# slice the lines + original indices
lines2 = [ln for ln, keep in zip(self.lines, mask) if keep]
orig2 = [oi for oi, keep in zip(self.original_idx, mask) if keep]
# slice the DF & reset local index
df2 = self._df[mask].reset_index(drop=False)
df2.rename(columns={"index": "local_idx"}, inplace=True)
df2["orig_idx"] = orig2
# if no profiles attached, return minimal
if not self.profile_functions:
new = SheapModel(lines=lines2)
new.original_idx = orig2
new._df = df2
return new
# slice profiles + names
funcs2 = [f for f, keep in zip(self.profile_functions, mask) if keep]
names2 = [n for n, keep in zip(self.profile_names, mask) if keep]
# build subset of the *global* index‐lists
glob_lists2 = [
self.global_profile_params_index_list[i]
for i, keep in enumerate(mask) if keep
]
flat_global = np.concatenate(glob_lists2).astype(int)
# slice the *original* params by global indices
params2 = self.params[:, flat_global]
u2 = self.uncertainty_params[:, flat_global]
# build local map: global→new‐local
local_map = { g: i for i, g in enumerate(flat_global) }
local_lists2 = [[ local_map[g] for g in lst ] for lst in glob_lists2]
# rebuild the subsetted params_dict from master names
master = np.array(self._master_param_names)
names_global = master[flat_global]
filtered_dict2 = { nm: i for i, nm in enumerate(names_global) }
# assemble the child
new = SheapModel(
lines=lines2,
profile_functions=funcs2,
profile_names=names2,
params_dict=filtered_dict2,
profile_params_index_list=local_lists2,
params=params2,
uncertainty_params=u2,
)
new._master_param_names = self._master_param_names
new.global_profile_params_index_list = glob_lists2
new.original_idx = orig2
new._df = df2
new._combined_func = make_fused_profiles(funcs2)
return new
def __getitem__(self, key: Union[int, slice, np.ndarray, List[int]]) -> "SheapModel":
mask = (np.zeros(len(self.lines), bool) if not isinstance(key,int)
else np.zeros(len(self.lines), bool))
mask[key] = True
return self._subset(mask)
[docs]
def group_by(self, field: str) -> Dict[Any, "SheapModel"]:
if field not in self._df.columns:
raise KeyError(f"No metadata column {field!r}")
return {
val: self.filter(**{field: val})
for val in np.unique(self._df[field].values)
}
[docs]
def param_subdict(self) -> Dict[str, np.ndarray]:
"""
Map each param name → its column in this instance’s params,
using the *local* flattened indices.
"""
names = np.array(list(self.params_dict.keys()))
return {nm: self.params[:, idx] for nm, idx in self.params_dict.items()}
[docs]
def unique(self, field: str) -> List[Any]:
if field not in self._df.columns:
raise KeyError(f"No metadata column {field!r}")
return sorted(pd.unique(self._df[field].dropna()).tolist())
@property
def regions(self) -> List[Any]:
return self.unique("region")
@property
def components(self) -> List[Any]:
return self.unique("component")
@property
def subregions(self) -> List[Any]:
return self.unique("subregion")
@property
def elements(self) -> List[Any]:
return self.unique("element")
[docs]
def characteristics(self) -> Dict[str, Any]:
by_region_component = (
self._df.groupby("region")["component"]
.nunique()
.sort_index()
.to_dict()
)
return {
"components": self.components,
"regions": self.regions,
#"profile_names": self.profile_names_list,
"elements": self.elements,
"subregions": self.subregions,
"n_components_per_region": by_region_component,
}
#still useffull?
[docs]
@dataclass
class SheapResult:
"""
Data class to store results from spectral region fitting.
Attributes:
region_list (List[SpectralLine]): List of spectral line configurations.
params (Optional[jnp.ndarray]): Optimized parameters from fitting.
uncertainty_params (Optional[jnp.ndarray]): Estimated uncertainties for each parameter.
mask (Optional[jnp.ndarray]): Mask used during the fitting process.
profile_functions (Optional[List[Callable]]): Functions describing each spectral profile.
profile_names (Optional[List[str]]): Names of spectral profiles used in fitting.
loss (Optional[List]): Values of the loss function during optimization.
profile_params_index_list (Optional[List]): Indices mapping profile parameters.
initial_params (Optional[jnp.ndarray]): Initial guess parameters before fitting.
scale (Optional[jnp.ndarray]): scale used for normalization.
params_dict (Optional[Dict[str, int]]): Mapping from parameter names to indices.
outer_limits (Optional[List]): Outer wavelength limits of the fitting region.
inner_limits (Optional[List]): Inner wavelength limits defining the region of interest.
model_keywords (Optional[dict]): Additional keywords for model configuration.
kind_list (List[str]): Unique types of spectral lines (computed post-init).
constraints same as constrains from fit
"""
region_list: List[SpectralLine] # can be mode to sheapmodel at the moment after reading.
fitting_routine: Optional[dict] = None
params: Optional[jnp.ndarray] = None
uncertainty_params: Optional[jnp.ndarray] = None
mask: Optional[jnp.ndarray] = None
constraints: Optional[jnp.ndarray] = None
profile_functions: Optional[List[Callable]] = None
profile_names: Optional[List[str]] = None
loss: Optional[List] = None
profile_params_index_list: Optional[List] = None
initial_params: Optional[jnp.ndarray] = None
scale: Optional[jnp.ndarray] = None
params_dict: Optional[Dict[str, int]] = None
outer_limits: Optional[List] = None
inner_limits: Optional[List] = None
model_keywords: Optional[dict] = None
source:Optional[dict] = None
dependencies:Optional[List] = None
free_params:Optional[jnp.ndarray] = None
residuals:Optional[jnp.ndarray] = None
chi2_red:Optional[jnp.ndarray] = None
posterior:Optional[dict] = None
fitkwargs:Optional[List[Dict]] = None
elapsed_time:Optional[List[Dict]] = None
params_class: Optional = None
def __post_init__(self):
#this should be an intermediate step in some cases it should be already done
self.sheapmodel = SheapModel(self.region_list)
self.sheapmodel.attach_profiles(self.profile_functions,self.profile_names,self.params,self.uncertainty_params,self.profile_params_index_list,self.params_dict)
[docs]
def to_dict(self) -> dict:
return asdict(self)
[docs]
@dataclass
class ProfileConstraintSet:
init: List[float]
upper: List[float]
lower: List[float]
profile: str
param_names: List[str]
profile_fn: Optional[Callable] = None
def __post_init__(self):
# Skip length check for SPAF profiles
if self.profile.startswith("SPAF"):
return
n = len(self.init)
if not (len(self.upper) == len(self.lower) == len(self.param_names) == n):
raise ValueError(
f"ConstraintSet mismatch: "
f"got init[{n}], upper[{len(self.upper)}], "
f"lower[{len(self.lower)}], param_names[{len(self.param_names)}]"
)
[docs]
@dataclass
class FittingLimits:
"""
Stores FWHM and shift limits for a line component kind.
Attributes:
upper_fwhm_kms (float): Maximum velocity FWHM (km/s).
lower_fwhm_kms (float): Minimum velocity FWHM (km/s).
vshift_kms (float): Maximum velocity shift (km/s).
max_amplitude (float): Maximum allowed amplitude.
"""
upper_fwhm_kms: float
lower_fwhm_kms: float
init_fwhm_kms:Optional[float] = None
vshift_kms: Optional[float] = None
max_amplitude: Optional[float] = None
references: Optional[list] = None
[docs]
@classmethod
def from_dict(cls, d: Dict[str, float]) -> "FittingLimits":
"""
Create FittingLimits from a dictionary with keys matching the attributes.
Args:
d (Dict[str, float]): Dictionary with keys:
'upper_fwhm_kms', 'lower_fwhm_kms', 'vshift_kms', 'max_amplitude'.
Returns:
FittingLimits: Instance created from the dictionary.
Raises:
ValueError: If any required key is missing from the dictionary.
"""
required_keys = {'upper_fwhm_kms', 'lower_fwhm_kms', 'vshift_kms', 'max_amplitude'}
missing = required_keys - d.keys()
if missing:
raise ValueError(f"Missing keys for FittingLimits: {missing}")
return cls(
upper_fwhm_kms=d['upper_fwhm_kms']
,lower_fwhm_kms=d['lower_fwhm_kms'],
init_fwhm_kms = d.get('init_fwhm_kms'),
vshift_kms =d['vshift_kms'],
max_amplitude=d['max_amplitude'],
references=d.get('references'),
)