Skip to content

Commit a68b977

Browse files
committed
Add tests
1 parent f2fd94c commit a68b977

File tree

2 files changed

+52
-6
lines changed

2 files changed

+52
-6
lines changed

sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
114114
def builder(children: Seq[Expression]) = udaf.scalaAggregator(children)
115115
functionRegistry.createOrReplaceTempFunction(name, builder)
116116
udf
117-
case _ =>
118-
def builder(children: Seq[Expression]) = udf.apply(children.map(Column.apply) : _*).expr
117+
case other =>
118+
def builder(children: Seq[Expression]) = other.apply(children.map(Column.apply) : _*).expr
119119
functionRegistry.createOrReplaceTempFunction(name, builder)
120-
udf
120+
other
121121
}
122122
}
123123

sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,18 @@ import scala.collection.mutable.{ArrayBuffer, WrappedArray}
2626

2727
import org.apache.spark.SparkException
2828
import org.apache.spark.sql.api.java._
29-
import org.apache.spark.sql.catalyst.encoders.OuterScopes
29+
import org.apache.spark.sql.catalyst.FunctionIdentifier
30+
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, OuterScopes}
31+
import org.apache.spark.sql.catalyst.expressions.{Literal, ScalaUDF}
3032
import org.apache.spark.sql.catalyst.plans.logical.Project
3133
import org.apache.spark.sql.catalyst.util.DateTimeUtils
3234
import org.apache.spark.sql.execution.{QueryExecution, SimpleMode}
35+
import org.apache.spark.sql.execution.aggregate.{ScalaAggregator, ScalaUDAF}
3336
import org.apache.spark.sql.execution.columnar.InMemoryRelation
3437
import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, ExplainCommand}
3538
import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
36-
import org.apache.spark.sql.expressions.SparkUserDefinedFunction
37-
import org.apache.spark.sql.functions.{lit, struct, udf}
39+
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, SparkUserDefinedFunction, UserDefinedAggregateFunction}
40+
import org.apache.spark.sql.functions.{lit, struct, udaf, udf}
3841
import org.apache.spark.sql.internal.SQLConf
3942
import org.apache.spark.sql.test.SharedSparkSession
4043
import org.apache.spark.sql.test.SQLTestData._
@@ -798,4 +801,47 @@ class UDFSuite extends QueryTest with SharedSparkSession {
798801
.select(myUdf(Column("col"))),
799802
Row(ArrayBuffer(100)))
800803
}
804+
805+
test("SPARK-34388: UDF name is propagated with registration for ScalaUDF") {
806+
spark.udf.register("udf34388", udf((value: Int) => value > 2))
807+
spark.sessionState.catalog.lookupFunction(
808+
FunctionIdentifier("udf34388"), Seq(Literal(1))) match {
809+
case udf: ScalaUDF => assert(udf.name === "udf34388")
810+
}
811+
}
812+
813+
test("SPARK-34388: UDF name is propagated with registration for ScalaAggregator") {
814+
val agg = new Aggregator[Long, Long, Long] {
815+
override def zero: Long = 0L
816+
override def reduce(b: Long, a: Long): Long = a + b
817+
override def merge(b1: Long, b2: Long): Long = b1 + b2
818+
override def finish(reduction: Long): Long = reduction
819+
override def bufferEncoder: Encoder[Long] = ExpressionEncoder[Long]()
820+
override def outputEncoder: Encoder[Long] = ExpressionEncoder[Long]()
821+
}
822+
823+
spark.udf.register("agg34388", udaf(agg))
824+
spark.sessionState.catalog.lookupFunction(
825+
FunctionIdentifier("agg34388"), Seq(Literal(1))) match {
826+
case agg: ScalaAggregator[_, _, _] => assert(agg.name === "agg34388")
827+
}
828+
}
829+
830+
test("SPARK-34388: UDF name is propagated with registration for ScalaUDAF") {
831+
val udaf = new UserDefinedAggregateFunction {
832+
def inputSchema: StructType = new StructType().add("a", LongType)
833+
def bufferSchema: StructType = new StructType().add("product", LongType)
834+
def dataType: DataType = LongType
835+
def deterministic: Boolean = true
836+
def initialize(buffer: MutableAggregationBuffer): Unit = {}
837+
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {}
838+
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {}
839+
def evaluate(buffer: Row): Any = buffer.getLong(0)
840+
}
841+
spark.udf.register("udaf34388", udaf)
842+
spark.sessionState.catalog.lookupFunction(
843+
FunctionIdentifier("udaf34388"), Seq(Literal(1))) match {
844+
case udaf: ScalaUDAF => assert(udaf.name === "udaf34388")
845+
}
846+
}
801847
}

0 commit comments

Comments
 (0)