Skip to content

Commit 574fe60

Browse files
enhance dynunet based on brats21 1st solution
Signed-off-by: Yiheng Wang <[email protected]>
1 parent 27c9b0a commit 574fe60

File tree

4 files changed

+76
-38
lines changed

4 files changed

+76
-38
lines changed

monai/networks/blocks/dynunet_block.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class UnetResBlock(nn.Module):
3333
kernel_size: convolution kernel size.
3434
stride: convolution stride.
3535
norm_name: feature normalization type and arguments.
36-
dropout: dropout probability
36+
dropout: dropout probability.
3737
3838
"""
3939

@@ -100,7 +100,7 @@ class UnetBasicBlock(nn.Module):
100100
kernel_size: convolution kernel size.
101101
stride: convolution stride.
102102
norm_name: feature normalization type and arguments.
103-
dropout: dropout probability
103+
dropout: dropout probability.
104104
105105
"""
106106

@@ -155,7 +155,8 @@ class UnetUpBlock(nn.Module):
155155
stride: convolution stride.
156156
upsample_kernel_size: convolution kernel size for transposed convolution layers.
157157
norm_name: feature normalization type and arguments.
158-
dropout: dropout probability
158+
dropout: dropout probability.
159+
trans_bias: transposed convolution bias.
159160
160161
"""
161162

@@ -169,6 +170,7 @@ def __init__(
169170
upsample_kernel_size: Union[Sequence[int], int],
170171
norm_name: Union[Tuple, str],
171172
dropout: Optional[Union[Tuple, str, float]] = None,
173+
trans_bias: bool = False,
172174
):
173175
super().__init__()
174176
upsample_stride = upsample_kernel_size
@@ -179,6 +181,7 @@ def __init__(
179181
kernel_size=upsample_kernel_size,
180182
stride=upsample_stride,
181183
dropout=dropout,
184+
bias=trans_bias,
182185
conv_only=True,
183186
is_transposed=True,
184187
)

monai/networks/nets/dynunet.py

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ class DynUNetSkipLayer(nn.Module):
3333

3434
heads: List[torch.Tensor]
3535

36-
def __init__(self, index, heads, downsample, upsample, super_head, next_layer):
36+
def __init__(self, index, heads, downsample, upsample, next_layer, super_head=None):
3737
super().__init__()
3838
self.downsample = downsample
39-
self.upsample = upsample
4039
self.next_layer = next_layer
40+
self.upsample = upsample
4141
self.super_head = super_head
4242
self.heads = heads
4343
self.index = index
@@ -46,8 +46,8 @@ def forward(self, x):
4646
downout = self.downsample(x)
4747
nextout = self.next_layer(downout)
4848
upout = self.upsample(nextout, downout)
49-
50-
self.heads[self.index] = self.super_head(upout)
49+
if self.super_head is not None and self.index > 0:
50+
self.heads[self.index - 1] = self.super_head(upout)
5151

5252
return upout
5353

@@ -57,6 +57,7 @@ class DynUNet(nn.Module):
5757
This reimplementation of a dynamic UNet (DynUNet) is based on:
5858
`Automated Design of Deep Learning Methods for Biomedical Image Segmentation <https://arxiv.org/abs/1904.08128>`_.
5959
`nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation <https://arxiv.org/abs/1809.10486>`_.
60+
`Optimized U-Net for Brain Tumor Segmentation <https://arxiv.org/pdf/2110.03352.pdf>`_.
6061
6162
This model is more flexible compared with ``monai.networks.nets.UNet`` in three
6263
places:
@@ -89,6 +90,12 @@ class DynUNet(nn.Module):
8990
strides: convolution strides for each blocks.
9091
upsample_kernel_size: convolution kernel size for transposed convolution layers. The values should
9192
equal to strides[1:].
93+
filters: number of output channels for each blocks. Different from nnU-Net, in this implementation we add
94+
this argument to make the network more flexible. As shown in the third reference, one way to determine
95+
this argument is like:
96+
``[64, 96, 128, 192, 256, 384, 512, 768, 1024][: len(strides)]``.
97+
The above way is used in the network that wins task 1 in the BraTS21 Challenge.
98+
If not specified, the way which nnUNet used will be employed. Defaults to ``None``.
9299
dropout: dropout ratio. Defaults to no dropout.
93100
norm_name: feature normalization type and arguments. Defaults to ``INSTANCE``.
94101
deep_supervision: whether to add deep supervision head before output. Defaults to ``False``.
@@ -109,6 +116,7 @@ class DynUNet(nn.Module):
109116
Defaults to 1.
110117
res_block: whether to use residual connection based convolution blocks during the network.
111118
Defaults to ``False``.
119+
trans_bias: whether to set the bias parameter in transposed convolution layers. Defaults to ``False``.
112120
"""
113121

