1+ /*
2+ * Licensed to the Apache Software Foundation (ASF) under one or more
3+ * contributor license agreements. See the NOTICE file distributed with
4+ * this work for additional information regarding copyright ownership.
5+ * The ASF licenses this file to You under the Apache License, Version 2.0
6+ * (the "License"); you may not use this file except in compliance with
7+ * the License. You may obtain a copy of the License at
8+ *
9+ * http://www.apache.org/licenses/LICENSE-2.0
10+ *
11+ * Unless required by applicable law or agreed to in writing, software
12+ * distributed under the License is distributed on an "AS IS" BASIS,
13+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+ * See the License for the specific language governing permissions and
15+ * limitations under the License.
16+ */
17+
18+ package org .apache .spark .util .collection .unsafe .sort ;
19+
20+ import java .io .File ;
21+ import java .io .InputStream ;
22+ import java .io .OutputStream ;
23+ import java .util .UUID ;
24+
25+ import scala .Tuple2 ;
26+ import scala .Tuple2$ ;
27+ import scala .runtime .AbstractFunction1 ;
28+
29+ import org .junit .Before ;
30+ import org .junit .Test ;
31+ import org .mockito .Mock ;
32+ import org .mockito .MockitoAnnotations ;
33+ import org .mockito .invocation .InvocationOnMock ;
34+ import org .mockito .stubbing .Answer ;
35+ import static org .junit .Assert .*;
36+ import static org .mockito .AdditionalAnswers .returnsFirstArg ;
37+ import static org .mockito .AdditionalAnswers .returnsSecondArg ;
38+ import static org .mockito .Answers .RETURNS_SMART_NULLS ;
39+ import static org .mockito .Mockito .*;
40+
41+ import org .apache .spark .HashPartitioner ;
42+ import org .apache .spark .SparkConf ;
43+ import org .apache .spark .TaskContext ;
44+ import org .apache .spark .executor .ShuffleWriteMetrics ;
45+ import org .apache .spark .executor .TaskMetrics ;
46+ import org .apache .spark .serializer .SerializerInstance ;
47+ import org .apache .spark .shuffle .ShuffleMemoryManager ;
48+ import org .apache .spark .storage .*;
49+ import org .apache .spark .unsafe .PlatformDependent ;
50+ import org .apache .spark .unsafe .memory .ExecutorMemoryManager ;
51+ import org .apache .spark .unsafe .memory .MemoryAllocator ;
52+ import org .apache .spark .unsafe .memory .TaskMemoryManager ;
53+ import org .apache .spark .util .Utils ;
54+
55+ public class UnsafeExternalSorterSuite {
56+
57+ final TaskMemoryManager memoryManager =
58+ new TaskMemoryManager (new ExecutorMemoryManager (MemoryAllocator .HEAP ));
59+ // Compute key prefixes based on the records' partition ids
60+ final HashPartitioner hashPartitioner = new HashPartitioner (4 );
61+ // Use integer comparison for comparing prefixes (which are partition ids, in this case)
62+ final PrefixComparator prefixComparator = new PrefixComparator () {
63+ @ Override
64+ public int compare (long prefix1 , long prefix2 ) {
65+ return (int ) prefix1 - (int ) prefix2 ;
66+ }
67+ };
68+ // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so
69+ // use a dummy comparator
70+ final RecordComparator recordComparator = new RecordComparator () {
71+ @ Override
72+ public int compare (
73+ Object leftBaseObject ,
74+ long leftBaseOffset ,
75+ Object rightBaseObject ,
76+ long rightBaseOffset ) {
77+ return 0 ;
78+ }
79+ };
80+
81+ @ Mock (answer = RETURNS_SMART_NULLS ) ShuffleMemoryManager shuffleMemoryManager ;
82+ @ Mock (answer = RETURNS_SMART_NULLS ) BlockManager blockManager ;
83+ @ Mock (answer = RETURNS_SMART_NULLS ) DiskBlockManager diskBlockManager ;
84+ @ Mock (answer = RETURNS_SMART_NULLS ) TaskContext taskContext ;
85+
86+ File tempDir ;
87+
88+ private static final class CompressStream extends AbstractFunction1 <OutputStream , OutputStream > {
89+ @ Override
90+ public OutputStream apply (OutputStream stream ) {
91+ return stream ;
92+ }
93+ }
94+
95+ @ Before
96+ public void setUp () {
97+ MockitoAnnotations .initMocks (this );
98+ tempDir = new File (Utils .createTempDir$default$1 ());
99+ taskContext = mock (TaskContext .class );
100+ when (taskContext .taskMetrics ()).thenReturn (new TaskMetrics ());
101+ when (shuffleMemoryManager .tryToAcquire (anyLong ())).then (returnsFirstArg ());
102+ when (blockManager .diskBlockManager ()).thenReturn (diskBlockManager );
103+ when (diskBlockManager .createTempLocalBlock ()).thenAnswer (new Answer <Tuple2 <TempLocalBlockId , File >>() {
104+ @ Override
105+ public Tuple2 <TempLocalBlockId , File > answer (InvocationOnMock invocationOnMock ) throws Throwable {
106+ TempLocalBlockId blockId = new TempLocalBlockId (UUID .randomUUID ());
107+ File file = File .createTempFile ("spillFile" , ".spill" , tempDir );
108+ return Tuple2$ .MODULE$ .apply (blockId , file );
109+ }
110+ });
111+ when (blockManager .getDiskWriter (
112+ any (BlockId .class ),
113+ any (File .class ),
114+ any (SerializerInstance .class ),
115+ anyInt (),
116+ any (ShuffleWriteMetrics .class ))).thenAnswer (new Answer <DiskBlockObjectWriter >() {
117+ @ Override
118+ public DiskBlockObjectWriter answer (InvocationOnMock invocationOnMock ) throws Throwable {
119+ Object [] args = invocationOnMock .getArguments ();
120+
121+ return new DiskBlockObjectWriter (
122+ (BlockId ) args [0 ],
123+ (File ) args [1 ],
124+ (SerializerInstance ) args [2 ],
125+ (Integer ) args [3 ],
126+ new CompressStream (),
127+ false ,
128+ (ShuffleWriteMetrics ) args [4 ]
129+ );
130+ }
131+ });
132+ when (blockManager .wrapForCompression (any (BlockId .class ), any (InputStream .class )))
133+ .then (returnsSecondArg ());
134+ }
135+
136+ private static void insertNumber (UnsafeExternalSorter sorter , int value ) throws Exception {
137+ final int [] arr = new int [] { value };
138+ sorter .insertRecord (arr , PlatformDependent .INT_ARRAY_OFFSET , 4 , value );
139+ }
140+
141+ /**
142+ * Tests the type of sorting that's used in the non-combiner path of sort-based shuffle.
143+ */
144+ @ Test
145+ public void testSortingOnlyByPartitionId () throws Exception {
146+
147+ final UnsafeExternalSorter sorter = new UnsafeExternalSorter (
148+ memoryManager ,
149+ shuffleMemoryManager ,
150+ blockManager ,
151+ taskContext ,
152+ recordComparator ,
153+ prefixComparator ,
154+ 1024 ,
155+ new SparkConf ());
156+
157+ insertNumber (sorter , 5 );
158+ insertNumber (sorter , 1 );
159+ insertNumber (sorter , 3 );
160+ sorter .spill ();
161+ insertNumber (sorter , 4 );
162+ insertNumber (sorter , 2 );
163+
164+ UnsafeSorterIterator iter = sorter .getSortedIterator ();
165+
166+ iter .loadNext ();
167+ assertEquals (1 , iter .getKeyPrefix ());
168+ iter .loadNext ();
169+ assertEquals (2 , iter .getKeyPrefix ());
170+ iter .loadNext ();
171+ assertEquals (3 , iter .getKeyPrefix ());
172+ iter .loadNext ();
173+ assertEquals (4 , iter .getKeyPrefix ());
174+ iter .loadNext ();
175+ assertEquals (5 , iter .getKeyPrefix ());
176+ assertFalse (iter .hasNext ());
177+ // TODO: check that the values are also read back properly.
178+
179+ // TODO: test for cleanup:
180+ // assert(tempDir.isEmpty)
181+ }
182+
183+ }
0 commit comments