Source code for jesterTOV.inference.samplers.blackjax.nested_sampling.acceptance_walk

"""BlackJAX Nested Sampling with acceptance walk kernel.

This module provides nested sampling using the BlackJAX library (handley-lab fork)
with acceptance walk kernel for efficient exploration of the parameter space.

# FIXME: this is still being tested, use with care!
"""

from typing import Any
import time

import jax
import jax.numpy as jnp
import jax.random
from jax.experimental import io_callback
from jaxtyping import Array, PRNGKeyArray

from blackjax.ns.base import StateWithLogLikelihood

from jesterTOV.inference.base import (
    LikelihoodBase,
    Prior,
    BijectiveTransform,
    NtoMTransform,
)
from jesterTOV.inference.config.schema import BlackJAXNSAWConfig
from jesterTOV.inference.samplers.jester_sampler import SamplerOutput
from jesterTOV.inference.samplers.blackjax.base import BlackjaxSampler
from jesterTOV.logging_config import get_logger

logger = get_logger("jester")


[docs] class BlackJAXNSAWSampler(BlackjaxSampler): """BlackJAX Nested Sampling with acceptance walk kernel. This sampler implements nested sampling for Bayesian evidence calculation and posterior sampling. It uses unit cube transforms (all parameters mapped to [0, 1]) and the acceptance walk kernel for MCMC proposals. Unlike SMC samplers, NS-AW works directly with dict-based functions from BlackjaxSampler parent class (no flattening needed). Parameters ---------- likelihood : LikelihoodBase Likelihood object with evaluate(params, data) method prior : Prior Prior object (CombinePrior of UniformPrior and/or MultivariateGaussianPrior) sample_transforms : list[BijectiveTransform] Unit cube transforms (created by transform_factory) likelihood_transforms : list[NtoMTransform] N-to-M transforms applied before likelihood evaluation config : BlackJAXNSAWConfig Nested sampling configuration seed : int, optional Random seed (default: 0) Attributes ---------- config : BlackJAXNSAWConfig Sampler configuration final_state : Any | None Final nested sampling state (after sampling) metadata : dict Sampling metadata (evidence, time, etc.) _logprior_fn : callable Pre-compiled log prior function (unit cube → prior space) _loglikelihood_fn : callable Pre-compiled log likelihood function (unit cube → likelihood) Notes ----- Requires BoundToBound [0,1] transforms for all parameters (created automatically by transform_factory for nested sampling). """ config: BlackJAXNSAWConfig final_state: Any | None metadata: dict _logprior_fn: Any # Compiled JAX function _loglikelihood_fn: Any # Compiled JAX function _unit_cube_stepper: Any # Unit cube stepper function _filtered_samples_cache: dict | None # Cache for filtered samples from anesthetic
[docs] def __init__( self, likelihood: LikelihoodBase, prior: Prior, sample_transforms: list[BijectiveTransform], likelihood_transforms: list[NtoMTransform], config: BlackJAXNSAWConfig, seed: int = 0, ) -> None: """Initialize BlackJAX nested sampling sampler.""" super().__init__(likelihood, prior, sample_transforms, likelihood_transforms) self.config = config self.final_state = None self.metadata = {} self._seed = seed self._filtered_samples_cache = None # Nested sampling requires unit cube transforms # If not provided, create them automatically if len(sample_transforms) == 0: logger.info( "No sample transforms provided - creating unit cube transforms for NS-AW" ) sample_transforms = self._create_unit_cube_transforms(prior) # Update the sample_transforms in the parent class self.sample_transforms = sample_transforms # Recompute parameter names after adding transforms for transform in self.sample_transforms: self.parameter_names = transform.propagate_name(self.parameter_names) logger.info("Initializing BlackJAX Nested Sampling (Acceptance Walk) sampler") logger.info( f"Configuration: {config.n_live} live points, " f"delete fraction {config.n_delete_frac}" ) logger.info(f"Termination: dlogZ < {config.termination_dlogz}") # Use parent class methods to create log prior and log likelihood functions # These work with dicts directly (no flattening needed for NS-AW) self._logprior_fn = self._create_logprior_fn_from_dict() self._loglikelihood_fn = self._create_loglikelihood_fn_from_dict() # Create unit cube stepper function (wraps at [0, 1] boundaries) self._create_unit_cube_stepper() # Import BlackJAX nested sampling (lazy import to avoid dependency issues) try: from blackjax.ns.utils import finalise self._finalise = finalise except ImportError as e: raise ImportError( "BlackJAX nested sampling not found. Install with: " "pip install git+https://github.com/handley-lab/blackjax@nested_sampling" ) from e
# Note: The actual nested sampler is created in sample() method # since it requires the acceptance walk kernel implementation def _create_unit_cube_transforms(self, prior: Prior) -> list[BijectiveTransform]: """Create unit-cube transforms for all prior parameters. This is required for BlackJAX nested sampling with acceptance walk, which samples in unit cube space and applies the inverse transform to evaluate in prior space. - ``UniformPrior`` components → single ``BoundToBound`` mapping :math:`[a, b] \\to [0, 1]`. - ``MultivariateGaussianPrior`` components → ``MVGaussianToUnitCube`` mapping :math:`\\mathcal{N}(\\mu, \\Sigma) \\to [0, 1]^n` via the probability integral transform. Parameters ---------- prior : Prior Prior distribution (CombinePrior of UniformPrior and/or MultivariateGaussianPrior components). Returns ------- list[BijectiveTransform] Transforms mapping all parameters to [0, 1]. Raises ------ ValueError If prior contains unsupported component types. """ from jesterTOV.inference.base.prior import ( UniformPrior, CombinePrior, MultivariateGaussianPrior, ) from jesterTOV.inference.base.transform import ( BoundToBound, MVGaussianToUnitCube, ) # Handle both single UniformPrior, MultivariateGaussianPrior, and CombinePrior if isinstance(prior, UniformPrior): prior = CombinePrior([prior]) elif isinstance(prior, MultivariateGaussianPrior): prior = CombinePrior([prior]) elif not isinstance(prior, CombinePrior): raise ValueError( f"BlackJAX NS-AW requires UniformPrior, MultivariateGaussianPrior, or CombinePrior, " f"got {type(prior).__name__}. " "Ensure your prior is a (combination of) UniformPrior / MultivariateGaussianPrior." ) transforms: list[BijectiveTransform] = [] uniform_names: list[str] = [] uniform_lower: dict[str, float] = {} uniform_upper: dict[str, float] = {} for component_prior in prior.base_prior: if isinstance(component_prior, UniformPrior): param_name = component_prior.parameter_names[0] uniform_names.append(param_name) uniform_lower[param_name] = component_prior.xmin uniform_upper[param_name] = component_prior.xmax elif isinstance(component_prior, MultivariateGaussianPrior): param_names = component_prior.parameter_names transforms.append( MVGaussianToUnitCube( name_mapping=(param_names, param_names), mean=component_prior.mean, cov=component_prior.cov, ) ) else: error_msg = ( f"BlackJAX NS-AW does not support prior component type " f"{type(component_prior).__name__} " f"(parameter: {component_prior.parameter_names}).\n" "Supported types: UniformPrior, MultivariateGaussianPrior." ) if isinstance(component_prior, CombinePrior): error_msg += ( "\nHint: Nested CombinePrior detected. Flatten it with " "CombinePrior(prior1.base_prior + prior2.base_prior)." ) raise ValueError(error_msg) # Collect all uniform parameters into a single BoundToBound transform if uniform_names: target_lower = {name: 0.0 for name in uniform_names} target_upper = {name: 1.0 for name in uniform_names} transforms.append( BoundToBound( name_mapping=(uniform_names, uniform_names), original_lower_bound=uniform_lower, original_upper_bound=uniform_upper, target_lower_bound=target_lower, target_upper_bound=target_upper, ) ) return transforms def _create_unit_cube_stepper(self) -> None: """Create stepper function that wraps parameters at [0, 1] boundaries. For JESTER, all parameters are bounded but not periodic, so we use modulo wrapping for all parameters to keep them in [0, 1]. """ def unit_cube_stepper( position: dict, direction: dict, step_size: float ) -> dict: """Step in unit cube with periodic boundary wrapping.""" proposed = jax.tree.map( lambda pos, d: pos + step_size * d, position, direction, ) # Wrap all parameters to [0, 1] using modulo return jax.tree.map(lambda prop: jnp.mod(prop, 1.0), proposed) self._unit_cube_stepper = unit_cube_stepper
[docs] def sample(self, key: PRNGKeyArray) -> None: """Run nested sampling until termination criterion. Parameters ---------- key : PRNGKeyArray JAX random key Notes ----- Initial live points are sampled from the prior and transformed to unit cube space internally. """ logger.info("Starting nested sampling...") start_time = time.time() # Import acceptance walk sampler from kernels from .kernels import bilby_adaptive_de_sampler_unit_cube # Configure sampler n_delete = int(self.config.n_live * self.config.n_delete_frac) logger.info(f"Sampling {self.config.n_live} live points, batch size {n_delete}") # Sample initial positions from prior key, subkey = jax.random.split(key) initial_particles = self.prior.sample(subkey, self.config.n_live) # Transform to unit cube for transform in self.sample_transforms: initial_particles = jax.vmap(transform.forward)(initial_particles) # Initialize nested sampler with acceptance walk kernel nested_sampler = bilby_adaptive_de_sampler_unit_cube( logprior_fn=self._logprior_fn, loglikelihood_fn=self._loglikelihood_fn, nlive=self.config.n_live, n_target=self.config.n_target, max_mcmc=self.config.max_mcmc, num_delete=n_delete, stepper_fn=self._unit_cube_stepper, max_proposals=self.config.max_proposals, ) # Initialize sampler state key, init_key = jax.random.split(key) state = nested_sampler.init(initial_particles, rng_key=init_key) def terminate(state): """Termination condition: stop when remaining evidence is small.""" # AdaptiveNSState stores evidence info in integrator dlogz = jnp.logaddexp(0, state.integrator.logZ_live - state.integrator.logZ) return jnp.isfinite(dlogz) and dlogz < self.config.termination_dlogz # JIT compile step function for performance step_fn = jax.jit(nested_sampler.step) # Progress callback for live updates during sampling def progress_callback(iteration: int, logZ: float, dlogZ: float) -> None: """Print progress update during nested sampling (called via io_callback).""" # Format logZ and dlogZ with appropriate precision logZ_str = f"{logZ:+10.2f}" if jnp.isfinite(logZ) else " -inf" dlogZ_str = f"{dlogZ:8.4f}" if jnp.isfinite(dlogZ) else " inf" # Print update logger.info( f"Iteration {iteration:4d} | logZ={logZ_str} | dlogZ={dlogZ_str}" ) # Run nested sampling loop logger.info("=" * 70) logger.info("STARTING NESTED SAMPLING") logger.info("=" * 70) logger.info(f"Live points: {self.config.n_live}") logger.info( f"Delete fraction: {self.config.n_delete_frac} ({n_delete} points per iteration)" ) logger.info(f"Termination: dlogZ < {self.config.termination_dlogz}") logger.info(f"Max MCMC steps: {self.config.max_mcmc}") logger.info("Progress updates will be shown after each iteration") logger.info("=" * 70) dead = [] n_iterations = 0 while not terminate(state): key, subkey = jax.random.split(key) state, dead_info = step_fn(subkey, state) dead.append(dead_info) n_iterations += 1 # Compute current evidence and termination criterion current_logZ = float(state.integrator.logZ) current_dlogZ = float( jnp.logaddexp(0, state.integrator.logZ_live - state.integrator.logZ) ) # Print progress update using io_callback io_callback( progress_callback, None, # No return value n_iterations, current_logZ, current_dlogZ, ) # Store evidence from state before finalization # (AdaptiveNSState stores evidence in integrator - type narrowing not supported) integrator = state.integrator # type: ignore[attr-defined] logZ = float(integrator.logZ) # Estimate uncertainty from remaining evidence in live points logZ_err = float(jnp.logaddexp(0, integrator.logZ_live - integrator.logZ)) # Finalize nested sampling results logger.info("Finalizing nested sampling results...") final_info = self._finalise(state, dead) # type: ignore[arg-type] # Transform particles back to prior space # Note: final_info.particles is StateWithLogLikelihood, we need just the position dict logger.info("Transforming samples back to prior space...") physical_particles = final_info.particles.position for transform in reversed(self.sample_transforms): # Type note: vmap preserves PyTree structure; physical_particles remains ArrayTree physical_particles = jax.vmap(transform.backward)(physical_particles) # type: ignore[arg-type] # Store final info with physical parameters # Create new StateWithLogLikelihood with transformed positions transformed_particles = StateWithLogLikelihood( position=physical_particles, logdensity=final_info.particles.logdensity, loglikelihood=final_info.particles.loglikelihood, loglikelihood_birth=final_info.particles.loglikelihood_birth, ) self.final_state = final_info._replace(particles=transformed_particles) # Store metadata end_time = time.time() sampling_time = end_time - start_time # Get number of samples from pytree (particles is a dict, not array) particles_leaves = jax.tree_util.tree_leaves(final_info.particles) n_samples = int(particles_leaves[0].shape[0]) if particles_leaves else 0 self.metadata = { "sampler": "blackjax_ns_aw", "n_live": self.config.n_live, "n_delete": n_delete, "n_delete_frac": self.config.n_delete_frac, "n_target": self.config.n_target, "max_mcmc": self.config.max_mcmc, "max_proposals": self.config.max_proposals, "termination_dlogz": self.config.termination_dlogz, "sampling_time_seconds": sampling_time, "sampling_time_minutes": sampling_time / 60, "n_iterations": n_iterations, "n_samples": n_samples, "n_likelihood_evaluations": int(jnp.sum(final_info.update_info.n_likelihood_evals)), # type: ignore[attr-defined] "logZ": logZ, "logZ_err": logZ_err, } logger.info("=" * 70) logger.info("NESTED SAMPLING COMPLETE") logger.info("=" * 70) logger.info(f"Total iterations: {n_iterations}") logger.info(f"Dead points generated: {len(dead) * n_delete}") logger.info(f"Final evidence: log(Z) = {logZ:.2f} ± {logZ_err:.2f}") logger.info( f"Final dlogZ: {logZ_err:.4f} (termination criterion: {self.config.termination_dlogz})" ) logger.info( f"Sampling time: {(sampling_time)//60:.0f} minutes {(sampling_time)%60:.1f} seconds" ) logger.info(f"Likelihood evaluations: {int(jnp.sum(final_info.update_info.n_likelihood_evals))}") # type: ignore[attr-defined] logger.info("=" * 70)
[docs] def print_summary(self, transform: bool = True) -> None: """Print summary of nested sampling run. Parameters ---------- transform : bool, optional Not used for nested sampling (always returns physical parameters) """ logger.info("=" * 70) logger.info("NESTED SAMPLING SUMMARY") logger.info("=" * 70) if self.final_state is None: logger.warning("No samples yet - run sample() first") return # Print evidence if "logZ" in self.metadata: logger.info( f"log(Z) = {self.metadata['logZ']:.2f} ± {self.metadata['logZ_err']:.2f}" ) # Print sampling info logger.info(f"Live points: {self.config.n_live}") logger.info( f"Sampling time: {self.metadata.get('sampling_time_seconds', 0):.1f}s" ) if "n_samples" in self.metadata: logger.info(f"Posterior samples: {self.metadata['n_samples']}")
[docs] def get_samples(self) -> dict: """Return unweighted posterior samples from nested sampling. This method computes importance weights using anesthetic, then resamples to produce approximately ESS (effective sample size) unweighted posterior samples. This ensures downstream analysis (plotting, postprocessing) treats all samples as equally weighted, which is the expected behavior. Returns ------- dict Dictionary with: - Parameter samples (resampled, unweighted) - 'logL': log likelihood values (resampled) - 'logL_birth': birth log likelihoods (resampled) Notes ----- The original weighted samples are cached in _filtered_samples_cache for advanced users who need access to the full weighted set. """ if self.final_state is None: raise RuntimeError("No samples available - run sample() first") # Handle birth likelihoods (replace NaN with -inf) logL_birth = self.final_state.particles.loglikelihood_birth.copy() logL_birth = jnp.where(jnp.isnan(logL_birth), -jnp.inf, logL_birth) # Compute importance weights using anesthetic (if available) # Note: anesthetic may drop invalid samples (logL <= logL_birth) try: from anesthetic.samples import NestedSamples import warnings # Note: self.final_state is NSInfo, which has particles in physical (prior) space # Suppress the logL <= logL_birth warning (it's handled internally by anesthetic) with warnings.catch_warnings(): warnings.filterwarnings( "ignore", "out of .* samples have logL <= logL_birth" ) ns_samples = NestedSamples( self.final_state.particles.position, logL=self.final_state.particles.loglikelihood, logL_birth=logL_birth, logzero=jnp.nan, dtype=jnp.float64, ) # Get posterior weights (anesthetic computes from logL and logL_birth) # NOTE: anesthetic may drop invalid samples, so we need to use the filtered data weights = ns_samples.get_weights() # Extract filtered samples from anesthetic (it's a DataFrame) # Get all columns except metadata columns param_cols = [ col for col in ns_samples.columns if col not in ["logL", "logL_birth", "weights"] ] samples = {col: jnp.array(ns_samples[col].values) for col in param_cols} # Add metadata samples["weights"] = jnp.array(weights) samples["logL"] = jnp.array(ns_samples["logL"].values) samples["logL_birth"] = jnp.array(ns_samples["logL_birth"].values) # Cache filtered samples BEFORE resampling (for get_log_prob() to use if needed) self._filtered_samples_cache = samples.copy() # Resample weighted samples to produce unweighted posterior samples # This is critical for downstream analysis that assumes equal weights logger.info( "Resampling weighted NS samples to produce unweighted posterior..." ) # Compute effective sample size weights_array = samples["weights"] ess = jnp.sum(weights_array) ** 2 / jnp.sum(weights_array**2) logger.info( f"Effective sample size: {ess:.1f} / {len(weights_array)} raw samples" ) # Number of samples to draw: use ESS as target # Round to nearest integer, minimum 100 samples n_resample = max(100, int(jnp.round(ess))) logger.info(f"Resampling to {n_resample} unweighted posterior samples...") # Normalize weights for sampling normalized_weights = weights_array / jnp.sum(weights_array) # Resample with replacement using weighted sampling key = jax.random.PRNGKey( self._seed + 1000 ) # Offset seed for reproducibility indices = jax.random.choice( key, len(weights_array), shape=(n_resample,), replace=True, p=normalized_weights, ) # Create resampled samples dict resampled_samples = {} for key_name, value in samples.items(): if key_name == "weights": # All samples now have equal weight continue elif key_name in ["logL", "logL_birth"]: # Keep metadata resampled_samples[key_name] = value[indices] else: # Resample parameter arrays resampled_samples[key_name] = value[indices] # Replace samples with resampled version samples = resampled_samples logger.info( f"Resampling complete: {len(samples[list(samples.keys())[0]])} unweighted samples" ) # Update cache to match resampled data (for get_log_prob() consistency) self._filtered_samples_cache = samples.copy() # Store evidence from anesthetic computation (more accurate than our estimate) try: # Note: logZ() returns the evidence value; std() is accessed on the samples logZ_result = ns_samples.logZ() logZ_anesthetic = float(logZ_result) # type: ignore[arg-type] # Get standard deviation from the logZ samples # Note: anesthetic stores logZ values in the samples dataframe logZ_err_anesthetic = float(ns_samples.logZ().std()) # type: ignore[union-attr] # Only set if both succeed self.metadata["logZ_anesthetic"] = logZ_anesthetic self.metadata["logZ_err_anesthetic"] = logZ_err_anesthetic except Exception as e: logger.warning(f"Could not compute anesthetic evidence: {e}") except ImportError: logger.warning( "anesthetic not available - using all samples without resampling" ) # Use all samples without filtering or resampling samples = dict(self.final_state.particles.position) samples["logL"] = self.final_state.particles.loglikelihood samples["logL_birth"] = logL_birth self._filtered_samples_cache = samples logger.warning( "Without anesthetic, cannot compute proper weights. " "All samples treated equally (may include low-weight samples)." ) except Exception as e: logger.warning( f"anesthetic weight computation failed: {e} - using all samples without resampling" ) # Use all samples without filtering or resampling samples = dict(self.final_state.particles.position) samples["logL"] = self.final_state.particles.loglikelihood samples["logL_birth"] = logL_birth self._filtered_samples_cache = samples logger.warning( "Without proper weights, all samples treated equally (may include low-weight samples)." ) return samples
[docs] def get_log_prob(self) -> Array: """Get log likelihoods from nested sampling. Returns ------- Array Log likelihood values (1D array) Note: For NS, this is log likelihood, not log posterior. Use weights separately for posterior inference. Notes ----- This method returns filtered log likelihoods (matching get_samples()). If anesthetic has dropped invalid samples, the length will be less than the raw NSInfo.loglikelihood array. """ if self.final_state is None: raise RuntimeError("No samples available - run sample() first") # Use cached filtered samples if available (from get_samples()) # This ensures get_log_prob() and get_samples() have consistent lengths if self._filtered_samples_cache is not None: return self._filtered_samples_cache["logL"] # Fallback: return all log likelihoods (unfiltered) # This shouldn't happen in normal usage since get_samples() is called first logger.warning( "get_log_prob() called before get_samples() - returning unfiltered logL" ) return self.final_state.particles.loglikelihood
[docs] def get_n_samples(self) -> int: """Get number of posterior samples from nested sampling. Returns ------- int Number of posterior samples Notes ----- This method returns the number of filtered samples (matching get_samples()). If anesthetic has dropped invalid samples, the count will be less than the raw NSInfo particle count. """ if self.final_state is None: return 0 # Use cached filtered samples if available (from get_samples()) # This ensures get_n_samples() matches get_samples() and get_log_prob() if self._filtered_samples_cache is not None: # Get length from any parameter array first_param = next(iter(self._filtered_samples_cache.keys())) return len(self._filtered_samples_cache[first_param]) # Fallback: return all particles (unfiltered) # Get length from parameter array in position dict n_samples = len(next(iter(self.final_state.particles.position.values()))) return n_samples
[docs] def get_sampler_output(self) -> SamplerOutput: """ Get standardized sampler output. Returns ------- SamplerOutput - samples: Unweighted parameter samples (dict of arrays, no metadata fields) - log_prob: Log likelihood (NOT log posterior - NS works in likelihood space) - metadata: {"logL": Array, "logL_birth": Array} Raises ------ RuntimeError If sampling has not been run yet. Notes ----- Samples are resampled using importance weights to produce unweighted posterior samples. This ensures downstream analysis treats all samples equally. log_prob contains log likelihood, not log posterior (standard for NS). """ if self.final_state is None: raise RuntimeError("No samples available. Run sample() first.") # Get current samples dict (includes weights, logL, logL_birth) all_data = self.get_samples() # Separate parameters from metadata samples: dict[str, Array] = {} metadata: dict[str, Any] = {} metadata_keys = {"weights", "logL", "logL_birth"} for key, value in all_data.items(): if key in metadata_keys: metadata[key] = value else: samples[key] = value # Get log probabilities (log likelihood for NS-AW) log_prob = self.get_log_prob() return SamplerOutput( samples=samples, log_prob=log_prob, metadata=metadata, )