SympNets with GeometricMachineLearning

This page serves as a short introduction into using SympNets with GeometricMachineLearning.jl. For the general theory see the theory section.

With GeometricMachineLearning.jl one can easily implement SympNets. The steps are the following :

  • Specify the architecture with the functions GSympNet and LASympNet,
  • Specify the type and the backend with NeuralNetwork,
  • Pick an optimizer for training the network,
  • Train the neural networks!

We discuss these points is some detail:

Specifying the architecture

To call an $LA$-SympNet, one needs to write

lasympnet = LASympNet(dim; depth=5, nhidden=1, activation=tanh, init_upper_linear=true, init_upper_act=true) 

LASympNet takes one obligatory argument:

  • dim : the dimension of the phase space (i.e. an integer) or optionally an instance of DataLoader. This latter option will be used below.

and several keywords argument :

  • depth : the depth for all the linear layers. The default value set to 5 (if width>5, width is set to 5). See the theory section for more details; there depth was called $n$.
  • nhidden : the number of pairs of linear and activation layers with default value set to 1 (i.e the $LA$-SympNet is a composition of a linear layer, an activation layer and then again a single layer).
  • activation : the activation function for all the activations layers with default set to tanh,
  • initupperlinear : a boolean that indicates whether the first linear layer changes $q$ first. By default this is true.
  • initupperact : a boolean that indicates whether the first activation layer changes $q$ first. By default this is true.

G-SympNet

To call a G-SympNet, one needs to write

gsympnet = GSympNet(dim; upscaling_dimension=2*dim, n_layers=2, activation=tanh, init_upper=true) 

GSympNet takes one obligatory argument:

  • dim : the dimension of the phase space (i.e. an integer) or optionally an instance of DataLoader. This latter option will be used below.

and severals keywords argument :

  • upscaling_dimension: The first dimension of the matrix with which the input is multiplied. In the theory section this matrix is called $K$ and the upscaling dimension is called $m$.
  • n_layers: the number of gradient layers with default value set to 2.
  • activation : the activation function for all the activations layers with default set to tanh.
  • init_upper : a boolean that indicates whether the first gradient layer changes $q$ first. By default this is true.

Loss function

The loss function described in the theory section is the default choice used in GeometricMachineLearning.jl for training SympNets.

Examples

Let us see how to use it on several examples.

Example of a pendulum with G-SympNet

Let us begin with a simple example, the pendulum system, the Hamiltonian of which is

\[H:(q,p)\in\mathbb{R}^2 \mapsto \frac{1}{2}p^2-cos(q) \in \mathbb{R}.\]

Here we generate pendulum data with the script GeometricMachineLearning/scripts/pendulum.jl:

Random.seed!(1234)

# load script
include("../../../scripts/pendulum.jl")
# specify the data type
type = Float16
# get data
qp_data = GeometricMachineLearning.apply_toNT(a -> type.(a), pendulum_data((q=[0.], p=[1.]); tspan=(0.,100.)))
# call the DataLoader
dl = DataLoader(qp_data)
[ Info: You have provided a NamedTuple with keys q and p; the data are matrices. This is interpreted as *symplectic data*.

Next we specify the architectures. GeometricMachineLearning.jl provides useful defaults for all parameters although they can be specified manually (which is done in the following):

# layer dimension for gradient module
const upscaling_dimension = 2
# hidden layers
const nhidden = 1
# activation function
const activation = tanh

# calling G-SympNet architecture
gsympnet = GSympNet(dl, upscaling_dimension=upscaling_dimension, n_layers=4, activation=activation)

# calling LA-SympNet architecture
lasympnet = LASympNet(dl, nhidden=nhidden, activation=activation)

# specify the backend
const backend = CPU()

# initialize the networks
la_nn = NeuralNetwork(lasympnet, backend, type)
g_nn = NeuralNetwork(gsympnet, backend, type)

If we want to obtain information on the number of parameters in a neural network, we can do that very simply with the function parameterlength. For the LASympNet:

parameterlength(la_nn.model)
14

And for the GSympNet:

parameterlength(g_nn.model)
12

Remark: We can also specify whether we would like to start with a layer that changes the $q$-component or one that changes the $p$-component. This can be done via the keywords init_upper for GSympNet, and init_upper_linear and init_upper_act for LASympNet.

We have to define an optimizer which will be use in the training of the SympNet. For more details on optimizer, please see the corresponding documentation. In this example we use Adam:

# set up optimizer; for this we first need to specify the optimization method (argue for why we need the optimizer method)
opt_method = AdamOptimizer(type)
la_opt = Optimizer(opt_method, la_nn)
g_opt = Optimizer(opt_method, g_nn)

We can now perform the training of the neural networks. The syntax is the following :

# number of training epochs
const nepochs = 300
# Batchsize used to compute the gradient of the loss function with respect to the parameters of the neural networks.
const batch_size = 100

batch = Batch(batch_size)

# perform training (returns array that contains the total loss for each training step)
g_loss_array = g_opt(g_nn, dl, batch, nepochs)
la_loss_array = la_opt(la_nn, dl, batch, nepochs)

Progress:   1%|▎                                        |  ETA: 0:40:24
  TrainingLoss:  1.7803702927949743


Progress:   2%|▋                                        |  ETA: 0:16:07
  TrainingLoss:  1.6495201816701386


Progress:   3%|█▏                                       |  ETA: 0:10:02
  TrainingLoss:  1.5234181792626509


Progress:   4%|█▌                                       |  ETA: 0:07:17
  TrainingLoss:  1.4035815451069618


Progress:   5%|█▉                                       |  ETA: 0:05:42
  TrainingLoss:  1.2922086676664217


Progress:   6%|██▍                                      |  ETA: 0:04:41
  TrainingLoss:  1.187802409322577


Progress:   7%|██▊                                      |  ETA: 0:03:58
  TrainingLoss:  1.093293775908665


Progress:   8%|███▏                                     |  ETA: 0:03:26
  TrainingLoss:  1.0092221210485548


Progress:   9%|███▌                                     |  ETA: 0:03:02
  TrainingLoss:  0.9381178786604968


Progress:  10%|████                                     |  ETA: 0:02:42
  TrainingLoss:  0.8720628349154051


Progress:  11%|████▍                                    |  ETA: 0:02:26
  TrainingLoss:  0.8097303657182294


Progress:  12%|████▊                                    |  ETA: 0:02:13
  TrainingLoss:  0.7488072204298228


Progress:  13%|█████▎                                   |  ETA: 0:02:02
  TrainingLoss:  0.6899178937402026


Progress:  14%|█████▋                                   |  ETA: 0:01:53
  TrainingLoss:  0.6336424245625359


Progress:  15%|██████▏                                  |  ETA: 0:01:42
  TrainingLoss:  0.5619970084835468


Progress:  16%|██████▌                                  |  ETA: 0:01:35
  TrainingLoss:  0.5117486223712584


Progress:  17%|███████                                  |  ETA: 0:01:29
  TrainingLoss:  0.46329032848486396


Progress:  18%|███████▍                                 |  ETA: 0:01:24
  TrainingLoss:  0.417877585533165


Progress:  19%|███████▊                                 |  ETA: 0:01:19
  TrainingLoss:  0.3752050267389528


Progress:  20%|████████▎                                |  ETA: 0:01:14
  TrainingLoss:  0.3349818101118874


Progress:  21%|████████▋                                |  ETA: 0:01:10
  TrainingLoss:  0.2985153433174448


Progress:  22%|█████████                                |  ETA: 0:01:07
  TrainingLoss:  0.2641840626069677


Progress:  23%|█████████▍                               |  ETA: 0:01:03
  TrainingLoss:  0.23226148188991774


Progress:  24%|█████████▉                               |  ETA: 0:01:00
  TrainingLoss:  0.20241745043567644


Progress:  25%|██████████▎                              |  ETA: 0:00:58
  TrainingLoss:  0.1755180435418891


Progress:  26%|██████████▋                              |  ETA: 0:00:55
  TrainingLoss:  0.1528642869631937


Progress:  27%|███████████▏                             |  ETA: 0:00:52
  TrainingLoss:  0.13164967357565685


Progress:  28%|███████████▌                             |  ETA: 0:00:50
  TrainingLoss:  0.1119712745579958


Progress:  29%|███████████▉                             |  ETA: 0:00:48
  TrainingLoss:  0.09400871520126969


Progress:  30%|████████████▎                            |  ETA: 0:00:46
  TrainingLoss:  0.07735418657180346


Progress:  31%|████████████▊                            |  ETA: 0:00:44
  TrainingLoss:  0.06198124321291002


Progress:  32%|█████████████▏                           |  ETA: 0:00:42
  TrainingLoss:  0.04798051607505162


Progress:  33%|█████████████▌                           |  ETA: 0:00:41
  TrainingLoss:  0.03779552527725173


Progress:  34%|██████████████                           |  ETA: 0:00:39
  TrainingLoss:  0.028516727667412305


Progress:  35%|██████████████▍                          |  ETA: 0:00:38
  TrainingLoss:  0.019865285410268947


Progress:  36%|██████████████▊                          |  ETA: 0:00:36
  TrainingLoss:  0.011861524055736152


Progress:  37%|███████████████▏                         |  ETA: 0:00:35
  TrainingLoss:  0.004593417181906671


Progress:  38%|███████████████▋                         |  ETA: 0:00:34
  TrainingLoss:  0.0032265105853485276


Progress:  39%|████████████████                         |  ETA: 0:00:33
  TrainingLoss:  0.0029803924065258105


Progress:  40%|████████████████▍                        |  ETA: 0:00:31
  TrainingLoss:  0.002896752691009678


Progress:  41%|████████████████▊                        |  ETA: 0:00:30
  TrainingLoss:  0.002856994063585451


Progress:  42%|█████████████████▎                       |  ETA: 0:00:29
  TrainingLoss:  0.0028414417069307467


Progress:  43%|█████████████████▋                       |  ETA: 0:00:28
  TrainingLoss:  0.0028103721858569247


Progress:  44%|██████████████████                       |  ETA: 0:00:27
  TrainingLoss:  0.002806339083148671


Progress:  45%|██████████████████▌                      |  ETA: 0:00:26
  TrainingLoss:  0.002800481547008411


Progress:  46%|██████████████████▉                      |  ETA: 0:00:25
  TrainingLoss:  0.0027477281866818684


Progress:  47%|███████████████████▎                     |  ETA: 0:00:25
  TrainingLoss:  0.002712417775081033


Progress:  48%|███████████████████▋                     |  ETA: 0:00:24
  TrainingLoss:  0.002697894027342172


Progress:  49%|████████████████████▏                    |  ETA: 0:00:23
  TrainingLoss:  0.0026357550519469815


Progress:  50%|████████████████████▌                    |  ETA: 0:00:22
  TrainingLoss:  0.0026325975467476692


Progress:  51%|████████████████████▉                    |  ETA: 0:00:21
  TrainingLoss:  0.0025908293032137547


Progress:  52%|█████████████████████▍                   |  ETA: 0:00:21
  TrainingLoss:  0.0025345162972804343


Progress:  53%|█████████████████████▊                   |  ETA: 0:00:20
  TrainingLoss:  0.002450647213380506


Progress:  54%|██████████████████████▏                  |  ETA: 0:00:19
  TrainingLoss:  0.002396951772359002


Progress:  55%|██████████████████████▌                  |  ETA: 0:00:19
  TrainingLoss:  0.002308707821496941


Progress:  56%|███████████████████████                  |  ETA: 0:00:18
  TrainingLoss:  0.002254083676788514


Progress:  57%|███████████████████████▍                 |  ETA: 0:00:17
  TrainingLoss:  0.0021928191996808564


Progress:  58%|███████████████████████▊                 |  ETA: 0:00:17
  TrainingLoss:  0.0021140381737167133


Progress:  59%|████████████████████████▎                |  ETA: 0:00:16
  TrainingLoss:  0.002047186532870345


Progress:  60%|████████████████████████▋                |  ETA: 0:00:16
  TrainingLoss:  0.0019635778141646612


Progress:  61%|█████████████████████████                |  ETA: 0:00:15
  TrainingLoss:  0.001916843662959252


Progress:  62%|█████████████████████████▍               |  ETA: 0:00:14
  TrainingLoss:  0.0018333750313756277


Progress:  63%|█████████████████████████▉               |  ETA: 0:00:14
  TrainingLoss:  0.0017870375035334428


Progress:  64%|██████████████████████████▎              |  ETA: 0:00:13
  TrainingLoss:  0.0017054255236336393


Progress:  65%|██████████████████████████▋              |  ETA: 0:00:13
  TrainingLoss:  0.0016902073498218002


Progress:  66%|███████████████████████████              |  ETA: 0:00:12
  TrainingLoss:  0.001633844598327013


Progress:  67%|███████████████████████████▌             |  ETA: 0:00:12
  TrainingLoss:  0.001559433784087885


Progress:  68%|███████████████████████████▉             |  ETA: 0:00:11
  TrainingLoss:  0.0015064956142470784


Progress:  69%|████████████████████████████▎            |  ETA: 0:00:11
  TrainingLoss:  0.0014653490914935522


Progress:  70%|████████████████████████████▊            |  ETA: 0:00:10
  TrainingLoss:  0.0014300976614041562


Progress:  71%|█████████████████████████████▏           |  ETA: 0:00:10
  TrainingLoss:  0.00139543327968236


Progress:  72%|█████████████████████████████▌           |  ETA: 0:00:10
  TrainingLoss:  0.0013350366497774915


Progress:  73%|█████████████████████████████▉           |  ETA: 0:00:09
  TrainingLoss:  0.0013069308784496858


Progress:  74%|██████████████████████████████▍          |  ETA: 0:00:09
  TrainingLoss:  0.0012825141091592144


Progress:  75%|██████████████████████████████▊          |  ETA: 0:00:08
  TrainingLoss:  0.0012222869043667169


Progress:  76%|███████████████████████████████▏         |  ETA: 0:00:08
  TrainingLoss:  0.0011923654386493175


Progress:  77%|███████████████████████████████▋         |  ETA: 0:00:08
  TrainingLoss:  0.0011731610482376069


Progress:  78%|████████████████████████████████         |  ETA: 0:00:07
  TrainingLoss:  0.0011362864565157356


Progress:  79%|████████████████████████████████▍        |  ETA: 0:00:07
  TrainingLoss:  0.0011203444035492438


Progress:  80%|████████████████████████████████▊        |  ETA: 0:00:06
  TrainingLoss:  0.0010678628323975944


Progress:  81%|█████████████████████████████████▎       |  ETA: 0:00:06
  TrainingLoss:  0.001056181876396537


Progress:  82%|█████████████████████████████████▋       |  ETA: 0:00:06
  TrainingLoss:  0.0010367496610259133


Progress:  83%|██████████████████████████████████       |  ETA: 0:00:05
  TrainingLoss:  0.001037578917128732


Progress:  84%|██████████████████████████████████▌      |  ETA: 0:00:05
  TrainingLoss:  0.0010347567434596265


Progress:  85%|██████████████████████████████████▉      |  ETA: 0:00:05
  TrainingLoss:  0.0009815574187053283


Progress:  86%|███████████████████████████████████▎     |  ETA: 0:00:04
  TrainingLoss:  0.0010363850968636591


Progress:  87%|███████████████████████████████████▋     |  ETA: 0:00:04
  TrainingLoss:  0.0009640947892551564


Progress:  88%|████████████████████████████████████▏    |  ETA: 0:00:04
  TrainingLoss:  0.0009597937646158175


Progress:  89%|████████████████████████████████████▌    |  ETA: 0:00:03
  TrainingLoss:  0.0009750364496416907


Progress:  90%|████████████████████████████████████▉    |  ETA: 0:00:03
  TrainingLoss:  0.000934910438363555


Progress:  91%|█████████████████████████████████████▎   |  ETA: 0:00:03
  TrainingLoss:  0.0009278928583648701


Progress:  92%|█████████████████████████████████████▊   |  ETA: 0:00:02
  TrainingLoss:  0.0009024624060920245


Progress:  93%|██████████████████████████████████████▏  |  ETA: 0:00:02
  TrainingLoss:  0.0009235735589632988


Progress:  94%|██████████████████████████████████████▌  |  ETA: 0:00:02
  TrainingLoss:  0.0009019012682137244


Progress:  95%|███████████████████████████████████████  |  ETA: 0:00:01
  TrainingLoss:  0.0009164494964562528


Progress:  96%|███████████████████████████████████████▍ |  ETA: 0:00:01
  TrainingLoss:  0.0008766171583718228


Progress:  97%|███████████████████████████████████████▊ |  ETA: 0:00:01
  TrainingLoss:  0.0008804169875312751


Progress:  98%|████████████████████████████████████████▏|  ETA: 0:00:01
  TrainingLoss:  0.0008964399339960166


Progress:  99%|████████████████████████████████████████▋|  ETA: 0:00:00
  TrainingLoss:  0.0008741816099714606


Progress: 100%|█████████████████████████████████████████| Time: 0:00:28
  TrainingLoss:  0.0008719476867063491

Progress:   1%|▎                                        |  ETA: 0:08:48
  TrainingLoss:  37.2781882663114


Progress:   1%|▌                                        |  ETA: 0:04:30
  TrainingLoss:  34.25271326831749


Progress:   2%|▉                                        |  ETA: 0:03:04
  TrainingLoss:  31.862353518919413


Progress:   3%|█▏                                       |  ETA: 0:02:21
  TrainingLoss:  29.79788169222612


Progress:   3%|█▍                                       |  ETA: 0:01:55
  TrainingLoss:  27.818564298700885


Progress:   4%|█▋                                       |  ETA: 0:01:38
  TrainingLoss:  26.17660855234893


Progress:   5%|█▉                                       |  ETA: 0:01:26
  TrainingLoss:  24.661681885340684


Progress:   5%|██▏                                      |  ETA: 0:01:17
  TrainingLoss:  23.18208766967964


Progress:   6%|██▌                                      |  ETA: 0:01:09
  TrainingLoss:  21.892173493600712


Progress:   7%|██▊                                      |  ETA: 0:01:03
  TrainingLoss:  20.70907264078189


Progress:   7%|███                                      |  ETA: 0:00:58
  TrainingLoss:  19.603362250682103


Progress:   8%|███▎                                     |  ETA: 0:00:54
  TrainingLoss:  18.587990052362557


Progress:   9%|███▌                                     |  ETA: 0:00:51
  TrainingLoss:  17.564158509466402


Progress:   9%|███▉                                     |  ETA: 0:00:48
  TrainingLoss:  16.650483294725923


Progress:  10%|████▏                                    |  ETA: 0:00:45
  TrainingLoss:  15.752376051647289


Progress:  11%|████▍                                    |  ETA: 0:00:43
  TrainingLoss:  14.903333804831055


Progress:  11%|████▋                                    |  ETA: 0:00:41
  TrainingLoss:  14.12958651853665


Progress:  12%|████▉                                    |  ETA: 0:00:39
  TrainingLoss:  13.372060156474362


Progress:  13%|█████▎                                   |  ETA: 0:00:38
  TrainingLoss:  12.618342793826097


Progress:  13%|█████▌                                   |  ETA: 0:00:36
  TrainingLoss:  11.909003995926955


Progress:  14%|█████▊                                   |  ETA: 0:00:35
  TrainingLoss:  11.193245810165742


Progress:  15%|██████                                   |  ETA: 0:00:34
  TrainingLoss:  10.476721016603614


Progress:  15%|██████▎                                  |  ETA: 0:00:32
  TrainingLoss:  9.786663808815183


Progress:  16%|██████▌                                  |  ETA: 0:00:31
  TrainingLoss:  9.09409232451754


Progress:  17%|██████▉                                  |  ETA: 0:00:30
  TrainingLoss:  8.457942431225625


Progress:  17%|███████▏                                 |  ETA: 0:00:30
  TrainingLoss:  7.843216402043079


Progress:  18%|███████▍                                 |  ETA: 0:00:29
  TrainingLoss:  7.247276942005529


Progress:  19%|███████▋                                 |  ETA: 0:00:28
  TrainingLoss:  6.658495081490054


Progress:  19%|███████▉                                 |  ETA: 0:00:27
  TrainingLoss:  6.076251501390475


Progress:  20%|████████▎                                |  ETA: 0:00:26
  TrainingLoss:  5.513614450504939


Progress:  21%|████████▌                                |  ETA: 0:00:26
  TrainingLoss:  4.9681923970885356


Progress:  21%|████████▊                                |  ETA: 0:00:25
  TrainingLoss:  4.4415843973124085


Progress:  22%|█████████                                |  ETA: 0:00:25
  TrainingLoss:  3.965076948620265


Progress:  23%|█████████▎                               |  ETA: 0:00:24
  TrainingLoss:  3.541163054322314


Progress:  23%|█████████▋                               |  ETA: 0:00:23
  TrainingLoss:  3.129075592515103


Progress:  24%|█████████▉                               |  ETA: 0:00:23
  TrainingLoss:  2.7620079357856757


Progress:  25%|██████████▏                              |  ETA: 0:00:22
  TrainingLoss:  2.4672092817399784


Progress:  25%|██████████▍                              |  ETA: 0:00:22
  TrainingLoss:  2.2120974042746355


Progress:  26%|██████████▋                              |  ETA: 0:00:22
  TrainingLoss:  2.0226646963548465


Progress:  27%|██████████▉                              |  ETA: 0:00:21
  TrainingLoss:  1.8711538684323021


Progress:  27%|███████████▎                             |  ETA: 0:00:21
  TrainingLoss:  1.762047837309004


Progress:  28%|███████████▌                             |  ETA: 0:00:20
  TrainingLoss:  1.6930511612013137


Progress:  29%|███████████▊                             |  ETA: 0:00:20
  TrainingLoss:  1.6428450577683127


Progress:  29%|████████████                             |  ETA: 0:00:19
  TrainingLoss:  1.604732219096015


Progress:  30%|████████████▎                            |  ETA: 0:00:19
  TrainingLoss:  1.5743680962510656


Progress:  31%|████████████▋                            |  ETA: 0:00:19
  TrainingLoss:  1.549400757226623


Progress:  31%|████████████▉                            |  ETA: 0:00:18
  TrainingLoss:  1.5250005191157894


Progress:  32%|█████████████▏                           |  ETA: 0:00:18
  TrainingLoss:  1.497296544711445


Progress:  33%|█████████████▍                           |  ETA: 0:00:18
  TrainingLoss:  1.4688491258707614


Progress:  33%|█████████████▋                           |  ETA: 0:00:17
  TrainingLoss:  1.4431224826617999


Progress:  34%|██████████████                           |  ETA: 0:00:17
  TrainingLoss:  1.4155222417161761


Progress:  35%|██████████████▎                          |  ETA: 0:00:17
  TrainingLoss:  1.3893250169508444


Progress:  35%|██████████████▌                          |  ETA: 0:00:17
  TrainingLoss:  1.360884754117148


Progress:  36%|██████████████▊                          |  ETA: 0:00:16
  TrainingLoss:  1.3286550171964167


Progress:  37%|███████████████                          |  ETA: 0:00:16
  TrainingLoss:  1.2967755704967285


Progress:  37%|███████████████▎                         |  ETA: 0:00:16
  TrainingLoss:  1.267264656909044


Progress:  38%|███████████████▋                         |  ETA: 0:00:15
  TrainingLoss:  1.2366297765190193


Progress:  39%|███████████████▉                         |  ETA: 0:00:15
  TrainingLoss:  1.2027572201119674


Progress:  39%|████████████████▏                        |  ETA: 0:00:15
  TrainingLoss:  1.172106629815977


Progress:  40%|████████████████▍                        |  ETA: 0:00:15
  TrainingLoss:  1.1404689983794714


Progress:  41%|████████████████▋                        |  ETA: 0:00:14
  TrainingLoss:  1.109053940800489


Progress:  41%|█████████████████                        |  ETA: 0:00:14
  TrainingLoss:  1.0807536972296783


Progress:  42%|█████████████████▎                       |  ETA: 0:00:14
  TrainingLoss:  1.0541253921740434


Progress:  43%|█████████████████▌                       |  ETA: 0:00:14
  TrainingLoss:  1.0275614843274163


Progress:  43%|█████████████████▊                       |  ETA: 0:00:13
  TrainingLoss:  1.0031025171873602


Progress:  44%|██████████████████                       |  ETA: 0:00:13
  TrainingLoss:  0.9804972768008853


Progress:  45%|██████████████████▍                      |  ETA: 0:00:13
  TrainingLoss:  0.9564057600330156


Progress:  45%|██████████████████▋                      |  ETA: 0:00:13
  TrainingLoss:  0.9348819102418282


Progress:  46%|██████████████████▉                      |  ETA: 0:00:13
  TrainingLoss:  0.9166099848748246


Progress:  47%|███████████████████▏                     |  ETA: 0:00:12
  TrainingLoss:  0.8974724276756186


Progress:  47%|███████████████████▍                     |  ETA: 0:00:12
  TrainingLoss:  0.8807817906901303


Progress:  48%|███████████████████▋                     |  ETA: 0:00:12
  TrainingLoss:  0.8628932966917839


Progress:  49%|████████████████████                     |  ETA: 0:00:12
  TrainingLoss:  0.8455134848081066


Progress:  49%|████████████████████▎                    |  ETA: 0:00:12
  TrainingLoss:  0.8287789782568564


Progress:  50%|████████████████████▌                    |  ETA: 0:00:11
  TrainingLoss:  0.8119622348518154


Progress:  51%|████████████████████▊                    |  ETA: 0:00:11
  TrainingLoss:  0.7946715983570767


Progress:  51%|█████████████████████                    |  ETA: 0:00:11
  TrainingLoss:  0.7763828999395177


Progress:  52%|█████████████████████▍                   |  ETA: 0:00:11
  TrainingLoss:  0.7561242496034482


Progress:  53%|█████████████████████▋                   |  ETA: 0:00:11
  TrainingLoss:  0.7333447163414457


Progress:  53%|█████████████████████▉                   |  ETA: 0:00:10
  TrainingLoss:  0.708478741408224


Progress:  54%|██████████████████████▏                  |  ETA: 0:00:10
  TrainingLoss:  0.6838312267808371


Progress:  55%|██████████████████████▍                  |  ETA: 0:00:10
  TrainingLoss:  0.6587695750234304


Progress:  55%|██████████████████████▋                  |  ETA: 0:00:10
  TrainingLoss:  0.6317548652072735


Progress:  56%|███████████████████████                  |  ETA: 0:00:10
  TrainingLoss:  0.6043337635359523


Progress:  57%|███████████████████████▎                 |  ETA: 0:00:09
  TrainingLoss:  0.576519005519089


Progress:  57%|███████████████████████▌                 |  ETA: 0:00:09
  TrainingLoss:  0.5501209207020171


Progress:  58%|███████████████████████▊                 |  ETA: 0:00:09
  TrainingLoss:  0.5303074704643341


Progress:  59%|████████████████████████                 |  ETA: 0:00:09
  TrainingLoss:  0.5197459755168593


Progress:  59%|████████████████████████▍                |  ETA: 0:00:09
  TrainingLoss:  0.5093142781613155


Progress:  60%|████████████████████████▋                |  ETA: 0:00:09
  TrainingLoss:  0.5004319221629353


Progress:  61%|████████████████████████▉                |  ETA: 0:00:08
  TrainingLoss:  0.4927224018210071


Progress:  61%|█████████████████████████▏               |  ETA: 0:00:08
  TrainingLoss:  0.48515212299285954


Progress:  62%|█████████████████████████▍               |  ETA: 0:00:08
  TrainingLoss:  0.47816889432577403


Progress:  63%|█████████████████████████▊               |  ETA: 0:00:08
  TrainingLoss:  0.47198655357139224


Progress:  63%|██████████████████████████               |  ETA: 0:00:08
  TrainingLoss:  0.4661780756803638


Progress:  64%|██████████████████████████▎              |  ETA: 0:00:08
  TrainingLoss:  0.4598880504138836


Progress:  65%|██████████████████████████▌              |  ETA: 0:00:07
  TrainingLoss:  0.4539778326533468


Progress:  65%|██████████████████████████▊              |  ETA: 0:00:07
  TrainingLoss:  0.4490226186554681


Progress:  66%|███████████████████████████              |  ETA: 0:00:07
  TrainingLoss:  0.4440291026849187


Progress:  67%|███████████████████████████▍             |  ETA: 0:00:07
  TrainingLoss:  0.43775111895119084


Progress:  67%|███████████████████████████▋             |  ETA: 0:00:07
  TrainingLoss:  0.43242428703079094


Progress:  68%|███████████████████████████▉             |  ETA: 0:00:07
  TrainingLoss:  0.4267439250148832


Progress:  69%|████████████████████████████▏            |  ETA: 0:00:07
  TrainingLoss:  0.4213329760234882


Progress:  69%|████████████████████████████▍            |  ETA: 0:00:06
  TrainingLoss:  0.41596142576030815


Progress:  70%|████████████████████████████▊            |  ETA: 0:00:06
  TrainingLoss:  0.41085111234298594


Progress:  71%|█████████████████████████████            |  ETA: 0:00:06
  TrainingLoss:  0.40555284162956867


Progress:  71%|█████████████████████████████▎           |  ETA: 0:00:06
  TrainingLoss:  0.39996165543168544


Progress:  72%|█████████████████████████████▌           |  ETA: 0:00:06
  TrainingLoss:  0.3944510570163676


Progress:  73%|█████████████████████████████▊           |  ETA: 0:00:06
  TrainingLoss:  0.3894658865167588


Progress:  73%|██████████████████████████████▏          |  ETA: 0:00:05
  TrainingLoss:  0.3843459864590106


Progress:  74%|██████████████████████████████▍          |  ETA: 0:00:05
  TrainingLoss:  0.3788211800762092


Progress:  75%|██████████████████████████████▋          |  ETA: 0:00:05
  TrainingLoss:  0.37309870381006693


Progress:  75%|██████████████████████████████▉          |  ETA: 0:00:05
  TrainingLoss:  0.3676834869003855


Progress:  76%|███████████████████████████████▏         |  ETA: 0:00:05
  TrainingLoss:  0.3626189222020604


Progress:  77%|███████████████████████████████▍         |  ETA: 0:00:05
  TrainingLoss:  0.3572365775854217


Progress:  77%|███████████████████████████████▊         |  ETA: 0:00:05
  TrainingLoss:  0.3516462205357554


Progress:  78%|████████████████████████████████         |  ETA: 0:00:04
  TrainingLoss:  0.3464006036083216


Progress:  79%|████████████████████████████████▎        |  ETA: 0:00:04
  TrainingLoss:  0.34115420853305894


Progress:  79%|████████████████████████████████▌        |  ETA: 0:00:04
  TrainingLoss:  0.33603036241928735


Progress:  80%|████████████████████████████████▊        |  ETA: 0:00:04
  TrainingLoss:  0.33040989562951634


Progress:  81%|█████████████████████████████████▏       |  ETA: 0:00:04
  TrainingLoss:  0.3248795195757514


Progress:  81%|█████████████████████████████████▍       |  ETA: 0:00:04
  TrainingLoss:  0.31961727277187835


Progress:  82%|█████████████████████████████████▋       |  ETA: 0:00:04
  TrainingLoss:  0.3141157125274173


Progress:  83%|█████████████████████████████████▉       |  ETA: 0:00:03
  TrainingLoss:  0.3092547349436968


Progress:  83%|██████████████████████████████████▏      |  ETA: 0:00:03
  TrainingLoss:  0.3042273260314133


Progress:  84%|██████████████████████████████████▌      |  ETA: 0:00:03
  TrainingLoss:  0.2982019047695372


Progress:  85%|██████████████████████████████████▊      |  ETA: 0:00:03
  TrainingLoss:  0.2927775997746805


Progress:  85%|███████████████████████████████████      |  ETA: 0:00:03
  TrainingLoss:  0.2871576837231991


Progress:  86%|███████████████████████████████████▎     |  ETA: 0:00:03
  TrainingLoss:  0.2815555113630414


Progress:  87%|███████████████████████████████████▌     |  ETA: 0:00:03
  TrainingLoss:  0.276246757739238


Progress:  87%|███████████████████████████████████▊     |  ETA: 0:00:02
  TrainingLoss:  0.27077539551445834


Progress:  88%|████████████████████████████████████▏    |  ETA: 0:00:02
  TrainingLoss:  0.2653546497074372


Progress:  89%|████████████████████████████████████▍    |  ETA: 0:00:02
  TrainingLoss:  0.26005935412025993


Progress:  89%|████████████████████████████████████▋    |  ETA: 0:00:02
  TrainingLoss:  0.25472460997662777


Progress:  90%|████████████████████████████████████▉    |  ETA: 0:00:02
  TrainingLoss:  0.24951144514950868


Progress:  91%|█████████████████████████████████████▏   |  ETA: 0:00:02
  TrainingLoss:  0.24414702719526224


Progress:  91%|█████████████████████████████████████▌   |  ETA: 0:00:02
  TrainingLoss:  0.23964784680479165


Progress:  92%|█████████████████████████████████████▊   |  ETA: 0:00:02
  TrainingLoss:  0.2345995351780516


Progress:  93%|██████████████████████████████████████   |  ETA: 0:00:01
  TrainingLoss:  0.23004146024632433


Progress:  93%|██████████████████████████████████████▎  |  ETA: 0:00:01
  TrainingLoss:  0.22518498600290301


Progress:  94%|██████████████████████████████████████▌  |  ETA: 0:00:01
  TrainingLoss:  0.220122637603745


Progress:  95%|██████████████████████████████████████▉  |  ETA: 0:00:01
  TrainingLoss:  0.21549455373774942


Progress:  95%|███████████████████████████████████████▏ |  ETA: 0:00:01
  TrainingLoss:  0.21073034871076857


Progress:  96%|███████████████████████████████████████▍ |  ETA: 0:00:01
  TrainingLoss:  0.2066337173806539


Progress:  97%|███████████████████████████████████████▋ |  ETA: 0:00:01
  TrainingLoss:  0.20259418185370368


Progress:  97%|███████████████████████████████████████▉ |  ETA: 0:00:01
  TrainingLoss:  0.19786520671631475


Progress:  98%|████████████████████████████████████████▏|  ETA: 0:00:00
  TrainingLoss:  0.19403101739367112


Progress:  99%|████████████████████████████████████████▌|  ETA: 0:00:00
  TrainingLoss:  0.19014066481621858


Progress:  99%|████████████████████████████████████████▊|  ETA: 0:00:00
  TrainingLoss:  0.18586338209508285


Progress: 100%|█████████████████████████████████████████| Time: 0:00:19
  TrainingLoss:  0.18294885931747412

We can also plot the training errors against the epoch (here the $y$-axis is in log-scale):

using Plots
p1 = plot(g_loss_array, xlabel="Epoch", ylabel="Training error", label="G-SympNet", color=3, yaxis=:log)
plot!(p1, la_loss_array, label="LA-SympNet", color=2)
Example block output

The trainings data data_q and data_p must be matrices of $\mathbb{R}^{n\times d}$ where $n$ is the length of data and $d$ is the half of the dimension of the system, i.e data_q[i,j] is $q_j(t_i)$ where $(t_1,...,t_n)$ are the corresponding time of the training data.

Now we can make a prediction. Let's compare the initial data with a prediction starting from the same phase space point using the function iterate:

ics = (q=qp_data.q[:,1], p=qp_data.p[:,1])

steps_to_plot = 200

#predictions
la_trajectory = iterate(la_nn, ics; n_points = steps_to_plot)
g_trajectory =  iterate(g_nn, ics; n_points = steps_to_plot)

using Plots
p2 = plot(qp_data.q'[1:steps_to_plot], qp_data.p'[1:steps_to_plot], label="training data")
plot!(p2, la_trajectory.q', la_trajectory.p', label="LA Sympnet")
plot!(p2, g_trajectory.q', g_trajectory.p', label="G Sympnet")
Example block output

We see that GSympNet outperforms the LASympNet on this problem.