Symbolic Neural Networks

When using a symbolic neural network we can use architectures from GeometricMachineLearning or more simple building blocks.

We first call the symbolic neural network that only consists of one layer:

using SymbolicNeuralNetworks
using AbstractNeuralNetworks: Chain, Dense, params

input_dim = 2
output_dim = 1
hidden_dim = 3
c = Chain(Dense(input_dim, hidden_dim), Dense(hidden_dim, hidden_dim), Dense(hidden_dim, output_dim))
nn = SymbolicNeuralNetwork(c)

We can now build symbolic expressions based on this neural network. Here we do so by calling evaluate_equations:

using Symbolics
using Latexify: latexify

@variables sinput[1:input_dim]
soutput = nn.model(sinput, params(nn))

soutput

\[ \begin{equation} \mathrm{broadcast}\left( \tanh, \mathrm{broadcast}\left( +, \mathtt{W_{5}} \mathrm{broadcast}\left( \tanh, \mathrm{broadcast}\left( +, \mathtt{W_{3}} \mathrm{broadcast}\left( \tanh, \mathrm{broadcast}\left( +, \mathtt{W_{1}} \mathtt{sinput}, \mathtt{W_{2}} \right) \right), \mathtt{W_{4}} \right) \right), \mathtt{W_{6}} \right) \right) \end{equation} \]

or use Symbolics.scalarize to get a more readable version of the equation:

soutput |> Symbolics.scalarize

\[ \begin{equation} \left[ \begin{array}{c} \tanh\left( \mathtt{W\_6}_{1} + \mathtt{W\_5}_{1,1} \tanh\left( \mathtt{W\_4}_{1} + \mathtt{W\_3}_{1,1} \tanh\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{1,2} \tanh\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{1,3} \tanh\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) + \mathtt{W\_5}_{1,2} \tanh\left( \mathtt{W\_4}_{2} + \mathtt{W\_3}_{2,1} \tanh\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{2,2} \tanh\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{2,3} \tanh\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) + \mathtt{W\_5}_{1,3} \tanh\left( \mathtt{W\_4}_{3} + \mathtt{W\_3}_{3,1} \tanh\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{3,2} \tanh\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{3,3} \tanh\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) \right) \\ \end{array} \right] \end{equation} \]

We can compute the symbolic gradient with SymbolicNeuralNetworks.Gradient:

using SymbolicNeuralNetworks: derivative
derivative(SymbolicNeuralNetworks.Gradient(soutput, nn))[1].L1.b

\[ \begin{equation} \left[ \begin{array}{c} \left( \mathtt{W\_3}_{1,1} \mathtt{W\_5}_{1,1} \left( 1 - \tanh^{2}\left( \mathtt{W\_4}_{1} + \mathtt{W\_3}_{1,1} \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{1,2} \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{1,3} \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) \right) \left( 1 - \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) \right) + \mathtt{W\_3}_{2,1} \mathtt{W\_5}_{1,2} \left( 1 - \tanh^{2}\left( \mathtt{W\_4}_{2} + \mathtt{W\_3}_{2,1} \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{2,2} \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{2,3} \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) \right) \left( 1 - \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) \right) + \mathtt{W\_3}_{3,1} \mathtt{W\_5}_{1,3} \left( 1 - \tanh^{2}\left( \mathtt{W\_4}_{3} + \mathtt{W\_3}_{3,1} \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{3,2} \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{3,3} \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) \right) \left( 1 - \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) \right) \right) \left( 1 - \tanh^{2}\left( \mathtt{W\_6}_{1} + \mathtt{W\_5}_{1,1} \tanh^{2}\left( \mathtt{W\_4}_{1} + \mathtt{W\_3}_{1,1} \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{1,2} \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{1,3} \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) + \mathtt{W\_5}_{1,2} \tanh^{2}\left( \mathtt{W\_4}_{2} + \mathtt{W\_3}_{2,1} \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{2,2} \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{2,3} \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) + \mathtt{W\_5}_{1,3} \tanh^{2}\left( \mathtt{W\_4}_{3} + \mathtt{W\_3}_{3,1} \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{3,2} \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{3,3} \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) \right) \right) \\ \left( \mathtt{W\_3}_{1,2} \mathtt{W\_5}_{1,1} \left( 1 - \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) \right) \left( 1 - \tanh^{2}\left( \mathtt{W\_4}_{1} + \mathtt{W\_3}_{1,1} \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{1,2} \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{1,3} \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) \right) + \mathtt{W\_3}_{2,2} \mathtt{W\_5}_{1,2} \left( 1 - \tanh^{2}\left( \mathtt{W\_4}_{2} + \mathtt{W\_3}_{2,1} \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{2,2} \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{2,3} \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) \right) \left( 1 - \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) \right) + \mathtt{W\_3}_{3,2} \mathtt{W\_5}_{1,3} \left( 1 - \tanh^{2}\left( \mathtt{W\_4}_{3} + \mathtt{W\_3}_{3,1} \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{3,2} \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{3,3} \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) \right) \left( 1 - \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) \right) \right) \left( 1 - \tanh^{2}\left( \mathtt{W\_6}_{1} + \mathtt{W\_5}_{1,1} \tanh^{2}\left( \mathtt{W\_4}_{1} + \mathtt{W\_3}_{1,1} \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{1,2} \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{1,3} \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) + \mathtt{W\_5}_{1,2} \tanh^{2}\left( \mathtt{W\_4}_{2} + \mathtt{W\_3}_{2,1} \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{2,2} \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{2,3} \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) + \mathtt{W\_5}_{1,3} \tanh^{2}\left( \mathtt{W\_4}_{3} + \mathtt{W\_3}_{3,1} \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{3,2} \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{3,3} \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) \right) \right) \\ \left( \mathtt{W\_3}_{1,3} \mathtt{W\_5}_{1,1} \left( 1 - \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) \left( 1 - \tanh^{2}\left( \mathtt{W\_4}_{1} + \mathtt{W\_3}_{1,1} \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{1,2} \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{1,3} \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) \right) + \mathtt{W\_3}_{2,3} \mathtt{W\_5}_{1,2} \left( 1 - \tanh^{2}\left( \mathtt{W\_4}_{2} + \mathtt{W\_3}_{2,1} \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{2,2} \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{2,3} \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) \right) \left( 1 - \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) + \mathtt{W\_3}_{3,3} \mathtt{W\_5}_{1,3} \left( 1 - \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) \left( 1 - \tanh^{2}\left( \mathtt{W\_4}_{3} + \mathtt{W\_3}_{3,1} \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{3,2} \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{3,3} \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) \right) \right) \left( 1 - \tanh^{2}\left( \mathtt{W\_6}_{1} + \mathtt{W\_5}_{1,1} \tanh^{2}\left( \mathtt{W\_4}_{1} + \mathtt{W\_3}_{1,1} \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{1,2} \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{1,3} \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) + \mathtt{W\_5}_{1,2} \tanh^{2}\left( \mathtt{W\_4}_{2} + \mathtt{W\_3}_{2,1} \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{2,2} \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{2,3} \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) + \mathtt{W\_5}_{1,3} \tanh^{2}\left( \mathtt{W\_4}_{3} + \mathtt{W\_3}_{3,1} \tanh^{2}\left( \mathtt{W\_2}_{1} + \mathtt{W\_1}_{1,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{1,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{3,2} \tanh^{2}\left( \mathtt{W\_2}_{2} + \mathtt{W\_1}_{2,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{2,2} \mathtt{sinput}_{2} \right) + \mathtt{W\_3}_{3,3} \tanh^{2}\left( \mathtt{W\_2}_{3} + \mathtt{W\_1}_{3,1} \mathtt{sinput}_{1} + \mathtt{W\_1}_{3,2} \mathtt{sinput}_{2} \right) \right) \right) \right) \\ \end{array} \right] \end{equation} \]

Info

SymbolicNeuralNetworks.Gradient can also be called as SymbolicNeuralNetworks.Gradient(snn), so without providing a specific output. In this case soutput is taken to be the symbolic output of the network (i.e. is equivalent to the construction presented here). Also note that we further called SymbolicNeuralNetworks.derivative here in order to get the symbolic gradient (as opposed to the symbolic output of the neural network).

In order to train a SymbolicNeuralNetwork we can use:

pb = SymbolicPullback(nn)
Info

SymbolicNeuralNetworks.Gradient and SymbolicPullback both use the function SymbolicNeuralNetworks.symbolic_pullback internally, so are computationally equivalent. SymbolicPullback should however be used in connection to a NetworkLoss and SymbolicNeuralNetworks.Gradient can be used more generally to compute the derivative of array-valued expressions.

We want to use our one-layer neural network to approximate a Gaussian on the interval $[-1, 1]\times[-1, 1]$. We fist generate the data for this task:

using GeometricMachineLearning

x_vec = -1.:.1:1.
y_vec = -1.:.1:1.
xy_data = hcat([[x, y] for x in x_vec, y in y_vec]...)
f(x::Vector) = exp.(-sum(x.^2))
z_data = mapreduce(i -> f(xy_data[:, i]), hcat, axes(xy_data, 2))

dl = DataLoader(xy_data, z_data)
[ Info: You have provided an input and an output.

Note that we use GeometricMachineLearning.DataLoader to process the data. We further also visualize them:

using CairoMakie

fig = Figure()
ax = Axis3(fig[1, 1])
surface!(x_vec, y_vec, [f([x, y]) for x in x_vec, y in y_vec]; alpha = .8, transparency = true)
fig
Example block output

We now train the network:

nn_cpu = NeuralNetwork(c, CPU())
o = Optimizer(AdamOptimizer(), nn_cpu)
n_epochs = 1000
batch = Batch(10)
@time o(nn_cpu, dl, batch, n_epochs, pb.loss, pb; show_progress = false);
  8.910982 seconds (36.09 M allocations: 1.389 GiB, 1.92% gc time)

We now compare the neural network-approximated curve to the original one:

fig = Figure()
ax = Axis3(fig[1, 1])

surface!(x_vec, y_vec, [c([x, y], params(nn_cpu))[1] for x in x_vec, y in y_vec]; alpha = .8, colormap = :darkterrain, transparency = true)
fig
Example block output

We can also compare the time it takes to train the SymbolicNeuralNetwork to the time it takes to train a standard neural network:

loss = FeedForwardLoss()
pb2 = GeometricMachineLearning.ZygotePullback(loss)
@time o(nn_cpu, dl, batch, n_epochs, pb2.loss, pb2; show_progress = false);
  1.582680 seconds (25.01 M allocations: 1.213 GiB, 6.88% gc time)
Info

For the case presented here we do not observe speed-ups of the symbolic neural network over the standard neural network. For other cases, especially Hamiltonian neural networks, this is however different.