Skip to content

Commit c425024

Browse files
wangyumcloud-fan
authored andcommitted
[SPARK-33847][SQL][FOLLOWUP] Remove the CaseWhen should consider deterministic
### What changes were proposed in this pull request? This pr fix remove the `CaseWhen` if elseValue is empty and other outputs are null because of we should consider deterministic. ### Why are the changes needed? Fix bug. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit test. Closes #30960 from wangyum/SPARK-33847-2. Authored-by: Yuming Wang <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 16c594d commit c425024

File tree

5 files changed

+23
-28
lines changed

5 files changed

+23
-28
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,8 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {
9898
val newBranches = cw.branches.map { case (cond, value) =>
9999
replaceNullWithFalse(cond) -> replaceNullWithFalse(value)
100100
}
101-
if (newBranches.forall(_._2 == FalseLiteral) && cw.elseValue.isEmpty) {
102-
FalseLiteral
103-
} else {
104-
val newElseValue = cw.elseValue.map(replaceNullWithFalse)
105-
CaseWhen(newBranches, newElseValue)
106-
}
101+
val newElseValue = cw.elseValue.map(replaceNullWithFalse).getOrElse(FalseLiteral)
102+
CaseWhen(newBranches, newElseValue)
107103
case i @ If(pred, trueVal, falseVal) if i.dataType == BooleanType =>
108104
If(replaceNullWithFalse(pred), replaceNullWithFalse(trueVal), replaceNullWithFalse(falseVal))
109105
case e if e.dataType == BooleanType =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -515,8 +515,9 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
515515
val (h, t) = branches.span(_._1 != TrueLiteral)
516516
CaseWhen( h :+ t.head, None)
517517

518-
case e @ CaseWhen(branches, Some(elseValue))
519-
if branches.forall(_._2.semanticEquals(elseValue)) =>
518+
case e @ CaseWhen(branches, elseOpt)
519+
if branches.forall(_._2.semanticEquals(elseOpt.getOrElse(Literal(null, e.dataType)))) =>
520+
val elseValue = elseOpt.getOrElse(Literal(null, e.dataType))
520521
// For non-deterministic conditions with side effect, we can not remove it, or change
521522
// the ordering. As a result, we try to remove the deterministic conditions from the tail.
522523
var hitNonDeterministicCond = false
@@ -532,10 +533,6 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
532533
} else {
533534
e.copy(branches = branches.take(i).map(branch => (branch._1, elseValue)))
534535
}
535-
536-
case e @ CaseWhen(branches, None)
537-
if branches.forall(_._2.semanticEquals(Literal(null, e.dataType))) =>
538-
Literal(null, e.dataType)
539536
}
540537
}
541538
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -260,14 +260,13 @@ class PushFoldableIntoBranchesSuite
260260
}
261261

262262
test("SPARK-33847: Remove the CaseWhen if elseValue is empty and other outputs are null") {
263-
Seq(a, LessThan(Rand(1), Literal(0.5))).foreach { condition =>
264-
assertEquivalent(
265-
EqualTo(CaseWhen(Seq((condition, Literal.create(null, IntegerType)))), Literal(2)),
266-
Literal.create(null, BooleanType))
267-
assertEquivalent(
268-
EqualTo(CaseWhen(Seq((condition, Literal("str")))).cast(IntegerType), Literal(2)),
269-
Literal.create(null, BooleanType))
270-
}
263+
assertEquivalent(
264+
EqualTo(CaseWhen(Seq((a, Literal.create(null, IntegerType)))), Literal(2)),
265+
Literal.create(null, BooleanType))
266+
assertEquivalent(
267+
EqualTo(CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal("str")))).cast(IntegerType),
268+
Literal(2)),
269+
CaseWhen(Seq((LessThan(Rand(1), Literal(0.5)), Literal.create(null, BooleanType)))))
271270
}
272271

273272
test("SPARK-33884: simplify CaseWhen clauses with (true and false) and (false and true)") {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
114114
val expectedBranches = Seq(
115115
(UnresolvedAttribute("i") < Literal(10)) -> FalseLiteral,
116116
(UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral)
117-
val expectedCond = CaseWhen(expectedBranches)
117+
val expectedCond = CaseWhen(expectedBranches, FalseLiteral)
118118

119119
testFilter(originalCond, expectedCond)
120120
testJoin(originalCond, expectedCond)
@@ -135,7 +135,7 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
135135
(UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral,
136136
(UnresolvedAttribute("i") > Literal(10)) -> FalseLiteral,
137137
TrueLiteral -> TrueLiteral)
138-
val expectedCond = CaseWhen(expectedBranches)
138+
val expectedCond = CaseWhen(expectedBranches, FalseLiteral)
139139

140140
testFilter(originalCond, expectedCond)
141141
testJoin(originalCond, expectedCond)
@@ -238,7 +238,8 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
238238
FalseLiteral)
239239
val condition = CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue))
240240
val expectedCond = CaseWhen(Seq(
241-
(UnresolvedAttribute("i") > Literal(10), (Literal(2) === nestedCaseWhen) <=> TrueLiteral)))
241+
(UnresolvedAttribute("i") > Literal(10), (Literal(2) === nestedCaseWhen) <=> TrueLiteral)),
242+
FalseLiteral)
242243
testFilter(originalCond = condition, expectedCond = expectedCond)
243244
testJoin(originalCond = condition, expectedCond = expectedCond)
244245
testDelete(originalCond = condition, expectedCond = expectedCond)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -237,11 +237,13 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P
237237
}
238238

239239
test("SPARK-33847: Remove the CaseWhen if elseValue is empty and other outputs are null") {
240-
Seq(GreaterThan('a, 1), GreaterThan(Rand(0), 1)).foreach { condition =>
241-
assertEquivalent(
242-
CaseWhen((condition, Literal.create(null, IntegerType)) :: Nil, None),
243-
Literal.create(null, IntegerType))
244-
}
240+
assertEquivalent(
241+
CaseWhen((GreaterThan('a, 1), Literal.create(null, IntegerType)) :: Nil, None),
242+
Literal.create(null, IntegerType))
243+
244+
assertEquivalent(
245+
CaseWhen((GreaterThan(Rand(0), 0.5), Literal.create(null, IntegerType)) :: Nil, None),
246+
CaseWhen((GreaterThan(Rand(0), 0.5), Literal.create(null, IntegerType)) :: Nil, None))
245247
}
246248

247249
test("SPARK-33884: simplify CaseWhen clauses with (true and false) and (false and true)") {

0 commit comments

Comments
 (0)