jesterTOV.inference.samplers.blackjax.smc.nuts.BlackJAXSMCNUTSSampler

jesterTOV.inference.samplers.blackjax.smc.nuts.BlackJAXSMCNUTSSampler#

class BlackJAXSMCNUTSSampler(
likelihood,
prior,
sample_transforms,
likelihood_transforms,
config,
seed=0,
)[source]#

Bases: BlackjaxSMCSampler

SMC with NUTS kernel and Hessian-based mass matrix adaptation.

WARNING: This sampler is EXPERIMENTAL. Use with caution and validate results carefully. The NUTS kernel with Hessian adaptation has not been thoroughly tested.

Parameters:
__init__(
likelihood,
prior,
sample_transforms,
likelihood_transforms,
config,
seed=0,
)[source]#

Initialize with EXPERIMENTAL warning.

Methods

__init__(likelihood, prior, ...[, seed])

Initialize with EXPERIMENTAL warning.

add_name(x)

Turn an array into a dictionary.

get_log_prob()

Get log posterior probabilities from SMC.

get_n_samples()

Get number of particles from SMC.

get_sampler_output()

Get standardized sampler output.

get_samples()

Return final particle positions.

plot_diagnostics([outdir, filename])

Generate diagnostic plots for SMC sampling run.

posterior(params, data)

Evaluate posterior log probability from flat array.

posterior_from_dict(named_params, data)

Evaluate posterior log probability from parameter dict.

print_summary([transform])

Print summary of sampling run.

sample(key)

Run SMC until λ = 1 (posterior).

Attributes

config

final_state

metadata

likelihood

prior

sample_transforms

likelihood_transforms

parameter_names

sampler