Skip to content

Commit 14ff219

Browse files
authored
Merge pull request #9087 from eric-wieser/fix-ufunc-resolution
BUG: __array_ufunc__ should always be looked up on the type, never the instance
2 parents 82e923f + 8f9eeef commit 14ff219

File tree

10 files changed

+175
-122
lines changed

10 files changed

+175
-122
lines changed

numpy/core/_internal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,7 @@ def __init__(self, axis, ndim=None, msg_prefix=None):
688688
super(AxisError, self).__init__(msg)
689689

690690

691-
def array_ufunc_errmsg_formatter(ufunc, method, *inputs, **kwargs):
691+
def array_ufunc_errmsg_formatter(dummy, ufunc, method, *inputs, **kwargs):
692692
""" Format the error message for when __array_ufunc__ gives up. """
693693
args_string = ', '.join(['{!r}'.format(arg) for arg in inputs] +
694694
['{}={!r}'.format(k, v)

numpy/core/src/multiarray/common.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ PyArray_DTypeFromObjectHelper(PyObject *obj, int maxdims,
329329
}
330330

331331
/* The array interface */
332-
ip = PyArray_GetAttrString_SuppressException(obj, "__array_interface__");
332+
ip = PyArray_LookupSpecial_OnInstance(obj, "__array_interface__");
333333
if (ip != NULL) {
334334
if (PyDict_Check(ip)) {
335335
PyObject *typestr;
@@ -362,7 +362,7 @@ PyArray_DTypeFromObjectHelper(PyObject *obj, int maxdims,
362362
}
363363

364364
/* The array struct interface */
365-
ip = PyArray_GetAttrString_SuppressException(obj, "__array_struct__");
365+
ip = PyArray_LookupSpecial_OnInstance(obj, "__array_struct__");
366366
if (ip != NULL) {
367367
PyArrayInterface *inter;
368368
char buf[40];
@@ -397,7 +397,7 @@ PyArray_DTypeFromObjectHelper(PyObject *obj, int maxdims,
397397
#endif
398398

399399
/* The __array__ attribute */
400-
ip = PyArray_GetAttrString_SuppressException(obj, "__array__");
400+
ip = PyArray_LookupSpecial_OnInstance(obj, "__array__");
401401
if (ip != NULL) {
402402
Py_DECREF(ip);
403403
ip = PyObject_CallMethod(obj, "__array__", NULL);

numpy/core/src/multiarray/ctors.c

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,7 @@ discover_dimensions(PyObject *obj, int *maxndim, npy_intp *d, int check_it,
753753
}
754754

755755
/* obj has the __array_struct__ interface */
756-
e = PyArray_GetAttrString_SuppressException(obj, "__array_struct__");
756+
e = PyArray_LookupSpecial_OnInstance(obj, "__array_struct__");
757757
if (e != NULL) {
758758
int nd = -1;
759759
if (NpyCapsule_Check(e)) {
@@ -778,7 +778,7 @@ discover_dimensions(PyObject *obj, int *maxndim, npy_intp *d, int check_it,
778778
}
779779

780780
/* obj has the __array_interface__ interface */
781-
e = PyArray_GetAttrString_SuppressException(obj, "__array_interface__");
781+
e = PyArray_LookupSpecial_OnInstance(obj, "__array_interface__");
782782
if (e != NULL) {
783783
int nd = -1;
784784
if (PyDict_Check(e)) {
@@ -2062,7 +2062,7 @@ PyArray_FromStructInterface(PyObject *input)
20622062
PyArrayObject *ret;
20632063
char endian = NPY_NATBYTE;
20642064

2065-
attr = PyArray_GetAttrString_SuppressException(input, "__array_struct__");
2065+
attr = PyArray_LookupSpecial_OnInstance(input, "__array_struct__");
20662066
if (attr == NULL) {
20672067
return Py_NotImplemented;
20682068
}
@@ -2176,7 +2176,7 @@ PyArray_FromInterface(PyObject *origin)
21762176
npy_intp dims[NPY_MAXDIMS], strides[NPY_MAXDIMS];
21772177
int dataflags = NPY_ARRAY_BEHAVED;
21782178

2179-
iface = PyArray_GetAttrString_SuppressException(origin,
2179+
iface = PyArray_LookupSpecial_OnInstance(origin,
21802180
"__array_interface__");
21812181
if (iface == NULL) {
21822182
return Py_NotImplemented;
@@ -2409,7 +2409,7 @@ PyArray_FromArrayAttr(PyObject *op, PyArray_Descr *typecode, PyObject *context)
24092409
PyObject *new;
24102410
PyObject *array_meth;
24112411

2412-
array_meth = PyArray_GetAttrString_SuppressException(op, "__array__");
2412+
array_meth = PyArray_LookupSpecial_OnInstance(op, "__array__");
24132413
if (array_meth == NULL) {
24142414
return Py_NotImplemented;
24152415
}

numpy/core/src/multiarray/methods.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1024,7 +1024,7 @@ array_ufunc(PyArrayObject *self, PyObject *args, PyObject *kwds)
10241024
return NULL;
10251025
}
10261026
/* ndarray cannot handle overrides itself */
1027-
num_override_args = PyUFunc_WithOverride(normal_args, kwds, NULL);
1027+
num_override_args = PyUFunc_WithOverride(normal_args, kwds, NULL, NULL);
10281028
if (num_override_args == -1) {
10291029
return NULL;
10301030
}

numpy/core/src/multiarray/multiarraymodule.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ PyArray_GetPriority(PyObject *obj, double default_)
8484
return NPY_SCALAR_PRIORITY;
8585
}
8686

87-
ret = PyArray_GetAttrString_SuppressException(obj, "__array_priority__");
87+
ret = PyArray_LookupSpecial_OnInstance(obj, "__array_priority__");
8888
if (ret == NULL) {
8989
return default_;
9090
}

numpy/core/src/private/binop_override.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,15 +121,14 @@ binop_should_defer(PyObject *self, PyObject *other, int inplace)
121121
self == NULL ||
122122
Py_TYPE(self) == Py_TYPE(other) ||
123123
PyArray_CheckExact(other) ||
124-
PyArray_CheckAnyScalarExact(other) ||
125-
_is_basic_python_type(other)) {
124+
PyArray_CheckAnyScalarExact(other)) {
126125
return 0;
127126
}
128127
/*
129128
* Classes with __array_ufunc__ are living in the future, and only need to
130129
* check whether __array_ufunc__ equals None.
131130
*/
132-
attr = PyArray_GetAttrString_SuppressException(other, "__array_ufunc__");
131+
attr = PyArray_LookupSpecial(other, "__array_ufunc__");
133132
if (attr) {
134133
defer = !inplace && (attr == Py_None);
135134
Py_DECREF(attr);
Lines changed: 77 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,45 @@
11
#ifndef __GET_ATTR_STRING_H
22
#define __GET_ATTR_STRING_H
33

4-
static NPY_INLINE int
5-
_is_basic_python_type(PyObject * obj)
4+
static NPY_INLINE npy_bool
5+
_is_basic_python_type(PyTypeObject *tp)
66
{
7-
if (obj == Py_None ||
8-
PyBool_Check(obj) ||
9-
/* Basic number types */
7+
return (
8+
/* Basic number types */
9+
tp == &PyBool_Type ||
1010
#if !defined(NPY_PY3K)
11-
PyInt_CheckExact(obj) ||
12-
PyString_CheckExact(obj) ||
11+
tp == &PyInt_Type ||
1312
#endif
14-
PyLong_CheckExact(obj) ||
15-
PyFloat_CheckExact(obj) ||
16-
PyComplex_CheckExact(obj) ||
17-
/* Basic sequence types */
18-
PyList_CheckExact(obj) ||
19-
PyTuple_CheckExact(obj) ||
20-
PyDict_CheckExact(obj) ||
21-
PyAnySet_CheckExact(obj) ||
22-
PyUnicode_CheckExact(obj) ||
23-
PyBytes_CheckExact(obj) ||
24-
PySlice_Check(obj)) {
13+
tp == &PyLong_Type ||
14+
tp == &PyFloat_Type ||
15+
tp == &PyComplex_Type ||
2516

26-
return 1;
27-
}
17+
/* Basic sequence types */
18+
tp == &PyList_Type ||
19+
tp == &PyTuple_Type ||
20+
tp == &PyDict_Type ||
21+
tp == &PySet_Type ||
22+
tp == &PyFrozenSet_Type ||
23+
tp == &PyUnicode_Type ||
24+
tp == &PyBytes_Type ||
25+
#if !defined(NPY_PY3K)
26+
tp == &PyString_Type ||
27+
#endif
28+
29+
/* other builtins */
30+
tp == &PySlice_Type ||
31+
tp == Py_TYPE(Py_None) ||
32+
tp == Py_TYPE(Py_Ellipsis) ||
33+
tp == Py_TYPE(Py_NotImplemented) ||
34+
35+
/* TODO: ndarray, but we can't see PyArray_Type here */
2836

29-
return 0;
37+
/* sentinel to swallow trailing || */
38+
NPY_FALSE
39+
);
3040
}
3141

3242
/*
33-
* PyArray_GetAttrString_SuppressException:
34-
*
3543
* Stripped down version of PyObject_GetAttrString,
3644
* avoids lookups for None, tuple, and List objects,
3745
* and doesn't create a PyErr since this code ignores it.
@@ -43,19 +51,14 @@ _is_basic_python_type(PyObject * obj)
4351
*
4452
* 'name' is the attribute to search for.
4553
*
46-
* Returns attribute value on success, 0 on failure.
54+
* Returns attribute value on success, NULL on failure.
4755
*/
48-
static PyObject *
49-
PyArray_GetAttrString_SuppressException(PyObject *obj, char *name)
56+
static NPY_INLINE PyObject *
57+
maybe_get_attr(PyObject *obj, char *name)
5058
{
5159
PyTypeObject *tp = Py_TYPE(obj);
5260
PyObject *res = (PyObject *)NULL;
5361

54-
/* We do not need to check for special attributes on trivial types */
55-
if (_is_basic_python_type(obj)) {
56-
return NULL;
57-
}
58-
5962
/* Attribute referenced by (char *)name */
6063
if (tp->tp_getattr != NULL) {
6164
res = (*tp->tp_getattr)(obj, name);
@@ -82,4 +85,47 @@ PyArray_GetAttrString_SuppressException(PyObject *obj, char *name)
8285
return res;
8386
}
8487

88+
/*
89+
* Lookup a special method, following the python approach of looking up
90+
* on the type object, rather than on the instance itself.
91+
*
92+
* Assumes that the special method is a numpy-specific one, so does not look
93+
* at builtin types, nor does it look at a base ndarray.
94+
*
95+
* In future, could be made more like _Py_LookupSpecial
96+
*/
97+
static NPY_INLINE PyObject *
98+
PyArray_LookupSpecial(PyObject *obj, char *name)
99+
{
100+
PyTypeObject *tp = Py_TYPE(obj);
101+
102+
/* We do not need to check for special attributes on trivial types */
103+
if (_is_basic_python_type(tp)) {
104+
return NULL;
105+
}
106+
107+
return maybe_get_attr((PyObject *)tp, name);
108+
}
109+
110+
/*
111+
* PyArray_LookupSpecial_OnInstance:
112+
*
113+
* Implements incorrect special method lookup rules, that break the python
114+
* convention, and looks on the instance, not the type.
115+
*
116+
* Kept for backwards compatibility. In future, we should deprecate this.
117+
*/
118+
static NPY_INLINE PyObject *
119+
PyArray_LookupSpecial_OnInstance(PyObject *obj, char *name)
120+
{
121+
PyTypeObject *tp = Py_TYPE(obj);
122+
123+
/* We do not need to check for special attributes on trivial types */
124+
if (_is_basic_python_type(tp)) {
125+
return NULL;
126+
}
127+
128+
return maybe_get_attr(obj, name);
129+
}
130+
85131
#endif

numpy/core/src/private/ufunc_override.c

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,22 @@
1212
* is not the default, i.e., the object is not an ndarray, and its
1313
* __array_ufunc__ is not the same as that of ndarray.
1414
*
15+
* Returns a new reference, the value of type(obj).__array_ufunc__
16+
*
17+
* If the __array_ufunc__ matches that of ndarray, or does not exist, return
18+
* NULL.
19+
*
1520
* Note that since this module is used with both multiarray and umath, we do
1621
* not have access to PyArray_Type and therewith neither to PyArray_CheckExact
1722
* nor to the default __array_ufunc__ method, so instead we import locally.
1823
* TODO: Can this really not be done more smartly?
1924
*/
20-
static int
21-
has_non_default_array_ufunc(PyObject *obj)
25+
static PyObject *
26+
get_non_default_array_ufunc(PyObject *obj)
2227
{
2328
static PyObject *ndarray = NULL;
2429
static PyObject *ndarray_array_ufunc = NULL;
2530
PyObject *cls_array_ufunc;
26-
int non_default;
2731

2832
/* on first entry, import and cache ndarray and its __array_ufunc__ */
2933
if (ndarray == NULL) {
@@ -34,47 +38,33 @@ has_non_default_array_ufunc(PyObject *obj)
3438

3539
/* Fast return for ndarray */
3640
if ((PyObject *)Py_TYPE(obj) == ndarray) {
37-
return 0;
41+
return NULL;
3842
}
3943
/* does the class define __array_ufunc__? */
40-
cls_array_ufunc = PyArray_GetAttrString_SuppressException(
41-
(PyObject *)Py_TYPE(obj), "__array_ufunc__");
44+
cls_array_ufunc = PyArray_LookupSpecial(obj, "__array_ufunc__");
4245
if (cls_array_ufunc == NULL) {
43-
return 0;
46+
return NULL;
4447
}
4548
/* is it different from ndarray.__array_ufunc__? */
46-
non_default = (cls_array_ufunc != ndarray_array_ufunc);
49+
if (cls_array_ufunc != ndarray_array_ufunc) {
50+
return cls_array_ufunc;
51+
}
4752
Py_DECREF(cls_array_ufunc);
48-
return non_default;
49-
}
50-
51-
/*
52-
* Check whether an object sets __array_ufunc__ = None. The __array_func__
53-
* attribute must already be known to exist.
54-
*/
55-
static int
56-
disables_array_ufunc(PyObject *obj)
57-
{
58-
PyObject *array_ufunc;
59-
int disables;
60-
61-
array_ufunc = PyObject_GetAttrString(obj, "__array_ufunc__");
62-
disables = (array_ufunc == Py_None);
63-
Py_XDECREF(array_ufunc);
64-
return disables;
53+
return NULL;
6554
}
6655

6756
/*
6857
* Check whether a set of input and output args have a non-default
6958
* `__array_ufunc__` method. Return the number of overrides, setting
7059
* corresponding objects in PyObject array with_override (if not NULL)
71-
* using borrowed references.
60+
* using borrowed references, and the corresponding __array_ufunc__ methods
61+
* in methods, using new references
7262
*
7363
* returns -1 on failure.
7464
*/
7565
NPY_NO_EXPORT int
7666
PyUFunc_WithOverride(PyObject *args, PyObject *kwds,
77-
PyObject **with_override)
67+
PyObject **with_override, PyObject **methods)
7868
{
7969
int i;
8070

@@ -116,6 +106,7 @@ PyUFunc_WithOverride(PyObject *args, PyObject *kwds,
116106
}
117107

118108
for (i = 0; i < nargs + nout_kwd; ++i) {
109+
PyObject *method;
119110
if (i < nargs) {
120111
obj = PyTuple_GET_ITEM(args, i);
121112
}
@@ -132,22 +123,32 @@ PyUFunc_WithOverride(PyObject *args, PyObject *kwds,
132123
* ignore the base ndarray.__ufunc__, so we skip any ndarray as well as
133124
* any ndarray subclass instances that did not override __array_ufunc__.
134125
*/
135-
if (has_non_default_array_ufunc(obj)) {
136-
if (disables_array_ufunc(obj)) {
126+
method = get_non_default_array_ufunc(obj);
127+
if (method != NULL) {
128+
if (method == Py_None) {
137129
PyErr_Format(PyExc_TypeError,
138130
"operand '%.200s' does not support ufuncs "
139131
"(__array_ufunc__=None)",
140132
obj->ob_type->tp_name);
133+
Py_DECREF(method);
141134
goto fail;
142135
}
143136
if (with_override != NULL) {
144137
with_override[num_override_args] = obj;
145138
}
139+
if (methods != NULL) {
140+
methods[num_override_args] = method;
141+
}
146142
++num_override_args;
147143
}
148144
}
149145
return num_override_args;
150146

151147
fail:
148+
if (methods != NULL) {
149+
for (i = 0; i < num_override_args; i++) {
150+
Py_XDECREF(methods[i]);
151+
}
152+
}
152153
return -1;
153154
}

numpy/core/src/private/ufunc_override.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@
1111
*/
1212
NPY_NO_EXPORT int
1313
PyUFunc_WithOverride(PyObject *args, PyObject *kwds,
14-
PyObject **with_override);
14+
PyObject **with_override, PyObject **methods);
1515
#endif

0 commit comments

Comments
 (0)