-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
When I use a custom collate_fn, multiprocessing and move the batch to cuda in the collate_fn, I get a RuntimeError: CUDA error: initialization error. Below a MWE. Tested on google colab with pytorch 1.0.1.post2.
import torch
import torch.utils.data
from collections import defaultdict
class Dataset(torch.utils.data.Dataset):
def __init__(self):
self.data = defaultdict(dict)
for i in range(32):
self.data[i]['sequence'] = list(range(10))
def __getitem__(self, idx):
return self.data[idx]
def __len__(self):
return len(self.data)
def collate_fn(data):
batch = defaultdict(list)
for key in data[0].keys():
batch[key] = [s[key] for s in data]
batch[key] = torch.Tensor(batch[key]).long()
batch[key] = batch[key].to(torch.device('cuda'))
return batch
# default collate works
try:
dataloader = torch.utils.data.DataLoader(Dataset(), 4, num_workers=2)
for epoch in range(2):
for iteration, batch in enumerate(dataloader):
pass
print("Default collate_fn works!")
except Exception as e:
print("Default collate_fn failed!")
print(e)
# custom collate fails
try:
dataloader = torch.utils.data.DataLoader(Dataset(), 4, num_workers=2, collate_fn=collate_fn)
for epoch in range(2):
for iteration, batch in enumerate(dataloader):
print(epoch, iteration)
print("Custom collate_fn works!")
except Exception as e:
print("Custom collate_fn failed!")
print(e)Default collate_fn works!
0 0
0 1
0 2
0 3
0 4
0 5
0 6
0 7
Custom collate_fn failed!
Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 138, in _worker_loop
samples = collate_fn([dataset[i] for i in batch_indices])
File "", line 29, in collate_fn
batch[key] = batch[key].to(torch.device('cuda'))
RuntimeError: CUDA error: initialization error
Edit: Obviously for this MWE a custom collate is not necessary. However, in my use case, I need it because I have variable length data, so I do the padding in the collate_fn. But as you can see from the trace, it fails when moving the data to cuda.