This repository was archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.7k
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
[v2.0] Type Infer Error for MXNet NumPy Array in-place Operation #20447
Copy link
Copy link
Open
Labels
Description
Description
As stated in array api specification, an in-place operation must not change the dtype or shape of the in-place array. But in MXNet NumPy Array library, some binary in-place operation will raise Type inconsistent error. For example:
>>> import mxnet as mx
>>> a1 = mx.np.array([1.0], dtype=onp.float32)
>>> b1 = mx.np.array([2.0], dtype=onp.float64)
>>> a1 *= b1
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/zhenghuj/GitHub/incubator-mxnet/python/mxnet/numpy/multiarray.py", line 268, in _wrap_mxnp_np_ufunc
return func(x1, x2)
File "/Users/zhenghuj/GitHub/incubator-mxnet/python/mxnet/numpy/multiarray.py", line 1104, in __imul__
return multiply(self, other, out=self)
File "/Users/zhenghuj/GitHub/incubator-mxnet/python/mxnet/util.py", line 645, in _wrap_np_binary_func
return func(x1, x2, out=out)
File "/Users/zhenghuj/GitHub/incubator-mxnet/python/mxnet/numpy/multiarray.py", line 3308, in multiply
return _mx_nd_np.multiply(x1, x2, out)
File "/Users/zhenghuj/GitHub/incubator-mxnet/python/mxnet/util.py", line 645, in _wrap_np_binary_func
return func(x1, x2, out=out)
File "/Users/zhenghuj/GitHub/incubator-mxnet/python/mxnet/ndarray/numpy/_op.py", line 1085, in multiply
return _api_internal.multiply(x1, x2, out)
File "mxnet/_ffi/_cython/./function.pxi", line 186, in mxnet._ffi._cy3.core.FunctionBase.__call__
File "mxnet/_ffi/_cython/./function.pxi", line 120, in mxnet._ffi._cy3.core.FuncCall
File "mxnet/_ffi/_cython/./function.pxi", line 108, in mxnet._ffi._cy3.core.FuncCall3
File "mxnet/_ffi/_cython/./base.pxi", line 91, in mxnet._ffi._cy3.core.CALL
mxnet.base.MXNetError: MXNetError: Type inconsistent, Provided = float32, inferred type = float64
The inferred type for a1 is expected to be float32 because *= is in-place operation and we should not change the dtype of a1.
Official NumPy has the expected behavior.
>>> import numpy as onp
>>> a = onp.array([1.0], dtype=onp.float32)
>>> b = onp.array([2.0], dtype=onp.float64)
>>> a += b
>>> a
array([3.], dtype=float32)
To Solve This
Need to check if the operation is in-place or not in InferType function https://github.com/apache/incubator-mxnet/blob/8fd17cef2ee854239c6e66f2dd3b9467aa222f79/src/operator/numpy/np_elemwise_broadcast_op.cc#L47-L61
szha