"""Supernova analytical light-curve models.
Reference:
Redback: https://github.com/nikhil-sarin/redback/blob/master/redback/transient_models/supernova_models.py
NMMA: https://github.com/nuclear-multimessenger-astronomy/nmma/blob/main/nmma/em/lightcurve_generation.py
"""
import jax
import jax.numpy as jnp
from fiesta.constants import days_to_seconds
from fiesta.inference.analytical_models.base import (
AnalyticalModel,
_gauss_legendre_nodes_weights,
_magnetar_luminosity,
_compute_diffusion_constants,
_arnett_diffusion_ode,
_arnett_diffusion_integral,
_csm_diffusion_integral,
_LOG10E, _LOG10_MSUN, _LOG10_CCGS, _LOG10_4PI, _LOG10_PI,
_LOG10_KM_CGS, _LOG10_AU_CGS, _LOG10_DAYS2SEC,
)
[docs]
class ArnettModel(AnalyticalModel):
"""Arnett (1982) Ni56/Co56-powered supernova bolometric model.
Reference:
Redback: https://github.com/nikhil-sarin/redback/blob/master/redback/transient_models/supernova_models.py
NMMA: https://github.com/nuclear-multimessenger-astronomy/nmma/blob/main/nmma/em/lightcurve_generation.py
Parameters (in ``x`` dict):
tau_m – diffusion timescale in days
log10_mni – log10 of Ni56 mass in solar masses
v_phot – photospheric velocity in units of 1e9 cm/s
t_0 – (modified variant only) gamma-ray trapping timescale in days
"""
parameter_names = ["tau_m", "log10_mni", "v_phot"]
# Ni56 and Co56 decay timescales (days)
_tau_ni = 8.77
_tau_co = 111.3
# log10 of energy release per gram (erg/s/g)
_log10_eps_ni = jnp.log10(3.90e10)
_log10_eps_co = jnp.log10(6.78e9)
def __init__(self, filters, times=None, modified=False):
self.modified = modified
if modified:
self.parameter_names = ["tau_m", "log10_mni", "v_phot", "t_0"]
else:
self.parameter_names = list(ArnettModel.parameter_names)
if times is None:
times = jnp.geomspace(0.1, 60.0, 100)
self._gl_nodes, self._gl_weights = _gauss_legendre_nodes_weights()
super().__init__(filters, times)
[docs]
def compute_log10_lbol_rphot(self, x, t_days):
tau_m = x["tau_m"] # days
log10_Mni = x["log10_mni"] + _LOG10_MSUN # log10(grams)
v_phot = x["v_phot"] # 1e9 cm/s
tau_ni = self._tau_ni
tau_co = self._tau_co
nodes = self._gl_nodes
weights = self._gl_weights
eps_ni = jnp.power(10.0, self._log10_eps_ni)
eps_co = jnp.power(10.0, self._log10_eps_co)
def _log10_lbol_single(t_d):
x = t_d / tau_m
z = x * nodes # GL nodes on [0,1] → z on [0, x]
# Arnett (1982) integrals with exp(-x^2) absorbed for stability:
# exp(-x^2) * A(x) = int_0^x 2z exp(z^2 - x^2 - z*tau_m/tau_ni) dz
# exp(-x^2) * B(x) = int_0^x 2z exp(z^2 - x^2 - z*tau_m/tau_co) dz
# Since z <= x, the exponent z^2 - x^2 <= 0, so bounded in float32.
A_ni = 2.0 * z * jnp.exp(z**2 - x**2 - z * tau_m / tau_ni)
I_A = x * jnp.dot(weights, A_ni)
A_co = 2.0 * z * jnp.exp(z**2 - x**2 - z * tau_m / tau_co)
I_B = x * jnp.dot(weights, A_co)
# L = M_ni * [(eps_ni - eps_co)*A + eps_co*B] (Ni -> Co -> Fe chain)
total_heating = (eps_ni - eps_co) * I_A + eps_co * I_B
log10_L = log10_Mni + jnp.log10(jnp.maximum(total_heating, 1e-30))
return log10_L
log10_L = jax.vmap(_log10_lbol_single)(t_days)
# Modified variant: multiply by trapping factor
if self.modified:
t_0 = x["t_0"]
trap = 1.0 - jnp.exp(-(t_0 / t_days)**2)
log10_L = log10_L + jnp.log10(jnp.maximum(trap, 1e-30))
log10_L = jnp.maximum(log10_L, 0.0)
# Photospheric radius: R = v_phot * 1e9 * t_sec
t_sec = t_days * days_to_seconds
# log10(R) = log10(v_phot) + 9 + log10(t_sec)
log10_R = jnp.log10(jnp.maximum(v_phot, 1e-10)) + 9.0 + jnp.log10(t_sec)
return log10_L, log10_R
[docs]
class NickelCobaltModel(AnalyticalModel):
"""Ni56/Co56 radioactive decay with Arnett (1982) diffusion.
Reference:
Redback: https://github.com/nikhil-sarin/redback/blob/master/redback/transient_models/supernova_models.py
Parameters (in ``x`` dict):
f_nickel – fraction of ejecta mass in Ni56
log10_mej – log10 of ejecta mass in solar masses
log10_vej – log10 of ejecta velocity in km/s
log10_kappa – log10 of opacity (cm^2/g)
log10_kappa_gamma – log10 of gamma-ray opacity (cm^2/g)
"""
parameter_names = ["f_nickel", "log10_mej", "log10_vej",
"log10_kappa", "log10_kappa_gamma"]
_n_internal = 500
# Decay parameters
_ni56_life = 8.8 # days
_co56_life = 111.3 # days
# Luminosity per solar mass of Ni56 (erg/s/Msun):
# ni56: 6.45e43, co56: 1.45e43
# Compute log10 without creating the large float (exceeds float32 max)
_log10_ni56_lum = 43.0 + jnp.log10(6.45) # 43.8096
_log10_co56_lum = 43.0 + jnp.log10(1.45) # 43.1614
def __init__(self, filters, times=None, temperature_floor=None):
if times is None:
times = jnp.geomspace(0.1, 150.0, 100)
super().__init__(filters, times, temperature_floor=temperature_floor)
[docs]
def compute_log10_lbol_rphot(self, x, t_days):
f_nickel = x["f_nickel"]
log10_mej_g = x["log10_mej"] + _LOG10_MSUN
log10_vej_kms = x["log10_vej"]
log10_kappa = x["log10_kappa"]
log10_kappa_gamma = x["log10_kappa_gamma"]
# Nickel mass in log10(solar masses) — the 6.45e43/1.45e43 constants
# are luminosity per solar mass of Ni56
log10_mni_solar = jnp.log10(jnp.maximum(f_nickel, 1e-10)) + x["log10_mej"]
# Dense internal time grid (log-spaced for better early-time resolution)
t_start = jnp.maximum(t_days[0] * 0.1, 0.01)
t_end = t_days[-1] * 1.1
t_int = jnp.geomspace(t_start, t_end, self._n_internal)
# Engine luminosity via log-sum-exp for float32 safety:
# L(t) = M_ni * (6.45e43 * exp(-t/8.8) + 1.45e43 * exp(-t/111.3))
# log10(L) = log10(M_ni) + log10(6.45e43*exp(-t/8.8) + 1.45e43*exp(-t/111.3))
log10_a = self._log10_ni56_lum + (-t_int / self._ni56_life) * _LOG10E
log10_b = self._log10_co56_lum + (-t_int / self._co56_life) * _LOG10E
log10_max = jnp.maximum(log10_a, log10_b)
log10_sum = log10_max + jnp.log10(
jnp.power(10.0, log10_a - log10_max)
+ jnp.power(10.0, log10_b - log10_max))
log10_L_engine = log10_mni_solar + log10_sum
# Diffusion constants
log10_td, log10_A_trap = _compute_diffusion_constants(
log10_kappa, log10_kappa_gamma, log10_mej_g, log10_vej_kms)
log10_L_scale = jnp.maximum(jnp.max(log10_L_engine), 30.0)
# Solve diffusion via trapezoidal integral (Redback-matched)
log10_L_int = _arnett_diffusion_integral(
log10_L_engine, t_int, log10_td, log10_A_trap, log10_L_scale)
# Interpolate to output times
L_int = jnp.power(10.0, log10_L_int - log10_L_scale)
L_out = jnp.interp(t_days, t_int, L_int)
log10_L = jnp.log10(jnp.maximum(L_out, 1e-30)) + log10_L_scale
log10_L = jnp.maximum(log10_L, 0.0)
# Photospheric radius: R = vej * t
log10_vej_cms = log10_vej_kms + _LOG10_KM_CGS
t_sec = t_days * days_to_seconds
log10_R = log10_vej_cms + jnp.log10(t_sec)
return log10_L, log10_R
[docs]
class MagnetarPoweredSNModel(AnalyticalModel):
"""Magnetar spin-down powered supernova with Arnett (1982) diffusion.
Reference:
Redback: https://github.com/nikhil-sarin/redback/blob/master/redback/transient_models/supernova_models.py
Parameters (in ``x`` dict):
log10_p0 – log10 initial spin period in ms
log10_bp – log10 polar B-field in 1e14 G
mass_ns – neutron star mass in solar masses
theta_pb – angle between spin and B-field in radians
log10_mej – log10 of ejecta mass in solar masses
log10_vej – log10 of ejecta velocity in km/s
log10_kappa – log10 of opacity (cm^2/g)
log10_kappa_gamma – log10 of gamma-ray opacity (cm^2/g)
"""
parameter_names = ["log10_p0", "log10_bp", "mass_ns", "theta_pb",
"log10_mej", "log10_vej", "log10_kappa",
"log10_kappa_gamma"]
_n_internal = 500
def __init__(self, filters, times=None, temperature_floor=None):
if times is None:
times = jnp.geomspace(0.1, 200.0, 100)
super().__init__(filters, times, temperature_floor=temperature_floor)
[docs]
def compute_log10_lbol_rphot(self, x, t_days):
log10_mej_g = x["log10_mej"] + _LOG10_MSUN
log10_vej_kms = x["log10_vej"]
log10_kappa = x["log10_kappa"]
log10_kappa_gamma = x["log10_kappa_gamma"]
# Dense internal time grid (log-spaced for better early-time resolution)
t_start = jnp.maximum(t_days[0] * 0.1, 0.01)
t_end = t_days[-1] * 1.1
t_int = jnp.geomspace(t_start, t_end, self._n_internal)
t_int_sec = t_int * days_to_seconds
# Engine: magnetar spin-down luminosity (already in log10 erg/s)
log10_L_engine = _magnetar_luminosity(
t_int_sec, x["log10_p0"], x["log10_bp"],
x["mass_ns"], x["theta_pb"])
# Diffusion constants
log10_td, log10_A_trap = _compute_diffusion_constants(
log10_kappa, log10_kappa_gamma, log10_mej_g, log10_vej_kms)
log10_L_scale = jnp.maximum(jnp.max(log10_L_engine), 30.0)
# Solve diffusion via trapezoidal integral (Redback-matched)
log10_L_int = _arnett_diffusion_integral(
log10_L_engine, t_int, log10_td, log10_A_trap, log10_L_scale)
# Interpolate to output times
L_int = jnp.power(10.0, log10_L_int - log10_L_scale)
L_out = jnp.interp(t_days, t_int, L_int)
log10_L = jnp.log10(jnp.maximum(L_out, 1e-30)) + log10_L_scale
log10_L = jnp.maximum(log10_L, 0.0)
# Photospheric radius: R = vej * t
log10_vej_cms = log10_vej_kms + _LOG10_KM_CGS
t_sec = t_days * days_to_seconds
log10_R = log10_vej_cms + jnp.log10(t_sec)
return log10_L, log10_R
[docs]
class CSMInteractionModel(AnalyticalModel):
"""Circumstellar medium interaction model (Chevalier 1982).
Reference:
Redback: https://github.com/nikhil-sarin/redback/blob/master/redback/transient_models/supernova_models.py
Forward + reverse shock luminosity from Chevalier self-similar solution,
with optional CSM diffusion.
Parameters (in ``x`` dict):
log10_mej – log10 of ejecta mass in solar masses
log10_csm_mass – log10 of CSM mass in solar masses
log10_vej – log10 of ejecta velocity in km/s
eta – CSM density profile exponent
log10_rho – log10 of CSM density amplitude (g/cm^{eta+3})
log10_kappa – log10 of opacity (cm^2/g)
log10_r0 – log10 of CSM inner radius in AU
Constructor kwargs:
nn – ejecta power-law index (default 12)
delta – inner density exponent (default 1)
efficiency – kinetic-to-luminosity conversion (default 0.5)
"""
parameter_names = ["log10_mej", "log10_csm_mass", "log10_vej", "eta",
"log10_rho", "log10_kappa", "log10_r0"]
_n_internal = 500
def __init__(self, filters, times=None, nn=12, delta=1, efficiency=0.5,
temperature_floor=None):
self.nn = nn
self.delta = delta
self.efficiency = efficiency
# Load CSM table and pre-interpolate along nn axis
self._load_csm_table(nn)
if times is None:
times = jnp.geomspace(0.1, 300.0, 100)
super().__init__(filters, times, temperature_floor=temperature_floor)
def _load_csm_table(self, nn_val):
"""Load CSM coefficient table and pre-interpolate for fixed nn."""
import numpy as np
import os
table_path = os.path.join(os.path.dirname(os.path.dirname(__file__)),
"tables", "csm_table.txt")
data = np.loadtxt(table_path, delimiter=',')
# Columns: eta_col(10 unique), nn_col(30 unique), Bf, Br, AA
# (matches Redback's column order in utils.get_csm_properties)
eta_col = data[:, 0]
nn_col = data[:, 1]
# Reshape: 10 eta values x 30 nn values
n_eta = len(np.unique(eta_col))
n_nn = len(np.unique(nn_col))
Bf_grid = data[:, 2].reshape(n_eta, n_nn) # (10, 30)
Br_grid = data[:, 3].reshape(n_eta, n_nn)
AA_grid = data[:, 4].reshape(n_eta, n_nn)
eta_vals = np.unique(eta_col) # 10 values
nn_vals = np.unique(nn_col) # 30 values
# Interpolate along nn axis for fixed nn_val -> 1D arrays of eta
nn_val_clipped = np.clip(nn_val, nn_vals[0], nn_vals[-1])
AA_1d = np.array([np.interp(nn_val_clipped, nn_vals, AA_grid[i, :])
for i in range(n_eta)])
Bf_1d = np.array([np.interp(nn_val_clipped, nn_vals, Bf_grid[i, :])
for i in range(n_eta)])
Br_1d = np.array([np.interp(nn_val_clipped, nn_vals, Br_grid[i, :])
for i in range(n_eta)])
self._eta_grid = jnp.array(eta_vals)
self._AA_1d = jnp.array(AA_1d)
self._Bf_1d = jnp.array(Bf_1d)
self._Br_1d = jnp.array(Br_1d)
[docs]
def compute_log10_lbol_rphot(self, x, t_days):
nn = float(self.nn)
delta = float(self.delta)
eff = self.efficiency
log10_mej = x["log10_mej"] + _LOG10_MSUN # grams
log10_csm = x["log10_csm_mass"] + _LOG10_MSUN # grams
log10_vej = x["log10_vej"] + _LOG10_KM_CGS # cm/s
eta = x["eta"]
log10_rho = x["log10_rho"]
log10_kappa = x["log10_kappa"]
log10_r0 = x["log10_r0"] + _LOG10_AU_CGS # cm
# Chevalier coefficients via 1D interp on eta
AA = jnp.interp(eta, self._eta_grid, self._AA_1d)
Bf = jnp.interp(eta, self._eta_grid, self._Bf_1d)
Br = jnp.interp(eta, self._eta_grid, self._Br_1d)
log10_AA = jnp.log10(jnp.maximum(AA, 1e-30))
log10_Bf = jnp.log10(jnp.maximum(Bf, 1e-30))
log10_Br = jnp.log10(jnp.maximum(Br, 1e-30))
# --- Geometry in log10 space to avoid float32 overflow ---
log10_qq = log10_rho + eta * log10_r0
# Nudge eta away from singularities at 1 and 3 to avoid division by zero
eps = 1e-4
eta = jnp.where(jnp.abs(eta - 1.0) < eps, 1.0 + eps, eta)
eta = jnp.where(jnp.abs(eta - 3.0) < eps, 3.0 - eps, eta)
# radius_csm = ((3-eta)/(4*pi*qq)*csm + r0^(3-eta))^(1/(3-eta))
# Compute each term in log10, then combine
log10_term1 = (jnp.log10(3.0 - eta) - _LOG10_4PI
- log10_qq + log10_csm)
log10_term2 = (3.0 - eta) * log10_r0
log10_max_rc = jnp.maximum(log10_term1, log10_term2)
log10_rc_inner = log10_max_rc + jnp.log10(
jnp.power(10.0, log10_term1 - log10_max_rc)
+ jnp.power(10.0, log10_term2 - log10_max_rc))
log10_radius_csm = log10_rc_inner / (3.0 - eta)
# r_photosphere = |(-2(1-eta)/(3*kappa*qq) + R_csm^(1-eta))^(1/(1-eta))|
# term_a = -2(1-eta)/(3*kappa*qq) [negative]
# term_b = R_csm^(1-eta) [positive]
log10_abs_a = (jnp.log10(2.0 * jnp.abs(1.0 - eta) / 3.0)
- log10_kappa - log10_qq)
log10_b = (1.0 - eta) * log10_radius_csm
# r_ph_inner = term_b - |term_a| (since term_a is negative)
# Use log-sub: log10(b-a) = log10(b) + log10(1 - a/b)
ratio = jnp.power(10.0, log10_abs_a - log10_b)
ratio = jnp.minimum(ratio, 0.999) # ensure positive result
log10_rph_inner = log10_b + jnp.log10(1.0 - ratio)
log10_r_ph = log10_rph_inner / (1.0 - eta)
# Optically thick CSM mass:
# mcst = |4*pi*qq/(3-eta) * (r_ph^(3-eta) - r0^(3-eta))|
log10_mcst_coeff = _LOG10_4PI + log10_qq - jnp.log10(3.0 - eta)
log10_rph_term = (3.0 - eta) * log10_r_ph
log10_r0_term = (3.0 - eta) * log10_r0
# r_ph^(3-eta) > r0^(3-eta) typically
log10_diff = log10_rph_term + jnp.log10(
1.0 - jnp.power(10.0, log10_r0_term - log10_rph_term))
log10_mcst = log10_mcst_coeff + log10_diff
# --- Overflow-prone quantities in log10 space ---
# Esn = 0.3 * vej^2 * mej
log10_Esn = jnp.log10(0.3) + 2.0 * log10_vej + log10_mej
# g_n = 1/(4*pi*(nn-delta)) * (2*(5-delta)*(nn-5)*Esn)^((nn-3)/2)
# / ((3-delta)*(nn-3)*mej)^((nn-5)/2)
log10_g_n = (-_LOG10_4PI - jnp.log10(nn - delta)
+ 0.5 * (nn - 3.0) * (jnp.log10(
2.0 * (5.0 - delta) * (nn - 5.0)) + log10_Esn)
- 0.5 * (nn - 5.0) * (jnp.log10(
(3.0 - delta) * (nn - 3.0)) + log10_mej))
# --- Shock breakout times in log10(seconds) ---
ab = nn - eta
# t_FS
p_FS = ab / ((nn - 3.0) * (3.0 - eta))
log10_t_FS_inner = (jnp.log10(3.0 - eta)
+ (3.0 - nn) / ab * log10_qq
+ (eta - 3.0) / ab * (log10_AA + log10_g_n)
- _LOG10_4PI - (3.0 - eta) * log10_Bf)
log10_t_FS = p_FS * log10_t_FS_inner + p_FS * log10_mcst
# t_RS: reverse shock sweep-up time (Chevalier 1982)
# t_RS = (vej / (Br*(AA*g_n/qq)^(1/ab)) * corr)^(ab/(eta-3))
# corr = (1 - (3-nn)*mej / (4*pi*vej^(3-nn)*g_n))^(1/(3-nn))
log10_inner_RS = (log10_vej - log10_Br
- (1.0 / ab) * (log10_AA + log10_g_n - log10_qq))
# Correction factor: for nn > 3, (3-nn) < 0, so the ratio term is
# negative and the argument is 1 + |ratio| > 1.
# |ratio| = (nn-3)*mej / (4*pi*vej^(3-nn)*g_n)
log10_abs_ratio = (jnp.log10(nn - 3.0) + log10_mej
- _LOG10_4PI + (nn - 3.0) * log10_vej - log10_g_n)
abs_ratio = jnp.power(10.0, log10_abs_ratio)
# arg = 1 + abs_ratio; corr = arg^(1/(3-nn)) = arg^(-1/(nn-3))
log10_corr = -1.0 / (nn - 3.0) * jnp.log10(1.0 + abs_ratio)
log10_t_RS = ab / (eta - 3.0) * (log10_inner_RS + log10_corr)
t_FS_sec = jnp.power(10.0, log10_t_FS)
t_RS_sec = jnp.power(10.0, log10_t_RS)
# --- Shock luminosities in log10 space ---
t_start = jnp.maximum(t_days[0] * 0.1, 0.01)
t_end = t_days[-1] * 1.1
t_int = jnp.linspace(t_start, t_end, self._n_internal)
t_int_sec = t_int * days_to_seconds + 1.0 # regularization
alpha_L = (2.0 * nn + 6.0 * eta - nn * eta - 15.0) / ab
log10_t_sec = jnp.log10(t_int_sec)
# log10(lbol_FS_coeff) — time-independent part
log10_FS_coeff = (jnp.log10(2.0 * jnp.pi) - 3.0 * jnp.log10(ab)
+ (5.0 - eta) / ab * log10_g_n
+ (nn - 5.0) / ab * log10_qq
+ 2.0 * jnp.log10(nn - 3.0)
+ jnp.log10(nn - 5.0)
+ (5.0 - eta) * log10_Bf
+ (5.0 - eta) / ab * log10_AA)
# log10(lbol_RS_coeff)
log10_RS_coeff = (jnp.log10(2.0 * jnp.pi)
+ (5.0 - nn) / ab
* (log10_AA + log10_g_n - log10_qq)
+ (5.0 - nn) * log10_Br
+ log10_g_n
+ 3.0 * jnp.log10((3.0 - eta) / ab))
log10_lbol_FS = log10_FS_coeff + alpha_L * log10_t_sec
log10_lbol_RS = log10_RS_coeff + alpha_L * log10_t_sec
# Mask by breakout times (set to -30 when inactive)
mask_FS = t_int_sec < t_FS_sec
mask_RS = t_int_sec < t_RS_sec
log10_lbol_FS = jnp.where(mask_FS, log10_lbol_FS, -30.0)
log10_lbol_RS = jnp.where(mask_RS, log10_lbol_RS, -30.0)
# Combine: lbol = eff * (lbol_FS + lbol_RS) via log-sum-exp
log10_max = jnp.maximum(log10_lbol_FS, log10_lbol_RS)
lbol_sum = (jnp.power(10.0, log10_lbol_FS - log10_max)
+ jnp.power(10.0, log10_lbol_RS - log10_max))
log10_lbol = (jnp.log10(eff) + log10_max
+ jnp.log10(jnp.maximum(lbol_sum, 1e-30)))
# --- CSM diffusion kernel (trapezoidal integral, Redback-matched) ---
log10_beta_csm = jnp.log10(4.0 * jnp.pi**3 / 9.0)
log10_t0_csm_sec = (log10_kappa + log10_mcst
- log10_beta_csm - _LOG10_CCGS - log10_r_ph)
t0_csm_days = jnp.maximum(
jnp.power(10.0, log10_t0_csm_sec - _LOG10_DAYS2SEC), 1e-6)
log10_L_scale = jnp.maximum(jnp.max(log10_lbol), 30.0)
log10_L_diff = _csm_diffusion_integral(
log10_lbol, t_int, t0_csm_days, log10_L_scale)
# Interpolate to output times
L_diff_n = jnp.power(10.0, log10_L_diff - log10_L_scale)
L_diff_out = jnp.interp(t_days, t_int, L_diff_n)
log10_L = jnp.log10(jnp.maximum(L_diff_out, 1e-30)) + log10_L_scale
log10_L = jnp.maximum(log10_L, 0.0)
# Photosphere: constant from CSM geometry
log10_R = jnp.full_like(t_days, log10_r_ph)
return log10_L, log10_R