Revert "[SPARK-13031] [SQL] cleanup codegen and improve test coverage"

This reverts commit cc18a71992.
This commit is contained in:
Davies Liu 2016-01-28 17:01:12 -08:00
parent 4637fc08a3
commit b9dfdcc63b
9 changed files with 199 additions and 331 deletions

View file

@ -144,23 +144,14 @@ class CodegenContext {
private val curId = new java.util.concurrent.atomic.AtomicInteger() private val curId = new java.util.concurrent.atomic.AtomicInteger()
/**
* A prefix used to generate fresh name.
*/
var freshNamePrefix = ""
/** /**
* Returns a term name that is unique within this instance of a `CodeGenerator`. * Returns a term name that is unique within this instance of a `CodeGenerator`.
* *
* (Since we aren't in a macro context we do not seem to have access to the built in `freshName` * (Since we aren't in a macro context we do not seem to have access to the built in `freshName`
* function.) * function.)
*/ */
def freshName(name: String): String = { def freshName(prefix: String): String = {
if (freshNamePrefix == "") { s"$prefix${curId.getAndIncrement}"
s"$name${curId.getAndIncrement}"
} else {
s"${freshNamePrefix}_$name${curId.getAndIncrement}"
}
} }
/** /**

View file

@ -93,7 +93,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
// Can't call setNullAt on DecimalType, because we need to keep the offset // Can't call setNullAt on DecimalType, because we need to keep the offset
s""" s"""
if (this.isNull_$i) { if (this.isNull_$i) {
${ctx.setColumn("mutableRow", e.dataType, i, "null")}; ${ctx.setColumn("mutableRow", e.dataType, i, null)};
} else { } else {
${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
} }

View file

@ -22,11 +22,9 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, LeafExpression}
import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.util.Utils
/** /**
* An interface for those physical operators that support codegen. * An interface for those physical operators that support codegen.
@ -44,16 +42,10 @@ trait CodegenSupport extends SparkPlan {
private var parent: CodegenSupport = null private var parent: CodegenSupport = null
/** /**
* Returns the RDD of InternalRow which generates the input rows. * Returns an input RDD of InternalRow and Java source code to process them.
*/ */
def upstream(): RDD[InternalRow] def produce(ctx: CodegenContext, parent: CodegenSupport): (RDD[InternalRow], String) = {
/**
* Returns Java source code to process the rows from upstream.
*/
def produce(ctx: CodegenContext, parent: CodegenSupport): String = {
this.parent = parent this.parent = parent
ctx.freshNamePrefix = nodeName
doProduce(ctx) doProduce(ctx)
} }
@ -74,41 +66,16 @@ trait CodegenSupport extends SparkPlan {
* # call consume(), wich will call parent.doConsume() * # call consume(), wich will call parent.doConsume()
* } * }
*/ */
protected def doProduce(ctx: CodegenContext): String protected def doProduce(ctx: CodegenContext): (RDD[InternalRow], String)
/** /**
* Consume the columns generated from current SparkPlan, call it's parent. * Consume the columns generated from current SparkPlan, call it's parent or create an iterator.
*/ */
def consume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = { protected def consume(ctx: CodegenContext, columns: Seq[ExprCode]): String = {
if (input != null) { assert(columns.length == output.length)
assert(input.length == output.length) parent.doConsume(ctx, this, columns)
}
parent.consumeChild(ctx, this, input, row)
} }
/**
* Consume the columns generated from it's child, call doConsume() or emit the rows.
*/
def consumeChild(
ctx: CodegenContext,
child: SparkPlan,
input: Seq[ExprCode],
row: String = null): String = {
ctx.freshNamePrefix = nodeName
if (row != null) {
ctx.currentVars = null
ctx.INPUT_ROW = row
val evals = child.output.zipWithIndex.map { case (attr, i) =>
BoundReference(i, attr.dataType, attr.nullable).gen(ctx)
}
s"""
| ${evals.map(_.code).mkString("\n")}
| ${doConsume(ctx, evals)}
""".stripMargin
} else {
doConsume(ctx, input)
}
}
/** /**
* Generate the Java source code to process the rows from child SparkPlan. * Generate the Java source code to process the rows from child SparkPlan.
@ -122,9 +89,7 @@ trait CodegenSupport extends SparkPlan {
* # call consume(), which will call parent.doConsume() * # call consume(), which will call parent.doConsume()
* } * }
*/ */
protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String
throw new UnsupportedOperationException
}
} }
@ -137,36 +102,31 @@ trait CodegenSupport extends SparkPlan {
case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
override def output: Seq[Attribute] = child.output override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = child.outputPartitioning
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
override def doPrepare(): Unit = { override def supportCodegen: Boolean = true
child.prepare()
}
override def doExecute(): RDD[InternalRow] = { override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
child.execute()
}
override def supportCodegen: Boolean = false
override def upstream(): RDD[InternalRow] = {
child.execute()
}
override def doProduce(ctx: CodegenContext): String = {
val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true)) val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true))
val row = ctx.freshName("row") val row = ctx.freshName("row")
ctx.INPUT_ROW = row ctx.INPUT_ROW = row
ctx.currentVars = null ctx.currentVars = null
val columns = exprs.map(_.gen(ctx)) val columns = exprs.map(_.gen(ctx))
s""" val code = s"""
| while (input.hasNext()) { | while (input.hasNext()) {
| InternalRow $row = (InternalRow) input.next(); | InternalRow $row = (InternalRow) input.next();
| ${columns.map(_.code).mkString("\n")} | ${columns.map(_.code).mkString("\n")}
| ${consume(ctx, columns)} | ${consume(ctx, columns)}
| } | }
""".stripMargin """.stripMargin
(child.execute(), code)
}
def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
throw new UnsupportedOperationException
}
override def doExecute(): RDD[InternalRow] = {
throw new UnsupportedOperationException
} }
override def simpleString: String = "INPUT" override def simpleString: String = "INPUT"
@ -183,20 +143,16 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
* *
* -> execute() * -> execute()
* | * |
* doExecute() ---------> upstream() -------> upstream() ------> execute() * doExecute() --------> produce()
* |
* -----------------> produce()
* | * |
* doProduce() -------> produce() * doProduce() -------> produce()
* | * |
* doProduce() * doProduce() ---> execute()
* | * |
* consume() * consume()
* consumeChild() <-----------| * doConsume() ------------|
* | * |
* doConsume() * doConsume() <----- consume()
* |
* consumeChild() <----- consume()
* *
* SparkPlan A should override doProduce() and doConsume(). * SparkPlan A should override doProduce() and doConsume().
* *
@ -206,19 +162,11 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
extends SparkPlan with CodegenSupport { extends SparkPlan with CodegenSupport {
override def supportCodegen: Boolean = false
override def output: Seq[Attribute] = plan.output override def output: Seq[Attribute] = plan.output
override def outputPartitioning: Partitioning = plan.outputPartitioning
override def outputOrdering: Seq[SortOrder] = plan.outputOrdering
override def doPrepare(): Unit = {
plan.prepare()
}
override def doExecute(): RDD[InternalRow] = { override def doExecute(): RDD[InternalRow] = {
val ctx = new CodegenContext val ctx = new CodegenContext
val code = plan.produce(ctx, this) val (rdd, code) = plan.produce(ctx, this)
val references = ctx.references.toArray val references = ctx.references.toArray
val source = s""" val source = s"""
public Object generate(Object[] references) { public Object generate(Object[] references) {
@ -229,25 +177,22 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
private Object[] references; private Object[] references;
${ctx.declareMutableStates()} ${ctx.declareMutableStates()}
${ctx.declareAddedFunctions()}
public GeneratedIterator(Object[] references) { public GeneratedIterator(Object[] references) {
this.references = references; this.references = references;
${ctx.initMutableStates()} ${ctx.initMutableStates()}
} }
protected void processNext() throws java.io.IOException { protected void processNext() {
$code $code
} }
} }
""" """
// try to compile, helpful for debug // try to compile, helpful for debug
// println(s"${CodeFormatter.format(source)}") // println(s"${CodeFormatter.format(source)}")
CodeGenerator.compile(source) CodeGenerator.compile(source)
plan.upstream().mapPartitions { iter => rdd.mapPartitions { iter =>
val clazz = CodeGenerator.compile(source) val clazz = CodeGenerator.compile(source)
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
buffer.setInput(iter) buffer.setInput(iter)
@ -258,28 +203,11 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
} }
} }
override def upstream(): RDD[InternalRow] = { override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
throw new UnsupportedOperationException throw new UnsupportedOperationException
} }
override def doProduce(ctx: CodegenContext): String = { override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
throw new UnsupportedOperationException
}
override def consumeChild(
ctx: CodegenContext,
child: SparkPlan,
input: Seq[ExprCode],
row: String = null): String = {
if (row != null) {
// There is an UnsafeRow already
s"""
| currentRow = $row;
| return;
""".stripMargin
} else {
assert(input != null)
if (input.nonEmpty) { if (input.nonEmpty) {
val colExprs = output.zipWithIndex.map { case (attr, i) => val colExprs = output.zipWithIndex.map { case (attr, i) =>
BoundReference(i, attr.dataType, attr.nullable) BoundReference(i, attr.dataType, attr.nullable)
@ -300,7 +228,6 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
""".stripMargin """.stripMargin
} }
} }
}
override def generateTreeString( override def generateTreeString(
depth: Int, depth: Int,
@ -319,7 +246,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
builder.append(simpleString) builder.append(simpleString)
builder.append("\n") builder.append("\n")
plan.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder) plan.generateTreeString(depth + 1, lastChildren :+children.isEmpty :+ true, builder)
if (children.nonEmpty) { if (children.nonEmpty) {
children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder)) children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder))
children.last.generateTreeString(depth + 1, lastChildren :+ true, builder) children.last.generateTreeString(depth + 1, lastChildren :+ true, builder)
@ -359,14 +286,13 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
case plan: CodegenSupport if supportCodegen(plan) && case plan: CodegenSupport if supportCodegen(plan) &&
// Whole stage codegen is only useful when there are at least two levels of operators that // Whole stage codegen is only useful when there are at least two levels of operators that
// support it (save at least one projection/iterator). // support it (save at least one projection/iterator).
(Utils.isTesting || plan.children.exists(supportCodegen)) => plan.children.exists(supportCodegen) =>
var inputs = ArrayBuffer[SparkPlan]() var inputs = ArrayBuffer[SparkPlan]()
val combined = plan.transform { val combined = plan.transform {
case p if !supportCodegen(p) => case p if !supportCodegen(p) =>
val input = apply(p) // collapse them recursively inputs += p
inputs += input InputAdapter(p)
InputAdapter(input)
}.asInstanceOf[CodegenSupport] }.asInstanceOf[CodegenSupport]
WholeStageCodegen(combined, inputs) WholeStageCodegen(combined, inputs)
} }

