Fokko Driesprong 9fcf0ea718 [SPARK-32319][PYSPARK] Disallow the use of unused imports
Disallow the use of unused imports:

- Unnecessary increases the memory footprint of the application
- Removes the imports that are required for the examples in the docstring from the file-scope to the example itself. This keeps the files itself clean, and gives a more complete example as it also includes the imports :)

fokkodriesprongFan spark % flake8 python | grep -i "imported but unused"
python/pyspark/ F401 'functools.partial' imported but unused
python/pyspark/ F401 'traceback' imported but unused
python/pyspark/ F401 '_heapq.*' imported but unused
python/pyspark/ F401 'pyspark.version.__version__' imported but unused
python/pyspark/ F401 'pyspark._globals._NoValue' imported but unused
python/pyspark/ F401 'pyspark.sql.SQLContext' imported but unused
python/pyspark/ F401 'pyspark.sql.HiveContext' imported but unused
python/pyspark/ F401 'pyspark.sql.Row' imported but unused
python/pyspark/ F401 're' imported but unused
python/pyspark/ F401 'tempfile.NamedTemporaryFile' imported but unused
python/pyspark/mllib/ F401 'pyspark.mllib.linalg.SparseVector' imported but unused
python/pyspark/mllib/ F401 'pyspark.mllib.linalg.SparseVector' imported but unused
python/pyspark/mllib/ F401 'pyspark.mllib.linalg.DenseVector' imported but unused
python/pyspark/mllib/ F401 'pyspark.mllib.linalg.SparseVector' imported but unused
python/pyspark/mllib/ F401 'pyspark.mllib.linalg.DenseVector' imported but unused
python/pyspark/mllib/ F401 'pyspark.mllib.linalg.SparseVector' imported but unused
python/pyspark/mllib/ F401 'pyspark.mllib.regression.LabeledPoint' imported but unused
python/pyspark/mllib/tests/ F401 'sys' imported but unused
python/pyspark/mllib/tests/ F401 'pyspark.mllib.tests.test_linalg.*' imported but unused
python/pyspark/mllib/tests/ F401 'numpy.random' imported but unused
python/pyspark/mllib/tests/ F401 'numpy.exp' imported but unused
python/pyspark/mllib/tests/ F401 'pyspark.mllib.linalg.Vector' imported but unused
python/pyspark/mllib/tests/ F401 'pyspark.mllib.linalg.VectorUDT' imported but unused
python/pyspark/mllib/tests/ F401 'pyspark.mllib.tests.test_feature.*' imported but unused
python/pyspark/mllib/tests/ F401 'pyspark.mllib.tests.test_util.*' imported but unused
python/pyspark/mllib/tests/ F401 'pyspark.mllib.linalg.Vector' imported but unused
python/pyspark/mllib/tests/ F401 'pyspark.mllib.linalg.SparseVector' imported but unused
python/pyspark/mllib/tests/ F401 'pyspark.mllib.linalg.DenseVector' imported but unused
python/pyspark/mllib/tests/ F401 'pyspark.mllib.linalg.VectorUDT' imported but unused
python/pyspark/mllib/tests/ F401 'pyspark.mllib.linalg._convert_to_vector' imported but unused
python/pyspark/mllib/tests/ F401 'pyspark.mllib.linalg.DenseMatrix' imported but unused
python/pyspark/mllib/tests/ F401 'pyspark.mllib.linalg.SparseMatrix' imported but unused
python/pyspark/mllib/tests/ F401 'pyspark.mllib.linalg.MatrixUDT' imported but unused
python/pyspark/mllib/tests/ F401 'pyspark.mllib.tests.test_stat.*' imported but unused
python/pyspark/mllib/tests/ F401 'time.time' imported but unused
python/pyspark/mllib/tests/ F401 'time.sleep' imported but unused
python/pyspark/mllib/tests/ F401 'pyspark.mllib.tests.test_streaming_algorithms.*' imported but unused
python/pyspark/mllib/tests/ F401 'pyspark.mllib.tests.test_algorithms.*' imported but unused
python/pyspark/tests/ F401 'xmlrunner' imported but unused
python/pyspark/tests/ F401 'sys' imported but unused
python/pyspark/tests/ F401 'pyspark.resource.ResourceProfile' imported but unused
python/pyspark/tests/ F401 'pyspark.tests.test_rdd.*' imported but unused
python/pyspark/tests/ F401 'sys' imported but unused
python/pyspark/tests/ F401 'array.array' imported but unused
python/pyspark/tests/ F401 'pyspark.tests.test_readwrite.*' imported but unused
python/pyspark/tests/ F401 'pyspark.tests.test_join.*' imported but unused
python/pyspark/tests/ F401 'shutil' imported but unused
python/pyspark/tests/ F401 'pyspark.tests.test_taskcontext.*' imported but unused
python/pyspark/tests/ F401 'pyspark.tests.test_conf.*' imported but unused
python/pyspark/tests/ F401 'pyspark.tests.test_broadcast.*' imported but unused
python/pyspark/tests/ F401 'pyspark.tests.test_daemon.*' imported but unused
python/pyspark/tests/ F401 'pyspark.tests.test_util.*' imported but unused
python/pyspark/tests/ F401 'random' imported but unused
python/pyspark/tests/ F401 'pyspark.tests.test_pin_thread.*' imported but unused
python/pyspark/tests/ F401 'sys' imported but unused
python/pyspark/tests/ F401 'resource' imported but unused
python/pyspark/tests/ F401 'pyspark.tests.test_worker.*' imported but unused
python/pyspark/tests/ F401 'pyspark.tests.test_profiler.*' imported but unused
python/pyspark/tests/ F401 'sys' imported but unused
python/pyspark/tests/ F401 'pyspark.tests.test_shuffle.*' imported but unused
python/pyspark/tests/ F401 'pyspark.tests.test_rddbarrier.*' imported but unused
python/pyspark/tests/ F401 'userlibrary.UserClass' imported but unused
python/pyspark/tests/ F401 'userlib.UserClass' imported but unused
python/pyspark/tests/ F401 'pyspark.tests.test_context.*' imported but unused
python/pyspark/tests/ F401 'pyspark.tests.test_appsubmit.*' imported but unused
python/pyspark/streaming/ F401 'sys' imported but unused
python/pyspark/streaming/tests/ F401 'pyspark.RDD' imported but unused
python/pyspark/streaming/tests/ F401 'pyspark.streaming.tests.test_dstream.*' imported but unused
python/pyspark/streaming/tests/ F401 'pyspark.streaming.tests.test_kinesis.*' imported but unused
python/pyspark/streaming/tests/ F401 'pyspark.streaming.tests.test_listener.*' imported but unused
python/pyspark/streaming/tests/ F401 'pyspark.streaming.tests.test_context.*' imported but unused
python/pyspark/testing/ F401 'scipy.sparse' imported but unused
python/pyspark/testing/ F401 'numpy as np' imported but unused
python/pyspark/ml/ F401 '' imported but unused
python/pyspark/ml/ F401 '' imported but unused
python/pyspark/ml/ F401 '' imported but unused
python/pyspark/ml/ F401 'sys' imported but unused
python/pyspark/ml/ F401 '' imported but unused
python/pyspark/ml/ F401 'sys' imported but unused
python/pyspark/ml/ F401 '' imported but unused
python/pyspark/ml/ F401 '' imported but unused
python/pyspark/ml/tests/ F401 'sys' imported but unused
python/pyspark/ml/tests/ F401 '*' imported but unused
python/pyspark/ml/tests/ F401 '*' imported but unused
python/pyspark/ml/tests/ F401 'pyspark.sql.functions as F' imported but unused
python/pyspark/ml/tests/ F401 '*' imported but unused
python/pyspark/ml/tests/ F401 '*' imported but unused
python/pyspark/ml/tests/ F401 'sys' imported but unused
python/pyspark/ml/tests/ F401 '*' imported but unused
python/pyspark/ml/tests/ F401 'py4j' imported but unused
python/pyspark/ml/tests/ F401 'pyspark.testing.mlutils.PySparkTestCase' imported but unused
python/pyspark/ml/tests/ F401 '*' imported but unused
python/pyspark/ml/tests/ F401 '*' imported but unused
python/pyspark/ml/tests/ F401 '*' imported but unused
python/pyspark/ml/tests/ F401 '*' imported but unused
python/pyspark/ml/tests/ F401 '*' imported but unused
python/pyspark/ml/tests/ F401 'sys' imported but unused
python/pyspark/ml/tests/ F401 '*' imported but unused
python/pyspark/ml/tests/ F401 '*' imported but unused
python/pyspark/ml/tests/ F401 '*' imported but unused
python/pyspark/ml/param/ F401 'sys' imported but unused
python/pyspark/resource/tests/ F401 'random' imported but unused
python/pyspark/resource/tests/ F401 'pyspark.resource.ResourceProfile' imported but unused
python/pyspark/resource/tests/ F401 'pyspark.resource.tests.test_resources.*' imported but unused
python/pyspark/sql/ F401 'pyspark.sql.udf.UserDefinedFunction' imported but unused
python/pyspark/sql/ F401 'pyspark.sql.pandas.functions.pandas_udf' imported but unused
python/pyspark/sql/ F401 'pyspark.sql.types.Row' imported but unused
python/pyspark/sql/ F401 'pyspark.sql.types.StringType' imported but unused
python/pyspark/sql/ F401 'pyspark.sql.Row' imported but unused
python/pyspark/sql/ F401 'pyspark.sql.types.IntegerType' imported but unused
python/pyspark/sql/ F401 'pyspark.sql.types.Row' imported but unused
python/pyspark/sql/ F401 'pyspark.sql.types.StringType' imported but unused
python/pyspark/sql/ F401 'pyspark.sql.udf.UDFRegistration' imported but unused
python/pyspark/sql/ F401 'pyspark.sql.Row' imported but unused
python/pyspark/sql/tests/ F401 'pyspark.sql.tests.test_utils.*' imported but unused
python/pyspark/sql/tests/ F401 'sys' imported but unused
python/pyspark/sql/tests/ F401 'pyspark.sql.functions.pandas_udf' imported but unused
python/pyspark/sql/tests/ F401 'pyspark.sql.functions.PandasUDFType' imported but unused
python/pyspark/sql/tests/ F401 'pyspark.sql.tests.test_pandas_map.*' imported but unused
python/pyspark/sql/tests/ F401 'pyspark.sql.tests.test_catalog.*' imported but unused
python/pyspark/sql/tests/ F401 'pyspark.sql.tests.test_group.*' imported but unused
python/pyspark/sql/tests/ F401 'pyspark.sql.tests.test_session.*' imported but unused
python/pyspark/sql/tests/ F401 'pyspark.sql.tests.test_conf.*' imported but unused
python/pyspark/sql/tests/ F401 'sys' imported but unused
python/pyspark/sql/tests/ F401 'pyspark.sql.functions.sum' imported but unused
python/pyspark/sql/tests/ F401 'pyspark.sql.functions.PandasUDFType' imported but unused
python/pyspark/sql/tests/ F401 'pandas.util.testing.assert_series_equal' imported but unused
python/pyspark/sql/tests/ F401 'pyarrow as pa' imported but unused
python/pyspark/sql/tests/ F401 'pyspark.sql.tests.test_pandas_cogrouped_map.*' imported but unused
python/pyspark/sql/tests/ F401 'py4j' imported but unused
python/pyspark/sql/tests/ F401 'pyspark.sql.tests.test_pandas_udf_typehints.*' imported but unused
python/pyspark/sql/tests/ F401 'sys' imported but unused
python/pyspark/sql/tests/ F401 'pyspark.sql.functions.exists' imported but unused
python/pyspark/sql/tests/ F401 'pyspark.sql.tests.test_functions.*' imported but unused
python/pyspark/sql/tests/ F401 'sys' imported but unused
python/pyspark/sql/tests/ F401 'pyarrow as pa' imported but unused
python/pyspark/sql/tests/ F401 'pyspark.sql.tests.test_pandas_udf_window.*' imported but unused
python/pyspark/sql/tests/ F401 'pyarrow as pa' imported but unused
python/pyspark/sql/tests/ F401 'sys' imported but unused
python/pyspark/sql/tests/ F401 'pyarrow as pa' imported but unused
python/pyspark/sql/tests/ F401 'pyspark.sql.DataFrame' imported but unused
python/pyspark/sql/avro/ F401 'pyspark.sql.Row' imported but unused
python/pyspark/sql/pandas/ F401 'sys' imported but unused

fokkodriesprongFan spark % flake8 python | grep -i "imported but unused"
fokkodriesprongFan spark %

### What changes were proposed in this pull request?

Removing unused imports from the Python files to keep everything nice and tidy.

### Why are the changes needed?

Cleaning up of the imports that aren't used, and suppressing the imports that are used as references to other modules, preserving backward compatibility.

### Does this PR introduce _any_ user-facing change?


### How was this patch tested?

Adding the rule to the existing Flake8 checks.

Closes #29121 from Fokko/SPARK-32319.

Authored-by: Fokko Driesprong <>
Signed-off-by: Dongjoon Hyun <>
2020-08-08 08:51:57 -07:00

617 lines
24 KiB

# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import unittest
from collections import OrderedDict
from decimal import Decimal
from pyspark.sql import Row
from pyspark.sql.functions import array, explode, col, lit, udf, sum, pandas_udf, PandasUDFType, \
from pyspark.sql.types import *
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
pandas_requirement_message, pyarrow_requirement_message
from pyspark.testing.utils import QuietTest
if have_pandas:
import pandas as pd
from pandas.util.testing import assert_frame_equal
if have_pyarrow:
import pyarrow as pa # noqa: F401
not have_pandas or not have_pyarrow,
pandas_requirement_message or pyarrow_requirement_message)
class GroupedMapInPandasTests(ReusedSQLTestCase):
def data(self):
return self.spark.range(10).toDF('id') \
.withColumn("vs", array([lit(i) for i in range(20, 30)])) \
.withColumn("v", explode(col('vs'))).drop('vs')
def test_supported_types(self):
values = [
1, 2, 3,
4, 5, 1.1,
2.2, Decimal(1.123),
[1, 2, 2], True, 'hello',
bytearray([0x01, 0x02])
output_fields = [
('id', IntegerType()), ('byte', ByteType()), ('short', ShortType()),
('int', IntegerType()), ('long', LongType()), ('float', FloatType()),
('double', DoubleType()), ('decim', DecimalType(10, 3)),
('array', ArrayType(IntegerType())), ('bool', BooleanType()), ('str', StringType()),
('bin', BinaryType())
output_schema = StructType([StructField(*x) for x in output_fields])
df = self.spark.createDataFrame([values], schema=output_schema)
# Different forms of group map pandas UDF, results of these are the same
udf1 = pandas_udf(
lambda pdf: pdf.assign(
byte=pdf.byte * 2,
short=pdf.short * 2, * 2,
long=pdf.long * 2,
float=pdf.float * 2,
double=pdf.double * 2,
decim=pdf.decim * 2,
bool=False if pdf.bool else True,
str=pdf.str + 'there',
udf2 = pandas_udf(
lambda _, pdf: pdf.assign(
byte=pdf.byte * 2,
short=pdf.short * 2, * 2,
long=pdf.long * 2,
float=pdf.float * 2,
double=pdf.double * 2,
decim=pdf.decim * 2,
bool=False if pdf.bool else True,
str=pdf.str + 'there',
udf3 = pandas_udf(
lambda key, pdf: pdf.assign(
byte=pdf.byte * 2,
short=pdf.short * 2, * 2,
long=pdf.long * 2,
float=pdf.float * 2,
double=pdf.double * 2,
decim=pdf.decim * 2,
bool=False if pdf.bool else True,
str=pdf.str + 'there',
result1 = df.groupby('id').apply(udf1).sort('id').toPandas()
expected1 = df.toPandas().groupby('id').apply(udf1.func).reset_index(drop=True)
result2 = df.groupby('id').apply(udf2).sort('id').toPandas()
expected2 = expected1
result3 = df.groupby('id').apply(udf3).sort('id').toPandas()
expected3 = expected1
assert_frame_equal(expected1, result1)
assert_frame_equal(expected2, result2)
assert_frame_equal(expected3, result3)
def test_array_type_correct(self):
df ="arr", array(col("id"))).repartition(1, "id")
output_schema = StructType(
[StructField('id', LongType()),
StructField('v', IntegerType()),
StructField('arr', ArrayType(LongType()))])
udf = pandas_udf(
lambda pdf: pdf,
result = df.groupby('id').apply(udf).sort('id').toPandas()
expected = df.toPandas().groupby('id').apply(udf.func).reset_index(drop=True)
assert_frame_equal(expected, result)
def test_register_grouped_map_udf(self):
foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUPED_MAP)
with QuietTest(
with self.assertRaisesRegexp(
self.spark.catalog.registerFunction("foo_udf", foo_udf)
def test_decorator(self):
df =
'id long, v int, v1 double, v2 long',
def foo(pdf):
return pdf.assign(v1=pdf.v * * 1.0, v2=pdf.v +
result = df.groupby('id').apply(foo).sort('id').toPandas()
expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True)
assert_frame_equal(expected, result)
def test_coerce(self):
df =
foo = pandas_udf(
lambda pdf: pdf,
'id long, v double',
result = df.groupby('id').apply(foo).sort('id').toPandas()
expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True)
expected = expected.assign(v=expected.v.astype('float64'))
assert_frame_equal(expected, result)
def test_complex_groupby(self):
df =
'id long, v int, norm double',
def normalize(pdf):
v = pdf.v
return pdf.assign(norm=(v - v.mean()) / v.std())
result = df.groupby(col('id') % 2 == 0).apply(normalize).sort('id', 'v').toPandas()
pdf = df.toPandas()
expected = pdf.groupby(pdf['id'] % 2 == 0, as_index=False).apply(normalize.func)
expected = expected.sort_values(['id', 'v']).reset_index(drop=True)
expected = expected.assign(norm=expected.norm.astype('float64'))
assert_frame_equal(expected, result)
def test_empty_groupby(self):
df =
'id long, v int, norm double',
def normalize(pdf):
v = pdf.v
return pdf.assign(norm=(v - v.mean()) / v.std())
result = df.groupby().apply(normalize).sort('id', 'v').toPandas()
pdf = df.toPandas()
expected = normalize.func(pdf)
expected = expected.sort_values(['id', 'v']).reset_index(drop=True)
expected = expected.assign(norm=expected.norm.astype('float64'))
assert_frame_equal(expected, result)
def test_datatype_string(self):
df =
foo_udf = pandas_udf(
lambda pdf: pdf.assign(v1=pdf.v * * 1.0, v2=pdf.v +,
'id long, v int, v1 double, v2 long',
result = df.groupby('id').apply(foo_udf).sort('id').toPandas()
expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True)
assert_frame_equal(expected, result)
def test_wrong_return_type(self):
with QuietTest(
with self.assertRaisesRegexp(
'Invalid return type.*grouped map Pandas UDF.*MapType'):
lambda pdf: pdf,
'id long, v map<int, int>',
def test_wrong_args(self):
df =
with QuietTest(
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
df.groupby('id').apply(lambda x: x)
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
df.groupby('id').apply(udf(lambda x: x, DoubleType()))
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
df.groupby('id').apply(df.v + 1)
with self.assertRaisesRegexp(ValueError, 'Invalid function'):
pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())])))
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
df.groupby('id').apply(pandas_udf(lambda x, y: x, DoubleType()))
with self.assertRaisesRegexp(ValueError, 'Invalid udf.*GROUPED_MAP'):
pandas_udf(lambda x, y: x, DoubleType(), PandasUDFType.SCALAR))
def test_unsupported_types(self):
common_err_msg = 'Invalid return type.*grouped map Pandas UDF.*'
unsupported_types = [
StructField('map', MapType(StringType(), IntegerType())),
StructField('arr_ts', ArrayType(TimestampType())),
StructField('null', NullType()),
StructField('struct', StructType([StructField('l', LongType())])),
for unsupported_type in unsupported_types:
schema = StructType([StructField('id', LongType(), True), unsupported_type])
with QuietTest(
with self.assertRaisesRegexp(NotImplementedError, common_err_msg):
pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP)
# Regression test for SPARK-23314
def test_timestamp_dst(self):
# Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am
dt = [datetime.datetime(2015, 11, 1, 0, 30),
datetime.datetime(2015, 11, 1, 1, 30),
datetime.datetime(2015, 11, 1, 2, 30)]
df = self.spark.createDataFrame(dt, 'timestamp').toDF('time')
foo_udf = pandas_udf(lambda pdf: pdf, 'time timestamp', PandasUDFType.GROUPED_MAP)
result = df.groupby('time').apply(foo_udf).sort('time')
assert_frame_equal(df.toPandas(), result.toPandas())
def test_udf_with_key(self):
import numpy as np
df =
pdf = df.toPandas()
def foo1(key, pdf):
assert type(key) == tuple
assert type(key[0]) == np.int64
return pdf.assign(v1=key[0],
v2=pdf.v * key[0],
v3=pdf.v *,
v4=pdf.v *
def foo2(key, pdf):
assert type(key) == tuple
assert type(key[0]) == np.int64
assert type(key[1]) == np.int32
return pdf.assign(v1=key[0],
v3=pdf.v * key[0],
v4=pdf.v + key[1])
def foo3(key, pdf):
assert type(key) == tuple
assert len(key) == 0
return pdf.assign(v1=pdf.v *
# v2 is int because numpy.int64 * pd.Series<int32> results in pd.Series<int32>
# v3 is long because pd.Series<int64> * pd.Series<int32> results in pd.Series<int64>
udf1 = pandas_udf(
'id long, v int, v1 long, v2 int, v3 long, v4 double',
udf2 = pandas_udf(
'id long, v int, v1 long, v2 int, v3 int, v4 int',
udf3 = pandas_udf(
'id long, v int, v1 long',
# Test groupby column
result1 = df.groupby('id').apply(udf1).sort('id', 'v').toPandas()
expected1 = pdf.groupby('id', as_index=False)\
.apply(lambda x: udf1.func(([0],), x))\
.sort_values(['id', 'v']).reset_index(drop=True)
assert_frame_equal(expected1, result1)
# Test groupby expression
result2 = df.groupby( % 2).apply(udf1).sort('id', 'v').toPandas()
expected2 = pdf.groupby( % 2, as_index=False)\
.apply(lambda x: udf1.func(([0] % 2,), x))\
.sort_values(['id', 'v']).reset_index(drop=True)
assert_frame_equal(expected2, result2)
# Test complex groupby
result3 = df.groupby(, df.v % 2).apply(udf2).sort('id', 'v').toPandas()
expected3 = pdf.groupby([, pdf.v % 2], as_index=False)\
.apply(lambda x: udf2.func(([0], (x.v % 2).iloc[0],), x))\
.sort_values(['id', 'v']).reset_index(drop=True)
assert_frame_equal(expected3, result3)
# Test empty groupby
result4 = df.groupby().apply(udf3).sort('id', 'v').toPandas()
expected4 = udf3.func((), pdf)
assert_frame_equal(expected4, result4)
def test_column_order(self):
# Helper function to set column names from a list
def rename_pdf(pdf, names):
pdf.rename(columns={old: new for old, new in
zip(pd_result.columns, names)}, inplace=True)
df =
grouped_df = df.groupby('id')
grouped_pdf = df.toPandas().groupby('id', as_index=False)
# Function returns a pdf with required column names, but order could be arbitrary using dict
def change_col_order(pdf):
# Constructing a DataFrame from a dict should result in the same order,
# but use OrderedDict to ensure the pdf column order is different than schema
return pd.DataFrame.from_dict(OrderedDict([
('u', pdf.v * 2),
('v', pdf.v)]))
ordered_udf = pandas_udf(
'id long, v int, u int',
# The UDF result should assign columns by name from the pdf
result = grouped_df.apply(ordered_udf).sort('id', 'v')\
.select('id', 'u', 'v').toPandas()
pd_result = grouped_pdf.apply(change_col_order)
expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True)
assert_frame_equal(expected, result)
# Function returns a pdf with positional columns, indexed by range
def range_col_order(pdf):
# Create a DataFrame with positional columns, fix types to long
return pd.DataFrame(list(zip(, pdf.v * 3, pdf.v)), dtype='int64')
range_udf = pandas_udf(
'id long, u long, v long',
# The UDF result uses positional columns from the pdf
result = grouped_df.apply(range_udf).sort('id', 'v') \
.select('id', 'u', 'v').toPandas()
pd_result = grouped_pdf.apply(range_col_order)
rename_pdf(pd_result, ['id', 'u', 'v'])
expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True)
assert_frame_equal(expected, result)
# Function returns a pdf with columns indexed with integers
def int_index(pdf):
return pd.DataFrame(OrderedDict([(0,, (1, pdf.v * 4), (2, pdf.v)]))
int_index_udf = pandas_udf(
'id long, u int, v int',
# The UDF result should assign columns by position of integer index
result = grouped_df.apply(int_index_udf).sort('id', 'v') \
.select('id', 'u', 'v').toPandas()
pd_result = grouped_pdf.apply(int_index)
rename_pdf(pd_result, ['id', 'u', 'v'])
expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True)
assert_frame_equal(expected, result)
@pandas_udf('id long, v int', PandasUDFType.GROUPED_MAP)
def column_name_typo(pdf):
return pd.DataFrame({'iid':, 'v': pdf.v})
@pandas_udf('id long, v int', PandasUDFType.GROUPED_MAP)
def invalid_positional_types(pdf):
return pd.DataFrame([(u'a', 1.2)])
with QuietTest(
with self.assertRaisesRegexp(Exception, "KeyError: 'id'"):
with self.assertRaisesRegexp(Exception, "an integer is required"):
def test_positional_assignment_conf(self):
with self.sql_conf({
"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName": False}):
@pandas_udf("a string, b float", PandasUDFType.GROUPED_MAP)
def foo(_):
return pd.DataFrame([('hi', 1)], columns=['x', 'y'])
df =
result = df.groupBy('id').apply(foo).select('a', 'b').collect()
for r in result:
self.assertEqual(r.a, 'hi')
self.assertEqual(r.b, 1)
def test_self_join_with_pandas(self):
@pandas_udf('key long, col string', PandasUDFType.GROUPED_MAP)
def dummy_pandas_udf(df):
return df[['key', 'col']]
df = self.spark.createDataFrame([Row(key=1, col='A'), Row(key=1, col='B'),
Row(key=2, col='C')])
df_with_pandas = df.groupBy('key').apply(dummy_pandas_udf)
# this was throwing an AnalysisException before SPARK-24208
res = df_with_pandas.alias('temp0').join(df_with_pandas.alias('temp1'),
col('temp0.key') == col('temp1.key'))
self.assertEquals(res.count(), 5)
def test_mixed_scalar_udfs_followed_by_grouby_apply(self):
df = self.spark.range(0, 10).toDF('v1')
df = df.withColumn('v2', udf(lambda x: x + 1, 'int')(df['v1'])) \
.withColumn('v3', pandas_udf(lambda x: x + 2, 'int')(df['v1']))
result = df.groupby() \
.apply(pandas_udf(lambda x: pd.DataFrame([x.sum().sum()]),
'sum int',
self.assertEquals(result.collect()[0]['sum'], 165)
def test_grouped_with_empty_partition(self):
data = [Row(id=1, x=2), Row(id=1, x=3), Row(id=2, x=4)]
expected = [Row(id=1, x=5), Row(id=1, x=5), Row(id=2, x=4)]
num_parts = len(data) + 1
df = self.spark.createDataFrame(, numSlices=num_parts))
f = pandas_udf(lambda pdf: pdf.assign(x=pdf['x'].sum()),
'id long, x int', PandasUDFType.GROUPED_MAP)
result = df.groupBy('id').apply(f).collect()
self.assertEqual(result, expected)
def test_grouped_over_window(self):
data = [(0, 1, "2018-03-10T00:00:00+00:00", [0]),
(1, 2, "2018-03-11T00:00:00+00:00", [0]),
(2, 2, "2018-03-12T00:00:00+00:00", [0]),
(3, 3, "2018-03-15T00:00:00+00:00", [0]),
(4, 3, "2018-03-16T00:00:00+00:00", [0]),
(5, 3, "2018-03-17T00:00:00+00:00", [0]),
(6, 3, "2018-03-21T00:00:00+00:00", [0])]
expected = {0: [0],
1: [1, 2],
2: [1, 2],
3: [3, 4, 5],
4: [3, 4, 5],
5: [3, 4, 5],
6: [6]}
df = self.spark.createDataFrame(data, ['id', 'group', 'ts', 'result'])
df ='id'), col('group'), col('ts').cast('timestamp'), col('result'))
def f(pdf):
# Assign each result element the ids of the windowed group
pdf['result'] = [pdf['id']] * len(pdf)
return pdf
result = df.groupby('group', window('ts', '5 days')).applyInPandas(f, df.schema)\
.select('id', 'result').collect()
for r in result:
self.assertListEqual(expected[r[0]], r[1])
def test_grouped_over_window_with_key(self):
data = [(0, 1, "2018-03-10T00:00:00+00:00", [0]),
(1, 2, "2018-03-11T00:00:00+00:00", [0]),
(2, 2, "2018-03-12T00:00:00+00:00", [0]),
(3, 3, "2018-03-15T00:00:00+00:00", [0]),
(4, 3, "2018-03-16T00:00:00+00:00", [0]),
(5, 3, "2018-03-17T00:00:00+00:00", [0]),
(6, 3, "2018-03-21T00:00:00+00:00", [0])]
expected_window = [
{'start': datetime.datetime(2018, 3, 10, 0, 0),
'end': datetime.datetime(2018, 3, 15, 0, 0)},
{'start': datetime.datetime(2018, 3, 15, 0, 0),
'end': datetime.datetime(2018, 3, 20, 0, 0)},
{'start': datetime.datetime(2018, 3, 20, 0, 0),
'end': datetime.datetime(2018, 3, 25, 0, 0)},
expected_key = {0: (1, expected_window[0]),
1: (2, expected_window[0]),
2: (2, expected_window[0]),
3: (3, expected_window[1]),
4: (3, expected_window[1]),
5: (3, expected_window[1]),
6: (3, expected_window[2])}
# id -> array of group with len of num records in window
expected = {0: [1],
1: [2, 2],
2: [2, 2],
3: [3, 3, 3],
4: [3, 3, 3],
5: [3, 3, 3],
6: [3]}
df = self.spark.createDataFrame(data, ['id', 'group', 'ts', 'result'])
df ='id'), col('group'), col('ts').cast('timestamp'), col('result'))
def f(key, pdf):
group = key[0]
window_range = key[1]
# Make sure the key with group and window values are correct
for _, i in
assert expected_key[i][0] == group, "{} != {}".format(expected_key[i][0], group)
assert expected_key[i][1] == window_range, \
"{} != {}".format(expected_key[i][1], window_range)
return pdf.assign(result=[[group] * len(pdf)] * len(pdf))
result = df.groupby('group', window('ts', '5 days')).applyInPandas(f, df.schema)\
.select('id', 'result').collect()
for r in result:
self.assertListEqual(expected[r[0]], r[1])
def test_case_insensitive_grouping_column(self):
# SPARK-31915: case-insensitive grouping column should work.
def my_pandas_udf(pdf):
return pdf.assign(score=0.5)
df = self.spark.createDataFrame([[1, 1]], ["column", "score"])
row = df.groupby('COLUMN').applyInPandas(
my_pandas_udf, schema="column integer, score float").first()
self.assertEquals(row.asDict(), Row(column=1, score=0.5).asDict())
if __name__ == "__main__":
from pyspark.sql.tests.test_pandas_grouped_map import *
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)