[SPARK-11578][SQL] User API for Typed Aggregation

This PR adds a new interface for user-defined aggregations, that can be used in `DataFrame` and `Dataset` operations to take all of the elements of a group and reduce them to a single value.

For example, the following aggregator extracts an `int` from a specific class and adds them up:

```scala
  case class Data(i: Int)

  val customSummer =  new Aggregator[Data, Int, Int] {
    def prepare(d: Data) = d.i
    def reduce(l: Int, r: Int) = l + r
    def present(r: Int) = r
  }.toColumn()

  val ds: Dataset[Data] = ...
  val aggregated = ds.select(customSummer)
```

By using helper functions, users can make a generic `Aggregator` that works on any input type:

```scala
/** An `Aggregator` that adds up any numeric type returned by the given function. */
class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable {
  val numeric = implicitly[Numeric[N]]
  override def zero: N = numeric.zero
  override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
  override def present(reduction: N): N = reduction
}

def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn
```

These aggregators can then be used alongside other built-in SQL aggregations.

```scala
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
ds
  .groupBy(_._1)
  .agg(
    sum(_._2),                // The aggregator defined above.
    expr("sum(_2)").as[Int],  // A built-in dynatically typed aggregation.
    count("*"))               // A built-in statically typed aggregation.
  .collect()

res0: ("a", 30, 30, 2L), ("b", 3, 3, 2L), ("c", 1, 1, 1L)
```

The current implementation focuses on integrating this into the typed API, but currently only supports running aggregations that return a single long value as explained in `TypedAggregateExpression`.  This will be improved in a followup PR.

Author: Michael Armbrust <michael@databricks.com>

Closes #9555 from marmbrus/dataset-useragg.
This commit is contained in:
Michael Armbrust 2015-11-09 16:11:00 -08:00
parent 2f38378856
commit 9c740a9ddf
9 changed files with 360 additions and 42 deletions

View file

@ -23,7 +23,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.Logging
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.encoders.Encoder
import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.DataTypeParser
import org.apache.spark.sql.types._
@ -39,10 +39,13 @@ private[sql] object Column {
}
/**
* A [[Column]] where an [[Encoder]] has been given for the expected return type.
* A [[Column]] where an [[Encoder]] has been given for the expected input and return type.
* @since 1.6.0
* @tparam T The input type expected for this expression. Can be `Any` if the expression is type
* checked by the analyzer instead of the compiler (i.e. `expr("sum(...)")`).
* @tparam U The output type of this column.
*/
class TypedColumn[T](expr: Expression)(implicit val encoder: Encoder[T]) extends Column(expr)
class TypedColumn[-T, U](expr: Expression, val encoder: Encoder[U]) extends Column(expr)
/**
* :: Experimental ::
@ -85,7 +88,7 @@ class Column(protected[sql] val expr: Expression) extends Logging {
* results into the correct JVM types.
* @since 1.6.0
*/
def as[T : Encoder]: TypedColumn[T] = new TypedColumn[T](expr)
def as[U : Encoder]: TypedColumn[Any, U] = new TypedColumn[Any, U](expr, encoderFor[U])
/**
* Extracts a value or values from a complex type.

View file

@ -358,7 +358,7 @@ class Dataset[T] private[sql](
* }}}
* @since 1.6.0
*/
def select[U1: Encoder](c1: TypedColumn[U1]): Dataset[U1] = {
def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = {
new Dataset[U1](sqlContext, Project(Alias(c1.expr, "_1")() :: Nil, logicalPlan))
}
@ -367,7 +367,7 @@ class Dataset[T] private[sql](
* code reuse, we do this without the help of the type system and then use helper functions
* that cast appropriately for the user facing interface.
*/
protected def selectUntyped(columns: TypedColumn[_]*): Dataset[_] = {
protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
val aliases = columns.zipWithIndex.map { case (c, i) => Alias(c.expr, s"_${i + 1}")() }
val unresolvedPlan = Project(aliases, logicalPlan)
val execution = new QueryExecution(sqlContext, unresolvedPlan)
@ -385,7 +385,7 @@ class Dataset[T] private[sql](
* Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
* @since 1.6.0
*/
def select[U1, U2](c1: TypedColumn[U1], c2: TypedColumn[U2]): Dataset[(U1, U2)] =
def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): Dataset[(U1, U2)] =
selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]]
/**
@ -393,9 +393,9 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def select[U1, U2, U3](
c1: TypedColumn[U1],
c2: TypedColumn[U2],
c3: TypedColumn[U3]): Dataset[(U1, U2, U3)] =
c1: TypedColumn[T, U1],
c2: TypedColumn[T, U2],
c3: TypedColumn[T, U3]): Dataset[(U1, U2, U3)] =
selectUntyped(c1, c2, c3).asInstanceOf[Dataset[(U1, U2, U3)]]
/**
@ -403,10 +403,10 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def select[U1, U2, U3, U4](
c1: TypedColumn[U1],
c2: TypedColumn[U2],
c3: TypedColumn[U3],
c4: TypedColumn[U4]): Dataset[(U1, U2, U3, U4)] =
c1: TypedColumn[T, U1],
c2: TypedColumn[T, U2],
c3: TypedColumn[T, U3],
c4: TypedColumn[T, U4]): Dataset[(U1, U2, U3, U4)] =
selectUntyped(c1, c2, c3, c4).asInstanceOf[Dataset[(U1, U2, U3, U4)]]
/**
@ -414,11 +414,11 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def select[U1, U2, U3, U4, U5](
c1: TypedColumn[U1],
c2: TypedColumn[U2],
c3: TypedColumn[U3],
c4: TypedColumn[U4],
c5: TypedColumn[U5]): Dataset[(U1, U2, U3, U4, U5)] =
c1: TypedColumn[T, U1],
c2: TypedColumn[T, U2],
c3: TypedColumn[T, U3],
c4: TypedColumn[T, U4],
c5: TypedColumn[T, U5]): Dataset[(U1, U2, U3, U4, U5)] =
selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]]
/* **************** *

View file

@ -18,6 +18,7 @@
package org.apache.spark.sql
import java.util.{Iterator => JIterator}
import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
@ -26,8 +27,10 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttrib
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Encoder}
import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, Alias, Attribute}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
import org.apache.spark.sql.execution.QueryExecution
/**
* :: Experimental ::
* A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not
@ -143,7 +146,7 @@ class GroupedDataset[K, T] private[sql](
* that cast appropriately for the user facing interface.
* TODO: does not handle aggrecations that return nonflat results,
*/
protected def aggUntyped(columns: TypedColumn[_]*): Dataset[_] = {
protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
val aliases = (groupingAttributes ++ columns.map(_.expr)).map {
case u: UnresolvedAttribute => UnresolvedAlias(u)
case expr: NamedExpression => expr
@ -151,7 +154,15 @@ class GroupedDataset[K, T] private[sql](
}
val unresolvedPlan = Aggregate(groupingAttributes, aliases, logicalPlan)
val execution = new QueryExecution(sqlContext, unresolvedPlan)
// Fill in the input encoders for any aggregators in the plan.
val withEncoders = unresolvedPlan transformAllExpressions {
case ta: TypedAggregateExpression if ta.aEncoder.isEmpty =>
ta.copy(
aEncoder = Some(tEnc.asInstanceOf[ExpressionEncoder[Any]]),
children = dataAttributes)
}
val execution = new QueryExecution(sqlContext, withEncoders)
val columnEncoders = columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]])
@ -162,43 +173,47 @@ class GroupedDataset[K, T] private[sql](
case (e, a) =>
e.unbind(a :: Nil).resolve(execution.analyzed.output)
}
new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
new Dataset(
sqlContext,
execution,
ExpressionEncoder.tuple(encoders))
}
/**
* Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key
* and the result of computing this aggregation over all elements in the group.
*/
def agg[A1](col1: TypedColumn[A1]): Dataset[(K, A1)] =
aggUntyped(col1).asInstanceOf[Dataset[(K, A1)]]
def agg[U1](col1: TypedColumn[T, U1]): Dataset[(K, U1)] =
aggUntyped(col1).asInstanceOf[Dataset[(K, U1)]]
/**
* Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
* and the result of computing these aggregations over all elements in the group.
*/
def agg[A1, A2](col1: TypedColumn[A1], col2: TypedColumn[A2]): Dataset[(K, A1, A2)] =
aggUntyped(col1, col2).asInstanceOf[Dataset[(K, A1, A2)]]
def agg[U1, U2](col1: TypedColumn[T, U1], col2: TypedColumn[T, U2]): Dataset[(K, U1, U2)] =
aggUntyped(col1, col2).asInstanceOf[Dataset[(K, U1, U2)]]
/**
* Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
* and the result of computing these aggregations over all elements in the group.
*/
def agg[A1, A2, A3](
col1: TypedColumn[A1],
col2: TypedColumn[A2],
col3: TypedColumn[A3]): Dataset[(K, A1, A2, A3)] =
aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, A1, A2, A3)]]
def agg[U1, U2, U3](
col1: TypedColumn[T, U1],
col2: TypedColumn[T, U2],
col3: TypedColumn[T, U3]): Dataset[(K, U1, U2, U3)] =
aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, U1, U2, U3)]]
/**
* Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
* and the result of computing these aggregations over all elements in the group.
*/
def agg[A1, A2, A3, A4](
col1: TypedColumn[A1],
col2: TypedColumn[A2],
col3: TypedColumn[A3],
col4: TypedColumn[A4]): Dataset[(K, A1, A2, A3, A4)] =
aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, A1, A2, A3, A4)]]
def agg[U1, U2, U3, U4](
col1: TypedColumn[T, U1],
col2: TypedColumn[T, U2],
col3: TypedColumn[T, U3],
col4: TypedColumn[T, U4]): Dataset[(K, U1, U2, U3, U4)] =
aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, U1, U2, U3, U4)]]
/**
* Returns a [[Dataset]] that contains a tuple with each key and the number of items present

View file

@ -21,7 +21,6 @@ import java.beans.{BeanInfo, Introspector}
import java.util.Properties
import java.util.concurrent.atomic.AtomicReference
import scala.collection.JavaConverters._
import scala.collection.immutable
import scala.reflect.runtime.universe.TypeTag

View file

@ -0,0 +1,129 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.aggregate
import scala.language.existentials
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder}
import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{StructType, DataType}
object TypedAggregateExpression {
def apply[A, B : Encoder, C : Encoder](
aggregator: Aggregator[A, B, C]): TypedAggregateExpression = {
new TypedAggregateExpression(
aggregator.asInstanceOf[Aggregator[Any, Any, Any]],
None,
encoderFor[B].asInstanceOf[ExpressionEncoder[Any]],
encoderFor[C].asInstanceOf[ExpressionEncoder[Any]],
Nil,
0,
0)
}
}
/**
* This class is a rough sketch of how to hook `Aggregator` into the Aggregation system. It has
* the following limitations:
* - It assumes the aggregator reduces and returns a single column of type `long`.
* - It might only work when there is a single aggregator in the first column.
* - It assumes the aggregator has a zero, `0`.
*/
case class TypedAggregateExpression(
aggregator: Aggregator[Any, Any, Any],
aEncoder: Option[ExpressionEncoder[Any]],
bEncoder: ExpressionEncoder[Any],
cEncoder: ExpressionEncoder[Any],
children: Seq[Expression],
mutableAggBufferOffset: Int,
inputAggBufferOffset: Int)
extends ImperativeAggregate with Logging {
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)
override def nullable: Boolean = true
// TODO: this assumes flat results...
override def dataType: DataType = cEncoder.schema.head.dataType
override def deterministic: Boolean = true
override lazy val resolved: Boolean = aEncoder.isDefined
override lazy val inputTypes: Seq[DataType] =
aEncoder.map(_.schema.map(_.dataType)).getOrElse(Nil)
override val aggBufferSchema: StructType = bEncoder.schema
override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes
// Note: although this simply copies aggBufferAttributes, this common code can not be placed
// in the superclass because that will lead to initialization ordering issues.
override val inputAggBufferAttributes: Seq[AttributeReference] =
aggBufferAttributes.map(_.newInstance())
lazy val inputAttributes = aEncoder.get.schema.toAttributes
lazy val inputMapping = AttributeMap(inputAttributes.zip(children))
lazy val boundA =
aEncoder.get.copy(constructExpression = aEncoder.get.constructExpression transform {
case a: AttributeReference => inputMapping(a)
})
// TODO: this probably only works when we are in the first column.
val bAttributes = bEncoder.schema.toAttributes
lazy val boundB = bEncoder.resolve(bAttributes).bind(bAttributes)
override def initialize(buffer: MutableRow): Unit = {
// TODO: We need to either force Aggregator to have a zero or we need to eliminate the need for
// this in execution.
buffer.setInt(mutableAggBufferOffset, aggregator.zero.asInstanceOf[Int])
}
override def update(buffer: MutableRow, input: InternalRow): Unit = {
val inputA = boundA.fromRow(input)
val currentB = boundB.fromRow(buffer)
val merged = aggregator.reduce(currentB, inputA)
val returned = boundB.toRow(merged)
buffer.setInt(mutableAggBufferOffset, returned.getInt(0))
}
override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
buffer1.setLong(
mutableAggBufferOffset,
buffer1.getLong(mutableAggBufferOffset) + buffer2.getLong(inputAggBufferOffset))
}
override def eval(buffer: InternalRow): Any = {
buffer.getInt(mutableAggBufferOffset)
}
override def toString: String = {
s"""${aggregator.getClass.getSimpleName}(${children.mkString(",")})"""
}
override def nodeName: String = aggregator.getClass.getSimpleName
}

View file

@ -0,0 +1,81 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.expressions
import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder}
import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2}
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
import org.apache.spark.sql.{Dataset, DataFrame, TypedColumn}
/**
* A base class for user-defined aggregations, which can be used in [[DataFrame]] and [[Dataset]]
* operations to take all of the elements of a group and reduce them to a single value.
*
* For example, the following aggregator extracts an `int` from a specific class and adds them up:
* {{{
* case class Data(i: Int)
*
* val customSummer = new Aggregator[Data, Int, Int] {
* def zero = 0
* def reduce(b: Int, a: Data) = b + a.i
* def present(r: Int) = r
* }.toColumn()
*
* val ds: Dataset[Data]
* val aggregated = ds.select(customSummer)
* }}}
*
* Based loosely on Aggregator from Algebird: https://github.com/twitter/algebird
*
* @tparam A The input type for the aggregation.
* @tparam B The type of the intermediate value of the reduction.
* @tparam C The type of the final result.
*/
abstract class Aggregator[-A, B, C] {
/** A zero value for this aggregation. Should satisfy the property that any b + zero = b */
def zero: B
/**
* Combine two values to produce a new value. For performance, the function may modify `b` and
* return it instead of constructing new object for b.
*/
def reduce(b: B, a: A): B
/**
* Transform the output of the reduction.
*/
def present(reduction: B): C
/**
* Returns this `Aggregator` as a [[TypedColumn]] that can be used in [[Dataset]] or [[DataFrame]]
* operations.
*/
def toColumn(
implicit bEncoder: Encoder[B],
cEncoder: Encoder[C]): TypedColumn[A, C] = {
val expr =
new AggregateExpression2(
TypedAggregateExpression(this),
Complete,
false)
new TypedColumn[A, C](expr, encoderFor[C])
}
}

View file

@ -17,6 +17,8 @@
package org.apache.spark.sql
import scala.language.implicitConversions
import scala.reflect.runtime.universe.{TypeTag, typeTag}
import scala.util.Try
@ -24,11 +26,32 @@ import scala.util.Try
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star}
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, Encoder}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
/**
* Ensures that java functions signatures for methods that now return a [[TypedColumn]] still have
* legacy equivalents in bytecode. This compatibility is done by forcing the compiler to generate
* "bridge" methods due to the use of covariant return types.
*
* {{{
* In LegacyFunctions:
* public abstract org.apache.spark.sql.Column avg(java.lang.String);
*
* In functions:
* public static org.apache.spark.sql.TypedColumn<java.lang.Object, java.lang.Object> avg(...);
* }}}
*
* This allows us to use the same functions both in typed [[Dataset]] operations and untyped
* [[DataFrame]] operations when the return type for a given function is statically known.
*/
private[sql] abstract class LegacyFunctions {
def count(columnName: String): Column
}
/**
* :: Experimental ::
* Functions available for [[DataFrame]].
@ -48,11 +71,14 @@ import org.apache.spark.util.Utils
*/
@Experimental
// scalastyle:off
object functions {
object functions extends LegacyFunctions {
// scalastyle:on
private def withExpr(expr: Expression): Column = Column(expr)
private implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true)
/**
* Returns a [[Column]] based on the given column name.
*
@ -234,7 +260,7 @@ object functions {
* @group agg_funcs
* @since 1.3.0
*/
def count(columnName: String): Column = count(Column(columnName))
def count(columnName: String): TypedColumn[Any, Long] = count(Column(columnName)).as[Long]
/**
* Aggregate function: returns the number of distinct items in a group.

View file

@ -258,8 +258,8 @@ public class JavaDatasetSuite implements Serializable {
Dataset<Integer> ds = context.createDataset(data, e.INT());
Dataset<Tuple2<Integer, String>> selected = ds.select(
expr("value + 1").as(e.INT()),
col("value").cast("string").as(e.STRING()));
expr("value + 1"),
col("value").cast("string")).as(e.tuple(e.INT(), e.STRING()));
Assert.assertEquals(
Arrays.asList(tuple2(3, "2"), tuple2(7, "6")),

View file

@ -0,0 +1,65 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import org.apache.spark.sql.catalyst.encoders.Encoder
import org.apache.spark.sql.functions._
import scala.language.postfixOps
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.expressions.Aggregator
/** An `Aggregator` that adds up any numeric type returned by the given function. */
class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable {
val numeric = implicitly[Numeric[N]]
override def zero: N = numeric.zero
override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
override def present(reduction: N): N = reduction
}
class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
import testImplicits._
def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] =
new SumOf(f).toColumn
test("typed aggregation: TypedAggregator") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
checkAnswer(
ds.groupBy(_._1).agg(sum(_._2)),
("a", 30), ("b", 3), ("c", 1))
}
test("typed aggregation: TypedAggregator, expr, expr") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
checkAnswer(
ds.groupBy(_._1).agg(
sum(_._2),
expr("sum(_2)").as[Int],
count("*")),
("a", 30, 30, 2L), ("b", 3, 3, 2L), ("c", 1, 1, 1L))
}
}