View file

@ -117,7 +117,9 @@ case class TungstenAggregate(
override def supportCodegen: Boolean = { override def supportCodegen: Boolean = {
groupingExpressions.isEmpty && groupingExpressions.isEmpty &&
// ImperativeAggregate is not supported right now // ImperativeAggregate is not supported right now
!aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) &&
// final aggregation only have one row, do not need to codegen
!aggregateExpressions.exists(e => e.mode == Final || e.mode == Complete)
} }
// The variables used as aggregation buffer // The variables used as aggregation buffer
@ -125,11 +127,7 @@ case class TungstenAggregate(
private val modes = aggregateExpressions.map(_.mode).distinct private val modes = aggregateExpressions.map(_.mode).distinct
override def upstream(): RDD[InternalRow] = { protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
child.asInstanceOf[CodegenSupport].upstream()
}
protected override def doProduce(ctx: CodegenContext): String = {
val initAgg = ctx.freshName("initAgg") val initAgg = ctx.freshName("initAgg")
ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
@ -139,80 +137,50 @@ case class TungstenAggregate(
bufVars = initExpr.map { e => bufVars = initExpr.map { e =>
val isNull = ctx.freshName("bufIsNull") val isNull = ctx.freshName("bufIsNull")
val value = ctx.freshName("bufValue") val value = ctx.freshName("bufValue")
ctx.addMutableState("boolean", isNull, "")
ctx.addMutableState(ctx.javaType(e.dataType), value, "")
// The initial expression should not access any column // The initial expression should not access any column
val ev = e.gen(ctx) val ev = e.gen(ctx)
val initVars = s""" val initVars = s"""
| $isNull = ${ev.isNull}; | boolean $isNull = ${ev.isNull};
| $value = ${ev.value}; | ${ctx.javaType(e.dataType)} $value = ${ev.value};
""".stripMargin """.stripMargin
ExprCode(ev.code + initVars, isNull, value) ExprCode(ev.code + initVars, isNull, value)
} }
// generate variables for output val (rdd, childSource) = child.asInstanceOf[CodegenSupport].produce(ctx, this)
val (resultVars, genResult) = if (modes.contains(Final) | modes.contains(Complete)) { val source =
// evaluate aggregate results
ctx.currentVars = bufVars
val bufferAttrs = functions.flatMap(_.aggBufferAttributes)
val aggResults = functions.map(_.evaluateExpression).map { e =>
BindReferences.bindReference(e, bufferAttrs).gen(ctx)
}
// evaluate result expressions
ctx.currentVars = aggResults
val resultVars = resultExpressions.map { e =>
BindReferences.bindReference(e, aggregateAttributes).gen(ctx)
}
(resultVars, s"""
| ${aggResults.map(_.code).mkString("\n")}
| ${resultVars.map(_.code).mkString("\n")}
""".stripMargin)
} else {
// output the aggregate buffer directly
(bufVars, "")
}
val doAgg = ctx.freshName("doAgg")
ctx.addNewFunction(doAgg,
s"""
| private void $doAgg() {
| // initialize aggregation buffer
| ${bufVars.map(_.code).mkString("\n")}
|
| ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
| }
""".stripMargin)
s""" s"""
| if (!$initAgg) { | if (!$initAgg) {
| $initAgg = true; | $initAgg = true;
| $doAgg(); |
| // initialize aggregation buffer
| ${bufVars.map(_.code).mkString("\n")}
|
| $childSource
| |
| // output the result | // output the result
| $genResult | ${consume(ctx, bufVars)}
|
| ${consume(ctx, resultVars)}
| } | }
""".stripMargin """.stripMargin
(rdd, source)
} }
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
// only have DeclarativeAggregate // only have DeclarativeAggregate
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output // the mode could be only Partial or PartialMerge
val updateExpr = aggregateExpressions.flatMap { e => val updateExpr = if (modes.contains(Partial)) {
e.mode match { functions.flatMap(_.updateExpressions)
case Partial | Complete => } else {
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions functions.flatMap(_.mergeExpressions)
case PartialMerge | Final =>
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
}
} }
val inputAttr = functions.flatMap(_.aggBufferAttributes) ++ child.output
val boundExpr = updateExpr.map(e => BindReferences.bindReference(e, inputAttr))
ctx.currentVars = bufVars ++ input ctx.currentVars = bufVars ++ input
// TODO: support subexpression elimination // TODO: support subexpression elimination
val updates = updateExpr.zipWithIndex.map { case (e, i) => val codes = boundExpr.zipWithIndex.map { case (e, i) =>
val ev = BindReferences.bindReference[Expression](e, inputAttrs).gen(ctx) val ev = e.gen(ctx)
s""" s"""
| ${ev.code} | ${ev.code}
| ${bufVars(i).isNull} = ${ev.isNull}; | ${bufVars(i).isNull} = ${ev.isNull};
@ -222,7 +190,7 @@ case class TungstenAggregate(
s""" s"""
| // do aggregate and update aggregation buffer | // do aggregate and update aggregation buffer
| ${updates.mkString("")} | ${codes.mkString("")}
""".stripMargin """.stripMargin
} }

View file

@ -37,15 +37,11 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
override def output: Seq[Attribute] = projectList.map(_.toAttribute) override def output: Seq[Attribute] = projectList.map(_.toAttribute)
override def upstream(): RDD[InternalRow] = { protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
child.asInstanceOf[CodegenSupport].upstream()
}
protected override def doProduce(ctx: CodegenContext): String = {
child.asInstanceOf[CodegenSupport].produce(ctx, this) child.asInstanceOf[CodegenSupport].produce(ctx, this)
} }
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
val exprs = projectList.map(x => val exprs = projectList.map(x =>
ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output))) ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output)))
ctx.currentVars = input ctx.currentVars = input
@ -80,15 +76,11 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit
"numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"),
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
override def upstream(): RDD[InternalRow] = { protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
child.asInstanceOf[CodegenSupport].upstream()
}
protected override def doProduce(ctx: CodegenContext): String = {
child.asInstanceOf[CodegenSupport].produce(ctx, this) child.asInstanceOf[CodegenSupport].produce(ctx, this)
} }
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
val expr = ExpressionCanonicalizer.execute( val expr = ExpressionCanonicalizer.execute(
BindReferences.bindReference(condition, child.output)) BindReferences.bindReference(condition, child.output))
ctx.currentVars = input ctx.currentVars = input
@ -161,21 +153,17 @@ case class Range(
output: Seq[Attribute]) output: Seq[Attribute])
extends LeafNode with CodegenSupport { extends LeafNode with CodegenSupport {
override def upstream(): RDD[InternalRow] = { protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
sqlContext.sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i)) val initTerm = ctx.freshName("range_initRange")
}
protected override def doProduce(ctx: CodegenContext): String = {
val initTerm = ctx.freshName("initRange")
ctx.addMutableState("boolean", initTerm, s"$initTerm = false;") ctx.addMutableState("boolean", initTerm, s"$initTerm = false;")
val partitionEnd = ctx.freshName("partitionEnd") val partitionEnd = ctx.freshName("range_partitionEnd")
ctx.addMutableState("long", partitionEnd, s"$partitionEnd = 0L;") ctx.addMutableState("long", partitionEnd, s"$partitionEnd = 0L;")
val number = ctx.freshName("number") val number = ctx.freshName("range_number")
ctx.addMutableState("long", number, s"$number = 0L;") ctx.addMutableState("long", number, s"$number = 0L;")
val overflow = ctx.freshName("overflow") val overflow = ctx.freshName("range_overflow")
ctx.addMutableState("boolean", overflow, s"$overflow = false;") ctx.addMutableState("boolean", overflow, s"$overflow = false;")
val value = ctx.freshName("value") val value = ctx.freshName("range_value")
val ev = ExprCode("", "false", value) val ev = ExprCode("", "false", value)
val BigInt = classOf[java.math.BigInteger].getName val BigInt = classOf[java.math.BigInteger].getName
val checkEnd = if (step > 0) { val checkEnd = if (step > 0) {
@ -184,10 +172,15 @@ case class Range(
s"$number > $partitionEnd" s"$number > $partitionEnd"
} }
ctx.addNewFunction("initRange", val rdd = sqlContext.sparkContext.parallelize(0 until numSlices, numSlices)
s""" .map(i => InternalRow(i))
| private void initRange(int idx) {
| $BigInt index = $BigInt.valueOf(idx); val code = s"""
| // initialize Range
| if (!$initTerm) {
| $initTerm = true;
| if (input.hasNext()) {
| $BigInt index = $BigInt.valueOf(((InternalRow) input.next()).getInt(0));
| $BigInt numSlice = $BigInt.valueOf(${numSlices}L); | $BigInt numSlice = $BigInt.valueOf(${numSlices}L);
| $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L); | $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L);
| $BigInt step = $BigInt.valueOf(${step}L); | $BigInt step = $BigInt.valueOf(${step}L);
@ -211,15 +204,6 @@ case class Range(
| } else { | } else {
| $partitionEnd = end.longValue(); | $partitionEnd = end.longValue();
| } | }
| }
""".stripMargin)
s"""
| // initialize Range
| if (!$initTerm) {
| $initTerm = true;
| if (input.hasNext()) {
| initRange(((InternalRow) input.next()).getInt(0));
| } else { | } else {
| return; | return;
| } | }
@ -234,6 +218,12 @@ case class Range(
| ${consume(ctx, Seq(ev))} | ${consume(ctx, Seq(ev))}
| } | }
""".stripMargin """.stripMargin
(rdd, code)
}
def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
throw new UnsupportedOperationException
} }
protected override def doExecute(): RDD[InternalRow] = { protected override def doExecute(): RDD[InternalRow] = {

View file

@ -1939,8 +1939,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
} }
test("Common subexpression elimination") { test("Common subexpression elimination") {
// TODO: support subexpression elimination in whole stage codegen
withSQLConf("spark.sql.codegen.wholeStage" -> "false") {
// select from a table to prevent constant folding. // select from a table to prevent constant folding.
val df = sql("SELECT a, b from testData2 limit 1") val df = sql("SELECT a, b from testData2 limit 1")
checkAnswer(df, Row(1, 1)) checkAnswer(df, Row(1, 1))
@ -1994,7 +1992,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true")
verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1)
} }
}
test("SPARK-10707: nullability should be correctly propagated through set operations (1)") { test("SPARK-10707: nullability should be correctly propagated through set operations (1)") {
// This test produced an incorrect result of 1 before the SPARK-10707 fix because of the // This test produced an incorrect result of 1 before the SPARK-10707 fix because of the

View file

@ -335,7 +335,6 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
test("save metrics") { test("save metrics") {
withTempPath { file => withTempPath { file =>
withSQLConf("spark.sql.codegen.wholeStage" -> "false") {
val previousExecutionIds = sqlContext.listener.executionIdToData.keySet val previousExecutionIds = sqlContext.listener.executionIdToData.keySet
// Assume the execution plan is // Assume the execution plan is
// PhysicalRDD(nodeId = 0) // PhysicalRDD(nodeId = 0)
@ -354,7 +353,6 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
assert(metricValues.values.toSeq === Seq("2")) assert(metricValues.values.toSeq === Seq("2"))
} }
} }
}
} }

View file

@ -199,7 +199,7 @@ private[sql] trait SQLTestUtils
val schema = df.schema val schema = df.schema
val childRDD = df val childRDD = df
.queryExecution .queryExecution
.sparkPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter]
.child .child
.execute() .execute()
.map(row => Row.fromSeq(row.copy().toSeq(schema))) .map(row => Row.fromSeq(row.copy().toSeq(schema)))

View file

@ -97,12 +97,10 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
} }
sqlContext.listenerManager.register(listener) sqlContext.listenerManager.register(listener)
withSQLConf("spark.sql.codegen.wholeStage" -> "false") {
val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count() val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count()
df.collect() df.collect()
df.collect() df.collect()
Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect() Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect()
}
assert(metrics.length == 3) assert(metrics.length == 3)
assert(metrics(0) == 1) assert(metrics(0) == 1)