Skip to content

Commit 717c7ac

Browse files
committed
Make nan types sort to the end.
Add test for the new nan sort order.
1 parent 3ad508a commit 717c7ac

File tree

2 files changed

+44
-3
lines changed

2 files changed

+44
-3
lines changed

numpy/core/src/_sortmodule.c.src

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
*****************************************************************************
4646
*/
4747

48-
4948
/**begin repeat
5049
*
5150
* #TYPE = BOOL, BYTE, UBYTE, SHORT, USHORT, INT, UINT, LONG, ULONG,
@@ -60,6 +59,7 @@ NPY_INLINE static int
6059
}
6160
/**end repeat**/
6261

62+
6363
/**begin repeat
6464
*
6565
* #TYPE = FLOAT, DOUBLE, LONGDOUBLE#
@@ -68,10 +68,17 @@ NPY_INLINE static int
6868
NPY_INLINE static int
6969
@TYPE@_LT(@type@ a, @type@ b)
7070
{
71-
return a < b;
71+
return a < b || (b != b && a == a);
7272
}
7373
/**end repeat**/
7474

75+
76+
/*
77+
* For inline functions SUN recommends not using a return in the then part
78+
* of an if statement. It's a SUN compiler thing, so assign the return value
79+
* to a variable instead.
80+
*/
81+
7582
/**begin repeat
7683
*
7784
* #TYPE = CFLOAT, CDOUBLE, CLONGDOUBLE#
@@ -80,10 +87,26 @@ NPY_INLINE static int
8087
NPY_INLINE static int
8188
@TYPE@_LT(@type@ a, @type@ b)
8289
{
83-
return a.real < b.real || (a.real == b.real && a.imag < b.imag);
90+
int ret;
91+
92+
if (a.real < b.real) {
93+
ret = a.imag == a.imag || b.imag != b.imag;
94+
}
95+
else if (a.real > b.real) {
96+
ret = b.imag != b.imag && a.imag == a.imag;
97+
}
98+
else if (a.real == b.real || (a.real != a.real && b.real != b.real)) {
99+
ret = a.imag < b.imag || (b.imag != b.imag && a.imag == a.imag);
100+
}
101+
else {
102+
ret = b.real != b.real;
103+
}
104+
105+
return ret;
84106
}
85107
/**end repeat**/
86108

109+
87110
/* The PyObject functions are stubs for later use */
88111
NPY_INLINE static int
89112
PyObject_LT(PyObject *pa, PyObject *pb)

numpy/core/tests/test_multiarray.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,24 @@ def test_transpose(self):
280280
self.failUnlessRaises(ValueError, lambda: a.transpose(0,1,2))
281281

282282
def test_sort(self):
283+
# test ordering for floats and complex containing nans. It is only
284+
# necessary to check the lessthan comparison, so sorts that
285+
# only follow the insertion sort path are sufficient. We only
286+
# test doubles and complex doubles as the logic is the same.
287+
288+
# check doubles
289+
msg = "Test real sort order with nans"
290+
a = np.array([np.nan, 1, 0])
291+
b = sort(a)
292+
assert_equal(b, a[::-1], msg)
293+
# check complex
294+
msg = "Test complex sort order with nans"
295+
a = np.zeros(9, dtype=np.complex128)
296+
a.real += [np.nan, np.nan,np. nan, 1, 0, 1, 1, 0, 0]
297+
a.imag += [np.nan, 1, 0, np.nan, np.nan, 1, 0, 1, 0]
298+
b = sort(a)
299+
assert_equal(b, a[::-1], msg)
300+
283301
# all c scalar sorts use the same code with different types
284302
# so it suffices to run a quick check with one type. The number
285303
# of sorted items must be greater than ~50 to check the actual

0 commit comments

Comments
 (0)