Skip to content

Commit 10b4d5a

Browse files
committed
Address PR comments
1 parent e2f7e44 commit 10b4d5a

File tree

2 files changed

+27
-42
lines changed

2 files changed

+27
-42
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -211,12 +211,9 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
211211
.orElse(reorderJoinKeysRecursively(
212212
leftKeys, rightKeys, leftPartitioning, None))
213213
case (Some(PartitioningCollection(partitionings)), _) =>
214-
partitionings.foreach { p =>
215-
reorderJoinKeysRecursively(leftKeys, rightKeys, Some(p), rightPartitioning).map { k =>
216-
return Some(k)
217-
}
218-
}
219-
reorderJoinKeysRecursively(leftKeys, rightKeys, None, rightPartitioning)
214+
partitionings.foldLeft(Option.empty[(Seq[Expression], Seq[Expression])]) { (res, p) =>
215+
res.orElse(reorderJoinKeysRecursively(leftKeys, rightKeys, Some(p), rightPartitioning))
216+
}.orElse(reorderJoinKeysRecursively(leftKeys, rightKeys, None, rightPartitioning))
220217
case (_, Some(PartitioningCollection(partitionings))) =>
221218
partitionings.foreach { p =>
222219
reorderJoinKeysRecursively(leftKeys, rightKeys, leftPartitioning, Some(p)).map { k =>

sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala

Lines changed: 24 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,12 @@ class EnsureRequirementsSuite extends SharedSparkSession {
4141
exprB :: exprA :: Nil, exprA :: exprB :: Nil, Inner, None, plan1, plan2)
4242
EnsureRequirements(spark.sessionState.conf).apply(smjExec1) match {
4343
case SortMergeJoinExec(leftKeys, rightKeys, _, _,
44-
SortExec(_, _,
45-
DummySparkPlan(_, _, PartitioningCollection(leftPartitionings), _, _), _),
46-
SortExec(_, _,
47-
ShuffleExchangeExec(HashPartitioning(rightPartitioningExpressions, _), _, _), _), _) =>
44+
SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _),
45+
SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _), _) =>
4846
assert(leftKeys !== smjExec1.leftKeys)
4947
assert(rightKeys !== smjExec1.rightKeys)
50-
assert(leftKeys === leftPartitionings.head.asInstanceOf[HashPartitioning].expressions)
51-
assert(rightKeys === rightPartitioningExpressions)
48+
assert(leftKeys === Seq(exprA, exprB))
49+
assert(rightKeys === Seq(exprB, exprA))
5250
case other => fail(other.toString)
5351
}
5452

@@ -57,14 +55,12 @@ class EnsureRequirementsSuite extends SharedSparkSession {
5755
exprA :: exprB :: Nil, exprB :: exprA :: Nil, Inner, None, plan2, plan1)
5856
EnsureRequirements(spark.sessionState.conf).apply(smjExec2) match {
5957
case SortMergeJoinExec(leftKeys, rightKeys, _, _,
60-
SortExec(_, _,
61-
ShuffleExchangeExec(HashPartitioning(leftPartitioningExpressions, _), _, _), _),
62-
SortExec(_, _,
63-
DummySparkPlan(_, _, PartitioningCollection(rightPartitionings), _, _), _), _) =>
58+
SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _),
59+
SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), _) =>
6460
assert(leftKeys !== smjExec2.leftKeys)
6561
assert(rightKeys !== smjExec2.rightKeys)
66-
assert(leftKeys === leftPartitioningExpressions)
67-
assert(rightKeys === rightPartitionings.head.asInstanceOf[HashPartitioning].expressions)
62+
assert(leftKeys === Seq(exprB, exprA))
63+
assert(rightKeys === Seq(exprA, exprB))
6864
case other => fail(other.toString)
6965
}
7066

@@ -74,14 +70,12 @@ class EnsureRequirementsSuite extends SharedSparkSession {
7470
exprA :: exprC :: Nil, exprB :: exprA :: Nil, Inner, None, plan1, plan1)
7571
EnsureRequirements(spark.sessionState.conf).apply(smjExec3) match {
7672
case SortMergeJoinExec(leftKeys, rightKeys, _, _,
77-
SortExec(_, _,
78-
ShuffleExchangeExec(HashPartitioning(leftPartitioningExpressions, _), _, _), _),
79-
SortExec(_, _,
80-
DummySparkPlan(_, _, PartitioningCollection(rightPartitionings), _, _), _), _) =>
73+
SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _),
74+
SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), _) =>
8175
assert(leftKeys !== smjExec3.leftKeys)
8276
assert(rightKeys !== smjExec3.rightKeys)
83-
assert(leftKeys === leftPartitioningExpressions)
84-
assert(rightKeys === rightPartitionings.head.asInstanceOf[HashPartitioning].expressions)
77+
assert(leftKeys === Seq(exprC, exprA))
78+
assert(rightKeys === Seq(exprA, exprB))
8579
case other => fail(other.toString)
8680
}
8781
}
@@ -97,14 +91,12 @@ class EnsureRequirementsSuite extends SharedSparkSession {
9791
exprA :: exprB :: Nil, exprC :: exprB :: Nil, Inner, None, plan1, plan2)
9892
EnsureRequirements(spark.sessionState.conf).apply(smjExec1) match {
9993
case SortMergeJoinExec(leftKeys, rightKeys, _, _,
100-
SortExec(_, _,
101-
ShuffleExchangeExec(HashPartitioning(leftPartitioningExpressions, _), _, _), _),
102-
SortExec(_, _,
103-
DummySparkPlan(_, _, HashPartitioning(rightPartitioningExpressions, _), _, _), _), _) =>
94+
SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _),
95+
SortExec(_, _, DummySparkPlan(_, _, _: HashPartitioning, _, _), _), _) =>
10496
assert(leftKeys !== smjExec1.leftKeys)
10597
assert(rightKeys !== smjExec1.rightKeys)
106-
assert(leftKeys === leftPartitioningExpressions)
107-
assert(rightKeys === rightPartitioningExpressions)
98+
assert(leftKeys === Seq(exprB, exprA))
99+
assert(rightKeys === Seq(exprB, exprC))
108100
case other => fail(other.toString)
109101
}
110102

@@ -115,14 +107,12 @@ class EnsureRequirementsSuite extends SharedSparkSession {
115107
exprA :: exprB :: Nil, exprC :: exprB :: Nil, Inner, None, plan1, plan3)
116108
EnsureRequirements(spark.sessionState.conf).apply(smjExec2) match {
117109
case SortMergeJoinExec(leftKeys, rightKeys, _, _,
118-
SortExec(_, _,
119-
ShuffleExchangeExec(HashPartitioning(leftPartitioningExpressions, _), _, _), _),
120-
SortExec(_, _,
121-
DummySparkPlan(_, _, PartitioningCollection(rightPartitionings), _, _), _), _) =>
110+
SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _),
111+
SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), _) =>
122112
assert(leftKeys !== smjExec2.leftKeys)
123113
assert(rightKeys !== smjExec2.rightKeys)
124-
assert(leftKeys === leftPartitioningExpressions)
125-
assert(rightKeys === rightPartitionings.head.asInstanceOf[HashPartitioning].expressions)
114+
assert(leftKeys === Seq(exprB, exprA))
115+
assert(rightKeys === Seq(exprB, exprC))
126116
case other => fail(other.toString)
127117
}
128118

@@ -132,14 +122,12 @@ class EnsureRequirementsSuite extends SharedSparkSession {
132122
exprC :: exprB :: Nil, exprA :: exprB :: Nil, Inner, None, plan3, plan1)
133123
EnsureRequirements(spark.sessionState.conf).apply(smjExec3) match {
134124
case SortMergeJoinExec(leftKeys, rightKeys, _, _,
135-
SortExec(_, _,
136-
DummySparkPlan(_, _, PartitioningCollection(leftPartitionings), _, _), _),
137-
SortExec(_, _,
138-
ShuffleExchangeExec(HashPartitioning(rightPartitioningExpressions, _), _, _), _), _) =>
125+
SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _),
126+
SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _), _) =>
139127
assert(leftKeys !== smjExec3.leftKeys)
140128
assert(rightKeys !== smjExec3.rightKeys)
141-
assert(leftKeys === leftPartitionings.head.asInstanceOf[HashPartitioning].expressions)
142-
assert(rightKeys === rightPartitioningExpressions)
129+
assert(leftKeys === Seq(exprB, exprC))
130+
assert(rightKeys === Seq(exprB, exprA))
143131
case other => fail(other.toString)
144132
}
145133
}

0 commit comments

Comments
 (0)