Skip to content

Commit 7315145

Browse files
authored
ENH: __array_function__ support for np.lib, part 2/2 (numpy#12119)
* ENH: __array_function__ support for np.lib, part 2 xref GH12028 np.lib.npyio through np.lib.ufunclike * Fix failures in numpy/core/tests/test_overrides.py * CLN: handle depreaction in dispatchers for np.lib.ufunclike * CLN: fewer dispatchers in lib.twodim_base * CLN: fewer dispatchers in lib.shape_base * CLN: more dispatcher consolidation * BUG: fix test failure * Use all method instead of function in assert_equal * DOC: indicate n is array_like in scimath.logn * MAINT: updates per review * MAINT: more conservative changes in assert_array_equal * MAINT: add back in comment * MAINT: casting tweaks in assert_array_equal * MAINT: fixes and tests for assert_array_equal on subclasses
1 parent 2bdb732 commit 7315145

File tree

12 files changed

+466
-31
lines changed

12 files changed

+466
-31
lines changed

numpy/core/tests/test_overrides.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ def _get_overloaded_args(relevant_args):
1616
return args
1717

1818

19-
def _return_self(self, *args, **kwargs):
20-
return self
19+
def _return_not_implemented(self, *args, **kwargs):
20+
return NotImplemented
2121

2222

2323
class TestGetOverloadedTypesAndArgs(object):
@@ -45,7 +45,7 @@ def test_ndarray(self):
4545
def test_ndarray_subclasses(self):
4646

4747
class OverrideSub(np.ndarray):
48-
__array_function__ = _return_self
48+
__array_function__ = _return_not_implemented
4949

5050
class NoOverrideSub(np.ndarray):
5151
pass
@@ -70,7 +70,7 @@ class NoOverrideSub(np.ndarray):
7070
def test_ndarray_and_duck_array(self):
7171

7272
class Other(object):
73-
__array_function__ = _return_self
73+
__array_function__ = _return_not_implemented
7474

7575
array = np.array(1)
7676
other = Other()
@@ -86,10 +86,10 @@ class Other(object):
8686
def test_ndarray_subclass_and_duck_array(self):
8787

8888
class OverrideSub(np.ndarray):
89-
__array_function__ = _return_self
89+
__array_function__ = _return_not_implemented
9090

9191
class Other(object):
92-
__array_function__ = _return_self
92+
__array_function__ = _return_not_implemented
9393

9494
array = np.array(1)
9595
subarray = np.array(1).view(OverrideSub)
@@ -103,16 +103,16 @@ class Other(object):
103103
def test_many_duck_arrays(self):
104104

105105
class A(object):
106-
__array_function__ = _return_self
106+
__array_function__ = _return_not_implemented
107107

108108
class B(A):
109-
__array_function__ = _return_self
109+
__array_function__ = _return_not_implemented
110110

111111
class C(A):
112-
__array_function__ = _return_self
112+
__array_function__ = _return_not_implemented
113113

114114
class D(object):
115-
__array_function__ = _return_self
115+
__array_function__ = _return_not_implemented
116116

117117
a = A()
118118
b = B()
@@ -135,7 +135,7 @@ class TestNDArrayArrayFunction(object):
135135
def test_method(self):
136136

137137
class SubOverride(np.ndarray):
138-
__array_function__ = _return_self
138+
__array_function__ = _return_not_implemented
139139

140140
class NoOverrideSub(np.ndarray):
141141
pass
@@ -189,7 +189,8 @@ def __array_function__(self, func, types, args, kwargs):
189189
assert_(obj is original)
190190
assert_(func is dispatched_one_arg)
191191
assert_equal(set(types), {MyArray})
192-
assert_equal(args, (original,))
192+
# assert_equal uses the overloaded np.iscomplexobj() internally
193+
assert_(args == (original,))
193194
assert_equal(kwargs, {})
194195

195196
def test_not_implemented(self):

numpy/lib/npyio.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from . import format
1313
from ._datasource import DataSource
1414
from numpy.core.multiarray import packbits, unpackbits
15+
from numpy.core.overrides import array_function_dispatch
1516
from numpy.core._internal import recursive
1617
from ._iotools import (
1718
LineSplitter, NameValidator, StringConverter, ConverterError,
@@ -447,6 +448,11 @@ def load(file, mmap_mode=None, allow_pickle=True, fix_imports=True,
447448
fid.close()
448449

449450

451+
def _save_dispatcher(file, arr, allow_pickle=None, fix_imports=None):
452+
return (arr,)
453+
454+
455+
@array_function_dispatch(_save_dispatcher)
450456
def save(file, arr, allow_pickle=True, fix_imports=True):
451457
"""
452458
Save an array to a binary file in NumPy ``.npy`` format.
@@ -525,6 +531,14 @@ def save(file, arr, allow_pickle=True, fix_imports=True):
525531
fid.close()
526532

527533

534+
def _savez_dispatcher(file, *args, **kwds):
535+
for a in args:
536+
yield a
537+
for v in kwds.values():
538+
yield v
539+
540+
541+
@array_function_dispatch(_savez_dispatcher)
528542
def savez(file, *args, **kwds):
529543
"""
530544
Save several arrays into a single file in uncompressed ``.npz`` format.
@@ -604,6 +618,14 @@ def savez(file, *args, **kwds):
604618
_savez(file, args, kwds, False)
605619

606620

621+
def _savez_compressed_dispatcher(file, *args, **kwds):
622+
for a in args:
623+
yield a
624+
for v in kwds.values():
625+
yield v
626+
627+
628+
@array_function_dispatch(_savez_compressed_dispatcher)
607629
def savez_compressed(file, *args, **kwds):
608630
"""
609631
Save several arrays into a single file in compressed ``.npz`` format.
@@ -1154,6 +1176,13 @@ def tobytes_first(x, conv):
11541176
return X
11551177

11561178

1179+
def _savetxt_dispatcher(fname, X, fmt=None, delimiter=None, newline=None,
1180+
header=None, footer=None, comments=None,
1181+
encoding=None):
1182+
return (X,)
1183+
1184+
1185+
@array_function_dispatch(_savetxt_dispatcher)
11571186
def savetxt(fname, X, fmt='%.18e', delimiter=' ', newline='\n', header='',
11581187
footer='', comments='# ', encoding=None):
11591188
"""

numpy/lib/polynomial.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414

1515
from numpy.core import (isscalar, abs, finfo, atleast_1d, hstack, dot, array,
1616
ones)
17+
from numpy.core.overrides import array_function_dispatch
1718
from numpy.lib.twodim_base import diag, vander
1819
from numpy.lib.function_base import trim_zeros
1920
from numpy.lib.type_check import iscomplex, real, imag, mintypecode
2021
from numpy.linalg import eigvals, lstsq, inv
2122

23+
2224
class RankWarning(UserWarning):
2325
"""
2426
Issued by `polyfit` when the Vandermonde matrix is rank deficient.
@@ -29,6 +31,12 @@ class RankWarning(UserWarning):
2931
"""
3032
pass
3133

34+
35+
def _poly_dispatcher(seq_of_zeros):
36+
return seq_of_zeros
37+
38+
39+
@array_function_dispatch(_poly_dispatcher)
3240
def poly(seq_of_zeros):
3341
"""
3442
Find the coefficients of a polynomial with the given sequence of roots.
@@ -145,6 +153,12 @@ def poly(seq_of_zeros):
145153

146154
return a
147155

156+
157+
def _roots_dispatcher(p):
158+
return p
159+
160+
161+
@array_function_dispatch(_roots_dispatcher)
148162
def roots(p):
149163
"""
150164
Return the roots of a polynomial with coefficients given in p.
@@ -229,6 +243,12 @@ def roots(p):
229243
roots = hstack((roots, NX.zeros(trailing_zeros, roots.dtype)))
230244
return roots
231245

246+
247+
def _polyint_dispatcher(p, m=None, k=None):
248+
return (p,)
249+
250+
251+
@array_function_dispatch(_polyint_dispatcher)
232252
def polyint(p, m=1, k=None):
233253
"""
234254
Return an antiderivative (indefinite integral) of a polynomial.
@@ -322,6 +342,12 @@ def polyint(p, m=1, k=None):
322342
return poly1d(val)
323343
return val
324344

345+
346+
def _polyder_dispatcher(p, m=None):
347+
return (p,)
348+
349+
350+
@array_function_dispatch(_polyder_dispatcher)
325351
def polyder(p, m=1):
326352
"""
327353
Return the derivative of the specified order of a polynomial.
@@ -390,6 +416,12 @@ def polyder(p, m=1):
390416
val = poly1d(val)
391417
return val
392418

419+
420+
def _polyfit_dispatcher(x, y, deg, rcond=None, full=None, w=None, cov=None):
421+
return (x, y, w)
422+
423+
424+
@array_function_dispatch(_polyfit_dispatcher)
393425
def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False):
394426
"""
395427
Least squares polynomial fit.
@@ -610,6 +642,11 @@ def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False):
610642
return c
611643

612644

645+
def _polyval_dispatcher(p, x):
646+
return (p, x)
647+
648+
649+
@array_function_dispatch(_polyval_dispatcher)
613650
def polyval(p, x):
614651
"""
615652
Evaluate a polynomial at specific values.
@@ -679,6 +716,12 @@ def polyval(p, x):
679716
y = y * x + p[i]
680717
return y
681718

719+
720+
def _binary_op_dispatcher(a1, a2):
721+
return (a1, a2)
722+
723+
724+
@array_function_dispatch(_binary_op_dispatcher)
682725
def polyadd(a1, a2):
683726
"""
684727
Find the sum of two polynomials.
@@ -739,6 +782,8 @@ def polyadd(a1, a2):
739782
val = poly1d(val)
740783
return val
741784

785+
786+
@array_function_dispatch(_binary_op_dispatcher)
742787
def polysub(a1, a2):
743788
"""
744789
Difference (subtraction) of two polynomials.
@@ -786,6 +831,7 @@ def polysub(a1, a2):
786831
return val
787832

788833

834+
@array_function_dispatch(_binary_op_dispatcher)
789835
def polymul(a1, a2):
790836
"""
791837
Find the product of two polynomials.
@@ -842,6 +888,12 @@ def polymul(a1, a2):
842888
val = poly1d(val)
843889
return val
844890

891+
892+
def _polydiv_dispatcher(u, v):
893+
return (u, v)
894+
895+
896+
@array_function_dispatch(_polydiv_dispatcher)
845897
def polydiv(u, v):
846898
"""
847899
Returns the quotient and remainder of polynomial division.

0 commit comments

Comments
 (0)