jesterTOV.inference.samplers.blackjax.base.BlackjaxSampler

jesterTOV.inference.samplers.blackjax.base.BlackjaxSampler#

class BlackjaxSampler(likelihood, prior, sample_transforms=None, likelihood_transforms=None)[source]#

Bases: JesterSampler

Base class for BlackJAX samplers with shared transform logic.

This class provides common functionality for all BlackJAX-based samplers: - Creating dict-based log prior functions (with inverse transforms + Jacobian) - Creating dict-based log likelihood functions (with inverse + likelihood transforms)

Different BlackJAX algorithms have different API requirements: - SMC requires flat array functions → subclass wraps these dict functions - NS-AW requires dict functions → subclass uses these directly

This design maximizes code reuse while respecting each algorithm’s API.

Parameters:
  • likelihood (LikelihoodBase) – Likelihood object with evaluate(params, data) method

  • prior (Prior) – Prior object

  • sample_transforms (list[BijectiveTransform] | None, optional) – Bijective transforms applied during sampling (with Jacobians)

  • likelihood_transforms (list[NtoMTransform] | None, optional) – N-to-M transforms applied before likelihood evaluation

Notes

Subclasses must implement: - sample(): Run the sampling algorithm - get_samples(): Return samples in dict format - get_log_prob(): Return log probabilities - get_n_samples(): Return number of samples - get_sampler_output(): Return standardized SamplerOutput

__init__(
likelihood,
prior,
sample_transforms=None,
likelihood_transforms=None,
)[source]#

Initialize BlackJAX sampler base class.

Methods

__init__(likelihood, prior[, ...])

Initialize BlackJAX sampler base class.

add_name(x)

Turn an array into a dictionary.

get_log_prob()

Get log probabilities for production samples.

get_n_samples()

Get number of production/final samples.

get_sampler_output()

Get standardized sampler output with samples, log probabilities, and metadata.

get_samples()

Get production samples from the sampler.

posterior(params, data)

Evaluate posterior log probability from flat array.

posterior_from_dict(named_params, data)

Evaluate posterior log probability from parameter dict.

print_summary([transform])

Print summary of sampling run.

sample(key)

Run sampling.

Attributes

likelihood

prior

sample_transforms

likelihood_transforms

parameter_names

sampler