jesterTOV.inference.samplers.jester_sampler.JesterSampler#
- class JesterSampler(likelihood, prior, sample_transforms=None, likelihood_transforms=None)[source]#
Bases:
objectLightweight base class for JESTER samplers.
This class provides a modular interface for Bayesian inference with different sampling backends (flowMC, Jim, NumPyro, etc.). It handles: - Parameter transforms (sample and likelihood transforms) - Posterior evaluation with Jacobian corrections - Parameter name propagation - Generic sampling interface
Backend-specific implementations should inherit from this class and: 1. Call super().__init__() to set up common attributes 2. Create self.sampler (the backend sampler instance) 3. Optionally override methods for backend-specific behavior
Critical features: - Uses jnp.inf instead of jnp.nan for initialization - Preserves parameter ordering when converting dict to array
- Parameters:
likelihood (LikelihoodBase) – Likelihood object with evaluate(params, data) method
prior (Prior) – Prior object with sample() and log_prob() methods
sample_transforms (list[BijectiveTransform] | None, optional) – Bijective transforms applied during sampling (with Jacobians)
likelihood_transforms (list[NtoMTransform] | None, optional) – N-to-M transforms applied before likelihood evaluation
- Variables:
likelihood (LikelihoodBase) – Likelihood object
prior (Prior) – Prior object
sample_transforms (list[BijectiveTransform]) – Transforms applied during sampling
likelihood_transforms (list[NtoMTransform]) – Transforms applied before likelihood evaluation
parameter_names (list[str]) – Names of parameters (propagated through sample transforms)
sampler (Any | None) – Backend sampler instance (created by subclasses)
Notes
Subclasses must create self.sampler in their __init__ method. The sampler should have a .sample() method and support get_sampler_state().
Methods
__init__(likelihood, prior[, ...])add_name(x)Turn an array into a dictionary.
Get log probabilities for production samples.
Get number of production/final samples.
Get standardized sampler output with samples, log probabilities, and metadata.
Get production samples from the sampler.
posterior(params, data)Evaluate posterior log probability from flat array.
posterior_from_dict(named_params, data)Evaluate posterior log probability from parameter dict.
print_summary([transform])Print summary of sampling run.
sample(key)Run sampling.
Attributes
- get_log_prob()[source]#
Get log probabilities for production samples.
This method must be implemented by backend-specific subclasses. Always returns production/final samples (not training samples).
- Return type:
- Returns:
Array – Log probability values (1D array)
- Raises:
NotImplementedError – This is an abstract method that must be implemented by subclasses
Notes
FlowMC: Returns log posterior from production sampler state
Nested Sampling: Returns log likelihood (use weights separately)
SMC: Returns log posterior (uniform weights at λ=1)
- get_n_samples()[source]#
Get number of production/final samples.
This method must be implemented by backend-specific subclasses. Always returns the number of production/final samples (not training samples).
- Return type:
- Returns:
int – Number of production/final samples
- Raises:
NotImplementedError – This is an abstract method that must be implemented by subclasses
Notes
For samplers with train/production splits (e.g., FlowMC), this returns only production sample count. Training sample count should be accessed via sampler-specific methods if needed.
- get_sampler_output()[source]#
Get standardized sampler output with samples, log probabilities, and metadata.
This is the preferred method for accessing sampler results. It returns a SamplerOutput dataclass containing all samples, log probabilities, and sampler-specific metadata in a standardized format.
Always returns production/final samples (not training samples).
- Return type:
- Returns:
SamplerOutput – Standardized output containing: - samples: Dict of parameter arrays (no metadata fields) - log_prob: Log probability array (posterior or likelihood) - metadata: Sampler-specific fields (weights, ESS, etc.)
- Raises:
NotImplementedError – If called on base class (must be implemented by backend-specific subclass)
RuntimeError – If sampling has not been run yet (no results available)
Notes
This method should be used instead of the older get_samples() and get_log_prob() methods, which are now considered legacy.
For NS-AW, log_prob contains log likelihood (not log posterior), as nested sampling works in likelihood space.
For samplers with train/production splits (e.g., FlowMC), this returns only production samples. Training samples should be accessed via sampler-specific methods if needed for diagnostics.
- get_samples()[source]#
Get production samples from the sampler.
This method must be implemented by backend-specific subclasses. Always returns production/final samples (not training samples).
- Return type:
- Returns:
dict – Dictionary of samples with parameter names as keys
- Raises:
NotImplementedError – This is an abstract method that must be implemented by subclasses
- likelihood: LikelihoodBase#
- likelihood_transforms: list[NtoMTransform]#
- posterior(params, data)[source]#
Evaluate posterior log probability from flat array.
- Parameters:
params (Array) – Parameter array in sampling space
data (dict) – Data dictionary (unused in JESTER, pass {})
- Return type:
Float- Returns:
Float – Log posterior probability
- posterior_from_dict(named_params, data)[source]#
Evaluate posterior log probability from parameter dict.
- print_summary(transform=True)[source]#
Print summary of sampling run.
This method can be implemented by backend-specific subclasses.
- sample(key)[source]#
Run sampling.
This method must be implemented by backend-specific subclasses. Initial positions are sampled from the prior internally by each sampler implementation.
- Parameters:
key (PRNGKeyArray) – JAX random key
- Raises:
NotImplementedError – This is an abstract method that must be implemented by subclasses
- Return type:
- sample_transforms: list[BijectiveTransform]#