Skip to content

Custom collate_fn fails on second epoch with multiprocessing and cuda #17359

@timbmg

Description

@timbmg

When I use a custom collate_fn, multiprocessing and move the batch to cuda in the collate_fn, I get a RuntimeError: CUDA error: initialization error. Below a MWE. Tested on google colab with pytorch 1.0.1.post2.

import torch
import torch.utils.data
from collections import defaultdict

class Dataset(torch.utils.data.Dataset):
  
  def __init__(self):
    self.data = defaultdict(dict)
    for i in range(32):
      self.data[i]['sequence'] = list(range(10))
    
  def __getitem__(self, idx):
    return self.data[idx]
 
  def __len__(self):
    return len(self.data)
  

def collate_fn(data):
  batch = defaultdict(list)
  for key in data[0].keys():
    batch[key] = [s[key] for s in data]
    batch[key] = torch.Tensor(batch[key]).long()
    batch[key] = batch[key].to(torch.device('cuda'))
      
  return batch

# default collate works
try:
  dataloader = torch.utils.data.DataLoader(Dataset(), 4, num_workers=2)
  for epoch in range(2):
    for iteration, batch in enumerate(dataloader):
      pass
  print("Default collate_fn works!")
except Exception as e:
  print("Default collate_fn failed!")
  print(e)

# custom collate fails
try:
  dataloader = torch.utils.data.DataLoader(Dataset(), 4, num_workers=2, collate_fn=collate_fn)
  for epoch in range(2):
    for iteration, batch in enumerate(dataloader):
      print(epoch, iteration)
  print("Custom collate_fn works!")
except Exception as e:
  print("Custom collate_fn failed!")
  print(e)

Default collate_fn works!
0 0
0 1
0 2
0 3
0 4
0 5
0 6
0 7
Custom collate_fn failed!
Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 138, in _worker_loop
samples = collate_fn([dataset[i] for i in batch_indices])
File "", line 29, in collate_fn
batch[key] = batch[key].to(torch.device('cuda'))
RuntimeError: CUDA error: initialization error

Edit: Obviously for this MWE a custom collate is not necessary. However, in my use case, I need it because I have variable length data, so I do the padding in the collate_fn. But as you can see from the trace, it fails when moving the data to cuda.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions