[SPARK-21870][SQL] Split aggregation code into small functions
## What changes were proposed in this pull request? This pr proposed to split aggregation code into small functions in `HashAggregateExec`. In #18810, we got performance regression if JVMs didn't compile too long functions. I checked and I found the codegen of `HashAggregateExec` frequently goes over the limit when a query has too many aggregate functions (e.g., q66 in TPCDS). The current master places all the generated aggregation code in a single function. In this pr, I modified the code to assign an individual function for each aggregate function (e.g., `SUM` and `AVG`). For example, in a query `SELECT SUM(a), AVG(a) FROM VALUES(1) t(a)`, the proposed code defines two functions for `SUM(a)` and `AVG(a)` as follows; - generated code with this pr (https://gist.github.com/maropu/812990012bc967a78364be0fa793f559): ``` /* 173 */ private void agg_doConsume_0(InternalRow inputadapter_row_0, long agg_expr_0_0, boolean agg_exprIsNull_0_0, double agg_expr_1_0, boolean agg_exprIsNull_1_0, long agg_expr_2_0, boolean agg_exprIsNull_2_0) throws java.io.IOException { /* 174 */ // do aggregate /* 175 */ // common sub-expressions /* 176 */ /* 177 */ // evaluate aggregate functions and update aggregation buffers /* 178 */ agg_doAggregate_sum_0(agg_exprIsNull_0_0, agg_expr_0_0); /* 179 */ agg_doAggregate_avg_0(agg_expr_1_0, agg_exprIsNull_1_0, agg_exprIsNull_2_0, agg_expr_2_0); /* 180 */ /* 181 */ } ... /* 071 */ private void agg_doAggregate_avg_0(double agg_expr_1_0, boolean agg_exprIsNull_1_0, boolean agg_exprIsNull_2_0, long agg_expr_2_0) throws java.io.IOException { /* 072 */ // do aggregate for avg /* 073 */ // evaluate aggregate function /* 074 */ boolean agg_isNull_19 = true; /* 075 */ double agg_value_19 = -1.0; ... /* 114 */ private void agg_doAggregate_sum_0(boolean agg_exprIsNull_0_0, long agg_expr_0_0) throws java.io.IOException { /* 115 */ // do aggregate for sum /* 116 */ // evaluate aggregate function /* 117 */ agg_agg_isNull_11_0 = true; /* 118 */ long agg_value_11 = -1L; ``` - generated code in the current master (https://gist.github.com/maropu/e9d772af2c98d8991a6a5f0af7841760) ``` /* 059 */ private void agg_doConsume_0(InternalRow localtablescan_row_0, int agg_expr_0_0) throws java.io.IOException { /* 060 */ // do aggregate /* 061 */ // common sub-expressions /* 062 */ boolean agg_isNull_4 = false; /* 063 */ long agg_value_4 = -1L; /* 064 */ if (!false) { /* 065 */ agg_value_4 = (long) agg_expr_0_0; /* 066 */ } /* 067 */ // evaluate aggregate function /* 068 */ agg_agg_isNull_7_0 = true; /* 069 */ long agg_value_7 = -1L; /* 070 */ do { /* 071 */ if (!agg_bufIsNull_0) { /* 072 */ agg_agg_isNull_7_0 = false; /* 073 */ agg_value_7 = agg_bufValue_0; /* 074 */ continue; /* 075 */ } /* 076 */ /* 077 */ boolean agg_isNull_9 = false; /* 078 */ long agg_value_9 = -1L; /* 079 */ if (!false) { /* 080 */ agg_value_9 = (long) 0; /* 081 */ } /* 082 */ if (!agg_isNull_9) { /* 083 */ agg_agg_isNull_7_0 = false; /* 084 */ agg_value_7 = agg_value_9; /* 085 */ continue; /* 086 */ } /* 087 */ /* 088 */ } while (false); /* 089 */ /* 090 */ long agg_value_6 = -1L; /* 091 */ /* 092 */ agg_value_6 = agg_value_7 + agg_value_4; /* 093 */ boolean agg_isNull_11 = true; /* 094 */ double agg_value_11 = -1.0; /* 095 */ /* 096 */ if (!agg_bufIsNull_1) { /* 097 */ agg_agg_isNull_13_0 = true; /* 098 */ double agg_value_13 = -1.0; /* 099 */ do { /* 100 */ boolean agg_isNull_14 = agg_isNull_4; /* 101 */ double agg_value_14 = -1.0; /* 102 */ if (!agg_isNull_4) { /* 103 */ agg_value_14 = (double) agg_value_4; /* 104 */ } /* 105 */ if (!agg_isNull_14) { /* 106 */ agg_agg_isNull_13_0 = false; /* 107 */ agg_value_13 = agg_value_14; /* 108 */ continue; /* 109 */ } /* 110 */ /* 111 */ boolean agg_isNull_15 = false; /* 112 */ double agg_value_15 = -1.0; /* 113 */ if (!false) { /* 114 */ agg_value_15 = (double) 0; /* 115 */ } /* 116 */ if (!agg_isNull_15) { /* 117 */ agg_agg_isNull_13_0 = false; /* 118 */ agg_value_13 = agg_value_15; /* 119 */ continue; /* 120 */ } /* 121 */ /* 122 */ } while (false); /* 123 */ /* 124 */ agg_isNull_11 = false; // resultCode could change nullability. /* 125 */ /* 126 */ agg_value_11 = agg_bufValue_1 + agg_value_13; /* 127 */ /* 128 */ } /* 129 */ boolean agg_isNull_17 = false; /* 130 */ long agg_value_17 = -1L; /* 131 */ if (!false && agg_isNull_4) { /* 132 */ agg_isNull_17 = agg_bufIsNull_2; /* 133 */ agg_value_17 = agg_bufValue_2; /* 134 */ } else { /* 135 */ boolean agg_isNull_20 = true; /* 136 */ long agg_value_20 = -1L; /* 137 */ /* 138 */ if (!agg_bufIsNull_2) { /* 139 */ agg_isNull_20 = false; // resultCode could change nullability. /* 140 */ /* 141 */ agg_value_20 = agg_bufValue_2 + 1L; /* 142 */ /* 143 */ } /* 144 */ agg_isNull_17 = agg_isNull_20; /* 145 */ agg_value_17 = agg_value_20; /* 146 */ } /* 147 */ // update aggregation buffer /* 148 */ agg_bufIsNull_0 = false; /* 149 */ agg_bufValue_0 = agg_value_6; /* 150 */ /* 151 */ agg_bufIsNull_1 = agg_isNull_11; /* 152 */ agg_bufValue_1 = agg_value_11; /* 153 */ /* 154 */ agg_bufIsNull_2 = agg_isNull_17; /* 155 */ agg_bufValue_2 = agg_value_17; /* 156 */ /* 157 */ } ``` You can check the previous discussion in https://github.com/apache/spark/pull/19082 ## How was this patch tested? Existing tests Closes #20965 from maropu/SPARK-21870-2. Authored-by: Takeshi Yamamuro <yamamuro@apache.org> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
36f8e53cfa
commit
cb0cddffe9
|
@ -115,7 +115,13 @@ package object dsl {
|
||||||
def getField(fieldName: String): UnresolvedExtractValue =
|
def getField(fieldName: String): UnresolvedExtractValue =
|
||||||
UnresolvedExtractValue(expr, Literal(fieldName))
|
UnresolvedExtractValue(expr, Literal(fieldName))
|
||||||
|
|
||||||
def cast(to: DataType): Expression = Cast(expr, to)
|
def cast(to: DataType): Expression = {
|
||||||
|
if (expr.resolved && expr.dataType.sameType(to)) {
|
||||||
|
expr
|
||||||
|
} else {
|
||||||
|
Cast(expr, to)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
def asc: SortOrder = SortOrder(expr, Ascending)
|
def asc: SortOrder = SortOrder(expr, Ascending)
|
||||||
def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast, Set.empty)
|
def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast, Set.empty)
|
||||||
|
|
|
@ -1612,6 +1612,48 @@ object CodeGenerator extends Logging {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extracts all the input variables from references and subexpression elimination states
|
||||||
|
* for a given `expr`. This result will be used to split the generated code of
|
||||||
|
* expressions into multiple functions.
|
||||||
|
*/
|
||||||
|
def getLocalInputVariableValues(
|
||||||
|
ctx: CodegenContext,
|
||||||
|
expr: Expression,
|
||||||
|
subExprs: Map[Expression, SubExprEliminationState]): Set[VariableValue] = {
|
||||||
|
val argSet = mutable.Set[VariableValue]()
|
||||||
|
if (ctx.INPUT_ROW != null) {
|
||||||
|
argSet += JavaCode.variable(ctx.INPUT_ROW, classOf[InternalRow])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collects local variables from a given `expr` tree
|
||||||
|
val collectLocalVariable = (ev: ExprValue) => ev match {
|
||||||
|
case vv: VariableValue => argSet += vv
|
||||||
|
case _ =>
|
||||||
|
}
|
||||||
|
|
||||||
|
val stack = mutable.Stack[Expression](expr)
|
||||||
|
while (stack.nonEmpty) {
|
||||||
|
stack.pop() match {
|
||||||
|
case e if subExprs.contains(e) =>
|
||||||
|
val SubExprEliminationState(isNull, value) = subExprs(e)
|
||||||
|
collectLocalVariable(value)
|
||||||
|
collectLocalVariable(isNull)
|
||||||
|
|
||||||
|
case ref: BoundReference if ctx.currentVars != null &&
|
||||||
|
ctx.currentVars(ref.ordinal) != null =>
|
||||||
|
val ExprCode(_, isNull, value) = ctx.currentVars(ref.ordinal)
|
||||||
|
collectLocalVariable(value)
|
||||||
|
collectLocalVariable(isNull)
|
||||||
|
|
||||||
|
case e =>
|
||||||
|
stack.pushAll(e.children)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
argSet.toSet
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the name used in accessor and setter for a Java primitive type.
|
* Returns the name used in accessor and setter for a Java primitive type.
|
||||||
*/
|
*/
|
||||||
|
@ -1719,6 +1761,15 @@ object CodeGenerator extends Logging {
|
||||||
1 + params.map(paramLengthForExpr).sum
|
1 + params.map(paramLengthForExpr).sum
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def calculateParamLengthFromExprValues(params: Seq[ExprValue]): Int = {
|
||||||
|
def paramLengthForExpr(input: ExprValue): Int = input.javaType match {
|
||||||
|
case java.lang.Long.TYPE | java.lang.Double.TYPE => 2
|
||||||
|
case _ => 1
|
||||||
|
}
|
||||||
|
// Initial value is 1 for `this`.
|
||||||
|
1 + params.map(paramLengthForExpr).sum
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* In Java, a method descriptor is valid only if it represents method parameters with a total
|
* In Java, a method descriptor is valid only if it represents method parameters with a total
|
||||||
* length less than a pre-defined constant.
|
* length less than a pre-defined constant.
|
||||||
|
|
|
@ -143,7 +143,10 @@ trait Block extends TreeNode[Block] with JavaCode {
|
||||||
case _ => code.trim
|
case _ => code.trim
|
||||||
}
|
}
|
||||||
|
|
||||||
def length: Int = toString.length
|
def length: Int = {
|
||||||
|
// Returns a code length without comments
|
||||||
|
CodeFormatter.stripExtraNewLinesAndComments(toString).length
|
||||||
|
}
|
||||||
|
|
||||||
def isEmpty: Boolean = toString.isEmpty
|
def isEmpty: Boolean = toString.isEmpty
|
||||||
|
|
||||||
|
|
|
@ -354,12 +354,14 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
|
||||||
|
|
||||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||||
val eval = child.genCode(ctx)
|
val eval = child.genCode(ctx)
|
||||||
val value = eval.isNull match {
|
val (value, newCode) = eval.isNull match {
|
||||||
case TrueLiteral => FalseLiteral
|
case TrueLiteral => (FalseLiteral, EmptyBlock)
|
||||||
case FalseLiteral => TrueLiteral
|
case FalseLiteral => (TrueLiteral, EmptyBlock)
|
||||||
case v => JavaCode.isNullExpression(s"!$v")
|
case v =>
|
||||||
|
val value = ctx.freshName("value")
|
||||||
|
(JavaCode.variable(value, BooleanType), code"boolean $value = !$v;")
|
||||||
}
|
}
|
||||||
ExprCode(code = eval.code, isNull = FalseLiteral, value = value)
|
ExprCode(code = eval.code + newCode, isNull = FalseLiteral, value = value)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def sql: String = s"(${child.sql} IS NOT NULL)"
|
override def sql: String = s"(${child.sql} IS NOT NULL)"
|
||||||
|
|
|
@ -1080,6 +1080,15 @@ object SQLConf {
|
||||||
.booleanConf
|
.booleanConf
|
||||||
.createWithDefault(false)
|
.createWithDefault(false)
|
||||||
|
|
||||||
|
val CODEGEN_SPLIT_AGGREGATE_FUNC =
|
||||||
|
buildConf("spark.sql.codegen.aggregate.splitAggregateFunc.enabled")
|
||||||
|
.internal()
|
||||||
|
.doc("When true, the code generator would split aggregate code into individual methods " +
|
||||||
|
"instead of a single big method. This can be used to avoid oversized function that " +
|
||||||
|
"can miss the opportunity of JIT optimization.")
|
||||||
|
.booleanConf
|
||||||
|
.createWithDefault(true)
|
||||||
|
|
||||||
val MAX_NESTED_VIEW_DEPTH =
|
val MAX_NESTED_VIEW_DEPTH =
|
||||||
buildConf("spark.sql.view.maxNestedViewDepth")
|
buildConf("spark.sql.view.maxNestedViewDepth")
|
||||||
.internal()
|
.internal()
|
||||||
|
@ -2353,6 +2362,8 @@ class SQLConf extends Serializable with Logging {
|
||||||
def cartesianProductExecBufferSpillThreshold: Int =
|
def cartesianProductExecBufferSpillThreshold: Int =
|
||||||
getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_SPILL_THRESHOLD)
|
getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_SPILL_THRESHOLD)
|
||||||
|
|
||||||
|
def codegenSplitAggregateFunc: Boolean = getConf(SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC)
|
||||||
|
|
||||||
def maxNestedViewDepth: Int = getConf(SQLConf.MAX_NESTED_VIEW_DEPTH)
|
def maxNestedViewDepth: Int = getConf(SQLConf.MAX_NESTED_VIEW_DEPTH)
|
||||||
|
|
||||||
def starSchemaDetection: Boolean = getConf(STARSCHEMA_DETECTION)
|
def starSchemaDetection: Boolean = getConf(STARSCHEMA_DETECTION)
|
||||||
|
|
|
@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.aggregate
|
||||||
|
|
||||||
import java.util.concurrent.TimeUnit._
|
import java.util.concurrent.TimeUnit._
|
||||||
|
|
||||||
|
import scala.collection.mutable
|
||||||
|
|
||||||
import org.apache.spark.TaskContext
|
import org.apache.spark.TaskContext
|
||||||
import org.apache.spark.memory.{SparkOutOfMemoryError, TaskMemoryManager}
|
import org.apache.spark.memory.{SparkOutOfMemoryError, TaskMemoryManager}
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
|
@ -174,8 +176,9 @@ case class HashAggregateExec(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// The variables used as aggregation buffer. Only used for aggregation without keys.
|
// The variables are used as aggregation buffers and each aggregate function has one or more
|
||||||
private var bufVars: Seq[ExprCode] = _
|
// ExprCode to initialize its buffer slots. Only used for aggregation without keys.
|
||||||
|
private var bufVars: Seq[Seq[ExprCode]] = _
|
||||||
|
|
||||||
private def doProduceWithoutKeys(ctx: CodegenContext): String = {
|
private def doProduceWithoutKeys(ctx: CodegenContext): String = {
|
||||||
val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg")
|
val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg")
|
||||||
|
@ -184,27 +187,30 @@ case class HashAggregateExec(
|
||||||
|
|
||||||
// generate variables for aggregation buffer
|
// generate variables for aggregation buffer
|
||||||
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
|
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
|
||||||
val initExpr = functions.flatMap(f => f.initialValues)
|
val initExpr = functions.map(f => f.initialValues)
|
||||||
bufVars = initExpr.map { e =>
|
bufVars = initExpr.map { exprs =>
|
||||||
val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "bufIsNull")
|
exprs.map { e =>
|
||||||
val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue")
|
val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "bufIsNull")
|
||||||
// The initial expression should not access any column
|
val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue")
|
||||||
val ev = e.genCode(ctx)
|
// The initial expression should not access any column
|
||||||
val initVars = code"""
|
val ev = e.genCode(ctx)
|
||||||
| $isNull = ${ev.isNull};
|
val initVars = code"""
|
||||||
| $value = ${ev.value};
|
|$isNull = ${ev.isNull};
|
||||||
""".stripMargin
|
|$value = ${ev.value};
|
||||||
ExprCode(
|
""".stripMargin
|
||||||
ev.code + initVars,
|
ExprCode(
|
||||||
JavaCode.isNullGlobal(isNull),
|
ev.code + initVars,
|
||||||
JavaCode.global(value, e.dataType))
|
JavaCode.isNullGlobal(isNull),
|
||||||
|
JavaCode.global(value, e.dataType))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
val initBufVar = evaluateVariables(bufVars)
|
val flatBufVars = bufVars.flatten
|
||||||
|
val initBufVar = evaluateVariables(flatBufVars)
|
||||||
|
|
||||||
// generate variables for output
|
// generate variables for output
|
||||||
val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) {
|
val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) {
|
||||||
// evaluate aggregate results
|
// evaluate aggregate results
|
||||||
ctx.currentVars = bufVars
|
ctx.currentVars = flatBufVars
|
||||||
val aggResults = bindReferences(
|
val aggResults = bindReferences(
|
||||||
functions.map(_.evaluateExpression),
|
functions.map(_.evaluateExpression),
|
||||||
aggregateBufferAttributes).map(_.genCode(ctx))
|
aggregateBufferAttributes).map(_.genCode(ctx))
|
||||||
|
@ -218,7 +224,7 @@ case class HashAggregateExec(
|
||||||
""".stripMargin)
|
""".stripMargin)
|
||||||
} else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
|
} else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
|
||||||
// output the aggregate buffer directly
|
// output the aggregate buffer directly
|
||||||
(bufVars, "")
|
(flatBufVars, "")
|
||||||
} else {
|
} else {
|
||||||
// no aggregate function, the result should be literals
|
// no aggregate function, the result should be literals
|
||||||
val resultVars = resultExpressions.map(_.genCode(ctx))
|
val resultVars = resultExpressions.map(_.genCode(ctx))
|
||||||
|
@ -255,11 +261,85 @@ case class HashAggregateExec(
|
||||||
""".stripMargin
|
""".stripMargin
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private def isValidParamLength(paramLength: Int): Boolean = {
|
||||||
|
// This config is only for testing
|
||||||
|
sqlContext.getConf("spark.sql.HashAggregateExec.validParamLength", null) match {
|
||||||
|
case null | "" => CodeGenerator.isValidParamLength(paramLength)
|
||||||
|
case validLength => paramLength <= validLength.toInt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Splits aggregate code into small functions because the most of JVM implementations
|
||||||
|
// can not compile too long functions. Returns None if we are not able to split the given code.
|
||||||
|
//
|
||||||
|
// Note: The difference from `CodeGenerator.splitExpressions` is that we define an individual
|
||||||
|
// function for each aggregation function (e.g., SUM and AVG). For example, in a query
|
||||||
|
// `SELECT SUM(a), AVG(a) FROM VALUES(1) t(a)`, we define two functions
|
||||||
|
// for `SUM(a)` and `AVG(a)`.
|
||||||
|
private def splitAggregateExpressions(
|
||||||
|
ctx: CodegenContext,
|
||||||
|
aggNames: Seq[String],
|
||||||
|
aggBufferUpdatingExprs: Seq[Seq[Expression]],
|
||||||
|
aggCodeBlocks: Seq[Block],
|
||||||
|
subExprs: Map[Expression, SubExprEliminationState]): Option[String] = {
|
||||||
|
val exprValsInSubExprs = subExprs.flatMap { case (_, s) => s.value :: s.isNull :: Nil }
|
||||||
|
if (exprValsInSubExprs.exists(_.isInstanceOf[SimpleExprValue])) {
|
||||||
|
// `SimpleExprValue`s cannot be used as an input variable for split functions, so
|
||||||
|
// we give up splitting functions if it exists in `subExprs`.
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
val inputVars = aggBufferUpdatingExprs.map { aggExprsForOneFunc =>
|
||||||
|
val inputVarsForOneFunc = aggExprsForOneFunc.map(
|
||||||
|
CodeGenerator.getLocalInputVariableValues(ctx, _, subExprs)).reduce(_ ++ _).toSeq
|
||||||
|
val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVarsForOneFunc)
|
||||||
|
|
||||||
|
// Checks if a parameter length for the `aggExprsForOneFunc` does not go over the JVM limit
|
||||||
|
if (isValidParamLength(paramLength)) {
|
||||||
|
Some(inputVarsForOneFunc)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Checks if all the aggregate code can be split into pieces.
|
||||||
|
// If the parameter length of at lease one `aggExprsForOneFunc` goes over the limit,
|
||||||
|
// we totally give up splitting aggregate code.
|
||||||
|
if (inputVars.forall(_.isDefined)) {
|
||||||
|
val splitCodes = inputVars.flatten.zipWithIndex.map { case (args, i) =>
|
||||||
|
val doAggFunc = ctx.freshName(s"doAggregate_${aggNames(i)}")
|
||||||
|
val argList = args.map(v => s"${v.javaType.getName} ${v.variableName}").mkString(", ")
|
||||||
|
val doAggFuncName = ctx.addNewFunction(doAggFunc,
|
||||||
|
s"""
|
||||||
|
|private void $doAggFunc($argList) throws java.io.IOException {
|
||||||
|
| ${aggCodeBlocks(i)}
|
||||||
|
|}
|
||||||
|
""".stripMargin)
|
||||||
|
|
||||||
|
val inputVariables = args.map(_.variableName).mkString(", ")
|
||||||
|
s"$doAggFuncName($inputVariables);"
|
||||||
|
}
|
||||||
|
Some(splitCodes.mkString("\n").trim)
|
||||||
|
} else {
|
||||||
|
val errMsg = "Failed to split aggregate code into small functions because the parameter " +
|
||||||
|
"length of at least one split function went over the JVM limit: " +
|
||||||
|
CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH
|
||||||
|
if (Utils.isTesting) {
|
||||||
|
throw new IllegalStateException(errMsg)
|
||||||
|
} else {
|
||||||
|
logInfo(errMsg)
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {
|
private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {
|
||||||
// only have DeclarativeAggregate
|
// only have DeclarativeAggregate
|
||||||
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
|
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
|
||||||
val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output
|
val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output
|
||||||
val updateExpr = aggregateExpressions.flatMap { e =>
|
// To individually generate code for each aggregate function, an element in `updateExprs` holds
|
||||||
|
// all the expressions for the buffer of an aggregation function.
|
||||||
|
val updateExprs = aggregateExpressions.map { e =>
|
||||||
e.mode match {
|
e.mode match {
|
||||||
case Partial | Complete =>
|
case Partial | Complete =>
|
||||||
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions
|
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions
|
||||||
|
@ -267,28 +347,56 @@ case class HashAggregateExec(
|
||||||
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
|
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ctx.currentVars = bufVars ++ input
|
ctx.currentVars = bufVars.flatten ++ input
|
||||||
val boundUpdateExpr = bindReferences(updateExpr, inputAttrs)
|
val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc =>
|
||||||
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
|
bindReferences(updateExprsForOneFunc, inputAttrs)
|
||||||
val effectiveCodes = subExprs.codes.mkString("\n")
|
|
||||||
val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) {
|
|
||||||
boundUpdateExpr.map(_.genCode(ctx))
|
|
||||||
}
|
}
|
||||||
// aggregate buffer should be updated atomic
|
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
|
||||||
val updates = aggVals.zipWithIndex.map { case (ev, i) =>
|
val effectiveCodes = subExprs.codes.mkString("\n")
|
||||||
s"""
|
val bufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc =>
|
||||||
| ${bufVars(i).isNull} = ${ev.isNull};
|
ctx.withSubExprEliminationExprs(subExprs.states) {
|
||||||
| ${bufVars(i).value} = ${ev.value};
|
boundUpdateExprsForOneFunc.map(_.genCode(ctx))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val aggNames = functions.map(_.prettyName)
|
||||||
|
val aggCodeBlocks = bufferEvals.zipWithIndex.map { case (bufferEvalsForOneFunc, i) =>
|
||||||
|
val bufVarsForOneFunc = bufVars(i)
|
||||||
|
// All the update code for aggregation buffers should be placed in the end
|
||||||
|
// of each aggregation function code.
|
||||||
|
val updates = bufferEvalsForOneFunc.zip(bufVarsForOneFunc).map { case (ev, bufVar) =>
|
||||||
|
s"""
|
||||||
|
|${bufVar.isNull} = ${ev.isNull};
|
||||||
|
|${bufVar.value} = ${ev.value};
|
||||||
|
""".stripMargin
|
||||||
|
}
|
||||||
|
code"""
|
||||||
|
|// do aggregate for ${aggNames(i)}
|
||||||
|
|// evaluate aggregate function
|
||||||
|
|${evaluateVariables(bufferEvalsForOneFunc)}
|
||||||
|
|// update aggregation buffers
|
||||||
|
|${updates.mkString("\n").trim}
|
||||||
""".stripMargin
|
""".stripMargin
|
||||||
}
|
}
|
||||||
|
|
||||||
|
val codeToEvalAggFunc = if (conf.codegenSplitAggregateFunc &&
|
||||||
|
aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) {
|
||||||
|
val maybeSplitCode = splitAggregateExpressions(
|
||||||
|
ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states)
|
||||||
|
|
||||||
|
maybeSplitCode.getOrElse {
|
||||||
|
aggCodeBlocks.fold(EmptyBlock)(_ + _).code
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
aggCodeBlocks.fold(EmptyBlock)(_ + _).code
|
||||||
|
}
|
||||||
|
|
||||||
s"""
|
s"""
|
||||||
| // do aggregate
|
|// do aggregate
|
||||||
| // common sub-expressions
|
|// common sub-expressions
|
||||||
| $effectiveCodes
|
|$effectiveCodes
|
||||||
| // evaluate aggregate function
|
|// evaluate aggregate functions and update aggregation buffers
|
||||||
| ${evaluateVariables(aggVals)}
|
|$codeToEvalAggFunc
|
||||||
| // update aggregation buffer
|
|
||||||
| ${updates.mkString("\n").trim}
|
|
||||||
""".stripMargin
|
""".stripMargin
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -745,8 +853,10 @@ case class HashAggregateExec(
|
||||||
val unsafeRowBuffer = ctx.freshName("unsafeRowAggBuffer")
|
val unsafeRowBuffer = ctx.freshName("unsafeRowAggBuffer")
|
||||||
val fastRowBuffer = ctx.freshName("fastAggBuffer")
|
val fastRowBuffer = ctx.freshName("fastAggBuffer")
|
||||||
|
|
||||||
// only have DeclarativeAggregate
|
// To individually generate code for each aggregate function, an element in `updateExprs` holds
|
||||||
val updateExpr = aggregateExpressions.flatMap { e =>
|
// all the expressions for the buffer of an aggregation function.
|
||||||
|
val updateExprs = aggregateExpressions.map { e =>
|
||||||
|
// only have DeclarativeAggregate
|
||||||
e.mode match {
|
e.mode match {
|
||||||
case Partial | Complete =>
|
case Partial | Complete =>
|
||||||
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions
|
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions
|
||||||
|
@ -824,25 +934,70 @@ case class HashAggregateExec(
|
||||||
// generating input columns, we use `currentVars`.
|
// generating input columns, we use `currentVars`.
|
||||||
ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input
|
ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input
|
||||||
|
|
||||||
|
val aggNames = aggregateExpressions.map(_.aggregateFunction.prettyName)
|
||||||
|
// Computes start offsets for each aggregation function code
|
||||||
|
// in the underlying buffer row.
|
||||||
|
val bufferStartOffsets = {
|
||||||
|
val offsets = mutable.ArrayBuffer[Int]()
|
||||||
|
var curOffset = 0
|
||||||
|
updateExprs.foreach { exprsForOneFunc =>
|
||||||
|
offsets += curOffset
|
||||||
|
curOffset += exprsForOneFunc.length
|
||||||
|
}
|
||||||
|
offsets.toArray
|
||||||
|
}
|
||||||
|
|
||||||
val updateRowInRegularHashMap: String = {
|
val updateRowInRegularHashMap: String = {
|
||||||
ctx.INPUT_ROW = unsafeRowBuffer
|
ctx.INPUT_ROW = unsafeRowBuffer
|
||||||
val boundUpdateExpr = bindReferences(updateExpr, inputAttr)
|
val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc =>
|
||||||
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
|
bindReferences(updateExprsForOneFunc, inputAttr)
|
||||||
|
}
|
||||||
|
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
|
||||||
val effectiveCodes = subExprs.codes.mkString("\n")
|
val effectiveCodes = subExprs.codes.mkString("\n")
|
||||||
val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
|
val unsafeRowBufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc =>
|
||||||
boundUpdateExpr.map(_.genCode(ctx))
|
ctx.withSubExprEliminationExprs(subExprs.states) {
|
||||||
|
boundUpdateExprsForOneFunc.map(_.genCode(ctx))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) =>
|
|
||||||
val dt = updateExpr(i).dataType
|
val aggCodeBlocks = updateExprs.indices.map { i =>
|
||||||
CodeGenerator.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable)
|
val rowBufferEvalsForOneFunc = unsafeRowBufferEvals(i)
|
||||||
|
val boundUpdateExprsForOneFunc = boundUpdateExprs(i)
|
||||||
|
val bufferOffset = bufferStartOffsets(i)
|
||||||
|
|
||||||
|
// All the update code for aggregation buffers should be placed in the end
|
||||||
|
// of each aggregation function code.
|
||||||
|
val updateRowBuffers = rowBufferEvalsForOneFunc.zipWithIndex.map { case (ev, j) =>
|
||||||
|
val updateExpr = boundUpdateExprsForOneFunc(j)
|
||||||
|
val dt = updateExpr.dataType
|
||||||
|
val nullable = updateExpr.nullable
|
||||||
|
CodeGenerator.updateColumn(unsafeRowBuffer, dt, bufferOffset + j, ev, nullable)
|
||||||
|
}
|
||||||
|
code"""
|
||||||
|
|// evaluate aggregate function for ${aggNames(i)}
|
||||||
|
|${evaluateVariables(rowBufferEvalsForOneFunc)}
|
||||||
|
|// update unsafe row buffer
|
||||||
|
|${updateRowBuffers.mkString("\n").trim}
|
||||||
|
""".stripMargin
|
||||||
}
|
}
|
||||||
|
|
||||||
|
val codeToEvalAggFunc = if (conf.codegenSplitAggregateFunc &&
|
||||||
|
aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) {
|
||||||
|
val maybeSplitCode = splitAggregateExpressions(
|
||||||
|
ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states)
|
||||||
|
|
||||||
|
maybeSplitCode.getOrElse {
|
||||||
|
aggCodeBlocks.fold(EmptyBlock)(_ + _).code
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
aggCodeBlocks.fold(EmptyBlock)(_ + _).code
|
||||||
|
}
|
||||||
|
|
||||||
s"""
|
s"""
|
||||||
|// common sub-expressions
|
|// common sub-expressions
|
||||||
|$effectiveCodes
|
|$effectiveCodes
|
||||||
|// evaluate aggregate function
|
|// evaluate aggregate functions and update aggregation buffers
|
||||||
|${evaluateVariables(unsafeRowBufferEvals)}
|
|$codeToEvalAggFunc
|
||||||
|// update unsafe row buffer
|
|
||||||
|${updateUnsafeRowBuffer.mkString("\n").trim}
|
|
||||||
""".stripMargin
|
""".stripMargin
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -850,16 +1005,48 @@ case class HashAggregateExec(
|
||||||
if (isFastHashMapEnabled) {
|
if (isFastHashMapEnabled) {
|
||||||
if (isVectorizedHashMapEnabled) {
|
if (isVectorizedHashMapEnabled) {
|
||||||
ctx.INPUT_ROW = fastRowBuffer
|
ctx.INPUT_ROW = fastRowBuffer
|
||||||
val boundUpdateExpr = bindReferences(updateExpr, inputAttr)
|
val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc =>
|
||||||
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
|
bindReferences(updateExprsForOneFunc, inputAttr)
|
||||||
val effectiveCodes = subExprs.codes.mkString("\n")
|
|
||||||
val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
|
|
||||||
boundUpdateExpr.map(_.genCode(ctx))
|
|
||||||
}
|
}
|
||||||
val updateFastRow = fastRowEvals.zipWithIndex.map { case (ev, i) =>
|
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
|
||||||
val dt = updateExpr(i).dataType
|
val effectiveCodes = subExprs.codes.mkString("\n")
|
||||||
CodeGenerator.updateColumn(
|
val fastRowEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc =>
|
||||||
fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorized = true)
|
ctx.withSubExprEliminationExprs(subExprs.states) {
|
||||||
|
boundUpdateExprsForOneFunc.map(_.genCode(ctx))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val aggCodeBlocks = fastRowEvals.zipWithIndex.map { case (fastRowEvalsForOneFunc, i) =>
|
||||||
|
val boundUpdateExprsForOneFunc = boundUpdateExprs(i)
|
||||||
|
val bufferOffset = bufferStartOffsets(i)
|
||||||
|
// All the update code for aggregation buffers should be placed in the end
|
||||||
|
// of each aggregation function code.
|
||||||
|
val updateRowBuffer = fastRowEvalsForOneFunc.zipWithIndex.map { case (ev, j) =>
|
||||||
|
val updateExpr = boundUpdateExprsForOneFunc(j)
|
||||||
|
val dt = updateExpr.dataType
|
||||||
|
val nullable = updateExpr.nullable
|
||||||
|
CodeGenerator.updateColumn(fastRowBuffer, dt, bufferOffset + j, ev, nullable,
|
||||||
|
isVectorized = true)
|
||||||
|
}
|
||||||
|
code"""
|
||||||
|
|// evaluate aggregate function for ${aggNames(i)}
|
||||||
|
|${evaluateVariables(fastRowEvalsForOneFunc)}
|
||||||
|
|// update fast row
|
||||||
|
|${updateRowBuffer.mkString("\n").trim}
|
||||||
|
""".stripMargin
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
val codeToEvalAggFunc = if (conf.codegenSplitAggregateFunc &&
|
||||||
|
aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) {
|
||||||
|
val maybeSplitCode = splitAggregateExpressions(
|
||||||
|
ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states)
|
||||||
|
|
||||||
|
maybeSplitCode.getOrElse {
|
||||||
|
aggCodeBlocks.fold(EmptyBlock)(_ + _).code
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
aggCodeBlocks.fold(EmptyBlock)(_ + _).code
|
||||||
}
|
}
|
||||||
|
|
||||||
// If vectorized fast hash map is on, we first generate code to update row
|
// If vectorized fast hash map is on, we first generate code to update row
|
||||||
|
@ -869,10 +1056,8 @@ case class HashAggregateExec(
|
||||||
|if ($fastRowBuffer != null) {
|
|if ($fastRowBuffer != null) {
|
||||||
| // common sub-expressions
|
| // common sub-expressions
|
||||||
| $effectiveCodes
|
| $effectiveCodes
|
||||||
| // evaluate aggregate function
|
| // evaluate aggregate functions and update aggregation buffers
|
||||||
| ${evaluateVariables(fastRowEvals)}
|
| $codeToEvalAggFunc
|
||||||
| // update fast row
|
|
||||||
| ${updateFastRow.mkString("\n").trim}
|
|
||||||
|} else {
|
|} else {
|
||||||
| $updateRowInRegularHashMap
|
| $updateRowInRegularHashMap
|
||||||
|}
|
|}
|
||||||
|
|
|
@ -398,4 +398,25 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession {
|
||||||
}.isDefined,
|
}.isDefined,
|
||||||
"LocalTableScanExec should be within a WholeStageCodegen domain.")
|
"LocalTableScanExec should be within a WholeStageCodegen domain.")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("Give up splitting aggregate code if a parameter length goes over the limit") {
|
||||||
|
withSQLConf(
|
||||||
|
SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC.key -> "true",
|
||||||
|
SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> "1",
|
||||||
|
"spark.sql.HashAggregateExec.validParamLength" -> "0") {
|
||||||
|
withTable("t") {
|
||||||
|
val expectedErrMsg = "Failed to split aggregate code into small functions"
|
||||||
|
Seq(
|
||||||
|
// Test case without keys
|
||||||
|
"SELECT AVG(v) FROM VALUES(1) t(v)",
|
||||||
|
// Tet case with keys
|
||||||
|
"SELECT k, AVG(v) FROM VALUES((1, 1)) t(k, v) GROUP BY k").foreach { query =>
|
||||||
|
val errMsg = intercept[IllegalStateException] {
|
||||||
|
sql(query).collect
|
||||||
|
}.getMessage
|
||||||
|
assert(errMsg.contains(expectedErrMsg))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue