@@ -105,8 +105,18 @@ case class SortMergeJoinExec(
105105 sqlContext.conf.sortMergeJoinExecBufferSpillThreshold
106106 }
107107
108+ // Flag to only buffer first matched row, to avoid buffering unnecessary rows.
109+ private val onlyBufferFirstMatchedRow = (joinType, condition) match {
110+ case (LeftExistence (_), None ) => true
111+ case _ => false
112+ }
113+
108114 private def getInMemoryThreshold : Int = {
109- sqlContext.conf.sortMergeJoinExecBufferInMemoryThreshold
115+ if (onlyBufferFirstMatchedRow) {
116+ 1
117+ } else {
118+ sqlContext.conf.sortMergeJoinExecBufferInMemoryThreshold
119+ }
110120 }
111121
112122 protected override def doExecute (): RDD [InternalRow ] = {
@@ -236,7 +246,7 @@ case class SortMergeJoinExec(
236246 inMemoryThreshold,
237247 spillThreshold,
238248 cleanupResources,
239- condition.isEmpty
249+ onlyBufferFirstMatchedRow
240250 )
241251 private [this ] val joinRow = new JoinedRow
242252
@@ -273,7 +283,7 @@ case class SortMergeJoinExec(
273283 inMemoryThreshold,
274284 spillThreshold,
275285 cleanupResources,
276- condition.isEmpty
286+ onlyBufferFirstMatchedRow
277287 )
278288 private [this ] val joinRow = new JoinedRow
279289
@@ -317,7 +327,7 @@ case class SortMergeJoinExec(
317327 inMemoryThreshold,
318328 spillThreshold,
319329 cleanupResources,
320- condition.isEmpty
330+ onlyBufferFirstMatchedRow
321331 )
322332 private [this ] val joinRow = new JoinedRow
323333
@@ -424,18 +434,8 @@ case class SortMergeJoinExec(
424434 // A list to hold all matched rows from buffered side.
425435 val clsName = classOf [ExternalAppendOnlyUnsafeRowArray ].getName
426436
427- // Flag to only buffer first matched row, to avoid buffering unnecessary rows.
428- val onlyBufferFirstMatchedRow = (joinType, condition) match {
429- case (LeftSemi , None ) => true
430- case _ => false
431- }
432- val inMemoryThreshold =
433- if (onlyBufferFirstMatchedRow) {
434- 1
435- } else {
436- getInMemoryThreshold
437- }
438437 val spillThreshold = getSpillThreshold
438+ val inMemoryThreshold = getInMemoryThreshold
439439
440440 // Inline mutable state since not many join operations in a task
441441 val matches = ctx.addMutableState(clsName, " matches" ,
@@ -668,7 +668,7 @@ case class SortMergeJoinExec(
668668 s " SortMergeJoin.doProduce should not take $x as the JoinType " )
669669 }
670670
671- val (beforeLoop , condCheck) = if (condition.isDefined) {
671+ val (streamedBeforeLoop , condCheck) = if (condition.isDefined) {
672672 // Split the code of creating variables based on whether it's used by condition or not.
673673 val loaded = ctx.freshName(" loaded" )
674674 val (streamedBefore, streamedAfter) = splitVarsByCondition(streamedOutput, streamedVars)
@@ -704,83 +704,124 @@ case class SortMergeJoinExec(
704704 (evaluateVariables(streamedVars), " " )
705705 }
706706
707- val thisPlan = ctx.addReferenceObj(" plan" , this )
708- val eagerCleanup = s " $thisPlan.cleanupResources(); "
709-
710- lazy val innerJoin =
707+ val beforeLoop =
711708 s """
712- |while (findNextJoinRows( $streamedInput, $bufferedInput)) {
713- | ${streamedVarDecl.mkString(" \n " )}
714- | ${beforeLoop.trim}
715- | scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
716- | while ( $iterator.hasNext()) {
717- | InternalRow $bufferedRow = (InternalRow) $iterator.next();
718- | ${condCheck.trim}
719- | $numOutput.add(1);
720- | ${consume(ctx, resultVars)}
721- | }
722- | if (shouldStop()) return;
723- |}
724- | $eagerCleanup
725- """ .stripMargin
726-
727- lazy val outerJoin = {
728- val hasOutputRow = ctx.freshName(" hasOutputRow" )
729- s """
730- |while ( $streamedInput.hasNext()) {
731- | findNextJoinRows( $streamedInput, $bufferedInput);
732- | ${streamedVarDecl.mkString(" \n " )}
733- | ${beforeLoop.trim}
734- | scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
735- | boolean $hasOutputRow = false;
736- |
737- | // the last iteration of this loop is to emit an empty row if there is no matched rows.
738- | while ( $iterator.hasNext() || ! $hasOutputRow) {
739- | InternalRow $bufferedRow = $iterator.hasNext() ?
740- | (InternalRow) $iterator.next() : null;
741- | ${condCheck.trim}
742- | $hasOutputRow = true;
743- | $numOutput.add(1);
744- | ${consume(ctx, resultVars)}
745- | }
746- | if (shouldStop()) return;
747- |}
748- | $eagerCleanup
709+ | ${streamedVarDecl.mkString(" \n " )}
710+ | ${streamedBeforeLoop.trim}
711+ |scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
749712 """ .stripMargin
750- }
751-
752- lazy val semiJoin = {
753- val hasOutputRow = ctx.freshName(" hasOutputRow" )
713+ val outputRow =
754714 s """
755- |while (findNextJoinRows( $streamedInput, $bufferedInput)) {
756- | ${streamedVarDecl.mkString(" \n " )}
757- | ${beforeLoop.trim}
758- | scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
759- | boolean $hasOutputRow = false;
760- |
761- | while (! $hasOutputRow && $iterator.hasNext()) {
762- | InternalRow $bufferedRow = (InternalRow) $iterator.next();
763- | ${condCheck.trim}
764- | $hasOutputRow = true;
765- | $numOutput.add(1);
766- | ${consume(ctx, resultVars)}
767- | }
768- | if (shouldStop()) return;
769- |}
770- | $eagerCleanup
715+ | $numOutput.add(1);
716+ | ${consume(ctx, resultVars)}
771717 """ .stripMargin
772- }
718+ val findNextJoinRows = s " findNextJoinRows( $streamedInput, $bufferedInput) "
719+ val thisPlan = ctx.addReferenceObj(" plan" , this )
720+ val eagerCleanup = s " $thisPlan.cleanupResources(); "
773721
774722 joinType match {
775- case _ : InnerLike => innerJoin
776- case LeftOuter | RightOuter => outerJoin
777- case LeftSemi => semiJoin
723+ case _ : InnerLike =>
724+ codegenInner(findNextJoinRows, beforeLoop, iterator, bufferedRow, condCheck, outputRow,
725+ eagerCleanup)
726+ case LeftOuter | RightOuter =>
727+ codegenOuter(streamedInput, findNextJoinRows, beforeLoop, iterator, bufferedRow, condCheck,
728+ ctx.freshName(" hasOutputRow" ), outputRow, eagerCleanup)
729+ case LeftSemi =>
730+ codegenSemi(findNextJoinRows, beforeLoop, iterator, bufferedRow, condCheck,
731+ ctx.freshName(" hasOutputRow" ), outputRow, eagerCleanup)
778732 case x =>
779733 throw new IllegalArgumentException (
780734 s " SortMergeJoin.doProduce should not take $x as the JoinType " )
781735 }
782736 }
783737
738+ /**
739+ * Generates the code for Inner join.
740+ */
741+ private def codegenInner (
742+ findNextJoinRows : String ,
743+ beforeLoop : String ,
744+ matchIterator : String ,
745+ bufferedRow : String ,
746+ conditionCheck : String ,
747+ outputRow : String ,
748+ eagerCleanup : String ): String = {
749+ s """
750+ |while ( $findNextJoinRows) {
751+ | ${beforeLoop.trim}
752+ | while ( $matchIterator.hasNext()) {
753+ | InternalRow $bufferedRow = (InternalRow) $matchIterator.next();
754+ | ${conditionCheck.trim}
755+ | $outputRow
756+ | }
757+ | if (shouldStop()) return;
758+ |}
759+ | $eagerCleanup
760+ """ .stripMargin
761+ }
762+
763+ /**
764+ * Generates the code for Left or Right Outer join.
765+ */
766+ private def codegenOuter (
767+ streamedInput : String ,
768+ findNextJoinRows : String ,
769+ beforeLoop : String ,
770+ matchIterator : String ,
771+ bufferedRow : String ,
772+ conditionCheck : String ,
773+ hasOutputRow : String ,
774+ outputRow : String ,
775+ eagerCleanup : String ): String = {
776+ s """
777+ |while ( $streamedInput.hasNext()) {
778+ | $findNextJoinRows;
779+ | ${beforeLoop.trim}
780+ | boolean $hasOutputRow = false;
781+ |
782+ | // the last iteration of this loop is to emit an empty row if there is no matched rows.
783+ | while ( $matchIterator.hasNext() || ! $hasOutputRow) {
784+ | InternalRow $bufferedRow = $matchIterator.hasNext() ?
785+ | (InternalRow) $matchIterator.next() : null;
786+ | ${conditionCheck.trim}
787+ | $hasOutputRow = true;
788+ | $outputRow
789+ | }
790+ | if (shouldStop()) return;
791+ |}
792+ | $eagerCleanup
793+ """ .stripMargin
794+ }
795+
796+ /**
797+ * Generates the code for Left Semi join.
798+ */
799+ private def codegenSemi (
800+ findNextJoinRows : String ,
801+ beforeLoop : String ,
802+ matchIterator : String ,
803+ bufferedRow : String ,
804+ conditionCheck : String ,
805+ hasOutputRow : String ,
806+ outputRow : String ,
807+ eagerCleanup : String ): String = {
808+ s """
809+ |while ( $findNextJoinRows) {
810+ | ${beforeLoop.trim}
811+ | boolean $hasOutputRow = false;
812+ |
813+ | while (! $hasOutputRow && $matchIterator.hasNext()) {
814+ | InternalRow $bufferedRow = (InternalRow) $matchIterator.next();
815+ | ${conditionCheck.trim}
816+ | $hasOutputRow = true;
817+ | $outputRow
818+ | }
819+ | if (shouldStop()) return;
820+ |}
821+ | $eagerCleanup
822+ """ .stripMargin
823+ }
824+
784825 override protected def withNewChildrenInternal (
785826 newLeft : SparkPlan , newRight : SparkPlan ): SortMergeJoinExec =
786827 copy(left = newLeft, right = newRight)
@@ -831,8 +872,7 @@ private[joins] class SortMergeJoinScanner(
831872 private [this ] var matchJoinKey : InternalRow = _
832873 /** Buffered rows from the buffered side of the join. This is empty if there are no matches. */
833874 private [this ] val bufferedMatches : ExternalAppendOnlyUnsafeRowArray =
834- new ExternalAppendOnlyUnsafeRowArray (if (onlyBufferFirstMatch) 1 else inMemoryThreshold,
835- spillThreshold)
875+ new ExternalAppendOnlyUnsafeRowArray (inMemoryThreshold, spillThreshold)
836876
837877 // Initialization (note: do _not_ want to advance streamed here).
838878 advancedBufferedToRowWithNullFreeJoinKey()
0 commit comments