Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions monai/config/type_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,16 @@

DtypeLike = Union[np.dtype, type, None]
"""Type of datatypes
adapted from https://github.com/numpy/numpy/blob/master/numpy/typing/_dtype_like.py

Adapted from https://github.com/numpy/numpy/blob/master/numpy/typing/_dtype_like.py
"""

# Generic type which can represent either a numpy.ndarray or a torch.Tensor
# Unlike Union can create a dependence between parameter(s) / return(s)
NdarrayTensor = TypeVar("NdarrayTensor", np.ndarray, torch.Tensor)
"""NdarrayTensor

Generic type which can represent either a numpy.ndarray or a torch.Tensor
Unlike Union can create a dependence between parameter(s) / return(s)
"""

TensorOrList = Union[torch.Tensor, Sequence[torch.Tensor]]
"""TensorOrList
Expand Down
91 changes: 83 additions & 8 deletions monai/transforms/croppad/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@

import numpy as np
import torch
from torch.nn.functional import pad as pad_pt

from monai.config import IndexSelection
from monai.config.type_definitions import NdarrayTensor
from monai.data.utils import get_random_patch, get_valid_patch_size
from monai.transforms.transform import Randomizable, Transform
from monai.transforms.utils import (
Expand All @@ -34,6 +36,8 @@
weighted_patch_samples,
)
from monai.utils import Method, NumpyPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple, look_up_option
from monai.utils.enums import TransformBackends
from monai.utils.type_conversion import convert_data_type

__all__ = [
"SpatialPad",
Expand All @@ -54,9 +58,72 @@
]


class Pad(Transform):
"""
Perform padding for a given an amount of padding in each dimension.
If input is `torch.Tensor` and mode is `constant`, `torch.nn.functional.pad` will be used.
Otherwise, `np.pad` will be used (input converted to `np.ndarray` if necessary).
Uses np.pad so in practice, a mode needs to be provided. See numpy.lib.arraypad.pad
for additional details.
Args:
to_pad: the amount to be padded in each dimension [(low_H, high_H), (low_W, high_W), ...].
mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``,
``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
One of the listed string values or a user supplied function. Defaults to ``"constant"``.
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(
self,
to_pad: List[Tuple[int, int]],
mode: Union[NumpyPadMode, str, None] = NumpyPadMode.CONSTANT,
**np_kwargs,
) -> None:
self.to_pad = to_pad
self.mode = mode or NumpyPadMode.CONSTANT
self.np_kwargs = np_kwargs

@staticmethod
def _np_pad(img: np.ndarray, all_pad_width, mode, **np_kwargs) -> np.ndarray:
img_np, *_ = convert_data_type(img, np.ndarray)
return np.pad(img_np, all_pad_width, mode=mode, **np_kwargs) # type: ignore

@staticmethod
def _pt_pad(img: torch.Tensor, all_pad_width, mode, **np_kwargs) -> torch.Tensor:
pt_pad_width = [val for sublist in all_pad_width for val in sublist[::-1]][::-1]
return pad_pt(img, pt_pad_width, mode=mode, **np_kwargs)

def __call__(self, img: NdarrayTensor, mode: Optional[Union[NumpyPadMode, str]] = None) -> NdarrayTensor:
"""
Args:
img: data to be transformed, assuming `img` is channel-first and
padding doesn't apply to the channel dim.
mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``,
``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
One of the listed string values or a user supplied function. Defaults to ``self.mode``.
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
"""
if not np.asarray(self.to_pad).any():
# all zeros, skip padding
return img
mode = mode or self.mode
mode = mode.value if isinstance(mode, NumpyPadMode) else mode
if isinstance(img, torch.Tensor) and mode == "constant" and not self.np_kwargs:
pad = self._pt_pad
else:
pad = self._np_pad # type: ignore
return pad(img, self.to_pad, mode, **self.np_kwargs)


class SpatialPad(Transform):
"""
Performs padding to the data, symmetric for all sides or all on one side for each dimension.

If input is `torch.Tensor` and mode is `constant`, `torch.nn.functional.pad` will be used.
Otherwise, `np.pad` will be used (input converted to `np.ndarray` if necessary).

Uses np.pad so in practice, a mode needs to be provided. See numpy.lib.arraypad.pad
for additional details.

Expand All @@ -77,6 +144,8 @@ class SpatialPad(Transform):

"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(
self,
spatial_size: Union[Sequence[int], int],
Expand All @@ -99,7 +168,7 @@ def _determine_data_pad_width(self, data_shape: Sequence[int]) -> List[Tuple[int
return pad_width
return [(0, max(sp_i - data_shape[i], 0)) for i, sp_i in enumerate(spatial_size)]

def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = None) -> np.ndarray:
def __call__(self, img: NdarrayTensor, mode: Optional[Union[NumpyPadMode, str]] = None) -> NdarrayTensor:
"""
Args:
img: data to be transformed, assuming `img` is channel-first and
Expand All @@ -115,9 +184,9 @@ def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = N
# all zeros, skip padding
return img

mode = look_up_option(self.mode if mode is None else mode, NumpyPadMode).value
img = np.pad(img, all_pad_width, mode=mode, **self.np_kwargs)
return img
mode = look_up_option(mode or self.mode, NumpyPadMode)
padder = Pad(all_pad_width, mode, **self.np_kwargs)
return padder(img)


class BorderPad(Transform):
Expand Down Expand Up @@ -145,6 +214,8 @@ class BorderPad(Transform):

"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(
self,
spatial_border: Union[Sequence[int], int],
Expand All @@ -155,7 +226,7 @@ def __init__(
self.mode: NumpyPadMode = look_up_option(mode, NumpyPadMode)
self.np_kwargs = np_kwargs

def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = None):
def __call__(self, img: NdarrayTensor, mode: Optional[Union[NumpyPadMode, str]] = None) -> NdarrayTensor:
"""
Args:
img: data to be transformed, assuming `img` is channel-first and
Expand Down Expand Up @@ -189,15 +260,19 @@ def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = N
f"[1, len(spatial_shape)={len(spatial_shape)}, 2*len(spatial_shape)={2*len(spatial_shape)}]."
)

mode = look_up_option(self.mode if mode is None else mode, NumpyPadMode).value
return np.pad(img, [(0, 0)] + data_pad_width, mode=mode, **self.np_kwargs)
all_pad_width = [(0, 0)] + data_pad_width
mode = look_up_option(mode or self.mode, NumpyPadMode)
padder = Pad(all_pad_width, mode, **self.np_kwargs)
return padder(img)


class DivisiblePad(Transform):
"""
Pad the input data, so that the spatial sizes are divisible by `k`.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(
self,
k: Union[Sequence[int], int],
Expand Down Expand Up @@ -226,7 +301,7 @@ def __init__(
self.method: Method = Method(method)
self.np_kwargs = np_kwargs

def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = None) -> np.ndarray:
def __call__(self, img: NdarrayTensor, mode: Optional[Union[NumpyPadMode, str]] = None) -> NdarrayTensor:
"""
Args:
img: data to be transformed, assuming `img` is channel-first
Expand Down
15 changes: 11 additions & 4 deletions monai/transforms/croppad/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import numpy as np

from monai.config import IndexSelection, KeysCollection
from monai.config.type_definitions import NdarrayTensor
from monai.data.utils import get_random_patch, get_valid_patch_size
from monai.transforms.croppad.array import (
BorderPad,
Expand All @@ -49,7 +50,7 @@
)
from monai.utils import ImageMetaKey as Key
from monai.utils import Method, NumpyPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple
from monai.utils.enums import InverseKeys
from monai.utils.enums import InverseKeys, TransformBackends

__all__ = [
"NumpyPadModeSequence",
Expand Down Expand Up @@ -106,6 +107,8 @@ class SpatialPadd(MapTransform, InvertibleTransform):
Performs padding to the data, symmetric for all sides or all on one side for each dimension.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(
self,
keys: KeysCollection,
Expand Down Expand Up @@ -140,7 +143,7 @@ def __init__(
self.mode = ensure_tuple_rep(mode, len(self.keys))
self.padder = SpatialPad(spatial_size, method, **np_kwargs)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]:
d = dict(data)
for key, m in self.key_iterator(d, self.mode):
self.push_transform(d, key, extra_info={"mode": m.value if isinstance(m, Enum) else m})
Expand Down Expand Up @@ -174,6 +177,8 @@ class BorderPadd(MapTransform, InvertibleTransform):
Dictionary-based wrapper of :py:class:`monai.transforms.BorderPad`.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(
self,
keys: KeysCollection,
Expand Down Expand Up @@ -211,7 +216,7 @@ def __init__(
self.mode = ensure_tuple_rep(mode, len(self.keys))
self.padder = BorderPad(spatial_border=spatial_border, **np_kwargs)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]:
d = dict(data)
for key, m in self.key_iterator(d, self.mode):
self.push_transform(d, key, extra_info={"mode": m.value if isinstance(m, Enum) else m})
Expand Down Expand Up @@ -249,6 +254,8 @@ class DivisiblePadd(MapTransform, InvertibleTransform):
Dictionary-based wrapper of :py:class:`monai.transforms.DivisiblePad`.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(
self,
keys: KeysCollection,
Expand Down Expand Up @@ -283,7 +290,7 @@ def __init__(
self.mode = ensure_tuple_rep(mode, len(self.keys))
self.padder = DivisiblePad(k=k, method=method, **np_kwargs)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]:
d = dict(data)
for key, m in self.key_iterator(d, self.mode):
self.push_transform(d, key, extra_info={"mode": m.value if isinstance(m, Enum) else m})
Expand Down
17 changes: 9 additions & 8 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
min_version,
optional_import,
)
from monai.utils.enums import TransformBackends
from monai.utils.type_conversion import convert_data_type

PILImageImage, has_pil = optional_import("PIL.Image", name="Image")
pil_image_fromarray, _ = optional_import("PIL.Image", name="fromarray")
Expand Down Expand Up @@ -288,16 +290,16 @@ class CastToType(Transform):
specified PyTorch data type.
"""

backends = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, dtype=np.float32) -> None:
"""
Args:
dtype: convert image to this data type, default is `np.float32`.
"""
self.dtype = dtype

def __call__(
self, img: Union[np.ndarray, torch.Tensor], dtype: Optional[Union[DtypeLike, torch.dtype]] = None
) -> Union[np.ndarray, torch.Tensor]:
def __call__(self, img: NdarrayTensor, dtype: Optional[Union[DtypeLike, torch.dtype]] = None) -> NdarrayTensor:
"""
Apply the transform to `img`, assuming `img` is a numpy array or PyTorch Tensor.

Expand All @@ -308,11 +310,10 @@ def __call__(
TypeError: When ``img`` type is not in ``Union[numpy.ndarray, torch.Tensor]``.

"""
if isinstance(img, np.ndarray):
return img.astype(self.dtype if dtype is None else dtype) # type: ignore
if isinstance(img, torch.Tensor):
return torch.as_tensor(img, dtype=self.dtype if dtype is None else dtype)
raise TypeError(f"img must be one of (numpy.ndarray, torch.Tensor) but is {type(img).__name__}.")
if not isinstance(img, (torch.Tensor, np.ndarray)):
raise TypeError(f"img must be one of (numpy.ndarray, torch.Tensor) but is {type(img).__name__}.")
img_out, *_ = convert_data_type(img, output_type=type(img), dtype=dtype or self.dtype)
return img_out


class ToTensor(Transform):
Expand Down
8 changes: 4 additions & 4 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
)
from monai.transforms.utils import extreme_points_to_image, get_extreme_points
from monai.utils import convert_to_numpy, ensure_tuple, ensure_tuple_rep
from monai.utils.enums import InverseKeys
from monai.utils.enums import InverseKeys, TransformBackends

__all__ = [
"AddChannelD",
Expand Down Expand Up @@ -393,6 +393,8 @@ class CastToTyped(MapTransform):
Dictionary-based wrapper of :py:class:`monai.transforms.CastToType`.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(
self,
keys: KeysCollection,
Expand All @@ -413,9 +415,7 @@ def __init__(
self.dtype = ensure_tuple_rep(dtype, len(self.keys))
self.converter = CastToType()

def __call__(
self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]]
) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]:
def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]:
d = dict(data)
for key, dtype in self.key_iterator(d, self.dtype):
d[key] = self.converter(d[key], dtype=dtype)
Expand Down
16 changes: 8 additions & 8 deletions monai/utils/type_conversion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import Any, Optional, Sequence, Tuple, Union
from typing import Any, Optional, Sequence, Tuple, Type, Union, cast

import numpy as np
import torch
Expand Down Expand Up @@ -147,7 +147,7 @@ def convert_to_numpy(data):

def convert_data_type(
data: Any,
output_type: Optional[type] = None,
output_type: Optional[Type[NdarrayTensor]] = None,
device: Optional[torch.device] = None,
dtype: Optional[Union[DtypeLike, torch.dtype]] = None,
) -> Tuple[NdarrayTensor, type, Optional[torch.device]]:
Expand Down Expand Up @@ -176,16 +176,16 @@ def convert_data_type(
data = convert_to_tensor(data)
if dtype != data.dtype:
data = data.to(dtype) # type: ignore
elif output_type is np.ndarray:
if device is not None:
data = data.to(device)
return cast(NdarrayTensor, data), orig_type, orig_device # pytype: disable=invalid-annotation
if output_type is np.ndarray:
if orig_type is not np.ndarray:
data = convert_to_numpy(data)
if data is not None and dtype != data.dtype:
data = data.astype(dtype) # type: ignore

if isinstance(data, torch.Tensor) and device is not None:
data = data.to(device)

return data, orig_type, orig_device
return cast(NdarrayTensor, data), orig_type, orig_device # pytype: disable=invalid-annotation
raise ValueError(f"Unsupported output type: {output_type}")


def convert_to_dst_type(src: Any, dst: NdarrayTensor) -> Tuple[NdarrayTensor, type, Optional[torch.device]]:
Expand Down
Loading