-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
Issue description
torch.distributions.Beta and torch.distributions.Dirichlet don't allow for backward() calls on the GPU; the CPU versions works fine though.
In [1]: import torch
In [2]: d = torch.distributions.Dirichlet(torch.sigmoid(torch.randn(3, 4, requir
...: es_grad=True))).rsample()
In [3]: torch.mean(d).backward()
In [4]: d = torch.distributions.Dirichlet(torch.sigmoid(torch.randn(3, 4, requir
...: es_grad=True).cuda())).rsample()
In [5]:
In [5]: torch.mean(d).backward()
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-5-d517ed89bac2> in <module>()
----> 1 torch.mean(d).backward()
~/.venv3/lib/python3.6/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
94 products. Defaults to ``False``.
95 """
---> 96 torch.autograd.backward(self, gradient, retain_graph, create_graph)
97
98 def register_hook(self, hook):
~/.venv3/lib/python3.6/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
88 Variable._execution_engine.run_backward(
89 tensors, grad_tensors, retain_graph, create_graph,
---> 90 allow_unreachable=True) # allow_unreachable flag
91
92
~/.venv3/lib/python3.6/site-packages/torch/autograd/function.py in apply(self, *args)
74
75 def apply(self, *args):
---> 76 return self._forward_cls.backward(self, *args)
77
78
~/.venv3/lib/python3.6/site-packages/torch/autograd/function.py in wrapper(ctx, *args)
186 def wrapper(ctx, *args):
187 with torch.no_grad():
--> 188 outputs = fn(ctx, *args)
189
190 if not torch.is_grad_enabled():
~/.venv3/lib/python3.6/site-packages/torch/distributions/dirichlet.py in backward(ctx, grad_output)
33 def backward(ctx, grad_output):
34 x, concentration = ctx.saved_tensors
---> 35 return _Dirichlet_backward(x, concentration, grad_output)
36
37
~/.venv3/lib/python3.6/site-packages/torch/distributions/dirichlet.py in _Dirichlet_backward(x, concentration, grad_output)
18 def _Dirichlet_backward(x, concentration, grad_output):
19 total = concentration.sum(-1, True).expand_as(concentration)
---> 20 grad = torch._dirichlet_grad(x, concentration, total)
21 return grad * (grad_output - (x * grad_output).sum(-1, True))
22
RuntimeError: _dirichlet_grad is not implemented for type torch.cuda.FloatTensor
Provide a short description.
Code example
import torch
d = torch.distributions.Dirichlet(torch.sigmoid(torch.randn(3, 4, requires_grad=True))).rsample()
torch.mean(d).backward() # works!
d = torch.distributions.Dirichlet(torch.sigmoid(torch.randn(3, 4, requires_grad=True).cuda())).rsample()
torch.mean(d).backward() # throws above exceptionSystem Info
(base) ➜ /tmp python collect_env.py
Collecting environment information...
PyTorch version: 0.5.0a0+ddc37d7
Is debug build: No
CUDA used to build PyTorch: 9.2.148
OS: Manjaro Linux
GCC version: (GCC) 8.2.0
CMake version: version 3.11.1
Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 9.2.148
GPU models and configuration: GPU 0: GeForce GTX 1060
Nvidia driver version: 396.54
cuDNN version: Could not collect
Versions of relevant libraries:
[pip] Could not collect
[conda] magma-cuda92 2.3.0 1 pytorch
[conda] torch 0.5.0a0+ddc37d7
[conda] torchfile 0.1.0
[conda] torchvision 0.2.1
- PyTorch or Caffe2: pytorch
- How you installed PyTorch (conda, pip, source): source
- OS: arch
- Python version: 3.7