jesterTOV.inference.samplers.jester_sampler.JesterSampler#

class JesterSampler(likelihood, prior, sample_transforms=None, likelihood_transforms=None)[source]#

Bases: object

Lightweight base class for JESTER samplers.

This class provides a modular interface for Bayesian inference with different sampling backends (flowMC, Jim, NumPyro, etc.). It handles: - Parameter transforms (sample and likelihood transforms) - Posterior evaluation with Jacobian corrections - Parameter name propagation - Generic sampling interface

Backend-specific implementations should inherit from this class and: 1. Call super().__init__() to set up common attributes 2. Create self.sampler (the backend sampler instance) 3. Optionally override methods for backend-specific behavior

Critical features: - Uses jnp.inf instead of jnp.nan for initialization - Preserves parameter ordering when converting dict to array

Parameters:
  • likelihood (LikelihoodBase) – Likelihood object with evaluate(params, data) method

  • prior (Prior) – Prior object with sample() and log_prob() methods

  • sample_transforms (list[BijectiveTransform] | None, optional) – Bijective transforms applied during sampling (with Jacobians)

  • likelihood_transforms (list[NtoMTransform] | None, optional) – N-to-M transforms applied before likelihood evaluation

Variables:
  • likelihood (LikelihoodBase) – Likelihood object

  • prior (Prior) – Prior object

  • sample_transforms (list[BijectiveTransform]) – Transforms applied during sampling

  • likelihood_transforms (list[NtoMTransform]) – Transforms applied before likelihood evaluation

  • parameter_names (list[str]) – Names of parameters (propagated through sample transforms)

  • sampler (Any | None) – Backend sampler instance (created by subclasses)

Notes

Subclasses must create self.sampler in their __init__ method. The sampler should have a .sample() method and support get_sampler_state().

__init__(
likelihood,
prior,
sample_transforms=None,
likelihood_transforms=None,
)[source]#

Methods

__init__(likelihood, prior[, ...])

add_name(x)

Turn an array into a dictionary.

get_log_prob()

Get log probabilities for production samples.

get_n_samples()

Get number of production/final samples.

get_sampler_output()

Get standardized sampler output with samples, log probabilities, and metadata.

get_samples()

Get production samples from the sampler.

posterior(params, data)

Evaluate posterior log probability from flat array.

posterior_from_dict(named_params, data)

Evaluate posterior log probability from parameter dict.

print_summary([transform])

Print summary of sampling run.

sample(key)

Run sampling.

Attributes

add_name(x)[source]#

Turn an array into a dictionary.

Parameters:

x (Array) – An array of parameters. Shape (n_dim,).

Return type:

dict[str, Float]

Returns:

dict[str, Float] – Dictionary mapping parameter names to values

get_log_prob()[source]#

Get log probabilities for production samples.

This method must be implemented by backend-specific subclasses. Always returns production/final samples (not training samples).

Return type:

Array

Returns:

Array – Log probability values (1D array)

Raises:

NotImplementedError – This is an abstract method that must be implemented by subclasses

Notes

  • FlowMC: Returns log posterior from production sampler state

  • Nested Sampling: Returns log likelihood (use weights separately)

  • SMC: Returns log posterior (uniform weights at λ=1)

get_n_samples()[source]#

Get number of production/final samples.

This method must be implemented by backend-specific subclasses. Always returns the number of production/final samples (not training samples).

Return type:

int

Returns:

int – Number of production/final samples

Raises:

NotImplementedError – This is an abstract method that must be implemented by subclasses

Notes

For samplers with train/production splits (e.g., FlowMC), this returns only production sample count. Training sample count should be accessed via sampler-specific methods if needed.

get_sampler_output()[source]#

Get standardized sampler output with samples, log probabilities, and metadata.

This is the preferred method for accessing sampler results. It returns a SamplerOutput dataclass containing all samples, log probabilities, and sampler-specific metadata in a standardized format.

Always returns production/final samples (not training samples).

Return type:

SamplerOutput

Returns:

SamplerOutput – Standardized output containing: - samples: Dict of parameter arrays (no metadata fields) - log_prob: Log probability array (posterior or likelihood) - metadata: Sampler-specific fields (weights, ESS, etc.)

Raises:
  • NotImplementedError – If called on base class (must be implemented by backend-specific subclass)

  • RuntimeError – If sampling has not been run yet (no results available)

Notes

This method should be used instead of the older get_samples() and get_log_prob() methods, which are now considered legacy.

For NS-AW, log_prob contains log likelihood (not log posterior), as nested sampling works in likelihood space.

For samplers with train/production splits (e.g., FlowMC), this returns only production samples. Training samples should be accessed via sampler-specific methods if needed for diagnostics.

get_samples()[source]#

Get production samples from the sampler.

This method must be implemented by backend-specific subclasses. Always returns production/final samples (not training samples).

Return type:

dict

Returns:

dict – Dictionary of samples with parameter names as keys

Raises:

NotImplementedError – This is an abstract method that must be implemented by subclasses

likelihood: LikelihoodBase#
likelihood_transforms: list[NtoMTransform]#
parameter_names: list[str]#
posterior(params, data)[source]#

Evaluate posterior log probability from flat array.

Parameters:
  • params (Array) – Parameter array in sampling space

  • data (dict) – Data dictionary (unused in JESTER, pass {})

Return type:

Float

Returns:

Float – Log posterior probability

posterior_from_dict(named_params, data)[source]#

Evaluate posterior log probability from parameter dict.

Parameters:
  • named_params (dict) – Parameter dictionary

  • data (dict) – Data dictionary (unused in JESTER, pass {})

Return type:

Float

Returns:

Float – Log posterior probability

print_summary(transform=True)[source]#

Print summary of sampling run.

This method can be implemented by backend-specific subclasses.

Parameters:

transform (bool, optional) – Whether to apply inverse sample transforms to results (default: True)

Return type:

None

prior: Prior#
sample(key)[source]#

Run sampling.

This method must be implemented by backend-specific subclasses. Initial positions are sampled from the prior internally by each sampler implementation.

Parameters:

key (PRNGKeyArray) – JAX random key

Raises:

NotImplementedError – This is an abstract method that must be implemented by subclasses

Return type:

None

sample_transforms: list[BijectiveTransform]#
sampler: Any | None#