Skip to content

Commit 488e051

Browse files
committed
Add more checks
1 parent dedce0c commit 488e051

File tree

2 files changed

+82
-14
lines changed

2 files changed

+82
-14
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala

Lines changed: 75 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.execution.joins
1919

20+
import scala.collection.mutable
21+
2022
import org.apache.spark.TaskContext
2123
import org.apache.spark.broadcast.Broadcast
2224
import org.apache.spark.rdd.RDD
@@ -61,25 +63,91 @@ case class BroadcastHashJoinExec(
6163
}
6264

6365
override def outputPartitioning: Partitioning = {
64-
def buildKeys: Seq[Expression] = buildSide match {
65-
case BuildLeft => leftKeys
66-
case BuildRight => rightKeys
66+
val (buildKeys, streamedKeys) = buildSide match {
67+
case BuildLeft => (leftKeys, rightKeys)
68+
case BuildRight => (rightKeys, leftKeys)
6769
}
6870

6971
joinType match {
7072
case _: InnerLike =>
7173
streamedPlan.outputPartitioning match {
7274
case h: HashPartitioning =>
73-
PartitioningCollection(Seq(h, HashPartitioning(buildKeys, h.numPartitions)))
74-
case c: PartitioningCollection
75-
if c.partitionings.forall(_.isInstanceOf[HashPartitioning]) =>
76-
PartitioningCollection(c.partitionings :+ HashPartitioning(buildKeys, c.numPartitions))
75+
getBuildSidePartitioning(h, streamedKeys, buildKeys) match {
76+
case Some(p) => PartitioningCollection(Seq(h, p))
77+
case None => h
78+
}
79+
case c: PartitioningCollection =>
80+
c.partitionings.foreach {
81+
case h: HashPartitioning =>
82+
getBuildSidePartitioning(h, streamedKeys, buildKeys) match {
83+
case Some(p) => return PartitioningCollection(c.partitionings :+ p)
84+
case None => ()
85+
}
86+
case _ => ()
87+
}
88+
c
7789
case other => other
7890
}
7991
case _ => streamedPlan.outputPartitioning
8092
}
8193
}
8294

95+
/**
96+
* Returns a partitioning for the build side if the following conditions are met:
97+
* - The streamed side's output partitioning expressions consist of all the keys
98+
* from the streamed side, we can add a partitioning for the build side.
99+
* - There is a one-to-one mapping from streamed keys to build keys.
100+
*
101+
* The build side partitioning will have expressions in the same order as the expressions
102+
* in the streamed side partitioning. For example, for the following setup:
103+
* - streamed partitioning expressions: Seq(s1, s2)
104+
* - streamed keys: Seq(c1, c2)
105+
* - build keys: Seq(b1, b2)
106+
* the expressions in the build side partitioning will be Seq(b1, b2), not Seq(b2, b1).
107+
*/
108+
private def getBuildSidePartitioning(
109+
streamedPartitioning: HashPartitioning,
110+
streamedKeys: Seq[Expression],
111+
buildKeys: Seq[Expression]): Option[HashPartitioning] = {
112+
if (!satisfiesPartitioning(streamedKeys, streamedPartitioning)) {
113+
return None
114+
}
115+
116+
val streamedKeyToBuildKeyMap = mutable.Map.empty[Expression, Expression]
117+
streamedKeys.zip(buildKeys).foreach {
118+
case (streamedKey, buildKey) =>
119+
val inserted = streamedKeyToBuildKeyMap.getOrElseUpdate(
120+
streamedKey.canonicalized,
121+
buildKey)
122+
123+
if (!inserted.semanticEquals(buildKey)) {
124+
// One-to-many mapping from streamed keys to build keys found.
125+
return None
126+
}
127+
}
128+
129+
// Ensure the one-to-one mapping from streamed keys to build keys.
130+
if (streamedKeyToBuildKeyMap.size != streamedKeyToBuildKeyMap.values.toSet.size) {
131+
return None
132+
}
133+
134+
// The final expressions are built by mapping stream partitioning expressions ->
135+
// streamed keys -> build keys.
136+
val buildPartitioningExpressions = streamedPartitioning.expressions.map { e =>
137+
streamedKeyToBuildKeyMap(e.canonicalized)
138+
}
139+
140+
Some(HashPartitioning(buildPartitioningExpressions, streamedPartitioning.numPartitions))
141+
}
142+
143+
// Returns true if `keys` consist of all the expressions in `partitioning`.
144+
private def satisfiesPartitioning(
145+
keys: Seq[Expression],
146+
partitioning: HashPartitioning): Boolean = {
147+
partitioning.expressions.length == keys.length &&
148+
partitioning.expressions.forall(e => keys.exists(_.semanticEquals(e)))
149+
}
150+
83151
protected override def doExecute(): RDD[InternalRow] = {
84152
val numOutputRows = longMetric("numOutputRows")
85153

sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -469,14 +469,14 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils
469469
val df1 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i1", "j1")
470470
val df2 = (0 until 20).map(i => (i % 7, i % 11)).toDF("i2", "j2")
471471
val df3 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i3", "j3")
472-
df1.write.format("parquet").bucketBy(8, "i1").saveAsTable("t1")
473-
df3.write.format("parquet").bucketBy(8, "i3").saveAsTable("t3")
472+
df1.write.format("parquet").bucketBy(8, "i1", "j1").saveAsTable("t1")
473+
df3.write.format("parquet").bucketBy(8, "i3", "j3").saveAsTable("t3")
474474
val t1 = spark.table("t1")
475475
val t3 = spark.table("t3")
476476

477477
// join1 is a broadcast join where df2 is broadcasted. Note that output partitioning on the
478478
// streamed side (t1) is HashPartitioning (bucketed files).
479-
val join1 = t1.join(df2, t1("i1") === df2("i2"))
479+
val join1 = t1.join(df2, t1("i1") === df2("i2") && t1("j1") === df2("j2"))
480480
val plan1 = join1.queryExecution.executedPlan
481481
assert(collect(plan1) { case e: ShuffleExchangeExec => e }.isEmpty)
482482
val broadcastJoins = collect(plan1) { case b: BroadcastHashJoinExec => b }
@@ -489,17 +489,17 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils
489489
case _ => fail()
490490
}
491491

492-
// Join on the column from the broadcasted side (i2) and make sure output partitioning
492+
// Join on the column from the broadcasted side (i2, j2) and make sure output partitioning
493493
// is maintained by checking no shuffle exchange is introduced.
494-
val join2 = join1.join(t3, join1("i2") === t3("i3"))
494+
val join2 = join1.join(t3, join1("i2") === t3("i3") && join1("j2") === t3("j3"))
495495
val plan2 = join2.queryExecution.executedPlan
496496
assert(collect(plan2) { case s: SortMergeJoinExec => s }.size == 1)
497497
assert(collect(plan2) { case b: BroadcastHashJoinExec => b }.size == 1)
498498
assert(collect(plan2) { case e: ShuffleExchangeExec => e }.isEmpty)
499499

500-
// Validate the data with boradcast join off.
500+
// Validate the data with broadcast join off.
501501
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
502-
val df = join1.join(t3, join1("i2") === t3("i3"))
502+
val df = join1.join(t3, join1("i2") === t3("i3") && join1("j2") === t3("j3"))
503503
QueryTest.sameRows(join2.collect().toSeq, df.collect().toSeq)
504504
}
505505
}

0 commit comments

Comments
 (0)