Source code for jesterTOV.inference.samplers.blackjax.smc.random_walk

"""SMC with Gaussian Random Walk Metropolis-Hastings kernel."""

from typing import Callable, cast

import jax
import jax.numpy as jnp
from jax import flatten_util
from jaxtyping import Array

from jesterTOV.inference.base import (
    LikelihoodBase,
    Prior,
    BijectiveTransform,
    NtoMTransform,
)
from jesterTOV.inference.config.schema import SMCRandomWalkSamplerConfig
from jesterTOV.inference.samplers.blackjax.smc.base import BlackjaxSMCSampler
from jesterTOV.logging_config import get_logger

from blackjax.mcmc import random_walk
from blackjax.smc import extend_params
from blackjax.smc.tuning.from_particles import particles_covariance_matrix

logger = get_logger("jester")


[docs] class BlackJAXSMCRandomWalkSampler(BlackjaxSMCSampler): """SMC with Gaussian Random Walk Metropolis-Hastings kernel. This sampler uses a simple random walk proposal with adaptive sigma tuning. Recommended for most use cases due to simplicity and robustness. The proposal covariance is adapted from current particles at each tempering step and scaled by a fixed sigma^2 parameter. Parameters ---------- likelihood : LikelihoodBase Likelihood object prior : Prior Prior object sample_transforms : list[BijectiveTransform] Sample transforms (typically empty for SMC) likelihood_transforms : list[NtoMTransform] Likelihood transforms config : SMCRandomWalkSamplerConfig Random walk SMC configuration seed : int, optional Random seed (default: 0) """
[docs] def __init__( self, likelihood: LikelihoodBase, prior: Prior, sample_transforms: list[BijectiveTransform], likelihood_transforms: list[NtoMTransform], config: SMCRandomWalkSamplerConfig, seed: int = 0, ) -> None: """Initialize Random Walk SMC sampler.""" super().__init__( likelihood, prior, sample_transforms, likelihood_transforms, config, seed )
def _get_kernel_name(self) -> str: """Return kernel name.""" return "random_walk" def _setup_mcmc_kernel( self, logprior_fn: Callable, loglikelihood_fn: Callable, logposterior_fn: Callable, initial_particles: Array, ) -> tuple[Callable, Callable, dict, Callable]: """Setup Random Walk kernel with covariance adaptation. The proposal covariance is computed from current particles and scaled by a fixed sigma^2 factor. Only the covariance shape is adapted, not the overall scale. Parameters ---------- logprior_fn : Callable Log prior function (not used for random walk) loglikelihood_fn : Callable Log likelihood function (not used for random walk) logposterior_fn : Callable Log posterior function (not used for random walk) initial_particles : Array Initial particle positions for computing initial covariance Returns ------- tuple[Callable, Callable, dict, Callable] (mcmc_step_fn, mcmc_init_fn, init_params, mcmc_parameter_update_fn) """ # Type narrow config for this subclass config = cast(SMCRandomWalkSamplerConfig, self.config) logger.info("Using random walk kernel") logger.info(f"Fixed sigma scaling: {config.random_walk_sigma}") # Setup random walk kernel with additive step kernel = random_walk.build_additive_step() # Compute initial covariance from initial particles init_cov = particles_covariance_matrix(initial_particles) # Ensure 2D array (n_dim, n_dim) even for 1D problems init_cov = jnp.atleast_2d(init_cov) # Scale by fixed sigma^2 init_cov = init_cov * (config.random_walk_sigma**2) init_params = {"cov": init_cov} # Define parameter update function with covariance adaptation only def mcmc_parameter_update_fn(key, state, info): """Adapt proposal covariance based on current particle distribution. The covariance matrix is computed from current particles and scaled by the fixed sigma^2 parameter. No scale adaptation is performed. """ # Note: state here is TemperedSMCState, particles are at state.particles # Compute covariance matrix from current particles cov = particles_covariance_matrix(state.particles) # Ensure 2D array (n_dim, n_dim) even for 1D problems cov = jnp.atleast_2d(cov) # Scale covariance by fixed sigma^2 scaled_cov = cov * (config.random_walk_sigma**2) return extend_params({"cov": scaled_cov}) # type: ignore[arg-type] # Wrap kernel to match expected signature def mcmc_step_fn(rng_key, state, logdensity_fn, **params): """Random walk step function with multivariate normal proposal.""" cov = params.get("cov", init_cov) def proposal_distribution(key, position): """Multivariate normal proposal using covariance matrix.""" x, ravel_fn = flatten_util.ravel_pytree(position) return ravel_fn( jax.random.multivariate_normal(key, jnp.zeros_like(x), cov) ) return kernel(rng_key, state, logdensity_fn, proposal_distribution) # Init function for random walk mcmc_init_fn = random_walk.init return mcmc_step_fn, mcmc_init_fn, init_params, mcmc_parameter_update_fn