Comparing Matrix and Vector Softmax as Activation Functions in a Transformer

Transformers are usually build with VectorSoftmax as activation function, meaning that the activation function looks at each columns of a matrix (or tensor) independently:

\[\mathrm{softmax}([v^1, v^2, \ldots, v^n]) \equiv \left[ \frac{e^{v^1_i}}{\sum_{i'=1}^de^{v^1_{i'}}}, \frac{e^{v^2_i}}{\sum_{i'=1}^de^{v^2_{i'}}}, \ldots, \frac{e^{v^n_i}}{\sum_{i'=1}^de^{v^n_{i'}}} \right].\]

One can however also use a MatrixSoftmax:

\[\mathrm{Msoftmax}(V) \equiv \frac{e^{V_{ij}}}{\sum_{i,j}e^{V_ij}}.\]

using GeometricMachineLearning

act1 = GeometricMachineLearning.VectorSoftmax()
act2 = GeometricMachineLearning.MatrixSoftmax()

A = [1 2 3; 1 2 3; 1 2 3]
3×3 Matrix{Int64}:
 1  2  3
 1  2  3
 1  2  3
act1(A)
3×3 Matrix{Float64}:
 0.333333  0.333333  0.333333
 0.333333  0.333333  0.333333
 0.333333  0.333333  0.333333
act2(A)
3×3 Matrix{Float64}:
 0.0300102  0.0815762  0.221747
 0.0300102  0.0815762  0.221747
 0.0300102  0.0815762  0.221747

We can now train transformers with these different activation functions in the MultiHeadAttention layers:

using GeometricProblems.CoupledHarmonicOscillator: hodeensemble, default_parameters

const tstep = .3
const n_init_con = 5

# ensemble problem
ep = hodeensemble([rand(2) for _ in 1:n_init_con], [rand(2) for _ in 1:n_init_con]; tstep = tstep)
dl = DataLoader(integrate(ep, ImplicitMidpoint()); suppress_info = true)

We now define the architectures and train them:

const seq_length = 4
const batch_size = 1024
const n_epochs = 1000

arch1 = StandardTransformerIntegrator(dl.input_dim; transformer_dim = 20,
                                                    n_heads = 4,
                                                    L = 1,
                                                    n_blocks = 2,
                                                    attention_activation = act1)

arch2 = StandardTransformerIntegrator(dl.input_dim; transformer_dim = 20,
                                                    n_heads = 4,
                                                    L = 1,
                                                    n_blocks = 2,
                                                    attention_activation = act2)

nn1 = NeuralNetwork(arch1)
nn2 = NeuralNetwork(arch2)

o_method = AdamOptimizer()

o1 = Optimizer(o_method, nn1)
o2 = Optimizer(o_method, nn2)

batch = Batch(batch_size, seq_length)

loss_array1 = o1(nn1, dl, batch, n_epochs; show_progress = false)
loss_array2 = o2(nn2, dl, batch, n_epochs; show_progress = false)

Training loss for the different networks. Training loss for the different networks.

We further evaluate a trajectory with the trained networks for 300 time steps:

Validation of the different networks. Validation of the different networks.