[SPARK-13373] [SQL] generate sort merge join
## What changes were proposed in this pull request? Generates code for SortMergeJoin. ## How was the this patch tested? Unit tests and manually tested with TPCDS Q72, which showed 70% performance improvements (from 42s to 25s), but micro benchmark only show minor improvements, it may depends the distribution of data and number of columns. Author: Davies Liu <davies@databricks.com> Closes #11248 from davies/gen_smj.
This commit is contained in:
parent
c481bdf512
commit
9cdd867da9
|
@ -203,6 +203,7 @@ private[spark] class DiskBlockObjectWriter(
|
|||
numRecordsWritten += 1
|
||||
writeMetrics.incRecordsWritten(1)
|
||||
|
||||
// TODO: call updateBytesWritten() less frequently.
|
||||
if (numRecordsWritten % 32 == 0) {
|
||||
updateBytesWritten()
|
||||
}
|
||||
|
|
|
@ -29,12 +29,9 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
|
|||
/**
|
||||
* An iterator interface used to pull the output from generated function for multiple operators
|
||||
* (whole stage codegen).
|
||||
*
|
||||
* TODO: replaced it by batched columnar format.
|
||||
*/
|
||||
public class BufferedRowIterator {
|
||||
public abstract class BufferedRowIterator {
|
||||
protected LinkedList<InternalRow> currentRows = new LinkedList<>();
|
||||
protected Iterator<InternalRow> input;
|
||||
// used when there is no column in output
|
||||
protected UnsafeRow unsafeRow = new UnsafeRow(0);
|
||||
|
||||
|
@ -49,8 +46,16 @@ public class BufferedRowIterator {
|
|||
return currentRows.remove();
|
||||
}
|
||||
|
||||
public void setInput(Iterator<InternalRow> iter) {
|
||||
input = iter;
|
||||
/**
|
||||
* Initializes from array of iterators of InternalRow.
|
||||
*/
|
||||
public abstract void init(Iterator<InternalRow> iters[]);
|
||||
|
||||
/**
|
||||
* Append a row to currentRows.
|
||||
*/
|
||||
protected void append(InternalRow row) {
|
||||
currentRows.add(row);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -74,9 +79,5 @@ public class BufferedRowIterator {
|
|||
*
|
||||
* After it's called, if currentRow is still null, it means no more rows left.
|
||||
*/
|
||||
protected void processNext() throws IOException {
|
||||
if (input.hasNext()) {
|
||||
currentRows.add(input.next());
|
||||
}
|
||||
}
|
||||
protected abstract void processNext() throws IOException;
|
||||
}
|
||||
|
|
|
@ -85,8 +85,8 @@ case class Expand(
|
|||
}
|
||||
}
|
||||
|
||||
override def upstream(): RDD[InternalRow] = {
|
||||
child.asInstanceOf[CodegenSupport].upstream()
|
||||
override def upstreams(): Seq[RDD[InternalRow]] = {
|
||||
child.asInstanceOf[CodegenSupport].upstreams()
|
||||
}
|
||||
|
||||
protected override def doProduce(ctx: CodegenContext): String = {
|
||||
|
|
|
@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning
|
|||
import org.apache.spark.sql.catalyst.rules.Rule
|
||||
import org.apache.spark.sql.catalyst.util.toCommentSafeString
|
||||
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
|
||||
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight}
|
||||
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight, SortMergeJoin}
|
||||
import org.apache.spark.sql.execution.metric.LongSQLMetricValue
|
||||
|
||||
/**
|
||||
|
@ -40,7 +40,8 @@ trait CodegenSupport extends SparkPlan {
|
|||
/** Prefix used in the current operator's variable names. */
|
||||
private def variablePrefix: String = this match {
|
||||
case _: TungstenAggregate => "agg"
|
||||
case _: BroadcastHashJoin => "join"
|
||||
case _: BroadcastHashJoin => "bhj"
|
||||
case _: SortMergeJoin => "smj"
|
||||
case _ => nodeName.toLowerCase
|
||||
}
|
||||
|
||||
|
@ -68,9 +69,11 @@ trait CodegenSupport extends SparkPlan {
|
|||
private var parent: CodegenSupport = null
|
||||
|
||||
/**
|
||||
* Returns the RDD of InternalRow which generates the input rows.
|
||||
* Returns all the RDDs of InternalRow which generates the input rows.
|
||||
*
|
||||
* Note: right now we support up to two RDDs.
|
||||
*/
|
||||
def upstream(): RDD[InternalRow]
|
||||
def upstreams(): Seq[RDD[InternalRow]]
|
||||
|
||||
/**
|
||||
* Returns Java source code to process the rows from upstream.
|
||||
|
@ -179,19 +182,23 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
|
|||
|
||||
override def supportCodegen: Boolean = false
|
||||
|
||||
override def upstream(): RDD[InternalRow] = {
|
||||
child.execute()
|
||||
override def upstreams(): Seq[RDD[InternalRow]] = {
|
||||
child.execute() :: Nil
|
||||
}
|
||||
|
||||
override def doProduce(ctx: CodegenContext): String = {
|
||||
val input = ctx.freshName("input")
|
||||
// Right now, InputAdapter is only used when there is one upstream.
|
||||
ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
|
||||
|
||||
val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true))
|
||||
val row = ctx.freshName("row")
|
||||
ctx.INPUT_ROW = row
|
||||
ctx.currentVars = null
|
||||
val columns = exprs.map(_.gen(ctx))
|
||||
s"""
|
||||
| while (input.hasNext()) {
|
||||
| InternalRow $row = (InternalRow) input.next();
|
||||
| while ($input.hasNext()) {
|
||||
| InternalRow $row = (InternalRow) $input.next();
|
||||
| ${columns.map(_.code).mkString("\n").trim}
|
||||
| ${consume(ctx, columns).trim}
|
||||
| if (shouldStop()) {
|
||||
|
@ -215,7 +222,7 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
|
|||
*
|
||||
* -> execute()
|
||||
* |
|
||||
* doExecute() ---------> upstream() -------> upstream() ------> execute()
|
||||
* doExecute() ---------> upstreams() -------> upstreams() ------> execute()
|
||||
* |
|
||||
* -----------------> produce()
|
||||
* |
|
||||
|
@ -267,6 +274,9 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
|
|||
|
||||
public GeneratedIterator(Object[] references) {
|
||||
this.references = references;
|
||||
}
|
||||
|
||||
public void init(scala.collection.Iterator inputs[]) {
|
||||
${ctx.initMutableStates()}
|
||||
}
|
||||
|
||||
|
@ -283,19 +293,33 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
|
|||
// println(s"${CodeFormatter.format(cleanedSource)}")
|
||||
CodeGenerator.compile(cleanedSource)
|
||||
|
||||
plan.upstream().mapPartitions { iter =>
|
||||
|
||||
val clazz = CodeGenerator.compile(source)
|
||||
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
|
||||
buffer.setInput(iter)
|
||||
new Iterator[InternalRow] {
|
||||
override def hasNext: Boolean = buffer.hasNext
|
||||
override def next: InternalRow = buffer.next()
|
||||
val rdds = plan.upstreams()
|
||||
assert(rdds.size <= 2, "Up to two upstream RDDs can be supported")
|
||||
if (rdds.length == 1) {
|
||||
rdds.head.mapPartitions { iter =>
|
||||
val clazz = CodeGenerator.compile(cleanedSource)
|
||||
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
|
||||
buffer.init(Array(iter))
|
||||
new Iterator[InternalRow] {
|
||||
override def hasNext: Boolean = buffer.hasNext
|
||||
override def next: InternalRow = buffer.next()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Right now, we support up to two upstreams.
|
||||
rdds.head.zipPartitions(rdds(1)) { (leftIter, rightIter) =>
|
||||
val clazz = CodeGenerator.compile(cleanedSource)
|
||||
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
|
||||
buffer.init(Array(leftIter, rightIter))
|
||||
new Iterator[InternalRow] {
|
||||
override def hasNext: Boolean = buffer.hasNext
|
||||
override def next: InternalRow = buffer.next()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def upstream(): RDD[InternalRow] = {
|
||||
override def upstreams(): Seq[RDD[InternalRow]] = {
|
||||
throw new UnsupportedOperationException
|
||||
}
|
||||
|
||||
|
@ -312,7 +336,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
|
|||
if (row != null) {
|
||||
// There is an UnsafeRow already
|
||||
s"""
|
||||
| currentRows.add($row.copy());
|
||||
|append($row.copy());
|
||||
""".stripMargin
|
||||
} else {
|
||||
assert(input != null)
|
||||
|
@ -324,13 +348,13 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
|
|||
ctx.currentVars = input
|
||||
val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
|
||||
s"""
|
||||
| ${code.code.trim}
|
||||
| currentRows.add(${code.value}.copy());
|
||||
|${code.code.trim}
|
||||
|append(${code.value}.copy());
|
||||
""".stripMargin
|
||||
} else {
|
||||
// There is no columns
|
||||
s"""
|
||||
| currentRows.add(unsafeRow);
|
||||
|append(unsafeRow);
|
||||
""".stripMargin
|
||||
}
|
||||
}
|
||||
|
@ -402,6 +426,9 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
|
|||
b.copy(left = apply(left))
|
||||
case b @ BroadcastHashJoin(_, _, _, BuildRight, _, left, right) =>
|
||||
b.copy(right = apply(right))
|
||||
case j @ SortMergeJoin(_, _, _, left, right) =>
|
||||
// The children of SortMergeJoin should do codegen separately.
|
||||
j.copy(left = apply(left), right = apply(right))
|
||||
case p if !supportCodegen(p) =>
|
||||
val input = apply(p) // collapse them recursively
|
||||
inputs += input
|
||||
|
|
|
@ -121,8 +121,8 @@ case class TungstenAggregate(
|
|||
!aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate])
|
||||
}
|
||||
|
||||
override def upstream(): RDD[InternalRow] = {
|
||||
child.asInstanceOf[CodegenSupport].upstream()
|
||||
override def upstreams(): Seq[RDD[InternalRow]] = {
|
||||
child.asInstanceOf[CodegenSupport].upstreams()
|
||||
}
|
||||
|
||||
protected override def doProduce(ctx: CodegenContext): String = {
|
||||
|
|
|
@ -31,8 +31,8 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
|
|||
|
||||
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
|
||||
|
||||
override def upstream(): RDD[InternalRow] = {
|
||||
child.asInstanceOf[CodegenSupport].upstream()
|
||||
override def upstreams(): Seq[RDD[InternalRow]] = {
|
||||
child.asInstanceOf[CodegenSupport].upstreams()
|
||||
}
|
||||
|
||||
protected override def doProduce(ctx: CodegenContext): String = {
|
||||
|
@ -69,8 +69,8 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit
|
|||
private[sql] override lazy val metrics = Map(
|
||||
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
|
||||
|
||||
override def upstream(): RDD[InternalRow] = {
|
||||
child.asInstanceOf[CodegenSupport].upstream()
|
||||
override def upstreams(): Seq[RDD[InternalRow]] = {
|
||||
child.asInstanceOf[CodegenSupport].upstreams()
|
||||
}
|
||||
|
||||
protected override def doProduce(ctx: CodegenContext): String = {
|
||||
|
@ -156,8 +156,9 @@ case class Range(
|
|||
private[sql] override lazy val metrics = Map(
|
||||
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
|
||||
|
||||
override def upstream(): RDD[InternalRow] = {
|
||||
sqlContext.sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i))
|
||||
override def upstreams(): Seq[RDD[InternalRow]] = {
|
||||
sqlContext.sparkContext.parallelize(0 until numSlices, numSlices)
|
||||
.map(i => InternalRow(i)) :: Nil
|
||||
}
|
||||
|
||||
protected override def doProduce(ctx: CodegenContext): String = {
|
||||
|
@ -213,12 +214,15 @@ case class Range(
|
|||
| }
|
||||
""".stripMargin)
|
||||
|
||||
val input = ctx.freshName("input")
|
||||
// Right now, Range is only used when there is one upstream.
|
||||
ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
|
||||
s"""
|
||||
| // initialize Range
|
||||
| if (!$initTerm) {
|
||||
| $initTerm = true;
|
||||
| if (input.hasNext()) {
|
||||
| initRange(((InternalRow) input.next()).getInt(0));
|
||||
| if ($input.hasNext()) {
|
||||
| initRange(((InternalRow) $input.next()).getInt(0));
|
||||
| } else {
|
||||
| return;
|
||||
| }
|
||||
|
|
|
@ -99,8 +99,8 @@ case class BroadcastHashJoin(
|
|||
}
|
||||
}
|
||||
|
||||
override def upstream(): RDD[InternalRow] = {
|
||||
streamedPlan.asInstanceOf[CodegenSupport].upstream()
|
||||
override def upstreams(): Seq[RDD[InternalRow]] = {
|
||||
streamedPlan.asInstanceOf[CodegenSupport].upstreams()
|
||||
}
|
||||
|
||||
override def doProduce(ctx: CodegenContext): String = {
|
||||
|
|
|
@ -27,7 +27,6 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
|
|||
import org.apache.spark.util.CompletionIterator
|
||||
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
|
||||
|
||||
|
||||
/**
|
||||
* An optimized CartesianRDD for UnsafeRow, which will cache the rows from second child RDD,
|
||||
* will be much faster than building the right partition for every row in left RDD, it also
|
||||
|
|
|
@ -22,9 +22,10 @@ import scala.collection.mutable.ArrayBuffer
|
|||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.plans.physical._
|
||||
import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan}
|
||||
import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics}
|
||||
import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, RowIterator, SparkPlan}
|
||||
import org.apache.spark.sql.execution.metric.SQLMetrics
|
||||
|
||||
/**
|
||||
* Performs an sort merge join of two child relations.
|
||||
|
@ -34,7 +35,7 @@ case class SortMergeJoin(
|
|||
rightKeys: Seq[Expression],
|
||||
condition: Option[Expression],
|
||||
left: SparkPlan,
|
||||
right: SparkPlan) extends BinaryNode {
|
||||
right: SparkPlan) extends BinaryNode with CodegenSupport {
|
||||
|
||||
override private[sql] lazy val metrics = Map(
|
||||
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
|
||||
|
@ -125,6 +126,246 @@ case class SortMergeJoin(
|
|||
}.toScala
|
||||
}
|
||||
}
|
||||
|
||||
override def upstreams(): Seq[RDD[InternalRow]] = {
|
||||
left.execute() :: right.execute() :: Nil
|
||||
}
|
||||
|
||||
private def createJoinKey(
|
||||
ctx: CodegenContext,
|
||||
row: String,
|
||||
keys: Seq[Expression],
|
||||
input: Seq[Attribute]): Seq[ExprCode] = {
|
||||
ctx.INPUT_ROW = row
|
||||
keys.map(BindReferences.bindReference(_, input).gen(ctx))
|
||||
}
|
||||
|
||||
private def copyKeys(ctx: CodegenContext, vars: Seq[ExprCode]): Seq[ExprCode] = {
|
||||
vars.zipWithIndex.map { case (ev, i) =>
|
||||
val value = ctx.freshName("value")
|
||||
ctx.addMutableState(ctx.javaType(leftKeys(i).dataType), value, "")
|
||||
val code =
|
||||
s"""
|
||||
|$value = ${ev.value};
|
||||
""".stripMargin
|
||||
ExprCode(code, "false", value)
|
||||
}
|
||||
}
|
||||
|
||||
private def genComparision(ctx: CodegenContext, a: Seq[ExprCode], b: Seq[ExprCode]): String = {
|
||||
val comparisons = a.zip(b).zipWithIndex.map { case ((l, r), i) =>
|
||||
s"""
|
||||
|if (comp == 0) {
|
||||
| comp = ${ctx.genComp(leftKeys(i).dataType, l.value, r.value)};
|
||||
|}
|
||||
""".stripMargin.trim
|
||||
}
|
||||
s"""
|
||||
|comp = 0;
|
||||
|${comparisons.mkString("\n")}
|
||||
""".stripMargin
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a function to scan both left and right to find a match, returns the term for
|
||||
* matched one row from left side and buffered rows from right side.
|
||||
*/
|
||||
private def genScanner(ctx: CodegenContext): (String, String) = {
|
||||
// Create class member for next row from both sides.
|
||||
val leftRow = ctx.freshName("leftRow")
|
||||
ctx.addMutableState("InternalRow", leftRow, "")
|
||||
val rightRow = ctx.freshName("rightRow")
|
||||
ctx.addMutableState("InternalRow", rightRow, s"$rightRow = null;")
|
||||
|
||||
// Create variables for join keys from both sides.
|
||||
val leftKeyVars = createJoinKey(ctx, leftRow, leftKeys, left.output)
|
||||
val leftAnyNull = leftKeyVars.map(_.isNull).mkString(" || ")
|
||||
val rightKeyTmpVars = createJoinKey(ctx, rightRow, rightKeys, right.output)
|
||||
val rightAnyNull = rightKeyTmpVars.map(_.isNull).mkString(" || ")
|
||||
// Copy the right key as class members so they could be used in next function call.
|
||||
val rightKeyVars = copyKeys(ctx, rightKeyTmpVars)
|
||||
|
||||
// A list to hold all matched rows from right side.
|
||||
val matches = ctx.freshName("matches")
|
||||
val clsName = classOf[java.util.ArrayList[InternalRow]].getName
|
||||
ctx.addMutableState(clsName, matches, s"$matches = new $clsName();")
|
||||
// Copy the left keys as class members so they could be used in next function call.
|
||||
val matchedKeyVars = copyKeys(ctx, leftKeyVars)
|
||||
|
||||
ctx.addNewFunction("findNextInnerJoinRows",
|
||||
s"""
|
||||
|private boolean findNextInnerJoinRows(
|
||||
| scala.collection.Iterator leftIter,
|
||||
| scala.collection.Iterator rightIter) {
|
||||
| $leftRow = null;
|
||||
| int comp = 0;
|
||||
| while ($leftRow == null) {
|
||||
| if (!leftIter.hasNext()) return false;
|
||||
| $leftRow = (InternalRow) leftIter.next();
|
||||
| ${leftKeyVars.map(_.code).mkString("\n")}
|
||||
| if ($leftAnyNull) {
|
||||
| $leftRow = null;
|
||||
| continue;
|
||||
| }
|
||||
| if (!$matches.isEmpty()) {
|
||||
| ${genComparision(ctx, leftKeyVars, matchedKeyVars)}
|
||||
| if (comp == 0) {
|
||||
| return true;
|
||||
| }
|
||||
| $matches.clear();
|
||||
| }
|
||||
|
|
||||
| do {
|
||||
| if ($rightRow == null) {
|
||||
| if (!rightIter.hasNext()) {
|
||||
| ${matchedKeyVars.map(_.code).mkString("\n")}
|
||||
| return !$matches.isEmpty();
|
||||
| }
|
||||
| $rightRow = (InternalRow) rightIter.next();
|
||||
| ${rightKeyTmpVars.map(_.code).mkString("\n")}
|
||||
| if ($rightAnyNull) {
|
||||
| $rightRow = null;
|
||||
| continue;
|
||||
| }
|
||||
| ${rightKeyVars.map(_.code).mkString("\n")}
|
||||
| }
|
||||
| ${genComparision(ctx, leftKeyVars, rightKeyVars)}
|
||||
| if (comp > 0) {
|
||||
| $rightRow = null;
|
||||
| } else if (comp < 0) {
|
||||
| if (!$matches.isEmpty()) {
|
||||
| ${matchedKeyVars.map(_.code).mkString("\n")}
|
||||
| return true;
|
||||
| }
|
||||
| $leftRow = null;
|
||||
| } else {
|
||||
| $matches.add($rightRow.copy());
|
||||
| $rightRow = null;;
|
||||
| }
|
||||
| } while ($leftRow != null);
|
||||
| }
|
||||
| return false; // unreachable
|
||||
|}
|
||||
""".stripMargin)
|
||||
|
||||
(leftRow, matches)
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates variables for left part of result row.
|
||||
*
|
||||
* In order to defer the access after condition and also only access once in the loop,
|
||||
* the variables should be declared separately from accessing the columns, we can't use the
|
||||
* codegen of BoundReference here.
|
||||
*/
|
||||
private def createLeftVars(ctx: CodegenContext, leftRow: String): Seq[ExprCode] = {
|
||||
ctx.INPUT_ROW = leftRow
|
||||
left.output.zipWithIndex.map { case (a, i) =>
|
||||
val value = ctx.freshName("value")
|
||||
val valueCode = ctx.getValue(leftRow, a.dataType, i.toString)
|
||||
// declare it as class member, so we can access the column before or in the loop.
|
||||
ctx.addMutableState(ctx.javaType(a.dataType), value, "")
|
||||
if (a.nullable) {
|
||||
val isNull = ctx.freshName("isNull")
|
||||
ctx.addMutableState("boolean", isNull, "")
|
||||
val code =
|
||||
s"""
|
||||
|$isNull = $leftRow.isNullAt($i);
|
||||
|$value = $isNull ? ${ctx.defaultValue(a.dataType)} : ($valueCode);
|
||||
""".stripMargin
|
||||
ExprCode(code, isNull, value)
|
||||
} else {
|
||||
ExprCode(s"$value = $valueCode;", "false", value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates the variables for right part of result row, using BoundReference, since the right
|
||||
* part are accessed inside the loop.
|
||||
*/
|
||||
private def createRightVar(ctx: CodegenContext, rightRow: String): Seq[ExprCode] = {
|
||||
ctx.INPUT_ROW = rightRow
|
||||
right.output.zipWithIndex.map { case (a, i) =>
|
||||
BoundReference(i, a.dataType, a.nullable).gen(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Splits variables based on whether it's used by condition or not, returns the code to create
|
||||
* these variables before the condition and after the condition.
|
||||
*
|
||||
* Only a few columns are used by condition, then we can skip the accessing of those columns
|
||||
* that are not used by condition also filtered out by condition.
|
||||
*/
|
||||
private def splitVarsByCondition(
|
||||
attributes: Seq[Attribute],
|
||||
variables: Seq[ExprCode]): (String, String) = {
|
||||
if (condition.isDefined) {
|
||||
val condRefs = condition.get.references
|
||||
val (used, notUsed) = attributes.zip(variables).partition{ case (a, ev) =>
|
||||
condRefs.contains(a)
|
||||
}
|
||||
val beforeCond = used.map(_._2.code).mkString("\n")
|
||||
val afterCond = notUsed.map(_._2.code).mkString("\n")
|
||||
(beforeCond, afterCond)
|
||||
} else {
|
||||
(variables.map(_.code).mkString("\n"), "")
|
||||
}
|
||||
}
|
||||
|
||||
override def doProduce(ctx: CodegenContext): String = {
|
||||
val leftInput = ctx.freshName("leftInput")
|
||||
ctx.addMutableState("scala.collection.Iterator", leftInput, s"$leftInput = inputs[0];")
|
||||
val rightInput = ctx.freshName("rightInput")
|
||||
ctx.addMutableState("scala.collection.Iterator", rightInput, s"$rightInput = inputs[1];")
|
||||
|
||||
val (leftRow, matches) = genScanner(ctx)
|
||||
|
||||
// Create variables for row from both sides.
|
||||
val leftVars = createLeftVars(ctx, leftRow)
|
||||
val rightRow = ctx.freshName("rightRow")
|
||||
val rightVars = createRightVar(ctx, rightRow)
|
||||
val resultVars = leftVars ++ rightVars
|
||||
|
||||
// Check condition
|
||||
ctx.currentVars = resultVars
|
||||
val cond = if (condition.isDefined) {
|
||||
BindReferences.bindReference(condition.get, output).gen(ctx)
|
||||
} else {
|
||||
ExprCode("", "false", "true")
|
||||
}
|
||||
// Split the code of creating variables based on whether it's used by condition or not.
|
||||
val loaded = ctx.freshName("loaded")
|
||||
val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars)
|
||||
val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars)
|
||||
|
||||
|
||||
val size = ctx.freshName("size")
|
||||
val i = ctx.freshName("i")
|
||||
val numOutput = metricTerm(ctx, "numOutputRows")
|
||||
s"""
|
||||
|while (findNextInnerJoinRows($leftInput, $rightInput)) {
|
||||
| int $size = $matches.size();
|
||||
| boolean $loaded = false;
|
||||
| $leftBefore
|
||||
| for (int $i = 0; $i < $size; $i ++) {
|
||||
| InternalRow $rightRow = (InternalRow) $matches.get($i);
|
||||
| $rightBefore
|
||||
| ${cond.code}
|
||||
| if (${cond.isNull} || !${cond.value}) continue;
|
||||
| if (!$loaded) {
|
||||
| $loaded = true;
|
||||
| $leftAfter
|
||||
| }
|
||||
| $rightAfter
|
||||
| $numOutput.add(1);
|
||||
| ${consume(ctx, resultVars)}
|
||||
| }
|
||||
| if (shouldStop()) return;
|
||||
|}
|
||||
""".stripMargin
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -38,6 +38,7 @@ import org.apache.spark.util.Benchmark
|
|||
class BenchmarkWholeStageCodegen extends SparkFunSuite {
|
||||
lazy val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark")
|
||||
.set("spark.sql.shuffle.partitions", "1")
|
||||
.set("spark.sql.autoBroadcastJoinThreshold", "0")
|
||||
lazy val sc = SparkContext.getOrCreate(conf)
|
||||
lazy val sqlContext = SQLContext.getOrCreate(sc)
|
||||
|
||||
|
@ -187,6 +188,39 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
|
|||
*/
|
||||
}
|
||||
|
||||
ignore("sort merge join") {
|
||||
val N = 2 << 20
|
||||
runBenchmark("merge join", N) {
|
||||
val df1 = sqlContext.range(N).selectExpr(s"id * 2 as k1")
|
||||
val df2 = sqlContext.range(N).selectExpr(s"id * 3 as k2")
|
||||
df1.join(df2, col("k1") === col("k2")).count()
|
||||
}
|
||||
|
||||
/**
|
||||
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
|
||||
merge join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
|
||||
-------------------------------------------------------------------------------------------
|
||||
merge join codegen=false 1588 / 1880 1.3 757.1 1.0X
|
||||
merge join codegen=true 1477 / 1531 1.4 704.2 1.1X
|
||||
*/
|
||||
|
||||
runBenchmark("sort merge join", N) {
|
||||
val df1 = sqlContext.range(N)
|
||||
.selectExpr(s"(id * 15485863) % ${N*10} as k1")
|
||||
val df2 = sqlContext.range(N)
|
||||
.selectExpr(s"(id * 15485867) % ${N*10} as k2")
|
||||
df1.join(df2, col("k1") === col("k2")).count()
|
||||
}
|
||||
|
||||
/**
|
||||
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
|
||||
sort merge join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
|
||||
-------------------------------------------------------------------------------------------
|
||||
sort merge join codegen=false 3626 / 3667 0.6 1728.9 1.0X
|
||||
sort merge join codegen=true 3405 / 3438 0.6 1623.8 1.1X
|
||||
*/
|
||||
}
|
||||
|
||||
ignore("rube") {
|
||||
val N = 5 << 20
|
||||
|
||||
|
|
|
@ -240,7 +240,8 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
|
|||
withBucket(df1.write.format("parquet"), bucketSpecLeft).saveAsTable("bucketed_table1")
|
||||
withBucket(df2.write.format("parquet"), bucketSpecRight).saveAsTable("bucketed_table2")
|
||||
|
||||
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
|
||||
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0",
|
||||
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") {
|
||||
val t1 = hiveContext.table("bucketed_table1")
|
||||
val t2 = hiveContext.table("bucketed_table2")
|
||||
val joined = t1.join(t2, joinCondition(t1, t2, joinColumns))
|
||||
|
|
Loading…
Reference in a new issue