[SPARK-33613][PYTHON][TESTS] Replace deprecated APIs in pyspark tests
### What changes were proposed in this pull request? This replaces deprecated API usage in PySpark tests with the preferred APIs. These have been deprecated for some time and usage is not consistent within tests. - https://docs.python.org/3/library/unittest.html#deprecated-aliases ### Why are the changes needed? For consistency and eventual removal of deprecated APIs. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing tests Closes #30557 from BryanCutler/replace-deprecated-apis-in-tests. Authored-by: Bryan Cutler <cutlerb@gmail.com> Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
parent
596fbc1d29
commit
aeb3649fb9
|
@ -169,7 +169,7 @@ class FeatureTests(SparkSessionTestCase):
|
|||
|
||||
# Test an empty vocabulary
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(Exception, "vocabSize.*invalid.*0"):
|
||||
with self.assertRaisesRegex(Exception, "vocabSize.*invalid.*0"):
|
||||
CountVectorizerModel.from_vocabulary([], inputCol="words")
|
||||
|
||||
# Test model with default settings can transform
|
||||
|
|
|
@ -47,19 +47,19 @@ class ImageFileFormatTest(SparkSessionTestCase):
|
|||
self.assertEqual(ImageSchema.undefinedImageType, "Undefined")
|
||||
|
||||
with QuietTest(self.sc):
|
||||
self.assertRaisesRegexp(
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"image argument should be pyspark.sql.types.Row; however",
|
||||
lambda: ImageSchema.toNDArray("a"))
|
||||
|
||||
with QuietTest(self.sc):
|
||||
self.assertRaisesRegexp(
|
||||
self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"image argument should have attributes specified in",
|
||||
lambda: ImageSchema.toNDArray(Row(a=1)))
|
||||
|
||||
with QuietTest(self.sc):
|
||||
self.assertRaisesRegexp(
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"array argument should be numpy.ndarray; however, it got",
|
||||
lambda: ImageSchema.toImage("a"))
|
||||
|
|
|
@ -308,7 +308,7 @@ class ParamTests(SparkSessionTestCase):
|
|||
LogisticRegression
|
||||
)
|
||||
|
||||
self.assertRaisesRegexp(
|
||||
self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Logistic Regression getThreshold found inconsistent.*$",
|
||||
LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5]
|
||||
|
|
|
@ -442,7 +442,7 @@ class PersistenceTest(SparkSessionTestCase):
|
|||
del metadata['defaultParamMap']
|
||||
metadataStr = json.dumps(metadata, separators=[',', ':'])
|
||||
loadedMetadata = reader._parseMetaData(metadataStr, )
|
||||
with self.assertRaisesRegexp(AssertionError, "`defaultParamMap` section not found"):
|
||||
with self.assertRaisesRegex(AssertionError, "`defaultParamMap` section not found"):
|
||||
reader.getAndSetParams(lr, loadedMetadata)
|
||||
|
||||
# Prior to 2.4.0, metadata doesn't have `defaultParamMap`.
|
||||
|
|
|
@ -499,7 +499,7 @@ class CrossValidatorTests(SparkSessionTestCase):
|
|||
evaluator=evaluator,
|
||||
numFolds=2,
|
||||
foldCol="fold")
|
||||
with self.assertRaisesRegexp(Exception, "Fold number must be in range"):
|
||||
with self.assertRaisesRegex(Exception, "Fold number must be in range"):
|
||||
cv.fit(dataset_with_folds)
|
||||
|
||||
cv = CrossValidator(estimator=lr,
|
||||
|
@ -507,7 +507,7 @@ class CrossValidatorTests(SparkSessionTestCase):
|
|||
evaluator=evaluator,
|
||||
numFolds=4,
|
||||
foldCol="fold")
|
||||
with self.assertRaisesRegexp(Exception, "The validation data at fold 3 is empty"):
|
||||
with self.assertRaisesRegex(Exception, "The validation data at fold 3 is empty"):
|
||||
cv.fit(dataset_with_folds)
|
||||
|
||||
|
||||
|
|
|
@ -54,7 +54,7 @@ class JavaWrapperMemoryTests(SparkSessionTestCase):
|
|||
model.__del__()
|
||||
|
||||
def condition():
|
||||
with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
|
||||
with self.assertRaisesRegex(py4j.protocol.Py4JError, error_no_object):
|
||||
model._java_obj.toString()
|
||||
self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString())
|
||||
return True
|
||||
|
@ -67,9 +67,9 @@ class JavaWrapperMemoryTests(SparkSessionTestCase):
|
|||
pass
|
||||
|
||||
def condition():
|
||||
with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
|
||||
with self.assertRaisesRegex(py4j.protocol.Py4JError, error_no_object):
|
||||
model._java_obj.toString()
|
||||
with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
|
||||
with self.assertRaisesRegex(py4j.protocol.Py4JError, error_no_object):
|
||||
summary._java_obj.toString()
|
||||
return True
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ from pyspark.testing.utils import QuietTest
|
|||
|
||||
if have_pandas:
|
||||
import pandas as pd
|
||||
from pandas.util.testing import assert_frame_equal
|
||||
from pandas.testing import assert_frame_equal
|
||||
|
||||
if have_pyarrow:
|
||||
import pyarrow as pa # noqa: F401
|
||||
|
@ -137,7 +137,7 @@ class ArrowTests(ReusedSQLTestCase):
|
|||
df = self.spark.createDataFrame([(None,)], schema=schema)
|
||||
with QuietTest(self.sc):
|
||||
with self.warnings_lock:
|
||||
with self.assertRaisesRegexp(Exception, 'Unsupported type'):
|
||||
with self.assertRaisesRegex(Exception, 'Unsupported type'):
|
||||
df.toPandas()
|
||||
|
||||
def test_null_conversion(self):
|
||||
|
@ -214,7 +214,7 @@ class ArrowTests(ReusedSQLTestCase):
|
|||
exception_udf = udf(raise_exception, IntegerType())
|
||||
df = df.withColumn("error", exception_udf())
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(Exception, 'My error'):
|
||||
with self.assertRaisesRegex(Exception, 'My error'):
|
||||
df.toPandas()
|
||||
|
||||
def _createDataFrame_toggle(self, pdf, schema=None):
|
||||
|
@ -228,7 +228,7 @@ class ArrowTests(ReusedSQLTestCase):
|
|||
def test_createDataFrame_toggle(self):
|
||||
pdf = self.create_pandas_data_frame()
|
||||
df_no_arrow, df_arrow = self._createDataFrame_toggle(pdf, schema=self.schema)
|
||||
self.assertEquals(df_no_arrow.collect(), df_arrow.collect())
|
||||
self.assertEqual(df_no_arrow.collect(), df_arrow.collect())
|
||||
|
||||
def test_createDataFrame_respect_session_timezone(self):
|
||||
from datetime import timedelta
|
||||
|
@ -258,7 +258,7 @@ class ArrowTests(ReusedSQLTestCase):
|
|||
def test_createDataFrame_with_schema(self):
|
||||
pdf = self.create_pandas_data_frame()
|
||||
df = self.spark.createDataFrame(pdf, schema=self.schema)
|
||||
self.assertEquals(self.schema, df.schema)
|
||||
self.assertEqual(self.schema, df.schema)
|
||||
pdf_arrow = df.toPandas()
|
||||
assert_frame_equal(pdf_arrow, pdf)
|
||||
|
||||
|
@ -269,7 +269,7 @@ class ArrowTests(ReusedSQLTestCase):
|
|||
wrong_schema = StructType(fields)
|
||||
with self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": False}):
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(Exception, "[D|d]ecimal.*got.*date"):
|
||||
with self.assertRaisesRegex(Exception, "[D|d]ecimal.*got.*date"):
|
||||
self.spark.createDataFrame(pdf, schema=wrong_schema)
|
||||
|
||||
def test_createDataFrame_with_names(self):
|
||||
|
@ -277,23 +277,23 @@ class ArrowTests(ReusedSQLTestCase):
|
|||
new_names = list(map(str, range(len(self.schema.fieldNames()))))
|
||||
# Test that schema as a list of column names gets applied
|
||||
df = self.spark.createDataFrame(pdf, schema=list(new_names))
|
||||
self.assertEquals(df.schema.fieldNames(), new_names)
|
||||
self.assertEqual(df.schema.fieldNames(), new_names)
|
||||
# Test that schema as tuple of column names gets applied
|
||||
df = self.spark.createDataFrame(pdf, schema=tuple(new_names))
|
||||
self.assertEquals(df.schema.fieldNames(), new_names)
|
||||
self.assertEqual(df.schema.fieldNames(), new_names)
|
||||
|
||||
def test_createDataFrame_column_name_encoding(self):
|
||||
pdf = pd.DataFrame({u'a': [1]})
|
||||
columns = self.spark.createDataFrame(pdf).columns
|
||||
self.assertTrue(isinstance(columns[0], str))
|
||||
self.assertEquals(columns[0], 'a')
|
||||
self.assertEqual(columns[0], 'a')
|
||||
columns = self.spark.createDataFrame(pdf, [u'b']).columns
|
||||
self.assertTrue(isinstance(columns[0], str))
|
||||
self.assertEquals(columns[0], 'b')
|
||||
self.assertEqual(columns[0], 'b')
|
||||
|
||||
def test_createDataFrame_with_single_data_type(self):
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(ValueError, ".*IntegerType.*not supported.*"):
|
||||
with self.assertRaisesRegex(ValueError, ".*IntegerType.*not supported.*"):
|
||||
self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int")
|
||||
|
||||
def test_createDataFrame_does_not_modify_input(self):
|
||||
|
@ -311,7 +311,7 @@ class ArrowTests(ReusedSQLTestCase):
|
|||
from pyspark.sql.pandas.types import from_arrow_schema, to_arrow_schema
|
||||
arrow_schema = to_arrow_schema(self.schema)
|
||||
schema_rt = from_arrow_schema(arrow_schema)
|
||||
self.assertEquals(self.schema, schema_rt)
|
||||
self.assertEqual(self.schema, schema_rt)
|
||||
|
||||
def test_createDataFrame_with_array_type(self):
|
||||
pdf = pd.DataFrame({"a": [[1, 2], [3, 4]], "b": [[u"x", u"y"], [u"y", u"z"]]})
|
||||
|
@ -420,7 +420,7 @@ class ArrowTests(ReusedSQLTestCase):
|
|||
|
||||
def test_createDataFrame_fallback_disabled(self):
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(TypeError, 'Unsupported type'):
|
||||
with self.assertRaisesRegex(TypeError, 'Unsupported type'):
|
||||
self.spark.createDataFrame(
|
||||
pd.DataFrame({"a": [[datetime.datetime(2015, 11, 1, 0, 30)]]}),
|
||||
"a: array<timestamp>")
|
||||
|
@ -545,7 +545,7 @@ class MaxResultArrowTests(unittest.TestCase):
|
|||
cls.spark.stop()
|
||||
|
||||
def test_exception_by_max_results(self):
|
||||
with self.assertRaisesRegexp(Exception, "is bigger than"):
|
||||
with self.assertRaisesRegex(Exception, "is bigger than"):
|
||||
self.spark.range(0, 10000, 1, 100).toPandas()
|
||||
|
||||
|
||||
|
|
|
@ -25,11 +25,11 @@ class CatalogTests(ReusedSQLTestCase):
|
|||
def test_current_database(self):
|
||||
spark = self.spark
|
||||
with self.database("some_db"):
|
||||
self.assertEquals(spark.catalog.currentDatabase(), "default")
|
||||
self.assertEqual(spark.catalog.currentDatabase(), "default")
|
||||
spark.sql("CREATE DATABASE some_db")
|
||||
spark.catalog.setCurrentDatabase("some_db")
|
||||
self.assertEquals(spark.catalog.currentDatabase(), "some_db")
|
||||
self.assertRaisesRegexp(
|
||||
self.assertEqual(spark.catalog.currentDatabase(), "some_db")
|
||||
self.assertRaisesRegex(
|
||||
AnalysisException,
|
||||
"does_not_exist",
|
||||
lambda: spark.catalog.setCurrentDatabase("does_not_exist"))
|
||||
|
@ -38,10 +38,10 @@ class CatalogTests(ReusedSQLTestCase):
|
|||
spark = self.spark
|
||||
with self.database("some_db"):
|
||||
databases = [db.name for db in spark.catalog.listDatabases()]
|
||||
self.assertEquals(databases, ["default"])
|
||||
self.assertEqual(databases, ["default"])
|
||||
spark.sql("CREATE DATABASE some_db")
|
||||
databases = [db.name for db in spark.catalog.listDatabases()]
|
||||
self.assertEquals(sorted(databases), ["default", "some_db"])
|
||||
self.assertEqual(sorted(databases), ["default", "some_db"])
|
||||
|
||||
def test_list_tables(self):
|
||||
from pyspark.sql.catalog import Table
|
||||
|
@ -50,8 +50,8 @@ class CatalogTests(ReusedSQLTestCase):
|
|||
spark.sql("CREATE DATABASE some_db")
|
||||
with self.table("tab1", "some_db.tab2", "tab3_via_catalog"):
|
||||
with self.tempView("temp_tab"):
|
||||
self.assertEquals(spark.catalog.listTables(), [])
|
||||
self.assertEquals(spark.catalog.listTables("some_db"), [])
|
||||
self.assertEqual(spark.catalog.listTables(), [])
|
||||
self.assertEqual(spark.catalog.listTables("some_db"), [])
|
||||
spark.createDataFrame([(1, 1)]).createOrReplaceTempView("temp_tab")
|
||||
spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet")
|
||||
spark.sql("CREATE TABLE some_db.tab2 (name STRING, age INT) USING parquet")
|
||||
|
@ -66,40 +66,40 @@ class CatalogTests(ReusedSQLTestCase):
|
|||
sorted(spark.catalog.listTables("default"), key=lambda t: t.name)
|
||||
tablesSomeDb = \
|
||||
sorted(spark.catalog.listTables("some_db"), key=lambda t: t.name)
|
||||
self.assertEquals(tables, tablesDefault)
|
||||
self.assertEquals(len(tables), 3)
|
||||
self.assertEquals(len(tablesSomeDb), 2)
|
||||
self.assertEquals(tables[0], Table(
|
||||
self.assertEqual(tables, tablesDefault)
|
||||
self.assertEqual(len(tables), 3)
|
||||
self.assertEqual(len(tablesSomeDb), 2)
|
||||
self.assertEqual(tables[0], Table(
|
||||
name="tab1",
|
||||
database="default",
|
||||
description=None,
|
||||
tableType="MANAGED",
|
||||
isTemporary=False))
|
||||
self.assertEquals(tables[1], Table(
|
||||
self.assertEqual(tables[1], Table(
|
||||
name="tab3_via_catalog",
|
||||
database="default",
|
||||
description=description,
|
||||
tableType="MANAGED",
|
||||
isTemporary=False))
|
||||
self.assertEquals(tables[2], Table(
|
||||
self.assertEqual(tables[2], Table(
|
||||
name="temp_tab",
|
||||
database=None,
|
||||
description=None,
|
||||
tableType="TEMPORARY",
|
||||
isTemporary=True))
|
||||
self.assertEquals(tablesSomeDb[0], Table(
|
||||
self.assertEqual(tablesSomeDb[0], Table(
|
||||
name="tab2",
|
||||
database="some_db",
|
||||
description=None,
|
||||
tableType="MANAGED",
|
||||
isTemporary=False))
|
||||
self.assertEquals(tablesSomeDb[1], Table(
|
||||
self.assertEqual(tablesSomeDb[1], Table(
|
||||
name="temp_tab",
|
||||
database=None,
|
||||
description=None,
|
||||
tableType="TEMPORARY",
|
||||
isTemporary=True))
|
||||
self.assertRaisesRegexp(
|
||||
self.assertRaisesRegex(
|
||||
AnalysisException,
|
||||
"does_not_exist",
|
||||
lambda: spark.catalog.listTables("does_not_exist"))
|
||||
|
@ -119,12 +119,12 @@ class CatalogTests(ReusedSQLTestCase):
|
|||
self.assertTrue("to_timestamp" in functions)
|
||||
self.assertTrue("to_unix_timestamp" in functions)
|
||||
self.assertTrue("current_database" in functions)
|
||||
self.assertEquals(functions["+"], Function(
|
||||
self.assertEqual(functions["+"], Function(
|
||||
name="+",
|
||||
description=None,
|
||||
className="org.apache.spark.sql.catalyst.expressions.Add",
|
||||
isTemporary=True))
|
||||
self.assertEquals(functions, functionsDefault)
|
||||
self.assertEqual(functions, functionsDefault)
|
||||
|
||||
with self.function("func1", "some_db.func2"):
|
||||
spark.catalog.registerFunction("temp_func", lambda x: str(x))
|
||||
|
@ -141,7 +141,7 @@ class CatalogTests(ReusedSQLTestCase):
|
|||
self.assertTrue("temp_func" in newFunctionsSomeDb)
|
||||
self.assertTrue("func1" not in newFunctionsSomeDb)
|
||||
self.assertTrue("func2" in newFunctionsSomeDb)
|
||||
self.assertRaisesRegexp(
|
||||
self.assertRaisesRegex(
|
||||
AnalysisException,
|
||||
"does_not_exist",
|
||||
lambda: spark.catalog.listFunctions("does_not_exist"))
|
||||
|
@ -158,16 +158,16 @@ class CatalogTests(ReusedSQLTestCase):
|
|||
columns = sorted(spark.catalog.listColumns("tab1"), key=lambda c: c.name)
|
||||
columnsDefault = \
|
||||
sorted(spark.catalog.listColumns("tab1", "default"), key=lambda c: c.name)
|
||||
self.assertEquals(columns, columnsDefault)
|
||||
self.assertEquals(len(columns), 2)
|
||||
self.assertEquals(columns[0], Column(
|
||||
self.assertEqual(columns, columnsDefault)
|
||||
self.assertEqual(len(columns), 2)
|
||||
self.assertEqual(columns[0], Column(
|
||||
name="age",
|
||||
description=None,
|
||||
dataType="int",
|
||||
nullable=True,
|
||||
isPartition=False,
|
||||
isBucket=False))
|
||||
self.assertEquals(columns[1], Column(
|
||||
self.assertEqual(columns[1], Column(
|
||||
name="name",
|
||||
description=None,
|
||||
dataType="string",
|
||||
|
@ -176,26 +176,26 @@ class CatalogTests(ReusedSQLTestCase):
|
|||
isBucket=False))
|
||||
columns2 = \
|
||||
sorted(spark.catalog.listColumns("tab2", "some_db"), key=lambda c: c.name)
|
||||
self.assertEquals(len(columns2), 2)
|
||||
self.assertEquals(columns2[0], Column(
|
||||
self.assertEqual(len(columns2), 2)
|
||||
self.assertEqual(columns2[0], Column(
|
||||
name="nickname",
|
||||
description=None,
|
||||
dataType="string",
|
||||
nullable=True,
|
||||
isPartition=False,
|
||||
isBucket=False))
|
||||
self.assertEquals(columns2[1], Column(
|
||||
self.assertEqual(columns2[1], Column(
|
||||
name="tolerance",
|
||||
description=None,
|
||||
dataType="float",
|
||||
nullable=True,
|
||||
isPartition=False,
|
||||
isBucket=False))
|
||||
self.assertRaisesRegexp(
|
||||
self.assertRaisesRegex(
|
||||
AnalysisException,
|
||||
"tab2",
|
||||
lambda: spark.catalog.listColumns("tab2"))
|
||||
self.assertRaisesRegexp(
|
||||
self.assertRaisesRegex(
|
||||
AnalysisException,
|
||||
"does_not_exist",
|
||||
lambda: spark.catalog.listColumns("does_not_exist"))
|
||||
|
|
|
@ -47,7 +47,7 @@ class ColumnTests(ReusedSQLTestCase):
|
|||
self.assertTrue("Column" in _to_java_column(u"a").getClass().toString())
|
||||
self.assertTrue("Column" in _to_java_column(self.spark.range(1).id).getClass().toString())
|
||||
|
||||
self.assertRaisesRegexp(
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"Invalid argument, not a string or column",
|
||||
lambda: _to_java_column(1))
|
||||
|
@ -58,7 +58,7 @@ class ColumnTests(ReusedSQLTestCase):
|
|||
self.assertRaises(TypeError, lambda: _to_java_column(A()))
|
||||
self.assertRaises(TypeError, lambda: _to_java_column([]))
|
||||
|
||||
self.assertRaisesRegexp(
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"Invalid argument, not a string or column",
|
||||
lambda: udf(lambda x: x)(None))
|
||||
|
@ -79,9 +79,9 @@ class ColumnTests(ReusedSQLTestCase):
|
|||
cs.startswith('a'), cs.endswith('a'), ci.eqNullSafe(cs)
|
||||
self.assertTrue(all(isinstance(c, Column) for c in css))
|
||||
self.assertTrue(isinstance(ci.cast(LongType()), Column))
|
||||
self.assertRaisesRegexp(ValueError,
|
||||
"Cannot apply 'in' operator against a column",
|
||||
lambda: 1 in cs)
|
||||
self.assertRaisesRegex(ValueError,
|
||||
"Cannot apply 'in' operator against a column",
|
||||
lambda: 1 in cs)
|
||||
|
||||
def test_column_accessor(self):
|
||||
from pyspark.sql.functions import col
|
||||
|
|
|
@ -28,7 +28,7 @@ class ConfTests(ReusedSQLTestCase):
|
|||
self.assertEqual(spark.conf.get("bogo"), "ta")
|
||||
self.assertEqual(spark.conf.get("bogo", "not.read"), "ta")
|
||||
self.assertEqual(spark.conf.get("not.set", "ta"), "ta")
|
||||
self.assertRaisesRegexp(Exception, "not.set", lambda: spark.conf.get("not.set"))
|
||||
self.assertRaisesRegex(Exception, "not.set", lambda: spark.conf.get("not.set"))
|
||||
spark.conf.unset("bogo")
|
||||
self.assertEqual(spark.conf.get("bogo", "colombia"), "colombia")
|
||||
|
||||
|
|
|
@ -343,7 +343,7 @@ class DataFrameTests(ReusedSQLTestCase):
|
|||
self.spark.createDataFrame(
|
||||
[(u'Alice', 10, 80.1)], schema).replace({u"Alice": u"Bob", 10: 20}).first()
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
'value argument is required when to_replace is not a dictionary.'):
|
||||
self.spark.createDataFrame(
|
||||
|
@ -390,7 +390,7 @@ class DataFrameTests(ReusedSQLTestCase):
|
|||
self.assertEqual(3, logical_plan.toString().count("itworks"))
|
||||
|
||||
def test_sample(self):
|
||||
self.assertRaisesRegexp(
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"should be a bool, float and number",
|
||||
lambda: self.spark.range(1).sample())
|
||||
|
@ -426,12 +426,12 @@ class DataFrameTests(ReusedSQLTestCase):
|
|||
self.assertEqual(df.collect(), data)
|
||||
|
||||
# number of fields must match.
|
||||
self.assertRaisesRegexp(Exception, "Length of object",
|
||||
lambda: rdd.toDF("key: int").collect())
|
||||
self.assertRaisesRegex(Exception, "Length of object",
|
||||
lambda: rdd.toDF("key: int").collect())
|
||||
|
||||
# field types mismatch will cause exception at runtime.
|
||||
self.assertRaisesRegexp(Exception, "FloatType can not accept",
|
||||
lambda: rdd.toDF("key: float, value: string").collect())
|
||||
self.assertRaisesRegex(Exception, "FloatType can not accept",
|
||||
lambda: rdd.toDF("key: float, value: string").collect())
|
||||
|
||||
# flat schema values will be wrapped into row.
|
||||
df = rdd.map(lambda row: row.key).toDF("int")
|
||||
|
@ -491,15 +491,15 @@ class DataFrameTests(ReusedSQLTestCase):
|
|||
spark.catalog.clearCache()
|
||||
self.assertFalse(spark.catalog.isCached("tab1"))
|
||||
self.assertFalse(spark.catalog.isCached("tab2"))
|
||||
self.assertRaisesRegexp(
|
||||
self.assertRaisesRegex(
|
||||
AnalysisException,
|
||||
"does_not_exist",
|
||||
lambda: spark.catalog.isCached("does_not_exist"))
|
||||
self.assertRaisesRegexp(
|
||||
self.assertRaisesRegex(
|
||||
AnalysisException,
|
||||
"does_not_exist",
|
||||
lambda: spark.catalog.cacheTable("does_not_exist"))
|
||||
self.assertRaisesRegexp(
|
||||
self.assertRaisesRegex(
|
||||
AnalysisException,
|
||||
"does_not_exist",
|
||||
lambda: spark.catalog.uncacheTable("does_not_exist"))
|
||||
|
@ -523,12 +523,12 @@ class DataFrameTests(ReusedSQLTestCase):
|
|||
import numpy as np
|
||||
pdf = self._to_pandas()
|
||||
types = pdf.dtypes
|
||||
self.assertEquals(types[0], np.int32)
|
||||
self.assertEquals(types[1], np.object)
|
||||
self.assertEquals(types[2], np.bool)
|
||||
self.assertEquals(types[3], np.float32)
|
||||
self.assertEquals(types[4], np.object) # datetime.date
|
||||
self.assertEquals(types[5], 'datetime64[ns]')
|
||||
self.assertEqual(types[0], np.int32)
|
||||
self.assertEqual(types[1], np.object)
|
||||
self.assertEqual(types[2], np.bool)
|
||||
self.assertEqual(types[3], np.float32)
|
||||
self.assertEqual(types[4], np.object) # datetime.date
|
||||
self.assertEqual(types[5], 'datetime64[ns]')
|
||||
|
||||
@unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore
|
||||
def test_to_pandas_with_duplicated_column_names(self):
|
||||
|
@ -540,8 +540,8 @@ class DataFrameTests(ReusedSQLTestCase):
|
|||
df = self.spark.sql(sql)
|
||||
pdf = df.toPandas()
|
||||
types = pdf.dtypes
|
||||
self.assertEquals(types.iloc[0], np.int32)
|
||||
self.assertEquals(types.iloc[1], np.int32)
|
||||
self.assertEqual(types.iloc[0], np.int32)
|
||||
self.assertEqual(types.iloc[1], np.int32)
|
||||
|
||||
@unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore
|
||||
def test_to_pandas_on_cross_join(self):
|
||||
|
@ -560,13 +560,13 @@ class DataFrameTests(ReusedSQLTestCase):
|
|||
df = self.spark.sql(sql)
|
||||
pdf = df.toPandas()
|
||||
types = pdf.dtypes
|
||||
self.assertEquals(types.iloc[0], np.int32)
|
||||
self.assertEquals(types.iloc[1], np.int32)
|
||||
self.assertEqual(types.iloc[0], np.int32)
|
||||
self.assertEqual(types.iloc[1], np.int32)
|
||||
|
||||
@unittest.skipIf(have_pandas, "Required Pandas was found.")
|
||||
def test_to_pandas_required_pandas_not_found(self):
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'):
|
||||
with self.assertRaisesRegex(ImportError, 'Pandas >= .* must be installed'):
|
||||
self._to_pandas()
|
||||
|
||||
@unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore
|
||||
|
@ -577,9 +577,9 @@ class DataFrameTests(ReusedSQLTestCase):
|
|||
data = [(1, "foo", 16777220), (None, "bar", None)]
|
||||
df = self.spark.createDataFrame(data, schema)
|
||||
types = df.toPandas().dtypes
|
||||
self.assertEquals(types[0], np.float64) # doesn't convert to np.int32 due to NaN value.
|
||||
self.assertEquals(types[1], np.object)
|
||||
self.assertEquals(types[2], np.float64)
|
||||
self.assertEqual(types[0], np.float64) # doesn't convert to np.int32 due to NaN value.
|
||||
self.assertEqual(types[1], np.object)
|
||||
self.assertEqual(types[2], np.float64)
|
||||
|
||||
@unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore
|
||||
def test_to_pandas_from_empty_dataframe(self):
|
||||
|
@ -675,7 +675,7 @@ class DataFrameTests(ReusedSQLTestCase):
|
|||
@unittest.skipIf(have_pandas, "Required Pandas was found.")
|
||||
def test_create_dataframe_required_pandas_not_found(self):
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ImportError,
|
||||
"(Pandas >= .* must be installed|No module named '?pandas'?)"):
|
||||
import pandas as pd
|
||||
|
@ -688,7 +688,7 @@ class DataFrameTests(ReusedSQLTestCase):
|
|||
@unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore
|
||||
def test_create_dataframe_from_pandas_with_dst(self):
|
||||
import pandas as pd
|
||||
from pandas.util.testing import assert_frame_equal
|
||||
from pandas.testing import assert_frame_equal
|
||||
from datetime import datetime
|
||||
|
||||
pdf = pd.DataFrame({'time': [datetime(2015, 10, 31, 22, 30)]})
|
||||
|
@ -724,7 +724,7 @@ class DataFrameTests(ReusedSQLTestCase):
|
|||
||22222|22222|
|
||||
|+-----+-----+
|
||||
|"""
|
||||
self.assertEquals(re.sub(pattern, '', expected1), df.__repr__())
|
||||
self.assertEqual(re.sub(pattern, '', expected1), df.__repr__())
|
||||
with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}):
|
||||
expected2 = """+---+-----+
|
||||
||key|value|
|
||||
|
@ -733,7 +733,7 @@ class DataFrameTests(ReusedSQLTestCase):
|
|||
||222| 222|
|
||||
|+---+-----+
|
||||
|"""
|
||||
self.assertEquals(re.sub(pattern, '', expected2), df.__repr__())
|
||||
self.assertEqual(re.sub(pattern, '', expected2), df.__repr__())
|
||||
with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}):
|
||||
expected3 = """+---+-----+
|
||||
||key|value|
|
||||
|
@ -742,7 +742,7 @@ class DataFrameTests(ReusedSQLTestCase):
|
|||
|+---+-----+
|
||||
|only showing top 1 row
|
||||
|"""
|
||||
self.assertEquals(re.sub(pattern, '', expected3), df.__repr__())
|
||||
self.assertEqual(re.sub(pattern, '', expected3), df.__repr__())
|
||||
|
||||
# test when eager evaluation is enabled and _repr_html_ will be called
|
||||
with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}):
|
||||
|
@ -752,7 +752,7 @@ class DataFrameTests(ReusedSQLTestCase):
|
|||
|<tr><td>22222</td><td>22222</td></tr>
|
||||
|</table>
|
||||
|"""
|
||||
self.assertEquals(re.sub(pattern, '', expected1), df._repr_html_())
|
||||
self.assertEqual(re.sub(pattern, '', expected1), df._repr_html_())
|
||||
with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}):
|
||||
expected2 = """<table border='1'>
|
||||
|<tr><th>key</th><th>value</th></tr>
|
||||
|
@ -760,7 +760,7 @@ class DataFrameTests(ReusedSQLTestCase):
|
|||
|<tr><td>222</td><td>222</td></tr>
|
||||
|</table>
|
||||
|"""
|
||||
self.assertEquals(re.sub(pattern, '', expected2), df._repr_html_())
|
||||
self.assertEqual(re.sub(pattern, '', expected2), df._repr_html_())
|
||||
with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}):
|
||||
expected3 = """<table border='1'>
|
||||
|<tr><th>key</th><th>value</th></tr>
|
||||
|
@ -768,19 +768,19 @@ class DataFrameTests(ReusedSQLTestCase):
|
|||
|</table>
|
||||
|only showing top 1 row
|
||||
|"""
|
||||
self.assertEquals(re.sub(pattern, '', expected3), df._repr_html_())
|
||||
self.assertEqual(re.sub(pattern, '', expected3), df._repr_html_())
|
||||
|
||||
# test when eager evaluation is disabled and _repr_html_ will be called
|
||||
with self.sql_conf({"spark.sql.repl.eagerEval.enabled": False}):
|
||||
expected = "DataFrame[key: bigint, value: string]"
|
||||
self.assertEquals(None, df._repr_html_())
|
||||
self.assertEquals(expected, df.__repr__())
|
||||
self.assertEqual(None, df._repr_html_())
|
||||
self.assertEqual(expected, df.__repr__())
|
||||
with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}):
|
||||
self.assertEquals(None, df._repr_html_())
|
||||
self.assertEquals(expected, df.__repr__())
|
||||
self.assertEqual(None, df._repr_html_())
|
||||
self.assertEqual(expected, df.__repr__())
|
||||
with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}):
|
||||
self.assertEquals(None, df._repr_html_())
|
||||
self.assertEquals(expected, df.__repr__())
|
||||
self.assertEqual(None, df._repr_html_())
|
||||
self.assertEqual(expected, df.__repr__())
|
||||
|
||||
def test_to_local_iterator(self):
|
||||
df = self.spark.range(8, numPartitions=4)
|
||||
|
@ -818,7 +818,7 @@ class DataFrameTests(ReusedSQLTestCase):
|
|||
|
||||
def test_same_semantics_error(self):
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(ValueError, "should be of DataFrame.*int"):
|
||||
with self.assertRaisesRegex(ValueError, "should be of DataFrame.*int"):
|
||||
self.spark.range(10).sameSemantics(1)
|
||||
|
||||
def test_input_files(self):
|
||||
|
@ -830,7 +830,7 @@ class DataFrameTests(ReusedSQLTestCase):
|
|||
input_files_list = self.spark.read.parquet(tpath).inputFiles()
|
||||
|
||||
# input files list should contain 10 entries
|
||||
self.assertEquals(len(input_files_list), 10)
|
||||
self.assertEqual(len(input_files_list), 10)
|
||||
# all file paths in list must contain tpath
|
||||
for file_path in input_files_list:
|
||||
self.assertTrue(tpath in file_path)
|
||||
|
|
|
@ -107,7 +107,7 @@ class DataSourcesTests(ReusedSQLTestCase):
|
|||
df = self.spark.read.text(['python/test_support/sql/text-test.txt',
|
||||
'python/test_support/sql/text-test.txt'])
|
||||
count = df.count()
|
||||
self.assertEquals(count, 4)
|
||||
self.assertEqual(count, 4)
|
||||
|
||||
def test_json_sampling_ratio(self):
|
||||
rdd = self.spark.sparkContext.range(0, 100, 1, 1) \
|
||||
|
@ -115,14 +115,14 @@ class DataSourcesTests(ReusedSQLTestCase):
|
|||
schema = self.spark.read.option('inferSchema', True) \
|
||||
.option('samplingRatio', 0.5) \
|
||||
.json(rdd).schema
|
||||
self.assertEquals(schema, StructType([StructField("a", LongType(), True)]))
|
||||
self.assertEqual(schema, StructType([StructField("a", LongType(), True)]))
|
||||
|
||||
def test_csv_sampling_ratio(self):
|
||||
rdd = self.spark.sparkContext.range(0, 100, 1, 1) \
|
||||
.map(lambda x: '0.1' if x == 1 else str(x))
|
||||
schema = self.spark.read.option('inferSchema', True)\
|
||||
.csv(rdd, samplingRatio=0.5).schema
|
||||
self.assertEquals(schema, StructType([StructField("_c0", IntegerType(), True)]))
|
||||
self.assertEqual(schema, StructType([StructField("_c0", IntegerType(), True)]))
|
||||
|
||||
def test_checking_csv_header(self):
|
||||
path = tempfile.mkdtemp()
|
||||
|
@ -135,7 +135,7 @@ class DataSourcesTests(ReusedSQLTestCase):
|
|||
StructField('f1', IntegerType(), nullable=True)])
|
||||
df = self.spark.read.option('header', 'true').schema(schema)\
|
||||
.csv(path, enforceSchema=False)
|
||||
self.assertRaisesRegexp(
|
||||
self.assertRaisesRegex(
|
||||
Exception,
|
||||
"CSV header does not conform to the schema",
|
||||
lambda: df.collect())
|
||||
|
@ -154,7 +154,7 @@ class DataSourcesTests(ReusedSQLTestCase):
|
|||
StructField('b', LongType(), nullable=True),
|
||||
StructField('c', StringType(), nullable=True)])
|
||||
readback = self.spark.read.json(path, dropFieldIfAllNull=True)
|
||||
self.assertEquals(readback.schema, schema)
|
||||
self.assertEqual(readback.schema, schema)
|
||||
finally:
|
||||
shutil.rmtree(path)
|
||||
|
||||
|
|
|
@ -185,7 +185,7 @@ class FunctionsTests(ReusedSQLTestCase):
|
|||
]
|
||||
|
||||
df = self.spark.createDataFrame([['nick']], schema=['name'])
|
||||
self.assertRaisesRegexp(
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"must be the same type",
|
||||
lambda: df.select(col('name').substr(0, lit(1))))
|
||||
|
@ -321,16 +321,16 @@ class FunctionsTests(ReusedSQLTestCase):
|
|||
|
||||
df = self.spark.createDataFrame(
|
||||
[('Tom', 80), (None, 60), ('Alice', 50)], ["name", "height"])
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
df.select(df.name).orderBy(functions.asc_nulls_first('name')).collect(),
|
||||
[Row(name=None), Row(name=u'Alice'), Row(name=u'Tom')])
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
df.select(df.name).orderBy(functions.asc_nulls_last('name')).collect(),
|
||||
[Row(name=u'Alice'), Row(name=u'Tom'), Row(name=None)])
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
df.select(df.name).orderBy(functions.desc_nulls_first('name')).collect(),
|
||||
[Row(name=None), Row(name=u'Tom'), Row(name=u'Alice')])
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
df.select(df.name).orderBy(functions.desc_nulls_last('name')).collect(),
|
||||
[Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)])
|
||||
|
||||
|
@ -354,7 +354,7 @@ class FunctionsTests(ReusedSQLTestCase):
|
|||
|
||||
df = self.spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ['x'])
|
||||
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
df.select(slice(df.x, 2, 2).alias("sliced")).collect(),
|
||||
df.select(slice(df.x, lit(2), lit(2)).alias("sliced")).collect(),
|
||||
)
|
||||
|
@ -364,7 +364,7 @@ class FunctionsTests(ReusedSQLTestCase):
|
|||
|
||||
df = self.spark.range(1)
|
||||
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
df.select(array_repeat("id", 3)).toDF("val").collect(),
|
||||
df.select(array_repeat("id", lit(3))).toDF("val").collect(),
|
||||
)
|
||||
|
@ -580,14 +580,14 @@ class FunctionsTests(ReusedSQLTestCase):
|
|||
from datetime import date
|
||||
df = self.spark.range(1).selectExpr("'2017-01-22' as dateCol")
|
||||
parse_result = df.select(functions.to_date(functions.col("dateCol"))).first()
|
||||
self.assertEquals(date(2017, 1, 22), parse_result['to_date(dateCol)'])
|
||||
self.assertEqual(date(2017, 1, 22), parse_result['to_date(dateCol)'])
|
||||
|
||||
def test_assert_true(self):
|
||||
from pyspark.sql.functions import assert_true
|
||||
|
||||
df = self.spark.range(3)
|
||||
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
df.select(assert_true(df.id < 3)).toDF("val").collect(),
|
||||
[Row(val=None), Row(val=None), Row(val=None)],
|
||||
)
|
||||
|
@ -604,7 +604,7 @@ class FunctionsTests(ReusedSQLTestCase):
|
|||
|
||||
with self.assertRaises(TypeError) as cm:
|
||||
df.select(assert_true(df.id < 2, 5))
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
"errMsg should be a Column or a str, got <class 'int'>",
|
||||
str(cm.exception)
|
||||
)
|
||||
|
@ -626,7 +626,7 @@ class FunctionsTests(ReusedSQLTestCase):
|
|||
|
||||
with self.assertRaises(TypeError) as cm:
|
||||
df.select(raise_error(None))
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
"errMsg should be a Column or a str, got <class 'NoneType'>",
|
||||
str(cm.exception)
|
||||
)
|
||||
|
|
|
@ -25,7 +25,7 @@ from pyspark.testing.utils import QuietTest
|
|||
|
||||
if have_pandas:
|
||||
import pandas as pd
|
||||
from pandas.util.testing import assert_frame_equal
|
||||
from pandas.testing import assert_frame_equal
|
||||
|
||||
if have_pyarrow:
|
||||
import pyarrow as pa # noqa: F401
|
||||
|
@ -135,8 +135,8 @@ class CogroupedMapInPandasTests(ReusedSQLTestCase):
|
|||
.applyInPandas(lambda x, y: pd.DataFrame([(x.sum().sum(), y.sum().sum())]),
|
||||
'sum1 int, sum2 int').collect()
|
||||
|
||||
self.assertEquals(result[0]['sum1'], 165)
|
||||
self.assertEquals(result[0]['sum2'], 165)
|
||||
self.assertEqual(result[0]['sum1'], 165)
|
||||
self.assertEqual(result[0]['sum2'], 165)
|
||||
|
||||
def test_with_key_left(self):
|
||||
self._test_with_key(self.data1, self.data1, isLeft=True)
|
||||
|
@ -174,7 +174,7 @@ class CogroupedMapInPandasTests(ReusedSQLTestCase):
|
|||
left = self.data1
|
||||
right = self.data2
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
'Invalid return type.*ArrayType.*TimestampType'):
|
||||
left.groupby('id').cogroup(right.groupby('id')).applyInPandas(
|
||||
|
@ -183,7 +183,7 @@ class CogroupedMapInPandasTests(ReusedSQLTestCase):
|
|||
def test_wrong_args(self):
|
||||
left = self.data1
|
||||
right = self.data2
|
||||
with self.assertRaisesRegexp(ValueError, 'Invalid function'):
|
||||
with self.assertRaisesRegex(ValueError, 'Invalid function'):
|
||||
left.groupby('id').cogroup(right.groupby('id')) \
|
||||
.applyInPandas(lambda: 1, StructType([StructField("d", DoubleType())]))
|
||||
|
||||
|
@ -194,14 +194,14 @@ class CogroupedMapInPandasTests(ReusedSQLTestCase):
|
|||
row = df1.groupby("ColUmn").cogroup(
|
||||
df1.groupby("COLUMN")
|
||||
).applyInPandas(lambda r, l: r + l, "column long, value long").first()
|
||||
self.assertEquals(row.asDict(), Row(column=2, value=2).asDict())
|
||||
self.assertEqual(row.asDict(), Row(column=2, value=2).asDict())
|
||||
|
||||
df2 = self.spark.createDataFrame([(1, 1)], ("column", "value"))
|
||||
|
||||
row = df1.groupby("ColUmn").cogroup(
|
||||
df2.groupby("COLUMN")
|
||||
).applyInPandas(lambda r, l: r + l, "column long, value long").first()
|
||||
self.assertEquals(row.asDict(), Row(column=2, value=2).asDict())
|
||||
self.assertEqual(row.asDict(), Row(column=2, value=2).asDict())
|
||||
|
||||
@staticmethod
|
||||
def _test_with_key(left, right, isLeft):
|
||||
|
|
|
@ -33,7 +33,7 @@ from pyspark.testing.utils import QuietTest
|
|||
|
||||
if have_pandas:
|
||||
import pandas as pd
|
||||
from pandas.util.testing import assert_frame_equal
|
||||
from pandas.testing import assert_frame_equal
|
||||
|
||||
if have_pyarrow:
|
||||
import pyarrow as pa # noqa: F401
|
||||
|
@ -160,7 +160,7 @@ class GroupedMapInPandasTests(ReusedSQLTestCase):
|
|||
def test_register_grouped_map_udf(self):
|
||||
foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUPED_MAP)
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'f.*SQL_BATCHED_UDF.*SQL_SCALAR_PANDAS_UDF.*SQL_GROUPED_AGG_PANDAS_UDF.*'):
|
||||
self.spark.catalog.registerFunction("foo_udf", foo_udf)
|
||||
|
@ -244,7 +244,7 @@ class GroupedMapInPandasTests(ReusedSQLTestCase):
|
|||
|
||||
def test_wrong_return_type(self):
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
'Invalid return type.*grouped map Pandas UDF.*ArrayType.*TimestampType'):
|
||||
pandas_udf(
|
||||
|
@ -256,20 +256,20 @@ class GroupedMapInPandasTests(ReusedSQLTestCase):
|
|||
df = self.data
|
||||
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
|
||||
with self.assertRaisesRegex(ValueError, 'Invalid udf'):
|
||||
df.groupby('id').apply(lambda x: x)
|
||||
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
|
||||
with self.assertRaisesRegex(ValueError, 'Invalid udf'):
|
||||
df.groupby('id').apply(udf(lambda x: x, DoubleType()))
|
||||
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
|
||||
with self.assertRaisesRegex(ValueError, 'Invalid udf'):
|
||||
df.groupby('id').apply(sum(df.v))
|
||||
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
|
||||
with self.assertRaisesRegex(ValueError, 'Invalid udf'):
|
||||
df.groupby('id').apply(df.v + 1)
|
||||
with self.assertRaisesRegexp(ValueError, 'Invalid function'):
|
||||
with self.assertRaisesRegex(ValueError, 'Invalid function'):
|
||||
df.groupby('id').apply(
|
||||
pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())])))
|
||||
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
|
||||
with self.assertRaisesRegex(ValueError, 'Invalid udf'):
|
||||
df.groupby('id').apply(pandas_udf(lambda x, y: x, DoubleType()))
|
||||
with self.assertRaisesRegexp(ValueError, 'Invalid udf.*GROUPED_MAP'):
|
||||
with self.assertRaisesRegex(ValueError, 'Invalid udf.*GROUPED_MAP'):
|
||||
df.groupby('id').apply(
|
||||
pandas_udf(lambda x, y: x, DoubleType(), PandasUDFType.SCALAR))
|
||||
|
||||
|
@ -284,7 +284,7 @@ class GroupedMapInPandasTests(ReusedSQLTestCase):
|
|||
for unsupported_type in unsupported_types:
|
||||
schema = StructType([StructField('id', LongType(), True), unsupported_type])
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(NotImplementedError, common_err_msg):
|
||||
with self.assertRaisesRegex(NotImplementedError, common_err_msg):
|
||||
pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP)
|
||||
|
||||
# Regression test for SPARK-23314
|
||||
|
@ -451,9 +451,9 @@ class GroupedMapInPandasTests(ReusedSQLTestCase):
|
|||
|
||||
with self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": False}):
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(Exception, "KeyError: 'id'"):
|
||||
with self.assertRaisesRegex(Exception, "KeyError: 'id'"):
|
||||
grouped_df.apply(column_name_typo).collect()
|
||||
with self.assertRaisesRegexp(Exception, "[D|d]ecimal.*got.*date"):
|
||||
with self.assertRaisesRegex(Exception, "[D|d]ecimal.*got.*date"):
|
||||
grouped_df.apply(invalid_positional_types).collect()
|
||||
|
||||
def test_positional_assignment_conf(self):
|
||||
|
@ -482,7 +482,7 @@ class GroupedMapInPandasTests(ReusedSQLTestCase):
|
|||
# this was throwing an AnalysisException before SPARK-24208
|
||||
res = df_with_pandas.alias('temp0').join(df_with_pandas.alias('temp1'),
|
||||
col('temp0.key') == col('temp1.key'))
|
||||
self.assertEquals(res.count(), 5)
|
||||
self.assertEqual(res.count(), 5)
|
||||
|
||||
def test_mixed_scalar_udfs_followed_by_groupby_apply(self):
|
||||
df = self.spark.range(0, 10).toDF('v1')
|
||||
|
@ -494,7 +494,7 @@ class GroupedMapInPandasTests(ReusedSQLTestCase):
|
|||
'sum int',
|
||||
PandasUDFType.GROUPED_MAP))
|
||||
|
||||
self.assertEquals(result.collect()[0]['sum'], 165)
|
||||
self.assertEqual(result.collect()[0]['sum'], 165)
|
||||
|
||||
def test_grouped_with_empty_partition(self):
|
||||
data = [Row(id=1, x=2), Row(id=1, x=3), Row(id=2, x=4)]
|
||||
|
@ -604,7 +604,7 @@ class GroupedMapInPandasTests(ReusedSQLTestCase):
|
|||
df = self.spark.createDataFrame([[1, 1]], ["column", "score"])
|
||||
row = df.groupby('COLUMN').applyInPandas(
|
||||
my_pandas_udf, schema="column integer, score float").first()
|
||||
self.assertEquals(row.asDict(), Row(column=1, score=0.5).asDict())
|
||||
self.assertEqual(row.asDict(), Row(column=1, score=0.5).asDict())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -61,7 +61,7 @@ class MapInPandasTests(ReusedSQLTestCase):
|
|||
df = self.spark.range(10)
|
||||
actual = df.mapInPandas(func, 'id long').collect()
|
||||
expected = df.collect()
|
||||
self.assertEquals(actual, expected)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_multiple_columns(self):
|
||||
data = [(1, "foo"), (2, None), (3, "bar"), (4, "bar")]
|
||||
|
@ -75,7 +75,7 @@ class MapInPandasTests(ReusedSQLTestCase):
|
|||
|
||||
actual = df.mapInPandas(func, df.schema).collect()
|
||||
expected = df.collect()
|
||||
self.assertEquals(actual, expected)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_different_output_length(self):
|
||||
def func(iterator):
|
||||
|
@ -84,7 +84,7 @@ class MapInPandasTests(ReusedSQLTestCase):
|
|||
|
||||
df = self.spark.range(10)
|
||||
actual = df.repartition(1).mapInPandas(func, 'a long').collect()
|
||||
self.assertEquals(set((r.a for r in actual)), set(range(100)))
|
||||
self.assertEqual(set((r.a for r in actual)), set(range(100)))
|
||||
|
||||
def test_empty_iterator(self):
|
||||
def empty_iter(_):
|
||||
|
@ -110,7 +110,7 @@ class MapInPandasTests(ReusedSQLTestCase):
|
|||
df = self.spark.range(10)
|
||||
actual = df.mapInPandas(func, 'id long').mapInPandas(func, 'id long').collect()
|
||||
expected = df.collect()
|
||||
self.assertEquals(actual, expected)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -114,31 +114,31 @@ class PandasUDFTests(ReusedSQLTestCase):
|
|||
@pandas_udf('blah')
|
||||
def foo(x):
|
||||
return x
|
||||
with self.assertRaisesRegexp(ValueError, 'Invalid return type.*None'):
|
||||
with self.assertRaisesRegex(ValueError, 'Invalid return type.*None'):
|
||||
@pandas_udf(functionType=PandasUDFType.SCALAR)
|
||||
def foo(x):
|
||||
return x
|
||||
with self.assertRaisesRegexp(ValueError, 'Invalid function'):
|
||||
with self.assertRaisesRegex(ValueError, 'Invalid function'):
|
||||
@pandas_udf('double', 100)
|
||||
def foo(x):
|
||||
return x
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, '0-arg pandas_udfs.*not.*supported'):
|
||||
with self.assertRaisesRegex(ValueError, '0-arg pandas_udfs.*not.*supported'):
|
||||
pandas_udf(lambda: 1, LongType(), PandasUDFType.SCALAR)
|
||||
with self.assertRaisesRegexp(ValueError, '0-arg pandas_udfs.*not.*supported'):
|
||||
with self.assertRaisesRegex(ValueError, '0-arg pandas_udfs.*not.*supported'):
|
||||
@pandas_udf(LongType(), PandasUDFType.SCALAR)
|
||||
def zero_with_type():
|
||||
return 1
|
||||
|
||||
with self.assertRaisesRegexp(TypeError, 'Invalid return type'):
|
||||
with self.assertRaisesRegex(TypeError, 'Invalid return type'):
|
||||
@pandas_udf(returnType=PandasUDFType.GROUPED_MAP)
|
||||
def foo(df):
|
||||
return df
|
||||
with self.assertRaisesRegexp(TypeError, 'Invalid return type'):
|
||||
with self.assertRaisesRegex(TypeError, 'Invalid return type'):
|
||||
@pandas_udf(returnType='double', functionType=PandasUDFType.GROUPED_MAP)
|
||||
def foo(df):
|
||||
return df
|
||||
with self.assertRaisesRegexp(ValueError, 'Invalid function'):
|
||||
with self.assertRaisesRegex(ValueError, 'Invalid function'):
|
||||
@pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUPED_MAP)
|
||||
def foo(k, v, w):
|
||||
return k
|
||||
|
@ -154,14 +154,14 @@ class PandasUDFTests(ReusedSQLTestCase):
|
|||
df = self.spark.range(0, 100)
|
||||
|
||||
# plain udf (test for SPARK-23754)
|
||||
self.assertRaisesRegexp(
|
||||
self.assertRaisesRegex(
|
||||
PythonException,
|
||||
exc_message,
|
||||
df.withColumn('v', udf(foo)('id')).collect
|
||||
)
|
||||
|
||||
# pandas scalar udf
|
||||
self.assertRaisesRegexp(
|
||||
self.assertRaisesRegex(
|
||||
PythonException,
|
||||
exc_message,
|
||||
df.withColumn(
|
||||
|
@ -170,7 +170,7 @@ class PandasUDFTests(ReusedSQLTestCase):
|
|||
)
|
||||
|
||||
# pandas grouped map
|
||||
self.assertRaisesRegexp(
|
||||
self.assertRaisesRegex(
|
||||
PythonException,
|
||||
exc_message,
|
||||
df.groupBy('id').apply(
|
||||
|
@ -178,7 +178,7 @@ class PandasUDFTests(ReusedSQLTestCase):
|
|||
).collect
|
||||
)
|
||||
|
||||
self.assertRaisesRegexp(
|
||||
self.assertRaisesRegex(
|
||||
PythonException,
|
||||
exc_message,
|
||||
df.groupBy('id').apply(
|
||||
|
@ -187,7 +187,7 @@ class PandasUDFTests(ReusedSQLTestCase):
|
|||
)
|
||||
|
||||
# pandas grouped agg
|
||||
self.assertRaisesRegexp(
|
||||
self.assertRaisesRegex(
|
||||
PythonException,
|
||||
exc_message,
|
||||
df.groupBy('id').agg(
|
||||
|
@ -210,8 +210,8 @@ class PandasUDFTests(ReusedSQLTestCase):
|
|||
# Since 0.11.0, PyArrow supports the feature to raise an error for unsafe cast.
|
||||
with self.sql_conf({
|
||||
"spark.sql.execution.pandas.convertToArrowArraySafely": True}):
|
||||
with self.assertRaisesRegexp(Exception,
|
||||
"Exception thrown when converting pandas.Series"):
|
||||
with self.assertRaisesRegex(Exception,
|
||||
"Exception thrown when converting pandas.Series"):
|
||||
df.select(['A']).withColumn('udf', udf('A')).collect()
|
||||
|
||||
# Disabling Arrow safe type check.
|
||||
|
@ -231,8 +231,8 @@ class PandasUDFTests(ReusedSQLTestCase):
|
|||
# When enabling safe type check, Arrow 0.11.0+ disallows overflow cast.
|
||||
with self.sql_conf({
|
||||
"spark.sql.execution.pandas.convertToArrowArraySafely": True}):
|
||||
with self.assertRaisesRegexp(Exception,
|
||||
"Exception thrown when converting pandas.Series"):
|
||||
with self.assertRaisesRegex(Exception,
|
||||
"Exception thrown when converting pandas.Series"):
|
||||
df.withColumn('udf', udf('id')).collect()
|
||||
|
||||
# Disabling safe type check, let Arrow do the cast anyway.
|
||||
|
|
|
@ -30,7 +30,7 @@ from pyspark.testing.utils import QuietTest
|
|||
|
||||
if have_pandas:
|
||||
import pandas as pd
|
||||
from pandas.util.testing import assert_frame_equal
|
||||
from pandas.testing import assert_frame_equal
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
|
@ -145,20 +145,20 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
|
|||
|
||||
def test_unsupported_types(self):
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(NotImplementedError, 'not supported'):
|
||||
with self.assertRaisesRegex(NotImplementedError, 'not supported'):
|
||||
pandas_udf(
|
||||
lambda x: x,
|
||||
ArrayType(ArrayType(TimestampType())),
|
||||
PandasUDFType.GROUPED_AGG)
|
||||
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(NotImplementedError, 'not supported'):
|
||||
with self.assertRaisesRegex(NotImplementedError, 'not supported'):
|
||||
@pandas_udf('mean double, std double', PandasUDFType.GROUPED_AGG)
|
||||
def mean_and_std_udf(v):
|
||||
return v.mean(), v.std()
|
||||
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(NotImplementedError, 'not supported'):
|
||||
with self.assertRaisesRegex(NotImplementedError, 'not supported'):
|
||||
@pandas_udf(ArrayType(TimestampType()), PandasUDFType.GROUPED_AGG)
|
||||
def mean_and_std_udf(v):
|
||||
return {v.mean(): v.std()}
|
||||
|
@ -428,7 +428,7 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
|
|||
|
||||
array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array<double>', PandasUDFType.GROUPED_AGG)
|
||||
result1 = df.groupby('id').agg(array_udf(df['v']).alias('v2'))
|
||||
self.assertEquals(result1.first()['v2'], [1.0, 2.0])
|
||||
self.assertEqual(result1.first()['v2'], [1.0, 2.0])
|
||||
|
||||
def test_invalid_args(self):
|
||||
df = self.data
|
||||
|
@ -436,19 +436,19 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
|
|||
mean_udf = self.pandas_agg_mean_udf
|
||||
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
AnalysisException,
|
||||
'nor.*aggregate function'):
|
||||
df.groupby(df.id).agg(plus_one(df.v)).collect()
|
||||
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
AnalysisException,
|
||||
'aggregate function.*argument.*aggregate function'):
|
||||
df.groupby(df.id).agg(mean_udf(mean_udf(df.v))).collect()
|
||||
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
AnalysisException,
|
||||
'mixture.*aggregate function.*group aggregate pandas UDF'):
|
||||
df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect()
|
||||
|
|
|
@ -133,7 +133,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
long_f(col('long')), float_f(col('float')),
|
||||
double_f(col('double')), decimal_f('decimal'),
|
||||
bool_f(col('bool')), array_long_f('array_long'))
|
||||
self.assertEquals(df.collect(), res.collect())
|
||||
self.assertEqual(df.collect(), res.collect())
|
||||
|
||||
def test_register_nondeterministic_vectorized_udf_basic(self):
|
||||
random_pandas_udf = pandas_udf(
|
||||
|
@ -169,7 +169,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
|
||||
bool_f = pandas_udf(lambda x: x, BooleanType(), udf_type)
|
||||
res = df.select(bool_f(col('bool')))
|
||||
self.assertEquals(df.collect(), res.collect())
|
||||
self.assertEqual(df.collect(), res.collect())
|
||||
|
||||
def test_vectorized_udf_null_byte(self):
|
||||
data = [(None,), (2,), (3,), (4,)]
|
||||
|
@ -178,7 +178,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
|
||||
byte_f = pandas_udf(lambda x: x, ByteType(), udf_type)
|
||||
res = df.select(byte_f(col('byte')))
|
||||
self.assertEquals(df.collect(), res.collect())
|
||||
self.assertEqual(df.collect(), res.collect())
|
||||
|
||||
def test_vectorized_udf_null_short(self):
|
||||
data = [(None,), (2,), (3,), (4,)]
|
||||
|
@ -187,7 +187,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
|
||||
short_f = pandas_udf(lambda x: x, ShortType(), udf_type)
|
||||
res = df.select(short_f(col('short')))
|
||||
self.assertEquals(df.collect(), res.collect())
|
||||
self.assertEqual(df.collect(), res.collect())
|
||||
|
||||
def test_vectorized_udf_null_int(self):
|
||||
data = [(None,), (2,), (3,), (4,)]
|
||||
|
@ -196,7 +196,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
|
||||
int_f = pandas_udf(lambda x: x, IntegerType(), udf_type)
|
||||
res = df.select(int_f(col('int')))
|
||||
self.assertEquals(df.collect(), res.collect())
|
||||
self.assertEqual(df.collect(), res.collect())
|
||||
|
||||
def test_vectorized_udf_null_long(self):
|
||||
data = [(None,), (2,), (3,), (4,)]
|
||||
|
@ -205,7 +205,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
|
||||
long_f = pandas_udf(lambda x: x, LongType(), udf_type)
|
||||
res = df.select(long_f(col('long')))
|
||||
self.assertEquals(df.collect(), res.collect())
|
||||
self.assertEqual(df.collect(), res.collect())
|
||||
|
||||
def test_vectorized_udf_null_float(self):
|
||||
data = [(3.0,), (5.0,), (-1.0,), (None,)]
|
||||
|
@ -214,7 +214,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
|
||||
float_f = pandas_udf(lambda x: x, FloatType(), udf_type)
|
||||
res = df.select(float_f(col('float')))
|
||||
self.assertEquals(df.collect(), res.collect())
|
||||
self.assertEqual(df.collect(), res.collect())
|
||||
|
||||
def test_vectorized_udf_null_double(self):
|
||||
data = [(3.0,), (5.0,), (-1.0,), (None,)]
|
||||
|
@ -223,7 +223,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
|
||||
double_f = pandas_udf(lambda x: x, DoubleType(), udf_type)
|
||||
res = df.select(double_f(col('double')))
|
||||
self.assertEquals(df.collect(), res.collect())
|
||||
self.assertEqual(df.collect(), res.collect())
|
||||
|
||||
def test_vectorized_udf_null_decimal(self):
|
||||
data = [(Decimal(3.0),), (Decimal(5.0),), (Decimal(-1.0),), (None,)]
|
||||
|
@ -232,7 +232,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
|
||||
decimal_f = pandas_udf(lambda x: x, DecimalType(38, 18), udf_type)
|
||||
res = df.select(decimal_f(col('decimal')))
|
||||
self.assertEquals(df.collect(), res.collect())
|
||||
self.assertEqual(df.collect(), res.collect())
|
||||
|
||||
def test_vectorized_udf_null_string(self):
|
||||
data = [("foo",), (None,), ("bar",), ("bar",)]
|
||||
|
@ -241,7 +241,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
|
||||
str_f = pandas_udf(lambda x: x, StringType(), udf_type)
|
||||
res = df.select(str_f(col('str')))
|
||||
self.assertEquals(df.collect(), res.collect())
|
||||
self.assertEqual(df.collect(), res.collect())
|
||||
|
||||
def test_vectorized_udf_string_in_udf(self):
|
||||
df = self.spark.range(10)
|
||||
|
@ -255,7 +255,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
str_f = pandas_udf(f, StringType(), udf_type)
|
||||
actual = df.select(str_f(col('id')))
|
||||
expected = df.select(col('id').cast('string'))
|
||||
self.assertEquals(expected.collect(), actual.collect())
|
||||
self.assertEqual(expected.collect(), actual.collect())
|
||||
|
||||
def test_vectorized_udf_datatype_string(self):
|
||||
df = self.spark.range(10).select(
|
||||
|
@ -279,7 +279,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
long_f(col('long')), float_f(col('float')),
|
||||
double_f(col('double')), decimal_f('decimal'),
|
||||
bool_f(col('bool')))
|
||||
self.assertEquals(df.collect(), res.collect())
|
||||
self.assertEqual(df.collect(), res.collect())
|
||||
|
||||
def test_vectorized_udf_null_binary(self):
|
||||
data = [(bytearray(b"a"),), (None,), (bytearray(b"bb"),), (bytearray(b"ccc"),)]
|
||||
|
@ -288,7 +288,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
|
||||
str_f = pandas_udf(lambda x: x, BinaryType(), udf_type)
|
||||
res = df.select(str_f(col('binary')))
|
||||
self.assertEquals(df.collect(), res.collect())
|
||||
self.assertEqual(df.collect(), res.collect())
|
||||
|
||||
def test_vectorized_udf_array_type(self):
|
||||
data = [([1, 2],), ([3, 4],)]
|
||||
|
@ -297,7 +297,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
|
||||
array_f = pandas_udf(lambda x: x, ArrayType(IntegerType()), udf_type)
|
||||
result = df.select(array_f(col('array')))
|
||||
self.assertEquals(df.collect(), result.collect())
|
||||
self.assertEqual(df.collect(), result.collect())
|
||||
|
||||
def test_vectorized_udf_null_array(self):
|
||||
data = [([1, 2],), (None,), (None,), ([3, 4],), (None,)]
|
||||
|
@ -306,7 +306,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
|
||||
array_f = pandas_udf(lambda x: x, ArrayType(IntegerType()), udf_type)
|
||||
result = df.select(array_f(col('array')))
|
||||
self.assertEquals(df.collect(), result.collect())
|
||||
self.assertEqual(df.collect(), result.collect())
|
||||
|
||||
def test_vectorized_udf_struct_type(self):
|
||||
df = self.spark.range(10)
|
||||
|
@ -375,7 +375,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
|
||||
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
Exception,
|
||||
'Invalid return type with scalar Pandas UDFs'):
|
||||
pandas_udf(lambda x: x, returnType=nested_type, functionType=udf_type)
|
||||
|
@ -392,7 +392,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
else:
|
||||
map_f = pandas_udf(lambda x: x, MapType(StringType(), LongType()), udf_type)
|
||||
result = df.select(map_f(col('map')))
|
||||
self.assertEquals(df.collect(), result.collect())
|
||||
self.assertEqual(df.collect(), result.collect())
|
||||
|
||||
def test_vectorized_udf_complex(self):
|
||||
df = self.spark.range(10).select(
|
||||
|
@ -422,7 +422,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
(iter_add, iter_power2, iter_mul)]:
|
||||
res = df.select(add(col('a'), col('b')), power2(col('a')), mul(col('b'), col('c')))
|
||||
expected = df.select(expr('a + b'), expr('power(2, a)'), expr('b * c'))
|
||||
self.assertEquals(expected.collect(), res.collect())
|
||||
self.assertEqual(expected.collect(), res.collect())
|
||||
|
||||
def test_vectorized_udf_exception(self):
|
||||
df = self.spark.range(10)
|
||||
|
@ -435,14 +435,14 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
|
||||
for raise_exception in [scalar_raise_exception, iter_raise_exception]:
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(Exception, 'division( or modulo)? by zero'):
|
||||
with self.assertRaisesRegex(Exception, 'division( or modulo)? by zero'):
|
||||
df.select(raise_exception(col('id'))).collect()
|
||||
|
||||
def test_vectorized_udf_invalid_length(self):
|
||||
df = self.spark.range(10)
|
||||
raise_exception = pandas_udf(lambda _: pd.Series(1), LongType())
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
Exception,
|
||||
'Result vector from pandas_udf was not the required length'):
|
||||
df.select(raise_exception(col('id'))).collect()
|
||||
|
@ -453,7 +453,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
yield pd.Series(1)
|
||||
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
Exception,
|
||||
"The length of output in Scalar iterator.*"
|
||||
"the length of output was 1"):
|
||||
|
@ -469,7 +469,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 3}):
|
||||
df1 = self.spark.range(10).repartition(1)
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
Exception,
|
||||
"pandas iterator UDF should exhaust"):
|
||||
df1.select(iter_udf_not_reading_all_input(col('id'))).collect()
|
||||
|
@ -486,7 +486,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
|
||||
for f, g in [(scalar_f, scalar_g), (iter_f, iter_g)]:
|
||||
res = df.select(g(f(col('id'))))
|
||||
self.assertEquals(df.collect(), res.collect())
|
||||
self.assertEqual(df.collect(), res.collect())
|
||||
|
||||
def test_vectorized_udf_chained_struct_type(self):
|
||||
df = self.spark.range(10)
|
||||
|
@ -517,7 +517,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
def test_vectorized_udf_wrong_return_type(self):
|
||||
with QuietTest(self.sc):
|
||||
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
'Invalid return type.*scalar Pandas UDF.*ArrayType.*TimestampType'):
|
||||
pandas_udf(lambda x: x, ArrayType(TimestampType()), udf_type)
|
||||
|
@ -529,7 +529,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
PandasUDFType.SCALAR_ITER)
|
||||
for f in [scalar_f, iter_f]:
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(Exception, 'Return.*type.*Series'):
|
||||
with self.assertRaisesRegex(Exception, 'Return.*type.*Series'):
|
||||
df.select(f(col('id'))).collect()
|
||||
|
||||
def test_vectorized_udf_decorator(self):
|
||||
|
@ -545,14 +545,14 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
|
||||
for identity in [scalar_identity, iter_identity]:
|
||||
res = df.select(identity(col('id')))
|
||||
self.assertEquals(df.collect(), res.collect())
|
||||
self.assertEqual(df.collect(), res.collect())
|
||||
|
||||
def test_vectorized_udf_empty_partition(self):
|
||||
df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2))
|
||||
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
|
||||
f = pandas_udf(lambda x: x, LongType(), udf_type)
|
||||
res = df.select(f(col('id')))
|
||||
self.assertEquals(df.collect(), res.collect())
|
||||
self.assertEqual(df.collect(), res.collect())
|
||||
|
||||
def test_vectorized_udf_struct_with_empty_partition(self):
|
||||
df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2))\
|
||||
|
@ -585,16 +585,16 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
|
||||
for f in [scalar_f, iter_f]:
|
||||
res = df.select(f(col('id'), col('id')))
|
||||
self.assertEquals(df.collect(), res.collect())
|
||||
self.assertEqual(df.collect(), res.collect())
|
||||
|
||||
def test_vectorized_udf_unsupported_types(self):
|
||||
with QuietTest(self.sc):
|
||||
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
'Invalid return type.*scalar Pandas UDF.*ArrayType.*TimestampType'):
|
||||
pandas_udf(lambda x: x, ArrayType(TimestampType()), udf_type)
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
'Invalid return type.*scalar Pandas UDF.*ArrayType.StructType'):
|
||||
pandas_udf(lambda x: x,
|
||||
|
@ -637,10 +637,10 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
result = df.withColumn("check_data",
|
||||
check_data(col("idx"), col("date"), col("date_copy"))).collect()
|
||||
|
||||
self.assertEquals(len(data), len(result))
|
||||
self.assertEqual(len(data), len(result))
|
||||
for i in range(len(result)):
|
||||
self.assertEquals(data[i][1], result[i][1]) # "date" col
|
||||
self.assertEquals(data[i][1], result[i][2]) # "date_copy" col
|
||||
self.assertEqual(data[i][1], result[i][1]) # "date" col
|
||||
self.assertEqual(data[i][1], result[i][2]) # "date_copy" col
|
||||
self.assertIsNone(result[i][3]) # "check_data" col
|
||||
|
||||
def test_vectorized_udf_timestamps(self):
|
||||
|
@ -686,10 +686,10 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
result = df.withColumn("check_data", check_data(col("idx"), col("timestamp"),
|
||||
col("timestamp_copy"))).collect()
|
||||
# Check that collection values are correct
|
||||
self.assertEquals(len(data), len(result))
|
||||
self.assertEqual(len(data), len(result))
|
||||
for i in range(len(result)):
|
||||
self.assertEquals(data[i][1], result[i][1]) # "timestamp" col
|
||||
self.assertEquals(data[i][1], result[i][2]) # "timestamp_copy" col
|
||||
self.assertEqual(data[i][1], result[i][1]) # "timestamp" col
|
||||
self.assertEqual(data[i][1], result[i][2]) # "timestamp_copy" col
|
||||
self.assertIsNone(result[i][3]) # "check_data" col
|
||||
|
||||
def test_vectorized_udf_return_timestamp_tz(self):
|
||||
|
@ -713,7 +713,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
i, ts = r
|
||||
ts_tz = pd.Timestamp(i, unit='D', tz='America/Los_Angeles').to_pydatetime()
|
||||
expected = spark_ts_t.fromInternal(spark_ts_t.toInternal(ts_tz))
|
||||
self.assertEquals(expected, ts)
|
||||
self.assertEqual(expected, ts)
|
||||
|
||||
def test_vectorized_udf_check_config(self):
|
||||
with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 3}):
|
||||
|
@ -799,9 +799,9 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
for random_udf in [self.nondeterministic_vectorized_udf,
|
||||
self.nondeterministic_vectorized_iter_udf]:
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'):
|
||||
with self.assertRaisesRegex(AnalysisException, 'nondeterministic'):
|
||||
df.groupby(df.id).agg(sum(random_udf(df.id))).collect()
|
||||
with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'):
|
||||
with self.assertRaisesRegex(AnalysisException, 'nondeterministic'):
|
||||
df.agg(sum(random_udf(df.id))).collect()
|
||||
|
||||
def test_register_vectorized_udf_basic(self):
|
||||
|
@ -825,8 +825,8 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
res2 = self.spark.sql(
|
||||
"SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM range(10)) t")
|
||||
expected = df.select(expr('a + b'))
|
||||
self.assertEquals(expected.collect(), res1.collect())
|
||||
self.assertEquals(expected.collect(), res2.collect())
|
||||
self.assertEqual(expected.collect(), res1.collect())
|
||||
self.assertEqual(expected.collect(), res2.collect())
|
||||
|
||||
def test_scalar_iter_udf_init(self):
|
||||
import numpy as np
|
||||
|
@ -854,7 +854,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
finally:
|
||||
raise RuntimeError("reached finally block")
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(Exception, "reached finally block"):
|
||||
with self.assertRaisesRegex(Exception, "reached finally block"):
|
||||
self.spark.range(1).select(test_close(col("id"))).collect()
|
||||
|
||||
def test_scalar_iter_udf_close_early(self):
|
||||
|
@ -905,7 +905,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
|
||||
foo_udf = pandas_udf(lambda x: x, 'timestamp', udf_type)
|
||||
result = df.withColumn('time', foo_udf(df.time))
|
||||
self.assertEquals(df.collect(), result.collect())
|
||||
self.assertEqual(df.collect(), result.collect())
|
||||
|
||||
def test_udf_category_type(self):
|
||||
|
||||
|
@ -1003,11 +1003,11 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
df_chained_4 = df.withColumn('f4_f2_f1', f4(f2(f1(df['v']))))
|
||||
df_chained_5 = df.withColumn('f4_f3_f1', f4(f3(f1(df['v']))))
|
||||
|
||||
self.assertEquals(expected_chained_1, df_chained_1.collect())
|
||||
self.assertEquals(expected_chained_2, df_chained_2.collect())
|
||||
self.assertEquals(expected_chained_3, df_chained_3.collect())
|
||||
self.assertEquals(expected_chained_4, df_chained_4.collect())
|
||||
self.assertEquals(expected_chained_5, df_chained_5.collect())
|
||||
self.assertEqual(expected_chained_1, df_chained_1.collect())
|
||||
self.assertEqual(expected_chained_2, df_chained_2.collect())
|
||||
self.assertEqual(expected_chained_3, df_chained_3.collect())
|
||||
self.assertEqual(expected_chained_4, df_chained_4.collect())
|
||||
self.assertEqual(expected_chained_5, df_chained_5.collect())
|
||||
|
||||
# Test multiple mixed UDF expressions in a single projection
|
||||
df_multi_1 = df \
|
||||
|
@ -1045,8 +1045,8 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
.withColumn('f4_f3_f2', f4(f3(f2(col('v'))))) \
|
||||
.withColumn('f4_f3_f2_f1', f4(f3(f2(f1(col('v'))))))
|
||||
|
||||
self.assertEquals(expected_multi, df_multi_1.collect())
|
||||
self.assertEquals(expected_multi, df_multi_2.collect())
|
||||
self.assertEqual(expected_multi, df_multi_1.collect())
|
||||
self.assertEqual(expected_multi, df_multi_2.collect())
|
||||
|
||||
def test_mixed_udf_and_sql(self):
|
||||
df = self.spark.range(0, 1).toDF('v')
|
||||
|
@ -1107,7 +1107,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
.withColumn('f3_f1_f2', f3(f1(f2(df['v'])))) \
|
||||
.withColumn('f3_f2_f1', f3(f2(f1(df['v']))))
|
||||
|
||||
self.assertEquals(expected, df1.collect())
|
||||
self.assertEqual(expected, df1.collect())
|
||||
|
||||
# SPARK-24721
|
||||
@unittest.skipIf(not test_compiled, test_not_compiled_message) # type: ignore
|
||||
|
@ -1138,17 +1138,17 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
for df in [filesource_df, datasource_df, datasource_v2_df]:
|
||||
result = df.withColumn('c', c1)
|
||||
expected = df.withColumn('c', lit(2))
|
||||
self.assertEquals(expected.collect(), result.collect())
|
||||
self.assertEqual(expected.collect(), result.collect())
|
||||
|
||||
for df in [filesource_df, datasource_df, datasource_v2_df]:
|
||||
result = df.withColumn('c', c2)
|
||||
expected = df.withColumn('c', col('i') + 1)
|
||||
self.assertEquals(expected.collect(), result.collect())
|
||||
self.assertEqual(expected.collect(), result.collect())
|
||||
|
||||
for df in [filesource_df, datasource_df, datasource_v2_df]:
|
||||
for f in [f1, f2]:
|
||||
result = df.filter(f)
|
||||
self.assertEquals(0, result.count())
|
||||
self.assertEqual(0, result.count())
|
||||
finally:
|
||||
shutil.rmtree(path)
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ from pyspark.sql import Row
|
|||
if have_pandas:
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from pandas.util.testing import assert_frame_equal
|
||||
from pandas.testing import assert_frame_equal
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
|
|
|
@ -26,7 +26,7 @@ from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarro
|
|||
from pyspark.testing.utils import QuietTest
|
||||
|
||||
if have_pandas:
|
||||
from pandas.util.testing import assert_frame_equal
|
||||
from pandas.testing import assert_frame_equal
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
|
@ -241,14 +241,14 @@ class WindowPandasUDFTests(ReusedSQLTestCase):
|
|||
|
||||
array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array<double>', PandasUDFType.GROUPED_AGG)
|
||||
result1 = df.withColumn('v2', array_udf(df['v']).over(w))
|
||||
self.assertEquals(result1.first()['v2'], [1.0, 2.0])
|
||||
self.assertEqual(result1.first()['v2'], [1.0, 2.0])
|
||||
|
||||
def test_invalid_args(self):
|
||||
df = self.data
|
||||
w = self.unbounded_window
|
||||
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
AnalysisException,
|
||||
'.*not supported within a window function'):
|
||||
foo_udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP)
|
||||
|
|
|
@ -180,7 +180,7 @@ class TypesTests(ReusedSQLTestCase):
|
|||
self.assertEqual(df.columns, ['col1', '_2'])
|
||||
|
||||
def test_infer_schema_fails(self):
|
||||
with self.assertRaisesRegexp(TypeError, 'field a'):
|
||||
with self.assertRaisesRegex(TypeError, 'field a'):
|
||||
self.spark.createDataFrame(self.spark.sparkContext.parallelize([[1, 1], ["x", 1]]),
|
||||
schema=["a", "b"], samplingRatio=0.99)
|
||||
|
||||
|
@ -578,18 +578,18 @@ class TypesTests(ReusedSQLTestCase):
|
|||
ArrayType(LongType()),
|
||||
ArrayType(LongType())
|
||||
), ArrayType(LongType()))
|
||||
with self.assertRaisesRegexp(TypeError, 'element in array'):
|
||||
with self.assertRaisesRegex(TypeError, 'element in array'):
|
||||
_merge_type(ArrayType(LongType()), ArrayType(DoubleType()))
|
||||
|
||||
self.assertEqual(_merge_type(
|
||||
MapType(StringType(), LongType()),
|
||||
MapType(StringType(), LongType())
|
||||
), MapType(StringType(), LongType()))
|
||||
with self.assertRaisesRegexp(TypeError, 'key of map'):
|
||||
with self.assertRaisesRegex(TypeError, 'key of map'):
|
||||
_merge_type(
|
||||
MapType(StringType(), LongType()),
|
||||
MapType(DoubleType(), LongType()))
|
||||
with self.assertRaisesRegexp(TypeError, 'value of map'):
|
||||
with self.assertRaisesRegex(TypeError, 'value of map'):
|
||||
_merge_type(
|
||||
MapType(StringType(), LongType()),
|
||||
MapType(StringType(), DoubleType()))
|
||||
|
@ -598,7 +598,7 @@ class TypesTests(ReusedSQLTestCase):
|
|||
StructType([StructField("f1", LongType()), StructField("f2", StringType())]),
|
||||
StructType([StructField("f1", LongType()), StructField("f2", StringType())])
|
||||
), StructType([StructField("f1", LongType()), StructField("f2", StringType())]))
|
||||
with self.assertRaisesRegexp(TypeError, 'field f1'):
|
||||
with self.assertRaisesRegex(TypeError, 'field f1'):
|
||||
_merge_type(
|
||||
StructType([StructField("f1", LongType()), StructField("f2", StringType())]),
|
||||
StructType([StructField("f1", DoubleType()), StructField("f2", StringType())]))
|
||||
|
@ -607,7 +607,7 @@ class TypesTests(ReusedSQLTestCase):
|
|||
StructType([StructField("f1", StructType([StructField("f2", LongType())]))]),
|
||||
StructType([StructField("f1", StructType([StructField("f2", LongType())]))])
|
||||
), StructType([StructField("f1", StructType([StructField("f2", LongType())]))]))
|
||||
with self.assertRaisesRegexp(TypeError, 'field f2 in field f1'):
|
||||
with self.assertRaisesRegex(TypeError, 'field f2 in field f1'):
|
||||
_merge_type(
|
||||
StructType([StructField("f1", StructType([StructField("f2", LongType())]))]),
|
||||
StructType([StructField("f1", StructType([StructField("f2", StringType())]))]))
|
||||
|
@ -616,7 +616,7 @@ class TypesTests(ReusedSQLTestCase):
|
|||
StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]),
|
||||
StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())])
|
||||
), StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]))
|
||||
with self.assertRaisesRegexp(TypeError, 'element in array field f1'):
|
||||
with self.assertRaisesRegex(TypeError, 'element in array field f1'):
|
||||
_merge_type(
|
||||
StructType([
|
||||
StructField("f1", ArrayType(LongType())),
|
||||
|
@ -635,7 +635,7 @@ class TypesTests(ReusedSQLTestCase):
|
|||
), StructType([
|
||||
StructField("f1", MapType(StringType(), LongType())),
|
||||
StructField("f2", StringType())]))
|
||||
with self.assertRaisesRegexp(TypeError, 'value of map field f1'):
|
||||
with self.assertRaisesRegex(TypeError, 'value of map field f1'):
|
||||
_merge_type(
|
||||
StructType([
|
||||
StructField("f1", MapType(StringType(), LongType())),
|
||||
|
@ -648,7 +648,7 @@ class TypesTests(ReusedSQLTestCase):
|
|||
StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]),
|
||||
StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))])
|
||||
), StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]))
|
||||
with self.assertRaisesRegexp(TypeError, 'key of map element in array field f1'):
|
||||
with self.assertRaisesRegex(TypeError, 'key of map element in array field f1'):
|
||||
_merge_type(
|
||||
StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]),
|
||||
StructType([StructField("f1", ArrayType(MapType(DoubleType(), LongType())))])
|
||||
|
@ -734,7 +734,7 @@ class TypesTests(ReusedSQLTestCase):
|
|||
unsupported_types = all_types - set(supported_types)
|
||||
# test unsupported types
|
||||
for t in unsupported_types:
|
||||
with self.assertRaisesRegexp(TypeError, "infer the type of the field myarray"):
|
||||
with self.assertRaisesRegex(TypeError, "infer the type of the field myarray"):
|
||||
a = array.array(t)
|
||||
self.spark.createDataFrame([Row(myarray=a)]).collect()
|
||||
|
||||
|
@ -789,13 +789,13 @@ class DataTypeTests(unittest.TestCase):
|
|||
class DataTypeVerificationTests(unittest.TestCase):
|
||||
|
||||
def test_verify_type_exception_msg(self):
|
||||
self.assertRaisesRegexp(
|
||||
self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"test_name",
|
||||
lambda: _make_type_verifier(StringType(), nullable=False, name="test_name")(None))
|
||||
|
||||
schema = StructType([StructField('a', StructType([StructField('b', IntegerType())]))])
|
||||
self.assertRaisesRegexp(
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"field b in field a",
|
||||
lambda: _make_type_verifier(schema)([["data"]]))
|
||||
|
|
|
@ -98,7 +98,7 @@ class UDFTests(ReusedSQLTestCase):
|
|||
|
||||
def test_udf_registration_return_type_not_none(self):
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(TypeError, "Invalid return type"):
|
||||
with self.assertRaisesRegex(TypeError, "Invalid return type"):
|
||||
self.spark.catalog.registerFunction(
|
||||
"f", UserDefinedFunction(lambda x, y: len(x) + y, StringType()), StringType())
|
||||
|
||||
|
@ -149,9 +149,9 @@ class UDFTests(ReusedSQLTestCase):
|
|||
df = self.spark.range(10)
|
||||
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(AnalysisException, "nondeterministic"):
|
||||
with self.assertRaisesRegex(AnalysisException, "nondeterministic"):
|
||||
df.groupby('id').agg(sum(udf_random_col())).collect()
|
||||
with self.assertRaisesRegexp(AnalysisException, "nondeterministic"):
|
||||
with self.assertRaisesRegex(AnalysisException, "nondeterministic"):
|
||||
df.agg(sum(udf_random_col())).collect()
|
||||
|
||||
def test_chained_udf(self):
|
||||
|
@ -203,7 +203,7 @@ class UDFTests(ReusedSQLTestCase):
|
|||
# Cross join.
|
||||
df = left.join(right, f("a", "b"))
|
||||
with self.sql_conf({"spark.sql.crossJoin.enabled": False}):
|
||||
with self.assertRaisesRegexp(AnalysisException, 'Detected implicit cartesian product'):
|
||||
with self.assertRaisesRegex(AnalysisException, 'Detected implicit cartesian product'):
|
||||
df.collect()
|
||||
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
|
||||
self.assertEqual(df.collect(), [Row(a=1, b=1)])
|
||||
|
@ -238,7 +238,7 @@ class UDFTests(ReusedSQLTestCase):
|
|||
f = udf(lambda a, b: a == b, BooleanType())
|
||||
|
||||
def runWithJoinType(join_type, type_string):
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
AnalysisException,
|
||||
'Using PythonUDF.*%s is not supported.' % type_string):
|
||||
left.join(right, [f("a", "b"), left.a1 == right.b1], join_type).collect()
|
||||
|
@ -385,18 +385,18 @@ class UDFTests(ReusedSQLTestCase):
|
|||
|
||||
def test_non_existed_udf(self):
|
||||
spark = self.spark
|
||||
self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf",
|
||||
lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf"))
|
||||
self.assertRaisesRegex(AnalysisException, "Can not load class non_existed_udf",
|
||||
lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf"))
|
||||
|
||||
# This is to check if a deprecated 'SQLContext.registerJavaFunction' can call its alias.
|
||||
sqlContext = spark._wrapped
|
||||
self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf",
|
||||
lambda: sqlContext.registerJavaFunction("udf1", "non_existed_udf"))
|
||||
self.assertRaisesRegex(AnalysisException, "Can not load class non_existed_udf",
|
||||
lambda: sqlContext.registerJavaFunction("udf1", "non_existed_udf"))
|
||||
|
||||
def test_non_existed_udaf(self):
|
||||
spark = self.spark
|
||||
self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf",
|
||||
lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf"))
|
||||
self.assertRaisesRegex(AnalysisException, "Can not load class non_existed_udaf",
|
||||
lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf"))
|
||||
|
||||
def test_udf_with_input_file_name(self):
|
||||
from pyspark.sql.functions import input_file_name
|
||||
|
@ -587,17 +587,17 @@ class UDFTests(ReusedSQLTestCase):
|
|||
for df in [filesource_df, datasource_df, datasource_v2_df]:
|
||||
result = df.withColumn('c', c1)
|
||||
expected = df.withColumn('c', lit(2))
|
||||
self.assertEquals(expected.collect(), result.collect())
|
||||
self.assertEqual(expected.collect(), result.collect())
|
||||
|
||||
for df in [filesource_df, datasource_df, datasource_v2_df]:
|
||||
result = df.withColumn('c', c2)
|
||||
expected = df.withColumn('c', col('i') + 1)
|
||||
self.assertEquals(expected.collect(), result.collect())
|
||||
self.assertEqual(expected.collect(), result.collect())
|
||||
|
||||
for df in [filesource_df, datasource_df, datasource_v2_df]:
|
||||
for f in [f1, f2]:
|
||||
result = df.filter(f)
|
||||
self.assertEquals(0, result.count())
|
||||
self.assertEqual(0, result.count())
|
||||
finally:
|
||||
shutil.rmtree(path)
|
||||
|
||||
|
|
|
@ -31,23 +31,22 @@ class UtilsTests(ReusedSQLTestCase):
|
|||
try:
|
||||
self.spark.sql("select `中文字段`")
|
||||
except AnalysisException as e:
|
||||
self.assertRegexpMatches(str(e), "cannot resolve '`中文字段`'")
|
||||
self.assertRegex(str(e), "cannot resolve '`中文字段`'")
|
||||
|
||||
def test_capture_parse_exception(self):
|
||||
self.assertRaises(ParseException, lambda: self.spark.sql("abc"))
|
||||
|
||||
def test_capture_illegalargument_exception(self):
|
||||
self.assertRaisesRegexp(IllegalArgumentException, "Setting negative mapred.reduce.tasks",
|
||||
lambda: self.spark.sql("SET mapred.reduce.tasks=-1"))
|
||||
self.assertRaisesRegex(IllegalArgumentException, "Setting negative mapred.reduce.tasks",
|
||||
lambda: self.spark.sql("SET mapred.reduce.tasks=-1"))
|
||||
df = self.spark.createDataFrame([(1, 2)], ["a", "b"])
|
||||
self.assertRaisesRegexp(IllegalArgumentException, "1024 is not in the permitted values",
|
||||
lambda: df.select(sha2(df.a, 1024)).collect())
|
||||
self.assertRaisesRegex(IllegalArgumentException, "1024 is not in the permitted values",
|
||||
lambda: df.select(sha2(df.a, 1024)).collect())
|
||||
try:
|
||||
df.select(sha2(df.a, 1024)).collect()
|
||||
except IllegalArgumentException as e:
|
||||
self.assertRegexpMatches(e.desc, "1024 is not in the permitted values")
|
||||
self.assertRegexpMatches(e.stackTrace,
|
||||
"org.apache.spark.sql.functions")
|
||||
self.assertRegex(e.desc, "1024 is not in the permitted values")
|
||||
self.assertRegex(e.stackTrace, "org.apache.spark.sql.functions")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -85,11 +85,11 @@ class ProfilerTests2(unittest.TestCase):
|
|||
def test_profiler_disabled(self):
|
||||
sc = SparkContext(conf=SparkConf().set("spark.python.profile", "false"))
|
||||
try:
|
||||
self.assertRaisesRegexp(
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"'spark.python.profile' configuration must be set",
|
||||
lambda: sc.show_profiles())
|
||||
self.assertRaisesRegexp(
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"'spark.python.profile' configuration must be set",
|
||||
lambda: sc.dump_profiles("/tmp/abc"))
|
||||
|
|
|
@ -733,25 +733,25 @@ class RDDTests(ReusedPySparkTestCase):
|
|||
keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10))
|
||||
msg = "Caught StopIteration thrown from user's code; failing the task"
|
||||
|
||||
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.map(stopit).collect)
|
||||
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.filter(stopit).collect)
|
||||
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit)
|
||||
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.reduce, stopit)
|
||||
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.fold, 0, stopit)
|
||||
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit)
|
||||
self.assertRaisesRegexp(Py4JJavaError, msg,
|
||||
seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect)
|
||||
self.assertRaisesRegex(Py4JJavaError, msg, seq_rdd.map(stopit).collect)
|
||||
self.assertRaisesRegex(Py4JJavaError, msg, seq_rdd.filter(stopit).collect)
|
||||
self.assertRaisesRegex(Py4JJavaError, msg, seq_rdd.foreach, stopit)
|
||||
self.assertRaisesRegex(Py4JJavaError, msg, seq_rdd.reduce, stopit)
|
||||
self.assertRaisesRegex(Py4JJavaError, msg, seq_rdd.fold, 0, stopit)
|
||||
self.assertRaisesRegex(Py4JJavaError, msg, seq_rdd.foreach, stopit)
|
||||
self.assertRaisesRegex(Py4JJavaError, msg,
|
||||
seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect)
|
||||
|
||||
# these methods call the user function both in the driver and in the executor
|
||||
# the exception raised is different according to where the StopIteration happens
|
||||
# RuntimeError is raised if in the driver
|
||||
# Py4JJavaError is raised if in the executor (wraps the RuntimeError raised in the worker)
|
||||
self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
|
||||
keyed_rdd.reduceByKeyLocally, stopit)
|
||||
self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
|
||||
seq_rdd.aggregate, 0, stopit, lambda *x: 1)
|
||||
self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
|
||||
seq_rdd.aggregate, 0, lambda *x: 1, stopit)
|
||||
self.assertRaisesRegex((Py4JJavaError, RuntimeError), msg,
|
||||
keyed_rdd.reduceByKeyLocally, stopit)
|
||||
self.assertRaisesRegex((Py4JJavaError, RuntimeError), msg,
|
||||
seq_rdd.aggregate, 0, stopit, lambda *x: 1)
|
||||
self.assertRaisesRegex((Py4JJavaError, RuntimeError), msg,
|
||||
seq_rdd.aggregate, 0, lambda *x: 1, stopit)
|
||||
|
||||
def test_overwritten_global_func(self):
|
||||
# Regression test for SPARK-27000
|
||||
|
@ -768,7 +768,7 @@ class RDDTests(ReusedPySparkTestCase):
|
|||
|
||||
rdd = self.sc.range(10).map(fail)
|
||||
|
||||
with self.assertRaisesRegexp(Exception, "local iterator error"):
|
||||
with self.assertRaisesRegex(Exception, "local iterator error"):
|
||||
for _ in rdd.toLocalIterator():
|
||||
pass
|
||||
|
||||
|
|
|
@ -165,7 +165,7 @@ class WorkerTests(ReusedPySparkTestCase):
|
|||
|
||||
self.sc.parallelize([1]).map(lambda x: f()).count()
|
||||
except Py4JJavaError as e:
|
||||
self.assertRegexpMatches(str(e), "exception with 中")
|
||||
self.assertRegex(str(e), "exception with 中")
|
||||
|
||||
|
||||
class WorkerReuseTest(PySparkTestCase):
|
||||
|
|
Loading…
Reference in a new issue