jesterTOV.inference.flows.train_flow.train_flow

Contents

jesterTOV.inference.flows.train_flow.train_flow#

train_flow(
flow,
data,
key,
learning_rate=0.001,
max_epochs=600,
max_patience=50,
val_prop=0.2,
batch_size=128,
)[source]#

Train the normalizing flow on data.

Parameters:
  • flow (Any) – Untrained flowjax flow

  • data (ndarray) – Training data of shape (n_samples, n_dims)

  • key (Array) – JAX random key

  • learning_rate (float) – Learning rate for optimizer

  • max_epochs (int) – Maximum number of epochs

  • max_patience (int) – Early stopping patience

  • val_prop (float) – Proportion of data to use for validation

  • batch_size (int) – Batch size for training

Return type:

Tuple[Any, Dict[str, list]]

Returns:

trained_flow – Trained flow model losses: Dictionary with ‘train’ and ‘val’ loss arrays