Skip to content

Commit 2613c30

Browse files
address comments
1 parent a24d93f commit 2613c30

File tree

3 files changed

+51
-67
lines changed

3 files changed

+51
-67
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1229,7 +1229,7 @@ object SQLConf {
12291229
.internal()
12301230
.doc("Whether to remove redundant project exec node based on children's output and " +
12311231
"ordering requirement.")
1232-
.version("3.0.0")
1232+
.version("3.1.0")
12331233
.booleanConf
12341234
.createWithDefault(true)
12351235

sql/core/src/main/scala/org/apache/spark/sql/execution/RemoveRedundantProjects.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ case class RemoveRedundantProjects(conf: SQLConf) extends Rule[SparkPlan] {
6161
val keepOrdering = a.aggregateExpressions
6262
.exists(ae => ae.mode.equals(Final) || ae.mode.equals(PartialMerge))
6363
a.mapChildren(removeProject(_, keepOrdering))
64-
case w: WindowExec => w.mapChildren(removeProject(_, false))
6564
case g: GenerateExec => g.mapChildren(removeProject(_, false))
6665
// JoinExec ordering requirement will inherit from its parent. If there is no ProjectExec in
6766
// its ancestors, JoinExec should require output columns to be ordered.

sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala

Lines changed: 50 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@ package org.apache.spark.sql.execution
1919

2020
import org.apache.spark.sql.{DataFrame, QueryTest}
2121
import org.apache.spark.sql.internal.SQLConf
22-
import org.apache.spark.sql.test.SharedSparkSession
22+
import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
23+
import org.apache.spark.util.Utils
2324

24-
class RemoveRedundantProjectsSuite extends QueryTest with SharedSparkSession {
25+
class RemoveRedundantProjectsSuite extends QueryTest with SharedSparkSession with SQLTestUtils {
2526

2627
private def assertProjectExecCount(df: DataFrame, expected: Integer): Unit = {
2728
withClue(df.queryExecution) {
@@ -34,110 +35,94 @@ class RemoveRedundantProjectsSuite extends QueryTest with SharedSparkSession {
3435
private def assertProjectExec(query: String, enabled: Integer, disabled: Integer): Unit = {
3536
val df = sql(query)
3637
assertProjectExecCount(df, enabled)
37-
val result1 = df.collect()
38+
val result = df.collect()
3839
withSQLConf(SQLConf.REMOVE_REDUNDANT_PROJECTS_ENABLED.key -> "false") {
3940
val df2 = sql(query)
4041
assertProjectExecCount(df2, disabled)
41-
QueryTest.sameRows(result1.toSeq, df2.collect().toSeq)
42+
checkAnswer(df2, result)
4243
}
4344
}
4445

45-
private def withTestView(f: => Unit): Unit = {
46-
withTempPath { p =>
47-
val path = p.getAbsolutePath
48-
spark.range(100).selectExpr("id % 10 as key", "id * 2 as a",
49-
"id * 3 as b", "cast(id as string) as c", "array(id, id + 1, id + 3) as d")
50-
.write.partitionBy("key").parquet(path)
51-
spark.read.parquet(path).createOrReplaceTempView("testView")
52-
f
53-
}
46+
private val tmpPath = Utils.createTempDir()
47+
48+
override def beforeAll(): Unit = {
49+
super.beforeAll()
50+
tmpPath.delete()
51+
val path = tmpPath.getAbsolutePath
52+
spark.range(100).selectExpr("id % 10 as key", "cast(id * 2 as int) as a",
53+
"cast(id * 3 as int) as b", "cast(id as string) as c", "array(id, id + 1, id + 3) as d")
54+
.write.partitionBy("key").parquet(path)
55+
spark.read.parquet(path).createOrReplaceTempView("testView")
56+
}
57+
58+
override def afterAll(): Unit = {
59+
Utils.deleteRecursively(tmpPath)
60+
super.afterAll()
5461
}
5562

5663
test("project") {
57-
withTestView {
58-
val query = "select * from testView"
59-
assertProjectExec(query, 0, 0)
60-
}
64+
val query = "select * from testView"
65+
assertProjectExec(query, 0, 0)
6166
}
6267

6368
test("project with filter") {
64-
withTestView {
65-
val query = "select * from testView where a > 5"
66-
assertProjectExec(query, 0, 1)
67-
}
69+
val query = "select * from testView where a > 5"
70+
assertProjectExec(query, 0, 1)
6871
}
6972

7073
test("project with specific column ordering") {
71-
withTestView {
72-
val query = "select key, a, b, c from testView"
73-
assertProjectExec(query, 1, 1)
74-
}
74+
val query = "select key, a, b, c from testView"
75+
assertProjectExec(query, 1, 1)
7576
}
7677

7778
test("project with extra columns") {
78-
withTestView {
79-
val query = "select a, b, c, key, a from testView"
80-
assertProjectExec(query, 1, 1)
81-
}
79+
val query = "select a, b, c, key, a from testView"
80+
assertProjectExec(query, 1, 1)
8281
}
8382

8483
test("project with fewer columns") {
85-
withTestView {
86-
val query = "select a from testView where a > 3"
87-
assertProjectExec(query, 1, 1)
88-
}
84+
val query = "select a from testView where a > 3"
85+
assertProjectExec(query, 1, 1)
8986
}
9087

9188
test("aggregate without ordering requirement") {
92-
withTestView {
93-
val query = "select sum(a) as sum_a, key, last(b) as last_b " +
94-
"from (select key, a, b from testView where a > 100) group by key"
95-
assertProjectExec(query, 0, 1)
96-
}
89+
val query = "select sum(a) as sum_a, key, last(b) as last_b " +
90+
"from (select key, a, b from testView where a > 100) group by key"
91+
assertProjectExec(query, 0, 1)
9792
}
9893

9994
test("aggregate with ordering requirement") {
100-
withTestView {
101-
val query = "select a, sum(b) as sum_b from testView group by a"
102-
assertProjectExec(query, 1, 1)
103-
}
95+
val query = "select a, sum(b) as sum_b from testView group by a"
96+
assertProjectExec(query, 1, 1)
10497
}
10598

10699
test("join without ordering requirement") {
107-
withTestView {
108-
val query = "select t1.key, t2.key, t1.a, t2.b from (select key, a, b, c from testView)" +
109-
" as t1 join (select key, a, b, c from testView) as t2 on t1.c > t2.c and t1.key > 10"
110-
assertProjectExec(query, 1, 3)
111-
}
100+
val query = "select t1.key, t2.key, t1.a, t2.b from (select key, a, b, c from testView)" +
101+
" as t1 join (select key, a, b, c from testView) as t2 on t1.c > t2.c and t1.key > 10"
102+
assertProjectExec(query, 1, 3)
112103
}
113104

114105
test("join with ordering requirement") {
115-
withTestView {
116-
val query = "select * from (select key, a, c, b from testView) as t1 join " +
117-
"(select key, a, b, c from testView) as t2 on t1.key = t2.key where t2.a > 50"
118-
assertProjectExec(query, 2, 2)
119-
}
106+
val query = "select * from (select key, a, c, b from testView) as t1 join " +
107+
"(select key, a, b, c from testView) as t2 on t1.key = t2.key where t2.a > 50"
108+
assertProjectExec(query, 2, 2)
120109
}
121110

122111
test("window function") {
123-
withTestView {
124-
val query = "select key, avg(a) over (partition by key order by a " +
125-
"rows between 1 preceding and 1 following) as avg, b from testView"
126-
assertProjectExec(query, 1, 2)
127-
}
112+
val query = "select key, b, avg(a) over (partition by key order by a " +
113+
"rows between 1 preceding and 1 following) as avg from testView"
114+
assertProjectExec(query, 1, 2)
128115
}
129116

130117
test("generate") {
131-
withTestView {
132-
val query = "select a, key, explode(d) from testView where a > 10"
133-
assertProjectExec(query, 0, 1)
134-
}
118+
val query = "select a, key, explode(d) from testView where a > 10"
119+
assertProjectExec(query, 0, 1)
135120
}
136121

137122
test("subquery") {
138-
withTestView {
139-
val query = "select * from testView where a in (select b from testView where key > 5)"
140-
assertProjectExec(query, 1, 1)
141-
}
123+
testData
124+
val query = "select key, value from testData where key in " +
125+
"(select sum(a) from testView where a > 5 group by key)"
126+
assertProjectExec(query, 0, 1)
142127
}
143128
}

0 commit comments

Comments
 (0)