Source code for sheap.Plotting.ploting

"""This module handles ?."""
__author__ = 'felavila'

__all__ = [
    "SheapPlot_old",
    "plot_a_spectra",
    "plot_region",
]

# import seaborn as sns
import jax.numpy as jnp
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

#from sheap.Minimizer.functions import linear_combination


# class Sheap_pca_ploting:
#     import numpy as np

#     def __init__(
#         self, test_clase, masked_uncertainties, fit_array, eigenvectors, params_linear
#     ):
#         self.test_clase = test_clase
#         self.masked_uncertainties = masked_uncertainties
#         self.fit_array = fit_array
#         self.eigenvectors = eigenvectors
#         self.params_linear = params_linear
#         self.combination = self.eigenvectors.T * 100 * self.params_linear.T
#         self.negatives_per_column = jnp.nansum(self.combination < 0, axis=0).T

#     def plot(self, n, save="", **kwargs):
#         # save = False
#         # for n in range(len(dr_filtered)):

#         # Create subplots with shared x-axis
#         # if save and os.path.isfile(f"images/images_pca/{n}.jpg"):
#         # continue
#         fig, (ax1, ax2) = plt.subplots(
#             2, 1, sharex=True, figsize=(35, 15), gridspec_kw={'height_ratios': [2, 1]}
#         )

#         # Set axis labels and their font sizes
#         ax1.set_ylabel("Flux", fontsize=20)
#         # ax1.set_xlabel("Wavelength", fontsize=40)  # Even though ax1 and ax2 share x-axis, you can label ax1
#         ax2.set_xlabel("wavelength A", fontsize=20)
#         ax2.set_ylabel("Normalized Residuals", fontsize=20)

#         # Define the x-axis based on the length of the spectrum
#         x_axis_pix = np.arange(len(self.test_clase.spectra[n, 0, :]))
#         x_limit_pix = x_axis_pix[self.masked_uncertainties[n] != 1e11][[0, -1]]
#         n_pixels = x_axis_pix.shape[0]
#         # Create an array of pixel indices
#         indices = jnp.arange(n_pixels)

#         # Create a boolean mask for indices outside the desired range
#         mask = (indices < x_limit_pix[0]) | (indices > x_limit_pix[1])
#         x_axis = self.test_clase.spectra[n, 0, :]
#         x_limit = [np.nanmin(x_axis), np.nanmax(x_axis)]
#         x_limit = [x_axis[x_limit_pix[0]], x_axis[x_limit_pix[-1]]]
#         obj = self.fit_array[:, 1, :][n]
#         linear_model = linear_combination(self.eigenvectors[n], self.params_linear[n])

#         residual = (self.fit_array[:, 1, :][n] - linear_model) / self.fit_array[:, 2, :][n]
#         model_qso = jnp.nansum(
#             self.eigenvectors[n][10:].T * 100 * self.params_linear[n][10:], axis=1
#         )
#         model_galaxy = jnp.nansum(
#             self.eigenvectors[n][:10].T * 100 * self.params_linear[n][:10], axis=1
#         )
#         linear_model = linear_model.at[mask].set(jnp.nan)
#         model_galaxy = model_galaxy.at[mask].set(jnp.nan)
#         model_qso = model_qso.at[mask].set(jnp.nan)
#         residual = residual.at[mask].set(jnp.nan)
#         maxs = [
#             np.nanmax(obj),
#             np.nanmax(linear_model),
#             np.nanmax(model_qso),
#             np.nanmax(model_galaxy),
#         ]
#         minx = [
#             np.nanmin(obj),
#             np.nanmin(linear_model),
#             np.nanmin(model_qso),
#             np.nanmin(model_galaxy),
#         ]
#         # Compute the model using your linear combination function

#         # Plot the observed object spectrum
#         ax1.plot(x_axis, obj, alpha=1, label=f"object {n}", color='grey')
#         # Plot the model spectrum
#         ax1.plot(x_axis, linear_model, label="model", color='r')
#         # Plot the PCA components
#         ax1.plot(x_axis, model_qso, label="pca_qso")
#         ax1.plot(x_axis, model_galaxy, label="pca_galaxy", color="g")

#         ax1.fill_between(
#             x_axis,
#             0,
#             max(maxs),
#             where=self.masked_uncertainties[n] != 1e11,
#             color="grey",
#             alpha=0.1,
#             zorder=10,
#             label="eigenvalues coverage",
#         )
#         # ax2.fill_between(x_axis, -0.5, 0.5,where=masked_uncertainties[n] != 1e11, color="grey", alpha=0.5,zorder=1, label="eigenvalues coverage")
#         ax1.axhline(0, ls="--", linewidth=5, c="k")
#         # Set the x-axis limits based on the non-masked region
#         ax1.set_xlim(x_limit)
#         ax2.set_xlim(x_limit)
#         ax1.set_ylim(min(minx), max(maxs))
#         # Place the legend for ax1 outside the plot area
#         ax1.legend(bbox_to_anchor=(1.01, 1), loc='upper left', fontsize=30)

#         # For ax2, plot the normalized residuals (observed - model)/error
#         ax2.scatter(x_axis, residual, alpha=0.5, zorder=10)
#         ax2.axhline(0, ls="--", linewidth=5, c="k")

#         ax2.text(
#             0.05,
#             0.95,
#             rf'{(sum(jnp.where(abs(residual)<=0.4,True,False))/sum(~jnp.isnan(residual)))*100:.3f} % between abs(0.4)',
#             transform=ax2.transAxes,
#             fontsize=30,
#             verticalalignment='top',
#         )
#         ax2.set_ylim(-0.4, 0.4)
#         plt.tight_layout(rect=[0, 0, 0.85, 1])
#         # Save or show figure
#         if save:
#             plt.savefig(f"images/{save}.jpg", dpi=300, bbox_inches='tight')
#             plt.close()
#         else:
#             plt.show()

#     def plot_valeu(self, n):
#         plt.plot(self.params_linear[n])
#         plt.axhline(0)
#         plt.axvline(10)
#         plt.axvspan(0, 10, alpha=0.2, color="r", label="galaxy linear paramters")
#         plt.axvspan(10, 60, alpha=0.2, label="qso linear paramters")
#         plt.xlim(0, 59)
#         plt.ylabel("parameter valeu")
#         plt.xlabel("parameter number")
#         plt.legend()
#         plt.show()

#     def plot_n_negatives(self, n):
#         # combination = self.eigenvectors[n].T*100*self.params_linear[n]
#         # negatives_per_column = jnp.nansum(combination < 0, axis=0)
#         plt.plot(self.negatives_per_column[n])
#         plt.axhline(0, ls="--", alpha=0.5)
#         plt.axvline(10)
#         plt.axvspan(0, 10, alpha=0.2, color="r", label="galaxy linear paramters")
#         plt.axvspan(10, 60, alpha=0.2, label="qso linear paramters")
#         plt.xlim(0, 59)
#         plt.ylabel("number of negatives by parameter x eigvector")
#         plt.xlabel("parameter number")
#         plt.legend()
#         plt.show()

#     def sep_componentes(self, n):
#         # combination = eigenvectors[0].T * 100 * params_linear[0]
#         fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(20, 10))
#         combination = self.combination[:, :, n]
#         for i, spec in enumerate(combination.T):
#             if i < 10:
#                 # Plot on the left subplot
#                 ax1.plot(spec, c="r", alpha=0.2)
#             else:
#                 # Plot on the right subplot
#                 ax2.plot(spec, c="b", alpha=0.1)

#         # (Optional) adjust x-limits or other axes properties if you wish:
#         ax1.set_xlim(left=0, right=combination.shape[0])  # Or omit for auto-limits
#         ax2.set_xlim(left=0, right=combination.shape[0])  # Or omit for auto-limits

#         ax1.set_title("galaxy(10)")
#         ax2.set_title("qso(50)")

#         plt.tight_layout()
#         plt.show()

#     # if save:
#     #     plt.savefig(f"images/images_pca/{n}.jpg", dpi=300, bbox_inches='tight')
#     #     plt.close()
#     # else:
#     #     plt.show()
#     # brea


