[SPARK-14296][SQL] whole stage codegen support for Dataset.map

## What changes were proposed in this pull request?

This PR adds a new operator `MapElements` for `Dataset.map`, it's a 1-1 mapping and is easier to adapt to whole stage codegen framework.

## How was this patch tested?

new test in `WholeStageCodegenSuite`

Author: Wenchen Fan <wenchen@databricks.com>

Closes #12087 from cloud-fan/map.
This commit is contained in:
Wenchen Fan 2016-04-06 12:09:10 +08:00
parent 8e5c1cbf2c
commit f6456fa80b
11 changed files with 247 additions and 41 deletions

View file

@ -345,7 +345,7 @@ case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None)
* @param inputAttributes The input attributes used to resolve deserializer expression, can be empty
* if we want to resolve deserializer by children output.
*/
case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq[Attribute])
case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq[Attribute] = Nil)
extends UnaryExpression with Unevaluable with NonSQLExpression {
// The input attributes used to resolve deserializer expression must be all resolved.
require(inputAttributes.forall(_.resolved), "Input attributes must all be resolved.")

View file

@ -119,18 +119,18 @@ case class Invoke(
override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
lazy val method = targetObject.dataType match {
@transient lazy val method = targetObject.dataType match {
case ObjectType(cls) =>
cls
.getMethods
.find(_.getName == functionName)
.getOrElse(sys.error(s"Couldn't find $functionName on $cls"))
.getReturnType
.getName
case _ => ""
val m = cls.getMethods.find(_.getName == functionName)
if (m.isEmpty) {
sys.error(s"Couldn't find $functionName on $cls")
} else {
m
}
case _ => None
}
lazy val unboxer = (dataType, method) match {
lazy val unboxer = (dataType, method.map(_.getReturnType.getName).getOrElse("")) match {
case (IntegerType, "java.lang.Object") => (s: String) =>
s"((java.lang.Integer)$s).intValue()"
case (LongType, "java.lang.Object") => (s: String) =>
@ -157,21 +157,31 @@ case class Invoke(
// If the function can return null, we do an extra check to make sure our null bit is still set
// correctly.
val objNullCheck = if (ctx.defaultValue(dataType) == "null") {
s"${ev.isNull} = ${ev.value} == null;"
s"boolean ${ev.isNull} = ${ev.value} == null;"
} else {
ev.isNull = obj.isNull
""
}
val value = unboxer(s"${obj.value}.$functionName($argString)")
val evaluate = if (method.forall(_.getExceptionTypes.isEmpty)) {
s"$javaType ${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType) $value;"
} else {
s"""
$javaType ${ev.value} = ${ctx.defaultValue(javaType)};
try {
${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(javaType)} : ($javaType) $value;
} catch (Exception e) {
org.apache.spark.unsafe.Platform.throwException(e);
}
"""
}
s"""
${obj.code}
${argGen.map(_.code).mkString("\n")}
boolean ${ev.isNull} = ${obj.isNull};
$javaType ${ev.value} =
${ev.isNull} ?
${ctx.defaultValue(dataType)} : ($javaType) $value;
$evaluate
$objNullCheck
"""
}

View file

@ -136,6 +136,7 @@ object SamplePushDown extends Rule[LogicalPlan] {
* representation of data item. For example back to back map operations.
*/
object EliminateSerialization extends Rule[LogicalPlan] {
// TODO: find a more general way to do this optimization.
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case m @ MapPartitions(_, deserializer, _, child: ObjectOperator)
if !deserializer.isInstanceOf[Attribute] &&
@ -144,6 +145,14 @@ object EliminateSerialization extends Rule[LogicalPlan] {
m.copy(
deserializer = childWithoutSerialization.output.head,
child = childWithoutSerialization)
case m @ MapElements(_, deserializer, _, child: ObjectOperator)
if !deserializer.isInstanceOf[Attribute] &&
deserializer.dataType == child.outputObject.dataType =>
val childWithoutSerialization = child.withObjectOutput
m.copy(
deserializer = childWithoutSerialization.output.head,
child = childWithoutSerialization)
}
}

View file

@ -65,7 +65,7 @@ object MapPartitions {
child: LogicalPlan): MapPartitions = {
MapPartitions(
func.asInstanceOf[Iterator[Any] => Iterator[Any]],
UnresolvedDeserializer(encoderFor[T].deserializer, Nil),
UnresolvedDeserializer(encoderFor[T].deserializer),
encoderFor[U].namedExpressions,
child)
}
@ -83,6 +83,30 @@ case class MapPartitions(
serializer: Seq[NamedExpression],
child: LogicalPlan) extends UnaryNode with ObjectOperator
object MapElements {
def apply[T : Encoder, U : Encoder](
func: AnyRef,
child: LogicalPlan): MapElements = {
MapElements(
func,
UnresolvedDeserializer(encoderFor[T].deserializer),
encoderFor[U].namedExpressions,
child)
}
}
/**
* A relation produced by applying `func` to each element of the `child`.
*
* @param deserializer used to extract the input to `func` from an input row.
* @param serializer use to serialize the output of `func`.
*/
case class MapElements(
func: AnyRef,
deserializer: Expression,
serializer: Seq[NamedExpression],
child: LogicalPlan) extends UnaryNode with ObjectOperator
/** Factory for constructing new `AppendColumn` nodes. */
object AppendColumns {
def apply[T : Encoder, U : Encoder](
@ -90,7 +114,7 @@ object AppendColumns {
child: LogicalPlan): AppendColumns = {
new AppendColumns(
func.asInstanceOf[Any => Any],
UnresolvedDeserializer(encoderFor[T].deserializer, Nil),
UnresolvedDeserializer(encoderFor[T].deserializer),
encoderFor[U].namedExpressions,
child)
}

View file

@ -766,7 +766,8 @@ class Dataset[T] private[sql](
implicit val tuple2Encoder: Encoder[(T, U)] =
ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder)
withTypedPlan[(T, U)](other, encoderFor[(T, U)]) { (left, right) =>
withTypedPlan {
Project(
leftData :: rightData :: Nil,
joined.analyzed)
@ -1900,7 +1901,9 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func))
def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan {
MapElements[T, U](func, logicalPlan)
}
/**
* :: Experimental ::
@ -1911,8 +1914,10 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] =
map(t => func.call(t))(encoder)
def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
implicit val uEnc = encoder
withTypedPlan(MapElements[T, U](func, logicalPlan))
}
/**
* :: Experimental ::
@ -2412,12 +2417,7 @@ class Dataset[T] private[sql](
}
/** A convenient function to wrap a logical plan and produce a Dataset. */
@inline private def withTypedPlan(logicalPlan: => LogicalPlan): Dataset[T] = {
new Dataset[T](sqlContext, logicalPlan, encoder)
@inline private def withTypedPlan[U : Encoder](logicalPlan: => LogicalPlan): Dataset[U] = {
Dataset(sqlContext, logicalPlan)
}
private[sql] def withTypedPlan[R](
other: Dataset[_], encoder: Encoder[R])(
f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] =
new Dataset[R](sqlContext, f(logicalPlan, other.logicalPlan), encoder)
}

View file

@ -341,6 +341,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.MapPartitions(f, in, out, child) =>
execution.MapPartitions(f, in, out, planLater(child)) :: Nil
case logical.MapElements(f, in, out, child) =>
execution.MapElements(f, in, out, planLater(child)) :: Nil
case logical.AppendColumns(f, in, out, child) =>
execution.AppendColumns(f, in, out, planLater(child)) :: Nil
case logical.MapGroups(f, key, in, out, grouping, data, child) =>

View file

@ -152,7 +152,7 @@ trait CodegenSupport extends SparkPlan {
s"""
|
|/*** CONSUME: ${toCommentSafeString(parent.simpleString)} */
|${evaluated}
|$evaluated
|${parent.doConsume(ctx, inputVars, rowVar)}
""".stripMargin
}
@ -169,20 +169,20 @@ trait CodegenSupport extends SparkPlan {
/**
* Returns source code to evaluate the variables for required attributes, and clear the code
* of evaluated variables, to prevent them to be evaluated twice..
* of evaluated variables, to prevent them to be evaluated twice.
*/
protected def evaluateRequiredVariables(
attributes: Seq[Attribute],
variables: Seq[ExprCode],
required: AttributeSet): String = {
var evaluateVars = ""
val evaluateVars = new StringBuilder
variables.zipWithIndex.foreach { case (ev, i) =>
if (ev.code != "" && required.contains(attributes(i))) {
evaluateVars += ev.code.trim + "\n"
evaluateVars.append(ev.code.trim + "\n")
ev.code = ""
}
}
evaluateVars
evaluateVars.toString()
}
/**
@ -305,7 +305,6 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
def doCodeGen(): (CodegenContext, String) = {
val ctx = new CodegenContext
val code = child.asInstanceOf[CodegenSupport].produce(ctx, this)
val references = ctx.references.toArray
val source = s"""
public Object generate(Object[] references) {
return new GeneratedIterator(references);

View file

@ -17,10 +17,13 @@
package org.apache.spark.sql.execution
import scala.language.existentials
import org.apache.spark.api.java.function.MapFunction
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection, GenerateUnsafeRowJoiner}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.types.ObjectType
@ -67,6 +70,70 @@ case class MapPartitions(
}
}
/**
* Applies the given function to each input row and encodes the result.
*
* Note that, each serializer expression needs the result object which is returned by the given
* function, as input. This operator uses some tricks to make sure we only calculate the result
* object once. We don't use [[Project]] directly as subexpression elimination doesn't work with
* whole stage codegen and it's confusing to show the un-common-subexpression-eliminated version of
* a project while explain.
*/
case class MapElements(
func: AnyRef,
deserializer: Expression,
serializer: Seq[NamedExpression],
child: SparkPlan) extends UnaryNode with ObjectOperator with CodegenSupport {
override def output: Seq[Attribute] = serializer.map(_.toAttribute)
override def upstreams(): Seq[RDD[InternalRow]] = {
child.asInstanceOf[CodegenSupport].upstreams()
}
protected override def doProduce(ctx: CodegenContext): String = {
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val (funcClass, methodName) = func match {
case m: MapFunction[_, _] => classOf[MapFunction[_, _]] -> "call"
case _ => classOf[Any => Any] -> "apply"
}
val funcObj = Literal.create(func, ObjectType(funcClass))
val resultObjType = serializer.head.collect { case b: BoundReference => b }.head.dataType
val callFunc = Invoke(funcObj, methodName, resultObjType, Seq(deserializer))
val bound = ExpressionCanonicalizer.execute(
BindReferences.bindReference(callFunc, child.output))
ctx.currentVars = input
val evaluated = bound.gen(ctx)
val resultObj = LambdaVariable(evaluated.value, evaluated.isNull, resultObjType)
val outputFields = serializer.map(_ transform {
case _: BoundReference => resultObj
})
val resultVars = outputFields.map(_.gen(ctx))
s"""
${evaluated.code}
${consume(ctx, resultVars)}
"""
}
override protected def doExecute(): RDD[InternalRow] = {
val callFunc: Any => Any = func match {
case m: MapFunction[_, _] => i => m.asInstanceOf[MapFunction[Any, Any]].call(i)
case _ => func.asInstanceOf[Any => Any]
}
child.execute().mapPartitionsInternal { iter =>
val getObject = generateToObject(deserializer, child.output)
val outputObject = generateToRow(serializer)
iter.map(row => outputObject(callFunc(getObject(row))))
}
}
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
}
/**
* Applies the given function to each input row, appending the encoded result at the end of the row.
*/

View file

@ -0,0 +1,86 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import org.apache.spark.SparkContext
import org.apache.spark.sql.types.StringType
import org.apache.spark.util.Benchmark
/**
* Benchmark for Dataset typed operations comparing with DataFrame and RDD versions.
*/
object DatasetBenchmark {
case class Data(l: Long, s: String)
def main(args: Array[String]): Unit = {
val sparkContext = new SparkContext("local[*]", "Dataset benchmark")
val sqlContext = new SQLContext(sparkContext)
import sqlContext.implicits._
val numRows = 10000000
val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s"))
val numChains = 10
val benchmark = new Benchmark("back-to-back map", numRows)
val func = (d: Data) => Data(d.l + 1, d.s)
benchmark.addCase("Dataset") { iter =>
var res = df.as[Data]
var i = 0
while (i < numChains) {
res = res.map(func)
i += 1
}
res.queryExecution.toRdd.foreach(_ => Unit)
}
benchmark.addCase("DataFrame") { iter =>
var res = df
var i = 0
while (i < numChains) {
res = res.select($"l" + 1 as "l")
i += 1
}
res.queryExecution.toRdd.foreach(_ => Unit)
}
val rdd = sparkContext.range(1, numRows).map(l => Data(l, l.toString))
benchmark.addCase("RDD") { iter =>
var res = rdd
var i = 0
while (i < numChains) {
res = rdd.map(func)
i += 1
}
res.foreach(_ => Unit)
}
/*
Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4
Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz
back-to-back map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
Dataset 902 / 995 11.1 90.2 1.0X
DataFrame 132 / 167 75.5 13.2 6.8X
RDD 216 / 237 46.3 21.6 4.2X
*/
benchmark.run()
}
}

View file

@ -198,10 +198,7 @@ abstract class QueryTest extends PlanTest {
val logicalPlan = df.queryExecution.analyzed
// bypass some cases that we can't handle currently.
logicalPlan.transform {
case _: MapPartitions => return
case _: MapGroups => return
case _: AppendColumns => return
case _: CoGroup => return
case _: ObjectOperator => return
case _: LogicalRelation => return
}.transformAllExpressions {
case a: ImperativeAggregate => return

View file

@ -17,7 +17,8 @@
package org.apache.spark.sql.execution
import org.apache.spark.sql.Row
import org.apache.spark.api.java.function.MapFunction
import org.apache.spark.sql.{Encoders, Row}
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
import org.apache.spark.sql.execution.joins.BroadcastHashJoin
import org.apache.spark.sql.functions.{avg, broadcast, col, max}
@ -70,4 +71,15 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[Sort]).isDefined)
assert(df.collect() === Array(Row(1), Row(2), Row(3)))
}
test("MapElements should be included in WholeStageCodegen") {
import testImplicits._
val ds = sqlContext.range(10).map(_.toString)
val plan = ds.queryExecution.executedPlan
assert(plan.find(p =>
p.isInstanceOf[WholeStageCodegen] &&
p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[MapElements]).isDefined)
assert(ds.collect() === 0.until(10).map(_.toString).toArray)
}
}