2015-05-15 23:09:15 -04:00
|
|
|
#
|
|
|
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
|
|
# this work for additional information regarding copyright ownership.
|
|
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
|
|
# (the "License"); you may not use this file except in compliance with
|
|
|
|
# the License. You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
#
|
|
|
|
|
2018-03-08 06:38:34 -05:00
|
|
|
import sys
|
|
|
|
|
2015-09-08 23:56:22 -04:00
|
|
|
from pyspark import since
|
2017-11-17 10:43:08 -05:00
|
|
|
from pyspark.rdd import ignore_unicode_prefix, PythonEvalType
|
2018-03-25 23:42:32 -04:00
|
|
|
from pyspark.sql.column import Column, _to_seq
|
2015-05-15 23:09:15 -04:00
|
|
|
from pyspark.sql.dataframe import DataFrame
|
|
|
|
from pyspark.sql.types import *
|
|
|
|
|
|
|
|
__all__ = ["GroupedData"]
|
|
|
|
|
|
|
|
|
|
|
|
def dfapi(f):
|
|
|
|
def _api(self):
|
|
|
|
name = f.__name__
|
2016-06-18 02:43:31 -04:00
|
|
|
jdf = getattr(self._jgd, name)()
|
2015-05-15 23:09:15 -04:00
|
|
|
return DataFrame(jdf, self.sql_ctx)
|
|
|
|
_api.__name__ = f.__name__
|
|
|
|
_api.__doc__ = f.__doc__
|
|
|
|
return _api
|
|
|
|
|
|
|
|
|
|
|
|
def df_varargs_api(f):
|
2016-06-18 02:43:31 -04:00
|
|
|
def _api(self, *cols):
|
2015-05-15 23:09:15 -04:00
|
|
|
name = f.__name__
|
2016-06-18 02:43:31 -04:00
|
|
|
jdf = getattr(self._jgd, name)(_to_seq(self.sql_ctx._sc, cols))
|
2015-05-15 23:09:15 -04:00
|
|
|
return DataFrame(jdf, self.sql_ctx)
|
|
|
|
_api.__name__ = f.__name__
|
|
|
|
_api.__doc__ = f.__doc__
|
|
|
|
return _api
|
|
|
|
|
|
|
|
|
|
|
|
class GroupedData(object):
|
|
|
|
"""
|
|
|
|
A set of methods for aggregations on a :class:`DataFrame`,
|
|
|
|
created by :func:`DataFrame.groupBy`.
|
2015-05-21 02:05:54 -04:00
|
|
|
|
|
|
|
.. versionadded:: 1.3
|
2015-05-15 23:09:15 -04:00
|
|
|
"""
|
|
|
|
|
2017-10-10 18:32:01 -04:00
|
|
|
def __init__(self, jgd, df):
|
2016-06-18 02:43:31 -04:00
|
|
|
self._jgd = jgd
|
2017-10-10 18:32:01 -04:00
|
|
|
self._df = df
|
|
|
|
self.sql_ctx = df.sql_ctx
|
2015-05-15 23:09:15 -04:00
|
|
|
|
|
|
|
@ignore_unicode_prefix
|
2015-05-21 02:05:54 -04:00
|
|
|
@since(1.3)
|
2015-05-15 23:09:15 -04:00
|
|
|
def agg(self, *exprs):
|
|
|
|
"""Compute aggregates and returns the result as a :class:`DataFrame`.
|
|
|
|
|
2018-01-23 00:11:30 -05:00
|
|
|
The available aggregate functions can be:
|
|
|
|
|
|
|
|
1. built-in aggregation functions, such as `avg`, `max`, `min`, `sum`, `count`
|
|
|
|
|
|
|
|
2. group aggregate pandas UDFs, created with :func:`pyspark.sql.functions.pandas_udf`
|
|
|
|
|
|
|
|
.. note:: There is no partial aggregation with group aggregate UDFs, i.e.,
|
|
|
|
a full shuffle is required. Also, all the data of a group will be loaded into
|
|
|
|
memory, so the user should be aware of the potential OOM risk if data is skewed
|
|
|
|
and certain groups are too large to fit in memory.
|
|
|
|
|
|
|
|
.. seealso:: :func:`pyspark.sql.functions.pandas_udf`
|
2015-05-15 23:09:15 -04:00
|
|
|
|
|
|
|
If ``exprs`` is a single :class:`dict` mapping from string to string, then the key
|
|
|
|
is the column to perform aggregation on, and the value is the aggregate function.
|
|
|
|
|
|
|
|
Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions.
|
|
|
|
|
2018-01-23 00:11:30 -05:00
|
|
|
.. note:: Built-in aggregation functions and group aggregate pandas UDFs cannot be mixed
|
|
|
|
in a single call to this function.
|
|
|
|
|
2015-05-15 23:09:15 -04:00
|
|
|
:param exprs: a dict mapping from column name (string) to aggregate functions (string),
|
|
|
|
or a list of :class:`Column`.
|
|
|
|
|
|
|
|
>>> gdf = df.groupBy(df.name)
|
2016-01-14 01:43:28 -05:00
|
|
|
>>> sorted(gdf.agg({"*": "count"}).collect())
|
2015-07-02 00:14:13 -04:00
|
|
|
[Row(name=u'Alice', count(1)=1), Row(name=u'Bob', count(1)=1)]
|
2015-05-15 23:09:15 -04:00
|
|
|
|
|
|
|
>>> from pyspark.sql import functions as F
|
2016-01-14 01:43:28 -05:00
|
|
|
>>> sorted(gdf.agg(F.min(df.age)).collect())
|
2015-07-02 00:14:13 -04:00
|
|
|
[Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)]
|
2018-01-23 00:11:30 -05:00
|
|
|
|
|
|
|
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
|
2018-01-30 07:55:55 -05:00
|
|
|
>>> @pandas_udf('int', PandasUDFType.GROUPED_AGG) # doctest: +SKIP
|
2018-01-23 00:11:30 -05:00
|
|
|
... def min_udf(v):
|
|
|
|
... return v.min()
|
|
|
|
>>> sorted(gdf.agg(min_udf(df.age)).collect()) # doctest: +SKIP
|
|
|
|
[Row(name=u'Alice', min_udf(age)=2), Row(name=u'Bob', min_udf(age)=5)]
|
2015-05-15 23:09:15 -04:00
|
|
|
"""
|
|
|
|
assert exprs, "exprs should not be empty"
|
|
|
|
if len(exprs) == 1 and isinstance(exprs[0], dict):
|
2016-06-18 02:43:31 -04:00
|
|
|
jdf = self._jgd.agg(exprs[0])
|
2015-05-15 23:09:15 -04:00
|
|
|
else:
|
|
|
|
# Columns
|
|
|
|
assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
|
2016-06-18 02:43:31 -04:00
|
|
|
jdf = self._jgd.agg(exprs[0]._jc,
|
2015-05-15 23:09:15 -04:00
|
|
|
_to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]]))
|
|
|
|
return DataFrame(jdf, self.sql_ctx)
|
|
|
|
|
|
|
|
@dfapi
|
2015-05-21 02:05:54 -04:00
|
|
|
@since(1.3)
|
2015-05-15 23:09:15 -04:00
|
|
|
def count(self):
|
|
|
|
"""Counts the number of records for each group.
|
|
|
|
|
2016-01-14 01:43:28 -05:00
|
|
|
>>> sorted(df.groupBy(df.age).count().collect())
|
2015-05-15 23:09:15 -04:00
|
|
|
[Row(age=2, count=1), Row(age=5, count=1)]
|
|
|
|
"""
|
|
|
|
|
|
|
|
@df_varargs_api
|
2015-05-21 02:05:54 -04:00
|
|
|
@since(1.3)
|
2015-05-15 23:09:15 -04:00
|
|
|
def mean(self, *cols):
|
|
|
|
"""Computes average values for each numeric columns for each group.
|
|
|
|
|
|
|
|
:func:`mean` is an alias for :func:`avg`.
|
|
|
|
|
|
|
|
:param cols: list of column names (string). Non-numeric columns are ignored.
|
|
|
|
|
|
|
|
>>> df.groupBy().mean('age').collect()
|
2015-07-02 00:14:13 -04:00
|
|
|
[Row(avg(age)=3.5)]
|
2015-05-15 23:09:15 -04:00
|
|
|
>>> df3.groupBy().mean('age', 'height').collect()
|
2015-07-02 00:14:13 -04:00
|
|
|
[Row(avg(age)=3.5, avg(height)=82.5)]
|
2015-05-15 23:09:15 -04:00
|
|
|
"""
|
|
|
|
|
|
|
|
@df_varargs_api
|
2015-05-21 02:05:54 -04:00
|
|
|
@since(1.3)
|
2015-05-15 23:09:15 -04:00
|
|
|
def avg(self, *cols):
|
|
|
|
"""Computes average values for each numeric columns for each group.
|
|
|
|
|
|
|
|
:func:`mean` is an alias for :func:`avg`.
|
|
|
|
|
|
|
|
:param cols: list of column names (string). Non-numeric columns are ignored.
|
|
|
|
|
|
|
|
>>> df.groupBy().avg('age').collect()
|
2015-07-02 00:14:13 -04:00
|
|
|
[Row(avg(age)=3.5)]
|
2015-05-15 23:09:15 -04:00
|
|
|
>>> df3.groupBy().avg('age', 'height').collect()
|
2015-07-02 00:14:13 -04:00
|
|
|
[Row(avg(age)=3.5, avg(height)=82.5)]
|
2015-05-15 23:09:15 -04:00
|
|
|
"""
|
|
|
|
|
|
|
|
@df_varargs_api
|
2015-05-21 02:05:54 -04:00
|
|
|
@since(1.3)
|
2015-05-15 23:09:15 -04:00
|
|
|
def max(self, *cols):
|
|
|
|
"""Computes the max value for each numeric columns for each group.
|
|
|
|
|
|
|
|
>>> df.groupBy().max('age').collect()
|
2015-07-02 00:14:13 -04:00
|
|
|
[Row(max(age)=5)]
|
2015-05-15 23:09:15 -04:00
|
|
|
>>> df3.groupBy().max('age', 'height').collect()
|
2015-07-02 00:14:13 -04:00
|
|
|
[Row(max(age)=5, max(height)=85)]
|
2015-05-15 23:09:15 -04:00
|
|
|
"""
|
|
|
|
|
|
|
|
@df_varargs_api
|
2015-05-21 02:05:54 -04:00
|
|
|
@since(1.3)
|
2015-05-15 23:09:15 -04:00
|
|
|
def min(self, *cols):
|
|
|
|
"""Computes the min value for each numeric column for each group.
|
|
|
|
|
|
|
|
:param cols: list of column names (string). Non-numeric columns are ignored.
|
|
|
|
|
|
|
|
>>> df.groupBy().min('age').collect()
|
2015-07-02 00:14:13 -04:00
|
|
|
[Row(min(age)=2)]
|
2015-05-15 23:09:15 -04:00
|
|
|
>>> df3.groupBy().min('age', 'height').collect()
|
2015-07-02 00:14:13 -04:00
|
|
|
[Row(min(age)=2, min(height)=80)]
|
2015-05-15 23:09:15 -04:00
|
|
|
"""
|
|
|
|
|
|
|
|
@df_varargs_api
|
2015-05-21 02:05:54 -04:00
|
|
|
@since(1.3)
|
2015-05-15 23:09:15 -04:00
|
|
|
def sum(self, *cols):
|
|
|
|
"""Compute the sum for each numeric columns for each group.
|
|
|
|
|
|
|
|
:param cols: list of column names (string). Non-numeric columns are ignored.
|
|
|
|
|
|
|
|
>>> df.groupBy().sum('age').collect()
|
2015-07-02 00:14:13 -04:00
|
|
|
[Row(sum(age)=7)]
|
2015-05-15 23:09:15 -04:00
|
|
|
>>> df3.groupBy().sum('age', 'height').collect()
|
2015-07-02 00:14:13 -04:00
|
|
|
[Row(sum(age)=7, sum(height)=165)]
|
2015-05-15 23:09:15 -04:00
|
|
|
"""
|
|
|
|
|
2015-11-13 13:31:17 -05:00
|
|
|
@since(1.6)
|
2015-11-24 15:54:37 -05:00
|
|
|
def pivot(self, pivot_col, values=None):
|
2015-12-07 18:01:00 -05:00
|
|
|
"""
|
2017-10-10 18:32:01 -04:00
|
|
|
Pivots a column of the current :class:`DataFrame` and perform the specified aggregation.
|
2015-12-07 18:01:00 -05:00
|
|
|
There are two versions of pivot function: one that requires the caller to specify the list
|
|
|
|
of distinct values to pivot on, and one that does not. The latter is more concise but less
|
|
|
|
efficient, because Spark needs to first compute the list of distinct values internally.
|
2015-11-13 13:31:17 -05:00
|
|
|
|
2015-12-07 18:01:00 -05:00
|
|
|
:param pivot_col: Name of the column to pivot.
|
|
|
|
:param values: List of values that will be translated to columns in the output DataFrame.
|
2015-11-24 15:54:37 -05:00
|
|
|
|
2016-06-18 02:43:31 -04:00
|
|
|
# Compute the sum of earnings for each year by course with each course as a separate column
|
2016-07-06 13:45:51 -04:00
|
|
|
|
2015-11-24 15:54:37 -05:00
|
|
|
>>> df4.groupBy("year").pivot("course", ["dotNET", "Java"]).sum("earnings").collect()
|
2015-11-13 13:31:17 -05:00
|
|
|
[Row(year=2012, dotNET=15000, Java=20000), Row(year=2013, dotNET=48000, Java=30000)]
|
2015-11-24 15:54:37 -05:00
|
|
|
|
2016-06-18 02:43:31 -04:00
|
|
|
# Or without specifying column values (less efficient)
|
2016-07-06 13:45:51 -04:00
|
|
|
|
2015-11-13 13:31:17 -05:00
|
|
|
>>> df4.groupBy("year").pivot("course").sum("earnings").collect()
|
|
|
|
[Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)]
|
2018-08-04 02:17:32 -04:00
|
|
|
>>> df5.groupBy("sales.year").pivot("sales.course").sum("sales.earnings").collect()
|
|
|
|
[Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)]
|
2015-11-13 13:31:17 -05:00
|
|
|
"""
|
2015-11-24 15:54:37 -05:00
|
|
|
if values is None:
|
2016-06-18 02:43:31 -04:00
|
|
|
jgd = self._jgd.pivot(pivot_col)
|
2015-11-24 15:54:37 -05:00
|
|
|
else:
|
2016-06-18 02:43:31 -04:00
|
|
|
jgd = self._jgd.pivot(pivot_col, values)
|
2017-10-10 18:32:01 -04:00
|
|
|
return GroupedData(jgd, self._df)
|
|
|
|
|
|
|
|
@since(2.3)
|
|
|
|
def apply(self, udf):
|
|
|
|
"""
|
|
|
|
Maps each group of the current :class:`DataFrame` using a pandas udf and returns the result
|
|
|
|
as a `DataFrame`.
|
|
|
|
|
|
|
|
The user-defined function should take a `pandas.DataFrame` and return another
|
|
|
|
`pandas.DataFrame`. For each group, all columns are passed together as a `pandas.DataFrame`
|
2018-01-23 00:11:30 -05:00
|
|
|
to the user-function and the returned `pandas.DataFrame` are combined as a
|
2017-10-10 18:32:01 -04:00
|
|
|
:class:`DataFrame`.
|
2018-01-23 00:11:30 -05:00
|
|
|
|
2017-10-10 18:32:01 -04:00
|
|
|
The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the
|
|
|
|
returnType of the pandas udf.
|
|
|
|
|
2018-01-23 00:11:30 -05:00
|
|
|
.. note:: This function requires a full shuffle. all the data of a group will be loaded
|
|
|
|
into memory, so the user should be aware of the potential OOM risk if data is skewed
|
|
|
|
and certain groups are too large to fit in memory.
|
2017-10-10 18:32:01 -04:00
|
|
|
|
2018-01-30 07:55:55 -05:00
|
|
|
:param udf: a grouped map user-defined function returned by
|
2018-01-23 00:11:30 -05:00
|
|
|
:func:`pyspark.sql.functions.pandas_udf`.
|
2017-10-10 18:32:01 -04:00
|
|
|
|
2017-11-17 10:43:08 -05:00
|
|
|
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
|
2017-10-10 18:32:01 -04:00
|
|
|
>>> df = spark.createDataFrame(
|
|
|
|
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
|
|
|
|
... ("id", "v"))
|
2018-01-30 07:55:55 -05:00
|
|
|
>>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP
|
2017-10-10 18:32:01 -04:00
|
|
|
... def normalize(pdf):
|
|
|
|
... v = pdf.v
|
|
|
|
... return pdf.assign(v=(v - v.mean()) / v.std())
|
2017-11-17 10:43:08 -05:00
|
|
|
>>> df.groupby("id").apply(normalize).show() # doctest: +SKIP
|
2017-10-10 18:32:01 -04:00
|
|
|
+---+-------------------+
|
|
|
|
| id| v|
|
|
|
|
+---+-------------------+
|
|
|
|
| 1|-0.7071067811865475|
|
|
|
|
| 1| 0.7071067811865475|
|
|
|
|
| 2|-0.8320502943378437|
|
|
|
|
| 2|-0.2773500981126146|
|
|
|
|
| 2| 1.1094003924504583|
|
|
|
|
+---+-------------------+
|
|
|
|
|
|
|
|
.. seealso:: :meth:`pyspark.sql.functions.pandas_udf`
|
|
|
|
|
|
|
|
"""
|
|
|
|
# Columns are special because hasattr always return True
|
2017-10-20 15:44:30 -04:00
|
|
|
if isinstance(udf, Column) or not hasattr(udf, 'func') \
|
2018-01-30 07:55:55 -05:00
|
|
|
or udf.evalType != PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
|
2017-11-17 10:43:08 -05:00
|
|
|
raise ValueError("Invalid udf: the udf argument must be a pandas_udf of type "
|
2018-01-30 07:55:55 -05:00
|
|
|
"GROUPED_MAP.")
|
2017-10-10 18:32:01 -04:00
|
|
|
df = self._df
|
2017-11-17 10:43:08 -05:00
|
|
|
udf_column = udf(*[df[col] for col in df.columns])
|
2017-10-10 18:32:01 -04:00
|
|
|
jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
|
|
|
|
return DataFrame(jdf, self.sql_ctx)
|
2015-11-13 13:31:17 -05:00
|
|
|
|
2015-05-15 23:09:15 -04:00
|
|
|
|
|
|
|
def _test():
|
|
|
|
import doctest
|
2016-05-23 21:14:48 -04:00
|
|
|
from pyspark.sql import Row, SparkSession
|
2015-05-15 23:09:15 -04:00
|
|
|
import pyspark.sql.group
|
|
|
|
globs = pyspark.sql.group.__dict__.copy()
|
2016-05-23 21:14:48 -04:00
|
|
|
spark = SparkSession.builder\
|
|
|
|
.master("local[4]")\
|
|
|
|
.appName("sql.group tests")\
|
|
|
|
.getOrCreate()
|
|
|
|
sc = spark.sparkContext
|
2015-05-15 23:09:15 -04:00
|
|
|
globs['sc'] = sc
|
2017-10-10 18:32:01 -04:00
|
|
|
globs['spark'] = spark
|
2015-05-15 23:09:15 -04:00
|
|
|
globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \
|
|
|
|
.toDF(StructType([StructField('age', IntegerType()),
|
|
|
|
StructField('name', StringType())]))
|
|
|
|
globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80),
|
|
|
|
Row(name='Bob', age=5, height=85)]).toDF()
|
2015-11-13 13:31:17 -05:00
|
|
|
globs['df4'] = sc.parallelize([Row(course="dotNET", year=2012, earnings=10000),
|
|
|
|
Row(course="Java", year=2012, earnings=20000),
|
|
|
|
Row(course="dotNET", year=2012, earnings=5000),
|
|
|
|
Row(course="dotNET", year=2013, earnings=48000),
|
|
|
|
Row(course="Java", year=2013, earnings=30000)]).toDF()
|
2018-08-04 02:17:32 -04:00
|
|
|
globs['df5'] = sc.parallelize([
|
|
|
|
Row(training="expert", sales=Row(course="dotNET", year=2012, earnings=10000)),
|
|
|
|
Row(training="junior", sales=Row(course="Java", year=2012, earnings=20000)),
|
|
|
|
Row(training="expert", sales=Row(course="dotNET", year=2012, earnings=5000)),
|
|
|
|
Row(training="junior", sales=Row(course="dotNET", year=2013, earnings=48000)),
|
|
|
|
Row(training="expert", sales=Row(course="Java", year=2013, earnings=30000))]).toDF()
|
2015-05-15 23:09:15 -04:00
|
|
|
|
|
|
|
(failure_count, test_count) = doctest.testmod(
|
|
|
|
pyspark.sql.group, globs=globs,
|
|
|
|
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
|
2016-05-23 21:14:48 -04:00
|
|
|
spark.stop()
|
2015-05-15 23:09:15 -04:00
|
|
|
if failure_count:
|
2018-03-08 06:38:34 -05:00
|
|
|
sys.exit(-1)
|
2015-05-15 23:09:15 -04:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
_test()
|