[docs] def plot_region(x, function, region, save=''): fig, ax1 = plt.subplots(1, 1, figsize=(35, 15)) ax1.plot(x, function) min_y, max_y = ax1.get_ylim() # Dictionary to keep track of how many times a line_name has been plotted text_offsets = {} for i, line in enumerate(region): # Unpack line details assuming the dictionary has these keys center = line.get("center") kind = line.get("kind") amplitude = line.get("amplitude") line_name = line.get("line_name") # Plot vertical line with a dashed style ax1.axvline(center, linestyle="--", color="red", linewidth=2, alpha=0.5) # Compute offset for the text label if the same line_name is already plotted if line_name in text_offsets: offset = text_offsets[line_name] text_offsets[line_name] += 2 else: offset = 0 text_offsets[line_name] = 2 n = 1 if "h" in line_name: n = 6 # Adjust text position: for each duplicate, move the text downwards by 5% of max_y. text_y = max_y / n - (offset * 0.05 * max_y) ax1.text( center, text_y, f" {line_name}\n {kind}", fontsize=16, rotation=0, verticalalignment="top", color="k", zorder=10, horizontalalignment="left", ) ax1.set_xlim(x[0], x[-1]) ax1.set_ylim(0.0, max_y + max_y / 10) # Add labels to x and y axes with larger fonts ax1.set_xlabel("rest-wavelength", fontsize=20) ax1.set_ylabel("Flux", fontsize=20) if save: plt.savefig(f"{save}.jpg") plt.show()
[docs] class SheapPlot_old: """This aim to be the main class to plot the results from sheap.""" def __init__( self, test_clase, fit_region_g, mask_fit, mask_fit_g, masked_uncertainties_g, Master_Gaussian, params_g, Baselines, outer_limits, AN, EWfin, signal_noise_region, host_detected, host_flux, ): self.test_clase = test_clase self.fit_region_g = fit_region_g self.mask_fit = mask_fit self.masked_uncertainties_g = masked_uncertainties_g self.Master_Gaussian = Master_Gaussian self.params_g = params_g self.Baselines = Baselines self.outer_limits = outer_limits self.AN = AN self.EWfin = EWfin self.signal_noise_region = signal_noise_region self.host_detected = host_detected # index self.mask_fit_g = mask_fit_g self.host_flux = host_flux
[docs] def plot_combined(self, n, save='', add_baseline=True, pandas=None, **kwargs): fig, (ax1, ax2) = plt.subplots( 2, 1, figsize=(35, 15), gridspec_kw={'height_ratios': [2, 1]} ) # Load necessary data x_axis = self.test_clase.spectra[n, 0, :] # Already redshift corrected y_spectrum = self.test_clase.spectra[n, 1, :] y_err = self.test_clase.spectra[n, 2, :] colors = plt.cm.Pastel1(np.linspace(0, 1, 9)) if not jnp.all(self.host_flux == 0): add_all = kwargs.get("add_all", False) spectros = {} if add_all: for i, key in enumerate( ['Spectra redshift corrected', "host", "AGN(spectra-host)"] ): if key == 'Spectra redshift corrected': y = y_spectrum yerr = y_err elif key == "host": y = self.host_flux[n] yerr = None elif key == 'AGN(spectra-host)': y = y_spectrum - self.host_flux[n] yerr = None ax1.errorbar( x_axis, y, yerr=yerr, ecolor='lightskyblue', label=key, zorder=3, alpha=0.8, ) else: ax1.errorbar( x_axis, y_spectrum, yerr=y_err, color='b', ecolor='lightskyblue', label='Spectra redshift corrected', zorder=1, ) ax1.set_ylabel( r'$ f_{\lambda}$ ($\rm 10^{-17} {\rm erg\;s^{-1}\;cm^{-2}\;\AA^{-1}}$)', fontsize=40, ) ax1.set_xlabel(r'$\rm Rest \ Wavelength$ ($\rm \AA$)', fontsize=40) ax1.tick_params(which="both", length=10, width=2, labelsize=35) # Set axis limits ax1.set_xlim(np.nanmin(x_axis), np.nanmax(x_axis)) if "xlim_sheap" in kwargs: ax1.set_xlim(*kwargs["xlim_sheap"]) if "ylim_sheap" in kwargs: ax1.set_ylim(*kwargs["ylim_sheap"]) ylimit = ax1.get_ylim() # Apply masking efficiently if "mask" in kwargs: mask = kwargs["mask"] mask_x = np.logical_and(x_axis >= min(mask), x_axis <= max(mask)) ax1.fill_between( x_axis, *ylimit, where=mask_x, color='grey', alpha=0.5, label='Mask', zorder=1 ) ax1.legend( loc='lower center', bbox_to_anchor=(0.5, 1), fancybox=True, shadow=False, ncol=4, fontsize=30, ) ax1.set_ylim(*ylimit) # Plot Local Spectra fit_x = self.fit_region_g[n][0][~self.mask_fit[n]] fit_y = self.fit_region_g[n][1][~self.mask_fit[n]] fit_err = self.masked_uncertainties_g[n][~self.mask_fit[n]] ax2.plot(x_axis, y_spectrum, color='red', label='Spectra') ax2.errorbar( fit_x, fit_y, yerr=fit_err, fmt='o', color='red', alpha=0.5, label='Fit Region with Uncertainties', ) baseline_plus_gaussian = self.Master_Gaussian.func(x_axis, self.params_g[n]) if add_baseline: baseline_plus_gaussian += self.Baselines[n] ax2.plot(x_axis, baseline_plus_gaussian, label='Baseline + Gaussian Fit') ax2.fill_between( x_axis, 0, np.nanmax(y_spectrum), where=self.mask_fit[n], color='grey', alpha=0.5, label='Mask for Linear Fit', zorder=10, ) ax2.fill_between( x_axis, 0, np.nanmax(y_spectrum), where=self.mask_fit_g[n], color='green', alpha=0.5, label='Mask for Gaussian Fit', zorder=10, ) ax2.set_xlim(self.outer_limits) median_val = np.median(fit_y) if "xlim_local" in kwargs: arg_min = np.nanargmin(abs(x_axis - min(kwargs["xlim_local"]))) arg_max = np.nanargmin(abs(x_axis - max(kwargs["xlim_local"]))) median_val = np.nanmedian(y_spectrum[arg_min:arg_max]) ax2.set_xlim(*kwargs["xlim_local"]) if np.isnan(median_val): median_val = np.nanmedian(y_spectrum) ax2.set_ylim([median_val * 0.1, median_val * 1.9]) if "ylim_local" in kwargs: ax2.set_ylim(*kwargs["ylim_local"]) if isinstance(pandas, pd.DataFrame): text_string = "" for k in [ "EWfin", "vel", "host_detected", "AN", "signal_noise_region", "rAGN_stars_5100", "rAGN_Cont_5100", "stars_Cont_5100", "agn_slope", ]: if k == "host_detected": text_string += f"{k} :{bool(pandas.iloc[n][k])}\n" else: text_string += f"{k} :{pandas.iloc[n][k]:.2f}\n" # = pandas.iloc[n][["EWfin","rAGN_Cont_5100","vel","index","AN","rAGN_stars_5100"]].values # matched_df else: text_string = ( f"AN: {self.AN[n]:.2f}\n" f"EWfin: {self.EWfin[n]:.2f}\n" f"vel: {self.vel[n]:.2f}\n" f"Signal noise region: {self.signal_noise_region[n]:.2f}\n" f"is the host detected?: {n in self.host_detected}" ) fig.text( 0.71, 0.7, text_string, fontsize=20, color='blue', verticalalignment='center', horizontalalignment='left', bbox=dict(facecolor='white', alpha=0.5, edgecolor='none'), ) ax2.set_title(f"Local Spectra for n={n}", fontsize=20) ax2.legend(loc='upper left', bbox_to_anchor=(1.05, 1), borderaxespad=0.0, fontsize=15) plt.tight_layout(rect=[0, 0, 0.85, 1]) # Cleanup large arrays del baseline_plus_gaussian, fit_x, fit_y, fit_err, y_spectrum, y_err # Save or show figure if save: plt.savefig(f"images/{save}.jpg", dpi=300, bbox_inches='tight') plt.close() else: plt.show()
[docs] def plot_a_spectra(spectra, save=None, **kwargs): # Unpack the spectra data x_axis, y_axis, yerr = spectra # Determine x-axis limits (either passed in or computed from data) xlim = kwargs.get("xlim", [np.nanmin(x_axis), np.nanmax(x_axis)]) # Optional object name to annotate; default if not provided object_name = kwargs.get("object_name", "Unknown Object") # Define spectral regions with wavelength ranges. regions = { 'La': [1000, 1500], 'CIV,cIII': [1100, 2000], 'MgII': [2500, 3000], 'Hbeta': [4400, 5600], 'Halpha': [5600, 7300], } # Filter regions that are completely inside the main x-axis limits. regions_to_plot = [] for region, (start, end) in regions.items(): if xlim[0] <= start and xlim[1] >= end: regions_to_plot.append((region, start, end)) n_regions = len(regions_to_plot) # print(n_regions) # If we have regions to show, create a two-row layout: top for regions, bottom for main plot. if n_regions > 0: # Create a figure with overall size (20,10) fig = plt.figure(figsize=(20, 10)) # Define a GridSpec with 2 rows: # - top row: one column per region (each region gets its own subplot) # - bottom row: main plot spans all columns. # Adjust height ratios to give more space to the main plot. gs = gridspec.GridSpec(2, n_regions, height_ratios=[1, 2], hspace=0.1) # Create subplots for each region in the top row. for i, (region, start, end) in enumerate(regions_to_plot): ax = fig.add_subplot(gs[0, i]) # Extract data in this spectral region. mask = (x_axis >= start) & (x_axis <= end) if np.any(mask): if yerr is not None: ax.errorbar( x_axis[mask], y_axis[mask], yerr=yerr[mask], fmt='-', lw=1, c="k" ) else: ax.plot(x_axis[mask], y_axis[mask], '-', lw=1, c="k") # Optionally adjust y-limits based on the data within the region. ax.set_ylim(np.nanmin(y_axis[mask]), np.nanmax(y_axis[mask]) * 1.1) ax.set_xlim(start, end) # Add annotation text in the top left corner of each region subplot. annotation = f"Region: {region}" ax.text( 0.05, 0.95, annotation, transform=ax.transAxes, fontsize=10, verticalalignment='top', bbox=dict(facecolor='white', alpha=0.5, edgecolor='none'), ) ax.set_xlabel("Wavelength", fontsize=10) ax.set_ylabel("Flux", fontsize=10) ax.tick_params(axis='both', labelsize=8) # Create the main plot on the bottom row spanning all top-row columns. ax_main = fig.add_subplot(gs[1, :]) else: # If no regions qualify, just create a single axis. fig, ax_main = plt.subplots(figsize=(20, 10)) # Plot the full spectrum on the main axis. # (x_axis,y_axis,yerr=yerr,ecolor='dimgray',c="k",zorder=1) if yerr is not None: ax_main.errorbar(x_axis, y_axis, yerr=yerr, c="k", fmt='-', ecolor='dimgray', lw=1) else: ax_main.plot(x_axis, y_axis, c="k", fmt='-', lw=1) ax_main.set_xlim(xlim) ax_main.set_xlabel('Wavelength', fontsize=12) ax_main.set_ylabel('Flux', fontsize=12) # Add annotation text in the top right corner of the main plot. ax_main.text( 0.15, 0.95, f"Object: {object_name}", transform=ax_main.transAxes, fontsize=12, verticalalignment='top', horizontalalignment='right', bbox=dict(facecolor='white', alpha=0.5, edgecolor='none'), ) # ax_main.legend(fontsize=10) plt.tight_layout() if save: plt.savefig(save, dpi=300, bbox_inches='tight') plt.close() else: plt.show()