-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
Description
Right now, hooks registered need to be read-only, and cannot be used to modify the grad_input / output and still get correct results in the whole graph.
This is because we make some read-only assumptions and optimize buffer reuse.
Allow writeable hooks with an interface like this:
register_backward_hook('name', hook, write=True)that would allow arbitrary changing of grad_input and output with some user-defined stuff. This for example will be useful for metalearning and in RL.
Here's an example put together by @apaszke showcasing the bugs if we try to modify the gradients in the current implementation of hooks:
import torch
from torch.autograd import Variable
x = Variable(torch.randn(5, 5), requires_grad=True)
y = Variable(torch.randn(5, 5), requires_grad=True)
a = x * 2
b = y * 3
def hook_a(grad_output):
grad_output.mul_(2)
a.register_hook('test', hook_a)
c = a + b
c.sum().backward()
print(x.grad) # should be 2, is 2
print(y.grad) # should be 3, is 6cc: @ludc
szymonmaszke and NarineK