Skip to content

Commit 1912436

Browse files
committed
Refactor sort merge join code-gen be agnostic to join type
1 parent 8b94eff commit 1912436

File tree

2 files changed

+84
-81
lines changed

2 files changed

+84
-81
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{Distribution, HashClustered
2525
* Holds common logic for join operators by shuffling two child relations
2626
* using the join keys.
2727
*/
28-
trait ShuffledJoin extends BaseJoinExec {
28+
trait ShuffledJoin extends JoinCodegenSupport {
2929
def isSkewJoin: Boolean
3030

3131
override def nodeName: String = {

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala

Lines changed: 83 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ case class SortMergeJoinExec(
4040
condition: Option[Expression],
4141
left: SparkPlan,
4242
right: SparkPlan,
43-
isSkewJoin: Boolean = false) extends ShuffledJoin with CodegenSupport {
43+
isSkewJoin: Boolean = false) extends ShuffledJoin {
4444

4545
override lazy val metrics = Map(
4646
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
@@ -353,12 +353,22 @@ case class SortMergeJoinExec(
353353
}
354354
}
355355

356+
private lazy val ((streamedPlan, streamedKeys), (bufferedPlan, bufferedKeys)) = joinType match {
357+
case _: InnerLike => ((left, leftKeys), (right, rightKeys))
358+
case x =>
359+
throw new IllegalArgumentException(
360+
s"SortMergeJoin.streamedPlan/bufferedPlan should not take $x as the JoinType")
361+
}
362+
363+
private lazy val streamedOutput = streamedPlan.output
364+
private lazy val bufferedOutput = bufferedPlan.output
365+
356366
override def supportCodegen: Boolean = {
357367
joinType.isInstanceOf[InnerLike]
358368
}
359369

360370
override def inputRDDs(): Seq[RDD[InternalRow]] = {
361-
left.execute() :: right.execute() :: Nil
371+
streamedPlan.execute() :: bufferedPlan.execute() :: Nil
362372
}
363373

364374
private def createJoinKey(
@@ -392,24 +402,24 @@ case class SortMergeJoinExec(
392402
}
393403

394404
/**
395-
* Generate a function to scan both left and right to find a match, returns the term for
396-
* matched one row from left side and buffered rows from right side.
405+
* Generate a function to scan both sides to find a match, returns the term for
406+
* matched one row from streamed side and buffered rows from buffered side.
397407
*/
398408
private def genScanner(ctx: CodegenContext): (String, String) = {
399409
// Create class member for next row from both sides.
400410
// Inline mutable state since not many join operations in a task
401-
val leftRow = ctx.addMutableState("InternalRow", "leftRow", forceInline = true)
402-
val rightRow = ctx.addMutableState("InternalRow", "rightRow", forceInline = true)
411+
val streamedRow = ctx.addMutableState("InternalRow", "streamedRow", forceInline = true)
412+
val bufferedRow = ctx.addMutableState("InternalRow", "bufferedRow", forceInline = true)
403413

404414
// Create variables for join keys from both sides.
405-
val leftKeyVars = createJoinKey(ctx, leftRow, leftKeys, left.output)
406-
val leftAnyNull = leftKeyVars.map(_.isNull).mkString(" || ")
407-
val rightKeyTmpVars = createJoinKey(ctx, rightRow, rightKeys, right.output)
408-
val rightAnyNull = rightKeyTmpVars.map(_.isNull).mkString(" || ")
409-
// Copy the right key as class members so they could be used in next function call.
410-
val rightKeyVars = copyKeys(ctx, rightKeyTmpVars)
411-
412-
// A list to hold all matched rows from right side.
415+
val streamedKeyVars = createJoinKey(ctx, streamedRow, streamedKeys, streamedOutput)
416+
val streamedAnyNull = streamedKeyVars.map(_.isNull).mkString(" || ")
417+
val bufferedKeyTmpVars = createJoinKey(ctx, bufferedRow, bufferedKeys, bufferedOutput)
418+
val bufferedAnyNull = bufferedKeyTmpVars.map(_.isNull).mkString(" || ")
419+
// Copy the buffered key as class members so they could be used in next function call.
420+
val bufferedKeyVars = copyKeys(ctx, bufferedKeyTmpVars)
421+
422+
// A list to hold all matched rows from buffered side.
413423
val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName
414424

415425
val spillThreshold = getSpillThreshold
@@ -418,115 +428,106 @@ case class SortMergeJoinExec(
418428
// Inline mutable state since not many join operations in a task
419429
val matches = ctx.addMutableState(clsName, "matches",
420430
v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);", forceInline = true)
421-
// Copy the left keys as class members so they could be used in next function call.
422-
val matchedKeyVars = copyKeys(ctx, leftKeyVars)
431+
// Copy the streamed keys as class members so they could be used in next function call.
432+
val matchedKeyVars = copyKeys(ctx, streamedKeyVars)
423433

424-
ctx.addNewFunction("findNextInnerJoinRows",
434+
ctx.addNewFunction("findNextJoinRows",
425435
s"""
426-
|private boolean findNextInnerJoinRows(
427-
| scala.collection.Iterator leftIter,
428-
| scala.collection.Iterator rightIter) {
429-
| $leftRow = null;
436+
|private boolean findNextJoinRows(
437+
| scala.collection.Iterator streamedIter,
438+
| scala.collection.Iterator bufferedIter) {
439+
| $streamedRow = null;
430440
| int comp = 0;
431-
| while ($leftRow == null) {
432-
| if (!leftIter.hasNext()) return false;
433-
| $leftRow = (InternalRow) leftIter.next();
434-
| ${leftKeyVars.map(_.code).mkString("\n")}
435-
| if ($leftAnyNull) {
436-
| $leftRow = null;
441+
| while ($streamedRow == null) {
442+
| if (!streamedIter.hasNext()) return false;
443+
| $streamedRow = (InternalRow) streamedIter.next();
444+
| ${streamedKeyVars.map(_.code).mkString("\n")}
445+
| if ($streamedAnyNull) {
446+
| $streamedRow = null;
437447
| continue;
438448
| }
439449
| if (!$matches.isEmpty()) {
440-
| ${genComparison(ctx, leftKeyVars, matchedKeyVars)}
450+
| ${genComparison(ctx, streamedKeyVars, matchedKeyVars)}
441451
| if (comp == 0) {
442452
| return true;
443453
| }
444454
| $matches.clear();
445455
| }
446456
|
447457
| do {
448-
| if ($rightRow == null) {
449-
| if (!rightIter.hasNext()) {
458+
| if ($bufferedRow == null) {
459+
| if (!bufferedIter.hasNext()) {
450460
| ${matchedKeyVars.map(_.code).mkString("\n")}
451461
| return !$matches.isEmpty();
452462
| }
453-
| $rightRow = (InternalRow) rightIter.next();
454-
| ${rightKeyTmpVars.map(_.code).mkString("\n")}
455-
| if ($rightAnyNull) {
456-
| $rightRow = null;
463+
| $bufferedRow = (InternalRow) bufferedIter.next();
464+
| ${bufferedKeyTmpVars.map(_.code).mkString("\n")}
465+
| if ($bufferedAnyNull) {
466+
| $bufferedRow = null;
457467
| continue;
458468
| }
459-
| ${rightKeyVars.map(_.code).mkString("\n")}
469+
| ${bufferedKeyVars.map(_.code).mkString("\n")}
460470
| }
461-
| ${genComparison(ctx, leftKeyVars, rightKeyVars)}
471+
| ${genComparison(ctx, streamedKeyVars, bufferedKeyVars)}
462472
| if (comp > 0) {
463-
| $rightRow = null;
473+
| $bufferedRow = null;
464474
| } else if (comp < 0) {
465475
| if (!$matches.isEmpty()) {
466476
| ${matchedKeyVars.map(_.code).mkString("\n")}
467477
| return true;
468478
| }
469-
| $leftRow = null;
479+
| $streamedRow = null;
470480
| } else {
471-
| $matches.add((UnsafeRow) $rightRow);
472-
| $rightRow = null;
481+
| $matches.add((UnsafeRow) $bufferedRow);
482+
| $bufferedRow = null;
473483
| }
474-
| } while ($leftRow != null);
484+
| } while ($streamedRow != null);
475485
| }
476486
| return false; // unreachable
477487
|}
478488
""".stripMargin, inlineToOuterClass = true)
479489

480-
(leftRow, matches)
490+
(streamedRow, matches)
481491
}
482492

483493
/**
484-
* Creates variables and declarations for left part of result row.
494+
* Creates variables and declarations for streamed part of result row.
485495
*
486496
* In order to defer the access after condition and also only access once in the loop,
487497
* the variables should be declared separately from accessing the columns, we can't use the
488498
* codegen of BoundReference here.
489499
*/
490-
private def createLeftVars(ctx: CodegenContext, leftRow: String): (Seq[ExprCode], Seq[String]) = {
491-
ctx.INPUT_ROW = leftRow
500+
private def createStreamedVars(
501+
ctx: CodegenContext,
502+
streamedRow: String): (Seq[ExprCode], Seq[String]) = {
503+
ctx.INPUT_ROW = streamedRow
492504
left.output.zipWithIndex.map { case (a, i) =>
493505
val value = ctx.freshName("value")
494-
val valueCode = CodeGenerator.getValue(leftRow, a.dataType, i.toString)
506+
val valueCode = CodeGenerator.getValue(streamedRow, a.dataType, i.toString)
495507
val javaType = CodeGenerator.javaType(a.dataType)
496508
val defaultValue = CodeGenerator.defaultValue(a.dataType)
497509
if (a.nullable) {
498510
val isNull = ctx.freshName("isNull")
499511
val code =
500512
code"""
501-
|$isNull = $leftRow.isNullAt($i);
513+
|$isNull = $streamedRow.isNullAt($i);
502514
|$value = $isNull ? $defaultValue : ($valueCode);
503515
""".stripMargin
504-
val leftVarsDecl =
516+
val streamedVarsDecl =
505517
s"""
506518
|boolean $isNull = false;
507519
|$javaType $value = $defaultValue;
508520
""".stripMargin
509521
(ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType)),
510-
leftVarsDecl)
522+
streamedVarsDecl)
511523
} else {
512524
val code = code"$value = $valueCode;"
513-
val leftVarsDecl = s"""$javaType $value = $defaultValue;"""
514-
(ExprCode(code, FalseLiteral, JavaCode.variable(value, a.dataType)), leftVarsDecl)
525+
val streamedVarsDecl = s"""$javaType $value = $defaultValue;"""
526+
(ExprCode(code, FalseLiteral, JavaCode.variable(value, a.dataType)), streamedVarsDecl)
515527
}
516528
}.unzip
517529
}
518530

519-
/**
520-
* Creates the variables for right part of result row, using BoundReference, since the right
521-
* part are accessed inside the loop.
522-
*/
523-
private def createRightVar(ctx: CodegenContext, rightRow: String): Seq[ExprCode] = {
524-
ctx.INPUT_ROW = rightRow
525-
right.output.zipWithIndex.map { case (a, i) =>
526-
BoundReference(i, a.dataType, a.nullable).genCode(ctx)
527-
}
528-
}
529-
530531
/**
531532
* Splits variables based on whether it's used by condition or not, returns the code to create
532533
* these variables before the condition and after the condition.
@@ -554,62 +555,64 @@ case class SortMergeJoinExec(
554555

555556
override def doProduce(ctx: CodegenContext): String = {
556557
// Inline mutable state since not many join operations in a task
557-
val leftInput = ctx.addMutableState("scala.collection.Iterator", "leftInput",
558+
val streamedInput = ctx.addMutableState("scala.collection.Iterator", "streamedInput",
558559
v => s"$v = inputs[0];", forceInline = true)
559-
val rightInput = ctx.addMutableState("scala.collection.Iterator", "rightInput",
560+
val bufferedInput = ctx.addMutableState("scala.collection.Iterator", "bufferedInput",
560561
v => s"$v = inputs[1];", forceInline = true)
561562

562-
val (leftRow, matches) = genScanner(ctx)
563+
val (streamedRow, matches) = genScanner(ctx)
563564

564565
// Create variables for row from both sides.
565-
val (leftVars, leftVarDecl) = createLeftVars(ctx, leftRow)
566-
val rightRow = ctx.freshName("rightRow")
567-
val rightVars = createRightVar(ctx, rightRow)
566+
val (streamedVars, streamedVarDecl) = createStreamedVars(ctx, streamedRow)
567+
val bufferedRow = ctx.freshName("bufferedRow")
568+
val bufferedVars = genBuildSideVars(ctx, bufferedRow, bufferedPlan)
568569

569570
val iterator = ctx.freshName("iterator")
570571
val numOutput = metricTerm(ctx, "numOutputRows")
572+
val resultVars = streamedVars ++ bufferedVars
573+
571574
val (beforeLoop, condCheck) = if (condition.isDefined) {
572575
// Split the code of creating variables based on whether it's used by condition or not.
573576
val loaded = ctx.freshName("loaded")
574-
val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars)
575-
val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars)
577+
val (streamedBefore, streamedAfter) = splitVarsByCondition(streamedOutput, streamedVars)
578+
val (bufferedBefore, bufferedAfter) = splitVarsByCondition(bufferedOutput, bufferedVars)
576579
// Generate code for condition
577-
ctx.currentVars = leftVars ++ rightVars
580+
ctx.currentVars = resultVars
578581
val cond = BindReferences.bindReference(condition.get, output).genCode(ctx)
579582
// evaluate the columns those used by condition before loop
580583
val before = s"""
581584
|boolean $loaded = false;
582-
|$leftBefore
585+
|$streamedBefore
583586
""".stripMargin
584587

585588
val checking = s"""
586-
|$rightBefore
589+
|$bufferedBefore
587590
|${cond.code}
588591
|if (${cond.isNull} || !${cond.value}) continue;
589592
|if (!$loaded) {
590593
| $loaded = true;
591-
| $leftAfter
594+
| $streamedAfter
592595
|}
593-
|$rightAfter
596+
|$bufferedAfter
594597
""".stripMargin
595598
(before, checking)
596599
} else {
597-
(evaluateVariables(leftVars), "")
600+
(evaluateVariables(streamedVars), "")
598601
}
599602

600603
val thisPlan = ctx.addReferenceObj("plan", this)
601604
val eagerCleanup = s"$thisPlan.cleanupResources();"
602605

603606
s"""
604-
|while (findNextInnerJoinRows($leftInput, $rightInput)) {
605-
| ${leftVarDecl.mkString("\n")}
607+
|while (findNextJoinRows($streamedInput, $bufferedInput)) {
608+
| ${streamedVarDecl.mkString("\n")}
606609
| ${beforeLoop.trim}
607610
| scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
608611
| while ($iterator.hasNext()) {
609-
| InternalRow $rightRow = (InternalRow) $iterator.next();
612+
| InternalRow $bufferedRow = (InternalRow) $iterator.next();
610613
| ${condCheck.trim}
611614
| $numOutput.add(1);
612-
| ${consume(ctx, leftVars ++ rightVars)}
615+
| ${consume(ctx, resultVars)}
613616
| }
614617
| if (shouldStop()) return;
615618
|}

0 commit comments

Comments
 (0)