MNIST Tutorial

In this tutorial we show how we can use GeometricMachineLearning to build a vision transformer and apply it for MNIST [90], while also putting some of the weights on a manifold. This is also the result presented in [7].

We get the dataset from MLDatasets. Before we use it we allocate it on gpu with cu from CUDA.jl [11]:

using MLDatasets
using CUDA
train_x, train_y = MLDatasets.MNIST(split=:train)[:]
test_x, test_y = MLDatasets.MNIST(split=:test)[:]

train_x = train_x |> cu
train_y = train_y |> cu
test_x = test_x |> cu
test_y = test_y |> cu

Next we call DataLoader on these data. For this we first need to specify a patch length[1].

Remark

In order to apply the transformer to a data set we should typically cast these data into a time series format. MNIST images are pictures with $28\times28$ pixels. Here we cast these images into time series of length 16, so one image is represented by a matrix $\in\mathbb{R}^{49\times{}16}$.

const patch_length = 7
dl = DataLoader(train_x, train_y, patch_length = patch_length; suppress_info = true)

Here we called DataLoader on a tensor and a vector of integers (targets) as input. DataLoader automatically converts the data to the correct input format for easy handling. This is visualized below:

Internally DataLoader calls split_and_flatten which splits each image into a number of patches according to the keyword arguments patch_length and number_of_patches. We also load the test data with DataLoader:

dl_test = DataLoader(test_x, test_y, patch_length=patch_length)
[ Info: You provided a tensor and a vector as input. This will be treated as a classification problem (MNIST). Tensor axes: (i) & (ii) image axes and (iii) parameter dimension.

We next define the model with which we want to train:

const n_heads = 7
const L = 16
const add_connection = false

model1 = ClassificationTransformer(dl;
                                    n_heads = n_heads,
                                    L = L,
                                    add_connection = add_connection,
                                    Stiefel = false)
model2 = ClassificationTransformer(dl;
                                    n_heads = n_heads,
                                    L = L,
                                    add_connection = add_connection,
                                    Stiefel = true)

Here we have chosen a ClassificationTransformer, i.e. a composition of a specific number of transformer layers composed with a classification layer. We also set the Stiefel option to true, i.e. we are optimizing on the Stiefel manifold.

We now have to initialize the neural network weights. This is done with the constructor for NeuralNetwork:

backend = GeometricMachineLearning.networkbackend(dl)
T = eltype(dl)
nn1 = NeuralNetwork(model1, backend, T)
nn2 = NeuralNetwork(model2, backend, T)

We still have to initialize the optimizers:

const batch_size = 2048
const n_epochs = 500
# an instance of batch is needed for the optimizer
batch = Batch(batch_size, dl)

opt1 = Optimizer(AdamOptimizer(T), nn1)
opt2 = Optimizer(AdamOptimizer(T), nn2)

And with this we can finally perform the training:

loss_array1 = opt1(nn1, dl, batch, n_epochs, FeedForwardLoss())
loss_array2 = opt2(nn2, dl, batch, n_epochs, FeedForwardLoss())

We furthermore optimize the second neural network (with weights on the manifold) with the GradientOptimizer and the MomentumOptimizer:

nn3 = NeuralNetwork(model2, backend, T)
nn4 = NeuralNetwork(model2, backend, T)

opt3 = Optimizer(GradientOptimizer(T(0.001)), nn3)
opt4 = Optimizer(MomentumOptimizer(T(0.001), T(0.5)), nn4)

For training we use the same data, the same batch and the same number of epochs:

loss_array3 = opt3(nn3, dl, batch, n_epochs, FeedForwardLoss())
loss_array4 = opt4(nn4, dl, batch, n_epochs, FeedForwardLoss())

And we get the following result:

Remark

We see that the loss value for the Adam optimizer without parameters on the Stiefel manifold is stuck at around 1.34 which means that it always predicts the same value. So in 1 out of ten cases we have error 0 and in 9 out of ten cases we have error $\sqrt{2}$, giving

\[ \sqrt{2\frac{9}{10}} = 1.342,\]

which is what we see in the error plot.

We can also call GeometricMachineLearning.accuracy to obtain the test accuracy instead of the training error:

(accuracy(nn1, dl_test), accuracy(nn2, dl_test), accuracy(nn3, dl_test), accuracy(nn4, dl_test))
(0.0974, 0.8613, 0.5518, 0.6351)
Remark

We note here that conventional convolutional neural networks and other vision transformers achieve much better accuracy on MNIST in a training time that is often shorter than what we presented here. Our aim here is not to outperform existing neural networks in terms of accuracy on image classification problems, but to demonstrate two things: (i) in many cases putting weights on the Stiefel manifold (which is a compact space) can enable training that would otherwise not be possible and (ii) as is the case with standard Adam, the manifold version also seems to achieve similar performance gain over the gradient and momentum optimizer. Both of these observations are demonstrated figure above.

Library Functions

GeometricMachineLearning.split_and_flattenFunction
split_and_flatten(input::AbstractArray)::AbstractArray

Perform a preprocessing of an image into flattened patches.

This rearranges the input data so that it can easily be processed with a transformer.

Examples

Consider a matrix of size $6\times6$ which we want to divide into patches of size $3\times3$.

using GeometricMachineLearning

input = [ 1  2  3  4  5  6; 
          7  8  9 10 11 12; 
         13 14 15 16 17 18;
         19 20 21 22 23 24; 
         25 26 27 28 29 30; 
         31 32 33 34 35 36]

split_and_flatten(input; patch_length = 3, number_of_patches = 4)

# output

9×4 Matrix{Int64}:
  1  19   4  22
  7  25  10  28
 13  31  16  34
  2  20   5  23
  8  26  11  29
 14  32  17  35
  3  21   6  24
  9  27  12  30
 15  33  18  36

Here we see that split_and_flatten:

  1. splits the original matrix into four $3\times3$ matrices and then
  2. flattens each matrix into a column vector of size $9.$

After this all the vectors are put together again to yield a $9\times4$ matrix.

Arguments

The optional keyword arguments are:

  • patch_length: by default this is 7.
  • number_of_patches: by default this is 16.

The sizes of the first and second axis of the output of split_and_flatten are

  1. $\mathtt{path\_length}^2$ and
  2. number_of_patches.
source
GeometricMachineLearning.onehotbatchFunction
onehotbatch(target)

Performs a one-hot-batch encoding of a vector of integers: $input\in\{0,1,\ldots,9\}^\ell$.

The output is a tensor of shape $10\times1\times\ell$.

If the input is $0$, this function produces:

\[0 \mapsto \begin{bmatrix} 1 & 0 & \ldots & 0 \end{bmatrix}^T.\]

In more abstract terms: $i \mapsto e_i$.

Examples

using GeometricMachineLearning: onehotbatch

target = [0]
onehotbatch(target)

# output

10×1×1 Array{Int64, 3}:
[:, :, 1] =
 1
 0
 0
 0
 0
 0
 0
 0
 0
 0
source
GeometricMachineLearning.ClassificationLayerType
ClassificationLayer(input_dim, output_dim, activation)

Make an instance of ClassificationLayer.

ClassificationLayer takes a matrix as an input and returns a vector that is used for classification.

It does:

\[ X \mapsto \sigma(\mathtt{compute\_vector}(AX)),\]

where $X$ is a matrix and $\mathtt{compute\_vector}$ specifices how this matrix is turned into a vector.

$\mathtt{compute\_vector}$ can be specified with the keyword average.

Arguments

ClassificationLayer has the following optional keyword argument:

  • average:Bool=false.

If this keyword argument is set to true, then the output is computed as

\[ input \mapsto \frac{1}{N}\sum_{i=1}^N[\mathcal{NN}(input)]_{\bullet{}i}.\]

If set to false (the default) it picks the last column of the input.

Examples

using GeometricMachineLearning

l = ClassificationLayer(2, 2, identity; average = true)
ps = (weight = [1 0; 0 1], )

input = [1 2 3; 1 1 1]

l(input, ps)

# output

2×1 Matrix{Float64}:
 2.0
 1.0
using GeometricMachineLearning

l = ClassificationLayer(2, 2, identity; average = false)
ps = (weight = [1 0; 0 1], )

input = [1 2 3; 1 1 1]

l(input, ps)

# output

2×1 Matrix{Int64}:
 3
 1
source

References

[7]
B. Brantner. Generalizing Adam To Manifolds For Efficiently Training Transformers, arXiv preprint arXiv:2305.16901 (2023).