Skip to content

Commit bd5ae26

Browse files
committed
fix.
1 parent 12cefc2 commit bd5ae26

File tree

2 files changed

+66
-58
lines changed

2 files changed

+66
-58
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,26 +1099,10 @@ class SessionCatalog(
10991099
* This performs reflection to decide what type of [[Expression]] to return in the builder.
11001100
*/
11011101
protected def makeFunctionBuilder(name: String, functionClassName: String): FunctionBuilder = {
1102-
makeFunctionBuilder(name, Utils.classForName(functionClassName))
1103-
}
1104-
1105-
/**
1106-
* Construct a [[FunctionBuilder]] based on the provided class that represents a function.
1107-
*/
1108-
private def makeFunctionBuilder(name: String, clazz: Class[_]): FunctionBuilder = {
1109-
// When we instantiate ScalaUDAF class, we may throw exception if the input
1110-
// expressions don't satisfy the UDAF, such as type mismatch, input number
1111-
// mismatch, etc. Here we catch the exception and throw AnalysisException instead.
1102+
val clazz = Utils.classForName(functionClassName)
11121103
(children: Seq[Expression]) => {
11131104
try {
1114-
val clsForUDAF =
1115-
Utils.classForName("org.apache.spark.sql.expressions.UserDefinedAggregateFunction")
1116-
if (clsForUDAF.isAssignableFrom(clazz)) {
1117-
val cls = Utils.classForName("org.apache.spark.sql.execution.aggregate.ScalaUDAF")
1118-
cls.getConstructor(classOf[Seq[Expression]], clsForUDAF, classOf[Int], classOf[Int])
1119-
.newInstance(children, clazz.newInstance().asInstanceOf[Object], Int.box(1), Int.box(1))
1120-
.asInstanceOf[Expression]
1121-
} else {
1105+
makeFunctionExpression(name, Utils.classForName(functionClassName), children).getOrElse {
11221106
throw new UnsupportedOperationException("Use sqlContext.udf.register(...) instead.")
11231107
}
11241108
} catch {
@@ -1137,6 +1121,25 @@ class SessionCatalog(
11371121
}
11381122
}
11391123

1124+
/**
1125+
* Construct a [[FunctionBuilder]] based on the provided class that represents a function.
1126+
*/
1127+
protected def makeFunctionExpression(
1128+
name: String,
1129+
clazz: Class[_],
1130+
children: Seq[Expression]): Option[Expression] = {
1131+
val clsForUDAF =
1132+
Utils.classForName("org.apache.spark.sql.expressions.UserDefinedAggregateFunction")
1133+
if (clsForUDAF.isAssignableFrom(clazz)) {
1134+
val cls = Utils.classForName("org.apache.spark.sql.execution.aggregate.ScalaUDAF")
1135+
Some(cls.getConstructor(classOf[Seq[Expression]], clsForUDAF, classOf[Int], classOf[Int])
1136+
.newInstance(children, clazz.newInstance().asInstanceOf[Object], Int.box(1), Int.box(1))
1137+
.asInstanceOf[Expression])
1138+
} else {
1139+
None
1140+
}
1141+
}
1142+
11401143
/**
11411144
* Loads resources such as JARs and Files for a function. Every resource is represented
11421145
* by a tuple (resource type, resource uri).

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala

Lines changed: 45 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
3434
import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, FunctionResourceLoader, GlobalTempViewManager, SessionCatalog}
3535
import org.apache.spark.sql.catalyst.expressions.{Cast, Expression}
3636
import org.apache.spark.sql.catalyst.parser.ParserInterface
37-
import org.apache.spark.sql.execution.aggregate.ScalaUDAF
38-
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
3937
import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper
4038
import org.apache.spark.sql.internal.SQLConf
4139
import org.apache.spark.sql.types.{DecimalType, DoubleType}
@@ -60,46 +58,11 @@ private[sql] class HiveSessionCatalog(
6058
parser,
6159
functionResourceLoader) {
6260

63-
override def makeFunctionBuilder(funcName: String, className: String): FunctionBuilder = {
64-
makeFunctionBuilder(funcName, Utils.classForName(className))
65-
}
66-
67-
/**
68-
* Construct a [[FunctionBuilder]] based on the provided class that represents a function.
69-
*/
70-
private def makeFunctionBuilder(name: String, clazz: Class[_]): FunctionBuilder = {
71-
// When we instantiate hive UDF wrapper class, we may throw exception if the input
72-
// expressions don't satisfy the hive UDF, such as type mismatch, input number
73-
// mismatch, etc. Here we catch the exception and throw AnalysisException instead.
61+
override def makeFunctionBuilder(name: String, functionClassName: String): FunctionBuilder = {
62+
val clazz = Utils.classForName(functionClassName)
7463
(children: Seq[Expression]) => {
7564
try {
76-
if (classOf[UDF].isAssignableFrom(clazz)) {
77-
val udf = HiveSimpleUDF(name, new HiveFunctionWrapper(clazz.getName), children)
78-
udf.dataType // Force it to check input data types.
79-
udf
80-
} else if (classOf[GenericUDF].isAssignableFrom(clazz)) {
81-
val udf = HiveGenericUDF(name, new HiveFunctionWrapper(clazz.getName), children)
82-
udf.dataType // Force it to check input data types.
83-
udf
84-
} else if (classOf[AbstractGenericUDAFResolver].isAssignableFrom(clazz)) {
85-
val udaf = HiveUDAFFunction(name, new HiveFunctionWrapper(clazz.getName), children)
86-
udaf.dataType // Force it to check input data types.
87-
udaf
88-
} else if (classOf[UDAF].isAssignableFrom(clazz)) {
89-
val udaf = HiveUDAFFunction(
90-
name,
91-
new HiveFunctionWrapper(clazz.getName),
92-
children,
93-
isUDAFBridgeRequired = true)
94-
udaf.dataType // Force it to check input data types.
95-
udaf
96-
} else if (classOf[GenericUDTF].isAssignableFrom(clazz)) {
97-
val udtf = HiveGenericUDTF(name, new HiveFunctionWrapper(clazz.getName), children)
98-
udtf.elementSchema // Force it to check input data types.
99-
udtf
100-
} else if (classOf[UserDefinedAggregateFunction].isAssignableFrom(clazz)) {
101-
ScalaUDAF(children, clazz.newInstance().asInstanceOf[UserDefinedAggregateFunction])
102-
} else {
65+
makeFunctionExpression(name, Utils.classForName(functionClassName), children).getOrElse {
10366
throw new AnalysisException(s"No handler for Hive UDF '${clazz.getCanonicalName}'")
10467
}
10568
} catch {
@@ -114,6 +77,48 @@ private[sql] class HiveSessionCatalog(
11477
}
11578
}
11679

80+
/**
81+
* Construct a [[FunctionBuilder]] based on the provided class that represents a function.
82+
*/
83+
override def makeFunctionExpression(
84+
name: String,
85+
clazz: Class[_],
86+
children: Seq[Expression]): Option[Expression] = {
87+
88+
super.makeFunctionExpression(name, clazz, children).orElse {
89+
// When we instantiate hive UDF wrapper class, we may throw exception if the input
90+
// expressions don't satisfy the hive UDF, such as type mismatch, input number
91+
// mismatch, etc. Here we catch the exception and throw AnalysisException instead.
92+
if (classOf[UDF].isAssignableFrom(clazz)) {
93+
val udf = HiveSimpleUDF(name, new HiveFunctionWrapper(clazz.getName), children)
94+
udf.dataType // Force it to check input data types.
95+
Some(udf)
96+
} else if (classOf[GenericUDF].isAssignableFrom(clazz)) {
97+
val udf = HiveGenericUDF(name, new HiveFunctionWrapper(clazz.getName), children)
98+
udf.dataType // Force it to check input data types.
99+
Some(udf)
100+
} else if (classOf[AbstractGenericUDAFResolver].isAssignableFrom(clazz)) {
101+
val udaf = HiveUDAFFunction(name, new HiveFunctionWrapper(clazz.getName), children)
102+
udaf.dataType // Force it to check input data types.
103+
Some(udaf)
104+
} else if (classOf[UDAF].isAssignableFrom(clazz)) {
105+
val udaf = HiveUDAFFunction(
106+
name,
107+
new HiveFunctionWrapper(clazz.getName),
108+
children,
109+
isUDAFBridgeRequired = true)
110+
udaf.dataType // Force it to check input data types.
111+
Some(udaf)
112+
} else if (classOf[GenericUDTF].isAssignableFrom(clazz)) {
113+
val udtf = HiveGenericUDTF(name, new HiveFunctionWrapper(clazz.getName), children)
114+
udtf.elementSchema // Force it to check input data types.
115+
Some(udtf)
116+
} else {
117+
None
118+
}
119+
}
120+
}
121+
117122
override def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression = {
118123
try {
119124
lookupFunction0(name, children)

0 commit comments

Comments
 (0)