Source code for sheap.Profiles.profiles_lines

r"""
Line Profiles
=============

This module implements all emission- and absorption-line profile functions 
available in *sheap*. These functions define the mathematical shapes of 
spectral lines (Gaussian, Lorentzian, Voigt, skewed Gaussian, EMG, etc.) 
and provide consistent JAX-compatible implementations for fitting routines.

Profiles are parameterized in terms of **log-amplitude**, **center**, and 
width measures (FWHM, \sigma, \gamma), with extensions for skewness, exponential 
decay, or Hermite expansions.

Functions
---------
- ``gaussian_fwhm`` : Standard Gaussian profile with FWHM parameterization.
- ``lorentzian_fwhm`` : Lorentzian profile with FWHM.
- ``voigt_pseudo`` : Pseudo-Voigt (linear combination of Gaussian and Lorentzian).
- ``skewed_gaussian`` : Skew-normal Gaussian with shape parameter α.
- ``emg_fwhm`` : Exponentially Modified Gaussian (Gaussian ⊗ exponential decay).
- ``top_hat`` : Rectangular (boxcar) profile.
- ``eval_hermite`` : Recursive Hermite polynomial evaluator.
- ``gauss_hermite_losvd_jax`` : Gauss–Hermite line-of-sight velocity distribution.

Notes
-----
- All profiles are decorated with ``@with_param_names`` to provide
    consistent parameter naming across the codebase.
- Amplitudes are expressed in base-10 logarithmic form (``amplitude``),
    so physical scaling is applied as ``amplitude``.
- Functions are written in JAX and fully differentiable, suitable for
    gradient-based fitting and uncertainty propagation.

Examples
--------
.. code-block:: python

    import jax.numpy as jnp
    from sheap.Profiles.profiles_line import gaussian_fwhm

    x = jnp.linspace(6500, 6600, 1000)
    params = jnp.array([0.0, 6563.0, 10.0])  # logamp=0 → amp=1, center=6563Å, FWHM=10Å
    y = gaussian_fwhm(x, params)

"""

__author__ = 'felavila'


__all__ = [
    "emg_fwhm",
    "eval_hermite",
    "gauss_hermite_losvd_jax",
    "gaussian_fwhm",
    "lorentzian_fwhm",
    "skewed_gaussian",
    "top_hat",
    "voigt_pseudo",
]

from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import jax.numpy as jnp
from jax import jit, vmap,lax 
from jax.scipy.special import erfc
from jax.scipy.stats import norm #maybe dosent exist xd

from sheap.Profiles.Utils import with_param_names,trapz_jax

