[SPARK-10605][SQL] Create native collect_list/collect_set aggregates

## What changes were proposed in this pull request?
We currently use the Hive implementations for the collect_list/collect_set aggregate functions. This has a few major drawbacks: the use of HiveUDAF (which has quite a bit of overhead) and the lack of support for struct datatypes. This PR adds native implementation of these functions to Spark.

The size of the collected list/set may vary, this means we cannot use the fast, Tungsten, aggregation path to perform the aggregation, and that we fallback to the slower sort based path. Another big issue with these operators is that when the size of the collected list/set grows too large, we can start experiencing large GC pauzes and OOMEs.

This `collect*` aggregates implemented in this PR rely on the sort based aggregate path for correctness. They maintain their own internal buffer which holds the rows for one group at a time. The sortbased aggregation path is triggered by disabling `partialAggregation` for these aggregates (which is kinda funny); this technique is also employed in `org.apache.spark.sql.hiveHiveUDAFFunction`.

I have done some performance testing:
```scala
import org.apache.spark.sql.{Dataset, Row}

sql("create function collect_list2 as 'org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectList'")

val df = range(0, 10000000).select($"id", (rand(213123L) * 100000).cast("int").as("grp"))
df.select(countDistinct($"grp")).show

def benchmark(name: String, plan: Dataset[Row], maxItr: Int = 5): Unit = {
   // Do not measure planning.
   plan1.queryExecution.executedPlan

   // Execute the plan a number of times and average the result.
   val start = System.nanoTime
   var i = 0
   while (i < maxItr) {
     plan.rdd.foreach(row => Unit)
     i += 1
   }
   val time = (System.nanoTime - start) / (maxItr * 1000000L)
   println(s"[$name] $maxItr iterations completed in an average time of $time ms.")
}

val plan1 = df.groupBy($"grp").agg(collect_list($"id"))
val plan2 = df.groupBy($"grp").agg(callUDF("collect_list2", $"id"))

benchmark("Spark collect_list", plan1)
...
> [Spark collect_list] 5 iterations completed in an average time of 3371 ms.

benchmark("Hive collect_list", plan2)
...
> [Hive collect_list] 5 iterations completed in an average time of 9109 ms.
```
Performance is improved by a factor 2-3.

## How was this patch tested?
Added tests to `DataFrameAggregateSuite`.

Author: Herman van Hovell <hvanhovell@questtec.nl>

Closes #12874 from hvanhovell/implode.
This commit is contained in:
Herman van Hovell 2016-05-12 13:56:00 -07:00 committed by Reynold Xin
parent a57aadae84
commit bb1362eb3b
6 changed files with 149 additions and 37 deletions

View file

@ -252,6 +252,8 @@ object FunctionRegistry {
expression[VarianceSamp]("variance"),
expression[VariancePop]("var_pop"),
expression[VarianceSamp]("var_samp"),
expression[CollectList]("collect_list"),
expression[CollectSet]("collect_set"),
// string functions
expression[Ascii]("ascii"),

View file

@ -0,0 +1,119 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions.aggregate
import scala.collection.generic.Growable
import scala.collection.mutable
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
/**
* The Collect aggregate function collects all seen expression values into a list of values.
*
* The operator is bound to the slower sort based aggregation path because the number of
* elements (and their memory usage) can not be determined in advance. This also means that the
* collected elements are stored on heap, and that too many elements can cause GC pauses and
* eventually Out of Memory Errors.
*/
abstract class Collect extends ImperativeAggregate {
val child: Expression
override def children: Seq[Expression] = child :: Nil
override def nullable: Boolean = true
override def dataType: DataType = ArrayType(child.dataType)
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
override def supportsPartial: Boolean = false
override def aggBufferAttributes: Seq[AttributeReference] = Nil
override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
override def inputAggBufferAttributes: Seq[AttributeReference] = Nil
protected[this] val buffer: Growable[Any] with Iterable[Any]
override def initialize(b: MutableRow): Unit = {
buffer.clear()
}
override def update(b: MutableRow, input: InternalRow): Unit = {
buffer += child.eval(input)
}
override def merge(buffer: MutableRow, input: InternalRow): Unit = {
sys.error("Collect cannot be used in partial aggregations.")
}
override def eval(input: InternalRow): Any = {
new GenericArrayData(buffer.toArray)
}
}
/**
* Collect a list of elements.
*/
@ExpressionDescription(
usage = "_FUNC_(expr) - Collects and returns a list of non-unique elements.")
case class CollectList(
child: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends Collect {
def this(child: Expression) = this(child, 0, 0)
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)
override def prettyName: String = "collect_list"
override protected[this] val buffer: mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty
}
/**
* Collect a list of unique elements.
*/
@ExpressionDescription(
usage = "_FUNC_(expr) - Collects and returns a set of unique elements.")
case class CollectSet(
child: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends Collect {
def this(child: Expression) = this(child, 0, 0)
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)
override def prettyName: String = "collect_set"
override protected[this] val buffer: mutable.HashSet[Any] = mutable.HashSet.empty
}

View file

@ -195,18 +195,14 @@ object functions {
/**
* Aggregate function: returns a list of objects with duplicates.
*
* For now this is an alias for the collect_list Hive UDAF.
*
* @group agg_funcs
* @since 1.6.0
*/
def collect_list(e: Column): Column = callUDF("collect_list", e)
def collect_list(e: Column): Column = withAggregateFunction { CollectList(e.expr) }
/**
* Aggregate function: returns a list of objects with duplicates.
*
* For now this is an alias for the collect_list Hive UDAF.
*
* @group agg_funcs
* @since 1.6.0
*/
@ -215,18 +211,14 @@ object functions {
/**
* Aggregate function: returns a set of objects with duplicate elements eliminated.
*
* For now this is an alias for the collect_set Hive UDAF.
*
* @group agg_funcs
* @since 1.6.0
*/
def collect_set(e: Column): Column = callUDF("collect_set", e)
def collect_set(e: Column): Column = withAggregateFunction { CollectSet(e.expr) }
/**
* Aggregate function: returns a set of objects with duplicate elements eliminated.
*
* For now this is an alias for the collect_set Hive UDAF.
*
* @group agg_funcs
* @since 1.6.0
*/

View file

@ -431,6 +431,32 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
Row(null, null, null, null, null))
}
test("collect functions") {
val df = Seq((1, 2), (2, 2), (3, 4)).toDF("a", "b")
checkAnswer(
df.select(collect_list($"a"), collect_list($"b")),
Seq(Row(Seq(1, 2, 3), Seq(2, 2, 4)))
)
checkAnswer(
df.select(collect_set($"a"), collect_set($"b")),
Seq(Row(Seq(1, 2, 3), Seq(2, 4)))
)
}
test("collect functions structs") {
val df = Seq((1, 2, 2), (2, 2, 2), (3, 4, 1))
.toDF("a", "x", "y")
.select($"a", struct($"x", $"y").as("b"))
checkAnswer(
df.select(collect_list($"a"), sort_array(collect_list($"b"))),
Seq(Row(Seq(1, 2, 3), Seq(Row(2, 2), Row(2, 2), Row(4, 1))))
)
checkAnswer(
df.select(collect_set($"a"), sort_array(collect_set($"b"))),
Seq(Row(Seq(1, 2, 3), Seq(Row(2, 2), Row(4, 1))))
)
}
test("SPARK-14664: Decimal sum/avg over window should work.") {
checkAnswer(
spark.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"),

View file

@ -222,20 +222,4 @@ private[sql] class HiveSessionCatalog(
}
}
}
// Pre-load a few commonly used Hive built-in functions.
HiveSessionCatalog.preloadedHiveBuiltinFunctions.foreach {
case (functionName, clazz) =>
val builder = makeFunctionBuilder(functionName, clazz)
val info = new ExpressionInfo(clazz.getCanonicalName, functionName)
createTempFunction(functionName, info, builder, ignoreIfExists = false)
}
}
private[sql] object HiveSessionCatalog {
// This is the list of Hive's built-in functions that are commonly used and we want to
// pre-load when we create the FunctionRegistry.
val preloadedHiveBuiltinFunctions =
("collect_set", classOf[org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectSet]) ::
("collect_list", classOf[org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectList]) :: Nil
}

View file

@ -58,17 +58,6 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with
)
}
test("collect functions") {
checkAnswer(
testData.select(collect_list($"a"), collect_list($"b")),
Seq(Row(Seq(1, 2, 3), Seq(2, 2, 4)))
)
checkAnswer(
testData.select(collect_set($"a"), collect_set($"b")),
Seq(Row(Seq(1, 2, 3), Seq(2, 4)))
)
}
test("cube") {
checkAnswer(
testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")),