Source code for sheap.Profiles.Combine

"""
Log-Lambda Profile Combination Utilities
========================================

This module defines utilities for **combining multiple emission-line components**
evaluated in **logarithmic wavelength space**, using velocity-based
parameterizations.

It provides:

- ``PROFILE_LINE_FUNC_MAP_loglambda`` :
  A registry mapping canonical profile names (e.g. ``"gaussian"``,
  ``"lorentzian"``, ``"skewed_gaussian"``) to their corresponding
  log-lambda profile functions.

- ``SPAF_loglambda`` :
  A SPAF (Sum Profiles Amplitude Free) constructor for log-lambda profiles,
  enabling physically motivated combinations of multiple emission lines
  with shared kinematic parameters.

The profiles referenced here operate internally in log(lambda) space via
the transformation:

.. math::

    v = c \\, \\ln(\\lambda / \\lambda_0)

ensuring exact symmetry in velocity space for Doppler-broadened features.

SPAF allows multiple lines to:
- Share kinematic parameters (velocity shift, FWHM, and shape parameters)
- Enforce fixed or semi-fixed amplitude ratios (e.g. doublets, multiplets)
- Be modeled with a reduced number of free parameters

Notes
-----
- Only profiles registered in ``PROFILE_LINE_FUNC_MAP_loglambda`` can be
  combined using ``SPAF_loglambda``.
- Base profiles must be decorated with ``@with_param_names`` and include
  at least ``"amplitude"`` and ``"lambda0"`` in their parameter list.
- Physical bounds and initial values for the combined parameters are
  handled by the constraint-building utilities elsewhere in *sheap*.

Examples
--------
.. code-block:: python

    from sheap.Profiles.combine import SPAF_loglambda

    # Hα + [NII] doublet with fixed 3:1 ratio
    centers = [6548.05, 6583.45]
    rules = [(0, 1.0, 0), (1, 3.0, 0)]

    G = SPAF_loglambda(
        centers=centers,
        amplitude_rules=rules,
        profile_name="gaussian",
    )

    # params = [amplitude0, vshift_kms, fwhm_v_kms]
    y = G(x_lambda, params)
"""

__author__ = "felavila"


from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from sheap.Core import ProfileFunc
import jax.numpy as jnp
from typing import List, Tuple, Callable
from sheap.Profiles.Utils import with_param_names,trapz_jax

from sheap.Profiles.profiles_lines_loglambda import (gaussian_fwhm_loglambda, lorentzian_fwhm_loglambda, skewed_gaussian_loglambda, top_hat_loglambda, voigt_pseudo_loglambda, emg_fwhm_loglambda,)

PROFILE_LINE_FUNC_MAP_loglambda: Dict[str, ProfileFunc] = {
    "gaussian": gaussian_fwhm_loglambda,
    "lorentzian": lorentzian_fwhm_loglambda,
    "skewed_gaussian": skewed_gaussian_loglambda,
    "top_hat": top_hat_loglambda,
    "voigt_pseudo": voigt_pseudo_loglambda,
    "emg": emg_fwhm_loglambda,
}


