[SPARK-35035][PYTHON] Port Koalas internal implementation unit tests into PySpark

### What changes were proposed in this pull request?
Now that we merged the Koalas main code into the PySpark code base (#32036), we should port the Koalas internal implementation unit tests to PySpark.

### Why are the changes needed?
Currently, the pandas-on-Spark modules are not tested fully. We should enable the internal implementation unit tests.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Enable internal implementation unit tests.

Closes #32137 from xinrong-databricks/port.test_internal_impl.

Lead-authored-by: Xinrong Meng <xinrong.meng@databricks.com>
Co-authored-by: xinrong-databricks <47337188+xinrong-databricks@users.noreply.github.com>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
Xinrong Meng 2021-04-14 13:59:33 +09:00 committed by HyukjinKwon
parent 2974b70d1e
commit 47d62af2a9
8 changed files with 1218 additions and 0 deletions

View file

@ -612,6 +612,13 @@ pyspark_pandas = Module(
"pyspark.pandas.typedef.typehints",
# unittests
"pyspark.pandas.tests.test_dataframe",
"pyspark.pandas.tests.test_config",
"pyspark.pandas.tests.test_default_index",
"pyspark.pandas.tests.test_extension",
"pyspark.pandas.tests.test_internal",
"pyspark.pandas.tests.test_numpy_compat",
"pyspark.pandas.tests.test_typedef",
"pyspark.pandas.tests.test_utils",
"pyspark.pandas.tests.test_dataframe_conversion",
"pyspark.pandas.tests.test_dataframe_spark_io",
"pyspark.pandas.tests.test_frame_spark",

View file

@ -0,0 +1,155 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pyspark import pandas as ps
from pyspark.pandas import config
from pyspark.pandas.config import Option, DictWrapper
from pyspark.pandas.testing.utils import ReusedSQLTestCase
class ConfigTest(ReusedSQLTestCase):
def setUp(self):
config._options_dict["test.config"] = Option(key="test.config", doc="", default="default")
config._options_dict["test.config.list"] = Option(
key="test.config.list", doc="", default=[], types=list
)
config._options_dict["test.config.float"] = Option(
key="test.config.float", doc="", default=1.2, types=float
)
config._options_dict["test.config.int"] = Option(
key="test.config.int",
doc="",
default=1,
types=int,
check_func=(lambda v: v > 0, "bigger then 0"),
)
config._options_dict["test.config.int.none"] = Option(
key="test.config.int", doc="", default=None, types=(int, type(None))
)
def tearDown(self):
ps.reset_option("test.config")
del config._options_dict["test.config"]
del config._options_dict["test.config.list"]
del config._options_dict["test.config.float"]
del config._options_dict["test.config.int"]
del config._options_dict["test.config.int.none"]
def test_get_set_reset_option(self):
self.assertEqual(ps.get_option("test.config"), "default")
ps.set_option("test.config", "value")
self.assertEqual(ps.get_option("test.config"), "value")
ps.reset_option("test.config")
self.assertEqual(ps.get_option("test.config"), "default")
def test_get_set_reset_option_different_types(self):
ps.set_option("test.config.list", [1, 2, 3, 4])
self.assertEqual(ps.get_option("test.config.list"), [1, 2, 3, 4])
ps.set_option("test.config.float", 5.0)
self.assertEqual(ps.get_option("test.config.float"), 5.0)
ps.set_option("test.config.int", 123)
self.assertEqual(ps.get_option("test.config.int"), 123)
self.assertEqual(ps.get_option("test.config.int.none"), None) # default None
ps.set_option("test.config.int.none", 123)
self.assertEqual(ps.get_option("test.config.int.none"), 123)
ps.set_option("test.config.int.none", None)
self.assertEqual(ps.get_option("test.config.int.none"), None)
def test_different_types(self):
with self.assertRaisesRegex(ValueError, "was <class 'int'>"):
ps.set_option("test.config.list", 1)
with self.assertRaisesRegex(ValueError, "however, expected types are"):
ps.set_option("test.config.float", "abc")
with self.assertRaisesRegex(ValueError, "[<class 'int'>]"):
ps.set_option("test.config.int", "abc")
with self.assertRaisesRegex(ValueError, "(<class 'int'>, <class 'NoneType'>)"):
ps.set_option("test.config.int.none", "abc")
def test_check_func(self):
with self.assertRaisesRegex(ValueError, "bigger then 0"):
ps.set_option("test.config.int", -1)
def test_unknown_option(self):
with self.assertRaisesRegex(config.OptionError, "No such option"):
ps.get_option("unknown")
with self.assertRaisesRegex(config.OptionError, "Available options"):
ps.set_option("unknown", "value")
with self.assertRaisesRegex(config.OptionError, "test.config"):
ps.reset_option("unknown")
def test_namespace_access(self):
try:
self.assertEqual(ps.options.compute.max_rows, ps.get_option("compute.max_rows"))
ps.options.compute.max_rows = 0
self.assertEqual(ps.options.compute.max_rows, 0)
self.assertTrue(isinstance(ps.options.compute, DictWrapper))
wrapper = ps.options.compute
self.assertEqual(wrapper.max_rows, ps.get_option("compute.max_rows"))
wrapper.max_rows = 1000
self.assertEqual(ps.options.compute.max_rows, 1000)
self.assertRaisesRegex(config.OptionError, "No such option", lambda: ps.options.compu)
self.assertRaisesRegex(
config.OptionError, "No such option", lambda: ps.options.compute.max
)
self.assertRaisesRegex(
config.OptionError, "No such option", lambda: ps.options.max_rows1
)
with self.assertRaisesRegex(config.OptionError, "No such option"):
ps.options.compute.max = 0
with self.assertRaisesRegex(config.OptionError, "No such option"):
ps.options.compute = 0
with self.assertRaisesRegex(config.OptionError, "No such option"):
ps.options.com = 0
finally:
ps.reset_option("compute.max_rows")
def test_dir_options(self):
self.assertTrue("compute.default_index_type" in dir(ps.options))
self.assertTrue("plotting.sample_ratio" in dir(ps.options))
self.assertTrue("default_index_type" in dir(ps.options.compute))
self.assertTrue("sample_ratio" not in dir(ps.options.compute))
self.assertTrue("default_index_type" not in dir(ps.options.plotting))
self.assertTrue("sample_ratio" in dir(ps.options.plotting))
if __name__ == "__main__":
import unittest
from pyspark.pandas.tests.test_config import * # noqa: F401
try:
import xmlrunner # type: ignore[import]
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)

View file

@ -0,0 +1,51 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pandas as pd
from pyspark import pandas as ps
from pyspark.pandas.testing.utils import ReusedSQLTestCase
class DefaultIndexTest(ReusedSQLTestCase):
def test_default_index_sequence(self):
with ps.option_context("compute.default_index_type", "sequence"):
sdf = self.spark.range(1000)
self.assert_eq(ps.DataFrame(sdf), pd.DataFrame({"id": list(range(1000))}))
def test_default_index_distributed_sequence(self):
with ps.option_context("compute.default_index_type", "distributed-sequence"):
sdf = self.spark.range(1000)
self.assert_eq(ps.DataFrame(sdf), pd.DataFrame({"id": list(range(1000))}))
def test_default_index_distributed(self):
with ps.option_context("compute.default_index_type", "distributed"):
sdf = self.spark.range(1000)
pdf = ps.DataFrame(sdf).to_pandas()
self.assertEqual(len(set(pdf.index)), len(pdf))
if __name__ == "__main__":
import unittest
from pyspark.pandas.tests.test_default_index import * # noqa: F401
try:
import xmlrunner # type: ignore[import]
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)

View file

@ -0,0 +1,151 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import contextlib
import numpy as np
import pandas as pd
from pyspark import pandas as ps
from pyspark.pandas.testing.utils import assert_produces_warning, ReusedSQLTestCase
from pyspark.pandas.extensions import (
register_dataframe_accessor,
register_series_accessor,
register_index_accessor,
)
@contextlib.contextmanager
def ensure_removed(obj, attr):
"""
Ensure attribute attached to 'obj' during testing is removed in the end
"""
try:
yield
finally:
try:
delattr(obj, attr)
except AttributeError:
pass
class CustomAccessor:
def __init__(self, obj):
self.obj = obj
self.item = "item"
@property
def prop(self):
return self.item
def method(self):
return self.item
def check_length(self, col=None):
if type(self.obj) == ps.DataFrame or col is not None:
return len(self.obj[col])
else:
try:
return len(self.obj)
except Exception as e:
raise ValueError(str(e))
class ExtensionTest(ReusedSQLTestCase):
@property
def pdf(self):
return pd.DataFrame(
{"a": [1, 2, 3, 4, 5, 6, 7, 8, 9], "b": [4, 5, 6, 3, 2, 1, 0, 0, 0]},
index=np.random.rand(9),
)
@property
def kdf(self):
return ps.from_pandas(self.pdf)
@property
def accessor(self):
return CustomAccessor(self.kdf)
def test_setup(self):
self.assertEqual("item", self.accessor.item)
def test_dataframe_register(self):
with ensure_removed(ps.DataFrame, "test"):
register_dataframe_accessor("test")(CustomAccessor)
assert self.kdf.test.prop == "item"
assert self.kdf.test.method() == "item"
assert len(self.kdf["a"]) == self.kdf.test.check_length("a")
def test_series_register(self):
with ensure_removed(ps.Series, "test"):
register_series_accessor("test")(CustomAccessor)
assert self.kdf.a.test.prop == "item"
assert self.kdf.a.test.method() == "item"
assert self.kdf.a.test.check_length() == len(self.kdf["a"])
def test_index_register(self):
with ensure_removed(ps.Index, "test"):
register_index_accessor("test")(CustomAccessor)
assert self.kdf.index.test.prop == "item"
assert self.kdf.index.test.method() == "item"
assert self.kdf.index.test.check_length() == self.kdf.index.size
def test_accessor_works(self):
register_series_accessor("test")(CustomAccessor)
s = ps.Series([1, 2])
assert s.test.obj is s
assert s.test.prop == "item"
assert s.test.method() == "item"
def test_overwrite_warns(self):
mean = ps.Series.mean
try:
with assert_produces_warning(UserWarning, raise_on_extra_warnings=False) as w:
register_series_accessor("mean")(CustomAccessor)
s = ps.Series([1, 2])
assert s.mean.prop == "item"
msg = str(w[0].message)
assert "mean" in msg
assert "CustomAccessor" in msg
assert "Series" in msg
finally:
ps.Series.mean = mean
def test_raises_attr_error(self):
with ensure_removed(ps.Series, "bad"):
class Bad:
def __init__(self, data):
raise AttributeError("whoops")
with self.assertRaises(AttributeError):
ps.Series([1, 2], dtype=object).bad
if __name__ == "__main__":
import unittest
from pyspark.pandas.tests.test_extension import * # noqa: F401
try:
import xmlrunner # type: ignore[import]
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)

View file

@ -0,0 +1,103 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pandas as pd
from pyspark.pandas.internal import (
InternalFrame,
SPARK_DEFAULT_INDEX_NAME,
SPARK_INDEX_NAME_FORMAT,
)
from pyspark.pandas.testing.utils import ReusedSQLTestCase, SQLTestUtils
class InternalFrameTest(ReusedSQLTestCase, SQLTestUtils):
def test_from_pandas(self):
pdf = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
internal = InternalFrame.from_pandas(pdf)
sdf = internal.spark_frame
self.assert_eq(internal.index_spark_column_names, [SPARK_DEFAULT_INDEX_NAME])
self.assert_eq(internal.index_names, [None])
self.assert_eq(internal.column_labels, [("a",), ("b",)])
self.assert_eq(internal.data_spark_column_names, ["a", "b"])
self.assertTrue(internal.spark_column_for(("a",))._jc.equals(sdf["a"]._jc))
self.assertTrue(internal.spark_column_for(("b",))._jc.equals(sdf["b"]._jc))
self.assert_eq(internal.to_pandas_frame, pdf)
# non-string column name
pdf1 = pd.DataFrame({0: [1, 2, 3], 1: [4, 5, 6]})
internal = InternalFrame.from_pandas(pdf1)
sdf = internal.spark_frame
self.assert_eq(internal.index_spark_column_names, [SPARK_DEFAULT_INDEX_NAME])
self.assert_eq(internal.index_names, [None])
self.assert_eq(internal.column_labels, [(0,), (1,)])
self.assert_eq(internal.data_spark_column_names, ["0", "1"])
self.assertTrue(internal.spark_column_for((0,))._jc.equals(sdf["0"]._jc))
self.assertTrue(internal.spark_column_for((1,))._jc.equals(sdf["1"]._jc))
self.assert_eq(internal.to_pandas_frame, pdf1)
# multi-index
pdf.set_index("a", append=True, inplace=True)
internal = InternalFrame.from_pandas(pdf)
sdf = internal.spark_frame
self.assert_eq(
internal.index_spark_column_names,
[SPARK_INDEX_NAME_FORMAT(0), SPARK_INDEX_NAME_FORMAT(1)],
)
self.assert_eq(internal.index_names, [None, ("a",)])
self.assert_eq(internal.column_labels, [("b",)])
self.assert_eq(internal.data_spark_column_names, ["b"])
self.assertTrue(internal.spark_column_for(("b",))._jc.equals(sdf["b"]._jc))
self.assert_eq(internal.to_pandas_frame, pdf)
# multi-index columns
pdf.columns = pd.MultiIndex.from_tuples([("x", "b")])
internal = InternalFrame.from_pandas(pdf)
sdf = internal.spark_frame
self.assert_eq(
internal.index_spark_column_names,
[SPARK_INDEX_NAME_FORMAT(0), SPARK_INDEX_NAME_FORMAT(1)],
)
self.assert_eq(internal.index_names, [None, ("a",)])
self.assert_eq(internal.column_labels, [("x", "b")])
self.assert_eq(internal.data_spark_column_names, ["(x, b)"])
self.assertTrue(internal.spark_column_for(("x", "b"))._jc.equals(sdf["(x, b)"]._jc))
self.assert_eq(internal.to_pandas_frame, pdf)
if __name__ == "__main__":
import unittest
from pyspark.pandas.tests.test_internal import * # noqa: F401
try:
import xmlrunner # type: ignore[import]
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)

View file

@ -0,0 +1,210 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from distutils.version import LooseVersion
import numpy as np
import pandas as pd
from pyspark import pandas as ps
from pyspark.pandas import set_option, reset_option
from pyspark.pandas.numpy_compat import unary_np_spark_mappings, binary_np_spark_mappings
from pyspark.pandas.testing.utils import ReusedSQLTestCase, SQLTestUtils
class NumPyCompatTest(ReusedSQLTestCase, SQLTestUtils):
blacklist = [
# Koalas does not currently support
"conj",
"conjugate",
"isnat",
"matmul",
"frexp",
# Values are close enough but tests failed.
"arccos",
"exp",
"expm1",
"log", # flaky
"log10", # flaky
"log1p", # flaky
"modf",
"floor_divide", # flaky
# Results seem inconsistent in a different version of, I (Hyukjin) suspect, PyArrow.
# From PyArrow 0.15, seems it returns the correct results via PySpark. Probably we
# can enable it later when Koalas switches to PyArrow 0.15 completely.
"left_shift",
]
@property
def pdf(self):
return pd.DataFrame(
{"a": [1, 2, 3, 4, 5, 6, 7, 8, 9], "b": [4, 5, 6, 3, 2, 1, 0, 0, 0]},
index=[0, 1, 3, 5, 6, 8, 9, 9, 9],
)
@property
def kdf(self):
return ps.from_pandas(self.pdf)
def test_np_add_series(self):
kdf = self.kdf
pdf = self.pdf
if LooseVersion(pd.__version__) < LooseVersion("0.25"):
self.assert_eq(np.add(kdf.a, kdf.b), np.add(pdf.a, pdf.b).rename())
else:
self.assert_eq(np.add(kdf.a, kdf.b), np.add(pdf.a, pdf.b))
kdf = self.kdf
pdf = self.pdf
self.assert_eq(np.add(kdf.a, 1), np.add(pdf.a, 1))
def test_np_add_index(self):
k_index = self.kdf.index
p_index = self.pdf.index
self.assert_eq(np.add(k_index, k_index), np.add(p_index, p_index))
def test_np_unsupported_series(self):
kdf = self.kdf
with self.assertRaisesRegex(NotImplementedError, "Koalas.*not.*support.*sqrt.*"):
np.sqrt(kdf.a, kdf.b)
def test_np_unsupported_frame(self):
kdf = self.kdf
with self.assertRaisesRegex(NotImplementedError, "Koalas.*not.*support.*sqrt.*"):
np.sqrt(kdf, kdf)
def test_np_spark_compat_series(self):
# Use randomly generated dataFrame
pdf = pd.DataFrame(
np.random.randint(-100, 100, size=(np.random.randint(100), 2)), columns=["a", "b"]
)
pdf2 = pd.DataFrame(
np.random.randint(-100, 100, size=(len(pdf), len(pdf.columns))), columns=["a", "b"]
)
kdf = ps.from_pandas(pdf)
kdf2 = ps.from_pandas(pdf2)
for np_name, spark_func in unary_np_spark_mappings.items():
np_func = getattr(np, np_name)
if np_name not in self.blacklist:
try:
# unary ufunc
self.assert_eq(np_func(pdf.a), np_func(kdf.a), almost=True)
except Exception as e:
raise AssertionError("Test in '%s' function was failed." % np_name) from e
for np_name, spark_func in binary_np_spark_mappings.items():
np_func = getattr(np, np_name)
if np_name not in self.blacklist:
try:
# binary ufunc
if LooseVersion(pd.__version__) < LooseVersion("0.25"):
self.assert_eq(
np_func(pdf.a, pdf.b).rename(), np_func(kdf.a, kdf.b), almost=True
)
else:
self.assert_eq(np_func(pdf.a, pdf.b), np_func(kdf.a, kdf.b), almost=True)
self.assert_eq(np_func(pdf.a, 1), np_func(kdf.a, 1), almost=True)
except Exception as e:
raise AssertionError("Test in '%s' function was failed." % np_name) from e
# Test only top 5 for now. 'compute.ops_on_diff_frames' option increases too much time.
try:
set_option("compute.ops_on_diff_frames", True)
for np_name, spark_func in list(binary_np_spark_mappings.items())[:5]:
np_func = getattr(np, np_name)
if np_name not in self.blacklist:
try:
# binary ufunc
if LooseVersion(pd.__version__) < LooseVersion("0.25"):
self.assert_eq(
np_func(pdf.a, pdf2.b).sort_index().rename(),
np_func(kdf.a, kdf2.b).sort_index(),
almost=True,
)
else:
self.assert_eq(
np_func(pdf.a, pdf2.b).sort_index(),
np_func(kdf.a, kdf2.b).sort_index(),
almost=True,
)
except Exception as e:
raise AssertionError("Test in '%s' function was failed." % np_name) from e
finally:
reset_option("compute.ops_on_diff_frames")
def test_np_spark_compat_frame(self):
# Use randomly generated dataFrame
pdf = pd.DataFrame(
np.random.randint(-100, 100, size=(np.random.randint(100), 2)), columns=["a", "b"]
)
pdf2 = pd.DataFrame(
np.random.randint(-100, 100, size=(len(pdf), len(pdf.columns))), columns=["a", "b"]
)
kdf = ps.from_pandas(pdf)
kdf2 = ps.from_pandas(pdf2)
for np_name, spark_func in unary_np_spark_mappings.items():
np_func = getattr(np, np_name)
if np_name not in self.blacklist:
try:
# unary ufunc
self.assert_eq(np_func(pdf), np_func(kdf), almost=True)
except Exception as e:
raise AssertionError("Test in '%s' function was failed." % np_name) from e
for np_name, spark_func in binary_np_spark_mappings.items():
np_func = getattr(np, np_name)
if np_name not in self.blacklist:
try:
# binary ufunc
self.assert_eq(np_func(pdf, pdf), np_func(kdf, kdf), almost=True)
self.assert_eq(np_func(pdf, 1), np_func(kdf, 1), almost=True)
except Exception as e:
raise AssertionError("Test in '%s' function was failed." % np_name) from e
# Test only top 5 for now. 'compute.ops_on_diff_frames' option increases too much time.
try:
set_option("compute.ops_on_diff_frames", True)
for np_name, spark_func in list(binary_np_spark_mappings.items())[:5]:
np_func = getattr(np, np_name)
if np_name not in self.blacklist:
try:
# binary ufunc
self.assert_eq(
np_func(pdf, pdf2).sort_index(),
np_func(kdf, kdf2).sort_index(),
almost=True,
)
except Exception as e:
raise AssertionError("Test in '%s' function was failed." % np_name) from e
finally:
reset_option("compute.ops_on_diff_frames")
if __name__ == "__main__":
import unittest
from pyspark.pandas.tests.test_numpy_compat import * # noqa: F401
try:
import xmlrunner # type: ignore[import]
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)

View file

@ -0,0 +1,437 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
import unittest
import datetime
import decimal
from typing import List
import pandas
import pandas as pd
from pandas.api.types import CategoricalDtype
import numpy as np
from pyspark.sql.types import (
ArrayType,
BinaryType,
BooleanType,
FloatType,
IntegerType,
LongType,
StringType,
StructField,
StructType,
ByteType,
ShortType,
DateType,
DecimalType,
DoubleType,
TimestampType,
)
from pyspark.pandas.typedef import (
as_spark_type,
extension_dtypes_available,
extension_float_dtypes_available,
extension_object_dtypes_available,
infer_return_type,
koalas_dtype,
)
from pyspark import pandas as ps
class TypeHintTests(unittest.TestCase):
@unittest.skipIf(
sys.version_info < (3, 7),
"Type inference from pandas instances is supported with Python 3.7+",
)
def test_infer_schema_from_pandas_instances(self):
def func() -> pd.Series[int]:
pass
inferred = infer_return_type(func)
self.assertEqual(inferred.dtype, np.int64)
self.assertEqual(inferred.spark_type, LongType())
def func() -> pd.Series[np.float]:
pass
inferred = infer_return_type(func)
self.assertEqual(inferred.dtype, np.float64)
self.assertEqual(inferred.spark_type, DoubleType())
def func() -> "pd.DataFrame[np.float, str]":
pass
expected = StructType([StructField("c0", DoubleType()), StructField("c1", StringType())])
inferred = infer_return_type(func)
self.assertEqual(inferred.dtypes, [np.float64, np.unicode_])
self.assertEqual(inferred.spark_type, expected)
def func() -> "pandas.DataFrame[np.float]":
pass
expected = StructType([StructField("c0", DoubleType())])
inferred = infer_return_type(func)
self.assertEqual(inferred.dtypes, [np.float64])
self.assertEqual(inferred.spark_type, expected)
def func() -> "pd.Series[int]":
pass
inferred = infer_return_type(func)
self.assertEqual(inferred.dtype, np.int64)
self.assertEqual(inferred.spark_type, LongType())
def func() -> pd.DataFrame[np.float, str]:
pass
expected = StructType([StructField("c0", DoubleType()), StructField("c1", StringType())])
inferred = infer_return_type(func)
self.assertEqual(inferred.dtypes, [np.float64, np.unicode_])
self.assertEqual(inferred.spark_type, expected)
def func() -> pd.DataFrame[np.float]:
pass
expected = StructType([StructField("c0", DoubleType())])
inferred = infer_return_type(func)
self.assertEqual(inferred.dtypes, [np.float64])
self.assertEqual(inferred.spark_type, expected)
pdf = pd.DataFrame({"a": [1, 2, 3], "b": [3, 4, 5]})
def func() -> pd.DataFrame[pdf.dtypes]: # type: ignore
pass
expected = StructType([StructField("c0", LongType()), StructField("c1", LongType())])
inferred = infer_return_type(func)
self.assertEqual(inferred.dtypes, [np.int64, np.int64])
self.assertEqual(inferred.spark_type, expected)
pdf = pd.DataFrame({"a": [1, 2, 3], "b": pd.Categorical(["a", "b", "c"])})
def func() -> pd.Series[pdf.b.dtype]: # type: ignore
pass
inferred = infer_return_type(func)
self.assertEqual(inferred.dtype, CategoricalDtype(categories=["a", "b", "c"]))
self.assertEqual(inferred.spark_type, LongType())
def func() -> pd.DataFrame[pdf.dtypes]: # type: ignore
pass
expected = StructType([StructField("c0", LongType()), StructField("c1", LongType())])
inferred = infer_return_type(func)
self.assertEqual(inferred.dtypes, [np.int64, CategoricalDtype(categories=["a", "b", "c"])])
self.assertEqual(inferred.spark_type, expected)
def test_if_pandas_implements_class_getitem(self):
# the current type hint implementation of pandas DataFrame assumes pandas doesn't
# implement '__class_getitem__'. This test case is to make sure pandas
# doesn't implement them.
assert not ps._frame_has_class_getitem
assert not ps._series_has_class_getitem
@unittest.skipIf(
sys.version_info < (3, 7),
"Type inference from pandas instances is supported with Python 3.7+",
)
def test_infer_schema_with_names_pandas_instances(self):
def func() -> 'pd.DataFrame["a" : np.float, "b":str]': # noqa: F405
pass
expected = StructType([StructField("a", DoubleType()), StructField("b", StringType())])
inferred = infer_return_type(func)
self.assertEqual(inferred.dtypes, [np.float64, np.unicode_])
self.assertEqual(inferred.spark_type, expected)
def func() -> "pd.DataFrame['a': np.float, 'b': int]": # noqa: F405
pass
expected = StructType([StructField("a", DoubleType()), StructField("b", LongType())])
inferred = infer_return_type(func)
self.assertEqual(inferred.dtypes, [np.float64, np.int64])
self.assertEqual(inferred.spark_type, expected)
pdf = pd.DataFrame({"a": [1, 2, 3], "b": [3, 4, 5]})
def func() -> pd.DataFrame[zip(pdf.columns, pdf.dtypes)]:
pass
expected = StructType([StructField("a", LongType()), StructField("b", LongType())])
inferred = infer_return_type(func)
self.assertEqual(inferred.dtypes, [np.int64, np.int64])
self.assertEqual(inferred.spark_type, expected)
pdf = pd.DataFrame({("x", "a"): [1, 2, 3], ("y", "b"): [3, 4, 5]})
def func() -> pd.DataFrame[zip(pdf.columns, pdf.dtypes)]:
pass
expected = StructType(
[StructField("(x, a)", LongType()), StructField("(y, b)", LongType())]
)
inferred = infer_return_type(func)
self.assertEqual(inferred.dtypes, [np.int64, np.int64])
self.assertEqual(inferred.spark_type, expected)
pdf = pd.DataFrame({"a": [1, 2, 3], "b": pd.Categorical(["a", "b", "c"])})
def func() -> pd.DataFrame[zip(pdf.columns, pdf.dtypes)]:
pass
expected = StructType([StructField("a", LongType()), StructField("b", LongType())])
inferred = infer_return_type(func)
self.assertEqual(inferred.dtypes, [np.int64, CategoricalDtype(categories=["a", "b", "c"])])
self.assertEqual(inferred.spark_type, expected)
@unittest.skipIf(
sys.version_info < (3, 7),
"Type inference from pandas instances is supported with Python 3.7+",
)
def test_infer_schema_with_names_pandas_instances_negative(self):
def try_infer_return_type():
def f() -> 'pd.DataFrame["a" : np.float : 1, "b":str:2]': # noqa: F405
pass
infer_return_type(f)
self.assertRaisesRegex(TypeError, "Type hints should be specified", try_infer_return_type)
class A:
pass
def try_infer_return_type():
def f() -> pd.DataFrame[A]:
pass
infer_return_type(f)
self.assertRaisesRegex(TypeError, "not understood", try_infer_return_type)
def try_infer_return_type():
def f() -> 'pd.DataFrame["a" : np.float : 1, "b":str:2]': # noqa: F405
pass
infer_return_type(f)
self.assertRaisesRegex(TypeError, "Type hints should be specified", try_infer_return_type)
# object type
pdf = pd.DataFrame({"a": ["a", 2, None]})
def try_infer_return_type():
def f() -> pd.DataFrame[pdf.dtypes]: # type: ignore
pass
infer_return_type(f)
self.assertRaisesRegex(TypeError, "object.*not understood", try_infer_return_type)
def try_infer_return_type():
def f() -> pd.Series[pdf.a.dtype]: # type: ignore
pass
infer_return_type(f)
self.assertRaisesRegex(TypeError, "object.*not understood", try_infer_return_type)
def test_infer_schema_with_names_negative(self):
def try_infer_return_type():
def f() -> 'ps.DataFrame["a" : np.float : 1, "b":str:2]': # noqa: F405
pass
infer_return_type(f)
self.assertRaisesRegex(TypeError, "Type hints should be specified", try_infer_return_type)
class A:
pass
def try_infer_return_type():
def f() -> ps.DataFrame[A]:
pass
infer_return_type(f)
self.assertRaisesRegex(TypeError, "not understood", try_infer_return_type)
def try_infer_return_type():
def f() -> 'ps.DataFrame["a" : np.float : 1, "b":str:2]': # noqa: F405
pass
infer_return_type(f)
self.assertRaisesRegex(TypeError, "Type hints should be specified", try_infer_return_type)
# object type
pdf = pd.DataFrame({"a": ["a", 2, None]})
def try_infer_return_type():
def f() -> ps.DataFrame[pdf.dtypes]: # type: ignore
pass
infer_return_type(f)
self.assertRaisesRegex(TypeError, "object.*not understood", try_infer_return_type)
def try_infer_return_type():
def f() -> ps.Series[pdf.a.dtype]: # type: ignore
pass
infer_return_type(f)
self.assertRaisesRegex(TypeError, "object.*not understood", try_infer_return_type)
def test_as_spark_type_koalas_dtype(self):
type_mapper = {
# binary
np.character: (np.character, BinaryType()),
np.bytes_: (np.bytes_, BinaryType()),
np.string_: (np.bytes_, BinaryType()),
bytes: (np.bytes_, BinaryType()),
# integer
np.int8: (np.int8, ByteType()),
np.byte: (np.int8, ByteType()),
np.int16: (np.int16, ShortType()),
np.int32: (np.int32, IntegerType()),
np.int64: (np.int64, LongType()),
np.int: (np.int64, LongType()),
int: (np.int64, LongType()),
# floating
np.float32: (np.float32, FloatType()),
np.float: (np.float64, DoubleType()),
np.float64: (np.float64, DoubleType()),
float: (np.float64, DoubleType()),
# string
np.str: (np.unicode_, StringType()),
np.unicode_: (np.unicode_, StringType()),
str: (np.unicode_, StringType()),
# bool
np.bool: (np.bool, BooleanType()),
bool: (np.bool, BooleanType()),
# datetime
np.datetime64: (np.datetime64, TimestampType()),
datetime.datetime: (np.dtype("datetime64[ns]"), TimestampType()),
# DateType
datetime.date: (np.dtype("object"), DateType()),
# DecimalType
decimal.Decimal: (np.dtype("object"), DecimalType(38, 18)),
# ArrayType
np.ndarray: (np.dtype("object"), ArrayType(StringType())),
List[bytes]: (np.dtype("object"), ArrayType(BinaryType())),
List[np.character]: (np.dtype("object"), ArrayType(BinaryType())),
List[np.bytes_]: (np.dtype("object"), ArrayType(BinaryType())),
List[np.string_]: (np.dtype("object"), ArrayType(BinaryType())),
List[bool]: (np.dtype("object"), ArrayType(BooleanType())),
List[np.bool]: (np.dtype("object"), ArrayType(BooleanType())),
List[datetime.date]: (np.dtype("object"), ArrayType(DateType())),
List[np.int8]: (np.dtype("object"), ArrayType(ByteType())),
List[np.byte]: (np.dtype("object"), ArrayType(ByteType())),
List[decimal.Decimal]: (np.dtype("object"), ArrayType(DecimalType(38, 18))),
List[float]: (np.dtype("object"), ArrayType(DoubleType())),
List[np.float]: (np.dtype("object"), ArrayType(DoubleType())),
List[np.float64]: (np.dtype("object"), ArrayType(DoubleType())),
List[np.float32]: (np.dtype("object"), ArrayType(FloatType())),
List[np.int32]: (np.dtype("object"), ArrayType(IntegerType())),
List[int]: (np.dtype("object"), ArrayType(LongType())),
List[np.int]: (np.dtype("object"), ArrayType(LongType())),
List[np.int64]: (np.dtype("object"), ArrayType(LongType())),
List[np.int16]: (np.dtype("object"), ArrayType(ShortType())),
List[str]: (np.dtype("object"), ArrayType(StringType())),
List[np.unicode_]: (np.dtype("object"), ArrayType(StringType())),
List[datetime.datetime]: (np.dtype("object"), ArrayType(TimestampType())),
List[np.datetime64]: (np.dtype("object"), ArrayType(TimestampType())),
# CategoricalDtype
CategoricalDtype(categories=["a", "b", "c"]): (
CategoricalDtype(categories=["a", "b", "c"]),
LongType(),
),
}
for numpy_or_python_type, (dtype, spark_type) in type_mapper.items():
self.assertEqual(as_spark_type(numpy_or_python_type), spark_type)
self.assertEqual(koalas_dtype(numpy_or_python_type), (dtype, spark_type))
with self.assertRaisesRegex(TypeError, "Type uint64 was not understood."):
as_spark_type(np.dtype("uint64"))
with self.assertRaisesRegex(TypeError, "Type object was not understood."):
as_spark_type(np.dtype("object"))
with self.assertRaisesRegex(TypeError, "Type uint64 was not understood."):
koalas_dtype(np.dtype("uint64"))
with self.assertRaisesRegex(TypeError, "Type object was not understood."):
koalas_dtype(np.dtype("object"))
@unittest.skipIf(not extension_dtypes_available, "The pandas extension types are not available")
def test_as_spark_type_extension_dtypes(self):
from pandas import Int8Dtype, Int16Dtype, Int32Dtype, Int64Dtype
type_mapper = {
Int8Dtype(): ByteType(),
Int16Dtype(): ShortType(),
Int32Dtype(): IntegerType(),
Int64Dtype(): LongType(),
}
for extension_dtype, spark_type in type_mapper.items():
self.assertEqual(as_spark_type(extension_dtype), spark_type)
self.assertEqual(koalas_dtype(extension_dtype), (extension_dtype, spark_type))
@unittest.skipIf(
not extension_object_dtypes_available, "The pandas extension object types are not available"
)
def test_as_spark_type_extension_object_dtypes(self):
from pandas import BooleanDtype, StringDtype
type_mapper = {
BooleanDtype(): BooleanType(),
StringDtype(): StringType(),
}
for extension_dtype, spark_type in type_mapper.items():
self.assertEqual(as_spark_type(extension_dtype), spark_type)
self.assertEqual(koalas_dtype(extension_dtype), (extension_dtype, spark_type))
@unittest.skipIf(
not extension_float_dtypes_available, "The pandas extension float types are not available"
)
def test_as_spark_type_extension_float_dtypes(self):
from pandas import Float32Dtype, Float64Dtype
type_mapper = {
Float32Dtype(): FloatType(),
Float64Dtype(): DoubleType(),
}
for extension_dtype, spark_type in type_mapper.items():
self.assertEqual(as_spark_type(extension_dtype), spark_type)
self.assertEqual(koalas_dtype(extension_dtype), (extension_dtype, spark_type))
if __name__ == "__main__":
from pyspark.pandas.tests.test_typedef import * # noqa: F401
try:
import xmlrunner # type: ignore[import]
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)

View file

@ -0,0 +1,104 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pandas as pd
from pyspark.pandas.testing.utils import ReusedSQLTestCase, SQLTestUtils
from pyspark.pandas.utils import (
lazy_property,
validate_arguments_and_invoke_function,
validate_bool_kwarg,
)
some_global_variable = 0
class UtilsTest(ReusedSQLTestCase, SQLTestUtils):
# a dummy to_html version with an extra parameter that pandas does not support
# used in test_validate_arguments_and_invoke_function
def to_html(self, max_rows=None, unsupported_param=None):
args = locals()
pdf = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}, index=[0, 1, 3])
validate_arguments_and_invoke_function(pdf, self.to_html, pd.DataFrame.to_html, args)
def to_clipboard(self, sep=",", **kwargs):
args = locals()
pdf = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}, index=[0, 1, 3])
validate_arguments_and_invoke_function(
pdf, self.to_clipboard, pd.DataFrame.to_clipboard, args
)
# Support for **kwargs
self.to_clipboard(sep=",", index=False)
def test_validate_arguments_and_invoke_function(self):
# This should pass and run fine
self.to_html()
self.to_html(unsupported_param=None)
self.to_html(max_rows=5)
# This should fail because we are explicitly setting an unsupported param
# to a non-default value
with self.assertRaises(TypeError):
self.to_html(unsupported_param=1)
def test_lazy_property(self):
obj = TestClassForLazyProp()
# If lazy prop is not working, the second test would fail (because it'd be 2)
self.assert_eq(obj.lazy_prop, 1)
self.assert_eq(obj.lazy_prop, 1)
def test_validate_bool_kwarg(self):
# This should pass and run fine
koalas = True
self.assert_eq(validate_bool_kwarg(koalas, "koalas"), True)
koalas = False
self.assert_eq(validate_bool_kwarg(koalas, "koalas"), False)
koalas = None
self.assert_eq(validate_bool_kwarg(koalas, "koalas"), None)
# This should fail because we are explicitly setting a non-boolean value
koalas = "true"
with self.assertRaisesRegex(
ValueError, 'For argument "koalas" expected type bool, received type str.'
):
validate_bool_kwarg(koalas, "koalas")
class TestClassForLazyProp:
def __init__(self):
self.some_variable = 0
@lazy_property
def lazy_prop(self):
self.some_variable += 1
return self.some_variable
if __name__ == "__main__":
import unittest
from pyspark.pandas.tests.test_utils import * # noqa: F401
try:
import xmlrunner # type: ignore[import]
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)