[SPARK-33613][PYTHON][TESTS] Replace deprecated APIs in pyspark tests

### What changes were proposed in this pull request?

This replaces deprecated API usage in PySpark tests with the preferred APIs. These have been deprecated for some time and usage is not consistent within tests.

- https://docs.python.org/3/library/unittest.html#deprecated-aliases

### Why are the changes needed?

For consistency and eventual removal of deprecated APIs.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Existing tests

Closes #30557 from BryanCutler/replace-deprecated-apis-in-tests.

Authored-by: Bryan Cutler <cutlerb@gmail.com>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
Bryan Cutler 2020-12-01 10:34:40 +09:00 committed by HyukjinKwon
parent 596fbc1d29
commit aeb3649fb9
27 changed files with 274 additions and 275 deletions

View file

@ -169,7 +169,7 @@ class FeatureTests(SparkSessionTestCase):
# Test an empty vocabulary
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, "vocabSize.*invalid.*0"):
with self.assertRaisesRegex(Exception, "vocabSize.*invalid.*0"):
CountVectorizerModel.from_vocabulary([], inputCol="words")
# Test model with default settings can transform

View file

@ -47,19 +47,19 @@ class ImageFileFormatTest(SparkSessionTestCase):
self.assertEqual(ImageSchema.undefinedImageType, "Undefined")
with QuietTest(self.sc):
self.assertRaisesRegexp(
self.assertRaisesRegex(
TypeError,
"image argument should be pyspark.sql.types.Row; however",
lambda: ImageSchema.toNDArray("a"))
with QuietTest(self.sc):
self.assertRaisesRegexp(
self.assertRaisesRegex(
ValueError,
"image argument should have attributes specified in",
lambda: ImageSchema.toNDArray(Row(a=1)))
with QuietTest(self.sc):
self.assertRaisesRegexp(
self.assertRaisesRegex(
TypeError,
"array argument should be numpy.ndarray; however, it got",
lambda: ImageSchema.toImage("a"))

View file

@ -308,7 +308,7 @@ class ParamTests(SparkSessionTestCase):
LogisticRegression
)
self.assertRaisesRegexp(
self.assertRaisesRegex(
ValueError,
"Logistic Regression getThreshold found inconsistent.*$",
LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5]

View file

@ -442,7 +442,7 @@ class PersistenceTest(SparkSessionTestCase):
del metadata['defaultParamMap']
metadataStr = json.dumps(metadata, separators=[',', ':'])
loadedMetadata = reader._parseMetaData(metadataStr, )
with self.assertRaisesRegexp(AssertionError, "`defaultParamMap` section not found"):
with self.assertRaisesRegex(AssertionError, "`defaultParamMap` section not found"):
reader.getAndSetParams(lr, loadedMetadata)
# Prior to 2.4.0, metadata doesn't have `defaultParamMap`.

View file

@ -499,7 +499,7 @@ class CrossValidatorTests(SparkSessionTestCase):
evaluator=evaluator,
numFolds=2,
foldCol="fold")
with self.assertRaisesRegexp(Exception, "Fold number must be in range"):
with self.assertRaisesRegex(Exception, "Fold number must be in range"):
cv.fit(dataset_with_folds)
cv = CrossValidator(estimator=lr,
@ -507,7 +507,7 @@ class CrossValidatorTests(SparkSessionTestCase):
evaluator=evaluator,
numFolds=4,
foldCol="fold")
with self.assertRaisesRegexp(Exception, "The validation data at fold 3 is empty"):
with self.assertRaisesRegex(Exception, "The validation data at fold 3 is empty"):
cv.fit(dataset_with_folds)

View file

@ -54,7 +54,7 @@ class JavaWrapperMemoryTests(SparkSessionTestCase):
model.__del__()
def condition():
with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
with self.assertRaisesRegex(py4j.protocol.Py4JError, error_no_object):
model._java_obj.toString()
self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString())
return True
@ -67,9 +67,9 @@ class JavaWrapperMemoryTests(SparkSessionTestCase):
pass
def condition():
with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
with self.assertRaisesRegex(py4j.protocol.Py4JError, error_no_object):
model._java_obj.toString()
with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
with self.assertRaisesRegex(py4j.protocol.Py4JError, error_no_object):
summary._java_obj.toString()
return True

View file

@ -34,7 +34,7 @@ from pyspark.testing.utils import QuietTest
if have_pandas:
import pandas as pd
from pandas.util.testing import assert_frame_equal
from pandas.testing import assert_frame_equal
if have_pyarrow:
import pyarrow as pa # noqa: F401
@ -137,7 +137,7 @@ class ArrowTests(ReusedSQLTestCase):
df = self.spark.createDataFrame([(None,)], schema=schema)
with QuietTest(self.sc):
with self.warnings_lock:
with self.assertRaisesRegexp(Exception, 'Unsupported type'):
with self.assertRaisesRegex(Exception, 'Unsupported type'):
df.toPandas()
def test_null_conversion(self):
@ -214,7 +214,7 @@ class ArrowTests(ReusedSQLTestCase):
exception_udf = udf(raise_exception, IntegerType())
df = df.withColumn("error", exception_udf())
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, 'My error'):
with self.assertRaisesRegex(Exception, 'My error'):
df.toPandas()
def _createDataFrame_toggle(self, pdf, schema=None):
@ -228,7 +228,7 @@ class ArrowTests(ReusedSQLTestCase):
def test_createDataFrame_toggle(self):
pdf = self.create_pandas_data_frame()
df_no_arrow, df_arrow = self._createDataFrame_toggle(pdf, schema=self.schema)
self.assertEquals(df_no_arrow.collect(), df_arrow.collect())
self.assertEqual(df_no_arrow.collect(), df_arrow.collect())
def test_createDataFrame_respect_session_timezone(self):
from datetime import timedelta
@ -258,7 +258,7 @@ class ArrowTests(ReusedSQLTestCase):
def test_createDataFrame_with_schema(self):
pdf = self.create_pandas_data_frame()
df = self.spark.createDataFrame(pdf, schema=self.schema)
self.assertEquals(self.schema, df.schema)
self.assertEqual(self.schema, df.schema)
pdf_arrow = df.toPandas()
assert_frame_equal(pdf_arrow, pdf)
@ -269,7 +269,7 @@ class ArrowTests(ReusedSQLTestCase):
wrong_schema = StructType(fields)
with self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": False}):
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, "[D|d]ecimal.*got.*date"):
with self.assertRaisesRegex(Exception, "[D|d]ecimal.*got.*date"):
self.spark.createDataFrame(pdf, schema=wrong_schema)
def test_createDataFrame_with_names(self):
@ -277,23 +277,23 @@ class ArrowTests(ReusedSQLTestCase):
new_names = list(map(str, range(len(self.schema.fieldNames()))))
# Test that schema as a list of column names gets applied
df = self.spark.createDataFrame(pdf, schema=list(new_names))
self.assertEquals(df.schema.fieldNames(), new_names)
self.assertEqual(df.schema.fieldNames(), new_names)
# Test that schema as tuple of column names gets applied
df = self.spark.createDataFrame(pdf, schema=tuple(new_names))
self.assertEquals(df.schema.fieldNames(), new_names)
self.assertEqual(df.schema.fieldNames(), new_names)
def test_createDataFrame_column_name_encoding(self):
pdf = pd.DataFrame({u'a': [1]})
columns = self.spark.createDataFrame(pdf).columns
self.assertTrue(isinstance(columns[0], str))
self.assertEquals(columns[0], 'a')
self.assertEqual(columns[0], 'a')
columns = self.spark.createDataFrame(pdf, [u'b']).columns
self.assertTrue(isinstance(columns[0], str))
self.assertEquals(columns[0], 'b')
self.assertEqual(columns[0], 'b')
def test_createDataFrame_with_single_data_type(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(ValueError, ".*IntegerType.*not supported.*"):
with self.assertRaisesRegex(ValueError, ".*IntegerType.*not supported.*"):
self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int")
def test_createDataFrame_does_not_modify_input(self):
@ -311,7 +311,7 @@ class ArrowTests(ReusedSQLTestCase):
from pyspark.sql.pandas.types import from_arrow_schema, to_arrow_schema
arrow_schema = to_arrow_schema(self.schema)
schema_rt = from_arrow_schema(arrow_schema)
self.assertEquals(self.schema, schema_rt)
self.assertEqual(self.schema, schema_rt)
def test_createDataFrame_with_array_type(self):
pdf = pd.DataFrame({"a": [[1, 2], [3, 4]], "b": [[u"x", u"y"], [u"y", u"z"]]})
@ -420,7 +420,7 @@ class ArrowTests(ReusedSQLTestCase):
def test_createDataFrame_fallback_disabled(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(TypeError, 'Unsupported type'):
with self.assertRaisesRegex(TypeError, 'Unsupported type'):
self.spark.createDataFrame(
pd.DataFrame({"a": [[datetime.datetime(2015, 11, 1, 0, 30)]]}),
"a: array<timestamp>")
@ -545,7 +545,7 @@ class MaxResultArrowTests(unittest.TestCase):
cls.spark.stop()
def test_exception_by_max_results(self):
with self.assertRaisesRegexp(Exception, "is bigger than"):
with self.assertRaisesRegex(Exception, "is bigger than"):
self.spark.range(0, 10000, 1, 100).toPandas()

View file

@ -25,11 +25,11 @@ class CatalogTests(ReusedSQLTestCase):
def test_current_database(self):
spark = self.spark
with self.database("some_db"):
self.assertEquals(spark.catalog.currentDatabase(), "default")
self.assertEqual(spark.catalog.currentDatabase(), "default")
spark.sql("CREATE DATABASE some_db")
spark.catalog.setCurrentDatabase("some_db")
self.assertEquals(spark.catalog.currentDatabase(), "some_db")
self.assertRaisesRegexp(
self.assertEqual(spark.catalog.currentDatabase(), "some_db")
self.assertRaisesRegex(
AnalysisException,
"does_not_exist",
lambda: spark.catalog.setCurrentDatabase("does_not_exist"))
@ -38,10 +38,10 @@ class CatalogTests(ReusedSQLTestCase):
spark = self.spark
with self.database("some_db"):
databases = [db.name for db in spark.catalog.listDatabases()]
self.assertEquals(databases, ["default"])
self.assertEqual(databases, ["default"])
spark.sql("CREATE DATABASE some_db")
databases = [db.name for db in spark.catalog.listDatabases()]
self.assertEquals(sorted(databases), ["default", "some_db"])
self.assertEqual(sorted(databases), ["default", "some_db"])
def test_list_tables(self):
from pyspark.sql.catalog import Table
@ -50,8 +50,8 @@ class CatalogTests(ReusedSQLTestCase):
spark.sql("CREATE DATABASE some_db")
with self.table("tab1", "some_db.tab2", "tab3_via_catalog"):
with self.tempView("temp_tab"):
self.assertEquals(spark.catalog.listTables(), [])
self.assertEquals(spark.catalog.listTables("some_db"), [])
self.assertEqual(spark.catalog.listTables(), [])
self.assertEqual(spark.catalog.listTables("some_db"), [])
spark.createDataFrame([(1, 1)]).createOrReplaceTempView("temp_tab")
spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet")
spark.sql("CREATE TABLE some_db.tab2 (name STRING, age INT) USING parquet")
@ -66,40 +66,40 @@ class CatalogTests(ReusedSQLTestCase):
sorted(spark.catalog.listTables("default"), key=lambda t: t.name)
tablesSomeDb = \
sorted(spark.catalog.listTables("some_db"), key=lambda t: t.name)
self.assertEquals(tables, tablesDefault)
self.assertEquals(len(tables), 3)
self.assertEquals(len(tablesSomeDb), 2)
self.assertEquals(tables[0], Table(
self.assertEqual(tables, tablesDefault)
self.assertEqual(len(tables), 3)
self.assertEqual(len(tablesSomeDb), 2)
self.assertEqual(tables[0], Table(
name="tab1",
database="default",
description=None,
tableType="MANAGED",
isTemporary=False))
self.assertEquals(tables[1], Table(
self.assertEqual(tables[1], Table(
name="tab3_via_catalog",
database="default",
description=description,
tableType="MANAGED",
isTemporary=False))
self.assertEquals(tables[2], Table(
self.assertEqual(tables[2], Table(
name="temp_tab",
database=None,
description=None,
tableType="TEMPORARY",
isTemporary=True))
self.assertEquals(tablesSomeDb[0], Table(
self.assertEqual(tablesSomeDb[0], Table(
name="tab2",
database="some_db",
description=None,
tableType="MANAGED",
isTemporary=False))
self.assertEquals(tablesSomeDb[1], Table(
self.assertEqual(tablesSomeDb[1], Table(
name="temp_tab",
database=None,
description=None,
tableType="TEMPORARY",
isTemporary=True))
self.assertRaisesRegexp(
self.assertRaisesRegex(
AnalysisException,
"does_not_exist",
lambda: spark.catalog.listTables("does_not_exist"))
@ -119,12 +119,12 @@ class CatalogTests(ReusedSQLTestCase):
self.assertTrue("to_timestamp" in functions)
self.assertTrue("to_unix_timestamp" in functions)
self.assertTrue("current_database" in functions)
self.assertEquals(functions["+"], Function(
self.assertEqual(functions["+"], Function(
name="+",
description=None,
className="org.apache.spark.sql.catalyst.expressions.Add",
isTemporary=True))
self.assertEquals(functions, functionsDefault)
self.assertEqual(functions, functionsDefault)
with self.function("func1", "some_db.func2"):
spark.catalog.registerFunction("temp_func", lambda x: str(x))
@ -141,7 +141,7 @@ class CatalogTests(ReusedSQLTestCase):
self.assertTrue("temp_func" in newFunctionsSomeDb)
self.assertTrue("func1" not in newFunctionsSomeDb)
self.assertTrue("func2" in newFunctionsSomeDb)
self.assertRaisesRegexp(
self.assertRaisesRegex(
AnalysisException,
"does_not_exist",
lambda: spark.catalog.listFunctions("does_not_exist"))
@ -158,16 +158,16 @@ class CatalogTests(ReusedSQLTestCase):
columns = sorted(spark.catalog.listColumns("tab1"), key=lambda c: c.name)
columnsDefault = \
sorted(spark.catalog.listColumns("tab1", "default"), key=lambda c: c.name)
self.assertEquals(columns, columnsDefault)
self.assertEquals(len(columns), 2)
self.assertEquals(columns[0], Column(
self.assertEqual(columns, columnsDefault)
self.assertEqual(len(columns), 2)
self.assertEqual(columns[0], Column(
name="age",
description=None,
dataType="int",
nullable=True,
isPartition=False,
isBucket=False))
self.assertEquals(columns[1], Column(
self.assertEqual(columns[1], Column(
name="name",
description=None,
dataType="string",
@ -176,26 +176,26 @@ class CatalogTests(ReusedSQLTestCase):
isBucket=False))
columns2 = \
sorted(spark.catalog.listColumns("tab2", "some_db"), key=lambda c: c.name)
self.assertEquals(len(columns2), 2)
self.assertEquals(columns2[0], Column(
self.assertEqual(len(columns2), 2)
self.assertEqual(columns2[0], Column(
name="nickname",
description=None,
dataType="string",
nullable=True,
isPartition=False,
isBucket=False))
self.assertEquals(columns2[1], Column(
self.assertEqual(columns2[1], Column(
name="tolerance",
description=None,
dataType="float",
nullable=True,
isPartition=False,
isBucket=False))
self.assertRaisesRegexp(
self.assertRaisesRegex(
AnalysisException,
"tab2",
lambda: spark.catalog.listColumns("tab2"))
self.assertRaisesRegexp(
self.assertRaisesRegex(
AnalysisException,
"does_not_exist",
lambda: spark.catalog.listColumns("does_not_exist"))

View file

@ -47,7 +47,7 @@ class ColumnTests(ReusedSQLTestCase):
self.assertTrue("Column" in _to_java_column(u"a").getClass().toString())
self.assertTrue("Column" in _to_java_column(self.spark.range(1).id).getClass().toString())
self.assertRaisesRegexp(
self.assertRaisesRegex(
TypeError,
"Invalid argument, not a string or column",
lambda: _to_java_column(1))
@ -58,7 +58,7 @@ class ColumnTests(ReusedSQLTestCase):
self.assertRaises(TypeError, lambda: _to_java_column(A()))
self.assertRaises(TypeError, lambda: _to_java_column([]))
self.assertRaisesRegexp(
self.assertRaisesRegex(
TypeError,
"Invalid argument, not a string or column",
lambda: udf(lambda x: x)(None))
@ -79,9 +79,9 @@ class ColumnTests(ReusedSQLTestCase):
cs.startswith('a'), cs.endswith('a'), ci.eqNullSafe(cs)
self.assertTrue(all(isinstance(c, Column) for c in css))
self.assertTrue(isinstance(ci.cast(LongType()), Column))
self.assertRaisesRegexp(ValueError,
"Cannot apply 'in' operator against a column",
lambda: 1 in cs)
self.assertRaisesRegex(ValueError,
"Cannot apply 'in' operator against a column",
lambda: 1 in cs)
def test_column_accessor(self):
from pyspark.sql.functions import col

View file

@ -28,7 +28,7 @@ class ConfTests(ReusedSQLTestCase):
self.assertEqual(spark.conf.get("bogo"), "ta")
self.assertEqual(spark.conf.get("bogo", "not.read"), "ta")
self.assertEqual(spark.conf.get("not.set", "ta"), "ta")
self.assertRaisesRegexp(Exception, "not.set", lambda: spark.conf.get("not.set"))
self.assertRaisesRegex(Exception, "not.set", lambda: spark.conf.get("not.set"))
spark.conf.unset("bogo")
self.assertEqual(spark.conf.get("bogo", "colombia"), "colombia")

View file

@ -343,7 +343,7 @@ class DataFrameTests(ReusedSQLTestCase):
self.spark.createDataFrame(
[(u'Alice', 10, 80.1)], schema).replace({u"Alice": u"Bob", 10: 20}).first()
with self.assertRaisesRegexp(
with self.assertRaisesRegex(
TypeError,
'value argument is required when to_replace is not a dictionary.'):
self.spark.createDataFrame(
@ -390,7 +390,7 @@ class DataFrameTests(ReusedSQLTestCase):
self.assertEqual(3, logical_plan.toString().count("itworks"))
def test_sample(self):
self.assertRaisesRegexp(
self.assertRaisesRegex(
TypeError,
"should be a bool, float and number",
lambda: self.spark.range(1).sample())
@ -426,12 +426,12 @@ class DataFrameTests(ReusedSQLTestCase):
self.assertEqual(df.collect(), data)
# number of fields must match.
self.assertRaisesRegexp(Exception, "Length of object",
lambda: rdd.toDF("key: int").collect())
self.assertRaisesRegex(Exception, "Length of object",
lambda: rdd.toDF("key: int").collect())
# field types mismatch will cause exception at runtime.
self.assertRaisesRegexp(Exception, "FloatType can not accept",
lambda: rdd.toDF("key: float, value: string").collect())
self.assertRaisesRegex(Exception, "FloatType can not accept",
lambda: rdd.toDF("key: float, value: string").collect())
# flat schema values will be wrapped into row.
df = rdd.map(lambda row: row.key).toDF("int")
@ -491,15 +491,15 @@ class DataFrameTests(ReusedSQLTestCase):
spark.catalog.clearCache()
self.assertFalse(spark.catalog.isCached("tab1"))
self.assertFalse(spark.catalog.isCached("tab2"))
self.assertRaisesRegexp(
self.assertRaisesRegex(
AnalysisException,
"does_not_exist",
lambda: spark.catalog.isCached("does_not_exist"))
self.assertRaisesRegexp(
self.assertRaisesRegex(
AnalysisException,
"does_not_exist",
lambda: spark.catalog.cacheTable("does_not_exist"))
self.assertRaisesRegexp(
self.assertRaisesRegex(
AnalysisException,
"does_not_exist",
lambda: spark.catalog.uncacheTable("does_not_exist"))
@ -523,12 +523,12 @@ class DataFrameTests(ReusedSQLTestCase):
import numpy as np
pdf = self._to_pandas()
types = pdf.dtypes
self.assertEquals(types[0], np.int32)
self.assertEquals(types[1], np.object)
self.assertEquals(types[2], np.bool)
self.assertEquals(types[3], np.float32)
self.assertEquals(types[4], np.object) # datetime.date
self.assertEquals(types[5], 'datetime64[ns]')
self.assertEqual(types[0], np.int32)
self.assertEqual(types[1], np.object)
self.assertEqual(types[2], np.bool)
self.assertEqual(types[3], np.float32)
self.assertEqual(types[4], np.object) # datetime.date
self.assertEqual(types[5], 'datetime64[ns]')
@unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore
def test_to_pandas_with_duplicated_column_names(self):
@ -540,8 +540,8 @@ class DataFrameTests(ReusedSQLTestCase):
df = self.spark.sql(sql)
pdf = df.toPandas()
types = pdf.dtypes
self.assertEquals(types.iloc[0], np.int32)
self.assertEquals(types.iloc[1], np.int32)
self.assertEqual(types.iloc[0], np.int32)
self.assertEqual(types.iloc[1], np.int32)
@unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore
def test_to_pandas_on_cross_join(self):
@ -560,13 +560,13 @@ class DataFrameTests(ReusedSQLTestCase):
df = self.spark.sql(sql)
pdf = df.toPandas()
types = pdf.dtypes
self.assertEquals(types.iloc[0], np.int32)
self.assertEquals(types.iloc[1], np.int32)
self.assertEqual(types.iloc[0], np.int32)
self.assertEqual(types.iloc[1], np.int32)
@unittest.skipIf(have_pandas, "Required Pandas was found.")
def test_to_pandas_required_pandas_not_found(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'):
with self.assertRaisesRegex(ImportError, 'Pandas >= .* must be installed'):
self._to_pandas()
@unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore
@ -577,9 +577,9 @@ class DataFrameTests(ReusedSQLTestCase):
data = [(1, "foo", 16777220), (None, "bar", None)]
df = self.spark.createDataFrame(data, schema)
types = df.toPandas().dtypes
self.assertEquals(types[0], np.float64) # doesn't convert to np.int32 due to NaN value.
self.assertEquals(types[1], np.object)
self.assertEquals(types[2], np.float64)
self.assertEqual(types[0], np.float64) # doesn't convert to np.int32 due to NaN value.
self.assertEqual(types[1], np.object)
self.assertEqual(types[2], np.float64)
@unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore
def test_to_pandas_from_empty_dataframe(self):
@ -675,7 +675,7 @@ class DataFrameTests(ReusedSQLTestCase):
@unittest.skipIf(have_pandas, "Required Pandas was found.")
def test_create_dataframe_required_pandas_not_found(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(
with self.assertRaisesRegex(
ImportError,
"(Pandas >= .* must be installed|No module named '?pandas'?)"):
import pandas as pd
@ -688,7 +688,7 @@ class DataFrameTests(ReusedSQLTestCase):
@unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore
def test_create_dataframe_from_pandas_with_dst(self):
import pandas as pd
from pandas.util.testing import assert_frame_equal
from pandas.testing import assert_frame_equal
from datetime import datetime
pdf = pd.DataFrame({'time': [datetime(2015, 10, 31, 22, 30)]})
@ -724,7 +724,7 @@ class DataFrameTests(ReusedSQLTestCase):
||22222|22222|
|+-----+-----+
|"""
self.assertEquals(re.sub(pattern, '', expected1), df.__repr__())
self.assertEqual(re.sub(pattern, '', expected1), df.__repr__())
with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}):
expected2 = """+---+-----+
||key|value|
@ -733,7 +733,7 @@ class DataFrameTests(ReusedSQLTestCase):
||222| 222|
|+---+-----+
|"""
self.assertEquals(re.sub(pattern, '', expected2), df.__repr__())
self.assertEqual(re.sub(pattern, '', expected2), df.__repr__())
with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}):
expected3 = """+---+-----+
||key|value|
@ -742,7 +742,7 @@ class DataFrameTests(ReusedSQLTestCase):
|+---+-----+
|only showing top 1 row
|"""
self.assertEquals(re.sub(pattern, '', expected3), df.__repr__())
self.assertEqual(re.sub(pattern, '', expected3), df.__repr__())
# test when eager evaluation is enabled and _repr_html_ will be called
with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}):
@ -752,7 +752,7 @@ class DataFrameTests(ReusedSQLTestCase):
|<tr><td>22222</td><td>22222</td></tr>
|</table>
|"""
self.assertEquals(re.sub(pattern, '', expected1), df._repr_html_())
self.assertEqual(re.sub(pattern, '', expected1), df._repr_html_())
with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}):
expected2 = """<table border='1'>
|<tr><th>key</th><th>value</th></tr>
@ -760,7 +760,7 @@ class DataFrameTests(ReusedSQLTestCase):
|<tr><td>222</td><td>222</td></tr>
|</table>
|"""
self.assertEquals(re.sub(pattern, '', expected2), df._repr_html_())
self.assertEqual(re.sub(pattern, '', expected2), df._repr_html_())
with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}):
expected3 = """<table border='1'>
|<tr><th>key</th><th>value</th></tr>
@ -768,19 +768,19 @@ class DataFrameTests(ReusedSQLTestCase):
|</table>
|only showing top 1 row
|"""
self.assertEquals(re.sub(pattern, '', expected3), df._repr_html_())
self.assertEqual(re.sub(pattern, '', expected3), df._repr_html_())
# test when eager evaluation is disabled and _repr_html_ will be called
with self.sql_conf({"spark.sql.repl.eagerEval.enabled": False}):
expected = "DataFrame[key: bigint, value: string]"
self.assertEquals(None, df._repr_html_())
self.assertEquals(expected, df.__repr__())
self.assertEqual(None, df._repr_html_())
self.assertEqual(expected, df.__repr__())
with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}):
self.assertEquals(None, df._repr_html_())
self.assertEquals(expected, df.__repr__())
self.assertEqual(None, df._repr_html_())
self.assertEqual(expected, df.__repr__())
with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}):
self.assertEquals(None, df._repr_html_())
self.assertEquals(expected, df.__repr__())
self.assertEqual(None, df._repr_html_())
self.assertEqual(expected, df.__repr__())
def test_to_local_iterator(self):
df = self.spark.range(8, numPartitions=4)
@ -818,7 +818,7 @@ class DataFrameTests(ReusedSQLTestCase):
def test_same_semantics_error(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(ValueError, "should be of DataFrame.*int"):
with self.assertRaisesRegex(ValueError, "should be of DataFrame.*int"):
self.spark.range(10).sameSemantics(1)
def test_input_files(self):
@ -830,7 +830,7 @@ class DataFrameTests(ReusedSQLTestCase):
input_files_list = self.spark.read.parquet(tpath).inputFiles()
# input files list should contain 10 entries
self.assertEquals(len(input_files_list), 10)
self.assertEqual(len(input_files_list), 10)
# all file paths in list must contain tpath
for file_path in input_files_list:
self.assertTrue(tpath in file_path)

View file

@ -107,7 +107,7 @@ class DataSourcesTests(ReusedSQLTestCase):
df = self.spark.read.text(['python/test_support/sql/text-test.txt',
'python/test_support/sql/text-test.txt'])
count = df.count()
self.assertEquals(count, 4)
self.assertEqual(count, 4)
def test_json_sampling_ratio(self):
rdd = self.spark.sparkContext.range(0, 100, 1, 1) \
@ -115,14 +115,14 @@ class DataSourcesTests(ReusedSQLTestCase):
schema = self.spark.read.option('inferSchema', True) \
.option('samplingRatio', 0.5) \
.json(rdd).schema
self.assertEquals(schema, StructType([StructField("a", LongType(), True)]))
self.assertEqual(schema, StructType([StructField("a", LongType(), True)]))
def test_csv_sampling_ratio(self):
rdd = self.spark.sparkContext.range(0, 100, 1, 1) \
.map(lambda x: '0.1' if x == 1 else str(x))
schema = self.spark.read.option('inferSchema', True)\
.csv(rdd, samplingRatio=0.5).schema
self.assertEquals(schema, StructType([StructField("_c0", IntegerType(), True)]))
self.assertEqual(schema, StructType([StructField("_c0", IntegerType(), True)]))
def test_checking_csv_header(self):
path = tempfile.mkdtemp()
@ -135,7 +135,7 @@ class DataSourcesTests(ReusedSQLTestCase):
StructField('f1', IntegerType(), nullable=True)])
df = self.spark.read.option('header', 'true').schema(schema)\
.csv(path, enforceSchema=False)
self.assertRaisesRegexp(
self.assertRaisesRegex(
Exception,
"CSV header does not conform to the schema",
lambda: df.collect())
@ -154,7 +154,7 @@ class DataSourcesTests(ReusedSQLTestCase):
StructField('b', LongType(), nullable=True),
StructField('c', StringType(), nullable=True)])
readback = self.spark.read.json(path, dropFieldIfAllNull=True)
self.assertEquals(readback.schema, schema)
self.assertEqual(readback.schema, schema)
finally:
shutil.rmtree(path)

View file

@ -185,7 +185,7 @@ class FunctionsTests(ReusedSQLTestCase):
]
df = self.spark.createDataFrame([['nick']], schema=['name'])
self.assertRaisesRegexp(
self.assertRaisesRegex(
TypeError,
"must be the same type",
lambda: df.select(col('name').substr(0, lit(1))))
@ -321,16 +321,16 @@ class FunctionsTests(ReusedSQLTestCase):
df = self.spark.createDataFrame(
[('Tom', 80), (None, 60), ('Alice', 50)], ["name", "height"])
self.assertEquals(
self.assertEqual(
df.select(df.name).orderBy(functions.asc_nulls_first('name')).collect(),
[Row(name=None), Row(name=u'Alice'), Row(name=u'Tom')])
self.assertEquals(
self.assertEqual(
df.select(df.name).orderBy(functions.asc_nulls_last('name')).collect(),
[Row(name=u'Alice'), Row(name=u'Tom'), Row(name=None)])
self.assertEquals(
self.assertEqual(
df.select(df.name).orderBy(functions.desc_nulls_first('name')).collect(),
[Row(name=None), Row(name=u'Tom'), Row(name=u'Alice')])
self.assertEquals(
self.assertEqual(
df.select(df.name).orderBy(functions.desc_nulls_last('name')).collect(),
[Row(name=u'Tom'), Row(name=u'Alice'), Row(name=None)])
@ -354,7 +354,7 @@ class FunctionsTests(ReusedSQLTestCase):
df = self.spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ['x'])
self.assertEquals(
self.assertEqual(
df.select(slice(df.x, 2, 2).alias("sliced")).collect(),
df.select(slice(df.x, lit(2), lit(2)).alias("sliced")).collect(),
)
@ -364,7 +364,7 @@ class FunctionsTests(ReusedSQLTestCase):
df = self.spark.range(1)
self.assertEquals(
self.assertEqual(
df.select(array_repeat("id", 3)).toDF("val").collect(),
df.select(array_repeat("id", lit(3))).toDF("val").collect(),
)
@ -580,14 +580,14 @@ class FunctionsTests(ReusedSQLTestCase):
from datetime import date
df = self.spark.range(1).selectExpr("'2017-01-22' as dateCol")
parse_result = df.select(functions.to_date(functions.col("dateCol"))).first()
self.assertEquals(date(2017, 1, 22), parse_result['to_date(dateCol)'])
self.assertEqual(date(2017, 1, 22), parse_result['to_date(dateCol)'])
def test_assert_true(self):
from pyspark.sql.functions import assert_true
df = self.spark.range(3)
self.assertEquals(
self.assertEqual(
df.select(assert_true(df.id < 3)).toDF("val").collect(),
[Row(val=None), Row(val=None), Row(val=None)],
)
@ -604,7 +604,7 @@ class FunctionsTests(ReusedSQLTestCase):
with self.assertRaises(TypeError) as cm:
df.select(assert_true(df.id < 2, 5))
self.assertEquals(
self.assertEqual(
"errMsg should be a Column or a str, got <class 'int'>",
str(cm.exception)
)
@ -626,7 +626,7 @@ class FunctionsTests(ReusedSQLTestCase):
with self.assertRaises(TypeError) as cm:
df.select(raise_error(None))
self.assertEquals(
self.assertEqual(
"errMsg should be a Column or a str, got <class 'NoneType'>",
str(cm.exception)
)

View file

@ -25,7 +25,7 @@ from pyspark.testing.utils import QuietTest
if have_pandas:
import pandas as pd
from pandas.util.testing import assert_frame_equal
from pandas.testing import assert_frame_equal
if have_pyarrow:
import pyarrow as pa # noqa: F401
@ -135,8 +135,8 @@ class CogroupedMapInPandasTests(ReusedSQLTestCase):
.applyInPandas(lambda x, y: pd.DataFrame([(x.sum().sum(), y.sum().sum())]),
'sum1 int, sum2 int').collect()
self.assertEquals(result[0]['sum1'], 165)
self.assertEquals(result[0]['sum2'], 165)
self.assertEqual(result[0]['sum1'], 165)
self.assertEqual(result[0]['sum2'], 165)
def test_with_key_left(self):
self._test_with_key(self.data1, self.data1, isLeft=True)
@ -174,7 +174,7 @@ class CogroupedMapInPandasTests(ReusedSQLTestCase):
left = self.data1
right = self.data2
with QuietTest(self.sc):
with self.assertRaisesRegexp(
with self.assertRaisesRegex(
NotImplementedError,
'Invalid return type.*ArrayType.*TimestampType'):
left.groupby('id').cogroup(right.groupby('id')).applyInPandas(
@ -183,7 +183,7 @@ class CogroupedMapInPandasTests(ReusedSQLTestCase):
def test_wrong_args(self):
left = self.data1
right = self.data2
with self.assertRaisesRegexp(ValueError, 'Invalid function'):
with self.assertRaisesRegex(ValueError, 'Invalid function'):
left.groupby('id').cogroup(right.groupby('id')) \
.applyInPandas(lambda: 1, StructType([StructField("d", DoubleType())]))
@ -194,14 +194,14 @@ class CogroupedMapInPandasTests(ReusedSQLTestCase):
row = df1.groupby("ColUmn").cogroup(
df1.groupby("COLUMN")
).applyInPandas(lambda r, l: r + l, "column long, value long").first()
self.assertEquals(row.asDict(), Row(column=2, value=2).asDict())
self.assertEqual(row.asDict(), Row(column=2, value=2).asDict())
df2 = self.spark.createDataFrame([(1, 1)], ("column", "value"))
row = df1.groupby("ColUmn").cogroup(
df2.groupby("COLUMN")
).applyInPandas(lambda r, l: r + l, "column long, value long").first()
self.assertEquals(row.asDict(), Row(column=2, value=2).asDict())
self.assertEqual(row.asDict(), Row(column=2, value=2).asDict())
@staticmethod
def _test_with_key(left, right, isLeft):

View file

@ -33,7 +33,7 @@ from pyspark.testing.utils import QuietTest
if have_pandas:
import pandas as pd
from pandas.util.testing import assert_frame_equal
from pandas.testing import assert_frame_equal
if have_pyarrow:
import pyarrow as pa # noqa: F401
@ -160,7 +160,7 @@ class GroupedMapInPandasTests(ReusedSQLTestCase):
def test_register_grouped_map_udf(self):
foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUPED_MAP)
with QuietTest(self.sc):
with self.assertRaisesRegexp(
with self.assertRaisesRegex(
ValueError,
'f.*SQL_BATCHED_UDF.*SQL_SCALAR_PANDAS_UDF.*SQL_GROUPED_AGG_PANDAS_UDF.*'):
self.spark.catalog.registerFunction("foo_udf", foo_udf)
@ -244,7 +244,7 @@ class GroupedMapInPandasTests(ReusedSQLTestCase):
def test_wrong_return_type(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(
with self.assertRaisesRegex(
NotImplementedError,
'Invalid return type.*grouped map Pandas UDF.*ArrayType.*TimestampType'):
pandas_udf(
@ -256,20 +256,20 @@ class GroupedMapInPandasTests(ReusedSQLTestCase):
df = self.data
with QuietTest(self.sc):
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
with self.assertRaisesRegex(ValueError, 'Invalid udf'):
df.groupby('id').apply(lambda x: x)
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
with self.assertRaisesRegex(ValueError, 'Invalid udf'):
df.groupby('id').apply(udf(lambda x: x, DoubleType()))
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
with self.assertRaisesRegex(ValueError, 'Invalid udf'):
df.groupby('id').apply(sum(df.v))
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
with self.assertRaisesRegex(ValueError, 'Invalid udf'):
df.groupby('id').apply(df.v + 1)
with self.assertRaisesRegexp(ValueError, 'Invalid function'):
with self.assertRaisesRegex(ValueError, 'Invalid function'):
df.groupby('id').apply(
pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())])))
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
with self.assertRaisesRegex(ValueError, 'Invalid udf'):
df.groupby('id').apply(pandas_udf(lambda x, y: x, DoubleType()))
with self.assertRaisesRegexp(ValueError, 'Invalid udf.*GROUPED_MAP'):
with self.assertRaisesRegex(ValueError, 'Invalid udf.*GROUPED_MAP'):
df.groupby('id').apply(
pandas_udf(lambda x, y: x, DoubleType(), PandasUDFType.SCALAR))
@ -284,7 +284,7 @@ class GroupedMapInPandasTests(ReusedSQLTestCase):
for unsupported_type in unsupported_types:
schema = StructType([StructField('id', LongType(), True), unsupported_type])
with QuietTest(self.sc):
with self.assertRaisesRegexp(NotImplementedError, common_err_msg):
with self.assertRaisesRegex(NotImplementedError, common_err_msg):
pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP)
# Regression test for SPARK-23314
@ -451,9 +451,9 @@ class GroupedMapInPandasTests(ReusedSQLTestCase):
with self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": False}):
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, "KeyError: 'id'"):
with self.assertRaisesRegex(Exception, "KeyError: 'id'"):
grouped_df.apply(column_name_typo).collect()
with self.assertRaisesRegexp(Exception, "[D|d]ecimal.*got.*date"):
with self.assertRaisesRegex(Exception, "[D|d]ecimal.*got.*date"):
grouped_df.apply(invalid_positional_types).collect()
def test_positional_assignment_conf(self):
@ -482,7 +482,7 @@ class GroupedMapInPandasTests(ReusedSQLTestCase):
# this was throwing an AnalysisException before SPARK-24208
res = df_with_pandas.alias('temp0').join(df_with_pandas.alias('temp1'),
col('temp0.key') == col('temp1.key'))
self.assertEquals(res.count(), 5)
self.assertEqual(res.count(), 5)
def test_mixed_scalar_udfs_followed_by_groupby_apply(self):
df = self.spark.range(0, 10).toDF('v1')
@ -494,7 +494,7 @@ class GroupedMapInPandasTests(ReusedSQLTestCase):
'sum int',
PandasUDFType.GROUPED_MAP))
self.assertEquals(result.collect()[0]['sum'], 165)
self.assertEqual(result.collect()[0]['sum'], 165)
def test_grouped_with_empty_partition(self):
data = [Row(id=1, x=2), Row(id=1, x=3), Row(id=2, x=4)]
@ -604,7 +604,7 @@ class GroupedMapInPandasTests(ReusedSQLTestCase):
df = self.spark.createDataFrame([[1, 1]], ["column", "score"])
row = df.groupby('COLUMN').applyInPandas(
my_pandas_udf, schema="column integer, score float").first()
self.assertEquals(row.asDict(), Row(column=1, score=0.5).asDict())
self.assertEqual(row.asDict(), Row(column=1, score=0.5).asDict())
if __name__ == "__main__":

View file

@ -61,7 +61,7 @@ class MapInPandasTests(ReusedSQLTestCase):
df = self.spark.range(10)
actual = df.mapInPandas(func, 'id long').collect()
expected = df.collect()
self.assertEquals(actual, expected)
self.assertEqual(actual, expected)
def test_multiple_columns(self):
data = [(1, "foo"), (2, None), (3, "bar"), (4, "bar")]
@ -75,7 +75,7 @@ class MapInPandasTests(ReusedSQLTestCase):
actual = df.mapInPandas(func, df.schema).collect()
expected = df.collect()
self.assertEquals(actual, expected)
self.assertEqual(actual, expected)
def test_different_output_length(self):
def func(iterator):
@ -84,7 +84,7 @@ class MapInPandasTests(ReusedSQLTestCase):
df = self.spark.range(10)
actual = df.repartition(1).mapInPandas(func, 'a long').collect()
self.assertEquals(set((r.a for r in actual)), set(range(100)))
self.assertEqual(set((r.a for r in actual)), set(range(100)))
def test_empty_iterator(self):
def empty_iter(_):
@ -110,7 +110,7 @@ class MapInPandasTests(ReusedSQLTestCase):
df = self.spark.range(10)
actual = df.mapInPandas(func, 'id long').mapInPandas(func, 'id long').collect()
expected = df.collect()
self.assertEquals(actual, expected)
self.assertEqual(actual, expected)
if __name__ == "__main__":

View file

@ -114,31 +114,31 @@ class PandasUDFTests(ReusedSQLTestCase):
@pandas_udf('blah')
def foo(x):
return x
with self.assertRaisesRegexp(ValueError, 'Invalid return type.*None'):
with self.assertRaisesRegex(ValueError, 'Invalid return type.*None'):
@pandas_udf(functionType=PandasUDFType.SCALAR)
def foo(x):
return x
with self.assertRaisesRegexp(ValueError, 'Invalid function'):
with self.assertRaisesRegex(ValueError, 'Invalid function'):
@pandas_udf('double', 100)
def foo(x):
return x
with self.assertRaisesRegexp(ValueError, '0-arg pandas_udfs.*not.*supported'):
with self.assertRaisesRegex(ValueError, '0-arg pandas_udfs.*not.*supported'):
pandas_udf(lambda: 1, LongType(), PandasUDFType.SCALAR)
with self.assertRaisesRegexp(ValueError, '0-arg pandas_udfs.*not.*supported'):
with self.assertRaisesRegex(ValueError, '0-arg pandas_udfs.*not.*supported'):
@pandas_udf(LongType(), PandasUDFType.SCALAR)
def zero_with_type():
return 1
with self.assertRaisesRegexp(TypeError, 'Invalid return type'):
with self.assertRaisesRegex(TypeError, 'Invalid return type'):
@pandas_udf(returnType=PandasUDFType.GROUPED_MAP)
def foo(df):
return df
with self.assertRaisesRegexp(TypeError, 'Invalid return type'):
with self.assertRaisesRegex(TypeError, 'Invalid return type'):
@pandas_udf(returnType='double', functionType=PandasUDFType.GROUPED_MAP)
def foo(df):
return df
with self.assertRaisesRegexp(ValueError, 'Invalid function'):
with self.assertRaisesRegex(ValueError, 'Invalid function'):
@pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUPED_MAP)
def foo(k, v, w):
return k
@ -154,14 +154,14 @@ class PandasUDFTests(ReusedSQLTestCase):
df = self.spark.range(0, 100)
# plain udf (test for SPARK-23754)
self.assertRaisesRegexp(
self.assertRaisesRegex(
PythonException,
exc_message,
df.withColumn('v', udf(foo)('id')).collect
)
# pandas scalar udf
self.assertRaisesRegexp(
self.assertRaisesRegex(
PythonException,
exc_message,
df.withColumn(
@ -170,7 +170,7 @@ class PandasUDFTests(ReusedSQLTestCase):
)
# pandas grouped map
self.assertRaisesRegexp(
self.assertRaisesRegex(
PythonException,
exc_message,
df.groupBy('id').apply(
@ -178,7 +178,7 @@ class PandasUDFTests(ReusedSQLTestCase):
).collect
)
self.assertRaisesRegexp(
self.assertRaisesRegex(
PythonException,
exc_message,
df.groupBy('id').apply(
@ -187,7 +187,7 @@ class PandasUDFTests(ReusedSQLTestCase):
)
# pandas grouped agg
self.assertRaisesRegexp(
self.assertRaisesRegex(
PythonException,
exc_message,
df.groupBy('id').agg(
@ -210,8 +210,8 @@ class PandasUDFTests(ReusedSQLTestCase):
# Since 0.11.0, PyArrow supports the feature to raise an error for unsafe cast.
with self.sql_conf({
"spark.sql.execution.pandas.convertToArrowArraySafely": True}):
with self.assertRaisesRegexp(Exception,
"Exception thrown when converting pandas.Series"):
with self.assertRaisesRegex(Exception,
"Exception thrown when converting pandas.Series"):
df.select(['A']).withColumn('udf', udf('A')).collect()
# Disabling Arrow safe type check.
@ -231,8 +231,8 @@ class PandasUDFTests(ReusedSQLTestCase):
# When enabling safe type check, Arrow 0.11.0+ disallows overflow cast.
with self.sql_conf({
"spark.sql.execution.pandas.convertToArrowArraySafely": True}):
with self.assertRaisesRegexp(Exception,
"Exception thrown when converting pandas.Series"):
with self.assertRaisesRegex(Exception,
"Exception thrown when converting pandas.Series"):
df.withColumn('udf', udf('id')).collect()
# Disabling safe type check, let Arrow do the cast anyway.

View file

@ -30,7 +30,7 @@ from pyspark.testing.utils import QuietTest
if have_pandas:
import pandas as pd
from pandas.util.testing import assert_frame_equal
from pandas.testing import assert_frame_equal
@unittest.skipIf(
@ -145,20 +145,20 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
def test_unsupported_types(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(NotImplementedError, 'not supported'):
with self.assertRaisesRegex(NotImplementedError, 'not supported'):
pandas_udf(
lambda x: x,
ArrayType(ArrayType(TimestampType())),
PandasUDFType.GROUPED_AGG)
with QuietTest(self.sc):
with self.assertRaisesRegexp(NotImplementedError, 'not supported'):
with self.assertRaisesRegex(NotImplementedError, 'not supported'):
@pandas_udf('mean double, std double', PandasUDFType.GROUPED_AGG)
def mean_and_std_udf(v):
return v.mean(), v.std()
with QuietTest(self.sc):
with self.assertRaisesRegexp(NotImplementedError, 'not supported'):
with self.assertRaisesRegex(NotImplementedError, 'not supported'):
@pandas_udf(ArrayType(TimestampType()), PandasUDFType.GROUPED_AGG)
def mean_and_std_udf(v):
return {v.mean(): v.std()}
@ -428,7 +428,7 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array<double>', PandasUDFType.GROUPED_AGG)
result1 = df.groupby('id').agg(array_udf(df['v']).alias('v2'))
self.assertEquals(result1.first()['v2'], [1.0, 2.0])
self.assertEqual(result1.first()['v2'], [1.0, 2.0])
def test_invalid_args(self):
df = self.data
@ -436,19 +436,19 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
mean_udf = self.pandas_agg_mean_udf
with QuietTest(self.sc):
with self.assertRaisesRegexp(
with self.assertRaisesRegex(
AnalysisException,
'nor.*aggregate function'):
df.groupby(df.id).agg(plus_one(df.v)).collect()
with QuietTest(self.sc):
with self.assertRaisesRegexp(
with self.assertRaisesRegex(
AnalysisException,
'aggregate function.*argument.*aggregate function'):
df.groupby(df.id).agg(mean_udf(mean_udf(df.v))).collect()
with QuietTest(self.sc):
with self.assertRaisesRegexp(
with self.assertRaisesRegex(
AnalysisException,
'mixture.*aggregate function.*group aggregate pandas UDF'):
df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect()

View file

@ -133,7 +133,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
long_f(col('long')), float_f(col('float')),
double_f(col('double')), decimal_f('decimal'),
bool_f(col('bool')), array_long_f('array_long'))
self.assertEquals(df.collect(), res.collect())
self.assertEqual(df.collect(), res.collect())
def test_register_nondeterministic_vectorized_udf_basic(self):
random_pandas_udf = pandas_udf(
@ -169,7 +169,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
bool_f = pandas_udf(lambda x: x, BooleanType(), udf_type)
res = df.select(bool_f(col('bool')))
self.assertEquals(df.collect(), res.collect())
self.assertEqual(df.collect(), res.collect())
def test_vectorized_udf_null_byte(self):
data = [(None,), (2,), (3,), (4,)]
@ -178,7 +178,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
byte_f = pandas_udf(lambda x: x, ByteType(), udf_type)
res = df.select(byte_f(col('byte')))
self.assertEquals(df.collect(), res.collect())
self.assertEqual(df.collect(), res.collect())
def test_vectorized_udf_null_short(self):
data = [(None,), (2,), (3,), (4,)]
@ -187,7 +187,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
short_f = pandas_udf(lambda x: x, ShortType(), udf_type)
res = df.select(short_f(col('short')))
self.assertEquals(df.collect(), res.collect())
self.assertEqual(df.collect(), res.collect())
def test_vectorized_udf_null_int(self):
data = [(None,), (2,), (3,), (4,)]
@ -196,7 +196,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
int_f = pandas_udf(lambda x: x, IntegerType(), udf_type)
res = df.select(int_f(col('int')))
self.assertEquals(df.collect(), res.collect())
self.assertEqual(df.collect(), res.collect())
def test_vectorized_udf_null_long(self):
data = [(None,), (2,), (3,), (4,)]
@ -205,7 +205,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
long_f = pandas_udf(lambda x: x, LongType(), udf_type)
res = df.select(long_f(col('long')))
self.assertEquals(df.collect(), res.collect())
self.assertEqual(df.collect(), res.collect())
def test_vectorized_udf_null_float(self):
data = [(3.0,), (5.0,), (-1.0,), (None,)]
@ -214,7 +214,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
float_f = pandas_udf(lambda x: x, FloatType(), udf_type)
res = df.select(float_f(col('float')))
self.assertEquals(df.collect(), res.collect())
self.assertEqual(df.collect(), res.collect())
def test_vectorized_udf_null_double(self):
data = [(3.0,), (5.0,), (-1.0,), (None,)]
@ -223,7 +223,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
double_f = pandas_udf(lambda x: x, DoubleType(), udf_type)
res = df.select(double_f(col('double')))
self.assertEquals(df.collect(), res.collect())
self.assertEqual(df.collect(), res.collect())
def test_vectorized_udf_null_decimal(self):
data = [(Decimal(3.0),), (Decimal(5.0),), (Decimal(-1.0),), (None,)]
@ -232,7 +232,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
decimal_f = pandas_udf(lambda x: x, DecimalType(38, 18), udf_type)
res = df.select(decimal_f(col('decimal')))
self.assertEquals(df.collect(), res.collect())
self.assertEqual(df.collect(), res.collect())
def test_vectorized_udf_null_string(self):
data = [("foo",), (None,), ("bar",), ("bar",)]
@ -241,7 +241,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
str_f = pandas_udf(lambda x: x, StringType(), udf_type)
res = df.select(str_f(col('str')))
self.assertEquals(df.collect(), res.collect())
self.assertEqual(df.collect(), res.collect())
def test_vectorized_udf_string_in_udf(self):
df = self.spark.range(10)
@ -255,7 +255,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
str_f = pandas_udf(f, StringType(), udf_type)
actual = df.select(str_f(col('id')))
expected = df.select(col('id').cast('string'))
self.assertEquals(expected.collect(), actual.collect())
self.assertEqual(expected.collect(), actual.collect())
def test_vectorized_udf_datatype_string(self):
df = self.spark.range(10).select(
@ -279,7 +279,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
long_f(col('long')), float_f(col('float')),
double_f(col('double')), decimal_f('decimal'),
bool_f(col('bool')))
self.assertEquals(df.collect(), res.collect())
self.assertEqual(df.collect(), res.collect())
def test_vectorized_udf_null_binary(self):
data = [(bytearray(b"a"),), (None,), (bytearray(b"bb"),), (bytearray(b"ccc"),)]
@ -288,7 +288,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
str_f = pandas_udf(lambda x: x, BinaryType(), udf_type)
res = df.select(str_f(col('binary')))
self.assertEquals(df.collect(), res.collect())
self.assertEqual(df.collect(), res.collect())
def test_vectorized_udf_array_type(self):
data = [([1, 2],), ([3, 4],)]
@ -297,7 +297,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
array_f = pandas_udf(lambda x: x, ArrayType(IntegerType()), udf_type)
result = df.select(array_f(col('array')))
self.assertEquals(df.collect(), result.collect())
self.assertEqual(df.collect(), result.collect())
def test_vectorized_udf_null_array(self):
data = [([1, 2],), (None,), (None,), ([3, 4],), (None,)]
@ -306,7 +306,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
array_f = pandas_udf(lambda x: x, ArrayType(IntegerType()), udf_type)
result = df.select(array_f(col('array')))
self.assertEquals(df.collect(), result.collect())
self.assertEqual(df.collect(), result.collect())
def test_vectorized_udf_struct_type(self):
df = self.spark.range(10)
@ -375,7 +375,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
with QuietTest(self.sc):
with self.assertRaisesRegexp(
with self.assertRaisesRegex(
Exception,
'Invalid return type with scalar Pandas UDFs'):
pandas_udf(lambda x: x, returnType=nested_type, functionType=udf_type)
@ -392,7 +392,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
else:
map_f = pandas_udf(lambda x: x, MapType(StringType(), LongType()), udf_type)
result = df.select(map_f(col('map')))
self.assertEquals(df.collect(), result.collect())
self.assertEqual(df.collect(), result.collect())
def test_vectorized_udf_complex(self):
df = self.spark.range(10).select(
@ -422,7 +422,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
(iter_add, iter_power2, iter_mul)]:
res = df.select(add(col('a'), col('b')), power2(col('a')), mul(col('b'), col('c')))
expected = df.select(expr('a + b'), expr('power(2, a)'), expr('b * c'))
self.assertEquals(expected.collect(), res.collect())
self.assertEqual(expected.collect(), res.collect())
def test_vectorized_udf_exception(self):
df = self.spark.range(10)
@ -435,14 +435,14 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
for raise_exception in [scalar_raise_exception, iter_raise_exception]:
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, 'division( or modulo)? by zero'):
with self.assertRaisesRegex(Exception, 'division( or modulo)? by zero'):
df.select(raise_exception(col('id'))).collect()
def test_vectorized_udf_invalid_length(self):
df = self.spark.range(10)
raise_exception = pandas_udf(lambda _: pd.Series(1), LongType())
with QuietTest(self.sc):
with self.assertRaisesRegexp(
with self.assertRaisesRegex(
Exception,
'Result vector from pandas_udf was not the required length'):
df.select(raise_exception(col('id'))).collect()
@ -453,7 +453,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
yield pd.Series(1)
with QuietTest(self.sc):
with self.assertRaisesRegexp(
with self.assertRaisesRegex(
Exception,
"The length of output in Scalar iterator.*"
"the length of output was 1"):
@ -469,7 +469,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 3}):
df1 = self.spark.range(10).repartition(1)
with QuietTest(self.sc):
with self.assertRaisesRegexp(
with self.assertRaisesRegex(
Exception,
"pandas iterator UDF should exhaust"):
df1.select(iter_udf_not_reading_all_input(col('id'))).collect()
@ -486,7 +486,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
for f, g in [(scalar_f, scalar_g), (iter_f, iter_g)]:
res = df.select(g(f(col('id'))))
self.assertEquals(df.collect(), res.collect())
self.assertEqual(df.collect(), res.collect())
def test_vectorized_udf_chained_struct_type(self):
df = self.spark.range(10)
@ -517,7 +517,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
def test_vectorized_udf_wrong_return_type(self):
with QuietTest(self.sc):
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
with self.assertRaisesRegexp(
with self.assertRaisesRegex(
NotImplementedError,
'Invalid return type.*scalar Pandas UDF.*ArrayType.*TimestampType'):
pandas_udf(lambda x: x, ArrayType(TimestampType()), udf_type)
@ -529,7 +529,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
PandasUDFType.SCALAR_ITER)
for f in [scalar_f, iter_f]:
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, 'Return.*type.*Series'):
with self.assertRaisesRegex(Exception, 'Return.*type.*Series'):
df.select(f(col('id'))).collect()
def test_vectorized_udf_decorator(self):
@ -545,14 +545,14 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
for identity in [scalar_identity, iter_identity]:
res = df.select(identity(col('id')))
self.assertEquals(df.collect(), res.collect())
self.assertEqual(df.collect(), res.collect())
def test_vectorized_udf_empty_partition(self):
df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2))
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
f = pandas_udf(lambda x: x, LongType(), udf_type)
res = df.select(f(col('id')))
self.assertEquals(df.collect(), res.collect())
self.assertEqual(df.collect(), res.collect())
def test_vectorized_udf_struct_with_empty_partition(self):
df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2))\
@ -585,16 +585,16 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
for f in [scalar_f, iter_f]:
res = df.select(f(col('id'), col('id')))
self.assertEquals(df.collect(), res.collect())
self.assertEqual(df.collect(), res.collect())
def test_vectorized_udf_unsupported_types(self):
with QuietTest(self.sc):
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
with self.assertRaisesRegexp(
with self.assertRaisesRegex(
NotImplementedError,
'Invalid return type.*scalar Pandas UDF.*ArrayType.*TimestampType'):
pandas_udf(lambda x: x, ArrayType(TimestampType()), udf_type)
with self.assertRaisesRegexp(
with self.assertRaisesRegex(
NotImplementedError,
'Invalid return type.*scalar Pandas UDF.*ArrayType.StructType'):
pandas_udf(lambda x: x,
@ -637,10 +637,10 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
result = df.withColumn("check_data",
check_data(col("idx"), col("date"), col("date_copy"))).collect()
self.assertEquals(len(data), len(result))
self.assertEqual(len(data), len(result))
for i in range(len(result)):
self.assertEquals(data[i][1], result[i][1]) # "date" col
self.assertEquals(data[i][1], result[i][2]) # "date_copy" col
self.assertEqual(data[i][1], result[i][1]) # "date" col
self.assertEqual(data[i][1], result[i][2]) # "date_copy" col
self.assertIsNone(result[i][3]) # "check_data" col
def test_vectorized_udf_timestamps(self):
@ -686,10 +686,10 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
result = df.withColumn("check_data", check_data(col("idx"), col("timestamp"),
col("timestamp_copy"))).collect()
# Check that collection values are correct
self.assertEquals(len(data), len(result))
self.assertEqual(len(data), len(result))
for i in range(len(result)):
self.assertEquals(data[i][1], result[i][1]) # "timestamp" col
self.assertEquals(data[i][1], result[i][2]) # "timestamp_copy" col
self.assertEqual(data[i][1], result[i][1]) # "timestamp" col
self.assertEqual(data[i][1], result[i][2]) # "timestamp_copy" col
self.assertIsNone(result[i][3]) # "check_data" col
def test_vectorized_udf_return_timestamp_tz(self):
@ -713,7 +713,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
i, ts = r
ts_tz = pd.Timestamp(i, unit='D', tz='America/Los_Angeles').to_pydatetime()
expected = spark_ts_t.fromInternal(spark_ts_t.toInternal(ts_tz))
self.assertEquals(expected, ts)
self.assertEqual(expected, ts)
def test_vectorized_udf_check_config(self):
with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 3}):
@ -799,9 +799,9 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
for random_udf in [self.nondeterministic_vectorized_udf,
self.nondeterministic_vectorized_iter_udf]:
with QuietTest(self.sc):
with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'):
with self.assertRaisesRegex(AnalysisException, 'nondeterministic'):
df.groupby(df.id).agg(sum(random_udf(df.id))).collect()
with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'):
with self.assertRaisesRegex(AnalysisException, 'nondeterministic'):
df.agg(sum(random_udf(df.id))).collect()
def test_register_vectorized_udf_basic(self):
@ -825,8 +825,8 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
res2 = self.spark.sql(
"SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM range(10)) t")
expected = df.select(expr('a + b'))
self.assertEquals(expected.collect(), res1.collect())
self.assertEquals(expected.collect(), res2.collect())
self.assertEqual(expected.collect(), res1.collect())
self.assertEqual(expected.collect(), res2.collect())
def test_scalar_iter_udf_init(self):
import numpy as np
@ -854,7 +854,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
finally:
raise RuntimeError("reached finally block")
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, "reached finally block"):
with self.assertRaisesRegex(Exception, "reached finally block"):
self.spark.range(1).select(test_close(col("id"))).collect()
def test_scalar_iter_udf_close_early(self):
@ -905,7 +905,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
foo_udf = pandas_udf(lambda x: x, 'timestamp', udf_type)
result = df.withColumn('time', foo_udf(df.time))
self.assertEquals(df.collect(), result.collect())
self.assertEqual(df.collect(), result.collect())
def test_udf_category_type(self):
@ -1003,11 +1003,11 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
df_chained_4 = df.withColumn('f4_f2_f1', f4(f2(f1(df['v']))))
df_chained_5 = df.withColumn('f4_f3_f1', f4(f3(f1(df['v']))))
self.assertEquals(expected_chained_1, df_chained_1.collect())
self.assertEquals(expected_chained_2, df_chained_2.collect())
self.assertEquals(expected_chained_3, df_chained_3.collect())
self.assertEquals(expected_chained_4, df_chained_4.collect())
self.assertEquals(expected_chained_5, df_chained_5.collect())
self.assertEqual(expected_chained_1, df_chained_1.collect())
self.assertEqual(expected_chained_2, df_chained_2.collect())
self.assertEqual(expected_chained_3, df_chained_3.collect())
self.assertEqual(expected_chained_4, df_chained_4.collect())
self.assertEqual(expected_chained_5, df_chained_5.collect())
# Test multiple mixed UDF expressions in a single projection
df_multi_1 = df \
@ -1045,8 +1045,8 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
.withColumn('f4_f3_f2', f4(f3(f2(col('v'))))) \
.withColumn('f4_f3_f2_f1', f4(f3(f2(f1(col('v'))))))
self.assertEquals(expected_multi, df_multi_1.collect())
self.assertEquals(expected_multi, df_multi_2.collect())
self.assertEqual(expected_multi, df_multi_1.collect())
self.assertEqual(expected_multi, df_multi_2.collect())
def test_mixed_udf_and_sql(self):
df = self.spark.range(0, 1).toDF('v')
@ -1107,7 +1107,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
.withColumn('f3_f1_f2', f3(f1(f2(df['v'])))) \
.withColumn('f3_f2_f1', f3(f2(f1(df['v']))))
self.assertEquals(expected, df1.collect())
self.assertEqual(expected, df1.collect())
# SPARK-24721
@unittest.skipIf(not test_compiled, test_not_compiled_message) # type: ignore
@ -1138,17 +1138,17 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
for df in [filesource_df, datasource_df, datasource_v2_df]:
result = df.withColumn('c', c1)
expected = df.withColumn('c', lit(2))
self.assertEquals(expected.collect(), result.collect())
self.assertEqual(expected.collect(), result.collect())
for df in [filesource_df, datasource_df, datasource_v2_df]:
result = df.withColumn('c', c2)
expected = df.withColumn('c', col('i') + 1)
self.assertEquals(expected.collect(), result.collect())
self.assertEqual(expected.collect(), result.collect())
for df in [filesource_df, datasource_df, datasource_v2_df]:
for f in [f1, f2]:
result = df.filter(f)
self.assertEquals(0, result.count())
self.assertEqual(0, result.count())
finally:
shutil.rmtree(path)

View file

@ -29,7 +29,7 @@ from pyspark.sql import Row
if have_pandas:
import pandas as pd
import numpy as np
from pandas.util.testing import assert_frame_equal
from pandas.testing import assert_frame_equal
@unittest.skipIf(

View file

@ -26,7 +26,7 @@ from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarro
from pyspark.testing.utils import QuietTest
if have_pandas:
from pandas.util.testing import assert_frame_equal
from pandas.testing import assert_frame_equal
@unittest.skipIf(
@ -241,14 +241,14 @@ class WindowPandasUDFTests(ReusedSQLTestCase):
array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array<double>', PandasUDFType.GROUPED_AGG)
result1 = df.withColumn('v2', array_udf(df['v']).over(w))
self.assertEquals(result1.first()['v2'], [1.0, 2.0])
self.assertEqual(result1.first()['v2'], [1.0, 2.0])
def test_invalid_args(self):
df = self.data
w = self.unbounded_window
with QuietTest(self.sc):
with self.assertRaisesRegexp(
with self.assertRaisesRegex(
AnalysisException,
'.*not supported within a window function'):
foo_udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP)

View file

@ -180,7 +180,7 @@ class TypesTests(ReusedSQLTestCase):
self.assertEqual(df.columns, ['col1', '_2'])
def test_infer_schema_fails(self):
with self.assertRaisesRegexp(TypeError, 'field a'):
with self.assertRaisesRegex(TypeError, 'field a'):
self.spark.createDataFrame(self.spark.sparkContext.parallelize([[1, 1], ["x", 1]]),
schema=["a", "b"], samplingRatio=0.99)
@ -578,18 +578,18 @@ class TypesTests(ReusedSQLTestCase):
ArrayType(LongType()),
ArrayType(LongType())
), ArrayType(LongType()))
with self.assertRaisesRegexp(TypeError, 'element in array'):
with self.assertRaisesRegex(TypeError, 'element in array'):
_merge_type(ArrayType(LongType()), ArrayType(DoubleType()))
self.assertEqual(_merge_type(
MapType(StringType(), LongType()),
MapType(StringType(), LongType())
), MapType(StringType(), LongType()))
with self.assertRaisesRegexp(TypeError, 'key of map'):
with self.assertRaisesRegex(TypeError, 'key of map'):
_merge_type(
MapType(StringType(), LongType()),
MapType(DoubleType(), LongType()))
with self.assertRaisesRegexp(TypeError, 'value of map'):
with self.assertRaisesRegex(TypeError, 'value of map'):
_merge_type(
MapType(StringType(), LongType()),
MapType(StringType(), DoubleType()))
@ -598,7 +598,7 @@ class TypesTests(ReusedSQLTestCase):
StructType([StructField("f1", LongType()), StructField("f2", StringType())]),
StructType([StructField("f1", LongType()), StructField("f2", StringType())])
), StructType([StructField("f1", LongType()), StructField("f2", StringType())]))
with self.assertRaisesRegexp(TypeError, 'field f1'):
with self.assertRaisesRegex(TypeError, 'field f1'):
_merge_type(
StructType([StructField("f1", LongType()), StructField("f2", StringType())]),
StructType([StructField("f1", DoubleType()), StructField("f2", StringType())]))
@ -607,7 +607,7 @@ class TypesTests(ReusedSQLTestCase):
StructType([StructField("f1", StructType([StructField("f2", LongType())]))]),
StructType([StructField("f1", StructType([StructField("f2", LongType())]))])
), StructType([StructField("f1", StructType([StructField("f2", LongType())]))]))
with self.assertRaisesRegexp(TypeError, 'field f2 in field f1'):
with self.assertRaisesRegex(TypeError, 'field f2 in field f1'):
_merge_type(
StructType([StructField("f1", StructType([StructField("f2", LongType())]))]),
StructType([StructField("f1", StructType([StructField("f2", StringType())]))]))
@ -616,7 +616,7 @@ class TypesTests(ReusedSQLTestCase):
StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]),
StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())])
), StructType([StructField("f1", ArrayType(LongType())), StructField("f2", StringType())]))
with self.assertRaisesRegexp(TypeError, 'element in array field f1'):
with self.assertRaisesRegex(TypeError, 'element in array field f1'):
_merge_type(
StructType([
StructField("f1", ArrayType(LongType())),
@ -635,7 +635,7 @@ class TypesTests(ReusedSQLTestCase):
), StructType([
StructField("f1", MapType(StringType(), LongType())),
StructField("f2", StringType())]))
with self.assertRaisesRegexp(TypeError, 'value of map field f1'):
with self.assertRaisesRegex(TypeError, 'value of map field f1'):
_merge_type(
StructType([
StructField("f1", MapType(StringType(), LongType())),
@ -648,7 +648,7 @@ class TypesTests(ReusedSQLTestCase):
StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]),
StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))])
), StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]))
with self.assertRaisesRegexp(TypeError, 'key of map element in array field f1'):
with self.assertRaisesRegex(TypeError, 'key of map element in array field f1'):
_merge_type(
StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]),
StructType([StructField("f1", ArrayType(MapType(DoubleType(), LongType())))])
@ -734,7 +734,7 @@ class TypesTests(ReusedSQLTestCase):
unsupported_types = all_types - set(supported_types)
# test unsupported types
for t in unsupported_types:
with self.assertRaisesRegexp(TypeError, "infer the type of the field myarray"):
with self.assertRaisesRegex(TypeError, "infer the type of the field myarray"):
a = array.array(t)
self.spark.createDataFrame([Row(myarray=a)]).collect()
@ -789,13 +789,13 @@ class DataTypeTests(unittest.TestCase):
class DataTypeVerificationTests(unittest.TestCase):
def test_verify_type_exception_msg(self):
self.assertRaisesRegexp(
self.assertRaisesRegex(
ValueError,
"test_name",
lambda: _make_type_verifier(StringType(), nullable=False, name="test_name")(None))
schema = StructType([StructField('a', StructType([StructField('b', IntegerType())]))])
self.assertRaisesRegexp(
self.assertRaisesRegex(
TypeError,
"field b in field a",
lambda: _make_type_verifier(schema)([["data"]]))

View file

@ -98,7 +98,7 @@ class UDFTests(ReusedSQLTestCase):
def test_udf_registration_return_type_not_none(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(TypeError, "Invalid return type"):
with self.assertRaisesRegex(TypeError, "Invalid return type"):
self.spark.catalog.registerFunction(
"f", UserDefinedFunction(lambda x, y: len(x) + y, StringType()), StringType())
@ -149,9 +149,9 @@ class UDFTests(ReusedSQLTestCase):
df = self.spark.range(10)
with QuietTest(self.sc):
with self.assertRaisesRegexp(AnalysisException, "nondeterministic"):
with self.assertRaisesRegex(AnalysisException, "nondeterministic"):
df.groupby('id').agg(sum(udf_random_col())).collect()
with self.assertRaisesRegexp(AnalysisException, "nondeterministic"):
with self.assertRaisesRegex(AnalysisException, "nondeterministic"):
df.agg(sum(udf_random_col())).collect()
def test_chained_udf(self):
@ -203,7 +203,7 @@ class UDFTests(ReusedSQLTestCase):
# Cross join.
df = left.join(right, f("a", "b"))
with self.sql_conf({"spark.sql.crossJoin.enabled": False}):
with self.assertRaisesRegexp(AnalysisException, 'Detected implicit cartesian product'):
with self.assertRaisesRegex(AnalysisException, 'Detected implicit cartesian product'):
df.collect()
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
self.assertEqual(df.collect(), [Row(a=1, b=1)])
@ -238,7 +238,7 @@ class UDFTests(ReusedSQLTestCase):
f = udf(lambda a, b: a == b, BooleanType())
def runWithJoinType(join_type, type_string):
with self.assertRaisesRegexp(
with self.assertRaisesRegex(
AnalysisException,
'Using PythonUDF.*%s is not supported.' % type_string):
left.join(right, [f("a", "b"), left.a1 == right.b1], join_type).collect()
@ -385,18 +385,18 @@ class UDFTests(ReusedSQLTestCase):
def test_non_existed_udf(self):
spark = self.spark
self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf",
lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf"))
self.assertRaisesRegex(AnalysisException, "Can not load class non_existed_udf",
lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf"))
# This is to check if a deprecated 'SQLContext.registerJavaFunction' can call its alias.
sqlContext = spark._wrapped
self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf",
lambda: sqlContext.registerJavaFunction("udf1", "non_existed_udf"))
self.assertRaisesRegex(AnalysisException, "Can not load class non_existed_udf",
lambda: sqlContext.registerJavaFunction("udf1", "non_existed_udf"))
def test_non_existed_udaf(self):
spark = self.spark
self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf",
lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf"))
self.assertRaisesRegex(AnalysisException, "Can not load class non_existed_udaf",
lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf"))
def test_udf_with_input_file_name(self):
from pyspark.sql.functions import input_file_name
@ -587,17 +587,17 @@ class UDFTests(ReusedSQLTestCase):
for df in [filesource_df, datasource_df, datasource_v2_df]:
result = df.withColumn('c', c1)
expected = df.withColumn('c', lit(2))
self.assertEquals(expected.collect(), result.collect())
self.assertEqual(expected.collect(), result.collect())
for df in [filesource_df, datasource_df, datasource_v2_df]:
result = df.withColumn('c', c2)
expected = df.withColumn('c', col('i') + 1)
self.assertEquals(expected.collect(), result.collect())
self.assertEqual(expected.collect(), result.collect())
for df in [filesource_df, datasource_df, datasource_v2_df]:
for f in [f1, f2]:
result = df.filter(f)
self.assertEquals(0, result.count())
self.assertEqual(0, result.count())
finally:
shutil.rmtree(path)

View file

@ -31,23 +31,22 @@ class UtilsTests(ReusedSQLTestCase):
try:
self.spark.sql("select `中文字段`")
except AnalysisException as e:
self.assertRegexpMatches(str(e), "cannot resolve '`中文字段`'")
self.assertRegex(str(e), "cannot resolve '`中文字段`'")
def test_capture_parse_exception(self):
self.assertRaises(ParseException, lambda: self.spark.sql("abc"))
def test_capture_illegalargument_exception(self):
self.assertRaisesRegexp(IllegalArgumentException, "Setting negative mapred.reduce.tasks",
lambda: self.spark.sql("SET mapred.reduce.tasks=-1"))
self.assertRaisesRegex(IllegalArgumentException, "Setting negative mapred.reduce.tasks",
lambda: self.spark.sql("SET mapred.reduce.tasks=-1"))
df = self.spark.createDataFrame([(1, 2)], ["a", "b"])
self.assertRaisesRegexp(IllegalArgumentException, "1024 is not in the permitted values",
lambda: df.select(sha2(df.a, 1024)).collect())
self.assertRaisesRegex(IllegalArgumentException, "1024 is not in the permitted values",
lambda: df.select(sha2(df.a, 1024)).collect())
try:
df.select(sha2(df.a, 1024)).collect()
except IllegalArgumentException as e:
self.assertRegexpMatches(e.desc, "1024 is not in the permitted values")
self.assertRegexpMatches(e.stackTrace,
"org.apache.spark.sql.functions")
self.assertRegex(e.desc, "1024 is not in the permitted values")
self.assertRegex(e.stackTrace, "org.apache.spark.sql.functions")
if __name__ == "__main__":

View file

@ -85,11 +85,11 @@ class ProfilerTests2(unittest.TestCase):
def test_profiler_disabled(self):
sc = SparkContext(conf=SparkConf().set("spark.python.profile", "false"))
try:
self.assertRaisesRegexp(
self.assertRaisesRegex(
RuntimeError,
"'spark.python.profile' configuration must be set",
lambda: sc.show_profiles())
self.assertRaisesRegexp(
self.assertRaisesRegex(
RuntimeError,
"'spark.python.profile' configuration must be set",
lambda: sc.dump_profiles("/tmp/abc"))

View file

@ -733,25 +733,25 @@ class RDDTests(ReusedPySparkTestCase):
keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10))
msg = "Caught StopIteration thrown from user's code; failing the task"
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.map(stopit).collect)
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.filter(stopit).collect)
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit)
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.reduce, stopit)
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.fold, 0, stopit)
self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit)
self.assertRaisesRegexp(Py4JJavaError, msg,
seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect)
self.assertRaisesRegex(Py4JJavaError, msg, seq_rdd.map(stopit).collect)
self.assertRaisesRegex(Py4JJavaError, msg, seq_rdd.filter(stopit).collect)
self.assertRaisesRegex(Py4JJavaError, msg, seq_rdd.foreach, stopit)
self.assertRaisesRegex(Py4JJavaError, msg, seq_rdd.reduce, stopit)
self.assertRaisesRegex(Py4JJavaError, msg, seq_rdd.fold, 0, stopit)
self.assertRaisesRegex(Py4JJavaError, msg, seq_rdd.foreach, stopit)
self.assertRaisesRegex(Py4JJavaError, msg,
seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect)
# these methods call the user function both in the driver and in the executor
# the exception raised is different according to where the StopIteration happens
# RuntimeError is raised if in the driver
# Py4JJavaError is raised if in the executor (wraps the RuntimeError raised in the worker)
self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
keyed_rdd.reduceByKeyLocally, stopit)
self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
seq_rdd.aggregate, 0, stopit, lambda *x: 1)
self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
seq_rdd.aggregate, 0, lambda *x: 1, stopit)
self.assertRaisesRegex((Py4JJavaError, RuntimeError), msg,
keyed_rdd.reduceByKeyLocally, stopit)
self.assertRaisesRegex((Py4JJavaError, RuntimeError), msg,
seq_rdd.aggregate, 0, stopit, lambda *x: 1)
self.assertRaisesRegex((Py4JJavaError, RuntimeError), msg,
seq_rdd.aggregate, 0, lambda *x: 1, stopit)
def test_overwritten_global_func(self):
# Regression test for SPARK-27000
@ -768,7 +768,7 @@ class RDDTests(ReusedPySparkTestCase):
rdd = self.sc.range(10).map(fail)
with self.assertRaisesRegexp(Exception, "local iterator error"):
with self.assertRaisesRegex(Exception, "local iterator error"):
for _ in rdd.toLocalIterator():
pass

View file

@ -165,7 +165,7 @@ class WorkerTests(ReusedPySparkTestCase):
self.sc.parallelize([1]).map(lambda x: f()).count()
except Py4JJavaError as e:
self.assertRegexpMatches(str(e), "exception with 中")
self.assertRegex(str(e), "exception with 中")
class WorkerReuseTest(PySparkTestCase):