0% found this document useful (0 votes)
9 views43 pages

Deep Learning: LSTM & GRU Explained

Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
9 views43 pages

Deep Learning: LSTM & GRU Explained

Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd

CS7015 (Deep Learning) : Lecture 15

Long Short Term Memory Cells (LSTMs), Gated Recurrent Units (GRUs)

Mitesh M. Khapra

Department of Computer Science and Engineering


Indian Institute of Technology Madras

1/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
Module 15.1: Selective Read, Selective Write, Selective
Forget - The Whiteboard Analogy

2/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
y1 y2 y3 y4 yt
The state (si ) of an RNN records
information from all previous time
steps
V V V V V
At each new timestep the old
s1
W s2
W s3
W s4
W ... W st
information gets morphed by the
current input
U U U U U
One could imagine that after t steps
the information stored at time step
x1 x2 x3 x4 xt
t − k (for some k < t) gets completely
morphed
so much that it would be impossible
to extract the original information
stored at time step t − k

3/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
y1 y2 y3 y4 yt
A similar problem occurs when
the information flows backwards
(backpropagation)
V V V V V
It is very hard to assign the
W W W W ... W
s1 s2 s3 s4 st responsibility of the error caused
at time step t to the events that
U U U U U
occurred at time step t − k
This responsibility is of course in the
x1 x2 x3 x4 xt
form of gradients and we studied the
problem in backward flow of gradients
We saw a formal argument for this
while discussing vanishing gradients

4/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
Let us see an analogy for this
We can think of the state as a fixed
size memory
Compare this to a fixed size white
board that you use to record
information
At each time step (periodic intervals)
we keep writing something to the
board
Effectively at each time step we
morph the information recorded till
that time point
After many timesteps it would be
impossible to see how the information
at time step t − k contributed to the
state at timestep t 5/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
Continuing our whiteboard analogy,
suppose we are interested in deriving
an expression on the whiteboard
We follow the following strategy at
each time step
Selectively write on the board
Selectively read the already written
content
Selectively forget (erase) some
content
Let us look at each of these in detail

6/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
a = 1 b = 3 c = 5 d = 11 Selective write
There may be many steps in the
Compute ac(bd + a) + ad
derivation but we may just skip a few
Say “board” can have only 3 statements In other words we select what to
at a time. write
1 ac
2 bd
3 bd + a
4 ac(bd + a)
5 ad
6 ac(bd + a) + ad

ac = 5
bd = 33

7/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
a = 1 b = 3 c = 5 d = 11 Selective read
While writing one step we typically
Compute ac(bd + a) + ad
read some of the previous steps we
Say “board” can have only 3 statements have already written and then decide
at a time. what to write next
1 ac For example at Step 3, information
from Step 2 is important
2 bd
In other words we select what to read
3 bd + a
4 ac(bd + a)
5 ad
6 ac(bd + a) + ad

ac = 5
bd = 33
bd + a = 34
8/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
a = 1 b = 3 c = 5 d = 11 Selective forget
Once the board is full, we need to
Compute ac(bd + a) + ad
delete some obsolete information
Say “board” can have only 3 statements But how do we decide what to delete?
at a time. We will typically delete the least
1 ac useful information
2 bd In other words we select what to
forget
3 bd + a
4 ac(bd + a)
5 ad
6 ac(bd + a) + ad

ac = 5
bd = 33
bd + a = 34
9/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
a = 1 b = 3 c = 5 d = 11 There are various other scenarios
Compute ac(bd + a) + ad where we can motivate the need for
selective write, read and forget
Say “board” can have only 3 statements
at a time. For example, you could think of our
brain as something which can store
1 ac only a finite number of facts
2 bd At different time steps we selectively
3 bd + a read, write and forget some of these
facts
4 ac(bd + a)
Since the RNN also has a finite state
5 ad
size, we need to figure out a way to
6 ac(bd + a) + ad allow it to selectively read, write and
forget
ad + ac(bd + a) = 181
ac(bd + a) = 170
ad = 11
10/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
Module 15.2: Long Short Term Memory(LSTM) and
Gated Recurrent Units(GRUs)

