fiesta.inference

Contents

fiesta.inference#

Components for Bayesian parameter estimation.

Analytical Models#

Base classes, constants, and shared helpers for analytical light-curve models.

Each model is fully JIT-compilable and differentiable so that flowMC’s MALA sampler can compute jax.grad through the likelihood. The models follow the same predict() contract as the surrogate models:

(source_frame_times, {filter_name: apparent_mag_array})

This makes them drop-in replacements inside CombinedSurrogate and EMLikelihood.

All internal physics computations use log10 space to avoid float32 overflow (e.g. explosion energies ~1e49 erg exceed float32 max ~3.4e38).

class fiesta.inference.analytical_models.base.AnalyticalModel(filters, times=None, temperature_floor=None)[source]#

Bases: object

Base class for analytical (non-surrogate) light-curve models.

Subclasses must implement compute_log10_lbol_rphot(self, x, t_days) which returns (log10_Lbol, log10_Rphot) — log10 of bolometric luminosity in erg/s and photospheric radius in cm.

add_filter(filters)[source]#
compute_log10_lbol_rphot(x, t_days)[source]#

Return (log10_L_bol, log10_R_phot) arrays at each time in t_days.

L_bol in erg/s, R_phot in cm.

Return type:

tuple[Array, Array]

filters: list[str]#
parameter_names: list[str]#
predict(x)[source]#
Return type:

tuple[Array, dict[str, Array]]

times: Array#

Phenomenological (flux-shape) light-curve models.

Reference:

Redback: nikhil-sarin/redback Boom: boom-astro/boom

class fiesta.inference.analytical_models.phenomenological_models.AfterglowModel(filters, times=None)[source]#

Bases: PhenomenologicalModel

Smooth broken power-law afterglow model.

Reference:

Boom: boom-astro/boom

Transitions from r^(-alpha_1) at early times to r^(-alpha_2) at late times.

Shape parameters: t0, log10_t_break, alpha_1, alpha_2

compute_shape(x, t_days)[source]#

Return the temporal shape function S(t) >= 0.

has_baseline: bool = False#
shape_parameter_names: list[str] = ['t0', 'log10_t_break', 'alpha_1', 'alpha_2']#
class fiesta.inference.analytical_models.phenomenological_models.BazinModel(filters, times=None)[source]#

Bases: PhenomenologicalModel

Bazin et al. phenomenological light-curve model.

Reference:

Boom: boom-astro/boom

Shape: exp(-dt/tau_fall) * sigmoid(dt/tau_rise)

Parameters (per-band): amp_mag_{filter}, base_mag_{filter} Shape parameters: t0, log10_tau_rise, log10_tau_fall

compute_shape(x, t_days)[source]#

Return the temporal shape function S(t) >= 0.

has_baseline: bool = True#
shape_parameter_names: list[str] = ['t0', 'log10_tau_rise', 'log10_tau_fall']#
class fiesta.inference.analytical_models.phenomenological_models.EvolvingBlackbodyModel(filters, times=None, reference_time=1.0)[source]#

Bases: AnalyticalModel

Phenomenological model with piecewise power-law T and R evolution.

Reference:

Redback: nikhil-sarin/redback

Model-agnostic — useful for fast empirical fitting of any thermal transient. Based on the evolving_blackbody model from Redback.

Parameters (in x dict):

log10_temperature_0 – log10 initial temperature (K) at reference_time log10_radius_0 – log10 initial radius (cm) at reference_time temp_rise_index – T rise power-law index for t <= temp_peak_time temp_decline_index – T decline power-law index for t > temp_peak_time temp_peak_time – time (days) when temperature peaks radius_rise_index – R rise power-law index for t <= radius_peak_time radius_decline_index – R decline power-law index for t > radius_peak_time radius_peak_time – time (days) when radius peaks

compute_log10_lbol_rphot(x, t_days)[source]#

Return (log10_L_bol, log10_R_phot) arrays at each time in t_days.

L_bol in erg/s, R_phot in cm.

parameter_names: list[str] = ['log10_temperature_0', 'log10_radius_0', 'temp_rise_index', 'temp_decline_index', 'temp_peak_time', 'radius_rise_index', 'radius_decline_index', 'radius_peak_time']#
class fiesta.inference.analytical_models.phenomenological_models.PhenomenologicalModel(filters, times=None)[source]#

Bases: AnalyticalModel

Base class for phenomenological light-curve models.

Unlike physics-based models that compute L_bol + R_phot and pass through a blackbody SED, phenomenological models compute a temporal shape function S(t) and convert directly to per-band apparent magnitudes.

Subclasses must set:

shape_parameter_names : list[str] — shared temporal shape parameters has_baseline : bool — whether the model has a baseline flux component

and implement compute_shape(self, x, t_days) -> Array.

add_filter(filters)[source]#
compute_shape(x, t_days)[source]#

Return the temporal shape function S(t) >= 0.

Return type:

Array

has_baseline: bool = False#
predict(x)[source]#
Return type:

tuple[Array, dict[str, Array]]

shape_parameter_names: list[str]#
class fiesta.inference.analytical_models.phenomenological_models.PhenomenologicalTDEModel(filters, times=None)[source]#

Bases: PhenomenologicalModel

Phenomenological TDE light-curve model.

Reference:

Boom: boom-astro/boom

Sigmoid rise with power-law decay.

Shape parameters: t0, log10_tau_rise, log10_tau_fall, alpha_decay

compute_shape(x, t_days)[source]#

Return the temporal shape function S(t) >= 0.

has_baseline: bool = True#
shape_parameter_names: list[str] = ['t0', 'log10_tau_rise', 'log10_tau_fall', 'alpha_decay']#
class fiesta.inference.analytical_models.phenomenological_models.VillarModel(filters, times=None)[source]#

Bases: PhenomenologicalModel

Villar et al. phenomenological light-curve model.

Reference:

Boom: boom-astro/boom

Piecewise shape with smooth sigmoid transition at gamma.

Shape parameters: t0, log10_tau_rise, log10_tau_fall, beta_slope, log10_gamma

compute_shape(x, t_days)[source]#

Return the temporal shape function S(t) >= 0.

constraint_penalty(x)[source]#

Physical validity penalty (de Soto et al. 2024).

Returns 0 for valid parameters, positive for invalid. Multiply by a large negative factor and add to log-likelihood to enforce.

has_baseline: bool = False#
shape_parameter_names: list[str] = ['t0', 'log10_tau_rise', 'log10_tau_fall', 'beta_slope', 'log10_gamma']#

Supernova analytical light-curve models.

Reference:

Redback: nikhil-sarin/redback NMMA: nuclear-multimessenger-astronomy/nmma

class fiesta.inference.analytical_models.supernova_models.ArnettModel(filters, times=None, modified=False)[source]#

Bases: AnalyticalModel

Arnett (1982) Ni56/Co56-powered supernova bolometric model.

Reference:

Redback: nikhil-sarin/redback NMMA: nuclear-multimessenger-astronomy/nmma

Parameters (in x dict):

tau_m – diffusion timescale in days log10_mni – log10 of Ni56 mass in solar masses v_phot – photospheric velocity in units of 1e9 cm/s t_0 – (modified variant only) gamma-ray trapping timescale in days

compute_log10_lbol_rphot(x, t_days)[source]#

Return (log10_L_bol, log10_R_phot) arrays at each time in t_days.

L_bol in erg/s, R_phot in cm.

parameter_names: list[str] = ['tau_m', 'log10_mni', 'v_phot']#
class fiesta.inference.analytical_models.supernova_models.CSMInteractionModel(filters, times=None, nn=12, delta=1, efficiency=0.5, temperature_floor=None)[source]#

Bases: AnalyticalModel

Circumstellar medium interaction model (Chevalier 1982).

Reference:

Redback: nikhil-sarin/redback

Forward + reverse shock luminosity from Chevalier self-similar solution, with optional CSM diffusion.

Parameters (in x dict):

log10_mej – log10 of ejecta mass in solar masses log10_csm_mass – log10 of CSM mass in solar masses log10_vej – log10 of ejecta velocity in km/s eta – CSM density profile exponent log10_rho – log10 of CSM density amplitude (g/cm^{eta+3}) log10_kappa – log10 of opacity (cm^2/g) log10_r0 – log10 of CSM inner radius in AU

Constructor kwargs:

nn – ejecta power-law index (default 12) delta – inner density exponent (default 1) efficiency – kinetic-to-luminosity conversion (default 0.5)

compute_log10_lbol_rphot(x, t_days)[source]#

Return (log10_L_bol, log10_R_phot) arrays at each time in t_days.

L_bol in erg/s, R_phot in cm.

parameter_names: list[str] = ['log10_mej', 'log10_csm_mass', 'log10_vej', 'eta', 'log10_rho', 'log10_kappa', 'log10_r0']#
class fiesta.inference.analytical_models.supernova_models.MagnetarPoweredSNModel(filters, times=None, temperature_floor=None)[source]#

Bases: AnalyticalModel

Magnetar spin-down powered supernova with Arnett (1982) diffusion.

Reference:

Redback: nikhil-sarin/redback

Parameters (in x dict):

log10_p0 – log10 initial spin period in ms log10_bp – log10 polar B-field in 1e14 G mass_ns – neutron star mass in solar masses theta_pb – angle between spin and B-field in radians log10_mej – log10 of ejecta mass in solar masses log10_vej – log10 of ejecta velocity in km/s log10_kappa – log10 of opacity (cm^2/g) log10_kappa_gamma – log10 of gamma-ray opacity (cm^2/g)

compute_log10_lbol_rphot(x, t_days)[source]#

Return (log10_L_bol, log10_R_phot) arrays at each time in t_days.

L_bol in erg/s, R_phot in cm.

parameter_names: list[str] = ['log10_p0', 'log10_bp', 'mass_ns', 'theta_pb', 'log10_mej', 'log10_vej', 'log10_kappa', 'log10_kappa_gamma']#
class fiesta.inference.analytical_models.supernova_models.NickelCobaltModel(filters, times=None, temperature_floor=None)[source]#

Bases: AnalyticalModel

Ni56/Co56 radioactive decay with Arnett (1982) diffusion.

Reference:

Redback: nikhil-sarin/redback

Parameters (in x dict):

f_nickel – fraction of ejecta mass in Ni56 log10_mej – log10 of ejecta mass in solar masses log10_vej – log10 of ejecta velocity in km/s log10_kappa – log10 of opacity (cm^2/g) log10_kappa_gamma – log10 of gamma-ray opacity (cm^2/g)

compute_log10_lbol_rphot(x, t_days)[source]#

Return (log10_L_bol, log10_R_phot) arrays at each time in t_days.

L_bol in erg/s, R_phot in cm.

parameter_names: list[str] = ['f_nickel', 'log10_mej', 'log10_vej', 'log10_kappa', 'log10_kappa_gamma']#

Kilonova analytical light-curve models.

Reference:

Redback: nikhil-sarin/redback NMMA: nuclear-multimessenger-astronomy/nmma

class fiesta.inference.analytical_models.kilonova_models.MagnetarBoostedKilonovaModel(filters, times=None, neutron_precursor=True, pair_cascade=True, vmax=0.7, magnetar_heating='first_layer')[source]#

Bases: AnalyticalModel

Multi-shell kilonova with magnetar spin-down heating, matching Redback.

Reference:

Redback: _general_metzger_magnetar_driven_kilonova_model

200-shell ODE with magnetar injection into bottom layer, velocity evolution, optional pair cascade and neutron precursor.

Parameters (in x dict):

