Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions numpy/_core/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -1187,6 +1187,7 @@ src_multiarray = multiarray_gen_headers + [
'src/multiarray/shape.c',
'src/multiarray/strfuncs.c',
'src/multiarray/stringdtype/casts.cpp',
'src/multiarray/stringdtype/sorts.c',
'src/multiarray/stringdtype/dtype.c',
'src/multiarray/stringdtype/utf8_utils.c',
'src/multiarray/stringdtype/static_string.c',
Expand Down
14 changes: 7 additions & 7 deletions numpy/_core/src/multiarray/array_method.c
Original file line number Diff line number Diff line change
Expand Up @@ -342,13 +342,13 @@ fill_arraymethod_from_slots(
return -1;
}
}
if (i >= meth->nin && NPY_DT_is_parametric(res->dtypes[i])) {
PyErr_Format(PyExc_TypeError,
"must provide a `resolve_descriptors` function if any "
"output DType is parametric. (method: %s)",
spec->name);
return -1;
}
// if (i >= meth->nin && NPY_DT_is_parametric(res->dtypes[i])) {
// PyErr_Format(PyExc_TypeError,
// "must provide a `resolve_descriptors` function if any "
// "output DType is parametric. (method: %s)",
// spec->name);
// return -1;
// }
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be merged!!! I noticed that we actually cannot skip resolve_descriptors for any parametric dtype, at least for sorts, because the output is parametric. Should we special-case somehow?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't follow all the discussions that led to using an ArrayMethod here. I think you explained elsewhere but I'm not finding it: why don't the sort ArrayMethods care about resolve_descriptors? Because sorts have nin = nout = 1 and the input is the output, or something along those lines? If that's what it is, maybe this check needs to be conditional on the input and output descriptors so that trivial ArrayMethods can skip it.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's because the output is essentially fake for sorts (we enforce data[0] == data[1]. Do you mean we can check if that is the case here and only error if it's not? If so makes sense -- I'll do that, thank you!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I just realized this is initialization, so we don't actually have the descriptors. Not sure what the condition could be...

}
}
if (meth->get_strided_loop != &npy_default_get_strided_loop) {
Expand Down
16 changes: 8 additions & 8 deletions numpy/_core/src/multiarray/item_selection.c
Original file line number Diff line number Diff line change
Expand Up @@ -1227,7 +1227,7 @@ _new_sortlike(PyArrayObject *op, int axis, PyArray_SortFunc *sort,
return 0;
}

if (method_flags != NULL) {
if (strided_loop != NULL) {
needs_api = *method_flags & NPY_METH_REQUIRES_PYAPI;
}
else {
Expand Down Expand Up @@ -1441,7 +1441,7 @@ _new_argsortlike(PyArrayObject *op, int axis, PyArray_ArgSortFunc *argsort,
rstride = PyArray_STRIDE(rop, axis);
needidxbuffer = rstride != sizeof(npy_intp);

if (method_flags != NULL) {
if (strided_loop != NULL) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are the two changes above bugfixes?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, bug fixes. This adjusts for the fact (below) that method_flags should actually never point to NULL, so it can be set in get_strided_loop.

needs_api = *method_flags & NPY_METH_REQUIRES_PYAPI;
}
else {
Expand Down Expand Up @@ -3142,7 +3142,7 @@ PyArray_Sort(PyArrayObject *op, int axis, NPY_SORTKIND flags)
PyArrayMethod_Context context = {0};
PyArray_Descr *loop_descrs[2];
NpyAuxData *auxdata = NULL;
NPY_ARRAYMETHOD_FLAGS *method_flags = NULL;
NPY_ARRAYMETHOD_FLAGS method_flags = 0;

PyArray_SortFunc **sort_table = NULL;
PyArray_SortFunc *sort = NULL;
Expand Down Expand Up @@ -3184,7 +3184,7 @@ PyArray_Sort(PyArrayObject *op, int axis, NPY_SORTKIND flags)
npy_intp strides[2] = {loop_descrs[0]->elsize, loop_descrs[1]->elsize};

if (sort_method->get_strided_loop(
&context, 1, 0, strides, &strided_loop, &auxdata, method_flags) < 0) {
&context, 1, 0, strides, &strided_loop, &auxdata, &method_flags) < 0) {
ret = -1;
goto fail;
}
Expand Down Expand Up @@ -3229,7 +3229,7 @@ PyArray_Sort(PyArrayObject *op, int axis, NPY_SORTKIND flags)
}

ret = _new_sortlike(op, axis, sort, strided_loop,
&context, auxdata, method_flags, NULL, NULL, 0);
&context, auxdata, &method_flags, NULL, NULL, 0);

fail:
if (sort_method != NULL) {
Expand Down Expand Up @@ -3259,7 +3259,7 @@ PyArray_ArgSort(PyArrayObject *op, int axis, NPY_SORTKIND flags)
PyArrayMethod_Context context = {0};
PyArray_Descr *loop_descrs[2];
NpyAuxData *auxdata = NULL;
NPY_ARRAYMETHOD_FLAGS *method_flags = NULL;
NPY_ARRAYMETHOD_FLAGS method_flags = 0;

PyArray_ArgSortFunc **argsort_table = NULL;
PyArray_ArgSortFunc *argsort = NULL;
Expand Down Expand Up @@ -3295,7 +3295,7 @@ PyArray_ArgSort(PyArrayObject *op, int axis, NPY_SORTKIND flags)
npy_intp strides[2] = {loop_descrs[0]->elsize, loop_descrs[1]->elsize};

if (argsort_method->get_strided_loop(
&context, 1, 0, strides, &strided_loop, &auxdata, method_flags) < 0) {
&context, 1, 0, strides, &strided_loop, &auxdata, &method_flags) < 0) {
ret = NULL;
goto fail;
}
Expand Down Expand Up @@ -3346,7 +3346,7 @@ PyArray_ArgSort(PyArrayObject *op, int axis, NPY_SORTKIND flags)
}

ret = _new_argsortlike(op2, axis, argsort, strided_loop,
&context, auxdata, method_flags, NULL, NULL, 0);
&context, auxdata, &method_flags, NULL, NULL, 0);
Py_DECREF(op2);

fail:
Expand Down
Loading
Loading