The Thing about LSTM and Exploding Gradients

Why LSTM addresses vanishing gradients but not exploding ones?
Published

April 11, 2022

An important challenge in training recurrent neural networks (RNNs) is the vanishing / exploding gradient problem. Here is an (over-simplified) illustration of the problem. Suppose you have an RNN with T time steps, with an initial input x (i.e., h0=x) and weight w. Assuming a linear activation function (again, for simplicity). The hidden states at time t will be: ht=wht1=wth0=wtx Therefore, the derivative / gradient with respect to parameter w is dhtdw=twt1x. The longer the time steps t, the higher the exponent in the wt1. As a result, for long sequences, the gradient vanishes even if w is slightly smaller than 1, and it explodes even if w is slightly greater than 1. This makes training RNN unstable.

At the root of this problem is the self-multiplication of weights across many time steps. The parameter sharing technique that enables RNNs to handle variable-length sequences is also the culprit of the vanishing / exploding gradient problem.

The LSTM architecture offers robustness against the vanishing gradient problem. To understand how, let’s first layout the key pieces of a LSTM cell;

where Θ(.) are all parameters that the network learns from data. The three “gates” can be conceptually thought of as “weights”, and the real “magic” of LSTM lies in the internal cell state. Notice that Ct is “auto-regressive”, in the sense that it depends on Ct1 through the time-varying forget gate weights. Having the forget gate weights close to 1 would allow Ct to “memorize” information from previous states. This is what mitigates the vanishing gradient problem.

However, the LSTM architecture does not address the exploding gradient problem. This is because the self-multiplication problem still exists through other variables, such as outputi. If we remove the internal cell state for a moment, the output ht=outputi would be exactly the same as what you get in a regular RNN architecture, where self-multiplication of Θoutput again is a problem.

For more technical / mathematical discussions of this issue, I recommend the following this StackExchange Q&A and this blog post.