package spark import java.io._ import java.net.{InetAddress, URL, URI} import java.util.{Locale, Random, UUID} import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, FileSystem, FileUtil} import scala.collection.mutable.ArrayBuffer import scala.io.Source /** * Various utility methods used by Spark. */ object Utils extends Logging { /** Serialize an object using Java serialization */ def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream() val oos = new ObjectOutputStream(bos) oos.writeObject(o) oos.close() return bos.toByteArray } /** Deserialize an object using Java serialization */ def deserialize[T](bytes: Array[Byte]): T = { val bis = new ByteArrayInputStream(bytes) val ois = new ObjectInputStream(bis) return ois.readObject.asInstanceOf[T] } /** Deserialize an object using Java serialization and the given ClassLoader */ def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = { val bis = new ByteArrayInputStream(bytes) val ois = new ObjectInputStream(bis) { override def resolveClass(desc: ObjectStreamClass) = Class.forName(desc.getName, false, loader) } return ois.readObject.asInstanceOf[T] } def isAlpha(c: Char): Boolean = { (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') } /** Split a string into words at non-alphabetic characters */ def splitWords(s: String): Seq[String] = { val buf = new ArrayBuffer[String] var i = 0 while (i < s.length) { var j = i while (j < s.length && isAlpha(s.charAt(j))) { j += 1 } if (j > i) { buf += s.substring(i, j) } i = j while (i < s.length && !isAlpha(s.charAt(i))) { i += 1 } } return buf } /** Create a temporary directory inside the given parent directory */ def createTempDir(root: String = System.getProperty("java.io.tmpdir")): File = { var attempts = 0 val maxAttempts = 10 var dir: File = null while (dir == null) { attempts += 1 if (attempts > maxAttempts) { throw new IOException("Failed to create a temp directory after " + maxAttempts + " attempts!") } try { dir = new File(root, "spark-" + UUID.randomUUID.toString) if (dir.exists() || !dir.mkdirs()) { dir = null } } catch { case e: IOException => ; } } // Add a shutdown hook to delete the temp dir when the JVM exits Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dir " + dir) { override def run() { Utils.deleteRecursively(dir) } }) return dir } /** Copy all data from an InputStream to an OutputStream */ def copyStream(in: InputStream, out: OutputStream, closeStreams: Boolean = false) { val buf = new Array[Byte](8192) var n = 0 while (n != -1) { n = in.read(buf) if (n != -1) { out.write(buf, 0, n) } } if (closeStreams) { in.close() out.close() } } /** Copy a file on the local file system */ def copyFile(source: File, dest: File) { val in = new FileInputStream(source) val out = new FileOutputStream(dest) copyStream(in, out, true) } /* Download a file from a given URL to the local filesystem */ def downloadFile(url: URL, localPath: String) { val in = url.openStream() val out = new FileOutputStream(localPath) Utils.copyStream(in, out, true) } /** * Download a file requested by the executor. Supports fetching the file in a variety of ways, * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. */ def fetchFile(url: String, targetDir: File) { val filename = url.split("/").last val targetFile = new File(targetDir, filename) val uri = new URI(url) uri.getScheme match { case "http" | "https" | "ftp" => logInfo("Fetching " + url + " to " + targetFile) val in = new URL(url).openStream() val out = new FileOutputStream(targetFile) Utils.copyStream(in, out, true) case "file" | null => // Remove the file if it already exists targetFile.delete() // Symlink the file locally logInfo("Symlinking " + url + " to " + targetFile) FileUtil.symLink(url, targetFile.toString) case _ => // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others val uri = new URI(url) val conf = new Configuration() val fs = FileSystem.get(uri, conf) val in = fs.open(new Path(uri)) val out = new FileOutputStream(targetFile) Utils.copyStream(in, out, true) } // Decompress the file if it's a .tar or .tar.gz if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) { logInfo("Untarring " + filename) Utils.execute(Seq("tar", "-xzf", filename), targetDir) } else if (filename.endsWith(".tar")) { logInfo("Untarring " + filename) Utils.execute(Seq("tar", "-xf", filename), targetDir) } // Make the file executable - That's necessary for scripts FileUtil.chmod(filename, "a+x") } /** * Shuffle the elements of a collection into a random order, returning the * result in a new collection. Unlike scala.util.Random.shuffle, this method * uses a local random number generator, avoiding inter-thread contention. */ def randomize[T: ClassManifest](seq: TraversableOnce[T]): Seq[T] = { randomizeInPlace(seq.toArray) } /** * Shuffle the elements of an array into a random order, modifying the * original array. Returns the original array. */ def randomizeInPlace[T](arr: Array[T], rand: Random = new Random): Array[T] = { for (i <- (arr.length - 1) to 1 by -1) { val j = rand.nextInt(i) val tmp = arr(j) arr(j) = arr(i) arr(i) = tmp } arr } /** * Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4). */ def localIpAddress(): String = InetAddress.getLocalHost.getHostAddress private var customHostname: Option[String] = None /** * Allow setting a custom host name because when we run on Mesos we need to use the same * hostname it reports to the master. */ def setCustomHostname(hostname: String) { customHostname = Some(hostname) } /** * Get the local machine's hostname. */ def localHostName(): String = { customHostname.getOrElse(InetAddress.getLocalHost.getHostName) } /** * Returns a standard ThreadFactory except all threads are daemons. */ private def newDaemonThreadFactory: ThreadFactory = { new ThreadFactory { def newThread(r: Runnable): Thread = { var t = Executors.defaultThreadFactory.newThread (r) t.setDaemon (true) return t } } } /** * Wrapper over newCachedThreadPool. */ def newDaemonCachedThreadPool(): ThreadPoolExecutor = { var threadPool = Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor] threadPool.setThreadFactory (newDaemonThreadFactory) return threadPool } /** * Return the string to tell how long has passed in seconds. The passing parameter should be in * millisecond. */ def getUsedTimeMs(startTimeMs: Long): String = { return " " + (System.currentTimeMillis - startTimeMs) + " ms " } /** * Wrapper over newFixedThreadPool. */ def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = { var threadPool = Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor] threadPool.setThreadFactory(newDaemonThreadFactory) return threadPool } /** * Delete a file or directory and its contents recursively. */ def deleteRecursively(file: File) { if (file.isDirectory) { for (child <- file.listFiles()) { deleteRecursively(child) } } if (!file.delete()) { throw new IOException("Failed to delete: " + file) } } /** * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes. * This is used to figure out how much memory to claim from Mesos based on the SPARK_MEM * environment variable. */ def memoryStringToMb(str: String): Int = { val lower = str.toLowerCase if (lower.endsWith("k")) { (lower.substring(0, lower.length-1).toLong / 1024).toInt } else if (lower.endsWith("m")) { lower.substring(0, lower.length-1).toInt } else if (lower.endsWith("g")) { lower.substring(0, lower.length-1).toInt * 1024 } else if (lower.endsWith("t")) { lower.substring(0, lower.length-1).toInt * 1024 * 1024 } else {// no suffix, so it's just a number in bytes (lower.toLong / 1024 / 1024).toInt } } /** * Convert a memory quantity in bytes to a human-readable string such as "4.0 MB". */ def memoryBytesToString(size: Long): String = { val TB = 1L << 40 val GB = 1L << 30 val MB = 1L << 20 val KB = 1L << 10 val (value, unit) = { if (size >= 2*TB) { (size.asInstanceOf[Double] / TB, "TB") } else if (size >= 2*GB) { (size.asInstanceOf[Double] / GB, "GB") } else if (size >= 2*MB) { (size.asInstanceOf[Double] / MB, "MB") } else if (size >= 2*KB) { (size.asInstanceOf[Double] / KB, "KB") } else { (size.asInstanceOf[Double], "B") } } "%.1f %s".formatLocal(Locale.US, value, unit) } /** * Convert a memory quantity in megabytes to a human-readable string such as "4.0 MB". */ def memoryMegabytesToString(megabytes: Long): String = { memoryBytesToString(megabytes * 1024L * 1024L) } /** * Execute a command in the given working directory, throwing an exception if it completes * with an exit code other than 0. */ def execute(command: Seq[String], workingDir: File) { val process = new ProcessBuilder(command: _*) .directory(workingDir) .redirectErrorStream(true) .start() new Thread("read stdout for " + command(0)) { override def run() { for (line <- Source.fromInputStream(process.getInputStream).getLines) { System.err.println(line) } } }.start() val exitCode = process.waitFor() if (exitCode != 0) { throw new SparkException("Process " + command + " exited with code " + exitCode) } } /** * Execute a command in the current working directory, throwing an exception if it completes * with an exit code other than 0. */ def execute(command: Seq[String]) { execute(command, new File(".")) } }