Skip to content

Commit b3eaccd

Browse files
committed
Extract aggregation map into its own class.
This makes the code much easier to understand and will allow me to implement unsafe versions of both GeneratedAggregate and the regular Aggregate operator.
1 parent d2bb986 commit b3eaccd

File tree

3 files changed

+229
-100
lines changed

3 files changed

+229
-100
lines changed
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions;
19+
20+
import java.util.Arrays;
21+
import java.util.Iterator;
22+
23+
import org.apache.spark.sql.Row;
24+
import org.apache.spark.sql.types.StructType;
25+
import org.apache.spark.unsafe.PlatformDependent;
26+
import org.apache.spark.unsafe.map.BytesToBytesMap;
27+
import org.apache.spark.unsafe.memory.MemoryAllocator;
28+
import org.apache.spark.unsafe.memory.MemoryLocation;
29+
30+
/**
31+
* Unsafe-based HashMap for performing aggregations in which the aggregated values are
32+
* fixed-width. This is NOT threadsafe.
33+
*/
34+
public final class UnsafeFixedWidthAggregationMap {
35+
36+
/**
37+
* An empty aggregation buffer, encoded in UnsafeRow format. When inserting a new key into the
38+
* map, we copy this buffer and use it as the value.
39+
*/
40+
private final long[] emptyAggregationBuffer;
41+
42+
private final StructType aggregationBufferSchema;
43+
44+
private final StructType groupingKeySchema;
45+
46+
/**
47+
* Encodes grouping keys as UnsafeRows.
48+
*/
49+
private final UnsafeRowConverter groupingKeyToUnsafeRowConverter;
50+
51+
/**
52+
* A hashmap which maps from opaque bytearray keys to bytearray values.
53+
*/
54+
private final BytesToBytesMap map;
55+
56+
/**
57+
* Re-used pointer to the current aggregation buffer
58+
*/
59+
private final UnsafeRow currentAggregationBuffer = new UnsafeRow();
60+
61+
/**
62+
* Scratch space that is used when encoding grouping keys into UnsafeRow format.
63+
*
64+
* By default, this is a 1MB array, but it will grow as necessary in case larger keys are
65+
* encountered.
66+
*/
67+
private long[] groupingKeyConversionScratchSpace = new long[1024 / 8];
68+
69+
/**
70+
* Create a new UnsafeFixedWidthAggregationMap.
71+
*
72+
* @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function)
73+
* @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion.
74+
* @param groupingKeySchema the schema of the grouping key, used for row conversion.
75+
* @param allocator the memory allocator used to allocate our Unsafe memory structures.
76+
* @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing).
77+
*/
78+
public UnsafeFixedWidthAggregationMap(
79+
Row emptyAggregationBuffer,
80+
StructType aggregationBufferSchema,
81+
StructType groupingKeySchema,
82+
MemoryAllocator allocator,
83+
long initialCapacity) {
84+
this.emptyAggregationBuffer =
85+
convertToUnsafeRow(emptyAggregationBuffer, aggregationBufferSchema);
86+
this.aggregationBufferSchema = aggregationBufferSchema;
87+
this.groupingKeyToUnsafeRowConverter = new UnsafeRowConverter(groupingKeySchema);
88+
this.groupingKeySchema = groupingKeySchema;
89+
this.map = new BytesToBytesMap(allocator, initialCapacity);
90+
}
91+
92+
/**
93+
* Convert a Java object row into an UnsafeRow, allocating it into a new long array.
94+
*/
95+
private static long[] convertToUnsafeRow(Row javaRow, StructType schema) {
96+
final UnsafeRowConverter converter = new UnsafeRowConverter(schema);
97+
final long[] unsafeRow = new long[converter.getSizeRequirement(javaRow)];
98+
final long writtenLength =
99+
converter.writeRow(javaRow, unsafeRow, PlatformDependent.LONG_ARRAY_OFFSET);
100+
assert (writtenLength == unsafeRow.length): "Size requirement calculation was wrong!";
101+
return unsafeRow;
102+
}
103+
104+
/**
105+
* Return the aggregation buffer for the current group. For efficiency, all calls to this method
106+
* return the same object.
107+
*/
108+
public UnsafeRow getAggregationBuffer(Row groupingKey) {
109+
// Zero out the buffer that's used to hold the current row. This is necessary in order
110+
// to ensure that rows hash properly, since garbage data from the previous row could
111+
// otherwise end up as padding in this row.
112+
Arrays.fill(groupingKeyConversionScratchSpace, 0);
113+
final int groupingKeySize = groupingKeyToUnsafeRowConverter.getSizeRequirement(groupingKey);
114+
if (groupingKeySize > groupingKeyConversionScratchSpace.length) {
115+
groupingKeyConversionScratchSpace = new long[groupingKeySize];
116+
}
117+
final long actualGroupingKeySize = groupingKeyToUnsafeRowConverter.writeRow(
118+
groupingKey,
119+
groupingKeyConversionScratchSpace,
120+
PlatformDependent.LONG_ARRAY_OFFSET);
121+
assert (groupingKeySize == actualGroupingKeySize) : "Size requirement calculation was wrong!";
122+
123+
// Probe our map using the serialized key
124+
final BytesToBytesMap.Location loc = map.lookup(
125+
groupingKeyConversionScratchSpace,
126+
PlatformDependent.LONG_ARRAY_OFFSET,
127+
groupingKeySize);
128+
if (!loc.isDefined()) {
129+
// This is the first time that we've seen this grouping key, so we'll insert a copy of the
130+
// empty aggregation buffer into the map:
131+
loc.storeKeyAndValue(
132+
groupingKeyConversionScratchSpace,
133+
PlatformDependent.LONG_ARRAY_OFFSET,
134+
groupingKeySize,
135+
emptyAggregationBuffer,
136+
PlatformDependent.LONG_ARRAY_OFFSET,
137+
emptyAggregationBuffer.length
138+
);
139+
}
140+
141+
// Reset the pointer to point to the value that we just stored or looked up:
142+
final MemoryLocation address = loc.getValueAddress();
143+
currentAggregationBuffer.set(
144+
address.getBaseObject(),
145+
address.getBaseOffset(),
146+
aggregationBufferSchema.length(),
147+
aggregationBufferSchema
148+
);
149+
return currentAggregationBuffer;
150+
}
151+
152+
public static class MapEntry {
153+
public final UnsafeRow key = new UnsafeRow();
154+
public final UnsafeRow value = new UnsafeRow();
155+
}
156+
157+
/**
158+
* Returns an iterator over the keys and values in this map.
159+
*
160+
* For efficiency, each call returns the same object.
161+
*/
162+
public Iterator<MapEntry> iterator() {
163+
return new Iterator<MapEntry>() {
164+
165+
private final MapEntry entry = new MapEntry();
166+
private final Iterator<BytesToBytesMap.Location> mapLocationIterator = map.iterator();
167+
168+
@Override
169+
public boolean hasNext() {
170+
return mapLocationIterator.hasNext();
171+
}
172+
173+
@Override
174+
public MapEntry next() {
175+
final BytesToBytesMap.Location loc = mapLocationIterator.next();
176+
final MemoryLocation keyAddress = loc.getKeyAddress();
177+
final MemoryLocation valueAddress = loc.getValueAddress();
178+
entry.key.set(
179+
keyAddress.getBaseObject(),
180+
keyAddress.getBaseOffset(),
181+
groupingKeySchema.length(),
182+
groupingKeySchema
183+
);
184+
entry.value.set(
185+
valueAddress.getBaseObject(),
186+
valueAddress.getBaseOffset(),
187+
aggregationBufferSchema.length(),
188+
aggregationBufferSchema
189+
);
190+
return entry;
191+
}
192+
193+
@Override
194+
public void remove() {
195+
throw new UnsupportedOperationException();
196+
}
197+
};
198+
}
199+
200+
/**
201+
* Free the unsafe memory associated with this map.
202+
*/
203+
public void free() {
204+
map.free();
205+
}
206+
207+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,10 @@ private case object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter
154154

155155
class UnsafeRowConverter(fieldTypes: Array[DataType]) {
156156

157+
def this(schema: StructType) {
158+
this(schema.fields.map(_.dataType))
159+
}
160+
157161
private[this] val unsafeRow = new UnsafeRow()
158162

159163
private[this] val writers: Array[UnsafeColumnWriter[Any]] = {

sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeGeneratedAggregate.scala

Lines changed: 18 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,12 @@
1717

1818
package org.apache.spark.sql.execution
1919

20-
import java.util
21-
2220
import org.apache.spark.annotation.DeveloperApi
2321
import org.apache.spark.rdd.RDD
2422
import org.apache.spark.sql.catalyst.trees._
2523
import org.apache.spark.sql.catalyst.expressions._
2624
import org.apache.spark.sql.catalyst.plans.physical._
2725
import org.apache.spark.sql.types._
28-
import org.apache.spark.unsafe.PlatformDependent
29-
import org.apache.spark.unsafe.map.BytesToBytesMap
3026
import org.apache.spark.unsafe.memory.MemoryAllocator
3127

3228
// TODO: finish cleaning up documentation instead of just copying it
@@ -258,128 +254,50 @@ case class UnsafeGeneratedAggregate(
258254
val resultProjection = resultProjectionBuilder()
259255
Iterator(resultProjection(buffer))
260256
} else {
261-
// TODO: if we knew how many groups to expect, we could size this hashmap appropriately
262-
val buffers = new BytesToBytesMap(MemoryAllocator.UNSAFE, 128)
263-
264-
// Set up the mutable "pointers" that we'll re-use when pointing to key and value rows
265-
val currentBuffer: UnsafeRow = new UnsafeRow()
266-
267-
// We're going to need to allocate a lot of empty aggregation buffers, so let's do it
268-
// once and keep a copy of the serialized buffer and copy it into the hash map when we see
269-
// new keys:
270-
val emptyAggregationBuffer: Array[Long] = {
271-
val javaBuffer: MutableRow = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow]
272-
val fieldTypes = StructType.fromAttributes(computationSchema).map(_.dataType).toArray
273-
val converter = new UnsafeRowConverter(fieldTypes)
274-
val buffer = new Array[Long](converter.getSizeRequirement(javaBuffer))
275-
converter.writeRow(javaBuffer, buffer, PlatformDependent.LONG_ARRAY_OFFSET)
276-
buffer
277-
}
278-
279-
val keyToUnsafeRowConverter: UnsafeRowConverter = {
280-
new UnsafeRowConverter(groupingExpressions.map(_.dataType).toArray)
281-
}
282-
283257
val aggregationBufferSchema = StructType.fromAttributes(computationSchema)
284-
val keySchema: StructType = {
258+
259+
val groupKeySchema: StructType = {
285260
val fields = groupingExpressions.zipWithIndex.map { case (expr, idx) =>
286261
StructField(idx.toString, expr.dataType, expr.nullable)
287262
}
288263
StructType(fields)
289264
}
290265

291-
// Allocate some scratch space for holding the keys that we use to index into the hash map.
292-
// 16 MB ought to be enough for anyone (TODO)
293-
val unsafeRowBuffer: Array[Long] = new Array[Long](1024 * 16 / 8)
266+
val aggregationMap = new UnsafeFixedWidthAggregationMap(
267+
newAggregationBuffer(EmptyRow),
268+
aggregationBufferSchema,
269+
groupKeySchema,
270+
MemoryAllocator.UNSAFE,
271+
1024
272+
)
294273

295274
while (iter.hasNext) {
296-
// Zero out the buffer that's used to hold the current row. This is necessary in order
297-
// to ensure that rows hash properly, since garbage data from the previous row could
298-
// otherwise end up as padding in this row.
299-
util.Arrays.fill(unsafeRowBuffer, 0)
300-
// Grab the next row from our input iterator and compute its group projection.
301-
// In the long run, it might be nice to use Unsafe rows for this as well, but for now
302-
// we'll just rely on the existing code paths to compute the projection.
303-
val currentJavaRow = iter.next()
304-
val currentGroup: Row = groupProjection(currentJavaRow)
305-
// Convert the current group into an UnsafeRow so that we can use it as a key for our
306-
// aggregation hash map
307-
val groupProjectionSize = keyToUnsafeRowConverter.getSizeRequirement(currentGroup)
308-
if (groupProjectionSize > unsafeRowBuffer.length) {
309-
throw new IllegalStateException("Group projection does not fit into buffer")
310-
}
311-
val keyLengthInBytes: Int = keyToUnsafeRowConverter.writeRow(
312-
currentGroup, unsafeRowBuffer, PlatformDependent.LONG_ARRAY_OFFSET).toInt // TODO
313-
314-
val loc: BytesToBytesMap#Location =
315-
buffers.lookup(unsafeRowBuffer, PlatformDependent.LONG_ARRAY_OFFSET, keyLengthInBytes)
316-
if (!loc.isDefined) {
317-
// This is the first time that we've seen this key, so we'll copy the empty aggregation
318-
// buffer row that we created earlier. TODO: this doesn't work very well for aggregates
319-
// where the size of the aggregate buffer is different for different rows (even if the
320-
// size of buffers don't grow once created, as is the case for things like grabbing the
321-
// first row's value for a string-valued column (or the shortest string)).
322-
323-
loc.storeKeyAndValue(
324-
unsafeRowBuffer,
325-
PlatformDependent.LONG_ARRAY_OFFSET,
326-
keyLengthInBytes,
327-
emptyAggregationBuffer,
328-
PlatformDependent.LONG_ARRAY_OFFSET,
329-
emptyAggregationBuffer.length
330-
)
331-
}
332-
// Reset our pointer to point to the buffer stored in the hash map
333-
val address = loc.getValueAddress
334-
currentBuffer.set(
335-
address.getBaseObject,
336-
address.getBaseOffset,
337-
aggregationBufferSchema.length,
338-
aggregationBufferSchema
339-
)
340-
// Target the projection at the current aggregation buffer and then project the updated
341-
// values.
342-
updateProjection.target(currentBuffer)(joinedRow(currentBuffer, currentJavaRow))
275+
val currentRow: Row = iter.next()
276+
val groupKey: Row = groupProjection(currentRow)
277+
val aggregationBuffer = aggregationMap.getAggregationBuffer(groupKey)
278+
updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow))
343279
}
344280

345281
new Iterator[Row] {
346-
private[this] val resultIterator = buffers.iterator()
282+
private[this] val mapIterator = aggregationMap.iterator()
347283
private[this] val resultProjection = resultProjectionBuilder()
348-
private[this] val key: UnsafeRow = new UnsafeRow()
349-
private[this] val value: UnsafeRow = new UnsafeRow()
350284

351-
def hasNext: Boolean = resultIterator.hasNext
285+
def hasNext: Boolean = mapIterator.hasNext
352286

353287
def next(): Row = {
354-
val currentGroup: BytesToBytesMap#Location = resultIterator.next()
355-
val keyAddress = currentGroup.getKeyAddress
356-
key.set(
357-
keyAddress.getBaseObject,
358-
keyAddress.getBaseOffset,
359-
groupingExpressions.length,
360-
keySchema)
361-
val valueAddress = currentGroup.getValueAddress
362-
value.set(
363-
valueAddress.getBaseObject,
364-
valueAddress.getBaseOffset,
365-
aggregationBufferSchema.length,
366-
aggregationBufferSchema)
367-
val result = resultProjection(joinedRow(key, value))
288+
val entry = mapIterator.next()
289+
val result = resultProjection(joinedRow(entry.key, entry.value))
368290
if (hasNext) {
369291
result
370292
} else {
371293
// This is the last element in the iterator, so let's free the buffer. Before we do,
372294
// though, we need to make a defensive copy of the result so that we don't return an
373295
// object that might contain dangling pointers to the freed memory
374296
val resultCopy = result.copy()
375-
buffers.free()
297+
aggregationMap.free()
376298
resultCopy
377299
}
378300
}
379-
380-
override def finalize(): Unit = {
381-
buffers.free()
382-
}
383301
}
384302
}
385303
}

0 commit comments

Comments
 (0)