Source code for jesterTOV.inference.base.transform

r"""Transform base classes for JESTER inference system.

These transforms encode the behavior to transform sets of parameters for the EOS, e.g. sampled from priors, to their TOV solutions.

This module contains transform classes that were originally from Jim (jimgw v0.2.0).
They are copied here to remove the dependency on jimgw.

Note: These classes follow the Jim/jimgw architecture and provide parameter
transformations with Jacobian corrections for Bayesian inference.
"""

from abc import ABC
from typing import Callable, TypeAlias

import jax
import jax.numpy as jnp
from beartype import beartype as typechecker
from jaxtyping import Array, Float, jaxtyped

# Type aliases for better readability
ParamDict: TypeAlias = dict[
    str, Float
]  # dictionary containing parameter names and their values
NameMapping: TypeAlias = tuple[
    list[str], list[str]
]  # tuple of (input_names, output_names)


class Transform(ABC):
    """
    Base class for transforms.

    Note: This class follows the Jim/jimgw architecture. The purpose of this class
    is purely for keeping track of parameter name mappings.
    """

    name_mapping: NameMapping

    def __init__(
        self,
        name_mapping: NameMapping,
    ) -> None:
        """
        Parameters
        ----------
        name_mapping : tuple[list[str], list[str]]
            Tuple of (input_names, output_names) for the transform.
        """
        self.name_mapping = name_mapping

    def propagate_name(self, x: list[str]) -> list[str]:
        """
        Propagate parameter names through the transform.

        Parameters
        ----------
        x : list[str]
            Input parameter names.

        Returns
        -------
        list[str]
            Output parameter names after applying the transform.
        """
        input_set = set(x)
        from_set = set(self.name_mapping[0])
        to_set = set(self.name_mapping[1])
        return list(input_set - from_set | to_set)


