jesterTOV.inference.flows.train_flow.train_flow#
- train_flow(
- flow,
- data,
- key,
- learning_rate=0.001,
- max_epochs=600,
- max_patience=50,
- val_prop=0.2,
- batch_size=128,
Train the normalizing flow on data.
- Parameters:
flow (
Any) – Untrained flowjax flowdata (
ndarray) – Training data of shape (n_samples, n_dims)key (
Array) – JAX random keylearning_rate (
float) – Learning rate for optimizermax_epochs (
int) – Maximum number of epochsmax_patience (
int) – Early stopping patienceval_prop (
float) – Proportion of data to use for validationbatch_size (
int) – Batch size for training
- Return type:
- Returns:
trained_flow – Trained flow model losses: Dictionary with ‘train’ and ‘val’ loss arrays