[SPARK-11563][CORE][REPL] Use RpcEnv to transfer REPL-generated classes.

This avoids bringing up yet another HTTP server on the driver, and
instead reuses the file server already managed by the driver's
RpcEnv. As a bonus, the repl now inherits the security features of
the network library.

There's also a small change to create the directory for storing classes
under the root temp dir for the application (instead of directly
under java.io.tmpdir).

Author: Marcelo Vanzin <vanzin@cloudera.com>

Closes #9923 from vanzin/SPARK-11563.
This commit is contained in:
Marcelo Vanzin 2015-12-10 13:26:30 -08:00
parent 2ecbe02d5b
commit 4a46b8859d
15 changed files with 183 additions and 98 deletions

View file

@ -71,6 +71,11 @@ private[spark] class HttpFileServer(
serverUri + "/jars/" + file.getName serverUri + "/jars/" + file.getName
} }
def addDirectory(path: String, resourceBase: String): String = {
httpServer.addDirectory(path, resourceBase)
serverUri + path
}
def addFileToDir(file: File, dir: File) : String = { def addFileToDir(file: File, dir: File) : String = {
// Check whether the file is a directory. If it is, throw a more meaningful exception. // Check whether the file is a directory. If it is, throw a more meaningful exception.
// If we don't catch this, Guava throws a very confusing error message: // If we don't catch this, Guava throws a very confusing error message:

View file

@ -23,10 +23,9 @@ import org.eclipse.jetty.server.ssl.SslSocketConnector
import org.eclipse.jetty.util.security.{Constraint, Password} import org.eclipse.jetty.util.security.{Constraint, Password}
import org.eclipse.jetty.security.authentication.DigestAuthenticator import org.eclipse.jetty.security.authentication.DigestAuthenticator
import org.eclipse.jetty.security.{ConstraintMapping, ConstraintSecurityHandler, HashLoginService} import org.eclipse.jetty.security.{ConstraintMapping, ConstraintSecurityHandler, HashLoginService}
import org.eclipse.jetty.server.Server import org.eclipse.jetty.server.Server
import org.eclipse.jetty.server.bio.SocketConnector import org.eclipse.jetty.server.bio.SocketConnector
import org.eclipse.jetty.server.handler.{DefaultHandler, HandlerList, ResourceHandler} import org.eclipse.jetty.servlet.{DefaultServlet, ServletContextHandler, ServletHolder}
import org.eclipse.jetty.util.thread.QueuedThreadPool import org.eclipse.jetty.util.thread.QueuedThreadPool
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
@ -52,6 +51,11 @@ private[spark] class HttpServer(
private var server: Server = null private var server: Server = null
private var port: Int = requestedPort private var port: Int = requestedPort
private val servlets = {
val handler = new ServletContextHandler()
handler.setContextPath("/")
handler
}
def start() { def start() {
if (server != null) { if (server != null) {
@ -65,6 +69,14 @@ private[spark] class HttpServer(
} }
} }
def addDirectory(contextPath: String, resourceBase: String): Unit = {
val holder = new ServletHolder()
holder.setInitParameter("resourceBase", resourceBase)
holder.setInitParameter("pathInfoOnly", "true")
holder.setServlet(new DefaultServlet())
servlets.addServlet(holder, contextPath.stripSuffix("/") + "/*")
}
/** /**
* Actually start the HTTP server on the given port. * Actually start the HTTP server on the given port.
* *
@ -85,21 +97,17 @@ private[spark] class HttpServer(
val threadPool = new QueuedThreadPool val threadPool = new QueuedThreadPool
threadPool.setDaemon(true) threadPool.setDaemon(true)
server.setThreadPool(threadPool) server.setThreadPool(threadPool)
val resHandler = new ResourceHandler addDirectory("/", resourceBase.getAbsolutePath)
resHandler.setResourceBase(resourceBase.getAbsolutePath)
val handlerList = new HandlerList
handlerList.setHandlers(Array(resHandler, new DefaultHandler))
if (securityManager.isAuthenticationEnabled()) { if (securityManager.isAuthenticationEnabled()) {
logDebug("HttpServer is using security") logDebug("HttpServer is using security")
val sh = setupSecurityHandler(securityManager) val sh = setupSecurityHandler(securityManager)
// make sure we go through security handler to get resources // make sure we go through security handler to get resources
sh.setHandler(handlerList) sh.setHandler(servlets)
server.setHandler(sh) server.setHandler(sh)
} else { } else {
logDebug("HttpServer is not using security") logDebug("HttpServer is not using security")
server.setHandler(handlerList) server.setHandler(servlets)
} }
server.start() server.start()

View file

@ -457,6 +457,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
_env = createSparkEnv(_conf, isLocal, listenerBus) _env = createSparkEnv(_conf, isLocal, listenerBus)
SparkEnv.set(_env) SparkEnv.set(_env)
// If running the REPL, register the repl's output dir with the file server.
_conf.getOption("spark.repl.class.outputDir").foreach { path =>
val replUri = _env.rpcEnv.fileServer.addDirectory("/classes", new File(path))
_conf.set("spark.repl.class.uri", replUri)
}
_metadataCleaner = new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, _conf) _metadataCleaner = new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, _conf)
_statusTracker = new SparkStatusTracker(this) _statusTracker = new SparkStatusTracker(this)

View file

@ -364,9 +364,9 @@ private[spark] class Executor(
val _userClassPathFirst: java.lang.Boolean = userClassPathFirst val _userClassPathFirst: java.lang.Boolean = userClassPathFirst
val klass = Utils.classForName("org.apache.spark.repl.ExecutorClassLoader") val klass = Utils.classForName("org.apache.spark.repl.ExecutorClassLoader")
.asInstanceOf[Class[_ <: ClassLoader]] .asInstanceOf[Class[_ <: ClassLoader]]
val constructor = klass.getConstructor(classOf[SparkConf], classOf[String], val constructor = klass.getConstructor(classOf[SparkConf], classOf[SparkEnv],
classOf[ClassLoader], classOf[Boolean]) classOf[String], classOf[ClassLoader], classOf[Boolean])
constructor.newInstance(conf, classUri, parent, _userClassPathFirst) constructor.newInstance(conf, env, classUri, parent, _userClassPathFirst)
} catch { } catch {
case _: ClassNotFoundException => case _: ClassNotFoundException =>
logError("Could not find org.apache.spark.repl.ExecutorClassLoader on classpath!") logError("Could not find org.apache.spark.repl.ExecutorClassLoader on classpath!")

View file

@ -179,6 +179,24 @@ private[spark] trait RpcEnvFileServer {
*/ */
def addJar(file: File): String def addJar(file: File): String
/**
* Adds a local directory to be served via this file server.
*
* @param baseUri Leading URI path (files can be retrieved by appending their relative
* path to this base URI). This cannot be "files" nor "jars".
* @param path Path to the local directory.
* @return URI for the root of the directory in the file server.
*/
def addDirectory(baseUri: String, path: File): String
/** Validates and normalizes the base URI for directories. */
protected def validateDirectoryUri(baseUri: String): String = {
val fixedBaseUri = "/" + baseUri.stripPrefix("/").stripSuffix("/")
require(fixedBaseUri != "/files" && fixedBaseUri != "/jars",
"Directory URI cannot be /files nor /jars.")
fixedBaseUri
}
} }
private[spark] case class RpcEnvConfig( private[spark] case class RpcEnvConfig(

View file

@ -273,6 +273,11 @@ private[akka] class AkkaFileServer(
getFileServer().addJar(file) getFileServer().addJar(file)
} }
override def addDirectory(baseUri: String, path: File): String = {
val fixedBaseUri = validateDirectoryUri(baseUri)
getFileServer().addDirectory(fixedBaseUri, path.getAbsolutePath())
}
def shutdown(): Unit = { def shutdown(): Unit = {
if (httpFileServer != null) { if (httpFileServer != null) {
httpFileServer.stop() httpFileServer.stop()

View file

@ -25,12 +25,22 @@ import org.apache.spark.rpc.RpcEnvFileServer
/** /**
* StreamManager implementation for serving files from a NettyRpcEnv. * StreamManager implementation for serving files from a NettyRpcEnv.
*
* Three kinds of resources can be registered in this manager, all backed by actual files:
*
* - "/files": a flat list of files; used as the backend for [[SparkContext.addFile]].
* - "/jars": a flat list of files; used as the backend for [[SparkContext.addJar]].
* - arbitrary directories; all files under the directory become available through the manager,
* respecting the directory's hierarchy.
*
* Only streaming (openStream) is supported.
*/ */
private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv) private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv)
extends StreamManager with RpcEnvFileServer { extends StreamManager with RpcEnvFileServer {
private val files = new ConcurrentHashMap[String, File]() private val files = new ConcurrentHashMap[String, File]()
private val jars = new ConcurrentHashMap[String, File]() private val jars = new ConcurrentHashMap[String, File]()
private val dirs = new ConcurrentHashMap[String, File]()
override def getChunk(streamId: Long, chunkIndex: Int): ManagedBuffer = { override def getChunk(streamId: Long, chunkIndex: Int): ManagedBuffer = {
throw new UnsupportedOperationException() throw new UnsupportedOperationException()
@ -41,7 +51,10 @@ private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv)
val file = ftype match { val file = ftype match {
case "files" => files.get(fname) case "files" => files.get(fname)
case "jars" => jars.get(fname) case "jars" => jars.get(fname)
case _ => throw new IllegalArgumentException(s"Invalid file type: $ftype") case other =>
val dir = dirs.get(ftype)
require(dir != null, s"Invalid stream URI: $ftype not found.")
new File(dir, fname)
} }
require(file != null && file.isFile(), s"File not found: $streamId") require(file != null && file.isFile(), s"File not found: $streamId")
@ -60,4 +73,11 @@ private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv)
s"${rpcEnv.address.toSparkURL}/jars/${file.getName()}" s"${rpcEnv.address.toSparkURL}/jars/${file.getName()}"
} }
override def addDirectory(baseUri: String, path: File): String = {
val fixedBaseUri = validateDirectoryUri(baseUri)
require(dirs.putIfAbsent(fixedBaseUri.stripPrefix("/"), path) == null,
s"URI '$fixedBaseUri' already registered.")
s"${rpcEnv.address.toSparkURL}$fixedBaseUri"
}
} }

View file

@ -734,9 +734,28 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
val jar = new File(tempDir, "jar") val jar = new File(tempDir, "jar")
Files.write(UUID.randomUUID().toString(), jar, UTF_8) Files.write(UUID.randomUUID().toString(), jar, UTF_8)
val dir1 = new File(tempDir, "dir1")
assert(dir1.mkdir())
val subFile1 = new File(dir1, "file1")
Files.write(UUID.randomUUID().toString(), subFile1, UTF_8)
val dir2 = new File(tempDir, "dir2")
assert(dir2.mkdir())
val subFile2 = new File(dir2, "file2")
Files.write(UUID.randomUUID().toString(), subFile2, UTF_8)
val fileUri = env.fileServer.addFile(file) val fileUri = env.fileServer.addFile(file)
val emptyUri = env.fileServer.addFile(empty) val emptyUri = env.fileServer.addFile(empty)
val jarUri = env.fileServer.addJar(jar) val jarUri = env.fileServer.addJar(jar)
val dir1Uri = env.fileServer.addDirectory("/dir1", dir1)
val dir2Uri = env.fileServer.addDirectory("/dir2", dir2)
// Try registering directories with invalid names.
Seq("/files", "/jars").foreach { uri =>
intercept[IllegalArgumentException] {
env.fileServer.addDirectory(uri, dir1)
}
}
val destDir = Utils.createTempDir() val destDir = Utils.createTempDir()
val sm = new SecurityManager(conf) val sm = new SecurityManager(conf)
@ -745,7 +764,9 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
val files = Seq( val files = Seq(
(file, fileUri), (file, fileUri),
(empty, emptyUri), (empty, emptyUri),
(jar, jarUri)) (jar, jarUri),
(subFile1, dir1Uri + "/file1"),
(subFile2, dir2Uri + "/file2"))
files.foreach { case (f, uri) => files.foreach { case (f, uri) =>
val destFile = new File(destDir, f.getName()) val destFile = new File(destDir, f.getName())
Utils.fetchFile(uri, destDir, conf, sm, hc, 0L, false) Utils.fetchFile(uri, destDir, conf, sm, hc, 0L, false)
@ -753,7 +774,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
} }
// Try to download files that do not exist. // Try to download files that do not exist.
Seq("files", "jars").foreach { root => Seq("files", "jars", "dir1").foreach { root =>
intercept[Exception] { intercept[Exception] {
val uri = env.address.toSparkURL + s"/$root/doesNotExist" val uri = env.address.toSparkURL + s"/$root/doesNotExist"
Utils.fetchFile(uri, destDir, conf, sm, hc, 0L, false) Utils.fetchFile(uri, destDir, conf, sm, hc, 0L, false)

View file

@ -1053,14 +1053,6 @@ Apart from these, the following properties are also available, and may be useful
to port + maxRetries. to port + maxRetries.
</td> </td>
</tr> </tr>
<tr>
<td><code>spark.replClassServer.port</code></td>
<td>(random)</td>
<td>
Port for the driver's HTTP class server to listen on.
This is only relevant for the Spark shell.
</td>
</tr>
<tr> <tr>
<td><code>spark.rpc.numRetries</code></td> <td><code>spark.rpc.numRetries</code></td>
<td>3</td> <td>3</td>

View file

@ -169,14 +169,6 @@ configure those ports.
<td>Jetty-based. Not used by TorrentBroadcast, which sends data through the block manager <td>Jetty-based. Not used by TorrentBroadcast, which sends data through the block manager
instead.</td> instead.</td>
</tr> </tr>
<tr>
<td>Executor</td>
<td>Driver</td>
<td>(random)</td>
<td>Class file server</td>
<td><code>spark.replClassServer.port</code></td>
<td>Jetty-based. Only used in Spark shells.</td>
</tr>
<tr> <tr>
<td>Executor / Driver</td> <td>Executor / Driver</td>
<td>Executor / Driver</td> <td>Executor / Driver</td>

View file

@ -253,7 +253,7 @@ class SparkILoop(
case xs => xs find (_.name == cmd) case xs => xs find (_.name == cmd)
} }
} }
private var fallbackMode = false private var fallbackMode = false
private def toggleFallbackMode() { private def toggleFallbackMode() {
val old = fallbackMode val old = fallbackMode
@ -261,9 +261,9 @@ class SparkILoop(
System.setProperty("spark.repl.fallback", fallbackMode.toString) System.setProperty("spark.repl.fallback", fallbackMode.toString)
echo(s""" echo(s"""
|Switched ${if (old) "off" else "on"} fallback mode without restarting. |Switched ${if (old) "off" else "on"} fallback mode without restarting.
| If you have defined classes in the repl, it would | If you have defined classes in the repl, it would
|be good to redefine them incase you plan to use them. If you still run |be good to redefine them incase you plan to use them. If you still run
|into issues it would be good to restart the repl and turn on `:fallback` |into issues it would be good to restart the repl and turn on `:fallback`
|mode as first command. |mode as first command.
""".stripMargin) """.stripMargin)
} }
@ -350,7 +350,7 @@ class SparkILoop(
shCommand, shCommand,
nullary("silent", "disable/enable automatic printing of results", verbosity), nullary("silent", "disable/enable automatic printing of results", verbosity),
nullary("fallback", """ nullary("fallback", """
|disable/enable advanced repl changes, these fix some issues but may introduce others. |disable/enable advanced repl changes, these fix some issues but may introduce others.
|This mode will be removed once these fixes stablize""".stripMargin, toggleFallbackMode), |This mode will be removed once these fixes stablize""".stripMargin, toggleFallbackMode),
cmd("type", "[-v] <expr>", "display the type of an expression without evaluating it", typeCommand), cmd("type", "[-v] <expr>", "display the type of an expression without evaluating it", typeCommand),
nullary("warnings", "show the suppressed warnings from the most recent line which had any", warningsCommand) nullary("warnings", "show the suppressed warnings from the most recent line which had any", warningsCommand)
@ -1009,8 +1009,13 @@ class SparkILoop(
val conf = new SparkConf() val conf = new SparkConf()
.setMaster(getMaster()) .setMaster(getMaster())
.setJars(jars) .setJars(jars)
.set("spark.repl.class.uri", intp.classServerUri)
.setIfMissing("spark.app.name", "Spark shell") .setIfMissing("spark.app.name", "Spark shell")
// SparkContext will detect this configuration and register it with the RpcEnv's
// file server, setting spark.repl.class.uri to the actual URI for executors to
// use. This is sort of ugly but since executors are started as part of SparkContext
// initialization in certain cases, there's an initialization order issue that prevents
// this from being set after SparkContext is instantiated.
.set("spark.repl.class.outputDir", intp.outputDir.getAbsolutePath())
if (execUri != null) { if (execUri != null) {
conf.set("spark.executor.uri", execUri) conf.set("spark.executor.uri", execUri)
} }
@ -1025,7 +1030,7 @@ class SparkILoop(
val loader = Utils.getContextOrSparkClassLoader val loader = Utils.getContextOrSparkClassLoader
try { try {
sqlContext = loader.loadClass(name).getConstructor(classOf[SparkContext]) sqlContext = loader.loadClass(name).getConstructor(classOf[SparkContext])
.newInstance(sparkContext).asInstanceOf[SQLContext] .newInstance(sparkContext).asInstanceOf[SQLContext]
logInfo("Created sql context (with Hive support)..") logInfo("Created sql context (with Hive support)..")
} }
catch { catch {

View file

@ -37,7 +37,7 @@ import scala.reflect.{ ClassTag, classTag }
import scala.tools.reflect.StdRuntimeTags._ import scala.tools.reflect.StdRuntimeTags._
import scala.util.control.ControlThrowable import scala.util.control.ControlThrowable
import org.apache.spark.{Logging, HttpServer, SecurityManager, SparkConf} import org.apache.spark.{Logging, SparkConf, SparkContext}
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
import org.apache.spark.annotation.DeveloperApi import org.apache.spark.annotation.DeveloperApi
@ -96,10 +96,9 @@ import org.apache.spark.annotation.DeveloperApi
private val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1") private val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1")
/** Local directory to save .class files too */ /** Local directory to save .class files too */
private lazy val outputDir = { private[repl] val outputDir = {
val tmp = System.getProperty("java.io.tmpdir") val rootDir = conf.getOption("spark.repl.classdir").getOrElse(Utils.getLocalDir(conf))
val rootDir = conf.get("spark.repl.classdir", tmp) Utils.createTempDir(root = rootDir, namePrefix = "repl")
Utils.createTempDir(rootDir)
} }
if (SPARK_DEBUG_REPL) { if (SPARK_DEBUG_REPL) {
echo("Output directory: " + outputDir) echo("Output directory: " + outputDir)
@ -114,8 +113,6 @@ import org.apache.spark.annotation.DeveloperApi
private val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles private val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles
/** Jetty server that will serve our classes to worker nodes */ /** Jetty server that will serve our classes to worker nodes */
private val classServerPort = conf.getInt("spark.replClassServer.port", 0)
private val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf), classServerPort, "HTTP class server")
private var currentSettings: Settings = initialSettings private var currentSettings: Settings = initialSettings
private var printResults = true // whether to print result lines private var printResults = true // whether to print result lines
private var totalSilence = false // whether to print anything private var totalSilence = false // whether to print anything
@ -124,22 +121,6 @@ import org.apache.spark.annotation.DeveloperApi
private var bindExceptions = true // whether to bind the lastException variable private var bindExceptions = true // whether to bind the lastException variable
private var _executionWrapper = "" // code to be wrapped around all lines private var _executionWrapper = "" // code to be wrapped around all lines
// Start the classServer and store its URI in a spark system property
// (which will be passed to executors so that they can connect to it)
classServer.start()
if (SPARK_DEBUG_REPL) {
echo("Class server started, URI = " + classServer.uri)
}
/**
* URI of the class server used to feed REPL compiled classes.
*
* @return The string representing the class server uri
*/
@DeveloperApi
def classServerUri = classServer.uri
/** We're going to go to some trouble to initialize the compiler asynchronously. /** We're going to go to some trouble to initialize the compiler asynchronously.
* It's critical that nothing call into it until it's been initialized or we will * It's critical that nothing call into it until it's been initialized or we will
* run into unrecoverable issues, but the perceived repl startup time goes * run into unrecoverable issues, but the perceived repl startup time goes
@ -994,7 +975,6 @@ import org.apache.spark.annotation.DeveloperApi
@DeveloperApi @DeveloperApi
def close() { def close() {
reporter.flush() reporter.flush()
classServer.stop()
} }
/** /**

View file

@ -28,11 +28,13 @@ import org.apache.spark.sql.SQLContext
object Main extends Logging { object Main extends Logging {
val conf = new SparkConf() val conf = new SparkConf()
val tmp = System.getProperty("java.io.tmpdir") val rootDir = conf.getOption("spark.repl.classdir").getOrElse(Utils.getLocalDir(conf))
val rootDir = conf.get("spark.repl.classdir", tmp) val outputDir = Utils.createTempDir(root = rootDir, namePrefix = "repl")
val outputDir = Utils.createTempDir(rootDir) val s = new Settings()
s.processArguments(List("-Yrepl-class-based",
"-Yrepl-outdir", s"${outputDir.getAbsolutePath}",
"-classpath", getAddedJars.mkString(File.pathSeparator)), true)
// the creation of SecurityManager has to be lazy so SPARK_YARN_MODE is set if needed // the creation of SecurityManager has to be lazy so SPARK_YARN_MODE is set if needed
lazy val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf))
var sparkContext: SparkContext = _ var sparkContext: SparkContext = _
var sqlContext: SQLContext = _ var sqlContext: SQLContext = _
var interp = new SparkILoop // this is a public var because tests reset it. var interp = new SparkILoop // this is a public var because tests reset it.
@ -45,7 +47,6 @@ object Main extends Logging {
} }
def main(args: Array[String]) { def main(args: Array[String]) {
val interpArguments = List( val interpArguments = List(
"-Yrepl-class-based", "-Yrepl-class-based",
"-Yrepl-outdir", s"${outputDir.getAbsolutePath}", "-Yrepl-outdir", s"${outputDir.getAbsolutePath}",
@ -57,11 +58,7 @@ object Main extends Logging {
if (!hasErrors) { if (!hasErrors) {
if (getMaster == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") if (getMaster == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true")
// Start the classServer and store its URI in a spark system property
// (which will be passed to executors so that they can connect to it)
classServer.start()
interp.process(settings) // Repl starts and goes in loop of R.E.P.L interp.process(settings) // Repl starts and goes in loop of R.E.P.L
classServer.stop()
Option(sparkContext).map(_.stop) Option(sparkContext).map(_.stop)
} }
} }
@ -82,9 +79,13 @@ object Main extends Logging {
val conf = new SparkConf() val conf = new SparkConf()
.setMaster(getMaster) .setMaster(getMaster)
.setJars(jars) .setJars(jars)
.set("spark.repl.class.uri", classServer.uri)
.setIfMissing("spark.app.name", "Spark shell") .setIfMissing("spark.app.name", "Spark shell")
logInfo("Spark class server started at " + classServer.uri) // SparkContext will detect this configuration and register it with the RpcEnv's
// file server, setting spark.repl.class.uri to the actual URI for executors to
// use. This is sort of ugly but since executors are started as part of SparkContext
// initialization in certain cases, there's an initialization order issue that prevents
// this from being set after SparkContext is instantiated.
.set("spark.repl.class.outputDir", outputDir.getAbsolutePath())
if (execUri != null) { if (execUri != null) {
conf.set("spark.executor.uri", execUri) conf.set("spark.executor.uri", execUri)
} }

View file

@ -19,6 +19,7 @@ package org.apache.spark.repl
import java.io.{IOException, ByteArrayOutputStream, InputStream} import java.io.{IOException, ByteArrayOutputStream, InputStream}
import java.net.{HttpURLConnection, URI, URL, URLEncoder} import java.net.{HttpURLConnection, URI, URL, URLEncoder}
import java.nio.channels.Channels
import scala.util.control.NonFatal import scala.util.control.NonFatal
@ -38,7 +39,11 @@ import org.apache.spark.util.ParentClassLoader
* This class loader delegates getting/finding resources to parent loader, * This class loader delegates getting/finding resources to parent loader,
* which makes sense until REPL never provide resource dynamically. * which makes sense until REPL never provide resource dynamically.
*/ */
class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader, class ExecutorClassLoader(
conf: SparkConf,
env: SparkEnv,
classUri: String,
parent: ClassLoader,
userClassPathFirst: Boolean) extends ClassLoader with Logging { userClassPathFirst: Boolean) extends ClassLoader with Logging {
val uri = new URI(classUri) val uri = new URI(classUri)
val directory = uri.getPath val directory = uri.getPath
@ -48,13 +53,12 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader
// Allows HTTP connect and read timeouts to be controlled for testing / debugging purposes // Allows HTTP connect and read timeouts to be controlled for testing / debugging purposes
private[repl] var httpUrlConnectionTimeoutMillis: Int = -1 private[repl] var httpUrlConnectionTimeoutMillis: Int = -1
// Hadoop FileSystem object for our URI, if it isn't using HTTP private val fetchFn: (String) => InputStream = uri.getScheme() match {
var fileSystem: FileSystem = { case "spark" => getClassFileInputStreamFromSparkRPC
if (Set("http", "https", "ftp").contains(uri.getScheme)) { case "http" | "https" | "ftp" => getClassFileInputStreamFromHttpServer
null case _ =>
} else { val fileSystem = FileSystem.get(uri, SparkHadoopUtil.get.newConfiguration(conf))
FileSystem.get(uri, SparkHadoopUtil.get.newConfiguration(conf)) getClassFileInputStreamFromFileSystem(fileSystem)
}
} }
override def getResource(name: String): URL = { override def getResource(name: String): URL = {
@ -90,6 +94,11 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader
} }
} }
private def getClassFileInputStreamFromSparkRPC(path: String): InputStream = {
val channel = env.rpcEnv.openChannel(s"$classUri/$path")
Channels.newInputStream(channel)
}
private def getClassFileInputStreamFromHttpServer(pathInDirectory: String): InputStream = { private def getClassFileInputStreamFromHttpServer(pathInDirectory: String): InputStream = {
val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) { val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) {
val uri = new URI(classUri + "/" + urlEncode(pathInDirectory)) val uri = new URI(classUri + "/" + urlEncode(pathInDirectory))
@ -126,7 +135,8 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader
} }
} }
private def getClassFileInputStreamFromFileSystem(pathInDirectory: String): InputStream = { private def getClassFileInputStreamFromFileSystem(fileSystem: FileSystem)(
pathInDirectory: String): InputStream = {
val path = new Path(directory, pathInDirectory) val path = new Path(directory, pathInDirectory)
if (fileSystem.exists(path)) { if (fileSystem.exists(path)) {
fileSystem.open(path) fileSystem.open(path)
@ -139,13 +149,7 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader
val pathInDirectory = name.replace('.', '/') + ".class" val pathInDirectory = name.replace('.', '/') + ".class"
var inputStream: InputStream = null var inputStream: InputStream = null
try { try {
inputStream = { inputStream = fetchFn(pathInDirectory)
if (fileSystem != null) {
getClassFileInputStreamFromFileSystem(pathInDirectory)
} else {
getClassFileInputStreamFromHttpServer(pathInDirectory)
}
}
val bytes = readAndTransformClass(name, inputStream) val bytes = readAndTransformClass(name, inputStream)
Some(defineClass(name, bytes, 0, bytes.length)) Some(defineClass(name, bytes, 0, bytes.length))
} catch { } catch {

View file

@ -18,24 +18,29 @@
package org.apache.spark.repl package org.apache.spark.repl
import java.io.File import java.io.File
import java.net.{URL, URLClassLoader} import java.net.{URI, URL, URLClassLoader}
import java.nio.channels.{FileChannel, ReadableByteChannel}
import java.nio.charset.StandardCharsets import java.nio.charset.StandardCharsets
import java.nio.file.{Paths, StandardOpenOption}
import java.util import java.util
import com.google.common.io.Files
import scala.concurrent.duration._ import scala.concurrent.duration._
import scala.io.Source import scala.io.Source
import scala.language.implicitConversions import scala.language.implicitConversions
import scala.language.postfixOps import scala.language.postfixOps
import com.google.common.io.Files
import org.scalatest.BeforeAndAfterAll import org.scalatest.BeforeAndAfterAll
import org.scalatest.concurrent.Interruptor import org.scalatest.concurrent.Interruptor
import org.scalatest.concurrent.Timeouts._ import org.scalatest.concurrent.Timeouts._
import org.scalatest.mock.MockitoSugar import org.scalatest.mock.MockitoSugar
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.mockito.Matchers.anyString
import org.mockito.Mockito._ import org.mockito.Mockito._
import org.apache.spark._ import org.apache.spark._
import org.apache.spark.rpc.RpcEnv
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
class ExecutorClassLoaderSuite class ExecutorClassLoaderSuite
@ -78,7 +83,7 @@ class ExecutorClassLoaderSuite
test("child first") { test("child first") {
val parentLoader = new URLClassLoader(urls2, null) val parentLoader = new URLClassLoader(urls2, null)
val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true)
val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance() val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance()
val fakeClassVersion = fakeClass.toString val fakeClassVersion = fakeClass.toString
assert(fakeClassVersion === "1") assert(fakeClassVersion === "1")
@ -86,7 +91,7 @@ class ExecutorClassLoaderSuite
test("parent first") { test("parent first") {
val parentLoader = new URLClassLoader(urls2, null) val parentLoader = new URLClassLoader(urls2, null)
val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, false) val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, false)
val fakeClass = classLoader.loadClass("ReplFakeClass1").newInstance() val fakeClass = classLoader.loadClass("ReplFakeClass1").newInstance()
val fakeClassVersion = fakeClass.toString val fakeClassVersion = fakeClass.toString
assert(fakeClassVersion === "2") assert(fakeClassVersion === "2")
@ -94,7 +99,7 @@ class ExecutorClassLoaderSuite
test("child first can fall back") { test("child first can fall back") {
val parentLoader = new URLClassLoader(urls2, null) val parentLoader = new URLClassLoader(urls2, null)
val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true)
val fakeClass = classLoader.loadClass("ReplFakeClass3").newInstance() val fakeClass = classLoader.loadClass("ReplFakeClass3").newInstance()
val fakeClassVersion = fakeClass.toString val fakeClassVersion = fakeClass.toString
assert(fakeClassVersion === "2") assert(fakeClassVersion === "2")
@ -102,7 +107,7 @@ class ExecutorClassLoaderSuite
test("child first can fail") { test("child first can fail") {
val parentLoader = new URLClassLoader(urls2, null) val parentLoader = new URLClassLoader(urls2, null)
val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true)
intercept[java.lang.ClassNotFoundException] { intercept[java.lang.ClassNotFoundException] {
classLoader.loadClass("ReplFakeClassDoesNotExist").newInstance() classLoader.loadClass("ReplFakeClassDoesNotExist").newInstance()
} }
@ -110,7 +115,7 @@ class ExecutorClassLoaderSuite
test("resource from parent") { test("resource from parent") {
val parentLoader = new URLClassLoader(urls2, null) val parentLoader = new URLClassLoader(urls2, null)
val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true)
val resourceName: String = parentResourceNames.head val resourceName: String = parentResourceNames.head
val is = classLoader.getResourceAsStream(resourceName) val is = classLoader.getResourceAsStream(resourceName)
assert(is != null, s"Resource $resourceName not found") assert(is != null, s"Resource $resourceName not found")
@ -120,7 +125,7 @@ class ExecutorClassLoaderSuite
test("resources from parent") { test("resources from parent") {
val parentLoader = new URLClassLoader(urls2, null) val parentLoader = new URLClassLoader(urls2, null)
val classLoader = new ExecutorClassLoader(new SparkConf(), url1, parentLoader, true) val classLoader = new ExecutorClassLoader(new SparkConf(), null, url1, parentLoader, true)
val resourceName: String = parentResourceNames.head val resourceName: String = parentResourceNames.head
val resources: util.Enumeration[URL] = classLoader.getResources(resourceName) val resources: util.Enumeration[URL] = classLoader.getResources(resourceName)
assert(resources.hasMoreElements, s"Resource $resourceName not found") assert(resources.hasMoreElements, s"Resource $resourceName not found")
@ -142,7 +147,7 @@ class ExecutorClassLoaderSuite
SparkEnv.set(mockEnv) SparkEnv.set(mockEnv)
// Create an ExecutorClassLoader that's configured to load classes from the HTTP server // Create an ExecutorClassLoader that's configured to load classes from the HTTP server
val parentLoader = new URLClassLoader(Array.empty, null) val parentLoader = new URLClassLoader(Array.empty, null)
val classLoader = new ExecutorClassLoader(conf, classServer.uri, parentLoader, false) val classLoader = new ExecutorClassLoader(conf, null, classServer.uri, parentLoader, false)
classLoader.httpUrlConnectionTimeoutMillis = 500 classLoader.httpUrlConnectionTimeoutMillis = 500
// Check that this class loader can actually load classes that exist // Check that this class loader can actually load classes that exist
val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance() val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance()
@ -177,4 +182,27 @@ class ExecutorClassLoaderSuite
failAfter(10 seconds)(tryAndFailToLoadABunchOfClasses())(interruptor) failAfter(10 seconds)(tryAndFailToLoadABunchOfClasses())(interruptor)
} }
test("fetch classes using Spark's RpcEnv") {
val env = mock[SparkEnv]
val rpcEnv = mock[RpcEnv]
when(env.rpcEnv).thenReturn(rpcEnv)
when(rpcEnv.openChannel(anyString())).thenAnswer(new Answer[ReadableByteChannel]() {
override def answer(invocation: InvocationOnMock): ReadableByteChannel = {
val uri = new URI(invocation.getArguments()(0).asInstanceOf[String])
val path = Paths.get(tempDir1.getAbsolutePath(), uri.getPath().stripPrefix("/"))
FileChannel.open(path, StandardOpenOption.READ)
}
})
val classLoader = new ExecutorClassLoader(new SparkConf(), env, "spark://localhost:1234",
getClass().getClassLoader(), false)
val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance()
val fakeClassVersion = fakeClass.toString
assert(fakeClassVersion === "1")
intercept[java.lang.ClassNotFoundException] {
classLoader.loadClass("ReplFakeClassDoesNotExist").newInstance()
}
}
} }