[SPARK-21870][SQL] Split aggregation code into small functions

## What changes were proposed in this pull request?
This pr proposed to split aggregation code into small functions in `HashAggregateExec`. In #18810, we got performance regression if JVMs didn't compile too long functions. I checked and I found the codegen of `HashAggregateExec` frequently goes over the limit when a query has too many aggregate functions (e.g., q66 in TPCDS).

The current master places all the generated aggregation code in a single function. In this pr, I modified the code to assign an individual function for each aggregate function (e.g., `SUM`
 and `AVG`). For example, in a query
`SELECT SUM(a), AVG(a) FROM VALUES(1) t(a)`, the proposed code defines two functions
for `SUM(a)` and `AVG(a)` as follows;

- generated  code with this pr (https://gist.github.com/maropu/812990012bc967a78364be0fa793f559):
```
/* 173 */   private void agg_doConsume_0(InternalRow inputadapter_row_0, long agg_expr_0_0, boolean agg_exprIsNull_0_0, double agg_expr_1_0, boolean agg_exprIsNull_1_0, long agg_expr_2_0, boolean agg_exprIsNull_2_0) throws java.io.IOException {
/* 174 */     // do aggregate
/* 175 */     // common sub-expressions
/* 176 */
/* 177 */     // evaluate aggregate functions and update aggregation buffers
/* 178 */     agg_doAggregate_sum_0(agg_exprIsNull_0_0, agg_expr_0_0);
/* 179 */     agg_doAggregate_avg_0(agg_expr_1_0, agg_exprIsNull_1_0, agg_exprIsNull_2_0, agg_expr_2_0);
/* 180 */
/* 181 */   }
...
/* 071 */   private void agg_doAggregate_avg_0(double agg_expr_1_0, boolean agg_exprIsNull_1_0, boolean agg_exprIsNull_2_0, long agg_expr_2_0) throws java.io.IOException {
/* 072 */     // do aggregate for avg
/* 073 */     // evaluate aggregate function
/* 074 */     boolean agg_isNull_19 = true;
/* 075 */     double agg_value_19 = -1.0;
...
/* 114 */   private void agg_doAggregate_sum_0(boolean agg_exprIsNull_0_0, long agg_expr_0_0) throws java.io.IOException {
/* 115 */     // do aggregate for sum
/* 116 */     // evaluate aggregate function
/* 117 */     agg_agg_isNull_11_0 = true;
/* 118 */     long agg_value_11 = -1L;
```

- generated code in the current master (https://gist.github.com/maropu/e9d772af2c98d8991a6a5f0af7841760)
```
/* 059 */   private void agg_doConsume_0(InternalRow localtablescan_row_0, int agg_expr_0_0) throws java.io.IOException {
/* 060 */     // do aggregate
/* 061 */     // common sub-expressions
/* 062 */     boolean agg_isNull_4 = false;
/* 063 */     long agg_value_4 = -1L;
/* 064 */     if (!false) {
/* 065 */       agg_value_4 = (long) agg_expr_0_0;
/* 066 */     }
/* 067 */     // evaluate aggregate function
/* 068 */     agg_agg_isNull_7_0 = true;
/* 069 */     long agg_value_7 = -1L;
/* 070 */     do {
/* 071 */       if (!agg_bufIsNull_0) {
/* 072 */         agg_agg_isNull_7_0 = false;
/* 073 */         agg_value_7 = agg_bufValue_0;
/* 074 */         continue;
/* 075 */       }
/* 076 */
/* 077 */       boolean agg_isNull_9 = false;
/* 078 */       long agg_value_9 = -1L;
/* 079 */       if (!false) {
/* 080 */         agg_value_9 = (long) 0;
/* 081 */       }
/* 082 */       if (!agg_isNull_9) {
/* 083 */         agg_agg_isNull_7_0 = false;
/* 084 */         agg_value_7 = agg_value_9;
/* 085 */         continue;
/* 086 */       }
/* 087 */
/* 088 */     } while (false);
/* 089 */
/* 090 */     long agg_value_6 = -1L;
/* 091 */
/* 092 */     agg_value_6 = agg_value_7 + agg_value_4;
/* 093 */     boolean agg_isNull_11 = true;
/* 094 */     double agg_value_11 = -1.0;
/* 095 */
/* 096 */     if (!agg_bufIsNull_1) {
/* 097 */       agg_agg_isNull_13_0 = true;
/* 098 */       double agg_value_13 = -1.0;
/* 099 */       do {
/* 100 */         boolean agg_isNull_14 = agg_isNull_4;
/* 101 */         double agg_value_14 = -1.0;
/* 102 */         if (!agg_isNull_4) {
/* 103 */           agg_value_14 = (double) agg_value_4;
/* 104 */         }
/* 105 */         if (!agg_isNull_14) {
/* 106 */           agg_agg_isNull_13_0 = false;
/* 107 */           agg_value_13 = agg_value_14;
/* 108 */           continue;
/* 109 */         }
/* 110 */
/* 111 */         boolean agg_isNull_15 = false;
/* 112 */         double agg_value_15 = -1.0;
/* 113 */         if (!false) {
/* 114 */           agg_value_15 = (double) 0;
/* 115 */         }
/* 116 */         if (!agg_isNull_15) {
/* 117 */           agg_agg_isNull_13_0 = false;
/* 118 */           agg_value_13 = agg_value_15;
/* 119 */           continue;
/* 120 */         }
/* 121 */
/* 122 */       } while (false);
/* 123 */
/* 124 */       agg_isNull_11 = false; // resultCode could change nullability.
/* 125 */
/* 126 */       agg_value_11 = agg_bufValue_1 + agg_value_13;
/* 127 */
/* 128 */     }
/* 129 */     boolean agg_isNull_17 = false;
/* 130 */     long agg_value_17 = -1L;
/* 131 */     if (!false && agg_isNull_4) {
/* 132 */       agg_isNull_17 = agg_bufIsNull_2;
/* 133 */       agg_value_17 = agg_bufValue_2;
/* 134 */     } else {
/* 135 */       boolean agg_isNull_20 = true;
/* 136 */       long agg_value_20 = -1L;
/* 137 */
/* 138 */       if (!agg_bufIsNull_2) {
/* 139 */         agg_isNull_20 = false; // resultCode could change nullability.
/* 140 */
/* 141 */         agg_value_20 = agg_bufValue_2 + 1L;
/* 142 */
/* 143 */       }
/* 144 */       agg_isNull_17 = agg_isNull_20;
/* 145 */       agg_value_17 = agg_value_20;
/* 146 */     }
/* 147 */     // update aggregation buffer
/* 148 */     agg_bufIsNull_0 = false;
/* 149 */     agg_bufValue_0 = agg_value_6;
/* 150 */
/* 151 */     agg_bufIsNull_1 = agg_isNull_11;
/* 152 */     agg_bufValue_1 = agg_value_11;
/* 153 */
/* 154 */     agg_bufIsNull_2 = agg_isNull_17;
/* 155 */     agg_bufValue_2 = agg_value_17;
/* 156 */
/* 157 */   }
```
You can check the previous discussion in https://github.com/apache/spark/pull/19082

## How was this patch tested?
Existing tests

Closes #20965 from maropu/SPARK-21870-2.

Authored-by: Takeshi Yamamuro <yamamuro@apache.org>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Takeshi Yamamuro 2019-09-06 11:45:14 +08:00 committed by Wenchen Fan
parent 36f8e53cfa
commit cb0cddffe9
7 changed files with 350 additions and 71 deletions

View file

@ -115,7 +115,13 @@ package object dsl {
def getField(fieldName: String): UnresolvedExtractValue =
UnresolvedExtractValue(expr, Literal(fieldName))
def cast(to: DataType): Expression = Cast(expr, to)
def cast(to: DataType): Expression = {
if (expr.resolved && expr.dataType.sameType(to)) {
expr
} else {
Cast(expr, to)
}
}
def asc: SortOrder = SortOrder(expr, Ascending)
def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast, Set.empty)

View file

@ -1612,6 +1612,48 @@ object CodeGenerator extends Logging {
}
}
/**
* Extracts all the input variables from references and subexpression elimination states
* for a given `expr`. This result will be used to split the generated code of
* expressions into multiple functions.
*/
def getLocalInputVariableValues(
ctx: CodegenContext,
expr: Expression,
subExprs: Map[Expression, SubExprEliminationState]): Set[VariableValue] = {
val argSet = mutable.Set[VariableValue]()
if (ctx.INPUT_ROW != null) {
argSet += JavaCode.variable(ctx.INPUT_ROW, classOf[InternalRow])
}
// Collects local variables from a given `expr` tree
val collectLocalVariable = (ev: ExprValue) => ev match {
case vv: VariableValue => argSet += vv
case _ =>
}
val stack = mutable.Stack[Expression](expr)
while (stack.nonEmpty) {
stack.pop() match {
case e if subExprs.contains(e) =>
val SubExprEliminationState(isNull, value) = subExprs(e)
collectLocalVariable(value)
collectLocalVariable(isNull)
case ref: BoundReference if ctx.currentVars != null &&
ctx.currentVars(ref.ordinal) != null =>
val ExprCode(_, isNull, value) = ctx.currentVars(ref.ordinal)
collectLocalVariable(value)
collectLocalVariable(isNull)
case e =>
stack.pushAll(e.children)
}
}
argSet.toSet
}
/**
* Returns the name used in accessor and setter for a Java primitive type.
*/
@ -1719,6 +1761,15 @@ object CodeGenerator extends Logging {
1 + params.map(paramLengthForExpr).sum
}
def calculateParamLengthFromExprValues(params: Seq[ExprValue]): Int = {
def paramLengthForExpr(input: ExprValue): Int = input.javaType match {
case java.lang.Long.TYPE | java.lang.Double.TYPE => 2
case _ => 1
}
// Initial value is 1 for `this`.
1 + params.map(paramLengthForExpr).sum
}
/**
* In Java, a method descriptor is valid only if it represents method parameters with a total
* length less than a pre-defined constant.

View file

@ -143,7 +143,10 @@ trait Block extends TreeNode[Block] with JavaCode {
case _ => code.trim
}
def length: Int = toString.length
def length: Int = {
// Returns a code length without comments
CodeFormatter.stripExtraNewLinesAndComments(toString).length
}
def isEmpty: Boolean = toString.isEmpty

View file

@ -354,12 +354,14 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val eval = child.genCode(ctx)
val value = eval.isNull match {
case TrueLiteral => FalseLiteral
case FalseLiteral => TrueLiteral
case v => JavaCode.isNullExpression(s"!$v")
val (value, newCode) = eval.isNull match {
case TrueLiteral => (FalseLiteral, EmptyBlock)
case FalseLiteral => (TrueLiteral, EmptyBlock)
case v =>
val value = ctx.freshName("value")
(JavaCode.variable(value, BooleanType), code"boolean $value = !$v;")
}
ExprCode(code = eval.code, isNull = FalseLiteral, value = value)
ExprCode(code = eval.code + newCode, isNull = FalseLiteral, value = value)
}
override def sql: String = s"(${child.sql} IS NOT NULL)"

View file

@ -1080,6 +1080,15 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
val CODEGEN_SPLIT_AGGREGATE_FUNC =
buildConf("spark.sql.codegen.aggregate.splitAggregateFunc.enabled")
.internal()
.doc("When true, the code generator would split aggregate code into individual methods " +
"instead of a single big method. This can be used to avoid oversized function that " +
"can miss the opportunity of JIT optimization.")
.booleanConf
.createWithDefault(true)
val MAX_NESTED_VIEW_DEPTH =
buildConf("spark.sql.view.maxNestedViewDepth")
.internal()
@ -2353,6 +2362,8 @@ class SQLConf extends Serializable with Logging {
def cartesianProductExecBufferSpillThreshold: Int =
getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_SPILL_THRESHOLD)
def codegenSplitAggregateFunc: Boolean = getConf(SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC)
def maxNestedViewDepth: Int = getConf(SQLConf.MAX_NESTED_VIEW_DEPTH)
def starSchemaDetection: Boolean = getConf(STARSCHEMA_DETECTION)

View file

@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.aggregate
import java.util.concurrent.TimeUnit._
import scala.collection.mutable
import org.apache.spark.TaskContext
import org.apache.spark.memory.{SparkOutOfMemoryError, TaskMemoryManager}
import org.apache.spark.rdd.RDD
@ -174,8 +176,9 @@ case class HashAggregateExec(
}
}
// The variables used as aggregation buffer. Only used for aggregation without keys.
private var bufVars: Seq[ExprCode] = _
// The variables are used as aggregation buffers and each aggregate function has one or more
// ExprCode to initialize its buffer slots. Only used for aggregation without keys.
private var bufVars: Seq[Seq[ExprCode]] = _
private def doProduceWithoutKeys(ctx: CodegenContext): String = {
val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg")
@ -184,27 +187,30 @@ case class HashAggregateExec(
// generate variables for aggregation buffer
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
val initExpr = functions.flatMap(f => f.initialValues)
bufVars = initExpr.map { e =>
val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "bufIsNull")
val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue")
// The initial expression should not access any column
val ev = e.genCode(ctx)
val initVars = code"""
| $isNull = ${ev.isNull};
| $value = ${ev.value};
""".stripMargin
ExprCode(
ev.code + initVars,
JavaCode.isNullGlobal(isNull),
JavaCode.global(value, e.dataType))
val initExpr = functions.map(f => f.initialValues)
bufVars = initExpr.map { exprs =>
exprs.map { e =>
val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "bufIsNull")
val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue")
// The initial expression should not access any column
val ev = e.genCode(ctx)
val initVars = code"""
|$isNull = ${ev.isNull};
|$value = ${ev.value};
""".stripMargin
ExprCode(
ev.code + initVars,
JavaCode.isNullGlobal(isNull),
JavaCode.global(value, e.dataType))
}
}
val initBufVar = evaluateVariables(bufVars)
val flatBufVars = bufVars.flatten
val initBufVar = evaluateVariables(flatBufVars)
// generate variables for output
val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) {
// evaluate aggregate results
ctx.currentVars = bufVars
ctx.currentVars = flatBufVars
val aggResults = bindReferences(
functions.map(_.evaluateExpression),
aggregateBufferAttributes).map(_.genCode(ctx))
@ -218,7 +224,7 @@ case class HashAggregateExec(
""".stripMargin)
} else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
// output the aggregate buffer directly
(bufVars, "")
(flatBufVars, "")
} else {
// no aggregate function, the result should be literals
val resultVars = resultExpressions.map(_.genCode(ctx))
@ -255,11 +261,85 @@ case class HashAggregateExec(
""".stripMargin
}
private def isValidParamLength(paramLength: Int): Boolean = {
// This config is only for testing
sqlContext.getConf("spark.sql.HashAggregateExec.validParamLength", null) match {
case null | "" => CodeGenerator.isValidParamLength(paramLength)
case validLength => paramLength <= validLength.toInt
}
}
// Splits aggregate code into small functions because the most of JVM implementations
// can not compile too long functions. Returns None if we are not able to split the given code.
//
// Note: The difference from `CodeGenerator.splitExpressions` is that we define an individual
// function for each aggregation function (e.g., SUM and AVG). For example, in a query
// `SELECT SUM(a), AVG(a) FROM VALUES(1) t(a)`, we define two functions
// for `SUM(a)` and `AVG(a)`.
private def splitAggregateExpressions(
ctx: CodegenContext,
aggNames: Seq[String],
aggBufferUpdatingExprs: Seq[Seq[Expression]],
aggCodeBlocks: Seq[Block],
subExprs: Map[Expression, SubExprEliminationState]): Option[String] = {
val exprValsInSubExprs = subExprs.flatMap { case (_, s) => s.value :: s.isNull :: Nil }
if (exprValsInSubExprs.exists(_.isInstanceOf[SimpleExprValue])) {
// `SimpleExprValue`s cannot be used as an input variable for split functions, so
// we give up splitting functions if it exists in `subExprs`.
None
} else {
val inputVars = aggBufferUpdatingExprs.map { aggExprsForOneFunc =>
val inputVarsForOneFunc = aggExprsForOneFunc.map(
CodeGenerator.getLocalInputVariableValues(ctx, _, subExprs)).reduce(_ ++ _).toSeq
val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVarsForOneFunc)
// Checks if a parameter length for the `aggExprsForOneFunc` does not go over the JVM limit
if (isValidParamLength(paramLength)) {
Some(inputVarsForOneFunc)
} else {
None
}
}
// Checks if all the aggregate code can be split into pieces.
// If the parameter length of at lease one `aggExprsForOneFunc` goes over the limit,
// we totally give up splitting aggregate code.
if (inputVars.forall(_.isDefined)) {
val splitCodes = inputVars.flatten.zipWithIndex.map { case (args, i) =>
val doAggFunc = ctx.freshName(s"doAggregate_${aggNames(i)}")
val argList = args.map(v => s"${v.javaType.getName} ${v.variableName}").mkString(", ")
val doAggFuncName = ctx.addNewFunction(doAggFunc,
s"""
|private void $doAggFunc($argList) throws java.io.IOException {
| ${aggCodeBlocks(i)}
|}
""".stripMargin)
val inputVariables = args.map(_.variableName).mkString(", ")
s"$doAggFuncName($inputVariables);"
}
Some(splitCodes.mkString("\n").trim)
} else {
val errMsg = "Failed to split aggregate code into small functions because the parameter " +
"length of at least one split function went over the JVM limit: " +
CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH
if (Utils.isTesting) {
throw new IllegalStateException(errMsg)
} else {
logInfo(errMsg)
None
}
}
}
}
private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {
// only have DeclarativeAggregate
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output
val updateExpr = aggregateExpressions.flatMap { e =>
// To individually generate code for each aggregate function, an element in `updateExprs` holds
// all the expressions for the buffer of an aggregation function.
val updateExprs = aggregateExpressions.map { e =>
e.mode match {
case Partial | Complete =>
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions
@ -267,28 +347,56 @@ case class HashAggregateExec(
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
}
}
ctx.currentVars = bufVars ++ input
val boundUpdateExpr = bindReferences(updateExpr, inputAttrs)
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
val effectiveCodes = subExprs.codes.mkString("\n")
val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) {
boundUpdateExpr.map(_.genCode(ctx))
ctx.currentVars = bufVars.flatten ++ input
val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc =>
bindReferences(updateExprsForOneFunc, inputAttrs)
}
// aggregate buffer should be updated atomic
val updates = aggVals.zipWithIndex.map { case (ev, i) =>
s"""
| ${bufVars(i).isNull} = ${ev.isNull};
| ${bufVars(i).value} = ${ev.value};
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
val effectiveCodes = subExprs.codes.mkString("\n")
val bufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc =>
ctx.withSubExprEliminationExprs(subExprs.states) {
boundUpdateExprsForOneFunc.map(_.genCode(ctx))
}
}
val aggNames = functions.map(_.prettyName)
val aggCodeBlocks = bufferEvals.zipWithIndex.map { case (bufferEvalsForOneFunc, i) =>
val bufVarsForOneFunc = bufVars(i)
// All the update code for aggregation buffers should be placed in the end
// of each aggregation function code.
val updates = bufferEvalsForOneFunc.zip(bufVarsForOneFunc).map { case (ev, bufVar) =>
s"""
|${bufVar.isNull} = ${ev.isNull};
|${bufVar.value} = ${ev.value};
""".stripMargin
}
code"""
|// do aggregate for ${aggNames(i)}
|// evaluate aggregate function
|${evaluateVariables(bufferEvalsForOneFunc)}
|// update aggregation buffers
|${updates.mkString("\n").trim}
""".stripMargin
}
val codeToEvalAggFunc = if (conf.codegenSplitAggregateFunc &&
aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) {
val maybeSplitCode = splitAggregateExpressions(
ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states)
maybeSplitCode.getOrElse {
aggCodeBlocks.fold(EmptyBlock)(_ + _).code
}
} else {
aggCodeBlocks.fold(EmptyBlock)(_ + _).code
}
s"""
| // do aggregate
| // common sub-expressions
| $effectiveCodes
| // evaluate aggregate function
| ${evaluateVariables(aggVals)}
| // update aggregation buffer
| ${updates.mkString("\n").trim}
|// do aggregate
|// common sub-expressions
|$effectiveCodes
|// evaluate aggregate functions and update aggregation buffers
|$codeToEvalAggFunc
""".stripMargin
}
@ -745,8 +853,10 @@ case class HashAggregateExec(
val unsafeRowBuffer = ctx.freshName("unsafeRowAggBuffer")
val fastRowBuffer = ctx.freshName("fastAggBuffer")
// only have DeclarativeAggregate
val updateExpr = aggregateExpressions.flatMap { e =>
// To individually generate code for each aggregate function, an element in `updateExprs` holds
// all the expressions for the buffer of an aggregation function.
val updateExprs = aggregateExpressions.map { e =>
// only have DeclarativeAggregate
e.mode match {
case Partial | Complete =>
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions
@ -824,25 +934,70 @@ case class HashAggregateExec(
// generating input columns, we use `currentVars`.
ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input
val aggNames = aggregateExpressions.map(_.aggregateFunction.prettyName)
// Computes start offsets for each aggregation function code
// in the underlying buffer row.
val bufferStartOffsets = {
val offsets = mutable.ArrayBuffer[Int]()
var curOffset = 0
updateExprs.foreach { exprsForOneFunc =>
offsets += curOffset
curOffset += exprsForOneFunc.length
}
offsets.toArray
}
val updateRowInRegularHashMap: String = {
ctx.INPUT_ROW = unsafeRowBuffer
val boundUpdateExpr = bindReferences(updateExpr, inputAttr)
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc =>
bindReferences(updateExprsForOneFunc, inputAttr)
}
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
val effectiveCodes = subExprs.codes.mkString("\n")
val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
boundUpdateExpr.map(_.genCode(ctx))
val unsafeRowBufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc =>
ctx.withSubExprEliminationExprs(subExprs.states) {
boundUpdateExprsForOneFunc.map(_.genCode(ctx))
}
}
val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) =>
val dt = updateExpr(i).dataType
CodeGenerator.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable)
val aggCodeBlocks = updateExprs.indices.map { i =>
val rowBufferEvalsForOneFunc = unsafeRowBufferEvals(i)
val boundUpdateExprsForOneFunc = boundUpdateExprs(i)
val bufferOffset = bufferStartOffsets(i)
// All the update code for aggregation buffers should be placed in the end
// of each aggregation function code.
val updateRowBuffers = rowBufferEvalsForOneFunc.zipWithIndex.map { case (ev, j) =>
val updateExpr = boundUpdateExprsForOneFunc(j)
val dt = updateExpr.dataType
val nullable = updateExpr.nullable
CodeGenerator.updateColumn(unsafeRowBuffer, dt, bufferOffset + j, ev, nullable)
}
code"""
|// evaluate aggregate function for ${aggNames(i)}
|${evaluateVariables(rowBufferEvalsForOneFunc)}
|// update unsafe row buffer
|${updateRowBuffers.mkString("\n").trim}
""".stripMargin
}
val codeToEvalAggFunc = if (conf.codegenSplitAggregateFunc &&
aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) {
val maybeSplitCode = splitAggregateExpressions(
ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states)
maybeSplitCode.getOrElse {
aggCodeBlocks.fold(EmptyBlock)(_ + _).code
}
} else {
aggCodeBlocks.fold(EmptyBlock)(_ + _).code
}
s"""
|// common sub-expressions
|$effectiveCodes
|// evaluate aggregate function
|${evaluateVariables(unsafeRowBufferEvals)}
|// update unsafe row buffer
|${updateUnsafeRowBuffer.mkString("\n").trim}
|// evaluate aggregate functions and update aggregation buffers
|$codeToEvalAggFunc
""".stripMargin
}
@ -850,16 +1005,48 @@ case class HashAggregateExec(
if (isFastHashMapEnabled) {
if (isVectorizedHashMapEnabled) {
ctx.INPUT_ROW = fastRowBuffer
val boundUpdateExpr = bindReferences(updateExpr, inputAttr)
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
val effectiveCodes = subExprs.codes.mkString("\n")
val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
boundUpdateExpr.map(_.genCode(ctx))
val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc =>
bindReferences(updateExprsForOneFunc, inputAttr)
}
val updateFastRow = fastRowEvals.zipWithIndex.map { case (ev, i) =>
val dt = updateExpr(i).dataType
CodeGenerator.updateColumn(
fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorized = true)
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
val effectiveCodes = subExprs.codes.mkString("\n")
val fastRowEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc =>
ctx.withSubExprEliminationExprs(subExprs.states) {
boundUpdateExprsForOneFunc.map(_.genCode(ctx))
}
}
val aggCodeBlocks = fastRowEvals.zipWithIndex.map { case (fastRowEvalsForOneFunc, i) =>
val boundUpdateExprsForOneFunc = boundUpdateExprs(i)
val bufferOffset = bufferStartOffsets(i)
// All the update code for aggregation buffers should be placed in the end
// of each aggregation function code.
val updateRowBuffer = fastRowEvalsForOneFunc.zipWithIndex.map { case (ev, j) =>
val updateExpr = boundUpdateExprsForOneFunc(j)
val dt = updateExpr.dataType
val nullable = updateExpr.nullable
CodeGenerator.updateColumn(fastRowBuffer, dt, bufferOffset + j, ev, nullable,
isVectorized = true)
}
code"""
|// evaluate aggregate function for ${aggNames(i)}
|${evaluateVariables(fastRowEvalsForOneFunc)}
|// update fast row
|${updateRowBuffer.mkString("\n").trim}
""".stripMargin
}
val codeToEvalAggFunc = if (conf.codegenSplitAggregateFunc &&
aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) {
val maybeSplitCode = splitAggregateExpressions(
ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states)
maybeSplitCode.getOrElse {
aggCodeBlocks.fold(EmptyBlock)(_ + _).code
}
} else {
aggCodeBlocks.fold(EmptyBlock)(_ + _).code
}
// If vectorized fast hash map is on, we first generate code to update row
@ -869,10 +1056,8 @@ case class HashAggregateExec(
|if ($fastRowBuffer != null) {
| // common sub-expressions
| $effectiveCodes
| // evaluate aggregate function
| ${evaluateVariables(fastRowEvals)}
| // update fast row
| ${updateFastRow.mkString("\n").trim}
| // evaluate aggregate functions and update aggregation buffers
| $codeToEvalAggFunc
|} else {
| $updateRowInRegularHashMap
|}

View file

@ -398,4 +398,25 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession {
}.isDefined,
"LocalTableScanExec should be within a WholeStageCodegen domain.")
}
test("Give up splitting aggregate code if a parameter length goes over the limit") {
withSQLConf(
SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC.key -> "true",
SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> "1",
"spark.sql.HashAggregateExec.validParamLength" -> "0") {
withTable("t") {
val expectedErrMsg = "Failed to split aggregate code into small functions"
Seq(
// Test case without keys
"SELECT AVG(v) FROM VALUES(1) t(v)",
// Tet case with keys
"SELECT k, AVG(v) FROM VALUES((1, 1)) t(k, v) GROUP BY k").foreach { query =>
val errMsg = intercept[IllegalStateException] {
sql(query).collect
}.getMessage
assert(errMsg.contains(expectedErrMsg))
}
}
}
}
}