r"""
Wrapper for trained normalizing flows with automatic data preprocessing.
This module provides a high-level interface for loading and using pre-trained
normalizing flow models for gravitational wave inference. The Flow class handles
the complexities of data standardization and model loading, allowing users to
sample from or evaluate trained flows with a simple API.
Normalizing flows trained on gravitational wave posterior samples can be used
for importance sampling in EOS inference, providing efficient proposals that
capture the correlations between binary component masses and tidal deformabilities.
Key Features
------------
- Automatic min-max standardization and inverse transformation
- Simple save/load interface compatible with flowjax models
- JAX-accelerated sampling and probability evaluation
Typical Workflow
----------------
1. Train a flow on GW posterior samples using train_flow.py
2. Load the trained flow: flow = Flow.from_directory("path/to/model/")
3. Sample or evaluate: samples = flow.sample(key, (1000,))
See Also
--------
train_flow : Module for training normalizing flows on GW posteriors
Examples
--------
Load a trained flow and generate samples:
>>> from jesterTOV.inference.flows import Flow
>>> import jax
>>> flow = Flow.from_directory("./models/gw170817/")
>>> samples = flow.sample(jax.random.key(0), (1000,))
>>> print(samples.shape) # (1000, 4) for (m1, m2, λ1, λ2)
Evaluate log-probability of data points:
>>> data = jnp.array([[1.4, 1.3, 100, 200]])
>>> log_prob = flow.log_prob(data)
"""
import json
import os
from typing import Any, Dict, Tuple
import equinox as eqx
import jax
import jax.numpy as jnp
from jax import Array
from flowjax.distributions import AbstractDistribution, Normal
from flowjax.flows import (
block_neural_autoregressive_flow,
coupling_flow,
masked_autoregressive_flow,
)
from flowjax.bijections import (
RationalQuadraticSpline,
Affine,
)
[docs]
class Flow:
"""
Wrapper class for flowjax normalizing flows with automatic standardization handling.
This class encapsulates a trained normalizing flow and handles data standardization
transparently. When sampling, it automatically converts samples back to the original
scale if standardization was used during training.
Attributes:
flow: The underlying flowjax flow model
metadata: Training metadata dictionary
flow_kwargs: Flow architecture kwargs
standardize: Whether standardization was used during training
data_bounds: Min/max bounds for each feature (if standardization was used)
Example:
>>> # Load a trained flow
>>> flow = Flow.from_directory("./models/gw170817/")
>>>
>>> # Sample in original scale (standardization handled automatically)
>>> samples = flow.sample(jax.random.key(0), (1000,))
>>>
>>> # Access metadata
>>> print(f"Flow type: {flow.metadata['flow_type']}")
>>> print(f"Standardized: {flow.standardize}")
"""
[docs]
def __init__(
self,
flow: AbstractDistribution,
metadata: Dict[str, Any],
flow_kwargs: Dict[str, Any],
):
"""
Initialize Flow wrapper.
Args:
flow: Trained flowjax flow model
metadata: Training metadata
flow_kwargs: Flow architecture kwargs
"""
self.flow = flow
self.metadata = metadata
self.flow_kwargs = flow_kwargs
self.standardize = metadata.get("standardize", False)
# Detect standardization method from metadata
has_mean_std = "data_mean" in metadata and "data_std" in metadata
has_bounds = "data_bounds_min" in metadata and "data_bounds_max" in metadata
if self.standardize:
if has_mean_std:
# Z-score standardization (new default)
self.standardization_method = "zscore"
self.data_mean = jnp.array(metadata["data_mean"])
self.data_std = jnp.array(metadata["data_std"])
# Avoid division by zero
self.data_std = jnp.where(self.data_std == 0, 1.0, self.data_std)
elif has_bounds:
# Min-max standardization (legacy)
self.standardization_method = "minmax"
self.data_min = jnp.array(metadata["data_bounds_min"])
self.data_max = jnp.array(metadata["data_bounds_max"])
self.data_range = self.data_max - self.data_min
# Avoid division by zero
self.data_range = jnp.where(self.data_range == 0, 1.0, self.data_range)
else:
raise ValueError(
"Standardization enabled but metadata missing both "
"(data_mean, data_std) and (data_bounds_min, data_bounds_max)"
)
else:
# No standardization - create identity transform
# Infer dimensionality from flow
n_features = self.flow.shape[0]
self.standardization_method = "none"
# For identity: use minmax with min=0, range=1
self.data_min = jnp.zeros(n_features)
self.data_max = jnp.ones(n_features)
self.data_range = jnp.ones(n_features)
[docs]
@classmethod
def from_directory(cls, output_dir: str) -> "Flow":
"""
Load a trained flow from a directory.
Args:
output_dir: Directory containing flow_weights.eqx, flow_kwargs.json, metadata.json
Returns:
Flow instance with loaded model and metadata
Example:
>>> flow = Flow.from_directory("./models/gw170817/")
"""
# Load the flow model and metadata
flow_model, metadata = load_model(output_dir)
# Load kwargs
kwargs_path = os.path.join(output_dir, "flow_kwargs.json")
with open(kwargs_path, "r") as f:
flow_kwargs = json.load(f)
return cls(flow_model, metadata, flow_kwargs)
[docs]
def sample(self, key: Array, shape: Tuple[int, ...]) -> Array:
"""
Sample from the flow and return in original scale.
If standardization was used during training, samples are automatically
converted back to the original scale using the inverse transformation
(z-score or min-max). If not, the transformation is identity (no-op).
Args:
key: JAX random key (jax.Array)
shape: Shape of samples to generate (e.g., (1000,) for 1000 samples)
Returns:
Samples in original scale as JAX array of shape (``*shape``, n_features)
Example:
>>> samples = flow.sample(jax.random.key(0), (1000,))
>>> print(samples.shape) # (1000, 4) for 4D flow
"""
# Sample in standardized space
samples = self.flow.sample(key, shape)
# Inverse transformation to original scale (method-dependent)
samples = self.destandardize_output(samples)
return samples
[docs]
def destandardize_output(self, data: Array) -> Array:
"""
Convert standardized data back to original scale.
Applies the inverse of the standardization method:
- Z-score: x * std + mean
- Min-max: x * (max - min) + min
- None: identity (no-op)
Args:
data: Data in standardized space (z-score or [0, 1])
Returns:
Data in original scale (or unchanged if standardization not used)
Example:
>>> standardized_data = jnp.array([[0.5, 0.5, 0.5, 0.5]])
>>> original = flow.destandardize_output(standardized_data)
"""
if self.standardization_method == "zscore":
# Inverse z-score: x * std + mean
return data * self.data_std + self.data_mean
else:
# Inverse min-max or identity: x * range + min
# If standardization disabled, this is identity (min=0, range=1)
return data * self.data_range + self.data_min
[docs]
def log_prob(self, x: Array) -> Array:
"""
Evaluate log probability of data under the flow.
If standardization was used, input data is automatically standardized
before evaluation and Jacobian correction is applied. If not, operations
are identity (no-op).
The Jacobian correction accounts for the change of variables:
- Z-score: log p(x) = log p(x_std) - sum(log(std))
- Min-max: log p(x) = log p(x_std) - sum(log(max - min))
- None: log p(x) = log p(x_std) (no correction)
Args:
x: Data in original scale, shape (n_samples, n_features).
JAX array.
Returns:
Log probabilities as JAX array, shape (n_samples,)
Example:
>>> data = jnp.array([[1.4, 1.3, 100, 200]])
>>> log_prob = flow.log_prob(data)
"""
# Standardize input (method-dependent or identity)
x_std = self.standardize_input(x)
# Evaluate log probability in standardized space
log_p = self.flow.log_prob(x_std)
# Account for Jacobian of inverse transformation
if self.standardization_method == "zscore":
# Z-score: log |det J| = sum(log(std))
log_det_jacobian = -jnp.sum(jnp.log(self.data_std))
else:
# Min-max or none: log |det J| = sum(log(range))
# If standardization disabled (range=1), log_det_jacobian = 0
log_det_jacobian = -jnp.sum(jnp.log(self.data_range))
log_p = log_p + log_det_jacobian
return log_p
def create_transformer(
transformer_type: str = "affine",
transformer_knots: int = 8,
transformer_interval: float = 4.0,
) -> Any:
"""
Create a transformer for masked_autoregressive_flow and coupling_flow.
Args:
transformer_type: Type of transformer ("affine", "rational_quadratic_spline")
transformer_knots: Number of knots for RationalQuadraticSpline
transformer_interval: Interval for RationalQuadraticSpline
Returns:
Transformer instance
"""
if transformer_type == "affine":
return Affine()
elif transformer_type == "rational_quadratic_spline":
return RationalQuadraticSpline(
knots=transformer_knots, interval=transformer_interval
)
else:
raise ValueError(
f"Unknown transformer type: {transformer_type}. "
"Must be one of: affine, rational_quadratic_spline"
)
def create_flow(
key: Array,
dim: int = 4,
flow_type: str = "masked_autoregressive_flow",
nn_depth: int = 5,
nn_block_dim: int = 8,
nn_width: int = 50,
flow_layers: int = 1,
invert: bool = True,
cond_dim: int | None = None,
transformer_type: str = "affine",
transformer_knots: int = 8,
transformer_interval: float = 4.0,
) -> Any:
"""
Create a normalizing flow of the specified type with flexible dimensionality.
Args:
key: JAX random key
dim: Dimensionality of the data (default: 4 for GW [m1, m2, λ1, λ2],
can be 2 for NICER [M, R], etc.)
flow_type: Type of flow ("block_neural_autoregressive_flow",
"masked_autoregressive_flow", "coupling_flow")
nn_depth: Depth of neural network (for block_neural_autoregressive_flow,
masked_autoregressive_flow, coupling_flow)
nn_block_dim: Block dimension (for block_neural_autoregressive_flow)
nn_width: Width of hidden layers (for masked_autoregressive_flow, coupling_flow)
flow_layers: Number of flow layers
invert: Whether to invert the flow
cond_dim: Conditional dimension (None for unconditional flows)
transformer_type: Type of transformer for masked_autoregressive_flow and coupling_flow
("affine", "rational_quadratic_spline")
transformer_knots: Number of knots for RationalQuadraticSpline
transformer_interval: Interval for RationalQuadraticSpline
Returns:
Untrained flowjax flow model
"""
base_dist = Normal(jnp.zeros(dim))
if flow_type == "block_neural_autoregressive_flow":
flow = block_neural_autoregressive_flow(
key=key,
base_dist=base_dist,
nn_depth=nn_depth,
nn_block_dim=nn_block_dim,
flow_layers=flow_layers,
invert=invert,
cond_dim=cond_dim,
)
elif flow_type == "masked_autoregressive_flow":
transformer = create_transformer(
transformer_type, transformer_knots, transformer_interval
)
flow = masked_autoregressive_flow(
key=key,
base_dist=base_dist,
flow_layers=flow_layers,
nn_width=nn_width,
nn_depth=nn_depth,
invert=invert,
cond_dim=cond_dim,
transformer=transformer,
)
elif flow_type == "coupling_flow":
transformer = create_transformer(
transformer_type, transformer_knots, transformer_interval
)
flow = coupling_flow(
key=key,
base_dist=base_dist,
flow_layers=flow_layers,
nn_width=nn_width,
nn_depth=nn_depth,
invert=invert,
cond_dim=cond_dim,
transformer=transformer,
)
else:
raise ValueError(
f"Unknown flow type: {flow_type}. Must be one of: "
"block_neural_autoregressive_flow, masked_autoregressive_flow, "
"coupling_flow"
)
return flow
[docs]
def load_model(output_dir: str) -> Tuple[Any, Dict[str, Any]]:
"""
Load a trained flow model from saved files.
Args:
output_dir: Directory containing saved model files
Returns:
flow: Loaded flow model
metadata: Training metadata (includes data statistics if standardization was used)
Example:
>>> flow, metadata = load_model("./models/gw170817/")
"""
# Load metadata first to infer dimensionality
metadata_path = os.path.join(output_dir, "metadata.json")
with open(metadata_path, "r") as f:
metadata = json.load(f)
# Load kwargs
kwargs_path = os.path.join(output_dir, "flow_kwargs.json")
with open(kwargs_path, "r") as f:
flow_kwargs = json.load(f)
# Infer dimensionality from metadata
# Try data_mean first (new format), then data_bounds_min (legacy)
if "data_mean" in metadata:
dim = len(metadata["data_mean"])
elif "data_bounds_min" in metadata:
dim = len(metadata["data_bounds_min"])
else:
# Default to 4 for backward compatibility with old models without standardization
dim = 4
# Recreate flow architecture
key = jax.random.key(flow_kwargs["seed"])
flow = create_flow(
key=key,
dim=dim,
flow_type=flow_kwargs["flow_type"],
nn_depth=flow_kwargs["nn_depth"],
nn_block_dim=flow_kwargs["nn_block_dim"],
nn_width=flow_kwargs["nn_width"],
flow_layers=flow_kwargs["flow_layers"],
invert=flow_kwargs["invert"],
cond_dim=flow_kwargs["cond_dim"],
transformer_type=flow_kwargs.get("transformer_type", "affine"),
transformer_knots=flow_kwargs.get("transformer_knots", 8),
transformer_interval=flow_kwargs.get("transformer_interval", 4.0),
)
# Load weights
weights_path = os.path.join(output_dir, "flow_weights.eqx")
flow = eqx.tree_deserialise_leaves(weights_path, flow)
return flow, metadata