Skip to content

Checkpoint backward not called if function returns einsum result #21970

@pimdh

Description

@pimdh

🐛 Bug

When a function returns the result of einsum is checkpointed, the backwards pass is not correctly computed.

To Reproduce

import torch
from torch.utils.checkpoint import checkpoint


def f(x, w):
    return torch.einsum('ab,ab->ab', [x, w])


def g(x, w):
    return torch.einsum('ab,ab->a', [x, w])


def h(x, w):
    return torch.einsum('ab,ab->a', [x, w]).clone()


# Function f, works
x = torch.ones(1, 1)
w = torch.ones(1, 1).requires_grad_()
y = checkpoint(f, x, w)
z = y.sum()
z.backward()
print(y, z, w.grad)

# Function g, fails
x = torch.ones(1, 1)
w = torch.ones(1, 1).requires_grad_()
y = checkpoint(g, x, w)
z = y.sum()
z.backward()
print(y, z, w.grad)

# Function h, works
x = torch.ones(1, 1)
w = torch.ones(1, 1).requires_grad_()
y = checkpoint(h, x, w)
z = y.sum()
z.backward()
print(y, z, w.grad)

Returns:

tensor([[1.]], grad_fn=<CheckpointFunctionBackward>) tensor(1., grad_fn=<SumBackward0>) tensor([[1.]])
tensor([1.], grad_fn=<AsStridedBackward>) tensor(1., grad_fn=<SumBackward0>) None
tensor([1.], grad_fn=<CheckpointFunctionBackward>) tensor(1., grad_fn=<SumBackward0>) tensor([[1.]])

Expected behavior

The three functions should result in the same gradient, but function g results in no gradient. When einsum is followed by clone, the gradient reappears, as in function h.

Environment

PyTorch version: 1.1.0
Is debug build: No
CUDA used to build PyTorch: 10.0.130

OS: Ubuntu 16.04.6 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.11) 5.4.0 20160609
CMake version: Could not collect

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: GeForce RTX 2080 Ti
Nvidia driver version: 410.104
cuDNN version: Could not collect

Versions of relevant libraries:
[pip] numpy==1.16.3
[pip] pytorch-ignite==0.2.0
[pip] pytorch-memlab==0.0.3
[pip] torch==1.1.0
[pip] torchvision==0.2.2
[conda] blas                      1.0                         mkl  
[conda] ignite                    0.2.0                    py37_0    pytorch
[conda] mkl                       2019.3                      199  
[conda] mkl_fft                   1.0.12           py37ha843d7b_0  
[conda] mkl_random                1.0.2            py37hd81dba3_0  
[conda] pytorch                   1.1.0           py3.7_cuda10.0.130_cudnn7.5.1_0    pytorch
[conda] pytorch-memlab            0.0.3                    pypi_0    pypi
[conda] torchvision               0.2.2                      py_3    pytorch

Additional context

When I attach a debugger to CheckpointFunction.backward, this is not called when I call backward on the result of function h, while it is triggered on the backwards of functions f and g.

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: activation checkpointingRelated to activation checkpointingmodule: autogradRelated to torch.autograd, and the autograd engine in general

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions