Adding a New Likelihood#
This guide walks through adding a new likelihood constraint to JESTER’s inference system. Likelihoods encode observational constraints from multi-messenger astronomy (gravitational waves, X-ray timing, radio observations, nuclear experiments).
Architecture Overview#
Likelihoods evaluate the probability of observing data given EOS parameters. All likelihoods inherit from LikelihoodBase and are created through a factory pattern. The inference system automatically combines multiple likelihoods with equal weighting.
Key files:
Likelihood implementation:
jesterTOV/inference/likelihoods/your_likelihood.pyConfiguration schema:
jesterTOV/inference/config/schema.pyLikelihood factory:
jesterTOV/inference/likelihoods/factory.pyData loading:
jesterTOV/inference/data/__init__.pyTests:
tests/test_inference/test_likelihoods/test_your_likelihood.py
Step 1: Create Likelihood Class#
Create your likelihood in jesterTOV/inference/likelihoods/ inheriting from LikelihoodBase:
from typing import Any
import jax.numpy as jnp
from jaxtyping import Array, Float
from jesterTOV.inference.base.likelihood import LikelihoodBase
class MyNewLikelihood(LikelihoodBase):
"""One-line description of observational constraint.
Detailed description including:
- What observation/experiment this represents
- Data sources and references
- Statistical model used
References:
Author et al. (Year). Title. Journal Volume, Page.
"""
def __init__(self, data: Any):
"""Initialize likelihood with observational data.
Args:
data: Observational data structure (format depends on likelihood)
Common formats: dict, JAX arrays, KDE objects
"""
self.data = data
def evaluate(self, params: dict[str, float], data: Any) -> float:
"""Compute log probability for parameters.
This method is called by the sampler during inference. It receives
the transformed parameters (including M-R-Λ curves from JesterTransform)
and returns the log likelihood.
Args:
params: Dictionary containing:
- EOS parameters (e.g., K_sat, L_sym)
- TOV parameters (if any)
- Derived quantities from transform:
- masses_EOS: Stellar masses (M☉)
- radii_EOS: Stellar radii (km)
- Lambdas_EOS: Tidal deformabilities
- logpc_EOS: Log10 central pressures
data: Observational data (passed through from __init__)
Returns:
Log probability: log P(data | params)
"""
# Extract relevant quantities from params
masses = params.get("masses_EOS") # From JesterTransform
radii = params.get("radii_EOS") # From JesterTransform
# Your likelihood calculation here
# Example: Gaussian likelihood for mass-radius constraint
log_likelihood = 0.0
# Add contributions from each observation
for obs in self.data["observations"]:
# Interpolate model prediction at observation point
predicted = self._interpolate_model(
masses, radii, obs["mass"]
)
# Gaussian log likelihood
residual = (predicted - obs["radius"]) / obs["error"]
log_likelihood += -0.5 * residual**2
return log_likelihood
def _interpolate_model(
self,
masses: Float[Array, "n"],
radii: Float[Array, "n"],
target_mass: float
) -> float:
"""Helper to interpolate model prediction.
Args:
masses: Model mass grid
radii: Model radius grid
target_mass: Mass at which to evaluate
Returns:
Interpolated radius
"""
return jnp.interp(target_mass, masses, radii)
Critical considerations:
JAX compatibility: Use
jax.numpyfor all computationsVectorization: Design for efficient evaluation over posterior samples
Numerical stability: Handle edge cases (divide by zero, log of negative)
Transform integration: Know what quantities
JesterTransformprovidesReturn log probabilities: Never return raw probabilities (numerical underflow)
Step 2: Add Data Loading Function#
If your likelihood requires external data, add loading to jesterTOV/inference/data/__init__.py:
def load_my_observation_data() -> dict[str, Any]:
"""Load observational data for MyNewLikelihood.
Returns:
Dictionary with observational data:
- observations: List of measurements
- metadata: Additional information
"""
# Option 1: Load from package data
data_path = Path(__file__).parent / "data_files" / "my_observation.json"
with open(data_path) as f:
data = json.load(f)
# Option 2: Download from Zenodo and cache
from jesterTOV.inference.data.paths import get_data_path
data_path = get_data_path(
"my_observation.h5",
zenodo_url="https://zenodo.org/record/XXXXX/files/my_observation.h5"
)
data = load_hdf5(data_path)
return {
"observations": data["measurements"],
"metadata": data["metadata"]
}
Data management best practices:
Small data (<1 MB): Include in
jesterTOV/inference/data/data_files/Large data (>1 MB): Host on Zenodo, cache locally with
get_data_path()Validate data format in loading function
Document data provenance and references
Step 3: Update Configuration Schema#
Add your likelihood config class to jesterTOV/inference/config/schemas/likelihoods.py and extend the LikelihoodConfig discriminated union there:
# In jesterTOV/inference/config/schemas/likelihoods.py
class MyNewLikelihoodConfig(BaseLikelihoodConfig):
"""Configuration for MyNewLikelihood."""
type: Literal["my_new_likelihood"] = "my_new_likelihood"
# Add likelihood-specific parameters
observation_set: str = Field(
default="default",
description="Which observation dataset to use"
)
systematic_error: float = Field(
default=0.0,
description="Additional systematic uncertainty (km)"
)
# Extend the discriminated union at the bottom of the file
LikelihoodConfig = Annotated[
Union[
GWLikelihoodConfig,
NICERLikelihoodConfig,
# ... existing configs ...
MyNewLikelihoodConfig,
],
Discriminator("type"),
]
Then re-export MyNewLikelihoodConfig from both config/schema.py (in the import block and __all__) and config/__init__.py so it is accessible as jesterTOV.inference.config.MyNewLikelihoodConfig.
Regenerate YAML documentation:
uv run python -m jesterTOV.inference.config.generate_yaml_reference
Step 4: Register in Likelihood Factory#
Add creation logic to jesterTOV/inference/likelihoods/factory.py:
def create_likelihood(
config: LikelihoodConfig,
transform: Optional[NtoMTransform] = None,
) -> LikelihoodBase:
"""Create likelihood from configuration."""
if config.type == "my_new_likelihood":
from jesterTOV.inference.data import load_my_observation_data
from jesterTOV.inference.likelihoods.my_new import MyNewLikelihood
data = load_my_observation_data()
# Apply configuration options
if config.systematic_error > 0:
# Add systematic uncertainty
for obs in data["observations"]:
obs["error"] = jnp.sqrt(
obs["error"]**2 + config.systematic_error**2
)
return MyNewLikelihood(data)
# ... other likelihood types ...
Step 5: Write Tests#
Create comprehensive tests in tests/test_inference/test_likelihoods/test_my_new_likelihood.py:
import jax.numpy as jnp
import pytest
from jesterTOV.inference.likelihoods.my_new import MyNewLikelihood
from jesterTOV.eos.metamodel import MetaModel_EOS_model
from jesterTOV.tov.gr import GRTOVSolver
class TestMyNewLikelihood:
"""Test suite for MyNewLikelihood."""
@pytest.fixture
def mock_data(self):
"""Create mock observational data."""
return {
"observations": [
{"mass": 1.4, "radius": 12.0, "error": 0.5},
{"mass": 2.0, "radius": 11.5, "error": 0.7},
]
}
@pytest.fixture
def mock_params(self):
"""Create mock parameter dict with M-R curves."""
# Build realistic M-R curve
model = MetaModel_EOS_model()
eos_params = {
"E_sat": -16.0,
"K_sat": 240.0,
"Q_sat": 400.0,
"Z_sat": 0.0,
"E_sym": 31.7,
"L_sym": 58.7,
"K_sym": -100.0,
"Q_sym": 0.0,
"Z_sym": 0.0,
}
eos_data = model.construct_eos(eos_params)
solver = GRTOVSolver()
family = solver.construct_family(eos_data, ndat=100, min_nsat=0.75)
return {
**eos_params,
"masses": family.masses,
"radii": family.radii,
"lambdas": family.lambdas,
}
def test_likelihood_shape(self, mock_data, mock_params):
"""Test that likelihood returns scalar."""
likelihood = MyNewLikelihood(mock_data)
log_prob = likelihood.evaluate(mock_params, mock_data)
assert isinstance(log_prob, (float, jnp.ndarray))
if isinstance(log_prob, jnp.ndarray):
assert log_prob.shape == ()
def test_likelihood_is_finite(self, mock_data, mock_params):
"""Test that likelihood returns finite value."""
likelihood = MyNewLikelihood(mock_data)
log_prob = likelihood.evaluate(mock_params, mock_data)
assert jnp.isfinite(log_prob)
def test_likelihood_sensitivity(self, mock_data, mock_params):
"""Test that likelihood responds to parameter changes."""
likelihood = MyNewLikelihood(mock_data)
log_prob1 = likelihood.evaluate(mock_params, mock_data)
# Modify radii (worse fit)
mock_params_bad = mock_params.copy()
mock_params_bad["radii"] = mock_params["radii"] + 2.0 # Shift by 2 km
log_prob2 = likelihood.evaluate(mock_params_bad, mock_data)
# Worse fit should have lower (more negative) log probability
assert log_prob2 < log_prob1
def test_configuration_options(self):
"""Test that configuration options are applied."""
from jesterTOV.inference.config.schema import MyNewLikelihoodConfig
from jesterTOV.inference.likelihoods.factory import create_likelihood
config = MyNewLikelihoodConfig(
type="my_new_likelihood",
systematic_error=0.5
)
likelihood = create_likelihood(config)
assert isinstance(likelihood, MyNewLikelihood)
def test_integration_with_inference(self, tmp_path):
"""Test end-to-end inference with this likelihood."""
config_yaml = """
seed: 42
eos:
type: "metamodel"
ndat_metamodel: 50
tov:
type: "gr"
prior: "prior.prior"
likelihoods:
- type: "my_new_likelihood"
enabled: true
sampler:
type: "smc-rw"
n_particles: 100
n_mcmc_steps: 5
outdir: "outdir"
"""
config_path = tmp_path / "config.yaml"
config_path.write_text(config_yaml)
prior_file = tmp_path / "prior.prior"
prior_file.write_text("""
E_sat = UniformPrior(-16.1, -15.9, parameter_names=["E_sat"])
K_sat = UniformPrior(150.0, 300.0, parameter_names=["K_sat"])
# ... other priors ...
""")
# Run inference (lightweight parameters for testing)
from jesterTOV.inference.run_inference import main
main(config_path)
# Check output exists
assert (tmp_path / "outdir").exists()
Test categories:
Shape/type validation: Returns scalar log probability
Numerical stability: Finite values, no NaN/Inf
Sensitivity: Responds correctly to parameter changes
Configuration: Factory correctly applies options
Integration: Works in full inference pipeline
Step 6: Create Example Configuration#
Create an example in examples/inference/my_new_likelihood/:
# config.yaml
seed: 42
eos:
type: metamodel
ndat_metamodel: 100
tov:
type: gr
prior:
specification_file: prior.prior
likelihoods:
- type: my_new_likelihood
enabled: true
observation_set: default
systematic_error: 0.1
# Can combine with other likelihoods
- type: constraints_eos
enabled: true
sampler:
type: smc-rw
n_particles: 2000
n_mcmc_steps: 20
target_ess: 0.9
outdir: outdir
Test the example runs successfully:
cd examples/inference/my_new_likelihood
uv run run_jester_inference config.yaml
Step 7: Documentation#
Add documentation to docs/overview/likelihoods/:
# My New Likelihood
One-sentence description of the observational constraint.
## Observational Data
Describe the observation or experiment:
- What was measured
- Instruments/telescopes used
- Key results and uncertainties
## Statistical Model
Explain the likelihood function:
$$
\mathcal{L}(R | M, \theta) = \prod_i \frac{1}{\sqrt{2\pi\sigma_i^2}} \exp\left[-\frac{(R_i - R_{\mathrm{model}}(M_i, \theta))^2}{2\sigma_i^2}\right]
$$
Where:
- $R_i$: Observed radius
- $M_i$: Measured mass
- $\theta$: EOS parameters
- $\sigma_i$: Measurement uncertainty
## Configuration
| Parameter | Description | Default |
|-----------|-------------|---------|
| observation_set | Dataset to use | "default" |
| systematic_error | Additional systematic (km) | 0.0 |
## References
- Observer et al. (Year). Title. *Journal* Volume, Page.
## Usage
\```yaml
likelihoods:
- type: "my_new_likelihood"
enabled: true
observation_set: "latest"
systematic_error: 0.2
\```
Update docs/overview/likelihoods.rst to include your new likelihood.
Checklist#
Before submitting a PR:
[ ] Likelihood inherits from
LikelihoodBase[ ]
evaluate()returns finite log probability (scalar)[ ] Data loading function added if needed
[ ] Configuration model added to
schema.py[ ] Regenerated YAML reference
[ ] Registered in
factory.py[ ] Comprehensive tests written and passing
[ ] Integration test with full inference pipeline
[ ] Example configuration runs successfully
[ ] Documentation with equations and references
[ ] JAX-compatible (uses
jax.numpy)[ ] Handles edge cases (e.g., extrapolation beyond M-R range)
[ ] Type hints for all functions
[ ] Data provenance documented
[ ] Code passes
uv run pyright jesterTOV/inference/likelihoods/[ ] Tests pass:
uv run pytest tests/test_inference/test_likelihoods/test_my_new_likelihood.py -v
Common Pitfalls#
Not returning log probabilities: Always return log P, never P. This avoids numerical underflow:
# BAD: Probability underflows to zero
prob = 1e-300
return prob
# GOOD: Log probability stays in range
log_prob = -690.7755 # log(1e-300)
return log_prob
Transform output not used: The params dict contains outputs from JesterTransform. Know what’s available:
# Available from JesterTransform (note the _EOS suffix)
masses = params["masses_EOS"] # Stellar masses (M☉)
radii = params["radii_EOS"] # Stellar radii (km)
lambdas = params["Lambdas_EOS"] # Tidal deformabilities
logpc = params["logpc_EOS"] # Log10 central pressures
# EOS quantities (on the density grid)
n = params["n"] # Number density
p = params["p"] # Pressure
cs2 = params["cs2"] # Sound speed squared
# Constraint violation counts (for physical validity checks)
n_tov_failures = params["n_tov_failures"]
n_causality_violations = params["n_causality_violations"]
# Original EOS parameters also present
K_sat = params["K_sat"]
L_sym = params["L_sym"]
Interpolation beyond M-R range: Handle cases where observation lies outside model range:
# Check if observation is in valid range
if target_mass < jnp.min(masses) or target_mass > jnp.max(masses):
# Return very negative log likelihood (rejected)
return -1e10
# Or: use fill_value in jnp.interp
radius = jnp.interp(target_mass, masses, radii, left=-jnp.inf, right=-jnp.inf)
Combining likelihoods incorrectly: The CombinedLikelihood wrapper handles summation. Just return your log probability:
# In your likelihood - just return single log probability
return log_prob
# The framework handles combination:
# total_log_prob = sum(likelihood.evaluate(params, data) for likelihood in likelihoods)
Not handling disabled likelihoods: The factory respects the enabled flag:
# In factory.py - check enabled flag
if not config.enabled:
return None # Or skip entirely
Advanced: Resampling Likelihoods#
Some likelihoods (like GW posteriors) require resampling during MCMC. See GWLikelihoodResampled for an example:
class ResamplingLikelihood(LikelihoodBase):
"""Likelihood that resamples from posterior at each evaluation."""
def __init__(self, posterior_samples: Array, n_resample: int = 1000):
"""Initialize with posterior samples from other analyses."""
self.posterior_samples = posterior_samples
self.n_resample = n_resample
def evaluate(self, params: dict, data: Any, rng_key: Any = None) -> float:
"""Evaluate by resampling posterior."""
if rng_key is None:
rng_key = jax.random.PRNGKey(0)
# Resample from posterior
indices = jax.random.choice(
rng_key, len(self.posterior_samples), shape=(self.n_resample,)
)
samples = self.posterior_samples[indices]
# Compute likelihood for each sample
# ... (compare with model predictions)
return log_likelihood
Resampling likelihoods need special handling in the sampler to pass PRNGKey.
Need Help?#
See existing implementations:
jesterTOV/inference/likelihoods/nicer.py,gw.pyReview
jesterTOV/inference/CLAUDE.mdfor inference architectureCheck data loading examples in
jesterTOV/inference/data/__init__.pyFor statistical model questions, consult relevant papers in
jesterTOV/references.bib