SPARK-1480: Clean up use of classloaders

The Spark codebase is a bit fast-and-loose when accessing classloaders and this has caused a few bugs to surface in master.

This patch defines some utility methods for accessing classloaders. This makes the intention when accessing a classloader much more explicit in the code and fixes a few cases where the wrong one was chosen.

case (a) -> We want the classloader that loaded Spark
case (b) -> We want the context class loader, or if not present, we want (a)

This patch provides a better fix for SPARK-1403 (https://issues.apache.org/jira/browse/SPARK-1403) than the current work around, which it reverts. It also fixes a previously unreported bug that the `./spark-submit` script did not work for running with `local` master. It didn't work because the executor classloader did not properly delegate to the context class loader (if it is defined) and in local mode the context class loader is set by the `./spark-submit` script. A unit test is added for that case.

Author: Patrick Wendell <pwendell@gmail.com>

Closes #398 from pwendell/class-loaders and squashes the following commits:

b4a1a58 [Patrick Wendell] Minor clean up
14f1272 [Patrick Wendell] SPARK-1480: Clean up use of classloaders
This commit is contained in:
Patrick Wendell 2014-04-13 08:58:37 -07:00
parent ca11919e6e
commit 4bc07eebbf
15 changed files with 78 additions and 35 deletions

View file

@ -22,6 +22,7 @@ import org.slf4j.{Logger, LoggerFactory}
import org.slf4j.impl.StaticLoggerBinder
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.Utils
/**
* :: DeveloperApi ::
@ -115,8 +116,7 @@ trait Logging {
val log4jInitialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements
if (!log4jInitialized && usingLog4j) {
val defaultLogProps = "org/apache/spark/log4j-defaults.properties"
val classLoader = this.getClass.getClassLoader
Option(classLoader.getResource(defaultLogProps)) match {
Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match {
case Some(url) =>
PropertyConfigurator.configure(url)
log.info(s"Using Spark's default log4j profile: $defaultLogProps")

View file

@ -292,7 +292,7 @@ private[spark] class Executor(
* created by the interpreter to the search path
*/
private def createClassLoader(): MutableURLClassLoader = {
val loader = this.getClass.getClassLoader
val currentLoader = Utils.getContextOrSparkClassLoader
// For each of the jars in the jarSet, add them to the class loader.
// We assume each of the files has already been fetched.
@ -301,8 +301,8 @@ private[spark] class Executor(
}.toArray
val userClassPathFirst = conf.getBoolean("spark.files.userClassPathFirst", false)
userClassPathFirst match {
case true => new ChildExecutorURLClassLoader(urls, loader)
case false => new ExecutorURLClassLoader(urls, loader)
case true => new ChildExecutorURLClassLoader(urls, currentLoader)
case false => new ExecutorURLClassLoader(urls, currentLoader)
}
}

View file

@ -50,21 +50,13 @@ private[spark] class MesosExecutorBackend
executorInfo: ExecutorInfo,
frameworkInfo: FrameworkInfo,
slaveInfo: SlaveInfo) {
val cl = Thread.currentThread.getContextClassLoader
try {
// Work around for SPARK-1480
Thread.currentThread.setContextClassLoader(getClass.getClassLoader)
logInfo("Registered with Mesos as executor ID " + executorInfo.getExecutorId.getValue)
this.driver = driver
val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray)
executor = new Executor(
executorInfo.getExecutorId.getValue,
slaveInfo.getHostname,
properties)
} finally {
// Work around for SPARK-1480
Thread.currentThread.setContextClassLoader(cl)
}
logInfo("Registered with Mesos as executor ID " + executorInfo.getExecutorId.getValue)
this.driver = driver
val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray)
executor = new Executor(
executorInfo.getExecutorId.getValue,
slaveInfo.getHostname,
properties)
}
override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) {

View file

@ -24,6 +24,7 @@ import scala.collection.mutable
import scala.util.matching.Regex
import org.apache.spark.Logging
import org.apache.spark.util.Utils
private[spark] class MetricsConfig(val configFile: Option[String]) extends Logging {
@ -50,7 +51,7 @@ private[spark] class MetricsConfig(val configFile: Option[String]) extends Loggi
try {
is = configFile match {
case Some(f) => new FileInputStream(f)
case None => getClass.getClassLoader.getResourceAsStream(METRICS_CONF)
case None => Utils.getSparkClassLoader.getResourceAsStream(METRICS_CONF)
}
if (is != null) {

View file

@ -54,7 +54,6 @@ private[spark] object ResultTask {
def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) =
{
val loader = Thread.currentThread.getContextClassLoader
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
val ser = SparkEnv.get.closureSerializer.newInstance()
val objIn = ser.deserializeStream(in)

View file

@ -23,6 +23,7 @@ import java.util.{NoSuchElementException, Properties}
import scala.xml.XML
import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.util.Utils
/**
* An interface to build Schedulable tree
@ -72,7 +73,7 @@ private[spark] class FairSchedulableBuilder(val rootPool: Pool, conf: SparkConf)
schedulerAllocFile.map { f =>
new FileInputStream(f)
}.getOrElse {
getClass.getClassLoader.getResourceAsStream(DEFAULT_SCHEDULER_FILE)
Utils.getSparkClassLoader.getResourceAsStream(DEFAULT_SCHEDULER_FILE)
}
}

View file

@ -85,13 +85,13 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
try {
if (serializedData != null && serializedData.limit() > 0) {
reason = serializer.get().deserialize[TaskEndReason](
serializedData, getClass.getClassLoader)
serializedData, Utils.getSparkClassLoader)
}
} catch {
case cnd: ClassNotFoundException =>
// Log an error but keep going here -- the task failed, so not catastropic if we can't
// deserialize the reason.
val loader = Thread.currentThread.getContextClassLoader
val loader = Utils.getContextOrSparkClassLoader
logError(
"Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader)
case ex: Throwable => {}

View file

@ -23,6 +23,7 @@ import java.nio.ByteBuffer
import org.apache.spark.SparkConf
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.ByteBufferInputStream
import org.apache.spark.util.Utils
private[spark] class JavaSerializationStream(out: OutputStream, counterReset: Int)
extends SerializationStream {
@ -86,7 +87,7 @@ private[spark] class JavaSerializerInstance(counterReset: Int) extends Serialize
}
def deserializeStream(s: InputStream): DeserializationStream = {
new JavaDeserializationStream(s, Thread.currentThread.getContextClassLoader)
new JavaDeserializationStream(s, Utils.getContextOrSparkClassLoader)
}
def deserializeStream(s: InputStream, loader: ClassLoader): DeserializationStream = {

View file

@ -33,6 +33,7 @@ import org.json4s.JValue
import org.json4s.jackson.JsonMethods.{pretty, render}
import org.apache.spark.{Logging, SecurityManager, SparkConf}
import org.apache.spark.util.Utils
/**
* Utilities for launching a web server using Jetty's HTTP Server class
@ -124,7 +125,7 @@ private[spark] object JettyUtils extends Logging {
contextHandler.setInitParameter("org.eclipse.jetty.servlet.Default.gzip", "false")
val staticHandler = new DefaultServlet
val holder = new ServletHolder(staticHandler)
Option(getClass.getClassLoader.getResource(resourceBase)) match {
Option(Utils.getSparkClassLoader.getResource(resourceBase)) match {
case Some(res) =>
holder.setInitParameter("resourceBase", res.toString)
case None =>

View file

@ -116,6 +116,21 @@ private[spark] object Utils extends Logging {
}
}
/**
* Get the ClassLoader which loaded Spark.
*/
def getSparkClassLoader = getClass.getClassLoader
/**
* Get the Context ClassLoader on this thread or, if not present, the ClassLoader that
* loaded Spark.
*
* This should be used whenever passing a ClassLoader to Class.ForName or finding the currently
* active loader when setting up ClassLoader delegation chains.
*/
def getContextOrSparkClassLoader =
Option(Thread.currentThread().getContextClassLoader).getOrElse(getSparkClassLoader)
/**
* Primitive often used when writing {@link java.nio.ByteBuffer} to {@link java.io.DataOutput}.
*/

View file

@ -17,12 +17,12 @@
package org.apache.spark.executor
import java.io.File
import java.net.URLClassLoader
import org.scalatest.FunSuite
import org.apache.spark.TestUtils
import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, TestUtils}
import org.apache.spark.util.Utils
class ExecutorURLClassLoaderSuite extends FunSuite {
@ -63,5 +63,33 @@ class ExecutorURLClassLoaderSuite extends FunSuite {
}
}
test("driver sets context class loader in local mode") {
// Test the case where the driver program sets a context classloader and then runs a job
// in local mode. This is what happens when ./spark-submit is called with "local" as the
// master.
val original = Thread.currentThread().getContextClassLoader
val className = "ClassForDriverTest"
val jar = TestUtils.createJarWithClasses(Seq(className))
val contextLoader = new URLClassLoader(Array(jar), Utils.getContextOrSparkClassLoader)
Thread.currentThread().setContextClassLoader(contextLoader)
val sc = new SparkContext("local", "driverLoaderTest")
try {
sc.makeRDD(1 to 5, 2).mapPartitions { x =>
val loader = Thread.currentThread().getContextClassLoader
Class.forName(className, true, loader).newInstance()
Seq().iterator
}.count()
}
catch {
case e: SparkException if e.getMessage.contains("ClassNotFoundException") =>
fail("Local executor could not find class", e)
case t: Throwable => fail("Unexpected exception ", t)
}
sc.stop()
Thread.currentThread().setContextClassLoader(original)
}
}

View file

@ -39,6 +39,7 @@ import scala.reflect.api.{Mirror, TypeCreator, Universe => ApiUniverse}
import org.apache.spark.Logging
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.util.Utils
/** The Scala interactive shell. It provides a read-eval-print loop
* around the Interpreter class.
@ -130,7 +131,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
def history = in.history
/** The context class loader at the time this object was created */
protected val originalClassLoader = Thread.currentThread.getContextClassLoader
protected val originalClassLoader = Utils.getContextOrSparkClassLoader
// classpath entries added via :cp
var addedClasspath: String = ""
@ -177,7 +178,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
override lazy val formatting = new Formatting {
def prompt = SparkILoop.this.prompt
}
override protected def parentClassLoader = SparkHelper.explicitParentLoader(settings).getOrElse(classOf[SparkILoop].getClassLoader)
override protected def parentClassLoader = SparkHelper.explicitParentLoader(settings).getOrElse(classOf[SparkILoop].getClassLoader)
}
/** Create a new interpreter. */
@ -871,7 +872,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
}
val u: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe
val m = u.runtimeMirror(getClass.getClassLoader)
val m = u.runtimeMirror(Utils.getSparkClassLoader)
private def tagOfStaticClass[T: ClassTag]: u.TypeTag[T] =
u.TypeTag[T](
m,

View file

@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst
import java.io.{PrintWriter, ByteArrayOutputStream, FileInputStream, File}
import org.apache.spark.util.{Utils => SparkUtils}
package object util {
/**
* Returns a path to a temporary file that probably does not exist.
@ -54,7 +56,7 @@ package object util {
def resourceToString(
resource:String,
encoding: String = "UTF-8",
classLoader: ClassLoader = this.getClass.getClassLoader) = {
classLoader: ClassLoader = SparkUtils.getSparkClassLoader) = {
val inStream = classLoader.getResourceAsStream(resource)
val outStream = new ByteArrayOutputStream
try {

View file

@ -26,6 +26,7 @@ import scala.reflect.runtime.universe.runtimeMirror
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.columnar._
import org.apache.spark.util.Utils
private[sql] case object PassThrough extends CompressionScheme {
override val typeId = 0
@ -254,7 +255,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme {
private val dictionary = {
// TODO Can we clean up this mess? Maybe move this to `DataType`?
implicit val classTag = {
val mirror = runtimeMirror(getClass.getClassLoader)
val mirror = runtimeMirror(Utils.getSparkClassLoader)
ClassTag[T#JvmType](mirror.runtimeClass(columnType.scalaTag.tpe))
}

View file

@ -25,6 +25,7 @@ import com.esotericsoftware.kryo.{Serializer, Kryo}
import org.apache.spark.{SparkEnv, SparkConf}
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.util.MutablePair
import org.apache.spark.util.Utils
class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) {
override def newKryo(): Kryo = {
@ -44,7 +45,7 @@ class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) {
kryo.register(classOf[scala.collection.mutable.ArrayBuffer[_]])
kryo.register(classOf[scala.math.BigDecimal], new BigDecimalSerializer)
kryo.setReferences(false)
kryo.setClassLoader(this.getClass.getClassLoader)
kryo.setClassLoader(Utils.getSparkClassLoader)
kryo
}
}