Long Short Term Memory (LSTM)- Improving RNNs

In this article, we will introduce Long-Short-Term Memory Networks (LSTMs), variants of regular vanilla Recurrent Neural Networks (RNNs) that are better at dealing with long-term dependencies.
They use different "gates" that remember or forget certain information that they deem necessary or unimportant for prediction.
LSTMs are the state-of-the-art version of RNNs. They are used extensively within the industry and form the backbone of all the fancy large language models (LLMs) we see today.
We will start by recapping RNNs, explain the vanishing and exploding gradient problem, and finally dive into how LSTMs work and are naturally a better model.
RNN Overview
Recurrent Neural Networks are a spin on the regular feedforward neural network, allowing them to handle sequence data like natural language and time series much better.
They do this by having a hidden recurrent neuron that passes in previous inputs and outputs to the next layer. Below is an example.

Notice how we have a vector h propagating through the network. This is the hidden state and the main feature behind RNNs, which is why they work well for sequence data.
The hidden states combine the previous calculated hidden state and the new input at the corresponding time step. They are then used to compute the final output of that particular time step with an applied sigmoid activation function. Mathematically we have:

Where:
- Y: output vectors.
- X: input vectors for the features.
- h: hidden states.
- V: weighted matrix for the output.
- U: weighted matrix for the input.
- W: weighted matrix for the hidden states.
The weight matrices of V, U, and W are found through backpropagation over time, which is just a variant of the backpropagation algorithm.
I have omitted the bias terms for simplicity.
Visually, it looks like this:

