MNIST tutorial

This is a short tutorial that shows how we can use GeometricMachineLearning to build a vision transformer and apply it for MNIST, while also putting some of the weights on a manifold. This is also the result presented in [48].

First, we need to import the relevant packages:

using GeometricMachineLearning, CUDA, Plots
import Zygote, MLDatasets, KernelAbstractions

For the AD routine we here use the GeometricMachineLearning default and we get the dataset from MLDatasets. First we need to load the data set, and put it on GPU (if you have one):

train_x, train_y = MLDatasets.MNIST(split=:train)[:]
test_x, test_y = MLDatasets.MNIST(split=:test)[:]
train_x = train_x |> cu 
test_x = test_x |> cu 
train_y = train_y |> cu 
test_y = test_y |> cu

GeometricMachineLearning has built-in data loaders that make it particularly easy to handle data:

patch_length = 7
dl = DataLoader(train_x, train_y, patch_length=patch_length)
dl_test = DataLoader(train_x, train_y, patch_length=patch_length)

Here patch_length indicates the size one patch has. One image in MNIST is of dimension $28\times28$, this means that we decompose this into 16 $(7\times7)$ images (also see [48]).

We next define the model with which we want to train:

model = ClassificationTransformer(dl, n_heads=n_heads, n_layers=n_layers, Stiefel=true)

Here we have chosen a ClassificationTransformer, i.e. a composition of a specific number of transformer layers composed with a classification layer. We also set the Stiefel option to true, i.e. we are optimizing on the Stiefel manifold.

We now have to initialize the neural network weights. This is done with the constructor for NeuralNetwork:

backend = KernelAbstractions.get_backend(dl)
T = eltype(dl)
nn = NeuralNetwork(model, backend, T)

And with this we can finally perform the training:

# an instance of batch is needed for the optimizer
batch = Batch(batch_size)

optimizer_instance = Optimizer(AdamOptimizer(), nn)

# this prints the accuracy and is optional
println("initial test accuracy: ", accuracy(Ψᵉ, ps, dl_test), "\n")

loss_array = optimizer_instance(nn, dl, batch, n_epochs)

println("final test accuracy: ", accuracy(Ψᵉ, ps, dl_test), "\n")

It is instructive to play with n_layers, n_epochs and the Stiefel property.

[48]
B. Brantner. Generalizing Adam To Manifolds For Efficiently Training Transformers, arXiv preprint arXiv:2305.16901 (2023).