Source code for fiesta.train.neuralnets

import time

import jax
import jax.numpy as jnp
from jaxtyping import Array, Float, Int

import flax
from flax import linen as nn  # Linen API
from flax.training.train_state import TrainState
from ml_collections import ConfigDict
import optax
import pickle

import fiesta.train.nn_architectures as nn
from fiesta.logging import logger

###############
### CONFIGS ###
###############

[docs] class NeuralnetConfig(ConfigDict): """Configuration for a neural network model. For type hinting""" name: str output_size: Int hidden_layer_sizes: list[int] layer_sizes: list[int] latent_dim: Int learning_rate: Float batch_size: Int nb_epochs: Int nb_report: Int def __init__(self, name: str = "MLP", output_size: int = 10, hidden_layer_sizes: list[int] = [64, 128, 64], latent_dim: int = 20, learning_rate: Float = 1e-3, weight_decay: Float = 0.0, batch_size: int = 128, nb_epochs: Int = 1_000, nb_report: Int = None, dropout_rate: float = 0.0, use_cosine_schedule: bool = False, cosine_alpha: float = 0.01, max_grad_norm: float = 0.0, pca_smoothness_weight: float = 0.0, pca_smoothness_start: int = 0): super().__init__() self.name = name self.output_size = output_size self.hidden_layer_sizes = hidden_layer_sizes self.layer_sizes = [*hidden_layer_sizes, output_size] self.latent_dim = latent_dim self.learning_rate = learning_rate self.weight_decay = weight_decay self.batch_size = batch_size self.nb_epochs = nb_epochs if nb_report is None: nb_report = max(1, self.nb_epochs // 10) self.nb_report = nb_report self.dropout_rate = dropout_rate self.use_cosine_schedule = use_cosine_schedule self.cosine_alpha = cosine_alpha self.max_grad_norm = max_grad_norm self.pca_smoothness_weight = pca_smoothness_weight self.pca_smoothness_start = pca_smoothness_start
############# ### UTILS ### #############
[docs] def kld(mean, logvar): """ Kullback-Leibler divergence of a normal distribution with arbitrary mean and log variance to the standard normal distribution with mean 0 and unit variance. """ return 0.5 * jnp.sum(mean**2 + jnp.exp(logvar) - logvar -1)
[docs] def bce(y, pred): """ binary cross entropy between y and the predicted array pred """ return -jnp.sum(y * jnp.log(pred) + (1-y) * jnp.log(1-pred))
[docs] def mse(y, pred): """ square error between y and the predicted array pred """ return jnp.sum((y-pred)**2)
[docs] def serialize(state: TrainState, config: NeuralnetConfig = None) -> dict: """ Serialize function to save the model and its configuration. Args: state (TrainState): The TrainState object to be serialized. config (NeuralnetConfig, optional): The config to be serialized. Defaults to None. Returns: _type_: _description_ """ # Get state dict, which has params params = flax.serialization.to_state_dict(state)["params"] serialized_dict = {"params": params, "config": config} return serialized_dict
################ ### TRAINING ### ################
[docs] class CVAE: def __init__(self, config: NeuralnetConfig, conditional_dim: Int, key: jax.random.PRNGKey = jax.random.key(21)): self.config = config net = nn.CVAE(hidden_layer_sizes=config.hidden_layer_sizes, latent_dim=config.latent_dim, output_size=config.output_size) key, subkey, subkey2 = jax.random.split(key, 3) params = net.init(subkey, jnp.ones(config.output_size), jnp.ones(conditional_dim), subkey2)['params'] if getattr(config, 'weight_decay', 0.0) > 0: tx = optax.adamw(config.learning_rate, weight_decay=config.weight_decay) else: tx = optax.adam(config.learning_rate) self.state = TrainState.create(apply_fn = net.apply, params = params, tx = tx) # initialize the training state
[docs] @staticmethod @jax.jit def train_step(state: TrainState, train_X: Float[Array, "n_batch_train ndim_input"], train_y: Float[Array, "n_batch_train ndim_output"], rng: jax.random.PRNGKey, val_X: Float[Array, "n_batch_val ndim_output"] = None, val_y: Float[Array, "n_batch_val ndim_output"] = None, ) -> tuple[TrainState, Float[Array, "n_batch_train"], Float[Array, "n_batch_val"]]: def apply_model(state, X, y, z_rng): def loss_fn(params): reconstructed_y, mean, logvar = state.apply_fn({'params': params}, y, X, z_rng) mse_loss = jnp.mean(jax.vmap(mse)(y, reconstructed_y)) # mean squared error loss kld_loss = jnp.mean(jax.vmap(kld)(mean, logvar)) # KLD loss return mse_loss + kld_loss grad_fn = jax.value_and_grad(loss_fn) loss, grads = grad_fn(state.params) return loss, grads rng, z_rng = jax.random.split(rng) train_loss, grads = apply_model(state, train_X, train_y, z_rng) if val_X is not None: rng, z_rng = jax.random.split(rng) val_loss, _ = apply_model(state, val_X, val_y, z_rng) else: val_loss = jnp.zeros_like(train_loss) # Update parameters state = state.apply_gradients(grads=grads) return state, train_loss, val_loss, rng
[docs] def train_loop(self, train_X: Float[Array, "n_batch_train ndim_input"], train_y: Float[Array, "n_batch_train ndim_output"], val_X: Float[Array, "n_batch_val ndim_output"] = None, val_y: Float[Array, "n_batch_val ndim_output"] = None, verbose: bool = True): train_losses, val_losses = [], [] rng = jax.random.key(2025) state = self.state best_state = state best_val_loss = jnp.inf start = time.time() for i in range(self.config.nb_epochs): # Do a single step rng, subkey = jax.random.split(rng) state, train_loss, val_loss, rng = self.train_step(state, train_X, train_y, subkey, val_X, val_y) # Save the losses train_losses.append(train_loss) val_losses.append(val_loss) # Track the best model by validation loss if val_X is not None and val_loss < best_val_loss: best_val_loss = val_loss best_state = state # Report once in a while if i % self.config.nb_report == 0 and verbose: logger.info(f"Train loss at step {i+1}: {train_loss}") logger.info(f"Valid loss at step {i+1}: {val_loss}") logger.info(f"Best valid loss so far: {best_val_loss}") logger.info(f"Learning rate: {self.config.learning_rate}") logger.info("---") end = time.time() if verbose: logger.info(f"Training for {self.config.nb_epochs} took {end-start} seconds.") if val_X is not None: logger.info(f"Best validation loss: {best_val_loss}") self.trained_state = best_state if val_X is not None else state return self.trained_state, train_losses, val_losses
[docs] def save_model(self, outfile: str = "my_flax_model.pkl"): """ Serialize and save the model to a file. Raises: ValueError: If the provided file extension is not .pkl or .pickle. Args: outfile (str, optional): The pickle file to which we save the serialized model. Defaults to "my_flax_model.pkl". """ if not outfile.endswith(".pkl") and not outfile.endswith(".pickle"): raise ValueError("For now, only .pkl or .pickle extensions are supported.") serialized_dict = serialize(self.trained_state, self.config) with open(outfile, 'wb') as handle: pickle.dump(serialized_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
[docs] @staticmethod def load_model(filename: str) -> tuple[TrainState, NeuralnetConfig]: """ Load a model from a file. TODO: this is very cumbersome now and must be massively improved in the future Args: filename (str): Filename of the model to be loaded. Raises: ValueError: If there is something wrong with loading, since lots of things can go wrong here. Returns: tuple[TrainState, NeuralnetConfig]: The TrainState object loaded from the file and the NeuralnetConfig object. """ with open(filename, 'rb') as handle: loaded_dict = pickle.load(handle) config: NeuralnetConfig = loaded_dict["config"] params = loaded_dict["params"] net = nn.Decoder(layer_sizes = [*config.hidden_layer_sizes[::-1], config.output_size]) # Create train state without optimizer state = TrainState.create(apply_fn = net.apply, params = params["decoder"], tx = optax.adam(config.learning_rate)) return state, config
[docs] @staticmethod def load_full_model(filename: str) -> tuple[TrainState, NeuralnetConfig]: with open(filename, "rb") as handle: loaded_dict = pickle.load(handle) config: NeuralnetConfig = loaded_dict["config"] params = loaded_dict["params"] net = nn.CVAE(hidden_layer_sizes=config.hidden_layer_sizes, output_size= config.output_size) # Create train state without optimizer state = TrainState.create(apply_fn = net.apply, params = params, tx = optax.adam(config.learning_rate)) return state, config
[docs] class MLP: def __init__(self, config: NeuralnetConfig, input_ndim: Int, key: jax.random.PRNGKey = jax.random.key(21)): self.config = config dropout_rate = getattr(config, 'dropout_rate', 0.0) net = nn.MLP(layer_sizes=config.layer_sizes, dropout_rate=dropout_rate) key, subkey = jax.random.split(key) params = net.init(subkey, jnp.ones(input_ndim), train=False)['params'] if getattr(config, 'weight_decay', 0.0) > 0: tx = optax.adamw(config.learning_rate, weight_decay=config.weight_decay) else: tx = optax.adam(config.learning_rate) self.state = TrainState.create(apply_fn=net.apply, params=params, tx=tx)
[docs] @staticmethod @jax.jit def train_step(state, batch_X, batch_y, dropout_rng, component_weights): def loss_fn(params): pred_y = state.apply_fn({'params': params}, batch_X, train=True, rngs={'dropout': dropout_rng}) per_sample = jax.vmap( lambda y, p: jnp.sum(component_weights * (y - p) ** 2) )(batch_y, pred_y) return jnp.mean(per_sample) loss, grads = jax.value_and_grad(loss_fn)(state.params) state = state.apply_gradients(grads=grads) return state, loss
[docs] @staticmethod @jax.jit def eval_step(state, X, y, component_weights): pred_y = state.apply_fn({'params': state.params}, X, train=False) per_sample = jax.vmap( lambda y, p: jnp.sum(component_weights * (y - p) ** 2) )(y, pred_y) return jnp.mean(per_sample)
[docs] def train_loop(self, train_X: Float[Array, "n_batch_train ndim_input"], train_y: Float[Array, "n_batch_train ndim_output"], val_X: Float[Array, "n_batch_val ndim_output"] = None, val_y: Float[Array, "n_batch_val ndim_output"] = None, verbose: bool = True): total_steps = self.config.nb_epochs # Component weights for smoothness regularization n_pca = train_y.shape[1] sw = getattr(self.config, 'pca_smoothness_weight', 0.0) ss = getattr(self.config, 'pca_smoothness_start', 0) if sw > 0: decay = jnp.exp(sw * jnp.clip(jnp.arange(n_pca) - ss, 0, None) / n_pca) component_weights = jnp.ones(n_pca).at[ss:].set(decay[ss:]) else: component_weights = jnp.ones(n_pca) # Optionally rebuild optimizer with cosine LR schedule if getattr(self.config, 'use_cosine_schedule', False): schedule_fn = optax.cosine_decay_schedule( init_value=self.config.learning_rate, decay_steps=total_steps, alpha=getattr(self.config, 'cosine_alpha', 0.01)) parts = [] if getattr(self.config, 'max_grad_norm', 0.0) > 0: parts.append(optax.clip_by_global_norm(self.config.max_grad_norm)) wd = getattr(self.config, 'weight_decay', 0.0) if wd > 0: parts.append(optax.adamw(schedule_fn, weight_decay=wd)) else: parts.append(optax.adam(schedule_fn)) tx = optax.chain(*parts) if len(parts) > 1 else parts[0] self.state = TrainState.create( apply_fn=self.state.apply_fn, params=self.state.params, tx=tx) train_losses, val_losses = [], [] state = self.state best_state = state best_val_loss = jnp.inf rng = jax.random.key(2025) start = time.time() for i in range(self.config.nb_epochs): rng, dropout_rng = jax.random.split(rng) state, epoch_loss = self.train_step( state, train_X, train_y, dropout_rng, component_weights) # Evaluate on full validation set if val_X is not None: val_loss = self.eval_step(state, val_X, val_y, component_weights) else: val_loss = jnp.zeros_like(epoch_loss) train_losses.append(epoch_loss) val_losses.append(val_loss) # Track the best model by validation loss if val_X is not None and val_loss < best_val_loss: best_val_loss = val_loss best_state = state # Report once in a while if i % self.config.nb_report == 0 and verbose: logger.info(f"Train loss at step {i+1}: {epoch_loss}") logger.info(f"Valid loss at step {i+1}: {val_loss}") logger.info(f"Best valid loss so far: {best_val_loss}") logger.info(f"Learning rate: {self.config.learning_rate}") logger.info("---") end = time.time() if verbose: logger.info(f"Training for {self.config.nb_epochs} took {end-start} seconds.") if val_X is not None: logger.info(f"Best validation loss: {best_val_loss}") self.trained_state = best_state if val_X is not None else state return self.trained_state, train_losses, val_losses
[docs] def save_model(self, outfile: str = "my_flax_model.pkl"): """ Serialize and save the model to a file. Raises: ValueError: If the provided file extension is not .pkl or .pickle. Args: outfile (str, optional): The pickle file to which we save the serialized model. Defaults to "my_flax_model.pkl". """ if not outfile.endswith(".pkl") and not outfile.endswith(".pickle"): raise ValueError("For now, only .pkl or .pickle extensions are supported.") serialized_dict = serialize(self.trained_state, self.config) with open(outfile, 'wb') as handle: pickle.dump(serialized_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
[docs] @staticmethod def load_model(filename: str) -> tuple[TrainState, NeuralnetConfig]: """ Load a model from a file. TODO: this is very cumbersome now and must be massively improved in the future Args: filename (str): Filename of the model to be loaded. Raises: ValueError: If there is something wrong with loading, since lots of things can go wrong here. Returns: tuple[TrainState, NeuralnetConfig]: The TrainState object loaded from the file and the NeuralnetConfig object. """ with open(filename, 'rb') as handle: loaded_dict = pickle.load(handle) config: NeuralnetConfig = loaded_dict["config"] params = loaded_dict["params"] dropout_rate = getattr(config, 'dropout_rate', 0.0) net = nn.MLP(config.layer_sizes, dropout_rate=dropout_rate) # Create train state without optimizer state = TrainState.create(apply_fn=net.apply, params=params, tx=optax.adam(config.learning_rate)) return state, config