Source code for fiesta.train.nn_architectures
from typing import Sequence, Callable
import jax
import jax.numpy as jnp
from jaxtyping import Array, Int
from flax import linen as nn # Linen API
#####################
### ARCHITECTURES ###
#####################
[docs]
class BaseNeuralnet(nn.Module):
"""Abstract base class. Needs layer sizes and activation function used"""
layer_sizes: Sequence[int]
act_func: Callable = nn.relu
[docs]
def setup(self):
raise NotImplementedError
def __call__(self, x):
raise NotImplementedError
[docs]
class MLP(BaseNeuralnet):
"""Basic multi-layer perceptron: a feedforward neural network with multiple Dense layers."""
dropout_rate: float = 0.0
[docs]
def setup(self):
self.layers = [nn.Dense(n) for n in self.layer_sizes]
@nn.compact
def __call__(self, x: Array, train: bool = False):
for layer in self.layers[:-1]:
x = layer(x)
x = self.act_func(x)
if self.dropout_rate > 0.0:
x = nn.Dropout(rate=self.dropout_rate,
deterministic=not train)(x)
x = self.layers[-1](x)
return x
[docs]
class Encoder(nn.Module):
layer_sizes: Sequence[int]
act_func: Callable = nn.relu
[docs]
def setup(self):
self.mu_layers = [nn.Dense(n) for n in self.layer_sizes]
self.logvar_layers = [nn.Dense(n) for n in self.layer_sizes]
@nn.compact
def __call__(self, y: Array):
mu = y.copy()
for layer in self.mu_layers[:-1]:
mu = layer(mu)
mu = self.act_func(mu)
mu = self.mu_layers[-1](mu)
logvar = y.copy()
for layer in self.logvar_layers[:-1]:
logvar = layer(logvar)
logvar = self.act_func(logvar)
logvar = self.logvar_layers[-1](logvar)
return mu, logvar
[docs]
class Decoder(MLP):
@nn.compact
def __call__(self, z: Array):
for layer in self.layers[:-1]:
# Apply the linear part of the layer's operation
z = layer(z)
# Apply the given activation function
z = self.act_func(z)
z = self.layers[-1](z) # for the output layer only apply the linear part
return z
[docs]
class CVAE(nn.Module):
"""Conditional Variational Autoencoder consisting of an Encoder and a Decoder."""
hidden_layer_sizes: Sequence[Int] # used for both the encoder and decoder
output_size: Int
latent_dim: Int = 20
[docs]
def setup(self):
self.encoder = Encoder([*self.hidden_layer_sizes, self.latent_dim])
self.decoder = Decoder(layer_sizes=[*self.hidden_layer_sizes[::-1], self.output_size], act_func=nn.relu)
def __call__(self, y: Array, x: Array, z_rng: jax.random.PRNGKey):
y = jnp.concatenate([y, x.copy()], axis = -1)
mu, logvar = self.encoder(y)
# Reparametrize
std = jnp.exp(0.5* logvar)
eps = jax.random.normal(z_rng, logvar.shape)
z = mu + eps * std
z_x = jnp.concatenate([z, x.copy()], axis = -1)
reconstructed_y = self.decoder(z_x)
return reconstructed_y, mu, logvar
[docs]
class CNN(nn.Module):
"""Convolutional Neural Network"""
dense_layer_sizes: Sequence[Int]
kernel_sizes: Sequence[Int]
conv_layer_sizes: Sequence[Int]
output_shape: tuple[Int, Int]
spatial: Int = 32
act_func: Callable = nn.relu
[docs]
def setup(self):
if self.dense_layer_sizes[-1] != self.conv_layer_sizes[0]:
raise ValueError(f"Final dense layer must be equally large as first convolutional layer.")
if self.conv_layer_sizes[-1] != 1:
raise ValueError(f"Last convolutional layer must be of size 1 to predict 2D array.")
self.dense_layers = [nn.Dense(n) for n in self.dense_layer_sizes[:-1]]
self.dense_layers += (nn.Dense(self.dense_layer_sizes[-1] * self.spatial**2), ) # the last dense layer should create an array that can be reshaped into spatial and chanel parts
self.conv_layers = [nn.Conv(features = f, kernel_size = (k,k)) for f, k in zip(self.conv_layer_sizes, self.kernel_sizes)]
def __call__(self, x: Array):
# Apply the dense layers
for layer in self.dense_layers:
x = layer(x)
x = self.act_func(x)
x = x.reshape((-1, self.spatial, self.spatial, self.dense_layer_sizes[-1]))
for layer in self.conv_layers[:-1]:
x = layer(x)
x = self.act_func(x)
x = self.conv_layers[-1](x) # only apply convolution part of last convolutional layer
x = x[:,:,:,0]
x = jax.image.resize(x, shape = (x.shape[0], *self.output_shape), method = "bilinear") # resize the NN output to the desired output
return x