Source code for jesterTOV.inference.likelihoods.nicer

r"""
NICER X-ray timing likelihood implementations

This module provides two implementations:
1. NICERLikelihood - Flow-based (NEW DEFAULT, more efficient)
2. NICERKDELikelihood - KDE-based (legacy, for backward compatibility)
"""

import jax
import jax.numpy as jnp
import numpy as np
from jax.scipy.stats import gaussian_kde
from jaxtyping import Array, Float
from jax.scipy.special import logsumexp

from jesterTOV.inference.base.likelihood import LikelihoodBase
from jesterTOV.logging_config import get_logger

logger = get_logger("jester")


[docs] class NICERLikelihood(LikelihoodBase): """ NICER likelihood using normalizing flows (NEW DEFAULT). This is the recommended NICER likelihood implementation that uses pre-trained normalizing flows on M-R posteriors for efficient and deterministic likelihood evaluation. For the legacy KDE-based version, see NICERKDELikelihood. The likelihood loads pre-trained flow models for Amsterdam and Maryland groups and evaluates the likelihood by: 1. Pre-sampling masses ONCE at initialization (deterministic with seed) 2. During evaluation: interpolating radius from the EOS for pre-sampled masses 3. Evaluating the flow log probability at (mass, radius) 4. Averaging over all samples Parameters ---------- psr_name : str Pulsar name (e.g., "J0030", "J0740") amsterdam_model_dir : str | None Path to directory containing Amsterdam flow model (flow_weights.eqx, metadata.json, flow_kwargs.json). If None, uses preset model path. maryland_model_dir : str | None Path to directory containing Maryland flow model. If None, uses preset model path. penalty_value : float, optional Penalty value for samples where mass exceeds Mtov (default: -99999.0) N_masses_evaluation : int, optional Number of mass samples per likelihood evaluation (default: 20) N_masses_batch_size : int, optional Batch size for processing mass samples (default: 10) seed : int, optional Random seed for pre-sampling masses (default: 42) Attributes ---------- psr_name : str Pulsar name penalty_value : float Penalty value for samples where mass exceeds Mtov N_masses_evaluation : int Number of mass samples per likelihood evaluation N_masses_batch_size : int Batch size for processing mass samples seed : int Random seed for deterministic pre-sampling amsterdam_flow : Flow Normalizing flow for Amsterdam M-R posterior maryland_flow : Flow Normalizing flow for Maryland M-R posterior amsterdam_fixed_mass_samples : Float[Array, "n_samples"] Pre-sampled mass values from Amsterdam flow (fixed at initialization) maryland_fixed_mass_samples : Float[Array, "n_samples"] Pre-sampled mass values from Maryland flow (fixed at initialization) """ psr_name: str penalty_value: float N_masses_evaluation: int N_masses_batch_size: int seed: int
[docs] def __init__( self, psr_name: str, amsterdam_model_dir: str | None = None, maryland_model_dir: str | None = None, penalty_value: float = -99999.0, N_masses_evaluation: int = 20, N_masses_batch_size: int = 10, seed: int = 42, ) -> None: super().__init__() self.psr_name = psr_name self.penalty_value = penalty_value self.N_masses_evaluation = N_masses_evaluation self.N_masses_batch_size = N_masses_batch_size self.seed = seed # Import Flow here to avoid circular imports from jesterTOV.inference.flows.flow import Flow # Validate that both model directories are provided if amsterdam_model_dir is None or maryland_model_dir is None: raise ValueError( f"Both amsterdam_model_dir and maryland_model_dir must be provided for {psr_name}. " "Preset model paths are not yet implemented. " "Please provide explicit paths to trained flow models " "(see TODO_FLOW_TRAINING.md Phase 3)." ) # Use provided model paths logger.info(f"Using Amsterdam model directory: {amsterdam_model_dir}") logger.info(f"Using Maryland model directory: {maryland_model_dir}") # Load flow models logger.info(f"Loading Amsterdam flow for {psr_name} from {amsterdam_model_dir}") self.amsterdam_flow = Flow.from_directory(amsterdam_model_dir) logger.info(f"Loading Maryland flow for {psr_name} from {maryland_model_dir}") self.maryland_flow = Flow.from_directory(maryland_model_dir) logger.info(f"Loaded normalizing flows for {psr_name}") # Pre-sample masses ONCE at initialization (deterministic with seed) logger.info( f"Pre-sampling {N_masses_evaluation} masses with seed={seed} for {psr_name}" ) key = jax.random.key(seed) key_amsterdam, key_maryland = jax.random.split(key) # Sample (mass, radius) from flows amsterdam_samples = self.amsterdam_flow.sample( key_amsterdam, (N_masses_evaluation,) ) maryland_samples = self.maryland_flow.sample( key_maryland, (N_masses_evaluation,) ) # Extract only masses (first column), discard radius values self.amsterdam_fixed_mass_samples = amsterdam_samples[:, 0] # Shape: (N,) self.maryland_fixed_mass_samples = maryland_samples[:, 0] # Shape: (N,) logger.info( f"Pre-sampled Amsterdam mass range: " f"[{jnp.min(self.amsterdam_fixed_mass_samples):.3f}, " f"{jnp.max(self.amsterdam_fixed_mass_samples):.3f}] Msun" ) logger.info( f"Pre-sampled Maryland mass range: " f"[{jnp.min(self.maryland_fixed_mass_samples):.3f}, " f"{jnp.max(self.maryland_fixed_mass_samples):.3f}] Msun" )
def _get_preset_model_path(self, psr_name: str, group: str) -> str: """ Get preset model path for a pulsar and analysis group. Parameters ---------- psr_name : str Pulsar name (e.g., "J0030", "J0740") group : str Analysis group ("amsterdam" or "maryland") Returns ------- str Path to preset model directory Raises ------ ValueError If no preset exists for this pulsar/group combination """ # TODO: Define preset paths once NICER flow models are trained # For now, this is a placeholder that will be updated in Phase 3 # Example preset structure (to be implemented): # base_dir = Path(__file__).parent.parent / "flows" / "models" / "nicer_maf" # model_dir = base_dir / psr_name / f"{psr_name}_{group}_NICER_model" raise NotImplementedError( f"Preset model paths for {psr_name} {group} not yet implemented. " "Please provide explicit model_dir paths or train NICER flows first " "(see TODO_FLOW_TRAINING.md Phase 3)." )
[docs] def evaluate(self, params: dict[str, Float | Array]) -> Float: """ Evaluate log likelihood for given EOS parameters. Uses pre-sampled masses from initialization (deterministic evaluation). Parameters ---------- params : dict[str, Float | Array] Must contain: - 'masses_EOS': Array of neutron star masses from EOS - 'radii_EOS': Array of neutron star radii from EOS Returns ------- Float Log likelihood value for this NICER observation """ # Extract parameters masses_EOS: Float[Array, " n_points"] = params["masses_EOS"] radii_EOS: Float[Array, " n_points"] = params["radii_EOS"] mtov: Float = jnp.max(masses_EOS) def process_sample_amsterdam(mass: Float) -> Float: """ Process a single Amsterdam mass sample Parameters ---------- mass : Float Sampled mass value (scalar) Returns ------- Float Log probability from Amsterdam flow including penalty """ # Interpolate radius from EOS radius = jnp.interp(mass, masses_EOS, radii_EOS) # Evaluate Amsterdam flow at (mass, radius) mr_point = jnp.array( [[mass, radius]] ) # Shape: (1, 2) for (n_samples, n_features) logpdf = self.amsterdam_flow.log_prob(mr_point) # Penalty for mass exceeding Mtov penalty = jnp.where(mass > mtov, self.penalty_value, 0.0) return logpdf + penalty def process_sample_maryland(mass: Float) -> Float: """ Process a single Maryland mass sample Parameters ---------- mass : Float Sampled mass value (scalar) Returns ------- Float Log probability from Maryland flow including penalty """ # Interpolate radius from EOS radius = jnp.interp(mass, masses_EOS, radii_EOS) # Evaluate Maryland flow at (mass, radius) mr_point = jnp.array( [[mass, radius]] ) # Shape: (1, 2) for (n_samples, n_features) logpdf = self.maryland_flow.log_prob(mr_point) # Penalty for mass exceeding Mtov penalty = jnp.where(mass > mtov, self.penalty_value, 0.0) return logpdf + penalty # Use jax.lax.map with batching for memory-efficient processing amsterdam_logprobs = jax.lax.map( process_sample_amsterdam, self.amsterdam_fixed_mass_samples, batch_size=self.N_masses_batch_size, ) maryland_logprobs = jax.lax.map( process_sample_maryland, self.maryland_fixed_mass_samples, batch_size=self.N_masses_batch_size, ) # Average over all samples for each group (log-mean = logsumexp - log(N)) N_amsterdam = amsterdam_logprobs.shape[0] N_maryland = maryland_logprobs.shape[0] logL_amsterdam = logsumexp(amsterdam_logprobs) - jnp.log(N_amsterdam) logL_maryland = logsumexp(maryland_logprobs) - jnp.log(N_maryland) # Average the two groups (equal weights, log-mean = logsumexp - log(2)) log_likelihood = logsumexp( jnp.array([logL_amsterdam, logL_maryland]) ) - jnp.log(2.0) return log_likelihood
class NICERKDELikelihood(LikelihoodBase): """ NICER likelihood using KDE (Kernel Density Estimation) approach. This is the original NICER likelihood implementation that uses KDE on M-R posterior samples. For the flow-based version, see NICERLikelihood. TODO: Generalize to e.g. only one group, weights between different hotspot models,... This likelihood loads posterior samples from Amsterdam and Maryland groups, constructs KDEs, and evaluates the likelihood by: 1. Sampling masses from the NICER posterior samples 2. Interpolating radius from the EOS for those masses 3. Evaluating the KDE log probability at (mass, radius) 4. Averaging over all samples Parameters ---------- psr_name : str Pulsar name (e.g., "J0030", "J0740") amsterdam_samples_file : str Path to npz file with Amsterdam group posterior samples Expected to contain 'mass' (Msun) and 'radius' (km) arrays maryland_samples_file : str Path to npz file with Maryland group posterior samples Expected to contain 'mass' (Msun) and 'radius' (km) arrays penalty_value : float, optional Penalty value for samples where mass exceeds Mtov (default: -99999.0) N_masses_evaluation : int, optional Number of mass samples per likelihood evaluation (default: 20) N_masses_batch_size : int, optional Batch size for processing mass samples (default: 10) Attributes ---------- psr_name : str Pulsar name penalty_value : float Penalty value for samples where mass exceeds Mtov N_masses_evaluation : int Number of mass samples per likelihood evaluation N_masses_batch_size : int Batch size for processing mass samples amsterdam_masses : Float[Array, " n_amsterdam"] Mass samples from Amsterdam group maryland_masses : Float[Array, " n_maryland"] Mass samples from Maryland group amsterdam_posterior : gaussian_kde KDE of Amsterdam (mass, radius) posterior maryland_posterior : gaussian_kde KDE of Maryland (mass, radius) posterior """ psr_name: str penalty_value: float N_masses_evaluation: int N_masses_batch_size: int amsterdam_masses: Float[Array, " n_amsterdam"] maryland_masses: Float[Array, " n_maryland"] amsterdam_posterior: gaussian_kde maryland_posterior: gaussian_kde def __init__( self, psr_name: str, amsterdam_samples_file: str, maryland_samples_file: str, penalty_value: float = -99999.0, N_masses_evaluation: int = 20, N_masses_batch_size: int = 10, ) -> None: super().__init__() self.psr_name = psr_name self.penalty_value = penalty_value self.N_masses_evaluation = N_masses_evaluation self.N_masses_batch_size = N_masses_batch_size # Load samples from npz files logger.info( f"Loading Amsterdam samples for {psr_name} from {amsterdam_samples_file}" ) amsterdam_data = np.load(amsterdam_samples_file, allow_pickle=True) logger.info( f"Loading Maryland samples for {psr_name} from {maryland_samples_file}" ) maryland_data = np.load(maryland_samples_file, allow_pickle=True) # Extract mass and radius samples # File format: mass (Msun), radius (km) amsterdam_mass = amsterdam_data["mass"] amsterdam_radius = amsterdam_data["radius"] maryland_mass = maryland_data["mass"] maryland_radius = maryland_data["radius"] # Store mass samples as JAX arrays for random sampling self.amsterdam_masses = jnp.array(amsterdam_mass) self.maryland_masses = jnp.array(maryland_mass) # Stack into [mass, radius] arrays for KDE # Convert to JAX arrays for JAX KDE amsterdam_mr = jnp.vstack([amsterdam_mass, amsterdam_radius]) maryland_mr = jnp.vstack([maryland_mass, maryland_radius]) # Construct KDEs using JAX implementation logger.info(f"Constructing JAX KDEs for {psr_name}") self.amsterdam_posterior = gaussian_kde(amsterdam_mr) self.maryland_posterior = gaussian_kde(maryland_mr) logger.info(f"Loaded JAX KDEs for {psr_name}") def evaluate(self, params: dict[str, Float | Array]) -> Float: """ Evaluate log likelihood for given EOS parameters Parameters ---------- params : dict[str, Float | Array] Must contain: - '_random_key': Random seed for mass sampling (cast to int64) - 'masses_EOS': Array of neutron star masses from EOS - 'radii_EOS': Array of neutron star radii from EOS Returns ------- Float Log likelihood value for this NICER observation """ # Extract parameters sampled_key = params["_random_key"].astype("int64") key = jax.random.key(sampled_key) masses_EOS: Float[Array, " n_points"] = params["masses_EOS"] radii_EOS: Float[Array, " n_points"] = params["radii_EOS"] mtov: Float = jnp.max(masses_EOS) # Split key for Amsterdam and Maryland sampling key_amsterdam, key_maryland = jax.random.split(key) # Sample masses from the NICER posterior samples # Each group gets half of N_masses_evaluation samples n_samples_per_group: int = self.N_masses_evaluation // 2 # Sample indices and get mass samples amsterdam_indices = jax.random.choice( key_amsterdam, len(self.amsterdam_masses), shape=(n_samples_per_group,), replace=True, ) maryland_indices = jax.random.choice( key_maryland, len(self.maryland_masses), shape=(n_samples_per_group,), replace=True, ) amsterdam_mass_samples: Float[Array, " n_amsterdam_samples"] = ( self.amsterdam_masses[amsterdam_indices] ) maryland_mass_samples: Float[Array, " n_maryland_samples"] = ( self.maryland_masses[maryland_indices] ) def process_sample_amsterdam(mass: Float) -> Float: """ Process a single Amsterdam mass sample Parameters ---------- mass : Float Sampled mass value Returns ------- Float Log probability from Amsterdam KDE including penalty """ # Interpolate radius from EOS radius = jnp.interp(mass, masses_EOS, radii_EOS) # Evaluate Amsterdam KDE at (mass, radius) mr_point = jnp.array([[mass], [radius]]) # Shape: (2, 1) logpdf = self.amsterdam_posterior.logpdf(mr_point) # Penalty for mass exceeding Mtov penalty = jnp.where(mass > mtov, self.penalty_value, 0.0) return logpdf + penalty def process_sample_maryland(mass: Float) -> Float: """ Process a single Maryland mass sample Parameters ---------- mass : Float Sampled mass value Returns ------- Float Log probability from Maryland KDE including penalty """ # Interpolate radius from EOS radius = jnp.interp(mass, masses_EOS, radii_EOS) # Evaluate Maryland KDE at (mass, radius) mr_point = jnp.array([[mass], [radius]]) # Shape: (2, 1) logpdf = self.maryland_posterior.logpdf(mr_point) # Penalty for mass exceeding Mtov penalty = jnp.where(mass > mtov, self.penalty_value, 0.0) return logpdf + penalty # Use jax.lax.map with batching for memory-efficient processing amsterdam_logprobs = jax.lax.map( process_sample_amsterdam, amsterdam_mass_samples, batch_size=self.N_masses_batch_size, ) maryland_logprobs = jax.lax.map( process_sample_maryland, maryland_mass_samples, batch_size=self.N_masses_batch_size, ) # Average over all samples for each group (log-mean = logsumexp - log(N)) N_amsterdam = amsterdam_logprobs.shape[0] N_maryland = maryland_logprobs.shape[0] logL_amsterdam = logsumexp(amsterdam_logprobs) - jnp.log(N_amsterdam) logL_maryland = logsumexp(maryland_logprobs) - jnp.log(N_maryland) # Average the two groups (equal weights, log-mean = logsumexp - log(2)) log_likelihood = logsumexp( jnp.array([logL_amsterdam, logL_maryland]) ) - jnp.log(2.0) return log_likelihood