diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index d0ad7a05a0..b8e2b67b2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -68,7 +68,10 @@ class EquivalentExpressions { * is found. That is, if `expr` has already been added, its children are not added. * If ignoreLeaf is true, leaf nodes are ignored. */ - def addExprTree(root: Expression, ignoreLeaf: Boolean = true): Unit = { + def addExprTree( + root: Expression, + ignoreLeaf: Boolean = true, + skipReferenceToExpressions: Boolean = true): Unit = { val skip = root.isInstanceOf[LeafExpression] && ignoreLeaf // There are some special expressions that we should not recurse into children. // 1. CodegenFallback: it's children will not be used to generate code (call eval() instead) @@ -77,7 +80,7 @@ class EquivalentExpressions { // TODO: some expressions implements `CodegenFallback` but can still do codegen, // e.g. `CaseWhen`, we should support them. case _: CodegenFallback => false - case _: ReferenceToExpressions => false + case _: ReferenceToExpressions if skipReferenceToExpressions => false case _ => true } if (!skip && !addExpr(root) && shouldRecurse) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index e4fa429b37..67f6719265 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -46,6 +46,25 @@ import org.apache.spark.util.Utils */ case class ExprCode(var code: String, var isNull: String, var value: String) +/** + * State used for subexpression elimination. + * + * @param isNull A term that holds a boolean value representing whether the expression evaluated + * to null. + * @param value A term for a value of a common sub-expression. Not valid if `isNull` + * is set to `true`. + */ +case class SubExprEliminationState(isNull: String, value: String) + +/** + * Codes and common subexpressions mapping used for subexpression elimination. + * + * @param codes Strings representing the codes that evaluate common subexpressions. + * @param states Foreach expression that is participating in subexpression elimination, + * the state to use. + */ +case class SubExprCodes(codes: Seq[String], states: Map[Expression, SubExprEliminationState]) + /** * A context for codegen, tracking a list of objects that could be passed into generated Java * function. @@ -148,9 +167,6 @@ class CodegenContext { */ val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions - // State used for subexpression elimination. - case class SubExprEliminationState(isNull: String, value: String) - // Foreach expression that is participating in subexpression elimination, the state to use. val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] @@ -571,6 +587,58 @@ class CodegenContext { } } + /** + * Perform a function which generates a sequence of ExprCodes with a given mapping between + * expressions and common expressions, instead of using the mapping in current context. + */ + def withSubExprEliminationExprs( + newSubExprEliminationExprs: Map[Expression, SubExprEliminationState])( + f: => Seq[ExprCode]): Seq[ExprCode] = { + val oldsubExprEliminationExprs = subExprEliminationExprs + subExprEliminationExprs.clear + newSubExprEliminationExprs.foreach(subExprEliminationExprs += _) + + val genCodes = f + + // Restore previous subExprEliminationExprs + subExprEliminationExprs.clear + oldsubExprEliminationExprs.foreach(subExprEliminationExprs += _) + genCodes + } + + /** + * Checks and sets up the state and codegen for subexpression elimination. This finds the + * common subexpressions, generates the code snippets that evaluate those expressions and + * populates the mapping of common subexpressions to the generated code snippets. The generated + * code snippets will be returned and should be inserted into generated codes before these + * common subexpressions actually are used first time. + */ + def subexpressionEliminationForWholeStageCodegen(expressions: Seq[Expression]): SubExprCodes = { + // Create a clear EquivalentExpressions and SubExprEliminationState mapping + val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions + val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] + + // Add each expression tree and compute the common subexpressions. + expressions.foreach(equivalentExpressions.addExprTree(_, true, false)) + + // Get all the expressions that appear at least twice and set up the state for subexpression + // elimination. + val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1) + val codes = commonExprs.map { e => + val expr = e.head + val fnName = freshName("evalExpr") + val isNull = s"${fnName}IsNull" + val value = s"${fnName}Value" + + // Generate the code for this expression tree. + val code = expr.genCode(this) + val state = SubExprEliminationState(code.isNull, code.value) + e.foreach(subExprEliminationExprs.put(_, state)) + code.code.trim + } + SubExprCodes(codes, subExprEliminationExprs.toMap) + } + /** * Checks and sets up the state and codegen for subexpression elimination. This finds the * common subexpressions, generates the functions that evaluate those expressions and populates diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index d0ba37ee13..d2dc80a7e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -244,8 +244,12 @@ case class TungstenAggregate( } } ctx.currentVars = bufVars ++ input - // TODO: support subexpression elimination - val aggVals = updateExpr.map(BindReferences.bindReference(_, inputAttrs).genCode(ctx)) + val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttrs)) + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) + val effectiveCodes = subExprs.codes.mkString("\n") + val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) { + boundUpdateExpr.map(_.genCode(ctx)) + } // aggregate buffer should be updated atomic val updates = aggVals.zipWithIndex.map { case (ev, i) => s""" @@ -255,6 +259,9 @@ case class TungstenAggregate( } s""" | // do aggregate + | // common sub-expressions + | $effectiveCodes + | // evaluate aggregate function | ${evaluateVariables(aggVals)} | // update aggregation buffer | ${updates.mkString("\n").trim} @@ -650,8 +657,12 @@ case class TungstenAggregate( val updateRowInVectorizedHashMap: Option[String] = { if (isVectorizedHashMapEnabled) { ctx.INPUT_ROW = vectorizedRowBuffer - val vectorizedRowEvals = - updateExpr.map(BindReferences.bindReference(_, inputAttr).genCode(ctx)) + val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) + val effectiveCodes = subExprs.codes.mkString("\n") + val vectorizedRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) { + boundUpdateExpr.map(_.genCode(ctx)) + } val updateVectorizedRow = vectorizedRowEvals.zipWithIndex.map { case (ev, i) => val dt = updateExpr(i).dataType ctx.updateColumn(vectorizedRowBuffer, dt, i, ev, updateExpr(i).nullable, @@ -659,6 +670,8 @@ case class TungstenAggregate( } Option( s""" + |// common sub-expressions + |$effectiveCodes |// evaluate aggregate function |${evaluateVariables(vectorizedRowEvals)} |// update vectorized row @@ -701,13 +714,19 @@ case class TungstenAggregate( val updateRowInUnsafeRowMap: String = { ctx.INPUT_ROW = unsafeRowBuffer - val unsafeRowBufferEvals = - updateExpr.map(BindReferences.bindReference(_, inputAttr).genCode(ctx)) + val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) + val effectiveCodes = subExprs.codes.mkString("\n") + val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) { + boundUpdateExpr.map(_.genCode(ctx)) + } val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => val dt = updateExpr(i).dataType ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) } s""" + |// common sub-expressions + |$effectiveCodes |// evaluate aggregate function |${evaluateVariables(unsafeRowBufferEvals)} |// update unsafe row buffer diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 535e64cb34..edca816cb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -31,31 +31,9 @@ object TypedAggregateExpression { def apply[BUF : Encoder, OUT : Encoder]( aggregator: Aggregator[_, BUF, OUT]): TypedAggregateExpression = { val bufferEncoder = encoderFor[BUF] - // We will insert the deserializer and function call expression at the bottom of each serializer - // expression while executing `TypedAggregateExpression`, which means multiply serializer - // expressions will all evaluate the same sub-expression at bottom. To avoid the re-evaluating, - // here we always use one single serializer expression to serialize the buffer object into a - // single-field row, no matter whether the encoder is flat or not. We also need to update the - // deserializer to read in all fields from that single-field row. - // TODO: remove this trick after we have better integration of subexpression elimination and - // whole stage codegen. - val bufferSerializer = if (bufferEncoder.flat) { - bufferEncoder.namedExpressions.head - } else { - Alias(CreateStruct(bufferEncoder.serializer), "buffer")() - } - - val bufferDeserializer = if (bufferEncoder.flat) { - bufferEncoder.deserializer transformUp { - case b: BoundReference => bufferSerializer.toAttribute - } - } else { - bufferEncoder.deserializer transformUp { - case UnresolvedAttribute(nameParts) => - assert(nameParts.length == 1) - UnresolvedExtractValue(bufferSerializer.toAttribute, Literal(nameParts.head)) - case BoundReference(ordinal, dt, _) => GetStructField(bufferSerializer.toAttribute, ordinal) - } + val bufferSerializer = bufferEncoder.namedExpressions + val bufferDeserializer = bufferEncoder.deserializer.transform { + case b: BoundReference => bufferSerializer(b.ordinal).toAttribute } val outputEncoder = encoderFor[OUT] @@ -82,7 +60,7 @@ object TypedAggregateExpression { case class TypedAggregateExpression( aggregator: Aggregator[Any, Any, Any], inputDeserializer: Option[Expression], - bufferSerializer: NamedExpression, + bufferSerializer: Seq[NamedExpression], bufferDeserializer: Expression, outputSerializer: Seq[Expression], outputExternalType: DataType, @@ -106,11 +84,11 @@ case class TypedAggregateExpression( private def bufferExternalType = bufferDeserializer.dataType override lazy val aggBufferAttributes: Seq[AttributeReference] = - bufferSerializer.toAttribute.asInstanceOf[AttributeReference] :: Nil + bufferSerializer.map(_.toAttribute.asInstanceOf[AttributeReference]) override lazy val initialValues: Seq[Expression] = { val zero = Literal.fromObject(aggregator.zero, bufferExternalType) - ReferenceToExpressions(bufferSerializer, zero :: Nil) :: Nil + bufferSerializer.map(ReferenceToExpressions(_, zero :: Nil)) } override lazy val updateExpressions: Seq[Expression] = { @@ -120,7 +98,7 @@ case class TypedAggregateExpression( bufferExternalType, bufferDeserializer :: inputDeserializer.get :: Nil) - ReferenceToExpressions(bufferSerializer, reduced :: Nil) :: Nil + bufferSerializer.map(ReferenceToExpressions(_, reduced :: Nil)) } override lazy val mergeExpressions: Seq[Expression] = { @@ -136,7 +114,7 @@ case class TypedAggregateExpression( bufferExternalType, leftBuffer :: rightBuffer :: Nil) - ReferenceToExpressions(bufferSerializer, merged :: Nil) :: Nil + bufferSerializer.map(ReferenceToExpressions(_, merged :: Nil)) } override lazy val evaluateExpression: Expression = {