Source code for jesterTOV.inference.samplers.blackjax.smc.base

"""Base class for BlackJAX Sequential Monte Carlo (SMC) samplers.

This module provides BlackjaxSMCSampler, which implements shared SMC functionality
(adaptive tempering, particle management, result handling) and delegates
only the kernel-specific parts to subclasses.
"""

from abc import abstractmethod
from typing import Any, Callable, cast
import time
from pathlib import Path
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
import jax.random
from jax import flatten_util
from jax.tree_util import tree_map
from jax.experimental import io_callback
from jaxtyping import Array, PRNGKeyArray

from jesterTOV.inference.base import (
    LikelihoodBase,
    Prior,
    BijectiveTransform,
    NtoMTransform,
)
from jesterTOV.inference.config.schema import (
    SMCRandomWalkSamplerConfig,
    SMCNUTSSamplerConfig,
)
from jesterTOV.inference.samplers.jester_sampler import SamplerOutput
from jesterTOV.inference.samplers.blackjax.base import BlackjaxSampler
from jesterTOV.logging_config import get_logger

from blackjax import inner_kernel_tuning, adaptive_tempered_smc
from blackjax.smc import extend_params
from blackjax.smc.base import SMCInfo
from blackjax.smc.inner_kernel_tuning import StateWithParameterOverride
from blackjax.smc.resampling import systematic
from blackjax.smc.tempered import TemperedSMCState

logger = get_logger("jester")


[docs] class BlackjaxSMCSampler(BlackjaxSampler): """Base class for BlackJAX Sequential Monte Carlo with adaptive tempering. This abstract base class implements all the shared SMC functionality (adaptive tempering, particle management, result handling) and delegates only the kernel-specific parts to subclasses. Key differences from parent BlackjaxSampler: - Adds flattening/unflattening utilities (SMC requires flat arrays) - Wraps dict-based functions from parent for flat array API - Implements full SMC sampling loop Subclasses must implement: - _setup_mcmc_kernel(): Return (mcmc_step_fn, mcmc_init_fn, init_params, mcmc_parameter_update_fn) - _get_kernel_name(): Return string name for logging/plotting Parameters ---------- likelihood : LikelihoodBase Likelihood object with evaluate(params, data) method prior : Prior Prior object sample_transforms : list[BijectiveTransform] Should be empty for SMC (works in prior space) likelihood_transforms : list[NtoMTransform] N-to-M transforms applied before likelihood evaluation config : SMCRandomWalkSamplerConfig | SMCNUTSSamplerConfig SMC configuration seed : int, optional Random seed (default: 0) Attributes ---------- config : SMCRandomWalkSamplerConfig | SMCNUTSSamplerConfig Sampler configuration final_state : Any | None Final SMC state (after sampling) metadata : dict Sampling metadata (ESS, time, etc.) _unflatten_fn : callable Function to convert flat arrays back to parameter dicts _flatten_fn : callable Function to convert parameter dicts to flat arrays _particles_flat : Array Final particle positions (flat arrays) _weights : Array Final particle weights """ config: SMCRandomWalkSamplerConfig | SMCNUTSSamplerConfig final_state: Any | None metadata: dict _unflatten_fn: Any # Callable[[Array], dict] _flatten_fn: Any # Callable[[dict], Array] _particles_flat: Array | None _weights: Array | None
[docs] def __init__( self, likelihood: LikelihoodBase, prior: Prior, sample_transforms: list[BijectiveTransform], likelihood_transforms: list[NtoMTransform], config: SMCRandomWalkSamplerConfig | SMCNUTSSamplerConfig, seed: int = 0, ) -> None: """Initialize BlackJAX SMC sampler.""" super().__init__(likelihood, prior, sample_transforms, likelihood_transforms) self.config = config self.final_state = None self.metadata = {} self._unflatten_fn = None self._flatten_fn = None self._particles_flat = None self._weights = None self._seed = seed # Validate that we don't have sample transforms (SMC works in prior space) if len(sample_transforms) > 0: logger.warning( "SMC sampler received sample transforms. SMC typically works best " "without sample transforms (in prior space). Proceeding anyway." ) logger.info( f"Initializing BlackJAX SMC sampler with {self._get_kernel_name()} kernel" ) logger.info( f"Configuration: {config.n_particles} particles, " f"{config.n_mcmc_steps} MCMC steps per tempering stage" ) logger.info(f"Target ESS: {config.target_ess}")
def _create_flatten_unflatten_utilities( self, initial_position_dict: dict[str, Array] ) -> None: """Create flatten/unflatten functions for SMC's flat array API. Parameters ---------- initial_position_dict : dict[str, Array] Dictionary of initial particle positions (each value is array of shape (n_particles,)) """ # Extract single sample to determine structure single_sample_dict = tree_map(lambda x: x[0], initial_position_dict) # Create unflatten function using ravel_pytree (alphabetical ordering) _, self._unflatten_fn = flatten_util.ravel_pytree(single_sample_dict) # Create flatten function self._flatten_fn = lambda x: flatten_util.ravel_pytree(x)[0] def _wrap_dict_fn_for_flat_arrays( self, dict_fn: Callable[[dict], float] ) -> Callable[[Array], float]: """Wrap a dict-based function to work with flat arrays. This is the bridge between BlackjaxSampler's dict functions and SMC's flat array API. Parameters ---------- dict_fn : Callable[[dict], float] Function that takes parameter dict and returns float Returns ------- Callable[[Array], float] Function that takes flat array and returns float Examples -------- >>> logprior_dict = self._create_logprior_fn_from_dict() >>> logprior_flat = self._wrap_dict_fn_for_flat_arrays(logprior_dict) >>> # Now logprior_flat can be passed to BlackJAX SMC """ def flat_fn(x_flat: Array) -> float: """Convert flat array to dict, evaluate function.""" x_flat = jnp.atleast_1d(x_flat) x_dict = self._unflatten_fn(x_flat) return dict_fn(x_dict) return flat_fn @abstractmethod def _setup_mcmc_kernel( self, logprior_fn: Callable, loglikelihood_fn: Callable, logposterior_fn: Callable, initial_particles: Array, ) -> tuple[Callable, Callable, dict, Callable]: """Setup kernel-specific components. Parameters ---------- logprior_fn : Callable Log prior function for single particle (flat array) loglikelihood_fn : Callable Log likelihood function for single particle (flat array) logposterior_fn : Callable Log posterior function for single particle (flat array, for NUTS Hessian) initial_particles : Array Initial particle positions (flat arrays, shape: (n_particles, n_dim)) Returns ------- tuple[Callable, Callable, dict, Callable] - mcmc_step_fn: MCMC step function - mcmc_init_fn: MCMC initialization function - init_params: Initial parameter dict for the kernel - mcmc_parameter_update_fn: Function to adapt parameters """ pass @abstractmethod def _get_kernel_name(self) -> str: """Return the kernel name for logging/plotting.""" pass
[docs] def sample(self, key: PRNGKeyArray) -> None: """Run SMC until λ = 1 (posterior). Parameters ---------- key : PRNGKeyArray JAX random key Notes ----- Initial particles are sampled from the prior internally. """ logger.info(f"Starting SMC sampling with {self._get_kernel_name()} kernel...") start_time = time.time() # Sample initial particles from prior key, subkey = jax.random.split(key) initial_position_dict: dict[str, Array] = self.prior.sample( subkey, self.config.n_particles ) # Apply sample transforms if any for transform in self.sample_transforms: initial_position_list = [] for i in range(self.config.n_particles): particle_dict = { name: initial_position_dict[name][i] for name in self.prior.parameter_names } transformed_dict, _ = transform.transform(particle_dict) initial_position_list.append(transformed_dict) # Reconstruct dict of arrays initial_position_dict = { name: jnp.array([p[name] for p in initial_position_list]) for name in initial_position_list[0].keys() } # Create flatten/unflatten utilities self._create_flatten_unflatten_utilities(initial_position_dict) # Flatten all particles using the flatten function initial_position_flat = jax.vmap(self._flatten_fn)(initial_position_dict) # Ensure float dtype for compatibility if not jnp.issubdtype(initial_position_flat.dtype, jnp.floating): logger.warning( f"Converting initial_position_flat from {initial_position_flat.dtype} to float64" ) initial_position_flat = initial_position_flat.astype(jnp.float64) # Create dict-based functions from parent class logprior_dict = self._create_logprior_fn_from_dict() loglikelihood_dict = self._create_loglikelihood_fn_from_dict() # Wrap for flat arrays (SMC requirement) logprior_fn = self._wrap_dict_fn_for_flat_arrays(logprior_dict) loglikelihood_fn = self._wrap_dict_fn_for_flat_arrays(loglikelihood_dict) # Create posterior for kernel setup (e.g., NUTS Hessian) logposterior_fn = lambda x: logprior_fn(x) + loglikelihood_fn(x) # Setup kernel-specific components mcmc_step_fn, mcmc_init_fn, init_params, mcmc_parameter_update_fn = ( self._setup_mcmc_kernel( logprior_fn, loglikelihood_fn, logposterior_fn, initial_position_flat ) ) # Initialize SMC algorithm with kernel smc_alg = inner_kernel_tuning( smc_algorithm=adaptive_tempered_smc, logprior_fn=logprior_fn, loglikelihood_fn=loglikelihood_fn, mcmc_step_fn=mcmc_step_fn, mcmc_init_fn=mcmc_init_fn, resampling_fn=systematic, mcmc_parameter_update_fn=mcmc_parameter_update_fn, initial_parameter_value=extend_params(init_params), # type: ignore[arg-type] target_ess=self.config.target_ess, num_mcmc_steps=self.config.n_mcmc_steps, ) # Initialize SMC state key, subkey = jax.random.split(key) state = smc_alg.init(initial_position_flat, subkey) # Progress callback for live updates during sampling def progress_callback( step: int, tempering_param: float, ess: float, acceptance: float ) -> None: """Print progress update during sampling (called via io_callback).""" # Create progress bar bar_length = 30 filled = int(tempering_param * bar_length) bar = "█" * filled + "░" * (bar_length - filled) # Print update logger.info( f"Step {step:4d} | λ={tempering_param:.6f} | ESS={ess*100:5.1f}% | " f"Accept={acceptance*100:5.1f}% | {bar}" ) # Define loop conditions with proper type hints # Carry is: (StateWithParameterOverride, key, step_count, tempering_param_history, ess_history, acceptance_history, log_evidence) def cond_fn( carry: tuple[ StateWithParameterOverride, PRNGKeyArray, int, Array, Array, Array, float, ] ) -> bool: state, _, _, _, _, _, _ = carry # Cast to proper type for type checker (runtime type is correct) sampler_state = cast(TemperedSMCState, state.sampler_state) # Type checker sees this as potentially returning Array, but at runtime # tempering_param is a scalar float, so comparison returns bool return sampler_state.tempering_param < 1 # type: ignore[return-value] def body_fn( carry: tuple[ StateWithParameterOverride, PRNGKeyArray, int, Array, Array, Array, float, ] ): ( state, key, step_count, tempering_param_history, ess_history, acceptance_history, log_evidence, ) = carry key, subkey = jax.random.split(key, 2) state, info = smc_alg.step(subkey, state) # Cast to proper types for type checker (runtime types are correct) state = cast(StateWithParameterOverride, state) info = cast(SMCInfo, info) sampler_state = cast(TemperedSMCState, state.sampler_state) # Accumulate log evidence from log_likelihood_increment log_evidence = log_evidence + info.log_likelihood_increment # Compute ESS weights = sampler_state.weights ess_value = ( jnp.sum(weights) ** 2 / jnp.sum(weights**2) / self.config.n_particles ) # Extract acceptance rate # Note: update_info is kernel-specific NamedTuple, not fully typed in blackjax acceptance_rate = info.update_info.acceptance_rate.mean() # type: ignore[attr-defined] # Update histories tempering_param_history = tempering_param_history.at[step_count].set( sampler_state.tempering_param ) ess_history = ess_history.at[step_count].set(ess_value) acceptance_history = acceptance_history.at[step_count].set(acceptance_rate) # Print progress update using io_callback io_callback( progress_callback, None, # No return value step_count, sampler_state.tempering_param, ess_value, acceptance_rate, ) return ( state, key, step_count + 1, tempering_param_history, ess_history, acceptance_history, log_evidence, ) # Run SMC with JAX while_loop logger.info("=" * 70) logger.info("STARTING ADAPTIVE TEMPERING") logger.info("=" * 70) logger.info(f"Kernel: {self._get_kernel_name().upper()}") logger.info(f"Particles: {self.config.n_particles}") logger.info(f"MCMC steps per tempering: {self.config.n_mcmc_steps}") logger.info(f"Target ESS: {self.config.target_ess * 100:.0f}%") logger.info("Temperature progression: lambda = 0 (prior) -> 1 (posterior)") logger.info("Progress updates will be shown after each annealing step") logger.info("=" * 70) max_steps = 1000 tempering_param_history = jnp.zeros(max_steps) ess_history = jnp.zeros(max_steps) acceptance_history = jnp.zeros(max_steps) log_evidence = 0.0 # Initialize log evidence accumulator init_carry = ( state, key, 0, tempering_param_history, ess_history, acceptance_history, log_evidence, ) logger.info("Running SMC loop (this may take several minutes)...") loop_start_time = time.time() ( state, key, steps, tempering_param_history, ess_history, acceptance_history, log_evidence, ) = jax.lax.while_loop( cond_fn, body_fn, init_carry # type: ignore[arg-type] ) loop_end_time = time.time() steps = int(steps) end_time = time.time() # Extract final particles # Cast to proper type for type checker (runtime type is correct) final_sampler_state = cast(TemperedSMCState, state.sampler_state) self._particles_flat = cast(Array, final_sampler_state.particles) self._weights = final_sampler_state.weights self.final_state = state # Compute final ESS (weights guaranteed non-None after assignment above) assert self._weights is not None ess = jnp.sum(self._weights) ** 2 / jnp.sum(self._weights**2) # Compute summary statistics mean_ess = float(jnp.mean(ess_history[:steps])) min_ess = float(jnp.min(ess_history[:steps])) mean_acceptance = float(jnp.mean(acceptance_history[:steps])) # FIXME: Need to implement a way to compute evidence error estimate log_evidence_err = 0.0 # Placeholder for now # Store metadata (kernel name will be set by subclass) self.metadata = { "sampler": f"blackjax_smc_{self._get_kernel_name()}", "kernel_type": self._get_kernel_name(), "n_particles": self.config.n_particles, "n_mcmc_steps": self.config.n_mcmc_steps, "target_ess": self.config.target_ess, "annealing_steps": steps, "final_ess": float(ess), "final_ess_percent": float(ess / self.config.n_particles * 100), "mean_ess": mean_ess, "min_ess": min_ess, "mean_acceptance": mean_acceptance, "logZ": float(log_evidence), "logZ_err": float(log_evidence_err), "sampling_time_seconds": end_time - start_time, "loop_time_seconds": loop_end_time - loop_start_time, "tempering_param_history": tempering_param_history[:steps].tolist(), "ess_history": ess_history[:steps].tolist(), "acceptance_history": acceptance_history[:steps].tolist(), }
[docs] def plot_diagnostics( self, outdir: str | Path = ".", filename: str = "smc_diagnostics.png" ) -> None: """Generate diagnostic plots for SMC sampling run. Creates a 3-panel figure showing: - Temperature (lambda) progression from 0 to 1 - Effective Sample Size (ESS) evolution - Acceptance rate evolution Parameters ---------- outdir : str or Path, optional Output directory for saving the plot (default: current directory) filename : str, optional Filename for the diagnostic plot (default: "smc_diagnostics.png") Notes ----- This method requires matplotlib to be installed. It should be called after sampling is complete (after calling `sample()`). """ if self.final_state is None: logger.warning("No samples yet - run sample() first") return # Extract histories from metadata tempering_param_history = self.metadata["tempering_param_history"] ess_history = self.metadata["ess_history"] acceptance_history = self.metadata["acceptance_history"] n_steps = self.metadata["annealing_steps"] # Create figure with 3 subplots fig, axes = plt.subplots(3, 1, figsize=(10, 9), sharex=True) kernel_name = self._get_kernel_name().upper() fig.suptitle( f"SMC Diagnostics ({kernel_name} kernel)", fontsize=14, fontweight="bold", ) # Plot 1: Lambda (temperature) progression axes[0].plot(range(n_steps), tempering_param_history, "b-o", linewidth=2) axes[0].set_ylabel(r"Inverse temperature $\lambda$", fontsize=12) axes[0].grid(True, alpha=0.3) axes[0].set_ylim(-0.05, 1.05) axes[0].axhline(y=0, color="black", linestyle="--", alpha=0.3, linewidth=1) axes[0].axhline(y=1, color="black", linestyle="--", alpha=0.3, linewidth=1) # Plot 2: ESS evolution ess_percent = [ess * 100 for ess in ess_history] axes[1].plot(range(n_steps), ess_percent, "g-o", linewidth=2) axes[1].axhline( y=self.config.target_ess * 100, color="black", linestyle="--", alpha=0.5, linewidth=1.5, label=f"Target ({self.config.target_ess*100:.0f}%)", ) axes[1].set_ylabel("ESS (%)", fontsize=12) axes[1].grid(True, alpha=0.3) axes[1].legend(loc="best", fontsize=10) axes[1].set_ylim(0, 105) # Plot 3: Acceptance rate evolution acceptance_percent = [acc * 100 for acc in acceptance_history] axes[2].plot( range(n_steps), acceptance_percent, "orange", linestyle="-", marker="o", linewidth=2, ) axes[2].set_ylabel("Acceptance Rate (%)", fontsize=12) axes[2].set_xlabel("Annealing Step", fontsize=12) axes[2].grid(True, alpha=0.3) axes[2].set_ylim(0, 105) plt.tight_layout() # Save figure outdir_path = Path(outdir) outdir_path.mkdir(parents=True, exist_ok=True) output_path = outdir_path / filename plt.savefig(output_path, dpi=150, bbox_inches="tight") logger.info(f"Saved diagnostic plot to {output_path}") plt.close(fig)
[docs] def get_samples(self) -> dict: """Return final particle positions. Returns ------- dict Dictionary with: - Parameter samples (transformed back to prior space) - 'weights': particle weights - 'ess': effective sample size """ if ( self.final_state is None or self._particles_flat is None or self._weights is None ): raise RuntimeError("No samples available - run sample() first") # Transform particles back to structured format assert self._particles_flat is not None assert self._weights is not None particles_dict = jax.vmap(self._unflatten_fn)(self._particles_flat) # Apply inverse sample transforms if any for transform in reversed(self.sample_transforms): particles_list = [] n_particles = len(self._particles_flat) for i in range(n_particles): particle_dict = { name: particles_dict[name][i] for name in particles_dict.keys() } transformed_dict, _ = transform.inverse(particle_dict) particles_list.append(transformed_dict) # Reconstruct dict of arrays particles_dict = { name: jnp.array([p[name] for p in particles_list]) for name in particles_list[0].keys() } # Add weights and ESS to output particles_dict["weights"] = self._weights particles_dict["ess"] = self.metadata["final_ess"] return particles_dict
[docs] def get_log_prob(self) -> Array: """Get log posterior probabilities from SMC. Returns ------- Array Log posterior probability values (1D array) Note: At λ=1 (final tempering), these are true posterior values. """ if self.final_state is None or self._particles_flat is None: raise RuntimeError("No samples available - run sample() first") assert self._particles_flat is not None def compute_log_prob(particle_flat): # Convert from flat array (alphabetical order) to dict using _unflatten_fn x_dict = self._unflatten_fn(particle_flat) # Use base class method to compute posterior from dict return self.posterior_from_dict(x_dict, {}) # Use batched processing for efficiency log_probs = jax.lax.map( compute_log_prob, self._particles_flat, batch_size=self.config.log_prob_batch_size, ) logger.info(f"Computed {len(log_probs)} log probability values") return log_probs
[docs] def get_n_samples(self) -> int: """Get number of particles from SMC. Returns ------- int Number of particles """ if self._particles_flat is None: return 0 return len(self._particles_flat)
[docs] def get_sampler_output(self) -> SamplerOutput: """Get standardized sampler output. Returns ------- SamplerOutput - samples: Parameter samples (dict of arrays, no weights/ess) - log_prob: Log posterior at λ=1 (final tempering) - metadata: {"weights": Array, "ess": float} Raises ------ RuntimeError If sampling has not been run yet. """ if self._particles_flat is None: raise RuntimeError("No samples available. Run sample() first.") # Get current samples dict (includes weights, ess) all_data = self.get_samples() # Separate parameters from metadata samples: dict[str, Array] = {} metadata: dict[str, Any] = {} metadata_keys = {"weights", "ess"} for key, value in all_data.items(): if key in metadata_keys: metadata[key] = value else: samples[key] = value # Get log probabilities log_prob = self.get_log_prob() return SamplerOutput( samples=samples, log_prob=log_prob, metadata=metadata, )