diff --git a/src/scala/spark/HttpServer.scala b/src/scala/spark/HttpServer.scala index d5bdd245bb..d2a663ac1f 100644 --- a/src/scala/spark/HttpServer.scala +++ b/src/scala/spark/HttpServer.scala @@ -61,18 +61,7 @@ class HttpServer(resourceBase: File) extends Logging { if (server == null) { throw new ServerStateException("Server is not started") } else { - return "http://" + getLocalIpAddress + ":" + port + return "http://" + Utils.localIpAddress + ":" + port } } - - /** - * Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4) - */ - private def getLocalIpAddress: String = { - // Get local IP as an array of four bytes - val bytes = InetAddress.getLocalHost().getAddress() - // Convert the bytes to ints (keeping in mind that they may be negative) - // and join them into a string - return bytes.map(b => (b.toInt + 256) % 256).mkString(".") - } } diff --git a/src/scala/spark/LocalFileShuffle.scala b/src/scala/spark/LocalFileShuffle.scala index b5b5e7267d..367599cfb4 100644 --- a/src/scala/spark/LocalFileShuffle.scala +++ b/src/scala/spark/LocalFileShuffle.scala @@ -97,19 +97,24 @@ object LocalFileShuffle extends Logging { private var nextShuffleId = new AtomicLong(0) // Variables initialized by initializeIfNeeded() - private var localDir: File = null + private var shuffleDir: File = null private var server: HttpServer = null private var serverUri: String = null private def initializeIfNeeded() = synchronized { if (!initialized) { + // TODO: localDir should be created by some mechanism common to Spark + // so that it can be shared among shuffle, broadcast, etc val localDirRoot = System.getProperty("spark.local.dir", "/tmp") var tries = 0 var foundLocalDir = false + var localDir: File = null + var localDirUuid: UUID = null while (!foundLocalDir && tries < 10) { tries += 1 try { - localDir = new File(localDirRoot, "spark-local-" + UUID.randomUUID()) + localDirUuid = UUID.randomUUID() + localDir = new File(localDirRoot, "spark-local-" + localDirUuid) if (!localDir.exists()) { localDir.mkdirs() foundLocalDir = true @@ -123,17 +128,33 @@ object LocalFileShuffle extends Logging { logError("Failed 10 attempts to create local dir in " + localDirRoot) System.exit(1) } - logInfo("Local dir: " + localDir) - server = new HttpServer(localDir) - server.start() - serverUri = server.uri + shuffleDir = new File(localDir, "shuffle") + shuffleDir.mkdirs() + logInfo("Shuffle dir: " + shuffleDir) + val extServerPort = System.getProperty( + "spark.localFileShuffle.external.server.port", "-1").toInt + if (extServerPort != -1) { + // We're using an external HTTP server; set URI relative to its root + var extServerPath = System.getProperty( + "spark.localFileShuffle.external.server.path", "") + if (extServerPath != "" && !extServerPath.endsWith("/")) { + extServerPath += "/" + } + serverUri = "http://%s:%d/%s/spark-local-%s".format( + Utils.localIpAddress, extServerPort, extServerPath, localDirUuid) + } else { + // Create our own server + server = new HttpServer(localDir) + server.start() + serverUri = server.uri + } initialized = true } } def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int): File = { initializeIfNeeded() - val dir = new File(localDir, "shuffle/" + shuffleId + "/" + inputId) + val dir = new File(shuffleDir, shuffleId + "/" + inputId) dir.mkdirs() val file = new File(dir, "" + outputId) return file diff --git a/src/scala/spark/Utils.scala b/src/scala/spark/Utils.scala index 025472633b..e333dd9c91 100644 --- a/src/scala/spark/Utils.scala +++ b/src/scala/spark/Utils.scala @@ -1,6 +1,7 @@ package spark import java.io._ +import java.net.InetAddress import java.util.UUID import scala.collection.mutable.ArrayBuffer @@ -112,4 +113,15 @@ object Utils { } buf } + + /** + * Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4) + */ + def localIpAddress(): String = { + // Get local IP as an array of four bytes + val bytes = InetAddress.getLocalHost().getAddress() + // Convert the bytes to ints (keeping in mind that they may be negative) + // and join them into a string + return bytes.map(b => (b.toInt + 256) % 256).mkString(".") + } }