"""
FITS Spectrum Readers
=====================
This module provides utilities to read spectra from different survey
and simulation formats (SDSS, DESI, PyQSO, and custom simulations).
It also includes parallel and batched readers to handle multiple files
efficiently, with fallbacks for sequential reading.
Readers
-------
- fits_reader_sdss: SDSS spectra (PLATE-MJD-FIBERID format)
- fits_reader_desi: DESI spectra
- fits_reader_pyqso: PyQSO pipeline spectra
- fits_reader_simulation: Simulated spectra
Batching / Parallel utilities
-----------------------------
- parallel_reader
- batched_reader
- sequential_reader
"""
__author__ = 'felavila'
__all__ = [
"READER_FUNCTIONS",
"batched_reader",
"fits_reader_desi",
"fits_reader_pyqso",
"fits_reader_sdss",
"fits_reader_simulation",
"n_cpu",
"parallel_reader",
"sequential_reader",
]
import os
import numpy as np
from multiprocessing import Pool, set_start_method
from astropy.io import fits
from functools import partial
#from sheap.Utils.SpectralSetup import resize_and_fill_with_nans
# Limit CPUs for safety
n_cpu = min(4, os.cpu_count())
[docs]
def fits_reader_desi(file: str):
"""
Read a DESI FITS spectrum.
Parameters
----------
file : str
Path to DESI FITS file.
Returns
-------
data_array : np.ndarray
Array with shape (3, n_pix): [wavelength, flux, error].
header_array : np.ndarray
Array with RA and DEC from header.
"""
hdul = fits.open(file)
flux_scale = float(hdul[1].header["TUNIT2"].split(" ")[0])
ivar_scale = float(hdul[1].header["TUNIT3"].split(" ")[0])
data = hdul[1].data
data_array = np.array([
data["WAVELENGTH"],
data["FLUX"] * flux_scale,
1 / np.sqrt(data["IVAR"] * ivar_scale)
])
data_array[np.isinf(data_array)] = 1e20
header_array = np.array([hdul[0].header["RA"], hdul[0].header["DEC"]])
return data_array, header_array
[docs]
def fits_reader_simulation(file: str, chanel: int = 1, template: bool = False):
"""
Read a simulated spectrum from a FITS file.
Parameters
----------
file : str
Path to simulation FITS file.
chanel : int, default=1
HDU extension index to read.
template : bool, default=False
If True, reads template arrays.
Returns
-------
data_array : np.ndarray
Array with shape (n_channels, n_pix).
header_array : list
Empty or metadata, depending on template.
"""
hdul = fits.open(file)
header_array = []
if template:
data_array = np.array([
hdul[chanel].data['LAMBDA'],
hdul[chanel].data["FLUX_DENSITY"]
])
return data_array.squeeze(), header_array
if chanel == 1:
data_array = np.array([
hdul[chanel].data["WAVE"],
hdul[chanel].data["FLUX"],
hdul[chanel].data["ERR_FLUX"],
])
else:
data_array = np.array([
hdul[chanel].data["WAVE"],
hdul[chanel].data["FLUX"],
hdul[chanel].data["ERR"],
])
return data_array.squeeze(), header_array
[docs]
def fits_reader_sdss(file: str):
"""
Read an SDSS FITS spectrum.
Parameters
----------
file : str
Path to SDSS FITS file.
Returns
-------
data_array : np.ndarray
Array with shape (4, n_pix): [wavelength, flux, error, wdisp].
header_array : np.ndarray
Array with RA and DEC from header.
"""
hdul = fits.open(file)
flux_scale = float(hdul[0].header["BUNIT"].split(" ")[0])
data = hdul[1].data
data_array = np.array([
10 ** data["loglam"],
data["flux"] * flux_scale,
flux_scale / np.sqrt(data["ivar"]),
data["wdisp"]
])
data_array[np.isinf(data_array)] = 1e20
header_array = np.array([hdul[0].header["PLUG_RA"], hdul[0].header["PLUG_DEC"]])
return data_array, header_array
[docs]
def fits_reader_pyqso(file: str):
"""
Read a PyQSO-format spectrum.
Parameters
----------
file : str
Path to PyQSO FITS file.
Returns
-------
spectra : np.ndarray
Array with shape (3, n_pix): [wavelength, flux, error].
header_array : list
Empty list (no coords stored).
"""
hdul = fits.open(file)
spectra = np.array([
hdul[3].data["wave_prereduced"],
hdul[3].data["flux_prereduced"],
hdul[3].data["err_prereduced"],
])
return spectra, []
READER_FUNCTIONS = {
"fits_reader_sdss": fits_reader_sdss,
"fits_reader_simulation": fits_reader_simulation,
"fits_reader_pyqso": fits_reader_pyqso,
"fits_reader_desi": fits_reader_desi,
}
[docs]
def parallel_reader(paths, n_cpu=n_cpu, function=fits_reader_sdss, **kwargs):
"""
Parallel reader using multiprocessing.
Parameters
----------
paths : list of str
Paths to FITS files.
n_cpu : int, optional
Number of processes to use (default=min(4, os.cpu_count())).
function : callable or str, optional
Reader function or key in `READER_FUNCTIONS`.
Returns
-------
coords : np.ndarray
Coordinates from headers (RA, DEC).
spectra_reshaped : list
Placeholder for reshaped spectra (currently empty).
spectra : list of np.ndarray
Raw spectra arrays.
"""
if isinstance(function, str):
function = READER_FUNCTIONS[function]
func_with_args = partial(function, **kwargs)
with Pool(processes=min(n_cpu, len(paths))) as pool:
results = pool.map(func_with_args, paths, chunksize=1)
spectra = [result[0] for result in results]
coords = np.array([result[1] for result in results])
shapes_max = max(s.shape[1] for s in spectra)
spectra_reshaped = [] # TODO: enable resize_and_fill_with_nans
return coords, spectra_reshaped, spectra
[docs]
def batched_reader(paths, batch_size=8, function=fits_reader_sdss):
"""
Batch reader for safer memory usage.
Parameters
----------
paths : list of str
Paths to FITS files.
batch_size : int, optional
Number of files to read per batch.
function : callable or str, optional
Reader function or key in `READER_FUNCTIONS`.
Returns
-------
coords : np.ndarray
Stacked coordinates from all batches.
spectra_reshaped : str
Placeholder (currently unused).
spectra_raw : list of np.ndarray
All raw spectra arrays.
"""
all_coords, all_reshaped, all_raw = [], [], []
for i in range(0, len(paths), batch_size):
batch = paths[i:i + batch_size]
coords, reshaped, raw = parallel_reader(
batch, n_cpu=min(n_cpu, len(batch)), function=function
)
all_coords.append(coords)
all_reshaped.append(reshaped)
all_raw.extend(raw)
coords = np.vstack(all_coords)
return coords, "unused", all_raw
[docs]
def sequential_reader(paths, function=fits_reader_sdss):
"""
Sequential FITS reader (fallback for debugging).
Parameters
----------
paths : list of str
Paths to FITS files.
function : callable or str, optional
Reader function or key in `READER_FUNCTIONS`.
Returns
-------
coords : np.ndarray
Coordinates from headers (RA, DEC).
spectra_reshaped : np.ndarray
Reshaped spectra array.
spectra : list of np.ndarray
Raw spectra arrays.
"""
results = []
for i in paths:
try:
results.append(function(i))
except Exception as e:
print(f"Failed to read {i}: {e}")
spectra = [result[0] for result in results]
coords = np.array([result[1] for result in results])
shapes_max = max(s.shape[1] for s in spectra)
spectra_reshaped = [] # TODO: enable resize_and_fill_with_nans
return coords, spectra_reshaped, spectra
# Ensure start method is set safely when calling as a script
if __name__ == '__main__':
try:
set_start_method("spawn", force=True)
except RuntimeError:
pass # already set