9fcf0ea718
Disallow the use of unused imports: - Unnecessary increases the memory footprint of the application - Removes the imports that are required for the examples in the docstring from the file-scope to the example itself. This keeps the files itself clean, and gives a more complete example as it also includes the imports :) ``` fokkodriesprongFan spark % flake8 python | grep -i "imported but unused" python/pyspark/cloudpickle.py:46:1: F401 'functools.partial' imported but unused python/pyspark/cloudpickle.py:55:1: F401 'traceback' imported but unused python/pyspark/heapq3.py:868:5: F401 '_heapq.*' imported but unused python/pyspark/__init__.py:61:1: F401 'pyspark.version.__version__' imported but unused python/pyspark/__init__.py:62:1: F401 'pyspark._globals._NoValue' imported but unused python/pyspark/__init__.py:115:1: F401 'pyspark.sql.SQLContext' imported but unused python/pyspark/__init__.py:115:1: F401 'pyspark.sql.HiveContext' imported but unused python/pyspark/__init__.py:115:1: F401 'pyspark.sql.Row' imported but unused python/pyspark/rdd.py:21:1: F401 're' imported but unused python/pyspark/rdd.py:29:1: F401 'tempfile.NamedTemporaryFile' imported but unused python/pyspark/mllib/regression.py:26:1: F401 'pyspark.mllib.linalg.SparseVector' imported but unused python/pyspark/mllib/clustering.py:28:1: F401 'pyspark.mllib.linalg.SparseVector' imported but unused python/pyspark/mllib/clustering.py:28:1: F401 'pyspark.mllib.linalg.DenseVector' imported but unused python/pyspark/mllib/classification.py:26:1: F401 'pyspark.mllib.linalg.SparseVector' imported but unused python/pyspark/mllib/feature.py:28:1: F401 'pyspark.mllib.linalg.DenseVector' imported but unused python/pyspark/mllib/feature.py:28:1: F401 'pyspark.mllib.linalg.SparseVector' imported but unused python/pyspark/mllib/feature.py:30:1: F401 'pyspark.mllib.regression.LabeledPoint' imported but unused python/pyspark/mllib/tests/test_linalg.py:18:1: F401 'sys' imported but unused python/pyspark/mllib/tests/test_linalg.py:642:5: F401 'pyspark.mllib.tests.test_linalg.*' imported but unused python/pyspark/mllib/tests/test_feature.py:21:1: F401 'numpy.random' imported but unused python/pyspark/mllib/tests/test_feature.py:21:1: F401 'numpy.exp' imported but unused python/pyspark/mllib/tests/test_feature.py:23:1: F401 'pyspark.mllib.linalg.Vector' imported but unused python/pyspark/mllib/tests/test_feature.py:23:1: F401 'pyspark.mllib.linalg.VectorUDT' imported but unused python/pyspark/mllib/tests/test_feature.py:185:5: F401 'pyspark.mllib.tests.test_feature.*' imported but unused python/pyspark/mllib/tests/test_util.py:97:5: F401 'pyspark.mllib.tests.test_util.*' imported but unused python/pyspark/mllib/tests/test_stat.py:23:1: F401 'pyspark.mllib.linalg.Vector' imported but unused python/pyspark/mllib/tests/test_stat.py:23:1: F401 'pyspark.mllib.linalg.SparseVector' imported but unused python/pyspark/mllib/tests/test_stat.py:23:1: F401 'pyspark.mllib.linalg.DenseVector' imported but unused python/pyspark/mllib/tests/test_stat.py:23:1: F401 'pyspark.mllib.linalg.VectorUDT' imported but unused python/pyspark/mllib/tests/test_stat.py:23:1: F401 'pyspark.mllib.linalg._convert_to_vector' imported but unused python/pyspark/mllib/tests/test_stat.py:23:1: F401 'pyspark.mllib.linalg.DenseMatrix' imported but unused python/pyspark/mllib/tests/test_stat.py:23:1: F401 'pyspark.mllib.linalg.SparseMatrix' imported but unused python/pyspark/mllib/tests/test_stat.py:23:1: F401 'pyspark.mllib.linalg.MatrixUDT' imported but unused python/pyspark/mllib/tests/test_stat.py:181:5: F401 'pyspark.mllib.tests.test_stat.*' imported but unused python/pyspark/mllib/tests/test_streaming_algorithms.py:18:1: F401 'time.time' imported but unused python/pyspark/mllib/tests/test_streaming_algorithms.py:18:1: F401 'time.sleep' imported but unused python/pyspark/mllib/tests/test_streaming_algorithms.py:470:5: F401 'pyspark.mllib.tests.test_streaming_algorithms.*' imported but unused python/pyspark/mllib/tests/test_algorithms.py:295:5: F401 'pyspark.mllib.tests.test_algorithms.*' imported but unused python/pyspark/tests/test_serializers.py:90:13: F401 'xmlrunner' imported but unused python/pyspark/tests/test_rdd.py:21:1: F401 'sys' imported but unused python/pyspark/tests/test_rdd.py:29:1: F401 'pyspark.resource.ResourceProfile' imported but unused python/pyspark/tests/test_rdd.py:885:5: F401 'pyspark.tests.test_rdd.*' imported but unused python/pyspark/tests/test_readwrite.py:19:1: F401 'sys' imported but unused python/pyspark/tests/test_readwrite.py:22:1: F401 'array.array' imported but unused python/pyspark/tests/test_readwrite.py:309:5: F401 'pyspark.tests.test_readwrite.*' imported but unused python/pyspark/tests/test_join.py:62:5: F401 'pyspark.tests.test_join.*' imported but unused python/pyspark/tests/test_taskcontext.py:19:1: F401 'shutil' imported but unused python/pyspark/tests/test_taskcontext.py:325:5: F401 'pyspark.tests.test_taskcontext.*' imported but unused python/pyspark/tests/test_conf.py:36:5: F401 'pyspark.tests.test_conf.*' imported but unused python/pyspark/tests/test_broadcast.py:148:5: F401 'pyspark.tests.test_broadcast.*' imported but unused python/pyspark/tests/test_daemon.py:76:5: F401 'pyspark.tests.test_daemon.*' imported but unused python/pyspark/tests/test_util.py:77:5: F401 'pyspark.tests.test_util.*' imported but unused python/pyspark/tests/test_pin_thread.py:19:1: F401 'random' imported but unused python/pyspark/tests/test_pin_thread.py:149:5: F401 'pyspark.tests.test_pin_thread.*' imported but unused python/pyspark/tests/test_worker.py:19:1: F401 'sys' imported but unused python/pyspark/tests/test_worker.py:26:5: F401 'resource' imported but unused python/pyspark/tests/test_worker.py:203:5: F401 'pyspark.tests.test_worker.*' imported but unused python/pyspark/tests/test_profiler.py:101:5: F401 'pyspark.tests.test_profiler.*' imported but unused python/pyspark/tests/test_shuffle.py:18:1: F401 'sys' imported but unused python/pyspark/tests/test_shuffle.py:171:5: F401 'pyspark.tests.test_shuffle.*' imported but unused python/pyspark/tests/test_rddbarrier.py:43:5: F401 'pyspark.tests.test_rddbarrier.*' imported but unused python/pyspark/tests/test_context.py:129:13: F401 'userlibrary.UserClass' imported but unused python/pyspark/tests/test_context.py:140:13: F401 'userlib.UserClass' imported but unused python/pyspark/tests/test_context.py:310:5: F401 'pyspark.tests.test_context.*' imported but unused python/pyspark/tests/test_appsubmit.py:241:5: F401 'pyspark.tests.test_appsubmit.*' imported but unused python/pyspark/streaming/dstream.py:18:1: F401 'sys' imported but unused python/pyspark/streaming/tests/test_dstream.py:27:1: F401 'pyspark.RDD' imported but unused python/pyspark/streaming/tests/test_dstream.py:647:5: F401 'pyspark.streaming.tests.test_dstream.*' imported but unused python/pyspark/streaming/tests/test_kinesis.py:83:5: F401 'pyspark.streaming.tests.test_kinesis.*' imported but unused python/pyspark/streaming/tests/test_listener.py:152:5: F401 'pyspark.streaming.tests.test_listener.*' imported but unused python/pyspark/streaming/tests/test_context.py:178:5: F401 'pyspark.streaming.tests.test_context.*' imported but unused python/pyspark/testing/utils.py:30:5: F401 'scipy.sparse' imported but unused python/pyspark/testing/utils.py:36:5: F401 'numpy as np' imported but unused python/pyspark/ml/regression.py:25:1: F401 'pyspark.ml.tree._TreeEnsembleParams' imported but unused python/pyspark/ml/regression.py:25:1: F401 'pyspark.ml.tree._HasVarianceImpurity' imported but unused python/pyspark/ml/regression.py:29:1: F401 'pyspark.ml.wrapper.JavaParams' imported but unused python/pyspark/ml/util.py:19:1: F401 'sys' imported but unused python/pyspark/ml/__init__.py:25:1: F401 'pyspark.ml.pipeline' imported but unused python/pyspark/ml/pipeline.py:18:1: F401 'sys' imported but unused python/pyspark/ml/stat.py:22:1: F401 'pyspark.ml.linalg.DenseMatrix' imported but unused python/pyspark/ml/stat.py:22:1: F401 'pyspark.ml.linalg.Vectors' imported but unused python/pyspark/ml/tests/test_training_summary.py:18:1: F401 'sys' imported but unused python/pyspark/ml/tests/test_training_summary.py:364:5: F401 'pyspark.ml.tests.test_training_summary.*' imported but unused python/pyspark/ml/tests/test_linalg.py:381:5: F401 'pyspark.ml.tests.test_linalg.*' imported but unused python/pyspark/ml/tests/test_tuning.py:427:9: F401 'pyspark.sql.functions as F' imported but unused python/pyspark/ml/tests/test_tuning.py:757:5: F401 'pyspark.ml.tests.test_tuning.*' imported but unused python/pyspark/ml/tests/test_wrapper.py:120:5: F401 'pyspark.ml.tests.test_wrapper.*' imported but unused python/pyspark/ml/tests/test_feature.py:19:1: F401 'sys' imported but unused python/pyspark/ml/tests/test_feature.py:304:5: F401 'pyspark.ml.tests.test_feature.*' imported but unused python/pyspark/ml/tests/test_image.py:19:1: F401 'py4j' imported but unused python/pyspark/ml/tests/test_image.py:22:1: F401 'pyspark.testing.mlutils.PySparkTestCase' imported but unused python/pyspark/ml/tests/test_image.py:71:5: F401 'pyspark.ml.tests.test_image.*' imported but unused python/pyspark/ml/tests/test_persistence.py:456:5: F401 'pyspark.ml.tests.test_persistence.*' imported but unused python/pyspark/ml/tests/test_evaluation.py:56:5: F401 'pyspark.ml.tests.test_evaluation.*' imported but unused python/pyspark/ml/tests/test_stat.py:43:5: F401 'pyspark.ml.tests.test_stat.*' imported but unused python/pyspark/ml/tests/test_base.py:70:5: F401 'pyspark.ml.tests.test_base.*' imported but unused python/pyspark/ml/tests/test_param.py:20:1: F401 'sys' imported but unused python/pyspark/ml/tests/test_param.py:375:5: F401 'pyspark.ml.tests.test_param.*' imported but unused python/pyspark/ml/tests/test_pipeline.py:62:5: F401 'pyspark.ml.tests.test_pipeline.*' imported but unused python/pyspark/ml/tests/test_algorithms.py:333:5: F401 'pyspark.ml.tests.test_algorithms.*' imported but unused python/pyspark/ml/param/__init__.py:18:1: F401 'sys' imported but unused python/pyspark/resource/tests/test_resources.py:17:1: F401 'random' imported but unused python/pyspark/resource/tests/test_resources.py:20:1: F401 'pyspark.resource.ResourceProfile' imported but unused python/pyspark/resource/tests/test_resources.py:75:5: F401 'pyspark.resource.tests.test_resources.*' imported but unused python/pyspark/sql/functions.py:32:1: F401 'pyspark.sql.udf.UserDefinedFunction' imported but unused python/pyspark/sql/functions.py:34:1: F401 'pyspark.sql.pandas.functions.pandas_udf' imported but unused python/pyspark/sql/session.py:30:1: F401 'pyspark.sql.types.Row' imported but unused python/pyspark/sql/session.py:30:1: F401 'pyspark.sql.types.StringType' imported but unused python/pyspark/sql/readwriter.py:1084:5: F401 'pyspark.sql.Row' imported but unused python/pyspark/sql/context.py:26:1: F401 'pyspark.sql.types.IntegerType' imported but unused python/pyspark/sql/context.py:26:1: F401 'pyspark.sql.types.Row' imported but unused python/pyspark/sql/context.py:26:1: F401 'pyspark.sql.types.StringType' imported but unused python/pyspark/sql/context.py:27:1: F401 'pyspark.sql.udf.UDFRegistration' imported but unused python/pyspark/sql/streaming.py:1212:5: F401 'pyspark.sql.Row' imported but unused python/pyspark/sql/tests/test_utils.py:55:5: F401 'pyspark.sql.tests.test_utils.*' imported but unused python/pyspark/sql/tests/test_pandas_map.py:18:1: F401 'sys' imported but unused python/pyspark/sql/tests/test_pandas_map.py:22:1: F401 'pyspark.sql.functions.pandas_udf' imported but unused python/pyspark/sql/tests/test_pandas_map.py:22:1: F401 'pyspark.sql.functions.PandasUDFType' imported but unused python/pyspark/sql/tests/test_pandas_map.py:119:5: F401 'pyspark.sql.tests.test_pandas_map.*' imported but unused python/pyspark/sql/tests/test_catalog.py:193:5: F401 'pyspark.sql.tests.test_catalog.*' imported but unused python/pyspark/sql/tests/test_group.py:39:5: F401 'pyspark.sql.tests.test_group.*' imported but unused python/pyspark/sql/tests/test_session.py:361:5: F401 'pyspark.sql.tests.test_session.*' imported but unused python/pyspark/sql/tests/test_conf.py:49:5: F401 'pyspark.sql.tests.test_conf.*' imported but unused python/pyspark/sql/tests/test_pandas_cogrouped_map.py:19:1: F401 'sys' imported but unused python/pyspark/sql/tests/test_pandas_cogrouped_map.py:21:1: F401 'pyspark.sql.functions.sum' imported but unused python/pyspark/sql/tests/test_pandas_cogrouped_map.py:21:1: F401 'pyspark.sql.functions.PandasUDFType' imported but unused python/pyspark/sql/tests/test_pandas_cogrouped_map.py:29:5: F401 'pandas.util.testing.assert_series_equal' imported but unused python/pyspark/sql/tests/test_pandas_cogrouped_map.py:32:5: F401 'pyarrow as pa' imported but unused python/pyspark/sql/tests/test_pandas_cogrouped_map.py:248:5: F401 'pyspark.sql.tests.test_pandas_cogrouped_map.*' imported but unused python/pyspark/sql/tests/test_udf.py:24:1: F401 'py4j' imported but unused python/pyspark/sql/tests/test_pandas_udf_typehints.py:246:5: F401 'pyspark.sql.tests.test_pandas_udf_typehints.*' imported but unused python/pyspark/sql/tests/test_functions.py:19:1: F401 'sys' imported but unused python/pyspark/sql/tests/test_functions.py:362:9: F401 'pyspark.sql.functions.exists' imported but unused python/pyspark/sql/tests/test_functions.py:387:5: F401 'pyspark.sql.tests.test_functions.*' imported but unused python/pyspark/sql/tests/test_pandas_udf_scalar.py:21:1: F401 'sys' imported but unused python/pyspark/sql/tests/test_pandas_udf_scalar.py:45:5: F401 'pyarrow as pa' imported but unused python/pyspark/sql/tests/test_pandas_udf_window.py:355:5: F401 'pyspark.sql.tests.test_pandas_udf_window.*' imported but unused python/pyspark/sql/tests/test_arrow.py:38:5: F401 'pyarrow as pa' imported but unused python/pyspark/sql/tests/test_pandas_grouped_map.py:20:1: F401 'sys' imported but unused python/pyspark/sql/tests/test_pandas_grouped_map.py:38:5: F401 'pyarrow as pa' imported but unused python/pyspark/sql/tests/test_dataframe.py:382:9: F401 'pyspark.sql.DataFrame' imported but unused python/pyspark/sql/avro/functions.py:125:5: F401 'pyspark.sql.Row' imported but unused python/pyspark/sql/pandas/functions.py:19:1: F401 'sys' imported but unused ``` After: ``` fokkodriesprongFan spark % flake8 python | grep -i "imported but unused" fokkodriesprongFan spark % ``` ### What changes were proposed in this pull request? Removing unused imports from the Python files to keep everything nice and tidy. ### Why are the changes needed? Cleaning up of the imports that aren't used, and suppressing the imports that are used as references to other modules, preserving backward compatibility. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Adding the rule to the existing Flake8 checks. Closes #29121 from Fokko/SPARK-32319. Authored-by: Fokko Driesprong <fokko@apache.org> Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
521 lines
20 KiB
Python
521 lines
20 KiB
Python
#
|
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
# this work for additional information regarding copyright ownership.
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
# (the "License"); you may not use this file except in compliance with
|
|
# the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import unittest
|
|
|
|
from pyspark.rdd import PythonEvalType
|
|
from pyspark.sql import Row
|
|
from pyspark.sql.functions import array, explode, col, lit, mean, sum, \
|
|
udf, pandas_udf, PandasUDFType
|
|
from pyspark.sql.types import *
|
|
from pyspark.sql.utils import AnalysisException
|
|
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
|
|
pandas_requirement_message, pyarrow_requirement_message
|
|
from pyspark.testing.utils import QuietTest
|
|
|
|
if have_pandas:
|
|
import pandas as pd
|
|
from pandas.util.testing import assert_frame_equal
|
|
|
|
|
|
@unittest.skipIf(
|
|
not have_pandas or not have_pyarrow,
|
|
pandas_requirement_message or pyarrow_requirement_message)
|
|
class GroupedAggPandasUDFTests(ReusedSQLTestCase):
|
|
|
|
@property
|
|
def data(self):
|
|
return self.spark.range(10).toDF('id') \
|
|
.withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \
|
|
.withColumn("v", explode(col('vs'))) \
|
|
.drop('vs') \
|
|
.withColumn('w', lit(1.0))
|
|
|
|
@property
|
|
def python_plus_one(self):
|
|
@udf('double')
|
|
def plus_one(v):
|
|
assert isinstance(v, (int, float))
|
|
return v + 1
|
|
return plus_one
|
|
|
|
@property
|
|
def pandas_scalar_plus_two(self):
|
|
@pandas_udf('double', PandasUDFType.SCALAR)
|
|
def plus_two(v):
|
|
assert isinstance(v, pd.Series)
|
|
return v + 2
|
|
return plus_two
|
|
|
|
@property
|
|
def pandas_agg_mean_udf(self):
|
|
@pandas_udf('double', PandasUDFType.GROUPED_AGG)
|
|
def avg(v):
|
|
return v.mean()
|
|
return avg
|
|
|
|
@property
|
|
def pandas_agg_sum_udf(self):
|
|
@pandas_udf('double', PandasUDFType.GROUPED_AGG)
|
|
def sum(v):
|
|
return v.sum()
|
|
return sum
|
|
|
|
@property
|
|
def pandas_agg_weighted_mean_udf(self):
|
|
import numpy as np
|
|
|
|
@pandas_udf('double', PandasUDFType.GROUPED_AGG)
|
|
def weighted_mean(v, w):
|
|
return np.average(v, weights=w)
|
|
return weighted_mean
|
|
|
|
def test_manual(self):
|
|
df = self.data
|
|
sum_udf = self.pandas_agg_sum_udf
|
|
mean_udf = self.pandas_agg_mean_udf
|
|
mean_arr_udf = pandas_udf(
|
|
self.pandas_agg_mean_udf.func,
|
|
ArrayType(self.pandas_agg_mean_udf.returnType),
|
|
self.pandas_agg_mean_udf.evalType)
|
|
|
|
result1 = df.groupby('id').agg(
|
|
sum_udf(df.v),
|
|
mean_udf(df.v),
|
|
mean_arr_udf(array(df.v))).sort('id')
|
|
expected1 = self.spark.createDataFrame(
|
|
[[0, 245.0, 24.5, [24.5]],
|
|
[1, 255.0, 25.5, [25.5]],
|
|
[2, 265.0, 26.5, [26.5]],
|
|
[3, 275.0, 27.5, [27.5]],
|
|
[4, 285.0, 28.5, [28.5]],
|
|
[5, 295.0, 29.5, [29.5]],
|
|
[6, 305.0, 30.5, [30.5]],
|
|
[7, 315.0, 31.5, [31.5]],
|
|
[8, 325.0, 32.5, [32.5]],
|
|
[9, 335.0, 33.5, [33.5]]],
|
|
['id', 'sum(v)', 'avg(v)', 'avg(array(v))'])
|
|
|
|
assert_frame_equal(expected1.toPandas(), result1.toPandas())
|
|
|
|
def test_basic(self):
|
|
df = self.data
|
|
weighted_mean_udf = self.pandas_agg_weighted_mean_udf
|
|
|
|
# Groupby one column and aggregate one UDF with literal
|
|
result1 = df.groupby('id').agg(weighted_mean_udf(df.v, lit(1.0))).sort('id')
|
|
expected1 = df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort('id')
|
|
assert_frame_equal(expected1.toPandas(), result1.toPandas())
|
|
|
|
# Groupby one expression and aggregate one UDF with literal
|
|
result2 = df.groupby((col('id') + 1)).agg(weighted_mean_udf(df.v, lit(1.0)))\
|
|
.sort(df.id + 1)
|
|
expected2 = df.groupby((col('id') + 1))\
|
|
.agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort(df.id + 1)
|
|
assert_frame_equal(expected2.toPandas(), result2.toPandas())
|
|
|
|
# Groupby one column and aggregate one UDF without literal
|
|
result3 = df.groupby('id').agg(weighted_mean_udf(df.v, df.w)).sort('id')
|
|
expected3 = df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, w)')).sort('id')
|
|
assert_frame_equal(expected3.toPandas(), result3.toPandas())
|
|
|
|
# Groupby one expression and aggregate one UDF without literal
|
|
result4 = df.groupby((col('id') + 1).alias('id'))\
|
|
.agg(weighted_mean_udf(df.v, df.w))\
|
|
.sort('id')
|
|
expected4 = df.groupby((col('id') + 1).alias('id'))\
|
|
.agg(mean(df.v).alias('weighted_mean(v, w)'))\
|
|
.sort('id')
|
|
assert_frame_equal(expected4.toPandas(), result4.toPandas())
|
|
|
|
def test_unsupported_types(self):
|
|
with QuietTest(self.sc):
|
|
with self.assertRaisesRegexp(NotImplementedError, 'not supported'):
|
|
pandas_udf(
|
|
lambda x: x,
|
|
ArrayType(ArrayType(TimestampType())),
|
|
PandasUDFType.GROUPED_AGG)
|
|
|
|
with QuietTest(self.sc):
|
|
with self.assertRaisesRegexp(NotImplementedError, 'not supported'):
|
|
@pandas_udf('mean double, std double', PandasUDFType.GROUPED_AGG)
|
|
def mean_and_std_udf(v):
|
|
return v.mean(), v.std()
|
|
|
|
with QuietTest(self.sc):
|
|
with self.assertRaisesRegexp(NotImplementedError, 'not supported'):
|
|
@pandas_udf(MapType(DoubleType(), DoubleType()), PandasUDFType.GROUPED_AGG)
|
|
def mean_and_std_udf(v):
|
|
return {v.mean(): v.std()}
|
|
|
|
def test_alias(self):
|
|
df = self.data
|
|
mean_udf = self.pandas_agg_mean_udf
|
|
|
|
result1 = df.groupby('id').agg(mean_udf(df.v).alias('mean_alias'))
|
|
expected1 = df.groupby('id').agg(mean(df.v).alias('mean_alias'))
|
|
|
|
assert_frame_equal(expected1.toPandas(), result1.toPandas())
|
|
|
|
def test_mixed_sql(self):
|
|
"""
|
|
Test mixing group aggregate pandas UDF with sql expression.
|
|
"""
|
|
df = self.data
|
|
sum_udf = self.pandas_agg_sum_udf
|
|
|
|
# Mix group aggregate pandas UDF with sql expression
|
|
result1 = (df.groupby('id')
|
|
.agg(sum_udf(df.v) + 1)
|
|
.sort('id'))
|
|
expected1 = (df.groupby('id')
|
|
.agg(sum(df.v) + 1)
|
|
.sort('id'))
|
|
|
|
# Mix group aggregate pandas UDF with sql expression (order swapped)
|
|
result2 = (df.groupby('id')
|
|
.agg(sum_udf(df.v + 1))
|
|
.sort('id'))
|
|
|
|
expected2 = (df.groupby('id')
|
|
.agg(sum(df.v + 1))
|
|
.sort('id'))
|
|
|
|
# Wrap group aggregate pandas UDF with two sql expressions
|
|
result3 = (df.groupby('id')
|
|
.agg(sum_udf(df.v + 1) + 2)
|
|
.sort('id'))
|
|
expected3 = (df.groupby('id')
|
|
.agg(sum(df.v + 1) + 2)
|
|
.sort('id'))
|
|
|
|
assert_frame_equal(expected1.toPandas(), result1.toPandas())
|
|
assert_frame_equal(expected2.toPandas(), result2.toPandas())
|
|
assert_frame_equal(expected3.toPandas(), result3.toPandas())
|
|
|
|
def test_mixed_udfs(self):
|
|
"""
|
|
Test mixing group aggregate pandas UDF with python UDF and scalar pandas UDF.
|
|
"""
|
|
df = self.data
|
|
plus_one = self.python_plus_one
|
|
plus_two = self.pandas_scalar_plus_two
|
|
sum_udf = self.pandas_agg_sum_udf
|
|
|
|
# Mix group aggregate pandas UDF and python UDF
|
|
result1 = (df.groupby('id')
|
|
.agg(plus_one(sum_udf(df.v)))
|
|
.sort('id'))
|
|
expected1 = (df.groupby('id')
|
|
.agg(plus_one(sum(df.v)))
|
|
.sort('id'))
|
|
|
|
# Mix group aggregate pandas UDF and python UDF (order swapped)
|
|
result2 = (df.groupby('id')
|
|
.agg(sum_udf(plus_one(df.v)))
|
|
.sort('id'))
|
|
expected2 = (df.groupby('id')
|
|
.agg(sum(plus_one(df.v)))
|
|
.sort('id'))
|
|
|
|
# Mix group aggregate pandas UDF and scalar pandas UDF
|
|
result3 = (df.groupby('id')
|
|
.agg(sum_udf(plus_two(df.v)))
|
|
.sort('id'))
|
|
expected3 = (df.groupby('id')
|
|
.agg(sum(plus_two(df.v)))
|
|
.sort('id'))
|
|
|
|
# Mix group aggregate pandas UDF and scalar pandas UDF (order swapped)
|
|
result4 = (df.groupby('id')
|
|
.agg(plus_two(sum_udf(df.v)))
|
|
.sort('id'))
|
|
expected4 = (df.groupby('id')
|
|
.agg(plus_two(sum(df.v)))
|
|
.sort('id'))
|
|
|
|
# Wrap group aggregate pandas UDF with two python UDFs and use python UDF in groupby
|
|
result5 = (df.groupby(plus_one(df.id))
|
|
.agg(plus_one(sum_udf(plus_one(df.v))))
|
|
.sort('plus_one(id)'))
|
|
expected5 = (df.groupby(plus_one(df.id))
|
|
.agg(plus_one(sum(plus_one(df.v))))
|
|
.sort('plus_one(id)'))
|
|
|
|
# Wrap group aggregate pandas UDF with two scala pandas UDF and user scala pandas UDF in
|
|
# groupby
|
|
result6 = (df.groupby(plus_two(df.id))
|
|
.agg(plus_two(sum_udf(plus_two(df.v))))
|
|
.sort('plus_two(id)'))
|
|
expected6 = (df.groupby(plus_two(df.id))
|
|
.agg(plus_two(sum(plus_two(df.v))))
|
|
.sort('plus_two(id)'))
|
|
|
|
assert_frame_equal(expected1.toPandas(), result1.toPandas())
|
|
assert_frame_equal(expected2.toPandas(), result2.toPandas())
|
|
assert_frame_equal(expected3.toPandas(), result3.toPandas())
|
|
assert_frame_equal(expected4.toPandas(), result4.toPandas())
|
|
assert_frame_equal(expected5.toPandas(), result5.toPandas())
|
|
assert_frame_equal(expected6.toPandas(), result6.toPandas())
|
|
|
|
def test_multiple_udfs(self):
|
|
"""
|
|
Test multiple group aggregate pandas UDFs in one agg function.
|
|
"""
|
|
df = self.data
|
|
mean_udf = self.pandas_agg_mean_udf
|
|
sum_udf = self.pandas_agg_sum_udf
|
|
weighted_mean_udf = self.pandas_agg_weighted_mean_udf
|
|
|
|
result1 = (df.groupBy('id')
|
|
.agg(mean_udf(df.v),
|
|
sum_udf(df.v),
|
|
weighted_mean_udf(df.v, df.w))
|
|
.sort('id')
|
|
.toPandas())
|
|
expected1 = (df.groupBy('id')
|
|
.agg(mean(df.v),
|
|
sum(df.v),
|
|
mean(df.v).alias('weighted_mean(v, w)'))
|
|
.sort('id')
|
|
.toPandas())
|
|
|
|
assert_frame_equal(expected1, result1)
|
|
|
|
def test_complex_groupby(self):
|
|
df = self.data
|
|
sum_udf = self.pandas_agg_sum_udf
|
|
plus_one = self.python_plus_one
|
|
plus_two = self.pandas_scalar_plus_two
|
|
|
|
# groupby one expression
|
|
result1 = df.groupby(df.v % 2).agg(sum_udf(df.v))
|
|
expected1 = df.groupby(df.v % 2).agg(sum(df.v))
|
|
|
|
# empty groupby
|
|
result2 = df.groupby().agg(sum_udf(df.v))
|
|
expected2 = df.groupby().agg(sum(df.v))
|
|
|
|
# groupby one column and one sql expression
|
|
result3 = df.groupby(df.id, df.v % 2).agg(sum_udf(df.v)).orderBy(df.id, df.v % 2)
|
|
expected3 = df.groupby(df.id, df.v % 2).agg(sum(df.v)).orderBy(df.id, df.v % 2)
|
|
|
|
# groupby one python UDF
|
|
result4 = df.groupby(plus_one(df.id)).agg(sum_udf(df.v))
|
|
expected4 = df.groupby(plus_one(df.id)).agg(sum(df.v))
|
|
|
|
# groupby one scalar pandas UDF
|
|
result5 = df.groupby(plus_two(df.id)).agg(sum_udf(df.v)).sort('sum(v)')
|
|
expected5 = df.groupby(plus_two(df.id)).agg(sum(df.v)).sort('sum(v)')
|
|
|
|
# groupby one expression and one python UDF
|
|
result6 = df.groupby(df.v % 2, plus_one(df.id)).agg(sum_udf(df.v))
|
|
expected6 = df.groupby(df.v % 2, plus_one(df.id)).agg(sum(df.v))
|
|
|
|
# groupby one expression and one scalar pandas UDF
|
|
result7 = (df.groupby(df.v % 2, plus_two(df.id))
|
|
.agg(sum_udf(df.v)).sort(['sum(v)', 'plus_two(id)']))
|
|
expected7 = (df.groupby(df.v % 2, plus_two(df.id))
|
|
.agg(sum(df.v)).sort(['sum(v)', 'plus_two(id)']))
|
|
|
|
assert_frame_equal(expected1.toPandas(), result1.toPandas())
|
|
assert_frame_equal(expected2.toPandas(), result2.toPandas())
|
|
assert_frame_equal(expected3.toPandas(), result3.toPandas())
|
|
assert_frame_equal(expected4.toPandas(), result4.toPandas())
|
|
assert_frame_equal(expected5.toPandas(), result5.toPandas())
|
|
assert_frame_equal(expected6.toPandas(), result6.toPandas())
|
|
assert_frame_equal(expected7.toPandas(), result7.toPandas())
|
|
|
|
def test_complex_expressions(self):
|
|
df = self.data
|
|
plus_one = self.python_plus_one
|
|
plus_two = self.pandas_scalar_plus_two
|
|
sum_udf = self.pandas_agg_sum_udf
|
|
|
|
# Test complex expressions with sql expression, python UDF and
|
|
# group aggregate pandas UDF
|
|
result1 = (df.withColumn('v1', plus_one(df.v))
|
|
.withColumn('v2', df.v + 2)
|
|
.groupby(df.id, df.v % 2)
|
|
.agg(sum_udf(col('v')),
|
|
sum_udf(col('v1') + 3),
|
|
sum_udf(col('v2')) + 5,
|
|
plus_one(sum_udf(col('v1'))),
|
|
sum_udf(plus_one(col('v2'))))
|
|
.sort(['id', '(v % 2)'])
|
|
.toPandas().sort_values(by=['id', '(v % 2)']))
|
|
|
|
expected1 = (df.withColumn('v1', df.v + 1)
|
|
.withColumn('v2', df.v + 2)
|
|
.groupby(df.id, df.v % 2)
|
|
.agg(sum(col('v')),
|
|
sum(col('v1') + 3),
|
|
sum(col('v2')) + 5,
|
|
plus_one(sum(col('v1'))),
|
|
sum(plus_one(col('v2'))))
|
|
.sort(['id', '(v % 2)'])
|
|
.toPandas().sort_values(by=['id', '(v % 2)']))
|
|
|
|
# Test complex expressions with sql expression, scala pandas UDF and
|
|
# group aggregate pandas UDF
|
|
result2 = (df.withColumn('v1', plus_one(df.v))
|
|
.withColumn('v2', df.v + 2)
|
|
.groupby(df.id, df.v % 2)
|
|
.agg(sum_udf(col('v')),
|
|
sum_udf(col('v1') + 3),
|
|
sum_udf(col('v2')) + 5,
|
|
plus_two(sum_udf(col('v1'))),
|
|
sum_udf(plus_two(col('v2'))))
|
|
.sort(['id', '(v % 2)'])
|
|
.toPandas().sort_values(by=['id', '(v % 2)']))
|
|
|
|
expected2 = (df.withColumn('v1', df.v + 1)
|
|
.withColumn('v2', df.v + 2)
|
|
.groupby(df.id, df.v % 2)
|
|
.agg(sum(col('v')),
|
|
sum(col('v1') + 3),
|
|
sum(col('v2')) + 5,
|
|
plus_two(sum(col('v1'))),
|
|
sum(plus_two(col('v2'))))
|
|
.sort(['id', '(v % 2)'])
|
|
.toPandas().sort_values(by=['id', '(v % 2)']))
|
|
|
|
# Test sequential groupby aggregate
|
|
result3 = (df.groupby('id')
|
|
.agg(sum_udf(df.v).alias('v'))
|
|
.groupby('id')
|
|
.agg(sum_udf(col('v')))
|
|
.sort('id')
|
|
.toPandas())
|
|
|
|
expected3 = (df.groupby('id')
|
|
.agg(sum(df.v).alias('v'))
|
|
.groupby('id')
|
|
.agg(sum(col('v')))
|
|
.sort('id')
|
|
.toPandas())
|
|
|
|
assert_frame_equal(expected1, result1)
|
|
assert_frame_equal(expected2, result2)
|
|
assert_frame_equal(expected3, result3)
|
|
|
|
def test_retain_group_columns(self):
|
|
with self.sql_conf({"spark.sql.retainGroupColumns": False}):
|
|
df = self.data
|
|
sum_udf = self.pandas_agg_sum_udf
|
|
|
|
result1 = df.groupby(df.id).agg(sum_udf(df.v))
|
|
expected1 = df.groupby(df.id).agg(sum(df.v))
|
|
assert_frame_equal(expected1.toPandas(), result1.toPandas())
|
|
|
|
def test_array_type(self):
|
|
df = self.data
|
|
|
|
array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array<double>', PandasUDFType.GROUPED_AGG)
|
|
result1 = df.groupby('id').agg(array_udf(df['v']).alias('v2'))
|
|
self.assertEquals(result1.first()['v2'], [1.0, 2.0])
|
|
|
|
def test_invalid_args(self):
|
|
df = self.data
|
|
plus_one = self.python_plus_one
|
|
mean_udf = self.pandas_agg_mean_udf
|
|
|
|
with QuietTest(self.sc):
|
|
with self.assertRaisesRegexp(
|
|
AnalysisException,
|
|
'nor.*aggregate function'):
|
|
df.groupby(df.id).agg(plus_one(df.v)).collect()
|
|
|
|
with QuietTest(self.sc):
|
|
with self.assertRaisesRegexp(
|
|
AnalysisException,
|
|
'aggregate function.*argument.*aggregate function'):
|
|
df.groupby(df.id).agg(mean_udf(mean_udf(df.v))).collect()
|
|
|
|
with QuietTest(self.sc):
|
|
with self.assertRaisesRegexp(
|
|
AnalysisException,
|
|
'mixture.*aggregate function.*group aggregate pandas UDF'):
|
|
df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect()
|
|
|
|
def test_register_vectorized_udf_basic(self):
|
|
sum_pandas_udf = pandas_udf(
|
|
lambda v: v.sum(), "integer", PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF)
|
|
|
|
self.assertEqual(sum_pandas_udf.evalType, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF)
|
|
group_agg_pandas_udf = self.spark.udf.register("sum_pandas_udf", sum_pandas_udf)
|
|
self.assertEqual(group_agg_pandas_udf.evalType, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF)
|
|
q = "SELECT sum_pandas_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2"
|
|
actual = sorted(map(lambda r: r[0], self.spark.sql(q).collect()))
|
|
expected = [1, 5]
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_grouped_with_empty_partition(self):
|
|
data = [Row(id=1, x=2), Row(id=1, x=3), Row(id=2, x=4)]
|
|
expected = [Row(id=1, sum=5), Row(id=2, x=4)]
|
|
num_parts = len(data) + 1
|
|
df = self.spark.createDataFrame(self.sc.parallelize(data, numSlices=num_parts))
|
|
|
|
f = pandas_udf(lambda x: x.sum(),
|
|
'int', PandasUDFType.GROUPED_AGG)
|
|
|
|
result = df.groupBy('id').agg(f(df['x']).alias('sum')).collect()
|
|
self.assertEqual(result, expected)
|
|
|
|
def test_grouped_without_group_by_clause(self):
|
|
@pandas_udf('double', PandasUDFType.GROUPED_AGG)
|
|
def max_udf(v):
|
|
return v.max()
|
|
|
|
df = self.spark.range(0, 100)
|
|
self.spark.udf.register('max_udf', max_udf)
|
|
|
|
with self.tempView("table"):
|
|
df.createTempView('table')
|
|
|
|
agg1 = df.agg(max_udf(df['id']))
|
|
agg2 = self.spark.sql("select max_udf(id) from table")
|
|
assert_frame_equal(agg1.toPandas(), agg2.toPandas())
|
|
|
|
def test_no_predicate_pushdown_through(self):
|
|
# SPARK-30921: We should not pushdown predicates of PythonUDFs through Aggregate.
|
|
import numpy as np
|
|
|
|
@pandas_udf('float', PandasUDFType.GROUPED_AGG)
|
|
def mean(x):
|
|
return np.mean(x)
|
|
|
|
df = self.spark.createDataFrame([
|
|
Row(id=1, foo=42), Row(id=2, foo=1), Row(id=2, foo=2)
|
|
])
|
|
|
|
agg = df.groupBy('id').agg(mean('foo').alias("mean"))
|
|
filtered = agg.filter(agg['mean'] > 40.0)
|
|
|
|
assert(filtered.collect()[0]["mean"] == 42.0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from pyspark.sql.tests.test_pandas_udf_grouped_agg import * # noqa: F401
|
|
|
|
try:
|
|
import xmlrunner
|
|
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
|
|
except ImportError:
|
|
testRunner = None
|
|
unittest.main(testRunner=testRunner, verbosity=2)
|