Comparing Different VolumePreservingAttention Mechanisms

In the section on volume-preserving attention we mentioned two ways of computing volume-preserving attention: one where we compute the correlations with a skew-symmetric matrix and one where we compute the correlations with an arbitrary matrix. Here we compare the two approaches. When calling the VolumePreservingAttention layer we can specify whether we want to use the skew-symmetric or the arbitrary weighting by setting the keyword skew_sym = true and skew_sym = false respectively.

In here we demonstrate the differences between the two approaches for computing correlations. For this we first generate a training set consisting of two collections of curves: (i) sine curves and (ii) cosine curve.

sine_cosine = zeros(1, 1000, 2)
sine_cosine[1, :, 1] .= sin.(0.:.1:99.9)
sine_cosine[1, :, 2] .= cos.(0.:.1:99.9)

const T = Float16
const dl = DataLoader(T.(sine_cosine); suppress_info = true)

The third axis (i.e. the parameter axis) has length two, meaning we have two different kinds of curves, i.e. the data look like this:

The data we treat here contains two different curves. The data we treat here contains two different curves.

We want to train a single neural network on both these curves. We already noted before that a simple feedforward neural network cannot do this. Here we compare three networks which are of the following form:

\[\mathtt{network} = \mathcal{NN}_d\circ\Psi\circ\mathcal{NN}_u,\]

where $\mathcal{NN}_u$ refers to a neural network that scales up and $\mathcal{NN}_d$ refers to a neural network that scales down. The up and down scaling is done with simple dense layers:

\[\mathcal{NN}_u(x) = \mathrm{tanh}(a_ux + b_u) \text{ and } \mathcal{NN}_d(x) = a_d^Tx + b_d,\]

where $a_u, b_u, a_d\in\mathbb{R}^\mathrm{ud}$ and $b_d$ is a scalar. ud refers to upscaling dimension. For $\Psi$ we consider three different choices:

  1. a volume-preserving attention with skew-symmetric weighting,
  2. a volume-preserving attention with arbitrary weighting,
  3. an identity layer.

We further choose a sequence length 5 (i.e. the network always sees the last 5 time steps) and always predict one step into the future (i.e. the prediction window is set to 1):

const seq_length = 3
const prediction_window = 1
const upscale_dimension_1 = 2

function set_up_networks(upscale_dimension::Int = upscale_dimension_1)
    model_skew = Chain( Dense(1, upscale_dimension, tanh),
                        VolumePreservingAttention(upscale_dimension, seq_length; skew_sym = true),
                        Dense(upscale_dimension, 1, identity; use_bias = true)
                        )

    model_arb  = Chain( Dense(1, upscale_dimension, tanh),
                        VolumePreservingAttention(upscale_dimension, seq_length; skew_sym = false),
                        Dense(upscale_dimension, 1, identity; use_bias = true)
                        )

    model_comp = Chain( Dense(1, upscale_dimension, tanh),
                        Dense(upscale_dimension, 1, identity; use_bias = true)
                        )

    nn_skew = NeuralNetwork(model_skew, CPU(), T)
    nn_arb  = NeuralNetwork(model_arb,  CPU(), T)
    nn_comp = NeuralNetwork(model_comp, CPU(), T)

    nn_skew, nn_arb, nn_comp
end

nn_skew, nn_arb, nn_comp = set_up_networks()

We expect the third network to not be able to learn anything useful since it cannot resolve time series data: a regular feedforward network only ever sees one datum at a time.

Next we train the networks (here we pick a batch size of 30 and train for 1000 epochs):

The training losses for the three networks. The training losses for the three networks.

Looking at the training errors, we can see that the network with the skew-symmetric weighting is stuck at a relatively high error rate, whereas the loss for the network with the arbitrary weighting is decreasing to a significantly lower level. The feedforward network without the attention mechanism is not able to learn anything useful (as was expected).

Before we can use the trained neural networks for prediction we have to make them TransformerIntegrators or NeuralNetworkIntegrators[1]:

initial_condition = dl.input[:, 1:seq_length, 2]

function make_networks_neural_network_integrators(nn_skew, nn_arb, nn_comp)
    nn_skew = NeuralNetwork(GeometricMachineLearning.DummyTransformer(seq_length),
                            nn_skew.model,
                            params(nn_skew),
                            CPU())
    nn_arb  = NeuralNetwork(GeometricMachineLearning.DummyTransformer(seq_length),
                            nn_arb.model,
                            params(nn_arb),
                            CPU())
    nn_comp = NeuralNetwork(GeometricMachineLearning.DummyNNIntegrator(),
                            nn_comp.model,
                            params(nn_comp),
                            CPU())

    nn_skew, nn_arb, nn_comp
end

nn_skew, nn_arb, nn_comp = make_networks_neural_network_integrators(nn_skew, nn_arb, nn_comp)
nothing

Comparing the two volume-preserving attention mechanisms for 40 points. Comparing the two volume-preserving attention mechanisms for 40 points.

In the plot above we can see that the network with the arbitrary weighting performs much better; even though the red line does not fit the purple line perfectly, it manages to least qualitatively reflect the training data. We can also plot the predictions for longer time intervals:

Comparing the two volume-preserving attention mechanisms for 400 points. Comparing the two volume-preserving attention mechanisms for 400 points.

This advantage of the volume-preserving attention with arbitrary weighting may however be due to the fact that the skew-symmetric attention only has 3 learnable parameters, as opposed to 9 for the arbitrary weighting. We can increase the upscaling dimension and see how it affects the result:

const upscale_dimension_2 = 10

nn_skew, nn_arb, nn_comp = set_up_networks(upscale_dimension_2)

o_skew, o_arb, o_comp = set_up_optimizers(nn_skew, nn_arb, nn_comp)

Comparison for 40 points, but with an upscaling of ten. Comparison for 40 points, but with an upscaling of ten.

initial_condition = dl.input[:, 1:seq_length, 2]

nn_skew, nn_arb, nn_comp = make_networks_neural_network_integrators(nn_skew, nn_arb, nn_comp)

fig_dark, fig_light, ax_dark, ax_light = produce_validation_plot(40, nn_skew, nn_arb, nn_comp)

And for a longer time interval:

fig_dark, fig_light, ax_dark, ax_light = produce_validation_plot(200, nn_skew, nn_arb, nn_comp)


nothing

Here we see that the arbitrary weighting quickly fails and the skew-symmetric weighting performs better on longer time scales.

Library Functions