Adding a New EOS Model#
This guide walks through adding a new equation of state (EOS) model to JESTER. EOS models define the thermodynamic relationship between pressure, energy density, and other state variables in neutron star matter.
Architecture Overview#
JESTER’s modular design separates EOS models from TOV solvers and inference. All EOS models inherit from Interpolate_EOS_model and are automatically compatible with all TOV solvers through the unified JesterTransform system.
Key files:
EOS implementation:
jesterTOV/eos/your_eos.pyConfiguration schema:
jesterTOV/inference/config/schemas/eos.pyTransform factory:
jesterTOV/inference/transforms/transform.pyTests:
tests/test_eos/test_your_eos.py
Step 1: Create EOS Class#
Create your EOS model in jesterTOV/eos/ inheriting from Interpolate_EOS_model:
from jaxtyping import Array, Float
from jesterTOV.eos.base import Interpolate_EOS_model
from jesterTOV.tov.data_classes import EOSData
class MyNewEOS(Interpolate_EOS_model):
"""One-line description of your EOS model.
Detailed description including:
- Physical motivation
- Parameter definitions
- Key references
References:
Author et al. (Year). Title. Journal Volume, Page.
"""
def __init__(self, config_param: float = 1.0):
"""Initialize EOS model with configuration parameters.
Args:
config_param: Description of configuration parameter
"""
self.config_param = config_param
def construct_eos(self, params: dict[str, float]) -> EOSData:
"""Build EOS from parameters.
Args:
params: Dictionary with keys matching get_required_parameters()
Returns:
EOSData: JAX NamedTuple with fields:
- ns: Number densities (fm^-3)
- ps: Pressures (geometric units)
- hs: Specific enthalpies (dimensionless)
- es: Energy densities (geometric units)
- dloge_dlogps: Derivative d(log e)/d(log p)
- cs2: Sound speed squared (c_s^2/c^2)
- mu: Chemical potential (optional)
- extra_constraints: Additional constraints (optional)
"""
# Extract parameters
param1 = params["param1"]
param2 = params["param2"]
# Build EOS (your physics here)
# ...
return EOSData(
ns=ns,
ps=ps,
hs=hs,
es=es,
dloge_dlogps=dloge_dlogps,
cs2=cs2,
mu=None, # Optional
extra_constraints=None # Optional
)
def get_required_parameters(self) -> list[str]:
"""Return list of required parameter names.
These parameter names must match those in the prior file.
Missing parameters will cause validation errors before sampling.
"""
return ["param1", "param2", "param3"]
Critical considerations:
JAX compatibility: Use
jax.numpyoperations, avoid Pythonifstatements on traced valuesGeometric units: Return pressures/energies in geometric units (c=G=1)
Array shapes: All returned arrays must have same length (number of density points)
Sound speed: Ensure \(0 \leq c_s^2 \leq 1\) (causality constraint)
Monotonicity: Pressure and energy must be monotonically increasing with density
Step 2: Update Configuration Schema#
Add a concrete config class for your EOS to jesterTOV/inference/config/schemas/eos.py and include it in the EOSConfig discriminated union.
Inherit from the appropriate base class:
BaseEOSConfig— for any EOS type (provides onlycrust_name)BaseMetamodelEOSConfig— for metamodel-based EOS types (addsndat_metamodel,nmax_nsat,nmin_MM_nsat)
# In jesterTOV/inference/config/schemas/eos.py
class MyNewEOSConfig(BaseEOSConfig):
"""Configuration for MyNewEOS."""
type: Literal["my_new_eos"] = "my_new_eos"
# EOS-specific fields
my_eos_param: float = Field(
default=1.0,
description="Configuration parameter for my new EOS"
)
# Extend the discriminated union
EOSConfig = Annotated[
Union[MetamodelEOSConfig, MetamodelCSEEOSConfig, SpectralEOSConfig, MyNewEOSConfig],
Discriminator("type"),
]
Also import and re-export MyNewEOSConfig in schema.py (both in the import block and in __all__), and in config/__init__.py, so it is accessible as jesterTOV.inference.config.MyNewEOSConfig.
Regenerate YAML documentation:
uv run python -m jesterTOV.inference.config.generate_yaml_reference
This updates docs/inference/yaml_reference.md with your new EOS type.
Step 3: Register in Transform Factory#
Add your EOS to jesterTOV/inference/transforms/transform.py using an isinstance check (so pyright can narrow types correctly):
from jesterTOV.inference.config.schema import (
...,
MyNewEOSConfig,
)
def _create_eos(config: BaseEOSConfig, ...) -> Interpolate_EOS_model:
"""Create EOS model from configuration."""
if isinstance(config, MetamodelEOSConfig):
# ... existing code ...
elif isinstance(config, MyNewEOSConfig):
from jesterTOV.eos.my_new import MyNewEOS
return MyNewEOS(config_param=config.my_eos_config)
else:
raise ValueError(f"Unknown EOS config type: {type(config).__name__}")
No need to create new transform classes—JesterTransform handles all EOS × TOV combinations automatically.
Step 4: Create Prior File#
Create a .prior file for your new EOS parameters:
# examples/inference/my_new_eos/prior.prior
# Required parameters (must match get_required_parameters())
param1 = UniformPrior(0.0, 10.0, parameter_names=["param1"])
param2 = UniformPrior(-5.0, 5.0, parameter_names=["param2"])
param3 = UniformPrior(100.0, 500.0, parameter_names=["param3"])
Parameter validation: The inference system validates that all parameters in get_required_parameters() are present in the prior file. Missing parameters cause an error before sampling starts.
Step 5: Write Tests#
Create comprehensive tests in tests/test_eos/test_my_new_eos.py:
import jax.numpy as jnp
import pytest
from jesterTOV.eos.my_new import MyNewEOS
from jesterTOV.tov import GRTOVSolver
class TestMyNewEOS:
"""Test suite for MyNewEOS model."""
@pytest.fixture
def test_params(self):
"""Valid parameter set for testing."""
return {
"param1": 5.0,
"param2": 0.0,
"param3": 250.0,
}
def test_construct_eos_shape(self, test_params):
"""Test that construct_eos returns correct array shapes."""
model = MyNewEOS()
eos_data = model.construct_eos(test_params)
# All arrays must have same length
n = len(eos_data.ns)
assert len(eos_data.ps) == n
assert len(eos_data.hs) == n
assert len(eos_data.es) == n
def test_causality(self, test_params):
"""Test that sound speed satisfies causality (cs^2 <= 1)."""
model = MyNewEOS()
eos_data = model.construct_eos(test_params)
assert jnp.all(eos_data.cs2 >= 0.0)
assert jnp.all(eos_data.cs2 <= 1.0)
def test_monotonicity(self, test_params):
"""Test that pressure and energy increase with density."""
model = MyNewEOS()
eos_data = model.construct_eos(test_params)
# Pressure monotonically increasing
assert jnp.all(jnp.diff(eos_data.ps) > 0)
# Energy monotonically increasing
assert jnp.all(jnp.diff(eos_data.es) > 0)
def test_integration_with_tov(self, test_params):
"""Test that EOS works with TOV solver."""
model = MyNewEOS()
eos_data = model.construct_eos(test_params)
solver = GRTOVSolver()
family_data = solver.construct_family(
eos_data,
ndat=50,
min_nsat=0.75
)
# Should produce physical mass-radius curves
assert jnp.all(family_data.masses > 0.0)
assert jnp.all(family_data.radii > 0.0)
def test_required_parameters(self):
"""Test that required parameters list is correct."""
model = MyNewEOS()
required = model.get_required_parameters()
assert "param1" in required
assert "param2" in required
assert "param3" in required
def test_missing_parameter_raises(self):
"""Test that missing parameters cause clear errors."""
model = MyNewEOS()
with pytest.raises(KeyError):
model.construct_eos({"param1": 1.0}) # Missing param2, param3
Test categories:
Shape validation: Array dimensions match expected
Physical constraints: Causality, positivity, monotonicity
Integration tests: Works with TOV solvers and inference
Error handling: Missing/invalid parameters handled gracefully
Step 6: Create Example Configuration#
Create an example in examples/inference/my_new_eos/:
# config.yaml
seed: 42
eos:
type: "my_new_eos"
my_eos_config: 1.0
tov:
type: "gr"
min_nsat_TOV: 0.75
ndat_TOV: 100
nb_masses: 100
prior:
specification_file: "prior.prior"
likelihoods:
- type: "eos_constraints"
enabled: true
sampler:
type: "smc-rw"
n_particles: 1000
n_mcmc_steps: 10
target_ess: 0.9
outdir: "outdir"
Test the example runs successfully:
cd examples/inference/my_new_eos
uv run run_jester_inference config.yaml
Step 7: Documentation#
Add documentation to docs/overview/eos/:
# My New EOS
One-sentence description of the EOS model.
## Physical Motivation
Explain the physics behind the model...
## Parameters
| Parameter | Description | Typical Range |
|-----------|-------------|---------------|
| param1 | Description | 0-10 |
| param2 | Description | -5 to 5 |
| param3 | Description | 100-500 |
## References
- Author et al. (Year). Title. *Journal* Volume, Page.
## Usage
\```python
from jesterTOV.eos.my_new import MyNewEOS
model = MyNewEOS()
params = {"param1": 5.0, "param2": 0.0, "param3": 250.0}
eos_data = model.construct_eos(params)
\```
Update docs/overview/eos.rst to include your new EOS.
Checklist#
Before submitting a PR:
[ ] EOS class inherits from
Interpolate_EOS_model[ ]
construct_eos()returns validEOSDataNamedTuple[ ]
get_required_parameters()lists all required parameters[ ] Added
MyNewEOSConfigtoschemas/eos.pyand included inEOSConfigunion[ ] Regenerated YAML reference
[ ] Registered in
_create_eos()factory[ ] Comprehensive tests written and passing
[ ] Example configuration runs successfully
[ ] Documentation added
[ ] All arrays are JAX-compatible (no Python
ifon traced values)[ ] Causality constraint satisfied (\(c_s^2 \leq 1\))
[ ] Monotonicity verified (pressure/energy increase with density)
[ ] Type hints added for all functions
[ ] Code passes
uv run pyright jesterTOV/eos/[ ] Tests pass:
uv run pytest tests/test_eos/test_my_new_eos.py -v
Common Pitfalls#
Python control flow on JAX arrays: Avoid if statements on traced values. Use jnp.where() instead.
# BAD: Fails during JAX tracing
if param > 0:
result = compute_positive(param)
else:
result = compute_negative(param)
# GOOD: JAX-compatible
result = jnp.where(
param > 0,
compute_positive(param),
compute_negative(param)
)
Unit inconsistencies: JESTER uses geometric units internally (c=G=1). Ensure your EOS returns pressures and energies in geometric units. Conversions to physical units (M☉, km) happen only in output.
EOSData field requirements: While mu and extra_constraints are optional, ns, ps, hs, es, dloge_dlogps, and cs2 are required. All must have the same array length.
Parameter naming: Parameter names in get_required_parameters() must exactly match those used in prior files. Mismatches cause validation errors.
Need Help?#
See existing implementations:
jesterTOV/eos/metamodel/base.py,jesterTOV/eos/spectral/Review
jesterTOV/inference/CLAUDE.mdfor inference system architectureCheck JAX documentation for array operations: https://jax.readthedocs.io/