jesterTOV.inference.samplers.blackjax.smc.random_walk.BlackJAXSMCRandomWalkSampler

jesterTOV.inference.samplers.blackjax.smc.random_walk.BlackJAXSMCRandomWalkSampler#

class BlackJAXSMCRandomWalkSampler(
likelihood,
prior,
sample_transforms,
likelihood_transforms,
config,
seed=0,
)[source]#

Bases: BlackjaxSMCSampler

SMC with Gaussian Random Walk Metropolis-Hastings kernel.

This sampler uses a simple random walk proposal with adaptive sigma tuning. Recommended for most use cases due to simplicity and robustness.

The proposal covariance is adapted from current particles at each tempering step and scaled by a fixed sigma^2 parameter.

Parameters:
__init__(
likelihood,
prior,
sample_transforms,
likelihood_transforms,
config,
seed=0,
)[source]#

Initialize Random Walk SMC sampler.

Methods

__init__(likelihood, prior, ...[, seed])

Initialize Random Walk SMC sampler.

add_name(x)

Turn an array into a dictionary.

get_log_prob()

Get log posterior probabilities from SMC.

get_n_samples()

Get number of particles from SMC.

get_sampler_output()

Get standardized sampler output.

get_samples()

Return final particle positions.

plot_diagnostics([outdir, filename])

Generate diagnostic plots for SMC sampling run.

posterior(params, data)

Evaluate posterior log probability from flat array.

posterior_from_dict(named_params, data)

Evaluate posterior log probability from parameter dict.

print_summary([transform])

Print summary of sampling run.

sample(key)

Run SMC until λ = 1 (posterior).

Attributes

config

final_state

metadata

likelihood

prior

sample_transforms

likelihood_transforms

parameter_names

sampler