Revert "[SPARK-13031] [SQL] cleanup codegen and improve test coverage"
This reverts commit cc18a71992
.
This commit is contained in:
parent
4637fc08a3
commit
b9dfdcc63b
|
@ -144,23 +144,14 @@ class CodegenContext {
|
||||||
|
|
||||||
private val curId = new java.util.concurrent.atomic.AtomicInteger()
|
private val curId = new java.util.concurrent.atomic.AtomicInteger()
|
||||||
|
|
||||||
/**
|
|
||||||
* A prefix used to generate fresh name.
|
|
||||||
*/
|
|
||||||
var freshNamePrefix = ""
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a term name that is unique within this instance of a `CodeGenerator`.
|
* Returns a term name that is unique within this instance of a `CodeGenerator`.
|
||||||
*
|
*
|
||||||
* (Since we aren't in a macro context we do not seem to have access to the built in `freshName`
|
* (Since we aren't in a macro context we do not seem to have access to the built in `freshName`
|
||||||
* function.)
|
* function.)
|
||||||
*/
|
*/
|
||||||
def freshName(name: String): String = {
|
def freshName(prefix: String): String = {
|
||||||
if (freshNamePrefix == "") {
|
s"$prefix${curId.getAndIncrement}"
|
||||||
s"$name${curId.getAndIncrement}"
|
|
||||||
} else {
|
|
||||||
s"${freshNamePrefix}_$name${curId.getAndIncrement}"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -93,7 +93,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
|
||||||
// Can't call setNullAt on DecimalType, because we need to keep the offset
|
// Can't call setNullAt on DecimalType, because we need to keep the offset
|
||||||
s"""
|
s"""
|
||||||
if (this.isNull_$i) {
|
if (this.isNull_$i) {
|
||||||
${ctx.setColumn("mutableRow", e.dataType, i, "null")};
|
${ctx.setColumn("mutableRow", e.dataType, i, null)};
|
||||||
} else {
|
} else {
|
||||||
${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
|
${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,11 +22,9 @@ import scala.collection.mutable.ArrayBuffer
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.SQLContext
|
import org.apache.spark.sql.SQLContext
|
||||||
import org.apache.spark.sql.catalyst.InternalRow
|
import org.apache.spark.sql.catalyst.InternalRow
|
||||||
import org.apache.spark.sql.catalyst.expressions._
|
import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, LeafExpression}
|
||||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||||
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
|
|
||||||
import org.apache.spark.sql.catalyst.rules.Rule
|
import org.apache.spark.sql.catalyst.rules.Rule
|
||||||
import org.apache.spark.util.Utils
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An interface for those physical operators that support codegen.
|
* An interface for those physical operators that support codegen.
|
||||||
|
@ -44,16 +42,10 @@ trait CodegenSupport extends SparkPlan {
|
||||||
private var parent: CodegenSupport = null
|
private var parent: CodegenSupport = null
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the RDD of InternalRow which generates the input rows.
|
* Returns an input RDD of InternalRow and Java source code to process them.
|
||||||
*/
|
*/
|
||||||
def upstream(): RDD[InternalRow]
|
def produce(ctx: CodegenContext, parent: CodegenSupport): (RDD[InternalRow], String) = {
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns Java source code to process the rows from upstream.
|
|
||||||
*/
|
|
||||||
def produce(ctx: CodegenContext, parent: CodegenSupport): String = {
|
|
||||||
this.parent = parent
|
this.parent = parent
|
||||||
ctx.freshNamePrefix = nodeName
|
|
||||||
doProduce(ctx)
|
doProduce(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -74,41 +66,16 @@ trait CodegenSupport extends SparkPlan {
|
||||||
* # call consume(), wich will call parent.doConsume()
|
* # call consume(), wich will call parent.doConsume()
|
||||||
* }
|
* }
|
||||||
*/
|
*/
|
||||||
protected def doProduce(ctx: CodegenContext): String
|
protected def doProduce(ctx: CodegenContext): (RDD[InternalRow], String)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Consume the columns generated from current SparkPlan, call it's parent.
|
* Consume the columns generated from current SparkPlan, call it's parent or create an iterator.
|
||||||
*/
|
*/
|
||||||
def consume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = {
|
protected def consume(ctx: CodegenContext, columns: Seq[ExprCode]): String = {
|
||||||
if (input != null) {
|
assert(columns.length == output.length)
|
||||||
assert(input.length == output.length)
|
parent.doConsume(ctx, this, columns)
|
||||||
}
|
|
||||||
parent.consumeChild(ctx, this, input, row)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Consume the columns generated from it's child, call doConsume() or emit the rows.
|
|
||||||
*/
|
|
||||||
def consumeChild(
|
|
||||||
ctx: CodegenContext,
|
|
||||||
child: SparkPlan,
|
|
||||||
input: Seq[ExprCode],
|
|
||||||
row: String = null): String = {
|
|
||||||
ctx.freshNamePrefix = nodeName
|
|
||||||
if (row != null) {
|
|
||||||
ctx.currentVars = null
|
|
||||||
ctx.INPUT_ROW = row
|
|
||||||
val evals = child.output.zipWithIndex.map { case (attr, i) =>
|
|
||||||
BoundReference(i, attr.dataType, attr.nullable).gen(ctx)
|
|
||||||
}
|
|
||||||
s"""
|
|
||||||
| ${evals.map(_.code).mkString("\n")}
|
|
||||||
| ${doConsume(ctx, evals)}
|
|
||||||
""".stripMargin
|
|
||||||
} else {
|
|
||||||
doConsume(ctx, input)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate the Java source code to process the rows from child SparkPlan.
|
* Generate the Java source code to process the rows from child SparkPlan.
|
||||||
|
@ -122,9 +89,7 @@ trait CodegenSupport extends SparkPlan {
|
||||||
* # call consume(), which will call parent.doConsume()
|
* # call consume(), which will call parent.doConsume()
|
||||||
* }
|
* }
|
||||||
*/
|
*/
|
||||||
protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
|
def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String
|
||||||
throw new UnsupportedOperationException
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -137,36 +102,31 @@ trait CodegenSupport extends SparkPlan {
|
||||||
case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
|
case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
|
||||||
|
|
||||||
override def output: Seq[Attribute] = child.output
|
override def output: Seq[Attribute] = child.output
|
||||||
override def outputPartitioning: Partitioning = child.outputPartitioning
|
|
||||||
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
|
|
||||||
|
|
||||||
override def doPrepare(): Unit = {
|
override def supportCodegen: Boolean = true
|
||||||
child.prepare()
|
|
||||||
}
|
|
||||||
|
|
||||||
override def doExecute(): RDD[InternalRow] = {
|
override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
|
||||||
child.execute()
|
|
||||||
}
|
|
||||||
|
|
||||||
override def supportCodegen: Boolean = false
|
|
||||||
|
|
||||||
override def upstream(): RDD[InternalRow] = {
|
|
||||||
child.execute()
|
|
||||||
}
|
|
||||||
|
|
||||||
override def doProduce(ctx: CodegenContext): String = {
|
|
||||||
val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true))
|
val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true))
|
||||||
val row = ctx.freshName("row")
|
val row = ctx.freshName("row")
|
||||||
ctx.INPUT_ROW = row
|
ctx.INPUT_ROW = row
|
||||||
ctx.currentVars = null
|
ctx.currentVars = null
|
||||||
val columns = exprs.map(_.gen(ctx))
|
val columns = exprs.map(_.gen(ctx))
|
||||||
s"""
|
val code = s"""
|
||||||
| while (input.hasNext()) {
|
| while (input.hasNext()) {
|
||||||
| InternalRow $row = (InternalRow) input.next();
|
| InternalRow $row = (InternalRow) input.next();
|
||||||
| ${columns.map(_.code).mkString("\n")}
|
| ${columns.map(_.code).mkString("\n")}
|
||||||
| ${consume(ctx, columns)}
|
| ${consume(ctx, columns)}
|
||||||
| }
|
| }
|
||||||
""".stripMargin
|
""".stripMargin
|
||||||
|
(child.execute(), code)
|
||||||
|
}
|
||||||
|
|
||||||
|
def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
|
||||||
|
throw new UnsupportedOperationException
|
||||||
|
}
|
||||||
|
|
||||||
|
override def doExecute(): RDD[InternalRow] = {
|
||||||
|
throw new UnsupportedOperationException
|
||||||
}
|
}
|
||||||
|
|
||||||
override def simpleString: String = "INPUT"
|
override def simpleString: String = "INPUT"
|
||||||
|
@ -183,20 +143,16 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
|
||||||
*
|
*
|
||||||
* -> execute()
|
* -> execute()
|
||||||
* |
|
* |
|
||||||
* doExecute() ---------> upstream() -------> upstream() ------> execute()
|
* doExecute() --------> produce()
|
||||||
* |
|
|
||||||
* -----------------> produce()
|
|
||||||
* |
|
* |
|
||||||
* doProduce() -------> produce()
|
* doProduce() -------> produce()
|
||||||
* |
|
* |
|
||||||
* doProduce()
|
* doProduce() ---> execute()
|
||||||
* |
|
* |
|
||||||
* consume()
|
* consume()
|
||||||
* consumeChild() <-----------|
|
* doConsume() ------------|
|
||||||
* |
|
* |
|
||||||
* doConsume()
|
* doConsume() <----- consume()
|
||||||
* |
|
|
||||||
* consumeChild() <----- consume()
|
|
||||||
*
|
*
|
||||||
* SparkPlan A should override doProduce() and doConsume().
|
* SparkPlan A should override doProduce() and doConsume().
|
||||||
*
|
*
|
||||||
|
@ -206,19 +162,11 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
|
||||||
case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
|
case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
|
||||||
extends SparkPlan with CodegenSupport {
|
extends SparkPlan with CodegenSupport {
|
||||||
|
|
||||||
override def supportCodegen: Boolean = false
|
|
||||||
|
|
||||||
override def output: Seq[Attribute] = plan.output
|
override def output: Seq[Attribute] = plan.output
|
||||||
override def outputPartitioning: Partitioning = plan.outputPartitioning
|
|
||||||
override def outputOrdering: Seq[SortOrder] = plan.outputOrdering
|
|
||||||
|
|
||||||
override def doPrepare(): Unit = {
|
|
||||||
plan.prepare()
|
|
||||||
}
|
|
||||||
|
|
||||||
override def doExecute(): RDD[InternalRow] = {
|
override def doExecute(): RDD[InternalRow] = {
|
||||||
val ctx = new CodegenContext
|
val ctx = new CodegenContext
|
||||||
val code = plan.produce(ctx, this)
|
val (rdd, code) = plan.produce(ctx, this)
|
||||||
val references = ctx.references.toArray
|
val references = ctx.references.toArray
|
||||||
val source = s"""
|
val source = s"""
|
||||||
public Object generate(Object[] references) {
|
public Object generate(Object[] references) {
|
||||||
|
@ -229,25 +177,22 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
|
||||||
|
|
||||||
private Object[] references;
|
private Object[] references;
|
||||||
${ctx.declareMutableStates()}
|
${ctx.declareMutableStates()}
|
||||||
${ctx.declareAddedFunctions()}
|
|
||||||
|
|
||||||
public GeneratedIterator(Object[] references) {
|
public GeneratedIterator(Object[] references) {
|
||||||
this.references = references;
|
this.references = references;
|
||||||
${ctx.initMutableStates()}
|
${ctx.initMutableStates()}
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void processNext() throws java.io.IOException {
|
protected void processNext() {
|
||||||
$code
|
$code
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
// try to compile, helpful for debug
|
// try to compile, helpful for debug
|
||||||
// println(s"${CodeFormatter.format(source)}")
|
// println(s"${CodeFormatter.format(source)}")
|
||||||
CodeGenerator.compile(source)
|
CodeGenerator.compile(source)
|
||||||
|
|
||||||
plan.upstream().mapPartitions { iter =>
|
rdd.mapPartitions { iter =>
|
||||||
|
|
||||||
val clazz = CodeGenerator.compile(source)
|
val clazz = CodeGenerator.compile(source)
|
||||||
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
|
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
|
||||||
buffer.setInput(iter)
|
buffer.setInput(iter)
|
||||||
|
@ -258,28 +203,11 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override def upstream(): RDD[InternalRow] = {
|
override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
|
||||||
throw new UnsupportedOperationException
|
throw new UnsupportedOperationException
|
||||||
}
|
}
|
||||||
|
|
||||||
override def doProduce(ctx: CodegenContext): String = {
|
override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
|
||||||
throw new UnsupportedOperationException
|
|
||||||
}
|
|
||||||
|
|
||||||
override def consumeChild(
|
|
||||||
ctx: CodegenContext,
|
|
||||||
child: SparkPlan,
|
|
||||||
input: Seq[ExprCode],
|
|
||||||
row: String = null): String = {
|
|
||||||
|
|
||||||
if (row != null) {
|
|
||||||
// There is an UnsafeRow already
|
|
||||||
s"""
|
|
||||||
| currentRow = $row;
|
|
||||||
| return;
|
|
||||||
""".stripMargin
|
|
||||||
} else {
|
|
||||||
assert(input != null)
|
|
||||||
if (input.nonEmpty) {
|
if (input.nonEmpty) {
|
||||||
val colExprs = output.zipWithIndex.map { case (attr, i) =>
|
val colExprs = output.zipWithIndex.map { case (attr, i) =>
|
||||||
BoundReference(i, attr.dataType, attr.nullable)
|
BoundReference(i, attr.dataType, attr.nullable)
|
||||||
|
@ -300,7 +228,6 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
|
||||||
""".stripMargin
|
""".stripMargin
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
override def generateTreeString(
|
override def generateTreeString(
|
||||||
depth: Int,
|
depth: Int,
|
||||||
|
@ -319,7 +246,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
|
||||||
builder.append(simpleString)
|
builder.append(simpleString)
|
||||||
builder.append("\n")
|
builder.append("\n")
|
||||||
|
|
||||||
plan.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder)
|
plan.generateTreeString(depth + 1, lastChildren :+children.isEmpty :+ true, builder)
|
||||||
if (children.nonEmpty) {
|
if (children.nonEmpty) {
|
||||||
children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder))
|
children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder))
|
||||||
children.last.generateTreeString(depth + 1, lastChildren :+ true, builder)
|
children.last.generateTreeString(depth + 1, lastChildren :+ true, builder)
|
||||||
|
@ -359,14 +286,13 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
|
||||||
case plan: CodegenSupport if supportCodegen(plan) &&
|
case plan: CodegenSupport if supportCodegen(plan) &&
|
||||||
// Whole stage codegen is only useful when there are at least two levels of operators that
|
// Whole stage codegen is only useful when there are at least two levels of operators that
|
||||||
// support it (save at least one projection/iterator).
|
// support it (save at least one projection/iterator).
|
||||||
(Utils.isTesting || plan.children.exists(supportCodegen)) =>
|
plan.children.exists(supportCodegen) =>
|
||||||
|
|
||||||
var inputs = ArrayBuffer[SparkPlan]()
|
var inputs = ArrayBuffer[SparkPlan]()
|
||||||
val combined = plan.transform {
|
val combined = plan.transform {
|
||||||
case p if !supportCodegen(p) =>
|
case p if !supportCodegen(p) =>
|
||||||
val input = apply(p) // collapse them recursively
|
inputs += p
|
||||||
inputs += input
|
InputAdapter(p)
|
||||||
InputAdapter(input)
|
|
||||||
}.asInstanceOf[CodegenSupport]
|
}.asInstanceOf[CodegenSupport]
|
||||||
WholeStageCodegen(combined, inputs)
|
WholeStageCodegen(combined, inputs)
|
||||||
}
|
}
|
||||||
|
|
|
@ -117,7 +117,9 @@ case class TungstenAggregate(
|
||||||
override def supportCodegen: Boolean = {
|
override def supportCodegen: Boolean = {
|
||||||
groupingExpressions.isEmpty &&
|
groupingExpressions.isEmpty &&
|
||||||
// ImperativeAggregate is not supported right now
|
// ImperativeAggregate is not supported right now
|
||||||
!aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate])
|
!aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) &&
|
||||||
|
// final aggregation only have one row, do not need to codegen
|
||||||
|
!aggregateExpressions.exists(e => e.mode == Final || e.mode == Complete)
|
||||||
}
|
}
|
||||||
|
|
||||||
// The variables used as aggregation buffer
|
// The variables used as aggregation buffer
|
||||||
|
@ -125,11 +127,7 @@ case class TungstenAggregate(
|
||||||
|
|
||||||
private val modes = aggregateExpressions.map(_.mode).distinct
|
private val modes = aggregateExpressions.map(_.mode).distinct
|
||||||
|
|
||||||
override def upstream(): RDD[InternalRow] = {
|
protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
|
||||||
child.asInstanceOf[CodegenSupport].upstream()
|
|
||||||
}
|
|
||||||
|
|
||||||
protected override def doProduce(ctx: CodegenContext): String = {
|
|
||||||
val initAgg = ctx.freshName("initAgg")
|
val initAgg = ctx.freshName("initAgg")
|
||||||
ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
|
ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
|
||||||
|
|
||||||
|
@ -139,80 +137,50 @@ case class TungstenAggregate(
|
||||||
bufVars = initExpr.map { e =>
|
bufVars = initExpr.map { e =>
|
||||||
val isNull = ctx.freshName("bufIsNull")
|
val isNull = ctx.freshName("bufIsNull")
|
||||||
val value = ctx.freshName("bufValue")
|
val value = ctx.freshName("bufValue")
|
||||||
ctx.addMutableState("boolean", isNull, "")
|
|
||||||
ctx.addMutableState(ctx.javaType(e.dataType), value, "")
|
|
||||||
// The initial expression should not access any column
|
// The initial expression should not access any column
|
||||||
val ev = e.gen(ctx)
|
val ev = e.gen(ctx)
|
||||||
val initVars = s"""
|
val initVars = s"""
|
||||||
| $isNull = ${ev.isNull};
|
| boolean $isNull = ${ev.isNull};
|
||||||
| $value = ${ev.value};
|
| ${ctx.javaType(e.dataType)} $value = ${ev.value};
|
||||||
""".stripMargin
|
""".stripMargin
|
||||||
ExprCode(ev.code + initVars, isNull, value)
|
ExprCode(ev.code + initVars, isNull, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
// generate variables for output
|
val (rdd, childSource) = child.asInstanceOf[CodegenSupport].produce(ctx, this)
|
||||||
val (resultVars, genResult) = if (modes.contains(Final) | modes.contains(Complete)) {
|
val source =
|
||||||
// evaluate aggregate results
|
|
||||||
ctx.currentVars = bufVars
|
|
||||||
val bufferAttrs = functions.flatMap(_.aggBufferAttributes)
|
|
||||||
val aggResults = functions.map(_.evaluateExpression).map { e =>
|
|
||||||
BindReferences.bindReference(e, bufferAttrs).gen(ctx)
|
|
||||||
}
|
|
||||||
// evaluate result expressions
|
|
||||||
ctx.currentVars = aggResults
|
|
||||||
val resultVars = resultExpressions.map { e =>
|
|
||||||
BindReferences.bindReference(e, aggregateAttributes).gen(ctx)
|
|
||||||
}
|
|
||||||
(resultVars, s"""
|
|
||||||
| ${aggResults.map(_.code).mkString("\n")}
|
|
||||||
| ${resultVars.map(_.code).mkString("\n")}
|
|
||||||
""".stripMargin)
|
|
||||||
} else {
|
|
||||||
// output the aggregate buffer directly
|
|
||||||
(bufVars, "")
|
|
||||||
}
|
|
||||||
|
|
||||||
val doAgg = ctx.freshName("doAgg")
|
|
||||||
ctx.addNewFunction(doAgg,
|
|
||||||
s"""
|
|
||||||
| private void $doAgg() {
|
|
||||||
| // initialize aggregation buffer
|
|
||||||
| ${bufVars.map(_.code).mkString("\n")}
|
|
||||||
|
|
|
||||||
| ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
|
|
||||||
| }
|
|
||||||
""".stripMargin)
|
|
||||||
|
|
||||||
s"""
|
s"""
|
||||||
| if (!$initAgg) {
|
| if (!$initAgg) {
|
||||||
| $initAgg = true;
|
| $initAgg = true;
|
||||||
| $doAgg();
|
|
|
||||||
|
| // initialize aggregation buffer
|
||||||
|
| ${bufVars.map(_.code).mkString("\n")}
|
||||||
|
|
|
||||||
|
| $childSource
|
||||||
|
|
|
|
||||||
| // output the result
|
| // output the result
|
||||||
| $genResult
|
| ${consume(ctx, bufVars)}
|
||||||
|
|
|
||||||
| ${consume(ctx, resultVars)}
|
|
||||||
| }
|
| }
|
||||||
""".stripMargin
|
""".stripMargin
|
||||||
|
|
||||||
|
(rdd, source)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
|
override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
|
||||||
// only have DeclarativeAggregate
|
// only have DeclarativeAggregate
|
||||||
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
|
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
|
||||||
val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output
|
// the mode could be only Partial or PartialMerge
|
||||||
val updateExpr = aggregateExpressions.flatMap { e =>
|
val updateExpr = if (modes.contains(Partial)) {
|
||||||
e.mode match {
|
functions.flatMap(_.updateExpressions)
|
||||||
case Partial | Complete =>
|
} else {
|
||||||
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions
|
functions.flatMap(_.mergeExpressions)
|
||||||
case PartialMerge | Final =>
|
|
||||||
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
val inputAttr = functions.flatMap(_.aggBufferAttributes) ++ child.output
|
||||||
|
val boundExpr = updateExpr.map(e => BindReferences.bindReference(e, inputAttr))
|
||||||
ctx.currentVars = bufVars ++ input
|
ctx.currentVars = bufVars ++ input
|
||||||
// TODO: support subexpression elimination
|
// TODO: support subexpression elimination
|
||||||
val updates = updateExpr.zipWithIndex.map { case (e, i) =>
|
val codes = boundExpr.zipWithIndex.map { case (e, i) =>
|
||||||
val ev = BindReferences.bindReference[Expression](e, inputAttrs).gen(ctx)
|
val ev = e.gen(ctx)
|
||||||
s"""
|
s"""
|
||||||
| ${ev.code}
|
| ${ev.code}
|
||||||
| ${bufVars(i).isNull} = ${ev.isNull};
|
| ${bufVars(i).isNull} = ${ev.isNull};
|
||||||
|
@ -222,7 +190,7 @@ case class TungstenAggregate(
|
||||||
|
|
||||||
s"""
|
s"""
|
||||||
| // do aggregate and update aggregation buffer
|
| // do aggregate and update aggregation buffer
|
||||||
| ${updates.mkString("")}
|
| ${codes.mkString("")}
|
||||||
""".stripMargin
|
""".stripMargin
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -37,15 +37,11 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
|
||||||
|
|
||||||
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
|
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
|
||||||
|
|
||||||
override def upstream(): RDD[InternalRow] = {
|
protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
|
||||||
child.asInstanceOf[CodegenSupport].upstream()
|
|
||||||
}
|
|
||||||
|
|
||||||
protected override def doProduce(ctx: CodegenContext): String = {
|
|
||||||
child.asInstanceOf[CodegenSupport].produce(ctx, this)
|
child.asInstanceOf[CodegenSupport].produce(ctx, this)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
|
override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
|
||||||
val exprs = projectList.map(x =>
|
val exprs = projectList.map(x =>
|
||||||
ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output)))
|
ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output)))
|
||||||
ctx.currentVars = input
|
ctx.currentVars = input
|
||||||
|
@ -80,15 +76,11 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit
|
||||||
"numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"),
|
"numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"),
|
||||||
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
|
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
|
||||||
|
|
||||||
override def upstream(): RDD[InternalRow] = {
|
protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
|
||||||
child.asInstanceOf[CodegenSupport].upstream()
|
|
||||||
}
|
|
||||||
|
|
||||||
protected override def doProduce(ctx: CodegenContext): String = {
|
|
||||||
child.asInstanceOf[CodegenSupport].produce(ctx, this)
|
child.asInstanceOf[CodegenSupport].produce(ctx, this)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
|
override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
|
||||||
val expr = ExpressionCanonicalizer.execute(
|
val expr = ExpressionCanonicalizer.execute(
|
||||||
BindReferences.bindReference(condition, child.output))
|
BindReferences.bindReference(condition, child.output))
|
||||||
ctx.currentVars = input
|
ctx.currentVars = input
|
||||||
|
@ -161,21 +153,17 @@ case class Range(
|
||||||
output: Seq[Attribute])
|
output: Seq[Attribute])
|
||||||
extends LeafNode with CodegenSupport {
|
extends LeafNode with CodegenSupport {
|
||||||
|
|
||||||
override def upstream(): RDD[InternalRow] = {
|
protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
|
||||||
sqlContext.sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i))
|
val initTerm = ctx.freshName("range_initRange")
|
||||||
}
|
|
||||||
|
|
||||||
protected override def doProduce(ctx: CodegenContext): String = {
|
|
||||||
val initTerm = ctx.freshName("initRange")
|
|
||||||
ctx.addMutableState("boolean", initTerm, s"$initTerm = false;")
|
ctx.addMutableState("boolean", initTerm, s"$initTerm = false;")
|
||||||
val partitionEnd = ctx.freshName("partitionEnd")
|
val partitionEnd = ctx.freshName("range_partitionEnd")
|
||||||
ctx.addMutableState("long", partitionEnd, s"$partitionEnd = 0L;")
|
ctx.addMutableState("long", partitionEnd, s"$partitionEnd = 0L;")
|
||||||
val number = ctx.freshName("number")
|
val number = ctx.freshName("range_number")
|
||||||
ctx.addMutableState("long", number, s"$number = 0L;")
|
ctx.addMutableState("long", number, s"$number = 0L;")
|
||||||
val overflow = ctx.freshName("overflow")
|
val overflow = ctx.freshName("range_overflow")
|
||||||
ctx.addMutableState("boolean", overflow, s"$overflow = false;")
|
ctx.addMutableState("boolean", overflow, s"$overflow = false;")
|
||||||
|
|
||||||
val value = ctx.freshName("value")
|
val value = ctx.freshName("range_value")
|
||||||
val ev = ExprCode("", "false", value)
|
val ev = ExprCode("", "false", value)
|
||||||
val BigInt = classOf[java.math.BigInteger].getName
|
val BigInt = classOf[java.math.BigInteger].getName
|
||||||
val checkEnd = if (step > 0) {
|
val checkEnd = if (step > 0) {
|
||||||
|
@ -184,10 +172,15 @@ case class Range(
|
||||||
s"$number > $partitionEnd"
|
s"$number > $partitionEnd"
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx.addNewFunction("initRange",
|
val rdd = sqlContext.sparkContext.parallelize(0 until numSlices, numSlices)
|
||||||
s"""
|
.map(i => InternalRow(i))
|
||||||
| private void initRange(int idx) {
|
|
||||||
| $BigInt index = $BigInt.valueOf(idx);
|
val code = s"""
|
||||||
|
| // initialize Range
|
||||||
|
| if (!$initTerm) {
|
||||||
|
| $initTerm = true;
|
||||||
|
| if (input.hasNext()) {
|
||||||
|
| $BigInt index = $BigInt.valueOf(((InternalRow) input.next()).getInt(0));
|
||||||
| $BigInt numSlice = $BigInt.valueOf(${numSlices}L);
|
| $BigInt numSlice = $BigInt.valueOf(${numSlices}L);
|
||||||
| $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L);
|
| $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L);
|
||||||
| $BigInt step = $BigInt.valueOf(${step}L);
|
| $BigInt step = $BigInt.valueOf(${step}L);
|
||||||
|
@ -211,15 +204,6 @@ case class Range(
|
||||||
| } else {
|
| } else {
|
||||||
| $partitionEnd = end.longValue();
|
| $partitionEnd = end.longValue();
|
||||||
| }
|
| }
|
||||||
| }
|
|
||||||
""".stripMargin)
|
|
||||||
|
|
||||||
s"""
|
|
||||||
| // initialize Range
|
|
||||||
| if (!$initTerm) {
|
|
||||||
| $initTerm = true;
|
|
||||||
| if (input.hasNext()) {
|
|
||||||
| initRange(((InternalRow) input.next()).getInt(0));
|
|
||||||
| } else {
|
| } else {
|
||||||
| return;
|
| return;
|
||||||
| }
|
| }
|
||||||
|
@ -234,6 +218,12 @@ case class Range(
|
||||||
| ${consume(ctx, Seq(ev))}
|
| ${consume(ctx, Seq(ev))}
|
||||||
| }
|
| }
|
||||||
""".stripMargin
|
""".stripMargin
|
||||||
|
|
||||||
|
(rdd, code)
|
||||||
|
}
|
||||||
|
|
||||||
|
def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
|
||||||
|
throw new UnsupportedOperationException
|
||||||
}
|
}
|
||||||
|
|
||||||
protected override def doExecute(): RDD[InternalRow] = {
|
protected override def doExecute(): RDD[InternalRow] = {
|
||||||
|
|
|
@ -1939,8 +1939,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
|
||||||
}
|
}
|
||||||
|
|
||||||
test("Common subexpression elimination") {
|
test("Common subexpression elimination") {
|
||||||
// TODO: support subexpression elimination in whole stage codegen
|
|
||||||
withSQLConf("spark.sql.codegen.wholeStage" -> "false") {
|
|
||||||
// select from a table to prevent constant folding.
|
// select from a table to prevent constant folding.
|
||||||
val df = sql("SELECT a, b from testData2 limit 1")
|
val df = sql("SELECT a, b from testData2 limit 1")
|
||||||
checkAnswer(df, Row(1, 1))
|
checkAnswer(df, Row(1, 1))
|
||||||
|
@ -1994,7 +1992,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
|
||||||
sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true")
|
sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true")
|
||||||
verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1)
|
verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
test("SPARK-10707: nullability should be correctly propagated through set operations (1)") {
|
test("SPARK-10707: nullability should be correctly propagated through set operations (1)") {
|
||||||
// This test produced an incorrect result of 1 before the SPARK-10707 fix because of the
|
// This test produced an incorrect result of 1 before the SPARK-10707 fix because of the
|
||||||
|
|
|
@ -335,7 +335,6 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
|
||||||
|
|
||||||
test("save metrics") {
|
test("save metrics") {
|
||||||
withTempPath { file =>
|
withTempPath { file =>
|
||||||
withSQLConf("spark.sql.codegen.wholeStage" -> "false") {
|
|
||||||
val previousExecutionIds = sqlContext.listener.executionIdToData.keySet
|
val previousExecutionIds = sqlContext.listener.executionIdToData.keySet
|
||||||
// Assume the execution plan is
|
// Assume the execution plan is
|
||||||
// PhysicalRDD(nodeId = 0)
|
// PhysicalRDD(nodeId = 0)
|
||||||
|
@ -354,7 +353,6 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
|
||||||
assert(metricValues.values.toSeq === Seq("2"))
|
assert(metricValues.values.toSeq === Seq("2"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -199,7 +199,7 @@ private[sql] trait SQLTestUtils
|
||||||
val schema = df.schema
|
val schema = df.schema
|
||||||
val childRDD = df
|
val childRDD = df
|
||||||
.queryExecution
|
.queryExecution
|
||||||
.sparkPlan.asInstanceOf[org.apache.spark.sql.execution.Filter]
|
.executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter]
|
||||||
.child
|
.child
|
||||||
.execute()
|
.execute()
|
||||||
.map(row => Row.fromSeq(row.copy().toSeq(schema)))
|
.map(row => Row.fromSeq(row.copy().toSeq(schema)))
|
||||||
|
|
|
@ -97,12 +97,10 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
|
||||||
}
|
}
|
||||||
sqlContext.listenerManager.register(listener)
|
sqlContext.listenerManager.register(listener)
|
||||||
|
|
||||||
withSQLConf("spark.sql.codegen.wholeStage" -> "false") {
|
|
||||||
val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count()
|
val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count()
|
||||||
df.collect()
|
df.collect()
|
||||||
df.collect()
|
df.collect()
|
||||||
Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect()
|
Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect()
|
||||||
}
|
|
||||||
|
|
||||||
assert(metrics.length == 3)
|
assert(metrics.length == 3)
|
||||||
assert(metrics(0) == 1)
|
assert(metrics(0) == 1)
|
||||||
|
|
Loading…
Reference in a new issue