r"""Prior base classes for JESTER inference system.
This module contains prior classes that were originally from Jim (jimgw v0.2.0).
They are copied here to remove the dependency on jimgw.
Note: These classes follow the Jim/jimgw architecture with a dict-based interface
for named parameters. The flowMC Distribution inheritance was removed to avoid
interface conflicts while maintaining compatibility with JESTER's sampling backends.
"""
from dataclasses import field
from typing import Any
import jax
import jax.numpy as jnp
from beartype import beartype as typechecker
from jaxtyping import Array, Float, PRNGKeyArray, jaxtyped
from .transform import (
BijectiveTransform,
LogitTransform,
ScaleTransform,
OffsetTransform,
)
[docs]
class Prior:
"""
Base class for JESTER prior distributions.
Note: This class follows the Jim/jimgw architecture. Should not be used directly
since it does not implement any of the real methods.
The rationale behind this is to have a class that can be used to keep track of
the names of the parameters and the transforms that are applied to them.
This class was previously inherited from flowMC's Distribution, but that
dependency has been removed to avoid interface conflicts.
"""
parameter_names: list[str]
composite: bool
@property
def n_dim(self) -> int:
return len(self.parameter_names)
[docs]
def __init__(self, parameter_names: list[str]) -> None:
"""
Parameters
----------
parameter_names : list[str]
A list of names for the parameters of the prior.
"""
self.parameter_names = parameter_names
self.composite = False
[docs]
def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]:
"""
Turn an array into a dictionary.
Parameters
----------
x : Array
An array of parameters. Shape (n_dim,).
Returns
-------
dict[str, Float]
Dictionary mapping parameter names to values.
"""
return dict(zip(self.parameter_names, x))
[docs]
def sample(
self, rng_key: PRNGKeyArray, n_samples: int
) -> dict[str, Float[Array, " n_samples"]]:
"""
Sample from the prior distribution.
Parameters
----------
rng_key : PRNGKeyArray
A random key to use for sampling.
n_samples : int
The number of samples to draw.
Returns
-------
samples : dict[str, Float[Array, " n_samples"]]
Samples from the distribution. The keys are the names of the parameters.
"""
raise NotImplementedError
[docs]
def log_prob(self, z: dict[str, Float | Array]) -> Float:
"""
Evaluate the log probability of the prior.
Parameters
----------
z : dict[str, Array]
Dictionary of parameter names to values.
Returns
-------
log_prob : Float
The log probability.
"""
raise NotImplementedError
@jaxtyped(typechecker=typechecker)
class LogisticDistribution(Prior):
"""
Logistic distribution prior.
Note: This class follows the Jim/jimgw architecture.
"""
def __repr__(self) -> str:
return f"LogisticDistribution(parameter_names={self.parameter_names})"
def __init__(self, parameter_names: list[str], **kwargs: Any) -> None:
super().__init__(parameter_names)
self.composite = False
assert self.n_dim == 1, "LogisticDistribution needs to be 1D distributions"
def sample(
self, rng_key: PRNGKeyArray, n_samples: int
) -> dict[str, Float[Array, " n_samples"]]:
"""
Sample from a logistic distribution.
Parameters
----------
rng_key : PRNGKeyArray
A random key to use for sampling.
n_samples : int
The number of samples to draw.
Returns
-------
samples : dict
Samples from the distribution. The keys are the names of the parameters.
"""
samples = jax.random.uniform(rng_key, (n_samples,), minval=0.0, maxval=1.0)
samples = jnp.log(samples / (1 - samples))
return self.add_name(samples[None])
def log_prob(self, z: dict[str, Float]) -> Float:
"""
Evaluate the log probability.
Parameters
----------
z : dict[str, Float]
Dictionary of parameter names to values.
Returns
-------
log_prob : Float
The log probability.
"""
variable = z[self.parameter_names[0]]
return -variable - 2 * jnp.log(1 + jnp.exp(-variable))
class SequentialTransformPrior(Prior):
"""
Transform a prior distribution by applying a sequence of transforms.
Note: This class follows the Jim/jimgw architecture.
The space before the transform is named as x,
and the space after the transform is named as z.
"""
base_prior: Prior
transforms: list[BijectiveTransform]
def __repr__(self) -> str:
return f"Sequential(priors={self.base_prior}, parameter_names={self.parameter_names})"
def __init__(
self,
base_prior: Prior,
transforms: list[BijectiveTransform],
) -> None:
self.base_prior = base_prior
self.transforms = transforms
self.parameter_names = base_prior.parameter_names
for transform in transforms:
self.parameter_names = transform.propagate_name(self.parameter_names)
self.composite = True
def sample(
self, rng_key: PRNGKeyArray, n_samples: int
) -> dict[str, Float[Array, " n_samples"]]:
"""
Sample from the transformed prior.
Parameters
----------
rng_key : PRNGKeyArray
A random key to use for sampling.
n_samples : int
The number of samples to draw.
Returns
-------
samples : dict
Samples from the transformed distribution.
"""
output = self.base_prior.sample(rng_key, n_samples)
return jax.vmap(self.transform)(output)
def log_prob(self, z: dict[str, Float]) -> Float:
"""
Evaluate the probability of the transformed variable z.
This is what flowMC should sample from.
Parameters
----------
z : dict[str, Float]
Dictionary of parameter names to values in transformed space.
Returns
-------
log_prob : Float
The log probability including Jacobian correction.
"""
output = 0
for transform in reversed(self.transforms):
z, log_jacobian = transform.inverse(z)
output += log_jacobian
output += self.base_prior.log_prob(z)
return output
def transform(self, x: dict[str, Float]) -> dict[str, Float]:
"""
Apply forward transforms to x.
Parameters
----------
x : dict[str, Float]
Dictionary of parameter names to values.
Returns
-------
z : dict[str, Float]
Transformed dictionary.
"""
for transform in self.transforms:
x = transform.forward(x)
return x
[docs]
class CombinePrior(Prior):
"""
A prior class constructed by joining multiple priors together to form a multivariate prior.
Note: This class follows the Jim/jimgw architecture.
This assumes the priors composing the Combine class are independent.
"""
base_prior: list[Prior] = field(default_factory=list)
def __repr__(self) -> str:
return (
f"Combine(priors={self.base_prior}, parameter_names={self.parameter_names})"
)
[docs]
def __init__(
self,
priors: list[Prior],
) -> None:
"""
Parameters
----------
priors : list[Prior]
List of independent prior distributions to combine.
"""
parameter_names = []
for prior in priors:
parameter_names += prior.parameter_names
self.base_prior = priors
self.parameter_names = parameter_names
self.composite = True
[docs]
def sample(
self, rng_key: PRNGKeyArray, n_samples: int
) -> dict[str, Float[Array, " n_samples"]]:
"""
Sample from the combined prior by sampling from each component.
Parameters
----------
rng_key : PRNGKeyArray
A random key to use for sampling.
n_samples : int
The number of samples to draw.
Returns
-------
samples : dict
Combined samples from all priors.
"""
output = {}
for prior in self.base_prior:
rng_key, subkey = jax.random.split(rng_key)
output.update(prior.sample(subkey, n_samples))
return output
[docs]
def log_prob(self, z: dict[str, Float]) -> Float:
"""
Evaluate the log probability by summing over independent priors.
Parameters
----------
z : dict[str, Float]
Dictionary of parameter names to values.
Returns
-------
log_prob : Float
The combined log probability.
"""
output = 0.0
for prior in self.base_prior:
output += prior.log_prob(z)
return output
class Fixed(Prior):
"""A parameter fixed to a constant value, excluded from the sampling space.
This is not a proper prior distribution — it has no ``log_prob`` or
``sample`` implementation. Use it in ``.prior`` files to pin a parameter
to a specific value while keeping the specification co-located with the
sampled priors:
.. code-block:: python
lambda_BL = Fixed(0.0, parameter_names=["lambda_BL"])
The parser will extract ``Fixed`` entries into a separate
``fixed_params`` dict and will not add them to the ``CombinePrior`` that
defines the sampling space.
Parameters
----------
value : float
The fixed value for the parameter.
parameter_names : list[str]
Must contain exactly one parameter name.
"""
value: float
def __repr__(self) -> str:
return f"Fixed(value={self.value}, parameter_names={self.parameter_names})"
def __init__(self, value: float, parameter_names: list[str]) -> None:
super().__init__(parameter_names)
assert self.n_dim == 1, "Fixed must be 1D (one parameter per Fixed instance)"
self.value = float(value)
def sample(
self, rng_key: PRNGKeyArray, n_samples: int
) -> dict[str, Float[Array, " n_samples"]]:
raise NotImplementedError(
"Fixed parameters are not sampled. "
"They are injected as constants into the transform."
)
def log_prob(self, z: dict[str, Float | Array]) -> Float:
raise NotImplementedError(
"Fixed parameters have no log_prob. "
"They do not contribute to the posterior density."
)
@jaxtyped(typechecker=typechecker)
class MultivariateGaussianPrior(Prior):
r"""Multivariate Gaussian prior :math:`\mathcal{N}(\mu, \Sigma)`.
By default this is the standard multivariate normal :math:`\mathcal{N}(0, I_d)`,
i.e. independent standard normals for each dimension. Arbitrary mean and
covariance can be supplied for general use.
Attributes
----------
mean : Float[Array, " n_dim"]
Mean vector.
cov : Float[Array, "n_dim n_dim"]
Covariance matrix (must be positive definite).
"""
mean: Float[Array, " n_dim"]
cov: Float[Array, "n_dim n_dim"]
def __repr__(self) -> str:
return (
f"MultivariateGaussianPrior(n_dim={self.n_dim}, "
f"parameter_names={self.parameter_names})"
)
def __init__(
self,
parameter_names: list[str],
mean: Float[Array, " n_dim"] | None = None,
cov: Float[Array, "n_dim n_dim"] | None = None,
) -> None:
r"""
Parameters
----------
parameter_names : list[str]
Names for each dimension.
mean : array-like, optional
Mean vector of length ``len(parameter_names)``. Defaults to zeros.
cov : array-like, optional
Covariance matrix of shape ``(n_dim, n_dim)``. Defaults to identity.
"""
super().__init__(parameter_names)
d = self.n_dim
self.mean = jnp.zeros(d) if mean is None else jnp.asarray(mean, dtype=float)
self.cov = jnp.eye(d) if cov is None else jnp.asarray(cov, dtype=float)
self.composite = False
def sample(
self, rng_key: PRNGKeyArray, n_samples: int
) -> dict[str, Float[Array, " n_samples"]]:
r"""Sample from :math:`\mathcal{N}(\mu, \Sigma)`.
Parameters
----------
rng_key : PRNGKeyArray
JAX random key.
n_samples : int
Number of samples to draw.
Returns
-------
dict[str, Float[Array, " n_samples"]]
Samples keyed by parameter name.
"""
# Draw standard normals then apply affine transform: x = mu + L z
L = jnp.linalg.cholesky(self.cov)
z = jax.random.normal(rng_key, shape=(n_samples, self.n_dim))
samples = self.mean + z @ L.T # (n_samples, n_dim)
return {name: samples[:, i] for i, name in enumerate(self.parameter_names)}
def log_prob(self, z: dict[str, Float | Array]) -> Float:
r"""Evaluate :math:`\log \mathcal{N}(z \mid \mu, \Sigma)`.
Parameters
----------
z : dict[str, Float]
Dictionary mapping parameter names to scalar values.
Returns
-------
Float
Log probability (scalar).
"""
x = jnp.array([z[name] for name in self.parameter_names])
return jax.scipy.stats.multivariate_normal.logpdf(x, self.mean, self.cov)