[SPARK-7320] [SQL] Add Cube / Rollup for dataframe
This is a follow up for #6257, which broke the maven test.
Add cube & rollup for DataFrame
For example:
```scala
testData.rollup($"a" + $"b", $"b").agg(sum($"a" - $"b"))
testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b"))
```
Author: Cheng Hao <hao.cheng@intel.com>
Closes #6304 from chenghao-intel/rollup and squashes the following commits:
04bb1de [Cheng Hao] move the table register/unregister into beforeAll/afterAll
a6069f1 [Cheng Hao] cancel the implicit keyword
ced4b8f [Cheng Hao] remove the unnecessary code changes
9959dfa [Cheng Hao] update the code as comments
e1d88aa [Cheng Hao] update the code as suggested
03bc3d9 [Cheng Hao] Remove the CubedData & RollupedData
5fd62d0 [Cheng Hao] hiden the CubedData & RollupedData
5ffb196 [Cheng Hao] Add Cube / Rollup for dataframe
(cherry picked from commit 42c592adb3
)
Signed-off-by: Yin Huai <yhuai@databricks.com>
This commit is contained in:
parent
b6182ce891
commit
4fd674336c
|
@ -685,7 +685,53 @@ class DataFrame private[sql](
|
|||
* @since 1.3.0
|
||||
*/
|
||||
@scala.annotation.varargs
|
||||
def groupBy(cols: Column*): GroupedData = new GroupedData(this, cols.map(_.expr))
|
||||
def groupBy(cols: Column*): GroupedData = {
|
||||
GroupedData(this, cols.map(_.expr), GroupedData.GroupByType)
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a multi-dimensional rollup for the current [[DataFrame]] using the specified columns,
|
||||
* so we can run aggregation on them.
|
||||
* See [[GroupedData]] for all the available aggregate functions.
|
||||
*
|
||||
* {{{
|
||||
* // Compute the average for all numeric columns rolluped by department and group.
|
||||
* df.rollup($"department", $"group").avg()
|
||||
*
|
||||
* // Compute the max age and average salary, rolluped by department and gender.
|
||||
* df.rollup($"department", $"gender").agg(Map(
|
||||
* "salary" -> "avg",
|
||||
* "age" -> "max"
|
||||
* ))
|
||||
* }}}
|
||||
* @group dfops
|
||||
* @since 1.4.0
|
||||
*/
|
||||
@scala.annotation.varargs
|
||||
def rollup(cols: Column*): GroupedData = {
|
||||
GroupedData(this, cols.map(_.expr), GroupedData.RollupType)
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a multi-dimensional cube for the current [[DataFrame]] using the specified columns,
|
||||
* so we can run aggregation on them.
|
||||
* See [[GroupedData]] for all the available aggregate functions.
|
||||
*
|
||||
* {{{
|
||||
* // Compute the average for all numeric columns cubed by department and group.
|
||||
* df.cube($"department", $"group").avg()
|
||||
*
|
||||
* // Compute the max age and average salary, cubed by department and gender.
|
||||
* df.cube($"department", $"gender").agg(Map(
|
||||
* "salary" -> "avg",
|
||||
* "age" -> "max"
|
||||
* ))
|
||||
* }}}
|
||||
* @group dfops
|
||||
* @since 1.4.0
|
||||
*/
|
||||
@scala.annotation.varargs
|
||||
def cube(cols: Column*): GroupedData = GroupedData(this, cols.map(_.expr), GroupedData.CubeType)
|
||||
|
||||
/**
|
||||
* Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them.
|
||||
|
@ -710,7 +756,61 @@ class DataFrame private[sql](
|
|||
@scala.annotation.varargs
|
||||
def groupBy(col1: String, cols: String*): GroupedData = {
|
||||
val colNames: Seq[String] = col1 +: cols
|
||||
new GroupedData(this, colNames.map(colName => resolve(colName)))
|
||||
GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.GroupByType)
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a multi-dimensional rollup for the current [[DataFrame]] using the specified columns,
|
||||
* so we can run aggregation on them.
|
||||
* See [[GroupedData]] for all the available aggregate functions.
|
||||
*
|
||||
* This is a variant of rollup that can only group by existing columns using column names
|
||||
* (i.e. cannot construct expressions).
|
||||
*
|
||||
* {{{
|
||||
* // Compute the average for all numeric columns rolluped by department and group.
|
||||
* df.rollup("department", "group").avg()
|
||||
*
|
||||
* // Compute the max age and average salary, rolluped by department and gender.
|
||||
* df.rollup($"department", $"gender").agg(Map(
|
||||
* "salary" -> "avg",
|
||||
* "age" -> "max"
|
||||
* ))
|
||||
* }}}
|
||||
* @group dfops
|
||||
* @since 1.4.0
|
||||
*/
|
||||
@scala.annotation.varargs
|
||||
def rollup(col1: String, cols: String*): GroupedData = {
|
||||
val colNames: Seq[String] = col1 +: cols
|
||||
GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.RollupType)
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a multi-dimensional cube for the current [[DataFrame]] using the specified columns,
|
||||
* so we can run aggregation on them.
|
||||
* See [[GroupedData]] for all the available aggregate functions.
|
||||
*
|
||||
* This is a variant of cube that can only group by existing columns using column names
|
||||
* (i.e. cannot construct expressions).
|
||||
*
|
||||
* {{{
|
||||
* // Compute the average for all numeric columns cubed by department and group.
|
||||
* df.cube("department", "group").avg()
|
||||
*
|
||||
* // Compute the max age and average salary, cubed by department and gender.
|
||||
* df.cube($"department", $"gender").agg(Map(
|
||||
* "salary" -> "avg",
|
||||
* "age" -> "max"
|
||||
* ))
|
||||
* }}}
|
||||
* @group dfops
|
||||
* @since 1.4.0
|
||||
*/
|
||||
@scala.annotation.varargs
|
||||
def cube(col1: String, cols: String*): GroupedData = {
|
||||
val colNames: Seq[String] = col1 +: cols
|
||||
GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.CubeType)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -23,9 +23,40 @@ import scala.language.implicitConversions
|
|||
import org.apache.spark.annotation.Experimental
|
||||
import org.apache.spark.sql.catalyst.analysis.Star
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate}
|
||||
import org.apache.spark.sql.types.NumericType
|
||||
|
||||
/**
|
||||
* Companion object for GroupedData
|
||||
*/
|
||||
private[sql] object GroupedData {
|
||||
def apply(
|
||||
df: DataFrame,
|
||||
groupingExprs: Seq[Expression],
|
||||
groupType: GroupType): GroupedData = {
|
||||
new GroupedData(df, groupingExprs, groupType: GroupType)
|
||||
}
|
||||
|
||||
/**
|
||||
* The Grouping Type
|
||||
*/
|
||||
trait GroupType
|
||||
|
||||
/**
|
||||
* To indicate it's the GroupBy
|
||||
*/
|
||||
object GroupByType extends GroupType
|
||||
|
||||
/**
|
||||
* To indicate it's the CUBE
|
||||
*/
|
||||
object CubeType extends GroupType
|
||||
|
||||
/**
|
||||
* To indicate it's the ROLLUP
|
||||
*/
|
||||
object RollupType extends GroupType
|
||||
}
|
||||
|
||||
/**
|
||||
* :: Experimental ::
|
||||
|
@ -34,19 +65,37 @@ import org.apache.spark.sql.types.NumericType
|
|||
* @since 1.3.0
|
||||
*/
|
||||
@Experimental
|
||||
class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) {
|
||||
class GroupedData protected[sql](
|
||||
df: DataFrame,
|
||||
groupingExprs: Seq[Expression],
|
||||
private val groupType: GroupedData.GroupType) {
|
||||
|
||||
private[sql] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = {
|
||||
val namedGroupingExprs = groupingExprs.map {
|
||||
private[this] def toDF(aggExprs: Seq[NamedExpression]): DataFrame = {
|
||||
val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) {
|
||||
val retainedExprs = groupingExprs.map {
|
||||
case expr: NamedExpression => expr
|
||||
case expr: Expression => Alias(expr, expr.prettyString)()
|
||||
}
|
||||
retainedExprs ++ aggExprs
|
||||
} else {
|
||||
aggExprs
|
||||
}
|
||||
|
||||
groupType match {
|
||||
case GroupedData.GroupByType =>
|
||||
DataFrame(
|
||||
df.sqlContext, Aggregate(groupingExprs, namedGroupingExprs ++ aggExprs, df.logicalPlan))
|
||||
df.sqlContext, Aggregate(groupingExprs, aggregates, df.logicalPlan))
|
||||
case GroupedData.RollupType =>
|
||||
DataFrame(
|
||||
df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aggregates))
|
||||
case GroupedData.CubeType =>
|
||||
DataFrame(
|
||||
df.sqlContext, Cube(groupingExprs, df.logicalPlan, aggregates))
|
||||
}
|
||||
}
|
||||
|
||||
private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => Expression)
|
||||
: Seq[NamedExpression] = {
|
||||
: DataFrame = {
|
||||
|
||||
val columnExprs = if (colNames.isEmpty) {
|
||||
// No columns specified. Use all numeric columns.
|
||||
|
@ -63,10 +112,10 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
|
|||
namedExpr
|
||||
}
|
||||
}
|
||||
columnExprs.map { c =>
|
||||
toDF(columnExprs.map { c =>
|
||||
val a = f(c)
|
||||
Alias(a, a.prettyString)()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
private[this] def strToExpr(expr: String): (Expression => Expression) = {
|
||||
|
@ -119,10 +168,10 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
|
|||
* @since 1.3.0
|
||||
*/
|
||||
def agg(exprs: Map[String, String]): DataFrame = {
|
||||
exprs.map { case (colName, expr) =>
|
||||
toDF(exprs.map { case (colName, expr) =>
|
||||
val a = strToExpr(expr)(df(colName).expr)
|
||||
Alias(a, a.prettyString)()
|
||||
}.toSeq
|
||||
}.toSeq)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -175,19 +224,10 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
|
|||
*/
|
||||
@scala.annotation.varargs
|
||||
def agg(expr: Column, exprs: Column*): DataFrame = {
|
||||
val aggExprs = (expr +: exprs).map(_.expr).map {
|
||||
toDF((expr +: exprs).map(_.expr).map {
|
||||
case expr: NamedExpression => expr
|
||||
case expr: Expression => Alias(expr, expr.prettyString)()
|
||||
}
|
||||
if (df.sqlContext.conf.dataFrameRetainGroupColumns) {
|
||||
val retainedExprs = groupingExprs.map {
|
||||
case expr: NamedExpression => expr
|
||||
case expr: Expression => Alias(expr, expr.prettyString)()
|
||||
}
|
||||
DataFrame(df.sqlContext, Aggregate(groupingExprs, retainedExprs ++ aggExprs, df.logicalPlan))
|
||||
} else {
|
||||
DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -196,7 +236,7 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
|
|||
*
|
||||
* @since 1.3.0
|
||||
*/
|
||||
def count(): DataFrame = Seq(Alias(Count(Literal(1)), "count")())
|
||||
def count(): DataFrame = toDF(Seq(Alias(Count(Literal(1)), "count")()))
|
||||
|
||||
/**
|
||||
* Compute the average value for each numeric columns for each group. This is an alias for `avg`.
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.hive
|
||||
|
||||
import org.apache.spark.sql.QueryTest
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.hive.test.TestHive
|
||||
import org.apache.spark.sql.hive.test.TestHive._
|
||||
import org.apache.spark.sql.hive.test.TestHive.implicits._
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
|
||||
case class TestData2Int(a: Int, b: Int)
|
||||
|
||||
// TODO ideally we should put the test suite into the package `sql`, as
|
||||
// `hive` package is optional in compiling, however, `SQLContext.sql` doesn't
|
||||
// support the `cube` or `rollup` yet.
|
||||
class HiveDataFrameAnalyticsSuite extends QueryTest with BeforeAndAfterAll {
|
||||
val testData =
|
||||
TestHive.sparkContext.parallelize(
|
||||
TestData2Int(1, 2) ::
|
||||
TestData2Int(2, 4) :: Nil).toDF()
|
||||
|
||||
override def beforeAll() {
|
||||
TestHive.registerDataFrameAsTable(testData, "mytable")
|
||||
}
|
||||
|
||||
override def afterAll(): Unit = {
|
||||
TestHive.dropTempTable("mytable")
|
||||
}
|
||||
|
||||
test("rollup") {
|
||||
checkAnswer(
|
||||
testData.rollup($"a" + $"b", $"b").agg(sum($"a" - $"b")),
|
||||
sql("select a + b, b, sum(a - b) from mytable group by a + b, b with rollup").collect()
|
||||
)
|
||||
|
||||
checkAnswer(
|
||||
testData.rollup("a", "b").agg(sum("b")),
|
||||
sql("select a, b, sum(b) from mytable group by a, b with rollup").collect()
|
||||
)
|
||||
}
|
||||
|
||||
test("cube") {
|
||||
checkAnswer(
|
||||
testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")),
|
||||
sql("select a + b, b, sum(a - b) from mytable group by a + b, b with cube").collect()
|
||||
)
|
||||
|
||||
checkAnswer(
|
||||
testData.cube("a", "b").agg(sum("b")),
|
||||
sql("select a, b, sum(b) from mytable group by a, b with cube").collect()
|
||||
)
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue