import re
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float, Int
from sncosmo.bandpasses import _BANDPASSES, _BANDPASS_INTERPOLATORS
from sncosmo import get_bandpass
from fiesta.conversions import monochromatic_AB_mag, bandpass_AB_mag, integrated_AB_mag
import fiesta.constants as constants
#########################
### Filters ###
#########################
[docs]
class Filter:
def __init__(self,
name: str,
nus: Array = None,
trans: Array = None):
"""
Filter class that uses the bandpass properties from sncosmo or just a simple monochromatic filter based on the name.
The necessary attributes are stored as jnp arrays.
Args:
name (str): Name of the filter. Will be either passed to sncosmo to get the optical bandpass, or the unit at the end will be used to create a monochromatic filter. Supported units are keV and GHz.
"""
self.name = name
if self.name in list(map(lambda x: x[0], _BANDPASSES._primary_loaders)):
bandpass = get_bandpass(self.name) # sncosmo bandpass
self.nu = constants.c / (bandpass.wave_eff*1e-10)
self.nus = constants.c / (bandpass.wave[::-1]*1e-10)
self.trans = bandpass.trans[::-1] # reverse the array to get the transmission as function of frequency (not wavelength)
if len(self.nus)>100: # to avoid memory issues later
self.nus = jnp.linspace(self.nus[0], self.nus[-1], 100)
self.trans = bandpass(constants.c / self.nus * 1e10)
self.filt_type = "bandpass"
elif self.name in list(map(lambda x: x[0], _BANDPASS_INTERPOLATORS._primary_loaders)):
bandpass = get_bandpass(self.name, 0) # these bandpass interpolators require a radius (here by default 0 cm)
self.nu = constants.c/(bandpass.wave_eff*1e-10)
self.nus = constants.c / (bandpass.wave[::-1]*1e-10)
self.trans = bandpass.trans[::-1] # reverse the array to get the transmission as function of frequency (not wavelength)
if len(self.nus)>100: # to avoid memory issues later
self.nus = jnp.linspace(self.nus[0], self.nus[-1], 100)
self.trans = bandpass(constants.c / self.nus * 1e10)
self.filt_type = "bandpass"
elif self.name.endswith("GHz"):
freq = re.findall(r"[-+]?(?:\d*\.*\d+)", self.name.replace("-",""))
freq = float(freq[-1])
self.nu = freq*1e9
self.nus = jnp.array([self.nu])
self.trans = jnp.ones(1)
self.filt_type = "monochromatic"
elif self.name.endswith("keV"):
if bool(re.match(r'^.*[^0-9.]-\d+(\.\d*)?keV$', self.name)):
energy = float(re.findall(r"\d+(?:\.\d*)?", self.name)[-1])
self.nu = energy*1000*constants.eV / constants.h
self.nus = jnp.array([self.nu])
self.trans = jnp.ones(1)
self.filt_type = "monochromatic"
elif bool(re.match(r'^.*[^0-9.]-\d+(\.\d*)?-\d+(\.\d*)?keV$', self.name)):
energy1, energy2 = re.findall(r"\d+(?:\.\d*)?", self.name)
nu1 = float(energy1)*1000*constants.eV / constants.h
nu2 = float(energy2)*1000*constants.eV / constants.h
self.nus = jnp.linspace(nu1, nu2, 20)
self.trans = jnp.ones_like(self.nus)
self.nu = jnp.mean(self.nus)
self.filt_type = "integrated"
else:
raise ValueError(f"X-ray filter {self.name} must either be in format 'X-ray-*-keV' or 'X-ray-*-*-keV' ")
elif nus is not None:
self.nus = nus
self.nu = jnp.mean(self.nus)
if trans is not None:
self.trans = trans
else:
trans = jnp.ones_like(nus)
self.filt_type = "bandpass"
else:
raise ValueError(f"Filter {self.name} not recognized")
self.wavelength = constants.c/self.nu*1e10
self._calculate_ref_flux()
if self.filt_type=="bandpass":
self.get_mag = lambda Fnu, nus: bandpass_AB_mag(Fnu, nus, self.nus, self.trans, self.ref_flux)
elif self.filt_type=="monochromatic":
self.get_mag = lambda Fnu, nus: monochromatic_AB_mag(Fnu, nus, self.nus, self.trans, self.ref_flux)
elif self.filt_type=="integrated":
self.get_mag = lambda Fnu, nus: integrated_AB_mag(Fnu, nus, self.nus, self.trans)
def _calculate_ref_flux(self,):
"""method to determine the reference flux for the magnitude conversion."""
if self.filt_type in ["monochromatic", "integrated"]:
self.ref_flux = 3631000. # mJy
elif self.filt_type=="bandpass":
integrand = self.trans / (constants.h_erg_s * self.nus) # https://en.wikipedia.org/wiki/AB_magnitude
integral = jnp.trapezoid(y = integrand, x = self.nus)
self.ref_flux = 3631000. * integral.item() # mJy
[docs]
def get_mags(self, fluxes: Float[Array, "n_samples n_nus n_times"], nus: Float[Array, "n_nus"]) -> Float[Array, "n_samples n_times"]:
def get_single(flux):
return self.get_mag(flux, nus)
mags = jax.vmap(get_single)(fluxes)
return mags