Skip to content

Out-of-memory on GPU due to the "weak_script" decorators #20588

@zhangguanheng66

Description

@zhangguanheng66

🐛 Bug

The issue has been resolved with a recently merged PR (#20563). This issue report is here for the record and future benchmark. The issue is related to the local scope of a weak-scripted function, which cause a memory leak.

We have the out-of-memory issue when running nn.MultiheadAttention module on CUDA. This happened since we split the forward function of nn.MultiheadAttention module and move major calculation to torch.nn.functional.py.

To fix the issue in the merged PR, we had to remove the "weak_script" decorators in multi_head_attention_forward() function.

To Reproduce

Steps to reproduce the behavior:

  1. Make sure you are on commit "6e82b1c77d36386ba738af3287693105b4bbafe2"
  2. Use the following script on GPU to reproduce the OOM error message.

import torch
import torch.nn as nn

d_model = 512
nhead = 16
bptt = 10
batch_size = 15
device = torch.device("cuda")

norm = nn.LayerNorm(d_model).to(device)
self_attn = nn.MultiheadAttention(d_model, nhead).to(device)
src_seq = torch.rand((bptt, batch_size, d_model)).to(device)

for _ in range(200000):
src = norm(src_seq)
output = self_attn(src, src, src)

Expected behavior

RuntimeError: CUDA out of memory. Tried to allocate 2.00 MiB (GPU 0; 15.90 GiB total capacity; 13.49 GiB already allocated; 1.56 MiB free; 1.87 GiB cached)

Environment

Collecting environment information...
PyTorch version: 1.1.0a0+1d33ab8
Is debug build: No
CUDA used to build PyTorch: 9.2.88

OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 7.3.0-27ubuntu1~18.04) 7.3.0
CMake version: version 3.12.2

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 9.2.88
GPU models and configuration:
GPU 0: Quadro GP100
GPU 1: Quadro GP100

Nvidia driver version: 410.79
cuDNN version: Could not collect

Versions of relevant libraries:
[pip] numpy==1.15.4
[pip] numpydoc==0.8.0
[pip] torch==1.1.0a0+1d33ab8
[conda] blas 1.0 mkl
[conda] magma-cuda90 2.5.0 1 pytorch
[conda] mkl 2019.1 144
[conda] mkl-include 2019.3 199
[conda] mkl-service 1.1.2 py37he904b0f_5
[conda] mkl_fft 1.0.6 py37hd81dba3_0
[conda] mkl_random 1.0.2 py37hd81dba3_0
[conda] torch 1.1.0a0+1d33ab8 dev_0

Additional context

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions