Skip to content

Commit 87f410b

Browse files
committed
ENH: Allocate lock only once in StringDType quicksort
1 parent 1654bd1 commit 87f410b

File tree

3 files changed

+49
-3
lines changed

3 files changed

+49
-3
lines changed

numpy/_core/src/multiarray/stringdtype/dtype.c

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,33 @@ compare(void *a, void *b, void *arr)
459459
return ret;
460460
}
461461

462+
int
463+
_compare_no_mutex(const void *a, const void *b, void *arr)
464+
{
465+
PyArray_StringDTypeObject *descr = (PyArray_StringDTypeObject *)PyArray_DESCR(arr);
466+
return _compare((void*) a, (void*) b, descr, descr);
467+
}
468+
469+
void
470+
_init_sort_cmp(PyArray_Descr *descr, PyArray_CompareFunc **out_cmp)
471+
{
472+
if (descr->type_num == NPY_VSTRING) {
473+
NpyString_acquire_allocator((PyArray_StringDTypeObject *)descr);
474+
*out_cmp = _compare_no_mutex;
475+
}
476+
else {
477+
*out_cmp = PyDataType_GetArrFuncs(descr)->compare;
478+
}
479+
}
480+
481+
void
482+
_end_sort_cmp(PyArray_Descr *descr)
483+
{
484+
if (descr->type_num == NPY_VSTRING) {
485+
NpyString_release_allocator(((PyArray_StringDTypeObject *)descr)->allocator);
486+
}
487+
}
488+
462489
int
463490
_compare(void *a, void *b, PyArray_StringDTypeObject *descr_a,
464491
PyArray_StringDTypeObject *descr_b)

numpy/_core/src/multiarray/stringdtype/dtype.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,21 @@ new_stringdtype_instance(PyObject *na_object, int coerce);
1919
NPY_NO_EXPORT int
2020
init_string_dtype(void);
2121

22+
NPY_NO_EXPORT int
23+
_compare_no_mutex(const void *a, const void *b, void *arr);
24+
2225
// Assumes that the caller has already acquired the allocator locks for both
2326
// descriptors
2427
NPY_NO_EXPORT int
2528
_compare(void *a, void *b, PyArray_StringDTypeObject *descr_a,
2629
PyArray_StringDTypeObject *descr_b);
2730

31+
NPY_NO_EXPORT void
32+
_init_sort_cmp(PyArray_Descr *descr, PyArray_CompareFunc **out_cmp);
33+
34+
NPY_NO_EXPORT void
35+
_end_sort_cmp(PyArray_Descr *descr);
36+
2837
NPY_NO_EXPORT int
2938
init_string_na_object(PyObject *mod);
3039

numpy/_core/src/npysort/quicksort.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
* the below code implements this converted to an iteration and as an
4545
* additional minor optimization skips the recursion depth checking on the
4646
* smaller partition as it is always less than half of the remaining data and
47-
* will thus terminate fast enough
47+
* will thus terminate fast enough`
4848
*/
4949

5050
#define NPY_NO_DEPRECATED_API NPY_API_VERSION
@@ -56,6 +56,7 @@
5656
#include "numpy_tag.h"
5757
#include "x86_simd_qsort.hpp"
5858
#include "highway_qsort.hpp"
59+
#include "stringdtype/dtype.h"
5960

6061
#include <cstdlib>
6162
#include <utility>
@@ -510,7 +511,7 @@ npy_quicksort(void *start, npy_intp num, void *varr)
510511
{
511512
PyArrayObject *arr = (PyArrayObject *)varr;
512513
npy_intp elsize = PyArray_ITEMSIZE(arr);
513-
PyArray_CompareFunc *cmp = PyDataType_GetArrFuncs(PyArray_DESCR(arr))->compare;
514+
PyArray_CompareFunc *cmp;
514515
char *vp;
515516
char *pl = (char *)start;
516517
char *pr = pl + (num - 1) * elsize;
@@ -521,6 +522,8 @@ npy_quicksort(void *start, npy_intp num, void *varr)
521522
int *psdepth = depth;
522523
int cdepth = npy_get_msb(num) * 2;
523524

525+
_init_sort_cmp(PyArray_DESCR(arr), &cmp);
526+
524527
/* Items that have zero size don't make sense to sort */
525528
if (elsize == 0) {
526529
return 0;
@@ -606,6 +609,9 @@ npy_quicksort(void *start, npy_intp num, void *varr)
606609
}
607610

608611
free(vp);
612+
613+
_end_sort_cmp(PyArray_DESCR(arr));
614+
609615
return 0;
610616
}
611617

@@ -615,7 +621,7 @@ npy_aquicksort(void *vv, npy_intp *tosort, npy_intp num, void *varr)
615621
char *v = (char *)vv;
616622
PyArrayObject *arr = (PyArrayObject *)varr;
617623
npy_intp elsize = PyArray_ITEMSIZE(arr);
618-
PyArray_CompareFunc *cmp = PyDataType_GetArrFuncs(PyArray_DESCR(arr))->compare;
624+
PyArray_CompareFunc *cmp;
619625
char *vp;
620626
npy_intp *pl = tosort;
621627
npy_intp *pr = tosort + num - 1;
@@ -626,6 +632,8 @@ npy_aquicksort(void *vv, npy_intp *tosort, npy_intp num, void *varr)
626632
int *psdepth = depth;
627633
int cdepth = npy_get_msb(num) * 2;
628634

635+
_init_sort_cmp(PyArray_DESCR(arr), &cmp);
636+
629637
/* Items that have zero size don't make sense to sort */
630638
if (elsize == 0) {
631639
return 0;
@@ -700,6 +708,8 @@ npy_aquicksort(void *vv, npy_intp *tosort, npy_intp num, void *varr)
700708
cdepth = *(--psdepth);
701709
}
702710

711+
_end_sort_cmp(PyArray_DESCR(arr));
712+
703713
return 0;
704714
}
705715

0 commit comments

Comments
 (0)