Source code for sheap.Profiles.profiles_templates

"""
Template-Based Profiles
=======================

This module provides template-driven spectral components used in *sheap*:

- **Fe II templates** (UV, optical, combined) read from ASCII files and broadened
    via FFT convolution.
- **Balmer high-order blends** represented as fixed templates.
- **Host galaxy templates** based on E-MILES SSP cubes, sub-selected in metallicity,
    age, and wavelength, and combined with free weights.

Functions
---------
- ``make_feii_template_function`` :
    Factory for Fe II template models by name. Supports optional wavelength cuts
    and returns a JAX-ready profile function plus template metadata.
- ``make_host_function`` :
    Factory for host galaxy models from a precomputed SSP cube. Uses efficient
    memory mapping and a single FFT-based convolution of the weighted template sum.

Constants
---------
- ``TEMPLATES_PATH`` : Path to the bundled template data directory.
- ``FEII_TEMPLATES`` : Registry of available Fe II template definitions.

Notes
-----
- All returned models are decorated with ``@with_param_names`` and are JAX-compatible.
- FFT-based Gaussian broadening quadratically subtracts the intrinsic template
    resolution before applying user-defined FWHM.
- Host models build parameter names dynamically as ``weight_Z{Z}_age{age}``
    for each included SSP grid point.

Todo
----
Rename ``make_feii_template_function`` for a more general function. 
"""

__author__ = 'felavila'

__all__ = [
    "TEMPLATES_PATH",
    "make_feii_template_function",
    "make_host_function",
]

from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from pathlib import Path

import jax 
import jax.numpy as jnp
import jax.scipy as jsp
import numpy as np

from sheap.Profiles.Utils import with_param_names
from sheap.Utils.Constants import DEFAULT_C_KMS

TEMPLATES_PATH = Path(__file__).resolve().parent.parent / "SuportData" / "templates"

TEMPLATES: Dict[str, Dict[str, Any]] = {
    "feop": {
        "file": TEMPLATES_PATH / "fe2_Op.dat",
        "central_wl": 4650.0,
        "sigmatemplate": 900.0 / 2.355,
        "fixed_dispersion": None,
    },
    "feuv": {
        "file": TEMPLATES_PATH / "fe2_UV02.dat",
        "central_wl": 2795.0,
        "sigmatemplate": 900.0 / 2.355,
        "fixed_dispersion": 106.3, 
    },
    "feuvop":{"file": TEMPLATES_PATH / "uvofeii1000kms.txt",
        "central_wl": 4570.0,
        "sigmatemplate": 1000.0 / 2.355},
    "BalHiOrd":{"file": TEMPLATES_PATH / "BalHiOrd_FWHM1000.dat",
                "sigmatemplate": 1000.0 / 2.355,
                "central_wl": 3675.0
                }
}

