[SPARK-9674][SQL] Remove GeneratedAggregate.
The new aggregate replaces the old GeneratedAggregate. Author: Reynold Xin <rxin@databricks.com> Closes #7983 from rxin/remove-generated-agg and squashes the following commits: 8334aae [Reynold Xin] [SPARK-9674][SQL] Remove GeneratedAggregate.
This commit is contained in:
parent
119b590538
commit
9270bd06fd
|
@ -1,352 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.execution
|
||||
|
||||
import java.io.IOException
|
||||
|
||||
import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext}
|
||||
import org.apache.spark.annotation.DeveloperApi
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans.physical._
|
||||
import org.apache.spark.sql.catalyst.trees._
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
case class AggregateEvaluation(
|
||||
schema: Seq[Attribute],
|
||||
initialValues: Seq[Expression],
|
||||
update: Seq[Expression],
|
||||
result: Expression)
|
||||
|
||||
/**
|
||||
* :: DeveloperApi ::
|
||||
* Alternate version of aggregation that leverages projection and thus code generation.
|
||||
* Aggregations are converted into a set of projections from a aggregation buffer tuple back onto
|
||||
* itself. Currently only used for simple aggregations like SUM, COUNT, or AVERAGE are supported.
|
||||
*
|
||||
* @param partial if true then aggregation is done partially on local data without shuffling to
|
||||
* ensure all values where `groupingExpressions` are equal are present.
|
||||
* @param groupingExpressions expressions that are evaluated to determine grouping.
|
||||
* @param aggregateExpressions expressions that are computed for each group.
|
||||
* @param unsafeEnabled whether to allow Unsafe-based aggregation buffers to be used.
|
||||
* @param child the input data source.
|
||||
*/
|
||||
@DeveloperApi
|
||||
case class GeneratedAggregate(
|
||||
partial: Boolean,
|
||||
groupingExpressions: Seq[Expression],
|
||||
aggregateExpressions: Seq[NamedExpression],
|
||||
unsafeEnabled: Boolean,
|
||||
child: SparkPlan)
|
||||
extends UnaryNode {
|
||||
|
||||
override def requiredChildDistribution: Seq[Distribution] =
|
||||
if (partial) {
|
||||
UnspecifiedDistribution :: Nil
|
||||
} else {
|
||||
if (groupingExpressions == Nil) {
|
||||
AllTuples :: Nil
|
||||
} else {
|
||||
ClusteredDistribution(groupingExpressions) :: Nil
|
||||
}
|
||||
}
|
||||
|
||||
override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
|
||||
|
||||
protected override def doExecute(): RDD[InternalRow] = {
|
||||
val aggregatesToCompute = aggregateExpressions.flatMap { a =>
|
||||
a.collect { case agg: AggregateExpression1 => agg}
|
||||
}
|
||||
|
||||
// If you add any new function support, please add tests in org.apache.spark.sql.SQLQuerySuite
|
||||
// (in test "aggregation with codegen").
|
||||
val computeFunctions = aggregatesToCompute.map {
|
||||
case c @ Count(expr) =>
|
||||
// If we're evaluating UnscaledValue(x), we can do Count on x directly, since its
|
||||
// UnscaledValue will be null if and only if x is null; helps with Average on decimals
|
||||
val toCount = expr match {
|
||||
case UnscaledValue(e) => e
|
||||
case _ => expr
|
||||
}
|
||||
val currentCount = AttributeReference("currentCount", LongType, nullable = false)()
|
||||
val initialValue = Literal(0L)
|
||||
val updateFunction = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount)
|
||||
val result = currentCount
|
||||
|
||||
AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
|
||||
|
||||
case s @ Sum(expr) =>
|
||||
val calcType =
|
||||
expr.dataType match {
|
||||
case DecimalType.Fixed(p, s) =>
|
||||
DecimalType.bounded(p + 10, s)
|
||||
case _ =>
|
||||
expr.dataType
|
||||
}
|
||||
|
||||
val currentSum = AttributeReference("currentSum", calcType, nullable = true)()
|
||||
val initialValue = Literal.create(null, calcType)
|
||||
|
||||
// Coalesce avoids double calculation...
|
||||
// but really, common sub expression elimination would be better....
|
||||
val zero = Cast(Literal(0), calcType)
|
||||
val updateFunction = Coalesce(
|
||||
Add(
|
||||
Coalesce(currentSum :: zero :: Nil),
|
||||
Cast(expr, calcType)
|
||||
) :: currentSum :: Nil)
|
||||
val result =
|
||||
expr.dataType match {
|
||||
case DecimalType.Fixed(_, _) =>
|
||||
Cast(currentSum, s.dataType)
|
||||
case _ => currentSum
|
||||
}
|
||||
|
||||
AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
|
||||
|
||||
case m @ Max(expr) =>
|
||||
val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)()
|
||||
val initialValue = Literal.create(null, expr.dataType)
|
||||
val updateMax = MaxOf(currentMax, expr)
|
||||
|
||||
AggregateEvaluation(
|
||||
currentMax :: Nil,
|
||||
initialValue :: Nil,
|
||||
updateMax :: Nil,
|
||||
currentMax)
|
||||
|
||||
case m @ Min(expr) =>
|
||||
val currentMin = AttributeReference("currentMin", expr.dataType, nullable = true)()
|
||||
val initialValue = Literal.create(null, expr.dataType)
|
||||
val updateMin = MinOf(currentMin, expr)
|
||||
|
||||
AggregateEvaluation(
|
||||
currentMin :: Nil,
|
||||
initialValue :: Nil,
|
||||
updateMin :: Nil,
|
||||
currentMin)
|
||||
|
||||
case CollectHashSet(Seq(expr)) =>
|
||||
val set =
|
||||
AttributeReference("hashSet", new OpenHashSetUDT(expr.dataType), nullable = false)()
|
||||
val initialValue = NewSet(expr.dataType)
|
||||
val addToSet = AddItemToSet(expr, set)
|
||||
|
||||
AggregateEvaluation(
|
||||
set :: Nil,
|
||||
initialValue :: Nil,
|
||||
addToSet :: Nil,
|
||||
set)
|
||||
|
||||
case CombineSetsAndCount(inputSet) =>
|
||||
val elementType = inputSet.dataType.asInstanceOf[OpenHashSetUDT].elementType
|
||||
val set =
|
||||
AttributeReference("hashSet", new OpenHashSetUDT(elementType), nullable = false)()
|
||||
val initialValue = NewSet(elementType)
|
||||
val collectSets = CombineSets(set, inputSet)
|
||||
|
||||
AggregateEvaluation(
|
||||
set :: Nil,
|
||||
initialValue :: Nil,
|
||||
collectSets :: Nil,
|
||||
CountSet(set))
|
||||
|
||||
case o => sys.error(s"$o can't be codegened.")
|
||||
}
|
||||
|
||||
val computationSchema = computeFunctions.flatMap(_.schema)
|
||||
|
||||
val resultMap: Map[TreeNodeRef, Expression] =
|
||||
aggregatesToCompute.zip(computeFunctions).map {
|
||||
case (agg, func) => new TreeNodeRef(agg) -> func.result
|
||||
}.toMap
|
||||
|
||||
val namedGroups = groupingExpressions.zipWithIndex.map {
|
||||
case (ne: NamedExpression, _) => (ne, ne.toAttribute)
|
||||
case (e, i) => (e, Alias(e, s"GroupingExpr$i")().toAttribute)
|
||||
}
|
||||
|
||||
// The set of expressions that produce the final output given the aggregation buffer and the
|
||||
// grouping expressions.
|
||||
val resultExpressions = aggregateExpressions.map(_.transform {
|
||||
case e: Expression if resultMap.contains(new TreeNodeRef(e)) => resultMap(new TreeNodeRef(e))
|
||||
case e: Expression =>
|
||||
namedGroups.collectFirst {
|
||||
case (expr, attr) if expr semanticEquals e => attr
|
||||
}.getOrElse(e)
|
||||
})
|
||||
|
||||
val aggregationBufferSchema: StructType = StructType.fromAttributes(computationSchema)
|
||||
|
||||
val groupKeySchema: StructType = {
|
||||
val fields = groupingExpressions.zipWithIndex.map { case (expr, idx) =>
|
||||
// This is a dummy field name
|
||||
StructField(idx.toString, expr.dataType, expr.nullable)
|
||||
}
|
||||
StructType(fields)
|
||||
}
|
||||
|
||||
val schemaSupportsUnsafe: Boolean = {
|
||||
UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) &&
|
||||
UnsafeProjection.canSupport(groupKeySchema)
|
||||
}
|
||||
|
||||
child.execute().mapPartitions { iter =>
|
||||
// Builds a new custom class for holding the results of aggregation for a group.
|
||||
val initialValues = computeFunctions.flatMap(_.initialValues)
|
||||
val newAggregationBuffer = newProjection(initialValues, child.output)
|
||||
log.info(s"Initial values: ${initialValues.mkString(",")}")
|
||||
|
||||
// A projection that computes the group given an input tuple.
|
||||
val groupProjection = newProjection(groupingExpressions, child.output)
|
||||
log.info(s"Grouping Projection: ${groupingExpressions.mkString(",")}")
|
||||
|
||||
// A projection that is used to update the aggregate values for a group given a new tuple.
|
||||
// This projection should be targeted at the current values for the group and then applied
|
||||
// to a joined row of the current values with the new input row.
|
||||
val updateExpressions = computeFunctions.flatMap(_.update)
|
||||
val updateSchema = computeFunctions.flatMap(_.schema) ++ child.output
|
||||
val updateProjection = newMutableProjection(updateExpressions, updateSchema)()
|
||||
log.info(s"Update Expressions: ${updateExpressions.mkString(",")}")
|
||||
|
||||
// A projection that produces the final result, given a computation.
|
||||
val resultProjectionBuilder =
|
||||
newMutableProjection(
|
||||
resultExpressions,
|
||||
namedGroups.map(_._2) ++ computationSchema)
|
||||
log.info(s"Result Projection: ${resultExpressions.mkString(",")}")
|
||||
|
||||
val joinedRow = new JoinedRow
|
||||
|
||||
if (!iter.hasNext) {
|
||||
// This is an empty input, so return early so that we do not allocate data structures
|
||||
// that won't be cleaned up (see SPARK-8357).
|
||||
if (groupingExpressions.isEmpty) {
|
||||
// This is a global aggregate, so return an empty aggregation buffer.
|
||||
val resultProjection = resultProjectionBuilder()
|
||||
Iterator(resultProjection(newAggregationBuffer(EmptyRow)))
|
||||
} else {
|
||||
// This is a grouped aggregate, so return an empty iterator.
|
||||
Iterator[InternalRow]()
|
||||
}
|
||||
} else if (groupingExpressions.isEmpty) {
|
||||
// TODO: Codegening anything other than the updateProjection is probably over kill.
|
||||
val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow]
|
||||
var currentRow: InternalRow = null
|
||||
updateProjection.target(buffer)
|
||||
|
||||
while (iter.hasNext) {
|
||||
currentRow = iter.next()
|
||||
updateProjection(joinedRow(buffer, currentRow))
|
||||
}
|
||||
|
||||
val resultProjection = resultProjectionBuilder()
|
||||
Iterator(resultProjection(buffer))
|
||||
|
||||
} else if (unsafeEnabled && schemaSupportsUnsafe) {
|
||||
assert(iter.hasNext, "There should be at least one row for this path")
|
||||
log.info("Using Unsafe-based aggregator")
|
||||
val pageSizeBytes = SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m")
|
||||
val taskContext = TaskContext.get()
|
||||
val aggregationMap = new UnsafeFixedWidthAggregationMap(
|
||||
newAggregationBuffer(EmptyRow),
|
||||
aggregationBufferSchema,
|
||||
groupKeySchema,
|
||||
taskContext.taskMemoryManager(),
|
||||
SparkEnv.get.shuffleMemoryManager,
|
||||
1024 * 16, // initial capacity
|
||||
pageSizeBytes,
|
||||
false // disable tracking of performance metrics
|
||||
)
|
||||
|
||||
while (iter.hasNext) {
|
||||
val currentRow: InternalRow = iter.next()
|
||||
val groupKey: InternalRow = groupProjection(currentRow)
|
||||
val aggregationBuffer = aggregationMap.getAggregationBuffer(groupKey)
|
||||
if (aggregationBuffer == null) {
|
||||
throw new IOException("Could not allocate memory to grow aggregation buffer")
|
||||
}
|
||||
updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow))
|
||||
}
|
||||
|
||||
// Record memory used in the process
|
||||
taskContext.internalMetricsToAccumulators(
|
||||
InternalAccumulator.PEAK_EXECUTION_MEMORY).add(aggregationMap.getMemoryUsage)
|
||||
|
||||
new Iterator[InternalRow] {
|
||||
private[this] val mapIterator = aggregationMap.iterator()
|
||||
private[this] val resultProjection = resultProjectionBuilder()
|
||||
private[this] var _hasNext = mapIterator.next()
|
||||
|
||||
def hasNext: Boolean = _hasNext
|
||||
|
||||
def next(): InternalRow = {
|
||||
if (_hasNext) {
|
||||
val result = resultProjection(joinedRow(mapIterator.getKey, mapIterator.getValue))
|
||||
_hasNext = mapIterator.next()
|
||||
if (_hasNext) {
|
||||
result
|
||||
} else {
|
||||
// This is the last element in the iterator, so let's free the buffer. Before we do,
|
||||
// though, we need to make a defensive copy of the result so that we don't return an
|
||||
// object that might contain dangling pointers to the freed memory.
|
||||
val resultCopy = result.copy()
|
||||
aggregationMap.free()
|
||||
resultCopy
|
||||
}
|
||||
} else {
|
||||
throw new java.util.NoSuchElementException
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (unsafeEnabled) {
|
||||
log.info("Not using Unsafe-based aggregator because it is not supported for this schema")
|
||||
}
|
||||
val buffers = new java.util.HashMap[InternalRow, MutableRow]()
|
||||
|
||||
var currentRow: InternalRow = null
|
||||
while (iter.hasNext) {
|
||||
currentRow = iter.next()
|
||||
val currentGroup = groupProjection(currentRow)
|
||||
var currentBuffer = buffers.get(currentGroup)
|
||||
if (currentBuffer == null) {
|
||||
currentBuffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow]
|
||||
buffers.put(currentGroup, currentBuffer)
|
||||
}
|
||||
// Target the projection at the current aggregation buffer and then project the updated
|
||||
// values.
|
||||
updateProjection.target(currentBuffer)(joinedRow(currentBuffer, currentRow))
|
||||
}
|
||||
|
||||
new Iterator[InternalRow] {
|
||||
private[this] val resultIterator = buffers.entrySet.iterator()
|
||||
private[this] val resultProjection = resultProjectionBuilder()
|
||||
|
||||
def hasNext: Boolean = resultIterator.hasNext
|
||||
|
||||
def next(): InternalRow = {
|
||||
val currentGroup = resultIterator.next()
|
||||
resultProjection(joinedRow(currentGroup.getKey, currentGroup.getValue))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -136,32 +136,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
|||
object HashAggregation extends Strategy {
|
||||
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
|
||||
// Aggregations that can be performed in two phases, before and after the shuffle.
|
||||
|
||||
// Cases where all aggregates can be codegened.
|
||||
case PartialAggregation(
|
||||
namedGroupingAttributes,
|
||||
rewrittenAggregateExpressions,
|
||||
groupingExpressions,
|
||||
partialComputation,
|
||||
child)
|
||||
if canBeCodeGened(
|
||||
allAggregates(partialComputation) ++
|
||||
allAggregates(rewrittenAggregateExpressions)) &&
|
||||
codegenEnabled &&
|
||||
!canBeConvertedToNewAggregation(plan) =>
|
||||
execution.GeneratedAggregate(
|
||||
partial = false,
|
||||
namedGroupingAttributes,
|
||||
rewrittenAggregateExpressions,
|
||||
unsafeEnabled,
|
||||
execution.GeneratedAggregate(
|
||||
partial = true,
|
||||
groupingExpressions,
|
||||
partialComputation,
|
||||
unsafeEnabled,
|
||||
planLater(child))) :: Nil
|
||||
|
||||
// Cases where some aggregate can not be codegened
|
||||
case PartialAggregation(
|
||||
namedGroupingAttributes,
|
||||
rewrittenAggregateExpressions,
|
||||
|
@ -192,14 +166,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
|||
case _ => false
|
||||
}
|
||||
|
||||
def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = aggs.forall {
|
||||
case _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => true
|
||||
// The generated set implementation is pretty limited ATM.
|
||||
case CollectHashSet(exprs) if exprs.size == 1 &&
|
||||
Seq(IntegerType, LongType).contains(exprs.head.dataType) => true
|
||||
case _ => false
|
||||
}
|
||||
|
||||
def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression1] =
|
||||
exprs.flatMap(_.collect { case a: AggregateExpression1 => a })
|
||||
}
|
||||
|
|
|
@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
|
|||
import org.apache.spark.sql.catalyst.DefaultParserDialect
|
||||
import org.apache.spark.sql.catalyst.errors.DialectException
|
||||
import org.apache.spark.sql.execution.aggregate
|
||||
import org.apache.spark.sql.execution.GeneratedAggregate
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.TestData._
|
||||
import org.apache.spark.sql.test.SQLTestUtils
|
||||
|
@ -263,7 +262,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
|
|||
val df = sql(sqlText)
|
||||
// First, check if we have GeneratedAggregate.
|
||||
val hasGeneratedAgg = df.queryExecution.executedPlan
|
||||
.collect { case _: GeneratedAggregate | _: aggregate.Aggregate => true }
|
||||
.collect { case _: aggregate.Aggregate => true }
|
||||
.nonEmpty
|
||||
if (!hasGeneratedAgg) {
|
||||
fail(
|
||||
|
@ -1603,7 +1602,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
|
|||
Row(new CalendarInterval(-(12 * 3 - 3), -(7L * MICROS_PER_WEEK + 123))))
|
||||
}
|
||||
|
||||
test("aggregation with codegen updates peak execution memory") {
|
||||
ignore("aggregation with codegen updates peak execution memory") {
|
||||
withSQLConf(
|
||||
(SQLConf.CODEGEN_ENABLED.key, "true"),
|
||||
(SQLConf.USE_SQL_AGGREGATE2.key, "false")) {
|
||||
|
|
|
@ -1,48 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.execution
|
||||
|
||||
import org.apache.spark.sql.SQLConf
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.test.TestSQLContext
|
||||
|
||||
class AggregateSuite extends SparkPlanTest {
|
||||
|
||||
test("SPARK-8357 unsafe aggregation path should not leak memory with empty input") {
|
||||
val codegenDefault = TestSQLContext.getConf(SQLConf.CODEGEN_ENABLED)
|
||||
val unsafeDefault = TestSQLContext.getConf(SQLConf.UNSAFE_ENABLED)
|
||||
try {
|
||||
TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, true)
|
||||
TestSQLContext.setConf(SQLConf.UNSAFE_ENABLED, true)
|
||||
val df = Seq.empty[(Int, Int)].toDF("a", "b")
|
||||
checkAnswer(
|
||||
df,
|
||||
GeneratedAggregate(
|
||||
partial = true,
|
||||
Seq(df.col("b").expr),
|
||||
Seq(Alias(Count(df.col("a").expr), "cnt")()),
|
||||
unsafeEnabled = true,
|
||||
_: SparkPlan),
|
||||
Seq.empty
|
||||
)
|
||||
} finally {
|
||||
TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault)
|
||||
TestSQLContext.setConf(SQLConf.UNSAFE_ENABLED, unsafeDefault)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue