Skip to content

Commit bf0b2d9

Browse files
committed
address feedback
1 parent 30ca8cf commit bf0b2d9

File tree

5 files changed

+40
-3
lines changed

5 files changed

+40
-3
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ case class PrintToStderr(child: Expression) extends UnaryExpression {
6464
""")
6565
case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
6666

67+
override lazy val deterministic: Boolean = false
68+
6769
override def nullable: Boolean = true
6870

6971
override def inputTypes: Seq[DataType] = Seq(BooleanType)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1627,6 +1627,8 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
16271627
case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil)
16281628
extends UnaryExpression with NonSQLExpression {
16291629

1630+
override lazy val deterministic: Boolean = false
1631+
16301632
override def dataType: DataType = child.dataType
16311633
override def foldable: Boolean = false
16321634
override def nullable: Boolean = false

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,8 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
390390
case If(TrueLiteral, trueValue, _) => trueValue
391391
case If(FalseLiteral, _, falseValue) => falseValue
392392
case If(Literal(null, _), _, falseValue) => falseValue
393-
case If(_, trueValue, falseValue) if trueValue.semanticEquals(falseValue) => trueValue
393+
case If(cond, trueValue, falseValue)
394+
if cond.deterministic && trueValue.semanticEquals(falseValue) => trueValue
394395

395396
case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) =>
396397
// If there are branches that are always false, remove them.

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.catalyst.optimizer
1919

2020
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
21+
import org.apache.spark.sql.catalyst.dsl.expressions._
2122
import org.apache.spark.sql.catalyst.dsl.plans._
2223
import org.apache.spark.sql.catalyst.expressions._
2324
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
@@ -30,7 +31,8 @@ import org.apache.spark.sql.types.{IntegerType, NullType}
3031
class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
3132

3233
object Optimize extends RuleExecutor[LogicalPlan] {
33-
val batches = Batch("SimplifyConditionals", FixedPoint(50), SimplifyConditionals) :: Nil
34+
val batches = Batch("SimplifyConditionals", FixedPoint(50),
35+
BooleanSimplification, ConstantFolding, SimplifyConditionals) :: Nil
3436
}
3537

3638
protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
@@ -44,6 +46,8 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
4446
private val unreachableBranch = (FalseLiteral, Literal(20))
4547
private val nullBranch = (Literal.create(null, NullType), Literal(30))
4648

49+
private val testRelation = LocalRelation('a.int)
50+
4751
test("simplify if") {
4852
assertEquivalent(
4953
If(TrueLiteral, Literal(10), Literal(20)),
@@ -64,6 +68,34 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
6468
Subtract(Literal(10), Literal(1)),
6569
Add(Literal(6), Literal(3))),
6670
Literal(9))
71+
72+
// For non-deterministic condition, we don't remove the `If` statement.
73+
assertEquivalent(
74+
If(GreaterThan(Rand(0), Literal(0.5)),
75+
Subtract(Literal(10), Literal(1)),
76+
Add(Literal(6), Literal(3))),
77+
If(GreaterThan(Rand(0), Literal(0.5)),
78+
Literal(9),
79+
Literal(9)))
80+
81+
// For non-deterministic condition, we don't remove the `If` statement.
82+
val originalQuery =
83+
testRelation
84+
.select(If(AssertTrue(IsNull(UnresolvedAttribute("a"))),
85+
Subtract(Literal(10), Literal(1)),
86+
Add(Literal(6), Literal(3)))).analyze
87+
88+
val optimized = Optimize.execute(originalQuery.analyze).canonicalized
89+
90+
val correctAnswer =
91+
testRelation
92+
.select(If(AssertTrue(IsNull(UnresolvedAttribute("a"))),
93+
Literal(9),
94+
Literal(9)))
95+
.analyze
96+
.canonicalized
97+
98+
comparePlans(optimized, correctAnswer)
6799
}
68100

69101
test("remove unreachable branches") {

sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ private[sql] trait SQLTestUtilsBase
393393
}
394394

395395
/**
396-
* Returns full path to the given file in the resouce folder
396+
* Returns full path to the given file in the resource folder
397397
*/
398398
protected def testFile(fileName: String): String = {
399399
Thread.currentThread().getContextClassLoader.getResource(fileName).toString

0 commit comments

Comments
 (0)