@@ -40,7 +40,7 @@ case class SortMergeJoinExec(
4040 condition : Option [Expression ],
4141 left : SparkPlan ,
4242 right : SparkPlan ,
43- isSkewJoin : Boolean = false ) extends ShuffledJoin with CodegenSupport {
43+ isSkewJoin : Boolean = false ) extends ShuffledJoin {
4444
4545 override lazy val metrics = Map (
4646 " numOutputRows" -> SQLMetrics .createMetric(sparkContext, " number of output rows" ))
@@ -353,12 +353,22 @@ case class SortMergeJoinExec(
353353 }
354354 }
355355
356+ private lazy val ((streamedPlan, streamedKeys), (bufferedPlan, bufferedKeys)) = joinType match {
357+ case _ : InnerLike => ((left, leftKeys), (right, rightKeys))
358+ case x =>
359+ throw new IllegalArgumentException (
360+ s " SortMergeJoin.streamedPlan/bufferedPlan should not take $x as the JoinType " )
361+ }
362+
363+ private lazy val streamedOutput = streamedPlan.output
364+ private lazy val bufferedOutput = bufferedPlan.output
365+
356366 override def supportCodegen : Boolean = {
357367 joinType.isInstanceOf [InnerLike ]
358368 }
359369
360370 override def inputRDDs (): Seq [RDD [InternalRow ]] = {
361- left .execute() :: right .execute() :: Nil
371+ streamedPlan .execute() :: bufferedPlan .execute() :: Nil
362372 }
363373
364374 private def createJoinKey (
@@ -392,24 +402,24 @@ case class SortMergeJoinExec(
392402 }
393403
394404 /**
395- * Generate a function to scan both left and right to find a match, returns the term for
396- * matched one row from left side and buffered rows from right side.
405+ * Generate a function to scan both sides to find a match, returns the term for
406+ * matched one row from streamed side and buffered rows from buffered side.
397407 */
398408 private def genScanner (ctx : CodegenContext ): (String , String ) = {
399409 // Create class member for next row from both sides.
400410 // Inline mutable state since not many join operations in a task
401- val leftRow = ctx.addMutableState(" InternalRow" , " leftRow " , forceInline = true )
402- val rightRow = ctx.addMutableState(" InternalRow" , " rightRow " , forceInline = true )
411+ val streamedRow = ctx.addMutableState(" InternalRow" , " streamedRow " , forceInline = true )
412+ val bufferedRow = ctx.addMutableState(" InternalRow" , " bufferedRow " , forceInline = true )
403413
404414 // Create variables for join keys from both sides.
405- val leftKeyVars = createJoinKey(ctx, leftRow, leftKeys, left.output )
406- val leftAnyNull = leftKeyVars .map(_.isNull).mkString(" || " )
407- val rightKeyTmpVars = createJoinKey(ctx, rightRow, rightKeys, right.output )
408- val rightAnyNull = rightKeyTmpVars .map(_.isNull).mkString(" || " )
409- // Copy the right key as class members so they could be used in next function call.
410- val rightKeyVars = copyKeys(ctx, rightKeyTmpVars )
411-
412- // A list to hold all matched rows from right side.
415+ val streamedKeyVars = createJoinKey(ctx, streamedRow, streamedKeys, streamedOutput )
416+ val streamedAnyNull = streamedKeyVars .map(_.isNull).mkString(" || " )
417+ val bufferedKeyTmpVars = createJoinKey(ctx, bufferedRow, bufferedKeys, bufferedOutput )
418+ val bufferedAnyNull = bufferedKeyTmpVars .map(_.isNull).mkString(" || " )
419+ // Copy the buffered key as class members so they could be used in next function call.
420+ val bufferedKeyVars = copyKeys(ctx, bufferedKeyTmpVars )
421+
422+ // A list to hold all matched rows from buffered side.
413423 val clsName = classOf [ExternalAppendOnlyUnsafeRowArray ].getName
414424
415425 val spillThreshold = getSpillThreshold
@@ -418,115 +428,106 @@ case class SortMergeJoinExec(
418428 // Inline mutable state since not many join operations in a task
419429 val matches = ctx.addMutableState(clsName, " matches" ,
420430 v => s " $v = new $clsName( $inMemoryThreshold, $spillThreshold); " , forceInline = true )
421- // Copy the left keys as class members so they could be used in next function call.
422- val matchedKeyVars = copyKeys(ctx, leftKeyVars )
431+ // Copy the streamed keys as class members so they could be used in next function call.
432+ val matchedKeyVars = copyKeys(ctx, streamedKeyVars )
423433
424- ctx.addNewFunction(" findNextInnerJoinRows " ,
434+ ctx.addNewFunction(" findNextJoinRows " ,
425435 s """
426- |private boolean findNextInnerJoinRows (
427- | scala.collection.Iterator leftIter ,
428- | scala.collection.Iterator rightIter ) {
429- | $leftRow = null;
436+ |private boolean findNextJoinRows (
437+ | scala.collection.Iterator streamedIter ,
438+ | scala.collection.Iterator bufferedIter ) {
439+ | $streamedRow = null;
430440 | int comp = 0;
431- | while ( $leftRow == null) {
432- | if (!leftIter .hasNext()) return false;
433- | $leftRow = (InternalRow) leftIter .next();
434- | ${leftKeyVars .map(_.code).mkString(" \n " )}
435- | if ( $leftAnyNull ) {
436- | $leftRow = null;
441+ | while ( $streamedRow == null) {
442+ | if (!streamedIter .hasNext()) return false;
443+ | $streamedRow = (InternalRow) streamedIter .next();
444+ | ${streamedKeyVars .map(_.code).mkString(" \n " )}
445+ | if ( $streamedAnyNull ) {
446+ | $streamedRow = null;
437447 | continue;
438448 | }
439449 | if (! $matches.isEmpty()) {
440- | ${genComparison(ctx, leftKeyVars , matchedKeyVars)}
450+ | ${genComparison(ctx, streamedKeyVars , matchedKeyVars)}
441451 | if (comp == 0) {
442452 | return true;
443453 | }
444454 | $matches.clear();
445455 | }
446456 |
447457 | do {
448- | if ( $rightRow == null) {
449- | if (!rightIter .hasNext()) {
458+ | if ( $bufferedRow == null) {
459+ | if (!bufferedIter .hasNext()) {
450460 | ${matchedKeyVars.map(_.code).mkString(" \n " )}
451461 | return ! $matches.isEmpty();
452462 | }
453- | $rightRow = (InternalRow) rightIter .next();
454- | ${rightKeyTmpVars .map(_.code).mkString(" \n " )}
455- | if ( $rightAnyNull ) {
456- | $rightRow = null;
463+ | $bufferedRow = (InternalRow) bufferedIter .next();
464+ | ${bufferedKeyTmpVars .map(_.code).mkString(" \n " )}
465+ | if ( $bufferedAnyNull ) {
466+ | $bufferedRow = null;
457467 | continue;
458468 | }
459- | ${rightKeyVars .map(_.code).mkString(" \n " )}
469+ | ${bufferedKeyVars .map(_.code).mkString(" \n " )}
460470 | }
461- | ${genComparison(ctx, leftKeyVars, rightKeyVars )}
471+ | ${genComparison(ctx, streamedKeyVars, bufferedKeyVars )}
462472 | if (comp > 0) {
463- | $rightRow = null;
473+ | $bufferedRow = null;
464474 | } else if (comp < 0) {
465475 | if (! $matches.isEmpty()) {
466476 | ${matchedKeyVars.map(_.code).mkString(" \n " )}
467477 | return true;
468478 | }
469- | $leftRow = null;
479+ | $streamedRow = null;
470480 | } else {
471- | $matches.add((UnsafeRow) $rightRow );
472- | $rightRow = null;
481+ | $matches.add((UnsafeRow) $bufferedRow );
482+ | $bufferedRow = null;
473483 | }
474- | } while ( $leftRow != null);
484+ | } while ( $streamedRow != null);
475485 | }
476486 | return false; // unreachable
477487 |}
478488 """ .stripMargin, inlineToOuterClass = true )
479489
480- (leftRow , matches)
490+ (streamedRow , matches)
481491 }
482492
483493 /**
484- * Creates variables and declarations for left part of result row.
494+ * Creates variables and declarations for streamed part of result row.
485495 *
486496 * In order to defer the access after condition and also only access once in the loop,
487497 * the variables should be declared separately from accessing the columns, we can't use the
488498 * codegen of BoundReference here.
489499 */
490- private def createLeftVars (ctx : CodegenContext , leftRow : String ): (Seq [ExprCode ], Seq [String ]) = {
491- ctx.INPUT_ROW = leftRow
500+ private def createStreamedVars (
501+ ctx : CodegenContext ,
502+ streamedRow : String ): (Seq [ExprCode ], Seq [String ]) = {
503+ ctx.INPUT_ROW = streamedRow
492504 left.output.zipWithIndex.map { case (a, i) =>
493505 val value = ctx.freshName(" value" )
494- val valueCode = CodeGenerator .getValue(leftRow , a.dataType, i.toString)
506+ val valueCode = CodeGenerator .getValue(streamedRow , a.dataType, i.toString)
495507 val javaType = CodeGenerator .javaType(a.dataType)
496508 val defaultValue = CodeGenerator .defaultValue(a.dataType)
497509 if (a.nullable) {
498510 val isNull = ctx.freshName(" isNull" )
499511 val code =
500512 code """
501- | $isNull = $leftRow .isNullAt( $i);
513+ | $isNull = $streamedRow .isNullAt( $i);
502514 | $value = $isNull ? $defaultValue : ( $valueCode);
503515 """ .stripMargin
504- val leftVarsDecl =
516+ val streamedVarsDecl =
505517 s """
506518 |boolean $isNull = false;
507519 | $javaType $value = $defaultValue;
508520 """ .stripMargin
509521 (ExprCode (code, JavaCode .isNullVariable(isNull), JavaCode .variable(value, a.dataType)),
510- leftVarsDecl )
522+ streamedVarsDecl )
511523 } else {
512524 val code = code " $value = $valueCode; "
513- val leftVarsDecl = s """ $javaType $value = $defaultValue; """
514- (ExprCode (code, FalseLiteral , JavaCode .variable(value, a.dataType)), leftVarsDecl )
525+ val streamedVarsDecl = s """ $javaType $value = $defaultValue; """
526+ (ExprCode (code, FalseLiteral , JavaCode .variable(value, a.dataType)), streamedVarsDecl )
515527 }
516528 }.unzip
517529 }
518530
519- /**
520- * Creates the variables for right part of result row, using BoundReference, since the right
521- * part are accessed inside the loop.
522- */
523- private def createRightVar (ctx : CodegenContext , rightRow : String ): Seq [ExprCode ] = {
524- ctx.INPUT_ROW = rightRow
525- right.output.zipWithIndex.map { case (a, i) =>
526- BoundReference (i, a.dataType, a.nullable).genCode(ctx)
527- }
528- }
529-
530531 /**
531532 * Splits variables based on whether it's used by condition or not, returns the code to create
532533 * these variables before the condition and after the condition.
@@ -554,62 +555,64 @@ case class SortMergeJoinExec(
554555
555556 override def doProduce (ctx : CodegenContext ): String = {
556557 // Inline mutable state since not many join operations in a task
557- val leftInput = ctx.addMutableState(" scala.collection.Iterator" , " leftInput " ,
558+ val streamedInput = ctx.addMutableState(" scala.collection.Iterator" , " streamedInput " ,
558559 v => s " $v = inputs[0]; " , forceInline = true )
559- val rightInput = ctx.addMutableState(" scala.collection.Iterator" , " rightInput " ,
560+ val bufferedInput = ctx.addMutableState(" scala.collection.Iterator" , " bufferedInput " ,
560561 v => s " $v = inputs[1]; " , forceInline = true )
561562
562- val (leftRow , matches) = genScanner(ctx)
563+ val (streamedRow , matches) = genScanner(ctx)
563564
564565 // Create variables for row from both sides.
565- val (leftVars, leftVarDecl ) = createLeftVars (ctx, leftRow )
566- val rightRow = ctx.freshName(" rightRow " )
567- val rightVars = createRightVar (ctx, rightRow )
566+ val (streamedVars, streamedVarDecl ) = createStreamedVars (ctx, streamedRow )
567+ val bufferedRow = ctx.freshName(" bufferedRow " )
568+ val bufferedVars = genBuildSideVars (ctx, bufferedRow, bufferedPlan )
568569
569570 val iterator = ctx.freshName(" iterator" )
570571 val numOutput = metricTerm(ctx, " numOutputRows" )
572+ val resultVars = streamedVars ++ bufferedVars
573+
571574 val (beforeLoop, condCheck) = if (condition.isDefined) {
572575 // Split the code of creating variables based on whether it's used by condition or not.
573576 val loaded = ctx.freshName(" loaded" )
574- val (leftBefore, leftAfter ) = splitVarsByCondition(left.output, leftVars )
575- val (rightBefore, rightAfter ) = splitVarsByCondition(right.output, rightVars )
577+ val (streamedBefore, streamedAfter ) = splitVarsByCondition(streamedOutput, streamedVars )
578+ val (bufferedBefore, bufferedAfter ) = splitVarsByCondition(bufferedOutput, bufferedVars )
576579 // Generate code for condition
577- ctx.currentVars = leftVars ++ rightVars
580+ ctx.currentVars = resultVars
578581 val cond = BindReferences .bindReference(condition.get, output).genCode(ctx)
579582 // evaluate the columns those used by condition before loop
580583 val before = s """
581584 |boolean $loaded = false;
582- | $leftBefore
585+ | $streamedBefore
583586 """ .stripMargin
584587
585588 val checking = s """
586- | $rightBefore
589+ | $bufferedBefore
587590 | ${cond.code}
588591 |if ( ${cond.isNull} || ! ${cond.value}) continue;
589592 |if (! $loaded) {
590593 | $loaded = true;
591- | $leftAfter
594+ | $streamedAfter
592595 |}
593- | $rightAfter
596+ | $bufferedAfter
594597 """ .stripMargin
595598 (before, checking)
596599 } else {
597- (evaluateVariables(leftVars ), " " )
600+ (evaluateVariables(streamedVars ), " " )
598601 }
599602
600603 val thisPlan = ctx.addReferenceObj(" plan" , this )
601604 val eagerCleanup = s " $thisPlan.cleanupResources(); "
602605
603606 s """
604- |while (findNextInnerJoinRows( $leftInput , $rightInput )) {
605- | ${leftVarDecl .mkString(" \n " )}
607+ |while (findNextJoinRows( $streamedInput , $bufferedInput )) {
608+ | ${streamedVarDecl .mkString(" \n " )}
606609 | ${beforeLoop.trim}
607610 | scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
608611 | while ( $iterator.hasNext()) {
609- | InternalRow $rightRow = (InternalRow) $iterator.next();
612+ | InternalRow $bufferedRow = (InternalRow) $iterator.next();
610613 | ${condCheck.trim}
611614 | $numOutput.add(1);
612- | ${consume(ctx, leftVars ++ rightVars )}
615+ | ${consume(ctx, resultVars )}
613616 | }
614617 | if (shouldStop()) return;
615618 |}
0 commit comments