@@ -33,11 +33,11 @@ class DynUNetSkipLayer(nn.Module):
3333
3434 heads : List [torch .Tensor ]
3535
36- def __init__ (self , index , heads , downsample , upsample , super_head , next_layer ):
36+ def __init__ (self , index , heads , downsample , upsample , next_layer , super_head = None ):
3737 super ().__init__ ()
3838 self .downsample = downsample
39- self .upsample = upsample
4039 self .next_layer = next_layer
40+ self .upsample = upsample
4141 self .super_head = super_head
4242 self .heads = heads
4343 self .index = index
@@ -46,8 +46,8 @@ def forward(self, x):
4646 downout = self .downsample (x )
4747 nextout = self .next_layer (downout )
4848 upout = self .upsample (nextout , downout )
49-
50- self .heads [self .index ] = self .super_head (upout )
49+ if self . super_head is not None and self . index > 0 :
50+ self .heads [self .index - 1 ] = self .super_head (upout )
5151
5252 return upout
5353
@@ -57,6 +57,7 @@ class DynUNet(nn.Module):
5757 This reimplementation of a dynamic UNet (DynUNet) is based on:
5858 `Automated Design of Deep Learning Methods for Biomedical Image Segmentation <https://arxiv.org/abs/1904.08128>`_.
5959 `nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation <https://arxiv.org/abs/1809.10486>`_.
60+ `Optimized U-Net for Brain Tumor Segmentation <https://arxiv.org/pdf/2110.03352.pdf>`_.
6061
6162 This model is more flexible compared with ``monai.networks.nets.UNet`` in three
6263 places:
@@ -89,6 +90,12 @@ class DynUNet(nn.Module):
8990 strides: convolution strides for each blocks.
9091 upsample_kernel_size: convolution kernel size for transposed convolution layers. The values should
9192 equal to strides[1:].
93+ filters: number of output channels for each blocks. Different from nnU-Net, in this implementation we add
94+ this argument to make the network more flexible. As shown in the third reference, one way to determine
95+ this argument is like:
96+ ``[64, 96, 128, 192, 256, 384, 512, 768, 1024][: len(strides)]``.
97+ The above way is used in the network that wins task 1 in the BraTS21 Challenge.
98+ If not specified, the way which nnUNet used will be employed. Defaults to ``None``.
9299 dropout: dropout ratio. Defaults to no dropout.
93100 norm_name: feature normalization type and arguments. Defaults to ``INSTANCE``.
94101 deep_supervision: whether to add deep supervision head before output. Defaults to ``False``.
@@ -109,6 +116,7 @@ class DynUNet(nn.Module):
109116 Defaults to 1.
110117 res_block: whether to use residual connection based convolution blocks during the network.
111118 Defaults to ``False``.
119+ trans_bias: whether to set the bias parameter in transposed convolution layers. Defaults to ``False``.
112120 """
113121
114122 def __init__ (
@@ -119,11 +127,13 @@ def __init__(
119127 kernel_size : Sequence [Union [Sequence [int ], int ]],
120128 strides : Sequence [Union [Sequence [int ], int ]],
121129 upsample_kernel_size : Sequence [Union [Sequence [int ], int ]],
130+ filters : Optional [Sequence [int ]] = None ,
122131 dropout : Optional [Union [Tuple , str , float ]] = None ,
123132 norm_name : Union [Tuple , str ] = ("INSTANCE" , {"affine" : True }),
124133 deep_supervision : bool = False ,
125134 deep_supr_num : int = 1 ,
126135 res_block : bool = False ,
136+ trans_bias : bool = False ,
127137 ):
128138 super ().__init__ ()
129139 self .spatial_dims = spatial_dims
@@ -135,21 +145,26 @@ def __init__(
135145 self .norm_name = norm_name
136146 self .dropout = dropout
137147 self .conv_block = UnetResBlock if res_block else UnetBasicBlock
138- self .filters = [min (2 ** (5 + i ), 320 if spatial_dims == 3 else 512 ) for i in range (len (strides ))]
148+ self .trans_bias = trans_bias
149+ if filters is not None :
150+ self .filters = filters
151+ self .check_filters ()
152+ else :
153+ self .filters = [min (2 ** (5 + i ), 320 if spatial_dims == 3 else 512 ) for i in range (len (strides ))]
139154 self .input_block = self .get_input_block ()
140155 self .downsamples = self .get_downsamples ()
141156 self .bottleneck = self .get_bottleneck ()
142157 self .upsamples = self .get_upsamples ()
143158 self .output_block = self .get_output_block (0 )
144159 self .deep_supervision = deep_supervision
145- self .deep_supervision_heads = self .get_deep_supervision_heads ()
146160 self .deep_supr_num = deep_supr_num
161+ self .deep_supervision_heads = self .get_deep_supervision_heads ()
147162 self .apply (self .initialize_weights )
148163 self .check_kernel_stride ()
149164 self .check_deep_supr_num ()
150165
151166 # initialize the typed list of supervision head outputs so that Torchscript can recognize what's going on
152- self .heads : List [torch .Tensor ] = [torch .rand (1 )] * ( len ( self .deep_supervision_heads ) + 1 )
167+ self .heads : List [torch .Tensor ] = [torch .rand (1 )] * self .deep_supr_num
153168
154169 def create_skips (index , downsamples , upsamples , superheads , bottleneck ):
155170 """
@@ -162,22 +177,27 @@ def create_skips(index, downsamples, upsamples, superheads, bottleneck):
162177
163178 if len (downsamples ) != len (upsamples ):
164179 raise AssertionError (f"{ len (downsamples )} != { len (upsamples )} " )
165- if (len (downsamples ) - len (superheads )) not in (1 , 0 ):
166- raise AssertionError (f"{ len (downsamples )} -(0,1) != { len (superheads )} " )
167180
168181 if len (downsamples ) == 0 : # bottom of the network, pass the bottleneck block
169182 return bottleneck
183+ super_head_flag = False
170184 if index == 0 : # don't associate a supervision head with self.input_block
171- current_head , rest_heads = nn . Identity (), superheads
185+ rest_heads = superheads
172186 elif not self .deep_supervision : # bypass supervision heads by passing nn.Identity in place of a real one
173- current_head , rest_heads = nn .Identity (), superheads [ 1 :]
187+ rest_heads = nn .ModuleList ()
174188 else :
175- current_head , rest_heads = superheads [0 ], superheads [1 :]
189+ if len (superheads ) > 0 :
190+ super_head_flag = True
191+ rest_heads = superheads [1 :]
192+ else :
193+ rest_heads = nn .ModuleList ()
176194
177195 # create the next layer down, this will stop at the bottleneck layer
178196 next_layer = create_skips (1 + index , downsamples [1 :], upsamples [1 :], rest_heads , bottleneck )
179-
180- return DynUNetSkipLayer (index , self .heads , downsamples [0 ], upsamples [0 ], current_head , next_layer )
197+ if super_head_flag :
198+ return DynUNetSkipLayer (index , self .heads , downsamples [0 ], upsamples [0 ], next_layer , superheads [0 ])
199+ else :
200+ return DynUNetSkipLayer (index , self .heads , downsamples [0 ], upsamples [0 ], next_layer )
181201
182202 self .skip_layers = create_skips (
183203 0 ,
@@ -212,13 +232,19 @@ def check_deep_supr_num(self):
212232 if deep_supr_num < 1 :
213233 raise AssertionError ("deep_supr_num should be larger than 0." )
214234
235+ def check_filters (self ):
236+ filters = self .filters
237+ if len (filters ) < len (self .strides ):
238+ raise AssertionError ("length of filters should be no less than the length of strides." )
239+ else :
240+ self .filters = filters [: len (self .strides )]
241+
215242 def forward (self , x ):
216243 out = self .skip_layers (x )
217244 out = self .output_block (out )
218245 if self .training and self .deep_supervision :
219246 out_all = [out ]
220- feature_maps = self .heads [1 : self .deep_supr_num + 1 ]
221- for feature_map in feature_maps :
247+ for feature_map in self .heads :
222248 out_all .append (interpolate (feature_map , out .shape [2 :]))
223249 return torch .stack (out_all , dim = 1 )
224250 return out
@@ -257,7 +283,9 @@ def get_upsamples(self):
257283 inp , out = self .filters [1 :][::- 1 ], self .filters [:- 1 ][::- 1 ]
258284 strides , kernel_size = self .strides [1 :][::- 1 ], self .kernel_size [1 :][::- 1 ]
259285 upsample_kernel_size = self .upsample_kernel_size [::- 1 ]
260- return self .get_module_list (inp , out , kernel_size , strides , UnetUpBlock , upsample_kernel_size )
286+ return self .get_module_list (
287+ inp , out , kernel_size , strides , UnetUpBlock , upsample_kernel_size , trans_bias = self .trans_bias
288+ )
261289
262290 def get_module_list (
263291 self ,
@@ -267,6 +295,7 @@ def get_module_list(
267295 strides : Sequence [Union [Sequence [int ], int ]],
268296 conv_block : nn .Module ,
269297 upsample_kernel_size : Optional [Sequence [Union [Sequence [int ], int ]]] = None ,
298+ trans_bias : bool = False ,
270299 ):
271300 layers = []
272301 if upsample_kernel_size is not None :
@@ -282,6 +311,7 @@ def get_module_list(
282311 "norm_name" : self .norm_name ,
283312 "dropout" : self .dropout ,
284313 "upsample_kernel_size" : up_kernel ,
314+ "trans_bias" : trans_bias ,
285315 }
286316 layer = conv_block (** params )
287317 layers .append (layer )
@@ -301,7 +331,9 @@ def get_module_list(
301331 return nn .ModuleList (layers )
302332
303333 def get_deep_supervision_heads (self ):
304- return nn .ModuleList ([self .get_output_block (i + 1 ) for i in range (len (self .upsamples ) - 1 )])
334+ if not self .deep_supervision :
335+ return nn .ModuleList ()
336+ return nn .ModuleList ([self .get_output_block (i + 1 ) for i in range (self .deep_supr_num )])
305337
306338 @staticmethod
307339 def initialize_weights (module ):
0 commit comments