"""
Continuum Profiles
==================
This module defines all continuum profile functions available in *sheap*.
Each function is JAX-compatible and decorated with ``@with_param_names``
to provide consistent parameter naming for fitting routines.
Profiles
--------
- ``linear`` : Linear continuum (slope + intercept).
- ``powerlaw`` : Standard power law anchored at λ₀=5500 Å.
- ``brokenpowerlaw`` : Two-slope power law with a break wavelength.
- ``logparabola`` : Log-parabolic shape with curvature term.
- ``exp_cutoff`` : Power law with exponential cutoff.
- ``polynomial`` : Polynomial expansion.
Constants
---------
- ``delta0`` : Reference wavelength (5500 Å) used for continuum scaling.
Notes
-----
- All functions take wavelength arrays in Ångström and return dimensionless
continuum templates scaled by their amplitude parameter.
- The reference wavelength ``delta0`` ensures consistent normalization
across continuum forms.
"""
__author__ = 'felavila'
__all__ = [
"brokenpowerlaw",
"delta0",
"exp_cutoff",
"linear",
"logparabola",
"polynomial",
"powerlaw",
]
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import jax
import jax.numpy as jnp
from jax import jit, vmap
from sheap.Profiles.Utils import with_param_names
"""
Note
--------
delta0 : Reference wavelength (5500 Å) used for continuum scaling.
"""
delta0 = 5500.0 #: Normalization wavelength in Ångström used for continuum models (λ/λ₀)
#TODO Check in the profiles with only one amplitude -> move all to logamp
[docs]
@with_param_names(["amplitude_slope", "amplitude_intercept"])
def linear(xs: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray:
r"""
Linear continuum profile.
.. math::
f(\lambda) = \text{intercept} + \text{slope} \cdot \left(\frac{\lambda}{\lambda_0}\right)
Parameters
----------
xs : jnp.ndarray
Wavelengths in Ångström.
params : array-like
- `params[0]`: Slope
- `params[1]`: Intercept
Returns
-------
jnp.ndarray
Evaluated flux.
"""
slope, intercept = params
x = xs / delta0
return intercept + slope * x
[docs]
@with_param_names(["logamp","alpha"])
def powerlaw(xs: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray:
r"""
Power-law continuum profile.
.. math::
f(\lambda) = A \cdot \left(\frac{\lambda}{\lambda_0}\right)^{\alpha}
Parameters
----------
xs : jnp.ndarray
Wavelengths in Ångström.
params : array-like
- `params[0]`: Slope :math:`\alpha`
- `params[1]`: Amplitude :math:`A`
Returns
-------
jnp.ndarray
Evaluated flux.
"""
A, alpha = params
x = xs / delta0
return 10**A * x ** alpha
[docs]
@with_param_names(["logamp", "alpha1", "alpha2", "x_break"])
def brokenpowerlaw(xs: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray:
r"""
Broken power-law continuum profile.
.. math::
f(\lambda) =
\begin{cases}
A \left(\dfrac{\lambda}{\lambda_0}\right)^{\alpha_1}, & \text{if } \lambda < x_{\text{break}} \\
A \, x_{\text{break}}^{\alpha_1 - \alpha_2} \left(\dfrac{\lambda}{\lambda_0}\right)^{\alpha_2}, & \text{otherwise}
\end{cases}
Parameters
----------
xs : jnp.ndarray
Wavelengths in Ångström.
params : array-like
- `params[0]`: Amplitude :math:`A`
- `params[1]`: Slope below break :math:`\alpha_1`
- `params[2]`: Slope above break :math:`\alpha_2`
- `params[3]`: Break wavelength :math:`x_{break}` in Ångström
Returns
-------
jnp.ndarray
Evaluated flux.
"""
A, alpha1, alpha2, xbr = params
x = xs / delta0
xbr = xbr / delta0
low = 10**A * x ** alpha1
high = 10**A * (xbr ** (alpha1 - alpha2)) * x ** alpha2
return jnp.where(x < xbr, low, high)
[docs]
@with_param_names(["amplitude", "alpha", "beta"])
def logparabola(xs: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray:
r"""
Log-parabolic continuum profile.
.. math::
f(\lambda) = A \cdot \left(\frac{\lambda}{\lambda_0}\right)^{-\alpha - \beta \cdot \log(\lambda / \lambda_0)}
Parameters
----------
xs : jnp.ndarray
Wavelengths in Ångström.
params : array-like
- `params[0]`: Amplitude :math:`A`
- `params[1]`: Spectral index :math:`\alpha`
- `params[2]`: Curvature parameter :math:`\beta`
Returns
-------
jnp.ndarray
Evaluated flux.
"""
A, alpha, beta = params
x = xs / delta0
return A * x ** (-alpha - beta * jnp.log(x))
[docs]
@with_param_names(["amplitude", "alpha", "x_cut"])
def exp_cutoff(xs: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray:
r"""
Power-law with exponential cutoff.
.. math::
f(\lambda) = A \cdot \left(\frac{\lambda}{\lambda_0}\right)^{-\alpha} \cdot \exp\left(-\frac{\lambda}{x_{cut}}\right)
Parameters
----------
xs : jnp.ndarray
Wavelengths in Ångström.
params : array-like
- `params[0]`: Amplitude :math:`A`
- `params[1]`: Slope :math:`\alpha`
- `params[2]`: Cutoff wavelength :math:`x_{cut}` in Ångström
Returns
-------
jnp.ndarray
Evaluated flux.
"""
A, alpha, xcut = params
x = xs / delta0
return A * x ** (-alpha) * jnp.exp(-xs / xcut)
########################################################################################################################
def make_linear_function(delta0: float = 5500.0):
@with_param_names(["amplitude_slope", "amplitude_intercept"])
def linear(xs: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray:
r"""
Linear continuum profile.
.. math::
f(\lambda) = \text{intercept} + \text{slope} \cdot \left(\frac{\lambda}{\lambda_0}\right)
Parameters
----------
xs : jnp.ndarray
Wavelengths in Ångström.
params : array-like
- `params[0]`: Slope
- `params[1]`: Intercept
Returns
-------
jnp.ndarray
Evaluated flux.
"""
slope, intercept = params
x = xs / delta0
return intercept + slope * x
return linear
def make_powerlaw_function(delta0: float = 5500.0):
@with_param_names(["logamp","alpha"])
def powerlaw(xs: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray:
r"""
Power-law continuum profile.
.. math::
f(\lambda) = A \cdot \left(\frac{\lambda}{\lambda_0}\right)^{\alpha}
Parameters
----------
xs : jnp.ndarray
Wavelengths in Ångström.
params : array-like
- `params[0]`: Slope :math:`\alpha`
- `params[1]`: Amplitude :math:`A`
Returns
-------
jnp.ndarray
Evaluated flux.
"""
A, alpha = params
x = xs / delta0
return 10**A * x ** alpha
return powerlaw
def make_brokenpowerlaw_function(delta0: float = 5500.0):
@with_param_names(["logamp", "alpha1", "alpha2", "x_break"])
def brokenpowerlaw(xs: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray:
r"""
Broken power-law continuum profile.
.. math::
f(\lambda) =
\begin{cases}
A \left(\dfrac{\lambda}{\lambda_0}\right)^{\alpha_1}, & \text{if } \lambda < x_{\text{break}} \\
A \, x_{\text{break}}^{\alpha_1 - \alpha_2} \left(\dfrac{\lambda}{\lambda_0}\right)^{\alpha_2}, & \text{otherwise}
\end{cases}
Parameters
----------
xs : jnp.ndarray
Wavelengths in Ångström.
params : array-like
- `params[0]`: Amplitude :math:`A`
- `params[1]`: Slope below break :math:`\alpha_1`
- `params[2]`: Slope above break :math:`\alpha_2`
- `params[3]`: Break wavelength :math:`x_{break}` in Ångström
Returns
-------
jnp.ndarray
Evaluated flux.
"""
A, alpha1, alpha2, xbr = params
x = xs / delta0
xbr = xbr / delta0
low = 10**A * x ** alpha1
high = 10**A * (xbr ** (alpha1 - alpha2)) * x ** alpha2
return jnp.where(x < xbr, low, high)
return brokenpowerlaw
def make_exp_cutoff_function(delta0: float = 5500.0):
@with_param_names(["amplitude", "alpha", "x_cut"])
def exp_cutoff(xs: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray:
r"""
Power-law with exponential cutoff.
.. math::
f(\lambda) = A \cdot \left(\frac{\lambda}{\lambda_0}\right)^{-\alpha} \cdot \exp\left(-\frac{\lambda}{x_{cut}}\right)
Parameters
----------
xs : jnp.ndarray
Wavelengths in Ångström.
params : array-like
- `params[0]`: Amplitude :math:`A`
- `params[1]`: Slope :math:`\alpha`
- `params[2]`: Cutoff wavelength :math:`x_{cut}` in Ångström
Returns
-------
jnp.ndarray
Evaluated flux.
"""
A, alpha, xcut = params
x = xs / delta0
return A * x ** (-alpha) * jnp.exp(-xs / xcut)
return exp_cutoff
def make_polynomial_function(degree: int, delta0: float = 5500.0):
if degree < 0:
raise ValueError("degree must be >= 0")
param_names = ["logamp"] + [f"c{i}" for i in range(1, degree + 1)]
@with_param_names(param_names, profile_name="polynomial")
def polynomial(xs: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray:
A = params[0]
coeffs = params[1:]
x = (xs - delta0) / delta0
corr = 0.0
for i, c in enumerate(coeffs, start=1):
corr = corr + c * x**i
return 10**A * jnp.exp(corr)
return polynomial
def make_logparabola_function(delta0: float = 5500.0):
@with_param_names(["amplitude", "alpha", "beta"])
def logparabola(xs: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray:
r"""
Log-parabolic continuum profile.
.. math::
f(\lambda) = A \cdot \left(\frac{\lambda}{\lambda_0}\right)^{-\alpha - \beta \cdot \log(\lambda / \lambda_0)}
Parameters
----------
xs : jnp.ndarray
Wavelengths in Ångström.
params : array-like
- `params[0]`: Amplitude :math:`A`
- `params[1]`: Spectral index :math:`\alpha`
- `params[2]`: Curvature parameter :math:`\beta`
Returns
-------
jnp.ndarray
Evaluated flux.
"""
A, alpha, beta = params
x = xs / delta0
return A * x ** (-alpha - beta * jnp.log(x))
return logparabola