[SPARK-6957] [SPARK-6958] [SQL] improve API compatibility to pandas

```
select(['cola', 'colb'])

groupby(['colA', 'colB'])
groupby([df.colA, df.colB])

df.sort('A', ascending=True)
df.sort(['A', 'B'], ascending=True)
df.sort(['A', 'B'], ascending=[1, 0])
```

cc rxin

Author: Davies Liu <davies@databricks.com>

Closes #5544 from davies/compatibility and squashes the following commits:

4944058 [Davies Liu] add docstrings
adb2816 [Davies Liu] Merge branch 'master' of github.com:apache/spark into compatibility
bcbbcab [Davies Liu] support ascending as list
8dabdf0 [Davies Liu] improve API compatibility to pandas
This commit is contained in:
Davies Liu 2015-04-17 11:29:27 -05:00 committed by Reynold Xin
parent dc48ba9f9f
commit c84d91692a
3 changed files with 70 additions and 39 deletions

View file

@ -485,13 +485,17 @@ class DataFrame(object):
return DataFrame(jdf, self.sql_ctx)
@ignore_unicode_prefix
def sort(self, *cols):
def sort(self, *cols, **kwargs):
"""Returns a new :class:`DataFrame` sorted by the specified column(s).
:param cols: list of :class:`Column` to sort by.
:param cols: list of :class:`Column` or column names to sort by.
:param ascending: sort by ascending order or not, could be bool, int
or list of bool, int (default: True).
>>> df.sort(df.age.desc()).collect()
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
>>> df.sort("age", ascending=False).collect()
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
>>> df.orderBy(df.age.desc()).collect()
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
>>> from pyspark.sql.functions import *
@ -499,16 +503,42 @@ class DataFrame(object):
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
>>> df.orderBy(desc("age"), "name").collect()
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
>>> df.orderBy(["age", "name"], ascending=[0, 1]).collect()
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
"""
if not cols:
raise ValueError("should sort by at least one column")
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
self._sc._gateway._gateway_client)
jdf = self._jdf.sort(self._sc._jvm.PythonUtils.toSeq(jcols))
if len(cols) == 1 and isinstance(cols[0], list):
cols = cols[0]
jcols = [_to_java_column(c) for c in cols]
ascending = kwargs.get('ascending', True)
if isinstance(ascending, (bool, int)):
if not ascending:
jcols = [jc.desc() for jc in jcols]
elif isinstance(ascending, list):
jcols = [jc if asc else jc.desc()
for asc, jc in zip(ascending, jcols)]
else:
raise TypeError("ascending can only be bool or list, but got %s" % type(ascending))
jdf = self._jdf.sort(self._jseq(jcols))
return DataFrame(jdf, self.sql_ctx)
orderBy = sort
def _jseq(self, cols, converter=None):
"""Return a JVM Seq of Columns from a list of Column or names"""
return _to_seq(self.sql_ctx._sc, cols, converter)
def _jcols(self, *cols):
"""Return a JVM Seq of Columns from a list of Column or column names
If `cols` has only one list in it, cols[0] will be used as the list.
"""
if len(cols) == 1 and isinstance(cols[0], list):
cols = cols[0]
return self._jseq(cols, _to_java_column)
def describe(self, *cols):
"""Computes statistics for numeric columns.
@ -523,9 +553,7 @@ class DataFrame(object):
min 2
max 5
"""
cols = ListConverter().convert(cols,
self.sql_ctx._sc._gateway._gateway_client)
jdf = self._jdf.describe(self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols))
jdf = self._jdf.describe(self._jseq(cols))
return DataFrame(jdf, self.sql_ctx)
@ignore_unicode_prefix
@ -607,9 +635,7 @@ class DataFrame(object):
>>> df.select(df.name, (df.age + 10).alias('age')).collect()
[Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)]
"""
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
self._sc._gateway._gateway_client)
jdf = self._jdf.select(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
jdf = self._jdf.select(self._jcols(*cols))
return DataFrame(jdf, self.sql_ctx)
def selectExpr(self, *expr):
@ -620,8 +646,9 @@ class DataFrame(object):
>>> df.selectExpr("age * 2", "abs(age)").collect()
[Row((age * 2)=4, Abs(age)=2), Row((age * 2)=10, Abs(age)=5)]
"""
jexpr = ListConverter().convert(expr, self._sc._gateway._gateway_client)
jdf = self._jdf.selectExpr(self._sc._jvm.PythonUtils.toSeq(jexpr))
if len(expr) == 1 and isinstance(expr[0], list):
expr = expr[0]
jdf = self._jdf.selectExpr(self._jseq(expr))
return DataFrame(jdf, self.sql_ctx)
@ignore_unicode_prefix
@ -659,6 +686,8 @@ class DataFrame(object):
so we can run aggregation on them. See :class:`GroupedData`
for all the available aggregate functions.
:func:`groupby` is an alias for :func:`groupBy`.
:param cols: list of columns to group by.
Each element should be a column name (string) or an expression (:class:`Column`).
@ -668,12 +697,14 @@ class DataFrame(object):
[Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)]
>>> df.groupBy(df.name).avg().collect()
[Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)]
>>> df.groupBy(['name', df.age]).count().collect()
[Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)]
"""
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
self._sc._gateway._gateway_client)
jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
jdf = self._jdf.groupBy(self._jcols(*cols))
return GroupedData(jdf, self.sql_ctx)
groupby = groupBy
def agg(self, *exprs):
""" Aggregate on the entire :class:`DataFrame` without groups
(shorthand for ``df.groupBy.agg()``).
@ -744,9 +775,7 @@ class DataFrame(object):
if thresh is None:
thresh = len(subset) if how == 'any' else 1
cols = ListConverter().convert(subset, self.sql_ctx._sc._gateway._gateway_client)
cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols)
return DataFrame(self._jdf.na().drop(thresh, cols), self.sql_ctx)
return DataFrame(self._jdf.na().drop(thresh, self._jseq(subset)), self.sql_ctx)
def fillna(self, value, subset=None):
"""Replace null values, alias for ``na.fill()``.
@ -799,9 +828,7 @@ class DataFrame(object):
elif not isinstance(subset, (list, tuple)):
raise ValueError("subset should be a list or tuple of column names")
cols = ListConverter().convert(subset, self.sql_ctx._sc._gateway._gateway_client)
cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols)
return DataFrame(self._jdf.na().fill(value, cols), self.sql_ctx)
return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx)
@ignore_unicode_prefix
def withColumn(self, colName, col):
@ -862,10 +889,8 @@ def dfapi(f):
def df_varargs_api(f):
def _api(self, *args):
jargs = ListConverter().convert(args,
self.sql_ctx._sc._gateway._gateway_client)
name = f.__name__
jdf = getattr(self._jdf, name)(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jargs))
jdf = getattr(self._jdf, name)(_to_seq(self.sql_ctx._sc, args))
return DataFrame(jdf, self.sql_ctx)
_api.__name__ = f.__name__
_api.__doc__ = f.__doc__
@ -912,9 +937,8 @@ class GroupedData(object):
else:
# Columns
assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
jcols = ListConverter().convert([c._jc for c in exprs[1:]],
self.sql_ctx._sc._gateway._gateway_client)
jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
jdf = self._jdf.agg(exprs[0]._jc,
_to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]]))
return DataFrame(jdf, self.sql_ctx)
@dfapi
@ -1006,6 +1030,19 @@ def _to_java_column(col):
return jcol
def _to_seq(sc, cols, converter=None):
"""
Convert a list of Column (or names) into a JVM Seq of Column.
An optional `converter` could be used to convert items in `cols`
into JVM Column objects.
"""
if converter:
cols = [converter(c) for c in cols]
jcols = ListConverter().convert(cols, sc._gateway._gateway_client)
return sc._jvm.PythonUtils.toSeq(jcols)
def _unary_op(name, doc="unary operator"):
""" Create a method for given unary operator """
def _(self):
@ -1177,8 +1214,7 @@ class Column(object):
cols = cols[0]
cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols]
sc = SparkContext._active_spark_context
jcols = ListConverter().convert(cols, sc._gateway._gateway_client)
jc = getattr(self._jc, "in")(sc._jvm.PythonUtils.toSeq(jcols))
jc = getattr(self._jc, "in")(_to_seq(sc, cols))
return Column(jc)
# order

View file

@ -23,13 +23,11 @@ import sys
if sys.version < "3":
from itertools import imap as map
from py4j.java_collections import ListConverter
from pyspark import SparkContext
from pyspark.rdd import _prepare_for_python_RDD
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
from pyspark.sql.types import StringType
from pyspark.sql.dataframe import Column, _to_java_column
from pyspark.sql.dataframe import Column, _to_java_column, _to_seq
__all__ = ['countDistinct', 'approxCountDistinct', 'udf']
@ -87,8 +85,7 @@ def countDistinct(col, *cols):
[Row(c=2)]
"""
sc = SparkContext._active_spark_context
jcols = ListConverter().convert([_to_java_column(c) for c in cols], sc._gateway._gateway_client)
jc = sc._jvm.functions.countDistinct(_to_java_column(col), sc._jvm.PythonUtils.toSeq(jcols))
jc = sc._jvm.functions.countDistinct(_to_java_column(col), _to_seq(sc, cols, _to_java_column))
return Column(jc)
@ -138,9 +135,7 @@ class UserDefinedFunction(object):
def __call__(self, *cols):
sc = SparkContext._active_spark_context
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
sc._gateway._gateway_client)
jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols))
jc = self._judf.apply(_to_seq(sc, cols, _to_java_column))
return Column(jc)

View file

@ -282,7 +282,7 @@ class SQLTests(ReusedPySparkTestCase):
StructField("struct1", StructType([StructField("b", ShortType(), False)]), False),
StructField("list1", ArrayType(ByteType(), False), False),
StructField("null1", DoubleType(), True)])
df = self.sqlCtx.applySchema(rdd, schema)
df = self.sqlCtx.createDataFrame(rdd, schema)
results = df.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, x.date1,
x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1))
r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1),