"""
FITS Spectrum Readers
=====================
This module provides utilities to read spectra from different survey
and simulation formats (SDSS, DESI, PyQSO, and custom simulations).
It also includes parallel and batched readers to handle multiple files
efficiently, with fallbacks for sequential reading.
Readers
-------
- fits_reader_sdss: SDSS spectra (PLATE-MJD-FIBERID format)
- fits_reader_desi: DESI spectra
- fits_reader_pyqso: PyQSO pipeline spectra
- fits_reader_simulation: Simulated spectra
Batching / Parallel utilities
-----------------------------
- parallel_reader
- batched_reader
- sequential_reader
"""
__author__ = 'felavila'
__all__ = [
"READER_FUNCTIONS",
"batched_reader",
"fits_reader_desi",
"fits_reader_pyqso",
"fits_reader_sdss",
"fits_reader_simulation",
"n_cpu",
"parallel_reader",
"sequential_reader",
]
import os
import numpy as np
from multiprocessing import Pool, set_start_method
from astropy.io import fits
from functools import partial
#from sheap.Utils.SpectralSetup import resize_and_fill_with_nans
# Limit CPUs for safety
n_cpu = min(4, os.cpu_count())
[docs]
def fits_reader_desi(file: str):
"""
Read a DESI FITS spectrum.
Parameters
----------
file : str
Path to DESI FITS file.
Returns
-------
data_array : np.ndarray
Array with shape (3, n_pix): [wavelength, flux, error].
header_array : np.ndarray
Array with RA and DEC from header.
"""
hdul = fits.open(file)
flux_scale = float(hdul[1].header["TUNIT2"].split(" ")[0])
ivar_scale = float(hdul[1].header["TUNIT3"].split(" ")[0])
data = hdul[1].data
data_array = np.array([
data["WAVELENGTH"],
data["FLUX"] * flux_scale,
1 / np.sqrt(data["IVAR"] * ivar_scale)
])
data_array[np.isinf(data_array)] = 1e20
header_array = np.array([hdul[0].header["RA"], hdul[0].header["DEC"]])
return data_array, header_array
def fits_reader_4most(file: str):
"""
Read a 4most FITS spectrum SPV.
Parameters
----------
file : str
Path to 4most FITS file.
Returns
-------
data_array : np.ndarray
Array with shape (3, n_pix): [wavelength, flux, error].
header_array : np.ndarray
Array with RA and DEC from header.
"""
hdul = fits.open(file)
data = hdul[1].data
header_array = np.array([hdul[0].header["RA"], hdul[0].header["DEC"]])
data_array = np.array([
data["WAVE"],
data["FLUX"],
data["ERR"]])
return data_array.squeeze(), header_array
[docs]
def fits_reader_simulation(file: str, chanel: int = 1, template: bool = False):
"""
Read a simulated spectrum from a FITS file.
Parameters
----------
file : str
Path to simulation FITS file.
chanel : int, default=1
HDU extension index to read.
template : bool, default=False
If True, reads template arrays.
Returns
-------
data_array : np.ndarray
Array with shape (n_channels, n_pix).
header_array : list
Empty or metadata, depending on template.
"""
hdul = fits.open(file)
header_array = []
if template:
data_array = np.array([
hdul[chanel].data['LAMBDA'],
hdul[chanel].data["FLUX_DENSITY"]
])
return data_array.squeeze(), header_array
if chanel == 1:
data_array = np.array([
hdul[chanel].data["WAVE"],
hdul[chanel].data["FLUX"],
hdul[chanel].data["ERR_FLUX"],
])
else:
data_array = np.array([
hdul[chanel].data["WAVE"],
hdul[chanel].data["FLUX"],
hdul[chanel].data["ERR"],
])
return data_array.squeeze(), header_array
[docs]
def fits_reader_sdss(file: str):
"""
Read an SDSS FITS spectrum.
Parameters
----------
file : str
Path to SDSS FITS file.
Returns
-------
data_array : np.ndarray
Array with shape (4, n_pix): [wavelength, flux, error, wdisp].
header_array : np.ndarray
Array with RA and DEC from header.
"""
hdul = fits.open(file)
flux_scale = float(hdul[0].header["BUNIT"].split(" ")[0])
data = hdul[1].data
data_array = np.array([
10 ** data["loglam"],
data["flux"] * flux_scale,
flux_scale / np.sqrt(data["ivar"]),
data["wdisp"]
])
data_array[np.isinf(data_array)] = 1e20
header_array = np.array([hdul[0].header["PLUG_RA"], hdul[0].header["PLUG_DEC"]])
return data_array, header_array
[docs]
def fits_reader_pyqso(file: str):
"""
Read a PyQSO-format spectrum.
Parameters
----------
file : str
Path to PyQSO FITS file.
Returns
-------
spectra : np.ndarray
Array with shape (3, n_pix): [wavelength, flux, error].
header_array : list
Empty list (no coords stored).
"""
hdul = fits.open(file)
spectra = np.array([
hdul[3].data["wave_prereduced"],
hdul[3].data["flux_prereduced"],
hdul[3].data["err_prereduced"],
])
return spectra, []
READER_FUNCTIONS = {
"fits_reader_sdss": fits_reader_sdss,
"fits_reader_simulation": fits_reader_simulation,
"fits_reader_pyqso": fits_reader_pyqso,
"fits_reader_desi": fits_reader_desi,
"fits_reader_4most": fits_reader_4most,
}
[docs]
def parallel_reader(paths, n_cpu=n_cpu, function=fits_reader_sdss, **kwargs):
"""
Parallel reader using multiprocessing.
Parameters
----------
paths : list of str
Paths to FITS files.
n_cpu : int, optional
Number of processes to use (default=min(4, os.cpu_count())).
function : callable or str, optional
Reader function or key in `READER_FUNCTIONS`.
Returns
-------
coords : np.ndarray
Coordinates from headers (RA, DEC).
spectra_reshaped : list
Placeholder for reshaped spectra (currently empty).
spectra : list of np.ndarray
Raw spectra arrays.
"""
if isinstance(function, str):
function = READER_FUNCTIONS[function]
func_with_args = partial(function, **kwargs)
with Pool(processes=min(n_cpu, len(paths))) as pool:
results = pool.map(func_with_args, paths, chunksize=1)
spectra = [result[0] for result in results]
coords = np.array([result[1] for result in results])
shapes_min= [s.shape[1] for s in spectra]
#spectra_reshaped = [] # TODO: enable resize_and_fill_with_nans
return coords, shapes_min, spectra
[docs]
def batched_reader(paths, batch_size=8, function=fits_reader_sdss):
"""
Batch reader for safer memory usage.
Parameters
----------
paths : list of str
Paths to FITS files.
batch_size : int, optional
Number of files to read per batch.
function : callable or str, optional
Reader function or key in `READER_FUNCTIONS`.
Returns
-------
coords : np.ndarray
Stacked coordinates from all batches.
spectra_reshaped : str
Placeholder (currently unused).
spectra_raw : list of np.ndarray
All raw spectra arrays.
"""
all_coords, all_reshaped, all_raw = [], [], []
for i in range(0, len(paths), batch_size):
batch = paths[i:i + batch_size]
coords, reshaped, raw = parallel_reader(
batch, n_cpu=min(n_cpu, len(batch)), function=function
)
all_coords.append(coords)
all_reshaped.append(reshaped)
all_raw.extend(raw)
coords = np.vstack(all_coords)
return coords, "unused", all_raw
[docs]
def sequential_reader(paths, function=fits_reader_sdss):
"""
Sequential FITS reader (fallback for debugging).
Parameters
----------
paths : list of str
Paths to FITS files.
function : callable or str, optional
Reader function or key in `READER_FUNCTIONS`.
Returns
-------
coords : np.ndarray
Coordinates from headers (RA, DEC).
spectra_reshaped : np.ndarray
Reshaped spectra array.
spectra : list of np.ndarray
Raw spectra arrays.
"""
results = []
for i in paths:
try:
results.append(function(i))
except Exception as e:
print(f"Failed to read {i}: {e}")
spectra = [result[0] for result in results]
coords = np.array([result[1] for result in results])
shapes_max = max(s.shape[1] for s in spectra)
spectra_reshaped = [] # TODO: enable resize_and_fill_with_nans
return coords, spectra_reshaped, spectra
def rebin_one_spectrum(
spectrum,
n_pix,
fill=np.nan,
clean_invalid=True,
):
"""
Rebin one spectrum to a fixed number of pixels.
Parameters
----------
spectrum : tuple or list
Spectrum in the form ``(wl, flux, error)``.
n_pix : int
Number of pixels in the output spectrum.
fill : float, optional
Fill value used by ``spectres`` outside the valid range.
clean_invalid : bool, optional
If True, non-finite flux/error values are replaced by 0.
Returns
-------
original : ndarray
Original spectrum with shape ``(3, n_original)``.
Rows are wavelength, flux, error.
new : ndarray
Rebinned spectrum with shape ``(4, n_pix)``.
Rows are wavelength, flux, error, mask.
conservation_ratio : float
Percentage ratio between rebinned and original integrated flux.
"""
wl, flux, error = spectrum
wl = np.asarray(wl, dtype=float)
flux = np.asarray(flux, dtype=float)
error = np.asarray(error, dtype=float)
order = np.argsort(wl)
wl = wl[order]
flux = flux[order]
error = error[order]
regrid = np.linspace(wl[0], wl[-1], int(n_pix))
new_flux, new_error = spectres(
regrid,
wl,
flux,
error,
fill=fill,
verbose=False,
)
valid = np.isfinite(new_flux) & np.isfinite(new_error)
if clean_invalid:
new_flux = np.where(valid, new_flux, 0.0)
new_error = np.where(valid, new_error, 0.0)
original_flux = np.trapezoid(flux, wl)
rebinned_flux = np.trapezoid(new_flux, regrid)
conservation_ratio = (
100.0 * rebinned_flux / original_flux
if original_flux != 0
else np.nan
)
original = np.stack([wl, flux, error], axis=0)
new = np.stack(
[
regrid,
new_flux,
new_error,
valid.astype(float),
],
axis=0,
)
return original, new, conservation_ratio
def rebin_spectra_list(
spectra_list,
n_pix=None,
file_names=None,
redshifts=None,
coords=None,
npix_original=None,
):
"""
Rebin a list of spectra to the same number of pixels.
Parameters
----------
spectra_list : list
List of spectra. Each item must be ``(wl, flux, error)``.
n_pix : int or None, optional
Number of pixels for the output grid. If None, uses the minimum
input spectrum length.
file_names : list or None, optional
Names used as dictionary keys.
redshifts : list or None, optional
Redshifts associated with each spectrum.
coords : list or None, optional
Coordinates associated with each spectrum.
npix_original : list or None, optional
Original number of pixels for each spectrum.
Returns
-------
new_dic : dict
Dictionary containing original and rebinned spectra.
"""
if n_pix is None:
n_pix = min(len(s[0]) for s in spectra_list)
n_pix = int(n_pix)
n_spectra = len(spectra_list)
if file_names is None:
file_names = [f"spectrum_{i}" for i in range(n_spectra)]
if redshifts is None:
redshifts = [np.nan] * n_spectra
if coords is None:
coords = [None] * n_spectra
if npix_original is None:
npix_original = [len(s[0]) for s in spectra_list]
new_dic = {}
for i, spectrum in enumerate(spectra_list):
original, new, conservation_ratio = rebin_one_spectrum(
spectrum=spectrum,
n_pix=n_pix,
)
file_name = file_names[i]
new_dic[file_name] = {
"original": original,
"new": new,
"conservation_ratio": conservation_ratio,
"coords": coords[i],
"z": redshifts[i],
"dr_name": file_name,
"npix": npix_original[i],
}
return new_dic