Source code for jesterTOV.inference.flows.train_flow

"""Training script for normalizing flows on gravitational wave posterior samples.

Trains normalizing flow models to approximate GW posteriors in (m1, m2, λ1, λ2) space.
The trained flows serve as efficient proposal distributions for EOS inference.

Training Pipeline
-----------------
1. Load configuration from YAML file
2. Load posterior samples from npz file
3. Apply optional standardization
4. Create flow architecture (autoregressive, coupling)
5. Fit flow using maximum likelihood with early stopping
6. Save trained weights, config, and metadata
7. Generate validation plots

Supported Architectures
-----------------------
- coupling_flow: Balanced speed and expressiveness
- block_neural_autoregressive_flow: Good expressiveness
- masked_autoregressive_flow: Flexible but slower

Configuration-Driven Usage
---------------------------
Create a YAML config file (e.g., config.yaml):

    posterior_file: data/gw170817_posterior.npz
    output_dir: models/gw170817/
    flow_type: masked_autoregressive_flow
    num_epochs: 1000
    learning_rate: 1.0e-3
    standardize: true
    plot_corner: true
    plot_losses: true

Then run:

    uv run python -m jesterTOV.inference.flows.train_flow config.yaml

Or use the bash scripts for batch training:

    bash train_all_flows.sh

Programmatic Usage
------------------
For custom training workflows, use the provided functions:

>>> from jesterTOV.inference.flows.train_flow import train_flow_from_config
>>> from jesterTOV.inference.flows.config import FlowTrainingConfig
>>> config = FlowTrainingConfig.from_yaml("config.yaml")
>>> train_flow_from_config(config)

Or use the lower-level functions directly:

>>> from jesterTOV.inference.flows.flow import create_flow
>>> from jesterTOV.inference.flows.train_flow import load_gw_posterior, train_flow, save_model
>>> data, metadata = load_gw_posterior("gw170817.npz", max_samples=50000)
>>> flow = create_flow(jax.random.key(0), flow_type="masked_autoregressive_flow")
>>> trained_flow, losses = train_flow(flow, data, jax.random.key(1))
>>> save_model(trained_flow, "models/gw170817/", flow_kwargs, metadata)

Output Files
------------
The training script saves:
- flow_weights.eqx: Trained model parameters (Equinox serialization)
- flow_kwargs.json: Architecture configuration for reproducibility
- metadata.json: Training metadata (epochs, losses, data bounds, etc.)
- figures/losses.png: Training and validation loss curves
- figures/corner.png: Corner plot comparing data and flow samples
- figures/transformed_training_data.png: Visualization of transformed data
  (if physics constraints are enabled)

See Also
--------
jesterTOV.inference.flows.flow.Flow : High-level interface for loading trained flows
jesterTOV.inference.flows.config.FlowTrainingConfig : Configuration schema

Notes
-----
Training requires:
- JAX with GPU support recommended for large datasets
- flowjax for flow architectures
- equinox for model serialization
- PyYAML for configuration loading
- Optional: matplotlib and corner for plotting
"""

import json
import os
import sys
from pathlib import Path
from typing import Any, Dict, Tuple, Mapping

import equinox as eqx
import jax
import numpy as np
from jax import Array
from flowjax.train import fit_to_data

from jesterTOV.logging_config import get_logger
from .config import FlowTrainingConfig
from .flow import create_flow

logger = get_logger("jester")


[docs] def load_posterior( filepath: str, parameter_names: list[str], max_samples: int = 20_000, ) -> Tuple[np.ndarray, Dict[str, Any]]: """ Load posterior samples from npz file with flexible parameter selection. Args: filepath: Path to .npz file parameter_names: List of parameter names to extract from file max_samples: Maximum number of samples to use (downsampling if needed) Returns: data: Array of shape (n_samples, n_params) with selected parameters metadata: Dictionary with loading information Raises: FileNotFoundError: If file doesn't exist KeyError: If required parameter names are missing from file """ if not os.path.exists(filepath): raise FileNotFoundError(f"Posterior file not found: {filepath}") # Load data posterior = np.load(filepath) # Validate required keys missing_keys = [key for key in parameter_names if key not in posterior] if missing_keys: available_keys = list(posterior.keys()) raise KeyError( f"Missing required parameter names: {missing_keys}\n" f"Available keys in file: {available_keys}\n" f"Requested parameters: {parameter_names}" ) # Extract samples for each parameter columns = [posterior[param].flatten() for param in parameter_names] # Combine into array data = np.column_stack(columns) n_samples_total = data.shape[0] # Downsample if needed if n_samples_total > max_samples: downsample_factor = int(np.ceil(n_samples_total / max_samples)) data = data[::downsample_factor] logger.info( f"Downsampled from {n_samples_total} to {data.shape[0]} samples " f"(factor: {downsample_factor})" ) else: logger.info(f"Using all {n_samples_total} samples") metadata = { "n_samples_total": n_samples_total, "n_samples_used": data.shape[0], "parameter_names": parameter_names, "filepath": filepath, } return data, metadata
def standardize_data_zscore( data: np.ndarray, ) -> Tuple[np.ndarray, Dict[str, np.ndarray]]: """ Standardize data to mean=0, std=1 using z-score normalization. Args: data: Array of shape (n_samples, n_features) Returns: standardized_data: Data with mean=0, std=1 per feature statistics: Dictionary with 'mean' and 'std' arrays for each feature """ data_mean = data.mean(axis=0) data_std = data.std(axis=0) # Avoid division by zero (if a feature is constant) data_std = np.where(data_std == 0, 1.0, data_std) standardized_data = (data - data_mean) / data_std statistics = {"mean": data_mean, "std": data_std} return standardized_data, statistics def inverse_standardize_data_zscore( standardized_data: np.ndarray, statistics: Dict[str, np.ndarray] ) -> np.ndarray: """ Inverse transform z-score standardized data back to original scale. Args: standardized_data: Data with mean=0, std=1 statistics: Dictionary with 'mean' and 'std' arrays for each feature Returns: data: Data in original scale """ data_mean = statistics["mean"] data_std = statistics["std"] data_std = np.where(data_std == 0, 1.0, data_std) data = standardized_data * data_std + data_mean return data def standardize_data_minmax( data: np.ndarray, ) -> Tuple[np.ndarray, Dict[str, np.ndarray]]: """ Standardize data to [0, 1] domain using min-max scaling. Args: data: Array of shape (n_samples, n_features) Returns: standardized_data: Data scaled to [0, 1] bounds: Dictionary with 'min' and 'max' arrays for each feature """ data_min = data.min(axis=0) data_max = data.max(axis=0) # Avoid division by zero (if a feature is constant) data_range = data_max - data_min data_range = np.where(data_range == 0, 1.0, data_range) standardized_data = (data - data_min) / data_range bounds = {"min": data_min, "max": data_max} return standardized_data, bounds def inverse_standardize_data_minmax( standardized_data: np.ndarray, bounds: Dict[str, np.ndarray] ) -> np.ndarray: """ Inverse transform min-max standardized data back to original scale. Args: standardized_data: Data in [0, 1] domain bounds: Dictionary with 'min' and 'max' arrays for each feature Returns: data: Data in original scale """ data_min = bounds["min"] data_max = bounds["max"] data_range = data_max - data_min data_range = np.where(data_range == 0, 1.0, data_range) data = standardized_data * data_range + data_min return data
[docs] def train_flow( flow: Any, data: np.ndarray, key: Array, learning_rate: float = 1e-3, max_epochs: int = 600, max_patience: int = 50, val_prop: float = 0.2, batch_size: int = 128, ) -> Tuple[Any, Dict[str, list]]: """ Train the normalizing flow on data. Args: flow: Untrained flowjax flow data: Training data of shape (n_samples, n_dims) key: JAX random key learning_rate: Learning rate for optimizer max_epochs: Maximum number of epochs max_patience: Early stopping patience val_prop: Proportion of data to use for validation batch_size: Batch size for training Returns: trained_flow: Trained flow model losses: Dictionary with 'train' and 'val' loss arrays """ logger.info(f"Training flow for up to {max_epochs} epochs...") logger.info(f"Using {val_prop:.1%} of data for validation") logger.info(f"Batch size: {batch_size}") trained_flow, losses = fit_to_data( key=key, dist=flow, data=data, learning_rate=learning_rate, max_epochs=max_epochs, max_patience=max_patience, val_prop=val_prop, batch_size=batch_size, ) logger.info(f"Training completed after {len(losses['train'])} epochs") return trained_flow, losses
[docs] def save_model( flow: Any, output_dir: str, flow_kwargs: Dict[str, Any], metadata: Dict[str, Any], ) -> None: """ Save trained flow model, architecture kwargs, and metadata. Args: flow: Trained flowjax flow output_dir: Directory to save files flow_kwargs: Dictionary of kwargs needed to recreate flow architecture metadata: Dictionary with training metadata """ os.makedirs(output_dir, exist_ok=True) # Save model weights weights_path = os.path.join(output_dir, "flow_weights.eqx") logger.info(f"Saving model weights to {weights_path}") eqx.tree_serialise_leaves(weights_path, flow) # Save architecture kwargs kwargs_path = os.path.join(output_dir, "flow_kwargs.json") logger.info(f"Saving flow kwargs to {kwargs_path}") with open(kwargs_path, "w") as f: json.dump(flow_kwargs, f, indent=2) # Save metadata metadata_path = os.path.join(output_dir, "metadata.json") logger.info(f"Saving metadata to {metadata_path}") with open(metadata_path, "w") as f: json.dump(metadata, f, indent=2)
def plot_losses(losses: Mapping[str, np.ndarray | list], output_path: str) -> None: """Plot training and validation losses (accepts dict or list values).""" try: import matplotlib.pyplot as plt except ImportError: logger.warning("matplotlib not available, skipping loss plot") return plt.figure(figsize=(10, 6)) plt.plot(losses["train"], label="Train", color="red", alpha=0.7) plt.plot(losses["val"], label="Validation", color="blue", alpha=0.7) plt.xlabel("Epoch") plt.ylabel("Negative Log Likelihood") plt.legend() plt.grid(True, alpha=0.3) plt.tight_layout() plt.savefig(output_path, dpi=150) plt.close() logger.info(f"Saved loss plot to {output_path}") def plot_corner( data: np.ndarray, flow_samples: np.ndarray, output_path: str, labels: list[str], ) -> None: """Create corner plot comparing data and flow samples. Args: data: Original data samples flow_samples: Samples from trained flow output_path: Path to save plot labels: Parameter labels for plot """ try: import corner import matplotlib.pyplot as plt except ImportError: logger.warning("corner package not available, skipping corner plot") return hist_kwargs = {"color": "blue", "density": True} fig = corner.corner( data, labels=labels, color="blue", bins=40, smooth=1.0, plot_datapoints=False, plot_density=False, fill_contours=False, levels=[0.68, 0.95], alpha=0.6, hist_kwargs=hist_kwargs, ) hist_kwargs["color"] = "red" corner.corner( flow_samples, fig=fig, color="red", bins=40, smooth=1.0, plot_datapoints=True, # DO plot them for the flow, to check if it violates bounds plot_density=False, fill_contours=False, levels=[0.68, 0.95], alpha=0.6, hist_kwargs=hist_kwargs, ) # Add legend from matplotlib.lines import Line2D legend_elements = [ Line2D([0], [0], color="blue", lw=2, label="Data"), Line2D([0], [0], color="red", lw=2, label="Flow"), ] fig.legend(handles=legend_elements, loc="upper right", fontsize=12) plt.savefig(output_path, dpi=150, bbox_inches="tight") plt.close() logger.info(f"Saved corner plot to {output_path}")
[docs] def train_flow_from_config(config: FlowTrainingConfig) -> None: """ Train a normalizing flow using a configuration object. Args: config: FlowTrainingConfig with all training parameters """ # Log configuration logger.info("=" * 60) logger.info("Normalizing Flow Training") logger.info("=" * 60) logger.info(f"Posterior file: {config.posterior_file}") logger.info(f"Output directory: {config.output_dir}") logger.info(f"Parameter names: {config.parameter_names}") logger.info(f"Max samples: {config.max_samples}") logger.info(f"Flow type: {config.flow_type}") logger.info(f"NN depth: {config.nn_depth}") logger.info(f"NN block dim: {config.nn_block_dim}") logger.info(f"NN width: {config.nn_width}") logger.info(f"Flow layers: {config.flow_layers}") logger.info(f"Invert: {config.invert}") logger.info(f"Cond dim: {config.cond_dim}") logger.info(f"Transformer: {config.transformer}") logger.info(f"Transformer knots: {config.transformer_knots}") logger.info(f"Transformer interval: {config.transformer_interval}") logger.info(f"Standardize: {config.standardize}") logger.info(f"Standardization method: {config.standardization_method}") logger.info(f"Max epochs: {config.num_epochs}") logger.info(f"Learning rate: {config.learning_rate}") logger.info(f"Patience: {config.max_patience}") logger.info(f"Val proportion: {config.val_prop}") logger.info(f"Seed: {config.seed}") logger.info("=" * 60) # Check for GPU logger.info(f"JAX devices: {jax.devices()}") # Load data logger.info("[1/5] Loading posterior samples...") data, load_metadata = load_posterior( config.posterior_file, parameter_names=config.parameter_names, max_samples=config.max_samples, ) parameter_names = load_metadata["parameter_names"] logger.info(f"Data shape: {data.shape}") logger.info(f"Parameters: {parameter_names}") logger.info("Original data ranges:") for i, name in enumerate(parameter_names): logger.info(f" {name}: [{data[:, i].min():.3f}, {data[:, i].max():.3f}]") # Keep copy of original data for corner plot original_data = data.copy() # Standardize data if requested data_statistics = None if config.standardize: if config.standardization_method == "zscore": logger.info("Standardizing data using z-score (mean=0, std=1)...") data, data_statistics = standardize_data_zscore(data) logger.info("Standardized data statistics:") for i, name in enumerate(parameter_names): logger.info( f" {name}: mean={data[:, i].mean():.3f}, std={data[:, i].std():.3f}" ) logger.info("Data mean and std saved for inverse transformation") else: # minmax logger.info("Standardizing data using min-max [0, 1] scaling...") data, data_statistics = standardize_data_minmax(data) logger.info("Standardized data ranges:") for i, name in enumerate(parameter_names): logger.info( f" {name}: [{data[:, i].min():.3f}, {data[:, i].max():.3f}]" ) logger.info("Data bounds saved for inverse transformation") # Create flow logger.info("[2/5] Creating flow architecture...") flow_key, train_key, sample_key = jax.random.split(jax.random.key(config.seed), 3) dim = data.shape[1] # Infer dimensionality from data logger.info(f"Flow dimensionality: {dim}D") flow = create_flow( key=flow_key, dim=dim, flow_type=config.flow_type, nn_depth=config.nn_depth, nn_block_dim=config.nn_block_dim, nn_width=config.nn_width, flow_layers=config.flow_layers, invert=config.invert, cond_dim=config.cond_dim, transformer_type=config.transformer, transformer_knots=config.transformer_knots, transformer_interval=config.transformer_interval, ) # Train flow logger.info("[3/5] Training flow...") logger.info(f"Training dataset shape: {data.shape}") trained_flow, losses = train_flow( flow, data, train_key, learning_rate=config.learning_rate, max_epochs=config.num_epochs, max_patience=config.max_patience, val_prop=config.val_prop, batch_size=config.batch_size, ) logger.info(f"Final train loss: {losses['train'][-1]:.4f}") logger.info(f"Final val loss: {losses['val'][-1]:.4f}") # Save model logger.info("[4/5] Saving model...") flow_kwargs = { "flow_type": config.flow_type, "nn_depth": config.nn_depth, "nn_block_dim": config.nn_block_dim, "nn_width": config.nn_width, "flow_layers": config.flow_layers, "invert": config.invert, "cond_dim": config.cond_dim, "seed": config.seed, "standardize": config.standardize, "standardization_method": config.standardization_method, "transformer_type": config.transformer, "transformer_knots": config.transformer_knots, "transformer_interval": config.transformer_interval, } # Add data statistics if standardization was used if config.standardize and data_statistics is not None: if config.standardization_method == "zscore": flow_kwargs["data_mean"] = data_statistics["mean"].tolist() flow_kwargs["data_std"] = data_statistics["std"].tolist() else: # minmax flow_kwargs["data_bounds_min"] = data_statistics["min"].tolist() flow_kwargs["data_bounds_max"] = data_statistics["max"].tolist() metadata = { **load_metadata, "flow_type": config.flow_type, "num_epochs": len(losses["train"]), "learning_rate": config.learning_rate, "max_patience": config.max_patience, "val_prop": config.val_prop, "standardize": config.standardize, "standardization_method": config.standardization_method, } # Add data statistics to metadata if standardization was used if config.standardize and data_statistics is not None: if config.standardization_method == "zscore": metadata["data_mean"] = data_statistics["mean"].tolist() metadata["data_std"] = data_statistics["std"].tolist() else: # minmax metadata["data_bounds_min"] = data_statistics["min"].tolist() metadata["data_bounds_max"] = data_statistics["max"].tolist() save_model(trained_flow, config.output_dir, flow_kwargs, metadata) # Generate plots logger.info("[5/5] Generating plots...") # Create figures subdirectory figures_dir = os.path.join(config.output_dir, "figures") os.makedirs(figures_dir, exist_ok=True) if config.plot_losses: loss_path = os.path.join(figures_dir, "losses.png") plot_losses(losses, loss_path) if config.plot_corner: try: # Sample from trained flow n_plot_samples = min(10_000, data.shape[0]) flow_samples = trained_flow.sample(sample_key, (n_plot_samples,)) flow_samples_np = np.array(flow_samples) # Inverse transform samples if data was standardized if config.standardize and data_statistics is not None: if config.standardization_method == "zscore": flow_samples_np = inverse_standardize_data_zscore( flow_samples_np, data_statistics ) else: # minmax flow_samples_np = inverse_standardize_data_minmax( flow_samples_np, data_statistics ) corner_path = os.path.join(figures_dir, "corner.png") # Use original_data for corner plot comparison # Update labels based on parameter names plot_corner( original_data, flow_samples_np, corner_path, labels=parameter_names ) except Exception as e: logger.warning( f"Corner plot generation failed, skipping. Error: {type(e).__name__}" ) logger.info("=" * 60) logger.info("Training complete!") logger.info(f"Model saved to: {config.output_dir}") logger.info(f"Figures saved to: {os.path.join(config.output_dir, 'figures')}") logger.info("=" * 60) logger.info("To use the trained flow:") logger.info(">>> from jesterTOV.inference.flows.flow import Flow") logger.info(f">>> flow = Flow.from_directory('{config.output_dir}')") logger.info(">>> samples = flow.sample(jax.random.key(0), (1000,))") if config.standardize: logger.info(">>> # Samples are automatically rescaled to original domain") logger.info("=" * 60)
def main(): """Main entry point for training script.""" if len(sys.argv) < 2: logger.error( "Usage: python -m jesterTOV.inference.flows.train_flow <config.yaml>" ) sys.exit(1) config_path = Path(sys.argv[1]) # Load config from YAML config = FlowTrainingConfig.from_yaml(config_path) # Train flow train_flow_from_config(config) if __name__ == "__main__": main()