[SPARK-30856][SQL][PYSPARK] Fix SQLContext.getOrCreate() when SparkContext is restarted

### What changes were proposed in this pull request?

As discussed on the Jira ticket, this change clears the SQLContext._instantiatedContext class attribute when the SparkSession is stopped. That way, the attribute will be reset with a new, usable SQLContext when a new SparkSession is started.

### Why are the changes needed?

When the underlying SQLContext is instantiated for a SparkSession, the instance is saved as a class attribute and returned from subsequent calls to SQLContext.getOrCreate(). If the SparkContext is stopped and a new one started, the SQLContext class attribute is never cleared so any code which calls SQLContext.getOrCreate() will get a SQLContext with a reference to the old, unusable SparkContext.

A similar issue was identified and fixed for SparkSession in [SPARK-19055](https://issues.apache.org/jira/browse/SPARK-19055), but the fix did not change SQLContext as well. I ran into this because mllib still [uses](https://github.com/apache/spark/blob/master/python/pyspark/mllib/common.py#L105) SQLContext.getOrCreate() under the hood.

### Does this PR introduce any user-facing change?

No

### How was this patch tested?

A new test was added. I verified that the test fails without the included change.

Closes #27610 from afavaro/restart-sqlcontext.

Authored-by: Alex Favaro <alex.favaro@affirm.com>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
Alex Favaro 2020-02-20 12:21:24 +09:00 committed by HyukjinKwon
parent e32411eb07
commit 96c1a4401d
4 changed files with 52 additions and 3 deletions

View file

@ -87,7 +87,8 @@ class SQLContext(object):
self._jsqlContext = jsqlContext
_monkey_patch_RDD(self.sparkSession)
install_exception_handler()
if SQLContext._instantiatedContext is None:
if (SQLContext._instantiatedContext is None
or SQLContext._instantiatedContext._sc._jsc is None):
SQLContext._instantiatedContext = self
@property
@ -118,7 +119,8 @@ class SQLContext(object):
"Deprecated in 3.0.0. Use SparkSession.builder.getOrCreate() instead.",
DeprecationWarning)
if cls._instantiatedContext is None:
if (cls._instantiatedContext is None
or SQLContext._instantiatedContext._sc._jsc is None):
jsqlContext = sc._jvm.SparkSession.builder().sparkContext(
sc._jsc.sc()).getOrCreate().sqlContext()
sparkSession = SparkSession(sc, jsqlContext.sparkSession())

View file

@ -699,12 +699,14 @@ class SparkSession(SparkConversionMixin):
def stop(self):
"""Stop the underlying :class:`SparkContext`.
"""
from pyspark.sql.context import SQLContext
self._sc.stop()
# We should clean the default session up. See SPARK-23228.
self._jvm.SparkSession.clearDefaultSession()
self._jvm.SparkSession.clearActiveSession()
SparkSession._instantiatedSession = None
SparkSession._activeSession = None
SQLContext._instantiatedContext = None
@since(2.0)
def __enter__(self):

View file

@ -270,7 +270,6 @@ class SQLContextTests(unittest.TestCase):
sql_context = SQLContext.getOrCreate(sc)
assert(isinstance(sql_context, SQLContext))
finally:
SQLContext._instantiatedContext = None
if sql_context is not None:
sql_context.sparkSession.stop()
if sc is not None:

View file

@ -225,6 +225,52 @@ class SparkSessionTests4(ReusedSQLTestCase):
session2.stop()
class SparkSessionTests5(unittest.TestCase):
def setUp(self):
# These tests require restarting the Spark context so we set up a new one for each test
# rather than at the class level.
self.sc = SparkContext('local[4]', self.__class__.__name__, conf=SparkConf())
self.spark = SparkSession(self.sc)
def tearDown(self):
self.sc.stop()
self.spark.stop()
def test_sqlcontext_with_stopped_sparksession(self):
# SPARK-30856: test that SQLContext.getOrCreate() returns a usable instance after
# the SparkSession is restarted.
sql_context = self.spark._wrapped
self.spark.stop()
sc = SparkContext('local[4]', self.sc.appName)
spark = SparkSession(sc) # Instantiate the underlying SQLContext
new_sql_context = spark._wrapped
self.assertIsNot(new_sql_context, sql_context)
self.assertIs(SQLContext.getOrCreate(sc).sparkSession, spark)
try:
df = spark.createDataFrame([(1, 2)], ['c', 'c'])
df.collect()
finally:
spark.stop()
self.assertIsNone(SQLContext._instantiatedContext)
sc.stop()
def test_sqlcontext_with_stopped_sparkcontext(self):
# SPARK-30856: test initialization via SparkSession when only the SparkContext is stopped
self.sc.stop()
self.sc = SparkContext('local[4]', self.sc.appName)
self.spark = SparkSession(self.sc)
self.assertIs(SQLContext.getOrCreate(self.sc).sparkSession, self.spark)
def test_get_sqlcontext_with_stopped_sparkcontext(self):
# SPARK-30856: test initialization via SQLContext.getOrCreate() when only the SparkContext
# is stopped
self.sc.stop()
self.sc = SparkContext('local[4]', self.sc.appName)
self.assertIs(SQLContext.getOrCreate(self.sc)._sc, self.sc)
class SparkSessionBuilderTests(unittest.TestCase):
def test_create_spark_context_first_then_spark_session(self):