7. Recurrent Neural Networks

🧱 The encoder-decoder RNN is quite useful for understanding the transformer.

Recurrent Neural Networks

Many tasks require input or output spaces that contain sequences. For instance, in translation programs we oftentimes encode words as sequences of one-hot vectors. These vectors are index vectors with a one at the position of an integer that maps to a word in a fixed vocabulary. A simple recurrent neural network (RNN) processes a sequence of vectors \(\{x_1, ..., x_T\}\) with a recurrence formula \(h_t = f_\theta(h_{t-1},x_t)\). The function \(f\) that we will describe in more detail below takes the same parameters \(\theta\) at every time step to process an arbitrary number of vectors (cf. parameter sharing and inductive bias). An interpretation of the hidden vector \(h\) is that of a summary of all previous \(x\) vectors. A common initialization is \(h_0=\vec{0}\). We can define \(f\) according to three criteria. The first criterion is the input and output space. We would model \(f\) differently for spaces from one to many, many to one or many to many vectors. The second criterion is the order in which we process input vectors and predict output vectors. This depends on the nature of the data. For example, we may either want to predict \(y_t\) directly after processing \(x_t\), or predict the complete output sequence \(\{y_1, ..., y_T\}\) after we have processed the complete input sequence \(\{x_i\}_{i=1}^T\). Another option is to not only compute a summary of the past but also a summary of the future and use both vectors in our predictions. The third criterion is how we want to improve upon drawbacks of the model formulation for longer sequences. We will come to this point later in this section. To illustrate how RNNs work, we will look at two examples.

Vanilla Recurrent Neural Networks

Vanilla Recurrent Neural Networks use a recurrence defined by

\(\)\begin{align} h_t &= \phi \big ( W\begin{bmatrix} x_{t}
h_{t-1} \end{bmatrix} \big ). \label{eq:vanilla-rnn} \end{align}\(\)

Here, we concatenate the current input and the previous hidden states, transform both linearly and pass it to a non-linear activation function. This vector notation is equivalent to \(h_t = \phi (W_{xh}x_t + W_{hh}h_{t-1})\). The two matrices \(W_{xh}\) and \(W_{hh}\) are concatenated horizontally to \(W\). If the input vectors \(x_t\) have dimension \(1 \times D\) and the hidden vectors dimension \(1 \times H\), then \(W_{xh} \in \mathbb{R}^{H \times D}\), \(W_{hh} \in \mathbb{R}^{H \times H}\), and weight matrix \(W\) is a matrix with dimensions \([H \times (D+H)]\). Vanilla RNN models the current hidden states \(h_t\) at each time step as a linear function of the elements in the previous hidden states \(h_{t-1}\) and the current input \(x_t\), transformed by a non-linearity. In a classification task, e.g. where we want to predict the next written letter in a prompt by the previously types letters, we would apply the Softmax function to a linear transformation of the hidden state at each time step, \(o_t = W_{ho} h_t\), in order to predict the next character’s one-hot encoding. This is illustrated by Figure 5.

Figure 5. Vanilla RNN as character-level language model. The left side shows the unrolled RNN. The vocabulary has four characters and the training sequence is “hello”. Each letter is represented by a 1-hot encoding (yellow) and the RNN predicts the encodings of the next letter (green) at each time step. The RNN has a hidden state with three dimensions (red). The output has four dimensions. The dimensions are the logits for the next character. They are the softmax of a linear transformation of the hidden states. During supervised learning, the model will be trained to increase (decrease) the logits of the correct (false) characters. The right side shows the rolled-up RNN. The graph has a cycle that shows that the same hidden states are shared across time and that the architecture is the same for each step.

Encoder-Decoder RNNs

Encoder-Decoder RNNs use the complete input history \(\{x_i\}_1^{T_x}\) to predict the first output \(y_1\). Then, it additionally uses the complete prediction history \(\{\hat{y}_i\}_1^{t-1}\) to predict the next \(y_t\) for \(t=2,...,T_y\). The model is able to generate sequences of arbitrary length that can be unequal to the length of the input sequence. An examplary task is translating sentences from English to German. Here, we work with one-hot encodings of words from a fixed vocabulary instead of letters from an alphabet. To build this RNN, we use the same recursion from the previous example as an encoder RNN. However, do not classify the output at each timestep directly. Instead, we use the last hidden state from the encoder \(h_T\) as a context vector \(c_0\). Intuitively, the context vector is an abstract representation of the entire input sentence. Then, we use another RNN, the decoder RNN to process the information from the context and the output from the previous period to generate the hidden states \(s_t\) for the current period:

\(\)\begin{align} s_t = \phi (W_{os}o_{t-1} + W_{ss}s_{t-1}) \end{align}\(\)

with \(s_0 = c_0\) and, for instance \(y_0 = \vec{0}\). We then predict the next word’s one-hot encoding from the hidden states like in the last example with the softmax of \(o_t=W_{so}s_t\). To model the fact that input and output sequences can have different length, we require special start of sentence, \(<\)sos\(>\), and end of sentence, \(<\)eos\(>\), tokens.

The encoder uses the \(<\)eos\(>\) as \(x_{1}\). The decoder takes \(<\)sos\(>\) as \(y_1\) and stops the recursion when it returns \(<\)eos\(>\).

Figure 6. Encoder-Decoder RNN as word-level language model. The input language has two and the output language has three words. Every word, the start and the end of a sentence are represented by a one-hot encoding (yellow) and the RNN predicts the encodings of the translation of the input sentence. The encoder RNN has a hidden state (green) that is updated by one word step-by-step to produce the final context vector (purple). The decoder RNN takes linear functions of the final encoding and the start embedding to compute the hidden states for the next output embedding (blue). The one-hot encoding of the output words (red) are a softmax of a linar function of these hidden states. The decoder RNN iterates the prediction until it returns the end token. With this architecture, we can predict sequences of arbitrary length using the embedded information of a whole input sequence.

The simplicity of the RNN’s formulation has two drawbacks. First, the connections between inputs and hidden states through linear layers and element-wise non-linearities is not flexible enough for some tasks. Second, the recurrent graph structure leads to problematic dynamics during the backward pass.

Exploding and Vanishing Gradient Problem

We now explore the second problem more formally. The vanilla RNN’s loss with respect to the weight matrix in Equation 1 is given by:

\(\)\begin{align}\frac{\partial L}{\partial W} = \sum_{t=1}^T \sum_{k=1}^{t+1} \frac{\partial L_{t+1}}{\partial o_{t+1}} \frac{\partial o_{t+1}}{\partial h_{t+1}} \frac{\partial h_{t+1}}{\partial h_{k}} \frac{\partial h_{k}}{\partial W}. \label{eq:rnn-derivative}\end{align}\(\) The crucial part is the derivative of the hidden layer in the next period with respect to some past layer \(k\). It is given by the recursive product \(\)\begin{align}\label{eq:rnn-derivative-recursion} \frac{\partial h_{t+1}}{\partial h_k} = \prod_{j=k}^t\frac{\partial h_{t+1}}{\partial h_{k}} = \frac{\partial h_{t+1}}{\partial h_{t}} \frac{\partial h_{t}}{\partial h_{t-1}} … \frac{\partial h_{k+1}}{\partial h_{k}} = \prod_{j=k}^t \text{diag} \big(W_{hh}\phi’ (W [x_{j+1};h_j])\big).\end{align}\(\) Equation 4 shows that, in order to compute the gradients for \(W_{xh}\) and \(W_{hh}\), we have to multiply a large number of Jacobians depending on the size of the input sequences. The reason is that the derivative of a hidden layer \(h_{t+1}\) with respect to some previous layer \(h_k\) equals the product of the derivatives from \(t+1\) to \(k+1\) with respect to their previous layers. For simplicity, let us assume that these derivatives are constant. Further, let us compute the eigendecomposition of the fixed Jacobian matrix \(\frac{\partial h_{t+1}}{\partial h_t}\) to analyze the problem more formally. We obtain the eigenvalues \(\lambda_1, \lambda_2,...,\lambda_n\), with \(|\lambda_1|> |\lambda_2|>...>|\lambda_n|\), and the respective eigenvectors, \(v_1, v_2, ..., v_n\). With these components, we can write the constant update of the hidden state in the direction of an eigenvector \(v_i\) as \(\lambda_i \Delta\). During the backward pass from \(t+1\) to \(k\), this change becomes \(\lambda_i^{t-k} \Delta h\). As a result, if the largest eigenvalue \(\lambda_1\) is smaller than one, the gradient will vanish, and if it is larger than one, the gradient explodes. In practice, the gradients oftentimes rather vanishes due the contribution of the activation derivative which cannot exceed one. In this case, the earlier hidden states are barely updated. The effect is that the front parts of the input do not impact the prediction. In other words, our model has a weak long-term memory.

There are many approaches to alleviate this problem. Examples are regularization, careful weight initializations, using ReLU activations only, and gated recurrent networks like Long Sort-Term Memory (LSTM) or Gated Recurrent Unit (GRU). We can view these gating mechanisms as much more sophisticated extensions of the skip connections from Chapter 5. Today, the most common approach is the transformer architecture. It is not only less vulnerable to vanishing gradients but also provides a more expressive coupling of current inputs and previous states compared to vanilla RNNs.

Citation

In case you like this series, cite it with:

@misc{stenzel2023deeplearning,
  title   = "Deep Learning Series",
  author  = "Stenzel, Tobias",
  year    = "2023",
  url     = "https://www.tobiasstenzel.com/blog/2023/dl-overview/
}