SPARK-1480: Clean up use of classloaders
The Spark codebase is a bit fast-and-loose when accessing classloaders and this has caused a few bugs to surface in master. This patch defines some utility methods for accessing classloaders. This makes the intention when accessing a classloader much more explicit in the code and fixes a few cases where the wrong one was chosen. case (a) -> We want the classloader that loaded Spark case (b) -> We want the context class loader, or if not present, we want (a) This patch provides a better fix for SPARK-1403 (https://issues.apache.org/jira/browse/SPARK-1403) than the current work around, which it reverts. It also fixes a previously unreported bug that the `./spark-submit` script did not work for running with `local` master. It didn't work because the executor classloader did not properly delegate to the context class loader (if it is defined) and in local mode the context class loader is set by the `./spark-submit` script. A unit test is added for that case. Author: Patrick Wendell <pwendell@gmail.com> Closes #398 from pwendell/class-loaders and squashes the following commits: b4a1a58 [Patrick Wendell] Minor clean up 14f1272 [Patrick Wendell] SPARK-1480: Clean up use of classloaders
This commit is contained in:
parent
ca11919e6e
commit
4bc07eebbf
|
@ -22,6 +22,7 @@ import org.slf4j.{Logger, LoggerFactory}
|
|||
import org.slf4j.impl.StaticLoggerBinder
|
||||
|
||||
import org.apache.spark.annotation.DeveloperApi
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
/**
|
||||
* :: DeveloperApi ::
|
||||
|
@ -115,8 +116,7 @@ trait Logging {
|
|||
val log4jInitialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements
|
||||
if (!log4jInitialized && usingLog4j) {
|
||||
val defaultLogProps = "org/apache/spark/log4j-defaults.properties"
|
||||
val classLoader = this.getClass.getClassLoader
|
||||
Option(classLoader.getResource(defaultLogProps)) match {
|
||||
Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match {
|
||||
case Some(url) =>
|
||||
PropertyConfigurator.configure(url)
|
||||
log.info(s"Using Spark's default log4j profile: $defaultLogProps")
|
||||
|
|
|
@ -292,7 +292,7 @@ private[spark] class Executor(
|
|||
* created by the interpreter to the search path
|
||||
*/
|
||||
private def createClassLoader(): MutableURLClassLoader = {
|
||||
val loader = this.getClass.getClassLoader
|
||||
val currentLoader = Utils.getContextOrSparkClassLoader
|
||||
|
||||
// For each of the jars in the jarSet, add them to the class loader.
|
||||
// We assume each of the files has already been fetched.
|
||||
|
@ -301,8 +301,8 @@ private[spark] class Executor(
|
|||
}.toArray
|
||||
val userClassPathFirst = conf.getBoolean("spark.files.userClassPathFirst", false)
|
||||
userClassPathFirst match {
|
||||
case true => new ChildExecutorURLClassLoader(urls, loader)
|
||||
case false => new ExecutorURLClassLoader(urls, loader)
|
||||
case true => new ChildExecutorURLClassLoader(urls, currentLoader)
|
||||
case false => new ExecutorURLClassLoader(urls, currentLoader)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -50,21 +50,13 @@ private[spark] class MesosExecutorBackend
|
|||
executorInfo: ExecutorInfo,
|
||||
frameworkInfo: FrameworkInfo,
|
||||
slaveInfo: SlaveInfo) {
|
||||
val cl = Thread.currentThread.getContextClassLoader
|
||||
try {
|
||||
// Work around for SPARK-1480
|
||||
Thread.currentThread.setContextClassLoader(getClass.getClassLoader)
|
||||
logInfo("Registered with Mesos as executor ID " + executorInfo.getExecutorId.getValue)
|
||||
this.driver = driver
|
||||
val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray)
|
||||
executor = new Executor(
|
||||
executorInfo.getExecutorId.getValue,
|
||||
slaveInfo.getHostname,
|
||||
properties)
|
||||
} finally {
|
||||
// Work around for SPARK-1480
|
||||
Thread.currentThread.setContextClassLoader(cl)
|
||||
}
|
||||
logInfo("Registered with Mesos as executor ID " + executorInfo.getExecutorId.getValue)
|
||||
this.driver = driver
|
||||
val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray)
|
||||
executor = new Executor(
|
||||
executorInfo.getExecutorId.getValue,
|
||||
slaveInfo.getHostname,
|
||||
properties)
|
||||
}
|
||||
|
||||
override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) {
|
||||
|
|
|
@ -24,6 +24,7 @@ import scala.collection.mutable
|
|||
import scala.util.matching.Regex
|
||||
|
||||
import org.apache.spark.Logging
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
private[spark] class MetricsConfig(val configFile: Option[String]) extends Logging {
|
||||
|
||||
|
@ -50,7 +51,7 @@ private[spark] class MetricsConfig(val configFile: Option[String]) extends Loggi
|
|||
try {
|
||||
is = configFile match {
|
||||
case Some(f) => new FileInputStream(f)
|
||||
case None => getClass.getClassLoader.getResourceAsStream(METRICS_CONF)
|
||||
case None => Utils.getSparkClassLoader.getResourceAsStream(METRICS_CONF)
|
||||
}
|
||||
|
||||
if (is != null) {
|
||||
|
|
|
@ -54,7 +54,6 @@ private[spark] object ResultTask {
|
|||
|
||||
def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) =
|
||||
{
|
||||
val loader = Thread.currentThread.getContextClassLoader
|
||||
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
|
||||
val ser = SparkEnv.get.closureSerializer.newInstance()
|
||||
val objIn = ser.deserializeStream(in)
|
||||
|
|
|
@ -23,6 +23,7 @@ import java.util.{NoSuchElementException, Properties}
|
|||
import scala.xml.XML
|
||||
|
||||
import org.apache.spark.{Logging, SparkConf}
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
/**
|
||||
* An interface to build Schedulable tree
|
||||
|
@ -72,7 +73,7 @@ private[spark] class FairSchedulableBuilder(val rootPool: Pool, conf: SparkConf)
|
|||
schedulerAllocFile.map { f =>
|
||||
new FileInputStream(f)
|
||||
}.getOrElse {
|
||||
getClass.getClassLoader.getResourceAsStream(DEFAULT_SCHEDULER_FILE)
|
||||
Utils.getSparkClassLoader.getResourceAsStream(DEFAULT_SCHEDULER_FILE)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -85,13 +85,13 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
|
|||
try {
|
||||
if (serializedData != null && serializedData.limit() > 0) {
|
||||
reason = serializer.get().deserialize[TaskEndReason](
|
||||
serializedData, getClass.getClassLoader)
|
||||
serializedData, Utils.getSparkClassLoader)
|
||||
}
|
||||
} catch {
|
||||
case cnd: ClassNotFoundException =>
|
||||
// Log an error but keep going here -- the task failed, so not catastropic if we can't
|
||||
// deserialize the reason.
|
||||
val loader = Thread.currentThread.getContextClassLoader
|
||||
val loader = Utils.getContextOrSparkClassLoader
|
||||
logError(
|
||||
"Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader)
|
||||
case ex: Throwable => {}
|
||||
|
|
|
@ -23,6 +23,7 @@ import java.nio.ByteBuffer
|
|||
import org.apache.spark.SparkConf
|
||||
import org.apache.spark.annotation.DeveloperApi
|
||||
import org.apache.spark.util.ByteBufferInputStream
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
private[spark] class JavaSerializationStream(out: OutputStream, counterReset: Int)
|
||||
extends SerializationStream {
|
||||
|
@ -86,7 +87,7 @@ private[spark] class JavaSerializerInstance(counterReset: Int) extends Serialize
|
|||
}
|
||||
|
||||
def deserializeStream(s: InputStream): DeserializationStream = {
|
||||
new JavaDeserializationStream(s, Thread.currentThread.getContextClassLoader)
|
||||
new JavaDeserializationStream(s, Utils.getContextOrSparkClassLoader)
|
||||
}
|
||||
|
||||
def deserializeStream(s: InputStream, loader: ClassLoader): DeserializationStream = {
|
||||
|
|
|
@ -33,6 +33,7 @@ import org.json4s.JValue
|
|||
import org.json4s.jackson.JsonMethods.{pretty, render}
|
||||
|
||||
import org.apache.spark.{Logging, SecurityManager, SparkConf}
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
/**
|
||||
* Utilities for launching a web server using Jetty's HTTP Server class
|
||||
|
@ -124,7 +125,7 @@ private[spark] object JettyUtils extends Logging {
|
|||
contextHandler.setInitParameter("org.eclipse.jetty.servlet.Default.gzip", "false")
|
||||
val staticHandler = new DefaultServlet
|
||||
val holder = new ServletHolder(staticHandler)
|
||||
Option(getClass.getClassLoader.getResource(resourceBase)) match {
|
||||
Option(Utils.getSparkClassLoader.getResource(resourceBase)) match {
|
||||
case Some(res) =>
|
||||
holder.setInitParameter("resourceBase", res.toString)
|
||||
case None =>
|
||||
|
|
|
@ -116,6 +116,21 @@ private[spark] object Utils extends Logging {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the ClassLoader which loaded Spark.
|
||||
*/
|
||||
def getSparkClassLoader = getClass.getClassLoader
|
||||
|
||||
/**
|
||||
* Get the Context ClassLoader on this thread or, if not present, the ClassLoader that
|
||||
* loaded Spark.
|
||||
*
|
||||
* This should be used whenever passing a ClassLoader to Class.ForName or finding the currently
|
||||
* active loader when setting up ClassLoader delegation chains.
|
||||
*/
|
||||
def getContextOrSparkClassLoader =
|
||||
Option(Thread.currentThread().getContextClassLoader).getOrElse(getSparkClassLoader)
|
||||
|
||||
/**
|
||||
* Primitive often used when writing {@link java.nio.ByteBuffer} to {@link java.io.DataOutput}.
|
||||
*/
|
||||
|
|
|
@ -17,12 +17,12 @@
|
|||
|
||||
package org.apache.spark.executor
|
||||
|
||||
import java.io.File
|
||||
import java.net.URLClassLoader
|
||||
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
import org.apache.spark.TestUtils
|
||||
import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, TestUtils}
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
class ExecutorURLClassLoaderSuite extends FunSuite {
|
||||
|
||||
|
@ -63,5 +63,33 @@ class ExecutorURLClassLoaderSuite extends FunSuite {
|
|||
}
|
||||
}
|
||||
|
||||
test("driver sets context class loader in local mode") {
|
||||
// Test the case where the driver program sets a context classloader and then runs a job
|
||||
// in local mode. This is what happens when ./spark-submit is called with "local" as the
|
||||
// master.
|
||||
val original = Thread.currentThread().getContextClassLoader
|
||||
|
||||
val className = "ClassForDriverTest"
|
||||
val jar = TestUtils.createJarWithClasses(Seq(className))
|
||||
val contextLoader = new URLClassLoader(Array(jar), Utils.getContextOrSparkClassLoader)
|
||||
Thread.currentThread().setContextClassLoader(contextLoader)
|
||||
|
||||
val sc = new SparkContext("local", "driverLoaderTest")
|
||||
|
||||
try {
|
||||
sc.makeRDD(1 to 5, 2).mapPartitions { x =>
|
||||
val loader = Thread.currentThread().getContextClassLoader
|
||||
Class.forName(className, true, loader).newInstance()
|
||||
Seq().iterator
|
||||
}.count()
|
||||
}
|
||||
catch {
|
||||
case e: SparkException if e.getMessage.contains("ClassNotFoundException") =>
|
||||
fail("Local executor could not find class", e)
|
||||
case t: Throwable => fail("Unexpected exception ", t)
|
||||
}
|
||||
|
||||
sc.stop()
|
||||
Thread.currentThread().setContextClassLoader(original)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -39,6 +39,7 @@ import scala.reflect.api.{Mirror, TypeCreator, Universe => ApiUniverse}
|
|||
import org.apache.spark.Logging
|
||||
import org.apache.spark.SparkConf
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
/** The Scala interactive shell. It provides a read-eval-print loop
|
||||
* around the Interpreter class.
|
||||
|
@ -130,7 +131,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
|
|||
def history = in.history
|
||||
|
||||
/** The context class loader at the time this object was created */
|
||||
protected val originalClassLoader = Thread.currentThread.getContextClassLoader
|
||||
protected val originalClassLoader = Utils.getContextOrSparkClassLoader
|
||||
|
||||
// classpath entries added via :cp
|
||||
var addedClasspath: String = ""
|
||||
|
@ -177,7 +178,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
|
|||
override lazy val formatting = new Formatting {
|
||||
def prompt = SparkILoop.this.prompt
|
||||
}
|
||||
override protected def parentClassLoader = SparkHelper.explicitParentLoader(settings).getOrElse(classOf[SparkILoop].getClassLoader)
|
||||
override protected def parentClassLoader = SparkHelper.explicitParentLoader(settings).getOrElse(classOf[SparkILoop].getClassLoader)
|
||||
}
|
||||
|
||||
/** Create a new interpreter. */
|
||||
|
@ -871,7 +872,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
|
|||
}
|
||||
|
||||
val u: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe
|
||||
val m = u.runtimeMirror(getClass.getClassLoader)
|
||||
val m = u.runtimeMirror(Utils.getSparkClassLoader)
|
||||
private def tagOfStaticClass[T: ClassTag]: u.TypeTag[T] =
|
||||
u.TypeTag[T](
|
||||
m,
|
||||
|
|
|
@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst
|
|||
|
||||
import java.io.{PrintWriter, ByteArrayOutputStream, FileInputStream, File}
|
||||
|
||||
import org.apache.spark.util.{Utils => SparkUtils}
|
||||
|
||||
package object util {
|
||||
/**
|
||||
* Returns a path to a temporary file that probably does not exist.
|
||||
|
@ -54,7 +56,7 @@ package object util {
|
|||
def resourceToString(
|
||||
resource:String,
|
||||
encoding: String = "UTF-8",
|
||||
classLoader: ClassLoader = this.getClass.getClassLoader) = {
|
||||
classLoader: ClassLoader = SparkUtils.getSparkClassLoader) = {
|
||||
val inStream = classLoader.getResourceAsStream(resource)
|
||||
val outStream = new ByteArrayOutputStream
|
||||
try {
|
||||
|
|
|
@ -26,6 +26,7 @@ import scala.reflect.runtime.universe.runtimeMirror
|
|||
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
|
||||
import org.apache.spark.sql.catalyst.types._
|
||||
import org.apache.spark.sql.columnar._
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
private[sql] case object PassThrough extends CompressionScheme {
|
||||
override val typeId = 0
|
||||
|
@ -254,7 +255,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme {
|
|||
private val dictionary = {
|
||||
// TODO Can we clean up this mess? Maybe move this to `DataType`?
|
||||
implicit val classTag = {
|
||||
val mirror = runtimeMirror(getClass.getClassLoader)
|
||||
val mirror = runtimeMirror(Utils.getSparkClassLoader)
|
||||
ClassTag[T#JvmType](mirror.runtimeClass(columnType.scalaTag.tpe))
|
||||
}
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ import com.esotericsoftware.kryo.{Serializer, Kryo}
|
|||
import org.apache.spark.{SparkEnv, SparkConf}
|
||||
import org.apache.spark.serializer.KryoSerializer
|
||||
import org.apache.spark.util.MutablePair
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) {
|
||||
override def newKryo(): Kryo = {
|
||||
|
@ -44,7 +45,7 @@ class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) {
|
|||
kryo.register(classOf[scala.collection.mutable.ArrayBuffer[_]])
|
||||
kryo.register(classOf[scala.math.BigDecimal], new BigDecimalSerializer)
|
||||
kryo.setReferences(false)
|
||||
kryo.setClassLoader(this.getClass.getClassLoader)
|
||||
kryo.setClassLoader(Utils.getSparkClassLoader)
|
||||
kryo
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue