[SPARK-24722][SQL] pivot() with Column type argument
## What changes were proposed in this pull request? In the PR, I propose column-based API for the `pivot()` function. It allows using of any column expressions as the pivot column. Also this makes it consistent with how groupBy() works. ## How was this patch tested? I added new tests to `DataFramePivotSuite` and updated PySpark examples for the `pivot()` function. Author: Maxim Gekk <maxim.gekk@databricks.com> Closes #21699 from MaxGekk/pivot-column.
This commit is contained in:
parent
4c27663cb2
commit
41c2227a23
|
@ -211,6 +211,8 @@ class GroupedData(object):
|
|||
|
||||
>>> df4.groupBy("year").pivot("course").sum("earnings").collect()
|
||||
[Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)]
|
||||
>>> df5.groupBy("sales.year").pivot("sales.course").sum("sales.earnings").collect()
|
||||
[Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)]
|
||||
"""
|
||||
if values is None:
|
||||
jgd = self._jgd.pivot(pivot_col)
|
||||
|
@ -296,6 +298,12 @@ def _test():
|
|||
Row(course="dotNET", year=2012, earnings=5000),
|
||||
Row(course="dotNET", year=2013, earnings=48000),
|
||||
Row(course="Java", year=2013, earnings=30000)]).toDF()
|
||||
globs['df5'] = sc.parallelize([
|
||||
Row(training="expert", sales=Row(course="dotNET", year=2012, earnings=10000)),
|
||||
Row(training="junior", sales=Row(course="Java", year=2012, earnings=20000)),
|
||||
Row(training="expert", sales=Row(course="dotNET", year=2012, earnings=5000)),
|
||||
Row(training="junior", sales=Row(course="dotNET", year=2013, earnings=48000)),
|
||||
Row(training="expert", sales=Row(course="Java", year=2013, earnings=30000))]).toDF()
|
||||
|
||||
(failure_count, test_count) = doctest.testmod(
|
||||
pyspark.sql.group, globs=globs,
|
||||
|
|
|
@ -314,7 +314,67 @@ class RelationalGroupedDataset protected[sql](
|
|||
* @param pivotColumn Name of the column to pivot.
|
||||
* @since 1.6.0
|
||||
*/
|
||||
def pivot(pivotColumn: String): RelationalGroupedDataset = {
|
||||
def pivot(pivotColumn: String): RelationalGroupedDataset = pivot(Column(pivotColumn))
|
||||
|
||||
/**
|
||||
* Pivots a column of the current `DataFrame` and performs the specified aggregation.
|
||||
* There are two versions of pivot function: one that requires the caller to specify the list
|
||||
* of distinct values to pivot on, and one that does not. The latter is more concise but less
|
||||
* efficient, because Spark needs to first compute the list of distinct values internally.
|
||||
*
|
||||
* {{{
|
||||
* // Compute the sum of earnings for each year by course with each course as a separate column
|
||||
* df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings")
|
||||
*
|
||||
* // Or without specifying column values (less efficient)
|
||||
* df.groupBy("year").pivot("course").sum("earnings")
|
||||
* }}}
|
||||
*
|
||||
* @param pivotColumn Name of the column to pivot.
|
||||
* @param values List of values that will be translated to columns in the output DataFrame.
|
||||
* @since 1.6.0
|
||||
*/
|
||||
def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = {
|
||||
pivot(Column(pivotColumn), values)
|
||||
}
|
||||
|
||||
/**
|
||||
* (Java-specific) Pivots a column of the current `DataFrame` and performs the specified
|
||||
* aggregation.
|
||||
*
|
||||
* There are two versions of pivot function: one that requires the caller to specify the list
|
||||
* of distinct values to pivot on, and one that does not. The latter is more concise but less
|
||||
* efficient, because Spark needs to first compute the list of distinct values internally.
|
||||
*
|
||||
* {{{
|
||||
* // Compute the sum of earnings for each year by course with each course as a separate column
|
||||
* df.groupBy("year").pivot("course", Arrays.<Object>asList("dotNET", "Java")).sum("earnings");
|
||||
*
|
||||
* // Or without specifying column values (less efficient)
|
||||
* df.groupBy("year").pivot("course").sum("earnings");
|
||||
* }}}
|
||||
*
|
||||
* @param pivotColumn Name of the column to pivot.
|
||||
* @param values List of values that will be translated to columns in the output DataFrame.
|
||||
* @since 1.6.0
|
||||
*/
|
||||
def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = {
|
||||
pivot(Column(pivotColumn), values)
|
||||
}
|
||||
|
||||
/**
|
||||
* Pivots a column of the current `DataFrame` and performs the specified aggregation.
|
||||
* This is an overloaded version of the `pivot` method with `pivotColumn` of the `String` type.
|
||||
*
|
||||
* {{{
|
||||
* // Or without specifying column values (less efficient)
|
||||
* df.groupBy($"year").pivot($"course").sum($"earnings");
|
||||
* }}}
|
||||
*
|
||||
* @param pivotColumn he column to pivot.
|
||||
* @since 2.4.0
|
||||
*/
|
||||
def pivot(pivotColumn: Column): RelationalGroupedDataset = {
|
||||
// This is to prevent unintended OOM errors when the number of distinct values is large
|
||||
val maxValues = df.sparkSession.sessionState.conf.dataFramePivotMaxValues
|
||||
// Get the distinct values of the column and sort them so its consistent
|
||||
|
@ -339,29 +399,24 @@ class RelationalGroupedDataset protected[sql](
|
|||
|
||||
/**
|
||||
* Pivots a column of the current `DataFrame` and performs the specified aggregation.
|
||||
* There are two versions of pivot function: one that requires the caller to specify the list
|
||||
* of distinct values to pivot on, and one that does not. The latter is more concise but less
|
||||
* efficient, because Spark needs to first compute the list of distinct values internally.
|
||||
* This is an overloaded version of the `pivot` method with `pivotColumn` of the `String` type.
|
||||
*
|
||||
* {{{
|
||||
* // Compute the sum of earnings for each year by course with each course as a separate column
|
||||
* df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings")
|
||||
*
|
||||
* // Or without specifying column values (less efficient)
|
||||
* df.groupBy("year").pivot("course").sum("earnings")
|
||||
* df.groupBy($"year").pivot($"course", Seq("dotNET", "Java")).sum($"earnings")
|
||||
* }}}
|
||||
*
|
||||
* @param pivotColumn Name of the column to pivot.
|
||||
* @param pivotColumn the column to pivot.
|
||||
* @param values List of values that will be translated to columns in the output DataFrame.
|
||||
* @since 1.6.0
|
||||
* @since 2.4.0
|
||||
*/
|
||||
def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = {
|
||||
def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset = {
|
||||
groupType match {
|
||||
case RelationalGroupedDataset.GroupByType =>
|
||||
new RelationalGroupedDataset(
|
||||
df,
|
||||
groupingExprs,
|
||||
RelationalGroupedDataset.PivotType(df.resolve(pivotColumn), values.map(Literal.apply)))
|
||||
RelationalGroupedDataset.PivotType(pivotColumn.expr, values.map(Literal.apply)))
|
||||
case _: RelationalGroupedDataset.PivotType =>
|
||||
throw new UnsupportedOperationException("repeated pivots are not supported")
|
||||
case _ =>
|
||||
|
@ -371,25 +426,14 @@ class RelationalGroupedDataset protected[sql](
|
|||
|
||||
/**
|
||||
* (Java-specific) Pivots a column of the current `DataFrame` and performs the specified
|
||||
* aggregation.
|
||||
* aggregation. This is an overloaded version of the `pivot` method with `pivotColumn` of
|
||||
* the `String` type.
|
||||
*
|
||||
* There are two versions of pivot function: one that requires the caller to specify the list
|
||||
* of distinct values to pivot on, and one that does not. The latter is more concise but less
|
||||
* efficient, because Spark needs to first compute the list of distinct values internally.
|
||||
*
|
||||
* {{{
|
||||
* // Compute the sum of earnings for each year by course with each course as a separate column
|
||||
* df.groupBy("year").pivot("course", Arrays.<Object>asList("dotNET", "Java")).sum("earnings");
|
||||
*
|
||||
* // Or without specifying column values (less efficient)
|
||||
* df.groupBy("year").pivot("course").sum("earnings");
|
||||
* }}}
|
||||
*
|
||||
* @param pivotColumn Name of the column to pivot.
|
||||
* @param pivotColumn the column to pivot.
|
||||
* @param values List of values that will be translated to columns in the output DataFrame.
|
||||
* @since 1.6.0
|
||||
* @since 2.4.0
|
||||
*/
|
||||
def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = {
|
||||
def pivot(pivotColumn: Column, values: java.util.List[Any]): RelationalGroupedDataset = {
|
||||
pivot(pivotColumn, values.asScala)
|
||||
}
|
||||
|
||||
|
|
|
@ -27,28 +27,40 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext {
|
|||
import testImplicits._
|
||||
|
||||
test("pivot courses") {
|
||||
val expected = Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
|
||||
checkAnswer(
|
||||
courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java"))
|
||||
.agg(sum($"earnings")),
|
||||
Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
|
||||
)
|
||||
expected)
|
||||
checkAnswer(
|
||||
courseSales.groupBy($"year").pivot($"course", Seq("dotNET", "Java"))
|
||||
.agg(sum($"earnings")),
|
||||
expected)
|
||||
}
|
||||
|
||||
test("pivot year") {
|
||||
val expected = Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
|
||||
checkAnswer(
|
||||
courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).agg(sum($"earnings")),
|
||||
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
|
||||
)
|
||||
expected)
|
||||
checkAnswer(
|
||||
courseSales.groupBy('course).pivot('year, Seq(2012, 2013)).agg(sum('earnings)),
|
||||
expected)
|
||||
}
|
||||
|
||||
test("pivot courses with multiple aggregations") {
|
||||
val expected = Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) ::
|
||||
Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil
|
||||
checkAnswer(
|
||||
courseSales.groupBy($"year")
|
||||
.pivot("course", Seq("dotNET", "Java"))
|
||||
.agg(sum($"earnings"), avg($"earnings")),
|
||||
Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) ::
|
||||
Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil
|
||||
)
|
||||
expected)
|
||||
checkAnswer(
|
||||
courseSales.groupBy($"year")
|
||||
.pivot($"course", Seq("dotNET", "Java"))
|
||||
.agg(sum($"earnings"), avg($"earnings")),
|
||||
expected)
|
||||
}
|
||||
|
||||
test("pivot year with string values (cast)") {
|
||||
|
@ -67,17 +79,23 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext {
|
|||
|
||||
test("pivot courses with no values") {
|
||||
// Note Java comes before dotNet in sorted order
|
||||
val expected = Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil
|
||||
checkAnswer(
|
||||
courseSales.groupBy("year").pivot("course").agg(sum($"earnings")),
|
||||
Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil
|
||||
)
|
||||
expected)
|
||||
checkAnswer(
|
||||
courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")),
|
||||
expected)
|
||||
}
|
||||
|
||||
test("pivot year with no values") {
|
||||
val expected = Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
|
||||
checkAnswer(
|
||||
courseSales.groupBy("course").pivot("year").agg(sum($"earnings")),
|
||||
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
|
||||
)
|
||||
expected)
|
||||
checkAnswer(
|
||||
courseSales.groupBy($"course").pivot($"year").agg(sum($"earnings")),
|
||||
expected)
|
||||
}
|
||||
|
||||
test("pivot max values enforced") {
|
||||
|
@ -181,10 +199,13 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext {
|
|||
}
|
||||
|
||||
test("pivot with datatype not supported by PivotFirst") {
|
||||
val expected = Row(Seq(1, 1, 1), Seq(2, 2, 2)) :: Nil
|
||||
checkAnswer(
|
||||
complexData.groupBy().pivot("b", Seq(true, false)).agg(max("a")),
|
||||
Row(Seq(1, 1, 1), Seq(2, 2, 2)) :: Nil
|
||||
)
|
||||
expected)
|
||||
checkAnswer(
|
||||
complexData.groupBy().pivot('b, Seq(true, false)).agg(max('a)),
|
||||
expected)
|
||||
}
|
||||
|
||||
test("pivot with datatype not supported by PivotFirst 2") {
|
||||
|
@ -246,4 +267,45 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext {
|
|||
checkAnswer(df.select($"a".cast(StringType)), Row(tsWithZone))
|
||||
}
|
||||
}
|
||||
|
||||
test("SPARK-24722: pivoting nested columns") {
|
||||
val expected = Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
|
||||
val df = trainingSales
|
||||
.groupBy($"sales.year")
|
||||
.pivot(lower($"sales.course"), Seq("dotNet", "Java").map(_.toLowerCase))
|
||||
.agg(sum($"sales.earnings"))
|
||||
|
||||
checkAnswer(df, expected)
|
||||
}
|
||||
|
||||
test("SPARK-24722: references to multiple columns in the pivot column") {
|
||||
val expected = Row(2012, 10000.0) :: Row(2013, 48000.0) :: Nil
|
||||
val df = trainingSales
|
||||
.groupBy($"sales.year")
|
||||
.pivot(concat_ws("-", $"training", $"sales.course"), Seq("Experts-dotNET"))
|
||||
.agg(sum($"sales.earnings"))
|
||||
|
||||
checkAnswer(df, expected)
|
||||
}
|
||||
|
||||
test("SPARK-24722: pivoting by a constant") {
|
||||
val expected = Row(2012, 35000.0) :: Row(2013, 78000.0) :: Nil
|
||||
val df1 = trainingSales
|
||||
.groupBy($"sales.year")
|
||||
.pivot(lit(123), Seq(123))
|
||||
.agg(sum($"sales.earnings"))
|
||||
|
||||
checkAnswer(df1, expected)
|
||||
}
|
||||
|
||||
test("SPARK-24722: aggregate as the pivot column") {
|
||||
val exception = intercept[AnalysisException] {
|
||||
trainingSales
|
||||
.groupBy($"sales.year")
|
||||
.pivot(min($"training"), Seq("Experts"))
|
||||
.agg(sum($"sales.earnings"))
|
||||
}
|
||||
|
||||
assert(exception.getMessage.contains("aggregate functions are not allowed"))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -268,6 +268,17 @@ private[sql] trait SQLTestData { self =>
|
|||
df
|
||||
}
|
||||
|
||||
protected lazy val trainingSales: DataFrame = {
|
||||
val df = spark.sparkContext.parallelize(
|
||||
TrainingSales("Experts", CourseSales("dotNET", 2012, 10000)) ::
|
||||
TrainingSales("Experts", CourseSales("JAVA", 2012, 20000)) ::
|
||||
TrainingSales("Dummies", CourseSales("dotNet", 2012, 5000)) ::
|
||||
TrainingSales("Experts", CourseSales("dotNET", 2013, 48000)) ::
|
||||
TrainingSales("Dummies", CourseSales("Java", 2013, 30000)) :: Nil).toDF()
|
||||
df.createOrReplaceTempView("trainingSales")
|
||||
df
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize all test data such that all temp tables are properly registered.
|
||||
*/
|
||||
|
@ -323,4 +334,5 @@ private[sql] object SQLTestData {
|
|||
case class Salary(personId: Int, salary: Double)
|
||||
case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean)
|
||||
case class CourseSales(course: String, year: Int, earnings: Double)
|
||||
case class TrainingSales(training: String, sales: CourseSales)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue