Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ class PersistentDataset(Dataset):
Subsequent uses of a dataset directly read pre-processed results from `cache_dir`
followed by applying the random dependant parts of transform processing.

During training call `set_data()` to update input data and recompute cache content.

Note:
The input data must be a list of file paths and will hash them as cache keys.

Expand Down Expand Up @@ -173,6 +175,16 @@ def __init__(
if not self.cache_dir.is_dir():
raise ValueError("cache_dir must be a directory.")

def set_data(self, data: Sequence):
"""
Set the input data and delete all the out-dated cache content.

"""
self.data = data
if self.cache_dir is not None and self.cache_dir.exists():
shutil.rmtree(self.cache_dir, ignore_errors=True)
self.cache_dir.mkdir(parents=True, exist_ok=True)

def _pre_transform(self, item_transformed):
"""
Process the data from original state up to the first random element.
Expand Down Expand Up @@ -404,6 +416,14 @@ def __init__(
self._read_env = None
print(f"Accessing lmdb file: {self.db_file.absolute()}.")

def set_data(self, data: Sequence):
"""
Set the input data and delete all the out-dated cache content.

"""
super().set_data(data=data)
self._read_env = None

def _fill_cache_start_reader(self):
# create cache
self.lmdb_kwargs["readonly"] = False
Expand Down Expand Up @@ -515,6 +535,9 @@ class CacheDataset(Dataset):
``RandCropByPosNegLabeld`` and ``ToTensord``, as ``RandCropByPosNegLabeld`` is a randomized transform
and the outcome not cached.

During training call `set_data()` to update input data and recompute cache content, note that it requires
`persistent_workers=False` in the PyTorch DataLoader.

Note:
`CacheDataset` executes non-random transforms and prepares cache content in the main process before
the first epoch, then all the subprocesses of DataLoader will read the same cache content in the main process
Expand Down Expand Up @@ -555,6 +578,18 @@ def __init__(
self.num_workers = max(int(self.num_workers), 1)
self._cache: List = self._fill_cache()

def set_data(self, data: Sequence):
"""
Set the input data and run deterministic transforms to generate cache content.

Note: should call this func after an entire epoch and must set `persisten_workers=False`
in PyTorch DataLoader, because it needs to create new worker processes based on new
generated cache content.

"""
self.data = data
self._cache = self._fill_cache()

def _fill_cache(self) -> List:
if self.cache_num <= 0:
return []
Expand Down Expand Up @@ -639,6 +674,9 @@ class SmartCacheDataset(Randomizable, CacheDataset):
3. Call `update_cache()` before every epoch to replace training items.
4. Call `shutdown()` when training ends.

During training call `set_data()` to update input data and recompute cache content, note to call
`shutdown()` to stop first, then update data and call `start()` to restart.

Note:
This replacement will not work for below cases:
1. Set the `multiprocessing_context` of DataLoader to `spawn`.
Expand Down Expand Up @@ -683,6 +721,7 @@ def __init__(
self.set_random_state(seed=seed)
data = copy(data)
self.randomize(data)
self.shuffle = shuffle

super().__init__(data, transform, cache_num, cache_rate, num_init_workers, progress)
if self._cache is None:
Expand Down Expand Up @@ -711,6 +750,22 @@ def __init__(

self._compute_data_idx()

def set_data(self, data: Sequence):
"""
Set the input data and run deterministic transforms to generate cache content.

Note: should call `shutdown()` before calling this func.

"""
if self.is_started():
warnings.warn("SmartCacheDataset is not shutdown yet, shutdown it directly.")
self.shutdown()

if self.shuffle:
data = copy(data)
self.randomize(data)
super().set_data(data)

def randomize(self, data: Sequence) -> None:
try:
self.R.shuffle(data)
Expand Down Expand Up @@ -798,6 +853,8 @@ def _try_shutdown(self):
with self._update_lock:
if self._replace_done:
self._round = 0
self._start_pos = 0
self._compute_data_idx()
self._replace_done = False
return True
return False
Expand Down
27 changes: 26 additions & 1 deletion tests/test_cachedataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from parameterized import parameterized

from monai.data import CacheDataset, DataLoader, PersistentDataset, SmartCacheDataset
from monai.transforms import Compose, LoadImaged, ThreadUnsafe, Transform
from monai.transforms import Compose, Lambda, LoadImaged, ThreadUnsafe, Transform
from monai.utils import get_torch_version_tuple

TEST_CASE_1 = [Compose([LoadImaged(keys=["image", "label", "extra"])]), (128, 128, 128)]
Expand Down Expand Up @@ -81,6 +81,31 @@ def test_shape(self, transform, expected_shape):
for d in data3:
self.assertTupleEqual(d["image"].shape, expected_shape)

def test_set_data(self):
data_list1 = list(range(10))

transform = Lambda(func=lambda x: np.array([x * 10]))

dataset = CacheDataset(
data=data_list1,
transform=transform,
cache_rate=1.0,
num_workers=4,
progress=True,
)

num_workers = 2 if sys.platform == "linux" else 0
dataloader = DataLoader(dataset=dataset, num_workers=num_workers, batch_size=1)
for i, d in enumerate(dataloader):
np.testing.assert_allclose([[data_list1[i] * 10]], d)

# update the datalist and fill the cache content
data_list2 = list(range(-10, 0))
dataset.set_data(data=data_list2)
# rerun with updated cache content
for i, d in enumerate(dataloader):
np.testing.assert_allclose([[data_list2[i] * 10]], d)


class _StatefulTransform(Transform, ThreadUnsafe):
"""
Expand Down
58 changes: 39 additions & 19 deletions tests/test_lmdbdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,25 +158,45 @@ def test_shape(self, transform, expected_shape, kwargs=None):
data1_postcached = dataset_postcached[0]
data2_postcached = dataset_postcached[1]

if transform is None:
self.assertEqual(data1_precached["image"], os.path.join(tempdir, "test_image1.nii.gz"))
self.assertEqual(data2_precached["label"], os.path.join(tempdir, "test_label2.nii.gz"))
self.assertEqual(data1_postcached["image"], os.path.join(tempdir, "test_image1.nii.gz"))
self.assertEqual(data2_postcached["extra"], os.path.join(tempdir, "test_extra2.nii.gz"))
else:
self.assertTupleEqual(data1_precached["image"].shape, expected_shape)
self.assertTupleEqual(data1_precached["label"].shape, expected_shape)
self.assertTupleEqual(data1_precached["extra"].shape, expected_shape)
self.assertTupleEqual(data2_precached["image"].shape, expected_shape)
self.assertTupleEqual(data2_precached["label"].shape, expected_shape)
self.assertTupleEqual(data2_precached["extra"].shape, expected_shape)

self.assertTupleEqual(data1_postcached["image"].shape, expected_shape)
self.assertTupleEqual(data1_postcached["label"].shape, expected_shape)
self.assertTupleEqual(data1_postcached["extra"].shape, expected_shape)
self.assertTupleEqual(data2_postcached["image"].shape, expected_shape)
self.assertTupleEqual(data2_postcached["label"].shape, expected_shape)
self.assertTupleEqual(data2_postcached["extra"].shape, expected_shape)
if transform is None:
self.assertEqual(data1_precached["image"], os.path.join(tempdir, "test_image1.nii.gz"))
self.assertEqual(data2_precached["label"], os.path.join(tempdir, "test_label2.nii.gz"))
self.assertEqual(data1_postcached["image"], os.path.join(tempdir, "test_image1.nii.gz"))
self.assertEqual(data2_postcached["extra"], os.path.join(tempdir, "test_extra2.nii.gz"))
else:
self.assertTupleEqual(data1_precached["image"].shape, expected_shape)
self.assertTupleEqual(data1_precached["label"].shape, expected_shape)
self.assertTupleEqual(data1_precached["extra"].shape, expected_shape)
self.assertTupleEqual(data2_precached["image"].shape, expected_shape)
self.assertTupleEqual(data2_precached["label"].shape, expected_shape)
self.assertTupleEqual(data2_precached["extra"].shape, expected_shape)

self.assertTupleEqual(data1_postcached["image"].shape, expected_shape)
self.assertTupleEqual(data1_postcached["label"].shape, expected_shape)
self.assertTupleEqual(data1_postcached["extra"].shape, expected_shape)
self.assertTupleEqual(data2_postcached["image"].shape, expected_shape)
self.assertTupleEqual(data2_postcached["label"].shape, expected_shape)
self.assertTupleEqual(data2_postcached["extra"].shape, expected_shape)

# update the data to cache
test_data_new = [
{
"image": os.path.join(tempdir, "test_image1_new.nii.gz"),
"label": os.path.join(tempdir, "test_label1_new.nii.gz"),
"extra": os.path.join(tempdir, "test_extra1_new.nii.gz"),
},
{
"image": os.path.join(tempdir, "test_image2_new.nii.gz"),
"label": os.path.join(tempdir, "test_label2_new.nii.gz"),
"extra": os.path.join(tempdir, "test_extra2_new.nii.gz"),
},
]
dataset_postcached.set_data(data=test_data_new)
# test new exchanged cache content
if transform is None:
self.assertEqual(dataset_postcached[0]["image"], os.path.join(tempdir, "test_image1_new.nii.gz"))
self.assertEqual(dataset_postcached[0]["label"], os.path.join(tempdir, "test_label1_new.nii.gz"))
self.assertEqual(dataset_postcached[1]["extra"], os.path.join(tempdir, "test_extra2_new.nii.gz"))


@skip_if_windows
Expand Down
20 changes: 20 additions & 0 deletions tests/test_persistentdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,26 @@ def test_shape(self, transform, expected_shape):
for d in data3_postcached:
self.assertTupleEqual(d["image"].shape, expected_shape)

# update the data to cache
test_data_new = [
{
"image": os.path.join(tempdir, "test_image1_new.nii.gz"),
"label": os.path.join(tempdir, "test_label1_new.nii.gz"),
"extra": os.path.join(tempdir, "test_extra1_new.nii.gz"),
},
{
"image": os.path.join(tempdir, "test_image2_new.nii.gz"),
"label": os.path.join(tempdir, "test_label2_new.nii.gz"),
"extra": os.path.join(tempdir, "test_extra2_new.nii.gz"),
},
]
dataset_postcached.set_data(data=test_data_new)
# test new exchanged cache content
if transform is None:
self.assertEqual(dataset_postcached[0]["image"], os.path.join(tempdir, "test_image1_new.nii.gz"))
self.assertEqual(dataset_postcached[0]["label"], os.path.join(tempdir, "test_label1_new.nii.gz"))
self.assertEqual(dataset_postcached[1]["extra"], os.path.join(tempdir, "test_extra2_new.nii.gz"))


if __name__ == "__main__":
unittest.main()
48 changes: 46 additions & 2 deletions tests/test_smartcachedataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@

import copy
import os
import sys
import tempfile
import unittest

import nibabel as nib
import numpy as np
from parameterized import parameterized

from monai.data import SmartCacheDataset
from monai.transforms import Compose, LoadImaged
from monai.data import DataLoader, SmartCacheDataset
from monai.transforms import Compose, Lambda, LoadImaged

TEST_CASE_1 = [0.1, 0, Compose([LoadImaged(keys=["image", "label", "extra"])])]

Expand Down Expand Up @@ -126,6 +127,49 @@ def test_shuffle(self):

dataset.shutdown()

def test_set_data(self):
data_list1 = list(range(10))

transform = Lambda(func=lambda x: np.array([x * 10]))

dataset = SmartCacheDataset(
data=data_list1,
transform=transform,
cache_rate=0.5,
replace_rate=0.4,
num_init_workers=4,
num_replace_workers=2,
shuffle=False,
progress=True,
)

num_workers = 2 if sys.platform == "linux" else 0
dataloader = DataLoader(dataset=dataset, num_workers=num_workers, batch_size=1)

dataset.start()
for i, d in enumerate(dataloader):
np.testing.assert_allclose([[data_list1[i] * 10]], d)
# replace cache content, move forward 2(5 * 0.4) items
dataset.update_cache()
for i, d in enumerate(dataloader):
np.testing.assert_allclose([[data_list1[i + 2] * 10]], d)
# shutdown to update data
dataset.shutdown()
# update the datalist and fill the cache content
data_list2 = list(range(-10, 0))
dataset.set_data(data=data_list2)
# restart the dataset
dataset.start()
# rerun with updated cache content
for i, d in enumerate(dataloader):
np.testing.assert_allclose([[data_list2[i] * 10]], d)
# replace cache content, move forward 2(5 * 0.4) items
dataset.update_cache()
for i, d in enumerate(dataloader):
np.testing.assert_allclose([[data_list2[i + 2] * 10]], d)
# finally shutdown the dataset
dataset.shutdown()

def test_datalist(self):
data_list = [np.array([i]) for i in range(5)]
data_list_backup = copy.copy(data_list)
Expand Down