-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Description
Calling the backward function does not necessarily raise an error when inplace operations are made. This leads to incorrect gradients. Below is an example with an LSTM that computes the gates with and without inplace operations.
class LSTM(nn.Container):
def __init__(self, input_dim, hidden_dim):
super(LSTM, self).__init__(
i2h=nn.Linear(input_dim, 4 * hidden_dim),
h2h=nn.Linear(hidden_dim, 4 * hidden_dim)
)
self.c0 = Variable(torch.zeros(hidden_dim), requires_grad=True)
self.h0 = Variable(torch.zeros(hidden_dim), requires_grad=True)
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.reset()
def forward(self, x):
last_c, last_h = self.temp
if self.step == 0:
last_c = last_c.view(1, self.hidden_dim).expand(x.size(0), self.hidden_dim)
last_h = last_h.view(1, self.hidden_dim).expand(x.size(0), self.hidden_dim)
gates = self.i2h(x) + self.h2h(last_h)
## works
# i_t = nn.Sigmoid()(gates[:, :self.hidden_dim])
# f_t = nn.Sigmoid()(gates[:, self.hidden_dim:2 * self.hidden_dim])
# o_t = nn.Sigmoid()(gates[:, 2 * self.hidden_dim:3 * self.hidden_dim])
# cell_input = nn.Tanh()(gates[:, 3 * self.hidden_dim:])
## works, same results than above, as expected
# i_t, f_t, o_t = gates[:, :3 * self.hidden_dim].sigmoid().chunk(3, 1)
# cell_input = gates[:, 3 * self.hidden_dim:].tanh()
## fails, although backward runs
gates[:, :3 * self.hidden_dim].sigmoid_()
gates[:, 3 * self.hidden_dim:].tanh_()
i_t, f_t, o_t, cell_input = gates.chunk(4, 1)
next_c = f_t * last_c + i_t * cell_input
next_h = o_t * next_c.tanh()
self.temp = [next_c, next_h]
self.step += 1
return next_h
def reset(self):
self.step = 0
self.temp = [self.c0, self.h0]
torch.manual_seed(0)
rnn = LSTM(3, 4)
inputs = [Variable(torch.FloatTensor(2, 3).normal_()) for _ in range(3)]
# code below doesn't always output the same thing when LSTM uses inplace operations
for k, p in rnn.parameter_dict().items():
p.grad.fill_(0)
for x in inputs:
print(x.sum().data[0], rnn(x).sum().data[0])
rnn.temp[-1].sum().backward()
for k, p in rnn.parameter_dict().items():
print(k, p.grad.sum())