Source code for jesterTOV.tov.base

r"""
Base class for TOV equation solvers.

This module defines the abstract interface that all TOV solvers must implement,
whether for General Relativity, modified gravity, or scalar-tensor theories.
"""

from abc import ABC, abstractmethod
import jax
import jax.numpy as jnp
from jaxtyping import Float, Array

from jesterTOV.tov.data_classes import EOSData, TOVSolution, FamilyData
from jesterTOV import utils
from jesterTOV.logging_config import get_logger

logger = get_logger("jester")


[docs] class TOVSolverBase(ABC): """ Abstract base class for TOV equation solvers. All TOV solvers must implement: - solve(): Solve TOV equations for a given central pressure - get_required_parameters(): Return list of additional parameters beyond EOS """
[docs] def fetch_params(self, params: dict[str, float]) -> dict[str, float]: """Extract solver-specific parameters from the combined EOS+TOV parameter dict. Uses :meth:`get_required_parameters` as the key list. For GR this returns an empty dict (no additional parameters needed). Args: params: Full parameter dictionary from the prior (EOS + TOV combined). Returns: dict[str, float]: Solver-specific subset of ``params``. """ required = self.get_required_parameters() missing = [k for k in required if k not in params] if missing: raise ValueError( f"{type(self).__name__}.fetch_params: missing required parameter(s): {missing}. " f"Available keys: {sorted(params.keys())}" ) return {k: params[k] for k in required}
[docs] @abstractmethod def solve( self, eos_data: EOSData, pc: float, tov_params: dict[str, float] ) -> TOVSolution: r""" Solve TOV equations for given central pressure. Args: eos_data: EOS quantities (type-safe dataclass) pc: Central pressure [geometric units] tov_params: Solver-specific parameters returned by :meth:`fetch_params`. Empty dict ``{}`` for GR; populated dict for modified-gravity solvers. Returns: TOVSolution: Mass, radius, and Love number k2 [geometric units] """ pass
[docs] def get_required_parameters(self) -> list[str]: """ Return additional parameters needed beyond EOS. Examples: - GR TOV: [] (no extra params) - Anisotropic TOV: ["gamma"] (anisotropy parameter) - Scalar-tensor: ["beta_ST", "phi_c", "nu_c"] Returns: list[str]: Parameter names """ return []
[docs] def construct_family( self, eos_data: EOSData, ndat: int, min_nsat: float, tov_params: dict[str, float], ) -> FamilyData: r""" Construct M-R-Λ curves by solving for multiple central pressures. The central pressure grid spans from a minimum based on min_nsat to the maximum pressure where the EOS remains causal (:math:`c_s^2 < 1`). Args: eos_data: EOS quantities in geometric units ndat: Number of points in central pressure grid min_nsat: Minimum central density in units of saturation density (assumed to be 0.16 :math:`\mathrm{fm}^{-3}`) tov_params: Solver-specific parameters returned by :meth:`fetch_params`. Empty dict ``{}`` for GR; populated dict for modified-gravity solvers. Returns: FamilyData: Mass-radius-tidal curves in physical units """ # Create central pressure grid pc_min = self._get_pc_min(eos_data, min_nsat) pc_max = self._get_pc_max(eos_data) pcs = jnp.logspace(jnp.log10(pc_min), jnp.log10(pc_max), num=ndat) # Solve TOV for each pc using vmap # NOTE: vmap batches TOVSolution fields into arrays def solve_single_pc(pc): return self.solve(eos_data, pc, tov_params) solutions = jax.vmap(solve_single_pc)(pcs) # Extract batched results (vmap converts scalar fields to arrays) masses: Float[Array, "ndat"] = solutions.M # type: ignore[assignment] radii: Float[Array, "ndat"] = solutions.R # type: ignore[assignment] k2s: Float[Array, "ndat"] = solutions.k2 # type: ignore[assignment] # Convert to physical units and compute tidal deformability return self._create_family_data(pcs, masses, radii, k2s, ndat)
def _get_pc_min(self, eos_data: EOSData, min_nsat: float) -> Float[Array, ""]: """ Calculate minimum central pressure from minimum density. Args: eos_data: EOS quantities min_nsat: Minimum density in units of saturation density Returns: Scalar Array: Minimum central pressure [geometric units] """ min_n_geometric = min_nsat * 0.16 * utils.fm_inv3_to_geometric pc_min = utils.interp_in_logspace(min_n_geometric, eos_data.ns, eos_data.ps) return pc_min def _get_pc_max(self, eos_data: EOSData) -> Float[Array, ""]: """ Calculate maximum causal central pressure. The maximum pressure is where the EOS becomes non-causal (cs2 >= 1). If the EOS is everywhere causal, use the maximum tabulated pressure. Args: eos_data: EOS quantities Returns: Scalar Array: Maximum central pressure [geometric units] """ # Find first non-causal point mask = eos_data.cs2 >= 1.0 any_noncausal = jnp.any(mask) indices = jnp.arange(len(eos_data.cs2)) masked_indices = jnp.where(mask, indices, len(eos_data.cs2)) first_noncausal_idx = jnp.min(masked_indices) # Use first non-causal point or last point if all causal idx = jnp.where(any_noncausal, first_noncausal_idx, len(eos_data.ps) - 1) pc_max = eos_data.ps[idx] return pc_max def _create_family_data( self, pcs: Float[Array, "ndat"], masses: Float[Array, "ndat"], radii: Float[Array, "ndat"], k2s: Float[Array, "ndat"], ndat: int, ) -> FamilyData: """ Shared post-processing: unit conversion, compactness limits, interpolation. Args: pcs: Central pressures [geometric units] masses: Masses [geometric units] radii: Radii [geometric units] k2s: Love numbers [dimensionless] ndat: Number of points for output grid Returns: FamilyData: Processed family curves in physical units """ # Calculate compactness compactness = masses / radii # Convert to physical units masses_solar = masses / utils.solar_mass_in_meter radii_km = radii / 1e3 # Calculate tidal deformability lambdas = 2.0 / 3.0 * k2s * jnp.power(compactness, -5.0) # Limit masses to be below MTOV (removes unstable branch) pcs_lim, masses_lim, radii_lim, lambdas_lim = utils.limit_by_MTOV( pcs, masses_solar, radii_km, lambdas ) # Get a mass grid and interpolate, since we might have some duplicate points mass_grid = jnp.linspace(jnp.min(masses_lim), jnp.max(masses_lim), ndat) radii_interp = jnp.interp(mass_grid, masses_lim, radii_lim) lambdas_interp = jnp.interp(mass_grid, masses_lim, lambdas_lim) pcs_interp = jnp.interp(mass_grid, masses_lim, pcs_lim) log10pcs = jnp.log10(pcs_interp) return FamilyData( log10pcs=log10pcs, masses=mass_grid, radii=radii_interp, lambdas=lambdas_interp, )