Source code for fiesta.inference.analytical_models.shock_powered_models

"""Shock-powered analytical light-curve models.

Reference:
    Redback: https://github.com/nikhil-sarin/redback/blob/master/redback/transient_models/shock_powered_models.py
    NMMA: https://github.com/nuclear-multimessenger-astronomy/nmma/blob/main/nmma/em/lightcurve_generation.py
"""

import jax
import jax.numpy as jnp

from fiesta.constants import c_cgs, msun_cgs, Rsun_cgs, days_to_seconds

from fiesta.inference.analytical_models.base import (
    AnalyticalModel,
    _LOG10E, _LOG10_MSUN, _LOG10_RSUN, _LOG10_CCGS, _LOG10_4PI,
    _LOG10_DAYS2SEC,
)


[docs] class ShockCoolingModel(AnalyticalModel): """Shock-cooling emission following Piro (2021). Reference: Redback: https://github.com/nikhil-sarin/redback/blob/master/redback/transient_models/shock_powered_models.py NMMA: https://github.com/nuclear-multimessenger-astronomy/nmma/blob/main/nmma/em/lightcurve_generation.py Parameters (all in ``x`` dict): log10_Menv – log10 envelope mass in solar masses log10_Renv – log10 envelope radius in solar radii log10_Ee – log10 explosion energy in erg """ parameter_names = ["log10_Menv", "log10_Renv", "log10_Ee"] _kappa = 0.2 # cm^2/g (electron scattering) def __init__(self, filters, times=None): if times is None: times = jnp.geomspace(1.0 / 24.0, 3.5, 100) super().__init__(filters, times)
[docs] def compute_log10_lbol_rphot(self, x, t_days): """Full Piro (2021) shock cooling with n=10, delta=1.1. Matches NMMA ``sc_bol_lc`` exactly. All overflow-prone quantities are computed in log10 space to stay within float32 range. """ n = 10.0 delta = 1.1 kappa = self._kappa # 0.2 cm^2/g log10_kappa = jnp.log10(kappa) # Physical parameters in log10-CGS log10_Me = x["log10_Menv"] + _LOG10_MSUN # log10(grams) log10_Re = x["log10_Renv"] + _LOG10_RSUN # log10(cm) log10_Ee = x["log10_Ee"] # log10(erg) # Piro (2021) constants K = (n - 3.0) * (3.0 - delta) / (4.0 * jnp.pi * (n - delta)) # vt = sqrt(((n-5)(5-delta)/((n-3)(3-delta))) * 2*Ee/Me) vel_coeff = (n - 5.0) * (5.0 - delta) / ((n - 3.0) * (3.0 - delta)) log10_vt = 0.5 * (jnp.log10(vel_coeff * 2.0) + log10_Ee - log10_Me) # td = sqrt(3*kappa*K*Me / ((n-1)*vt*c)) log10_td = 0.5 * (jnp.log10(3.0 * kappa * K) + log10_Me - jnp.log10(n - 1.0) - log10_vt - _LOG10_CCGS) td = jnp.power(10.0, log10_td) # moderate: ~1e4-1e5 # Time in seconds t = t_days * days_to_seconds t = jnp.maximum(t, 1.0) log10_t = jnp.log10(t) # prefactor = pi*(n-1)/(3*(n-5)) * c * Re * vt^2 / kappa # This can exceed float32 max (~3.4e38), so compute in log10 log10_prefactor = (jnp.log10(jnp.pi * (n - 1.0) / (3.0 * (n - 5.0))) + _LOG10_CCGS + log10_Re + 2.0 * log10_vt - log10_kappa) # L_early = prefactor * (td/t)^(4/(n-2)) log10_L_early = log10_prefactor + 4.0 / (n - 2.0) * (log10_td - log10_t) # L_late = prefactor * exp(-0.5*(t^2/td^2 - 1)) exp_arg = -0.5 * (t**2 / td**2 - 1.0) log10_L_late = log10_prefactor + exp_arg * _LOG10E log10_L = jnp.where(t < td, log10_L_early, log10_L_late) log10_L = jnp.maximum(log10_L, 0.0) # floor at 1 erg/s # Photospheric radius # tph = sqrt(3*kappa*K*Me / (2*(n-1)*vt^2)) log10_tph = 0.5 * (jnp.log10(3.0 * kappa * K) + log10_Me - jnp.log10(2.0 * (n - 1.0)) - 2.0 * log10_vt) tph = jnp.power(10.0, log10_tph) # moderate: ~1e4 # R_early = (tph/t)^(2/(n-1)) * vt * t log10_R_early = (2.0 / (n - 1.0) * (log10_tph - log10_t) + log10_vt + log10_t) # R_late = (1 + (delta-1)/(n-1)*((t/tph)^2-1))^(-1/(delta-1)) * vt*t inner = 1.0 + (delta - 1.0) / (n - 1.0) * ((t / tph)**2 - 1.0) inner = jnp.maximum(inner, 1e-10) log10_R_late = (-1.0 / (delta - 1.0) * jnp.log10(inner) + log10_vt + log10_t) log10_R = jnp.where(t < tph, log10_R_early, log10_R_late) return log10_L, log10_R
[docs] class ShockedCocoonModel(AnalyticalModel): """Analytical jet cocoon cooling model. Reference: Redback: https://github.com/nikhil-sarin/redback/blob/master/redback/transient_models/shock_powered_models.py Fully algebraic (no ODE) — power-law luminosity decay with diffusion timescale. Based on the shocked cocoon model from Redback. Parameters (in ``x`` dict): log10_mej – log10 ejecta mass in solar masses log10_vej – log10 ejecta velocity in units of c eta – slope for ejecta density profile log10_tshock – log10 shock time in seconds shocked_fraction – fraction of ejecta mass shocked cos_theta_cocoon – cosine of cocoon opening angle log10_kappa – log10 gray opacity in cm^2/g """ parameter_names = ["log10_mej", "log10_vej", "eta", "log10_tshock", "shocked_fraction", "cos_theta_cocoon", "log10_kappa"] def __init__(self, filters, times=None): if times is None: times = jnp.geomspace(0.01, 30.0, 100) super().__init__(filters, times)
[docs] def compute_log10_lbol_rphot(self, x, t_days): log10_mej = x["log10_mej"] # solar masses log10_vej = x["log10_vej"] # units of c eta = x["eta"] log10_tshock = x["log10_tshock"] # seconds f_sh = x["shocked_fraction"] cos_theta = x["cos_theta_cocoon"] log10_kappa = x["log10_kappa"] f_sh = jnp.maximum(f_sh, 1e-12) cos_theta = jnp.clip(cos_theta, -1.0 + 1e-6, 1.0 - 1e-6) theta = jnp.arccos(cos_theta) # log10 of CGS quantities log10_vej_cms = log10_vej + _LOG10_CCGS # cm/s log10_rshock = log10_tshock + _LOG10_CCGS # cm log10_Msh_g = jnp.log10(f_sh) + log10_mej + _LOG10_MSUN # grams vej_cms = jnp.power(10.0, log10_vej_cms) # moderate: ~1e9-1e10 # Diffusion timescale (days): # tau_diff = sqrt(Msun * kappa * f_sh * mej / (4pi * c * vej_cms)) / day_to_s log10_tau_diff_s = 0.5 * (_LOG10_MSUN + log10_kappa + jnp.log10(f_sh) + log10_mej - _LOG10_4PI - _LOG10_CCGS - log10_vej_cms) tau_diff = jnp.power(10.0, log10_tau_diff_s) / days_to_seconds # days log10_tau_diff = jnp.log10(tau_diff) # Transition time (days): t_thin = sqrt(c/vej) * tau_diff t_thin = jnp.sqrt(c_cgs / vej_cms) * tau_diff # Luminosity scale in log10: # L_scale = (theta^2/2)^(1/3) * M_sh_g * vej_cms * rshock / (tau_diff_s)^2 log10_theta_factor = (1.0 / 3.0) * jnp.log10(theta**2 / 2.0) log10_L_scale = (log10_theta_factor + log10_Msh_g + log10_vej_cms + log10_rshock - 2.0 * log10_tau_diff_s) # Bolometric luminosity in log10: # L_bol = L_scale * (t/tau_diff)^(-4/(eta+2)) * (1+tanh(t_thin-t))/2 power_term = -4.0 / (eta + 2.0) * jnp.log10(jnp.maximum(t_days / tau_diff, 1e-10)) tanh_term = jnp.log10(jnp.maximum( (1.0 + jnp.tanh(t_thin - t_days)) / 2.0, 1e-30)) log10_L = log10_L_scale + power_term + tanh_term log10_L = jnp.maximum(log10_L, 0.0) # Photospheric velocity and radius (in log10) # v_phot = vej * (t/t_thin)^(-2/(eta+3)) log10_v_phot = (log10_vej_cms + (-2.0 / (eta + 3.0)) * jnp.log10(jnp.maximum(t_days / t_thin, 1e-10))) # R_phot = v_phot * t_days * day_to_s log10_R = log10_v_phot + jnp.log10(t_days) + _LOG10_DAYS2SEC return log10_L, log10_R