spark-instrumented-optimizer/python/pyspark/pandas/tests/plot/test_frame_plot.py
Xinrong Meng 4d2b559d92 [SPARK-34999][PYTHON] Consolidate PySpark testing utils
### What changes were proposed in this pull request?
Consolidate PySpark testing utils by removing `python/pyspark/pandas/testing`, and then creating a file `pandasutils` under `python/pyspark/testing` for test utilities used in `pyspark/pandas`.

### Why are the changes needed?

`python/pyspark/pandas/testing` hold test utilites for pandas-on-spark, and `python/pyspark/testing` contain test utilities for pyspark. Consolidating them makes code cleaner and easier to maintain.

Updated import statements are as shown below:
- from pyspark.testing.sqlutils import SQLTestUtils
- from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
(PandasOnSparkTestCase is the original ReusedSQLTestCase in `python/pyspark/pandas/testing/utils.py`)

Minor improvements include:
- Usage of missing library's requirement_message
- `except ImportError` rather than `except`
- import pyspark.pandas alias as `ps` rather than `pp`

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Unit tests under python/pyspark/pandas/tests.

Closes #32177 from xinrong-databricks/port.merge_utils.

Authored-by: Xinrong Meng <xinrong.meng@databricks.com>
Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
2021-04-22 13:07:35 -07:00

124 lines
4.6 KiB
Python

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pandas as pd
import numpy as np
from pyspark import pandas as ps
from pyspark.pandas.config import set_option, reset_option, option_context
from pyspark.pandas.plot import TopNPlotBase, SampledPlotBase, HistogramPlotBase
from pyspark.pandas.exceptions import PandasNotImplementedError
from pyspark.testing.pandasutils import PandasOnSparkTestCase
class DataFramePlotTest(PandasOnSparkTestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
set_option("plotting.max_rows", 2000)
set_option("plotting.sample_ratio", None)
@classmethod
def tearDownClass(cls):
reset_option("plotting.max_rows")
reset_option("plotting.sample_ratio")
super().tearDownClass()
def test_missing(self):
kdf = ps.DataFrame(np.random.rand(2500, 4), columns=["a", "b", "c", "d"])
unsupported_functions = ["box", "hexbin"]
for name in unsupported_functions:
with self.assertRaisesRegex(
PandasNotImplementedError, "method.*DataFrame.*{}.*not implemented".format(name)
):
getattr(kdf.plot, name)()
def test_topn_max_rows(self):
pdf = pd.DataFrame(np.random.rand(2500, 4), columns=["a", "b", "c", "d"])
kdf = ps.from_pandas(pdf)
data = TopNPlotBase().get_top_n(kdf)
self.assertEqual(len(data), 2000)
def test_sampled_plot_with_ratio(self):
with option_context("plotting.sample_ratio", 0.5):
pdf = pd.DataFrame(np.random.rand(2500, 4), columns=["a", "b", "c", "d"])
kdf = ps.from_pandas(pdf)
data = SampledPlotBase().get_sampled(kdf)
self.assertEqual(round(len(data) / 2500, 1), 0.5)
def test_sampled_plot_with_max_rows(self):
# 'plotting.max_rows' is 2000
pdf = pd.DataFrame(np.random.rand(2000, 4), columns=["a", "b", "c", "d"])
kdf = ps.from_pandas(pdf)
data = SampledPlotBase().get_sampled(kdf)
self.assertEqual(round(len(data) / 2000, 1), 1)
def test_compute_hist_single_column(self):
kdf = ps.DataFrame(
{"a": [1, 2, 3, 4, 5, 6, 7, 8, 9, 15, 50]}, index=[0, 1, 3, 5, 6, 8, 9, 9, 9, 10, 10]
)
expected_bins = np.linspace(1, 50, 11)
bins = HistogramPlotBase.get_bins(kdf[["a"]].to_spark(), 10)
expected_histogram = np.array([5, 4, 1, 0, 0, 0, 0, 0, 0, 1])
histogram = HistogramPlotBase.compute_hist(kdf[["a"]], bins)[0]
self.assert_eq(pd.Series(expected_bins), pd.Series(bins))
self.assert_eq(pd.Series(expected_histogram, name="a"), histogram, almost=True)
def test_compute_hist_multi_columns(self):
expected_bins = np.linspace(1, 50, 11)
kdf = ps.DataFrame(
{
"a": [1, 2, 3, 4, 5, 6, 7, 8, 9, 15, 50],
"b": [50, 50, 30, 30, 30, 24, 10, 5, 4, 3, 1],
}
)
bins = HistogramPlotBase.get_bins(kdf.to_spark(), 10)
self.assert_eq(pd.Series(expected_bins), pd.Series(bins))
expected_histograms = [
np.array([5, 4, 1, 0, 0, 0, 0, 0, 0, 1]),
np.array([4, 1, 0, 0, 1, 3, 0, 0, 0, 2]),
]
histograms = HistogramPlotBase.compute_hist(kdf, bins)
expected_names = ["a", "b"]
for histogram, expected_histogram, expected_name in zip(
histograms, expected_histograms, expected_names
):
self.assert_eq(
pd.Series(expected_histogram, name=expected_name), histogram, almost=True
)
if __name__ == "__main__":
import unittest
from pyspark.pandas.tests.plot.test_frame_plot import * # noqa: F401
try:
import xmlrunner # type: ignore[import]
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)