Source code for fiesta.inference.analytical_models.salt3_models
"""SALT3 spectral-template supernova model via jax-bandflux.
Uses ``jax_supernovae`` (PyPI: ``jax-bandflux``) for JAX-native, JIT-compiled,
differentiable SALT3 light-curve evaluation. Unlike the physics-based models
that compute L_bol + R_phot -> blackbody SED, SALT3 uses spectral templates
(M0, M1, colour law) to compute per-band fluxes directly.
The ``jax_supernovae`` import is kept lazy to avoid loading heavy dependencies
for users who don't use SALT3.
"""
from functools import partial
import jax
import jax.numpy as jnp
from jaxtyping import Array
[docs]
class SALT3Model:
"""SALT3 spectral-template model for Type Ia supernova light curves.
Parameters
----------
filters : list[str]
Band names recognised by ``jax_supernovae.bandpasses`` (e.g.
``"ztfg"``, ``"ztfr"``, ``"bessellb"``).
times : Array, optional
Observer-frame times (days) at which to evaluate the model.
redshift : float
Source redshift (fixed, not sampled).
Sampled parameters (passed via ``predict(x)``):
log10_x0 – log10 of the SALT3 amplitude parameter *x0*
x1 – SALT3 stretch
c – SALT3 colour
t0 – time of B-band maximum (days, same frame as *times*)
"""
parameter_names: list[str]
filters: list[str]
times: Array
def __init__(self, filters: list[str], times: Array = None,
redshift: float = 0.0):
# Lazy import to avoid loading heavy dependencies at module level
try:
from jax_supernovae.salt3 import (
optimized_salt3_multiband_flux,
precompute_bandflux_bridge,
)
from jax_supernovae.bandpasses import get_bandpass, register_all_bandpasses
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"SALT3Model requires the 'jax-bandflux' package. "
"Install it with: pip install jax-bandflux"
) from exc
register_all_bandpasses()
if isinstance(filters, str):
filters = [filters]
if not filters:
raise ValueError("At least one filter must be provided")
self.filters = list(filters)
if times is None:
raise ValueError(
"times must be provided (observer-frame days array)"
)
self.times = jnp.asarray(times)
self.redshift = redshift
self.parameter_names = ["log10_x0", "x1", "c", "t0"]
# Pre-compute bridges (expensive, done once)
self._bridges = tuple(
precompute_bandflux_bridge(get_bandpass(f)) for f in self.filters
)
# Cache zpbandflux_ab per band for AB mag conversion
self._zpbandflux_ab = jnp.array(
[b['zpbandflux_ab'] for b in self._bridges]
)
# Store reference to the flux function
self._multiband_flux = optimized_salt3_multiband_flux
[docs]
@partial(jax.jit, static_argnums=(0,))
def predict(self, x: dict[str, Array]) -> tuple[Array, dict[str, Array]]:
t_days = self.times # observer-frame days
# Build SALT3 params dict (z and t0 handled internally by the function)
x0 = jnp.power(10.0, x["log10_x0"])
salt3_params = {
"x0": x0,
"x1": x["x1"],
"c": x["c"],
"z": jnp.array(self.redshift),
"t0": x["t0"],
}
# Compute fluxes: (n_times, n_bands) in photons/s/cm^2
flux_matrix = self._multiband_flux(t_days, self._bridges, salt3_params)
# Convert to apparent AB magnitudes
mag_app = {}
for i, fname in enumerate(self.filters):
flux = flux_matrix[:, i]
mag = -2.5 * jnp.log10(
jnp.maximum(flux / self._zpbandflux_ab[i], 1e-30)
)
mag_app[fname] = mag
return t_days, mag_app