114122
def __init__(
@@ -119,11 +127,13 @@ def __init__(
119127
kernel_size: Sequence[Union[Sequence[int], int]],
120128
strides: Sequence[Union[Sequence[int], int]],
121129
upsample_kernel_size: Sequence[Union[Sequence[int], int]],
130+
filters: Optional[Sequence[int]] = None,
122131
dropout: Optional[Union[Tuple, str, float]] = None,
123132
norm_name: Union[Tuple, str] = ("INSTANCE", {"affine": True}),
124133
deep_supervision: bool = False,
125134
deep_supr_num: int = 1,
126135
res_block: bool = False,
136+
trans_bias: bool = False,
127137
):
128138
super().__init__()
129139
self.spatial_dims = spatial_dims
@@ -135,21 +145,26 @@ def __init__(
135145
self.norm_name = norm_name
136146
self.dropout = dropout
137147
self.conv_block = UnetResBlock if res_block else UnetBasicBlock
138-
self.filters = [min(2 ** (5 + i), 320 if spatial_dims == 3 else 512) for i in range(len(strides))]
148+
self.trans_bias = trans_bias
149+
if filters is not None:
150+
self.filters = filters
151+
self.check_filters()
152+
else:
153+
self.filters = [min(2 ** (5 + i), 320 if spatial_dims == 3 else 512) for i in range(len(strides))]
139154
self.input_block = self.get_input_block()
140155
self.downsamples = self.get_downsamples()
141156
self.bottleneck = self.get_bottleneck()
142157
self.upsamples = self.get_upsamples()
143158
self.output_block = self.get_output_block(0)
144159
self.deep_supervision = deep_supervision
145-
self.deep_supervision_heads = self.get_deep_supervision_heads()
146160
self.deep_supr_num = deep_supr_num
161+
self.deep_supervision_heads = self.get_deep_supervision_heads()
147162
self.apply(self.initialize_weights)
148163
self.check_kernel_stride()
149164
self.check_deep_supr_num()
150165

151166
# initialize the typed list of supervision head outputs so that Torchscript can recognize what's going on
152-
self.heads: List[torch.Tensor] = [torch.rand(1)] * (len(self.deep_supervision_heads) + 1)
167+
self.heads: List[torch.Tensor] = [torch.rand(1)] * self.deep_supr_num
153168

154169
def create_skips(index, downsamples, upsamples, superheads, bottleneck):
155170
"""
@@ -162,22 +177,27 @@ def create_skips(index, downsamples, upsamples, superheads, bottleneck):
162177

163178
if len(downsamples) != len(upsamples):
164179
raise AssertionError(f"{len(downsamples)} != {len(upsamples)}")
165-
if (len(downsamples) - len(superheads)) not in (1, 0):
166-
raise AssertionError(f"{len(downsamples)}-(0,1) != {len(superheads)}")
167180

168181
if len(downsamples) == 0: # bottom of the network, pass the bottleneck block
169182
return bottleneck
183+
super_head_flag = False
170184
if index == 0: # don't associate a supervision head with self.input_block
171-
current_head, rest_heads = nn.Identity(), superheads
185+
rest_heads = superheads
172186
elif not self.deep_supervision: # bypass supervision heads by passing nn.Identity in place of a real one
173-
current_head, rest_heads = nn.Identity(), superheads[1:]
187+
rest_heads = nn.ModuleList()
174188
else:
175-
current_head, rest_heads = superheads[0], superheads[1:]
189+
if len(superheads) > 0:
190+
super_head_flag = True
191+
rest_heads = superheads[1:]
192+
else:
193+
rest_heads = nn.ModuleList()
176194

177195
# create the next layer down, this will stop at the bottleneck layer
178196
next_layer = create_skips(1 + index, downsamples[1:], upsamples[1:], rest_heads, bottleneck)
179-
180-
return DynUNetSkipLayer(index, self.heads, downsamples[0], upsamples[0], current_head, next_layer)
197+
if super_head_flag:
198+
return DynUNetSkipLayer(index, self.heads, downsamples[0], upsamples[0], next_layer, superheads[0])
199+
else:
200+
return DynUNetSkipLayer(index, self.heads, downsamples[0], upsamples[0], next_layer)
181201

182202
self.skip_layers = create_skips(
183203
0,
@@ -212,13 +232,19 @@ def check_deep_supr_num(self):
212232
if deep_supr_num < 1:
213233
raise AssertionError("deep_supr_num should be larger than 0.")
214234

235+
def check_filters(self):
236+
filters = self.filters
237+
if len(filters) < len(self.strides):
238+
raise AssertionError("length of filters should be no less than the length of strides.")
239+
else:
240+
self.filters = filters[: len(self.strides)]
241+
215242
def forward(self, x):
216243
out = self.skip_layers(x)
217244
out = self.output_block(out)
218245
if self.training and self.deep_supervision:
219246
out_all = [out]
220-
feature_maps = self.heads[1 : self.deep_supr_num + 1]
221-
for feature_map in feature_maps:
247+
for feature_map in self.heads:
222248
out_all.append(interpolate(feature_map, out.shape[2:]))
223249
return torch.stack(out_all, dim=1)
224250
return out
@@ -257,7 +283,9 @@ def get_upsamples(self):
257283
inp, out = self.filters[1:][::-1], self.filters[:-1][::-1]
258284
strides, kernel_size = self.strides[1:][::-1], self.kernel_size[1:][::-1]
259285
upsample_kernel_size = self.upsample_kernel_size[::-1]
260-
return self.get_module_list(inp, out, kernel_size, strides, UnetUpBlock, upsample_kernel_size)
286+
return self.get_module_list(
287+
inp, out, kernel_size, strides, UnetUpBlock, upsample_kernel_size, trans_bias=self.trans_bias
288+
)
261289

262290
def get_module_list(
263291
self,
@@ -267,6 +295,7 @@ def get_module_list(
267295
strides: Sequence[Union[Sequence[int], int]],
268296
conv_block: nn.Module,
269297
upsample_kernel_size: Optional[Sequence[Union[Sequence[int], int]]] = None,
298+
trans_bias: bool = False,
270299
):
271300
layers = []
272301
if upsample_kernel_size is not None:
@@ -282,6 +311,7 @@ def get_module_list(
282311
"norm_name": self.norm_name,
283312
"dropout": self.dropout,
284313
"upsample_kernel_size": up_kernel,
314+
"trans_bias": trans_bias,
285315
}
286316
layer = conv_block(**params)
287317
layers.append(layer)
@@ -301,7 +331,9 @@ def get_module_list(
301331
return nn.ModuleList(layers)
302332

303333
def get_deep_supervision_heads(self):
304-
return nn.ModuleList([self.get_output_block(i + 1) for i in range(len(self.upsamples) - 1)])
334+
if not self.deep_supervision:
335+
return nn.ModuleList()
336+
return nn.ModuleList([self.get_output_block(i + 1) for i in range(self.deep_supr_num)])
305337

306338
@staticmethod
307339
def initialize_weights(module):

tests/test_dynunet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,9 @@
6666
"kernel_size": (3, (1, 1, 3), 3, 3),
6767
"strides": ((1, 2, 1), 2, 2, 1),
6868
"upsample_kernel_size": (2, 2, 1),
69+
"filters": (64, 96, 128, 192),
6970
"norm_name": ("INSTANCE", {"affine": True}),
70-
"deep_supervision": False,
71+
"deep_supervision": True,
7172
"res_block": res_block,
7273
"dropout": ("alphadropout", {"p": 0.25}),
7374
},

tests/test_dynunet_block.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -49,22 +49,24 @@
4949
for stride in [1, 2]:
5050
for norm_name in ["batch", "instance"]:
5151
for in_size in [15, 16]:
52-
out_size = in_size * stride
53-
test_case = [
54-
{
55-
"spatial_dims": spatial_dims,
56-
"in_channels": in_channels,
57-
"out_channels": out_channels,
58-
"kernel_size": kernel_size,
59-
"norm_name": norm_name,
60-
"stride": stride,
61-
"upsample_kernel_size": stride,
62-
},
63-
(1, in_channels, *([in_size] * spatial_dims)),
64-
(1, out_channels, *([out_size] * spatial_dims)),
65-
(1, out_channels, *([in_size * stride] * spatial_dims)),
66-
]
67-
TEST_UP_BLOCK.append(test_case)
52+
for trans_bias in [True, False]:
53+
out_size = in_size * stride
54+
test_case = [
55+
{
56+
"spatial_dims": spatial_dims,
57+
"in_channels": in_channels,
58+
"out_channels": out_channels,
59+
"kernel_size": kernel_size,
60+
"norm_name": norm_name,
61+
"stride": stride,
62+
"upsample_kernel_size": stride,
63+
"trans_bias": trans_bias,
64+
},
65+
(1, in_channels, *([in_size] * spatial_dims)),
66+
(1, out_channels, *([out_size] * spatial_dims)),
67+
(1, out_channels, *([in_size * stride] * spatial_dims)),
68+
]
69+
TEST_UP_BLOCK.append(test_case)
6870

6971

7072
class TestResBasicBlock(unittest.TestCase):

0 commit comments

Comments
 (0)