Skip to content

Commit 56c3001

Browse files
committed
fix tests
1 parent a68b977 commit 56c3001

File tree

2 files changed

+27
-19
lines changed

2 files changed

+27
-19
lines changed

sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
113113
case udaf: UserDefinedAggregator[_, _, _] =>
114114
def builder(children: Seq[Expression]) = udaf.scalaAggregator(children)
115115
functionRegistry.createOrReplaceTempFunction(name, builder)
116-
udf
116+
udaf
117117
case other =>
118118
def builder(children: Seq[Expression]) = other.apply(children.map(Column.apply) : _*).expr
119119
functionRegistry.createOrReplaceTempFunction(name, builder)

sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -321,26 +321,34 @@ object IntegratedUDFTestUtils extends SQLHelper {
321321
* casted_col.cast(df.schema("col").dataType)
322322
* }}}
323323
*/
324-
case class TestScalaUDF(name: String) extends TestUDF {
325-
private[IntegratedUDFTestUtils] lazy val udf = new SparkUserDefinedFunction(
326-
(input: Any) => if (input == null) {
327-
null
328-
} else {
329-
input.toString
330-
},
331-
StringType,
332-
inputEncoders = Seq.fill(1)(None),
333-
name = Some(name)) {
334-
335-
override def apply(exprs: Column*): Column = {
336-
assert(exprs.length == 1, "Defined UDF only has one column")
337-
val expr = exprs.head.expr
338-
assert(expr.resolved, "column should be resolved to use the same type " +
339-
"as input. Try df(name) or df.col(name)")
340-
Column(Cast(createScalaUDF(Cast(expr, StringType) :: Nil), expr.dataType))
341-
}
324+
class TestInternalScalaUDF(name: String) extends SparkUserDefinedFunction(
325+
(input: Any) => if (input == null) {
326+
null
327+
} else {
328+
input.toString
329+
},
330+
StringType,
331+
inputEncoders = Seq.fill(1)(None),
332+
name = Some(name)) {
333+
334+
override def apply(exprs: Column*): Column = {
335+
assert(exprs.length == 1, "Defined UDF only has one column")
336+
val expr = exprs.head.expr
337+
assert(expr.resolved, "column should be resolved to use the same type " +
338+
"as input. Try df(name) or df.col(name)")
339+
Column(Cast(createScalaUDF(Cast(expr, StringType) :: Nil), expr.dataType))
342340
}
343341

342+
override def withName(name: String): TestInternalScalaUDF = {
343+
// "withName" should overridden to return TestInternalScalaUDF. Otherwise, the current object
344+
// is sliced and the overridden "apply" is not invoked.
345+
new TestInternalScalaUDF(name)
346+
}
347+
}
348+
349+
case class TestScalaUDF(name: String) extends TestUDF {
350+
private[IntegratedUDFTestUtils] lazy val udf = new TestInternalScalaUDF(name)
351+
344352
def apply(exprs: Column*): Column = udf(exprs: _*)
345353

346354
val prettyName: String = "Scala UDF"

0 commit comments

Comments
 (0)