Skip to content

Commit 9ead596

Browse files
authored
ENH: Improve performance for np.result_type (#28710)
Improve performance of np.result_type * Use NPY_ALLOC_WORKSPACE in array_result_type * Avoid increfs/decrefs in PyArray_ResultType * Fast path for identical arguments to np.result_type
1 parent f249607 commit 9ead596

File tree

2 files changed

+12
-13
lines changed

2 files changed

+12
-13
lines changed

numpy/_core/src/multiarray/convert_datatype.c

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1597,13 +1597,12 @@ PyArray_ResultType(
15971597
return NULL;
15981598
}
15991599

1600-
PyArray_DTypeMeta **all_DTypes = (PyArray_DTypeMeta **)workspace;
1600+
PyArray_DTypeMeta **all_DTypes = (PyArray_DTypeMeta **)workspace; // borrowed references
16011601
PyArray_Descr **all_descriptors = (PyArray_Descr **)(&all_DTypes[narrs+ndtypes]);
16021602

16031603
/* Copy all dtypes into a single array defining non-value-based behaviour */
16041604
for (npy_intp i=0; i < ndtypes; i++) {
16051605
all_DTypes[i] = NPY_DTYPE(descrs[i]);
1606-
Py_INCREF(all_DTypes[i]);
16071606
all_descriptors[i] = descrs[i];
16081607
}
16091608

@@ -1628,14 +1627,10 @@ PyArray_ResultType(
16281627
all_descriptors[i_all] = PyArray_DTYPE(arrs[i]);
16291628
all_DTypes[i_all] = NPY_DTYPE(all_descriptors[i_all]);
16301629
}
1631-
Py_INCREF(all_DTypes[i_all]);
16321630
}
16331631

16341632
PyArray_DTypeMeta *common_dtype = PyArray_PromoteDTypeSequence(
16351633
narrs+ndtypes, all_DTypes);
1636-
for (npy_intp i=0; i < narrs+ndtypes; i++) {
1637-
Py_DECREF(all_DTypes[i]);
1638-
}
16391634
if (common_dtype == NULL) {
16401635
goto error;
16411636
}

numpy/_core/src/multiarray/multiarraymodule.c

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3583,24 +3583,28 @@ static PyObject *
35833583
array_result_type(PyObject *NPY_UNUSED(dummy), PyObject *const *args, Py_ssize_t len)
35843584
{
35853585
npy_intp i, narr = 0, ndtypes = 0;
3586-
PyArrayObject **arr = NULL;
3587-
PyArray_Descr **dtypes = NULL;
35883586
PyObject *ret = NULL;
35893587

35903588
if (len == 0) {
35913589
PyErr_SetString(PyExc_ValueError,
35923590
"at least one array or dtype is required");
3593-
goto finish;
3591+
return NULL;
35943592
}
35953593

3596-
arr = PyArray_malloc(2 * len * sizeof(void *));
3594+
NPY_ALLOC_WORKSPACE(arr, PyArrayObject *, 2 * 3, 2 * len);
35973595
if (arr == NULL) {
3598-
return PyErr_NoMemory();
3596+
return NULL;
35993597
}
3600-
dtypes = (PyArray_Descr**)&arr[len];
3598+
PyArray_Descr **dtypes = (PyArray_Descr**)&arr[len];
3599+
3600+
PyObject *previous_obj = NULL;
36013601

36023602
for (i = 0; i < len; ++i) {
36033603
PyObject *obj = args[i];
3604+
if (obj == previous_obj) {
3605+
continue;
3606+
}
3607+
36043608
if (PyArray_Check(obj)) {
36053609
Py_INCREF(obj);
36063610
arr[narr] = (PyArrayObject *)obj;
@@ -3636,7 +3640,7 @@ array_result_type(PyObject *NPY_UNUSED(dummy), PyObject *const *args, Py_ssize_t
36363640
for (i = 0; i < ndtypes; ++i) {
36373641
Py_DECREF(dtypes[i]);
36383642
}
3639-
PyArray_free(arr);
3643+
npy_free_workspace(arr);
36403644
return ret;
36413645
}
36423646

0 commit comments

Comments
 (0)