import os
import ast
import warnings
import h5py
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from matplotlib.cm import ScalarMappable
from scipy.integrate import trapezoid
from scipy.interpolate import interp1d
from fiesta.inference.lightcurve_model import LightcurveModel, FluxModel
from fiesta.inference.plot import latex_labels
[docs]
class Benchmarker:
def __init__(self,
model: LightcurveModel,
data_file: str,
filters: list = None,
outdir: str = "./benchmarks",
metric_name: str = "Linf",
) -> None:
self.model = model
self.times = self.model.times
self.file = data_file
self.outdir = outdir
# Load filters
if filters is None:
self.Filters = model.Filters
else:
self.Filters = [Filt for Filt in model.Filters if Filt.name in filters]
print(f"Loaded filters are: {[Filt.name for Filt in self.Filters]}.")
# Load metric
if metric_name == "L2":
self.metric_name = "$\\mathcal{L}_2$"
self.metric = lambda y: np.sqrt(trapezoid(x= np.log(self.times) ,y=y**2, axis = -1)) / (np.log(self.times[-1]) - np.log(self.times[0]))
self.metric2d = lambda y: np.sqrt(trapezoid(x = self.nus, y =trapezoid(x = self.times, y = (y**2).reshape(-1, len(self.nus), len(self.times)) ) ))
self.file_ending = "L2"
else:
self.metric_name = "$\\mathcal{L}_\\infty$"
self.metric = lambda y: np.max(np.abs(y), axis = -1)
self.metric2d = lambda y: np.max(np.abs(y), axis = (1,2))
self.file_ending = "Linf"
self.get_data()
self.calculate_error()
self.get_error_distribution()
[docs]
def get_data(self,):
# get the test data
self.test_mag = {}
with h5py.File(self.file, "r") as f:
self.parameter_distributions = self.model.parameter_distributions
self.parameter_names = self.model.parameter_names
nus = f["nus"][:]
self.test_X_raw = f["test"]["X"][:]
test_y_raw = f["test"]["y"][:]
test_y_raw = test_y_raw.reshape(len(self.test_X_raw), len(f["nus"]), len(f["times"]) )
test_y_raw = interp1d(f["times"][:], test_y_raw, axis = 2)(self.times) # interpolate the test data over the time range of the model
self.test_log_flux = test_y_raw # store log10 flux for FluxModel error calculation
self.data_nus = nus # store raw data frequency grid
mJys = np.power(10, test_y_raw)
if "redshift" in self.parameter_names:
from fiesta.train.DataManager import concatenate_redshift, redshifted_magnitude
self.test_X_raw = concatenate_redshift(self.test_X_raw, max_z=self.parameter_distributions["redshift"][1])
for Filt in self.Filters:
self.test_mag[Filt.name] = jnp.array(redshifted_magnitude(Filt, mJys.copy(), nus, self.test_X_raw[:,-1]))
else:
for Filt in self.Filters:
self.test_mag[Filt.name] = Filt.get_mags(mJys, nus)
# get the model prediction on the test data
param_dict = dict(zip(self.parameter_names, self.test_X_raw.T))
param_dict["luminosity_distance"] = np.ones(len(self.test_X_raw)) * 1e-5
if "redshift" not in param_dict.keys():
param_dict["redshift"] = np.zeros(len(self.test_X_raw))
_, self.pred_mag = self.model.vpredict(param_dict)
[docs]
def calculate_error(self,):
self.error = {}
for Filt in self.Filters:
test_y = self.test_mag[Filt.name]
pred_y = self.pred_mag[Filt.name]
mask = np.isinf(pred_y) | np.isinf(test_y)
test_y = test_y.at[mask].set(0.)
pred_y = pred_y.at[mask].set(0.)
self.error[Filt.name] = self.metric(test_y - pred_y)
if isinstance(self.model, FluxModel):
self.nus = self.model.nus
log_flux_pred = []
for j in range(len(self.test_X_raw)):
param_dict_j = dict(zip(self.parameter_names, self.test_X_raw[j], strict=True))
param_dict_j["luminosity_distance"] = 1e-5
param_dict_j["redshift"] = 0.0
_, pred_nus, log_flux = self.model.predict_log_flux(param_dict_j)
log_flux_pred.append(log_flux)
log_flux_pred = np.array(log_flux_pred)
# Interpolate ground truth onto the prediction's nu/time grid
pred_nus = np.array(pred_nus)
test_log_interp = interp1d(self.data_nus, self.test_log_flux, axis=1,
bounds_error=False, fill_value=np.nan)(pred_nus)
log_flux_residual = log_flux_pred - test_log_interp
# Mask non-finite entries before clipping
nan_mask = ~np.isfinite(log_flux_residual)
n_nan = np.count_nonzero(nan_mask)
n_total = log_flux_residual.size
self.nan_fraction = n_nan / n_total if n_total > 0 else 0.0
if n_nan > 0:
warnings.warn(
f"Benchmarker: {n_nan}/{n_total} ({100*self.nan_fraction:.1f}%) "
f"residual entries are NaN/Inf (likely from frequency grid "
f"extrapolation). These entries are excluded from the total "
f"error calculation.",
stacklevel=2)
# Set non-finite entries to NaN, then clip physical residuals
log_flux_residual = np.where(nan_mask, np.nan, log_flux_residual)
log_flux_residual = np.clip(log_flux_residual, -100, 100)
# Exclude NaN/Inf entries from error calculation
if self.file_ending == "Linf":
self.error["total"] = np.nanmax(np.abs(log_flux_residual), axis=(1, 2))
else:
# NaN-aware L2: integrate only over finite entries per sample
r2 = np.where(nan_mask, np.nan, log_flux_residual ** 2)
self.error["total"] = np.sqrt(np.nanmean(r2, axis=(1, 2)))
# Replace NaN/Inf (from all-NaN samples or overflow) with 0
self.error["total"] = np.nan_to_num(
self.error["total"], nan=0.0, posinf=0.0, neginf=0.0)
else:
max_errors = {key: np.max(value) for key, value in self.error.items()}
max_key = max(max_errors, key=max_errors.get)
self.error["total"] = self.error[max_key]
[docs]
def get_error_distribution(self,):
error_distribution = {}
# Normalize weights to prevent overflow in density computation
total = self.error["total"]
w_max = np.max(np.abs(total))
if w_max > 0:
weights = total / w_max
else:
weights = np.ones_like(total)
for j, p in enumerate(self.parameter_names):
p_array = self.test_X_raw[:,j]
bins = np.linspace(self.parameter_distributions[p][0], self.parameter_distributions[p][1], 12)
error_distribution[p] = np.histogram(p_array, weights=weights, bins=bins, density=True)
self.error_distribution = error_distribution
[docs]
def benchmark(self,):
self.print_correlations()
self.plot_worst_lightcurves()
self.plot_error_over_time()
self.plot_error_distribution()
[docs]
def plot_lightcurves_mismatch(self):
if self.metric_name == "$\\mathcal{L}_2$":
vline = self.metric(np.ones(len(self.times)))
vmin, vmax = 0, vline*2
bins = np.linspace(vmin, vmax, 25)
else:
vline = 1.
vmin, vmax = 0, 2*vline
bins = np.linspace(vmin, vmax, 20)
cmap = colors.LinearSegmentedColormap.from_list(name = "mymap", colors = [(0, "lightblue"), (1, "darkred")])
label_dic = {p: latex_labels.get(p, p) for p in self.parameter_names}
for Filt in self.Filters:
mismatch = self.error[Filt.name]
colored_mismatch = cmap(mismatch/vmax)
fig, ax = plt.subplots(len(self.parameter_names)-1, len(self.parameter_names)-1)
fig.suptitle(f"{Filt.name}: {self.metric_name} norm")
for j, p in enumerate(self.parameter_names[1:]):
for k, pp in enumerate(self.parameter_names[:j+1]):
sort = np.argsort(mismatch)
ax[j,k].scatter(self.test_X_raw[sort,k], self.test_X_raw[sort,j+1], c = colored_mismatch[sort], s = 1, rasterized = True)
ax[j,k].set_xlim((self.test_X_raw[:,k].min(), self.test_X_raw[:,k].max()))
ax[j,k].set_ylim((self.test_X_raw[:,j+1].min(), self.test_X_raw[:,j+1].max()))
if k!=0:
ax[j,k].set_yticklabels([])
if j!=len(self.parameter_names)-2:
ax[j,k].set_xticklabels([])
ax[-1,k].set_xlabel(label_dic[pp])
ax[j,0].set_ylabel(label_dic[p])
for cax in ax[j, j+1:]:
cax.set_axis_off()
ax[0,-1].set_axis_on()
ax[0,-1].hist(mismatch, density = True, histtype = "step", bins = bins,)
ax[0,-1].vlines([vline], *ax[0,-1].get_ylim(), colors = ["lightgrey"], linestyles = "dashed")
ax[0,-1].set_yticks([])
fig.colorbar(ScalarMappable(norm=colors.Normalize(vmin = vmin, vmax = vmax), cmap = cmap), ax = ax[1:-1, -1])
outfile = f"benchmark_{Filt.name}_{self.file_ending}.pdf"
fig.savefig(os.path.join(self.outdir, outfile))
plt.close(fig)
[docs]
def plot_worst_lightcurves(self,):
label_dic = {p: latex_labels.get(p, p) for p in self.parameter_names}
MAG_FAINT_CLIP = 40 # magnitudes fainter than this are unphysical
n_filters = len(self.Filters)
ncols = min(n_filters, 3)
nrows = int(np.ceil(n_filters / ncols))
fig, axes = plt.subplots(nrows, ncols, figsize=(6 * ncols, 4.5 * nrows))
axes = np.atleast_2d(axes)
fig.subplots_adjust(hspace=0.55, wspace=0.35, bottom=0.06, top=0.94, left=0.07, right=0.97)
for i, filt in enumerate(self.Filters):
cax = axes[i // ncols, i % ncols]
ind = np.argmax(self.error[filt.name])
prediction = np.array(self.pred_mag[filt.name][ind])
truth = np.array(self.test_mag[filt.name][ind])
cax.plot(self.times, truth, color="red", lw=1.8, label="Baseline", zorder=3)
cax.plot(self.times, prediction, color="royalblue", lw=1.0, alpha=0.85, label="Surrogate", zorder=2)
cax.fill_between(self.times, prediction - 1, prediction + 1,
color="royalblue", alpha=0.12, zorder=1)
# Y-limits from truth only, clamped to physical range
truth_finite = truth[np.isfinite(truth)]
truth_clipped = truth_finite[truth_finite < MAG_FAINT_CLIP]
if len(truth_clipped) > 0:
ylo = np.min(truth_clipped)
yhi = np.max(truth_clipped)
elif len(truth_finite) > 0:
ylo, yhi = np.min(truth_finite), MAG_FAINT_CLIP
else:
ylo, yhi = -5, MAG_FAINT_CLIP
pad = max(2.0, (yhi - ylo) * 0.12)
cax.set_ylim(yhi + pad, ylo - pad) # inverted for magnitudes
cax.set(xscale="log", xlim=(self.times[0], self.times[-1]))
cax.set_xlabel("$t$ [days]", fontsize=9)
cax.set_ylabel("mag", fontsize=9)
cax.set_title(filt.name, fontsize=11, fontweight="bold")
cax.grid(True, alpha=0.25, lw=0.5)
cax.tick_params(labelsize=8)
# Multi-line parameter annotation (4 params per line)
params_per_line = 4
items = [f"{label_dic.get(p, p)}={self.test_X_raw[ind, j]:.2g}"
for j, p in enumerate(self.parameter_names)]
lines = [", ".join(items[k:k + params_per_line])
for k in range(0, len(items), params_per_line)]
param_str = "\n".join(lines)
cax.text(0.03, 0.04, param_str, transform=cax.transAxes,
fontsize=6.5, color="0.35", va="bottom", family="monospace",
bbox=dict(facecolor="white", alpha=0.85, edgecolor="0.8",
pad=2, boxstyle="round,pad=0.3"))
if i == 0:
cax.legend(fontsize=9, loc="upper right",
framealpha=0.9, edgecolor="0.8")
for i in range(n_filters, nrows * ncols):
axes[i // ncols, i % ncols].set_visible(False)
fig.savefig(os.path.join(self.outdir, f"worst_lightcurves_{self.file_ending}.pdf"), dpi=200)
plt.close(fig)
[docs]
def plot_error_over_time(self,):
n_filters = len(self.Filters)
ncols = min(n_filters, 3)
nrows = int(np.ceil(n_filters / ncols))
fig, axes = plt.subplots(nrows, ncols, figsize=(6 * ncols, 4.5 * nrows))
axes = np.atleast_2d(axes)
fig.subplots_adjust(hspace=0.55, wspace=0.35, bottom=0.06, top=0.94, left=0.07, right=0.97)
# Pick time indices evenly in log-space
log_times = np.log10(self.times)
target_log = np.linspace(log_times[0], log_times[-1], 10)
indices = np.array([np.argmin(np.abs(log_times - t)) for t in target_log])
indices = np.unique(indices)
for i, filt in enumerate(self.Filters):
cax = axes[i // ncols, i % ncols]
error = np.abs(np.array(self.pred_mag[filt.name]) - np.array(self.test_mag[filt.name]))
error = np.where(np.isfinite(error), error, 0.0)
# Clip outliers at 99th percentile across all times for cleaner violins
all_err = error[:, indices].ravel()
clip_val = np.percentile(all_err[all_err > 0], 99) if np.any(all_err > 0) else 1.0
error_clipped = np.clip(error, 0, clip_val)
# Use log-space positions for the violin plot
log_pos = np.log10(self.times[indices])
spacing = np.diff(np.concatenate([[log_pos[0] - 0.5], log_pos]))
width = spacing * 0.55
width = np.clip(width, 0.08, None)
data_list = [error_clipped[:, idx] for idx in indices]
parts = cax.violinplot(data_list, positions=log_pos, widths=width, points=300,
showmedians=True, showextrema=False)
for pc in parts["bodies"]:
pc.set_facecolor("steelblue")
pc.set_edgecolor("steelblue")
pc.set_alpha(0.5)
parts["cmedians"].set_color("darkred")
parts["cmedians"].set_linewidth(1.5)
# Manual log-scale tick labels
tick_vals = np.array([1e-4, 1e-3, 1e-2, 1e-1, 1, 10, 100, 1000])
tick_vals = tick_vals[(tick_vals >= self.times[0]) & (tick_vals <= self.times[-1])]
cax.set_xticks(np.log10(tick_vals))
cax.set_xticklabels([f"$10^{{{int(np.log10(v))}}}$" for v in tick_vals],
fontsize=8)
cax.set_xlim(log_times[0] - 0.3, log_times[-1] + 0.3)
# Y-limit from the clipped data median + a few sigma
medians = np.array([np.median(d) for d in data_list])
p90 = np.percentile(error_clipped[:, indices].ravel(), 90)
cax.set_ylim(0, max(p90 * 1.5, np.max(medians) * 3, 0.5))
cax.set_xlabel("$t$ [days]", fontsize=9)
cax.set_ylabel("error [mag]", fontsize=9)
cax.set_title(filt.name, fontsize=11, fontweight="bold")
cax.grid(True, axis="y", alpha=0.25, lw=0.5)
cax.tick_params(labelsize=8)
for i in range(n_filters, nrows * ncols):
axes[i // ncols, i % ncols].set_visible(False)
fig.savefig(os.path.join(self.outdir, "error_over_time.pdf"), dpi=200)
plt.close(fig)
[docs]
def print_correlations(self, ):
for Filt in self.Filters:
error = self.error[Filt.name]
print(f"\n \n \nCorrelations for filter {Filt.name}:\n")
for j, p in enumerate(self.parameter_names):
print(f"{p}: {np.corrcoef(self.test_X_raw[:,j], error)[0,1]}")
[docs]
def plot_error_distribution(self,):
label_dic = {p: latex_labels.get(p, p) for p in self.parameter_names}
n_params = len(self.parameter_names)
ncols = min(n_params, 4)
nrows = int(np.ceil(n_params / ncols))
fig, axes = plt.subplots(nrows, ncols, figsize=(4.5 * ncols, 3.5 * nrows))
axes = np.atleast_2d(axes)
fig.subplots_adjust(hspace=0.6, wspace=0.4, bottom=0.10, top=0.90, left=0.07, right=0.97)
nan_frac = getattr(self, 'nan_fraction', 0.0)
title = "Total error distribution per parameter"
if nan_frac > 0:
title += f" ({100*nan_frac:.1f}% of flux residual entries were NaN/Inf, excluded)"
fig.suptitle(title, fontsize=10)
for j, p in enumerate(self.parameter_names):
cax = axes[j // ncols, j % ncols]
p_array = self.test_X_raw[:, j]
pmin, pmax = self.parameter_distributions[p][0], self.parameter_distributions[p][1]
bins = np.linspace(pmin, pmax, 15)
# Mean error per bin
counts, _ = np.histogram(p_array, bins=bins)
weighted, _ = np.histogram(p_array, bins=bins, weights=self.error["total"])
mean_error = np.where(counts > 0, weighted / counts, 0)
bin_centers = 0.5 * (bins[:-1] + bins[1:])
cax.bar(bin_centers, mean_error, width=np.diff(bins) * 0.85,
color="steelblue", edgecolor="white", linewidth=0.5)
cax.set_xlabel(label_dic.get(p, p), fontsize=9)
cax.set_ylabel(f"mean {self.metric_name}", fontsize=9)
cax.set_xlim(pmin, pmax)
cax.grid(True, axis="y", alpha=0.25, lw=0.5)
cax.tick_params(labelsize=8)
for i in range(n_params, nrows * ncols):
axes[i // ncols, i % ncols].set_visible(False)
fig.savefig(os.path.join(self.outdir, "error_distribution.pdf"), dpi=200)
plt.close(fig)