[SPARK-24722][SQL] pivot() with Column type argument

## What changes were proposed in this pull request?

In the PR, I propose column-based API for the `pivot()` function. It allows using of any column expressions as the pivot column. Also this makes it consistent with how groupBy() works.

## How was this patch tested?

I added new tests to `DataFramePivotSuite` and updated PySpark examples for the `pivot()` function.

Author: Maxim Gekk <maxim.gekk@databricks.com>

Closes #21699 from MaxGekk/pivot-column.
This commit is contained in:
Maxim Gekk 2018-08-04 14:17:32 +08:00 committed by hyukjinkwon
parent 4c27663cb2
commit 41c2227a23
4 changed files with 167 additions and 41 deletions

View file

@ -211,6 +211,8 @@ class GroupedData(object):
>>> df4.groupBy("year").pivot("course").sum("earnings").collect()
[Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)]
>>> df5.groupBy("sales.year").pivot("sales.course").sum("sales.earnings").collect()
[Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)]
"""
if values is None:
jgd = self._jgd.pivot(pivot_col)
@ -296,6 +298,12 @@ def _test():
Row(course="dotNET", year=2012, earnings=5000),
Row(course="dotNET", year=2013, earnings=48000),
Row(course="Java", year=2013, earnings=30000)]).toDF()
globs['df5'] = sc.parallelize([
Row(training="expert", sales=Row(course="dotNET", year=2012, earnings=10000)),
Row(training="junior", sales=Row(course="Java", year=2012, earnings=20000)),
Row(training="expert", sales=Row(course="dotNET", year=2012, earnings=5000)),
Row(training="junior", sales=Row(course="dotNET", year=2013, earnings=48000)),
Row(training="expert", sales=Row(course="Java", year=2013, earnings=30000))]).toDF()
(failure_count, test_count) = doctest.testmod(
pyspark.sql.group, globs=globs,

View file

@ -314,7 +314,67 @@ class RelationalGroupedDataset protected[sql](
* @param pivotColumn Name of the column to pivot.
* @since 1.6.0
*/
def pivot(pivotColumn: String): RelationalGroupedDataset = {
def pivot(pivotColumn: String): RelationalGroupedDataset = pivot(Column(pivotColumn))
/**
* Pivots a column of the current `DataFrame` and performs the specified aggregation.
* There are two versions of pivot function: one that requires the caller to specify the list
* of distinct values to pivot on, and one that does not. The latter is more concise but less
* efficient, because Spark needs to first compute the list of distinct values internally.
*
* {{{
* // Compute the sum of earnings for each year by course with each course as a separate column
* df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings")
*
* // Or without specifying column values (less efficient)
* df.groupBy("year").pivot("course").sum("earnings")
* }}}
*
* @param pivotColumn Name of the column to pivot.
* @param values List of values that will be translated to columns in the output DataFrame.
* @since 1.6.0
*/
def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = {
pivot(Column(pivotColumn), values)
}
/**
* (Java-specific) Pivots a column of the current `DataFrame` and performs the specified
* aggregation.
*
* There are two versions of pivot function: one that requires the caller to specify the list
* of distinct values to pivot on, and one that does not. The latter is more concise but less
* efficient, because Spark needs to first compute the list of distinct values internally.
*
* {{{
* // Compute the sum of earnings for each year by course with each course as a separate column
* df.groupBy("year").pivot("course", Arrays.<Object>asList("dotNET", "Java")).sum("earnings");
*
* // Or without specifying column values (less efficient)
* df.groupBy("year").pivot("course").sum("earnings");
* }}}
*
* @param pivotColumn Name of the column to pivot.
* @param values List of values that will be translated to columns in the output DataFrame.
* @since 1.6.0
*/
def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = {
pivot(Column(pivotColumn), values)
}
/**
* Pivots a column of the current `DataFrame` and performs the specified aggregation.
* This is an overloaded version of the `pivot` method with `pivotColumn` of the `String` type.
*
* {{{
* // Or without specifying column values (less efficient)
* df.groupBy($"year").pivot($"course").sum($"earnings");
* }}}
*
* @param pivotColumn he column to pivot.
* @since 2.4.0
*/
def pivot(pivotColumn: Column): RelationalGroupedDataset = {
// This is to prevent unintended OOM errors when the number of distinct values is large
val maxValues = df.sparkSession.sessionState.conf.dataFramePivotMaxValues
// Get the distinct values of the column and sort them so its consistent
@ -339,29 +399,24 @@ class RelationalGroupedDataset protected[sql](
/**
* Pivots a column of the current `DataFrame` and performs the specified aggregation.
* There are two versions of pivot function: one that requires the caller to specify the list
* of distinct values to pivot on, and one that does not. The latter is more concise but less
* efficient, because Spark needs to first compute the list of distinct values internally.
* This is an overloaded version of the `pivot` method with `pivotColumn` of the `String` type.
*
* {{{
* // Compute the sum of earnings for each year by course with each course as a separate column
* df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings")
*
* // Or without specifying column values (less efficient)
* df.groupBy("year").pivot("course").sum("earnings")
* df.groupBy($"year").pivot($"course", Seq("dotNET", "Java")).sum($"earnings")
* }}}
*
* @param pivotColumn Name of the column to pivot.
* @param pivotColumn the column to pivot.
* @param values List of values that will be translated to columns in the output DataFrame.
* @since 1.6.0
* @since 2.4.0
*/
def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = {
def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset = {
groupType match {
case RelationalGroupedDataset.GroupByType =>
new RelationalGroupedDataset(
df,
groupingExprs,
RelationalGroupedDataset.PivotType(df.resolve(pivotColumn), values.map(Literal.apply)))
RelationalGroupedDataset.PivotType(pivotColumn.expr, values.map(Literal.apply)))
case _: RelationalGroupedDataset.PivotType =>
throw new UnsupportedOperationException("repeated pivots are not supported")
case _ =>
@ -371,25 +426,14 @@ class RelationalGroupedDataset protected[sql](
/**
* (Java-specific) Pivots a column of the current `DataFrame` and performs the specified
* aggregation.
* aggregation. This is an overloaded version of the `pivot` method with `pivotColumn` of
* the `String` type.
*
* There are two versions of pivot function: one that requires the caller to specify the list
* of distinct values to pivot on, and one that does not. The latter is more concise but less
* efficient, because Spark needs to first compute the list of distinct values internally.
*
* {{{
* // Compute the sum of earnings for each year by course with each course as a separate column
* df.groupBy("year").pivot("course", Arrays.<Object>asList("dotNET", "Java")).sum("earnings");
*
* // Or without specifying column values (less efficient)
* df.groupBy("year").pivot("course").sum("earnings");
* }}}
*
* @param pivotColumn Name of the column to pivot.
* @param pivotColumn the column to pivot.
* @param values List of values that will be translated to columns in the output DataFrame.
* @since 1.6.0
* @since 2.4.0
*/
def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = {
def pivot(pivotColumn: Column, values: java.util.List[Any]): RelationalGroupedDataset = {
pivot(pivotColumn, values.asScala)
}

View file

@ -27,28 +27,40 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext {
import testImplicits._
test("pivot courses") {
val expected = Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
checkAnswer(
courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java"))
.agg(sum($"earnings")),
Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
)
expected)
checkAnswer(
courseSales.groupBy($"year").pivot($"course", Seq("dotNET", "Java"))
.agg(sum($"earnings")),
expected)
}
test("pivot year") {
val expected = Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
checkAnswer(
courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).agg(sum($"earnings")),
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
)
expected)
checkAnswer(
courseSales.groupBy('course).pivot('year, Seq(2012, 2013)).agg(sum('earnings)),
expected)
}
test("pivot courses with multiple aggregations") {
val expected = Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) ::
Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil
checkAnswer(
courseSales.groupBy($"year")
.pivot("course", Seq("dotNET", "Java"))
.agg(sum($"earnings"), avg($"earnings")),
Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) ::
Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil
)
expected)
checkAnswer(
courseSales.groupBy($"year")
.pivot($"course", Seq("dotNET", "Java"))
.agg(sum($"earnings"), avg($"earnings")),
expected)
}
test("pivot year with string values (cast)") {
@ -67,17 +79,23 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext {
test("pivot courses with no values") {
// Note Java comes before dotNet in sorted order
val expected = Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil
checkAnswer(
courseSales.groupBy("year").pivot("course").agg(sum($"earnings")),
Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil
)
expected)
checkAnswer(
courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")),
expected)
}
test("pivot year with no values") {
val expected = Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
checkAnswer(
courseSales.groupBy("course").pivot("year").agg(sum($"earnings")),
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
)
expected)
checkAnswer(
courseSales.groupBy($"course").pivot($"year").agg(sum($"earnings")),
expected)
}
test("pivot max values enforced") {
@ -181,10 +199,13 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext {
}
test("pivot with datatype not supported by PivotFirst") {
val expected = Row(Seq(1, 1, 1), Seq(2, 2, 2)) :: Nil
checkAnswer(
complexData.groupBy().pivot("b", Seq(true, false)).agg(max("a")),
Row(Seq(1, 1, 1), Seq(2, 2, 2)) :: Nil
)
expected)
checkAnswer(
complexData.groupBy().pivot('b, Seq(true, false)).agg(max('a)),
expected)
}
test("pivot with datatype not supported by PivotFirst 2") {
@ -246,4 +267,45 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext {
checkAnswer(df.select($"a".cast(StringType)), Row(tsWithZone))
}
}
test("SPARK-24722: pivoting nested columns") {
val expected = Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
val df = trainingSales
.groupBy($"sales.year")
.pivot(lower($"sales.course"), Seq("dotNet", "Java").map(_.toLowerCase))
.agg(sum($"sales.earnings"))
checkAnswer(df, expected)
}
test("SPARK-24722: references to multiple columns in the pivot column") {
val expected = Row(2012, 10000.0) :: Row(2013, 48000.0) :: Nil
val df = trainingSales
.groupBy($"sales.year")
.pivot(concat_ws("-", $"training", $"sales.course"), Seq("Experts-dotNET"))
.agg(sum($"sales.earnings"))
checkAnswer(df, expected)
}
test("SPARK-24722: pivoting by a constant") {
val expected = Row(2012, 35000.0) :: Row(2013, 78000.0) :: Nil
val df1 = trainingSales
.groupBy($"sales.year")
.pivot(lit(123), Seq(123))
.agg(sum($"sales.earnings"))
checkAnswer(df1, expected)
}
test("SPARK-24722: aggregate as the pivot column") {
val exception = intercept[AnalysisException] {
trainingSales
.groupBy($"sales.year")
.pivot(min($"training"), Seq("Experts"))
.agg(sum($"sales.earnings"))
}
assert(exception.getMessage.contains("aggregate functions are not allowed"))
}
}

View file

@ -268,6 +268,17 @@ private[sql] trait SQLTestData { self =>
df
}
protected lazy val trainingSales: DataFrame = {
val df = spark.sparkContext.parallelize(
TrainingSales("Experts", CourseSales("dotNET", 2012, 10000)) ::
TrainingSales("Experts", CourseSales("JAVA", 2012, 20000)) ::
TrainingSales("Dummies", CourseSales("dotNet", 2012, 5000)) ::
TrainingSales("Experts", CourseSales("dotNET", 2013, 48000)) ::
TrainingSales("Dummies", CourseSales("Java", 2013, 30000)) :: Nil).toDF()
df.createOrReplaceTempView("trainingSales")
df
}
/**
* Initialize all test data such that all temp tables are properly registered.
*/
@ -323,4 +334,5 @@ private[sql] object SQLTestData {
case class Salary(personId: Int, salary: Double)
case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean)
case class CourseSales(course: String, year: Int, earnings: Double)
case class TrainingSales(training: String, sales: CourseSales)
}