[SPARK-7543] [SQL] [PySpark] split dataframe.py into multiple files

dataframe.py is splited into column.py, group.py and dataframe.py:
```
   360 column.py
  1223 dataframe.py
   183 group.py
```

Author: Davies Liu <davies@databricks.com>

Closes #6201 from davies/split_df and squashes the following commits:

fc8f5ab [Davies Liu] split dataframe.py into multiple files
This commit is contained in:
Davies Liu 2015-05-15 20:09:15 -07:00 committed by Reynold Xin
parent adfd366814
commit d7b69946cb
6 changed files with 552 additions and 449 deletions

View file

@ -55,8 +55,9 @@ del modname, sys
from pyspark.sql.types import Row
from pyspark.sql.context import SQLContext, HiveContext
from pyspark.sql.dataframe import DataFrame, GroupedData, Column, SchemaRDD, DataFrameNaFunctions
from pyspark.sql.dataframe import DataFrameStatFunctions
from pyspark.sql.column import Column
from pyspark.sql.dataframe import DataFrame, SchemaRDD, DataFrameNaFunctions, DataFrameStatFunctions
from pyspark.sql.group import GroupedData
__all__ = [
'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row',

View file

@ -0,0 +1,360 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
if sys.version >= '3':
basestring = str
long = int
from pyspark.context import SparkContext
from pyspark.rdd import ignore_unicode_prefix
from pyspark.sql.types import *
__all__ = ["DataFrame", "Column", "SchemaRDD", "DataFrameNaFunctions",
"DataFrameStatFunctions"]
def _create_column_from_literal(literal):
sc = SparkContext._active_spark_context
return sc._jvm.functions.lit(literal)
def _create_column_from_name(name):
sc = SparkContext._active_spark_context
return sc._jvm.functions.col(name)
def _to_java_column(col):
if isinstance(col, Column):
jcol = col._jc
else:
jcol = _create_column_from_name(col)
return jcol
def _to_seq(sc, cols, converter=None):
"""
Convert a list of Column (or names) into a JVM Seq of Column.
An optional `converter` could be used to convert items in `cols`
into JVM Column objects.
"""
if converter:
cols = [converter(c) for c in cols]
return sc._jvm.PythonUtils.toSeq(cols)
def _unary_op(name, doc="unary operator"):
""" Create a method for given unary operator """
def _(self):
jc = getattr(self._jc, name)()
return Column(jc)
_.__doc__ = doc
return _
def _func_op(name, doc=''):
def _(self):
sc = SparkContext._active_spark_context
jc = getattr(sc._jvm.functions, name)(self._jc)
return Column(jc)
_.__doc__ = doc
return _
def _bin_op(name, doc="binary operator"):
""" Create a method for given binary operator
"""
def _(self, other):
jc = other._jc if isinstance(other, Column) else other
njc = getattr(self._jc, name)(jc)
return Column(njc)
_.__doc__ = doc
return _
def _reverse_op(name, doc="binary operator"):
""" Create a method for binary operator (this object is on right side)
"""
def _(self, other):
jother = _create_column_from_literal(other)
jc = getattr(jother, name)(self._jc)
return Column(jc)
_.__doc__ = doc
return _
class Column(object):
"""
A column in a DataFrame.
:class:`Column` instances can be created by::
# 1. Select a column out of a DataFrame
df.colName
df["colName"]
# 2. Create from an expression
df.colName + 1
1 / df.colName
"""
def __init__(self, jc):
self._jc = jc
# arithmetic operators
__neg__ = _func_op("negate")
__add__ = _bin_op("plus")
__sub__ = _bin_op("minus")
__mul__ = _bin_op("multiply")
__div__ = _bin_op("divide")
__truediv__ = _bin_op("divide")
__mod__ = _bin_op("mod")
__radd__ = _bin_op("plus")
__rsub__ = _reverse_op("minus")
__rmul__ = _bin_op("multiply")
__rdiv__ = _reverse_op("divide")
__rtruediv__ = _reverse_op("divide")
__rmod__ = _reverse_op("mod")
# logistic operators
__eq__ = _bin_op("equalTo")
__ne__ = _bin_op("notEqual")
__lt__ = _bin_op("lt")
__le__ = _bin_op("leq")
__ge__ = _bin_op("geq")
__gt__ = _bin_op("gt")
# `and`, `or`, `not` cannot be overloaded in Python,
# so use bitwise operators as boolean operators
__and__ = _bin_op('and')
__or__ = _bin_op('or')
__invert__ = _func_op('not')
__rand__ = _bin_op("and")
__ror__ = _bin_op("or")
# container operators
__contains__ = _bin_op("contains")
__getitem__ = _bin_op("apply")
# bitwise operators
bitwiseOR = _bin_op("bitwiseOR")
bitwiseAND = _bin_op("bitwiseAND")
bitwiseXOR = _bin_op("bitwiseXOR")
def getItem(self, key):
"""An expression that gets an item at position `ordinal` out of a list,
or gets an item by key out of a dict.
>>> df = sc.parallelize([([1, 2], {"key": "value"})]).toDF(["l", "d"])
>>> df.select(df.l.getItem(0), df.d.getItem("key")).show()
+----+------+
|l[0]|d[key]|
+----+------+
| 1| value|
+----+------+
>>> df.select(df.l[0], df.d["key"]).show()
+----+------+
|l[0]|d[key]|
+----+------+
| 1| value|
+----+------+
"""
return self[key]
def getField(self, name):
"""An expression that gets a field by name in a StructField.
>>> from pyspark.sql import Row
>>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF()
>>> df.select(df.r.getField("b")).show()
+----+
|r[b]|
+----+
| b|
+----+
>>> df.select(df.r.a).show()
+----+
|r[a]|
+----+
| 1|
+----+
"""
return self[name]
def __getattr__(self, item):
if item.startswith("__"):
raise AttributeError(item)
return self.getField(item)
# string methods
rlike = _bin_op("rlike")
like = _bin_op("like")
startswith = _bin_op("startsWith")
endswith = _bin_op("endsWith")
@ignore_unicode_prefix
def substr(self, startPos, length):
"""
Return a :class:`Column` which is a substring of the column
:param startPos: start position (int or Column)
:param length: length of the substring (int or Column)
>>> df.select(df.name.substr(1, 3).alias("col")).collect()
[Row(col=u'Ali'), Row(col=u'Bob')]
"""
if type(startPos) != type(length):
raise TypeError("Can not mix the type")
if isinstance(startPos, (int, long)):
jc = self._jc.substr(startPos, length)
elif isinstance(startPos, Column):
jc = self._jc.substr(startPos._jc, length._jc)
else:
raise TypeError("Unexpected type: %s" % type(startPos))
return Column(jc)
__getslice__ = substr
@ignore_unicode_prefix
def inSet(self, *cols):
""" A boolean expression that is evaluated to true if the value of this
expression is contained by the evaluated values of the arguments.
>>> df[df.name.inSet("Bob", "Mike")].collect()
[Row(age=5, name=u'Bob')]
>>> df[df.age.inSet([1, 2, 3])].collect()
[Row(age=2, name=u'Alice')]
"""
if len(cols) == 1 and isinstance(cols[0], (list, set)):
cols = cols[0]
cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols]
sc = SparkContext._active_spark_context
jc = getattr(self._jc, "in")(_to_seq(sc, cols))
return Column(jc)
# order
asc = _unary_op("asc", "Returns a sort expression based on the"
" ascending order of the given column name.")
desc = _unary_op("desc", "Returns a sort expression based on the"
" descending order of the given column name.")
isNull = _unary_op("isNull", "True if the current expression is null.")
isNotNull = _unary_op("isNotNull", "True if the current expression is not null.")
def alias(self, *alias):
"""Returns this column aliased with a new name or names (in the case of expressions that
return more than one column, such as explode).
>>> df.select(df.age.alias("age2")).collect()
[Row(age2=2), Row(age2=5)]
"""
if len(alias) == 1:
return Column(getattr(self._jc, "as")(alias[0]))
else:
sc = SparkContext._active_spark_context
return Column(getattr(self._jc, "as")(_to_seq(sc, list(alias))))
@ignore_unicode_prefix
def cast(self, dataType):
""" Convert the column into type `dataType`
>>> df.select(df.age.cast("string").alias('ages')).collect()
[Row(ages=u'2'), Row(ages=u'5')]
>>> df.select(df.age.cast(StringType()).alias('ages')).collect()
[Row(ages=u'2'), Row(ages=u'5')]
"""
if isinstance(dataType, basestring):
jc = self._jc.cast(dataType)
elif isinstance(dataType, DataType):
sc = SparkContext._active_spark_context
ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
jdt = ssql_ctx.parseDataType(dataType.json())
jc = self._jc.cast(jdt)
else:
raise TypeError("unexpected type: %s" % type(dataType))
return Column(jc)
@ignore_unicode_prefix
def between(self, lowerBound, upperBound):
""" A boolean expression that is evaluated to true if the value of this
expression is between the given columns.
"""
return (self >= lowerBound) & (self <= upperBound)
@ignore_unicode_prefix
def when(self, condition, value):
"""Evaluates a list of conditions and returns one of multiple possible result expressions.
If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
See :func:`pyspark.sql.functions.when` for example usage.
:param condition: a boolean :class:`Column` expression.
:param value: a literal value, or a :class:`Column` expression.
"""
sc = SparkContext._active_spark_context
if not isinstance(condition, Column):
raise TypeError("condition should be a Column")
v = value._jc if isinstance(value, Column) else value
jc = sc._jvm.functions.when(condition._jc, v)
return Column(jc)
@ignore_unicode_prefix
def otherwise(self, value):
"""Evaluates a list of conditions and returns one of multiple possible result expressions.
If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
See :func:`pyspark.sql.functions.when` for example usage.
:param value: a literal value, or a :class:`Column` expression.
"""
v = value._jc if isinstance(value, Column) else value
jc = self._jc.otherwise(value)
return Column(jc)
def __repr__(self):
return 'Column<%s>' % self._jc.toString().encode('utf8')
def _test():
import doctest
from pyspark.context import SparkContext
from pyspark.sql import SQLContext
import pyspark.sql.column
globs = pyspark.sql.column.__dict__.copy()
sc = SparkContext('local[4]', 'PythonTest')
globs['sc'] = sc
globs['sqlContext'] = SQLContext(sc)
globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \
.toDF(StructType([StructField('age', IntegerType()),
StructField('name', StringType())]))
(failure_count, test_count) = doctest.testmod(
pyspark.sql.column, globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
globs['sc'].stop()
if failure_count:
exit(-1)
if __name__ == "__main__":
_test()

View file

@ -25,17 +25,15 @@ if sys.version >= '3':
else:
from itertools import imap as map
from pyspark.context import SparkContext
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
from pyspark.sql.types import *
from pyspark.sql.types import _create_cls, _parse_datatype_json_string
from pyspark.sql.column import Column, _to_seq, _to_java_column
__all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD", "DataFrameNaFunctions",
"DataFrameStatFunctions"]
__all__ = ["DataFrame", "SchemaRDD", "DataFrameNaFunctions", "DataFrameStatFunctions"]
class DataFrame(object):
@ -757,6 +755,7 @@ class DataFrame(object):
[Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)]
"""
jdf = self._jdf.groupBy(self._jcols(*cols))
from pyspark.sql.group import GroupedData
return GroupedData(jdf, self.sql_ctx)
def agg(self, *exprs):
@ -1141,169 +1140,6 @@ class SchemaRDD(DataFrame):
"""
def dfapi(f):
def _api(self):
name = f.__name__
jdf = getattr(self._jdf, name)()
return DataFrame(jdf, self.sql_ctx)
_api.__name__ = f.__name__
_api.__doc__ = f.__doc__
return _api
def df_varargs_api(f):
def _api(self, *args):
name = f.__name__
jdf = getattr(self._jdf, name)(_to_seq(self.sql_ctx._sc, args))
return DataFrame(jdf, self.sql_ctx)
_api.__name__ = f.__name__
_api.__doc__ = f.__doc__
return _api
class GroupedData(object):
"""
A set of methods for aggregations on a :class:`DataFrame`,
created by :func:`DataFrame.groupBy`.
"""
def __init__(self, jdf, sql_ctx):
self._jdf = jdf
self.sql_ctx = sql_ctx
@ignore_unicode_prefix
def agg(self, *exprs):
"""Compute aggregates and returns the result as a :class:`DataFrame`.
The available aggregate functions are `avg`, `max`, `min`, `sum`, `count`.
If ``exprs`` is a single :class:`dict` mapping from string to string, then the key
is the column to perform aggregation on, and the value is the aggregate function.
Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions.
:param exprs: a dict mapping from column name (string) to aggregate functions (string),
or a list of :class:`Column`.
>>> gdf = df.groupBy(df.name)
>>> gdf.agg({"*": "count"}).collect()
[Row(name=u'Alice', COUNT(1)=1), Row(name=u'Bob', COUNT(1)=1)]
>>> from pyspark.sql import functions as F
>>> gdf.agg(F.min(df.age)).collect()
[Row(name=u'Alice', MIN(age)=2), Row(name=u'Bob', MIN(age)=5)]
"""
assert exprs, "exprs should not be empty"
if len(exprs) == 1 and isinstance(exprs[0], dict):
jdf = self._jdf.agg(exprs[0])
else:
# Columns
assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
jdf = self._jdf.agg(exprs[0]._jc,
_to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]]))
return DataFrame(jdf, self.sql_ctx)
@dfapi
def count(self):
"""Counts the number of records for each group.
>>> df.groupBy(df.age).count().collect()
[Row(age=2, count=1), Row(age=5, count=1)]
"""
@df_varargs_api
def mean(self, *cols):
"""Computes average values for each numeric columns for each group.
:func:`mean` is an alias for :func:`avg`.
:param cols: list of column names (string). Non-numeric columns are ignored.
>>> df.groupBy().mean('age').collect()
[Row(AVG(age)=3.5)]
>>> df3.groupBy().mean('age', 'height').collect()
[Row(AVG(age)=3.5, AVG(height)=82.5)]
"""
@df_varargs_api
def avg(self, *cols):
"""Computes average values for each numeric columns for each group.
:func:`mean` is an alias for :func:`avg`.
:param cols: list of column names (string). Non-numeric columns are ignored.
>>> df.groupBy().avg('age').collect()
[Row(AVG(age)=3.5)]
>>> df3.groupBy().avg('age', 'height').collect()
[Row(AVG(age)=3.5, AVG(height)=82.5)]
"""
@df_varargs_api
def max(self, *cols):
"""Computes the max value for each numeric columns for each group.
>>> df.groupBy().max('age').collect()
[Row(MAX(age)=5)]
>>> df3.groupBy().max('age', 'height').collect()
[Row(MAX(age)=5, MAX(height)=85)]
"""
@df_varargs_api
def min(self, *cols):
"""Computes the min value for each numeric column for each group.
:param cols: list of column names (string). Non-numeric columns are ignored.
>>> df.groupBy().min('age').collect()
[Row(MIN(age)=2)]
>>> df3.groupBy().min('age', 'height').collect()
[Row(MIN(age)=2, MIN(height)=80)]
"""
@df_varargs_api
def sum(self, *cols):
"""Compute the sum for each numeric columns for each group.
:param cols: list of column names (string). Non-numeric columns are ignored.
>>> df.groupBy().sum('age').collect()
[Row(SUM(age)=7)]
>>> df3.groupBy().sum('age', 'height').collect()
[Row(SUM(age)=7, SUM(height)=165)]
"""
def _create_column_from_literal(literal):
sc = SparkContext._active_spark_context
return sc._jvm.functions.lit(literal)
def _create_column_from_name(name):
sc = SparkContext._active_spark_context
return sc._jvm.functions.col(name)
def _to_java_column(col):
if isinstance(col, Column):
jcol = col._jc
else:
jcol = _create_column_from_name(col)
return jcol
def _to_seq(sc, cols, converter=None):
"""
Convert a list of Column (or names) into a JVM Seq of Column.
An optional `converter` could be used to convert items in `cols`
into JVM Column objects.
"""
if converter:
cols = [converter(c) for c in cols]
return sc._jvm.PythonUtils.toSeq(cols)
def _to_scala_map(sc, jm):
"""
Convert a dict into a JVM Map.
@ -1311,282 +1147,6 @@ def _to_scala_map(sc, jm):
return sc._jvm.PythonUtils.toScalaMap(jm)
def _unary_op(name, doc="unary operator"):
""" Create a method for given unary operator """
def _(self):
jc = getattr(self._jc, name)()
return Column(jc)
_.__doc__ = doc
return _
def _func_op(name, doc=''):
def _(self):
sc = SparkContext._active_spark_context
jc = getattr(sc._jvm.functions, name)(self._jc)
return Column(jc)
_.__doc__ = doc
return _
def _bin_op(name, doc="binary operator"):
""" Create a method for given binary operator
"""
def _(self, other):
jc = other._jc if isinstance(other, Column) else other
njc = getattr(self._jc, name)(jc)
return Column(njc)
_.__doc__ = doc
return _
def _reverse_op(name, doc="binary operator"):
""" Create a method for binary operator (this object is on right side)
"""
def _(self, other):
jother = _create_column_from_literal(other)
jc = getattr(jother, name)(self._jc)
return Column(jc)
_.__doc__ = doc
return _
class Column(object):
"""
A column in a DataFrame.
:class:`Column` instances can be created by::
# 1. Select a column out of a DataFrame
df.colName
df["colName"]
# 2. Create from an expression
df.colName + 1
1 / df.colName
"""
def __init__(self, jc):
self._jc = jc
# arithmetic operators
__neg__ = _func_op("negate")
__add__ = _bin_op("plus")
__sub__ = _bin_op("minus")
__mul__ = _bin_op("multiply")
__div__ = _bin_op("divide")
__truediv__ = _bin_op("divide")
__mod__ = _bin_op("mod")
__radd__ = _bin_op("plus")
__rsub__ = _reverse_op("minus")
__rmul__ = _bin_op("multiply")
__rdiv__ = _reverse_op("divide")
__rtruediv__ = _reverse_op("divide")
__rmod__ = _reverse_op("mod")
# logistic operators
__eq__ = _bin_op("equalTo")
__ne__ = _bin_op("notEqual")
__lt__ = _bin_op("lt")
__le__ = _bin_op("leq")
__ge__ = _bin_op("geq")
__gt__ = _bin_op("gt")
# `and`, `or`, `not` cannot be overloaded in Python,
# so use bitwise operators as boolean operators
__and__ = _bin_op('and')
__or__ = _bin_op('or')
__invert__ = _func_op('not')
__rand__ = _bin_op("and")
__ror__ = _bin_op("or")
# container operators
__contains__ = _bin_op("contains")
__getitem__ = _bin_op("apply")
# bitwise operators
bitwiseOR = _bin_op("bitwiseOR")
bitwiseAND = _bin_op("bitwiseAND")
bitwiseXOR = _bin_op("bitwiseXOR")
def getItem(self, key):
"""An expression that gets an item at position `ordinal` out of a list,
or gets an item by key out of a dict.
>>> df = sc.parallelize([([1, 2], {"key": "value"})]).toDF(["l", "d"])
>>> df.select(df.l.getItem(0), df.d.getItem("key")).show()
+----+------+
|l[0]|d[key]|
+----+------+
| 1| value|
+----+------+
>>> df.select(df.l[0], df.d["key"]).show()
+----+------+
|l[0]|d[key]|
+----+------+
| 1| value|
+----+------+
"""
return self[key]
def getField(self, name):
"""An expression that gets a field by name in a StructField.
>>> from pyspark.sql import Row
>>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF()
>>> df.select(df.r.getField("b")).show()
+----+
|r[b]|
+----+
| b|
+----+
>>> df.select(df.r.a).show()
+----+
|r[a]|
+----+
| 1|
+----+
"""
return self[name]
def __getattr__(self, item):
if item.startswith("__"):
raise AttributeError(item)
return self.getField(item)
# string methods
rlike = _bin_op("rlike")
like = _bin_op("like")
startswith = _bin_op("startsWith")
endswith = _bin_op("endsWith")
@ignore_unicode_prefix
def substr(self, startPos, length):
"""
Return a :class:`Column` which is a substring of the column
:param startPos: start position (int or Column)
:param length: length of the substring (int or Column)
>>> df.select(df.name.substr(1, 3).alias("col")).collect()
[Row(col=u'Ali'), Row(col=u'Bob')]
"""
if type(startPos) != type(length):
raise TypeError("Can not mix the type")
if isinstance(startPos, (int, long)):
jc = self._jc.substr(startPos, length)
elif isinstance(startPos, Column):
jc = self._jc.substr(startPos._jc, length._jc)
else:
raise TypeError("Unexpected type: %s" % type(startPos))
return Column(jc)
__getslice__ = substr
@ignore_unicode_prefix
def inSet(self, *cols):
""" A boolean expression that is evaluated to true if the value of this
expression is contained by the evaluated values of the arguments.
>>> df[df.name.inSet("Bob", "Mike")].collect()
[Row(age=5, name=u'Bob')]
>>> df[df.age.inSet([1, 2, 3])].collect()
[Row(age=2, name=u'Alice')]
"""
if len(cols) == 1 and isinstance(cols[0], (list, set)):
cols = cols[0]
cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols]
sc = SparkContext._active_spark_context
jc = getattr(self._jc, "in")(_to_seq(sc, cols))
return Column(jc)
# order
asc = _unary_op("asc", "Returns a sort expression based on the"
" ascending order of the given column name.")
desc = _unary_op("desc", "Returns a sort expression based on the"
" descending order of the given column name.")
isNull = _unary_op("isNull", "True if the current expression is null.")
isNotNull = _unary_op("isNotNull", "True if the current expression is not null.")
def alias(self, *alias):
"""Returns this column aliased with a new name or names (in the case of expressions that
return more than one column, such as explode).
>>> df.select(df.age.alias("age2")).collect()
[Row(age2=2), Row(age2=5)]
"""
if len(alias) == 1:
return Column(getattr(self._jc, "as")(alias[0]))
else:
sc = SparkContext._active_spark_context
return Column(getattr(self._jc, "as")(_to_seq(sc, list(alias))))
@ignore_unicode_prefix
def cast(self, dataType):
""" Convert the column into type `dataType`
>>> df.select(df.age.cast("string").alias('ages')).collect()
[Row(ages=u'2'), Row(ages=u'5')]
>>> df.select(df.age.cast(StringType()).alias('ages')).collect()
[Row(ages=u'2'), Row(ages=u'5')]
"""
if isinstance(dataType, basestring):
jc = self._jc.cast(dataType)
elif isinstance(dataType, DataType):
sc = SparkContext._active_spark_context
ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
jdt = ssql_ctx.parseDataType(dataType.json())
jc = self._jc.cast(jdt)
else:
raise TypeError("unexpected type: %s" % type(dataType))
return Column(jc)
@ignore_unicode_prefix
def between(self, lowerBound, upperBound):
""" A boolean expression that is evaluated to true if the value of this
expression is between the given columns.
"""
return (self >= lowerBound) & (self <= upperBound)
@ignore_unicode_prefix
def when(self, condition, value):
"""Evaluates a list of conditions and returns one of multiple possible result expressions.
If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
See :func:`pyspark.sql.functions.when` for example usage.
:param condition: a boolean :class:`Column` expression.
:param value: a literal value, or a :class:`Column` expression.
"""
sc = SparkContext._active_spark_context
if not isinstance(condition, Column):
raise TypeError("condition should be a Column")
v = value._jc if isinstance(value, Column) else value
jc = sc._jvm.functions.when(condition._jc, v)
return Column(jc)
@ignore_unicode_prefix
def otherwise(self, value):
"""Evaluates a list of conditions and returns one of multiple possible result expressions.
If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
See :func:`pyspark.sql.functions.when` for example usage.
:param value: a literal value, or a :class:`Column` expression.
"""
v = value._jc if isinstance(value, Column) else value
jc = self._jc.otherwise(value)
return Column(jc)
def __repr__(self):
return 'Column<%s>' % self._jc.toString().encode('utf8')
class DataFrameNaFunctions(object):
"""Functionality for working with missing data in :class:`DataFrame`.
"""
@ -1646,9 +1206,6 @@ def _test():
.toDF(StructType([StructField('age', IntegerType()),
StructField('name', StringType())]))
globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF()
globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80),
Row(name='Bob', age=5, height=85)]).toDF()
globs['df4'] = sc.parallelize([Row(name='Alice', age=10, height=80),
Row(name='Bob', age=5, height=None),
Row(name='Tom', age=None, height=None),

View file

@ -27,7 +27,7 @@ from pyspark import SparkContext
from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
from pyspark.sql.types import StringType
from pyspark.sql.dataframe import Column, _to_java_column, _to_seq
from pyspark.sql.column import Column, _to_java_column, _to_seq
__all__ = [

183
python/pyspark/sql/group.py Normal file
View file

@ -0,0 +1,183 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pyspark.rdd import ignore_unicode_prefix
from pyspark.sql.column import Column, _to_seq
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import *
__all__ = ["GroupedData"]
def dfapi(f):
def _api(self):
name = f.__name__
jdf = getattr(self._jdf, name)()
return DataFrame(jdf, self.sql_ctx)
_api.__name__ = f.__name__
_api.__doc__ = f.__doc__
return _api
def df_varargs_api(f):
def _api(self, *args):
name = f.__name__
jdf = getattr(self._jdf, name)(_to_seq(self.sql_ctx._sc, args))
return DataFrame(jdf, self.sql_ctx)
_api.__name__ = f.__name__
_api.__doc__ = f.__doc__
return _api
class GroupedData(object):
"""
A set of methods for aggregations on a :class:`DataFrame`,
created by :func:`DataFrame.groupBy`.
"""
def __init__(self, jdf, sql_ctx):
self._jdf = jdf
self.sql_ctx = sql_ctx
@ignore_unicode_prefix
def agg(self, *exprs):
"""Compute aggregates and returns the result as a :class:`DataFrame`.
The available aggregate functions are `avg`, `max`, `min`, `sum`, `count`.
If ``exprs`` is a single :class:`dict` mapping from string to string, then the key
is the column to perform aggregation on, and the value is the aggregate function.
Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions.
:param exprs: a dict mapping from column name (string) to aggregate functions (string),
or a list of :class:`Column`.
>>> gdf = df.groupBy(df.name)
>>> gdf.agg({"*": "count"}).collect()
[Row(name=u'Alice', COUNT(1)=1), Row(name=u'Bob', COUNT(1)=1)]
>>> from pyspark.sql import functions as F
>>> gdf.agg(F.min(df.age)).collect()
[Row(name=u'Alice', MIN(age)=2), Row(name=u'Bob', MIN(age)=5)]
"""
assert exprs, "exprs should not be empty"
if len(exprs) == 1 and isinstance(exprs[0], dict):
jdf = self._jdf.agg(exprs[0])
else:
# Columns
assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
jdf = self._jdf.agg(exprs[0]._jc,
_to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]]))
return DataFrame(jdf, self.sql_ctx)
@dfapi
def count(self):
"""Counts the number of records for each group.
>>> df.groupBy(df.age).count().collect()
[Row(age=2, count=1), Row(age=5, count=1)]
"""
@df_varargs_api
def mean(self, *cols):
"""Computes average values for each numeric columns for each group.
:func:`mean` is an alias for :func:`avg`.
:param cols: list of column names (string). Non-numeric columns are ignored.
>>> df.groupBy().mean('age').collect()
[Row(AVG(age)=3.5)]
>>> df3.groupBy().mean('age', 'height').collect()
[Row(AVG(age)=3.5, AVG(height)=82.5)]
"""
@df_varargs_api
def avg(self, *cols):
"""Computes average values for each numeric columns for each group.
:func:`mean` is an alias for :func:`avg`.
:param cols: list of column names (string). Non-numeric columns are ignored.
>>> df.groupBy().avg('age').collect()
[Row(AVG(age)=3.5)]
>>> df3.groupBy().avg('age', 'height').collect()
[Row(AVG(age)=3.5, AVG(height)=82.5)]
"""
@df_varargs_api
def max(self, *cols):
"""Computes the max value for each numeric columns for each group.
>>> df.groupBy().max('age').collect()
[Row(MAX(age)=5)]
>>> df3.groupBy().max('age', 'height').collect()
[Row(MAX(age)=5, MAX(height)=85)]
"""
@df_varargs_api
def min(self, *cols):
"""Computes the min value for each numeric column for each group.
:param cols: list of column names (string). Non-numeric columns are ignored.
>>> df.groupBy().min('age').collect()
[Row(MIN(age)=2)]
>>> df3.groupBy().min('age', 'height').collect()
[Row(MIN(age)=2, MIN(height)=80)]
"""
@df_varargs_api
def sum(self, *cols):
"""Compute the sum for each numeric columns for each group.
:param cols: list of column names (string). Non-numeric columns are ignored.
>>> df.groupBy().sum('age').collect()
[Row(SUM(age)=7)]
>>> df3.groupBy().sum('age', 'height').collect()
[Row(SUM(age)=7, SUM(height)=165)]
"""
def _test():
import doctest
from pyspark.context import SparkContext
from pyspark.sql import Row, SQLContext
import pyspark.sql.group
globs = pyspark.sql.group.__dict__.copy()
sc = SparkContext('local[4]', 'PythonTest')
globs['sc'] = sc
globs['sqlContext'] = SQLContext(sc)
globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \
.toDF(StructType([StructField('age', IntegerType()),
StructField('name', StringType())]))
globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80),
Row(name='Bob', age=5, height=85)]).toDF()
(failure_count, test_count) = doctest.testmod(
pyspark.sql.group, globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
globs['sc'].stop()
if failure_count:
exit(-1)
if __name__ == "__main__":
_test()

View file

@ -72,7 +72,9 @@ function run_sql_tests() {
echo "Run sql tests ..."
run_test "pyspark/sql/_types.py"
run_test "pyspark/sql/context.py"
run_test "pyspark/sql/column.py"
run_test "pyspark/sql/dataframe.py"
run_test "pyspark/sql/group.py"
run_test "pyspark/sql/functions.py"
run_test "pyspark/sql/tests.py"
}