Source code for fiesta.inference.analytical_models.base

"""Base classes, constants, and shared helpers for analytical light-curve models.

Each model is fully JIT-compilable and differentiable so that flowMC's MALA
sampler can compute ``jax.grad`` through the likelihood.  The models follow
the same ``predict()`` contract as the surrogate models:

    (source_frame_times, {filter_name: apparent_mag_array})

This makes them drop-in replacements inside ``CombinedSurrogate`` and
``EMLikelihood``.

All internal physics computations use log10 space to avoid float32 overflow
(e.g. explosion energies ~1e49 erg exceed float32 max ~3.4e38).
"""

from functools import partial

import numpy as np
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float

from fiesta import filters as fiesta_filters
from fiesta.conversions import mag_app_from_mag_abs
from fiesta.constants import (
    c_cgs, msun_cgs, sigSB, kb, Rsun_cgs, days_to_seconds,
    pc_to_cm, h_erg_s, c, au_cgs, km_cgs,
)
from fiesta.logging import logger

# ---------------------------------------------------------------------------
# Pre-computed log10 constants (avoids recomputation inside JIT)
# ---------------------------------------------------------------------------

_LOG10E = jnp.log10(jnp.e)           # log10(e) for converting ln -> log10
_LN10 = jnp.log(10.0)
_LOG10_MSUN = jnp.log10(msun_cgs)    # ~33.30
_LOG10_RSUN = jnp.log10(Rsun_cgs)    # ~10.84
_LOG10_CCGS = jnp.log10(c_cgs)       # ~10.48
_LOG10_SIGMASB = jnp.log10(sigSB)    # ~-4.25
_LOG10_KB = jnp.log10(kb)            # ~-15.86
_LOG10_H_ERG_S = jnp.log10(h_erg_s)  # ~-26.18
_LOG10_4PI = jnp.log10(4.0 * jnp.pi)
_LOG10_PI = jnp.log10(jnp.pi)
_LOG10_D10PC = jnp.log10(10.0 * pc_to_cm)
_LOG10_DAYS2SEC = jnp.log10(days_to_seconds)
_LOG10_AU_CGS = jnp.log10(au_cgs)
_LOG10_KM_CGS = jnp.log10(km_cgs)


# ---------------------------------------------------------------------------
# Barnes+16 thermalisation (interpolated, matching Redback exactly)
# ---------------------------------------------------------------------------

_BARNES_V = jnp.array([0.1, 0.2, 0.3, 0.4])
_BARNES_M = jnp.array([1.e-3, 5.e-3, 1.e-2, 5.e-2, 1.e-1])

_BARNES_A = jnp.array([[2.01, 4.52, 8.16, 16.3],
                        [0.81, 1.9, 3.2, 5.0],
                        [0.56, 1.31, 2.19, 3.0],
                        [0.27, 0.55, 0.95, 2.0],
                        [0.20, 0.39, 0.65, 0.9]])

_BARNES_B = jnp.array([[0.28, 0.62, 1.19, 2.4],
                        [0.19, 0.28, 0.45, 0.65],
                        [0.17, 0.21, 0.31, 0.45],
                        [0.10, 0.13, 0.15, 0.17],
                        [0.06, 0.11, 0.12, 0.12]])

_BARNES_D = jnp.array([[1.12, 1.39, 1.52, 1.65],
                        [0.86, 1.21, 1.39, 1.5],
                        [0.74, 1.13, 1.32, 1.4],
                        [0.6, 0.9, 1.13, 1.25],
                        [0.63, 0.79, 1.04, 1.5]])


def _bilinear_interp(grid_x, grid_y, values, x, y):
    """JAX-compatible bilinear interpolation on a 2-D regular grid.

    Clamps at boundaries (matching scipy RegularGridInterpolator fill_value=None).
    grid_x: (Nx,), grid_y: (Ny,), values: (Nx, Ny), x/y: scalars.
    """
    x = jnp.clip(x, grid_x[0], grid_x[-1])
    y = jnp.clip(y, grid_y[0], grid_y[-1])

    ix = jnp.searchsorted(grid_x, x, side='right') - 1
    ix = jnp.clip(ix, 0, len(grid_x) - 2)
    iy = jnp.searchsorted(grid_y, y, side='right') - 1
    iy = jnp.clip(iy, 0, len(grid_y) - 2)

    x0 = grid_x[ix]
    x1 = grid_x[ix + 1]
    y0 = grid_y[iy]
    y1 = grid_y[iy + 1]

    tx = (x - x0) / (x1 - x0)
    ty = (y - y0) / (y1 - y0)

    v00 = values[ix, iy]
    v10 = values[ix + 1, iy]
    v01 = values[ix, iy + 1]
    v11 = values[ix + 1, iy + 1]

    return (v00 * (1 - tx) * (1 - ty) + v10 * tx * (1 - ty)
            + v01 * (1 - tx) * ty + v11 * tx * ty)


