"""Functions for creating and handling injections"""
import argparse
import os
import h5py
from jaxtyping import Float, Array
import numpy as np
from fiesta.inference.lightcurve_model import LightcurveModel
from fiesta.conversions import mag_app_from_mag_abs, apply_redshift
from fiesta.filters import Filter
from fiesta.utils import write_event_data
from fiesta.logging import logger
from fiesta.train.AfterglowData import RunAfterglowpy, RunPyblastafterglow
# TODO: get the parser going
[docs]
def get_parser(**kwargs):
add_help = kwargs.get("add_help", True)
parser = argparse.ArgumentParser(
description="Inference on kilonova and GRB parameters.",
add_help=add_help,
)
[docs]
class InjectionBase:
"""
Base class to create synthetic injection lightcurves.
The injection model is first initialized with the following parameters:
filters (list): List of filters in which the synthetic data should be given out.
trigger_time (float): Reference trigger time (e.g. MJD or GPS seconds) added as an offset to all detection time stamps. Required.
tmin (float): Time of earliest synthetic detection possible in days. Defaults to 0.1.
tmax (float): Time of latest synthetic detection possible in days. Defaults to 10.0
N_datapoints (int): Total number of datapoints (across all filters) for the synthetic lightcurve. Defaults to 10.
t_detect (dict[str, Array]): Detection time points in each filter. If none is specified, then the detection times will be sampled randomly.
error_budget (float): Typical measurement error scale of the synthetic data. Defaults to 1.
detection_limit (float): Synthetic datapoints with mangnitude higher than this value (i.e. less brighter) will be turned into nondetections. Defaults to np.inf.
nondetections (bool): Additional to detection_limit, this turns some of the synthetic datapoints to nondetections. Defaults to False.
nondetections_fraction: If nondetections is True, then this will determine the fractions of N_datapoints turned into nondetections. Defaults to 0.1.
Then one can call the .create_injection() method to get synthetic lightcurve data.
The method .write_to_file() writes the synthetic lightcurve data to file.
"""
def __init__(self,
filters: list[str],
trigger_time: float,
tmin: Float = 0.1,
tmax: Float = 10.0,
N_datapoints: int = 10,
t_detect: dict[str, Array] = None,
error_budget: Float = 1.0,
nondetections: bool = False,
nondetections_fraction: Float = 0.1,
detection_limit: Float | dict[str, Float] = np.inf):
self.Filters = [Filter(filt) for filt in filters]
logger.info(f"Creating injection with filters: {filters}")
self.trigger_time = trigger_time
self.tmin = tmin
self.tmax = tmax
if t_detect is not None:
self.t_detect = t_detect
else:
self.create_t_detect(tmin, tmax, N_datapoints)
self.error_budget = error_budget
self.nondetections = nondetections
self.nondetections_fraction = nondetections_fraction
if isinstance(detection_limit, float):
detection_limit_val = detection_limit_val
detection_limit = {Filter.name: detection_limit_val for Filter in self.Filters}
self.detection_limit = detection_limit
[docs]
def create_t_detect(self, tmin, tmax, N):
"""Create a time grid for the injection data."""
self.t_detect = {}
n_filters = len(self.Filters)
if N < n_filters:
raise ValueError(f"Number of injected data points needs to be larger than number of filters.")
base = np.ones(n_filters) # each filter at least one point
points_list = base + np.random.multinomial(N-n_filters, [1/n_filters]*n_filters) # random number of time points in each filter
points_list = points_list.astype(int)
for points, Filt in zip(points_list, self.Filters):
t = np.exp(np.random.uniform(np.log(tmin), np.log(tmax), size=points))
t = np.sort(t)
t[::2] *= np.random.uniform(1, (tmax/tmin)**(1/points), size = len(t[::2])) # correlate the time points
t[::3] *= np.random.uniform(1, (tmax/tmin)**(1/points), size = len(t[::3])) # correlate the time points
mask = (t<tmin) | (t>tmax)
t[mask] = np.exp(np.random.uniform(np.log(tmin), np.log(tmax), size=np.sum(mask)))
self.t_detect[Filt.name] = np.sort(t)
[docs]
def create_injection(self,
injection_dict: dict[str, Float],
file: str = None):
"""
Creates an injection that is stored as a ``.data`` attribute.
Args:
injection_dict (dict): Parameters for the synthetic light curve.
file (str, optional): Training data file that stores light curves from the physical base model of the surrogate.
If provided, the method will take a random test element and base the injection on it.
In this case, the ``.injection_parameter`` attribute is updated to contain the real parameters used to generate the light curve.
"""
if file is None:
times, mag_app = self._get_injection_lc(injection_dict)
else:
times, mag_app, injection_dict = self._get_injection_lc_from_file(injection_dict, file)
self.injection_dict = injection_dict
self.data = {}
for Filter in self.Filters:
t_detect = self.t_detect[Filter.name]
mu = np.interp(t_detect, times, mag_app[Filter.name])
sigma = self.error_budget * np.sqrt(np.random.chisquare(df=1, size = len(t_detect)))
sigma = np.maximum(sigma, 0.01)
sigma = np.minimum(sigma, 1)
mag_measured = np.random.normal(loc=mu, scale=sigma)
# apply detection limit
not_detected = np.where(mag_measured > self.detection_limit[Filter.name])
mag_measured[not_detected] = self.detection_limit[Filter.name]
sigma[not_detected] = np.inf
self.data[Filter.name] = np.array([t_detect + self.trigger_time, mag_measured, sigma]).T
# add additional non detections
self.randomize_nondetections()
[docs]
def create_injection_from_mags(self,
times: Array,
mag_app: Array):
self.injection_dict = dict()
self.data = {}
for Filter in self.Filters:
t_detect = self.t_detect[Filter.name]
mu = np.interp(t_detect, times, mag_app[Filter.name])
sigma = self.error_budget * np.sqrt(np.random.chisquare(df=1, size = len(t_detect)))
sigma = np.maximum(sigma, 0.01)
sigma = np.minimum(sigma, 1)
mag_measured = np.random.normal(loc=mu, scale=sigma)
# apply detection limit
not_detected = np.where(mag_measured > self.detection_limit[Filter.name])
mag_measured[not_detected] = self.detection_limit[Filter.name]
sigma[not_detected] = np.inf
self.data[Filter.name] = np.array([t_detect + self.trigger_time, mag_measured, sigma]).T
# add additional non detections
self.randomize_nondetections()
def _get_injection_lc_from_file(self, injection_dict, file):
"""Create a synthetic lightcurve from training data file given the parameters in injection_dict."""
with h5py.File(file) as f:
times = f["times"][:]
nus = f["nus"][:]
parameter_names = f["parameter_names"][:].astype(str).tolist()
test_X_raw = f["test"]["X"][:]
X = np.array([injection_dict[p] for p in parameter_names])
ind = np.argmin(np.sum( ( (test_X_raw - X)/(np.max(test_X_raw, axis=0) - np.min(test_X_raw, axis=0)) )**2, axis=1))
X = test_X_raw[ind]
log_flux = f["test"]["y"][ind]
injection_dict.update(dict(zip(parameter_names, X)))
injection_dict["redshift"] = injection_dict.get("redshift", 0.0)
print(f"Found suitable injection with {injection_dict}")
mJys = np.exp(log_flux).reshape(len(nus), len(times))
mJys, times_obs, nus = apply_redshift(mJys, times, nus, injection_dict["redshift"])
if self.tmin < times_obs[0] or self.tmax > times_obs[-1]:
raise ValueError(f"Time range {(self.tmin, self.tmax)} is too large for file {file} with time range {(times[0], times[-1])} at redshift {injection_dict['redshift']}.")
mags = {}
for Filter in self.Filters:
mag_abs = Filter.get_mag(mJys, nus)
mags[Filter.name] = mag_app_from_mag_abs(mag_abs, injection_dict["luminosity_distance"])
return times_obs, mags, injection_dict
[docs]
def randomize_nondetections(self,):
if not self.nondetections:
return
N = np.sum([len(self.t_detect[Filt.name]) for Filt in self.Filters])
nondets_list = np.random.multinomial(int(N*self.nondetections_fraction), [1/len(self.Filters)]*len(self.Filters)) # random number of non detections in each filter
for nondets, Filt in zip(nondets_list, self.Filters):
inds = np.random.choice(np.arange(len(self.data[Filt.name])), size=nondets, replace=False)
self.data[Filt.name][inds] += np.array([0, -5., np.inf])
[docs]
def write_to_file(self, file: str):
write_event_data(file, self.data)
dir = os.path.dirname(file)
with open(os.path.join(dir,"param_dict.dat"), "w") as o:
o.write(str(self.injection_dict))
[docs]
class InjectionSurrogate(InjectionBase):
"""
Class to create synthetic injection lightcurves from a surrogate.
After instantiation one can call the .create_injection() method to get synthetic lightcurve data.
The method .write_to_file() writes the synthetic lightcurve data to file.
"""
def __init__(self,
model: LightcurveModel,
*args,
**kwargs):
"""
Args:
model: The surrogate used for creating the injection light curves.
filters (list): List of filters in which the synthetic data should be given out.
tmin (float): Time of earliest synthetic detection possible in days. Defaults to 0.1.
tmax (float): Time of latest synthetic detection possible in days. Defaults to 10.0
N_datapoints (int): Total number of datapoints (across all filters) for the synthetic lightcurve. Defaults to 10.
t_detect (dict[str, Array]): Detection time points in each filter. If none is specified, then the detection times will be sampled randomly.
error_budget (float): Typical measurement error scale of the synthetic data. Defaults to 1.
detection_limit (float): Synthetic datapoints with mangnitude higher than this value (i.e. less brighter) will be turned into nondetections. Defaults to np.inf.
nondetections (bool): Additional to detection_limit, this turns some of the synthetic datapoints to nondetections. Defaults to False.
nondetections_fraction (float): If nondetections is True, then this will determine the fractions of N_datapoints turned into nondetections. Defaults to 0.1.
Then one can call the .create_injection() method to get synthetic lightcurve data.
The method .write_to_file() writes the synthetic lightcurve data to file.
"""
self.model = model
super().__init__(*args, **kwargs)
def _get_injection_lc(self, injection_dict):
"""Create a synthetic lightcurve from a surrogate given the parameters in injection_dict."""
injection_dict["luminosity_distance"] = injection_dict.get('luminosity_distance', 1e-5)
injection_dict["redshift"] = injection_dict.get('redshift', 0)
times, mags = self.model.predict(injection_dict)
if self.tmin < times[0] or self.tmax > times[-1]:
raise ValueError(f"Time range {(self.tmin, self.tmax)} is too large for model {self.model} with time range {(self.model.times[0], self.model.times[-1])} at redshift {injection_dict['redshift']}.")
return times, mags
[docs]
class InjectionAfterglowpy(InjectionBase):
def __init__(self,
jet_type: int = -1,
*args,
**kwargs):
self.jet_type = jet_type
super().__init__(*args, **kwargs)
def _get_injection_lc(self, injection_dict):
"""Create a synthetic lightcurve from afterglowpy given the parameters in injection_dict."""
nus = [nu for Filter in self.Filters for nu in Filter.nus]
times = [t for Filter in self.Filters for t in self.t_detect[Filter.name]]
nus = np.sort(nus)
times = np.sort(times)
afgpy = RunAfterglowpy(self.jet_type, times, nus, [list(injection_dict.values())], injection_dict.keys())
_, log_flux = afgpy(0)
mJys = np.exp(log_flux).reshape(len(nus), len(times))
mags = {}
for Filter in self.Filters:
mag_abs = Filter.get_mag(mJys, nus) # even when 'luminosity_distance' is passed to RunAfterglowpy, it will return the abs mag (with redshift)
mags[Filter.name] = mag_app_from_mag_abs(mag_abs, injection_dict["luminosity_distance"])
return times, mags
[docs]
class InjectionPyblastafterglow(InjectionBase):
def __init__(self,
jet_type: str = "tophat",
*args,
**kwargs):
self.jet_type = jet_type
super().__init__(*args, **kwargs)
def _get_injection_lc(self, injection_dict):
"""Create a synthetic lightcurve from pyblastafterglow given the parameters in injection_dict."""
nus = [nu for Filter in self.Filters for nu in Filter.nus]
times = [t for Filter in self.Filters for t in self.t_detect[Filter.name]]
nus = np.sort(nus)
times = np.sort(times)
nus = np.logspace(np.log10(nus[0]), np.log10(nus[-1]), 128) #pbag only takes log (or linear) spaced arrays
times = np.logspace(np.log10(times[0]), np.log10(times[-1]), 100)
pbag = RunPyblastafterglow(self.jet_type, times, nus, [list(injection_dict.values())], injection_dict.keys())
_, log_flux = pbag(0)
mJys = np.exp(log_flux).reshape(len(nus), len(times))
mags = []
for Filter in self.Filters:
mags[Filter.name] = Filter.get_mag(mJys, nus)
return times, mags
[docs]
class InjectionKN(InjectionBase):
def __init__(self,
*args,
**kwargs):
super().__init__(*args, **kwargs)
def _get_injection_lc(self, injection_dict):
raise NotImplementedError(f"No direct calculation for KN injection available, use a training data file instead.")