Source code for fiesta.inference.prior.prior
from dataclasses import field
from typing import Callable
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float, PRNGKeyArray, jaxtyped
from beartype import beartype as typechecker
from astropy.cosmology import Planck18, z_at_value
import astropy.units as u
[docs]
class Prior(object):
"""
A thin base clase to do book keeping.
Should not be used directly since it does not implement any of the real method.
The rationale behind this is to have a class that can be used to keep track of
the names of the parameters and the transforms that are applied to them.
"""
naming: list[str]
transforms: dict[str, tuple[str, Callable]] = field(default_factory=dict)
@property
def n_dim(self):
return len(self.naming)
def __init__(
self, naming: list[str], transforms: dict[str, tuple[str, Callable]] = {}
):
"""
Parameters
----------
naming : list[str]
A list of names for the parameters of the prior.
transforms : dict[tuple[str,Callable]]
A dictionary of transforms to apply to the parameters. The keys are
the names of the parameters and the values are a tuple of the name
of the transform and the transform itself.
"""
self.naming = naming
self.transforms = {}
def make_lambda(name):
return lambda x: x[name]
for name in naming:
if name in transforms:
self.transforms[name] = transforms[name]
else:
# Without the function, the lambda will refer to the variable name instead of its value,
# which will make lambda reference the last value of the variable name
self.transforms[name] = (name, make_lambda(name))
[docs]
def transform(self, x: dict[str, Float]) -> dict[str, Float]:
"""
Apply the transforms to the parameters.
Parameters
----------
x : dict
A dictionary of parameters. Names should match the ones in the prior.
Returns
-------
x : dict
A dictionary of parameters with the transforms applied.
"""
output = {}
for value in self.transforms.values():
output[value[0]] = value[1](x)
return output
[docs]
def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]:
"""
Turn an array into a dictionary
Parameters
----------
x : Array
An array of parameters. Shape (n_dim,).
"""
return dict(zip(self.naming, x))
[docs]
def sample(
self, rng_key: PRNGKeyArray, n_samples: int
) -> dict[str, Float[Array, " n_samples"]]:
raise NotImplementedError
[docs]
class InterpedPrior(Prior):
xx: Array
yy: Array
def __repr__(self,):
return f"InterpedPrior(x={self.xx}, y={self.yy})"
def __init__(self,
xx: Array,
yy: Array,
naming: list[str]):
normalization_factor = jnp.trapezoid(y=yy, x=xx)
self.xx = jnp.array(xx)
self.yy = jnp.array(yy) / normalization_factor
dx = jnp.diff(self.xx)
increments = 0.5 * (self.yy[:-1] + self.yy[1:]) * dx
cdf = jnp.concatenate(
[jnp.array([0.0]), jnp.cumsum(increments)]
)
self.cdf = cdf
super().__init__(naming, {})
assert self.n_dim == 1, "UniformSourceFrame needs to be a 1D distribution"
[docs]
def sample(self, rng_key: PRNGKeyArray, n_samples: int) -> dict[str, Float[Array, "n_samples"]]:
alpha = jax.random.uniform(rng_key, shape=(n_samples))
samples = jnp.interp(alpha, self.cdf, self.xx)
return self.add_name(samples[None])
[docs]
def log_prob(self, x: dict[str, Array]) -> Float:
variable = x[self.naming[0]]
return jnp.log(jnp.interp(variable, self.xx, self.yy, left=0, right=0))
[docs]
@jaxtyped(typechecker=typechecker)
class Uniform(Prior):
xmin: float = 0.0
xmax: float = 1.0
def __repr__(self):
return f"Uniform(xmin={self.xmin}, xmax={self.xmax})"
def __init__(
self,
xmin: Float,
xmax: Float,
naming: list[str],
transforms: dict[str, tuple[str, Callable]] = {},
**kwargs,
):
super().__init__(naming, transforms)
assert self.n_dim == 1, "Uniform needs to be 1D distributions"
self.xmax = xmax
self.xmin = xmin
[docs]
def sample(
self, rng_key: PRNGKeyArray, n_samples: int
) -> dict[str, Float[Array, " n_samples"]]:
"""
Sample from a uniform distribution.
Parameters
----------
rng_key : PRNGKeyArray
A random key to use for sampling.
n_samples : int
The number of samples to draw.
Returns
-------
samples : dict
Samples from the distribution. The keys are the names of the parameters.
"""
samples = jax.random.uniform(
rng_key, (n_samples,), minval=self.xmin, maxval=self.xmax
)
return self.add_name(samples[None])
[docs]
def log_prob(self, x: dict[str, Array]) -> Float:
variable = x[self.naming[0]]
output = jnp.where(
(variable >= self.xmax) | (variable <= self.xmin),
jnp.zeros_like(variable) - jnp.inf,
jnp.zeros_like(variable),
)
return output + jnp.log(1.0 / (self.xmax - self.xmin))
[docs]
@jaxtyped(typechecker=typechecker)
class Normal(Prior):
mu: float = 0.0
sigma: float = 1.0
def __repr__(self):
return f"Normal(mu={self.mu}, sigma={self.sigma})"
def __init__(
self,
mu: Float,
sigma: Float,
naming: list[str],
transforms: dict[str, tuple[str, Callable]] = {},
**kwargs,
):
super().__init__(naming, transforms)
assert self.n_dim == 1, "Normal needs to be 1D distributions"
self.mu = mu
self.sigma = sigma
[docs]
def sample(
self, rng_key: PRNGKeyArray, n_samples: int
) -> dict[str, Float[Array, " n_samples"]]:
"""
Sample from a normal distribution.
Parameters
----------
rng_key : PRNGKeyArray
A random key to use for sampling.
n_samples : int
The number of samples to draw.
Returns
-------
samples : dict
Samples from the distribution. The keys are the names of the parameters.
"""
samples = jax.random.normal(rng_key, (n_samples,),)
samples = self.mu + self.sigma * samples
return self.add_name(samples[None])
[docs]
def log_prob(self, x: dict[str, Array]) -> Float:
variable = x[self.naming[0]]
return -1/(2*self.sigma**2) * (variable-self.mu)**2 - 0.5 * jnp.log(2*jnp.pi*self.sigma**2)
[docs]
class TruncatedNormal(Prior):
"""Truncated normal distribution with explicit bounds.
Useful for informed priors from population studies (e.g., superphot+).
The SVISampler uses xmin/xmax for its guide constraints and mu/sigma
for the model's TruncatedNormal distribution.
"""
mu: float = 0.0
sigma: float = 1.0
xmin: float = -10.0
xmax: float = 10.0
def __repr__(self):
return f"TruncatedNormal(mu={self.mu}, sigma={self.sigma}, xmin={self.xmin}, xmax={self.xmax})"
def __init__(
self,
mu: Float,
sigma: Float,
xmin: Float,
xmax: Float,
naming: list[str],
transforms: dict[str, tuple[str, Callable]] = {},
**kwargs,
):
super().__init__(naming, transforms)
assert self.n_dim == 1, "TruncatedNormal needs to be 1D distributions"
self.mu = mu
self.sigma = sigma
self.xmin = xmin
self.xmax = xmax
[docs]
def sample(
self, rng_key: PRNGKeyArray, n_samples: int
) -> dict[str, Float[Array, " n_samples"]]:
# Standardize bounds for truncated standard normal, then shift/scale
lower = (self.xmin - self.mu) / self.sigma
upper = (self.xmax - self.mu) / self.sigma
samples = jax.random.truncated_normal(
rng_key, lower, upper, (n_samples,),
)
samples = self.mu + self.sigma * samples
return self.add_name(samples[None])
[docs]
def log_prob(self, x: dict[str, Array]) -> Float:
variable = x[self.naming[0]]
in_bounds = (variable >= self.xmin) & (variable <= self.xmax)
# Gaussian kernel
log_p = -0.5 * ((variable - self.mu) / self.sigma) ** 2 - 0.5 * jnp.log(2 * jnp.pi * self.sigma ** 2)
# Truncation normalization: divide by (Phi(upper) - Phi(lower))
lower = (self.xmin - self.mu) / self.sigma
upper = (self.xmax - self.mu) / self.sigma
log_Z = jnp.log(jax.scipy.stats.norm.cdf(upper) - jax.scipy.stats.norm.cdf(lower))
return jnp.where(in_bounds, log_p - log_Z, -jnp.inf)
[docs]
@jaxtyped(typechecker=typechecker)
class UniformVolume(Prior):
xmin: float = 10.
xmax: float = 1e5
def __repr__(self):
return f"UniformVolume(xmin={self.xmin}, xmax={self.xmax})"
def __init__(
self,
xmin: Float,
xmax: Float,
naming: list[str],
transforms: dict[str, tuple[str, Callable]] = {},
**kwargs,
):
super().__init__(naming, transforms)
assert self.n_dim == 1, "UniformComovingVolume needs to be 1D distributions"
self.xmax = xmax
self.xmin = xmin
[docs]
def sample(
self, rng_key: PRNGKeyArray, n_samples: int
) -> dict[str, Float[Array, " n_samples"]]:
"""
Sample luminosity distance from a distribution uniform in volume.
Parameters
----------
rng_key : PRNGKeyArray
A random key to use for sampling.
n_samples : int
The number of samples to draw.
Returns
-------
samples : dict
Samples from the distribution. The keys are the names of the parameters.
"""
vol_max = 4/3 * jnp.pi * self.xmax**3
vol_min = 4/3 * jnp.pi * self.xmin**3
samples = jax.random.uniform(
rng_key, (n_samples,), minval= vol_min, maxval=vol_max
)
samples = (3 / (4*jnp.pi) * samples)**(1/3)
return self.add_name(samples[None])
[docs]
def log_prob(self, x: dict[str, Array]) -> Float:
variable = x[self.naming[0]]
vol_max = 4/3 * jnp.pi * self.xmax**3
vol_min = 4/3 * jnp.pi * self.xmin**3
output = jnp.where(
(variable >= self.xmax) | (variable <= self.xmin),
jnp.zeros_like(variable) - jnp.inf,
jnp.log(
4*jnp.pi*variable**2 / (vol_max-vol_min)
),
)
return output
[docs]
@jaxtyped(typechecker=typechecker)
class UniformSourceFrame(InterpedPrior):
xmin: float = 10.
xmax: float = 1e5
"""
Prior that is uniform in comoving volume and source frame time, analogue to the corresponding bilby prior.
Uses the default cosmology in fiesta which is Planck18.
"""
def __repr__(self):
return f"UniformSourceFrame(xmin={self.xmin}, xmax={self.xmax})"
def __init__(
self,
dmin: Float,
dmax: Float,
naming: list[str],
cosmology = Planck18,
**kwargs,
):
"""
Args:
dmin (Float): Minimum luminosity distance in Mpc.
dmax (Float): Maximum luminosity distance in Mpc.
naming (list[str]): Parameter name. Must be ['luminosity_distance'].
cosmology (astropy.cosmology): Astropy cosmology. Defaults to Planck18.
"""
self.dmax = dmax
self.dmin = dmin
xx = jnp.linspace(self.dmin, self.dmax, 500)
redshift_arr = jnp.array(z_at_value(cosmology.luminosity_distance, xx * u.Mpc))
ddl_dz = jnp.gradient(xx, redshift_arr)
yy = cosmology.differential_comoving_volume(redshift_arr).value / (1 + redshift_arr) * 1/ ddl_dz
if naming != ["luminosity_distance"]:
raise NotImplementedError(f"For now, parameter must be 'luminosity_distance'. {naming[0]} not yet supported.")
super().__init__(xx, yy, naming)
[docs]
@jaxtyped(typechecker=typechecker)
class Sine(Prior):
xmin: float = 0.0
xmax: float = 1.0
def __repr__(self):
return f"Sine(xmin={self.xmin}, xmax={self.xmax})"
def __init__(
self,
naming: list[str],
xmin: Float = 0,
xmax: Float = jnp.pi,
transforms: dict[str, tuple[str, Callable]] = {},
):
super().__init__(naming, transforms)
assert self.n_dim == 1, "Sine needs to be 1D distributions"
assert xmax > xmin, f"Provided xmax {xmax} is smaller than xmin, needs to be larger than {xmin}."
self.xmax = xmax
self.xmin = xmin
[docs]
def sample(
self, rng_key: PRNGKeyArray, n_samples: int
) -> dict[str, Float[Array, " n_samples"]]:
"""
Sample from a uniform distribution.
Parameters
----------
rng_key : PRNGKeyArray
A random key to use for sampling.
n_samples : int
The number of samples to draw.
Returns
-------
samples : dict
Samples from the distribution. The keys are the names of the parameters.
"""
samples = jax.random.uniform(
rng_key, (n_samples,), minval=jnp.cos(self.xmax), maxval=jnp.cos(self.xmin)
)
samples = jnp.arccos(samples)
return self.add_name(samples[None])
[docs]
def log_prob(self, x: dict[str, Array]) -> Float:
variable = x[self.naming[0]]
output = jnp.where(
(variable >= self.xmax) | (variable <= self.xmin),
jnp.zeros_like(variable) - jnp.inf,
jnp.log(jnp.sin(variable) / (jnp.cos(self.xmin) - jnp.cos(self.xmax))),
)
return output
[docs]
@jaxtyped(typechecker=typechecker)
class LogUniform(Prior):
xmin: float = 0.0
xmax: float = 1.0
def __repr__(self):
return f"LogUniform(xmin={self.xmin}, xmax={self.xmax})"
def __init__(
self,
xmin: Float,
xmax: Float,
naming: list[str],
transforms: dict[str, tuple[str, Callable]] = {},
**kwargs,
):
super().__init__(naming, transforms)
assert self.n_dim == 1, "LogUniform needs to be 1D distributions"
assert xmin > 0, f"Provided xmin {xmin} is negative, needs to be larger than 0."
assert xmax > xmin, f"Provided xmax {xmax} is smaller than xmin, needs to be larger than {xmin}."
self.xmax = xmax
self.xmin = xmin
[docs]
def sample(
self, rng_key: PRNGKeyArray, n_samples: int
) -> dict[str, Float[Array, " n_samples"]]:
"""
Sample from a uniform distribution.
Parameters
----------
rng_key : PRNGKeyArray
A random key to use for sampling.
n_samples : int
The number of samples to draw.
Returns
-------
samples : dict
Samples from the distribution. The keys are the names of the parameters.
"""
samples = jax.random.uniform(
rng_key, (n_samples,), minval=jnp.log(self.xmin), maxval=jnp.log(self.xmax)
)
samples = jnp.exp(samples)
return self.add_name(samples[None])
[docs]
def log_prob(self, x: dict[str, Array]) -> Float:
variable = x[self.naming[0]]
output = jnp.where(
(variable >= self.xmax) | (variable <= self.xmin),
jnp.zeros_like(variable) - jnp.inf,
jnp.zeros_like(variable),
)
return output + jnp.log(1.0 / (jnp.log(self.xmax) - jnp.log(self.xmin)) ) - jnp.log(variable)
[docs]
class CompositePrior(Prior):
priors: list[Prior] = field(default_factory=list)
def __repr__(self):
return f"Composite(priors={self.priors}, naming={self.naming})"
def __init__(
self,
priors: list[Prior],
transforms: dict[str, tuple[str, Callable]] = {},
**kwargs,
):
naming = []
self.transforms = {}
for prior in priors:
naming += prior.naming
self.transforms.update(prior.transforms)
self.priors = priors
self.naming = naming
self.transforms.update(transforms)
[docs]
def sample(
self, rng_key: PRNGKeyArray, n_samples: int
) -> dict[str, Float[Array, " n_samples"]]:
output = {}
for prior in self.priors:
rng_key, subkey = jax.random.split(rng_key)
output.update(prior.sample(subkey, n_samples))
return output
[docs]
def log_prob(self, x: dict[str, Float]) -> Float:
output = 0.0
for prior in self.priors:
output += prior.log_prob(x)
return output
[docs]
class Constraint(Prior):
xmin: float
xmax: float
def __init__(self,
naming: list[str],
xmin: Float,
xmax: Float,
transforms: dict[str, tuple[str, Callable]] = {})->None:
super().__init__(naming = naming, transforms=transforms)
self.xmin = xmin
self.xmax = xmax
[docs]
def log_prob(self, x: dict[str, Array]) -> Float:
variable = x[self.naming[0]]
output = jnp.where(
(variable > self.xmax) | (variable < self.xmin),
jnp.zeros_like(variable) - jnp.inf,
jnp.zeros_like(variable),
)
return output