11/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
Questions
Can we give a concrete example where RNNs also need to selectively read,
write and forget ?
How do we convert this intuition into mathematical equations ? We will see
this over the next few slides

12/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
+/− Consider the task of predicting the sentiment
(positive/negative) of a review
RNN reads the document from left to right
and after every word updates the state
By the time we reach the end of the document
the information obtained from the first few
words is completely lost
The first ... ... ... performance
Ideally we want to
forget the information added by stop words
Review: The first half of the movie was dry but
the second half really picked up pace. The lead
(a, the, etc.)
actor delivered an amazing performance selectively read the information added by
previous sentiment bearing words (awesome,
amazing, etc.)
selectively write new information from the
current word to the state

13/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
Questions
Can we give a concrete example where RNNs also need to selectively read,
write and forget ?
How do we convert this intuition into mathematical equations ?

14/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
+/− Recall that the blue colored vector
(st ) is called the state of the RNN
It has a finite size (st ∈ Rn ) and is
used to store all the information upto
timestep t
This state is analogous to the
whiteboard and sooner or later it will
The first ... ... ... performance
get overloaded and the information
from the initial states will get
Review: The first half of the movie was dry but
the second half really picked up pace. The lead
morphed beyond recognition
actor delivered an amazing performance Wishlist: selective write, selective
read and selective forget to ensure
that this finite sized state vector is
used effectively

15/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
-1.4
-0.4 selective read
-0.9
0.2
Just to be clear, we have computed
selective write
1
.. selective forget
1
.. a state st−1 at timestep t − 1 and
. . now we want to overload it with new
-2 -1.9

st−1 0.7 st
information (xt ) and compute a new
-0.2
1.1
state (st )
..
. While doing so we want to make sure
-0.3
that we use selective write, selective
xt
read and selective forget so that only
important information is retained in
st
We will now see how to implement
these items from our wishlist

16/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
-1.4 -0.9
Selective Write
-0.4
0 W 0.2
1 σ 1 Recall that in RNNs we use st−1 to
.. ..
. . compute st
U
-2
0 -1.9
st = σ(W st−1 + U xt ) (ignoring bias)
st−1 0.7 st
-0.2
1.1
But now instead of passing st−1 as it
..
.
is to st we want to pass (write) only
-0.3 some portions of it to the next state
xt
In the strictest case our decisions
could be binary (for example, retain
1st and 3rd entries and delete the
rest of the entries)
But a more sensible way of doing
this would be to assign a value
between 0 and 1 which determines
what fraction of the current state to
pass on to the next state
17/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
-1.4 0.2 0.5 -1.4
Selective Write
-0.4 0.34 0.36 -0.4
1 0.9
= 0.9 1 We introduce a vector ot−1 which
.. .. .. ..
. . . . decides what fraction of each element
-2 0.29 0.6 0.7 -2
of st−1 should be passed to the next
st−1 ot−1 ht−1 -0.2 st
1.1 state
selective write ..
. Each element of ot−1 gets multiplied
-0.3

xt with the corresponding element of


st−1
Each element of ot−1 is restricted to
be between 0 and 1
But how do we compute ot−1 ? How
does the RNN know what fraction of
the state to pass on?

18/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
-1.4 0.2 0.5 -1.4
Selective Write
-0.4 0.34 0.36 -0.4
1 0.9
= 0.9 1 Well the RNN has to learn ot−1 along
.. .. .. ..
. . . . with the other parameters (W, U, V )
-2 0.29 0.6 -2

st−1 ot−1 ht−1


0.7
-0.2 st
We compute ot−1 and ht−1 as
1.1
selective write ..
. ot−1 = σ(Wo ht−2 + Uo xt−1 + bo )
-0.3

xt ht−1 = ot−1 σ(st−1 )

The parameters Wo , Uo , bo need to


be learned along with the existing
parameters W, U, V
The sigmoid (logistic) function
ensures that the values are between
0 and 1
ot is called the output gate as it
decides how much to pass (write) to
the next time step 19/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
-1.4 0.2 0.5 0.4 -1.4
Selective Read
-0.4 0.34 0.36 W 0.6 -0.4
1 0.9
= 0.9 σ 0.1 1 We will now use ht−1 to compute the
.. .. .. .. ..
. . . U . . new state at the next time step
-2 0.29 0.6 0.2 -2

st−1 ot−1 ht−1


0.7
-0.2 s̃t st
We will also use xt which is the new
selective write
1.1
.. input at time step t
.
-0.3

xt
s˜t = σ(W ht−1 + U xt + b)

Note that W, U and b are similar to


the parameters that we used in RNN
(for simplicity we have not shown the
bias b in the figure)

20/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
-1.4 0.2 0.5 0.4 0.8 -1.4
Selective Read
-0.4 0.34 0.36 W 0.6 0.66 -0.4
1 0.9
= 0.9 σ 0.1 0.1 1 s˜t thus captures all the information
.. .. .. .. .. ..
. . . U . . . from the previous state (ht−1 ) and the
-2 0.29 0.6 0.7 0.2 0.71 -2
current input xt
st−1 ot−1 ht−1 -0.2 s̃t it st

selective write
1.1
.. selective read However, we may not want to
. use all this new information and
-0.3

xt only selectively read from it before


constructing the new cell state st
To do this we introduce another gate
called the input gate

it = σ(Wi ht−1 + Ui xt + bi )

and use it s˜t as the selectively read


state information

21/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
-1.4
-0.4
0.2
0.34
0.5
0.36 W
0.4
0.6
0.8
0.66
-1.4
-0.4
So far we have the following
1 0.9
= 0.9 σ 0.1 0.1 1
.. .. .. .. .. ..
. . . U . . . Previous state:
-2 0.29 0.6 0.7 0.2 0.71 -2

st−1 ot−1 ht−1 -0.2 s̃t it st st−1


1.1
selective write .. selective read Output gate:
.
-0.3
ot−1 = σ(Wo ht−2 + Uo xt−1 + bo )
xt
Selectively Write:
ht−1 = ot−1 σ(st−1 )
Current (temporary) state:
s˜t = σ(W ht−1 + U xt + b)
Input gate:
it = σ(Wi ht−1 + Ui xt + bi )
Selectively Read:
it s˜t
22/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
-1.4 0.2 0.5 0.4 0.8 -1.4 0.9 -0.9
-0.4 0.34 0.36 W 0.6 0.66 -0.4 0.7 0.2
1 0.9
= 0.9 σ 0.1 0.1 1 0.9
= 1
.. .. .. .. .. + .. .. ..
. . . U . . . . .
-2 0.29 0.6 0.7 0.2 0.71 -2 0.8 -1.9

st−1 ot−1 ht−1 -0.2 s̃t it st−1 ft st


1.1
selective write .. selective read selective forget
.
-0.3

xt

Selective Forget But we may not want to use the whole


How do we combine st−1 and s˜t to get of st−1 but forget some parts of it
the new state To do this we introduce the forget
Here is one simple (but effective) way gate
of doing this:
ft = σ(Wf ht−1 + Uf xt + bf )
st = st−1 + it s˜t st = ft st−1 + it s˜t

23/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
-1.4 0.2 0.5 0.4 0.8 -1.4 0.9 -1.5 0.19 0.4
-0.4 0.34 0.36 W 0.6 0.66 -0.4 0.7 0.2 0.34 0.34
1 0.9
= 0.9 σ 0.1 0.1 1 0.9
= 1 0.9
= 0.8
.. .. .. .. .. + .. .. .. .. ..
. . . U . . . . . . .
-2 0.29 0.6 0.7 0.2 0.71 -2 0.8 -1.9 0.32 0.12

st−1 ot−1 ht−1 -0.2 s̃t it st−1 ft st ot ht


1.1
selective write .. selective read selective forget selective write
.
-0.3

xt

We now have the full set of equations for LSTMs


The green box together with the selective write operations following it, show
all the computations which happen at timestep t

Gates: States:

ot = σ(Wo ht−1 + Uo xt + bo ) s˜t = σ(W ht−1 + U xt + b)