log10_mej – log10 ejecta mass in solar masses log10_vej – log10 ejecta velocity (vmin) in units of c beta – velocity power-law index log10_kappa_r – log10 opacity in cm^2/g log10_p0 – log10 initial spin period in ms log10_bp – log10 polar B-field in 1e14 G mass_ns – neutron star mass in solar masses theta_pb – angle between spin and B-field in radians thermalisation_efficiency – magnetar thermalisation efficiency

compute_log10_lbol_rphot(x, t_days)[source]#

Return (log10_L_bol, log10_R_phot) arrays at each time in t_days.

L_bol in erg/s, R_phot in cm.

parameter_names: list[str] = ['log10_mej', 'log10_vej', 'beta', 'log10_kappa_r', 'log10_p0', 'log10_bp', 'mass_ns', 'theta_pb', 'thermalisation_efficiency']#
class fiesta.inference.analytical_models.kilonova_models.MetzgerFullModel(filters, times=None, neutron_precursor=True, vmax=0.7)[source]#

Bases: AnalyticalModel

Multi-shell kilonova model (Metzger 2017), matching Redback exactly.

Reference:

Redback: _metzger_kilonova_model in kilonova_models.py

200 shells with linear velocity spacing, Barnes+16 thermalisation, optional neutron precursor, per-gram energy ODE.

Parameters (in x dict):

log10_mej – log10 ejecta mass in solar masses log10_vej – log10 ejecta velocity (vmin) in units of c beta – velocity power-law index log10_kappa_r – log10 opacity in cm^2/g

compute_log10_lbol_rphot(x, t_days)[source]#

Return (log10_L_bol, log10_R_phot) arrays at each time in t_days.

L_bol in erg/s, R_phot in cm.

parameter_names: list[str] = ['log10_mej', 'log10_vej', 'beta', 'log10_kappa_r']#
class fiesta.inference.analytical_models.kilonova_models.MetzgerModel(filters, times=None)[source]#

Bases: AnalyticalModel

300-shell kilonova model matching NMMA eff_metzger_lc.

Reference:

Redback: nikhil-sarin/redback NMMA: nuclear-multimessenger-astronomy/nmma

Parameters (in x dict):

log10_mej – log10 ejecta mass in solar masses log10_vej – log10 ejecta velocity in units of c beta – velocity power-law index log10_kappa_r – log10 opacity in cm^2/g

The ODE is solved per-shell in normalized units to avoid float32 overflow. Uses 300 mass shells with velocity profile, neutron fractions, and shell-dependent opacities matching the NMMA implementation.

compute_log10_lbol_rphot(x, t_days)[source]#

Return (log10_L_bol, log10_R_phot) arrays at each time in t_days.

L_bol in erg/s, R_phot in cm.

parameter_names: list[str] = ['log10_mej', 'log10_vej', 'beta', 'log10_kappa_r']#
class fiesta.inference.analytical_models.kilonova_models.OneComponentKilonovaModel(filters, times=None, temperature_floor=4000.0)[source]#

Bases: AnalyticalModel

Single-component kilonova with diffusion-integral heating.

Reference:

Redback: _one_component_kilonova_model in kilonova_models.py

Matches Redback’s cumulative trapezoid algorithm exactly, using a float32-safe damped recurrence that avoids exp(t^2/td^2) overflow.

Parameters (in x dict):

log10_mej – log10 ejecta mass in solar masses log10_vej – log10 ejecta velocity in units of c log10_kappa – log10 gray opacity in cm^2/g

compute_log10_lbol_rphot(x, t_days)[source]#

Return (log10_L_bol, log10_R_phot) arrays at each time in t_days.

L_bol in erg/s, R_phot in cm.

parameter_names: list[str] = ['log10_mej', 'log10_vej', 'log10_kappa']#

Shock-powered analytical light-curve models.

Reference:

Redback: nikhil-sarin/redback NMMA: nuclear-multimessenger-astronomy/nmma

class fiesta.inference.analytical_models.shock_powered_models.ShockCoolingModel(filters, times=None)[source]#

Bases: AnalyticalModel

Shock-cooling emission following Piro (2021).

Reference:

Redback: nikhil-sarin/redback NMMA: nuclear-multimessenger-astronomy/nmma

Parameters (all in x dict):

log10_Menv – log10 envelope mass in solar masses log10_Renv – log10 envelope radius in solar radii log10_Ee – log10 explosion energy in erg

compute_log10_lbol_rphot(x, t_days)[source]#

Full Piro (2021) shock cooling with n=10, delta=1.1.

Matches NMMA sc_bol_lc exactly. All overflow-prone quantities are computed in log10 space to stay within float32 range.

parameter_names: list[str] = ['log10_Menv', 'log10_Renv', 'log10_Ee']#
class fiesta.inference.analytical_models.shock_powered_models.ShockedCocoonModel(filters, times=None)[source]#

Bases: AnalyticalModel

Analytical jet cocoon cooling model.

Reference:

Redback: nikhil-sarin/redback

Fully algebraic (no ODE) — power-law luminosity decay with diffusion timescale. Based on the shocked cocoon model from Redback.

Parameters (in x dict):

log10_mej – log10 ejecta mass in solar masses log10_vej – log10 ejecta velocity in units of c eta – slope for ejecta density profile log10_tshock – log10 shock time in seconds shocked_fraction – fraction of ejecta mass shocked cos_theta_cocoon – cosine of cocoon opening angle log10_kappa – log10 gray opacity in cm^2/g

compute_log10_lbol_rphot(x, t_days)[source]#

Return (log10_L_bol, log10_R_phot) arrays at each time in t_days.

L_bol in erg/s, R_phot in cm.

parameter_names: list[str] = ['log10_mej', 'log10_vej', 'eta', 'log10_tshock', 'shocked_fraction', 'cos_theta_cocoon', 'log10_kappa']#

Tidal disruption event (TDE) analytical light-curve models.

Reference:

Redback: nikhil-sarin/redback

class fiesta.inference.analytical_models.tde_models.TDEAnalyticalModel(filters, times=None, temperature_floor=None)[source]#

Bases: AnalyticalModel

TDE analytical model with t^{-5/3} fallback + Arnett diffusion.

