[SPARK-12951] [SQL] support spilling in generated aggregate
This PR add spilling support for generated TungstenAggregate. If spilling happened, it's not that bad to do the iterator based sort-merge-aggregate (not generated). The changes will be covered by TungstenAggregationQueryWithControlledFallbackSuite Author: Davies Liu <davies@databricks.com> Closes #10998 from davies/gen_spilling.
This commit is contained in:
parent
ff71261b65
commit
99a6e3c1e8
|
@ -25,9 +25,9 @@ import org.apache.spark.sql.catalyst.expressions._
|
|||
import org.apache.spark.sql.catalyst.expressions.aggregate._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.plans.physical._
|
||||
import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap}
|
||||
import org.apache.spark.sql.execution._
|
||||
import org.apache.spark.sql.execution.metric.SQLMetrics
|
||||
import org.apache.spark.sql.types.{DecimalType, StructType}
|
||||
import org.apache.spark.sql.types.StructType
|
||||
import org.apache.spark.unsafe.KVIterator
|
||||
|
||||
case class TungstenAggregate(
|
||||
|
@ -258,6 +258,7 @@ case class TungstenAggregate(
|
|||
|
||||
// The name for HashMap
|
||||
private var hashMapTerm: String = _
|
||||
private var sorterTerm: String = _
|
||||
|
||||
/**
|
||||
* This is called by generated Java class, should be public.
|
||||
|
@ -286,39 +287,98 @@ case class TungstenAggregate(
|
|||
GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema)
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Update peak execution memory, called in generated Java class.
|
||||
* Called by generated Java class to finish the aggregate and return a KVIterator.
|
||||
*/
|
||||
def updatePeakMemory(hashMap: UnsafeFixedWidthAggregationMap): Unit = {
|
||||
def finishAggregate(
|
||||
hashMap: UnsafeFixedWidthAggregationMap,
|
||||
sorter: UnsafeKVExternalSorter): KVIterator[UnsafeRow, UnsafeRow] = {
|
||||
|
||||
// update peak execution memory
|
||||
val mapMemory = hashMap.getPeakMemoryUsedBytes
|
||||
val sorterMemory = Option(sorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L)
|
||||
val peakMemory = Math.max(mapMemory, sorterMemory)
|
||||
val metrics = TaskContext.get().taskMetrics()
|
||||
metrics.incPeakExecutionMemory(mapMemory)
|
||||
metrics.incPeakExecutionMemory(peakMemory)
|
||||
|
||||
if (sorter == null) {
|
||||
// not spilled
|
||||
return hashMap.iterator()
|
||||
}
|
||||
|
||||
// merge the final hashMap into sorter
|
||||
sorter.merge(hashMap.destructAndCreateExternalSorter())
|
||||
hashMap.free()
|
||||
val sortedIter = sorter.sortedIterator()
|
||||
|
||||
// Create a KVIterator based on the sorted iterator.
|
||||
new KVIterator[UnsafeRow, UnsafeRow] {
|
||||
|
||||
// Create a MutableProjection to merge the rows of same key together
|
||||
val mergeExpr = declFunctions.flatMap(_.mergeExpressions)
|
||||
val mergeProjection = newMutableProjection(
|
||||
mergeExpr,
|
||||
bufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes),
|
||||
subexpressionEliminationEnabled)()
|
||||
val joinedRow = new JoinedRow()
|
||||
|
||||
var currentKey: UnsafeRow = null
|
||||
var currentRow: UnsafeRow = null
|
||||
var nextKey: UnsafeRow = if (sortedIter.next()) {
|
||||
sortedIter.getKey
|
||||
} else {
|
||||
null
|
||||
}
|
||||
|
||||
override def next(): Boolean = {
|
||||
if (nextKey != null) {
|
||||
currentKey = nextKey.copy()
|
||||
currentRow = sortedIter.getValue.copy()
|
||||
nextKey = null
|
||||
// use the first row as aggregate buffer
|
||||
mergeProjection.target(currentRow)
|
||||
|
||||
// merge the following rows with same key together
|
||||
var findNextGroup = false
|
||||
while (!findNextGroup && sortedIter.next()) {
|
||||
val key = sortedIter.getKey
|
||||
if (currentKey.equals(key)) {
|
||||
mergeProjection(joinedRow(currentRow, sortedIter.getValue))
|
||||
} else {
|
||||
// We find a new group.
|
||||
findNextGroup = true
|
||||
nextKey = key
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
override def getKey: UnsafeRow = currentKey
|
||||
override def getValue: UnsafeRow = currentRow
|
||||
override def close(): Unit = {
|
||||
sortedIter.close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def doProduceWithKeys(ctx: CodegenContext): String = {
|
||||
val initAgg = ctx.freshName("initAgg")
|
||||
ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
|
||||
|
||||
// create hashMap
|
||||
val thisPlan = ctx.addReferenceObj("plan", this)
|
||||
hashMapTerm = ctx.freshName("hashMap")
|
||||
val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName
|
||||
ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();")
|
||||
|
||||
// Create a name for iterator from HashMap
|
||||
val iterTerm = ctx.freshName("mapIter")
|
||||
ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "")
|
||||
|
||||
// generate code for output
|
||||
val keyTerm = ctx.freshName("aggKey")
|
||||
val bufferTerm = ctx.freshName("aggBuffer")
|
||||
val outputCode = if (modes.contains(Final) || modes.contains(Complete)) {
|
||||
/**
|
||||
* Generate the code for output.
|
||||
*/
|
||||
private def generateResultCode(
|
||||
ctx: CodegenContext,
|
||||
keyTerm: String,
|
||||
bufferTerm: String,
|
||||
plan: String): String = {
|
||||
if (modes.contains(Final) || modes.contains(Complete)) {
|
||||
// generate output using resultExpressions
|
||||
ctx.currentVars = null
|
||||
ctx.INPUT_ROW = keyTerm
|
||||
val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) =>
|
||||
BoundReference(i, e.dataType, e.nullable).gen(ctx)
|
||||
BoundReference(i, e.dataType, e.nullable).gen(ctx)
|
||||
}
|
||||
ctx.INPUT_ROW = bufferTerm
|
||||
val bufferVars = bufferAttributes.zipWithIndex.map { case (e, i) =>
|
||||
|
@ -348,7 +408,7 @@ case class TungstenAggregate(
|
|||
// This should be the last operator in a stage, we should output UnsafeRow directly
|
||||
val joinerTerm = ctx.freshName("unsafeRowJoiner")
|
||||
ctx.addMutableState(classOf[UnsafeRowJoiner].getName, joinerTerm,
|
||||
s"$joinerTerm = $thisPlan.createUnsafeJoiner();")
|
||||
s"$joinerTerm = $plan.createUnsafeJoiner();")
|
||||
val resultRow = ctx.freshName("resultRow")
|
||||
s"""
|
||||
UnsafeRow $resultRow = $joinerTerm.join($keyTerm, $bufferTerm);
|
||||
|
@ -367,6 +427,23 @@ case class TungstenAggregate(
|
|||
${consume(ctx, eval)}
|
||||
"""
|
||||
}
|
||||
}
|
||||
|
||||
private def doProduceWithKeys(ctx: CodegenContext): String = {
|
||||
val initAgg = ctx.freshName("initAgg")
|
||||
ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
|
||||
|
||||
// create hashMap
|
||||
val thisPlan = ctx.addReferenceObj("plan", this)
|
||||
hashMapTerm = ctx.freshName("hashMap")
|
||||
val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName
|
||||
ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();")
|
||||
sorterTerm = ctx.freshName("sorter")
|
||||
ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm, "")
|
||||
|
||||
// Create a name for iterator from HashMap
|
||||
val iterTerm = ctx.freshName("mapIter")
|
||||
ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "")
|
||||
|
||||
val doAgg = ctx.freshName("doAggregateWithKeys")
|
||||
ctx.addNewFunction(doAgg,
|
||||
|
@ -374,10 +451,15 @@ case class TungstenAggregate(
|
|||
private void $doAgg() throws java.io.IOException {
|
||||
${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
|
||||
|
||||
$iterTerm = $hashMapTerm.iterator();
|
||||
$iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm);
|
||||
}
|
||||
""")
|
||||
|
||||
// generate code for output
|
||||
val keyTerm = ctx.freshName("aggKey")
|
||||
val bufferTerm = ctx.freshName("aggBuffer")
|
||||
val outputCode = generateResultCode(ctx, keyTerm, bufferTerm, thisPlan)
|
||||
|
||||
s"""
|
||||
if (!$initAgg) {
|
||||
$initAgg = true;
|
||||
|
@ -391,8 +473,10 @@ case class TungstenAggregate(
|
|||
$outputCode
|
||||
}
|
||||
|
||||
$thisPlan.updatePeakMemory($hashMapTerm);
|
||||
$hashMapTerm.free();
|
||||
$iterTerm.close();
|
||||
if ($sorterTerm == null) {
|
||||
$hashMapTerm.free();
|
||||
}
|
||||
"""
|
||||
}
|
||||
|
||||
|
@ -425,14 +509,42 @@ case class TungstenAggregate(
|
|||
ctx.updateColumn(buffer, dt, i, ev, updateExpr(i).nullable)
|
||||
}
|
||||
|
||||
val (checkFallback, resetCoulter, incCounter) = if (testFallbackStartsAt.isDefined) {
|
||||
val countTerm = ctx.freshName("fallbackCounter")
|
||||
ctx.addMutableState("int", countTerm, s"$countTerm = 0;")
|
||||
(s"$countTerm < ${testFallbackStartsAt.get}", s"$countTerm = 0;", s"$countTerm += 1;")
|
||||
} else {
|
||||
("true", "", "")
|
||||
}
|
||||
|
||||
// We try to do hash map based in-memory aggregation first. If there is not enough memory (the
|
||||
// hash map will return null for new key), we spill the hash map to disk to free memory, then
|
||||
// continue to do in-memory aggregation and spilling until all the rows had been processed.
|
||||
// Finally, sort the spilled aggregate buffers by key, and merge them together for same key.
|
||||
s"""
|
||||
// generate grouping key
|
||||
${keyCode.code}
|
||||
UnsafeRow $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key);
|
||||
if ($buffer == null) {
|
||||
// failed to allocate the first page
|
||||
throw new OutOfMemoryError("No enough memory for aggregation");
|
||||
UnsafeRow $buffer = null;
|
||||
if ($checkFallback) {
|
||||
// try to get the buffer from hash map
|
||||
$buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key);
|
||||
}
|
||||
if ($buffer == null) {
|
||||
if ($sorterTerm == null) {
|
||||
$sorterTerm = $hashMapTerm.destructAndCreateExternalSorter();
|
||||
} else {
|
||||
$sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter());
|
||||
}
|
||||
$resetCoulter
|
||||
// the hash map had be spilled, it should have enough memory now,
|
||||
// try to allocate buffer again.
|
||||
$buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key);
|
||||
if ($buffer == null) {
|
||||
// failed to allocate the first page
|
||||
throw new OutOfMemoryError("No enough memory for aggregation");
|
||||
}
|
||||
}
|
||||
$incCounter
|
||||
|
||||
// evaluate aggregate function
|
||||
${evals.map(_.code).mkString("\n")}
|
||||
|
|
Loading…
Reference in a new issue