[SPARK-23691][PYTHON] Use sql_conf util in PySpark tests where possible
## What changes were proposed in this pull request?
d6632d185e
added an useful util
```python
contextmanager
def sql_conf(self, pairs):
...
```
to allow configuration set/unset within a block:
```python
with self.sql_conf({"spark.blah.blah.blah", "blah"})
# test codes
```
This PR proposes to use this util where possible in PySpark tests.
Note that there look already few places affecting tests without restoring the original value back in unittest classes.
## How was this patch tested?
Manually tested via:
```
./run-tests --modules=pyspark-sql --python-executables=python2
./run-tests --modules=pyspark-sql --python-executables=python3
```
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20830 from HyukjinKwon/cleanup-sql-conf.
This commit is contained in:
parent
5f4deff195
commit
566321852b
|
@ -2461,17 +2461,13 @@ class SQLTests(ReusedSQLTestCase):
|
|||
df1 = self.spark.range(1).toDF("a")
|
||||
df2 = self.spark.range(1).toDF("b")
|
||||
|
||||
try:
|
||||
self.spark.conf.set("spark.sql.crossJoin.enabled", "false")
|
||||
with self.sql_conf({"spark.sql.crossJoin.enabled": False}):
|
||||
self.assertRaises(AnalysisException, lambda: df1.join(df2, how="inner").collect())
|
||||
|
||||
self.spark.conf.set("spark.sql.crossJoin.enabled", "true")
|
||||
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
|
||||
actual = df1.join(df2, how="inner").collect()
|
||||
expected = [Row(a=0, b=0)]
|
||||
self.assertEqual(actual, expected)
|
||||
finally:
|
||||
# We should unset this. Otherwise, other tests are affected.
|
||||
self.spark.conf.unset("spark.sql.crossJoin.enabled")
|
||||
|
||||
# Regression test for invalid join methods when on is None, Spark-14761
|
||||
def test_invalid_join_method(self):
|
||||
|
@ -2943,21 +2939,18 @@ class SQLTests(ReusedSQLTestCase):
|
|||
self.assertPandasEqual(pdf, df.toPandas())
|
||||
|
||||
orig_env_tz = os.environ.get('TZ', None)
|
||||
orig_session_tz = self.spark.conf.get('spark.sql.session.timeZone')
|
||||
try:
|
||||
tz = 'America/Los_Angeles'
|
||||
os.environ['TZ'] = tz
|
||||
time.tzset()
|
||||
self.spark.conf.set('spark.sql.session.timeZone', tz)
|
||||
|
||||
df = self.spark.createDataFrame(pdf)
|
||||
self.assertPandasEqual(pdf, df.toPandas())
|
||||
with self.sql_conf({'spark.sql.session.timeZone': tz}):
|
||||
df = self.spark.createDataFrame(pdf)
|
||||
self.assertPandasEqual(pdf, df.toPandas())
|
||||
finally:
|
||||
del os.environ['TZ']
|
||||
if orig_env_tz is not None:
|
||||
os.environ['TZ'] = orig_env_tz
|
||||
time.tzset()
|
||||
self.spark.conf.set('spark.sql.session.timeZone', orig_session_tz)
|
||||
|
||||
|
||||
class HiveSparkSubmitTests(SparkSubmitTests):
|
||||
|
@ -3562,12 +3555,11 @@ class ArrowTests(ReusedSQLTestCase):
|
|||
self.assertTrue(all([c == 1 for c in null_counts]))
|
||||
|
||||
def _toPandas_arrow_toggle(self, df):
|
||||
self.spark.conf.set("spark.sql.execution.arrow.enabled", "false")
|
||||
try:
|
||||
with self.sql_conf({"spark.sql.execution.arrow.enabled": False}):
|
||||
pdf = df.toPandas()
|
||||
finally:
|
||||
self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
|
||||
|
||||
pdf_arrow = df.toPandas()
|
||||
|
||||
return pdf, pdf_arrow
|
||||
|
||||
def test_toPandas_arrow_toggle(self):
|
||||
|
@ -3579,16 +3571,17 @@ class ArrowTests(ReusedSQLTestCase):
|
|||
|
||||
def test_toPandas_respect_session_timezone(self):
|
||||
df = self.spark.createDataFrame(self.data, schema=self.schema)
|
||||
orig_tz = self.spark.conf.get("spark.sql.session.timeZone")
|
||||
try:
|
||||
timezone = "America/New_York"
|
||||
self.spark.conf.set("spark.sql.session.timeZone", timezone)
|
||||
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false")
|
||||
try:
|
||||
pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df)
|
||||
self.assertPandasEqual(pdf_arrow_la, pdf_la)
|
||||
finally:
|
||||
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true")
|
||||
|
||||
timezone = "America/New_York"
|
||||
with self.sql_conf({
|
||||
"spark.sql.execution.pandas.respectSessionTimeZone": False,
|
||||
"spark.sql.session.timeZone": timezone}):
|
||||
pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df)
|
||||
self.assertPandasEqual(pdf_arrow_la, pdf_la)
|
||||
|
||||
with self.sql_conf({
|
||||
"spark.sql.execution.pandas.respectSessionTimeZone": True,
|
||||
"spark.sql.session.timeZone": timezone}):
|
||||
pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df)
|
||||
self.assertPandasEqual(pdf_arrow_ny, pdf_ny)
|
||||
|
||||
|
@ -3601,8 +3594,6 @@ class ArrowTests(ReusedSQLTestCase):
|
|||
pdf_la_corrected[field.name] = _check_series_convert_timestamps_local_tz(
|
||||
pdf_la_corrected[field.name], timezone)
|
||||
self.assertPandasEqual(pdf_ny, pdf_la_corrected)
|
||||
finally:
|
||||
self.spark.conf.set("spark.sql.session.timeZone", orig_tz)
|
||||
|
||||
def test_pandas_round_trip(self):
|
||||
pdf = self.create_pandas_data_frame()
|
||||
|
@ -3618,12 +3609,11 @@ class ArrowTests(ReusedSQLTestCase):
|
|||
self.assertTrue(pdf.empty)
|
||||
|
||||
def _createDataFrame_toggle(self, pdf, schema=None):
|
||||
self.spark.conf.set("spark.sql.execution.arrow.enabled", "false")
|
||||
try:
|
||||
with self.sql_conf({"spark.sql.execution.arrow.enabled": False}):
|
||||
df_no_arrow = self.spark.createDataFrame(pdf, schema=schema)
|
||||
finally:
|
||||
self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
|
||||
|
||||
df_arrow = self.spark.createDataFrame(pdf, schema=schema)
|
||||
|
||||
return df_no_arrow, df_arrow
|
||||
|
||||
def test_createDataFrame_toggle(self):
|
||||
|
@ -3634,18 +3624,18 @@ class ArrowTests(ReusedSQLTestCase):
|
|||
def test_createDataFrame_respect_session_timezone(self):
|
||||
from datetime import timedelta
|
||||
pdf = self.create_pandas_data_frame()
|
||||
orig_tz = self.spark.conf.get("spark.sql.session.timeZone")
|
||||
try:
|
||||
timezone = "America/New_York"
|
||||
self.spark.conf.set("spark.sql.session.timeZone", timezone)
|
||||
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false")
|
||||
try:
|
||||
df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema)
|
||||
result_la = df_no_arrow_la.collect()
|
||||
result_arrow_la = df_arrow_la.collect()
|
||||
self.assertEqual(result_la, result_arrow_la)
|
||||
finally:
|
||||
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true")
|
||||
timezone = "America/New_York"
|
||||
with self.sql_conf({
|
||||
"spark.sql.execution.pandas.respectSessionTimeZone": False,
|
||||
"spark.sql.session.timeZone": timezone}):
|
||||
df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema)
|
||||
result_la = df_no_arrow_la.collect()
|
||||
result_arrow_la = df_arrow_la.collect()
|
||||
self.assertEqual(result_la, result_arrow_la)
|
||||
|
||||
with self.sql_conf({
|
||||
"spark.sql.execution.pandas.respectSessionTimeZone": True,
|
||||
"spark.sql.session.timeZone": timezone}):
|
||||
df_no_arrow_ny, df_arrow_ny = self._createDataFrame_toggle(pdf, schema=self.schema)
|
||||
result_ny = df_no_arrow_ny.collect()
|
||||
result_arrow_ny = df_arrow_ny.collect()
|
||||
|
@ -3658,8 +3648,6 @@ class ArrowTests(ReusedSQLTestCase):
|
|||
for k, v in row.asDict().items()})
|
||||
for row in result_la]
|
||||
self.assertEqual(result_ny, result_la_corrected)
|
||||
finally:
|
||||
self.spark.conf.set("spark.sql.session.timeZone", orig_tz)
|
||||
|
||||
def test_createDataFrame_with_schema(self):
|
||||
pdf = self.create_pandas_data_frame()
|
||||
|
@ -4336,9 +4324,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
def test_vectorized_udf_check_config(self):
|
||||
from pyspark.sql.functions import pandas_udf, col
|
||||
import pandas as pd
|
||||
orig_value = self.spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch", None)
|
||||
self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 3)
|
||||
try:
|
||||
with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 3}):
|
||||
df = self.spark.range(10, numPartitions=1)
|
||||
|
||||
@pandas_udf(returnType=LongType())
|
||||
|
@ -4348,11 +4334,6 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
result = df.select(check_records_per_batch(col("id"))).collect()
|
||||
for (r,) in result:
|
||||
self.assertTrue(r <= 3)
|
||||
finally:
|
||||
if orig_value is None:
|
||||
self.spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch")
|
||||
else:
|
||||
self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", orig_value)
|
||||
|
||||
def test_vectorized_udf_timestamps_respect_session_timezone(self):
|
||||
from pyspark.sql.functions import pandas_udf, col
|
||||
|
@ -4371,30 +4352,27 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
|
|||
internal_value = pandas_udf(
|
||||
lambda ts: ts.apply(lambda ts: ts.value if ts is not pd.NaT else None), LongType())
|
||||
|
||||
orig_tz = self.spark.conf.get("spark.sql.session.timeZone")
|
||||
try:
|
||||
timezone = "America/New_York"
|
||||
self.spark.conf.set("spark.sql.session.timeZone", timezone)
|
||||
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false")
|
||||
try:
|
||||
df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \
|
||||
.withColumn("internal_value", internal_value(col("timestamp")))
|
||||
result_la = df_la.select(col("idx"), col("internal_value")).collect()
|
||||
# Correct result_la by adjusting 3 hours difference between Los Angeles and New York
|
||||
diff = 3 * 60 * 60 * 1000 * 1000 * 1000
|
||||
result_la_corrected = \
|
||||
df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect()
|
||||
finally:
|
||||
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true")
|
||||
timezone = "America/New_York"
|
||||
with self.sql_conf({
|
||||
"spark.sql.execution.pandas.respectSessionTimeZone": False,
|
||||
"spark.sql.session.timeZone": timezone}):
|
||||
df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \
|
||||
.withColumn("internal_value", internal_value(col("timestamp")))
|
||||
result_la = df_la.select(col("idx"), col("internal_value")).collect()
|
||||
# Correct result_la by adjusting 3 hours difference between Los Angeles and New York
|
||||
diff = 3 * 60 * 60 * 1000 * 1000 * 1000
|
||||
result_la_corrected = \
|
||||
df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect()
|
||||
|
||||
with self.sql_conf({
|
||||
"spark.sql.execution.pandas.respectSessionTimeZone": True,
|
||||
"spark.sql.session.timeZone": timezone}):
|
||||
df_ny = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \
|
||||
.withColumn("internal_value", internal_value(col("timestamp")))
|
||||
result_ny = df_ny.select(col("idx"), col("tscopy"), col("internal_value")).collect()
|
||||
|
||||
self.assertNotEqual(result_ny, result_la)
|
||||
self.assertEqual(result_ny, result_la_corrected)
|
||||
finally:
|
||||
self.spark.conf.set("spark.sql.session.timeZone", orig_tz)
|
||||
|
||||
def test_nondeterministic_vectorized_udf(self):
|
||||
# Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations
|
||||
|
@ -5170,9 +5148,7 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
|
|||
|
||||
def test_retain_group_columns(self):
|
||||
from pyspark.sql.functions import sum, lit, col
|
||||
orig_value = self.spark.conf.get("spark.sql.retainGroupColumns", None)
|
||||
self.spark.conf.set("spark.sql.retainGroupColumns", False)
|
||||
try:
|
||||
with self.sql_conf({"spark.sql.retainGroupColumns": False}):
|
||||
df = self.data
|
||||
sum_udf = self.pandas_agg_sum_udf
|
||||
|
||||
|
@ -5180,12 +5156,6 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
|
|||
expected1 = df.groupby(df.id).agg(sum(df.v))
|
||||
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
|
||||
|
||||
finally:
|
||||
if orig_value is None:
|
||||
self.spark.conf.unset("spark.sql.retainGroupColumns")
|
||||
else:
|
||||
self.spark.conf.set("spark.sql.retainGroupColumns", orig_value)
|
||||
|
||||
def test_invalid_args(self):
|
||||
from pyspark.sql.functions import mean
|
||||
|
||||
|
|
Loading…
Reference in a new issue