For example, when predicting _Y_1, the RNN would use the inputs of X_1 plus the output from the previous time step from Y_0. As Y_0 influences Y_1, we can then see that Y_0 will also indirectly influence Y_2_, demonstrating the recurrent nature.
If you want a full intro to RNNs then check out my previous blog.
Recurrent Neural Networks – An Introduction To Sequence Modelling
Vanishing & Exploding Problem
A positive aspect of RNNs is that each layer shares the weight matrices of U, W, and V, whereas each layer has its own weight matrix for a regular feedforward neural network. This makes RNNs more memory efficient.
However, this sharing of weight matrices also leads to one of their significant flaws, the vanishing and exploding gradient problem.
RNNs learn using a variant of the backpropagation algorithm called backpropagation through time](https://towardsdatascience.com/backpropagation-through-time-how-rnns-learn-e5bc03ad1f0a) (BPTT). This algorithm is similar to regular backpropagation, but with more ‘nested' computations due to the shared weight matrices between layers, which need to be updated at each time step.
The general formula for BPTT is:

Where J is an arbitrary weight matrix in the RNN, which will be either U, W or V and E is the total error.
As RNNs tend to be deeper than regular neural networks (each time step is a layer), if the gradient is less than or greater than one, this would lead to either the gradient vanishing or exploding when it's propagated backwards.

For the interested reader, the complete mathematical breakdown of why this happens can be found here. It involves fun things like eigenvalues and Jacobian matrices!
If you want a full breakdown of the backpropagation through time algorithm (BTTP) and the exploding and vanishing gradient problem, then check out my previous post.
A great example to showcase the vanishing and exploding gradient problem is shown in Stanford's CS224D course. Imagine we have two sentences:
- "Jane walked into the room. John walked in, too. Jane said hi to ___."
- "Jane walked into the room. John walked in, too. It was late, and everyone was walking home after a long day at work. Jane said hi to ___."
In both cases, the blank space most likely refers to John. An RNN should learn this context to understand that John is the output for both sentences.
However, from experiments, sentence 1 was more likely to be predicted correctly than sentence 2. This is because the gradient vanishes for sentence 2, so it doesn't see the distant context as efficiently when making the prediction.
This is clearly a problem, as RNNs are designed to have "memory" for scenarios like these.
So, what do we do about it?
Long-Short-Term Memory Networks!
Long-Short-Term Memory
Overview
LSTMs were introduced by Hochreiter & Schmidhuber in 1997, and their fundamental idea is that of "gates" inside the recurrent cell. These gates control which things the recurrent cell should remember and forget better, building long-term memory.
In normal RNNs, the recurrent cells look like:

However, the LSTM cell is a lot more complicated:

I appreciate there is a lot going on here, but lets break it down step by step.
Cell State
One of the key differences is the introduction of the cell state C. This cell state contains essential information regarding context and historical patterns, basically memory. It runs through the cell and can be adjusted by several so-called gates with linear interactions.
It's easy to get confused between the cell state and hidden state, but in general, the cell state is meant to contain the whole memory of the network, while the hidden state is just for short-term dependency and really only has recent information. It is also used in the output of the cell for prediction.
A regular RNN hasn't got a cell state, only a hidden state.
Forget Gate
The first step in a LSTM is the forget gate. This gate is responsible for deciding what old information to remove from the previous cell state **C{t-1}**_.

Where:
- σ: sigmoid activation function.
- _Wf: the weight matrix for the forget gate.
- h{t−1}_: the output from the previous time step.
- _xt: the input at time step t.
- _bf: bias for the forget gate.
- _f_t_: forget gate output which has a value between 0 and 1.
- _X_t_: current input.
The output, _f_t, is then multiplied by the previous cell state C_{t-1}_ to modify which elements should be forgotten. Due to the sigmoid, the value is between 0 and 1, where the former leads to forgetting and the latter adds to the memory.
It learns this information by finding the correct values in the _W_f_ matrix through backpropagation. This allows us to decide which elements to remember or forget.
Input Gate & Candidate Cell State
The input gate, _i_t, is the next but and determines which new memory to add to the cell state at the current time stamp. The candidate cell state, *Ct** holds all the potential information we can add to the cell state.

Where:
- σ: sigmoid activation function.
- tanh: the hyperbolic tangent activation function.
- _W_i_: weight matrix for the input gate.
- _W_c:_ weight matrix for the candidate cell state.
- _bi: bias for input gate
- _b_c_: bias for the candidate cell state.
- _*Ct:** candidate cell state, output value between -1 and 1.
- _i_t_: input gate output between 0 and1 .
- **h{t-1}**_: previous hidden state.
- _X_t_: current input.
Using tanh allows for the increase or decrease of the cell state as tanh squashes the output between -1 and 1. The sigmoid is used for a similar reason as in the forget gate to add new things in memory to the cell state.
Note that the input gate takes in the exact same inputs of _X_t and h_{t-1}_ as the forget gate.
Cell State Update
We want to only add the relevant information from the candidate cell state, _*Ct,** to the new cell state, _C_t. To do this, we can multiply the candidate cell state with the input gate, i_t, and add that to the product of the forget gate, f_t, and the previous cell state, C_{t-1}_.

Overall, we have updated the cell state by forgetting irrelevant information and adding new relevant information to the cell state.
Output Gate
The last part is deciding what we will output from the cell as the prediction. We first calculate the output gate, _o_t_, which determines what part of the cell state we will output; this is basically the normal hidden state recurrent cell in normal RNNs.
This output is then multiplied by the tanh of the new cell state to output only the values we want.

Where:
- σ: sigmoid function.
- tanh: hyperbolic tangent activation function.
- _W_o_: weight matrix.
- _b_o:_ the bias.
- _o_t_: output state.
- _h_t_: new hidden state.
- **h{t-1}**_: previous hidden state.
- _X_t_: current input.
- _C_t_: new cell state.
And that's really all there is to it! A key thing to mention is that all the weight matrices must be trained to learn which elements to forget and remember using BPTT.
Variants
This is just the standard LSTM and there are other variants out there. One of the most common ones are:
To go over all these in this article is slightly out of scope, but the interested reader can find more about it from the links provided above. In the next article I will cover Gated Recurrent Unit though.
Summary & Further Thoughts
Even though LSTMs may seem scary initially, I hope this article made them less intimidating for you! There are quite a few calculations, but they are all very similar. There are two fundamental components: the forget gate, which decides what to remove from memory, and the input gate, which determines which new information to add to memory. The benefits of LSTMs is that they have better long term memory due to these gates.
Another Thing!
I have a free newsletter, Dishing the Data, where I share weekly tips for becoming a better Data Scientist, my general experience in the field, and some thoughts I have had over the past week.
Connect With Me!
- LinkedIn, X (Twitter), or Instagram.
- My YouTube Channel to learn technical data science and Machine Learning concepts!