The gradient of np.pad is wrong. See the following reproducible example:
MXNet Implementation:
import mxnet as mx
mx.npx.set_np()
ctx = mx.gpu()
a = mx.np.ones((3, 3, 3), ctx=ctx)
mult = np.random.normal(0, 1, (3, 3, 3))
a.attach_grad()
with mx.autograd.record():
b = mx.np.pad(a[:, 1:], ((0, 0), (0, 1), (0, 0))) * mx.np.array(mult, ctx=ctx)
b = b.sum()
b.backward()
print(a.grad)
Output:
[[[0. 0. 0.]
[1. 1. 1.]
[1. 1. 1.]]
[[0. 0. 0.]
[1. 1. 1.]
[1. 1. 1.]]
[[0. 0. 0.]
[1. 1. 1.]
[1. 1. 1.]]] @gpu(0)
Jax Implementation:
from jax import grad
import jax.numpy as jnp
import numpy as np
mult = np.random.normal(0, 1, (3, 3, 3))
a = jnp.ones((3, 3, 3))
def f(x):
b = jnp.pad(x[:, 1:], ((0, 0), (0, 1), (0, 0))) * jnp.array(mult)
return b.sum()
print(grad(f)(a))
Output:
[[[ 0. 0. 0. ]
[ 0.3545383 -0.84326786 -0.31482664]
[ 1.0994871 -1.230104 2.8007567 ]]
[[ 0. 0. 0. ]
[ 1.0447861 -0.16119051 -0.39860427]
[-0.7756538 0.5314936 1.4601654 ]]
[[ 0. 0. 0. ]
[ 0.37878916 -2.0777514 0.96676654]
[ 0.45230922 0.3094176 -0.43687683]]]
Basically, the following line is not correct:
https://github.com/apache/incubator-mxnet/blob/b0c39f7ea983639093c63d7d2486bbef083a55d6/src/operator/numpy/np_pad_op-inl.h#L544-L551
We should change that to
KERNEL_ASSIGN(out[i], req, a[i]);
In addition, I do not know why we need the using namespace mxnet_op; . @cassinixu