[SPARK-9674][SQL] Remove GeneratedAggregate.

The new aggregate replaces the old GeneratedAggregate.

Author: Reynold Xin <rxin@databricks.com>

Closes #7983 from rxin/remove-generated-agg and squashes the following commits:

8334aae [Reynold Xin] [SPARK-9674][SQL] Remove GeneratedAggregate.
This commit is contained in:
Reynold Xin 2015-08-05 21:50:14 -07:00
parent 119b590538
commit 9270bd06fd
4 changed files with 2 additions and 437 deletions

View file

@ -1,352 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution
import java.io.IOException
import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.trees._
import org.apache.spark.sql.types._
case class AggregateEvaluation(
schema: Seq[Attribute],
initialValues: Seq[Expression],
update: Seq[Expression],
result: Expression)
/**
* :: DeveloperApi ::
* Alternate version of aggregation that leverages projection and thus code generation.
* Aggregations are converted into a set of projections from a aggregation buffer tuple back onto
* itself. Currently only used for simple aggregations like SUM, COUNT, or AVERAGE are supported.
*
* @param partial if true then aggregation is done partially on local data without shuffling to
* ensure all values where `groupingExpressions` are equal are present.
* @param groupingExpressions expressions that are evaluated to determine grouping.
* @param aggregateExpressions expressions that are computed for each group.
* @param unsafeEnabled whether to allow Unsafe-based aggregation buffers to be used.
* @param child the input data source.
*/
@DeveloperApi
case class GeneratedAggregate(
partial: Boolean,
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[NamedExpression],
unsafeEnabled: Boolean,
child: SparkPlan)
extends UnaryNode {
override def requiredChildDistribution: Seq[Distribution] =
if (partial) {
UnspecifiedDistribution :: Nil
} else {
if (groupingExpressions == Nil) {
AllTuples :: Nil
} else {
ClusteredDistribution(groupingExpressions) :: Nil
}
}
override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
protected override def doExecute(): RDD[InternalRow] = {
val aggregatesToCompute = aggregateExpressions.flatMap { a =>
a.collect { case agg: AggregateExpression1 => agg}
}
// If you add any new function support, please add tests in org.apache.spark.sql.SQLQuerySuite
// (in test "aggregation with codegen").
val computeFunctions = aggregatesToCompute.map {
case c @ Count(expr) =>
// If we're evaluating UnscaledValue(x), we can do Count on x directly, since its
// UnscaledValue will be null if and only if x is null; helps with Average on decimals
val toCount = expr match {
case UnscaledValue(e) => e
case _ => expr
}
val currentCount = AttributeReference("currentCount", LongType, nullable = false)()
val initialValue = Literal(0L)
val updateFunction = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount)
val result = currentCount
AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
case s @ Sum(expr) =>
val calcType =
expr.dataType match {
case DecimalType.Fixed(p, s) =>
DecimalType.bounded(p + 10, s)
case _ =>
expr.dataType
}
val currentSum = AttributeReference("currentSum", calcType, nullable = true)()
val initialValue = Literal.create(null, calcType)
// Coalesce avoids double calculation...
// but really, common sub expression elimination would be better....
val zero = Cast(Literal(0), calcType)
val updateFunction = Coalesce(
Add(
Coalesce(currentSum :: zero :: Nil),
Cast(expr, calcType)
) :: currentSum :: Nil)
val result =
expr.dataType match {
case DecimalType.Fixed(_, _) =>
Cast(currentSum, s.dataType)
case _ => currentSum
}
AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
case m @ Max(expr) =>
val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)()
val initialValue = Literal.create(null, expr.dataType)
val updateMax = MaxOf(currentMax, expr)
AggregateEvaluation(
currentMax :: Nil,
initialValue :: Nil,
updateMax :: Nil,
currentMax)
case m @ Min(expr) =>
val currentMin = AttributeReference("currentMin", expr.dataType, nullable = true)()
val initialValue = Literal.create(null, expr.dataType)
val updateMin = MinOf(currentMin, expr)
AggregateEvaluation(
currentMin :: Nil,
initialValue :: Nil,
updateMin :: Nil,
currentMin)
case CollectHashSet(Seq(expr)) =>
val set =
AttributeReference("hashSet", new OpenHashSetUDT(expr.dataType), nullable = false)()
val initialValue = NewSet(expr.dataType)
val addToSet = AddItemToSet(expr, set)
AggregateEvaluation(
set :: Nil,
initialValue :: Nil,
addToSet :: Nil,
set)
case CombineSetsAndCount(inputSet) =>
val elementType = inputSet.dataType.asInstanceOf[OpenHashSetUDT].elementType
val set =
AttributeReference("hashSet", new OpenHashSetUDT(elementType), nullable = false)()
val initialValue = NewSet(elementType)
val collectSets = CombineSets(set, inputSet)
AggregateEvaluation(
set :: Nil,
initialValue :: Nil,
collectSets :: Nil,
CountSet(set))
case o => sys.error(s"$o can't be codegened.")
}
val computationSchema = computeFunctions.flatMap(_.schema)
val resultMap: Map[TreeNodeRef, Expression] =
aggregatesToCompute.zip(computeFunctions).map {
case (agg, func) => new TreeNodeRef(agg) -> func.result
}.toMap
val namedGroups = groupingExpressions.zipWithIndex.map {
case (ne: NamedExpression, _) => (ne, ne.toAttribute)
case (e, i) => (e, Alias(e, s"GroupingExpr$i")().toAttribute)
}
// The set of expressions that produce the final output given the aggregation buffer and the
// grouping expressions.
val resultExpressions = aggregateExpressions.map(_.transform {
case e: Expression if resultMap.contains(new TreeNodeRef(e)) => resultMap(new TreeNodeRef(e))
case e: Expression =>
namedGroups.collectFirst {
case (expr, attr) if expr semanticEquals e => attr
}.getOrElse(e)
})
val aggregationBufferSchema: StructType = StructType.fromAttributes(computationSchema)
val groupKeySchema: StructType = {
val fields = groupingExpressions.zipWithIndex.map { case (expr, idx) =>
// This is a dummy field name
StructField(idx.toString, expr.dataType, expr.nullable)
}
StructType(fields)
}
val schemaSupportsUnsafe: Boolean = {
UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) &&
UnsafeProjection.canSupport(groupKeySchema)
}
child.execute().mapPartitions { iter =>
// Builds a new custom class for holding the results of aggregation for a group.
val initialValues = computeFunctions.flatMap(_.initialValues)
val newAggregationBuffer = newProjection(initialValues, child.output)
log.info(s"Initial values: ${initialValues.mkString(",")}")
// A projection that computes the group given an input tuple.
val groupProjection = newProjection(groupingExpressions, child.output)
log.info(s"Grouping Projection: ${groupingExpressions.mkString(",")}")
// A projection that is used to update the aggregate values for a group given a new tuple.
// This projection should be targeted at the current values for the group and then applied
// to a joined row of the current values with the new input row.
val updateExpressions = computeFunctions.flatMap(_.update)
val updateSchema = computeFunctions.flatMap(_.schema) ++ child.output
val updateProjection = newMutableProjection(updateExpressions, updateSchema)()
log.info(s"Update Expressions: ${updateExpressions.mkString(",")}")
// A projection that produces the final result, given a computation.
val resultProjectionBuilder =
newMutableProjection(
resultExpressions,
namedGroups.map(_._2) ++ computationSchema)
log.info(s"Result Projection: ${resultExpressions.mkString(",")}")
val joinedRow = new JoinedRow
if (!iter.hasNext) {
// This is an empty input, so return early so that we do not allocate data structures
// that won't be cleaned up (see SPARK-8357).
if (groupingExpressions.isEmpty) {
// This is a global aggregate, so return an empty aggregation buffer.
val resultProjection = resultProjectionBuilder()
Iterator(resultProjection(newAggregationBuffer(EmptyRow)))
} else {
// This is a grouped aggregate, so return an empty iterator.
Iterator[InternalRow]()
}
} else if (groupingExpressions.isEmpty) {
// TODO: Codegening anything other than the updateProjection is probably over kill.
val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow]
var currentRow: InternalRow = null
updateProjection.target(buffer)
while (iter.hasNext) {
currentRow = iter.next()
updateProjection(joinedRow(buffer, currentRow))
}
val resultProjection = resultProjectionBuilder()
Iterator(resultProjection(buffer))
} else if (unsafeEnabled && schemaSupportsUnsafe) {
assert(iter.hasNext, "There should be at least one row for this path")
log.info("Using Unsafe-based aggregator")
val pageSizeBytes = SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m")
val taskContext = TaskContext.get()
val aggregationMap = new UnsafeFixedWidthAggregationMap(
newAggregationBuffer(EmptyRow),
aggregationBufferSchema,
groupKeySchema,
taskContext.taskMemoryManager(),
SparkEnv.get.shuffleMemoryManager,
1024 * 16, // initial capacity
pageSizeBytes,
false // disable tracking of performance metrics
)
while (iter.hasNext) {
val currentRow: InternalRow = iter.next()
val groupKey: InternalRow = groupProjection(currentRow)
val aggregationBuffer = aggregationMap.getAggregationBuffer(groupKey)
if (aggregationBuffer == null) {
throw new IOException("Could not allocate memory to grow aggregation buffer")
}
updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow))
}
// Record memory used in the process
taskContext.internalMetricsToAccumulators(
InternalAccumulator.PEAK_EXECUTION_MEMORY).add(aggregationMap.getMemoryUsage)
new Iterator[InternalRow] {
private[this] val mapIterator = aggregationMap.iterator()
private[this] val resultProjection = resultProjectionBuilder()
private[this] var _hasNext = mapIterator.next()
def hasNext: Boolean = _hasNext
def next(): InternalRow = {
if (_hasNext) {
val result = resultProjection(joinedRow(mapIterator.getKey, mapIterator.getValue))
_hasNext = mapIterator.next()
if (_hasNext) {
result
} else {
// This is the last element in the iterator, so let's free the buffer. Before we do,
// though, we need to make a defensive copy of the result so that we don't return an
// object that might contain dangling pointers to the freed memory.
val resultCopy = result.copy()
aggregationMap.free()
resultCopy
}
} else {
throw new java.util.NoSuchElementException
}
}
}
} else {
if (unsafeEnabled) {
log.info("Not using Unsafe-based aggregator because it is not supported for this schema")
}
val buffers = new java.util.HashMap[InternalRow, MutableRow]()
var currentRow: InternalRow = null
while (iter.hasNext) {
currentRow = iter.next()
val currentGroup = groupProjection(currentRow)
var currentBuffer = buffers.get(currentGroup)
if (currentBuffer == null) {
currentBuffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow]
buffers.put(currentGroup, currentBuffer)
}
// Target the projection at the current aggregation buffer and then project the updated
// values.
updateProjection.target(currentBuffer)(joinedRow(currentBuffer, currentRow))
}
new Iterator[InternalRow] {
private[this] val resultIterator = buffers.entrySet.iterator()
private[this] val resultProjection = resultProjectionBuilder()
def hasNext: Boolean = resultIterator.hasNext
def next(): InternalRow = {
val currentGroup = resultIterator.next()
resultProjection(joinedRow(currentGroup.getKey, currentGroup.getValue))
}
}
}
}
}
}

View file

@ -136,32 +136,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
object HashAggregation extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
// Aggregations that can be performed in two phases, before and after the shuffle.
// Cases where all aggregates can be codegened.
case PartialAggregation(
namedGroupingAttributes,
rewrittenAggregateExpressions,
groupingExpressions,
partialComputation,
child)
if canBeCodeGened(
allAggregates(partialComputation) ++
allAggregates(rewrittenAggregateExpressions)) &&
codegenEnabled &&
!canBeConvertedToNewAggregation(plan) =>
execution.GeneratedAggregate(
partial = false,
namedGroupingAttributes,
rewrittenAggregateExpressions,
unsafeEnabled,
execution.GeneratedAggregate(
partial = true,
groupingExpressions,
partialComputation,
unsafeEnabled,
planLater(child))) :: Nil
// Cases where some aggregate can not be codegened
case PartialAggregation(
namedGroupingAttributes,
rewrittenAggregateExpressions,
@ -192,14 +166,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case _ => false
}
def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = aggs.forall {
case _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => true
// The generated set implementation is pretty limited ATM.
case CollectHashSet(exprs) if exprs.size == 1 &&
Seq(IntegerType, LongType).contains(exprs.head.dataType) => true
case _ => false
}
def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression1] =
exprs.flatMap(_.collect { case a: AggregateExpression1 => a })
}

View file

@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.DefaultParserDialect
import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.execution.GeneratedAggregate
import org.apache.spark.sql.functions._
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.test.SQLTestUtils
@ -263,7 +262,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
val df = sql(sqlText)
// First, check if we have GeneratedAggregate.
val hasGeneratedAgg = df.queryExecution.executedPlan
.collect { case _: GeneratedAggregate | _: aggregate.Aggregate => true }
.collect { case _: aggregate.Aggregate => true }
.nonEmpty
if (!hasGeneratedAgg) {
fail(
@ -1603,7 +1602,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
Row(new CalendarInterval(-(12 * 3 - 3), -(7L * MICROS_PER_WEEK + 123))))
}
test("aggregation with codegen updates peak execution memory") {
ignore("aggregation with codegen updates peak execution memory") {
withSQLConf(
(SQLConf.CODEGEN_ENABLED.key, "true"),
(SQLConf.USE_SQL_AGGREGATE2.key, "false")) {

View file

@ -1,48 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution
import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.test.TestSQLContext
class AggregateSuite extends SparkPlanTest {
test("SPARK-8357 unsafe aggregation path should not leak memory with empty input") {
val codegenDefault = TestSQLContext.getConf(SQLConf.CODEGEN_ENABLED)
val unsafeDefault = TestSQLContext.getConf(SQLConf.UNSAFE_ENABLED)
try {
TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, true)
TestSQLContext.setConf(SQLConf.UNSAFE_ENABLED, true)
val df = Seq.empty[(Int, Int)].toDF("a", "b")
checkAnswer(
df,
GeneratedAggregate(
partial = true,
Seq(df.col("b").expr),
Seq(Alias(Count(df.col("a").expr), "cnt")()),
unsafeEnabled = true,
_: SparkPlan),
Seq.empty
)
} finally {
TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault)
TestSQLContext.setConf(SQLConf.UNSAFE_ENABLED, unsafeDefault)
}
}
}