Source code for jesterTOV.inference.samplers.blackjax.base

"""Base class for BlackJAX samplers with shared transform logic.

This module provides BlackjaxSampler, which handles parameter space transformations
in a way that can be shared across different BlackJAX sampling algorithms (SMC, NS, etc.).
"""

from typing import Any, Callable

import jax
from jesterTOV.inference.base import (
    LikelihoodBase,
    Prior,
    BijectiveTransform,
    NtoMTransform,
)
from jesterTOV.inference.samplers.jester_sampler import JesterSampler
from jesterTOV.logging_config import get_logger

logger = get_logger("jester")


[docs] class BlackjaxSampler(JesterSampler): """Base class for BlackJAX samplers with shared transform logic. This class provides common functionality for all BlackJAX-based samplers: - Creating dict-based log prior functions (with inverse transforms + Jacobian) - Creating dict-based log likelihood functions (with inverse + likelihood transforms) Different BlackJAX algorithms have different API requirements: - SMC requires flat array functions → subclass wraps these dict functions - NS-AW requires dict functions → subclass uses these directly This design maximizes code reuse while respecting each algorithm's API. Parameters ---------- likelihood : LikelihoodBase Likelihood object with evaluate(params, data) method prior : Prior Prior object sample_transforms : list[BijectiveTransform] | None, optional Bijective transforms applied during sampling (with Jacobians) likelihood_transforms : list[NtoMTransform] | None, optional N-to-M transforms applied before likelihood evaluation Notes ----- Subclasses must implement: - sample(): Run the sampling algorithm - get_samples(): Return samples in dict format - get_log_prob(): Return log probabilities - get_n_samples(): Return number of samples - get_sampler_output(): Return standardized SamplerOutput """
[docs] def __init__( self, likelihood: LikelihoodBase, prior: Prior, sample_transforms: list[BijectiveTransform] | None = None, likelihood_transforms: list[NtoMTransform] | None = None, ) -> None: """Initialize BlackJAX sampler base class.""" super().__init__(likelihood, prior, sample_transforms, likelihood_transforms)
def _create_logprior_fn_from_dict(self) -> Callable[[dict[str, Any]], float]: """Create log prior function that works with parameter dicts. This function: 1. Applies inverse sample transforms (sampling space → prior space) 2. Adds Jacobian corrections from transforms 3. Evaluates prior log probability Both SMC and NS can use this - SMC will wrap it for flat arrays. Returns ------- Callable[[dict[str, Any]], float] JIT-compiled log prior function for single sample dict Examples -------- >>> logprior_fn = self._create_logprior_fn_from_dict() >>> params = {"K_sat": 0.5, "L_sym": 0.3} # In sampling space (e.g., unit cube) >>> log_p = logprior_fn(params) # Returns log prior in prior space + Jacobian """ def logprior_fn(params_dict: dict[str, Any]) -> float: """Evaluate log prior with transforms and Jacobian corrections.""" transform_jacobian = 0.0 named_params = params_dict.copy() # Apply inverse sample transforms (sampling space → prior space) for transform in reversed(self.sample_transforms): named_params, jacobian = transform.inverse(named_params) transform_jacobian += jacobian # Evaluate prior + Jacobian return self.prior.log_prob(named_params) + transform_jacobian # JIT compile for performance return jax.jit(logprior_fn) def _create_loglikelihood_fn_from_dict(self) -> Callable[[dict[str, Any]], float]: """Create log likelihood function that works with parameter dicts. This function: 1. Applies inverse sample transforms (sampling space → prior space) 2. Applies forward likelihood transforms (prior → likelihood params) 3. Evaluates likelihood Both SMC and NS can use this - SMC will wrap it for flat arrays. Returns ------- Callable[[dict[str, Any]], float] JIT-compiled log likelihood function for single sample dict Examples -------- >>> loglikelihood_fn = self._create_loglikelihood_fn_from_dict() >>> params = {"K_sat": 0.5, "L_sym": 0.3} # In sampling space (e.g., unit cube) >>> log_l = loglikelihood_fn(params) # Returns log likelihood """ def loglikelihood_fn(params_dict: dict[str, Any]) -> float: """Evaluate log likelihood with transforms.""" named_params = params_dict.copy() # Apply inverse sample transforms (sampling space → prior space) for transform in reversed(self.sample_transforms): named_params, _ = transform.inverse(named_params) # Apply likelihood transforms (prior → likelihood params) for transform in self.likelihood_transforms: named_params = transform.forward(named_params) # Evaluate likelihood return self.likelihood.evaluate(named_params) # JIT compile for performance return jax.jit(loglikelihood_fn)