Matrix Softmax v Vector Softmax

In this section we compare the VectorSoftmax to the MatrixSoftmax. What is usually meant by softmax is the vector softmax, i.e. one that does:

\[[\mathrm{softmax}(a)]_i = \frac{e^{a_i}}{\sum_{i'=1}^de^{a_i}}. \]

So each column of a matrix is normalized to sum up to one. With this softmax, the linear recombination that is performed by the attention layer becomes a convex recombination. This is not the case for the MatrixSoftmax, where the normalization is computed over all matrix entries:

\[[\mathrm{softmax}(A)]_{ij} = \frac{e^{A_{ij}}}{\sum_{i'=1, j'=1}^{d,\bar{d}}e^{A_{ij}}}. \]

We want to compare those two approaches on the example of the coupled harmonic oscillator. It is a Hamiltonian system with

\[H(q_1, q_2, p_1, p_2) = \frac{q_1^2}{2m_1} + \frac{q_2^2}{2m_2} + k_1\frac{q_1^2}{2} + k_2\frac{q_2^2}{2} + k\sigma(q_1)\frac{(q_2 - q_1)^2}{2},\]

where $\sigma(x) = 1 / (1 + e^{-x})$ is the sigmoid activation function. The system parameters are:

  • $k_1$: spring constant belonging to $m_1$,
  • $k_2$: spring constant belonging to $m_2$,
  • $m_1$: mass 1,
  • $m_2$: mass 2,
  • $k$: coupling strength between the two masses.

Visualization of the coupled harmonic oscillator. Visualization of the coupled harmonic oscillator.

We will leave the parameters fixed but alter the initial conditions[1]:

using GeometricProblems.CoupledHarmonicOscillator: hodeensemble, default_parameters


const tstep = .3
const n_init_con = 5

# ensemble problem
ep = hodeensemble([rand(2) for _ in 1:n_init_con], [rand(2) for _ in 1:n_init_con]; tstep = tstep)
dl = DataLoader(integrate(ep, ImplicitMidpoint()); suppress_info = true)

We now use the same architecture, a TransformerIntegrator, twice, but alter its activation function:

const seq_length = 4
const batch_size = 1024
const n_epochs = 1000

act1 = GeometricMachineLearning.VectorSoftmax()
act2 = GeometricMachineLearning.MatrixSoftmax()

arch1 = StandardTransformerIntegrator(dl.input_dim; transformer_dim = 20,
                                                    n_heads = 4,
                                                    L = 1,
                                                    n_blocks = 2,
                                                    attention_activation = act1)

arch2 = StandardTransformerIntegrator(dl.input_dim; transformer_dim = 20,
                                                    n_heads = 4,
                                                    L = 1,
                                                    n_blocks = 2,
                                                    attention_activation = act2)

nn1 = NeuralNetwork(arch1)
nn2 = NeuralNetwork(arch2)
NeuralNetwork{StandardTransformerIntegrator{typeof(identity), typeof(tanh), MatrixSoftmax}, Chain{Tuple{Dense{4, 20, true, AbstractNeuralNetworks.GenericActivation{typeof(identity)}}, MultiHeadAttention{20, 20, false, true, MatrixSoftmax}, GeometricMachineLearning.ResNetLayer{20, 20, true, typeof(tanh)}, GeometricMachineLearning.ResNetLayer{20, 20, true, typeof(tanh)}, GeometricMachineLearning.ResNetLayer{20, 20, true, typeof(identity)}, Dense{20, 4, true, AbstractNeuralNetworks.GenericActivation{typeof(identity)}}}}, NeuralNetworkParameters{(:L1, :L2, :L3, :L4, :L5, :L6), Tuple{@NamedTuple{W::Matrix{Float64}, b::Vector{Float64}}, @NamedTuple{PQ::@NamedTuple{head_1::Matrix{Float64}, head_2::Matrix{Float64}, head_3::Matrix{Float64}, head_4::Matrix{Float64}}, PK::@NamedTuple{head_1::Matrix{Float64}, head_2::Matrix{Float64}, head_3::Matrix{Float64}, head_4::Matrix{Float64}}, PV::@NamedTuple{head_1::Matrix{Float64}, head_2::Matrix{Float64}, head_3::Matrix{Float64}, head_4::Matrix{Float64}}}, @NamedTuple{weight::Matrix{Float64}, bias::Vector{Float64}}, @NamedTuple{weight::Matrix{Float64}, bias::Vector{Float64}}, @NamedTuple{weight::Matrix{Float64}, bias::Vector{Float64}}, @NamedTuple{W::Matrix{Float64}, b::Vector{Float64}}}}, CPU}(StandardTransformerIntegrator{typeof(identity), typeof(tanh), MatrixSoftmax}(4, 20, 4, 2, 1, identity, tanh, MatrixSoftmax(), true), Chain{Tuple{Dense{4, 20, true, AbstractNeuralNetworks.GenericActivation{typeof(identity)}}, MultiHeadAttention{20, 20, false, true, MatrixSoftmax}, GeometricMachineLearning.ResNetLayer{20, 20, true, typeof(tanh)}, GeometricMachineLearning.ResNetLayer{20, 20, true, typeof(tanh)}, GeometricMachineLearning.ResNetLayer{20, 20, true, typeof(identity)}, Dense{20, 4, true, AbstractNeuralNetworks.GenericActivation{typeof(identity)}}}}((Dense{4, 20, true, AbstractNeuralNetworks.GenericActivation{typeof(identity)}}(AbstractNeuralNetworks.GenericActivation{typeof(identity)}(identity)), MultiHeadAttention{20, 20, false, true, MatrixSoftmax}(4, MatrixSoftmax()), GeometricMachineLearning.ResNetLayer{20, 20, true, typeof(tanh)}(tanh), GeometricMachineLearning.ResNetLayer{20, 20, true, typeof(tanh)}(tanh), GeometricMachineLearning.ResNetLayer{20, 20, true, typeof(identity)}(identity), Dense{20, 4, true, AbstractNeuralNetworks.GenericActivation{typeof(identity)}}(AbstractNeuralNetworks.GenericActivation{typeof(identity)}(identity)))), NeuralNetworkParameters{(:L1, :L2, :L3, :L4, :L5, :L6), Tuple{@NamedTuple{W::Matrix{Float64}, b::Vector{Float64}}, @NamedTuple{PQ::@NamedTuple{head_1::Matrix{Float64}, head_2::Matrix{Float64}, head_3::Matrix{Float64}, head_4::Matrix{Float64}}, PK::@NamedTuple{head_1::Matrix{Float64}, head_2::Matrix{Float64}, head_3::Matrix{Float64}, head_4::Matrix{Float64}}, PV::@NamedTuple{head_1::Matrix{Float64}, head_2::Matrix{Float64}, head_3::Matrix{Float64}, head_4::Matrix{Float64}}}, @NamedTuple{weight::Matrix{Float64}, bias::Vector{Float64}}, @NamedTuple{weight::Matrix{Float64}, bias::Vector{Float64}}, @NamedTuple{weight::Matrix{Float64}, bias::Vector{Float64}}, @NamedTuple{W::Matrix{Float64}, b::Vector{Float64}}}}((L1 = (W = [0.10066298375125626 -0.35364675771241916 -0.1467084800740771 0.4248749863385256; 0.17841791031605636 0.10377112965095447 -0.3578019841143091 0.14390668732266798; … ; 0.4117976291622122 -0.25631265585599383 0.44079804239449527 0.24918882580361879; 0.3294731111209084 -0.21413340994896535 0.10510427248976739 0.447475898424748], b = [-0.4681060972033016, -0.20712455344265535, 0.285850003312389, -0.4169648881850941, -0.08816110037106675, 0.4291899175276593, 0.1332105680505265, -0.02986062675580682, -0.4167258792256687, 0.0646547050572142, -0.32570626624612636, -0.3218095348265876, 0.48422882659102867, 0.12188674954007958, -0.3599836289256822, 0.48760915012293976, -0.02103630879737673, -0.4619724791421773, -0.1503113445449189, -0.34928916384506964]), L2 = (PQ = (head_1 = [0.016662254025383724 0.3883140865148206 … -0.43253739969102206 0.0333614489770306; 0.05332456288357622 -0.2748278494488077 … -0.32921338009876594 -0.39865600863045897; … ; -0.47430210448935045 -0.22631609493043883 … 0.2986225017246056 -0.2885430758666563; 0.21382982480630766 -0.04240473309199647 … 0.26233682799922253 0.44278445096438174], head_2 = [-0.14595593408746108 -0.36998084025936856 … 0.2480886306353009 -0.2994615413080256; 0.15855780751556967 0.07721133495054608 … -0.39070214419902344 0.2021741651212523; … ; -0.3629226001060866 -0.27152661418974405 … -0.4013365510938958 -0.1613994314775208; 0.40544037794588783 -0.21310096936314976 … 0.2569840992468988 0.014842343909321331], head_3 = [-0.32118749844651096 -0.2527465445868935 … 0.34822091748824 -0.40442184982573837; -0.045185596894312056 0.2782124480831981 … 0.29948039914855945 0.32961907136386526; … ; 0.3003314510363774 0.0887591715836612 … -0.4171102060057366 0.29777482785297776; 0.17006159894476358 0.20084373540378692 … 0.03492045036046375 -0.30834000789444893], head_4 = [0.2555265899920641 0.07254383274176389 … -0.40311719666618406 0.12942001548755755; 0.19517109534091054 0.15328117889850437 … 0.01981777151055476 0.08379469650771493; … ; 0.05736065922491906 0.42435555079730664 … -0.01790026191549013 -0.47475125524981704; -0.0730195409187248 -0.06758040647809833 … 0.29668235541909194 -0.1044040843722482]), PK = (head_1 = [0.42699508813671977 -0.38522342146291594 … -0.2474818151443071 0.08572369633539789; -0.1467419263723294 0.3365968159588514 … 0.14197589305251818 0.36816289685569853; … ; -0.08033022237522121 0.1830275096439908 … -0.2391532258494944 -0.0021102649960473055; -0.20153006079879118 0.4148050823910771 … -0.24414980059036226 -0.27622596995522625], head_2 = [0.3610130127456452 -0.30713009211954323 … -0.47557821249562393 -0.22512208731435757; 0.45161842291378773 -0.23055953255406617 … 0.0851877060783184 -0.2721844102840559; … ; 0.16901040827908884 0.33638546041625655 … 0.3108949097705041 0.19829668483725152; 0.39459259487021986 0.3220159446271593 … 0.4756597542365993 0.03944110938606986], head_3 = [-0.33197116709706265 0.41439335711663455 … -0.3528757163349376 -0.1759430639874229; -0.35726470037138236 0.4609876155451485 … -0.07540472148470842 0.12174855553664475; … ; 0.2180188115061345 -0.04147999150372373 … -0.37659082500259755 -0.45608844649254043; -0.4486304412145292 -0.3609072106865055 … -0.06022748344015519 0.17537562756507136], head_4 = [-0.09279806263500859 -0.4158002368831087 … -0.05232674220111158 -0.2587218173239759; -0.4448145360070213 -0.2536363035377226 … 0.27274768505122576 -0.4506318151926834; … ; -0.4757857743041415 -0.10237666095060396 … -0.42474030242873745 -0.06150377418511561; 0.16668316423086268 0.14998792189423799 … 0.4847168777867799 0.037151935012819555]), PV = (head_1 = [0.1480941686119456 -0.22523458019291404 … 0.11780480454416423 -0.1070162407271716; -0.15542053242777223 0.35610768873546633 … -0.22635165152312206 -0.4434450687298628; … ; 0.2932520899458074 0.28144746255829756 … 0.01987362707324607 0.10095597619219437; 0.34996473798317673 -0.1450118567833755 … 0.2797208151776914 0.4638151184226553], head_2 = [-0.28828731872315055 -0.04307703366005785 … 0.4546511818988031 -0.2145085566403588; 0.41164409521697853 -0.22835070755756323 … -0.10677494532554317 -0.24893465931784517; … ; 0.005571631997165911 -0.17454852701022558 … 0.14241741562655816 0.0015851319541329036; 0.42022675155050565 0.2396905550671086 … -0.07961644400966843 -0.26371940192161986], head_3 = [-0.10726083353112015 0.13802834843108988 … -0.4050728079493777 0.006557293245175603; -0.2614087026603019 -0.12196467214009152 … -0.4588095229618329 -0.39113440940799643; … ; -0.23418409493713194 -0.45030643174171825 … -0.1048436690263366 0.022683362848324774; 0.16793696131566152 -0.22359521744334568 … 0.19493186510684515 0.24895258888970542], head_4 = [0.21130767022623112 0.029306791205836926 … 0.11374106275370172 -0.28785266794645714; -0.08439548569182322 0.46562854119563085 … -0.44306619675580927 0.44565289343891074; … ; 0.09881893342785555 -0.3418888941855809 … 0.2017103848985018 0.23307282862715262; 0.03857626631759694 -0.1157117565351148 … -0.4862265272252116 -0.2784643805329771])), L3 = (weight = [0.3034921900450392 0.23201542817217197 … 0.3794003470725941 -0.33933658520539706; -0.15212460857380425 -0.06972727143936149 … -0.35890369378863723 0.08152448679960361; … ; -0.27834364077992785 0.12919351411256863 … -0.3212181099233573 -0.37916850831779647; 0.20014072110040737 0.3570905482660302 … 0.0943706547940748 0.23149859316219504], bias = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), L4 = (weight = [0.238205540234594 -0.0848064118940128 … -0.20489320256457114 0.357039340024244; 0.3568988120281821 0.16578717937199122 … 0.2604956660043026 -0.011088627645045299; … ; 0.34020350803681104 0.21535906494918014 … 0.3774884107175589 -0.12052205871098912; -0.20112345749712904 0.07205151722423399 … 0.08396709259173978 0.04502995230254548], bias = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), L5 = (weight = [-0.009691667566346403 0.1982498129193345 … -0.2626605461419492 -0.05629822094292248; -0.18014558618447565 -0.10889191300237576 … -0.15062315535433746 -0.10253154985196693; … ; 0.36832929957128824 -0.3837896523553637 … 0.041022564132100495 0.09507042049541484; -0.26095898693565744 -0.04750150275138121 … 0.2050313237403284 -0.2060453675084893], bias = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), L6 = (W = [0.08376370538592492 -0.2776287107775792 … -0.46647661600728696 -0.06250842845140303; 0.4239084154976247 0.05014816311118264 … 0.09695869505847587 -0.23399586453375165; 0.055606373247990004 -0.40699246251171606 … 0.2198043914085226 0.015599052652691237; -0.39723266745211205 -0.010474883920549738 … -0.020543826476554394 -0.44044006704670613], b = [0.9769208029751789, 0.4302486020565372, -0.7843766234138053, 0.5008382038893412]))), CPU(false))

Training is done with the AdamOptimizer:

o_method = AdamOptimizer()

o1 = Optimizer(o_method, nn1)
o2 = Optimizer(o_method, nn2)

batch = Batch(batch_size, seq_length)

loss_array1 = o1(nn1, dl, batch, n_epochs; show_progress = false)
loss_array2 = o2(nn2, dl, batch, n_epochs; show_progress = false)
1000-element Vector{Float64}:
 3.3058544353285093
 3.2716342817452713
 2.915376993735294
 2.7860645148734005
 2.4956956820771525
 2.2723535892533517
 2.078811417973053
 1.9179707422966026
 1.7644973915207747
 1.5738499204543892
 ⋮
 0.012745961204577239
 0.01149839638917843
 0.009551757124576066
 0.009213955857990756
 0.009715837431882375
 0.011508647171753705
 0.01001547633452667
 0.009198001172487591
 0.011717177768242715

Training loss for the different networks. Training loss for the different networks.

Predicting trajectories with transformers based on the vector softmax and the matrix softmax. Predicting trajectories with transformers based on the vector softmax and the matrix softmax.

A similar page can be found here.

  • 1We here use the implementation of the coupled harmonic oscillator from GeometricProblems.