diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala index a8d8050149..794807fd3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{Distribution, HashClustered * Holds common logic for join operators by shuffling two child relations * using the join keys. */ -trait ShuffledJoin extends BaseJoinExec { +trait ShuffledJoin extends JoinCodegenSupport { def isSkewJoin: Boolean override def nodeName: String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index ad4dee7c9a..74371f2005 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -40,7 +40,7 @@ case class SortMergeJoinExec( condition: Option[Expression], left: SparkPlan, right: SparkPlan, - isSkewJoin: Boolean = false) extends ShuffledJoin with CodegenSupport { + isSkewJoin: Boolean = false) extends ShuffledJoin { override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -353,12 +353,22 @@ case class SortMergeJoinExec( } } + private lazy val ((streamedPlan, streamedKeys), (bufferedPlan, bufferedKeys)) = joinType match { + case _: InnerLike => ((left, leftKeys), (right, rightKeys)) + case x => + throw new IllegalArgumentException( + s"SortMergeJoin.streamedPlan/bufferedPlan should not take $x as the JoinType") + } + + private lazy val streamedOutput = streamedPlan.output + private lazy val bufferedOutput = bufferedPlan.output + override def supportCodegen: Boolean = { joinType.isInstanceOf[InnerLike] } override def inputRDDs(): Seq[RDD[InternalRow]] = { - left.execute() :: right.execute() :: Nil + streamedPlan.execute() :: bufferedPlan.execute() :: Nil } private def createJoinKey( @@ -392,24 +402,24 @@ case class SortMergeJoinExec( } /** - * Generate a function to scan both left and right to find a match, returns the term for - * matched one row from left side and buffered rows from right side. + * Generate a function to scan both sides to find a match, returns the term for + * matched one row from streamed side and buffered rows from buffered side. */ private def genScanner(ctx: CodegenContext): (String, String) = { // Create class member for next row from both sides. // Inline mutable state since not many join operations in a task - val leftRow = ctx.addMutableState("InternalRow", "leftRow", forceInline = true) - val rightRow = ctx.addMutableState("InternalRow", "rightRow", forceInline = true) + val streamedRow = ctx.addMutableState("InternalRow", "streamedRow", forceInline = true) + val bufferedRow = ctx.addMutableState("InternalRow", "bufferedRow", forceInline = true) // Create variables for join keys from both sides. - val leftKeyVars = createJoinKey(ctx, leftRow, leftKeys, left.output) - val leftAnyNull = leftKeyVars.map(_.isNull).mkString(" || ") - val rightKeyTmpVars = createJoinKey(ctx, rightRow, rightKeys, right.output) - val rightAnyNull = rightKeyTmpVars.map(_.isNull).mkString(" || ") - // Copy the right key as class members so they could be used in next function call. - val rightKeyVars = copyKeys(ctx, rightKeyTmpVars) + val streamedKeyVars = createJoinKey(ctx, streamedRow, streamedKeys, streamedOutput) + val streamedAnyNull = streamedKeyVars.map(_.isNull).mkString(" || ") + val bufferedKeyTmpVars = createJoinKey(ctx, bufferedRow, bufferedKeys, bufferedOutput) + val bufferedAnyNull = bufferedKeyTmpVars.map(_.isNull).mkString(" || ") + // Copy the buffered key as class members so they could be used in next function call. + val bufferedKeyVars = copyKeys(ctx, bufferedKeyTmpVars) - // A list to hold all matched rows from right side. + // A list to hold all matched rows from buffered side. val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName val spillThreshold = getSpillThreshold @@ -418,26 +428,26 @@ case class SortMergeJoinExec( // Inline mutable state since not many join operations in a task val matches = ctx.addMutableState(clsName, "matches", v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);", forceInline = true) - // Copy the left keys as class members so they could be used in next function call. - val matchedKeyVars = copyKeys(ctx, leftKeyVars) + // Copy the streamed keys as class members so they could be used in next function call. + val matchedKeyVars = copyKeys(ctx, streamedKeyVars) - ctx.addNewFunction("findNextInnerJoinRows", + ctx.addNewFunction("findNextJoinRows", s""" - |private boolean findNextInnerJoinRows( - | scala.collection.Iterator leftIter, - | scala.collection.Iterator rightIter) { - | $leftRow = null; + |private boolean findNextJoinRows( + | scala.collection.Iterator streamedIter, + | scala.collection.Iterator bufferedIter) { + | $streamedRow = null; | int comp = 0; - | while ($leftRow == null) { - | if (!leftIter.hasNext()) return false; - | $leftRow = (InternalRow) leftIter.next(); - | ${leftKeyVars.map(_.code).mkString("\n")} - | if ($leftAnyNull) { - | $leftRow = null; + | while ($streamedRow == null) { + | if (!streamedIter.hasNext()) return false; + | $streamedRow = (InternalRow) streamedIter.next(); + | ${streamedKeyVars.map(_.code).mkString("\n")} + | if ($streamedAnyNull) { + | $streamedRow = null; | continue; | } | if (!$matches.isEmpty()) { - | ${genComparison(ctx, leftKeyVars, matchedKeyVars)} + | ${genComparison(ctx, streamedKeyVars, matchedKeyVars)} | if (comp == 0) { | return true; | } @@ -445,88 +455,79 @@ case class SortMergeJoinExec( | } | | do { - | if ($rightRow == null) { - | if (!rightIter.hasNext()) { + | if ($bufferedRow == null) { + | if (!bufferedIter.hasNext()) { | ${matchedKeyVars.map(_.code).mkString("\n")} | return !$matches.isEmpty(); | } - | $rightRow = (InternalRow) rightIter.next(); - | ${rightKeyTmpVars.map(_.code).mkString("\n")} - | if ($rightAnyNull) { - | $rightRow = null; + | $bufferedRow = (InternalRow) bufferedIter.next(); + | ${bufferedKeyTmpVars.map(_.code).mkString("\n")} + | if ($bufferedAnyNull) { + | $bufferedRow = null; | continue; | } - | ${rightKeyVars.map(_.code).mkString("\n")} + | ${bufferedKeyVars.map(_.code).mkString("\n")} | } - | ${genComparison(ctx, leftKeyVars, rightKeyVars)} + | ${genComparison(ctx, streamedKeyVars, bufferedKeyVars)} | if (comp > 0) { - | $rightRow = null; + | $bufferedRow = null; | } else if (comp < 0) { | if (!$matches.isEmpty()) { | ${matchedKeyVars.map(_.code).mkString("\n")} | return true; | } - | $leftRow = null; + | $streamedRow = null; | } else { - | $matches.add((UnsafeRow) $rightRow); - | $rightRow = null; + | $matches.add((UnsafeRow) $bufferedRow); + | $bufferedRow = null; | } - | } while ($leftRow != null); + | } while ($streamedRow != null); | } | return false; // unreachable |} """.stripMargin, inlineToOuterClass = true) - (leftRow, matches) + (streamedRow, matches) } /** - * Creates variables and declarations for left part of result row. + * Creates variables and declarations for streamed part of result row. * * In order to defer the access after condition and also only access once in the loop, * the variables should be declared separately from accessing the columns, we can't use the * codegen of BoundReference here. */ - private def createLeftVars(ctx: CodegenContext, leftRow: String): (Seq[ExprCode], Seq[String]) = { - ctx.INPUT_ROW = leftRow + private def createStreamedVars( + ctx: CodegenContext, + streamedRow: String): (Seq[ExprCode], Seq[String]) = { + ctx.INPUT_ROW = streamedRow left.output.zipWithIndex.map { case (a, i) => val value = ctx.freshName("value") - val valueCode = CodeGenerator.getValue(leftRow, a.dataType, i.toString) + val valueCode = CodeGenerator.getValue(streamedRow, a.dataType, i.toString) val javaType = CodeGenerator.javaType(a.dataType) val defaultValue = CodeGenerator.defaultValue(a.dataType) if (a.nullable) { val isNull = ctx.freshName("isNull") val code = code""" - |$isNull = $leftRow.isNullAt($i); + |$isNull = $streamedRow.isNullAt($i); |$value = $isNull ? $defaultValue : ($valueCode); """.stripMargin - val leftVarsDecl = + val streamedVarsDecl = s""" |boolean $isNull = false; |$javaType $value = $defaultValue; """.stripMargin (ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType)), - leftVarsDecl) + streamedVarsDecl) } else { val code = code"$value = $valueCode;" - val leftVarsDecl = s"""$javaType $value = $defaultValue;""" - (ExprCode(code, FalseLiteral, JavaCode.variable(value, a.dataType)), leftVarsDecl) + val streamedVarsDecl = s"""$javaType $value = $defaultValue;""" + (ExprCode(code, FalseLiteral, JavaCode.variable(value, a.dataType)), streamedVarsDecl) } }.unzip } - /** - * Creates the variables for right part of result row, using BoundReference, since the right - * part are accessed inside the loop. - */ - private def createRightVar(ctx: CodegenContext, rightRow: String): Seq[ExprCode] = { - ctx.INPUT_ROW = rightRow - right.output.zipWithIndex.map { case (a, i) => - BoundReference(i, a.dataType, a.nullable).genCode(ctx) - } - } - /** * Splits variables based on whether it's used by condition or not, returns the code to create * these variables before the condition and after the condition. @@ -554,62 +555,64 @@ case class SortMergeJoinExec( override def doProduce(ctx: CodegenContext): String = { // Inline mutable state since not many join operations in a task - val leftInput = ctx.addMutableState("scala.collection.Iterator", "leftInput", + val streamedInput = ctx.addMutableState("scala.collection.Iterator", "streamedInput", v => s"$v = inputs[0];", forceInline = true) - val rightInput = ctx.addMutableState("scala.collection.Iterator", "rightInput", + val bufferedInput = ctx.addMutableState("scala.collection.Iterator", "bufferedInput", v => s"$v = inputs[1];", forceInline = true) - val (leftRow, matches) = genScanner(ctx) + val (streamedRow, matches) = genScanner(ctx) // Create variables for row from both sides. - val (leftVars, leftVarDecl) = createLeftVars(ctx, leftRow) - val rightRow = ctx.freshName("rightRow") - val rightVars = createRightVar(ctx, rightRow) + val (streamedVars, streamedVarDecl) = createStreamedVars(ctx, streamedRow) + val bufferedRow = ctx.freshName("bufferedRow") + val bufferedVars = genBuildSideVars(ctx, bufferedRow, bufferedPlan) val iterator = ctx.freshName("iterator") val numOutput = metricTerm(ctx, "numOutputRows") + val resultVars = streamedVars ++ bufferedVars + val (beforeLoop, condCheck) = if (condition.isDefined) { // Split the code of creating variables based on whether it's used by condition or not. val loaded = ctx.freshName("loaded") - val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars) - val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars) + val (streamedBefore, streamedAfter) = splitVarsByCondition(streamedOutput, streamedVars) + val (bufferedBefore, bufferedAfter) = splitVarsByCondition(bufferedOutput, bufferedVars) // Generate code for condition - ctx.currentVars = leftVars ++ rightVars + ctx.currentVars = resultVars val cond = BindReferences.bindReference(condition.get, output).genCode(ctx) // evaluate the columns those used by condition before loop val before = s""" |boolean $loaded = false; - |$leftBefore + |$streamedBefore """.stripMargin val checking = s""" - |$rightBefore + |$bufferedBefore |${cond.code} |if (${cond.isNull} || !${cond.value}) continue; |if (!$loaded) { | $loaded = true; - | $leftAfter + | $streamedAfter |} - |$rightAfter + |$bufferedAfter """.stripMargin (before, checking) } else { - (evaluateVariables(leftVars), "") + (evaluateVariables(streamedVars), "") } val thisPlan = ctx.addReferenceObj("plan", this) val eagerCleanup = s"$thisPlan.cleanupResources();" s""" - |while (findNextInnerJoinRows($leftInput, $rightInput)) { - | ${leftVarDecl.mkString("\n")} + |while (findNextJoinRows($streamedInput, $bufferedInput)) { + | ${streamedVarDecl.mkString("\n")} | ${beforeLoop.trim} | scala.collection.Iterator $iterator = $matches.generateIterator(); | while ($iterator.hasNext()) { - | InternalRow $rightRow = (InternalRow) $iterator.next(); + | InternalRow $bufferedRow = (InternalRow) $iterator.next(); | ${condCheck.trim} | $numOutput.add(1); - | ${consume(ctx, leftVars ++ rightVars)} + | ${consume(ctx, resultVars)} | } | if (shouldStop()) return; |}