[docs] def SPAF_loglambda( centers: List[float], amplitude_rules: List[Tuple[int, float, int]], profile_name: str, ): """ SPAF (Sum Profiles Amplitude Free) wrapper for *log-lambda* line profiles. This builds a composite profile made of multiple lines that share the same *shape parameters* (e.g., ``vshift_kms``, ``fwhm_v_kms``, and any extra shape params like ``alpha``, ``eta``, or ``tau_kms``), while allowing a flexible set of **free amplitudes** combined through ``amplitude_rules``. Parameters ---------- centers : list[float] Per-line rest wavelengths :math:`\\lambda_0` (Å). These are required and injected as the last parameter of the base profile for each line. amplitude_rules : list[(line_idx, coefficient, free_amp_idx)] For each line: ``amp_line = coefficient * free_amplitudes[free_amp_idx]``. Example for a doublet with fixed 2:1 ratio sharing the same free amp 0:: [(0, 1.0, 0), (1, 0.5, 0)] profile_name : str Name of the base profile to use. It must exist in ``PROFILE_LINE_FUNC_MAP_loglambda`` and be decorated with ``@with_param_names``. The base profile must include at least these parameter names: ``"amplitude"`` and ``"lambda0"``. Any additional parameters are treated as *shared* across all lines. Returns ------- ProfileFunc A callable ``G(x_lambda, params)`` decorated with ``@with_param_names``. The parameter layout is: - ``[amplitude0, ..., amplitude{Nfree-1}, <shared_params...>]`` where ``<shared_params...>`` are all base parameters except ``amplitude`` and ``lambda0`` (in the same order as the base profile's ``param_names``). Notes ----- - This works for any log-lambda base profile with signature ``base_func(x_lambda, params)`` and ``param_names`` containing ``"amplitude"`` and ``"lambda0"``. - Shape parameters are shared across all lines; only amplitudes are combined via ``amplitude_rules``. """ centers = jnp.asarray(centers, dtype=jnp.float32) base_func = PROFILE_LINE_FUNC_MAP_loglambda.get(profile_name) if base_func is None: raise ValueError( f"Profile '{profile_name}' not found in PROFILE_LINE_FUNC_MAP_loglambda." ) base_param_names = getattr(base_func, "param_names", None) if not base_param_names: raise ValueError( f"Base profile '{profile_name}' must be decorated with @with_param_names " f"and expose 'param_names'." ) if "amplitude" not in base_param_names or "lambda0" not in base_param_names: raise ValueError( f"Base profile '{profile_name}' must include parameter names " f"'amplitude' and 'lambda0'. Got: {base_param_names}" ) # Shared params are everything except amplitude + lambda0, in base order shared_names = [n for n in base_param_names if n not in ("amplitude", "lambda0")] # Normalize/compact free amplitude indices raw_free = [r[2] for r in amplitude_rules] uniq = sorted({int(i) for i in raw_free}) idx_map = {orig: new for new, orig in enumerate(uniq)} rules = [(int(li), float(coef), idx_map[int(fi)]) for li, coef, fi in amplitude_rules] n_free = len(uniq) # Public param names for the composite param_names = [f"amplitude{k}" for k in range(n_free)] + shared_names @with_param_names(param_names) def G(x_lambda, params): x_dtype = x_lambda.dtype amps_linear = params[:n_free] # linear free amplitudes shared_vals = params[n_free:] # shared shape params in shared_names order total = jnp.array(0.0, dtype=x_dtype) # Build each line with correct base param ordering for line_idx, coef, free_idx in rules: amp_line = coef * amps_linear[free_idx] lambda0_i = centers[line_idx].astype(x_dtype) # map name->value for the base params pdict = {"amplitude": amp_line, "lambda0": lambda0_i} for name, val in zip(shared_names, shared_vals): pdict[name] = val p_line = jnp.array([pdict[name] for name in base_param_names], dtype=x_dtype) total = total + base_func(x_lambda, p_line) return total return G
[docs] def SPAF_loglambda_old( centers: List[float], amplitude_rules: List[Tuple[int, float, int]], profile_name: str, ): """ SPAF (Sum Profiles Amplitude Free) for log-lambda profiles. Parameters ---------- centers : list[float] Per-line rest wavelengths λ0 (Å). These are *required* and injected as the last parameter of the base profile. amplitude_rules : list[(line_idx, coefficient, free_amp_idx)] For each line: amp_line = coefficient * free_amplitudes[free_amp_idx]. Example for a doublet with fixed 2:1 ratio sharing the same free amp 0: [(0, 1.0, 0), (1, 0.5, 0)] base_func : Callable A profile with param_names == ["amp","vshift_kms","fwhm_v_kms","lambda0"]. Returns ------- ProfileFunc G(x, params) params layout: [ amplitude0, amplitude1, ..., amplitude_{Nfree-1}, shift_kms, # shared Δv for the whole group fwhm_v_kms ] # shared FWHM in km/s """ centers = jnp.asarray(centers, dtype=jnp.float32) base_func = PROFILE_LINE_FUNC_MAP_loglambda.get(profile_name) if base_func is None: raise ValueError(f"Profile '{profile_name}' not found in PROFILE_LINE_FUNC_MAP_loglambda.") # normalize/compact free amplitude indices raw_free = [r[2] for r in amplitude_rules] uniq = sorted({int(i) for i in raw_free}) idx_map = {orig: new for new, orig in enumerate(uniq)} rules = [(li, coef, idx_map[int(fi)]) for li, coef, fi in amplitude_rules] n_free = len(uniq) # Public param names (self-documenting) param_names = [f"amplitude{k}" for k in range(n_free)] + ["vshift_kms", "fwhm_v_kms"] @with_param_names(param_names) def G(x, params): amps_linear = params[:n_free] # linear amplitudes vshift = params[n_free + 0] # shared Δv [km/s] #fwhm_vkms = 10**params[n_free + 1] # shared FWHM_v [km/s] fwhm_vkms = params[n_free + 1] # stored as log10(FWHM [km/s]) #fwhm_vkms = jnp.maximum(jnp.power(10.0, log10_fwhm), jnp.finfo(x.dtype).tiny) total = 0.0 for line_idx, coef, free_idx in rules: amp_line = coef * amps_linear[free_idx] lambda0_i = centers[line_idx] # base expects [amp, vshift_kms, fwhm_v_kms, lambda0] p_line = jnp.array([amp_line, vshift, fwhm_vkms, lambda0_i], dtype=x.dtype) total += base_func(x, p_line) return total return G