Skip to content

Error on cuda.LongTensor which is sent via multiprocessing.Queue #11422

@jianda-chen

Description

@jianda-chen

I was trying to send multiple tensors in tuple via multiprocessing.Queue to anther process. However when I access to a cuda LongTensor, I get this error

RuntimeError: Attempt to access Storage having data type Float as data type Long

Code example

import torch
import torch.multiprocessing as mp
import time

def producer(queue):
    while True:
        a = torch.ones(2,2).float().cuda()
        idx = torch.LongTensor([[0, 0], [0, 1]]).cuda()
        queue.put((a, idx))

def consumer(queue):
    while True:
        a, idx = queue.get()
        print(idx.type())
        print(idx)

if __name__ == '__main__':
    mp.set_start_method('spawn')

    queue = mp.Queue()

    p = mp.Process(target=producer, args=(queue,))
    c = mp.Process(target=consumer, args=(queue,))
    p.start()
    c.start()
    
    time.sleep(10)

    p.join()
    c.join()

Output:

torch.cuda.LongTensor
Process Process-2:
Traceback (most recent call last):
  File "***/lib/python3.5/multiprocessing/process.py", line 252, in _bootstrap
    self.run()
  File "***/lib/python3.5/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "queue_test.py", line 15, in consumer
    print(idx)
  File "***/lib/python3.5/site-packages/torch/tensor.py", line 57, in __repr__
    return torch._tensor_str._str(self)
  File "***/lib/python3.5/site-packages/torch/_tensor_str.py", line 256, in _str
    formatter = _Formatter(get_summarized_data(self) if summarize else self)
  File "***/lib/python3.5/site-packages/torch/_tensor_str.py", line 76, in __init__
    copy = torch.empty(tensor.size(), dtype=torch.long).copy_(tensor).view(tensor.nelement())
RuntimeError: Attempt to access Storage having data type Float as data type Long

System Info

PyTorch version: 0.4.1
Is debug build: No
CUDA used to build PyTorch: 8.0.61

OS: Ubuntu 14.04.5 LTS
GCC version: (Ubuntu 4.8.4-2ubuntu1~14.04.4) 4.8.4
CMake version: version 3.11.0

Python version: 3.5
Is CUDA available: Yes
CUDA runtime version: 8.0.61
GPU models and configuration:
GPU 0: Tesla P100-PCIE-16GB
GPU 1: Tesla P100-PCIE-16GB

Nvidia driver version: 375.66
cuDNN version: Probably one of the following:
/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcudnn.so
/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcudnn.so.5.1.10
/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcudnn.so.6
/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcudnn.so.6.0.21
/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcudnn_static.a

Versions of relevant libraries:
[pip] numpy (1.13.3)
[pip] torch (0.4.1)
[pip] torchvision (0.2.1)
[conda] torch 0.4.1
[conda] torchvision 0.2.1

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions