Source code for jesterTOV.inference.likelihoods.gw

r"""Gravitational wave event likelihood implementations"""

import jax
import jax.numpy as jnp
from jaxtyping import Array, Float
from jax.scipy.special import logsumexp

from jesterTOV.inference.base.likelihood import LikelihoodBase
from jesterTOV.inference.flows.flow import Flow
from jesterTOV.logging_config import get_logger

logger = get_logger("jester")


[docs] class GWLikelihoodResampled(LikelihoodBase): """ Gravitational wave likelihood for a single GW event using normalizing flow posteriors This likelihood evaluates the GW posterior by: 1. Sampling masses (m1, m2) from the trained normalizing flow 2. Interpolating tidal deformabilities (Λ1, Λ2) from the EOS 3. Evaluating the NF log probability on (m1, m2, Λ1, Λ2) Parameters ---------- event_name : str Name of the GW event (e.g., "GW170817") model_dir : str Path to directory containing the trained normalizing flow model penalty_value : float, optional Penalty value for samples where masses exceed Mtov (default: 0.0, i.e. no penalty) N_masses_evaluation : int, optional Number of mass samples per likelihood evaluation (default: 20) N_masses_batch_size : int, optional Batch size for processing mass samples (default: 10) Attributes ---------- event_name : str Name of the GW event model_dir : str Path to directory containing the trained normalizing flow model penalty_value : float Penalty value for samples where masses exceed Mtov N_masses_evaluation : int Number of mass samples per likelihood evaluation N_masses_batch_size : int Batch size for processing mass samples flow : Flow Normalizing flow model for this GW event """ event_name: str model_dir: str penalty_value: float N_masses_evaluation: int N_masses_batch_size: int flow: Flow
[docs] def __init__( self, event_name: str, model_dir: str, penalty_value: float = 0.0, N_masses_evaluation: int = 20, N_masses_batch_size: int = 10, ) -> None: super().__init__() self.event_name = event_name self.model_dir = model_dir self.penalty_value = penalty_value self.N_masses_evaluation = N_masses_evaluation self.N_masses_batch_size = N_masses_batch_size # Load Flow model for this event logger.info(f"Loading NF model for {event_name} from {model_dir}") self.flow = Flow.from_directory(model_dir) logger.info(f"Loaded NF model for {event_name}")
[docs] def evaluate(self, params: dict[str, Float | Array]) -> Float: """ Evaluate log likelihood for given EOS parameters Parameters ---------- params : dict[str, Float | Array] Must contain: - '_random_key': Random seed for mass sampling (cast to int64) - 'masses_EOS': Array of neutron star masses from EOS - 'Lambdas_EOS': Array of tidal deformabilities from EOS Returns ------- Float Log likelihood value for this GW event """ # Extract parameters sampled_key = params["_random_key"].astype("int64") key = jax.random.key(sampled_key) masses_EOS: Float[Array, " n_points"] = params["masses_EOS"] Lambdas_EOS: Float[Array, " n_points"] = params["Lambdas_EOS"] mtov: Float = jnp.max(masses_EOS) # Sample all N_masses_evaluation samples from NF in one go all_nf_samples: Float[Array, "n_samples 2"] = self.flow.sample( key, (self.N_masses_evaluation,) ) def process_sample(sample: Float[Array, " 2"]) -> Float: """ Process a single NF sample Note: jax.lax.map with batch_size still applies the function to individual elements, not batches. The batch_size parameter is for compilation optimization. Parameters ---------- sample : Float[Array, " 2"] Single sample with [m1, m2] Returns ------- Float Log probability including penalties for this sample """ m1 = sample[0] m2 = sample[1] # Interpolate lambdas lambda_1 = jnp.interp(m1, masses_EOS, Lambdas_EOS, right=1.0) lambda_2 = jnp.interp(m2, masses_EOS, Lambdas_EOS, right=1.0) # Evaluate log_prob on single sample ml_sample = jnp.array([m1, m2, lambda_1, lambda_2]) logpdf = self.flow.log_prob(ml_sample) # Penalties for masses exceeding Mtov penalty_m1 = jnp.where(m1 > mtov, self.penalty_value, 0.0) penalty_m2 = jnp.where(m2 > mtov, self.penalty_value, 0.0) # Return log prob + penalties for this sample return logpdf + penalty_m1 + penalty_m2 # Use jax.lax.map with batching for memory-efficient processing # batch_size helps with compilation memory, not runtime batching all_logprobs = jax.lax.map( process_sample, all_nf_samples, batch_size=self.N_masses_batch_size ) # Average over all samples for this event log_likelihood = jnp.mean(all_logprobs) return log_likelihood
[docs] class GWLikelihood(LikelihoodBase): """ Gravitational wave likelihood using pre-sampled masses for deterministic evaluation This likelihood improves upon GWLikelihoodResampled by pre-sampling mass pairs once at initialization, eliminating the need for the _random_key parameter and providing deterministic likelihood evaluations critical for sampler convergence. Key improvements over GWLikelihoodResampled: 1. Deterministic: Same EOS parameters → same likelihood value 2. No _random_key hack: Uses fixed seed at initialization 3. Scalable: Can use N=10,000+ samples efficiently on GPU 4. Fair comparison: All EOS evaluated at identical mass points 5. Better convergence: Smooth likelihood surface for MCMC/SMC The likelihood works by: 1. Pre-sampling (m1, m2) pairs from the trained flow at initialization 2. For each EOS evaluation: interpolate Λ1, Λ2 from the candidate EOS at the fixed mass points, evaluate flow log_prob on (m1, m2, Λ1_EOS, Λ2_EOS), apply penalties for masses exceeding Mtov, and average over all pre-sampled mass pairs Parameters ---------- event_name : str Name of the GW event (e.g., "GW170817") model_dir : str Path to directory containing the trained normalizing flow model penalty_value : float, optional Penalty value for samples where masses exceed Mtov (default: 0.0, i.e. no penalty) N_masses_evaluation : int, optional Number of mass samples to pre-sample (default: 2000) Large values recommended - GPU parallelization makes this cheap! N_masses_batch_size : int, optional Batch size for jax.lax.map processing (default: 1000) seed : int, optional Random seed for mass pre-sampling (default: 42) Fixed seed ensures reproducibility across runs Attributes ---------- event_name : str Name of the GW event model_dir : str Path to directory containing the trained normalizing flow model penalty_value : float Penalty value for samples where masses exceed Mtov N_masses_evaluation : int Number of pre-sampled mass pairs N_masses_batch_size : int Batch size for processing seed : int Random seed used for pre-sampling flow : Flow Normalizing flow model for this GW event fixed_mass_samples : Float[Array, "n_samples 2"] Pre-sampled (m1, m2) pairs from the flow, shape [N, 2] Notes ----- This class does NOT require _random_key in the parameter dictionary, unlike GWLikelihoodResampled. The seed is only used once at initialization. GPU parallelization via jax.lax.map means N=10,000 samples costs nearly the same as N=20, so use large N for near-integration accuracy. Examples -------- Configure in YAML:: likelihoods: - type: "gw" enabled: true parameters: events: - name: "GW170817" N_masses_evaluation: 2000 # Default value N_masses_batch_size: 1000 seed: 42 """ event_name: str model_dir: str penalty_value: float N_masses_evaluation: int N_masses_batch_size: int seed: int flow: Flow fixed_mass_samples: Float[Array, "n_samples 2"]
[docs] def __init__( self, event_name: str, model_dir: str, penalty_value: float = 0.0, N_masses_evaluation: int = 2000, N_masses_batch_size: int = 1000, seed: int = 42, ) -> None: super().__init__() self.event_name = event_name self.model_dir = model_dir self.penalty_value = penalty_value self.N_masses_evaluation = N_masses_evaluation self.N_masses_batch_size = N_masses_batch_size self.seed = seed # Load Flow model for this event logger.info(f"Loading NF model for {event_name} from {model_dir}") self.flow = Flow.from_directory(model_dir) logger.info(f"Loaded NF model for {event_name}") # Pre-sample masses ONCE at initialization logger.info( f"Pre-sampling {N_masses_evaluation} mass pairs with seed={seed} for {event_name}" ) key = jax.random.key(seed) samples = self.flow.sample(key, (N_masses_evaluation,)) # Extract only (m1, m2), discard Lambda values from flow self.fixed_mass_samples = samples[:, :2] # Shape: [N, 2] logger.info( f"Pre-sampled mass range: m1=[{jnp.min(self.fixed_mass_samples[:, 0]):.3f}, " f"{jnp.max(self.fixed_mass_samples[:, 0]):.3f}] Msun, " f"m2=[{jnp.min(self.fixed_mass_samples[:, 1]):.3f}, " f"{jnp.max(self.fixed_mass_samples[:, 1]):.3f}] Msun" )
[docs] def evaluate(self, params: dict[str, Float | Array]) -> Float: """ Evaluate log likelihood for given EOS parameters Parameters ---------- params : dict[str, Float | Array] Must contain: - 'masses_EOS': Array of neutron star masses from EOS - 'Lambdas_EOS': Array of tidal deformabilities from EOS Note: Does NOT require '_random_key' (unlike GWLikelihood) Returns ------- Float Log likelihood value for this GW event """ # Extract EOS parameters (no _random_key needed!) masses_EOS: Float[Array, " n_points"] = params["masses_EOS"] Lambdas_EOS: Float[Array, " n_points"] = params["Lambdas_EOS"] mtov: Float = jnp.max(masses_EOS) def process_sample(sample: Float[Array, " 2"]) -> Float: """ Process a single pre-sampled mass pair Note: jax.lax.map with batch_size applies function to individual elements. The batch_size parameter is for compilation optimization. Parameters ---------- sample : Float[Array, " 2"] Pre-sampled mass pair [m1, m2] Returns ------- Float Log probability including penalties for this sample """ m1 = sample[0] m2 = sample[1] # Interpolate lambdas from candidate EOS lambda_1 = jnp.interp(m1, masses_EOS, Lambdas_EOS, right=1.0) lambda_2 = jnp.interp(m2, masses_EOS, Lambdas_EOS, right=1.0) # Evaluate log_prob on single sample ml_sample = jnp.array([m1, m2, lambda_1, lambda_2]) logpdf = self.flow.log_prob(ml_sample) # Penalties for masses exceeding Mtov penalty_m1 = jnp.where(m1 > mtov, self.penalty_value, 0.0) penalty_m2 = jnp.where(m2 > mtov, self.penalty_value, 0.0) # Return log prob + penalties for this sample return logpdf + penalty_m1 + penalty_m2 # Use jax.lax.map with batching for memory-efficient processing # Process all pre-sampled mass pairs all_logprobs = jax.lax.map( process_sample, self.fixed_mass_samples, batch_size=self.N_masses_batch_size ) # Take logsumexp over all pre-sampled mass pairs log_likelihood = logsumexp(all_logprobs) - jnp.log(self.N_masses_evaluation) return log_likelihood