it = σ(Wi ht−1 + Ui xt + bi ) st = ft st−1 + it s˜t
ft = σ(Wf ht−1 + Uf xt + bf ) ht = o t σ(st ) and rnnout = ht
24/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
Note
LSTM has many variants which include different number of gates and also
different arrangement of gates
The one which we just saw is one of the most popular variants of LSTM
Another equally popular variant of LSTM is Gated Recurrent Unit which we
will see next

25/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
-1.4 0.2 0.4 0.8 -1.4 0.2 -1.5
-0.4 0.34 W 0.6 0.66 -0.4 0.34 0.2
1 0.9 σ 0.1 0.1 1 0.9
= 1
.. .. .. .. + .. .. ..
. . U . . . . .
-2 0.29 0.7 0.2 0.71 -2 0.29 -1.9

st−1 ot -0.2 s̃t it st−1 1 − it st


1.1
..
.
-0.3

xt

The full set of equations for GRUs


Gates:
ot = σ(Wo st−1 + Uo xt + bo ) No explicit forget gate (the forget
it = σ(Wi st−1 + Ui xt + bi ) gate and input gates are tied)
The gates depend directly on st−1 and
States: not the intermediate ht−1 as in the
s˜t = σ(W (ot st−1 ) + U xt + b) case of LSTMs
st = (1 − it ) st−1 + it s˜t

26/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
Module 15.3: How LSTMs avoid the problem of
vanishing gradients

27/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
-1.4 0.2 0.5 0.4 0.8 -1.4 0.9 -1.5 0.19 0.4
-0.4 0.34 0.36 W 0.6 0.66 -0.4 0.7 0.2 0.34 0.34
1 0.9
= 0.9 σ 0.1 0.1 1 0.9
= 1 0.9
= 0.8
.. .. .. .. .. + .. .. .. .. ..
. . . U . . . . . . .
-2 0.29 0.6 0.7 0.2 0.71 -2 0.8 -1.9 0.32 0.12

st−1 ot−1 ht−1 -0.2 s̃t it st−1 ft st ot ht


1.1
selective write .. selective read selective forget selective write
.
-0.3

xt

Intuition
During forward propagation the Similarly during backward
gates control the flow of information propagation they control the flow of
They prevent any irrelevant gradients
information from being written to It is easy to see that during
the state backward pass the gradients will get
multiplied by the gate

28/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
-1.4 0.2 0.5 0.4 0.8 -1.4 0.9 -1.5 0.19 0.4
-0.4 0.34 0.36 W 0.6 0.66 -0.4 0.7 0.2 0.34 0.34
1 0.9
= 0.9 σ 0.1 0.1 1 0.9
= 1 0.9
= 0.8
.. .. .. .. .. + .. .. .. .. ..
. . . U . . . . . . .
-2 0.29 0.6 0.7 0.2 0.71 -2 0.8 -1.9 0.32 0.12

st−1 ot−1 ht−1 -0.2 s̃t it st−1 ft st ot ht


1.1
selective write .. selective read selective forget selective write
.
-0.3

xt

If the state at time t − 1 did not contribute much to the state at time t (i.e., if
kft k → 0 and kot−1 k → 0) then during backpropagation the gradients flowing
into st−1 will vanish
But this kind of a vanishing gradient is fine (since st−1 did not contribute to st
we don’t want to hold it responsible for the crimes of st )
The key difference from vanilla RNNs is that the flow of information and
gradients is controlled by the gates which ensure that the gradients vanish only
when they should (i.e., when st−1 didn’t contribute much to st )

29/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
We will now see an illustrative proof of how the gates control the flow of gradients

30/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
L1 (θ) L2 (θ) L3 (θ) L4 (θ)
Recall that RNNs had this
multiplicative term which caused the
V V V V gradients to vanish
s1
W s2
W s3
W s4
W ... t t−1
∂Lt (θ) ∂Lt (θ) X Y ∂sj+1 ∂ + sk
U U U U
=
∂W ∂st ∂sj ∂W
k=1 j=k

