Skip to content
Closed
1 change: 1 addition & 0 deletions monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
LMDBDataset,
NPZDictItemDataset,
PersistentDataset,
SharedCacheDataset,
SmartCacheDataset,
ZipDataset,
)
Expand Down
128 changes: 128 additions & 0 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import time
import warnings
from copy import copy, deepcopy
from multiprocessing.managers import ListProxy # type:ignore
from multiprocessing.pool import ThreadPool
from pathlib import Path
from typing import IO, TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Union
Expand Down Expand Up @@ -886,6 +887,133 @@ def _transform(self, index: int):
return data


import torch.distributed as dist
class SharedCacheDataset(Dataset):
"""
Dataset with a shared cache among the processes. Particularly useful in DistributedDataParallel
multigpu run, when each process is able to read/write to the same shared cache list and
collectively cache the whole dataset in RAM.

Leading subset of non-random transforms are cached to accelerate the training data pipeline.
The transforms which are supposed to be cached must implement the `monai.transforms.Transform`
interface and should not be `Randomizable`. This dataset will cache the outcomes before the first
`Randomizable` `Transform` within a `Compose` instance.
So to improve the caching efficiency, please always put as many as possible non-random transforms
before the randomized ones when composing the chain of transforms.

For example, if the transform is a `Compose` of::

transforms = Compose([
LoadImaged(),
EnsureChannelFirstd(),
Spacingd(),
Orientationd(),
ScaleIntensityRanged(),
RandCropByPosNegLabeld(),
ToTensord()
])

when `transforms` is used in a multi-epoch training pipeline, before the first training epoch,
this dataset will cache the results up to ``ScaleIntensityRanged``, as
all non-random transforms `LoadImaged`, `EnsureChannelFirstd`, `Spacingd`, `Orientationd`, `ScaleIntensityRanged`
can be cached. During training, the dataset will load the cached results and run
``RandCropByPosNegLabeld`` and ``ToTensord``, as ``RandCropByPosNegLabeld`` is a randomized transform
and the outcome not cached.

Note:
Unlike `CacheDataset` class, `SharedCacheDataset` caches the full dataset in the shared memory (via ListProxy),
so each process within DistributedDataParallel can access items previously cached by other processes.
Data is cached on the fly, so there is no need to wait to cache all the data beforehand.
"""

def __init__(
self,
data: Sequence,
transform: Optional[Union[Sequence[Callable], Callable]] = None,
copy_cache: bool = False,
as_contiguous: bool = True,
cache_list: Optional[ListProxy] = None,
use_cache: bool = True,
) -> None:
"""
Args:
data: input data to load and transform to generate dataset for model.
transform: transforms to execute operations on input data.
copy_cache: whether to `deepcopy` the cache content before applying the random transforms,
default to `False`.
as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
it may help improve the performance of following logic.
cache_list: ListProxy instance created by a master process, to hold the shared memory list
It must be created on the master process e.g. as cache_list = multiprocessing.Manager().list()
before spawning/forking the subprocesses. If no cache_list is provided, a local (per process)
non-shared cache list will be created (could be sufficient in single gpu environment)
use_cache: whether to use cache mechanism. Defaults to True. When False, the logic becomes equivalent
to a Dataset super class (no caching)
"""
if not isinstance(transform, Compose):
transform = Compose(transform)
super().__init__(data=data, transform=transform)
self.copy_cache = copy_cache
self.as_contiguous = as_contiguous

if cache_list is None and use_cache:
cache_list = torch.multiprocessing.Manager().list()
if dist.is_initialized():
print('Print using shared cache via broadcasting in DDP, torch.cuda.set_device(device) must be set before')
object_list=[cache_list,]
dist.broadcast_object_list(object_list, src=0)
cache_list = object_list[0]

if cache_list is not None:
cache_list[:] = [None] * len(data)

self._cache = cache_list

def _load_cache_item(self, idx: int):
"""
Args:
idx: the index of the input data sequence.
"""
item = self.data[idx]
for _transform in self.transform.transforms: # type:ignore
# execute all the deterministic transforms
if isinstance(_transform, Randomizable) or not isinstance(_transform, Transform):
break
_xform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform
item = apply_transform(_xform, item)
if self.as_contiguous:
item = convert_to_contiguous(item, memory_format=torch.contiguous_format)
return item

def _transform(self, index: int):
"""
Args:
index: the index of the input data sequence.
"""
if self._cache is None:
return super()._transform(index=index)

data = self._cache[index]

# if data is not in cache yet, transform (non-randoms) and cache it
if data is None:
data = self._load_cache_item(index)
self._cache[index] = data

# proceed with randomizable transforms
start_run = False
comp_transform: Compose = self.transform # type:ignore
for _transform in comp_transform.transforms:
if start_run or isinstance(_transform, Randomizable) or not isinstance(_transform, Transform):
# only need to deep copy data on first non-deterministic transform
if not start_run:
start_run = True
if self.copy_cache:
data = deepcopy(data)
data = apply_transform(_transform, data)
return data


class SmartCacheDataset(Randomizable, CacheDataset):
"""
Re-implementation of the SmartCache mechanism in NVIDIA Clara-train SDK.
Expand Down
1 change: 1 addition & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def run_testsuit():
"test_auto3dseg_hpo",
"test_auto3dseg",
"test_cachedataset",
"test_sharedcachedataset",
"test_cachedataset_parallel",
"test_cachedataset_persistent_workers",
"test_cachentransdataset",
Expand Down
151 changes: 151 additions & 0 deletions tests/test_sharedcachedataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile
import unittest

import nibabel as nib
import numpy as np
import torch
import torch.distributed as dist
from parameterized import parameterized
from torch.utils.data.distributed import DistributedSampler

from monai.data import DataLoader, SharedCacheDataset
from monai.transforms import Compose, LoadImaged, RandomizableTransform, Transform

TEST_CASE_1 = [Compose([LoadImaged(keys=["image", "label", "extra"])]), (128, 128, 128)]
TEST_CASE_2 = [None, (128, 128, 128)]


class TestCacheDataset(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_shape(self, transform, expected_shape):
test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4), dtype=np.uint8)
with tempfile.TemporaryDirectory() as tempdir:
test_data = []
for i in ["1", "2"]:
for k in ["image", "label", "extra"]:
nib.save(test_image, os.path.join(tempdir, f"{k}{i}.nii.gz"))
test_data.append({k: os.path.join(tempdir, f"{k}{i}.nii.gz") for k in ["image", "label", "extra"]})

dataset = SharedCacheDataset(data=test_data, transform=transform)
data1 = dataset[0]
data2 = dataset[1]
data3 = dataset[0:-1]
data4 = dataset[-1]
self.assertEqual(len(data3), 1)

if transform is None:
# Check without providing transfrom
dataset2 = SharedCacheDataset(data=test_data)
for k in ["image", "label", "extra"]:
self.assertEqual(dataset[0][k], dataset2[0][k])

if transform is None:
self.assertEqual(data1["image"], os.path.join(tempdir, "image1.nii.gz"))
self.assertEqual(data2["label"], os.path.join(tempdir, "label2.nii.gz"))
self.assertEqual(data4["image"], os.path.join(tempdir, "image2.nii.gz"))
else:
self.assertTupleEqual(data1["image"].shape, expected_shape)
self.assertTupleEqual(data1["label"].shape, expected_shape)
self.assertTupleEqual(data1["extra"].shape, expected_shape)
self.assertTupleEqual(data2["image"].shape, expected_shape)
self.assertTupleEqual(data2["label"].shape, expected_shape)
self.assertTupleEqual(data2["extra"].shape, expected_shape)
for d in data3:
self.assertTupleEqual(d["image"].shape, expected_shape)


class TransformNonrandom(Transform):
def __call__(self, x):
return np.array([x * 10])


class TransformRandom(RandomizableTransform):
def __call__(self, x):
return x + 1


def main_worker(rank, nprocs, cache_list):

has_cuda = torch.cuda.is_available() and torch.cuda.device_count() > 1

device = torch.device(rank) if has_cuda else torch.device("cpu")
device_ids = [rank] if has_cuda else None
output_device = device if has_cuda else None
backend = "nccl" if has_cuda else "gloo"

dist.init_process_group(backend=backend, init_method="tcp://127.0.0.1:12345", world_size=nprocs, rank=rank)
model = torch.nn.Conv3d(in_channels=1, out_channels=32, kernel_size=3, bias=True).to(device=device)
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=device_ids, output_device=output_device, find_unused_parameters=False
)

data_list1 = list(range(4 * nprocs))
transform = Compose([TransformNonrandom(), TransformRandom()])

dataset = SharedCacheDataset(data=data_list1, transform=transform, copy_cache=False, cache_list=cache_list)
sampler = DistributedSampler(dataset, shuffle=False)
dataloader = DataLoader(dataset, num_workers=2, sampler=sampler)
ids = list(range(rank, len(data_list1), nprocs)) # ids for given process that DistributedSampler will use

# each process goes only over a small subset of the data (and caches in shared cache)
p = 0
for i, d in enumerate(dataloader):
# print(rank, i, d)
expected_data = data_list1[ids[i]] * 10 + 1
np.testing.assert_allclose([[expected_data]], d)
p = p + 1
assert p == len(dataset) // nprocs, f"each process processed {p} out of {len(dataset)}"
torch.distributed.barrier()

# at this point the full dataset is cached, and every process has access to it
# lets inspect cache
for i in range(len(dataset)):
expected_data = data_list1[i] * 10 # cached part was only the first transform
cache = dataset._cache[i]
np.testing.assert_allclose(expected_data, cache)
torch.distributed.barrier()

# lets update cache directly by +1
for i in ids:
dataset._cache[i] += 1
torch.distributed.barrier()

# inspect results, must have output +1 (since cache will be used instead of the first transform)
for i, d in enumerate(dataloader):
expected_data = data_list1[rank : len(data_list1) : nprocs][i] * 10 + 1 + 1 # expecting +1 in output
# print(rank, i, d, expected_data)
np.testing.assert_allclose([[expected_data]], d)
torch.distributed.barrier()

# print('processed rank', rank)
cache_list[:] = []
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()


class TestDDP(unittest.TestCase):
def test_ddp_ops(self):
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
nprocs = torch.cuda.device_count()
else:
nprocs = 2

manager = torch.multiprocessing.Manager()
cache_list = manager.list()
torch.multiprocessing.spawn(main_worker, nprocs=nprocs, args=(nprocs, cache_list))


if __name__ == "__main__":
unittest.main()