Skip to content

Rewrite fused RNN kernels to avoid clones in backward() #1532

@aosokin

Description

@aosokin

Hi, in my model, I am applying a linear layer on top of the GRUCell.
When sequentially calling backward on two different outputs, the gradient w.r.t. GRUCell parameters is substantially different on CPU vs GPU. The gradient w.r.t. parameters of the linear layer is identical.
The GRUCell layer alone seems to have correct gradient on both CPU and GPU.

The code below produces

weight_ih 0.40665897727012634
weight_hh 0.06703907996416092
bias_ih 0.0623239129781723
bias_hh 0.13187159597873688
weight 0.3885607421398163
bias 1.3700947761535645

on GPU and

weight_ih 0.18657396119047137
weight_hh 0.05437005038803549
bias_ih 0.053355072396595864
bias_hh 0.03761203796902013
weight 0.3885607195130627
bias 1.3700947601746332

on CPU. Tested in master branch, he latest version (commit 5bb1348)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from collections import OrderedDict

# GPU vs CPU
use_cuda = False
th = torch.cuda if use_cuda else torch

# Create layers
rnn = nn.GRUCell(2, 2)
lin = nn.Linear(2, 27)
if use_cuda:
    rnn.cuda()
    lin.cuda()

# damped state to reproduce the bug
rnn_state = OrderedDict({'weight_ih': th.FloatTensor([
    [-0.6940, 0.7056], [0.4358, 0.6890], [-0.4579, 0.5398],
    [0.1502, -0.1490], [-0.2306, 0.2639], [0.1209, -0.1180]]),
    'weight_hh': th.FloatTensor([
        [-0.2607, 0.2100], [-0.5842, -0.2695], [-0.3918, 0.1502],
        [-0.4412, 0.0336], [-0.1018, -0.6725], [0.2238, 0.4111]]),
    'bias_ih': th.FloatTensor([
        -0.3340, -0.6903, -0.6900, -0.2925, 0.0268, -0.6987]),
    'bias_hh': th.FloatTensor([
        0.1496, -0.1751, 0.6546, -0.6178, 0.1265, 0.5848])})

lin_state = OrderedDict({'weight': th.FloatTensor([
    [-0.0408, -0.1192], [0.1038, -0.0573], [0.6051, -0.2123],
    [0.4644, 0.1803], [-0.2503, 0.4022], [-0.3716, 0.2095],
    [-0.3357, 0.3510], [0.0063, 0.5741], [0.3451, 0.5141],
    [0.0046, 0.2462], [-0.6112, -0.4036], [0.3162, 0.1947],
    [0.4327, 0.0086], [-0.2977, 0.3072], [0.2321, 0.4708],
    [-0.1038, 0.3433], [0.5503, 0.4901], [0.4882, -0.6168],
    [0.0360, -0.6624], [0.2708, -0.3415], [0.5421, 0.5666],
    [-0.0061, 0.4577], [-0.2641, 0.2584], [0.3587, 0.2839],
    [-0.3959, 0.3206], [-0.1576, 0.6259], [0.5172, 0.4698]]),
    'bias': th.FloatTensor([
        0.3174, 0.1745, -0.1360, 0.2344, -0.4034, -0.4048, -0.2037,
        0.2225, -0.3053, -0.1065, -0.4581, -0.4908, 0.3441, 0.4333,
        -0.3506, 0.0145, -0.0648, 0.6752, -0.1614, -0.3088, 0.6818,
        -0.0273, -0.1540, -0.0735, -0.5940, -0.5363, -0.5801])})

rnn.load_state_dict(rnn_state)
lin.load_state_dict(lin_state)

rnn.zero_grad()
lin.zero_grad()

# inputs
hidden0 = Variable(th.FloatTensor(1, 2))
hidden0[0, 0] = -0.1997
hidden0[0, 1] = 0.4675

input1 = th.FloatTensor(1, 2).fill_(0.0)
input1 = Variable(input1)

input2 = Variable(th.FloatTensor(1, 2))
input2[0, 0] = 0.8316
input2[0, 1] = 1.0540

target = Variable(th.LongTensor([24, 12]).unsqueeze(1))

# process the layers
hidden1 = rnn(input1, hidden0)
hidden2 = rnn(input2, hidden1)
output1 = lin(hidden1)
output2 = lin(hidden2)

output = torch.stack([output1, output2], 0)
output = output.squeeze(1)

# cross entropy loss, but separated w.r.t. outputs
logits = F.log_softmax(output)
objs = -torch.gather(logits.contiguous(), dim=1, index=target)

# backward through each element
for i in range(objs.size(0)):
    objs[i, 0].backward([th.FloatTensor([1.0])], retain_graph=True)

# print gradient norms
for name, param in rnn.named_parameters():
    print(name, torch.norm(param.grad.data))

Metadata

Metadata

Assignees

No one assigned

    Labels

    todoNot as important as medium or high priority tasks, but we will work on these.

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions