Skip to content

Problem with backward hook function #598

@ludc

Description

@ludc

Hi,

there is something strange in the backward step (or maybe something I don't understand). If I define a Module that takes 3 inputs, the grad_input has to be of size 3, right ? But this is not the case here (from the backward_hook point of view):

from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F

def bh(m,go,gi):
    print("Grad Input")
    print(go)
    print("Grad Output")
    print(gi)

class M(nn.Module):
    def __init__(self):
        super(M,self).__init__()
        self.register_backward_hook(bh)


    def forward(self,x,y,z):
        return (x+y+z)

x=Variable(torch.randn(1,5),requires_grad=True)
y=Variable(torch.randn(1,5),requires_grad=True)
z=Variable(torch.randn(1,5),requires_grad=True)

criterion=nn.MSELoss()
mod=M()
out=mod(x,y,z)
loss=criterion(out,Variable(torch.randn(1,5)))
loss.backward()```

In that case, when I print grad_input throught the hook function, it is just composed of two elements... Could you tell me where am I wrong ? But `x.grad, y.grad and z.grad` seem correctly computed

cc @ezyang @gchanan @zou3519 @SsnL @albanD @gqchen

Metadata

Metadata

Labels

high prioritymodule: autogradRelated to torch.autograd, and the autograd engine in generalmodule: docsRelated to our documentation, both in docs/ and docblockstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions