Fix to preserve dtype of input array in cupy.linalg.norm#875
Fix to preserve dtype of input array in cupy.linalg.norm#875okuta merged 3 commits intocupy:masterfrom
Conversation
|
I think that CuPy aims at the same behavior as the latest NumPy |
| absx **= ord | ||
| ret = absx.sum(axis=axis, keepdims=keepdims) | ||
| ret **= (1.0 / ord) | ||
| ret **= cupy.reciprocal(ord, dtype=ret.dtype) |
There was a problem hiding this comment.
I think you need to update this - np.reciprocal fails if ord is an integer, so yours might too
There was a problem hiding this comment.
Thank you for the comment!
I guess both NumPy and CuPy work correctly if dtype is given to reciprocal?
>>> import numpy as np, cupy as cp
>>> np.reciprocal(10, dtype=np.float32)
0.1
>>> cp.reciprocal(10, dtype=cp.float32)
array(0.1, dtype=float32)There was a problem hiding this comment.
I raised a PR regarding this, could you please confirm? numpy/numpy#10667
cupy/linalg/norms.py
Outdated
| # Zero norm | ||
| # Convert to Python float in accordance with NumPy | ||
| return (x != 0).sum(axis=axis, keepdims=keepdims, dtype='d') | ||
| return (x != 0).astype(x.real.dtype).sum(axis=axis, keepdims=keepdims) |
There was a problem hiding this comment.
Cloud you fix this line.
./cupy/linalg/norms.py:72:80: E501 line too long (82 > 79 characters)
Codecov Report
@@ Coverage Diff @@
## master #875 +/- ##
==========================================
- Coverage 93.23% 89.27% -3.96%
==========================================
Files 107 107
Lines 5984 5848 -136
==========================================
- Hits 5579 5221 -358
- Misses 405 627 +222
Continue to review full report at Codecov.
|
|
I got error on NumPy 1.14.2 environment. |
b74eccb to
6fd3106
Compare
|
I fixed the test code to use (The root cause of this issue is that NumPy started to use |
1fe93f3 to
4cf6c8d
Compare
|
Resolved conflicts. |
|
LGTM! |
Fix to preserve dtype of input array in cupy.linalg.norm
In NumPy 1.14, np.linalg.norm has been changed to preserve dtype of input array.
However, as fix made in NumPy 1.14 was incomplete, there are some cases (e.g., when arbitrary order is specified) that dtype is not preserved.
The issue is recognized by the team and the additional fix is ongoing.
np.linalg.normnot preserving input dtype? numpy/numpy#10364This PR makes CuPy behavior consistent with NumPy 1.14.1 (assuming the above PR got merged and released).
We need to discuss if we should separate tests for
<1.14,==1.14, '>1.14'.(Maybe we can ignore
==1.14... but in that case, can we say that we support NumPy 1.14?)