JESTER Inference Workflow#
This guide explains how the JESTER inference system works from a file-level perspective, showing how different components connect and how data flows through the system.
Table of Contents#
Overview#
The JESTER inference system performs Bayesian inference on neutron star equation of state parameters using observational data from gravitational waves, X-ray timing, and other sources. The system is modular and config-driven, with clear separation between configuration, data loading, parameter transformation, likelihood evaluation, and sampling.
The workflow follows this general pattern:
User provides a YAML configuration file and a prior specification file
System validates configuration and sets up all components
Sampler draws parameter samples and evaluates posterior probability
For each sample: parameters are transformed into neutron star observables via TOV solver
Observables are compared against data to compute likelihood
Final samples are saved with derived quantities
File Structure and Connections#
Directory Layout#
The inference module is organized into logical subdirectories:
jesterTOV/inference/
├── base/ # Base classes (copied from Jim for independence)
│ ├── likelihood.py # LikelihoodBase abstract class
│ ├── prior.py # Prior, CombinePrior, UniformPrior
│ └── transform.py # NtoMTransform, BijectiveTransform
├── config/ # Configuration parsing and validation
│ ├── parser.py # YAML loading and path resolution
│ ├── schema.py # Pydantic models for validation
│ └── generate_yaml_reference.py # Auto-generate docs
├── priors/ # Prior specification system
│ ├── parser.py # Parse .prior files (bilby-style Python)
│ └── library.py # Prior library utilities
├── transforms/ # EOS parameter transformations
│ ├── base.py # JesterTransformBase (abstract)
│ ├── factory.py # create_transform() factory function
│ ├── metamodel.py # MetaModelTransform (8 NEP params)
│ └── metamodel_cse.py # MetaModelCSETransform (NEP + CSE grid)
├── likelihoods/ # Observational constraints
│ ├── factory.py # create_combined_likelihood()
│ ├── gw.py # Gravitational wave likelihood
│ ├── nicer.py # NICER X-ray timing likelihood
│ ├── radio.py # Radio pulsar timing likelihood
│ ├── chieft.py # Chiral EFT constraints
│ ├── rex.py # PREX/CREX radius measurements
│ ├── constraint_eos.py # Causality/stability penalties
│ └── constraint_tov.py # TOV solution validity
├── data/ # Data loading and caching
│ ├── __init__.py # Data loading functions
│ ├── paths.py # Default data paths
│ └── cache.py # Caching utilities
├── samplers/ # Sampling backends
│ ├── jester_sampler.py # JesterSampler base class
│ ├── flowmc.py # FlowMC sampler implementation
│ ├── blackjax_smc.py # Sequential Monte Carlo
│ └── blackjax_ns_aw.py # Nested Sampling with Acceptance Walk
├── flows/ # Normalizing flow training for GW
│ ├── train.py # Flow training utilities
│ └── models/ # Pre-trained flow models
├── postprocessing/ # Visualization and analysis
│ └── postprocessing.py # Corner plots, M-R diagrams, etc.
├── result.py # InferenceResult class (HDF5 storage)
└── run_inference.py # Main entry point and orchestration
File Connection Diagram#
This diagram shows how the main files connect during execution:
graph TB
subgraph "User Inputs"
CONFIG[config.yaml]
PRIOR[prior.prior]
DATA[Data Files]
end
subgraph "Entry Point"
MAIN[run_inference.py]
end
subgraph "Configuration System"
PARSER[config/parser.py]
SCHEMA[config/schema.py]
CONFIG_OBJ[InferenceConfig object]
end
subgraph "Prior System"
PRIOR_PARSER[priors/parser.py]
PRIOR_OBJ[CombinePrior object]
end
subgraph "Transform System"
TRANS_FACTORY[transforms/factory.py]
TRANS_BASE[transforms/base.py]
TRANS_MM[transforms/metamodel.py]
TRANS_CSE[transforms/metamodel_cse.py]
TRANS_OBJ[Transform object]
end
subgraph "Likelihood System"
LK_FACTORY[likelihoods/factory.py]
LK_GW[likelihoods/gw.py]
LK_NICER[likelihoods/nicer.py]
LK_RADIO[likelihoods/radio.py]
DATA_LOADER[data/__init__.py]
LK_OBJ[CombinedLikelihood object]
end
subgraph "Sampling System"
SAMPLER_BASE[samplers/jester_sampler.py]
SAMPLER_FLOW[samplers/flowmc.py]
SAMPLER_SMC[samplers/blackjax_smc.py]
SAMPLER_NS[samplers/blackjax_ns_aw.py]
SAMPLER_OBJ[Sampler object]
end
subgraph "Results System"
RESULT[result.py]
POSTPROC[postprocessing/postprocessing.py]
HDF5[results.h5]
end
CONFIG --> MAIN
PRIOR --> MAIN
DATA --> MAIN
MAIN --> PARSER
PARSER --> SCHEMA
SCHEMA --> CONFIG_OBJ
CONFIG_OBJ --> MAIN
MAIN --> PRIOR_PARSER
PRIOR --> PRIOR_PARSER
PRIOR_PARSER --> PRIOR_OBJ
PRIOR_OBJ --> MAIN
MAIN --> TRANS_FACTORY
CONFIG_OBJ --> TRANS_FACTORY
TRANS_FACTORY --> TRANS_BASE
TRANS_FACTORY --> TRANS_MM
TRANS_FACTORY --> TRANS_CSE
TRANS_MM --> TRANS_OBJ
TRANS_CSE --> TRANS_OBJ
TRANS_OBJ --> MAIN
MAIN --> LK_FACTORY
CONFIG_OBJ --> LK_FACTORY
DATA --> DATA_LOADER
LK_FACTORY --> DATA_LOADER
LK_FACTORY --> LK_GW
LK_FACTORY --> LK_NICER
LK_FACTORY --> LK_RADIO
LK_GW --> LK_OBJ
LK_NICER --> LK_OBJ
LK_RADIO --> LK_OBJ
LK_OBJ --> MAIN
MAIN --> SAMPLER_BASE
CONFIG_OBJ --> SAMPLER_FLOW
CONFIG_OBJ --> SAMPLER_SMC
CONFIG_OBJ --> SAMPLER_NS
PRIOR_OBJ --> SAMPLER_FLOW
LK_OBJ --> SAMPLER_FLOW
TRANS_OBJ --> SAMPLER_FLOW
SAMPLER_FLOW --> SAMPLER_BASE
SAMPLER_SMC --> SAMPLER_BASE
SAMPLER_NS --> SAMPLER_BASE
SAMPLER_BASE --> SAMPLER_OBJ
SAMPLER_OBJ --> MAIN
MAIN --> RESULT
SAMPLER_OBJ --> RESULT
TRANS_OBJ --> RESULT
RESULT --> HDF5
HDF5 --> POSTPROC
Data Flow Through the System#
This section shows how data transforms as it flows through the system.
High-Level Data Flow#
graph LR
A[YAML Config] --> B[Validated Config Object]
C[.prior File] --> D[CombinePrior Object]
E[Data Files] --> F[Loaded Data]
B --> G[Transform Object]
B --> H[Likelihood Object]
F --> H
D --> I[Sampler]
G --> I
H --> I
I --> J[Parameter Samples]
J --> K[Transform]
K --> L[Observable Samples]
L --> M[InferenceResult]
M --> N[HDF5 File]
Detailed Posterior Evaluation Flow#
This is what happens when the sampler evaluates the posterior probability for a proposed parameter set:
sequenceDiagram
participant S as Sampler
participant JS as JesterSampler
participant P as CombinePrior
participant T as Transform
participant L as Likelihood
participant TOV as TOV Solver
S->>JS: Propose parameters (array)
JS->>JS: Convert array to dict using parameter_names
JS->>P: log_prob(params)
P-->>JS: log_prior
alt log_prior is -inf
JS-->>S: Return -inf (reject)
else log_prior is finite
JS->>T: forward(params)
T->>TOV: Construct EOS and solve TOV
TOV-->>T: masses_EOS, radii_EOS, Lambdas_EOS
T-->>JS: transformed_params
JS->>L: evaluate(transformed_params, {})
L->>L: For each likelihood component:
L->>L: - Extract needed observables
L->>L: - Compare with data
L->>L: - Compute log_likelihood
L-->>JS: log_likelihood
JS->>JS: log_posterior = log_prior + log_likelihood
JS-->>S: log_posterior
end
Transform: Parameters to Observables#
The transform is the heart of the physics computation. Here is what happens inside:
graph TB
subgraph "Input: Microscopic Parameters"
P1[K_sat, Q_sat, Z_sat<br/>Nuclear incompressibility]
P2[E_sym, L_sym, K_sym, Q_sym, Z_sym<br/>Symmetry energy]
P3[nbreak, cs2grid<br/>CSE parameters if using MetaModel+CSE]
end
subgraph "MetaModel Construction"
M1[Build energy density expansion]
M2[Compute pressure p n ]
M3[Compute enthalpy h n ]
M4[Compute sound speed cs² n ]
M5[Attach crust model DH or BPS]
end
subgraph "TOV Integration"
T1[Define central pressure grid]
T2[For each pc in grid:]
T3[Integrate dm/dr, dP/dr from center]
T4[Stop at surface P = 0]
T5[Extract M, R, Λ]
end
subgraph "Output: Macroscopic Observables"
O1[masses_EOS: M vs pc]
O2[radii_EOS: R vs pc]
O3[Lambdas_EOS: Λ vs pc]
O4[EOS quantities: n, p, h, e, cs²]
end
P1 --> M1
P2 --> M1
P3 --> M1
M1 --> M2
M2 --> M3
M3 --> M4
M4 --> M5
M5 --> T1
T1 --> T2
T2 --> T3
T3 --> T4
T4 --> T5
T5 --> O1
T5 --> O2
T5 --> O3
M5 --> O4
Understanding the Main Components#
Configuration System#
Files: config/parser.py, config/schema.py
The configuration system is built on Pydantic for type-safe validation. When you call load_config("config.yaml"), here is what happens:
YAML file is loaded as a dictionary
Relative paths are resolved to absolute paths
Dictionary is validated against Pydantic schemas
A fully validated
InferenceConfigobject is returned
The schemas define the structure and validation rules:
class InferenceConfig(BaseModel):
seed: int
eos: EOSConfig # discriminated union by type
tov: TOVConfig # discriminated union by type
prior: PriorConfig
likelihoods: list[LikelihoodConfig]
sampler: SamplerConfig
data_paths: dict[str, str] = {}
Each section has its own schema with validation. For example, MetamodelEOSConfig ensures that nb_CSE == 0 when using standard MetaModel, and MetamodelCSEEOSConfig requires nb_CSE > 0. Config classes are defined in config/schemas/ and re-exported from config/schema.py.
Prior System#
Files: priors/parser.py, base/prior.py
The prior system uses bilby-style Python syntax. Users write a .prior file like:
K_sat = UniformPrior(150.0, 300.0, parameter_names=["K_sat"])
L_sym = UniformPrior(10.0, 200.0, parameter_names=["L_sym"])
The parser executes this Python code in a controlled namespace, extracts the Prior objects, and combines them into a CombinePrior that the sampler can use.
The parser also handles automatic inclusion/exclusion based on configuration. For example:
CSE parameters are only included if
nb_CSE > 0in the transform configCSE grid parameters are auto-generated programmatically
The CombinePrior provides two key methods:
sample(rng_key, n): Draw samples from the priorlog_prob(params): Evaluate the prior probability density
Transform System#
Files: transforms/base.py, transforms/metamodel.py, transforms/metamodel_cse.py, transforms/factory.py
Transforms convert parameter dictionaries to observable dictionaries. The base class JesterTransformBase inherits from NtoMTransform and defines the interface:
class JesterTransformBase(NtoMTransform):
def forward(self, params: dict) -> dict:
# Transform parameters to observables
pass
The two concrete implementations are:
MetaModelTransform: Takes 8 nuclear empirical parameters and produces mass-radius-tidal deformability relations
MetaModelCSETransform: Extends MetaModel with a high-density crust-splitting extension using additional grid parameters
Both implementations follow the same pattern:
Extract parameters from input dict
Construct EOS using the MetaModel formalism
Call TOV solver to get M-R-Λ relations
Return dict with observables
The factory function create_transform() instantiates the correct class based on configuration.
Likelihood System#
Files: likelihoods/factory.py, likelihoods/gw.py, likelihoods/nicer.py, etc.
Each likelihood class inherits from LikelihoodBase and implements:
def evaluate(self, params: dict, data: dict) -> float:
# Extract observables from params
# Compare with data
# Return log likelihood
pass
The params dict contains the transformed observables (masses_EOS, radii_EOS, etc.), not the original parameters.
Different likelihood types:
GWLikelihood: Uses normalizing flow trained on GW posterior
NICERLikelihood: Uses KDE of M-R posterior from X-ray timing
RadioTimingLikelihood: Uses mass measurements from radio pulsars
ChiEFTLikelihood: Penalizes EOSs outside chiral EFT bands
ConstraintEOSLikelihood: Penalizes causality violations and instabilities
ConstraintTOVLikelihood: Penalizes invalid TOV solutions
The factory function create_combined_likelihood() creates instances of each enabled likelihood and wraps them in a CombinedLikelihood that sums the log probabilities.
Data Loading#
Files: data/__init__.py, data/paths.py
Data loading is handled through a lazy-loading and caching mechanism. Data is only loaded when a likelihood that needs it is created, and once loaded, it is cached for reuse.
Each likelihood type knows what data it needs:
GW likelihoods need pre-trained normalizing flow models
NICER likelihoods need M-R posterior samples from Amsterdam/Maryland groups
ChiEFT likelihoods need constraint band files
The data loading functions handle downloading from Zenodo if needed, with automatic caching.
Sampler System#
Files: samplers/jester_sampler.py, samplers/flowmc.py, samplers/blackjax_smc.py, samplers/blackjax_ns_aw.py
The sampler system has a two-level architecture:
JesterSampler (base class): Provides a common interface for all samplers
Constructs posterior from prior + likelihood + transforms
Handles parameter name propagation
Provides transform caching infrastructure (currently disabled)
Generic interface:
sample(),get_samples(),print_summary()
Backend-specific samplers: Wrap external libraries
FlowMCSampler: Wraps flowMC library for flow-enhanced MCMC
BlackJAXSMCSampler: Wraps BlackJAX SMC with NUTS or random walk kernels
BlackJAXNSAWSampler: Wraps BlackJAX nested sampling
Each backend has specific configuration and behavior, but they all present the same interface to the rest of the system.
Results Storage#
Files: result.py
The InferenceResult class provides unified HDF5 storage for all sampler types. The structure is:
results.h5
├── /posterior/
│ ├── K_sat, L_sym, etc. # Parameter samples
│ ├── log_prob # Log posterior values
│ ├── masses_EOS # Derived: neutron star masses
│ ├── radii_EOS # Derived: radii
│ ├── Lambdas_EOS # Derived: tidal deformabilities
│ └── [sampler-specific] # weights, ess, etc.
├── /metadata/
│ ├── sampler_type
│ ├── runtime
│ └── config (JSON)
└── /histories/ # Diagnostics (ESS evolution, etc.)
The class provides:
InferenceResult.from_sampler(): Create from sampler outputadd_derived_eos(): Add M-R-Λ curves by applying transformsave()/load(): HDF5 I/Osummary(): Generate human-readable summary
Step-by-Step Execution#
Let’s walk through what happens when you run run_jester_inference config.yaml.
Phase 1: Initialization#
File: run_inference.py::main()
Load configuration (
config/parser.py)config = load_config("config.yaml") # Returns validated InferenceConfig object
Parse prior specification (
priors/parser.py)prior = parse_prior_file(config.prior.specification_file) # Returns CombinePrior with appropriate parameters
Create transform (
transforms/transform.py)transform = JesterTransform.from_config(config.eos, config.tov) # Instantiates the correct EOS model and TOV solver based on config
Create likelihood (
likelihoods/factory.py)likelihood = create_combined_likelihood( config.likelihoods, data_paths=config.data_paths ) # Returns CombinedLikelihood with all enabled likelihoods
Create sampler (
samplers/flowmc.pyor similar)if config.sampler.type == "flowmc": sampler = setup_flowmc_sampler( config.sampler, prior, likelihood, transform ) # Returns JesterSampler wrapping flowMC backend
Test likelihood evaluation
# Draw sample from prior test_sample = prior.sample(rng_key, 1) # Transform to observables test_obs = transform.forward(test_sample) # Evaluate likelihood test_logp = likelihood.evaluate(test_obs, {}) # Verify finite value
Phase 2: Sampling#
File: run_inference.py::run_sampling()
Initialize sampler
sampler.sample(jax.random.PRNGKey(config.seed))
Training phase (FlowMC only)
Draw samples with local MCMC
Train normalizing flow on samples
Repeat for n_loop_training iterations
Production phase
For FlowMC: Alternate between local MCMC and flow proposals
For SMC: Adaptive tempering with NUTS or random walk
For NS-AW: Nested sampling with acceptance walk
Extract samples
samples = sampler.get_samples(training=False) # Returns dict: {"K_sat": array, "L_sym": array, ...} log_prob = sampler.get_log_prob(training=False) # Returns array of log posterior values
Phase 3: Postprocessing#
File: run_inference.py::generate_eos_samples()
Select subset of samples
# Randomly select n_eos_samples from full posterior indices = np.random.choice(len(samples["K_sat"]), n_eos_samples) selected_samples = {k: v[indices] for k, v in samples.items()}
Filter log_prob and sampler-specific fields
selected_log_prob = log_prob[indices] # Also filter: weights, ess, logL, logL_birth if present
Apply transform to get EOS samples
# JIT compile transform for speed transform_jit = jax.jit(transform.forward) # Vectorize over samples eos_samples = jax.vmap(transform_jit)(selected_samples) # Returns: masses_EOS, radii_EOS, Lambdas_EOS for each sample
Create InferenceResult
result = InferenceResult.from_sampler( sampler_type=config.sampler.type, posterior={**selected_samples, **eos_samples}, metadata=metadata_dict, histories=histories_dict )
Save to HDF5
result.save(output_dir / "results.h5")
Optional: Generate plots
if config.postprocessing.enabled: from jesterTOV.inference.postprocessing import generate_corner_plot generate_corner_plot(result, output_dir) generate_mr_diagram(result, output_dir)
Developer Guide: Using and Extending#
Using the Inference System#
If you want to run inference with existing components:
Create a configuration file (see
examples/inference/*/config.yaml)Specify transform type and parameters
Choose which likelihoods to enable
Configure sampler settings
Create a prior specification file (see
examples/inference/*/prior.prior)Define prior distributions for nuclear parameters
Use
UniformPriorfor simple uniform distributionsInclude CSE parameters if using MetaModel+CSE
Run inference
run_jester_inference config.yamlAnalyze results
from jesterTOV.inference.result import InferenceResult result = InferenceResult.load("outdir/results.h5") print(result.summary()) # Access samples masses = result.posterior["masses_EOS"] radii = result.posterior["radii_EOS"]
Extending with a Custom Likelihood#
To add a new likelihood type:
Create a new likelihood class in
likelihoods/# likelihoods/my_custom_likelihood.py from jesterTOV.inference.base import LikelihoodBase from jaxtyping import Array, Float class MyCustomLikelihood(LikelihoodBase): def __init__(self, data_file: str): # Load your data self.data = load_data(data_file) def evaluate( self, params: dict[str, Array], data: dict ) -> Float: # Extract observables masses = params["masses_EOS"] radii = params["radii_EOS"] # Compare with your data log_likelihood = compute_log_likelihood( masses, radii, self.data ) return log_likelihood
Update the factory in
likelihoods/factory.pyfrom .my_custom_likelihood import MyCustomLikelihood def create_combined_likelihood(configs, data_paths): likelihoods = [] for config in configs: if config.type == "my_custom": lk = MyCustomLikelihood( data_file=config.parameters["data_file"] ) likelihoods.append(lk) # ... other types ...
Add configuration schema in
config/schema.pyclass MyCustomLikelihoodConfig(BaseModel): type: Literal["my_custom"] enabled: bool parameters: dict[str, Any] LikelihoodConfig = ( GWLikelihoodConfig | NICERLikelihoodConfig | MyCustomLikelihoodConfig | ... )
Add tests in
tests/test_inference/test_likelihoods.pydef test_my_custom_likelihood(): lk = MyCustomLikelihood(data_file="test_data.h5") params = {...} # Mock parameters log_p = lk.evaluate(params, {}) assert jnp.isfinite(log_p)
Extending with a Custom Transform#
To add a new transform (e.g., a different EOS model):
Create a new transform class in
transforms/# transforms/my_eos_transform.py from .base import JesterTransformBase class MyEOSTransform(JesterTransformBase): def __init__(self, name_mapping, **kwargs): super().__init__(name_mapping, **kwargs) # Initialize your EOS model def forward(self, params: dict) -> dict: # Extract your parameters param1 = params["param1"] param2 = params["param2"] # Construct EOS eos = construct_my_eos(param1, param2) # Solve TOV (reuse parent class method) masses, radii, lambdas = self._solve_tov(eos) return { "masses_EOS": masses, "radii_EOS": radii, "Lambdas_EOS": lambdas, ... } def get_eos_type(self) -> str: return "my_eos" def get_parameter_names(self) -> list[str]: return ["param1", "param2", ...]
Update the factory in
transforms/factory.pyAdd configuration schema in
config/schema.pyAdd tests
Understanding Parameter Names#
Parameter names flow through the system as follows:
Prior specification defines the names:
K_sat = UniformPrior(150, 300, parameter_names=["K_sat"])
CombinePrior flattens them:
prior.parameter_names = ["K_sat", "Q_sat", ...]
Sampler uses them to convert arrays to dicts:
params_array = [200.0, 300.0, ...] # From MCMC params_dict = { "K_sat": 200.0, "Q_sat": 300.0, ... } # Using parameter_names
Transform modifies them:
output_dict = { "masses_EOS": [...], "radii_EOS": [...], ... } # Different names!
Likelihood expects transform output names:
def evaluate(self, params, data): masses = params["masses_EOS"] # Not "K_sat"!
Transform Caching (Currently Disabled)#
The system has infrastructure for caching transform results to avoid redundant TOV solver calls, but it is currently disabled due to JAX tracing issues.
Problem: Cannot hash/cache inside JAX-compiled functions.
Current workaround: Transform is recomputed once when generating EOS samples for the result file.
Implementing caching: Each sampler needs to implement caching outside the JAX trace context. The infrastructure is in samplers/jester_sampler.py but requires per-sampler implementation.
Testing Your Extensions#
Always add tests for new components:
Unit tests: Test your component in isolation
Integration tests: Test interaction with other components
Type checking: Ensure no pyright errors
Validation: Check physical validity of results
Run tests with:
uv run pytest tests/test_inference/
uv run pyright jesterTOV/inference/
Configuration Best Practices#
Use relative paths in config files - they will be resolved automatically
Keep prior files separate - easier to modify and reuse
Start with prior-only sampling - verify parameter ranges are reasonable
Validate before sampling - use
validate_only: trueto check configurationUse dry run - use
dry_run: trueto test setup without expensive sampling
Common Pitfalls#
Parameter name mismatch: Ensure likelihood uses transform output names, not input names
Units: System uses geometric units internally, convert when interfacing with data
NaN handling: Use
jnp.infinstead ofjnp.nanfor initializationEmpty data dict: Always pass
{}to likelihood.evaluate(), notNonePrior boundaries: Avoid exact boundary values to prevent numerical issues
Transform output: Must include all names that likelihoods expect
Performance Considerations#
TOV solver: Most expensive operation, called once per likelihood evaluation
JIT compilation: First transform call is slow (compilation), subsequent calls are fast
Vectorization: Use
jax.vmapfor batch processingMemory: Large posterior samples can exhaust memory, use
n_eos_samplesto subsampleGPU: JAX automatically uses GPU if available, no code changes needed
Summary#
The JESTER inference system is designed for modularity and extensibility:
Config-driven: All settings in YAML, no code changes for different runs
Type-safe: Pydantic validation catches errors early
Modular: Easy to add new likelihoods, transforms, or samplers
Well-tested: Comprehensive test suite ensures reliability
Documented: Multiple docs at different detail levels
For end-to-end examples, see examples/inference/ directory. For specific component usage, see the other inference documentation files (quickstart, guide, workflow, yaml_reference).