Skip to content

allow forward / backward hooks to rewrite outputs and gradients #262

@soumith

Description

@soumith

Right now, hooks registered need to be read-only, and cannot be used to modify the grad_input / output and still get correct results in the whole graph.
This is because we make some read-only assumptions and optimize buffer reuse.

Allow writeable hooks with an interface like this:

register_backward_hook('name', hook, write=True)

that would allow arbitrary changing of grad_input and output with some user-defined stuff. This for example will be useful for metalearning and in RL.

Here's an example put together by @apaszke showcasing the bugs if we try to modify the gradients in the current implementation of hooks:

import torch
from torch.autograd import Variable

x = Variable(torch.randn(5, 5), requires_grad=True)
y = Variable(torch.randn(5, 5), requires_grad=True)

a = x * 2
b = y * 3

def hook_a(grad_output):
    grad_output.mul_(2)

a.register_hook('test', hook_a)

c = a + b
c.sum().backward()

print(x.grad) # should be 2, is 2
print(y.grad) # should be 3, is 6

cc: @ludc

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions