Source code for fiesta.inference.plot

import matplotlib
import matplotlib.pyplot as plt

pltparams = {"axes.grid": False,
        "text.usetex" : True,
        "font.family" : "serif",
        "ytick.color" : "black",
        "xtick.color" : "black",
        "axes.labelcolor" : "black",
        "axes.edgecolor" : "black",
        "font.serif" : ["Computer Modern Serif"],
        "xtick.labelsize": 16,
        "ytick.labelsize": 16,
        "axes.labelsize": 16,
        "legend.fontsize": 16,
        "legend.title_fontsize": 16,
        "figure.titlesize": 16,
        "figure.constrained_layout.use": False}

plt.rcParams.update(pltparams)

import numpy as np
import pandas as pd
from jaxtyping import Array
import jax
import jax.numpy as jnp

from fiesta.logging import logger
from fiesta.inference.systematic import process_file
from fiesta.inference.likelihood import LikelihoodBase


#############################
# DEFAULT SETTINGS / LABELS #
#############################


default_corner_kwargs = dict(bins=40, 
                        smooth=True, 
                        label_kwargs=dict(fontsize=14),
                        title_kwargs=dict(fontsize=14), 
                        quantiles=[],
                        levels=[0.68, 0.95],
                        plot_density=False, 
                        plot_datapoints=False, 
                        fill_contours=True,
                        max_n_ticks=3,
                        min_n_ticks=3,
                        save=False,
                        truth_color="darkorange",
                        labelpad=0.2)

latex_labels=dict(inclination_EM="$\\iota$ [rad]",
                  log10_E0="$\\log_{10}(E_0)$ [erg]", 
                  thetaCore="$\\theta_{\\mathrm{c}}$ [rad]", 
                  thetaWing="$\\theta_{\\mathrm{w}}$ [rad]", 
                  alphaWing="$\\alpha_{\\mathrm{w}}$", 
                  log10_n0="$\\log_{10}(n_{\\mathrm{ism}})$ [cm$^{-3}$]",
                  p="$p$", 
                  log10_epsilon_e="$\\log_{10}(\\epsilon_e)$",
                  log10_epsilon_B="$\\log_{10}(\\epsilon_B)$",
                  epsilon_e="$\\epsilon_e$",
                  epsilon_B="$\\epsilon_B$",
                  log10_mej_dyn="$\\log_{10}(m_{\\mathrm{ej,dyn}})$ [$M_\\odot$]",
                  log10_mej_wind="$\\log_{10}(m_{\\mathrm{ej,wind}})$ [$M_\\odot$]",
                  v_ej_dyn="$\\bar{v}_{\\mathrm{ej,dyn}}$ [$c$]",
                  v_ej_wind="$\\bar{v}_{\\mathrm{ej,wind}}$ [$c$]",
                  Ye_dyn="$\\bar{Y}_{e,\\mathrm{dyn}}$",
                  Ye_wind="$Y_{e,\\mathrm{wind}}$",
                  luminosity_distance="$d_L$ [Mpc]",
                  redshift="$z$",
                  sys_err="$\\sigma_{\\mathrm{sys}}$ [mag]",
                  Gamma0="$\\Gamma_0$",
                  timeshift = "$t_0$ [days]",
                  log10_Menv = "$\\log_{10}(M_{\\rm e})$ [$M_\\odot$]",
                  log10_Renv = "$\\log_{10}(R_{\\rm e})$ [cm]",
                  log10_Ee = "$\\log_{10}(E_e)$ [erg]", 
                  em_syserr = "$\\sigma_{\\rm sys}$ [mag]",
                  Ebv = "$E(B-V)$ [mag]",
                  amplitude = "$A$",
                  supernova_mag_stretch = "$s$")



#######################
#  PLOTTING FUNCTIONS #
#######################