def _barnes16_thermalisation_coefficients(mej_solar, vej_c):
    """Return (av, bv, dv) for Barnes+16 thermalisation, matching Redback."""
    av = _bilinear_interp(_BARNES_M, _BARNES_V, _BARNES_A, mej_solar, vej_c)
    bv = _bilinear_interp(_BARNES_M, _BARNES_V, _BARNES_B, mej_solar, vej_c)
    dv = _bilinear_interp(_BARNES_M, _BARNES_V, _BARNES_D, mej_solar, vej_c)
    return av, bv, dv


def _barnes16_e_th(t_days, av, bv, dv):
    """Barnes+16 thermalisation efficiency (Eq 25 Metzger 2017), matching Redback."""
    t_safe = jnp.maximum(t_days, 1e-12)
    denom = 2.0 * bv * t_safe ** dv
    return 0.36 * (jnp.exp(jnp.clip(-av * t_safe, -80.0, 0.0))
                   + jnp.log1p(denom) / jnp.maximum(denom, 1e-30))


# Tanaka+19 kappa-to-Ye table (matching Redback electron_fraction_from_kappa)
_TANAKA_KAPPA = jnp.array([35.0, 32.2, 22.3, 5.60, 5.36, 3.30, 0.96, 0.5])
_TANAKA_YE = jnp.array([0.10, 0.15, 0.2, 0.25, 0.30, 0.35, 0.4, 0.5])


def _electron_fraction_from_kappa(kappa):
    """Tanaka+19 interpolation from gray opacity to electron fraction (JAX)."""
    return jnp.interp(kappa, _TANAKA_KAPPA[::-1], _TANAKA_YE[::-1])


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def _log10_blackbody_mJy_at_10pc(T, log10_R, nus):
    """log10 of Planck spectral flux density at 10 pc in mJy.

    Works entirely in log10 space for the radius to avoid float32 overflow.

    Parameters
    ----------
    T : scalar  – temperature in K (linear, typically 1e3–1e5, safe for float32)
    log10_R : scalar  – log10 of photospheric radius in cm
    nus : 1-D array  – frequencies in Hz

    Returns
    -------
    log10_mJy : 1-D array (n_nus,) – log10 of spectral flux density in mJy
    """
    # Planck function  B_nu = 2 h nu^3 / c^2  /  (exp(h nu / kT) - 1)
    x = h_erg_s * nus / (kb * T)
    x = jnp.clip(x, 0.0, 80.0)  # prevent overflow in exp

    log10_Bnu = (_LOG10_H_ERG_S + jnp.log10(2.0) + 3.0 * jnp.log10(nus)
                 - 2.0 * _LOG10_CCGS - jnp.log10(jnp.expm1(x)))

    # F_nu = pi * (R / d)^2 * B_nu   (erg/cm^2/s/Hz at 10 pc)
    log10_Fnu = _LOG10_PI + 2.0 * log10_R - 2.0 * _LOG10_D10PC + log10_Bnu

    # -> mJy:  1 Jy = 1e-23 erg/cm^2/s/Hz, 1 mJy = 1e-3 Jy
    log10_mJy = log10_Fnu + 23.0 + 3.0
    return log10_mJy


def _gauss_legendre_nodes_weights(n=32):
    """Return n-point Gauss-Legendre nodes/weights on [0, 1]."""
    import numpy as np
    nodes, weights = np.polynomial.legendre.leggauss(n)
    nodes = (nodes + 1.0) / 2.0
    weights = weights / 2.0
    return jnp.array(nodes), jnp.array(weights)


def _magnetar_luminosity(t_sec, log10_p0, log10_bp, mass_ns, theta_pb):
    """Magnetar spin-down luminosity (dipole radiation).

    All intermediate quantities computed in log10 space to avoid float32
    overflow (E_rot ~ 1e52 exceeds float32 max).

    Parameters
    ----------
    t_sec : array – time in seconds
    log10_p0 : scalar – log10 of initial spin period in milliseconds
    log10_bp : scalar – log10 of polar B-field in units of 1e14 G
    mass_ns : scalar – neutron star mass in solar masses
    theta_pb : scalar – angle between spin and B-field axes in radians

    Returns
    -------
    log10_L_mag : array – log10 of magnetar luminosity in erg/s
    """
    # log10(E_rot) = log10(2.6e52) + 1.5*log10(M/1.4) - 2*log10(p0_ms)
    log10_mass_ratio = jnp.log10(mass_ns / 1.4)
    log10_E_rot = 52.0 + jnp.log10(2.6) + 1.5 * log10_mass_ratio - 2.0 * log10_p0

    # tau_p = 1.3e5 * Bp14^{-2} * P_ms^2 * (M/1.4)^{3/2} * sin(theta)^{-2}
    # Compute in log10 to avoid overflow
    log10_tau_p = (5.0 + jnp.log10(1.3)
                   - 2.0 * log10_bp + 2.0 * log10_p0
                   + 1.5 * log10_mass_ratio
                   - 2.0 * jnp.log10(jnp.maximum(jnp.sin(theta_pb), 1e-6)))
    tau_p = jnp.power(10.0, log10_tau_p)

    # L_mag = E_rot / tau_p / (1 + t/tau_p)^2
    log10_L_mag = (log10_E_rot - log10_tau_p
                   - 2.0 * jnp.log10(1.0 + t_sec / tau_p))
    return log10_L_mag


def _compute_diffusion_constants(log10_kappa, log10_kappa_gamma,
                                 log10_mej_g, log10_vej_kms):
    """Compute diffusion timescale and trapping coefficient in log10 space.

    Uses Arnett (1982) conventions with CGS inputs where velocity is in km/s.

    Returns
    -------
    log10_td_days : scalar – log10 of diffusion timescale in days
    log10_A_trap_days2 : scalar – log10 of trapping coefficient in days^2
    """
    # tau_diff = sqrt(2 * kappa * mej_g / (13.7 * c_cgs * vej_cms))
    # where mej_g is already in grams and vej_cms = 10^(log10_vej_kms) * km_cgs
    log10_td_sec = 0.5 * (jnp.log10(2.0) + log10_kappa + log10_mej_g
                          - jnp.log10(13.7) - _LOG10_CCGS
                          - log10_vej_kms - _LOG10_KM_CGS)
    log10_td_days = log10_td_sec - _LOG10_DAYS2SEC

    # A_trap = 3 * kappa_gamma * mej_g / (4*pi * vej_cms^2)
    log10_A_sec2 = (jnp.log10(3.0) + log10_kappa_gamma + log10_mej_g
                    - _LOG10_4PI - 2.0 * (log10_vej_kms + _LOG10_KM_CGS))
    log10_A_trap_days2 = log10_A_sec2 - 2.0 * _LOG10_DAYS2SEC

    return log10_td_days, log10_A_trap_days2


def _arnett_diffusion_ode(log10_L_engine, t_days_grid,
                          log10_td_days, log10_A_trap_days2,
                          log10_L_scale):
    """Arnett (1982) diffusion integral via stable first-order ODE.

    Solves dV/dt = L_engine(t)*t - 2*t/td^2 * V  using ``jax.lax.scan``,
    then L_obs = 2/td^2 * V * (1 - exp(-A_trap/t^2)).

    All luminosity values are normalized by ``10^log10_L_scale`` to keep
    float32-safe O(1) intermediates.

    Parameters
    ----------
    log10_L_engine : 1-D array – log10 of engine luminosity (erg/s)
    t_days_grid : 1-D array – time grid in days (must be uniformly spaced)
    log10_td_days : scalar – log10 of diffusion timescale in days
    log10_A_trap_days2 : scalar – log10 of trapping coefficient in days^2
    log10_L_scale : scalar – normalization scale (log10 erg/s)

    Returns
    -------
    log10_L_obs : 1-D array – log10 of observed luminosity (erg/s)
    """
    td = jnp.power(10.0, log10_td_days)
    A_trap = jnp.power(10.0, log10_A_trap_days2)
    dt = t_days_grid[1] - t_days_grid[0]

    # Normalized engine luminosity
    L_engine_n = jnp.power(10.0, log10_L_engine - log10_L_scale)

    def _scan_step(V_n, inputs):
        L_n_i, t_i = inputs
        source = L_n_i * t_i
        loss = 2.0 * t_i / td**2 * V_n
        V_n_new = jnp.maximum(V_n + (source - loss) * dt, 0.0)
        # Observed luminosity (normalized)
        trap_factor = -jnp.expm1(-A_trap / jnp.maximum(t_i**2, 1e-30))
        L_obs_n = 2.0 / td**2 * V_n_new * trap_factor
        return V_n_new, L_obs_n

    V0 = L_engine_n[0] * t_days_grid[0] * dt
    _, L_obs_n = jax.lax.scan(_scan_step, V0, (L_engine_n, t_days_grid))

    log10_L_obs = jnp.log10(jnp.maximum(L_obs_n, 1e-30)) + log10_L_scale
    return log10_L_obs


