-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
module: activation checkpointingRelated to activation checkpointingRelated to activation checkpointingmodule: autogradRelated to torch.autograd, and the autograd engine in generalRelated to torch.autograd, and the autograd engine in general
Description
🐛 Bug
When a function returns the result of einsum is checkpointed, the backwards pass is not correctly computed.
To Reproduce
import torch
from torch.utils.checkpoint import checkpoint
def f(x, w):
return torch.einsum('ab,ab->ab', [x, w])
def g(x, w):
return torch.einsum('ab,ab->a', [x, w])
def h(x, w):
return torch.einsum('ab,ab->a', [x, w]).clone()
# Function f, works
x = torch.ones(1, 1)
w = torch.ones(1, 1).requires_grad_()
y = checkpoint(f, x, w)
z = y.sum()
z.backward()
print(y, z, w.grad)
# Function g, fails
x = torch.ones(1, 1)
w = torch.ones(1, 1).requires_grad_()
y = checkpoint(g, x, w)
z = y.sum()
z.backward()
print(y, z, w.grad)
# Function h, works
x = torch.ones(1, 1)
w = torch.ones(1, 1).requires_grad_()
y = checkpoint(h, x, w)
z = y.sum()
z.backward()
print(y, z, w.grad)Returns:
tensor([[1.]], grad_fn=<CheckpointFunctionBackward>) tensor(1., grad_fn=<SumBackward0>) tensor([[1.]])
tensor([1.], grad_fn=<AsStridedBackward>) tensor(1., grad_fn=<SumBackward0>) None
tensor([1.], grad_fn=<CheckpointFunctionBackward>) tensor(1., grad_fn=<SumBackward0>) tensor([[1.]])
Expected behavior
The three functions should result in the same gradient, but function g results in no gradient. When einsum is followed by clone, the gradient reappears, as in function h.
Environment
PyTorch version: 1.1.0
Is debug build: No
CUDA used to build PyTorch: 10.0.130
OS: Ubuntu 16.04.6 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.11) 5.4.0 20160609
CMake version: Could not collect
Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: GeForce RTX 2080 Ti
Nvidia driver version: 410.104
cuDNN version: Could not collect
Versions of relevant libraries:
[pip] numpy==1.16.3
[pip] pytorch-ignite==0.2.0
[pip] pytorch-memlab==0.0.3
[pip] torch==1.1.0
[pip] torchvision==0.2.2
[conda] blas 1.0 mkl
[conda] ignite 0.2.0 py37_0 pytorch
[conda] mkl 2019.3 199
[conda] mkl_fft 1.0.12 py37ha843d7b_0
[conda] mkl_random 1.0.2 py37hd81dba3_0
[conda] pytorch 1.1.0 py3.7_cuda10.0.130_cudnn7.5.1_0 pytorch
[conda] pytorch-memlab 0.0.3 pypi_0 pypi
[conda] torchvision 0.2.2 py_3 pytorch
Additional context
When I attach a debugger to CheckpointFunction.backward, this is not called when I call backward on the result of function h, while it is triggered on the backwards of functions f and g.
Metadata
Metadata
Assignees
Labels
module: activation checkpointingRelated to activation checkpointingRelated to activation checkpointingmodule: autogradRelated to torch.autograd, and the autograd engine in generalRelated to torch.autograd, and the autograd engine in general