Source code for sheap.Profiles.profiles_lines_loglambda

"""
Log-Lambda Emission Line Profiles
================================

This module defines a collection of **emission-line profile functions**
evaluated in **logarithmic wavelength space**, ensuring **velocity symmetry**
for Doppler-broadened spectral features.

All profiles accept a **linear wavelength grid** as input but internally
operate in log(lambda) space via the transformation:

.. math::

    v = c \\, \\ln(\\lambda / \\lambda_0)

where :math:`\\lambda_0` is the rest-frame wavelength of the line and
:math:`c` is the speed of light.

The profiles defined here are intended for use in *sheap* spectral fitting
pipelines where velocity-based parameterizations (FWHM, centroid shifts)
are preferred over linear wavelength widths.

Available profiles
------------------
- ``gaussian_fwhm_loglambda`` :
  Velocity-symmetric Gaussian profile.
- ``lorentzian_fwhm_loglambda`` :
  Lorentzian (Cauchy) profile in velocity space.
- ``skewed_gaussian_loglambda`` :
  Skew-normal Gaussian profile allowing asymmetric line shapes.
- ``top_hat_loglambda`` :
  Boxcar (top-hat) profile in velocity space.
- ``voigt_pseudo_loglambda`` :
  Pseudo-Voigt profile (Gaussian–Lorentzian mixture).
- ``emg_fwhm_loglambda`` :
  Exponentially Modified Gaussian (EMG) profile.

All profiles share a common parameterization:
- ``amplitude`` : Linear amplitude of the line.
- ``vshift_kms`` : Velocity shift of the centroid [km/s].
- ``fwhm_v_kms`` : log10 of the velocity FWHM [km/s].
- ``lambda0`` : Rest-frame wavelength of the line (fixed).

Some profiles include additional shape parameters (e.g. skewness,
Lorentzian mixing fraction, or exponential decay scale).

Notes
-----
- Using log(lambda) ensures exact symmetry in velocity space and avoids
  wavelength-dependent distortions present in linear-lambda profiles.
- These profiles are fully JAX-compatible and differentiable, enabling
  GPU-accelerated optimization and sampling.
- Physical bounds and initial values for these profiles are handled by
  the corresponding constraint-building utilities in *sheap*.

Examples
--------
.. code-block:: python

    import jax.numpy as jnp
    from sheap.Profiles.loglambda_profiles import gaussian_fwhm_loglambda

    x_lambda = jnp.linspace(6500.0, 6600.0, 2048)
    params = [1.0, 0.0, 3.0, 6563.0]  # amp, vshift, log10(FWHM), lambda0

    y = gaussian_fwhm_loglambda(x_lambda, params)
"""

__author__ = "felavila"



from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import jax.numpy as jnp
from jax import jit, vmap,lax 
from jax.scipy.special import erfc
from jax.scipy.stats import norm #maybe dosent exist xd

from sheap.Profiles.Utils import with_param_names,trapz_jax
from sheap.Utils.Constants import DEFAULT_C_KMS,FWHM_TO_SIGMA



__all__ = ["gaussian_fwhm_loglambda",
    "lorentzian_fwhm_loglambda",
    "skewed_gaussian_loglambda",
    "top_hat_loglambda",
    "voigt_pseudo_loglambda",
    "emg_fwhm_loglambda",]

[docs] @with_param_names(["amplitude", "vshift_kms", "fwhm_v_kms", "lambda0"]) def gaussian_fwhm_loglambda(x_lambda, params): """ Velocity-symmetric Gaussian in log(lambda) space. Parameters ---------- x_lambda : jnp.ndarray Wavelength grid. params : [amp, vshift_kms, fwhm_v_kms, lambda0] amp : Linear amplitude vshift_kms : Centroid velocity shift [km/s] fwhm_v_kms : log10(FWHM [km/s]) lambda0 : Rest wavelength of the line (required, fixed) """ amp, vshift_kms, fwhm_v_kms, lambda0 = params ratio = jnp.clip(x_lambda / lambda0, a_min=jnp.finfo(x_lambda.dtype).tiny, a_max=jnp.inf) y = DEFAULT_C_KMS * jnp.log(ratio) fwhm_linear = 10.0 ** fwhm_v_kms sigma_v = fwhm_linear * FWHM_TO_SIGMA z = (y - vshift_kms) / sigma_v return amp * jnp.exp(-0.5 * z * z)
[docs] @with_param_names(["amplitude", "vshift_kms", "fwhm_v_kms", "lambda0"]) def lorentzian_fwhm_loglambda(x_lambda, params): """ Lorentzian (Cauchy) profile evaluated in log(lambda) velocity space. Parameters ---------- x_lambda : jnp.ndarray Wavelength grid (linear wavelength). params : [amplitude, vshift_kms, fwhm_v_kms, lambda0] amplitude : Linear amplitude. vshift_kms : Velocity shift of the centroid [km/s]. fwhm_v_kms : log10(FWHM velocity width [km/s]). lambda0 : Rest-frame wavelength of the line. """ amp, vshift_kms, fwhm_v_kms, lambda0 = params ratio = jnp.clip(x_lambda / lambda0, a_min=jnp.finfo(x_lambda.dtype).tiny, a_max=jnp.inf) y = DEFAULT_C_KMS * jnp.log(ratio) fwhm = 10.0 ** fwhm_v_kms gamma = 0.5 * fwhm z = (y - vshift_kms) / gamma return amp / (1.0 + z * z)
[docs] @with_param_names(["amplitude", "vshift_kms", "fwhm_v_kms", "alpha", "lambda0"]) def skewed_gaussian_loglambda(x_lambda, params): """ Skewed Gaussian (skew-normal) profile evaluated in log(lambda) space. Parameters ---------- x_lambda : jnp.ndarray Wavelength grid (linear wavelength). params : [amplitude, vshift_kms, fwhm_v_kms, alpha, lambda0] amplitude : Linear amplitude. vshift_kms : Velocity shift of the centroid [km/s]. fwhm_v_kms : log10(FWHM velocity width [km/s]). alpha : Skewness parameter (positive = red skew). lambda0 : Rest-frame wavelength of the line. """ amp, vshift_kms, fwhm_v_kms, alpha, lambda0 = params ratio = jnp.clip(x_lambda / lambda0, a_min=jnp.finfo(x_lambda.dtype).tiny, a_max=jnp.inf) y = DEFAULT_C_KMS * jnp.log(ratio) fwhm = 10.0 ** fwhm_v_kms sigma = fwhm * FWHM_TO_SIGMA t = (y - vshift_kms) / sigma phi = jnp.exp(-0.5 * t * t) / jnp.sqrt(2.0 * jnp.pi) Phi = 0.5 * (1.0 + jnp.erf(alpha * t / jnp.sqrt(2.0))) return amp * 2.0 * phi * Phi
[docs] @with_param_names(["amplitude", "vshift_kms", "fwhm_v_kms", "lambda0"]) def top_hat_loglambda(x_lambda, params): """ Top-hat (boxcar) profile defined in log(lambda) velocity space. Parameters ---------- x_lambda : jnp.ndarray Wavelength grid (linear wavelength). params : [amplitude, vshift_kms, fwhm_v_kms, lambda0] amplitude : Linear amplitude. vshift_kms : Velocity shift of the centroid [km/s]. fwhm_v_kms : log10(FWHM velocity width [km/s]). lambda0 : Rest-frame wavelength of the line. """ amp, vshift_kms, fwhm_v_kms, lambda0 = params ratio = jnp.clip(x_lambda / lambda0, a_min=jnp.finfo(x_lambda.dtype).tiny, a_max=jnp.inf) y = DEFAULT_C_KMS * jnp.log(ratio) fwhm = 10.0 ** fwhm_v_kms half = 0.5 * fwhm return amp * (jnp.abs(y - vshift_kms) <= half)
[docs] @with_param_names(["amplitude", "vshift_kms", "fwhm_v_kms", "eta", "lambda0"]) def voigt_pseudo_loglambda(x_lambda, params): """ Pseudo-Voigt profile (Gaussian + Lorentzian mixture) evaluated in log(lambda) velocity space. Parameters ---------- x_lambda : jnp.ndarray Wavelength grid (linear wavelength). params : [amplitude, vshift_kms, fwhm_v_kms, eta, lambda0] amplitude : Linear amplitude. vshift_kms : Velocity shift of the centroid [km/s]. fwhm_v_kms : log10(FWHM velocity width [km/s]). eta : Mixing fraction (eta=1 Lorentzian, eta=0 Gaussian). lambda0 : Rest-frame wavelength of the line. """ amp, vshift_kms, fwhm_v_kms, eta, lambda0 = params ratio = jnp.clip(x_lambda / lambda0, a_min=jnp.finfo(x_lambda.dtype).tiny, a_max=jnp.inf) y = DEFAULT_C_KMS * jnp.log(ratio) fwhm = 10.0 ** fwhm_v_kms sigma = fwhm * FWHM_TO_SIGMA gamma = 0.5 * fwhm tg = (y - vshift_kms) / sigma tl = (y - vshift_kms) / gamma G = jnp.exp(-0.5 * tg * tg) L = 1.0 / (1.0 + tl * tl) return amp * (eta * L + (1.0 - eta) * G)
[docs] @with_param_names(["amplitude", "vshift_kms", "fwhm_v_kms", "tau_kms", "lambda0"]) def emg_fwhm_loglambda(x_lambda, params): """ Exponentially Modified Gaussian (EMG) profile evaluated in log(lambda) velocity space. Parameters ---------- x_lambda : jnp.ndarray Wavelength grid (linear wavelength). params : [amplitude, vshift_kms, fwhm_v_kms, tau_kms, lambda0] amplitude : Linear amplitude. vshift_kms : Velocity shift of the centroid [km/s]. fwhm_v_kms : log10(FWHM velocity width [km/s]). tau_kms : Exponential decay scale [km/s]. lambda0 : Rest-frame wavelength of the line. """ amp, vshift_kms, fwhm_v_kms, tau_kms, lambda0 = params ratio = jnp.clip(x_lambda / lambda0, a_min=jnp.finfo(x_lambda.dtype).tiny, a_max=jnp.inf) y = DEFAULT_C_KMS * jnp.log(ratio) fwhm = 10.0 ** fwhm_v_kms sigma = fwhm * FWHM_TO_SIGMA tau = jnp.clip(tau_kms, a_min=jnp.finfo(x_lambda.dtype).tiny, a_max=jnp.inf) arg = (sigma / tau - (y - vshift_kms) / sigma) / jnp.sqrt(2.0) core = jnp.exp((sigma * sigma) / (2.0 * tau * tau) - (y - vshift_kms) / tau) return amp * 0.5 / tau * core * erfc(arg)