[SPARK-35035][PYTHON] Port Koalas internal implementation unit tests into PySpark
### What changes were proposed in this pull request? Now that we merged the Koalas main code into the PySpark code base (#32036), we should port the Koalas internal implementation unit tests to PySpark. ### Why are the changes needed? Currently, the pandas-on-Spark modules are not tested fully. We should enable the internal implementation unit tests. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Enable internal implementation unit tests. Closes #32137 from xinrong-databricks/port.test_internal_impl. Lead-authored-by: Xinrong Meng <xinrong.meng@databricks.com> Co-authored-by: xinrong-databricks <47337188+xinrong-databricks@users.noreply.github.com> Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
parent
2974b70d1e
commit
47d62af2a9
|
@ -612,6 +612,13 @@ pyspark_pandas = Module(
|
|||
"pyspark.pandas.typedef.typehints",
|
||||
# unittests
|
||||
"pyspark.pandas.tests.test_dataframe",
|
||||
"pyspark.pandas.tests.test_config",
|
||||
"pyspark.pandas.tests.test_default_index",
|
||||
"pyspark.pandas.tests.test_extension",
|
||||
"pyspark.pandas.tests.test_internal",
|
||||
"pyspark.pandas.tests.test_numpy_compat",
|
||||
"pyspark.pandas.tests.test_typedef",
|
||||
"pyspark.pandas.tests.test_utils",
|
||||
"pyspark.pandas.tests.test_dataframe_conversion",
|
||||
"pyspark.pandas.tests.test_dataframe_spark_io",
|
||||
"pyspark.pandas.tests.test_frame_spark",
|
||||
|
|
155
python/pyspark/pandas/tests/test_config.py
Normal file
155
python/pyspark/pandas/tests/test_config.py
Normal file
|
@ -0,0 +1,155 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright ownership.
|
||||
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
# (the "License"); you may not use this file except in compliance with
|
||||
# the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from pyspark import pandas as ps
|
||||
from pyspark.pandas import config
|
||||
from pyspark.pandas.config import Option, DictWrapper
|
||||
from pyspark.pandas.testing.utils import ReusedSQLTestCase
|
||||
|
||||
|
||||
class ConfigTest(ReusedSQLTestCase):
|
||||
def setUp(self):
|
||||
config._options_dict["test.config"] = Option(key="test.config", doc="", default="default")
|
||||
|
||||
config._options_dict["test.config.list"] = Option(
|
||||
key="test.config.list", doc="", default=[], types=list
|
||||
)
|
||||
config._options_dict["test.config.float"] = Option(
|
||||
key="test.config.float", doc="", default=1.2, types=float
|
||||
)
|
||||
|
||||
config._options_dict["test.config.int"] = Option(
|
||||
key="test.config.int",
|
||||
doc="",
|
||||
default=1,
|
||||
types=int,
|
||||
check_func=(lambda v: v > 0, "bigger then 0"),
|
||||
)
|
||||
config._options_dict["test.config.int.none"] = Option(
|
||||
key="test.config.int", doc="", default=None, types=(int, type(None))
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
ps.reset_option("test.config")
|
||||
del config._options_dict["test.config"]
|
||||
del config._options_dict["test.config.list"]
|
||||
del config._options_dict["test.config.float"]
|
||||
del config._options_dict["test.config.int"]
|
||||
del config._options_dict["test.config.int.none"]
|
||||
|
||||
def test_get_set_reset_option(self):
|
||||
self.assertEqual(ps.get_option("test.config"), "default")
|
||||
|
||||
ps.set_option("test.config", "value")
|
||||
self.assertEqual(ps.get_option("test.config"), "value")
|
||||
|
||||
ps.reset_option("test.config")
|
||||
self.assertEqual(ps.get_option("test.config"), "default")
|
||||
|
||||
def test_get_set_reset_option_different_types(self):
|
||||
ps.set_option("test.config.list", [1, 2, 3, 4])
|
||||
self.assertEqual(ps.get_option("test.config.list"), [1, 2, 3, 4])
|
||||
|
||||
ps.set_option("test.config.float", 5.0)
|
||||
self.assertEqual(ps.get_option("test.config.float"), 5.0)
|
||||
|
||||
ps.set_option("test.config.int", 123)
|
||||
self.assertEqual(ps.get_option("test.config.int"), 123)
|
||||
|
||||
self.assertEqual(ps.get_option("test.config.int.none"), None) # default None
|
||||
ps.set_option("test.config.int.none", 123)
|
||||
self.assertEqual(ps.get_option("test.config.int.none"), 123)
|
||||
ps.set_option("test.config.int.none", None)
|
||||
self.assertEqual(ps.get_option("test.config.int.none"), None)
|
||||
|
||||
def test_different_types(self):
|
||||
with self.assertRaisesRegex(ValueError, "was <class 'int'>"):
|
||||
ps.set_option("test.config.list", 1)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "however, expected types are"):
|
||||
ps.set_option("test.config.float", "abc")
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "[<class 'int'>]"):
|
||||
ps.set_option("test.config.int", "abc")
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "(<class 'int'>, <class 'NoneType'>)"):
|
||||
ps.set_option("test.config.int.none", "abc")
|
||||
|
||||
def test_check_func(self):
|
||||
with self.assertRaisesRegex(ValueError, "bigger then 0"):
|
||||
ps.set_option("test.config.int", -1)
|
||||
|
||||
def test_unknown_option(self):
|
||||
with self.assertRaisesRegex(config.OptionError, "No such option"):
|
||||
ps.get_option("unknown")
|
||||
|
||||
with self.assertRaisesRegex(config.OptionError, "Available options"):
|
||||
ps.set_option("unknown", "value")
|
||||
|
||||
with self.assertRaisesRegex(config.OptionError, "test.config"):
|
||||
ps.reset_option("unknown")
|
||||
|
||||
def test_namespace_access(self):
|
||||
try:
|
||||
self.assertEqual(ps.options.compute.max_rows, ps.get_option("compute.max_rows"))
|
||||
ps.options.compute.max_rows = 0
|
||||
self.assertEqual(ps.options.compute.max_rows, 0)
|
||||
self.assertTrue(isinstance(ps.options.compute, DictWrapper))
|
||||
|
||||
wrapper = ps.options.compute
|
||||
self.assertEqual(wrapper.max_rows, ps.get_option("compute.max_rows"))
|
||||
wrapper.max_rows = 1000
|
||||
self.assertEqual(ps.options.compute.max_rows, 1000)
|
||||
|
||||
self.assertRaisesRegex(config.OptionError, "No such option", lambda: ps.options.compu)
|
||||
self.assertRaisesRegex(
|
||||
config.OptionError, "No such option", lambda: ps.options.compute.max
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
config.OptionError, "No such option", lambda: ps.options.max_rows1
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(config.OptionError, "No such option"):
|
||||
ps.options.compute.max = 0
|
||||
with self.assertRaisesRegex(config.OptionError, "No such option"):
|
||||
ps.options.compute = 0
|
||||
with self.assertRaisesRegex(config.OptionError, "No such option"):
|
||||
ps.options.com = 0
|
||||
finally:
|
||||
ps.reset_option("compute.max_rows")
|
||||
|
||||
def test_dir_options(self):
|
||||
self.assertTrue("compute.default_index_type" in dir(ps.options))
|
||||
self.assertTrue("plotting.sample_ratio" in dir(ps.options))
|
||||
|
||||
self.assertTrue("default_index_type" in dir(ps.options.compute))
|
||||
self.assertTrue("sample_ratio" not in dir(ps.options.compute))
|
||||
|
||||
self.assertTrue("default_index_type" not in dir(ps.options.plotting))
|
||||
self.assertTrue("sample_ratio" in dir(ps.options.plotting))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
from pyspark.pandas.tests.test_config import * # noqa: F401
|
||||
|
||||
try:
|
||||
import xmlrunner # type: ignore[import]
|
||||
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
|
||||
except ImportError:
|
||||
testRunner = None
|
||||
unittest.main(testRunner=testRunner, verbosity=2)
|
51
python/pyspark/pandas/tests/test_default_index.py
Normal file
51
python/pyspark/pandas/tests/test_default_index.py
Normal file
|
@ -0,0 +1,51 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright ownership.
|
||||
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
# (the "License"); you may not use this file except in compliance with
|
||||
# the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from pyspark import pandas as ps
|
||||
from pyspark.pandas.testing.utils import ReusedSQLTestCase
|
||||
|
||||
|
||||
class DefaultIndexTest(ReusedSQLTestCase):
|
||||
def test_default_index_sequence(self):
|
||||
with ps.option_context("compute.default_index_type", "sequence"):
|
||||
sdf = self.spark.range(1000)
|
||||
self.assert_eq(ps.DataFrame(sdf), pd.DataFrame({"id": list(range(1000))}))
|
||||
|
||||
def test_default_index_distributed_sequence(self):
|
||||
with ps.option_context("compute.default_index_type", "distributed-sequence"):
|
||||
sdf = self.spark.range(1000)
|
||||
self.assert_eq(ps.DataFrame(sdf), pd.DataFrame({"id": list(range(1000))}))
|
||||
|
||||
def test_default_index_distributed(self):
|
||||
with ps.option_context("compute.default_index_type", "distributed"):
|
||||
sdf = self.spark.range(1000)
|
||||
pdf = ps.DataFrame(sdf).to_pandas()
|
||||
self.assertEqual(len(set(pdf.index)), len(pdf))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
from pyspark.pandas.tests.test_default_index import * # noqa: F401
|
||||
|
||||
try:
|
||||
import xmlrunner # type: ignore[import]
|
||||
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
|
||||
except ImportError:
|
||||
testRunner = None
|
||||
unittest.main(testRunner=testRunner, verbosity=2)
|
151
python/pyspark/pandas/tests/test_extension.py
Normal file
151
python/pyspark/pandas/tests/test_extension.py
Normal file
|
@ -0,0 +1,151 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright ownership.
|
||||
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
# (the "License"); you may not use this file except in compliance with
|
||||
# the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import contextlib
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from pyspark import pandas as ps
|
||||
from pyspark.pandas.testing.utils import assert_produces_warning, ReusedSQLTestCase
|
||||
from pyspark.pandas.extensions import (
|
||||
register_dataframe_accessor,
|
||||
register_series_accessor,
|
||||
register_index_accessor,
|
||||
)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def ensure_removed(obj, attr):
|
||||
"""
|
||||
Ensure attribute attached to 'obj' during testing is removed in the end
|
||||
"""
|
||||
try:
|
||||
yield
|
||||
|
||||
finally:
|
||||
try:
|
||||
delattr(obj, attr)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
class CustomAccessor:
|
||||
def __init__(self, obj):
|
||||
self.obj = obj
|
||||
self.item = "item"
|
||||
|
||||
@property
|
||||
def prop(self):
|
||||
return self.item
|
||||
|
||||
def method(self):
|
||||
return self.item
|
||||
|
||||
def check_length(self, col=None):
|
||||
if type(self.obj) == ps.DataFrame or col is not None:
|
||||
return len(self.obj[col])
|
||||
else:
|
||||
try:
|
||||
return len(self.obj)
|
||||
except Exception as e:
|
||||
raise ValueError(str(e))
|
||||
|
||||
|
||||
class ExtensionTest(ReusedSQLTestCase):
|
||||
@property
|
||||
def pdf(self):
|
||||
return pd.DataFrame(
|
||||
{"a": [1, 2, 3, 4, 5, 6, 7, 8, 9], "b": [4, 5, 6, 3, 2, 1, 0, 0, 0]},
|
||||
index=np.random.rand(9),
|
||||
)
|
||||
|
||||
@property
|
||||
def kdf(self):
|
||||
return ps.from_pandas(self.pdf)
|
||||
|
||||
@property
|
||||
def accessor(self):
|
||||
return CustomAccessor(self.kdf)
|
||||
|
||||
def test_setup(self):
|
||||
self.assertEqual("item", self.accessor.item)
|
||||
|
||||
def test_dataframe_register(self):
|
||||
with ensure_removed(ps.DataFrame, "test"):
|
||||
register_dataframe_accessor("test")(CustomAccessor)
|
||||
assert self.kdf.test.prop == "item"
|
||||
assert self.kdf.test.method() == "item"
|
||||
assert len(self.kdf["a"]) == self.kdf.test.check_length("a")
|
||||
|
||||
def test_series_register(self):
|
||||
with ensure_removed(ps.Series, "test"):
|
||||
register_series_accessor("test")(CustomAccessor)
|
||||
assert self.kdf.a.test.prop == "item"
|
||||
assert self.kdf.a.test.method() == "item"
|
||||
assert self.kdf.a.test.check_length() == len(self.kdf["a"])
|
||||
|
||||
def test_index_register(self):
|
||||
with ensure_removed(ps.Index, "test"):
|
||||
register_index_accessor("test")(CustomAccessor)
|
||||
assert self.kdf.index.test.prop == "item"
|
||||
assert self.kdf.index.test.method() == "item"
|
||||
assert self.kdf.index.test.check_length() == self.kdf.index.size
|
||||
|
||||
def test_accessor_works(self):
|
||||
register_series_accessor("test")(CustomAccessor)
|
||||
|
||||
s = ps.Series([1, 2])
|
||||
assert s.test.obj is s
|
||||
assert s.test.prop == "item"
|
||||
assert s.test.method() == "item"
|
||||
|
||||
def test_overwrite_warns(self):
|
||||
mean = ps.Series.mean
|
||||
try:
|
||||
with assert_produces_warning(UserWarning, raise_on_extra_warnings=False) as w:
|
||||
register_series_accessor("mean")(CustomAccessor)
|
||||
s = ps.Series([1, 2])
|
||||
assert s.mean.prop == "item"
|
||||
msg = str(w[0].message)
|
||||
assert "mean" in msg
|
||||
assert "CustomAccessor" in msg
|
||||
assert "Series" in msg
|
||||
finally:
|
||||
ps.Series.mean = mean
|
||||
|
||||
def test_raises_attr_error(self):
|
||||
with ensure_removed(ps.Series, "bad"):
|
||||
|
||||
class Bad:
|
||||
def __init__(self, data):
|
||||
raise AttributeError("whoops")
|
||||
|
||||
with self.assertRaises(AttributeError):
|
||||
ps.Series([1, 2], dtype=object).bad
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
from pyspark.pandas.tests.test_extension import * # noqa: F401
|
||||
|
||||
try:
|
||||
import xmlrunner # type: ignore[import]
|
||||
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
|
||||
except ImportError:
|
||||
testRunner = None
|
||||
unittest.main(testRunner=testRunner, verbosity=2)
|
103
python/pyspark/pandas/tests/test_internal.py
Normal file
103
python/pyspark/pandas/tests/test_internal.py
Normal file
|
@ -0,0 +1,103 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright ownership.
|
||||
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
# (the "License"); you may not use this file except in compliance with
|
||||
# the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from pyspark.pandas.internal import (
|
||||
InternalFrame,
|
||||
SPARK_DEFAULT_INDEX_NAME,
|
||||
SPARK_INDEX_NAME_FORMAT,
|
||||
)
|
||||
from pyspark.pandas.testing.utils import ReusedSQLTestCase, SQLTestUtils
|
||||
|
||||
|
||||
class InternalFrameTest(ReusedSQLTestCase, SQLTestUtils):
|
||||
def test_from_pandas(self):
|
||||
pdf = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
|
||||
|
||||
internal = InternalFrame.from_pandas(pdf)
|
||||
sdf = internal.spark_frame
|
||||
|
||||
self.assert_eq(internal.index_spark_column_names, [SPARK_DEFAULT_INDEX_NAME])
|
||||
self.assert_eq(internal.index_names, [None])
|
||||
self.assert_eq(internal.column_labels, [("a",), ("b",)])
|
||||
self.assert_eq(internal.data_spark_column_names, ["a", "b"])
|
||||
self.assertTrue(internal.spark_column_for(("a",))._jc.equals(sdf["a"]._jc))
|
||||
self.assertTrue(internal.spark_column_for(("b",))._jc.equals(sdf["b"]._jc))
|
||||
|
||||
self.assert_eq(internal.to_pandas_frame, pdf)
|
||||
|
||||
# non-string column name
|
||||
pdf1 = pd.DataFrame({0: [1, 2, 3], 1: [4, 5, 6]})
|
||||
|
||||
internal = InternalFrame.from_pandas(pdf1)
|
||||
sdf = internal.spark_frame
|
||||
|
||||
self.assert_eq(internal.index_spark_column_names, [SPARK_DEFAULT_INDEX_NAME])
|
||||
self.assert_eq(internal.index_names, [None])
|
||||
self.assert_eq(internal.column_labels, [(0,), (1,)])
|
||||
self.assert_eq(internal.data_spark_column_names, ["0", "1"])
|
||||
self.assertTrue(internal.spark_column_for((0,))._jc.equals(sdf["0"]._jc))
|
||||
self.assertTrue(internal.spark_column_for((1,))._jc.equals(sdf["1"]._jc))
|
||||
|
||||
self.assert_eq(internal.to_pandas_frame, pdf1)
|
||||
|
||||
# multi-index
|
||||
pdf.set_index("a", append=True, inplace=True)
|
||||
|
||||
internal = InternalFrame.from_pandas(pdf)
|
||||
sdf = internal.spark_frame
|
||||
|
||||
self.assert_eq(
|
||||
internal.index_spark_column_names,
|
||||
[SPARK_INDEX_NAME_FORMAT(0), SPARK_INDEX_NAME_FORMAT(1)],
|
||||
)
|
||||
self.assert_eq(internal.index_names, [None, ("a",)])
|
||||
self.assert_eq(internal.column_labels, [("b",)])
|
||||
self.assert_eq(internal.data_spark_column_names, ["b"])
|
||||
self.assertTrue(internal.spark_column_for(("b",))._jc.equals(sdf["b"]._jc))
|
||||
|
||||
self.assert_eq(internal.to_pandas_frame, pdf)
|
||||
|
||||
# multi-index columns
|
||||
pdf.columns = pd.MultiIndex.from_tuples([("x", "b")])
|
||||
|
||||
internal = InternalFrame.from_pandas(pdf)
|
||||
sdf = internal.spark_frame
|
||||
|
||||
self.assert_eq(
|
||||
internal.index_spark_column_names,
|
||||
[SPARK_INDEX_NAME_FORMAT(0), SPARK_INDEX_NAME_FORMAT(1)],
|
||||
)
|
||||
self.assert_eq(internal.index_names, [None, ("a",)])
|
||||
self.assert_eq(internal.column_labels, [("x", "b")])
|
||||
self.assert_eq(internal.data_spark_column_names, ["(x, b)"])
|
||||
self.assertTrue(internal.spark_column_for(("x", "b"))._jc.equals(sdf["(x, b)"]._jc))
|
||||
|
||||
self.assert_eq(internal.to_pandas_frame, pdf)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
from pyspark.pandas.tests.test_internal import * # noqa: F401
|
||||
|
||||
try:
|
||||
import xmlrunner # type: ignore[import]
|
||||
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
|
||||
except ImportError:
|
||||
testRunner = None
|
||||
unittest.main(testRunner=testRunner, verbosity=2)
|
210
python/pyspark/pandas/tests/test_numpy_compat.py
Normal file
210
python/pyspark/pandas/tests/test_numpy_compat.py
Normal file
|
@ -0,0 +1,210 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright ownership.
|
||||
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
# (the "License"); you may not use this file except in compliance with
|
||||
# the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from pyspark import pandas as ps
|
||||
from pyspark.pandas import set_option, reset_option
|
||||
from pyspark.pandas.numpy_compat import unary_np_spark_mappings, binary_np_spark_mappings
|
||||
from pyspark.pandas.testing.utils import ReusedSQLTestCase, SQLTestUtils
|
||||
|
||||
|
||||
class NumPyCompatTest(ReusedSQLTestCase, SQLTestUtils):
|
||||
blacklist = [
|
||||
# Koalas does not currently support
|
||||
"conj",
|
||||
"conjugate",
|
||||
"isnat",
|
||||
"matmul",
|
||||
"frexp",
|
||||
# Values are close enough but tests failed.
|
||||
"arccos",
|
||||
"exp",
|
||||
"expm1",
|
||||
"log", # flaky
|
||||
"log10", # flaky
|
||||
"log1p", # flaky
|
||||
"modf",
|
||||
"floor_divide", # flaky
|
||||
# Results seem inconsistent in a different version of, I (Hyukjin) suspect, PyArrow.
|
||||
# From PyArrow 0.15, seems it returns the correct results via PySpark. Probably we
|
||||
# can enable it later when Koalas switches to PyArrow 0.15 completely.
|
||||
"left_shift",
|
||||
]
|
||||
|
||||
@property
|
||||
def pdf(self):
|
||||
return pd.DataFrame(
|
||||
{"a": [1, 2, 3, 4, 5, 6, 7, 8, 9], "b": [4, 5, 6, 3, 2, 1, 0, 0, 0]},
|
||||
index=[0, 1, 3, 5, 6, 8, 9, 9, 9],
|
||||
)
|
||||
|
||||
@property
|
||||
def kdf(self):
|
||||
return ps.from_pandas(self.pdf)
|
||||
|
||||
def test_np_add_series(self):
|
||||
kdf = self.kdf
|
||||
pdf = self.pdf
|
||||
|
||||
if LooseVersion(pd.__version__) < LooseVersion("0.25"):
|
||||
self.assert_eq(np.add(kdf.a, kdf.b), np.add(pdf.a, pdf.b).rename())
|
||||
else:
|
||||
self.assert_eq(np.add(kdf.a, kdf.b), np.add(pdf.a, pdf.b))
|
||||
|
||||
kdf = self.kdf
|
||||
pdf = self.pdf
|
||||
self.assert_eq(np.add(kdf.a, 1), np.add(pdf.a, 1))
|
||||
|
||||
def test_np_add_index(self):
|
||||
k_index = self.kdf.index
|
||||
p_index = self.pdf.index
|
||||
self.assert_eq(np.add(k_index, k_index), np.add(p_index, p_index))
|
||||
|
||||
def test_np_unsupported_series(self):
|
||||
kdf = self.kdf
|
||||
with self.assertRaisesRegex(NotImplementedError, "Koalas.*not.*support.*sqrt.*"):
|
||||
np.sqrt(kdf.a, kdf.b)
|
||||
|
||||
def test_np_unsupported_frame(self):
|
||||
kdf = self.kdf
|
||||
with self.assertRaisesRegex(NotImplementedError, "Koalas.*not.*support.*sqrt.*"):
|
||||
np.sqrt(kdf, kdf)
|
||||
|
||||
def test_np_spark_compat_series(self):
|
||||
# Use randomly generated dataFrame
|
||||
pdf = pd.DataFrame(
|
||||
np.random.randint(-100, 100, size=(np.random.randint(100), 2)), columns=["a", "b"]
|
||||
)
|
||||
pdf2 = pd.DataFrame(
|
||||
np.random.randint(-100, 100, size=(len(pdf), len(pdf.columns))), columns=["a", "b"]
|
||||
)
|
||||
kdf = ps.from_pandas(pdf)
|
||||
kdf2 = ps.from_pandas(pdf2)
|
||||
|
||||
for np_name, spark_func in unary_np_spark_mappings.items():
|
||||
np_func = getattr(np, np_name)
|
||||
if np_name not in self.blacklist:
|
||||
try:
|
||||
# unary ufunc
|
||||
self.assert_eq(np_func(pdf.a), np_func(kdf.a), almost=True)
|
||||
except Exception as e:
|
||||
raise AssertionError("Test in '%s' function was failed." % np_name) from e
|
||||
|
||||
for np_name, spark_func in binary_np_spark_mappings.items():
|
||||
np_func = getattr(np, np_name)
|
||||
if np_name not in self.blacklist:
|
||||
try:
|
||||
# binary ufunc
|
||||
if LooseVersion(pd.__version__) < LooseVersion("0.25"):
|
||||
self.assert_eq(
|
||||
np_func(pdf.a, pdf.b).rename(), np_func(kdf.a, kdf.b), almost=True
|
||||
)
|
||||
else:
|
||||
self.assert_eq(np_func(pdf.a, pdf.b), np_func(kdf.a, kdf.b), almost=True)
|
||||
self.assert_eq(np_func(pdf.a, 1), np_func(kdf.a, 1), almost=True)
|
||||
except Exception as e:
|
||||
raise AssertionError("Test in '%s' function was failed." % np_name) from e
|
||||
|
||||
# Test only top 5 for now. 'compute.ops_on_diff_frames' option increases too much time.
|
||||
try:
|
||||
set_option("compute.ops_on_diff_frames", True)
|
||||
for np_name, spark_func in list(binary_np_spark_mappings.items())[:5]:
|
||||
np_func = getattr(np, np_name)
|
||||
if np_name not in self.blacklist:
|
||||
try:
|
||||
# binary ufunc
|
||||
if LooseVersion(pd.__version__) < LooseVersion("0.25"):
|
||||
self.assert_eq(
|
||||
np_func(pdf.a, pdf2.b).sort_index().rename(),
|
||||
np_func(kdf.a, kdf2.b).sort_index(),
|
||||
almost=True,
|
||||
)
|
||||
else:
|
||||
self.assert_eq(
|
||||
np_func(pdf.a, pdf2.b).sort_index(),
|
||||
np_func(kdf.a, kdf2.b).sort_index(),
|
||||
almost=True,
|
||||
)
|
||||
except Exception as e:
|
||||
raise AssertionError("Test in '%s' function was failed." % np_name) from e
|
||||
finally:
|
||||
reset_option("compute.ops_on_diff_frames")
|
||||
|
||||
def test_np_spark_compat_frame(self):
|
||||
# Use randomly generated dataFrame
|
||||
pdf = pd.DataFrame(
|
||||
np.random.randint(-100, 100, size=(np.random.randint(100), 2)), columns=["a", "b"]
|
||||
)
|
||||
pdf2 = pd.DataFrame(
|
||||
np.random.randint(-100, 100, size=(len(pdf), len(pdf.columns))), columns=["a", "b"]
|
||||
)
|
||||
kdf = ps.from_pandas(pdf)
|
||||
kdf2 = ps.from_pandas(pdf2)
|
||||
|
||||
for np_name, spark_func in unary_np_spark_mappings.items():
|
||||
np_func = getattr(np, np_name)
|
||||
if np_name not in self.blacklist:
|
||||
try:
|
||||
# unary ufunc
|
||||
self.assert_eq(np_func(pdf), np_func(kdf), almost=True)
|
||||
except Exception as e:
|
||||
raise AssertionError("Test in '%s' function was failed." % np_name) from e
|
||||
|
||||
for np_name, spark_func in binary_np_spark_mappings.items():
|
||||
np_func = getattr(np, np_name)
|
||||
if np_name not in self.blacklist:
|
||||
try:
|
||||
# binary ufunc
|
||||
self.assert_eq(np_func(pdf, pdf), np_func(kdf, kdf), almost=True)
|
||||
self.assert_eq(np_func(pdf, 1), np_func(kdf, 1), almost=True)
|
||||
except Exception as e:
|
||||
raise AssertionError("Test in '%s' function was failed." % np_name) from e
|
||||
|
||||
# Test only top 5 for now. 'compute.ops_on_diff_frames' option increases too much time.
|
||||
try:
|
||||
set_option("compute.ops_on_diff_frames", True)
|
||||
for np_name, spark_func in list(binary_np_spark_mappings.items())[:5]:
|
||||
np_func = getattr(np, np_name)
|
||||
if np_name not in self.blacklist:
|
||||
try:
|
||||
# binary ufunc
|
||||
self.assert_eq(
|
||||
np_func(pdf, pdf2).sort_index(),
|
||||
np_func(kdf, kdf2).sort_index(),
|
||||
almost=True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise AssertionError("Test in '%s' function was failed." % np_name) from e
|
||||
finally:
|
||||
reset_option("compute.ops_on_diff_frames")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
from pyspark.pandas.tests.test_numpy_compat import * # noqa: F401
|
||||
|
||||
try:
|
||||
import xmlrunner # type: ignore[import]
|
||||
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
|
||||
except ImportError:
|
||||
testRunner = None
|
||||
unittest.main(testRunner=testRunner, verbosity=2)
|
437
python/pyspark/pandas/tests/test_typedef.py
Normal file
437
python/pyspark/pandas/tests/test_typedef.py
Normal file
|
@ -0,0 +1,437 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright ownership.
|
||||
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
# (the "License"); you may not use this file except in compliance with
|
||||
# the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
import datetime
|
||||
import decimal
|
||||
from typing import List
|
||||
|
||||
import pandas
|
||||
import pandas as pd
|
||||
from pandas.api.types import CategoricalDtype
|
||||
import numpy as np
|
||||
from pyspark.sql.types import (
|
||||
ArrayType,
|
||||
BinaryType,
|
||||
BooleanType,
|
||||
FloatType,
|
||||
IntegerType,
|
||||
LongType,
|
||||
StringType,
|
||||
StructField,
|
||||
StructType,
|
||||
ByteType,
|
||||
ShortType,
|
||||
DateType,
|
||||
DecimalType,
|
||||
DoubleType,
|
||||
TimestampType,
|
||||
)
|
||||
|
||||
from pyspark.pandas.typedef import (
|
||||
as_spark_type,
|
||||
extension_dtypes_available,
|
||||
extension_float_dtypes_available,
|
||||
extension_object_dtypes_available,
|
||||
infer_return_type,
|
||||
koalas_dtype,
|
||||
)
|
||||
from pyspark import pandas as ps
|
||||
|
||||
|
||||
class TypeHintTests(unittest.TestCase):
|
||||
@unittest.skipIf(
|
||||
sys.version_info < (3, 7),
|
||||
"Type inference from pandas instances is supported with Python 3.7+",
|
||||
)
|
||||
def test_infer_schema_from_pandas_instances(self):
|
||||
def func() -> pd.Series[int]:
|
||||
pass
|
||||
|
||||
inferred = infer_return_type(func)
|
||||
self.assertEqual(inferred.dtype, np.int64)
|
||||
self.assertEqual(inferred.spark_type, LongType())
|
||||
|
||||
def func() -> pd.Series[np.float]:
|
||||
pass
|
||||
|
||||
inferred = infer_return_type(func)
|
||||
self.assertEqual(inferred.dtype, np.float64)
|
||||
self.assertEqual(inferred.spark_type, DoubleType())
|
||||
|
||||
def func() -> "pd.DataFrame[np.float, str]":
|
||||
pass
|
||||
|
||||
expected = StructType([StructField("c0", DoubleType()), StructField("c1", StringType())])
|
||||
inferred = infer_return_type(func)
|
||||
self.assertEqual(inferred.dtypes, [np.float64, np.unicode_])
|
||||
self.assertEqual(inferred.spark_type, expected)
|
||||
|
||||
def func() -> "pandas.DataFrame[np.float]":
|
||||
pass
|
||||
|
||||
expected = StructType([StructField("c0", DoubleType())])
|
||||
inferred = infer_return_type(func)
|
||||
self.assertEqual(inferred.dtypes, [np.float64])
|
||||
self.assertEqual(inferred.spark_type, expected)
|
||||
|
||||
def func() -> "pd.Series[int]":
|
||||
pass
|
||||
|
||||
inferred = infer_return_type(func)
|
||||
self.assertEqual(inferred.dtype, np.int64)
|
||||
self.assertEqual(inferred.spark_type, LongType())
|
||||
|
||||
def func() -> pd.DataFrame[np.float, str]:
|
||||
pass
|
||||
|
||||
expected = StructType([StructField("c0", DoubleType()), StructField("c1", StringType())])
|
||||
inferred = infer_return_type(func)
|
||||
self.assertEqual(inferred.dtypes, [np.float64, np.unicode_])
|
||||
self.assertEqual(inferred.spark_type, expected)
|
||||
|
||||
def func() -> pd.DataFrame[np.float]:
|
||||
pass
|
||||
|
||||
expected = StructType([StructField("c0", DoubleType())])
|
||||
inferred = infer_return_type(func)
|
||||
self.assertEqual(inferred.dtypes, [np.float64])
|
||||
self.assertEqual(inferred.spark_type, expected)
|
||||
|
||||
pdf = pd.DataFrame({"a": [1, 2, 3], "b": [3, 4, 5]})
|
||||
|
||||
def func() -> pd.DataFrame[pdf.dtypes]: # type: ignore
|
||||
pass
|
||||
|
||||
expected = StructType([StructField("c0", LongType()), StructField("c1", LongType())])
|
||||
inferred = infer_return_type(func)
|
||||
self.assertEqual(inferred.dtypes, [np.int64, np.int64])
|
||||
self.assertEqual(inferred.spark_type, expected)
|
||||
|
||||
pdf = pd.DataFrame({"a": [1, 2, 3], "b": pd.Categorical(["a", "b", "c"])})
|
||||
|
||||
def func() -> pd.Series[pdf.b.dtype]: # type: ignore
|
||||
pass
|
||||
|
||||
inferred = infer_return_type(func)
|
||||
self.assertEqual(inferred.dtype, CategoricalDtype(categories=["a", "b", "c"]))
|
||||
self.assertEqual(inferred.spark_type, LongType())
|
||||
|
||||
def func() -> pd.DataFrame[pdf.dtypes]: # type: ignore
|
||||
pass
|
||||
|
||||
expected = StructType([StructField("c0", LongType()), StructField("c1", LongType())])
|
||||
inferred = infer_return_type(func)
|
||||
self.assertEqual(inferred.dtypes, [np.int64, CategoricalDtype(categories=["a", "b", "c"])])
|
||||
self.assertEqual(inferred.spark_type, expected)
|
||||
|
||||
def test_if_pandas_implements_class_getitem(self):
|
||||
# the current type hint implementation of pandas DataFrame assumes pandas doesn't
|
||||
# implement '__class_getitem__'. This test case is to make sure pandas
|
||||
# doesn't implement them.
|
||||
assert not ps._frame_has_class_getitem
|
||||
assert not ps._series_has_class_getitem
|
||||
|
||||
@unittest.skipIf(
|
||||
sys.version_info < (3, 7),
|
||||
"Type inference from pandas instances is supported with Python 3.7+",
|
||||
)
|
||||
def test_infer_schema_with_names_pandas_instances(self):
|
||||
def func() -> 'pd.DataFrame["a" : np.float, "b":str]': # noqa: F405
|
||||
pass
|
||||
|
||||
expected = StructType([StructField("a", DoubleType()), StructField("b", StringType())])
|
||||
inferred = infer_return_type(func)
|
||||
self.assertEqual(inferred.dtypes, [np.float64, np.unicode_])
|
||||
self.assertEqual(inferred.spark_type, expected)
|
||||
|
||||
def func() -> "pd.DataFrame['a': np.float, 'b': int]": # noqa: F405
|
||||
pass
|
||||
|
||||
expected = StructType([StructField("a", DoubleType()), StructField("b", LongType())])
|
||||
inferred = infer_return_type(func)
|
||||
self.assertEqual(inferred.dtypes, [np.float64, np.int64])
|
||||
self.assertEqual(inferred.spark_type, expected)
|
||||
|
||||
pdf = pd.DataFrame({"a": [1, 2, 3], "b": [3, 4, 5]})
|
||||
|
||||
def func() -> pd.DataFrame[zip(pdf.columns, pdf.dtypes)]:
|
||||
pass
|
||||
|
||||
expected = StructType([StructField("a", LongType()), StructField("b", LongType())])
|
||||
inferred = infer_return_type(func)
|
||||
self.assertEqual(inferred.dtypes, [np.int64, np.int64])
|
||||
self.assertEqual(inferred.spark_type, expected)
|
||||
|
||||
pdf = pd.DataFrame({("x", "a"): [1, 2, 3], ("y", "b"): [3, 4, 5]})
|
||||
|
||||
def func() -> pd.DataFrame[zip(pdf.columns, pdf.dtypes)]:
|
||||
pass
|
||||
|
||||
expected = StructType(
|
||||
[StructField("(x, a)", LongType()), StructField("(y, b)", LongType())]
|
||||
)
|
||||
inferred = infer_return_type(func)
|
||||
self.assertEqual(inferred.dtypes, [np.int64, np.int64])
|
||||
self.assertEqual(inferred.spark_type, expected)
|
||||
|
||||
pdf = pd.DataFrame({"a": [1, 2, 3], "b": pd.Categorical(["a", "b", "c"])})
|
||||
|
||||
def func() -> pd.DataFrame[zip(pdf.columns, pdf.dtypes)]:
|
||||
pass
|
||||
|
||||
expected = StructType([StructField("a", LongType()), StructField("b", LongType())])
|
||||
inferred = infer_return_type(func)
|
||||
self.assertEqual(inferred.dtypes, [np.int64, CategoricalDtype(categories=["a", "b", "c"])])
|
||||
self.assertEqual(inferred.spark_type, expected)
|
||||
|
||||
@unittest.skipIf(
|
||||
sys.version_info < (3, 7),
|
||||
"Type inference from pandas instances is supported with Python 3.7+",
|
||||
)
|
||||
def test_infer_schema_with_names_pandas_instances_negative(self):
|
||||
def try_infer_return_type():
|
||||
def f() -> 'pd.DataFrame["a" : np.float : 1, "b":str:2]': # noqa: F405
|
||||
pass
|
||||
|
||||
infer_return_type(f)
|
||||
|
||||
self.assertRaisesRegex(TypeError, "Type hints should be specified", try_infer_return_type)
|
||||
|
||||
class A:
|
||||
pass
|
||||
|
||||
def try_infer_return_type():
|
||||
def f() -> pd.DataFrame[A]:
|
||||
pass
|
||||
|
||||
infer_return_type(f)
|
||||
|
||||
self.assertRaisesRegex(TypeError, "not understood", try_infer_return_type)
|
||||
|
||||
def try_infer_return_type():
|
||||
def f() -> 'pd.DataFrame["a" : np.float : 1, "b":str:2]': # noqa: F405
|
||||
pass
|
||||
|
||||
infer_return_type(f)
|
||||
|
||||
self.assertRaisesRegex(TypeError, "Type hints should be specified", try_infer_return_type)
|
||||
|
||||
# object type
|
||||
pdf = pd.DataFrame({"a": ["a", 2, None]})
|
||||
|
||||
def try_infer_return_type():
|
||||
def f() -> pd.DataFrame[pdf.dtypes]: # type: ignore
|
||||
pass
|
||||
|
||||
infer_return_type(f)
|
||||
|
||||
self.assertRaisesRegex(TypeError, "object.*not understood", try_infer_return_type)
|
||||
|
||||
def try_infer_return_type():
|
||||
def f() -> pd.Series[pdf.a.dtype]: # type: ignore
|
||||
pass
|
||||
|
||||
infer_return_type(f)
|
||||
|
||||
self.assertRaisesRegex(TypeError, "object.*not understood", try_infer_return_type)
|
||||
|
||||
def test_infer_schema_with_names_negative(self):
|
||||
def try_infer_return_type():
|
||||
def f() -> 'ps.DataFrame["a" : np.float : 1, "b":str:2]': # noqa: F405
|
||||
pass
|
||||
|
||||
infer_return_type(f)
|
||||
|
||||
self.assertRaisesRegex(TypeError, "Type hints should be specified", try_infer_return_type)
|
||||
|
||||
class A:
|
||||
pass
|
||||
|
||||
def try_infer_return_type():
|
||||
def f() -> ps.DataFrame[A]:
|
||||
pass
|
||||
|
||||
infer_return_type(f)
|
||||
|
||||
self.assertRaisesRegex(TypeError, "not understood", try_infer_return_type)
|
||||
|
||||
def try_infer_return_type():
|
||||
def f() -> 'ps.DataFrame["a" : np.float : 1, "b":str:2]': # noqa: F405
|
||||
pass
|
||||
|
||||
infer_return_type(f)
|
||||
|
||||
self.assertRaisesRegex(TypeError, "Type hints should be specified", try_infer_return_type)
|
||||
|
||||
# object type
|
||||
pdf = pd.DataFrame({"a": ["a", 2, None]})
|
||||
|
||||
def try_infer_return_type():
|
||||
def f() -> ps.DataFrame[pdf.dtypes]: # type: ignore
|
||||
pass
|
||||
|
||||
infer_return_type(f)
|
||||
|
||||
self.assertRaisesRegex(TypeError, "object.*not understood", try_infer_return_type)
|
||||
|
||||
def try_infer_return_type():
|
||||
def f() -> ps.Series[pdf.a.dtype]: # type: ignore
|
||||
pass
|
||||
|
||||
infer_return_type(f)
|
||||
|
||||
self.assertRaisesRegex(TypeError, "object.*not understood", try_infer_return_type)
|
||||
|
||||
def test_as_spark_type_koalas_dtype(self):
|
||||
type_mapper = {
|
||||
# binary
|
||||
np.character: (np.character, BinaryType()),
|
||||
np.bytes_: (np.bytes_, BinaryType()),
|
||||
np.string_: (np.bytes_, BinaryType()),
|
||||
bytes: (np.bytes_, BinaryType()),
|
||||
# integer
|
||||
np.int8: (np.int8, ByteType()),
|
||||
np.byte: (np.int8, ByteType()),
|
||||
np.int16: (np.int16, ShortType()),
|
||||
np.int32: (np.int32, IntegerType()),
|
||||
np.int64: (np.int64, LongType()),
|
||||
np.int: (np.int64, LongType()),
|
||||
int: (np.int64, LongType()),
|
||||
# floating
|
||||
np.float32: (np.float32, FloatType()),
|
||||
np.float: (np.float64, DoubleType()),
|
||||
np.float64: (np.float64, DoubleType()),
|
||||
float: (np.float64, DoubleType()),
|
||||
# string
|
||||
np.str: (np.unicode_, StringType()),
|
||||
np.unicode_: (np.unicode_, StringType()),
|
||||
str: (np.unicode_, StringType()),
|
||||
# bool
|
||||
np.bool: (np.bool, BooleanType()),
|
||||
bool: (np.bool, BooleanType()),
|
||||
# datetime
|
||||
np.datetime64: (np.datetime64, TimestampType()),
|
||||
datetime.datetime: (np.dtype("datetime64[ns]"), TimestampType()),
|
||||
# DateType
|
||||
datetime.date: (np.dtype("object"), DateType()),
|
||||
# DecimalType
|
||||
decimal.Decimal: (np.dtype("object"), DecimalType(38, 18)),
|
||||
# ArrayType
|
||||
np.ndarray: (np.dtype("object"), ArrayType(StringType())),
|
||||
List[bytes]: (np.dtype("object"), ArrayType(BinaryType())),
|
||||
List[np.character]: (np.dtype("object"), ArrayType(BinaryType())),
|
||||
List[np.bytes_]: (np.dtype("object"), ArrayType(BinaryType())),
|
||||
List[np.string_]: (np.dtype("object"), ArrayType(BinaryType())),
|
||||
List[bool]: (np.dtype("object"), ArrayType(BooleanType())),
|
||||
List[np.bool]: (np.dtype("object"), ArrayType(BooleanType())),
|
||||
List[datetime.date]: (np.dtype("object"), ArrayType(DateType())),
|
||||
List[np.int8]: (np.dtype("object"), ArrayType(ByteType())),
|
||||
List[np.byte]: (np.dtype("object"), ArrayType(ByteType())),
|
||||
List[decimal.Decimal]: (np.dtype("object"), ArrayType(DecimalType(38, 18))),
|
||||
List[float]: (np.dtype("object"), ArrayType(DoubleType())),
|
||||
List[np.float]: (np.dtype("object"), ArrayType(DoubleType())),
|
||||
List[np.float64]: (np.dtype("object"), ArrayType(DoubleType())),
|
||||
List[np.float32]: (np.dtype("object"), ArrayType(FloatType())),
|
||||
List[np.int32]: (np.dtype("object"), ArrayType(IntegerType())),
|
||||
List[int]: (np.dtype("object"), ArrayType(LongType())),
|
||||
List[np.int]: (np.dtype("object"), ArrayType(LongType())),
|
||||
List[np.int64]: (np.dtype("object"), ArrayType(LongType())),
|
||||
List[np.int16]: (np.dtype("object"), ArrayType(ShortType())),
|
||||
List[str]: (np.dtype("object"), ArrayType(StringType())),
|
||||
List[np.unicode_]: (np.dtype("object"), ArrayType(StringType())),
|
||||
List[datetime.datetime]: (np.dtype("object"), ArrayType(TimestampType())),
|
||||
List[np.datetime64]: (np.dtype("object"), ArrayType(TimestampType())),
|
||||
# CategoricalDtype
|
||||
CategoricalDtype(categories=["a", "b", "c"]): (
|
||||
CategoricalDtype(categories=["a", "b", "c"]),
|
||||
LongType(),
|
||||
),
|
||||
}
|
||||
|
||||
for numpy_or_python_type, (dtype, spark_type) in type_mapper.items():
|
||||
self.assertEqual(as_spark_type(numpy_or_python_type), spark_type)
|
||||
self.assertEqual(koalas_dtype(numpy_or_python_type), (dtype, spark_type))
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "Type uint64 was not understood."):
|
||||
as_spark_type(np.dtype("uint64"))
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "Type object was not understood."):
|
||||
as_spark_type(np.dtype("object"))
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "Type uint64 was not understood."):
|
||||
koalas_dtype(np.dtype("uint64"))
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "Type object was not understood."):
|
||||
koalas_dtype(np.dtype("object"))
|
||||
|
||||
@unittest.skipIf(not extension_dtypes_available, "The pandas extension types are not available")
|
||||
def test_as_spark_type_extension_dtypes(self):
|
||||
from pandas import Int8Dtype, Int16Dtype, Int32Dtype, Int64Dtype
|
||||
|
||||
type_mapper = {
|
||||
Int8Dtype(): ByteType(),
|
||||
Int16Dtype(): ShortType(),
|
||||
Int32Dtype(): IntegerType(),
|
||||
Int64Dtype(): LongType(),
|
||||
}
|
||||
|
||||
for extension_dtype, spark_type in type_mapper.items():
|
||||
self.assertEqual(as_spark_type(extension_dtype), spark_type)
|
||||
self.assertEqual(koalas_dtype(extension_dtype), (extension_dtype, spark_type))
|
||||
|
||||
@unittest.skipIf(
|
||||
not extension_object_dtypes_available, "The pandas extension object types are not available"
|
||||
)
|
||||
def test_as_spark_type_extension_object_dtypes(self):
|
||||
from pandas import BooleanDtype, StringDtype
|
||||
|
||||
type_mapper = {
|
||||
BooleanDtype(): BooleanType(),
|
||||
StringDtype(): StringType(),
|
||||
}
|
||||
|
||||
for extension_dtype, spark_type in type_mapper.items():
|
||||
self.assertEqual(as_spark_type(extension_dtype), spark_type)
|
||||
self.assertEqual(koalas_dtype(extension_dtype), (extension_dtype, spark_type))
|
||||
|
||||
@unittest.skipIf(
|
||||
not extension_float_dtypes_available, "The pandas extension float types are not available"
|
||||
)
|
||||
def test_as_spark_type_extension_float_dtypes(self):
|
||||
from pandas import Float32Dtype, Float64Dtype
|
||||
|
||||
type_mapper = {
|
||||
Float32Dtype(): FloatType(),
|
||||
Float64Dtype(): DoubleType(),
|
||||
}
|
||||
|
||||
for extension_dtype, spark_type in type_mapper.items():
|
||||
self.assertEqual(as_spark_type(extension_dtype), spark_type)
|
||||
self.assertEqual(koalas_dtype(extension_dtype), (extension_dtype, spark_type))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pyspark.pandas.tests.test_typedef import * # noqa: F401
|
||||
|
||||
try:
|
||||
import xmlrunner # type: ignore[import]
|
||||
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
|
||||
except ImportError:
|
||||
testRunner = None
|
||||
unittest.main(testRunner=testRunner, verbosity=2)
|
104
python/pyspark/pandas/tests/test_utils.py
Normal file
104
python/pyspark/pandas/tests/test_utils.py
Normal file
|
@ -0,0 +1,104 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright ownership.
|
||||
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
# (the "License"); you may not use this file except in compliance with
|
||||
# the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from pyspark.pandas.testing.utils import ReusedSQLTestCase, SQLTestUtils
|
||||
from pyspark.pandas.utils import (
|
||||
lazy_property,
|
||||
validate_arguments_and_invoke_function,
|
||||
validate_bool_kwarg,
|
||||
)
|
||||
|
||||
some_global_variable = 0
|
||||
|
||||
|
||||
class UtilsTest(ReusedSQLTestCase, SQLTestUtils):
|
||||
|
||||
# a dummy to_html version with an extra parameter that pandas does not support
|
||||
# used in test_validate_arguments_and_invoke_function
|
||||
def to_html(self, max_rows=None, unsupported_param=None):
|
||||
args = locals()
|
||||
|
||||
pdf = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}, index=[0, 1, 3])
|
||||
validate_arguments_and_invoke_function(pdf, self.to_html, pd.DataFrame.to_html, args)
|
||||
|
||||
def to_clipboard(self, sep=",", **kwargs):
|
||||
args = locals()
|
||||
|
||||
pdf = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}, index=[0, 1, 3])
|
||||
validate_arguments_and_invoke_function(
|
||||
pdf, self.to_clipboard, pd.DataFrame.to_clipboard, args
|
||||
)
|
||||
|
||||
# Support for **kwargs
|
||||
self.to_clipboard(sep=",", index=False)
|
||||
|
||||
def test_validate_arguments_and_invoke_function(self):
|
||||
# This should pass and run fine
|
||||
self.to_html()
|
||||
self.to_html(unsupported_param=None)
|
||||
self.to_html(max_rows=5)
|
||||
|
||||
# This should fail because we are explicitly setting an unsupported param
|
||||
# to a non-default value
|
||||
with self.assertRaises(TypeError):
|
||||
self.to_html(unsupported_param=1)
|
||||
|
||||
def test_lazy_property(self):
|
||||
obj = TestClassForLazyProp()
|
||||
# If lazy prop is not working, the second test would fail (because it'd be 2)
|
||||
self.assert_eq(obj.lazy_prop, 1)
|
||||
self.assert_eq(obj.lazy_prop, 1)
|
||||
|
||||
def test_validate_bool_kwarg(self):
|
||||
# This should pass and run fine
|
||||
koalas = True
|
||||
self.assert_eq(validate_bool_kwarg(koalas, "koalas"), True)
|
||||
koalas = False
|
||||
self.assert_eq(validate_bool_kwarg(koalas, "koalas"), False)
|
||||
koalas = None
|
||||
self.assert_eq(validate_bool_kwarg(koalas, "koalas"), None)
|
||||
|
||||
# This should fail because we are explicitly setting a non-boolean value
|
||||
koalas = "true"
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'For argument "koalas" expected type bool, received type str.'
|
||||
):
|
||||
validate_bool_kwarg(koalas, "koalas")
|
||||
|
||||
|
||||
class TestClassForLazyProp:
|
||||
def __init__(self):
|
||||
self.some_variable = 0
|
||||
|
||||
@lazy_property
|
||||
def lazy_prop(self):
|
||||
self.some_variable += 1
|
||||
return self.some_variable
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
from pyspark.pandas.tests.test_utils import * # noqa: F401
|
||||
|
||||
try:
|
||||
import xmlrunner # type: ignore[import]
|
||||
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
|
||||
except ImportError:
|
||||
testRunner = None
|
||||
unittest.main(testRunner=testRunner, verbosity=2)
|
Loading…
Reference in a new issue