Symbolic Neural Networks
When using a symbolic neural network we can use architectures from GeometricMachineLearning
or more simple building blocks.
We first call the symbolic neural network that only consists of one layer:
using SymbolicNeuralNetworks
using AbstractNeuralNetworks: Chain, Dense, params
input_dim = 2
output_dim = 1
hidden_dim = 3
c = Chain(Dense(input_dim, hidden_dim), Dense(hidden_dim, hidden_dim), Dense(hidden_dim, output_dim))
nn = SymbolicNeuralNetwork(c)
We can now build symbolic expressions based on this neural network. Here we do so by calling evaluate_equations
:
using Symbolics
using Latexify: latexify
@variables sinput[1:input_dim]
soutput = nn.model(sinput, params(nn))
soutput
\[ \begin{equation} \mathrm{broadcast}\left( \tanh, \mathrm{broadcast}\left( +, \mathtt{W_{5}} \mathrm{broadcast}\left( \tanh, \mathrm{broadcast}\left( +, \mathtt{W_{3}} \mathrm{broadcast}\left( \tanh, \mathrm{broadcast}\left( +, \mathtt{W_{1}} \mathtt{sinput}, \mathtt{W_{2}} \right) \right), \mathtt{W_{4}} \right) \right), \mathtt{W_{6}} \right) \right) \end{equation} \]
or use Symbolics.scalarize
to get a more readable version of the equation:
soutput |> Symbolics.scalarize
\[ \begin{equation} \left[ \begin{array}{c} \tanh\left( \mathtt{W\_6}_{1} + \mathtt{W\_5}_{1,1} \tanh\left( \mathtt{W\_4}_{1} + \mathtt{W\_3}_{1,1} \tanh\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{1,2} \tanh\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{1,3} \tanh\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) + \mathtt{W\_5}_{1,2} \tanh\left( \mathtt{W\_4}_{2} + \mathtt{W\_3}_{2,1} \tanh\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{2,2} \tanh\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{2,3} \tanh\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) + \mathtt{W\_5}_{1,3} \tanh\left( \mathtt{W\_4}_{3} + \mathtt{W\_3}_{3,1} \tanh\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{3,2} \tanh\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{3,3} \tanh\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) \right) \\ \end{array} \right] \end{equation} \]
We can compute the symbolic gradient with SymbolicNeuralNetworks.Gradient
:
using SymbolicNeuralNetworks: derivative
derivative(SymbolicNeuralNetworks.Gradient(soutput, nn))[1].L1.b
\[ \begin{equation} \left[ \begin{array}{c} \left( \mathtt{W\_3}_{1,1} \mathtt{W\_5}_{1,1} \left( 1 - \tanh^{2}\left( \mathtt{W\_4}_{1} + \mathtt{W\_3}_{1,1} \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{1,2} \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{1,3} \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) \right) \left( 1 - \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) \right) + \mathtt{W\_3}_{2,1} \mathtt{W\_5}_{1,2} \left( 1 - \tanh^{2}\left( \mathtt{W\_4}_{2} + \mathtt{W\_3}_{2,1} \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{2,2} \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{2,3} \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) \right) \left( 1 - \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) \right) + \mathtt{W\_3}_{3,1} \mathtt{W\_5}_{1,3} \left( 1 - \tanh^{2}\left( \mathtt{W\_4}_{3} + \mathtt{W\_3}_{3,1} \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{3,2} \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{3,3} \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) \right) \left( 1 - \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) \right) \right) \left( 1 - \tanh^{2}\left( \mathtt{W\_6}_{1} + \mathtt{W\_5}_{1,1} \tanh^{2}\left( \mathtt{W\_4}_{1} + \mathtt{W\_3}_{1,1} \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{1,2} \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{1,3} \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) + \mathtt{W\_5}_{1,2} \tanh^{2}\left( \mathtt{W\_4}_{2} + \mathtt{W\_3}_{2,1} \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{2,2} \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{2,3} \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) + \mathtt{W\_5}_{1,3} \tanh^{2}\left( \mathtt{W\_4}_{3} + \mathtt{W\_3}_{3,1} \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{3,2} \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{3,3} \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) \right) \right) \\ \left( \mathtt{W\_3}_{1,2} \mathtt{W\_5}_{1,1} \left( 1 - \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) \right) \left( 1 - \tanh^{2}\left( \mathtt{W\_4}_{1} + \mathtt{W\_3}_{1,1} \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{1,2} \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{1,3} \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) \right) + \mathtt{W\_3}_{2,2} \mathtt{W\_5}_{1,2} \left( 1 - \tanh^{2}\left( \mathtt{W\_4}_{2} + \mathtt{W\_3}_{2,1} \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{2,2} \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{2,3} \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) \right) \left( 1 - \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) \right) + \mathtt{W\_3}_{3,2} \mathtt{W\_5}_{1,3} \left( 1 - \tanh^{2}\left( \mathtt{W\_4}_{3} + \mathtt{W\_3}_{3,1} \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{3,2} \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{3,3} \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) \right) \left( 1 - \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) \right) \right) \left( 1 - \tanh^{2}\left( \mathtt{W\_6}_{1} + \mathtt{W\_5}_{1,1} \tanh^{2}\left( \mathtt{W\_4}_{1} + \mathtt{W\_3}_{1,1} \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{1,2} \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{1,3} \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) + \mathtt{W\_5}_{1,2} \tanh^{2}\left( \mathtt{W\_4}_{2} + \mathtt{W\_3}_{2,1} \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{2,2} \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{2,3} \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) + \mathtt{W\_5}_{1,3} \tanh^{2}\left( \mathtt{W\_4}_{3} + \mathtt{W\_3}_{3,1} \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{3,2} \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{3,3} \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) \right) \right) \\ \left( \mathtt{W\_3}_{1,3} \mathtt{W\_5}_{1,1} \left( 1 - \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) \left( 1 - \tanh^{2}\left( \mathtt{W\_4}_{1} + \mathtt{W\_3}_{1,1} \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{1,2} \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{1,3} \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) \right) + \mathtt{W\_3}_{2,3} \mathtt{W\_5}_{1,2} \left( 1 - \tanh^{2}\left( \mathtt{W\_4}_{2} + \mathtt{W\_3}_{2,1} \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{2,2} \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{2,3} \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) \right) \left( 1 - \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) + \mathtt{W\_3}_{3,3} \mathtt{W\_5}_{1,3} \left( 1 - \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) \left( 1 - \tanh^{2}\left( \mathtt{W\_4}_{3} + \mathtt{W\_3}_{3,1} \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{3,2} \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{3,3} \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) \right) \right) \left( 1 - \tanh^{2}\left( \mathtt{W\_6}_{1} + \mathtt{W\_5}_{1,1} \tanh^{2}\left( \mathtt{W\_4}_{1} + \mathtt{W\_3}_{1,1} \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{1,2} \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{1,3} \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) + \mathtt{W\_5}_{1,2} \tanh^{2}\left( \mathtt{W\_4}_{2} + \mathtt{W\_3}_{2,1} \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{2,2} \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{2,3} \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) + \mathtt{W\_5}_{1,3} \tanh^{2}\left( \mathtt{W\_4}_{3} + \mathtt{W\_3}_{3,1} \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{3,2} \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{3,3} \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) \right) \right) \\ \end{array} \right] \end{equation} \]
SymbolicNeuralNetworks.Gradient
can also be called as SymbolicNeuralNetworks.Gradient(snn)
, so without providing a specific output. In this case soutput
is taken to be the symbolic output of the network (i.e. is equivalent to the construction presented here). Also note that we further called SymbolicNeuralNetworks.derivative
here in order to get the symbolic gradient (as opposed to the symbolic output of the neural network).
In order to train a SymbolicNeuralNetwork
we can use:
pb = SymbolicPullback(nn)
SymbolicNeuralNetworks.Gradient
and SymbolicPullback
both use the function SymbolicNeuralNetworks.symbolic_pullback
internally, so are computationally equivalent. SymbolicPullback
should however be used in connection to a NetworkLoss
and SymbolicNeuralNetworks.Gradient
can be used more generally to compute the derivative of array-valued expressions.
We want to use our one-layer neural network to approximate a Gaussian on the interval $[-1, 1]\times[-1, 1]$. We fist generate the data for this task:
using GeometricMachineLearning
x_vec = -1.:.1:1.
y_vec = -1.:.1:1.
xy_data = hcat([[x, y] for x in x_vec, y in y_vec]...)
f(x::Vector) = exp.(-sum(x.^2))
z_data = mapreduce(i -> f(xy_data[:, i]), hcat, axes(xy_data, 2))
dl = DataLoader(xy_data, z_data)
[ Info: You have provided an input and an output.
Note that we use GeometricMachineLearning.DataLoader
to process the data. We further also visualize them:
using CairoMakie
fig = Figure()
ax = Axis3(fig[1, 1])
surface!(x_vec, y_vec, [f([x, y]) for x in x_vec, y in y_vec]; alpha = .8, transparency = true)
fig

We now train the network:
nn_cpu = NeuralNetwork(c, CPU())
o = Optimizer(AdamOptimizer(), nn_cpu)
n_epochs = 1000
batch = Batch(10)
@time o(nn_cpu, dl, batch, n_epochs, pb.loss, pb; show_progress = false);
8.794173 seconds (36.09 M allocations: 1.389 GiB, 1.97% gc time)
We now compare the neural network-approximated curve to the original one:
fig = Figure()
ax = Axis3(fig[1, 1])
surface!(x_vec, y_vec, [c([x, y], params(nn_cpu))[1] for x in x_vec, y in y_vec]; alpha = .8, colormap = :darkterrain, transparency = true)
fig

We can also compare the time it takes to train the SymbolicNeuralNetwork
to the time it takes to train a standard neural network:
loss = FeedForwardLoss()
pb2 = GeometricMachineLearning.ZygotePullback(loss)
@time o(nn_cpu, dl, batch, n_epochs, pb2.loss, pb2; show_progress = false);
1.564433 seconds (25.01 M allocations: 1.213 GiB, 7.00% gc time)
For the case presented here we do not observe speed-ups of the symbolic neural network over the standard neural network. For other cases, especially Hamiltonian neural networks, this is however different.