|
17 | 17 |
|
18 | 18 | package org.apache.spark.sql.execution.joins |
19 | 19 |
|
| 20 | +import scala.collection.mutable |
| 21 | + |
20 | 22 | import org.apache.spark.TaskContext |
21 | 23 | import org.apache.spark.broadcast.Broadcast |
22 | 24 | import org.apache.spark.rdd.RDD |
@@ -61,25 +63,91 @@ case class BroadcastHashJoinExec( |
61 | 63 | } |
62 | 64 |
|
63 | 65 | override def outputPartitioning: Partitioning = { |
64 | | - def buildKeys: Seq[Expression] = buildSide match { |
65 | | - case BuildLeft => leftKeys |
66 | | - case BuildRight => rightKeys |
| 66 | + val (buildKeys, streamedKeys) = buildSide match { |
| 67 | + case BuildLeft => (leftKeys, rightKeys) |
| 68 | + case BuildRight => (rightKeys, leftKeys) |
67 | 69 | } |
68 | 70 |
|
69 | 71 | joinType match { |
70 | 72 | case _: InnerLike => |
71 | 73 | streamedPlan.outputPartitioning match { |
72 | 74 | case h: HashPartitioning => |
73 | | - PartitioningCollection(Seq(h, HashPartitioning(buildKeys, h.numPartitions))) |
74 | | - case c: PartitioningCollection |
75 | | - if c.partitionings.forall(_.isInstanceOf[HashPartitioning]) => |
76 | | - PartitioningCollection(c.partitionings :+ HashPartitioning(buildKeys, c.numPartitions)) |
| 75 | + getBuildSidePartitioning(h, streamedKeys, buildKeys) match { |
| 76 | + case Some(p) => PartitioningCollection(Seq(h, p)) |
| 77 | + case None => h |
| 78 | + } |
| 79 | + case c: PartitioningCollection => |
| 80 | + c.partitionings.foreach { |
| 81 | + case h: HashPartitioning => |
| 82 | + getBuildSidePartitioning(h, streamedKeys, buildKeys) match { |
| 83 | + case Some(p) => return PartitioningCollection(c.partitionings :+ p) |
| 84 | + case None => () |
| 85 | + } |
| 86 | + case _ => () |
| 87 | + } |
| 88 | + c |
77 | 89 | case other => other |
78 | 90 | } |
79 | 91 | case _ => streamedPlan.outputPartitioning |
80 | 92 | } |
81 | 93 | } |
82 | 94 |
|
| 95 | + /** |
| 96 | + * Returns a partitioning for the build side if the following conditions are met: |
| 97 | + * - The streamed side's output partitioning expressions consist of all the keys |
| 98 | + * from the streamed side, we can add a partitioning for the build side. |
| 99 | + * - There is a one-to-one mapping from streamed keys to build keys. |
| 100 | + * |
| 101 | + * The build side partitioning will have expressions in the same order as the expressions |
| 102 | + * in the streamed side partitioning. For example, for the following setup: |
| 103 | + * - streamed partitioning expressions: Seq(s1, s2) |
| 104 | + * - streamed keys: Seq(c1, c2) |
| 105 | + * - build keys: Seq(b1, b2) |
| 106 | + * the expressions in the build side partitioning will be Seq(b1, b2), not Seq(b2, b1). |
| 107 | + */ |
| 108 | + private def getBuildSidePartitioning( |
| 109 | + streamedPartitioning: HashPartitioning, |
| 110 | + streamedKeys: Seq[Expression], |
| 111 | + buildKeys: Seq[Expression]): Option[HashPartitioning] = { |
| 112 | + if (!satisfiesPartitioning(streamedKeys, streamedPartitioning)) { |
| 113 | + return None |
| 114 | + } |
| 115 | + |
| 116 | + val streamedKeyToBuildKeyMap = mutable.Map.empty[Expression, Expression] |
| 117 | + streamedKeys.zip(buildKeys).foreach { |
| 118 | + case (streamedKey, buildKey) => |
| 119 | + val inserted = streamedKeyToBuildKeyMap.getOrElseUpdate( |
| 120 | + streamedKey.canonicalized, |
| 121 | + buildKey) |
| 122 | + |
| 123 | + if (!inserted.semanticEquals(buildKey)) { |
| 124 | + // One-to-many mapping from streamed keys to build keys found. |
| 125 | + return None |
| 126 | + } |
| 127 | + } |
| 128 | + |
| 129 | + // Ensure the one-to-one mapping from streamed keys to build keys. |
| 130 | + if (streamedKeyToBuildKeyMap.size != streamedKeyToBuildKeyMap.values.toSet.size) { |
| 131 | + return None |
| 132 | + } |
| 133 | + |
| 134 | + // The final expressions are built by mapping stream partitioning expressions -> |
| 135 | + // streamed keys -> build keys. |
| 136 | + val buildPartitioningExpressions = streamedPartitioning.expressions.map { e => |
| 137 | + streamedKeyToBuildKeyMap(e.canonicalized) |
| 138 | + } |
| 139 | + |
| 140 | + Some(HashPartitioning(buildPartitioningExpressions, streamedPartitioning.numPartitions)) |
| 141 | + } |
| 142 | + |
| 143 | + // Returns true if `keys` consist of all the expressions in `partitioning`. |
| 144 | + private def satisfiesPartitioning( |
| 145 | + keys: Seq[Expression], |
| 146 | + partitioning: HashPartitioning): Boolean = { |
| 147 | + partitioning.expressions.length == keys.length && |
| 148 | + partitioning.expressions.forall(e => keys.exists(_.semanticEquals(e))) |
| 149 | + } |
| 150 | + |
83 | 151 | protected override def doExecute(): RDD[InternalRow] = { |
84 | 152 | val numOutputRows = longMetric("numOutputRows") |
85 | 153 |
|
|
0 commit comments