[SPARK-11578][SQL] User API for Typed Aggregation
This PR adds a new interface for user-defined aggregations, that can be used in `DataFrame` and `Dataset` operations to take all of the elements of a group and reduce them to a single value. For example, the following aggregator extracts an `int` from a specific class and adds them up: ```scala case class Data(i: Int) val customSummer = new Aggregator[Data, Int, Int] { def prepare(d: Data) = d.i def reduce(l: Int, r: Int) = l + r def present(r: Int) = r }.toColumn() val ds: Dataset[Data] = ... val aggregated = ds.select(customSummer) ``` By using helper functions, users can make a generic `Aggregator` that works on any input type: ```scala /** An `Aggregator` that adds up any numeric type returned by the given function. */ class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable { val numeric = implicitly[Numeric[N]] override def zero: N = numeric.zero override def reduce(b: N, a: I): N = numeric.plus(b, f(a)) override def present(reduction: N): N = reduction } def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn ``` These aggregators can then be used alongside other built-in SQL aggregations. ```scala val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() ds .groupBy(_._1) .agg( sum(_._2), // The aggregator defined above. expr("sum(_2)").as[Int], // A built-in dynatically typed aggregation. count("*")) // A built-in statically typed aggregation. .collect() res0: ("a", 30, 30, 2L), ("b", 3, 3, 2L), ("c", 1, 1, 1L) ``` The current implementation focuses on integrating this into the typed API, but currently only supports running aggregations that return a single long value as explained in `TypedAggregateExpression`. This will be improved in a followup PR. Author: Michael Armbrust <michael@databricks.com> Closes #9555 from marmbrus/dataset-useragg.
This commit is contained in:
parent
2f38378856
commit
9c740a9ddf
|
@ -23,7 +23,7 @@ import org.apache.spark.annotation.Experimental
|
|||
import org.apache.spark.Logging
|
||||
import org.apache.spark.sql.functions.lit
|
||||
import org.apache.spark.sql.catalyst.analysis._
|
||||
import org.apache.spark.sql.catalyst.encoders.Encoder
|
||||
import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder}
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.util.DataTypeParser
|
||||
import org.apache.spark.sql.types._
|
||||
|
@ -39,10 +39,13 @@ private[sql] object Column {
|
|||
}
|
||||
|
||||
/**
|
||||
* A [[Column]] where an [[Encoder]] has been given for the expected return type.
|
||||
* A [[Column]] where an [[Encoder]] has been given for the expected input and return type.
|
||||
* @since 1.6.0
|
||||
* @tparam T The input type expected for this expression. Can be `Any` if the expression is type
|
||||
* checked by the analyzer instead of the compiler (i.e. `expr("sum(...)")`).
|
||||
* @tparam U The output type of this column.
|
||||
*/
|
||||
class TypedColumn[T](expr: Expression)(implicit val encoder: Encoder[T]) extends Column(expr)
|
||||
class TypedColumn[-T, U](expr: Expression, val encoder: Encoder[U]) extends Column(expr)
|
||||
|
||||
/**
|
||||
* :: Experimental ::
|
||||
|
@ -85,7 +88,7 @@ class Column(protected[sql] val expr: Expression) extends Logging {
|
|||
* results into the correct JVM types.
|
||||
* @since 1.6.0
|
||||
*/
|
||||
def as[T : Encoder]: TypedColumn[T] = new TypedColumn[T](expr)
|
||||
def as[U : Encoder]: TypedColumn[Any, U] = new TypedColumn[Any, U](expr, encoderFor[U])
|
||||
|
||||
/**
|
||||
* Extracts a value or values from a complex type.
|
||||
|
|
|
@ -358,7 +358,7 @@ class Dataset[T] private[sql](
|
|||
* }}}
|
||||
* @since 1.6.0
|
||||
*/
|
||||
def select[U1: Encoder](c1: TypedColumn[U1]): Dataset[U1] = {
|
||||
def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = {
|
||||
new Dataset[U1](sqlContext, Project(Alias(c1.expr, "_1")() :: Nil, logicalPlan))
|
||||
}
|
||||
|
||||
|
@ -367,7 +367,7 @@ class Dataset[T] private[sql](
|
|||
* code reuse, we do this without the help of the type system and then use helper functions
|
||||
* that cast appropriately for the user facing interface.
|
||||
*/
|
||||
protected def selectUntyped(columns: TypedColumn[_]*): Dataset[_] = {
|
||||
protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
|
||||
val aliases = columns.zipWithIndex.map { case (c, i) => Alias(c.expr, s"_${i + 1}")() }
|
||||
val unresolvedPlan = Project(aliases, logicalPlan)
|
||||
val execution = new QueryExecution(sqlContext, unresolvedPlan)
|
||||
|
@ -385,7 +385,7 @@ class Dataset[T] private[sql](
|
|||
* Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
|
||||
* @since 1.6.0
|
||||
*/
|
||||
def select[U1, U2](c1: TypedColumn[U1], c2: TypedColumn[U2]): Dataset[(U1, U2)] =
|
||||
def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): Dataset[(U1, U2)] =
|
||||
selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]]
|
||||
|
||||
/**
|
||||
|
@ -393,9 +393,9 @@ class Dataset[T] private[sql](
|
|||
* @since 1.6.0
|
||||
*/
|
||||
def select[U1, U2, U3](
|
||||
c1: TypedColumn[U1],
|
||||
c2: TypedColumn[U2],
|
||||
c3: TypedColumn[U3]): Dataset[(U1, U2, U3)] =
|
||||
c1: TypedColumn[T, U1],
|
||||
c2: TypedColumn[T, U2],
|
||||
c3: TypedColumn[T, U3]): Dataset[(U1, U2, U3)] =
|
||||
selectUntyped(c1, c2, c3).asInstanceOf[Dataset[(U1, U2, U3)]]
|
||||
|
||||
/**
|
||||
|
@ -403,10 +403,10 @@ class Dataset[T] private[sql](
|
|||
* @since 1.6.0
|
||||
*/
|
||||
def select[U1, U2, U3, U4](
|
||||
c1: TypedColumn[U1],
|
||||
c2: TypedColumn[U2],
|
||||
c3: TypedColumn[U3],
|
||||
c4: TypedColumn[U4]): Dataset[(U1, U2, U3, U4)] =
|
||||
c1: TypedColumn[T, U1],
|
||||
c2: TypedColumn[T, U2],
|
||||
c3: TypedColumn[T, U3],
|
||||
c4: TypedColumn[T, U4]): Dataset[(U1, U2, U3, U4)] =
|
||||
selectUntyped(c1, c2, c3, c4).asInstanceOf[Dataset[(U1, U2, U3, U4)]]
|
||||
|
||||
/**
|
||||
|
@ -414,11 +414,11 @@ class Dataset[T] private[sql](
|
|||
* @since 1.6.0
|
||||
*/
|
||||
def select[U1, U2, U3, U4, U5](
|
||||
c1: TypedColumn[U1],
|
||||
c2: TypedColumn[U2],
|
||||
c3: TypedColumn[U3],
|
||||
c4: TypedColumn[U4],
|
||||
c5: TypedColumn[U5]): Dataset[(U1, U2, U3, U4, U5)] =
|
||||
c1: TypedColumn[T, U1],
|
||||
c2: TypedColumn[T, U2],
|
||||
c3: TypedColumn[T, U3],
|
||||
c4: TypedColumn[T, U4],
|
||||
c5: TypedColumn[T, U5]): Dataset[(U1, U2, U3, U4, U5)] =
|
||||
selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]]
|
||||
|
||||
/* **************** *
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.spark.sql
|
||||
|
||||
import java.util.{Iterator => JIterator}
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import org.apache.spark.annotation.Experimental
|
||||
|
@ -26,8 +27,10 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttrib
|
|||
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Encoder}
|
||||
import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, Alias, Attribute}
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
|
||||
import org.apache.spark.sql.execution.QueryExecution
|
||||
|
||||
|
||||
/**
|
||||
* :: Experimental ::
|
||||
* A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not
|
||||
|
@ -143,7 +146,7 @@ class GroupedDataset[K, T] private[sql](
|
|||
* that cast appropriately for the user facing interface.
|
||||
* TODO: does not handle aggrecations that return nonflat results,
|
||||
*/
|
||||
protected def aggUntyped(columns: TypedColumn[_]*): Dataset[_] = {
|
||||
protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
|
||||
val aliases = (groupingAttributes ++ columns.map(_.expr)).map {
|
||||
case u: UnresolvedAttribute => UnresolvedAlias(u)
|
||||
case expr: NamedExpression => expr
|
||||
|
@ -151,7 +154,15 @@ class GroupedDataset[K, T] private[sql](
|
|||
}
|
||||
|
||||
val unresolvedPlan = Aggregate(groupingAttributes, aliases, logicalPlan)
|
||||
val execution = new QueryExecution(sqlContext, unresolvedPlan)
|
||||
|
||||
// Fill in the input encoders for any aggregators in the plan.
|
||||
val withEncoders = unresolvedPlan transformAllExpressions {
|
||||
case ta: TypedAggregateExpression if ta.aEncoder.isEmpty =>
|
||||
ta.copy(
|
||||
aEncoder = Some(tEnc.asInstanceOf[ExpressionEncoder[Any]]),
|
||||
children = dataAttributes)
|
||||
}
|
||||
val execution = new QueryExecution(sqlContext, withEncoders)
|
||||
|
||||
val columnEncoders = columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]])
|
||||
|
||||
|
@ -162,43 +173,47 @@ class GroupedDataset[K, T] private[sql](
|
|||
case (e, a) =>
|
||||
e.unbind(a :: Nil).resolve(execution.analyzed.output)
|
||||
}
|
||||
new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
|
||||
|
||||
new Dataset(
|
||||
sqlContext,
|
||||
execution,
|
||||
ExpressionEncoder.tuple(encoders))
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key
|
||||
* and the result of computing this aggregation over all elements in the group.
|
||||
*/
|
||||
def agg[A1](col1: TypedColumn[A1]): Dataset[(K, A1)] =
|
||||
aggUntyped(col1).asInstanceOf[Dataset[(K, A1)]]
|
||||
def agg[U1](col1: TypedColumn[T, U1]): Dataset[(K, U1)] =
|
||||
aggUntyped(col1).asInstanceOf[Dataset[(K, U1)]]
|
||||
|
||||
/**
|
||||
* Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
|
||||
* and the result of computing these aggregations over all elements in the group.
|
||||
*/
|
||||
def agg[A1, A2](col1: TypedColumn[A1], col2: TypedColumn[A2]): Dataset[(K, A1, A2)] =
|
||||
aggUntyped(col1, col2).asInstanceOf[Dataset[(K, A1, A2)]]
|
||||
def agg[U1, U2](col1: TypedColumn[T, U1], col2: TypedColumn[T, U2]): Dataset[(K, U1, U2)] =
|
||||
aggUntyped(col1, col2).asInstanceOf[Dataset[(K, U1, U2)]]
|
||||
|
||||
/**
|
||||
* Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
|
||||
* and the result of computing these aggregations over all elements in the group.
|
||||
*/
|
||||
def agg[A1, A2, A3](
|
||||
col1: TypedColumn[A1],
|
||||
col2: TypedColumn[A2],
|
||||
col3: TypedColumn[A3]): Dataset[(K, A1, A2, A3)] =
|
||||
aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, A1, A2, A3)]]
|
||||
def agg[U1, U2, U3](
|
||||
col1: TypedColumn[T, U1],
|
||||
col2: TypedColumn[T, U2],
|
||||
col3: TypedColumn[T, U3]): Dataset[(K, U1, U2, U3)] =
|
||||
aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, U1, U2, U3)]]
|
||||
|
||||
/**
|
||||
* Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
|
||||
* and the result of computing these aggregations over all elements in the group.
|
||||
*/
|
||||
def agg[A1, A2, A3, A4](
|
||||
col1: TypedColumn[A1],
|
||||
col2: TypedColumn[A2],
|
||||
col3: TypedColumn[A3],
|
||||
col4: TypedColumn[A4]): Dataset[(K, A1, A2, A3, A4)] =
|
||||
aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, A1, A2, A3, A4)]]
|
||||
def agg[U1, U2, U3, U4](
|
||||
col1: TypedColumn[T, U1],
|
||||
col2: TypedColumn[T, U2],
|
||||
col3: TypedColumn[T, U3],
|
||||
col4: TypedColumn[T, U4]): Dataset[(K, U1, U2, U3, U4)] =
|
||||
aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, U1, U2, U3, U4)]]
|
||||
|
||||
/**
|
||||
* Returns a [[Dataset]] that contains a tuple with each key and the number of items present
|
||||
|
|
|
@ -21,7 +21,6 @@ import java.beans.{BeanInfo, Introspector}
|
|||
import java.util.Properties
|
||||
import java.util.concurrent.atomic.AtomicReference
|
||||
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.immutable
|
||||
import scala.reflect.runtime.universe.TypeTag
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.execution.aggregate
|
||||
|
||||
import scala.language.existentials
|
||||
|
||||
import org.apache.spark.Logging
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder}
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
|
||||
import org.apache.spark.sql.expressions.Aggregator
|
||||
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.types.{StructType, DataType}
|
||||
|
||||
object TypedAggregateExpression {
|
||||
def apply[A, B : Encoder, C : Encoder](
|
||||
aggregator: Aggregator[A, B, C]): TypedAggregateExpression = {
|
||||
new TypedAggregateExpression(
|
||||
aggregator.asInstanceOf[Aggregator[Any, Any, Any]],
|
||||
None,
|
||||
encoderFor[B].asInstanceOf[ExpressionEncoder[Any]],
|
||||
encoderFor[C].asInstanceOf[ExpressionEncoder[Any]],
|
||||
Nil,
|
||||
0,
|
||||
0)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This class is a rough sketch of how to hook `Aggregator` into the Aggregation system. It has
|
||||
* the following limitations:
|
||||
* - It assumes the aggregator reduces and returns a single column of type `long`.
|
||||
* - It might only work when there is a single aggregator in the first column.
|
||||
* - It assumes the aggregator has a zero, `0`.
|
||||
*/
|
||||
case class TypedAggregateExpression(
|
||||
aggregator: Aggregator[Any, Any, Any],
|
||||
aEncoder: Option[ExpressionEncoder[Any]],
|
||||
bEncoder: ExpressionEncoder[Any],
|
||||
cEncoder: ExpressionEncoder[Any],
|
||||
children: Seq[Expression],
|
||||
mutableAggBufferOffset: Int,
|
||||
inputAggBufferOffset: Int)
|
||||
extends ImperativeAggregate with Logging {
|
||||
|
||||
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
|
||||
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
|
||||
|
||||
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
|
||||
copy(inputAggBufferOffset = newInputAggBufferOffset)
|
||||
|
||||
override def nullable: Boolean = true
|
||||
|
||||
// TODO: this assumes flat results...
|
||||
override def dataType: DataType = cEncoder.schema.head.dataType
|
||||
|
||||
override def deterministic: Boolean = true
|
||||
|
||||
override lazy val resolved: Boolean = aEncoder.isDefined
|
||||
|
||||
override lazy val inputTypes: Seq[DataType] =
|
||||
aEncoder.map(_.schema.map(_.dataType)).getOrElse(Nil)
|
||||
|
||||
override val aggBufferSchema: StructType = bEncoder.schema
|
||||
|
||||
override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes
|
||||
|
||||
// Note: although this simply copies aggBufferAttributes, this common code can not be placed
|
||||
// in the superclass because that will lead to initialization ordering issues.
|
||||
override val inputAggBufferAttributes: Seq[AttributeReference] =
|
||||
aggBufferAttributes.map(_.newInstance())
|
||||
|
||||
lazy val inputAttributes = aEncoder.get.schema.toAttributes
|
||||
lazy val inputMapping = AttributeMap(inputAttributes.zip(children))
|
||||
lazy val boundA =
|
||||
aEncoder.get.copy(constructExpression = aEncoder.get.constructExpression transform {
|
||||
case a: AttributeReference => inputMapping(a)
|
||||
})
|
||||
|
||||
// TODO: this probably only works when we are in the first column.
|
||||
val bAttributes = bEncoder.schema.toAttributes
|
||||
lazy val boundB = bEncoder.resolve(bAttributes).bind(bAttributes)
|
||||
|
||||
override def initialize(buffer: MutableRow): Unit = {
|
||||
// TODO: We need to either force Aggregator to have a zero or we need to eliminate the need for
|
||||
// this in execution.
|
||||
buffer.setInt(mutableAggBufferOffset, aggregator.zero.asInstanceOf[Int])
|
||||
}
|
||||
|
||||
override def update(buffer: MutableRow, input: InternalRow): Unit = {
|
||||
val inputA = boundA.fromRow(input)
|
||||
val currentB = boundB.fromRow(buffer)
|
||||
val merged = aggregator.reduce(currentB, inputA)
|
||||
val returned = boundB.toRow(merged)
|
||||
buffer.setInt(mutableAggBufferOffset, returned.getInt(0))
|
||||
}
|
||||
|
||||
override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
|
||||
buffer1.setLong(
|
||||
mutableAggBufferOffset,
|
||||
buffer1.getLong(mutableAggBufferOffset) + buffer2.getLong(inputAggBufferOffset))
|
||||
}
|
||||
|
||||
override def eval(buffer: InternalRow): Any = {
|
||||
buffer.getInt(mutableAggBufferOffset)
|
||||
}
|
||||
|
||||
override def toString: String = {
|
||||
s"""${aggregator.getClass.getSimpleName}(${children.mkString(",")})"""
|
||||
}
|
||||
|
||||
override def nodeName: String = aggregator.getClass.getSimpleName
|
||||
}
|
|
@ -0,0 +1,81 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.expressions
|
||||
|
||||
import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder}
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2}
|
||||
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
|
||||
import org.apache.spark.sql.{Dataset, DataFrame, TypedColumn}
|
||||
|
||||
/**
|
||||
* A base class for user-defined aggregations, which can be used in [[DataFrame]] and [[Dataset]]
|
||||
* operations to take all of the elements of a group and reduce them to a single value.
|
||||
*
|
||||
* For example, the following aggregator extracts an `int` from a specific class and adds them up:
|
||||
* {{{
|
||||
* case class Data(i: Int)
|
||||
*
|
||||
* val customSummer = new Aggregator[Data, Int, Int] {
|
||||
* def zero = 0
|
||||
* def reduce(b: Int, a: Data) = b + a.i
|
||||
* def present(r: Int) = r
|
||||
* }.toColumn()
|
||||
*
|
||||
* val ds: Dataset[Data]
|
||||
* val aggregated = ds.select(customSummer)
|
||||
* }}}
|
||||
*
|
||||
* Based loosely on Aggregator from Algebird: https://github.com/twitter/algebird
|
||||
*
|
||||
* @tparam A The input type for the aggregation.
|
||||
* @tparam B The type of the intermediate value of the reduction.
|
||||
* @tparam C The type of the final result.
|
||||
*/
|
||||
abstract class Aggregator[-A, B, C] {
|
||||
|
||||
/** A zero value for this aggregation. Should satisfy the property that any b + zero = b */
|
||||
def zero: B
|
||||
|
||||
/**
|
||||
* Combine two values to produce a new value. For performance, the function may modify `b` and
|
||||
* return it instead of constructing new object for b.
|
||||
*/
|
||||
def reduce(b: B, a: A): B
|
||||
|
||||
/**
|
||||
* Transform the output of the reduction.
|
||||
*/
|
||||
def present(reduction: B): C
|
||||
|
||||
/**
|
||||
* Returns this `Aggregator` as a [[TypedColumn]] that can be used in [[Dataset]] or [[DataFrame]]
|
||||
* operations.
|
||||
*/
|
||||
def toColumn(
|
||||
implicit bEncoder: Encoder[B],
|
||||
cEncoder: Encoder[C]): TypedColumn[A, C] = {
|
||||
val expr =
|
||||
new AggregateExpression2(
|
||||
TypedAggregateExpression(this),
|
||||
Complete,
|
||||
false)
|
||||
|
||||
new TypedColumn[A, C](expr, encoderFor[C])
|
||||
}
|
||||
}
|
||||
|
|
@ -17,6 +17,8 @@
|
|||
|
||||
package org.apache.spark.sql
|
||||
|
||||
|
||||
|
||||
import scala.language.implicitConversions
|
||||
import scala.reflect.runtime.universe.{TypeTag, typeTag}
|
||||
import scala.util.Try
|
||||
|
@ -24,11 +26,32 @@ import scala.util.Try
|
|||
import org.apache.spark.annotation.Experimental
|
||||
import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}
|
||||
import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star}
|
||||
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, Encoder}
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
/**
|
||||
* Ensures that java functions signatures for methods that now return a [[TypedColumn]] still have
|
||||
* legacy equivalents in bytecode. This compatibility is done by forcing the compiler to generate
|
||||
* "bridge" methods due to the use of covariant return types.
|
||||
*
|
||||
* {{{
|
||||
* In LegacyFunctions:
|
||||
* public abstract org.apache.spark.sql.Column avg(java.lang.String);
|
||||
*
|
||||
* In functions:
|
||||
* public static org.apache.spark.sql.TypedColumn<java.lang.Object, java.lang.Object> avg(...);
|
||||
* }}}
|
||||
*
|
||||
* This allows us to use the same functions both in typed [[Dataset]] operations and untyped
|
||||
* [[DataFrame]] operations when the return type for a given function is statically known.
|
||||
*/
|
||||
private[sql] abstract class LegacyFunctions {
|
||||
def count(columnName: String): Column
|
||||
}
|
||||
|
||||
/**
|
||||
* :: Experimental ::
|
||||
* Functions available for [[DataFrame]].
|
||||
|
@ -48,11 +71,14 @@ import org.apache.spark.util.Utils
|
|||
*/
|
||||
@Experimental
|
||||
// scalastyle:off
|
||||
object functions {
|
||||
object functions extends LegacyFunctions {
|
||||
// scalastyle:on
|
||||
|
||||
private def withExpr(expr: Expression): Column = Column(expr)
|
||||
|
||||
private implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true)
|
||||
|
||||
|
||||
/**
|
||||
* Returns a [[Column]] based on the given column name.
|
||||
*
|
||||
|
@ -234,7 +260,7 @@ object functions {
|
|||
* @group agg_funcs
|
||||
* @since 1.3.0
|
||||
*/
|
||||
def count(columnName: String): Column = count(Column(columnName))
|
||||
def count(columnName: String): TypedColumn[Any, Long] = count(Column(columnName)).as[Long]
|
||||
|
||||
/**
|
||||
* Aggregate function: returns the number of distinct items in a group.
|
||||
|
|
|
@ -258,8 +258,8 @@ public class JavaDatasetSuite implements Serializable {
|
|||
Dataset<Integer> ds = context.createDataset(data, e.INT());
|
||||
|
||||
Dataset<Tuple2<Integer, String>> selected = ds.select(
|
||||
expr("value + 1").as(e.INT()),
|
||||
col("value").cast("string").as(e.STRING()));
|
||||
expr("value + 1"),
|
||||
col("value").cast("string")).as(e.tuple(e.INT(), e.STRING()));
|
||||
|
||||
Assert.assertEquals(
|
||||
Arrays.asList(tuple2(3, "2"), tuple2(7, "6")),
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import org.apache.spark.sql.catalyst.encoders.Encoder
|
||||
import org.apache.spark.sql.functions._
|
||||
|
||||
import scala.language.postfixOps
|
||||
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
|
||||
import org.apache.spark.sql.expressions.Aggregator
|
||||
|
||||
/** An `Aggregator` that adds up any numeric type returned by the given function. */
|
||||
class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable {
|
||||
val numeric = implicitly[Numeric[N]]
|
||||
|
||||
override def zero: N = numeric.zero
|
||||
|
||||
override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
|
||||
|
||||
override def present(reduction: N): N = reduction
|
||||
}
|
||||
|
||||
class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
|
||||
|
||||
import testImplicits._
|
||||
|
||||
def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] =
|
||||
new SumOf(f).toColumn
|
||||
|
||||
test("typed aggregation: TypedAggregator") {
|
||||
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
|
||||
|
||||
checkAnswer(
|
||||
ds.groupBy(_._1).agg(sum(_._2)),
|
||||
("a", 30), ("b", 3), ("c", 1))
|
||||
}
|
||||
|
||||
test("typed aggregation: TypedAggregator, expr, expr") {
|
||||
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
|
||||
|
||||
checkAnswer(
|
||||
ds.groupBy(_._1).agg(
|
||||
sum(_._2),
|
||||
expr("sum(_2)").as[Int],
|
||||
count("*")),
|
||||
("a", 30, 30, 2L), ("b", 3, 3, 2L), ("c", 1, 1, 1L))
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue