11"""Provide different loss or metrics classes for images."""
22import tensorflow as tf
33
4+ from deepreg .constant import EPS
45from deepreg .loss .util import NegativeLossMixin
56from deepreg .loss .util import gaussian_kernel1d_size as gaussian_kernel1d
67from deepreg .loss .util import (
1011)
1112from deepreg .registry import REGISTRY
1213
13- EPS = tf .keras .backend .epsilon ()
14-
1514
1615@REGISTRY .register_loss (name = "ssd" )
1716class SumSquaredDifference (tf .keras .losses .Loss ):
@@ -156,27 +155,33 @@ class LocalNormalizedCrossCorrelation(tf.keras.losses.Loss):
156155
157156 E[t] = sum_i(w_i * t_i) / sum_i(w_i)
158157
158+ Here, we assume sum_i(w_i) == 1, means the weights have been normalized.
159+
159160 Similarly, the discrete variance in the window V[t] is
160161
161- V[t] = E[t**2] - E[t] ** 2
162+ V[t] = E[(t - E[t])**2]
162163
163164 The local squared zero-normalized cross-correlation is therefore
164165
165166 E[ (t-E[t]) * (p-E[p]) ] ** 2 / V[t] / V[p]
166167
167- where the expectation in numerator is
168-
169- E[ (t-E[t]) * (p-E[p]) ] = E[t * p] - E[t] * E[p]
170-
171- Different kernel corresponds to different weights.
168+ When calculating variance, we choose to subtract the mean first then calculte
169+ variance instead of calculating E[t**2] - E[t] ** 2, the reason is that when
170+ E[t**2] and E[t] ** 2 are both very large or very small, the subtraction may
171+ have large rounding error and makes the result inaccurate. Also, it is not
172+ guaranteed that the result >= 0. For more details, please read "Algorithms for
173+ computing the sample variance: Analysis and recommendations." page 1.
172174
173175 For now, y_true and y_pred have to be at least 4d tensor, including batch axis.
174176
175177 Reference:
176178
177179 - Zero-normalized cross-correlation (ZNCC):
178180 https://en.wikipedia.org/wiki/Cross-correlation
179- - Code: https://github.com/voxelmorph/voxelmorph/blob/legacy/src/losses.py
181+ - https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights
182+ - Chan, Tony F., Gene H. Golub, and Randall J. LeVeque.
183+ "Algorithms for computing the sample variance: Analysis and recommendations."
184+ The American Statistician 37.3 (1983): 242-247.
180185 """
181186
182187 kernel_fn_dict = dict (
@@ -212,13 +217,8 @@ def __init__(
212217 self .kernel_size = kernel_size
213218
214219 # (kernel_size, )
220+ # sum of the kernel weights would be one
215221 self .kernel = self .kernel_fn (kernel_size = self .kernel_size )
216- # E[1] = sum_i(w_i), ()
217- self .kernel_vol = tf .reduce_sum (
218- self .kernel [:, None , None ]
219- * self .kernel [None , :, None ]
220- * self .kernel [None , None , :]
221- )
222222
223223 def call (self , y_true : tf .Tensor , y_pred : tf .Tensor ) -> tf .Tensor :
224224 """
@@ -230,38 +230,29 @@ def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
230230 or (batch, dim1, dim2, dim3, ch)
231231 :return: shape = (batch,)
232232 """
233- # adjust
233+ # adjust shape to be (batch, dim1, dim2, dim3, ch)
234234 if len (y_true .shape ) == 4 :
235235 y_true = tf .expand_dims (y_true , axis = 4 )
236236 y_pred = tf .expand_dims (y_pred , axis = 4 )
237237 assert len (y_true .shape ) == len (y_pred .shape ) == 5
238238
239239 # t = y_true, p = y_pred
240- # (batch, dim1, dim2, dim3, ch)
241- t2 = y_true * y_true
242- p2 = y_pred * y_pred
243- tp = y_true * y_pred
244-
245- # sum over kernel
246- # (batch, dim1, dim2, dim3, 1)
247- t_sum = separable_filter (y_true , kernel = self .kernel ) # E[t] * E[1]
248- p_sum = separable_filter (y_pred , kernel = self .kernel ) # E[p] * E[1]
249- t2_sum = separable_filter (t2 , kernel = self .kernel ) # E[tt] * E[1]
250- p2_sum = separable_filter (p2 , kernel = self .kernel ) # E[pp] * E[1]
251- tp_sum = separable_filter (tp , kernel = self .kernel ) # E[tp] * E[1]
252-
253- # average over kernel
254- # (batch, dim1, dim2, dim3, 1)
255- t_avg = t_sum / self .kernel_vol # E[t]
256- p_avg = p_sum / self .kernel_vol # E[p]
257-
258- # shape = (batch, dim1, dim2, dim3, 1)
259- cross = tp_sum - p_avg * t_sum # E[tp] * E[1] - E[p] * E[t] * E[1]
260- t_var = t2_sum - t_avg * t_sum # V[t] * E[1]
261- p_var = p2_sum - p_avg * p_sum # V[p] * E[1]
262-
263- # (E[tp] - E[p] * E[t]) ** 2 / V[t] / V[p]
264- ncc = (cross * cross + EPS ) / (t_var * p_var + EPS )
240+ t_mean = separable_filter (y_true , kernel = self .kernel )
241+ p_mean = separable_filter (y_pred , kernel = self .kernel )
242+
243+ t = y_true - t_mean
244+ p = y_pred - p_mean
245+
246+ # the variance can be biased but as both num and denom are biased
247+ # it got cancelled
248+ # https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights
249+ cross = separable_filter (t * p , kernel = self .kernel )
250+ t_var = separable_filter (t * t , kernel = self .kernel )
251+ p_var = separable_filter (p * p , kernel = self .kernel )
252+
253+ num = cross * cross
254+ denom = t_var * p_var
255+ ncc = (num + EPS ) / (denom + EPS )
265256
266257 return tf .reduce_mean (ncc , axis = [1 , 2 , 3 , 4 ])
267258
0 commit comments