[SQL] DataFrame API improvements

1. Added Dsl.column in case Dsl.col is shadowed.
2. Allow using String to specify the target data type in cast.
3. Support sorting on multiple columns using column names.
4. Added Java API test file.

Author: Reynold Xin <rxin@databricks.com>

Closes #4280 from rxin/dsl1 and squashes the following commits:

33ecb7a [Reynold Xin] Add the Java test.
d06540a [Reynold Xin] [SQL] DataFrame API improvements.
This commit is contained in:
Reynold Xin 2015-01-29 17:24:00 -08:00
parent d2071e8f45
commit ce9c43ba8c
6 changed files with 209 additions and 16 deletions

View file

@ -56,7 +56,7 @@ object Column {
class Column(
sqlContext: Option[SQLContext],
plan: Option[LogicalPlan],
val expr: Expression)
protected[sql] val expr: Expression)
extends DataFrame(sqlContext, plan) with ExpressionApi {
/** Turns a Catalyst expression into a `Column`. */
@ -437,9 +437,7 @@ class Column(
override def rlike(literal: String): Column = RLike(expr, lit(literal).expr)
/**
* An expression that gets an
* @param ordinal
* @return
* An expression that gets an item at position `ordinal` out of an array.
*/
override def getItem(ordinal: Int): Column = GetItem(expr, Literal(ordinal))
@ -490,11 +488,38 @@ class Column(
* {{{
* // Casts colA to IntegerType.
* import org.apache.spark.sql.types.IntegerType
* df.select(df("colA").as(IntegerType))
* df.select(df("colA").cast(IntegerType))
*
* // equivalent to
* df.select(df("colA").cast("int"))
* }}}
*/
override def cast(to: DataType): Column = Cast(expr, to)
/**
* Casts the column to a different data type, using the canonical string representation
* of the type. The supported types are: `string`, `boolean`, `byte`, `short`, `int`, `long`,
* `float`, `double`, `decimal`, `date`, `timestamp`.
* {{{
* // Casts colA to integer.
* df.select(df("colA").cast("int"))
* }}}
*/
override def cast(to: String): Column = Cast(expr, to.toLowerCase match {
case "string" => StringType
case "boolean" => BooleanType
case "byte" => ByteType
case "short" => ShortType
case "int" => IntegerType
case "long" => LongType
case "float" => FloatType
case "double" => DoubleType
case "decimal" => DecimalType.Unlimited
case "date" => DateType
case "timestamp" => TimestampType
case _ => throw new RuntimeException(s"""Unsupported cast type: "$to"""")
})
override def desc: Column = SortOrder(expr, Descending)
override def asc: Column = SortOrder(expr, Ascending)

View file

@ -208,7 +208,7 @@ class DataFrame protected[sql](
}
/**
* Returns a new [[DataFrame]] sorted by the specified column, in ascending column.
* Returns a new [[DataFrame]] sorted by the specified column, all in ascending order.
* {{{
* // The following 3 are equivalent
* df.sort("sortcol")
@ -216,8 +216,9 @@ class DataFrame protected[sql](
* df.sort($"sortcol".asc)
* }}}
*/
override def sort(colName: String): DataFrame = {
Sort(Seq(SortOrder(apply(colName).expr, Ascending)), global = true, logicalPlan)
@scala.annotation.varargs
override def sort(sortCol: String, sortCols: String*): DataFrame = {
orderBy(apply(sortCol), sortCols.map(apply) :_*)
}
/**
@ -239,6 +240,15 @@ class DataFrame protected[sql](
Sort(sortOrder, global = true, logicalPlan)
}
/**
* Returns a new [[DataFrame]] sorted by the given expressions.
* This is an alias of the `sort` function.
*/
@scala.annotation.varargs
override def orderBy(sortCol: String, sortCols: String*): DataFrame = {
sort(sortCol, sortCols :_*)
}
/**
* Returns a new [[DataFrame]] sorted by the given expressions.
* This is an alias of the `sort` function.
@ -401,6 +411,16 @@ class DataFrame protected[sql](
*/
override def agg(exprs: Map[String, String]): DataFrame = groupBy().agg(exprs)
/**
* Aggregates on the entire [[DataFrame]] without groups.
* {{
* // df.agg(...) is a shorthand for df.groupBy().agg(...)
* df.agg(Map("age" -> "max", "salary" -> "avg"))
* df.groupBy().agg(Map("age" -> "max", "salary" -> "avg"))
* }}
*/
override def agg(exprs: java.util.Map[String, String]): DataFrame = agg(exprs.toMap)
/**
* Aggregates on the entire [[DataFrame]] without groups.
* {{

View file

@ -62,6 +62,11 @@ object Dsl {
*/
def col(colName: String): Column = new Column(colName)
/**
* Returns a [[Column]] based on the given column name. Alias of [[col]].
*/
def column(colName: String): Column = new Column(colName)
/**
* Creates a [[Column]] of literal value.
*/
@ -96,6 +101,7 @@ object Dsl {
def sumDistinct(e: Column): Column = SumDistinct(e.expr)
def count(e: Column): Column = Count(e.expr)
@scala.annotation.varargs
def countDistinct(expr: Column, exprs: Column*): Column =
CountDistinct((expr +: exprs).map(_.expr))

View file

@ -58,7 +58,9 @@ class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expressi
}
/**
* Compute aggregates by specifying a map from column name to aggregate methods.
* Compute aggregates by specifying a map from column name to aggregate methods. The resulting
* [[DataFrame]] will also contain the grouping columns.
*
* The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`.
* {{{
* // Selects the age of the oldest employee and the aggregate expense for each department
@ -76,7 +78,9 @@ class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expressi
}
/**
* Compute aggregates by specifying a map from column name to aggregate methods.
* Compute aggregates by specifying a map from column name to aggregate methods. The resulting
* [[DataFrame]] will also contain the grouping columns.
*
* The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`.
* {{{
* // Selects the age of the oldest employee and the aggregate expense for each department
@ -91,12 +95,15 @@ class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expressi
}
/**
* Compute aggregates by specifying a series of aggregate columns.
* The available aggregate methods are defined in [[org.apache.spark.sql.dsl]].
* Compute aggregates by specifying a series of aggregate columns. Unlike other methods in this
* class, the resulting [[DataFrame]] won't automatically include the grouping columns.
*
* The available aggregate methods are defined in [[org.apache.spark.sql.Dsl]].
*
* {{{
* // Selects the age of the oldest employee and the aggregate expense for each department
* import org.apache.spark.sql.dsl._
* df.groupBy("department").agg(max($"age"), sum($"expense"))
* df.groupBy("department").agg($"department", max($"age"), sum($"expense"))
* }}}
*/
@scala.annotation.varargs
@ -109,31 +116,39 @@ class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expressi
new DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan))
}
/** Count the number of rows for each group. */
/**
* Count the number of rows for each group.
* The resulting [[DataFrame]] will also contain the grouping columns.
*/
override def count(): DataFrame = Seq(Alias(Count(LiteralExpr(1)), "count")())
/**
* Compute the average value for each numeric columns for each group. This is an alias for `avg`.
* The resulting [[DataFrame]] will also contain the grouping columns.
*/
override def mean(): DataFrame = aggregateNumericColumns(Average)
/**
* Compute the max value for each numeric columns for each group.
* The resulting [[DataFrame]] will also contain the grouping columns.
*/
override def max(): DataFrame = aggregateNumericColumns(Max)
/**
* Compute the mean value for each numeric columns for each group.
* The resulting [[DataFrame]] will also contain the grouping columns.
*/
override def avg(): DataFrame = aggregateNumericColumns(Average)
/**
* Compute the min value for each numeric column for each group.
* The resulting [[DataFrame]] will also contain the grouping columns.
*/
override def min(): DataFrame = aggregateNumericColumns(Min)
/**
* Compute the sum for each numeric columns for each group.
* The resulting [[DataFrame]] will also contain the grouping columns.
*/
override def sum(): DataFrame = aggregateNumericColumns(Sum)
}

View file

@ -113,16 +113,22 @@ private[sql] trait DataFrameSpecificApi {
def agg(exprs: Map[String, String]): DataFrame
def agg(exprs: java.util.Map[String, String]): DataFrame
@scala.annotation.varargs
def agg(expr: Column, exprs: Column*): DataFrame
def sort(colName: String): DataFrame
@scala.annotation.varargs
def sort(sortExpr: Column, sortExprs: Column*): DataFrame
@scala.annotation.varargs
def sort(sortCol: String, sortCols: String*): DataFrame
@scala.annotation.varargs
def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame
@scala.annotation.varargs
def sort(sortExpr: Column, sortExprs: Column*): DataFrame
def orderBy(sortCol: String, sortCols: String*): DataFrame
def join(right: DataFrame): DataFrame
@ -257,6 +263,7 @@ private[sql] trait ExpressionApi {
def getField(fieldName: String): Column
def cast(to: DataType): Column
def cast(to: String): Column
def asc: Column
def desc: Column

View file

@ -0,0 +1,120 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.api.java;
import com.google.common.collect.ImmutableMap;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.types.DataTypes;
import static org.apache.spark.sql.Dsl.*;
/**
* This test doesn't actually run anything. It is here to check the API compatibility for Java.
*/
public class JavaDsl {
public static void testDataFrame(final DataFrame df) {
DataFrame df1 = df.select("colA");
df1 = df.select("colA", "colB");
df1 = df.select(col("colA"), col("colB"), lit("literal value").$plus(1));
df1 = df.filter(col("colA"));
java.util.Map<String, String> aggExprs = ImmutableMap.<String, String>builder()
.put("colA", "sum")
.put("colB", "avg")
.build();
df1 = df.agg(aggExprs);
df1 = df.groupBy("groupCol").agg(aggExprs);
df1 = df.join(df1, col("key1").$eq$eq$eq(col("key2")), "outer");
df.orderBy("colA");
df.orderBy("colA", "colB", "colC");
df.orderBy(col("colA").desc());
df.orderBy(col("colA").desc(), col("colB").asc());
df.sort("colA");
df.sort("colA", "colB", "colC");
df.sort(col("colA").desc());
df.sort(col("colA").desc(), col("colB").asc());
df.as("b");
df.limit(5);
df.unionAll(df1);
df.intersect(df1);
df.except(df1);
df.sample(true, 0.1, 234);
df.head();
df.head(5);
df.first();
df.count();
}
public static void testColumn(final Column c) {
c.asc();
c.desc();
c.endsWith("abcd");
c.startsWith("afgasdf");
c.like("asdf%");
c.rlike("wef%asdf");
c.as("newcol");
c.cast("int");
c.cast(DataTypes.IntegerType);
}
public static void testDsl() {
// Creating a column.
Column c = col("abcd");
Column c1 = column("abcd");
// Literals
Column l1 = lit(1);
Column l2 = lit(1.0);
Column l3 = lit("abcd");
// Functions
Column a = upper(c);
a = lower(c);
a = sqrt(c);
a = abs(c);
// Aggregates
a = min(c);
a = max(c);
a = sum(c);
a = sumDistinct(c);
a = countDistinct(c, a);
a = avg(c);
a = first(c);
a = last(c);
}
}