Source code for jesterTOV.inference.run_inference

#!/usr/bin/env python
r"""
Modular inference script for jesterTOV
"""

import os
import time
import warnings
import json
from pathlib import Path
import numpy as np
import jax
import jax.numpy as jnp

# Enable 64-bit precision by default
jax.config.update("jax_enable_x64", True)

from .config.parser import load_config
from .config.schema import (
    InferenceConfig,
    MetamodelCSEEOSConfig,
    BaseMetamodelEOSConfig,
    GWLikelihoodConfig,
    GWEventConfig,
)
from .priors.parser import parse_prior_file
from .base.prior import CombinePrior
from .base.likelihood import LikelihoodBase
from .transforms import JesterTransform
from .likelihoods.factory import create_combined_likelihood
from .samplers import create_sampler, JesterSampler
from .result import InferenceResult
from jesterTOV.logging_config import get_logger

# Set up logger
logger = get_logger("jester")


[docs] def determine_keep_names( config: InferenceConfig, prior: CombinePrior, fixed_params: dict[str, float] | None = None, ) -> list[str] | None: """ Determine which parameters need to be preserved in transform output. This function checks which likelihoods are enabled and determines which prior parameters need to be kept in the transform output for likelihood evaluation. Parameters ---------- config : InferenceConfig Configuration object with likelihood settings prior : CombinePrior Prior object with parameter names (sampled parameters only) fixed_params : dict[str, float] | None Parameters pinned to constant values via ``Fixed(...)`` in the prior file. These are not in ``prior.parameter_names`` but are still valid. Returns ------- list[str] | None List of parameter names to keep, or None if no special handling needed Raises ------ ValueError If a required parameter is missing from both the prior and fixed_params """ _fixed = fixed_params or {} keep_names = [] # ChiEFT likelihood requires 'nbreak' parameter for CSE grid stitching # Only check if using metamodel_cse EOS chieft_enabled = any( lk.enabled and lk.type == "chieft" for lk in config.likelihoods ) if chieft_enabled and isinstance(config.eos, MetamodelCSEEOSConfig): if "nbreak" not in prior.parameter_names and "nbreak" not in _fixed: raise ValueError( "ChiEFT likelihood is enabled with metamodel_cse but 'nbreak' parameter is not in the prior. " "Please add 'nbreak' to your prior specification file. " f"Current prior parameters: {prior.parameter_names}" ) # Only add to keep_names if nbreak is sampled (not fixed). # Fixed parameters are already added to the transform output automatically. if "nbreak" in prior.parameter_names: keep_names.append("nbreak") logger.info( "ChiEFT likelihood enabled: 'nbreak' parameter will be preserved in transform output" ) else: logger.info( "ChiEFT likelihood enabled: 'nbreak' is fixed, will appear in transform output via fixed_params" ) return keep_names if keep_names else None
[docs] def setup_prior(config: InferenceConfig) -> tuple[CombinePrior, dict[str, float]]: """ Setup prior from configuration. Parameters ---------- config : InferenceConfig Configuration object Returns ------- prior : CombinePrior Combined prior over sampled parameters only. fixed_params : dict[str, float] Parameters pinned to constant values via ``Fixed(...)`` in the prior file. These are excluded from the sampling space. """ from .base.prior import UniformPrior, CombinePrior # Determine conditional parameters nb_CSE = config.eos.nb_CSE if isinstance(config.eos, MetamodelCSEEOSConfig) else 0 # Check if GW or NICER likelihoods are enabled (both need _random_key) # Note: the default `gw` and `nicer` likelihoods do NOT need # _random_key as they pre-generate masses on which to evaluate needs_random_key = False for lk in config.likelihoods: if lk.enabled and lk.type in ["gw_resampled", "nicer_kde"]: needs_random_key = True break # Parse prior file parsed = parse_prior_file( config.prior.specification_file, nb_CSE=nb_CSE, ) prior = parsed.prior fixed_params = parsed.fixed_params if fixed_params: logger.info(f"Fixed parameters found in prior file: {fixed_params}") # Add _random_key prior if GW or NICER likelihoods are enabled if needs_random_key: logger.info("Adding _random_key prior for likelihood sampling") random_key_prior = UniformPrior( float(0), float(2**32 - 1), parameter_names=["_random_key"] ) # Flatten the prior structure to avoid nested CombinePrior prior = CombinePrior(prior.base_prior + [random_key_prior]) return prior, fixed_params
[docs] def setup_transform( config: InferenceConfig, prior: CombinePrior | None = None, keep_names: list[str] | None = None, fixed_params: dict[str, float] | None = None, ) -> JesterTransform: """ Setup transform from configuration Parameters ---------- config : InferenceConfig Configuration object prior : CombinePrior, optional Prior object to extract parameter bounds (e.g., max nbreak for metamodel_cse) keep_names : list[str], optional Parameter names to keep in transformed output fixed_params : dict[str, float], optional Parameters pinned to constant values, excluded from the sampling space. Returns ------- JesterTransform Transform instance """ _fixed_params = fixed_params or {} # Determine max_nbreak_nsat for MetamodelCSE: compare config field vs prior bound # TODO: in a future version, need to improve this... max_nbreak_nsat = None if isinstance(config.eos, MetamodelCSEEOSConfig): config_value = config.eos.max_nbreak_nsat prior_value = None if prior is not None: from .base.prior import UniformPrior for param_prior in prior.base_prior: if "nbreak" in param_prior.parameter_names: if isinstance(param_prior, UniformPrior): nsat = 0.16 # saturation density in fm^-3 prior_value = param_prior.xmax / nsat break if config_value is not None and prior_value is not None: if not np.isclose(config_value, prior_value, rtol=1e-3): raise ValueError( f"eos.max_nbreak_nsat in config ({config_value:.4f} n_sat) does not " f"match the upper bound of the nbreak prior ({prior_value:.4f} n_sat). " "Either remove max_nbreak_nsat from the eos config to derive it " "automatically from the nbreak prior, or update the prior bound to match." ) max_nbreak_nsat = config_value elif prior_value is not None: max_nbreak_nsat = prior_value logger.info( f"Derived max_nbreak from nbreak prior: {prior_value:.4f} n_sat" ) elif config_value is not None: max_nbreak_nsat = config_value logger.info( f"Using max_nbreak_nsat from eos config: {config_value:.4f} n_sat" ) transform = JesterTransform.from_config( eos_config=config.eos, tov_config=config.tov, keep_names=keep_names, max_nbreak_nsat=max_nbreak_nsat, fixed_params=_fixed_params if _fixed_params else None, ) # Validate that all required parameters are present. # get_parameter_names() already excludes fixed params, so we only need to # check that the remaining required params are covered by the sampled prior. if prior is not None: required_params = set(transform.get_parameter_names()) prior_params = set(prior.parameter_names) missing_params = required_params - prior_params if missing_params: eos_name = transform.get_eos_type() # TODO: add repr to TOV solver for get_tov_type() so we can make this similar to EOS tov_name = repr(transform.tov_solver) raise ValueError( f"Transform with EOS = {eos_name} and TOV = {tov_name} is missing " f"params = {sorted(missing_params)} from the prior file" ) # Log warning for unused parameters (not an error, just informational) unused_params = prior_params - required_params if unused_params: logger.warning( f"Prior contains unused parameters: {sorted(unused_params)}. " f"These will be preserved but not used by the transform." ) return transform
def prepare_gw_flows(config: InferenceConfig, outdir: Path) -> InferenceConfig: """Prepare normalizing flows for GW events that need flow training. For each enabled :class:`~jesterTOV.inference.config.schema.GWLikelihoodConfig` with events in bilby or NPZ mode, this function: 1. Resolves the ``nf_model_dir`` to ``{outdir}/gw_flow_cache/{event.name}``. 2. Obtains posterior samples — either by extracting from a bilby HDF5 (``from_bilby_result``) or by using an existing NPZ directly (``from_npz_file``). 3. Trains a normalizing flow (with config-hash-based cache invalidation). 4. Replaces the training-mode event with a pre-trained-flow-mode event. This is a no-op if no ``GWLikelihoodConfig`` events use ``from_bilby_result`` or ``from_npz_file``. Parameters ---------- config : InferenceConfig Current inference configuration. outdir : Path Run output directory (flows are cached under ``{outdir}/gw_flow_cache/``). Returns ------- InferenceConfig Updated config where all GW events point to a pre-trained flow directory. """ import hashlib import shutil from .flows.bilby_extract import extract_gw_posterior_from_bilby from .flows.train_flow import train_flow_from_config from .flows.config import FlowTrainingConfig logger.info("Checking to prepare GW normalizing flows...") updated_likelihoods = list(config.likelihoods) any_changes = False for lk_idx, lk_config in enumerate(config.likelihoods): if not isinstance(lk_config, GWLikelihoodConfig) or not lk_config.enabled: continue training_events = [ e for e in lk_config.events if e.from_bilby_result is not None or e.from_npz_file is not None ] if not training_events: continue updated_events = list(lk_config.events) for ev_idx, event in enumerate(lk_config.events): if event.from_bilby_result is None and event.from_npz_file is None: continue # 1. Resolve nf_model_dir (cache directory for this event) nf_model_dir = outdir / "gw_flow_cache" / event.name nf_model_dir.mkdir(parents=True, exist_ok=True) # 2. Obtain NPZ path if event.from_bilby_result is not None: # Bilby mode: extract NPZ from HDF5 into the cache directory npz_path = nf_model_dir / "posterior_samples.npz" if not npz_path.exists(): logger.info( f"Extracting GW posterior for '{event.name}' from " f"{event.from_bilby_result}" ) extract_gw_posterior_from_bilby( bilby_result_file=event.from_bilby_result, output_file=str(npz_path), ) else: logger.info( f"Using cached posterior samples for '{event.name}' at {npz_path}" ) else: # NPZ mode: use the provided NPZ file directly npz_path = Path(event.from_npz_file) # type: ignore[arg-type] if not npz_path.exists(): logger.error( f"GW event '{event.name}': from_npz_file path does not exist: {npz_path}" ) raise ValueError( f"GW event '{event.name}': from_npz_file '{npz_path}' does not exist." ) if npz_path.suffix.lower() != ".npz": logger.error( f"GW event '{event.name}': from_npz_file does not have a .npz extension: {npz_path}" ) raise ValueError( f"GW event '{event.name}': from_npz_file '{npz_path}' does not have a .npz extension." ) logger.info( f"Using provided NPZ posterior samples for '{event.name}' " f"at {npz_path}" ) # 3. Build FlowTrainingConfig fixed_fields = { "posterior_file": str(npz_path), "parameter_names": [ "mass_1_source", "mass_2_source", "lambda_1", "lambda_2", ], "output_dir": str(nf_model_dir), } if event.flow_config: # If there is a flow_config, use it as a base and override the fixed fields (e.g., to allow custom training settings like n_epochs or batch_size) ft_config = FlowTrainingConfig.from_yaml(event.flow_config) ft_config = ft_config.model_copy(update=fixed_fields) else: # No flow_config provided, use defaults for everything except the fixed fields ft_config = FlowTrainingConfig(**fixed_fields) # 4. Check cache flow_weights = nf_model_dir / "flow_weights.eqx" config_hash_file = nf_model_dir / "flow_config_hash.json" current_hash = hashlib.sha256( ft_config.model_dump_json().encode() ).hexdigest() should_train = True if flow_weights.exists() and not event.retrain_flow: stored_hash: str | None = None if config_hash_file.exists(): stored_hash = json.loads(config_hash_file.read_text()).get("hash") if stored_hash == current_hash: logger.info( f"Reusing cached flow for '{event.name}' at {nf_model_dir}" ) should_train = False else: logger.warning( f"Flow config changed for '{event.name}' → retraining" ) shutil.rmtree(nf_model_dir) nf_model_dir.mkdir(parents=True, exist_ok=True) # For bilby mode: re-extract NPZ after clearing the cache directory if event.from_bilby_result is not None and not npz_path.exists(): extract_gw_posterior_from_bilby( bilby_result_file=event.from_bilby_result, output_file=str(npz_path), ) elif event.retrain_flow and flow_weights.exists(): logger.info(f"retrain_flow=True for '{event.name}' → retraining") shutil.rmtree(nf_model_dir) nf_model_dir.mkdir(parents=True, exist_ok=True) # For bilby mode: re-extract NPZ after clearing the cache directory if event.from_bilby_result is not None and not npz_path.exists(): extract_gw_posterior_from_bilby( bilby_result_file=event.from_bilby_result, output_file=str(npz_path), ) # 5. Train flow if needed if should_train: logger.info(f"Training normalizing flow for '{event.name}'...") train_flow_from_config(ft_config) config_hash_file.write_text(json.dumps({"hash": current_hash})) # 6. Replace event with resolved pre-trained-flow event updated_events[ev_idx] = GWEventConfig( name=event.name, nf_model_dir=str(nf_model_dir), ) updated_likelihoods[lk_idx] = lk_config.model_copy( update={"events": updated_events} ) any_changes = True if not any_changes: return config return config.model_copy(update={"likelihoods": updated_likelihoods}) # TODO: remove transform second argument, as it is unused # TODO: this is a bit redundant and we can just use the factory directly
[docs] def setup_likelihood( config: InferenceConfig, transform: JesterTransform ) -> LikelihoodBase: """ Setup combined likelihood from configuration Parameters ---------- config : InferenceConfig Configuration object transform : JesterTransform Transform instance Returns ------- LikelihoodBase Combined likelihood instance """ return create_combined_likelihood(config.likelihoods)
[docs] def run_sampling( sampler: JesterSampler, seed: int, config: InferenceConfig, outdir: str | Path, fixed_params: dict[str, float] | None = None, ) -> InferenceResult: """ Run MCMC sampling and create InferenceResult Parameters ---------- sampler : JesterSampler JesterSampler instance (FlowMC, BlackJAX NS, or BlackJAX SMC) seed : int Random seed for sampling config : InferenceConfig Configuration object outdir : str or Path Output directory fixed_params : dict[str, float] | None, optional Parameters pinned to constant values during inference. Returns ------- InferenceResult Result object containing samples, metadata, and histories """ logger.info(f"Starting sampling with seed {seed}...") start = time.time() sampler.sample(jax.random.PRNGKey(seed)) sampler.print_summary() end = time.time() runtime = end - start logger.info( f"Sampling complete! Runtime: {int(runtime / 60)} min {int(runtime % 60)} sec" ) # Generate diagnostic plots for SMC samplers from .samplers.blackjax.smc.base import BlackjaxSMCSampler # TODO: plot_diagnostics should be in the base class, do not fail if not implemented but just pass # Then, samplers can implement their own diagnostics as needed (e.g., FlowMC could have training diagnostics, acceptance rates, etc.) -- for now we only have this for SMC, but that is fine if isinstance(sampler, BlackjaxSMCSampler): logger.info("Generating SMC diagnostic plots...") sampler.plot_diagnostics(outdir=outdir, filename="smc_diagnostics.png") # Create InferenceResult from sampler output logger.info("Creating InferenceResult from sampler output...") result = InferenceResult.from_sampler( sampler=sampler, config=config, runtime=runtime, fixed_params=fixed_params, ) return result
# TODO: fully deprecate this: remove this entirely (if no other dependency on it)
[docs] def generate_eos_samples( config: InferenceConfig, result: InferenceResult, transform_eos: JesterTransform, outdir: str | Path, n_eos_samples: int = 10_000, ) -> None: """ .. deprecated:: This function is deprecated and will be removed in a future version. Use :meth:`InferenceResult.add_eos_from_transform` instead. Generate EOS curves from sampled parameters and add to InferenceResult Parameters ---------- config : InferenceConfig Configuration object result : InferenceResult Result object with posterior samples transform_eos : JesterTransform Transform for generating full EOS quantities outdir : str or Path Output directory n_eos_samples : int, optional Number of EOS samples to generate """ warnings.warn( "generate_eos_samples() is deprecated and will be removed in a future version. " "Use InferenceResult.add_eos_from_transform() instead.", DeprecationWarning, stacklevel=2, ) # Get log_prob from result log_prob = result.posterior["log_prob"] # Cap n_eos_samples at available sample size n_available = len(log_prob) if n_eos_samples > n_available: logger.warning( f"Requested {n_eos_samples} EOS samples but only {n_available} available." ) logger.warning(f"Using all {n_available} samples instead.") n_eos_samples = n_available logger.info(f"Generating {n_eos_samples} EOS samples...") # Randomly select samples idx = np.random.choice(np.arange(len(log_prob)), size=n_eos_samples, replace=False) # Filter out metadata fields and derived EOS quantities that aren't transform parameters # Only keep fields that are NEP/CSE parameters for the transform exclude_keys = { "weights", "ess", "logL", "logL_birth", "log_prob", "_sampler_specific", "masses_EOS", "radii_EOS", "Lambdas_EOS", "n", "p", "e", "cs2", } param_samples = {k: v for k, v in result.posterior.items() if k not in exclude_keys} chosen_samples = {k: jnp.array(v[idx]) for k, v in param_samples.items()} # NOTE: This function is currently unused - see InferenceResult.add_eos_from_transform() instead # CRITICAL: If this function is ever re-enabled, remember to filter ALL arrays: # - log_prob and sampler fields (weights, ess, logL, logL_birth) # - NEP/CSE parameter arrays # Store the original full log_prob for reference, then update with filtered version result.posterior["log_prob_full"] = result.posterior["log_prob"].copy() result.posterior["log_prob"] = result.posterior["log_prob"][idx] # Filter other sampler-specific fields if present sampler_fields_to_filter = ["weights", "ess", "logL", "logL_birth"] for field in sampler_fields_to_filter: if field in result.posterior: result.posterior[f"{field}_full"] = result.posterior[field].copy() result.posterior[field] = result.posterior[field][idx] logger.info( f"Filtered log_prob and sampler fields from {len(log_prob)} to {len(result.posterior['log_prob'])} samples" ) # Generate EOS curves with batched processing logger.info("Running TOV solver with batched processing...") my_forward = jax.jit(transform_eos.forward) # Get batch size from config batch_size = config.sampler.log_prob_batch_size if batch_size > n_eos_samples: logger.warning( f"Requested batch size ({batch_size}) is larger than the number of samples " f"({n_eos_samples}). Adjusting batch size to {n_eos_samples}." ) batch_size = n_eos_samples logger.info(f"Using batch size: {batch_size}") # Run with batched processing (JIT compilation happens on first batch) TOV_start = time.time() transformed_samples = jax.lax.map(my_forward, chosen_samples, batch_size=batch_size) TOV_end = time.time() logger.info( f"TOV solve time: {TOV_end - TOV_start:.2f} s ({n_eos_samples} samples)" ) # Add derived EOS quantities to result result.add_derived_eos(transformed_samples) logger.info("Derived EOS quantities added to InferenceResult")
# TODO: there are some calls that check specific types of samplers/configs/... # Ideally we should remove this and just have a small loop that prints over all available # attributes of the config/sampler/likelihood/transform/prior objects, so that we don't have to update this function every time we add a new sampler type or likelihood type or EOS type, etc. We can still have some special handling for certain fields (e.g., if chieft enabled then print nbreak bounds, etc.) but in general we should just loop over all fields and print them in a structured way (e.g., using Pydantic's model_dump() with some formatting for logging) # This is already done a bit with the likelihoods, so follow that approach in the future
[docs] def main(config_path: str) -> None: """Main inference script Parameters ---------- config_path : str Path to YAML configuration file """ # Load configuration logger.info(f"Loading configuration from {config_path}") config = load_config(config_path) # Enable NaN debugging if requested if config.debug_nans: logger.info("Enabling JAX NaN debugging") jax.config.update("jax_debug_nans", True) outdir = config.sampler.output_dir # Print GPU info logger.info(f"JAX devices: {jax.devices()}") # Validation only if config.validate_only: logger.info("Configuration valid!") return # Setup components logger.info("Setting up prior...") prior, fixed_params = setup_prior(config) # Log detailed prior information logger.info(f"Prior has {prior.n_dim} dimensions") logger.info(f"Prior parameter names: {prior.parameter_names}") # Get individual priors - CombinePrior stores them in base_prior attribute if hasattr(prior, "base_prior") and isinstance(prior.base_prior, list): individual_priors = prior.base_prior else: # For single priors, wrap in a list individual_priors = [prior] # Flatten the list of priors (in case of nested CombinePriors) def flatten_priors(priors_list): result = [] for p in priors_list: if hasattr(p, "base_prior") and isinstance(p.base_prior, list): result.extend(flatten_priors(p.base_prior)) else: result.append(p) return result all_priors = flatten_priors(individual_priors) # Log each prior with its parameters idx = 0 for param_prior in all_priors: for name in param_prior.parameter_names: if hasattr(param_prior, "xmin") and hasattr(param_prior, "xmax"): logger.info( f" [{idx}] {name}: Uniform({param_prior.xmin}, {param_prior.xmax})" ) else: logger.info(f" [{idx}] {name}: {type(param_prior).__name__}") idx += 1 # Determine which parameters need to be preserved in transform output # based on enabled likelihoods (validates required parameters exist in prior or fixed_params) keep_names = determine_keep_names(config, prior, fixed_params) logger.info("Setting up transform...") transform = setup_transform( config, prior=prior, keep_names=keep_names, fixed_params=fixed_params ) # Log transform details logger.info(f"EOS type: {config.eos.type}") if isinstance(config.eos, MetamodelCSEEOSConfig): logger.info(f" nb_CSE: {config.eos.nb_CSE}") if config.eos.max_nbreak_nsat is not None: logger.info(f" max_nbreak_nsat: {config.eos.max_nbreak_nsat:.4f} n_sat") if isinstance(config.eos, BaseMetamodelEOSConfig): logger.info(f" ndat_metamodel: {config.eos.ndat_metamodel}") logger.info(f" nmax_nsat: {config.eos.nmax_nsat}") logger.info(f"TOV solver: {config.tov.type}") logger.info(f" ndat_TOV: {config.tov.ndat_TOV}") if keep_names: logger.info(f" Preserving parameters in output: {keep_names}") config = prepare_gw_flows(config, Path(outdir)) logger.info("Setting up likelihood...") likelihood = setup_likelihood(config, transform) # Log detailed likelihood information enabled_likelihoods = [lk for lk in config.likelihoods if lk.enabled] logger.info(f"Number of enabled likelihoods: {len(enabled_likelihoods)}") for lk in enabled_likelihoods: # Use Pydantic's model_dump to serialize config for logging lk_dict = lk.model_dump( exclude={"enabled"} ) # Exclude enabled since we already filtered logger.info(f" - {lk.type.upper()}:") logger.info(f" {json.dumps(lk_dict, indent=6)}") logger.info(f"Setting up {config.sampler.type} sampler...") sampler = create_sampler( config=config.sampler, prior=prior, likelihood=likelihood, likelihood_transforms=[transform], seed=config.seed, ) # Log detailed sampler configuration logger.info("=" * 60) logger.info("Configuration Summary") logger.info("=" * 60) logger.info(f"EOS type: {config.eos.type}") logger.info(f"TOV solver: {config.tov.type}") logger.info(f"Random seed: {config.seed}") logger.info(f"Sampler type: {config.sampler.type}") logger.info("Sampler Configuration:") # Log sampler-specific config fields if config.sampler.type == "flowmc": logger.info(f" Chains: {config.sampler.n_chains}") logger.info(f" Training loops: {config.sampler.n_loop_training}") logger.info(f" Production loops: {config.sampler.n_loop_production}") logger.info(f" Local steps per loop: {config.sampler.n_local_steps}") logger.info(f" Global steps per loop: {config.sampler.n_global_steps}") logger.info(f" Training epochs: {config.sampler.n_epochs}") logger.info(f" Learning rate: {config.sampler.learning_rate}") logger.info(f" Training thinning: {config.sampler.train_thinning}") logger.info(f" Output thinning: {config.sampler.output_thinning}") elif config.sampler.type == "blackjax-ns-aw": logger.info(f" Live points: {config.sampler.n_live}") logger.info(f" Delete fraction: {config.sampler.n_delete_frac}") logger.info(f" Target MCMC steps: {config.sampler.n_target}") logger.info(f" Termination dlogZ: {config.sampler.termination_dlogz}") elif config.sampler.type in ["smc-rw", "smc-nuts"]: logger.info(f" Particles: {config.sampler.n_particles}") logger.info(f" MCMC steps: {config.sampler.n_mcmc_steps}") logger.info(f" Target ESS: {config.sampler.target_ess}") # Log shared sampler config fields logger.info(f" EOS samples to generate: {config.sampler.n_eos_samples}") logger.info(f" Output directory: {outdir}") logger.info("=" * 60 + "\n") # Dry run option if config.dry_run: logger.info("Dry run complete!") return # Create output directory os.makedirs(outdir, exist_ok=True) # Test likelihood evaluation # FIXME: this should throw an error if a Nan is found logger.info("Testing likelihood evaluation...") test_samples = prior.sample(jax.random.PRNGKey(0), 3) test_samples_transformed = jax.vmap(transform.forward)(test_samples) test_log_prob = jax.vmap(likelihood.evaluate)(test_samples_transformed) logger.info(f"Test log probabilities: {test_log_prob}") # Run inference result = run_sampling( sampler, config.seed, config, outdir, fixed_params=fixed_params ) # Generate EOS quantities from posterior samples # Note: This requires recomputing the TOV solver for selected samples. # Future optimization: implement transform caching outside JAX trace (see JesterSampler) result.add_eos_from_transform( transform=transform, # Use the same transform from sampling n_eos_samples=config.sampler.n_eos_samples, batch_size=config.sampler.log_prob_batch_size, ) # Save unified HDF5 file result_path = os.path.join(outdir, "results.h5") result.save(result_path) logger.info(f"Results saved to {result_path}") # Run postprocessing if enabled if config.postprocessing.enabled: logger.info("\n" + "=" * 60) logger.info("Running postprocessing...") logger.info("=" * 60) from jesterTOV.inference.postprocessing.postprocessing import generate_all_plots generate_all_plots( outdir=outdir, prior_dir=config.postprocessing.prior_dir, make_cornerplot_flag=config.postprocessing.make_cornerplot, make_massradius_flag=config.postprocessing.make_massradius, make_pressuredensity_flag=config.postprocessing.make_pressuredensity, make_histograms_flag=config.postprocessing.make_histograms, make_cs2_flag=config.postprocessing.make_cs2, injection_eos_path=config.postprocessing.injection_eos_path, ) logger.info(f"\nPostprocessing complete! Plots saved to {outdir}") logger.info(f"\nInference complete! Results saved to {outdir}")
[docs] def cli_entry_point() -> None: """ Entry point for console script. Allows running inference with: run_jester_inference config.yaml Instead of: python -m jesterTOV.inference.run_inference config.yaml """ import sys # Check for exactly one argument (the config file path) if len(sys.argv) != 2: logger.error("Usage: run_jester_inference <config.yaml>") logger.info("\nExamples:") logger.info(" run_jester_inference config.yaml") logger.info( " run_jester_inference examples/inference/full_inference/config.yaml" ) logger.info( "\nOptions like dry_run and validate_only should be set in the YAML config file." ) sys.exit(1) config_path = sys.argv[1] main(config_path)
if __name__ == "__main__": import sys if len(sys.argv) != 2: logger.error("Usage: python -m jesterTOV.inference.run_inference <config.yaml>") sys.exit(1) main(sys.argv[1])