Skip to content

Full-range random_() generation broken for cuda.IntTensor, cuda.LongTensor and LongTensor. #16944

@obilaniu

Description

@obilaniu

🐛 Bug

When generating random integers over the full range of IntTensor and LongTensor (-2^31 to 2^31-1, and -2^63 to 2^63-1), only IntTensor on CPU works as expected; The other three combinations either crash or do not produce actual random numbers.

This seems to be at least partly due to integer overflow, as max-min+1 of an integer datatype cannot by definition be contained within the same datatype.

To Reproduce

# NOT WORKING: CUDA, LongTensor
>>> torch.cuda.LongTensor(10).random_(-2**63, 2**63)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Overflow when unpacking long

# NOT WORKING: CUDA, IntTensor. All random numbers identical.
>>> torch.cuda.IntTensor(10).random_(-2**31, 2**31)
tensor([2147483647, 2147483647, 2147483647, 2147483647, 2147483647, 2147483647,
        2147483647, 2147483647, 2147483647, 2147483647], device='cuda:0',
       dtype=torch.int32)

# NOT WORKING: CPU, LongTensor
>>> torch.LongTensor(10).random_(-2**63, 2**63)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Overflow when unpacking long

# WORKING: CPU, IntTensor
>>> torch.IntTensor(10).random_(-2**31, 2**31)
tensor([ -388064048,  -707636165,   558090412, -1659029973,  2046043448,
         1383561155,  -896567077,  1357117392,   297908810,   601099022],
       dtype=torch.int32)

Expected behavior

Random numbers over the full range of the datatype.

Environment

PyTorch version: 1.1.0a0+c865d46
Is debug build: No
CUDA used to build PyTorch: 10.0.130

OS: openSUSE Leap 15.0
GCC version: (SUSE Linux) 7.3.1 20180323 [gcc-7-branch revision 258812]
CMake version: version 3.10.2

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: GeForce GTX 765M
Nvidia driver version: 418.30
cuDNN version: 7.4.1

Versions of relevant libraries:
[pip3] numpy (1.14.0)
[pip3] numpydoc (0.7.0)
[pip3] torch (1.1.0a0+c865d46)
[pip3] torchvision (0.2.0)
[conda] mkl 2018.0.1 h19d6760_4
[conda] mkl-service 1.1.2 py36h17a0993_4

Metadata

Metadata

Assignees

Labels

high prioritymodule: 64-bitProblems related to incorrectly using 32-bit integers when 64-bit is needed (e.g., 8G tensors)module: error checkingBugs related to incorrect/lacking error checkingmodule: randomRelated to random number generation in PyTorch (rng generator)triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions