Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit e0a7dda

Browse files
committed
Fix sample vs. pop variance issue with test_numpy_op.py::test_npx_batch_norm
1 parent 6f599a3 commit e0a7dda

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

tests/python/unittest/test_numpy_op.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1848,15 +1848,18 @@ def _test_batchnorm_impl(axis,
18481848

18491849
running_mean = running_mean * momentum + \
18501850
data_mean_flat * (1 - momentum)
1851+
1852+
m = _np.prod(shape) / shape[axis]
1853+
# cudnn uses m-1 in the denominator of its sample variance calculation, not m
1854+
sample_var_adjust = 1.0 if cudnn_off or fix_gamma else m / (m-1)
18511855
running_var = running_var * momentum + \
1852-
data_var_flat * (1 - momentum)
1856+
data_var_flat * sample_var_adjust * (1 - momentum)
18531857

18541858
W = bn_gamma.reshape(expand_shape)
18551859
dnx = ograd * W
18561860
xsm = data - data_mean
18571861
nd = 1.0 / np.sqrt(data_var + epsilon)
18581862
nx = xsm * nd
1859-
m = _np.prod(shape) / shape[axis]
18601863
dvar = (dnx * xsm).sum(axis=reduce_axis, keepdims=True,
18611864
) * (-0.5) * np.power(nd, 3)
18621865
dmean = -nd * dnx.sum(axis=reduce_axis, keepdims=True) - \

0 commit comments

Comments
 (0)