Skip to content

Commit 265abd3

Browse files
committed
revert
1 parent 1bdff95 commit 265abd3

File tree

2 files changed

+32
-16
lines changed

2 files changed

+32
-16
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -309,11 +309,17 @@ trait DivModLike extends BinaryArithmetic {
309309

310310
override def nullable: Boolean = true
311311

312-
final override def nullSafeEval(input1: Any, input2: Any): Any = {
313-
if (input2 == 0) {
312+
final override def eval(input: InternalRow): Any = {
313+
val input2 = right.eval(input)
314+
if (input2 == null || input2 == 0) {
314315
null
315316
} else {
316-
evalOperation(input1, input2)
317+
val input1 = left.eval(input)
318+
if (input1 == null) {
319+
null
320+
} else {
321+
evalOperation(input1, input2)
322+
}
317323
}
318324
}
319325

@@ -510,18 +516,24 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
510516

511517
override def nullable: Boolean = true
512518

513-
override def nullSafeEval(input1: Any, input2: Any): Any = {
514-
if (input2 == 0) {
519+
override def eval(input: InternalRow): Any = {
520+
val input2 = right.eval(input)
521+
if (input2 == null || input2 == 0) {
515522
null
516523
} else {
517-
input1 match {
518-
case i: Integer => pmod(i, input2.asInstanceOf[java.lang.Integer])
519-
case l: Long => pmod(l, input2.asInstanceOf[java.lang.Long])
520-
case s: Short => pmod(s, input2.asInstanceOf[java.lang.Short])
521-
case b: Byte => pmod(b, input2.asInstanceOf[java.lang.Byte])
522-
case f: Float => pmod(f, input2.asInstanceOf[java.lang.Float])
523-
case d: Double => pmod(d, input2.asInstanceOf[java.lang.Double])
524-
case d: Decimal => pmod(d, input2.asInstanceOf[Decimal])
524+
val input1 = left.eval(input)
525+
if (input1 == null) {
526+
null
527+
} else {
528+
input1 match {
529+
case i: Integer => pmod(i, input2.asInstanceOf[java.lang.Integer])
530+
case l: Long => pmod(l, input2.asInstanceOf[java.lang.Long])
531+
case s: Short => pmod(s, input2.asInstanceOf[java.lang.Short])
532+
case b: Byte => pmod(b, input2.asInstanceOf[java.lang.Byte])
533+
case f: Float => pmod(f, input2.asInstanceOf[java.lang.Float])
534+
case d: Double => pmod(d, input2.asInstanceOf[java.lang.Double])
535+
case d: Decimal => pmod(d, input2.asInstanceOf[Decimal])
536+
}
525537
}
526538
}
527539
}

sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,13 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession {
164164
val exprTypesToCheck = Seq(classOf[UnaryExpression], classOf[BinaryExpression],
165165
classOf[TernaryExpression], classOf[QuaternaryExpression], classOf[SeptenaryExpression])
166166

167+
// Do not check these expressions, because these expressions extend NullIntolerant
168+
// and override the eval method to avoid evaluating input1 if input2 is 0.
169+
val ignoreSet = Set(classOf[IntegralDivide], classOf[Divide], classOf[Remainder], classOf[Pmod])
170+
167171
val candidateExprsToCheck = spark.sessionState.functionRegistry.listFunction()
168172
.map(spark.sessionState.catalog.lookupFunctionInfo).map(_.getClassName)
173+
.filterNot(c => ignoreSet.exists(_.getName.equals(c)))
169174
.map(name => Utils.classForName(name))
170175
.filterNot(classOf[NonSQLExpression].isAssignableFrom)
171176

@@ -175,9 +180,8 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession {
175180
superClass.getMethod("eval", classOf[InternalRow])
176181
val isNullIntolerantMixedIn = classOf[NullIntolerant].isAssignableFrom(clazz)
177182
if (isEvalOverrode && isNullIntolerantMixedIn) {
178-
fail(s"${clazz.getName} overrode the eval method and extended " +
179-
s"${classOf[NullIntolerant].getSimpleName}, which may be incorrect. " +
180-
s"You may need to override the nullSafeEval method.")
183+
fail(s"${clazz.getName} should not extend ${classOf[NullIntolerant].getSimpleName}, " +
184+
s"or add ${clazz.getName} in the ignoreSet of this test.")
181185
} else if (!isEvalOverrode && !isNullIntolerantMixedIn) {
182186
fail(s"${clazz.getName} should extend ${classOf[NullIntolerant].getSimpleName}.")
183187
} else {

0 commit comments

Comments
 (0)