-
Notifications
You must be signed in to change notification settings - Fork 85
Expand file tree
/
Copy pathglobal_net.py
More file actions
141 lines (121 loc) · 4.74 KB
/
global_net.py
File metadata and controls
141 lines (121 loc) · 4.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# coding=utf-8
from typing import List, Optional, Tuple, Union
import numpy as np
import tensorflow as tf
import tensorflow.keras.layers as tfkl
from deepreg.model import layer_util
from deepreg.model.backbone.u_net import UNet
from deepreg.registry import REGISTRY
class AffineHead(tfkl.Layer):
def __init__(
self,
image_size: tuple,
name: str = "AffineHead",
):
"""
Init.
:param image_size: such as (dim1, dim2, dim3)
:param name: name of the layer
"""
super().__init__(name=name)
self.reference_grid = layer_util.get_reference_grid(image_size)
self.transform_initial = tf.constant_initializer(
value=list(np.eye(4, 3).reshape((-1)))
)
self._flatten = tfkl.Flatten()
self._dense = tfkl.Dense(units=12, bias_initializer=self.transform_initial)
def call(
self, inputs: Union[tf.Tensor, List], **kwargs
) -> Tuple[tf.Tensor, tf.Tensor]:
"""
:param inputs: a tensor or a list of tensor with length 1
:param kwargs: additional args
:return: ddf and theta
- ddf has shape (batch, dim1, dim2, dim3, 3)
- theta has shape (batch, 4, 3)
"""
if isinstance(inputs, list):
inputs = inputs[0]
theta = self._dense(self._flatten(inputs))
theta = tf.reshape(theta, shape=(-1, 4, 3))
# warp the reference grid with affine parameters to output a ddf
grid_warped = layer_util.warp_grid(self.reference_grid, theta)
ddf = grid_warped - self.reference_grid
return ddf, theta
def get_config(self):
"""Return the config dictionary for recreating this class."""
config = super().get_config()
config.update(image_size=self.reference_grid.shape[:3])
return config
@REGISTRY.register_backbone(name="global")
class GlobalNet(UNet):
"""
Build GlobalNet for image registration.
GlobalNet is a special UNet where the decoder for up-sampling is skipped.
The network's outputs come from the bottom layer from the encoder directly.
Reference:
- Hu, Yipeng, et al.
"Label-driven weakly-supervised learning
for multimodal deformable image registration,"
https://arxiv.org/abs/1711.01666
"""
def __init__(
self,
image_size: tuple,
num_channel_initial: int,
extract_levels: Optional[Tuple[int, ...]] = None,
depth: Optional[int] = None,
name: str = "GlobalNet",
**kwargs,
):
"""
Image is encoded gradually, i from level 0 to E.
Then, a densely-connected layer outputs an affine
transformation.
:param image_size: tuple, such as (dim1, dim2, dim3)
:param num_channel_initial: int, number of initial channels
:param extract_levels: list, which levels from net to extract, deprecated.
If depth is not given, depth = max(extract_levels) will be used.
:param depth: depth of the encoder. If given, extract_levels is not used.
:param name: name of the backbone.
:param kwargs: additional arguments.
"""
if depth is None:
if extract_levels is None:
raise ValueError(
"GlobalNet requires `depth` or `extract_levels` "
"to define the depth of encoder. "
"If `depth` is not given, "
"the maximum value of `extract_levels` will be used."
"However the argument `extract_levels` is deprecated "
"and will be removed in future release."
)
depth = max(extract_levels)
super().__init__(
image_size=image_size,
num_channel_initial=num_channel_initial,
depth=depth,
extract_levels=(depth,),
name=name,
**kwargs,
)
def build_output_block(
self,
image_size: Tuple[int, ...],
extract_levels: Tuple[int, ...],
out_channels: int,
out_kernel_initializer: str,
out_activation: str,
) -> Union[tf.keras.Model, tfkl.Layer]:
"""
Build a block for output.
The input to this block is a list of length 1.
The output has two tensors.
:param image_size: such as (dim1, dim2, dim3)
:param extract_levels: not used
:param out_channels: not used
:param out_kernel_initializer: not used
:param out_activation: not used
:return: a block consists of one or multiple layers
"""
return AffineHead(image_size=image_size)