Skip to content

Commit dc8de5f

Browse files
committed
remove casewhen if possible
1 parent 2edf17e commit dc8de5f

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,15 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
414414
// these branches can be pruned away
415415
val (h, t) = branches.span(_._1 != TrueLiteral)
416416
CaseWhen( h :+ t.head, None)
417+
418+
case CaseWhen(branches, Some(elseValue)) if {
419+
// With previous rules, it's guaranteed that `branches.length >= 2`
420+
val list = branches.map(_._2) :+ elseValue
421+
list.tail.forall(list.head.semanticEquals)
422+
} =>
423+
// If all the values in the branches and elseValue are the same,
424+
// `CaseWhen` condition can be removed.
425+
elseValue
417426
}
418427
}
419428
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.optimizer
1919

20+
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
21+
import org.apache.spark.sql.catalyst.dsl.expressions._
2022
import org.apache.spark.sql.catalyst.dsl.plans._
2123
import org.apache.spark.sql.catalyst.expressions._
2224
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
@@ -29,7 +31,8 @@ import org.apache.spark.sql.types.{IntegerType, NullType}
2931
class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
3032

3133
object Optimize extends RuleExecutor[LogicalPlan] {
32-
val batches = Batch("SimplifyConditionals", FixedPoint(50), SimplifyConditionals) :: Nil
34+
val batches = Batch("SimplifyConditionals", FixedPoint(50),
35+
BooleanSimplification, ConstantFolding, SimplifyConditionals) :: Nil
3336
}
3437

3538
protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
@@ -43,6 +46,12 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
4346
private val unreachableBranch = (FalseLiteral, Literal(20))
4447
private val nullBranch = (Literal.create(null, NullType), Literal(30))
4548

49+
private val testRelation = LocalRelation('a.int, 'b.string, 'c.boolean)
50+
51+
val isNotNullCond = IsNotNull(UnresolvedAttribute(Seq("a")))
52+
val isNullCond = IsNull(UnresolvedAttribute("b"))
53+
val notCond = Not(UnresolvedAttribute("c"))
54+
4655
test("simplify if") {
4756
assertEquivalent(
4857
If(TrueLiteral, Literal(10), Literal(20)),
@@ -100,4 +109,25 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
100109
None),
101110
CaseWhen(normalBranch :: trueBranch :: Nil, None))
102111
}
112+
113+
test("remove entire CaseWhen if all the outputs are semantic equivalence") {
114+
val originalQuery =
115+
testRelation
116+
.select(
117+
CaseWhen((isNotNullCond, Subtract(Literal(3), Literal(2))) ::
118+
(isNullCond, Literal(1)) ::
119+
(notCond, Add(Literal(6), Literal(-5))) ::
120+
Nil,
121+
Add(Literal(2), Literal(-1))))
122+
.analyze
123+
124+
val optimized = Optimize.execute(originalQuery.analyze).canonicalized
125+
val correctAnswer =
126+
testRelation
127+
.select(Literal(1))
128+
.analyze
129+
.canonicalized
130+
131+
comparePlans(optimized, correctAnswer)
132+
}
103133
}

0 commit comments

Comments
 (0)