[SPARK-30856][SQL][PYSPARK] Fix SQLContext.getOrCreate() when SparkContext is restarted
### What changes were proposed in this pull request? As discussed on the Jira ticket, this change clears the SQLContext._instantiatedContext class attribute when the SparkSession is stopped. That way, the attribute will be reset with a new, usable SQLContext when a new SparkSession is started. ### Why are the changes needed? When the underlying SQLContext is instantiated for a SparkSession, the instance is saved as a class attribute and returned from subsequent calls to SQLContext.getOrCreate(). If the SparkContext is stopped and a new one started, the SQLContext class attribute is never cleared so any code which calls SQLContext.getOrCreate() will get a SQLContext with a reference to the old, unusable SparkContext. A similar issue was identified and fixed for SparkSession in [SPARK-19055](https://issues.apache.org/jira/browse/SPARK-19055), but the fix did not change SQLContext as well. I ran into this because mllib still [uses](https://github.com/apache/spark/blob/master/python/pyspark/mllib/common.py#L105) SQLContext.getOrCreate() under the hood. ### Does this PR introduce any user-facing change? No ### How was this patch tested? A new test was added. I verified that the test fails without the included change. Closes #27610 from afavaro/restart-sqlcontext. Authored-by: Alex Favaro <alex.favaro@affirm.com> Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
parent
e32411eb07
commit
96c1a4401d
|
@ -87,7 +87,8 @@ class SQLContext(object):
|
|||
self._jsqlContext = jsqlContext
|
||||
_monkey_patch_RDD(self.sparkSession)
|
||||
install_exception_handler()
|
||||
if SQLContext._instantiatedContext is None:
|
||||
if (SQLContext._instantiatedContext is None
|
||||
or SQLContext._instantiatedContext._sc._jsc is None):
|
||||
SQLContext._instantiatedContext = self
|
||||
|
||||
@property
|
||||
|
@ -118,7 +119,8 @@ class SQLContext(object):
|
|||
"Deprecated in 3.0.0. Use SparkSession.builder.getOrCreate() instead.",
|
||||
DeprecationWarning)
|
||||
|
||||
if cls._instantiatedContext is None:
|
||||
if (cls._instantiatedContext is None
|
||||
or SQLContext._instantiatedContext._sc._jsc is None):
|
||||
jsqlContext = sc._jvm.SparkSession.builder().sparkContext(
|
||||
sc._jsc.sc()).getOrCreate().sqlContext()
|
||||
sparkSession = SparkSession(sc, jsqlContext.sparkSession())
|
||||
|
|
|
@ -699,12 +699,14 @@ class SparkSession(SparkConversionMixin):
|
|||
def stop(self):
|
||||
"""Stop the underlying :class:`SparkContext`.
|
||||
"""
|
||||
from pyspark.sql.context import SQLContext
|
||||
self._sc.stop()
|
||||
# We should clean the default session up. See SPARK-23228.
|
||||
self._jvm.SparkSession.clearDefaultSession()
|
||||
self._jvm.SparkSession.clearActiveSession()
|
||||
SparkSession._instantiatedSession = None
|
||||
SparkSession._activeSession = None
|
||||
SQLContext._instantiatedContext = None
|
||||
|
||||
@since(2.0)
|
||||
def __enter__(self):
|
||||
|
|
|
@ -270,7 +270,6 @@ class SQLContextTests(unittest.TestCase):
|
|||
sql_context = SQLContext.getOrCreate(sc)
|
||||
assert(isinstance(sql_context, SQLContext))
|
||||
finally:
|
||||
SQLContext._instantiatedContext = None
|
||||
if sql_context is not None:
|
||||
sql_context.sparkSession.stop()
|
||||
if sc is not None:
|
||||
|
|
|
@ -225,6 +225,52 @@ class SparkSessionTests4(ReusedSQLTestCase):
|
|||
session2.stop()
|
||||
|
||||
|
||||
class SparkSessionTests5(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# These tests require restarting the Spark context so we set up a new one for each test
|
||||
# rather than at the class level.
|
||||
self.sc = SparkContext('local[4]', self.__class__.__name__, conf=SparkConf())
|
||||
self.spark = SparkSession(self.sc)
|
||||
|
||||
def tearDown(self):
|
||||
self.sc.stop()
|
||||
self.spark.stop()
|
||||
|
||||
def test_sqlcontext_with_stopped_sparksession(self):
|
||||
# SPARK-30856: test that SQLContext.getOrCreate() returns a usable instance after
|
||||
# the SparkSession is restarted.
|
||||
sql_context = self.spark._wrapped
|
||||
self.spark.stop()
|
||||
sc = SparkContext('local[4]', self.sc.appName)
|
||||
spark = SparkSession(sc) # Instantiate the underlying SQLContext
|
||||
new_sql_context = spark._wrapped
|
||||
|
||||
self.assertIsNot(new_sql_context, sql_context)
|
||||
self.assertIs(SQLContext.getOrCreate(sc).sparkSession, spark)
|
||||
try:
|
||||
df = spark.createDataFrame([(1, 2)], ['c', 'c'])
|
||||
df.collect()
|
||||
finally:
|
||||
spark.stop()
|
||||
self.assertIsNone(SQLContext._instantiatedContext)
|
||||
sc.stop()
|
||||
|
||||
def test_sqlcontext_with_stopped_sparkcontext(self):
|
||||
# SPARK-30856: test initialization via SparkSession when only the SparkContext is stopped
|
||||
self.sc.stop()
|
||||
self.sc = SparkContext('local[4]', self.sc.appName)
|
||||
self.spark = SparkSession(self.sc)
|
||||
self.assertIs(SQLContext.getOrCreate(self.sc).sparkSession, self.spark)
|
||||
|
||||
def test_get_sqlcontext_with_stopped_sparkcontext(self):
|
||||
# SPARK-30856: test initialization via SQLContext.getOrCreate() when only the SparkContext
|
||||
# is stopped
|
||||
self.sc.stop()
|
||||
self.sc = SparkContext('local[4]', self.sc.appName)
|
||||
self.assertIs(SQLContext.getOrCreate(self.sc)._sc, self.sc)
|
||||
|
||||
|
||||
class SparkSessionBuilderTests(unittest.TestCase):
|
||||
|
||||
def test_create_spark_context_first_then_spark_session(self):
|
||||
|
|
Loading…
Reference in a new issue