Skip to content
66 changes: 42 additions & 24 deletions numpy/core/src/multiarray/buffer.c
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ static PyObject *_buffer_info_cache = NULL;

/* Fill in the info structure */
static _buffer_info_t*
_buffer_info_new(PyObject *obj, npy_bool f_contiguous)
_buffer_info_new(PyObject *obj, int flags)
{
/*
* Note that the buffer info is cached as PyLongObjects making them appear
Expand Down Expand Up @@ -514,6 +514,7 @@ _buffer_info_new(PyObject *obj, npy_bool f_contiguous)
* (This is unnecessary, but has no effect in the case where
* NPY_RELAXED_STRIDES CHECKING is disabled.)
*/
int f_contiguous = (flags & PyBUF_F_CONTIGUOUS) == PyBUF_F_CONTIGUOUS;
if (PyArray_IS_C_CONTIGUOUS(arr) && !(
f_contiguous && PyArray_IS_F_CONTIGUOUS(arr))) {
Py_ssize_t sd = PyArray_ITEMSIZE(arr);
Expand Down Expand Up @@ -547,16 +548,20 @@ _buffer_info_new(PyObject *obj, npy_bool f_contiguous)
}

/* Fill in format */
err = _buffer_format_string(descr, &fmt, obj, NULL, NULL);
Py_DECREF(descr);
if (err != 0) {
goto fail;
if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {
err = _buffer_format_string(descr, &fmt, obj, NULL, NULL);
Py_DECREF(descr);
if (err != 0) {
goto fail;
}
if (_append_char(&fmt, '\0') < 0) {
goto fail;
}
info->format = fmt.s;
}
if (_append_char(&fmt, '\0') < 0) {
goto fail;
else {
info->format = NULL;
}
info->format = fmt.s;

return info;

fail:
Expand All @@ -572,9 +577,10 @@ _buffer_info_cmp(_buffer_info_t *a, _buffer_info_t *b)
Py_ssize_t c;
int k;

c = strcmp(a->format, b->format);
if (c != 0) return c;

if (a->format != NULL && b->format != NULL) {
c = strcmp(a->format, b->format);
if (c != 0) return c;
}
Comment on lines 580 to 583
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this looks risky - are you sure it shouldn't be:

Suggested change
if (a->format != NULL && b->format != NULL) {
c = strcmp(a->format, b->format);
if (c != 0) return c;
}
/* null format sorts before empty string */
c = (a->format != NULL) - (b->format != NULL);
if (c != 0) return c;
if (a->format != NULL && b->format != NULL) {
c = strcmp(a->format, b->format);
if (c != 0) return c;
}

Otherwise NULL is considered equal to arbitrary strings, and equality is not transitive

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, should be correct. if the format is NULL it seems OK to replace it with any other format. Now if the old (first) format is NULL, we will replace the NULL with the actual (second) format. If the second format is NULL, format is ignored completely, so it is fine as well.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An empty string would indicate that the itemsize is 0 in the exported buffer, I guess. A size change could seem problematic, but I do not think so.

In theory, changing the itemsize might be dangerous, but it is explicitly stored in the exported buffer info, so while it could be dangerous from a buffer user perspective, I do not think it is dangerous for having incorrect information inside the exported buffer information.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does this replacing of NULL happen?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

c = a->ndim - b->ndim;
if (c != 0) return c;

Expand All @@ -599,7 +605,7 @@ _buffer_info_free(_buffer_info_t *info)

/* Get buffer info from the global dictionary */
static _buffer_info_t*
_buffer_get_info(PyObject *obj, npy_bool f_contiguous)
_buffer_get_info(PyObject *obj, int flags)
{
PyObject *key = NULL, *item_list = NULL, *item = NULL;
_buffer_info_t *info = NULL, *old_info = NULL;
Expand All @@ -612,7 +618,7 @@ _buffer_get_info(PyObject *obj, npy_bool f_contiguous)
}

/* Compute information */
info = _buffer_info_new(obj, f_contiguous);
info = _buffer_info_new(obj, flags);
if (info == NULL) {
return NULL;
}
Expand All @@ -630,11 +636,9 @@ _buffer_get_info(PyObject *obj, npy_bool f_contiguous)
if (item_list_length > 0) {
item = PyList_GetItem(item_list, item_list_length - 1);
old_info = (_buffer_info_t*)PyLong_AsVoidPtr(item);
if (_buffer_info_cmp(info, old_info) == 0) {
_buffer_info_free(info);
info = old_info;
}
else {
if (_buffer_info_cmp(info, old_info) != 0) {
old_info = NULL; /* Can't use this one, but possibly next */

if (item_list_length > 1 && info->ndim > 1) {
/*
* Some arrays are C- and F-contiguous and if they have more
Expand All @@ -648,12 +652,26 @@ _buffer_get_info(PyObject *obj, npy_bool f_contiguous)
*/
item = PyList_GetItem(item_list, item_list_length - 2);
old_info = (_buffer_info_t*)PyLong_AsVoidPtr(item);
if (_buffer_info_cmp(info, old_info) == 0) {
_buffer_info_free(info);
info = old_info;
if (_buffer_info_cmp(info, old_info) != 0) {
old_info = NULL;
}
}
}

if (old_info != NULL) {
/*
* The two info->format are considered equal if one of them
* has no format set (meaning the format is arbitrary and can
* be modified). If the new info has a format, but we reuse
* the old one, this transfers the ownership to the old one.
*/
if (old_info->format == NULL) {
old_info->format = info->format;
info->format = NULL;
}
_buffer_info_free(info);
info = old_info;
}
}
}
else {
Expand Down Expand Up @@ -760,7 +778,7 @@ array_getbuffer(PyObject *obj, Py_buffer *view, int flags)
}

/* Fill in information */
info = _buffer_get_info(obj, (flags & PyBUF_F_CONTIGUOUS) == PyBUF_F_CONTIGUOUS);
info = _buffer_get_info(obj, flags);
if (info == NULL) {
goto fail;
}
Expand Down Expand Up @@ -825,7 +843,7 @@ void_getbuffer(PyObject *self, Py_buffer *view, int flags)
}

/* Fill in information */
info = _buffer_get_info(self, 0);
info = _buffer_get_info(self, flags);
if (info == NULL) {
goto fail;
}
Expand Down
4 changes: 2 additions & 2 deletions numpy/core/src/multiarray/scalarapi.c
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ scalar_value(PyObject *scalar, PyArray_Descr *descr)
{
int type_num;
int align;
npy_intp memloc;
uintptr_t memloc;
if (descr == NULL) {
descr = PyArray_DescrFromScalar(scalar);
type_num = descr->type_num;
Expand Down Expand Up @@ -168,7 +168,7 @@ scalar_value(PyObject *scalar, PyArray_Descr *descr)
* Use the alignment flag to figure out where the data begins
* after a PyObject_HEAD
*/
memloc = (npy_intp)scalar;
memloc = (uintptr_t)scalar;
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this casting to signed was the problem (and I first thought it can't be because the denominator is stored as a denom - 1 here making the 2 a 1...

If the value is too large, it would be a negative which breaks the rounding (and also means that a previously aligned value looks unaligned).

memloc += sizeof(PyObject);
/* now round-up to the nearest alignment value */
align = descr->alignment;
Expand Down
45 changes: 45 additions & 0 deletions numpy/core/src/multiarray/scalartypes.c.src
Original file line number Diff line number Diff line change
Expand Up @@ -2383,6 +2383,50 @@ static PySequenceMethods voidtype_as_sequence = {
};


/*
* This function implements simple buffer export for user defined subclasses
* of `np.generic`. All other scalar types override the buffer export.
*/
static int
gentype_arrtype_getbuffer(PyObject *self, Py_buffer *view, int flags)
{
if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {
PyErr_Format(PyExc_TypeError,
"NumPy scalar %R can only exported as a buffer without format.",
self);
return -1;
}
PyArray_Descr *descr = PyArray_DescrFromScalar(self);
if (descr == NULL) {
return -1;
}
if (!PyDataType_ISUSERDEF(descr)) {
/* This path would also reject the (hopefully) impossible "object" */
PyErr_Format(PyExc_TypeError,
"user-defined scalar %R registered for built-in dtype %S? "
"This should be impossible.",
self, descr);
return -1;
}
view->ndim = 0;
view->len = descr->elsize;
view->itemsize = descr->elsize;
view->shape = NULL;
view->strides = NULL;
view->suboffsets = NULL;
Py_INCREF(self);
view->obj = self;
view->buf = scalar_value(self, descr);
Py_DECREF(descr);
view->format = NULL;
return 0;
}


static PyBufferProcs gentype_arrtype_as_buffer = {
.bf_getbuffer = (getbufferproc)gentype_arrtype_getbuffer,
};


/**begin repeat
* #name = bool, byte, short, int, long, longlong, ubyte, ushort, uint, ulong,
Expand Down Expand Up @@ -3794,6 +3838,7 @@ initialize_numeric_types(void)
PyGenericArrType_Type.tp_alloc = gentype_alloc;
PyGenericArrType_Type.tp_free = (freefunc)gentype_free;
PyGenericArrType_Type.tp_richcompare = gentype_richcompare;
PyGenericArrType_Type.tp_as_buffer = &gentype_arrtype_as_buffer;

PyBoolArrType_Type.tp_as_number = &bool_arrtype_as_number;
/*
Expand Down
2 changes: 1 addition & 1 deletion numpy/core/src/umath/_rational_tests.c.src
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ static PyGetSetDef pyrational_getset[] = {

static PyTypeObject PyRational_Type = {
PyVarObject_HEAD_INIT(NULL, 0)
"rational", /* tp_name */
"numpy.core._rational_tests.rational", /* tp_name */
sizeof(PyRational), /* tp_basicsize */
0, /* tp_itemsize */
0, /* tp_dealloc */
Expand Down
16 changes: 16 additions & 0 deletions numpy/core/tests/test_multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import numpy as np
import numpy.core._multiarray_tests as _multiarray_tests
from numpy.core._rational_tests import rational
from numpy.testing import (
assert_, assert_raises, assert_warns, assert_equal, assert_almost_equal,
assert_array_equal, assert_raises_regex, assert_array_almost_equal,
Expand Down Expand Up @@ -7143,6 +7144,21 @@ def test_export_flags(self):
_multiarray_tests.get_buffer_info,
np.arange(5)[::2], ('SIMPLE',))

@pytest.mark.parametrize(["obj", "error"], [
pytest.param(np.array([1, 2], dtype=rational), ValueError, id="array"),
pytest.param(rational(1, 2), TypeError, id="scalar")])
def test_export_and_pickle_user_dtype(self, obj, error):
# User dtypes should export successfully when FORMAT was not requested.
with pytest.raises(error):
_multiarray_tests.get_buffer_info(obj, ("STRIDED", "FORMAT"))

_multiarray_tests.get_buffer_info(obj, ("STRIDED",))

# This is currently also necessary to implement pickling:
pickle_obj = pickle.dumps(obj)
res = pickle.loads(pickle_obj)
assert_array_equal(res, obj)

def test_padding(self):
for j in range(8):
x = np.array([(1,), (2,)], dtype={'f0': (int, j)})
Expand Down