Reference:

Redback: nikhil-sarin/redback

Parameters (in x dict):

log10_l0 – log10 of luminosity at 1 second (erg/s) t_0_turn – turn-on time in days log10_mej – log10 of ejecta mass in solar masses log10_vej – log10 of ejecta velocity in km/s log10_kappa – log10 of opacity (cm^2/g) log10_kappa_gamma – log10 of gamma-ray opacity (cm^2/g)

compute_log10_lbol_rphot(x, t_days)[source]#

Return (log10_L_bol, log10_R_phot) arrays at each time in t_days.

L_bol in erg/s, R_phot in cm.

parameter_names: list[str] = ['log10_l0', 't_0_turn', 'log10_mej', 'log10_vej', 'log10_kappa', 'log10_kappa_gamma']#

SALT3 spectral-template supernova model via jax-bandflux.

Uses jax_supernovae (PyPI: jax-bandflux) for JAX-native, JIT-compiled, differentiable SALT3 light-curve evaluation. Unlike the physics-based models that compute L_bol + R_phot -> blackbody SED, SALT3 uses spectral templates (M0, M1, colour law) to compute per-band fluxes directly.

The jax_supernovae import is kept lazy to avoid loading heavy dependencies for users who don’t use SALT3.

class fiesta.inference.analytical_models.salt3_models.SALT3Model(filters, times=None, redshift=0.0)[source]#

Bases: object

SALT3 spectral-template model for Type Ia supernova light curves.

Parameters:
  • filters (list[str]) – Band names recognised by jax_supernovae.bandpasses (e.g. "ztfg", "ztfr", "bessellb").

  • times (Array, optional) – Observer-frame times (days) at which to evaluate the model.

  • redshift (float) – Source redshift (fixed, not sampled).

  • predict(x)) (Sampled parameters (passed via) – log10_x0 – log10 of the SALT3 amplitude parameter x0 x1 – SALT3 stretch c – SALT3 colour t0 – time of B-band maximum (days, same frame as times)

filters: list[str]#
parameter_names: list[str]#
predict(x)[source]#
Return type:

tuple[Array, dict[str, Array]]

times: Array#

Lightcurve Models#

Store classes to load in trained models and give routines to let them generate lightcurves.

class fiesta.inference.lightcurve_model.AfterglowFlux(*args, **kwargs)[source]#

Bases: FluxModel

class fiesta.inference.lightcurve_model.BullaFlux(*args, **kwargs)[source]#

Bases: FluxModel

class fiesta.inference.lightcurve_model.BullaLightcurveModel(*args, **kwargs)[source]#

Bases: LightcurveModel

class fiesta.inference.lightcurve_model.CombinedSurrogate(models, sample_times)[source]#

Bases: SurrogateModel

add_filter(filters)[source]#
predict(x)[source]#
class fiesta.inference.lightcurve_model.FluxModel(name, filters, directory=None)[source]#

Bases: SurrogateModel

Class of surrogate models that predicts the 2D spectral flux density array.

compute_output(x)[source]#

Apply the trained flax neural network on the given input x.

Parameters:

x (dict[str, Array]) – Input array of parameters per filter

Returns:

_description_

Return type:

dict[str, Array]

convert_to_mag(y, x)[source]#
Return type:

tuple[Array, dict[str, Array]]

load_filters(filters=None)[source]#
Return type:

None

load_networks()[source]#
Return type:

None

predict_log_flux(x)[source]#

Predict the total log10 flux array for the parameters x.

Parameters:

x (dict[str, Array]) – Input parameters, unnormalized and untransformed.

Returns:

times [Array]: time array in observer frame nus [Array]: frequency array in observer frame log10_flux [Array]: Array of log10-fluxes in mJy.

Return type:

tuple

project_input(x)[source]#

Project the given input to whatever preprocessed input space we are in.

Parameters:

x (Array) – Original input array

Returns:

Transformed input array

Return type:

Array

project_output(y)[source]#

Project the computed output to whatever preprocessed output space we are in.

Parameters:

y (dict[str, Array]) – Output array

Returns:

Output array transformed to the preprocessed space.

Return type:

dict[str, Array]

class fiesta.inference.lightcurve_model.LightcurveModel(name, filters, directory=None)[source]#

Bases: SurrogateModel

Class of surrogate models that predicts the magnitudes per filter.

X_scaler: object#
compute_output(x)[source]#

Apply the trained flax neural network on the given input x.

Parameters:

x (dict[str, Array]) – Input array of parameters per filter

Returns:

_description_

Return type:

dict[str, Array]

convert_to_mag(y, x)[source]#
Return type:

tuple[Array, dict[str, Array]]

directory: str#
load_filters(filters_args=None)[source]#
Return type:

None

load_networks()[source]#
Return type:

None

metadata: dict#
models: dict[str, TrainState]#
project_input(x)[source]#

Project the given input to whatever preprocessed input space we are in.

Parameters:

x (dict[str, Array]) – Original input array

Returns:

Transformed input array

Return type:

dict[str, Array]

project_output(y)[source]#

Project the computed output to whatever preprocessed output space we are in.

Parameters:

y (dict[str, Array]) – Output array

Returns:

Output array transformed to the preprocessed space.

Return type:

dict[str, Array]

y_scaler: dict[str, object]#
class fiesta.inference.lightcurve_model.SurrogateModel(name, directory=None)[source]#

Bases: object

Abstract class for general surrogate models

add_filter(filters)[source]#
add_name(x)[source]#
compute_output(x)[source]#

Compute the output (untransformed) from the given, transformed input. This is the main method that needs to be implemented by subclasses.

Parameters:

x (Array) – Input array

Returns:

Output array

Return type:

Array

convert_to_mag(y, x)[source]#
Return type:

tuple[Array, dict[str, Array]]

directory: str#
filters: list[str]#
load_metadata()[source]#
Return type:

None

name: str#
parameter_names: list[str]#
predict(x)[source]#

Generate the apparent magnitudes from the unnormalized and untransformed input x. Chains the projections with the actual computation of the output. E.g. if the model is a trained surrogate neural network, they represent the map from x tilde to y tilde. The mappings from x to x tilde and y to y tilde take care of projections (e.g. SVD projections) and normalizations.

Parameters:

x (dict[str, Array]) – Input array, unnormalized and untransformed.

Returns:

times (Array): time array in observer frame mag (dict[str, Array]): The desired magnitudes per filter

Return type:

tuple

predict_abs_mag(x)[source]#
Return type:

tuple[Array, dict[str, Array]]

project_input(x)[source]#

Project the given input to whatever preprocessed input space we are in. By default (i.e., in this base class), the projection is the identity function.

Parameters:

x (Array) – Input array

Returns:

Input array transformed to the preprocessed space.

Return type:

Array

project_output(y)[source]#

Project the computed output to whatever preprocessed output space we are in. By default (i.e., in this base class), the projection is the identity function.

Parameters:

y (Array) – Output array

Returns:

Output array transformed to the preprocessed space.

Return type:

Array

times: Array#
vpredict(X)[source]#

Vectorized prediction function to calculate the apparent magnitudes for several inputs x at the same time.

Return type:

tuple[Array, dict[str, Array]]

fiesta.inference.lightcurve_model.get_default_directory(name)[source]#

Likelihood#

Functions for computing likelihoods of data given a model.

class fiesta.inference.likelihood.EMLikelihood(model, data, trigger_time, data_tmin=0.0, data_tmax=999.0, filters=None, error_budget=0.3, conversion_function=<function EMLikelihood.<lambda>>, fixed_params={}, detection_limit=None)[source]#

Bases: LikelihoodBase

Likelihood object to compute likelihoods for the model parameters and a set of magnitude data points.

Parameters:
  • model (LightcurveModel | AnalyticalModel) – Light curve model that generates the estimated light curve from the parameters passed to evaluate.

  • data (dict[str, Float[Array, "ntimes 3"]]) – Dictionary with photometric filters as keys and arrays as values. The first column of the array are the detection times in MJD. The second column the magnitude data points. The third column are the Gaussian measurement errors. If an error is np.inf, the data point will be treated as an upper limit on the light curve.

  • trigger_time (Float) – Trigger time or start point of the light curve in MJD.

  • data_tmin (Float) – Time point (in observer frame, relative to trigger_time) before any data point from data will be cropped. Defaults to 0.0.

  • data_tmax (Float, default: 999.0) – Time point (in observer frame, relative to trigger_time) after which any data point from data will be cropped. Defaults to 999.0

  • filters (list[str]) – Filters that should be used for the likelihood evaluation. If None, will take filters from data. Defaults to None.

  • error_budget (Float) – Fixed error budget for the systematic uncertainty. Defaults to 0.3.

  • conversion_function (Callable) – Function that will be called on the params before model predicts the light curve. Defaults to the idenity.

  • fixed_params (dict[str, Float]) – Fixed parameters. These are added to the params before model predicts the light curve. Defaults to {}.

  • detection_limit (Float) – Detection limit of the telescope. If set, a truncated gaussian likelihood will be used. Defaults to None.

times_det#

The time points of the detected magnitudes per filter relative to the trigger time.

Type:

dict[str, Array]

times_nondet#

The time points of the non-detected magnitudes (upper limits) per filter relative to the trigger time.

Type:

dict[str, Array]

datapoints_det#

The detected magnitudes per filter.

Type:

dict[str, Array]

datapoints_nondet#

The non-detection magnitudes (upper limits) per filter.

Type:

dict[str, Array]

datapoints_err#

The gaussian measurement error of the detected magnitudes per filter.

Type:

dict[str, Array]

evaluate(theta)[source]#

Evaluate the log-likelihood of the data given the model and the parameters theta, at a single point.

Parameters:

theta (dict[str, Array]) – A dictionary containing the parameters used to generate the model light curve that is then used to compute the loglikelihood.

Returns:

The log-likelihood value at this parameter point.

Return type:

Float

class fiesta.inference.likelihood.FluxLikelihood(model, data, trigger_time, data_tmin=0.0, data_tmax=999.0, filters=None, error_budget=1, conversion_function=<function FluxLikelihood.<lambda>>, fixed_params={}, detection_limit=None, zero_point_mag=16.4)[source]#

Bases: LikelihoodBase

Likelihood object to compute likelihoods for the model parameters and a set of flux data points. Note that the data in the input argument still needs to be magnitudes. They will be converted internally to fluxes.

Parameters:
  • model (LightcurveModel | AnalyticalModel) – Light curve model that generates the estimated light curve from the parameters passed to evaluate.

  • data (dict[str, Float[Array, "ntimes 3"]]) – Dictionary with photometric filters as keys and arrays as values. The first column of the array are the detection times in MJD. The second column the magnitude data points. The third column are the Gaussian measurement errors. If an error is np.inf, the data point will be treated as an upper limit on the light curve.

  • trigger_time (Float) – Trigger time or start point of the light curve in MJD.

  • data_tmin (Float) – Time point (in observer frame, relative to trigger_time) before any data point from data will be cropped. Defaults to 0.0.

  • data_tmax (Float, default: 999.0) – Time point (in observer frame, relative to trigger_time) after which any data point from data will be cropped. Defaults to 999.0

  • filters (list[str]) – Filters that should be used for the likelihood evaluation. If None, will take filters from data. Defaults to None.

  • error_budget (Float) – Fixed error budget for the systematic uncertainty. Defaults to 1 mJy.

  • conversion_function (Callable) – Function that will be called on the params before model predicts the light curve. Defaults to the idenity.

  • fixed_params (dict[str, Float]) – Fixed parameters. These are added to the params before model predicts the light curve. Defaults to {}.

  • detection_limit (Float) – Detection limit of the telescope. If set, a truncated gaussian likelihood will be used. Defaults to None.

  • zero_point_mag (Float, default: 16.4) – Zero-point for mag-to-flux conversion, specifically to mJy (defaults to 16.4 for AB mag).

times_det#

The time points of the detected fluxes per filter relative to the trigger time.

Type:

dict[str, Array]

times_nondet#

The time points of the non-detected fluxes (upper limits) per filter relative to the trigger time.

Type:

dict[str, Array]

datapoints_det#

The detected fluxs per filter.

Type:

dict[str, Array]

datapoints_nondet#

The non-detection fluxes (upper limits) per filter.

Type:

dict[str, Array]

datapoints_err#

The gaussian measurement error of the detected fluxes per filter.

Type:

dict[str, Array]

evaluate(theta)[source]#

Evaluate the log-likelihood of the data given the model and the parameters theta, at a single point.

Parameters:

theta (dict[str, Array]) – A dictionary containing the parameters used to generate the model light curve that is then used to compute the loglikelihood.

Returns:

The log-likelihood value at this parameter point.

Return type:

Float

mag_to_flux(mag_arr)[source]#

Converts mag_arr to fluxes in mJy.

class fiesta.inference.likelihood.LikelihoodBase(model, data, trigger_time, data_tmin=0.0, data_tmax=999.0, filters=None, error_budget=0.3, conversion_function=<function LikelihoodBase.<lambda>>, fixed_params={}, detection_limit=None)[source]#

Bases: object

Base class for likelihoods.

static compute_gaussian_likelihood(y_est, y_data, sigma, lim)[source]#

Return the log likelihood of the chisquare part of the likelihood function, without truncation (no detection limit is given), i.e. a Gaussian pdf.

Return type:

Float

static compute_gaussian_survival(y_est, y_data, error_budget)[source]#
Return type:

Float

static compute_trunc_gaussian_likelihood(y_est, y_data, sigma, lim)[source]#

Return the log likelihood of the chisquare part of the likelihood function, with truncation of the Gaussian (detection limit is given).

Return type:

Float

cut_data_to_time_range(data, data_tmin, data_tmax)[source]#
Return type:

dict[str, Array, 'ntimes 3']]

