jesterTOV.inference.samplers.blackjax.smc.base.BlackjaxSMCSampler#

class BlackjaxSMCSampler(
likelihood,
prior,
sample_transforms,
likelihood_transforms,
config,
seed=0,
)[source]#

Bases: BlackjaxSampler

Base class for BlackJAX Sequential Monte Carlo with adaptive tempering.

This abstract base class implements all the shared SMC functionality (adaptive tempering, particle management, result handling) and delegates only the kernel-specific parts to subclasses.

Key differences from parent BlackjaxSampler: - Adds flattening/unflattening utilities (SMC requires flat arrays) - Wraps dict-based functions from parent for flat array API - Implements full SMC sampling loop

Subclasses must implement: - _setup_mcmc_kernel(): Return (mcmc_step_fn, mcmc_init_fn, init_params, mcmc_parameter_update_fn) - _get_kernel_name(): Return string name for logging/plotting

Parameters:
Variables:
  • config (SMCRandomWalkSamplerConfig | SMCNUTSSamplerConfig) – Sampler configuration

  • final_state (Any | None) – Final SMC state (after sampling)

  • metadata (dict) – Sampling metadata (ESS, time, etc.)

  • _unflatten_fn (callable) – Function to convert flat arrays back to parameter dicts

  • _flatten_fn (callable) – Function to convert parameter dicts to flat arrays

  • _particles_flat (Array) – Final particle positions (flat arrays)

  • _weights (Array) – Final particle weights

__init__(
likelihood,
prior,
sample_transforms,
likelihood_transforms,
config,
seed=0,
)[source]#

Initialize BlackJAX SMC sampler.

Methods

__init__(likelihood, prior, ...[, seed])

Initialize BlackJAX SMC sampler.

add_name(x)

Turn an array into a dictionary.

get_log_prob()

Get log posterior probabilities from SMC.

get_n_samples()

Get number of particles from SMC.

get_sampler_output()

Get standardized sampler output.

get_samples()

Return final particle positions.

plot_diagnostics([outdir, filename])

Generate diagnostic plots for SMC sampling run.

posterior(params, data)

Evaluate posterior log probability from flat array.

posterior_from_dict(named_params, data)

Evaluate posterior log probability from parameter dict.

print_summary([transform])

Print summary of sampling run.

sample(key)

Run SMC until λ = 1 (posterior).

Attributes

config

final_state

metadata

likelihood

prior

sample_transforms

likelihood_transforms

parameter_names

sampler

config: SMCRandomWalkSamplerConfig | SMCNUTSSamplerConfig#
final_state: Any | None#
get_log_prob()[source]#

Get log posterior probabilities from SMC.

Return type:

Array

Returns:

Array – Log posterior probability values (1D array) Note: At λ=1 (final tempering), these are true posterior values.

get_n_samples()[source]#

Get number of particles from SMC.

Return type:

int

Returns:

int – Number of particles

get_sampler_output()[source]#

Get standardized sampler output.

Return type:

SamplerOutput

Returns:

SamplerOutput

  • samples: Parameter samples (dict of arrays, no weights/ess)

  • log_prob: Log posterior at λ=1 (final tempering)

  • metadata: {“weights”: Array, “ess”: float}

Raises:

RuntimeError – If sampling has not been run yet.

get_samples()[source]#

Return final particle positions.

Return type:

dict

Returns:

dict – Dictionary with: - Parameter samples (transformed back to prior space) - ‘weights’: particle weights - ‘ess’: effective sample size

metadata: dict#
plot_diagnostics(outdir='.', filename='smc_diagnostics.png')[source]#

Generate diagnostic plots for SMC sampling run.

Creates a 3-panel figure showing: - Temperature (lambda) progression from 0 to 1 - Effective Sample Size (ESS) evolution - Acceptance rate evolution

Parameters:
  • outdir (str or Path, optional) – Output directory for saving the plot (default: current directory)

  • filename (str, optional) – Filename for the diagnostic plot (default: “smc_diagnostics.png”)

Return type:

None

Notes

This method requires matplotlib to be installed. It should be called after sampling is complete (after calling sample()).

sample(key)[source]#

Run SMC until λ = 1 (posterior).

Parameters:

key (PRNGKeyArray) – JAX random key

Return type:

None

Notes

Initial particles are sampled from the prior internally.