Skip to content

Cannot change device of JIT module after initial loading using C++ frontend #19039

@abonnet

Description

@abonnet

🐛 Bug

I exported a JIT model on GPU, then try to load it and run it on CPU, using C++ frontend.

To Reproduce

Steps to reproduce the behavior:

  1. Create a demo model and export it on GPU
import torch

class BatchNormModel(torch.jit.ScriptModule):
    def __init__(self):
        super(BatchNormModel, self).__init__(optimize=False)
        self.bn = torch.nn.BatchNorm1d(10)

    @torch.jit.script_method
    def forward(self, x):
        return self.bn(x)


model = BatchNormModel()
model.eval()
model.cuda()
model.save("/tmp/bn_gpu.pt")
  1. Try to load the model and execute on CPU
#include <torch/script.h>

#include <memory>

int main() {
    std::shared_ptr<torch::jit::script::Module> module = torch::jit::load("/tmp/bn_gpu.pt");
    module->to(torch::kCPU);

    std::vector<torch::jit::IValue> inputs;
    inputs.emplace_back(torch::randn({6, 10}));
    auto output = module->forward(inputs).toTensor();
}

Error

Terminate std::runtime_error( Expected tensor to have CPU Backend, but got tensor with CUDA Backend (while checking arguments for batch_norm_cpu) (checkBackend at ../aten/src/ATen/TensorUtils.cpp:202)

Expected behavior

Expect ->to function to recursively move all the parameters on CPU.

Environment

Please copy and paste the output from our
PyTorch version: 1.1.0a0+9696f06
Is debug build: No
CUDA used to build PyTorch: 10.0.130

OS: Ubuntu 14.04.6 LTS
GCC version: (Ubuntu 4.8.4-2ubuntu1~14.04.4) 4.8.4
CMake version: version 3.11.1

Python version: 2.7
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: TITAN V
GPU 1: TITAN Xp
GPU 2: GeForce RTX 2080 Ti

Nvidia driver version: 410.78
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.5.1.5
/usr/lib/x86_64-linux-gnu/libcudnn.so.6.0.21

Versions of relevant libraries:
[pip] numpy==1.13.3
[pip] pyarrow==0.10.0+numpy1.13.1
[conda] Could not collect

Additional context

I have seen similar behavior when trying to cast weights to be half (FP16). If I pass a torch::Device object to torch::jit::load, model is correctly loaded on CPU so the issue seems to be related to the ->to function.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions