Source code for sheap.Sheapectral.Utils.SpectralSetup

"""
Spectral preparation & masking utilities.

This module provides small, JAX‑friendly helpers to:
- hard‑cut spectra to wavelength windows,
- pad a missing error channel,
- normalize/reshape batches to a common pixel length,
- build robust masks and uncertainty arrays for fitting.

Notes
-----
- All functions are written to be light on copies and JAX‑compatible
    (using `jnp` where the arrays participate in later JAX code).
"""

__author__ = 'felavila'

__all__ = [
    "cut_spectra",
    "mask_builder",
    "pad_error_channel",
    "prepare_spectra",
    "prepare_uncertainties",
    "resize_and_fill_with_nans",
    "ensure_sfd_data",
    "profile_functions_from_region_list"
]

from typing import Optional, Sequence, Tuple

import jax.numpy as jnp
import numpy as np
from pathlib import Path
import requests
from sheap.Core import ArrayLike


[docs] def resize_and_fill_with_nans( original_array: np.ndarray, new_xaxis_length: int, number_columns: Optional[int] = None, ) -> np.ndarray: """ Resize a (C, N) spectral array to (C_out, new_xaxis_length), padding with NaNs. Parameters ---------- original_array Input array with shape (C, N). new_xaxis_length Target number of pixels along the wavelength axis. number_columns Target number of channels C_out. If None, uses original C. Returns ------- np.ndarray New array of shape (C_out, new_xaxis_length) with the original data copied into the upper-left corner and NaNs elsewhere. """ C_in, N_in = original_array.shape C_out = number_columns or C_in out = np.full((C_out, new_xaxis_length), np.nan, dtype=float) out[: min(C_in, C_out), : min(N_in, new_xaxis_length)] = original_array[ : min(C_in, C_out), : min(N_in, new_xaxis_length) ] return out
[docs] def prepare_spectra( spectra_list: Sequence[np.ndarray], outer_limits: Tuple[float, float], ): """ Cut each spectrum to `outer_limits`, pad to common length, and build masks. Parameters ---------- spectra_list Iterable of per‑object arrays with shape (C, N). outer_limits (xmin, xmax) hard window to keep. Returns ------- spectral_region : jnp.ndarray Batched spectra of shape (n_obj, C, N_max) after cutting/padding. mask_region : jnp.ndarray Boolean mask with True where samples are masked (ignored in fit). """ xmin, xmax = outer_limits clipped = [cut_spectra(s, xmin, xmax) for s in spectra_list] n_max = max(s.shape[1] for s in clipped) stacked = jnp.array([resize_and_fill_with_nans(s, n_max) for s in clipped]) spectral_region, _, _, mask_region = mask_builder( stacked, outer_limits=outer_limits ) return spectral_region, mask_region
[docs] def cut_spectra(spectra: ArrayLike, xmin: float, xmax: float) -> ArrayLike: """ Hard cut a spectrum to a wavelength interval. Parameters ---------- spectra Array with shape (C, N) whose first row is wavelength [Å]. xmin, xmax Interval bounds. Returns ------- ArrayLike The sliced spectrum with pixels xmin ≤ λ ≤ xmax. """ wl = spectra[0, :] sel = (wl >= xmin) & (wl <= xmax) return spectra[:, sel]
[docs] def mask_builder( sheap_array: jnp.ndarray, inner_limits: Tuple[float, float] = (0.0, 0.0), outer_limits: Optional[Tuple[float, float]] = None, instrumental_limit: float = 1e50, # kept for API compatibility; not used internally ): """ Build a robust mask and uncertainty channel for a batch of spectra. Rules: - Mask pixels inside `inner_limits` (to exclude e.g. strong tellurics). - If `outer_limits` is provided, mask pixels outside it. - Mask NaN wavelengths / infinite errors / non‑positive flux. - Convert any NaNs in the error channel to a very large uncertainty (1e31). Parameters ---------- sheap_array Array with shape (n_obj, C, N), channels: [λ, flux, err, (opt ...)]. inner_limits (xmin, xmax) wavelengths to *mask out* inside this window. outer_limits If given, wavelengths outside (xmin, xmax) are masked. instrumental_limit Unused placeholder (kept to avoid breaking callers). Returns ------- array : jnp.ndarray Copy of input with the error channel replaced by prepared uncertainties. prepared_uncertainties : jnp.ndarray The prepared error channel (same shape as input error channel). original_array : jnp.ndarray Reference to the original `sheap_array`. mask : jnp.ndarray Boolean mask (True = masked / ignored). """ copy_array = jnp.array(sheap_array) wl = sheap_array[:, 0, :] flux = sheap_array[:, 1, :] err = sheap_array[:, 2, :] # Mask inside inner limits mask = (wl >= inner_limits[0]) & (wl <= inner_limits[1]) # And outside outer limits, if provided if outer_limits is not None: mask |= (wl < outer_limits[0]) | (wl > outer_limits[1]) # Invalidate bad values mask |= jnp.isnan(wl) | jnp.isinf(err) | (flux <= 0) # Set error to NaN where masked; then prepare_uncertainties → 1e31 at those places err_masked = jnp.where(mask, jnp.nan, err) copy_array = copy_array.at[:, 2, :].set(err_masked) prepared = prepare_uncertainties(copy_array[:, 2, :], flux) copy_array = copy_array.at[:, 2, :].set(prepared) return copy_array, prepared, sheap_array, mask
[docs] def prepare_uncertainties( y_uncertainties: Optional[jnp.ndarray], y_data: jnp.ndarray, ) -> jnp.ndarray: """ Prepare an uncertainty channel consistent with masking rules. - If `y_uncertainties` is None, returns an array of ones. - Any NaNs in the uncertainties or the data are set to a very large value (1e31), so inverse‑variance weighting effectively ignores those pixels. Parameters ---------- y_uncertainties Error channel or None. y_data Flux channel (used only to propagate NaN locations). Returns ------- jnp.ndarray Prepared uncertainties (same shape as `y_data`). """ if y_uncertainties is None: y_uncertainties = jnp.ones_like(y_data) bad = jnp.isnan(y_data) | jnp.isnan(y_uncertainties) return jnp.where(bad, 1e31, y_uncertainties)
[docs] def pad_error_channel(spectra: ArrayLike, frac: float = 0.01) -> ArrayLike: """ Ensure a third channel (error) by padding with a fraction of the signal. Parameters ---------- spectra Array with shape (n_obj, C, N) or (C, N). If C==2 (λ, flux), the error channel is appended as `frac * flux`. frac Error fraction applied to the flux to fabricate uncertainties. Returns ------- ArrayLike Spectra with shape (..., 3, N). """ if spectra.shape[1] != 2: return spectra # already has ≥3 channels signal = spectra[:, 1, :] error = jnp.expand_dims(signal * frac, axis=1) return jnp.concatenate((spectra, error), axis=1)
[docs] def ensure_sfd_data(sfd_path: Path = None): """ Ensure the Schlegel, Finkbeiner & Davis (1998) dust maps are available locally. Downloads the 4 required FITS files into `sfd_path` if missing. Parameters ---------- sfd_path : Path, optional Directory where the SFD data should be stored. Defaults to `SuportData/sfddata` relative to this file. Files ----- - SFD_dust_4096_ngp.fits - SFD_dust_4096_sgp.fits - SFD_mask_4096_ngp.fits - SFD_mask_4096_sgp.fits """ if sfd_path is None: sfd_path = Path(__file__).resolve().parent.parent / "SuportData" / "sfddata" sfd_path.mkdir(parents=True, exist_ok=True) files = [ "SFD_dust_4096_ngp.fits", "SFD_dust_4096_sgp.fits", "SFD_mask_4096_ngp.fits", "SFD_mask_4096_sgp.fits", ] base_url = "https://raw.githubusercontent.com/kbarbary/sfddata/master" missing = [fname for fname in files if not (sfd_path / fname).exists()] if not missing: return print(f"For the SFD correction is necessary download a list of files ({missing}) this will be done just ones") for fname in missing: url = f"{base_url}/{fname}" outpath = sfd_path / fname r = requests.get(url, stream=True) r.raise_for_status() with open(outpath, "wb") as f: for chunk in r.iter_content(chunk_size=8192): f.write(chunk)
[docs] def profile_functions_from_region_list(region_list): """ Recreate profile functions for each region component. Returns ------- list of callables Profile model functions. """ from sheap.Profiles.Profiles import PROFILE_FUNC_MAP,PROFILE_CONTINUUM_FUNC_MAP profile_functions = [] for _,sp in enumerate(region_list): holder_profile = getattr(sp, "profile") # cant be none if "SPAF" in holder_profile: if len(sp.profile.split("_")) == 2: _, subprofile = sp.profile.split("_") else: print("Warning this if u have an SPAF, you should have and subprofile otherwise it can be readed correctly") sm = PROFILE_FUNC_MAP["SPAF"](sp.center,sp.amplitude_relations,subprofile) elif sp.profile == "hostmiles": sm = PROFILE_FUNC_MAP[sp.profile](**sp.template_info)["model"] elif sp.profile == "template": sm =PROFILE_FUNC_MAP[sp.profile](**sp.template_info)["model"] elif sp.profile in PROFILE_CONTINUUM_FUNC_MAP: ##retroactive solution keywords = sp.template_info["keywords"] if sp.template_info is not None else {"delta0":5500.0} sm = PROFILE_FUNC_MAP.get(sp.profile)(**keywords) else: sm = PROFILE_FUNC_MAP.get(holder_profile) profile_functions.append(sm) return profile_functions
# """ # Spectral preparation & masking utilities. # This module provides small, JAX‑friendly helpers to: # - hard‑cut spectra to wavelength windows, # - pad a missing error channel, # - normalize/reshape batches to a common pixel length, # - build robust masks and uncertainty arrays for fitting. # Notes # ----- # - All functions are written to be light on copies and JAX‑compatible # (using `jnp` where the arrays participate in later JAX code). # """ # __author__ = 'felavila' # __all__ = [ # "cut_spectra", # "mask_builder", # "pad_error_channel", # "prepare_spectra", # "prepare_uncertainties", # "resize_and_fill_with_nans", # ] # from typing import Callable, Dict, Optional, Tuple, Union # import jax.numpy as jnp # import numpy as np # from sheap.Core import ArrayLike # def resize_and_fill_with_nans(original_array, new_xaxis_length, number_columns=4): # """ # Resize an array to the target shape, filling new entries with NaNs. # """ # new_array = np.full((number_columns, new_xaxis_length), np.nan, dtype=float) # slices = tuple( # slice(0, min(o, t)) # for o, t in zip(original_array.shape, (number_columns, new_xaxis_length)) # ) # new_array[slices] = original_array[slices] # return new_array # def prepare_spectra(spectra_list, outer_limits): # list_cut = [cut_spectra(s, *outer_limits) for s in spectra_list] # shapes_max = max(s.shape[1] for s in list_cut) # spectra_reshaped = jnp.array([resize_and_fill_with_nans(s, shapes_max) for s in list_cut]) # spectral_region, _, _, mask_region = mask_builder( # spectra_reshaped, outer_limits=outer_limits # ) # return spectral_region, mask_region # def cut_spectra(spectra, xmin, xmax): # """hard cut of the spectra""" # mask = (spectra[0, :] >= xmin) & (spectra[0, :] <= xmax) # spectra = spectra[:, mask] # return spectra # def mask_builder( # sheap_array, inner_limits=[0, 0], outer_limits=None, instrumental_limit=10e50 # ): # """ # -full nan the error matrix # if outer_limits is not None: # mask_outside_outer = (sheap_array[:, 0, :] < outer_limits[0]) | (sheap_array[:, 0, :] > outer_limits[1]) # Parameters: # - sheap_array: Input array with shape (N, 3, M). # - inner_limits: List of two values [min, max] for the inner limits. # - outer_limits: Optional list of two values [min, max] for the outer limits. # - instrumental_limit: in units of flux this defines the limit that can reach the instrument after understimate the error # Returns: # - array: Array with masked values based on the limits. # - mask: Prepared uncertainties array. # - original_array: The original sheap_array. # - masked_uncertainties: The mask applied to the array this means the error in these regions go to 1e11 # comment: # # Combine masks to mask values inside inner_limits or outside outter_limits # # take the uncertainties and put it to nan in the region that we wan to not take in account # #place in where we want to not fit # """ # copy_array = jnp.copy(sheap_array) # mask = (sheap_array[:, 0, :] >= inner_limits[0]) & ( # sheap_array[:, 0, :] <= inner_limits[1] # ) # if outer_limits is not None: # mask_outside_outter = (sheap_array[:, 0, :] < outer_limits[0]) | ( # sheap_array[:, 0, :] > outer_limits[1] # ) # mask = mask | mask_outside_outter # mask = ( # mask # | (jnp.isnan(sheap_array[:, 0, :]) | jnp.isinf(sheap_array[:, 2, :])) # | (sheap_array[:, 1, :] <= 0) # ) # copy_array = copy_array.at[:, 2, :].set(jnp.where(mask, jnp.nan, copy_array[:, 2, :])) # masked_uncertainties = prepare_uncertainties(copy_array[:, 2, :], copy_array[:, 1, :]) # copy_array = copy_array.at[:, 2, :].set(masked_uncertainties) # # masked_uncertainties = masked_uncertainties == 1.e+31 # return copy_array, masked_uncertainties, sheap_array, mask # def prepare_uncertainties( # y_uncertainties: Optional[jnp.ndarray], y_data: jnp.ndarray # ) -> jnp.ndarray: # """ # Prepare the y_uncertainties array. If None, return an array of ones. # If there are NaN values in y_data, set the corresponding uncertainties to 1e11. # Parameters: # - y_uncertainties: Provided uncertainties or None. # - y_data: The target data array. # Returns: # - y_uncertainties: An array of uncertainties. # """ # if y_uncertainties is None: # y_uncertainties = jnp.ones_like(y_data) # # Identify positions where y_data has NaN values # nan_positions = jnp.isnan(y_data) | jnp.isnan(y_uncertainties) # # Set uncertainties to 1e11 at positions where y_data is NaN/here i have some corncerns about is it is weight or not # y_uncertainties = jnp.where(nan_positions, 1e31, y_uncertainties) # return y_uncertainties # # TODO Add multiple models to the reading. # def pad_error_channel(spectra: ArrayLike, frac: float = 0.01) -> ArrayLike: # """Ensure *spectra* has a third channel (error) by padding with *frac* × signal.""" # if spectra.shape[1] != 2: # return spectra # already 3‑channel # signal = spectra[:, 1, :] # error = jnp.expand_dims(signal * frac, axis=1) # return jnp.concatenate((spectra, error), axis=1)