[SPARK-14951] [SQL] Support subexpression elimination in TungstenAggregate

## What changes were proposed in this pull request?

We can support subexpression elimination in TungstenAggregate by using current `EquivalentExpressions` which is already used in subexpression elimination for expression codegen.

However, in wholestage codegen, we can't wrap the common expression's codes in functions as before, we simply generate the code snippets for common expressions. These code snippets are inserted before the common expressions are actually used in generated java codes.

For multiple `TypedAggregateExpression` used in aggregation operator, since their input type should be the same. So their `inputDeserializer` will be the same too. This patch can also reduce redundant input deserialization.

## How was this patch tested?
Existing tests.

Author: Liang-Chi Hsieh <simonh@tw.ibm.com>

Closes #12729 from viirya/subexpr-elimination-tungstenaggregate.
This commit is contained in:
Liang-Chi Hsieh 2016-05-04 10:54:51 -07:00 committed by Davies Liu
parent d864c55cf8
commit b85d21fb9d
4 changed files with 109 additions and 41 deletions

View file

@ -68,7 +68,10 @@ class EquivalentExpressions {
* is found. That is, if `expr` has already been added, its children are not added. * is found. That is, if `expr` has already been added, its children are not added.
* If ignoreLeaf is true, leaf nodes are ignored. * If ignoreLeaf is true, leaf nodes are ignored.
*/ */
def addExprTree(root: Expression, ignoreLeaf: Boolean = true): Unit = { def addExprTree(
root: Expression,
ignoreLeaf: Boolean = true,
skipReferenceToExpressions: Boolean = true): Unit = {
val skip = root.isInstanceOf[LeafExpression] && ignoreLeaf val skip = root.isInstanceOf[LeafExpression] && ignoreLeaf
// There are some special expressions that we should not recurse into children. // There are some special expressions that we should not recurse into children.
// 1. CodegenFallback: it's children will not be used to generate code (call eval() instead) // 1. CodegenFallback: it's children will not be used to generate code (call eval() instead)
@ -77,7 +80,7 @@ class EquivalentExpressions {
// TODO: some expressions implements `CodegenFallback` but can still do codegen, // TODO: some expressions implements `CodegenFallback` but can still do codegen,
// e.g. `CaseWhen`, we should support them. // e.g. `CaseWhen`, we should support them.
case _: CodegenFallback => false case _: CodegenFallback => false
case _: ReferenceToExpressions => false case _: ReferenceToExpressions if skipReferenceToExpressions => false
case _ => true case _ => true
} }
if (!skip && !addExpr(root) && shouldRecurse) { if (!skip && !addExpr(root) && shouldRecurse) {

View file

@ -46,6 +46,25 @@ import org.apache.spark.util.Utils
*/ */
case class ExprCode(var code: String, var isNull: String, var value: String) case class ExprCode(var code: String, var isNull: String, var value: String)
/**
* State used for subexpression elimination.
*
* @param isNull A term that holds a boolean value representing whether the expression evaluated
* to null.
* @param value A term for a value of a common sub-expression. Not valid if `isNull`
* is set to `true`.
*/
case class SubExprEliminationState(isNull: String, value: String)
/**
* Codes and common subexpressions mapping used for subexpression elimination.
*
* @param codes Strings representing the codes that evaluate common subexpressions.
* @param states Foreach expression that is participating in subexpression elimination,
* the state to use.
*/
case class SubExprCodes(codes: Seq[String], states: Map[Expression, SubExprEliminationState])
/** /**
* A context for codegen, tracking a list of objects that could be passed into generated Java * A context for codegen, tracking a list of objects that could be passed into generated Java
* function. * function.
@ -148,9 +167,6 @@ class CodegenContext {
*/ */
val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions
// State used for subexpression elimination.
case class SubExprEliminationState(isNull: String, value: String)
// Foreach expression that is participating in subexpression elimination, the state to use. // Foreach expression that is participating in subexpression elimination, the state to use.
val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState]
@ -571,6 +587,58 @@ class CodegenContext {
} }
} }
/**
* Perform a function which generates a sequence of ExprCodes with a given mapping between
* expressions and common expressions, instead of using the mapping in current context.
*/
def withSubExprEliminationExprs(
newSubExprEliminationExprs: Map[Expression, SubExprEliminationState])(
f: => Seq[ExprCode]): Seq[ExprCode] = {
val oldsubExprEliminationExprs = subExprEliminationExprs
subExprEliminationExprs.clear
newSubExprEliminationExprs.foreach(subExprEliminationExprs += _)
val genCodes = f
// Restore previous subExprEliminationExprs
subExprEliminationExprs.clear
oldsubExprEliminationExprs.foreach(subExprEliminationExprs += _)
genCodes
}
/**
* Checks and sets up the state and codegen for subexpression elimination. This finds the
* common subexpressions, generates the code snippets that evaluate those expressions and
* populates the mapping of common subexpressions to the generated code snippets. The generated
* code snippets will be returned and should be inserted into generated codes before these
* common subexpressions actually are used first time.
*/
def subexpressionEliminationForWholeStageCodegen(expressions: Seq[Expression]): SubExprCodes = {
// Create a clear EquivalentExpressions and SubExprEliminationState mapping
val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions
val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState]
// Add each expression tree and compute the common subexpressions.
expressions.foreach(equivalentExpressions.addExprTree(_, true, false))
// Get all the expressions that appear at least twice and set up the state for subexpression
// elimination.
val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1)
val codes = commonExprs.map { e =>
val expr = e.head
val fnName = freshName("evalExpr")
val isNull = s"${fnName}IsNull"
val value = s"${fnName}Value"
// Generate the code for this expression tree.
val code = expr.genCode(this)
val state = SubExprEliminationState(code.isNull, code.value)
e.foreach(subExprEliminationExprs.put(_, state))
code.code.trim
}
SubExprCodes(codes, subExprEliminationExprs.toMap)
}
/** /**
* Checks and sets up the state and codegen for subexpression elimination. This finds the * Checks and sets up the state and codegen for subexpression elimination. This finds the
* common subexpressions, generates the functions that evaluate those expressions and populates * common subexpressions, generates the functions that evaluate those expressions and populates

View file

@ -244,8 +244,12 @@ case class TungstenAggregate(
} }
} }
ctx.currentVars = bufVars ++ input ctx.currentVars = bufVars ++ input
// TODO: support subexpression elimination val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttrs))
val aggVals = updateExpr.map(BindReferences.bindReference(_, inputAttrs).genCode(ctx)) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
val effectiveCodes = subExprs.codes.mkString("\n")
val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) {
boundUpdateExpr.map(_.genCode(ctx))
}
// aggregate buffer should be updated atomic // aggregate buffer should be updated atomic
val updates = aggVals.zipWithIndex.map { case (ev, i) => val updates = aggVals.zipWithIndex.map { case (ev, i) =>
s""" s"""
@ -255,6 +259,9 @@ case class TungstenAggregate(
} }
s""" s"""
| // do aggregate | // do aggregate
| // common sub-expressions
| $effectiveCodes
| // evaluate aggregate function
| ${evaluateVariables(aggVals)} | ${evaluateVariables(aggVals)}
| // update aggregation buffer | // update aggregation buffer
| ${updates.mkString("\n").trim} | ${updates.mkString("\n").trim}
@ -650,8 +657,12 @@ case class TungstenAggregate(
val updateRowInVectorizedHashMap: Option[String] = { val updateRowInVectorizedHashMap: Option[String] = {
if (isVectorizedHashMapEnabled) { if (isVectorizedHashMapEnabled) {
ctx.INPUT_ROW = vectorizedRowBuffer ctx.INPUT_ROW = vectorizedRowBuffer
val vectorizedRowEvals = val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr))
updateExpr.map(BindReferences.bindReference(_, inputAttr).genCode(ctx)) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
val effectiveCodes = subExprs.codes.mkString("\n")
val vectorizedRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
boundUpdateExpr.map(_.genCode(ctx))
}
val updateVectorizedRow = vectorizedRowEvals.zipWithIndex.map { case (ev, i) => val updateVectorizedRow = vectorizedRowEvals.zipWithIndex.map { case (ev, i) =>
val dt = updateExpr(i).dataType val dt = updateExpr(i).dataType
ctx.updateColumn(vectorizedRowBuffer, dt, i, ev, updateExpr(i).nullable, ctx.updateColumn(vectorizedRowBuffer, dt, i, ev, updateExpr(i).nullable,
@ -659,6 +670,8 @@ case class TungstenAggregate(
} }
Option( Option(
s""" s"""
|// common sub-expressions
|$effectiveCodes
|// evaluate aggregate function |// evaluate aggregate function
|${evaluateVariables(vectorizedRowEvals)} |${evaluateVariables(vectorizedRowEvals)}
|// update vectorized row |// update vectorized row
@ -701,13 +714,19 @@ case class TungstenAggregate(
val updateRowInUnsafeRowMap: String = { val updateRowInUnsafeRowMap: String = {
ctx.INPUT_ROW = unsafeRowBuffer ctx.INPUT_ROW = unsafeRowBuffer
val unsafeRowBufferEvals = val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr))
updateExpr.map(BindReferences.bindReference(_, inputAttr).genCode(ctx)) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
val effectiveCodes = subExprs.codes.mkString("\n")
val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
boundUpdateExpr.map(_.genCode(ctx))
}
val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) =>
val dt = updateExpr(i).dataType val dt = updateExpr(i).dataType
ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable)
} }
s""" s"""
|// common sub-expressions
|$effectiveCodes
|// evaluate aggregate function |// evaluate aggregate function
|${evaluateVariables(unsafeRowBufferEvals)} |${evaluateVariables(unsafeRowBufferEvals)}
|// update unsafe row buffer |// update unsafe row buffer

View file

@ -31,31 +31,9 @@ object TypedAggregateExpression {
def apply[BUF : Encoder, OUT : Encoder]( def apply[BUF : Encoder, OUT : Encoder](
aggregator: Aggregator[_, BUF, OUT]): TypedAggregateExpression = { aggregator: Aggregator[_, BUF, OUT]): TypedAggregateExpression = {
val bufferEncoder = encoderFor[BUF] val bufferEncoder = encoderFor[BUF]
// We will insert the deserializer and function call expression at the bottom of each serializer val bufferSerializer = bufferEncoder.namedExpressions
// expression while executing `TypedAggregateExpression`, which means multiply serializer val bufferDeserializer = bufferEncoder.deserializer.transform {
// expressions will all evaluate the same sub-expression at bottom. To avoid the re-evaluating, case b: BoundReference => bufferSerializer(b.ordinal).toAttribute
// here we always use one single serializer expression to serialize the buffer object into a
// single-field row, no matter whether the encoder is flat or not. We also need to update the
// deserializer to read in all fields from that single-field row.
// TODO: remove this trick after we have better integration of subexpression elimination and
// whole stage codegen.
val bufferSerializer = if (bufferEncoder.flat) {
bufferEncoder.namedExpressions.head
} else {
Alias(CreateStruct(bufferEncoder.serializer), "buffer")()
}
val bufferDeserializer = if (bufferEncoder.flat) {
bufferEncoder.deserializer transformUp {
case b: BoundReference => bufferSerializer.toAttribute
}
} else {
bufferEncoder.deserializer transformUp {
case UnresolvedAttribute(nameParts) =>
assert(nameParts.length == 1)
UnresolvedExtractValue(bufferSerializer.toAttribute, Literal(nameParts.head))
case BoundReference(ordinal, dt, _) => GetStructField(bufferSerializer.toAttribute, ordinal)
}
} }
val outputEncoder = encoderFor[OUT] val outputEncoder = encoderFor[OUT]
@ -82,7 +60,7 @@ object TypedAggregateExpression {
case class TypedAggregateExpression( case class TypedAggregateExpression(
aggregator: Aggregator[Any, Any, Any], aggregator: Aggregator[Any, Any, Any],
inputDeserializer: Option[Expression], inputDeserializer: Option[Expression],
bufferSerializer: NamedExpression, bufferSerializer: Seq[NamedExpression],
bufferDeserializer: Expression, bufferDeserializer: Expression,
outputSerializer: Seq[Expression], outputSerializer: Seq[Expression],
outputExternalType: DataType, outputExternalType: DataType,
@ -106,11 +84,11 @@ case class TypedAggregateExpression(
private def bufferExternalType = bufferDeserializer.dataType private def bufferExternalType = bufferDeserializer.dataType
override lazy val aggBufferAttributes: Seq[AttributeReference] = override lazy val aggBufferAttributes: Seq[AttributeReference] =
bufferSerializer.toAttribute.asInstanceOf[AttributeReference] :: Nil bufferSerializer.map(_.toAttribute.asInstanceOf[AttributeReference])
override lazy val initialValues: Seq[Expression] = { override lazy val initialValues: Seq[Expression] = {
val zero = Literal.fromObject(aggregator.zero, bufferExternalType) val zero = Literal.fromObject(aggregator.zero, bufferExternalType)
ReferenceToExpressions(bufferSerializer, zero :: Nil) :: Nil bufferSerializer.map(ReferenceToExpressions(_, zero :: Nil))
} }
override lazy val updateExpressions: Seq[Expression] = { override lazy val updateExpressions: Seq[Expression] = {
@ -120,7 +98,7 @@ case class TypedAggregateExpression(
bufferExternalType, bufferExternalType,
bufferDeserializer :: inputDeserializer.get :: Nil) bufferDeserializer :: inputDeserializer.get :: Nil)
ReferenceToExpressions(bufferSerializer, reduced :: Nil) :: Nil bufferSerializer.map(ReferenceToExpressions(_, reduced :: Nil))
} }
override lazy val mergeExpressions: Seq[Expression] = { override lazy val mergeExpressions: Seq[Expression] = {
@ -136,7 +114,7 @@ case class TypedAggregateExpression(
bufferExternalType, bufferExternalType,
leftBuffer :: rightBuffer :: Nil) leftBuffer :: rightBuffer :: Nil)
ReferenceToExpressions(bufferSerializer, merged :: Nil) :: Nil bufferSerializer.map(ReferenceToExpressions(_, merged :: Nil))
} }
override lazy val evaluateExpression: Expression = { override lazy val evaluateExpression: Expression = {