Skip to content

torch.bernoulli inconsistent gpu/cpu results #10357

@davidmascharka

Description

@davidmascharka

Issue description

The behavior of CUDA and CPU sample in the distributions module is inconsistent. On the CPU side, there appears to be an argument validity check, while on the GPU side, this does not appear to be implemented. This issue is somewhat related.

Code example

>>> torch.distributions.Bernoulli(torch.FloatTensor([1.5])).sample()
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-150-d17a827b2cc3> in <module>()
----> 1 torch.distributions.Bernoulli(torch.FloatTensor([1.5])).sample()

python3.6/site-packages/torch/distributions/bernoulli.py in sample(self, sample_shape)
     72         shape = self._extended_shape(sample_shape)
     73         with torch.no_grad():
---> 74             return torch.bernoulli(self.probs.expand(shape))
     75 
     76     def log_prob(self, value):

RuntimeError: invalid argument 1: must be >= 0 and <= 1 at /pytorch/aten/src/TH/THRandom.cpp:314

>>> torch.distributions.Bernoulli(torch.cuda.FloatTensor([1.5])).sample()
tensor([1.], device='cuda:0')

>>> torch.distributions.Bernoulli(torch.cuda.FloatTensor([-2])).sample()
tensor([0.], device='cuda:0')

System Info

PyTorch version: 0.5.0a0+a24163a
Is debug build: No
CUDA used to build PyTorch: 9.0.176

OS: Ubuntu 16.04.4 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.10) 5.4.0 20160609
CMake version: version 3.9.1

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration: 
GPU 0: GeForce GTX 1080 Ti
GPU 1: GeForce GTX 1080 Ti
GPU 2: GeForce GTX 1080 Ti
GPU 3: GeForce GTX 1080 Ti

Nvidia driver version: 390.67
cuDNN version: 7

Versions of relevant libraries:
[pip] numpy (1.13.3)
[pip] numpydoc (0.7.0)
[pip] torch (0.5.0a0+a24163a)
[pip] torchvision (0.1.9)

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions