fiesta.train

Contents

fiesta.train#

Components for training surrogate models.

Trainers#

Method to train the surrogate models

class fiesta.train.FluxTrainer.CVAETrainer(name, outdir, data_manager_args, image_size, conversion=None, plots_dir=None, save_preprocessed_data=False)[source]#

Bases: FluxTrainer

fit(config, key=Array([0, 0], dtype=uint32), verbose=True)[source]#

Method used to initialize the autoencoder based on the architecture specified in config and then fit it based on the learning rate and epoch number specified in config. The config controls which architecture is built through config.hidden_layers. The encoder and decoder share the hidden_layers argument, though the layers for the decoder are implemented in reverse order.

Parameters:
  • config (fiesta.train.neuralnets.NeuralnetConfig) – config that needs to specify at least the network output, hidden_layers, learning rate, and learning epochs. Its output_size must be equal to the product of self.image_size.

  • key (jax.random.PRNGKey, optional) – jax.random.PRNGKey used to initialize the parameters of the network. Defaults to jax.random.PRNGKey(0).

  • verbose (bool, optional) – Whether the train and validation loss is printed to terminal in certain intervals. Defaults to True.

Return type:

None

preprocess()[source]#

Preprocessing method to get the down_sample arrays of the standardized training data. It assigns the attributes self.train_X, self.train_y, self.val_X, self.val_y that are passed to the fitting method.

Return type:

None

class fiesta.train.FluxTrainer.FluxTrainer(name, outdir, plots_dir=None, save_preprocessed_data=False)[source]#

Bases: object

Abstract class for training a surrogate model that predicts a spectral flux density array.

fit(config=None, key=Array([0, 0], dtype=uint32), verbose=True)[source]#
Return type:

None

name: str#
outdir: str#
parameter_names: list[str]#
plot_example_lc(lc_model)[source]#
plot_learning_curve(train_losses, val_losses)[source]#
preprocess()[source]#
save()[source]#

Save the trained model and all the metadata to the outdir. The meta data is saved as a pickled dict to be read by fiesta.inference.lightcurve_model.SurrogateLightcurveModel. The NN is saved as a pickled serialized dict using the NN.save_model method.

Return type:

None

class fiesta.train.FluxTrainer.PCATrainer(name, outdir, data_manager_args, n_pca=100, conversion=None, plots_dir=None, save_preprocessed_data=False)[source]#

Bases: FluxTrainer

fit(config, key=Array([0, 0], dtype=uint32), verbose=True)[source]#

Method used to initialize a NN based on the architecture specified in config and then fit it based on the learning rate and epoch number specified in config. The config controls which architecture is built through config.hidden_layers.

Parameters:
  • config (fiesta.train.neuralnets.NeuralnetConfig) – config that needs to specify at least the network output, hidden_layers, learning rate, and learning epochs. Its output_size must be equal to n_pca.

  • key (jax.random.PRNGKey, optional) – jax.random.PRNGKey used to initialize the parameters of the network. Defaults to jax.random.PRNGKey(0).

  • verbose (bool, optional) – Whether the train and validation loss is printed to terminal in certain intervals. Defaults to True.

preprocess()[source]#

Preprocessing method to get the PCA coefficients of the standardized training data. It assigns the attributes self.train_X, self.train_y, self.val_X, self.val_y that are passed to the fitting method.

Method to train the surrogate models

class fiesta.train.LightcurveTrainer.LightcurveTrainer(name, outdir, plots_dir=None, save_preprocessed_data=False)[source]#

Bases: object

Abstract class for training a collection of surrogate models per filter

filters: list[Filter]#
fit(config, key=Array([0, 0], dtype=uint32), verbose=True)[source]#

The config controls which architecture is built and therefore should not be specified here.

Parameters:

config (nn.NeuralnetConfig, optional) – _description_. Defaults to None.

Return type:

None

name: str#
outdir: str#
parameter_names: list[str]#
plot_example_lc(lc_model)[source]#
preprocess()[source]#
save()[source]#

Save the trained model and all the used metadata to the outdir.

class fiesta.train.LightcurveTrainer.SVDTrainer(name, outdir, filters, data_manager_args, svd_ncoeff=50, conversion=None, plots_dir=None, save_preprocessed_data=False)[source]#

Bases: LightcurveTrainer

load_filters(filters)[source]#
preprocess()[source]#

Preprocessing method to get the SVD coefficients of the training and validation data. This includes scaling the inputs and outputs, as well as performing SVD decomposition.

Data#

class fiesta.train.DataManager.DataManager(file, tmin, tmax, numin=1000000000.0, numax=2.5e+18, n_training=None, n_val=None, special_training=[])[source]#

Bases: object

load_raw_data_from_file(n_training=1, n_val=0)[source]#

Loads raw data for training and validation data and returns them as arrays

Return type:

tuple[Array, Array, Array, Array]

pass_meta_data(object)[source]#

Pass training data meta data to another object. Used for the FluxTrainers.

Return type:

None

preprocess_cVAE(image_size, conversion=None)[source]#

Loads in the training and validation data and performs data preprocessing for the CVAE using fiesta.utils.ImageScaler. Because of memory issues, the training data set is loaded in chunks. The X arrays (parameter values) are standardized with fiesta.utils.StandardScalerJax.

Parameters:
  • image_size (Array[Int]) – Image size the 2D flux arrays are down sampled to with jax.image.resize

  • conversion (str) – references how to convert the parameters for the training. Defaults to None, in which case it’s the identity.

Returns:

Standardized training parameters. train_y (Array): PCA coefficients of the training data. val_X (Array): Standardized validation parameters val_y (Array): PCA coefficients of the validation data. Xscaler (StandardScalerJax): Standardizer object fitted to the mean and sigma of the raw training data. Can be used to transform and inverse transform parameter points. yscaler (ImageScaler): ImageScaler object fitted to part of the raw training data. Can be used to transform and inverse transform log spectral flux densities.

Return type:

train_X (Array)

preprocess_pca(n_components, conversion=None)[source]#

Loads in the training and validation data and performs PCA decomposition using fiesta.utils.PCADecomposer. Because of memory issues, the training data set is loaded in chunks. The X arrays (parameter values) are standardized with fiesta.utils.StandardScalerJax.

Parameters:
  • n_components (int) – Number of PCA components to keep.

  • conversion (str) – references how to convert the parameters for the training. Defaults to None, in which case it’s the identity.

Returns:

Standardized training parameters. train_y (Array): PCA coefficients of the training data. val_X (Array): Standardized validation parameters val_y (Array): PCA coefficients of the validation data. Xscaler (StandardScalerJax): Standardizer object fitted to the mean and sigma of the raw training data. Can be used to transform and inverse transform parameter points. yscaler (PCAdecomposer): PCADecomposer object fitted to part of the raw training data. Can be used to transform and inverse transform log spectral flux densities.

Return type:

train_X (Array)

preprocess_svd(svd_ncoeff, filters, conversion=None)[source]#

Loads in the training and validation data and performs data preprocessing for the SVD decomposition using fiesta.utils.SVDDecomposer. This is done per filter supplied in the filters argument which is equivalent to the old NMMA procedure. The X arrays (parameter values) are scaled to [0,1] with MinMaxScalerJax()

Parameters:
  • svd_ncoeff (Int) – Number of SVD coefficients to keep

  • filters (Filter[list]) – List of fiesta.utils.filter instances that are used to convert the fluxes to magnitudes

  • conversion (str) – references how to convert the parameters for the training. Defaults to None, in which case it’s the identity.

Returns:

Scaled training parameters. train_y (dict[Array]): Dictionary of the SVD coefficients of the training magnitude lightcurves with the filter names as keys val_X (Array): Scaled validation parameters val_y (dict[Array]): Dictionary of the SVD coefficients of the validation magnitude lightcurves with the filter names as keys Xscaler (ParameterScaler): MinMaxScaler object fitted to the minimum and maximum of the training data parameters. Can be used to transform and inverse transform parameter points. yscaler (dict[str, SVDDecomposer]): Dictionary of SVDDecomposer objects with the filter names as keys. The SVDDecomposer objects are fitted to the magnitude training data. Can be used to transform and inverse transform magnitudes in this filter.

Return type:

train_X (Array)

print_file_info()[source]#

Prints the meta data of the raw data, i.e., time, frequencies, and parameter names to terminal. Also prints how many training, validation, and test data points are available.

Return type:

None

read_metadata_from_file()[source]#

Reads in the metadata of the raw data, i.e., times, frequencies and parameter names. Also determines how many training and validation data points are available.

Return type:

None

set_up_domain_mask()[source]#

Trims the stored data down to the time and frequency range desired for training. It sets the mask attribute which is a boolean mask used when loading the data arrays.

Return type:

None

fiesta.train.DataManager.array_mask_from_interval(sorted_array, amin, amax)[source]#
fiesta.train.DataManager.concatenate_redshift(X_raw, max_z=0.5)[source]#
fiesta.train.DataManager.redshifted_magnitude(filt, mJys, nus, redshifts)[source]#

This is a slow and inefficient implementation to get the redshifted magnitudes as training data.

Method to train the surrogate models

class fiesta.train.AfterglowData.AfterglowData(outfile, n_training, n_val, n_test, parameter_distributions=None, jet_type=-1, tmin=1.0, tmax=1000.0, n_times=100, use_log_spacing=True, numin=1000000000.0, numax=2.5e+18, n_nu=256, fixed_parameters=None)[source]#

Bases: object

create_raw_data(n, training=True)[source]#

Create draws X in the parameter space and run the afterglow model on it.

create_special_data(X_raw, label, comment=None)[source]#

Create special training data with pre-specified parameters X. These will be stored in the ‘special_train’ hdf5 group.

fix_nans(X, y)[source]#
get_raw_data(n, group)[source]#
initialize_nus(numin, numax, n_nu)[source]#
initialize_times(tmin, tmax, n_times, use_log_spacing=True)[source]#
run_afterglow_model(X)[source]#
class fiesta.train.AfterglowData.AfterglowpyData(n_pool, *args, **kwargs)[source]#

Bases: AfterglowData

run_afterglow_model(X)[source]#

Uses multiprocessing to run afterglowpy on the supplied parameters in X.

class fiesta.train.AfterglowData.BlastwaveData(*args, n_pool=None, **kwargs)[source]#

Bases: AfterglowData

create_raw_data(n, training=True)[source]#

Create draws X in the parameter space and run the blastwave model on it.

run_afterglow_model(X)[source]#

Run blastwave model on the supplied parameters in X.

No multiprocessing.Pool needed — the Rust extension uses rayon for automatic parallelism within each FluxDensity call (releases the GIL via py.allow_threads).

class fiesta.train.AfterglowData.BlastwaveRSData(*args, n_pool=None, **kwargs)[source]#

Bases: BlastwaveData

BlastwaveData variant with reverse shock enabled.

Following Japelj+ 2014 (1402.3701), the RS microphysics are tied to the FS values via a magnetization ratio RB:

eps_e_rs = eps_e_f
eps_b_rs = RB * eps_b_f
p_rs     = p_f

Extra sampled parameter: log10_RB, log10_duration. sigma is kept fixed at 0.0 (unmagnetized ejecta).

create_raw_data(n, training=True)[source]#

Sample parameters, enforce FS and RS energy constraints, then run model.

run_afterglow_model(X)[source]#

Run blastwave model on the supplied parameters in X.

No multiprocessing.Pool needed — the Rust extension uses rayon for automatic parallelism within each FluxDensity call (releases the GIL via py.allow_threads).

fiesta.train.AfterglowData.JetsimpyData#

alias of BlastwaveData

class fiesta.train.AfterglowData.PyblastafterglowData(path_to_exec, pbag_kwargs=None, rank=0, *args, **kwargs)[source]#

Bases: AfterglowData

run_afterglow_model(X)[source]#

Should be run in parallel with different mpi processes to run pyblastafterglow on the parameters in the array X.

supplement_time(t_supp)[source]#

WARNING: NOT READY TO BE USED

class fiesta.train.AfterglowData.RunAfterglowpy(jet_type, times, nus, X, parameter_names, fixed_parameters=None)[source]#

Bases: object

class fiesta.train.AfterglowData.RunBlastwave(times, nus, X, parameter_names, fixed_parameters=None, ncells=33, spread_mode='ode')[source]#

Bases: object

class fiesta.train.AfterglowData.RunBlastwaveRS(times, nus, X, parameter_names, fixed_parameters=None, ncells=33, spread_mode='ode')[source]#

Bases: object

Like RunBlastwave but with reverse shock enabled.

Following Japelj+ 2014: RS microphysics derived from FS values via magnetization ratio RB = eps_b_rs / eps_b_f, with eps_e_rs = eps_e_f and p_rs = p_f.

fiesta.train.AfterglowData.RunJetsimpy#

alias of RunBlastwave

class fiesta.train.AfterglowData.RunPyblastafterglow(jet_type, times, nus, X, parameter_names, fixed_parameters=None, rank=0, path_to_exec='./pba.out', grb_resolution=12, ntb=1000, tb0=10.0, tb1=100000000000.0, rtol=0.1, loglevel='err')[source]#

Bases: object

Neural Networks#

class fiesta.train.neuralnets.CVAE(config, conditional_dim, key=Array((), dtype=key<fry>) overlaying: [ 0 21])[source]#

Bases: object

static load_full_model(filename)[source]#
Return type:

tuple[TrainState, NeuralnetConfig]

static load_model(filename)[source]#

Load a model from a file. TODO: this is very cumbersome now and must be massively improved in the future

Parameters:

filename (str) – Filename of the model to be loaded.

Raises:

ValueError – If there is something wrong with loading, since lots of things can go wrong here.

Returns:

The TrainState object loaded from the file and the NeuralnetConfig object.

Return type:

tuple[TrainState, NeuralnetConfig]

save_model(outfile='my_flax_model.pkl')[source]#

Serialize and save the model to a file.

Raises:

ValueError – If the provided file extension is not .pkl or .pickle.

Parameters:

outfile (str, optional) – The pickle file to which we save the serialized model. Defaults to “my_flax_model.pkl”.

train_loop(train_X, train_y, val_X=None, val_y=None, verbose=True)[source]#
static train_step(state, train_X, train_y, rng, val_X=None, val_y=None)[source]#
Return type:

tuple[TrainState, Array, 'n_batch_train'], Array, 'n_batch_val']]

class fiesta.train.neuralnets.MLP(config, input_ndim, key=Array((), dtype=key<fry>) overlaying: [ 0 21])[source]#

Bases: object

static eval_step(state, X, y, component_weights)[source]#
static load_model(filename)[source]#

Load a model from a file. TODO: this is very cumbersome now and must be massively improved in the future

Parameters:

filename (str) – Filename of the model to be loaded.

Raises:

ValueError – If there is something wrong with loading, since lots of things can go wrong here.

Returns:

The TrainState object loaded from the file and the NeuralnetConfig object.

Return type:

tuple[TrainState, NeuralnetConfig]

save_model(outfile='my_flax_model.pkl')[source]#

Serialize and save the model to a file.

Raises:

ValueError – If the provided file extension is not .pkl or .pickle.

Parameters:

outfile (str, optional) – The pickle file to which we save the serialized model. Defaults to “my_flax_model.pkl”.

train_loop(train_X, train_y, val_X=None, val_y=None, verbose=True)[source]#
static train_step(state, batch_X, batch_y, dropout_rng, component_weights)[source]#
class fiesta.train.neuralnets.NeuralnetConfig(name='MLP', output_size=10, hidden_layer_sizes=[64, 128, 64], latent_dim=20, learning_rate=0.001, weight_decay=0.0, batch_size=128, nb_epochs=1000, nb_report=None, dropout_rate=0.0, use_cosine_schedule=False, cosine_alpha=0.01, max_grad_norm=0.0, pca_smoothness_weight=0.0, pca_smoothness_start=0)[source]#

Bases: ConfigDict

Configuration for a neural network model. For type hinting

batch_size: Int#
hidden_layer_sizes: list[int]#
latent_dim: Int#
layer_sizes: list[int]#
learning_rate: Float#
name: str#
nb_epochs: Int#
nb_report: Int#
output_size: Int#
fiesta.train.neuralnets.bce(y, pred)[source]#

binary cross entropy between y and the predicted array pred

fiesta.train.neuralnets.kld(mean, logvar)[source]#

Kullback-Leibler divergence of a normal distribution with arbitrary mean and log variance to the standard normal distribution with mean 0 and unit variance.

fiesta.train.neuralnets.mse(y, pred)[source]#

square error between y and the predicted array pred

fiesta.train.neuralnets.serialize(state, config=None)[source]#

Serialize function to save the model and its configuration.

Parameters:
  • state (TrainState) – The TrainState object to be serialized.

  • config (NeuralnetConfig, optional) – The config to be serialized. Defaults to None.

Returns:

_description_

Return type:

_type_

class fiesta.train.nn_architectures.BaseNeuralnet(layer_sizes, act_func=<jax._src.custom_derivatives.custom_jvp object>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

Abstract base class. Needs layer sizes and activation function used

layer_sizes: Sequence[int]#
name: str | None = None#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
setup()[source]#

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

class fiesta.train.nn_architectures.CNN(dense_layer_sizes, kernel_sizes, conv_layer_sizes, output_shape, spatial=32, act_func=<jax._src.custom_derivatives.custom_jvp object>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

Convolutional Neural Network

conv_layer_sizes: Sequence[Int]#
dense_layer_sizes: Sequence[Int]#
kernel_sizes: Sequence[Int]#
name: str | None = None#
output_shape: tuple[Int, Int]#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
setup()[source]#

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

spatial: Int = 32#
class fiesta.train.nn_architectures.CVAE(hidden_layer_sizes, output_size, latent_dim=20, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

Conditional Variational Autoencoder consisting of an Encoder and a Decoder.

hidden_layer_sizes: Sequence[Int]#
latent_dim: Int = 20#
name: str | None = None#
output_size: Int#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
setup()[source]#

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

class fiesta.train.nn_architectures.Decoder(layer_sizes, act_func=<jax._src.custom_derivatives.custom_jvp object>, dropout_rate=0.0, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: MLP

name: str | None = None#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
class fiesta.train.nn_architectures.Encoder(layer_sizes, act_func=<jax._src.custom_derivatives.custom_jvp object>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

layer_sizes: Sequence[int]#
name: str | None = None#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
setup()[source]#

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

class fiesta.train.nn_architectures.MLP(layer_sizes, act_func=<jax._src.custom_derivatives.custom_jvp object>, dropout_rate=0.0, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: BaseNeuralnet

Basic multi-layer perceptron: a feedforward neural network with multiple Dense layers.

dropout_rate: float = 0.0#
name: str | None = None#
parent: Module | Scope | _Sentinel | None = None#
scope: Scope | None = None#
setup()[source]#

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

Benchmarking#

class fiesta.train.Benchmarker.Benchmarker(model, data_file, filters=None, outdir='./benchmarks', metric_name='Linf')[source]#

Bases: object

benchmark()[source]#
calculate_error()[source]#
get_data()[source]#
get_error_distribution()[source]#
plot_error_distribution()[source]#
plot_error_over_time()[source]#
plot_lightcurves_mismatch()[source]#
plot_worst_lightcurves()[source]#
print_correlations()[source]#