1717
1818package org .apache .spark .sql .catalyst .optimizer
1919
20+ import org .apache .spark .sql .catalyst .analysis .UnresolvedAttribute
21+ import org .apache .spark .sql .catalyst .dsl .expressions ._
2022import org .apache .spark .sql .catalyst .dsl .plans ._
2123import org .apache .spark .sql .catalyst .expressions ._
2224import org .apache .spark .sql .catalyst .expressions .Literal .{FalseLiteral , TrueLiteral }
@@ -29,7 +31,8 @@ import org.apache.spark.sql.types.{IntegerType, NullType}
2931class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
3032
3133 object Optimize extends RuleExecutor [LogicalPlan ] {
32- val batches = Batch (" SimplifyConditionals" , FixedPoint (50 ), SimplifyConditionals ) :: Nil
34+ val batches = Batch (" SimplifyConditionals" , FixedPoint (50 ),
35+ BooleanSimplification , ConstantFolding , SimplifyConditionals ) :: Nil
3336 }
3437
3538 protected def assertEquivalent (e1 : Expression , e2 : Expression ): Unit = {
@@ -43,6 +46,12 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
4346 private val unreachableBranch = (FalseLiteral , Literal (20 ))
4447 private val nullBranch = (Literal .create(null , NullType ), Literal (30 ))
4548
49+ private val testRelation = LocalRelation (' a .int, ' b .string, ' c .boolean)
50+
51+ val isNotNullCond = IsNotNull (UnresolvedAttribute (Seq (" a" )))
52+ val isNullCond = IsNull (UnresolvedAttribute (" b" ))
53+ val notCond = Not (UnresolvedAttribute (" c" ))
54+
4655 test(" simplify if" ) {
4756 assertEquivalent(
4857 If (TrueLiteral , Literal (10 ), Literal (20 )),
@@ -100,4 +109,25 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
100109 None ),
101110 CaseWhen (normalBranch :: trueBranch :: Nil , None ))
102111 }
112+
113+ test(" remove entire CaseWhen if all the outputs are semantic equivalence" ) {
114+ val originalQuery =
115+ testRelation
116+ .select(
117+ CaseWhen ((isNotNullCond, Subtract (Literal (3 ), Literal (2 ))) ::
118+ (isNullCond, Literal (1 )) ::
119+ (notCond, Add (Literal (6 ), Literal (- 5 ))) ::
120+ Nil ,
121+ Add (Literal (2 ), Literal (- 1 ))))
122+ .analyze
123+
124+ val optimized = Optimize .execute(originalQuery.analyze).canonicalized
125+ val correctAnswer =
126+ testRelation
127+ .select(Literal (1 ))
128+ .analyze
129+ .canonicalized
130+
131+ comparePlans(optimized, correctAnswer)
132+ }
103133}
0 commit comments