Skip to content

Commit 9420d0e

Browse files
committed
restore original test
1 parent 54c2fa5 commit 9420d0e

File tree

5 files changed

+42
-127
lines changed

5 files changed

+42
-127
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -340,12 +340,3 @@ case class BroadcastPartitioning(mode: BroadcastMode) extends Partitioning {
340340
case _ => false
341341
}
342342
}
343-
344-
/**
345-
* A test-only partitioning that just output the "given key / base" as partition id.
346-
*/
347-
case class PassThroughPartitioning(key: Attribute, base: Int, numPartitions: Int)
348-
extends Partitioning {
349-
assert(key.dataType == IntegerType)
350-
override def satisfies0(required: Distribution): Boolean = true
351-
}

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,12 +248,13 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
248248
val startIndices = ShufflePartitionsCoalescer.coalescePartitions(
249249
Array(leftStats, rightStats),
250250
firstPartitionIndex = nonSkewPartitionIndices.head,
251-
lastPartitionIndex = nonSkewPartitionIndices.last,
251+
// `lastPartitionIndex` is exclusive.
252+
lastPartitionIndex = nonSkewPartitionIndices.last + 1,
252253
advisoryTargetSize = conf.targetPostShuffleInputSize)
253254
startIndices.indices.map { i =>
254255
val startIndex = startIndices(i)
255256
val endIndex = if (i == startIndices.length - 1) {
256-
// the `endIndex` is exclusive.
257+
// `endIndex` is exclusive.
257258
nonSkewPartitionIndices.last + 1
258259
} else {
259260
startIndices(i + 1)

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsCoalescer.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@ import org.apache.spark.internal.Logging
2525
object ShufflePartitionsCoalescer extends Logging {
2626

2727
/**
28-
* Coalesce the same range of partitions (firstPartitionIndex to lastPartitionIndex, inclusive)
29-
* from multiple shuffles. This method assumes that all the shuffles have the same number of
30-
* partitions, and the partitions of same index will be read together by one task.
28+
* Coalesce the same range of partitions (`firstPartitionIndex`` to `lastPartitionIndex`, the
29+
* start is inclusive and the end is exclusive) from multiple shuffles. This method assumes that
30+
* all the shuffles have the same number of partitions, and the partitions of same index will be
31+
* read together by one task.
3132
*
3233
* The strategy used to determine the number of coalesced partitions is described as follows.
3334
* To determine the number of coalesced partitions, we have a target size for a coalesced

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -216,11 +216,6 @@ object ShuffleExchangeExec {
216216
override def numPartitions: Int = 1
217217
override def getPartition(key: Any): Int = 0
218218
}
219-
case PassThroughPartitioning(_, _, n) =>
220-
new Partitioner {
221-
override def numPartitions: Int = n
222-
override def getPartition(key: Any): Int = key.asInstanceOf[Int]
223-
}
224219
case _ => sys.error(s"Exchange not implemented for $newPartitioning")
225220
// TODO: Handle BroadcastPartitioning.
226221
}
@@ -240,10 +235,6 @@ object ShuffleExchangeExec {
240235
val projection = UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes)
241236
row => projection(row)
242237
case SinglePartition => identity
243-
case p: PassThroughPartitioning =>
244-
val projection = UnsafeProjection.create(
245-
Divide(p.key, Literal(p.base)) :: Nil, outputAttributes)
246-
row => projection(row).getInt(0)
247238
case _ => sys.error(s"Exchange not implemented for $newPartitioning")
248239
}
249240

sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala

Lines changed: 35 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,14 @@ package org.apache.spark.sql.execution.adaptive
2020
import java.io.File
2121
import java.net.URI
2222

23-
import scala.util.Random
24-
25-
import org.apache.spark.TaskContext
26-
import org.apache.spark.rdd.RDD
2723
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart}
28-
import org.apache.spark.sql.{DataFrame, Dataset, QueryTest, Strategy}
29-
import org.apache.spark.sql.catalyst.InternalRow
30-
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, UnsafeProjection}
31-
import org.apache.spark.sql.catalyst.planning.ScanOperation
32-
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
33-
import org.apache.spark.sql.catalyst.plans.physical.PassThroughPartitioning
34-
import org.apache.spark.sql.execution.{FilterExec, LeafExecNode, ReusedSubqueryExec, SparkPlan}
24+
import org.apache.spark.sql.QueryTest
25+
import org.apache.spark.sql.execution.{ReusedSubqueryExec, SparkPlan}
3526
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ReusedExchangeExec, ShuffleExchangeExec}
3627
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildRight, SortMergeJoinExec}
3728
import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate
3829
import org.apache.spark.sql.internal.SQLConf
3930
import org.apache.spark.sql.test.SharedSparkSession
40-
import org.apache.spark.sql.types.StructType
4131
import org.apache.spark.util.Utils
4232

4333
class AdaptiveQueryExecSuite
@@ -613,23 +603,23 @@ class AdaptiveQueryExecSuite
613603
}
614604
}
615605

606+
// TODO: we need a way to customize data distribution after shuffle, to improve test coverage
607+
// of this case.
616608
test("SPARK-29544: adaptive skew join with different join types") {
617-
// Unfortunately, we can't remove the injected extension. The `SkewJoinTestStrategy` is
618-
// harmless and only affects this test suite.
619-
spark.extensions.injectPlannerStrategy(_ => SkewJoinTestStrategy)
620-
def createRelation(partitionRowCount: Int*): DataFrame = {
621-
val output = new StructType().add("key", "int").toAttributes
622-
Dataset.ofRows(spark, SkewJoinTestSource(output, partitionRowCount))
623-
}
624-
625609
withSQLConf(
626610
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
627611
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
628612
SQLConf.ADAPTIVE_EXECUTION_SKEWED_PARTITION_SIZE_THRESHOLD.key -> "100",
629-
SQLConf.ADAPTIVE_EXECUTION_SKEWED_PARTITION_FACTOR.key -> "2") {
630-
withTempView("t1", "t2") {
631-
createRelation(3100, 100, 3200, 300, 3300, 400, 500).createTempView("t1")
632-
createRelation(3400, 200, 300, 2900, 3200, 100, 600).createTempView("t2")
613+
SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key -> "700") {
614+
withTempView("skewData1", "skewData2") {
615+
spark
616+
.range(0, 1000, 1, 10)
617+
.selectExpr("id % 2 as key1", "id as value1")
618+
.createOrReplaceTempView("skewData1")
619+
spark
620+
.range(0, 1000, 1, 10)
621+
.selectExpr("id % 1 as key2", "id as value2")
622+
.createOrReplaceTempView("skewData2")
633623

634624
def checkSkewJoin(joins: Seq[SortMergeJoinExec], expectedNumPartitions: Int): Unit = {
635625
assert(joins.size == 1 && joins.head.isSkewJoin)
@@ -643,55 +633,45 @@ class AdaptiveQueryExecSuite
643633

644634
// skewed inner join optimization
645635
val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult(
646-
"SELECT * FROM t1 join t2 ON t1.key = t2.key")
636+
"SELECT * FROM skewData1 join skewData2 ON key1 = key2")
637+
// left stats: [3496, 0, 0, 0, 4014]
638+
// right stats:[6292, 0, 0, 0, 0]
647639
// Partition 0: both left and right sides are skewed, and divide into 5 splits, so
648640
// 5 x 5 sub-partitions.
649-
// Partition 1: not skewed, just 1 partition.
650-
// Partition 2: only left side is skewed, and divide into 5 splits, so
651-
// 5 sub-partitions.
652-
// Partition 3: only right side is skewed, and divide into 5 splits, so
641+
// Partition 1, 2, 3: not skewed, and coalesced into 1 partition.
642+
// Partition 4: only left side is skewed, and divide into 5 splits, so
653643
// 5 sub-partitions.
654-
// Partition 4: both left and right sides are skewed, and divide into 5 splits, so
655-
// 5 x 5 sub-partitions.
656-
// Partition 5, 6: not skewed, and coalesced into 1 partition.
657-
// So total (25 + 1 + 5 + 5 + 25 + 1) partitions.
644+
// So total (25 + 1 + 5) partitions.
658645
val innerSmj = findTopLevelSortMergeJoin(innerAdaptivePlan)
659-
checkSkewJoin(innerSmj, 25 + 1 + 5 + 5 + 25 + 1)
646+
checkSkewJoin(innerSmj, 25 + 1 + 5)
660647

661648
// skewed left outer join optimization
662649
val (_, leftAdaptivePlan) = runAdaptiveAndVerifyResult(
663-
"SELECT * FROM t1 left outer join t2 ON t1.key = t2.key")
650+
"SELECT * FROM skewData1 left outer join skewData2 ON key1 = key2")
651+
// left stats: [3496, 0, 0, 0, 4014]
652+
// right stats:[6292, 0, 0, 0, 0]
664653
// Partition 0: both left and right sides are skewed, but left join can't split right side,
665654
// so only left side is divided into 5 splits, and thus 5 sub-partitions.
666-
// Partition 1: not skewed, just 1 partition.
667-
// Partition 2: only left side is skewed, and divide into 5 splits, so
655+
// Partition 1, 2, 3: not skewed, and coalesced into 1 partition.
656+
// Partition 4: only left side is skewed, and divide into 5 splits, so
668657
// 5 sub-partitions.
669-
// Partition 3: only right side is skewed, but left join can't split right side, so just
670-
// 1 partition.
671-
// Partition 4: both left and right sides are skewed, but left join can't split right side,
672-
// so only left side is divided into 5 splits, and thus 5 sub-partitions.
673-
// Partition 5, 6: not skewed, and coalesced into 1 partition.
674-
// So total (5 + 1 + 5 + 1 + 5 + 1) partitions.
658+
// So total (5 + 1 + 5) partitions.
675659
val leftSmj = findTopLevelSortMergeJoin(leftAdaptivePlan)
676-
checkSkewJoin(leftSmj, 5 + 1 + 5 + 1 + 5 + 1)
660+
checkSkewJoin(leftSmj, 5 + 1 + 5)
677661

678662
// skewed right outer join optimization
679663
val (_, rightAdaptivePlan) = runAdaptiveAndVerifyResult(
680-
"SELECT * FROM t1 right outer join t2 ON t1.key = t2.key")
664+
"SELECT * FROM skewData1 right outer join skewData2 ON key1 = key2")
665+
// left stats: [3496, 0, 0, 0, 4014]
666+
// right stats:[6292, 0, 0, 0, 0]
681667
// Partition 0: both left and right sides are skewed, but right join can't split left side,
682668
// so only right side is divided into 5 splits, and thus 5 sub-partitions.
683-
// Partition 1: not skewed, just 1 partition.
684-
// Partition 2: only left side is skewed, but right join can't split left side, so just
669+
// Partition 1, 2, 3: not skewed, and coalesced into 1 partition.
670+
// Partition 4: only left side is skewed, but right join can't split left side, so just
685671
// 1 partition.
686-
// Partition 1 and 2 get coalesced.
687-
// Partition 3: only right side is skewed, and divide into 5 splits, so
688-
// 5 sub-partitions.
689-
// Partition 4: both left and right sides are skewed, but right join can't split left side,
690-
// so only right side is divided into 5 splits, and thus 5 sub-partitions.
691-
// Partition 5, 6: not skewed, and coalesced into 1 partition.
692-
// So total (5 + 1 + 5 + 5 + 1) partitions.
672+
// So total (5 + 1 + 1) partitions.
693673
val rightSmj = findTopLevelSortMergeJoin(rightAdaptivePlan)
694-
checkSkewJoin(rightSmj, 5 + 1 + 5 + 5 + 1)
674+
checkSkewJoin(rightSmj, 5 + 1 + 1)
695675
}
696676
}
697677
}
@@ -735,52 +715,3 @@ class AdaptiveQueryExecSuite
735715
}
736716
}
737717

738-
case class SkewJoinTestSource(output: Seq[Attribute], partitionRowCount: Seq[Int])
739-
extends LeafNode {
740-
override def computeStats(): Statistics = Statistics(Long.MaxValue)
741-
}
742-
743-
case class SkewJoinTestSourceExec(output: Seq[Attribute], partitionRowCount: Seq[Int])
744-
extends LeafExecNode {
745-
746-
override protected def doExecute(): RDD[InternalRow] = {
747-
val sum = partitionRowCount.sum
748-
sparkContext.makeRDD(Seq.empty[Byte], 10).mapPartitions { _ =>
749-
val proj = UnsafeProjection.create(output, output)
750-
val rand = new Random(TaskContext.getPartitionId())
751-
752-
// Each RDD partition generates different partition IDs, but overall the partition ID
753-
// distribution respects the ratio specified in `partitionRowCount`.
754-
Seq.fill(sum / 10) {
755-
val value = rand.nextInt(sum)
756-
var partId = -1
757-
var currentSum = 0
758-
var i = 0
759-
while (partId == -1 && i < partitionRowCount.length) {
760-
currentSum += partitionRowCount(i)
761-
if (value < currentSum) partId = i
762-
i += 1
763-
}
764-
// Increase the partition ID diversity to avoid the join outputing too many results.
765-
InternalRow(rand.nextInt(50) + partId * 100)
766-
}.iterator.map(proj)
767-
}
768-
}
769-
}
770-
771-
object SkewJoinTestStrategy extends Strategy {
772-
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
773-
case ScanOperation(projectList, filters, s: SkewJoinTestSource) =>
774-
assert(projectList == s.output)
775-
val sourceExec = SkewJoinTestSourceExec(s.output, s.partitionRowCount)
776-
val withFilter = if (filters.isEmpty) {
777-
sourceExec
778-
} else {
779-
FilterExec(filters.reduce(And), sourceExec)
780-
}
781-
ShuffleExchangeExec(
782-
PassThroughPartitioning(s.output.head, 100, s.partitionRowCount.length),
783-
withFilter) :: Nil
784-
case _ => Nil
785-
}
786-
}

0 commit comments

Comments
 (0)