Skip to content

Commit e47c569

Browse files
authored
Merge pull request #719 from DeepRegNet/690-nan-inf-loss
690 nan inf loss
2 parents c48b78b + 6ad3312 commit e47c569

File tree

11 files changed

+73
-67
lines changed

11 files changed

+73
-67
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ compatible with the updates.
2424

2525
### Changed
2626

27+
- Increased all EPS to 1e-5.
2728
- Clarify the suggestion in doc to use all-zero masks for missing labels.
2829
- Moved contributor list to a separate page.
2930
- Changed `no-test` flag to `full` for demo scripts.
@@ -37,6 +38,7 @@ compatible with the updates.
3738

3839
### Fixed
3940

41+
- Fixed LNCC loss regarding INF values.
4042
- Removed loss weight checks to be more robust.
4143
- Fixed import error under python 3.6.
4244
- Fixed the residual module in local net architecture, compatible for previous

deepreg/constant.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
"""Module defining global constants."""
2+
3+
EPS = 1.0e-5

deepreg/dataset/loader/interface.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -379,10 +379,10 @@ def validate_images_and_labels(
379379
for arr, name in zip(
380380
[moving_image, fixed_image], ["moving_image", "fixed_image"]
381381
):
382-
if len(arr.shape) != 3:
382+
if len(arr.shape) != 3 or min(arr.shape) <= 0:
383383
raise ValueError(
384-
f"Sample {image_indices}'s {name}' shape should be 3D. "
385-
f"Got {arr.shape}."
384+
f"Sample {image_indices}'s {name}' shape should be 3D"
385+
f" and non-empty, got {arr.shape}."
386386
)
387387
# when data are labeled
388388
if moving_label is not None and fixed_label is not None:

deepreg/loss/image.py

Lines changed: 32 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Provide different loss or metrics classes for images."""
22
import tensorflow as tf
33

4+
from deepreg.constant import EPS
45
from deepreg.loss.util import NegativeLossMixin
56
from deepreg.loss.util import gaussian_kernel1d_size as gaussian_kernel1d
67
from deepreg.loss.util import (
@@ -10,8 +11,6 @@
1011
)
1112
from deepreg.registry import REGISTRY
1213

13-
EPS = tf.keras.backend.epsilon()
14-
1514

1615
@REGISTRY.register_loss(name="ssd")
1716
class SumSquaredDifference(tf.keras.losses.Loss):
@@ -156,27 +155,33 @@ class LocalNormalizedCrossCorrelation(tf.keras.losses.Loss):
156155
157156
E[t] = sum_i(w_i * t_i) / sum_i(w_i)
158157
158+
Here, we assume sum_i(w_i) == 1, means the weights have been normalized.
159+
159160
Similarly, the discrete variance in the window V[t] is
160161
161-
V[t] = E[t**2] - E[t] ** 2
162+
V[t] = E[(t - E[t])**2]
162163
163164
The local squared zero-normalized cross-correlation is therefore
164165
165166
E[ (t-E[t]) * (p-E[p]) ] ** 2 / V[t] / V[p]
166167
167-
where the expectation in numerator is
168-
169-
E[ (t-E[t]) * (p-E[p]) ] = E[t * p] - E[t] * E[p]
170-
171-
Different kernel corresponds to different weights.
168+
When calculating variance, we choose to subtract the mean first then calculte
169+
variance instead of calculating E[t**2] - E[t] ** 2, the reason is that when
170+
E[t**2] and E[t] ** 2 are both very large or very small, the subtraction may
171+
have large rounding error and makes the result inaccurate. Also, it is not
172+
guaranteed that the result >= 0. For more details, please read "Algorithms for
173+
computing the sample variance: Analysis and recommendations." page 1.
172174
173175
For now, y_true and y_pred have to be at least 4d tensor, including batch axis.
174176
175177
Reference:
176178
177179
- Zero-normalized cross-correlation (ZNCC):
178180
https://en.wikipedia.org/wiki/Cross-correlation
179-
- Code: https://github.com/voxelmorph/voxelmorph/blob/legacy/src/losses.py
181+
- https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights
182+
- Chan, Tony F., Gene H. Golub, and Randall J. LeVeque.
183+
"Algorithms for computing the sample variance: Analysis and recommendations."
184+
The American Statistician 37.3 (1983): 242-247.
180185
"""
181186

182187
kernel_fn_dict = dict(
@@ -212,13 +217,8 @@ def __init__(
212217
self.kernel_size = kernel_size
213218

214219
# (kernel_size, )
220+
# sum of the kernel weights would be one
215221
self.kernel = self.kernel_fn(kernel_size=self.kernel_size)
216-
# E[1] = sum_i(w_i), ()
217-
self.kernel_vol = tf.reduce_sum(
218-
self.kernel[:, None, None]
219-
* self.kernel[None, :, None]
220-
* self.kernel[None, None, :]
221-
)
222222

223223
def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
224224
"""
@@ -230,38 +230,29 @@ def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
230230
or (batch, dim1, dim2, dim3, ch)
231231
:return: shape = (batch,)
232232
"""
233-
# adjust
233+
# adjust shape to be (batch, dim1, dim2, dim3, ch)
234234
if len(y_true.shape) == 4:
235235
y_true = tf.expand_dims(y_true, axis=4)
236236
y_pred = tf.expand_dims(y_pred, axis=4)
237237
assert len(y_true.shape) == len(y_pred.shape) == 5
238238

239239
# t = y_true, p = y_pred
240-
# (batch, dim1, dim2, dim3, ch)
241-
t2 = y_true * y_true
242-
p2 = y_pred * y_pred
243-
tp = y_true * y_pred
244-
245-
# sum over kernel
246-
# (batch, dim1, dim2, dim3, 1)
247-
t_sum = separable_filter(y_true, kernel=self.kernel) # E[t] * E[1]
248-
p_sum = separable_filter(y_pred, kernel=self.kernel) # E[p] * E[1]
249-
t2_sum = separable_filter(t2, kernel=self.kernel) # E[tt] * E[1]
250-
p2_sum = separable_filter(p2, kernel=self.kernel) # E[pp] * E[1]
251-
tp_sum = separable_filter(tp, kernel=self.kernel) # E[tp] * E[1]
252-
253-
# average over kernel
254-
# (batch, dim1, dim2, dim3, 1)
255-
t_avg = t_sum / self.kernel_vol # E[t]
256-
p_avg = p_sum / self.kernel_vol # E[p]
257-
258-
# shape = (batch, dim1, dim2, dim3, 1)
259-
cross = tp_sum - p_avg * t_sum # E[tp] * E[1] - E[p] * E[t] * E[1]
260-
t_var = t2_sum - t_avg * t_sum # V[t] * E[1]
261-
p_var = p2_sum - p_avg * p_sum # V[p] * E[1]
262-
263-
# (E[tp] - E[p] * E[t]) ** 2 / V[t] / V[p]
264-
ncc = (cross * cross + EPS) / (t_var * p_var + EPS)
240+
t_mean = separable_filter(y_true, kernel=self.kernel)
241+
p_mean = separable_filter(y_pred, kernel=self.kernel)
242+
243+
t = y_true - t_mean
244+
p = y_pred - p_mean
245+
246+
# the variance can be biased but as both num and denom are biased
247+
# it got cancelled
248+
# https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights
249+
cross = separable_filter(t * p, kernel=self.kernel)
250+
t_var = separable_filter(t * t, kernel=self.kernel)
251+
p_var = separable_filter(p * p, kernel=self.kernel)
252+
253+
num = cross * cross
254+
denom = t_var * p_var
255+
ncc = (num + EPS) / (denom + EPS)
265256

266257
return tf.reduce_mean(ncc, axis=[1, 2, 3, 4])
267258

deepreg/loss/label.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,12 @@
44

55
import tensorflow as tf
66

7+
from deepreg.constant import EPS
78
from deepreg.loss.util import NegativeLossMixin, cauchy_kernel1d
89
from deepreg.loss.util import gaussian_kernel1d_sigma as gaussian_kernel1d
910
from deepreg.loss.util import separable_filter
1011
from deepreg.registry import REGISTRY
1112

12-
EPS = tf.keras.backend.epsilon()
13-
1413

1514
class MultiScaleLoss(tf.keras.losses.Loss):
1615
"""

deepreg/loss/util.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,25 +27,25 @@ def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
2727
return -super().call(y_true=y_true, y_pred=y_pred)
2828

2929

30-
EPS = tf.keras.backend.epsilon()
31-
32-
3330
def rectangular_kernel1d(kernel_size: int) -> tf.Tensor:
3431
"""
35-
Return a the 1D filter for separable convolution equivalent to a 3-D rectangular
36-
kernel for LocalNormalizedCrossCorrelation.
32+
Return a the 1D rectangular kernel for LocalNormalizedCrossCorrelation.
33+
34+
Sum of the weights is 1.
3735
3836
:param kernel_size: scalar, size of the 1-D kernel
3937
:return: kernel_weights, of shape (kernel_size, )
4038
"""
4139

42-
kernel = tf.ones(shape=(kernel_size,), dtype=tf.float32)
40+
kernel = tf.ones(shape=(kernel_size,), dtype=tf.float32) / float(kernel_size)
4341
return kernel
4442

4543

4644
def triangular_kernel1d(kernel_size: int) -> tf.Tensor:
4745
"""
48-
1D triangular kernel.
46+
Return a the 1D triangular kernel for LocalNormalizedCrossCorrelation.
47+
48+
Sum of the weights is 1.
4949
5050
Assume kernel_size is odd, it will be a smoothed from
5151
a kernel which center part is zero.
@@ -73,13 +73,17 @@ def triangular_kernel1d(kernel_size: int) -> tf.Tensor:
7373
kernel = tf.nn.conv1d(
7474
kernel[None, :, None], filters=filters, stride=[1, 1, 1], padding="SAME"
7575
)
76+
kernel = kernel / tf.reduce_sum(kernel)
77+
7678
return kernel[0, :, 0]
7779

7880

7981
def gaussian_kernel1d_size(kernel_size: int) -> tf.Tensor:
8082
"""
81-
Return a the 1D filter for separable convolution equivalent to a 3-D Gaussian
82-
kernel for LocalNormalizedCrossCorrelation.
83+
Return a the 1D Gaussian kernel for LocalNormalizedCrossCorrelation.
84+
85+
Sum of the weights is 1.
86+
8387
:param kernel_size: scalar, size of the 1-D kernel
8488
:return: filters, of shape (kernel_size, )
8589
"""
@@ -88,6 +92,7 @@ def gaussian_kernel1d_size(kernel_size: int) -> tf.Tensor:
8892

8993
grid = tf.range(0, kernel_size, dtype=tf.float32)
9094
filters = tf.exp(-tf.square(grid - mean) / (2 * sigma ** 2))
95+
filters = filters / tf.reduce_sum(filters)
9196

9297
return filters
9398

deepreg/model/network.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ class RegistrationModel(tf.keras.Model):
3030

3131
def __init__(
3232
self,
33-
moving_image_size: tuple,
34-
fixed_image_size: tuple,
33+
moving_image_size: Tuple,
34+
fixed_image_size: Tuple,
3535
index_size: int,
3636
labeled: bool,
3737
batch_size: int,
@@ -61,6 +61,7 @@ def __init__(
6161
self.config = config
6262
self.num_devices = num_devices
6363
self.global_batch_size = num_devices * batch_size
64+
assert self.global_batch_size > 0
6465

6566
self._inputs = None # save inputs of self._model as dict
6667
self._outputs = None # save outputs of self._model as dict
@@ -222,7 +223,6 @@ def _build_loss(self, name: str, inputs_dict: dict):
222223

223224
# add loss
224225
self._model.add_loss(weighted_loss)
225-
226226
# add metric
227227
self._model.add_metric(
228228
loss_value, name=f"loss/{name}_{loss_layer.name}", aggregation="mean"

test/unit/test_interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ def mock_sample_index_generator():
356356
fixed_label=None,
357357
image_indices=[1],
358358
)
359-
assert "Sample [1]'s moving_image' shape should be 3D. " in str(err_info.value)
359+
assert "Sample [1]'s moving_image' shape should be 3D" in str(err_info.value)
360360
with pytest.raises(ValueError) as err_info:
361361
generator.validate_images_and_labels(
362362
fixed_image=dummy_array,

test/unit/test_loss_label.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import tensorflow as tf
1313

1414
import deepreg.loss.label as label
15+
from deepreg.constant import EPS
1516

1617

1718
class TestMultiScaleLoss:
@@ -91,11 +92,11 @@ def y_pred(self):
9192
@pytest.mark.parametrize(
9293
"binary,background_weight,scales,expected",
9394
[
94-
(True, 0.0, None, -np.log(1.0e-7)),
95-
(False, 0.0, None, -0.6 * np.log(0.3)),
96-
(False, 0.2, None, -0.48 * np.log(0.3) - 0.08 * np.log(0.7)),
97-
(False, 0.2, [0, 0], -0.48 * np.log(0.3) - 0.08 * np.log(0.7)),
98-
(False, 0.2, [0, 1], 0.5239637),
95+
(True, 0.0, None, -np.log(EPS)),
96+
(False, 0.0, None, -0.6 * np.log(0.3 + EPS)),
97+
(False, 0.2, None, -0.48 * np.log(0.3 + EPS) - 0.08 * np.log(0.7 + EPS)),
98+
(False, 0.2, [0, 0], -0.48 * np.log(0.3 + EPS) - 0.08 * np.log(0.7 + EPS)),
99+
(False, 0.2, [0, 1], 0.5239465),
99100
],
100101
)
101102
def test_call(self, y_true, y_pred, binary, background_weight, scales, expected):
@@ -135,7 +136,7 @@ def y_pred(self):
135136
(True, None, 0),
136137
(False, None, 0.25),
137138
(False, [0, 0], 0.25),
138-
(False, [0, 1], 0.17484076),
139+
(False, [0, 1], 0.17485845),
139140
],
140141
)
141142
def test_call(self, y_true, y_pred, binary, scales, expected):

test/unit/test_loss_util.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def test_gaussian_kernel1d_size(kernel_size):
6262

6363
grid = tf.range(0, kernel_size, dtype=tf.float32)
6464
expected = tf.exp(-tf.square(grid - mean) / (2 * sigma ** 2))
65+
expected = expected / tf.reduce_sum(expected)
6566

6667
got = gaussian_kernel1d_size(kernel_size)
6768
assert is_equal_tf(got, expected)
@@ -75,6 +76,7 @@ def test_rectangular_kernel1d(kernel_size):
7576
:return:
7677
"""
7778
expected = tf.ones(shape=(kernel_size,), dtype=tf.float32)
79+
expected = expected / tf.reduce_sum(expected)
7880
got = rectangular_kernel1d(kernel_size)
7981
assert is_equal_tf(got, expected)
8082

@@ -91,6 +93,7 @@ def test_triangular_kernel1d(kernel_size):
9193
for it_k in range(kernel_size // 2):
9294
expected[it_k] = it_k + 1
9395
expected[-it_k - 1] = it_k + 1
96+
expected = expected / tf.reduce_sum(expected)
9497

9598
got = triangular_kernel1d(kernel_size)
9699
assert is_equal_tf(got, expected)

0 commit comments

Comments
 (0)