jesterTOV.inference.samplers.blackjax.base.BlackjaxSampler#
- class BlackjaxSampler(likelihood, prior, sample_transforms=None, likelihood_transforms=None)[source]#
Bases:
JesterSamplerBase class for BlackJAX samplers with shared transform logic.
This class provides common functionality for all BlackJAX-based samplers: - Creating dict-based log prior functions (with inverse transforms + Jacobian) - Creating dict-based log likelihood functions (with inverse + likelihood transforms)
Different BlackJAX algorithms have different API requirements: - SMC requires flat array functions → subclass wraps these dict functions - NS-AW requires dict functions → subclass uses these directly
This design maximizes code reuse while respecting each algorithm’s API.
- Parameters:
likelihood (LikelihoodBase) – Likelihood object with evaluate(params, data) method
prior (Prior) – Prior object
sample_transforms (list[BijectiveTransform] | None, optional) – Bijective transforms applied during sampling (with Jacobians)
likelihood_transforms (list[NtoMTransform] | None, optional) – N-to-M transforms applied before likelihood evaluation
Notes
Subclasses must implement: - sample(): Run the sampling algorithm - get_samples(): Return samples in dict format - get_log_prob(): Return log probabilities - get_n_samples(): Return number of samples - get_sampler_output(): Return standardized SamplerOutput
- __init__(
- likelihood,
- prior,
- sample_transforms=None,
- likelihood_transforms=None,
Initialize BlackJAX sampler base class.
Methods
__init__(likelihood, prior[, ...])Initialize BlackJAX sampler base class.
add_name(x)Turn an array into a dictionary.
get_log_prob()Get log probabilities for production samples.
get_n_samples()Get number of production/final samples.
get_sampler_output()Get standardized sampler output with samples, log probabilities, and metadata.
get_samples()Get production samples from the sampler.
posterior(params, data)Evaluate posterior log probability from flat array.
posterior_from_dict(named_params, data)Evaluate posterior log probability from parameter dict.
print_summary([transform])Print summary of sampling run.
sample(key)Run sampling.
Attributes
likelihoodpriorsample_transformslikelihood_transformsparameter_namessampler