Source code for jesterTOV.inference.config.schema

r"""Pydantic models for inference configuration validation.

IMPORTANT: When you modify these schemas, regenerate the YAML reference documentation:

    uv run python -m jesterTOV.inference.config.generate_yaml_reference

TODO: make this automatic in CI/CD, so this note can be removed and user is not burdened with it

This ensures the user documentation stays in sync with the actual validation rules.

Schema organisation
-------------------
Config models are split into domain-specific sub-modules under ``schemas/``:

- ``schemas/eos.py``          EOS configuration (BaseEOSConfig + concrete types)
- ``schemas/tov.py``          TOV solver configuration (BaseTOVConfig + concrete types)
- ``schemas/likelihoods.py``  Likelihood configurations
- ``schemas/samplers.py``     Sampler configurations

This module assembles them into the top-level :class:`InferenceConfig` and re-exports
every name so that existing imports (``from .schema import ...``) continue to work.
"""

from pydantic import Field, field_validator, ConfigDict

from .schemas._base import JesterBaseModel

# EOS schemas
from .schemas.eos import (
    BaseEOSConfig,
    BaseMetamodelEOSConfig,
    MetamodelEOSConfig,
    MetamodelCSEEOSConfig,
    SpectralEOSConfig,
    EOSConfig,
)

# TOV schemas
from .schemas.tov import (
    BaseTOVConfig,
    GRTOVConfig,
    AnisotropyTOVConfig,
    TOVConfig,
)

# Likelihood schemas
from .schemas.likelihoods import (
    BaseLikelihoodConfig,
    GWEventConfig,
    GWLikelihoodConfig,
    GWResampledLikelihoodConfig,
    NICERLikelihoodConfig,
    NICERKDELikelihoodConfig,
    RadioLikelihoodConfig,
    ChiEFTLikelihoodConfig,
    EOSConstraintsLikelihoodConfig,
    TOVConstraintsLikelihoodConfig,
    GammaConstraintsLikelihoodConfig,
    DeprecatedConstraintsLikelihoodConfig,
    REXLikelihoodConfig,
    ZeroLikelihoodConfig,
    LikelihoodConfig,
)

# Sampler schemas
from .schemas.samplers import (
    BaseSamplerConfig,
    FlowMCSamplerConfig,
    BlackJAXNSAWConfig,
    SMCRandomWalkSamplerConfig,
    SMCNUTSSamplerConfig,
    SamplerConfig,
)


# ============================================================================
# Prior Configuration
# ============================================================================


[docs] class PriorConfig(JesterBaseModel): """Configuration for priors. Attributes ---------- specification_file : str Path to .prior file specifying prior distributions """ model_config = ConfigDict(extra="forbid") specification_file: str
[docs] @field_validator("specification_file") @classmethod def validate_file_extension(cls, v: str) -> str: """Validate that specification file has .prior extension.""" if not v.endswith(".prior"): raise ValueError( f"Prior specification file must have .prior extension, got: {v}" ) return v
# ============================================================================ # Postprocessing Configuration # ============================================================================
[docs] class PostprocessingConfig(JesterBaseModel): r"""Configuration for postprocessing plots. Attributes ---------- enabled : bool Whether to run postprocessing after inference (default: True) make_cornerplot : bool Generate cornerplot of EOS parameters (default: True) make_massradius : bool Generate mass-radius plot (default: True) make_masslambda : bool Generate mass-Lambda plot (default: True) make_pressuredensity : bool Generate pressure-density plot (default: True) make_histograms : bool Generate parameter histograms (default: True) make_cs2 : bool Generate cs2-density plot (default: True) prior_dir : str | None Directory containing prior samples for comparison (default: None) injection_eos_path : str | None Path to NPZ file containing injection EOS data for plotting (default: None). The NPZ file should contain arrays in geometric units: - masses_EOS: Solar masses :math:`M_{\odot}` - radii_EOS: :math:`\mathrm{km}` - Lambda_EOS: dimensionless tidal deformability - n: geometric units :math:`m^{-2}` - p: geometric units :math:`m^{-2}` - e: geometric units :math:`m^{-2}` - cs2: dimensionless This matches LALSuite EOS format and JESTER HDF5 output. Missing keys handled gracefully. """ model_config = ConfigDict(extra="forbid") enabled: bool = True make_cornerplot: bool = True make_massradius: bool = True make_masslambda: bool = True make_pressuredensity: bool = True make_histograms: bool = True make_cs2: bool = True prior_dir: str | None = None injection_eos_path: str | None = None
# ============================================================================ # Top-level InferenceConfig # ============================================================================
[docs] class InferenceConfig(JesterBaseModel): """Top-level inference configuration. Attributes ---------- seed : int Random seed for reproducibility eos : EOSConfig EOS configuration (discriminated union by type) tov : TOVConfig TOV solver configuration (discriminated union by type) prior : PriorConfig Prior configuration likelihoods : list[LikelihoodConfig] List of likelihood configurations sampler : SamplerConfig Sampler configuration postprocessing : PostprocessingConfig Postprocessing configuration data_paths : dict[str, str] Override default data paths dry_run : bool Setup everything but don't run sampler (default: False) validate_only : bool Only validate configuration, don't run inference (default: False) debug_nans : bool Enable JAX NaN debugging for catching numerical issues (default: False) """ model_config = ConfigDict(extra="forbid") seed: int = 43 eos: EOSConfig tov: TOVConfig prior: PriorConfig likelihoods: list[LikelihoodConfig] sampler: SamplerConfig postprocessing: PostprocessingConfig = Field(default_factory=PostprocessingConfig) data_paths: dict[str, str] = Field(default_factory=dict) dry_run: bool = False validate_only: bool = False debug_nans: bool = Field( default=False, description="Enable JAX NaN debugging for catching numerical issues during inference", )
[docs] @field_validator("likelihoods") @classmethod def validate_likelihoods(cls, v: list[LikelihoodConfig]) -> list[LikelihoodConfig]: """Validate that at least one likelihood is enabled.""" if not any(lk.enabled for lk in v): raise ValueError("At least one likelihood must be enabled") return v
[docs] @field_validator("seed") @classmethod def validate_seed(cls, v: int) -> int: """Validate that seed is non-negative.""" if v < 0: raise ValueError(f"Seed must be non-negative, got: {v}") return v
__all__ = [ # EOS "BaseEOSConfig", "BaseMetamodelEOSConfig", "MetamodelEOSConfig", "MetamodelCSEEOSConfig", "SpectralEOSConfig", "EOSConfig", # TOV "BaseTOVConfig", "GRTOVConfig", "AnisotropyTOVConfig", "TOVConfig", # Likelihoods "BaseLikelihoodConfig", "GWEventConfig", "GWLikelihoodConfig", "GWResampledLikelihoodConfig", "NICERLikelihoodConfig", "NICERKDELikelihoodConfig", "RadioLikelihoodConfig", "ChiEFTLikelihoodConfig", "EOSConstraintsLikelihoodConfig", "TOVConstraintsLikelihoodConfig", "GammaConstraintsLikelihoodConfig", "DeprecatedConstraintsLikelihoodConfig", "REXLikelihoodConfig", "ZeroLikelihoodConfig", "LikelihoodConfig", # Samplers "BaseSamplerConfig", "FlowMCSamplerConfig", "BlackJAXNSAWConfig", "SMCRandomWalkSamplerConfig", "SMCNUTSSamplerConfig", "SamplerConfig", # Other "PriorConfig", "PostprocessingConfig", "InferenceConfig", ]