data_tmax: Float#
data_tmin: Float#
datapoints_det: dict[str, Array]#
datapoints_err: dict[str, Array]#
datapoints_nondet: dict[str, Array]#
detection_limit: dict[str, Array]#
error_budget: dict[str, Array]#
evaluate(theta)[source]#

Evaluate the log-likelihood of the data given the parameters in theta and the underlying model.

Return type:

Float

filters: list[str]#
get_gaussprob_det(y_est, y_data, sigma, lim)[source]#

Return the log likelihood of the gaussian likelihood function for a single filter. Branch-off of jax.lax.cond is based on provided detection limit (lim). If the limit is infinite, the likelihood is calculated without truncation and without resorting to scipy for faster evaluation. If the limit is finite, the likelihood is calculated with truncation and with scipy.

Parameters:
  • y_est (Array) – The estimated data from the model at detection times.

  • y_data (Array) – The detected data.

  • sigma (Array) – The uncertainties on the detected apparent magnitudes, including the error budget.

  • lim (Float) – The detection limit for this filter.

Returns:

The gaussian log-likelihood for this filter.

Return type:

Float

get_gaussprob_nondet(y_est, y_data, error_budget)[source]#

Return the log likelihood of the gaussian likelihood function for a single filter. Branch-off of jax.lax.cond is based on provided detection limit (lim). If the limit is infinite, the likelihood is calculated without truncation and without resorting to scipy for faster evaluation. If the limit is finite, the likelihood is calculated with truncation and with scipy.

Parameters:
  • y_est (Array) – The estimated data from the model at detection times.

  • y_data (Array) – The nondetection data points.

  • sigma (Array) – The uncertainties on the detected apparent magnitudes, including the error budget.

  • lim (Float) – The detection limit for this filter.

Returns:

The gaussian log-likelihood for this filter.

Return type:

Float

model: LightcurveModel | AnalyticalModel#
setup_filters_and_data(filters, data)[source]#
Return type:

dict[str, Array, 'ntimes 3']]

times_det: dict[str, Array]#
times_nondet: dict[str, Array]#
trigger_time: Float#
vectorized_evaluate(theta)[source]#

Priors#

class fiesta.inference.prior.CompositePrior(priors, transforms={}, **kwargs)[source]#

Bases: Prior

log_prob(x)[source]#
Return type:

Float

priors: list[Prior] = Field(name=None,type=None,default=<dataclasses._MISSING_TYPE object>,default_factory=<class 'list'>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
sample(rng_key, n_samples)[source]#
Return type:

dict[str, Array, 'n_samples']]

class fiesta.inference.prior.ConstrainedPrior(priors, conversion_function=<function ConstrainedPrior.<lambda>>, transforms={})[source]#

Bases: CompositePrior

constraints: list[Constraint]#
conversion: Callable#
evaluate_constraints(samples)[source]#
factor: Float#
log_prob(x)[source]#
Return type:

Float

sample(rng_key, n_samples)[source]#
Return type:

dict[str, Array, 'n_samples']]

class fiesta.inference.prior.Constraint(naming, xmin, xmax, transforms={})[source]#

Bases: Prior

log_prob(x)[source]#
Return type:

Float

xmax: float#
xmin: float#
class fiesta.inference.prior.InterpedPrior(xx, yy, naming)[source]#

Bases: Prior

log_prob(x)[source]#
Return type:

Float

sample(rng_key, n_samples)[source]#
Return type:

dict[str, Array, 'n_samples']]

xx: Array#
yy: Array#
class fiesta.inference.prior.LogUniform(xmin, xmax, naming, transforms={}, **kwargs)[source]#

Bases: Prior

log_prob(x)[source]#
Return type:

Float

sample(rng_key, n_samples)[source]#

Sample from a uniform distribution.

Parameters:
  • rng_key (PRNGKeyArray) – A random key to use for sampling.

  • n_samples (int) – The number of samples to draw.

