Skip to content

Which backprop method is correct for RNN? #635

@csarofeen

Description

@csarofeen

Accumulating loss incrementally with timestep as in the tutorial, and sending all tiemsteps to RNN seem to produce the same output/hidden/loss but loss.backwards is calculating different parameter gradients. Is there a correct and incorrect method to do this? Which is right?

import torch
from torch import nn
from torch.autograd import Variable

torch.backends.cudnn.enabled=False

model  = nn.LSTM(5, 5, 1).cuda()
model2 = nn.LSTM(5, 5, 1).cuda()

for i in range(len(model2.all_weights)):
    for j in range(len(model2.all_weights[i])):
        model2.all_weights[i][j].data.copy_(model.all_weights[i][j].data)

crit = nn.MSELoss().cuda()
crit2 = nn.MSELoss().cuda()

input = Variable(torch.randn(2,1,5).cuda())
target = Variable(torch.ones(2,1,5).cuda(), requires_grad=False)
hidden = [ Variable(torch.randn(1,1,5).cuda().fill_(0.0)),
            Variable(torch.randn(1,1,5).cuda().fill_(0.0))]

output, hidden = model(input, hidden)
loss = crit(output, target)
loss.backward(retain_variables=True)

hidden2 = [ Variable(torch.randn(1,1,5).cuda().fill_(0.0)),
            Variable(torch.randn(1,1,5).cuda().fill_(0.0))]

loss2 = 0
for i in range(input.size(0)):
    output2, hidden2 = model(input[i].view(1,1,-1), hidden2)
    loss2 += crit2(output2[0], target[i])

loss2 = loss2/2
loss2.backward(retain_variables=True)

diff = 0
max_w = 0
for i in range(len(model2.all_weights)):
    for j in range(len(model2.all_weights[i])):
        diff = max(diff, (model2.all_weights[i][j].grad - model.all_weights[i][j].grad).abs().max().data[0])
        
        max_w = max(model2.all_weights[i][j].grad.max().data[0], max_w)
        max_w = max(model.all_weights[i][j].grad.max().data[0], max_w)

dh = (hidden[0]-hidden2[0]).abs().max().data[0]
dc = (hidden[1]-hidden2[1]).abs().max().data[0]
do = (output[1]-output2).abs().max().data[0]
dl = (loss-loss2).abs().max().data[0]

print("Diff in output : " + str(do))
print("Diff in hidden states : "+str(dh) +", "+str(dc))
print("Diff in loss : " + str(dl))

print("Max weight grad found : " +str(max_w))
print("Diff in weight grad : " + str(diff))

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions