Deep Learning: LSTM & GRU Explained
Deep Learning: LSTM & GRU Explained
Long Short Term Memory Cells (LSTMs), Gated Recurrent Units (GRUs)
Mitesh M. Khapra
1/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
Module 15.1: Selective Read, Selective Write, Selective
Forget - The Whiteboard Analogy
2/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
y1 y2 y3 y4 yt
The state (si ) of an RNN records
information from all previous time
steps
V V V V V
At each new timestep the old
s1
W s2
W s3
W s4
W ... W st
information gets morphed by the
current input
U U U U U
One could imagine that after t steps
the information stored at time step
x1 x2 x3 x4 xt
t − k (for some k < t) gets completely
morphed
so much that it would be impossible
to extract the original information
stored at time step t − k
3/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
y1 y2 y3 y4 yt
A similar problem occurs when
the information flows backwards
(backpropagation)
V V V V V
It is very hard to assign the
W W W W ... W
s1 s2 s3 s4 st responsibility of the error caused
at time step t to the events that
U U U U U
occurred at time step t − k
This responsibility is of course in the
x1 x2 x3 x4 xt
form of gradients and we studied the
problem in backward flow of gradients
We saw a formal argument for this
while discussing vanishing gradients
4/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
Let us see an analogy for this
We can think of the state as a fixed
size memory
Compare this to a fixed size white
board that you use to record
information
At each time step (periodic intervals)
we keep writing something to the
board
Effectively at each time step we
morph the information recorded till
that time point
After many timesteps it would be
impossible to see how the information
at time step t − k contributed to the
state at timestep t 5/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
Continuing our whiteboard analogy,
suppose we are interested in deriving
an expression on the whiteboard
We follow the following strategy at
each time step
Selectively write on the board
Selectively read the already written
content
Selectively forget (erase) some
content
Let us look at each of these in detail
6/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
a = 1 b = 3 c = 5 d = 11 Selective write
There may be many steps in the
Compute ac(bd + a) + ad
derivation but we may just skip a few
Say “board” can have only 3 statements In other words we select what to
at a time. write
1 ac
2 bd
3 bd + a
4 ac(bd + a)
5 ad
6 ac(bd + a) + ad
ac = 5
bd = 33
7/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
a = 1 b = 3 c = 5 d = 11 Selective read
While writing one step we typically
Compute ac(bd + a) + ad
read some of the previous steps we
Say “board” can have only 3 statements have already written and then decide
at a time. what to write next
1 ac For example at Step 3, information
from Step 2 is important
2 bd
In other words we select what to read
3 bd + a
4 ac(bd + a)
5 ad
6 ac(bd + a) + ad
ac = 5
bd = 33
bd + a = 34
8/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
a = 1 b = 3 c = 5 d = 11 Selective forget
Once the board is full, we need to
Compute ac(bd + a) + ad
delete some obsolete information
Say “board” can have only 3 statements But how do we decide what to delete?
at a time. We will typically delete the least
1 ac useful information
2 bd In other words we select what to
forget
3 bd + a
4 ac(bd + a)
5 ad
6 ac(bd + a) + ad
ac = 5
bd = 33
bd + a = 34
9/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
a = 1 b = 3 c = 5 d = 11 There are various other scenarios
Compute ac(bd + a) + ad where we can motivate the need for
selective write, read and forget
Say “board” can have only 3 statements
at a time. For example, you could think of our
brain as something which can store
1 ac only a finite number of facts
2 bd At different time steps we selectively
3 bd + a read, write and forget some of these
facts
4 ac(bd + a)
Since the RNN also has a finite state
5 ad
size, we need to figure out a way to
6 ac(bd + a) + ad allow it to selectively read, write and
forget
ad + ac(bd + a) = 181
ac(bd + a) = 170
ad = 11
10/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
Module 15.2: Long Short Term Memory(LSTM) and
Gated Recurrent Units(GRUs)
11/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
Questions
Can we give a concrete example where RNNs also need to selectively read,
write and forget ?
How do we convert this intuition into mathematical equations ? We will see
this over the next few slides
12/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
+/− Consider the task of predicting the sentiment
(positive/negative) of a review
RNN reads the document from left to right
and after every word updates the state
By the time we reach the end of the document
the information obtained from the first few
words is completely lost
The first ... ... ... performance
Ideally we want to
forget the information added by stop words
Review: The first half of the movie was dry but
the second half really picked up pace. The lead
(a, the, etc.)
actor delivered an amazing performance selectively read the information added by
previous sentiment bearing words (awesome,
amazing, etc.)
selectively write new information from the
current word to the state
13/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
Questions
Can we give a concrete example where RNNs also need to selectively read,
write and forget ?
How do we convert this intuition into mathematical equations ?
14/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
+/− Recall that the blue colored vector
(st ) is called the state of the RNN
It has a finite size (st ∈ Rn ) and is
used to store all the information upto
timestep t
This state is analogous to the
whiteboard and sooner or later it will
The first ... ... ... performance
get overloaded and the information
from the initial states will get
Review: The first half of the movie was dry but
the second half really picked up pace. The lead
morphed beyond recognition
actor delivered an amazing performance Wishlist: selective write, selective
read and selective forget to ensure
that this finite sized state vector is
used effectively
15/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
-1.4
-0.4 selective read
-0.9
0.2
Just to be clear, we have computed
selective write
1
.. selective forget
1
.. a state st−1 at timestep t − 1 and
. . now we want to overload it with new
-2 -1.9
st−1 0.7 st
information (xt ) and compute a new
-0.2
1.1
state (st )
..
. While doing so we want to make sure
-0.3
that we use selective write, selective
xt
read and selective forget so that only
important information is retained in
st
We will now see how to implement
these items from our wishlist
16/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
-1.4 -0.9
Selective Write
-0.4
0 W 0.2
1 σ 1 Recall that in RNNs we use st−1 to
.. ..
. . compute st
U
-2
0 -1.9
st = σ(W st−1 + U xt ) (ignoring bias)
st−1 0.7 st
-0.2
1.1
But now instead of passing st−1 as it
..
.
is to st we want to pass (write) only
-0.3 some portions of it to the next state
xt
In the strictest case our decisions
could be binary (for example, retain
1st and 3rd entries and delete the
rest of the entries)
But a more sensible way of doing
this would be to assign a value
between 0 and 1 which determines
what fraction of the current state to
pass on to the next state
17/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
-1.4 0.2 0.5 -1.4
Selective Write
-0.4 0.34 0.36 -0.4
1 0.9
= 0.9 1 We introduce a vector ot−1 which
.. .. .. ..
. . . . decides what fraction of each element
-2 0.29 0.6 0.7 -2
of st−1 should be passed to the next
st−1 ot−1 ht−1 -0.2 st
1.1 state
selective write ..
. Each element of ot−1 gets multiplied
-0.3
18/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
-1.4 0.2 0.5 -1.4
Selective Write
-0.4 0.34 0.36 -0.4
1 0.9
= 0.9 1 Well the RNN has to learn ot−1 along
.. .. .. ..
. . . . with the other parameters (W, U, V )
-2 0.29 0.6 -2
xt
s˜t = σ(W ht−1 + U xt + b)
20/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
-1.4 0.2 0.5 0.4 0.8 -1.4
Selective Read
-0.4 0.34 0.36 W 0.6 0.66 -0.4
1 0.9
= 0.9 σ 0.1 0.1 1 s˜t thus captures all the information
.. .. .. .. .. ..
. . . U . . . from the previous state (ht−1 ) and the
-2 0.29 0.6 0.7 0.2 0.71 -2
current input xt
st−1 ot−1 ht−1 -0.2 s̃t it st
selective write
1.1
.. selective read However, we may not want to
. use all this new information and
-0.3
it = σ(Wi ht−1 + Ui xt + bi )
21/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
-1.4
-0.4
0.2
0.34
0.5
0.36 W
0.4
0.6
0.8
0.66
-1.4
-0.4
So far we have the following
1 0.9
= 0.9 σ 0.1 0.1 1
.. .. .. .. .. ..
. . . U . . . Previous state:
-2 0.29 0.6 0.7 0.2 0.71 -2
xt
23/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
-1.4 0.2 0.5 0.4 0.8 -1.4 0.9 -1.5 0.19 0.4
-0.4 0.34 0.36 W 0.6 0.66 -0.4 0.7 0.2 0.34 0.34
1 0.9
= 0.9 σ 0.1 0.1 1 0.9
= 1 0.9
= 0.8
.. .. .. .. .. + .. .. .. .. ..
. . . U . . . . . . .
-2 0.29 0.6 0.7 0.2 0.71 -2 0.8 -1.9 0.32 0.12
xt
Gates: States:
25/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
-1.4 0.2 0.4 0.8 -1.4 0.2 -1.5
-0.4 0.34 W 0.6 0.66 -0.4 0.34 0.2
1 0.9 σ 0.1 0.1 1 0.9
= 1
.. .. .. .. + .. .. ..
. . U . . . . .
-2 0.29 0.7 0.2 0.71 -2 0.29 -1.9
xt
26/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
Module 15.3: How LSTMs avoid the problem of
vanishing gradients
27/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
-1.4 0.2 0.5 0.4 0.8 -1.4 0.9 -1.5 0.19 0.4
-0.4 0.34 0.36 W 0.6 0.66 -0.4 0.7 0.2 0.34 0.34
1 0.9
= 0.9 σ 0.1 0.1 1 0.9
= 1 0.9
= 0.8
.. .. .. .. .. + .. .. .. .. ..
. . . U . . . . . . .
-2 0.29 0.6 0.7 0.2 0.71 -2 0.8 -1.9 0.32 0.12
xt
Intuition
During forward propagation the Similarly during backward
gates control the flow of information propagation they control the flow of
They prevent any irrelevant gradients
information from being written to It is easy to see that during
the state backward pass the gradients will get
multiplied by the gate
28/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
-1.4 0.2 0.5 0.4 0.8 -1.4 0.9 -1.5 0.19 0.4
-0.4 0.34 0.36 W 0.6 0.66 -0.4 0.7 0.2 0.34 0.34
1 0.9
= 0.9 σ 0.1 0.1 1 0.9
= 1 0.9
= 0.8
.. .. .. .. .. + .. .. .. .. ..
. . . U . . . . . . .
-2 0.29 0.6 0.7 0.2 0.71 -2 0.8 -1.9 0.32 0.12
xt
If the state at time t − 1 did not contribute much to the state at time t (i.e., if
kft k → 0 and kot−1 k → 0) then during backpropagation the gradients flowing
into st−1 will vanish
But this kind of a vanishing gradient is fine (since st−1 did not contribute to st
we don’t want to hold it responsible for the crimes of st )
The key difference from vanilla RNNs is that the flow of information and
gradients is controlled by the gates which ensure that the gradients vanish only
when they should (i.e., when st−1 didn’t contribute much to st )
29/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
We will now see an illustrative proof of how the gates control the flow of gradients
30/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
L1 (θ) L2 (θ) L3 (θ) L4 (θ)
Recall that RNNs had this
multiplicative term which caused the
V V V V gradients to vanish
s1
W s2
W s3
W s4
W ... t t−1
∂Lt (θ) ∂Lt (θ) X Y ∂sj+1 ∂ + sk
U U U U
=
∂W ∂st ∂sj ∂W
k=1 j=k
x1 x2 x3 x4
In particular, if the loss at L4 (θ) was
s0 s1 s2 s3 s4 L4 (θ) high because W was not good enough
to compute s1 correctly then this
information will not be propagated
W back to W as the gradient ∂L t (θ)
∂W
along this long path will vanish
31/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
L1 (θ) L2 (θ) L3 (θ) L4 (θ)
32/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
We will start with the dependency graph
involving different variables in LSTMs
Starting with the states at timestep k − 1
ok = σ(Wo hk−1 + Uo xk + bo )
sk−1
hk−1 Wo , Uo , bo
For simplicity we will omit the parameters for
s̃k fk ik ok now and return back to them later
sk ik = σ(Wi hk−1 + Ui xk + bi )
hk
fk = σ(Wf hk−1 + Uf xk + bf )
s˜k = σ(W hk−1 + U xk + b)
sk = fk sk−1 + ik s˜k
hk = ok σ(sk )
33/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
sk−1
hk−1
s̃k fk ik ok
sk
hk
st
ht Lt (θ)
35/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
It is sufficient to show that ∂L∂st (θ)
k
does not
sk−1
vanish (because if this does not vanish we can
hk−1
W Wo reach Wf through sk )
Wf Wi
ok
First, we observe that there are multiple paths
s̃k fk ik
from Lt (θ) to sk (you just need to reverse the
sk direction of the arrows for backpropagation)
hk
For example, there is one path through sk+1 ,
another through hk
st−1 Further, there are multiple paths to reach
ht−1 to hk itself (as should be obvious from the
number of outgoing arrows from hk )
s̃t ft it ot So at this point just convince yourself that
there are many paths from Lt (θ) to sk
st
ht Lt (θ)
36/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
Consider one such path (highlighted) which
sk−1 will contribute to the gradient
hk−1 Let us denote the gradient along this path as
t0
s̃k fk ik ok
∂Lt (θ) ∂ht ∂st ∂sk+1
t0 = ...
sk ∂ht ∂st ∂st−1 ∂sk
hk
ht Lt (θ)
37/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
∂ht
Let us first look at ∂st
sk−1 Recall that
hk−1
ht = o t σ(st )
s̃k fk ik ok
Note that hti only depends on oti and sti and
sk not on any other elements of ot and st
hk ∂ht
∂st will thus be a square diagonal matrix
∈ Rd×d whose diagonal will be
st−1 ot σ 0 (st ) ∈ Rd (see slide 35 of Lecture 14)
ht−1 We will represent this diagonal matrix by
D(ot σ 0 (st ))
s̃t ft it ot
st
ht Lt (θ)
38/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
∂st
Now let us consider ∂st−1
sk−1 Recall that
hk−1
st = ft st−1 + it s˜t
s̃k fk ik ok
Notice that s˜t also depends on st−1 so we
sk cannot treat it as a constant
hk
So once again we are dealing with an ordered
network and thus ∂s∂st−1t
will be a sum of an
st−1 explicit term and an implicit term (see slide
ht−1 37 from Lecture 14)
For simplicity, let us assume that the gradient
s̃t ft it ot from the implicit term vanishes (we are
assuming a worst case scenario)
st And the gradient from the explicit term
ht Lt (θ) (treating s˜t as a constant) is given by D(ft )
39/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
We now return back to our full expression for
sk−1 t0 :
hk−1
∂Lt (θ) ∂ht ∂st ∂sk+1
t0 = ...
s̃k fk ik ok ∂ht ∂st ∂st−1 ∂sk
0 0
= Lt (ht ).D(ot σ (st ))D(ft ) . . . D(fk+1 )
sk
hk = L0t (ht ).D(ot σ 0 (st ))D(ft ... fk+1 )
= L0t (ht ).D(ot 0
σ (st ))D( t
i=k+1 fi )
st−1 The red terms don’t vanish and the blue terms
ht−1 contain a multiplication of the forget gates
The forget gates thus regulate the gradient
s̃t ft it ot flow depending on the explicit contribution of
a state (st ) to the next state st+1
st
ht Lt (θ)
40/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
If during forward pass st did not contribute
sk−1 much to st+1 (because ft → 0) then during
hk−1 backpropgation also the gradient will not
reach st
s̃k fk ik ok This is fine because if st did not contribute
much to st+1 then there is no reason to hold
sk it responsible during backpropgation (ft does
hk
the same regulation during forward pass and
backward pass which is fair)
st−1 Thus there exists this one path along which
ht−1 the gradient doesn’t vanish when it shouldn’t
And as argued as long as the gradient flows
s̃t ft it ot back to Wf through one of the paths (t0 )
through sk we are fine !
st Of course the gradient flows back only when
ht Lt (θ) required as regulated by fi ’s (but let me just
say it one last time that this is fair )
41/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
Now we will see why LSTMs do not solve the
problem of exploding gradients
sk−1
hk−1 We will show a path through which the
gradient can explode
∂Lt (θ)
s̃k fk ik ok Let us compute one term (say t1 ) of ∂hk−1
corresponding to the highlighted path
sk
hk ∂Lt (θ)
∂ht ∂ot
∂hk ∂ok
t1 = ...
∂ht ∂ot ∂ht−1 ∂ok ∂hk−1
=L0t (ht ) (D(σ(st )o0t ).Wo ) . . .
st−1
0
ht−1 (D(σ(sk ) ok ).Wo )
t−k+1
kt1 k ≤kL0t (ht )k (kKkkWo k)
s̃t ft it ot
Depending on the norm of matrix Wo , the
st gradient ∂L t (θ)
∂hk−1 may explode
ht Lt (θ) Similarly, Wi , Wf and W can also cause the
gradients to explode 42/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
So how do we deal with the problem of
sk−1 exploding gradients ?
hk−1
One popular trick is to use gradient clipping
While backpropagating if the norm of the
s̃k fk ik ok
gradient exceeds a certain value, it is scaled to
sk keep its norm within an acceptable threshold∗
hk
Essentially we retain the direction of the
gradient but scale down the norm
st−1 ∗
Pascanu, Razvan, Tomas Mikolov, and Yoshua Bengio.
ht−1 “On the difficulty of training recurrent neural networks.”
ICML(3)28(2013):1310-1318
s̃t ft it ot
st
ht Lt (θ)
43/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15