MNIST Tutorial
In this tutorial we show how we can use GeometricMachineLearning
to build a vision transformer and apply it for MNIST [90], while also putting some of the weights on a manifold. This is also the result presented in [7].
We get the dataset from MLDatasets
. Before we use it we allocate it on gpu with cu
from CUDA.jl
[11]:
using MLDatasets
using CUDA
train_x, train_y = MLDatasets.MNIST(split=:train)[:]
test_x, test_y = MLDatasets.MNIST(split=:test)[:]
train_x = train_x |> cu
train_y = train_y |> cu
test_x = test_x |> cu
test_y = test_y |> cu
Next we call DataLoader
on these data. For this we first need to specify a patch length[1].
In order to apply the transformer to a data set we should typically cast these data into a time series format. MNIST images are pictures with $28\times28$ pixels. Here we cast these images into time series of length 16, so one image is represented by a matrix $\in\mathbb{R}^{49\times{}16}$.
const patch_length = 7
dl = DataLoader(train_x, train_y, patch_length = patch_length; suppress_info = true)
Here we called DataLoader
on a tensor and a vector of integers (targets) as input. DataLoader
automatically converts the data to the correct input format for easy handling. This is visualized below:
Internally DataLoader
calls split_and_flatten
which splits each image into a number of patches according to the keyword arguments patch_length
and number_of_patches
. We also load the test data with DataLoader
:
dl_test = DataLoader(test_x, test_y, patch_length=patch_length)
[ Info: You provided a tensor and a vector as input. This will be treated as a classification problem (MNIST). Tensor axes: (i) & (ii) image axes and (iii) parameter dimension.
We next define the model with which we want to train:
const n_heads = 7
const L = 16
const add_connection = false
model1 = ClassificationTransformer(dl;
n_heads = n_heads,
L = L,
add_connection = add_connection,
Stiefel = false)
model2 = ClassificationTransformer(dl;
n_heads = n_heads,
L = L,
add_connection = add_connection,
Stiefel = true)
Here we have chosen a ClassificationTransformer
, i.e. a composition of a specific number of transformer layers composed with a classification layer. We also set the Stiefel option to true
, i.e. we are optimizing on the Stiefel manifold.
We now have to initialize the neural network weights. This is done with the constructor for NeuralNetwork
:
backend = GeometricMachineLearning.networkbackend(dl)
T = eltype(dl)
nn1 = NeuralNetwork(model1, backend, T)
nn2 = NeuralNetwork(model2, backend, T)
We still have to initialize the optimizers:
const batch_size = 2048
const n_epochs = 500
# an instance of batch is needed for the optimizer
batch = Batch(batch_size, dl)
opt1 = Optimizer(AdamOptimizer(T), nn1)
opt2 = Optimizer(AdamOptimizer(T), nn2)
And with this we can finally perform the training:
loss_array1 = opt1(nn1, dl, batch, n_epochs, FeedForwardLoss())
loss_array2 = opt2(nn2, dl, batch, n_epochs, FeedForwardLoss())
We furthermore optimize the second neural network (with weights on the manifold) with the GradientOptimizer
and the MomentumOptimizer
:
nn3 = NeuralNetwork(model2, backend, T)
nn4 = NeuralNetwork(model2, backend, T)
opt3 = Optimizer(GradientOptimizer(T(0.001)), nn3)
opt4 = Optimizer(MomentumOptimizer(T(0.001), T(0.5)), nn4)
For training we use the same data, the same batch and the same number of epochs:
loss_array3 = opt3(nn3, dl, batch, n_epochs, FeedForwardLoss())
loss_array4 = opt4(nn4, dl, batch, n_epochs, FeedForwardLoss())
And we get the following result:
We see that the loss value for the Adam optimizer without parameters on the Stiefel manifold is stuck at around 1.34 which means that it always predicts the same value. So in 1 out of ten cases we have error 0 and in 9 out of ten cases we have error $\sqrt{2}$, giving
\[ \sqrt{2\frac{9}{10}} = 1.342,\]
which is what we see in the error plot.
We can also call GeometricMachineLearning.accuracy
to obtain the test accuracy instead of the training error:
(accuracy(nn1, dl_test), accuracy(nn2, dl_test), accuracy(nn3, dl_test), accuracy(nn4, dl_test))
(0.0974, 0.8613, 0.5518, 0.6351)
We note here that conventional convolutional neural networks and other vision transformers achieve much better accuracy on MNIST in a training time that is often shorter than what we presented here. Our aim here is not to outperform existing neural networks in terms of accuracy on image classification problems, but to demonstrate two things: (i) in many cases putting weights on the Stiefel manifold (which is a compact space) can enable training that would otherwise not be possible and (ii) as is the case with standard Adam, the manifold version also seems to achieve similar performance gain over the gradient and momentum optimizer. Both of these observations are demonstrated figure above.
Library Functions
GeometricMachineLearning.split_and_flatten
— Functionsplit_and_flatten(input::AbstractArray)::AbstractArray
Perform a preprocessing of an image into flattened patches.
This rearranges the input data so that it can easily be processed with a transformer.
Examples
Consider a matrix of size $6\times6$ which we want to divide into patches of size $3\times3$.
using GeometricMachineLearning
input = [ 1 2 3 4 5 6;
7 8 9 10 11 12;
13 14 15 16 17 18;
19 20 21 22 23 24;
25 26 27 28 29 30;
31 32 33 34 35 36]
split_and_flatten(input; patch_length = 3, number_of_patches = 4)
# output
9×4 Matrix{Int64}:
1 19 4 22
7 25 10 28
13 31 16 34
2 20 5 23
8 26 11 29
14 32 17 35
3 21 6 24
9 27 12 30
15 33 18 36
Here we see that split_and_flatten
:
- splits the original matrix into four $3\times3$ matrices and then
- flattens each matrix into a column vector of size $9.$
After this all the vectors are put together again to yield a $9\times4$ matrix.
Arguments
The optional keyword arguments are:
patch_length
: by default this is 7.number_of_patches
: by default this is 16.
The sizes of the first and second axis of the output of split_and_flatten
are
- $\mathtt{path\_length}^2$ and
number_of_patches
.
GeometricMachineLearning.accuracy
— Functionaccuracy(model, ps, dl)
Compute the accuracy of a neural network classifier.
This needs an instance of DataLoader
that stores the test data.
accuracy(nn, dl)
Compute the accuracy of a neural network classifier.
This is like accuracy(::Chain, ::Tuple, ::DataLoader)
, but for a NeuralNetwork
.
GeometricMachineLearning.onehotbatch
— Functiononehotbatch(target)
Performs a one-hot-batch encoding of a vector of integers: $input\in\{0,1,\ldots,9\}^\ell$.
The output is a tensor of shape $10\times1\times\ell$.
If the input is $0$, this function produces:
\[0 \mapsto \begin{bmatrix} 1 & 0 & \ldots & 0 \end{bmatrix}^T.\]
In more abstract terms: $i \mapsto e_i$.
Examples
using GeometricMachineLearning: onehotbatch
target = [0]
onehotbatch(target)
# output
10×1×1 Array{Int64, 3}:
[:, :, 1] =
1
0
0
0
0
0
0
0
0
0
GeometricMachineLearning.ClassificationLayer
— TypeClassificationLayer(input_dim, output_dim, activation)
Make an instance of ClassificationLayer
.
ClassificationLayer
takes a matrix as an input and returns a vector that is used for classification.
It does:
\[ X \mapsto \sigma(\mathtt{compute\_vector}(AX)),\]
where $X$ is a matrix and $\mathtt{compute\_vector}$ specifices how this matrix is turned into a vector.
$\mathtt{compute\_vector}$ can be specified with the keyword average
.
Arguments
ClassificationLayer
has the following optional keyword argument:
average:Bool=false
.
If this keyword argument is set to true
, then the output is computed as
\[ input \mapsto \frac{1}{N}\sum_{i=1}^N[\mathcal{NN}(input)]_{\bullet{}i}.\]
If set to false
(the default) it picks the last column of the input.
Examples
using GeometricMachineLearning
l = ClassificationLayer(2, 2, identity; average = true)
ps = (weight = [1 0; 0 1], )
input = [1 2 3; 1 1 1]
l(input, ps)
# output
2×1 Matrix{Float64}:
2.0
1.0
using GeometricMachineLearning
l = ClassificationLayer(2, 2, identity; average = false)
ps = (weight = [1 0; 0 1], )
input = [1 2 3; 1 1 1]
l(input, ps)
# output
2×1 Matrix{Int64}:
3
1
References
- [7]
- B. Brantner. Generalizing Adam To Manifolds For Efficiently Training Transformers, arXiv preprint arXiv:2305.16901 (2023).
- 1When
DataLoader
is called this way it usessplit_and_flatten
internally.