import jax
import jax.numpy as jnp
from jaxtyping import Array, Float
import numpy as np
from fiesta.constants import pc_to_cm, h_erg_s, c, H0, Omega_m
#######################
# DISTANCE CONVERSION #
#######################
[docs]
def Mpc_to_cm(d: float):
return d * 1e6 * pc_to_cm
[docs]
def redshift_to_luminosity_distance(z: Array, H0=H0, Omega_m=Omega_m):
def correction_factor(z: Float):
z_arr = jnp.linspace(0, z, 100)
integrand = ( Omega_m* (1+z_arr)**3 + (1-Omega_m) )**(-0.5)
return jnp.trapezoid(x=z_arr, y=integrand)
correction = jax.vmap(correction_factor)(z)
luminosity_distance = c / H0 * (1+z) * correction
return luminosity_distance
z_arr = jnp.logspace(-6, jnp.log10(15), 250)
dL_arr = redshift_to_luminosity_distance(z_arr)
[docs]
def luminosity_distance_to_redshift(dL: Array):
return jnp.interp(dL, dL_arr, z_arr)
###################
# FLUX CONVERSION #
###################
[docs]
def Flambda_to_Fnu(F_lambda: Float[Array, "n_lambdas n_times"], lambdas: Float[Array, "n_lambdas"]) -> Float[Array, "n_lambdas n_times"]:
"""
JAX-compatible conversion of wavelength flux in erg cm^{-2} s^{-1} Angström^{-1} to spectral flux density in mJys.
Args:
flux_lambda (Float[Array]): 2D flux density array in erg cm^{-2} s^{-1} Angström^{-1}. The rows correspond to the wavelengths provided in lambdas.
lambdas (Float[Array]): 1D wavelength array in Angström.
Returns:
mJys (Float[Array]): 2D spectral flux density array in mJys
nus (Float[Array]): 1D frequency array in Hz
"""
F_lambda = F_lambda.reshape(lambdas.shape[0], -1)
log_F_lambda = jnp.log10(F_lambda) # got to log because of large factors
log_F_nu = log_F_lambda + 2* jnp.log10(lambdas[:, None]) + jnp.log10(3.3356) + 4 # https://en.wikipedia.org/wiki/AB_magnitude
F_nu = 10**(log_F_nu)
F_nu = F_nu[::-1, :] # reverse the order to get lowest frequencies in first row
mJys = 1e3 * F_nu # convert Jys to mJys
nus = c / (lambdas*1e-10)
nus = nus[::-1]
return mJys, nus
[docs]
def Fnu_to_Flambda(F_nu: Float[Array, "n_nus n_times"], nus: Float[Array, "n_nus"]) -> Float[Array, "n_nus n_times"]:
"""
JAX-compatible conversion of spectral flux density in mJys to wavelength flux in erg cm^{-2} s^{-1}.
Args:
flux_nu (Float[Array]): 2D flux density array in mJys. The rows correspond to the frequencies provided in nus.
nus (Float[Array]): 1D frequency array in Hz.
Returns:
flux_lambda (Float[Array]): 2D wavelength flux density array in erg cm^{-2} s^{-1} Angström^{-1}.
lambdas (Float[Array]): 1D wavelength array in Angström.
"""
F_nu = F_nu.reshape(nus.shape[0], -1)
log_F_nu = jnp.log10(F_nu) # go to log because of large factors
log_F_nu = log_F_nu - 3 # convert mJys to Jys
log_F_lambda = log_F_nu + 2 * jnp.log10(nus[:, None]) + jnp.log10(3.3356) - 42
F_lambda = 10**(log_F_lambda)
F_lambda = F_lambda[::-1, :] # reverse the order to get the lowest wavelegnths in first row
lambdas = c / nus
lambdas = lambdas[::-1] * 1e10
return F_lambda, lambdas
[docs]
def apply_redshift(F_nu: Float[Array, "n_nus n_times"], times: Float[Array, "n_times"], nus: Float[Array, "n_nus"], z: Float):
"""
Transforms a 2D flux density array from source frame in observer frame, as well as the associated time and frequency array.
Does not account for the distance factor, so cosmological energy loss and time elongation are taken into account by the luminosity distance.
Args:
F_nu (Float[Array]): 2D flux density array in mJy in source frame. The rows correspond to the frequencies provided in nus, the columns to times.
times (Float[Array]): 1D time array in source frame.
nus (Float[Array]): 1D frequency array in source frame
Returns:
tuple:
times (Float[Array]): 1D time array in observer frame.
nus (Float[Array]): 1D frequency array in source frame.
F_nu (Float[Array]): 2D flux density redshifted to observer frame.
"""
F_nu = F_nu * (1 + z) # this is just the frequency redshift, cosmological energy loss and time elongation are taken into account by luminosity_distance
times = times * (1 + z)
nus = nus / (1 + z)
return times, nus, F_nu
########################
# MAGNITUDE CONVERSION #
########################
[docs]
def monochromatic_AB_mag(flux: Float[Array, "n_nus n_times"],
nus: Float[Array, "n_nus"],
nus_filt: Float[Array, "n_nus_filt"],
trans_filt: Float[Array, "n_nus_filt"],
ref_flux: Float) -> Float[Array, "n_times"]:
interp_col = lambda col: jnp.interp(nus_filt, nus, col)
mJys = jax.vmap(interp_col, in_axes = 1, out_axes = 1)(flux) # apply vectorized interpolation to interpolate columns of 2D array
mJys = mJys * trans_filt[:, None]
mag = mJys_to_mag_jnp(mJys)
return mag[0]
[docs]
def bandpass_AB_mag(flux: Float[Array, "n_nus n_times"],
nus: Float[Array, "n_nus"],
nus_filt: Float[Array, "n_nus_filt"],
trans_filt: Float[Array, "n_nus_filt"],
ref_flux: Float) -> Float[Array, "n_times"]:
"""
This is a JAX-compatile equivalent of sncosmo.TimeSeriesSource.bandmag(). Unlike sncosmo, we use the frequency flux and not wavelength flux,
but this function is tested to yield the same results as the sncosmo version.
Args:
flux (Float[Array, "n_nus n_times"]): Spectral flux density as a 2D array in mJys.
nus (Float[Array, "n_nus"]): Associated frequencies in Hz
nus_filt (Float[Array, "n_nus_filt"]): frequency array of the filter in Hz
trans_filt (Float[Array, "n_nus_filt"]): transmissivity array of the filter in transmitted photons / incoming photons
ref_flux (Float): flux in mJy for which the filter is 0 mag
"""
interp_col = lambda col: jnp.interp(nus_filt, nus, col)
mJys = jax.vmap(interp_col, in_axes = 1, out_axes = 1)(flux) # apply vectorized interpolation to interpolate columns of 2D array
log_mJys = jnp.log10(mJys) # go to log because of large factors
log_mJys = log_mJys + jnp.log10(trans_filt[:, None])
log_mJys = log_mJys - jnp.log10(h_erg_s) - jnp.log10(nus_filt[:, None]) # https://en.wikipedia.org/wiki/AB_magnitude
max_log_mJys = jnp.max(log_mJys)
integrand = 10**(log_mJys - max_log_mJys) # make the integrand between 0 and 1, otherwise infs could appear
integrate_col = lambda col: jnp.trapezoid(y = col, x = nus_filt)
norm_band_flux = jax.vmap(integrate_col, in_axes = 1)(integrand) # normalized band flux
log_integrated_flux = jnp.log10(norm_band_flux) + max_log_mJys # reintroduce scale here
mag = -2.5 * log_integrated_flux + 2.5 * jnp.log10(ref_flux)
return mag
[docs]
def integrated_AB_mag(flux: Float[Array, "n_nus n_times"],
nus: Float[Array, "n_nus"],
nus_filt: Float[Array, "n_nus_filt"],
trans_filt: Float[Array, "n_nus_filt"]) -> Float[Array, "n_times"]:
interp_col = lambda col: jnp.interp(nus_filt, nus, col)
mJys = jax.vmap(interp_col, in_axes = 1, out_axes = 1)(flux) # apply vectorized interpolation to interpolate columns of 2D array
log_mJys = jnp.log10(mJys) # go to log because of large factors
log_mJys = log_mJys + jnp.log10(trans_filt[:, None])
max_log_mJys = jnp.max(log_mJys)
integrand = 10**(log_mJys - max_log_mJys) # make the integrand between 0 and 1, otherwise infs could appear
integrate_col = lambda col: jnp.trapezoid(y = col, x = nus_filt)
norm_band_flux = jax.vmap(integrate_col, in_axes = 1)(integrand) # normalized band flux
log_integrated_flux = jnp.log10(norm_band_flux) + max_log_mJys # reintroduce scale here
log_integrated_flux = log_integrated_flux - jnp.log10(nus_filt[-1] - nus_filt[0]) # divide by integration range
mJys = 10**log_integrated_flux
mag = mJys_to_mag_jnp(mJys)
return mag
[docs]
@jax.jit
def mJys_to_mag_jnp(mJys: Array):
mag = -48.6 + -1 * jnp.log10(mJys) * 2.5 + 26 * 2.5 # https://en.wikipedia.org/wiki/AB_magnitude
return mag
# TODO: need a np and jnp version?
# TODO: account for extinction
[docs]
def mJys_to_mag_np(mJys: np.array):
Jys = 1e-3 * mJys
mag = -48.6 + -1 * np.log10(Jys / 1e23) * 2.5
return mag
[docs]
def mag_app_from_mag_abs(mag_abs: Array,
luminosity_distance: Float) -> Array:
return mag_abs + 5.0 * jnp.log10(luminosity_distance * 1e6 / 10.0)