[SPARK-14296][SQL] whole stage codegen support for Dataset.map
## What changes were proposed in this pull request? This PR adds a new operator `MapElements` for `Dataset.map`, it's a 1-1 mapping and is easier to adapt to whole stage codegen framework. ## How was this patch tested? new test in `WholeStageCodegenSuite` Author: Wenchen Fan <wenchen@databricks.com> Closes #12087 from cloud-fan/map.
This commit is contained in:
parent
8e5c1cbf2c
commit
f6456fa80b
|
@ -345,7 +345,7 @@ case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None)
|
|||
* @param inputAttributes The input attributes used to resolve deserializer expression, can be empty
|
||||
* if we want to resolve deserializer by children output.
|
||||
*/
|
||||
case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq[Attribute])
|
||||
case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq[Attribute] = Nil)
|
||||
extends UnaryExpression with Unevaluable with NonSQLExpression {
|
||||
// The input attributes used to resolve deserializer expression must be all resolved.
|
||||
require(inputAttributes.forall(_.resolved), "Input attributes must all be resolved.")
|
||||
|
|
|
@ -119,18 +119,18 @@ case class Invoke(
|
|||
override def eval(input: InternalRow): Any =
|
||||
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
|
||||
|
||||
lazy val method = targetObject.dataType match {
|
||||
@transient lazy val method = targetObject.dataType match {
|
||||
case ObjectType(cls) =>
|
||||
cls
|
||||
.getMethods
|
||||
.find(_.getName == functionName)
|
||||
.getOrElse(sys.error(s"Couldn't find $functionName on $cls"))
|
||||
.getReturnType
|
||||
.getName
|
||||
case _ => ""
|
||||
val m = cls.getMethods.find(_.getName == functionName)
|
||||
if (m.isEmpty) {
|
||||
sys.error(s"Couldn't find $functionName on $cls")
|
||||
} else {
|
||||
m
|
||||
}
|
||||
case _ => None
|
||||
}
|
||||
|
||||
lazy val unboxer = (dataType, method) match {
|
||||
lazy val unboxer = (dataType, method.map(_.getReturnType.getName).getOrElse("")) match {
|
||||
case (IntegerType, "java.lang.Object") => (s: String) =>
|
||||
s"((java.lang.Integer)$s).intValue()"
|
||||
case (LongType, "java.lang.Object") => (s: String) =>
|
||||
|
@ -157,21 +157,31 @@ case class Invoke(
|
|||
// If the function can return null, we do an extra check to make sure our null bit is still set
|
||||
// correctly.
|
||||
val objNullCheck = if (ctx.defaultValue(dataType) == "null") {
|
||||
s"${ev.isNull} = ${ev.value} == null;"
|
||||
s"boolean ${ev.isNull} = ${ev.value} == null;"
|
||||
} else {
|
||||
ev.isNull = obj.isNull
|
||||
""
|
||||
}
|
||||
|
||||
val value = unboxer(s"${obj.value}.$functionName($argString)")
|
||||
|
||||
val evaluate = if (method.forall(_.getExceptionTypes.isEmpty)) {
|
||||
s"$javaType ${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType) $value;"
|
||||
} else {
|
||||
s"""
|
||||
$javaType ${ev.value} = ${ctx.defaultValue(javaType)};
|
||||
try {
|
||||
${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(javaType)} : ($javaType) $value;
|
||||
} catch (Exception e) {
|
||||
org.apache.spark.unsafe.Platform.throwException(e);
|
||||
}
|
||||
"""
|
||||
}
|
||||
|
||||
s"""
|
||||
${obj.code}
|
||||
${argGen.map(_.code).mkString("\n")}
|
||||
|
||||
boolean ${ev.isNull} = ${obj.isNull};
|
||||
$javaType ${ev.value} =
|
||||
${ev.isNull} ?
|
||||
${ctx.defaultValue(dataType)} : ($javaType) $value;
|
||||
$evaluate
|
||||
$objNullCheck
|
||||
"""
|
||||
}
|
||||
|
|
|
@ -136,6 +136,7 @@ object SamplePushDown extends Rule[LogicalPlan] {
|
|||
* representation of data item. For example back to back map operations.
|
||||
*/
|
||||
object EliminateSerialization extends Rule[LogicalPlan] {
|
||||
// TODO: find a more general way to do this optimization.
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
case m @ MapPartitions(_, deserializer, _, child: ObjectOperator)
|
||||
if !deserializer.isInstanceOf[Attribute] &&
|
||||
|
@ -144,6 +145,14 @@ object EliminateSerialization extends Rule[LogicalPlan] {
|
|||
m.copy(
|
||||
deserializer = childWithoutSerialization.output.head,
|
||||
child = childWithoutSerialization)
|
||||
|
||||
case m @ MapElements(_, deserializer, _, child: ObjectOperator)
|
||||
if !deserializer.isInstanceOf[Attribute] &&
|
||||
deserializer.dataType == child.outputObject.dataType =>
|
||||
val childWithoutSerialization = child.withObjectOutput
|
||||
m.copy(
|
||||
deserializer = childWithoutSerialization.output.head,
|
||||
child = childWithoutSerialization)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -65,7 +65,7 @@ object MapPartitions {
|
|||
child: LogicalPlan): MapPartitions = {
|
||||
MapPartitions(
|
||||
func.asInstanceOf[Iterator[Any] => Iterator[Any]],
|
||||
UnresolvedDeserializer(encoderFor[T].deserializer, Nil),
|
||||
UnresolvedDeserializer(encoderFor[T].deserializer),
|
||||
encoderFor[U].namedExpressions,
|
||||
child)
|
||||
}
|
||||
|
@ -83,6 +83,30 @@ case class MapPartitions(
|
|||
serializer: Seq[NamedExpression],
|
||||
child: LogicalPlan) extends UnaryNode with ObjectOperator
|
||||
|
||||
object MapElements {
|
||||
def apply[T : Encoder, U : Encoder](
|
||||
func: AnyRef,
|
||||
child: LogicalPlan): MapElements = {
|
||||
MapElements(
|
||||
func,
|
||||
UnresolvedDeserializer(encoderFor[T].deserializer),
|
||||
encoderFor[U].namedExpressions,
|
||||
child)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A relation produced by applying `func` to each element of the `child`.
|
||||
*
|
||||
* @param deserializer used to extract the input to `func` from an input row.
|
||||
* @param serializer use to serialize the output of `func`.
|
||||
*/
|
||||
case class MapElements(
|
||||
func: AnyRef,
|
||||
deserializer: Expression,
|
||||
serializer: Seq[NamedExpression],
|
||||
child: LogicalPlan) extends UnaryNode with ObjectOperator
|
||||
|
||||
/** Factory for constructing new `AppendColumn` nodes. */
|
||||
object AppendColumns {
|
||||
def apply[T : Encoder, U : Encoder](
|
||||
|
@ -90,7 +114,7 @@ object AppendColumns {
|
|||
child: LogicalPlan): AppendColumns = {
|
||||
new AppendColumns(
|
||||
func.asInstanceOf[Any => Any],
|
||||
UnresolvedDeserializer(encoderFor[T].deserializer, Nil),
|
||||
UnresolvedDeserializer(encoderFor[T].deserializer),
|
||||
encoderFor[U].namedExpressions,
|
||||
child)
|
||||
}
|
||||
|
|
|
@ -766,7 +766,8 @@ class Dataset[T] private[sql](
|
|||
|
||||
implicit val tuple2Encoder: Encoder[(T, U)] =
|
||||
ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder)
|
||||
withTypedPlan[(T, U)](other, encoderFor[(T, U)]) { (left, right) =>
|
||||
|
||||
withTypedPlan {
|
||||
Project(
|
||||
leftData :: rightData :: Nil,
|
||||
joined.analyzed)
|
||||
|
@ -1900,7 +1901,9 @@ class Dataset[T] private[sql](
|
|||
* @since 1.6.0
|
||||
*/
|
||||
@Experimental
|
||||
def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func))
|
||||
def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan {
|
||||
MapElements[T, U](func, logicalPlan)
|
||||
}
|
||||
|
||||
/**
|
||||
* :: Experimental ::
|
||||
|
@ -1911,8 +1914,10 @@ class Dataset[T] private[sql](
|
|||
* @since 1.6.0
|
||||
*/
|
||||
@Experimental
|
||||
def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] =
|
||||
map(t => func.call(t))(encoder)
|
||||
def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
|
||||
implicit val uEnc = encoder
|
||||
withTypedPlan(MapElements[T, U](func, logicalPlan))
|
||||
}
|
||||
|
||||
/**
|
||||
* :: Experimental ::
|
||||
|
@ -2412,12 +2417,7 @@ class Dataset[T] private[sql](
|
|||
}
|
||||
|
||||
/** A convenient function to wrap a logical plan and produce a Dataset. */
|
||||
@inline private def withTypedPlan(logicalPlan: => LogicalPlan): Dataset[T] = {
|
||||
new Dataset[T](sqlContext, logicalPlan, encoder)
|
||||
@inline private def withTypedPlan[U : Encoder](logicalPlan: => LogicalPlan): Dataset[U] = {
|
||||
Dataset(sqlContext, logicalPlan)
|
||||
}
|
||||
|
||||
private[sql] def withTypedPlan[R](
|
||||
other: Dataset[_], encoder: Encoder[R])(
|
||||
f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] =
|
||||
new Dataset[R](sqlContext, f(logicalPlan, other.logicalPlan), encoder)
|
||||
}
|
||||
|
|
|
@ -341,6 +341,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
|||
|
||||
case logical.MapPartitions(f, in, out, child) =>
|
||||
execution.MapPartitions(f, in, out, planLater(child)) :: Nil
|
||||
case logical.MapElements(f, in, out, child) =>
|
||||
execution.MapElements(f, in, out, planLater(child)) :: Nil
|
||||
case logical.AppendColumns(f, in, out, child) =>
|
||||
execution.AppendColumns(f, in, out, planLater(child)) :: Nil
|
||||
case logical.MapGroups(f, key, in, out, grouping, data, child) =>
|
||||
|
|
|
@ -152,7 +152,7 @@ trait CodegenSupport extends SparkPlan {
|
|||
s"""
|
||||
|
|
||||
|/*** CONSUME: ${toCommentSafeString(parent.simpleString)} */
|
||||
|${evaluated}
|
||||
|$evaluated
|
||||
|${parent.doConsume(ctx, inputVars, rowVar)}
|
||||
""".stripMargin
|
||||
}
|
||||
|
@ -169,20 +169,20 @@ trait CodegenSupport extends SparkPlan {
|
|||
|
||||
/**
|
||||
* Returns source code to evaluate the variables for required attributes, and clear the code
|
||||
* of evaluated variables, to prevent them to be evaluated twice..
|
||||
* of evaluated variables, to prevent them to be evaluated twice.
|
||||
*/
|
||||
protected def evaluateRequiredVariables(
|
||||
attributes: Seq[Attribute],
|
||||
variables: Seq[ExprCode],
|
||||
required: AttributeSet): String = {
|
||||
var evaluateVars = ""
|
||||
val evaluateVars = new StringBuilder
|
||||
variables.zipWithIndex.foreach { case (ev, i) =>
|
||||
if (ev.code != "" && required.contains(attributes(i))) {
|
||||
evaluateVars += ev.code.trim + "\n"
|
||||
evaluateVars.append(ev.code.trim + "\n")
|
||||
ev.code = ""
|
||||
}
|
||||
}
|
||||
evaluateVars
|
||||
evaluateVars.toString()
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -305,7 +305,6 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
|
|||
def doCodeGen(): (CodegenContext, String) = {
|
||||
val ctx = new CodegenContext
|
||||
val code = child.asInstanceOf[CodegenSupport].produce(ctx, this)
|
||||
val references = ctx.references.toArray
|
||||
val source = s"""
|
||||
public Object generate(Object[] references) {
|
||||
return new GeneratedIterator(references);
|
||||
|
|
|
@ -17,10 +17,13 @@
|
|||
|
||||
package org.apache.spark.sql.execution
|
||||
|
||||
import scala.language.existentials
|
||||
|
||||
import org.apache.spark.api.java.function.MapFunction
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection, GenerateUnsafeRowJoiner}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.plans.physical._
|
||||
import org.apache.spark.sql.types.ObjectType
|
||||
|
||||
|
@ -67,6 +70,70 @@ case class MapPartitions(
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies the given function to each input row and encodes the result.
|
||||
*
|
||||
* Note that, each serializer expression needs the result object which is returned by the given
|
||||
* function, as input. This operator uses some tricks to make sure we only calculate the result
|
||||
* object once. We don't use [[Project]] directly as subexpression elimination doesn't work with
|
||||
* whole stage codegen and it's confusing to show the un-common-subexpression-eliminated version of
|
||||
* a project while explain.
|
||||
*/
|
||||
case class MapElements(
|
||||
func: AnyRef,
|
||||
deserializer: Expression,
|
||||
serializer: Seq[NamedExpression],
|
||||
child: SparkPlan) extends UnaryNode with ObjectOperator with CodegenSupport {
|
||||
override def output: Seq[Attribute] = serializer.map(_.toAttribute)
|
||||
|
||||
override def upstreams(): Seq[RDD[InternalRow]] = {
|
||||
child.asInstanceOf[CodegenSupport].upstreams()
|
||||
}
|
||||
|
||||
protected override def doProduce(ctx: CodegenContext): String = {
|
||||
child.asInstanceOf[CodegenSupport].produce(ctx, this)
|
||||
}
|
||||
|
||||
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
|
||||
val (funcClass, methodName) = func match {
|
||||
case m: MapFunction[_, _] => classOf[MapFunction[_, _]] -> "call"
|
||||
case _ => classOf[Any => Any] -> "apply"
|
||||
}
|
||||
val funcObj = Literal.create(func, ObjectType(funcClass))
|
||||
val resultObjType = serializer.head.collect { case b: BoundReference => b }.head.dataType
|
||||
val callFunc = Invoke(funcObj, methodName, resultObjType, Seq(deserializer))
|
||||
|
||||
val bound = ExpressionCanonicalizer.execute(
|
||||
BindReferences.bindReference(callFunc, child.output))
|
||||
ctx.currentVars = input
|
||||
val evaluated = bound.gen(ctx)
|
||||
|
||||
val resultObj = LambdaVariable(evaluated.value, evaluated.isNull, resultObjType)
|
||||
val outputFields = serializer.map(_ transform {
|
||||
case _: BoundReference => resultObj
|
||||
})
|
||||
val resultVars = outputFields.map(_.gen(ctx))
|
||||
s"""
|
||||
${evaluated.code}
|
||||
${consume(ctx, resultVars)}
|
||||
"""
|
||||
}
|
||||
|
||||
override protected def doExecute(): RDD[InternalRow] = {
|
||||
val callFunc: Any => Any = func match {
|
||||
case m: MapFunction[_, _] => i => m.asInstanceOf[MapFunction[Any, Any]].call(i)
|
||||
case _ => func.asInstanceOf[Any => Any]
|
||||
}
|
||||
child.execute().mapPartitionsInternal { iter =>
|
||||
val getObject = generateToObject(deserializer, child.output)
|
||||
val outputObject = generateToRow(serializer)
|
||||
iter.map(row => outputObject(callFunc(getObject(row))))
|
||||
}
|
||||
}
|
||||
|
||||
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies the given function to each input row, appending the encoded result at the end of the row.
|
||||
*/
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.sql.types.StringType
|
||||
import org.apache.spark.util.Benchmark
|
||||
|
||||
/**
|
||||
* Benchmark for Dataset typed operations comparing with DataFrame and RDD versions.
|
||||
*/
|
||||
object DatasetBenchmark {
|
||||
|
||||
case class Data(l: Long, s: String)
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
val sparkContext = new SparkContext("local[*]", "Dataset benchmark")
|
||||
val sqlContext = new SQLContext(sparkContext)
|
||||
|
||||
import sqlContext.implicits._
|
||||
|
||||
val numRows = 10000000
|
||||
val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s"))
|
||||
val numChains = 10
|
||||
|
||||
val benchmark = new Benchmark("back-to-back map", numRows)
|
||||
|
||||
val func = (d: Data) => Data(d.l + 1, d.s)
|
||||
benchmark.addCase("Dataset") { iter =>
|
||||
var res = df.as[Data]
|
||||
var i = 0
|
||||
while (i < numChains) {
|
||||
res = res.map(func)
|
||||
i += 1
|
||||
}
|
||||
res.queryExecution.toRdd.foreach(_ => Unit)
|
||||
}
|
||||
|
||||
benchmark.addCase("DataFrame") { iter =>
|
||||
var res = df
|
||||
var i = 0
|
||||
while (i < numChains) {
|
||||
res = res.select($"l" + 1 as "l")
|
||||
i += 1
|
||||
}
|
||||
res.queryExecution.toRdd.foreach(_ => Unit)
|
||||
}
|
||||
|
||||
val rdd = sparkContext.range(1, numRows).map(l => Data(l, l.toString))
|
||||
benchmark.addCase("RDD") { iter =>
|
||||
var res = rdd
|
||||
var i = 0
|
||||
while (i < numChains) {
|
||||
res = rdd.map(func)
|
||||
i += 1
|
||||
}
|
||||
res.foreach(_ => Unit)
|
||||
}
|
||||
|
||||
/*
|
||||
Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4
|
||||
Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz
|
||||
back-to-back map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
|
||||
-------------------------------------------------------------------------------------------
|
||||
Dataset 902 / 995 11.1 90.2 1.0X
|
||||
DataFrame 132 / 167 75.5 13.2 6.8X
|
||||
RDD 216 / 237 46.3 21.6 4.2X
|
||||
*/
|
||||
benchmark.run()
|
||||
}
|
||||
}
|
|
@ -198,10 +198,7 @@ abstract class QueryTest extends PlanTest {
|
|||
val logicalPlan = df.queryExecution.analyzed
|
||||
// bypass some cases that we can't handle currently.
|
||||
logicalPlan.transform {
|
||||
case _: MapPartitions => return
|
||||
case _: MapGroups => return
|
||||
case _: AppendColumns => return
|
||||
case _: CoGroup => return
|
||||
case _: ObjectOperator => return
|
||||
case _: LogicalRelation => return
|
||||
}.transformAllExpressions {
|
||||
case a: ImperativeAggregate => return
|
||||
|
|
|
@ -17,7 +17,8 @@
|
|||
|
||||
package org.apache.spark.sql.execution
|
||||
|
||||
import org.apache.spark.sql.Row
|
||||
import org.apache.spark.api.java.function.MapFunction
|
||||
import org.apache.spark.sql.{Encoders, Row}
|
||||
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
|
||||
import org.apache.spark.sql.execution.joins.BroadcastHashJoin
|
||||
import org.apache.spark.sql.functions.{avg, broadcast, col, max}
|
||||
|
@ -70,4 +71,15 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
|
|||
p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[Sort]).isDefined)
|
||||
assert(df.collect() === Array(Row(1), Row(2), Row(3)))
|
||||
}
|
||||
|
||||
test("MapElements should be included in WholeStageCodegen") {
|
||||
import testImplicits._
|
||||
|
||||
val ds = sqlContext.range(10).map(_.toString)
|
||||
val plan = ds.queryExecution.executedPlan
|
||||
assert(plan.find(p =>
|
||||
p.isInstanceOf[WholeStageCodegen] &&
|
||||
p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[MapElements]).isDefined)
|
||||
assert(ds.collect() === 0.until(10).map(_.toString).toArray)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue