@@ -433,14 +433,7 @@ abstract class DeclarativeAggregate
433433 * calls method `eval(buffer: T)` to generate the final output for this group.
434434 * 5. The framework moves on to next group, until all groups have been processed.
435435 */
436- abstract class TypedImperativeAggregate [T >: Null ] extends ImperativeAggregate {
437-
438- /**
439- * Spark Sql type of user-defined aggregation buffer object. It needs to be an `UserDefinedType`
440- * so that the framework knows how to serialize the aggregation buffer object to Spark sql
441- * internally supported storage format.
442- */
443- def aggregationBufferType : UserDefinedType [T ]
436+ abstract class TypedImperativeAggregate [T ] extends ImperativeAggregate {
444437
445438 /**
446439 * Creates an empty aggregation buffer object. This is called before processing each key group
@@ -478,6 +471,43 @@ abstract class TypedImperativeAggregate[T >: Null] extends ImperativeAggregate {
478471 */
479472 def eval (buffer : T ): Any
480473
474+ /** Returns the class of aggregation buffer object */
475+ def aggregationBufferClass : Class [T ]
476+
477+ /** Serializes the aggregation buffer object T to Spark-sql internally supported storage format */
478+ def serialize (buffer : T ): Any
479+
480+ /** De-serializes the storage format, and produces aggregation buffer object T */
481+ def deserialize (storageFormat : Any ): T
482+
483+ /**
484+ * Returns the aggregation-buffer-object storage format's Sql type.
485+ *
486+ * Here is a list of supported storage format and corresponding Sql type:
487+ *
488+ * {{{
489+ * aggregation buffer object's Storage format | storage format's Sql type
490+ * ------------------------------------------------------------------------------------------
491+ * Array[Byte] (*) | BinaryType (*)
492+ * Null | NullType
493+ * Boolean | BooleanType
494+ * Byte | ByteType
495+ * Short | ShortType
496+ * Int | IntegerType
497+ * Long | LongType
498+ * Float | FloatType
499+ * Double | DoubleType
500+ * org.apache.spark.sql.types.Decimal | DecimalType
501+ * org.apache.spark.unsafe.types.UTF8String | StringType
502+ * org.apache.spark.unsafe.types.CalendarInterval| CalendarIntervalType
503+ * org.apache.spark.sql.catalyst.util.MapData | MapType
504+ * org.apache.spark.sql.catalyst.util.ArrayData | ArrayType
505+ * org.apache.spark.sql.catalyst.InternalRow |
506+ * }}}
507+ *
508+ */
509+ def aggregationBufferStorageFormatSqlType : DataType
510+
481511 final override def initialize (buffer : MutableRow ): Unit = {
482512 val bufferObject = createAggregationBuffer()
483513 buffer.update(mutableAggBufferOffset, bufferObject)
@@ -496,7 +526,7 @@ abstract class TypedImperativeAggregate[T >: Null] extends ImperativeAggregate {
496526
497527 final override def eval (buffer : InternalRow ): Any = {
498528 val bufferObject = field(buffer, mutableAggBufferOffset)
499- if (bufferObject.getClass == aggregationBufferType.userClass ) {
529+ if (bufferObject.getClass == aggregationBufferClass ) {
500530 // When used in Window frame aggregation, eval(buffer: InternalRow) is called directly
501531 // on the object aggregation buffer without intermediate serializing/de-serializing.
502532 eval(bufferObject.asInstanceOf [T ])
@@ -505,17 +535,13 @@ abstract class TypedImperativeAggregate[T >: Null] extends ImperativeAggregate {
505535 }
506536 }
507537
508- private def deserialize (input : AnyRef ): T = {
509- aggregationBufferType.deserialize(input)
510- }
511-
512538 private def field (input : InternalRow , offset : Int ): AnyRef = {
513539 input.get(offset, null )
514540 }
515541
516- final override val aggBufferAttributes : Seq [AttributeReference ] = {
542+ final override lazy val aggBufferAttributes : Seq [AttributeReference ] = {
517543 // Underlying storage type for the aggregation buffer object
518- Seq (AttributeReference (" buf" , aggregationBufferType.sqlType )())
544+ Seq (AttributeReference (" buf" , aggregationBufferStorageFormatSqlType )())
519545 }
520546
521547 final override lazy val inputAggBufferAttributes : Seq [AttributeReference ] =
@@ -531,6 +557,6 @@ abstract class TypedImperativeAggregate[T >: Null] extends ImperativeAggregate {
531557 */
532558 final def serializeAggregateBufferInPlace (buffer : MutableRow ): Unit = {
533559 val bufferObject = field(buffer, mutableAggBufferOffset).asInstanceOf [T ]
534- buffer(mutableAggBufferOffset) = aggregationBufferType. serialize(bufferObject)
560+ buffer(mutableAggBufferOffset) = serialize(bufferObject)
535561 }
536562}
0 commit comments