-
-
Notifications
You must be signed in to change notification settings - Fork 11.9k
Description
Currently the numpy dot operation allows for specifying the out parameter. Although the documentation warns people that this is a performance feature and therefore this code will throw an exception if the out argument does not have the right type, it does not throw an exception if someone tries to overwrite the original matrix.
The ideal fix for this will be to do something smart, by doing the matrix multiplication in a memory efficient way by keeping only a column/row of scratch space [1], but short of that it will be better to throw an exception in case someone provides the out matrix as the same as either of the two matrices being multiplied instead of returning an all zero matrix and turning all the values in the input to zero.
The patch below suggests one possible error message that can be shown to people, and the python session illustrates the current wrong behavior.
diff --git a/numpy/ma/core.py b/numpy/ma/core.py
index 4466dc0..6d9c53b 100755
--- a/numpy/ma/core.py
+++ b/numpy/ma/core.py
@@ -7307,6 +7307,10 @@ def dot(a, b, strict=False, out=None):
am = ~getmaskarray(a)
bm = ~getmaskarray(b)
+ if out is a or out is b:
+ raise ValueError("The out matrix is the same as the input. "
+ "The multiplication output will be zero "
+ "and this is definitely not what you want to do.")
if out is None:
d = np.dot(filled(a, 0), filled(b, 0))
m = ~np.dot(am, bm)>>> import numpy.random
>>> import numpy
>>> def f():
a = numpy.random.randn(3,3)
b = numpy.random.randn(3,3)
c = numpy.dot(a,b)
return a,b,c
>>> a,b,c = f()
>>> numpy.dot(a,b,out=b)
array([[ 0., 0., 0.],
[ 0., 0., 0.],
[ 0., 0., 0.]])
>>> a,b,c = f()
>>> numpy.dot(a,b,out=a)
array([[ 0., 0., 0.],
[ 0., 0., 0.],
[ 0., 0., 0.]])
>>> a,b,c = f()
>>> numpy.dot(a,a,out=a)
array([[ 0., 0., 0.],
[ 0., 0., 0.],
[ 0., 0., 0.]])
>>>
>>> print c
[[ 0.89307654 0.55849275 -0.57240046]
[-1.93567811 -0.75110132 1.60961766]
[ 0.85899293 1.16581478 -0.8796278 ]][1] I looked into scipy.lapack.blas and it exposes the *trmm methods which allow for inplace modification of output when the input matrix is triangular but there is nothing for general rectangular times square matrix.