def _build_log_mirror_quad(n_half=50, minimum_log_spacing=-3):
    """Build log-spaced quadrature nodes mirrored around 0.5, on [0, 1].

    Matches Redback's ``Diffusion.convert_input_luminosity`` node layout:
    logspace from a small value to 1, concatenated with (1 - logspace), unique-sorted.
    Returns a 1-D JAX array.
    """
    lsp = np.logspace(minimum_log_spacing, 0, n_half)
    xm = np.unique(np.concatenate((lsp, 1.0 - lsp)))
    return jnp.array(xm)


# Pre-computed quadrature nodes (module-level, outside JIT)
_ARNETT_QUAD_NODES = _build_log_mirror_quad(n_half=50)
_CSM_QUAD_NODES = _build_log_mirror_quad(n_half=1500)


def _arnett_diffusion_integral(log10_L_engine, t_days_grid,
                                log10_td_days, log10_A_trap_days2,
                                log10_L_scale, xm_quad=_ARNETT_QUAD_NODES):
    """Arnett (1982) diffusion via trapezoidal integral (Redback-matched).

    For each evaluation time t_e, computes:
        L_obs(t_e) = (2/td^2) * (1 - exp(-A_trap/t_e^2))
                     * integral_0^{t_e} L(t)*t*exp((t^2-t_e^2)/td^2) dt

    using ``jnp.trapezoid`` on log-mirrored quadrature nodes.

    Parameters
    ----------
    log10_L_engine : 1-D array – log10 of engine luminosity (erg/s)
    t_days_grid : 1-D array – time grid in days
    log10_td_days : scalar – log10 of diffusion timescale in days
    log10_A_trap_days2 : scalar – log10 of trapping coefficient in days^2
    log10_L_scale : scalar – normalization scale (log10 erg/s)
    xm_quad : 1-D array – quadrature nodes on [0, 1]

    Returns
    -------
    log10_L_obs : 1-D array – log10 of observed luminosity (erg/s)
    """
    td = jnp.power(10.0, log10_td_days)
    A_trap = jnp.power(10.0, log10_A_trap_days2)

    # Normalized engine luminosity
    L_engine_n = jnp.power(10.0, log10_L_engine - log10_L_scale)

    def _integrate_one(t_e):
        int_t = t_e * xm_quad                          # (N_quad,)
        L_at_t = jnp.interp(int_t, t_days_grid, L_engine_n)
        exponent = jnp.clip((int_t**2 - t_e**2) / td**2, -80.0, 0.0)
        integrand = L_at_t * int_t * jnp.exp(exponent)
        integral = jnp.trapezoid(integrand, int_t)
        trap_factor = -jnp.expm1(-A_trap / jnp.maximum(t_e**2, 1e-30))
        return jnp.maximum(2.0 / td**2 * integral * trap_factor, 0.0)

    L_obs_n = jax.vmap(_integrate_one)(t_days_grid)
    return jnp.log10(jnp.maximum(L_obs_n, 1e-30)) + log10_L_scale


def _csm_diffusion_integral(log10_L_input, t_days_grid,
                             t0_csm_days, log10_L_scale,
                             xm_quad=_CSM_QUAD_NODES):
    """CSM diffusion kernel via trapezoidal integral (Redback-matched).

    For each evaluation time t_e, computes:
        L_obs(t_e) = (1/t0) * integral_0^{t_e} L(t)*exp((t-t_e)/t0) dt

    Parameters
    ----------
    log10_L_input : 1-D array – log10 of input luminosity (erg/s)
    t_days_grid : 1-D array – time grid in days
    t0_csm_days : scalar – CSM diffusion timescale in days
    log10_L_scale : scalar – normalization scale (log10 erg/s)
    xm_quad : 1-D array – quadrature nodes on [0, 1]

    Returns
    -------
    log10_L_obs : 1-D array – log10 of observed luminosity (erg/s)
    """
    L_n = jnp.power(10.0, log10_L_input - log10_L_scale)

    def _csm_integrate_one(t_e):
        int_t = t_e * xm_quad
        L_at_t = jnp.interp(int_t, t_days_grid, L_n)
        # Combined exponent (t - t_e)/t0 is always <= 0
        integrand = L_at_t * jnp.exp((int_t - t_e) / t0_csm_days)
        return jnp.maximum(jnp.trapezoid(integrand, int_t) / t0_csm_days, 0.0)

    L_diff_n = jax.vmap(_csm_integrate_one)(t_days_grid)
    return jnp.log10(jnp.maximum(L_diff_n, 1e-30)) + log10_L_scale


# ---------------------------------------------------------------------------
# Base class
# ---------------------------------------------------------------------------

[docs] class AnalyticalModel: """Base class for analytical (non-surrogate) light-curve models. Subclasses must implement ``compute_log10_lbol_rphot(self, x, t_days)`` which returns ``(log10_Lbol, log10_Rphot)`` — log10 of bolometric luminosity in erg/s and photospheric radius in cm. """ parameter_names: list[str] filters: list[str] times: Array # source-frame days _nus: Array def __init__(self, filters: list[str], times: Array = None, temperature_floor: float | None = None): self.filters = [] self.Filters = [] self._nus = None self.times = None self.temperature_floor = temperature_floor self.add_filter(filters) if times is not None: self.times = jnp.asarray(times) # -- filter management ---------------------------------------------------
[docs] def add_filter(self, filters): if isinstance(filters, (str, fiesta_filters.Filter)): filters = [filters] for filt in filters: if isinstance(filt, str): F = fiesta_filters.Filter(filt) elif isinstance(filt, fiesta_filters.Filter): F = filt else: raise TypeError("Filter must be a name string or Filter object.") if F.name not in self.filters: self.filters.append(F.name) self.Filters.append(F) self._build_nu_grid()
def _build_nu_grid(self): """Build a log-spaced frequency grid spanning all loaded filters.""" if len(self.Filters) == 0: return nu_min = min(F.nus[0] for F in self.Filters) nu_max = max(F.nus[-1] for F in self.Filters) self._nus = jnp.logspace(jnp.log10(nu_min * 0.95), jnp.log10(nu_max * 1.05), 100) # -- physics (to be overridden) ------------------------------------------
[docs] def compute_log10_lbol_rphot(self, x: dict[str, Array], t_days: Array) -> tuple[Array, Array]: """Return (log10_L_bol, log10_R_phot) arrays at each time in *t_days*. L_bol in erg/s, R_phot in cm. """ raise NotImplementedError
# -- predict (JIT-compiled) ---------------------------------------------
[docs] @partial(jax.jit, static_argnums=(0,)) def predict(self, x: dict[str, Array]) -> tuple[Array, dict[str, Array]]: if self.times is None: raise ValueError( "times must be set before calling predict(). " "Pass times= to the constructor or set model.times directly." ) t_days = self.times log10_L, log10_R = self.compute_log10_lbol_rphot(x, t_days) # Temperature from Stefan-Boltzmann: T = (L / (4 pi R^2 sigma))^{1/4} log10_T = 0.25 * (log10_L - _LOG10_4PI - 2.0 * log10_R - _LOG10_SIGMASB) # Apply temperature floor if set (branch-free for JIT) if self.temperature_floor is not None: log10_T_floor = jnp.log10(self.temperature_floor) below_floor = log10_T < log10_T_floor log10_T = jnp.where(below_floor, log10_T_floor, log10_T) # Adjust R so that L = 4*pi*R^2*sigma*T_floor^4 log10_R_eff = 0.5 * (log10_L - _LOG10_4PI - _LOG10_SIGMASB - 4.0 * log10_T_floor) log10_R = jnp.where(below_floor, log10_R_eff, log10_R) # Clamp temperature to [100, 1e6] K log10_T = jnp.clip(log10_T, 2.0, 6.0) T = jnp.power(10.0, log10_T) # Blackbody SED at each time -> magnitudes per filter nus = self._nus def _sed_at_t(args): T_i, log10_R_i = args return _log10_blackbody_mJy_at_10pc(T_i, log10_R_i, nus) # (n_times, n_nus) of log10(mJy) log10_sed = jax.vmap(_sed_at_t)((T, log10_R)) # Convert to linear mJy: (n_times, n_nus) -> transpose to (n_nus, n_times) sed = jnp.power(10.0, log10_sed).T # Redshift: shift frequencies only (source-frame times returned) z = x.get("redshift", 0.0) nus_obs = nus / (1.0 + z) sed_obs = sed * (1.0 + z) mag_abs = {F.name: F.get_mag(sed_obs, nus_obs) for F in self.Filters} dL = x["luminosity_distance"] mag_app = {k: mag_app_from_mag_abs(v, dL) for k, v in mag_abs.items()} return t_days, mag_app