Skip to content

[multiprocessing] does not play well with distributions in GPU #16141

@fehiepsi

Description

@fehiepsi

🐛 Bug

Continued from some discussions at #14736 and pyro-ppl/pyro#1694.

Running the following script gives the error RuntimeError: Assertion `self->allocator() != nullptr' failed:

Traceback (most recent call last):
  File "testt.py", line 31, in <module>
    sample = q.get()
  File "/home/fehiepsi/miniconda3/envs/pydata/lib/python3.6/multiprocessing/queues.py", line 113, in get
    return _ForkingPickler.loads(res)
  File "/home/fehiepsi/miniconda3/envs/pydata/lib/python3.6/site-packages/torch/multiprocessing/reductions.py", line 104, in rebuild_cuda_tensor
    t = torch._utils._rebuild_tensor(storage, tensor_offset, tensor_size, tensor_stride)
  File "/home/fehiepsi/miniconda3/envs/pydata/lib/python3.6/site-packages/torch/_utils.py", line 131, in _rebuild_tensor
    return tensor_class().set_(storage, storage_offset, size, stride)
RuntimeError: Assertion `self->allocator() != nullptr' failed.  at /opt/conda/conda-bld/pytorch-nightly_1547631207535/work/aten/src/THC/THCStorage.cpp:16

To Reproduce

Running the following script

import torch
import torch.distributions as dist
import torch.multiprocessing as mp

torch.set_default_tensor_type(torch.cuda.FloatTensor)
n = 10

def model():
    d1 = dist.Normal(torch.zeros(3), 1)
    v1 = d1.rsample()  # no error when replace `.rsample()` by `.sample()`
    d2 = dist.Normal(v1, 2)
    v2 = d2.rsample()
    return [(d1, v1), (d2, v2)]

def worker(q, e):
    for i in range(n):
        sample = [torch.zeros(1), torch.ones(1)]  # no error when remove this line ???
        sample = model()
        q.put(sample)
        e.wait()
        e.clear()

if __name__ == "__main__":
    ctx = mp.get_context("spawn")
    q = ctx.Queue()
    e = ctx.Event()
    p = ctx.Process(target=worker, args=(q, e))
    p.start()
    for i in range(n):
        print("=== ITER {} ===".format(i))
        sample = q.get()
        print(sample)
        e.set()
    p.join()

Expected behavior

No error as when running in CPU.

Environment

  • PyTorch Version (e.g., 1.0): 1.0
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, source): conda
  • Build command you used (if compiling from source):
  • Python version: 3.6
  • CUDA/cuDNN version: 9
  • GPU models and configuration: GTX 1080

Additional context

The error does not happen with PyTorch 1.0.dev.20181202 version.

cc @neerajprad @ailzhang @ezyang

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions