Skip to content

Einsum as result of autograd.Function forward makes that backward is not called #22072

@pimdh

Description

@pimdh

🐛 Bug

This is a more general version of the bug in #21970, which I'll close.

When a autograd.Function returns the result of einsum, the backward pass is ignored.
When the result of einsum is cloned before returning, the backward pass works fine.

To Reproduce

import torch
from torch.autograd import Function


class TestFunction(Function):
    @staticmethod
    def forward(ctx, x, w):
        ctx.save_for_backward(x, w)
        y = torch.einsum('ab,ab->a', [x, w])
        return y

    @staticmethod
    def backward(ctx, grad):
        x, w = ctx.saved_tensors
        w_grad = torch.einsum('a,ab->ab', [grad, x])
        x_grad = torch.einsum('a,ab->ab', [grad, w])
        return x_grad, w_grad


class TestFunction2(Function):
    @staticmethod
    def forward(ctx, x, w):
        ctx.save_for_backward(x, w)
        y = torch.einsum('ab,ab->a', [x, w])
        return y.clone()

    @staticmethod
    def backward(ctx, grad):
        x, w = ctx.saved_tensors
        w_grad = torch.einsum('a,ab->ab', [grad, x])
        x_grad = torch.einsum('a,ab->ab', [grad, w])
        return x_grad, w_grad


x = torch.ones(1, 1)
w = torch.ones(1, 1).requires_grad_()
y = TestFunction.apply(x, w)
z = y.sum()
z.backward()
print(y, z, w.grad)

x = torch.ones(1, 1)
w = torch.ones(1, 1).requires_grad_()
y = TestFunction2.apply(x, w)
z = y.sum()
z.backward()
print(y, z, w.grad)

Expected behavior

The first output should equal the second output, but instead, we get:

tensor([1.], grad_fn=<AsStridedBackward>) tensor(1., grad_fn=<SumBackward0>) None
tensor([1.], grad_fn=<TestFunction2Backward>) tensor(1., grad_fn=<SumBackward0>) tensor([[1.]])

Environment

PyTorch version: 1.1.0
Is debug build: No
CUDA used to build PyTorch: 10.0.130

OS: Ubuntu 16.04.6 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.11) 5.4.0 20160609
CMake version: Could not collect

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: GeForce RTX 2080 Ti
Nvidia driver version: 410.104
cuDNN version: Could not collect

Versions of relevant libraries:
[pip] numpy==1.16.3
[pip] pytorch-ignite==0.2.0
[pip] pytorch-memlab==0.0.3
[pip] torch==1.1.0
[pip] torchvision==0.2.2
[conda] blas                      1.0                         mkl  
[conda] ignite                    0.2.0                    py37_0    pytorch
[conda] mkl                       2019.3                      199  
[conda] mkl_fft                   1.0.12           py37ha843d7b_0  
[conda] mkl_random                1.0.2            py37hd81dba3_0  
[conda] pytorch                   1.1.0           py3.7_cuda10.0.130_cudnn7.5.1_0    pytorch
[conda] pytorch-memlab            0.0.3                    pypi_0    pypi
[conda] torchvision               0.2.2                      py_3    pytorch

Additional context

When I attach a breakpoint at TestFuction.backward this is never triggered, but a breakpoint at TestFunction2.backward is.

Metadata

Metadata

Assignees

No one assigned

    Labels

    triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions