Skip to content

Commit 0fdc1ea

Browse files
committed
fix comments
1 parent 10861b2 commit 0fdc1ea

File tree

2 files changed

+53
-31
lines changed

2 files changed

+53
-31
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -433,14 +433,7 @@ abstract class DeclarativeAggregate
433433
* calls method `eval(buffer: T)` to generate the final output for this group.
434434
* 5. The framework moves on to next group, until all groups have been processed.
435435
*/
436-
abstract class TypedImperativeAggregate[T >: Null] extends ImperativeAggregate {
437-
438-
/**
439-
* Spark Sql type of user-defined aggregation buffer object. It needs to be an `UserDefinedType`
440-
* so that the framework knows how to serialize the aggregation buffer object to Spark sql
441-
* internally supported storage format.
442-
*/
443-
def aggregationBufferType: UserDefinedType[T]
436+
abstract class TypedImperativeAggregate[T] extends ImperativeAggregate {
444437

445438
/**
446439
* Creates an empty aggregation buffer object. This is called before processing each key group
@@ -478,6 +471,43 @@ abstract class TypedImperativeAggregate[T >: Null] extends ImperativeAggregate {
478471
*/
479472
def eval(buffer: T): Any
480473

474+
/** Returns the class of aggregation buffer object */
475+
def aggregationBufferClass: Class[T]
476+
477+
/** Serializes the aggregation buffer object T to Spark-sql internally supported storage format */
478+
def serialize(buffer: T): Any
479+
480+
/** De-serializes the storage format, and produces aggregation buffer object T */
481+
def deserialize(storageFormat: Any): T
482+
483+
/**
484+
* Returns the aggregation-buffer-object storage format's Sql type.
485+
*
486+
* Here is a list of supported storage format and corresponding Sql type:
487+
*
488+
* {{{
489+
* aggregation buffer object's Storage format | storage format's Sql type
490+
* ------------------------------------------------------------------------------------------
491+
* Array[Byte] (*) | BinaryType (*)
492+
* Null | NullType
493+
* Boolean | BooleanType
494+
* Byte | ByteType
495+
* Short | ShortType
496+
* Int | IntegerType
497+
* Long | LongType
498+
* Float | FloatType
499+
* Double | DoubleType
500+
* org.apache.spark.sql.types.Decimal | DecimalType
501+
* org.apache.spark.unsafe.types.UTF8String | StringType
502+
* org.apache.spark.unsafe.types.CalendarInterval| CalendarIntervalType
503+
* org.apache.spark.sql.catalyst.util.MapData | MapType
504+
* org.apache.spark.sql.catalyst.util.ArrayData | ArrayType
505+
* org.apache.spark.sql.catalyst.InternalRow |
506+
* }}}
507+
*
508+
*/
509+
def aggregationBufferStorageFormatSqlType: DataType
510+
481511
final override def initialize(buffer: MutableRow): Unit = {
482512
val bufferObject = createAggregationBuffer()
483513
buffer.update(mutableAggBufferOffset, bufferObject)
@@ -496,7 +526,7 @@ abstract class TypedImperativeAggregate[T >: Null] extends ImperativeAggregate {
496526

497527
final override def eval(buffer: InternalRow): Any = {
498528
val bufferObject = field(buffer, mutableAggBufferOffset)
499-
if (bufferObject.getClass == aggregationBufferType.userClass) {
529+
if (bufferObject.getClass == aggregationBufferClass) {
500530
// When used in Window frame aggregation, eval(buffer: InternalRow) is called directly
501531
// on the object aggregation buffer without intermediate serializing/de-serializing.
502532
eval(bufferObject.asInstanceOf[T])
@@ -505,17 +535,13 @@ abstract class TypedImperativeAggregate[T >: Null] extends ImperativeAggregate {
505535
}
506536
}
507537

508-
private def deserialize(input: AnyRef): T = {
509-
aggregationBufferType.deserialize(input)
510-
}
511-
512538
private def field(input: InternalRow, offset: Int): AnyRef = {
513539
input.get(offset, null)
514540
}
515541

516-
final override val aggBufferAttributes: Seq[AttributeReference] = {
542+
final override lazy val aggBufferAttributes: Seq[AttributeReference] = {
517543
// Underlying storage type for the aggregation buffer object
518-
Seq(AttributeReference("buf", aggregationBufferType.sqlType)())
544+
Seq(AttributeReference("buf", aggregationBufferStorageFormatSqlType)())
519545
}
520546

521547
final override lazy val inputAggBufferAttributes: Seq[AttributeReference] =
@@ -531,6 +557,6 @@ abstract class TypedImperativeAggregate[T >: Null] extends ImperativeAggregate {
531557
*/
532558
final def serializeAggregateBufferInPlace(buffer: MutableRow): Unit = {
533559
val bufferObject = field(buffer, mutableAggBufferOffset).asInstanceOf[T]
534-
buffer(mutableAggBufferOffset) = aggregationBufferType.serialize(bufferObject)
560+
buffer(mutableAggBufferOffset) = serialize(bufferObject)
535561
}
536562
}

sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ package org.apache.spark.sql
2020
import org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMax
2121
import org.apache.spark.sql.catalyst.InternalRow
2222
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, UnsafeRow}
23-
import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate}
23+
import org.apache.spark.sql.catalyst.expressions.aggregate.{TypedImperativeAggregate}
2424
import org.apache.spark.sql.execution.aggregate.SortAggregateExec
2525
import org.apache.spark.sql.functions._
2626
import org.apache.spark.sql.test.SharedSQLContext
27-
import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType, UserDefinedType}
27+
import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegerType}
2828

2929
class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext {
3030

@@ -119,7 +119,6 @@ object TypedImperativeAggregateSuite {
119119
mutableAggBufferOffset: Int = 0,
120120
inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[MaxValue] {
121121

122-
override lazy val aggregationBufferType: UserDefinedType[MaxValue] = new MaxValueUDT()
123122

124123
override def createAggregationBuffer(): MaxValue = {
125124
new MaxValue(Int.MinValue)
@@ -152,27 +151,24 @@ object TypedImperativeAggregateSuite {
152151

153152
override def dataType: DataType = IntegerType
154153

155-
override def withNewMutableAggBufferOffset(newOffset: Int): ImperativeAggregate =
154+
override def withNewMutableAggBufferOffset(newOffset: Int): TypedImperativeAggregate[MaxValue] =
156155
copy(mutableAggBufferOffset = newOffset)
157156

158-
override def withNewInputAggBufferOffset(newOffset: Int): ImperativeAggregate =
157+
override def withNewInputAggBufferOffset(newOffset: Int): TypedImperativeAggregate[MaxValue] =
159158
copy(inputAggBufferOffset = newOffset)
160159

161-
}
162-
163-
private class MaxValue(var value: Int)
160+
override def aggregationBufferClass: Class[MaxValue] = classOf[MaxValue]
164161

165-
private class MaxValueUDT extends UserDefinedType[MaxValue] {
166-
override def sqlType: DataType = IntegerType
162+
override def serialize(buffer: MaxValue): Any = buffer.value
167163

168-
override def serialize(obj: MaxValue): Any = obj.value
164+
override def aggregationBufferStorageFormatSqlType: DataType = IntegerType
169165

170-
override def userClass: Class[MaxValue] = classOf[MaxValue]
171-
172-
override def deserialize(datum: Any): MaxValue = {
173-
datum match {
166+
override def deserialize(storageFormat: Any): MaxValue = {
167+
storageFormat match {
174168
case i: Int => new MaxValue(i)
175169
}
176170
}
177171
}
172+
173+
private class MaxValue(var value: Int)
178174
}

0 commit comments

Comments
 (0)