Skip to content

[cuda] randn on a non-default stream doesn't work #19508

@ssnl

Description

@ssnl

🐛 Bug

In [1]: import torch
   ...:
   ...: s = torch.cuda.Stream()
   ...:
   ...: with torch.cuda.stream(s):
   ...:     x = torch.randn(30000, device='cuda')
   ...:
   ...: torch.cuda.synchronize()


THCudaCheck FAIL file=../torch/csrc/cuda/Module.cpp line=210 error=77 : an illegal memory access was encountered
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-1-0f0394c25b35> in <module>
      6     x = torch.randn(30000, device='cuda')
      7
----> 8 torch.cuda.synchronize()

/data/packages/pytorch/torch/cuda/__init__.py in synchronize()
    356     r"""Waits for all kernels in all streams on current device to complete."""
    357     _lazy_init()
--> 358     return torch._C._cuda_synchronize()
    359
    360

RuntimeError: cuda runtime error (77) : an illegal memory access was encountered at ../torch/csrc/cuda/Module.cpp:210

On the other hand, non sampling factory methods work:

In [1]: import torch
   ...:
   ...: s = torch.cuda.Stream()
   ...:
   ...: with torch.cuda.stream(s):
   ...:     x = torch.empty(30000, device='cuda')
   ...:
   ...: torch.cuda.synchronize()

In [2]: x
Out[2]: tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0')

Upon further investigation, this seems to be a problem with random state initialization, because all is good after I initialize rng state first:

In [1]: import torch
   ...:
   ...: s = torch.cuda.Stream()
   ...:
   ...: torch.randn(1, device='cuda')
   ...:
   ...: with torch.cuda.stream(s):
   ...:     x = torch.randn(30000, device='cuda')
   ...:
   ...: torch.cuda.synchronize()

In [2]: x
Out[2]:
tensor([ 0.8381, -0.3533, -0.9305,  ..., -1.0518,  1.4129,  0.4790],
       device='cuda:0')

Environment

PyTorch version: 1.1.0a0+746065b
Is debug build: No
CUDA used to build PyTorch: 10.0.130

OS: Ubuntu 18.04.2 LTS
GCC version: (Ubuntu 7.3.0-27ubuntu1~18.04) 7.3.0
CMake version: version 3.12.2

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration:
GPU 0: TITAN Xp
GPU 1: TITAN Xp

Nvidia driver version: 415.27
cuDNN version: Could not collect

Versions of relevant libraries:
[pip3] numpy==1.16.2
[pip3] numpydoc==0.8.0
[pip3] pytorch-sphinx-theme==0.0.24
[pip3] torch==1.1.0a0+746065b
[pip3] torchfile==0.1.0
[pip3] torchreparam==0.0.1
[pip3] torchvision==0.2.1
[conda] blas                      1.0                         mkl    defaults
[conda] magma-cuda100             2.5.0                         1    pytorch
[conda] mkl                       2019.1                      144    defaults
[conda] mkl-include               2019.1                      144    defaults
[conda] mkl-service               1.1.2            py37he904b0f_5    defaults
[conda] mkl_fft                   1.0.10           py37ha843d7b_0    defaults
[conda] mkl_random                1.0.2            py37hd81dba3_0    defaults
[conda] pytorch-sphinx-theme      0.0.24                   pypi_0    pypi
[conda] torch                     1.1.0a0+746065b           dev_0    <develop>
[conda] torchfile                 0.1.0                    pypi_0    pypi
[conda] torchreparam              0.0.1                     dev_0    <develop>

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: cudaRelated to torch.cuda, and CUDA support in generaltriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions