[SPARK-21870][SQL] Split aggregation code into small functions
## What changes were proposed in this pull request? This pr proposed to split aggregation code into small functions in `HashAggregateExec`. In #18810, we got performance regression if JVMs didn't compile too long functions. I checked and I found the codegen of `HashAggregateExec` frequently goes over the limit when a query has too many aggregate functions (e.g., q66 in TPCDS). The current master places all the generated aggregation code in a single function. In this pr, I modified the code to assign an individual function for each aggregate function (e.g., `SUM` and `AVG`). For example, in a query `SELECT SUM(a), AVG(a) FROM VALUES(1) t(a)`, the proposed code defines two functions for `SUM(a)` and `AVG(a)` as follows; - generated code with this pr (https://gist.github.com/maropu/812990012bc967a78364be0fa793f559): ``` /* 173 */ private void agg_doConsume_0(InternalRow inputadapter_row_0, long agg_expr_0_0, boolean agg_exprIsNull_0_0, double agg_expr_1_0, boolean agg_exprIsNull_1_0, long agg_expr_2_0, boolean agg_exprIsNull_2_0) throws java.io.IOException { /* 174 */ // do aggregate /* 175 */ // common sub-expressions /* 176 */ /* 177 */ // evaluate aggregate functions and update aggregation buffers /* 178 */ agg_doAggregate_sum_0(agg_exprIsNull_0_0, agg_expr_0_0); /* 179 */ agg_doAggregate_avg_0(agg_expr_1_0, agg_exprIsNull_1_0, agg_exprIsNull_2_0, agg_expr_2_0); /* 180 */ /* 181 */ } ... /* 071 */ private void agg_doAggregate_avg_0(double agg_expr_1_0, boolean agg_exprIsNull_1_0, boolean agg_exprIsNull_2_0, long agg_expr_2_0) throws java.io.IOException { /* 072 */ // do aggregate for avg /* 073 */ // evaluate aggregate function /* 074 */ boolean agg_isNull_19 = true; /* 075 */ double agg_value_19 = -1.0; ... /* 114 */ private void agg_doAggregate_sum_0(boolean agg_exprIsNull_0_0, long agg_expr_0_0) throws java.io.IOException { /* 115 */ // do aggregate for sum /* 116 */ // evaluate aggregate function /* 117 */ agg_agg_isNull_11_0 = true; /* 118 */ long agg_value_11 = -1L; ``` - generated code in the current master (https://gist.github.com/maropu/e9d772af2c98d8991a6a5f0af7841760) ``` /* 059 */ private void agg_doConsume_0(InternalRow localtablescan_row_0, int agg_expr_0_0) throws java.io.IOException { /* 060 */ // do aggregate /* 061 */ // common sub-expressions /* 062 */ boolean agg_isNull_4 = false; /* 063 */ long agg_value_4 = -1L; /* 064 */ if (!false) { /* 065 */ agg_value_4 = (long) agg_expr_0_0; /* 066 */ } /* 067 */ // evaluate aggregate function /* 068 */ agg_agg_isNull_7_0 = true; /* 069 */ long agg_value_7 = -1L; /* 070 */ do { /* 071 */ if (!agg_bufIsNull_0) { /* 072 */ agg_agg_isNull_7_0 = false; /* 073 */ agg_value_7 = agg_bufValue_0; /* 074 */ continue; /* 075 */ } /* 076 */ /* 077 */ boolean agg_isNull_9 = false; /* 078 */ long agg_value_9 = -1L; /* 079 */ if (!false) { /* 080 */ agg_value_9 = (long) 0; /* 081 */ } /* 082 */ if (!agg_isNull_9) { /* 083 */ agg_agg_isNull_7_0 = false; /* 084 */ agg_value_7 = agg_value_9; /* 085 */ continue; /* 086 */ } /* 087 */ /* 088 */ } while (false); /* 089 */ /* 090 */ long agg_value_6 = -1L; /* 091 */ /* 092 */ agg_value_6 = agg_value_7 + agg_value_4; /* 093 */ boolean agg_isNull_11 = true; /* 094 */ double agg_value_11 = -1.0; /* 095 */ /* 096 */ if (!agg_bufIsNull_1) { /* 097 */ agg_agg_isNull_13_0 = true; /* 098 */ double agg_value_13 = -1.0; /* 099 */ do { /* 100 */ boolean agg_isNull_14 = agg_isNull_4; /* 101 */ double agg_value_14 = -1.0; /* 102 */ if (!agg_isNull_4) { /* 103 */ agg_value_14 = (double) agg_value_4; /* 104 */ } /* 105 */ if (!agg_isNull_14) { /* 106 */ agg_agg_isNull_13_0 = false; /* 107 */ agg_value_13 = agg_value_14; /* 108 */ continue; /* 109 */ } /* 110 */ /* 111 */ boolean agg_isNull_15 = false; /* 112 */ double agg_value_15 = -1.0; /* 113 */ if (!false) { /* 114 */ agg_value_15 = (double) 0; /* 115 */ } /* 116 */ if (!agg_isNull_15) { /* 117 */ agg_agg_isNull_13_0 = false; /* 118 */ agg_value_13 = agg_value_15; /* 119 */ continue; /* 120 */ } /* 121 */ /* 122 */ } while (false); /* 123 */ /* 124 */ agg_isNull_11 = false; // resultCode could change nullability. /* 125 */ /* 126 */ agg_value_11 = agg_bufValue_1 + agg_value_13; /* 127 */ /* 128 */ } /* 129 */ boolean agg_isNull_17 = false; /* 130 */ long agg_value_17 = -1L; /* 131 */ if (!false && agg_isNull_4) { /* 132 */ agg_isNull_17 = agg_bufIsNull_2; /* 133 */ agg_value_17 = agg_bufValue_2; /* 134 */ } else { /* 135 */ boolean agg_isNull_20 = true; /* 136 */ long agg_value_20 = -1L; /* 137 */ /* 138 */ if (!agg_bufIsNull_2) { /* 139 */ agg_isNull_20 = false; // resultCode could change nullability. /* 140 */ /* 141 */ agg_value_20 = agg_bufValue_2 + 1L; /* 142 */ /* 143 */ } /* 144 */ agg_isNull_17 = agg_isNull_20; /* 145 */ agg_value_17 = agg_value_20; /* 146 */ } /* 147 */ // update aggregation buffer /* 148 */ agg_bufIsNull_0 = false; /* 149 */ agg_bufValue_0 = agg_value_6; /* 150 */ /* 151 */ agg_bufIsNull_1 = agg_isNull_11; /* 152 */ agg_bufValue_1 = agg_value_11; /* 153 */ /* 154 */ agg_bufIsNull_2 = agg_isNull_17; /* 155 */ agg_bufValue_2 = agg_value_17; /* 156 */ /* 157 */ } ``` You can check the previous discussion in https://github.com/apache/spark/pull/19082 ## How was this patch tested? Existing tests Closes #20965 from maropu/SPARK-21870-2. Authored-by: Takeshi Yamamuro <yamamuro@apache.org> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
36f8e53cfa
commit
cb0cddffe9
|
@ -115,7 +115,13 @@ package object dsl {
|
|||
def getField(fieldName: String): UnresolvedExtractValue =
|
||||
UnresolvedExtractValue(expr, Literal(fieldName))
|
||||
|
||||
def cast(to: DataType): Expression = Cast(expr, to)
|
||||
def cast(to: DataType): Expression = {
|
||||
if (expr.resolved && expr.dataType.sameType(to)) {
|
||||
expr
|
||||
} else {
|
||||
Cast(expr, to)
|
||||
}
|
||||
}
|
||||
|
||||
def asc: SortOrder = SortOrder(expr, Ascending)
|
||||
def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast, Set.empty)
|
||||
|
|
|
@ -1612,6 +1612,48 @@ object CodeGenerator extends Logging {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts all the input variables from references and subexpression elimination states
|
||||
* for a given `expr`. This result will be used to split the generated code of
|
||||
* expressions into multiple functions.
|
||||
*/
|
||||
def getLocalInputVariableValues(
|
||||
ctx: CodegenContext,
|
||||
expr: Expression,
|
||||
subExprs: Map[Expression, SubExprEliminationState]): Set[VariableValue] = {
|
||||
val argSet = mutable.Set[VariableValue]()
|
||||
if (ctx.INPUT_ROW != null) {
|
||||
argSet += JavaCode.variable(ctx.INPUT_ROW, classOf[InternalRow])
|
||||
}
|
||||
|
||||
// Collects local variables from a given `expr` tree
|
||||
val collectLocalVariable = (ev: ExprValue) => ev match {
|
||||
case vv: VariableValue => argSet += vv
|
||||
case _ =>
|
||||
}
|
||||
|
||||
val stack = mutable.Stack[Expression](expr)
|
||||
while (stack.nonEmpty) {
|
||||
stack.pop() match {
|
||||
case e if subExprs.contains(e) =>
|
||||
val SubExprEliminationState(isNull, value) = subExprs(e)
|
||||
collectLocalVariable(value)
|
||||
collectLocalVariable(isNull)
|
||||
|
||||
case ref: BoundReference if ctx.currentVars != null &&
|
||||
ctx.currentVars(ref.ordinal) != null =>
|
||||
val ExprCode(_, isNull, value) = ctx.currentVars(ref.ordinal)
|
||||
collectLocalVariable(value)
|
||||
collectLocalVariable(isNull)
|
||||
|
||||
case e =>
|
||||
stack.pushAll(e.children)
|
||||
}
|
||||
}
|
||||
|
||||
argSet.toSet
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the name used in accessor and setter for a Java primitive type.
|
||||
*/
|
||||
|
@ -1719,6 +1761,15 @@ object CodeGenerator extends Logging {
|
|||
1 + params.map(paramLengthForExpr).sum
|
||||
}
|
||||
|
||||
def calculateParamLengthFromExprValues(params: Seq[ExprValue]): Int = {
|
||||
def paramLengthForExpr(input: ExprValue): Int = input.javaType match {
|
||||
case java.lang.Long.TYPE | java.lang.Double.TYPE => 2
|
||||
case _ => 1
|
||||
}
|
||||
// Initial value is 1 for `this`.
|
||||
1 + params.map(paramLengthForExpr).sum
|
||||
}
|
||||
|
||||
/**
|
||||
* In Java, a method descriptor is valid only if it represents method parameters with a total
|
||||
* length less than a pre-defined constant.
|
||||
|
|
|
@ -143,7 +143,10 @@ trait Block extends TreeNode[Block] with JavaCode {
|
|||
case _ => code.trim
|
||||
}
|
||||
|
||||
def length: Int = toString.length
|
||||
def length: Int = {
|
||||
// Returns a code length without comments
|
||||
CodeFormatter.stripExtraNewLinesAndComments(toString).length
|
||||
}
|
||||
|
||||
def isEmpty: Boolean = toString.isEmpty
|
||||
|
||||
|
|
|
@ -354,12 +354,14 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
|
|||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val eval = child.genCode(ctx)
|
||||
val value = eval.isNull match {
|
||||
case TrueLiteral => FalseLiteral
|
||||
case FalseLiteral => TrueLiteral
|
||||
case v => JavaCode.isNullExpression(s"!$v")
|
||||
val (value, newCode) = eval.isNull match {
|
||||
case TrueLiteral => (FalseLiteral, EmptyBlock)
|
||||
case FalseLiteral => (TrueLiteral, EmptyBlock)
|
||||
case v =>
|
||||
val value = ctx.freshName("value")
|
||||
(JavaCode.variable(value, BooleanType), code"boolean $value = !$v;")
|
||||
}
|
||||
ExprCode(code = eval.code, isNull = FalseLiteral, value = value)
|
||||
ExprCode(code = eval.code + newCode, isNull = FalseLiteral, value = value)
|
||||
}
|
||||
|
||||
override def sql: String = s"(${child.sql} IS NOT NULL)"
|
||||
|
|
|
@ -1080,6 +1080,15 @@ object SQLConf {
|
|||
.booleanConf
|
||||
.createWithDefault(false)
|
||||
|
||||
val CODEGEN_SPLIT_AGGREGATE_FUNC =
|
||||
buildConf("spark.sql.codegen.aggregate.splitAggregateFunc.enabled")
|
||||
.internal()
|
||||
.doc("When true, the code generator would split aggregate code into individual methods " +
|
||||
"instead of a single big method. This can be used to avoid oversized function that " +
|
||||
"can miss the opportunity of JIT optimization.")
|
||||
.booleanConf
|
||||
.createWithDefault(true)
|
||||
|
||||
val MAX_NESTED_VIEW_DEPTH =
|
||||
buildConf("spark.sql.view.maxNestedViewDepth")
|
||||
.internal()
|
||||
|
@ -2353,6 +2362,8 @@ class SQLConf extends Serializable with Logging {
|
|||
def cartesianProductExecBufferSpillThreshold: Int =
|
||||
getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_SPILL_THRESHOLD)
|
||||
|
||||
def codegenSplitAggregateFunc: Boolean = getConf(SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC)
|
||||
|
||||
def maxNestedViewDepth: Int = getConf(SQLConf.MAX_NESTED_VIEW_DEPTH)
|
||||
|
||||
def starSchemaDetection: Boolean = getConf(STARSCHEMA_DETECTION)
|
||||
|
|
|
@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.aggregate
|
|||
|
||||
import java.util.concurrent.TimeUnit._
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.memory.{SparkOutOfMemoryError, TaskMemoryManager}
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
@ -174,8 +176,9 @@ case class HashAggregateExec(
|
|||
}
|
||||
}
|
||||
|
||||
// The variables used as aggregation buffer. Only used for aggregation without keys.
|
||||
private var bufVars: Seq[ExprCode] = _
|
||||
// The variables are used as aggregation buffers and each aggregate function has one or more
|
||||
// ExprCode to initialize its buffer slots. Only used for aggregation without keys.
|
||||
private var bufVars: Seq[Seq[ExprCode]] = _
|
||||
|
||||
private def doProduceWithoutKeys(ctx: CodegenContext): String = {
|
||||
val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg")
|
||||
|
@ -184,27 +187,30 @@ case class HashAggregateExec(
|
|||
|
||||
// generate variables for aggregation buffer
|
||||
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
|
||||
val initExpr = functions.flatMap(f => f.initialValues)
|
||||
bufVars = initExpr.map { e =>
|
||||
val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "bufIsNull")
|
||||
val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue")
|
||||
// The initial expression should not access any column
|
||||
val ev = e.genCode(ctx)
|
||||
val initVars = code"""
|
||||
| $isNull = ${ev.isNull};
|
||||
| $value = ${ev.value};
|
||||
""".stripMargin
|
||||
ExprCode(
|
||||
ev.code + initVars,
|
||||
JavaCode.isNullGlobal(isNull),
|
||||
JavaCode.global(value, e.dataType))
|
||||
val initExpr = functions.map(f => f.initialValues)
|
||||
bufVars = initExpr.map { exprs =>
|
||||
exprs.map { e =>
|
||||
val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "bufIsNull")
|
||||
val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue")
|
||||
// The initial expression should not access any column
|
||||
val ev = e.genCode(ctx)
|
||||
val initVars = code"""
|
||||
|$isNull = ${ev.isNull};
|
||||
|$value = ${ev.value};
|
||||
""".stripMargin
|
||||
ExprCode(
|
||||
ev.code + initVars,
|
||||
JavaCode.isNullGlobal(isNull),
|
||||
JavaCode.global(value, e.dataType))
|
||||
}
|
||||
}
|
||||
val initBufVar = evaluateVariables(bufVars)
|
||||
val flatBufVars = bufVars.flatten
|
||||
val initBufVar = evaluateVariables(flatBufVars)
|
||||
|
||||
// generate variables for output
|
||||
val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) {
|
||||
// evaluate aggregate results
|
||||
ctx.currentVars = bufVars
|
||||
ctx.currentVars = flatBufVars
|
||||
val aggResults = bindReferences(
|
||||
functions.map(_.evaluateExpression),
|
||||
aggregateBufferAttributes).map(_.genCode(ctx))
|
||||
|
@ -218,7 +224,7 @@ case class HashAggregateExec(
|
|||
""".stripMargin)
|
||||
} else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
|
||||
// output the aggregate buffer directly
|
||||
(bufVars, "")
|
||||
(flatBufVars, "")
|
||||
} else {
|
||||
// no aggregate function, the result should be literals
|
||||
val resultVars = resultExpressions.map(_.genCode(ctx))
|
||||
|
@ -255,11 +261,85 @@ case class HashAggregateExec(
|
|||
""".stripMargin
|
||||
}
|
||||
|
||||
private def isValidParamLength(paramLength: Int): Boolean = {
|
||||
// This config is only for testing
|
||||
sqlContext.getConf("spark.sql.HashAggregateExec.validParamLength", null) match {
|
||||
case null | "" => CodeGenerator.isValidParamLength(paramLength)
|
||||
case validLength => paramLength <= validLength.toInt
|
||||
}
|
||||
}
|
||||
|
||||
// Splits aggregate code into small functions because the most of JVM implementations
|
||||
// can not compile too long functions. Returns None if we are not able to split the given code.
|
||||
//
|
||||
// Note: The difference from `CodeGenerator.splitExpressions` is that we define an individual
|
||||
// function for each aggregation function (e.g., SUM and AVG). For example, in a query
|
||||
// `SELECT SUM(a), AVG(a) FROM VALUES(1) t(a)`, we define two functions
|
||||
// for `SUM(a)` and `AVG(a)`.
|
||||
private def splitAggregateExpressions(
|
||||
ctx: CodegenContext,
|
||||
aggNames: Seq[String],
|
||||
aggBufferUpdatingExprs: Seq[Seq[Expression]],
|
||||
aggCodeBlocks: Seq[Block],
|
||||
subExprs: Map[Expression, SubExprEliminationState]): Option[String] = {
|
||||
val exprValsInSubExprs = subExprs.flatMap { case (_, s) => s.value :: s.isNull :: Nil }
|
||||
if (exprValsInSubExprs.exists(_.isInstanceOf[SimpleExprValue])) {
|
||||
// `SimpleExprValue`s cannot be used as an input variable for split functions, so
|
||||
// we give up splitting functions if it exists in `subExprs`.
|
||||
None
|
||||
} else {
|
||||
val inputVars = aggBufferUpdatingExprs.map { aggExprsForOneFunc =>
|
||||
val inputVarsForOneFunc = aggExprsForOneFunc.map(
|
||||
CodeGenerator.getLocalInputVariableValues(ctx, _, subExprs)).reduce(_ ++ _).toSeq
|
||||
val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVarsForOneFunc)
|
||||
|
||||
// Checks if a parameter length for the `aggExprsForOneFunc` does not go over the JVM limit
|
||||
if (isValidParamLength(paramLength)) {
|
||||
Some(inputVarsForOneFunc)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
// Checks if all the aggregate code can be split into pieces.
|
||||
// If the parameter length of at lease one `aggExprsForOneFunc` goes over the limit,
|
||||
// we totally give up splitting aggregate code.
|
||||
if (inputVars.forall(_.isDefined)) {
|
||||
val splitCodes = inputVars.flatten.zipWithIndex.map { case (args, i) =>
|
||||
val doAggFunc = ctx.freshName(s"doAggregate_${aggNames(i)}")
|
||||
val argList = args.map(v => s"${v.javaType.getName} ${v.variableName}").mkString(", ")
|
||||
val doAggFuncName = ctx.addNewFunction(doAggFunc,
|
||||
s"""
|
||||
|private void $doAggFunc($argList) throws java.io.IOException {
|
||||
| ${aggCodeBlocks(i)}
|
||||
|}
|
||||
""".stripMargin)
|
||||
|
||||
val inputVariables = args.map(_.variableName).mkString(", ")
|
||||
s"$doAggFuncName($inputVariables);"
|
||||
}
|
||||
Some(splitCodes.mkString("\n").trim)
|
||||
} else {
|
||||
val errMsg = "Failed to split aggregate code into small functions because the parameter " +
|
||||
"length of at least one split function went over the JVM limit: " +
|
||||
CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH
|
||||
if (Utils.isTesting) {
|
||||
throw new IllegalStateException(errMsg)
|
||||
} else {
|
||||
logInfo(errMsg)
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {
|
||||
// only have DeclarativeAggregate
|
||||
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
|
||||
val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output
|
||||
val updateExpr = aggregateExpressions.flatMap { e =>
|
||||
// To individually generate code for each aggregate function, an element in `updateExprs` holds
|
||||
// all the expressions for the buffer of an aggregation function.
|
||||
val updateExprs = aggregateExpressions.map { e =>
|
||||
e.mode match {
|
||||
case Partial | Complete =>
|
||||
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions
|
||||
|
@ -267,28 +347,56 @@ case class HashAggregateExec(
|
|||
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
|
||||
}
|
||||
}
|
||||
ctx.currentVars = bufVars ++ input
|
||||
val boundUpdateExpr = bindReferences(updateExpr, inputAttrs)
|
||||
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
|
||||
val effectiveCodes = subExprs.codes.mkString("\n")
|
||||
val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) {
|
||||
boundUpdateExpr.map(_.genCode(ctx))
|
||||
ctx.currentVars = bufVars.flatten ++ input
|
||||
val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc =>
|
||||
bindReferences(updateExprsForOneFunc, inputAttrs)
|
||||
}
|
||||
// aggregate buffer should be updated atomic
|
||||
val updates = aggVals.zipWithIndex.map { case (ev, i) =>
|
||||
s"""
|
||||
| ${bufVars(i).isNull} = ${ev.isNull};
|
||||
| ${bufVars(i).value} = ${ev.value};
|
||||
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
|
||||
val effectiveCodes = subExprs.codes.mkString("\n")
|
||||
val bufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc =>
|
||||
ctx.withSubExprEliminationExprs(subExprs.states) {
|
||||
boundUpdateExprsForOneFunc.map(_.genCode(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
val aggNames = functions.map(_.prettyName)
|
||||
val aggCodeBlocks = bufferEvals.zipWithIndex.map { case (bufferEvalsForOneFunc, i) =>
|
||||
val bufVarsForOneFunc = bufVars(i)
|
||||
// All the update code for aggregation buffers should be placed in the end
|
||||
// of each aggregation function code.
|
||||
val updates = bufferEvalsForOneFunc.zip(bufVarsForOneFunc).map { case (ev, bufVar) =>
|
||||
s"""
|
||||
|${bufVar.isNull} = ${ev.isNull};
|
||||
|${bufVar.value} = ${ev.value};
|
||||
""".stripMargin
|
||||
}
|
||||
code"""
|
||||
|// do aggregate for ${aggNames(i)}
|
||||
|// evaluate aggregate function
|
||||
|${evaluateVariables(bufferEvalsForOneFunc)}
|
||||
|// update aggregation buffers
|
||||
|${updates.mkString("\n").trim}
|
||||
""".stripMargin
|
||||
}
|
||||
|
||||
val codeToEvalAggFunc = if (conf.codegenSplitAggregateFunc &&
|
||||
aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) {
|
||||
val maybeSplitCode = splitAggregateExpressions(
|
||||
ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states)
|
||||
|
||||
maybeSplitCode.getOrElse {
|
||||
aggCodeBlocks.fold(EmptyBlock)(_ + _).code
|
||||
}
|
||||
} else {
|
||||
aggCodeBlocks.fold(EmptyBlock)(_ + _).code
|
||||
}
|
||||
|
||||
s"""
|
||||
| // do aggregate
|
||||
| // common sub-expressions
|
||||
| $effectiveCodes
|
||||
| // evaluate aggregate function
|
||||
| ${evaluateVariables(aggVals)}
|
||||
| // update aggregation buffer
|
||||
| ${updates.mkString("\n").trim}
|
||||
|// do aggregate
|
||||
|// common sub-expressions
|
||||
|$effectiveCodes
|
||||
|// evaluate aggregate functions and update aggregation buffers
|
||||
|$codeToEvalAggFunc
|
||||
""".stripMargin
|
||||
}
|
||||
|
||||
|
@ -745,8 +853,10 @@ case class HashAggregateExec(
|
|||
val unsafeRowBuffer = ctx.freshName("unsafeRowAggBuffer")
|
||||
val fastRowBuffer = ctx.freshName("fastAggBuffer")
|
||||
|
||||
// only have DeclarativeAggregate
|
||||
val updateExpr = aggregateExpressions.flatMap { e =>
|
||||
// To individually generate code for each aggregate function, an element in `updateExprs` holds
|
||||
// all the expressions for the buffer of an aggregation function.
|
||||
val updateExprs = aggregateExpressions.map { e =>
|
||||
// only have DeclarativeAggregate
|
||||
e.mode match {
|
||||
case Partial | Complete =>
|
||||
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions
|
||||
|
@ -824,25 +934,70 @@ case class HashAggregateExec(
|
|||
// generating input columns, we use `currentVars`.
|
||||
ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input
|
||||
|
||||
val aggNames = aggregateExpressions.map(_.aggregateFunction.prettyName)
|
||||
// Computes start offsets for each aggregation function code
|
||||
// in the underlying buffer row.
|
||||
val bufferStartOffsets = {
|
||||
val offsets = mutable.ArrayBuffer[Int]()
|
||||
var curOffset = 0
|
||||
updateExprs.foreach { exprsForOneFunc =>
|
||||
offsets += curOffset
|
||||
curOffset += exprsForOneFunc.length
|
||||
}
|
||||
offsets.toArray
|
||||
}
|
||||
|
||||
val updateRowInRegularHashMap: String = {
|
||||
ctx.INPUT_ROW = unsafeRowBuffer
|
||||
val boundUpdateExpr = bindReferences(updateExpr, inputAttr)
|
||||
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
|
||||
val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc =>
|
||||
bindReferences(updateExprsForOneFunc, inputAttr)
|
||||
}
|
||||
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
|
||||
val effectiveCodes = subExprs.codes.mkString("\n")
|
||||
val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
|
||||
boundUpdateExpr.map(_.genCode(ctx))
|
||||
val unsafeRowBufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc =>
|
||||
ctx.withSubExprEliminationExprs(subExprs.states) {
|
||||
boundUpdateExprsForOneFunc.map(_.genCode(ctx))
|
||||
}
|
||||
}
|
||||
val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) =>
|
||||
val dt = updateExpr(i).dataType
|
||||
CodeGenerator.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable)
|
||||
|
||||
val aggCodeBlocks = updateExprs.indices.map { i =>
|
||||
val rowBufferEvalsForOneFunc = unsafeRowBufferEvals(i)
|
||||
val boundUpdateExprsForOneFunc = boundUpdateExprs(i)
|
||||
val bufferOffset = bufferStartOffsets(i)
|
||||
|
||||
// All the update code for aggregation buffers should be placed in the end
|
||||
// of each aggregation function code.
|
||||
val updateRowBuffers = rowBufferEvalsForOneFunc.zipWithIndex.map { case (ev, j) =>
|
||||
val updateExpr = boundUpdateExprsForOneFunc(j)
|
||||
val dt = updateExpr.dataType
|
||||
val nullable = updateExpr.nullable
|
||||
CodeGenerator.updateColumn(unsafeRowBuffer, dt, bufferOffset + j, ev, nullable)
|
||||
}
|
||||
code"""
|
||||
|// evaluate aggregate function for ${aggNames(i)}
|
||||
|${evaluateVariables(rowBufferEvalsForOneFunc)}
|
||||
|// update unsafe row buffer
|
||||
|${updateRowBuffers.mkString("\n").trim}
|
||||
""".stripMargin
|
||||
}
|
||||
|
||||
val codeToEvalAggFunc = if (conf.codegenSplitAggregateFunc &&
|
||||
aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) {
|
||||
val maybeSplitCode = splitAggregateExpressions(
|
||||
ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states)
|
||||
|
||||
maybeSplitCode.getOrElse {
|
||||
aggCodeBlocks.fold(EmptyBlock)(_ + _).code
|
||||
}
|
||||
} else {
|
||||
aggCodeBlocks.fold(EmptyBlock)(_ + _).code
|
||||
}
|
||||
|
||||
s"""
|
||||
|// common sub-expressions
|
||||
|$effectiveCodes
|
||||
|// evaluate aggregate function
|
||||
|${evaluateVariables(unsafeRowBufferEvals)}
|
||||
|// update unsafe row buffer
|
||||
|${updateUnsafeRowBuffer.mkString("\n").trim}
|
||||
|// evaluate aggregate functions and update aggregation buffers
|
||||
|$codeToEvalAggFunc
|
||||
""".stripMargin
|
||||
}
|
||||
|
||||
|
@ -850,16 +1005,48 @@ case class HashAggregateExec(
|
|||
if (isFastHashMapEnabled) {
|
||||
if (isVectorizedHashMapEnabled) {
|
||||
ctx.INPUT_ROW = fastRowBuffer
|
||||
val boundUpdateExpr = bindReferences(updateExpr, inputAttr)
|
||||
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
|
||||
val effectiveCodes = subExprs.codes.mkString("\n")
|
||||
val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
|
||||
boundUpdateExpr.map(_.genCode(ctx))
|
||||
val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc =>
|
||||
bindReferences(updateExprsForOneFunc, inputAttr)
|
||||
}
|
||||
val updateFastRow = fastRowEvals.zipWithIndex.map { case (ev, i) =>
|
||||
val dt = updateExpr(i).dataType
|
||||
CodeGenerator.updateColumn(
|
||||
fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorized = true)
|
||||
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
|
||||
val effectiveCodes = subExprs.codes.mkString("\n")
|
||||
val fastRowEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc =>
|
||||
ctx.withSubExprEliminationExprs(subExprs.states) {
|
||||
boundUpdateExprsForOneFunc.map(_.genCode(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
val aggCodeBlocks = fastRowEvals.zipWithIndex.map { case (fastRowEvalsForOneFunc, i) =>
|
||||
val boundUpdateExprsForOneFunc = boundUpdateExprs(i)
|
||||
val bufferOffset = bufferStartOffsets(i)
|
||||
// All the update code for aggregation buffers should be placed in the end
|
||||
// of each aggregation function code.
|
||||
val updateRowBuffer = fastRowEvalsForOneFunc.zipWithIndex.map { case (ev, j) =>
|
||||
val updateExpr = boundUpdateExprsForOneFunc(j)
|
||||
val dt = updateExpr.dataType
|
||||
val nullable = updateExpr.nullable
|
||||
CodeGenerator.updateColumn(fastRowBuffer, dt, bufferOffset + j, ev, nullable,
|
||||
isVectorized = true)
|
||||
}
|
||||
code"""
|
||||
|// evaluate aggregate function for ${aggNames(i)}
|
||||
|${evaluateVariables(fastRowEvalsForOneFunc)}
|
||||
|// update fast row
|
||||
|${updateRowBuffer.mkString("\n").trim}
|
||||
""".stripMargin
|
||||
}
|
||||
|
||||
|
||||
val codeToEvalAggFunc = if (conf.codegenSplitAggregateFunc &&
|
||||
aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) {
|
||||
val maybeSplitCode = splitAggregateExpressions(
|
||||
ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states)
|
||||
|
||||
maybeSplitCode.getOrElse {
|
||||
aggCodeBlocks.fold(EmptyBlock)(_ + _).code
|
||||
}
|
||||
} else {
|
||||
aggCodeBlocks.fold(EmptyBlock)(_ + _).code
|
||||
}
|
||||
|
||||
// If vectorized fast hash map is on, we first generate code to update row
|
||||
|
@ -869,10 +1056,8 @@ case class HashAggregateExec(
|
|||
|if ($fastRowBuffer != null) {
|
||||
| // common sub-expressions
|
||||
| $effectiveCodes
|
||||
| // evaluate aggregate function
|
||||
| ${evaluateVariables(fastRowEvals)}
|
||||
| // update fast row
|
||||
| ${updateFastRow.mkString("\n").trim}
|
||||
| // evaluate aggregate functions and update aggregation buffers
|
||||
| $codeToEvalAggFunc
|
||||
|} else {
|
||||
| $updateRowInRegularHashMap
|
||||
|}
|
||||
|
|
|
@ -398,4 +398,25 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession {
|
|||
}.isDefined,
|
||||
"LocalTableScanExec should be within a WholeStageCodegen domain.")
|
||||
}
|
||||
|
||||
test("Give up splitting aggregate code if a parameter length goes over the limit") {
|
||||
withSQLConf(
|
||||
SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC.key -> "true",
|
||||
SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> "1",
|
||||
"spark.sql.HashAggregateExec.validParamLength" -> "0") {
|
||||
withTable("t") {
|
||||
val expectedErrMsg = "Failed to split aggregate code into small functions"
|
||||
Seq(
|
||||
// Test case without keys
|
||||
"SELECT AVG(v) FROM VALUES(1) t(v)",
|
||||
// Tet case with keys
|
||||
"SELECT k, AVG(v) FROM VALUES((1, 1)) t(k, v) GROUP BY k").foreach { query =>
|
||||
val errMsg = intercept[IllegalStateException] {
|
||||
sql(query).collect
|
||||
}.getMessage
|
||||
assert(errMsg.contains(expectedErrMsg))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue