jesterTOV.inference.samplers.blackjax.smc.nuts.BlackJAXSMCNUTSSampler#
- class BlackJAXSMCNUTSSampler(
- likelihood,
- prior,
- sample_transforms,
- likelihood_transforms,
- config,
- seed=0,
Bases:
BlackjaxSMCSamplerSMC with NUTS kernel and Hessian-based mass matrix adaptation.
WARNING: This sampler is EXPERIMENTAL. Use with caution and validate results carefully. The NUTS kernel with Hessian adaptation has not been thoroughly tested.
- Parameters:
likelihood (LikelihoodBase) – Likelihood object
prior (Prior) – Prior object
sample_transforms (list[BijectiveTransform]) – Sample transforms (typically empty for SMC)
likelihood_transforms (list[NtoMTransform]) – Likelihood transforms
config (SMCNUTSSamplerConfig) – NUTS SMC configuration
seed (int, optional) – Random seed (default: 0)
- __init__(
- likelihood,
- prior,
- sample_transforms,
- likelihood_transforms,
- config,
- seed=0,
Initialize with EXPERIMENTAL warning.
Methods
__init__(likelihood, prior, ...[, seed])Initialize with EXPERIMENTAL warning.
add_name(x)Turn an array into a dictionary.
get_log_prob()Get log posterior probabilities from SMC.
get_n_samples()Get number of particles from SMC.
get_sampler_output()Get standardized sampler output.
get_samples()Return final particle positions.
plot_diagnostics([outdir, filename])Generate diagnostic plots for SMC sampling run.
posterior(params, data)Evaluate posterior log probability from flat array.
posterior_from_dict(named_params, data)Evaluate posterior log probability from parameter dict.
print_summary([transform])Print summary of sampling run.
sample(key)Run SMC until λ = 1 (posterior).
Attributes
configfinal_statemetadatalikelihoodpriorsample_transformslikelihood_transformsparameter_namessampler