Returns:

samples – Samples from the distribution. The keys are the names of the parameters.

Return type:

dict

xmax: float = 1.0#
xmin: float = 0.0#
class fiesta.inference.prior.Normal(mu, sigma, naming, transforms={}, **kwargs)[source]#

Bases: Prior

log_prob(x)[source]#
Return type:

Float

mu: float = 0.0#
sample(rng_key, n_samples)[source]#

Sample from a normal distribution.

Parameters:
  • rng_key (PRNGKeyArray) – A random key to use for sampling.

  • n_samples (int) – The number of samples to draw.

Returns:

samples – Samples from the distribution. The keys are the names of the parameters.

Return type:

dict

sigma: float = 1.0#
class fiesta.inference.prior.Prior(naming, transforms={})[source]#

Bases: object

A thin base clase to do book keeping.

Should not be used directly since it does not implement any of the real method.

The rationale behind this is to have a class that can be used to keep track of the names of the parameters and the transforms that are applied to them.

add_name(x)[source]#

Turn an array into a dictionary

Parameters:

x (Array) – An array of parameters. Shape (n_dim,).

Return type:

dict[str, Float]

log_prob(x)[source]#
Return type:

Float

property n_dim#
naming: list[str]#
sample(rng_key, n_samples)[source]#
Return type:

dict[str, Array, 'n_samples']]

transform(x)[source]#

Apply the transforms to the parameters.

Parameters:

x (dict) – A dictionary of parameters. Names should match the ones in the prior.

Returns:

x – A dictionary of parameters with the transforms applied.

Return type:

dict

transforms: dict[str, tuple[str, Callable]] = Field(name=None,type=None,default=<dataclasses._MISSING_TYPE object>,default_factory=<class 'dict'>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
class fiesta.inference.prior.Sine(naming, xmin=0, xmax=3.141592653589793, transforms={})[source]#

Bases: Prior

log_prob(x)[source]#
Return type:

Float

sample(rng_key, n_samples)[source]#

Sample from a uniform distribution.

Parameters:
  • rng_key (PRNGKeyArray) – A random key to use for sampling.

  • n_samples (int) – The number of samples to draw.

Returns:

samples – Samples from the distribution. The keys are the names of the parameters.

Return type:

dict

xmax: float = 1.0#
xmin: float = 0.0#
class fiesta.inference.prior.TruncatedNormal(mu, sigma, xmin, xmax, naming, transforms={}, **kwargs)[source]#

Bases: Prior

Truncated normal distribution with explicit bounds.

Useful for informed priors from population studies (e.g., superphot+). The SVISampler uses xmin/xmax for its guide constraints and mu/sigma for the model’s TruncatedNormal distribution.

log_prob(x)[source]#
Return type:

Float

mu: float = 0.0#
sample(rng_key, n_samples)[source]#
Return type:

dict[str, Array, 'n_samples']]

sigma: float = 1.0#
xmax: float = 10.0#
xmin: float = -10.0#
class fiesta.inference.prior.Uniform(xmin, xmax, naming, transforms={}, **kwargs)[source]#

Bases: Prior

log_prob(x)[source]#
Return type:

Float

sample(rng_key, n_samples)[source]#

Sample from a uniform distribution.

Parameters:
  • rng_key (PRNGKeyArray) – A random key to use for sampling.

  • n_samples (int) – The number of samples to draw.

Returns:

samples – Samples from the distribution. The keys are the names of the parameters.

Return type:

dict

xmax: float = 1.0#
xmin: float = 0.0#
class fiesta.inference.prior.UniformSourceFrame(dmin, dmax, naming, cosmology=FlatLambdaCDM(name='Planck18', H0=<Quantity 67.66 km / (Mpc s)>, Om0=0.30966, Tcmb0=<Quantity 2.7255 K>, Neff=3.046, m_nu=<Quantity [0., 0., 0.06] eV>, Ob0=0.04897), **kwargs)[source]#

Bases: InterpedPrior

xmax: float = 100000.0#

Prior that is uniform in comoving volume and source frame time, analogue to the corresponding bilby prior. Uses the default cosmology in fiesta which is Planck18.

xmin: float = 10.0#
class fiesta.inference.prior.UniformVolume(xmin, xmax, naming, transforms={}, **kwargs)[source]#

Bases: Prior

log_prob(x)[source]#
Return type:

Float

sample(rng_key, n_samples)[source]#

Sample luminosity distance from a distribution uniform in volume.

Parameters:
  • rng_key (PRNGKeyArray) – A random key to use for sampling.

  • n_samples (int) – The number of samples to draw.

Returns:

samples – Samples from the distribution. The keys are the names of the parameters.

Return type:

dict

xmax: float = 100000.0#
xmin: float = 10.0#

Systematics#

fiesta.inference.systematic.check_filter_compatability(yaml_dict, filters)[source]#
fiesta.inference.systematic.fetch_prior_params(yaml_entry)[source]#
fiesta.inference.systematic.process_file(systematic_file, filters)[source]#
fiesta.inference.systematic.setup_systematic_from_file(likelihood, prior, systematics_file)[source]#
fiesta.inference.systematic.setup_systematics_basic(likelihood, prior, error_budget=0.3)[source]#

Sampler#

Utilities#

Functions for creating and handling injections

class fiesta.inference.injection.InjectionAfterglowpy(jet_type=-1, *args, **kwargs)[source]#

Bases: InjectionBase

class fiesta.inference.injection.InjectionBase(filters, trigger_time, tmin=0.1, tmax=10.0, N_datapoints=10, t_detect=None, error_budget=1.0, nondetections=False, nondetections_fraction=0.1, detection_limit=inf)[source]#

Bases: object

Base class to create synthetic injection lightcurves. The injection model is first initialized with the following parameters:

