Understanding architecture of LSTM cell from scratch with code.

Written by maniksoni653 | Published 2018/06/18
Tech Story Tags: machine-learning | lstm | deep-learning | architecture-of-lstm-cell | lstm-cell

TLDRvia the TL;DR App

source:Google

Ordinary Neural Networks don’t perform well in cases where sequence of data is important. For example: language translation, sentiment-analysis, time-series and more. To overcome this failure, RNNs were invented. RNN stands for “Recurrent Neural Network”. An RNN cell not only considers its present input but also the output of RNN cells preceding it, for it’s present output.

Simple form of Vanilla RNN’s present state could be represented as :

Representation of simple RNN cell,source: stanford

RNNs performed very well on sequential data and performed well on tasks where sequence was important.

But there exists many problems with ordinary RNNs

Vanishing gradients problem:

Vanishing Gradient problem 1.tanh 2.derivative of tanh

Hyperbolic tangent(tanh) is mostly used as activation function in RNNs which lies in [-1,1] and derivative of tanh lies in [0,1]. During backpropagation, as gradient is calculated by chain rule, it has an effect of multiplying these small numbers n (number of times tanh used in rnn architecture) times which squeezes the final gradient to almost zero and hence subtracting gradient from weights doesn’t make any change to them which stops the training of model.

Exploding gradients problem:

Opposite to vanishing gradient problem, while following chain rule we multiply with the weight matrix(transposed W )too at each step, and if the values are larger than 1, multiplying a large number to itself many times leads to a very large number leading to explosion of gradient.

exploding and vanishing gradients, source: CS231N stanford

Long-Term Dependencies problem

**Long-term dependency problem, each node represents an rnn cell.**source:Google

RNNs are good in handling sequential data but they run into problem when the context is far away. Example: I live France and I know ____. The answer must be ‘French’ here but if the there are some more words in between ‘I live in France’ & ‘I know ____’. It’ll be difficult for RNNs to predict ‘French’. This is the problem of Long-Term Dependencies. Hence we come to LSTMs.

Long Short Term Memory Networks

LSTMs are special kind of RNNs with capability of handling Long-Term dependencies. LSTMs also provide solution to Vanishing/Exploding Gradient problem. We’ll discuss later in this article.

A simple LSTM cell looks like this:

RNN vs LSTM cell representation, source: stanford

At start, we need to initialize the weight matrices and bias terms as shown below.

Some information about an LSTM cell

A simple LSTM cell consists of 4 gates:

3 LSTM cells connected to each other. source:Google

LSTM cell visual representation, source: Google

handy information about gates, source: Stanford CS231N

Let’s discuss the gates:

•Forget Gate: After getting the output of previous state, h(t-1), Forget gate helps us to take decisions about what must be removed from h(t-1) state and thus keeping only relevant stuff. It is surrounded by a sigmoid function which helps to crush the input between [0,1].It is represented as:

Forget Gate, src: Google

We multiply forget gate with previous cell state to forget the unnecessary stuff from previous state which is not needed anymore, as shown below:

•Input Gate: In the input gate, we decide to add new stuff from the present input to our present cell state scaled by how much we wish to add them.

Input Gate+Gate_gate,photo credits: Christopher Olah

In the above photo, sigmoid layer decides which values to be updated and tanh layer creates a vector for new candidates to added to present cell state. The code is shown below.

To calculate the present cell state, we add the output of ( (input_gate*gate_gate) and forget gate) as shown below.

Output Gate: Finally we’ll decide what to output from our cell state which will be done by our sigmoid function.

We multiply the input with tanh to crush the values between (-1,1) and then multiply it with the output of sigmoid function so that we only output what we want to.

output Gate, source:Google

An overall view of what we did.

LSTM responds to vanishing and exploding gradient problem in the following way. LSTM has much cleaner backprop compared to Vanilla RNNs

**Gradient flows smoothly during Backprop,**source: CS231N stanford

•First, There is no multiplication with matrix W during backprop. It’s element wise multiplication with f(forget gate). So it’s time complexity is less.

  • Second, During backprop through each LSTM cell, it’s multiplied by different values of forget fate, which makes it less prone to vanishing/exploding gradient. Though, if values of all forget gates are less than 1, it may suffer from vanishing gradient but in practice people tend to initialise the bias terms with some positive number so in the beginning of training f(forget gate) is very close to 1 and as time passes the model can learn these bias terms.

  • Still, the model may suffer with vanishing gradient problem but chances are very less.

This article was limited to architecture of LSTM cell but you can see the complete code HERE. The code also implements an example of generating simple sequence from random inputs using LSTMs.

I tried the program using Deep Learning Studio:

Deep Learning Studio comes with inbuilt jupyter notebooks and pre-installed deep learning frameworks such as Tensorflow, Caffe etc.. So you just need to click on Notebooks(in the left pane) to open a jupyter notebook in Deep Learning Studio and you’re ready to go!

A special thanks to Christopher Olah ,Stanford CS231n team.

If you liked the article, do share and clap 😄.For more articles about Deep Learning follow me on Medium and LinkedIn.

Thanks for reading.

Happy LSTMs.

— — — — — — — — — — — — — — — — — — — — — — — — — — — — — — —

More learning stuff and References:

Understanding LSTM Networks -- colah's blog_These loops make recurrent neural networks seem kind of mysterious. However, if you think a bit more, it turns out that…_colah.github.io


Written by maniksoni653 | SDE @Amazon
Published by HackerNoon on 2018/06/18