Skip to content

Commit 2bd8c9a

Browse files
committed
Import my original tests and get them to pass.
1 parent d5d3106 commit 2bd8c9a

File tree

4 files changed

+337
-6
lines changed

4 files changed

+337
-6
lines changed

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,10 @@ public UnsafeSorterSpillWriter(
5959
}
6060

6161
public void write(
62-
Object baseObject,
63-
long baseOffset,
64-
int recordLength,
65-
long keyPrefix) throws IOException {
62+
Object baseObject,
63+
long baseOffset,
64+
int recordLength,
65+
long keyPrefix) throws IOException {
6666
dos.writeInt(recordLength);
6767
dos.writeLong(keyPrefix);
6868
PlatformDependent.copyMemory(
@@ -72,7 +72,6 @@ public void write(
7272
PlatformDependent.BYTE_ARRAY_OFFSET,
7373
recordLength);
7474
writer.write(arr, 0, recordLength);
75-
// TODO: add a test that detects whether we leave this call out:
7675
writer.recordWritten();
7776
}
7877

core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,14 @@ private[spark] class DiskBlockObjectWriter(
211211
recordWritten()
212212
}
213213

214-
override def write(b: Int): Unit = throw new UnsupportedOperationException()
214+
override def write(b: Int): Unit = {
215+
// TOOD: re-enable the `throw new UnsupportedOperationException()` here
216+
if (!initialized) {
217+
open()
218+
}
219+
220+
bs.write(b)
221+
}
215222

216223
override def write(kvBytes: Array[Byte], offs: Int, len: Int): Unit = {
217224
if (!initialized) {
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.util.collection.unsafe.sort;
19+
20+
import java.io.File;
21+
import java.io.InputStream;
22+
import java.io.OutputStream;
23+
import java.util.UUID;
24+
25+
import scala.Tuple2;
26+
import scala.Tuple2$;
27+
import scala.runtime.AbstractFunction1;
28+
29+
import org.junit.Before;
30+
import org.junit.Test;
31+
import org.mockito.Mock;
32+
import org.mockito.MockitoAnnotations;
33+
import org.mockito.invocation.InvocationOnMock;
34+
import org.mockito.stubbing.Answer;
35+
import static org.junit.Assert.*;
36+
import static org.mockito.AdditionalAnswers.returnsFirstArg;
37+
import static org.mockito.AdditionalAnswers.returnsSecondArg;
38+
import static org.mockito.Answers.RETURNS_SMART_NULLS;
39+
import static org.mockito.Mockito.*;
40+
41+
import org.apache.spark.HashPartitioner;
42+
import org.apache.spark.SparkConf;
43+
import org.apache.spark.TaskContext;
44+
import org.apache.spark.executor.ShuffleWriteMetrics;
45+
import org.apache.spark.executor.TaskMetrics;
46+
import org.apache.spark.serializer.SerializerInstance;
47+
import org.apache.spark.shuffle.ShuffleMemoryManager;
48+
import org.apache.spark.storage.*;
49+
import org.apache.spark.unsafe.PlatformDependent;
50+
import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
51+
import org.apache.spark.unsafe.memory.MemoryAllocator;
52+
import org.apache.spark.unsafe.memory.TaskMemoryManager;
53+
import org.apache.spark.util.Utils;
54+
55+
public class UnsafeExternalSorterSuite {
56+
57+
final TaskMemoryManager memoryManager =
58+
new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
59+
// Compute key prefixes based on the records' partition ids
60+
final HashPartitioner hashPartitioner = new HashPartitioner(4);
61+
// Use integer comparison for comparing prefixes (which are partition ids, in this case)
62+
final PrefixComparator prefixComparator = new PrefixComparator() {
63+
@Override
64+
public int compare(long prefix1, long prefix2) {
65+
return (int) prefix1 - (int) prefix2;
66+
}
67+
};
68+
// Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so
69+
// use a dummy comparator
70+
final RecordComparator recordComparator = new RecordComparator() {
71+
@Override
72+
public int compare(
73+
Object leftBaseObject,
74+
long leftBaseOffset,
75+
Object rightBaseObject,
76+
long rightBaseOffset) {
77+
return 0;
78+
}
79+
};
80+
81+
@Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager;
82+
@Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
83+
@Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
84+
@Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
85+
86+
File tempDir;
87+
88+
private static final class CompressStream extends AbstractFunction1<OutputStream, OutputStream> {
89+
@Override
90+
public OutputStream apply(OutputStream stream) {
91+
return stream;
92+
}
93+
}
94+
95+
@Before
96+
public void setUp() {
97+
MockitoAnnotations.initMocks(this);
98+
tempDir = new File(Utils.createTempDir$default$1());
99+
taskContext = mock(TaskContext.class);
100+
when(taskContext.taskMetrics()).thenReturn(new TaskMetrics());
101+
when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg());
102+
when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
103+
when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer<Tuple2<TempLocalBlockId, File>>() {
104+
@Override
105+
public Tuple2<TempLocalBlockId, File> answer(InvocationOnMock invocationOnMock) throws Throwable {
106+
TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID());
107+
File file = File.createTempFile("spillFile", ".spill", tempDir);
108+
return Tuple2$.MODULE$.apply(blockId, file);
109+
}
110+
});
111+
when(blockManager.getDiskWriter(
112+
any(BlockId.class),
113+
any(File.class),
114+
any(SerializerInstance.class),
115+
anyInt(),
116+
any(ShuffleWriteMetrics.class))).thenAnswer(new Answer<DiskBlockObjectWriter>() {
117+
@Override
118+
public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable {
119+
Object[] args = invocationOnMock.getArguments();
120+
121+
return new DiskBlockObjectWriter(
122+
(BlockId) args[0],
123+
(File) args[1],
124+
(SerializerInstance) args[2],
125+
(Integer) args[3],
126+
new CompressStream(),
127+
false,
128+
(ShuffleWriteMetrics) args[4]
129+
);
130+
}
131+
});
132+
when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class)))
133+
.then(returnsSecondArg());
134+
}
135+
136+
private static void insertNumber(UnsafeExternalSorter sorter, int value) throws Exception {
137+
final int[] arr = new int[] { value };
138+
sorter.insertRecord(arr, PlatformDependent.INT_ARRAY_OFFSET, 4, value);
139+
}
140+
141+
/**
142+
* Tests the type of sorting that's used in the non-combiner path of sort-based shuffle.
143+
*/
144+
@Test
145+
public void testSortingOnlyByPartitionId() throws Exception {
146+
147+
final UnsafeExternalSorter sorter = new UnsafeExternalSorter(
148+
memoryManager,
149+
shuffleMemoryManager,
150+
blockManager,
151+
taskContext,
152+
recordComparator,
153+
prefixComparator,
154+
1024,
155+
new SparkConf());
156+
157+
insertNumber(sorter, 5);
158+
insertNumber(sorter, 1);
159+
insertNumber(sorter, 3);
160+
sorter.spill();
161+
insertNumber(sorter, 4);
162+
insertNumber(sorter, 2);
163+
164+
UnsafeSorterIterator iter = sorter.getSortedIterator();
165+
166+
iter.loadNext();
167+
assertEquals(1, iter.getKeyPrefix());
168+
iter.loadNext();
169+
assertEquals(2, iter.getKeyPrefix());
170+
iter.loadNext();
171+
assertEquals(3, iter.getKeyPrefix());
172+
iter.loadNext();
173+
assertEquals(4, iter.getKeyPrefix());
174+
iter.loadNext();
175+
assertEquals(5, iter.getKeyPrefix());
176+
assertFalse(iter.hasNext());
177+
// TODO: check that the values are also read back properly.
178+
179+
// TODO: test for cleanup:
180+
// assert(tempDir.isEmpty)
181+
}
182+
183+
}
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.util.collection.unsafe.sort;
19+
20+
import java.util.Arrays;
21+
22+
import org.junit.Test;
23+
import static org.hamcrest.MatcherAssert.assertThat;
24+
import static org.hamcrest.Matchers.*;
25+
import static org.junit.Assert.*;
26+
import static org.mockito.Mockito.mock;
27+
28+
import org.apache.spark.HashPartitioner;
29+
import org.apache.spark.unsafe.PlatformDependent;
30+
import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
31+
import org.apache.spark.unsafe.memory.MemoryAllocator;
32+
import org.apache.spark.unsafe.memory.MemoryBlock;
33+
import org.apache.spark.unsafe.memory.TaskMemoryManager;
34+
35+
public class UnsafeInMemorySorterSuite {
36+
37+
private static String getStringFromDataPage(Object baseObject, long baseOffset) {
38+
final int strLength = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset);
39+
final byte[] strBytes = new byte[strLength];
40+
PlatformDependent.copyMemory(
41+
baseObject,
42+
baseOffset + 8,
43+
strBytes,
44+
PlatformDependent.BYTE_ARRAY_OFFSET, strLength);
45+
return new String(strBytes);
46+
}
47+
48+
@Test
49+
public void testSortingEmptyInput() {
50+
final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(
51+
new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)),
52+
mock(RecordComparator.class),
53+
mock(PrefixComparator.class),
54+
100);
55+
final UnsafeSorterIterator iter = sorter.getSortedIterator();
56+
assert(!iter.hasNext());
57+
}
58+
59+
/**
60+
* Tests the type of sorting that's used in the non-combiner path of sort-based shuffle.
61+
*/
62+
@Test
63+
public void testSortingOnlyByPartitionId() throws Exception {
64+
final String[] dataToSort = new String[] {
65+
"Boba",
66+
"Pearls",
67+
"Tapioca",
68+
"Taho",
69+
"Condensed Milk",
70+
"Jasmine",
71+
"Milk Tea",
72+
"Lychee",
73+
"Mango"
74+
};
75+
final TaskMemoryManager memoryManager =
76+
new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
77+
final MemoryBlock dataPage = memoryManager.allocatePage(2048);
78+
final Object baseObject = dataPage.getBaseObject();
79+
// Write the records into the data page:
80+
long position = dataPage.getBaseOffset();
81+
for (String str : dataToSort) {
82+
final byte[] strBytes = str.getBytes("utf-8");
83+
PlatformDependent.UNSAFE.putLong(baseObject, position, strBytes.length);
84+
position += 8;
85+
PlatformDependent.copyMemory(
86+
strBytes,
87+
PlatformDependent.BYTE_ARRAY_OFFSET,
88+
baseObject,
89+
position,
90+
strBytes.length);
91+
position += strBytes.length;
92+
}
93+
// Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so
94+
// use a dummy comparator
95+
final RecordComparator recordComparator = new RecordComparator() {
96+
@Override
97+
public int compare(
98+
Object leftBaseObject,
99+
long leftBaseOffset,
100+
Object rightBaseObject,
101+
long rightBaseOffset) {
102+
return 0;
103+
}
104+
};
105+
// Compute key prefixes based on the records' partition ids
106+
final HashPartitioner hashPartitioner = new HashPartitioner(4);
107+
// Use integer comparison for comparing prefixes (which are partition ids, in this case)
108+
final PrefixComparator prefixComparator = new PrefixComparator() {
109+
@Override
110+
public int compare(long prefix1, long prefix2) {
111+
return (int) prefix1 - (int) prefix2;
112+
}
113+
};
114+
UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(memoryManager, recordComparator,
115+
prefixComparator, dataToSort.length);
116+
// Given a page of records, insert those records into the sorter one-by-one:
117+
position = dataPage.getBaseOffset();
118+
for (int i = 0; i < dataToSort.length; i++) {
119+
// position now points to the start of a record (which holds its length).
120+
final long recordLength = PlatformDependent.UNSAFE.getLong(baseObject, position);
121+
final long address = memoryManager.encodePageNumberAndOffset(dataPage, position);
122+
final String str = getStringFromDataPage(baseObject, position);
123+
final int partitionId = hashPartitioner.getPartition(str);
124+
sorter.insertRecord(address, partitionId);
125+
position += 8 + recordLength;
126+
}
127+
final UnsafeSorterIterator iter = sorter.getSortedIterator();
128+
int iterLength = 0;
129+
long prevPrefix = -1;
130+
Arrays.sort(dataToSort);
131+
while (iter.hasNext()) {
132+
iter.loadNext();
133+
final String str = getStringFromDataPage(iter.getBaseObject(), iter.getBaseOffset());
134+
final long keyPrefix = iter.getKeyPrefix();
135+
assertTrue(Arrays.asList(dataToSort).contains(str));
136+
assertThat(keyPrefix, greaterThanOrEqualTo(prevPrefix));
137+
prevPrefix = keyPrefix;
138+
iterLength++;
139+
}
140+
assertEquals(dataToSort.length, iterLength);
141+
}
142+
}

0 commit comments

Comments
 (0)