Linear Symplectic Attention
The attention layer introduced here can be seen as an extension of the SympNet gradient layer to the setting where we deal with time series data. Before we introduce the LinearSymplecticAttention
layer we first define a notion of symplecticity for multi-step methods.
This definition is different from [58, 59], but similar to the definition of volume-preservation for product spaces in [4].
A multi-step method $\varphi\times_T\mathbb{R}^{2n}\to\times_T\mathbb{R}^{2n}$ is called symplectic if it preserves the the symplectic product structure, i.e. if $\hat{\varphi}$ is symplectic.
The symplectic product structure is the following skew-symmetric non-degenerate bilinear form:
\[\hat{\mathbb{J}}([z^{(1)}, \ldots, z^{(T)}], [\tilde{z}^{(1)}, \ldots, \tilde{z}^{(T)}]) := \sum_{i=1}^T (z^{(i)})^T\mathbb{J}_{2n}\tilde{z}^{(i)}.\]
$\hat{\mathbb{J}}$ is defined through the isomorphism between the product space and the space of big vectors
\[\hat{}: \times_\text{($T$ times)}\mathbb{R}^{d}\stackrel{\approx}{\longrightarrow}\mathbb{R}^{dT},\]
so we induce the symplectic structure on the product space through the pullback of this isomorphism.
In order to construct a symplectic attention mechanism we extend the principle behind the SympNet gradient layer, i.e. we construct scalar functions that only depend on $[q^{(1)}, \ldots, q^{(T)}]$ or $[p^{(1)}, \ldots, p^{(T)}]$. The specific choice we make here is the following:
\[F(q^{(1)}, \ldots, q^{(T)}) = \frac{1}{2}\mathrm{Tr}(QAQ^T),\]
where $Q := [q^{(1)}, \ldots, q^{(T)}]$ is the concatenation of the vectors into a matrix. We therefore have for the gradient:
\[\nabla_Qf = \frac{1}{2}Q(A + A^T) =: Q\bar{A},\]
where $\bar{A}\in\mathcal{S}_\mathrm{sym}(T)$ is a symmetric matrix. So the map performs:
\[[q^{(1)}, \ldots, q^{(T)}] \mapsto \left[ \sum_{i=1}^Ta_{1i}q^{(i)}, \ldots, \sum_{i=1}^Ta_{Ti}q^{(i)} \right] \text{ for } a_{ji} = [\bar{A}]_{ji}.\]
Note that there is still a reweighting of the input vectors performed with this linear symplectic attention, like in standard attention and volume-preserving attention, but the crucial difference is that the coefficients $a_{ji}$ here are fixed and not computed as the result of a softmax or a Cayley transform. We hence call this attention mechanism linear symplectic attention as it performs a linear reweighting of the input vectors. We distinguish it from the standard attention mechanism, which computes coefficients that depend on the input nonlinearly.
Library Functions
GeometricMachineLearning.LinearSymplecticAttention
— TypeLinearSymplecticAttention
Implements the linear symplectic attention layers. Analogous to GradientLayer
it performs mappings that only change the $Q$ or the $P$ part.
This layer preserves symplecticity in the product-space sense.
For more information see LinearSymplecticAttentionQ
and LinearSymplecticAttentionP
.
Implementation
The coefficients of a LinearSymplecticAttention
layer is a SymmetricMatrix
:
using GeometricMachineLearning
using GeometricMachineLearning: params
l = LinearSymplecticAttentionQ(3, 5)
ps = params(NeuralNetwork(Chain(l))).L1
typeof(ps.A) <: SymmetricMatrix
# output
true
GeometricMachineLearning.LinearSymplecticAttentionQ
— TypeLinearSymplecticAttentionQ(sys_dim, seq_length)
Make an instance of LinearSymplecticAttentionQ
for a specific dimension and sequence length.
Performs:
\[\begin{pmatrix} Q \\ P \end{pmatrix} \mapsto \begin{pmatrix} Q + \nabla_PF \\ P \end{pmatrix},\]
where $Q,\, P\in\mathbb{R}^{n\times{}T}$ and $F(P) = \frac{1}{2}\mathrm{Tr}(P A P^T)$.
The parameters of this layer are $\bar{A} = \frac{1}{2}(A + A^T).$
GeometricMachineLearning.LinearSymplecticAttentionP
— TypeLinearSymplecticAttentionP(sys_dim, seq_length)
Make an instance of LinearSymplecticAttentionP
for a specific dimension and sequence length.
Performs:
\[\begin{pmatrix} Q \\ P \end{pmatrix} \mapsto \begin{pmatrix} Q \\ P + \nabla_QF \end{pmatrix},\]
where $Q,\, P\in\mathbb{R}^{n\times{}T}$ and $F(Q) = \frac{1}{2}\mathrm{Tr}(Q A Q^T)$.
The parameters of this layer are $\bar{A} = \frac{1}{2}(A + A^T).$