Linear Symplectic Transformer
In this section we compare the linear symplectic transformer to the standard transformer. The example we treat here is the coupled harmonic oscillator:
It is a Hamiltonian system with
\[H(q_1, q_2, p_1, p_2) = \frac{q_1^2}{2m_1} + \frac{q_2^2}{2m_2} + k_1\frac{q_1^2}{2} + k_2\frac{q_2^2}{2} + k\sigma(q_1)\frac{(q_2 - q_1)^2}{2},\]
where $\sigma(x) = 1 / (1 + e^{-x})$ is the sigmoid activation function. The system parameters are:
- $k_1$: spring constant belonging to $m_1$,
- $k_2$: spring constant belonging to $m_2$,
- $m_1$: mass 1,
- $m_2$: mass 2,
- $k$: coupling strength between the two masses.
To demonstrate the efficacy of the linear symplectic transformer here we will leave the parameters fixed but alter the initial conditions[1]:
using GeometricProblems.CoupledHarmonicOscillator: hodeensemble, default_parameters
const tstep = .3
const n_init_con = 5
# ensemble problem
ep = hodeensemble([rand(2) for _ in 1:n_init_con], [rand(2) for _ in 1:n_init_con]; tstep = tstep)
dl = DataLoader(integrate(ep, ImplicitMidpoint()); suppress_info = true)
We now define the architectures and train them:
const seq_length = 4
const batch_size = 1024
const n_epochs = 2000
arch_standard = StandardTransformerIntegrator(dl.input_dim; n_heads = 2,
L = 1,
n_blocks = 2)
arch_symplectic = LinearSymplecticTransformer( dl.input_dim,
seq_length; n_sympnet = 2,
L = 1,
upscaling_dimension = 2 * dl.input_dim)
arch_sympnet = GSympNet(dl.input_dim; n_layers = 4,
upscaling_dimension = 2 * dl.input_dim)
nn_standard = NeuralNetwork(arch_standard)
nn_symplectic = NeuralNetwork(arch_symplectic)
nn_sympnet = NeuralNetwork(arch_sympnet)
o_method = AdamOptimizerWithDecay(n_epochs, Float64)
o_standard = Optimizer(o_method, nn_standard)
o_symplectic = Optimizer(o_method, nn_symplectic)
o_sympnet = Optimizer(o_method, nn_sympnet)
batch = Batch(batch_size, seq_length)
batch2 = Batch(batch_size)
loss_array_standard = o_standard(nn_standard, dl, batch, n_epochs; show_progress = false)
loss_array_symplectic = o_symplectic(nn_symplectic, dl, batch, n_epochs; show_progress = false)
loss_array_sympnet = o_sympnet(nn_sympnet, dl, batch2, n_epochs; show_progress = false)
And the corresponding training losses look as follows:
We further evaluate a trajectory with the trained networks for thirty time steps:
We can see that the standard transformer is not able to stay close to the trajectory coming from implicit midpoint very well. The linear symplectic transformer outperforms the standard transformer as well as the SympNet while needing fewer parameters than the standard transformer:
parameterlength(nn_standard), parameterlength(nn_symplectic), parameterlength(nn_sympnet)
(108, 84, 64)
It is also interesting to note that the training error for the SympNet gets lower than the one for the linear symplectic transformer, but it does not manage to outperform it when looking at the validation.
- 1We here use the implementation of the coupled harmonic oscillator from
GeometricProblems
.