@@ -46,60 +46,13 @@ def _yield_comparisons(self, actual):
4646 raise NotImplementedError
4747
4848
49- class ApproxNumpyBase (ApproxBase ):
49+ class ApproxNumpy (ApproxBase ):
5050 """
5151 Perform approximate comparisons for numpy arrays.
52-
53- This class should not be used directly. Instead, the `inherit_ndarray()`
54- class method should be used to make a subclass that also inherits from
55- `np.ndarray`. This indirection is necessary because the object doing the
56- approximate comparison must inherit from `np.ndarray`, or it will only work
57- on the left side of the `==` operator. But importing numpy is relatively
58- expensive, so we also want to avoid that unless we actually have a numpy
59- array to compare.
60-
61- The reason why the approx object needs to inherit from `np.ndarray` has to
62- do with how python decides whether to call `a.__eq__()` or `b.__eq__()`
63- when it parses `a == b`. If `a` and `b` are not related by inheritance,
64- `a` gets priority. So as long as `a.__eq__` is defined, it will be called.
65- Because most implementations of `a.__eq__` end up calling `b.__eq__`, this
66- detail usually doesn't matter. However, `np.ndarray.__eq__` treats the
67- approx object as a scalar and builds a new array by comparing it to each
68- item in the original array. `b.__eq__` is called to compare against each
69- individual element in the array, but it has no way (that I can see) to
70- prevent the return value from being an boolean array, and boolean arrays
71- can't be used with assert because "the truth value of an array with more
72- than one element is ambiguous."
73-
74- The trick is that the priority rules change if `a` and `b` are related
75- by inheritance. Specifically, `b.__eq__` gets priority if `b` is a
76- subclass of `a`. So by inheriting from `np.ndarray`, we can guarantee that
77- `ApproxNumpy.__eq__` gets called no matter which side of the `==` operator
78- it appears on.
7952 """
8053
81- subclass = None
82-
83- @classmethod
84- def inherit_ndarray (cls ):
85- import numpy as np
86- assert not isinstance (cls , np .ndarray )
87-
88- if cls .subclass is None :
89- cls .subclass = type ('ApproxNumpy' , (cls , np .ndarray ), {})
90-
91- return cls .subclass
92-
93- def __new__ (cls , expected , rel = None , abs = None , nan_ok = False ):
94- """
95- Numpy uses __new__ (rather than __init__) to initialize objects.
96-
97- The `expected` argument must be a numpy array. This should be
98- ensured by the approx() delegator function.
99- """
100- obj = super (ApproxNumpyBase , cls ).__new__ (cls , ())
101- obj .__init__ (expected , rel , abs , nan_ok )
102- return obj
54+ # Tell numpy to use our `__eq__` operator instead of its.
55+ __array_priority__ = 100
10356
10457 def __repr__ (self ):
10558 # It might be nice to rewrite this function to account for the
@@ -113,7 +66,7 @@ def __eq__(self, actual):
11366 try :
11467 actual = np .asarray (actual )
11568 except :
116- raise ValueError ("cannot cast '{0}' to numpy.ndarray" .format (actual ))
69+ raise TypeError ("cannot compare '{0}' to numpy.ndarray" .format (actual ))
11770
11871 if actual .shape != self .expected .shape :
11972 return False
@@ -157,6 +110,9 @@ class ApproxSequence(ApproxBase):
157110 Perform approximate comparisons for sequences of numbers.
158111 """
159112
113+ # Tell numpy to use our `__eq__` operator instead of its.
114+ __array_priority__ = 100
115+
160116 def __repr__ (self ):
161117 seq_type = type (self .expected )
162118 if seq_type not in (tuple , list , set ):
@@ -422,9 +378,7 @@ def approx(expected, rel=None, abs=None, nan_ok=False):
422378 # their keys, which is probably not what most people would expect.
423379
424380 if _is_numpy_array (expected ):
425- # Create the delegate class on the fly. This allow us to inherit from
426- # ``np.ndarray`` while still not importing numpy unless we need to.
427- cls = ApproxNumpyBase .inherit_ndarray ()
381+ cls = ApproxNumpy
428382 elif isinstance (expected , Mapping ):
429383 cls = ApproxMapping
430384 elif isinstance (expected , Sequence ) and not isinstance (expected , String ):
0 commit comments