99# See the License for the specific language governing permissions and
1010# limitations under the License.
1111
12- from typing import Sequence , Tuple , Type , Union
12+ from typing import Optional , Sequence , Tuple , Type , Union
1313
1414import numpy as np
1515import torch
2121from monai .networks .blocks import MLPBlock as Mlp
2222from monai .networks .blocks import PatchEmbed , UnetOutBlock , UnetrBasicBlock , UnetrUpBlock
2323from monai .networks .layers import DropPath , trunc_normal_
24- from monai .utils import ensure_tuple_rep , optional_import
24+ from monai .utils import ensure_tuple_rep , look_up_option , optional_import
2525
2626rearrange , _ = optional_import ("einops" , name = "rearrange" )
2727
28+ __all__ = [
29+ "SwinUNETR" ,
30+ "window_partition" ,
31+ "window_reverse" ,
32+ "WindowAttention" ,
33+ "SwinTransformerBlock" ,
34+ "PatchMerging" ,
35+ "PatchMergingV2" ,
36+ "MERGING_MODE" ,
37+ "BasicLayer" ,
38+ "SwinTransformer" ,
39+ ]
40+
2841
2942class SwinUNETR (nn .Module ):
3043 """
@@ -48,6 +61,7 @@ def __init__(
4861 normalize : bool = True ,
4962 use_checkpoint : bool = False ,
5063 spatial_dims : int = 3 ,
64+ downsample = "merging" ,
5165 ) -> None :
5266 """
5367 Args:
@@ -64,6 +78,9 @@ def __init__(
6478 normalize: normalize output intermediate features in each stage.
6579 use_checkpoint: use gradient checkpointing for reduced memory usage.
6680 spatial_dims: number of spatial dims.
81+ downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
82+ user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`.
83+ The default is currently `"merging"` (the original version defined in v0.9.0).
6784
6885 Examples::
6986
@@ -121,6 +138,7 @@ def __init__(
121138 norm_layer = nn .LayerNorm ,
122139 use_checkpoint = use_checkpoint ,
123140 spatial_dims = spatial_dims ,
141+ downsample = look_up_option (downsample , MERGING_MODE ) if isinstance (downsample , str ) else downsample ,
124142 )
125143
126144 self .encoder1 = UnetrBasicBlock (
@@ -657,7 +675,7 @@ def forward(self, x, mask_matrix):
657675 return x
658676
659677
660- class PatchMerging (nn .Module ):
678+ class PatchMergingV2 (nn .Module ):
661679 """
662680 Patch merging layer based on: "Liu et al.,
663681 Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
@@ -695,8 +713,8 @@ def forward(self, x):
695713 x2 = x [:, 0 ::2 , 1 ::2 , 0 ::2 , :]
696714 x3 = x [:, 0 ::2 , 0 ::2 , 1 ::2 , :]
697715 x4 = x [:, 1 ::2 , 0 ::2 , 1 ::2 , :]
698- x5 = x [:, 0 ::2 , 1 ::2 , 0 ::2 , :]
699- x6 = x [:, 0 ::2 , 0 ::2 , 1 ::2 , :]
716+ x5 = x [:, 1 ::2 , 1 ::2 , 0 ::2 , :]
717+ x6 = x [:, 0 ::2 , 1 ::2 , 1 ::2 , :]
700718 x7 = x [:, 1 ::2 , 1 ::2 , 1 ::2 , :]
701719 x = torch .cat ([x0 , x1 , x2 , x3 , x4 , x5 , x6 , x7 ], - 1 )
702720
@@ -716,6 +734,36 @@ def forward(self, x):
716734 return x
717735
718736
737+ class PatchMerging (PatchMergingV2 ):
738+ """The `PatchMerging` module previously defined in v0.9.0."""
739+
740+ def forward (self , x ):
741+ x_shape = x .size ()
742+ if len (x_shape ) == 4 :
743+ return super ().forward (x )
744+ if len (x_shape ) != 5 :
745+ raise ValueError (f"expecting 5D x, got { x .shape } ." )
746+ b , d , h , w , c = x_shape
747+ pad_input = (h % 2 == 1 ) or (w % 2 == 1 ) or (d % 2 == 1 )
748+ if pad_input :
749+ x = F .pad (x , (0 , 0 , 0 , w % 2 , 0 , h % 2 , 0 , d % 2 ))
750+ x0 = x [:, 0 ::2 , 0 ::2 , 0 ::2 , :]
751+ x1 = x [:, 1 ::2 , 0 ::2 , 0 ::2 , :]
752+ x2 = x [:, 0 ::2 , 1 ::2 , 0 ::2 , :]
753+ x3 = x [:, 0 ::2 , 0 ::2 , 1 ::2 , :]
754+ x4 = x [:, 1 ::2 , 0 ::2 , 1 ::2 , :]
755+ x5 = x [:, 0 ::2 , 1 ::2 , 0 ::2 , :]
756+ x6 = x [:, 0 ::2 , 0 ::2 , 1 ::2 , :]
757+ x7 = x [:, 1 ::2 , 1 ::2 , 1 ::2 , :]
758+ x = torch .cat ([x0 , x1 , x2 , x3 , x4 , x5 , x6 , x7 ], - 1 )
759+ x = self .norm (x )
760+ x = self .reduction (x )
761+ return x
762+
763+
764+ MERGING_MODE = {"merging" : PatchMerging , "mergingv2" : PatchMergingV2 }
765+
766+
719767def compute_mask (dims , window_size , shift_size , device ):
720768 """Computing region masks based on: "Liu et al.,
721769 Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
@@ -776,7 +824,7 @@ def __init__(
776824 drop : float = 0.0 ,
777825 attn_drop : float = 0.0 ,
778826 norm_layer : Type [LayerNorm ] = nn .LayerNorm ,
779- downsample : isinstance = None , # type: ignore
827+ downsample : Optional [ nn . Module ] = None ,
780828 use_checkpoint : bool = False ,
781829 ) -> None :
782830 """
@@ -791,7 +839,7 @@ def __init__(
791839 drop: dropout rate.
792840 attn_drop: attention dropout rate.
793841 norm_layer: normalization layer.
794- downsample: downsample layer at the end of the layer.
842+ downsample: an optional downsampling layer at the end of the layer.
795843 use_checkpoint: use gradient checkpointing for reduced memory usage.
796844 """
797845
@@ -820,7 +868,7 @@ def __init__(
820868 ]
821869 )
822870 self .downsample = downsample
823- if self .downsample is not None :
871+ if callable ( self .downsample ) :
824872 self .downsample = downsample (dim = dim , norm_layer = norm_layer , spatial_dims = len (self .window_size ))
825873
826874 def forward (self , x ):
@@ -881,6 +929,7 @@ def __init__(
881929 patch_norm : bool = False ,
882930 use_checkpoint : bool = False ,
883931 spatial_dims : int = 3 ,
932+ downsample = "merging" ,
884933 ) -> None :
885934 """
886935 Args:
@@ -899,6 +948,9 @@ def __init__(
899948 patch_norm: add normalization after patch embedding.
900949 use_checkpoint: use gradient checkpointing for reduced memory usage.
901950 spatial_dims: spatial dimension.
951+ downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
952+ user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`.
953+ The default is currently `"merging"` (the original version defined in v0.9.0).
902954 """
903955
904956 super ().__init__ ()
@@ -920,6 +972,7 @@ def __init__(
920972 self .layers2 = nn .ModuleList ()
921973 self .layers3 = nn .ModuleList ()
922974 self .layers4 = nn .ModuleList ()
975+ down_sample_mod = look_up_option (downsample , MERGING_MODE ) if isinstance (downsample , str ) else downsample
923976 for i_layer in range (self .num_layers ):
924977 layer = BasicLayer (
925978 dim = int (embed_dim * 2 ** i_layer ),
@@ -932,7 +985,7 @@ def __init__(
932985 drop = drop_rate ,
933986 attn_drop = attn_drop_rate ,
934987 norm_layer = norm_layer ,
935- downsample = PatchMerging ,
988+ downsample = down_sample_mod ,
936989 use_checkpoint = use_checkpoint ,
937990 )
938991 if i_layer == 0 :
0 commit comments