jesterTOV.inference.samplers.blackjax.smc.random_walk.BlackJAXSMCRandomWalkSampler#
- class BlackJAXSMCRandomWalkSampler(
- likelihood,
- prior,
- sample_transforms,
- likelihood_transforms,
- config,
- seed=0,
Bases:
BlackjaxSMCSamplerSMC with Gaussian Random Walk Metropolis-Hastings kernel.
This sampler uses a simple random walk proposal with adaptive sigma tuning. Recommended for most use cases due to simplicity and robustness.
The proposal covariance is adapted from current particles at each tempering step and scaled by a fixed sigma^2 parameter.
- Parameters:
likelihood (LikelihoodBase) – Likelihood object
prior (Prior) – Prior object
sample_transforms (list[BijectiveTransform]) – Sample transforms (typically empty for SMC)
likelihood_transforms (list[NtoMTransform]) – Likelihood transforms
config (SMCRandomWalkSamplerConfig) – Random walk SMC configuration
seed (int, optional) – Random seed (default: 0)
- __init__(
- likelihood,
- prior,
- sample_transforms,
- likelihood_transforms,
- config,
- seed=0,
Initialize Random Walk SMC sampler.
Methods
__init__(likelihood, prior, ...[, seed])Initialize Random Walk SMC sampler.
add_name(x)Turn an array into a dictionary.
get_log_prob()Get log posterior probabilities from SMC.
get_n_samples()Get number of particles from SMC.
get_sampler_output()Get standardized sampler output.
get_samples()Return final particle positions.
plot_diagnostics([outdir, filename])Generate diagnostic plots for SMC sampling run.
posterior(params, data)Evaluate posterior log probability from flat array.
posterior_from_dict(named_params, data)Evaluate posterior log probability from parameter dict.
print_summary([transform])Print summary of sampling run.
sample(key)Run SMC until λ = 1 (posterior).
Attributes
configfinal_statemetadatalikelihoodpriorsample_transformslikelihood_transformsparameter_namessampler