[SQL] SPARK-1371 Hash Aggregation Improvements
Given: ```scala case class Data(a: Int, b: Int) val rdd = sparkContext .parallelize(1 to 200) .flatMap(_ => (1 to 50000).map(i => Data(i % 100, i))) rdd.registerAsTable("data") cacheTable("data") ``` Before: ``` SELECT COUNT(*) FROM data:[10000000] 16795.567ms SELECT a, SUM(b) FROM data GROUP BY a 7536.436ms SELECT SUM(b) FROM data 10954.1ms ``` After: ``` SELECT COUNT(*) FROM data:[10000000] 1372.175ms SELECT a, SUM(b) FROM data GROUP BY a 2070.446ms SELECT SUM(b) FROM data 958.969ms ``` Author: Michael Armbrust <michael@databricks.com> Closes #295 from marmbrus/hashAgg and squashes the following commits: ec63575 [Michael Armbrust] Add comment. d0495a9 [Michael Armbrust] Use scaladoc instead. b4a6887 [Michael Armbrust] Address review comments. a2d90ba [Michael Armbrust] Capture child output statically to avoid issues with generators and serialization. 7c13112 [Michael Armbrust] Rewrite Aggregate operator to stream input and use projections. Remove unused local RDD functions implicits. 5096f99 [Michael Armbrust] Make HiveUDAF fields transient since object inspectors are not serializable. 6a4b671 [Michael Armbrust] Add option to avoid binding operators expressions automatically. 92cca08 [Michael Armbrust] Always include serialization debug info when running tests. 1279df2 [Michael Armbrust] Increase default number of partitions.
This commit is contained in:
parent
87d0928a33
commit
accd0999f9
|
@ -178,6 +178,7 @@ object SparkBuild extends Build {
|
||||||
fork := true,
|
fork := true,
|
||||||
javaOptions in Test += "-Dspark.home=" + sparkHome,
|
javaOptions in Test += "-Dspark.home=" + sparkHome,
|
||||||
javaOptions in Test += "-Dspark.testing=1",
|
javaOptions in Test += "-Dspark.testing=1",
|
||||||
|
javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true",
|
||||||
javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark").map { case (k,v) => s"-D$k=$v" }.toSeq,
|
javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark").map { case (k,v) => s"-D$k=$v" }.toSeq,
|
||||||
javaOptions += "-Xmx3g",
|
javaOptions += "-Xmx3g",
|
||||||
// Show full stack trace and duration in test cases.
|
// Show full stack trace and duration in test cases.
|
||||||
|
|
|
@ -48,11 +48,17 @@ case class BoundReference(ordinal: Int, baseReference: Attribute)
|
||||||
override def apply(input: Row): Any = input(ordinal)
|
override def apply(input: Row): Any = input(ordinal)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Used to denote operators that do their own binding of attributes internally.
|
||||||
|
*/
|
||||||
|
trait NoBind { self: trees.TreeNode[_] => }
|
||||||
|
|
||||||
class BindReferences[TreeNode <: QueryPlan[TreeNode]] extends Rule[TreeNode] {
|
class BindReferences[TreeNode <: QueryPlan[TreeNode]] extends Rule[TreeNode] {
|
||||||
import BindReferences._
|
import BindReferences._
|
||||||
|
|
||||||
def apply(plan: TreeNode): TreeNode = {
|
def apply(plan: TreeNode): TreeNode = {
|
||||||
plan.transform {
|
plan.transform {
|
||||||
|
case n: NoBind => n.asInstanceOf[TreeNode]
|
||||||
case leafNode if leafNode.children.isEmpty => leafNode
|
case leafNode if leafNode.children.isEmpty => leafNode
|
||||||
case unaryNode if unaryNode.children.size == 1 => unaryNode.transformExpressions { case e =>
|
case unaryNode if unaryNode.children.size == 1 => unaryNode.transformExpressions { case e =>
|
||||||
bindReference(e, unaryNode.children.head.output)
|
bindReference(e, unaryNode.children.head.output)
|
||||||
|
|
|
@ -28,9 +28,9 @@ class Projection(expressions: Seq[Expression]) extends (Row => Row) {
|
||||||
|
|
||||||
protected val exprArray = expressions.toArray
|
protected val exprArray = expressions.toArray
|
||||||
def apply(input: Row): Row = {
|
def apply(input: Row): Row = {
|
||||||
val outputArray = new Array[Any](exprArray.size)
|
val outputArray = new Array[Any](exprArray.length)
|
||||||
var i = 0
|
var i = 0
|
||||||
while (i < exprArray.size) {
|
while (i < exprArray.length) {
|
||||||
outputArray(i) = exprArray(i).apply(input)
|
outputArray(i) = exprArray(i).apply(input)
|
||||||
i += 1
|
i += 1
|
||||||
}
|
}
|
||||||
|
@ -57,7 +57,7 @@ case class MutableProjection(expressions: Seq[Expression]) extends (Row => Row)
|
||||||
|
|
||||||
def apply(input: Row): Row = {
|
def apply(input: Row): Row = {
|
||||||
var i = 0
|
var i = 0
|
||||||
while (i < exprArray.size) {
|
while (i < exprArray.length) {
|
||||||
mutableRow(i) = exprArray(i).apply(input)
|
mutableRow(i) = exprArray(i).apply(input)
|
||||||
i += 1
|
i += 1
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,7 +27,7 @@ abstract class AggregateExpression extends Expression {
|
||||||
* Creates a new instance that can be used to compute this aggregate expression for a group
|
* Creates a new instance that can be used to compute this aggregate expression for a group
|
||||||
* of input rows/
|
* of input rows/
|
||||||
*/
|
*/
|
||||||
def newInstance: AggregateFunction
|
def newInstance(): AggregateFunction
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -75,7 +75,7 @@ abstract class AggregateFunction
|
||||||
override def apply(input: Row): Any
|
override def apply(input: Row): Any
|
||||||
|
|
||||||
// Do we really need this?
|
// Do we really need this?
|
||||||
def newInstance = makeCopy(productIterator.map { case a: AnyRef => a }.toArray)
|
def newInstance() = makeCopy(productIterator.map { case a: AnyRef => a }.toArray)
|
||||||
}
|
}
|
||||||
|
|
||||||
case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
|
case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
|
||||||
|
@ -89,7 +89,7 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod
|
||||||
SplitEvaluation(Sum(partialCount.toAttribute), partialCount :: Nil)
|
SplitEvaluation(Sum(partialCount.toAttribute), partialCount :: Nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def newInstance = new CountFunction(child, this)
|
override def newInstance()= new CountFunction(child, this)
|
||||||
}
|
}
|
||||||
|
|
||||||
case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpression {
|
case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpression {
|
||||||
|
@ -98,7 +98,7 @@ case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpressi
|
||||||
def nullable = false
|
def nullable = false
|
||||||
def dataType = IntegerType
|
def dataType = IntegerType
|
||||||
override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")}})"
|
override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")}})"
|
||||||
override def newInstance = new CountDistinctFunction(expressions, this)
|
override def newInstance()= new CountDistinctFunction(expressions, this)
|
||||||
}
|
}
|
||||||
|
|
||||||
case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
|
case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
|
||||||
|
@ -118,7 +118,7 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
|
||||||
partialCount :: partialSum :: Nil)
|
partialCount :: partialSum :: Nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def newInstance = new AverageFunction(child, this)
|
override def newInstance()= new AverageFunction(child, this)
|
||||||
}
|
}
|
||||||
|
|
||||||
case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
|
case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
|
||||||
|
@ -134,7 +134,7 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
|
||||||
partialSum :: Nil)
|
partialSum :: Nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def newInstance = new SumFunction(child, this)
|
override def newInstance()= new SumFunction(child, this)
|
||||||
}
|
}
|
||||||
|
|
||||||
case class SumDistinct(child: Expression)
|
case class SumDistinct(child: Expression)
|
||||||
|
@ -145,7 +145,7 @@ case class SumDistinct(child: Expression)
|
||||||
def dataType = child.dataType
|
def dataType = child.dataType
|
||||||
override def toString = s"SUM(DISTINCT $child)"
|
override def toString = s"SUM(DISTINCT $child)"
|
||||||
|
|
||||||
override def newInstance = new SumDistinctFunction(child, this)
|
override def newInstance()= new SumDistinctFunction(child, this)
|
||||||
}
|
}
|
||||||
|
|
||||||
case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
|
case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
|
||||||
|
@ -160,7 +160,7 @@ case class First(child: Expression) extends PartialAggregate with trees.UnaryNod
|
||||||
First(partialFirst.toAttribute),
|
First(partialFirst.toAttribute),
|
||||||
partialFirst :: Nil)
|
partialFirst :: Nil)
|
||||||
}
|
}
|
||||||
override def newInstance = new FirstFunction(child, this)
|
override def newInstance()= new FirstFunction(child, this)
|
||||||
}
|
}
|
||||||
|
|
||||||
case class AverageFunction(expr: Expression, base: AggregateExpression)
|
case class AverageFunction(expr: Expression, base: AggregateExpression)
|
||||||
|
|
|
@ -1,100 +0,0 @@
|
||||||
/*
|
|
||||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
|
||||||
* contributor license agreements. See the NOTICE file distributed with
|
|
||||||
* this work for additional information regarding copyright ownership.
|
|
||||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
|
||||||
* (the "License"); you may not use this file except in compliance with
|
|
||||||
* the License. You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.apache.spark.rdd
|
|
||||||
|
|
||||||
import scala.language.implicitConversions
|
|
||||||
|
|
||||||
import scala.reflect._
|
|
||||||
import scala.collection.mutable.ArrayBuffer
|
|
||||||
|
|
||||||
import org.apache.spark.{Aggregator, InterruptibleIterator, Logging}
|
|
||||||
import org.apache.spark.util.collection.AppendOnlyMap
|
|
||||||
|
|
||||||
/* Implicit conversions */
|
|
||||||
import org.apache.spark.SparkContext._
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Extra functions on RDDs that perform only local operations. These can be used when data has
|
|
||||||
* already been partitioned correctly.
|
|
||||||
*/
|
|
||||||
private[spark] class PartitionLocalRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
|
|
||||||
extends Logging
|
|
||||||
with Serializable {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Cogroup corresponding partitions of `this` and `other`. These two RDDs should have
|
|
||||||
* the same number of partitions. Partitions of these two RDDs are cogrouped
|
|
||||||
* according to the indexes of partitions. If we have two RDDs and
|
|
||||||
* each of them has n partitions, we will cogroup the partition i from `this`
|
|
||||||
* with the partition i from `other`.
|
|
||||||
* This function will not introduce a shuffling operation.
|
|
||||||
*/
|
|
||||||
def cogroupLocally[W](other: RDD[(K, W)]): RDD[(K, (Seq[V], Seq[W]))] = {
|
|
||||||
val cg = self.zipPartitions(other)((iter1:Iterator[(K, V)], iter2:Iterator[(K, W)]) => {
|
|
||||||
val map = new AppendOnlyMap[K, Seq[ArrayBuffer[Any]]]
|
|
||||||
|
|
||||||
val update: (Boolean, Seq[ArrayBuffer[Any]]) => Seq[ArrayBuffer[Any]] = (hadVal, oldVal) => {
|
|
||||||
if (hadVal) oldVal else Array.fill(2)(new ArrayBuffer[Any])
|
|
||||||
}
|
|
||||||
|
|
||||||
val getSeq = (k: K) => {
|
|
||||||
map.changeValue(k, update)
|
|
||||||
}
|
|
||||||
|
|
||||||
iter1.foreach { kv => getSeq(kv._1)(0) += kv._2 }
|
|
||||||
iter2.foreach { kv => getSeq(kv._1)(1) += kv._2 }
|
|
||||||
|
|
||||||
map.iterator
|
|
||||||
}).mapValues { case Seq(vs, ws) => (vs.asInstanceOf[Seq[V]], ws.asInstanceOf[Seq[W]])}
|
|
||||||
|
|
||||||
cg
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Group the values for each key within a partition of the RDD into a single sequence.
|
|
||||||
* This function will not introduce a shuffling operation.
|
|
||||||
*/
|
|
||||||
def groupByKeyLocally(): RDD[(K, Seq[V])] = {
|
|
||||||
def createCombiner(v: V) = ArrayBuffer(v)
|
|
||||||
def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v
|
|
||||||
val aggregator = new Aggregator[K, V, ArrayBuffer[V]](createCombiner, mergeValue, _ ++ _)
|
|
||||||
val bufs = self.mapPartitionsWithContext((context, iter) => {
|
|
||||||
new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context))
|
|
||||||
}, preservesPartitioning = true)
|
|
||||||
bufs.asInstanceOf[RDD[(K, Seq[V])]]
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Join corresponding partitions of `this` and `other`.
|
|
||||||
* If we have two RDDs and each of them has n partitions,
|
|
||||||
* we will join the partition i from `this` with the partition i from `other`.
|
|
||||||
* This function will not introduce a shuffling operation.
|
|
||||||
*/
|
|
||||||
def joinLocally[W](other: RDD[(K, W)]): RDD[(K, (V, W))] = {
|
|
||||||
cogroupLocally(other).flatMapValues {
|
|
||||||
case (vs, ws) => for (v <- vs.iterator; w <- ws.iterator) yield (v, w)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private[spark] object PartitionLocalRDDFunctions {
|
|
||||||
implicit def rddToPartitionLocalRDDFunctions[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]) =
|
|
||||||
new PartitionLocalRDDFunctions(rdd)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
|
@ -76,7 +76,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
|
||||||
*/
|
*/
|
||||||
object AddExchange extends Rule[SparkPlan] {
|
object AddExchange extends Rule[SparkPlan] {
|
||||||
// TODO: Determine the number of partitions.
|
// TODO: Determine the number of partitions.
|
||||||
val numPartitions = 8
|
val numPartitions = 150
|
||||||
|
|
||||||
def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
|
def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
|
||||||
case operator: SparkPlan =>
|
case operator: SparkPlan =>
|
||||||
|
|
|
@ -17,14 +17,13 @@
|
||||||
|
|
||||||
package org.apache.spark.sql.execution
|
package org.apache.spark.sql.execution
|
||||||
|
|
||||||
|
import java.util.HashMap
|
||||||
|
|
||||||
import org.apache.spark.SparkContext
|
import org.apache.spark.SparkContext
|
||||||
import org.apache.spark.sql.catalyst.errors._
|
import org.apache.spark.sql.catalyst.errors._
|
||||||
import org.apache.spark.sql.catalyst.expressions._
|
import org.apache.spark.sql.catalyst.expressions._
|
||||||
import org.apache.spark.sql.catalyst.plans.physical._
|
import org.apache.spark.sql.catalyst.plans.physical._
|
||||||
|
|
||||||
/* Implicit conversions */
|
|
||||||
import org.apache.spark.rdd.PartitionLocalRDDFunctions._
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Groups input data by `groupingExpressions` and computes the `aggregateExpressions` for each
|
* Groups input data by `groupingExpressions` and computes the `aggregateExpressions` for each
|
||||||
* group.
|
* group.
|
||||||
|
@ -40,7 +39,7 @@ case class Aggregate(
|
||||||
groupingExpressions: Seq[Expression],
|
groupingExpressions: Seq[Expression],
|
||||||
aggregateExpressions: Seq[NamedExpression],
|
aggregateExpressions: Seq[NamedExpression],
|
||||||
child: SparkPlan)(@transient sc: SparkContext)
|
child: SparkPlan)(@transient sc: SparkContext)
|
||||||
extends UnaryNode {
|
extends UnaryNode with NoBind {
|
||||||
|
|
||||||
override def requiredChildDistribution =
|
override def requiredChildDistribution =
|
||||||
if (partial) {
|
if (partial) {
|
||||||
|
@ -55,61 +54,149 @@ case class Aggregate(
|
||||||
|
|
||||||
override def otherCopyArgs = sc :: Nil
|
override def otherCopyArgs = sc :: Nil
|
||||||
|
|
||||||
|
// HACK: Generators don't correctly preserve their output through serializations so we grab
|
||||||
|
// out child's output attributes statically here.
|
||||||
|
val childOutput = child.output
|
||||||
|
|
||||||
def output = aggregateExpressions.map(_.toAttribute)
|
def output = aggregateExpressions.map(_.toAttribute)
|
||||||
|
|
||||||
/* Replace all aggregate expressions with spark functions that will compute the result. */
|
/**
|
||||||
def createAggregateImplementations() = aggregateExpressions.map { agg =>
|
* An aggregate that needs to be computed for each row in a group.
|
||||||
val impl = agg transform {
|
*
|
||||||
case a: AggregateExpression => a.newInstance
|
* @param unbound Unbound version of this aggregate, used for result substitution.
|
||||||
|
* @param aggregate A bound copy of this aggregate used to create a new aggregation buffer.
|
||||||
|
* @param resultAttribute An attribute used to refer to the result of this aggregate in the final
|
||||||
|
* output.
|
||||||
|
*/
|
||||||
|
case class ComputedAggregate(
|
||||||
|
unbound: AggregateExpression,
|
||||||
|
aggregate: AggregateExpression,
|
||||||
|
resultAttribute: AttributeReference)
|
||||||
|
|
||||||
|
/** A list of aggregates that need to be computed for each group. */
|
||||||
|
@transient
|
||||||
|
lazy val computedAggregates = aggregateExpressions.flatMap { agg =>
|
||||||
|
agg.collect {
|
||||||
|
case a: AggregateExpression =>
|
||||||
|
ComputedAggregate(
|
||||||
|
a,
|
||||||
|
BindReferences.bindReference(a, childOutput).asInstanceOf[AggregateExpression],
|
||||||
|
AttributeReference(s"aggResult:$a", a.dataType, nullable = true)())
|
||||||
}
|
}
|
||||||
|
}.toArray
|
||||||
|
|
||||||
val remainingAttributes = impl.collect { case a: Attribute => a }
|
/** The schema of the result of all aggregate evaluations */
|
||||||
// If any references exist that are not inside agg functions then the must be grouping exprs
|
@transient
|
||||||
// in this case we must rebind them to the grouping tuple.
|
lazy val computedSchema = computedAggregates.map(_.resultAttribute)
|
||||||
if (remainingAttributes.nonEmpty) {
|
|
||||||
val unaliasedAggregateExpr = agg transform { case Alias(c, _) => c }
|
|
||||||
|
|
||||||
// An exact match with a grouping expression
|
/** Creates a new aggregate buffer for a group. */
|
||||||
val exactGroupingExpr = groupingExpressions.indexOf(unaliasedAggregateExpr) match {
|
def newAggregateBuffer(): Array[AggregateFunction] = {
|
||||||
case -1 => None
|
val buffer = new Array[AggregateFunction](computedAggregates.length)
|
||||||
case ordinal => Some(BoundReference(ordinal, Alias(impl, "AGGEXPR")().toAttribute))
|
var i = 0
|
||||||
}
|
while (i < computedAggregates.length) {
|
||||||
|
buffer(i) = computedAggregates(i).aggregate.newInstance()
|
||||||
|
i += 1
|
||||||
|
}
|
||||||
|
buffer
|
||||||
|
}
|
||||||
|
|
||||||
exactGroupingExpr.getOrElse(
|
/** Named attributes used to substitute grouping attributes into the final result. */
|
||||||
sys.error(s"$agg is not in grouping expressions: $groupingExpressions"))
|
@transient
|
||||||
} else {
|
lazy val namedGroups = groupingExpressions.map {
|
||||||
impl
|
case ne: NamedExpression => ne -> ne.toAttribute
|
||||||
|
case e => e -> Alias(e, s"groupingExpr:$e")().toAttribute
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A map of substitutions that are used to insert the aggregate expressions and grouping
|
||||||
|
* expression into the final result expression.
|
||||||
|
*/
|
||||||
|
@transient
|
||||||
|
lazy val resultMap =
|
||||||
|
(computedAggregates.map { agg => agg.unbound -> agg.resultAttribute} ++ namedGroups).toMap
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Substituted version of aggregateExpressions expressions which are used to compute final
|
||||||
|
* output rows given a group and the result of all aggregate computations.
|
||||||
|
*/
|
||||||
|
@transient
|
||||||
|
lazy val resultExpressions = aggregateExpressions.map { agg =>
|
||||||
|
agg.transform {
|
||||||
|
case e: Expression if resultMap.contains(e) => resultMap(e)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
def execute() = attachTree(this, "execute") {
|
def execute() = attachTree(this, "execute") {
|
||||||
// TODO: If the child of it is an [[catalyst.execution.Exchange]],
|
if (groupingExpressions.isEmpty) {
|
||||||
// do not evaluate the groupingExpressions again since we have evaluated it
|
child.execute().mapPartitions { iter =>
|
||||||
// in the [[catalyst.execution.Exchange]].
|
val buffer = newAggregateBuffer()
|
||||||
val grouped = child.execute().mapPartitions { iter =>
|
var currentRow: Row = null
|
||||||
val buildGrouping = new Projection(groupingExpressions)
|
while (iter.hasNext) {
|
||||||
iter.map(row => (buildGrouping(row), row.copy()))
|
currentRow = iter.next()
|
||||||
}.groupByKeyLocally()
|
var i = 0
|
||||||
|
while (i < buffer.length) {
|
||||||
|
buffer(i).update(currentRow)
|
||||||
|
i += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
val resultProjection = new Projection(resultExpressions, computedSchema)
|
||||||
|
val aggregateResults = new GenericMutableRow(computedAggregates.length)
|
||||||
|
|
||||||
val result = grouped.map { case (group, rows) =>
|
var i = 0
|
||||||
val aggImplementations = createAggregateImplementations()
|
while (i < buffer.length) {
|
||||||
|
aggregateResults(i) = buffer(i).apply(EmptyRow)
|
||||||
|
i += 1
|
||||||
|
}
|
||||||
|
|
||||||
// Pull out all the functions so we can feed each row into them.
|
Iterator(resultProjection(aggregateResults))
|
||||||
val aggFunctions = aggImplementations.flatMap(_ collect { case f: AggregateFunction => f })
|
|
||||||
|
|
||||||
rows.foreach { row =>
|
|
||||||
aggFunctions.foreach(_.update(row))
|
|
||||||
}
|
}
|
||||||
buildRow(aggImplementations.map(_.apply(group)))
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: THIS BREAKS PIPELINING, DOUBLE COMPUTES THE ANSWER, AND USES TOO MUCH MEMORY...
|
|
||||||
if (groupingExpressions.isEmpty && result.count == 0) {
|
|
||||||
// When there there is no output to the Aggregate operator, we still output an empty row.
|
|
||||||
val aggImplementations = createAggregateImplementations()
|
|
||||||
sc.makeRDD(buildRow(aggImplementations.map(_.apply(null))) :: Nil)
|
|
||||||
} else {
|
} else {
|
||||||
result
|
child.execute().mapPartitions { iter =>
|
||||||
|
val hashTable = new HashMap[Row, Array[AggregateFunction]]
|
||||||
|
val groupingProjection = new MutableProjection(groupingExpressions, childOutput)
|
||||||
|
|
||||||
|
var currentRow: Row = null
|
||||||
|
while (iter.hasNext) {
|
||||||
|
currentRow = iter.next()
|
||||||
|
val currentGroup = groupingProjection(currentRow)
|
||||||
|
var currentBuffer = hashTable.get(currentGroup)
|
||||||
|
if (currentBuffer == null) {
|
||||||
|
currentBuffer = newAggregateBuffer()
|
||||||
|
hashTable.put(currentGroup.copy(), currentBuffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
var i = 0
|
||||||
|
while (i < currentBuffer.length) {
|
||||||
|
currentBuffer(i).update(currentRow)
|
||||||
|
i += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
new Iterator[Row] {
|
||||||
|
private[this] val hashTableIter = hashTable.entrySet().iterator()
|
||||||
|
private[this] val aggregateResults = new GenericMutableRow(computedAggregates.length)
|
||||||
|
private[this] val resultProjection =
|
||||||
|
new MutableProjection(resultExpressions, computedSchema ++ namedGroups.map(_._2))
|
||||||
|
private[this] val joinedRow = new JoinedRow
|
||||||
|
|
||||||
|
override final def hasNext: Boolean = hashTableIter.hasNext
|
||||||
|
|
||||||
|
override final def next(): Row = {
|
||||||
|
val currentEntry = hashTableIter.next()
|
||||||
|
val currentGroup = currentEntry.getKey
|
||||||
|
val currentBuffer = currentEntry.getValue
|
||||||
|
|
||||||
|
var i = 0
|
||||||
|
while (i < currentBuffer.length) {
|
||||||
|
// Evaluating an aggregate buffer returns the result. No row is required since we
|
||||||
|
// already added all rows in the group using update.
|
||||||
|
aggregateResults(i) = currentBuffer(i).apply(EmptyRow)
|
||||||
|
i += 1
|
||||||
|
}
|
||||||
|
resultProjection(joinedRow(aggregateResults, currentGroup))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -337,13 +337,16 @@ case class HiveGenericUdaf(
|
||||||
|
|
||||||
type UDFType = AbstractGenericUDAFResolver
|
type UDFType = AbstractGenericUDAFResolver
|
||||||
|
|
||||||
|
@transient
|
||||||
protected lazy val resolver: AbstractGenericUDAFResolver = createFunction(name)
|
protected lazy val resolver: AbstractGenericUDAFResolver = createFunction(name)
|
||||||
|
|
||||||
|
@transient
|
||||||
protected lazy val objectInspector = {
|
protected lazy val objectInspector = {
|
||||||
resolver.getEvaluator(children.map(_.dataType.toTypeInfo).toArray)
|
resolver.getEvaluator(children.map(_.dataType.toTypeInfo).toArray)
|
||||||
.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray)
|
.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@transient
|
||||||
protected lazy val inspectors = children.map(_.dataType).map(toInspector)
|
protected lazy val inspectors = children.map(_.dataType).map(toInspector)
|
||||||
|
|
||||||
def dataType: DataType = inspectorToDataType(objectInspector)
|
def dataType: DataType = inspectorToDataType(objectInspector)
|
||||||
|
|
Loading…
Reference in a new issue