Source code for jesterTOV.inference.transforms.transform

r"""Unified transform for EOS parameters to neutron star observables."""

from typing import Any

import jax.numpy as jnp
from jaxtyping import Array, Float

from jesterTOV.eos.base import Interpolate_EOS_model
from jesterTOV.eos.metamodel import (
    MetaModel_EOS_model,
    MetaModel_with_CSE_EOS_model,
)
from jesterTOV.eos.spectral import SpectralDecomposition_EOS_model
from jesterTOV.tov.base import TOVSolverBase
from jesterTOV.tov.gr import GRTOVSolver
from jesterTOV.inference.base import NtoMTransform
from jesterTOV.inference.config.schema import (
    BaseEOSConfig,
    MetamodelEOSConfig,
    MetamodelCSEEOSConfig,
    SpectralEOSConfig,
    BaseTOVConfig,
    GRTOVConfig,
    AnisotropyTOVConfig,
)
from jesterTOV.inference.likelihoods.constraints import check_all_constraints
from jesterTOV.logging_config import get_logger

logger = get_logger("jester")


[docs] class JesterTransform(NtoMTransform): """Transform EOS parameters to neutron star observables (M, R, Λ). This is the main transform class that combines an equation of state (EOS) model with a TOV solver to produce neutron star observables from microscopic EOS parameters. The transform can be created either by: 1. Passing EOS and TOV solver instances directly 2. Using from_config() classmethod with configuration dict/object Parameters ---------- eos : Interpolate_EOS_model EOS model instance (MetaModel, MetaModelCSE, Spectral, etc.) tov_solver : TOVSolverBase TOV solver instance (GRTOVSolver, AnisotropyTOVSolver, ScalarTensorTOVSolver) name_mapping : tuple[list[str], list[str]] | None Tuple of (input_names, output_names). If None, constructed from EOS and TOV required parameters. keep_names : list[str] | None Parameter names to preserve in output. If None, keeps all inputs. ndat_TOV : int Number of central pressure points for M-R-Λ curves (default: 100) min_nsat_TOV : float Minimum density for TOV integration in units of nsat (default: 0.75) **kwargs Additional parameters (for compatibility) Attributes ---------- eos : Interpolate_EOS_model The equation of state model tov_solver : TOVSolverBase The TOV equation solver eos_params : list[str] Parameters required by the EOS tov_params : list[str] Parameters required by the TOV solver keep_names : list[str] Parameters to preserve in output Examples -------- >>> # Direct instantiation >>> from jesterTOV.eos.metamodel import MetaModel_EOS_model >>> from jesterTOV.tov.gr import GRTOVSolver >>> eos = MetaModel_EOS_model(crust_name="DH") >>> solver = GRTOVSolver() >>> transform = JesterTransform(eos=eos, tov_solver=solver) >>> # From configuration >>> from jesterTOV.inference.config.schema import MetamodelCSEEOSConfig, TOVConfig >>> eos_config = MetamodelCSEEOSConfig(type="metamodel_cse", nb_CSE=8) >>> tov_config = TOVConfig(type="gr") >>> transform = JesterTransform.from_config(eos_config, tov_config) >>> # Transform parameters to observables >>> params = {"E_sat": -16.0, "K_sat": 230.0, ...} >>> result = transform.forward(params) >>> print(result["masses_EOS"]) # Neutron star masses in M☉ """
[docs] def __init__( self, eos: Interpolate_EOS_model, tov_solver: TOVSolverBase, name_mapping: tuple[list[str], list[str]] | None = None, keep_names: list[str] | None = None, ndat_TOV: int = 100, min_nsat_TOV: float = 0.75, fixed_params: dict[str, float] | None = None, **kwargs: Any, # FIXME: remove kwargs argument in a future release ) -> None: self.eos = eos self.tov_solver = tov_solver self.ndat_TOV = ndat_TOV self.min_nsat_TOV = min_nsat_TOV if fixed_params is not None: self.fixed_params: dict[str, float] = fixed_params.copy() else: self.fixed_params = {} # Get required parameters from EOS and TOV solver self.eos_params = eos.get_required_parameters() self.tov_params = tov_solver.get_required_parameters() # Construct name mapping if not provided. # Fixed parameters are not part of the sampled input space, so they # are excluded from input_names. if name_mapping is None: sampled_eos = [p for p in self.eos_params if p not in self.fixed_params] sampled_tov = [p for p in self.tov_params if p not in self.fixed_params] input_names = sampled_eos + sampled_tov output_names = ["logpc_EOS", "masses_EOS", "radii_EOS", "Lambdas_EOS"] name_mapping = (input_names, output_names) # Set keep_names (default: all input names) if keep_names is None: keep_names = name_mapping[0] self.keep_names = keep_names # Initialize parent NtoMTransform super().__init__(name_mapping) # Set transform_func for parent class compatibility self.transform_func = self.construct_eos_and_solve_tov logger.info( f"Initialized JesterTransform: EOS={repr(eos)}, TOV={repr(tov_solver)}" ) logger.debug(f" EOS parameters ({len(self.eos_params)}): {self.eos_params}") logger.debug(f" TOV parameters ({len(self.tov_params)}): {self.tov_params}") if self.fixed_params: logger.info(f" Fixed parameters: {self.fixed_params}")
[docs] @classmethod def from_config( cls, eos_config: BaseEOSConfig, tov_config: BaseTOVConfig, keep_names: list[str] | None = None, max_nbreak_nsat: float | None = None, fixed_params: dict[str, float] | None = None, ) -> "JesterTransform": """Create transform from configuration objects. This factory method instantiates the appropriate EOS and TOV solver based on the separate configurations, then creates the transform. Parameters ---------- eos_config : EOSConfig EOS configuration (MetamodelEOSConfig, MetamodelCSEEOSConfig, or SpectralEOSConfig) tov_config : TOVConfig TOV solver configuration keep_names : list[str] | None Parameters to preserve in output max_nbreak_nsat : float | None Maximum nbreak value (for MetaModelCSE optimization) fixed_params : dict[str, float] | None Parameters pinned to constant values, excluded from the sampling space but injected into every ``forward()`` call. Returns ------- JesterTransform Configured transform instance Raises ------ ValueError If EOS or TOV type is unknown """ # Instantiate EOS based on eos_config.type # If max_nbreak_nsat is not passed, fall back to the value from the config effective_max = ( max_nbreak_nsat if max_nbreak_nsat is not None else getattr(eos_config, "max_nbreak_nsat", None) ) eos = cls._create_eos(eos_config, effective_max) # Instantiate TOV solver based on tov_config tov_solver = cls._create_tov_solver(tov_config) # Create transform return cls( eos=eos, tov_solver=tov_solver, keep_names=keep_names, ndat_TOV=tov_config.ndat_TOV, min_nsat_TOV=tov_config.min_nsat_TOV, fixed_params=fixed_params, )
@staticmethod def _create_eos( config: BaseEOSConfig, max_nbreak_nsat: float | None = None ) -> Interpolate_EOS_model: """Create EOS instance from config. Parameters ---------- config : EOSConfig EOS configuration object (discriminated union) max_nbreak_nsat : float | None Maximum nbreak value for MetaModelCSE Returns ------- Interpolate_EOS_model EOS instance Raises ------ ValueError If config.type is not recognized """ if isinstance(config, MetamodelEOSConfig): return MetaModel_EOS_model( nsat=0.16, nmin_MM_nsat=config.nmin_MM_nsat, nmax_nsat=config.nmax_nsat, ndat=config.ndat_metamodel, crust_name=config.crust_name, ) elif isinstance(config, MetamodelCSEEOSConfig): return MetaModel_with_CSE_EOS_model( nsat=0.16, nmin_MM_nsat=config.nmin_MM_nsat, nmax_nsat=config.nmax_nsat, max_nbreak_nsat=max_nbreak_nsat, ndat_metamodel=config.ndat_metamodel, ndat_CSE=config.ndat_CSE, nb_CSE=config.nb_CSE, crust_name=config.crust_name, ) elif isinstance(config, SpectralEOSConfig): return SpectralDecomposition_EOS_model( crust_name=config.crust_name, n_points_high=config.n_points_high, reparametrized=config.reparametrized, sigma_scale=config.sigma_scale, ) else: raise ValueError(f"Unknown EOS config type: {type(config).__name__}") @staticmethod def _create_tov_solver(config: BaseTOVConfig) -> TOVSolverBase: """Create TOV solver instance from config. Parameters ---------- config : TOVConfig TOV configuration object Returns ------- TOVSolverBase TOV solver instance Raises ------ ValueError If TOV solver type is not recognized NotImplementedError If TOV solver config class is not implemented yet """ if isinstance(config, GRTOVConfig): return GRTOVSolver() elif isinstance(config, AnisotropyTOVConfig): from jesterTOV.tov.anisotropy import AnisotropyTOVSolver return AnisotropyTOVSolver() else: raise ValueError(f"Unknown TOV solver type: {type(config).__name__}")
[docs] def get_eos_type(self) -> str: """Return EOS type identifier. Returns ------- str EOS class name (e.g., 'MetaModel_EOS_model') """ return repr(self.eos)
[docs] def get_parameter_names(self) -> list[str]: """Return combined list of sampled EOS and TOV parameter names. Fixed parameters (those pinned via ``Fixed(...)`` in the prior file) are excluded — they are not part of the sampling space. Returns ------- list[str] Sampled parameter names required by this transform """ all_params = self.eos_params + self.tov_params return [p for p in all_params if p not in self.fixed_params]
[docs] def construct_eos_and_solve_tov( self, params: dict[str, Float], ) -> dict[str, Float | Float[Array, " n"]]: """Construct EOS from parameters and solve TOV equations. This is the core transformation method that: 1. Constructs EOS from parameters 2. Solves TOV equations for M-R-Λ family 3. Returns observables with constraint checking Parameters ---------- params : dict[str, Float] Input parameters (EOS + TOV parameters) Returns ------- dict[str, Float | Float[Array, " n"]] Dictionary containing: - masses_EOS : Neutron star masses [M☉] - radii_EOS : Neutron star radii [km] - Lambdas_EOS : Tidal deformabilities - logpc_EOS : Log10 central pressures - n, p, h, e, dloge_dlogp, cs2 : EOS quantities - Constraint violation counts """ # Construct EOS from parameters # EOS handles all parameter preprocessing (e.g., CSE conversion) eos_data = self.eos.construct_eos(params) # Extract TOV-specific parameters from the combined prior dict tov_params = self.tov_solver.fetch_params(params) # Solve TOV equations to get M-R-Λ family family_data = self.tov_solver.construct_family( eos_data, ndat=self.ndat_TOV, min_nsat=self.min_nsat_TOV, tov_params=tov_params, ) # Create standardized return dictionary with constraint checking result = self._create_return_dict( logpc_EOS=family_data.log10pcs, masses_EOS=family_data.masses, radii_EOS=family_data.radii, Lambdas_EOS=family_data.lambdas, ns=eos_data.ns, ps=eos_data.ps, hs=eos_data.hs, es=eos_data.es, dloge_dlogps=eos_data.dloge_dlogps, cs2=eos_data.cs2, extra_constraints=eos_data.extra_constraints, ) return result
def _create_return_dict( self, logpc_EOS: Float[Array, " n"], masses_EOS: Float[Array, " n"], radii_EOS: Float[Array, " n"], Lambdas_EOS: Float[Array, " n"], ns: Float[Array, " n"], ps: Float[Array, " n"], hs: Float[Array, " n"], es: Float[Array, " n"], dloge_dlogps: Float[Array, " n"], cs2: Float[Array, " n"], extra_constraints: dict[str, Float | Float[Array, " n"]] | None = None, ) -> dict[str, Float | Float[Array, " n"]]: """Create standardized return dictionary with constraint checking. This method checks for physical constraint violations (NaN, causality, etc.) and adds violation counts to the output. It also cleans NaN values to prevent propagation through the likelihood evaluation. Parameters ---------- logpc_EOS : Float[Array, " n"] Log10 of central pressures masses_EOS : Float[Array, " n"] Neutron star masses radii_EOS : Float[Array, " n"] Neutron star radii Lambdas_EOS : Float[Array, " n"] Tidal deformabilities ns : Float[Array, " n"] Number densities ps : Float[Array, " n"] Pressures hs : Float[Array, " n"] Enthalpies es : Float[Array, " n"] Energy densities dloge_dlogps : Float[Array, " n"] Logarithmic derivative d(ln ε)/d(ln p) cs2 : Float[Array, " n"] Sound speeds squared extra_constraints : dict | None Additional constraint violations from EOS Returns ------- dict[str, Float | Float[Array, " n"]] Complete output dictionary with cleaned values and violation counts """ # Check all constraints BEFORE cleaning NaN constraints = check_all_constraints(masses_EOS, radii_EOS, Lambdas_EOS, cs2, ps) # Clean NaN values to prevent propagation masses_EOS_clean = jnp.nan_to_num(masses_EOS, nan=0.0, posinf=0.0, neginf=0.0) radii_EOS_clean = jnp.nan_to_num(radii_EOS, nan=0.0, posinf=0.0, neginf=0.0) Lambdas_EOS_clean = jnp.nan_to_num(Lambdas_EOS, nan=0.0, posinf=0.0, neginf=0.0) logpc_EOS_clean = jnp.nan_to_num(logpc_EOS, nan=0.0, posinf=0.0, neginf=0.0) result = { # TOV solution (cleaned) "logpc_EOS": logpc_EOS_clean, "masses_EOS": masses_EOS_clean, "radii_EOS": radii_EOS_clean, "Lambdas_EOS": Lambdas_EOS_clean, # EOS quantities "n": ns, "p": ps, "h": hs, "e": es, "dloge_dlogp": dloge_dlogps, "cs2": cs2, # Constraint violation counts (scalars for JAX compatibility) "n_tov_failures": constraints["n_tov_failures"], "n_causality_violations": constraints["n_causality_violations"], "n_stability_violations": constraints["n_stability_violations"], "n_pressure_violations": constraints["n_pressure_violations"], } # Add any extra constraint violations from EOS if extra_constraints is not None: result.update(extra_constraints) return result
[docs] def forward(self, x: dict[str, Float]) -> dict[str, Float]: """Transform parameters and preserve keep_names. This overrides NtoMTransform.forward() to: 1. Merge fixed parameters into ``x`` before the EOS/TOV pipeline runs. 2. Preserve parameters specified in ``self.keep_names``. 3. Add fixed parameters to the output so they appear in the result. Parameters ---------- x : dict[str, Float] Input parameter dictionary (sampled parameters only) Returns ------- dict[str, Float] Transformed parameters with keep_names and fixed_params included """ # Inject fixed parameters so the EOS/TOV pipeline receives them. # Create a new dict to avoid mutating the caller's input. if self.fixed_params: x = {**x, **self.fixed_params} # Save parameters that should be kept kept_params = {name: x[name] for name in self.keep_names if name in x} # Call parent forward() for standard transformation result = super().forward(x) # Add back the kept parameters result.update(kept_params) # Add fixed parameters to the output for traceability if self.fixed_params: result.update(self.fixed_params) return result