r"""
NICER X-ray timing likelihood implementations
This module provides two implementations:
1. NICERLikelihood - Flow-based (NEW DEFAULT, more efficient)
2. NICERKDELikelihood - KDE-based (legacy, for backward compatibility)
"""
import jax
import jax.numpy as jnp
import numpy as np
from jax.scipy.stats import gaussian_kde
from jaxtyping import Array, Float
from jax.scipy.special import logsumexp
from jesterTOV.inference.base.likelihood import LikelihoodBase
from jesterTOV.logging_config import get_logger
from jesterTOV.inference.flows.flow import Flow
logger = get_logger("jester")
[docs]
class NICERLikelihood(LikelihoodBase):
"""
NICER likelihood using normalizing flows (NEW DEFAULT).
This is the recommended NICER likelihood implementation that uses
pre-trained normalizing flows on M-R posteriors for efficient and
deterministic likelihood evaluation.
For the legacy KDE-based version, see NICERKDELikelihood.
The likelihood loads pre-trained flow models for one or both of the Amsterdam
and Maryland analysis groups, and evaluates the likelihood by:
1. Pre-sampling masses ONCE at initialization (deterministic with seed)
2. During evaluation: interpolating radius from the EOS for pre-sampled masses
3. Evaluating the flow log probability at (mass, radius)
4. Averaging over all samples, then averaging over available groups
At least one of ``amsterdam_model_dir`` or ``maryland_model_dir`` must be provided.
If only one group is provided, the likelihood uses only that group.
Parameters
----------
psr_name : str
Pulsar name (e.g., "J0030", "J0740", "J0437", "J0614")
amsterdam_model_dir : str | None
Path to directory containing Amsterdam flow model
(flow_weights.eqx, metadata.json, flow_kwargs.json).
If None, Amsterdam group is omitted.
maryland_model_dir : str | None
Path to directory containing Maryland flow model.
If None, Maryland group is omitted.
penalty_value : float, optional
Penalty value for samples where mass exceeds Mtov (default: -99999.0)
N_masses_evaluation : int, optional
Number of mass samples per likelihood evaluation (default: 20)
N_masses_batch_size : int, optional
Batch size for processing mass samples (default: 10)
seed : int, optional
Random seed for pre-sampling masses (default: 42)
Attributes
----------
psr_name : str
Pulsar name
penalty_value : float
Penalty value for samples where mass exceeds Mtov
N_masses_evaluation : int
Number of mass samples per likelihood evaluation
N_masses_batch_size : int
Batch size for processing mass samples
seed : int
Random seed for deterministic pre-sampling
amsterdam_flow : Flow | None
Normalizing flow for Amsterdam M-R posterior, or None if not provided
maryland_flow : Flow | None
Normalizing flow for Maryland M-R posterior, or None if not provided
amsterdam_fixed_mass_samples : Float[Array, "n_samples"] | None
Pre-sampled mass values from Amsterdam flow (fixed at initialization), or None
maryland_fixed_mass_samples : Float[Array, "n_samples"] | None
Pre-sampled mass values from Maryland flow (fixed at initialization), or None
"""
psr_name: str
penalty_value: float
N_masses_evaluation: int
N_masses_batch_size: int
seed: int
[docs]
def __init__(
self,
psr_name: str,
amsterdam_model_dir: str | None = None,
maryland_model_dir: str | None = None,
penalty_value: float = -99999.0,
N_masses_evaluation: int = 20,
N_masses_batch_size: int = 10,
seed: int = 42,
) -> None:
super().__init__()
self.psr_name = psr_name
self.penalty_value = penalty_value
self.N_masses_evaluation = N_masses_evaluation
self.N_masses_batch_size = N_masses_batch_size
self.seed = seed
if amsterdam_model_dir is None and maryland_model_dir is None:
raise ValueError(
f"At least one of amsterdam_model_dir or maryland_model_dir must be "
f"provided for {psr_name}."
)
key = jax.random.key(seed)
key_amsterdam, key_maryland = jax.random.split(key)
if amsterdam_model_dir is not None:
self.amsterdam_flow, self.amsterdam_fixed_mass_samples = (
self._load_flow_and_presample(
amsterdam_model_dir, key_amsterdam, "Amsterdam"
)
)
print(
f"Amsterdam flow loaded for {psr_name}. Pre-sampled mass range: "
f"[{jnp.min(self.amsterdam_fixed_mass_samples):.3f}, "
f"{jnp.max(self.amsterdam_fixed_mass_samples):.3f}] Msun"
)
else:
self.amsterdam_flow = None
self.amsterdam_fixed_mass_samples = None
if maryland_model_dir is not None:
self.maryland_flow, self.maryland_fixed_mass_samples = (
self._load_flow_and_presample(
maryland_model_dir, key_maryland, "Maryland"
)
)
print(
f"Maryland flow loaded for {psr_name}. Pre-sampled mass range: "
f"[{jnp.min(self.maryland_fixed_mass_samples):.3f}, "
f"{jnp.max(self.maryland_fixed_mass_samples):.3f}] Msun"
)
else:
self.maryland_flow = None
self.maryland_fixed_mass_samples = None
self.active_groups: list[tuple[Flow, Float[Array, "n_samples"]]] = [
(flow, samples)
for flow, samples in [
(self.amsterdam_flow, self.amsterdam_fixed_mass_samples),
(self.maryland_flow, self.maryland_fixed_mass_samples),
]
if flow is not None and samples is not None
]
logger.info(
f"Loaded {len(self.active_groups)} normalizing flow(s) for {psr_name}"
)
def _load_flow_and_presample(
self,
model_dir: str,
key: Array,
group_name: str,
) -> tuple[Flow, Float[Array, "n_samples"]]:
from jesterTOV.inference.flows.flow import Flow
logger.info(f"Loading {group_name} flow for {self.psr_name} from {model_dir}")
flow = Flow.from_directory(model_dir)
mass_samples: Float[Array, "n_samples"] = flow.sample(
key, (self.N_masses_evaluation,)
)[:, 0]
logger.info(
f"Pre-sampled {group_name} mass range: "
f"[{jnp.min(mass_samples):.3f}, {jnp.max(mass_samples):.3f}] Msun"
)
return flow, mass_samples
def _get_preset_model_path(self, psr_name: str, group: str) -> str:
"""
Get preset model path for a pulsar and analysis group.
Parameters
----------
psr_name : str
Pulsar name (e.g., "J0030", "J0740")
group : str
Analysis group ("amsterdam" or "maryland")
Returns
-------
str
Path to preset model directory
Raises
------
ValueError
If no preset exists for this pulsar/group combination
"""
# TODO: Define preset paths once NICER flow models are trained
# For now, this is a placeholder that will be updated in Phase 3
# Example preset structure (to be implemented):
# base_dir = Path(__file__).parent.parent / "flows" / "models" / "nicer_maf"
# model_dir = base_dir / psr_name / f"{psr_name}_{group}_NICER_model"
raise NotImplementedError(
f"Preset model paths for {psr_name} {group} not yet implemented. "
"Please provide explicit model_dir paths or train NICER flows first "
"(see TODO_FLOW_TRAINING.md Phase 3)."
)
[docs]
def evaluate(self, params: dict[str, Float | Array]) -> Float:
"""
Evaluate log likelihood for given EOS parameters.
Uses pre-sampled masses from initialization (deterministic evaluation).
Parameters
----------
params : dict[str, Float | Array]
Must contain:
- 'masses_EOS': Array of neutron star masses from EOS
- 'radii_EOS': Array of neutron star radii from EOS
Returns
-------
Float
Log likelihood value for this NICER observation
"""
masses_EOS: Float[Array, " n_points"] = params["masses_EOS"]
radii_EOS: Float[Array, " n_points"] = params["radii_EOS"]
mtov: Float = jnp.max(masses_EOS)
def compute_group_logL(
flow: Flow, mass_samples: Float[Array, "n_samples"]
) -> Float:
def process_sample(mass: Float) -> Float:
radius = jnp.interp(mass, masses_EOS, radii_EOS, right=0.0)
mr_point = jnp.array([[mass, radius]]) # Shape: (1, 2)
logpdf = flow.log_prob(mr_point)
return logpdf + jnp.where(mass > mtov, self.penalty_value, 0.0)
logprobs = jax.lax.map(
process_sample, mass_samples, batch_size=self.N_masses_batch_size
)
return logsumexp(logprobs) - jnp.log(logprobs.shape[0])
group_logLs = jnp.stack(
[compute_group_logL(flow, samples) for flow, samples in self.active_groups]
)
return logsumexp(group_logLs) - jnp.log(float(group_logLs.shape[0]))
[docs]
class NICERKDELikelihood(LikelihoodBase):
"""
NICER likelihood using KDE (Kernel Density Estimation) approach.
This is the original NICER likelihood implementation that uses KDE
on M-R posterior samples. For the flow-based version, see NICERLikelihood.
TODO: Generalize to e.g. only one group, weights between different hotspot models,...
This likelihood loads posterior samples from Amsterdam and Maryland groups,
constructs KDEs, and evaluates the likelihood by:
1. Sampling masses from the NICER posterior samples
2. Interpolating radius from the EOS for those masses
3. Evaluating the KDE log probability at (mass, radius)
4. Averaging over all samples
Parameters
----------
psr_name : str
Pulsar name (e.g., "J0030", "J0740")
amsterdam_samples_file : str
Path to npz file with Amsterdam group posterior samples
Expected to contain 'mass' (Msun) and 'radius' (km) arrays
maryland_samples_file : str
Path to npz file with Maryland group posterior samples
Expected to contain 'mass' (Msun) and 'radius' (km) arrays
penalty_value : float, optional
Penalty value for samples where mass exceeds Mtov (default: -99999.0)
N_masses_evaluation : int, optional
Number of mass samples per likelihood evaluation (default: 20)
N_masses_batch_size : int, optional
Batch size for processing mass samples (default: 10)
Attributes
----------
psr_name : str
Pulsar name
penalty_value : float
Penalty value for samples where mass exceeds Mtov
N_masses_evaluation : int
Number of mass samples per likelihood evaluation
N_masses_batch_size : int
Batch size for processing mass samples
amsterdam_masses : Float[Array, " n_amsterdam"]
Mass samples from Amsterdam group
maryland_masses : Float[Array, " n_maryland"]
Mass samples from Maryland group
amsterdam_posterior : gaussian_kde
KDE of Amsterdam (mass, radius) posterior
maryland_posterior : gaussian_kde
KDE of Maryland (mass, radius) posterior
"""
psr_name: str
penalty_value: float
N_masses_evaluation: int
N_masses_batch_size: int
amsterdam_masses: Float[Array, " n_amsterdam"]
maryland_masses: Float[Array, " n_maryland"]
amsterdam_posterior: gaussian_kde
maryland_posterior: gaussian_kde
[docs]
def __init__(
self,
psr_name: str,
amsterdam_samples_file: str,
maryland_samples_file: str,
penalty_value: float = -99999.0,
N_masses_evaluation: int = 20,
N_masses_batch_size: int = 10,
) -> None:
super().__init__()
self.psr_name = psr_name
self.penalty_value = penalty_value
self.N_masses_evaluation = N_masses_evaluation
self.N_masses_batch_size = N_masses_batch_size
# Load samples from npz files
logger.info(
f"Loading Amsterdam samples for {psr_name} from {amsterdam_samples_file}"
)
amsterdam_data = np.load(amsterdam_samples_file, allow_pickle=True)
logger.info(
f"Loading Maryland samples for {psr_name} from {maryland_samples_file}"
)
maryland_data = np.load(maryland_samples_file, allow_pickle=True)
# Extract mass and radius samples
# File format: mass (Msun), radius (km)
amsterdam_mass = amsterdam_data["mass"]
amsterdam_radius = amsterdam_data["radius"]
maryland_mass = maryland_data["mass"]
maryland_radius = maryland_data["radius"]
# Store mass samples as JAX arrays for random sampling
self.amsterdam_masses = jnp.array(amsterdam_mass)
self.maryland_masses = jnp.array(maryland_mass)
# Stack into [mass, radius] arrays for KDE
# Convert to JAX arrays for JAX KDE
amsterdam_mr = jnp.vstack([amsterdam_mass, amsterdam_radius])
maryland_mr = jnp.vstack([maryland_mass, maryland_radius])
# Construct KDEs using JAX implementation
logger.info(f"Constructing JAX KDEs for {psr_name}")
self.amsterdam_posterior = gaussian_kde(amsterdam_mr)
self.maryland_posterior = gaussian_kde(maryland_mr)
logger.info(f"Loaded JAX KDEs for {psr_name}")
[docs]
def evaluate(self, params: dict[str, Float | Array]) -> Float:
"""
Evaluate log likelihood for given EOS parameters
Parameters
----------
params : dict[str, Float | Array]
Must contain:
- '_random_key': Random seed for mass sampling (cast to int64)
- 'masses_EOS': Array of neutron star masses from EOS
- 'radii_EOS': Array of neutron star radii from EOS
Returns
-------
Float
Log likelihood value for this NICER observation
"""
# Extract parameters
sampled_key = params["_random_key"].astype("int64")
key = jax.random.key(sampled_key)
masses_EOS: Float[Array, " n_points"] = params["masses_EOS"]
radii_EOS: Float[Array, " n_points"] = params["radii_EOS"]
mtov: Float = jnp.max(masses_EOS)
# Split key for Amsterdam and Maryland sampling
key_amsterdam, key_maryland = jax.random.split(key)
# Sample masses from the NICER posterior samples
# Each group gets half of N_masses_evaluation samples
n_samples_per_group: int = self.N_masses_evaluation // 2
# Sample indices and get mass samples
amsterdam_indices = jax.random.choice(
key_amsterdam,
len(self.amsterdam_masses),
shape=(n_samples_per_group,),
replace=True,
)
maryland_indices = jax.random.choice(
key_maryland,
len(self.maryland_masses),
shape=(n_samples_per_group,),
replace=True,
)
amsterdam_mass_samples: Float[Array, " n_amsterdam_samples"] = (
self.amsterdam_masses[amsterdam_indices]
)
maryland_mass_samples: Float[Array, " n_maryland_samples"] = (
self.maryland_masses[maryland_indices]
)
def compute_group_logL(
posterior_kde: gaussian_kde, mass_samples: Float[Array, "n_samples"]
) -> Float:
def process_sample(mass: Float) -> Float:
radius = jnp.interp(mass, masses_EOS, radii_EOS, right=0.0)
mr_point = jnp.array([[mass], [radius]]) # Shape: (2, 1)
logpdf = posterior_kde.logpdf(mr_point)
return logpdf + jnp.where(mass > mtov, self.penalty_value, 0.0)
logprobs = jax.lax.map(
process_sample, mass_samples, batch_size=self.N_masses_batch_size
)
return logsumexp(logprobs) - jnp.log(logprobs.shape[0])
logL_amsterdam = compute_group_logL(
self.amsterdam_posterior, amsterdam_mass_samples
)
logL_maryland = compute_group_logL(
self.maryland_posterior, maryland_mass_samples
)
return logsumexp(jnp.array([logL_amsterdam, logL_maryland])) - jnp.log(2.0)