Source code for fiesta.inference.analytical_models.kilonova_models

"""Kilonova analytical light-curve models.

Reference:
    Redback: https://github.com/nikhil-sarin/redback/blob/master/redback/transient_models/kilonova_models.py
    NMMA: https://github.com/nuclear-multimessenger-astronomy/nmma/blob/main/nmma/em/lightcurve_generation.py
"""

import jax
import jax.numpy as jnp

from fiesta.constants import c_cgs, msun_cgs, days_to_seconds

from fiesta.inference.analytical_models.base import (
    AnalyticalModel,
    _magnetar_luminosity,
    _LOG10_MSUN, _LOG10_RSUN, _LOG10_CCGS, _LOG10_4PI,
    _barnes16_thermalisation_coefficients,
    _barnes16_e_th,
    _electron_fraction_from_kappa,
)


def _validate_times(times):
    """Validate time array for kilonova ODE/recurrence models.

    Requires at least 2 finite, positive, strictly increasing samples
    (needed for jnp.diff in the time-stepping logic).
    """
    import numpy as np
    arr = np.asarray(times)
    if arr.ndim == 0 or arr.shape[0] < 2:
        raise ValueError(
            f"Kilonova models require at least 2 time samples, got shape {arr.shape}"
        )
    if not np.all(np.isfinite(arr)):
        raise ValueError("times array contains non-finite values")
    if not np.all(arr > 0):
        raise ValueError("times array must be strictly positive")
    if not np.all(np.diff(arr) > 0):
        raise ValueError("times array must be strictly increasing")


[docs] class MetzgerModel(AnalyticalModel): """300-shell kilonova model matching NMMA ``eff_metzger_lc``. Reference: Redback: https://github.com/nikhil-sarin/redback/blob/master/redback/transient_models/kilonova_models.py NMMA: https://github.com/nuclear-multimessenger-astronomy/nmma/blob/main/nmma/em/lightcurve_generation.py Parameters (in ``x`` dict): log10_mej – log10 ejecta mass in solar masses log10_vej – log10 ejecta velocity in units of c beta – velocity power-law index log10_kappa_r – log10 opacity in cm^2/g The ODE is solved per-shell in normalized units to avoid float32 overflow. Uses 300 mass shells with velocity profile, neutron fractions, and shell-dependent opacities matching the NMMA implementation. """ parameter_names = ["log10_mej", "log10_vej", "beta", "log10_kappa_r"] _n_shells = 300 _n_internal = 500 def __init__(self, filters, times=None): if times is None: times = jnp.geomspace(0.1, 30.0, 100) _validate_times(times) super().__init__(filters, times)
[docs] def compute_log10_lbol_rphot(self, x, t_days): M0 = jnp.power(10.0, x["log10_mej"]) * msun_cgs # total ejecta mass (g) v0 = jnp.power(10.0, x["log10_vej"]) * c_cgs # minimum escape velocity (cm/s) beta_val = x["beta"] kappa_r = jnp.power(10.0, x["log10_kappa_r"]) n_shells = self._n_shells # 300 # Use the output time grid directly (matching NMMA's approach) t_sec = t_days * days_to_seconds # Variable time steps (matches NMMA's geomspace dt) dt_arr = jnp.diff(t_sec) # Pad with the first dt so arrays match (first step is approximate) dt_padded = jnp.concatenate([dt_arr[:1], dt_arr]) # Thermalization efficiency pre-computed at output times (matching NMMA) def _eth(t_day): timescale_factor = 2.0 * 0.17 * jnp.power(jnp.maximum(t_day, 1e-6), 0.74) return 0.36 * (jnp.exp(-0.56 * t_day) + jnp.log(1.0 + timescale_factor) / timescale_factor) eth_arr = _eth(t_days) # (n_t,) # Mass shells: geometric spacing (solar masses) matching NMMA m = jnp.geomspace(1e-8, jnp.power(10.0, x["log10_mej"]), n_shells) # Msun dm = jnp.diff(m) # (n_shells-1,) in Msun # Velocity profile: v = v0 * (m*msun/M0)^(-1/beta), capped at c vm = v0 * jnp.power(m * msun_cgs / M0, -1.0 / beta_val) vm = jnp.minimum(vm, c_cgs) # Neutron fractions (NMMA: Mn=1e-8, Ye=0.1, Xn0max=0.8) Mn = 1e-8 Xn0 = 0.8 * 2.0 / jnp.pi * jnp.arctan(Mn / m) # (n_shells,) Xr = 1.0 - Xn0 # r-process fraction # Pre-compute time-dependent arrays: Xn(t), edot(t), kappa(t) # Shape: (n_t, n_shells) or (n_t, n_shells-1) Xn_t = Xn0[:-1][None, :] * jnp.exp(-t_sec[:, None] / 900.0) # (n_t, ns-1) edotn = 3.2e14 * Xn_t # (n_t, ns-1) edotr = (2.1e10 * eth_arr[:, None] * jnp.power(jnp.maximum(t_days, 1e-6), -1.3)[:, None]) # (n_t, ns-1) edot_all = edotn + edotr # (n_t, ns-1) kappan = 0.4 * (1.0 - Xn_t - Xr[:-1][None, :]) kappa_all = kappan + kappa_r * Xr[:-1][None, :] # (n_t, ns-1) # The ODE uses per-gram quantities (ene in erg/g), which are moderate # (~1e10 to 1e18) and safe for float32. def _scan_step(ene, inputs): # ene: (n_shells-1,) energy per gram in erg/g t_i, dt_i, edot_i, kappa_i = inputs t_safe = jnp.maximum(t_i, 1.0) # Diffusion timescale per shell (NMMA formula) tdiff = 0.08 * kappa_i * m[:-1] * msun_cgs * 3.0 / ( vm[:-1] * c_cgs * t_safe * beta_val) # Luminosity per unit mass (erg/s/g) lum_j = ene / (tdiff + t_i * vm[:-1] / c_cgs) # Total luminosity: sum(lum_j * dm) in Msun*erg/s/g (moderate) L_dm_total = jnp.sum(lum_j * dm) # Optical depth for photosphere tau_shells = m[:-1] * msun_cgs * kappa_i / ( 4.0 * jnp.pi * (t_safe * vm[:-1])**2) pig = jnp.argmin(jnp.abs(tau_shells - 1.0)) R_ph = vm[pig] * t_i # Energy ODE (NMMA: ene += dt*(edot - ene/t - lum_j)) dene = dt_i * (edot_i - ene / t_safe - lum_j) ene_new = jnp.maximum(ene + dene, 0.0) return ene_new, (L_dm_total, R_ph) # Initial energy per shell: zero (NMMA starts from zero) E0 = jnp.zeros(n_shells - 1) _, (L_dm_arr, R_arr) = jax.lax.scan( _scan_step, E0, (t_sec, dt_padded, edot_all, kappa_all)) # Convert total luminosity to log10: # L_total = L_dm_total * msun_cgs log10_L = jnp.log10(jnp.maximum(jnp.abs(L_dm_arr), 1e-30)) + _LOG10_MSUN log10_L = jnp.maximum(log10_L, 0.0) log10_R = jnp.log10(jnp.maximum(R_arr, 1.0)) return log10_L, log10_R
[docs] class MetzgerFullModel(AnalyticalModel): """Multi-shell kilonova model (Metzger 2017), matching Redback exactly. Reference: Redback: _metzger_kilonova_model in kilonova_models.py 200 shells with linear velocity spacing, Barnes+16 thermalisation, optional neutron precursor, per-gram energy ODE. Parameters (in ``x`` dict): log10_mej – log10 ejecta mass in solar masses log10_vej – log10 ejecta velocity (vmin) in units of c beta – velocity power-law index log10_kappa_r – log10 opacity in cm^2/g """ parameter_names = ["log10_mej", "log10_vej", "beta", "log10_kappa_r"] _n_shells = 200 def __init__(self, filters, times=None, neutron_precursor=True, vmax=0.7): self._neutron_precursor = neutron_precursor self._vmax = vmax if times is None: times = jnp.geomspace(0.1, 30.0, 100) _validate_times(times) super().__init__(filters, times)
[docs] def compute_log10_lbol_rphot(self, x, t_days): mej = jnp.power(10.0, x["log10_mej"]) # solar masses vej = jnp.power(10.0, x["log10_vej"]) # fraction of c beta_val = x["beta"] kappa_r = jnp.power(10.0, x["log10_kappa_r"]) mass_len = self._n_shells # 200 t_sec = t_days * days_to_seconds t_d = t_days # Barnes+16 thermalisation av, bv, dv = _barnes16_thermalisation_coefficients(mej, vej) e_th = _barnes16_e_th(t_d, av, bv, dv) # (time_len,) # Heating rate (dual regime, matching Redback lines 1867-1871) t0 = 1.3 # seconds sig = 0.11 # seconds edotr_late = 2.1e10 * e_th * t_d ** (-1.3) edotr_early = 4.0e18 * jnp.power( 0.5 - jnp.arctan((t_sec - t0) / sig) / jnp.pi, 1.3) * e_th edotr_base = jnp.where(t_sec > t0, edotr_late, edotr_early) # Shell layout: linear velocity spacing (matching Redback) vmin = vej vmax = self._vmax vel = jnp.linspace(vmin, vmax, mass_len) m_array = mej * (vel / vmin) ** (-beta_val) # solar masses v_m = vel * c_cgs # cm/s dm = jnp.abs(jnp.diff(m_array)) # (mass_len-1,) solar masses # Neutron precursor (matching Redback) neutron_precursor = self._neutron_precursor if neutron_precursor: Ye = _electron_fraction_from_kappa(kappa_r) neutron_mass = 1e-8 * msun_cgs Xn0 = 1.0 - 2.0 * Ye * 2.0 * jnp.arctan( neutron_mass / m_array / msun_cgs) / jnp.pi Xr = 1.0 - Xn0 # Initial conditions in Redback mixed units: energy_v = m_array * v^2/2 # Units: Msun * (cm/s)^2 — moderate values, safe for float32 E0 = 0.5 * m_array * v_m ** 2 dt_arr = jnp.diff(t_sec) def _scan_step(ene, inputs): t_i, dt_i, edotr_i, e_th_i = inputs if neutron_precursor: Xn_t = Xn0 * jnp.exp(-t_i / 900.0) edotn = 3.2e14 * Xn_t * Xn_t edot = edotn[:-1] + edotr_i kappa_n = 0.4 * (1.0 - Xn_t - Xr) kappa_total = kappa_n + kappa_r * Xr else: edot = edotr_i * jnp.ones(mass_len - 1) kappa_total = kappa_r * jnp.ones(mass_len) td_v = (kappa_total[:-1] * m_array[:-1] * msun_cgs * 3.0) / ( 4.0 * jnp.pi * v_m[:-1] * c_cgs * t_i * beta_val) lum_rad = ene[:-1] / (td_v + t_i * (v_m[:-1] / c_cgs)) ene_new = jnp.concatenate([ ene[:-1] + (edot - ene[:-1] / t_i - lum_rad) * dt_i, ene[-1:], ]) ene_new = jnp.maximum(ene_new, 0.0) # Sum lum * dm in mixed units (moderate), defer msun to log10 L_dm_total = jnp.sum(lum_rad * dm) tau = m_array[:-1] * msun_cgs * kappa_total[:-1] / ( 4.0 * jnp.pi * (t_i * v_m[:-1]) ** 2) tau_full = jnp.concatenate([tau, tau[-1:]]) pig = jnp.argmin(jnp.abs(tau_full - 1.0)) R_ph = v_m[pig] * t_i return ene_new, (L_dm_total, R_ph) _, (L_arr, R_arr) = jax.lax.scan( _scan_step, E0, (t_sec[:-1], dt_arr, edotr_base[:-1], e_th[:-1])) L_arr = jnp.concatenate([L_arr, L_arr[-1:]]) R_arr = jnp.concatenate([R_arr, R_arr[-1:]]) # Convert: L_erg_s = L_dm_total * msun_cgs log10_L = jnp.log10(jnp.maximum(jnp.abs(L_arr), 1e-30)) + _LOG10_MSUN log10_L = jnp.maximum(log10_L, 0.0) log10_R = jnp.log10(jnp.maximum(R_arr, 1.0)) return log10_L, log10_R
[docs] class OneComponentKilonovaModel(AnalyticalModel): """Single-component kilonova with diffusion-integral heating. Reference: Redback: _one_component_kilonova_model in kilonova_models.py Matches Redback's cumulative trapezoid algorithm exactly, using a float32-safe damped recurrence that avoids exp(t^2/td^2) overflow. Parameters (in ``x`` dict): log10_mej – log10 ejecta mass in solar masses log10_vej – log10 ejecta velocity in units of c log10_kappa – log10 gray opacity in cm^2/g """ parameter_names = ["log10_mej", "log10_vej", "log10_kappa"] def __init__(self, filters, times=None, temperature_floor=4000.0): if times is None: times = jnp.geomspace(0.1, 30.0, 100) _validate_times(times) super().__init__(filters, times, temperature_floor=temperature_floor)
[docs] def compute_log10_lbol_rphot(self, x, t_days): log10_mej_g = x["log10_mej"] + _LOG10_MSUN log10_vej_cms = x["log10_vej"] + _LOG10_CCGS log10_kappa = x["log10_kappa"] mej_solar = jnp.power(10.0, x["log10_mej"]) vej_c = jnp.power(10.0, x["log10_vej"]) t_sec = t_days * days_to_seconds # Diffusion timescale: td = sqrt(2 * kappa * mej / (13.7 * vej * c)) log10_td = 0.5 * (jnp.log10(2.0) + log10_kappa + log10_mej_g - jnp.log10(13.7) - log10_vej_cms - _LOG10_CCGS) td = jnp.power(10.0, log10_td) # Normalization: Q_scale = 4e18 * mej_g log10_Q_scale = jnp.log10(4.0e18) + log10_mej_g # Barnes+16 thermalisation av, bv, dv = _barnes16_thermalisation_coefficients(mej_solar, vej_c) e_th = _barnes16_e_th(t_days, av, bv, dv) # Normalized integrand (without exp factor): # f_n = shape(t) * e_th(t) * (t/td) # Use identity: 0.5 - arctan(x)/π = arctan(1/x)/π for x > 0 # to avoid float32 catastrophic cancellation at late times. t0_heat = 1.3 # seconds sig_heat = 0.11 # seconds f_n = (jnp.power( jnp.arctan(sig_heat / jnp.maximum(t_sec - t0_heat, 1e-6)) / jnp.pi, 1.3) * e_th * (t_sec / td)) # Float32-safe O(n²) vectorized cumulative trapezoid. # Redback: L = cumtrapz(f*exp(t²/td²), t) * exp(-t²/td²) / td # Equivalent: L[j] = (1/td) * sum_i 0.5*(f[i]*φ[j,i] + f[i+1]*φ[j,i+1])*dt[i] # where φ[j,k] = exp(-(t[j]²-t[k]²)/td²) ∈ [0,1] for j >= k. # Each output computed independently — no sequential error accumulation. n = t_sec.shape[0] dt = jnp.diff(t_sec) # (n-1,) t_sq = t_sec ** 2 # Damping factors: φ[j+1, k] for j=0..n-2, k=0..n-1 exp_arg = -(t_sq[1:, None] - t_sq[None, :]) / td**2 # (n-1, n) phi = jnp.exp(jnp.clip(exp_arg, -80.0, 0.0)) # Trapezoid: h[j, i] = 0.5*(f[i]*φ[j+1,i] + f[i+1]*φ[j+1,i+1]) * dt[i] h = 0.5 * (f_n[:-1][None, :] * phi[:, :-1] + f_n[1:][None, :] * phi[:, 1:]) * dt[None, :] # Mask: only i <= j (lower triangular incl. diagonal) mask = jnp.tril(jnp.ones((n - 1, n - 1), dtype=bool)) h = jnp.where(mask, h, 0.0) W_arr = jnp.sum(h, axis=1) # (n-1,) # L[0] = L[1] matching Redback W_full = jnp.concatenate([W_arr[:1], W_arr]) L_n = jnp.maximum(W_full / td, 0.0) log10_L = jnp.log10(jnp.maximum(L_n, 1e-30)) + log10_Q_scale log10_L = jnp.maximum(log10_L, 0.0) # Photospheric radius: R = vej * t log10_R = log10_vej_cms + jnp.log10(t_sec) return log10_L, log10_R
[docs] class MagnetarBoostedKilonovaModel(AnalyticalModel): """Multi-shell kilonova with magnetar spin-down heating, matching Redback. Reference: Redback: _general_metzger_magnetar_driven_kilonova_model 200-shell ODE with magnetar injection into bottom layer, velocity evolution, optional pair cascade and neutron precursor. Parameters (in ``x`` dict): log10_mej – log10 ejecta mass in solar masses log10_vej – log10 ejecta velocity (vmin) in units of c beta – velocity power-law index log10_kappa_r – log10 opacity in cm^2/g log10_p0 – log10 initial spin period in ms log10_bp – log10 polar B-field in 1e14 G mass_ns – neutron star mass in solar masses theta_pb – angle between spin and B-field in radians thermalisation_efficiency – magnetar thermalisation efficiency """ parameter_names = ["log10_mej", "log10_vej", "beta", "log10_kappa_r", "log10_p0", "log10_bp", "mass_ns", "theta_pb", "thermalisation_efficiency"] _n_shells = 200 def __init__(self, filters, times=None, neutron_precursor=True, pair_cascade=True, vmax=0.7, magnetar_heating='first_layer'): self._neutron_precursor = neutron_precursor self._pair_cascade = pair_cascade self._vmax = vmax self._magnetar_heating = magnetar_heating if times is None: times = jnp.geomspace(0.1, 30.0, 100) _validate_times(times) super().__init__(filters, times)
[docs] def compute_log10_lbol_rphot(self, x, t_days): mej = jnp.power(10.0, x["log10_mej"]) # solar masses vej = jnp.power(10.0, x["log10_vej"]) # fraction of c beta_val = x["beta"] kappa_r = jnp.power(10.0, x["log10_kappa_r"]) th_eff = x["thermalisation_efficiency"] mass_len = self._n_shells # 200 t_sec = t_days * days_to_seconds t_d = t_days # Barnes+16 thermalisation for r-process av, bv, dv = _barnes16_thermalisation_coefficients(mej, vej) e_th = _barnes16_e_th(t_d, av, bv, dv) # Heating rate (dual regime) t0 = 1.3 sig = 0.11 edotr_late = 2.1e10 * e_th * t_d ** (-1.3) edotr_early = 4.0e18 * jnp.power( 0.5 - jnp.arctan((t_sec - t0) / sig) / jnp.pi, 1.3) * e_th edotr_base = jnp.where(t_sec > t0, edotr_late, edotr_early) # Shell layout with mass normalization (matching Redback) vmin = vej vmax_val = self._vmax vel = jnp.linspace(vmin, vmax_val, mass_len) m_array = mej * (vel / vmin) ** (-beta_val) total_mass = jnp.sum(m_array) m_array = m_array * (mej / total_mass) v_m_init = vel * c_cgs dm = jnp.abs(jnp.diff(m_array)) # Magnetar luminosity in log10 (never materialized in linear) log10_L_mag = _magnetar_luminosity( t_sec, x["log10_p0"], x["log10_bp"], x["mass_ns"], x["theta_pb"]) # Energy normalization via log10_E_scale (never materialized as 10^x). # Redback's energy_v is in erg. We define ene_n = energy_v / E_scale. log10_E_scale = jnp.maximum(log10_L_mag[0], 30.0) # Precompute msun / E_scale (safe: LOG10_MSUN - log10_E_scale < 0) msun_per_E = jnp.power(10.0, _LOG10_MSUN - log10_E_scale) # Precompute E_scale / m0 for velocity evolution (float32-safe) # E_over_m0 = 10^(log10_E_scale - LOG10_MSUN - log10_mej) E_over_m0 = jnp.power(10.0, log10_E_scale - _LOG10_MSUN - x["log10_mej"]) # Mass power-law exponent (precomputed, constant) mass_power = (m_array / mej) ** (-1.0 / beta_val) # Neutron precursor neutron_precursor = self._neutron_precursor if neutron_precursor: Ye = _electron_fraction_from_kappa(kappa_r) neutron_mass = 1e-8 * msun_cgs Xn0 = 1.0 - 2.0 * Ye * 2.0 * jnp.arctan( neutron_mass / m_array / msun_cgs) / jnp.pi Xr = 1.0 - Xn0 # Initial conditions (normalized via log10) E0_raw = 0.5 * m_array * v_m_init ** 2 # Msun*(cm/s)^2, safe float32 E0_n = jnp.power(10.0, jnp.log10(jnp.maximum(E0_raw, 1e-30)) - log10_E_scale) # Initial kinetic energy (normalized): ek_n = 0.5*m0*v0^2 / E_scale log10_ek_0 = (jnp.log10(0.5) + x["log10_mej"] + _LOG10_MSUN + 2.0 * (x["log10_vej"] + _LOG10_CCGS)) ek_n_0 = jnp.power(10.0, log10_ek_0 - log10_E_scale) dt_arr = jnp.diff(t_sec) first_layer = (self._magnetar_heating == 'first_layer') def _scan_step(state, inputs): ene_n, ek_n = state t_i, dt_i, edotr_i, log10_lsd_i = inputs # Velocity evolution (matching Redback lines 682-689): # kinetic_energy += energy_v[0]/t * dt # v0 = sqrt(2*KE/m0) ek_n = ek_n + (ene_n[0] / t_i) * dt_i v0_sq = 2.0 * ek_n * E_over_m0 v0_new = jnp.sqrt(jnp.maximum(v0_sq, 0.0)) v_m_t = jnp.minimum(v0_new * mass_power, c_cgs) # Magnetar heating (normalized via log10) log10_qdot_n = (jnp.log10(jnp.maximum(th_eff, 1e-30)) + log10_lsd_i - log10_E_scale) qdot_n = jnp.power(10.0, jnp.clip(log10_qdot_n, -30.0, 30.0)) if neutron_precursor: Xn_t = Xn0 * jnp.exp(-t_i / 900.0) edotn = 3.2e14 * Xn_t * Xn_t kappa_n = 0.4 * (1.0 - Xn_t - Xr) kappa_total = kappa_n + kappa_r * Xr else: edotn = jnp.zeros(mass_len) kappa_total = kappa_r * jnp.ones(mass_len) # Heating terms normalized: edotr * dm * msun / E_scale edotr_dm_n = edotr_i * dm * msun_per_E edotn_dm_n = edotn[:-1] * dm * msun_per_E # Diffusion timescale (using evolved v_m_t) td_v = (kappa_total[:-1] * m_array[:-1] * msun_cgs * 3.0) / ( 4.0 * jnp.pi * v_m_t[:-1] * c_cgs * t_i * beta_val) lum_n = ene_n[:-1] / (td_v + t_i * (v_m_t[:-1] / c_cgs)) if first_layer: ene_0_new = ene_n[0] + (qdot_n + edotr_dm_n[0] + edotn_dm_n[0] - ene_n[0] / t_i - lum_n[0]) * dt_i ene_mid_new = ene_n[1:-1] + (edotr_dm_n[1:] + edotn_dm_n[1:] - ene_n[1:-1] / t_i - lum_n[1:]) * dt_i ene_n_new = jnp.concatenate([ ene_0_new[None], ene_mid_new, ene_n[-1:]]) else: ene_n_new = jnp.concatenate([ ene_n[:-1] + (qdot_n + edotr_dm_n + edotn_dm_n - ene_n[:-1] / t_i - lum_n) * dt_i, ene_n[-1:], ]) ene_n_new = jnp.maximum(ene_n_new, 0.0) L_n_total = jnp.sum(lum_n) tau = m_array[:-1] * msun_cgs * kappa_total[:-1] / ( 4.0 * jnp.pi * (t_i * v_m_t[:-1]) ** 2) tau_full = jnp.concatenate([tau, tau[-1:]]) pig = jnp.argmin(jnp.abs(tau_full - 1.0)) R_ph = v_m_t[pig] * t_i return (ene_n_new, ek_n), (L_n_total, R_ph) _, (L_n_arr, R_arr) = jax.lax.scan( _scan_step, (E0_n, ek_n_0), (t_sec[:-1], dt_arr, edotr_base[:-1], log10_L_mag[:-1])) L_n_arr = jnp.concatenate([L_n_arr, L_n_arr[-1:]]) R_arr = jnp.concatenate([R_arr, R_arr[-1:]]) # Convert: L_erg_s = L_n * E_scale log10_L = jnp.log10(jnp.maximum(jnp.abs(L_n_arr), 1e-30)) + log10_E_scale # Pair cascade (matching Redback) — uses v0 from last step if self._pair_cascade: ejecta_albedo = 0.5 pair_cascade_fraction = 0.01 log10_tlife = (jnp.log10(0.6 / (1.0 - ejecta_albedo)) + 0.5 * jnp.log10(pair_cascade_fraction / 0.1) + 0.5 * (log10_L_mag - 45.0) + 0.5 * jnp.log10(vej / 0.3) - 0.5 * jnp.log10(jnp.maximum(t_d, 1e-10))) tlife_t = jnp.power(10.0, jnp.clip(log10_tlife, -20.0, 20.0)) log10_L = log10_L - jnp.log10(1.0 + tlife_t) log10_L = jnp.maximum(log10_L, 0.0) log10_R = jnp.log10(jnp.maximum(R_arr, 1.0)) return log10_L, log10_R