def make_template_function(
    name: str,
    x_min: Optional[float] = None,  # Angstroms (linear)
    x_max: Optional[float] = None,
) -> Dict[str, Any]:
    """
    Factory for a FeII template model by name, with optional wavelength cuts.

    Looks up path, central_wl, sigmatemplate, and optional fixed_dispersion in TEMPLATES.

    If x_min/x_max are provided, the template spectrum is cut to [x_min, x_max]
    with a ±50 Å guard band to reduce boundary artifacts in the FFT broadening, and
    then re-normalized to unit sum.

    Notes
    -----
    The third parameter is **vshift_kms** (velocity shift in km/s), applied as a
    multiplicative stretch of the wavelength grid:
        wl_shifted = wl * (1 + vshift_kms / DEFAULT_C_KMS)

    Returns
    -------
    dict
        {
          'model': Callable(x, params) -> flux,  # has .param_names, .n_params
          'template_info': {
              'name', 'file', 'central_wl', 'sigmatemplate',
              'fixed_dispersion', 'x_min', 'x_max', 'dl'
          }
        }
    """
    cfg = TEMPLATES.get(name)
    if cfg is None:
        raise KeyError(f"No such template: {name}")

    path          = cfg["file"]
    central_wl    = cfg["central_wl"]
    sigmatemplate = cfg["sigmatemplate"]
    user_fd       = cfg.get("fixed_dispersion", None)

    data = np.loadtxt(path, comments="#").T
    wl   = np.array(data[0], dtype=np.float64)
    flux = np.array(data[1], dtype=np.float64)

    # Optional wavelength cut with ±50 Å margin
    if x_min is not None or x_max is not None:
        mask = np.ones_like(wl, dtype=bool)
        if x_min is not None:
            mask &= wl >= max(x_min - 50.0, wl.min())
        if x_max is not None:
            mask &= wl <= min(x_max + 50.0, wl.max())
        if not np.any(mask):
            raise ValueError(f"No wavelength values left after applying x_min/x_max cut for {name}")
        wl   = wl[mask]
        flux = flux[mask]

    # Ensure equally spaced grid
    if wl.size < 3:
        raise ValueError("Template too short after cutting; need at least 3 points.")
    dl = float(wl[1] - wl[0])

    # Re-normalize to unit sum AFTER any cut
    unit_flux = flux / np.clip(np.sum(flux), 1e-10, np.inf)

    if user_fd is None:
        fixed_dispersion = (dl / central_wl) * DEFAULT_C_KMS
    else:
        fixed_dispersion = float(user_fd)

    param_names = ["logamp", "logFWHM", "vshift_kms"]

    # Pre-pack constants as JAX arrays once
    wl_jax = jnp.asarray(wl)
    unit_flux_jax = jnp.asarray(unit_flux)

    @with_param_names(param_names)
    def model(x: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray:
        logamp, logFWHM, vshift_kms = params

        amp  = 10.0 ** logamp
        FWHM = 10.0 ** logFWHM              # km/s
        sigma_model = FWHM / 2.355          # km/s

        # Quadratic subtraction of template intrinsic sigma (km/s), made safe
        diff_sq = sigma_model**2 - sigmatemplate**2
        diff_sq_safe = jax.nn.softplus(diff_sq / 10.0) * 10.0 + 1e-12
        delta_sigma = jnp.sqrt(diff_sq_safe)  # km/s

        # Convert km/s broadening -> pixels (fixed_dispersion is km/s per pixel)
        sigma_pix = delta_sigma / fixed_dispersion

        n_pix = unit_flux_jax.shape[0]
        freq = jnp.fft.fftfreq(n_pix, d=1.0)  # pixel-frequency
        gauss_tf = jnp.exp(-2.0 * (jnp.pi * freq * sigma_pix) ** 2)

        spec_fft = jnp.fft.fft(unit_flux_jax)
        broadened = jnp.real(jnp.fft.ifft(spec_fft * gauss_tf))

        # --- velocity shift (positive -> redder features) ---
        beta = vshift_kms / DEFAULT_C_KMS
        xp = wl_jax * (1.0 + beta)

        interp = jnp.interp(x, xp, broadened, left=0.0, right=0.0)
        return amp * interp

    return {
        "model": model,
        "template_info": {
            "name": name,
            "file": str(path),
            "central_wl": float(central_wl),
            "sigmatemplate": float(sigmatemplate),
            "fixed_dispersion": float(fixed_dispersion),
            "x_min": None if x_min is None else float(x_min),
            "x_max": None if x_max is None else float(x_max),
            "dl": dl,
        },
    }

[docs] def make_host_function( filename: str | Path = TEMPLATES_PATH / "miles_cube_log.npz", z_include: Optional[Union[tuple[float, float], list[float]]] = [-0.7, 0.22], age_include: Optional[Union[tuple[float, float], list[float]]] = [0.1, 10.0], xmin: Optional[float] = None, xmax: Optional[float] = None, verbose: Optional[bool] = None, **kwargs, ) -> dict: r""" Factory for host-galaxy SSP models stored on a log-sampled wavelength grid. The returned model builds a weighted sum of SSP templates, then applies a single FFT-based Gaussian broadening in **velocity space** (LOSVD-like), and finally interpolates to the user grid ``x``. Parameters ---------- filename Either ``"miles"``, ``"xsl"``, a path to a ``*_cube_log.npz`` file, or ``None`` (defaults to MILES). z_include Metallicity selection; either (min,max) or list/array. age_include Age selection; either (min,max) or list/array. xmin, xmax Optional wavelength window in Angstrom. A ±50 Å guard band is applied. verbose Print selected grid size. Returns ------- dict ``{"model": callable, "host_info": dict}`` Notes ----- - **Important:** The template wavelength axis in the NPZ file must be sampled approximately uniformly in :math:`\ln(\lambda)` (i.e. *log-spaced in wavelength*). The convolution kernel is constructed in :math:`\ln(\lambda)`, so the model assumes a constant velocity step per pixel. - The input ``x`` passed to ``model(x, params)`` is in **Angstrom**, but should be sampled on a grid that is also approximately uniform in :math:`\ln(\lambda)` if you want the broadening to behave exactly like a Gaussian LOSVD on your evaluation grid. (Interpolation to arbitrary ``x`` is allowed, but the kernel is defined on the template log-grid.) - Velocity shift is applied as :math:`\lambda \rightarrow \lambda \exp(v/c)`. The parameter order is: ``[logamp, logFWHM, vshift_kms, weight_0, weight_1, ...]``. """ #print("Makehost test function") allowed = {"miles": TEMPLATES_PATH / "miles_cube_log.npz", "xsl": TEMPLATES_PATH / "xsl_cube_log.npz", None: TEMPLATES_PATH / "miles_cube_log.npz"} if filename in allowed: filename = allowed[filename] else: filename = Path(filename) if not filename.exists(): raise KeyError( f"file_name '{filename}' is not available. Available: ['miles', 'xsl'] or a valid path." ) data = np.load(filename, mmap_mode="r") cube = np.asarray(data["cube_log"], dtype=np.float32) # (n_Z, n_age, n_pix) wave = np.asarray(data["wave_log"], dtype=np.float32) # (n_pix,) in Angstrom, log-sampled all_ages = np.asarray(data["ages_sub"], dtype=np.float32) all_zs = np.asarray(data["zs_sub"], dtype=np.float32) sigmatemplate = float(data["sigmatemplate"]) # km/s #fixed_dispersion = float(data["fixed_dispersion"]) # km/s per pixel (on the log grid) if wave.size < 4: raise ValueError("Template wavelength grid too short.") if not np.all(np.isfinite(wave)) or np.any(wave <= 0): raise ValueError("wave_log must be finite and strictly > 0 (Angstrom).") if np.any(np.diff(wave) <= 0): raise ValueError("wave_log must be strictly increasing.") if z_include is not None: z_min, z_max = float(np.min(z_include)), float(np.max(z_include)) z_mask = (all_zs >= z_min) & (all_zs <= z_max) if not np.any(z_mask): raise ValueError(f"No metallicities in range {z_min} to {z_max}") zs = all_zs[z_mask] cube = cube[z_mask, :, :] else: zs = all_zs if age_include is not None: a_min, a_max = float(np.min(age_include)), float(np.max(age_include)) a_mask = (all_ages >= a_min) & (all_ages <= a_max) if not np.any(a_mask): raise ValueError(f"No ages in range {a_min} to {a_max}") ages = all_ages[a_mask] cube = cube[:, a_mask, :] else: ages = all_ages grid_metadata = [(float(Z), float(age)) for Z in zs for age in ages] if kwargs.get("data"): print("We will use external template experimental") data = kwargs["data"] filename = kwargs["data"]["filename"] cube = np.asarray(data["cube_log"], dtype=np.float32) wave = np.asarray(data["wave_log"], dtype=np.float32) # (n_pix,) in Angstrom, log-sampled assert wave.shape[0] == cube.shape[-1], "check the shapes of wave_log and cube_log" sigmatemplate = data["sigmatemplate"] n_Z, n_age = cube.shape[:-1] grid_metadata = [(float(Z), float(age)) for Z in np.arange(n_Z) for age in np.arange(n_age)] f = 1.0 # keep if you later want external scaling if xmin is not None or xmax is not None: mask = np.ones_like(wave, dtype=bool) if xmin is not None: mask &= wave >= max(float(xmin) * f - 50.0, float(wave.min())) if xmax is not None: mask &= wave <= min(float(xmax) * f + 50.0, float(wave.max())) if not np.any(mask): raise ValueError("No wavelength values left after applying xmin/xmax.") wave = wave[mask].astype(np.float32, copy=False) cube = cube[:, :, mask].astype(np.float32, copy=False) n_Z, n_age, n_pix = cube.shape if verbose: print(f"Host added with n_Z={n_Z}, n_age={n_age}, n_pix={n_pix}") eps = 1e-30 flux_int = np.nansum(cube, axis=-1, keepdims=True) cube = cube / (flux_int + eps) templates_flat = cube.reshape(-1, n_pix) # (n_Z*n_age, n_pix) param_names = ["logamp", "logFWHM", "vshift_kms"] linear_params = [] for Z, age in grid_metadata: zstr = str(Z).replace(".", "p") astr = str(age).replace(".", "p") param_names.append(f"weight_Z{zstr}_age{astr}") linear_params.append(f"weight_Z{zstr}_age{astr}") templates_jax = jnp.asarray(templates_flat) # (Ntemp, n_pix) float32 wave_jax = jnp.asarray(wave) # (n_pix,) float32 # build convolution kernel in ln(lambda) (LOSVD-correct) # dln is ~constant if wave is log-sampled in Angstrom dln = float(np.mean(np.gradient(np.log(wave.astype(np.float64))))) freq = jnp.fft.fftfreq(n_pix, d=dln).astype(jnp.float32) @with_param_names(param_names,linear_param_names=linear_params) def model(x: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray: logamp = params[0] logFWHM = params[1] vshift_kms = params[2] weights = params[3:] # (Ntemp,) amplitude = 10.0 ** logamp base = jnp.tensordot(weights, templates_jax, axes=(0, 0)) # (n_pix,) FWHM = 10.0 ** logFWHM sigma_model = FWHM / 2.355 # km/s diff_sq = sigma_model**2 - sigmatemplate**2 diff_sq_safe = jax.nn.softplus(diff_sq / 10.0) * 10.0 + 1e-12 delta_sigma = jnp.sqrt(diff_sq_safe) # km/s sigma_y = (delta_sigma / DEFAULT_C_KMS).astype(jnp.float32) gauss_tf = jnp.exp(-2.0 * (jnp.pi * freq * sigma_y) ** 2) conv = jnp.real(jnp.fft.ifft(jnp.fft.fft(base) * gauss_tf)) xp = wave_jax * jnp.exp(vshift_kms / DEFAULT_C_KMS) return amplitude * jnp.interp(x * f, xp, conv, left=0.0, right=0.0) return { "model": model, "host_info": { "z_include": zs, "age_include": ages, "n_Z": n_Z, "n_age": n_age, "xmin": xmin, "xmax": xmax, "file_name": str(filename), "dln": dln, "dv_pix_kms": float(DEFAULT_C_KMS * dln), "sigmatemplate": sigmatemplate, "data":kwargs.get("data",None) }, }
# #xsl_cube_log_ def make_host_function_classic( filename: str = TEMPLATES_PATH / "miles_cube_log.npz", #filename: str = TEMPLATES_PATH / "xsl_cube_log.npz", z_include: Optional[Union[tuple[float, float], list[float]]] = [-0.7, 0.22], age_include: Optional[Union[tuple[float, float], list[float]]] = [0.1, 10.0], xmin: Optional[float] = None, xmax: Optional[float] = None, verbose: Optional[bool] = None, **kwargs, ) -> dict: """ Memory-lean host model: - sums weighted templates first, then does a single FFT-based convolution - np.load(..., mmap_mode='r') to reduce RAM pressure - keeps arrays in float32 Parameters ---------- The third parameter is vshift_kms: a velocity shift in km/s. """ #f = 1.0 #print(filename) #z_source = 2.16 #z_lens = 0.905 #f = (1 + z_source) / (1 + z_lens) if filename not in ["miles","xsl",TEMPLATES_PATH / "miles_cube_log.npz",TEMPLATES_PATH / "xsl_cube_log.npz",None]: raise KeyError( f"file_name '{filename}' is not available file_name." f"Available parameters: {['miles','xsl']}") f = 1. filename = {"miles":TEMPLATES_PATH / "miles_cube_log.npz","xsl":TEMPLATES_PATH / "xsl_cube_log.npz"}.get(filename,TEMPLATES_PATH / "miles_cube_log.npz") #print(filename) data = np.load(filename, mmap_mode="r") cube = np.asarray(data["cube_log"], dtype=np.float32) # (n_Z, n_age, n_pix) wave = np.asarray(data["wave_log"], dtype=np.float32) all_ages = np.asarray(data["ages_sub"], dtype=np.float32) all_zs = np.asarray(data["zs_sub"], dtype=np.float32) sigmatemplate = float(data["sigmatemplate"]) fixed_dispersion = float(data["fixed_dispersion"]) if z_include is not None: z_min, z_max = np.min(z_include), np.max(z_include) z_mask = (all_zs >= z_min) & (all_zs <= z_max) if not np.any(z_mask): raise ValueError(f"No metallicities in range {z_min} to {z_max}") zs = all_zs[z_mask] cube = cube[z_mask, :, :] else: zs = all_zs if age_include is not None: a_min, a_max = np.min(age_include), np.max(age_include) a_mask = (all_ages >= a_min) & (all_ages <= a_max) if not np.any(a_mask): raise ValueError(f"No ages in range {a_min} to {a_max}") ages = all_ages[a_mask] cube = cube[:, a_mask, :] else: ages = all_ages if xmin is not None or xmax is not None: mask = np.ones_like(wave, dtype=bool) if xmin is not None: mask &= wave >= max([xmin * f - 50.0, float(wave.min())]) if xmax is not None: mask &= wave <= min([xmax * f + 50.0, float(wave.max())]) if not np.any(mask): raise ValueError("No wavelength values left after applying x_min/x_max cut.") wave = wave[mask].astype(np.float32, copy=False) cube = cube[:, :, mask].astype(np.float32, copy=False) #print(wave) dx = float(wave[1] - wave[0]) n_Z, n_age, n_pix = cube.shape if verbose: print(f"Host added with n_Z: {n_Z} and n_age: {n_age}") eps = 1e-30 flux_int = np.nansum(cube, axis=-1, keepdims=True) cube = cube / (flux_int + eps) templates_flat = cube.reshape(-1, n_pix) # numpy array grid_metadata = [(float(Z), float(age)) for Z in zs for age in ages] param_names = ["logamp", "logFWHM", "vshift_kms"] linear_params = [] for Z, age in grid_metadata: zstr = str(Z).replace(".", "p") astr = str(age).replace(".", "p") param_names.append(f"weight_Z{zstr}_age{astr}") linear_params.append(f"weight_Z{zstr}_age{astr}") templates_jax = jnp.asarray(templates_flat) # (N, P) float32 wave_jax = jnp.asarray(wave) #print(wave_jax) freq = jnp.fft.fftfreq(n_pix, d=dx).astype(jnp.float32) # (P,) print(linear_params) @with_param_names(param_names,linear_param_names=linear_params) def model(x: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray: logamp = params[0] amplitude = 10.0 ** logamp logFWHM = params[1] vshift_kms = params[2] weights = params[3:] # (N,) base = jnp.tensordot(weights, templates_jax, axes=(0, 0)) # (P,) # --- broadening --- FWHM = 10.0 ** logFWHM # total FWHM in km/s sigma_model = FWHM / 2.355 # km/s diff_sq = sigma_model**2 - sigmatemplate**2 diff_sq_safe = jax.nn.softplus(diff_sq / 10.0) * 10.0 + 1e-12 delta_sigma = jnp.sqrt(diff_sq_safe) # km/s sigma_pix = delta_sigma / fixed_dispersion sigma_lambda = sigma_pix * dx gauss_tf = jnp.exp(-2.0 * (jnp.pi * freq * sigma_lambda) ** 2) base_fft = jnp.fft.fft(base) conv = jnp.real(jnp.fft.ifft(base_fft * gauss_tf)) beta = vshift_kms / DEFAULT_C_KMS xp = wave_jax * (1.0 + beta) return amplitude * jnp.interp(x*f, xp, conv, left=0.0, right=0.0) return { "model": model, "host_info": { "z_include": zs, "age_include": ages, "n_Z": n_Z, "n_age": n_age, "xmin": xmin, "xmax": xmax, "file_name":filename }, }