"""Method to train the surrogate models"""
import dill
import os
import pickle
import jax
from jaxtyping import Array, Float, Int
import numpy as np
import matplotlib.pyplot as plt
import fiesta.train.neuralnets as fiesta_nn
from fiesta.train.DataManager import DataManager
from fiesta.logging import logger
################
# TRAINING API #
################
[docs]
class FluxTrainer:
"""Abstract class for training a surrogate model that predicts a spectral flux density array."""
name: str
outdir: str
parameter_names: list[str]
train_X: Float[Array, "n_train"]
train_y: Float[Array, "n_train"]
val_X: Float[Array, "n_val"]
val_y: Float[Array, "n_val"]
def __init__(self,
name: str,
outdir: str,
plots_dir: str = None,
save_preprocessed_data: bool = False) -> None:
self.name = name
# Check if directories exists, otherwise, create:
self.outdir = outdir
if not os.path.exists(self.outdir):
os.makedirs(self.outdir)
self.plots_dir = plots_dir
if self.plots_dir is not None and not os.path.exists(self.plots_dir):
os.makedirs(self.plots_dir)
self.save_preprocessed_data = save_preprocessed_data
# To be loaded by child classes
self.parameter_names = None
self.train_X = None
self.train_y = None
self.val_X = None
self.val_y = None
def __repr__(self) -> str:
return f"FluxTrainer(name={self.name})"
[docs]
def preprocess(self):
raise NotImplementedError
[docs]
def fit(self,
config: fiesta_nn.NeuralnetConfig = None,
key: jax.random.PRNGKey = jax.random.PRNGKey(0),
verbose: bool = True) -> None:
raise NotImplementedError
[docs]
def plot_learning_curve(self, train_losses, val_losses):
fig, ax = plt.subplots(figsize=(10, 5))
epochs = np.arange(1, len(train_losses) + 1)
ax.plot(epochs, train_losses, "-", lw=1.0, label="Train", color="red")
ax.plot(epochs, val_losses, "-", lw=1.0, label="Validation", color="blue")
# Mark best validation epoch
best_idx = np.argmin(val_losses)
ax.axvline(best_idx + 1, color="blue", ls="--", alpha=0.4, lw=0.8)
ax.annotate(f"best val @ {best_idx + 1}", xy=(best_idx + 1, val_losses[best_idx]),
fontsize=8, color="blue", alpha=0.7,
xytext=(10, 10), textcoords="offset points")
ax.legend(fontsize=11)
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.set_yscale("log")
ax.set_title("Learning curves")
ax.grid(True, alpha=0.3)
fig.savefig(os.path.join(self.plots_dir, f"learning_curves_{self.name}.png"),
bbox_inches="tight", dpi=150)
plt.close(fig)
[docs]
def plot_example_lc(self, lc_model):
_, _, X, y = self.data_manager.load_raw_data_from_file(0,1) # loads validation data
y = y.reshape(len(self.data_manager.nus), len(self.data_manager.times))
mJys_val = np.power(10, y)
params = dict(zip(self.parameter_names, X.flatten() ))
_, mag_predict = lc_model.predict_abs_mag(params)
mag_val = {Filt.name: Filt.get_mag(mJys_val, self.data_manager.nus) for Filt in lc_model.Filters}
for filt in lc_model.Filters:
plt.plot(lc_model.times, mag_val[filt.name], color = "red", label="Base model")
plt.plot(lc_model.times, mag_predict[filt.name], color = "blue", label="Surrogate prediction")
upper_bound = mag_predict[filt.name] + 1
lower_bound = mag_predict[filt.name] - 1
plt.fill_between(lc_model.times, lower_bound, upper_bound, color='blue', alpha=0.2)
plt.ylabel(f"mag for {filt.name}")
plt.xlabel("$t$ in days")
plt.legend()
plt.gca().invert_yaxis()
plt.xscale('log')
plt.xlim(lc_model.times[0], lc_model.times[-1])
if self.plots_dir is None:
self.plots_dir = "."
plt.savefig(os.path.join(self.plots_dir, f"{self.name}_{filt.name}_example.png"), bbox_inches="tight")
plt.close()
[docs]
def save(self) -> None:
"""
Save the trained model and all the metadata to the outdir.
The meta data is saved as a pickled dict to be read by fiesta.inference.lightcurve_model.SurrogateLightcurveModel.
The NN is saved as a pickled serialized dict using the NN.save_model method.
"""
# Save the metadata
meta_filename = os.path.join(self.outdir, f"{self.name}_metadata.pkl")
save = {}
save["times"] = self.times
save["nus"] = self.nus
save["parameter_names"] = self.parameter_names
save["parameter_distributions"] = self.parameter_distributions
save["X_scaler"] = self.X_scaler
save["y_scaler"] = self.y_scaler
save["model_type"] = self.model_type
with open(meta_filename, "wb") as meta_file:
dill.dump(save, meta_file)
# Save the NN
self.network.save_model(outfile=os.path.join(self.outdir, f"{self.name}.pkl"))
def _save_preprocessed_data(self) -> None:
logger.info("Saving preprocessed data . . .")
np.savez(os.path.join(self.outdir, f"{self.name}_preprocessed_data.npz"), train_X=self.train_X, train_y=self.train_y, val_X=self.val_X, val_y=self.val_y)
logger.info("Saving preprocessed data . . . done")
[docs]
class PCATrainer(FluxTrainer):
def __init__(self,
name: str,
outdir: str,
data_manager_args: dict,
n_pca: Int = 100,
conversion: str = None,
plots_dir: str = None,
save_preprocessed_data: bool = False) -> None:
"""
FluxTrainer for training a feed-forward neural network on the PCA coefficients of the training data to predict the full 2D spectral flux density array.
Initializing will read the data and preprocess it with the DataManager class. It can then be fit with the fit() method.
To write the surrogate model to file, the save() method is to be used, which will create two pickle files (one for the metadata, one for the neural network).
Args:
name (str): Name of the model to be trained. Will be used when saving metadata and model to file.
outdir (str): Directory where the NN and its metadata will be written to file.
data_manager_args (dict): Arguments for the DataManager class instance that will be used to read the data from the .h5 file in outdir and preprocess it.
n_pca (int): Number of PCA components that will be kept when performing data preprocessing. Defaults to 100.
conversion (str): references how to convert the parameters for the training. Defaults to None, in which case it's the identity.
plots_dir (str): Directory where the loss curves will be plotted. If None, the plot will not be created. Defaults to None.
save_preprocessed_data (bool): Whether the preprocessed (i.e. PCA decomposed) training and validation data will be written to file. Defaults to False.
"""
super().__init__(name = name,
outdir = outdir,
plots_dir = plots_dir,
save_preprocessed_data = save_preprocessed_data)
self.model_type = "MLP"
self.n_pca = n_pca
self.conversion = conversion
self.data_manager = DataManager(**data_manager_args)
self.data_manager.print_file_info()
self.data_manager.pass_meta_data(self)
[docs]
def preprocess(self):
"""
Preprocessing method to get the PCA coefficients of the standardized training data.
It assigns the attributes self.train_X, self.train_y, self.val_X, self.val_y that are passed to the fitting method.
"""
logger.info(f"Preprocessing data by decomposing data into {self.n_pca} components.")
self.train_X, self.train_y, self.val_X, self.val_y, self.X_scaler, self.y_scaler = self.data_manager.preprocess_pca(self.n_pca, self.conversion)
if np.any(np.isnan(self.train_y)) or np.any(np.isnan(self.val_y)):
raise ValueError(f"Data preprocessing introduced nans. Check raw data for nans of infs or vanishing variance in a specific entry.")
logger.info(f"PCA decomposition accounts for a share {np.sum(self.y_scaler.scalers[0].explained_variance_ratio_)} of the total variance in the training data. This value is hopefully close to 1.")
logger.info("Preprocessing data . . . done")
[docs]
def fit(self,
config: fiesta_nn.NeuralnetConfig,
key: jax.random.PRNGKey = jax.random.PRNGKey(0),
verbose: bool = True):
"""
Method used to initialize a NN based on the architecture specified in config and then fit it based on the learning rate and epoch number specified in config.
The config controls which architecture is built through config.hidden_layers.
Args:
config (fiesta.train.neuralnets.NeuralnetConfig): config that needs to specify at least the network output, hidden_layers, learning rate, and learning epochs. Its output_size must be equal to n_pca.
key (jax.random.PRNGKey, optional): jax.random.PRNGKey used to initialize the parameters of the network. Defaults to jax.random.PRNGKey(0).
verbose (bool, optional): Whether the train and validation loss is printed to terminal in certain intervals. Defaults to True.
"""
self.preprocess()
if self.save_preprocessed_data:
self._save_preprocessed_data()
self.config = config
self.config.output_size = self.n_pca # the config.output_size has to be equal to the number of PCA components
input_ndim = self.train_X.shape[1]
# Create neural network and initialize the state
self.network = fiesta_nn.MLP(config = config, input_ndim = input_ndim, key = key)
# Perform training loop
state, train_losses, val_losses = self.network.train_loop(self.train_X, self.train_y, self.val_X, self.val_y, verbose=verbose)
# Plot and save the plot if so desired
if self.plots_dir is not None:
self.plot_learning_curve(train_losses, val_losses)
[docs]
class CVAETrainer(FluxTrainer):
def __init__(self,
name: str,
outdir,
data_manager_args,
image_size: tuple[Int],
conversion: str = None,
plots_dir: str = None,
save_preprocessed_data=False)->None:
"""
FluxTrainer for training a conditional variational autoencoder on the log fluxes of the training data to predict the full 2D spectral flux density array.
Initializing will read the data and preprocess it with the DataManager class. It can then be fit with the fit() method.
To write the surrogate model to file, the save() method is to be used, which will create two pickle files (one for the metadata, one for the neural network).
Args:
name (str): Name of the model to be trained. Will be used when saving metadata and model to file.
outdir (str): Directory where the NN and its metadata will be written to file.
data_manager_args (dict): Arguments for the DataManager class instance that will be used to read the data from the .h5 file in outdir and preprocess it.
image_size (tuple(Int)): Size the 2D flux array will be down-sampled to with jax.image.resize when performing data preprocessing.
conversion (str): references how to convert the parameters for the training. Defaults to None, in which case it's the identity.
plots_dir (str): Directory where the loss curves will be plotted. If None, the plot will not be created. Defaults to None.
save_preprocessed_data (bool): Whether the preprocessed (i.e. down sampled and standardized) training and validation data will be written to file. Defaults to False.
"""
super().__init__(name = name,
outdir = outdir,
plots_dir = plots_dir,
save_preprocessed_data = save_preprocessed_data)
self.model_type = "CVAE"
self.data_manager = DataManager(**data_manager_args)
self.data_manager.print_file_info()
self.data_manager.pass_meta_data(self)
self.image_size = image_size
self.conversion = conversion
[docs]
def preprocess(self)-> None:
"""
Preprocessing method to get the down_sample arrays of the standardized training data.
It assigns the attributes self.train_X, self.train_y, self.val_X, self.val_y that are passed to the fitting method.
"""
logger.info(f"Preprocessing data by resampling flux array to {self.image_size} and standardizing.")
self.train_X, self.train_y, self.val_X, self.val_y, self.X_scaler, self.y_scaler = self.data_manager.preprocess_cVAE(self.image_size, self.conversion)
if np.any(np.isnan(self.train_y)) or np.any(np.isnan(self.val_y)):
raise ValueError(f"Data preprocessing introduced nans. Check raw data for nans of infs or vanishing variance in a specific entry.")
logger.info("Preprocessing data . . . done")
[docs]
def fit(self,
config: fiesta_nn.NeuralnetConfig,
key: jax.random.PRNGKey = jax.random.PRNGKey(0),
verbose: bool = True) -> None:
"""
Method used to initialize the autoencoder based on the architecture specified in config and then fit it based on the learning rate and epoch number specified in config.
The config controls which architecture is built through config.hidden_layers. The encoder and decoder share the hidden_layers argument, though the layers for the decoder are implemented in reverse order.
Args:
config (fiesta.train.neuralnets.NeuralnetConfig): config that needs to specify at least the network output, hidden_layers, learning rate, and learning epochs. Its output_size must be equal to the product of self.image_size.
key (jax.random.PRNGKey, optional): jax.random.PRNGKey used to initialize the parameters of the network. Defaults to jax.random.PRNGKey(0).
verbose (bool, optional): Whether the train and validation loss is printed to terminal in certain intervals. Defaults to True.
"""
self.preprocess()
if self.save_preprocessed_data:
self._save_preprocessed_data()
self.config = config
config.output_size = int(np.prod(self.image_size)) # Output must be equal to the product of self.image_size.
self.network = fiesta_nn.CVAE(config=self.config, conditional_dim=self.train_X.shape[1], key=key)
state, train_losses, val_losses = self.network.train_loop(self.train_X, self.train_y, self.val_X, self.val_y, verbose=verbose)
# Plot and save the plot if so desired
if self.plots_dir is not None:
self.plot_learning_curve(train_losses, val_losses)