9fcf0ea718
Disallow the use of unused imports: - Unnecessary increases the memory footprint of the application - Removes the imports that are required for the examples in the docstring from the file-scope to the example itself. This keeps the files itself clean, and gives a more complete example as it also includes the imports :) ``` fokkodriesprongFan spark % flake8 python | grep -i "imported but unused" python/pyspark/cloudpickle.py:46:1: F401 'functools.partial' imported but unused python/pyspark/cloudpickle.py:55:1: F401 'traceback' imported but unused python/pyspark/heapq3.py:868:5: F401 '_heapq.*' imported but unused python/pyspark/__init__.py:61:1: F401 'pyspark.version.__version__' imported but unused python/pyspark/__init__.py:62:1: F401 'pyspark._globals._NoValue' imported but unused python/pyspark/__init__.py:115:1: F401 'pyspark.sql.SQLContext' imported but unused python/pyspark/__init__.py:115:1: F401 'pyspark.sql.HiveContext' imported but unused python/pyspark/__init__.py:115:1: F401 'pyspark.sql.Row' imported but unused python/pyspark/rdd.py:21:1: F401 're' imported but unused python/pyspark/rdd.py:29:1: F401 'tempfile.NamedTemporaryFile' imported but unused python/pyspark/mllib/regression.py:26:1: F401 'pyspark.mllib.linalg.SparseVector' imported but unused python/pyspark/mllib/clustering.py:28:1: F401 'pyspark.mllib.linalg.SparseVector' imported but unused python/pyspark/mllib/clustering.py:28:1: F401 'pyspark.mllib.linalg.DenseVector' imported but unused python/pyspark/mllib/classification.py:26:1: F401 'pyspark.mllib.linalg.SparseVector' imported but unused python/pyspark/mllib/feature.py:28:1: F401 'pyspark.mllib.linalg.DenseVector' imported but unused python/pyspark/mllib/feature.py:28:1: F401 'pyspark.mllib.linalg.SparseVector' imported but unused python/pyspark/mllib/feature.py:30:1: F401 'pyspark.mllib.regression.LabeledPoint' imported but unused python/pyspark/mllib/tests/test_linalg.py:18:1: F401 'sys' imported but unused python/pyspark/mllib/tests/test_linalg.py:642:5: F401 'pyspark.mllib.tests.test_linalg.*' imported but unused python/pyspark/mllib/tests/test_feature.py:21:1: F401 'numpy.random' imported but unused python/pyspark/mllib/tests/test_feature.py:21:1: F401 'numpy.exp' imported but unused python/pyspark/mllib/tests/test_feature.py:23:1: F401 'pyspark.mllib.linalg.Vector' imported but unused python/pyspark/mllib/tests/test_feature.py:23:1: F401 'pyspark.mllib.linalg.VectorUDT' imported but unused python/pyspark/mllib/tests/test_feature.py:185:5: F401 'pyspark.mllib.tests.test_feature.*' imported but unused python/pyspark/mllib/tests/test_util.py:97:5: F401 'pyspark.mllib.tests.test_util.*' imported but unused python/pyspark/mllib/tests/test_stat.py:23:1: F401 'pyspark.mllib.linalg.Vector' imported but unused python/pyspark/mllib/tests/test_stat.py:23:1: F401 'pyspark.mllib.linalg.SparseVector' imported but unused python/pyspark/mllib/tests/test_stat.py:23:1: F401 'pyspark.mllib.linalg.DenseVector' imported but unused python/pyspark/mllib/tests/test_stat.py:23:1: F401 'pyspark.mllib.linalg.VectorUDT' imported but unused python/pyspark/mllib/tests/test_stat.py:23:1: F401 'pyspark.mllib.linalg._convert_to_vector' imported but unused python/pyspark/mllib/tests/test_stat.py:23:1: F401 'pyspark.mllib.linalg.DenseMatrix' imported but unused python/pyspark/mllib/tests/test_stat.py:23:1: F401 'pyspark.mllib.linalg.SparseMatrix' imported but unused python/pyspark/mllib/tests/test_stat.py:23:1: F401 'pyspark.mllib.linalg.MatrixUDT' imported but unused python/pyspark/mllib/tests/test_stat.py:181:5: F401 'pyspark.mllib.tests.test_stat.*' imported but unused python/pyspark/mllib/tests/test_streaming_algorithms.py:18:1: F401 'time.time' imported but unused python/pyspark/mllib/tests/test_streaming_algorithms.py:18:1: F401 'time.sleep' imported but unused python/pyspark/mllib/tests/test_streaming_algorithms.py:470:5: F401 'pyspark.mllib.tests.test_streaming_algorithms.*' imported but unused python/pyspark/mllib/tests/test_algorithms.py:295:5: F401 'pyspark.mllib.tests.test_algorithms.*' imported but unused python/pyspark/tests/test_serializers.py:90:13: F401 'xmlrunner' imported but unused python/pyspark/tests/test_rdd.py:21:1: F401 'sys' imported but unused python/pyspark/tests/test_rdd.py:29:1: F401 'pyspark.resource.ResourceProfile' imported but unused python/pyspark/tests/test_rdd.py:885:5: F401 'pyspark.tests.test_rdd.*' imported but unused python/pyspark/tests/test_readwrite.py:19:1: F401 'sys' imported but unused python/pyspark/tests/test_readwrite.py:22:1: F401 'array.array' imported but unused python/pyspark/tests/test_readwrite.py:309:5: F401 'pyspark.tests.test_readwrite.*' imported but unused python/pyspark/tests/test_join.py:62:5: F401 'pyspark.tests.test_join.*' imported but unused python/pyspark/tests/test_taskcontext.py:19:1: F401 'shutil' imported but unused python/pyspark/tests/test_taskcontext.py:325:5: F401 'pyspark.tests.test_taskcontext.*' imported but unused python/pyspark/tests/test_conf.py:36:5: F401 'pyspark.tests.test_conf.*' imported but unused python/pyspark/tests/test_broadcast.py:148:5: F401 'pyspark.tests.test_broadcast.*' imported but unused python/pyspark/tests/test_daemon.py:76:5: F401 'pyspark.tests.test_daemon.*' imported but unused python/pyspark/tests/test_util.py:77:5: F401 'pyspark.tests.test_util.*' imported but unused python/pyspark/tests/test_pin_thread.py:19:1: F401 'random' imported but unused python/pyspark/tests/test_pin_thread.py:149:5: F401 'pyspark.tests.test_pin_thread.*' imported but unused python/pyspark/tests/test_worker.py:19:1: F401 'sys' imported but unused python/pyspark/tests/test_worker.py:26:5: F401 'resource' imported but unused python/pyspark/tests/test_worker.py:203:5: F401 'pyspark.tests.test_worker.*' imported but unused python/pyspark/tests/test_profiler.py:101:5: F401 'pyspark.tests.test_profiler.*' imported but unused python/pyspark/tests/test_shuffle.py:18:1: F401 'sys' imported but unused python/pyspark/tests/test_shuffle.py:171:5: F401 'pyspark.tests.test_shuffle.*' imported but unused python/pyspark/tests/test_rddbarrier.py:43:5: F401 'pyspark.tests.test_rddbarrier.*' imported but unused python/pyspark/tests/test_context.py:129:13: F401 'userlibrary.UserClass' imported but unused python/pyspark/tests/test_context.py:140:13: F401 'userlib.UserClass' imported but unused python/pyspark/tests/test_context.py:310:5: F401 'pyspark.tests.test_context.*' imported but unused python/pyspark/tests/test_appsubmit.py:241:5: F401 'pyspark.tests.test_appsubmit.*' imported but unused python/pyspark/streaming/dstream.py:18:1: F401 'sys' imported but unused python/pyspark/streaming/tests/test_dstream.py:27:1: F401 'pyspark.RDD' imported but unused python/pyspark/streaming/tests/test_dstream.py:647:5: F401 'pyspark.streaming.tests.test_dstream.*' imported but unused python/pyspark/streaming/tests/test_kinesis.py:83:5: F401 'pyspark.streaming.tests.test_kinesis.*' imported but unused python/pyspark/streaming/tests/test_listener.py:152:5: F401 'pyspark.streaming.tests.test_listener.*' imported but unused python/pyspark/streaming/tests/test_context.py:178:5: F401 'pyspark.streaming.tests.test_context.*' imported but unused python/pyspark/testing/utils.py:30:5: F401 'scipy.sparse' imported but unused python/pyspark/testing/utils.py:36:5: F401 'numpy as np' imported but unused python/pyspark/ml/regression.py:25:1: F401 'pyspark.ml.tree._TreeEnsembleParams' imported but unused python/pyspark/ml/regression.py:25:1: F401 'pyspark.ml.tree._HasVarianceImpurity' imported but unused python/pyspark/ml/regression.py:29:1: F401 'pyspark.ml.wrapper.JavaParams' imported but unused python/pyspark/ml/util.py:19:1: F401 'sys' imported but unused python/pyspark/ml/__init__.py:25:1: F401 'pyspark.ml.pipeline' imported but unused python/pyspark/ml/pipeline.py:18:1: F401 'sys' imported but unused python/pyspark/ml/stat.py:22:1: F401 'pyspark.ml.linalg.DenseMatrix' imported but unused python/pyspark/ml/stat.py:22:1: F401 'pyspark.ml.linalg.Vectors' imported but unused python/pyspark/ml/tests/test_training_summary.py:18:1: F401 'sys' imported but unused python/pyspark/ml/tests/test_training_summary.py:364:5: F401 'pyspark.ml.tests.test_training_summary.*' imported but unused python/pyspark/ml/tests/test_linalg.py:381:5: F401 'pyspark.ml.tests.test_linalg.*' imported but unused python/pyspark/ml/tests/test_tuning.py:427:9: F401 'pyspark.sql.functions as F' imported but unused python/pyspark/ml/tests/test_tuning.py:757:5: F401 'pyspark.ml.tests.test_tuning.*' imported but unused python/pyspark/ml/tests/test_wrapper.py:120:5: F401 'pyspark.ml.tests.test_wrapper.*' imported but unused python/pyspark/ml/tests/test_feature.py:19:1: F401 'sys' imported but unused python/pyspark/ml/tests/test_feature.py:304:5: F401 'pyspark.ml.tests.test_feature.*' imported but unused python/pyspark/ml/tests/test_image.py:19:1: F401 'py4j' imported but unused python/pyspark/ml/tests/test_image.py:22:1: F401 'pyspark.testing.mlutils.PySparkTestCase' imported but unused python/pyspark/ml/tests/test_image.py:71:5: F401 'pyspark.ml.tests.test_image.*' imported but unused python/pyspark/ml/tests/test_persistence.py:456:5: F401 'pyspark.ml.tests.test_persistence.*' imported but unused python/pyspark/ml/tests/test_evaluation.py:56:5: F401 'pyspark.ml.tests.test_evaluation.*' imported but unused python/pyspark/ml/tests/test_stat.py:43:5: F401 'pyspark.ml.tests.test_stat.*' imported but unused python/pyspark/ml/tests/test_base.py:70:5: F401 'pyspark.ml.tests.test_base.*' imported but unused python/pyspark/ml/tests/test_param.py:20:1: F401 'sys' imported but unused python/pyspark/ml/tests/test_param.py:375:5: F401 'pyspark.ml.tests.test_param.*' imported but unused python/pyspark/ml/tests/test_pipeline.py:62:5: F401 'pyspark.ml.tests.test_pipeline.*' imported but unused python/pyspark/ml/tests/test_algorithms.py:333:5: F401 'pyspark.ml.tests.test_algorithms.*' imported but unused python/pyspark/ml/param/__init__.py:18:1: F401 'sys' imported but unused python/pyspark/resource/tests/test_resources.py:17:1: F401 'random' imported but unused python/pyspark/resource/tests/test_resources.py:20:1: F401 'pyspark.resource.ResourceProfile' imported but unused python/pyspark/resource/tests/test_resources.py:75:5: F401 'pyspark.resource.tests.test_resources.*' imported but unused python/pyspark/sql/functions.py:32:1: F401 'pyspark.sql.udf.UserDefinedFunction' imported but unused python/pyspark/sql/functions.py:34:1: F401 'pyspark.sql.pandas.functions.pandas_udf' imported but unused python/pyspark/sql/session.py:30:1: F401 'pyspark.sql.types.Row' imported but unused python/pyspark/sql/session.py:30:1: F401 'pyspark.sql.types.StringType' imported but unused python/pyspark/sql/readwriter.py:1084:5: F401 'pyspark.sql.Row' imported but unused python/pyspark/sql/context.py:26:1: F401 'pyspark.sql.types.IntegerType' imported but unused python/pyspark/sql/context.py:26:1: F401 'pyspark.sql.types.Row' imported but unused python/pyspark/sql/context.py:26:1: F401 'pyspark.sql.types.StringType' imported but unused python/pyspark/sql/context.py:27:1: F401 'pyspark.sql.udf.UDFRegistration' imported but unused python/pyspark/sql/streaming.py:1212:5: F401 'pyspark.sql.Row' imported but unused python/pyspark/sql/tests/test_utils.py:55:5: F401 'pyspark.sql.tests.test_utils.*' imported but unused python/pyspark/sql/tests/test_pandas_map.py:18:1: F401 'sys' imported but unused python/pyspark/sql/tests/test_pandas_map.py:22:1: F401 'pyspark.sql.functions.pandas_udf' imported but unused python/pyspark/sql/tests/test_pandas_map.py:22:1: F401 'pyspark.sql.functions.PandasUDFType' imported but unused python/pyspark/sql/tests/test_pandas_map.py:119:5: F401 'pyspark.sql.tests.test_pandas_map.*' imported but unused python/pyspark/sql/tests/test_catalog.py:193:5: F401 'pyspark.sql.tests.test_catalog.*' imported but unused python/pyspark/sql/tests/test_group.py:39:5: F401 'pyspark.sql.tests.test_group.*' imported but unused python/pyspark/sql/tests/test_session.py:361:5: F401 'pyspark.sql.tests.test_session.*' imported but unused python/pyspark/sql/tests/test_conf.py:49:5: F401 'pyspark.sql.tests.test_conf.*' imported but unused python/pyspark/sql/tests/test_pandas_cogrouped_map.py:19:1: F401 'sys' imported but unused python/pyspark/sql/tests/test_pandas_cogrouped_map.py:21:1: F401 'pyspark.sql.functions.sum' imported but unused python/pyspark/sql/tests/test_pandas_cogrouped_map.py:21:1: F401 'pyspark.sql.functions.PandasUDFType' imported but unused python/pyspark/sql/tests/test_pandas_cogrouped_map.py:29:5: F401 'pandas.util.testing.assert_series_equal' imported but unused python/pyspark/sql/tests/test_pandas_cogrouped_map.py:32:5: F401 'pyarrow as pa' imported but unused python/pyspark/sql/tests/test_pandas_cogrouped_map.py:248:5: F401 'pyspark.sql.tests.test_pandas_cogrouped_map.*' imported but unused python/pyspark/sql/tests/test_udf.py:24:1: F401 'py4j' imported but unused python/pyspark/sql/tests/test_pandas_udf_typehints.py:246:5: F401 'pyspark.sql.tests.test_pandas_udf_typehints.*' imported but unused python/pyspark/sql/tests/test_functions.py:19:1: F401 'sys' imported but unused python/pyspark/sql/tests/test_functions.py:362:9: F401 'pyspark.sql.functions.exists' imported but unused python/pyspark/sql/tests/test_functions.py:387:5: F401 'pyspark.sql.tests.test_functions.*' imported but unused python/pyspark/sql/tests/test_pandas_udf_scalar.py:21:1: F401 'sys' imported but unused python/pyspark/sql/tests/test_pandas_udf_scalar.py:45:5: F401 'pyarrow as pa' imported but unused python/pyspark/sql/tests/test_pandas_udf_window.py:355:5: F401 'pyspark.sql.tests.test_pandas_udf_window.*' imported but unused python/pyspark/sql/tests/test_arrow.py:38:5: F401 'pyarrow as pa' imported but unused python/pyspark/sql/tests/test_pandas_grouped_map.py:20:1: F401 'sys' imported but unused python/pyspark/sql/tests/test_pandas_grouped_map.py:38:5: F401 'pyarrow as pa' imported but unused python/pyspark/sql/tests/test_dataframe.py:382:9: F401 'pyspark.sql.DataFrame' imported but unused python/pyspark/sql/avro/functions.py:125:5: F401 'pyspark.sql.Row' imported but unused python/pyspark/sql/pandas/functions.py:19:1: F401 'sys' imported but unused ``` After: ``` fokkodriesprongFan spark % flake8 python | grep -i "imported but unused" fokkodriesprongFan spark % ``` ### What changes were proposed in this pull request? Removing unused imports from the Python files to keep everything nice and tidy. ### Why are the changes needed? Cleaning up of the imports that aren't used, and suppressing the imports that are used as references to other modules, preserving backward compatibility. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Adding the rule to the existing Flake8 checks. Closes #29121 from Fokko/SPARK-32319. Authored-by: Fokko Driesprong <fokko@apache.org> Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
914 lines
39 KiB
Python
914 lines
39 KiB
Python
#
|
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
# this work for additional information regarding copyright ownership.
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
# (the "License"); you may not use this file except in compliance with
|
|
# the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import os
|
|
import pydoc
|
|
import shutil
|
|
import tempfile
|
|
import time
|
|
import unittest
|
|
|
|
from pyspark.sql import SparkSession, Row
|
|
from pyspark.sql.types import *
|
|
from pyspark.sql.utils import AnalysisException, IllegalArgumentException
|
|
from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils, have_pyarrow, have_pandas, \
|
|
pandas_requirement_message, pyarrow_requirement_message
|
|
from pyspark.testing.utils import QuietTest
|
|
|
|
|
|
class DataFrameTests(ReusedSQLTestCase):
|
|
|
|
def test_range(self):
|
|
self.assertEqual(self.spark.range(1, 1).count(), 0)
|
|
self.assertEqual(self.spark.range(1, 0, -1).count(), 1)
|
|
self.assertEqual(self.spark.range(0, 1 << 40, 1 << 39).count(), 2)
|
|
self.assertEqual(self.spark.range(-2).count(), 0)
|
|
self.assertEqual(self.spark.range(3).count(), 3)
|
|
|
|
def test_duplicated_column_names(self):
|
|
df = self.spark.createDataFrame([(1, 2)], ["c", "c"])
|
|
row = df.select('*').first()
|
|
self.assertEqual(1, row[0])
|
|
self.assertEqual(2, row[1])
|
|
self.assertEqual("Row(c=1, c=2)", str(row))
|
|
# Cannot access columns
|
|
self.assertRaises(AnalysisException, lambda: df.select(df[0]).first())
|
|
self.assertRaises(AnalysisException, lambda: df.select(df.c).first())
|
|
self.assertRaises(AnalysisException, lambda: df.select(df["c"]).first())
|
|
|
|
def test_freqItems(self):
|
|
vals = [Row(a=1, b=-2.0) if i % 2 == 0 else Row(a=i, b=i * 1.0) for i in range(100)]
|
|
df = self.sc.parallelize(vals).toDF()
|
|
items = df.stat.freqItems(("a", "b"), 0.4).collect()[0]
|
|
self.assertTrue(1 in items[0])
|
|
self.assertTrue(-2.0 in items[1])
|
|
|
|
def test_help_command(self):
|
|
# Regression test for SPARK-5464
|
|
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
|
|
df = self.spark.read.json(rdd)
|
|
# render_doc() reproduces the help() exception without printing output
|
|
pydoc.render_doc(df)
|
|
pydoc.render_doc(df.foo)
|
|
pydoc.render_doc(df.take(1))
|
|
|
|
def test_dropna(self):
|
|
schema = StructType([
|
|
StructField("name", StringType(), True),
|
|
StructField("age", IntegerType(), True),
|
|
StructField("height", DoubleType(), True)])
|
|
|
|
# shouldn't drop a non-null row
|
|
self.assertEqual(self.spark.createDataFrame(
|
|
[(u'Alice', 50, 80.1)], schema).dropna().count(),
|
|
1)
|
|
|
|
# dropping rows with a single null value
|
|
self.assertEqual(self.spark.createDataFrame(
|
|
[(u'Alice', None, 80.1)], schema).dropna().count(),
|
|
0)
|
|
self.assertEqual(self.spark.createDataFrame(
|
|
[(u'Alice', None, 80.1)], schema).dropna(how='any').count(),
|
|
0)
|
|
|
|
# if how = 'all', only drop rows if all values are null
|
|
self.assertEqual(self.spark.createDataFrame(
|
|
[(u'Alice', None, 80.1)], schema).dropna(how='all').count(),
|
|
1)
|
|
self.assertEqual(self.spark.createDataFrame(
|
|
[(None, None, None)], schema).dropna(how='all').count(),
|
|
0)
|
|
|
|
# how and subset
|
|
self.assertEqual(self.spark.createDataFrame(
|
|
[(u'Alice', 50, None)], schema).dropna(how='any', subset=['name', 'age']).count(),
|
|
1)
|
|
self.assertEqual(self.spark.createDataFrame(
|
|
[(u'Alice', None, None)], schema).dropna(how='any', subset=['name', 'age']).count(),
|
|
0)
|
|
|
|
# threshold
|
|
self.assertEqual(self.spark.createDataFrame(
|
|
[(u'Alice', None, 80.1)], schema).dropna(thresh=2).count(),
|
|
1)
|
|
self.assertEqual(self.spark.createDataFrame(
|
|
[(u'Alice', None, None)], schema).dropna(thresh=2).count(),
|
|
0)
|
|
|
|
# threshold and subset
|
|
self.assertEqual(self.spark.createDataFrame(
|
|
[(u'Alice', 50, None)], schema).dropna(thresh=2, subset=['name', 'age']).count(),
|
|
1)
|
|
self.assertEqual(self.spark.createDataFrame(
|
|
[(u'Alice', None, 180.9)], schema).dropna(thresh=2, subset=['name', 'age']).count(),
|
|
0)
|
|
|
|
# thresh should take precedence over how
|
|
self.assertEqual(self.spark.createDataFrame(
|
|
[(u'Alice', 50, None)], schema).dropna(
|
|
how='any', thresh=2, subset=['name', 'age']).count(),
|
|
1)
|
|
|
|
def test_fillna(self):
|
|
schema = StructType([
|
|
StructField("name", StringType(), True),
|
|
StructField("age", IntegerType(), True),
|
|
StructField("height", DoubleType(), True),
|
|
StructField("spy", BooleanType(), True)])
|
|
|
|
# fillna shouldn't change non-null values
|
|
row = self.spark.createDataFrame([(u'Alice', 10, 80.1, True)], schema).fillna(50).first()
|
|
self.assertEqual(row.age, 10)
|
|
|
|
# fillna with int
|
|
row = self.spark.createDataFrame([(u'Alice', None, None, None)], schema).fillna(50).first()
|
|
self.assertEqual(row.age, 50)
|
|
self.assertEqual(row.height, 50.0)
|
|
|
|
# fillna with double
|
|
row = self.spark.createDataFrame(
|
|
[(u'Alice', None, None, None)], schema).fillna(50.1).first()
|
|
self.assertEqual(row.age, 50)
|
|
self.assertEqual(row.height, 50.1)
|
|
|
|
# fillna with bool
|
|
row = self.spark.createDataFrame(
|
|
[(u'Alice', None, None, None)], schema).fillna(True).first()
|
|
self.assertEqual(row.age, None)
|
|
self.assertEqual(row.spy, True)
|
|
|
|
# fillna with string
|
|
row = self.spark.createDataFrame([(None, None, None, None)], schema).fillna("hello").first()
|
|
self.assertEqual(row.name, u"hello")
|
|
self.assertEqual(row.age, None)
|
|
|
|
# fillna with subset specified for numeric cols
|
|
row = self.spark.createDataFrame(
|
|
[(None, None, None, None)], schema).fillna(50, subset=['name', 'age']).first()
|
|
self.assertEqual(row.name, None)
|
|
self.assertEqual(row.age, 50)
|
|
self.assertEqual(row.height, None)
|
|
self.assertEqual(row.spy, None)
|
|
|
|
# fillna with subset specified for string cols
|
|
row = self.spark.createDataFrame(
|
|
[(None, None, None, None)], schema).fillna("haha", subset=['name', 'age']).first()
|
|
self.assertEqual(row.name, "haha")
|
|
self.assertEqual(row.age, None)
|
|
self.assertEqual(row.height, None)
|
|
self.assertEqual(row.spy, None)
|
|
|
|
# fillna with subset specified for bool cols
|
|
row = self.spark.createDataFrame(
|
|
[(None, None, None, None)], schema).fillna(True, subset=['name', 'spy']).first()
|
|
self.assertEqual(row.name, None)
|
|
self.assertEqual(row.age, None)
|
|
self.assertEqual(row.height, None)
|
|
self.assertEqual(row.spy, True)
|
|
|
|
# fillna with dictionary for boolean types
|
|
row = self.spark.createDataFrame([Row(a=None), Row(a=True)]).fillna({"a": True}).first()
|
|
self.assertEqual(row.a, True)
|
|
|
|
def test_repartitionByRange_dataframe(self):
|
|
schema = StructType([
|
|
StructField("name", StringType(), True),
|
|
StructField("age", IntegerType(), True),
|
|
StructField("height", DoubleType(), True)])
|
|
|
|
df1 = self.spark.createDataFrame(
|
|
[(u'Bob', 27, 66.0), (u'Alice', 10, 10.0), (u'Bob', 10, 66.0)], schema)
|
|
df2 = self.spark.createDataFrame(
|
|
[(u'Alice', 10, 10.0), (u'Bob', 10, 66.0), (u'Bob', 27, 66.0)], schema)
|
|
|
|
# test repartitionByRange(numPartitions, *cols)
|
|
df3 = df1.repartitionByRange(2, "name", "age")
|
|
self.assertEqual(df3.rdd.getNumPartitions(), 2)
|
|
self.assertEqual(df3.rdd.first(), df2.rdd.first())
|
|
self.assertEqual(df3.rdd.take(3), df2.rdd.take(3))
|
|
|
|
# test repartitionByRange(numPartitions, *cols)
|
|
df4 = df1.repartitionByRange(3, "name", "age")
|
|
self.assertEqual(df4.rdd.getNumPartitions(), 3)
|
|
self.assertEqual(df4.rdd.first(), df2.rdd.first())
|
|
self.assertEqual(df4.rdd.take(3), df2.rdd.take(3))
|
|
|
|
# test repartitionByRange(*cols)
|
|
df5 = df1.repartitionByRange("name", "age")
|
|
self.assertEqual(df5.rdd.first(), df2.rdd.first())
|
|
self.assertEqual(df5.rdd.take(3), df2.rdd.take(3))
|
|
|
|
def test_replace(self):
|
|
schema = StructType([
|
|
StructField("name", StringType(), True),
|
|
StructField("age", IntegerType(), True),
|
|
StructField("height", DoubleType(), True)])
|
|
|
|
# replace with int
|
|
row = self.spark.createDataFrame([(u'Alice', 10, 10.0)], schema).replace(10, 20).first()
|
|
self.assertEqual(row.age, 20)
|
|
self.assertEqual(row.height, 20.0)
|
|
|
|
# replace with double
|
|
row = self.spark.createDataFrame(
|
|
[(u'Alice', 80, 80.0)], schema).replace(80.0, 82.1).first()
|
|
self.assertEqual(row.age, 82)
|
|
self.assertEqual(row.height, 82.1)
|
|
|
|
# replace with string
|
|
row = self.spark.createDataFrame(
|
|
[(u'Alice', 10, 80.1)], schema).replace(u'Alice', u'Ann').first()
|
|
self.assertEqual(row.name, u"Ann")
|
|
self.assertEqual(row.age, 10)
|
|
|
|
# replace with subset specified by a string of a column name w/ actual change
|
|
row = self.spark.createDataFrame(
|
|
[(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='age').first()
|
|
self.assertEqual(row.age, 20)
|
|
|
|
# replace with subset specified by a string of a column name w/o actual change
|
|
row = self.spark.createDataFrame(
|
|
[(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='height').first()
|
|
self.assertEqual(row.age, 10)
|
|
|
|
# replace with subset specified with one column replaced, another column not in subset
|
|
# stays unchanged.
|
|
row = self.spark.createDataFrame(
|
|
[(u'Alice', 10, 10.0)], schema).replace(10, 20, subset=['name', 'age']).first()
|
|
self.assertEqual(row.name, u'Alice')
|
|
self.assertEqual(row.age, 20)
|
|
self.assertEqual(row.height, 10.0)
|
|
|
|
# replace with subset specified but no column will be replaced
|
|
row = self.spark.createDataFrame(
|
|
[(u'Alice', 10, None)], schema).replace(10, 20, subset=['name', 'height']).first()
|
|
self.assertEqual(row.name, u'Alice')
|
|
self.assertEqual(row.age, 10)
|
|
self.assertEqual(row.height, None)
|
|
|
|
# replace with lists
|
|
row = self.spark.createDataFrame(
|
|
[(u'Alice', 10, 80.1)], schema).replace([u'Alice'], [u'Ann']).first()
|
|
self.assertTupleEqual(row, (u'Ann', 10, 80.1))
|
|
|
|
# replace with dict
|
|
row = self.spark.createDataFrame(
|
|
[(u'Alice', 10, 80.1)], schema).replace({10: 11}).first()
|
|
self.assertTupleEqual(row, (u'Alice', 11, 80.1))
|
|
|
|
# test backward compatibility with dummy value
|
|
dummy_value = 1
|
|
row = self.spark.createDataFrame(
|
|
[(u'Alice', 10, 80.1)], schema).replace({'Alice': 'Bob'}, dummy_value).first()
|
|
self.assertTupleEqual(row, (u'Bob', 10, 80.1))
|
|
|
|
# test dict with mixed numerics
|
|
row = self.spark.createDataFrame(
|
|
[(u'Alice', 10, 80.1)], schema).replace({10: -10, 80.1: 90.5}).first()
|
|
self.assertTupleEqual(row, (u'Alice', -10, 90.5))
|
|
|
|
# replace with tuples
|
|
row = self.spark.createDataFrame(
|
|
[(u'Alice', 10, 80.1)], schema).replace((u'Alice', ), (u'Bob', )).first()
|
|
self.assertTupleEqual(row, (u'Bob', 10, 80.1))
|
|
|
|
# replace multiple columns
|
|
row = self.spark.createDataFrame(
|
|
[(u'Alice', 10, 80.0)], schema).replace((10, 80.0), (20, 90)).first()
|
|
self.assertTupleEqual(row, (u'Alice', 20, 90.0))
|
|
|
|
# test for mixed numerics
|
|
row = self.spark.createDataFrame(
|
|
[(u'Alice', 10, 80.0)], schema).replace((10, 80), (20, 90.5)).first()
|
|
self.assertTupleEqual(row, (u'Alice', 20, 90.5))
|
|
|
|
row = self.spark.createDataFrame(
|
|
[(u'Alice', 10, 80.0)], schema).replace({10: 20, 80: 90.5}).first()
|
|
self.assertTupleEqual(row, (u'Alice', 20, 90.5))
|
|
|
|
# replace with boolean
|
|
row = (self
|
|
.spark.createDataFrame([(u'Alice', 10, 80.0)], schema)
|
|
.selectExpr("name = 'Bob'", 'age <= 15')
|
|
.replace(False, True).first())
|
|
self.assertTupleEqual(row, (True, True))
|
|
|
|
# replace string with None and then drop None rows
|
|
row = self.spark.createDataFrame(
|
|
[(u'Alice', 10, 80.0)], schema).replace(u'Alice', None).dropna()
|
|
self.assertEqual(row.count(), 0)
|
|
|
|
# replace with number and None
|
|
row = self.spark.createDataFrame(
|
|
[(u'Alice', 10, 80.0)], schema).replace([10, 80], [20, None]).first()
|
|
self.assertTupleEqual(row, (u'Alice', 20, None))
|
|
|
|
# should fail if subset is not list, tuple or None
|
|
with self.assertRaises(ValueError):
|
|
self.spark.createDataFrame(
|
|
[(u'Alice', 10, 80.1)], schema).replace({10: 11}, subset=1).first()
|
|
|
|
# should fail if to_replace and value have different length
|
|
with self.assertRaises(ValueError):
|
|
self.spark.createDataFrame(
|
|
[(u'Alice', 10, 80.1)], schema).replace(["Alice", "Bob"], ["Eve"]).first()
|
|
|
|
# should fail if when received unexpected type
|
|
with self.assertRaises(ValueError):
|
|
from datetime import datetime
|
|
self.spark.createDataFrame(
|
|
[(u'Alice', 10, 80.1)], schema).replace(datetime.now(), datetime.now()).first()
|
|
|
|
# should fail if provided mixed type replacements
|
|
with self.assertRaises(ValueError):
|
|
self.spark.createDataFrame(
|
|
[(u'Alice', 10, 80.1)], schema).replace(["Alice", 10], ["Eve", 20]).first()
|
|
|
|
with self.assertRaises(ValueError):
|
|
self.spark.createDataFrame(
|
|
[(u'Alice', 10, 80.1)], schema).replace({u"Alice": u"Bob", 10: 20}).first()
|
|
|
|
with self.assertRaisesRegexp(
|
|
TypeError,
|
|
'value argument is required when to_replace is not a dictionary.'):
|
|
self.spark.createDataFrame(
|
|
[(u'Alice', 10, 80.0)], schema).replace(["Alice", "Bob"]).first()
|
|
|
|
def test_with_column_with_existing_name(self):
|
|
keys = self.df.withColumn("key", self.df.key).select("key").collect()
|
|
self.assertEqual([r.key for r in keys], list(range(100)))
|
|
|
|
# regression test for SPARK-10417
|
|
def test_column_iterator(self):
|
|
|
|
def foo():
|
|
for x in self.df.key:
|
|
break
|
|
|
|
self.assertRaises(TypeError, foo)
|
|
|
|
def test_generic_hints(self):
|
|
from pyspark.sql import DataFrame
|
|
|
|
df1 = self.spark.range(10e10).toDF("id")
|
|
df2 = self.spark.range(10e10).toDF("id")
|
|
|
|
self.assertIsInstance(df1.hint("broadcast"), DataFrame)
|
|
self.assertIsInstance(df1.hint("broadcast", []), DataFrame)
|
|
|
|
# Dummy rules
|
|
self.assertIsInstance(df1.hint("broadcast", "foo", "bar"), DataFrame)
|
|
self.assertIsInstance(df1.hint("broadcast", ["foo", "bar"]), DataFrame)
|
|
|
|
plan = df1.join(df2.hint("broadcast"), "id")._jdf.queryExecution().executedPlan()
|
|
self.assertEqual(1, plan.toString().count("BroadcastHashJoin"))
|
|
|
|
# add tests for SPARK-23647 (test more types for hint)
|
|
def test_extended_hint_types(self):
|
|
df = self.spark.range(10e10).toDF("id")
|
|
such_a_nice_list = ["itworks1", "itworks2", "itworks3"]
|
|
hinted_df = df.hint("my awesome hint", 1.2345, "what", such_a_nice_list)
|
|
logical_plan = hinted_df._jdf.queryExecution().logical()
|
|
|
|
self.assertEqual(1, logical_plan.toString().count("1.2345"))
|
|
self.assertEqual(1, logical_plan.toString().count("what"))
|
|
self.assertEqual(3, logical_plan.toString().count("itworks"))
|
|
|
|
def test_sample(self):
|
|
self.assertRaisesRegexp(
|
|
TypeError,
|
|
"should be a bool, float and number",
|
|
lambda: self.spark.range(1).sample())
|
|
|
|
self.assertRaises(
|
|
TypeError,
|
|
lambda: self.spark.range(1).sample("a"))
|
|
|
|
self.assertRaises(
|
|
TypeError,
|
|
lambda: self.spark.range(1).sample(seed="abc"))
|
|
|
|
self.assertRaises(
|
|
IllegalArgumentException,
|
|
lambda: self.spark.range(1).sample(-1.0))
|
|
|
|
def test_toDF_with_schema_string(self):
|
|
data = [Row(key=i, value=str(i)) for i in range(100)]
|
|
rdd = self.sc.parallelize(data, 5)
|
|
|
|
df = rdd.toDF("key: int, value: string")
|
|
self.assertEqual(df.schema.simpleString(), "struct<key:int,value:string>")
|
|
self.assertEqual(df.collect(), data)
|
|
|
|
# different but compatible field types can be used.
|
|
df = rdd.toDF("key: string, value: string")
|
|
self.assertEqual(df.schema.simpleString(), "struct<key:string,value:string>")
|
|
self.assertEqual(df.collect(), [Row(key=str(i), value=str(i)) for i in range(100)])
|
|
|
|
# field names can differ.
|
|
df = rdd.toDF(" a: int, b: string ")
|
|
self.assertEqual(df.schema.simpleString(), "struct<a:int,b:string>")
|
|
self.assertEqual(df.collect(), data)
|
|
|
|
# number of fields must match.
|
|
self.assertRaisesRegexp(Exception, "Length of object",
|
|
lambda: rdd.toDF("key: int").collect())
|
|
|
|
# field types mismatch will cause exception at runtime.
|
|
self.assertRaisesRegexp(Exception, "FloatType can not accept",
|
|
lambda: rdd.toDF("key: float, value: string").collect())
|
|
|
|
# flat schema values will be wrapped into row.
|
|
df = rdd.map(lambda row: row.key).toDF("int")
|
|
self.assertEqual(df.schema.simpleString(), "struct<value:int>")
|
|
self.assertEqual(df.collect(), [Row(key=i) for i in range(100)])
|
|
|
|
# users can use DataType directly instead of data type string.
|
|
df = rdd.map(lambda row: row.key).toDF(IntegerType())
|
|
self.assertEqual(df.schema.simpleString(), "struct<value:int>")
|
|
self.assertEqual(df.collect(), [Row(key=i) for i in range(100)])
|
|
|
|
def test_join_without_on(self):
|
|
df1 = self.spark.range(1).toDF("a")
|
|
df2 = self.spark.range(1).toDF("b")
|
|
|
|
with self.sql_conf({"spark.sql.crossJoin.enabled": False}):
|
|
self.assertRaises(AnalysisException, lambda: df1.join(df2, how="inner").collect())
|
|
|
|
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
|
|
actual = df1.join(df2, how="inner").collect()
|
|
expected = [Row(a=0, b=0)]
|
|
self.assertEqual(actual, expected)
|
|
|
|
# Regression test for invalid join methods when on is None, Spark-14761
|
|
def test_invalid_join_method(self):
|
|
df1 = self.spark.createDataFrame([("Alice", 5), ("Bob", 8)], ["name", "age"])
|
|
df2 = self.spark.createDataFrame([("Alice", 80), ("Bob", 90)], ["name", "height"])
|
|
self.assertRaises(IllegalArgumentException, lambda: df1.join(df2, how="invalid-join-type"))
|
|
|
|
# Cartesian products require cross join syntax
|
|
def test_require_cross(self):
|
|
|
|
df1 = self.spark.createDataFrame([(1, "1")], ("key", "value"))
|
|
df2 = self.spark.createDataFrame([(1, "1")], ("key", "value"))
|
|
|
|
with self.sql_conf({"spark.sql.crossJoin.enabled": False}):
|
|
# joins without conditions require cross join syntax
|
|
self.assertRaises(AnalysisException, lambda: df1.join(df2).collect())
|
|
|
|
# works with crossJoin
|
|
self.assertEqual(1, df1.crossJoin(df2).count())
|
|
|
|
def test_cache(self):
|
|
spark = self.spark
|
|
with self.tempView("tab1", "tab2"):
|
|
spark.createDataFrame([(2, 2), (3, 3)]).createOrReplaceTempView("tab1")
|
|
spark.createDataFrame([(2, 2), (3, 3)]).createOrReplaceTempView("tab2")
|
|
self.assertFalse(spark.catalog.isCached("tab1"))
|
|
self.assertFalse(spark.catalog.isCached("tab2"))
|
|
spark.catalog.cacheTable("tab1")
|
|
self.assertTrue(spark.catalog.isCached("tab1"))
|
|
self.assertFalse(spark.catalog.isCached("tab2"))
|
|
spark.catalog.cacheTable("tab2")
|
|
spark.catalog.uncacheTable("tab1")
|
|
self.assertFalse(spark.catalog.isCached("tab1"))
|
|
self.assertTrue(spark.catalog.isCached("tab2"))
|
|
spark.catalog.clearCache()
|
|
self.assertFalse(spark.catalog.isCached("tab1"))
|
|
self.assertFalse(spark.catalog.isCached("tab2"))
|
|
self.assertRaisesRegexp(
|
|
AnalysisException,
|
|
"does_not_exist",
|
|
lambda: spark.catalog.isCached("does_not_exist"))
|
|
self.assertRaisesRegexp(
|
|
AnalysisException,
|
|
"does_not_exist",
|
|
lambda: spark.catalog.cacheTable("does_not_exist"))
|
|
self.assertRaisesRegexp(
|
|
AnalysisException,
|
|
"does_not_exist",
|
|
lambda: spark.catalog.uncacheTable("does_not_exist"))
|
|
|
|
def _to_pandas(self):
|
|
from datetime import datetime, date
|
|
schema = StructType().add("a", IntegerType()).add("b", StringType())\
|
|
.add("c", BooleanType()).add("d", FloatType())\
|
|
.add("dt", DateType()).add("ts", TimestampType())
|
|
data = [
|
|
(1, "foo", True, 3.0, date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
|
|
(2, "foo", True, 5.0, None, None),
|
|
(3, "bar", False, -1.0, date(2012, 3, 3), datetime(2012, 3, 3, 3, 3, 3)),
|
|
(4, "bar", False, 6.0, date(2100, 4, 4), datetime(2100, 4, 4, 4, 4, 4)),
|
|
]
|
|
df = self.spark.createDataFrame(data, schema)
|
|
return df.toPandas()
|
|
|
|
@unittest.skipIf(not have_pandas, pandas_requirement_message)
|
|
def test_to_pandas(self):
|
|
import numpy as np
|
|
pdf = self._to_pandas()
|
|
types = pdf.dtypes
|
|
self.assertEquals(types[0], np.int32)
|
|
self.assertEquals(types[1], np.object)
|
|
self.assertEquals(types[2], np.bool)
|
|
self.assertEquals(types[3], np.float32)
|
|
self.assertEquals(types[4], np.object) # datetime.date
|
|
self.assertEquals(types[5], 'datetime64[ns]')
|
|
|
|
@unittest.skipIf(not have_pandas, pandas_requirement_message)
|
|
def test_to_pandas_with_duplicated_column_names(self):
|
|
import numpy as np
|
|
|
|
sql = "select 1 v, 1 v"
|
|
for arrowEnabled in [False, True]:
|
|
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrowEnabled}):
|
|
df = self.spark.sql(sql)
|
|
pdf = df.toPandas()
|
|
types = pdf.dtypes
|
|
self.assertEquals(types.iloc[0], np.int32)
|
|
self.assertEquals(types.iloc[1], np.int32)
|
|
|
|
@unittest.skipIf(not have_pandas, pandas_requirement_message)
|
|
def test_to_pandas_on_cross_join(self):
|
|
import numpy as np
|
|
|
|
sql = """
|
|
select t1.*, t2.* from (
|
|
select explode(sequence(1, 3)) v
|
|
) t1 left join (
|
|
select explode(sequence(1, 3)) v
|
|
) t2
|
|
"""
|
|
for arrowEnabled in [False, True]:
|
|
with self.sql_conf({"spark.sql.crossJoin.enabled": True,
|
|
"spark.sql.execution.arrow.pyspark.enabled": arrowEnabled}):
|
|
df = self.spark.sql(sql)
|
|
pdf = df.toPandas()
|
|
types = pdf.dtypes
|
|
self.assertEquals(types.iloc[0], np.int32)
|
|
self.assertEquals(types.iloc[1], np.int32)
|
|
|
|
@unittest.skipIf(have_pandas, "Required Pandas was found.")
|
|
def test_to_pandas_required_pandas_not_found(self):
|
|
with QuietTest(self.sc):
|
|
with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'):
|
|
self._to_pandas()
|
|
|
|
@unittest.skipIf(not have_pandas, pandas_requirement_message)
|
|
def test_to_pandas_avoid_astype(self):
|
|
import numpy as np
|
|
schema = StructType().add("a", IntegerType()).add("b", StringType())\
|
|
.add("c", IntegerType())
|
|
data = [(1, "foo", 16777220), (None, "bar", None)]
|
|
df = self.spark.createDataFrame(data, schema)
|
|
types = df.toPandas().dtypes
|
|
self.assertEquals(types[0], np.float64) # doesn't convert to np.int32 due to NaN value.
|
|
self.assertEquals(types[1], np.object)
|
|
self.assertEquals(types[2], np.float64)
|
|
|
|
@unittest.skipIf(not have_pandas, pandas_requirement_message)
|
|
def test_to_pandas_from_empty_dataframe(self):
|
|
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
|
|
# SPARK-29188 test that toPandas() on an empty dataframe has the correct dtypes
|
|
import numpy as np
|
|
sql = """
|
|
SELECT CAST(1 AS TINYINT) AS tinyint,
|
|
CAST(1 AS SMALLINT) AS smallint,
|
|
CAST(1 AS INT) AS int,
|
|
CAST(1 AS BIGINT) AS bigint,
|
|
CAST(0 AS FLOAT) AS float,
|
|
CAST(0 AS DOUBLE) AS double,
|
|
CAST(1 AS BOOLEAN) AS boolean,
|
|
CAST('foo' AS STRING) AS string,
|
|
CAST('2019-01-01' AS TIMESTAMP) AS timestamp
|
|
"""
|
|
dtypes_when_nonempty_df = self.spark.sql(sql).toPandas().dtypes
|
|
dtypes_when_empty_df = self.spark.sql(sql).filter("False").toPandas().dtypes
|
|
self.assertTrue(np.all(dtypes_when_empty_df == dtypes_when_nonempty_df))
|
|
|
|
@unittest.skipIf(not have_pandas, pandas_requirement_message)
|
|
def test_to_pandas_from_null_dataframe(self):
|
|
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
|
|
# SPARK-29188 test that toPandas() on a dataframe with only nulls has correct dtypes
|
|
import numpy as np
|
|
sql = """
|
|
SELECT CAST(NULL AS TINYINT) AS tinyint,
|
|
CAST(NULL AS SMALLINT) AS smallint,
|
|
CAST(NULL AS INT) AS int,
|
|
CAST(NULL AS BIGINT) AS bigint,
|
|
CAST(NULL AS FLOAT) AS float,
|
|
CAST(NULL AS DOUBLE) AS double,
|
|
CAST(NULL AS BOOLEAN) AS boolean,
|
|
CAST(NULL AS STRING) AS string,
|
|
CAST(NULL AS TIMESTAMP) AS timestamp
|
|
"""
|
|
pdf = self.spark.sql(sql).toPandas()
|
|
types = pdf.dtypes
|
|
self.assertEqual(types[0], np.float64)
|
|
self.assertEqual(types[1], np.float64)
|
|
self.assertEqual(types[2], np.float64)
|
|
self.assertEqual(types[3], np.float64)
|
|
self.assertEqual(types[4], np.float32)
|
|
self.assertEqual(types[5], np.float64)
|
|
self.assertEqual(types[6], np.object)
|
|
self.assertEqual(types[7], np.object)
|
|
self.assertTrue(np.can_cast(np.datetime64, types[8]))
|
|
|
|
@unittest.skipIf(not have_pandas, pandas_requirement_message)
|
|
def test_to_pandas_from_mixed_dataframe(self):
|
|
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
|
|
# SPARK-29188 test that toPandas() on a dataframe with some nulls has correct dtypes
|
|
import numpy as np
|
|
sql = """
|
|
SELECT CAST(col1 AS TINYINT) AS tinyint,
|
|
CAST(col2 AS SMALLINT) AS smallint,
|
|
CAST(col3 AS INT) AS int,
|
|
CAST(col4 AS BIGINT) AS bigint,
|
|
CAST(col5 AS FLOAT) AS float,
|
|
CAST(col6 AS DOUBLE) AS double,
|
|
CAST(col7 AS BOOLEAN) AS boolean,
|
|
CAST(col8 AS STRING) AS string,
|
|
timestamp_seconds(col9) AS timestamp
|
|
FROM VALUES (1, 1, 1, 1, 1, 1, 1, 1, 1),
|
|
(NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL)
|
|
"""
|
|
pdf_with_some_nulls = self.spark.sql(sql).toPandas()
|
|
pdf_with_only_nulls = self.spark.sql(sql).filter('tinyint is null').toPandas()
|
|
self.assertTrue(np.all(pdf_with_only_nulls.dtypes == pdf_with_some_nulls.dtypes))
|
|
|
|
def test_create_dataframe_from_array_of_long(self):
|
|
import array
|
|
data = [Row(longarray=array.array('l', [-9223372036854775808, 0, 9223372036854775807]))]
|
|
df = self.spark.createDataFrame(data)
|
|
self.assertEqual(df.first(), Row(longarray=[-9223372036854775808, 0, 9223372036854775807]))
|
|
|
|
@unittest.skipIf(not have_pandas, pandas_requirement_message)
|
|
def test_create_dataframe_from_pandas_with_timestamp(self):
|
|
import pandas as pd
|
|
from datetime import datetime
|
|
pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)],
|
|
"d": [pd.Timestamp.now().date()]}, columns=["d", "ts"])
|
|
# test types are inferred correctly without specifying schema
|
|
df = self.spark.createDataFrame(pdf)
|
|
self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType))
|
|
self.assertTrue(isinstance(df.schema['d'].dataType, DateType))
|
|
# test with schema will accept pdf as input
|
|
df = self.spark.createDataFrame(pdf, schema="d date, ts timestamp")
|
|
self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType))
|
|
self.assertTrue(isinstance(df.schema['d'].dataType, DateType))
|
|
|
|
@unittest.skipIf(have_pandas, "Required Pandas was found.")
|
|
def test_create_dataframe_required_pandas_not_found(self):
|
|
with QuietTest(self.sc):
|
|
with self.assertRaisesRegexp(
|
|
ImportError,
|
|
"(Pandas >= .* must be installed|No module named '?pandas'?)"):
|
|
import pandas as pd
|
|
from datetime import datetime
|
|
pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)],
|
|
"d": [pd.Timestamp.now().date()]})
|
|
self.spark.createDataFrame(pdf)
|
|
|
|
# Regression test for SPARK-23360
|
|
@unittest.skipIf(not have_pandas, pandas_requirement_message)
|
|
def test_create_dataframe_from_pandas_with_dst(self):
|
|
import pandas as pd
|
|
from pandas.util.testing import assert_frame_equal
|
|
from datetime import datetime
|
|
|
|
pdf = pd.DataFrame({'time': [datetime(2015, 10, 31, 22, 30)]})
|
|
|
|
df = self.spark.createDataFrame(pdf)
|
|
assert_frame_equal(pdf, df.toPandas())
|
|
|
|
orig_env_tz = os.environ.get('TZ', None)
|
|
try:
|
|
tz = 'America/Los_Angeles'
|
|
os.environ['TZ'] = tz
|
|
time.tzset()
|
|
with self.sql_conf({'spark.sql.session.timeZone': tz}):
|
|
df = self.spark.createDataFrame(pdf)
|
|
assert_frame_equal(pdf, df.toPandas())
|
|
finally:
|
|
del os.environ['TZ']
|
|
if orig_env_tz is not None:
|
|
os.environ['TZ'] = orig_env_tz
|
|
time.tzset()
|
|
|
|
def test_repr_behaviors(self):
|
|
import re
|
|
pattern = re.compile(r'^ *\|', re.MULTILINE)
|
|
df = self.spark.createDataFrame([(1, "1"), (22222, "22222")], ("key", "value"))
|
|
|
|
# test when eager evaluation is enabled and _repr_html_ will not be called
|
|
with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}):
|
|
expected1 = """+-----+-----+
|
|
|| key|value|
|
|
|+-----+-----+
|
|
|| 1| 1|
|
|
||22222|22222|
|
|
|+-----+-----+
|
|
|"""
|
|
self.assertEquals(re.sub(pattern, '', expected1), df.__repr__())
|
|
with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}):
|
|
expected2 = """+---+-----+
|
|
||key|value|
|
|
|+---+-----+
|
|
|| 1| 1|
|
|
||222| 222|
|
|
|+---+-----+
|
|
|"""
|
|
self.assertEquals(re.sub(pattern, '', expected2), df.__repr__())
|
|
with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}):
|
|
expected3 = """+---+-----+
|
|
||key|value|
|
|
|+---+-----+
|
|
|| 1| 1|
|
|
|+---+-----+
|
|
|only showing top 1 row
|
|
|"""
|
|
self.assertEquals(re.sub(pattern, '', expected3), df.__repr__())
|
|
|
|
# test when eager evaluation is enabled and _repr_html_ will be called
|
|
with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}):
|
|
expected1 = """<table border='1'>
|
|
|<tr><th>key</th><th>value</th></tr>
|
|
|<tr><td>1</td><td>1</td></tr>
|
|
|<tr><td>22222</td><td>22222</td></tr>
|
|
|</table>
|
|
|"""
|
|
self.assertEquals(re.sub(pattern, '', expected1), df._repr_html_())
|
|
with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}):
|
|
expected2 = """<table border='1'>
|
|
|<tr><th>key</th><th>value</th></tr>
|
|
|<tr><td>1</td><td>1</td></tr>
|
|
|<tr><td>222</td><td>222</td></tr>
|
|
|</table>
|
|
|"""
|
|
self.assertEquals(re.sub(pattern, '', expected2), df._repr_html_())
|
|
with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}):
|
|
expected3 = """<table border='1'>
|
|
|<tr><th>key</th><th>value</th></tr>
|
|
|<tr><td>1</td><td>1</td></tr>
|
|
|</table>
|
|
|only showing top 1 row
|
|
|"""
|
|
self.assertEquals(re.sub(pattern, '', expected3), df._repr_html_())
|
|
|
|
# test when eager evaluation is disabled and _repr_html_ will be called
|
|
with self.sql_conf({"spark.sql.repl.eagerEval.enabled": False}):
|
|
expected = "DataFrame[key: bigint, value: string]"
|
|
self.assertEquals(None, df._repr_html_())
|
|
self.assertEquals(expected, df.__repr__())
|
|
with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}):
|
|
self.assertEquals(None, df._repr_html_())
|
|
self.assertEquals(expected, df.__repr__())
|
|
with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}):
|
|
self.assertEquals(None, df._repr_html_())
|
|
self.assertEquals(expected, df.__repr__())
|
|
|
|
def test_to_local_iterator(self):
|
|
df = self.spark.range(8, numPartitions=4)
|
|
expected = df.collect()
|
|
it = df.toLocalIterator()
|
|
self.assertEqual(expected, list(it))
|
|
|
|
# Test DataFrame with empty partition
|
|
df = self.spark.range(3, numPartitions=4)
|
|
it = df.toLocalIterator()
|
|
expected = df.collect()
|
|
self.assertEqual(expected, list(it))
|
|
|
|
def test_to_local_iterator_prefetch(self):
|
|
df = self.spark.range(8, numPartitions=4)
|
|
expected = df.collect()
|
|
it = df.toLocalIterator(prefetchPartitions=True)
|
|
self.assertEqual(expected, list(it))
|
|
|
|
def test_to_local_iterator_not_fully_consumed(self):
|
|
# SPARK-23961: toLocalIterator throws exception when not fully consumed
|
|
# Create a DataFrame large enough so that write to socket will eventually block
|
|
df = self.spark.range(1 << 20, numPartitions=2)
|
|
it = df.toLocalIterator()
|
|
self.assertEqual(df.take(1)[0], next(it))
|
|
with QuietTest(self.sc):
|
|
it = None # remove iterator from scope, socket is closed when cleaned up
|
|
# Make sure normal df operations still work
|
|
result = []
|
|
for i, row in enumerate(df.toLocalIterator()):
|
|
result.append(row)
|
|
if i == 7:
|
|
break
|
|
self.assertEqual(df.take(8), result)
|
|
|
|
def test_same_semantics_error(self):
|
|
with QuietTest(self.sc):
|
|
with self.assertRaisesRegexp(ValueError, "should be of DataFrame.*int"):
|
|
self.spark.range(10).sameSemantics(1)
|
|
|
|
def test_input_files(self):
|
|
tpath = tempfile.mkdtemp()
|
|
shutil.rmtree(tpath)
|
|
try:
|
|
self.spark.range(1, 100, 1, 10).write.parquet(tpath)
|
|
# read parquet file and get the input files list
|
|
input_files_list = self.spark.read.parquet(tpath).inputFiles()
|
|
|
|
# input files list should contain 10 entries
|
|
self.assertEquals(len(input_files_list), 10)
|
|
# all file paths in list must contain tpath
|
|
for file_path in input_files_list:
|
|
self.assertTrue(tpath in file_path)
|
|
finally:
|
|
shutil.rmtree(tpath)
|
|
|
|
|
|
class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils):
|
|
# These tests are separate because it uses 'spark.sql.queryExecutionListeners' which is
|
|
# static and immutable. This can't be set or unset, for example, via `spark.conf`.
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
import glob
|
|
from pyspark.find_spark_home import _find_spark_home
|
|
|
|
SPARK_HOME = _find_spark_home()
|
|
filename_pattern = (
|
|
"sql/core/target/scala-*/test-classes/org/apache/spark/sql/"
|
|
"TestQueryExecutionListener.class")
|
|
cls.has_listener = bool(glob.glob(os.path.join(SPARK_HOME, filename_pattern)))
|
|
|
|
if cls.has_listener:
|
|
# Note that 'spark.sql.queryExecutionListeners' is a static immutable configuration.
|
|
cls.spark = SparkSession.builder \
|
|
.master("local[4]") \
|
|
.appName(cls.__name__) \
|
|
.config(
|
|
"spark.sql.queryExecutionListeners",
|
|
"org.apache.spark.sql.TestQueryExecutionListener") \
|
|
.getOrCreate()
|
|
|
|
def setUp(self):
|
|
if not self.has_listener:
|
|
raise self.skipTest(
|
|
"'org.apache.spark.sql.TestQueryExecutionListener' is not "
|
|
"available. Will skip the related tests.")
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
if hasattr(cls, "spark"):
|
|
cls.spark.stop()
|
|
|
|
def tearDown(self):
|
|
self.spark._jvm.OnSuccessCall.clear()
|
|
|
|
def test_query_execution_listener_on_collect(self):
|
|
self.assertFalse(
|
|
self.spark._jvm.OnSuccessCall.isCalled(),
|
|
"The callback from the query execution listener should not be called before 'collect'")
|
|
self.spark.sql("SELECT * FROM range(1)").collect()
|
|
self.spark.sparkContext._jsc.sc().listenerBus().waitUntilEmpty(10000)
|
|
self.assertTrue(
|
|
self.spark._jvm.OnSuccessCall.isCalled(),
|
|
"The callback from the query execution listener should be called after 'collect'")
|
|
|
|
@unittest.skipIf(
|
|
not have_pandas or not have_pyarrow,
|
|
pandas_requirement_message or pyarrow_requirement_message)
|
|
def test_query_execution_listener_on_collect_with_arrow(self):
|
|
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": True}):
|
|
self.assertFalse(
|
|
self.spark._jvm.OnSuccessCall.isCalled(),
|
|
"The callback from the query execution listener should not be "
|
|
"called before 'toPandas'")
|
|
self.spark.sql("SELECT * FROM range(1)").toPandas()
|
|
self.spark.sparkContext._jsc.sc().listenerBus().waitUntilEmpty(10000)
|
|
self.assertTrue(
|
|
self.spark._jvm.OnSuccessCall.isCalled(),
|
|
"The callback from the query execution listener should be called after 'toPandas'")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from pyspark.sql.tests.test_dataframe import *
|
|
|
|
try:
|
|
import xmlrunner
|
|
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
|
|
except ImportError:
|
|
testRunner = None
|
|
unittest.main(testRunner=testRunner, verbosity=2)
|