Skip to content

Commit eb642f1

Browse files
committed
ENH: For compatibility, use an exception type that subclasses both original types
1 parent 370b650 commit eb642f1

File tree

3 files changed

+20
-3
lines changed

3 files changed

+20
-3
lines changed

numpy/core/_internal.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -630,3 +630,6 @@ def _gcd(a, b):
630630
# Exception used in shares_memory()
631631
class TooHardError(RuntimeError):
632632
pass
633+
634+
class AxisError(ValueError, IndexError):
635+
pass

numpy/core/numeric.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
ERR_DEFAULT, PINF, NAN)
2828
from . import numerictypes
2929
from .numerictypes import longlong, intc, int_, float_, complex_, bool_
30-
from ._internal import TooHardError
30+
from ._internal import TooHardError, AxisError
3131

3232
bitwise_not = invert
3333
ufunc = type(sin)
@@ -65,7 +65,7 @@
6565
'True_', 'bitwise_not', 'CLIP', 'RAISE', 'WRAP', 'MAXDIMS', 'BUFSIZE',
6666
'ALLOW_THREADS', 'ComplexWarning', 'full', 'full_like', 'matmul',
6767
'shares_memory', 'may_share_memory', 'MAY_SHARE_BOUNDS', 'MAY_SHARE_EXACT',
68-
'TooHardError',
68+
'TooHardError', 'AxisError'
6969
]
7070

7171

numpy/core/src/multiarray/common.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,21 @@ check_and_adjust_axis(int *axis, int ndim)
144144
{
145145
/* Check that index is valid, taking into account negative indices */
146146
if (NPY_UNLIKELY((*axis < -ndim) || (*axis >= ndim))) {
147-
PyErr_Format(PyExc_IndexError,
147+
/*
148+
* Load the exception type, if we don't already have it. Unfortunately
149+
* we don't have access to npy_cache_import here
150+
*/
151+
static PyObject *AxisError_cls = NULL;
152+
if (AxisError_cls == NULL) {
153+
PyObject *mod = PyImport_ImportModule("numpy.core._internal");
154+
155+
if (mod != NULL) {
156+
AxisError_cls = PyObject_GetAttrString(mod, "AxisError");
157+
Py_DECREF(mod);
158+
}
159+
}
160+
161+
PyErr_Format(AxisError_cls,
148162
"axis %d is out of bounds for array of dimension %d",
149163
*axis, ndim);
150164
return -1;

0 commit comments

Comments
 (0)