[SPARK-12951] [SQL] support spilling in generated aggregate

This PR add spilling support for generated TungstenAggregate.

If spilling happened, it's not that bad to do the iterator based sort-merge-aggregate (not generated).

The changes will be covered by TungstenAggregationQueryWithControlledFallbackSuite

Author: Davies Liu <davies@databricks.com>

Closes #10998 from davies/gen_spilling.
This commit is contained in:
Davies Liu 2016-02-02 19:47:44 -08:00 committed by Davies Liu
parent ff71261b65
commit 99a6e3c1e8

View file

@ -25,9 +25,9 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.{DecimalType, StructType}
import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.KVIterator
case class TungstenAggregate(
@ -258,6 +258,7 @@ case class TungstenAggregate(
// The name for HashMap
private var hashMapTerm: String = _
private var sorterTerm: String = _
/**
* This is called by generated Java class, should be public.
@ -286,39 +287,98 @@ case class TungstenAggregate(
GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema)
}
/**
* Update peak execution memory, called in generated Java class.
* Called by generated Java class to finish the aggregate and return a KVIterator.
*/
def updatePeakMemory(hashMap: UnsafeFixedWidthAggregationMap): Unit = {
def finishAggregate(
hashMap: UnsafeFixedWidthAggregationMap,
sorter: UnsafeKVExternalSorter): KVIterator[UnsafeRow, UnsafeRow] = {
// update peak execution memory
val mapMemory = hashMap.getPeakMemoryUsedBytes
val sorterMemory = Option(sorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L)
val peakMemory = Math.max(mapMemory, sorterMemory)
val metrics = TaskContext.get().taskMetrics()
metrics.incPeakExecutionMemory(mapMemory)
metrics.incPeakExecutionMemory(peakMemory)
if (sorter == null) {
// not spilled
return hashMap.iterator()
}
// merge the final hashMap into sorter
sorter.merge(hashMap.destructAndCreateExternalSorter())
hashMap.free()
val sortedIter = sorter.sortedIterator()
// Create a KVIterator based on the sorted iterator.
new KVIterator[UnsafeRow, UnsafeRow] {
// Create a MutableProjection to merge the rows of same key together
val mergeExpr = declFunctions.flatMap(_.mergeExpressions)
val mergeProjection = newMutableProjection(
mergeExpr,
bufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes),
subexpressionEliminationEnabled)()
val joinedRow = new JoinedRow()
var currentKey: UnsafeRow = null
var currentRow: UnsafeRow = null
var nextKey: UnsafeRow = if (sortedIter.next()) {
sortedIter.getKey
} else {
null
}
override def next(): Boolean = {
if (nextKey != null) {
currentKey = nextKey.copy()
currentRow = sortedIter.getValue.copy()
nextKey = null
// use the first row as aggregate buffer
mergeProjection.target(currentRow)
// merge the following rows with same key together
var findNextGroup = false
while (!findNextGroup && sortedIter.next()) {
val key = sortedIter.getKey
if (currentKey.equals(key)) {
mergeProjection(joinedRow(currentRow, sortedIter.getValue))
} else {
// We find a new group.
findNextGroup = true
nextKey = key
}
}
true
} else {
false
}
}
override def getKey: UnsafeRow = currentKey
override def getValue: UnsafeRow = currentRow
override def close(): Unit = {
sortedIter.close()
}
}
}
private def doProduceWithKeys(ctx: CodegenContext): String = {
val initAgg = ctx.freshName("initAgg")
ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
// create hashMap
val thisPlan = ctx.addReferenceObj("plan", this)
hashMapTerm = ctx.freshName("hashMap")
val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName
ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();")
// Create a name for iterator from HashMap
val iterTerm = ctx.freshName("mapIter")
ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "")
// generate code for output
val keyTerm = ctx.freshName("aggKey")
val bufferTerm = ctx.freshName("aggBuffer")
val outputCode = if (modes.contains(Final) || modes.contains(Complete)) {
/**
* Generate the code for output.
*/
private def generateResultCode(
ctx: CodegenContext,
keyTerm: String,
bufferTerm: String,
plan: String): String = {
if (modes.contains(Final) || modes.contains(Complete)) {
// generate output using resultExpressions
ctx.currentVars = null
ctx.INPUT_ROW = keyTerm
val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) =>
BoundReference(i, e.dataType, e.nullable).gen(ctx)
BoundReference(i, e.dataType, e.nullable).gen(ctx)
}
ctx.INPUT_ROW = bufferTerm
val bufferVars = bufferAttributes.zipWithIndex.map { case (e, i) =>
@ -348,7 +408,7 @@ case class TungstenAggregate(
// This should be the last operator in a stage, we should output UnsafeRow directly
val joinerTerm = ctx.freshName("unsafeRowJoiner")
ctx.addMutableState(classOf[UnsafeRowJoiner].getName, joinerTerm,
s"$joinerTerm = $thisPlan.createUnsafeJoiner();")
s"$joinerTerm = $plan.createUnsafeJoiner();")
val resultRow = ctx.freshName("resultRow")
s"""
UnsafeRow $resultRow = $joinerTerm.join($keyTerm, $bufferTerm);
@ -367,6 +427,23 @@ case class TungstenAggregate(
${consume(ctx, eval)}
"""
}
}
private def doProduceWithKeys(ctx: CodegenContext): String = {
val initAgg = ctx.freshName("initAgg")
ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
// create hashMap
val thisPlan = ctx.addReferenceObj("plan", this)
hashMapTerm = ctx.freshName("hashMap")
val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName
ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();")
sorterTerm = ctx.freshName("sorter")
ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm, "")
// Create a name for iterator from HashMap
val iterTerm = ctx.freshName("mapIter")
ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "")
val doAgg = ctx.freshName("doAggregateWithKeys")
ctx.addNewFunction(doAgg,
@ -374,10 +451,15 @@ case class TungstenAggregate(
private void $doAgg() throws java.io.IOException {
${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
$iterTerm = $hashMapTerm.iterator();
$iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm);
}
""")
// generate code for output
val keyTerm = ctx.freshName("aggKey")
val bufferTerm = ctx.freshName("aggBuffer")
val outputCode = generateResultCode(ctx, keyTerm, bufferTerm, thisPlan)
s"""
if (!$initAgg) {
$initAgg = true;
@ -391,8 +473,10 @@ case class TungstenAggregate(
$outputCode
}
$thisPlan.updatePeakMemory($hashMapTerm);
$hashMapTerm.free();
$iterTerm.close();
if ($sorterTerm == null) {
$hashMapTerm.free();
}
"""
}
@ -425,14 +509,42 @@ case class TungstenAggregate(
ctx.updateColumn(buffer, dt, i, ev, updateExpr(i).nullable)
}
val (checkFallback, resetCoulter, incCounter) = if (testFallbackStartsAt.isDefined) {
val countTerm = ctx.freshName("fallbackCounter")
ctx.addMutableState("int", countTerm, s"$countTerm = 0;")
(s"$countTerm < ${testFallbackStartsAt.get}", s"$countTerm = 0;", s"$countTerm += 1;")
} else {
("true", "", "")
}
// We try to do hash map based in-memory aggregation first. If there is not enough memory (the
// hash map will return null for new key), we spill the hash map to disk to free memory, then
// continue to do in-memory aggregation and spilling until all the rows had been processed.
// Finally, sort the spilled aggregate buffers by key, and merge them together for same key.
s"""
// generate grouping key
${keyCode.code}
UnsafeRow $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key);
if ($buffer == null) {
// failed to allocate the first page
throw new OutOfMemoryError("No enough memory for aggregation");
UnsafeRow $buffer = null;
if ($checkFallback) {
// try to get the buffer from hash map
$buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key);
}
if ($buffer == null) {
if ($sorterTerm == null) {
$sorterTerm = $hashMapTerm.destructAndCreateExternalSorter();
} else {
$sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter());
}
$resetCoulter
// the hash map had be spilled, it should have enough memory now,
// try to allocate buffer again.
$buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key);
if ($buffer == null) {
// failed to allocate the first page
throw new OutOfMemoryError("No enough memory for aggregation");
}
}
$incCounter
// evaluate aggregate function
${evals.map(_.code).mkString("\n")}