@@ -20,9 +20,8 @@ package org.apache.spark.sql.expressions
2020import scala .collection .parallel .immutable .ParVector
2121
2222import org .apache .spark .SparkFunSuite
23- import org .apache .spark .sql .catalyst .FunctionIdentifier
24- import org .apache .spark .sql .catalyst .expressions .{NonSQLExpression , _ }
25- import org .apache .spark .sql .catalyst .optimizer .NormalizeNaNAndZero
23+ import org .apache .spark .sql .catalyst .{FunctionIdentifier , InternalRow }
24+ import org .apache .spark .sql .catalyst .expressions ._
2625import org .apache .spark .sql .execution .HiveResult .hiveResultString
2726import org .apache .spark .sql .internal .SQLConf
2827import org .apache .spark .sql .test .SharedSparkSession
@@ -159,73 +158,37 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession {
159158 }
160159 }
161160
162- test(" Check whether should extend NullIntolerant" ) {
163- // Only check expressions extended from these expressions
164- val parentExpressionNames = Seq (classOf [UnaryExpression ], classOf [BinaryExpression ],
165- classOf [TernaryExpression ], classOf [QuaternaryExpression ],
166- classOf [SeptenaryExpression ]).map(_.getName)
167- // Do not check these expressions
168- val whiteList = Seq (
169- classOf [IntegralDivide ], classOf [Divide ], classOf [Remainder ], classOf [Pmod ],
170- classOf [CheckOverflow ], classOf [NormalizeNaNAndZero ], classOf [InSet ],
171- classOf [PrintToStderr ], classOf [CodegenFallbackExpression ]).map(_.getName)
172-
173- spark.sessionState.functionRegistry.listFunction()
174- .map(spark.sessionState.catalog.lookupFunctionInfo).map(_.getClassName)
175- .filterNot(c => whiteList.exists(_.equals(c))).foreach { className =>
176- if (needToCheckNullIntolerant(className)) {
177- val evalExist = checkIfEvalOverrode(className)
178- val nullIntolerantExist = checkIfNullIntolerantMixedIn(className)
179- if (evalExist && nullIntolerantExist) {
180- fail(s " $className should not extend ${classOf [NullIntolerant ].getSimpleName}" )
181- } else if (! evalExist && ! nullIntolerantExist) {
182- fail(s " $className should extend ${classOf [NullIntolerant ].getSimpleName}" )
183- } else {
184- assert((! evalExist && nullIntolerantExist) || (evalExist && ! nullIntolerantExist))
185- }
186- }
187- }
161+ test(" Check whether SQL expressions should extend NullIntolerant" ) {
162+ // Only check expressions extended from these expressions because these expressions are
163+ // NullIntolerant by default.
164+ val exprTypesToCheck = Seq (classOf [UnaryExpression ], classOf [BinaryExpression ],
165+ classOf [TernaryExpression ], classOf [QuaternaryExpression ], classOf [SeptenaryExpression ])
188166
189- def needToCheckNullIntolerant (className : String ): Boolean = {
190- var clazz : Class [_] = Utils .classForName(className)
191- val isNonSQLExpr =
192- clazz.getInterfaces.exists(_.getName.equals(classOf [NonSQLExpression ].getName))
193- var checkNullIntolerant : Boolean = false
194- while (! checkNullIntolerant && clazz.getSuperclass != null ) {
195- checkNullIntolerant = parentExpressionNames.exists(_.equals(clazz.getSuperclass.getName))
196- if (! checkNullIntolerant) {
197- clazz = clazz.getSuperclass
198- }
199- }
200- checkNullIntolerant && ! isNonSQLExpr
201- }
167+ // Do not check these expressions, because these expressions extend NullIntolerant
168+ // and override the eval function.
169+ val ignoreSet = Set (classOf [IntegralDivide ], classOf [Divide ], classOf [Remainder ], classOf [Pmod ])
202170
203- def checkIfNullIntolerantMixedIn (className : String ): Boolean = {
204- val nullIntolerantName = classOf [NullIntolerant ].getName
205- var clazz : Class [_] = Utils .classForName(className)
206- var nullIntolerantMixedIn = false
207- while (! nullIntolerantMixedIn && ! parentExpressionNames.exists(_.equals(clazz.getName))) {
208- nullIntolerantMixedIn = clazz.getInterfaces.exists(_.getName.equals(nullIntolerantName)) ||
209- clazz.getInterfaces.exists { i =>
210- Utils .classForName(i.getName).getInterfaces.exists(_.getName.equals(nullIntolerantName))
211- }
212- if (! nullIntolerantMixedIn) {
213- clazz = clazz.getSuperclass
214- }
215- }
216- nullIntolerantMixedIn
217- }
218-
219- def checkIfEvalOverrode (className : String ): Boolean = {
220- var clazz : Class [_] = Utils .classForName(className)
221- var evalOverrode : Boolean = false
222- while (! evalOverrode && ! parentExpressionNames.exists(_.equals(clazz.getName))) {
223- evalOverrode = clazz.getDeclaredMethods.exists(_.getName.equals(" eval" ))
224- if (! evalOverrode) {
225- clazz = clazz.getSuperclass
171+ val candidateExprsToCheck = spark.sessionState.functionRegistry.listFunction()
172+ .map(spark.sessionState.catalog.lookupFunctionInfo).map(_.getClassName)
173+ .filterNot(c => ignoreSet.exists(_.getName.equals(c)))
174+ .map(name => Utils .classForName(name))
175+ .filterNot(classOf [NonSQLExpression ].isAssignableFrom)
176+
177+ exprTypesToCheck.foreach { superClass =>
178+ candidateExprsToCheck.filter(superClass.isAssignableFrom).foreach { clazz =>
179+ val isEvalOverrode = clazz.getMethod(" eval" , classOf [InternalRow ]) !=
180+ superClass.getMethod(" eval" , classOf [InternalRow ])
181+ val isNullIntolerantMixedIn = classOf [NullIntolerant ].isAssignableFrom(clazz)
182+ if (isEvalOverrode && isNullIntolerantMixedIn) {
183+ fail(s " ${clazz.getName} should not extend ${classOf [NullIntolerant ].getSimpleName}, " +
184+ s " or add ${clazz.getName} in the ignoreSet of this test. " )
185+ } else if (! isEvalOverrode && ! isNullIntolerantMixedIn) {
186+ fail(s " ${clazz.getName} should extend ${classOf [NullIntolerant ].getSimpleName}. " )
187+ } else {
188+ assert((! isEvalOverrode && isNullIntolerantMixedIn) ||
189+ (isEvalOverrode && ! isNullIntolerantMixedIn))
226190 }
227191 }
228- evalOverrode
229192 }
230193 }
231194}
0 commit comments