Source code for jesterTOV.inference.samplers.jester_sampler

r"""
Base sampler implementation for jesterTOV.

This module provides a lightweight, modular base class for sampling.
Backend-specific implementations (e.g., flowMC, Jim, NumPyro) should
inherit from JesterSampler and implement the sampler initialization.
"""

from dataclasses import dataclass, field
from typing import Any

from jaxtyping import Array, Float, PRNGKeyArray

from jesterTOV.inference.base import (
    LikelihoodBase,
    Prior,
    BijectiveTransform,
    NtoMTransform,
)
from jesterTOV.logging_config import get_logger

logger = get_logger("jester")


[docs] @dataclass class SamplerOutput: """ Standardized output from JESTER samplers. This dataclass provides a uniform interface for accessing samples, log probabilities, and sampler-specific metadata across different sampling backends (FlowMC, SMC, NS-AW). Attributes ---------- samples : dict[str, Array] Dictionary of parameter samples. Keys are parameter names, values are JAX arrays of shape (n_samples,) or (n_samples, n_dim). Only contains actual parameters, not metadata fields. log_prob : Array Log probability for each sample. Interpretation depends on sampler: - FlowMC/SMC: log posterior probability - NS-AW: log likelihood (nested sampling uses likelihood) Shape: (n_samples,) metadata : dict[str, Any] Sampler-specific metadata. Common fields: - FlowMC: {} (empty, MCMC has equal weights) - SMC: {"weights": Array, "ess": float} - NS-AW: {"weights": Array, "logL": Array, "logL_birth": Array} Notes ----- The log_prob field has different semantics for NS-AW (log likelihood) versus FlowMC/SMC (log posterior). Consumers should check the sampler type when interpreting this field. """ samples: dict[str, Array] log_prob: Array metadata: dict[str, Any] = field(default_factory=dict)
[docs] class JesterSampler: """ Lightweight base class for JESTER samplers. This class provides a modular interface for Bayesian inference with different sampling backends (flowMC, Jim, NumPyro, etc.). It handles: - Parameter transforms (sample and likelihood transforms) - Posterior evaluation with Jacobian corrections - Parameter name propagation - Generic sampling interface Backend-specific implementations should inherit from this class and: 1. Call super().__init__() to set up common attributes 2. Create self.sampler (the backend sampler instance) 3. Optionally override methods for backend-specific behavior Critical features: - Uses jnp.inf instead of jnp.nan for initialization - Preserves parameter ordering when converting dict to array Parameters ---------- likelihood : LikelihoodBase Likelihood object with evaluate(params, data) method prior : Prior Prior object with sample() and log_prob() methods sample_transforms : list[BijectiveTransform] | None, optional Bijective transforms applied during sampling (with Jacobians) likelihood_transforms : list[NtoMTransform] | None, optional N-to-M transforms applied before likelihood evaluation Attributes ---------- likelihood : LikelihoodBase Likelihood object prior : Prior Prior object sample_transforms : list[BijectiveTransform] Transforms applied during sampling likelihood_transforms : list[NtoMTransform] Transforms applied before likelihood evaluation parameter_names : list[str] Names of parameters (propagated through sample transforms) sampler : Any | None Backend sampler instance (created by subclasses) Notes ----- Subclasses must create self.sampler in their __init__ method. The sampler should have a .sample() method and support get_sampler_state(). """ likelihood: LikelihoodBase prior: Prior sample_transforms: list[BijectiveTransform] likelihood_transforms: list[NtoMTransform] parameter_names: list[str] sampler: Any | None
[docs] def __init__( self, likelihood: LikelihoodBase, prior: Prior, sample_transforms: list[BijectiveTransform] | None = None, likelihood_transforms: list[NtoMTransform] | None = None, ) -> None: # Handle None defaults if sample_transforms is None: sample_transforms = [] if likelihood_transforms is None: likelihood_transforms = [] self.likelihood = likelihood self.prior = prior self.sample_transforms = sample_transforms self.likelihood_transforms = likelihood_transforms self.parameter_names = prior.parameter_names if len(sample_transforms) == 0: logger.debug( "No sample transforms provided. Using prior parameters as sampling parameters" ) else: logger.debug("Using sample transforms") for transform in sample_transforms: self.parameter_names = transform.propagate_name(self.parameter_names) if len(likelihood_transforms) == 0: logger.debug( "No likelihood transforms provided. Using prior parameters as likelihood parameters" ) # Backend sampler instance - must be created by subclasses self.sampler = None
[docs] def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]: """ Turn an array into a dictionary. Parameters ---------- x : Array An array of parameters. Shape (n_dim,). Returns ------- dict[str, Float] Dictionary mapping parameter names to values """ return dict(zip(self.parameter_names, x))
[docs] def posterior_from_dict( self, named_params: dict[str, Float], data: dict[str, Any] ) -> Float: """ Evaluate posterior log probability from parameter dict. Parameters ---------- named_params : dict Parameter dictionary data : dict Data dictionary (unused in JESTER, pass {}) Returns ------- Float Log posterior probability """ transform_jacobian = 0.0 for transform in reversed(self.sample_transforms): named_params, jacobian = transform.inverse(named_params) transform_jacobian += jacobian prior = self.prior.log_prob(named_params) + transform_jacobian # Apply likelihood transforms for transform in self.likelihood_transforms: named_params = transform.forward(named_params) return self.likelihood.evaluate(named_params) + prior
[docs] def posterior(self, params: Float[Array, " n_dim"], data: dict[str, Any]) -> Float: """ Evaluate posterior log probability from flat array. Parameters ---------- params : Array Parameter array in sampling space data : dict Data dictionary (unused in JESTER, pass {}) Returns ------- Float Log posterior probability """ named_params = self.add_name(params) return self.posterior_from_dict(named_params, data)
[docs] def sample(self, key: PRNGKeyArray) -> None: """ Run sampling. This method must be implemented by backend-specific subclasses. Initial positions are sampled from the prior internally by each sampler implementation. Parameters ---------- key : PRNGKeyArray JAX random key Raises ------ NotImplementedError This is an abstract method that must be implemented by subclasses """ raise NotImplementedError( "sample() must be implemented by backend-specific subclass" )
[docs] def print_summary(self, transform: bool = True) -> None: """ Print summary of sampling run. This method can be implemented by backend-specific subclasses. Parameters ---------- transform : bool, optional Whether to apply inverse sample transforms to results (default: True) """ # Optionally overridden by subclasses, but if not, then just do nothing pass
[docs] def get_samples(self) -> dict: """ Get production samples from the sampler. This method must be implemented by backend-specific subclasses. Always returns production/final samples (not training samples). Returns ------- dict Dictionary of samples with parameter names as keys Raises ------ NotImplementedError This is an abstract method that must be implemented by subclasses """ raise NotImplementedError( "get_samples() must be implemented by backend-specific subclass" )
[docs] def get_log_prob(self) -> Array: """ Get log probabilities for production samples. This method must be implemented by backend-specific subclasses. Always returns production/final samples (not training samples). Returns ------- Array Log probability values (1D array) Raises ------ NotImplementedError This is an abstract method that must be implemented by subclasses Notes ----- - FlowMC: Returns log posterior from production sampler state - Nested Sampling: Returns log likelihood (use weights separately) - SMC: Returns log posterior (uniform weights at λ=1) """ raise NotImplementedError( "get_log_prob() must be implemented by backend-specific subclass" )
[docs] def get_n_samples(self) -> int: """ Get number of production/final samples. This method must be implemented by backend-specific subclasses. Always returns the number of production/final samples (not training samples). Returns ------- int Number of production/final samples Raises ------ NotImplementedError This is an abstract method that must be implemented by subclasses Notes ----- For samplers with train/production splits (e.g., FlowMC), this returns only production sample count. Training sample count should be accessed via sampler-specific methods if needed. """ raise NotImplementedError( "get_n_samples() must be implemented by backend-specific subclass" )
[docs] def get_sampler_output(self) -> SamplerOutput: """ Get standardized sampler output with samples, log probabilities, and metadata. This is the preferred method for accessing sampler results. It returns a SamplerOutput dataclass containing all samples, log probabilities, and sampler-specific metadata in a standardized format. Always returns production/final samples (not training samples). Returns ------- SamplerOutput Standardized output containing: - samples: Dict of parameter arrays (no metadata fields) - log_prob: Log probability array (posterior or likelihood) - metadata: Sampler-specific fields (weights, ESS, etc.) Raises ------ NotImplementedError If called on base class (must be implemented by backend-specific subclass) RuntimeError If sampling has not been run yet (no results available) Notes ----- This method should be used instead of the older get_samples() and get_log_prob() methods, which are now considered legacy. For NS-AW, log_prob contains log likelihood (not log posterior), as nested sampling works in likelihood space. For samplers with train/production splits (e.g., FlowMC), this returns only production samples. Training samples should be accessed via sampler-specific methods if needed for diagnostics. """ raise NotImplementedError( "get_sampler_output() must be implemented by backend-specific subclass" )