[docs] @with_param_names(["amplitude", "center", "fwhm"]) def gaussian_fwhm(x, params): r""" Standard Gaussian line profile using FWHM. .. math:: f(x) = A \cdot \exp\left( -\frac{1}{2} \left( \frac{x - \mu}{\sigma} \right)^2 \right) where: - :math:`A = 10^{\mathrm{amplitude}}` - :math:`\sigma = \mathrm{fwhm} / 2.355` Parameters ---------- x : jnp.ndarray Input wavelength array. params : array-like - `amplitude`: Log base-10 amplitude. - `center`: Line center. - `fwhm`: Full width at half maximum. Returns ------- jnp.ndarray Profile evaluated at `x`. """ log_amp, center, fwhm = params #center = logcenter amplitude = log_amp #amplitude = jnp.sign(log_amp) *10 ** jnp.abs(log_amp) #amplitude = log_amp sigma = fwhm / 2.355 #fwhm -> logfwhm return amplitude * jnp.exp(-0.5 * ((x - center) / sigma) ** 2)
[docs] @with_param_names(["amplitude", "center", "fwhm"]) def lorentzian_fwhm(x, params): r""" Lorentzian line profile using FWHM. .. math:: f(x) = \frac{A}{1 + \left( \frac{x - \mu}{\gamma} \right)^2 } where: - :math:`A = 10^{\mathrm{amplitude}}` - :math:`\gamma = \mathrm{fwhm} / 2` Parameters ---------- x : jnp.ndarray Input wavelength array. params : array-like - `amplitude`: Log base-10 amplitude. - `center`: Line center. - `fwhm`: Full width at half maximum. Returns ------- jnp.ndarray Profile evaluated at `x`. """ log_amp, center, fwhm = params amplitude = log_amp gamma = fwhm / 2.0 return amplitude / (1.0 + ((x - center) / gamma) ** 2)
#################### Exotic ##############
[docs] @with_param_names(["amplitude", "center", "fwhm_g", "fwhm_l"]) def voigt_pseudo(x, params): r""" Pseudo-Voigt profile (weighted sum of Gaussian and Lorentzian). .. math:: f(x) = A \cdot \left[ \eta \cdot L(x) + (1 - \eta) \cdot G(x) \right] where: - :math:`A = 10^{\mathrm{amplitude}}` - :math:`\sigma = \mathrm{fwhm_g} / 2.355` - :math:`\gamma = \mathrm{fwhm_l} / 2` - :math:`\eta` is an empirical function of :math:`\gamma` and :math:`\sigma` Parameters ---------- x : jnp.ndarray Input wavelength array. params : array-like - `amplitude`: Log base-10 amplitude. - `center`: Line center. - `fwhm_g`: Gaussian FWHM. - `fwhm_l`: Lorentzian FWHM. Returns ------- jnp.ndarray Profile evaluated at `x`. """ log_amp, center, fwhm_g, fwhm_l = params amplitude = log_amp sigma = fwhm_g / 2.355 gamma = fwhm_l / 2.0 # Ratio for weighting r = gamma / (gamma + sigma * jnp.sqrt(2 * jnp.log(2))) eta = 1.36603 * r - 0.47719 * r**2 + 0.11116 * r**3 # Gaussian and Lorentzian parts gauss = jnp.exp(-0.5 * ((x - center) / sigma) ** 2) lorentz = 1.0 / (1.0 + ((x - center) / gamma) ** 2) return amplitude * (eta * lorentz + (1.0 - eta) * gauss)
[docs] @with_param_names(["amplitude", "center", "fwhm", "alpha"]) def skewed_gaussian(x, params): r""" Skewed Gaussian profile using the Azzalini formulation. .. math:: f(x) = 2A \cdot \phi(t) \cdot \Phi(\alpha t) where: - :math:`t = \frac{x - \mu}{\sigma}` - :math:`\phi(t)` is the standard normal PDF - :math:`\Phi(t)` is the standard normal CDF - :math:`\sigma = \mathrm{fwhm} / 2.355` Parameters ---------- x : jnp.ndarray Input wavelength array. params : array-like - `amplitude`: Log base-10 amplitude. - `center`: Mean of the Gaussian. - `fwhm`: Full width at half maximum. - `alpha`: Skewness parameter. Returns ------- jnp.ndarray Profile evaluated at `x`. """ log_amp, center, fwhm, alpha = params # alpha = skewness amplitude = log_amp sigma = fwhm / 2.355 t = (x - center) / sigma return 2 * amplitude * norm.pdf(t) * norm.cdf(alpha * t)
[docs] @with_param_names(["amplitude", "center", "fwhm", "lambda"]) def emg_fwhm(x, params): r""" Exponentially Modified Gaussian (EMG) profile. .. math:: f(x) = \frac{A \cdot \lambda}{2} \cdot \exp\left( \frac{\lambda}{2}(2\mu + \lambda\sigma^2 - 2x) \right) \cdot \mathrm{erfc}\left( \frac{\mu + \lambda\sigma^2 - x}{\sqrt{2}\sigma} \right) where: - :math:`A = 10^{\mathrm{amplitude}}` - :math:`\sigma = \mathrm{fwhm} / 2.355` Parameters ---------- x : jnp.ndarray Input wavelength array. params : array-like - `amplitude`: Log base-10 amplitude. - `center`: Gaussian mean (μ). - `fwhm`: Gaussian full width at half maximum. - `lambda`: Exponential decay rate (1/τ). Returns ------- jnp.ndarray Profile evaluated at `x`. """ log_amp, mu, fwhm, lambda_ = params amplitude = log_amp sigma = fwhm / 2.355 arg1 = 0.5 * lambda_ * (2 * mu + lambda_ * sigma**2 - 2 * x) arg2 = (mu + lambda_ * sigma**2 - x) / (jnp.sqrt(2) * sigma) return amplitude * 0.5 * lambda_ * jnp.exp(arg1) * erfc(arg2)
[docs] @with_param_names(["amplitude", "center", "width"]) def top_hat(x, params): r""" Rectangular (top-hat) function. .. math:: f(x) = A \quad \text{if } |x - \mu| \leq \frac{w}{2}; \quad 0 \text{ otherwise} where: - :math:`A = 10^{\mathrm{amplitude}}` - :math:`\mu = \text{center}` - :math:`w = \text{width}` Parameters ---------- x : jnp.ndarray Input wavelength array. params : array-like - `amplitude`: Log base-10 amplitude. - `center`: Center of the box. - `width`: Width of the top-hat. Returns ------- jnp.ndarray Profile evaluated at `x`. """ log_amp, center, width = params amplitude = log_amp half_width = width / 2.0 return amplitude * ((x >= (center - half_width)) & (x <= (center + half_width))).astype(jnp.float32)
""" Experimental """
[docs] def eval_hermite(n: int, x: jnp.ndarray) -> jnp.ndarray: r""" Evaluate the physicist’s Hermite polynomial :math:`H_n(x)` recursively using JAX. The recurrence relation is: .. math:: H_0(x) = 1 \\ H_1(x) = 2x \\ H_n(x) = 2x \cdot H_{n-1}(x) - 2(n-1) \cdot H_{n-2}(x) Parameters ---------- n : int Order of the Hermite polynomial. x : jnp.ndarray Input array where the polynomial is evaluated. Returns ------- jnp.ndarray Values of :math:`H_n(x)` with same shape as `x`. """ def body(i, state): H0, H1 = state Hn = 2 * x * H1 - 2 * (i - 1) * H0 return (H1, Hn) H0 = jnp.ones_like(x) H1 = 2 * x _, Hn = lax.fori_loop(2, n + 1, body, (H0, H1)) return lax.select(n == 0, H0, lax.select(n == 1, H1, Hn))
#@jit # 3. Gauss-Hermite LOSVD #@jit
[docs] def gauss_hermite_losvd_jax(v, v0, sigma, h3=0.0, h4=0.0): r""" Line-of-sight velocity distribution (LOSVD) using Gauss-Hermite expansion. Based on van der Marel & Franx (1993) formulation: .. math:: \mathcal{L}(v) = \frac{1}{\sqrt{2\pi} \sigma} \exp\left(-\frac{(v - v_0)^2}{2\sigma^2}\right) \cdot \left[ 1 + h_3 H_3(x) + h_4 H_4(x) \right] where: - :math:`x = \frac{v - v_0}{\sigma}` - :math:`H_3(x)` and :math:`H_4(x)` are normalized Hermite polynomials. - Output is normalized to integrate to 1. Parameters ---------- v : jnp.ndarray Velocity grid in km/s or appropriate units. v0 : float Mean velocity (center). sigma : float Standard deviation of the Gaussian core. h3 : float, optional Third Gauss-Hermite coefficient (skewness). h4 : float, optional Fourth Gauss-Hermite coefficient (kurtosis). Returns ------- jnp.ndarray Normalized LOSVD array with same shape as `v`. """ x = (v - v0) / sigma norm_gauss = jnp.exp(-0.5 * x**2) / (sigma * jnp.sqrt(2 * jnp.pi)) H3 = eval_hermite(3, x) / jnp.sqrt(6.0) H4 = eval_hermite(4, x) / jnp.sqrt(24.0) losvd = norm_gauss * (1 + h3 * H3 + h4 * H4) losvd /= trapz_jax(losvd, v) return losvd