[SPARK-11404] [SQL] Support for groupBy using column expressions
This PR adds a new method `groupBy(cols: Column*)` to `Dataset` that allows users to group using column expressions instead of a lambda function. Since the return type of these expressions is not known at compile time, we just set the key type as a generic `Row`. If the user would like to work the key in a type-safe way, they can call `grouped.asKey[Type]`, which is also added in this PR. ```scala val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1").asKey[String] val agged = grouped.mapGroups { case (g, iter) => Iterator((g, iter.map(_._2).sum)) } agged.collect() res0: Array(("a", 30), ("b", 3), ("c", 1)) ``` Author: Michael Armbrust <michael@databricks.com> Closes #9359 from marmbrus/columnGroupBy and squashes the following commits: bbcb03b [Michael Armbrust] Update DatasetSuite.scala 8fd2908 [Michael Armbrust] Update DatasetSuite.scala 0b0e2f8 [Michael Armbrust] [SPARK-11404] [SQL] Support for groupBy using column expressions
This commit is contained in:
parent
425ff03f5a
commit
b86f2cab67
|
@ -19,6 +19,7 @@ package org.apache.spark.sql
|
|||
|
||||
import org.apache.spark.annotation.Experimental
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
|
||||
import org.apache.spark.sql.catalyst.encoders._
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans.Inner
|
||||
|
@ -78,9 +79,17 @@ class Dataset[T] private(
|
|||
* ************* */
|
||||
|
||||
/**
|
||||
* Returns a new `Dataset` where each record has been mapped on to the specified type.
|
||||
* TODO: should bind here...
|
||||
* TODO: document binding rules
|
||||
* Returns a new `Dataset` where each record has been mapped on to the specified type. The
|
||||
* method used to map columns depend on the type of `U`:
|
||||
* - When `U` is a class, fields for the class will be mapped to columns of the same name
|
||||
* (case sensitivity is determined by `spark.sql.caseSensitive`)
|
||||
* - When `U` is a tuple, the columns will be be mapped by ordinal (i.e. the first column will
|
||||
* be assigned to `_1`).
|
||||
* - When `U` is a primitive type (i.e. String, Int, etc). then the first column of the
|
||||
* [[DataFrame]] will be used.
|
||||
*
|
||||
* If the schema of the [[DataFrame]] does not match the desired `U` type, you can use `select`
|
||||
* along with `alias` or `as` to rearrange or rename as required.
|
||||
* @since 1.6.0
|
||||
*/
|
||||
def as[U : Encoder]: Dataset[U] = {
|
||||
|
@ -225,6 +234,27 @@ class Dataset[T] private(
|
|||
withGroupingKey.newColumns)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a [[GroupedDataset]] where the data is grouped by the given [[Column]] expressions.
|
||||
* @since 1.6.0
|
||||
*/
|
||||
@scala.annotation.varargs
|
||||
def groupBy(cols: Column*): GroupedDataset[Row, T] = {
|
||||
val withKeyColumns = logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias)
|
||||
val withKey = Project(withKeyColumns, logicalPlan)
|
||||
val executed = sqlContext.executePlan(withKey)
|
||||
|
||||
val dataAttributes = executed.analyzed.output.dropRight(cols.size)
|
||||
val keyAttributes = executed.analyzed.output.takeRight(cols.size)
|
||||
|
||||
new GroupedDataset(
|
||||
RowEncoder(keyAttributes.toStructType),
|
||||
encoderFor[T],
|
||||
executed,
|
||||
dataAttributes,
|
||||
keyAttributes)
|
||||
}
|
||||
|
||||
/* ****************** *
|
||||
* Typed Relational *
|
||||
* ****************** */
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import org.apache.spark.sql.catalyst.encoders.Encoder
|
||||
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Encoder}
|
||||
import org.apache.spark.sql.catalyst.expressions.Attribute
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.execution.QueryExecution
|
||||
|
@ -34,11 +34,33 @@ class GroupedDataset[K, T] private[sql](
|
|||
private val dataAttributes: Seq[Attribute],
|
||||
private val groupingAttributes: Seq[Attribute]) extends Serializable {
|
||||
|
||||
private implicit def kEnc = kEncoder
|
||||
private implicit def tEnc = tEncoder
|
||||
private implicit val kEnc = kEncoder match {
|
||||
case e: ExpressionEncoder[K] => e.resolve(groupingAttributes)
|
||||
case other =>
|
||||
throw new UnsupportedOperationException("Only expression encoders are currently supported")
|
||||
}
|
||||
|
||||
private implicit val tEnc = tEncoder match {
|
||||
case e: ExpressionEncoder[T] => e.resolve(dataAttributes)
|
||||
case other =>
|
||||
throw new UnsupportedOperationException("Only expression encoders are currently supported")
|
||||
}
|
||||
|
||||
private def logicalPlan = queryExecution.analyzed
|
||||
private def sqlContext = queryExecution.sqlContext
|
||||
|
||||
/**
|
||||
* Returns a new [[GroupedDataset]] where the type of the key has been mapped to the specified
|
||||
* type. The mapping of key columns to the type follows the same rules as `as` on [[Dataset]].
|
||||
*/
|
||||
def asKey[L : Encoder]: GroupedDataset[L, T] =
|
||||
new GroupedDataset(
|
||||
encoderFor[L],
|
||||
tEncoder,
|
||||
queryExecution,
|
||||
dataAttributes,
|
||||
groupingAttributes)
|
||||
|
||||
/**
|
||||
* Returns a [[Dataset]] that contains each unique key.
|
||||
*/
|
||||
|
|
|
@ -203,6 +203,54 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
|
|||
("a", 30), ("b", 3), ("c", 1))
|
||||
}
|
||||
|
||||
test("groupBy columns, mapGroups") {
|
||||
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
|
||||
val grouped = ds.groupBy($"_1")
|
||||
val agged = grouped.mapGroups { case (g, iter) =>
|
||||
Iterator((g.getString(0), iter.map(_._2).sum))
|
||||
}
|
||||
|
||||
checkAnswer(
|
||||
agged,
|
||||
("a", 30), ("b", 3), ("c", 1))
|
||||
}
|
||||
|
||||
test("groupBy columns asKey, mapGroups") {
|
||||
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
|
||||
val grouped = ds.groupBy($"_1").asKey[String]
|
||||
val agged = grouped.mapGroups { case (g, iter) =>
|
||||
Iterator((g, iter.map(_._2).sum))
|
||||
}
|
||||
|
||||
checkAnswer(
|
||||
agged,
|
||||
("a", 30), ("b", 3), ("c", 1))
|
||||
}
|
||||
|
||||
test("groupBy columns asKey tuple, mapGroups") {
|
||||
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
|
||||
val grouped = ds.groupBy($"_1", lit(1)).asKey[(String, Int)]
|
||||
val agged = grouped.mapGroups { case (g, iter) =>
|
||||
Iterator((g, iter.map(_._2).sum))
|
||||
}
|
||||
|
||||
checkAnswer(
|
||||
agged,
|
||||
(("a", 1), 30), (("b", 1), 3), (("c", 1), 1))
|
||||
}
|
||||
|
||||
test("groupBy columns asKey class, mapGroups") {
|
||||
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
|
||||
val grouped = ds.groupBy($"_1".as("a"), lit(1).as("b")).asKey[ClassData]
|
||||
val agged = grouped.mapGroups { case (g, iter) =>
|
||||
Iterator((g, iter.map(_._2).sum))
|
||||
}
|
||||
|
||||
checkAnswer(
|
||||
agged,
|
||||
(ClassData("a", 1), 30), (ClassData("b", 1), 3), (ClassData("c", 1), 1))
|
||||
}
|
||||
|
||||
test("cogroup") {
|
||||
val ds1 = Seq(1 -> "a", 3 -> "abc", 5 -> "hello", 3 -> "foo").toDS()
|
||||
val ds2 = Seq(2 -> "q", 3 -> "w", 5 -> "e", 5 -> "r").toDS()
|
||||
|
|
Loading…
Reference in a new issue