[SPARK-7543] [SQL] [PySpark] split dataframe.py into multiple files
dataframe.py is splited into column.py, group.py and dataframe.py: ``` 360 column.py 1223 dataframe.py 183 group.py ``` Author: Davies Liu <davies@databricks.com> Closes #6201 from davies/split_df and squashes the following commits: fc8f5ab [Davies Liu] split dataframe.py into multiple files
This commit is contained in:
parent
adfd366814
commit
d7b69946cb
|
@ -55,8 +55,9 @@ del modname, sys
|
|||
|
||||
from pyspark.sql.types import Row
|
||||
from pyspark.sql.context import SQLContext, HiveContext
|
||||
from pyspark.sql.dataframe import DataFrame, GroupedData, Column, SchemaRDD, DataFrameNaFunctions
|
||||
from pyspark.sql.dataframe import DataFrameStatFunctions
|
||||
from pyspark.sql.column import Column
|
||||
from pyspark.sql.dataframe import DataFrame, SchemaRDD, DataFrameNaFunctions, DataFrameStatFunctions
|
||||
from pyspark.sql.group import GroupedData
|
||||
|
||||
__all__ = [
|
||||
'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row',
|
||||
|
|
360
python/pyspark/sql/column.py
Normal file
360
python/pyspark/sql/column.py
Normal file
|
@ -0,0 +1,360 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright ownership.
|
||||
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
# (the "License"); you may not use this file except in compliance with
|
||||
# the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import sys
|
||||
|
||||
if sys.version >= '3':
|
||||
basestring = str
|
||||
long = int
|
||||
|
||||
from pyspark.context import SparkContext
|
||||
from pyspark.rdd import ignore_unicode_prefix
|
||||
from pyspark.sql.types import *
|
||||
|
||||
__all__ = ["DataFrame", "Column", "SchemaRDD", "DataFrameNaFunctions",
|
||||
"DataFrameStatFunctions"]
|
||||
|
||||
|
||||
def _create_column_from_literal(literal):
|
||||
sc = SparkContext._active_spark_context
|
||||
return sc._jvm.functions.lit(literal)
|
||||
|
||||
|
||||
def _create_column_from_name(name):
|
||||
sc = SparkContext._active_spark_context
|
||||
return sc._jvm.functions.col(name)
|
||||
|
||||
|
||||
def _to_java_column(col):
|
||||
if isinstance(col, Column):
|
||||
jcol = col._jc
|
||||
else:
|
||||
jcol = _create_column_from_name(col)
|
||||
return jcol
|
||||
|
||||
|
||||
def _to_seq(sc, cols, converter=None):
|
||||
"""
|
||||
Convert a list of Column (or names) into a JVM Seq of Column.
|
||||
|
||||
An optional `converter` could be used to convert items in `cols`
|
||||
into JVM Column objects.
|
||||
"""
|
||||
if converter:
|
||||
cols = [converter(c) for c in cols]
|
||||
return sc._jvm.PythonUtils.toSeq(cols)
|
||||
|
||||
|
||||
def _unary_op(name, doc="unary operator"):
|
||||
""" Create a method for given unary operator """
|
||||
def _(self):
|
||||
jc = getattr(self._jc, name)()
|
||||
return Column(jc)
|
||||
_.__doc__ = doc
|
||||
return _
|
||||
|
||||
|
||||
def _func_op(name, doc=''):
|
||||
def _(self):
|
||||
sc = SparkContext._active_spark_context
|
||||
jc = getattr(sc._jvm.functions, name)(self._jc)
|
||||
return Column(jc)
|
||||
_.__doc__ = doc
|
||||
return _
|
||||
|
||||
|
||||
def _bin_op(name, doc="binary operator"):
|
||||
""" Create a method for given binary operator
|
||||
"""
|
||||
def _(self, other):
|
||||
jc = other._jc if isinstance(other, Column) else other
|
||||
njc = getattr(self._jc, name)(jc)
|
||||
return Column(njc)
|
||||
_.__doc__ = doc
|
||||
return _
|
||||
|
||||
|
||||
def _reverse_op(name, doc="binary operator"):
|
||||
""" Create a method for binary operator (this object is on right side)
|
||||
"""
|
||||
def _(self, other):
|
||||
jother = _create_column_from_literal(other)
|
||||
jc = getattr(jother, name)(self._jc)
|
||||
return Column(jc)
|
||||
_.__doc__ = doc
|
||||
return _
|
||||
|
||||
|
||||
class Column(object):
|
||||
|
||||
"""
|
||||
A column in a DataFrame.
|
||||
|
||||
:class:`Column` instances can be created by::
|
||||
|
||||
# 1. Select a column out of a DataFrame
|
||||
|
||||
df.colName
|
||||
df["colName"]
|
||||
|
||||
# 2. Create from an expression
|
||||
df.colName + 1
|
||||
1 / df.colName
|
||||
"""
|
||||
|
||||
def __init__(self, jc):
|
||||
self._jc = jc
|
||||
|
||||
# arithmetic operators
|
||||
__neg__ = _func_op("negate")
|
||||
__add__ = _bin_op("plus")
|
||||
__sub__ = _bin_op("minus")
|
||||
__mul__ = _bin_op("multiply")
|
||||
__div__ = _bin_op("divide")
|
||||
__truediv__ = _bin_op("divide")
|
||||
__mod__ = _bin_op("mod")
|
||||
__radd__ = _bin_op("plus")
|
||||
__rsub__ = _reverse_op("minus")
|
||||
__rmul__ = _bin_op("multiply")
|
||||
__rdiv__ = _reverse_op("divide")
|
||||
__rtruediv__ = _reverse_op("divide")
|
||||
__rmod__ = _reverse_op("mod")
|
||||
|
||||
# logistic operators
|
||||
__eq__ = _bin_op("equalTo")
|
||||
__ne__ = _bin_op("notEqual")
|
||||
__lt__ = _bin_op("lt")
|
||||
__le__ = _bin_op("leq")
|
||||
__ge__ = _bin_op("geq")
|
||||
__gt__ = _bin_op("gt")
|
||||
|
||||
# `and`, `or`, `not` cannot be overloaded in Python,
|
||||
# so use bitwise operators as boolean operators
|
||||
__and__ = _bin_op('and')
|
||||
__or__ = _bin_op('or')
|
||||
__invert__ = _func_op('not')
|
||||
__rand__ = _bin_op("and")
|
||||
__ror__ = _bin_op("or")
|
||||
|
||||
# container operators
|
||||
__contains__ = _bin_op("contains")
|
||||
__getitem__ = _bin_op("apply")
|
||||
|
||||
# bitwise operators
|
||||
bitwiseOR = _bin_op("bitwiseOR")
|
||||
bitwiseAND = _bin_op("bitwiseAND")
|
||||
bitwiseXOR = _bin_op("bitwiseXOR")
|
||||
|
||||
def getItem(self, key):
|
||||
"""An expression that gets an item at position `ordinal` out of a list,
|
||||
or gets an item by key out of a dict.
|
||||
|
||||
>>> df = sc.parallelize([([1, 2], {"key": "value"})]).toDF(["l", "d"])
|
||||
>>> df.select(df.l.getItem(0), df.d.getItem("key")).show()
|
||||
+----+------+
|
||||
|l[0]|d[key]|
|
||||
+----+------+
|
||||
| 1| value|
|
||||
+----+------+
|
||||
>>> df.select(df.l[0], df.d["key"]).show()
|
||||
+----+------+
|
||||
|l[0]|d[key]|
|
||||
+----+------+
|
||||
| 1| value|
|
||||
+----+------+
|
||||
"""
|
||||
return self[key]
|
||||
|
||||
def getField(self, name):
|
||||
"""An expression that gets a field by name in a StructField.
|
||||
|
||||
>>> from pyspark.sql import Row
|
||||
>>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF()
|
||||
>>> df.select(df.r.getField("b")).show()
|
||||
+----+
|
||||
|r[b]|
|
||||
+----+
|
||||
| b|
|
||||
+----+
|
||||
>>> df.select(df.r.a).show()
|
||||
+----+
|
||||
|r[a]|
|
||||
+----+
|
||||
| 1|
|
||||
+----+
|
||||
"""
|
||||
return self[name]
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item.startswith("__"):
|
||||
raise AttributeError(item)
|
||||
return self.getField(item)
|
||||
|
||||
# string methods
|
||||
rlike = _bin_op("rlike")
|
||||
like = _bin_op("like")
|
||||
startswith = _bin_op("startsWith")
|
||||
endswith = _bin_op("endsWith")
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def substr(self, startPos, length):
|
||||
"""
|
||||
Return a :class:`Column` which is a substring of the column
|
||||
|
||||
:param startPos: start position (int or Column)
|
||||
:param length: length of the substring (int or Column)
|
||||
|
||||
>>> df.select(df.name.substr(1, 3).alias("col")).collect()
|
||||
[Row(col=u'Ali'), Row(col=u'Bob')]
|
||||
"""
|
||||
if type(startPos) != type(length):
|
||||
raise TypeError("Can not mix the type")
|
||||
if isinstance(startPos, (int, long)):
|
||||
jc = self._jc.substr(startPos, length)
|
||||
elif isinstance(startPos, Column):
|
||||
jc = self._jc.substr(startPos._jc, length._jc)
|
||||
else:
|
||||
raise TypeError("Unexpected type: %s" % type(startPos))
|
||||
return Column(jc)
|
||||
|
||||
__getslice__ = substr
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def inSet(self, *cols):
|
||||
""" A boolean expression that is evaluated to true if the value of this
|
||||
expression is contained by the evaluated values of the arguments.
|
||||
|
||||
>>> df[df.name.inSet("Bob", "Mike")].collect()
|
||||
[Row(age=5, name=u'Bob')]
|
||||
>>> df[df.age.inSet([1, 2, 3])].collect()
|
||||
[Row(age=2, name=u'Alice')]
|
||||
"""
|
||||
if len(cols) == 1 and isinstance(cols[0], (list, set)):
|
||||
cols = cols[0]
|
||||
cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols]
|
||||
sc = SparkContext._active_spark_context
|
||||
jc = getattr(self._jc, "in")(_to_seq(sc, cols))
|
||||
return Column(jc)
|
||||
|
||||
# order
|
||||
asc = _unary_op("asc", "Returns a sort expression based on the"
|
||||
" ascending order of the given column name.")
|
||||
desc = _unary_op("desc", "Returns a sort expression based on the"
|
||||
" descending order of the given column name.")
|
||||
|
||||
isNull = _unary_op("isNull", "True if the current expression is null.")
|
||||
isNotNull = _unary_op("isNotNull", "True if the current expression is not null.")
|
||||
|
||||
def alias(self, *alias):
|
||||
"""Returns this column aliased with a new name or names (in the case of expressions that
|
||||
return more than one column, such as explode).
|
||||
|
||||
>>> df.select(df.age.alias("age2")).collect()
|
||||
[Row(age2=2), Row(age2=5)]
|
||||
"""
|
||||
|
||||
if len(alias) == 1:
|
||||
return Column(getattr(self._jc, "as")(alias[0]))
|
||||
else:
|
||||
sc = SparkContext._active_spark_context
|
||||
return Column(getattr(self._jc, "as")(_to_seq(sc, list(alias))))
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def cast(self, dataType):
|
||||
""" Convert the column into type `dataType`
|
||||
|
||||
>>> df.select(df.age.cast("string").alias('ages')).collect()
|
||||
[Row(ages=u'2'), Row(ages=u'5')]
|
||||
>>> df.select(df.age.cast(StringType()).alias('ages')).collect()
|
||||
[Row(ages=u'2'), Row(ages=u'5')]
|
||||
"""
|
||||
if isinstance(dataType, basestring):
|
||||
jc = self._jc.cast(dataType)
|
||||
elif isinstance(dataType, DataType):
|
||||
sc = SparkContext._active_spark_context
|
||||
ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
|
||||
jdt = ssql_ctx.parseDataType(dataType.json())
|
||||
jc = self._jc.cast(jdt)
|
||||
else:
|
||||
raise TypeError("unexpected type: %s" % type(dataType))
|
||||
return Column(jc)
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def between(self, lowerBound, upperBound):
|
||||
""" A boolean expression that is evaluated to true if the value of this
|
||||
expression is between the given columns.
|
||||
"""
|
||||
return (self >= lowerBound) & (self <= upperBound)
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def when(self, condition, value):
|
||||
"""Evaluates a list of conditions and returns one of multiple possible result expressions.
|
||||
If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
|
||||
|
||||
See :func:`pyspark.sql.functions.when` for example usage.
|
||||
|
||||
:param condition: a boolean :class:`Column` expression.
|
||||
:param value: a literal value, or a :class:`Column` expression.
|
||||
|
||||
"""
|
||||
sc = SparkContext._active_spark_context
|
||||
if not isinstance(condition, Column):
|
||||
raise TypeError("condition should be a Column")
|
||||
v = value._jc if isinstance(value, Column) else value
|
||||
jc = sc._jvm.functions.when(condition._jc, v)
|
||||
return Column(jc)
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def otherwise(self, value):
|
||||
"""Evaluates a list of conditions and returns one of multiple possible result expressions.
|
||||
If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
|
||||
|
||||
See :func:`pyspark.sql.functions.when` for example usage.
|
||||
|
||||
:param value: a literal value, or a :class:`Column` expression.
|
||||
"""
|
||||
v = value._jc if isinstance(value, Column) else value
|
||||
jc = self._jc.otherwise(value)
|
||||
return Column(jc)
|
||||
|
||||
def __repr__(self):
|
||||
return 'Column<%s>' % self._jc.toString().encode('utf8')
|
||||
|
||||
|
||||
def _test():
|
||||
import doctest
|
||||
from pyspark.context import SparkContext
|
||||
from pyspark.sql import SQLContext
|
||||
import pyspark.sql.column
|
||||
globs = pyspark.sql.column.__dict__.copy()
|
||||
sc = SparkContext('local[4]', 'PythonTest')
|
||||
globs['sc'] = sc
|
||||
globs['sqlContext'] = SQLContext(sc)
|
||||
globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \
|
||||
.toDF(StructType([StructField('age', IntegerType()),
|
||||
StructField('name', StringType())]))
|
||||
|
||||
(failure_count, test_count) = doctest.testmod(
|
||||
pyspark.sql.column, globs=globs,
|
||||
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
|
||||
globs['sc'].stop()
|
||||
if failure_count:
|
||||
exit(-1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test()
|
|
@ -25,17 +25,15 @@ if sys.version >= '3':
|
|||
else:
|
||||
from itertools import imap as map
|
||||
|
||||
from pyspark.context import SparkContext
|
||||
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
|
||||
from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
|
||||
from pyspark.storagelevel import StorageLevel
|
||||
from pyspark.traceback_utils import SCCallSiteSync
|
||||
from pyspark.sql.types import *
|
||||
from pyspark.sql.types import _create_cls, _parse_datatype_json_string
|
||||
from pyspark.sql.column import Column, _to_seq, _to_java_column
|
||||
|
||||
|
||||
__all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD", "DataFrameNaFunctions",
|
||||
"DataFrameStatFunctions"]
|
||||
__all__ = ["DataFrame", "SchemaRDD", "DataFrameNaFunctions", "DataFrameStatFunctions"]
|
||||
|
||||
|
||||
class DataFrame(object):
|
||||
|
@ -757,6 +755,7 @@ class DataFrame(object):
|
|||
[Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)]
|
||||
"""
|
||||
jdf = self._jdf.groupBy(self._jcols(*cols))
|
||||
from pyspark.sql.group import GroupedData
|
||||
return GroupedData(jdf, self.sql_ctx)
|
||||
|
||||
def agg(self, *exprs):
|
||||
|
@ -1141,169 +1140,6 @@ class SchemaRDD(DataFrame):
|
|||
"""
|
||||
|
||||
|
||||
def dfapi(f):
|
||||
def _api(self):
|
||||
name = f.__name__
|
||||
jdf = getattr(self._jdf, name)()
|
||||
return DataFrame(jdf, self.sql_ctx)
|
||||
_api.__name__ = f.__name__
|
||||
_api.__doc__ = f.__doc__
|
||||
return _api
|
||||
|
||||
|
||||
def df_varargs_api(f):
|
||||
def _api(self, *args):
|
||||
name = f.__name__
|
||||
jdf = getattr(self._jdf, name)(_to_seq(self.sql_ctx._sc, args))
|
||||
return DataFrame(jdf, self.sql_ctx)
|
||||
_api.__name__ = f.__name__
|
||||
_api.__doc__ = f.__doc__
|
||||
return _api
|
||||
|
||||
|
||||
class GroupedData(object):
|
||||
"""
|
||||
A set of methods for aggregations on a :class:`DataFrame`,
|
||||
created by :func:`DataFrame.groupBy`.
|
||||
"""
|
||||
|
||||
def __init__(self, jdf, sql_ctx):
|
||||
self._jdf = jdf
|
||||
self.sql_ctx = sql_ctx
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def agg(self, *exprs):
|
||||
"""Compute aggregates and returns the result as a :class:`DataFrame`.
|
||||
|
||||
The available aggregate functions are `avg`, `max`, `min`, `sum`, `count`.
|
||||
|
||||
If ``exprs`` is a single :class:`dict` mapping from string to string, then the key
|
||||
is the column to perform aggregation on, and the value is the aggregate function.
|
||||
|
||||
Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions.
|
||||
|
||||
:param exprs: a dict mapping from column name (string) to aggregate functions (string),
|
||||
or a list of :class:`Column`.
|
||||
|
||||
>>> gdf = df.groupBy(df.name)
|
||||
>>> gdf.agg({"*": "count"}).collect()
|
||||
[Row(name=u'Alice', COUNT(1)=1), Row(name=u'Bob', COUNT(1)=1)]
|
||||
|
||||
>>> from pyspark.sql import functions as F
|
||||
>>> gdf.agg(F.min(df.age)).collect()
|
||||
[Row(name=u'Alice', MIN(age)=2), Row(name=u'Bob', MIN(age)=5)]
|
||||
"""
|
||||
assert exprs, "exprs should not be empty"
|
||||
if len(exprs) == 1 and isinstance(exprs[0], dict):
|
||||
jdf = self._jdf.agg(exprs[0])
|
||||
else:
|
||||
# Columns
|
||||
assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
|
||||
jdf = self._jdf.agg(exprs[0]._jc,
|
||||
_to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]]))
|
||||
return DataFrame(jdf, self.sql_ctx)
|
||||
|
||||
@dfapi
|
||||
def count(self):
|
||||
"""Counts the number of records for each group.
|
||||
|
||||
>>> df.groupBy(df.age).count().collect()
|
||||
[Row(age=2, count=1), Row(age=5, count=1)]
|
||||
"""
|
||||
|
||||
@df_varargs_api
|
||||
def mean(self, *cols):
|
||||
"""Computes average values for each numeric columns for each group.
|
||||
|
||||
:func:`mean` is an alias for :func:`avg`.
|
||||
|
||||
:param cols: list of column names (string). Non-numeric columns are ignored.
|
||||
|
||||
>>> df.groupBy().mean('age').collect()
|
||||
[Row(AVG(age)=3.5)]
|
||||
>>> df3.groupBy().mean('age', 'height').collect()
|
||||
[Row(AVG(age)=3.5, AVG(height)=82.5)]
|
||||
"""
|
||||
|
||||
@df_varargs_api
|
||||
def avg(self, *cols):
|
||||
"""Computes average values for each numeric columns for each group.
|
||||
|
||||
:func:`mean` is an alias for :func:`avg`.
|
||||
|
||||
:param cols: list of column names (string). Non-numeric columns are ignored.
|
||||
|
||||
>>> df.groupBy().avg('age').collect()
|
||||
[Row(AVG(age)=3.5)]
|
||||
>>> df3.groupBy().avg('age', 'height').collect()
|
||||
[Row(AVG(age)=3.5, AVG(height)=82.5)]
|
||||
"""
|
||||
|
||||
@df_varargs_api
|
||||
def max(self, *cols):
|
||||
"""Computes the max value for each numeric columns for each group.
|
||||
|
||||
>>> df.groupBy().max('age').collect()
|
||||
[Row(MAX(age)=5)]
|
||||
>>> df3.groupBy().max('age', 'height').collect()
|
||||
[Row(MAX(age)=5, MAX(height)=85)]
|
||||
"""
|
||||
|
||||
@df_varargs_api
|
||||
def min(self, *cols):
|
||||
"""Computes the min value for each numeric column for each group.
|
||||
|
||||
:param cols: list of column names (string). Non-numeric columns are ignored.
|
||||
|
||||
>>> df.groupBy().min('age').collect()
|
||||
[Row(MIN(age)=2)]
|
||||
>>> df3.groupBy().min('age', 'height').collect()
|
||||
[Row(MIN(age)=2, MIN(height)=80)]
|
||||
"""
|
||||
|
||||
@df_varargs_api
|
||||
def sum(self, *cols):
|
||||
"""Compute the sum for each numeric columns for each group.
|
||||
|
||||
:param cols: list of column names (string). Non-numeric columns are ignored.
|
||||
|
||||
>>> df.groupBy().sum('age').collect()
|
||||
[Row(SUM(age)=7)]
|
||||
>>> df3.groupBy().sum('age', 'height').collect()
|
||||
[Row(SUM(age)=7, SUM(height)=165)]
|
||||
"""
|
||||
|
||||
|
||||
def _create_column_from_literal(literal):
|
||||
sc = SparkContext._active_spark_context
|
||||
return sc._jvm.functions.lit(literal)
|
||||
|
||||
|
||||
def _create_column_from_name(name):
|
||||
sc = SparkContext._active_spark_context
|
||||
return sc._jvm.functions.col(name)
|
||||
|
||||
|
||||
def _to_java_column(col):
|
||||
if isinstance(col, Column):
|
||||
jcol = col._jc
|
||||
else:
|
||||
jcol = _create_column_from_name(col)
|
||||
return jcol
|
||||
|
||||
|
||||
def _to_seq(sc, cols, converter=None):
|
||||
"""
|
||||
Convert a list of Column (or names) into a JVM Seq of Column.
|
||||
|
||||
An optional `converter` could be used to convert items in `cols`
|
||||
into JVM Column objects.
|
||||
"""
|
||||
if converter:
|
||||
cols = [converter(c) for c in cols]
|
||||
return sc._jvm.PythonUtils.toSeq(cols)
|
||||
|
||||
|
||||
def _to_scala_map(sc, jm):
|
||||
"""
|
||||
Convert a dict into a JVM Map.
|
||||
|
@ -1311,282 +1147,6 @@ def _to_scala_map(sc, jm):
|
|||
return sc._jvm.PythonUtils.toScalaMap(jm)
|
||||
|
||||
|
||||
def _unary_op(name, doc="unary operator"):
|
||||
""" Create a method for given unary operator """
|
||||
def _(self):
|
||||
jc = getattr(self._jc, name)()
|
||||
return Column(jc)
|
||||
_.__doc__ = doc
|
||||
return _
|
||||
|
||||
|
||||
def _func_op(name, doc=''):
|
||||
def _(self):
|
||||
sc = SparkContext._active_spark_context
|
||||
jc = getattr(sc._jvm.functions, name)(self._jc)
|
||||
return Column(jc)
|
||||
_.__doc__ = doc
|
||||
return _
|
||||
|
||||
|
||||
def _bin_op(name, doc="binary operator"):
|
||||
""" Create a method for given binary operator
|
||||
"""
|
||||
def _(self, other):
|
||||
jc = other._jc if isinstance(other, Column) else other
|
||||
njc = getattr(self._jc, name)(jc)
|
||||
return Column(njc)
|
||||
_.__doc__ = doc
|
||||
return _
|
||||
|
||||
|
||||
def _reverse_op(name, doc="binary operator"):
|
||||
""" Create a method for binary operator (this object is on right side)
|
||||
"""
|
||||
def _(self, other):
|
||||
jother = _create_column_from_literal(other)
|
||||
jc = getattr(jother, name)(self._jc)
|
||||
return Column(jc)
|
||||
_.__doc__ = doc
|
||||
return _
|
||||
|
||||
|
||||
class Column(object):
|
||||
|
||||
"""
|
||||
A column in a DataFrame.
|
||||
|
||||
:class:`Column` instances can be created by::
|
||||
|
||||
# 1. Select a column out of a DataFrame
|
||||
|
||||
df.colName
|
||||
df["colName"]
|
||||
|
||||
# 2. Create from an expression
|
||||
df.colName + 1
|
||||
1 / df.colName
|
||||
"""
|
||||
|
||||
def __init__(self, jc):
|
||||
self._jc = jc
|
||||
|
||||
# arithmetic operators
|
||||
__neg__ = _func_op("negate")
|
||||
__add__ = _bin_op("plus")
|
||||
__sub__ = _bin_op("minus")
|
||||
__mul__ = _bin_op("multiply")
|
||||
__div__ = _bin_op("divide")
|
||||
__truediv__ = _bin_op("divide")
|
||||
__mod__ = _bin_op("mod")
|
||||
__radd__ = _bin_op("plus")
|
||||
__rsub__ = _reverse_op("minus")
|
||||
__rmul__ = _bin_op("multiply")
|
||||
__rdiv__ = _reverse_op("divide")
|
||||
__rtruediv__ = _reverse_op("divide")
|
||||
__rmod__ = _reverse_op("mod")
|
||||
|
||||
# logistic operators
|
||||
__eq__ = _bin_op("equalTo")
|
||||
__ne__ = _bin_op("notEqual")
|
||||
__lt__ = _bin_op("lt")
|
||||
__le__ = _bin_op("leq")
|
||||
__ge__ = _bin_op("geq")
|
||||
__gt__ = _bin_op("gt")
|
||||
|
||||
# `and`, `or`, `not` cannot be overloaded in Python,
|
||||
# so use bitwise operators as boolean operators
|
||||
__and__ = _bin_op('and')
|
||||
__or__ = _bin_op('or')
|
||||
__invert__ = _func_op('not')
|
||||
__rand__ = _bin_op("and")
|
||||
__ror__ = _bin_op("or")
|
||||
|
||||
# container operators
|
||||
__contains__ = _bin_op("contains")
|
||||
__getitem__ = _bin_op("apply")
|
||||
|
||||
# bitwise operators
|
||||
bitwiseOR = _bin_op("bitwiseOR")
|
||||
bitwiseAND = _bin_op("bitwiseAND")
|
||||
bitwiseXOR = _bin_op("bitwiseXOR")
|
||||
|
||||
def getItem(self, key):
|
||||
"""An expression that gets an item at position `ordinal` out of a list,
|
||||
or gets an item by key out of a dict.
|
||||
|
||||
>>> df = sc.parallelize([([1, 2], {"key": "value"})]).toDF(["l", "d"])
|
||||
>>> df.select(df.l.getItem(0), df.d.getItem("key")).show()
|
||||
+----+------+
|
||||
|l[0]|d[key]|
|
||||
+----+------+
|
||||
| 1| value|
|
||||
+----+------+
|
||||
>>> df.select(df.l[0], df.d["key"]).show()
|
||||
+----+------+
|
||||
|l[0]|d[key]|
|
||||
+----+------+
|
||||
| 1| value|
|
||||
+----+------+
|
||||
"""
|
||||
return self[key]
|
||||
|
||||
def getField(self, name):
|
||||
"""An expression that gets a field by name in a StructField.
|
||||
|
||||
>>> from pyspark.sql import Row
|
||||
>>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF()
|
||||
>>> df.select(df.r.getField("b")).show()
|
||||
+----+
|
||||
|r[b]|
|
||||
+----+
|
||||
| b|
|
||||
+----+
|
||||
>>> df.select(df.r.a).show()
|
||||
+----+
|
||||
|r[a]|
|
||||
+----+
|
||||
| 1|
|
||||
+----+
|
||||
"""
|
||||
return self[name]
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item.startswith("__"):
|
||||
raise AttributeError(item)
|
||||
return self.getField(item)
|
||||
|
||||
# string methods
|
||||
rlike = _bin_op("rlike")
|
||||
like = _bin_op("like")
|
||||
startswith = _bin_op("startsWith")
|
||||
endswith = _bin_op("endsWith")
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def substr(self, startPos, length):
|
||||
"""
|
||||
Return a :class:`Column` which is a substring of the column
|
||||
|
||||
:param startPos: start position (int or Column)
|
||||
:param length: length of the substring (int or Column)
|
||||
|
||||
>>> df.select(df.name.substr(1, 3).alias("col")).collect()
|
||||
[Row(col=u'Ali'), Row(col=u'Bob')]
|
||||
"""
|
||||
if type(startPos) != type(length):
|
||||
raise TypeError("Can not mix the type")
|
||||
if isinstance(startPos, (int, long)):
|
||||
jc = self._jc.substr(startPos, length)
|
||||
elif isinstance(startPos, Column):
|
||||
jc = self._jc.substr(startPos._jc, length._jc)
|
||||
else:
|
||||
raise TypeError("Unexpected type: %s" % type(startPos))
|
||||
return Column(jc)
|
||||
|
||||
__getslice__ = substr
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def inSet(self, *cols):
|
||||
""" A boolean expression that is evaluated to true if the value of this
|
||||
expression is contained by the evaluated values of the arguments.
|
||||
|
||||
>>> df[df.name.inSet("Bob", "Mike")].collect()
|
||||
[Row(age=5, name=u'Bob')]
|
||||
>>> df[df.age.inSet([1, 2, 3])].collect()
|
||||
[Row(age=2, name=u'Alice')]
|
||||
"""
|
||||
if len(cols) == 1 and isinstance(cols[0], (list, set)):
|
||||
cols = cols[0]
|
||||
cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols]
|
||||
sc = SparkContext._active_spark_context
|
||||
jc = getattr(self._jc, "in")(_to_seq(sc, cols))
|
||||
return Column(jc)
|
||||
|
||||
# order
|
||||
asc = _unary_op("asc", "Returns a sort expression based on the"
|
||||
" ascending order of the given column name.")
|
||||
desc = _unary_op("desc", "Returns a sort expression based on the"
|
||||
" descending order of the given column name.")
|
||||
|
||||
isNull = _unary_op("isNull", "True if the current expression is null.")
|
||||
isNotNull = _unary_op("isNotNull", "True if the current expression is not null.")
|
||||
|
||||
def alias(self, *alias):
|
||||
"""Returns this column aliased with a new name or names (in the case of expressions that
|
||||
return more than one column, such as explode).
|
||||
|
||||
>>> df.select(df.age.alias("age2")).collect()
|
||||
[Row(age2=2), Row(age2=5)]
|
||||
"""
|
||||
|
||||
if len(alias) == 1:
|
||||
return Column(getattr(self._jc, "as")(alias[0]))
|
||||
else:
|
||||
sc = SparkContext._active_spark_context
|
||||
return Column(getattr(self._jc, "as")(_to_seq(sc, list(alias))))
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def cast(self, dataType):
|
||||
""" Convert the column into type `dataType`
|
||||
|
||||
>>> df.select(df.age.cast("string").alias('ages')).collect()
|
||||
[Row(ages=u'2'), Row(ages=u'5')]
|
||||
>>> df.select(df.age.cast(StringType()).alias('ages')).collect()
|
||||
[Row(ages=u'2'), Row(ages=u'5')]
|
||||
"""
|
||||
if isinstance(dataType, basestring):
|
||||
jc = self._jc.cast(dataType)
|
||||
elif isinstance(dataType, DataType):
|
||||
sc = SparkContext._active_spark_context
|
||||
ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
|
||||
jdt = ssql_ctx.parseDataType(dataType.json())
|
||||
jc = self._jc.cast(jdt)
|
||||
else:
|
||||
raise TypeError("unexpected type: %s" % type(dataType))
|
||||
return Column(jc)
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def between(self, lowerBound, upperBound):
|
||||
""" A boolean expression that is evaluated to true if the value of this
|
||||
expression is between the given columns.
|
||||
"""
|
||||
return (self >= lowerBound) & (self <= upperBound)
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def when(self, condition, value):
|
||||
"""Evaluates a list of conditions and returns one of multiple possible result expressions.
|
||||
If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
|
||||
|
||||
See :func:`pyspark.sql.functions.when` for example usage.
|
||||
|
||||
:param condition: a boolean :class:`Column` expression.
|
||||
:param value: a literal value, or a :class:`Column` expression.
|
||||
|
||||
"""
|
||||
sc = SparkContext._active_spark_context
|
||||
if not isinstance(condition, Column):
|
||||
raise TypeError("condition should be a Column")
|
||||
v = value._jc if isinstance(value, Column) else value
|
||||
jc = sc._jvm.functions.when(condition._jc, v)
|
||||
return Column(jc)
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def otherwise(self, value):
|
||||
"""Evaluates a list of conditions and returns one of multiple possible result expressions.
|
||||
If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
|
||||
|
||||
See :func:`pyspark.sql.functions.when` for example usage.
|
||||
|
||||
:param value: a literal value, or a :class:`Column` expression.
|
||||
"""
|
||||
v = value._jc if isinstance(value, Column) else value
|
||||
jc = self._jc.otherwise(value)
|
||||
return Column(jc)
|
||||
|
||||
def __repr__(self):
|
||||
return 'Column<%s>' % self._jc.toString().encode('utf8')
|
||||
|
||||
|
||||
class DataFrameNaFunctions(object):
|
||||
"""Functionality for working with missing data in :class:`DataFrame`.
|
||||
"""
|
||||
|
@ -1646,9 +1206,6 @@ def _test():
|
|||
.toDF(StructType([StructField('age', IntegerType()),
|
||||
StructField('name', StringType())]))
|
||||
globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF()
|
||||
globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80),
|
||||
Row(name='Bob', age=5, height=85)]).toDF()
|
||||
|
||||
globs['df4'] = sc.parallelize([Row(name='Alice', age=10, height=80),
|
||||
Row(name='Bob', age=5, height=None),
|
||||
Row(name='Tom', age=None, height=None),
|
||||
|
|
|
@ -27,7 +27,7 @@ from pyspark import SparkContext
|
|||
from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix
|
||||
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
|
||||
from pyspark.sql.types import StringType
|
||||
from pyspark.sql.dataframe import Column, _to_java_column, _to_seq
|
||||
from pyspark.sql.column import Column, _to_java_column, _to_seq
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
|
183
python/pyspark/sql/group.py
Normal file
183
python/pyspark/sql/group.py
Normal file
|
@ -0,0 +1,183 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright ownership.
|
||||
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
# (the "License"); you may not use this file except in compliance with
|
||||
# the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from pyspark.rdd import ignore_unicode_prefix
|
||||
from pyspark.sql.column import Column, _to_seq
|
||||
from pyspark.sql.dataframe import DataFrame
|
||||
from pyspark.sql.types import *
|
||||
|
||||
__all__ = ["GroupedData"]
|
||||
|
||||
|
||||
def dfapi(f):
|
||||
def _api(self):
|
||||
name = f.__name__
|
||||
jdf = getattr(self._jdf, name)()
|
||||
return DataFrame(jdf, self.sql_ctx)
|
||||
_api.__name__ = f.__name__
|
||||
_api.__doc__ = f.__doc__
|
||||
return _api
|
||||
|
||||
|
||||
def df_varargs_api(f):
|
||||
def _api(self, *args):
|
||||
name = f.__name__
|
||||
jdf = getattr(self._jdf, name)(_to_seq(self.sql_ctx._sc, args))
|
||||
return DataFrame(jdf, self.sql_ctx)
|
||||
_api.__name__ = f.__name__
|
||||
_api.__doc__ = f.__doc__
|
||||
return _api
|
||||
|
||||
|
||||
class GroupedData(object):
|
||||
"""
|
||||
A set of methods for aggregations on a :class:`DataFrame`,
|
||||
created by :func:`DataFrame.groupBy`.
|
||||
"""
|
||||
|
||||
def __init__(self, jdf, sql_ctx):
|
||||
self._jdf = jdf
|
||||
self.sql_ctx = sql_ctx
|
||||
|
||||
@ignore_unicode_prefix
|
||||
def agg(self, *exprs):
|
||||
"""Compute aggregates and returns the result as a :class:`DataFrame`.
|
||||
|
||||
The available aggregate functions are `avg`, `max`, `min`, `sum`, `count`.
|
||||
|
||||
If ``exprs`` is a single :class:`dict` mapping from string to string, then the key
|
||||
is the column to perform aggregation on, and the value is the aggregate function.
|
||||
|
||||
Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions.
|
||||
|
||||
:param exprs: a dict mapping from column name (string) to aggregate functions (string),
|
||||
or a list of :class:`Column`.
|
||||
|
||||
>>> gdf = df.groupBy(df.name)
|
||||
>>> gdf.agg({"*": "count"}).collect()
|
||||
[Row(name=u'Alice', COUNT(1)=1), Row(name=u'Bob', COUNT(1)=1)]
|
||||
|
||||
>>> from pyspark.sql import functions as F
|
||||
>>> gdf.agg(F.min(df.age)).collect()
|
||||
[Row(name=u'Alice', MIN(age)=2), Row(name=u'Bob', MIN(age)=5)]
|
||||
"""
|
||||
assert exprs, "exprs should not be empty"
|
||||
if len(exprs) == 1 and isinstance(exprs[0], dict):
|
||||
jdf = self._jdf.agg(exprs[0])
|
||||
else:
|
||||
# Columns
|
||||
assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
|
||||
jdf = self._jdf.agg(exprs[0]._jc,
|
||||
_to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]]))
|
||||
return DataFrame(jdf, self.sql_ctx)
|
||||
|
||||
@dfapi
|
||||
def count(self):
|
||||
"""Counts the number of records for each group.
|
||||
|
||||
>>> df.groupBy(df.age).count().collect()
|
||||
[Row(age=2, count=1), Row(age=5, count=1)]
|
||||
"""
|
||||
|
||||
@df_varargs_api
|
||||
def mean(self, *cols):
|
||||
"""Computes average values for each numeric columns for each group.
|
||||
|
||||
:func:`mean` is an alias for :func:`avg`.
|
||||
|
||||
:param cols: list of column names (string). Non-numeric columns are ignored.
|
||||
|
||||
>>> df.groupBy().mean('age').collect()
|
||||
[Row(AVG(age)=3.5)]
|
||||
>>> df3.groupBy().mean('age', 'height').collect()
|
||||
[Row(AVG(age)=3.5, AVG(height)=82.5)]
|
||||
"""
|
||||
|
||||
@df_varargs_api
|
||||
def avg(self, *cols):
|
||||
"""Computes average values for each numeric columns for each group.
|
||||
|
||||
:func:`mean` is an alias for :func:`avg`.
|
||||
|
||||
:param cols: list of column names (string). Non-numeric columns are ignored.
|
||||
|
||||
>>> df.groupBy().avg('age').collect()
|
||||
[Row(AVG(age)=3.5)]
|
||||
>>> df3.groupBy().avg('age', 'height').collect()
|
||||
[Row(AVG(age)=3.5, AVG(height)=82.5)]
|
||||
"""
|
||||
|
||||
@df_varargs_api
|
||||
def max(self, *cols):
|
||||
"""Computes the max value for each numeric columns for each group.
|
||||
|
||||
>>> df.groupBy().max('age').collect()
|
||||
[Row(MAX(age)=5)]
|
||||
>>> df3.groupBy().max('age', 'height').collect()
|
||||
[Row(MAX(age)=5, MAX(height)=85)]
|
||||
"""
|
||||
|
||||
@df_varargs_api
|
||||
def min(self, *cols):
|
||||
"""Computes the min value for each numeric column for each group.
|
||||
|
||||
:param cols: list of column names (string). Non-numeric columns are ignored.
|
||||
|
||||
>>> df.groupBy().min('age').collect()
|
||||
[Row(MIN(age)=2)]
|
||||
>>> df3.groupBy().min('age', 'height').collect()
|
||||
[Row(MIN(age)=2, MIN(height)=80)]
|
||||
"""
|
||||
|
||||
@df_varargs_api
|
||||
def sum(self, *cols):
|
||||
"""Compute the sum for each numeric columns for each group.
|
||||
|
||||
:param cols: list of column names (string). Non-numeric columns are ignored.
|
||||
|
||||
>>> df.groupBy().sum('age').collect()
|
||||
[Row(SUM(age)=7)]
|
||||
>>> df3.groupBy().sum('age', 'height').collect()
|
||||
[Row(SUM(age)=7, SUM(height)=165)]
|
||||
"""
|
||||
|
||||
|
||||
def _test():
|
||||
import doctest
|
||||
from pyspark.context import SparkContext
|
||||
from pyspark.sql import Row, SQLContext
|
||||
import pyspark.sql.group
|
||||
globs = pyspark.sql.group.__dict__.copy()
|
||||
sc = SparkContext('local[4]', 'PythonTest')
|
||||
globs['sc'] = sc
|
||||
globs['sqlContext'] = SQLContext(sc)
|
||||
globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \
|
||||
.toDF(StructType([StructField('age', IntegerType()),
|
||||
StructField('name', StringType())]))
|
||||
globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80),
|
||||
Row(name='Bob', age=5, height=85)]).toDF()
|
||||
|
||||
(failure_count, test_count) = doctest.testmod(
|
||||
pyspark.sql.group, globs=globs,
|
||||
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
|
||||
globs['sc'].stop()
|
||||
if failure_count:
|
||||
exit(-1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test()
|
|
@ -72,7 +72,9 @@ function run_sql_tests() {
|
|||
echo "Run sql tests ..."
|
||||
run_test "pyspark/sql/_types.py"
|
||||
run_test "pyspark/sql/context.py"
|
||||
run_test "pyspark/sql/column.py"
|
||||
run_test "pyspark/sql/dataframe.py"
|
||||
run_test "pyspark/sql/group.py"
|
||||
run_test "pyspark/sql/functions.py"
|
||||
run_test "pyspark/sql/tests.py"
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue