Skip to content

No way to mark "NotImplemented" for comparison operations on subclasses #4709

@mdboom

Description

@mdboom

In astropy, we have a Quantity class which is a subclass of ndarray for physical quantities. One of its features is that if you compare two of them and the units aren't equivalent, they compare to false.

In Numpy master (but not Numpy 1.8.x), as of 9b8f6c7 (@seberg), it seems there might no longer be a way to implement this, since raising an exception in __array_prepare__ bubbles up through ndarray.richcompare (it currently emits a DeprecationWarning, but I understand will eventually pass the original exception through). Since __array_prepare__ is only able to return an instance ndarray (or subclass), there doesn't seem to be a way to flag to the comparison operator that what we really want is NotImplemented. Overriding the __eq__ operators, etc., doesn't seem to be sufficient, since it can't handle the case where there is a (regular) array on the left hand side.

Here's a minimal script to reproduce the issue (the actual code is much more complex, but this boils it down to the essence of the problem):

import warnings
warnings.filterwarnings("error", ".*", DeprecationWarning)

import numpy as np


class Quantity(np.ndarray):
    def __repr__(self):
        return '<{0} {1}>'.format(
            np.ndarray.__repr__(self), self._unit)

    def __new__(cls, value, unit, dtype=None):
        value = np.array(value, dtype=dtype)
        value = value.view(cls)
        value._unit = unit
        return value

    def __array_prepare__(self, obj, context=None):
        function = context[0]
        args = context[1][:function.nin]

        if ((not isinstance(args[0], Quantity) or args[0]._unit != self._unit) or
            (not isinstance(args[1], Quantity) or args[1]._unit != self._unit)):
            raise ValueError("Units don't match")

        if not isinstance(obj, np.ndarray):
            obj = np.array(obj)
        view = obj.view(Quantity)
        view._unit = self._unit
        return view

    # def __eq__(self, other):
    #     try:
    #         super(Quantity, self).__eq__(other)
    #     except (DeprecationWarning, ValueError):
    #         return NotImplemented


q = Quantity([10.0, 20.0], 'm')
print 10 == q
print np.int64(10) == q

The output on Numpy 1.8.x is:

False
False

The output on Numpy master is:

Traceback (most recent call last):
  File "quantity_compare.py", line 40, in <module>
    print 10 == q
DeprecationWarning: elementwise comparison failed; this will raise the error in the future.

If you uncomment the __eq__ function above, the first test passes, but the second still fails.

Is there a workaround here, or something I'm missing?

@mhvk, @astrofrog

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions