Here are some additional details on how the proportional limit enables relating discrete covariance updates to the continuous neural SDE,
- In a proportional limit, the number of layers
$(L)$ and width$(n)$ satisfy$(L/n \rightarrow d/w)$ for some constant ratio$(d/w)$ as$(n \rightarrow \infty)$ . - This constant ratio allows defining a continuous time
$(t = l/n \in [0, T])$ where$(T = d/w)$ and$(l)$ is the layer index. - As
$(n \rightarrow \infty)$ , there are infinitesimally small gaps$(1/n)$ between discrete layers$(l)$ and$(l+1)$ in this pseudo-continuous time. - The attention operation gives covariance updates of size
$(O(1/n))$ . - This means the discrete covariance updates
$(V_{l+1} - V_l)$ are small$O(1/n)$ for large$(n)$ . - These small updates can be embedded into the continuous time
$(t \in [0, T])$ as$(n \rightarrow \infty)$ . - The evolution of
$(V_l)$ in discrete layers can be approximated by discretizing the SDE$(dV_t)$ using step size$1/n$ . - This gives updates like
$(V_{l+1} \approx V_l + \text{(SDE drift and diffusion terms)})$ that converge to the continuous SDE. - So the constant
$(d/w)$ ratio provides a continuous notion of time to embed discrete steps into. - The
$(O(1/n))$ update size allows approximating the SDE via these discrete updates.
Together this connects the discrete covariance evolution to the solution of the continuous SDE.
In summary, the key is that the temperature scaling leads to
The following is the mathematical derivation for the Taylor expansion approximation for the centered softmax attention matrix
We start with the definition of
The softmax function is used in machine learning to convert a vector of arbitrary values to a probability distribution. It is defined as follows:
For a vector
-
Case
$(i = j)$ :$$[ \frac{\partial \text{softmax}(z_i)}{\partial z_i} = \text{softmax}(z_i) \cdot (1 - \text{softmax}(z_i)) ]$$ -
Case
$(i \neq j)$ :$$[ \frac{\partial \text{softmax}(z_i)}{\partial z_j} = - \text{softmax}(z_i) \cdot \text{softmax}(z_j) ]$$
So, the derivative of the softmax function with respect to its inputs can be compactly represented as follows:
The second derivative,
In matrix form:
The first term: