Matrix Softmax v Vector Softmax
In this section we compare the VectorSoftmax
to the MatrixSoftmax
. What is usually meant by softmax is the vector softmax, i.e. one that does:
\[[\mathrm{softmax}(a)]_i = \frac{e^{a_i}}{\sum_{i'=1}^de^{a_i}}. \]
So each column of a matrix is normalized to sum up to one. With this softmax, the linear recombination that is performed by the attention layer becomes a convex recombination. This is not the case for the MatrixSoftmax
, where the normalization is computed over all matrix entries:
\[[\mathrm{softmax}(A)]_{ij} = \frac{e^{A_{ij}}}{\sum_{i'=1, j'=1}^{d,\bar{d}}e^{A_{ij}}}. \]
We want to compare those two approaches on the example of the coupled harmonic oscillator. It is a Hamiltonian system with
\[H(q_1, q_2, p_1, p_2) = \frac{q_1^2}{2m_1} + \frac{q_2^2}{2m_2} + k_1\frac{q_1^2}{2} + k_2\frac{q_2^2}{2} + k\sigma(q_1)\frac{(q_2 - q_1)^2}{2},\]
where $\sigma(x) = 1 / (1 + e^{-x})$ is the sigmoid activation function. The system parameters are:
- $k_1$: spring constant belonging to $m_1$,
- $k_2$: spring constant belonging to $m_2$,
- $m_1$: mass 1,
- $m_2$: mass 2,
- $k$: coupling strength between the two masses.
We will leave the parameters fixed but alter the initial conditions[1]:
using GeometricProblems.CoupledHarmonicOscillator: hodeensemble, default_parameters
const tstep = .3
const n_init_con = 5
# ensemble problem
ep = hodeensemble([rand(2) for _ in 1:n_init_con], [rand(2) for _ in 1:n_init_con]; tstep = tstep)
dl = DataLoader(integrate(ep, ImplicitMidpoint()); suppress_info = true)
We now use the same architecture, a TransformerIntegrator
, twice, but alter its activation function:
const seq_length = 4
const batch_size = 1024
const n_epochs = 1000
act1 = GeometricMachineLearning.VectorSoftmax()
act2 = GeometricMachineLearning.MatrixSoftmax()
arch1 = StandardTransformerIntegrator(dl.input_dim; transformer_dim = 20,
n_heads = 4,
L = 1,
n_blocks = 2,
attention_activation = act1)
arch2 = StandardTransformerIntegrator(dl.input_dim; transformer_dim = 20,
n_heads = 4,
L = 1,
n_blocks = 2,
attention_activation = act2)
nn1 = NeuralNetwork(arch1)
nn2 = NeuralNetwork(arch2)
NeuralNetwork{StandardTransformerIntegrator{typeof(identity), typeof(tanh), MatrixSoftmax}, Chain{Tuple{Dense{4, 20, true, AbstractNeuralNetworks.GenericActivation{typeof(identity)}}, MultiHeadAttention{20, 20, false, true, MatrixSoftmax}, GeometricMachineLearning.ResNetLayer{20, 20, true, typeof(tanh)}, GeometricMachineLearning.ResNetLayer{20, 20, true, typeof(tanh)}, GeometricMachineLearning.ResNetLayer{20, 20, true, typeof(identity)}, Dense{20, 4, true, AbstractNeuralNetworks.GenericActivation{typeof(identity)}}}}, NeuralNetworkParameters{(:L1, :L2, :L3, :L4, :L5, :L6), Tuple{@NamedTuple{W::Matrix{Float64}, b::Vector{Float64}}, @NamedTuple{PQ::@NamedTuple{head_1::Matrix{Float64}, head_2::Matrix{Float64}, head_3::Matrix{Float64}, head_4::Matrix{Float64}}, PK::@NamedTuple{head_1::Matrix{Float64}, head_2::Matrix{Float64}, head_3::Matrix{Float64}, head_4::Matrix{Float64}}, PV::@NamedTuple{head_1::Matrix{Float64}, head_2::Matrix{Float64}, head_3::Matrix{Float64}, head_4::Matrix{Float64}}}, @NamedTuple{weight::Matrix{Float64}, bias::Vector{Float64}}, @NamedTuple{weight::Matrix{Float64}, bias::Vector{Float64}}, @NamedTuple{weight::Matrix{Float64}, bias::Vector{Float64}}, @NamedTuple{W::Matrix{Float64}, b::Vector{Float64}}}}, CPU}(StandardTransformerIntegrator{typeof(identity), typeof(tanh), MatrixSoftmax}(4, 20, 4, 2, 1, identity, tanh, MatrixSoftmax(), true), Chain{Tuple{Dense{4, 20, true, AbstractNeuralNetworks.GenericActivation{typeof(identity)}}, MultiHeadAttention{20, 20, false, true, MatrixSoftmax}, GeometricMachineLearning.ResNetLayer{20, 20, true, typeof(tanh)}, GeometricMachineLearning.ResNetLayer{20, 20, true, typeof(tanh)}, GeometricMachineLearning.ResNetLayer{20, 20, true, typeof(identity)}, Dense{20, 4, true, AbstractNeuralNetworks.GenericActivation{typeof(identity)}}}}((Dense{4, 20, true, AbstractNeuralNetworks.GenericActivation{typeof(identity)}}(AbstractNeuralNetworks.GenericActivation{typeof(identity)}(identity)), MultiHeadAttention{20, 20, false, true, MatrixSoftmax}(4, MatrixSoftmax()), GeometricMachineLearning.ResNetLayer{20, 20, true, typeof(tanh)}(tanh), GeometricMachineLearning.ResNetLayer{20, 20, true, typeof(tanh)}(tanh), GeometricMachineLearning.ResNetLayer{20, 20, true, typeof(identity)}(identity), Dense{20, 4, true, AbstractNeuralNetworks.GenericActivation{typeof(identity)}}(AbstractNeuralNetworks.GenericActivation{typeof(identity)}(identity)))), NeuralNetworkParameters{(:L1, :L2, :L3, :L4, :L5, :L6), Tuple{@NamedTuple{W::Matrix{Float64}, b::Vector{Float64}}, @NamedTuple{PQ::@NamedTuple{head_1::Matrix{Float64}, head_2::Matrix{Float64}, head_3::Matrix{Float64}, head_4::Matrix{Float64}}, PK::@NamedTuple{head_1::Matrix{Float64}, head_2::Matrix{Float64}, head_3::Matrix{Float64}, head_4::Matrix{Float64}}, PV::@NamedTuple{head_1::Matrix{Float64}, head_2::Matrix{Float64}, head_3::Matrix{Float64}, head_4::Matrix{Float64}}}, @NamedTuple{weight::Matrix{Float64}, bias::Vector{Float64}}, @NamedTuple{weight::Matrix{Float64}, bias::Vector{Float64}}, @NamedTuple{weight::Matrix{Float64}, bias::Vector{Float64}}, @NamedTuple{W::Matrix{Float64}, b::Vector{Float64}}}}((L1 = (W = [0.10066298375125626 -0.35364675771241916 -0.1467084800740771 0.4248749863385256; 0.17841791031605636 0.10377112965095447 -0.3578019841143091 0.14390668732266798; … ; 0.4117976291622122 -0.25631265585599383 0.44079804239449527 0.24918882580361879; 0.3294731111209084 -0.21413340994896535 0.10510427248976739 0.447475898424748], b = [-0.4681060972033016, -0.20712455344265535, 0.285850003312389, -0.4169648881850941, -0.08816110037106675, 0.4291899175276593, 0.1332105680505265, -0.02986062675580682, -0.4167258792256687, 0.0646547050572142, -0.32570626624612636, -0.3218095348265876, 0.48422882659102867, 0.12188674954007958, -0.3599836289256822, 0.48760915012293976, -0.02103630879737673, -0.4619724791421773, -0.1503113445449189, -0.34928916384506964]), L2 = (PQ = (head_1 = [0.016662254025383724 0.3883140865148206 … -0.43253739969102206 0.0333614489770306; 0.05332456288357622 -0.2748278494488077 … -0.32921338009876594 -0.39865600863045897; … ; -0.47430210448935045 -0.22631609493043883 … 0.2986225017246056 -0.2885430758666563; 0.21382982480630766 -0.04240473309199647 … 0.26233682799922253 0.44278445096438174], head_2 = [-0.14595593408746108 -0.36998084025936856 … 0.2480886306353009 -0.2994615413080256; 0.15855780751556967 0.07721133495054608 … -0.39070214419902344 0.2021741651212523; … ; -0.3629226001060866 -0.27152661418974405 … -0.4013365510938958 -0.1613994314775208; 0.40544037794588783 -0.21310096936314976 … 0.2569840992468988 0.014842343909321331], head_3 = [-0.32118749844651096 -0.2527465445868935 … 0.34822091748824 -0.40442184982573837; -0.045185596894312056 0.2782124480831981 … 0.29948039914855945 0.32961907136386526; … ; 0.3003314510363774 0.0887591715836612 … -0.4171102060057366 0.29777482785297776; 0.17006159894476358 0.20084373540378692 … 0.03492045036046375 -0.30834000789444893], head_4 = [0.2555265899920641 0.07254383274176389 … -0.40311719666618406 0.12942001548755755; 0.19517109534091054 0.15328117889850437 … 0.01981777151055476 0.08379469650771493; … ; 0.05736065922491906 0.42435555079730664 … -0.01790026191549013 -0.47475125524981704; -0.0730195409187248 -0.06758040647809833 … 0.29668235541909194 -0.1044040843722482]), PK = (head_1 = [0.42699508813671977 -0.38522342146291594 … -0.2474818151443071 0.08572369633539789; -0.1467419263723294 0.3365968159588514 … 0.14197589305251818 0.36816289685569853; … ; -0.08033022237522121 0.1830275096439908 … -0.2391532258494944 -0.0021102649960473055; -0.20153006079879118 0.4148050823910771 … -0.24414980059036226 -0.27622596995522625], head_2 = [0.3610130127456452 -0.30713009211954323 … -0.47557821249562393 -0.22512208731435757; 0.45161842291378773 -0.23055953255406617 … 0.0851877060783184 -0.2721844102840559; … ; 0.16901040827908884 0.33638546041625655 … 0.3108949097705041 0.19829668483725152; 0.39459259487021986 0.3220159446271593 … 0.4756597542365993 0.03944110938606986], head_3 = [-0.33197116709706265 0.41439335711663455 … -0.3528757163349376 -0.1759430639874229; -0.35726470037138236 0.4609876155451485 … -0.07540472148470842 0.12174855553664475; … ; 0.2180188115061345 -0.04147999150372373 … -0.37659082500259755 -0.45608844649254043; -0.4486304412145292 -0.3609072106865055 … -0.06022748344015519 0.17537562756507136], head_4 = [-0.09279806263500859 -0.4158002368831087 … -0.05232674220111158 -0.2587218173239759; -0.4448145360070213 -0.2536363035377226 … 0.27274768505122576 -0.4506318151926834; … ; -0.4757857743041415 -0.10237666095060396 … -0.42474030242873745 -0.06150377418511561; 0.16668316423086268 0.14998792189423799 … 0.4847168777867799 0.037151935012819555]), PV = (head_1 = [0.1480941686119456 -0.22523458019291404 … 0.11780480454416423 -0.1070162407271716; -0.15542053242777223 0.35610768873546633 … -0.22635165152312206 -0.4434450687298628; … ; 0.2932520899458074 0.28144746255829756 … 0.01987362707324607 0.10095597619219437; 0.34996473798317673 -0.1450118567833755 … 0.2797208151776914 0.4638151184226553], head_2 = [-0.28828731872315055 -0.04307703366005785 … 0.4546511818988031 -0.2145085566403588; 0.41164409521697853 -0.22835070755756323 … -0.10677494532554317 -0.24893465931784517; … ; 0.005571631997165911 -0.17454852701022558 … 0.14241741562655816 0.0015851319541329036; 0.42022675155050565 0.2396905550671086 … -0.07961644400966843 -0.26371940192161986], head_3 = [-0.10726083353112015 0.13802834843108988 … -0.4050728079493777 0.006557293245175603; -0.2614087026603019 -0.12196467214009152 … -0.4588095229618329 -0.39113440940799643; … ; -0.23418409493713194 -0.45030643174171825 … -0.1048436690263366 0.022683362848324774; 0.16793696131566152 -0.22359521744334568 … 0.19493186510684515 0.24895258888970542], head_4 = [0.21130767022623112 0.029306791205836926 … 0.11374106275370172 -0.28785266794645714; -0.08439548569182322 0.46562854119563085 … -0.44306619675580927 0.44565289343891074; … ; 0.09881893342785555 -0.3418888941855809 … 0.2017103848985018 0.23307282862715262; 0.03857626631759694 -0.1157117565351148 … -0.4862265272252116 -0.2784643805329771])), L3 = (weight = [0.3034921900450392 0.23201542817217197 … 0.3794003470725941 -0.33933658520539706; -0.15212460857380425 -0.06972727143936149 … -0.35890369378863723 0.08152448679960361; … ; -0.27834364077992785 0.12919351411256863 … -0.3212181099233573 -0.37916850831779647; 0.20014072110040737 0.3570905482660302 … 0.0943706547940748 0.23149859316219504], bias = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), L4 = (weight = [0.238205540234594 -0.0848064118940128 … -0.20489320256457114 0.357039340024244; 0.3568988120281821 0.16578717937199122 … 0.2604956660043026 -0.011088627645045299; … ; 0.34020350803681104 0.21535906494918014 … 0.3774884107175589 -0.12052205871098912; -0.20112345749712904 0.07205151722423399 … 0.08396709259173978 0.04502995230254548], bias = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), L5 = (weight = [-0.009691667566346403 0.1982498129193345 … -0.2626605461419492 -0.05629822094292248; -0.18014558618447565 -0.10889191300237576 … -0.15062315535433746 -0.10253154985196693; … ; 0.36832929957128824 -0.3837896523553637 … 0.041022564132100495 0.09507042049541484; -0.26095898693565744 -0.04750150275138121 … 0.2050313237403284 -0.2060453675084893], bias = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), L6 = (W = [0.08376370538592492 -0.2776287107775792 … -0.46647661600728696 -0.06250842845140303; 0.4239084154976247 0.05014816311118264 … 0.09695869505847587 -0.23399586453375165; 0.055606373247990004 -0.40699246251171606 … 0.2198043914085226 0.015599052652691237; -0.39723266745211205 -0.010474883920549738 … -0.020543826476554394 -0.44044006704670613], b = [0.9769208029751789, 0.4302486020565372, -0.7843766234138053, 0.5008382038893412]))), CPU(false))
Training is done with the AdamOptimizer
:
o_method = AdamOptimizer()
o1 = Optimizer(o_method, nn1)
o2 = Optimizer(o_method, nn2)
batch = Batch(batch_size, seq_length)
loss_array1 = o1(nn1, dl, batch, n_epochs; show_progress = false)
loss_array2 = o2(nn2, dl, batch, n_epochs; show_progress = false)
1000-element Vector{Float64}:
3.3058544353285093
3.2716342817452713
2.915376993735294
2.7860645148734005
2.4956956820771525
2.2723535892533517
2.078811417973053
1.9179707422966026
1.7644973915207747
1.5738499204543892
⋮
0.012745961204577239
0.01149839638917843
0.009551757124576066
0.009213955857990756
0.009715837431882375
0.011508647171753705
0.01001547633452667
0.009198001172487591
0.011717177768242715
A similar page can be found here.
- 1We here use the implementation of the coupled harmonic oscillator from
GeometricProblems
.