-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Bug
The issue has been resolved with a recently merged PR (#20563). This issue report is here for the record and future benchmark. The issue is related to the local scope of a weak-scripted function, which cause a memory leak.
We have the out-of-memory issue when running nn.MultiheadAttention module on CUDA. This happened since we split the forward function of nn.MultiheadAttention module and move major calculation to torch.nn.functional.py.
To fix the issue in the merged PR, we had to remove the "weak_script" decorators in multi_head_attention_forward() function.
To Reproduce
Steps to reproduce the behavior:
- Make sure you are on commit "6e82b1c77d36386ba738af3287693105b4bbafe2"
- Use the following script on GPU to reproduce the OOM error message.
import torch
import torch.nn as nn
d_model = 512
nhead = 16
bptt = 10
batch_size = 15
device = torch.device("cuda")
norm = nn.LayerNorm(d_model).to(device)
self_attn = nn.MultiheadAttention(d_model, nhead).to(device)
src_seq = torch.rand((bptt, batch_size, d_model)).to(device)
for _ in range(200000):
src = norm(src_seq)
output = self_attn(src, src, src)
Expected behavior
RuntimeError: CUDA out of memory. Tried to allocate 2.00 MiB (GPU 0; 15.90 GiB total capacity; 13.49 GiB already allocated; 1.56 MiB free; 1.87 GiB cached)
Environment
Collecting environment information...
PyTorch version: 1.1.0a0+1d33ab8
Is debug build: No
CUDA used to build PyTorch: 9.2.88
OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 7.3.0-27ubuntu1~18.04) 7.3.0
CMake version: version 3.12.2
Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 9.2.88
GPU models and configuration:
GPU 0: Quadro GP100
GPU 1: Quadro GP100
Nvidia driver version: 410.79
cuDNN version: Could not collect
Versions of relevant libraries:
[pip] numpy==1.15.4
[pip] numpydoc==0.8.0
[pip] torch==1.1.0a0+1d33ab8
[conda] blas 1.0 mkl
[conda] magma-cuda90 2.5.0 1 pytorch
[conda] mkl 2019.1 144
[conda] mkl-include 2019.3 199
[conda] mkl-service 1.1.2 py37he904b0f_5
[conda] mkl_fft 1.0.6 py37hd81dba3_0
[conda] mkl_random 1.0.2 py37hd81dba3_0
[conda] torch 1.1.0a0+1d33ab8 dev_0