spark-instrumented-optimizer/python/pyspark/sql/tests/test_functions.py
Fokko Driesprong 9fcf0ea718 [SPARK-32319][PYSPARK] Disallow the use of unused imports
Disallow the use of unused imports:

- Unnecessary increases the memory footprint of the application
- Removes the imports that are required for the examples in the docstring from the file-scope to the example itself. This keeps the files itself clean, and gives a more complete example as it also includes the imports :)

```
fokkodriesprongFan spark % flake8 python | grep -i "imported but unused"
python/pyspark/cloudpickle.py:46:1: F401 'functools.partial' imported but unused
python/pyspark/cloudpickle.py:55:1: F401 'traceback' imported but unused
python/pyspark/heapq3.py:868:5: F401 '_heapq.*' imported but unused
python/pyspark/__init__.py:61:1: F401 'pyspark.version.__version__' imported but unused
python/pyspark/__init__.py:62:1: F401 'pyspark._globals._NoValue' imported but unused
python/pyspark/__init__.py:115:1: F401 'pyspark.sql.SQLContext' imported but unused
python/pyspark/__init__.py:115:1: F401 'pyspark.sql.HiveContext' imported but unused
python/pyspark/__init__.py:115:1: F401 'pyspark.sql.Row' imported but unused
python/pyspark/rdd.py:21:1: F401 're' imported but unused
python/pyspark/rdd.py:29:1: F401 'tempfile.NamedTemporaryFile' imported but unused
python/pyspark/mllib/regression.py:26:1: F401 'pyspark.mllib.linalg.SparseVector' imported but unused
python/pyspark/mllib/clustering.py:28:1: F401 'pyspark.mllib.linalg.SparseVector' imported but unused
python/pyspark/mllib/clustering.py:28:1: F401 'pyspark.mllib.linalg.DenseVector' imported but unused
python/pyspark/mllib/classification.py:26:1: F401 'pyspark.mllib.linalg.SparseVector' imported but unused
python/pyspark/mllib/feature.py:28:1: F401 'pyspark.mllib.linalg.DenseVector' imported but unused
python/pyspark/mllib/feature.py:28:1: F401 'pyspark.mllib.linalg.SparseVector' imported but unused
python/pyspark/mllib/feature.py:30:1: F401 'pyspark.mllib.regression.LabeledPoint' imported but unused
python/pyspark/mllib/tests/test_linalg.py:18:1: F401 'sys' imported but unused
python/pyspark/mllib/tests/test_linalg.py:642:5: F401 'pyspark.mllib.tests.test_linalg.*' imported but unused
python/pyspark/mllib/tests/test_feature.py:21:1: F401 'numpy.random' imported but unused
python/pyspark/mllib/tests/test_feature.py:21:1: F401 'numpy.exp' imported but unused
python/pyspark/mllib/tests/test_feature.py:23:1: F401 'pyspark.mllib.linalg.Vector' imported but unused
python/pyspark/mllib/tests/test_feature.py:23:1: F401 'pyspark.mllib.linalg.VectorUDT' imported but unused
python/pyspark/mllib/tests/test_feature.py:185:5: F401 'pyspark.mllib.tests.test_feature.*' imported but unused
python/pyspark/mllib/tests/test_util.py:97:5: F401 'pyspark.mllib.tests.test_util.*' imported but unused
python/pyspark/mllib/tests/test_stat.py:23:1: F401 'pyspark.mllib.linalg.Vector' imported but unused
python/pyspark/mllib/tests/test_stat.py:23:1: F401 'pyspark.mllib.linalg.SparseVector' imported but unused
python/pyspark/mllib/tests/test_stat.py:23:1: F401 'pyspark.mllib.linalg.DenseVector' imported but unused
python/pyspark/mllib/tests/test_stat.py:23:1: F401 'pyspark.mllib.linalg.VectorUDT' imported but unused
python/pyspark/mllib/tests/test_stat.py:23:1: F401 'pyspark.mllib.linalg._convert_to_vector' imported but unused
python/pyspark/mllib/tests/test_stat.py:23:1: F401 'pyspark.mllib.linalg.DenseMatrix' imported but unused
python/pyspark/mllib/tests/test_stat.py:23:1: F401 'pyspark.mllib.linalg.SparseMatrix' imported but unused
python/pyspark/mllib/tests/test_stat.py:23:1: F401 'pyspark.mllib.linalg.MatrixUDT' imported but unused
python/pyspark/mllib/tests/test_stat.py:181:5: F401 'pyspark.mllib.tests.test_stat.*' imported but unused
python/pyspark/mllib/tests/test_streaming_algorithms.py:18:1: F401 'time.time' imported but unused
python/pyspark/mllib/tests/test_streaming_algorithms.py:18:1: F401 'time.sleep' imported but unused
python/pyspark/mllib/tests/test_streaming_algorithms.py:470:5: F401 'pyspark.mllib.tests.test_streaming_algorithms.*' imported but unused
python/pyspark/mllib/tests/test_algorithms.py:295:5: F401 'pyspark.mllib.tests.test_algorithms.*' imported but unused
python/pyspark/tests/test_serializers.py:90:13: F401 'xmlrunner' imported but unused
python/pyspark/tests/test_rdd.py:21:1: F401 'sys' imported but unused
python/pyspark/tests/test_rdd.py:29:1: F401 'pyspark.resource.ResourceProfile' imported but unused
python/pyspark/tests/test_rdd.py:885:5: F401 'pyspark.tests.test_rdd.*' imported but unused
python/pyspark/tests/test_readwrite.py:19:1: F401 'sys' imported but unused
python/pyspark/tests/test_readwrite.py:22:1: F401 'array.array' imported but unused
python/pyspark/tests/test_readwrite.py:309:5: F401 'pyspark.tests.test_readwrite.*' imported but unused
python/pyspark/tests/test_join.py:62:5: F401 'pyspark.tests.test_join.*' imported but unused
python/pyspark/tests/test_taskcontext.py:19:1: F401 'shutil' imported but unused
python/pyspark/tests/test_taskcontext.py:325:5: F401 'pyspark.tests.test_taskcontext.*' imported but unused
python/pyspark/tests/test_conf.py:36:5: F401 'pyspark.tests.test_conf.*' imported but unused
python/pyspark/tests/test_broadcast.py:148:5: F401 'pyspark.tests.test_broadcast.*' imported but unused
python/pyspark/tests/test_daemon.py:76:5: F401 'pyspark.tests.test_daemon.*' imported but unused
python/pyspark/tests/test_util.py:77:5: F401 'pyspark.tests.test_util.*' imported but unused
python/pyspark/tests/test_pin_thread.py:19:1: F401 'random' imported but unused
python/pyspark/tests/test_pin_thread.py:149:5: F401 'pyspark.tests.test_pin_thread.*' imported but unused
python/pyspark/tests/test_worker.py:19:1: F401 'sys' imported but unused
python/pyspark/tests/test_worker.py:26:5: F401 'resource' imported but unused
python/pyspark/tests/test_worker.py:203:5: F401 'pyspark.tests.test_worker.*' imported but unused
python/pyspark/tests/test_profiler.py:101:5: F401 'pyspark.tests.test_profiler.*' imported but unused
python/pyspark/tests/test_shuffle.py:18:1: F401 'sys' imported but unused
python/pyspark/tests/test_shuffle.py:171:5: F401 'pyspark.tests.test_shuffle.*' imported but unused
python/pyspark/tests/test_rddbarrier.py:43:5: F401 'pyspark.tests.test_rddbarrier.*' imported but unused
python/pyspark/tests/test_context.py:129:13: F401 'userlibrary.UserClass' imported but unused
python/pyspark/tests/test_context.py:140:13: F401 'userlib.UserClass' imported but unused
python/pyspark/tests/test_context.py:310:5: F401 'pyspark.tests.test_context.*' imported but unused
python/pyspark/tests/test_appsubmit.py:241:5: F401 'pyspark.tests.test_appsubmit.*' imported but unused
python/pyspark/streaming/dstream.py:18:1: F401 'sys' imported but unused
python/pyspark/streaming/tests/test_dstream.py:27:1: F401 'pyspark.RDD' imported but unused
python/pyspark/streaming/tests/test_dstream.py:647:5: F401 'pyspark.streaming.tests.test_dstream.*' imported but unused
python/pyspark/streaming/tests/test_kinesis.py:83:5: F401 'pyspark.streaming.tests.test_kinesis.*' imported but unused
python/pyspark/streaming/tests/test_listener.py:152:5: F401 'pyspark.streaming.tests.test_listener.*' imported but unused
python/pyspark/streaming/tests/test_context.py:178:5: F401 'pyspark.streaming.tests.test_context.*' imported but unused
python/pyspark/testing/utils.py:30:5: F401 'scipy.sparse' imported but unused
python/pyspark/testing/utils.py:36:5: F401 'numpy as np' imported but unused
python/pyspark/ml/regression.py:25:1: F401 'pyspark.ml.tree._TreeEnsembleParams' imported but unused
python/pyspark/ml/regression.py:25:1: F401 'pyspark.ml.tree._HasVarianceImpurity' imported but unused
python/pyspark/ml/regression.py:29:1: F401 'pyspark.ml.wrapper.JavaParams' imported but unused
python/pyspark/ml/util.py:19:1: F401 'sys' imported but unused
python/pyspark/ml/__init__.py:25:1: F401 'pyspark.ml.pipeline' imported but unused
python/pyspark/ml/pipeline.py:18:1: F401 'sys' imported but unused
python/pyspark/ml/stat.py:22:1: F401 'pyspark.ml.linalg.DenseMatrix' imported but unused
python/pyspark/ml/stat.py:22:1: F401 'pyspark.ml.linalg.Vectors' imported but unused
python/pyspark/ml/tests/test_training_summary.py:18:1: F401 'sys' imported but unused
python/pyspark/ml/tests/test_training_summary.py:364:5: F401 'pyspark.ml.tests.test_training_summary.*' imported but unused
python/pyspark/ml/tests/test_linalg.py:381:5: F401 'pyspark.ml.tests.test_linalg.*' imported but unused
python/pyspark/ml/tests/test_tuning.py:427:9: F401 'pyspark.sql.functions as F' imported but unused
python/pyspark/ml/tests/test_tuning.py:757:5: F401 'pyspark.ml.tests.test_tuning.*' imported but unused
python/pyspark/ml/tests/test_wrapper.py:120:5: F401 'pyspark.ml.tests.test_wrapper.*' imported but unused
python/pyspark/ml/tests/test_feature.py:19:1: F401 'sys' imported but unused
python/pyspark/ml/tests/test_feature.py:304:5: F401 'pyspark.ml.tests.test_feature.*' imported but unused
python/pyspark/ml/tests/test_image.py:19:1: F401 'py4j' imported but unused
python/pyspark/ml/tests/test_image.py:22:1: F401 'pyspark.testing.mlutils.PySparkTestCase' imported but unused
python/pyspark/ml/tests/test_image.py:71:5: F401 'pyspark.ml.tests.test_image.*' imported but unused
python/pyspark/ml/tests/test_persistence.py:456:5: F401 'pyspark.ml.tests.test_persistence.*' imported but unused
python/pyspark/ml/tests/test_evaluation.py:56:5: F401 'pyspark.ml.tests.test_evaluation.*' imported but unused
python/pyspark/ml/tests/test_stat.py:43:5: F401 'pyspark.ml.tests.test_stat.*' imported but unused
python/pyspark/ml/tests/test_base.py:70:5: F401 'pyspark.ml.tests.test_base.*' imported but unused
python/pyspark/ml/tests/test_param.py:20:1: F401 'sys' imported but unused
python/pyspark/ml/tests/test_param.py:375:5: F401 'pyspark.ml.tests.test_param.*' imported but unused
python/pyspark/ml/tests/test_pipeline.py:62:5: F401 'pyspark.ml.tests.test_pipeline.*' imported but unused
python/pyspark/ml/tests/test_algorithms.py:333:5: F401 'pyspark.ml.tests.test_algorithms.*' imported but unused
python/pyspark/ml/param/__init__.py:18:1: F401 'sys' imported but unused
python/pyspark/resource/tests/test_resources.py:17:1: F401 'random' imported but unused
python/pyspark/resource/tests/test_resources.py:20:1: F401 'pyspark.resource.ResourceProfile' imported but unused
python/pyspark/resource/tests/test_resources.py:75:5: F401 'pyspark.resource.tests.test_resources.*' imported but unused
python/pyspark/sql/functions.py:32:1: F401 'pyspark.sql.udf.UserDefinedFunction' imported but unused
python/pyspark/sql/functions.py:34:1: F401 'pyspark.sql.pandas.functions.pandas_udf' imported but unused
python/pyspark/sql/session.py:30:1: F401 'pyspark.sql.types.Row' imported but unused
python/pyspark/sql/session.py:30:1: F401 'pyspark.sql.types.StringType' imported but unused
python/pyspark/sql/readwriter.py:1084:5: F401 'pyspark.sql.Row' imported but unused
python/pyspark/sql/context.py:26:1: F401 'pyspark.sql.types.IntegerType' imported but unused
python/pyspark/sql/context.py:26:1: F401 'pyspark.sql.types.Row' imported but unused
python/pyspark/sql/context.py:26:1: F401 'pyspark.sql.types.StringType' imported but unused
python/pyspark/sql/context.py:27:1: F401 'pyspark.sql.udf.UDFRegistration' imported but unused
python/pyspark/sql/streaming.py:1212:5: F401 'pyspark.sql.Row' imported but unused
python/pyspark/sql/tests/test_utils.py:55:5: F401 'pyspark.sql.tests.test_utils.*' imported but unused
python/pyspark/sql/tests/test_pandas_map.py:18:1: F401 'sys' imported but unused
python/pyspark/sql/tests/test_pandas_map.py:22:1: F401 'pyspark.sql.functions.pandas_udf' imported but unused
python/pyspark/sql/tests/test_pandas_map.py:22:1: F401 'pyspark.sql.functions.PandasUDFType' imported but unused
python/pyspark/sql/tests/test_pandas_map.py:119:5: F401 'pyspark.sql.tests.test_pandas_map.*' imported but unused
python/pyspark/sql/tests/test_catalog.py:193:5: F401 'pyspark.sql.tests.test_catalog.*' imported but unused
python/pyspark/sql/tests/test_group.py:39:5: F401 'pyspark.sql.tests.test_group.*' imported but unused
python/pyspark/sql/tests/test_session.py:361:5: F401 'pyspark.sql.tests.test_session.*' imported but unused
python/pyspark/sql/tests/test_conf.py:49:5: F401 'pyspark.sql.tests.test_conf.*' imported but unused
python/pyspark/sql/tests/test_pandas_cogrouped_map.py:19:1: F401 'sys' imported but unused
python/pyspark/sql/tests/test_pandas_cogrouped_map.py:21:1: F401 'pyspark.sql.functions.sum' imported but unused
python/pyspark/sql/tests/test_pandas_cogrouped_map.py:21:1: F401 'pyspark.sql.functions.PandasUDFType' imported but unused
python/pyspark/sql/tests/test_pandas_cogrouped_map.py:29:5: F401 'pandas.util.testing.assert_series_equal' imported but unused
python/pyspark/sql/tests/test_pandas_cogrouped_map.py:32:5: F401 'pyarrow as pa' imported but unused
python/pyspark/sql/tests/test_pandas_cogrouped_map.py:248:5: F401 'pyspark.sql.tests.test_pandas_cogrouped_map.*' imported but unused
python/pyspark/sql/tests/test_udf.py:24:1: F401 'py4j' imported but unused
python/pyspark/sql/tests/test_pandas_udf_typehints.py:246:5: F401 'pyspark.sql.tests.test_pandas_udf_typehints.*' imported but unused
python/pyspark/sql/tests/test_functions.py:19:1: F401 'sys' imported but unused
python/pyspark/sql/tests/test_functions.py:362:9: F401 'pyspark.sql.functions.exists' imported but unused
python/pyspark/sql/tests/test_functions.py:387:5: F401 'pyspark.sql.tests.test_functions.*' imported but unused
python/pyspark/sql/tests/test_pandas_udf_scalar.py:21:1: F401 'sys' imported but unused
python/pyspark/sql/tests/test_pandas_udf_scalar.py:45:5: F401 'pyarrow as pa' imported but unused
python/pyspark/sql/tests/test_pandas_udf_window.py:355:5: F401 'pyspark.sql.tests.test_pandas_udf_window.*' imported but unused
python/pyspark/sql/tests/test_arrow.py:38:5: F401 'pyarrow as pa' imported but unused
python/pyspark/sql/tests/test_pandas_grouped_map.py:20:1: F401 'sys' imported but unused
python/pyspark/sql/tests/test_pandas_grouped_map.py:38:5: F401 'pyarrow as pa' imported but unused
python/pyspark/sql/tests/test_dataframe.py:382:9: F401 'pyspark.sql.DataFrame' imported but unused
python/pyspark/sql/avro/functions.py:125:5: F401 'pyspark.sql.Row' imported but unused
python/pyspark/sql/pandas/functions.py:19:1: F401 'sys' imported but unused
```

After:
```
fokkodriesprongFan spark % flake8 python | grep -i "imported but unused"
fokkodriesprongFan spark %
```

### What changes were proposed in this pull request?

Removing unused imports from the Python files to keep everything nice and tidy.

### Why are the changes needed?

Cleaning up of the imports that aren't used, and suppressing the imports that are used as references to other modules, preserving backward compatibility.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Adding the rule to the existing Flake8 checks.

Closes #29121 from Fokko/SPARK-32319.

Authored-by: Fokko Driesprong <fokko@apache.org>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
2020-08-08 08:51:57 -07:00

404 lines
17 KiB
Python

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import datetime
from itertools import chain
import re
from pyspark.sql import Row
from pyspark.sql.functions import udf, input_file_name, col, percentile_approx, lit
from pyspark.testing.sqlutils import ReusedSQLTestCase
class FunctionsTests(ReusedSQLTestCase):
def test_explode(self):
from pyspark.sql.functions import explode, explode_outer, posexplode_outer
d = [
Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"}),
Row(a=1, intlist=[], mapfield={}),
Row(a=1, intlist=None, mapfield=None),
]
rdd = self.sc.parallelize(d)
data = self.spark.createDataFrame(rdd)
result = data.select(explode(data.intlist).alias("a")).select("a").collect()
self.assertEqual(result[0][0], 1)
self.assertEqual(result[1][0], 2)
self.assertEqual(result[2][0], 3)
result = data.select(explode(data.mapfield).alias("a", "b")).select("a", "b").collect()
self.assertEqual(result[0][0], "a")
self.assertEqual(result[0][1], "b")
result = [tuple(x) for x in data.select(posexplode_outer("intlist")).collect()]
self.assertEqual(result, [(0, 1), (1, 2), (2, 3), (None, None), (None, None)])
result = [tuple(x) for x in data.select(posexplode_outer("mapfield")).collect()]
self.assertEqual(result, [(0, 'a', 'b'), (None, None, None), (None, None, None)])
result = [x[0] for x in data.select(explode_outer("intlist")).collect()]
self.assertEqual(result, [1, 2, 3, None, None])
result = [tuple(x) for x in data.select(explode_outer("mapfield")).collect()]
self.assertEqual(result, [('a', 'b'), (None, None), (None, None)])
def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
df = self.spark.read.json(rdd)
df.count()
df.collect()
df.schema
# cache and checkpoint
self.assertFalse(df.is_cached)
df.persist()
df.unpersist(True)
df.cache()
self.assertTrue(df.is_cached)
self.assertEqual(2, df.count())
with self.tempView("temp"):
df.createOrReplaceTempView("temp")
df = self.spark.sql("select foo from temp")
df.count()
df.collect()
def test_corr(self):
import math
df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF()
corr = df.stat.corr(u"a", "b")
self.assertTrue(abs(corr - 0.95734012) < 1e-6)
def test_sampleby(self):
df = self.sc.parallelize([Row(a=i, b=(i % 3)) for i in range(100)]).toDF()
sampled = df.stat.sampleBy(u"b", fractions={0: 0.5, 1: 0.5}, seed=0)
self.assertTrue(sampled.count() == 35)
def test_cov(self):
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
cov = df.stat.cov(u"a", "b")
self.assertTrue(abs(cov - 55.0 / 3) < 1e-6)
def test_crosstab(self):
df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF()
ct = df.stat.crosstab(u"a", "b").collect()
ct = sorted(ct, key=lambda x: x[0])
for i, row in enumerate(ct):
self.assertEqual(row[0], str(i))
self.assertTrue(row[1], 1)
self.assertTrue(row[2], 1)
def test_math_functions(self):
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
from pyspark.sql import functions
import math
def get_values(l):
return [j[0] for j in l]
def assert_close(a, b):
c = get_values(b)
diff = [abs(v - c[k]) < 1e-6 for k, v in enumerate(a)]
return sum(diff) == len(a)
assert_close([math.cos(i) for i in range(10)],
df.select(functions.cos(df.a)).collect())
assert_close([math.cos(i) for i in range(10)],
df.select(functions.cos("a")).collect())
assert_close([math.sin(i) for i in range(10)],
df.select(functions.sin(df.a)).collect())
assert_close([math.sin(i) for i in range(10)],
df.select(functions.sin(df['a'])).collect())
assert_close([math.pow(i, 2 * i) for i in range(10)],
df.select(functions.pow(df.a, df.b)).collect())
assert_close([math.pow(i, 2) for i in range(10)],
df.select(functions.pow(df.a, 2)).collect())
assert_close([math.pow(i, 2) for i in range(10)],
df.select(functions.pow(df.a, 2.0)).collect())
assert_close([math.hypot(i, 2 * i) for i in range(10)],
df.select(functions.hypot(df.a, df.b)).collect())
assert_close([math.hypot(i, 2 * i) for i in range(10)],
df.select(functions.hypot("a", u"b")).collect())
assert_close([math.hypot(i, 2) for i in range(10)],
df.select(functions.hypot("a", 2)).collect())
assert_close([math.hypot(i, 2) for i in range(10)],
df.select(functions.hypot(df.a, 2)).collect())
def test_rand_functions(self):
df = self.df
from pyspark.sql import functions
rnd = df.select('key', functions.rand()).collect()
for row in rnd:
assert row[1] >= 0.0 and row[1] <= 1.0, "got: %s" % row[1]
rndn = df.select('key', functions.randn(5)).collect()
for row in rndn:
assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1]
# If the specified seed is 0, we should use it.
# https://issues.apache.org/jira/browse/SPARK-9691
rnd1 = df.select('key', functions.rand(0)).collect()
rnd2 = df.select('key', functions.rand(0)).collect()
self.assertEqual(sorted(rnd1), sorted(rnd2))
rndn1 = df.select('key', functions.randn(0)).collect()
rndn2 = df.select('key', functions.randn(0)).collect()
self.assertEqual(sorted(rndn1), sorted(rndn2))
def test_string_functions(self):
from pyspark.sql import functions
from pyspark.sql.functions import col, lit, _string_functions
df = self.spark.createDataFrame([['nick']], schema=['name'])
self.assertRaisesRegexp(
TypeError,
"must be the same type",
lambda: df.select(col('name').substr(0, lit(1))))
for name in _string_functions.keys():
self.assertEqual(
df.select(getattr(functions, name)("name")).first()[0],
df.select(getattr(functions, name)(col("name"))).first()[0])
def test_array_contains_function(self):
from pyspark.sql.functions import array_contains
df = self.spark.createDataFrame([(["1", "2", "3"],), ([],)], ['data'])
actual = df.select(array_contains(df.data, "1").alias('b')).collect()
self.assertEqual([Row(b=True), Row(b=False)], actual)
def test_between_function(self):
df = self.sc.parallelize([
Row(a=1, b=2, c=3),
Row(a=2, b=1, c=3),
Row(a=4, b=1, c=4)]).toDF()
self.assertEqual([Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)],
df.filter(df.a.between(df.b, df.c)).collect())
def test_dayofweek(self):
from pyspark.sql.functions import dayofweek
dt = datetime.datetime(2017, 11, 6)
df = self.spark.createDataFrame([Row(date=dt)])
row = df.select(dayofweek(df.date)).first()
self.assertEqual(row[0], 2)
def test_expr(self):
from pyspark.sql import functions
row = Row(a="length string", b=75)
df = self.spark.createDataFrame([row])
result = df.select(functions.expr("length(a)")).collect()[0].asDict()
self.assertEqual(13, result["length(a)"])
# add test for SPARK-10577 (test broadcast join hint)
def test_functions_broadcast(self):
from pyspark.sql.functions import broadcast
df1 = self.spark.createDataFrame([(1, "1"), (2, "2")], ("key", "value"))
df2 = self.spark.createDataFrame([(1, "1"), (2, "2")], ("key", "value"))
# equijoin - should be converted into broadcast join
plan1 = df1.join(broadcast(df2), "key")._jdf.queryExecution().executedPlan()
self.assertEqual(1, plan1.toString().count("BroadcastHashJoin"))
# no join key -- should not be a broadcast join
plan2 = df1.crossJoin(broadcast(df2))._jdf.queryExecution().executedPlan()
self.assertEqual(0, plan2.toString().count("BroadcastHashJoin"))
# planner should not crash without a join
broadcast(df1)._jdf.queryExecution().executedPlan()
def test_first_last_ignorenulls(self):
from pyspark.sql import functions
df = self.spark.range(0, 100)
df2 = df.select(functions.when(df.id % 3 == 0, None).otherwise(df.id).alias("id"))
df3 = df2.select(functions.first(df2.id, False).alias('a'),
functions.first(df2.id, True).alias('b'),
functions.last(df2.id, False).alias('c'),
functions.last(df2.id, True).alias('d'))
self.assertEqual([Row(a=None, b=1, c=None, d=98)], df3.collect())
def test_approxQuantile(self):
df = self.sc.parallelize([Row(a=i, b=i+10) for i in range(10)]).toDF()
for f in ["a", u"a"]:
aq = df.stat.approxQuantile(f, [0.1, 0.5, 0.9], 0.1)
self.assertTrue(isinstance(aq, list))
self.assertEqual(len(aq), 3)
self.assertTrue(all(isinstance(q, float) for q in aq))
aqs = df.stat.approxQuantile(["a", u"b"], [0.1, 0.5, 0.9], 0.1)
self.assertTrue(isinstance(aqs, list))
self.assertEqual(len(aqs), 2)
self.assertTrue(isinstance(aqs[0], list))
self.assertEqual(len(aqs[0]), 3)
self.assertTrue(all(isinstance(q, float) for q in aqs[0]))
self.assertTrue(isinstance(aqs[1], list))
self.assertEqual(len(aqs[1]), 3)
self.assertTrue(all(isinstance(q, float) for q in aqs[1]))
aqt = df.stat.approxQuantile((u"a", "b"), [0.1, 0.5, 0.9], 0.1)
self.assertTrue(isinstance(aqt, list))
self.assertEqual(len(aqt), 2)
self.assertTrue(isinstance(aqt[0], list))
self.assertEqual(len(aqt[0]), 3)
self.assertTrue(all(isinstance(q, float) for q in aqt[0]))
self.assertTrue(isinstance(aqt[1], list))
self.assertEqual(len(aqt[1]), 3)
self.assertTrue(all(isinstance(q, float) for q in aqt[1]))
self.assertRaises(ValueError, lambda: df.stat.approxQuantile(123, [0.1, 0.9], 0.1))
self.assertRaises(ValueError, lambda: df.stat.approxQuantile(("a", 123), [0.1, 0.9], 0.1))
self.assertRaises(ValueError, lambda: df.stat.approxQuantile(["a", 123], [0.1, 0.9], 0.1))
def test_sort_with_nulls_order(self):
from pyspark.sql import functions
df = self.spark.createDataFrame(
[('Tom', 80), (None, 60), ('Alice', 50)], ["name", "height"])
self.assertEquals(
df.select(df.name).orderBy(functions.asc_nulls_first('name')).collect(),
[Row(name=None), Row(name=u'Alice'), Row(name=u'Tom')])
self.assertEquals(
df.select(df.name).orderBy(functions.asc_nulls_last('name')).collect(),
[Row(name=u'Alice'), Row(name=u'Tom'), Row(name=None)])
self.assertEquals(
df.select(df.name).orderBy(functions.desc_nulls_first('name')).collect(),
[Row(name=None), Row(name=u'Tom'), Row(name=u'Alice')])
self.assertEquals(
df.select(df.name).orderBy(functions.desc_nulls_last('name')).collect(),
[Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)])
def test_input_file_name_reset_for_rdd(self):
rdd = self.sc.textFile('python/test_support/hello/hello.txt').map(lambda x: {'data': x})
df = self.spark.createDataFrame(rdd, "data STRING")
df.select(input_file_name().alias('file')).collect()
non_file_df = self.spark.range(100).select(input_file_name())
results = non_file_df.collect()
self.assertTrue(len(results) == 100)
# [SPARK-24605]: if everything was properly reset after the last job, this should return
# empty string rather than the file read in the last job.
for result in results:
self.assertEqual(result[0], '')
def test_slice(self):
from pyspark.sql.functions import slice, lit
df = self.spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ['x'])
self.assertEquals(
df.select(slice(df.x, 2, 2).alias("sliced")).collect(),
df.select(slice(df.x, lit(2), lit(2)).alias("sliced")).collect(),
)
def test_array_repeat(self):
from pyspark.sql.functions import array_repeat, lit
df = self.spark.range(1)
self.assertEquals(
df.select(array_repeat("id", 3)).toDF("val").collect(),
df.select(array_repeat("id", lit(3))).toDF("val").collect(),
)
def test_input_file_name_udf(self):
df = self.spark.read.text('python/test_support/hello/hello.txt')
df = df.select(udf(lambda x: x)("value"), input_file_name().alias('file'))
file_name = df.collect()[0].file
self.assertTrue("python/test_support/hello/hello.txt" in file_name)
def test_overlay(self):
from pyspark.sql.functions import col, lit, overlay
from itertools import chain
import re
actual = list(chain.from_iterable([
re.findall("(overlay\\(.*\\))", str(x)) for x in [
overlay(col("foo"), col("bar"), 1),
overlay("x", "y", 3),
overlay(col("x"), col("y"), 1, 3),
overlay("x", "y", 2, 5),
overlay("x", "y", lit(11)),
overlay("x", "y", lit(2), lit(5)),
]
]))
expected = [
"overlay(foo, bar, 1, -1)",
"overlay(x, y, 3, -1)",
"overlay(x, y, 1, 3)",
"overlay(x, y, 2, 5)",
"overlay(x, y, 11, -1)",
"overlay(x, y, 2, 5)",
]
self.assertListEqual(actual, expected)
def test_percentile_approx(self):
actual = list(chain.from_iterable([
re.findall("(percentile_approx\\(.*\\))", str(x)) for x in [
percentile_approx(col("foo"), lit(0.5)),
percentile_approx(col("bar"), 0.25, 42),
percentile_approx(col("bar"), [0.25, 0.5, 0.75]),
percentile_approx(col("foo"), (0.05, 0.95), 100),
percentile_approx("foo", 0.5),
percentile_approx("bar", [0.1, 0.9], lit(10)),
]
]))
expected = [
"percentile_approx(foo, 0.5, 10000)",
"percentile_approx(bar, 0.25, 42)",
"percentile_approx(bar, array(0.25, 0.5, 0.75), 10000)",
"percentile_approx(foo, array(0.05, 0.95), 100)",
"percentile_approx(foo, 0.5, 10000)",
"percentile_approx(bar, array(0.1, 0.9), 10)"
]
self.assertListEqual(actual, expected)
def test_higher_order_function_failures(self):
from pyspark.sql.functions import col, transform
# Should fail with varargs
with self.assertRaises(ValueError):
transform(col("foo"), lambda *x: lit(1))
# Should fail with kwargs
with self.assertRaises(ValueError):
transform(col("foo"), lambda **x: lit(1))
# Should fail with nullary function
with self.assertRaises(ValueError):
transform(col("foo"), lambda: lit(1))
# Should fail with quaternary function
with self.assertRaises(ValueError):
transform(col("foo"), lambda x1, x2, x3, x4: lit(1))
# Should fail if function doesn't return Column
with self.assertRaises(ValueError):
transform(col("foo"), lambda x: 1)
if __name__ == "__main__":
import unittest
from pyspark.sql.tests.test_functions import * # noqa: F401
try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)