-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Refactor MaskedInferenceWSIDataset #4014
Copy link
Copy link
Closed
Labels
enhancementNew feature or requestNew feature or requestrefactorNon-breaking feature enhancementsNon-breaking feature enhancements
Description
Refactor and enhance MaskedInferenceWSIDataset in pathology, for blending it into core MONAI as laid out in #4005.
MONAI/monai/apps/pathology/data/datasets.py
Lines 175 to 319 in a676e38
| class MaskedInferenceWSIDataset(Dataset): | |
| """ | |
| This dataset load the provided foreground masks at an arbitrary resolution level, | |
| and extract patches based on that mask from the associated whole slide image. | |
| Args: | |
| data: a list of sample including the path to the whole slide image and the path to the mask. | |
| Like this: `[{"image": "path/to/image1.tiff", "mask": "path/to/mask1.npy}, ...]"`. | |
| patch_size: the size of patches to be extracted from the whole slide image for inference. | |
| transform: transforms to be executed on extracted patches. | |
| image_reader_name: the name of library to be used for loading whole slide imaging, either CuCIM or OpenSlide. | |
| Defaults to CuCIM. | |
| Note: | |
| The resulting output (probability maps) after performing inference using this dataset is | |
| supposed to be the same size as the foreground mask and not the original wsi image size. | |
| """ | |
| def __init__( | |
| self, | |
| data: List[Dict["str", "str"]], | |
| patch_size: Union[int, Tuple[int, int]], | |
| transform: Optional[Callable] = None, | |
| image_reader_name: str = "cuCIM", | |
| ) -> None: | |
| super().__init__(data, transform) | |
| self.patch_size = ensure_tuple_rep(patch_size, 2) | |
| # set up whole slide image reader | |
| self.image_reader_name = image_reader_name.lower() | |
| self.image_reader = WSIReader(image_reader_name) | |
| # process data and create a list of dictionaries containing all required data and metadata | |
| self.data = self._prepare_data(data) | |
| # calculate cumulative number of patches for all the samples | |
| self.num_patches_per_sample = [len(d["image_locations"]) for d in self.data] | |
| self.num_patches = sum(self.num_patches_per_sample) | |
| self.cum_num_patches = np.cumsum([0] + self.num_patches_per_sample[:-1]) | |
| def _prepare_data(self, input_data: List[Dict["str", "str"]]) -> List[Dict]: | |
| prepared_data = [] | |
| for sample in input_data: | |
| prepared_sample = self._prepare_a_sample(sample) | |
| prepared_data.append(prepared_sample) | |
| return prepared_data | |
| def _prepare_a_sample(self, sample: Dict["str", "str"]) -> Dict: | |
| """ | |
| Preprocess input data to load WSIReader object and the foreground mask, | |
| and define the locations where patches need to be extracted. | |
| Args: | |
| sample: one sample, a dictionary containing path to the whole slide image and the foreground mask. | |
| For example: `{"image": "path/to/image1.tiff", "mask": "path/to/mask1.npy}` | |
| Return: | |
| A dictionary containing: | |
| "name": the base name of the whole slide image, | |
| "image": the WSIReader image object, | |
| "mask_shape": the size of the foreground mask, | |
| "mask_locations": the list of non-zero pixel locations (x, y) on the foreground mask, | |
| "image_locations": the list of pixel locations (x, y) on the whole slide image where patches are extracted, and | |
| "level": the resolution level of the mask with respect to the whole slide image. | |
| } | |
| """ | |
| image = self.image_reader.read(sample["image"]) | |
| mask = np.load(sample["mask"]) | |
| try: | |
| level, ratio = self._calculate_mask_level(image, mask) | |
| except ValueError as err: | |
| err.args = (sample["mask"],) + err.args | |
| raise | |
| # get all indices for non-zero pixels of the foreground mask | |
| mask_locations = np.vstack(mask.nonzero()).T | |
| # convert mask locations to image locations to extract patches | |
| image_locations = (mask_locations + 0.5) * ratio - np.array(self.patch_size) // 2 | |
| return { | |
| "name": os.path.splitext(os.path.basename(sample["image"]))[0], | |
| "image": image, | |
| "mask_shape": mask.shape, | |
| "mask_locations": mask_locations.astype(int).tolist(), | |
| "image_locations": image_locations.astype(int).tolist(), | |
| "level": level, | |
| } | |
| def _calculate_mask_level(self, image: np.ndarray, mask: np.ndarray) -> Tuple[int, float]: | |
| """ | |
| Calculate level of the mask and its ratio with respect to the whole slide image | |
| Args: | |
| image: the original whole slide image | |
| mask: a mask, that can be down-sampled at an arbitrary level. | |
| Note that down-sampling ratio should be 2^N and equal in all dimension. | |
| Return: | |
| tuple: (level, ratio) where ratio is 2^level | |
| """ | |
| image_shape = image.shape | |
| mask_shape = mask.shape | |
| ratios = [image_shape[i] / mask_shape[i] for i in range(2)] | |
| level = np.log2(ratios[0]) | |
| if ratios[0] != ratios[1]: | |
| raise ValueError( | |
| "Image/Mask ratio across dimensions does not match!" | |
| f"ratio 0: {ratios[0]} ({image_shape[0]} / {mask_shape[0]})," | |
| f"ratio 1: {ratios[1]} ({image_shape[1]} / {mask_shape[1]})," | |
| ) | |
| if not level.is_integer(): | |
| raise ValueError(f"Mask is not at a regular level (ratio not power of 2), image / mask ratio: {ratios[0]}") | |
| return int(level), ratios[0] | |
| def _load_a_patch(self, index): | |
| """ | |
| Load sample given the index | |
| Since index is sequential and the patches are coming in an stream from different images, | |
| this method, first, finds the whole slide image and the patch that should be extracted, | |
| then it loads the patch and provide it with its image name and the corresponding mask location. | |
| """ | |
| sample_num = np.argmax(self.cum_num_patches > index) - 1 | |
| sample = self.data[sample_num] | |
| patch_num = index - self.cum_num_patches[sample_num] | |
| location_on_image = sample["image_locations"][patch_num] | |
| location_on_mask = sample["mask_locations"][patch_num] | |
| image, _ = self.image_reader.get_data(img=sample["image"], location=location_on_image, size=self.patch_size) | |
| processed_sample = {"image": image, "name": sample["name"], "mask_location": location_on_mask} | |
| return processed_sample | |
| def __len__(self): | |
| return self.num_patches | |
| def __getitem__(self, index): | |
| patch = [self._load_a_patch(index)] | |
| if self.transform: | |
| patch = self.transform(patch) | |
| return patch |
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or requestrefactorNon-breaking feature enhancementsNon-breaking feature enhancements
Type
Projects
Status
💯 Complete