jesterTOV.inference.samplers.blackjax.nested_sampling.acceptance_walk.BlackJAXNSAWSampler#

class BlackJAXNSAWSampler(
likelihood,
prior,
sample_transforms,
likelihood_transforms,
config,
seed=0,
)[source]#

Bases: BlackjaxSampler

BlackJAX Nested Sampling with acceptance walk kernel.

This sampler implements nested sampling for Bayesian evidence calculation and posterior sampling. It uses unit cube transforms (all parameters mapped to [0, 1]) and the acceptance walk kernel for MCMC proposals.

Unlike SMC samplers, NS-AW works directly with dict-based functions from BlackjaxSampler parent class (no flattening needed).

Parameters:
  • likelihood (LikelihoodBase) – Likelihood object with evaluate(params, data) method

  • prior (Prior) – Prior object (CombinePrior of UniformPrior and/or MultivariateGaussianPrior)

  • sample_transforms (list[BijectiveTransform]) – Unit cube transforms (created by transform_factory)

  • likelihood_transforms (list[NtoMTransform]) – N-to-M transforms applied before likelihood evaluation

  • config (BlackJAXNSAWConfig) – Nested sampling configuration

  • seed (int, optional) – Random seed (default: 0)

Variables:
  • config (BlackJAXNSAWConfig) – Sampler configuration

  • final_state (Any | None) – Final nested sampling state (after sampling)

  • metadata (dict) – Sampling metadata (evidence, time, etc.)

  • _logprior_fn (callable) – Pre-compiled log prior function (unit cube → prior space)

  • _loglikelihood_fn (callable) – Pre-compiled log likelihood function (unit cube → likelihood)

Notes

Requires BoundToBound [0,1] transforms for all parameters (created automatically by transform_factory for nested sampling).

__init__(
likelihood,
prior,
sample_transforms,
likelihood_transforms,
config,
seed=0,
)[source]#

Initialize BlackJAX nested sampling sampler.

Methods

__init__(likelihood, prior, ...[, seed])

Initialize BlackJAX nested sampling sampler.

add_name(x)

Turn an array into a dictionary.

get_log_prob()

Get log likelihoods from nested sampling.

get_n_samples()

Get number of posterior samples from nested sampling.

get_sampler_output()

Get standardized sampler output.

get_samples()

Return unweighted posterior samples from nested sampling.

posterior(params, data)

Evaluate posterior log probability from flat array.

posterior_from_dict(named_params, data)

Evaluate posterior log probability from parameter dict.

print_summary([transform])

Print summary of nested sampling run.

sample(key)

Run nested sampling until termination criterion.

Attributes

config

final_state

metadata

likelihood

prior

sample_transforms

likelihood_transforms

parameter_names

sampler

config: BlackJAXNSAWConfig#
final_state: Any | None#
get_log_prob()[source]#

Get log likelihoods from nested sampling.

Return type:

Array

Returns:

Array – Log likelihood values (1D array) Note: For NS, this is log likelihood, not log posterior. Use weights separately for posterior inference.

Notes

This method returns filtered log likelihoods (matching get_samples()). If anesthetic has dropped invalid samples, the length will be less than the raw NSInfo.loglikelihood array.

get_n_samples()[source]#

Get number of posterior samples from nested sampling.

Return type:

int

Returns:

int – Number of posterior samples

Notes

This method returns the number of filtered samples (matching get_samples()). If anesthetic has dropped invalid samples, the count will be less than the raw NSInfo particle count.

get_sampler_output()[source]#

Get standardized sampler output.

Return type:

SamplerOutput

Returns:

SamplerOutput

  • samples: Unweighted parameter samples (dict of arrays, no metadata fields)

  • log_prob: Log likelihood (NOT log posterior - NS works in likelihood space)

  • metadata: {“logL”: Array, “logL_birth”: Array}

Raises:

RuntimeError – If sampling has not been run yet.

Notes

Samples are resampled using importance weights to produce unweighted posterior samples. This ensures downstream analysis treats all samples equally. log_prob contains log likelihood, not log posterior (standard for NS).

get_samples()[source]#

Return unweighted posterior samples from nested sampling.

This method computes importance weights using anesthetic, then resamples to produce approximately ESS (effective sample size) unweighted posterior samples. This ensures downstream analysis (plotting, postprocessing) treats all samples as equally weighted, which is the expected behavior.

Return type:

dict

Returns:

dict – Dictionary with: - Parameter samples (resampled, unweighted) - ‘logL’: log likelihood values (resampled) - ‘logL_birth’: birth log likelihoods (resampled)

Notes

The original weighted samples are cached in _filtered_samples_cache for advanced users who need access to the full weighted set.

metadata: dict#
print_summary(transform=True)[source]#

Print summary of nested sampling run.

Parameters:

transform (bool, optional) – Not used for nested sampling (always returns physical parameters)

Return type:

None

sample(key)[source]#

Run nested sampling until termination criterion.

Parameters:

key (PRNGKeyArray) – JAX random key

Return type:

None

Notes

Initial live points are sampled from the prior and transformed to unit cube space internally.