Skip to content

inplace operations not raising error during backward #144

@glample

Description

@glample

Calling the backward function does not necessarily raise an error when inplace operations are made. This leads to incorrect gradients. Below is an example with an LSTM that computes the gates with and without inplace operations.

class LSTM(nn.Container):

    def __init__(self, input_dim, hidden_dim):
        super(LSTM, self).__init__(
            i2h=nn.Linear(input_dim, 4 * hidden_dim),
            h2h=nn.Linear(hidden_dim, 4 * hidden_dim)
        )
        self.c0 = Variable(torch.zeros(hidden_dim), requires_grad=True)
        self.h0 = Variable(torch.zeros(hidden_dim), requires_grad=True)
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.reset()

    def forward(self, x):
        last_c, last_h = self.temp
        if self.step == 0:
            last_c = last_c.view(1, self.hidden_dim).expand(x.size(0), self.hidden_dim)
            last_h = last_h.view(1, self.hidden_dim).expand(x.size(0), self.hidden_dim)
        gates = self.i2h(x) + self.h2h(last_h)

        ## works
#         i_t = nn.Sigmoid()(gates[:, :self.hidden_dim])
#         f_t = nn.Sigmoid()(gates[:, self.hidden_dim:2 * self.hidden_dim])
#         o_t = nn.Sigmoid()(gates[:, 2 * self.hidden_dim:3 * self.hidden_dim])
#         cell_input = nn.Tanh()(gates[:, 3 * self.hidden_dim:])

        ## works, same results than above, as expected
#         i_t, f_t, o_t = gates[:, :3 * self.hidden_dim].sigmoid().chunk(3, 1)
#         cell_input = gates[:, 3 * self.hidden_dim:].tanh()

        ## fails, although backward runs
        gates[:, :3 * self.hidden_dim].sigmoid_()
        gates[:, 3 * self.hidden_dim:].tanh_()
        i_t, f_t, o_t, cell_input = gates.chunk(4, 1)

        next_c = f_t * last_c + i_t * cell_input
        next_h = o_t * next_c.tanh()
        self.temp = [next_c, next_h]
        self.step += 1
        return next_h

    def reset(self):
        self.step = 0
        self.temp = [self.c0, self.h0]

torch.manual_seed(0)
rnn = LSTM(3, 4)
inputs = [Variable(torch.FloatTensor(2, 3).normal_()) for _ in range(3)]

# code below doesn't always output the same thing when LSTM uses inplace operations
for k, p in rnn.parameter_dict().items():
    p.grad.fill_(0)
for x in inputs:
    print(x.sum().data[0], rnn(x).sum().data[0])
rnn.temp[-1].sum().backward()
for k, p in rnn.parameter_dict().items():
    print(k, p.grad.sum())

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions