Skip to content

Commit 81c38f8

Browse files
committed
fix
1 parent d07344f commit 81c38f8

File tree

4 files changed

+40
-18
lines changed

4 files changed

+40
-18
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -526,8 +526,9 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
526526
e.copy(branches = branches.take(i).map(branch => (branch._1, elseValue)))
527527
}
528528

529-
case e @ CaseWhen(_, elseValue) if elseValue.isEmpty =>
530-
e.copy(elseValue = Some(Literal.create(null, e.dataType)))
529+
case e @ CaseWhen(branches, elseOpt)
530+
if elseOpt.isEmpty && branches.forall(_._2.semanticEquals(Literal(null, e.dataType))) =>
531+
Literal(null, e.dataType)
531532
}
532533
}
533534
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,20 +122,19 @@ class PushFoldableIntoBranchesSuite
122122
CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), Some(TrueLiteral)))
123123
assertEquivalent(
124124
EqualTo(CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), None), Literal(4)),
125-
CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), Literal.create(null, BooleanType)))
125+
CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), None))
126126

127127
assertEquivalent(
128128
And(EqualTo(caseWhen, Literal(5)), EqualTo(caseWhen, Literal(6))),
129129
FalseLiteral)
130130

131131
// Push down at most one branch is not foldable expressions.
132132
assertEquivalent(EqualTo(CaseWhen(Seq((a, b), (c, Literal(1))), None), Literal(1)),
133-
CaseWhen(Seq((a, EqualTo(b, Literal(1))), (c, TrueLiteral)),
134-
Literal.create(null, BooleanType)))
133+
CaseWhen(Seq((a, EqualTo(b, Literal(1))), (c, TrueLiteral)), None))
135134
assertEquivalent(EqualTo(CaseWhen(Seq((a, b), (c, b + 1)), None), Literal(1)),
136-
EqualTo(CaseWhen(Seq((a, b), (c, b + 1)), Literal.create(null, IntegerType)), Literal(1)))
135+
EqualTo(CaseWhen(Seq((a, b), (c, b + 1)), None), Literal(1)))
137136
assertEquivalent(EqualTo(CaseWhen(Seq((a, b)), None), Literal(1)),
138-
CaseWhen(Seq((a, b === Literal(1))), Literal.create(null, BooleanType)))
137+
EqualTo(CaseWhen(Seq((a, b)), None), Literal(1)))
139138

140139
// Push down non-deterministic expressions.
141140
val nonDeterministic =

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
115115
val expectedBranches = Seq(
116116
(UnresolvedAttribute("i") < Literal(10)) -> FalseLiteral,
117117
(UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral)
118-
val expectedCond = CaseWhen(expectedBranches, FalseLiteral)
118+
val expectedCond = CaseWhen(expectedBranches)
119119

120120
testFilter(originalCond, expectedCond)
121121
testJoin(originalCond, expectedCond)
@@ -241,13 +241,9 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
241241
Literal(2) === nestedCaseWhen,
242242
TrueLiteral,
243243
FalseLiteral)
244-
val branches = Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue)
245-
val condition = CaseWhen(branches)
244+
val condition = CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue))
246245
val expectedCond = CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) ->
247-
CaseWhen(
248-
Seq((UnresolvedAttribute("i") > Literal(20)) -> TrueLiteral),
249-
FalseLiteral)),
250-
FalseLiteral)
246+
CaseWhen(Seq((UnresolvedAttribute("i") > Literal(20)) -> TrueLiteral), FalseLiteral)))
251247
testFilter(originalCond = condition, expectedCond = expectedCond)
252248
testJoin(originalCond = condition, expectedCond = expectedCond)
253249
testDelete(originalCond = condition, expectedCond = expectedCond)
@@ -408,6 +404,18 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
408404
testUpdate(nonAllFalseCond, nonAllFalseCond)
409405
}
410406

407+
test("replace None of elseValue inside CaseWhen if all branches are null") {
408+
val allFalseBranches = Seq(
409+
(UnresolvedAttribute("i") < Literal(10)) -> Literal.create(null, BooleanType),
410+
(UnresolvedAttribute("i") > Literal(40)) -> Literal.create(null, BooleanType))
411+
val allFalseCond = CaseWhen(allFalseBranches)
412+
413+
testFilter(allFalseCond, FalseLiteral)
414+
testJoin(allFalseCond, FalseLiteral)
415+
testDelete(allFalseCond, FalseLiteral)
416+
testUpdate(allFalseCond, FalseLiteral)
417+
}
418+
411419
test("replace None of elseValue inside CaseWhen with PushFoldableIntoBranches") {
412420
val allFalseBranches = Seq(
413421
(UnresolvedAttribute("i") < Literal(10)) -> Literal("a"),

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P
8383
// i.e. removing branches whose conditions are always false
8484
assertEquivalent(
8585
CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: nullBranch :: Nil, None),
86-
CaseWhen(normalBranch :: Nil, Literal.create(null, IntegerType)))
86+
CaseWhen(normalBranch :: Nil, None))
8787
}
8888

8989
test("remove entire CaseWhen if only the else branch is reachable") {
@@ -216,9 +216,23 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P
216216
LessThanOrEqual(Rand(0), UnresolvedAttribute("a")))
217217
}
218218

219-
test("SPARK-33847: Replace None of elseValue inside CaseWhen to null literal") {
219+
test("SPARK-33847: Remove the CaseWhen if elseValue is empty and other outputs are null") {
220220
assertEquivalent(
221-
CaseWhen(normalBranch :: Nil, None),
222-
CaseWhen(normalBranch :: Nil, Literal.create(null, IntegerType)))
221+
CaseWhen((GreaterThan('a, 1), Literal.create(null, IntegerType)) :: Nil,
222+
None),
223+
Literal.create(null, IntegerType))
224+
assertEquivalent(
225+
CaseWhen((GreaterThan(Rand(0), 1), Literal.create(null, IntegerType)) :: Nil,
226+
None),
227+
Literal.create(null, IntegerType))
228+
229+
assertEquivalent(
230+
CaseWhen((GreaterThan('a, 1), Literal.create(null, IntegerType)) :: Nil,
231+
Some(Literal.create(null, IntegerType))),
232+
Literal.create(null, IntegerType))
233+
assertEquivalent(
234+
CaseWhen((GreaterThan('a, 1), Literal(20)) :: (GreaterThan('b, 1), Literal(20)) :: Nil,
235+
Some(Literal(20))),
236+
Literal(20))
223237
}
224238
}

0 commit comments

Comments
 (0)