"""Pydantic configuration schema for normalizing flow training.
This module provides type-safe configuration for training normalizing flows
on posterior samples (GW, NICER, EOS parameters, etc.), replacing the argparse
interface in train_flow.py.
"""
from pathlib import Path
from typing import Literal
import yaml
from pydantic import BaseModel, field_validator
[docs]
class FlowTrainingConfig(BaseModel):
"""Configuration for training normalizing flows on posterior samples.
Attributes
----------
posterior_file : str
Path to .npz file with posterior samples
output_dir : str
Directory to save model weights, kwargs, and plots
parameter_names : list[str]
List of parameter names to extract from posterior file.
Examples: GW parameters ["mass_1_source", "mass_2_source", "lambda_1", "lambda_2"],
NICER ["mass", "radius"]
num_epochs : int
Number of training epochs (default: 600)
learning_rate : float
Learning rate for training (default: 1e-3)
max_patience : int
Early stopping patience (default: 50)
nn_depth : int
Depth of neural network blocks (default: 5)
nn_block_dim : int
Dimension of neural network blocks (default: 8)
flow_layers : int
Number of flow layers (default: 1)
invert : bool
Whether to invert the flow (default: True)
cond_dim : int | None
Conditional dimension for conditional flows (default: None)
max_samples : int
Maximum number of samples to use for training (default: 50,000)
seed : int
Random seed for reproducibility (default: 0)
plot_corner : bool
Generate corner plot comparison (default: True)
plot_losses : bool
Plot training and validation losses (default: True)
flow_type : Literal["block_neural_autoregressive_flow", "masked_autoregressive_flow", "coupling_flow"]
Type of normalizing flow to use (default: masked_autoregressive_flow)
nn_width : int
Width of neural network hidden layers (default: 50)
standardize : bool
Whether to standardize input data (default: True, changed from False)
standardization_method : Literal["zscore", "minmax"]
Method for standardizing input data (default: zscore).
- "zscore": Standardize to mean=0, std=1 (recommended for most cases)
- "minmax": Standardize to [0, 1] range (legacy, for backward compatibility)
Only used if standardize=True.
transformer : Literal["affine", "rational_quadratic_spline"]
Transformer type for masked_autoregressive_flow and coupling_flow
(default: rational_quadratic_spline, changed from affine)
transformer_knots : int
Number of knots for RationalQuadraticSpline transformer (default: 10, changed from 8)
transformer_interval : float
Interval for RationalQuadraticSpline transformer (default: 5.0, changed from 4.0)
val_prop : float
Proportion of data to use for validation (default: 0.2)
batch_size : int
Batch size for training (default: 128)
"""
posterior_file: str
output_dir: str
parameter_names: list[str]
num_epochs: int = 600
learning_rate: float = 1e-3
max_patience: int = 50
nn_depth: int = 5
nn_block_dim: int = 8
flow_layers: int = 1
invert: bool = True
cond_dim: int | None = None
max_samples: int = 50_000
seed: int = 0
plot_corner: bool = True
plot_losses: bool = True
flow_type: Literal[
"block_neural_autoregressive_flow",
"masked_autoregressive_flow",
"coupling_flow",
] = "masked_autoregressive_flow"
nn_width: int = 50
standardize: bool = True
standardization_method: Literal["zscore", "minmax"] = "zscore"
transformer: Literal["affine", "rational_quadratic_spline"] = (
"rational_quadratic_spline"
)
transformer_knots: int = 10
transformer_interval: float = 5.0
val_prop: float = 0.2
batch_size: int = 128
[docs]
@field_validator(
"num_epochs",
"max_patience",
"nn_depth",
"nn_block_dim",
"flow_layers",
"max_samples",
"nn_width",
"transformer_knots",
"batch_size",
)
@classmethod
def validate_positive_int(cls, v: int) -> int:
"""Validate that integer value is positive."""
if v <= 0:
raise ValueError(f"Value must be positive, got: {v}")
return v
[docs]
@field_validator("learning_rate", "val_prop", "transformer_interval")
@classmethod
def validate_positive_float(cls, v: float) -> float:
"""Validate that float value is positive."""
if v <= 0:
raise ValueError(f"Value must be positive, got: {v}")
return v
[docs]
@field_validator("val_prop")
@classmethod
def validate_val_prop_range(cls, v: float) -> float:
"""Validate that validation proportion is in (0, 1)."""
if v <= 0 or v >= 1:
raise ValueError(f"val_prop must be in (0, 1), got: {v}")
return v
[docs]
@field_validator("parameter_names")
@classmethod
def validate_parameter_names(cls, v: list[str]) -> list[str]:
"""Validate that parameter_names is a non-empty list."""
if len(v) == 0:
raise ValueError("parameter_names cannot be an empty list.")
return v
[docs]
@classmethod
def from_yaml(cls, filepath: str | Path) -> "FlowTrainingConfig":
"""
Load configuration from a YAML file.
Args:
filepath: Path to YAML configuration file
Returns:
FlowTrainingConfig instance with loaded configuration
Example:
>>> config = FlowTrainingConfig.from_yaml("config.yaml")
"""
with open(filepath, "r") as f:
config_dict = yaml.safe_load(f)
return cls(**config_dict)