[SPARK-35363][SQL] Refactor sort merge join code-gen be agnostic to join type

### What changes were proposed in this pull request?

This is a pre-requisite of https://github.com/apache/spark/pull/32476, in discussion of https://github.com/apache/spark/pull/32476#issuecomment-836469779 . This is to refactor sort merge join code-gen to depend on streamed/buffered terminology, which makes the code-gen agnostic to different join types and can be extended to support other join types than inner join.

### Why are the changes needed?

Pre-requisite of https://github.com/apache/spark/pull/32476.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Existing unit test in `InnerJoinSuite.scala` for inner join code-gen.

Closes #32495 from c21/smj-refactor.

Authored-by: Cheng Su <chengsu@fb.com>
Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
This commit is contained in:
Cheng Su 2021-05-11 11:21:59 +09:00 committed by Takeshi Yamamuro
parent 44bd0a8bd3
commit c4ca23207b
2 changed files with 83 additions and 80 deletions

View file

@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{Distribution, HashClustered
* Holds common logic for join operators by shuffling two child relations * Holds common logic for join operators by shuffling two child relations
* using the join keys. * using the join keys.
*/ */
trait ShuffledJoin extends BaseJoinExec { trait ShuffledJoin extends JoinCodegenSupport {
def isSkewJoin: Boolean def isSkewJoin: Boolean
override def nodeName: String = { override def nodeName: String = {

View file

@ -40,7 +40,7 @@ case class SortMergeJoinExec(
condition: Option[Expression], condition: Option[Expression],
left: SparkPlan, left: SparkPlan,
right: SparkPlan, right: SparkPlan,
isSkewJoin: Boolean = false) extends ShuffledJoin with CodegenSupport { isSkewJoin: Boolean = false) extends ShuffledJoin {
override lazy val metrics = Map( override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
@ -353,12 +353,22 @@ case class SortMergeJoinExec(
} }
} }
private lazy val ((streamedPlan, streamedKeys), (bufferedPlan, bufferedKeys)) = joinType match {
case _: InnerLike => ((left, leftKeys), (right, rightKeys))
case x =>
throw new IllegalArgumentException(
s"SortMergeJoin.streamedPlan/bufferedPlan should not take $x as the JoinType")
}
private lazy val streamedOutput = streamedPlan.output
private lazy val bufferedOutput = bufferedPlan.output
override def supportCodegen: Boolean = { override def supportCodegen: Boolean = {
joinType.isInstanceOf[InnerLike] joinType.isInstanceOf[InnerLike]
} }
override def inputRDDs(): Seq[RDD[InternalRow]] = { override def inputRDDs(): Seq[RDD[InternalRow]] = {
left.execute() :: right.execute() :: Nil streamedPlan.execute() :: bufferedPlan.execute() :: Nil
} }
private def createJoinKey( private def createJoinKey(
@ -392,24 +402,24 @@ case class SortMergeJoinExec(
} }
/** /**
* Generate a function to scan both left and right to find a match, returns the term for * Generate a function to scan both sides to find a match, returns the term for
* matched one row from left side and buffered rows from right side. * matched one row from streamed side and buffered rows from buffered side.
*/ */
private def genScanner(ctx: CodegenContext): (String, String) = { private def genScanner(ctx: CodegenContext): (String, String) = {
// Create class member for next row from both sides. // Create class member for next row from both sides.
// Inline mutable state since not many join operations in a task // Inline mutable state since not many join operations in a task
val leftRow = ctx.addMutableState("InternalRow", "leftRow", forceInline = true) val streamedRow = ctx.addMutableState("InternalRow", "streamedRow", forceInline = true)
val rightRow = ctx.addMutableState("InternalRow", "rightRow", forceInline = true) val bufferedRow = ctx.addMutableState("InternalRow", "bufferedRow", forceInline = true)
// Create variables for join keys from both sides. // Create variables for join keys from both sides.
val leftKeyVars = createJoinKey(ctx, leftRow, leftKeys, left.output) val streamedKeyVars = createJoinKey(ctx, streamedRow, streamedKeys, streamedOutput)
val leftAnyNull = leftKeyVars.map(_.isNull).mkString(" || ") val streamedAnyNull = streamedKeyVars.map(_.isNull).mkString(" || ")
val rightKeyTmpVars = createJoinKey(ctx, rightRow, rightKeys, right.output) val bufferedKeyTmpVars = createJoinKey(ctx, bufferedRow, bufferedKeys, bufferedOutput)
val rightAnyNull = rightKeyTmpVars.map(_.isNull).mkString(" || ") val bufferedAnyNull = bufferedKeyTmpVars.map(_.isNull).mkString(" || ")
// Copy the right key as class members so they could be used in next function call. // Copy the buffered key as class members so they could be used in next function call.
val rightKeyVars = copyKeys(ctx, rightKeyTmpVars) val bufferedKeyVars = copyKeys(ctx, bufferedKeyTmpVars)
// A list to hold all matched rows from right side. // A list to hold all matched rows from buffered side.
val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName
val spillThreshold = getSpillThreshold val spillThreshold = getSpillThreshold
@ -418,26 +428,26 @@ case class SortMergeJoinExec(
// Inline mutable state since not many join operations in a task // Inline mutable state since not many join operations in a task
val matches = ctx.addMutableState(clsName, "matches", val matches = ctx.addMutableState(clsName, "matches",
v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);", forceInline = true) v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);", forceInline = true)
// Copy the left keys as class members so they could be used in next function call. // Copy the streamed keys as class members so they could be used in next function call.
val matchedKeyVars = copyKeys(ctx, leftKeyVars) val matchedKeyVars = copyKeys(ctx, streamedKeyVars)
ctx.addNewFunction("findNextInnerJoinRows", ctx.addNewFunction("findNextJoinRows",
s""" s"""
|private boolean findNextInnerJoinRows( |private boolean findNextJoinRows(
| scala.collection.Iterator leftIter, | scala.collection.Iterator streamedIter,
| scala.collection.Iterator rightIter) { | scala.collection.Iterator bufferedIter) {
| $leftRow = null; | $streamedRow = null;
| int comp = 0; | int comp = 0;
| while ($leftRow == null) { | while ($streamedRow == null) {
| if (!leftIter.hasNext()) return false; | if (!streamedIter.hasNext()) return false;
| $leftRow = (InternalRow) leftIter.next(); | $streamedRow = (InternalRow) streamedIter.next();
| ${leftKeyVars.map(_.code).mkString("\n")} | ${streamedKeyVars.map(_.code).mkString("\n")}
| if ($leftAnyNull) { | if ($streamedAnyNull) {
| $leftRow = null; | $streamedRow = null;
| continue; | continue;
| } | }
| if (!$matches.isEmpty()) { | if (!$matches.isEmpty()) {
| ${genComparison(ctx, leftKeyVars, matchedKeyVars)} | ${genComparison(ctx, streamedKeyVars, matchedKeyVars)}
| if (comp == 0) { | if (comp == 0) {
| return true; | return true;
| } | }
@ -445,88 +455,79 @@ case class SortMergeJoinExec(
| } | }
| |
| do { | do {
| if ($rightRow == null) { | if ($bufferedRow == null) {
| if (!rightIter.hasNext()) { | if (!bufferedIter.hasNext()) {
| ${matchedKeyVars.map(_.code).mkString("\n")} | ${matchedKeyVars.map(_.code).mkString("\n")}
| return !$matches.isEmpty(); | return !$matches.isEmpty();
| } | }
| $rightRow = (InternalRow) rightIter.next(); | $bufferedRow = (InternalRow) bufferedIter.next();
| ${rightKeyTmpVars.map(_.code).mkString("\n")} | ${bufferedKeyTmpVars.map(_.code).mkString("\n")}
| if ($rightAnyNull) { | if ($bufferedAnyNull) {
| $rightRow = null; | $bufferedRow = null;
| continue; | continue;
| } | }
| ${rightKeyVars.map(_.code).mkString("\n")} | ${bufferedKeyVars.map(_.code).mkString("\n")}
| } | }
| ${genComparison(ctx, leftKeyVars, rightKeyVars)} | ${genComparison(ctx, streamedKeyVars, bufferedKeyVars)}
| if (comp > 0) { | if (comp > 0) {
| $rightRow = null; | $bufferedRow = null;
| } else if (comp < 0) { | } else if (comp < 0) {
| if (!$matches.isEmpty()) { | if (!$matches.isEmpty()) {
| ${matchedKeyVars.map(_.code).mkString("\n")} | ${matchedKeyVars.map(_.code).mkString("\n")}
| return true; | return true;
| } | }
| $leftRow = null; | $streamedRow = null;
| } else { | } else {
| $matches.add((UnsafeRow) $rightRow); | $matches.add((UnsafeRow) $bufferedRow);
| $rightRow = null; | $bufferedRow = null;
| } | }
| } while ($leftRow != null); | } while ($streamedRow != null);
| } | }
| return false; // unreachable | return false; // unreachable
|} |}
""".stripMargin, inlineToOuterClass = true) """.stripMargin, inlineToOuterClass = true)
(leftRow, matches) (streamedRow, matches)
} }
/** /**
* Creates variables and declarations for left part of result row. * Creates variables and declarations for streamed part of result row.
* *
* In order to defer the access after condition and also only access once in the loop, * In order to defer the access after condition and also only access once in the loop,
* the variables should be declared separately from accessing the columns, we can't use the * the variables should be declared separately from accessing the columns, we can't use the
* codegen of BoundReference here. * codegen of BoundReference here.
*/ */
private def createLeftVars(ctx: CodegenContext, leftRow: String): (Seq[ExprCode], Seq[String]) = { private def createStreamedVars(
ctx.INPUT_ROW = leftRow ctx: CodegenContext,
streamedRow: String): (Seq[ExprCode], Seq[String]) = {
ctx.INPUT_ROW = streamedRow
left.output.zipWithIndex.map { case (a, i) => left.output.zipWithIndex.map { case (a, i) =>
val value = ctx.freshName("value") val value = ctx.freshName("value")
val valueCode = CodeGenerator.getValue(leftRow, a.dataType, i.toString) val valueCode = CodeGenerator.getValue(streamedRow, a.dataType, i.toString)
val javaType = CodeGenerator.javaType(a.dataType) val javaType = CodeGenerator.javaType(a.dataType)
val defaultValue = CodeGenerator.defaultValue(a.dataType) val defaultValue = CodeGenerator.defaultValue(a.dataType)
if (a.nullable) { if (a.nullable) {
val isNull = ctx.freshName("isNull") val isNull = ctx.freshName("isNull")
val code = val code =
code""" code"""
|$isNull = $leftRow.isNullAt($i); |$isNull = $streamedRow.isNullAt($i);
|$value = $isNull ? $defaultValue : ($valueCode); |$value = $isNull ? $defaultValue : ($valueCode);
""".stripMargin """.stripMargin
val leftVarsDecl = val streamedVarsDecl =
s""" s"""
|boolean $isNull = false; |boolean $isNull = false;
|$javaType $value = $defaultValue; |$javaType $value = $defaultValue;
""".stripMargin """.stripMargin
(ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType)), (ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType)),
leftVarsDecl) streamedVarsDecl)
} else { } else {
val code = code"$value = $valueCode;" val code = code"$value = $valueCode;"
val leftVarsDecl = s"""$javaType $value = $defaultValue;""" val streamedVarsDecl = s"""$javaType $value = $defaultValue;"""
(ExprCode(code, FalseLiteral, JavaCode.variable(value, a.dataType)), leftVarsDecl) (ExprCode(code, FalseLiteral, JavaCode.variable(value, a.dataType)), streamedVarsDecl)
} }
}.unzip }.unzip
} }
/**
* Creates the variables for right part of result row, using BoundReference, since the right
* part are accessed inside the loop.
*/
private def createRightVar(ctx: CodegenContext, rightRow: String): Seq[ExprCode] = {
ctx.INPUT_ROW = rightRow
right.output.zipWithIndex.map { case (a, i) =>
BoundReference(i, a.dataType, a.nullable).genCode(ctx)
}
}
/** /**
* Splits variables based on whether it's used by condition or not, returns the code to create * Splits variables based on whether it's used by condition or not, returns the code to create
* these variables before the condition and after the condition. * these variables before the condition and after the condition.
@ -554,62 +555,64 @@ case class SortMergeJoinExec(
override def doProduce(ctx: CodegenContext): String = { override def doProduce(ctx: CodegenContext): String = {
// Inline mutable state since not many join operations in a task // Inline mutable state since not many join operations in a task
val leftInput = ctx.addMutableState("scala.collection.Iterator", "leftInput", val streamedInput = ctx.addMutableState("scala.collection.Iterator", "streamedInput",
v => s"$v = inputs[0];", forceInline = true) v => s"$v = inputs[0];", forceInline = true)
val rightInput = ctx.addMutableState("scala.collection.Iterator", "rightInput", val bufferedInput = ctx.addMutableState("scala.collection.Iterator", "bufferedInput",
v => s"$v = inputs[1];", forceInline = true) v => s"$v = inputs[1];", forceInline = true)
val (leftRow, matches) = genScanner(ctx) val (streamedRow, matches) = genScanner(ctx)
// Create variables for row from both sides. // Create variables for row from both sides.
val (leftVars, leftVarDecl) = createLeftVars(ctx, leftRow) val (streamedVars, streamedVarDecl) = createStreamedVars(ctx, streamedRow)
val rightRow = ctx.freshName("rightRow") val bufferedRow = ctx.freshName("bufferedRow")
val rightVars = createRightVar(ctx, rightRow) val bufferedVars = genBuildSideVars(ctx, bufferedRow, bufferedPlan)
val iterator = ctx.freshName("iterator") val iterator = ctx.freshName("iterator")
val numOutput = metricTerm(ctx, "numOutputRows") val numOutput = metricTerm(ctx, "numOutputRows")
val resultVars = streamedVars ++ bufferedVars
val (beforeLoop, condCheck) = if (condition.isDefined) { val (beforeLoop, condCheck) = if (condition.isDefined) {
// Split the code of creating variables based on whether it's used by condition or not. // Split the code of creating variables based on whether it's used by condition or not.
val loaded = ctx.freshName("loaded") val loaded = ctx.freshName("loaded")
val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars) val (streamedBefore, streamedAfter) = splitVarsByCondition(streamedOutput, streamedVars)
val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars) val (bufferedBefore, bufferedAfter) = splitVarsByCondition(bufferedOutput, bufferedVars)
// Generate code for condition // Generate code for condition
ctx.currentVars = leftVars ++ rightVars ctx.currentVars = resultVars
val cond = BindReferences.bindReference(condition.get, output).genCode(ctx) val cond = BindReferences.bindReference(condition.get, output).genCode(ctx)
// evaluate the columns those used by condition before loop // evaluate the columns those used by condition before loop
val before = s""" val before = s"""
|boolean $loaded = false; |boolean $loaded = false;
|$leftBefore |$streamedBefore
""".stripMargin """.stripMargin
val checking = s""" val checking = s"""
|$rightBefore |$bufferedBefore
|${cond.code} |${cond.code}
|if (${cond.isNull} || !${cond.value}) continue; |if (${cond.isNull} || !${cond.value}) continue;
|if (!$loaded) { |if (!$loaded) {
| $loaded = true; | $loaded = true;
| $leftAfter | $streamedAfter
|} |}
|$rightAfter |$bufferedAfter
""".stripMargin """.stripMargin
(before, checking) (before, checking)
} else { } else {
(evaluateVariables(leftVars), "") (evaluateVariables(streamedVars), "")
} }
val thisPlan = ctx.addReferenceObj("plan", this) val thisPlan = ctx.addReferenceObj("plan", this)
val eagerCleanup = s"$thisPlan.cleanupResources();" val eagerCleanup = s"$thisPlan.cleanupResources();"
s""" s"""
|while (findNextInnerJoinRows($leftInput, $rightInput)) { |while (findNextJoinRows($streamedInput, $bufferedInput)) {
| ${leftVarDecl.mkString("\n")} | ${streamedVarDecl.mkString("\n")}
| ${beforeLoop.trim} | ${beforeLoop.trim}
| scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator(); | scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
| while ($iterator.hasNext()) { | while ($iterator.hasNext()) {
| InternalRow $rightRow = (InternalRow) $iterator.next(); | InternalRow $bufferedRow = (InternalRow) $iterator.next();
| ${condCheck.trim} | ${condCheck.trim}
| $numOutput.add(1); | $numOutput.add(1);
| ${consume(ctx, leftVars ++ rightVars)} | ${consume(ctx, resultVars)}
| } | }
| if (shouldStop()) return; | if (shouldStop()) return;
|} |}