Skip to content

implement dirichlet / beta GPU grad  #11030

@jramapuram

Description

@jramapuram

Issue description

torch.distributions.Beta and torch.distributions.Dirichlet don't allow for backward() calls on the GPU; the CPU versions works fine though.

In [1]: import torch

In [2]: d = torch.distributions.Dirichlet(torch.sigmoid(torch.randn(3, 4, requir
   ...: es_grad=True))).rsample()

In [3]: torch.mean(d).backward()

In [4]: d = torch.distributions.Dirichlet(torch.sigmoid(torch.randn(3, 4, requir
   ...: es_grad=True).cuda())).rsample()


In [5]: 

In [5]: torch.mean(d).backward()
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-5-d517ed89bac2> in <module>()
----> 1 torch.mean(d).backward()

~/.venv3/lib/python3.6/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
     94                 products. Defaults to ``False``.
     95         """
---> 96         torch.autograd.backward(self, gradient, retain_graph, create_graph)
     97 
     98     def register_hook(self, hook):

~/.venv3/lib/python3.6/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
     88     Variable._execution_engine.run_backward(
     89         tensors, grad_tensors, retain_graph, create_graph,
---> 90         allow_unreachable=True)  # allow_unreachable flag
     91 
     92 

~/.venv3/lib/python3.6/site-packages/torch/autograd/function.py in apply(self, *args)
     74 
     75     def apply(self, *args):
---> 76         return self._forward_cls.backward(self, *args)
     77 
     78 

~/.venv3/lib/python3.6/site-packages/torch/autograd/function.py in wrapper(ctx, *args)
    186     def wrapper(ctx, *args):
    187         with torch.no_grad():
--> 188             outputs = fn(ctx, *args)
    189 
    190         if not torch.is_grad_enabled():

~/.venv3/lib/python3.6/site-packages/torch/distributions/dirichlet.py in backward(ctx, grad_output)
     33     def backward(ctx, grad_output):
     34         x, concentration = ctx.saved_tensors
---> 35         return _Dirichlet_backward(x, concentration, grad_output)
     36 
     37 

~/.venv3/lib/python3.6/site-packages/torch/distributions/dirichlet.py in _Dirichlet_backward(x, concentration, grad_output)
     18 def _Dirichlet_backward(x, concentration, grad_output):
     19     total = concentration.sum(-1, True).expand_as(concentration)
---> 20     grad = torch._dirichlet_grad(x, concentration, total)
     21     return grad * (grad_output - (x * grad_output).sum(-1, True))
     22 

RuntimeError: _dirichlet_grad is not implemented for type torch.cuda.FloatTensor

Provide a short description.

Code example

import torch
d = torch.distributions.Dirichlet(torch.sigmoid(torch.randn(3, 4, requires_grad=True))).rsample()
torch.mean(d).backward() # works!
d = torch.distributions.Dirichlet(torch.sigmoid(torch.randn(3, 4, requires_grad=True).cuda())).rsample()
torch.mean(d).backward() # throws above exception

System Info

(base) ➜ /tmp python collect_env.py
Collecting environment information...
PyTorch version: 0.5.0a0+ddc37d7
Is debug build: No
CUDA used to build PyTorch: 9.2.148

OS: Manjaro Linux
GCC version: (GCC) 8.2.0
CMake version: version 3.11.1

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 9.2.148
GPU models and configuration: GPU 0: GeForce GTX 1060
Nvidia driver version: 396.54
cuDNN version: Could not collect

Versions of relevant libraries:
[pip] Could not collect
[conda] magma-cuda92 2.3.0 1 pytorch
[conda] torch 0.5.0a0+ddc37d7
[conda] torchfile 0.1.0
[conda] torchvision 0.2.1

  • PyTorch or Caffe2: pytorch
  • How you installed PyTorch (conda, pip, source): source
  • OS: arch
  • Python version: 3.7

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: distributionsRelated to torch.distributionstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions