Skip to content

Commit b127c41

Browse files
committed
fix
1 parent 2c943cc commit b127c41

File tree

1 file changed

+29
-66
lines changed

1 file changed

+29
-66
lines changed

sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala

Lines changed: 29 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@ package org.apache.spark.sql.expressions
2020
import scala.collection.parallel.immutable.ParVector
2121

2222
import org.apache.spark.SparkFunSuite
23-
import org.apache.spark.sql.catalyst.FunctionIdentifier
24-
import org.apache.spark.sql.catalyst.expressions.{NonSQLExpression, _}
25-
import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
23+
import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow}
24+
import org.apache.spark.sql.catalyst.expressions._
2625
import org.apache.spark.sql.execution.HiveResult.hiveResultString
2726
import org.apache.spark.sql.internal.SQLConf
2827
import org.apache.spark.sql.test.SharedSparkSession
@@ -159,73 +158,37 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession {
159158
}
160159
}
161160

162-
test("Check whether should extend NullIntolerant") {
163-
// Only check expressions extended from these expressions
164-
val parentExpressionNames = Seq(classOf[UnaryExpression], classOf[BinaryExpression],
165-
classOf[TernaryExpression], classOf[QuaternaryExpression],
166-
classOf[SeptenaryExpression]).map(_.getName)
167-
// Do not check these expressions
168-
val whiteList = Seq(
169-
classOf[IntegralDivide], classOf[Divide], classOf[Remainder], classOf[Pmod],
170-
classOf[CheckOverflow], classOf[NormalizeNaNAndZero], classOf[InSet],
171-
classOf[PrintToStderr], classOf[CodegenFallbackExpression]).map(_.getName)
172-
173-
spark.sessionState.functionRegistry.listFunction()
174-
.map(spark.sessionState.catalog.lookupFunctionInfo).map(_.getClassName)
175-
.filterNot(c => whiteList.exists(_.equals(c))).foreach { className =>
176-
if (needToCheckNullIntolerant(className)) {
177-
val evalExist = checkIfEvalOverrode(className)
178-
val nullIntolerantExist = checkIfNullIntolerantMixedIn(className)
179-
if (evalExist && nullIntolerantExist) {
180-
fail(s"$className should not extend ${classOf[NullIntolerant].getSimpleName}")
181-
} else if (!evalExist && !nullIntolerantExist) {
182-
fail(s"$className should extend ${classOf[NullIntolerant].getSimpleName}")
183-
} else {
184-
assert((!evalExist && nullIntolerantExist) || (evalExist && !nullIntolerantExist))
185-
}
186-
}
187-
}
161+
test("Check whether SQL expressions should extend NullIntolerant") {
162+
// Only check expressions extended from these expressions because these expressions are
163+
// NullIntolerant by default.
164+
val exprTypesToCheck = Seq(classOf[UnaryExpression], classOf[BinaryExpression],
165+
classOf[TernaryExpression], classOf[QuaternaryExpression], classOf[SeptenaryExpression])
188166

189-
def needToCheckNullIntolerant(className: String): Boolean = {
190-
var clazz: Class[_] = Utils.classForName(className)
191-
val isNonSQLExpr =
192-
clazz.getInterfaces.exists(_.getName.equals(classOf[NonSQLExpression].getName))
193-
var checkNullIntolerant: Boolean = false
194-
while (!checkNullIntolerant && clazz.getSuperclass != null) {
195-
checkNullIntolerant = parentExpressionNames.exists(_.equals(clazz.getSuperclass.getName))
196-
if (!checkNullIntolerant) {
197-
clazz = clazz.getSuperclass
198-
}
199-
}
200-
checkNullIntolerant && !isNonSQLExpr
201-
}
167+
// Do not check these expressions, because these expressions extend NullIntolerant
168+
// and override the eval function.
169+
val ignoreSet = Set(classOf[IntegralDivide], classOf[Divide], classOf[Remainder], classOf[Pmod])
202170

203-
def checkIfNullIntolerantMixedIn(className: String): Boolean = {
204-
val nullIntolerantName = classOf[NullIntolerant].getName
205-
var clazz: Class[_] = Utils.classForName(className)
206-
var nullIntolerantMixedIn = false
207-
while (!nullIntolerantMixedIn && !parentExpressionNames.exists(_.equals(clazz.getName))) {
208-
nullIntolerantMixedIn = clazz.getInterfaces.exists(_.getName.equals(nullIntolerantName)) ||
209-
clazz.getInterfaces.exists { i =>
210-
Utils.classForName(i.getName).getInterfaces.exists(_.getName.equals(nullIntolerantName))
211-
}
212-
if (!nullIntolerantMixedIn) {
213-
clazz = clazz.getSuperclass
214-
}
215-
}
216-
nullIntolerantMixedIn
217-
}
218-
219-
def checkIfEvalOverrode(className: String): Boolean = {
220-
var clazz: Class[_] = Utils.classForName(className)
221-
var evalOverrode: Boolean = false
222-
while (!evalOverrode && !parentExpressionNames.exists(_.equals(clazz.getName))) {
223-
evalOverrode = clazz.getDeclaredMethods.exists(_.getName.equals("eval"))
224-
if (!evalOverrode) {
225-
clazz = clazz.getSuperclass
171+
val candidateExprsToCheck = spark.sessionState.functionRegistry.listFunction()
172+
.map(spark.sessionState.catalog.lookupFunctionInfo).map(_.getClassName)
173+
.filterNot(c => ignoreSet.exists(_.getName.equals(c)))
174+
.map(name => Utils.classForName(name))
175+
.filterNot(classOf[NonSQLExpression].isAssignableFrom)
176+
177+
exprTypesToCheck.foreach { superClass =>
178+
candidateExprsToCheck.filter(superClass.isAssignableFrom).foreach { clazz =>
179+
val isEvalOverrode = clazz.getMethod("eval", classOf[InternalRow]) !=
180+
superClass.getMethod("eval", classOf[InternalRow])
181+
val isNullIntolerantMixedIn = classOf[NullIntolerant].isAssignableFrom(clazz)
182+
if (isEvalOverrode && isNullIntolerantMixedIn) {
183+
fail(s"${clazz.getName} should not extend ${classOf[NullIntolerant].getSimpleName}, " +
184+
s"or add ${clazz.getName} in the ignoreSet of this test.")
185+
} else if (!isEvalOverrode && !isNullIntolerantMixedIn) {
186+
fail(s"${clazz.getName} should extend ${classOf[NullIntolerant].getSimpleName}.")
187+
} else {
188+
assert((!isEvalOverrode && isNullIntolerantMixedIn) ||
189+
(isEvalOverrode && !isNullIntolerantMixedIn))
226190
}
227191
}
228-
evalOverrode
229192
}
230193
}
231194
}

0 commit comments

Comments
 (0)