diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 4a1d341f5d..0adbf1d96e 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -97,13 +97,26 @@ class SparkEnv ( object SparkEnv extends Logging { private val env = new ThreadLocal[SparkEnv] + @volatile private var lastSetSparkEnv : SparkEnv = _ def set(e: SparkEnv) { + lastSetSparkEnv = e env.set(e) } + /** + * Returns the ThreadLocal SparkEnv, if non-null. Else returns the SparkEnv + * previously set in any thread. + */ def get: SparkEnv = { - env.get() + Option(env.get()).getOrElse(lastSetSparkEnv) + } + + /** + * Returns the ThreadLocal SparkEnv. + */ + def getThreadLocal : SparkEnv = { + env.get() } def createFromSystemProperties(