Source code for fiesta.train.LightcurveTrainer

"""Method to train the surrogate models"""

import dill
import os
import pickle
from typing import Callable, Dict

import numpy as np
import matplotlib.pyplot as plt

import jax
from jaxtyping import Array, Float, Int

from fiesta.filters import Filter
from fiesta.train.DataManager import DataManager
from fiesta.scalers import MinMaxScalerJax
import fiesta.train.neuralnets as fiesta_nn
from fiesta.logging import logger

################
# TRAINING API #
################

[docs] class LightcurveTrainer: """Abstract class for training a collection of surrogate models per filter""" name: str outdir: str filters: list[Filter] parameter_names: list[str] train_X: Float[Array, "n_train"] train_y: Dict[str, Float[Array, "n"]] val_X: Float[Array, "n_val"] val_y: Dict[str, Float[Array, "n"]] def __init__(self, name: str, outdir: str, plots_dir: str = None, save_preprocessed_data: bool = False) -> None: self.name = name # Check if directories exists, otherwise, create: self.outdir = outdir if not os.path.exists(self.outdir): os.makedirs(self.outdir) self.plots_dir = plots_dir if self.plots_dir is not None and not os.path.exists(self.plots_dir): os.makedirs(self.plots_dir) self.save_preprocessed_data = save_preprocessed_data # To be loaded by child classes self.filters = None self.parameter_names = None self.train_X = None self.train_y = None self.val_X = None self.val_y = None def __repr__(self) -> str: return f"LightcurveTrainer(name={self.name})"
[docs] def preprocess(self): logger.info("Preprocessing data by minmax scaling . . .") self.X_scaler = MinMaxScalerJax() self.X = self.X_scaler.fit_transform(self.X_raw) self.y_scaler: dict[str, MinMaxScalerJax] = {} self.y = {} for filt in self.filters: y_scaler = MinMaxScalerJax() self.y[filt.name] = y_scaler.fit_transform(self.y_raw[filt.name]) self.y_scaler[filt.name] = y_scaler logger.info("Preprocessing data . . . done")
[docs] def fit(self, config: fiesta_nn.NeuralnetConfig, key: jax.random.PRNGKey = jax.random.PRNGKey(0), verbose: bool = True) -> None: """ The config controls which architecture is built and therefore should not be specified here. Args: config (nn.NeuralnetConfig, optional): _description_. Defaults to None. """ self.preprocess() if self.save_preprocessed_data: self._save_preprocessed_data() self.config = config self.models = {} input_ndim = len(self.parameter_names) for filt in self.filters: logger.info("\n \n") logger.info(f"Training {filt.name}...") logger.info(f"----------------------------------\n") # Create neural network and initialize the state net = fiesta_nn.MLP(config = config, input_ndim = input_ndim, key = key) # Perform training loop state, train_losses, val_losses = net.train_loop(self.train_X, self.train_y[filt.name], self.val_X, self.val_y[filt.name], verbose=verbose) self.models[filt.name] = net # Plot and save the plot if so desired if self.plots_dir is not None: plt.figure(figsize=(10, 5)) ls = "-o" ms = 3 plt.plot([i+1 for i in range(len(train_losses))], train_losses, ls, markersize=ms, label="Train", color="red") plt.plot([i+1 for i in range(len(val_losses))], val_losses, ls, markersize=ms, label="Validation", color="blue") plt.legend() plt.xlabel("Epoch") plt.ylabel("MSE loss") plt.yscale('log') plt.title("Learning curves") plt.savefig(os.path.join(self.plots_dir, f"learning_curves_{filt.name}.png")) plt.close()
[docs] def plot_example_lc(self, lc_model): _, _, X, y = self.data_manager.load_raw_data_from_file(0,1) # loads validation data y = y.reshape(len(self.data_manager.nus), len(self.data_manager.times)) mJys_val = np.exp(y) params = dict(zip(self.parameter_names, X.flatten() )) _, mag_predict = lc_model.predict_abs_mag(params) mag_val = {Filt.name: Filt.get_mag(mJys_val, self.data_manager.nus) for Filt in lc_model.Filters} for filt in lc_model.Filters: plt.plot(lc_model.times, mag_val[filt.name], color = "red", label="Base model") plt.plot(lc_model.times, mag_predict[filt.name], color = "blue", label="Surrogate prediction") upper_bound = mag_predict[filt.name] + 1 lower_bound = mag_predict[filt.name] - 1 plt.fill_between(lc_model.times, lower_bound, upper_bound, color='blue', alpha=0.2) plt.ylabel(f"mag for {filt.name}") plt.xlabel("$t$ in days") plt.legend() plt.gca().invert_yaxis() plt.xscale('log') plt.xlim(lc_model.times[0], lc_model.times[-1]) if self.plots_dir is None: self.plots_dir = "." plt.savefig(os.path.join(self.plots_dir, f"{self.name}_{filt.name}_example.png")) plt.close()
[docs] def save(self): """ Save the trained model and all the used metadata to the outdir. """ # Save the metadata meta_filename = os.path.join(self.outdir, f"{self.name}_metadata.pkl") save = {} save["times"] = self.times save["parameter_names"] = self.parameter_names save["parameter_distributions"] = self.parameter_distributions save["X_scaler"] = self.X_scaler save["y_scaler"] = self.y_scaler save["model_type"] = "MLP" with open(meta_filename, "wb") as meta_file: dill.dump(save, meta_file) # Save the NN for filt in self.filters: model = self.models[filt.name] model.save_model(outfile = os.path.join(self.outdir, f"{self.name}_{filt.name}.pkl"))
def _save_preprocessed_data(self) -> None: logger.info("Saving preprocessed data . . .") np.savez(os.path.join(self.outdir, f"{self.name}_preprocessed_data.npz"), train_X=self.train_X, train_y = self.train_y, val_X = self.val_X, val_y = self.val_y) logger.info("Saving preprocessed data . . . done")
[docs] class SVDTrainer(LightcurveTrainer): def __init__(self, name: str, outdir: str, filters: list[str], data_manager_args: dict, svd_ncoeff: Int = 50, conversion: str = None, plots_dir: str = None, save_preprocessed_data: bool = False) -> None: """ Initialize the surrogate model trainer that decomposes the training data into its SVD coefficients. The initialization also takes care of reading data and preprocessing it, but does not automatically fit the model. Users may want to inspect the data before fitting the model. Args: name (str): Name of the surrogate model. Will be used outdir (str): Directory where the trained surrogate model is to be saved. filters (list[str]): List of the filters for which the surrogate has to be trained. These have to be either bandpasses from sncosmo or specifiy the frequency through endign with GHz or keV. data_manager_args (dict): data_manager_args (dict): Arguments for the DataManager class instance that will be used to read the data from the .h5 file in outdir and preprocess it. svd_ncoeff (int, optional) : Number of SVD coefficients to use in data reduction during training. Defaults to 50. conversion (str): references how to convert the parameters for the training. Defaults to None, in which case it's the identity. plots_dir (str, optional): Directory where the plots of the training process will be saved. Defaults to None, which means no plots will be generated. save_preprocessed_data (bool, optional): If True, the preprocessed data (reduced, rescaled) will be saved in the outdir. Defaults to False. """ super().__init__(name = name, outdir = outdir, plots_dir = plots_dir, save_preprocessed_data = save_preprocessed_data) self.svd_ncoeff = svd_ncoeff self.conversion = conversion self.data_manager = DataManager(**data_manager_args) self.data_manager.print_file_info() self.data_manager.pass_meta_data(self) self.load_filters(filters)
[docs] def load_filters(self, filters): self.filters = [] for filt in filters: Filt = Filter(filt) if Filt.nus[0] < self.data_manager.nus[0] or Filt.nus[-1] > self.data_manager.nus[-1]: raise ValueError(f"Filter {filt} exceeds the frequency range of the training data.") self.filters.append(Filt)
[docs] def preprocess(self): """ Preprocessing method to get the SVD coefficients of the training and validation data. This includes scaling the inputs and outputs, as well as performing SVD decomposition. """ logger.info(f"Preprocessing data by decomposing training data into SVD coefficients.") self.train_X, self.train_y, self.val_X, self.val_y, self.X_scaler, self.y_scaler = self.data_manager.preprocess_svd(self.svd_ncoeff, self.filters, self.conversion) self.parameter_names += ["redshift"] self.parameter_distributions = self.parameter_distributions[:-1] + ", 'redshift': (0, 0.5, 'uniform')}" # TODO make adding redshift more flexible (i.e. whether to add redshift at all and its range) nan_filters = [] for key in self.train_y.keys(): if np.any(np.isnan(self.train_y[key])) or np.any(np.isnan(self.val_y[key])): logger.warning(f"Data preprocessing for {key} introduced nans. Check raw data for nans of infs or vanishing variance in a specific entry. Removing {key} from training.") nan_filters.append(key) self.filters = [filt for filt in self.filters if filt.name not in nan_filters] for key in nan_filters: del self.train_y[key] del self.val_y[key] del self.y_scaler[key] logger.info(f"Preprocessing data . . . done")