[SPARK-15214][SQL] Code-generation for Generate
## What changes were proposed in this pull request? This PR adds code generation to `Generate`. It supports two code paths: - General `TraversableOnce` based iteration. This used for regular `Generator` (code generation supporting) expressions. This code path expects the expression to return a `TraversableOnce[InternalRow]` and it will iterate over the returned collection. This PR adds code generation for the `stack` generator. - Specialized `ArrayData/MapData` based iteration. This is used for the `explode`, `posexplode` & `inline` functions and operates directly on the `ArrayData`/`MapData` result that the child of the generator returns. ### Benchmarks I have added some benchmarks and it seems we can create a nice speedup for explode: #### Environment ``` Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6 Intel(R) Core(TM) i7-4980HQ CPU 2.80GHz ``` #### Explode Array ##### Before ``` generate explode array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate explode array wholestage off 7377 / 7607 2.3 439.7 1.0X generate explode array wholestage on 6055 / 6086 2.8 360.9 1.2X ``` ##### After ``` generate explode array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate explode array wholestage off 7432 / 7696 2.3 443.0 1.0X generate explode array wholestage on 631 / 646 26.6 37.6 11.8X ``` #### Explode Map ##### Before ``` generate explode map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate explode map wholestage off 12792 / 12848 1.3 762.5 1.0X generate explode map wholestage on 11181 / 11237 1.5 666.5 1.1X ``` ##### After ``` generate explode map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate explode map wholestage off 10949 / 10972 1.5 652.6 1.0X generate explode map wholestage on 870 / 913 19.3 51.9 12.6X ``` #### Posexplode ##### Before ``` generate posexplode array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate posexplode array wholestage off 7547 / 7580 2.2 449.8 1.0X generate posexplode array wholestage on 5786 / 5838 2.9 344.9 1.3X ``` ##### After ``` generate posexplode array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate posexplode array wholestage off 7535 / 7548 2.2 449.1 1.0X generate posexplode array wholestage on 620 / 624 27.1 37.0 12.1X ``` #### Inline ##### Before ``` generate inline array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate inline array wholestage off 6935 / 6978 2.4 413.3 1.0X generate inline array wholestage on 6360 / 6400 2.6 379.1 1.1X ``` ##### After ``` generate inline array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate inline array wholestage off 6940 / 6966 2.4 413.6 1.0X generate inline array wholestage on 1002 / 1012 16.7 59.7 6.9X ``` #### Stack ##### Before ``` generate stack: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate stack wholestage off 12980 / 13104 1.3 773.7 1.0X generate stack wholestage on 11566 / 11580 1.5 689.4 1.1X ``` ##### After ``` generate stack: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ generate stack wholestage off 12875 / 12949 1.3 767.4 1.0X generate stack wholestage on 840 / 845 20.0 50.0 15.3X ``` ## How was this patch tested? Existing tests. Author: Herman van Hovell <hvanhovell@databricks.com> Author: Herman van Hovell <hvanhovell@questtec.nl> Closes #13065 from hvanhovell/SPARK-15214.
This commit is contained in:
parent
a64f25d8b4
commit
7ca7a63524
|
@ -17,10 +17,12 @@
|
||||||
|
|
||||||
package org.apache.spark.sql.catalyst.expressions
|
package org.apache.spark.sql.catalyst.expressions
|
||||||
|
|
||||||
|
import scala.collection.mutable
|
||||||
|
|
||||||
import org.apache.spark.sql.Row
|
import org.apache.spark.sql.Row
|
||||||
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
|
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
|
||||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||||
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
|
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
|
||||||
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
|
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
|
|
||||||
|
@ -60,6 +62,26 @@ trait Generator extends Expression {
|
||||||
* rows can be made here.
|
* rows can be made here.
|
||||||
*/
|
*/
|
||||||
def terminate(): TraversableOnce[InternalRow] = Nil
|
def terminate(): TraversableOnce[InternalRow] = Nil
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if this generator supports code generation.
|
||||||
|
*/
|
||||||
|
def supportCodegen: Boolean = !isInstanceOf[CodegenFallback]
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A collection producing [[Generator]]. This trait provides a different path for code generation,
|
||||||
|
* by allowing code generation to return either an [[ArrayData]] or a [[MapData]] object.
|
||||||
|
*/
|
||||||
|
trait CollectionGenerator extends Generator {
|
||||||
|
/** The position of an element within the collection should also be returned. */
|
||||||
|
def position: Boolean
|
||||||
|
|
||||||
|
/** Rows will be inlined during generation. */
|
||||||
|
def inline: Boolean
|
||||||
|
|
||||||
|
/** The type of the returned collection object. */
|
||||||
|
def collectionType: DataType = dataType
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -77,7 +99,9 @@ case class UserDefinedGenerator(
|
||||||
private def initializeConverters(): Unit = {
|
private def initializeConverters(): Unit = {
|
||||||
inputRow = new InterpretedProjection(children)
|
inputRow = new InterpretedProjection(children)
|
||||||
convertToScala = {
|
convertToScala = {
|
||||||
val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true)))
|
val inputSchema = StructType(children.map { e =>
|
||||||
|
StructField(e.simpleString, e.dataType, nullable = true)
|
||||||
|
})
|
||||||
CatalystTypeConverters.createToScalaConverter(inputSchema)
|
CatalystTypeConverters.createToScalaConverter(inputSchema)
|
||||||
}.asInstanceOf[InternalRow => Row]
|
}.asInstanceOf[InternalRow => Row]
|
||||||
}
|
}
|
||||||
|
@ -109,8 +133,7 @@ case class UserDefinedGenerator(
|
||||||
1 2
|
1 2
|
||||||
3 NULL
|
3 NULL
|
||||||
""")
|
""")
|
||||||
case class Stack(children: Seq[Expression])
|
case class Stack(children: Seq[Expression]) extends Generator {
|
||||||
extends Expression with Generator with CodegenFallback {
|
|
||||||
|
|
||||||
private lazy val numRows = children.head.eval().asInstanceOf[Int]
|
private lazy val numRows = children.head.eval().asInstanceOf[Int]
|
||||||
private lazy val numFields = Math.ceil((children.length - 1.0) / numRows).toInt
|
private lazy val numFields = Math.ceil((children.length - 1.0) / numRows).toInt
|
||||||
|
@ -149,29 +172,58 @@ case class Stack(children: Seq[Expression])
|
||||||
InternalRow(fields: _*)
|
InternalRow(fields: _*)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Only support code generation when stack produces 50 rows or less.
|
||||||
|
*/
|
||||||
|
override def supportCodegen: Boolean = numRows <= 50
|
||||||
|
|
||||||
|
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||||
|
// Rows - we write these into an array.
|
||||||
|
val rowData = ctx.freshName("rows")
|
||||||
|
ctx.addMutableState("InternalRow[]", rowData, s"this.$rowData = new InternalRow[$numRows];")
|
||||||
|
val values = children.tail
|
||||||
|
val dataTypes = values.take(numFields).map(_.dataType)
|
||||||
|
val code = ctx.splitExpressions(ctx.INPUT_ROW, Seq.tabulate(numRows) { row =>
|
||||||
|
val fields = Seq.tabulate(numFields) { col =>
|
||||||
|
val index = row * numFields + col
|
||||||
|
if (index < values.length) values(index) else Literal(null, dataTypes(col))
|
||||||
|
}
|
||||||
|
val eval = CreateStruct(fields).genCode(ctx)
|
||||||
|
s"${eval.code}\nthis.$rowData[$row] = ${eval.value};"
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create the collection.
|
||||||
|
val wrapperClass = classOf[mutable.WrappedArray[_]].getName
|
||||||
|
ctx.addMutableState(
|
||||||
|
s"$wrapperClass<InternalRow>",
|
||||||
|
ev.value,
|
||||||
|
s"this.${ev.value} = $wrapperClass$$.MODULE$$.make(this.$rowData);")
|
||||||
|
ev.copy(code = code, isNull = "false")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A base class for Explode and PosExplode
|
* A base class for [[Explode]] and [[PosExplode]].
|
||||||
*/
|
*/
|
||||||
abstract class ExplodeBase(child: Expression, position: Boolean)
|
abstract class ExplodeBase extends UnaryExpression with CollectionGenerator with Serializable {
|
||||||
extends UnaryExpression with Generator with CodegenFallback with Serializable {
|
override val inline: Boolean = false
|
||||||
|
|
||||||
override def checkInputDataTypes(): TypeCheckResult = {
|
override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
|
||||||
if (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) {
|
case _: ArrayType | _: MapType =>
|
||||||
TypeCheckResult.TypeCheckSuccess
|
TypeCheckResult.TypeCheckSuccess
|
||||||
} else {
|
case _ =>
|
||||||
TypeCheckResult.TypeCheckFailure(
|
TypeCheckResult.TypeCheckFailure(
|
||||||
s"input to function explode should be array or map type, not ${child.dataType}")
|
s"input to function explode should be array or map type, not ${child.dataType}")
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// hive-compatible default alias for explode function ("col" for array, "key", "value" for map)
|
// hive-compatible default alias for explode function ("col" for array, "key", "value" for map)
|
||||||
override def elementSchema: StructType = child.dataType match {
|
override def elementSchema: StructType = child.dataType match {
|
||||||
case ArrayType(et, containsNull) =>
|
case ArrayType(et, containsNull) =>
|
||||||
if (position) {
|
if (position) {
|
||||||
new StructType()
|
new StructType()
|
||||||
.add("pos", IntegerType, false)
|
.add("pos", IntegerType, nullable = false)
|
||||||
.add("col", et, containsNull)
|
.add("col", et, containsNull)
|
||||||
} else {
|
} else {
|
||||||
new StructType()
|
new StructType()
|
||||||
|
@ -180,12 +232,12 @@ abstract class ExplodeBase(child: Expression, position: Boolean)
|
||||||
case MapType(kt, vt, valueContainsNull) =>
|
case MapType(kt, vt, valueContainsNull) =>
|
||||||
if (position) {
|
if (position) {
|
||||||
new StructType()
|
new StructType()
|
||||||
.add("pos", IntegerType, false)
|
.add("pos", IntegerType, nullable = false)
|
||||||
.add("key", kt, false)
|
.add("key", kt, nullable = false)
|
||||||
.add("value", vt, valueContainsNull)
|
.add("value", vt, valueContainsNull)
|
||||||
} else {
|
} else {
|
||||||
new StructType()
|
new StructType()
|
||||||
.add("key", kt, false)
|
.add("key", kt, nullable = false)
|
||||||
.add("value", vt, valueContainsNull)
|
.add("value", vt, valueContainsNull)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -218,6 +270,12 @@ abstract class ExplodeBase(child: Expression, position: Boolean)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def collectionType: DataType = child.dataType
|
||||||
|
|
||||||
|
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||||
|
child.genCode(ctx)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -239,7 +297,9 @@ abstract class ExplodeBase(child: Expression, position: Boolean)
|
||||||
20
|
20
|
||||||
""")
|
""")
|
||||||
// scalastyle:on line.size.limit
|
// scalastyle:on line.size.limit
|
||||||
case class Explode(child: Expression) extends ExplodeBase(child, position = false)
|
case class Explode(child: Expression) extends ExplodeBase {
|
||||||
|
override val position: Boolean = false
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Given an input array produces a sequence of rows for each position and value in the array.
|
* Given an input array produces a sequence of rows for each position and value in the array.
|
||||||
|
@ -260,7 +320,9 @@ case class Explode(child: Expression) extends ExplodeBase(child, position = fals
|
||||||
1 20
|
1 20
|
||||||
""")
|
""")
|
||||||
// scalastyle:on line.size.limit
|
// scalastyle:on line.size.limit
|
||||||
case class PosExplode(child: Expression) extends ExplodeBase(child, position = true)
|
case class PosExplode(child: Expression) extends ExplodeBase {
|
||||||
|
override val position = true
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Explodes an array of structs into a table.
|
* Explodes an array of structs into a table.
|
||||||
|
@ -273,10 +335,12 @@ case class PosExplode(child: Expression) extends ExplodeBase(child, position = t
|
||||||
1 a
|
1 a
|
||||||
2 b
|
2 b
|
||||||
""")
|
""")
|
||||||
case class Inline(child: Expression) extends UnaryExpression with Generator with CodegenFallback {
|
case class Inline(child: Expression) extends UnaryExpression with CollectionGenerator {
|
||||||
|
override val inline: Boolean = true
|
||||||
|
override val position: Boolean = false
|
||||||
|
|
||||||
override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
|
override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
|
||||||
case ArrayType(et, _) if et.isInstanceOf[StructType] =>
|
case ArrayType(st: StructType, _) =>
|
||||||
TypeCheckResult.TypeCheckSuccess
|
TypeCheckResult.TypeCheckSuccess
|
||||||
case _ =>
|
case _ =>
|
||||||
TypeCheckResult.TypeCheckFailure(
|
TypeCheckResult.TypeCheckFailure(
|
||||||
|
@ -284,9 +348,11 @@ case class Inline(child: Expression) extends UnaryExpression with Generator with
|
||||||
}
|
}
|
||||||
|
|
||||||
override def elementSchema: StructType = child.dataType match {
|
override def elementSchema: StructType = child.dataType match {
|
||||||
case ArrayType(et : StructType, _) => et
|
case ArrayType(st: StructType, _) => st
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def collectionType: DataType = child.dataType
|
||||||
|
|
||||||
private lazy val numFields = elementSchema.fields.length
|
private lazy val numFields = elementSchema.fields.length
|
||||||
|
|
||||||
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
|
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
|
||||||
|
@ -298,4 +364,8 @@ case class Inline(child: Expression) extends UnaryExpression with Generator with
|
||||||
yield inputArray.getStruct(i, numFields)
|
yield inputArray.getStruct(i, numFields)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||||
|
child.genCode(ctx)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,8 @@
|
||||||
package org.apache.spark.sql.catalyst.expressions
|
package org.apache.spark.sql.catalyst.expressions
|
||||||
|
|
||||||
import org.apache.spark.SparkFunSuite
|
import org.apache.spark.SparkFunSuite
|
||||||
import org.apache.spark.sql.types.IntegerType
|
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
|
||||||
|
import org.apache.spark.sql.types.{DataType, IntegerType}
|
||||||
|
|
||||||
class SubexpressionEliminationSuite extends SparkFunSuite {
|
class SubexpressionEliminationSuite extends SparkFunSuite {
|
||||||
test("Semantic equals and hash") {
|
test("Semantic equals and hash") {
|
||||||
|
@ -162,13 +163,18 @@ class SubexpressionEliminationSuite extends SparkFunSuite {
|
||||||
test("Children of CodegenFallback") {
|
test("Children of CodegenFallback") {
|
||||||
val one = Literal(1)
|
val one = Literal(1)
|
||||||
val two = Add(one, one)
|
val two = Add(one, one)
|
||||||
val explode = Explode(two)
|
val fallback = CodegenFallbackExpression(two)
|
||||||
val add = Add(two, explode)
|
val add = Add(two, fallback)
|
||||||
|
|
||||||
var equivalence = new EquivalentExpressions
|
val equivalence = new EquivalentExpressions
|
||||||
equivalence.addExprTree(add, true)
|
equivalence.addExprTree(add, true)
|
||||||
// the `two` inside `explode` should not be added
|
// the `two` inside `fallback` should not be added
|
||||||
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0)
|
assert(equivalence.getAllEquivalentExprs.count(_.size > 1) == 0)
|
||||||
assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode
|
assert(equivalence.getAllEquivalentExprs.count(_.size == 1) == 3) // add, two, explode
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case class CodegenFallbackExpression(child: Expression)
|
||||||
|
extends UnaryExpression with CodegenFallback {
|
||||||
|
override def dataType: DataType = child.dataType
|
||||||
|
}
|
||||||
|
|
|
@ -20,8 +20,10 @@ package org.apache.spark.sql.execution
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.catalyst.InternalRow
|
import org.apache.spark.sql.catalyst.InternalRow
|
||||||
import org.apache.spark.sql.catalyst.expressions._
|
import org.apache.spark.sql.catalyst.expressions._
|
||||||
|
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
|
||||||
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
|
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
|
||||||
import org.apache.spark.sql.execution.metric.SQLMetrics
|
import org.apache.spark.sql.execution.metric.SQLMetrics
|
||||||
|
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* For lazy computing, be sure the generator.terminate() called in the very last
|
* For lazy computing, be sure the generator.terminate() called in the very last
|
||||||
|
@ -40,6 +42,10 @@ private[execution] sealed case class LazyIterator(func: () => TraversableOnce[In
|
||||||
* output of each into a new stream of rows. This operation is similar to a `flatMap` in functional
|
* output of each into a new stream of rows. This operation is similar to a `flatMap` in functional
|
||||||
* programming with one important additional feature, which allows the input rows to be joined with
|
* programming with one important additional feature, which allows the input rows to be joined with
|
||||||
* their output.
|
* their output.
|
||||||
|
*
|
||||||
|
* This operator supports whole stage code generation for generators that do not implement
|
||||||
|
* terminate().
|
||||||
|
*
|
||||||
* @param generator the generator expression
|
* @param generator the generator expression
|
||||||
* @param join when true, each output row is implicitly joined with the input tuple that produced
|
* @param join when true, each output row is implicitly joined with the input tuple that produced
|
||||||
* it.
|
* it.
|
||||||
|
@ -54,7 +60,7 @@ case class GenerateExec(
|
||||||
outer: Boolean,
|
outer: Boolean,
|
||||||
output: Seq[Attribute],
|
output: Seq[Attribute],
|
||||||
child: SparkPlan)
|
child: SparkPlan)
|
||||||
extends UnaryExecNode {
|
extends UnaryExecNode with CodegenSupport {
|
||||||
|
|
||||||
override lazy val metrics = Map(
|
override lazy val metrics = Map(
|
||||||
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
|
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
|
||||||
|
@ -103,5 +109,197 @@ case class GenerateExec(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
override def supportCodegen: Boolean = generator.supportCodegen
|
||||||
|
|
||||||
|
override def inputRDDs(): Seq[RDD[InternalRow]] = {
|
||||||
|
child.asInstanceOf[CodegenSupport].inputRDDs()
|
||||||
|
}
|
||||||
|
|
||||||
|
protected override def doProduce(ctx: CodegenContext): String = {
|
||||||
|
child.asInstanceOf[CodegenSupport].produce(ctx, this)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
|
||||||
|
ctx.currentVars = input
|
||||||
|
ctx.copyResult = true
|
||||||
|
|
||||||
|
// Add input rows to the values when we are joining
|
||||||
|
val values = if (join) {
|
||||||
|
input
|
||||||
|
} else {
|
||||||
|
Seq.empty
|
||||||
|
}
|
||||||
|
|
||||||
|
boundGenerator match {
|
||||||
|
case e: CollectionGenerator => codeGenCollection(ctx, e, values, row)
|
||||||
|
case g => codeGenTraversableOnce(ctx, g, values, row)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate code for [[CollectionGenerator]] expressions.
|
||||||
|
*/
|
||||||
|
private def codeGenCollection(
|
||||||
|
ctx: CodegenContext,
|
||||||
|
e: CollectionGenerator,
|
||||||
|
input: Seq[ExprCode],
|
||||||
|
row: ExprCode): String = {
|
||||||
|
|
||||||
|
// Generate code for the generator.
|
||||||
|
val data = e.genCode(ctx)
|
||||||
|
|
||||||
|
// Generate looping variables.
|
||||||
|
val index = ctx.freshName("index")
|
||||||
|
|
||||||
|
// Add a check if the generate outer flag is true.
|
||||||
|
val checks = optionalCode(outer, data.isNull)
|
||||||
|
|
||||||
|
// Add position
|
||||||
|
val position = if (e.position) {
|
||||||
|
Seq(ExprCode("", "false", index))
|
||||||
|
} else {
|
||||||
|
Seq.empty
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate code for either ArrayData or MapData
|
||||||
|
val (initMapData, updateRowData, values) = e.collectionType match {
|
||||||
|
case ArrayType(st: StructType, nullable) if e.inline =>
|
||||||
|
val row = codeGenAccessor(ctx, data.value, "col", index, st, nullable, checks)
|
||||||
|
val fieldChecks = checks ++ optionalCode(nullable, row.isNull)
|
||||||
|
val columns = st.fields.toSeq.zipWithIndex.map { case (f, i) =>
|
||||||
|
codeGenAccessor(ctx, row.value, f.name, i.toString, f.dataType, f.nullable, fieldChecks)
|
||||||
|
}
|
||||||
|
("", row.code, columns)
|
||||||
|
|
||||||
|
case ArrayType(dataType, nullable) =>
|
||||||
|
("", "", Seq(codeGenAccessor(ctx, data.value, "col", index, dataType, nullable, checks)))
|
||||||
|
|
||||||
|
case MapType(keyType, valueType, valueContainsNull) =>
|
||||||
|
// Materialize the key and the value arrays before we enter the loop.
|
||||||
|
val keyArray = ctx.freshName("keyArray")
|
||||||
|
val valueArray = ctx.freshName("valueArray")
|
||||||
|
val initArrayData =
|
||||||
|
s"""
|
||||||
|
|ArrayData $keyArray = ${data.isNull} ? null : ${data.value}.keyArray();
|
||||||
|
|ArrayData $valueArray = ${data.isNull} ? null : ${data.value}.valueArray();
|
||||||
|
""".stripMargin
|
||||||
|
val values = Seq(
|
||||||
|
codeGenAccessor(ctx, keyArray, "key", index, keyType, nullable = false, checks),
|
||||||
|
codeGenAccessor(ctx, valueArray, "value", index, valueType, valueContainsNull, checks))
|
||||||
|
(initArrayData, "", values)
|
||||||
|
}
|
||||||
|
|
||||||
|
// In case of outer=true we need to make sure the loop is executed at-least once when the
|
||||||
|
// array/map contains no input. We do this by setting the looping index to -1 if there is no
|
||||||
|
// input, evaluation of the array is prevented by a check in the accessor code.
|
||||||
|
val numElements = ctx.freshName("numElements")
|
||||||
|
val init = if (outer) {
|
||||||
|
s"$numElements == 0 ? -1 : 0"
|
||||||
|
} else {
|
||||||
|
"0"
|
||||||
|
}
|
||||||
|
val numOutput = metricTerm(ctx, "numOutputRows")
|
||||||
|
s"""
|
||||||
|
|${data.code}
|
||||||
|
|$initMapData
|
||||||
|
|int $numElements = ${data.isNull} ? 0 : ${data.value}.numElements();
|
||||||
|
|for (int $index = $init; $index < $numElements; $index++) {
|
||||||
|
| $numOutput.add(1);
|
||||||
|
| $updateRowData
|
||||||
|
| ${consume(ctx, input ++ position ++ values)}
|
||||||
|
|}
|
||||||
|
""".stripMargin
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate code for a regular [[TraversableOnce]] returning [[Generator]].
|
||||||
|
*/
|
||||||
|
private def codeGenTraversableOnce(
|
||||||
|
ctx: CodegenContext,
|
||||||
|
e: Expression,
|
||||||
|
input: Seq[ExprCode],
|
||||||
|
row: ExprCode): String = {
|
||||||
|
|
||||||
|
// Generate the code for the generator
|
||||||
|
val data = e.genCode(ctx)
|
||||||
|
|
||||||
|
// Generate looping variables.
|
||||||
|
val iterator = ctx.freshName("iterator")
|
||||||
|
val hasNext = ctx.freshName("hasNext")
|
||||||
|
val current = ctx.freshName("row")
|
||||||
|
|
||||||
|
// Add a check if the generate outer flag is true.
|
||||||
|
val checks = optionalCode(outer, s"!$hasNext")
|
||||||
|
val values = e.dataType match {
|
||||||
|
case ArrayType(st: StructType, nullable) =>
|
||||||
|
st.fields.toSeq.zipWithIndex.map { case (f, i) =>
|
||||||
|
codeGenAccessor(ctx, current, f.name, s"$i", f.dataType, f.nullable, checks)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// In case of outer=true we need to make sure the loop is executed at-least-once when the
|
||||||
|
// iterator contains no input. We do this by adding an 'outer' variable which guarantees
|
||||||
|
// execution of the first iteration even if there is no input. Evaluation of the iterator is
|
||||||
|
// prevented by checks in the next() and accessor code.
|
||||||
|
val numOutput = metricTerm(ctx, "numOutputRows")
|
||||||
|
if (outer) {
|
||||||
|
val outerVal = ctx.freshName("outer")
|
||||||
|
s"""
|
||||||
|
|${data.code}
|
||||||
|
|scala.collection.Iterator<InternalRow> $iterator = ${data.value}.toIterator();
|
||||||
|
|boolean $outerVal = true;
|
||||||
|
|while ($iterator.hasNext() || $outerVal) {
|
||||||
|
| $numOutput.add(1);
|
||||||
|
| boolean $hasNext = $iterator.hasNext();
|
||||||
|
| InternalRow $current = (InternalRow)($hasNext? $iterator.next() : null);
|
||||||
|
| $outerVal = false;
|
||||||
|
| ${consume(ctx, input ++ values)}
|
||||||
|
|}
|
||||||
|
""".stripMargin
|
||||||
|
} else {
|
||||||
|
s"""
|
||||||
|
|${data.code}
|
||||||
|
|scala.collection.Iterator<InternalRow> $iterator = ${data.value}.toIterator();
|
||||||
|
|while ($iterator.hasNext()) {
|
||||||
|
| $numOutput.add(1);
|
||||||
|
| InternalRow $current = (InternalRow)($iterator.next());
|
||||||
|
| ${consume(ctx, input ++ values)}
|
||||||
|
|}
|
||||||
|
""".stripMargin
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate accessor code for ArrayData and InternalRows.
|
||||||
|
*/
|
||||||
|
private def codeGenAccessor(
|
||||||
|
ctx: CodegenContext,
|
||||||
|
source: String,
|
||||||
|
name: String,
|
||||||
|
index: String,
|
||||||
|
dt: DataType,
|
||||||
|
nullable: Boolean,
|
||||||
|
initialChecks: Seq[String]): ExprCode = {
|
||||||
|
val value = ctx.freshName(name)
|
||||||
|
val javaType = ctx.javaType(dt)
|
||||||
|
val getter = ctx.getValue(source, dt, index)
|
||||||
|
val checks = initialChecks ++ optionalCode(nullable, s"$source.isNullAt($index)")
|
||||||
|
if (checks.nonEmpty) {
|
||||||
|
val isNull = ctx.freshName("isNull")
|
||||||
|
val code =
|
||||||
|
s"""
|
||||||
|
|boolean $isNull = ${checks.mkString(" || ")};
|
||||||
|
|$javaType $value = $isNull ? ${ctx.defaultValue(dt)} : $getter;
|
||||||
|
""".stripMargin
|
||||||
|
ExprCode(code, isNull, value)
|
||||||
|
} else {
|
||||||
|
ExprCode(s"$javaType $value = $getter;", "false", value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private def optionalCode(condition: Boolean, code: => String): Seq[String] = {
|
||||||
|
if (condition) Seq(code)
|
||||||
|
else Seq.empty
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -17,8 +17,12 @@
|
||||||
|
|
||||||
package org.apache.spark.sql
|
package org.apache.spark.sql
|
||||||
|
|
||||||
|
import org.apache.spark.sql.catalyst.InternalRow
|
||||||
|
import org.apache.spark.sql.catalyst.expressions.{Expression, Generator}
|
||||||
|
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
import org.apache.spark.sql.test.SharedSQLContext
|
import org.apache.spark.sql.test.SharedSQLContext
|
||||||
|
import org.apache.spark.sql.types.{IntegerType, StructType}
|
||||||
|
|
||||||
class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
|
class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
|
||||||
import testImplicits._
|
import testImplicits._
|
||||||
|
@ -202,4 +206,34 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
|
||||||
df.selectExpr("array(struct(a), named_struct('a', b))").selectExpr("inline(*)"),
|
df.selectExpr("array(struct(a), named_struct('a', b))").selectExpr("inline(*)"),
|
||||||
Row(1) :: Row(2) :: Nil)
|
Row(1) :: Row(2) :: Nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("SPARK-14986: Outer lateral view with empty generate expression") {
|
||||||
|
checkAnswer(
|
||||||
|
sql("select nil from values 1 lateral view outer explode(array()) n as nil"),
|
||||||
|
Row(null) :: Nil
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("outer explode()") {
|
||||||
|
checkAnswer(
|
||||||
|
sql("select * from values 1, 2 lateral view outer explode(array()) a as b"),
|
||||||
|
Row(1, null) :: Row(2, null) :: Nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("outer generator()") {
|
||||||
|
spark.sessionState.functionRegistry.registerFunction("empty_gen", _ => EmptyGenerator())
|
||||||
|
checkAnswer(
|
||||||
|
sql("select * from values 1, 2 lateral view outer empty_gen() a as b"),
|
||||||
|
Row(1, null) :: Row(2, null) :: Nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case class EmptyGenerator() extends Generator {
|
||||||
|
override def children: Seq[Expression] = Nil
|
||||||
|
override def elementSchema: StructType = new StructType().add("id", IntegerType)
|
||||||
|
override def eval(input: InternalRow): TraversableOnce[InternalRow] = Seq.empty
|
||||||
|
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||||
|
val iteratorClass = classOf[Iterator[_]].getName
|
||||||
|
ev.copy(code = s"$iteratorClass<InternalRow> ${ev.value} = $iteratorClass$$.MODULE$$.empty();")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2086,13 +2086,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
test("SPARK-14986: Outer lateral view with empty generate expression") {
|
|
||||||
checkAnswer(
|
|
||||||
sql("select nil from (select 1 as x ) x lateral view outer explode(array()) n as nil"),
|
|
||||||
Row(null) :: Nil
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
test("data source table created in InMemoryCatalog should be able to read/write") {
|
test("data source table created in InMemoryCatalog should be able to read/write") {
|
||||||
withTable("tbl") {
|
withTable("tbl") {
|
||||||
sql("CREATE TABLE tbl(i INT, j STRING) USING parquet")
|
sql("CREATE TABLE tbl(i INT, j STRING) USING parquet")
|
||||||
|
|
|
@ -17,7 +17,9 @@
|
||||||
|
|
||||||
package org.apache.spark.sql.execution
|
package org.apache.spark.sql.execution
|
||||||
|
|
||||||
import org.apache.spark.sql.Row
|
import org.apache.spark.sql.{Column, Dataset, Row}
|
||||||
|
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
|
||||||
|
import org.apache.spark.sql.catalyst.expressions.{Add, Literal, Stack}
|
||||||
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
|
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
|
||||||
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
|
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
|
||||||
import org.apache.spark.sql.expressions.scalalang.typed
|
import org.apache.spark.sql.expressions.scalalang.typed
|
||||||
|
@ -113,4 +115,32 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
|
||||||
p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
|
p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
|
||||||
assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0)))
|
assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("generate should be included in WholeStageCodegen") {
|
||||||
|
import org.apache.spark.sql.functions._
|
||||||
|
val ds = spark.range(2).select(
|
||||||
|
col("id"),
|
||||||
|
explode(array(col("id") + 1, col("id") + 2)).as("value"))
|
||||||
|
val plan = ds.queryExecution.executedPlan
|
||||||
|
assert(plan.find(p =>
|
||||||
|
p.isInstanceOf[WholeStageCodegenExec] &&
|
||||||
|
p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[GenerateExec]).isDefined)
|
||||||
|
assert(ds.collect() === Array(Row(0, 1), Row(0, 2), Row(1, 2), Row(1, 3)))
|
||||||
|
}
|
||||||
|
|
||||||
|
test("large stack generator should not use WholeStageCodegen") {
|
||||||
|
def createStackGenerator(rows: Int): SparkPlan = {
|
||||||
|
val id = UnresolvedAttribute("id")
|
||||||
|
val stack = Stack(Literal(rows) +: Seq.tabulate(rows)(i => Add(id, Literal(i))))
|
||||||
|
spark.range(500).select(Column(stack)).queryExecution.executedPlan
|
||||||
|
}
|
||||||
|
val isCodeGenerated: SparkPlan => Boolean = {
|
||||||
|
case WholeStageCodegenExec(_: GenerateExec) => true
|
||||||
|
case _ => false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only 'stack' generators that produce 50 rows or less are code generated.
|
||||||
|
assert(createStackGenerator(50).find(isCodeGenerated).isDefined)
|
||||||
|
assert(createStackGenerator(100).find(isCodeGenerated).isEmpty)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -102,7 +102,7 @@ class MiscBenchmark extends BenchmarkBase {
|
||||||
}
|
}
|
||||||
benchmark.run()
|
benchmark.run()
|
||||||
|
|
||||||
/**
|
/*
|
||||||
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
|
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
|
||||||
collect: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
|
collect: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
|
||||||
-------------------------------------------------------------------------------------------
|
-------------------------------------------------------------------------------------------
|
||||||
|
@ -124,7 +124,7 @@ class MiscBenchmark extends BenchmarkBase {
|
||||||
}
|
}
|
||||||
benchmark.run()
|
benchmark.run()
|
||||||
|
|
||||||
/**
|
/*
|
||||||
model name : Westmere E56xx/L56xx/X56xx (Nehalem-C)
|
model name : Westmere E56xx/L56xx/X56xx (Nehalem-C)
|
||||||
collect limit: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
|
collect limit: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
|
||||||
-------------------------------------------------------------------------------------------
|
-------------------------------------------------------------------------------------------
|
||||||
|
@ -132,4 +132,99 @@ class MiscBenchmark extends BenchmarkBase {
|
||||||
collect limit 2 millions 3348 / 4005 0.3 3193.3 0.2X
|
collect limit 2 millions 3348 / 4005 0.3 3193.3 0.2X
|
||||||
*/
|
*/
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ignore("generate explode") {
|
||||||
|
val N = 1 << 24
|
||||||
|
runBenchmark("generate explode array", N) {
|
||||||
|
val df = sparkSession.range(N).selectExpr(
|
||||||
|
"id as key",
|
||||||
|
"array(rand(), rand(), rand(), rand(), rand()) as values")
|
||||||
|
df.selectExpr("key", "explode(values) value").count()
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6
|
||||||
|
Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz
|
||||||
|
|
||||||
|
generate explode array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
|
||||||
|
------------------------------------------------------------------------------------------------
|
||||||
|
generate explode array wholestage off 6920 / 7129 2.4 412.5 1.0X
|
||||||
|
generate explode array wholestage on 623 / 646 26.9 37.1 11.1X
|
||||||
|
*/
|
||||||
|
|
||||||
|
runBenchmark("generate explode map", N) {
|
||||||
|
val df = sparkSession.range(N).selectExpr(
|
||||||
|
"id as key",
|
||||||
|
"map('a', rand(), 'b', rand(), 'c', rand(), 'd', rand(), 'e', rand()) pairs")
|
||||||
|
df.selectExpr("key", "explode(pairs) as (k, v)").count()
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6
|
||||||
|
Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz
|
||||||
|
|
||||||
|
generate explode map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
|
||||||
|
------------------------------------------------------------------------------------------------
|
||||||
|
generate explode map wholestage off 11978 / 11993 1.4 714.0 1.0X
|
||||||
|
generate explode map wholestage on 866 / 919 19.4 51.6 13.8X
|
||||||
|
*/
|
||||||
|
|
||||||
|
runBenchmark("generate posexplode array", N) {
|
||||||
|
val df = sparkSession.range(N).selectExpr(
|
||||||
|
"id as key",
|
||||||
|
"array(rand(), rand(), rand(), rand(), rand()) as values")
|
||||||
|
df.selectExpr("key", "posexplode(values) as (idx, value)").count()
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6
|
||||||
|
Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz
|
||||||
|
|
||||||
|
generate posexplode array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
|
||||||
|
------------------------------------------------------------------------------------------------
|
||||||
|
generate posexplode array wholestage off 7502 / 7513 2.2 447.1 1.0X
|
||||||
|
generate posexplode array wholestage on 617 / 623 27.2 36.8 12.2X
|
||||||
|
*/
|
||||||
|
|
||||||
|
runBenchmark("generate inline array", N) {
|
||||||
|
val df = sparkSession.range(N).selectExpr(
|
||||||
|
"id as key",
|
||||||
|
"array((rand(), rand()), (rand(), rand()), (rand(), 0.0d)) as values")
|
||||||
|
df.selectExpr("key", "inline(values) as (r1, r2)").count()
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6
|
||||||
|
Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz
|
||||||
|
|
||||||
|
generate inline array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
|
||||||
|
------------------------------------------------------------------------------------------------
|
||||||
|
generate inline array wholestage off 6901 / 6928 2.4 411.3 1.0X
|
||||||
|
generate inline array wholestage on 1001 / 1010 16.8 59.7 6.9X
|
||||||
|
*/
|
||||||
|
}
|
||||||
|
|
||||||
|
ignore("generate regular generator") {
|
||||||
|
val N = 1 << 24
|
||||||
|
runBenchmark("generate stack", N) {
|
||||||
|
val df = sparkSession.range(N).selectExpr(
|
||||||
|
"id as key",
|
||||||
|
"id % 2 as t1",
|
||||||
|
"id % 3 as t2",
|
||||||
|
"id % 5 as t3",
|
||||||
|
"id % 7 as t4",
|
||||||
|
"id % 13 as t5")
|
||||||
|
df.selectExpr("key", "stack(4, t1, t2, t3, t4, t5)").count()
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6
|
||||||
|
Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz
|
||||||
|
|
||||||
|
generate stack: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
|
||||||
|
------------------------------------------------------------------------------------------------
|
||||||
|
generate stack wholestage off 12953 / 13070 1.3 772.1 1.0X
|
||||||
|
generate stack wholestage on 836 / 847 20.1 49.8 15.5X
|
||||||
|
*/
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue