[SPARK-19634][ML] Multivariate summarizer - dataframes API

## What changes were proposed in this pull request?

This patch adds the DataFrames API to the multivariate summarizer (mean, variance, etc.). In addition to all the features of MultivariateOnlineSummarizer, it also allows the user to select a subset of the metrics.

## How was this patch tested?

Testcases added.

## Performance
Resolve several performance issues in #17419, further optimization pending on SQL team's work. One of the SQL layer performance issue related to these feature has been resolved in #18712, thanks liancheng and cloud-fan

### Performance data

(test on my laptop, use 2 partitions. tries out = 20, warm up = 10)

The unit of test results is records/milliseconds (higher is better)

Vector size/records number | 1/10000000 | 10/1000000 | 100/1000000 | 1000/100000 | 10000/10000
----|------|----|---|----|----
Dataframe | 15149  | 7441 | 2118 | 224 | 21
RDD from Dataframe | 4992  | 4440 | 2328 | 320 | 33
raw RDD | 53931  | 20683 | 3966 | 528 | 53

Author: WeichenXu <WeichenXu123@outlook.com>

Closes #18798 from WeichenXu123/SPARK-19634-dataframe-summarizer.
This commit is contained in:
WeichenXu 2017-08-16 10:41:05 +08:00 committed by Yanbo Liang
parent 9660831050
commit 07549b20a3
5 changed files with 1203 additions and 11 deletions

View file

@ -27,17 +27,7 @@ import org.apache.spark.sql.types._
*/ */
private[spark] class VectorUDT extends UserDefinedType[Vector] { private[spark] class VectorUDT extends UserDefinedType[Vector] {
override def sqlType: StructType = { override final def sqlType: StructType = _sqlType
// type: 0 = sparse, 1 = dense
// We only use "values" for dense vectors, and "size", "indices", and "values" for sparse
// vectors. The "values" field is nullable because we might want to add binary vectors later,
// which uses "size" and "indices", but not "values".
StructType(Seq(
StructField("type", ByteType, nullable = false),
StructField("size", IntegerType, nullable = true),
StructField("indices", ArrayType(IntegerType, containsNull = false), nullable = true),
StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true)))
}
override def serialize(obj: Vector): InternalRow = { override def serialize(obj: Vector): InternalRow = {
obj match { obj match {
@ -94,4 +84,16 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
override def typeName: String = "vector" override def typeName: String = "vector"
private[spark] override def asNullable: VectorUDT = this private[spark] override def asNullable: VectorUDT = this
private[this] val _sqlType = {
// type: 0 = sparse, 1 = dense
// We only use "values" for dense vectors, and "size", "indices", and "values" for sparse
// vectors. The "values" field is nullable because we might want to add binary vectors later,
// which uses "size" and "indices", but not "values".
StructType(Seq(
StructField("type", ByteType, nullable = false),
StructField("size", IntegerType, nullable = true),
StructField("indices", ArrayType(IntegerType, containsNull = false), nullable = true),
StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true)))
}
} }

View file

@ -0,0 +1,596 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.stat
import java.io._
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeArrayData}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, TypedImperativeAggregate}
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.types._
/**
* A builder object that provides summary statistics about a given column.
*
* Users should not directly create such builders, but instead use one of the methods in
* [[Summarizer]].
*/
@Experimental
@Since("2.3.0")
sealed abstract class SummaryBuilder {
/**
* Returns an aggregate object that contains the summary of the column with the requested metrics.
* @param featuresCol a column that contains features Vector object.
* @param weightCol a column that contains weight value.
* @return an aggregate column that contains the statistics. The exact content of this
* structure is determined during the creation of the builder.
*/
@Since("2.3.0")
def summary(featuresCol: Column, weightCol: Column): Column
@Since("2.3.0")
def summary(featuresCol: Column): Column = summary(featuresCol, lit(1.0))
}
/**
* Tools for vectorized statistics on MLlib Vectors.
*
* The methods in this package provide various statistics for Vectors contained inside DataFrames.
*
* This class lets users pick the statistics they would like to extract for a given column. Here is
* an example in Scala:
* {{{
* val dataframe = ... // Some dataframe containing a feature column
* val allStats = dataframe.select(Summarizer.metrics("min", "max").summary($"features"))
* val Row(Row(min_, max_)) = allStats.first()
* }}}
*
* If one wants to get a single metric, shortcuts are also available:
* {{{
* val meanDF = dataframe.select(Summarizer.mean($"features"))
* val Row(mean_) = meanDF.first()
* }}}
*
* Note: Currently, the performance of this interface is about 2x~3x slower then using the RDD
* interface.
*/
@Experimental
@Since("2.3.0")
object Summarizer extends Logging {
import SummaryBuilderImpl._
/**
* Given a list of metrics, provides a builder that it turns computes metrics from a column.
*
* See the documentation of [[Summarizer]] for an example.
*
* The following metrics are accepted (case sensitive):
* - mean: a vector that contains the coefficient-wise mean.
* - variance: a vector tha contains the coefficient-wise variance.
* - count: the count of all vectors seen.
* - numNonzeros: a vector with the number of non-zeros for each coefficients
* - max: the maximum for each coefficient.
* - min: the minimum for each coefficient.
* - normL2: the Euclidian norm for each coefficient.
* - normL1: the L1 norm of each coefficient (sum of the absolute values).
* @param firstMetric the metric being provided
* @param metrics additional metrics that can be provided.
* @return a builder.
* @throws IllegalArgumentException if one of the metric names is not understood.
*
* Note: Currently, the performance of this interface is about 2x~3x slower then using the RDD
* interface.
*/
@Since("2.3.0")
def metrics(firstMetric: String, metrics: String*): SummaryBuilder = {
val (typedMetrics, computeMetrics) = getRelevantMetrics(Seq(firstMetric) ++ metrics)
new SummaryBuilderImpl(typedMetrics, computeMetrics)
}
@Since("2.3.0")
def mean(col: Column): Column = getSingleMetric(col, "mean")
@Since("2.3.0")
def variance(col: Column): Column = getSingleMetric(col, "variance")
@Since("2.3.0")
def count(col: Column): Column = getSingleMetric(col, "count")
@Since("2.3.0")
def numNonZeros(col: Column): Column = getSingleMetric(col, "numNonZeros")
@Since("2.3.0")
def max(col: Column): Column = getSingleMetric(col, "max")
@Since("2.3.0")
def min(col: Column): Column = getSingleMetric(col, "min")
@Since("2.3.0")
def normL1(col: Column): Column = getSingleMetric(col, "normL1")
@Since("2.3.0")
def normL2(col: Column): Column = getSingleMetric(col, "normL2")
private def getSingleMetric(col: Column, metric: String): Column = {
val c1 = metrics(metric).summary(col)
c1.getField(metric).as(s"$metric($col)")
}
}
private[ml] class SummaryBuilderImpl(
requestedMetrics: Seq[SummaryBuilderImpl.Metric],
requestedCompMetrics: Seq[SummaryBuilderImpl.ComputeMetric]
) extends SummaryBuilder {
override def summary(featuresCol: Column, weightCol: Column): Column = {
val agg = SummaryBuilderImpl.MetricsAggregate(
requestedMetrics,
requestedCompMetrics,
featuresCol.expr,
weightCol.expr,
mutableAggBufferOffset = 0,
inputAggBufferOffset = 0)
new Column(AggregateExpression(agg, mode = Complete, isDistinct = false))
}
}
private[ml] object SummaryBuilderImpl extends Logging {
def implementedMetrics: Seq[String] = allMetrics.map(_._1).sorted
@throws[IllegalArgumentException]("When the list is empty or not a subset of known metrics")
def getRelevantMetrics(requested: Seq[String]): (Seq[Metric], Seq[ComputeMetric]) = {
val all = requested.map { req =>
val (_, metric, _, deps) = allMetrics.find(_._1 == req).getOrElse {
throw new IllegalArgumentException(s"Metric $req cannot be found." +
s" Valid metrics are $implementedMetrics")
}
metric -> deps
}
// Do not sort, otherwise the user has to look the schema to see the order that it
// is going to be given in.
val metrics = all.map(_._1)
val computeMetrics = all.flatMap(_._2).distinct.sortBy(_.toString)
metrics -> computeMetrics
}
def structureForMetrics(metrics: Seq[Metric]): StructType = {
val dict = allMetrics.map { case (name, metric, dataType, _) =>
(metric, (name, dataType))
}.toMap
val fields = metrics.map(dict.apply).map { case (name, dataType) =>
StructField(name, dataType, nullable = false)
}
StructType(fields)
}
private val arrayDType = ArrayType(DoubleType, containsNull = false)
private val arrayLType = ArrayType(LongType, containsNull = false)
/**
* All the metrics that can be currently computed by Spark for vectors.
*
* This list associates the user name, the internal (typed) name, and the list of computation
* metrics that need to de computed internally to get the final result.
*/
private val allMetrics: Seq[(String, Metric, DataType, Seq[ComputeMetric])] = Seq(
("mean", Mean, arrayDType, Seq(ComputeMean, ComputeWeightSum)),
("variance", Variance, arrayDType, Seq(ComputeWeightSum, ComputeMean, ComputeM2n)),
("count", Count, LongType, Seq()),
("numNonZeros", NumNonZeros, arrayLType, Seq(ComputeNNZ)),
("max", Max, arrayDType, Seq(ComputeMax, ComputeNNZ)),
("min", Min, arrayDType, Seq(ComputeMin, ComputeNNZ)),
("normL2", NormL2, arrayDType, Seq(ComputeM2)),
("normL1", NormL1, arrayDType, Seq(ComputeL1))
)
/**
* The metrics that are currently implemented.
*/
sealed trait Metric extends Serializable
private[stat] case object Mean extends Metric
private[stat] case object Variance extends Metric
private[stat] case object Count extends Metric
private[stat] case object NumNonZeros extends Metric
private[stat] case object Max extends Metric
private[stat] case object Min extends Metric
private[stat] case object NormL2 extends Metric
private[stat] case object NormL1 extends Metric
/**
* The running metrics that are going to be computed.
*
* There is a bipartite graph between the metrics and the computed metrics.
*/
sealed trait ComputeMetric extends Serializable
private[stat] case object ComputeMean extends ComputeMetric
private[stat] case object ComputeM2n extends ComputeMetric
private[stat] case object ComputeM2 extends ComputeMetric
private[stat] case object ComputeL1 extends ComputeMetric
private[stat] case object ComputeWeightSum extends ComputeMetric
private[stat] case object ComputeNNZ extends ComputeMetric
private[stat] case object ComputeMax extends ComputeMetric
private[stat] case object ComputeMin extends ComputeMetric
private[stat] class SummarizerBuffer(
requestedMetrics: Seq[Metric],
requestedCompMetrics: Seq[ComputeMetric]
) extends Serializable {
private var n = 0
private var currMean: Array[Double] = null
private var currM2n: Array[Double] = null
private var currM2: Array[Double] = null
private var currL1: Array[Double] = null
private var totalCnt: Long = 0
private var totalWeightSum: Double = 0.0
private var weightSquareSum: Double = 0.0
private var weightSum: Array[Double] = null
private var nnz: Array[Long] = null
private var currMax: Array[Double] = null
private var currMin: Array[Double] = null
def this() {
this(
Seq(Mean, Variance, Count, NumNonZeros, Max, Min, NormL2, NormL1),
Seq(ComputeMean, ComputeM2n, ComputeM2, ComputeL1,
ComputeWeightSum, ComputeNNZ, ComputeMax, ComputeMin)
)
}
/**
* Add a new sample to this summarizer, and update the statistical summary.
*/
def add(instance: Vector, weight: Double): this.type = {
require(weight >= 0.0, s"sample weight, $weight has to be >= 0.0")
if (weight == 0.0) return this
if (n == 0) {
require(instance.size > 0, s"Vector should have dimension larger than zero.")
n = instance.size
if (requestedCompMetrics.contains(ComputeMean)) { currMean = Array.ofDim[Double](n) }
if (requestedCompMetrics.contains(ComputeM2n)) { currM2n = Array.ofDim[Double](n) }
if (requestedCompMetrics.contains(ComputeM2)) { currM2 = Array.ofDim[Double](n) }
if (requestedCompMetrics.contains(ComputeL1)) { currL1 = Array.ofDim[Double](n) }
if (requestedCompMetrics.contains(ComputeWeightSum)) { weightSum = Array.ofDim[Double](n) }
if (requestedCompMetrics.contains(ComputeNNZ)) { nnz = Array.ofDim[Long](n) }
if (requestedCompMetrics.contains(ComputeMax)) {
currMax = Array.fill[Double](n)(Double.MinValue)
}
if (requestedCompMetrics.contains(ComputeMin)) {
currMin = Array.fill[Double](n)(Double.MaxValue)
}
}
require(n == instance.size, s"Dimensions mismatch when adding new sample." +
s" Expecting $n but got ${instance.size}.")
val localCurrMean = currMean
val localCurrM2n = currM2n
val localCurrM2 = currM2
val localCurrL1 = currL1
val localWeightSum = weightSum
val localNumNonzeros = nnz
val localCurrMax = currMax
val localCurrMin = currMin
instance.foreachActive { (index, value) =>
if (value != 0.0) {
if (localCurrMax != null && localCurrMax(index) < value) {
localCurrMax(index) = value
}
if (localCurrMin != null && localCurrMin(index) > value) {
localCurrMin(index) = value
}
if (localWeightSum != null) {
if (localCurrMean != null) {
val prevMean = localCurrMean(index)
val diff = value - prevMean
localCurrMean(index) = prevMean + weight * diff / (localWeightSum(index) + weight)
if (localCurrM2n != null) {
localCurrM2n(index) += weight * (value - localCurrMean(index)) * diff
}
}
localWeightSum(index) += weight
}
if (localCurrM2 != null) {
localCurrM2(index) += weight * value * value
}
if (localCurrL1 != null) {
localCurrL1(index) += weight * math.abs(value)
}
if (localNumNonzeros != null) {
localNumNonzeros(index) += 1
}
}
}
totalWeightSum += weight
weightSquareSum += weight * weight
totalCnt += 1
this
}
def add(instance: Vector): this.type = add(instance, 1.0)
/**
* Merge another SummarizerBuffer, and update the statistical summary.
* (Note that it's in place merging; as a result, `this` object will be modified.)
*
* @param other The other MultivariateOnlineSummarizer to be merged.
*/
def merge(other: SummarizerBuffer): this.type = {
if (this.totalWeightSum != 0.0 && other.totalWeightSum != 0.0) {
require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " +
s"Expecting $n but got ${other.n}.")
totalCnt += other.totalCnt
totalWeightSum += other.totalWeightSum
weightSquareSum += other.weightSquareSum
var i = 0
while (i < n) {
if (weightSum != null) {
val thisWeightSum = weightSum(i)
val otherWeightSum = other.weightSum(i)
val totalWeightSum = thisWeightSum + otherWeightSum
if (totalWeightSum != 0.0) {
if (currMean != null) {
val deltaMean = other.currMean(i) - currMean(i)
// merge mean together
currMean(i) += deltaMean * otherWeightSum / totalWeightSum
if (currM2n != null) {
// merge m2n together
currM2n(i) += other.currM2n(i) +
deltaMean * deltaMean * thisWeightSum * otherWeightSum / totalWeightSum
}
}
}
weightSum(i) = totalWeightSum
}
// merge m2 together
if (currM2 != null) { currM2(i) += other.currM2(i) }
// merge l1 together
if (currL1 != null) { currL1(i) += other.currL1(i) }
// merge max and min
if (currMax != null) { currMax(i) = math.max(currMax(i), other.currMax(i)) }
if (currMin != null) { currMin(i) = math.min(currMin(i), other.currMin(i)) }
if (nnz != null) { nnz(i) = nnz(i) + other.nnz(i) }
i += 1
}
} else if (totalWeightSum == 0.0 && other.totalWeightSum != 0.0) {
this.n = other.n
if (other.currMean != null) { this.currMean = other.currMean.clone() }
if (other.currM2n != null) { this.currM2n = other.currM2n.clone() }
if (other.currM2 != null) { this.currM2 = other.currM2.clone() }
if (other.currL1 != null) { this.currL1 = other.currL1.clone() }
this.totalCnt = other.totalCnt
this.totalWeightSum = other.totalWeightSum
this.weightSquareSum = other.weightSquareSum
if (other.weightSum != null) { this.weightSum = other.weightSum.clone() }
if (other.nnz != null) { this.nnz = other.nnz.clone() }
if (other.currMax != null) { this.currMax = other.currMax.clone() }
if (other.currMin != null) { this.currMin = other.currMin.clone() }
}
this
}
/**
* Sample mean of each dimension.
*/
def mean: Vector = {
require(requestedMetrics.contains(Mean))
require(totalWeightSum > 0, s"Nothing has been added to this summarizer.")
val realMean = Array.ofDim[Double](n)
var i = 0
while (i < n) {
realMean(i) = currMean(i) * (weightSum(i) / totalWeightSum)
i += 1
}
Vectors.dense(realMean)
}
/**
* Unbiased estimate of sample variance of each dimension.
*/
def variance: Vector = {
require(requestedMetrics.contains(Variance))
require(totalWeightSum > 0, s"Nothing has been added to this summarizer.")
val realVariance = Array.ofDim[Double](n)
val denominator = totalWeightSum - (weightSquareSum / totalWeightSum)
// Sample variance is computed, if the denominator is less than 0, the variance is just 0.
if (denominator > 0.0) {
val deltaMean = currMean
var i = 0
val len = currM2n.length
while (i < len) {
realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
(totalWeightSum - weightSum(i)) / totalWeightSum) / denominator
i += 1
}
}
Vectors.dense(realVariance)
}
/**
* Sample size.
*/
def count: Long = totalCnt
/**
* Number of nonzero elements in each dimension.
*
*/
def numNonzeros: Vector = {
require(requestedMetrics.contains(NumNonZeros))
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
Vectors.dense(nnz.map(_.toDouble))
}
/**
* Maximum value of each dimension.
*/
def max: Vector = {
require(requestedMetrics.contains(Max))
require(totalWeightSum > 0, s"Nothing has been added to this summarizer.")
var i = 0
while (i < n) {
if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0
i += 1
}
Vectors.dense(currMax)
}
/**
* Minimum value of each dimension.
*/
def min: Vector = {
require(requestedMetrics.contains(Min))
require(totalWeightSum > 0, s"Nothing has been added to this summarizer.")
var i = 0
while (i < n) {
if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0
i += 1
}
Vectors.dense(currMin)
}
/**
* L2 (Euclidian) norm of each dimension.
*/
def normL2: Vector = {
require(requestedMetrics.contains(NormL2))
require(totalWeightSum > 0, s"Nothing has been added to this summarizer.")
val realMagnitude = Array.ofDim[Double](n)
var i = 0
val len = currM2.length
while (i < len) {
realMagnitude(i) = math.sqrt(currM2(i))
i += 1
}
Vectors.dense(realMagnitude)
}
/**
* L1 norm of each dimension.
*/
def normL1: Vector = {
require(requestedMetrics.contains(NormL1))
require(totalWeightSum > 0, s"Nothing has been added to this summarizer.")
Vectors.dense(currL1)
}
}
private case class MetricsAggregate(
requestedMetrics: Seq[Metric],
requestedComputeMetrics: Seq[ComputeMetric],
featuresExpr: Expression,
weightExpr: Expression,
mutableAggBufferOffset: Int,
inputAggBufferOffset: Int)
extends TypedImperativeAggregate[SummarizerBuffer] {
override def eval(state: SummarizerBuffer): InternalRow = {
val metrics = requestedMetrics.map {
case Mean => UnsafeArrayData.fromPrimitiveArray(state.mean.toArray)
case Variance => UnsafeArrayData.fromPrimitiveArray(state.variance.toArray)
case Count => state.count
case NumNonZeros => UnsafeArrayData.fromPrimitiveArray(
state.numNonzeros.toArray.map(_.toLong))
case Max => UnsafeArrayData.fromPrimitiveArray(state.max.toArray)
case Min => UnsafeArrayData.fromPrimitiveArray(state.min.toArray)
case NormL2 => UnsafeArrayData.fromPrimitiveArray(state.normL2.toArray)
case NormL1 => UnsafeArrayData.fromPrimitiveArray(state.normL1.toArray)
}
InternalRow.apply(metrics: _*)
}
override def children: Seq[Expression] = featuresExpr :: weightExpr :: Nil
override def update(state: SummarizerBuffer, row: InternalRow): SummarizerBuffer = {
val features = udt.deserialize(featuresExpr.eval(row))
val weight = weightExpr.eval(row).asInstanceOf[Double]
state.add(features, weight)
state
}
override def merge(state: SummarizerBuffer,
other: SummarizerBuffer): SummarizerBuffer = {
state.merge(other)
}
override def nullable: Boolean = false
override def createAggregationBuffer(): SummarizerBuffer
= new SummarizerBuffer(requestedMetrics, requestedComputeMetrics)
override def serialize(state: SummarizerBuffer): Array[Byte] = {
// TODO: Use ByteBuffer to optimize
val bos = new ByteArrayOutputStream()
val oos = new ObjectOutputStream(bos)
oos.writeObject(state)
bos.toByteArray
}
override def deserialize(bytes: Array[Byte]): SummarizerBuffer = {
// TODO: Use ByteBuffer to optimize
val bis = new ByteArrayInputStream(bytes)
val ois = new ObjectInputStream(bis)
ois.readObject().asInstanceOf[SummarizerBuffer]
}
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): MetricsAggregate = {
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
}
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): MetricsAggregate = {
copy(inputAggBufferOffset = newInputAggBufferOffset)
}
override lazy val dataType: DataType = structureForMetrics(requestedMetrics)
override def prettyName: String = "aggregate_metrics"
}
private[this] val udt = new VectorUDT
}

View file

@ -0,0 +1,582 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.stat
import org.scalatest.exceptions.TestFailedException
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, Statistics}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
import testImplicits._
import Summarizer._
import SummaryBuilderImpl._
private case class ExpectedMetrics(
mean: Seq[Double],
variance: Seq[Double],
count: Long,
numNonZeros: Seq[Long],
max: Seq[Double],
min: Seq[Double],
normL2: Seq[Double],
normL1: Seq[Double])
/**
* The input is expected to be either a sparse vector, a dense vector or an array of doubles
* (which will be converted to a dense vector)
* The expected is the list of all the known metrics.
*
* The tests take an list of input vectors and a list of all the summary values that
* are expected for this input. They currently test against some fixed subset of the
* metrics, but should be made fuzzy in the future.
*/
private def testExample(name: String, input: Seq[Any], exp: ExpectedMetrics): Unit = {
def inputVec: Seq[Vector] = input.map {
case x: Array[Double @unchecked] => Vectors.dense(x)
case x: Seq[Double @unchecked] => Vectors.dense(x.toArray)
case x: Vector => x
case x => throw new Exception(x.toString)
}
val summarizer = {
val _summarizer = new MultivariateOnlineSummarizer
inputVec.foreach(v => _summarizer.add(OldVectors.fromML(v)))
_summarizer
}
// Because the Spark context is reset between tests, we cannot hold a reference onto it.
def wrappedInit() = {
val df = inputVec.map(Tuple1.apply).toDF("features")
val col = df.col("features")
(df, col)
}
registerTest(s"$name - mean only") {
val (df, c) = wrappedInit()
compare(df.select(metrics("mean").summary(c), mean(c)), Seq(Row(exp.mean), summarizer.mean))
}
registerTest(s"$name - mean only (direct)") {
val (df, c) = wrappedInit()
compare(df.select(mean(c)), Seq(exp.mean))
}
registerTest(s"$name - variance only") {
val (df, c) = wrappedInit()
compare(df.select(metrics("variance").summary(c), variance(c)),
Seq(Row(exp.variance), summarizer.variance))
}
registerTest(s"$name - variance only (direct)") {
val (df, c) = wrappedInit()
compare(df.select(variance(c)), Seq(summarizer.variance))
}
registerTest(s"$name - count only") {
val (df, c) = wrappedInit()
compare(df.select(metrics("count").summary(c), count(c)),
Seq(Row(exp.count), exp.count))
}
registerTest(s"$name - count only (direct)") {
val (df, c) = wrappedInit()
compare(df.select(count(c)),
Seq(exp.count))
}
registerTest(s"$name - numNonZeros only") {
val (df, c) = wrappedInit()
compare(df.select(metrics("numNonZeros").summary(c), numNonZeros(c)),
Seq(Row(exp.numNonZeros), exp.numNonZeros))
}
registerTest(s"$name - numNonZeros only (direct)") {
val (df, c) = wrappedInit()
compare(df.select(numNonZeros(c)),
Seq(exp.numNonZeros))
}
registerTest(s"$name - min only") {
val (df, c) = wrappedInit()
compare(df.select(metrics("min").summary(c), min(c)),
Seq(Row(exp.min), exp.min))
}
registerTest(s"$name - max only") {
val (df, c) = wrappedInit()
compare(df.select(metrics("max").summary(c), max(c)),
Seq(Row(exp.max), exp.max))
}
registerTest(s"$name - normL1 only") {
val (df, c) = wrappedInit()
compare(df.select(metrics("normL1").summary(c), normL1(c)),
Seq(Row(exp.normL1), exp.normL1))
}
registerTest(s"$name - normL2 only") {
val (df, c) = wrappedInit()
compare(df.select(metrics("normL2").summary(c), normL2(c)),
Seq(Row(exp.normL2), exp.normL2))
}
registerTest(s"$name - all metrics at once") {
val (df, c) = wrappedInit()
compare(df.select(
metrics("mean", "variance", "count", "numNonZeros").summary(c),
mean(c), variance(c), count(c), numNonZeros(c)),
Seq(Row(exp.mean, exp.variance, exp.count, exp.numNonZeros),
exp.mean, exp.variance, exp.count, exp.numNonZeros))
}
}
private def denseData(input: Seq[Seq[Double]]): DataFrame = {
input.map(_.toArray).map(Vectors.dense).map(Tuple1.apply).toDF("features")
}
private def compare(df: DataFrame, exp: Seq[Any]): Unit = {
val coll = df.collect().toSeq
val Seq(row) = coll
val res = row.toSeq
val names = df.schema.fieldNames.zipWithIndex.map { case (n, idx) => s"$n ($idx)" }
assert(res.size === exp.size, (res.size, exp.size))
for (((x1, x2), name) <- res.zip(exp).zip(names)) {
compareStructures(x1, x2, name)
}
}
// Compares structured content.
private def compareStructures(x1: Any, x2: Any, name: String): Unit = (x1, x2) match {
case (y1: Seq[Double @unchecked], v1: OldVector) =>
compareStructures(y1, v1.toArray.toSeq, name)
case (d1: Double, d2: Double) =>
assert2(Vectors.dense(d1) ~== Vectors.dense(d2) absTol 1e-4, name)
case (r1: GenericRowWithSchema, r2: Row) =>
assert(r1.size === r2.size, (r1, r2))
for (((fname, x1), x2) <- r1.schema.fieldNames.zip(r1.toSeq).zip(r2.toSeq)) {
compareStructures(x1, x2, s"$name.$fname")
}
case (r1: Row, r2: Row) =>
assert(r1.size === r2.size, (r1, r2))
for ((x1, x2) <- r1.toSeq.zip(r2.toSeq)) { compareStructures(x1, x2, name) }
case (v1: Vector, v2: Vector) =>
assert2(v1 ~== v2 absTol 1e-4, name)
case (l1: Long, l2: Long) => assert(l1 === l2)
case (s1: Seq[_], s2: Seq[_]) =>
assert(s1.size === s2.size, s"$name ${(s1, s2)}")
for (((x1, idx), x2) <- s1.zipWithIndex.zip(s2)) {
compareStructures(x1, x2, s"$name.$idx")
}
case (arr1: Array[_], arr2: Array[_]) =>
assert(arr1.toSeq === arr2.toSeq)
case _ => throw new Exception(s"$name: ${x1.getClass} ${x2.getClass} $x1 $x2")
}
private def assert2(x: => Boolean, hint: String): Unit = {
try {
assert(x, hint)
} catch {
case tfe: TestFailedException =>
throw new TestFailedException(Some(s"Failure with hint $hint"), Some(tfe), 1)
}
}
test("debugging test") {
val df = denseData(Nil)
val c = df.col("features")
val c1 = metrics("mean").summary(c)
val res = df.select(c1)
intercept[SparkException] {
compare(res, Seq.empty)
}
}
test("basic error handling") {
val df = denseData(Nil)
val c = df.col("features")
val res = df.select(metrics("mean").summary(c), mean(c))
intercept[SparkException] {
compare(res, Seq.empty)
}
}
test("no element, working metrics") {
val df = denseData(Nil)
val c = df.col("features")
val res = df.select(metrics("count").summary(c), count(c))
compare(res, Seq(Row(0L), 0L))
}
val singleElem = Seq(0.0, 1.0, 2.0)
testExample("single element", Seq(singleElem), ExpectedMetrics(
mean = singleElem,
variance = Seq(0.0, 0.0, 0.0),
count = 1,
numNonZeros = Seq(0, 1, 1),
max = singleElem,
min = singleElem,
normL1 = singleElem,
normL2 = singleElem
))
testExample("two elements", Seq(Seq(0.0, 1.0, 2.0), Seq(0.0, -1.0, -2.0)), ExpectedMetrics(
mean = Seq(0.0, 0.0, 0.0),
// TODO: I have a doubt about these values, they are not normalized.
variance = Seq(0.0, 2.0, 8.0),
count = 2,
numNonZeros = Seq(0, 2, 2),
max = Seq(0.0, 1.0, 2.0),
min = Seq(0.0, -1.0, -2.0),
normL1 = Seq(0.0, 2.0, 4.0),
normL2 = Seq(0.0, math.sqrt(2.0), math.sqrt(2.0) * 2.0)
))
testExample("dense vector input",
Seq(Seq(-1.0, 0.0, 6.0), Seq(3.0, -3.0, 0.0)),
ExpectedMetrics(
mean = Seq(1.0, -1.5, 3.0),
variance = Seq(8.0, 4.5, 18.0),
count = 2,
numNonZeros = Seq(2, 1, 1),
max = Seq(3.0, 0.0, 6.0),
min = Seq(-1.0, -3, 0.0),
normL1 = Seq(4.0, 3.0, 6.0),
normL2 = Seq(math.sqrt(10), 3, 6.0)
)
)
test("summarizer buffer basic error handing") {
val summarizer = new SummarizerBuffer
assert(summarizer.count === 0, "should be zero since nothing is added.")
withClue("Getting numNonzeros from empty summarizer should throw exception.") {
intercept[IllegalArgumentException] {
summarizer.numNonzeros
}
}
withClue("Getting variance from empty summarizer should throw exception.") {
intercept[IllegalArgumentException] {
summarizer.variance
}
}
withClue("Getting mean from empty summarizer should throw exception.") {
intercept[IllegalArgumentException] {
summarizer.mean
}
}
withClue("Getting max from empty summarizer should throw exception.") {
intercept[IllegalArgumentException] {
summarizer.max
}
}
withClue("Getting min from empty summarizer should throw exception.") {
intercept[IllegalArgumentException] {
summarizer.min
}
}
summarizer.add(Vectors.dense(-1.0, 2.0, 6.0)).add(Vectors.sparse(3, Seq((0, -2.0), (1, 6.0))))
withClue("Adding a new dense sample with different array size should throw exception.") {
intercept[IllegalArgumentException] {
summarizer.add(Vectors.dense(3.0, 1.0))
}
}
withClue("Adding a new sparse sample with different array size should throw exception.") {
intercept[IllegalArgumentException] {
summarizer.add(Vectors.sparse(5, Seq((0, -2.0), (1, 6.0))))
}
}
val summarizer2 = (new SummarizerBuffer).add(Vectors.dense(1.0, -2.0, 0.0, 4.0))
withClue("Merging a new summarizer with different dimensions should throw exception.") {
intercept[IllegalArgumentException] {
summarizer.merge(summarizer2)
}
}
}
test("summarizer buffer dense vector input") {
// For column 2, the maximum will be 0.0, and it's not explicitly added since we ignore all
// the zeros; it's a case we need to test. For column 3, the minimum will be 0.0 which we
// need to test as well.
val summarizer = (new SummarizerBuffer)
.add(Vectors.dense(-1.0, 0.0, 6.0))
.add(Vectors.dense(3.0, -3.0, 0.0))
assert(summarizer.mean ~== Vectors.dense(1.0, -1.5, 3.0) absTol 1E-5, "mean mismatch")
assert(summarizer.min ~== Vectors.dense(-1.0, -3, 0.0) absTol 1E-5, "min mismatch")
assert(summarizer.max ~== Vectors.dense(3.0, 0.0, 6.0) absTol 1E-5, "max mismatch")
assert(summarizer.numNonzeros ~== Vectors.dense(2, 1, 1) absTol 1E-5, "numNonzeros mismatch")
assert(summarizer.variance ~== Vectors.dense(8.0, 4.5, 18.0) absTol 1E-5, "variance mismatch")
assert(summarizer.count === 2)
}
test("summarizer buffer sparse vector input") {
val summarizer = (new SummarizerBuffer)
.add(Vectors.sparse(3, Seq((0, -1.0), (2, 6.0))))
.add(Vectors.sparse(3, Seq((0, 3.0), (1, -3.0))))
assert(summarizer.mean ~== Vectors.dense(1.0, -1.5, 3.0) absTol 1E-5, "mean mismatch")
assert(summarizer.min ~== Vectors.dense(-1.0, -3, 0.0) absTol 1E-5, "min mismatch")
assert(summarizer.max ~== Vectors.dense(3.0, 0.0, 6.0) absTol 1E-5, "max mismatch")
assert(summarizer.numNonzeros ~== Vectors.dense(2, 1, 1) absTol 1E-5, "numNonzeros mismatch")
assert(summarizer.variance ~== Vectors.dense(8.0, 4.5, 18.0) absTol 1E-5, "variance mismatch")
assert(summarizer.count === 2)
}
test("summarizer buffer mixing dense and sparse vector input") {
val summarizer = (new SummarizerBuffer)
.add(Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))))
.add(Vectors.dense(0.0, -1.0, -3.0))
.add(Vectors.sparse(3, Seq((1, -5.1))))
.add(Vectors.dense(3.8, 0.0, 1.9))
.add(Vectors.dense(1.7, -0.6, 0.0))
.add(Vectors.sparse(3, Seq((1, 1.9), (2, 0.0))))
assert(summarizer.mean ~==
Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333) absTol 1E-5, "mean mismatch")
assert(summarizer.min ~== Vectors.dense(-2.0, -5.1, -3) absTol 1E-5, "min mismatch")
assert(summarizer.max ~== Vectors.dense(3.8, 2.3, 1.9) absTol 1E-5, "max mismatch")
assert(summarizer.numNonzeros ~== Vectors.dense(3, 5, 2) absTol 1E-5, "numNonzeros mismatch")
assert(summarizer.variance ~==
Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5,
"variance mismatch")
assert(summarizer.count === 6)
}
test("summarizer buffer merging two summarizers") {
val summarizer1 = (new SummarizerBuffer)
.add(Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))))
.add(Vectors.dense(0.0, -1.0, -3.0))
val summarizer2 = (new SummarizerBuffer)
.add(Vectors.sparse(3, Seq((1, -5.1))))
.add(Vectors.dense(3.8, 0.0, 1.9))
.add(Vectors.dense(1.7, -0.6, 0.0))
.add(Vectors.sparse(3, Seq((1, 1.9), (2, 0.0))))
val summarizer = summarizer1.merge(summarizer2)
assert(summarizer.mean ~==
Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333) absTol 1E-5, "mean mismatch")
assert(summarizer.min ~== Vectors.dense(-2.0, -5.1, -3) absTol 1E-5, "min mismatch")
assert(summarizer.max ~== Vectors.dense(3.8, 2.3, 1.9) absTol 1E-5, "max mismatch")
assert(summarizer.numNonzeros ~== Vectors.dense(3, 5, 2) absTol 1E-5, "numNonzeros mismatch")
assert(summarizer.variance ~==
Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5,
"variance mismatch")
assert(summarizer.count === 6)
}
test("summarizer buffer merging summarizer with empty summarizer") {
// If one of two is non-empty, this should return the non-empty summarizer.
// If both of them are empty, then just return the empty summarizer.
val summarizer1 = (new SummarizerBuffer)
.add(Vectors.dense(0.0, -1.0, -3.0)).merge(new SummarizerBuffer)
assert(summarizer1.count === 1)
val summarizer2 = (new SummarizerBuffer)
.merge((new SummarizerBuffer).add(Vectors.dense(0.0, -1.0, -3.0)))
assert(summarizer2.count === 1)
val summarizer3 = (new SummarizerBuffer).merge(new SummarizerBuffer)
assert(summarizer3.count === 0)
assert(summarizer1.mean ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "mean mismatch")
assert(summarizer2.mean ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "mean mismatch")
assert(summarizer1.min ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "min mismatch")
assert(summarizer2.min ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "min mismatch")
assert(summarizer1.max ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "max mismatch")
assert(summarizer2.max ~== Vectors.dense(0.0, -1.0, -3.0) absTol 1E-5, "max mismatch")
assert(summarizer1.numNonzeros ~== Vectors.dense(0, 1, 1) absTol 1E-5, "numNonzeros mismatch")
assert(summarizer2.numNonzeros ~== Vectors.dense(0, 1, 1) absTol 1E-5, "numNonzeros mismatch")
assert(summarizer1.variance ~== Vectors.dense(0, 0, 0) absTol 1E-5, "variance mismatch")
assert(summarizer2.variance ~== Vectors.dense(0, 0, 0) absTol 1E-5, "variance mismatch")
}
test("summarizer buffer merging summarizer when one side has zero mean (SPARK-4355)") {
val s0 = new SummarizerBuffer()
.add(Vectors.dense(2.0))
.add(Vectors.dense(2.0))
val s1 = new SummarizerBuffer()
.add(Vectors.dense(1.0))
.add(Vectors.dense(-1.0))
s0.merge(s1)
assert(s0.mean(0) ~== 1.0 absTol 1e-14)
}
test("summarizer buffer merging summarizer with weighted samples") {
val summarizer = (new SummarizerBuffer)
.add(Vectors.sparse(3, Seq((0, -0.8), (1, 1.7))), weight = 0.1)
.add(Vectors.dense(0.0, -1.2, -1.7), 0.2).merge(
(new SummarizerBuffer)
.add(Vectors.sparse(3, Seq((0, -0.7), (1, 0.01), (2, 1.3))), 0.15)
.add(Vectors.dense(-0.5, 0.3, -1.5), 0.05))
assert(summarizer.count === 4)
// The following values are hand calculated using the formula:
// [[https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights]]
// which defines the reliability weight used for computing the unbiased estimation of variance
// for weighted instances.
assert(summarizer.mean ~== Vectors.dense(Array(-0.42, -0.107, -0.44))
absTol 1E-10, "mean mismatch")
assert(summarizer.variance ~== Vectors.dense(Array(0.17657142857, 1.645115714, 2.42057142857))
absTol 1E-8, "variance mismatch")
assert(summarizer.numNonzeros ~== Vectors.dense(Array(3.0, 4.0, 3.0))
absTol 1E-10, "numNonzeros mismatch")
assert(summarizer.max ~== Vectors.dense(Array(0.0, 1.7, 1.3)) absTol 1E-10, "max mismatch")
assert(summarizer.min ~== Vectors.dense(Array(-0.8, -1.2, -1.7)) absTol 1E-10, "min mismatch")
assert(summarizer.normL2 ~== Vectors.dense(0.387298335, 0.762571308141, 0.9715966241192)
absTol 1E-8, "normL2 mismatch")
assert(summarizer.normL1 ~== Vectors.dense(0.21, 0.4265, 0.61) absTol 1E-10, "normL1 mismatch")
}
test("summarizer buffer test min/max with weighted samples") {
val summarizer1 = new SummarizerBuffer()
.add(Vectors.dense(10.0, -10.0), 1e10)
.add(Vectors.dense(0.0, 0.0), 1e-7)
val summarizer2 = new SummarizerBuffer()
summarizer2.add(Vectors.dense(10.0, -10.0), 1e10)
for (i <- 1 to 100) {
summarizer2.add(Vectors.dense(0.0, 0.0), 1e-7)
}
val summarizer3 = new SummarizerBuffer()
for (i <- 1 to 100) {
summarizer3.add(Vectors.dense(0.0, 0.0), 1e-7)
}
summarizer3.add(Vectors.dense(10.0, -10.0), 1e10)
assert(summarizer1.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14)
assert(summarizer1.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14)
assert(summarizer2.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14)
assert(summarizer2.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14)
assert(summarizer3.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14)
assert(summarizer3.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14)
}
ignore("performance test") {
/*
Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.12
MacBook Pro (15-inch, 2016) CPU 2.9 GHz Intel Core i7
Use 2 partitions. tries out times= 20, warm up times = 10
The unit of test results is records/milliseconds (higher is better)
Vector size/records number: 1/1E7 10/1E6 100/1E6 1E3/1E5 1E4/1E4
-----------------------------------------------------------------------------
DataFrame 15149 7441 2118 224 21
RDD from DataFrame 4992 4440 2328 320 33
Raw RDD 53931 20683 3966 528 53
*/
import scala.util.Random
val rand = new Random()
val genArr = (dim: Int) => {
Array.fill(dim)(rand.nextDouble())
}
val numPartitions = 2
for ( (n, dim) <- Seq(
(10000000, 1), (1000000, 10), (1000000, 100), (100000, 1000), (10000, 10000))
) {
val rdd1 = sc.parallelize(1 to n, numPartitions).map { idx =>
OldVectors.dense(genArr(dim))
}
// scalastyle:off println
println(s"records number = $n, vector size = $dim, partition = ${rdd1.getNumPartitions}")
// scalastyle:on println
val numOfTry = 20
val numOfWarmUp = 10
rdd1.cache()
rdd1.count()
val rdd2 = sc.parallelize(1 to n, numPartitions).map { idx =>
Vectors.dense(genArr(dim))
}
rdd2.cache()
rdd2.count()
val df = rdd2.map(Tuple1.apply).toDF("features")
df.cache()
df.count()
def print(name: String, l: List[Long]): Unit = {
def f(z: Long) = (1e6 * n.toDouble) / z
val min = f(l.max)
val max = f(l.min)
val med = f(l.sorted.drop(l.size / 2).head)
// scalastyle:off println
println(s"$name = [$min ~ $med ~ $max] records / milli")
// scalastyle:on println
}
var timeDF: List[Long] = Nil
val x = df.select(
metrics("mean", "variance", "count", "numNonZeros", "max", "min", "normL1",
"normL2").summary($"features"))
for (i <- 1 to numOfTry) {
val start = System.nanoTime()
x.head()
val end = System.nanoTime()
if (i > numOfWarmUp) timeDF ::= (end - start)
}
var timeRDD: List[Long] = Nil
for (i <- 1 to numOfTry) {
val start = System.nanoTime()
Statistics.colStats(rdd1)
val end = System.nanoTime()
if (i > numOfWarmUp) timeRDD ::= (end - start)
}
var timeRDDFromDF: List[Long] = Nil
val rddFromDf = df.rdd.map { case Row(v: Vector) => OldVectors.fromML(v) }
for (i <- 1 to numOfTry) {
val start = System.nanoTime()
Statistics.colStats(rddFromDf)
val end = System.nanoTime()
if (i > numOfWarmUp) timeRDDFromDF ::= (end - start)
}
print("DataFrame : ", timeDF)
print("RDD :", timeRDD)
print("RDD from DataFrame : ", timeRDDFromDF)
}
}
}

View file

@ -101,6 +101,8 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu
/** /**
* A projection that returns UnsafeRow. * A projection that returns UnsafeRow.
*
* CAUTION: the returned projection object should *not* be assumed to be thread-safe.
*/ */
abstract class UnsafeProjection extends Projection { abstract class UnsafeProjection extends Projection {
override def apply(row: InternalRow): UnsafeRow override def apply(row: InternalRow): UnsafeRow
@ -110,11 +112,15 @@ object UnsafeProjection {
/** /**
* Returns an UnsafeProjection for given StructType. * Returns an UnsafeProjection for given StructType.
*
* CAUTION: the returned projection object is *not* thread-safe.
*/ */
def create(schema: StructType): UnsafeProjection = create(schema.fields.map(_.dataType)) def create(schema: StructType): UnsafeProjection = create(schema.fields.map(_.dataType))
/** /**
* Returns an UnsafeProjection for given Array of DataTypes. * Returns an UnsafeProjection for given Array of DataTypes.
*
* CAUTION: the returned projection object is *not* thread-safe.
*/ */
def create(fields: Array[DataType]): UnsafeProjection = { def create(fields: Array[DataType]): UnsafeProjection = {
create(fields.zipWithIndex.map(x => BoundReference(x._2, x._1, true))) create(fields.zipWithIndex.map(x => BoundReference(x._2, x._1, true)))

View file

@ -511,6 +511,12 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate {
* Generates the final aggregation result value for current key group with the aggregation buffer * Generates the final aggregation result value for current key group with the aggregation buffer
* object. * object.
* *
* Developer note: the only return types accepted by Spark are:
* - primitive types
* - InternalRow and subclasses
* - ArrayData
* - MapData
*
* @param buffer aggregation buffer object. * @param buffer aggregation buffer object.
* @return The aggregation result of current key group * @return The aggregation result of current key group
*/ */