-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Closed
Copy link
Labels
triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Bug
This is a more general version of the bug in #21970, which I'll close.
When a autograd.Function returns the result of einsum, the backward pass is ignored.
When the result of einsum is cloned before returning, the backward pass works fine.
To Reproduce
import torch
from torch.autograd import Function
class TestFunction(Function):
@staticmethod
def forward(ctx, x, w):
ctx.save_for_backward(x, w)
y = torch.einsum('ab,ab->a', [x, w])
return y
@staticmethod
def backward(ctx, grad):
x, w = ctx.saved_tensors
w_grad = torch.einsum('a,ab->ab', [grad, x])
x_grad = torch.einsum('a,ab->ab', [grad, w])
return x_grad, w_grad
class TestFunction2(Function):
@staticmethod
def forward(ctx, x, w):
ctx.save_for_backward(x, w)
y = torch.einsum('ab,ab->a', [x, w])
return y.clone()
@staticmethod
def backward(ctx, grad):
x, w = ctx.saved_tensors
w_grad = torch.einsum('a,ab->ab', [grad, x])
x_grad = torch.einsum('a,ab->ab', [grad, w])
return x_grad, w_grad
x = torch.ones(1, 1)
w = torch.ones(1, 1).requires_grad_()
y = TestFunction.apply(x, w)
z = y.sum()
z.backward()
print(y, z, w.grad)
x = torch.ones(1, 1)
w = torch.ones(1, 1).requires_grad_()
y = TestFunction2.apply(x, w)
z = y.sum()
z.backward()
print(y, z, w.grad)
Expected behavior
The first output should equal the second output, but instead, we get:
tensor([1.], grad_fn=<AsStridedBackward>) tensor(1., grad_fn=<SumBackward0>) None
tensor([1.], grad_fn=<TestFunction2Backward>) tensor(1., grad_fn=<SumBackward0>) tensor([[1.]])
Environment
PyTorch version: 1.1.0
Is debug build: No
CUDA used to build PyTorch: 10.0.130
OS: Ubuntu 16.04.6 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.11) 5.4.0 20160609
CMake version: Could not collect
Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: GeForce RTX 2080 Ti
Nvidia driver version: 410.104
cuDNN version: Could not collect
Versions of relevant libraries:
[pip] numpy==1.16.3
[pip] pytorch-ignite==0.2.0
[pip] pytorch-memlab==0.0.3
[pip] torch==1.1.0
[pip] torchvision==0.2.2
[conda] blas 1.0 mkl
[conda] ignite 0.2.0 py37_0 pytorch
[conda] mkl 2019.3 199
[conda] mkl_fft 1.0.12 py37ha843d7b_0
[conda] mkl_random 1.0.2 py37hd81dba3_0
[conda] pytorch 1.1.0 py3.7_cuda10.0.130_cudnn7.5.1_0 pytorch
[conda] pytorch-memlab 0.0.3 pypi_0 pypi
[conda] torchvision 0.2.2 py_3 pytorch
Additional context
When I attach a breakpoint at TestFuction.backward this is never triggered, but a breakpoint at TestFunction2.backward is.
Metadata
Metadata
Assignees
Labels
triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module