[SPARK-5985][SQL] DataFrame sortBy -> orderBy in Python.

Also added desc/asc function for constructing sorting expressions more conveniently. And added a small fix to lift alias out of cast expression.

Author: Reynold Xin <rxin@databricks.com>

Closes #4752 from rxin/SPARK-5985 and squashes the following commits:

aeda5ae [Reynold Xin] Added Experimental flag to ColumnName.
047ad03 [Reynold Xin] Lift alias out of cast.
c9cf17c [Reynold Xin] [SPARK-5985][SQL] DataFrame sortBy -> orderBy in Python.
This commit is contained in:
Reynold Xin 2015-02-24 18:59:23 -08:00 committed by Michael Armbrust
parent 53a1ebf33b
commit fba11c2f55
6 changed files with 59 additions and 5 deletions

View file

@ -504,13 +504,18 @@ class DataFrame(object):
return DataFrame(jdf, self.sql_ctx) return DataFrame(jdf, self.sql_ctx)
def sort(self, *cols): def sort(self, *cols):
""" Return a new :class:`DataFrame` sorted by the specified column. """ Return a new :class:`DataFrame` sorted by the specified column(s).
:param cols: The columns or expressions used for sorting :param cols: The columns or expressions used for sorting
>>> df.sort(df.age.desc()).collect() >>> df.sort(df.age.desc()).collect()
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
>>> df.sortBy(df.age.desc()).collect() >>> df.orderBy(df.age.desc()).collect()
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
>>> from pyspark.sql.functions import *
>>> df.sort(asc("age")).collect()
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
>>> df.orderBy(desc("age"), "name").collect()
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
""" """
if not cols: if not cols:
@ -520,7 +525,7 @@ class DataFrame(object):
jdf = self._jdf.sort(self._sc._jvm.PythonUtils.toSeq(jcols)) jdf = self._jdf.sort(self._sc._jvm.PythonUtils.toSeq(jcols))
return DataFrame(jdf, self.sql_ctx) return DataFrame(jdf, self.sql_ctx)
sortBy = sort orderBy = sort
def head(self, n=None): def head(self, n=None):
""" Return the first `n` rows or the first row if n is None. """ Return the first `n` rows or the first row if n is None.

View file

@ -48,6 +48,9 @@ _functions = {
'lit': 'Creates a :class:`Column` of literal value.', 'lit': 'Creates a :class:`Column` of literal value.',
'col': 'Returns a :class:`Column` based on the given column name.', 'col': 'Returns a :class:`Column` based on the given column name.',
'column': 'Returns a :class:`Column` based on the given column name.', 'column': 'Returns a :class:`Column` based on the given column name.',
'asc': 'Returns a sort expression based on the ascending order of the given column name.',
'desc': 'Returns a sort expression based on the descending order of the given column name.',
'upper': 'Converts a string expression to upper case.', 'upper': 'Converts a string expression to upper case.',
'lower': 'Converts a string expression to upper case.', 'lower': 'Converts a string expression to upper case.',
'sqrt': 'Computes the square root of the specified float value.', 'sqrt': 'Computes the square root of the specified float value.',

View file

@ -600,7 +600,11 @@ class Column(protected[sql] val expr: Expression) {
* *
* @group expr_ops * @group expr_ops
*/ */
def cast(to: DataType): Column = Cast(expr, to) def cast(to: DataType): Column = expr match {
// Lift alias out of cast so we can support col.as("name").cast(IntegerType)
case Alias(childExpr, name) => Alias(Cast(childExpr, to), name)()
case _ => Cast(expr, to)
}
/** /**
* Casts the column to a different data type, using the canonical string representation * Casts the column to a different data type, using the canonical string representation
@ -613,7 +617,7 @@ class Column(protected[sql] val expr: Expression) {
* *
* @group expr_ops * @group expr_ops
*/ */
def cast(to: String): Column = Cast(expr, to.toLowerCase match { def cast(to: String): Column = cast(to.toLowerCase match {
case "string" | "str" => StringType case "string" | "str" => StringType
case "boolean" => BooleanType case "boolean" => BooleanType
case "byte" => ByteType case "byte" => ByteType
@ -671,6 +675,11 @@ class Column(protected[sql] val expr: Expression) {
} }
/**
* :: Experimental ::
* A convenient class used for constructing schema.
*/
@Experimental
class ColumnName(name: String) extends Column(name) { class ColumnName(name: String) extends Column(name) {
/** Creates a new AttributeReference of type boolean */ /** Creates a new AttributeReference of type boolean */

View file

@ -33,6 +33,7 @@ import org.apache.spark.sql.types._
* *
* @groupname udf_funcs UDF functions * @groupname udf_funcs UDF functions
* @groupname agg_funcs Aggregate functions * @groupname agg_funcs Aggregate functions
* @groupname sort_funcs Sorting functions
* @groupname normal_funcs Non-aggregate functions * @groupname normal_funcs Non-aggregate functions
* @groupname Ungrouped Support functions for DataFrames. * @groupname Ungrouped Support functions for DataFrames.
*/ */
@ -96,6 +97,33 @@ object functions {
} }
////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////
// Sort functions
//////////////////////////////////////////////////////////////////////////////////////////////
/**
* Returns a sort expression based on ascending order of the column.
* {{
* // Sort by dept in ascending order, and then age in descending order.
* df.sort(asc("dept"), desc("age"))
* }}
*
* @group sort_funcs
*/
def asc(columnName: String): Column = Column(columnName).asc
/**
* Returns a sort expression based on the descending order of the column.
* {{
* // Sort by dept in ascending order, and then age in descending order.
* df.sort(asc("dept"), desc("age"))
* }}
*
* @group sort_funcs
*/
def desc(columnName: String): Column = Column(columnName).desc
//////////////////////////////////////////////////////////////////////////////////////////////
// Aggregate functions
////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////
/** /**
@ -263,6 +291,7 @@ object functions {
def max(columnName: String): Column = max(Column(columnName)) def max(columnName: String): Column = max(Column(columnName))
////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////
// Non-aggregate functions
////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////
/** /**

View file

@ -309,4 +309,8 @@ class ColumnExpressionSuite extends QueryTest {
(1 to 100).map(n => Row(null)) (1 to 100).map(n => Row(null))
) )
} }
test("lift alias out of cast") {
assert(col("1234").as("name").cast("int").expr === col("1234").cast("int").as("name").expr)
}
} }

View file

@ -239,6 +239,10 @@ class DataFrameSuite extends QueryTest {
testData2.orderBy('a.asc, 'b.asc), testData2.orderBy('a.asc, 'b.asc),
Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2))) Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2)))
checkAnswer(
testData2.orderBy(asc("a"), desc("b")),
Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1)))
checkAnswer( checkAnswer(
testData2.orderBy('a.asc, 'b.desc), testData2.orderBy('a.asc, 'b.desc),
Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1))) Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1)))