Skip to content

Commit c3c819a

Browse files
committed
MAINT: add __array_priority__ special case to masked array binary ops
ndarray special methods like __add__ have a special case where if the right argument is not an ndarray or subclass, and it has higher __array_priority__ than the left argument, then we return NotImplemented and let the right argument handle the operation. ufuncs have traditionally had a similar but different special case, where if it's a 2 input - 1 output ufunc, and the right argument is not an ndarray (exactly, subclasses don't count), and when converted to an ndarray ends up as an object array (presumably b/c it doesn't have a meaningful coercion route, though who knows), and it has a higher __array_priority__ than the left argument AND it has a __r<operation>__ attribute, then they return NotImplemented. In practice this latter special case is not used by regular ndarrays, b/c anytime it would need to be triggered, the former special case triggers first and the ufunc is never called. However, numpy.ma did not have the former special case, and was thus relying on the ufunc special case. This commit adds the special case to the numpy.ma special methods directly, so that they no longer depend on the quirky ufunc behaviour. It also cleans up the relevant test to things that actually should be true in general, instead of just testing some implementation details.
1 parent 069f0ff commit c3c819a

File tree

3 files changed

+37
-9
lines changed

3 files changed

+37
-9
lines changed

numpy/core/src/multiarray/number.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,10 @@ PyArray_GenericBinaryFunction(PyArrayObject *m1, PyObject *m2, PyObject *op)
304304
* See also:
305305
* - https://github.com/numpy/numpy/issues/3502
306306
* - https://github.com/numpy/numpy/issues/3503
307+
*
308+
* NB: there's another copy of this code in
309+
* numpy.ma.core.MaskedArray._delegate_binop
310+
* which should possibly be updated when this is.
307311
*/
308312
double m1_prio = PyArray_GetPriority((PyObject *)m1,
309313
NPY_SCALAR_PRIORITY);

numpy/ma/core.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3678,6 +3678,16 @@ def __repr__(self):
36783678
return _print_templates['short_std'] % parameters
36793679
return _print_templates['long_std'] % parameters
36803680

3681+
def _delegate_binop(self, other):
3682+
# This emulates the logic in
3683+
# multiarray/number.c:PyArray_GenericBinaryFunction
3684+
if (not isinstance(other, np.ndarray)
3685+
and not hasattr(other, "__numpy_ufunc__")):
3686+
other_priority = getattr(other, "__array_priority__", -1000000)
3687+
if self.__array_priority__ < other_priority:
3688+
return True
3689+
return False
3690+
36813691
def __eq__(self, other):
36823692
"Check whether other equals self elementwise"
36833693
if self is masked:
@@ -3746,6 +3756,8 @@ def __ne__(self, other):
37463756
#
37473757
def __add__(self, other):
37483758
"Add other to self, and return a new masked array."
3759+
if self._delegate_binop(other):
3760+
return NotImplemented
37493761
return add(self, other)
37503762
#
37513763
def __radd__(self, other):
@@ -3754,6 +3766,8 @@ def __radd__(self, other):
37543766
#
37553767
def __sub__(self, other):
37563768
"Subtract other to self, and return a new masked array."
3769+
if self._delegate_binop(other):
3770+
return NotImplemented
37573771
return subtract(self, other)
37583772
#
37593773
def __rsub__(self, other):
@@ -3762,6 +3776,8 @@ def __rsub__(self, other):
37623776
#
37633777
def __mul__(self, other):
37643778
"Multiply other by self, and return a new masked array."
3779+
if self._delegate_binop(other):
3780+
return NotImplemented
37653781
return multiply(self, other)
37663782
#
37673783
def __rmul__(self, other):
@@ -3770,10 +3786,14 @@ def __rmul__(self, other):
37703786
#
37713787
def __div__(self, other):
37723788
"Divide other into self, and return a new masked array."
3789+
if self._delegate_binop(other):
3790+
return NotImplemented
37733791
return divide(self, other)
37743792
#
37753793
def __truediv__(self, other):
37763794
"Divide other into self, and return a new masked array."
3795+
if self._delegate_binop(other):
3796+
return NotImplemented
37773797
return true_divide(self, other)
37783798
#
37793799
def __rtruediv__(self, other):
@@ -3782,6 +3802,8 @@ def __rtruediv__(self, other):
37823802
#
37833803
def __floordiv__(self, other):
37843804
"Divide other into self, and return a new masked array."
3805+
if self._delegate_binop(other):
3806+
return NotImplemented
37853807
return floor_divide(self, other)
37863808
#
37873809
def __rfloordiv__(self, other):
@@ -3790,6 +3812,8 @@ def __rfloordiv__(self, other):
37903812
#
37913813
def __pow__(self, other):
37923814
"Raise self to the power other, masking the potential NaNs/Infs"
3815+
if self._delegate_binop(other):
3816+
return NotImplemented
37933817
return power(self, other)
37943818
#
37953819
def __rpow__(self, other):

numpy/ma/tests/test_core.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import sys
1313
import pickle
1414
from functools import reduce
15+
import operator
1516

1617
from nose.tools import assert_raises
1718

@@ -1691,16 +1692,15 @@ def test_ndarray_mask(self):
16911692
self.assertTrue(not isinstance(test.mask, MaskedArray))
16921693

16931694
def test_treatment_of_NotImplemented(self):
1694-
# Check any NotImplemented returned by umath.<ufunc> is passed on
1695+
# Check that NotImplemented is returned at appropriate places
1696+
16951697
a = masked_array([1., 2.], mask=[1, 0])
1696-
# basic tests for _MaskedBinaryOperation
1697-
assert_(a.__mul__('abc') is NotImplemented)
1698-
assert_(multiply.outer(a, 'abc') is NotImplemented)
1699-
# and for _DomainedBinaryOperation
1700-
assert_(a.__div__('abc') is NotImplemented)
1701-
1702-
# also check explicitly that rmul of another class can be accessed
1703-
class MyClass(str):
1698+
self.assertRaises(TypeError, operator.mul, a, "abc")
1699+
self.assertRaises(TypeError, operator.truediv, a, "abc")
1700+
1701+
class MyClass(object):
1702+
__array_priority__ = a.__array_priority__ + 1
1703+
17041704
def __mul__(self, other):
17051705
return "My mul"
17061706

0 commit comments

Comments
 (0)