Skip to content

Commit 0f95a6f

Browse files
committed
fix.
1 parent f9483cb commit 0f95a6f

File tree

4 files changed

+137
-76
lines changed

4 files changed

+137
-76
lines changed

R/pkg/inst/tests/testthat/test_sparkSQL.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2558,8 +2558,8 @@ test_that("coalesce, repartition, numPartitions", {
25582558

25592559
df2 <- repartition(df1, 10)
25602560
expect_equal(getNumPartitions(df2), 10)
2561-
expect_equal(getNumPartitions(coalesce(df2, 13)), 5)
2562-
expect_equal(getNumPartitions(coalesce(df2, 7)), 5)
2561+
expect_equal(getNumPartitions(coalesce(df2, 13)), 10)
2562+
expect_equal(getNumPartitions(coalesce(df2, 7)), 7)
25632563
expect_equal(getNumPartitions(coalesce(df2, 3)), 3)
25642564
})
25652565

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 32 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -562,46 +562,43 @@ object CollapseProject extends Rule[LogicalPlan] {
562562
}
563563

564564
/**
565-
* Combines adjacent [[Repartition]] and [[RepartitionByExpression]] operator combinations
566-
* by keeping only the one.
567-
* 1. For adjacent [[Repartition]]s, collapse into the last [[Repartition]] if their shuffle types
568-
* are the same or the parent's shuffle is true.
569-
* 2. For adjacent [[RepartitionByExpression]]s, collapse into the last [[RepartitionByExpression]].
570-
* 3. When a shuffle-enabled [[Repartition]] is above a [[RepartitionByExpression]], collapse as a
571-
* single [[RepartitionByExpression]] with the expression and the last number of partition.
572-
* 4. When a [[RepartitionByExpression]] is above a [[Repartition]], collapse as a single
573-
* [[RepartitionByExpression]] with the expression and the last number of partition.
565+
* Combines adjacent [[RepartitionOperation]] operators
574566
*/
575567
object CollapseRepartition extends Rule[LogicalPlan] {
576568
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
577-
// Case 1
578-
case r @ Repartition(numPartitions, shuffle, child @ Repartition(_, _, grandChild)) =>
579-
(shuffle, child.shuffle) match {
580-
case (true, true) | (true, false) | (false, false) =>
581-
Repartition(numPartitions, shuffle, grandChild)
582-
case (false, true) if numPartitions >= child.numPartitions =>
583-
child
584-
case _ =>
585-
r
569+
// Case 1: When a Repartition has a child of Repartition or RepartitionByExpression,
570+
// we can collapse it with the child based on the type of shuffle and the specified number
571+
// of partitions.
572+
case r @ Repartition(_, _, child: Repartition) =>
573+
collapseRepartition(r, child)
574+
case r @ Repartition(_, _, child: RepartitionByExpression) =>
575+
collapseRepartition(r, child)
576+
// Case 2: When a RepartitionByExpression has a child of Repartition or RepartitionByExpression
577+
// we can remove the child.
578+
case r @ RepartitionByExpression(_, child: RepartitionByExpression, _) =>
579+
r.copy(child = child.child)
580+
case r @ RepartitionByExpression(_, child: Repartition, _) =>
581+
r.copy(child = child.child)
582+
}
583+
584+
/**
585+
* Collapses the [[Repartition]] with its child [[RepartitionOperation]], if possible.
586+
* - Case 1 the top [[Repartition]] does not enable shuffle (i.e., coalesce API):
587+
* If the last numPartitions is bigger, returns the child node; otherwise, keep unchanged.
588+
* - Case 2 the top [[Repartition]] enables shuffle (i.e., repartition API):
589+
* returns the child node with the last numPartitions.
590+
*/
591+
private def collapseRepartition(r: Repartition, child: RepartitionOperation): LogicalPlan = {
592+
(r.shuffle, child.shuffle) match {
593+
case (false, true) => child match {
594+
case c: Repartition => if (r.numPartitions >= c.numPartitions) c else r
595+
case c: RepartitionByExpression => if (r.numPartitions >= c.numPartitions) c else r
586596
}
587-
// Case 2
588-
case RepartitionByExpression(exprs, RepartitionByExpression(_, grandChild, _), numPartitions) =>
589-
RepartitionByExpression(exprs, grandChild, numPartitions)
590-
// Case 3
591-
case Repartition(numPartitions, _, r: RepartitionByExpression) =>
592-
r.copy(numPartitions = numPartitions)
593-
// Case 3
594-
case r @ Repartition(numPartitions, shuffle, child: RepartitionByExpression) =>
595-
if (shuffle) {
596-
child.copy(numPartitions = numPartitions)
597-
} else if (numPartitions >= child.numPartitions) {
598-
r
599-
} else {
600-
r
597+
case _ => child match {
598+
case child: Repartition => child.copy(numPartitions = r.numPartitions, shuffle = r.shuffle)
599+
case child: RepartitionByExpression => child.copy(numPartitions = r.numPartitions)
601600
}
602-
// Case 4
603-
case RepartitionByExpression(exprs, Repartition(_, _, grandChild), numPartitions) =>
604-
RepartitionByExpression(exprs, grandChild, numPartitions)
601+
}
605602
}
606603
}
607604

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -835,16 +835,23 @@ case class Distinct(child: LogicalPlan) extends UnaryNode {
835835
override def output: Seq[Attribute] = child.output
836836
}
837837

838+
/**
839+
* A base interface for [[RepartitionByExpression]] and [[Repartition]]
840+
*/
841+
abstract class RepartitionOperation(numPartitions: Int) extends UnaryNode {
842+
def shuffle: Boolean
843+
override def output: Seq[Attribute] = child.output
844+
}
845+
838846
/**
839847
* Returns a new RDD that has exactly `numPartitions` partitions. Differs from
840848
* [[RepartitionByExpression]] as this method is called directly by DataFrame's, because the user
841849
* asked for `coalesce` or `repartition`. [[RepartitionByExpression]] is used when the consumer
842850
* of the output requires some specific ordering or distribution of the data.
843851
*/
844852
case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan)
845-
extends UnaryNode {
853+
extends RepartitionOperation(numPartitions) {
846854
require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.")
847-
override def output: Seq[Attribute] = child.output
848855
}
849856

850857
/**
@@ -856,14 +863,12 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan)
856863
case class RepartitionByExpression(
857864
partitionExpressions: Seq[Expression],
858865
child: LogicalPlan,
859-
numPartitions: Int) extends UnaryNode {
866+
numPartitions: Int) extends RepartitionOperation(numPartitions) {
860867

861868
require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.")
862869

863-
override lazy val resolved: Boolean = super.resolved && numPartitions.nonEmpty
864-
865870
override def maxRows: Option[Long] = child.maxRows
866-
override def output: Seq[Attribute] = child.output
871+
override def shuffle: Boolean = true
867872
}
868873

869874
/**

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala

Lines changed: 92 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -34,61 +34,82 @@ class CollapseRepartitionSuite extends PlanTest {
3434

3535

3636
test("collapse two adjacent coalesces into one") {
37-
val query = testRelation
37+
// Always respects the top coalesces amd removes useless coalesce below coalesce
38+
val query1 = testRelation
3839
.coalesce(10)
3940
.coalesce(20)
41+
val query2 = testRelation
42+
.coalesce(30)
43+
.coalesce(20)
44+
45+
val optimized1 = Optimize.execute(query1.analyze)
46+
val optimized2 = Optimize.execute(query2.analyze)
4047

41-
val optimized = Optimize.execute(query.analyze)
4248
val correctAnswer = testRelation.coalesce(20).analyze
4349

44-
comparePlans(optimized, correctAnswer)
50+
comparePlans(optimized1, correctAnswer)
51+
comparePlans(optimized2, correctAnswer)
4552
}
4653

4754
test("collapse two adjacent repartitions into one") {
48-
val query = testRelation
55+
// Always respects the top repartition amd removes useless repartition below repartition
56+
val query1 = testRelation
4957
.repartition(10)
5058
.repartition(20)
59+
val query2 = testRelation
60+
.repartition(30)
61+
.repartition(20)
5162

52-
val optimized = Optimize.execute(query.analyze)
63+
val optimized1 = Optimize.execute(query1.analyze)
64+
val optimized2 = Optimize.execute(query2.analyze)
5365
val correctAnswer = testRelation.repartition(20).analyze
5466

55-
comparePlans(optimized, correctAnswer)
67+
comparePlans(optimized1, correctAnswer)
68+
comparePlans(optimized2, correctAnswer)
5669
}
5770

58-
test("collapse one coalesce and one repartition into one") {
59-
// Remove useless coalesce below repartition
71+
test("coalesce above repartition") {
72+
// Remove useless coalesce above repartition
6073
val query1 = testRelation
74+
.repartition(10)
6175
.coalesce(20)
62-
.repartition(5)
6376

6477
val optimized1 = Optimize.execute(query1.analyze)
65-
val correctAnswer1 = testRelation.repartition(5).analyze
78+
val correctAnswer1 = testRelation.repartition(10).analyze
6679

6780
comparePlans(optimized1, correctAnswer1)
6881

69-
// Remove useless coalesce above repartition when its numPartitions is larger than or equal to
70-
// the child's numPartitions
82+
// No change in this case
7183
val query2 = testRelation
72-
.repartition(5)
84+
.repartition(30)
7385
.coalesce(20)
7486

7587
val optimized2 = Optimize.execute(query2.analyze)
76-
val correctAnswer2 = testRelation.repartition(5).analyze
88+
val correctAnswer2 = query2.analyze
7789

7890
comparePlans(optimized2, correctAnswer2)
91+
}
92+
93+
test("repartition above coalesce") {
94+
// Always respects the top repartition amd removes useless coalesce below repartition
95+
val query1 = testRelation
96+
.coalesce(10)
97+
.repartition(20)
98+
// Remove useless coalesce above repartition
99+
val query2 = testRelation
100+
.coalesce(30)
101+
.repartition(20)
79102

80-
// Keep coalesce above repartition unchanged when its numPartitions is smaller than the child
81-
val query3 = testRelation
82-
.repartition(5)
83-
.coalesce(3)
103+
val optimized1 = Optimize.execute(query1.analyze)
104+
val optimized2 = Optimize.execute(query2.analyze)
84105

85-
val optimized3 = Optimize.execute(query3.analyze)
86-
val correctAnswer3 = testRelation.repartition(5).coalesce(3).analyze
106+
val correctAnswer = testRelation.repartition(20).analyze
87107

88-
comparePlans(optimized3, correctAnswer3)
108+
comparePlans(optimized1, correctAnswer)
109+
comparePlans(optimized2, correctAnswer)
89110
}
90111

91-
test("collapse repartition and repartitionBy into one") {
112+
test("repartitionBy above repartition") {
92113
val query1 = testRelation
93114
.repartition(10)
94115
.distribute('a)(20)
@@ -99,7 +120,7 @@ class CollapseRepartitionSuite extends PlanTest {
99120
comparePlans(optimized1, correctAnswer1)
100121

101122
val query2 = testRelation
102-
.coalesce(10)
123+
.repartition(30)
103124
.distribute('a)(20)
104125

105126
val optimized2 = Optimize.execute(query2.analyze)
@@ -108,7 +129,27 @@ class CollapseRepartitionSuite extends PlanTest {
108129
comparePlans(optimized2, correctAnswer2)
109130
}
110131

111-
test("collapse repartitionBy and repartition into one") {
132+
test("repartitionBy above coalesce") {
133+
val query1 = testRelation
134+
.coalesce(10)
135+
.distribute('a)(20)
136+
137+
val optimized1 = Optimize.execute(query1.analyze)
138+
val correctAnswer1 = testRelation.distribute('a)(20).analyze
139+
140+
comparePlans(optimized1, correctAnswer1)
141+
142+
val query2 = testRelation
143+
.coalesce(20)
144+
.distribute('a)(30)
145+
146+
val optimized2 = Optimize.execute(query2.analyze)
147+
val correctAnswer2 = testRelation.distribute('a)(30).analyze
148+
149+
comparePlans(optimized2, correctAnswer2)
150+
}
151+
152+
test("repartition above repartitionBy") {
112153
val query1 = testRelation
113154
.distribute('a)(20)
114155
.repartition(10)
@@ -123,30 +164,48 @@ class CollapseRepartitionSuite extends PlanTest {
123164
.repartition(30)
124165

125166
val optimized2 = Optimize.execute(query2.analyze)
126-
val correctAnswer2 = testRelation.distribute('a)(20).analyze
167+
val correctAnswer2 = testRelation.distribute('a)(30).analyze
127168

128169
comparePlans(optimized2, correctAnswer2)
129170
}
130171

131172
test("coalesce above repartitionBy") {
132-
val query = testRelation
173+
val query1 = testRelation
133174
.distribute('a)(20)
134175
.coalesce(10)
135176

136-
val optimized = Optimize.execute(query.analyze)
137-
val correctAnswer = testRelation.distribute('a)(20).coalesce(10).analyze
177+
val optimized1 = Optimize.execute(query1.analyze)
178+
val correctAnswer1 = testRelation.distribute('a)(20).coalesce(10).analyze
179+
180+
comparePlans(optimized1, correctAnswer1)
181+
182+
val query2 = testRelation
183+
.distribute('a)(20)
184+
.coalesce(30)
138185

139-
comparePlans(optimized, correctAnswer)
186+
val optimized2 = Optimize.execute(query2.analyze)
187+
val correctAnswer2 = testRelation.distribute('a)(20).analyze
188+
189+
comparePlans(optimized2, correctAnswer2)
140190
}
141191

142192
test("collapse two adjacent repartitionBys into one") {
143-
val query = testRelation
193+
val query1 = testRelation
144194
.distribute('b)(10)
145195
.distribute('a)(20)
146196

147-
val optimized = Optimize.execute(query.analyze)
148-
val correctAnswer = testRelation.distribute('a)(20).analyze
197+
val optimized1 = Optimize.execute(query1.analyze)
198+
val correctAnswer1 = testRelation.distribute('a)(20).analyze
199+
200+
comparePlans(optimized1, correctAnswer1)
201+
202+
val query2 = testRelation
203+
.distribute('b)(30)
204+
.distribute('a)(20)
205+
206+
val optimized2 = Optimize.execute(query2.analyze)
207+
val correctAnswer2 = testRelation.distribute('a)(20).analyze
149208

150-
comparePlans(optimized, correctAnswer)
209+
comparePlans(optimized2, correctAnswer2)
151210
}
152211
}

0 commit comments

Comments
 (0)