Skip to content

torch.pow() in a script module produces an error #19253

@nict-wisdom

Description

@nict-wisdom

🐛 Bug

The backward of torch.pow() in a traced module produces an error when the input is on cuda.

RuntimeError:
Expected tensor to have CUDA Backend, but got tensor with CPU Backend (while checking arguments for CUDA_tensor_apply4) (checkBackend at ../aten/src/ATen/TensorUtils.cpp:202)

The backward is defined in torch/csrc/jit/symbolic_script.cpp.
I think the reason is that the backend of the first argument of torch.where() is always CPU.

def pow_0(self,
          exponent: float):
    def backward(grad_output):
        grad_self = torch.where(torch.tensor(exponent == 0.0), torch.zeros_like(self), grad_output * exponent * torch.pow(self, exponent - 1))
        return grad_self, None
    return torch.pow(self, exponent), backward

To Reproduce

The following model is the one in
https://github.com/pytorch/examples/blob/master/mnist/main.py
but I just inserted torch.pow().

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(64, 4*4*50)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = torch.pow(x, 1.1)
        x = F.log_softmax(x, dim=1)
        return x

m = Net().cuda()
x = torch.randn(64, 1, 28, 28, requires_grad=True).cuda()
traced_net = torch.jit.trace(m, x)
traced_output = traced_net(x)
tgt = torch.randn(traced_output.size()).cuda()
traced_output.backward(tgt)

Environment

PyTorch version: 1.1.0a0+7e73783
Is debug build: No
CUDA used to build PyTorch: 9.2.88

OS: CentOS Linux release 7.5.1804 (Core)
GCC version: (GCC) 4.8.5 20150623 (Red Hat 4.8.5-28)
CMake version: version 3.12.2

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 9.2.88
GPU models and configuration:
GPU 0: Tesla P100-PCIE-16GB
GPU 1: Tesla P100-PCIE-16GB

Nvidia driver version: 396.26
cuDNN version: Could not collect

Versions of relevant libraries:
[pip] numpy==1.15.4
[pip] numpydoc==0.8.0
[pip] torch==1.1.0a0+7e73783
[pip] torchvision==0.2.2.post3
[conda] blas 1.0 mkl
[conda] magma-cuda90 2.5.0 1 pytorch
[conda] magma-cuda92 2.5.0 1 pytorch
[conda] mkl 2019.1 144
[conda] mkl-include 2019.3 199
[conda] mkl-service 1.1.2 py37he904b0f_5
[conda] mkl_fft 1.0.6 py37hd81dba3_0
[conda] mkl_random 1.0.2 py37hd81dba3_0
[conda] torch 1.1.0a0+7e73783 dev_0
[conda] torchvision 0.2.2.post3 pypi_0 pypi

Metadata

Metadata

Assignees

Labels

oncall: jitAdd this issue/PR to JIT oncall triage queue

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions