Skip to content

Commit 479a106

Browse files
committed
SPARK-5984: Fix TimSort bug causes ArrayOutOfBoundsException
1 parent 4ad5153 commit 479a106

File tree

3 files changed

+121
-5
lines changed

3 files changed

+121
-5
lines changed

core/src/main/java/org/apache/spark/util/collection/TimSort.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -425,15 +425,14 @@ private void pushRun(int runBase, int runLen) {
425425
private void mergeCollapse() {
426426
while (stackSize > 1) {
427427
int n = stackSize - 2;
428-
if (n > 0 && runLen[n-1] <= runLen[n] + runLen[n+1]) {
428+
if ( (n >= 1 && runLen[n-1] <= runLen[n] + runLen[n+1])
429+
|| (n >= 2 && runLen[n-2] <= runLen[n] + runLen[n-1])) {
429430
if (runLen[n - 1] < runLen[n + 1])
430431
n--;
431-
mergeAt(n);
432-
} else if (runLen[n] <= runLen[n + 1]) {
433-
mergeAt(n);
434-
} else {
432+
} else if (runLen[n] > runLen[n + 1]) {
435433
break; // Invariant is established
436434
}
435+
mergeAt(n);
437436
}
438437
}
439438

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
package org.apache.spark.util.collection;
2+
3+
import java.util.*;
4+
5+
/**
6+
* This codes generates a int array which fails the standard TimSort, Borrowed from
7+
* the reporter of this bug.
8+
*
9+
* http://www.envisage-project.eu/timsort-specification-and-verification/
10+
*/
11+
public class TestTimSort {
12+
13+
private static final int MIN_MERGE = 32;
14+
15+
/**
16+
* Returns an array of integers that demonstrate the bug in TimSort
17+
*/
18+
public static int[] getTimSortBugTestSet(int length) {
19+
int minRun = minRunLength(length);
20+
List<Long> runs = runsJDKWorstCase(minRun, length);
21+
return createArray(runs, length);
22+
}
23+
24+
private static int minRunLength(int n) {
25+
int r = 0; // Becomes 1 if any 1 bits are shifted off
26+
while (n >= MIN_MERGE) {
27+
r |= (n & 1);
28+
n >>= 1;
29+
}
30+
return n + r;
31+
}
32+
33+
private static int[] createArray(List<Long> runs, int length) {
34+
int[] a = new int[length];
35+
Arrays.fill(a, 0);
36+
int endRun = -1;
37+
for (long len : runs)
38+
a[endRun += len] = 1;
39+
a[length - 1] = 0;
40+
return a;
41+
}
42+
43+
/**
44+
* Fills <code>runs</code> with a sequence of run lengths of the form<br>
45+
* Y_n x_{n,1} x_{n,2} ... x_{n,l_n} <br>
46+
* Y_{n-1} x_{n-1,1} x_{n-1,2} ... x_{n-1,l_{n-1}} <br>
47+
* ... <br>
48+
* Y_1 x_{1,1} x_{1,2} ... x_{1,l_1}<br>
49+
* The Y_i's are chosen to satisfy the invariant throughout execution,
50+
* but the x_{i,j}'s are merged (by <code>TimSort.mergeCollapse</code>)
51+
* into an X_i that violates the invariant.
52+
*
53+
* @param length The sum of all run lengths that will be added to <code>runs</code>.
54+
*/
55+
private static List<Long> runsJDKWorstCase(int minRun, int length) {
56+
List<Long> runs = new ArrayList<Long>();
57+
58+
long runningTotal = 0, Y = minRun + 4, X = minRun;
59+
60+
while (runningTotal + Y + X <= length) {
61+
runningTotal += X + Y;
62+
generateJDKWrongElem(runs, minRun, X);
63+
runs.add(0, Y);
64+
// X_{i+1} = Y_i + x_{i,1} + 1, since runs.get(1) = x_{i,1}
65+
X = Y + runs.get(1) + 1;
66+
// Y_{i+1} = X_{i+1} + Y_i + 1
67+
Y += X + 1;
68+
}
69+
70+
if (runningTotal + X <= length) {
71+
runningTotal += X;
72+
generateJDKWrongElem(runs, minRun, X);
73+
}
74+
75+
runs.add(length - runningTotal);
76+
return runs;
77+
}
78+
79+
/**
80+
* Adds a sequence x_1, ..., x_n of run lengths to <code>runs</code> such that:<br>
81+
* 1. X = x_1 + ... + x_n <br>
82+
* 2. x_j >= minRun for all j <br>
83+
* 3. x_1 + ... + x_{j-2} < x_j < x_1 + ... + x_{j-1} for all j <br>
84+
* These conditions guarantee that TimSort merges all x_j's one by one
85+
* (resulting in X) using only merges on the second-to-last element.
86+
*
87+
* @param X The sum of the sequence that should be added to runs.
88+
*/
89+
private static void generateJDKWrongElem(List<Long> runs, int minRun, long X) {
90+
for (long newTotal; X >= 2 * minRun + 1; X = newTotal) {
91+
//Default strategy
92+
newTotal = X / 2 + 1;
93+
//Specialized strategies
94+
if (3 * minRun + 3 <= X && X <= 4 * minRun + 1) {
95+
// add x_1=MIN+1, x_2=MIN, x_3=X-newTotal to runs
96+
newTotal = 2 * minRun + 1;
97+
} else if (5 * minRun + 5 <= X && X <= 6 * minRun + 5) {
98+
// add x_1=MIN+1, x_2=MIN, x_3=MIN+2, x_4=X-newTotal to runs
99+
newTotal = 3 * minRun + 3;
100+
} else if (8 * minRun + 9 <= X && X <= 10 * minRun + 9) {
101+
// add x_1=MIN+1, x_2=MIN, x_3=MIN+2, x_4=2MIN+2, x_5=X-newTotal to runs
102+
newTotal = 5 * minRun + 5;
103+
} else if (13 * minRun + 15 <= X && X <= 16 * minRun + 17) {
104+
// add x_1=MIN+1, x_2=MIN, x_3=MIN+2, x_4=2MIN+2, x_5=3MIN+4, x_6=X-newTotal to runs
105+
newTotal = 8 * minRun + 9;
106+
}
107+
runs.add(0, X - newTotal);
108+
}
109+
runs.add(0, X);
110+
}
111+
}

core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@ class SorterSuite extends FunSuite {
6565
}
6666
}
6767

68+
test("bug of TimSort") {
69+
val data = TestTimSort.getTimSortBugTestSet(67108864)
70+
new Sorter(new IntArraySortDataFormat).sort(data, 0, data.length, Ordering.Int)
71+
(0 to data.length - 2).foreach(i => assert(data(i) <= data(i+1)))
72+
}
73+
6874
/** Runs an experiment several times. */
6975
def runExperiment(name: String, skip: Boolean = false)(f: => Unit, prepare: () => Unit): Unit = {
7076
if (skip) {

0 commit comments

Comments
 (0)