Skip to content

Commit 7ba1974

Browse files
address comments
1 parent c148cfe commit 7ba1974

File tree

2 files changed

+21
-22
lines changed

2 files changed

+21
-22
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -437,24 +437,19 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
437437

438438
/**
439439
* Returns a copy of this node where the given partial function has been recursively applied
440-
* first to this node's children, then this node's subqueries, and finally this node itself
441-
* (post-order). When the partial function does not apply to a given node, it is left unchanged.
440+
* first to the subqueries in this node's children, then this node's children, and finally
441+
* this node itself (post-order). When the partial function does not apply to a given node,
442+
* it is left unchanged.
442443
*/
443444
def transformUpWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = {
444-
val g: PartialFunction[PlanType, PlanType] = new PartialFunction[PlanType, PlanType] {
445-
override def isDefinedAt(x: PlanType): Boolean = true
446-
447-
override def apply(plan: PlanType): PlanType = {
448-
val transformed = plan transformExpressionsUp {
449-
case planExpression: PlanExpression[PlanType] =>
450-
val newPlan = planExpression.plan.transformUpWithSubqueries(f)
451-
planExpression.withNewPlan(newPlan)
452-
}
453-
f.applyOrElse[PlanType, PlanType](transformed, identity)
445+
transformUp { case plan =>
446+
val transformed = plan transformExpressionsUp {
447+
case planExpression: PlanExpression[PlanType] =>
448+
val newPlan = planExpression.plan.transformUpWithSubqueries(f)
449+
planExpression.withNewPlan(newPlan)
454450
}
451+
f.applyOrElse[PlanType, PlanType](transformed, identity)
455452
}
456-
457-
transformUp(g)
458453
}
459454

460455
/**

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowRelationSubquerySuite.scala

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717

1818
package org.apache.spark.sql.catalyst.optimizer
1919

20+
import org.apache.spark.sql.catalyst.analysis.CleanupAliases
2021
import org.apache.spark.sql.catalyst.dsl.expressions._
2122
import org.apache.spark.sql.catalyst.dsl.plans._
22-
import org.apache.spark.sql.catalyst.expressions.{Alias, ScalarSubquery}
23+
import org.apache.spark.sql.catalyst.expressions.ScalarSubquery
2324
import org.apache.spark.sql.catalyst.plans._
24-
import org.apache.spark.sql.catalyst.plans.logical.{DomainJoin, LocalRelation, LogicalPlan, OneRowRelation, Project}
25+
import org.apache.spark.sql.catalyst.plans.logical.{DomainJoin, LocalRelation, LogicalPlan, OneRowRelation}
2526
import org.apache.spark.sql.catalyst.rules.RuleExecutor
2627
import org.apache.spark.sql.internal.SQLConf
2728

@@ -46,7 +47,9 @@ class OptimizeOneRowRelationSubquerySuite extends PlanTest {
4647
val batches =
4748
Batch("Subquery", Once,
4849
OptimizeOneRowRelationSubquery,
49-
PullupCorrelatedPredicates) :: Nil
50+
PullupCorrelatedPredicates) ::
51+
Batch("Cleanup", FixedPoint(10),
52+
CleanupAliases) :: Nil
5053
}
5154

5255
private def assertHasDomainJoin(plan: LogicalPlan): Unit = {
@@ -91,8 +94,8 @@ class OptimizeOneRowRelationSubquerySuite extends PlanTest {
9194
val inner = t0.select('a.as("a1"), 'b.as("b1")).select(('a1 + 'b1).as("c"))
9295
val query = t1.select(ScalarSubquery(inner).as("sub"))
9396
val optimized = Optimize.execute(query.analyze)
94-
val correctAnswer = Project(Alias(Alias(a + b, "c")(), "sub")() :: Nil, t1)
95-
comparePlans(optimized, correctAnswer)
97+
val correctAnswer = t1.select(('a + 'b).as("c").as("sub"))
98+
comparePlans(optimized, correctAnswer.analyze)
9699
}
97100

98101
test("Optimize lateral subquery with multiple projects") {
@@ -111,8 +114,8 @@ class OptimizeOneRowRelationSubquerySuite extends PlanTest {
111114
val inner = t0.select('a.as("b")).select(ScalarSubquery(t0.select('b)).as("s"))
112115
val query = t1.select(ScalarSubquery(inner).as("sub"))
113116
val optimized = Optimize.execute(query.analyze)
114-
val correctAnswer = Project(Alias(Alias(a, "s")(), "sub")() :: Nil, t1)
115-
comparePlans(optimized, correctAnswer)
117+
val correctAnswer = t1.select('a.as("s").as("sub"))
118+
comparePlans(optimized, correctAnswer.analyze)
116119
}
117120

118121
test("Batch should be idempotent") {
@@ -149,8 +152,9 @@ class OptimizeOneRowRelationSubquerySuite extends PlanTest {
149152
}
150153
}
151154

152-
test("Should not optimize subquery with nested subqueries") {
155+
test("Should not optimize subquery with nested subqueries that can't be optimized") {
153156
// SELECT (SELECT (SELECT a WHERE a = 1) FROM (SELECT a AS a)) FROM t1
157+
// Filter (a = 1) cannot be optimized.
154158
val inner = t0.select('a).where('a === 1)
155159
val subquery = t0.select('a.as("a"))
156160
.select(ScalarSubquery(inner).as("s")).select('s + 1)

0 commit comments

Comments
 (0)