11"""Provide different loss or metrics classes for images."""
22import tensorflow as tf
33
4- from deepreg .constant import EPS
54from deepreg .loss .util import NegativeLossMixin
65from deepreg .loss .util import gaussian_kernel1d_size as gaussian_kernel1d
76from deepreg .loss .util import (
1110)
1211from deepreg .registry import REGISTRY
1312
13+ EPS = tf .keras .backend .epsilon ()
14+
1415
1516@REGISTRY .register_loss (name = "ssd" )
1617class SumSquaredDifference (tf .keras .losses .Loss ):
@@ -155,33 +156,27 @@ class LocalNormalizedCrossCorrelation(tf.keras.losses.Loss):
155156
156157 E[t] = sum_i(w_i * t_i) / sum_i(w_i)
157158
158- Here, we assume sum_i(w_i) == 1, means the weights have been normalized.
159-
160159 Similarly, the discrete variance in the window V[t] is
161160
162- V[t] = E[(t - E[t])**2]
161+ V[t] = E[t**2] - E[t] ** 2
163162
164163 The local squared zero-normalized cross-correlation is therefore
165164
166165 E[ (t-E[t]) * (p-E[p]) ] ** 2 / V[t] / V[p]
167166
168- When calculating variance, we choose to subtract the mean first then calculte
169- variance instead of calculating E[t**2] - E[t] ** 2, the reason is that when
170- E[t**2] and E[t] ** 2 are both very large or very small, the subtraction may
171- have large rounding error and makes the result inaccurate. Also, it is not
172- guaranteed that the result >= 0. For more details, please read "Algorithms for
173- computing the sample variance: Analysis and recommendations." page 1.
167+ where the expectation in numerator is
168+
169+ E[ (t-E[t]) * (p-E[p]) ] = E[t * p] - E[t] * E[p]
170+
171+ Different kernel corresponds to different weights.
174172
175173 For now, y_true and y_pred have to be at least 4d tensor, including batch axis.
176174
177175 Reference:
178176
179177 - Zero-normalized cross-correlation (ZNCC):
180178 https://en.wikipedia.org/wiki/Cross-correlation
181- - https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights
182- - Chan, Tony F., Gene H. Golub, and Randall J. LeVeque.
183- "Algorithms for computing the sample variance: Analysis and recommendations."
184- The American Statistician 37.3 (1983): 242-247.
179+ - Code: https://github.com/voxelmorph/voxelmorph/blob/legacy/src/losses.py
185180 """
186181
187182 kernel_fn_dict = dict (
@@ -217,8 +212,13 @@ def __init__(
217212 self .kernel_size = kernel_size
218213
219214 # (kernel_size, )
220- # sum of the kernel weights would be one
221215 self .kernel = self .kernel_fn (kernel_size = self .kernel_size )
216+ # E[1] = sum_i(w_i), ()
217+ self .kernel_vol = tf .reduce_sum (
218+ self .kernel [:, None , None ]
219+ * self .kernel [None , :, None ]
220+ * self .kernel [None , None , :]
221+ )
222222
223223 def call (self , y_true : tf .Tensor , y_pred : tf .Tensor ) -> tf .Tensor :
224224 """
@@ -230,29 +230,38 @@ def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
230230 or (batch, dim1, dim2, dim3, ch)
231231 :return: shape = (batch,)
232232 """
233- # adjust shape to be (batch, dim1, dim2, dim3, ch)
233+ # adjust
234234 if len (y_true .shape ) == 4 :
235235 y_true = tf .expand_dims (y_true , axis = 4 )
236236 y_pred = tf .expand_dims (y_pred , axis = 4 )
237237 assert len (y_true .shape ) == len (y_pred .shape ) == 5
238238
239239 # t = y_true, p = y_pred
240- t_mean = separable_filter (y_true , kernel = self .kernel )
241- p_mean = separable_filter (y_pred , kernel = self .kernel )
242-
243- t = y_true - t_mean
244- p = y_pred - p_mean
245-
246- # the variance can be biased but as both num and denom are biased
247- # it got cancelled
248- # https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights
249- cross = separable_filter (t * p , kernel = self .kernel )
250- t_var = separable_filter (t * t , kernel = self .kernel )
251- p_var = separable_filter (p * p , kernel = self .kernel )
252-
253- num = cross * cross
254- denom = t_var * p_var
255- ncc = (num + EPS ) / (denom + EPS )
240+ # (batch, dim1, dim2, dim3, ch)
241+ t2 = y_true * y_true
242+ p2 = y_pred * y_pred
243+ tp = y_true * y_pred
244+
245+ # sum over kernel
246+ # (batch, dim1, dim2, dim3, 1)
247+ t_sum = separable_filter (y_true , kernel = self .kernel ) # E[t] * E[1]
248+ p_sum = separable_filter (y_pred , kernel = self .kernel ) # E[p] * E[1]
249+ t2_sum = separable_filter (t2 , kernel = self .kernel ) # E[tt] * E[1]
250+ p2_sum = separable_filter (p2 , kernel = self .kernel ) # E[pp] * E[1]
251+ tp_sum = separable_filter (tp , kernel = self .kernel ) # E[tp] * E[1]
252+
253+ # average over kernel
254+ # (batch, dim1, dim2, dim3, 1)
255+ t_avg = t_sum / self .kernel_vol # E[t]
256+ p_avg = p_sum / self .kernel_vol # E[p]
257+
258+ # shape = (batch, dim1, dim2, dim3, 1)
259+ cross = tp_sum - p_avg * t_sum # E[tp] * E[1] - E[p] * E[t] * E[1]
260+ t_var = t2_sum - t_avg * t_sum # V[t] * E[1]
261+ p_var = p2_sum - p_avg * p_sum # V[p] * E[1]
262+
263+ # (E[tp] - E[p] * E[t]) ** 2 / V[t] / V[p]
264+ ncc = (cross * cross + EPS ) / (t_var * p_var + EPS )
256265
257266 return tf .reduce_mean (ncc , axis = [1 , 2 , 3 , 4 ])
258267
0 commit comments