Skip to content

torch.cdist raises CUDA error on backward with too big batch #27209

@arthurdouillard

Description

@arthurdouillard

🐛 Bug

Hello,

torch.cdist raises a CUDA error when trying to backward if the batch is too large (but still fit in VRAM).

To Reproduce

I've done my tests with:

  • TITAN XP (12gb of vram)
  • CUDA 10.1
  • Driver version 418.74.

Steps to reproduce the behavior:

n = 102  # Don't fail for n < 102

x = torch.randn(n, 1).to("cuda:0")
x.requires_grad = True

dist = torch.cdist(x, x, p=2)
dist.sum().backward()

Raises:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-20-3b0cf57d2b1a> in <module>()
----> 1 dist.sum().backward()

/home/douillard/.local/lib/python3.7/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
    116                 products. Defaults to ``False``.
    117         """
--> 118         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    119 
    120     def register_hook(self, hook):

/home/douillard/.local/lib/python3.7/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
     91     Variable._execution_engine.run_backward(
     92         tensors, grad_tensors, retain_graph, create_graph,
---> 93         allow_unreachable=True)  # allow_unreachable flag
     94 
     95 

RuntimeError: CUDA error: invalid configuration argument

Expected behavior

The backward to work as expected, without cuda crashing.

Environment

PyTorch version: 1.2.0
Is debug build: No
CUDA used to build PyTorch: 10.0.130

OS: Debian GNU/Linux 10 (buster)
GCC version: (Debian 8.3.0-6) 8.3.0
CMake version: version 3.13.4

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 8.0.44
GPU models and configuration:
GPU 0: TITAN Xp
GPU 1: TITAN Xp
GPU 2: TITAN Xp
GPU 3: TITAN Xp

Nvidia driver version: 418.74
cuDNN version: Could not collect

Versions of relevant libraries:
[pip3] numpy==1.17.2
[pip3] numpydoc==0.7.0
[pip3] torch==1.2.0
[pip3] torchvision==0.4.0
[conda] Could not collect

Additional context

It works however if I add a new dimension:

dist = torch.cdist(x.unsqueeze(0), x.unsqueeze(0), p=2)

cc @ezyang @ssnl @albanD @zou3519 @gqchen

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: autogradRelated to torch.autograd, and the autograd engine in generalmodule: cudaRelated to torch.cuda, and CUDA support in generaltriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions