Skip to content

Refactor MaskedInferenceWSIDataset #4014

@bhashemian

Description

@bhashemian

Refactor and enhance MaskedInferenceWSIDataset in pathology, for blending it into core MONAI as laid out in #4005.

class MaskedInferenceWSIDataset(Dataset):
"""
This dataset load the provided foreground masks at an arbitrary resolution level,
and extract patches based on that mask from the associated whole slide image.
Args:
data: a list of sample including the path to the whole slide image and the path to the mask.
Like this: `[{"image": "path/to/image1.tiff", "mask": "path/to/mask1.npy}, ...]"`.
patch_size: the size of patches to be extracted from the whole slide image for inference.
transform: transforms to be executed on extracted patches.
image_reader_name: the name of library to be used for loading whole slide imaging, either CuCIM or OpenSlide.
Defaults to CuCIM.
Note:
The resulting output (probability maps) after performing inference using this dataset is
supposed to be the same size as the foreground mask and not the original wsi image size.
"""
def __init__(
self,
data: List[Dict["str", "str"]],
patch_size: Union[int, Tuple[int, int]],
transform: Optional[Callable] = None,
image_reader_name: str = "cuCIM",
) -> None:
super().__init__(data, transform)
self.patch_size = ensure_tuple_rep(patch_size, 2)
# set up whole slide image reader
self.image_reader_name = image_reader_name.lower()
self.image_reader = WSIReader(image_reader_name)
# process data and create a list of dictionaries containing all required data and metadata
self.data = self._prepare_data(data)
# calculate cumulative number of patches for all the samples
self.num_patches_per_sample = [len(d["image_locations"]) for d in self.data]
self.num_patches = sum(self.num_patches_per_sample)
self.cum_num_patches = np.cumsum([0] + self.num_patches_per_sample[:-1])
def _prepare_data(self, input_data: List[Dict["str", "str"]]) -> List[Dict]:
prepared_data = []
for sample in input_data:
prepared_sample = self._prepare_a_sample(sample)
prepared_data.append(prepared_sample)
return prepared_data
def _prepare_a_sample(self, sample: Dict["str", "str"]) -> Dict:
"""
Preprocess input data to load WSIReader object and the foreground mask,
and define the locations where patches need to be extracted.
Args:
sample: one sample, a dictionary containing path to the whole slide image and the foreground mask.
For example: `{"image": "path/to/image1.tiff", "mask": "path/to/mask1.npy}`
Return:
A dictionary containing:
"name": the base name of the whole slide image,
"image": the WSIReader image object,
"mask_shape": the size of the foreground mask,
"mask_locations": the list of non-zero pixel locations (x, y) on the foreground mask,
"image_locations": the list of pixel locations (x, y) on the whole slide image where patches are extracted, and
"level": the resolution level of the mask with respect to the whole slide image.
}
"""
image = self.image_reader.read(sample["image"])
mask = np.load(sample["mask"])
try:
level, ratio = self._calculate_mask_level(image, mask)
except ValueError as err:
err.args = (sample["mask"],) + err.args
raise
# get all indices for non-zero pixels of the foreground mask
mask_locations = np.vstack(mask.nonzero()).T
# convert mask locations to image locations to extract patches
image_locations = (mask_locations + 0.5) * ratio - np.array(self.patch_size) // 2
return {
"name": os.path.splitext(os.path.basename(sample["image"]))[0],
"image": image,
"mask_shape": mask.shape,
"mask_locations": mask_locations.astype(int).tolist(),
"image_locations": image_locations.astype(int).tolist(),
"level": level,
}
def _calculate_mask_level(self, image: np.ndarray, mask: np.ndarray) -> Tuple[int, float]:
"""
Calculate level of the mask and its ratio with respect to the whole slide image
Args:
image: the original whole slide image
mask: a mask, that can be down-sampled at an arbitrary level.
Note that down-sampling ratio should be 2^N and equal in all dimension.
Return:
tuple: (level, ratio) where ratio is 2^level
"""
image_shape = image.shape
mask_shape = mask.shape
ratios = [image_shape[i] / mask_shape[i] for i in range(2)]
level = np.log2(ratios[0])
if ratios[0] != ratios[1]:
raise ValueError(
"Image/Mask ratio across dimensions does not match!"
f"ratio 0: {ratios[0]} ({image_shape[0]} / {mask_shape[0]}),"
f"ratio 1: {ratios[1]} ({image_shape[1]} / {mask_shape[1]}),"
)
if not level.is_integer():
raise ValueError(f"Mask is not at a regular level (ratio not power of 2), image / mask ratio: {ratios[0]}")
return int(level), ratios[0]
def _load_a_patch(self, index):
"""
Load sample given the index
Since index is sequential and the patches are coming in an stream from different images,
this method, first, finds the whole slide image and the patch that should be extracted,
then it loads the patch and provide it with its image name and the corresponding mask location.
"""
sample_num = np.argmax(self.cum_num_patches > index) - 1
sample = self.data[sample_num]
patch_num = index - self.cum_num_patches[sample_num]
location_on_image = sample["image_locations"][patch_num]
location_on_mask = sample["mask_locations"][patch_num]
image, _ = self.image_reader.get_data(img=sample["image"], location=location_on_image, size=self.patch_size)
processed_sample = {"image": image, "name": sample["name"], "mask_location": location_on_mask}
return processed_sample
def __len__(self):
return self.num_patches
def __getitem__(self, index):
patch = [self._load_a_patch(index)]
if self.transform:
patch = self.transform(patch)
return patch

Metadata

Metadata

Assignees

Labels

enhancementNew feature or requestrefactorNon-breaking feature enhancements

Type

No type

Projects

Status

💯 Complete

Relationships

None yet

Development

No branches or pull requests

Issue actions