SympNets with GeometricMachineLearning
This page serves as a short introduction into using SympNets with GeometricMachineLearning.jl
. For the general theory see the theory section.
With GeometricMachineLearning.jl
one can easily implement SympNets. The steps are the following :
- Specify the architecture with the functions
GSympNet
andLASympNet
, - Specify the type and the backend with
NeuralNetwork
, - Pick an optimizer for training the network,
- Train the neural networks!
We discuss these points is some detail:
Specifying the architecture
To call an $LA$-SympNet, one needs to write
lasympnet = LASympNet(dim; depth=5, nhidden=1, activation=tanh, init_upper_linear=true, init_upper_act=true)
LASympNet
takes one obligatory argument:
- dim : the dimension of the phase space (i.e. an integer) or optionally an instance of
DataLoader
. This latter option will be used below.
and several keywords argument :
- depth : the depth for all the linear layers. The default value set to 5 (if width>5, width is set to 5). See the theory section for more details; there depth was called $n$.
- nhidden : the number of pairs of linear and activation layers with default value set to 1 (i.e the $LA$-SympNet is a composition of a linear layer, an activation layer and then again a single layer).
- activation : the activation function for all the activations layers with default set to tanh,
- initupperlinear : a boolean that indicates whether the first linear layer changes $q$ first. By default this is
true
. - initupperact : a boolean that indicates whether the first activation layer changes $q$ first. By default this is
true
.
G-SympNet
To call a G-SympNet, one needs to write
gsympnet = GSympNet(dim; upscaling_dimension=2*dim, n_layers=2, activation=tanh, init_upper=true)
GSympNet
takes one obligatory argument:
- dim : the dimension of the phase space (i.e. an integer) or optionally an instance of
DataLoader
. This latter option will be used below.
and severals keywords argument :
- upscaling_dimension: The first dimension of the matrix with which the input is multiplied. In the theory section this matrix is called $K$ and the upscaling dimension is called $m$.
- n_layers: the number of gradient layers with default value set to 2.
- activation : the activation function for all the activations layers with default set to tanh.
- init_upper : a boolean that indicates whether the first gradient layer changes $q$ first. By default this is
true
.
Loss function
The loss function described in the theory section is the default choice used in GeometricMachineLearning.jl
for training SympNets.
Examples
Let us see how to use it on several examples.
Example of a pendulum with G-SympNet
Let us begin with a simple example, the pendulum system, the Hamiltonian of which is
\[H:(q,p)\in\mathbb{R}^2 \mapsto \frac{1}{2}p^2-cos(q) \in \mathbb{R}.\]
Here we generate pendulum data with the script GeometricMachineLearning/scripts/pendulum.jl
:
Random.seed!(1234)
# load script
include("../../../scripts/pendulum.jl")
# specify the data type
type = Float16
# get data
qp_data = GeometricMachineLearning.apply_toNT(a -> type.(a), pendulum_data((q=[0.], p=[1.]); tspan=(0.,100.)))
# call the DataLoader
dl = DataLoader(qp_data)
[ Info: You have provided a NamedTuple with keys q and p; the data are matrices. This is interpreted as *symplectic data*.
Next we specify the architectures. GeometricMachineLearning.jl
provides useful defaults for all parameters although they can be specified manually (which is done in the following):
# layer dimension for gradient module
const upscaling_dimension = 2
# hidden layers
const nhidden = 1
# activation function
const activation = tanh
# calling G-SympNet architecture
gsympnet = GSympNet(dl, upscaling_dimension=upscaling_dimension, n_layers=4, activation=activation)
# calling LA-SympNet architecture
lasympnet = LASympNet(dl, nhidden=nhidden, activation=activation)
# specify the backend
const backend = CPU()
# initialize the networks
la_nn = NeuralNetwork(lasympnet, backend, type)
g_nn = NeuralNetwork(gsympnet, backend, type)
If we want to obtain information on the number of parameters in a neural network, we can do that very simply with the function parameterlength
. For the LASympNet
:
parameterlength(la_nn.model)
14
And for the GSympNet
:
parameterlength(g_nn.model)
12
Remark: We can also specify whether we would like to start with a layer that changes the $q$-component or one that changes the $p$-component. This can be done via the keywords init_upper
for GSympNet
, and init_upper_linear
and init_upper_act
for LASympNet
.
We have to define an optimizer which will be use in the training of the SympNet. For more details on optimizer, please see the corresponding documentation. In this example we use Adam:
# set up optimizer; for this we first need to specify the optimization method (argue for why we need the optimizer method)
opt_method = AdamOptimizer(type)
la_opt = Optimizer(opt_method, la_nn)
g_opt = Optimizer(opt_method, g_nn)
We can now perform the training of the neural networks. The syntax is the following :
# number of training epochs
const nepochs = 300
# Batchsize used to compute the gradient of the loss function with respect to the parameters of the neural networks.
const batch_size = 100
batch = Batch(batch_size)
# perform training (returns array that contains the total loss for each training step)
g_loss_array = g_opt(g_nn, dl, batch, nepochs)
la_loss_array = la_opt(la_nn, dl, batch, nepochs)
Progress: 1%|▎ | ETA: 0:40:24
TrainingLoss: 1.7803702927949743
Progress: 2%|▋ | ETA: 0:16:07
TrainingLoss: 1.6495201816701386
Progress: 3%|█▏ | ETA: 0:10:02
TrainingLoss: 1.5234181792626509
Progress: 4%|█▌ | ETA: 0:07:17
TrainingLoss: 1.4035815451069618
Progress: 5%|█▉ | ETA: 0:05:42
TrainingLoss: 1.2922086676664217
Progress: 6%|██▍ | ETA: 0:04:41
TrainingLoss: 1.187802409322577
Progress: 7%|██▊ | ETA: 0:03:58
TrainingLoss: 1.093293775908665
Progress: 8%|███▏ | ETA: 0:03:26
TrainingLoss: 1.0092221210485548
Progress: 9%|███▌ | ETA: 0:03:02
TrainingLoss: 0.9381178786604968
Progress: 10%|████ | ETA: 0:02:42
TrainingLoss: 0.8720628349154051
Progress: 11%|████▍ | ETA: 0:02:26
TrainingLoss: 0.8097303657182294
Progress: 12%|████▊ | ETA: 0:02:13
TrainingLoss: 0.7488072204298228
Progress: 13%|█████▎ | ETA: 0:02:02
TrainingLoss: 0.6899178937402026
Progress: 14%|█████▋ | ETA: 0:01:53
TrainingLoss: 0.6336424245625359
Progress: 15%|██████▏ | ETA: 0:01:42
TrainingLoss: 0.5619970084835468
Progress: 16%|██████▌ | ETA: 0:01:35
TrainingLoss: 0.5117486223712584
Progress: 17%|███████ | ETA: 0:01:29
TrainingLoss: 0.46329032848486396
Progress: 18%|███████▍ | ETA: 0:01:24
TrainingLoss: 0.417877585533165
Progress: 19%|███████▊ | ETA: 0:01:19
TrainingLoss: 0.3752050267389528
Progress: 20%|████████▎ | ETA: 0:01:14
TrainingLoss: 0.3349818101118874
Progress: 21%|████████▋ | ETA: 0:01:10
TrainingLoss: 0.2985153433174448
Progress: 22%|█████████ | ETA: 0:01:07
TrainingLoss: 0.2641840626069677
Progress: 23%|█████████▍ | ETA: 0:01:03
TrainingLoss: 0.23226148188991774
Progress: 24%|█████████▉ | ETA: 0:01:00
TrainingLoss: 0.20241745043567644
Progress: 25%|██████████▎ | ETA: 0:00:58
TrainingLoss: 0.1755180435418891
Progress: 26%|██████████▋ | ETA: 0:00:55
TrainingLoss: 0.1528642869631937
Progress: 27%|███████████▏ | ETA: 0:00:52
TrainingLoss: 0.13164967357565685
Progress: 28%|███████████▌ | ETA: 0:00:50
TrainingLoss: 0.1119712745579958
Progress: 29%|███████████▉ | ETA: 0:00:48
TrainingLoss: 0.09400871520126969
Progress: 30%|████████████▎ | ETA: 0:00:46
TrainingLoss: 0.07735418657180346
Progress: 31%|████████████▊ | ETA: 0:00:44
TrainingLoss: 0.06198124321291002
Progress: 32%|█████████████▏ | ETA: 0:00:42
TrainingLoss: 0.04798051607505162
Progress: 33%|█████████████▌ | ETA: 0:00:41
TrainingLoss: 0.03779552527725173
Progress: 34%|██████████████ | ETA: 0:00:39
TrainingLoss: 0.028516727667412305
Progress: 35%|██████████████▍ | ETA: 0:00:38
TrainingLoss: 0.019865285410268947
Progress: 36%|██████████████▊ | ETA: 0:00:36
TrainingLoss: 0.011861524055736152
Progress: 37%|███████████████▏ | ETA: 0:00:35
TrainingLoss: 0.004593417181906671
Progress: 38%|███████████████▋ | ETA: 0:00:34
TrainingLoss: 0.0032265105853485276
Progress: 39%|████████████████ | ETA: 0:00:33
TrainingLoss: 0.0029803924065258105
Progress: 40%|████████████████▍ | ETA: 0:00:31
TrainingLoss: 0.002896752691009678
Progress: 41%|████████████████▊ | ETA: 0:00:30
TrainingLoss: 0.002856994063585451
Progress: 42%|█████████████████▎ | ETA: 0:00:29
TrainingLoss: 0.0028414417069307467
Progress: 43%|█████████████████▋ | ETA: 0:00:28
TrainingLoss: 0.0028103721858569247
Progress: 44%|██████████████████ | ETA: 0:00:27
TrainingLoss: 0.002806339083148671
Progress: 45%|██████████████████▌ | ETA: 0:00:26
TrainingLoss: 0.002800481547008411
Progress: 46%|██████████████████▉ | ETA: 0:00:25
TrainingLoss: 0.0027477281866818684
Progress: 47%|███████████████████▎ | ETA: 0:00:25
TrainingLoss: 0.002712417775081033
Progress: 48%|███████████████████▋ | ETA: 0:00:24
TrainingLoss: 0.002697894027342172
Progress: 49%|████████████████████▏ | ETA: 0:00:23
TrainingLoss: 0.0026357550519469815
Progress: 50%|████████████████████▌ | ETA: 0:00:22
TrainingLoss: 0.0026325975467476692
Progress: 51%|████████████████████▉ | ETA: 0:00:21
TrainingLoss: 0.0025908293032137547
Progress: 52%|█████████████████████▍ | ETA: 0:00:21
TrainingLoss: 0.0025345162972804343
Progress: 53%|█████████████████████▊ | ETA: 0:00:20
TrainingLoss: 0.002450647213380506
Progress: 54%|██████████████████████▏ | ETA: 0:00:19
TrainingLoss: 0.002396951772359002
Progress: 55%|██████████████████████▌ | ETA: 0:00:19
TrainingLoss: 0.002308707821496941
Progress: 56%|███████████████████████ | ETA: 0:00:18
TrainingLoss: 0.002254083676788514
Progress: 57%|███████████████████████▍ | ETA: 0:00:17
TrainingLoss: 0.0021928191996808564
Progress: 58%|███████████████████████▊ | ETA: 0:00:17
TrainingLoss: 0.0021140381737167133
Progress: 59%|████████████████████████▎ | ETA: 0:00:16
TrainingLoss: 0.002047186532870345
Progress: 60%|████████████████████████▋ | ETA: 0:00:16
TrainingLoss: 0.0019635778141646612
Progress: 61%|█████████████████████████ | ETA: 0:00:15
TrainingLoss: 0.001916843662959252
Progress: 62%|█████████████████████████▍ | ETA: 0:00:14
TrainingLoss: 0.0018333750313756277
Progress: 63%|█████████████████████████▉ | ETA: 0:00:14
TrainingLoss: 0.0017870375035334428
Progress: 64%|██████████████████████████▎ | ETA: 0:00:13
TrainingLoss: 0.0017054255236336393
Progress: 65%|██████████████████████████▋ | ETA: 0:00:13
TrainingLoss: 0.0016902073498218002
Progress: 66%|███████████████████████████ | ETA: 0:00:12
TrainingLoss: 0.001633844598327013
Progress: 67%|███████████████████████████▌ | ETA: 0:00:12
TrainingLoss: 0.001559433784087885
Progress: 68%|███████████████████████████▉ | ETA: 0:00:11
TrainingLoss: 0.0015064956142470784
Progress: 69%|████████████████████████████▎ | ETA: 0:00:11
TrainingLoss: 0.0014653490914935522
Progress: 70%|████████████████████████████▊ | ETA: 0:00:10
TrainingLoss: 0.0014300976614041562
Progress: 71%|█████████████████████████████▏ | ETA: 0:00:10
TrainingLoss: 0.00139543327968236
Progress: 72%|█████████████████████████████▌ | ETA: 0:00:10
TrainingLoss: 0.0013350366497774915
Progress: 73%|█████████████████████████████▉ | ETA: 0:00:09
TrainingLoss: 0.0013069308784496858
Progress: 74%|██████████████████████████████▍ | ETA: 0:00:09
TrainingLoss: 0.0012825141091592144
Progress: 75%|██████████████████████████████▊ | ETA: 0:00:08
TrainingLoss: 0.0012222869043667169
Progress: 76%|███████████████████████████████▏ | ETA: 0:00:08
TrainingLoss: 0.0011923654386493175
Progress: 77%|███████████████████████████████▋ | ETA: 0:00:08
TrainingLoss: 0.0011731610482376069
Progress: 78%|████████████████████████████████ | ETA: 0:00:07
TrainingLoss: 0.0011362864565157356
Progress: 79%|████████████████████████████████▍ | ETA: 0:00:07
TrainingLoss: 0.0011203444035492438
Progress: 80%|████████████████████████████████▊ | ETA: 0:00:06
TrainingLoss: 0.0010678628323975944
Progress: 81%|█████████████████████████████████▎ | ETA: 0:00:06
TrainingLoss: 0.001056181876396537
Progress: 82%|█████████████████████████████████▋ | ETA: 0:00:06
TrainingLoss: 0.0010367496610259133
Progress: 83%|██████████████████████████████████ | ETA: 0:00:05
TrainingLoss: 0.001037578917128732
Progress: 84%|██████████████████████████████████▌ | ETA: 0:00:05
TrainingLoss: 0.0010347567434596265
Progress: 85%|██████████████████████████████████▉ | ETA: 0:00:05
TrainingLoss: 0.0009815574187053283
Progress: 86%|███████████████████████████████████▎ | ETA: 0:00:04
TrainingLoss: 0.0010363850968636591
Progress: 87%|███████████████████████████████████▋ | ETA: 0:00:04
TrainingLoss: 0.0009640947892551564
Progress: 88%|████████████████████████████████████▏ | ETA: 0:00:04
TrainingLoss: 0.0009597937646158175
Progress: 89%|████████████████████████████████████▌ | ETA: 0:00:03
TrainingLoss: 0.0009750364496416907
Progress: 90%|████████████████████████████████████▉ | ETA: 0:00:03
TrainingLoss: 0.000934910438363555
Progress: 91%|█████████████████████████████████████▎ | ETA: 0:00:03
TrainingLoss: 0.0009278928583648701
Progress: 92%|█████████████████████████████████████▊ | ETA: 0:00:02
TrainingLoss: 0.0009024624060920245
Progress: 93%|██████████████████████████████████████▏ | ETA: 0:00:02
TrainingLoss: 0.0009235735589632988
Progress: 94%|██████████████████████████████████████▌ | ETA: 0:00:02
TrainingLoss: 0.0009019012682137244
Progress: 95%|███████████████████████████████████████ | ETA: 0:00:01
TrainingLoss: 0.0009164494964562528
Progress: 96%|███████████████████████████████████████▍ | ETA: 0:00:01
TrainingLoss: 0.0008766171583718228
Progress: 97%|███████████████████████████████████████▊ | ETA: 0:00:01
TrainingLoss: 0.0008804169875312751
Progress: 98%|████████████████████████████████████████▏| ETA: 0:00:01
TrainingLoss: 0.0008964399339960166
Progress: 99%|████████████████████████████████████████▋| ETA: 0:00:00
TrainingLoss: 0.0008741816099714606
Progress: 100%|█████████████████████████████████████████| Time: 0:00:28
TrainingLoss: 0.0008719476867063491
Progress: 1%|▎ | ETA: 0:08:48
TrainingLoss: 37.2781882663114
Progress: 1%|▌ | ETA: 0:04:30
TrainingLoss: 34.25271326831749
Progress: 2%|▉ | ETA: 0:03:04
TrainingLoss: 31.862353518919413
Progress: 3%|█▏ | ETA: 0:02:21
TrainingLoss: 29.79788169222612
Progress: 3%|█▍ | ETA: 0:01:55
TrainingLoss: 27.818564298700885
Progress: 4%|█▋ | ETA: 0:01:38
TrainingLoss: 26.17660855234893
Progress: 5%|█▉ | ETA: 0:01:26
TrainingLoss: 24.661681885340684
Progress: 5%|██▏ | ETA: 0:01:17
TrainingLoss: 23.18208766967964
Progress: 6%|██▌ | ETA: 0:01:09
TrainingLoss: 21.892173493600712
Progress: 7%|██▊ | ETA: 0:01:03
TrainingLoss: 20.70907264078189
Progress: 7%|███ | ETA: 0:00:58
TrainingLoss: 19.603362250682103
Progress: 8%|███▎ | ETA: 0:00:54
TrainingLoss: 18.587990052362557
Progress: 9%|███▌ | ETA: 0:00:51
TrainingLoss: 17.564158509466402
Progress: 9%|███▉ | ETA: 0:00:48
TrainingLoss: 16.650483294725923
Progress: 10%|████▏ | ETA: 0:00:45
TrainingLoss: 15.752376051647289
Progress: 11%|████▍ | ETA: 0:00:43
TrainingLoss: 14.903333804831055
Progress: 11%|████▋ | ETA: 0:00:41
TrainingLoss: 14.12958651853665
Progress: 12%|████▉ | ETA: 0:00:39
TrainingLoss: 13.372060156474362
Progress: 13%|█████▎ | ETA: 0:00:38
TrainingLoss: 12.618342793826097
Progress: 13%|█████▌ | ETA: 0:00:36
TrainingLoss: 11.909003995926955
Progress: 14%|█████▊ | ETA: 0:00:35
TrainingLoss: 11.193245810165742
Progress: 15%|██████ | ETA: 0:00:34
TrainingLoss: 10.476721016603614
Progress: 15%|██████▎ | ETA: 0:00:32
TrainingLoss: 9.786663808815183
Progress: 16%|██████▌ | ETA: 0:00:31
TrainingLoss: 9.09409232451754
Progress: 17%|██████▉ | ETA: 0:00:30
TrainingLoss: 8.457942431225625
Progress: 17%|███████▏ | ETA: 0:00:30
TrainingLoss: 7.843216402043079
Progress: 18%|███████▍ | ETA: 0:00:29
TrainingLoss: 7.247276942005529
Progress: 19%|███████▋ | ETA: 0:00:28
TrainingLoss: 6.658495081490054
Progress: 19%|███████▉ | ETA: 0:00:27
TrainingLoss: 6.076251501390475
Progress: 20%|████████▎ | ETA: 0:00:26
TrainingLoss: 5.513614450504939
Progress: 21%|████████▌ | ETA: 0:00:26
TrainingLoss: 4.9681923970885356
Progress: 21%|████████▊ | ETA: 0:00:25
TrainingLoss: 4.4415843973124085
Progress: 22%|█████████ | ETA: 0:00:25
TrainingLoss: 3.965076948620265
Progress: 23%|█████████▎ | ETA: 0:00:24
TrainingLoss: 3.541163054322314
Progress: 23%|█████████▋ | ETA: 0:00:23
TrainingLoss: 3.129075592515103
Progress: 24%|█████████▉ | ETA: 0:00:23
TrainingLoss: 2.7620079357856757
Progress: 25%|██████████▏ | ETA: 0:00:22
TrainingLoss: 2.4672092817399784
Progress: 25%|██████████▍ | ETA: 0:00:22
TrainingLoss: 2.2120974042746355
Progress: 26%|██████████▋ | ETA: 0:00:22
TrainingLoss: 2.0226646963548465
Progress: 27%|██████████▉ | ETA: 0:00:21
TrainingLoss: 1.8711538684323021
Progress: 27%|███████████▎ | ETA: 0:00:21
TrainingLoss: 1.762047837309004
Progress: 28%|███████████▌ | ETA: 0:00:20
TrainingLoss: 1.6930511612013137
Progress: 29%|███████████▊ | ETA: 0:00:20
TrainingLoss: 1.6428450577683127
Progress: 29%|████████████ | ETA: 0:00:19
TrainingLoss: 1.604732219096015
Progress: 30%|████████████▎ | ETA: 0:00:19
TrainingLoss: 1.5743680962510656
Progress: 31%|████████████▋ | ETA: 0:00:19
TrainingLoss: 1.549400757226623
Progress: 31%|████████████▉ | ETA: 0:00:18
TrainingLoss: 1.5250005191157894
Progress: 32%|█████████████▏ | ETA: 0:00:18
TrainingLoss: 1.497296544711445
Progress: 33%|█████████████▍ | ETA: 0:00:18
TrainingLoss: 1.4688491258707614
Progress: 33%|█████████████▋ | ETA: 0:00:17
TrainingLoss: 1.4431224826617999
Progress: 34%|██████████████ | ETA: 0:00:17
TrainingLoss: 1.4155222417161761
Progress: 35%|██████████████▎ | ETA: 0:00:17
TrainingLoss: 1.3893250169508444
Progress: 35%|██████████████▌ | ETA: 0:00:17
TrainingLoss: 1.360884754117148
Progress: 36%|██████████████▊ | ETA: 0:00:16
TrainingLoss: 1.3286550171964167
Progress: 37%|███████████████ | ETA: 0:00:16
TrainingLoss: 1.2967755704967285
Progress: 37%|███████████████▎ | ETA: 0:00:16
TrainingLoss: 1.267264656909044
Progress: 38%|███████████████▋ | ETA: 0:00:15
TrainingLoss: 1.2366297765190193
Progress: 39%|███████████████▉ | ETA: 0:00:15
TrainingLoss: 1.2027572201119674
Progress: 39%|████████████████▏ | ETA: 0:00:15
TrainingLoss: 1.172106629815977
Progress: 40%|████████████████▍ | ETA: 0:00:15
TrainingLoss: 1.1404689983794714
Progress: 41%|████████████████▋ | ETA: 0:00:14
TrainingLoss: 1.109053940800489
Progress: 41%|█████████████████ | ETA: 0:00:14
TrainingLoss: 1.0807536972296783
Progress: 42%|█████████████████▎ | ETA: 0:00:14
TrainingLoss: 1.0541253921740434
Progress: 43%|█████████████████▌ | ETA: 0:00:14
TrainingLoss: 1.0275614843274163
Progress: 43%|█████████████████▊ | ETA: 0:00:13
TrainingLoss: 1.0031025171873602
Progress: 44%|██████████████████ | ETA: 0:00:13
TrainingLoss: 0.9804972768008853
Progress: 45%|██████████████████▍ | ETA: 0:00:13
TrainingLoss: 0.9564057600330156
Progress: 45%|██████████████████▋ | ETA: 0:00:13
TrainingLoss: 0.9348819102418282
Progress: 46%|██████████████████▉ | ETA: 0:00:13
TrainingLoss: 0.9166099848748246
Progress: 47%|███████████████████▏ | ETA: 0:00:12
TrainingLoss: 0.8974724276756186
Progress: 47%|███████████████████▍ | ETA: 0:00:12
TrainingLoss: 0.8807817906901303
Progress: 48%|███████████████████▋ | ETA: 0:00:12
TrainingLoss: 0.8628932966917839
Progress: 49%|████████████████████ | ETA: 0:00:12
TrainingLoss: 0.8455134848081066
Progress: 49%|████████████████████▎ | ETA: 0:00:12
TrainingLoss: 0.8287789782568564
Progress: 50%|████████████████████▌ | ETA: 0:00:11
TrainingLoss: 0.8119622348518154
Progress: 51%|████████████████████▊ | ETA: 0:00:11
TrainingLoss: 0.7946715983570767
Progress: 51%|█████████████████████ | ETA: 0:00:11
TrainingLoss: 0.7763828999395177
Progress: 52%|█████████████████████▍ | ETA: 0:00:11
TrainingLoss: 0.7561242496034482
Progress: 53%|█████████████████████▋ | ETA: 0:00:11
TrainingLoss: 0.7333447163414457
Progress: 53%|█████████████████████▉ | ETA: 0:00:10
TrainingLoss: 0.708478741408224
Progress: 54%|██████████████████████▏ | ETA: 0:00:10
TrainingLoss: 0.6838312267808371
Progress: 55%|██████████████████████▍ | ETA: 0:00:10
TrainingLoss: 0.6587695750234304
Progress: 55%|██████████████████████▋ | ETA: 0:00:10
TrainingLoss: 0.6317548652072735
Progress: 56%|███████████████████████ | ETA: 0:00:10
TrainingLoss: 0.6043337635359523
Progress: 57%|███████████████████████▎ | ETA: 0:00:09
TrainingLoss: 0.576519005519089
Progress: 57%|███████████████████████▌ | ETA: 0:00:09
TrainingLoss: 0.5501209207020171
Progress: 58%|███████████████████████▊ | ETA: 0:00:09
TrainingLoss: 0.5303074704643341
Progress: 59%|████████████████████████ | ETA: 0:00:09
TrainingLoss: 0.5197459755168593
Progress: 59%|████████████████████████▍ | ETA: 0:00:09
TrainingLoss: 0.5093142781613155
Progress: 60%|████████████████████████▋ | ETA: 0:00:09
TrainingLoss: 0.5004319221629353
Progress: 61%|████████████████████████▉ | ETA: 0:00:08
TrainingLoss: 0.4927224018210071
Progress: 61%|█████████████████████████▏ | ETA: 0:00:08
TrainingLoss: 0.48515212299285954
Progress: 62%|█████████████████████████▍ | ETA: 0:00:08
TrainingLoss: 0.47816889432577403
Progress: 63%|█████████████████████████▊ | ETA: 0:00:08
TrainingLoss: 0.47198655357139224
Progress: 63%|██████████████████████████ | ETA: 0:00:08
TrainingLoss: 0.4661780756803638
Progress: 64%|██████████████████████████▎ | ETA: 0:00:08
TrainingLoss: 0.4598880504138836
Progress: 65%|██████████████████████████▌ | ETA: 0:00:07
TrainingLoss: 0.4539778326533468
Progress: 65%|██████████████████████████▊ | ETA: 0:00:07
TrainingLoss: 0.4490226186554681
Progress: 66%|███████████████████████████ | ETA: 0:00:07
TrainingLoss: 0.4440291026849187
Progress: 67%|███████████████████████████▍ | ETA: 0:00:07
TrainingLoss: 0.43775111895119084
Progress: 67%|███████████████████████████▋ | ETA: 0:00:07
TrainingLoss: 0.43242428703079094
Progress: 68%|███████████████████████████▉ | ETA: 0:00:07
TrainingLoss: 0.4267439250148832
Progress: 69%|████████████████████████████▏ | ETA: 0:00:07
TrainingLoss: 0.4213329760234882
Progress: 69%|████████████████████████████▍ | ETA: 0:00:06
TrainingLoss: 0.41596142576030815
Progress: 70%|████████████████████████████▊ | ETA: 0:00:06
TrainingLoss: 0.41085111234298594
Progress: 71%|█████████████████████████████ | ETA: 0:00:06
TrainingLoss: 0.40555284162956867
Progress: 71%|█████████████████████████████▎ | ETA: 0:00:06
TrainingLoss: 0.39996165543168544
Progress: 72%|█████████████████████████████▌ | ETA: 0:00:06
TrainingLoss: 0.3944510570163676
Progress: 73%|█████████████████████████████▊ | ETA: 0:00:06
TrainingLoss: 0.3894658865167588
Progress: 73%|██████████████████████████████▏ | ETA: 0:00:05
TrainingLoss: 0.3843459864590106
Progress: 74%|██████████████████████████████▍ | ETA: 0:00:05
TrainingLoss: 0.3788211800762092
Progress: 75%|██████████████████████████████▋ | ETA: 0:00:05
TrainingLoss: 0.37309870381006693
Progress: 75%|██████████████████████████████▉ | ETA: 0:00:05
TrainingLoss: 0.3676834869003855
Progress: 76%|███████████████████████████████▏ | ETA: 0:00:05
TrainingLoss: 0.3626189222020604
Progress: 77%|███████████████████████████████▍ | ETA: 0:00:05
TrainingLoss: 0.3572365775854217
Progress: 77%|███████████████████████████████▊ | ETA: 0:00:05
TrainingLoss: 0.3516462205357554
Progress: 78%|████████████████████████████████ | ETA: 0:00:04
TrainingLoss: 0.3464006036083216
Progress: 79%|████████████████████████████████▎ | ETA: 0:00:04
TrainingLoss: 0.34115420853305894
Progress: 79%|████████████████████████████████▌ | ETA: 0:00:04
TrainingLoss: 0.33603036241928735
Progress: 80%|████████████████████████████████▊ | ETA: 0:00:04
TrainingLoss: 0.33040989562951634
Progress: 81%|█████████████████████████████████▏ | ETA: 0:00:04
TrainingLoss: 0.3248795195757514
Progress: 81%|█████████████████████████████████▍ | ETA: 0:00:04
TrainingLoss: 0.31961727277187835
Progress: 82%|█████████████████████████████████▋ | ETA: 0:00:04
TrainingLoss: 0.3141157125274173
Progress: 83%|█████████████████████████████████▉ | ETA: 0:00:03
TrainingLoss: 0.3092547349436968
Progress: 83%|██████████████████████████████████▏ | ETA: 0:00:03
TrainingLoss: 0.3042273260314133
Progress: 84%|██████████████████████████████████▌ | ETA: 0:00:03
TrainingLoss: 0.2982019047695372
Progress: 85%|██████████████████████████████████▊ | ETA: 0:00:03
TrainingLoss: 0.2927775997746805
Progress: 85%|███████████████████████████████████ | ETA: 0:00:03
TrainingLoss: 0.2871576837231991
Progress: 86%|███████████████████████████████████▎ | ETA: 0:00:03
TrainingLoss: 0.2815555113630414
Progress: 87%|███████████████████████████████████▌ | ETA: 0:00:03
TrainingLoss: 0.276246757739238
Progress: 87%|███████████████████████████████████▊ | ETA: 0:00:02
TrainingLoss: 0.27077539551445834
Progress: 88%|████████████████████████████████████▏ | ETA: 0:00:02
TrainingLoss: 0.2653546497074372
Progress: 89%|████████████████████████████████████▍ | ETA: 0:00:02
TrainingLoss: 0.26005935412025993
Progress: 89%|████████████████████████████████████▋ | ETA: 0:00:02
TrainingLoss: 0.25472460997662777
Progress: 90%|████████████████████████████████████▉ | ETA: 0:00:02
TrainingLoss: 0.24951144514950868
Progress: 91%|█████████████████████████████████████▏ | ETA: 0:00:02
TrainingLoss: 0.24414702719526224
Progress: 91%|█████████████████████████████████████▌ | ETA: 0:00:02
TrainingLoss: 0.23964784680479165
Progress: 92%|█████████████████████████████████████▊ | ETA: 0:00:02
TrainingLoss: 0.2345995351780516
Progress: 93%|██████████████████████████████████████ | ETA: 0:00:01
TrainingLoss: 0.23004146024632433
Progress: 93%|██████████████████████████████████████▎ | ETA: 0:00:01
TrainingLoss: 0.22518498600290301
Progress: 94%|██████████████████████████████████████▌ | ETA: 0:00:01
TrainingLoss: 0.220122637603745
Progress: 95%|██████████████████████████████████████▉ | ETA: 0:00:01
TrainingLoss: 0.21549455373774942
Progress: 95%|███████████████████████████████████████▏ | ETA: 0:00:01
TrainingLoss: 0.21073034871076857
Progress: 96%|███████████████████████████████████████▍ | ETA: 0:00:01
TrainingLoss: 0.2066337173806539
Progress: 97%|███████████████████████████████████████▋ | ETA: 0:00:01
TrainingLoss: 0.20259418185370368
Progress: 97%|███████████████████████████████████████▉ | ETA: 0:00:01
TrainingLoss: 0.19786520671631475
Progress: 98%|████████████████████████████████████████▏| ETA: 0:00:00
TrainingLoss: 0.19403101739367112
Progress: 99%|████████████████████████████████████████▌| ETA: 0:00:00
TrainingLoss: 0.19014066481621858
Progress: 99%|████████████████████████████████████████▊| ETA: 0:00:00
TrainingLoss: 0.18586338209508285
Progress: 100%|█████████████████████████████████████████| Time: 0:00:19
TrainingLoss: 0.18294885931747412
We can also plot the training errors against the epoch (here the $y$-axis is in log-scale):
using Plots
p1 = plot(g_loss_array, xlabel="Epoch", ylabel="Training error", label="G-SympNet", color=3, yaxis=:log)
plot!(p1, la_loss_array, label="LA-SympNet", color=2)
The trainings data data_q
and data_p
must be matrices of $\mathbb{R}^{n\times d}$ where $n$ is the length of data and $d$ is the half of the dimension of the system, i.e data_q[i,j]
is $q_j(t_i)$ where $(t_1,...,t_n)$ are the corresponding time of the training data.
Now we can make a prediction. Let's compare the initial data with a prediction starting from the same phase space point using the function iterate
:
ics = (q=qp_data.q[:,1], p=qp_data.p[:,1])
steps_to_plot = 200
#predictions
la_trajectory = iterate(la_nn, ics; n_points = steps_to_plot)
g_trajectory = iterate(g_nn, ics; n_points = steps_to_plot)
using Plots
p2 = plot(qp_data.q'[1:steps_to_plot], qp_data.p'[1:steps_to_plot], label="training data")
plot!(p2, la_trajectory.q', la_trajectory.p', label="LA Sympnet")
plot!(p2, g_trajectory.q', g_trajectory.p', label="G Sympnet")
We see that GSympNet
outperforms the LASympNet
on this problem.