x1 x2 x3 x4
In particular, if the loss at L4 (θ) was
s0 s1 s2 s3 s4 L4 (θ) high because W was not good enough
to compute s1 correctly then this
information will not be propagated
W back to W as the gradient ∂L t (θ)
∂W
along this long path will vanish

31/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
L1 (θ) L2 (θ) L3 (θ) L4 (θ)

In general, the gradient of Lt (θ)


w.r.t. θi vanishes when the gradients
V V V V flowing through each and every
s1
W s2
W s3
W s4
W ... path from Lt (θ) to θi vanish.
On the other hand, the gradient of
U U U U
Lt (θ) w.r.t. θi explodes when the
gradient flowing through at least
x1 x2 x3 x4
one path explodes.
s0 s1 s2 s3 s4 L4 (θ) We will first argue that in the case of
LSTMs there exists at least one path
through which the gradients can flow
W effectively (and hence no vanishing
gradients)

32/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
We will start with the dependency graph
involving different variables in LSTMs
Starting with the states at timestep k − 1

ok = σ(Wo hk−1 + Uo xk + bo )
sk−1
hk−1 Wo , Uo , bo
For simplicity we will omit the parameters for
s̃k fk ik ok now and return back to them later

sk ik = σ(Wi hk−1 + Ui xk + bi )
hk
fk = σ(Wf hk−1 + Uf xk + bf )
s˜k = σ(W hk−1 + U xk + b)
sk = fk sk−1 + ik s˜k
hk = ok σ(sk )

33/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
sk−1
hk−1

s̃k fk ik ok

sk
hk

Starting from hk−1 and sk−1 we have reached


st−1 hk and sk
ht−1
And the recursion will now continue till the
last timestep
s̃t ft it ot
For simplicity and ease of illustration, instead
st of considering the parameters (W , Wo , Wi ,
ht Lt (θ) Wf , U , Uo , Ui , Uf ) as separate nodes in
the graph we will just put them on the
appropriate edges. (We show only a few34/43
sk−1
Mitesh M. Khapraparameters and
CS7015 (Deep not all)
Learning) : Lecture 15
For example, we are interested in knowing if
sk−1 the gradient flows to Wf through sk
hk−1
W Wo In other words, if Lt (θ) was high because Wf
Wi
Wf failed to compute an appropriate value for sk
s̃k fk ik ok
then this information should flow back to Wf
through the gradients
sk
hk
We can ask a similar question about the other
parameters (for example, Wi , Wo , W , etc.)
st−1
How does LSTM ensure that this gradient
ht−1 does not vanish even at arbitrary time steps?
Let us see
s̃t ft it ot

st

ht Lt (θ)

35/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
It is sufficient to show that ∂L∂st (θ)
k
does not
sk−1
vanish (because if this does not vanish we can
hk−1
W Wo reach Wf through sk )
Wf Wi

ok
First, we observe that there are multiple paths
s̃k fk ik
from Lt (θ) to sk (you just need to reverse the
sk direction of the arrows for backpropagation)
hk
For example, there is one path through sk+1 ,
another through hk
st−1 Further, there are multiple paths to reach
ht−1 to hk itself (as should be obvious from the
number of outgoing arrows from hk )
s̃t ft it ot So at this point just convince yourself that
there are many paths from Lt (θ) to sk
st

ht Lt (θ)

36/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
Consider one such path (highlighted) which
sk−1 will contribute to the gradient
hk−1 Let us denote the gradient along this path as
t0
s̃k fk ik ok
∂Lt (θ) ∂ht ∂st ∂sk+1
t0 = ...
sk ∂ht ∂st ∂st−1 ∂sk
hk

The first term ∂L t (θ)


∂ht is fine and it doesn’t
st−1
vanish (ht is directly connected to Lt (θ) and
ht−1 there are no intermediate nodes which can
cause the gradient to vanish)
ot
We will now look at the other terms
s̃t ft it ∂ht ∂st
∂st ∂st−1 (∀t)
st

ht Lt (θ)

37/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
∂ht
Let us first look at ∂st
sk−1 Recall that
hk−1

ht = o t σ(st )
s̃k fk ik ok
Note that hti only depends on oti and sti and
sk not on any other elements of ot and st
hk ∂ht
∂st will thus be a square diagonal matrix
∈ Rd×d whose diagonal will be
st−1 ot σ 0 (st ) ∈ Rd (see slide 35 of Lecture 14)
ht−1 We will represent this diagonal matrix by
D(ot σ 0 (st ))
s̃t ft it ot

st

ht Lt (θ)

38/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
∂st
Now let us consider ∂st−1
sk−1 Recall that
hk−1

st = ft st−1 + it s˜t
s̃k fk ik ok
Notice that s˜t also depends on st−1 so we
sk cannot treat it as a constant
hk
So once again we are dealing with an ordered
network and thus ∂s∂st−1t
will be a sum of an
st−1 explicit term and an implicit term (see slide
ht−1 37 from Lecture 14)
For simplicity, let us assume that the gradient
s̃t ft it ot from the implicit term vanishes (we are
assuming a worst case scenario)
st And the gradient from the explicit term
ht Lt (θ) (treating s˜t as a constant) is given by D(ft )
39/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
We now return back to our full expression for
sk−1 t0 :
hk−1
∂Lt (θ) ∂ht ∂st ∂sk+1
t0 = ...
s̃k fk ik ok ∂ht ∂st ∂st−1 ∂sk
0 0
= Lt (ht ).D(ot σ (st ))D(ft ) . . . D(fk+1 )
sk
hk = L0t (ht ).D(ot σ 0 (st ))D(ft ... fk+1 )
= L0t (ht ).D(ot 0
σ (st ))D( t
i=k+1 fi )

st−1 The red terms don’t vanish and the blue terms
ht−1 contain a multiplication of the forget gates
The forget gates thus regulate the gradient
s̃t ft it ot flow depending on the explicit contribution of
a state (st ) to the next state st+1
st

ht Lt (θ)

40/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
If during forward pass st did not contribute
sk−1 much to st+1 (because ft → 0) then during
hk−1 backpropgation also the gradient will not
reach st
s̃k fk ik ok This is fine because if st did not contribute
much to st+1 then there is no reason to hold
sk it responsible during backpropgation (ft does
hk
the same regulation during forward pass and
backward pass which is fair)
st−1 Thus there exists this one path along which
ht−1 the gradient doesn’t vanish when it shouldn’t
And as argued as long as the gradient flows
s̃t ft it ot back to Wf through one of the paths (t0 )
through sk we are fine !
st Of course the gradient flows back only when
ht Lt (θ) required as regulated by fi ’s (but let me just
say it one last time that this is fair )
41/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
Now we will see why LSTMs do not solve the
problem of exploding gradients
sk−1
hk−1 We will show a path through which the
gradient can explode
∂Lt (θ)
s̃k fk ik ok Let us compute one term (say t1 ) of ∂hk−1
corresponding to the highlighted path
sk
hk ∂Lt (θ)

∂ht ∂ot
 
∂hk ∂ok

t1 = ...
∂ht ∂ot ∂ht−1 ∂ok ∂hk−1
=L0t (ht ) (D(σ(st )o0t ).Wo ) . . .
st−1
0
ht−1 (D(σ(sk ) ok ).Wo )
t−k+1
kt1 k ≤kL0t (ht )k (kKkkWo k)
s̃t ft it ot
Depending on the norm of matrix Wo , the
st gradient ∂L t (θ)
∂hk−1 may explode
ht Lt (θ) Similarly, Wi , Wf and W can also cause the
gradients to explode 42/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15
So how do we deal with the problem of
sk−1 exploding gradients ?
hk−1
One popular trick is to use gradient clipping
While backpropagating if the norm of the
s̃k fk ik ok
gradient exceeds a certain value, it is scaled to
sk keep its norm within an acceptable threshold∗
hk
Essentially we retain the direction of the
gradient but scale down the norm
st−1 ∗
Pascanu, Razvan, Tomas Mikolov, and Yoshua Bengio.
ht−1 “On the difficulty of training recurrent neural networks.”
ICML(3)28(2013):1310-1318
s̃t ft it ot

st

ht Lt (θ)

43/43
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 15

You might also like