Skip to content

Commit 979c759

Browse files
committed
Address all comments and regenerate unit test plan files
1 parent 8eb55c3 commit 979c759

File tree

45 files changed

+1747
-1607
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+1747
-1607
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala

Lines changed: 123 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,18 @@ case class SortMergeJoinExec(
105105
sqlContext.conf.sortMergeJoinExecBufferSpillThreshold
106106
}
107107

108+
// Flag to only buffer first matched row, to avoid buffering unnecessary rows.
109+
private val onlyBufferFirstMatchedRow = (joinType, condition) match {
110+
case (LeftExistence(_), None) => true
111+
case _ => false
112+
}
113+
108114
private def getInMemoryThreshold: Int = {
109-
sqlContext.conf.sortMergeJoinExecBufferInMemoryThreshold
115+
if (onlyBufferFirstMatchedRow) {
116+
1
117+
} else {
118+
sqlContext.conf.sortMergeJoinExecBufferInMemoryThreshold
119+
}
110120
}
111121

112122
protected override def doExecute(): RDD[InternalRow] = {
@@ -236,7 +246,7 @@ case class SortMergeJoinExec(
236246
inMemoryThreshold,
237247
spillThreshold,
238248
cleanupResources,
239-
condition.isEmpty
249+
onlyBufferFirstMatchedRow
240250
)
241251
private[this] val joinRow = new JoinedRow
242252

@@ -273,7 +283,7 @@ case class SortMergeJoinExec(
273283
inMemoryThreshold,
274284
spillThreshold,
275285
cleanupResources,
276-
condition.isEmpty
286+
onlyBufferFirstMatchedRow
277287
)
278288
private[this] val joinRow = new JoinedRow
279289

@@ -317,7 +327,7 @@ case class SortMergeJoinExec(
317327
inMemoryThreshold,
318328
spillThreshold,
319329
cleanupResources,
320-
condition.isEmpty
330+
onlyBufferFirstMatchedRow
321331
)
322332
private[this] val joinRow = new JoinedRow
323333

@@ -424,18 +434,8 @@ case class SortMergeJoinExec(
424434
// A list to hold all matched rows from buffered side.
425435
val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName
426436

427-
// Flag to only buffer first matched row, to avoid buffering unnecessary rows.
428-
val onlyBufferFirstMatchedRow = (joinType, condition) match {
429-
case (LeftSemi, None) => true
430-
case _ => false
431-
}
432-
val inMemoryThreshold =
433-
if (onlyBufferFirstMatchedRow) {
434-
1
435-
} else {
436-
getInMemoryThreshold
437-
}
438437
val spillThreshold = getSpillThreshold
438+
val inMemoryThreshold = getInMemoryThreshold
439439

440440
// Inline mutable state since not many join operations in a task
441441
val matches = ctx.addMutableState(clsName, "matches",
@@ -668,7 +668,7 @@ case class SortMergeJoinExec(
668668
s"SortMergeJoin.doProduce should not take $x as the JoinType")
669669
}
670670

671-
val (beforeLoop, condCheck) = if (condition.isDefined) {
671+
val (streamedBeforeLoop, condCheck) = if (condition.isDefined) {
672672
// Split the code of creating variables based on whether it's used by condition or not.
673673
val loaded = ctx.freshName("loaded")
674674
val (streamedBefore, streamedAfter) = splitVarsByCondition(streamedOutput, streamedVars)
@@ -704,83 +704,124 @@ case class SortMergeJoinExec(
704704
(evaluateVariables(streamedVars), "")
705705
}
706706

707-
val thisPlan = ctx.addReferenceObj("plan", this)
708-
val eagerCleanup = s"$thisPlan.cleanupResources();"
709-
710-
lazy val innerJoin =
707+
val beforeLoop =
711708
s"""
712-
|while (findNextJoinRows($streamedInput, $bufferedInput)) {
713-
| ${streamedVarDecl.mkString("\n")}
714-
| ${beforeLoop.trim}
715-
| scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
716-
| while ($iterator.hasNext()) {
717-
| InternalRow $bufferedRow = (InternalRow) $iterator.next();
718-
| ${condCheck.trim}
719-
| $numOutput.add(1);
720-
| ${consume(ctx, resultVars)}
721-
| }
722-
| if (shouldStop()) return;
723-
|}
724-
|$eagerCleanup
725-
""".stripMargin
726-
727-
lazy val outerJoin = {
728-
val hasOutputRow = ctx.freshName("hasOutputRow")
729-
s"""
730-
|while ($streamedInput.hasNext()) {
731-
| findNextJoinRows($streamedInput, $bufferedInput);
732-
| ${streamedVarDecl.mkString("\n")}
733-
| ${beforeLoop.trim}
734-
| scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
735-
| boolean $hasOutputRow = false;
736-
|
737-
| // the last iteration of this loop is to emit an empty row if there is no matched rows.
738-
| while ($iterator.hasNext() || !$hasOutputRow) {
739-
| InternalRow $bufferedRow = $iterator.hasNext() ?
740-
| (InternalRow) $iterator.next() : null;
741-
| ${condCheck.trim}
742-
| $hasOutputRow = true;
743-
| $numOutput.add(1);
744-
| ${consume(ctx, resultVars)}
745-
| }
746-
| if (shouldStop()) return;
747-
|}
748-
|$eagerCleanup
709+
|${streamedVarDecl.mkString("\n")}
710+
|${streamedBeforeLoop.trim}
711+
|scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
749712
""".stripMargin
750-
}
751-
752-
lazy val semiJoin = {
753-
val hasOutputRow = ctx.freshName("hasOutputRow")
713+
val outputRow =
754714
s"""
755-
|while (findNextJoinRows($streamedInput, $bufferedInput)) {
756-
| ${streamedVarDecl.mkString("\n")}
757-
| ${beforeLoop.trim}
758-
| scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
759-
| boolean $hasOutputRow = false;
760-
|
761-
| while (!$hasOutputRow && $iterator.hasNext()) {
762-
| InternalRow $bufferedRow = (InternalRow) $iterator.next();
763-
| ${condCheck.trim}
764-
| $hasOutputRow = true;
765-
| $numOutput.add(1);
766-
| ${consume(ctx, resultVars)}
767-
| }
768-
| if (shouldStop()) return;
769-
|}
770-
|$eagerCleanup
715+
|$numOutput.add(1);
716+
|${consume(ctx, resultVars)}
771717
""".stripMargin
772-
}
718+
val findNextJoinRows = s"findNextJoinRows($streamedInput, $bufferedInput)"
719+
val thisPlan = ctx.addReferenceObj("plan", this)
720+
val eagerCleanup = s"$thisPlan.cleanupResources();"
773721

774722
joinType match {
775-
case _: InnerLike => innerJoin
776-
case LeftOuter | RightOuter => outerJoin
777-
case LeftSemi => semiJoin
723+
case _: InnerLike =>
724+
codegenInner(findNextJoinRows, beforeLoop, iterator, bufferedRow, condCheck, outputRow,
725+
eagerCleanup)
726+
case LeftOuter | RightOuter =>
727+
codegenOuter(streamedInput, findNextJoinRows, beforeLoop, iterator, bufferedRow, condCheck,
728+
ctx.freshName("hasOutputRow"), outputRow, eagerCleanup)
729+
case LeftSemi =>
730+
codegenSemi(findNextJoinRows, beforeLoop, iterator, bufferedRow, condCheck,
731+
ctx.freshName("hasOutputRow"), outputRow, eagerCleanup)
778732
case x =>
779733
throw new IllegalArgumentException(
780734
s"SortMergeJoin.doProduce should not take $x as the JoinType")
781735
}
782736
}
783737

738+
/**
739+
* Generates the code for Inner join.
740+
*/
741+
private def codegenInner(
742+
findNextJoinRows: String,
743+
beforeLoop: String,
744+
matchIterator: String,
745+
bufferedRow: String,
746+
conditionCheck: String,
747+
outputRow: String,
748+
eagerCleanup: String): String = {
749+
s"""
750+
|while ($findNextJoinRows) {
751+
| ${beforeLoop.trim}
752+
| while ($matchIterator.hasNext()) {
753+
| InternalRow $bufferedRow = (InternalRow) $matchIterator.next();
754+
| ${conditionCheck.trim}
755+
| $outputRow
756+
| }
757+
| if (shouldStop()) return;
758+
|}
759+
|$eagerCleanup
760+
""".stripMargin
761+
}
762+
763+
/**
764+
* Generates the code for Left or Right Outer join.
765+
*/
766+
private def codegenOuter(
767+
streamedInput: String,
768+
findNextJoinRows: String,
769+
beforeLoop: String,
770+
matchIterator: String,
771+
bufferedRow: String,
772+
conditionCheck: String,
773+
hasOutputRow: String,
774+
outputRow: String,
775+
eagerCleanup: String): String = {
776+
s"""
777+
|while ($streamedInput.hasNext()) {
778+
| $findNextJoinRows;
779+
| ${beforeLoop.trim}
780+
| boolean $hasOutputRow = false;
781+
|
782+
| // the last iteration of this loop is to emit an empty row if there is no matched rows.
783+
| while ($matchIterator.hasNext() || !$hasOutputRow) {
784+
| InternalRow $bufferedRow = $matchIterator.hasNext() ?
785+
| (InternalRow) $matchIterator.next() : null;
786+
| ${conditionCheck.trim}
787+
| $hasOutputRow = true;
788+
| $outputRow
789+
| }
790+
| if (shouldStop()) return;
791+
|}
792+
|$eagerCleanup
793+
""".stripMargin
794+
}
795+
796+
/**
797+
* Generates the code for Left Semi join.
798+
*/
799+
private def codegenSemi(
800+
findNextJoinRows: String,
801+
beforeLoop: String,
802+
matchIterator: String,
803+
bufferedRow: String,
804+
conditionCheck: String,
805+
hasOutputRow: String,
806+
outputRow: String,
807+
eagerCleanup: String): String = {
808+
s"""
809+
|while ($findNextJoinRows) {
810+
| ${beforeLoop.trim}
811+
| boolean $hasOutputRow = false;
812+
|
813+
| while (!$hasOutputRow && $matchIterator.hasNext()) {
814+
| InternalRow $bufferedRow = (InternalRow) $matchIterator.next();
815+
| ${conditionCheck.trim}
816+
| $hasOutputRow = true;
817+
| $outputRow
818+
| }
819+
| if (shouldStop()) return;
820+
|}
821+
|$eagerCleanup
822+
""".stripMargin
823+
}
824+
784825
override protected def withNewChildrenInternal(
785826
newLeft: SparkPlan, newRight: SparkPlan): SortMergeJoinExec =
786827
copy(left = newLeft, right = newRight)
@@ -831,8 +872,7 @@ private[joins] class SortMergeJoinScanner(
831872
private[this] var matchJoinKey: InternalRow = _
832873
/** Buffered rows from the buffered side of the join. This is empty if there are no matches. */
833874
private[this] val bufferedMatches: ExternalAppendOnlyUnsafeRowArray =
834-
new ExternalAppendOnlyUnsafeRowArray(if (onlyBufferFirstMatch) 1 else inMemoryThreshold,
835-
spillThreshold)
875+
new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold)
836876

837877
// Initialization (note: do _not_ want to advance streamed here).
838878
advancedBufferedToRowWithNullFreeJoinKey()

0 commit comments

Comments
 (0)