-
-
Notifications
You must be signed in to change notification settings - Fork 12k
Description
In astropy, we have a Quantity class which is a subclass of ndarray for physical quantities. One of its features is that if you compare two of them and the units aren't equivalent, they compare to false.
In Numpy master (but not Numpy 1.8.x), as of 9b8f6c7 (@seberg), it seems there might no longer be a way to implement this, since raising an exception in __array_prepare__ bubbles up through ndarray.richcompare (it currently emits a DeprecationWarning, but I understand will eventually pass the original exception through). Since __array_prepare__ is only able to return an instance ndarray (or subclass), there doesn't seem to be a way to flag to the comparison operator that what we really want is NotImplemented. Overriding the __eq__ operators, etc., doesn't seem to be sufficient, since it can't handle the case where there is a (regular) array on the left hand side.
Here's a minimal script to reproduce the issue (the actual code is much more complex, but this boils it down to the essence of the problem):
import warnings
warnings.filterwarnings("error", ".*", DeprecationWarning)
import numpy as np
class Quantity(np.ndarray):
def __repr__(self):
return '<{0} {1}>'.format(
np.ndarray.__repr__(self), self._unit)
def __new__(cls, value, unit, dtype=None):
value = np.array(value, dtype=dtype)
value = value.view(cls)
value._unit = unit
return value
def __array_prepare__(self, obj, context=None):
function = context[0]
args = context[1][:function.nin]
if ((not isinstance(args[0], Quantity) or args[0]._unit != self._unit) or
(not isinstance(args[1], Quantity) or args[1]._unit != self._unit)):
raise ValueError("Units don't match")
if not isinstance(obj, np.ndarray):
obj = np.array(obj)
view = obj.view(Quantity)
view._unit = self._unit
return view
# def __eq__(self, other):
# try:
# super(Quantity, self).__eq__(other)
# except (DeprecationWarning, ValueError):
# return NotImplemented
q = Quantity([10.0, 20.0], 'm')
print 10 == q
print np.int64(10) == q
The output on Numpy 1.8.x is:
False
False
The output on Numpy master is:
Traceback (most recent call last):
File "quantity_compare.py", line 40, in <module>
print 10 == q
DeprecationWarning: elementwise comparison failed; this will raise the error in the future.
If you uncomment the __eq__ function above, the first test passes, but the second still fails.
Is there a workaround here, or something I'm missing?