Skip to content

RNN CUDNN backend OOM issue #1737

@imisra

Description

@imisra

Hi,

I think I have stumbled upon something weird with the CUDNN backend for RNN. I am using CUDNN v5 on Cent OS 7.3.1.

torch.version.__version__ = e1d257bc6d472ee297df1719bf344bae359dbeaa

I have discussed this with @soumith as well.
The code snippet for reproducing is below. Enabling the cudnn backend increases the memory used linearly (goes OOM eventually). Disabling the backend results in expected behavior.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import torch
torch.backends.cudnn.enabled = False
import torch.cuda
import torch.nn as nn
from torch.autograd import Variable
import gc

print(torch.version.__version__)


def get_num_tensors():
    ctr = 0
    for obj in gc.get_objects():
        if torch.is_tensor(obj):
            ctr += 1
    return ctr


wordvec_dim = 300
hidden_dim = 256
rnn_num_layers = 1
batch_size = 10
vocab_size = 100
rnn_dropout = 0.5

model = nn.LSTM(wordvec_dim, hidden_dim, rnn_num_layers,
                           dropout=rnn_dropout, batch_first=True)
# set training mode
model.cuda()
model.train()

encoded = Variable(torch.FloatTensor(batch_size, 1, wordvec_dim))
encoded = encoded.cuda()

h0 = Variable(torch.zeros(rnn_num_layers, batch_size, hidden_dim))
c0 = Variable(torch.zeros(rnn_num_layers, batch_size, hidden_dim))
h = h0.cuda()
c = c0.cuda()

print('Start:', get_num_tensors())
num_forward_passes = 10

for _i in range(num_forward_passes):
    output, (h, c) = model(encoded, (h, c))
    print(_i, get_num_tensors())

print('End:', get_num_tensors())

Output with cudnn enabled

e1d257bc6d472ee297df1719bf344bae359dbeaa
Start: 9
0 16
1 22
2 28
3 34
4 40
5 46
6 52
7 58
8 64
9 70
End: 70

Output without cudnn

e1d257bc6d472ee297df1719bf344bae359dbeaa
Start: 9
0 10
1 10
2 10
3 10
4 10
5 10
6 10
7 10
8 10
9 10
End: 10

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions