[SPARK-35363][SQL] Refactor sort merge join code-gen be agnostic to join type
### What changes were proposed in this pull request? This is a pre-requisite of https://github.com/apache/spark/pull/32476, in discussion of https://github.com/apache/spark/pull/32476#issuecomment-836469779 . This is to refactor sort merge join code-gen to depend on streamed/buffered terminology, which makes the code-gen agnostic to different join types and can be extended to support other join types than inner join. ### Why are the changes needed? Pre-requisite of https://github.com/apache/spark/pull/32476. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing unit test in `InnerJoinSuite.scala` for inner join code-gen. Closes #32495 from c21/smj-refactor. Authored-by: Cheng Su <chengsu@fb.com> Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
This commit is contained in:
parent
44bd0a8bd3
commit
c4ca23207b
|
@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{Distribution, HashClustered
|
||||||
* Holds common logic for join operators by shuffling two child relations
|
* Holds common logic for join operators by shuffling two child relations
|
||||||
* using the join keys.
|
* using the join keys.
|
||||||
*/
|
*/
|
||||||
trait ShuffledJoin extends BaseJoinExec {
|
trait ShuffledJoin extends JoinCodegenSupport {
|
||||||
def isSkewJoin: Boolean
|
def isSkewJoin: Boolean
|
||||||
|
|
||||||
override def nodeName: String = {
|
override def nodeName: String = {
|
||||||
|
|
|
@ -40,7 +40,7 @@ case class SortMergeJoinExec(
|
||||||
condition: Option[Expression],
|
condition: Option[Expression],
|
||||||
left: SparkPlan,
|
left: SparkPlan,
|
||||||
right: SparkPlan,
|
right: SparkPlan,
|
||||||
isSkewJoin: Boolean = false) extends ShuffledJoin with CodegenSupport {
|
isSkewJoin: Boolean = false) extends ShuffledJoin {
|
||||||
|
|
||||||
override lazy val metrics = Map(
|
override lazy val metrics = Map(
|
||||||
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
|
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
|
||||||
|
@ -353,12 +353,22 @@ case class SortMergeJoinExec(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private lazy val ((streamedPlan, streamedKeys), (bufferedPlan, bufferedKeys)) = joinType match {
|
||||||
|
case _: InnerLike => ((left, leftKeys), (right, rightKeys))
|
||||||
|
case x =>
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
s"SortMergeJoin.streamedPlan/bufferedPlan should not take $x as the JoinType")
|
||||||
|
}
|
||||||
|
|
||||||
|
private lazy val streamedOutput = streamedPlan.output
|
||||||
|
private lazy val bufferedOutput = bufferedPlan.output
|
||||||
|
|
||||||
override def supportCodegen: Boolean = {
|
override def supportCodegen: Boolean = {
|
||||||
joinType.isInstanceOf[InnerLike]
|
joinType.isInstanceOf[InnerLike]
|
||||||
}
|
}
|
||||||
|
|
||||||
override def inputRDDs(): Seq[RDD[InternalRow]] = {
|
override def inputRDDs(): Seq[RDD[InternalRow]] = {
|
||||||
left.execute() :: right.execute() :: Nil
|
streamedPlan.execute() :: bufferedPlan.execute() :: Nil
|
||||||
}
|
}
|
||||||
|
|
||||||
private def createJoinKey(
|
private def createJoinKey(
|
||||||
|
@ -392,24 +402,24 @@ case class SortMergeJoinExec(
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate a function to scan both left and right to find a match, returns the term for
|
* Generate a function to scan both sides to find a match, returns the term for
|
||||||
* matched one row from left side and buffered rows from right side.
|
* matched one row from streamed side and buffered rows from buffered side.
|
||||||
*/
|
*/
|
||||||
private def genScanner(ctx: CodegenContext): (String, String) = {
|
private def genScanner(ctx: CodegenContext): (String, String) = {
|
||||||
// Create class member for next row from both sides.
|
// Create class member for next row from both sides.
|
||||||
// Inline mutable state since not many join operations in a task
|
// Inline mutable state since not many join operations in a task
|
||||||
val leftRow = ctx.addMutableState("InternalRow", "leftRow", forceInline = true)
|
val streamedRow = ctx.addMutableState("InternalRow", "streamedRow", forceInline = true)
|
||||||
val rightRow = ctx.addMutableState("InternalRow", "rightRow", forceInline = true)
|
val bufferedRow = ctx.addMutableState("InternalRow", "bufferedRow", forceInline = true)
|
||||||
|
|
||||||
// Create variables for join keys from both sides.
|
// Create variables for join keys from both sides.
|
||||||
val leftKeyVars = createJoinKey(ctx, leftRow, leftKeys, left.output)
|
val streamedKeyVars = createJoinKey(ctx, streamedRow, streamedKeys, streamedOutput)
|
||||||
val leftAnyNull = leftKeyVars.map(_.isNull).mkString(" || ")
|
val streamedAnyNull = streamedKeyVars.map(_.isNull).mkString(" || ")
|
||||||
val rightKeyTmpVars = createJoinKey(ctx, rightRow, rightKeys, right.output)
|
val bufferedKeyTmpVars = createJoinKey(ctx, bufferedRow, bufferedKeys, bufferedOutput)
|
||||||
val rightAnyNull = rightKeyTmpVars.map(_.isNull).mkString(" || ")
|
val bufferedAnyNull = bufferedKeyTmpVars.map(_.isNull).mkString(" || ")
|
||||||
// Copy the right key as class members so they could be used in next function call.
|
// Copy the buffered key as class members so they could be used in next function call.
|
||||||
val rightKeyVars = copyKeys(ctx, rightKeyTmpVars)
|
val bufferedKeyVars = copyKeys(ctx, bufferedKeyTmpVars)
|
||||||
|
|
||||||
// A list to hold all matched rows from right side.
|
// A list to hold all matched rows from buffered side.
|
||||||
val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName
|
val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName
|
||||||
|
|
||||||
val spillThreshold = getSpillThreshold
|
val spillThreshold = getSpillThreshold
|
||||||
|
@ -418,26 +428,26 @@ case class SortMergeJoinExec(
|
||||||
// Inline mutable state since not many join operations in a task
|
// Inline mutable state since not many join operations in a task
|
||||||
val matches = ctx.addMutableState(clsName, "matches",
|
val matches = ctx.addMutableState(clsName, "matches",
|
||||||
v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);", forceInline = true)
|
v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);", forceInline = true)
|
||||||
// Copy the left keys as class members so they could be used in next function call.
|
// Copy the streamed keys as class members so they could be used in next function call.
|
||||||
val matchedKeyVars = copyKeys(ctx, leftKeyVars)
|
val matchedKeyVars = copyKeys(ctx, streamedKeyVars)
|
||||||
|
|
||||||
ctx.addNewFunction("findNextInnerJoinRows",
|
ctx.addNewFunction("findNextJoinRows",
|
||||||
s"""
|
s"""
|
||||||
|private boolean findNextInnerJoinRows(
|
|private boolean findNextJoinRows(
|
||||||
| scala.collection.Iterator leftIter,
|
| scala.collection.Iterator streamedIter,
|
||||||
| scala.collection.Iterator rightIter) {
|
| scala.collection.Iterator bufferedIter) {
|
||||||
| $leftRow = null;
|
| $streamedRow = null;
|
||||||
| int comp = 0;
|
| int comp = 0;
|
||||||
| while ($leftRow == null) {
|
| while ($streamedRow == null) {
|
||||||
| if (!leftIter.hasNext()) return false;
|
| if (!streamedIter.hasNext()) return false;
|
||||||
| $leftRow = (InternalRow) leftIter.next();
|
| $streamedRow = (InternalRow) streamedIter.next();
|
||||||
| ${leftKeyVars.map(_.code).mkString("\n")}
|
| ${streamedKeyVars.map(_.code).mkString("\n")}
|
||||||
| if ($leftAnyNull) {
|
| if ($streamedAnyNull) {
|
||||||
| $leftRow = null;
|
| $streamedRow = null;
|
||||||
| continue;
|
| continue;
|
||||||
| }
|
| }
|
||||||
| if (!$matches.isEmpty()) {
|
| if (!$matches.isEmpty()) {
|
||||||
| ${genComparison(ctx, leftKeyVars, matchedKeyVars)}
|
| ${genComparison(ctx, streamedKeyVars, matchedKeyVars)}
|
||||||
| if (comp == 0) {
|
| if (comp == 0) {
|
||||||
| return true;
|
| return true;
|
||||||
| }
|
| }
|
||||||
|
@ -445,88 +455,79 @@ case class SortMergeJoinExec(
|
||||||
| }
|
| }
|
||||||
|
|
|
|
||||||
| do {
|
| do {
|
||||||
| if ($rightRow == null) {
|
| if ($bufferedRow == null) {
|
||||||
| if (!rightIter.hasNext()) {
|
| if (!bufferedIter.hasNext()) {
|
||||||
| ${matchedKeyVars.map(_.code).mkString("\n")}
|
| ${matchedKeyVars.map(_.code).mkString("\n")}
|
||||||
| return !$matches.isEmpty();
|
| return !$matches.isEmpty();
|
||||||
| }
|
| }
|
||||||
| $rightRow = (InternalRow) rightIter.next();
|
| $bufferedRow = (InternalRow) bufferedIter.next();
|
||||||
| ${rightKeyTmpVars.map(_.code).mkString("\n")}
|
| ${bufferedKeyTmpVars.map(_.code).mkString("\n")}
|
||||||
| if ($rightAnyNull) {
|
| if ($bufferedAnyNull) {
|
||||||
| $rightRow = null;
|
| $bufferedRow = null;
|
||||||
| continue;
|
| continue;
|
||||||
| }
|
| }
|
||||||
| ${rightKeyVars.map(_.code).mkString("\n")}
|
| ${bufferedKeyVars.map(_.code).mkString("\n")}
|
||||||
| }
|
| }
|
||||||
| ${genComparison(ctx, leftKeyVars, rightKeyVars)}
|
| ${genComparison(ctx, streamedKeyVars, bufferedKeyVars)}
|
||||||
| if (comp > 0) {
|
| if (comp > 0) {
|
||||||
| $rightRow = null;
|
| $bufferedRow = null;
|
||||||
| } else if (comp < 0) {
|
| } else if (comp < 0) {
|
||||||
| if (!$matches.isEmpty()) {
|
| if (!$matches.isEmpty()) {
|
||||||
| ${matchedKeyVars.map(_.code).mkString("\n")}
|
| ${matchedKeyVars.map(_.code).mkString("\n")}
|
||||||
| return true;
|
| return true;
|
||||||
| }
|
| }
|
||||||
| $leftRow = null;
|
| $streamedRow = null;
|
||||||
| } else {
|
| } else {
|
||||||
| $matches.add((UnsafeRow) $rightRow);
|
| $matches.add((UnsafeRow) $bufferedRow);
|
||||||
| $rightRow = null;
|
| $bufferedRow = null;
|
||||||
| }
|
| }
|
||||||
| } while ($leftRow != null);
|
| } while ($streamedRow != null);
|
||||||
| }
|
| }
|
||||||
| return false; // unreachable
|
| return false; // unreachable
|
||||||
|}
|
|}
|
||||||
""".stripMargin, inlineToOuterClass = true)
|
""".stripMargin, inlineToOuterClass = true)
|
||||||
|
|
||||||
(leftRow, matches)
|
(streamedRow, matches)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates variables and declarations for left part of result row.
|
* Creates variables and declarations for streamed part of result row.
|
||||||
*
|
*
|
||||||
* In order to defer the access after condition and also only access once in the loop,
|
* In order to defer the access after condition and also only access once in the loop,
|
||||||
* the variables should be declared separately from accessing the columns, we can't use the
|
* the variables should be declared separately from accessing the columns, we can't use the
|
||||||
* codegen of BoundReference here.
|
* codegen of BoundReference here.
|
||||||
*/
|
*/
|
||||||
private def createLeftVars(ctx: CodegenContext, leftRow: String): (Seq[ExprCode], Seq[String]) = {
|
private def createStreamedVars(
|
||||||
ctx.INPUT_ROW = leftRow
|
ctx: CodegenContext,
|
||||||
|
streamedRow: String): (Seq[ExprCode], Seq[String]) = {
|
||||||
|
ctx.INPUT_ROW = streamedRow
|
||||||
left.output.zipWithIndex.map { case (a, i) =>
|
left.output.zipWithIndex.map { case (a, i) =>
|
||||||
val value = ctx.freshName("value")
|
val value = ctx.freshName("value")
|
||||||
val valueCode = CodeGenerator.getValue(leftRow, a.dataType, i.toString)
|
val valueCode = CodeGenerator.getValue(streamedRow, a.dataType, i.toString)
|
||||||
val javaType = CodeGenerator.javaType(a.dataType)
|
val javaType = CodeGenerator.javaType(a.dataType)
|
||||||
val defaultValue = CodeGenerator.defaultValue(a.dataType)
|
val defaultValue = CodeGenerator.defaultValue(a.dataType)
|
||||||
if (a.nullable) {
|
if (a.nullable) {
|
||||||
val isNull = ctx.freshName("isNull")
|
val isNull = ctx.freshName("isNull")
|
||||||
val code =
|
val code =
|
||||||
code"""
|
code"""
|
||||||
|$isNull = $leftRow.isNullAt($i);
|
|$isNull = $streamedRow.isNullAt($i);
|
||||||
|$value = $isNull ? $defaultValue : ($valueCode);
|
|$value = $isNull ? $defaultValue : ($valueCode);
|
||||||
""".stripMargin
|
""".stripMargin
|
||||||
val leftVarsDecl =
|
val streamedVarsDecl =
|
||||||
s"""
|
s"""
|
||||||
|boolean $isNull = false;
|
|boolean $isNull = false;
|
||||||
|$javaType $value = $defaultValue;
|
|$javaType $value = $defaultValue;
|
||||||
""".stripMargin
|
""".stripMargin
|
||||||
(ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType)),
|
(ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType)),
|
||||||
leftVarsDecl)
|
streamedVarsDecl)
|
||||||
} else {
|
} else {
|
||||||
val code = code"$value = $valueCode;"
|
val code = code"$value = $valueCode;"
|
||||||
val leftVarsDecl = s"""$javaType $value = $defaultValue;"""
|
val streamedVarsDecl = s"""$javaType $value = $defaultValue;"""
|
||||||
(ExprCode(code, FalseLiteral, JavaCode.variable(value, a.dataType)), leftVarsDecl)
|
(ExprCode(code, FalseLiteral, JavaCode.variable(value, a.dataType)), streamedVarsDecl)
|
||||||
}
|
}
|
||||||
}.unzip
|
}.unzip
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates the variables for right part of result row, using BoundReference, since the right
|
|
||||||
* part are accessed inside the loop.
|
|
||||||
*/
|
|
||||||
private def createRightVar(ctx: CodegenContext, rightRow: String): Seq[ExprCode] = {
|
|
||||||
ctx.INPUT_ROW = rightRow
|
|
||||||
right.output.zipWithIndex.map { case (a, i) =>
|
|
||||||
BoundReference(i, a.dataType, a.nullable).genCode(ctx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Splits variables based on whether it's used by condition or not, returns the code to create
|
* Splits variables based on whether it's used by condition or not, returns the code to create
|
||||||
* these variables before the condition and after the condition.
|
* these variables before the condition and after the condition.
|
||||||
|
@ -554,62 +555,64 @@ case class SortMergeJoinExec(
|
||||||
|
|
||||||
override def doProduce(ctx: CodegenContext): String = {
|
override def doProduce(ctx: CodegenContext): String = {
|
||||||
// Inline mutable state since not many join operations in a task
|
// Inline mutable state since not many join operations in a task
|
||||||
val leftInput = ctx.addMutableState("scala.collection.Iterator", "leftInput",
|
val streamedInput = ctx.addMutableState("scala.collection.Iterator", "streamedInput",
|
||||||
v => s"$v = inputs[0];", forceInline = true)
|
v => s"$v = inputs[0];", forceInline = true)
|
||||||
val rightInput = ctx.addMutableState("scala.collection.Iterator", "rightInput",
|
val bufferedInput = ctx.addMutableState("scala.collection.Iterator", "bufferedInput",
|
||||||
v => s"$v = inputs[1];", forceInline = true)
|
v => s"$v = inputs[1];", forceInline = true)
|
||||||
|
|
||||||
val (leftRow, matches) = genScanner(ctx)
|
val (streamedRow, matches) = genScanner(ctx)
|
||||||
|
|
||||||
// Create variables for row from both sides.
|
// Create variables for row from both sides.
|
||||||
val (leftVars, leftVarDecl) = createLeftVars(ctx, leftRow)
|
val (streamedVars, streamedVarDecl) = createStreamedVars(ctx, streamedRow)
|
||||||
val rightRow = ctx.freshName("rightRow")
|
val bufferedRow = ctx.freshName("bufferedRow")
|
||||||
val rightVars = createRightVar(ctx, rightRow)
|
val bufferedVars = genBuildSideVars(ctx, bufferedRow, bufferedPlan)
|
||||||
|
|
||||||
val iterator = ctx.freshName("iterator")
|
val iterator = ctx.freshName("iterator")
|
||||||
val numOutput = metricTerm(ctx, "numOutputRows")
|
val numOutput = metricTerm(ctx, "numOutputRows")
|
||||||
|
val resultVars = streamedVars ++ bufferedVars
|
||||||
|
|
||||||
val (beforeLoop, condCheck) = if (condition.isDefined) {
|
val (beforeLoop, condCheck) = if (condition.isDefined) {
|
||||||
// Split the code of creating variables based on whether it's used by condition or not.
|
// Split the code of creating variables based on whether it's used by condition or not.
|
||||||
val loaded = ctx.freshName("loaded")
|
val loaded = ctx.freshName("loaded")
|
||||||
val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars)
|
val (streamedBefore, streamedAfter) = splitVarsByCondition(streamedOutput, streamedVars)
|
||||||
val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars)
|
val (bufferedBefore, bufferedAfter) = splitVarsByCondition(bufferedOutput, bufferedVars)
|
||||||
// Generate code for condition
|
// Generate code for condition
|
||||||
ctx.currentVars = leftVars ++ rightVars
|
ctx.currentVars = resultVars
|
||||||
val cond = BindReferences.bindReference(condition.get, output).genCode(ctx)
|
val cond = BindReferences.bindReference(condition.get, output).genCode(ctx)
|
||||||
// evaluate the columns those used by condition before loop
|
// evaluate the columns those used by condition before loop
|
||||||
val before = s"""
|
val before = s"""
|
||||||
|boolean $loaded = false;
|
|boolean $loaded = false;
|
||||||
|$leftBefore
|
|$streamedBefore
|
||||||
""".stripMargin
|
""".stripMargin
|
||||||
|
|
||||||
val checking = s"""
|
val checking = s"""
|
||||||
|$rightBefore
|
|$bufferedBefore
|
||||||
|${cond.code}
|
|${cond.code}
|
||||||
|if (${cond.isNull} || !${cond.value}) continue;
|
|if (${cond.isNull} || !${cond.value}) continue;
|
||||||
|if (!$loaded) {
|
|if (!$loaded) {
|
||||||
| $loaded = true;
|
| $loaded = true;
|
||||||
| $leftAfter
|
| $streamedAfter
|
||||||
|}
|
|}
|
||||||
|$rightAfter
|
|$bufferedAfter
|
||||||
""".stripMargin
|
""".stripMargin
|
||||||
(before, checking)
|
(before, checking)
|
||||||
} else {
|
} else {
|
||||||
(evaluateVariables(leftVars), "")
|
(evaluateVariables(streamedVars), "")
|
||||||
}
|
}
|
||||||
|
|
||||||
val thisPlan = ctx.addReferenceObj("plan", this)
|
val thisPlan = ctx.addReferenceObj("plan", this)
|
||||||
val eagerCleanup = s"$thisPlan.cleanupResources();"
|
val eagerCleanup = s"$thisPlan.cleanupResources();"
|
||||||
|
|
||||||
s"""
|
s"""
|
||||||
|while (findNextInnerJoinRows($leftInput, $rightInput)) {
|
|while (findNextJoinRows($streamedInput, $bufferedInput)) {
|
||||||
| ${leftVarDecl.mkString("\n")}
|
| ${streamedVarDecl.mkString("\n")}
|
||||||
| ${beforeLoop.trim}
|
| ${beforeLoop.trim}
|
||||||
| scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
|
| scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
|
||||||
| while ($iterator.hasNext()) {
|
| while ($iterator.hasNext()) {
|
||||||
| InternalRow $rightRow = (InternalRow) $iterator.next();
|
| InternalRow $bufferedRow = (InternalRow) $iterator.next();
|
||||||
| ${condCheck.trim}
|
| ${condCheck.trim}
|
||||||
| $numOutput.add(1);
|
| $numOutput.add(1);
|
||||||
| ${consume(ctx, leftVars ++ rightVars)}
|
| ${consume(ctx, resultVars)}
|
||||||
| }
|
| }
|
||||||
| if (shouldStop()) return;
|
| if (shouldStop()) return;
|
||||||
|}
|
|}
|
||||||
|
|
Loading…
Reference in a new issue