Source code for jesterTOV.inference.config.schemas.samplers

"""Pydantic models for sampler configuration."""

from typing import Literal, Union, Annotated
from pydantic import Field, field_validator, ConfigDict, Discriminator

from ._base import JesterBaseModel


[docs] class BaseSamplerConfig(JesterBaseModel): """Base configuration for all samplers. This base class provides common fields shared by all sampler types. Each subclass must define its own 'type' field with a specific literal value for use as a discriminator in the SamplerConfig union. Attributes ---------- output_dir : str Directory to save results n_eos_samples : int Number of EOS samples to generate after inference (default: 10000) log_prob_batch_size : int Batch size for computing log probabilities and generating EOS samples (default: 1000) """ model_config = ConfigDict(extra="forbid") output_dir: str = "./outdir/" n_eos_samples: int = 10_000 log_prob_batch_size: int = 1000
[docs] @field_validator("n_eos_samples", "log_prob_batch_size") @classmethod def validate_base_positive(cls, v: int) -> int: """Validate that value is positive.""" if v <= 0: raise ValueError(f"Value must be positive, got: {v}") return v
[docs] class FlowMCSamplerConfig(BaseSamplerConfig): """Configuration for FlowMC sampler (normalizing flow-enhanced MCMC). Attributes ---------- type : Literal["flowmc"] Sampler type identifier n_chains : int Number of parallel chains n_loop_training : int Number of training loops n_loop_production : int Number of production loops n_local_steps : int Number of local MCMC steps per loop n_global_steps : int Number of global steps per loop n_epochs : int Number of training epochs for normalizing flow learning_rate : float Learning rate for flow training train_thinning : int Thinning factor for training samples (default: 1) output_thinning : int Thinning factor for output samples (default: 5) output_dir : str Directory to save results n_eos_samples : int Number of EOS samples to generate after inference (default: 10000) """ type: Literal["flowmc"] = "flowmc" n_chains: int = 20 n_loop_training: int = 3 n_loop_production: int = 3 n_local_steps: int = 100 n_global_steps: int = 100 n_epochs: int = 30 learning_rate: float = 0.001 train_thinning: int = 1 output_thinning: int = 5
[docs] @field_validator( "n_chains", "n_loop_training", "n_loop_production", "n_local_steps", "n_global_steps", "n_epochs", "train_thinning", "output_thinning", ) @classmethod def validate_positive(cls, v: int) -> int: """Validate that value is positive.""" if v <= 0: raise ValueError(f"Value must be positive, got: {v}") return v
[docs] @field_validator("learning_rate") @classmethod def validate_positive_float(cls, v: float) -> float: """Validate that learning rate is positive.""" if v <= 0: raise ValueError(f"Value must be positive, got: {v}") return v
[docs] class BlackJAXNSAWConfig(BaseSamplerConfig): """Configuration for BlackJAX Nested Sampling with Acceptance Walk. Attributes ---------- type : Literal["blackjax-ns-aw"] Sampler type identifier n_live : int Number of live points (default: 1000) n_delete_frac : float Fraction of live points to delete per iteration (default: 0.5) n_target : int Target number of accepted MCMC steps (default: 60) max_mcmc : int Maximum MCMC steps per iteration (default: 5000) max_proposals : int Maximum proposal attempts per MCMC step (default: 1000) termination_dlogz : float Evidence convergence criterion (default: 0.1) output_dir : str Directory to save results n_eos_samples : int Number of EOS samples to generate after inference (default: 10000) """ type: Literal["blackjax-ns-aw"] = "blackjax-ns-aw" n_live: int = 1000 n_delete_frac: float = 0.5 n_target: int = 60 max_mcmc: int = 5000 max_proposals: int = 1000 termination_dlogz: float = 0.1
[docs] @field_validator("n_delete_frac") @classmethod def validate_delete_frac(cls, v: float) -> float: """Validate that deletion fraction is in (0, 1].""" if v <= 0 or v > 1: raise ValueError(f"n_delete_frac must be in (0, 1], got: {v}") return v
[docs] @field_validator("n_live", "n_target", "max_mcmc", "max_proposals") @classmethod def validate_positive(cls, v: int) -> int: """Validate that value is positive.""" if v <= 0: raise ValueError(f"Value must be positive, got: {v}") return v
[docs] class SMCRandomWalkSamplerConfig(BaseSamplerConfig): """Configuration for Sequential Monte Carlo with Random Walk kernel. Attributes ---------- type : Literal["smc-rw"] Sampler type identifier n_particles : int Number of particles (default: 10000) n_mcmc_steps : int Number of MCMC steps per tempering level (default: 1) target_ess : float Target effective sample size for adaptive tempering (default: 0.9) random_walk_sigma : float Fixed sigma scaling for Gaussian random walk kernel (default: 1.0). The proposal covariance is computed from particles and scaled by sigma^2. Default of 1.0 uses the empirical covariance directly. """ type: Literal["smc-rw"] = "smc-rw" n_particles: int = 10000 n_mcmc_steps: int = 1 target_ess: float = 0.9 random_walk_sigma: float = 1.0
[docs] @field_validator("n_particles", "n_mcmc_steps") @classmethod def validate_positive(cls, v: int) -> int: """Validate that value is positive.""" if v <= 0: raise ValueError(f"Value must be positive, got: {v}") return v
[docs] @field_validator("target_ess") @classmethod def validate_fraction(cls, v: float) -> float: """Validate that value is in (0, 1].""" if v <= 0 or v > 1: raise ValueError(f"Value must be in (0, 1], got: {v}") return v
[docs] @field_validator("random_walk_sigma") @classmethod def validate_positive_float(cls, v: float) -> float: """Validate that value is positive.""" if v <= 0: raise ValueError(f"Value must be positive, got: {v}") return v
[docs] class SMCNUTSSamplerConfig(BaseSamplerConfig): """Configuration for Sequential Monte Carlo with NUTS kernel (EXPERIMENTAL). WARNING: This sampler is experimental and should be used with caution. Attributes ---------- type : Literal["smc-nuts"] Sampler type identifier n_particles : int Number of particles (default: 10000) n_mcmc_steps : int Number of MCMC steps per tempering level (default: 1) target_ess : float Target effective sample size for adaptive tempering (default: 0.9) init_step_size : float Initial NUTS step size (default: 1e-2) mass_matrix_base : float Base value for diagonal mass matrix (default: 2e-1) mass_matrix_param_scales : dict[str, float] Per-parameter scaling for mass matrix (default: {}) target_acceptance : float Target acceptance rate (default: 0.7) adaptation_rate : float Adaptation rate for step size tuning (default: 0.3) """ type: Literal["smc-nuts"] = "smc-nuts" n_particles: int = 10000 n_mcmc_steps: int = 1 target_ess: float = 0.9 init_step_size: float = 1e-2 mass_matrix_base: float = 2e-1 mass_matrix_param_scales: dict[str, float] = Field(default_factory=dict) target_acceptance: float = 0.7 adaptation_rate: float = 0.3
[docs] @field_validator("n_particles", "n_mcmc_steps") @classmethod def validate_positive(cls, v: int) -> int: """Validate that value is positive.""" if v <= 0: raise ValueError(f"Value must be positive, got: {v}") return v
[docs] @field_validator("target_ess", "target_acceptance", "adaptation_rate") @classmethod def validate_fraction(cls, v: float) -> float: """Validate that value is in (0, 1].""" if v <= 0 or v > 1: raise ValueError(f"Value must be in (0, 1], got: {v}") return v
[docs] @field_validator("init_step_size", "mass_matrix_base") @classmethod def validate_positive_float(cls, v: float) -> float: """Validate that value is positive.""" if v <= 0: raise ValueError(f"Value must be positive, got: {v}") return v
# Discriminated union for sampler configurations SamplerConfig = Annotated[ Union[ FlowMCSamplerConfig, BlackJAXNSAWConfig, SMCRandomWalkSamplerConfig, SMCNUTSSamplerConfig, ], Discriminator("type"), ]