This repository was archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.7k
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
elemwise_mul results in corrupt response #13193
Copy link
Copy link
Open
Description
Description
When two CSR matrices are multiplied using the * operator, or with nd.sparse.elemwise_mul, the resulting CSR has corrupt indices/indptr values.
Minimum reproducible example
a = nd.array([[1, 2, 3, 0, 0, 0, 0]]).tostype('csr')
print(a, a.data, a.indices, a.indptr)
(
<CSRNDArray 1x7 @cpu(0)>,
[ 1. 2. 3.]
<NDArray 3 @cpu(0)>,
[0 1 2]
<NDArray 3 @cpu(0)>,
[0 3]
<NDArray 2 @cpu(0)>)
b = nd.array([[0, 1, 2, 3, 4, 0, 0]]).tostype('csr')
print(b, b.data, b.indices, b.indptr)
(
<CSRNDArray 1x7 @cpu(0)>,
[ 1. 2. 3. 4.]
<NDArray 4 @cpu(0)>,
[1 2 3 4]
<NDArray 4 @cpu(0)>,
[0 4]
<NDArray 2 @cpu(0)>)
c = a * b
print(c, c.data, c.indices, c.indptr)
(
<CSRNDArray 1x7 @cpu(0)>,
[ 6.00000000e+00 2.00000000e+00 0.00000000e+00 -8.58993459e+09
-1.25413071e+12 4.58308676e-41 -4.48397312e+11]
<NDArray 7 @cpu(0)>,
[ 2 1 140475495284740
4769646468283318577 3910306675754944051 3618421518977090361
3907208325270680133]
<NDArray 7 @cpu(0)>,
[0 2]
<NDArray 2 @cpu(0)>)
d = nd.sparse.elemwise_mul(a, b)
print(d, d.data, d.indices, d.indptr)
(
<CSRNDArray 1x7 @cpu(0)>,
[ 6.00000000e+00 2.00000000e+00 9.97764545e-12 1.58494935e+29
9.94760004e-12 1.08446635e-19 0.00000000e+00]
<NDArray 7 @cpu(0)>,
[ 2 1 4661575685
4661795648 4661787752 4661783920
7594316269682434092]
<NDArray 7 @cpu(0)>,
[0 2]
<NDArray 2 @cpu(0)>)
It is not clear from the documentation whether the c = a * b is allowed when the inputs are CSR matrices. But it goes ahead without an exception, and actually returns a CSR, but the data has been corrupted.
I haven't checked the other elemwise_* operations in the nd.sparse package.
As an aside, it is possible to generate a CSR matrix from a 1-D array. But operating on it results in exceptions. It would be good to incorporate some checks here. E.g.
f = nd.array([1, 2]).tostype('csr')
f
<CSRNDArray 2 @cpu(0)>
f.data
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/local/lib/python2.7/site-packages/mxnet/ndarray/sparse.py", line 476, in data
return self._data()
File "/usr/local/lib/python2.7/site-packages/mxnet/ndarray/sparse.py", line 266, in _data
self.wait_to_read()
File "/usr/local/lib/python2.7/site-packages/mxnet/ndarray/ndarray.py", line 1798, in wait_to_read
check_call(_LIB.MXNDArrayWaitToRead(self.handle))
File "/usr/local/lib/python2.7/site-packages/mxnet/base.py", line 252, in check_call
raise MXNetError(py_str(_LIB.MXGetLastError()))
mxnet.base.MXNetError: [16:17:07] src/operator/tensor/./cast_storage-inl.h:238: Check failed: dns.shape_.ndim() == 2 (1 vs. 2)
Stack trace returned 10 entries:
[bt] (0) 0 libmxnet.so 0x000000010564fb90 libmxnet.so + 15248
[bt] (1) 1 libmxnet.so 0x000000010564f93f libmxnet.so + 14655
[bt] (2) 2 libmxnet.so 0x0000000105b4039d libmxnet.so + 5194653
[bt] (3) 3 libmxnet.so 0x0000000105b37b79 libmxnet.so + 5159801
[bt] (4) 4 libmxnet.so 0x0000000106b911f7 MXNDListFree + 555447
[bt] (5) 5 libmxnet.so 0x0000000106b197f4 MXNDListFree + 65460
[bt] (6) 6 libmxnet.so 0x0000000106b1bec8 MXNDListFree + 75400
[bt] (7) 7 libmxnet.so 0x0000000106b1f261 MXNDListFree + 88609
[bt] (8) 8 libmxnet.so 0x0000000106b1f17f MXNDListFree + 88383
[bt] (9) 9 libmxnet.so 0x0000000106b1cc55 MXNDListFree + 78869