@@ -19,9 +19,10 @@ package org.apache.spark.sql.execution
1919
2020import org .apache .spark .sql .{DataFrame , QueryTest }
2121import org .apache .spark .sql .internal .SQLConf
22- import org .apache .spark .sql .test .SharedSparkSession
22+ import org .apache .spark .sql .test .{SharedSparkSession , SQLTestUtils }
23+ import org .apache .spark .util .Utils
2324
24- class RemoveRedundantProjectsSuite extends QueryTest with SharedSparkSession {
25+ class RemoveRedundantProjectsSuite extends QueryTest with SharedSparkSession with SQLTestUtils {
2526
2627 private def assertProjectExecCount (df : DataFrame , expected : Integer ): Unit = {
2728 withClue(df.queryExecution) {
@@ -34,110 +35,94 @@ class RemoveRedundantProjectsSuite extends QueryTest with SharedSparkSession {
3435 private def assertProjectExec (query : String , enabled : Integer , disabled : Integer ): Unit = {
3536 val df = sql(query)
3637 assertProjectExecCount(df, enabled)
37- val result1 = df.collect()
38+ val result = df.collect()
3839 withSQLConf(SQLConf .REMOVE_REDUNDANT_PROJECTS_ENABLED .key -> " false" ) {
3940 val df2 = sql(query)
4041 assertProjectExecCount(df2, disabled)
41- QueryTest .sameRows(result1.toSeq, df2.collect().toSeq )
42+ checkAnswer(df2, result )
4243 }
4344 }
4445
45- private def withTestView (f : => Unit ): Unit = {
46- withTempPath { p =>
47- val path = p.getAbsolutePath
48- spark.range(100 ).selectExpr(" id % 10 as key" , " id * 2 as a" ,
49- " id * 3 as b" , " cast(id as string) as c" , " array(id, id + 1, id + 3) as d" )
50- .write.partitionBy(" key" ).parquet(path)
51- spark.read.parquet(path).createOrReplaceTempView(" testView" )
52- f
53- }
46+ private val tmpPath = Utils .createTempDir()
47+
48+ override def beforeAll (): Unit = {
49+ super .beforeAll()
50+ tmpPath.delete()
51+ val path = tmpPath.getAbsolutePath
52+ spark.range(100 ).selectExpr(" id % 10 as key" , " cast(id * 2 as int) as a" ,
53+ " cast(id * 3 as int) as b" , " cast(id as string) as c" , " array(id, id + 1, id + 3) as d" )
54+ .write.partitionBy(" key" ).parquet(path)
55+ spark.read.parquet(path).createOrReplaceTempView(" testView" )
56+ }
57+
58+ override def afterAll (): Unit = {
59+ Utils .deleteRecursively(tmpPath)
60+ super .afterAll()
5461 }
5562
5663 test(" project" ) {
57- withTestView {
58- val query = " select * from testView"
59- assertProjectExec(query, 0 , 0 )
60- }
64+ val query = " select * from testView"
65+ assertProjectExec(query, 0 , 0 )
6166 }
6267
6368 test(" project with filter" ) {
64- withTestView {
65- val query = " select * from testView where a > 5"
66- assertProjectExec(query, 0 , 1 )
67- }
69+ val query = " select * from testView where a > 5"
70+ assertProjectExec(query, 0 , 1 )
6871 }
6972
7073 test(" project with specific column ordering" ) {
71- withTestView {
72- val query = " select key, a, b, c from testView"
73- assertProjectExec(query, 1 , 1 )
74- }
74+ val query = " select key, a, b, c from testView"
75+ assertProjectExec(query, 1 , 1 )
7576 }
7677
7778 test(" project with extra columns" ) {
78- withTestView {
79- val query = " select a, b, c, key, a from testView"
80- assertProjectExec(query, 1 , 1 )
81- }
79+ val query = " select a, b, c, key, a from testView"
80+ assertProjectExec(query, 1 , 1 )
8281 }
8382
8483 test(" project with fewer columns" ) {
85- withTestView {
86- val query = " select a from testView where a > 3"
87- assertProjectExec(query, 1 , 1 )
88- }
84+ val query = " select a from testView where a > 3"
85+ assertProjectExec(query, 1 , 1 )
8986 }
9087
9188 test(" aggregate without ordering requirement" ) {
92- withTestView {
93- val query = " select sum(a) as sum_a, key, last(b) as last_b " +
94- " from (select key, a, b from testView where a > 100) group by key"
95- assertProjectExec(query, 0 , 1 )
96- }
89+ val query = " select sum(a) as sum_a, key, last(b) as last_b " +
90+ " from (select key, a, b from testView where a > 100) group by key"
91+ assertProjectExec(query, 0 , 1 )
9792 }
9893
9994 test(" aggregate with ordering requirement" ) {
100- withTestView {
101- val query = " select a, sum(b) as sum_b from testView group by a"
102- assertProjectExec(query, 1 , 1 )
103- }
95+ val query = " select a, sum(b) as sum_b from testView group by a"
96+ assertProjectExec(query, 1 , 1 )
10497 }
10598
10699 test(" join without ordering requirement" ) {
107- withTestView {
108- val query = " select t1.key, t2.key, t1.a, t2.b from (select key, a, b, c from testView)" +
109- " as t1 join (select key, a, b, c from testView) as t2 on t1.c > t2.c and t1.key > 10"
110- assertProjectExec(query, 1 , 3 )
111- }
100+ val query = " select t1.key, t2.key, t1.a, t2.b from (select key, a, b, c from testView)" +
101+ " as t1 join (select key, a, b, c from testView) as t2 on t1.c > t2.c and t1.key > 10"
102+ assertProjectExec(query, 1 , 3 )
112103 }
113104
114105 test(" join with ordering requirement" ) {
115- withTestView {
116- val query = " select * from (select key, a, c, b from testView) as t1 join " +
117- " (select key, a, b, c from testView) as t2 on t1.key = t2.key where t2.a > 50"
118- assertProjectExec(query, 2 , 2 )
119- }
106+ val query = " select * from (select key, a, c, b from testView) as t1 join " +
107+ " (select key, a, b, c from testView) as t2 on t1.key = t2.key where t2.a > 50"
108+ assertProjectExec(query, 2 , 2 )
120109 }
121110
122111 test(" window function" ) {
123- withTestView {
124- val query = " select key, avg(a) over (partition by key order by a " +
125- " rows between 1 preceding and 1 following) as avg, b from testView"
126- assertProjectExec(query, 1 , 2 )
127- }
112+ val query = " select key, b, avg(a) over (partition by key order by a " +
113+ " rows between 1 preceding and 1 following) as avg from testView"
114+ assertProjectExec(query, 1 , 2 )
128115 }
129116
130117 test(" generate" ) {
131- withTestView {
132- val query = " select a, key, explode(d) from testView where a > 10"
133- assertProjectExec(query, 0 , 1 )
134- }
118+ val query = " select a, key, explode(d) from testView where a > 10"
119+ assertProjectExec(query, 0 , 1 )
135120 }
136121
137122 test(" subquery" ) {
138- withTestView {
139- val query = " select * from testView where a in (select b from testView where key > 5) "
140- assertProjectExec(query, 1 , 1 )
141- }
123+ testData
124+ val query = " select key, value from testData where key in " +
125+ " (select sum(a) from testView where a > 5 group by key) "
126+ assertProjectExec(query, 0 , 1 )
142127 }
143128}
0 commit comments