-
Notifications
You must be signed in to change notification settings - Fork 85
Expand file tree
/
Copy pathlocal_net.py
More file actions
237 lines (208 loc) · 8.03 KB
/
local_net.py
File metadata and controls
237 lines (208 loc) · 8.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
# coding=utf-8
from typing import List, Optional, Tuple, Union
import tensorflow as tf
import tensorflow.keras.layers as tfkl
from deepreg.model import layer
from deepreg.model.backbone.u_net import UNet
from deepreg.model.layer import Extraction
from deepreg.registry import REGISTRY
class AdditiveUpsampling(tfkl.Layer):
def __init__(
self,
filters: int,
output_padding: Union[int, Tuple, List],
kernel_size: Union[int, Tuple, List],
padding: str,
strides: Union[int, Tuple, List],
output_shape: Tuple,
name: str = "AdditiveUpsampling",
):
"""
Addictive up-sampling layer.
:param filters: number of channels for output
:param output_padding: padding for output
:param kernel_size: arg for deconv3d
:param padding: arg for deconv3d
:param strides: arg for deconv3d
:param output_shape: shape of the output tensor
:param name: name of the layer.
"""
super().__init__(name=name)
self.deconv3d = layer.Deconv3dBlock(
filters=filters,
output_padding=output_padding,
kernel_size=kernel_size,
strides=strides,
padding=padding,
)
self.resize = layer.Resize3d(shape=output_shape)
def call(self, inputs, **kwargs):
deconved = self.deconv3d(inputs)
resized = self.resize(inputs)
resized = tf.add_n(tf.split(resized, num_or_size_splits=2, axis=4))
return deconved + resized
def get_config(self) -> dict:
"""Return the config dictionary for recreating this class."""
config = super().get_config()
deconv_config = self.deconv3d.get_config()
config.update(
filters=deconv_config["filters"],
output_padding=deconv_config["output_padding"],
kernel_size=deconv_config["kernel_size"],
strides=deconv_config["strides"],
padding=deconv_config["padding"],
)
config.update(output_shape=self.resize._shape)
return config
@REGISTRY.register_backbone(name="local")
class LocalNet(UNet):
"""
Build LocalNet for image registration.
Reference:
- Hu, Yipeng, et al.
"Weakly-supervised convolutional neural networks
for multimodal image registration."
Medical image analysis 49 (2018): 1-13.
https://doi.org/10.1016/j.media.2018.07.002
- Hu, Yipeng, et al.
"Label-driven weakly-supervised learning
for multimodal deformable image registration,"
https://arxiv.org/abs/1711.01666
"""
def __init__(
self,
image_size: tuple,
num_channel_initial: int,
extract_levels: Tuple[int, ...],
out_kernel_initializer: str,
out_activation: str,
out_channels: int,
depth: Optional[int] = None,
use_additive_upsampling: bool = True,
pooling: bool = True,
concat_skip: bool = False,
name: str = "LocalNet",
**kwargs,
):
"""
Init.
Image is encoded gradually, i from level 0 to D,
then it is decoded gradually, j from level D to 0.
Some of the decoded levels are used for generating extractions.
So, extract_levels are between [0, D].
:param image_size: such as (dim1, dim2, dim3)
:param num_channel_initial: number of initial channels.
:param extract_levels: from which depths the output will be built.
:param out_kernel_initializer: initializer to use for kernels.
:param out_activation: activation to use at end layer.
:param out_channels: number of channels for the extractions
:param depth: depth of the encoder.
If depth is not given, depth = max(extract_levels) will be used.
:param use_additive_upsampling: whether use additive up-sampling layer
for decoding.
:param pooling: for down-sampling, use non-parameterized
pooling if true, otherwise use conv3d
:param concat_skip: when up-sampling, concatenate skipped
tensor if true, otherwise use addition
:param name: name of the backbone.
:param kwargs: additional arguments.
"""
self._use_additive_upsampling = use_additive_upsampling
if depth is None:
depth = max(extract_levels)
kwargs["encode_kernel_sizes"] = [7] + [3] * depth
super().__init__(
image_size=image_size,
num_channel_initial=num_channel_initial,
depth=depth,
extract_levels=extract_levels,
out_kernel_initializer=out_kernel_initializer,
out_activation=out_activation,
out_channels=out_channels,
pooling=pooling,
concat_skip=concat_skip,
name=name,
**kwargs,
)
def build_bottom_block(
self, filters: int, kernel_size: int, padding: str
) -> Union[tf.keras.Model, tfkl.Layer]:
"""
Build a block for bottom layer.
This block do not change the tensor shape (width, height, depth),
it only changes the number of channels.
:param filters: number of channels for output
:param kernel_size: arg for conv3d
:param padding: arg for conv3d
:return: a block consists of one or multiple layers
"""
return layer.Conv3dBlock(
filters=filters, kernel_size=kernel_size, padding=padding
)
def build_up_sampling_block(
self,
filters: int,
output_padding: Union[Tuple[int, ...], int],
kernel_size: Union[Tuple[int, ...], int],
padding: str,
strides: Union[Tuple[int, ...], int],
output_shape: Tuple[int, ...],
) -> Union[tf.keras.Model, tfkl.Layer]:
"""
Build a block for up-sampling.
This block changes the tensor shape (width, height, depth),
but it does not changes the number of channels.
:param filters: number of channels for output
:param output_padding: padding for output
:param kernel_size: arg for deconv3d
:param padding: arg for deconv3d
:param strides: arg for deconv3d
:param output_shape: shape of the output tensor
:return: a block consists of one or multiple layers
"""
if self._use_additive_upsampling:
return AdditiveUpsampling(
filters=filters,
output_padding=output_padding,
kernel_size=kernel_size,
strides=strides,
padding=padding,
output_shape=output_shape,
)
return layer.Deconv3dBlock(
filters=filters,
output_padding=output_padding,
kernel_size=kernel_size,
strides=strides,
padding=padding,
)
def build_output_block(
self,
image_size: Tuple[int, ...],
extract_levels: Tuple[int, ...],
out_channels: int,
out_kernel_initializer: str,
out_activation: str,
) -> Union[tf.keras.Model, tfkl.Layer]:
"""
Build a block for output.
The input to this block is a list of tensors.
:param image_size: such as (dim1, dim2, dim3)
:param extract_levels: number of extraction levels.
:param out_channels: number of channels for the extractions
:param out_kernel_initializer: initializer to use for kernels.
:param out_activation: activation to use at end layer.
:return: a block consists of one or multiple layers
"""
return Extraction(
image_size=image_size,
extract_levels=extract_levels,
out_channels=out_channels,
out_kernel_initializer=out_kernel_initializer,
out_activation=out_activation,
)
def get_config(self) -> dict:
"""Return the config dictionary for recreating this class."""
config = super().get_config()
config.update(use_additive_upsampling=self._use_additive_upsampling)
return config