[docs] class NtoMTransform(Transform): """ N-to-M parameter transform (not necessarily invertible). Note: This class follows the Jim/jimgw architecture. Used for likelihood transforms where you map N parameters to M different parameters without requiring invertibility or Jacobian corrections. """ transform_func: Callable[ [ParamDict], ParamDict ] # Maps parameter dict to transformed parameter dict
[docs] def forward(self, x: ParamDict) -> ParamDict: """ Push forward the input x to transformed coordinate y. Parameters ---------- x : dict[str, Float] The input dictionary. Returns ------- y : dict[str, Float] The transformed dictionary. """ x_copy = x.copy() output_params = self.transform_func(x_copy) jax.tree.map( lambda key: x_copy.pop(key), self.name_mapping[0], ) jax.tree.map( lambda key: x_copy.update({key: output_params[key]}), list(output_params.keys()), ) return x_copy
class NtoNTransform(NtoMTransform): """ N-to-N parameter transform with Jacobian calculation. Note: This class follows the Jim/jimgw architecture. """ @property def n_dim(self) -> int: return len(self.name_mapping[0]) def transform(self, x: ParamDict) -> tuple[ParamDict, Float]: """ Transform the input x to transformed coordinate y and return the log Jacobian determinant. This only works if the transform is a N -> N transform. Parameters ---------- x : ParamDict The input dictionary. Returns ------- y : ParamDict The transformed dictionary. log_det : Float The log Jacobian determinant. """ x_copy = x.copy() transform_params = dict((key, x_copy[key]) for key in self.name_mapping[0]) output_params = self.transform_func(transform_params) jacobian = jax.jacfwd(self.transform_func)(transform_params) jacobian = jnp.array(jax.tree.leaves(jacobian)) jacobian = jnp.log( jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) ) jax.tree.map( lambda key: x_copy.pop(key), self.name_mapping[0], ) jax.tree.map( lambda key: x_copy.update({key: output_params[key]}), list(output_params.keys()), ) return x_copy, jacobian
[docs] class BijectiveTransform(NtoNTransform): """ Bijective (invertible) N-to-N parameter transform with Jacobian corrections. Note: This class follows the Jim/jimgw architecture. Used for sample transforms where parameters are transformed during MCMC sampling and Jacobian corrections are applied to the prior. """ inverse_transform_func: Callable[ [ParamDict], ParamDict ] # Maps transformed dict back to original parameter dict
[docs] def inverse(self, y: ParamDict) -> tuple[ParamDict, Float]: """ Inverse transform the input y to original coordinate x. Parameters ---------- y : ParamDict The transformed dictionary. Returns ------- x : ParamDict The original dictionary. log_det : Float The log Jacobian determinant. """ y_copy = y.copy() transform_params = dict((key, y_copy[key]) for key in self.name_mapping[1]) output_params = self.inverse_transform_func(transform_params) jacobian = jax.jacfwd(self.inverse_transform_func)(transform_params) jacobian = jnp.array(jax.tree.leaves(jacobian)) jacobian = jnp.log( jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))) ) jax.tree.map( lambda key: y_copy.pop(key), self.name_mapping[1], ) jax.tree.map( lambda key: y_copy.update({key: output_params[key]}), list(output_params.keys()), ) return y_copy, jacobian
[docs] def backward(self, y: ParamDict) -> ParamDict: """ Pull back the input y to original coordinate x (without Jacobian). Parameters ---------- y : ParamDict The transformed dictionary. Returns ------- x : ParamDict The original dictionary. """ y_copy = y.copy() output_params = self.inverse_transform_func(y_copy) jax.tree.map( lambda key: y_copy.pop(key), self.name_mapping[1], ) jax.tree.map( lambda key: y_copy.update({key: output_params[key]}), list(output_params.keys()), ) return y_copy
# ============================================================================ # Specific Transform Implementations # ============================================================================ # These are used internally by UniformPrior and other prior distributions. @jaxtyped(typechecker=typechecker) class ScaleTransform(BijectiveTransform): """ Scale transform: y = x * scale. Note: This class follows the Jim/jimgw architecture. """ scale: Float def __init__( self, name_mapping: NameMapping, scale: Float, ) -> None: """ Parameters ---------- name_mapping : NameMapping Tuple of (input_names, output_names). scale : Float The scaling factor. """ super().__init__(name_mapping) self.scale = scale self.transform_func = lambda x: { name_mapping[1][i]: x[name_mapping[0][i]] * self.scale for i in range(len(name_mapping[0])) } self.inverse_transform_func = lambda x: { name_mapping[0][i]: x[name_mapping[1][i]] / self.scale for i in range(len(name_mapping[1])) } @jaxtyped(typechecker=typechecker) class OffsetTransform(BijectiveTransform): """ Offset transform: y = x + offset. Note: This class follows the Jim/jimgw architecture. """ offset: Float def __init__( self, name_mapping: NameMapping, offset: Float, ) -> None: """ Parameters ---------- name_mapping : NameMapping Tuple of (input_names, output_names). offset : Float The offset value. """ super().__init__(name_mapping) self.offset = offset self.transform_func = lambda x: { name_mapping[1][i]: x[name_mapping[0][i]] + self.offset for i in range(len(name_mapping[0])) } self.inverse_transform_func = lambda x: { name_mapping[0][i]: x[name_mapping[1][i]] - self.offset for i in range(len(name_mapping[1])) } @jaxtyped(typechecker=typechecker) class LogitTransform(BijectiveTransform): """ Logit transform: y = 1 / (1 + exp(-x)). Note: This class follows the Jim/jimgw architecture. """ def __init__( self, name_mapping: NameMapping, ) -> None: """ Parameters ---------- name_mapping : NameMapping Tuple of (input_names, output_names). """ super().__init__(name_mapping) self.transform_func = lambda x: { name_mapping[1][i]: 1 / (1 + jnp.exp(-x[name_mapping[0][i]])) for i in range(len(name_mapping[0])) } self.inverse_transform_func = lambda x: { name_mapping[0][i]: jnp.log( x[name_mapping[1][i]] / (1 - x[name_mapping[1][i]]) ) for i in range(len(name_mapping[1])) } @jaxtyped(typechecker=typechecker) class MVGaussianToUnitCube(BijectiveTransform): r"""Maps Multivariate Gaussian parameter space :math:`\leftrightarrow` unit hypercube :math:`[0,1]^n`. Uses the probability integral transform to convert between a multivariate Gaussian distribution and the unit hypercube, enabling nested sampling (which requires sampling in :math:`[0,1]^n`) with Gaussian priors. Forward :math:`\theta \to \tilde{u}`: .. math:: z = L^{-1}(\theta - \mu), \quad \tilde{u}_i = \Phi(z_i) Backward :math:`\tilde{u} \to \theta`: .. math:: z_i = \Phi^{-1}(\tilde{u}_i), \quad \theta = \mu + Lz where :math:`L` is the lower-triangular Cholesky factor of the covariance :math:`\Sigma`, and :math:`\Phi` is the standard normal CDF. By the probability integral transform, the prior density in unit-cube space is exactly flat (log density = 0), which is what nested sampling requires. Parameters ---------- name_mapping : NameMapping Tuple of (input_names, output_names). Typically the same parameter names are used in both spaces. mean : Float[Array, " n_dim"] Mean vector of the Gaussian prior. cov : Float[Array, "n_dim n_dim"] Covariance matrix of the Gaussian prior (must be positive definite). """ mean: Float[Array, " n_dim"] L: Float[Array, "n_dim n_dim"] # Lower triangular Cholesky factor of cov def __init__( self, name_mapping: NameMapping, mean: Float[Array, " n_dim"], cov: Float[Array, "n_dim n_dim"], ) -> None: """ Parameters ---------- name_mapping : NameMapping Tuple of (input_names, output_names). mean : Float[Array, " n_dim"] Mean vector. cov : Float[Array, "n_dim n_dim"] Covariance matrix (positive definite). """ n_out = len(name_mapping[0]) n_in = len(name_mapping[1]) mean_arr_raw = jnp.asarray(mean) cov_arr_raw = jnp.asarray(cov) if n_out != n_in: raise ValueError( f"MVGaussianToUnitCube: name_mapping output names ({n_out}) and " f"input names ({n_in}) must have the same length." ) if mean_arr_raw.shape != (n_out,): raise ValueError( f"MVGaussianToUnitCube: mean shape {mean_arr_raw.shape} does not match " f"expected ({n_out},) from name_mapping." ) if cov_arr_raw.shape != (n_out, n_out): raise ValueError( f"MVGaussianToUnitCube: cov shape {cov_arr_raw.shape} does not match " f"expected ({n_out}, {n_out}) from name_mapping." ) super().__init__(name_mapping) self.mean = mean_arr_raw self.L = jnp.linalg.cholesky(cov_arr_raw) n = n_out mean_arr = self.mean L_arr = self.L def _forward(x: ParamDict) -> ParamDict: theta = jnp.array([x[k] for k in name_mapping[0]]) z = jnp.linalg.solve(L_arr, theta - mean_arr) u = jax.scipy.special.ndtr(z) return {name_mapping[1][i]: u[i] for i in range(n)} def _inverse(y: ParamDict) -> ParamDict: u = jnp.array([y[k] for k in name_mapping[1]]) # Clamp to open interval (0, 1) so ndtri never produces ±inf, # which happens when the unit-cube stepper lands exactly on 0.0 or 1.0. u_lo = jnp.nextafter( jnp.array(0.0, dtype=u.dtype), jnp.array(1.0, dtype=u.dtype) ) u_hi = jnp.nextafter( jnp.array(1.0, dtype=u.dtype), jnp.array(0.0, dtype=u.dtype) ) u_clamped = jnp.clip(u, u_lo, u_hi) z = jax.scipy.special.ndtri(u_clamped) theta = mean_arr + L_arr @ z return {name_mapping[0][i]: theta[i] for i in range(n)} self.transform_func = _forward self.inverse_transform_func = _inverse @jaxtyped(typechecker=typechecker) class BoundToBound(BijectiveTransform): """ Linear transform from [original_lower, original_upper] to [target_lower, target_upper]. Used for nested sampling to map prior bounds to unit cube [0, 1]. Note: This implementation handles per-parameter bounds, where each parameter can have different original and target bounds. """ original_lower_bound: ParamDict # Lower bounds for original parameters original_upper_bound: ParamDict # Upper bounds for original parameters target_lower_bound: ParamDict # Lower bounds for target parameters target_upper_bound: ParamDict # Upper bounds for target parameters def __init__( self, name_mapping: NameMapping, original_lower_bound: ParamDict, # Lower bounds for original parameters original_upper_bound: ParamDict, # Upper bounds for original parameters target_lower_bound: ParamDict, # Lower bounds for target parameters target_upper_bound: ParamDict, # Upper bounds for target parameters ) -> None: """ Parameters ---------- name_mapping : NameMapping Tuple of (input_names, output_names). original_lower_bound : dict[str, Float] Lower bounds in original space (per parameter). original_upper_bound : dict[str, Float] Upper bounds in original space (per parameter). target_lower_bound : dict[str, Float] Lower bounds in target space (per parameter). target_upper_bound : dict[str, Float] Upper bounds in target space (per parameter). """ super().__init__(name_mapping) self.original_lower_bound = original_lower_bound self.original_upper_bound = original_upper_bound self.target_lower_bound = target_lower_bound self.target_upper_bound = target_upper_bound # Forward: original → target # y = (x - x_min) / (x_max - x_min) * (y_max - y_min) + y_min def _forward(x: ParamDict) -> ParamDict: result = {} for i, in_name in enumerate(name_mapping[0]): out_name = name_mapping[1][i] x_val = x[in_name] x_min = original_lower_bound[in_name] x_max = original_upper_bound[in_name] y_min = target_lower_bound[out_name] y_max = target_upper_bound[out_name] result[out_name] = (x_val - x_min) / (x_max - x_min) * ( y_max - y_min ) + y_min return result # Inverse: target → original # x = (y - y_min) / (y_max - y_min) * (x_max - x_min) + x_min def _inverse(y: ParamDict) -> ParamDict: result = {} for i, out_name in enumerate(name_mapping[1]): in_name = name_mapping[0][i] y_val = y[out_name] x_min = original_lower_bound[in_name] x_max = original_upper_bound[in_name] y_min = target_lower_bound[out_name] y_max = target_upper_bound[out_name] result[in_name] = (y_val - y_min) / (y_max - y_min) * ( x_max - x_min ) + x_min return result self.transform_func = _forward self.inverse_transform_func = _inverse