fiesta.train#
Components for training surrogate models.
Trainers#
Method to train the surrogate models
- class fiesta.train.FluxTrainer.CVAETrainer(name, outdir, data_manager_args, image_size, conversion=None, plots_dir=None, save_preprocessed_data=False)[source]#
Bases:
FluxTrainer- fit(config, key=Array([0, 0], dtype=uint32), verbose=True)[source]#
Method used to initialize the autoencoder based on the architecture specified in config and then fit it based on the learning rate and epoch number specified in config. The config controls which architecture is built through config.hidden_layers. The encoder and decoder share the hidden_layers argument, though the layers for the decoder are implemented in reverse order.
- Parameters:
config (fiesta.train.neuralnets.NeuralnetConfig) – config that needs to specify at least the network output, hidden_layers, learning rate, and learning epochs. Its output_size must be equal to the product of self.image_size.
key (jax.random.PRNGKey, optional) – jax.random.PRNGKey used to initialize the parameters of the network. Defaults to jax.random.PRNGKey(0).
verbose (bool, optional) – Whether the train and validation loss is printed to terminal in certain intervals. Defaults to True.
- Return type:
- class fiesta.train.FluxTrainer.FluxTrainer(name, outdir, plots_dir=None, save_preprocessed_data=False)[source]#
Bases:
objectAbstract class for training a surrogate model that predicts a spectral flux density array.
- class fiesta.train.FluxTrainer.PCATrainer(name, outdir, data_manager_args, n_pca=100, conversion=None, plots_dir=None, save_preprocessed_data=False)[source]#
Bases:
FluxTrainer- fit(config, key=Array([0, 0], dtype=uint32), verbose=True)[source]#
Method used to initialize a NN based on the architecture specified in config and then fit it based on the learning rate and epoch number specified in config. The config controls which architecture is built through config.hidden_layers.
- Parameters:
config (fiesta.train.neuralnets.NeuralnetConfig) – config that needs to specify at least the network output, hidden_layers, learning rate, and learning epochs. Its output_size must be equal to n_pca.
key (jax.random.PRNGKey, optional) – jax.random.PRNGKey used to initialize the parameters of the network. Defaults to jax.random.PRNGKey(0).
verbose (bool, optional) – Whether the train and validation loss is printed to terminal in certain intervals. Defaults to True.
Method to train the surrogate models
- class fiesta.train.LightcurveTrainer.LightcurveTrainer(name, outdir, plots_dir=None, save_preprocessed_data=False)[source]#
Bases:
objectAbstract class for training a collection of surrogate models per filter
- class fiesta.train.LightcurveTrainer.SVDTrainer(name, outdir, filters, data_manager_args, svd_ncoeff=50, conversion=None, plots_dir=None, save_preprocessed_data=False)[source]#
Bases:
LightcurveTrainer
Data#
- class fiesta.train.DataManager.DataManager(file, tmin, tmax, numin=1000000000.0, numax=2.5e+18, n_training=None, n_val=None, special_training=[])[source]#
Bases:
object- load_raw_data_from_file(n_training=1, n_val=0)[source]#
Loads raw data for training and validation data and returns them as arrays
- Return type:
tuple[Array,Array,Array,Array]
- pass_meta_data(object)[source]#
Pass training data meta data to another object. Used for the FluxTrainers.
- Return type:
- preprocess_cVAE(image_size, conversion=None)[source]#
Loads in the training and validation data and performs data preprocessing for the CVAE using fiesta.utils.ImageScaler. Because of memory issues, the training data set is loaded in chunks. The X arrays (parameter values) are standardized with fiesta.utils.StandardScalerJax.
- Parameters:
image_size (Array[Int]) – Image size the 2D flux arrays are down sampled to with jax.image.resize
conversion (str) – references how to convert the parameters for the training. Defaults to None, in which case it’s the identity.
- Returns:
Standardized training parameters. train_y (Array): PCA coefficients of the training data. val_X (Array): Standardized validation parameters val_y (Array): PCA coefficients of the validation data. Xscaler (StandardScalerJax): Standardizer object fitted to the mean and sigma of the raw training data. Can be used to transform and inverse transform parameter points. yscaler (ImageScaler): ImageScaler object fitted to part of the raw training data. Can be used to transform and inverse transform log spectral flux densities.
- Return type:
train_X (Array)
- preprocess_pca(n_components, conversion=None)[source]#
Loads in the training and validation data and performs PCA decomposition using fiesta.utils.PCADecomposer. Because of memory issues, the training data set is loaded in chunks. The X arrays (parameter values) are standardized with fiesta.utils.StandardScalerJax.
- Parameters:
- Returns:
Standardized training parameters. train_y (Array): PCA coefficients of the training data. val_X (Array): Standardized validation parameters val_y (Array): PCA coefficients of the validation data. Xscaler (StandardScalerJax): Standardizer object fitted to the mean and sigma of the raw training data. Can be used to transform and inverse transform parameter points. yscaler (PCAdecomposer): PCADecomposer object fitted to part of the raw training data. Can be used to transform and inverse transform log spectral flux densities.
- Return type:
train_X (Array)
- preprocess_svd(svd_ncoeff, filters, conversion=None)[source]#
Loads in the training and validation data and performs data preprocessing for the SVD decomposition using fiesta.utils.SVDDecomposer. This is done per filter supplied in the filters argument which is equivalent to the old NMMA procedure. The X arrays (parameter values) are scaled to [0,1] with MinMaxScalerJax()
- Parameters:
- Returns:
Scaled training parameters. train_y (dict[Array]): Dictionary of the SVD coefficients of the training magnitude lightcurves with the filter names as keys val_X (Array): Scaled validation parameters val_y (dict[Array]): Dictionary of the SVD coefficients of the validation magnitude lightcurves with the filter names as keys Xscaler (ParameterScaler): MinMaxScaler object fitted to the minimum and maximum of the training data parameters. Can be used to transform and inverse transform parameter points. yscaler (dict[str, SVDDecomposer]): Dictionary of SVDDecomposer objects with the filter names as keys. The SVDDecomposer objects are fitted to the magnitude training data. Can be used to transform and inverse transform magnitudes in this filter.
- Return type:
train_X (Array)
- print_file_info()[source]#
Prints the meta data of the raw data, i.e., time, frequencies, and parameter names to terminal. Also prints how many training, validation, and test data points are available.
- Return type:
- fiesta.train.DataManager.redshifted_magnitude(filt, mJys, nus, redshifts)[source]#
This is a slow and inefficient implementation to get the redshifted magnitudes as training data.
Method to train the surrogate models
- class fiesta.train.AfterglowData.AfterglowData(outfile, n_training, n_val, n_test, parameter_distributions=None, jet_type=-1, tmin=1.0, tmax=1000.0, n_times=100, use_log_spacing=True, numin=1000000000.0, numax=2.5e+18, n_nu=256, fixed_parameters=None)[source]#
Bases:
object- create_raw_data(n, training=True)[source]#
Create draws X in the parameter space and run the afterglow model on it.
- class fiesta.train.AfterglowData.AfterglowpyData(n_pool, *args, **kwargs)[source]#
Bases:
AfterglowData
- class fiesta.train.AfterglowData.BlastwaveData(*args, n_pool=None, **kwargs)[source]#
Bases:
AfterglowData
- class fiesta.train.AfterglowData.BlastwaveRSData(*args, n_pool=None, **kwargs)[source]#
Bases:
BlastwaveDataBlastwaveData variant with reverse shock enabled.
Following Japelj+ 2014 (1402.3701), the RS microphysics are tied to the FS values via a magnetization ratio RB:
eps_e_rs = eps_e_f eps_b_rs = RB * eps_b_f p_rs = p_f
Extra sampled parameter: log10_RB, log10_duration. sigma is kept fixed at 0.0 (unmagnetized ejecta).
- fiesta.train.AfterglowData.JetsimpyData#
alias of
BlastwaveData
- class fiesta.train.AfterglowData.PyblastafterglowData(path_to_exec, pbag_kwargs=None, rank=0, *args, **kwargs)[source]#
Bases:
AfterglowData
- class fiesta.train.AfterglowData.RunAfterglowpy(jet_type, times, nus, X, parameter_names, fixed_parameters=None)[source]#
Bases:
object
- class fiesta.train.AfterglowData.RunBlastwave(times, nus, X, parameter_names, fixed_parameters=None, ncells=33, spread_mode='ode')[source]#
Bases:
object
- class fiesta.train.AfterglowData.RunBlastwaveRS(times, nus, X, parameter_names, fixed_parameters=None, ncells=33, spread_mode='ode')[source]#
Bases:
objectLike RunBlastwave but with reverse shock enabled.
Following Japelj+ 2014: RS microphysics derived from FS values via magnetization ratio RB = eps_b_rs / eps_b_f, with eps_e_rs = eps_e_f and p_rs = p_f.
- fiesta.train.AfterglowData.RunJetsimpy#
alias of
RunBlastwave
Neural Networks#
- class fiesta.train.neuralnets.CVAE(config, conditional_dim, key=Array((), dtype=key<fry>) overlaying: [ 0 21])[source]#
Bases:
object- static load_full_model(filename)[source]#
- Return type:
tuple[TrainState,NeuralnetConfig]
- static load_model(filename)[source]#
Load a model from a file. TODO: this is very cumbersome now and must be massively improved in the future
- Parameters:
filename (str) – Filename of the model to be loaded.
- Raises:
ValueError – If there is something wrong with loading, since lots of things can go wrong here.
- Returns:
The TrainState object loaded from the file and the NeuralnetConfig object.
- Return type:
tuple[TrainState, NeuralnetConfig]
- save_model(outfile='my_flax_model.pkl')[source]#
Serialize and save the model to a file.
- Raises:
ValueError – If the provided file extension is not .pkl or .pickle.
- Parameters:
outfile (str, optional) – The pickle file to which we save the serialized model. Defaults to “my_flax_model.pkl”.
- class fiesta.train.neuralnets.MLP(config, input_ndim, key=Array((), dtype=key<fry>) overlaying: [ 0 21])[source]#
Bases:
object- static load_model(filename)[source]#
Load a model from a file. TODO: this is very cumbersome now and must be massively improved in the future
- Parameters:
filename (str) – Filename of the model to be loaded.
- Raises:
ValueError – If there is something wrong with loading, since lots of things can go wrong here.
- Returns:
The TrainState object loaded from the file and the NeuralnetConfig object.
- Return type:
tuple[TrainState, NeuralnetConfig]
- save_model(outfile='my_flax_model.pkl')[source]#
Serialize and save the model to a file.
- Raises:
ValueError – If the provided file extension is not .pkl or .pickle.
- Parameters:
outfile (str, optional) – The pickle file to which we save the serialized model. Defaults to “my_flax_model.pkl”.
- class fiesta.train.neuralnets.NeuralnetConfig(name='MLP', output_size=10, hidden_layer_sizes=[64, 128, 64], latent_dim=20, learning_rate=0.001, weight_decay=0.0, batch_size=128, nb_epochs=1000, nb_report=None, dropout_rate=0.0, use_cosine_schedule=False, cosine_alpha=0.01, max_grad_norm=0.0, pca_smoothness_weight=0.0, pca_smoothness_start=0)[source]#
Bases:
ConfigDictConfiguration for a neural network model. For type hinting
- batch_size: Int#
- latent_dim: Int#
- learning_rate: Float#
- nb_epochs: Int#
- nb_report: Int#
- output_size: Int#
- fiesta.train.neuralnets.bce(y, pred)[source]#
binary cross entropy between y and the predicted array pred
- fiesta.train.neuralnets.kld(mean, logvar)[source]#
Kullback-Leibler divergence of a normal distribution with arbitrary mean and log variance to the standard normal distribution with mean 0 and unit variance.
- fiesta.train.neuralnets.serialize(state, config=None)[source]#
Serialize function to save the model and its configuration.
- Parameters:
state (TrainState) – The TrainState object to be serialized.
config (NeuralnetConfig, optional) – The config to be serialized. Defaults to None.
- Returns:
_description_
- Return type:
_type_
- class fiesta.train.nn_architectures.BaseNeuralnet(layer_sizes, act_func=<jax._src.custom_derivatives.custom_jvp object>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
ModuleAbstract base class. Needs layer sizes and activation function used
- setup()[source]#
Initializes a Module lazily (similar to a lazy
__init__).setupis called once lazily on a module instance when a module is bound, immediately before any other methods like__call__are invoked, or before asetup-defined attribute onselfis accessed.This can happen in three cases:
Immediately when invoking
apply(),init()orinit_and_output().Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setupmethod (see__setattr__()):>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact(), immediately before another method is called orsetupdefined attribute is accessed.
- class fiesta.train.nn_architectures.CNN(dense_layer_sizes, kernel_sizes, conv_layer_sizes, output_shape, spatial=32, act_func=<jax._src.custom_derivatives.custom_jvp object>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
ModuleConvolutional Neural Network
- setup()[source]#
Initializes a Module lazily (similar to a lazy
__init__).setupis called once lazily on a module instance when a module is bound, immediately before any other methods like__call__are invoked, or before asetup-defined attribute onselfis accessed.This can happen in three cases:
Immediately when invoking
apply(),init()orinit_and_output().Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setupmethod (see__setattr__()):>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact(), immediately before another method is called orsetupdefined attribute is accessed.
- spatial: Int = 32#
- class fiesta.train.nn_architectures.CVAE(hidden_layer_sizes, output_size, latent_dim=20, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
ModuleConditional Variational Autoencoder consisting of an Encoder and a Decoder.
- latent_dim: Int = 20#
- output_size: Int#
- setup()[source]#
Initializes a Module lazily (similar to a lazy
__init__).setupis called once lazily on a module instance when a module is bound, immediately before any other methods like__call__are invoked, or before asetup-defined attribute onselfis accessed.This can happen in three cases:
Immediately when invoking
apply(),init()orinit_and_output().Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setupmethod (see__setattr__()):>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact(), immediately before another method is called orsetupdefined attribute is accessed.
- class fiesta.train.nn_architectures.Decoder(layer_sizes, act_func=<jax._src.custom_derivatives.custom_jvp object>, dropout_rate=0.0, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
MLP
- class fiesta.train.nn_architectures.Encoder(layer_sizes, act_func=<jax._src.custom_derivatives.custom_jvp object>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
Module- setup()[source]#
Initializes a Module lazily (similar to a lazy
__init__).setupis called once lazily on a module instance when a module is bound, immediately before any other methods like__call__are invoked, or before asetup-defined attribute onselfis accessed.This can happen in three cases:
Immediately when invoking
apply(),init()orinit_and_output().Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setupmethod (see__setattr__()):>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact(), immediately before another method is called orsetupdefined attribute is accessed.
- class fiesta.train.nn_architectures.MLP(layer_sizes, act_func=<jax._src.custom_derivatives.custom_jvp object>, dropout_rate=0.0, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
BaseNeuralnetBasic multi-layer perceptron: a feedforward neural network with multiple Dense layers.
- setup()[source]#
Initializes a Module lazily (similar to a lazy
__init__).setupis called once lazily on a module instance when a module is bound, immediately before any other methods like__call__are invoked, or before asetup-defined attribute onselfis accessed.This can happen in three cases:
Immediately when invoking
apply(),init()orinit_and_output().Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setupmethod (see__setattr__()):>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact(), immediately before another method is called orsetupdefined attribute is accessed.