Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion test/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ def test_sequential(self):
def test_sequential_batch(self):
self._test_sequential(DataLoader(self.dataset, batch_size=2))

def test_sequential_pin_memory(self):
loader = DataLoader(self.dataset, batch_size=2, pin_memory=True)
for input, target in loader:
self.assertTrue(input.is_pinned())
self.assertTrue(target.is_pinned())

def test_shuffle(self):
self._test_shuffle(DataLoader(self.dataset, shuffle=True))

Expand All @@ -110,6 +116,12 @@ def test_shuffle_workers(self):
def test_shuffle_batch_workers(self):
self._test_shuffle(DataLoader(self.dataset, batch_size=2, shuffle=True, num_workers=4))

def test_shuffle_pin_memory(self):
loader = DataLoader(self.dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True)
for input, target in loader:
self.assertTrue(input.is_pinned())
self.assertTrue(target.is_pinned())

def test_error(self):
self._test_error(DataLoader(ErrorDataset(100), batch_size=2, shuffle=True))

Expand All @@ -118,8 +130,9 @@ def test_error_workers(self):

def test_partial_workers(self):
"check that workers exit even if the iterator is not exhausted"
loader = iter(DataLoader(self.dataset, batch_size=2, num_workers=4))
loader = iter(DataLoader(self.dataset, batch_size=2, num_workers=4, pin_memory=True))
workers = loader.workers
pin_thread = loader.pin_thread
for i, sample in enumerate(loader):
if i == 3:
break
Expand All @@ -128,6 +141,8 @@ def test_partial_workers(self):
w.join(1.0) # timeout of one second
self.assertFalse(w.is_alive(), 'subprocess not terminated')
self.assertEqual(w.exitcode, 0)
pin_thread.join(1.0)
self.assertFalse(pin_thread.is_alive())


if __name__ == '__main__':
Expand Down
50 changes: 48 additions & 2 deletions torch/utils/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
import collections
import sys
import traceback
import threading
if sys.version_info[0] == 2:
import Queue as queue
else:
import queue


class ExceptionWrapper(object):
Expand All @@ -19,6 +24,7 @@ def _worker_loop(dataset, index_queue, data_queue, collate_fn):
while True:
r = index_queue.get()
if r is None:
data_queue.put(None)
break
idx, batch_indices = r
try:
Expand All @@ -29,6 +35,23 @@ def _worker_loop(dataset, index_queue, data_queue, collate_fn):
data_queue.put((idx, samples))


def _pin_memory_loop(in_queue, out_queue):
while True:
r = in_queue.get()
if r is None:
break
if isinstance(r[1], ExceptionWrapper):
out_queue.put(r)
continue
idx, batch = r
try:
batch = pin_memory_batch(batch)
except Exception:
out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else:
out_queue.put((idx, batch))


def default_collate(batch):
"Puts each data field into a tensor with outer dimension batch size"
if torch.is_tensor(batch[0]):
Expand All @@ -49,6 +72,15 @@ def default_collate(batch):
.format(type(batch[0]))))


def pin_memory_batch(batch):
if torch.is_tensor(batch):
return batch.pin_memory()
elif isinstance(batch, collections.Iterable):
return [pin_memory_batch(sample) for sample in batch]
else:
return batch


class DataLoaderIter(object):
"Iterates once over the DataLoader's dataset, as specified by the sampler"

Expand All @@ -58,6 +90,7 @@ def __init__(self, loader):
self.collate_fn = loader.collate_fn
self.sampler = loader.sampler
self.num_workers = loader.num_workers
self.pin_memory = loader.pin_memory

self.samples_remaining = len(self.sampler)
self.sample_iter = iter(self.sampler)
Expand All @@ -81,6 +114,15 @@ def __init__(self, loader):
w.daemon = True # ensure that the worker exits on process exit
w.start()

if self.pin_memory:
in_data = self.data_queue
self.data_queue = queue.Queue()
self.pin_thread = threading.Thread(
target=_pin_memory_loop,
args=(in_data, self.data_queue))
self.pin_thread.daemon = True
self.pin_thread.start()

# prime the prefetch loop
for _ in range(2 * self.num_workers):
self._put_indices()
Expand All @@ -94,7 +136,10 @@ def __next__(self):
if self.samples_remaining == 0:
raise StopIteration
indices = self._next_indices()
return self.collate_fn([self.dataset[i] for i in indices])
batch = self.collate_fn([self.dataset[i] for i in indices])
if self.pin_memory:
batch = pin_memory_batch(batch)
return batch

# check if the next sample has already been generated
if self.rcvd_idx in self.reorder_dict:
Expand Down Expand Up @@ -166,11 +211,12 @@ class DataLoader(object):
"""

def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
num_workers=0, collate_fn=default_collate):
num_workers=0, collate_fn=default_collate, pin_memory=False):
self.dataset = dataset
self.batch_size = batch_size
self.num_workers = num_workers
self.collate_fn = collate_fn
self.pin_memory = pin_memory

if sampler is not None:
self.sampler = sampler
Expand Down