-
Notifications
You must be signed in to change notification settings - Fork 6.7k
F.Take Backwards - Incorrect Gradient #19817
Description
Description
Backwards implementation of F.take computes incorrect gradient when used after sequence of transpose -> convolution -> transpose. any trainable parameters that receive gradients through the F.take operator are incorrect. Equivalent implementations using slice operators produce correct results.
Other Details
I have been unable to find any other scenario when it happens (for example, if one replaces the Conv Layers in the example below with a linear layer, there is no issue with the gradient computation).
I also encounter the bug on MXNet 1.5 and 1.6 (have not tested with earlier versions).
To Reproduce
Below I provide an example of a simple model with two implementations -- one that uses F.take (Model A) and one that uses F.slice_axis (Model B) instead.
def conv_layer(atrous_rates, num_channels):
convs = HybridSequential()
convs.add(HybridLambda(lambda F, x: F.transpose(x, (0, 2, 1))))
for rate in atrous_rates:
convs.add(Conv1D(num_channels, 3, padding=rate, dilation=rate, activation='tanh'))
convs.add(HybridLambda(lambda F, x: F.transpose(x, (0, 2, 1))))
return convs
class Model(HybridBlock):
"""
Model takes tensors of shape N x T x C and produces predictions with shape N x T
"""
def __init__(self, conv_units, atrous_rates, use_take=False, **kwargs):
super().__init__(prefix=kwargs.get('prefix', None), params=kwargs.get('params', None))
self.use_take = use_take
with self.name_scope():
self.convs = conv_layer(atrous_rates, conv_units)
self.dense_out = Dense(1, flatten=False, activation='tanh')
def hybrid_forward(self, F, X):
X1 = X
X2 = self.convs(X1)
if self.use_take:
X3 = F.take(X2, nd.array([1, 2, 3]), axis=-1)
else:
X3 = F.slice_axis(X2, begin=1, end=4, axis=-1)
X4 = self.dense_out(X3)
X4 = F.squeeze(X4, axis=-1)
return X4The script provided below instantiates both implementations with the same initial weights, computes L2Loss and prints the gradients from both models. A random seed is set so the output should be deterministic (and it is for Model B).
Steps to reproduce
- Download this script: https://gist.github.com/ceisenach/9ffed8343e5576748ec7d5623ffe6c46
- Run script (
python take_bug.py)
Result
- As expected, output of forward pass is the same from both models
- Gradients (Model A): parameters in Model A that receive gradients through
F.takeare on the order of 1e28 (or in some cases are infinite). The results are non-deterministic - Gradients (Model B): Gradient values seem reasonable and are deterministic (same results each time).
Example output from the script I provided
||g_param||_2: INF | Param: model0_conv0_weight
||g_param||_2: 7.21E+18 | Param: model0_conv0_bias
||g_param||_2: INF | Param: model0_conv1_weight
||g_param||_2: INF | Param: model0_conv1_bias
||g_param||_2: INF | Param: model0_conv2_weight
||g_param||_2: INF | Param: model0_conv2_bias
||g_param||_2: 1.38E-04 | Param: model0_dense0_weight
||g_param||_2: 1.06E-02 | Param: model0_dense0_bias
-------------------------------------------
------- Grad Info
* ||g||_2: INF
* ||g||_1: 1.77E+21
* ||g||_inf: 5.79E+20
||g_param||_2: 2.37E-04 | Param: model1_conv0_weight
||g_param||_2: 2.29E-05 | Param: model1_conv0_bias
||g_param||_2: 2.23E-04 | Param: model1_conv1_weight
||g_param||_2: 1.50E-04 | Param: model1_conv1_bias
||g_param||_2: 4.26E-04 | Param: model1_conv2_weight
||g_param||_2: 7.02E-04 | Param: model1_conv2_bias
||g_param||_2: 1.38E-04 | Param: model1_dense0_weight
||g_param||_2: 1.06E-02 | Param: model1_dense0_bias
-------------------------------------------
------- Grad Info
* ||g||_2: 1.06E-02
* ||g||_1: 1.75E-02
* ||g||_inf: 1.06E-02
==== Same outputs?
Y_hat1 - Yhat2 = 0.0000
It appears that there is either an OOB memory access or some values involved in the calculation are not initialized before they are used. I haven't attempted to track down the root cause.
What have you tried to solve it?
In many cases, can workaround by using one of the slice operators and concatenation instead. They do not appear to have any issues.
Environment
OS: ubuntu 18.04
Python: 3.8.5
pip: 20.2.3
mxnet: 1.7.0 (Commit Hash: 64f737c)
numpy: 1.19.2