[SPARK-13373] [SQL] generate sort merge join

## What changes were proposed in this pull request?

Generates code for SortMergeJoin.

## How was the this patch tested?

Unit tests and manually tested with TPCDS Q72, which showed 70% performance improvements (from 42s to 25s), but micro benchmark only show minor improvements, it may depends the distribution of data and number of columns.

Author: Davies Liu <davies@databricks.com>

Closes #11248 from davies/gen_smj.
This commit is contained in:
Davies Liu 2016-02-23 15:00:10 -08:00 committed by Davies Liu
parent c481bdf512
commit 9cdd867da9
11 changed files with 360 additions and 52 deletions

View file

@ -203,6 +203,7 @@ private[spark] class DiskBlockObjectWriter(
numRecordsWritten += 1
writeMetrics.incRecordsWritten(1)
// TODO: call updateBytesWritten() less frequently.
if (numRecordsWritten % 32 == 0) {
updateBytesWritten()
}

View file

@ -29,12 +29,9 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
/**
* An iterator interface used to pull the output from generated function for multiple operators
* (whole stage codegen).
*
* TODO: replaced it by batched columnar format.
*/
public class BufferedRowIterator {
public abstract class BufferedRowIterator {
protected LinkedList<InternalRow> currentRows = new LinkedList<>();
protected Iterator<InternalRow> input;
// used when there is no column in output
protected UnsafeRow unsafeRow = new UnsafeRow(0);
@ -49,8 +46,16 @@ public class BufferedRowIterator {
return currentRows.remove();
}
public void setInput(Iterator<InternalRow> iter) {
input = iter;
/**
* Initializes from array of iterators of InternalRow.
*/
public abstract void init(Iterator<InternalRow> iters[]);
/**
* Append a row to currentRows.
*/
protected void append(InternalRow row) {
currentRows.add(row);
}
/**
@ -74,9 +79,5 @@ public class BufferedRowIterator {
*
* After it's called, if currentRow is still null, it means no more rows left.
*/
protected void processNext() throws IOException {
if (input.hasNext()) {
currentRows.add(input.next());
}
}
protected abstract void processNext() throws IOException;
}

View file

@ -85,8 +85,8 @@ case class Expand(
}
}
override def upstream(): RDD[InternalRow] = {
child.asInstanceOf[CodegenSupport].upstream()
override def upstreams(): Seq[RDD[InternalRow]] = {
child.asInstanceOf[CodegenSupport].upstreams()
}
protected override def doProduce(ctx: CodegenContext): String = {

View file

@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.toCommentSafeString
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight, SortMergeJoin}
import org.apache.spark.sql.execution.metric.LongSQLMetricValue
/**
@ -40,7 +40,8 @@ trait CodegenSupport extends SparkPlan {
/** Prefix used in the current operator's variable names. */
private def variablePrefix: String = this match {
case _: TungstenAggregate => "agg"
case _: BroadcastHashJoin => "join"
case _: BroadcastHashJoin => "bhj"
case _: SortMergeJoin => "smj"
case _ => nodeName.toLowerCase
}
@ -68,9 +69,11 @@ trait CodegenSupport extends SparkPlan {
private var parent: CodegenSupport = null
/**
* Returns the RDD of InternalRow which generates the input rows.
* Returns all the RDDs of InternalRow which generates the input rows.
*
* Note: right now we support up to two RDDs.
*/
def upstream(): RDD[InternalRow]
def upstreams(): Seq[RDD[InternalRow]]
/**
* Returns Java source code to process the rows from upstream.
@ -179,19 +182,23 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
override def supportCodegen: Boolean = false
override def upstream(): RDD[InternalRow] = {
child.execute()
override def upstreams(): Seq[RDD[InternalRow]] = {
child.execute() :: Nil
}
override def doProduce(ctx: CodegenContext): String = {
val input = ctx.freshName("input")
// Right now, InputAdapter is only used when there is one upstream.
ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true))
val row = ctx.freshName("row")
ctx.INPUT_ROW = row
ctx.currentVars = null
val columns = exprs.map(_.gen(ctx))
s"""
| while (input.hasNext()) {
| InternalRow $row = (InternalRow) input.next();
| while ($input.hasNext()) {
| InternalRow $row = (InternalRow) $input.next();
| ${columns.map(_.code).mkString("\n").trim}
| ${consume(ctx, columns).trim}
| if (shouldStop()) {
@ -215,7 +222,7 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
*
* -> execute()
* |
* doExecute() ---------> upstream() -------> upstream() ------> execute()
* doExecute() ---------> upstreams() -------> upstreams() ------> execute()
* |
* -----------------> produce()
* |
@ -267,6 +274,9 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
public GeneratedIterator(Object[] references) {
this.references = references;
}
public void init(scala.collection.Iterator inputs[]) {
${ctx.initMutableStates()}
}
@ -283,19 +293,33 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
// println(s"${CodeFormatter.format(cleanedSource)}")
CodeGenerator.compile(cleanedSource)
plan.upstream().mapPartitions { iter =>
val clazz = CodeGenerator.compile(source)
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
buffer.setInput(iter)
new Iterator[InternalRow] {
override def hasNext: Boolean = buffer.hasNext
override def next: InternalRow = buffer.next()
val rdds = plan.upstreams()
assert(rdds.size <= 2, "Up to two upstream RDDs can be supported")
if (rdds.length == 1) {
rdds.head.mapPartitions { iter =>
val clazz = CodeGenerator.compile(cleanedSource)
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
buffer.init(Array(iter))
new Iterator[InternalRow] {
override def hasNext: Boolean = buffer.hasNext
override def next: InternalRow = buffer.next()
}
}
} else {
// Right now, we support up to two upstreams.
rdds.head.zipPartitions(rdds(1)) { (leftIter, rightIter) =>
val clazz = CodeGenerator.compile(cleanedSource)
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
buffer.init(Array(leftIter, rightIter))
new Iterator[InternalRow] {
override def hasNext: Boolean = buffer.hasNext
override def next: InternalRow = buffer.next()
}
}
}
}
override def upstream(): RDD[InternalRow] = {
override def upstreams(): Seq[RDD[InternalRow]] = {
throw new UnsupportedOperationException
}
@ -312,7 +336,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
if (row != null) {
// There is an UnsafeRow already
s"""
| currentRows.add($row.copy());
|append($row.copy());
""".stripMargin
} else {
assert(input != null)
@ -324,13 +348,13 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
ctx.currentVars = input
val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
s"""
| ${code.code.trim}
| currentRows.add(${code.value}.copy());
|${code.code.trim}
|append(${code.value}.copy());
""".stripMargin
} else {
// There is no columns
s"""
| currentRows.add(unsafeRow);
|append(unsafeRow);
""".stripMargin
}
}
@ -402,6 +426,9 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
b.copy(left = apply(left))
case b @ BroadcastHashJoin(_, _, _, BuildRight, _, left, right) =>
b.copy(right = apply(right))
case j @ SortMergeJoin(_, _, _, left, right) =>
// The children of SortMergeJoin should do codegen separately.
j.copy(left = apply(left), right = apply(right))
case p if !supportCodegen(p) =>
val input = apply(p) // collapse them recursively
inputs += input

View file

@ -121,8 +121,8 @@ case class TungstenAggregate(
!aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate])
}
override def upstream(): RDD[InternalRow] = {
child.asInstanceOf[CodegenSupport].upstream()
override def upstreams(): Seq[RDD[InternalRow]] = {
child.asInstanceOf[CodegenSupport].upstreams()
}
protected override def doProduce(ctx: CodegenContext): String = {

View file

@ -31,8 +31,8 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
override def upstream(): RDD[InternalRow] = {
child.asInstanceOf[CodegenSupport].upstream()
override def upstreams(): Seq[RDD[InternalRow]] = {
child.asInstanceOf[CodegenSupport].upstreams()
}
protected override def doProduce(ctx: CodegenContext): String = {
@ -69,8 +69,8 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit
private[sql] override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
override def upstream(): RDD[InternalRow] = {
child.asInstanceOf[CodegenSupport].upstream()
override def upstreams(): Seq[RDD[InternalRow]] = {
child.asInstanceOf[CodegenSupport].upstreams()
}
protected override def doProduce(ctx: CodegenContext): String = {
@ -156,8 +156,9 @@ case class Range(
private[sql] override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
override def upstream(): RDD[InternalRow] = {
sqlContext.sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i))
override def upstreams(): Seq[RDD[InternalRow]] = {
sqlContext.sparkContext.parallelize(0 until numSlices, numSlices)
.map(i => InternalRow(i)) :: Nil
}
protected override def doProduce(ctx: CodegenContext): String = {
@ -213,12 +214,15 @@ case class Range(
| }
""".stripMargin)
val input = ctx.freshName("input")
// Right now, Range is only used when there is one upstream.
ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
s"""
| // initialize Range
| if (!$initTerm) {
| $initTerm = true;
| if (input.hasNext()) {
| initRange(((InternalRow) input.next()).getInt(0));
| if ($input.hasNext()) {
| initRange(((InternalRow) $input.next()).getInt(0));
| } else {
| return;
| }

View file

@ -99,8 +99,8 @@ case class BroadcastHashJoin(
}
}
override def upstream(): RDD[InternalRow] = {
streamedPlan.asInstanceOf[CodegenSupport].upstream()
override def upstreams(): Seq[RDD[InternalRow]] = {
streamedPlan.asInstanceOf[CodegenSupport].upstreams()
}
override def doProduce(ctx: CodegenContext): String = {

View file

@ -27,7 +27,6 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
/**
* An optimized CartesianRDD for UnsafeRow, which will cache the rows from second child RDD,
* will be much faster than building the right partition for every row in left RDD, it also

View file

@ -22,9 +22,10 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan}
import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics}
import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, RowIterator, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
/**
* Performs an sort merge join of two child relations.
@ -34,7 +35,7 @@ case class SortMergeJoin(
rightKeys: Seq[Expression],
condition: Option[Expression],
left: SparkPlan,
right: SparkPlan) extends BinaryNode {
right: SparkPlan) extends BinaryNode with CodegenSupport {
override private[sql] lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
@ -125,6 +126,246 @@ case class SortMergeJoin(
}.toScala
}
}
override def upstreams(): Seq[RDD[InternalRow]] = {
left.execute() :: right.execute() :: Nil
}
private def createJoinKey(
ctx: CodegenContext,
row: String,
keys: Seq[Expression],
input: Seq[Attribute]): Seq[ExprCode] = {
ctx.INPUT_ROW = row
keys.map(BindReferences.bindReference(_, input).gen(ctx))
}
private def copyKeys(ctx: CodegenContext, vars: Seq[ExprCode]): Seq[ExprCode] = {
vars.zipWithIndex.map { case (ev, i) =>
val value = ctx.freshName("value")
ctx.addMutableState(ctx.javaType(leftKeys(i).dataType), value, "")
val code =
s"""
|$value = ${ev.value};
""".stripMargin
ExprCode(code, "false", value)
}
}
private def genComparision(ctx: CodegenContext, a: Seq[ExprCode], b: Seq[ExprCode]): String = {
val comparisons = a.zip(b).zipWithIndex.map { case ((l, r), i) =>
s"""
|if (comp == 0) {
| comp = ${ctx.genComp(leftKeys(i).dataType, l.value, r.value)};
|}
""".stripMargin.trim
}
s"""
|comp = 0;
|${comparisons.mkString("\n")}
""".stripMargin
}
/**
* Generate a function to scan both left and right to find a match, returns the term for
* matched one row from left side and buffered rows from right side.
*/
private def genScanner(ctx: CodegenContext): (String, String) = {
// Create class member for next row from both sides.
val leftRow = ctx.freshName("leftRow")
ctx.addMutableState("InternalRow", leftRow, "")
val rightRow = ctx.freshName("rightRow")
ctx.addMutableState("InternalRow", rightRow, s"$rightRow = null;")
// Create variables for join keys from both sides.
val leftKeyVars = createJoinKey(ctx, leftRow, leftKeys, left.output)
val leftAnyNull = leftKeyVars.map(_.isNull).mkString(" || ")
val rightKeyTmpVars = createJoinKey(ctx, rightRow, rightKeys, right.output)
val rightAnyNull = rightKeyTmpVars.map(_.isNull).mkString(" || ")
// Copy the right key as class members so they could be used in next function call.
val rightKeyVars = copyKeys(ctx, rightKeyTmpVars)
// A list to hold all matched rows from right side.
val matches = ctx.freshName("matches")
val clsName = classOf[java.util.ArrayList[InternalRow]].getName
ctx.addMutableState(clsName, matches, s"$matches = new $clsName();")
// Copy the left keys as class members so they could be used in next function call.
val matchedKeyVars = copyKeys(ctx, leftKeyVars)
ctx.addNewFunction("findNextInnerJoinRows",
s"""
|private boolean findNextInnerJoinRows(
| scala.collection.Iterator leftIter,
| scala.collection.Iterator rightIter) {
| $leftRow = null;
| int comp = 0;
| while ($leftRow == null) {
| if (!leftIter.hasNext()) return false;
| $leftRow = (InternalRow) leftIter.next();
| ${leftKeyVars.map(_.code).mkString("\n")}
| if ($leftAnyNull) {
| $leftRow = null;
| continue;
| }
| if (!$matches.isEmpty()) {
| ${genComparision(ctx, leftKeyVars, matchedKeyVars)}
| if (comp == 0) {
| return true;
| }
| $matches.clear();
| }
|
| do {
| if ($rightRow == null) {
| if (!rightIter.hasNext()) {
| ${matchedKeyVars.map(_.code).mkString("\n")}
| return !$matches.isEmpty();
| }
| $rightRow = (InternalRow) rightIter.next();
| ${rightKeyTmpVars.map(_.code).mkString("\n")}
| if ($rightAnyNull) {
| $rightRow = null;
| continue;
| }
| ${rightKeyVars.map(_.code).mkString("\n")}
| }
| ${genComparision(ctx, leftKeyVars, rightKeyVars)}
| if (comp > 0) {
| $rightRow = null;
| } else if (comp < 0) {
| if (!$matches.isEmpty()) {
| ${matchedKeyVars.map(_.code).mkString("\n")}
| return true;
| }
| $leftRow = null;
| } else {
| $matches.add($rightRow.copy());
| $rightRow = null;;
| }
| } while ($leftRow != null);
| }
| return false; // unreachable
|}
""".stripMargin)
(leftRow, matches)
}
/**
* Creates variables for left part of result row.
*
* In order to defer the access after condition and also only access once in the loop,
* the variables should be declared separately from accessing the columns, we can't use the
* codegen of BoundReference here.
*/
private def createLeftVars(ctx: CodegenContext, leftRow: String): Seq[ExprCode] = {
ctx.INPUT_ROW = leftRow
left.output.zipWithIndex.map { case (a, i) =>
val value = ctx.freshName("value")
val valueCode = ctx.getValue(leftRow, a.dataType, i.toString)
// declare it as class member, so we can access the column before or in the loop.
ctx.addMutableState(ctx.javaType(a.dataType), value, "")
if (a.nullable) {
val isNull = ctx.freshName("isNull")
ctx.addMutableState("boolean", isNull, "")
val code =
s"""
|$isNull = $leftRow.isNullAt($i);
|$value = $isNull ? ${ctx.defaultValue(a.dataType)} : ($valueCode);
""".stripMargin
ExprCode(code, isNull, value)
} else {
ExprCode(s"$value = $valueCode;", "false", value)
}
}
}
/**
* Creates the variables for right part of result row, using BoundReference, since the right
* part are accessed inside the loop.
*/
private def createRightVar(ctx: CodegenContext, rightRow: String): Seq[ExprCode] = {
ctx.INPUT_ROW = rightRow
right.output.zipWithIndex.map { case (a, i) =>
BoundReference(i, a.dataType, a.nullable).gen(ctx)
}
}
/**
* Splits variables based on whether it's used by condition or not, returns the code to create
* these variables before the condition and after the condition.
*
* Only a few columns are used by condition, then we can skip the accessing of those columns
* that are not used by condition also filtered out by condition.
*/
private def splitVarsByCondition(
attributes: Seq[Attribute],
variables: Seq[ExprCode]): (String, String) = {
if (condition.isDefined) {
val condRefs = condition.get.references
val (used, notUsed) = attributes.zip(variables).partition{ case (a, ev) =>
condRefs.contains(a)
}
val beforeCond = used.map(_._2.code).mkString("\n")
val afterCond = notUsed.map(_._2.code).mkString("\n")
(beforeCond, afterCond)
} else {
(variables.map(_.code).mkString("\n"), "")
}
}
override def doProduce(ctx: CodegenContext): String = {
val leftInput = ctx.freshName("leftInput")
ctx.addMutableState("scala.collection.Iterator", leftInput, s"$leftInput = inputs[0];")
val rightInput = ctx.freshName("rightInput")
ctx.addMutableState("scala.collection.Iterator", rightInput, s"$rightInput = inputs[1];")
val (leftRow, matches) = genScanner(ctx)
// Create variables for row from both sides.
val leftVars = createLeftVars(ctx, leftRow)
val rightRow = ctx.freshName("rightRow")
val rightVars = createRightVar(ctx, rightRow)
val resultVars = leftVars ++ rightVars
// Check condition
ctx.currentVars = resultVars
val cond = if (condition.isDefined) {
BindReferences.bindReference(condition.get, output).gen(ctx)
} else {
ExprCode("", "false", "true")
}
// Split the code of creating variables based on whether it's used by condition or not.
val loaded = ctx.freshName("loaded")
val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars)
val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars)
val size = ctx.freshName("size")
val i = ctx.freshName("i")
val numOutput = metricTerm(ctx, "numOutputRows")
s"""
|while (findNextInnerJoinRows($leftInput, $rightInput)) {
| int $size = $matches.size();
| boolean $loaded = false;
| $leftBefore
| for (int $i = 0; $i < $size; $i ++) {
| InternalRow $rightRow = (InternalRow) $matches.get($i);
| $rightBefore
| ${cond.code}
| if (${cond.isNull} || !${cond.value}) continue;
| if (!$loaded) {
| $loaded = true;
| $leftAfter
| }
| $rightAfter
| $numOutput.add(1);
| ${consume(ctx, resultVars)}
| }
| if (shouldStop()) return;
|}
""".stripMargin
}
}
/**

View file

@ -38,6 +38,7 @@ import org.apache.spark.util.Benchmark
class BenchmarkWholeStageCodegen extends SparkFunSuite {
lazy val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark")
.set("spark.sql.shuffle.partitions", "1")
.set("spark.sql.autoBroadcastJoinThreshold", "0")
lazy val sc = SparkContext.getOrCreate(conf)
lazy val sqlContext = SQLContext.getOrCreate(sc)
@ -187,6 +188,39 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
*/
}
ignore("sort merge join") {
val N = 2 << 20
runBenchmark("merge join", N) {
val df1 = sqlContext.range(N).selectExpr(s"id * 2 as k1")
val df2 = sqlContext.range(N).selectExpr(s"id * 3 as k2")
df1.join(df2, col("k1") === col("k2")).count()
}
/**
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
merge join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
merge join codegen=false 1588 / 1880 1.3 757.1 1.0X
merge join codegen=true 1477 / 1531 1.4 704.2 1.1X
*/
runBenchmark("sort merge join", N) {
val df1 = sqlContext.range(N)
.selectExpr(s"(id * 15485863) % ${N*10} as k1")
val df2 = sqlContext.range(N)
.selectExpr(s"(id * 15485867) % ${N*10} as k2")
df1.join(df2, col("k1") === col("k2")).count()
}
/**
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
sort merge join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
sort merge join codegen=false 3626 / 3667 0.6 1728.9 1.0X
sort merge join codegen=true 3405 / 3438 0.6 1623.8 1.1X
*/
}
ignore("rube") {
val N = 5 << 20

View file

@ -240,7 +240,8 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
withBucket(df1.write.format("parquet"), bucketSpecLeft).saveAsTable("bucketed_table1")
withBucket(df2.write.format("parquet"), bucketSpecRight).saveAsTable("bucketed_table2")
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0",
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") {
val t1 = hiveContext.table("bucketed_table1")
val t2 = hiveContext.table("bucketed_table2")
val joined = t1.join(t2, joinCondition(t1, t2, joinColumns))