Source code for jesterTOV.inference.priors.parser

r"""Parser for .prior specification files in bilby-style Python format."""

from dataclasses import dataclass
from pathlib import Path
from typing import Union, Any, Dict
from jesterTOV.inference.base import (
    CombinePrior,
    Prior,
    UniformPrior,
    MultivariateGaussianPrior,
)
from jesterTOV.inference.base.prior import Fixed


@dataclass
class ParsedPrior:
    """Result of parsing a ``.prior`` file.

    Parameters
    ----------
    prior : CombinePrior
        Combined prior containing only the sampled (non-fixed) parameters.
    fixed_params : dict[str, float]
        Mapping of parameter name to fixed value for parameters declared with
        ``Fixed(...)`` in the prior file.
    """

    prior: CombinePrior
    fixed_params: dict[str, float]


[docs] def parse_prior_file( prior_file: Union[str, Path], nb_CSE: int = 0, ) -> ParsedPrior: """Parse .prior file (Python format) and return a :class:`ParsedPrior`. The prior file should contain Python variable assignments in bilby-style format: .. code-block:: python K_sat = UniformPrior(150.0, 300.0, parameter_names=["K_sat"]) Q_sat = UniformPrior(-500.0, 1100.0, parameter_names=["Q_sat"]) nbreak = UniformPrior(0.16, 0.32, parameter_names=["nbreak"]) # Pin a parameter to a fixed value (not sampled): lambda_BL = Fixed(0.0, parameter_names=["lambda_BL"]) Parameters declared with :class:`~jesterTOV.inference.base.prior.Fixed` are collected into :attr:`ParsedPrior.fixed_params` and excluded from the sampling prior. The parser will automatically: - Include all NEP parameters (``_sat`` and ``_sym`` parameters) - Include ``nbreak`` only if ``nb_CSE > 0`` - Add CSE grid parameters (``n_CSE_i_u``, ``cs2_CSE_i``) if ``nb_CSE > 0`` Parameters ---------- prior_file : str or Path Path to .prior specification file (Python format) nb_CSE : int, optional Number of CSE parameters (0 for MetaModel only) Returns ------- ParsedPrior Parsed prior with sampled :class:`CombinePrior` and ``fixed_params`` dict. Raises ------ FileNotFoundError If prior file does not exist ValueError If prior file format is invalid or no priors found Examples -------- >>> result = parse_prior_file("nep_standard.prior", nb_CSE=8) >>> print(result.prior.n_dim) # Number of sampled dimensions 25 # 8 NEP + 1 nbreak + 8*2 CSE grid params >>> print(result.fixed_params) # Any fixed parameters {} """ prior_file = Path(prior_file) if not prior_file.exists(): raise FileNotFoundError(f"Prior specification file not found: {prior_file}") # Read the prior file with open(prior_file, "r") as f: prior_code = f.read() # Create execution namespace with required imports only namespace: dict[str, Any] = { "UniformPrior": UniformPrior, "MultivariateGaussianPrior": MultivariateGaussianPrior, "Fixed": Fixed, } # Execute the prior file to populate the namespace try: exec(prior_code, namespace) except Exception as e: raise ValueError(f"Error executing prior file {prior_file}: {e}") from e # Extract all Prior objects from the namespace excluded_keys = { "__builtins__", "UniformPrior", "MultivariateGaussianPrior", "Fixed", } all_priors: Dict[str, Prior] = {} for key, value in namespace.items(): if key not in excluded_keys and isinstance(value, Prior): all_priors[key] = value # Separate Fixed parameters from sampled priors fixed_params: dict[str, float] = {} sampled_priors: Dict[str, Prior] = {} for param_name, prior in all_priors.items(): if isinstance(prior, Fixed): fixed_params[prior.parameter_names[0]] = prior.value else: sampled_priors[param_name] = prior # Filter sampled priors based on configuration prior_list = [] for param_name, prior in sampled_priors.items(): # Always include NEP parameters (_sat and _sym) if param_name.endswith("_sat") or param_name.endswith("_sym"): prior_list.append(prior) # Include nbreak only if nb_CSE > 0 elif param_name == "nbreak": if nb_CSE > 0: prior_list.append(prior) else: # Include any other parameters not handled by special cases prior_list.append(prior) # Add CSE grid parameters programmatically if nb_CSE > 0 if nb_CSE > 0: for i in range(nb_CSE): # Add n_CSE_i_u parameters (uniform [0, 1]) prior_list.append(UniformPrior(0.0, 1.0, parameter_names=[f"n_CSE_{i}_u"])) # Add cs2_CSE_i parameters (uniform [0, 1]) prior_list.append(UniformPrior(0.0, 1.0, parameter_names=[f"cs2_CSE_{i}"])) # Add final cs2 parameter for the grid point at nmax # This gives us nb_CSE grid points + 1 final point at nmax prior_list.append(UniformPrior(0.0, 1.0, parameter_names=[f"cs2_CSE_{nb_CSE}"])) if len(prior_list) == 0: raise ValueError( f"No sampled priors found in {prior_file}. " "Check file format and ensure at least one Prior object is defined. " "Note: Fixed parameters do not count as sampled priors." ) return ParsedPrior(prior=CombinePrior(prior_list), fixed_params=fixed_params)