Adding a New TOV Solver#
This guide walks through adding a new Tolman-Oppenheimer-Volkoff (TOV) equation solver to JESTER. TOV solvers integrate the structure equations for neutron stars, computing mass-radius-tidal deformability (M-R-Λ) relationships from an equation of state.
Architecture Overview#
TOV solvers are independent of specific EOS models and work with any EOS through the unified EOSData interface. All solvers inherit from TOVSolverBase and are automatically compatible with all EOS models through JesterTransform.
Key files:
TOV solver implementation:
jesterTOV/tov/your_solver.pyConfiguration schema:
jesterTOV/inference/config/schemas/tov.pyTransform factory:
jesterTOV/inference/transforms/transform.pyTests:
tests/test_tov/test_your_solver.py
Step 1: Create TOV Solver Class#
Create your solver in jesterTOV/tov/ inheriting from TOVSolverBase:
import jax.numpy as jnp
from jaxtyping import Array, Float
from jesterTOV.tov.base import TOVSolverBase
from jesterTOV.tov.data_classes import EOSData, TOVSolution, FamilyData
import diffrax
class MyNewTOVSolver(TOVSolverBase):
"""One-line description of your TOV solver.
Detailed description including:
- Gravity theory or modification being implemented
- Key equations
- References
References:
Author et al. (Year). Title. Journal Volume, Page.
"""
def __init__(self, coupling_constant: float = 0.0):
"""Initialize TOV solver with theory-specific parameters.
Args:
coupling_constant: Description of coupling parameter
"""
self.coupling_constant = coupling_constant
def solve(
self,
eos_data: EOSData,
pc: float,
**kwargs: Any
) -> TOVSolution:
"""Solve TOV equations for a single central pressure.
Args:
eos_data: EOS data from any EOS model
pc: Central pressure (geometric units)
**kwargs: Additional parameters from prior (if needed)
Returns:
TOVSolution: NamedTuple with:
- M: Total mass (solar masses)
- R: Radius (km)
- k2: Tidal Love number (dimensionless)
"""
# Extract additional parameters if needed
theory_param = kwargs.get("theory_param", self.coupling_constant)
# Define your modified TOV equations
def ode_system(t, y, args):
"""ODE system: dy/dt = f(t, y).
Args:
t: Independent variable (typically radius)
y: State vector [mass, pressure, ...]
args: Additional arguments (eos_data, etc.)
Returns:
Derivatives dy/dt
"""
# Your equations here
# ...
return derivatives
# Solve using Diffrax
solver = diffrax.Dopri5()
term = diffrax.ODETerm(ode_system)
# Set up initial conditions
y0 = jnp.array([...]) # Initial state
# Solve with adaptive stepping
sol = diffrax.diffeqsolve(
term,
solver,
t0=0.0,
t1=t_final,
dt0=0.01,
y0=y0,
args=(eos_data,),
max_steps=10000,
throw=False # Graceful failure handling
)
# Extract solution
M = sol.ys[0][-1] # type: ignore[index]
R = sol.ys[1][-1] # type: ignore[index]
# Compute Love number (theory-specific)
k2 = self._compute_love_number(sol, eos_data)
# Convert to physical units
from jesterTOV.utils import Msun_km
M_solar = M / Msun_km
R_km = R
return TOVSolution(M=M_solar, R=R_km, k2=k2)
def construct_family(
self,
eos_data: EOSData,
ndat: int,
min_nsat: float,
**kwargs: Any
) -> FamilyData:
"""Build M-R-Λ family curves.
Args:
eos_data: EOS data from any EOS model
ndat: Number of central pressures to sample
min_nsat: Minimum density as fraction of saturation density
**kwargs: Additional theory parameters from prior
Returns:
FamilyData: NamedTuple with:
- log10pcs: Log10 central pressures
- masses: Gravitational masses (M☉)
- radii: Radii (km)
- lambdas: Tidal deformabilities (dimensionless)
"""
# Generate central pressure grid
log10_pc_min = ... # Based on min_nsat
log10_pc_max = ...
log10pcs = jnp.linspace(log10_pc_min, log10_pc_max, ndat)
# Vectorize solve() over central pressures
solve_vec = jax.vmap(
lambda pc: self.solve(eos_data, 10**pc, **kwargs)
)
solutions = solve_vec(log10pcs)
# Extract arrays (vmap batches scalar fields → arrays)
masses: Float[Array, "ndat"] = solutions.M # type: ignore[assignment]
radii: Float[Array, "ndat"] = solutions.R # type: ignore[assignment]
k2s: Float[Array, "ndat"] = solutions.k2 # type: ignore[assignment]
# Compute dimensionless tidal deformability
lambdas = self._compute_lambda(masses, radii, k2s)
return FamilyData(
log10pcs=log10pcs,
masses=masses,
radii=radii,
lambdas=lambdas
)
def get_required_parameters(self) -> list[str]:
"""Return list of additional parameters beyond EOS.
These are theory-specific coupling constants or other parameters
that must be present in the prior file.
Returns:
List of parameter names required from prior
"""
return ["theory_param"] # Or empty list if no additional params
Critical considerations:
JAX compatibility: Use
jax.numpy,jax.vmap, avoid Python control flow on traced valuesDiffrax for ODE solving: Use
Dopri5orTsit5with adaptive steppingGraceful failure: Set
throw=Falseindiffeqsolveto handle divergent solutionsType ignores for vmap:
vmapbatches scalar NamedTuple fields into arrays, requiring type ignoresUnit conventions: Return masses in M☉, radii in km, Λ dimensionless
Step 2: Update Configuration Schema#
Add a concrete config class for your solver to jesterTOV/inference/config/schemas/tov.py and include it in the TOVConfig discriminated union. BaseTOVConfig already provides min_nsat_TOV, ndat_TOV, and nb_masses; only add solver-specific fields to your subclass:
# In jesterTOV/inference/config/schemas/tov.py
class MyNewTOVConfig(BaseTOVConfig):
"""Configuration for MyNewTOVSolver."""
type: Literal["my_new_solver"] = "my_new_solver" # type: ignore[override]
# Solver-specific fields
my_solver_coupling: float = Field(
default=0.0,
description="Coupling constant for my new solver"
)
# Switch TOVConfig to a discriminated union
TOVConfig = Annotated[
Union[GRTOVConfig, MyNewTOVConfig],
Discriminator("type"),
]
Also export MyNewTOVConfig from schema.py.
Regenerate YAML documentation:
uv run python -m jesterTOV.inference.config.generate_yaml_reference
Step 3: Register in Transform Factory#
Add your solver to jesterTOV/inference/transforms/transform.py using an isinstance check:
from jesterTOV.inference.config.schema import BaseTOVConfig, GRTOVConfig, MyNewTOVConfig
def _create_tov_solver(config: BaseTOVConfig) -> TOVSolverBase:
"""Create TOV solver from configuration."""
if isinstance(config, GRTOVConfig):
return GRTOVSolver()
elif isinstance(config, MyNewTOVConfig):
from jesterTOV.tov.my_new import MyNewTOVSolver
return MyNewTOVSolver(coupling_constant=config.my_solver_coupling)
else:
raise ValueError(f"Unknown TOV config type: {type(config).__name__}")
Step 4: Create Prior File#
If your solver requires additional parameters beyond the EOS:
# examples/inference/my_new_solver/prior.prior
# EOS parameters (e.g., for MetaModel)
E_sat = UniformPrior(-16.1, -15.9, parameter_names=["E_sat"])
K_sat = UniformPrior(150.0, 300.0, parameter_names=["K_sat"])
# ... other EOS params ...
# Theory-specific parameters
theory_param = UniformPrior(-1.0, 1.0, parameter_names=["theory_param"])
Parameter validation: The inference system checks that all parameters in get_required_parameters() are present. Missing parameters cause errors before sampling.
Step 5: Write Tests#
Create comprehensive tests in tests/test_tov/test_my_new_solver.py:
import jax.numpy as jnp
import pytest
from jesterTOV.tov.my_new import MyNewTOVSolver
from jesterTOV.eos.metamodel import MetaModel_EOS_model
class TestMyNewTOVSolver:
"""Test suite for MyNewTOVSolver."""
@pytest.fixture
def eos_data(self):
"""Valid EOS data for testing."""
model = MetaModel_EOS_model()
params = {
"E_sat": -16.0,
"K_sat": 240.0,
"Q_sat": 400.0,
"Z_sat": 0.0,
"E_sym": 31.7,
"L_sym": 58.7,
"K_sym": -100.0,
"Q_sym": 0.0,
"Z_sym": 0.0,
}
return model.construct_eos(params)
def test_solve_single_star(self, eos_data):
"""Test single star solution."""
solver = MyNewTOVSolver()
pc = 1e34 # Example central pressure
solution = solver.solve(eos_data, pc)
# Physical mass and radius
assert solution.M > 0.0
assert solution.M < 3.0 # Reasonable mass range
assert solution.R > 5.0
assert solution.R < 20.0 # Reasonable radius range
# Love number range
assert solution.k2 >= 0.0
assert solution.k2 <= 1.0
def test_construct_family(self, eos_data):
"""Test M-R-Λ family construction."""
solver = MyNewTOVSolver()
family_data = solver.construct_family(
eos_data,
ndat=50,
min_nsat=0.75
)
# Check shapes
assert len(family_data.log10pcs) == 50
assert len(family_data.masses) == 50
assert len(family_data.radii) == 50
assert len(family_data.lambdas) == 50
# Physical ranges
assert jnp.all(family_data.masses > 0.0)
assert jnp.all(family_data.radii > 0.0)
assert jnp.all(family_data.lambdas >= 0.0)
# Find maximum mass
max_mass_idx = jnp.argmax(family_data.masses)
M_max = family_data.masses[max_mass_idx]
assert M_max > 1.0 # Reasonable maximum mass
def test_gr_limit(self, eos_data):
"""Test that solver reduces to GR when coupling = 0."""
from jesterTOV.tov.gr import GRTOVSolver
solver_new = MyNewTOVSolver(coupling_constant=0.0)
solver_gr = GRTOVSolver()
pc = 1e34
sol_new = solver_new.solve(eos_data, pc)
sol_gr = solver_gr.solve(eos_data, pc)
# Should match GR to numerical precision
assert jnp.allclose(sol_new.M, sol_gr.M, rtol=1e-6)
assert jnp.allclose(sol_new.R, sol_gr.R, rtol=1e-6)
def test_additional_parameters(self, eos_data):
"""Test that additional parameters from prior are used."""
solver = MyNewTOVSolver()
pc = 1e34
sol1 = solver.solve(eos_data, pc, theory_param=0.0)
sol2 = solver.solve(eos_data, pc, theory_param=1.0)
# Different parameters should give different results
assert not jnp.allclose(sol1.M, sol2.M)
def test_required_parameters(self):
"""Test that required parameters list is correct."""
solver = MyNewTOVSolver()
required = solver.get_required_parameters()
if required: # If solver needs additional params
assert "theory_param" in required
@pytest.mark.slow
def test_convergence_tolerance(self, eos_data):
"""Test numerical convergence with different tolerances."""
solver = MyNewTOVSolver()
pc = 1e34
# Run with different max_steps (if configurable)
# Verify solution converges
Test categories:
Single star: Test
solve()produces physical M, R, k2Family curves: Test
construct_family()produces valid M-R-ΛLimit checks: Verify reduction to GR when appropriate
Parameter sensitivity: Additional parameters affect results
Convergence: Numerical accuracy tests
Step 6: Create Example Configuration#
Create an example in examples/inference/my_new_solver/:
# config.yaml
seed: 42
eos:
type: "metamodel"
ndat_metamodel: 100
nmin_MM_nsat: 0.75
tov:
type: "my_new_solver"
my_solver_coupling: 0.1
min_nsat_TOV: 0.75
ndat_TOV: 100
nb_masses: 100
prior: "prior.prior"
likelihoods:
- type: "nicer"
enabled: true
sources: ["J0030"]
sampler:
type: "smc-rw"
n_particles: 1000
n_mcmc_steps: 10
target_ess: 0.9
outdir: "outdir"
Test the example runs successfully:
cd examples/inference/my_new_solver
uv run run_jester_inference config.yaml
Step 7: Documentation#
Add documentation to docs/overview/tov/:
# My New TOV Solver
One-sentence description of the modified gravity theory.
## Theory Background
Explain the gravity theory modifications...
## Equations
Key equations (use LaTeX):
$$
\frac{dM}{dr} = 4\pi r^2 \epsilon(r) + f(\alpha, r)
$$
Where $f(\alpha, r)$ represents the modification.
## Parameters
| Parameter | Description | Typical Range |
|-----------|-------------|---------------|
| theory_param | Coupling constant | -1 to 1 |
## GR Limit
Describe when the solver reduces to General Relativity...
## References
- Author et al. (Year). Title. *Journal* Volume, Page.
## Usage
\```python
from jesterTOV.tov.my_new import MyNewTOVSolver
from jesterTOV.eos.metamodel import MetaModel_EOS_model
# Build EOS
model = MetaModel_EOS_model()
params = {...}
eos_data = model.construct_eos(params)
# Solve TOV
solver = MyNewTOVSolver(coupling_constant=0.1)
family = solver.construct_family(eos_data, ndat=100, min_nsat=0.75)
\```
Update docs/overview/tov_solvers.rst to include your new solver.
Checklist#
Before submitting a PR:
[ ] Solver inherits from
TOVSolverBase[ ]
solve()returns validTOVSolution(M, R, k2)[ ]
construct_family()returns validFamilyData[ ]
get_required_parameters()lists additional parameters[ ] Added
MyNewTOVConfigtoschemas/tov.pyand included inTOVConfigunion[ ] Regenerated YAML reference
[ ] Registered in
_create_tov_solver()factory[ ] Comprehensive tests written and passing
[ ] Example configuration runs successfully
[ ] Documentation added with equations and references
[ ] Uses Diffrax for ODE integration
[ ] JAX-compatible (no Python
ifon traced values)[ ] Type hints for all functions
[ ] Graceful failure handling (
throw=False)[ ] Reduces to GR when appropriate (if applicable)
[ ] Type ignores documented for vmap patterns
[ ] Code passes
uv run pyright jesterTOV/tov/[ ] Tests pass:
uv run pytest tests/test_tov/test_my_new_solver.py -v
Common Pitfalls#
Type ignores for vmap: When using vmap to batch solve(), scalar fields in TOVSolution become arrays:
# This is correct and required
masses: Float[Array, "ndat"] = solutions.M # type: ignore[assignment]
Diffrax throw parameter: Always use throw=False to handle divergent solutions gracefully:
sol = diffrax.diffeqsolve(..., throw=False)
# Check if solution succeeded
if sol.result != 0:
# Handle failure (return NaN or default values)
Chemical potential access: If your solver needs mu from EOSData, check if it’s populated:
if eos_data.mu is None:
raise ValueError("This solver requires mu from EOS")
mu: Float[Array, "n"] = eos_data.mu # type: ignore[assignment]
Unit conversions: Remember JESTER uses geometric units internally. Convert to physical units only at the end:
from jesterTOV.utils import Msun_km, c_km, G_km
M_solar = M_geometric / Msun_km # Convert to solar masses
R_km = R_geometric # Already in km if using correct units
Need Help?#
See existing implementations:
jesterTOV/tov/gr.py,jesterTOV/tov/scalar_tensor.pyReview BlackJAX source code and documentation for sampler best practices
Check Diffrax documentation: https://docs.kidger.site/diffrax/
Review
jesterTOV/inference/CLAUDE.mdfor inference system architecture