Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

[Numpy][Bug] The gradient of np.pad is wrong! #19043

@sxjscience

Description

@sxjscience

The gradient of np.pad is wrong. See the following reproducible example:

MXNet Implementation:

import mxnet as mx
mx.npx.set_np()

ctx = mx.gpu()
a = mx.np.ones((3, 3, 3), ctx=ctx)
mult = np.random.normal(0, 1, (3, 3, 3))
a.attach_grad()
with mx.autograd.record():
    b = mx.np.pad(a[:, 1:], ((0, 0), (0, 1), (0, 0))) * mx.np.array(mult, ctx=ctx)
    b = b.sum()
b.backward()
print(a.grad)

Output:

[[[0. 0. 0.]
  [1. 1. 1.]
  [1. 1. 1.]]

 [[0. 0. 0.]
  [1. 1. 1.]
  [1. 1. 1.]]

 [[0. 0. 0.]
  [1. 1. 1.]
  [1. 1. 1.]]] @gpu(0)

Jax Implementation:

from jax import grad
import jax.numpy as jnp
import numpy as np
mult = np.random.normal(0, 1, (3, 3, 3))

a = jnp.ones((3, 3, 3))

def f(x):
    b = jnp.pad(x[:, 1:], ((0, 0), (0, 1), (0, 0))) * jnp.array(mult)
    return b.sum()
print(grad(f)(a))

Output:

[[[ 0.          0.          0.        ]
  [ 0.3545383  -0.84326786 -0.31482664]
  [ 1.0994871  -1.230104    2.8007567 ]]

 [[ 0.          0.          0.        ]
  [ 1.0447861  -0.16119051 -0.39860427]
  [-0.7756538   0.5314936   1.4601654 ]]

 [[ 0.          0.          0.        ]
  [ 0.37878916 -2.0777514   0.96676654]
  [ 0.45230922  0.3094176  -0.43687683]]]

Basically, the following line is not correct:

https://github.com/apache/incubator-mxnet/blob/b0c39f7ea983639093c63d7d2486bbef083a55d6/src/operator/numpy/np_pad_op-inl.h#L544-L551

We should change that to

 KERNEL_ASSIGN(out[i], req, a[i]);

In addition, I do not know why we need the using namespace mxnet_op; . @cassinixu

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions