"""Phenomenological (flux-shape) light-curve models.
Reference:
Redback: https://github.com/nikhil-sarin/redback/blob/master/redback/transient_models/phenomenological_models.py
Boom: https://github.com/boom-astro/boom/blob/main/src/fitting/parametric.rs
"""
from functools import partial
import jax
import jax.numpy as jnp
from jaxtyping import Array
from fiesta.inference.analytical_models.base import (
AnalyticalModel,
_LOG10_4PI, _LOG10_SIGMASB,
)
[docs]
class PhenomenologicalModel(AnalyticalModel):
"""Base class for phenomenological light-curve models.
Unlike physics-based models that compute L_bol + R_phot and pass through
a blackbody SED, phenomenological models compute a temporal shape function
S(t) and convert directly to per-band apparent magnitudes.
Subclasses must set:
shape_parameter_names : list[str] — shared temporal shape parameters
has_baseline : bool — whether the model has a baseline flux component
and implement ``compute_shape(self, x, t_days) -> Array``.
"""
shape_parameter_names: list[str]
has_baseline: bool = False
def __init__(self, filters: list[str], times: Array = None):
# Build filter list via parent (sets self.filters, self.Filters)
super().__init__(filters, times)
# Build parameter_names dynamically from shape params + per-band params
self._build_parameter_names()
def _build_parameter_names(self):
names = list(self.shape_parameter_names)
for fname in self.filters:
names.append(f"amp_mag_{fname}")
if self.has_baseline:
names.append(f"base_mag_{fname}")
self.parameter_names = names
[docs]
def add_filter(self, filters):
super().add_filter(filters)
self._build_parameter_names()
[docs]
def compute_shape(self, x: dict[str, Array],
t_days: Array) -> Array:
"""Return the temporal shape function S(t) >= 0."""
raise NotImplementedError
[docs]
@partial(jax.jit, static_argnums=(0,))
def predict(self, x: dict[str, Array]) -> tuple[Array, dict[str, Array]]:
t_days = self.times
shape = self.compute_shape(x, t_days)
mag_app = {}
for fname in self.filters:
amp_mag = x[f"amp_mag_{fname}"]
if self.has_baseline:
base_mag = x[f"base_mag_{fname}"]
# total_flux = 10^(-0.4*amp_mag) * shape + 10^(-0.4*base_mag)
total_flux = (jnp.power(10.0, -0.4 * amp_mag) * shape
+ jnp.power(10.0, -0.4 * base_mag))
mag_app[fname] = -2.5 * jnp.log10(jnp.maximum(total_flux, 1e-30))
else:
# mag = amp_mag - 2.5 * log10(max(shape, 1e-30))
mag_app[fname] = amp_mag - 2.5 * jnp.log10(
jnp.maximum(shape, 1e-30))
return t_days, mag_app
[docs]
class EvolvingBlackbodyModel(AnalyticalModel):
"""Phenomenological model with piecewise power-law T and R evolution.
Reference:
Redback: https://github.com/nikhil-sarin/redback/blob/master/redback/transient_models/phenomenological_models.py
Model-agnostic — useful for fast empirical fitting of any thermal
transient. Based on the evolving_blackbody model from Redback.
Parameters (in ``x`` dict):
log10_temperature_0 – log10 initial temperature (K) at reference_time
log10_radius_0 – log10 initial radius (cm) at reference_time
temp_rise_index – T rise power-law index for t <= temp_peak_time
temp_decline_index – T decline power-law index for t > temp_peak_time
temp_peak_time – time (days) when temperature peaks
radius_rise_index – R rise power-law index for t <= radius_peak_time
radius_decline_index – R decline power-law index for t > radius_peak_time
radius_peak_time – time (days) when radius peaks
"""
parameter_names = [
"log10_temperature_0", "log10_radius_0",
"temp_rise_index", "temp_decline_index", "temp_peak_time",
"radius_rise_index", "radius_decline_index", "radius_peak_time",
]
def __init__(self, filters, times=None, reference_time=1.0):
self.reference_time = reference_time
if times is None:
times = jnp.geomspace(0.1, 30.0, 100)
super().__init__(filters, times)
[docs]
def compute_log10_lbol_rphot(self, x, t_days):
T0 = jnp.power(10.0, x["log10_temperature_0"])
R0 = jnp.power(10.0, x["log10_radius_0"])
a_T_rise = x["temp_rise_index"]
a_T_dec = x["temp_decline_index"]
t_pk_T = x["temp_peak_time"]
a_R_rise = x["radius_rise_index"]
a_R_dec = x["radius_decline_index"]
t_pk_R = x["radius_peak_time"]
t_ref = self.reference_time
# Temperature evolution (piecewise power-law)
T_peak = T0 * jnp.power(t_pk_T / t_ref, a_T_rise)
T = jnp.where(
t_days <= t_pk_T,
T0 * jnp.power(t_days / t_ref, a_T_rise),
T_peak * jnp.power(t_days / t_pk_T, -a_T_dec),
)
# Radius evolution (piecewise power-law)
R_peak = R0 * jnp.power(t_pk_R / t_ref, a_R_rise)
R = jnp.where(
t_days <= t_pk_R,
R0 * jnp.power(t_days / t_ref, a_R_rise),
R_peak * jnp.power(t_days / t_pk_R, -a_R_dec),
)
# L = 4 * pi * R^2 * sigma * T^4
log10_L = (_LOG10_4PI + 2.0 * jnp.log10(jnp.maximum(R, 1.0))
+ _LOG10_SIGMASB + 4.0 * jnp.log10(jnp.maximum(T, 1.0)))
log10_R = jnp.log10(jnp.maximum(R, 1.0))
return log10_L, log10_R
[docs]
class BazinModel(PhenomenologicalModel):
"""Bazin et al. phenomenological light-curve model.
Reference:
Boom: https://github.com/boom-astro/boom/blob/main/src/fitting/parametric.rs
Shape: exp(-dt/tau_fall) * sigmoid(dt/tau_rise)
Parameters (per-band): amp_mag_{filter}, base_mag_{filter}
Shape parameters: t0, log10_tau_rise, log10_tau_fall
"""
shape_parameter_names = ["t0", "log10_tau_rise", "log10_tau_fall"]
has_baseline = True
def __init__(self, filters, times=None):
if times is None:
times = jnp.linspace(0.01, 100.0, 200)
super().__init__(filters, times)
[docs]
def compute_shape(self, x, t_days):
t0 = x["t0"]
tau_rise = jnp.power(10.0, x["log10_tau_rise"])
tau_fall = jnp.power(10.0, x["log10_tau_fall"])
dt = t_days - t0
exp_arg = jnp.clip(-dt / tau_fall, -80.0, 80.0)
return jnp.exp(exp_arg) * jax.nn.sigmoid(dt / tau_rise)
[docs]
class VillarModel(PhenomenologicalModel):
"""Villar et al. phenomenological light-curve model.
Reference:
Boom: https://github.com/boom-astro/boom/blob/main/src/fitting/parametric.rs
Piecewise shape with smooth sigmoid transition at gamma.
Shape parameters: t0, log10_tau_rise, log10_tau_fall, beta_slope, log10_gamma
"""
shape_parameter_names = ["t0", "log10_tau_rise", "log10_tau_fall",
"beta_slope", "log10_gamma"]
has_baseline = False
def __init__(self, filters, times=None):
if times is None:
times = jnp.linspace(0.01, 150.0, 200)
super().__init__(filters, times)
[docs]
def compute_shape(self, x, t_days):
t0 = x["t0"]
tau_rise = jnp.power(10.0, x["log10_tau_rise"])
tau_fall = jnp.power(10.0, x["log10_tau_fall"])
beta = x["beta_slope"]
gamma = jnp.power(10.0, x["log10_gamma"])
phase = t_days - t0
sig_rise = jax.nn.sigmoid(phase / tau_rise)
w = jax.nn.sigmoid(10.0 * (phase - gamma))
piece_left = 1.0 - beta * phase
piece_right = (1.0 - beta * gamma) * jnp.exp((gamma - phase) / tau_fall)
return sig_rise * jnp.maximum(
(1.0 - w) * piece_left + w * piece_right, 1e-30)
[docs]
def constraint_penalty(self, x):
"""Physical validity penalty (de Soto et al. 2024).
Returns 0 for valid parameters, positive for invalid. Multiply by
a large negative factor and add to log-likelihood to enforce.
"""
beta = x["beta_slope"]
gamma = jnp.power(10.0, x["log10_gamma"])
tau_rise = jnp.power(10.0, x["log10_tau_rise"])
tau_fall = jnp.power(10.0, x["log10_tau_fall"])
return (
jnp.maximum(gamma * beta - 1.0, 0.0)
+ jnp.maximum(
jnp.exp(-gamma / tau_rise) * (tau_fall / tau_rise - 1.0) - 1.0,
0.0,
)
+ jnp.maximum(beta * tau_fall - 1.0 + beta * gamma, 0.0)
)
[docs]
class PhenomenologicalTDEModel(PhenomenologicalModel):
"""Phenomenological TDE light-curve model.
Reference:
Boom: https://github.com/boom-astro/boom/blob/main/src/fitting/parametric.rs
Sigmoid rise with power-law decay.
Shape parameters: t0, log10_tau_rise, log10_tau_fall, alpha_decay
"""
shape_parameter_names = ["t0", "log10_tau_rise", "log10_tau_fall",
"alpha_decay"]
has_baseline = True
def __init__(self, filters, times=None):
if times is None:
times = jnp.linspace(0.01, 200.0, 200)
super().__init__(filters, times)
[docs]
def compute_shape(self, x, t_days):
t0 = x["t0"]
tau_rise = jnp.power(10.0, x["log10_tau_rise"])
tau_fall = jnp.power(10.0, x["log10_tau_fall"])
alpha = x["alpha_decay"]
phase = t_days - t0
sig = jax.nn.sigmoid(phase / tau_rise)
phase_soft = jax.nn.softplus(phase) + 1e-6
decay = jnp.power(1.0 + phase_soft / tau_fall, -alpha)
return sig * decay
[docs]
class AfterglowModel(PhenomenologicalModel):
"""Smooth broken power-law afterglow model.
Reference:
Boom: https://github.com/boom-astro/boom/blob/main/src/fitting/parametric.rs
Transitions from r^(-alpha_1) at early times to r^(-alpha_2) at late times.
Shape parameters: t0, log10_t_break, alpha_1, alpha_2
"""
shape_parameter_names = ["t0", "log10_t_break", "alpha_1", "alpha_2"]
has_baseline = False
def __init__(self, filters, times=None):
if times is None:
times = jnp.linspace(0.01, 300.0, 200)
super().__init__(filters, times)
[docs]
def compute_shape(self, x, t_days):
t0 = x["t0"]
t_break = jnp.power(10.0, x["log10_t_break"])
alpha_1 = x["alpha_1"]
alpha_2 = x["alpha_2"]
phase = t_days - t0
phase_soft = jax.nn.softplus(phase) + 1e-6
r = phase_soft / t_break
ln_r = jnp.log(r)
u1 = jnp.exp(2.0 * alpha_1 * ln_r)
u2 = jnp.exp(2.0 * alpha_2 * ln_r)
return jnp.power(u1 + u2, -0.5)