@@ -26,15 +26,18 @@ import scala.collection.mutable.{ArrayBuffer, WrappedArray}
2626
2727import org .apache .spark .SparkException
2828import org .apache .spark .sql .api .java ._
29- import org .apache .spark .sql .catalyst .encoders .OuterScopes
29+ import org .apache .spark .sql .catalyst .FunctionIdentifier
30+ import org .apache .spark .sql .catalyst .encoders .{ExpressionEncoder , OuterScopes }
31+ import org .apache .spark .sql .catalyst .expressions .{Literal , ScalaUDF }
3032import org .apache .spark .sql .catalyst .plans .logical .Project
3133import org .apache .spark .sql .catalyst .util .DateTimeUtils
3234import org .apache .spark .sql .execution .{QueryExecution , SimpleMode }
35+ import org .apache .spark .sql .execution .aggregate .{ScalaAggregator , ScalaUDAF }
3336import org .apache .spark .sql .execution .columnar .InMemoryRelation
3437import org .apache .spark .sql .execution .command .{CreateDataSourceTableAsSelectCommand , ExplainCommand }
3538import org .apache .spark .sql .execution .datasources .InsertIntoHadoopFsRelationCommand
36- import org .apache .spark .sql .expressions .SparkUserDefinedFunction
37- import org .apache .spark .sql .functions .{lit , struct , udf }
39+ import org .apache .spark .sql .expressions .{ Aggregator , MutableAggregationBuffer , SparkUserDefinedFunction , UserDefinedAggregateFunction }
40+ import org .apache .spark .sql .functions .{lit , struct , udaf , udf }
3841import org .apache .spark .sql .internal .SQLConf
3942import org .apache .spark .sql .test .SharedSparkSession
4043import org .apache .spark .sql .test .SQLTestData ._
@@ -798,4 +801,47 @@ class UDFSuite extends QueryTest with SharedSparkSession {
798801 .select(myUdf(Column (" col" ))),
799802 Row (ArrayBuffer (100 )))
800803 }
804+
805+ test(" SPARK-34388: UDF name is propagated with registration for ScalaUDF" ) {
806+ spark.udf.register(" udf34388" , udf((value : Int ) => value > 2 ))
807+ spark.sessionState.catalog.lookupFunction(
808+ FunctionIdentifier (" udf34388" ), Seq (Literal (1 ))) match {
809+ case udf : ScalaUDF => assert(udf.name === " udf34388" )
810+ }
811+ }
812+
813+ test(" SPARK-34388: UDF name is propagated with registration for ScalaAggregator" ) {
814+ val agg = new Aggregator [Long , Long , Long ] {
815+ override def zero : Long = 0L
816+ override def reduce (b : Long , a : Long ): Long = a + b
817+ override def merge (b1 : Long , b2 : Long ): Long = b1 + b2
818+ override def finish (reduction : Long ): Long = reduction
819+ override def bufferEncoder : Encoder [Long ] = ExpressionEncoder [Long ]()
820+ override def outputEncoder : Encoder [Long ] = ExpressionEncoder [Long ]()
821+ }
822+
823+ spark.udf.register(" agg34388" , udaf(agg))
824+ spark.sessionState.catalog.lookupFunction(
825+ FunctionIdentifier (" agg34388" ), Seq (Literal (1 ))) match {
826+ case agg : ScalaAggregator [_, _, _] => assert(agg.name === " agg34388" )
827+ }
828+ }
829+
830+ test(" SPARK-34388: UDF name is propagated with registration for ScalaUDAF" ) {
831+ val udaf = new UserDefinedAggregateFunction {
832+ def inputSchema : StructType = new StructType ().add(" a" , LongType )
833+ def bufferSchema : StructType = new StructType ().add(" product" , LongType )
834+ def dataType : DataType = LongType
835+ def deterministic : Boolean = true
836+ def initialize (buffer : MutableAggregationBuffer ): Unit = {}
837+ def update (buffer : MutableAggregationBuffer , input : Row ): Unit = {}
838+ def merge (buffer1 : MutableAggregationBuffer , buffer2 : Row ): Unit = {}
839+ def evaluate (buffer : Row ): Any = buffer.getLong(0 )
840+ }
841+ spark.udf.register(" udaf34388" , udaf)
842+ spark.sessionState.catalog.lookupFunction(
843+ FunctionIdentifier (" udaf34388" ), Seq (Literal (1 ))) match {
844+ case udaf : ScalaUDAF => assert(udaf.name === " udaf34388" )
845+ }
846+ }
801847}
0 commit comments