-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Bug
I exported a JIT model on GPU, then try to load it and run it on CPU, using C++ frontend.
To Reproduce
Steps to reproduce the behavior:
- Create a demo model and export it on GPU
import torch
class BatchNormModel(torch.jit.ScriptModule):
def __init__(self):
super(BatchNormModel, self).__init__(optimize=False)
self.bn = torch.nn.BatchNorm1d(10)
@torch.jit.script_method
def forward(self, x):
return self.bn(x)
model = BatchNormModel()
model.eval()
model.cuda()
model.save("/tmp/bn_gpu.pt")
- Try to load the model and execute on CPU
#include <torch/script.h>
#include <memory>
int main() {
std::shared_ptr<torch::jit::script::Module> module = torch::jit::load("/tmp/bn_gpu.pt");
module->to(torch::kCPU);
std::vector<torch::jit::IValue> inputs;
inputs.emplace_back(torch::randn({6, 10}));
auto output = module->forward(inputs).toTensor();
}
Error
Terminate std::runtime_error( Expected tensor to have CPU Backend, but got tensor with CUDA Backend (while checking arguments for batch_norm_cpu) (checkBackend at ../aten/src/ATen/TensorUtils.cpp:202)
Expected behavior
Expect ->to function to recursively move all the parameters on CPU.
Environment
Please copy and paste the output from our
PyTorch version: 1.1.0a0+9696f06
Is debug build: No
CUDA used to build PyTorch: 10.0.130
OS: Ubuntu 14.04.6 LTS
GCC version: (Ubuntu 4.8.4-2ubuntu1~14.04.4) 4.8.4
CMake version: version 3.11.1
Python version: 2.7
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: TITAN V
GPU 1: TITAN Xp
GPU 2: GeForce RTX 2080 Ti
Nvidia driver version: 410.78
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.5.1.5
/usr/lib/x86_64-linux-gnu/libcudnn.so.6.0.21
Versions of relevant libraries:
[pip] numpy==1.13.3
[pip] pyarrow==0.10.0+numpy1.13.1
[conda] Could not collect
Additional context
I have seen similar behavior when trying to cast weights to be half (FP16). If I pass a torch::Device object to torch::jit::load, model is correctly loaded on CPU so the issue seems to be related to the ->to function.