jesterTOV.inference.samplers.blackjax.smc.base.BlackjaxSMCSampler#
- class BlackjaxSMCSampler(
- likelihood,
- prior,
- sample_transforms,
- likelihood_transforms,
- config,
- seed=0,
Bases:
BlackjaxSamplerBase class for BlackJAX Sequential Monte Carlo with adaptive tempering.
This abstract base class implements all the shared SMC functionality (adaptive tempering, particle management, result handling) and delegates only the kernel-specific parts to subclasses.
Key differences from parent BlackjaxSampler: - Adds flattening/unflattening utilities (SMC requires flat arrays) - Wraps dict-based functions from parent for flat array API - Implements full SMC sampling loop
Subclasses must implement: - _setup_mcmc_kernel(): Return (mcmc_step_fn, mcmc_init_fn, init_params, mcmc_parameter_update_fn) - _get_kernel_name(): Return string name for logging/plotting
- Parameters:
likelihood (LikelihoodBase) – Likelihood object with evaluate(params, data) method
prior (Prior) – Prior object
sample_transforms (list[BijectiveTransform]) – Should be empty for SMC (works in prior space)
likelihood_transforms (list[NtoMTransform]) – N-to-M transforms applied before likelihood evaluation
config (SMCRandomWalkSamplerConfig | SMCNUTSSamplerConfig) – SMC configuration
seed (int, optional) – Random seed (default: 0)
- Variables:
config (SMCRandomWalkSamplerConfig | SMCNUTSSamplerConfig) – Sampler configuration
final_state (Any | None) – Final SMC state (after sampling)
metadata (dict) – Sampling metadata (ESS, time, etc.)
_unflatten_fn (callable) – Function to convert flat arrays back to parameter dicts
_flatten_fn (callable) – Function to convert parameter dicts to flat arrays
_particles_flat (Array) – Final particle positions (flat arrays)
_weights (Array) – Final particle weights
- __init__(
- likelihood,
- prior,
- sample_transforms,
- likelihood_transforms,
- config,
- seed=0,
Initialize BlackJAX SMC sampler.
Methods
__init__(likelihood, prior, ...[, seed])Initialize BlackJAX SMC sampler.
add_name(x)Turn an array into a dictionary.
Get log posterior probabilities from SMC.
Get number of particles from SMC.
Get standardized sampler output.
Return final particle positions.
plot_diagnostics([outdir, filename])Generate diagnostic plots for SMC sampling run.
posterior(params, data)Evaluate posterior log probability from flat array.
posterior_from_dict(named_params, data)Evaluate posterior log probability from parameter dict.
print_summary([transform])Print summary of sampling run.
sample(key)Run SMC until λ = 1 (posterior).
Attributes
likelihoodpriorsample_transformslikelihood_transformsparameter_namessampler- get_log_prob()[source]#
Get log posterior probabilities from SMC.
- Return type:
- Returns:
Array – Log posterior probability values (1D array) Note: At λ=1 (final tempering), these are true posterior values.
- get_n_samples()[source]#
Get number of particles from SMC.
- Return type:
- Returns:
int – Number of particles
- get_sampler_output()[source]#
Get standardized sampler output.
- Return type:
- Returns:
SamplerOutput –
samples: Parameter samples (dict of arrays, no weights/ess)
log_prob: Log posterior at λ=1 (final tempering)
metadata: {“weights”: Array, “ess”: float}
- Raises:
RuntimeError – If sampling has not been run yet.
- get_samples()[source]#
Return final particle positions.
- Return type:
- Returns:
dict – Dictionary with: - Parameter samples (transformed back to prior space) - ‘weights’: particle weights - ‘ess’: effective sample size
- plot_diagnostics(outdir='.', filename='smc_diagnostics.png')[source]#
Generate diagnostic plots for SMC sampling run.
Creates a 3-panel figure showing: - Temperature (lambda) progression from 0 to 1 - Effective Sample Size (ESS) evolution - Acceptance rate evolution
- Parameters:
- Return type:
Notes
This method requires matplotlib to be installed. It should be called after sampling is complete (after calling sample()).