r"""
Utility functions for neutron star physics calculations.
This module provides essential utility functions for equation of state
interpolation, unit conversions, numerical integration, and auxiliary
calculations needed for TOV equation solving.
**Units:** The module defines conversion factors between different unit
systems commonly used in neutron star physics.
"""
from jax import vmap
import jax.numpy as jnp
from functools import partial
from jaxtyping import Array, Float
from interpax._spline import interp1d as interpax_interp1d
from diffrax import diffeqsolve, ODETerm, Tsit5, SaveAt, PIDController
#################################
### PHYSICAL CONSTANTS AND UNIT CONVERSIONS ###
#################################
# Fundamental constants (SI units)
eV = 1.602176634e-19 # Elementary charge [C]
c = 299792458.0 # Speed of light [m/s]
G = 6.6743e-11 # Gravitational constant [m³/kg/s²]
Msun = 1.988409870698051e30 # Solar mass [kg]
hbarc = 197.3269804593025 # Reduced Planck constant × c [MeV⋅fm]
# Particle masses [MeV]
m_p = 938.2720881604904 # Proton mass
m_n = 939.5654205203889 # Neutron mass
m_e = 0.510998 # Electron mass
m = (m_p + m_n) / 2.0 # Average nucleon mass (Margueron et al.)
# Derived constants
hbar = hbarc # Alias for compatibility
solar_mass_in_meter = Msun * G / c / c # Solar mass in geometric units [m]
# Basic unit conversions
fm_to_m = 1e-15 # Femtometer to meter
MeV_to_J = 1e6 * eV # MeV to Joule
m_to_fm = 1.0 / fm_to_m # Meter to femtometer
J_to_MeV = 1.0 / MeV_to_J # Joule to MeV
# Number density conversions
fm_inv3_to_SI = 1.0 / fm_to_m**3 # fm⁻³ to m⁻³
number_density_to_geometric = 1 # Number density scaling factor
fm_inv3_to_geometric = fm_inv3_to_SI * number_density_to_geometric
SI_to_fm_inv3 = 1.0 / fm_inv3_to_SI
geometric_to_fm_inv3 = 1.0 / fm_inv3_to_geometric
# Pressure and energy density conversions
MeV_fm_inv3_to_SI = MeV_to_J * fm_inv3_to_SI # MeV/fm³ to Pa
SI_to_MeV_fm_inv3 = 1.0 / MeV_fm_inv3_to_SI # Pa to MeV/fm³
pressure_SI_to_geometric = G / c**4 # SI pressure to geometric units
MeV_fm_inv3_to_geometric = MeV_fm_inv3_to_SI * pressure_SI_to_geometric
# Additional useful conversions
dyn_cm2_to_MeV_fm_inv3 = 1e-1 * J_to_MeV / m_to_fm**3 # dyn/cm² to MeV/fm³
g_cm_inv3_to_MeV_fm_inv3 = 1e3 * c**2 * J_to_MeV / m_to_fm**3 # g/cm³ to MeV/fm³
# Inverse conversions
geometric_to_SI = 1.0 / pressure_SI_to_geometric
SI_to_MeV_fm_inv3 = 1.0 / MeV_fm_inv3_to_SI
geometric_to_MeV_fm_inv3 = 1.0 / MeV_fm_inv3_to_geometric
#########################
### UTILITY FUNCTIONS ###
#########################
# Vectorized polynomial root finding
roots_vmap = vmap(partial(jnp.roots, strip_zeros=False), in_axes=0, out_axes=0)
[docs]
@vmap
def cubic_root_for_proton_fraction(coefficients):
r"""
Solve cubic equation for proton fraction in beta-equilibrium.
This function solves the cubic equation that arises from the
beta-equilibrium condition in neutron star matter:
.. math::
ax^3 + bx^2 + cx + d = 0
where the coefficients are related to the symmetry energy and
electron chemical potential. Uses Cardano's formula for exact
analytical solution.
Args:
coefficients (Array): Cubic polynomial coefficients [a, b, c, d].
Returns:
Array: Three roots of the cubic equation (may be complex).
Note:
This function is vectorized to handle multiple coefficient sets
simultaneously for different densities.
"""
a, b, c, d = coefficients
# Cardano's formula implementation
f = ((3.0 * c / a) - ((b**2) / (a**2))) / 3.0
g = (((2.0 * (b**3)) / (a**3)) - ((9.0 * b * c) / (a**2)) + (27.0 * d / a)) / 27.0
g_squared = g**2
f_cubed = f**3
h = g_squared / 4.0 + f_cubed / 27.0
R = -(g / 2.0) + jnp.sqrt(h)
S = jnp.cbrt(R)
T = -(g / 2.0) - jnp.sqrt(h)
U = jnp.cbrt(T)
# Three roots of the cubic equation
x1 = (S + U) - (b / (3.0 * a))
x2 = -(S + U) / 2 - (b / (3.0 * a)) + (S - U) * jnp.sqrt(3.0) * 0.5j
x3 = -(S + U) / 2 - (b / (3.0 * a)) - (S - U) * jnp.sqrt(3.0) * 0.5j
return jnp.array([x1, x2, x3])
[docs]
def cumtrapz(y, x):
r"""
Cumulatively integrate y(x) using the composite trapezoidal rule.
This function performs cumulative integration using the trapezoidal rule,
which is essential for computing thermodynamic quantities like enthalpy
and chemical potential from EOS data.
The trapezoidal rule approximates:
.. math::
\int_{x_0}^{x_i} y(x) dx \approx \sum_{j=1}^{i} \frac{\Delta x_j}{2}(y_{j-1} + y_j)
Args:
y (Array): Values to integrate.
x (Array): The coordinate to integrate along.
Returns:
Array: The result of cumulative integration of y along x.
Note:
The result array has the same length as the input, with the first
element set to a small value (1e-30) to avoid logarithm issues.
"""
# Validate input arrays
assert y.shape == x.shape, "Input arrays must have matching shapes"
assert len(y.shape) == 1, "Input arrays must be one-dimensional"
assert len(x.shape) == 1, "Input arrays must be one-dimensional"
# Apply trapezoidal rule for cumulative integration
dx = jnp.diff(x)
res = jnp.cumsum(dx * (y[1:] + y[:-1]) / 2.0)
# Prepend small value to avoid log(0) issues in subsequent calculations
res = jnp.concatenate((jnp.array([1e-30]), res))
return res
[docs]
def interp_in_logspace(x, xs, ys):
r"""
Perform logarithmic interpolation.
This function performs interpolation in logarithmic space, which is
more appropriate for quantities that span many orders of magnitude
(like pressure and density in neutron stars).
The interpolation is performed as:
.. math::
\log y(x) = \text{interp}(\log x, \log x_s, \log y_s)
Args:
x (float): Point at which to evaluate the interpolation.
xs (Array): Known x-coordinates (must be positive).
ys (Array): Known y-coordinates (must be positive).
Returns:
float: Interpolated value at x.
Note:
All input values must be positive since logarithms are taken.
"""
# Perform interpolation in log space and convert back
logx = jnp.log(x)
logxs = jnp.log(xs)
logys = jnp.log(ys)
return jnp.exp(jnp.interp(logx, logxs, logys))
[docs]
def limit_by_MTOV(
pc: Array, m: Array, r: Array, l: Array
) -> tuple[Array, Array, Array, Array]:
r"""
Truncate neutron star family at maximum TOV mass.
This function limits the mass-radius relation to the stable branch
by truncating at the maximum TOV mass (MTOV). Points beyond MTOV
correspond to unstable configurations and are replaced with duplicates
of the MTOV values to maintain array shape for JIT compilation.
The maximum mass occurs when:
.. math::
\frac{dM}{dp_c} = 0
Args:
pc (Array): Central pressure array.
m (Array): Gravitational mass array.
r (Array): Radius array.
l (Array): Tidal deformability array.
Returns:
tuple: Truncated arrays (pc, m, r, l) where unstable configurations
are replaced with MTOV values.
Note:
This approach maintains static array shapes required for JAX JIT
compilation while effectively removing unstable configurations.
"""
# Identify maximum TOV mass and corresponding index
m_at_TOV = jnp.max(m)
idx_TOV = jnp.argmax(m)
# Extract values at maximum mass point
pc_at_TOV = pc[idx_TOV]
r_at_TOV = r[idx_TOV]
l_at_TOV = l[idx_TOV]
# Identify stable (mass-increasing) configurations
m_is_increasing = jnp.diff(m) > 0
m_is_increasing = jnp.insert(m_is_increasing, idx_TOV, True)
# Mask out configurations beyond maximum mass
m_is_increasing = jnp.where(jnp.arange(len(m)) > idx_TOV, False, m_is_increasing)
# Replace unstable configurations with MTOV values
pc_new = jnp.where(m_is_increasing, pc, pc_at_TOV)
m_new = jnp.where(m_is_increasing, m, m_at_TOV)
r_new = jnp.where(m_is_increasing, r, r_at_TOV)
l_new = jnp.where(m_is_increasing, l, l_at_TOV)
# Sort by increasing mass for consistency
sort_idx = jnp.argsort(m_new)
pc_new = pc_new[sort_idx]
m_new = m_new[sort_idx]
r_new = r_new[sort_idx]
l_new = l_new[sort_idx]
return pc_new, m_new, r_new, l_new
###################
### SPLINES etc ###
###################
[docs]
def cubic_spline(xq: Float[Array, "n"], xp: Float[Array, "n"], fp: Float[Array, "n"]):
r"""
Cubic spline interpolation using interpax.
This function creates a cubic spline interpolator through the given
data points and evaluates it at the query points. Cubic splines
provide smooth interpolation with continuous first and second derivatives.
Args:
xq (Float[Array, "n"]): Query points for evaluation.
xp (Float[Array, "n"]): Known x-coordinates of data points.
fp (Float[Array, "n"]): Known y-coordinates, i.e., fp = f(xp).
Returns:
Array: Interpolated values at query points xq.
Note:
Uses the interpax library for JAX-compatible spline interpolation.
See: https://github.com/f0uriest/interpax
"""
return interpax_interp1d(xq, xp, fp, method="cubic")
[docs]
def sigmoid(x: Array) -> Array:
r"""
Sigmoid activation function.
Computes the sigmoid function:
.. math::
\sigma(x) = \frac{1}{1 + e^{-x}}
Args:
x (Array): Input values.
Returns:
Array: Sigmoid function values in range (0, 1).
"""
return 1.0 / (1.0 + jnp.exp(-x))
[docs]
def calculate_rest_mass_density(e: Float[Array, "n"], p: Float[Array, "n"]):
r"""
Compute rest-mass density from energy density and pressure.
This function solves the first law of thermodynamics to obtain the
rest-mass density (baryon density) from the energy density and pressure.
The relation is given by:
.. math::
\frac{d\rho}{d\varepsilon} = \frac{\rho}{p + \varepsilon}
where :math:`\rho` is the rest-mass density, :math:`\varepsilon` is the
energy density, and :math:`p` is the pressure.
Args:
e (Float[Array, "n"]): Energy density array [geometric units].
p (Float[Array, "n"]): Pressure array [geometric units].
Returns:
Array: Rest-mass density array [geometric units].
Note:
This function uses diffrax for ODE integration and may have
compatibility issues with some diffrax versions. The initial
condition assumes :math:`\rho(\varepsilon_0) = \varepsilon_0`.
"""
# Define pressure interpolation function
def p_interp(e_val):
return jnp.interp(e_val, e, p)
# Define the ODE: dρ/dε = ρ/(p + ε)
def rhs(t, rho, args):
p_val = p_interp(t)
return rho / (p_val + t)
# Initial condition: assume ρ(ε_0) = ε_0
rho0 = e[0]
# Set up ODE integration using diffrax
term = ODETerm(rhs)
solver = Tsit5()
# Integrate from e[0] to e[-1]
solution = diffeqsolve(
term,
solver,
t0=e[0], # Initial energy density
t1=e[-1], # Final energy density
dt0=1e-8, # Initial step size
y0=rho0, # Initial rest-mass density
saveat=SaveAt(ts=e), # Save at input grid points
stepsize_controller=PIDController(rtol=1e-5, atol=1e-6),
)
return solution.ys