[docs] def corner_plot(posterior: dict | pd.DataFrame, parameter_names: list[str], truths: dict | None = None, color: str = "blue", legend_label: str | None = None, fig: matplotlib.figure.Figure | None = None, ax: matplotlib.axes.Axes | None = None, **kwargs): """ Make a nice corner plot from the posterior with automated parameter labels. Args: posterior (dict | pd.DataFrame): posterior samples for which to do the corner plot. parameter_names (list[str]): parameters from posterior that should be included in the corner plot. truths (dict[str, float] | None): True (injected values) for some of the parameters. Defaults to None. color (str): color for the corner plot contours. Defaults to blue. legend_label (str): Label for the legend. If not set, no legend will be shown. Defaults to None. fig (matplotlib.figure.Figure): Figure over which to do the corner plot. If set, ax must also be provided. Defaults to None. ax (matplotlib.axes.Axes): Axes over which to do the corner plot. If set, fig must also be provided. Defaults to None. Returns: fig (matplotlib.figure.Figure): Figure with the corner plot. ax (matplotlib.axes.Axes): array of axes """ try: import corner except ImportError: logger.warning("Install corner to create corner plots.") return None, None if truths is None: truths = {} posterior = pd.DataFrame(posterior) labels= [] truths_list = [] for p in parameter_names: labels.append(latex_labels.get(p, p)) truths_list.append(truths.get(p, None)) if fig is None and ax is None: n = len(parameter_names) fig, ax = plt.subplots(n, n, figsize = (n*1.5, n*1.5)) if (fig is None and ax is not None) or (fig is not None and ax is None): raise ValueError("fig and ax must be either both be specified or both be None.") corner_args = default_corner_kwargs.copy() corner_args.update(kwargs) corner.corner(posterior[parameter_names], fig=fig, color=color, labels=labels, truths=truths_list, **corner_args, hist_kwargs=dict(density=True, color=color)) if legend_label is not None: if len(parameter_names) < 4: lx, ly = 0, -1 else: lx, ly = 1, 4 handle = plt.plot([],[], color=color)[0] ax[lx, ly].legend(handles=[handle], labels=[legend_label], fontsize=15, fancybox=False, framealpha=1) #fig.tight_layout() return fig, ax
# TODO: superpose multiple posteriors in one corner plot
[docs] class LightcurvePlotter: """ Interface to plot lightcurves from a given posterior. Args: posterior (dict | pd.DataFrame): Posterior samples for which the light curves should be plotted. likelihood (EMLikelihood): Likelihood object that was used to sample the posterior. systematics_file (str): Systematics file that was used to sample the posterior. Defaults to None. free_syserr (bool): Whether a global systematic uncertainty was sampled freely. Defaults to False. Will overwrite systematics_file. """ def __init__(self, posterior: dict | pd.DataFrame, likelihood: LikelihoodBase, systematics_file: str = None, free_syserr=False): self.systematics = "fixed" if systematics_file is not None: self.systematics = "from_file" sys_params_per_filter, t_nodes_per_filter, _= process_file(systematics_file, filters=likelihood.filters) likelihood._setup_sys_uncertainty_from_file(sys_params_per_filter, t_nodes_per_filter) self.systematics_file = systematics_file if free_syserr: self.systematics = "free" likelihood._setup_sys_uncertainty_free() self.likelihood = likelihood self.tmin = likelihood.data_tmin self.tmax = likelihood.data_tmax self.times_det = likelihood.times_det self.times_nondet = likelihood.times_nondet if hasattr(likelihood, "zero_point_mag"): def flux_to_mag(flux_arr): return -2.5*np.log10(flux) + likelihood.zero_point_mag self.mag_det = jax.tree.map(flux_to_mag, likelihood.datapoints_det) self.mag_nondet = jax.tree.map(flux_to_mag, likelihood.datapoints_nondet) self.mag_err = -2.5 * jax.tree.map(lambda x, y: x/y, likelihood.datapoints_err, likelihood.datapoints_det) else: self.mag_det = likelihood.datapoints_det self.mag_err = likelihood.datapoints_err self.mag_nondet = likelihood.datapoints_nondet self.model = likelihood.model self.posterior = pd.DataFrame(posterior) self.fixed_params = likelihood.fixed_params
[docs] def plot_data(self, ax: matplotlib.axes.Axes, filt: str, zorder=3, **kwargs): """ Plots data points from a filter over ax. Args: ax (matplotlib.axes.Axes): ax to plot the data points to. filt (str): Which filter from the data should be plotted on ax. zorder (int): zorder with which the data points should be plotted. **kwargs: kwargs to be passed to errorbar and scatter. """ # Detections t, mag, err = self.times_det[filt], self.mag_det[filt], self.mag_err[filt] ax.errorbar(t, mag, yerr=err, fmt="o", zorder=zorder, **kwargs) # Non-detections t, mag = self.times_nondet[filt], self.mag_nondet[filt] ax.scatter(t, mag, zorder=zorder, marker="v", **kwargs)
[docs] def plot_best_fit_lc(self, ax: matplotlib.axes.Axes, filt: str, zorder=2, **kwargs): """ Plots one filter from the best fit light curve from the posterior over ax. Args: ax (matplotlib.axes.Axes): ax to plot the light curve onto. filt (str): Which filter from the best fit lightcurve should be plotted on ax. zorder (int): zorder with which the lightcurve should be plotted. **kwargs: kwargs to be passed to plot. """ self._get_best_fit_lc() ax.plot(self.t_best_fit, self.best_fit_lc[filt], zorder=zorder, **kwargs)
def _get_best_fit_lc(self,): if hasattr(self, "_best_fit_lc_determined"): return best_ind = np.argmax(self.posterior["log_likelihood"]) self.best_fit_params = self.posterior.iloc[best_ind].to_dict() self.best_fit_params.update(self.fixed_params) self.best_fit_params = self.likelihood.conversion(self.best_fit_params) t, model_mag = self.model.predict(self.best_fit_params) mask = (t>=self.tmin) & (t<=self.tmax) self.t_best_fit = t[mask] self.best_fit_lc = {} for filt in model_mag.keys(): self.best_fit_lc[filt] = model_mag[filt][mask] self._best_fit_lc_determined = True
[docs] def plot_sample_lc(self, ax: matplotlib.axes.Axes, filt: str, zorder=1): """ Plots background light curves from the posterior over ax. Args: ax (matplotlib.axes.Axes): ax to plot the light curve onto. filt (str): Which filter from the background light curves should be plotted on ax. zorder (int): zorder with which the lightcurve should be plotted. """ self._get_samples_lcs() for j in range(200): ax.plot(self.t_sample_lc[j], self.sample_lc[filt][j], color="grey", alpha=0.05, zorder=zorder, rasterized=True)
def _get_samples_lcs(self,): if hasattr(self, "_sample_lcs_determined"): return total_nb_samples = self.posterior.values.shape[0] ind = np.random.choice(total_nb_samples, 200, replace=False) params = {} for key in self.posterior.keys(): params[key] = self.posterior[key][ind].to_numpy() for key in self.fixed_params: params[key] = np.ones(200) * self.fixed_params[key] params = self.likelihood.conversion(params) self.t_sample_lc, self.sample_lc = self.model.vpredict(params) self._sample_lcs_determined = True
[docs] def plot_sys_uncertainty_band(self, ax: matplotlib.axes.Axes, filt: str, zorder=2, **kwargs): """ Plots systematic uncertainty band from the best fit light curve for one filter over ax. Args: ax (matplotlib.axes.Axes): ax to plot the band onto. filt (str): Which filter from the band should be plotted on ax. zorder (int): zorder with which the band should be plotted. **kwargs: kwargs to be passed to fill_between. """ self._get_best_fit_lc() if self.systematics=="from_file": sys_params_per_filter, t_nodes_per_filter, _ = process_file(self.systematics_file, [filt]) sys_params_per_filter = sys_params_per_filter[filt] t_nodes_per_filter = t_nodes_per_filter[filt] if t_nodes_per_filter is None: t_nodes_per_filter = np.linspace(self.tmin, self.tmax, len(sys_params_per_filter)) sys_param_array = np.array([self.best_fit_params[p] for p in sys_params_per_filter]) sigma_sys = np.interp(self.t_best_fit, t_nodes_per_filter, sys_param_array) elif self.systematics=="free": sigma_sys = self.best_fit_params["sys_err"] else: sigma_sys = self.likelihood.error_budget ax.fill_between(self.t_best_fit, self.best_fit_lc[filt] + sigma_sys, self.best_fit_lc[filt] - sigma_sys, alpha=0.1, zorder=zorder, **kwargs)
[docs] def get_chisquared(self, per_dof: bool=False): """ Get the total chisquared value and the chisquared values per filter. This is different from the log_likelihood value in the posterior, because the likelihood function contains (2 pi sigma)^(-1/2). Args: per_dof (bool): Whether to return reduced chi-squared values, i.e., per number of data points. Returns: tuple(float, dict): The total chi-squared value across all data points and a dict with the chi-squared value in each filter. """ self._get_best_fit_lc() mag_est_det = jax.tree.map(lambda t, m: jnp.interp(t, self.t_best_fit, m, left = "extrapolate", right = "extrapolate"), # TODO extrapolation is maybe problematic here self.times_det, self.best_fit_lc) mag_est_nondet = jax.tree.map(lambda t, m: jnp.interp(t, self.t_best_fit, m, left = "extrapolate", right = "extrapolate"), self.times_nondet, self.best_fit_lc) # Get the systematic uncertainty + data uncertainty sigma = self.likelihood.get_sigma(self.best_fit_params) nondet_sigma = self.likelihood.get_nondet_sigma(self.best_fit_params) # Get chisq chisq_dict = jax.tree.map(lambda mstar, md, s: jnp.sum((mstar-md)**2/s**2), mag_est_det, self.mag_det, sigma) chisq_total = sum(chisq_dict.values()) if per_dof: n_data_total = 0 for key in self.mag_det.keys(): n_data = self.mag_det[key].shape[0] chisq_dict[key] = chisq_dict[key] / n_data n_data_total += n_data chisq_total /= n_data_total chisq_dict["chisqu_total"] = chisq_total return chisq_dict