filters (list): List of filters in which the synthetic data should be given out. trigger_time (float): Reference trigger time (e.g. MJD or GPS seconds) added as an offset to all detection time stamps. Required. tmin (float): Time of earliest synthetic detection possible in days. Defaults to 0.1. tmax (float): Time of latest synthetic detection possible in days. Defaults to 10.0 N_datapoints (int): Total number of datapoints (across all filters) for the synthetic lightcurve. Defaults to 10. t_detect (dict[str, Array]): Detection time points in each filter. If none is specified, then the detection times will be sampled randomly. error_budget (float): Typical measurement error scale of the synthetic data. Defaults to 1. detection_limit (float): Synthetic datapoints with mangnitude higher than this value (i.e. less brighter) will be turned into nondetections. Defaults to np.inf. nondetections (bool): Additional to detection_limit, this turns some of the synthetic datapoints to nondetections. Defaults to False. nondetections_fraction: If nondetections is True, then this will determine the fractions of N_datapoints turned into nondetections. Defaults to 0.1.

Then one can call the .create_injection() method to get synthetic lightcurve data. The method .write_to_file() writes the synthetic lightcurve data to file.

create_injection(injection_dict, file=None)[source]#

Creates an injection that is stored as a .data attribute.

Parameters:
  • injection_dict (dict) – Parameters for the synthetic light curve.

  • file (str, optional) – Training data file that stores light curves from the physical base model of the surrogate. If provided, the method will take a random test element and base the injection on it. In this case, the .injection_parameter attribute is updated to contain the real parameters used to generate the light curve.

create_injection_from_mags(times, mag_app)[source]#
create_t_detect(tmin, tmax, N)[source]#

Create a time grid for the injection data.

randomize_nondetections()[source]#
write_to_file(file)[source]#
class fiesta.inference.injection.InjectionKN(*args, **kwargs)[source]#

Bases: InjectionBase

class fiesta.inference.injection.InjectionPyblastafterglow(jet_type='tophat', *args, **kwargs)[source]#

Bases: InjectionBase

class fiesta.inference.injection.InjectionSurrogate(model, *args, **kwargs)[source]#

Bases: InjectionBase

Class to create synthetic injection lightcurves from a surrogate. After instantiation one can call the .create_injection() method to get synthetic lightcurve data. The method .write_to_file() writes the synthetic lightcurve data to file.

fiesta.inference.injection.get_parser(**kwargs)[source]#
class fiesta.inference.plot.LightcurvePlotter(posterior, likelihood, systematics_file=None, free_syserr=False)[source]#

Bases: object

Interface to plot lightcurves from a given posterior.

Parameters:
  • posterior (dict | pd.DataFrame) – Posterior samples for which the light curves should be plotted.

  • likelihood (EMLikelihood) – Likelihood object that was used to sample the posterior.

  • systematics_file (str) – Systematics file that was used to sample the posterior. Defaults to None.

  • free_syserr (bool) – Whether a global systematic uncertainty was sampled freely. Defaults to False. Will overwrite systematics_file.

get_chisquared(per_dof=False)[source]#

Get the total chisquared value and the chisquared values per filter. This is different from the log_likelihood value in the posterior, because the likelihood function contains (2 pi sigma)^(-1/2).

Parameters:

per_dof (bool) – Whether to return reduced chi-squared values, i.e., per number of data points.

Returns:

The total chi-squared value across all data points and a dict with the chi-squared value in each filter.

Return type:

tuple(float, dict)

plot_best_fit_lc(ax, filt, zorder=2, **kwargs)[source]#

Plots one filter from the best fit light curve from the posterior over ax.

Parameters:
  • ax (matplotlib.axes.Axes) – ax to plot the light curve onto.

  • filt (str) – Which filter from the best fit lightcurve should be plotted on ax.

  • zorder (int) – zorder with which the lightcurve should be plotted.

  • **kwargs – kwargs to be passed to plot.

plot_data(ax, filt, zorder=3, **kwargs)[source]#

Plots data points from a filter over ax.

Parameters:
  • ax (matplotlib.axes.Axes) – ax to plot the data points to.

  • filt (str) – Which filter from the data should be plotted on ax.

  • zorder (int) – zorder with which the data points should be plotted.

  • **kwargs – kwargs to be passed to errorbar and scatter.

plot_sample_lc(ax, filt, zorder=1)[source]#

Plots background light curves from the posterior over ax.

Parameters:
  • ax (matplotlib.axes.Axes) – ax to plot the light curve onto.

  • filt (str) – Which filter from the background light curves should be plotted on ax.

  • zorder (int) – zorder with which the lightcurve should be plotted.

plot_sys_uncertainty_band(ax, filt, zorder=2, **kwargs)[source]#

Plots systematic uncertainty band from the best fit light curve for one filter over ax.

Parameters:
  • ax (matplotlib.axes.Axes) – ax to plot the band onto.

  • filt (str) – Which filter from the band should be plotted on ax.

  • zorder (int) – zorder with which the band should be plotted.

  • **kwargs – kwargs to be passed to fill_between.

fiesta.inference.plot.corner_plot(posterior, parameter_names, truths=None, color='blue', legend_label=None, fig=None, ax=None, **kwargs)[source]#

Make a nice corner plot from the posterior with automated parameter labels.

Parameters:
  • posterior (dict | pd.DataFrame) – posterior samples for which to do the corner plot.

  • parameter_names (list[str]) – parameters from posterior that should be included in the corner plot.

  • truths (dict[str, float] | None) – True (injected values) for some of the parameters. Defaults to None.

  • color (str) – color for the corner plot contours. Defaults to blue.

  • legend_label (str) – Label for the legend. If not set, no legend will be shown. Defaults to None.

  • fig (matplotlib.figure.Figure) – Figure over which to do the corner plot. If set, ax must also be provided. Defaults to None.

  • ax (matplotlib.axes.Axes) – Axes over which to do the corner plot. If set, fig must also be provided. Defaults to None.

Returns:

Figure with the corner plot. ax (matplotlib.axes.Axes): array of axes

Return type:

fig (matplotlib.figure.Figure)