Source code for fiesta.inference.prior.prior_dict

import jax
import jax.numpy as jnp

from typing import Callable
from jaxtyping import Array, Float, PRNGKeyArray, jaxtyped
from .prior import Prior, Constraint, CompositePrior


[docs] class ConstrainedPrior(CompositePrior): priors: CompositePrior constraints: list[Constraint] conversion: Callable factor: Float def __init__(self, priors: list, conversion_function: Callable=lambda x: x, transforms: dict[str, tuple[str, Callable]] = {}): super().__init__([prior for prior in priors if not isinstance(prior, Constraint)]) self.constraints = [constraint for constraint in priors if isinstance(constraint, Constraint)] self.conversion = conversion_function self._estimate_normalization() def _estimate_normalization(self, nrepeats: int = 10, sampling_chunk: int = 50_000): rng_key = jax.random.key(314159265) factor_estimates = [] for _ in range(nrepeats): rng_key, subkey = jax.random.split(rng_key) samples = super().sample(subkey, n_samples = sampling_chunk) constr = ~jnp.isneginf(self.evaluate_constraints(samples)) factor_estimates.append(sampling_chunk/jnp.sum(constr)) factor_estimates = jnp.array(factor_estimates) decimals = min(16, -jnp.floor(jnp.log10(3*jnp.std(factor_estimates)))) decimals = max(0, decimals) self.factor = jnp.round(jnp.mean(factor_estimates), int(decimals))
[docs] def evaluate_constraints(self, samples): converted_sample = self.conversion(samples) log_prob = jnp.zeros_like(samples[self.naming[0]]) for constraint in self.constraints: log_prob+=constraint.log_prob(converted_sample) return log_prob
[docs] def sample( self, rng_key: PRNGKeyArray, n_samples: int ) -> dict[str, Float[Array, "n_samples"]]: rng_key, subkey = jax.random.split(rng_key) samples = super().sample(subkey, n_samples) constr = ~jnp.isneginf(self.evaluate_constraints(samples)) while jnp.any(~constr): # not really jax-y but no idea atm how to do implement this logic better idx = jnp.where(~constr, jnp.arange(n_samples), 0) idx = jnp.unique(idx)# problems with jit here rng_key, subkey = jax.random.split(rng_key) new_samples = super().sample(subkey, idx.shape[0]) new_constr = ~jnp.isneginf(self.evaluate_constraints(new_samples)) def update_arrays(old_arr, new_arr): return old_arr.at[idx].set(new_arr) samples = jax.tree_util.tree_map(update_arrays, samples, new_samples) # update the samples dic by mapping update_arrays function over it constr = constr.at[idx].set(new_constr) for constraint in self.constraints: if constraint.naming[0] in samples.keys(): del samples[constraint.naming[0]] return samples """ def check_constraint(state): _, constr, _ , _ = state return jnp.all(constr) def update_samples(state): samples, constr, rng_key, super = state idx = jnp.where(~constr, jnp.arange(constr.shape[0]), 0) rng_key, subkey = jax.random.split(rng_key) new_samples = super.sample(subkey, jnp.sum(idx!=0)) new_constr = self.evaluate_constraints(new_samples) samples = jax.tree_util.tree_map(update_arrays, samples, new_samples) constr = constr.at[idx].set(new_constr) return samples, constr, rng_key, super rng_key, subkey = jax.random.split(rng_key) init_sample = super().sample(subkey, n_samples) init_constr = ~jnp.isneginf(self.evaluate_constraints(init_sample)) init_state = (init_sample, init_constr, rng_key, super()) final_state = jax.lax.while_loop(check_constraint, update_samples, init_state) return final_state[0] """
[docs] def log_prob(self, x: dict[str, Float]) -> Float: output = self.evaluate_constraints(x) for prior in self.priors: output += prior.log_prob(x) output += jnp.log(self.factor) return output