Skip to content

Commit e4d929c

Browse files
committed
Fix bugs and merge into SimplifyConditionals
1 parent f9e87f0 commit e4d929c

File tree

4 files changed

+22
-64
lines changed

4 files changed

+22
-64
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ abstract class Optimizer(catalogManager: CatalogManager)
103103
ReplaceNullWithFalseInPredicate,
104104
PruneFilters,
105105
SimplifyCasts,
106-
SimplifyIf,
107106
SimplifyCaseConversionExpressions,
108107
RewriteCorrelatedScalarSubquery,
109108
EliminateSerialization,

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.optimizer
1919

2020
import scala.collection.immutable.HashSet
2121
import scala.collection.mutable.{ArrayBuffer, Stack}
22-
2322
import org.apache.spark.sql.catalyst.analysis._
2423
import org.apache.spark.sql.catalyst.expressions._
2524
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
@@ -463,6 +462,8 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
463462
case If(Literal(null, _), _, falseValue) => falseValue
464463
case If(cond, trueValue, falseValue)
465464
if cond.deterministic && trueValue.semanticEquals(falseValue) => trueValue
465+
case If(p, l @ Literal(null, _), FalseLiteral) if !p.nullable => And(p, l)
466+
case If(p, l @ Literal(null, _), TrueLiteral) if !p.nullable => Or(Not(p), l)
466467

467468
case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) =>
468469
// If there are branches that are always false, remove them.
@@ -716,18 +717,6 @@ object SimplifyCasts extends Rule[LogicalPlan] {
716717
}
717718
}
718719

719-
/**
720-
* Simplify if clauses of pattern `if(p, null, true|false)`, by replacing them with
721-
* AND or OR clauses which are simpler and can better be pushed down
722-
*/
723-
object SimplifyIf extends Rule[LogicalPlan] {
724-
val nullLiteral = Literal(null, BooleanType)
725-
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
726-
case If(p, Literal(null, _), FalseLiteral) => And(p, nullLiteral)
727-
case If(p, Literal(null, _), TrueLiteral) => Or(p, nullLiteral)
728-
}
729-
}
730-
731720
/**
732721
* Removes nodes that are not necessary.
733722
*/

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.rules._
2828
import org.apache.spark.sql.types.{BooleanType, IntegerType}
2929

3030

31-
class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
31+
class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with PredicateHelper {
3232

3333
object Optimize extends RuleExecutor[LogicalPlan] {
3434
val batches = Batch("SimplifyConditionals", FixedPoint(50),
@@ -165,4 +165,23 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
165165
Literal(1))
166166
)
167167
}
168+
169+
test("simplify if when then clause is null and else clause is boolean") {
170+
val p = IsNull('a)
171+
val nullLiteral = Literal(null, BooleanType)
172+
assertEquivalent(If(p, nullLiteral, FalseLiteral), And(p, nullLiteral))
173+
assertEquivalent(If(p, nullLiteral, TrueLiteral), Or(IsNotNull('a), nullLiteral))
174+
175+
// the rule should not apply to nullable predicate
176+
Seq(TrueLiteral, FalseLiteral).foreach { b =>
177+
assertEquivalent(If(GreaterThan('a, 42), nullLiteral, b),
178+
If(GreaterThan('a, 42), nullLiteral, b))
179+
}
180+
181+
// check evaluation also
182+
Seq(TrueLiteral, FalseLiteral).foreach { p =>
183+
checkEvaluation(If(p, nullLiteral, FalseLiteral), And(p, nullLiteral).eval(EmptyRow))
184+
checkEvaluation(If(p, nullLiteral, TrueLiteral), Or(Not(p), nullLiteral).eval(EmptyRow))
185+
}
186+
}
168187
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyIfSuite.scala

Lines changed: 0 additions & 49 deletions
This file was deleted.

0 commit comments

Comments
 (0)