diff --git a/.gitignore b/.gitignore index e1f64a1133..b3c4363af0 100644 --- a/.gitignore +++ b/.gitignore @@ -41,3 +41,4 @@ derby.log dist/ spark-*-bin.tar.gz unit-tests.log +lib/ diff --git a/README.md b/README.md index 456b8060ef..1550a8b551 100644 --- a/README.md +++ b/README.md @@ -12,9 +12,8 @@ This README file only contains basic setup instructions. ## Building -Spark requires Scala 2.9.3 (Scala 2.10 is not yet supported). The project is -built using Simple Build Tool (SBT), which is packaged with it. To build -Spark and its example programs, run: +Spark requires Scala 2.10. The project is built using Simple Build Tool (SBT), +which is packaged with it. To build Spark and its example programs, run: sbt/sbt assembly @@ -55,7 +54,7 @@ versions without YARN, use: # Cloudera CDH 4.2.0 with MapReduce v1 $ SPARK_HADOOP_VERSION=2.0.0-mr1-cdh4.2.0 sbt/sbt assembly -For Apache Hadoop 2.x, 0.23.x, Cloudera CDH MRv2, and other Hadoop versions +For Apache Hadoop 2.2.X, 2.1.X, 2.0.X, 0.23.x, Cloudera CDH MRv2, and other Hadoop versions with YARN, also set `SPARK_YARN=true`: # Apache Hadoop 2.0.5-alpha @@ -64,8 +63,8 @@ with YARN, also set `SPARK_YARN=true`: # Cloudera CDH 4.2.0 with MapReduce v2 $ SPARK_HADOOP_VERSION=2.0.0-cdh4.2.0 SPARK_YARN=true sbt/sbt assembly -For convenience, these variables may also be set through the `conf/spark-env.sh` file -described below. + # Apache Hadoop 2.2.X and newer + $ SPARK_HADOOP_VERSION=2.2.0 SPARK_YARN=true sbt/sbt assembly When developing a Spark application, specify the Hadoop version by adding the "hadoop-client" artifact to your project's dependencies. For example, if you're diff --git a/assembly/pom.xml b/assembly/pom.xml index 09df8c1fd7..fc2adc1fbb 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -26,7 +26,7 @@ org.apache.spark - spark-assembly_2.9.3 + spark-assembly_2.10 Spark Project Assembly http://spark.incubator.apache.org/ @@ -41,27 +41,27 @@ org.apache.spark - spark-core_2.9.3 + spark-core_${scala.binary.version} ${project.version} org.apache.spark - spark-bagel_2.9.3 + spark-bagel_${scala.binary.version} ${project.version} org.apache.spark - spark-mllib_2.9.3 + spark-mllib_${scala.binary.version} ${project.version} org.apache.spark - spark-repl_2.9.3 + spark-repl_${scala.binary.version} ${project.version} org.apache.spark - spark-streaming_2.9.3 + spark-streaming_${scala.binary.version} ${project.version} @@ -79,7 +79,7 @@ maven-shade-plugin false - ${project.build.directory}/scala-${scala.version}/${project.artifactId}-${project.version}-hadoop${hadoop.version}.jar + ${project.build.directory}/scala-${scala.binary.version}/${project.artifactId}-${project.version}-hadoop${hadoop.version}.jar *:* @@ -128,7 +128,7 @@ org.apache.spark - spark-yarn_2.9.3 + spark-yarn_${scala.binary.version} ${project.version} diff --git a/bagel/pom.xml b/bagel/pom.xml index 0e552c880f..cb8e79f225 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -26,7 +26,7 @@ org.apache.spark - spark-bagel_2.9.3 + spark-bagel_2.10 jar Spark Project Bagel http://spark.incubator.apache.org/ @@ -34,7 +34,7 @@ org.apache.spark - spark-core_2.9.3 + spark-core_${scala.binary.version} ${project.version} @@ -43,18 +43,18 @@ org.scalatest - scalatest_2.9.3 + scalatest_${scala.binary.version} test org.scalacheck - scalacheck_2.9.3 + scalacheck_${scala.binary.version} test - target/scala-${scala.version}/classes - target/scala-${scala.version}/test-classes + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes org.scalatest diff --git a/bin/compute-classpath.cmd b/bin/compute-classpath.cmd index cf38188c4b..9e3e10ecaa 100644 --- a/bin/compute-classpath.cmd +++ b/bin/compute-classpath.cmd @@ -20,7 +20,7 @@ rem rem This script computes Spark's classpath and prints it to stdout; it's used by both the "run" rem script and the ExecutorRunner in standalone cluster mode. -set SCALA_VERSION=2.9.3 +set SCALA_VERSION=2.10 rem Figure out where the Spark framework is installed set FWDIR=%~dp0..\ diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh index c16afd6b36..40555089fc 100755 --- a/bin/compute-classpath.sh +++ b/bin/compute-classpath.sh @@ -20,7 +20,7 @@ # This script computes Spark's classpath and prints it to stdout; it's used by both the "run" # script and the ExecutorRunner in standalone cluster mode. -SCALA_VERSION=2.9.3 +SCALA_VERSION=2.10 # Figure out where Spark is installed FWDIR="$(cd `dirname $0`/..; pwd)" diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template index ae10f615d1..1c3d94e1b0 100644 --- a/conf/metrics.properties.template +++ b/conf/metrics.properties.template @@ -80,6 +80,14 @@ # /metrics/aplications/json # App information # /metrics/master/json # Master information +# org.apache.spark.metrics.sink.GraphiteSink +# Name: Default: Description: +# host NONE Hostname of Graphite server +# port NONE Port of Graphite server +# period 10 Poll period +# unit seconds Units of poll period +# prefix EMPTY STRING Prefix to prepend to metric name + ## Examples # Enable JmxSink for all instances by class name #*.sink.jmx.class=org.apache.spark.metrics.sink.JmxSink diff --git a/core/pom.xml b/core/pom.xml index 8621d257e5..043f6cf68d 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -26,7 +26,7 @@ org.apache.spark - spark-core_2.9.3 + spark-core_2.10 jar Spark Project Core http://spark.incubator.apache.org/ @@ -80,13 +80,9 @@ org.ow2.asm asm - - com.google.protobuf - protobuf-java - com.twitter - chill_2.9.3 + chill_${scala.binary.version} 0.3.1 @@ -95,20 +91,12 @@ 0.3.1 - com.typesafe.akka - akka-actor + ${akka.group} + akka-remote_${scala.binary.version} - com.typesafe.akka - akka-remote - - - com.typesafe.akka - akka-slf4j - - - org.scala-lang - scalap + ${akka.group} + akka-slf4j_${scala.binary.version} org.scala-lang @@ -116,7 +104,7 @@ net.liftweb - lift-json_2.9.2 + lift-json_${scala.binary.version} it.unimi.dsi @@ -126,10 +114,6 @@ colt colt - - com.github.scala-incubator.io - scala-io-file_2.9.2 - org.apache.mesos mesos @@ -158,19 +142,28 @@ com.codahale.metrics metrics-ganglia + + com.codahale.metrics + metrics-graphite + org.apache.derby derby test + + commons-io + commons-io + test + org.scalatest - scalatest_2.9.3 + scalatest_${scala.binary.version} test org.scalacheck - scalacheck_2.9.3 + scalacheck_${scala.binary.version} test @@ -190,8 +183,8 @@ - target/scala-${scala.version}/classes - target/scala-${scala.version}/test-classes + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes org.apache.maven.plugins diff --git a/core/src/main/java/org/apache/spark/network/netty/FileClient.java b/core/src/main/java/org/apache/spark/network/netty/FileClient.java index 20a7a3aa8c..edd0fc56f8 100644 --- a/core/src/main/java/org/apache/spark/network/netty/FileClient.java +++ b/core/src/main/java/org/apache/spark/network/netty/FileClient.java @@ -19,8 +19,6 @@ package org.apache.spark.network.netty; import io.netty.bootstrap.Bootstrap; import io.netty.channel.Channel; -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelOption; import io.netty.channel.oio.OioEventLoopGroup; import io.netty.channel.socket.oio.OioSocketChannel; diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServer.java b/core/src/main/java/org/apache/spark/network/netty/FileServer.java index 666432474d..a99af348ce 100644 --- a/core/src/main/java/org/apache/spark/network/netty/FileServer.java +++ b/core/src/main/java/org/apache/spark/network/netty/FileServer.java @@ -20,7 +20,6 @@ package org.apache.spark.network.netty; import java.net.InetSocketAddress; import io.netty.bootstrap.ServerBootstrap; -import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelOption; import io.netty.channel.oio.OioEventLoopGroup; diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index 1ad9240cfa..c6b4ac5192 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -99,7 +99,7 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: override def ready(atMost: Duration)(implicit permit: CanAwait): SimpleFutureAction.this.type = { if (!atMost.isFinite()) { awaitResult() - } else { + } else jobWaiter.synchronized { val finishTime = System.currentTimeMillis() + atMost.toMillis while (!isCompleted) { val time = System.currentTimeMillis() diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 5e465fa22c..10fae5af9f 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -21,12 +21,11 @@ import java.io._ import java.util.zip.{GZIPInputStream, GZIPOutputStream} import scala.collection.mutable.HashSet +import scala.concurrent.Await +import scala.concurrent.duration._ import akka.actor._ -import akka.dispatch._ import akka.pattern.ask -import akka.util.Duration - import org.apache.spark.scheduler.MapStatus import org.apache.spark.storage.BlockManagerId @@ -55,9 +54,9 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster private[spark] class MapOutputTracker extends Logging { private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") - + // Set to the MapOutputTrackerActor living on the driver - var trackerActor: ActorRef = _ + var trackerActor: Either[ActorRef, ActorSelection] = _ protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] @@ -73,8 +72,18 @@ private[spark] class MapOutputTracker extends Logging { // throw a SparkException if this fails. private def askTracker(message: Any): Any = { try { - val future = trackerActor.ask(message)(timeout) - return Await.result(future, timeout) + /* + The difference between ActorRef and ActorSelection is well explained here: + http://doc.akka.io/docs/akka/2.2.3/project/migration-guide-2.1.x-2.2.x.html#Use_actorSelection_instead_of_actorFor + In spark a map output tracker can be either started on Driver where it is created which + is an ActorRef or it can be on executor from where it is looked up which is an + actorSelection. + */ + val future = trackerActor match { + case Left(a: ActorRef) => a.ask(message)(timeout) + case Right(b: ActorSelection) => b.ask(message)(timeout) + } + Await.result(future, timeout) } catch { case e: Exception => throw new SparkException("Error communicating with MapOutputTracker", e) @@ -117,7 +126,7 @@ private[spark] class MapOutputTracker extends Logging { fetching += shuffleId } } - + if (fetchedStatuses == null) { // We won the race to fetch the output locs; do so logInfo("Doing the fetch; tracker actor = " + trackerActor) @@ -144,7 +153,7 @@ private[spark] class MapOutputTracker extends Logging { else{ throw new FetchFailedException(null, shuffleId, -1, reduceId, new Exception("Missing all output locations for shuffle " + shuffleId)) - } + } } else { statuses.synchronized { return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses) @@ -244,12 +253,12 @@ private[spark] class MapOutputTrackerMaster extends MapOutputTracker { case Some(bytes) => return bytes case None => - statuses = mapStatuses(shuffleId) + statuses = mapStatuses.getOrElse(shuffleId, Array[MapStatus]()) epochGotten = epoch } } // If we got here, we failed to find the serialized locations in the cache, so we pulled - // out a snapshot of the locations as "locs"; let's serialize and return that + // out a snapshot of the locations as "statuses"; let's serialize and return that val bytes = MapOutputTracker.serializeMapStatuses(statuses) logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length)) // Add them into the table only if the epoch hasn't changed while we were working @@ -274,6 +283,10 @@ private[spark] class MapOutputTrackerMaster extends MapOutputTracker { override def updateEpoch(newEpoch: Long) { // This might be called on the MapOutputTrackerMaster if we're running in local mode. } + + def has(shuffleId: Int): Boolean = { + cachedSerializedStatuses.get(shuffleId).isDefined || mapStatuses.contains(shuffleId) + } } private[spark] object MapOutputTracker { @@ -308,7 +321,7 @@ private[spark] object MapOutputTracker { statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = { assert (statuses != null) statuses.map { - status => + status => if (status == null) { throw new FetchFailedException(null, shuffleId, -1, reduceId, new Exception("Missing an output location for shuffle " + shuffleId)) diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 0e2c987a59..bcec41c439 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -17,8 +17,10 @@ package org.apache.spark -import org.apache.spark.util.Utils +import scala.reflect.ClassTag + import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils /** * An object that defines how the elements in a key-value pair RDD are partitioned by key. @@ -72,7 +74,7 @@ class HashPartitioner(partitions: Int) extends Partitioner { case null => 0 case _ => Utils.nonNegativeMod(key.hashCode, numPartitions) } - + override def equals(other: Any): Boolean = other match { case h: HashPartitioner => h.numPartitions == numPartitions @@ -85,7 +87,7 @@ class HashPartitioner(partitions: Int) extends Partitioner { * A [[org.apache.spark.Partitioner]] that partitions sortable records by range into roughly equal ranges. * Determines the ranges by sampling the RDD passed in. */ -class RangePartitioner[K <% Ordered[K]: ClassManifest, V]( +class RangePartitioner[K <% Ordered[K]: ClassTag, V]( partitions: Int, @transient rdd: RDD[_ <: Product2[K,V]], private val ascending: Boolean = true) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 42b2985b50..a0f794edfd 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -26,6 +26,7 @@ import scala.collection.Map import scala.collection.generic.Growable import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap +import scala.reflect.{ClassTag, classTag} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -81,7 +82,7 @@ class SparkContext( val sparkHome: String = null, val jars: Seq[String] = Nil, val environment: Map[String, String] = Map(), - // This is used only by yarn for now, but should be relevant to other cluster types (mesos, etc) + // This is used only by YARN for now, but should be relevant to other cluster types (Mesos, etc) // too. This is typically generated from InputFormatInfo.computePreferredLocations .. host, set // of data-local splits on host val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = @@ -153,98 +154,11 @@ class SparkContext( executorEnvs("SPARK_USER") = sparkUser // Create and start the scheduler - private[spark] var taskScheduler: TaskScheduler = { - // Regular expression used for local[N] master format - val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r - // Regular expression for local[N, maxRetries], used in tests with failing tasks - val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r - // Regular expression for simulating a Spark cluster of [N, cores, memory] locally - val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r - // Regular expression for connecting to Spark deploy clusters - val SPARK_REGEX = """spark://(.*)""".r - // Regular expression for connection to Mesos cluster - val MESOS_REGEX = """mesos://(.*)""".r - // Regular expression for connection to Simr cluster - val SIMR_REGEX = """simr://(.*)""".r - - master match { - case "local" => - new LocalScheduler(1, 0, this) - - case LOCAL_N_REGEX(threads) => - new LocalScheduler(threads.toInt, 0, this) - - case LOCAL_N_FAILURES_REGEX(threads, maxFailures) => - new LocalScheduler(threads.toInt, maxFailures.toInt, this) - - case SPARK_REGEX(sparkUrl) => - val scheduler = new ClusterScheduler(this) - val masterUrls = sparkUrl.split(",").map("spark://" + _) - val backend = new SparkDeploySchedulerBackend(scheduler, this, masterUrls, appName) - scheduler.initialize(backend) - scheduler - - case SIMR_REGEX(simrUrl) => - val scheduler = new ClusterScheduler(this) - val backend = new SimrSchedulerBackend(scheduler, this, simrUrl) - scheduler.initialize(backend) - scheduler - - case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) => - // Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang. - val memoryPerSlaveInt = memoryPerSlave.toInt - if (SparkContext.executorMemoryRequested > memoryPerSlaveInt) { - throw new SparkException( - "Asked to launch cluster with %d MB RAM / worker but requested %d MB/worker".format( - memoryPerSlaveInt, SparkContext.executorMemoryRequested)) - } - - val scheduler = new ClusterScheduler(this) - val localCluster = new LocalSparkCluster( - numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt) - val masterUrls = localCluster.start() - val backend = new SparkDeploySchedulerBackend(scheduler, this, masterUrls, appName) - scheduler.initialize(backend) - backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => { - localCluster.stop() - } - scheduler - - case "yarn-standalone" => - val scheduler = try { - val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler") - val cons = clazz.getConstructor(classOf[SparkContext]) - cons.newInstance(this).asInstanceOf[ClusterScheduler] - } catch { - // TODO: Enumerate the exact reasons why it can fail - // But irrespective of it, it means we cannot proceed ! - case th: Throwable => { - throw new SparkException("YARN mode not available ?", th) - } - } - val backend = new CoarseGrainedSchedulerBackend(scheduler, this.env.actorSystem) - scheduler.initialize(backend) - scheduler - - case MESOS_REGEX(mesosUrl) => - MesosNativeLibrary.load() - val scheduler = new ClusterScheduler(this) - val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean - val backend = if (coarseGrained) { - new CoarseMesosSchedulerBackend(scheduler, this, mesosUrl, appName) - } else { - new MesosSchedulerBackend(scheduler, this, mesosUrl, appName) - } - scheduler.initialize(backend) - scheduler - - case _ => - throw new SparkException("Could not parse Master URL: '" + master + "'") - } - } + private[spark] var taskScheduler = SparkContext.createTaskScheduler(this, master, appName) taskScheduler.start() @volatile private[spark] var dagScheduler = new DAGScheduler(taskScheduler) + dagScheduler.start() ui.start() @@ -354,19 +268,19 @@ class SparkContext( // Methods for creating RDDs /** Distribute a local Scala collection to form an RDD. */ - def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { + def parallelize[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]()) } /** Distribute a local Scala collection to form an RDD. */ - def makeRDD[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { + def makeRDD[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { parallelize(seq, numSlices) } /** Distribute a local Scala collection to form an RDD, with one or more * location preferences (hostnames of Spark nodes) for each object. * Create a new partition for each collection item. */ - def makeRDD[T: ClassManifest](seq: Seq[(T, Seq[String])]): RDD[T] = { + def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = { val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs) } @@ -419,7 +333,7 @@ class SparkContext( } /** - * Smarter version of hadoopFile() that uses class manifests to figure out the classes of keys, + * Smarter version of hadoopFile() that uses class tags to figure out the classes of keys, * values and the InputFormat so that users don't need to pass them directly. Instead, callers * can just write, for example, * {{{ @@ -427,17 +341,17 @@ class SparkContext( * }}} */ def hadoopFile[K, V, F <: InputFormat[K, V]](path: String, minSplits: Int) - (implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]) + (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]) : RDD[(K, V)] = { hadoopFile(path, - fm.erasure.asInstanceOf[Class[F]], - km.erasure.asInstanceOf[Class[K]], - vm.erasure.asInstanceOf[Class[V]], + fm.runtimeClass.asInstanceOf[Class[F]], + km.runtimeClass.asInstanceOf[Class[K]], + vm.runtimeClass.asInstanceOf[Class[V]], minSplits) } /** - * Smarter version of hadoopFile() that uses class manifests to figure out the classes of keys, + * Smarter version of hadoopFile() that uses class tags to figure out the classes of keys, * values and the InputFormat so that users don't need to pass them directly. Instead, callers * can just write, for example, * {{{ @@ -445,17 +359,17 @@ class SparkContext( * }}} */ def hadoopFile[K, V, F <: InputFormat[K, V]](path: String) - (implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]): RDD[(K, V)] = + (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = hadoopFile[K, V, F](path, defaultMinSplits) /** Get an RDD for a Hadoop file with an arbitrary new API InputFormat. */ def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]](path: String) - (implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]): RDD[(K, V)] = { + (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = { newAPIHadoopFile( path, - fm.erasure.asInstanceOf[Class[F]], - km.erasure.asInstanceOf[Class[K]], - vm.erasure.asInstanceOf[Class[V]]) + fm.runtimeClass.asInstanceOf[Class[F]], + km.runtimeClass.asInstanceOf[Class[K]], + vm.runtimeClass.asInstanceOf[Class[V]]) } /** @@ -513,11 +427,11 @@ class SparkContext( * IntWritable). The most natural thing would've been to have implicit objects for the * converters, but then we couldn't have an object for every subclass of Writable (you can't * have a parameterized singleton object). We use functions instead to create a new converter - * for the appropriate type. In addition, we pass the converter a ClassManifest of its type to + * for the appropriate type. In addition, we pass the converter a ClassTag of its type to * allow it to figure out the Writable class to use in the subclass case. */ def sequenceFile[K, V](path: String, minSplits: Int = defaultMinSplits) - (implicit km: ClassManifest[K], vm: ClassManifest[V], + (implicit km: ClassTag[K], vm: ClassTag[V], kcf: () => WritableConverter[K], vcf: () => WritableConverter[V]) : RDD[(K, V)] = { val kc = kcf() @@ -536,7 +450,7 @@ class SparkContext( * slow if you use the default serializer (Java serialization), though the nice thing about it is * that there's very little effort required to save arbitrary objects. */ - def objectFile[T: ClassManifest]( + def objectFile[T: ClassTag]( path: String, minSplits: Int = defaultMinSplits ): RDD[T] = { @@ -545,17 +459,17 @@ class SparkContext( } - protected[spark] def checkpointFile[T: ClassManifest]( + protected[spark] def checkpointFile[T: ClassTag]( path: String ): RDD[T] = { new CheckpointRDD[T](this, path) } /** Build the union of a list of RDDs. */ - def union[T: ClassManifest](rdds: Seq[RDD[T]]): RDD[T] = new UnionRDD(this, rdds) + def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = new UnionRDD(this, rdds) /** Build the union of a list of RDDs passed as variable-length arguments. */ - def union[T: ClassManifest](first: RDD[T], rest: RDD[T]*): RDD[T] = + def union[T: ClassTag](first: RDD[T], rest: RDD[T]*): RDD[T] = new UnionRDD(this, Seq(first) ++ rest) // Methods for creating shared variables @@ -798,7 +712,7 @@ class SparkContext( * flag specifies whether the scheduler can run the computation on the driver rather than * shipping it out to the cluster, for short actions like first(). */ - def runJob[T, U: ClassManifest]( + def runJob[T, U: ClassTag]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], @@ -819,7 +733,7 @@ class SparkContext( * allowLocal flag specifies whether the scheduler can run the computation on the driver rather * than shipping it out to the cluster, for short actions like first(). */ - def runJob[T, U: ClassManifest]( + def runJob[T, U: ClassTag]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], @@ -834,7 +748,7 @@ class SparkContext( * Run a job on a given set of partitions of an RDD, but take a function of type * `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`. */ - def runJob[T, U: ClassManifest]( + def runJob[T, U: ClassTag]( rdd: RDD[T], func: Iterator[T] => U, partitions: Seq[Int], @@ -846,21 +760,21 @@ class SparkContext( /** * Run a job on all partitions in an RDD and return the results in an array. */ - def runJob[T, U: ClassManifest](rdd: RDD[T], func: (TaskContext, Iterator[T]) => U): Array[U] = { + def runJob[T, U: ClassTag](rdd: RDD[T], func: (TaskContext, Iterator[T]) => U): Array[U] = { runJob(rdd, func, 0 until rdd.partitions.size, false) } /** * Run a job on all partitions in an RDD and return the results in an array. */ - def runJob[T, U: ClassManifest](rdd: RDD[T], func: Iterator[T] => U): Array[U] = { + def runJob[T, U: ClassTag](rdd: RDD[T], func: Iterator[T] => U): Array[U] = { runJob(rdd, func, 0 until rdd.partitions.size, false) } /** * Run a job on all partitions in an RDD and pass the results to a handler function. */ - def runJob[T, U: ClassManifest]( + def runJob[T, U: ClassTag]( rdd: RDD[T], processPartition: (TaskContext, Iterator[T]) => U, resultHandler: (Int, U) => Unit) @@ -871,7 +785,7 @@ class SparkContext( /** * Run a job on all partitions in an RDD and pass the results to a handler function. */ - def runJob[T, U: ClassManifest]( + def runJob[T, U: ClassTag]( rdd: RDD[T], processPartition: Iterator[T] => U, resultHandler: (Int, U) => Unit) @@ -1017,16 +931,16 @@ object SparkContext { // TODO: Add AccumulatorParams for other types, e.g. lists and strings - implicit def rddToPairRDDFunctions[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) = + implicit def rddToPairRDDFunctions[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]) = new PairRDDFunctions(rdd) - implicit def rddToAsyncRDDActions[T: ClassManifest](rdd: RDD[T]) = new AsyncRDDActions(rdd) + implicit def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]) = new AsyncRDDActions(rdd) - implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable: ClassManifest]( + implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag]( rdd: RDD[(K, V)]) = new SequenceFileRDDFunctions(rdd) - implicit def rddToOrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest]( + implicit def rddToOrderedRDDFunctions[K <% Ordered[K]: ClassTag, V: ClassTag]( rdd: RDD[(K, V)]) = new OrderedRDDFunctions[K, V, (K, V)](rdd) @@ -1051,16 +965,16 @@ object SparkContext { implicit def stringToText(s: String) = new Text(s) - private implicit def arrayToArrayWritable[T <% Writable: ClassManifest](arr: Traversable[T]): ArrayWritable = { + private implicit def arrayToArrayWritable[T <% Writable: ClassTag](arr: Traversable[T]): ArrayWritable = { def anyToWritable[U <% Writable](u: U): Writable = u - new ArrayWritable(classManifest[T].erasure.asInstanceOf[Class[Writable]], + new ArrayWritable(classTag[T].runtimeClass.asInstanceOf[Class[Writable]], arr.map(x => anyToWritable(x)).toArray) } // Helper objects for converting common types to Writable - private def simpleWritableConverter[T, W <: Writable: ClassManifest](convert: W => T) = { - val wClass = classManifest[W].erasure.asInstanceOf[Class[W]] + private def simpleWritableConverter[T, W <: Writable: ClassTag](convert: W => T) = { + val wClass = classTag[W].runtimeClass.asInstanceOf[Class[W]] new WritableConverter[T](_ => wClass, x => convert(x.asInstanceOf[W])) } @@ -1079,7 +993,7 @@ object SparkContext { implicit def stringWritableConverter() = simpleWritableConverter[String, Text](_.toString) implicit def writableWritableConverter[T <: Writable]() = - new WritableConverter[T](_.erasure.asInstanceOf[Class[T]], _.asInstanceOf[T]) + new WritableConverter[T](_.runtimeClass.asInstanceOf[Class[T]], _.asInstanceOf[T]) /** * Find the JAR from which a given class was loaded, to make it easy for users to pass @@ -1111,17 +1025,135 @@ object SparkContext { .map(Utils.memoryStringToMb) .getOrElse(512) } + + // Creates a task scheduler based on a given master URL. Extracted for testing. + private + def createTaskScheduler(sc: SparkContext, master: String, appName: String): TaskScheduler = { + // Regular expression used for local[N] master format + val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r + // Regular expression for local[N, maxRetries], used in tests with failing tasks + val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r + // Regular expression for simulating a Spark cluster of [N, cores, memory] locally + val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r + // Regular expression for connecting to Spark deploy clusters + val SPARK_REGEX = """spark://(.*)""".r + // Regular expression for connection to Mesos cluster by mesos:// or zk:// url + val MESOS_REGEX = """(mesos|zk)://.*""".r + // Regular expression for connection to Simr cluster + val SIMR_REGEX = """simr://(.*)""".r + + master match { + case "local" => + new LocalScheduler(1, 0, sc) + + case LOCAL_N_REGEX(threads) => + new LocalScheduler(threads.toInt, 0, sc) + + case LOCAL_N_FAILURES_REGEX(threads, maxFailures) => + new LocalScheduler(threads.toInt, maxFailures.toInt, sc) + + case SPARK_REGEX(sparkUrl) => + val scheduler = new ClusterScheduler(sc) + val masterUrls = sparkUrl.split(",").map("spark://" + _) + val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls, appName) + scheduler.initialize(backend) + scheduler + + case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) => + // Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang. + val memoryPerSlaveInt = memoryPerSlave.toInt + if (SparkContext.executorMemoryRequested > memoryPerSlaveInt) { + throw new SparkException( + "Asked to launch cluster with %d MB RAM / worker but requested %d MB/worker".format( + memoryPerSlaveInt, SparkContext.executorMemoryRequested)) + } + + val scheduler = new ClusterScheduler(sc) + val localCluster = new LocalSparkCluster( + numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt) + val masterUrls = localCluster.start() + val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls, appName) + scheduler.initialize(backend) + backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => { + localCluster.stop() + } + scheduler + + case "yarn-standalone" => + val scheduler = try { + val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler") + val cons = clazz.getConstructor(classOf[SparkContext]) + cons.newInstance(sc).asInstanceOf[ClusterScheduler] + } catch { + // TODO: Enumerate the exact reasons why it can fail + // But irrespective of it, it means we cannot proceed ! + case th: Throwable => { + throw new SparkException("YARN mode not available ?", th) + } + } + val backend = new CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + scheduler.initialize(backend) + scheduler + + case "yarn-client" => + val scheduler = try { + val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClientClusterScheduler") + val cons = clazz.getConstructor(classOf[SparkContext]) + cons.newInstance(sc).asInstanceOf[ClusterScheduler] + + } catch { + case th: Throwable => { + throw new SparkException("YARN mode not available ?", th) + } + } + + val backend = try { + val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend") + val cons = clazz.getConstructor(classOf[ClusterScheduler], classOf[SparkContext]) + cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend] + } catch { + case th: Throwable => { + throw new SparkException("YARN mode not available ?", th) + } + } + + scheduler.initialize(backend) + scheduler + + case mesosUrl @ MESOS_REGEX(_) => + MesosNativeLibrary.load() + val scheduler = new ClusterScheduler(sc) + val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean + val url = mesosUrl.stripPrefix("mesos://") // strip scheme from raw Mesos URLs + val backend = if (coarseGrained) { + new CoarseMesosSchedulerBackend(scheduler, sc, url, appName) + } else { + new MesosSchedulerBackend(scheduler, sc, url, appName) + } + scheduler.initialize(backend) + scheduler + + case SIMR_REGEX(simrUrl) => + val scheduler = new ClusterScheduler(sc) + val backend = new SimrSchedulerBackend(scheduler, sc, simrUrl) + scheduler.initialize(backend) + scheduler + + case _ => + throw new SparkException("Could not parse Master URL: '" + master + "'") + } + } } /** * A class encapsulating how to convert some type T to Writable. It stores both the Writable class * corresponding to T (e.g. IntWritable for Int) and a function for doing the conversion. - * The getter for the writable class takes a ClassManifest[T] in case this is a generic object + * The getter for the writable class takes a ClassTag[T] in case this is a generic object * that doesn't know the type of T when it is created. This sounds strange but is necessary to * support converting subclasses of Writable to themselves (writableWritableConverter). */ private[spark] class WritableConverter[T]( - val writableClass: ClassManifest[T] => Class[_ <: Writable], + val writableClass: ClassTag[T] => Class[_ <: Writable], val convert: Writable => T) extends Serializable diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index ff2df8fb6a..826f5c2d8c 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -20,7 +20,7 @@ package org.apache.spark import collection.mutable import serializer.Serializer -import akka.actor.{Actor, ActorRef, Props, ActorSystemImpl, ActorSystem} +import akka.actor._ import akka.remote.RemoteActorRefProvider import org.apache.spark.broadcast.BroadcastManager @@ -74,7 +74,8 @@ class SparkEnv ( actorSystem.shutdown() // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut // down, but let's call it anyway in case it gets fixed in a later release - actorSystem.awaitTermination() + // UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it. + //actorSystem.awaitTermination() } def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = { @@ -151,17 +152,17 @@ object SparkEnv extends Logging { val closureSerializer = serializerManager.get( System.getProperty("spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer")) - def registerOrLookup(name: String, newActor: => Actor): ActorRef = { + def registerOrLookup(name: String, newActor: => Actor): Either[ActorRef, ActorSelection] = { if (isDriver) { logInfo("Registering " + name) - actorSystem.actorOf(Props(newActor), name = name) + Left(actorSystem.actorOf(Props(newActor), name = name)) } else { val driverHost: String = System.getProperty("spark.driver.host", "localhost") val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt Utils.checkHost(driverHost, "Expected hostname") - val url = "akka://spark@%s:%s/user/%s".format(driverHost, driverPort, name) + val url = "akka.tcp://spark@%s:%s/user/%s".format(driverHost, driverPort, name) logInfo("Connecting to " + name + ": " + url) - actorSystem.actorFor(url) + Right(actorSystem.actorSelection(url)) } } diff --git a/core/src/main/scala/org/apache/spark/TaskState.scala b/core/src/main/scala/org/apache/spark/TaskState.scala index 19ce8369d9..0bf1e4a5e2 100644 --- a/core/src/main/scala/org/apache/spark/TaskState.scala +++ b/core/src/main/scala/org/apache/spark/TaskState.scala @@ -19,8 +19,7 @@ package org.apache.spark import org.apache.mesos.Protos.{TaskState => MesosTaskState} -private[spark] object TaskState - extends Enumeration("LAUNCHING", "RUNNING", "FINISHED", "FAILED", "KILLED", "LOST") { +private[spark] object TaskState extends Enumeration { val LAUNCHING, RUNNING, FINISHED, FAILED, KILLED, LOST = Value diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index 043cb183ba..da30cf619a 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -17,18 +17,23 @@ package org.apache.spark.api.java +import scala.reflect.ClassTag + import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext.doubleRDDToDoubleRDDFunctions import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.util.StatCounter import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.storage.StorageLevel + import java.lang.Double import org.apache.spark.Partitioner +import scala.collection.JavaConverters._ + class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, JavaDoubleRDD] { - override val classManifest: ClassManifest[Double] = implicitly[ClassManifest[Double]] + override val classTag: ClassTag[Double] = implicitly[ClassTag[Double]] override val rdd: RDD[Double] = srdd.map(x => Double.valueOf(x)) @@ -42,7 +47,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ def cache(): JavaDoubleRDD = fromRDD(srdd.cache()) - /** + /** * Set this RDD's storage level to persist its values across operations after the first time * it is computed. Can only be called once on each RDD. */ @@ -106,7 +111,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav /** * Return an RDD with the elements from `this` that are not in `other`. - * + * * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting * RDD will be <= us. */ @@ -182,6 +187,44 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav /** (Experimental) Approximate operation to return the sum within a timeout. */ def sumApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.sumApprox(timeout) + + /** + * Compute a histogram of the data using bucketCount number of buckets evenly + * spaced between the minimum and maximum of the RDD. For example if the min + * value is 0 and the max is 100 and there are two buckets the resulting + * buckets will be [0,50) [50,100]. bucketCount must be at least 1 + * If the RDD contains infinity, NaN throws an exception + * If the elements in RDD do not vary (max == min) always returns a single bucket. + */ + def histogram(bucketCount: Int): Pair[Array[scala.Double], Array[Long]] = { + val result = srdd.histogram(bucketCount) + (result._1, result._2) + } + + /** + * Compute a histogram using the provided buckets. The buckets are all open + * to the left except for the last which is closed + * e.g. for the array + * [1,10,20,50] the buckets are [1,10) [10,20) [20,50] + * e.g 1<=x<10 , 10<=x<20, 20<=x<50 + * And on the input of 1 and 50 we would have a histogram of 1,0,0 + * + * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched + * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets + * to true. + * buckets must be sorted and not contain any duplicates. + * buckets array must be at least two elements + * All NaN entries are treated the same. If you have a NaN bucket it must be + * the maximum value of the last position and all NaN entries will be counted + * in that bucket. + */ + def histogram(buckets: Array[scala.Double]): Array[Long] = { + srdd.histogram(buckets, false) + } + + def histogram(buckets: Array[Double], evenBuckets: Boolean): Array[Long] = { + srdd.histogram(buckets.map(_.toDouble), evenBuckets) + } } object JavaDoubleRDD { diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 2142fd7327..363667fa86 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -22,6 +22,7 @@ import java.util.Comparator import scala.Tuple2 import scala.collection.JavaConversions._ +import scala.reflect.ClassTag import com.google.common.base.Optional import org.apache.hadoop.io.compress.CompressionCodec @@ -43,13 +44,13 @@ import org.apache.spark.rdd.OrderedRDDFunctions import org.apache.spark.storage.StorageLevel -class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManifest[K], - implicit val vManifest: ClassManifest[V]) extends JavaRDDLike[(K, V), JavaPairRDD[K, V]] { +class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kClassTag: ClassTag[K], + implicit val vClassTag: ClassTag[V]) extends JavaRDDLike[(K, V), JavaPairRDD[K, V]] { override def wrapRDD(rdd: RDD[(K, V)]): JavaPairRDD[K, V] = JavaPairRDD.fromRDD(rdd) - override val classManifest: ClassManifest[(K, V)] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]] + override val classTag: ClassTag[(K, V)] = + implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[Tuple2[K, V]]] import JavaPairRDD._ @@ -58,7 +59,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ def cache(): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.cache()) - /** + /** * Set this RDD's storage level to persist its values across operations after the first time * it is computed. Can only be called once on each RDD. */ @@ -138,14 +139,14 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif override def first(): (K, V) = rdd.first() // Pair RDD functions - + /** - * Generic function to combine the elements for each key using a custom set of aggregation - * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a - * "combined type" C * Note that V and C can be different -- for example, one might group an - * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three + * Generic function to combine the elements for each key using a custom set of aggregation + * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a + * "combined type" C * Note that V and C can be different -- for example, one might group an + * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three * functions: - * + * * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) * - `mergeCombiners`, to combine two C's into a single one. @@ -157,8 +158,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif mergeValue: JFunction2[C, V, C], mergeCombiners: JFunction2[C, C, C], partitioner: Partitioner): JavaPairRDD[K, C] = { - implicit val cm: ClassManifest[C] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[C]] + implicit val cm: ClassTag[C] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[C]] fromRDD(rdd.combineByKey( createCombiner, mergeValue, @@ -195,14 +195,14 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif /** Count the number of elements for each key, and return the result to the master as a Map. */ def countByKey(): java.util.Map[K, Long] = mapAsJavaMap(rdd.countByKey()) - /** + /** * (Experimental) Approximate version of countByKey that can return a partial result if it does * not finish within a timeout. */ def countByKeyApprox(timeout: Long): PartialResult[java.util.Map[K, BoundedDouble]] = rdd.countByKeyApprox(timeout).map(mapAsJavaMap) - /** + /** * (Experimental) Approximate version of countByKey that can return a partial result if it does * not finish within a timeout. */ @@ -258,7 +258,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif /** * Return an RDD with the elements from `this` that are not in `other`. - * + * * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting * RDD will be <= us. */ @@ -315,15 +315,14 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif fromRDD(joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)}) } - /** + /** * Simplified version of combineByKey that hash-partitions the resulting RDD using the existing * partitioner/parallelism level. */ def combineByKey[C](createCombiner: JFunction[V, C], mergeValue: JFunction2[C, V, C], mergeCombiners: JFunction2[C, C, C]): JavaPairRDD[K, C] = { - implicit val cm: ClassManifest[C] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[C]] + implicit val cm: ClassTag[C] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[C]] fromRDD(combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(rdd))) } @@ -414,8 +413,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif * this also retains the original RDD's partitioning. */ def mapValues[U](f: JFunction[V, U]): JavaPairRDD[K, U] = { - implicit val cm: ClassManifest[U] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]] + implicit val cm: ClassTag[U] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[U]] fromRDD(rdd.mapValues(f)) } @@ -426,8 +424,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairRDD[K, U] = { import scala.collection.JavaConverters._ def fn = (x: V) => f.apply(x).asScala - implicit val cm: ClassManifest[U] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]] + implicit val cm: ClassTag[U] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[U]] fromRDD(rdd.flatMapValues(fn)) } @@ -591,6 +588,20 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif fromRDD(new OrderedRDDFunctions[K, V, (K, V)](rdd).sortByKey(ascending)) } + /** + * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling + * `collect` or `save` on the resulting RDD will return or output an ordered list of records + * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in + * order of the keys). + */ + def sortByKey(comp: Comparator[K], ascending: Boolean, numPartitions: Int): JavaPairRDD[K, V] = { + class KeyOrdering(val a: K) extends Ordered[K] { + override def compare(b: K) = comp.compare(a, b) + } + implicit def toOrdered(x: K): Ordered[K] = new KeyOrdering(x) + fromRDD(new OrderedRDDFunctions[K, V, (K, V)](rdd).sortByKey(ascending, numPartitions)) + } + /** * Return an RDD with the keys of each tuple. */ @@ -603,22 +614,22 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif } object JavaPairRDD { - def groupByResultToJava[K, T](rdd: RDD[(K, Seq[T])])(implicit kcm: ClassManifest[K], - vcm: ClassManifest[T]): RDD[(K, JList[T])] = + def groupByResultToJava[K, T](rdd: RDD[(K, Seq[T])])(implicit kcm: ClassTag[K], + vcm: ClassTag[T]): RDD[(K, JList[T])] = rddToPairRDDFunctions(rdd).mapValues(seqAsJavaList _) - def cogroupResultToJava[W, K, V](rdd: RDD[(K, (Seq[V], Seq[W]))])(implicit kcm: ClassManifest[K], - vcm: ClassManifest[V]): RDD[(K, (JList[V], JList[W]))] = rddToPairRDDFunctions(rdd).mapValues((x: (Seq[V], - Seq[W])) => (seqAsJavaList(x._1), seqAsJavaList(x._2))) + def cogroupResultToJava[W, K, V](rdd: RDD[(K, (Seq[V], Seq[W]))])(implicit kcm: ClassTag[K], + vcm: ClassTag[V]): RDD[(K, (JList[V], JList[W]))] = rddToPairRDDFunctions(rdd) + .mapValues((x: (Seq[V], Seq[W])) => (seqAsJavaList(x._1), seqAsJavaList(x._2))) def cogroupResult2ToJava[W1, W2, K, V](rdd: RDD[(K, (Seq[V], Seq[W1], - Seq[W2]))])(implicit kcm: ClassManifest[K]) : RDD[(K, (JList[V], JList[W1], + Seq[W2]))])(implicit kcm: ClassTag[K]) : RDD[(K, (JList[V], JList[W1], JList[W2]))] = rddToPairRDDFunctions(rdd).mapValues( (x: (Seq[V], Seq[W1], Seq[W2])) => (seqAsJavaList(x._1), seqAsJavaList(x._2), seqAsJavaList(x._3))) - def fromRDD[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]): JavaPairRDD[K, V] = + def fromRDD[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd) implicit def toRDD[K, V](rdd: JavaPairRDD[K, V]): RDD[(K, V)] = rdd.rdd @@ -626,10 +637,8 @@ object JavaPairRDD { /** Convert a JavaRDD of key-value pairs to JavaPairRDD. */ def fromJavaRDD[K, V](rdd: JavaRDD[(K, V)]): JavaPairRDD[K, V] = { - implicit val cmk: ClassManifest[K] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] - implicit val cmv: ClassManifest[V] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]] + implicit val cmk: ClassTag[K] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]] + implicit val cmv: ClassTag[V] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]] new JavaPairRDD[K, V](rdd.rdd) } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index 3b359a8fd6..c47657f512 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -17,12 +17,14 @@ package org.apache.spark.api.java +import scala.reflect.ClassTag + import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.storage.StorageLevel -class JavaRDD[T](val rdd: RDD[T])(implicit val classManifest: ClassManifest[T]) extends +class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) extends JavaRDDLike[T, JavaRDD[T]] { override def wrapRDD(rdd: RDD[T]): JavaRDD[T] = JavaRDD.fromRDD(rdd) @@ -127,8 +129,7 @@ JavaRDDLike[T, JavaRDD[T]] { object JavaRDD { - implicit def fromRDD[T: ClassManifest](rdd: RDD[T]): JavaRDD[T] = new JavaRDD[T](rdd) + implicit def fromRDD[T: ClassTag](rdd: RDD[T]): JavaRDD[T] = new JavaRDD[T](rdd) implicit def toRDD[T](rdd: JavaRDD[T]): RDD[T] = rdd.rdd } - diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 7a3568c5ef..9e912d3adb 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -20,6 +20,7 @@ package org.apache.spark.api.java import java.util.{List => JList, Comparator} import scala.Tuple2 import scala.collection.JavaConversions._ +import scala.reflect.ClassTag import com.google.common.base.Optional import org.apache.hadoop.io.compress.CompressionCodec @@ -35,7 +36,7 @@ import org.apache.spark.storage.StorageLevel trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def wrapRDD(rdd: RDD[T]): This - implicit val classManifest: ClassManifest[T] + implicit val classTag: ClassTag[T] def rdd: RDD[T] @@ -71,7 +72,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return a new RDD by applying a function to each partition of this RDD, while tracking the index * of the original partition. */ - def mapPartitionsWithIndex[R: ClassManifest]( + def mapPartitionsWithIndex[R: ClassTag]( f: JFunction2[Int, java.util.Iterator[T], java.util.Iterator[R]], preservesPartitioning: Boolean = false): JavaRDD[R] = new JavaRDD(rdd.mapPartitionsWithIndex(((a,b) => f(a,asJavaIterator(b))), @@ -87,7 +88,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return a new RDD by applying a function to all elements of this RDD. */ def map[K2, V2](f: PairFunction[T, K2, V2]): JavaPairRDD[K2, V2] = { - def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K2, V2]]] + def cm = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[Tuple2[K2, V2]]] new JavaPairRDD(rdd.map(f)(cm))(f.keyType(), f.valueType()) } @@ -118,7 +119,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def flatMap[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairRDD[K2, V2] = { import scala.collection.JavaConverters._ def fn = (x: T) => f.apply(x).asScala - def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K2, V2]]] + def cm = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[Tuple2[K2, V2]]] JavaPairRDD.fromRDD(rdd.flatMap(fn)(cm))(f.keyType(), f.valueType()) } @@ -158,18 +159,16 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * elements (a, b) where a is in `this` and b is in `other`. */ def cartesian[U](other: JavaRDDLike[U, _]): JavaPairRDD[T, U] = - JavaPairRDD.fromRDD(rdd.cartesian(other.rdd)(other.classManifest))(classManifest, - other.classManifest) + JavaPairRDD.fromRDD(rdd.cartesian(other.rdd)(other.classTag))(classTag, other.classTag) /** * Return an RDD of grouped elements. Each group consists of a key and a sequence of elements * mapping to that key. */ def groupBy[K](f: JFunction[T, K]): JavaPairRDD[K, JList[T]] = { - implicit val kcm: ClassManifest[K] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] - implicit val vcm: ClassManifest[JList[T]] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[JList[T]]] + implicit val kcm: ClassTag[K] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]] + implicit val vcm: ClassTag[JList[T]] = + implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[JList[T]]] JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f)(f.returnType)))(kcm, vcm) } @@ -178,10 +177,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * mapping to that key. */ def groupBy[K](f: JFunction[T, K], numPartitions: Int): JavaPairRDD[K, JList[T]] = { - implicit val kcm: ClassManifest[K] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] - implicit val vcm: ClassManifest[JList[T]] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[JList[T]]] + implicit val kcm: ClassTag[K] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]] + implicit val vcm: ClassTag[JList[T]] = + implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[JList[T]]] JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f, numPartitions)(f.returnType)))(kcm, vcm) } @@ -209,7 +207,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * a map on the other). */ def zip[U](other: JavaRDDLike[U, _]): JavaPairRDD[T, U] = { - JavaPairRDD.fromRDD(rdd.zip(other.rdd)(other.classManifest))(classManifest, other.classManifest) + JavaPairRDD.fromRDD(rdd.zip(other.rdd)(other.classTag))(classTag, other.classTag) } /** @@ -224,7 +222,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def fn = (x: Iterator[T], y: Iterator[U]) => asScalaIterator( f.apply(asJavaIterator(x), asJavaIterator(y)).iterator()) JavaRDD.fromRDD( - rdd.zipPartitions(other.rdd)(fn)(other.classManifest, f.elementType()))(f.elementType()) + rdd.zipPartitions(other.rdd)(fn)(other.classTag, f.elementType()))(f.elementType()) } // Actions (launch a job to return a value to the user program) @@ -356,7 +354,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Creates tuples of the elements in this RDD by applying `f`. */ def keyBy[K](f: JFunction[T, K]): JavaPairRDD[K, T] = { - implicit val kcm: ClassManifest[K] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] + implicit val kcm: ClassTag[K] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]] JavaPairRDD.fromRDD(rdd.keyBy(f)) } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 8869e072bf..acf328aa6a 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -21,6 +21,7 @@ import java.util.{Map => JMap} import scala.collection.JavaConversions import scala.collection.JavaConversions._ +import scala.reflect.ClassTag import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.InputFormat @@ -82,8 +83,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork /** Distribute a local Scala collection to form an RDD. */ def parallelize[T](list: java.util.List[T], numSlices: Int): JavaRDD[T] = { - implicit val cm: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices) } @@ -94,10 +94,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork /** Distribute a local Scala collection to form an RDD. */ def parallelizePairs[K, V](list: java.util.List[Tuple2[K, V]], numSlices: Int) : JavaPairRDD[K, V] = { - implicit val kcm: ClassManifest[K] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] - implicit val vcm: ClassManifest[V] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]] + implicit val kcm: ClassTag[K] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]] + implicit val vcm: ClassTag[V] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]] JavaPairRDD.fromRDD(sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices)) } @@ -132,16 +130,16 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork valueClass: Class[V], minSplits: Int ): JavaPairRDD[K, V] = { - implicit val kcm = ClassManifest.fromClass(keyClass) - implicit val vcm = ClassManifest.fromClass(valueClass) + implicit val kcm: ClassTag[K] = ClassTag(keyClass) + implicit val vcm: ClassTag[V] = ClassTag(valueClass) new JavaPairRDD(sc.sequenceFile(path, keyClass, valueClass, minSplits)) } /**Get an RDD for a Hadoop SequenceFile. */ def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): JavaPairRDD[K, V] = { - implicit val kcm = ClassManifest.fromClass(keyClass) - implicit val vcm = ClassManifest.fromClass(valueClass) + implicit val kcm: ClassTag[K] = ClassTag(keyClass) + implicit val vcm: ClassTag[V] = ClassTag(valueClass) new JavaPairRDD(sc.sequenceFile(path, keyClass, valueClass)) } @@ -153,8 +151,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork * that there's very little effort required to save arbitrary objects. */ def objectFile[T](path: String, minSplits: Int): JavaRDD[T] = { - implicit val cm: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] sc.objectFile(path, minSplits)(cm) } @@ -166,8 +163,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork * that there's very little effort required to save arbitrary objects. */ def objectFile[T](path: String): JavaRDD[T] = { - implicit val cm: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] sc.objectFile(path)(cm) } @@ -183,8 +179,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork valueClass: Class[V], minSplits: Int ): JavaPairRDD[K, V] = { - implicit val kcm = ClassManifest.fromClass(keyClass) - implicit val vcm = ClassManifest.fromClass(valueClass) + implicit val kcm: ClassTag[K] = ClassTag(keyClass) + implicit val vcm: ClassTag[V] = ClassTag(valueClass) new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass, minSplits)) } @@ -199,8 +195,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork keyClass: Class[K], valueClass: Class[V] ): JavaPairRDD[K, V] = { - implicit val kcm = ClassManifest.fromClass(keyClass) - implicit val vcm = ClassManifest.fromClass(valueClass) + implicit val kcm: ClassTag[K] = ClassTag(keyClass) + implicit val vcm: ClassTag[V] = ClassTag(valueClass) new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass)) } @@ -212,8 +208,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork valueClass: Class[V], minSplits: Int ): JavaPairRDD[K, V] = { - implicit val kcm = ClassManifest.fromClass(keyClass) - implicit val vcm = ClassManifest.fromClass(valueClass) + implicit val kcm: ClassTag[K] = ClassTag(keyClass) + implicit val vcm: ClassTag[V] = ClassTag(valueClass) new JavaPairRDD(sc.hadoopFile(path, inputFormatClass, keyClass, valueClass, minSplits)) } @@ -224,8 +220,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork keyClass: Class[K], valueClass: Class[V] ): JavaPairRDD[K, V] = { - implicit val kcm = ClassManifest.fromClass(keyClass) - implicit val vcm = ClassManifest.fromClass(valueClass) + implicit val kcm: ClassTag[K] = ClassTag(keyClass) + implicit val vcm: ClassTag[V] = ClassTag(valueClass) new JavaPairRDD(sc.hadoopFile(path, inputFormatClass, keyClass, valueClass)) } @@ -240,8 +236,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork kClass: Class[K], vClass: Class[V], conf: Configuration): JavaPairRDD[K, V] = { - implicit val kcm = ClassManifest.fromClass(kClass) - implicit val vcm = ClassManifest.fromClass(vClass) + implicit val kcm: ClassTag[K] = ClassTag(kClass) + implicit val vcm: ClassTag[V] = ClassTag(vClass) new JavaPairRDD(sc.newAPIHadoopFile(path, fClass, kClass, vClass, conf)) } @@ -254,15 +250,15 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork fClass: Class[F], kClass: Class[K], vClass: Class[V]): JavaPairRDD[K, V] = { - implicit val kcm = ClassManifest.fromClass(kClass) - implicit val vcm = ClassManifest.fromClass(vClass) + implicit val kcm: ClassTag[K] = ClassTag(kClass) + implicit val vcm: ClassTag[V] = ClassTag(vClass) new JavaPairRDD(sc.newAPIHadoopRDD(conf, fClass, kClass, vClass)) } /** Build the union of two or more RDDs. */ override def union[T](first: JavaRDD[T], rest: java.util.List[JavaRDD[T]]): JavaRDD[T] = { val rdds: Seq[RDD[T]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd) - implicit val cm: ClassManifest[T] = first.classManifest + implicit val cm: ClassTag[T] = first.classTag sc.union(rdds)(cm) } @@ -270,9 +266,9 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork override def union[K, V](first: JavaPairRDD[K, V], rest: java.util.List[JavaPairRDD[K, V]]) : JavaPairRDD[K, V] = { val rdds: Seq[RDD[(K, V)]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd) - implicit val cm: ClassManifest[(K, V)] = first.classManifest - implicit val kcm: ClassManifest[K] = first.kManifest - implicit val vcm: ClassManifest[V] = first.vManifest + implicit val cm: ClassTag[(K, V)] = first.classTag + implicit val kcm: ClassTag[K] = first.kClassTag + implicit val vcm: ClassTag[V] = first.vClassTag new JavaPairRDD(sc.union(rdds)(cm))(kcm, vcm) } @@ -405,8 +401,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork } protected def checkpointFile[T](path: String): JavaRDD[T] = { - implicit val cm: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + implicit val cm: ClassTag[T] = + implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] new JavaRDD(sc.checkpointFile(path)) } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java index c9cbce5624..2090efd3b9 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java @@ -17,7 +17,6 @@ package org.apache.spark.api.java; -import java.util.Arrays; import java.util.ArrayList; import java.util.List; diff --git a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala index 2dfda8b09a..bdb01f7670 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala +++ b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala @@ -17,9 +17,11 @@ package org.apache.spark.api.java.function +import scala.reflect.ClassTag + /** * A function that returns zero or more output records from each input record. */ abstract class FlatMapFunction[T, R] extends Function[T, java.lang.Iterable[R]] { - def elementType() : ClassManifest[R] = ClassManifest.Any.asInstanceOf[ClassManifest[R]] + def elementType(): ClassTag[R] = ClassTag.Any.asInstanceOf[ClassTag[R]] } diff --git a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala index 528e1c0a7c..aae1349c5e 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala +++ b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala @@ -17,9 +17,11 @@ package org.apache.spark.api.java.function +import scala.reflect.ClassTag + /** * A function that takes two inputs and returns zero or more output records. */ abstract class FlatMapFunction2[A, B, C] extends Function2[A, B, java.lang.Iterable[C]] { - def elementType() : ClassManifest[C] = ClassManifest.Any.asInstanceOf[ClassManifest[C]] + def elementType() : ClassTag[C] = ClassTag.Any.asInstanceOf[ClassTag[C]] } diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function.java b/core/src/main/scala/org/apache/spark/api/java/function/Function.java index ce368ee01b..537439ef53 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/Function.java +++ b/core/src/main/scala/org/apache/spark/api/java/function/Function.java @@ -17,8 +17,8 @@ package org.apache.spark.api.java.function; -import scala.reflect.ClassManifest; -import scala.reflect.ClassManifest$; +import scala.reflect.ClassTag; +import scala.reflect.ClassTag$; import java.io.Serializable; @@ -29,8 +29,8 @@ import java.io.Serializable; * when mapping RDDs of other types. */ public abstract class Function extends WrappedFunction1 implements Serializable { - public ClassManifest returnType() { - return (ClassManifest) ClassManifest$.MODULE$.fromClass(Object.class); + public ClassTag returnType() { + return ClassTag$.MODULE$.apply(Object.class); } } diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function2.java b/core/src/main/scala/org/apache/spark/api/java/function/Function2.java index 44ad559d48..a2d1214fb4 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/Function2.java +++ b/core/src/main/scala/org/apache/spark/api/java/function/Function2.java @@ -17,8 +17,8 @@ package org.apache.spark.api.java.function; -import scala.reflect.ClassManifest; -import scala.reflect.ClassManifest$; +import scala.reflect.ClassTag; +import scala.reflect.ClassTag$; import java.io.Serializable; @@ -28,8 +28,8 @@ import java.io.Serializable; public abstract class Function2 extends WrappedFunction2 implements Serializable { - public ClassManifest returnType() { - return (ClassManifest) ClassManifest$.MODULE$.fromClass(Object.class); + public ClassTag returnType() { + return (ClassTag) ClassTag$.MODULE$.apply(Object.class); } } diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function3.java b/core/src/main/scala/org/apache/spark/api/java/function/Function3.java index ac6178924a..fb1deceab5 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/Function3.java +++ b/core/src/main/scala/org/apache/spark/api/java/function/Function3.java @@ -17,8 +17,8 @@ package org.apache.spark.api.java.function; -import scala.reflect.ClassManifest; -import scala.reflect.ClassManifest$; +import scala.reflect.ClassTag; +import scala.reflect.ClassTag$; import scala.runtime.AbstractFunction2; import java.io.Serializable; @@ -29,8 +29,8 @@ import java.io.Serializable; public abstract class Function3 extends WrappedFunction3 implements Serializable { - public ClassManifest returnType() { - return (ClassManifest) ClassManifest$.MODULE$.fromClass(Object.class); + public ClassTag returnType() { + return (ClassTag) ClassTag$.MODULE$.apply(Object.class); } } diff --git a/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java index 6d76a8f970..ca485b3cc2 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java +++ b/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java @@ -18,8 +18,8 @@ package org.apache.spark.api.java.function; import scala.Tuple2; -import scala.reflect.ClassManifest; -import scala.reflect.ClassManifest$; +import scala.reflect.ClassTag; +import scala.reflect.ClassTag$; import java.io.Serializable; @@ -33,11 +33,11 @@ public abstract class PairFlatMapFunction extends WrappedFunction1>> implements Serializable { - public ClassManifest keyType() { - return (ClassManifest) ClassManifest$.MODULE$.fromClass(Object.class); + public ClassTag keyType() { + return (ClassTag) ClassTag$.MODULE$.apply(Object.class); } - public ClassManifest valueType() { - return (ClassManifest) ClassManifest$.MODULE$.fromClass(Object.class); + public ClassTag valueType() { + return (ClassTag) ClassTag$.MODULE$.apply(Object.class); } } diff --git a/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java index ede7ceefb5..cbe2306026 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java +++ b/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java @@ -18,8 +18,8 @@ package org.apache.spark.api.java.function; import scala.Tuple2; -import scala.reflect.ClassManifest; -import scala.reflect.ClassManifest$; +import scala.reflect.ClassTag; +import scala.reflect.ClassTag$; import java.io.Serializable; @@ -31,11 +31,11 @@ import java.io.Serializable; public abstract class PairFunction extends WrappedFunction1> implements Serializable { - public ClassManifest keyType() { - return (ClassManifest) ClassManifest$.MODULE$.fromClass(Object.class); + public ClassTag keyType() { + return (ClassTag) ClassTag$.MODULE$.apply(Object.class); } - public ClassManifest valueType() { - return (ClassManifest) ClassManifest$.MODULE$.fromClass(Object.class); + public ClassTag valueType() { + return (ClassTag) ClassTag$.MODULE$.apply(Object.class); } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 12b4d94a56..a659cc06c2 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -22,18 +22,17 @@ import java.net._ import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} import scala.collection.JavaConversions._ +import scala.reflect.ClassTag import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import org.apache.spark.broadcast.Broadcast import org.apache.spark._ import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.PipedRDD import org.apache.spark.util.Utils - -private[spark] class PythonRDD[T: ClassManifest]( +private[spark] class PythonRDD[T: ClassTag]( parent: RDD[T], - command: Seq[String], + command: Array[Byte], envVars: JMap[String, String], pythonIncludes: JList[String], preservePartitoning: Boolean, @@ -44,21 +43,10 @@ private[spark] class PythonRDD[T: ClassManifest]( val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt - // Similar to Runtime.exec(), if we are given a single string, split it into words - // using a standard StringTokenizer (i.e. by spaces) - def this(parent: RDD[T], command: String, envVars: JMap[String, String], - pythonIncludes: JList[String], - preservePartitoning: Boolean, pythonExec: String, - broadcastVars: JList[Broadcast[Array[Byte]]], - accumulator: Accumulator[JList[Array[Byte]]]) = - this(parent, PipedRDD.tokenize(command), envVars, pythonIncludes, preservePartitoning, pythonExec, - broadcastVars, accumulator) - override def getPartitions = parent.partitions override val partitioner = if (preservePartitoning) parent.partitioner else None - override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { val startTime = System.currentTimeMillis val env = SparkEnv.get @@ -71,11 +59,10 @@ private[spark] class PythonRDD[T: ClassManifest]( SparkEnv.set(env) val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) - val printOut = new PrintWriter(stream) // Partition index dataOut.writeInt(split.index) // sparkFilesDir - PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut) + dataOut.writeUTF(SparkFiles.getRootDirectory) // Broadcast variables dataOut.writeInt(broadcastVars.length) for (broadcast <- broadcastVars) { @@ -85,21 +72,16 @@ private[spark] class PythonRDD[T: ClassManifest]( } // Python includes (*.zip and *.egg files) dataOut.writeInt(pythonIncludes.length) - for (f <- pythonIncludes) { - PythonRDD.writeAsPickle(f, dataOut) - } + pythonIncludes.foreach(dataOut.writeUTF) dataOut.flush() - // Serialized user code - for (elem <- command) { - printOut.println(elem) - } - printOut.flush() + // Serialized command: + dataOut.writeInt(command.length) + dataOut.write(command) // Data values for (elem <- parent.iterator(split, context)) { - PythonRDD.writeAsPickle(elem, dataOut) + PythonRDD.writeToStream(elem, dataOut) } dataOut.flush() - printOut.flush() worker.shutdownOutput() } catch { case e: IOException => @@ -132,7 +114,7 @@ private[spark] class PythonRDD[T: ClassManifest]( val obj = new Array[Byte](length) stream.readFully(obj) obj - case -3 => + case SpecialLengths.TIMING_DATA => // Timing data from worker val bootTime = stream.readLong() val initTime = stream.readLong() @@ -143,30 +125,30 @@ private[spark] class PythonRDD[T: ClassManifest]( val total = finishTime - startTime logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish)) read - case -2 => + case SpecialLengths.PYTHON_EXCEPTION_THROWN => // Signals that an exception has been thrown in python val exLength = stream.readInt() val obj = new Array[Byte](exLength) stream.readFully(obj) throw new PythonException(new String(obj)) - case -1 => + case SpecialLengths.END_OF_DATA_SECTION => // We've finished the data section of the output, but we can still - // read some accumulator updates; let's do that, breaking when we - // get a negative length record. - var len2 = stream.readInt() - while (len2 >= 0) { - val update = new Array[Byte](len2) + // read some accumulator updates: + val numAccumulatorUpdates = stream.readInt() + (1 to numAccumulatorUpdates).foreach { _ => + val updateLen = stream.readInt() + val update = new Array[Byte](updateLen) stream.readFully(update) accumulator += Collections.singletonList(update) - len2 = stream.readInt() + } - new Array[Byte](0) + Array.empty[Byte] } } catch { case eof: EOFException => { throw new SparkException("Python worker exited unexpectedly (crashed)", eof) } - case e => throw e + case e: Throwable => throw e } } @@ -197,62 +179,15 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends val asJavaPairRDD : JavaPairRDD[Long, Array[Byte]] = JavaPairRDD.fromRDD(this) } +private object SpecialLengths { + val END_OF_DATA_SECTION = -1 + val PYTHON_EXCEPTION_THROWN = -2 + val TIMING_DATA = -3 +} + private[spark] object PythonRDD { - /** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */ - def stripPickle(arr: Array[Byte]) : Array[Byte] = { - arr.slice(2, arr.length - 1) - } - - /** - * Write strings, pickled Python objects, or pairs of pickled objects to a data output stream. - * The data format is a 32-bit integer representing the pickled object's length (in bytes), - * followed by the pickled data. - * - * Pickle module: - * - * http://docs.python.org/2/library/pickle.html - * - * The pickle protocol is documented in the source of the `pickle` and `pickletools` modules: - * - * http://hg.python.org/cpython/file/2.6/Lib/pickle.py - * http://hg.python.org/cpython/file/2.6/Lib/pickletools.py - * - * @param elem the object to write - * @param dOut a data output stream - */ - def writeAsPickle(elem: Any, dOut: DataOutputStream) { - if (elem.isInstanceOf[Array[Byte]]) { - val arr = elem.asInstanceOf[Array[Byte]] - dOut.writeInt(arr.length) - dOut.write(arr) - } else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) { - val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]] - val length = t._1.length + t._2.length - 3 - 3 + 4 // stripPickle() removes 3 bytes - dOut.writeInt(length) - dOut.writeByte(Pickle.PROTO) - dOut.writeByte(Pickle.TWO) - dOut.write(PythonRDD.stripPickle(t._1)) - dOut.write(PythonRDD.stripPickle(t._2)) - dOut.writeByte(Pickle.TUPLE2) - dOut.writeByte(Pickle.STOP) - } else if (elem.isInstanceOf[String]) { - // For uniformity, strings are wrapped into Pickles. - val s = elem.asInstanceOf[String].getBytes("UTF-8") - val length = 2 + 1 + 4 + s.length + 1 - dOut.writeInt(length) - dOut.writeByte(Pickle.PROTO) - dOut.writeByte(Pickle.TWO) - dOut.write(Pickle.BINUNICODE) - dOut.writeInt(Integer.reverseBytes(s.length)) - dOut.write(s) - dOut.writeByte(Pickle.STOP) - } else { - throw new SparkException("Unexpected RDD type") - } - } - - def readRDDFromPickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) : + def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int): JavaRDD[Array[Byte]] = { val file = new DataInputStream(new FileInputStream(filename)) val objs = new collection.mutable.ArrayBuffer[Array[Byte]] @@ -265,41 +200,47 @@ private[spark] object PythonRDD { } } catch { case eof: EOFException => {} - case e => throw e + case e: Throwable => throw e } JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) } - def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) { - import scala.collection.JavaConverters._ - writeIteratorToPickleFile(items.asScala, filename) + def writeToStream(elem: Any, dataOut: DataOutputStream) { + elem match { + case bytes: Array[Byte] => + dataOut.writeInt(bytes.length) + dataOut.write(bytes) + case pair: (Array[Byte], Array[Byte]) => + dataOut.writeInt(pair._1.length) + dataOut.write(pair._1) + dataOut.writeInt(pair._2.length) + dataOut.write(pair._2) + case str: String => + dataOut.writeUTF(str) + case other => + throw new SparkException("Unexpected element type " + other.getClass) + } } - def writeIteratorToPickleFile[T](items: Iterator[T], filename: String) { + def writeToFile[T](items: java.util.Iterator[T], filename: String) { + import scala.collection.JavaConverters._ + writeToFile(items.asScala, filename) + } + + def writeToFile[T](items: Iterator[T], filename: String) { val file = new DataOutputStream(new FileOutputStream(filename)) for (item <- items) { - writeAsPickle(item, file) + writeToStream(item, file) } file.close() } def takePartition[T](rdd: RDD[T], partition: Int): Iterator[T] = { - implicit val cm : ClassManifest[T] = rdd.elementClassManifest + implicit val cm : ClassTag[T] = rdd.elementClassTag rdd.context.runJob(rdd, ((x: Iterator[T]) => x.toArray), Seq(partition), true).head.iterator } } -private object Pickle { - val PROTO: Byte = 0x80.toByte - val TWO: Byte = 0x02.toByte - val BINUNICODE: Byte = 'X' - val STOP: Byte = '.' - val TUPLE2: Byte = 0x86.toByte - val EMPTY_LIST: Byte = ']' - val MARK: Byte = '(' - val APPENDS: Byte = 'e' -} - private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] { override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8") } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 67d45723ba..f291266fcf 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -64,7 +64,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String startDaemon() new Socket(daemonHost, daemonPort) } - case e => throw e + case e: Throwable => throw e } } } @@ -198,7 +198,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } }.start() } catch { - case e => { + case e: Throwable => { stopDaemon() throw e } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 609464e38d..47db720416 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -19,6 +19,7 @@ package org.apache.spark.broadcast import java.io.{File, FileOutputStream, ObjectInputStream, OutputStream} import java.net.URL +import java.util.concurrent.TimeUnit import it.unimi.dsi.fastutil.io.FastBufferedInputStream import it.unimi.dsi.fastutil.io.FastBufferedOutputStream @@ -83,6 +84,8 @@ private object HttpBroadcast extends Logging { private val files = new TimeStampedHashSet[String] private val cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup) + private val httpReadTimeout = TimeUnit.MILLISECONDS.convert(5,TimeUnit.MINUTES).toInt + private lazy val compressionCodec = CompressionCodec.createCodec() def initialize(isDriver: Boolean) { @@ -138,10 +141,13 @@ private object HttpBroadcast extends Logging { def read[T](id: Long): T = { val url = serverUri + "/" + BroadcastBlockId(id).name val in = { + val httpConnection = new URL(url).openConnection() + httpConnection.setReadTimeout(httpReadTimeout) + val inputStream = httpConnection.getInputStream() if (compress) { - compressionCodec.compressedInputStream(new URL(url).openStream()) + compressionCodec.compressedInputStream(inputStream) } else { - new FastBufferedInputStream(new URL(url).openStream(), bufferSize) + new FastBufferedInputStream(inputStream, bufferSize) } } val ser = SparkEnv.get.serializer.newInstance() diff --git a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala index fcfea96ad6..37dfa7fec0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala @@ -17,8 +17,7 @@ package org.apache.spark.deploy -private[spark] object ExecutorState - extends Enumeration("LAUNCHING", "LOADING", "RUNNING", "KILLED", "FAILED", "LOST") { +private[spark] object ExecutorState extends Enumeration { val LAUNCHING, LOADING, RUNNING, KILLED, FAILED, LOST = Value diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index 668032a3a2..0aa8852649 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -1,19 +1,19 @@ /* * - * * Licensed to the Apache Software Foundation (ASF) under one or more - * * contributor license agreements. See the NOTICE file distributed with - * * this work for additional information regarding copyright ownership. - * * The ASF licenses this file to You under the Apache License, Version 2.0 - * * (the "License"); you may not use this file except in compliance with - * * the License. You may obtain a copy of the License at - * * - * * http://www.apache.org/licenses/LICENSE-2.0 - * * - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, - * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * * See the License for the specific language governing permissions and - * * limitations under the License. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. * */ diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index a724900943..59d12a3e6f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -34,11 +34,11 @@ import scala.collection.mutable.ArrayBuffer */ private[spark] class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) extends Logging { - + private val localHostname = Utils.localHostName() private val masterActorSystems = ArrayBuffer[ActorSystem]() private val workerActorSystems = ArrayBuffer[ActorSystem]() - + def start(): Array[String] = { logInfo("Starting a local Spark cluster with " + numWorkers + " workers.") @@ -61,10 +61,13 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I def stop() { logInfo("Shutting down local Spark cluster.") // Stop the workers before the master so they don't get upset that it disconnected + // TODO: In Akka 2.1.x, ActorSystem.awaitTermination hangs when you have remote actors! + // This is unfortunate, but for now we just comment it out. workerActorSystems.foreach(_.shutdown()) - workerActorSystems.foreach(_.awaitTermination()) - + //workerActorSystems.foreach(_.awaitTermination()) masterActorSystems.foreach(_.shutdown()) - masterActorSystems.foreach(_.awaitTermination()) + //masterActorSystems.foreach(_.awaitTermination()) + masterActorSystems.clear() + workerActorSystems.clear() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/Client.scala b/core/src/main/scala/org/apache/spark/deploy/client/Client.scala index 77422f61ec..4d95efa73a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/Client.scala @@ -19,17 +19,15 @@ package org.apache.spark.deploy.client import java.util.concurrent.TimeoutException -import akka.actor._ -import akka.actor.Terminated -import akka.pattern.ask -import akka.util.Duration -import akka.util.duration._ -import akka.remote.RemoteClientDisconnected -import akka.remote.RemoteClientLifeCycleEvent -import akka.remote.RemoteClientShutdown -import akka.dispatch.Await +import scala.concurrent.duration._ +import scala.concurrent.Await -import org.apache.spark.Logging +import akka.actor._ +import akka.pattern.AskTimeoutException +import akka.pattern.ask +import akka.remote.{RemotingLifecycleEvent, DisassociatedEvent, AssociationErrorEvent} + +import org.apache.spark.{SparkException, Logging} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.Master @@ -51,18 +49,19 @@ private[spark] class Client( val REGISTRATION_TIMEOUT = 20.seconds val REGISTRATION_RETRIES = 3 + var masterAddress: Address = null var actor: ActorRef = null var appId: String = null var registered = false var activeMasterUrl: String = null class ClientActor extends Actor with Logging { - var master: ActorRef = null - var masterAddress: Address = null + var master: ActorSelection = null var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times var alreadyDead = false // To avoid calling listener.dead() multiple times override def preStart() { + context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) try { registerWithMaster() } catch { @@ -76,7 +75,7 @@ private[spark] class Client( def tryRegisterAllMasters() { for (masterUrl <- masterUrls) { logInfo("Connecting to master " + masterUrl + "...") - val actor = context.actorFor(Master.toAkkaUrl(masterUrl)) + val actor = context.actorSelection(Master.toAkkaUrl(masterUrl)) actor ! RegisterApplication(appDescription) } } @@ -84,6 +83,7 @@ private[spark] class Client( def registerWithMaster() { tryRegisterAllMasters() + import context.dispatcher var retries = 0 lazy val retryTimer: Cancellable = context.system.scheduler.schedule(REGISTRATION_TIMEOUT, REGISTRATION_TIMEOUT) { @@ -102,10 +102,13 @@ private[spark] class Client( def changeMaster(url: String) { activeMasterUrl = url - master = context.actorFor(Master.toAkkaUrl(url)) - masterAddress = master.path.address - context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) - context.watch(master) // Doesn't work with remote actors, but useful for testing + master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl)) + masterAddress = activeMasterUrl match { + case Master.sparkUrlRegex(host, port) => + Address("akka.tcp", Master.systemName, host, port.toInt) + case x => + throw new SparkException("Invalid spark URL: " + x) + } } override def receive = { @@ -135,21 +138,12 @@ private[spark] class Client( case MasterChanged(masterUrl, masterWebUiUrl) => logInfo("Master has changed, new master is at " + masterUrl) - context.unwatch(master) changeMaster(masterUrl) alreadyDisconnected = false sender ! MasterChangeAcknowledged(appId) - case Terminated(actor_) if actor_ == master => - logWarning("Connection to master failed; waiting for master to reconnect...") - markDisconnected() - - case RemoteClientDisconnected(transport, address) if address == masterAddress => - logWarning("Connection to master failed; waiting for master to reconnect...") - markDisconnected() - - case RemoteClientShutdown(transport, address) if address == masterAddress => - logWarning("Connection to master failed; waiting for master to reconnect...") + case DisassociatedEvent(_, address, _) if address == masterAddress => + logWarning(s"Connection to $address failed; waiting for master to reconnect...") markDisconnected() case StopClient => diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala index fedf879eff..67e6c5d66a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala @@ -17,8 +17,7 @@ package org.apache.spark.deploy.master -private[spark] object ApplicationState - extends Enumeration("WAITING", "RUNNING", "FINISHED", "FAILED", "UNKNOWN") { +private[spark] object ApplicationState extends Enumeration { type ApplicationState = Value diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala index c0849ef324..043945a211 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala @@ -65,7 +65,7 @@ private[spark] class FileSystemPersistenceEngine( (apps, workers) } - private def serializeIntoFile(file: File, value: Serializable) { + private def serializeIntoFile(file: File, value: AnyRef) { val created = file.createNewFile() if (!created) { throw new IllegalStateException("Could not create file: " + file) } @@ -77,13 +77,13 @@ private[spark] class FileSystemPersistenceEngine( out.close() } - def deserializeFromFile[T <: Serializable](file: File)(implicit m: Manifest[T]): T = { + def deserializeFromFile[T](file: File)(implicit m: Manifest[T]): T = { val fileData = new Array[Byte](file.length().asInstanceOf[Int]) val dis = new DataInputStream(new FileInputStream(file)) dis.readFully(fileData) dis.close() - val clazz = m.erasure.asInstanceOf[Class[T]] + val clazz = m.runtimeClass.asInstanceOf[Class[T]] val serializer = serialization.serializerFor(clazz) serializer.fromBinary(fileData).asInstanceOf[T] } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index cd916672ac..c627dd3806 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -17,19 +17,20 @@ package org.apache.spark.deploy.master -import java.util.Date import java.text.SimpleDateFormat +import java.util.concurrent.TimeUnit +import java.util.Date import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.concurrent.Await +import scala.concurrent.duration._ +import scala.concurrent.duration.{Duration, FiniteDuration} import akka.actor._ -import akka.actor.Terminated -import akka.dispatch.Await import akka.pattern.ask -import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientDisconnected, RemoteClientShutdown} +import akka.remote._ import akka.serialization.SerializationExtension -import akka.util.duration._ -import akka.util.{Duration, Timeout} +import akka.util.Timeout import org.apache.spark.{Logging, SparkException} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} @@ -37,9 +38,11 @@ import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.MasterMessages._ import org.apache.spark.deploy.master.ui.MasterWebUI import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.{Utils, AkkaUtils} private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging { + import context.dispatcher + val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs val WORKER_TIMEOUT = System.getProperty("spark.worker.timeout", "60").toLong * 1000 val RETAINED_APPLICATIONS = System.getProperty("spark.deploy.retainedApplications", "200").toInt @@ -93,7 +96,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act override def preStart() { logInfo("Starting Spark master at " + masterUrl) // Listen for remote client disconnection events, since they don't go through Akka's watch() - context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) + context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) webUi.start() masterWebUiUrl = "http://" + masterPublicAddress + ":" + webUi.boundPort.get context.system.scheduler.schedule(0 millis, WORKER_TIMEOUT millis, self, CheckForWorkerTimeOut) @@ -113,13 +116,12 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act new BlackHolePersistenceEngine() } - leaderElectionAgent = context.actorOf(Props( - RECOVERY_MODE match { + leaderElectionAgent = RECOVERY_MODE match { case "ZOOKEEPER" => - new ZooKeeperLeaderElectionAgent(self, masterUrl) + context.actorOf(Props(classOf[ZooKeeperLeaderElectionAgent], self, masterUrl)) case _ => - new MonarchyLeaderAgent(self) - })) + context.actorOf(Props(classOf[MonarchyLeaderAgent], self)) + } } override def preRestart(reason: Throwable, message: Option[Any]) { @@ -142,9 +144,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act RecoveryState.ALIVE else RecoveryState.RECOVERING - logInfo("I have been elected leader! New state: " + state) - if (state == RecoveryState.RECOVERING) { beginRecovery(storedApps, storedWorkers) context.system.scheduler.scheduleOnce(WORKER_TIMEOUT millis) { completeRecovery() } @@ -156,7 +156,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act System.exit(0) } - case RegisterWorker(id, host, workerPort, cores, memory, webUiPort, publicAddress) => { + case RegisterWorker(id, workerHost, workerPort, cores, memory, workerWebUiPort, publicAddress) => { logInfo("Registering worker %s:%d with %d cores, %s RAM".format( host, workerPort, cores, Utils.megabytesToString(memory))) if (state == RecoveryState.STANDBY) { @@ -164,9 +164,9 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act } else if (idToWorker.contains(id)) { sender ! RegisterWorkerFailed("Duplicate worker ID") } else { - val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress) + val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory, + sender, workerWebUiPort, publicAddress) registerWorker(worker) - context.watch(sender) // This doesn't work with remote actors but helps for testing persistenceEngine.addWorker(worker) sender ! RegisteredWorker(masterUrl, masterWebUiUrl) schedule() @@ -181,7 +181,6 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act val app = createApplication(description, sender) registerApplication(app) logInfo("Registered app " + description.name + " with ID " + app.id) - context.watch(sender) // This doesn't work with remote actors but helps for testing persistenceEngine.addApplication(app) sender ! RegisteredApplication(app.id, masterUrl) schedule() @@ -257,23 +256,9 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act if (canCompleteRecovery) { completeRecovery() } } - case Terminated(actor) => { - // The disconnected actor could've been either a worker or an app; remove whichever of - // those we have an entry for in the corresponding actor hashmap - actorToWorker.get(actor).foreach(removeWorker) - actorToApp.get(actor).foreach(finishApplication) - if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() } - } - - case RemoteClientDisconnected(transport, address) => { - // The disconnected client could've been either a worker or an app; remove whichever it was - addressToWorker.get(address).foreach(removeWorker) - addressToApp.get(address).foreach(finishApplication) - if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() } - } - - case RemoteClientShutdown(transport, address) => { + case DisassociatedEvent(_, address, _) => { // The disconnected client could've been either a worker or an app; remove whichever it was + logInfo(s"$address got disassociated, removing it.") addressToWorker.get(address).foreach(removeWorker) addressToApp.get(address).foreach(finishApplication) if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() } @@ -530,9 +515,9 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act } private[spark] object Master { - private val systemName = "sparkMaster" + val systemName = "sparkMaster" private val actorName = "Master" - private val sparkUrlRegex = "spark://([^:]+):([0-9]+)".r + val sparkUrlRegex = "spark://([^:]+):([0-9]+)".r def main(argStrings: Array[String]) { val args = new MasterArguments(argStrings) @@ -540,11 +525,11 @@ private[spark] object Master { actorSystem.awaitTermination() } - /** Returns an `akka://...` URL for the Master actor given a sparkUrl `spark://host:ip`. */ + /** Returns an `akka.tcp://...` URL for the Master actor given a sparkUrl `spark://host:ip`. */ def toAkkaUrl(sparkUrl: String): String = { sparkUrl match { case sparkUrlRegex(host, port) => - "akka://%s@%s:%s/user/%s".format(systemName, host, port, actorName) + "akka.tcp://%s@%s:%s/user/%s".format(systemName, host, port, actorName) case _ => throw new SparkException("Invalid master URL: " + sparkUrl) } @@ -552,9 +537,9 @@ private[spark] object Master { def startSystemAndActor(host: String, port: Int, webUiPort: Int): (ActorSystem, Int, Int) = { val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port) - val actor = actorSystem.actorOf(Props(new Master(host, boundPort, webUiPort)), name = actorName) - val timeoutDuration = Duration.create( - System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") + val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort), name = actorName) + val timeoutDuration: FiniteDuration = Duration.create( + System.getProperty("spark.akka.askTimeout", "10").toLong, TimeUnit.SECONDS) implicit val timeout = Timeout(timeoutDuration) val respFuture = actor ? RequestWebUIPort // ask pattern val resp = Await.result(respFuture, timeoutDuration).asInstanceOf[WebUIPortResponse] diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala index b91be821f0..256a5a7c28 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala @@ -17,9 +17,7 @@ package org.apache.spark.deploy.master -private[spark] object RecoveryState - extends Enumeration("STANDBY", "ALIVE", "RECOVERING", "COMPLETING_RECOVERY") { - +private[spark] object RecoveryState extends Enumeration { type MasterState = Value val STANDBY, ALIVE, RECOVERING, COMPLETING_RECOVERY = Value diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala index c8d34f25e2..0b36ef6005 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala @@ -17,9 +17,7 @@ package org.apache.spark.deploy.master -private[spark] object WorkerState - extends Enumeration("ALIVE", "DEAD", "DECOMMISSIONED", "UNKNOWN") { - +private[spark] object WorkerState extends Enumeration { type WorkerState = Value val ALIVE, DEAD, DECOMMISSIONED, UNKNOWN = Value diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index a0233a7271..825344b3bb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -70,15 +70,15 @@ class ZooKeeperPersistenceEngine(serialization: Serialization) (apps, workers) } - private def serializeIntoFile(path: String, value: Serializable) { + private def serializeIntoFile(path: String, value: AnyRef) { val serializer = serialization.findSerializerFor(value) val serialized = serializer.toBinary(value) zk.create(path, serialized, CreateMode.PERSISTENT) } - def deserializeFromFile[T <: Serializable](filename: String)(implicit m: Manifest[T]): T = { + def deserializeFromFile[T](filename: String)(implicit m: Manifest[T]): T = { val fileData = zk.getData("/spark/master_status/" + filename) - val clazz = m.erasure.asInstanceOf[Class[T]] + val clazz = m.runtimeClass.asInstanceOf[Class[T]] val serializer = serialization.serializerFor(clazz) serializer.fromBinary(fileData).asInstanceOf[T] } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index f4e574d15d..3b983c19eb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -19,9 +19,10 @@ package org.apache.spark.deploy.master.ui import scala.xml.Node -import akka.dispatch.Await import akka.pattern.ask -import akka.util.duration._ + +import scala.concurrent.Await +import scala.concurrent.duration._ import javax.servlet.http.HttpServletRequest diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala index d7a57229b0..65e7a14e7a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala @@ -21,9 +21,9 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node -import akka.dispatch.Await +import scala.concurrent.Await import akka.pattern.ask -import akka.util.duration._ +import scala.concurrent.duration._ import net.liftweb.json.JsonAST.JValue diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index f4df729e87..a211ce2b42 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.master.ui -import akka.util.Duration +import scala.concurrent.duration._ import javax.servlet.http.HttpServletRequest diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 216d9d44ac..87531b6719 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -17,23 +17,31 @@ package org.apache.spark.deploy.worker +import java.io.File import java.text.SimpleDateFormat import java.util.Date -import java.io.File import scala.collection.mutable.HashMap +import scala.concurrent.duration._ import akka.actor._ -import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected} -import akka.util.duration._ +import akka.remote.{ DisassociatedEvent, RemotingLifecycleEvent} -import org.apache.spark.Logging +import org.apache.spark.{SparkException, Logging} import org.apache.spark.deploy.{ExecutorDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.Master import org.apache.spark.deploy.worker.ui.WorkerWebUI import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.{Utils, AkkaUtils} +import org.apache.spark.deploy.DeployMessages.WorkerStateResponse +import org.apache.spark.deploy.DeployMessages.RegisterWorkerFailed +import org.apache.spark.deploy.DeployMessages.KillExecutor +import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged +import org.apache.spark.deploy.DeployMessages.Heartbeat +import org.apache.spark.deploy.DeployMessages.RegisteredWorker +import org.apache.spark.deploy.DeployMessages.LaunchExecutor +import org.apache.spark.deploy.DeployMessages.RegisterWorker /** * @param masterUrls Each url should look like spark://host:port. @@ -47,6 +55,7 @@ private[spark] class Worker( masterUrls: Array[String], workDirPath: String = null) extends Actor with Logging { + import context.dispatcher Utils.checkHost(host, "Expected hostname") assert (port > 0) @@ -63,7 +72,8 @@ private[spark] class Worker( var masterIndex = 0 val masterLock: Object = new Object() - var master: ActorRef = null + var master: ActorSelection = null + var masterAddress: Address = null var activeMasterUrl: String = "" var activeMasterWebUiUrl : String = "" @volatile var registered = false @@ -114,7 +124,7 @@ private[spark] class Worker( logInfo("Spark home: " + sparkHome) createWorkDir() webUi = new WorkerWebUI(this, workDir, Some(webUiPort)) - + context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) webUi.start() registerWithMaster() @@ -126,9 +136,13 @@ private[spark] class Worker( masterLock.synchronized { activeMasterUrl = url activeMasterWebUiUrl = uiUrl - master = context.actorFor(Master.toAkkaUrl(activeMasterUrl)) - context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) - context.watch(master) // Doesn't work with remote actors, but useful for testing + master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl)) + masterAddress = activeMasterUrl match { + case Master.sparkUrlRegex(_host, _port) => + Address("akka.tcp", Master.systemName, _host, _port.toInt) + case x => + throw new SparkException("Invalid spark URL: " + x) + } connected = true } } @@ -136,7 +150,7 @@ private[spark] class Worker( def tryRegisterAllMasters() { for (masterUrl <- masterUrls) { logInfo("Connecting to master " + masterUrl + "...") - val actor = context.actorFor(Master.toAkkaUrl(masterUrl)) + val actor = context.actorSelection(Master.toAkkaUrl(masterUrl)) actor ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort.get, publicAddress) } @@ -175,7 +189,6 @@ private[spark] class Worker( case MasterChanged(masterUrl, masterWebUiUrl) => logInfo("Master has changed, new master is at " + masterUrl) - context.unwatch(master) changeMaster(masterUrl, masterWebUiUrl) val execs = executors.values. @@ -234,13 +247,8 @@ private[spark] class Worker( } } - case Terminated(actor_) if actor_ == master => - masterDisconnected() - - case RemoteClientDisconnected(transport, address) if address == master.path.address => - masterDisconnected() - - case RemoteClientShutdown(transport, address) if address == master.path.address => + case x: DisassociatedEvent if x.remoteAddress == masterAddress => + logInfo(s"$x Disassociated !") masterDisconnected() case RequestWorkerState => { @@ -280,8 +288,8 @@ private[spark] object Worker { // The LocalSparkCluster runs multiple local sparkWorkerX actor systems val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("") val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port) - val actor = actorSystem.actorOf(Props(new Worker(host, boundPort, webUiPort, cores, memory, - masterUrls, workDir)), name = "Worker") + actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory, + masterUrls, workDir), name = "Worker") (actorSystem, boundPort) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala index d2d3617498..1a768d501f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala @@ -21,9 +21,10 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node -import akka.dispatch.Await +import scala.concurrent.duration._ +import scala.concurrent.Await + import akka.pattern.ask -import akka.util.duration._ import net.liftweb.json.JsonAST.JValue diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index 800f1cafcc..6c18a3c245 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -17,20 +17,19 @@ package org.apache.spark.deploy.worker.ui -import akka.util.{Duration, Timeout} +import java.io.File -import java.io.{FileInputStream, File} +import scala.concurrent.duration._ +import akka.util.Timeout import javax.servlet.http.HttpServletRequest -import org.eclipse.jetty.server.{Handler, Server} - +import org.apache.spark.Logging import org.apache.spark.deploy.worker.Worker -import org.apache.spark.{Logging} -import org.apache.spark.ui.JettyUtils +import org.apache.spark.ui.{JettyUtils, UIUtils} import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.ui.UIUtils import org.apache.spark.util.Utils +import org.eclipse.jetty.server.{Handler, Server} /** * Web UI server for the standalone worker. diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 8332631838..debbdd4c44 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -19,15 +19,14 @@ package org.apache.spark.executor import java.nio.ByteBuffer -import akka.actor.{ActorRef, Actor, Props, Terminated} -import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected} +import akka.actor._ +import akka.remote._ import org.apache.spark.Logging import org.apache.spark.TaskState.TaskState import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.{Utils, AkkaUtils} - private[spark] class CoarseGrainedExecutorBackend( driverUrl: String, executorId: String, @@ -40,14 +39,13 @@ private[spark] class CoarseGrainedExecutorBackend( Utils.checkHostPort(hostPort, "Expected hostport") var executor: Executor = null - var driver: ActorRef = null + var driver: ActorSelection = null override def preStart() { logInfo("Connecting to driver: " + driverUrl) - driver = context.actorFor(driverUrl) + driver = context.actorSelection(driverUrl) driver ! RegisterExecutor(executorId, hostPort, cores) - context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) - context.watch(driver) // Doesn't work with remote actors, but useful for testing + context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) } override def receive = { @@ -77,8 +75,8 @@ private[spark] class CoarseGrainedExecutorBackend( executor.killTask(taskId) } - case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) => - logError("Driver terminated or disconnected! Shutting down.") + case x: DisassociatedEvent => + logError(s"Driver $x disassociated! Shutting down.") System.exit(1) case StopExecutor => @@ -99,12 +97,13 @@ private[spark] object CoarseGrainedExecutorBackend { // Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor // before getting started with all our system properties, etc - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0) + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0, + indestructible = true) // set it val sparkHostPort = hostname + ":" + boundPort System.setProperty("spark.hostPort", sparkHostPort) - val actor = actorSystem.actorOf( - Props(new CoarseGrainedExecutorBackend(driverUrl, executorId, sparkHostPort, cores)), + actorSystem.actorOf( + Props(classOf[CoarseGrainedExecutorBackend], driverUrl, executorId, sparkHostPort, cores), name = "Executor") actorSystem.awaitTermination() } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 5c9bb9db1c..0b0a60ee60 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -121,7 +121,7 @@ private[spark] class Executor( // Akka's message frame size. If task result is bigger than this, we use the block manager // to send the result back. private val akkaFrameSize = { - env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size") + env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size") } // Start worker thread pool diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 0b4892f98f..c0ce46e379 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -61,50 +61,53 @@ object TaskMetrics { class ShuffleReadMetrics extends Serializable { /** - * Time when shuffle finishs + * Absolute time when this task finished reading shuffle data */ var shuffleFinishTime: Long = _ /** - * Total number of blocks fetched in a shuffle (remote or local) + * Number of blocks fetched in this shuffle by this task (remote or local) */ var totalBlocksFetched: Int = _ /** - * Number of remote blocks fetched in a shuffle + * Number of remote blocks fetched in this shuffle by this task */ var remoteBlocksFetched: Int = _ /** - * Local blocks fetched in a shuffle + * Number of local blocks fetched in this shuffle by this task */ var localBlocksFetched: Int = _ /** - * Total time that is spent blocked waiting for shuffle to fetch data + * Time the task spent waiting for remote shuffle blocks. This only includes the time + * blocking on shuffle input data. For instance if block B is being fetched while the task is + * still not finished processing block A, it is not considered to be blocking on block B. */ var fetchWaitTime: Long = _ /** - * The total amount of time for all the shuffle fetches. This adds up time from overlapping - * shuffles, so can be longer than task time + * Total time spent fetching remote shuffle blocks. This aggregates the time spent fetching all + * input blocks. Since block fetches are both pipelined and parallelized, this can + * exceed fetchWaitTime and executorRunTime. */ var remoteFetchTime: Long = _ /** - * Total number of remote bytes read from a shuffle + * Total number of remote bytes read from the shuffle by this task */ var remoteBytesRead: Long = _ } class ShuffleWriteMetrics extends Serializable { /** - * Number of bytes written for a shuffle + * Number of bytes written for the shuffle by this task */ var shuffleBytesWritten: Long = _ /** - * Time spent blocking on writes to disk or buffer cache, in nanoseconds. + * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds */ var shuffleWriteTime: Long = _ } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala new file mode 100644 index 0000000000..cdcfec8ca7 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.metrics.sink + +import java.util.Properties +import java.util.concurrent.TimeUnit +import java.net.InetSocketAddress + +import com.codahale.metrics.MetricRegistry +import com.codahale.metrics.graphite.{GraphiteReporter, Graphite} + +import org.apache.spark.metrics.MetricsSystem + +class GraphiteSink(val property: Properties, val registry: MetricRegistry) extends Sink { + val GRAPHITE_DEFAULT_PERIOD = 10 + val GRAPHITE_DEFAULT_UNIT = "SECONDS" + val GRAPHITE_DEFAULT_PREFIX = "" + + val GRAPHITE_KEY_HOST = "host" + val GRAPHITE_KEY_PORT = "port" + val GRAPHITE_KEY_PERIOD = "period" + val GRAPHITE_KEY_UNIT = "unit" + val GRAPHITE_KEY_PREFIX = "prefix" + + def propertyToOption(prop: String) = Option(property.getProperty(prop)) + + if (!propertyToOption(GRAPHITE_KEY_HOST).isDefined) { + throw new Exception("Graphite sink requires 'host' property.") + } + + if (!propertyToOption(GRAPHITE_KEY_PORT).isDefined) { + throw new Exception("Graphite sink requires 'port' property.") + } + + val host = propertyToOption(GRAPHITE_KEY_HOST).get + val port = propertyToOption(GRAPHITE_KEY_PORT).get.toInt + + val pollPeriod = propertyToOption(GRAPHITE_KEY_PERIOD) match { + case Some(s) => s.toInt + case None => GRAPHITE_DEFAULT_PERIOD + } + + val pollUnit = propertyToOption(GRAPHITE_KEY_UNIT) match { + case Some(s) => TimeUnit.valueOf(s.toUpperCase()) + case None => TimeUnit.valueOf(GRAPHITE_DEFAULT_UNIT) + } + + val prefix = propertyToOption(GRAPHITE_KEY_PREFIX).getOrElse(GRAPHITE_DEFAULT_PREFIX) + + MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod) + + val graphite: Graphite = new Graphite(new InetSocketAddress(host, port)) + + val reporter: GraphiteReporter = GraphiteReporter.forRegistry(registry) + .convertDurationsTo(TimeUnit.MILLISECONDS) + .convertRatesTo(TimeUnit.SECONDS) + .prefixedWith(prefix) + .build(graphite) + + override def start() { + reporter.start(pollPeriod, pollUnit) + } + + override def stop() { + reporter.stop() + } +} diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index 9c2fee4023..703bc6a9ca 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -31,11 +31,11 @@ import scala.collection.mutable.SynchronizedMap import scala.collection.mutable.SynchronizedQueue import scala.collection.mutable.ArrayBuffer -import akka.dispatch.{Await, Promise, ExecutionContext, Future} -import akka.util.Duration -import akka.util.duration._ -import org.apache.spark.util.Utils +import scala.concurrent.{Await, Promise, ExecutionContext, Future} +import scala.concurrent.duration.Duration +import scala.concurrent.duration._ +import org.apache.spark.util.Utils private[spark] class ConnectionManager(port: Int) extends Logging { diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala index 8d9ad9604d..4f5742d29b 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala @@ -25,8 +25,8 @@ import scala.io.Source import java.nio.ByteBuffer import java.net.InetAddress -import akka.dispatch.Await -import akka.util.duration._ +import scala.concurrent.Await +import scala.concurrent.duration._ private[spark] object ConnectionManagerTest extends Logging{ def main(args: Array[String]) { diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala index 481ff8c3e0..b1e1576dad 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala @@ -76,7 +76,7 @@ private[spark] object ShuffleCopier extends Logging { extends FileClientHandler with Logging { override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) { - logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)"); + logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)") resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index faaf837be0..d1c74a5063 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -21,6 +21,7 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable.ArrayBuffer import scala.concurrent.ExecutionContext.Implicits.global +import scala.reflect.ClassTag import org.apache.spark.{ComplexFutureAction, FutureAction, Logging} @@ -28,7 +29,7 @@ import org.apache.spark.{ComplexFutureAction, FutureAction, Logging} * A set of asynchronous RDD actions available through an implicit conversion. * Import `org.apache.spark.SparkContext._` at the top of your program to use these functions. */ -class AsyncRDDActions[T: ClassManifest](self: RDD[T]) extends Serializable with Logging { +class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Logging { /** * Returns a future for counting the number of elements in the RDD. diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index 44ea573a7c..424354ae16 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -17,6 +17,8 @@ package org.apache.spark.rdd +import scala.reflect.ClassTag + import org.apache.spark.{SparkContext, SparkEnv, Partition, TaskContext} import org.apache.spark.storage.{BlockId, BlockManager} @@ -25,7 +27,7 @@ private[spark] class BlockRDDPartition(val blockId: BlockId, idx: Int) extends P } private[spark] -class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[BlockId]) +class BlockRDD[T: ClassTag](sc: SparkContext, @transient blockIds: Array[BlockId]) extends RDD[T](sc, Nil) { @transient lazy val locations_ = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get) diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala index 9b0c882481..87b950ba43 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala @@ -18,6 +18,7 @@ package org.apache.spark.rdd import java.io.{ObjectOutputStream, IOException} +import scala.reflect.ClassTag import org.apache.spark._ @@ -43,7 +44,7 @@ class CartesianPartition( } private[spark] -class CartesianRDD[T: ClassManifest, U:ClassManifest]( +class CartesianRDD[T: ClassTag, U: ClassTag]( sc: SparkContext, var rdd1 : RDD[T], var rdd2 : RDD[U]) @@ -70,7 +71,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( override def compute(split: Partition, context: TaskContext) = { val currSplit = split.asInstanceOf[CartesianPartition] for (x <- rdd1.iterator(currSplit.s1, context); - y <- rdd2.iterator(currSplit.s2, context)) yield (x, y) + y <- rdd2.iterator(currSplit.s2, context)) yield (x, y) } override def getDependencies: Seq[Dependency[_]] = List( diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index d3033ea4a6..a712ef1c27 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -17,15 +17,13 @@ package org.apache.spark.rdd +import java.io.IOException + +import scala.reflect.ClassTag + +import org.apache.hadoop.fs.Path import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.hadoop.mapred.{FileInputFormat, SequenceFileInputFormat, JobConf, Reporter} -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.io.{NullWritable, BytesWritable} -import org.apache.hadoop.util.ReflectionUtils -import org.apache.hadoop.fs.Path -import java.io.{File, IOException, EOFException} -import java.text.NumberFormat private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {} @@ -33,7 +31,7 @@ private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {} * This RDD represents a RDD checkpoint file (similar to HadoopRDD). */ private[spark] -class CheckpointRDD[T: ClassManifest](sc: SparkContext, val checkpointPath: String) +class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) extends RDD[T](sc, Nil) { @transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration) diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index c5de6362a9..98da35763b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -22,6 +22,7 @@ import java.io.{ObjectOutputStream, IOException} import scala.collection.mutable import scala.Some import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag /** * Class that captures a coalesced RDD by essentially keeping track of parent partitions @@ -68,7 +69,7 @@ case class CoalescedRDDPartition( * @param maxPartitions number of desired partitions in the coalesced RDD * @param balanceSlack used to trade-off balance and locality. 1.0 is all locality, 0 is all balance */ -class CoalescedRDD[T: ClassManifest]( +class CoalescedRDD[T: ClassTag]( @transient var prev: RDD[T], maxPartitions: Int, balanceSlack: Double = 0.10) diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index a4bec41752..688c310ee9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -24,6 +24,8 @@ import org.apache.spark.partial.SumEvaluator import org.apache.spark.util.StatCounter import org.apache.spark.{TaskContext, Logging} +import scala.collection.immutable.NumericRange + /** * Extra functions available on RDDs of Doubles through an implicit conversion. * Import `org.apache.spark.SparkContext._` at the top of your program to use these functions. @@ -76,4 +78,129 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { val evaluator = new SumEvaluator(self.partitions.size, confidence) self.context.runApproximateJob(self, processPartition, evaluator, timeout) } + + /** + * Compute a histogram of the data using bucketCount number of buckets evenly + * spaced between the minimum and maximum of the RDD. For example if the min + * value is 0 and the max is 100 and there are two buckets the resulting + * buckets will be [0, 50) [50, 100]. bucketCount must be at least 1 + * If the RDD contains infinity, NaN throws an exception + * If the elements in RDD do not vary (max == min) always returns a single bucket. + */ + def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = { + // Compute the minimum and the maxium + val (max: Double, min: Double) = self.mapPartitions { items => + Iterator(items.foldRight(Double.NegativeInfinity, + Double.PositiveInfinity)((e: Double, x: Pair[Double, Double]) => + (x._1.max(e), x._2.min(e)))) + }.reduce { (maxmin1, maxmin2) => + (maxmin1._1.max(maxmin2._1), maxmin1._2.min(maxmin2._2)) + } + if (min.isNaN || max.isNaN || max.isInfinity || min.isInfinity ) { + throw new UnsupportedOperationException( + "Histogram on either an empty RDD or RDD containing +/-infinity or NaN") + } + val increment = (max-min)/bucketCount.toDouble + val range = if (increment != 0) { + Range.Double.inclusive(min, max, increment) + } else { + List(min, min) + } + val buckets = range.toArray + (buckets, histogram(buckets, true)) + } + + /** + * Compute a histogram using the provided buckets. The buckets are all open + * to the left except for the last which is closed + * e.g. for the array + * [1, 10, 20, 50] the buckets are [1, 10) [10, 20) [20, 50] + * e.g 1<=x<10 , 10<=x<20, 20<=x<50 + * And on the input of 1 and 50 we would have a histogram of 1, 0, 0 + * + * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched + * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets + * to true. + * buckets must be sorted and not contain any duplicates. + * buckets array must be at least two elements + * All NaN entries are treated the same. If you have a NaN bucket it must be + * the maximum value of the last position and all NaN entries will be counted + * in that bucket. + */ + def histogram(buckets: Array[Double], evenBuckets: Boolean = false): Array[Long] = { + if (buckets.length < 2) { + throw new IllegalArgumentException("buckets array must have at least two elements") + } + // The histogramPartition function computes the partail histogram for a given + // partition. The provided bucketFunction determines which bucket in the array + // to increment or returns None if there is no bucket. This is done so we can + // specialize for uniformly distributed buckets and save the O(log n) binary + // search cost. + def histogramPartition(bucketFunction: (Double) => Option[Int])(iter: Iterator[Double]): + Iterator[Array[Long]] = { + val counters = new Array[Long](buckets.length - 1) + while (iter.hasNext) { + bucketFunction(iter.next()) match { + case Some(x: Int) => {counters(x) += 1} + case _ => {} + } + } + Iterator(counters) + } + // Merge the counters. + def mergeCounters(a1: Array[Long], a2: Array[Long]): Array[Long] = { + a1.indices.foreach(i => a1(i) += a2(i)) + a1 + } + // Basic bucket function. This works using Java's built in Array + // binary search. Takes log(size(buckets)) + def basicBucketFunction(e: Double): Option[Int] = { + val location = java.util.Arrays.binarySearch(buckets, e) + if (location < 0) { + // If the location is less than 0 then the insertion point in the array + // to keep it sorted is -location-1 + val insertionPoint = -location-1 + // If we have to insert before the first element or after the last one + // its out of bounds. + // We do this rather than buckets.lengthCompare(insertionPoint) + // because Array[Double] fails to override it (for now). + if (insertionPoint > 0 && insertionPoint < buckets.length) { + Some(insertionPoint-1) + } else { + None + } + } else if (location < buckets.length - 1) { + // Exact match, just insert here + Some(location) + } else { + // Exact match to the last element + Some(location - 1) + } + } + // Determine the bucket function in constant time. Requires that buckets are evenly spaced + def fastBucketFunction(min: Double, increment: Double, count: Int)(e: Double): Option[Int] = { + // If our input is not a number unless the increment is also NaN then we fail fast + if (e.isNaN()) { + return None + } + val bucketNumber = (e - min)/(increment) + // We do this rather than buckets.lengthCompare(bucketNumber) + // because Array[Double] fails to override it (for now). + if (bucketNumber > count || bucketNumber < 0) { + None + } else { + Some(bucketNumber.toInt.min(count - 1)) + } + } + // Decide which bucket function to pass to histogramPartition. We decide here + // rather than having a general function so that the decission need only be made + // once rather than once per shard + val bucketFunction = if (evenBuckets) { + fastBucketFunction(buckets(0), buckets(1)-buckets(0), buckets.length-1) _ + } else { + basicBucketFunction _ + } + self.mapPartitions(histogramPartition(bucketFunction)).reduce(mergeCounters) + } + } diff --git a/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala b/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala index c8900d1a93..a84e5f9fd8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala @@ -17,13 +17,14 @@ package org.apache.spark.rdd -import org.apache.spark.{SparkContext, SparkEnv, Partition, TaskContext} +import scala.reflect.ClassTag +import org.apache.spark.{Partition, SparkContext, TaskContext} /** * An RDD that is empty, i.e. has no element in it. */ -class EmptyRDD[T: ClassManifest](sc: SparkContext) extends RDD[T](sc, Nil) { +class EmptyRDD[T: ClassTag](sc: SparkContext) extends RDD[T](sc, Nil) { override def getPartitions: Array[Partition] = Array.empty diff --git a/core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala b/core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala index 5312dc0b59..e74c83b90b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala @@ -18,8 +18,9 @@ package org.apache.spark.rdd import org.apache.spark.{OneToOneDependency, Partition, TaskContext} +import scala.reflect.ClassTag -private[spark] class FilteredRDD[T: ClassManifest]( +private[spark] class FilteredRDD[T: ClassTag]( prev: RDD[T], f: T => Boolean) extends RDD[T](prev) { diff --git a/core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala index cbdf6d84c0..4d1878fc14 100644 --- a/core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala @@ -18,10 +18,11 @@ package org.apache.spark.rdd import org.apache.spark.{Partition, TaskContext} +import scala.reflect.ClassTag private[spark] -class FlatMappedRDD[U: ClassManifest, T: ClassManifest]( +class FlatMappedRDD[U: ClassTag, T: ClassTag]( prev: RDD[T], f: T => TraversableOnce[U]) extends RDD[U](prev) { diff --git a/core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala index 829545d7b0..1a694475f6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala @@ -18,8 +18,9 @@ package org.apache.spark.rdd import org.apache.spark.{Partition, TaskContext} +import scala.reflect.ClassTag -private[spark] class GlommedRDD[T: ClassManifest](prev: RDD[T]) +private[spark] class GlommedRDD[T: ClassTag](prev: RDD[T]) extends RDD[Array[T]](prev) { override def getPartitions: Array[Partition] = firstParent[T].partitions diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala index aca0146884..8df8718f3b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -19,6 +19,8 @@ package org.apache.spark.rdd import java.sql.{Connection, ResultSet} +import scala.reflect.ClassTag + import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} import org.apache.spark.util.NextIterator @@ -45,7 +47,7 @@ private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) e * This should only call getInt, getString, etc; the RDD takes care of calling next. * The default maps a ResultSet to an array of Object. */ -class JdbcRDD[T: ClassManifest]( +class JdbcRDD[T: ClassTag]( sc: SparkContext, getConnection: () => Connection, sql: String, diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala index 203179c4ea..db15baf503 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala @@ -18,20 +18,18 @@ package org.apache.spark.rdd import org.apache.spark.{Partition, TaskContext} +import scala.reflect.ClassTag - -private[spark] -class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( +private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag]( prev: RDD[T], - f: Iterator[T] => Iterator[U], + f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator) preservesPartitioning: Boolean = false) extends RDD[U](prev) { - override val partitioner = - if (preservesPartitioning) firstParent[T].partitioner else None + override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None override def getPartitions: Array[Partition] = firstParent[T].partitions override def compute(split: Partition, context: TaskContext) = - f(firstParent[T].iterator(split, context)) + f(context, split.index, firstParent[T].iterator(split, context)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala index e8be1c4816..8d7c288593 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala @@ -17,10 +17,12 @@ package org.apache.spark.rdd +import scala.reflect.ClassTag + import org.apache.spark.{Partition, TaskContext} private[spark] -class MappedRDD[U: ClassManifest, T: ClassManifest](prev: RDD[T], f: T => U) +class MappedRDD[U: ClassTag, T: ClassTag](prev: RDD[T], f: T => U) extends RDD[U](prev) { override def getPartitions: Array[Partition] = firstParent[T].partitions diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala index 697be8b997..d5691f2267 100644 --- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala @@ -17,7 +17,9 @@ package org.apache.spark.rdd -import org.apache.spark.{RangePartitioner, Logging} +import scala.reflect.ClassTag + +import org.apache.spark.{Logging, RangePartitioner} /** * Extra functions available on RDDs of (key, value) pairs where the key is sortable through @@ -25,9 +27,9 @@ import org.apache.spark.{RangePartitioner, Logging} * use these functions. They will work with any key type that has a `scala.math.Ordered` * implementation. */ -class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, - V: ClassManifest, - P <: Product2[K, V] : ClassManifest]( +class OrderedRDDFunctions[K <% Ordered[K]: ClassTag, + V: ClassTag, + P <: Product2[K, V] : ClassTag]( self: RDD[P]) extends Logging with Serializable { diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 93b78e1232..48168e152e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -25,6 +25,7 @@ import java.util.{HashMap => JHashMap} import scala.collection.{mutable, Map} import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConversions._ +import scala.reflect.{ClassTag, classTag} import org.apache.hadoop.mapred._ import org.apache.hadoop.io.compress.CompressionCodec @@ -50,7 +51,7 @@ import org.apache.spark.Partitioner.defaultPartitioner * Extra functions available on RDDs of (key, value) pairs through an implicit conversion. * Import `org.apache.spark.SparkContext._` at the top of your program to use these functions. */ -class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)]) +class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) extends Logging with SparkHadoopMapReduceUtil with Serializable { @@ -415,7 +416,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)]) throw new SparkException("Default partitioner cannot partition array keys.") } val cg = new CoGroupedRDD[K](Seq(self, other), partitioner) - val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classManifest[K], Manifests.seqSeqManifest) + val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classTag[K], ClassTags.seqSeqClassTag) prfs.mapValues { case Seq(vs, ws) => (vs.asInstanceOf[Seq[V]], ws.asInstanceOf[Seq[W]]) } @@ -431,7 +432,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)]) throw new SparkException("Default partitioner cannot partition array keys.") } val cg = new CoGroupedRDD[K](Seq(self, other1, other2), partitioner) - val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classManifest[K], Manifests.seqSeqManifest) + val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classTag[K], ClassTags.seqSeqClassTag) prfs.mapValues { case Seq(vs, w1s, w2s) => (vs.asInstanceOf[Seq[V]], w1s.asInstanceOf[Seq[W1]], w2s.asInstanceOf[Seq[W2]]) } @@ -488,15 +489,15 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)]) * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting * RDD will be <= us. */ - def subtractByKey[W: ClassManifest](other: RDD[(K, W)]): RDD[(K, V)] = + def subtractByKey[W: ClassTag](other: RDD[(K, W)]): RDD[(K, V)] = subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.size))) /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ - def subtractByKey[W: ClassManifest](other: RDD[(K, W)], numPartitions: Int): RDD[(K, V)] = + def subtractByKey[W: ClassTag](other: RDD[(K, W)], numPartitions: Int): RDD[(K, V)] = subtractByKey(other, new HashPartitioner(numPartitions)) /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ - def subtractByKey[W: ClassManifest](other: RDD[(K, W)], p: Partitioner): RDD[(K, V)] = + def subtractByKey[W: ClassTag](other: RDD[(K, W)], p: Partitioner): RDD[(K, V)] = new SubtractedRDD[K, V, W](self, other, p) /** @@ -525,8 +526,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)]) * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class * supporting the key and value types K and V in this RDD. */ - def saveAsHadoopFile[F <: OutputFormat[K, V]](path: String)(implicit fm: ClassManifest[F]) { - saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]]) + def saveAsHadoopFile[F <: OutputFormat[K, V]](path: String)(implicit fm: ClassTag[F]) { + saveAsHadoopFile(path, getKeyClass, getValueClass, fm.runtimeClass.asInstanceOf[Class[F]]) } /** @@ -535,16 +536,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)]) * supplied codec. */ def saveAsHadoopFile[F <: OutputFormat[K, V]]( - path: String, codec: Class[_ <: CompressionCodec]) (implicit fm: ClassManifest[F]) { - saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]], codec) + path: String, codec: Class[_ <: CompressionCodec]) (implicit fm: ClassTag[F]) { + saveAsHadoopFile(path, getKeyClass, getValueClass, fm.runtimeClass.asInstanceOf[Class[F]], codec) } /** * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat` * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD. */ - def saveAsNewAPIHadoopFile[F <: NewOutputFormat[K, V]](path: String)(implicit fm: ClassManifest[F]) { - saveAsNewAPIHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]]) + def saveAsNewAPIHadoopFile[F <: NewOutputFormat[K, V]](path: String)(implicit fm: ClassTag[F]) { + saveAsNewAPIHadoopFile(path, getKeyClass, getValueClass, fm.runtimeClass.asInstanceOf[Class[F]]) } /** @@ -698,11 +699,11 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)]) */ def values: RDD[V] = self.map(_._2) - private[spark] def getKeyClass() = implicitly[ClassManifest[K]].erasure + private[spark] def getKeyClass() = implicitly[ClassTag[K]].runtimeClass - private[spark] def getValueClass() = implicitly[ClassManifest[V]].erasure + private[spark] def getValueClass() = implicitly[ClassTag[V]].runtimeClass } -private[spark] object Manifests { - val seqSeqManifest = classManifest[Seq[Seq[_]]] +private[spark] object ClassTags { + val seqSeqClassTag = classTag[Seq[Seq[_]]] } diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index cd96250389..09d0a8189d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -20,13 +20,15 @@ package org.apache.spark.rdd import scala.collection.immutable.NumericRange import scala.collection.mutable.ArrayBuffer import scala.collection.Map +import scala.reflect.ClassTag + import org.apache.spark._ import java.io._ import scala.Serializable import org.apache.spark.serializer.JavaSerializer import org.apache.spark.util.Utils -private[spark] class ParallelCollectionPartition[T: ClassManifest]( +private[spark] class ParallelCollectionPartition[T: ClassTag]( var rddId: Long, var slice: Int, var values: Seq[T]) @@ -78,7 +80,7 @@ private[spark] class ParallelCollectionPartition[T: ClassManifest]( } } -private[spark] class ParallelCollectionRDD[T: ClassManifest]( +private[spark] class ParallelCollectionRDD[T: ClassTag]( @transient sc: SparkContext, @transient data: Seq[T], numSlices: Int, @@ -109,7 +111,7 @@ private object ParallelCollectionRDD { * collections specially, encoding the slices as other Ranges to minimize memory cost. This makes * it efficient to run Spark over RDDs representing large sets of numbers. */ - def slice[T: ClassManifest](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = { + def slice[T: ClassTag](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = { if (numSlices < 1) { throw new IllegalArgumentException("Positive number of slices required") } diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala index 165cd412fc..ea8885b36e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala @@ -17,6 +17,8 @@ package org.apache.spark.rdd +import scala.reflect.ClassTag + import org.apache.spark.{NarrowDependency, SparkEnv, Partition, TaskContext} @@ -33,11 +35,13 @@ class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boo extends NarrowDependency[T](rdd) { @transient - val partitions: Array[Partition] = rdd.partitions.zipWithIndex - .filter(s => partitionFilterFunc(s._2)) + val partitions: Array[Partition] = rdd.partitions + .filter(s => partitionFilterFunc(s.index)).zipWithIndex .map { case(split, idx) => new PartitionPruningRDDPartition(idx, split) : Partition } - override def getParents(partitionId: Int) = List(partitions(partitionId).index) + override def getParents(partitionId: Int) = { + List(partitions(partitionId).asInstanceOf[PartitionPruningRDDPartition].parentSplit.index) + } } @@ -47,7 +51,7 @@ class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boo * and the execution DAG has a filter on the key, we can avoid launching tasks * on partitions that don't have the range covering the key. */ -class PartitionPruningRDD[T: ClassManifest]( +class PartitionPruningRDD[T: ClassTag]( @transient prev: RDD[T], @transient partitionFilterFunc: Int => Boolean) extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) { @@ -67,6 +71,6 @@ object PartitionPruningRDD { * when its type T is not known at compile time. */ def create[T](rdd: RDD[T], partitionFilterFunc: Int => Boolean) = { - new PartitionPruningRDD[T](rdd, partitionFilterFunc)(rdd.elementClassManifest) + new PartitionPruningRDD[T](rdd, partitionFilterFunc)(rdd.elementClassTag) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index d5304ab0ae..1dbbe39898 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -24,6 +24,7 @@ import scala.collection.Map import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer import scala.io.Source +import scala.reflect.ClassTag import org.apache.spark.{SparkEnv, Partition, TaskContext} import org.apache.spark.broadcast.Broadcast @@ -33,7 +34,7 @@ import org.apache.spark.broadcast.Broadcast * An RDD that pipes the contents of each parent partition through an external command * (printing them one per line) and returns the output as a collection of strings. */ -class PipedRDD[T: ClassManifest]( +class PipedRDD[T: ClassTag]( prev: RDD[T], command: Seq[String], envVars: Map[String, String], diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 6e88be6f6a..ea45566ad1 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -23,6 +23,9 @@ import scala.collection.Map import scala.collection.JavaConversions.mapAsScalaMap import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.reflect.{classTag, ClassTag} + import org.apache.hadoop.io.BytesWritable import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.io.NullWritable @@ -69,7 +72,7 @@ import org.apache.spark._ * [[http://www.cs.berkeley.edu/~matei/papers/2012/nsdi_spark.pdf Spark paper]] for more details * on RDD internals. */ -abstract class RDD[T: ClassManifest]( +abstract class RDD[T: ClassTag]( @transient private var sc: SparkContext, @transient private var deps: Seq[Dependency[_]] ) extends Serializable with Logging { @@ -101,7 +104,7 @@ abstract class RDD[T: ClassManifest]( protected def getPreferredLocations(split: Partition): Seq[String] = Nil /** Optionally overridden by subclasses to specify how they are partitioned. */ - val partitioner: Option[Partitioner] = None + @transient val partitioner: Option[Partitioner] = None // ======================================================================= // Methods and fields available on all RDDs @@ -114,7 +117,7 @@ abstract class RDD[T: ClassManifest]( val id: Int = sc.newRddId() /** A friendly name for this RDD */ - var name: String = null + @transient var name: String = null /** Assign a name to this RDD */ def setName(_name: String) = { @@ -123,7 +126,7 @@ abstract class RDD[T: ClassManifest]( } /** User-defined generator of this RDD*/ - var generator = Utils.getCallSiteInfo.firstUserClass + @transient var generator = Utils.getCallSiteInfo.firstUserClass /** Reset generator*/ def setGenerator(_generator: String) = { @@ -243,13 +246,13 @@ abstract class RDD[T: ClassManifest]( /** * Return a new RDD by applying a function to all elements of this RDD. */ - def map[U: ClassManifest](f: T => U): RDD[U] = new MappedRDD(this, sc.clean(f)) + def map[U: ClassTag](f: T => U): RDD[U] = new MappedRDD(this, sc.clean(f)) /** * Return a new RDD by first applying a function to all elements of this * RDD, and then flattening the results. */ - def flatMap[U: ClassManifest](f: T => TraversableOnce[U]): RDD[U] = + def flatMap[U: ClassTag](f: T => TraversableOnce[U]): RDD[U] = new FlatMappedRDD(this, sc.clean(f)) /** @@ -374,25 +377,25 @@ abstract class RDD[T: ClassManifest]( * Return the Cartesian product of this RDD and another one, that is, the RDD of all pairs of * elements (a, b) where a is in `this` and b is in `other`. */ - def cartesian[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new CartesianRDD(sc, this, other) + def cartesian[U: ClassTag](other: RDD[U]): RDD[(T, U)] = new CartesianRDD(sc, this, other) /** * Return an RDD of grouped items. */ - def groupBy[K: ClassManifest](f: T => K): RDD[(K, Seq[T])] = + def groupBy[K: ClassTag](f: T => K): RDD[(K, Seq[T])] = groupBy[K](f, defaultPartitioner(this)) /** * Return an RDD of grouped elements. Each group consists of a key and a sequence of elements * mapping to that key. */ - def groupBy[K: ClassManifest](f: T => K, numPartitions: Int): RDD[(K, Seq[T])] = + def groupBy[K: ClassTag](f: T => K, numPartitions: Int): RDD[(K, Seq[T])] = groupBy(f, new HashPartitioner(numPartitions)) /** * Return an RDD of grouped items. */ - def groupBy[K: ClassManifest](f: T => K, p: Partitioner): RDD[(K, Seq[T])] = { + def groupBy[K: ClassTag](f: T => K, p: Partitioner): RDD[(K, Seq[T])] = { val cleanF = sc.clean(f) this.map(t => (cleanF(t), t)).groupByKey(p) } @@ -408,7 +411,6 @@ abstract class RDD[T: ClassManifest]( def pipe(command: String, env: Map[String, String]): RDD[String] = new PipedRDD(this, command, env) - /** * Return an RDD created by piping elements to a forked external process. * The print behavior can be customized by providing two functions. @@ -440,29 +442,31 @@ abstract class RDD[T: ClassManifest]( /** * Return a new RDD by applying a function to each partition of this RDD. */ - def mapPartitions[U: ClassManifest]( + def mapPartitions[U: ClassTag]( f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { - new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning) + val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(iter) + new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning) } /** * Return a new RDD by applying a function to each partition of this RDD, while tracking the index * of the original partition. */ - def mapPartitionsWithIndex[U: ClassManifest]( + def mapPartitionsWithIndex[U: ClassTag]( f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { - val func = (context: TaskContext, iter: Iterator[T]) => f(context.partitionId, iter) - new MapPartitionsWithContextRDD(this, sc.clean(func), preservesPartitioning) + val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter) + new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning) } /** * Return a new RDD by applying a function to each partition of this RDD. This is a variant of * mapPartitions that also passes the TaskContext into the closure. */ - def mapPartitionsWithContext[U: ClassManifest]( + def mapPartitionsWithContext[U: ClassTag]( f: (TaskContext, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { - new MapPartitionsWithContextRDD(this, sc.clean(f), preservesPartitioning) + val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(context, iter) + new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning) } /** @@ -470,7 +474,7 @@ abstract class RDD[T: ClassManifest]( * of the original partition. */ @deprecated("use mapPartitionsWithIndex", "0.7.0") - def mapPartitionsWithSplit[U: ClassManifest]( + def mapPartitionsWithSplit[U: ClassTag]( f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { mapPartitionsWithIndex(f, preservesPartitioning) } @@ -480,14 +484,13 @@ abstract class RDD[T: ClassManifest]( * additional parameter is produced by constructA, which is called in each * partition with the index of that partition. */ - def mapWith[A: ClassManifest, U: ClassManifest] + def mapWith[A: ClassTag, U: ClassTag] (constructA: Int => A, preservesPartitioning: Boolean = false) (f: (T, A) => U): RDD[U] = { - def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = { - val a = constructA(context.partitionId) + mapPartitionsWithIndex((index, iter) => { + val a = constructA(index) iter.map(t => f(t, a)) - } - new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning) + }, preservesPartitioning) } /** @@ -495,14 +498,13 @@ abstract class RDD[T: ClassManifest]( * additional parameter is produced by constructA, which is called in each * partition with the index of that partition. */ - def flatMapWith[A: ClassManifest, U: ClassManifest] + def flatMapWith[A: ClassTag, U: ClassTag] (constructA: Int => A, preservesPartitioning: Boolean = false) (f: (T, A) => Seq[U]): RDD[U] = { - def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = { - val a = constructA(context.partitionId) + mapPartitionsWithIndex((index, iter) => { + val a = constructA(index) iter.flatMap(t => f(t, a)) - } - new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning) + }, preservesPartitioning) } /** @@ -510,12 +512,11 @@ abstract class RDD[T: ClassManifest]( * This additional parameter is produced by constructA, which is called in each * partition with the index of that partition. */ - def foreachWith[A: ClassManifest](constructA: Int => A)(f: (T, A) => Unit) { - def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = { - val a = constructA(context.partitionId) + def foreachWith[A: ClassTag](constructA: Int => A)(f: (T, A) => Unit) { + mapPartitionsWithIndex { (index, iter) => + val a = constructA(index) iter.map(t => {f(t, a); t}) - } - new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true).foreach(_ => {}) + }.foreach(_ => {}) } /** @@ -523,12 +524,11 @@ abstract class RDD[T: ClassManifest]( * additional parameter is produced by constructA, which is called in each * partition with the index of that partition. */ - def filterWith[A: ClassManifest](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = { - def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = { - val a = constructA(context.partitionId) + def filterWith[A: ClassTag](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = { + mapPartitionsWithIndex((index, iter) => { + val a = constructA(index) iter.filter(t => p(t, a)) - } - new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true) + }, preservesPartitioning = true) } /** @@ -537,7 +537,7 @@ abstract class RDD[T: ClassManifest]( * partitions* and the *same number of elements in each partition* (e.g. one was made through * a map on the other). */ - def zip[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other) + def zip[U: ClassTag](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other) /** * Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by @@ -545,20 +545,30 @@ abstract class RDD[T: ClassManifest]( * *same number of partitions*, but does *not* require them to have the same number * of elements in each partition. */ - def zipPartitions[B: ClassManifest, V: ClassManifest] + def zipPartitions[B: ClassTag, V: ClassTag] (rdd2: RDD[B]) (f: (Iterator[T], Iterator[B]) => Iterator[V]): RDD[V] = - new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2) + new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2, false) - def zipPartitions[B: ClassManifest, C: ClassManifest, V: ClassManifest] + def zipPartitions[B: ClassTag, C: ClassTag, V: ClassTag] + (rdd2: RDD[B], rdd3: RDD[C], preservesPartitioning: Boolean) + (f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V]): RDD[V] = + new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3, preservesPartitioning) + + def zipPartitions[B: ClassTag, C: ClassTag, V: ClassTag] (rdd2: RDD[B], rdd3: RDD[C]) (f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V]): RDD[V] = - new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3) + new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3, false) - def zipPartitions[B: ClassManifest, C: ClassManifest, D: ClassManifest, V: ClassManifest] + def zipPartitions[B: ClassTag, C: ClassTag, D: ClassTag, V: ClassTag] + (rdd2: RDD[B], rdd3: RDD[C], rdd4: RDD[D], preservesPartitioning: Boolean) + (f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V]): RDD[V] = + new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4, preservesPartitioning) + + def zipPartitions[B: ClassTag, C: ClassTag, D: ClassTag, V: ClassTag] (rdd2: RDD[B], rdd3: RDD[C], rdd4: RDD[D]) (f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V]): RDD[V] = - new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4) + new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4, false) // Actions (launch a job to return a value to the user program) @@ -593,7 +603,7 @@ abstract class RDD[T: ClassManifest]( /** * Return an RDD that contains all matching values by applying `f`. */ - def collect[U: ClassManifest](f: PartialFunction[T, U]): RDD[U] = { + def collect[U: ClassTag](f: PartialFunction[T, U]): RDD[U] = { filter(f.isDefinedAt).map(f) } @@ -683,7 +693,7 @@ abstract class RDD[T: ClassManifest]( * allowed to modify and return their first argument instead of creating a new U to avoid memory * allocation. */ - def aggregate[U: ClassManifest](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = { + def aggregate[U: ClassTag](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = { // Clone the zero value since we will also be serializing it as part of tasks var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance()) val cleanSeqOp = sc.clean(seqOp) @@ -732,7 +742,7 @@ abstract class RDD[T: ClassManifest]( * combine step happens locally on the master, equivalent to running a single reduce task. */ def countByValue(): Map[T, Long] = { - if (elementClassManifest.erasure.isArray) { + if (elementClassTag.runtimeClass.isArray) { throw new SparkException("countByValue() does not support arrays") } // TODO: This should perhaps be distributed by default. @@ -763,7 +773,7 @@ abstract class RDD[T: ClassManifest]( timeout: Long, confidence: Double = 0.95 ): PartialResult[Map[T, BoundedDouble]] = { - if (elementClassManifest.erasure.isArray) { + if (elementClassTag.runtimeClass.isArray) { throw new SparkException("countByValueApprox() does not support arrays") } val countPartition: (TaskContext, Iterator[T]) => OLMap[T] = { (ctx, iter) => @@ -928,14 +938,14 @@ abstract class RDD[T: ClassManifest]( private var storageLevel: StorageLevel = StorageLevel.NONE /** Record user function generating this RDD. */ - private[spark] val origin = Utils.formatSparkCallSite + @transient private[spark] val origin = Utils.formatSparkCallSite - private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T] + private[spark] def elementClassTag: ClassTag[T] = classTag[T] private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None /** Returns the first parent RDD */ - protected[spark] def firstParent[U: ClassManifest] = { + protected[spark] def firstParent[U: ClassTag] = { dependencies.head.rdd.asInstanceOf[RDD[U]] } @@ -943,7 +953,7 @@ abstract class RDD[T: ClassManifest]( def context = sc // Avoid handling doCheckpoint multiple times to prevent excessive recursion - private var doCheckpointCalled = false + @transient private var doCheckpointCalled = false /** * Performs the checkpointing of this RDD by saving this. It is called by the DAGScheduler @@ -997,7 +1007,7 @@ abstract class RDD[T: ClassManifest]( origin) def toJavaRDD() : JavaRDD[T] = { - new JavaRDD(this)(elementClassManifest) + new JavaRDD(this)(elementClassTag) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index 6009a41570..3b56e45aa9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -17,6 +17,8 @@ package org.apache.spark.rdd +import scala.reflect.ClassTag + import org.apache.hadoop.fs.Path import org.apache.hadoop.conf.Configuration @@ -38,7 +40,7 @@ private[spark] object CheckpointState extends Enumeration { * manages the post-checkpoint state by providing the updated partitions, iterator and preferred locations * of the checkpointed RDD. */ -private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T]) +private[spark] class RDDCheckpointData[T: ClassTag](rdd: RDD[T]) extends Logging with Serializable { import CheckpointState._ diff --git a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala index 2c5253ae30..d433670cc2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala @@ -17,6 +17,7 @@ package org.apache.spark.rdd +import scala.reflect.ClassTag import java.util.Random import cern.jet.random.Poisson @@ -29,9 +30,9 @@ class SampledRDDPartition(val prev: Partition, val seed: Int) extends Partition override val index: Int = prev.index } -class SampledRDD[T: ClassManifest]( +class SampledRDD[T: ClassTag]( prev: RDD[T], - withReplacement: Boolean, + withReplacement: Boolean, frac: Double, seed: Int) extends RDD[T](prev) { diff --git a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala index 5fe4676029..2d1bd5b481 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala @@ -14,9 +14,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.spark.rdd +import scala.reflect.{ ClassTag, classTag} + import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapred.SequenceFileOutputFormat import org.apache.hadoop.io.compress.CompressionCodec @@ -32,15 +33,15 @@ import org.apache.spark.Logging * * Import `org.apache.spark.SparkContext._` at the top of their program to use these functions. */ -class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : ClassManifest]( +class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag]( self: RDD[(K, V)]) extends Logging with Serializable { - private def getWritableClass[T <% Writable: ClassManifest](): Class[_ <: Writable] = { + private def getWritableClass[T <% Writable: ClassTag](): Class[_ <: Writable] = { val c = { - if (classOf[Writable].isAssignableFrom(classManifest[T].erasure)) { - classManifest[T].erasure + if (classOf[Writable].isAssignableFrom(classTag[T].runtimeClass)) { + classTag[T].runtimeClass } else { // We get the type of the Writable class by looking at the apply method which converts // from T to Writable. Since we have two apply methods we filter out the one which diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index a5d751a7bd..3682c84598 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -17,8 +17,10 @@ package org.apache.spark.rdd -import org.apache.spark.{Dependency, Partitioner, SparkEnv, ShuffleDependency, Partition, TaskContext} +import scala.reflect.ClassTag +import org.apache.spark.{Dependency, Partition, Partitioner, ShuffleDependency, + SparkEnv, TaskContext} private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { override val index = idx @@ -32,7 +34,7 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { * @tparam K the key class. * @tparam V the value class. */ -class ShuffledRDD[K, V, P <: Product2[K, V] : ClassManifest]( +class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag]( @transient var prev: RDD[P], part: Partitioner) extends RDD[P](prev.context, Nil) { diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index 7af4d803e7..aab30b1bb4 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -18,8 +18,11 @@ package org.apache.spark.rdd import java.util.{HashMap => JHashMap} + import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + import org.apache.spark.Partitioner import org.apache.spark.Dependency import org.apache.spark.TaskContext @@ -45,7 +48,7 @@ import org.apache.spark.OneToOneDependency * you can use `rdd1`'s partitioner/partition size and not worry about running * out of memory because of the size of `rdd2`. */ -private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassManifest]( +private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( @transient var rdd1: RDD[_ <: Product2[K, V]], @transient var rdd2: RDD[_ <: Product2[K, W]], part: Partitioner) diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index ae8a9f36a6..08a41ac558 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -18,10 +18,13 @@ package org.apache.spark.rdd import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + import org.apache.spark.{Dependency, RangeDependency, SparkContext, Partition, TaskContext} + import java.io.{ObjectOutputStream, IOException} -private[spark] class UnionPartition[T: ClassManifest](idx: Int, rdd: RDD[T], splitIndex: Int) +private[spark] class UnionPartition[T: ClassTag](idx: Int, rdd: RDD[T], splitIndex: Int) extends Partition { var split: Partition = rdd.partitions(splitIndex) @@ -40,7 +43,7 @@ private[spark] class UnionPartition[T: ClassManifest](idx: Int, rdd: RDD[T], spl } } -class UnionRDD[T: ClassManifest]( +class UnionRDD[T: ClassTag]( sc: SparkContext, @transient var rdds: Seq[RDD[T]]) extends RDD[T](sc, Nil) { // Nil since we implement getDependencies diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala index 31e6fd519d..83be3c6eb4 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -19,10 +19,12 @@ package org.apache.spark.rdd import org.apache.spark.{OneToOneDependency, SparkContext, Partition, TaskContext} import java.io.{ObjectOutputStream, IOException} +import scala.reflect.ClassTag private[spark] class ZippedPartitionsPartition( idx: Int, - @transient rdds: Seq[RDD[_]]) + @transient rdds: Seq[RDD[_]], + @transient val preferredLocations: Seq[String]) extends Partition { override val index: Int = idx @@ -37,33 +39,31 @@ private[spark] class ZippedPartitionsPartition( } } -abstract class ZippedPartitionsBaseRDD[V: ClassManifest]( +abstract class ZippedPartitionsBaseRDD[V: ClassTag]( sc: SparkContext, - var rdds: Seq[RDD[_]]) + var rdds: Seq[RDD[_]], + preservesPartitioning: Boolean = false) extends RDD[V](sc, rdds.map(x => new OneToOneDependency(x))) { + override val partitioner = + if (preservesPartitioning) firstParent[Any].partitioner else None + override def getPartitions: Array[Partition] = { - val sizes = rdds.map(x => x.partitions.size) - if (!sizes.forall(x => x == sizes(0))) { + val numParts = rdds.head.partitions.size + if (!rdds.forall(rdd => rdd.partitions.size == numParts)) { throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") } - val array = new Array[Partition](sizes(0)) - for (i <- 0 until sizes(0)) { - array(i) = new ZippedPartitionsPartition(i, rdds) + Array.tabulate[Partition](numParts) { i => + val prefs = rdds.map(rdd => rdd.preferredLocations(rdd.partitions(i))) + // Check whether there are any hosts that match all RDDs; otherwise return the union + val exactMatchLocations = prefs.reduce((x, y) => x.intersect(y)) + val locs = if (!exactMatchLocations.isEmpty) exactMatchLocations else prefs.flatten.distinct + new ZippedPartitionsPartition(i, rdds, locs) } - array } override def getPreferredLocations(s: Partition): Seq[String] = { - val parts = s.asInstanceOf[ZippedPartitionsPartition].partitions - val prefs = rdds.zip(parts).map { case (rdd, p) => rdd.preferredLocations(p) } - // Check whether there are any hosts that match all RDDs; otherwise return the union - val exactMatchLocations = prefs.reduce((x, y) => x.intersect(y)) - if (!exactMatchLocations.isEmpty) { - exactMatchLocations - } else { - prefs.flatten.distinct - } + s.asInstanceOf[ZippedPartitionsPartition].preferredLocations } override def clearDependencies() { @@ -72,12 +72,13 @@ abstract class ZippedPartitionsBaseRDD[V: ClassManifest]( } } -class ZippedPartitionsRDD2[A: ClassManifest, B: ClassManifest, V: ClassManifest]( +class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag]( sc: SparkContext, f: (Iterator[A], Iterator[B]) => Iterator[V], var rdd1: RDD[A], - var rdd2: RDD[B]) - extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2)) { + var rdd2: RDD[B], + preservesPartitioning: Boolean = false) + extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2), preservesPartitioning) { override def compute(s: Partition, context: TaskContext): Iterator[V] = { val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions @@ -92,13 +93,14 @@ class ZippedPartitionsRDD2[A: ClassManifest, B: ClassManifest, V: ClassManifest] } class ZippedPartitionsRDD3 - [A: ClassManifest, B: ClassManifest, C: ClassManifest, V: ClassManifest]( + [A: ClassTag, B: ClassTag, C: ClassTag, V: ClassTag]( sc: SparkContext, f: (Iterator[A], Iterator[B], Iterator[C]) => Iterator[V], var rdd1: RDD[A], var rdd2: RDD[B], - var rdd3: RDD[C]) - extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3)) { + var rdd3: RDD[C], + preservesPartitioning: Boolean = false) + extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3), preservesPartitioning) { override def compute(s: Partition, context: TaskContext): Iterator[V] = { val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions @@ -116,14 +118,15 @@ class ZippedPartitionsRDD3 } class ZippedPartitionsRDD4 - [A: ClassManifest, B: ClassManifest, C: ClassManifest, D:ClassManifest, V: ClassManifest]( + [A: ClassTag, B: ClassTag, C: ClassTag, D:ClassTag, V: ClassTag]( sc: SparkContext, f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V], var rdd1: RDD[A], var rdd2: RDD[B], var rdd3: RDD[C], - var rdd4: RDD[D]) - extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4)) { + var rdd4: RDD[D], + preservesPartitioning: Boolean = false) + extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4), preservesPartitioning) { override def compute(s: Partition, context: TaskContext): Iterator[V] = { val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala index 567b67dfee..fb5b070c18 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala @@ -18,10 +18,12 @@ package org.apache.spark.rdd import org.apache.spark.{OneToOneDependency, SparkContext, Partition, TaskContext} + import java.io.{ObjectOutputStream, IOException} +import scala.reflect.ClassTag -private[spark] class ZippedPartition[T: ClassManifest, U: ClassManifest]( +private[spark] class ZippedPartition[T: ClassTag, U: ClassTag]( idx: Int, @transient rdd1: RDD[T], @transient rdd2: RDD[U] @@ -42,7 +44,7 @@ private[spark] class ZippedPartition[T: ClassManifest, U: ClassManifest]( } } -class ZippedRDD[T: ClassManifest, U: ClassManifest]( +class ZippedRDD[T: ClassTag, U: ClassTag]( sc: SparkContext, var rdd1: RDD[T], var rdd2: RDD[U]) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 42bb3884c8..963d15b76d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -21,9 +21,11 @@ import java.io.NotSerializableException import java.util.Properties import java.util.concurrent.atomic.AtomicInteger -import akka.actor._ -import akka.util.duration._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} +import scala.concurrent.duration._ +import scala.reflect.ClassTag + +import akka.actor._ import org.apache.spark._ import org.apache.spark.rdd.RDD @@ -104,36 +106,16 @@ class DAGScheduler( // The time, in millis, to wait for fetch failure events to stop coming in after one is detected; // this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one // as more failure events come in - val RESUBMIT_TIMEOUT = 50L + val RESUBMIT_TIMEOUT = 50.milliseconds // The time, in millis, to wake up between polls of the completion queue in order to potentially // resubmit failed stages val POLL_TIMEOUT = 10L - private val eventProcessActor: ActorRef = env.actorSystem.actorOf(Props(new Actor { - override def preStart() { - context.system.scheduler.schedule(RESUBMIT_TIMEOUT milliseconds, RESUBMIT_TIMEOUT milliseconds) { - if (failed.size > 0) { - resubmitFailedStages() - } - } - } + // Warns the user if a stage contains a task with size greater than this value (in KB) + val TASK_SIZE_TO_WARN = 100 - /** - * The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure - * events and responds by launching tasks. This runs in a dedicated thread and receives events - * via the eventQueue. - */ - def receive = { - case event: DAGSchedulerEvent => - logDebug("Got event of type " + event.getClass.getName) - - if (!processEvent(event)) - submitWaitingStages() - else - context.stop(self) - } - })) + private var eventProcessActor: ActorRef = _ private[scheduler] val nextJobId = new AtomicInteger(0) @@ -141,9 +123,13 @@ class DAGScheduler( private val nextStageId = new AtomicInteger(0) - private val stageIdToStage = new TimeStampedHashMap[Int, Stage] + private[scheduler] val jobIdToStageIds = new TimeStampedHashMap[Int, HashSet[Int]] - private val shuffleToMapStage = new TimeStampedHashMap[Int, Stage] + private[scheduler] val stageIdToJobIds = new TimeStampedHashMap[Int, HashSet[Int]] + + private[scheduler] val stageIdToStage = new TimeStampedHashMap[Int, Stage] + + private[scheduler] val shuffleToMapStage = new TimeStampedHashMap[Int, Stage] private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo] @@ -174,6 +160,57 @@ class DAGScheduler( val metadataCleaner = new MetadataCleaner(MetadataCleanerType.DAG_SCHEDULER, this.cleanup) + /** + * Starts the event processing actor. The actor has two responsibilities: + * + * 1. Waits for events like job submission, task finished, task failure etc., and calls + * [[org.apache.spark.scheduler.DAGScheduler.processEvent()]] to process them. + * 2. Schedules a periodical task to resubmit failed stages. + * + * NOTE: the actor cannot be started in the constructor, because the periodical task references + * some internal states of the enclosing [[org.apache.spark.scheduler.DAGScheduler]] object, thus + * cannot be scheduled until the [[org.apache.spark.scheduler.DAGScheduler]] is fully constructed. + */ + def start() { + eventProcessActor = env.actorSystem.actorOf(Props(new Actor { + /** + * A handle to the periodical task, used to cancel the task when the actor is stopped. + */ + var resubmissionTask: Cancellable = _ + + override def preStart() { + import context.dispatcher + /** + * A message is sent to the actor itself periodically to remind the actor to resubmit failed + * stages. In this way, stage resubmission can be done within the same thread context of + * other event processing logic to avoid unnecessary synchronization overhead. + */ + resubmissionTask = context.system.scheduler.schedule( + RESUBMIT_TIMEOUT, RESUBMIT_TIMEOUT, self, ResubmitFailedStages) + } + + /** + * The main event loop of the DAG scheduler. + */ + def receive = { + case event: DAGSchedulerEvent => + logDebug("Got event of type " + event.getClass.getName) + + /** + * All events are forwarded to `processEvent()`, so that the event processing logic can + * easily tested without starting a dedicated actor. Please refer to `DAGSchedulerSuite` + * for details. + */ + if (!processEvent(event)) { + submitWaitingStages() + } else { + resubmissionTask.cancel() + context.stop(self) + } + } + })) + } + def addSparkListener(listener: SparkListener) { listenerBus.addListener(listener) } @@ -202,16 +239,16 @@ class DAGScheduler( shuffleToMapStage.get(shuffleDep.shuffleId) match { case Some(stage) => stage case None => - val stage = newStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, Some(shuffleDep), jobId) + val stage = newOrUsedStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, shuffleDep, jobId) shuffleToMapStage(shuffleDep.shuffleId) = stage stage } } /** - * Create a Stage for the given RDD, either as a shuffle map stage (for a ShuffleDependency) or - * as a result stage for the final RDD used directly in an action. The stage will also be - * associated with the provided jobId. + * Create a Stage -- either directly for use as a result stage, or as part of the (re)-creation + * of a shuffle map stage in newOrUsedStage. The stage will be associated with the provided + * jobId. Production of shuffle map stages should always use newOrUsedStage, not newStage directly. */ private def newStage( rdd: RDD[_], @@ -221,20 +258,44 @@ class DAGScheduler( callSite: Option[String] = None) : Stage = { - if (shuffleDep != None) { - // Kind of ugly: need to register RDDs with the cache and map output tracker here - // since we can't do it in the RDD constructor because # of partitions is unknown - logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")") - mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.partitions.size) - } val id = nextStageId.getAndIncrement() val stage = new Stage(id, rdd, numTasks, shuffleDep, getParentStages(rdd, jobId), jobId, callSite) stageIdToStage(id) = stage + updateJobIdStageIdMaps(jobId, stage) stageToInfos(stage) = new StageInfo(stage) stage } + /** + * Create a shuffle map Stage for the given RDD. The stage will also be associated with the + * provided jobId. If a stage for the shuffleId existed previously so that the shuffleId is + * present in the MapOutputTracker, then the number and location of available outputs are + * recovered from the MapOutputTracker + */ + private def newOrUsedStage( + rdd: RDD[_], + numTasks: Int, + shuffleDep: ShuffleDependency[_,_], + jobId: Int, + callSite: Option[String] = None) + : Stage = + { + val stage = newStage(rdd, numTasks, Some(shuffleDep), jobId, callSite) + if (mapOutputTracker.has(shuffleDep.shuffleId)) { + val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId) + val locs = MapOutputTracker.deserializeMapStatuses(serLocs) + for (i <- 0 until locs.size) stage.outputLocs(i) = List(locs(i)) + stage.numAvailableOutputs = locs.size + } else { + // Kind of ugly: need to register RDDs with the cache and map output tracker here + // since we can't do it in the RDD constructor because # of partitions is unknown + logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")") + mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.size) + } + stage + } + /** * Get or create the list of parent stages for a given RDD. The stages will be assigned the * provided jobId if they haven't already been created with a lower jobId. @@ -286,6 +347,89 @@ class DAGScheduler( missing.toList } + /** + * Registers the given jobId among the jobs that need the given stage and + * all of that stage's ancestors. + */ + private def updateJobIdStageIdMaps(jobId: Int, stage: Stage) { + def updateJobIdStageIdMapsList(stages: List[Stage]) { + if (!stages.isEmpty) { + val s = stages.head + stageIdToJobIds.getOrElseUpdate(s.id, new HashSet[Int]()) += jobId + jobIdToStageIds.getOrElseUpdate(jobId, new HashSet[Int]()) += s.id + val parents = getParentStages(s.rdd, jobId) + val parentsWithoutThisJobId = parents.filter(p => !stageIdToJobIds.get(p.id).exists(_.contains(jobId))) + updateJobIdStageIdMapsList(parentsWithoutThisJobId ++ stages.tail) + } + } + updateJobIdStageIdMapsList(List(stage)) + } + + /** + * Removes job and any stages that are not needed by any other job. Returns the set of ids for stages that + * were removed. The associated tasks for those stages need to be cancelled if we got here via job cancellation. + */ + private def removeJobAndIndependentStages(jobId: Int): Set[Int] = { + val registeredStages = jobIdToStageIds(jobId) + val independentStages = new HashSet[Int]() + if (registeredStages.isEmpty) { + logError("No stages registered for job " + jobId) + } else { + stageIdToJobIds.filterKeys(stageId => registeredStages.contains(stageId)).foreach { + case (stageId, jobSet) => + if (!jobSet.contains(jobId)) { + logError("Job %d not registered for stage %d even though that stage was registered for the job" + .format(jobId, stageId)) + } else { + def removeStage(stageId: Int) { + // data structures based on Stage + stageIdToStage.get(stageId).foreach { s => + if (running.contains(s)) { + logDebug("Removing running stage %d".format(stageId)) + running -= s + } + stageToInfos -= s + shuffleToMapStage.keys.filter(shuffleToMapStage(_) == s).foreach(shuffleToMapStage.remove) + if (pendingTasks.contains(s) && !pendingTasks(s).isEmpty) { + logDebug("Removing pending status for stage %d".format(stageId)) + } + pendingTasks -= s + if (waiting.contains(s)) { + logDebug("Removing stage %d from waiting set.".format(stageId)) + waiting -= s + } + if (failed.contains(s)) { + logDebug("Removing stage %d from failed set.".format(stageId)) + failed -= s + } + } + // data structures based on StageId + stageIdToStage -= stageId + stageIdToJobIds -= stageId + + logDebug("After removal of stage %d, remaining stages = %d".format(stageId, stageIdToStage.size)) + } + + jobSet -= jobId + if (jobSet.isEmpty) { // no other job needs this stage + independentStages += stageId + removeStage(stageId) + } + } + } + } + independentStages.toSet + } + + private def jobIdToStageIdsRemove(jobId: Int) { + if (!jobIdToStageIds.contains(jobId)) { + logDebug("Trying to remove unregistered job " + jobId) + } else { + removeJobAndIndependentStages(jobId) + jobIdToStageIds -= jobId + } + } + /** * Submit a job to the job scheduler and get a JobWaiter object back. The JobWaiter object * can be used to block until the the job finishes executing or can be used to cancel the job. @@ -319,7 +463,7 @@ class DAGScheduler( waiter } - def runJob[T, U: ClassManifest]( + def runJob[T, U: ClassTag]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], @@ -375,13 +519,25 @@ class DAGScheduler( } /** - * Process one event retrieved from the event queue. - * Returns true if we should stop the event loop. + * Process one event retrieved from the event processing actor. + * + * @param event The event to be processed. + * @return `true` if we should stop the event loop. */ private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = { event match { case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) => - val finalStage = newStage(rdd, partitions.size, None, jobId, Some(callSite)) + var finalStage: Stage = null + try { + // New stage creation at times and if its not protected, the scheduler thread is killed. + // e.g. it can fail when jobs are run on HadoopRDD whose underlying hdfs files have been deleted + finalStage = newStage(rdd, partitions.size, None, jobId, Some(callSite)) + } catch { + case e: Exception => + logWarning("Creating new stage failed due to exception - job: " + jobId, e) + listener.jobFailed(e) + return false + } val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties) clearCacheLocs() logInfo("Got job " + job.jobId + " (" + callSite + ") with " + partitions.length + @@ -391,37 +547,31 @@ class DAGScheduler( logInfo("Missing parents: " + getMissingParentStages(finalStage)) if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) { // Compute very short actions like first() or take() with no parent stages locally. + listenerBus.post(SparkListenerJobStart(job, Array(), properties)) runLocally(job) } else { - listenerBus.post(SparkListenerJobStart(job, properties)) idToActiveJob(jobId) = job activeJobs += job resultStageToJob(finalStage) = job + listenerBus.post(SparkListenerJobStart(job, jobIdToStageIds(jobId).toArray, properties)) submitStage(finalStage) } case JobCancelled(jobId) => - // Cancel a job: find all the running stages that are linked to this job, and cancel them. - running.filter(_.jobId == jobId).foreach { stage => - taskSched.cancelTasks(stage.id) - } + handleJobCancellation(jobId) case JobGroupCancelled(groupId) => // Cancel all jobs belonging to this job group. // First finds all active jobs with this group id, and then kill stages for them. - val jobIds = activeJobs.filter(groupId == _.properties.get(SparkContext.SPARK_JOB_GROUP_ID)) - .map(_.jobId) - if (!jobIds.isEmpty) { - running.filter(stage => jobIds.contains(stage.jobId)).foreach { stage => - taskSched.cancelTasks(stage.id) - } - } + val activeInGroup = activeJobs.filter(groupId == _.properties.get(SparkContext.SPARK_JOB_GROUP_ID)) + val jobIds = activeInGroup.map(_.jobId) + jobIds.foreach { handleJobCancellation } case AllJobsCancelled => // Cancel all running jobs. - running.foreach { stage => - taskSched.cancelTasks(stage.id) - } + running.map(_.jobId).foreach { handleJobCancellation } + activeJobs.clear() // These should already be empty by this point, + idToActiveJob.clear() // but just in case we lost track of some jobs... case ExecutorGained(execId, host) => handleExecutorGained(execId, host) @@ -430,6 +580,18 @@ class DAGScheduler( handleExecutorLost(execId) case BeginEvent(task, taskInfo) => + for ( + job <- idToActiveJob.get(task.stageId); + stage <- stageIdToStage.get(task.stageId); + stageInfo <- stageToInfos.get(stage) + ) { + if (taskInfo.serializedSize > TASK_SIZE_TO_WARN * 1024 && !stageInfo.emittedTaskSizeWarning) { + stageInfo.emittedTaskSizeWarning = true + logWarning(("Stage %d (%s) contains a task of very large " + + "size (%d KB). The maximum recommended task size is %d KB.").format( + task.stageId, stageInfo.name, taskInfo.serializedSize / 1024, TASK_SIZE_TO_WARN)) + } + } listenerBus.post(SparkListenerTaskStart(task, taskInfo)) case GettingResultEvent(task, taskInfo) => @@ -440,7 +602,12 @@ class DAGScheduler( handleTaskCompletion(completion) case TaskSetFailed(taskSet, reason) => - abortStage(stageIdToStage(taskSet.stageId), reason) + stageIdToStage.get(taskSet.stageId).foreach { abortStage(_, reason) } + + case ResubmitFailedStages => + if (failed.size > 0) { + resubmitFailedStages() + } case StopDAGScheduler => // Cancel any active jobs @@ -502,6 +669,7 @@ class DAGScheduler( // Broken out for easier testing in DAGSchedulerSuite. protected def runLocallyWithinThread(job: ActiveJob) { + var jobResult: JobResult = JobSucceeded try { SparkEnv.set(env) val rdd = job.finalStage.rdd @@ -516,31 +684,59 @@ class DAGScheduler( } } catch { case e: Exception => + jobResult = JobFailed(e, Some(job.finalStage)) job.listener.jobFailed(e) + } finally { + val s = job.finalStage + stageIdToJobIds -= s.id // clean up data structures that were populated for a local job, + stageIdToStage -= s.id // but that won't get cleaned up via the normal paths through + stageToInfos -= s // completion events or stage abort + jobIdToStageIds -= job.jobId + listenerBus.post(SparkListenerJobEnd(job, jobResult)) + } + } + + /** Finds the earliest-created active job that needs the stage */ + // TODO: Probably should actually find among the active jobs that need this + // stage the one with the highest priority (highest-priority pool, earliest created). + // That should take care of at least part of the priority inversion problem with + // cross-job dependencies. + private def activeJobForStage(stage: Stage): Option[Int] = { + if (stageIdToJobIds.contains(stage.id)) { + val jobsThatUseStage: Array[Int] = stageIdToJobIds(stage.id).toArray.sorted + jobsThatUseStage.find(idToActiveJob.contains(_)) + } else { + None } } /** Submits stage, but first recursively submits any missing parents. */ private def submitStage(stage: Stage) { - logDebug("submitStage(" + stage + ")") - if (!waiting(stage) && !running(stage) && !failed(stage)) { - val missing = getMissingParentStages(stage).sortBy(_.id) - logDebug("missing: " + missing) - if (missing == Nil) { - logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents") - submitMissingTasks(stage) - running += stage - } else { - for (parent <- missing) { - submitStage(parent) + val jobId = activeJobForStage(stage) + if (jobId.isDefined) { + logDebug("submitStage(" + stage + ")") + if (!waiting(stage) && !running(stage) && !failed(stage)) { + val missing = getMissingParentStages(stage).sortBy(_.id) + logDebug("missing: " + missing) + if (missing == Nil) { + logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents") + submitMissingTasks(stage, jobId.get) + running += stage + } else { + for (parent <- missing) { + submitStage(parent) + } + waiting += stage } - waiting += stage } + } else { + abortStage(stage, "No active job for stage " + stage.id) } } + /** Called when stage's parents are available and we can now do its task. */ - private def submitMissingTasks(stage: Stage) { + private def submitMissingTasks(stage: Stage, jobId: Int) { logDebug("submitMissingTasks(" + stage + ")") // Get our pending tasks and remember them in our pendingTasks entry val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet) @@ -561,7 +757,7 @@ class DAGScheduler( } } - val properties = if (idToActiveJob.contains(stage.jobId)) { + val properties = if (idToActiveJob.contains(jobId)) { idToActiveJob(stage.jobId).properties } else { //this stage will be assigned to "default" pool @@ -643,6 +839,7 @@ class DAGScheduler( activeJobs -= job resultStageToJob -= stage markStageAsFinished(stage) + jobIdToStageIdsRemove(job.jobId) listenerBus.post(SparkListenerJobEnd(job, JobSucceeded)) } job.listener.taskSucceeded(rt.outputId, event.result) @@ -679,7 +876,7 @@ class DAGScheduler( changeEpoch = true) } clearCacheLocs() - if (stage.outputLocs.count(_ == Nil) != 0) { + if (stage.outputLocs.exists(_ == Nil)) { // Some tasks had failed; let's resubmit this stage // TODO: Lower-level scheduler should also deal with this logInfo("Resubmitting " + stage + " (" + stage.name + @@ -696,9 +893,12 @@ class DAGScheduler( } waiting --= newlyRunnable running ++= newlyRunnable - for (stage <- newlyRunnable.sortBy(_.id)) { + for { + stage <- newlyRunnable.sortBy(_.id) + jobId <- activeJobForStage(stage) + } { logInfo("Submitting " + stage + " (" + stage.rdd + "), which is now runnable") - submitMissingTasks(stage) + submitMissingTasks(stage, jobId) } } } @@ -782,21 +982,42 @@ class DAGScheduler( } } + private def handleJobCancellation(jobId: Int) { + if (!jobIdToStageIds.contains(jobId)) { + logDebug("Trying to cancel unregistered job " + jobId) + } else { + val independentStages = removeJobAndIndependentStages(jobId) + independentStages.foreach { taskSched.cancelTasks } + val error = new SparkException("Job %d cancelled".format(jobId)) + val job = idToActiveJob(jobId) + job.listener.jobFailed(error) + jobIdToStageIds -= jobId + activeJobs -= job + idToActiveJob -= jobId + listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(job.finalStage)))) + } + } + /** * Aborts all jobs depending on a particular Stage. This is called in response to a task set * being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. */ private def abortStage(failedStage: Stage, reason: String) { + if (!stageIdToStage.contains(failedStage.id)) { + // Skip all the actions if the stage has been removed. + return + } val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq stageToInfos(failedStage).completionTime = Some(System.currentTimeMillis()) for (resultStage <- dependentStages) { val job = resultStageToJob(resultStage) val error = new SparkException("Job aborted: " + reason) job.listener.jobFailed(error) - listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(failedStage)))) + jobIdToStageIdsRemove(job.jobId) idToActiveJob -= resultStage.jobId activeJobs -= job resultStageToJob -= resultStage + listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(failedStage)))) } if (dependentStages.isEmpty) { logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done") @@ -867,25 +1088,24 @@ class DAGScheduler( } private def cleanup(cleanupTime: Long) { - var sizeBefore = stageIdToStage.size - stageIdToStage.clearOldValues(cleanupTime) - logInfo("stageIdToStage " + sizeBefore + " --> " + stageIdToStage.size) - - sizeBefore = shuffleToMapStage.size - shuffleToMapStage.clearOldValues(cleanupTime) - logInfo("shuffleToMapStage " + sizeBefore + " --> " + shuffleToMapStage.size) - - sizeBefore = pendingTasks.size - pendingTasks.clearOldValues(cleanupTime) - logInfo("pendingTasks " + sizeBefore + " --> " + pendingTasks.size) - - sizeBefore = stageToInfos.size - stageToInfos.clearOldValues(cleanupTime) - logInfo("stageToInfos " + sizeBefore + " --> " + stageToInfos.size) + Map( + "stageIdToStage" -> stageIdToStage, + "shuffleToMapStage" -> shuffleToMapStage, + "pendingTasks" -> pendingTasks, + "stageToInfos" -> stageToInfos, + "jobIdToStageIds" -> jobIdToStageIds, + "stageIdToJobIds" -> stageIdToJobIds). + foreach { case(s, t) => { + val sizeBefore = t.size + t.clearOldValues(cleanupTime) + logInfo("%s %d --> %d".format(s, sizeBefore, t.size)) + }} } def stop() { - eventProcessActor ! StopDAGScheduler + if (eventProcessActor != null) { + eventProcessActor ! StopDAGScheduler + } metadataCleaner.cancel() taskSched.stop() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index 708d221d60..add1187613 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -65,12 +65,13 @@ private[scheduler] case class CompletionEvent( taskMetrics: TaskMetrics) extends DAGSchedulerEvent -private[scheduler] -case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent +private[scheduler] case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent private[scheduler] case class ExecutorLost(execId: String) extends DAGSchedulerEvent private[scheduler] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent +private[scheduler] case object ResubmitFailedStages extends DAGSchedulerEvent + private[scheduler] case object StopDAGScheduler extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala index 58f238d8cf..b026f860a8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala @@ -31,6 +31,7 @@ private[spark] class JobWaiter[T]( private var finishedTasks = 0 // Is the job as a whole finished (succeeded or failed)? + @volatile private var _jobFinished = totalTasks == 0 def jobFinished = _jobFinished diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulingMode.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulingMode.scala index 0a786deb16..3832ee7ff6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulingMode.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulingMode.scala @@ -22,7 +22,7 @@ package org.apache.spark.scheduler * to order tasks amongst a Schedulable's sub-queues * "NONE" is used when the a Schedulable has no sub-queues. */ -object SchedulingMode extends Enumeration("FAIR", "FIFO", "NONE") { +object SchedulingMode extends Enumeration { type SchedulingMode = Value val FAIR,FIFO,NONE = Value diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 1dc71a0428..0f2deb4bcb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -167,6 +167,7 @@ private[spark] class ShuffleMapTask( var totalTime = 0L val compressedSizes: Array[Byte] = shuffle.writers.map { writer: BlockObjectWriter => writer.commit() + writer.close() val size = writer.fileSegment().length totalBytes += size totalTime += writer.timeWriting() @@ -184,14 +185,16 @@ private[spark] class ShuffleMapTask( } catch { case e: Exception => // If there is an exception from running the task, revert the partial writes // and throw the exception upstream to Spark. - if (shuffle != null) { - shuffle.writers.foreach(_.revertPartialWrites()) + if (shuffle != null && shuffle.writers != null) { + for (writer <- shuffle.writers) { + writer.revertPartialWrites() + writer.close() + } } throw e } finally { // Release the writers back to the shuffle block manager. if (shuffle != null && shuffle.writers != null) { - shuffle.writers.foreach(_.close()) shuffle.releaseWriters(success) } // Execute the callbacks on task completion. diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index a35081f7b1..ee63b3c4a1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -37,7 +37,7 @@ case class SparkListenerTaskGettingResult( case class SparkListenerTaskEnd(task: Task[_], reason: TaskEndReason, taskInfo: TaskInfo, taskMetrics: TaskMetrics) extends SparkListenerEvents -case class SparkListenerJobStart(job: ActiveJob, properties: Properties = null) +case class SparkListenerJobStart(job: ActiveJob, stageIds: Array[Int], properties: Properties = null) extends SparkListenerEvents case class SparkListenerJobEnd(job: ActiveJob, jobResult: JobResult) @@ -63,7 +63,7 @@ trait SparkListener { * Called when a task begins remotely fetching its result (will not be called for tasks that do * not need to fetch the result remotely). */ - def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { } + def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { } /** * Called when a task ends @@ -131,8 +131,8 @@ object StatsReportListener extends Logging { def showDistribution(heading: String, d: Distribution, formatNumber: Double => String) { val stats = d.statCounter - logInfo(heading + stats) val quantiles = d.getQuantiles(probabilities).map{formatNumber} + logInfo(heading + stats) logInfo(percentilesHeader) logInfo("\t" + quantiles.mkString("\t")) } @@ -173,8 +173,6 @@ object StatsReportListener extends Logging { showMillisDistribution(heading, extractLongDistribution(stage, getMetric)) } - - val seconds = 1000L val minutes = seconds * 60 val hours = minutes * 60 @@ -198,7 +196,6 @@ object StatsReportListener extends Logging { } - case class RuntimePercentage(executorPct: Double, fetchPct: Option[Double], other: Double) object RuntimePercentage { def apply(totalTime: Long, metrics: TaskMetrics): RuntimePercentage = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index d5824e7954..85687ea330 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -91,4 +91,3 @@ private[spark] class SparkListenerBus() extends Logging { return true } } - diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index 93599dfdc8..e9f2198a00 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -33,4 +33,5 @@ class StageInfo( val name = stage.name val numPartitions = stage.numPartitions val numTasks = stage.numTasks + var emittedTaskSizeWarning = false } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index 4bae26f3a6..3c22edd524 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -46,6 +46,8 @@ class TaskInfo( var failed = false + var serializedSize: Int = 0 + def markGettingResult(time: Long = System.currentTimeMillis) { gettingResultTime = time } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala index 47b0f387aa..35de13c385 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala @@ -18,9 +18,7 @@ package org.apache.spark.scheduler -private[spark] object TaskLocality - extends Enumeration("PROCESS_LOCAL", "NODE_LOCAL", "RACK_LOCAL", "ANY") -{ +private[spark] object TaskLocality extends Enumeration { // process local is expected to be used ONLY within tasksetmanager for now. val PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY = Value diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala index c1e65a3c48..66ab8ea4cd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala @@ -24,8 +24,7 @@ import java.util.{TimerTask, Timer} import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet - -import akka.util.duration._ +import scala.concurrent.duration._ import org.apache.spark._ import org.apache.spark.TaskState.TaskState @@ -100,8 +99,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext) this.dagScheduler = dagScheduler } - def initialize(context: SchedulerBackend) { - backend = context + def initialize(backend: SchedulerBackend) { + this.backend = backend // temporarily set rootPool name to empty rootPool = new Pool("", schedulingMode, 0, 0) schedulableBuilder = { @@ -122,7 +121,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) if (System.getProperty("spark.speculation", "false").toBoolean) { logInfo("Starting speculative execution thread") - + import sc.env.actorSystem.dispatcher sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL milliseconds, SPECULATION_INTERVAL milliseconds) { checkSpeculatableTasks() @@ -173,7 +172,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext) backend.killTask(tid, execId) } } - tsm.error("Stage %d was cancelled".format(stageId)) + logInfo("Stage %d was cancelled".format(stageId)) + tsm.removeAllRunningTasks() + taskSetFinished(tsm) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala index 4c5eca8537..bf494aa64d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala @@ -377,6 +377,7 @@ private[spark] class ClusterTaskSetManager( logInfo("Serialized task %s:%d as %d bytes in %d ms".format( taskSet.id, index, serializedTask.limit, timeTaken)) val taskName = "task %s:%d".format(taskSet.id, index) + info.serializedSize = serializedTask.limit if (taskAttempts(index).size == 1) taskStarted(task,info) return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask)) @@ -528,10 +529,10 @@ private[spark] class ClusterTaskSetManager( addPendingTask(index) if (state != TaskState.KILLED) { numFailures(index) += 1 - if (numFailures(index) > MAX_TASK_FAILURES) { - logError("Task %s:%d failed more than %d times; aborting job".format( + if (numFailures(index) >= MAX_TASK_FAILURES) { + logError("Task %s:%d failed %d times; aborting job".format( taskSet.id, index, MAX_TASK_FAILURES)) - abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES)) + abort("Task %s:%d failed %d times".format(taskSet.id, index, MAX_TASK_FAILURES)) } } } else { @@ -573,7 +574,7 @@ private[spark] class ClusterTaskSetManager( runningTasks = runningTasksSet.size } - private def removeAllRunningTasks() { + private[cluster] def removeAllRunningTasks() { val numRunningTasks = runningTasksSet.size runningTasksSet.clear() if (parent != null) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index a45bee536c..f5e8766f6d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -20,13 +20,12 @@ package org.apache.spark.scheduler.cluster import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.concurrent.Await +import scala.concurrent.duration._ import akka.actor._ -import akka.dispatch.Await import akka.pattern.ask -import akka.remote.{RemoteClientShutdown, RemoteClientDisconnected, RemoteClientLifeCycleEvent} -import akka.util.Duration -import akka.util.duration._ +import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import org.apache.spark.{SparkException, Logging, TaskState} import org.apache.spark.scheduler.TaskDescription @@ -53,15 +52,15 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac private val executorAddress = new HashMap[String, Address] private val executorHost = new HashMap[String, String] private val freeCores = new HashMap[String, Int] - private val actorToExecutorId = new HashMap[ActorRef, String] private val addressToExecutorId = new HashMap[Address, String] override def preStart() { // Listen for remote client disconnection events, since they don't go through Akka's watch() - context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) + context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) // Periodically revive offers to allow delay scheduling to work val reviveInterval = System.getProperty("spark.scheduler.revive.interval", "1000").toLong + import context.dispatcher context.system.scheduler.schedule(0.millis, reviveInterval.millis, self, ReviveOffers) } @@ -73,12 +72,10 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac } else { logInfo("Registered executor: " + sender + " with ID " + executorId) sender ! RegisteredExecutor(sparkProperties) - context.watch(sender) executorActor(executorId) = sender executorHost(executorId) = Utils.parseHostPort(hostPort)._1 freeCores(executorId) = cores executorAddress(executorId) = sender.path.address - actorToExecutorId(sender) = executorId addressToExecutorId(sender.path.address) = executorId totalCoreCount.addAndGet(cores) makeOffers() @@ -118,14 +115,9 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac removeExecutor(executorId, reason) sender ! true - case Terminated(actor) => - actorToExecutorId.get(actor).foreach(removeExecutor(_, "Akka actor terminated")) + case DisassociatedEvent(_, address, _) => + addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client disassociated")) - case RemoteClientDisconnected(transport, address) => - addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client disconnected")) - - case RemoteClientShutdown(transport, address) => - addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client shutdown")) } // Make fake resource offers on all executors @@ -153,7 +145,6 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac if (executorActor.contains(executorId)) { logInfo("Executor " + executorId + " disconnected, so removing it") val numCores = freeCores(executorId) - actorToExecutorId -= executorActor(executorId) addressToExecutorId -= executorAddress(executorId) executorActor -= executorId executorHost -= executorId @@ -199,6 +190,7 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac } override def stop() { + stopExecutors() try { if (driverActor != null) { val future = driverActor.ask(StopDriver)(timeout) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index 0ea35e2b7a..e8fecec4a6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -36,7 +36,7 @@ private[spark] class SimrSchedulerBackend( override def start() { super.start() - val driverUrl = "akka://spark@%s:%s/user/%s".format( + val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format( System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"), CoarseGrainedSchedulerBackend.ACTOR_NAME) @@ -62,7 +62,6 @@ private[spark] class SimrSchedulerBackend( val conf = new Configuration() val fs = FileSystem.get(conf) fs.delete(new Path(driverFilePath), false) - super.stopExecutors() super.stop() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index cefa970bb9..7127a72d6d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -42,7 +42,7 @@ private[spark] class SparkDeploySchedulerBackend( super.start() // The endpoint for executors to talk to us - val driverUrl = "akka://spark@%s:%s/user/%s".format( + val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format( System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"), CoarseGrainedSchedulerBackend.ACTOR_NAME) val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}") diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala index 2064d97b49..e68c527713 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala @@ -71,7 +71,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche case cnf: ClassNotFoundException => val loader = Thread.currentThread.getContextClassLoader taskSetManager.abort("ClassNotFound with classloader: " + loader) - case ex => + case ex: Throwable => taskSetManager.abort("Exception while deserializing and fetching task: %s".format(ex)) } } @@ -95,7 +95,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche val loader = Thread.currentThread.getContextClassLoader logError( "Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader) - case ex => {} + case ex: Throwable => {} } scheduler.handleFailedTask(taskSetManager, tid, taskState, reason) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index cd521e0f2b..84fe3094cc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -120,7 +120,7 @@ private[spark] class CoarseMesosSchedulerBackend( } val command = CommandInfo.newBuilder() .setEnvironment(environment) - val driverUrl = "akka://spark@%s:%s/user/%s".format( + val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format( System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"), CoarseGrainedSchedulerBackend.ACTOR_NAME) diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala index 2699f0b33e..01e95162c0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala @@ -74,7 +74,7 @@ class LocalActor(localScheduler: LocalScheduler, private var freeCores: Int) } } -private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext) +private[spark] class LocalScheduler(val threads: Int, val maxFailures: Int, val sc: SparkContext) extends TaskScheduler with ExecutorBackend with Logging { @@ -144,7 +144,8 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: localActor ! KillTask(tid) } } - tsm.error("Stage %d was cancelled".format(stageId)) + logInfo("Stage %d was cancelled".format(stageId)) + taskSetFinished(tsm) } } @@ -192,17 +193,19 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: synchronized { taskIdToTaskSetId.get(taskId) match { case Some(taskSetId) => - val taskSetManager = activeTaskSets(taskSetId) - taskSetTaskIds(taskSetId) -= taskId + val taskSetManager = activeTaskSets.get(taskSetId) + taskSetManager.foreach { tsm => + taskSetTaskIds(taskSetId) -= taskId - state match { - case TaskState.FINISHED => - taskSetManager.taskEnded(taskId, state, serializedData) - case TaskState.FAILED => - taskSetManager.taskFailed(taskId, state, serializedData) - case TaskState.KILLED => - taskSetManager.error("Task %d was killed".format(taskId)) - case _ => {} + state match { + case TaskState.FINISHED => + tsm.taskEnded(taskId, state, serializedData) + case TaskState.FAILED => + tsm.taskFailed(taskId, state, serializedData) + case TaskState.KILLED => + tsm.error("Task %d was killed".format(taskId)) + case _ => {} + } } case None => logInfo("Ignoring update from TID " + taskId + " because its task set is gone") diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 702aca8323..19a025a329 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -24,9 +24,9 @@ import scala.collection.mutable.{HashMap, ArrayBuffer} import scala.util.Random import akka.actor.{ActorSystem, Cancellable, Props} -import akka.dispatch.{Await, Future} -import akka.util.Duration -import akka.util.duration._ +import scala.concurrent.{Await, Future} +import scala.concurrent.duration.Duration +import scala.concurrent.duration._ import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream} @@ -924,4 +924,3 @@ private[spark] object BlockManager extends Logging { blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.host)) } } - diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 94038649b3..e05b842476 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -17,16 +17,17 @@ package org.apache.spark.storage -import akka.actor.ActorRef -import akka.dispatch.{Await, Future} +import scala.concurrent.{Await, Future} +import scala.concurrent.duration._ +import scala.concurrent.ExecutionContext.Implicits.global + +import akka.actor._ import akka.pattern.ask -import akka.util.Duration import org.apache.spark.{Logging, SparkException} import org.apache.spark.storage.BlockManagerMessages._ - -private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Logging { +private[spark] class BlockManagerMaster(var driverActor : Either[ActorRef, ActorSelection]) extends Logging { val AKKA_RETRY_ATTEMPTS: Int = System.getProperty("spark.akka.num.retries", "3").toInt val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt @@ -156,7 +157,10 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi while (attempts < AKKA_RETRY_ATTEMPTS) { attempts += 1 try { - val future = driverActor.ask(message)(timeout) + val future = driverActor match { + case Left(a: ActorRef) => a.ask(message)(timeout) + case Right(b: ActorSelection) => b.ask(message)(timeout) + } val result = Await.result(future, timeout) if (result == null) { throw new SparkException("BlockManagerMaster returned null") diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index f8cf14b503..154a3980e9 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -23,10 +23,10 @@ import scala.collection.mutable import scala.collection.JavaConversions._ import akka.actor.{Actor, ActorRef, Cancellable} -import akka.dispatch.Future import akka.pattern.ask -import akka.util.Duration -import akka.util.duration._ + +import scala.concurrent.duration._ +import scala.concurrent.Future import org.apache.spark.{Logging, SparkException} import org.apache.spark.storage.BlockManagerMessages._ @@ -65,6 +65,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { override def preStart() { if (!BlockManager.getDisableHeartBeatsForTesting) { + import context.dispatcher timeoutCheckingTask = context.system.scheduler.schedule( 0.seconds, checkTimeoutInterval.milliseconds, self, ExpireDeadHosts) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 469e68fed7..b4451fc7b8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -93,6 +93,8 @@ class DiskBlockObjectWriter( def write(i: Int): Unit = callWithTiming(out.write(i)) override def write(b: Array[Byte]) = callWithTiming(out.write(b)) override def write(b: Array[Byte], off: Int, len: Int) = callWithTiming(out.write(b, off, len)) + override def close() = out.close() + override def flush() = out.flush() } private val syncWrites = System.getProperty("spark.shuffle.sync", "false").toBoolean diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index 2f1b049ce4..e828e1d1c5 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -62,7 +62,7 @@ class ShuffleBlockManager(blockManager: BlockManager) { // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file. // TODO: Remove this once the shuffle file consolidation feature is stable. val consolidateShuffleFiles = - System.getProperty("spark.shuffle.consolidateFiles", "true").toBoolean + System.getProperty("spark.shuffle.consolidateFiles", "false").toBoolean private val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024 diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala index 632ff047d1..b5596dffd3 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala @@ -101,7 +101,7 @@ class StorageLevel private( var result = "" result += (if (useDisk) "Disk " else "") result += (if (useMemory) "Memory " else "") - result += (if (deserialized) "Deserialized " else "Serialized") + result += (if (deserialized) "Deserialized " else "Serialized ") result += "%sx Replicated".format(replication) result } diff --git a/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala index 1e4db4f66b..d52b3d8284 100644 --- a/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala +++ b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.storage import java.util.concurrent.atomic.AtomicLong diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala index 860e680576..a8db37ded1 100644 --- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala @@ -93,7 +93,7 @@ private[spark] object ThreadingTest { val actorSystem = ActorSystem("test") val serializer = new KryoSerializer val blockManagerMaster = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true)))) + Left(actorSystem.actorOf(Props(new BlockManagerMasterActor(true))))) val blockManager = new BlockManager( "", actorSystem, blockManagerMaster, serializer, 1024 * 1024) val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala index 42e9be6e19..e596690bc3 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala @@ -76,7 +76,7 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { } - val execInfo = for (b <- 0 until storageStatusList.size) yield getExecInfo(b) + val execInfo = for (statusId <- 0 until storageStatusList.size) yield getExecInfo(statusId) val execTable = UIUtils.listingTable(execHead, execRow, execInfo) val content = @@ -99,16 +99,17 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { UIUtils.headerSparkPage(content, sc, "Executors (" + execInfo.size + ")", Executors) } - def getExecInfo(a: Int): Seq[String] = { - val execId = sc.getExecutorStorageStatus(a).blockManagerId.executorId - val hostPort = sc.getExecutorStorageStatus(a).blockManagerId.hostPort - val rddBlocks = sc.getExecutorStorageStatus(a).blocks.size.toString - val memUsed = sc.getExecutorStorageStatus(a).memUsed().toString - val maxMem = sc.getExecutorStorageStatus(a).maxMem.toString - val diskUsed = sc.getExecutorStorageStatus(a).diskUsed().toString - val activeTasks = listener.executorToTasksActive.get(a.toString).map(l => l.size).getOrElse(0) - val failedTasks = listener.executorToTasksFailed.getOrElse(a.toString, 0) - val completedTasks = listener.executorToTasksComplete.getOrElse(a.toString, 0) + def getExecInfo(statusId: Int): Seq[String] = { + val status = sc.getExecutorStorageStatus(statusId) + val execId = status.blockManagerId.executorId + val hostPort = status.blockManagerId.hostPort + val rddBlocks = status.blocks.size.toString + val memUsed = status.memUsed().toString + val maxMem = status.maxMem.toString + val diskUsed = status.diskUsed().toString + val activeTasks = listener.executorToTasksActive.getOrElse(execId, HashSet.empty[Long]).size + val failedTasks = listener.executorToTasksFailed.getOrElse(execId, 0) + val completedTasks = listener.executorToTasksComplete.getOrElse(execId, 0) val totalTasks = activeTasks + failedTasks + completedTasks Seq( diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala index e7eab374ad..c1ee2f3d00 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala @@ -17,7 +17,7 @@ package org.apache.spark.ui.jobs -import akka.util.Duration +import scala.concurrent.duration._ import java.text.SimpleDateFormat diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index c1c7aa70e6..69f9446bab 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -60,11 +60,13 @@ private[spark] class StagePage(parent: JobProgressUI) { var activeTime = 0L listener.stageIdToTasksActive(stageId).foreach(activeTime += _.timeRunning(now)) + val finishedTasks = listener.stageIdToTaskInfos(stageId).filter(_._1.finished) + val summary =
  • - CPU time: + Total duration across all tasks: {parent.formatDuration(listener.stageIdToTime.getOrElse(stageId, 0L) + activeTime)}
  • {if (hasShuffleRead) @@ -104,6 +106,33 @@ private[spark] class StagePage(parent: JobProgressUI) { val serviceQuantiles = "Duration" +: Distribution(serviceTimes).get.getQuantiles().map( ms => parent.formatDuration(ms.toLong)) + val gettingResultTimes = validTasks.map{case (info, metrics, exception) => + if (info.gettingResultTime > 0) { + (info.finishTime - info.gettingResultTime).toDouble + } else { + 0.0 + } + } + val gettingResultQuantiles = ("Time spent fetching task results" +: + Distribution(gettingResultTimes).get.getQuantiles().map( + millis => parent.formatDuration(millis.toLong))) + // The scheduler delay includes the network delay to send the task to the worker + // machine and to send back the result (but not the time to fetch the task result, + // if it needed to be fetched from the block manager on the worker). + val schedulerDelays = validTasks.map{case (info, metrics, exception) => + val totalExecutionTime = { + if (info.gettingResultTime > 0) { + (info.gettingResultTime - info.launchTime).toDouble + } else { + (info.finishTime - info.launchTime).toDouble + } + } + totalExecutionTime - metrics.get.executorRunTime + } + val schedulerDelayQuantiles = ("Scheduler delay" +: + Distribution(schedulerDelays).get.getQuantiles().map( + millis => parent.formatDuration(millis.toLong))) + def getQuantileCols(data: Seq[Double]) = Distribution(data).get.getQuantiles().map(d => Utils.bytesToString(d.toLong)) @@ -119,7 +148,10 @@ private[spark] class StagePage(parent: JobProgressUI) { } val shuffleWriteQuantiles = "Shuffle Write" +: getQuantileCols(shuffleWriteSizes) - val listings: Seq[Seq[String]] = Seq(serviceQuantiles, + val listings: Seq[Seq[String]] = Seq( + serviceQuantiles, + gettingResultQuantiles, + schedulerDelayQuantiles, if (hasShuffleRead) shuffleReadQuantiles else Nil, if (hasShuffleWrite) shuffleWriteQuantiles else Nil) @@ -133,7 +165,7 @@ private[spark] class StagePage(parent: JobProgressUI) { summary ++

    Summary Metrics for {numCompleted} Completed Tasks

    ++
    {summaryTable.getOrElse("No tasks have reported metrics yet.")}
    ++ -

    Tasks

    ++ taskTable; +

    Tasks

    ++ taskTable headerSparkPage(content, parent.sc, "Details for Stage %d".format(stageId), Stages) } @@ -152,21 +184,18 @@ private[spark] class StagePage(parent: JobProgressUI) { else metrics.map(m => parent.formatDuration(m.executorRunTime)).getOrElse("") val gcTime = metrics.map(m => m.jvmGCTime).getOrElse(0L) - var shuffleReadSortable: String = "" - var shuffleReadReadable: String = "" - if (shuffleRead) { - shuffleReadSortable = metrics.flatMap{m => m.shuffleReadMetrics}.map{s => s.remoteBytesRead}.toString() - shuffleReadReadable = metrics.flatMap{m => m.shuffleReadMetrics}.map{s => - Utils.bytesToString(s.remoteBytesRead)}.getOrElse("") - } + val maybeShuffleRead = metrics.flatMap{m => m.shuffleReadMetrics}.map{s => s.remoteBytesRead} + val shuffleReadSortable = maybeShuffleRead.map(_.toString).getOrElse("") + val shuffleReadReadable = maybeShuffleRead.map{Utils.bytesToString(_)}.getOrElse("") - var shuffleWriteSortable: String = "" - var shuffleWriteReadable: String = "" - if (shuffleWrite) { - shuffleWriteSortable = metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => s.shuffleBytesWritten}.toString() - shuffleWriteReadable = metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => - Utils.bytesToString(s.shuffleBytesWritten)}.getOrElse("") - } + val maybeShuffleWrite = metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => s.shuffleBytesWritten} + val shuffleWriteSortable = maybeShuffleWrite.map(_.toString).getOrElse("") + val shuffleWriteReadable = maybeShuffleWrite.map{Utils.bytesToString(_)}.getOrElse("") + + val maybeWriteTime = metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => s.shuffleWriteTime} + val writeTimeSortable = maybeWriteTime.map(_.toString).getOrElse("") + val writeTimeReadable = maybeWriteTime.map{ t => t / (1000 * 1000)}.map{ ms => + if (ms == 0) "" else parent.formatDuration(ms)}.getOrElse("") {info.index} @@ -187,8 +216,8 @@ private[spark] class StagePage(parent: JobProgressUI) { }} {if (shuffleWrite) { - {metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => - parent.formatDuration(s.shuffleWriteTime / (1000 * 1000))}.getOrElse("")} + + {writeTimeReadable} {shuffleWriteReadable} diff --git a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala index 1d633d374a..a5446b3fc3 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala @@ -17,7 +17,7 @@ package org.apache.spark.ui.storage -import akka.util.Duration +import scala.concurrent.duration._ import javax.servlet.http.HttpServletRequest diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index d4c5065c3f..74133cef6c 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -17,11 +17,8 @@ package org.apache.spark.util -import akka.actor.{ActorSystem, ExtendedActorSystem} +import akka.actor.{ActorSystem, ExtendedActorSystem, IndestructibleActorSystem} import com.typesafe.config.ConfigFactory -import akka.util.duration._ -import akka.remote.RemoteActorRefProvider - /** * Various utility classes for working with Akka. @@ -34,39 +31,57 @@ private[spark] object AkkaUtils { * * Note: the `name` parameter is important, as even if a client sends a message to right * host + port, if the system name is incorrect, Akka will drop the message. + * + * If indestructible is set to true, the Actor System will continue running in the event + * of a fatal exception. This is used by [[org.apache.spark.executor.Executor]]. */ - def createActorSystem(name: String, host: String, port: Int): (ActorSystem, Int) = { - val akkaThreads = System.getProperty("spark.akka.threads", "4").toInt + def createActorSystem(name: String, host: String, port: Int, indestructible: Boolean = false) + : (ActorSystem, Int) = { + + val akkaThreads = System.getProperty("spark.akka.threads", "4").toInt val akkaBatchSize = System.getProperty("spark.akka.batchSize", "15").toInt - val akkaTimeout = System.getProperty("spark.akka.timeout", "60").toInt + + val akkaTimeout = System.getProperty("spark.akka.timeout", "100").toInt + val akkaFrameSize = System.getProperty("spark.akka.frameSize", "10").toInt - val lifecycleEvents = if (System.getProperty("spark.akka.logLifecycleEvents", "false").toBoolean) "on" else "off" - // 10 seconds is the default akka timeout, but in a cluster, we need higher by default. - val akkaWriteTimeout = System.getProperty("spark.akka.writeTimeout", "30").toInt - - val akkaConf = ConfigFactory.parseString(""" - akka.daemonic = on - akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"] - akka.stdout-loglevel = "ERROR" - akka.actor.provider = "akka.remote.RemoteActorRefProvider" - akka.remote.transport = "akka.remote.netty.NettyRemoteTransport" - akka.remote.netty.hostname = "%s" - akka.remote.netty.port = %d - akka.remote.netty.connection-timeout = %ds - akka.remote.netty.message-frame-size = %d MiB - akka.remote.netty.execution-pool-size = %d - akka.actor.default-dispatcher.throughput = %d - akka.remote.log-remote-lifecycle-events = %s - akka.remote.netty.write-timeout = %ds - """.format(host, port, akkaTimeout, akkaFrameSize, akkaThreads, akkaBatchSize, - lifecycleEvents, akkaWriteTimeout)) + val lifecycleEvents = + if (System.getProperty("spark.akka.logLifecycleEvents", "false").toBoolean) "on" else "off" - val actorSystem = ActorSystem(name, akkaConf) + val akkaHeartBeatPauses = System.getProperty("spark.akka.heartbeat.pauses", "600").toInt + val akkaFailureDetector = + System.getProperty("spark.akka.failure-detector.threshold", "300.0").toDouble + val akkaHeartBeatInterval = System.getProperty("spark.akka.heartbeat.interval", "1000").toInt + + val akkaConf = ConfigFactory.parseString( + s""" + |akka.daemonic = on + |akka.loggers = [""akka.event.slf4j.Slf4jLogger""] + |akka.stdout-loglevel = "ERROR" + |akka.jvm-exit-on-fatal-error = off + |akka.remote.transport-failure-detector.heartbeat-interval = $akkaHeartBeatInterval s + |akka.remote.transport-failure-detector.acceptable-heartbeat-pause = $akkaHeartBeatPauses s + |akka.remote.transport-failure-detector.threshold = $akkaFailureDetector + |akka.actor.provider = "akka.remote.RemoteActorRefProvider" + |akka.remote.netty.tcp.transport-class = "akka.remote.transport.netty.NettyTransport" + |akka.remote.netty.tcp.hostname = "$host" + |akka.remote.netty.tcp.port = $port + |akka.remote.netty.tcp.tcp-nodelay = on + |akka.remote.netty.tcp.connection-timeout = $akkaTimeout s + |akka.remote.netty.tcp.maximum-frame-size = ${akkaFrameSize}MiB + |akka.remote.netty.tcp.execution-pool-size = $akkaThreads + |akka.actor.default-dispatcher.throughput = $akkaBatchSize + |akka.remote.log-remote-lifecycle-events = $lifecycleEvents + """.stripMargin) + + val actorSystem = if (indestructible) { + IndestructibleActorSystem(name, akkaConf) + } else { + ActorSystem(name, akkaConf) + } - // Figure out the port number we bound to, in case port was passed as 0. This is a bit of a - // hack because Akka doesn't let you figure out the port through the public API yet. val provider = actorSystem.asInstanceOf[ExtendedActorSystem].provider - val boundPort = provider.asInstanceOf[RemoteActorRefProvider].transport.address.port.get - return (actorSystem, boundPort) + val boundPort = provider.getDefaultAddress.port.get + (actorSystem, boundPort) } + } diff --git a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala index f60deafc6f..8bb4ee3bfa 100644 --- a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala @@ -35,6 +35,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi private var capacity = nextPowerOf2(initialCapacity) private var mask = capacity - 1 private var curSize = 0 + private var growThreshold = LOAD_FACTOR * capacity // Holds keys and values in the same array for memory locality; specifically, the order of // elements is key0, value0, key1, value1, key2, value2, etc. @@ -56,7 +57,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi var i = 1 while (true) { val curKey = data(2 * pos) - if (k.eq(curKey) || k == curKey) { + if (k.eq(curKey) || k.equals(curKey)) { return data(2 * pos + 1).asInstanceOf[V] } else if (curKey.eq(null)) { return null.asInstanceOf[V] @@ -80,9 +81,23 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi haveNullValue = true return } - val isNewEntry = putInto(data, k, value.asInstanceOf[AnyRef]) - if (isNewEntry) { - incrementSize() + var pos = rehash(key.hashCode) & mask + var i = 1 + while (true) { + val curKey = data(2 * pos) + if (curKey.eq(null)) { + data(2 * pos) = k + data(2 * pos + 1) = value.asInstanceOf[AnyRef] + incrementSize() // Since we added a new key + return + } else if (k.eq(curKey) || k.equals(curKey)) { + data(2 * pos + 1) = value.asInstanceOf[AnyRef] + return + } else { + val delta = i + pos = (pos + delta) & mask + i += 1 + } } } @@ -104,7 +119,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi var i = 1 while (true) { val curKey = data(2 * pos) - if (k.eq(curKey) || k == curKey) { + if (k.eq(curKey) || k.equals(curKey)) { val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V]) data(2 * pos + 1) = newValue.asInstanceOf[AnyRef] return newValue @@ -161,45 +176,17 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi /** Increase table size by 1, rehashing if necessary */ private def incrementSize() { curSize += 1 - if (curSize > LOAD_FACTOR * capacity) { + if (curSize > growThreshold) { growTable() } } /** - * Re-hash a value to deal better with hash functions that don't differ - * in the lower bits, similar to java.util.HashMap + * Re-hash a value to deal better with hash functions that don't differ in the lower bits. + * We use the Murmur Hash 3 finalization step that's also used in fastutil. */ private def rehash(h: Int): Int = { - val r = h ^ (h >>> 20) ^ (h >>> 12) - r ^ (r >>> 7) ^ (r >>> 4) - } - - /** - * Put an entry into a table represented by data, returning true if - * this increases the size of the table or false otherwise. Assumes - * that "data" has at least one empty slot. - */ - private def putInto(data: Array[AnyRef], key: AnyRef, value: AnyRef): Boolean = { - val mask = (data.length / 2) - 1 - var pos = rehash(key.hashCode) & mask - var i = 1 - while (true) { - val curKey = data(2 * pos) - if (curKey.eq(null)) { - data(2 * pos) = key - data(2 * pos + 1) = value.asInstanceOf[AnyRef] - return true - } else if (curKey.eq(key) || curKey == key) { - data(2 * pos + 1) = value.asInstanceOf[AnyRef] - return false - } else { - val delta = i - pos = (pos + delta) & mask - i += 1 - } - } - return false // Never reached but needed to keep compiler happy + it.unimi.dsi.fastutil.HashCommon.murmurHash3(h) } /** Double the table's size and re-hash everything */ @@ -211,16 +198,36 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi throw new Exception("Can't make capacity bigger than 2^29 elements") } val newData = new Array[AnyRef](2 * newCapacity) - var pos = 0 - while (pos < capacity) { - if (!data(2 * pos).eq(null)) { - putInto(newData, data(2 * pos), data(2 * pos + 1)) + val newMask = newCapacity - 1 + // Insert all our old values into the new array. Note that because our old keys are + // unique, there's no need to check for equality here when we insert. + var oldPos = 0 + while (oldPos < capacity) { + if (!data(2 * oldPos).eq(null)) { + val key = data(2 * oldPos) + val value = data(2 * oldPos + 1) + var newPos = rehash(key.hashCode) & newMask + var i = 1 + var keepGoing = true + while (keepGoing) { + val curKey = newData(2 * newPos) + if (curKey.eq(null)) { + newData(2 * newPos) = key + newData(2 * newPos + 1) = value + keepGoing = false + } else { + val delta = i + newPos = (newPos + delta) & newMask + i += 1 + } + } } - pos += 1 + oldPos += 1 } data = newData capacity = newCapacity - mask = newCapacity - 1 + mask = newMask + growThreshold = LOAD_FACTOR * newCapacity } private def nextPowerOf2(n: Int): Int = { diff --git a/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala b/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala index 0b51c23f7b..a38329df03 100644 --- a/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala +++ b/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala @@ -34,6 +34,8 @@ class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A]) override def iterator: Iterator[A] = underlying.iterator.asScala + override def size: Int = underlying.size + override def ++=(xs: TraversableOnce[A]): this.type = { xs.foreach { this += _ } this diff --git a/core/src/main/scala/org/apache/spark/util/IndestructibleActorSystem.scala b/core/src/main/scala/org/apache/spark/util/IndestructibleActorSystem.scala new file mode 100644 index 0000000000..bf71882ef7 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/IndestructibleActorSystem.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Must be in akka.actor package as ActorSystemImpl is protected[akka]. +package akka.actor + +import scala.util.control.{ControlThrowable, NonFatal} + +import com.typesafe.config.Config + +/** + * An [[akka.actor.ActorSystem]] which refuses to shut down in the event of a fatal exception. + * This is necessary as Spark Executors are allowed to recover from fatal exceptions + * (see [[org.apache.spark.executor.Executor]]). + */ +object IndestructibleActorSystem { + def apply(name: String, config: Config): ActorSystem = + apply(name, config, ActorSystem.findClassLoader()) + + def apply(name: String, config: Config, classLoader: ClassLoader): ActorSystem = + new IndestructibleActorSystemImpl(name, config, classLoader).start() +} + +private[akka] class IndestructibleActorSystemImpl( + override val name: String, + applicationConfig: Config, + classLoader: ClassLoader) + extends ActorSystemImpl(name, applicationConfig, classLoader) { + + protected override def uncaughtExceptionHandler: Thread.UncaughtExceptionHandler = { + val fallbackHandler = super.uncaughtExceptionHandler + + new Thread.UncaughtExceptionHandler() { + def uncaughtException(thread: Thread, cause: Throwable): Unit = { + if (isFatalError(cause) && !settings.JvmExitOnFatalError) { + log.error(cause, "Uncaught fatal error from thread [{}] not shutting down " + + "ActorSystem [{}] tolerating and continuing.... ", thread.getName, name) + //shutdown() //TODO make it configurable + } else { + fallbackHandler.uncaughtException(thread, cause) + } + } + } + } + + def isFatalError(e: Throwable): Boolean = { + e match { + case NonFatal(_) | _: InterruptedException | _: NotImplementedError | _: ControlThrowable => + false + case _ => + true + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala index 67a7f87a5c..7b41ef89f1 100644 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -55,8 +55,7 @@ class MetadataCleaner(cleanerType: MetadataCleanerType.MetadataCleanerType, clea } } -object MetadataCleanerType extends Enumeration("MapOutputTracker", "SparkContext", "HttpBroadcast", "DagScheduler", "ResultTask", - "ShuffleMapTask", "BlockManager", "DiskBlockManager", "BroadcastVars") { +object MetadataCleanerType extends Enumeration { val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK, SHUFFLE_MAP_TASK, BLOCK_MANAGER, SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala index 277de2f8a6..dbff571de9 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala @@ -85,7 +85,7 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() with Logging { } override def filter(p: ((A, B)) => Boolean): Map[A, B] = { - JavaConversions.asScalaConcurrentMap(internalMap).map(kv => (kv._1, kv._2._1)).filter(p) + JavaConversions.mapAsScalaConcurrentMap(internalMap).map(kv => (kv._1, kv._2._1)).filter(p) } override def empty: Map[A, B] = new TimeStampedHashMap[A, B]() diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index fe932d8ede..3f7858d2de 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -22,10 +22,11 @@ import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address} import java.util.{Locale, Random, UUID} import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadPoolExecutor} +import scala.collection.JavaConversions._ import scala.collection.Map import scala.collection.mutable.ArrayBuffer -import scala.collection.JavaConversions._ import scala.io.Source +import scala.reflect.ClassTag import com.google.common.io.Files import com.google.common.util.concurrent.ThreadFactoryBuilder @@ -319,7 +320,7 @@ private[spark] object Utils extends Logging { * result in a new collection. Unlike scala.util.Random.shuffle, this method * uses a local random number generator, avoiding inter-thread contention. */ - def randomize[T: ClassManifest](seq: TraversableOnce[T]): Seq[T] = { + def randomize[T: ClassTag](seq: TraversableOnce[T]): Seq[T] = { randomizeInPlace(seq.toArray) } @@ -823,4 +824,28 @@ private[spark] object Utils extends Logging { return System.getProperties().clone() .asInstanceOf[java.util.Properties].toMap[String, String] } + + /** + * Method executed for repeating a task for side effects. + * Unlike a for comprehension, it permits JVM JIT optimization + */ + def times(numIters: Int)(f: => Unit): Unit = { + var i = 0 + while (i < numIters) { + f + i += 1 + } + } + + /** + * Timing method based on iterations that permit JVM JIT optimization. + * @param numIters number of iterations + * @param f function to be executed + */ + def timeIt(numIters: Int)(f: => Unit): Long = { + val start = System.currentTimeMillis + times(numIters)(f) + System.currentTimeMillis - start + } + } diff --git a/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala new file mode 100644 index 0000000000..e9907e6c85 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.{Random => JavaRandom} +import org.apache.spark.util.Utils.timeIt + +/** + * This class implements a XORShift random number generator algorithm + * Source: + * Marsaglia, G. (2003). Xorshift RNGs. Journal of Statistical Software, Vol. 8, Issue 14. + * @see Paper + * This implementation is approximately 3.5 times faster than + * {@link java.util.Random java.util.Random}, partly because of the algorithm, but also due + * to renouncing thread safety. JDK's implementation uses an AtomicLong seed, this class + * uses a regular Long. We can forgo thread safety since we use a new instance of the RNG + * for each thread. + */ +private[spark] class XORShiftRandom(init: Long) extends JavaRandom(init) { + + def this() = this(System.nanoTime) + + private var seed = init + + // we need to just override next - this will be called by nextInt, nextDouble, + // nextGaussian, nextLong, etc. + override protected def next(bits: Int): Int = { + var nextSeed = seed ^ (seed << 21) + nextSeed ^= (nextSeed >>> 35) + nextSeed ^= (nextSeed << 4) + seed = nextSeed + (nextSeed & ((1L << bits) -1)).asInstanceOf[Int] + } +} + +/** Contains benchmark method and main method to run benchmark of the RNG */ +private[spark] object XORShiftRandom { + + /** + * Main method for running benchmark + * @param args takes one argument - the number of random numbers to generate + */ + def main(args: Array[String]): Unit = { + if (args.length != 1) { + println("Benchmark of XORShiftRandom vis-a-vis java.util.Random") + println("Usage: XORShiftRandom number_of_random_numbers_to_generate") + System.exit(1) + } + println(benchmark(args(0).toInt)) + } + + /** + * @param numIters Number of random numbers to generate while running the benchmark + * @return Map of execution times for {@link java.util.Random java.util.Random} + * and XORShift + */ + def benchmark(numIters: Int) = { + + val seed = 1L + val million = 1e6.toInt + val javaRand = new JavaRandom(seed) + val xorRand = new XORShiftRandom(seed) + + // this is just to warm up the JIT - we're not timing anything + timeIt(1e6.toInt) { + javaRand.nextInt() + xorRand.nextInt() + } + + val iters = timeIt(numIters)(_) + + /* Return results as a map instead of just printing to screen + in case the user wants to do something with them */ + Map("javaTime" -> iters {javaRand.nextInt()}, + "xorTime" -> iters {xorRand.nextInt()}) + + } + +} \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala index 80545c9688..c26f23d500 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala @@ -17,6 +17,7 @@ package org.apache.spark.util.collection +import scala.reflect.ClassTag /** * A fast hash map implementation for nullable keys. This hash map supports insertions and updates, @@ -26,7 +27,7 @@ package org.apache.spark.util.collection * Under the hood, it uses our OpenHashSet implementation. */ private[spark] -class OpenHashMap[K >: Null : ClassManifest, @specialized(Long, Int, Double) V: ClassManifest]( +class OpenHashMap[K >: Null : ClassTag, @specialized(Long, Int, Double) V: ClassTag]( initialCapacity: Int) extends Iterable[(K, V)] with Serializable { diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 4592e4f939..87e009a4de 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -17,6 +17,7 @@ package org.apache.spark.util.collection +import scala.reflect._ /** * A simple, fast hash set optimized for non-null insertion-only use case, where keys are never @@ -36,7 +37,7 @@ package org.apache.spark.util.collection * to explore all spaces for each key (see http://en.wikipedia.org/wiki/Quadratic_probing). */ private[spark] -class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( +class OpenHashSet[@specialized(Long, Int) T: ClassTag]( initialCapacity: Int, loadFactor: Double) extends Serializable { @@ -62,14 +63,14 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( // throws: // scala.tools.nsc.symtab.Types$TypeError: type mismatch; // found : scala.reflect.AnyValManifest[Long] - // required: scala.reflect.ClassManifest[Int] + // required: scala.reflect.ClassTag[Int] // at scala.tools.nsc.typechecker.Contexts$Context.error(Contexts.scala:298) // at scala.tools.nsc.typechecker.Infer$Inferencer.error(Infer.scala:207) // ... - val mt = classManifest[T] - if (mt == ClassManifest.Long) { + val mt = classTag[T] + if (mt == ClassTag.Long) { (new LongHasher).asInstanceOf[Hasher[T]] - } else if (mt == ClassManifest.Int) { + } else if (mt == ClassTag.Int) { (new IntHasher).asInstanceOf[Hasher[T]] } else { new Hasher[T] @@ -79,6 +80,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( protected var _capacity = nextPowerOf2(initialCapacity) protected var _mask = _capacity - 1 protected var _size = 0 + protected var _growThreshold = (loadFactor * _capacity).toInt protected var _bitset = new BitSet(_capacity) @@ -115,7 +117,29 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( * @return The position where the key is placed, plus the highest order bit is set if the key * exists previously. */ - def addWithoutResize(k: T): Int = putInto(_bitset, _data, k) + def addWithoutResize(k: T): Int = { + var pos = hashcode(hasher.hash(k)) & _mask + var i = 1 + while (true) { + if (!_bitset.get(pos)) { + // This is a new key. + _data(pos) = k + _bitset.set(pos) + _size += 1 + return pos | NONEXISTENCE_MASK + } else if (_data(pos) == k) { + // Found an existing key. + return pos + } else { + val delta = i + pos = (pos + delta) & _mask + i += 1 + } + } + // Never reached here + assert(INVALID_POS != INVALID_POS) + INVALID_POS + } /** * Rehash the set if it is overloaded. @@ -126,7 +150,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( * to a new position (in the new data array). */ def rehashIfNeeded(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) { - if (_size > loadFactor * _capacity) { + if (_size > _growThreshold) { rehash(k, allocateFunc, moveFunc) } } @@ -160,37 +184,6 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( */ def nextPos(fromPos: Int): Int = _bitset.nextSetBit(fromPos) - /** - * Put an entry into the set. Return the position where the key is placed. In addition, the - * highest bit in the returned position is set if the key exists prior to this put. - * - * This function assumes the data array has at least one empty slot. - */ - private def putInto(bitset: BitSet, data: Array[T], k: T): Int = { - val mask = data.length - 1 - var pos = hashcode(hasher.hash(k)) & mask - var i = 1 - while (true) { - if (!bitset.get(pos)) { - // This is a new key. - data(pos) = k - bitset.set(pos) - _size += 1 - return pos | NONEXISTENCE_MASK - } else if (data(pos) == k) { - // Found an existing key. - return pos - } else { - val delta = i - pos = (pos + delta) & mask - i += 1 - } - } - // Never reached here - assert(INVALID_POS != INVALID_POS) - INVALID_POS - } - /** * Double the table's size and re-hash everything. We are not really using k, but it is declared * so Scala compiler can specialize this method (which leads to calling the specialized version @@ -204,34 +197,49 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( */ private def rehash(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) { val newCapacity = _capacity * 2 - require(newCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements") - allocateFunc(newCapacity) - val newData = new Array[T](newCapacity) val newBitset = new BitSet(newCapacity) - var pos = 0 - _size = 0 - while (pos < _capacity) { - if (_bitset.get(pos)) { - val newPos = putInto(newBitset, newData, _data(pos)) - moveFunc(pos, newPos & POSITION_MASK) + val newData = new Array[T](newCapacity) + val newMask = newCapacity - 1 + + var oldPos = 0 + while (oldPos < capacity) { + if (_bitset.get(oldPos)) { + val key = _data(oldPos) + var newPos = hashcode(hasher.hash(key)) & newMask + var i = 1 + var keepGoing = true + // No need to check for equality here when we insert so this has one less if branch than + // the similar code path in addWithoutResize. + while (keepGoing) { + if (!newBitset.get(newPos)) { + // Inserting the key at newPos + newData(newPos) = key + newBitset.set(newPos) + moveFunc(oldPos, newPos) + keepGoing = false + } else { + val delta = i + newPos = (newPos + delta) & newMask + i += 1 + } + } } - pos += 1 + oldPos += 1 } + _bitset = newBitset _data = newData _capacity = newCapacity - _mask = newCapacity - 1 + _mask = newMask + _growThreshold = (loadFactor * newCapacity).toInt } /** - * Re-hash a value to deal better with hash functions that don't differ - * in the lower bits, similar to java.util.HashMap + * Re-hash a value to deal better with hash functions that don't differ in the lower bits. + * We use the Murmur Hash 3 finalization step that's also used in fastutil. */ - private def hashcode(h: Int): Int = { - val r = h ^ (h >>> 20) ^ (h >>> 12) - r ^ (r >>> 7) ^ (r >>> 4) - } + private def hashcode(h: Int): Int = it.unimi.dsi.fastutil.HashCommon.murmurHash3(h) private def nextPowerOf2(n: Int): Int = { val highBit = Integer.highestOneBit(n) diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala index d76143e45a..2e1ef06cbc 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala @@ -17,6 +17,7 @@ package org.apache.spark.util.collection +import scala.reflect._ /** * A fast hash map implementation for primitive, non-null keys. This hash map supports @@ -26,15 +27,15 @@ package org.apache.spark.util.collection * Under the hood, it uses our OpenHashSet implementation. */ private[spark] -class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassManifest, - @specialized(Long, Int, Double) V: ClassManifest]( +class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, + @specialized(Long, Int, Double) V: ClassTag]( initialCapacity: Int) extends Iterable[(K, V)] with Serializable { def this() = this(64) - require(classManifest[K] == classManifest[Long] || classManifest[K] == classManifest[Int]) + require(classTag[K] == classTag[Long] || classTag[K] == classTag[Int]) // Init in constructor (instead of in declaration) to work around a Scala compiler specialization // bug that would generate two arrays (one for Object and one for specialized T). diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala index 20554f0aab..b84eb65c62 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala @@ -17,11 +17,13 @@ package org.apache.spark.util.collection +import scala.reflect.ClassTag + /** * An append-only, non-threadsafe, array-backed vector that is optimized for primitive types. */ private[spark] -class PrimitiveVector[@specialized(Long, Int, Double) V: ClassManifest](initialSize: Int = 64) { +class PrimitiveVector[@specialized(Long, Int, Double) V: ClassTag](initialSize: Int = 64) { private var _numElements = 0 private var _array: Array[V] = _ diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 4434f3b87c..c443c5266e 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -27,6 +27,21 @@ import org.apache.spark.SparkContext._ class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkContext { + + implicit def setAccum[A] = new AccumulableParam[mutable.Set[A], A] { + def addInPlace(t1: mutable.Set[A], t2: mutable.Set[A]) : mutable.Set[A] = { + t1 ++= t2 + t1 + } + def addAccumulator(t1: mutable.Set[A], t2: A) : mutable.Set[A] = { + t1 += t2 + t1 + } + def zero(t: mutable.Set[A]) : mutable.Set[A] = { + new mutable.HashSet[A]() + } + } + test ("basic accumulation"){ sc = new SparkContext("local", "test") val acc : Accumulator[Int] = sc.accumulator(0) @@ -51,7 +66,6 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkConte } test ("add value to collection accumulators") { - import SetAccum._ val maxI = 1000 for (nThreads <- List(1, 10)) { //test single & multi-threaded sc = new SparkContext("local[" + nThreads + "]", "test") @@ -68,22 +82,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkConte } } - implicit object SetAccum extends AccumulableParam[mutable.Set[Any], Any] { - def addInPlace(t1: mutable.Set[Any], t2: mutable.Set[Any]) : mutable.Set[Any] = { - t1 ++= t2 - t1 - } - def addAccumulator(t1: mutable.Set[Any], t2: Any) : mutable.Set[Any] = { - t1 += t2 - t1 - } - def zero(t: mutable.Set[Any]) : mutable.Set[Any] = { - new mutable.HashSet[Any]() - } - } - test ("value not readable in tasks") { - import SetAccum._ val maxI = 1000 for (nThreads <- List(1, 10)) { //test single & multi-threaded sc = new SparkContext("local[" + nThreads + "]", "test") @@ -125,7 +124,6 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkConte } test ("localValue readable in tasks") { - import SetAccum._ val maxI = 1000 for (nThreads <- List(1, 10)) { //test single & multi-threaded sc = new SparkContext("local[" + nThreads + "]", "test") diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index f26c44d3e7..f25d921d3f 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark +import scala.reflect.ClassTag import org.scalatest.FunSuite import java.io.File import org.apache.spark.rdd._ @@ -62,8 +63,6 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { testCheckpointing(_.sample(false, 0.5, 0)) testCheckpointing(_.glom()) testCheckpointing(_.mapPartitions(_.map(_.toString))) - testCheckpointing(r => new MapPartitionsWithContextRDD(r, - (context: TaskContext, iter: Iterator[Int]) => iter.map(_.toString), false )) testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString)) testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x)) testCheckpointing(_.pipe(Seq("cat"))) @@ -207,7 +206,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { * not, but this is not done by default as usually the partitions do not refer to any RDD and * therefore never store the lineage. */ - def testCheckpointing[U: ClassManifest]( + def testCheckpointing[U: ClassTag]( op: (RDD[Int]) => RDD[U], testRDDSize: Boolean = true, testRDDPartitionSize: Boolean = false @@ -276,7 +275,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { * RDDs partitions. So even if the parent RDD is checkpointed and its partitions changed, * this RDD will remember the partitions and therefore potentially the whole lineage. */ - def testParentCheckpointing[U: ClassManifest]( + def testParentCheckpointing[U: ClassTag]( op: (RDD[Int]) => RDD[U], testRDDSize: Boolean, testRDDPartitionSize: Boolean diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 480bac84f3..d9cb7fead5 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -122,7 +122,7 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter sc.parallelize(1 to 10, 10).foreach(x => println(x / 0)) } assert(thrown.getClass === classOf[SparkException]) - assert(thrown.getMessage.contains("more than 4 times")) + assert(thrown.getMessage.contains("failed 4 times")) } test("caching") { @@ -303,12 +303,13 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter Thread.sleep(200) } } catch { - case _ => { Thread.sleep(10) } + case _: Throwable => { Thread.sleep(10) } // Do nothing. We might see exceptions because block manager // is racing this thread to remove entries from the driver. } } } + } object DistributedSuite { diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index 01a72d8401..6d1695eae7 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -34,7 +34,7 @@ class DriverSuite extends FunSuite with Timeouts { // Regression test for SPARK-530: "Spark driver process doesn't exit after finishing" val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]")) forAll(masters) { (master: String) => - failAfter(30 seconds) { + failAfter(60 seconds) { Utils.execute(Seq("./spark-class", "org.apache.spark.DriverWithoutCleanup", master), new File(System.getenv("SPARK_HOME"))) } diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java index 352036f182..4234f6eac7 100644 --- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java @@ -364,6 +364,20 @@ public class JavaAPISuite implements Serializable { List take = rdd.take(5); } + @Test + public void javaDoubleRDDHistoGram() { + JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); + // Test using generated buckets + Tuple2 results = rdd.histogram(2); + double[] expected_buckets = {1.0, 2.5, 4.0}; + long[] expected_counts = {2, 2}; + Assert.assertArrayEquals(expected_buckets, results._1, 0.1); + Assert.assertArrayEquals(expected_counts, results._2); + // Test with provided buckets + long[] histogram = rdd.histogram(expected_buckets); + Assert.assertArrayEquals(expected_counts, histogram); + } + @Test public void map() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index d8a0e983b2..1121e06e2e 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -114,7 +114,7 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf // Once A is cancelled, job B should finish fairly quickly. assert(jobB.get() === 100) } - +/* test("two jobs sharing the same stage") { // sem1: make sure cancel is issued after some tasks are launched // sem2: make sure the first stage is not finished until cancel is issued @@ -148,7 +148,7 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf intercept[SparkException] { f1.get() } intercept[SparkException] { f2.get() } } - + */ def testCount() { // Cancel before launching any tasks { diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala index 459e257d79..8dd5786da6 100644 --- a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala @@ -30,7 +30,7 @@ trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { self @transient var sc: SparkContext = _ override def beforeAll() { - InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory()); + InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory()) super.beforeAll() } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index b7eb268bd5..271dc905bc 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -49,14 +49,14 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { test("master start and stop") { val actorSystem = ActorSystem("test") val tracker = new MapOutputTrackerMaster() - tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker))) + tracker.trackerActor = Left(actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))) tracker.stop() } test("master register and fetch") { val actorSystem = ActorSystem("test") val tracker = new MapOutputTrackerMaster() - tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker))) + tracker.trackerActor = Left(actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))) tracker.registerShuffle(10, 2) val compressedSize1000 = MapOutputTracker.compressSize(1000L) val compressedSize10000 = MapOutputTracker.compressSize(10000L) @@ -75,7 +75,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { test("master register and unregister and fetch") { val actorSystem = ActorSystem("test") val tracker = new MapOutputTrackerMaster() - tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker))) + tracker.trackerActor = Left(actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))) tracker.registerShuffle(10, 2) val compressedSize1000 = MapOutputTracker.compressSize(1000L) val compressedSize10000 = MapOutputTracker.compressSize(10000L) @@ -101,13 +101,13 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { System.setProperty("spark.hostPort", hostname + ":" + boundPort) val masterTracker = new MapOutputTrackerMaster() - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker") + masterTracker.trackerActor = Left(actorSystem.actorOf( + Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker")) val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0) val slaveTracker = new MapOutputTracker() - slaveTracker.trackerActor = slaveSystem.actorFor( - "akka://spark@localhost:" + boundPort + "/user/MapOutputTracker") + slaveTracker.trackerActor = Right(slaveSystem.actorSelection( + "akka.tcp://spark@localhost:" + boundPort + "/user/MapOutputTracker")) masterTracker.registerShuffle(10, 1) masterTracker.incrementEpoch() diff --git a/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala b/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala deleted file mode 100644 index 21f16ef2c6..0000000000 --- a/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import org.scalatest.FunSuite -import org.apache.spark.SparkContext._ -import org.apache.spark.rdd.{RDD, PartitionPruningRDD} - - -class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext { - - test("Pruned Partitions inherit locality prefs correctly") { - class TestPartition(i: Int) extends Partition { - def index = i - } - val rdd = new RDD[Int](sc, Nil) { - override protected def getPartitions = { - Array[Partition]( - new TestPartition(1), - new TestPartition(2), - new TestPartition(3)) - } - def compute(split: Partition, context: TaskContext) = {Iterator()} - } - val prunedRDD = PartitionPruningRDD.create(rdd, {x => if (x==2) true else false}) - val p = prunedRDD.partitions(0) - assert(p.index == 2) - assert(prunedRDD.partitions.length == 1) - } -} diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index 7d938917f2..1374d01774 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -142,11 +142,11 @@ class PartitioningSuite extends FunSuite with SharedSparkContext { .filter(_ >= 0.0) // Run the partitions, including the consecutive empty ones, through StatCounter - val stats: StatCounter = rdd.stats(); - assert(abs(6.0 - stats.sum) < 0.01); - assert(abs(6.0/2 - rdd.mean) < 0.01); - assert(abs(1.0 - rdd.variance) < 0.01); - assert(abs(1.0 - rdd.stdev) < 0.01); + val stats: StatCounter = rdd.stats() + assert(abs(6.0 - stats.sum) < 0.01) + assert(abs(6.0/2 - rdd.mean) < 0.01) + assert(abs(1.0 - rdd.variance) < 0.01) + assert(abs(1.0 - rdd.stdev) < 0.01) // Add other tests here for classes that should be able to handle empty partitions correctly } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala new file mode 100644 index 0000000000..151af0d213 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.{FunSuite, PrivateMethodTester} + +import org.apache.spark.scheduler.TaskScheduler +import org.apache.spark.scheduler.cluster.{ClusterScheduler, SimrSchedulerBackend, SparkDeploySchedulerBackend} +import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} +import org.apache.spark.scheduler.local.LocalScheduler + +class SparkContextSchedulerCreationSuite + extends FunSuite with PrivateMethodTester with LocalSparkContext with Logging { + + def createTaskScheduler(master: String): TaskScheduler = { + // Create local SparkContext to setup a SparkEnv. We don't actually want to start() the + // real schedulers, so we don't want to create a full SparkContext with the desired scheduler. + sc = new SparkContext("local", "test") + val createTaskSchedulerMethod = PrivateMethod[TaskScheduler]('createTaskScheduler) + SparkContext invokePrivate createTaskSchedulerMethod(sc, master, "test") + } + + test("bad-master") { + val e = intercept[SparkException] { + createTaskScheduler("localhost:1234") + } + assert(e.getMessage.contains("Could not parse Master URL")) + } + + test("local") { + createTaskScheduler("local") match { + case s: LocalScheduler => + assert(s.threads === 1) + assert(s.maxFailures === 0) + case _ => fail() + } + } + + test("local-n") { + createTaskScheduler("local[5]") match { + case s: LocalScheduler => + assert(s.threads === 5) + assert(s.maxFailures === 0) + case _ => fail() + } + } + + test("local-n-failures") { + createTaskScheduler("local[4, 2]") match { + case s: LocalScheduler => + assert(s.threads === 4) + assert(s.maxFailures === 2) + case _ => fail() + } + } + + test("simr") { + createTaskScheduler("simr://uri") match { + case s: ClusterScheduler => + assert(s.backend.isInstanceOf[SimrSchedulerBackend]) + case _ => fail() + } + } + + test("local-cluster") { + createTaskScheduler("local-cluster[3, 14, 512]") match { + case s: ClusterScheduler => + assert(s.backend.isInstanceOf[SparkDeploySchedulerBackend]) + case _ => fail() + } + } + + def testYarn(master: String, expectedClassName: String) { + try { + createTaskScheduler(master) match { + case s: ClusterScheduler => + assert(s.getClass === Class.forName(expectedClassName)) + case _ => fail() + } + } catch { + case e: SparkException => + assert(e.getMessage.contains("YARN mode not available")) + logWarning("YARN not available, could not test actual YARN scheduler creation") + case e: Throwable => fail(e) + } + } + + test("yarn-standalone") { + testYarn("yarn-standalone", "org.apache.spark.scheduler.cluster.YarnClusterScheduler") + } + + test("yarn-client") { + testYarn("yarn-client", "org.apache.spark.scheduler.cluster.YarnClientClusterScheduler") + } + + def testMesos(master: String, expectedClass: Class[_]) { + try { + createTaskScheduler(master) match { + case s: ClusterScheduler => + assert(s.backend.getClass === expectedClass) + case _ => fail() + } + } catch { + case e: UnsatisfiedLinkError => + assert(e.getMessage.contains("no mesos in")) + logWarning("Mesos not available, could not test actual Mesos scheduler creation") + case e: Throwable => fail(e) + } + } + + test("mesos fine-grained") { + System.setProperty("spark.mesos.coarse", "false") + testMesos("mesos://localhost:1234", classOf[MesosSchedulerBackend]) + } + + test("mesos coarse-grained") { + System.setProperty("spark.mesos.coarse", "true") + testMesos("mesos://localhost:1234", classOf[CoarseMesosSchedulerBackend]) + } + + test("mesos with zookeeper") { + System.setProperty("spark.mesos.coarse", "false") + testMesos("zk://localhost:1234,localhost:2345", classOf[MesosSchedulerBackend]) + } +} diff --git a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala index 46a2da1724..768ca3850e 100644 --- a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala +++ b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala @@ -37,7 +37,7 @@ class UnpersistSuite extends FunSuite with LocalSparkContext { Thread.sleep(200) } } catch { - case _ => { Thread.sleep(10) } + case _: Throwable => { Thread.sleep(10) } // Do nothing. We might see exceptions because block manager // is racing this thread to remove entries from the driver. } diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala index 8f0954122b..4cb4ddc9cd 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.deploy.worker import java.io.File diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index da032b17d9..0d4c10db8e 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.rdd import java.util.concurrent.Semaphore +import scala.concurrent.{Await, TimeoutException} +import scala.concurrent.duration.Duration import scala.concurrent.ExecutionContext.Implicits.global import org.scalatest.{BeforeAndAfterAll, FunSuite} @@ -173,4 +175,28 @@ class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll with Timeouts sem.acquire(2) } } + + /** + * Awaiting FutureAction results + */ + test("FutureAction result, infinite wait") { + val f = sc.parallelize(1 to 100, 4) + .countAsync() + assert(Await.result(f, Duration.Inf) === 100) + } + + test("FutureAction result, finite wait") { + val f = sc.parallelize(1 to 100, 4) + .countAsync() + assert(Await.result(f, Duration(30, "seconds")) === 100) + } + + test("FutureAction result, timeout") { + val f = sc.parallelize(1 to 100, 4) + .mapPartitions(itr => { Thread.sleep(20); itr }) + .countAsync() + intercept[TimeoutException] { + Await.result(f, Duration(20, "milliseconds")) + } + } } diff --git a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala new file mode 100644 index 0000000000..7f50a5a47c --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala @@ -0,0 +1,271 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.math.abs +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.FunSuite + +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd._ +import org.apache.spark._ + +class DoubleRDDSuite extends FunSuite with SharedSparkContext { + // Verify tests on the histogram functionality. We test with both evenly + // and non-evenly spaced buckets as the bucket lookup function changes. + test("WorksOnEmpty") { + // Make sure that it works on an empty input + val rdd: RDD[Double] = sc.parallelize(Seq()) + val buckets = Array(0.0, 10.0) + val histogramResults = rdd.histogram(buckets) + val histogramResults2 = rdd.histogram(buckets, true) + val expectedHistogramResults = Array(0) + assert(histogramResults === expectedHistogramResults) + assert(histogramResults2 === expectedHistogramResults) + } + + test("WorksWithOutOfRangeWithOneBucket") { + // Verify that if all of the elements are out of range the counts are zero + val rdd = sc.parallelize(Seq(10.01, -0.01)) + val buckets = Array(0.0, 10.0) + val histogramResults = rdd.histogram(buckets) + val histogramResults2 = rdd.histogram(buckets, true) + val expectedHistogramResults = Array(0) + assert(histogramResults === expectedHistogramResults) + assert(histogramResults2 === expectedHistogramResults) + } + + test("WorksInRangeWithOneBucket") { + // Verify the basic case of one bucket and all elements in that bucket works + val rdd = sc.parallelize(Seq(1, 2, 3, 4)) + val buckets = Array(0.0, 10.0) + val histogramResults = rdd.histogram(buckets) + val histogramResults2 = rdd.histogram(buckets, true) + val expectedHistogramResults = Array(4) + assert(histogramResults === expectedHistogramResults) + assert(histogramResults2 === expectedHistogramResults) + } + + test("WorksInRangeWithOneBucketExactMatch") { + // Verify the basic case of one bucket and all elements in that bucket works + val rdd = sc.parallelize(Seq(1, 2, 3, 4)) + val buckets = Array(1.0, 4.0) + val histogramResults = rdd.histogram(buckets) + val histogramResults2 = rdd.histogram(buckets, true) + val expectedHistogramResults = Array(4) + assert(histogramResults === expectedHistogramResults) + assert(histogramResults2 === expectedHistogramResults) + } + + test("WorksWithOutOfRangeWithTwoBuckets") { + // Verify that out of range works with two buckets + val rdd = sc.parallelize(Seq(10.01, -0.01)) + val buckets = Array(0.0, 5.0, 10.0) + val histogramResults = rdd.histogram(buckets) + val histogramResults2 = rdd.histogram(buckets, true) + val expectedHistogramResults = Array(0, 0) + assert(histogramResults === expectedHistogramResults) + assert(histogramResults2 === expectedHistogramResults) + } + + test("WorksWithOutOfRangeWithTwoUnEvenBuckets") { + // Verify that out of range works with two un even buckets + val rdd = sc.parallelize(Seq(10.01, -0.01)) + val buckets = Array(0.0, 4.0, 10.0) + val histogramResults = rdd.histogram(buckets) + val expectedHistogramResults = Array(0, 0) + assert(histogramResults === expectedHistogramResults) + } + + test("WorksInRangeWithTwoBuckets") { + // Make sure that it works with two equally spaced buckets and elements in each + val rdd = sc.parallelize(Seq(1, 2, 3, 5, 6)) + val buckets = Array(0.0, 5.0, 10.0) + val histogramResults = rdd.histogram(buckets) + val histogramResults2 = rdd.histogram(buckets, true) + val expectedHistogramResults = Array(3, 2) + assert(histogramResults === expectedHistogramResults) + assert(histogramResults2 === expectedHistogramResults) + } + + test("WorksInRangeWithTwoBucketsAndNaN") { + // Make sure that it works with two equally spaced buckets and elements in each + val rdd = sc.parallelize(Seq(1, 2, 3, 5, 6, Double.NaN)) + val buckets = Array(0.0, 5.0, 10.0) + val histogramResults = rdd.histogram(buckets) + val histogramResults2 = rdd.histogram(buckets, true) + val expectedHistogramResults = Array(3, 2) + assert(histogramResults === expectedHistogramResults) + assert(histogramResults2 === expectedHistogramResults) + } + + test("WorksInRangeWithTwoUnevenBuckets") { + // Make sure that it works with two unequally spaced buckets and elements in each + val rdd = sc.parallelize(Seq(1, 2, 3, 5, 6)) + val buckets = Array(0.0, 5.0, 11.0) + val histogramResults = rdd.histogram(buckets) + val expectedHistogramResults = Array(3, 2) + assert(histogramResults === expectedHistogramResults) + } + + test("WorksMixedRangeWithTwoUnevenBuckets") { + // Make sure that it works with two unequally spaced buckets and elements in each + val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01)) + val buckets = Array(0.0, 5.0, 11.0) + val histogramResults = rdd.histogram(buckets) + val expectedHistogramResults = Array(4, 3) + assert(histogramResults === expectedHistogramResults) + } + + test("WorksMixedRangeWithFourUnevenBuckets") { + // Make sure that it works with two unequally spaced buckets and elements in each + val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, + 200.0, 200.1)) + val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0) + val histogramResults = rdd.histogram(buckets) + val expectedHistogramResults = Array(4, 2, 1, 3) + assert(histogramResults === expectedHistogramResults) + } + + test("WorksMixedRangeWithUnevenBucketsAndNaN") { + // Make sure that it works with two unequally spaced buckets and elements in each + val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, + 200.0, 200.1, Double.NaN)) + val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0) + val histogramResults = rdd.histogram(buckets) + val expectedHistogramResults = Array(4, 2, 1, 3) + assert(histogramResults === expectedHistogramResults) + } + // Make sure this works with a NaN end bucket + test("WorksMixedRangeWithUnevenBucketsAndNaNAndNaNRange") { + // Make sure that it works with two unequally spaced buckets and elements in each + val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, + 200.0, 200.1, Double.NaN)) + val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0, Double.NaN) + val histogramResults = rdd.histogram(buckets) + val expectedHistogramResults = Array(4, 2, 1, 2, 3) + assert(histogramResults === expectedHistogramResults) + } + // Make sure this works with a NaN end bucket and an inifity + test("WorksMixedRangeWithUnevenBucketsAndNaNAndNaNRangeAndInfity") { + // Make sure that it works with two unequally spaced buckets and elements in each + val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, + 200.0, 200.1, 1.0/0.0, -1.0/0.0, Double.NaN)) + val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0, Double.NaN) + val histogramResults = rdd.histogram(buckets) + val expectedHistogramResults = Array(4, 2, 1, 2, 4) + assert(histogramResults === expectedHistogramResults) + } + + test("WorksWithOutOfRangeWithInfiniteBuckets") { + // Verify that out of range works with two buckets + val rdd = sc.parallelize(Seq(10.01, -0.01, Double.NaN)) + val buckets = Array(-1.0/0.0 , 0.0, 1.0/0.0) + val histogramResults = rdd.histogram(buckets) + val expectedHistogramResults = Array(1, 1) + assert(histogramResults === expectedHistogramResults) + } + // Test the failure mode with an invalid bucket array + test("ThrowsExceptionOnInvalidBucketArray") { + val rdd = sc.parallelize(Seq(1.0)) + // Empty array + intercept[IllegalArgumentException] { + val buckets = Array.empty[Double] + val result = rdd.histogram(buckets) + } + // Single element array + intercept[IllegalArgumentException] { + val buckets = Array(1.0) + val result = rdd.histogram(buckets) + } + } + + // Test automatic histogram function + test("WorksWithoutBucketsBasic") { + // Verify the basic case of one bucket and all elements in that bucket works + val rdd = sc.parallelize(Seq(1, 2, 3, 4)) + val (histogramBuckets, histogramResults) = rdd.histogram(1) + val expectedHistogramResults = Array(4) + val expectedHistogramBuckets = Array(1.0, 4.0) + assert(histogramResults === expectedHistogramResults) + assert(histogramBuckets === expectedHistogramBuckets) + } + // Test automatic histogram function with a single element + test("WorksWithoutBucketsBasicSingleElement") { + // Verify the basic case of one bucket and all elements in that bucket works + val rdd = sc.parallelize(Seq(1)) + val (histogramBuckets, histogramResults) = rdd.histogram(1) + val expectedHistogramResults = Array(1) + val expectedHistogramBuckets = Array(1.0, 1.0) + assert(histogramResults === expectedHistogramResults) + assert(histogramBuckets === expectedHistogramBuckets) + } + // Test automatic histogram function with a single element + test("WorksWithoutBucketsBasicNoRange") { + // Verify the basic case of one bucket and all elements in that bucket works + val rdd = sc.parallelize(Seq(1, 1, 1, 1)) + val (histogramBuckets, histogramResults) = rdd.histogram(1) + val expectedHistogramResults = Array(4) + val expectedHistogramBuckets = Array(1.0, 1.0) + assert(histogramResults === expectedHistogramResults) + assert(histogramBuckets === expectedHistogramBuckets) + } + + test("WorksWithoutBucketsBasicTwo") { + // Verify the basic case of one bucket and all elements in that bucket works + val rdd = sc.parallelize(Seq(1, 2, 3, 4)) + val (histogramBuckets, histogramResults) = rdd.histogram(2) + val expectedHistogramResults = Array(2, 2) + val expectedHistogramBuckets = Array(1.0, 2.5, 4.0) + assert(histogramResults === expectedHistogramResults) + assert(histogramBuckets === expectedHistogramBuckets) + } + + test("WorksWithoutBucketsWithMoreRequestedThanElements") { + // Verify the basic case of one bucket and all elements in that bucket works + val rdd = sc.parallelize(Seq(1, 2)) + val (histogramBuckets, histogramResults) = rdd.histogram(10) + val expectedHistogramResults = + Array(1, 0, 0, 0, 0, 0, 0, 0, 0, 1) + val expectedHistogramBuckets = + Array(1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0) + assert(histogramResults === expectedHistogramResults) + assert(histogramBuckets === expectedHistogramBuckets) + } + + // Test the failure mode with an invalid RDD + test("ThrowsExceptionOnInvalidRDDs") { + // infinity + intercept[UnsupportedOperationException] { + val rdd = sc.parallelize(Seq(1, 1.0/0.0)) + val result = rdd.histogram(1) + } + // NaN + intercept[UnsupportedOperationException] { + val rdd = sc.parallelize(Seq(1, Double.NaN)) + val result = rdd.histogram(1) + } + // Empty + intercept[UnsupportedOperationException] { + val rdd: RDD[Double] = sc.parallelize(Seq()) + val result = rdd.histogram(1) + } + } + +} diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala new file mode 100644 index 0000000000..53a7b7c44d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.scalatest.FunSuite +import org.apache.spark.{TaskContext, Partition, SharedSparkContext} + + +class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext { + + + test("Pruned Partitions inherit locality prefs correctly") { + + val rdd = new RDD[Int](sc, Nil) { + override protected def getPartitions = { + Array[Partition]( + new TestPartition(0, 1), + new TestPartition(1, 1), + new TestPartition(2, 1)) + } + + def compute(split: Partition, context: TaskContext) = { + Iterator() + } + } + val prunedRDD = PartitionPruningRDD.create(rdd, { + x => if (x == 2) true else false + }) + assert(prunedRDD.partitions.length == 1) + val p = prunedRDD.partitions(0) + assert(p.index == 0) + assert(p.asInstanceOf[PartitionPruningRDDPartition].parentSplit.index == 2) + } + + + test("Pruned Partitions can be unioned ") { + + val rdd = new RDD[Int](sc, Nil) { + override protected def getPartitions = { + Array[Partition]( + new TestPartition(0, 4), + new TestPartition(1, 5), + new TestPartition(2, 6)) + } + + def compute(split: Partition, context: TaskContext) = { + List(split.asInstanceOf[TestPartition].testValue).iterator + } + } + val prunedRDD1 = PartitionPruningRDD.create(rdd, { + x => if (x == 0) true else false + }) + + val prunedRDD2 = PartitionPruningRDD.create(rdd, { + x => if (x == 2) true else false + }) + + val merged = prunedRDD1 ++ prunedRDD2 + assert(merged.count() == 2) + val take = merged.take(2) + assert(take.apply(0) == 4) + assert(take.apply(1) == 6) + } +} + +class TestPartition(i: Int, value: Int) extends Partition with Serializable { + def index = i + + def testValue = this.value + +} diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 88b36a6855..964ca9e91d 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -271,8 +271,8 @@ class RDDSuite extends FunSuite with SharedSparkContext { // test that you get over 90% locality in each group val minLocality = coalesced2.partitions .map(part => part.asInstanceOf[CoalescedRDDPartition].localFraction) - .foldLeft(1.)((perc, loc) => math.min(perc,loc)) - assert(minLocality >= 0.90, "Expected 90% locality but got " + (minLocality*100.).toInt + "%") + .foldLeft(1.0)((perc, loc) => math.min(perc,loc)) + assert(minLocality >= 0.90, "Expected 90% locality but got " + (minLocality*100.0).toInt + "%") // test that the groups are load balanced with 100 +/- 20 elements in each val maxImbalance = coalesced2.partitions @@ -284,9 +284,9 @@ class RDDSuite extends FunSuite with SharedSparkContext { val coalesced3 = data3.coalesce(numMachines*2) val minLocality2 = coalesced3.partitions .map(part => part.asInstanceOf[CoalescedRDDPartition].localFraction) - .foldLeft(1.)((perc, loc) => math.min(perc,loc)) + .foldLeft(1.0)((perc, loc) => math.min(perc,loc)) assert(minLocality2 >= 0.90, "Expected 90% locality for derived RDD but got " + - (minLocality2*100.).toInt + "%") + (minLocality2*100.0).toInt + "%") } test("zipped RDDs") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index a4d41ebbff..706d84a58b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -206,6 +206,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont submit(rdd, Array(0)) complete(taskSets(0), List((Success, 42))) assert(results === Map(0 -> 42)) + assertDataStructuresEmpty } test("local job") { @@ -219,6 +220,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont val jobId = scheduler.nextJobId.getAndIncrement() runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, null, listener)) assert(results === Map(0 -> 42)) + assertDataStructuresEmpty } test("run trivial job w/ dependency") { @@ -227,6 +229,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont submit(finalRdd, Array(0)) complete(taskSets(0), Seq((Success, 42))) assert(results === Map(0 -> 42)) + assertDataStructuresEmpty } test("cache location preferences w/ dependency") { @@ -239,12 +242,14 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont assertLocations(taskSet, Seq(Seq("hostA", "hostB"))) complete(taskSet, Seq((Success, 42))) assert(results === Map(0 -> 42)) + assertDataStructuresEmpty } test("trivial job failure") { submit(makeRdd(1, Nil), Array(0)) failed(taskSets(0), "some failure") assert(failure.getMessage === "Job aborted: some failure") + assertDataStructuresEmpty } test("run trivial shuffle") { @@ -260,6 +265,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) complete(taskSets(1), Seq((Success, 42))) assert(results === Map(0 -> 42)) + assertDataStructuresEmpty } test("run trivial shuffle with fetch failure") { @@ -285,6 +291,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB")) complete(taskSets(3), Seq((Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) + assertDataStructuresEmpty } test("ignore late map task completions") { @@ -313,6 +320,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) complete(taskSets(1), Seq((Success, 42), (Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) + assertDataStructuresEmpty } test("run trivial shuffle with out-of-band failure and retry") { @@ -329,15 +337,16 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)))) - // have hostC complete the resubmitted task - complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1)))) - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === - Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) - complete(taskSets(2), Seq((Success, 42))) - assert(results === Map(0 -> 42)) - } + // have hostC complete the resubmitted task + complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1)))) + assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === + Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + complete(taskSets(2), Seq((Success, 42))) + assert(results === Map(0 -> 42)) + assertDataStructuresEmpty + } - test("recursive shuffle failures") { + test("recursive shuffle failures") { val shuffleOneRdd = makeRdd(2, Nil) val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null) val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne)) @@ -363,6 +372,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont complete(taskSets(4), Seq((Success, makeMapStatus("hostA", 1)))) complete(taskSets(5), Seq((Success, 42))) assert(results === Map(0 -> 42)) + assertDataStructuresEmpty } test("cached post-shuffle") { @@ -394,6 +404,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont complete(taskSets(3), Seq((Success, makeMapStatus("hostD", 1)))) complete(taskSets(4), Seq((Success, 42))) assert(results === Map(0 -> 42)) + assertDataStructuresEmpty } /** @@ -413,4 +424,18 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont private def makeBlockManagerId(host: String): BlockManagerId = BlockManagerId("exec-" + host, host, 12345, 0) + private def assertDataStructuresEmpty = { + assert(scheduler.pendingTasks.isEmpty) + assert(scheduler.activeJobs.isEmpty) + assert(scheduler.failed.isEmpty) + assert(scheduler.idToActiveJob.isEmpty) + assert(scheduler.jobIdToStageIds.isEmpty) + assert(scheduler.stageIdToJobIds.isEmpty) + assert(scheduler.stageIdToStage.isEmpty) + assert(scheduler.stageToInfos.isEmpty) + assert(scheduler.resultStageToJob.isEmpty) + assert(scheduler.running.isEmpty) + assert(scheduler.shuffleToMapStage.isEmpty) + assert(scheduler.waiting.isEmpty) + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala index 984881861c..002368ff55 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.rdd.RDD class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers { + val WAIT_TIMEOUT_MILLIS = 10000 test("inner method") { sc = new SparkContext("local", "joblogger") @@ -92,6 +93,8 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) } rdd.reduceByKey(_+_).collect() + assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + val user = System.getProperty("user.name", SparkContext.SPARK_UNKNOWN_USER) joblogger.getLogDir should be ("/tmp/spark-%s".format(user)) @@ -120,7 +123,9 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers sc.addSparkListener(joblogger) val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) } rdd.reduceByKey(_+_).collect() - + + assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + joblogger.onJobStartCount should be (1) joblogger.onJobEndCount should be (1) joblogger.onTaskEndCount should be (8) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 1fd76420ea..2e41438a52 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -145,7 +145,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc // Make a task whose result is larger than the akka frame size System.setProperty("spark.akka.frameSize", "1") val akkaFrameSize = - sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt + sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x,y) => x) assert(result === 1.to(akkaFrameSize).toArray) diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala index b97f2b19b5..29c4cc5d9c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala @@ -283,7 +283,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo // Fail the task MAX_TASK_FAILURES times, and check that the task set is aborted // after the last failure. - (0 until manager.MAX_TASK_FAILURES).foreach { index => + (1 to manager.MAX_TASK_FAILURES).foreach { index => val offerResult = manager.resourceOffer("exec1", "host1", 1, ANY) assert(offerResult != None, "Expect resource offer on iteration %s to return a task".format(index)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala index ee150a3107..27c2d53361 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala @@ -82,7 +82,7 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA test("handling results larger than Akka frame size") { val akkaFrameSize = - sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt + sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x) assert(result === 1.to(akkaFrameSize).toArray) @@ -103,7 +103,7 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA } scheduler.taskResultGetter = new ResultDeletingTaskResultGetter(sc.env, scheduler) val akkaFrameSize = - sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt + sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x) assert(result === 1.to(akkaFrameSize).toArray) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala index cb76275e39..b647e8a672 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala @@ -39,7 +39,7 @@ class BlockIdSuite extends FunSuite { fail() } catch { case e: IllegalStateException => // OK - case _ => fail() + case _: Throwable => fail() } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 484a654108..5b4d63b954 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -56,7 +56,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT System.setProperty("spark.hostPort", "localhost:" + boundPort) master = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true)))) + Left(actorSystem.actorOf(Props(new BlockManagerMasterActor(true))))) // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case oldArch = System.setProperty("os.arch", "amd64") diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index 0b9056344c..070982e798 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.storage import java.io.{FileWriter, File} @@ -5,9 +22,9 @@ import java.io.{FileWriter, File} import scala.collection.mutable import com.google.common.io.Files -import org.scalatest.{BeforeAndAfterEach, FunSuite} +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} -class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach { +class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll { val rootDir0 = Files.createTempDir() rootDir0.deleteOnExit() @@ -16,6 +33,12 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach { val rootDirs = rootDir0.getName + "," + rootDir1.getName println("Created root dirs: " + rootDirs) + // This suite focuses primarily on consolidation features, + // so we coerce consolidation if not already enabled. + val consolidateProp = "spark.shuffle.consolidateFiles" + val oldConsolidate = Option(System.getProperty(consolidateProp)) + System.setProperty(consolidateProp, "true") + val shuffleBlockManager = new ShuffleBlockManager(null) { var idToSegmentMap = mutable.Map[ShuffleBlockId, FileSegment]() override def getBlockLocation(id: ShuffleBlockId) = idToSegmentMap(id) @@ -23,6 +46,10 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach { var diskBlockManager: DiskBlockManager = _ + override def afterAll() { + oldConsolidate.map(c => System.setProperty(consolidateProp, c)) + } + override def beforeEach() { diskBlockManager = new DiskBlockManager(shuffleBlockManager, rootDirs) shuffleBlockManager.idToSegmentMap.clear() diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index 8f0ec6683b..3764f4d1a0 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -34,7 +34,6 @@ class UISuite extends FunSuite { } val (jettyServer1, boundPort1) = JettyUtils.startJettyServer("localhost", startPort, Seq()) val (jettyServer2, boundPort2) = JettyUtils.startJettyServer("localhost", startPort, Seq()) - // Allow some wiggle room in case ports on the machine are under contention assert(boundPort1 > startPort && boundPort1 < startPort + 10) assert(boundPort2 > boundPort1 && boundPort2 < boundPort1 + 10) diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index 4e40dcbdee..5aff26f9fc 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -63,54 +63,53 @@ class SizeEstimatorSuite } test("simple classes") { - assert(SizeEstimator.estimate(new DummyClass1) === 16) - assert(SizeEstimator.estimate(new DummyClass2) === 16) - assert(SizeEstimator.estimate(new DummyClass3) === 24) - assert(SizeEstimator.estimate(new DummyClass4(null)) === 24) - assert(SizeEstimator.estimate(new DummyClass4(new DummyClass3)) === 48) + expectResult(16)(SizeEstimator.estimate(new DummyClass1)) + expectResult(16)(SizeEstimator.estimate(new DummyClass2)) + expectResult(24)(SizeEstimator.estimate(new DummyClass3)) + expectResult(24)(SizeEstimator.estimate(new DummyClass4(null))) + expectResult(48)(SizeEstimator.estimate(new DummyClass4(new DummyClass3))) } // NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors // (Sun vs IBM). Use a DummyString class to make tests deterministic. test("strings") { - assert(SizeEstimator.estimate(DummyString("")) === 40) - assert(SizeEstimator.estimate(DummyString("a")) === 48) - assert(SizeEstimator.estimate(DummyString("ab")) === 48) - assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 56) + expectResult(40)(SizeEstimator.estimate(DummyString(""))) + expectResult(48)(SizeEstimator.estimate(DummyString("a"))) + expectResult(48)(SizeEstimator.estimate(DummyString("ab"))) + expectResult(56)(SizeEstimator.estimate(DummyString("abcdefgh"))) } test("primitive arrays") { - assert(SizeEstimator.estimate(new Array[Byte](10)) === 32) - assert(SizeEstimator.estimate(new Array[Char](10)) === 40) - assert(SizeEstimator.estimate(new Array[Short](10)) === 40) - assert(SizeEstimator.estimate(new Array[Int](10)) === 56) - assert(SizeEstimator.estimate(new Array[Long](10)) === 96) - assert(SizeEstimator.estimate(new Array[Float](10)) === 56) - assert(SizeEstimator.estimate(new Array[Double](10)) === 96) - assert(SizeEstimator.estimate(new Array[Int](1000)) === 4016) - assert(SizeEstimator.estimate(new Array[Long](1000)) === 8016) + expectResult(32)(SizeEstimator.estimate(new Array[Byte](10))) + expectResult(40)(SizeEstimator.estimate(new Array[Char](10))) + expectResult(40)(SizeEstimator.estimate(new Array[Short](10))) + expectResult(56)(SizeEstimator.estimate(new Array[Int](10))) + expectResult(96)(SizeEstimator.estimate(new Array[Long](10))) + expectResult(56)(SizeEstimator.estimate(new Array[Float](10))) + expectResult(96)(SizeEstimator.estimate(new Array[Double](10))) + expectResult(4016)(SizeEstimator.estimate(new Array[Int](1000))) + expectResult(8016)(SizeEstimator.estimate(new Array[Long](1000))) } test("object arrays") { // Arrays containing nulls should just have one pointer per element - assert(SizeEstimator.estimate(new Array[String](10)) === 56) - assert(SizeEstimator.estimate(new Array[AnyRef](10)) === 56) - + expectResult(56)(SizeEstimator.estimate(new Array[String](10))) + expectResult(56)(SizeEstimator.estimate(new Array[AnyRef](10))) // For object arrays with non-null elements, each object should take one pointer plus // however many bytes that class takes. (Note that Array.fill calls the code in its // second parameter separately for each object, so we get distinct objects.) - assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass1)) === 216) - assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass2)) === 216) - assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass3)) === 296) - assert(SizeEstimator.estimate(Array(new DummyClass1, new DummyClass2)) === 56) + expectResult(216)(SizeEstimator.estimate(Array.fill(10)(new DummyClass1))) + expectResult(216)(SizeEstimator.estimate(Array.fill(10)(new DummyClass2))) + expectResult(296)(SizeEstimator.estimate(Array.fill(10)(new DummyClass3))) + expectResult(56)(SizeEstimator.estimate(Array(new DummyClass1, new DummyClass2))) // Past size 100, our samples 100 elements, but we should still get the right size. - assert(SizeEstimator.estimate(Array.fill(1000)(new DummyClass3)) === 28016) + expectResult(28016)(SizeEstimator.estimate(Array.fill(1000)(new DummyClass3))) // If an array contains the *same* element many times, we should only count it once. val d1 = new DummyClass1 - assert(SizeEstimator.estimate(Array.fill(10)(d1)) === 72) // 10 pointers plus 8-byte object - assert(SizeEstimator.estimate(Array.fill(100)(d1)) === 432) // 100 pointers plus 8-byte object + expectResult(72)(SizeEstimator.estimate(Array.fill(10)(d1))) // 10 pointers plus 8-byte object + expectResult(432)(SizeEstimator.estimate(Array.fill(100)(d1))) // 100 pointers plus 8-byte object // Same thing with huge array containing the same element many times. Note that this won't // return exactly 4032 because it can't tell that *all* the elements will equal the first @@ -128,11 +127,10 @@ class SizeEstimatorSuite val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() - assert(SizeEstimator.estimate(DummyString("")) === 40) - assert(SizeEstimator.estimate(DummyString("a")) === 48) - assert(SizeEstimator.estimate(DummyString("ab")) === 48) - assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 56) - + expectResult(40)(SizeEstimator.estimate(DummyString(""))) + expectResult(48)(SizeEstimator.estimate(DummyString("a"))) + expectResult(48)(SizeEstimator.estimate(DummyString("ab"))) + expectResult(56)(SizeEstimator.estimate(DummyString("abcdefgh"))) resetOrClear("os.arch", arch) } @@ -145,10 +143,10 @@ class SizeEstimatorSuite val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() - assert(SizeEstimator.estimate(DummyString("")) === 56) - assert(SizeEstimator.estimate(DummyString("a")) === 64) - assert(SizeEstimator.estimate(DummyString("ab")) === 64) - assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 72) + expectResult(56)(SizeEstimator.estimate(DummyString(""))) + expectResult(64)(SizeEstimator.estimate(DummyString("a"))) + expectResult(64)(SizeEstimator.estimate(DummyString("ab"))) + expectResult(72)(SizeEstimator.estimate(DummyString("abcdefgh"))) resetOrClear("os.arch", arch) resetOrClear("spark.test.useCompressedOops", oops) diff --git a/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala new file mode 100644 index 0000000000..b78367b6ca --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.Random +import org.scalatest.FlatSpec +import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers +import org.apache.spark.util.Utils.times + +class XORShiftRandomSuite extends FunSuite with ShouldMatchers { + + def fixture = new { + val seed = 1L + val xorRand = new XORShiftRandom(seed) + val hundMil = 1e8.toInt + } + + /* + * This test is based on a chi-squared test for randomness. The values are hard-coded + * so as not to create Spark's dependency on apache.commons.math3 just to call one + * method for calculating the exact p-value for a given number of random numbers + * and bins. In case one would want to move to a full-fledged test based on + * apache.commons.math3, the relevant class is here: + * org.apache.commons.math3.stat.inference.ChiSquareTest + */ + test ("XORShift generates valid random numbers") { + + val f = fixture + + val numBins = 10 + // create 10 bins + val bins = Array.fill(numBins)(0) + + // populate bins based on modulus of the random number + times(f.hundMil) {bins(math.abs(f.xorRand.nextInt) % 10) += 1} + + /* since the seed is deterministic, until the algorithm is changed, we know the result will be + * exactly this: Array(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272, + * 10000790, 10002286, 9998699), so the test will never fail at the prespecified (5%) + * significance level. However, should the RNG implementation change, the test should still + * pass at the same significance level. The chi-squared test done in R gave the following + * results: + * > chisq.test(c(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272, + * 10000790, 10002286, 9998699)) + * Chi-squared test for given probabilities + * data: c(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272, 10000790, + * 10002286, 9998699) + * X-squared = 11.975, df = 9, p-value = 0.2147 + * Note that the p-value was ~0.22. The test will fail if alpha < 0.05, which for 100 million + * random numbers + * and 10 bins will happen at X-squared of ~16.9196. So, the test will fail if X-squared + * is greater than or equal to that number. + */ + val binSize = f.hundMil/numBins + val xSquared = bins.map(x => math.pow((binSize - x), 2)/binSize).sum + xSquared should be < (16.9196) + + } + +} \ No newline at end of file diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala index ca3f684668..e9b62ea70d 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala @@ -1,9 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.util.collection import scala.collection.mutable.HashSet import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers +import org.apache.spark.util.SizeEstimator -class OpenHashMapSuite extends FunSuite { +class OpenHashMapSuite extends FunSuite with ShouldMatchers { + + test("size for specialized, primitive value (int)") { + val capacity = 1024 + val map = new OpenHashMap[String, Int](capacity) + val actualSize = SizeEstimator.estimate(map) + // 64 bit for pointers, 32 bit for ints, and 1 bit for the bitset. + val expectedSize = capacity * (64 + 32 + 1) / 8 + // Make sure we are not allocating a significant amount of memory beyond our expected. + actualSize should be <= (expectedSize * 1.1).toLong + } test("initialization") { val goodMap1 = new OpenHashMap[String, Int](1) diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala index 4e11e8a628..1b24f8f287 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala @@ -1,9 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.util.collection import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers + +import org.apache.spark.util.SizeEstimator -class OpenHashSetSuite extends FunSuite { +class OpenHashSetSuite extends FunSuite with ShouldMatchers { + + test("size for specialized, primitive int") { + val loadFactor = 0.7 + val set = new OpenHashSet[Int](64, loadFactor) + for (i <- 0 until 1024) { + set.add(i) + } + assert(set.size === 1024) + assert(set.capacity > 1024) + val actualSize = SizeEstimator.estimate(set) + // 32 bits for the ints + 1 bit for the bitset + val expectedSize = set.capacity * (32 + 1) / 8 + // Make sure we are not allocating a significant amount of memory beyond our expected. + actualSize should be <= (expectedSize * 1.1).toLong + } test("primitive int") { val set = new OpenHashSet[Int] diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala similarity index 63% rename from core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala rename to core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala index dfd6aed2c4..3b60decee9 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala @@ -1,9 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.util.collection import scala.collection.mutable.HashSet import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers +import org.apache.spark.util.SizeEstimator -class PrimitiveKeyOpenHashSetSuite extends FunSuite { +class PrimitiveKeyOpenHashMapSuite extends FunSuite with ShouldMatchers { + + test("size for specialized, primitive key, value (int, int)") { + val capacity = 1024 + val map = new PrimitiveKeyOpenHashMap[Int, Int](capacity) + val actualSize = SizeEstimator.estimate(map) + // 32 bit for keys, 32 bit for values, and 1 bit for the bitset. + val expectedSize = capacity * (32 + 32 + 1) / 8 + // Make sure we are not allocating a significant amount of memory beyond our expected. + actualSize should be <= (expectedSize * 1.1).toLong + } test("initialization") { val goodMap1 = new PrimitiveKeyOpenHashMap[Int, Int](1) diff --git a/docs/_config.yml b/docs/_config.yml index 48ecb8d0c9..02067f9750 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -5,6 +5,6 @@ markdown: kramdown # of Spark, Scala, and Mesos. SPARK_VERSION: 0.9.0-incubating-SNAPSHOT SPARK_VERSION_SHORT: 0.9.0-SNAPSHOT -SCALA_VERSION: 2.9.3 +SCALA_VERSION: 2.10 MESOS_VERSION: 0.13.0 SPARK_ISSUE_TRACKER_URL: https://spark-project.atlassian.net diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 0c1d657cde..ad7969d012 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -74,12 +74,12 @@ diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index c574ea7f5c..431de909cb 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -35,7 +35,7 @@ if not (ENV['SKIP_API'] == '1' or ENV['SKIP_SCALADOC'] == '1') # Copy over the scaladoc from each project into the docs directory. # This directory will be copied over to _site when `jekyll` command is run. projects.each do |project_name| - source = "../" + project_name + "/target/scala-2.9.3/api" + source = "../" + project_name + "/target/scala-2.10/api" dest = "api/" + project_name puts "echo making directory " + dest diff --git a/docs/bagel-programming-guide.md b/docs/bagel-programming-guide.md index 140190a38c..de001e6c52 100644 --- a/docs/bagel-programming-guide.md +++ b/docs/bagel-programming-guide.md @@ -106,7 +106,7 @@ _Example_ ## Operations -Here are the actions and types in the Bagel API. See [Bagel.scala](https://github.com/apache/incubator-spark/blob/master/bagel/src/main/scala/spark/bagel/Bagel.scala) for details. +Here are the actions and types in the Bagel API. See [Bagel.scala](https://github.com/apache/incubator-spark/blob/master/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala) for details. ### Actions diff --git a/docs/building-with-maven.md b/docs/building-with-maven.md index 19c01e179f..c709001632 100644 --- a/docs/building-with-maven.md +++ b/docs/building-with-maven.md @@ -45,6 +45,12 @@ For Apache Hadoop 2.x, 0.23.x, Cloudera CDH MRv2, and other Hadoop versions with # Cloudera CDH 4.2.0 with MapReduce v2 $ mvn -Phadoop2-yarn -Dhadoop.version=2.0.0-cdh4.2.0 -Dyarn.version=2.0.0-chd4.2.0 -DskipTests clean package +Hadoop versions 2.2.x and newer can be built by setting the ```new-yarn``` and the ```yarn.version``` as follows: + + # Apache Hadoop 2.2.X and newer + $ mvn -Dyarn.version=2.2.0 -Dhadoop.version=2.2.0 -Pnew-yarn + +The build process handles Hadoop 2.2.x as a special case that uses the directory ```new-yarn```, which supports the new YARN API. Furthermore, for this version, the build depends on artifacts published by the spark-project to enable Akka 2.0.5 to work with protobuf 2.5. ## Spark Tests in Maven ## diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md index 5927f736f3..e16703292c 100644 --- a/docs/cluster-overview.md +++ b/docs/cluster-overview.md @@ -45,7 +45,7 @@ The system currently supports three cluster managers: easy to set up a cluster. * [Apache Mesos](running-on-mesos.html) -- a general cluster manager that can also run Hadoop MapReduce and service applications. -* [Hadoop YARN](running-on-yarn.html) -- the resource manager in Hadoop 2.0. +* [Hadoop YARN](running-on-yarn.html) -- the resource manager in Hadoop 2. In addition, Spark's [EC2 launch scripts](ec2-scripts.html) make it easy to launch a standalone cluster on Amazon EC2. diff --git a/docs/configuration.md b/docs/configuration.md index 97183bafdb..677d182e50 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -275,11 +275,32 @@ Apart from these, the following properties are also available, and may be useful spark.akka.timeout - 20 + 100 Communication timeout between Spark nodes, in seconds. + + spark.akka.heartbeat.pauses + 600 + + This is set to a larger value to disable failure detector that comes inbuilt akka. It can be enabled again, if you plan to use this feature (Not recommended). Acceptable heart beat pause in seconds for akka. This can be used to control sensitivity to gc pauses. Tune this in combination of `spark.akka.heartbeat.interval` and `spark.akka.failure-detector.threshold` if you need to. + + + + spark.akka.failure-detector.threshold + 300.0 + + This is set to a larger value to disable failure detector that comes inbuilt akka. It can be enabled again, if you plan to use this feature (Not recommended). This maps to akka's `akka.remote.transport-failure-detector.threshold`. Tune this in combination of `spark.akka.heartbeat.pauses` and `spark.akka.heartbeat.interval` if you need to. + + + + spark.akka.heartbeat.interval + 1000 + + This is set to a larger value to disable failure detector that comes inbuilt akka. It can be enabled again, if you plan to use this feature (Not recommended). A larger interval value in seconds reduces network overhead and a smaller value ( ~ 1 s) might be more informative for akka's failure detector. Tune this in combination of `spark.akka.heartbeat.pauses` and `spark.akka.failure-detector.threshold` if you need to. Only positive use case for using failure detector can be, a sensistive failure detector can help evict rogue executors really quick. However this is usually not the case as gc pauses and network lags are expected in a real spark cluster. Apart from that enabling this leads to a lot of exchanges of heart beats between nodes leading to flooding the network with those. + + spark.driver.host (local hostname) @@ -327,7 +348,41 @@ Apart from these, the following properties are also available, and may be useful Too large a value decreases parallelism during broadcast (makes it slower); however, if it is too small, BlockManager might take a performance hit. - + + spark.shuffle.consolidateFiles + false + + If set to "true", consolidates intermediate files created during a shuffle. Creating fewer files can improve filesystem performance for shuffles with large numbers of reduce tasks. It is recommended to set this to "true" when using ext4 or xfs filesystems. On ext3, this option might degrade performance on machines with many (>8) cores due to filesystem limitations. + + + + spark.speculation + false + + If set to "true", performs speculative execution of tasks. This means if one or more tasks are running slowly in a stage, they will be re-launched. + + + + spark.speculation.interval + 100 + + How often Spark will check for tasks to speculate, in milliseconds. + + + + spark.speculation.quantile + 0.75 + + Percentage of tasks which must be complete before speculation is enabled for a particular stage. + + + + spark.speculation.multiplier + 1.5 + + How many times slower a task is than the median to be considered for speculation. + + # Environment Variables diff --git a/docs/hadoop-third-party-distributions.md b/docs/hadoop-third-party-distributions.md index b33af2cf24..de6a2b0a43 100644 --- a/docs/hadoop-third-party-distributions.md +++ b/docs/hadoop-third-party-distributions.md @@ -10,7 +10,7 @@ with these distributions: # Compile-time Hadoop Version When compiling Spark, you'll need to -[set the SPARK_HADOOP_VERSION flag](http://localhost:4000/index.html#a-note-about-hadoop-versions): +[set the SPARK_HADOOP_VERSION flag](index.html#a-note-about-hadoop-versions): SPARK_HADOOP_VERSION=1.0.4 sbt/sbt assembly @@ -40,6 +40,7 @@ the _exact_ Hadoop version you are running to avoid any compatibility errors. HDP 1.21.1.2 HDP 1.11.0.3 HDP 1.01.0.3 + HDP 2.02.2.0 diff --git a/docs/index.md b/docs/index.md index bd386a8a8f..d3ac696d1e 100644 --- a/docs/index.md +++ b/docs/index.md @@ -56,14 +56,16 @@ Hadoop, you must build Spark against the same version that your cluster uses. By default, Spark links to Hadoop 1.0.4. You can change this by setting the `SPARK_HADOOP_VERSION` variable when compiling: - SPARK_HADOOP_VERSION=1.2.1 sbt/sbt assembly + SPARK_HADOOP_VERSION=2.2.0 sbt/sbt assembly -In addition, if you wish to run Spark on [YARN](running-on-yarn.md), set +In addition, if you wish to run Spark on [YARN](running-on-yarn.html), set `SPARK_YARN` to `true`: SPARK_HADOOP_VERSION=2.0.5-alpha SPARK_YARN=true sbt/sbt assembly -(Note that on Windows, you need to set the environment variables on separate lines, e.g., `set SPARK_HADOOP_VERSION=1.2.1`.) +Note that on Windows, you need to set the environment variables on separate lines, e.g., `set SPARK_HADOOP_VERSION=1.2.1`. + +For this version of Spark (0.8.1) Hadoop 2.2.x (or newer) users will have to build Spark and publish it locally. See [Launching Spark on YARN](running-on-yarn.html). This is needed because Hadoop 2.2 has non backwards compatible API changes. # Where to Go from Here diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index d304c5497b..dbcb9ae343 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -91,7 +91,7 @@ The fair scheduler also supports grouping jobs into _pools_, and setting differe (e.g. weight) for each pool. This can be useful to create a "high-priority" pool for more important jobs, for example, or to group the jobs of each user together and give _users_ equal shares regardless of how many concurrent jobs they have instead of giving _jobs_ equal shares. This approach is modeled after the -[Hadoop Fair Scheduler](http://hadoop.apache.org/docs/stable/fair_scheduler.html). +[Hadoop Fair Scheduler](http://hadoop.apache.org/docs/current/hadoop-yarn/hadoop-yarn-site/FairScheduler.html). Without any intervention, newly submitted jobs go into a _default pool_, but jobs' pools can be set by adding the `spark.scheduler.pool` "local property" to the SparkContext in the thread that's submitting them. diff --git a/docs/monitoring.md b/docs/monitoring.md index 5f456b999b..5ed0474477 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -50,6 +50,7 @@ Each instance can report to zero or more _sinks_. Sinks are contained in the * `GangliaSink`: Sends metrics to a Ganglia node or multicast group. * `JmxSink`: Registers metrics for viewing in a JXM console. * `MetricsServlet`: Adds a servlet within the existing Spark UI to serve metrics data as JSON data. +* `GraphiteSink`: Sends metrics to a Graphite node. The syntax of the metrics configuration file is defined in an example configuration file, `$SPARK_HOME/conf/metrics.conf.template`. diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 6fd1d0d150..aa75ca4324 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -17,6 +17,7 @@ This can be built by setting the Hadoop version and `SPARK_YARN` environment var The assembled JAR will be something like this: `./assembly/target/scala-{{site.SCALA_VERSION}}/spark-assembly_{{site.SPARK_VERSION}}-hadoop2.0.5.jar`. +The build process now also supports new YARN versions (2.2.x). See below. # Preparations @@ -31,18 +32,26 @@ If you want to test out the YARN deployment mode, you can use the current Spark Most of the configs are the same for Spark on YARN as other deploys. See the Configuration page for more information on those. These are configs that are specific to SPARK on YARN. Environment variables: + * `SPARK_YARN_USER_ENV`, to add environment variables to the Spark processes launched on YARN. This can be a comma separated list of environment variables, e.g. `SPARK_YARN_USER_ENV="JAVA_HOME=/jdk64,FOO=bar"`. System Properties: -* 'spark.yarn.applicationMaster.waitTries', property to set the number of times the ApplicationMaster waits for the the spark master and then also the number of tries it waits for the Spark Context to be intialized. Default is 10. -* 'spark.yarn.submit.file.replication', the HDFS replication level for the files uploaded into HDFS for the application. These include things like the spark jar, the app jar, and any distributed cache files/archives. -* 'spark.yarn.preserve.staging.files', set to true to preserve the staged files(spark jar, app jar, distributed cache files) at the end of the job rather then delete them. + +* `spark.yarn.applicationMaster.waitTries`, property to set the number of times the ApplicationMaster waits for the the spark master and then also the number of tries it waits for the Spark Context to be intialized. Default is 10. +* `spark.yarn.submit.file.replication`, the HDFS replication level for the files uploaded into HDFS for the application. These include things like the spark jar, the app jar, and any distributed cache files/archives. +* `spark.yarn.preserve.staging.files`, set to true to preserve the staged files(spark jar, app jar, distributed cache files) at the end of the job rather then delete them. +* `spark.yarn.scheduler.heartbeat.interval-ms`, the interval in ms in which the Spark application master heartbeats into the YARN ResourceManager. Default is 5 seconds. +* `spark.yarn.max.worker.failures`, the maximum number of worker failures before failing the application. Default is the number of workers requested times 2 with minimum of 3. # Launching Spark on YARN Ensure that HADOOP_CONF_DIR or YARN_CONF_DIR points to the directory which contains the (client side) configuration files for the hadoop cluster. This would be used to connect to the cluster, write to the dfs and submit jobs to the resource manager. +There are two scheduler mode that can be used to launch spark application on YARN. + +## Launch spark application by YARN Client with yarn-standalone mode. + The command to launch the YARN Client is as follows: SPARK_JAR= ./spark-class org.apache.spark.deploy.yarn.Client \ @@ -50,6 +59,7 @@ The command to launch the YARN Client is as follows: --class \ --args \ --num-workers \ + --master-class --master-memory \ --worker-memory \ --worker-cores \ @@ -83,12 +93,37 @@ For example: $ cat $YARN_APP_LOGS_DIR/$YARN_APP_ID/container*_000001/stdout Pi is roughly 3.13794 -The above starts a YARN Client programs which periodically polls the Application Master for status updates and displays them in the console. The client will exit once your application has finished running. +The above starts a YARN Client programs which start the default Application Master. Then SparkPi will be run as a child thread of Application Master, YARN Client will periodically polls the Application Master for status updates and displays them in the console. The client will exit once your application has finished running. + +With this mode, your application is actually run on the remote machine where the Application Master is run upon. Thus application that involve local interaction will not work well, e.g. spark-shell. + +## Launch spark application with yarn-client mode. + +With yarn-client mode, the application will be launched locally. Just like running application or spark-shell on Local / Mesos / Standalone mode. The launch method is also the similar with them, just make sure that when you need to specify a master url, use "yarn-client" instead. And you also need to export the env value for SPARK_JAR and SPARK_YARN_APP_JAR + +In order to tune worker core/number/memory etc. You need to export SPARK_WORKER_CORES, SPARK_WORKER_MEMORY, SPARK_WORKER_INSTANCES e.g. by ./conf/spark-env.sh + +For example: + + SPARK_JAR=./assembly/target/scala-{{site.SCALA_VERSION}}/spark-assembly-{{site.SPARK_VERSION}}-hadoop2.0.5-alpha.jar \ + SPARK_YARN_APP_JAR=examples/target/scala-{{site.SCALA_VERSION}}/spark-examples-assembly-{{site.SPARK_VERSION}}.jar \ + ./run-example org.apache.spark.examples.SparkPi yarn-client + + + SPARK_JAR=./assembly/target/scala-{{site.SCALA_VERSION}}/spark-assembly-{{site.SPARK_VERSION}}-hadoop2.0.5-alpha.jar \ + SPARK_YARN_APP_JAR=examples/target/scala-{{site.SCALA_VERSION}}/spark-examples-assembly-{{site.SPARK_VERSION}}.jar \ + MASTER=yarn-client ./spark-shell + +# Building Spark for Hadoop/YARN 2.2.x + +Hadoop 2.2.x users must build Spark and publish it locally. The SBT build process handles Hadoop 2.2.x as a special case. This version of Hadoop has new YARN API changes and depends on a Protobuf version (2.5) that is not compatible with the Akka version (2.0.5) that Spark uses. Therefore, if the Hadoop version (e.g. set through ```SPARK_HADOOP_VERSION```) starts with 2.2.0 or higher then the build process will depend on Akka artifacts distributed by the Spark project compatible with Protobuf 2.5. Furthermore, the build process then uses the directory ```new-yarn``` (instead of ```yarn```), which supports the new YARN API. The build process should seamlessly work out of the box. + +See [Building Spark with Maven](building-with-maven.html) for instructions on how to build Spark using the Maven process. # Important Notes -- When your application instantiates a Spark context it must use a special "yarn-standalone" master url. This starts the scheduler without forcing it to connect to a cluster. A good way to handle this is to pass "yarn-standalone" as an argument to your program, as shown in the example above. - We do not requesting container resources based on the number of cores. Thus the numbers of cores given via command line arguments cannot be guaranteed. - The local directories used for spark will be the local directories configured for YARN (Hadoop Yarn config yarn.nodemanager.local-dirs). If the user specifies spark.local.dir, it will be ignored. - The --files and --archives options support specifying file names with the # similar to Hadoop. For example you can specify: --files localtest.txt#appSees.txt and this will upload the file you have locally named localtest.txt into HDFS but this will be linked to by the name appSees.txt and your application should use the name as appSees.txt to reference it when running on YARN. - The --addJars option allows the SparkContext.addJar function to work if you are using it with local files. It does not need to be used if you are using it with HDFS, HTTP, HTTPS, or FTP files. +- YARN 2.2.x users cannot simply depend on the Spark packages without building Spark, as the published Spark artifacts are compiled to work with the pre 2.2 API. Those users must build Spark and publish it locally. diff --git a/docs/scala-programming-guide.md b/docs/scala-programming-guide.md index 94e8563a8b..56d2a3a4a0 100644 --- a/docs/scala-programming-guide.md +++ b/docs/scala-programming-guide.md @@ -363,7 +363,7 @@ res2: Int = 10 # Where to Go from Here -You can see some [example Spark programs](http://www.spark-project.org/examples.html) on the Spark website. +You can see some [example Spark programs](http://spark.incubator.apache.org/examples.html) on the Spark website. In addition, Spark includes several samples in `examples/src/main/scala`. Some of them have both Spark versions and local (non-parallel) versions, allowing you to see what had to be changed to make the program run on a cluster. You can run them using by passing the class name to the `run-example` script included in Spark; for example: ./run-example org.apache.spark.examples.SparkPi diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 17066ef0dd..b822265b5a 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -51,11 +51,11 @@ Finally, the following configuration options can be passed to the master and wor -c CORES, --cores CORES - Total CPU cores to allow Spark applicatons to use on the machine (default: all available); only on worker + Total CPU cores to allow Spark applications to use on the machine (default: all available); only on worker -m MEM, --memory MEM - Total amount of memory to allow Spark applicatons to use on the machine, in a format like 1000M or 2G (default: your machine's total RAM minus 1 GB); only on worker + Total amount of memory to allow Spark applications to use on the machine, in a format like 1000M or 2G (default: your machine's total RAM minus 1 GB); only on worker -d DIR, --work-dir DIR diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 851e30fe76..82f42e0b8d 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -214,7 +214,7 @@ ssc.stop() {% endhighlight %} # Example -A simple example to start off is the [NetworkWordCount](https://github.com/apache/incubator-spark/tree/master/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala). This example counts the words received from a network server every second. Given below is the relevant sections of the source code. You can find the full source code in `/streaming/src/main/scala/spark/streaming/examples/NetworkWordCount.scala` . +A simple example to start off is the [NetworkWordCount](https://github.com/apache/incubator-spark/tree/master/examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala). This example counts the words received from a network server every second. Given below is the relevant sections of the source code. You can find the full source code in `/streaming/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala` . {% highlight scala %} import org.apache.spark.streaming.{Seconds, StreamingContext} @@ -283,7 +283,7 @@ Time: 1357008430000 ms -You can find more examples in `/streaming/src/main/scala/spark/streaming/examples/`. They can be run in the similar manner using `./run-example org.apache.spark.streaming.examples....` . Executing without any parameter would give the required parameter list. Further explanation to run them can be found in comments in the files. +You can find more examples in `/streaming/src/main/scala/org/apache/spark/streaming/examples/`. They can be run in the similar manner using `./run-example org.apache.spark.streaming.examples....` . Executing without any parameter would give the required parameter list. Further explanation to run them can be found in comments in the files. # DStream Persistence Similar to RDDs, DStreams also allow developers to persist the stream's data in memory. That is, using `persist()` method on a DStream would automatically persist every RDD of that DStream in memory. This is useful if the data in the DStream will be computed multiple times (e.g., multiple operations on the same data). For window-based operations like `reduceByWindow` and `reduceByKeyAndWindow` and state-based operations like `updateStateByKey`, this is implicitly true. Hence, DStreams generated by window-based operations are automatically persisted in memory, without the developer calling `persist()`. @@ -483,7 +483,7 @@ Similar to [Spark's Java API](java-programming-guide.html), we also provide a Ja 1. Functions for transformations must be implemented as subclasses of [Function](api/core/index.html#org.apache.spark.api.java.function.Function) and [Function2](api/core/index.html#org.apache.spark.api.java.function.Function2) 1. Unlike the Scala API, the Java API handles DStreams for key-value pairs using a separate [JavaPairDStream](api/streaming/index.html#org.apache.spark.streaming.api.java.JavaPairDStream) class(similar to [JavaRDD and JavaPairRDD](java-programming-guide.html#rdd-classes). DStream functions like `map` and `filter` are implemented separately by JavaDStreams and JavaPairDStream to return DStreams of appropriate types. -Spark's [Java Programming Guide](java-programming-guide.html) gives more ideas about using the Java API. To extends the ideas presented for the RDDs to DStreams, we present parts of the Java version of the same NetworkWordCount example presented above. The full source code is given at `/examples/src/main/java/spark/streaming/examples/JavaNetworkWordCount.java` +Spark's [Java Programming Guide](java-programming-guide.html) gives more ideas about using the Java API. To extends the ideas presented for the RDDs to DStreams, we present parts of the Java version of the same NetworkWordCount example presented above. The full source code is given at `/examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java` The streaming context and the socket stream from input source is started by using a `JavaStreamingContext`, that has the same parameters and provides the same input streams as its Scala counterpart. @@ -527,5 +527,5 @@ JavaPairDStream wordCounts = words.map( # Where to Go from Here * API docs - [Scala](api/streaming/index.html#org.apache.spark.streaming.package) and [Java](api/streaming/index.html#org.apache.spark.streaming.api.java.package) -* More examples - [Scala](https://github.com/apache/incubator-spark/tree/master/examples/src/main/scala/spark/streaming/examples) and [Java](https://github.com/apache/incubator-spark/tree/master/examples/src/main/java/spark/streaming/examples) +* More examples - [Scala](https://github.com/apache/incubator-spark/tree/master/examples/src/main/scala/org/apache/spark/streaming/examples) and [Java](https://github.com/apache/incubator-spark/tree/master/examples/src/main/java/org/apache/spark/streaming/examples) * [Paper describing Spark Streaming](http://www.eecs.berkeley.edu/Pubs/TechRpts/2012/EECS-2012-259.pdf) diff --git a/docs/tuning.md b/docs/tuning.md index f491ae9b95..a4be188169 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -39,7 +39,8 @@ in your operations) and performance. It provides two serialization libraries: for best performance. You can switch to using Kryo by calling `System.setProperty("spark.serializer", "org.apache.spark.serializer.KryoSerializer")` -*before* creating your SparkContext. The only reason it is not the default is because of the custom +*before* creating your SparkContext. This setting configures the serializer used for not only shuffling data between worker +nodes but also when serializing RDDs to disk. The only reason Kryo is not the default is because of the custom registration requirement, but we recommend trying it in any network-intensive application. Finally, to register your classes with Kryo, create a public class that extends @@ -67,7 +68,7 @@ The [Kryo documentation](http://code.google.com/p/kryo/) describes more advanced registration options, such as adding custom serialization code. If your objects are large, you may also need to increase the `spark.kryoserializer.buffer.mb` -system property. The default is 32, but this value needs to be large enough to hold the *largest* +system property. The default is 2, but this value needs to be large enough to hold the *largest* object you will serialize. Finally, if you don't register your classes, Kryo will still work, but it will have to store the diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 1189232428..a2b0e7e7f4 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -589,7 +589,7 @@ def ssh(host, opts, command): while True: try: return subprocess.check_call( - ssh_command(opts) + ['-t', '%s@%s' % (opts.user, host), stringify_command(command)]) + ssh_command(opts) + ['-t', '-t', '%s@%s' % (opts.user, host), stringify_command(command)]) except subprocess.CalledProcessError as e: if (tries > 2): # If this was an ssh failure, provide the user with hints. @@ -730,7 +730,7 @@ def real_main(): if opts.proxy_port != None: proxy_opt = ['-D', opts.proxy_port] subprocess.check_call( - ssh_command(opts) + proxy_opt + ['-t', "%s@%s" % (opts.user, master)]) + ssh_command(opts) + proxy_opt + ['-t', '-t', "%s@%s" % (opts.user, master)]) elif action == "get-master": (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) diff --git a/examples/pom.xml b/examples/pom.xml index aee371fbc7..7a7032c319 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -26,7 +26,7 @@ org.apache.spark - spark-examples_2.9.3 + spark-examples_2.10 jar Spark Project Examples http://spark.incubator.apache.org/ @@ -49,25 +49,25 @@ org.apache.spark - spark-core_2.9.3 + spark-core_${scala.binary.version} ${project.version} provided org.apache.spark - spark-streaming_2.9.3 + spark-streaming_${scala.binary.version} ${project.version} provided org.apache.spark - spark-mllib_2.9.3 + spark-mllib_${scala.binary.version} ${project.version} provided org.apache.spark - spark-bagel_2.9.3 + spark-bagel_${scala.binary.version} ${project.version} provided @@ -87,8 +87,8 @@ - org.apache.kafka - kafka_2.9.2 + com.sksamuel.kafka + kafka_${scala.binary.version} 0.8.0-beta1 @@ -107,23 +107,23 @@ com.twitter - algebird-core_2.9.2 + algebird-core_${scala.binary.version} 0.1.11 org.scalatest - scalatest_2.9.3 + scalatest_${scala.binary.version} test org.scalacheck - scalacheck_2.9.3 + scalacheck_${scala.binary.version} test org.apache.cassandra cassandra-all - 1.2.5 + 1.2.6 com.google.guava @@ -166,8 +166,8 @@ - target/scala-${scala.version}/classes - target/scala-${scala.version}/test-classes + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes org.apache.maven.plugins diff --git a/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java b/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java index 152f029213..407cd7ccfa 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java @@ -123,7 +123,7 @@ public class JavaLogQuery { }); List, Stats>> output = counts.collect(); - for (Tuple2 t : output) { + for (Tuple2 t : output) { System.out.println(t._1 + "\t" + t._2); } System.exit(0); diff --git a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java index c5603a639b..89aed8f279 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java @@ -21,7 +21,6 @@ import scala.Tuple2; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFlatMapFunction; @@ -106,7 +105,7 @@ public class JavaPageRank { // Collects all URL ranks and dump them to console. List> output = ranks.collect(); - for (Tuple2 tuple : output) { + for (Tuple2 tuple : output) { System.out.println(tuple._1 + " has rank: " + tuple._2 + "."); } diff --git a/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java b/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java index 07d32ad659..bd6383e13d 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java @@ -58,7 +58,7 @@ public class JavaWordCount { }); List> output = counts.collect(); - for (Tuple2 tuple : output) { + for (Tuple2 tuple : output) { System.out.println(tuple._1 + ": " + tuple._2); } System.exit(0); diff --git a/examples/src/main/java/org/apache/spark/mllib/examples/JavaALS.java b/examples/src/main/java/org/apache/spark/mllib/examples/JavaALS.java index 628cb892b6..45a0d237da 100644 --- a/examples/src/main/java/org/apache/spark/mllib/examples/JavaALS.java +++ b/examples/src/main/java/org/apache/spark/mllib/examples/JavaALS.java @@ -25,7 +25,6 @@ import org.apache.spark.mllib.recommendation.ALS; import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; import org.apache.spark.mllib.recommendation.Rating; -import java.io.Serializable; import java.util.Arrays; import java.util.StringTokenizer; diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala index 529709c2f9..a119980992 100644 --- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala @@ -32,13 +32,13 @@ object BroadcastTest { System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast." + bcName + "BroadcastFactory") System.setProperty("spark.broadcast.blockSize", blockSize) - val sc = new SparkContext(args(0), "Broadcast Test 2", + val sc = new SparkContext(args(0), "Broadcast Test", System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) val slices = if (args.length > 1) args(1).toInt else 2 val num = if (args.length > 2) args(2).toInt else 1000000 - var arr1 = new Array[Int](num) + val arr1 = new Array[Int](num) for (i <- 0 until arr1.length) { arr1(i) = i } @@ -48,9 +48,9 @@ object BroadcastTest { println("===========") val startTime = System.nanoTime val barr1 = sc.broadcast(arr1) - sc.parallelize(1 to 10, slices).foreach { - i => println(barr1.value.size) - } + val observedSizes = sc.parallelize(1 to 10, slices).map(_ => barr1.value.size) + // Collect the small RDD so we can print the observed sizes locally. + observedSizes.collect().foreach(i => println(i)) println("Iteration %d took %.0f milliseconds".format(i, (System.nanoTime - startTime) / 1E6)) } diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala index 4af45b2b4a..83db8b9e26 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala @@ -120,7 +120,7 @@ object LocalALS { System.exit(1) } } - printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS); + printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS) val R = generateR() diff --git a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala index f79f0142b8..e1afc29f9a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala @@ -18,35 +18,38 @@ package org.apache.spark.examples import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDD object MultiBroadcastTest { def main(args: Array[String]) { if (args.length == 0) { - System.err.println("Usage: BroadcastTest [] [numElem]") + System.err.println("Usage: MultiBroadcastTest [] [numElem]") System.exit(1) } - val sc = new SparkContext(args(0), "Broadcast Test", + val sc = new SparkContext(args(0), "Multi-Broadcast Test", System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) val slices = if (args.length > 1) args(1).toInt else 2 val num = if (args.length > 2) args(2).toInt else 1000000 - var arr1 = new Array[Int](num) + val arr1 = new Array[Int](num) for (i <- 0 until arr1.length) { arr1(i) = i } - var arr2 = new Array[Int](num) + val arr2 = new Array[Int](num) for (i <- 0 until arr2.length) { arr2(i) = i } val barr1 = sc.broadcast(arr1) val barr2 = sc.broadcast(arr2) - sc.parallelize(1 to 10, slices).foreach { - i => println(barr1.value.size + barr2.value.size) + val observedSizes: RDD[(Int, Int)] = sc.parallelize(1 to 10, slices).map { _ => + (barr1.value.size, barr2.value.size) } + // Collect the small RDD so we can print the observed sizes locally. + observedSizes.collect().foreach(i => println(i)) System.exit(0) } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala index 5a7a9d1bd8..8543ce0e32 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala @@ -65,7 +65,7 @@ object SparkTC { oldCount = nextCount // Perform the join, obtaining an RDD of (y, (z, x)) pairs, // then project the result to obtain the new (x, z) paths. - tc = tc.union(tc.join(edges).map(x => (x._2._2, x._2._1))).distinct().cache(); + tc = tc.union(tc.join(edges).map(x => (x._2._2, x._2._1))).distinct().cache() nextCount = tc.count() } while (nextCount != oldCount) diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala index cd3423a07b..50e3f9639c 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala @@ -19,6 +19,7 @@ package org.apache.spark.streaming.examples import scala.collection.mutable.LinkedList import scala.util.Random +import scala.reflect.ClassTag import akka.actor.Actor import akka.actor.ActorRef @@ -82,10 +83,10 @@ class FeederActor extends Actor { * * @see [[org.apache.spark.streaming.examples.FeederActor]] */ -class SampleActorReceiver[T: ClassManifest](urlOfPublisher: String) +class SampleActorReceiver[T: ClassTag](urlOfPublisher: String) extends Actor with Receiver { - lazy private val remotePublisher = context.actorFor(urlOfPublisher) + lazy private val remotePublisher = context.actorSelection(urlOfPublisher) override def preStart = remotePublisher ! SubscribeReceiver(context.self) @@ -120,7 +121,7 @@ object FeederActor { println("Feeder started as:" + feeder) - actorSystem.awaitTermination(); + actorSystem.awaitTermination() } } @@ -164,7 +165,7 @@ object ActorWordCount { */ val lines = ssc.actorStream[String]( - Props(new SampleActorReceiver[String]("akka://test@%s:%s/user/FeederActor".format( + Props(new SampleActorReceiver[String]("akka.tcp://test@%s:%s/user/FeederActor".format( host, port.toInt))), "SampleReceiver") //compute wordcount diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/MQTTWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/MQTTWordCount.scala index af698a01d5..ff332a0282 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/MQTTWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/MQTTWordCount.scala @@ -54,12 +54,12 @@ object MQTTPublisher { client.connect() - val msgtopic: MqttTopic = client.getTopic(topic); + val msgtopic: MqttTopic = client.getTopic(topic) val msg: String = "hello mqtt demo for spark streaming" while (true) { val message: MqttMessage = new MqttMessage(String.valueOf(msg).getBytes()) - msgtopic.publish(message); + msgtopic.publish(message) println("Published data. topic: " + msgtopic.getName() + " Message: " + message) } client.disconnect() diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala index c8743b9e25..e83ce78aa5 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala @@ -23,6 +23,7 @@ import akka.zeromq._ import org.apache.spark.streaming.{ Seconds, StreamingContext } import org.apache.spark.streaming.StreamingContext._ import akka.zeromq.Subscribe +import akka.util.ByteString /** * A simple publisher for demonstration purposes, repeatedly publishes random Messages @@ -40,10 +41,11 @@ object SimpleZeroMQPublisher { val acs: ActorSystem = ActorSystem() val pubSocket = ZeroMQExtension(acs).newSocket(SocketType.Pub, Bind(url)) - val messages: Array[String] = Array("words ", "may ", "count ") + implicit def stringToByteString(x: String) = ByteString(x) + val messages: List[ByteString] = List("words ", "may ", "count ") while (true) { Thread.sleep(1000) - pubSocket ! ZMQMessage(Frame(topic) :: messages.map(x => Frame(x.getBytes)).toList) + pubSocket ! ZMQMessage(ByteString(topic) :: messages) } acs.awaitTermination() } @@ -78,7 +80,7 @@ object ZeroMQWordCount { val ssc = new StreamingContext(master, "ZeroMQWordCount", Seconds(2), System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - def bytesToStringIterator(x: Seq[Seq[Byte]]) = (x.map(x => new String(x.toArray))).iterator + def bytesToStringIterator(x: Seq[ByteString]) = (x.map(_.utf8String)).iterator //For this stream, a zeroMQ publisher should be running. val lines = ssc.zeroMQStream(url, Subscribe(topic), bytesToStringIterator) diff --git a/mllib/pom.xml b/mllib/pom.xml index f472082ad1..dda3900afe 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -26,7 +26,7 @@ org.apache.spark - spark-mllib_2.9.3 + spark-mllib_2.10 jar Spark Project ML Library http://spark.incubator.apache.org/ @@ -34,7 +34,7 @@ org.apache.spark - spark-core_2.9.3 + spark-core_${scala.binary.version} ${project.version} @@ -48,12 +48,12 @@ org.scalatest - scalatest_2.9.3 + scalatest_${scala.binary.version} test org.scalacheck - scalacheck_2.9.3 + scalacheck_${scala.binary.version} test @@ -63,8 +63,8 @@ - target/scala-${scala.version}/classes - target/scala-${scala.version}/test-classes + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes org.scalatest diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index edbf77dbcc..0dee9399a8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -18,15 +18,16 @@ package org.apache.spark.mllib.clustering import scala.collection.mutable.ArrayBuffer -import scala.util.Random + +import org.jblas.DoubleMatrix import org.apache.spark.SparkContext import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.Logging import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.util.XORShiftRandom -import org.jblas.DoubleMatrix /** @@ -195,7 +196,7 @@ class KMeans private ( */ private def initRandom(data: RDD[Array[Double]]): Array[ClusterCenters] = { // Sample all the cluster centers in one pass to avoid repeated scans - val sample = data.takeSample(true, runs * k, new Random().nextInt()).toSeq + val sample = data.takeSample(true, runs * k, new XORShiftRandom().nextInt()).toSeq Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).toArray) } @@ -210,7 +211,7 @@ class KMeans private ( */ private def initKMeansParallel(data: RDD[Array[Double]]): Array[ClusterCenters] = { // Initialize each run's center to a random point - val seed = new Random().nextInt() + val seed = new XORShiftRandom().nextInt() val sample = data.takeSample(true, runs, seed).toSeq val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r))) @@ -222,7 +223,7 @@ class KMeans private ( for (r <- 0 until runs) yield (r, KMeans.pointCost(centerArrays(r), point)) }.reduceByKey(_ + _).collectAsMap() val chosen = data.mapPartitionsWithIndex { (index, points) => - val rand = new Random(seed ^ (step << 16) ^ index) + val rand = new XORShiftRandom(seed ^ (step << 16) ^ index) for { p <- points r <- 0 until runs diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala index 5aec867257..d5f3f6b8db 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala @@ -83,7 +83,7 @@ object MFDataGenerator{ scala.math.round(.99 * m * n)).toInt val rand = new Random() val mn = m * n - val shuffled = rand.shuffle(1 to mn toIterable) + val shuffled = rand.shuffle(1 to mn toList) val omega = shuffled.slice(0, sampSize) val ordered = omega.sortWith(_ < _).toArray diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java index 32d3934ac1..33b99f4bd3 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java @@ -77,7 +77,7 @@ public class JavaKMeansSuite implements Serializable { @Test public void runKMeansUsingStaticMethods() { - List points = new ArrayList(); + List points = new ArrayList(); points.add(new double[]{1.0, 2.0, 6.0}); points.add(new double[]{1.0, 3.0, 0.0}); points.add(new double[]{1.0, 4.0, 6.0}); @@ -94,7 +94,7 @@ public class JavaKMeansSuite implements Serializable { @Test public void runKMeansUsingConstructor() { - List points = new ArrayList(); + List points = new ArrayList(); points.add(new double[]{1.0, 2.0, 6.0}); points.add(new double[]{1.0, 3.0, 0.0}); points.add(new double[]{1.0, 4.0, 6.0}); diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java index eafee060cd..b40f552e0d 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java @@ -21,8 +21,6 @@ import java.io.Serializable; import java.util.List; import java.lang.Math; -import scala.Tuple2; - import org.junit.After; import org.junit.Assert; import org.junit.Before; diff --git a/new-yarn/pom.xml b/new-yarn/pom.xml new file mode 100644 index 0000000000..4cd28f34e3 --- /dev/null +++ b/new-yarn/pom.xml @@ -0,0 +1,161 @@ + + + + 4.0.0 + + org.apache.spark + spark-parent + 0.9.0-incubating-SNAPSHOT + ../pom.xml + + + org.apache.spark + spark-yarn_2.10 + jar + Spark Project YARN Support + http://spark.incubator.apache.org/ + + + + org.apache.spark + spark-core_2.10 + ${project.version} + + + org.apache.hadoop + hadoop-yarn-api + + + org.apache.hadoop + hadoop-yarn-common + + + org.apache.hadoop + hadoop-yarn-client + + + org.apache.hadoop + hadoop-client + ${yarn.version} + + + org.apache.avro + avro + + + org.apache.avro + avro-ipc + + + org.scalatest + scalatest_2.10 + test + + + org.mockito + mockito-all + test + + + + + target/scala-${scala.version}/classes + target/scala-${scala.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + ${project.build.directory}/${project.artifactId}-${project.version}-shaded.jar + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + package + + shade + + + + + + reference.conf + + + + + + + + org.apache.maven.plugins + maven-antrun-plugin + + + test + + run + + + true + + + + + + + + + + + + + + + + + + + + org.scalatest + scalatest-maven-plugin + + + ${basedir}/.. + 1 + ${spark.classpath} + + + + + + diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala new file mode 100644 index 0000000000..eeeca3ea8a --- /dev/null +++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -0,0 +1,446 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.io.IOException +import java.net.Socket +import java.util.concurrent.CopyOnWriteArrayList +import java.util.concurrent.atomic.{AtomicInteger, AtomicReference} + +import scala.collection.JavaConversions._ + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.net.NetUtils +import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.util.ShutdownHookManager +import org.apache.hadoop.yarn.api._ +import org.apache.hadoop.yarn.api.protocolrecords._ +import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.client.api.AMRMClient +import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.ipc.YarnRPC +import org.apache.hadoop.yarn.util.{ConverterUtils, Records} + +import org.apache.spark.{SparkContext, Logging} +import org.apache.spark.util.Utils + + +class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) extends Logging { + + def this(args: ApplicationMasterArguments) = this(args, new Configuration()) + + private var rpc: YarnRPC = YarnRPC.create(conf) + private val yarnConf: YarnConfiguration = new YarnConfiguration(conf) + private var appAttemptId: ApplicationAttemptId = _ + private var userThread: Thread = _ + private val fs = FileSystem.get(yarnConf) + + private var yarnAllocator: YarnAllocationHandler = _ + private var isFinished: Boolean = false + private var uiAddress: String = _ + private val maxAppAttempts: Int = conf.getInt( + YarnConfiguration.RM_AM_MAX_ATTEMPTS, YarnConfiguration.DEFAULT_RM_AM_MAX_ATTEMPTS) + private var isLastAMRetry: Boolean = true + private var amClient: AMRMClient[ContainerRequest] = _ + + // Default to numWorkers * 2, with minimum of 3 + private val maxNumWorkerFailures = System.getProperty("spark.yarn.max.worker.failures", + math.max(args.numWorkers * 2, 3).toString()).toInt + + def run() { + // Setup the directories so things go to YARN approved directories rather + // than user specified and /tmp. + System.setProperty("spark.local.dir", getLocalDirs()) + + // Use priority 30 as it's higher then HDFS. It's same priority as MapReduce is using. + ShutdownHookManager.get().addShutdownHook(new AppMasterShutdownHook(this), 30) + + appAttemptId = getApplicationAttemptId() + isLastAMRetry = appAttemptId.getAttemptId() >= maxAppAttempts + amClient = AMRMClient.createAMRMClient() + amClient.init(yarnConf) + amClient.start() + + // Workaround until hadoop moves to something which has + // https://issues.apache.org/jira/browse/HADOOP-8406 - fixed in (2.0.2-alpha but no 0.23 line) + // org.apache.hadoop.io.compress.CompressionCodecFactory.getCodecClasses(conf) + + ApplicationMaster.register(this) + + // Start the user's JAR + userThread = startUserClass() + + // This a bit hacky, but we need to wait until the spark.driver.port property has + // been set by the Thread executing the user class. + waitForSparkMaster() + + waitForSparkContextInitialized() + + // Do this after Spark master is up and SparkContext is created so that we can register UI Url. + val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster() + + // Allocate all containers + allocateWorkers() + + // Wait for the user class to Finish + userThread.join() + + System.exit(0) + } + + /** Get the Yarn approved local directories. */ + private def getLocalDirs(): String = { + // Hadoop 0.23 and 2.x have different Environment variable names for the + // local dirs, so lets check both. We assume one of the 2 is set. + // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X + val localDirs = Option(System.getenv("YARN_LOCAL_DIRS")) + .getOrElse(Option(System.getenv("LOCAL_DIRS")) + .getOrElse("")) + + if (localDirs.isEmpty()) { + throw new Exception("Yarn Local dirs can't be empty") + } + localDirs + } + + private def getApplicationAttemptId(): ApplicationAttemptId = { + val envs = System.getenv() + val containerIdString = envs.get(ApplicationConstants.Environment.CONTAINER_ID.name()) + val containerId = ConverterUtils.toContainerId(containerIdString) + val appAttemptId = containerId.getApplicationAttemptId() + logInfo("ApplicationAttemptId: " + appAttemptId) + appAttemptId + } + + private def registerApplicationMaster(): RegisterApplicationMasterResponse = { + logInfo("Registering the ApplicationMaster") + amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress) + } + + private def waitForSparkMaster() { + logInfo("Waiting for Spark driver to be reachable.") + var driverUp = false + var tries = 0 + val numTries = System.getProperty("spark.yarn.applicationMaster.waitTries", "10").toInt + while (!driverUp && tries < numTries) { + val driverHost = System.getProperty("spark.driver.host") + val driverPort = System.getProperty("spark.driver.port") + try { + val socket = new Socket(driverHost, driverPort.toInt) + socket.close() + logInfo("Driver now available: %s:%s".format(driverHost, driverPort)) + driverUp = true + } catch { + case e: Exception => { + logWarning("Failed to connect to driver at %s:%s, retrying ...". + format(driverHost, driverPort)) + Thread.sleep(100) + tries = tries + 1 + } + } + } + } + + private def startUserClass(): Thread = { + logInfo("Starting the user JAR in a separate Thread") + val mainMethod = Class.forName( + args.userClass, + false /* initialize */, + Thread.currentThread.getContextClassLoader).getMethod("main", classOf[Array[String]]) + val t = new Thread { + override def run() { + var successed = false + try { + // Copy + var mainArgs: Array[String] = new Array[String](args.userArgs.size) + args.userArgs.copyToArray(mainArgs, 0, args.userArgs.size) + mainMethod.invoke(null, mainArgs) + // some job script has "System.exit(0)" at the end, for example SparkPi, SparkLR + // userThread will stop here unless it has uncaught exception thrown out + // It need shutdown hook to set SUCCEEDED + successed = true + } finally { + logDebug("finishing main") + isLastAMRetry = true + if (successed) { + ApplicationMaster.this.finishApplicationMaster(FinalApplicationStatus.SUCCEEDED) + } else { + ApplicationMaster.this.finishApplicationMaster(FinalApplicationStatus.FAILED) + } + } + } + } + t.start() + t + } + + // This need to happen before allocateWorkers() + private def waitForSparkContextInitialized() { + logInfo("Waiting for Spark context initialization") + try { + var sparkContext: SparkContext = null + ApplicationMaster.sparkContextRef.synchronized { + var numTries = 0 + val waitTime = 10000L + val maxNumTries = System.getProperty("spark.yarn.ApplicationMaster.waitTries", "10").toInt + while (ApplicationMaster.sparkContextRef.get() == null && numTries < maxNumTries) { + logInfo("Waiting for Spark context initialization ... " + numTries) + numTries = numTries + 1 + ApplicationMaster.sparkContextRef.wait(waitTime) + } + sparkContext = ApplicationMaster.sparkContextRef.get() + assert(sparkContext != null || numTries >= maxNumTries) + + if (sparkContext != null) { + uiAddress = sparkContext.ui.appUIAddress + this.yarnAllocator = YarnAllocationHandler.newAllocator( + yarnConf, + amClient, + appAttemptId, + args, + sparkContext.preferredNodeLocationData) + } else { + logWarning("Unable to retrieve SparkContext inspite of waiting for %d, maxNumTries = %d". + format(numTries * waitTime, maxNumTries)) + this.yarnAllocator = YarnAllocationHandler.newAllocator( + yarnConf, + amClient, + appAttemptId, + args) + } + } + } finally { + // In case of exceptions, etc - ensure that count is at least ALLOCATOR_LOOP_WAIT_COUNT : + // so that the loop (in ApplicationMaster.sparkContextInitialized) breaks. + ApplicationMaster.incrementAllocatorLoop(ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT) + } + } + + private def allocateWorkers() { + try { + logInfo("Allocating " + args.numWorkers + " workers.") + // Wait until all containers have finished + // TODO: This is a bit ugly. Can we make it nicer? + // TODO: Handle container failure + yarnAllocator.addResourceRequests(args.numWorkers) + // Exits the loop if the user thread exits. + while (yarnAllocator.getNumWorkersRunning < args.numWorkers && userThread.isAlive) { + if (yarnAllocator.getNumWorkersFailed >= maxNumWorkerFailures) { + finishApplicationMaster(FinalApplicationStatus.FAILED, + "max number of worker failures reached") + } + yarnAllocator.allocateResources() + ApplicationMaster.incrementAllocatorLoop(1) + Thread.sleep(100) + } + } finally { + // In case of exceptions, etc - ensure that count is at least ALLOCATOR_LOOP_WAIT_COUNT, + // so that the loop in ApplicationMaster#sparkContextInitialized() breaks. + ApplicationMaster.incrementAllocatorLoop(ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT) + } + logInfo("All workers have launched.") + + // Launch a progress reporter thread, else the app will get killed after expiration + // (def: 10mins) timeout. + if (userThread.isAlive) { + // Ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapses. + val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000) + + // we want to be reasonably responsive without causing too many requests to RM. + val schedulerInterval = + System.getProperty("spark.yarn.scheduler.heartbeat.interval-ms", "5000").toLong + + // must be <= timeoutInterval / 2. + val interval = math.min(timeoutInterval / 2, schedulerInterval) + + launchReporterThread(interval) + } + } + + private def launchReporterThread(_sleepTime: Long): Thread = { + val sleepTime = if (_sleepTime <= 0 ) 0 else _sleepTime + + val t = new Thread { + override def run() { + while (userThread.isAlive) { + if (yarnAllocator.getNumWorkersFailed >= maxNumWorkerFailures) { + finishApplicationMaster(FinalApplicationStatus.FAILED, + "max number of worker failures reached") + } + val missingWorkerCount = args.numWorkers - yarnAllocator.getNumWorkersRunning - + yarnAllocator.getNumPendingAllocate + if (missingWorkerCount > 0) { + logInfo("Allocating %d containers to make up for (potentially) lost containers". + format(missingWorkerCount)) + yarnAllocator.addResourceRequests(missingWorkerCount) + } + sendProgress() + Thread.sleep(sleepTime) + } + } + } + // Setting to daemon status, though this is usually not a good idea. + t.setDaemon(true) + t.start() + logInfo("Started progress reporter thread - sleep time : " + sleepTime) + t + } + + private def sendProgress() { + logDebug("Sending progress") + // Simulated with an allocate request with no nodes requested. + yarnAllocator.allocateResources() + } + + /* + def printContainers(containers: List[Container]) = { + for (container <- containers) { + logInfo("Launching shell command on a new container." + + ", containerId=" + container.getId() + + ", containerNode=" + container.getNodeId().getHost() + + ":" + container.getNodeId().getPort() + + ", containerNodeURI=" + container.getNodeHttpAddress() + + ", containerState" + container.getState() + + ", containerResourceMemory" + + container.getResource().getMemory()) + } + } + */ + + def finishApplicationMaster(status: FinalApplicationStatus, diagnostics: String = "") { + synchronized { + if (isFinished) { + return + } + isFinished = true + } + + logInfo("finishApplicationMaster with " + status) + // Set tracking URL to empty since we don't have a history server. + amClient.unregisterApplicationMaster(status, "" /* appMessage */, "" /* appTrackingUrl */) + } + + /** + * Clean up the staging directory. + */ + private def cleanupStagingDir() { + var stagingDirPath: Path = null + try { + val preserveFiles = System.getProperty("spark.yarn.preserve.staging.files", "false").toBoolean + if (!preserveFiles) { + stagingDirPath = new Path(System.getenv("SPARK_YARN_STAGING_DIR")) + if (stagingDirPath == null) { + logError("Staging directory is null") + return + } + logInfo("Deleting staging directory " + stagingDirPath) + fs.delete(stagingDirPath, true) + } + } catch { + case ioe: IOException => + logError("Failed to cleanup staging dir " + stagingDirPath, ioe) + } + } + + // The shutdown hook that runs when a signal is received AND during normal close of the JVM. + class AppMasterShutdownHook(appMaster: ApplicationMaster) extends Runnable { + + def run() { + logInfo("AppMaster received a signal.") + // we need to clean up staging dir before HDFS is shut down + // make sure we don't delete it until this is the last AM + if (appMaster.isLastAMRetry) appMaster.cleanupStagingDir() + } + } +} + +object ApplicationMaster { + // Number of times to wait for the allocator loop to complete. + // Each loop iteration waits for 100ms, so maximum of 3 seconds. + // This is to ensure that we have reasonable number of containers before we start + // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be + // optimal as more containers are available. Might need to handle this better. + private val ALLOCATOR_LOOP_WAIT_COUNT = 30 + + private val applicationMasters = new CopyOnWriteArrayList[ApplicationMaster]() + + val sparkContextRef: AtomicReference[SparkContext] = + new AtomicReference[SparkContext](null /* initialValue */) + + val yarnAllocatorLoop: AtomicInteger = new AtomicInteger(0) + + def incrementAllocatorLoop(by: Int) { + val count = yarnAllocatorLoop.getAndAdd(by) + if (count >= ALLOCATOR_LOOP_WAIT_COUNT) { + yarnAllocatorLoop.synchronized { + // to wake threads off wait ... + yarnAllocatorLoop.notifyAll() + } + } + } + + def register(master: ApplicationMaster) { + applicationMasters.add(master) + } + + // TODO(harvey): See whether this should be discarded - it isn't used anywhere atm... + def sparkContextInitialized(sc: SparkContext): Boolean = { + var modified = false + sparkContextRef.synchronized { + modified = sparkContextRef.compareAndSet(null, sc) + sparkContextRef.notifyAll() + } + + // Add a shutdown hook - as a best case effort in case users do not call sc.stop or do + // System.exit. + // Should not really have to do this, but it helps YARN to evict resources earlier. + // Not to mention, prevent the Client from declaring failure even though we exited properly. + // Note that this will unfortunately not properly clean up the staging files because it gets + // called too late, after the filesystem is already shutdown. + if (modified) { + Runtime.getRuntime().addShutdownHook(new Thread with Logging { + // This is not only logs, but also ensures that log system is initialized for this instance + // when we are actually 'run'-ing. + logInfo("Adding shutdown hook for context " + sc) + override def run() { + logInfo("Invoking sc stop from shutdown hook") + sc.stop() + // Best case ... + for (master <- applicationMasters) { + master.finishApplicationMaster(FinalApplicationStatus.SUCCEEDED) + } + } + } ) + } + + // Wait for initialization to complete and atleast 'some' nodes can get allocated. + yarnAllocatorLoop.synchronized { + while (yarnAllocatorLoop.get() <= ALLOCATOR_LOOP_WAIT_COUNT) { + yarnAllocatorLoop.wait(1000L) + } + } + modified + } + + def main(argStrings: Array[String]) { + val args = new ApplicationMasterArguments(argStrings) + new ApplicationMaster(args).run() + } +} diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala new file mode 100644 index 0000000000..f76a5ddd39 --- /dev/null +++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import org.apache.spark.util.IntParam +import collection.mutable.ArrayBuffer + +class ApplicationMasterArguments(val args: Array[String]) { + var userJar: String = null + var userClass: String = null + var userArgs: Seq[String] = Seq[String]() + var workerMemory = 1024 + var workerCores = 1 + var numWorkers = 2 + + parseArgs(args.toList) + + private def parseArgs(inputArgs: List[String]): Unit = { + val userArgsBuffer = new ArrayBuffer[String]() + + var args = inputArgs + + while (! args.isEmpty) { + + args match { + case ("--jar") :: value :: tail => + userJar = value + args = tail + + case ("--class") :: value :: tail => + userClass = value + args = tail + + case ("--args") :: value :: tail => + userArgsBuffer += value + args = tail + + case ("--num-workers") :: IntParam(value) :: tail => + numWorkers = value + args = tail + + case ("--worker-memory") :: IntParam(value) :: tail => + workerMemory = value + args = tail + + case ("--worker-cores") :: IntParam(value) :: tail => + workerCores = value + args = tail + + case Nil => + if (userJar == null || userClass == null) { + printUsageAndExit(1) + } + + case _ => + printUsageAndExit(1, args) + } + } + + userArgs = userArgsBuffer.readOnly + } + + def printUsageAndExit(exitCode: Int, unknownParam: Any = null) { + if (unknownParam != null) { + System.err.println("Unknown/unsupported param " + unknownParam) + } + System.err.println( + "Usage: org.apache.spark.deploy.yarn.ApplicationMaster [options] \n" + + "Options:\n" + + " --jar JAR_PATH Path to your application's JAR file (required)\n" + + " --class CLASS_NAME Name of your application's main class (required)\n" + + " --args ARGS Arguments to be passed to your application's main class.\n" + + " Mutliple invocations are possible, each will be passed in order.\n" + + " --num-workers NUM Number of workers to start (Default: 2)\n" + + " --worker-cores NUM Number of cores for the workers (Default: 1)\n" + + " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n") + System.exit(exitCode) + } +} diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala new file mode 100644 index 0000000000..94678815e8 --- /dev/null +++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -0,0 +1,519 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.net.{InetAddress, UnknownHostException, URI} +import java.nio.ByteBuffer + +import scala.collection.JavaConversions._ +import scala.collection.mutable.HashMap +import scala.collection.mutable.Map + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileContext, FileStatus, FileSystem, Path, FileUtil} +import org.apache.hadoop.fs.permission.FsPermission; +import org.apache.hadoop.io.DataOutputBuffer +import org.apache.hadoop.mapred.Master +import org.apache.hadoop.net.NetUtils +import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.yarn.api._ +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment +import org.apache.hadoop.yarn.api.protocolrecords._ +import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.client.api.impl.YarnClientImpl +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.ipc.YarnRPC +import org.apache.hadoop.yarn.util.{Apps, Records} + +import org.apache.spark.Logging +import org.apache.spark.util.Utils +import org.apache.spark.deploy.SparkHadoopUtil + + +/** + * The entry point (starting in Client#main() and Client#run()) for launching Spark on YARN. The + * Client submits an application to the global ResourceManager to launch Spark's ApplicationMaster, + * which will launch a Spark master process and negotiate resources throughout its duration. + */ +class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl with Logging { + + var rpc: YarnRPC = YarnRPC.create(conf) + val yarnConf: YarnConfiguration = new YarnConfiguration(conf) + val credentials = UserGroupInformation.getCurrentUser().getCredentials() + private val SPARK_STAGING: String = ".sparkStaging" + private val distCacheMgr = new ClientDistributedCacheManager() + + // Staging directory is private! -> rwx-------- + val STAGING_DIR_PERMISSION: FsPermission = FsPermission.createImmutable(0700: Short) + // App files are world-wide readable and owner writable -> rw-r--r-- + val APP_FILE_PERMISSION: FsPermission = FsPermission.createImmutable(0644: Short) + + def this(args: ClientArguments) = this(new Configuration(), args) + + def runApp(): ApplicationId = { + validateArgs() + // Initialize and start the client service. + init(yarnConf) + start() + + // Log details about this YARN cluster (e.g, the number of slave machines/NodeManagers). + logClusterResourceDetails() + + // Prepare to submit a request to the ResourcManager (specifically its ApplicationsManager (ASM) + // interface). + + // Get a new client application. + val newApp = super.createApplication() + val newAppResponse = newApp.getNewApplicationResponse() + val appId = newAppResponse.getApplicationId() + + verifyClusterResources(newAppResponse) + + // Set up resource and environment variables. + val appStagingDir = getAppStagingDir(appId) + val localResources = prepareLocalResources(appStagingDir) + val launchEnv = setupLaunchEnv(localResources, appStagingDir) + val amContainer = createContainerLaunchContext(newAppResponse, localResources, launchEnv) + + // Set up an application submission context. + val appContext = newApp.getApplicationSubmissionContext() + appContext.setApplicationName(args.appName) + appContext.setQueue(args.amQueue) + appContext.setAMContainerSpec(amContainer) + + // Memory for the ApplicationMaster. + val memoryResource = Records.newRecord(classOf[Resource]).asInstanceOf[Resource] + memoryResource.setMemory(args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + appContext.setResource(memoryResource) + + // Finally, submit and monitor the application. + submitApp(appContext) + appId + } + + def run() { + val appId = runApp() + monitorApplication(appId) + System.exit(0) + } + + // TODO(harvey): This could just go in ClientArguments. + def validateArgs() = { + Map( + (System.getenv("SPARK_JAR") == null) -> "Error: You must set SPARK_JAR environment variable!", + (args.userJar == null) -> "Error: You must specify a user jar!", + (args.userClass == null) -> "Error: You must specify a user class!", + (args.numWorkers <= 0) -> "Error: You must specify atleast 1 worker!", + (args.amMemory <= YarnAllocationHandler.MEMORY_OVERHEAD) -> ("Error: AM memory size must be" + + "greater than: " + YarnAllocationHandler.MEMORY_OVERHEAD), + (args.workerMemory <= YarnAllocationHandler.MEMORY_OVERHEAD) -> ("Error: Worker memory size" + + "must be greater than: " + YarnAllocationHandler.MEMORY_OVERHEAD.toString) + ).foreach { case(cond, errStr) => + if (cond) { + logError(errStr) + args.printUsageAndExit(1) + } + } + } + + def getAppStagingDir(appId: ApplicationId): String = { + SPARK_STAGING + Path.SEPARATOR + appId.toString() + Path.SEPARATOR + } + + def logClusterResourceDetails() { + val clusterMetrics: YarnClusterMetrics = super.getYarnClusterMetrics + logInfo("Got Cluster metric info from ApplicationsManager (ASM), number of NodeManagers: " + + clusterMetrics.getNumNodeManagers) + + val queueInfo: QueueInfo = super.getQueueInfo(args.amQueue) + logInfo("""Queue info ... queueName: %s, queueCurrentCapacity: %s, queueMaxCapacity: %s, + queueApplicationCount = %s, queueChildQueueCount = %s""".format( + queueInfo.getQueueName, + queueInfo.getCurrentCapacity, + queueInfo.getMaximumCapacity, + queueInfo.getApplications.size, + queueInfo.getChildQueues.size)) + } + + def verifyClusterResources(app: GetNewApplicationResponse) = { + val maxMem = app.getMaximumResourceCapability().getMemory() + logInfo("Max mem capabililty of a single resource in this cluster " + maxMem) + + // If we have requested more then the clusters max for a single resource then exit. + if (args.workerMemory > maxMem) { + logError("Required worker memory (%d MB), is above the max threshold (%d MB) of this cluster.". + format(args.workerMemory, maxMem)) + System.exit(1) + } + val amMem = args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD + if (amMem > maxMem) { + logError("Required AM memory (%d) is above the max threshold (%d) of this cluster". + format(args.amMemory, maxMem)) + System.exit(1) + } + + // We could add checks to make sure the entire cluster has enough resources but that involves + // getting all the node reports and computing ourselves. + } + + /** See if two file systems are the same or not. */ + private def compareFs(srcFs: FileSystem, destFs: FileSystem): Boolean = { + val srcUri = srcFs.getUri() + val dstUri = destFs.getUri() + if (srcUri.getScheme() == null) { + return false + } + if (!srcUri.getScheme().equals(dstUri.getScheme())) { + return false + } + var srcHost = srcUri.getHost() + var dstHost = dstUri.getHost() + if ((srcHost != null) && (dstHost != null)) { + try { + srcHost = InetAddress.getByName(srcHost).getCanonicalHostName() + dstHost = InetAddress.getByName(dstHost).getCanonicalHostName() + } catch { + case e: UnknownHostException => + return false + } + if (!srcHost.equals(dstHost)) { + return false + } + } else if (srcHost == null && dstHost != null) { + return false + } else if (srcHost != null && dstHost == null) { + return false + } + //check for ports + if (srcUri.getPort() != dstUri.getPort()) { + return false + } + return true + } + + /** Copy the file into HDFS if needed. */ + private def copyRemoteFile( + dstDir: Path, + originalPath: Path, + replication: Short, + setPerms: Boolean = false): Path = { + val fs = FileSystem.get(conf) + val remoteFs = originalPath.getFileSystem(conf) + var newPath = originalPath + if (! compareFs(remoteFs, fs)) { + newPath = new Path(dstDir, originalPath.getName()) + logInfo("Uploading " + originalPath + " to " + newPath) + FileUtil.copy(remoteFs, originalPath, fs, newPath, false, conf) + fs.setReplication(newPath, replication) + if (setPerms) fs.setPermission(newPath, new FsPermission(APP_FILE_PERMISSION)) + } + // Resolve any symlinks in the URI path so using a "current" symlink to point to a specific + // version shows the specific version in the distributed cache configuration + val qualPath = fs.makeQualified(newPath) + val fc = FileContext.getFileContext(qualPath.toUri(), conf) + val destPath = fc.resolvePath(qualPath) + destPath + } + + def prepareLocalResources(appStagingDir: String): HashMap[String, LocalResource] = { + logInfo("Preparing Local resources") + // Upload Spark and the application JAR to the remote file system if necessary. Add them as + // local resources to the application master. + val fs = FileSystem.get(conf) + + val delegTokenRenewer = Master.getMasterPrincipal(conf) + if (UserGroupInformation.isSecurityEnabled()) { + if (delegTokenRenewer == null || delegTokenRenewer.length() == 0) { + logError("Can't get Master Kerberos principal for use as renewer") + System.exit(1) + } + } + val dst = new Path(fs.getHomeDirectory(), appStagingDir) + val replication = System.getProperty("spark.yarn.submit.file.replication", "3").toShort + + if (UserGroupInformation.isSecurityEnabled()) { + val dstFs = dst.getFileSystem(conf) + dstFs.addDelegationTokens(delegTokenRenewer, credentials) + } + + val localResources = HashMap[String, LocalResource]() + FileSystem.mkdirs(fs, dst, new FsPermission(STAGING_DIR_PERMISSION)) + + val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() + + Map( + Client.SPARK_JAR -> System.getenv("SPARK_JAR"), Client.APP_JAR -> args.userJar, + Client.LOG4J_PROP -> System.getenv("SPARK_LOG4J_CONF") + ).foreach { case(destName, _localPath) => + val localPath: String = if (_localPath != null) _localPath.trim() else "" + if (! localPath.isEmpty()) { + var localURI = new URI(localPath) + // If not specified assume these are in the local filesystem to keep behavior like Hadoop + if (localURI.getScheme() == null) { + localURI = new URI(FileSystem.getLocal(conf).makeQualified(new Path(localPath)).toString) + } + val setPermissions = if (destName.equals(Client.APP_JAR)) true else false + val destPath = copyRemoteFile(dst, new Path(localURI), replication, setPermissions) + distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, + destName, statCache) + } + } + + // Handle jars local to the ApplicationMaster. + if ((args.addJars != null) && (!args.addJars.isEmpty())){ + args.addJars.split(',').foreach { case file: String => + val localURI = new URI(file.trim()) + val localPath = new Path(localURI) + val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName()) + val destPath = copyRemoteFile(dst, localPath, replication) + // Only add the resource to the Spark ApplicationMaster. + val appMasterOnly = true + distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, + linkname, statCache, appMasterOnly) + } + } + + // Handle any distributed cache files + if ((args.files != null) && (!args.files.isEmpty())){ + args.files.split(',').foreach { case file: String => + val localURI = new URI(file.trim()) + val localPath = new Path(localURI) + val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName()) + val destPath = copyRemoteFile(dst, localPath, replication) + distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, + linkname, statCache) + } + } + + // Handle any distributed cache archives + if ((args.archives != null) && (!args.archives.isEmpty())) { + args.archives.split(',').foreach { case file:String => + val localURI = new URI(file.trim()) + val localPath = new Path(localURI) + val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName()) + val destPath = copyRemoteFile(dst, localPath, replication) + distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, + linkname, statCache) + } + } + + UserGroupInformation.getCurrentUser().addCredentials(credentials) + localResources + } + + def setupLaunchEnv( + localResources: HashMap[String, LocalResource], + stagingDir: String): HashMap[String, String] = { + logInfo("Setting up the launch environment") + val log4jConfLocalRes = localResources.getOrElse(Client.LOG4J_PROP, null) + + val env = new HashMap[String, String]() + + Client.populateClasspath(yarnConf, log4jConfLocalRes != null, env) + env("SPARK_YARN_MODE") = "true" + env("SPARK_YARN_STAGING_DIR") = stagingDir + + // Set the environment variables to be passed on to the Workers. + distCacheMgr.setDistFilesEnv(env) + distCacheMgr.setDistArchivesEnv(env) + + // Allow users to specify some environment variables. + Apps.setEnvFromInputString(env, System.getenv("SPARK_YARN_USER_ENV")) + + // Add each SPARK_* key to the environment. + System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v } + + env + } + + def userArgsToString(clientArgs: ClientArguments): String = { + val prefix = " --args " + val args = clientArgs.userArgs + val retval = new StringBuilder() + for (arg <- args){ + retval.append(prefix).append(" '").append(arg).append("' ") + } + retval.toString + } + + def createContainerLaunchContext( + newApp: GetNewApplicationResponse, + localResources: HashMap[String, LocalResource], + env: HashMap[String, String]): ContainerLaunchContext = { + logInfo("Setting up container launch context") + val amContainer = Records.newRecord(classOf[ContainerLaunchContext]) + amContainer.setLocalResources(localResources) + amContainer.setEnvironment(env) + + // TODO: Need a replacement for the following code to fix -Xmx? + // val minResMemory: Int = newApp.getMinimumResourceCapability().getMemory() + // var amMemory = ((args.amMemory / minResMemory) * minResMemory) + + // ((if ((args.amMemory % minResMemory) == 0) 0 else minResMemory) - + // YarnAllocationHandler.MEMORY_OVERHEAD) + + // Extra options for the JVM + var JAVA_OPTS = "" + + // Add Xmx for AM memory + JAVA_OPTS += "-Xmx" + args.amMemory + "m" + + val tmpDir = new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR) + JAVA_OPTS += " -Djava.io.tmpdir=" + tmpDir + + // TODO: Remove once cpuset version is pushed out. + // The context is, default gc for server class machines ends up using all cores to do gc - + // hence if there are multiple containers in same node, Spark GC affects all other containers' + // performance (which can be that of other Spark containers) + // Instead of using this, rely on cpusets by YARN to enforce "proper" Spark behavior in + // multi-tenant environments. Not sure how default Java GC behaves if it is limited to subset + // of cores on a node. + val useConcurrentAndIncrementalGC = env.isDefinedAt("SPARK_USE_CONC_INCR_GC") && + java.lang.Boolean.parseBoolean(env("SPARK_USE_CONC_INCR_GC")) + if (useConcurrentAndIncrementalGC) { + // In our expts, using (default) throughput collector has severe perf ramifications in + // multi-tenant machines + JAVA_OPTS += " -XX:+UseConcMarkSweepGC " + JAVA_OPTS += " -XX:+CMSIncrementalMode " + JAVA_OPTS += " -XX:+CMSIncrementalPacing " + JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 " + JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 " + } + + if (env.isDefinedAt("SPARK_JAVA_OPTS")) { + JAVA_OPTS += " " + env("SPARK_JAVA_OPTS") + } + + // Command for the ApplicationMaster + var javaCommand = "java" + val javaHome = System.getenv("JAVA_HOME") + if ((javaHome != null && !javaHome.isEmpty()) || env.isDefinedAt("JAVA_HOME")) { + javaCommand = Environment.JAVA_HOME.$() + "/bin/java" + } + + val commands = List[String]( + javaCommand + + " -server " + + JAVA_OPTS + + " " + args.amClass + + " --class " + args.userClass + + " --jar " + args.userJar + + userArgsToString(args) + + " --worker-memory " + args.workerMemory + + " --worker-cores " + args.workerCores + + " --num-workers " + args.numWorkers + + " 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" + + " 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr") + + logInfo("Command for starting the Spark ApplicationMaster: " + commands(0)) + amContainer.setCommands(commands) + + // Setup security tokens. + val dob = new DataOutputBuffer() + credentials.writeTokenStorageToStream(dob) + amContainer.setTokens(ByteBuffer.wrap(dob.getData())) + + amContainer + } + + def submitApp(appContext: ApplicationSubmissionContext) = { + // Submit the application to the applications manager. + logInfo("Submitting application to ASM") + super.submitApplication(appContext) + } + + def monitorApplication(appId: ApplicationId): Boolean = { + while (true) { + Thread.sleep(1000) + val report = super.getApplicationReport(appId) + + logInfo("Application report from ASM: \n" + + "\t application identifier: " + appId.toString() + "\n" + + "\t appId: " + appId.getId() + "\n" + + "\t clientToAMToken: " + report.getClientToAMToken() + "\n" + + "\t appDiagnostics: " + report.getDiagnostics() + "\n" + + "\t appMasterHost: " + report.getHost() + "\n" + + "\t appQueue: " + report.getQueue() + "\n" + + "\t appMasterRpcPort: " + report.getRpcPort() + "\n" + + "\t appStartTime: " + report.getStartTime() + "\n" + + "\t yarnAppState: " + report.getYarnApplicationState() + "\n" + + "\t distributedFinalState: " + report.getFinalApplicationStatus() + "\n" + + "\t appTrackingUrl: " + report.getTrackingUrl() + "\n" + + "\t appUser: " + report.getUser() + ) + + val state = report.getYarnApplicationState() + val dsStatus = report.getFinalApplicationStatus() + if (state == YarnApplicationState.FINISHED || + state == YarnApplicationState.FAILED || + state == YarnApplicationState.KILLED) { + return true + } + } + true + } +} + +object Client { + val SPARK_JAR: String = "spark.jar" + val APP_JAR: String = "app.jar" + val LOG4J_PROP: String = "log4j.properties" + + def main(argStrings: Array[String]) { + // Set an env variable indicating we are running in YARN mode. + // Note: anything env variable with SPARK_ prefix gets propagated to all (remote) processes - + // see Client#setupLaunchEnv(). + System.setProperty("SPARK_YARN_MODE", "true") + + val args = new ClientArguments(argStrings) + + (new Client(args)).run() + } + + // Based on code from org.apache.hadoop.mapreduce.v2.util.MRApps + def populateHadoopClasspath(conf: Configuration, env: HashMap[String, String]) { + for (c <- conf.getStrings(YarnConfiguration.YARN_APPLICATION_CLASSPATH)) { + Apps.addToEnvironment(env, Environment.CLASSPATH.name, c.trim) + } + } + + def populateClasspath(conf: Configuration, addLog4j: Boolean, env: HashMap[String, String]) { + Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$()) + // If log4j present, ensure ours overrides all others + if (addLog4j) { + Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + + Path.SEPARATOR + LOG4J_PROP) + } + // Normally the users app.jar is last in case conflicts with spark jars + val userClasspathFirst = System.getProperty("spark.yarn.user.classpath.first", "false") + .toBoolean + if (userClasspathFirst) { + Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + + Path.SEPARATOR + APP_JAR) + } + Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + + Path.SEPARATOR + SPARK_JAR) + Client.populateHadoopClasspath(conf, env) + + if (!userClasspathFirst) { + Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + + Path.SEPARATOR + APP_JAR) + } + Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + + Path.SEPARATOR + "*") + } +} diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala new file mode 100644 index 0000000000..70be15d0a3 --- /dev/null +++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import scala.collection.mutable.{ArrayBuffer, HashMap} + +import org.apache.spark.scheduler.{InputFormatInfo, SplitInfo} +import org.apache.spark.util.IntParam +import org.apache.spark.util.MemoryParam + + +// TODO: Add code and support for ensuring that yarn resource 'tasks' are location aware ! +class ClientArguments(val args: Array[String]) { + var addJars: String = null + var files: String = null + var archives: String = null + var userJar: String = null + var userClass: String = null + var userArgs: Seq[String] = Seq[String]() + var workerMemory = 1024 // MB + var workerCores = 1 + var numWorkers = 2 + var amQueue = System.getProperty("QUEUE", "default") + var amMemory: Int = 512 // MB + var amClass: String = "org.apache.spark.deploy.yarn.ApplicationMaster" + var appName: String = "Spark" + // TODO + var inputFormatInfo: List[InputFormatInfo] = null + // TODO(harvey) + var priority = 0 + + parseArgs(args.toList) + + private def parseArgs(inputArgs: List[String]): Unit = { + val userArgsBuffer: ArrayBuffer[String] = new ArrayBuffer[String]() + val inputFormatMap: HashMap[String, InputFormatInfo] = new HashMap[String, InputFormatInfo]() + + var args = inputArgs + + while (!args.isEmpty) { + args match { + case ("--jar") :: value :: tail => + userJar = value + args = tail + + case ("--class") :: value :: tail => + userClass = value + args = tail + + case ("--args") :: value :: tail => + userArgsBuffer += value + args = tail + + case ("--master-class") :: value :: tail => + amClass = value + args = tail + + case ("--master-memory") :: MemoryParam(value) :: tail => + amMemory = value + args = tail + + case ("--num-workers") :: IntParam(value) :: tail => + numWorkers = value + args = tail + + case ("--worker-memory") :: MemoryParam(value) :: tail => + workerMemory = value + args = tail + + case ("--worker-cores") :: IntParam(value) :: tail => + workerCores = value + args = tail + + case ("--queue") :: value :: tail => + amQueue = value + args = tail + + case ("--name") :: value :: tail => + appName = value + args = tail + + case ("--addJars") :: value :: tail => + addJars = value + args = tail + + case ("--files") :: value :: tail => + files = value + args = tail + + case ("--archives") :: value :: tail => + archives = value + args = tail + + case Nil => + if (userJar == null || userClass == null) { + printUsageAndExit(1) + } + + case _ => + printUsageAndExit(1, args) + } + } + + userArgs = userArgsBuffer.readOnly + inputFormatInfo = inputFormatMap.values.toList + } + + + def printUsageAndExit(exitCode: Int, unknownParam: Any = null) { + if (unknownParam != null) { + System.err.println("Unknown/unsupported param " + unknownParam) + } + System.err.println( + "Usage: org.apache.spark.deploy.yarn.Client [options] \n" + + "Options:\n" + + " --jar JAR_PATH Path to your application's JAR file (required)\n" + + " --class CLASS_NAME Name of your application's main class (required)\n" + + " --args ARGS Arguments to be passed to your application's main class.\n" + + " Mutliple invocations are possible, each will be passed in order.\n" + + " --num-workers NUM Number of workers to start (Default: 2)\n" + + " --worker-cores NUM Number of cores for the workers (Default: 1). This is unsused right now.\n" + + " --master-class CLASS_NAME Class Name for Master (Default: spark.deploy.yarn.ApplicationMaster)\n" + + " --master-memory MEM Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)\n" + + " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n" + + " --name NAME The name of your application (Default: Spark)\n" + + " --queue QUEUE The hadoop queue to use for allocation requests (Default: 'default')\n" + + " --addJars jars Comma separated list of local jars that want SparkContext.addJar to work with.\n" + + " --files files Comma separated list of files to be distributed with the job.\n" + + " --archives archives Comma separated list of archives to be distributed with the job." + ) + System.exit(exitCode) + } + +} diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala new file mode 100644 index 0000000000..5f159b073f --- /dev/null +++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.net.URI + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileStatus +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.permission.FsAction +import org.apache.hadoop.yarn.api.records.LocalResource +import org.apache.hadoop.yarn.api.records.LocalResourceVisibility +import org.apache.hadoop.yarn.api.records.LocalResourceType +import org.apache.hadoop.yarn.util.{Records, ConverterUtils} + +import org.apache.spark.Logging + +import scala.collection.mutable.HashMap +import scala.collection.mutable.LinkedHashMap +import scala.collection.mutable.Map + + +/** Client side methods to setup the Hadoop distributed cache */ +class ClientDistributedCacheManager() extends Logging { + private val distCacheFiles: Map[String, Tuple3[String, String, String]] = + LinkedHashMap[String, Tuple3[String, String, String]]() + private val distCacheArchives: Map[String, Tuple3[String, String, String]] = + LinkedHashMap[String, Tuple3[String, String, String]]() + + + /** + * Add a resource to the list of distributed cache resources. This list can + * be sent to the ApplicationMaster and possibly the workers so that it can + * be downloaded into the Hadoop distributed cache for use by this application. + * Adds the LocalResource to the localResources HashMap passed in and saves + * the stats of the resources to they can be sent to the workers and verified. + * + * @param fs FileSystem + * @param conf Configuration + * @param destPath path to the resource + * @param localResources localResource hashMap to insert the resource into + * @param resourceType LocalResourceType + * @param link link presented in the distributed cache to the destination + * @param statCache cache to store the file/directory stats + * @param appMasterOnly Whether to only add the resource to the app master + */ + def addResource( + fs: FileSystem, + conf: Configuration, + destPath: Path, + localResources: HashMap[String, LocalResource], + resourceType: LocalResourceType, + link: String, + statCache: Map[URI, FileStatus], + appMasterOnly: Boolean = false) = { + val destStatus = fs.getFileStatus(destPath) + val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] + amJarRsrc.setType(resourceType) + val visibility = getVisibility(conf, destPath.toUri(), statCache) + amJarRsrc.setVisibility(visibility) + amJarRsrc.setResource(ConverterUtils.getYarnUrlFromPath(destPath)) + amJarRsrc.setTimestamp(destStatus.getModificationTime()) + amJarRsrc.setSize(destStatus.getLen()) + if (link == null || link.isEmpty()) throw new Exception("You must specify a valid link name") + localResources(link) = amJarRsrc + + if (appMasterOnly == false) { + val uri = destPath.toUri() + val pathURI = new URI(uri.getScheme(), uri.getAuthority(), uri.getPath(), null, link) + if (resourceType == LocalResourceType.FILE) { + distCacheFiles(pathURI.toString()) = (destStatus.getLen().toString(), + destStatus.getModificationTime().toString(), visibility.name()) + } else { + distCacheArchives(pathURI.toString()) = (destStatus.getLen().toString(), + destStatus.getModificationTime().toString(), visibility.name()) + } + } + } + + /** + * Adds the necessary cache file env variables to the env passed in + * @param env + */ + def setDistFilesEnv(env: Map[String, String]) = { + val (keys, tupleValues) = distCacheFiles.unzip + val (sizes, timeStamps, visibilities) = tupleValues.unzip3 + + if (keys.size > 0) { + env("SPARK_YARN_CACHE_FILES") = keys.reduceLeft[String] { (acc,n) => acc + "," + n } + env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") = + timeStamps.reduceLeft[String] { (acc,n) => acc + "," + n } + env("SPARK_YARN_CACHE_FILES_FILE_SIZES") = + sizes.reduceLeft[String] { (acc,n) => acc + "," + n } + env("SPARK_YARN_CACHE_FILES_VISIBILITIES") = + visibilities.reduceLeft[String] { (acc,n) => acc + "," + n } + } + } + + /** + * Adds the necessary cache archive env variables to the env passed in + * @param env + */ + def setDistArchivesEnv(env: Map[String, String]) = { + val (keys, tupleValues) = distCacheArchives.unzip + val (sizes, timeStamps, visibilities) = tupleValues.unzip3 + + if (keys.size > 0) { + env("SPARK_YARN_CACHE_ARCHIVES") = keys.reduceLeft[String] { (acc,n) => acc + "," + n } + env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") = + timeStamps.reduceLeft[String] { (acc,n) => acc + "," + n } + env("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") = + sizes.reduceLeft[String] { (acc,n) => acc + "," + n } + env("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") = + visibilities.reduceLeft[String] { (acc,n) => acc + "," + n } + } + } + + /** + * Returns the local resource visibility depending on the cache file permissions + * @param conf + * @param uri + * @param statCache + * @return LocalResourceVisibility + */ + def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): + LocalResourceVisibility = { + if (isPublic(conf, uri, statCache)) { + return LocalResourceVisibility.PUBLIC + } + return LocalResourceVisibility.PRIVATE + } + + /** + * Returns a boolean to denote whether a cache file is visible to all(public) + * or not + * @param conf + * @param uri + * @param statCache + * @return true if the path in the uri is visible to all, false otherwise + */ + def isPublic(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): Boolean = { + val fs = FileSystem.get(uri, conf) + val current = new Path(uri.getPath()) + //the leaf level file should be readable by others + if (!checkPermissionOfOther(fs, current, FsAction.READ, statCache)) { + return false + } + return ancestorsHaveExecutePermissions(fs, current.getParent(), statCache) + } + + /** + * Returns true if all ancestors of the specified path have the 'execute' + * permission set for all users (i.e. that other users can traverse + * the directory heirarchy to the given path) + * @param fs + * @param path + * @param statCache + * @return true if all ancestors have the 'execute' permission set for all users + */ + def ancestorsHaveExecutePermissions(fs: FileSystem, path: Path, + statCache: Map[URI, FileStatus]): Boolean = { + var current = path + while (current != null) { + //the subdirs in the path should have execute permissions for others + if (!checkPermissionOfOther(fs, current, FsAction.EXECUTE, statCache)) { + return false + } + current = current.getParent() + } + return true + } + + /** + * Checks for a given path whether the Other permissions on it + * imply the permission in the passed FsAction + * @param fs + * @param path + * @param action + * @param statCache + * @return true if the path in the uri is visible to all, false otherwise + */ + def checkPermissionOfOther(fs: FileSystem, path: Path, + action: FsAction, statCache: Map[URI, FileStatus]): Boolean = { + val status = getFileStatus(fs, path.toUri(), statCache) + val perms = status.getPermission() + val otherAction = perms.getOtherAction() + if (otherAction.implies(action)) { + return true + } + return false + } + + /** + * Checks to see if the given uri exists in the cache, if it does it + * returns the existing FileStatus, otherwise it stats the uri, stores + * it in the cache, and returns the FileStatus. + * @param fs + * @param uri + * @param statCache + * @return FileStatus + */ + def getFileStatus(fs: FileSystem, uri: URI, statCache: Map[URI, FileStatus]): FileStatus = { + val stat = statCache.get(uri) match { + case Some(existstat) => existstat + case None => + val newStat = fs.getFileStatus(new Path(uri)) + statCache.put(uri, newStat) + newStat + } + return stat + } +} diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala new file mode 100644 index 0000000000..bc31bb2eb0 --- /dev/null +++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.net.Socket +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.net.NetUtils +import org.apache.hadoop.yarn.api._ +import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.api.protocolrecords._ +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.util.{ConverterUtils, Records} +import akka.actor._ +import akka.remote._ +import akka.actor.Terminated +import org.apache.spark.{SparkContext, Logging} +import org.apache.spark.util.{Utils, AkkaUtils} +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend +import org.apache.spark.scheduler.SplitInfo +import org.apache.hadoop.yarn.client.api.AMRMClient +import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest + +class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration) extends Logging { + + def this(args: ApplicationMasterArguments) = this(args, new Configuration()) + + private var appAttemptId: ApplicationAttemptId = _ + private var reporterThread: Thread = _ + private val yarnConf: YarnConfiguration = new YarnConfiguration(conf) + + private var yarnAllocator: YarnAllocationHandler = _ + private var driverClosed:Boolean = false + + private var amClient: AMRMClient[ContainerRequest] = _ + + val actorSystem : ActorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0)._1 + var actor: ActorRef = _ + + // This actor just working as a monitor to watch on Driver Actor. + class MonitorActor(driverUrl: String) extends Actor { + + var driver: ActorSelection = null + + override def preStart() { + logInfo("Listen to driver: " + driverUrl) + driver = context.actorSelection(driverUrl) + context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + } + + override def receive = { + case x: DisassociatedEvent => + logInfo("Driver terminated or disconnected! Shutting down.") + driverClosed = true + } + } + + def run() { + + amClient = AMRMClient.createAMRMClient() + amClient.init(yarnConf) + amClient.start() + + appAttemptId = getApplicationAttemptId() + registerApplicationMaster() + + waitForSparkMaster() + + // Allocate all containers + allocateWorkers() + + // Launch a progress reporter thread, else app will get killed after expiration (def: 10mins) timeout + // ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse. + + val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000) + // must be <= timeoutInterval/ 2. + // On other hand, also ensure that we are reasonably responsive without causing too many requests to RM. + // so atleast 1 minute or timeoutInterval / 10 - whichever is higher. + val interval = math.min(timeoutInterval / 2, math.max(timeoutInterval/ 10, 60000L)) + reporterThread = launchReporterThread(interval) + + // Wait for the reporter thread to Finish. + reporterThread.join() + + finishApplicationMaster(FinalApplicationStatus.SUCCEEDED) + actorSystem.shutdown() + + logInfo("Exited") + System.exit(0) + } + + private def getApplicationAttemptId(): ApplicationAttemptId = { + val envs = System.getenv() + val containerIdString = envs.get(ApplicationConstants.Environment.CONTAINER_ID.name()) + val containerId = ConverterUtils.toContainerId(containerIdString) + val appAttemptId = containerId.getApplicationAttemptId() + logInfo("ApplicationAttemptId: " + appAttemptId) + appAttemptId + } + + private def registerApplicationMaster(): RegisterApplicationMasterResponse = { + logInfo("Registering the ApplicationMaster") + // TODO:(Raymond) Find out Spark UI address and fill in here? + amClient.registerApplicationMaster(Utils.localHostName(), 0, "") + } + + private def waitForSparkMaster() { + logInfo("Waiting for Spark driver to be reachable.") + var driverUp = false + val hostport = args.userArgs(0) + val (driverHost, driverPort) = Utils.parseHostPort(hostport) + while(!driverUp) { + try { + val socket = new Socket(driverHost, driverPort) + socket.close() + logInfo("Driver now available: %s:%s".format(driverHost, driverPort)) + driverUp = true + } catch { + case e: Exception => + logError("Failed to connect to driver at %s:%s, retrying ...". + format(driverHost, driverPort)) + Thread.sleep(100) + } + } + System.setProperty("spark.driver.host", driverHost) + System.setProperty("spark.driver.port", driverPort.toString) + + val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format( + driverHost, driverPort.toString, CoarseGrainedSchedulerBackend.ACTOR_NAME) + + actor = actorSystem.actorOf(Props(new MonitorActor(driverUrl)), name = "YarnAM") + } + + + private def allocateWorkers() { + + // Fixme: should get preferredNodeLocationData from SparkContext, just fake a empty one for now. + val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = + scala.collection.immutable.Map() + + yarnAllocator = YarnAllocationHandler.newAllocator( + yarnConf, + amClient, + appAttemptId, + args, + preferredNodeLocationData) + + logInfo("Allocating " + args.numWorkers + " workers.") + // Wait until all containers have finished + // TODO: This is a bit ugly. Can we make it nicer? + // TODO: Handle container failure + + yarnAllocator.addResourceRequests(args.numWorkers) + while(yarnAllocator.getNumWorkersRunning < args.numWorkers) { + yarnAllocator.allocateResources() + Thread.sleep(100) + } + + logInfo("All workers have launched.") + + } + + // TODO: We might want to extend this to allocate more containers in case they die ! + private def launchReporterThread(_sleepTime: Long): Thread = { + val sleepTime = if (_sleepTime <= 0 ) 0 else _sleepTime + + val t = new Thread { + override def run() { + while (!driverClosed) { + val missingWorkerCount = args.numWorkers - yarnAllocator.getNumWorkersRunning - + yarnAllocator.getNumPendingAllocate + if (missingWorkerCount > 0) { + logInfo("Allocating %d containers to make up for (potentially) lost containers". + format(missingWorkerCount)) + yarnAllocator.addResourceRequests(missingWorkerCount) + } + sendProgress() + Thread.sleep(sleepTime) + } + } + } + // setting to daemon status, though this is usually not a good idea. + t.setDaemon(true) + t.start() + logInfo("Started progress reporter thread - sleep time : " + sleepTime) + t + } + + private def sendProgress() { + logDebug("Sending progress") + // simulated with an allocate request with no nodes requested ... + yarnAllocator.allocateResources() + } + + def finishApplicationMaster(status: FinalApplicationStatus) { + logInfo("finish ApplicationMaster with " + status) + amClient.unregisterApplicationMaster(status, "" /* appMessage */, "" /* appTrackingUrl */) + } + +} + + +object WorkerLauncher { + def main(argStrings: Array[String]) { + val args = new ApplicationMasterArguments(argStrings) + new WorkerLauncher(args).run() + } +} diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala new file mode 100644 index 0000000000..9f5523c4b9 --- /dev/null +++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.net.URI +import java.nio.ByteBuffer +import java.security.PrivilegedExceptionAction + +import scala.collection.JavaConversions._ +import scala.collection.mutable.HashMap + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.DataOutputBuffer +import org.apache.hadoop.net.NetUtils +import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.yarn.api._ +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment +import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.api.records.impl.pb.ProtoUtils +import org.apache.hadoop.yarn.api.protocolrecords._ +import org.apache.hadoop.yarn.client.api.NMClient +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.ipc.YarnRPC +import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records} + +import org.apache.spark.Logging + + +class WorkerRunnable( + container: Container, + conf: Configuration, + masterAddress: String, + slaveId: String, + hostname: String, + workerMemory: Int, + workerCores: Int) + extends Runnable with Logging { + + var rpc: YarnRPC = YarnRPC.create(conf) + var nmClient: NMClient = _ + val yarnConf: YarnConfiguration = new YarnConfiguration(conf) + + def run = { + logInfo("Starting Worker Container") + nmClient = NMClient.createNMClient() + nmClient.init(yarnConf) + nmClient.start() + startContainer + } + + def startContainer = { + logInfo("Setting up ContainerLaunchContext") + + val ctx = Records.newRecord(classOf[ContainerLaunchContext]) + .asInstanceOf[ContainerLaunchContext] + + val localResources = prepareLocalResources + ctx.setLocalResources(localResources) + + val env = prepareEnvironment + ctx.setEnvironment(env) + + // Extra options for the JVM + var JAVA_OPTS = "" + // Set the JVM memory + val workerMemoryString = workerMemory + "m" + JAVA_OPTS += "-Xms" + workerMemoryString + " -Xmx" + workerMemoryString + " " + if (env.isDefinedAt("SPARK_JAVA_OPTS")) { + JAVA_OPTS += env("SPARK_JAVA_OPTS") + " " + } + + JAVA_OPTS += " -Djava.io.tmpdir=" + + new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR) + " " + + // Commenting it out for now - so that people can refer to the properties if required. Remove + // it once cpuset version is pushed out. + // The context is, default gc for server class machines end up using all cores to do gc - hence + // if there are multiple containers in same node, spark gc effects all other containers + // performance (which can also be other spark containers) + // Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in + // multi-tenant environments. Not sure how default java gc behaves if it is limited to subset + // of cores on a node. +/* + else { + // If no java_opts specified, default to using -XX:+CMSIncrementalMode + // It might be possible that other modes/config is being done in SPARK_JAVA_OPTS, so we dont + // want to mess with it. + // In our expts, using (default) throughput collector has severe perf ramnifications in + // multi-tennent machines + // The options are based on + // http://www.oracle.com/technetwork/java/gc-tuning-5-138395.html#0.0.0.%20When%20to%20Use%20the%20Concurrent%20Low%20Pause%20Collector|outline + JAVA_OPTS += " -XX:+UseConcMarkSweepGC " + JAVA_OPTS += " -XX:+CMSIncrementalMode " + JAVA_OPTS += " -XX:+CMSIncrementalPacing " + JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 " + JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 " + } +*/ + + val credentials = UserGroupInformation.getCurrentUser().getCredentials() + val dob = new DataOutputBuffer() + credentials.writeTokenStorageToStream(dob) + ctx.setTokens(ByteBuffer.wrap(dob.getData())) + + var javaCommand = "java" + val javaHome = System.getenv("JAVA_HOME") + if ((javaHome != null && !javaHome.isEmpty()) || env.isDefinedAt("JAVA_HOME")) { + javaCommand = Environment.JAVA_HOME.$() + "/bin/java" + } + + val commands = List[String](javaCommand + + " -server " + + // Kill if OOM is raised - leverage yarn's failure handling to cause rescheduling. + // Not killing the task leaves various aspects of the worker and (to some extent) the jvm in + // an inconsistent state. + // TODO: If the OOM is not recoverable by rescheduling it on different node, then do + // 'something' to fail job ... akin to blacklisting trackers in mapred ? + " -XX:OnOutOfMemoryError='kill %p' " + + JAVA_OPTS + + " org.apache.spark.executor.CoarseGrainedExecutorBackend " + + masterAddress + " " + + slaveId + " " + + hostname + " " + + workerCores + + " 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" + + " 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr") + logInfo("Setting up worker with commands: " + commands) + ctx.setCommands(commands) + + // Send the start request to the ContainerManager + nmClient.startContainer(container, ctx) + } + + private def setupDistributedCache( + file: String, + rtype: LocalResourceType, + localResources: HashMap[String, LocalResource], + timestamp: String, + size: String, + vis: String) = { + val uri = new URI(file) + val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] + amJarRsrc.setType(rtype) + amJarRsrc.setVisibility(LocalResourceVisibility.valueOf(vis)) + amJarRsrc.setResource(ConverterUtils.getYarnUrlFromURI(uri)) + amJarRsrc.setTimestamp(timestamp.toLong) + amJarRsrc.setSize(size.toLong) + localResources(uri.getFragment()) = amJarRsrc + } + + def prepareLocalResources: HashMap[String, LocalResource] = { + logInfo("Preparing Local resources") + val localResources = HashMap[String, LocalResource]() + + if (System.getenv("SPARK_YARN_CACHE_FILES") != null) { + val timeStamps = System.getenv("SPARK_YARN_CACHE_FILES_TIME_STAMPS").split(',') + val fileSizes = System.getenv("SPARK_YARN_CACHE_FILES_FILE_SIZES").split(',') + val distFiles = System.getenv("SPARK_YARN_CACHE_FILES").split(',') + val visibilities = System.getenv("SPARK_YARN_CACHE_FILES_VISIBILITIES").split(',') + for( i <- 0 to distFiles.length - 1) { + setupDistributedCache(distFiles(i), LocalResourceType.FILE, localResources, timeStamps(i), + fileSizes(i), visibilities(i)) + } + } + + if (System.getenv("SPARK_YARN_CACHE_ARCHIVES") != null) { + val timeStamps = System.getenv("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS").split(',') + val fileSizes = System.getenv("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES").split(',') + val distArchives = System.getenv("SPARK_YARN_CACHE_ARCHIVES").split(',') + val visibilities = System.getenv("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES").split(',') + for( i <- 0 to distArchives.length - 1) { + setupDistributedCache(distArchives(i), LocalResourceType.ARCHIVE, localResources, + timeStamps(i), fileSizes(i), visibilities(i)) + } + } + + logInfo("Prepared Local resources " + localResources) + localResources + } + + def prepareEnvironment: HashMap[String, String] = { + val env = new HashMap[String, String]() + + Client.populateClasspath(yarnConf, System.getenv("SPARK_YARN_LOG4J_PATH") != null, env) + + // Allow users to specify some environment variables + Apps.setEnvFromInputString(env, System.getenv("SPARK_YARN_USER_ENV")) + + System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v } + env + } + +} diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala new file mode 100644 index 0000000000..c27257cda4 --- /dev/null +++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala @@ -0,0 +1,687 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.lang.{Boolean => JBoolean} +import java.util.{Collections, Set => JSet} +import java.util.concurrent.{CopyOnWriteArrayList, ConcurrentHashMap} +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection +import scala.collection.JavaConversions._ +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} + +import org.apache.spark.Logging +import org.apache.spark.scheduler.SplitInfo +import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedulerBackend} +import org.apache.spark.util.Utils + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.yarn.api.ApplicationMasterProtocol +import org.apache.hadoop.yarn.api.records.ApplicationAttemptId +import org.apache.hadoop.yarn.api.records.{Container, ContainerId, ContainerStatus} +import org.apache.hadoop.yarn.api.records.{Priority, Resource, ResourceRequest} +import org.apache.hadoop.yarn.api.protocolrecords.{AllocateRequest, AllocateResponse} +import org.apache.hadoop.yarn.client.api.AMRMClient +import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest +import org.apache.hadoop.yarn.util.{RackResolver, Records} + + +object AllocationType extends Enumeration ("HOST", "RACK", "ANY") { + type AllocationType = Value + val HOST, RACK, ANY = Value +} + +// TODO: +// Too many params. +// Needs to be mt-safe +// Need to refactor this to make it 'cleaner' ... right now, all computation is reactive - should +// make it more proactive and decoupled. + +// Note that right now, we assume all node asks as uniform in terms of capabilities and priority +// Refer to http://developer.yahoo.com/blogs/hadoop/posts/2011/03/mapreduce-nextgen-scheduler/ for +// more info on how we are requesting for containers. +private[yarn] class YarnAllocationHandler( + val conf: Configuration, + val amClient: AMRMClient[ContainerRequest], + val appAttemptId: ApplicationAttemptId, + val maxWorkers: Int, + val workerMemory: Int, + val workerCores: Int, + val preferredHostToCount: Map[String, Int], + val preferredRackToCount: Map[String, Int]) + extends Logging { + // These three are locked on allocatedHostToContainersMap. Complementary data structures + // allocatedHostToContainersMap : containers which are running : host, Set + // allocatedContainerToHostMap: container to host mapping. + private val allocatedHostToContainersMap = + new HashMap[String, collection.mutable.Set[ContainerId]]() + + private val allocatedContainerToHostMap = new HashMap[ContainerId, String]() + + // allocatedRackCount is populated ONLY if allocation happens (or decremented if this is an + // allocated node) + // As with the two data structures above, tightly coupled with them, and to be locked on + // allocatedHostToContainersMap + private val allocatedRackCount = new HashMap[String, Int]() + + // Containers which have been released. + private val releasedContainerList = new CopyOnWriteArrayList[ContainerId]() + // Containers to be released in next request to RM + private val pendingReleaseContainers = new ConcurrentHashMap[ContainerId, Boolean] + + // Number of container requests that have been sent to, but not yet allocated by the + // ApplicationMaster. + private val numPendingAllocate = new AtomicInteger() + private val numWorkersRunning = new AtomicInteger() + // Used to generate a unique id per worker + private val workerIdCounter = new AtomicInteger() + private val lastResponseId = new AtomicInteger() + private val numWorkersFailed = new AtomicInteger() + + def getNumPendingAllocate: Int = numPendingAllocate.intValue + + def getNumWorkersRunning: Int = numWorkersRunning.intValue + + def getNumWorkersFailed: Int = numWorkersFailed.intValue + + def isResourceConstraintSatisfied(container: Container): Boolean = { + container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + } + + def releaseContainer(container: Container) { + val containerId = container.getId + pendingReleaseContainers.put(containerId, true) + amClient.releaseAssignedContainer(containerId) + } + + def allocateResources() { + // We have already set the container request. Poll the ResourceManager for a response. + // This doubles as a heartbeat if there are no pending container requests. + val progressIndicator = 0.1f + val allocateResponse = amClient.allocate(progressIndicator) + + val allocatedContainers = allocateResponse.getAllocatedContainers() + if (allocatedContainers.size > 0) { + var numPendingAllocateNow = numPendingAllocate.addAndGet(-1 * allocatedContainers.size) + + if (numPendingAllocateNow < 0) { + numPendingAllocateNow = numPendingAllocate.addAndGet(-1 * numPendingAllocateNow) + } + + logDebug(""" + Allocated containers: %d + Current worker count: %d + Containers released: %s + Containers to-be-released: %s + Cluster resources: %s + """.format( + allocatedContainers.size, + numWorkersRunning.get(), + releasedContainerList, + pendingReleaseContainers, + allocateResponse.getAvailableResources)) + + val hostToContainers = new HashMap[String, ArrayBuffer[Container]]() + + for (container <- allocatedContainers) { + if (isResourceConstraintSatisfied(container)) { + // Add the accepted `container` to the host's list of already accepted, + // allocated containers + val host = container.getNodeId.getHost + val containersForHost = hostToContainers.getOrElseUpdate(host, + new ArrayBuffer[Container]()) + containersForHost += container + } else { + // Release container, since it doesn't satisfy resource constraints. + releaseContainer(container) + } + } + + // Find the appropriate containers to use. + // TODO: Cleanup this group-by... + val dataLocalContainers = new HashMap[String, ArrayBuffer[Container]]() + val rackLocalContainers = new HashMap[String, ArrayBuffer[Container]]() + val offRackContainers = new HashMap[String, ArrayBuffer[Container]]() + + for (candidateHost <- hostToContainers.keySet) { + val maxExpectedHostCount = preferredHostToCount.getOrElse(candidateHost, 0) + val requiredHostCount = maxExpectedHostCount - allocatedContainersOnHost(candidateHost) + + val remainingContainersOpt = hostToContainers.get(candidateHost) + assert(remainingContainersOpt.isDefined) + var remainingContainers = remainingContainersOpt.get + + if (requiredHostCount >= remainingContainers.size) { + // Since we have <= required containers, add all remaining containers to + // `dataLocalContainers`. + dataLocalContainers.put(candidateHost, remainingContainers) + // There are no more free containers remaining. + remainingContainers = null + } else if (requiredHostCount > 0) { + // Container list has more containers than we need for data locality. + // Split the list into two: one based on the data local container count, + // (`remainingContainers.size` - `requiredHostCount`), and the other to hold remaining + // containers. + val (dataLocal, remaining) = remainingContainers.splitAt( + remainingContainers.size - requiredHostCount) + dataLocalContainers.put(candidateHost, dataLocal) + + // Invariant: remainingContainers == remaining + + // YARN has a nasty habit of allocating a ton of containers on a host - discourage this. + // Add each container in `remaining` to list of containers to release. If we have an + // insufficient number of containers, then the next allocation cycle will reallocate + // (but won't treat it as data local). + // TODO(harvey): Rephrase this comment some more. + for (container <- remaining) releaseContainer(container) + remainingContainers = null + } + + // For rack local containers + if (remainingContainers != null) { + val rack = YarnAllocationHandler.lookupRack(conf, candidateHost) + if (rack != null) { + val maxExpectedRackCount = preferredRackToCount.getOrElse(rack, 0) + val requiredRackCount = maxExpectedRackCount - allocatedContainersOnRack(rack) - + rackLocalContainers.getOrElse(rack, List()).size + + if (requiredRackCount >= remainingContainers.size) { + // Add all remaining containers to to `dataLocalContainers`. + dataLocalContainers.put(rack, remainingContainers) + remainingContainers = null + } else if (requiredRackCount > 0) { + // Container list has more containers that we need for data locality. + // Split the list into two: one based on the data local container count, + // (`remainingContainers.size` - `requiredHostCount`), and the other to hold remaining + // containers. + val (rackLocal, remaining) = remainingContainers.splitAt( + remainingContainers.size - requiredRackCount) + val existingRackLocal = rackLocalContainers.getOrElseUpdate(rack, + new ArrayBuffer[Container]()) + + existingRackLocal ++= rackLocal + + remainingContainers = remaining + } + } + } + + if (remainingContainers != null) { + // Not all containers have been consumed - add them to the list of off-rack containers. + offRackContainers.put(candidateHost, remainingContainers) + } + } + + // Now that we have split the containers into various groups, go through them in order: + // first host-local, then rack-local, and finally off-rack. + // Note that the list we create below tries to ensure that not all containers end up within + // a host if there is a sufficiently large number of hosts/containers. + val allocatedContainersToProcess = new ArrayBuffer[Container](allocatedContainers.size) + allocatedContainersToProcess ++= ClusterScheduler.prioritizeContainers(dataLocalContainers) + allocatedContainersToProcess ++= ClusterScheduler.prioritizeContainers(rackLocalContainers) + allocatedContainersToProcess ++= ClusterScheduler.prioritizeContainers(offRackContainers) + + // Run each of the allocated containers. + for (container <- allocatedContainersToProcess) { + val numWorkersRunningNow = numWorkersRunning.incrementAndGet() + val workerHostname = container.getNodeId.getHost + val containerId = container.getId + + val workerMemoryOverhead = (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + assert(container.getResource.getMemory >= workerMemoryOverhead) + + if (numWorkersRunningNow > maxWorkers) { + logInfo("""Ignoring container %s at host %s, since we already have the required number of + containers for it.""".format(containerId, workerHostname)) + releaseContainer(container) + numWorkersRunning.decrementAndGet() + } else { + val workerId = workerIdCounter.incrementAndGet().toString + val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format( + System.getProperty("spark.driver.host"), + System.getProperty("spark.driver.port"), + CoarseGrainedSchedulerBackend.ACTOR_NAME) + + logInfo("Launching container %s for on host %s".format(containerId, workerHostname)) + + // To be safe, remove the container from `pendingReleaseContainers`. + pendingReleaseContainers.remove(containerId) + + val rack = YarnAllocationHandler.lookupRack(conf, workerHostname) + allocatedHostToContainersMap.synchronized { + val containerSet = allocatedHostToContainersMap.getOrElseUpdate(workerHostname, + new HashSet[ContainerId]()) + + containerSet += containerId + allocatedContainerToHostMap.put(containerId, workerHostname) + + if (rack != null) { + allocatedRackCount.put(rack, allocatedRackCount.getOrElse(rack, 0) + 1) + } + } + logInfo("Launching WorkerRunnable. driverUrl: %s, workerHostname: %s".format(driverUrl, workerHostname)) + val workerRunnable = new WorkerRunnable( + container, + conf, + driverUrl, + workerId, + workerHostname, + workerMemory, + workerCores) + new Thread(workerRunnable).start() + } + } + logDebug(""" + Finished allocating %s containers (from %s originally). + Current number of workers running: %d, + releasedContainerList: %s, + pendingReleaseContainers: %s + """.format( + allocatedContainersToProcess, + allocatedContainers, + numWorkersRunning.get(), + releasedContainerList, + pendingReleaseContainers)) + } + + val completedContainers = allocateResponse.getCompletedContainersStatuses() + if (completedContainers.size > 0) { + logDebug("Completed %d containers".format(completedContainers.size)) + + for (completedContainer <- completedContainers) { + val containerId = completedContainer.getContainerId + + if (pendingReleaseContainers.containsKey(containerId)) { + // YarnAllocationHandler already marked the container for release, so remove it from + // `pendingReleaseContainers`. + pendingReleaseContainers.remove(containerId) + } else { + // Decrement the number of workers running. The next iteration of the ApplicationMaster's + // reporting thread will take care of allocating. + numWorkersRunning.decrementAndGet() + logInfo("Completed container %s (state: %s, exit status: %s)".format( + containerId, + completedContainer.getState, + completedContainer.getExitStatus())) + // Hadoop 2.2.X added a ContainerExitStatus we should switch to use + // there are some exit status' we shouldn't necessarily count against us, but for + // now I think its ok as none of the containers are expected to exit + if (completedContainer.getExitStatus() != 0) { + logInfo("Container marked as failed: " + containerId) + numWorkersFailed.incrementAndGet() + } + } + + allocatedHostToContainersMap.synchronized { + if (allocatedContainerToHostMap.containsKey(containerId)) { + val hostOpt = allocatedContainerToHostMap.get(containerId) + assert(hostOpt.isDefined) + val host = hostOpt.get + + val containerSetOpt = allocatedHostToContainersMap.get(host) + assert(containerSetOpt.isDefined) + val containerSet = containerSetOpt.get + + containerSet.remove(containerId) + if (containerSet.isEmpty) { + allocatedHostToContainersMap.remove(host) + } else { + allocatedHostToContainersMap.update(host, containerSet) + } + + allocatedContainerToHostMap.remove(containerId) + + // TODO: Move this part outside the synchronized block? + val rack = YarnAllocationHandler.lookupRack(conf, host) + if (rack != null) { + val rackCount = allocatedRackCount.getOrElse(rack, 0) - 1 + if (rackCount > 0) { + allocatedRackCount.put(rack, rackCount) + } else { + allocatedRackCount.remove(rack) + } + } + } + } + } + logDebug(""" + Finished processing %d completed containers. + Current number of workers running: %d, + releasedContainerList: %s, + pendingReleaseContainers: %s + """.format( + completedContainers.size, + numWorkersRunning.get(), + releasedContainerList, + pendingReleaseContainers)) + } + } + + def createRackResourceRequests( + hostContainers: ArrayBuffer[ContainerRequest] + ): ArrayBuffer[ContainerRequest] = { + // Generate modified racks and new set of hosts under it before issuing requests. + val rackToCounts = new HashMap[String, Int]() + + for (container <- hostContainers) { + val candidateHost = container.getNodes.last + assert(YarnAllocationHandler.ANY_HOST != candidateHost) + + val rack = YarnAllocationHandler.lookupRack(conf, candidateHost) + if (rack != null) { + var count = rackToCounts.getOrElse(rack, 0) + count += 1 + rackToCounts.put(rack, count) + } + } + + val requestedContainers = new ArrayBuffer[ContainerRequest](rackToCounts.size) + for ((rack, count) <- rackToCounts) { + requestedContainers ++= createResourceRequests( + AllocationType.RACK, + rack, + count, + YarnAllocationHandler.PRIORITY) + } + + requestedContainers + } + + def allocatedContainersOnHost(host: String): Int = { + var retval = 0 + allocatedHostToContainersMap.synchronized { + retval = allocatedHostToContainersMap.getOrElse(host, Set()).size + } + retval + } + + def allocatedContainersOnRack(rack: String): Int = { + var retval = 0 + allocatedHostToContainersMap.synchronized { + retval = allocatedRackCount.getOrElse(rack, 0) + } + retval + } + + def addResourceRequests(numWorkers: Int) { + val containerRequests: List[ContainerRequest] = + if (numWorkers <= 0 || preferredHostToCount.isEmpty) { + logDebug("numWorkers: " + numWorkers + ", host preferences: " + + preferredHostToCount.isEmpty) + createResourceRequests( + AllocationType.ANY, + resource = null, + numWorkers, + YarnAllocationHandler.PRIORITY).toList + } else { + // Request for all hosts in preferred nodes and for numWorkers - + // candidates.size, request by default allocation policy. + val hostContainerRequests = new ArrayBuffer[ContainerRequest](preferredHostToCount.size) + for ((candidateHost, candidateCount) <- preferredHostToCount) { + val requiredCount = candidateCount - allocatedContainersOnHost(candidateHost) + + if (requiredCount > 0) { + hostContainerRequests ++= createResourceRequests( + AllocationType.HOST, + candidateHost, + requiredCount, + YarnAllocationHandler.PRIORITY) + } + } + val rackContainerRequests: List[ContainerRequest] = createRackResourceRequests( + hostContainerRequests).toList + + val anyContainerRequests = createResourceRequests( + AllocationType.ANY, + resource = null, + numWorkers, + YarnAllocationHandler.PRIORITY) + + val containerRequestBuffer = new ArrayBuffer[ContainerRequest]( + hostContainerRequests.size + rackContainerRequests.size() + anyContainerRequests.size) + + containerRequestBuffer ++= hostContainerRequests + containerRequestBuffer ++= rackContainerRequests + containerRequestBuffer ++= anyContainerRequests + containerRequestBuffer.toList + } + + for (request <- containerRequests) { + amClient.addContainerRequest(request) + } + + if (numWorkers > 0) { + numPendingAllocate.addAndGet(numWorkers) + logInfo("Will Allocate %d worker containers, each with %d memory".format( + numWorkers, + (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD))) + } else { + logDebug("Empty allocation request ...") + } + + for (request <- containerRequests) { + val nodes = request.getNodes + var hostStr = if (nodes == null || nodes.isEmpty) { + "Any" + } else { + nodes.last + } + logInfo("Container request (host: %s, priority: %s, capability: %s".format( + hostStr, + request.getPriority().getPriority, + request.getCapability)) + } + } + + private def createResourceRequests( + requestType: AllocationType.AllocationType, + resource: String, + numWorkers: Int, + priority: Int + ): ArrayBuffer[ContainerRequest] = { + + // If hostname is specified, then we need at least two requests - node local and rack local. + // There must be a third request, which is ANY. That will be specially handled. + requestType match { + case AllocationType.HOST => { + assert(YarnAllocationHandler.ANY_HOST != resource) + val hostname = resource + val nodeLocal = constructContainerRequests( + Array(hostname), + racks = null, + numWorkers, + priority) + + // Add `hostname` to the global (singleton) host->rack mapping in YarnAllocationHandler. + YarnAllocationHandler.populateRackInfo(conf, hostname) + nodeLocal + } + case AllocationType.RACK => { + val rack = resource + constructContainerRequests(hosts = null, Array(rack), numWorkers, priority) + } + case AllocationType.ANY => constructContainerRequests( + hosts = null, racks = null, numWorkers, priority) + case _ => throw new IllegalArgumentException( + "Unexpected/unsupported request type: " + requestType) + } + } + + private def constructContainerRequests( + hosts: Array[String], + racks: Array[String], + numWorkers: Int, + priority: Int + ): ArrayBuffer[ContainerRequest] = { + + val memoryResource = Records.newRecord(classOf[Resource]) + memoryResource.setMemory(workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + + val prioritySetting = Records.newRecord(classOf[Priority]) + prioritySetting.setPriority(priority) + + val requests = new ArrayBuffer[ContainerRequest]() + for (i <- 0 until numWorkers) { + requests += new ContainerRequest(memoryResource, hosts, racks, prioritySetting) + } + requests + } +} + +object YarnAllocationHandler { + + val ANY_HOST = "*" + // All requests are issued with same priority : we do not (yet) have any distinction between + // request types (like map/reduce in hadoop for example) + val PRIORITY = 1 + + // Additional memory overhead - in mb. + val MEMORY_OVERHEAD = 384 + + // Host to rack map - saved from allocation requests. We are expecting this not to change. + // Note that it is possible for this to change : and ResurceManager will indicate that to us via + // update response to allocate. But we are punting on handling that for now. + private val hostToRack = new ConcurrentHashMap[String, String]() + private val rackToHostSet = new ConcurrentHashMap[String, JSet[String]]() + + + def newAllocator( + conf: Configuration, + amClient: AMRMClient[ContainerRequest], + appAttemptId: ApplicationAttemptId, + args: ApplicationMasterArguments + ): YarnAllocationHandler = { + new YarnAllocationHandler( + conf, + amClient, + appAttemptId, + args.numWorkers, + args.workerMemory, + args.workerCores, + Map[String, Int](), + Map[String, Int]()) + } + + def newAllocator( + conf: Configuration, + amClient: AMRMClient[ContainerRequest], + appAttemptId: ApplicationAttemptId, + args: ApplicationMasterArguments, + map: collection.Map[String, + collection.Set[SplitInfo]] + ): YarnAllocationHandler = { + val (hostToSplitCount, rackToSplitCount) = generateNodeToWeight(conf, map) + new YarnAllocationHandler( + conf, + amClient, + appAttemptId, + args.numWorkers, + args.workerMemory, + args.workerCores, + hostToSplitCount, + rackToSplitCount) + } + + def newAllocator( + conf: Configuration, + amClient: AMRMClient[ContainerRequest], + appAttemptId: ApplicationAttemptId, + maxWorkers: Int, + workerMemory: Int, + workerCores: Int, + map: collection.Map[String, collection.Set[SplitInfo]] + ): YarnAllocationHandler = { + val (hostToCount, rackToCount) = generateNodeToWeight(conf, map) + new YarnAllocationHandler( + conf, + amClient, + appAttemptId, + maxWorkers, + workerMemory, + workerCores, + hostToCount, + rackToCount) + } + + // A simple method to copy the split info map. + private def generateNodeToWeight( + conf: Configuration, + input: collection.Map[String, collection.Set[SplitInfo]] + ): (Map[String, Int], Map[String, Int]) = { + + if (input == null) { + return (Map[String, Int](), Map[String, Int]()) + } + + val hostToCount = new HashMap[String, Int] + val rackToCount = new HashMap[String, Int] + + for ((host, splits) <- input) { + val hostCount = hostToCount.getOrElse(host, 0) + hostToCount.put(host, hostCount + splits.size) + + val rack = lookupRack(conf, host) + if (rack != null){ + val rackCount = rackToCount.getOrElse(host, 0) + rackToCount.put(host, rackCount + splits.size) + } + } + + (hostToCount.toMap, rackToCount.toMap) + } + + def lookupRack(conf: Configuration, host: String): String = { + if (!hostToRack.contains(host)) { + populateRackInfo(conf, host) + } + hostToRack.get(host) + } + + def fetchCachedHostsForRack(rack: String): Option[Set[String]] = { + Option(rackToHostSet.get(rack)).map { set => + val convertedSet: collection.mutable.Set[String] = set + // TODO: Better way to get a Set[String] from JSet. + convertedSet.toSet + } + } + + def populateRackInfo(conf: Configuration, hostname: String) { + Utils.checkHost(hostname) + + if (!hostToRack.containsKey(hostname)) { + // If there are repeated failures to resolve, all to an ignore list. + val rackInfo = RackResolver.resolve(conf, hostname) + if (rackInfo != null && rackInfo.getNetworkLocation != null) { + val rack = rackInfo.getNetworkLocation + hostToRack.put(hostname, rack) + if (! rackToHostSet.containsKey(rack)) { + rackToHostSet.putIfAbsent(rack, + Collections.newSetFromMap(new ConcurrentHashMap[String, JBoolean]())) + } + rackToHostSet.get(rack).add(hostname) + + // TODO(harvey): Figure out what this comment means... + // Since RackResolver caches, we are disabling this for now ... + } /* else { + // right ? Else we will keep calling rack resolver in case we cant resolve rack info ... + hostToRack.put(hostname, null) + } */ + } + } +} diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala new file mode 100644 index 0000000000..2ba2366ead --- /dev/null +++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.conf.Configuration + +/** + * Contains util methods to interact with Hadoop from spark. + */ +class YarnSparkHadoopUtil extends SparkHadoopUtil { + + // Note that all params which start with SPARK are propagated all the way through, so if in yarn mode, this MUST be set to true. + override def isYarnMode(): Boolean = { true } + + // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop subsystems + // Always create a new config, dont reuse yarnConf. + override def newConfiguration(): Configuration = new YarnConfiguration(new Configuration()) + + // add any user credentials to the job conf which are necessary for running on a secure Hadoop cluster + override def addCredentials(conf: JobConf) { + val jobCreds = conf.getCredentials() + jobCreds.mergeAll(UserGroupInformation.getCurrentUser().getCredentials()) + } +} diff --git a/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala new file mode 100644 index 0000000000..63a0449e5a --- /dev/null +++ b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import org.apache.spark._ +import org.apache.hadoop.conf.Configuration +import org.apache.spark.deploy.yarn.YarnAllocationHandler +import org.apache.spark.util.Utils + +/** + * + * This scheduler launch worker through Yarn - by call into Client to launch WorkerLauncher as AM. + */ +private[spark] class YarnClientClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) { + + def this(sc: SparkContext) = this(sc, new Configuration()) + + // By default, rack is unknown + override def getRackForHost(hostPort: String): Option[String] = { + val host = Utils.parseHostPort(hostPort)._1 + val retval = YarnAllocationHandler.lookupRack(conf, host) + if (retval != null) Some(retval) else None + } + + override def postStartHook() { + + // The yarn application is running, but the worker might not yet ready + // Wait for a few seconds for the slaves to bootstrap and register with master - best case attempt + Thread.sleep(2000L) + logInfo("YarnClientClusterScheduler.postStartHook done") + } +} diff --git a/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala new file mode 100644 index 0000000000..b206780c78 --- /dev/null +++ b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState} +import org.apache.spark.{SparkException, Logging, SparkContext} +import org.apache.spark.deploy.yarn.{Client, ClientArguments} + +private[spark] class YarnClientSchedulerBackend( + scheduler: ClusterScheduler, + sc: SparkContext) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + with Logging { + + var client: Client = null + var appId: ApplicationId = null + + override def start() { + super.start() + + val defalutWorkerCores = "2" + val defalutWorkerMemory = "512m" + val defaultWorkerNumber = "1" + + val userJar = System.getenv("SPARK_YARN_APP_JAR") + var workerCores = System.getenv("SPARK_WORKER_CORES") + var workerMemory = System.getenv("SPARK_WORKER_MEMORY") + var workerNumber = System.getenv("SPARK_WORKER_INSTANCES") + + if (userJar == null) + throw new SparkException("env SPARK_YARN_APP_JAR is not set") + + if (workerCores == null) + workerCores = defalutWorkerCores + if (workerMemory == null) + workerMemory = defalutWorkerMemory + if (workerNumber == null) + workerNumber = defaultWorkerNumber + + val driverHost = System.getProperty("spark.driver.host") + val driverPort = System.getProperty("spark.driver.port") + val hostport = driverHost + ":" + driverPort + + val argsArray = Array[String]( + "--class", "notused", + "--jar", userJar, + "--args", hostport, + "--worker-memory", workerMemory, + "--worker-cores", workerCores, + "--num-workers", workerNumber, + "--master-class", "org.apache.spark.deploy.yarn.WorkerLauncher" + ) + + val args = new ClientArguments(argsArray) + client = new Client(args) + appId = client.runApp() + waitForApp() + } + + def waitForApp() { + + // TODO : need a better way to find out whether the workers are ready or not + // maybe by resource usage report? + while(true) { + val report = client.getApplicationReport(appId) + + logInfo("Application report from ASM: \n" + + "\t appMasterRpcPort: " + report.getRpcPort() + "\n" + + "\t appStartTime: " + report.getStartTime() + "\n" + + "\t yarnAppState: " + report.getYarnApplicationState() + "\n" + ) + + // Ready to go, or already gone. + val state = report.getYarnApplicationState() + if (state == YarnApplicationState.RUNNING) { + return + } else if (state == YarnApplicationState.FINISHED || + state == YarnApplicationState.FAILED || + state == YarnApplicationState.KILLED) { + throw new SparkException("Yarn application already ended," + + "might be killed or not able to launch application master.") + } + + Thread.sleep(1000) + } + } + + override def stop() { + super.stop() + client.stop() + logInfo("Stoped") + } + +} diff --git a/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala new file mode 100644 index 0000000000..29b3f22e13 --- /dev/null +++ b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import org.apache.spark._ +import org.apache.spark.deploy.yarn.{ApplicationMaster, YarnAllocationHandler} +import org.apache.spark.util.Utils +import org.apache.hadoop.conf.Configuration + +/** + * + * This is a simple extension to ClusterScheduler - to ensure that appropriate initialization of ApplicationMaster, etc is done + */ +private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) { + + logInfo("Created YarnClusterScheduler") + + def this(sc: SparkContext) = this(sc, new Configuration()) + + // Nothing else for now ... initialize application master : which needs sparkContext to determine how to allocate + // Note that only the first creation of SparkContext influences (and ideally, there must be only one SparkContext, right ?) + // Subsequent creations are ignored - since nodes are already allocated by then. + + + // By default, rack is unknown + override def getRackForHost(hostPort: String): Option[String] = { + val host = Utils.parseHostPort(hostPort)._1 + val retval = YarnAllocationHandler.lookupRack(conf, host) + if (retval != null) Some(retval) else None + } + + override def postStartHook() { + val sparkContextInitialized = ApplicationMaster.sparkContextInitialized(sc) + if (sparkContextInitialized){ + // Wait for a few seconds for the slaves to bootstrap and register with master - best case attempt + Thread.sleep(3000L) + } + logInfo("YarnClusterScheduler.postStartHook done") + } +} diff --git a/new-yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala b/new-yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala new file mode 100644 index 0000000000..2941356bc5 --- /dev/null +++ b/new-yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.net.URI + +import org.scalatest.FunSuite +import org.scalatest.mock.MockitoSugar +import org.mockito.Mockito.when + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileStatus +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.permission.FsAction +import org.apache.hadoop.yarn.api.records.LocalResource +import org.apache.hadoop.yarn.api.records.LocalResourceVisibility +import org.apache.hadoop.yarn.api.records.LocalResourceType +import org.apache.hadoop.yarn.util.{Records, ConverterUtils} + +import scala.collection.mutable.HashMap +import scala.collection.mutable.Map + + +class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { + + class MockClientDistributedCacheManager extends ClientDistributedCacheManager { + override def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): + LocalResourceVisibility = { + return LocalResourceVisibility.PRIVATE + } + } + + test("test getFileStatus empty") { + val distMgr = new ClientDistributedCacheManager() + val fs = mock[FileSystem] + val uri = new URI("/tmp/testing") + when(fs.getFileStatus(new Path(uri))).thenReturn(new FileStatus()) + val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() + val stat = distMgr.getFileStatus(fs, uri, statCache) + assert(stat.getPath() === null) + } + + test("test getFileStatus cached") { + val distMgr = new ClientDistributedCacheManager() + val fs = mock[FileSystem] + val uri = new URI("/tmp/testing") + val realFileStatus = new FileStatus(10, false, 1, 1024, 10, 10, null, "testOwner", + null, new Path("/tmp/testing")) + when(fs.getFileStatus(new Path(uri))).thenReturn(new FileStatus()) + val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus](uri -> realFileStatus) + val stat = distMgr.getFileStatus(fs, uri, statCache) + assert(stat.getPath().toString() === "/tmp/testing") + } + + test("test addResource") { + val distMgr = new MockClientDistributedCacheManager() + val fs = mock[FileSystem] + val conf = new Configuration() + val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing") + val localResources = HashMap[String, LocalResource]() + val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() + when(fs.getFileStatus(destPath)).thenReturn(new FileStatus()) + + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, "link", + statCache, false) + val resource = localResources("link") + assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE) + assert(ConverterUtils.getPathFromYarnURL(resource.getResource()) === destPath) + assert(resource.getTimestamp() === 0) + assert(resource.getSize() === 0) + assert(resource.getType() === LocalResourceType.FILE) + + val env = new HashMap[String, String]() + distMgr.setDistFilesEnv(env) + assert(env("SPARK_YARN_CACHE_FILES") === "file:/foo.invalid.com:8080/tmp/testing#link") + assert(env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") === "0") + assert(env("SPARK_YARN_CACHE_FILES_FILE_SIZES") === "0") + assert(env("SPARK_YARN_CACHE_FILES_VISIBILITIES") === LocalResourceVisibility.PRIVATE.name()) + + distMgr.setDistArchivesEnv(env) + assert(env.get("SPARK_YARN_CACHE_ARCHIVES") === None) + assert(env.get("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") === None) + assert(env.get("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") === None) + assert(env.get("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") === None) + + //add another one and verify both there and order correct + val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", + null, new Path("/tmp/testing2")) + val destPath2 = new Path("file:///foo.invalid.com:8080/tmp/testing2") + when(fs.getFileStatus(destPath2)).thenReturn(realFileStatus) + distMgr.addResource(fs, conf, destPath2, localResources, LocalResourceType.FILE, "link2", + statCache, false) + val resource2 = localResources("link2") + assert(resource2.getVisibility() === LocalResourceVisibility.PRIVATE) + assert(ConverterUtils.getPathFromYarnURL(resource2.getResource()) === destPath2) + assert(resource2.getTimestamp() === 10) + assert(resource2.getSize() === 20) + assert(resource2.getType() === LocalResourceType.FILE) + + val env2 = new HashMap[String, String]() + distMgr.setDistFilesEnv(env2) + val timestamps = env2("SPARK_YARN_CACHE_FILES_TIME_STAMPS").split(',') + val files = env2("SPARK_YARN_CACHE_FILES").split(',') + val sizes = env2("SPARK_YARN_CACHE_FILES_FILE_SIZES").split(',') + val visibilities = env2("SPARK_YARN_CACHE_FILES_VISIBILITIES") .split(',') + assert(files(0) === "file:/foo.invalid.com:8080/tmp/testing#link") + assert(timestamps(0) === "0") + assert(sizes(0) === "0") + assert(visibilities(0) === LocalResourceVisibility.PRIVATE.name()) + + assert(files(1) === "file:/foo.invalid.com:8080/tmp/testing2#link2") + assert(timestamps(1) === "10") + assert(sizes(1) === "20") + assert(visibilities(1) === LocalResourceVisibility.PRIVATE.name()) + } + + test("test addResource link null") { + val distMgr = new MockClientDistributedCacheManager() + val fs = mock[FileSystem] + val conf = new Configuration() + val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing") + val localResources = HashMap[String, LocalResource]() + val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() + when(fs.getFileStatus(destPath)).thenReturn(new FileStatus()) + + intercept[Exception] { + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, null, + statCache, false) + } + assert(localResources.get("link") === None) + assert(localResources.size === 0) + } + + test("test addResource appmaster only") { + val distMgr = new MockClientDistributedCacheManager() + val fs = mock[FileSystem] + val conf = new Configuration() + val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing") + val localResources = HashMap[String, LocalResource]() + val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() + val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", + null, new Path("/tmp/testing")) + when(fs.getFileStatus(destPath)).thenReturn(realFileStatus) + + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link", + statCache, true) + val resource = localResources("link") + assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE) + assert(ConverterUtils.getPathFromYarnURL(resource.getResource()) === destPath) + assert(resource.getTimestamp() === 10) + assert(resource.getSize() === 20) + assert(resource.getType() === LocalResourceType.ARCHIVE) + + val env = new HashMap[String, String]() + distMgr.setDistFilesEnv(env) + assert(env.get("SPARK_YARN_CACHE_FILES") === None) + assert(env.get("SPARK_YARN_CACHE_FILES_TIME_STAMPS") === None) + assert(env.get("SPARK_YARN_CACHE_FILES_FILE_SIZES") === None) + assert(env.get("SPARK_YARN_CACHE_FILES_VISIBILITIES") === None) + + distMgr.setDistArchivesEnv(env) + assert(env.get("SPARK_YARN_CACHE_ARCHIVES") === None) + assert(env.get("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") === None) + assert(env.get("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") === None) + assert(env.get("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") === None) + } + + test("test addResource archive") { + val distMgr = new MockClientDistributedCacheManager() + val fs = mock[FileSystem] + val conf = new Configuration() + val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing") + val localResources = HashMap[String, LocalResource]() + val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() + val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", + null, new Path("/tmp/testing")) + when(fs.getFileStatus(destPath)).thenReturn(realFileStatus) + + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link", + statCache, false) + val resource = localResources("link") + assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE) + assert(ConverterUtils.getPathFromYarnURL(resource.getResource()) === destPath) + assert(resource.getTimestamp() === 10) + assert(resource.getSize() === 20) + assert(resource.getType() === LocalResourceType.ARCHIVE) + + val env = new HashMap[String, String]() + + distMgr.setDistArchivesEnv(env) + assert(env("SPARK_YARN_CACHE_ARCHIVES") === "file:/foo.invalid.com:8080/tmp/testing#link") + assert(env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") === "10") + assert(env("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") === "20") + assert(env("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") === LocalResourceVisibility.PRIVATE.name()) + + distMgr.setDistFilesEnv(env) + assert(env.get("SPARK_YARN_CACHE_FILES") === None) + assert(env.get("SPARK_YARN_CACHE_FILES_TIME_STAMPS") === None) + assert(env.get("SPARK_YARN_CACHE_FILES_FILE_SIZES") === None) + assert(env.get("SPARK_YARN_CACHE_FILES_VISIBILITIES") === None) + } + + +} diff --git a/pom.xml b/pom.xml index edcc3b35cd..57e843596f 100644 --- a/pom.xml +++ b/pom.xml @@ -99,13 +99,17 @@ UTF-8 UTF-8 - 1.5 - 2.9.3 + 1.6 + + 2.10.3 + 2.10 0.13.0 - 2.0.5 + org.spark-project.akka + 2.2.3-shaded-protobuf 1.7.2 1.2.17 1.0.4 + 2.4.1 0.23.7 0.94.6 @@ -114,32 +118,21 @@ + + maven-repo + Maven Repository + http://repo.maven.apache.org/maven2 + + true + + + false + + jboss-repo JBoss Repository - http://repository.jboss.org/nexus/content/repositories/releases/ - - true - - - false - - - - cloudera-repo - Cloudera Repository - https://repository.cloudera.com/artifactory/cloudera-repos/ - - true - - - false - - - - akka-repo - Akka Repository - http://repo.akka.io/releases/ + http://repository.jboss.org/nexus/content/repositories/releases true @@ -150,7 +143,7 @@ mqtt-repo MQTT Repository - https://repo.eclipse.org/content/repositories/paho-releases/ + https://repo.eclipse.org/content/repositories/paho-releases true @@ -159,41 +152,6 @@ - - - oss-sonatype-releases - OSS Sonatype - https://oss.sonatype.org/content/repositories/releases - - true - - - false - - - - oss-sonatype-snapshots - OSS Sonatype - https://oss.sonatype.org/content/repositories/snapshots - - false - - - true - - - - oss-sonatype - OSS Sonatype - https://oss.sonatype.org/content/groups/public - - true - - - true - - - @@ -242,14 +200,19 @@ asm 4.0 + com.google.protobuf protobuf-java - 2.4.1 + ${protobuf.version} com.twitter - chill_2.9.3 + chill_${scala.binary.version} 0.3.1 @@ -258,8 +221,8 @@ 0.3.1 - com.typesafe.akka - akka-actor + ${akka.group} + akka-actor_${scala.binary.version} ${akka.version} @@ -269,8 +232,8 @@ - com.typesafe.akka - akka-remote + ${akka.group} + akka-remote_${scala.binary.version} ${akka.version} @@ -280,8 +243,19 @@ - com.typesafe.akka - akka-slf4j + ${akka.group} + akka-slf4j_${scala.binary.version} + ${akka.version} + + + org.jboss.netty + netty + + + + + ${akka.group} + akka-zeromq_${scala.binary.version} ${akka.version} @@ -300,11 +274,6 @@ colt 1.2.0 - - com.github.scala-incubator.io - scala-io-file_2.9.2 - 0.4.1 - org.apache.mesos mesos @@ -313,7 +282,7 @@ io.netty netty-all - 4.0.0.Beta2 + 4.0.0.CR1 org.apache.derby @@ -323,8 +292,14 @@ net.liftweb - lift-json_2.9.2 - 2.5 + lift-json_${scala.binary.version} + 2.5.1 + + + org.scala-lang + scalap + + com.codahale.metrics @@ -346,6 +321,11 @@ metrics-ganglia 3.0.0 + + com.codahale.metrics + metrics-graphite + 3.0.0 + org.scala-lang scala-compiler @@ -361,24 +341,22 @@ scala-library ${scala.version} - - org.scala-lang - scalap - ${scala.version} - - log4j log4j ${log4j.version} - org.scalatest - scalatest_2.9.3 + scalatest_${scala.binary.version} 1.9.1 test + + commons-io + commons-io + 2.4 + org.easymock easymock @@ -393,7 +371,7 @@ org.scalacheck - scalacheck_2.9.3 + scalacheck_${scala.binary.version} 1.10.0 test @@ -488,6 +466,7 @@ + org.apache.hadoop hadoop-yarn-client @@ -526,6 +505,10 @@ org.jboss.netty netty + + io.netty + netty + @@ -739,6 +722,7 @@ 2 0.23.7 + 2.5.0 @@ -750,7 +734,7 @@ maven-root Maven root repository - http://repo1.maven.org/maven2/ + http://repo1.maven.org/maven2 true @@ -765,6 +749,39 @@ + + + new-yarn + + 2 + 2.2.0 + 2.5.0 + + + + new-yarn + + + + + maven-root + Maven root repository + http://repo1.maven.org/maven2 + + true + + + false + + + + + + + + + + repl-bin diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index bccf36c00f..29f4a4b9ff 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -28,14 +28,19 @@ object SparkBuild extends Build { // "2.0.0-mr1-cdh4.2.0" for Cloudera Hadoop. Note that these variables can be set // through the environment variables SPARK_HADOOP_VERSION and SPARK_YARN. val DEFAULT_HADOOP_VERSION = "1.0.4" + + // Whether the Hadoop version to build against is 2.2.x, or a variant of it. This can be set + // through the SPARK_IS_NEW_HADOOP environment variable. + val DEFAULT_IS_NEW_HADOOP = false + val DEFAULT_YARN = false // HBase version; set as appropriate. val HBASE_VERSION = "0.94.6" // Target JVM version - val SCALAC_JVM_VERSION = "jvm-1.5" - val JAVAC_JVM_VERSION = "1.5" + val SCALAC_JVM_VERSION = "jvm-1.6" + val JAVAC_JVM_VERSION = "1.6" lazy val root = Project("root", file("."), settings = rootSettings) aggregate(allProjects: _*) @@ -55,8 +60,6 @@ object SparkBuild extends Build { lazy val mllib = Project("mllib", file("mllib"), settings = mllibSettings) dependsOn(core) - lazy val yarn = Project("yarn", file("yarn"), settings = yarnSettings) dependsOn(core) - lazy val assemblyProj = Project("assembly", file("assembly"), settings = assemblyProjSettings) .dependsOn(core, bagel, mllib, repl, streaming) dependsOn(maybeYarn: _*) @@ -68,14 +71,26 @@ object SparkBuild extends Build { // Allows build configuration to be set through environment variables lazy val hadoopVersion = scala.util.Properties.envOrElse("SPARK_HADOOP_VERSION", DEFAULT_HADOOP_VERSION) + lazy val isNewHadoop = scala.util.Properties.envOrNone("SPARK_IS_NEW_HADOOP") match { + case None => { + val isNewHadoopVersion = "2.[2-9]+".r.findFirstIn(hadoopVersion).isDefined + (isNewHadoopVersion|| DEFAULT_IS_NEW_HADOOP) + } + case Some(v) => v.toBoolean + } + lazy val isYarnEnabled = scala.util.Properties.envOrNone("SPARK_YARN") match { case None => DEFAULT_YARN case Some(v) => v.toBoolean } // Conditionally include the yarn sub-project - lazy val maybeYarn = if(isYarnEnabled) Seq[ClasspathDependency](yarn) else Seq[ClasspathDependency]() - lazy val maybeYarnRef = if(isYarnEnabled) Seq[ProjectReference](yarn) else Seq[ProjectReference]() + lazy val yarn = Project("yarn", file(if (isNewHadoop) "new-yarn" else "yarn"), settings = yarnSettings) dependsOn(core) + + //lazy val yarn = Project("yarn", file("yarn"), settings = yarnSettings) dependsOn(core) + + lazy val maybeYarn = if (isYarnEnabled) Seq[ClasspathDependency](yarn) else Seq[ClasspathDependency]() + lazy val maybeYarnRef = if (isYarnEnabled) Seq[ProjectReference](yarn) else Seq[ProjectReference]() // Everything except assembly, tools and examples belong to packageProjects lazy val packageProjects = Seq[ProjectReference](core, repl, bagel, streaming, mllib) ++ maybeYarnRef @@ -83,9 +98,9 @@ object SparkBuild extends Build { lazy val allProjects = packageProjects ++ Seq[ProjectReference](examples, tools, assemblyProj) def sharedSettings = Defaults.defaultSettings ++ Seq( - organization := "org.apache.spark", - version := "0.9.0-incubating-SNAPSHOT", - scalaVersion := "2.9.3", + organization := "org.apache.spark", + version := "0.9.0-incubating-SNAPSHOT", + scalaVersion := "2.10.3", scalacOptions := Seq("-Xmax-classfile-name", "120", "-unchecked", "-deprecation", "-target:" + SCALAC_JVM_VERSION), javacOptions := Seq("-target", JAVAC_JVM_VERSION, "-source", JAVAC_JVM_VERSION), @@ -105,12 +120,6 @@ object SparkBuild extends Build { // also check the local Maven repository ~/.m2 resolvers ++= Seq(Resolver.file("Local Maven Repo", file(Path.userHome + "/.m2/repository"))), - // Shared between both core and streaming. - resolvers ++= Seq("Akka Repository" at "http://repo.akka.io/releases/"), - - // Shared between both examples and streaming. - resolvers ++= Seq("Mqtt Repository" at "https://repo.eclipse.org/content/repositories/paho-releases/"), - // For Sonatype publishing resolvers ++= Seq("sonatype-snapshots" at "https://oss.sonatype.org/content/repositories/snapshots", "sonatype-staging" at "https://oss.sonatype.org/service/local/staging/deploy/maven2/"), @@ -166,13 +175,17 @@ object SparkBuild extends Build { libraryDependencies ++= Seq( - "org.eclipse.jetty" % "jetty-server" % "7.6.8.v20121106", - "org.scalatest" %% "scalatest" % "1.9.1" % "test", - "org.scalacheck" %% "scalacheck" % "1.10.0" % "test", - "com.novocode" % "junit-interface" % "0.9" % "test", - "org.easymock" % "easymock" % "3.1" % "test", - "org.mockito" % "mockito-all" % "1.8.5" % "test" + "io.netty" % "netty-all" % "4.0.0.CR1", + "org.eclipse.jetty" % "jetty-server" % "7.6.8.v20121106", + "org.scalatest" %% "scalatest" % "1.9.1" % "test", + "org.scalacheck" %% "scalacheck" % "1.10.0" % "test", + "com.novocode" % "junit-interface" % "0.9" % "test", + "org.easymock" % "easymock" % "3.1" % "test", + "org.mockito" % "mockito-all" % "1.8.5" % "test", + "commons-io" % "commons-io" % "2.4" % "test" ), + + parallelExecution := true, /* Workaround for issue #206 (fixed after SBT 0.11.0) */ watchTransitiveSources <<= Defaults.inDependencies[Task[Seq[File]]](watchSources.task, const(std.TaskExtra.constant(Nil)), aggregate = true, includeRoot = true) apply { _.join.map(_.flatten) }, @@ -197,61 +210,61 @@ object SparkBuild extends Build { def coreSettings = sharedSettings ++ Seq( name := "spark-core", resolvers ++= Seq( - "JBoss Repository" at "http://repository.jboss.org/nexus/content/repositories/releases/", - "Cloudera Repository" at "https://repository.cloudera.com/artifactory/cloudera-repos/" + "JBoss Repository" at "http://repository.jboss.org/nexus/content/repositories/releases/", + "Cloudera Repository" at "https://repository.cloudera.com/artifactory/cloudera-repos/" ), libraryDependencies ++= Seq( - "com.google.guava" % "guava" % "14.0.1", - "com.google.code.findbugs" % "jsr305" % "1.3.9", - "log4j" % "log4j" % "1.2.17", - "org.slf4j" % "slf4j-api" % slf4jVersion, - "org.slf4j" % "slf4j-log4j12" % slf4jVersion, - "commons-daemon" % "commons-daemon" % "1.0.10", // workaround for bug HADOOP-9407 - "com.ning" % "compress-lzf" % "0.8.4", - "org.xerial.snappy" % "snappy-java" % "1.0.5", - "org.ow2.asm" % "asm" % "4.0", - "com.google.protobuf" % "protobuf-java" % "2.4.1", - "com.typesafe.akka" % "akka-actor" % "2.0.5" excludeAll(excludeNetty), - "com.typesafe.akka" % "akka-remote" % "2.0.5" excludeAll(excludeNetty), - "com.typesafe.akka" % "akka-slf4j" % "2.0.5" excludeAll(excludeNetty), - "it.unimi.dsi" % "fastutil" % "6.4.4", - "colt" % "colt" % "1.2.0", - "net.liftweb" % "lift-json_2.9.2" % "2.5", - "org.apache.mesos" % "mesos" % "0.13.0", - "io.netty" % "netty-all" % "4.0.0.Beta2", - "org.apache.derby" % "derby" % "10.4.2.0" % "test", - "org.apache.hadoop" % "hadoop-client" % hadoopVersion excludeAll(excludeJackson, excludeNetty, excludeAsm, excludeCglib), - "net.java.dev.jets3t" % "jets3t" % "0.7.1", - "org.apache.avro" % "avro" % "1.7.4", - "org.apache.avro" % "avro-ipc" % "1.7.4" excludeAll(excludeNetty), - "org.apache.zookeeper" % "zookeeper" % "3.4.5" excludeAll(excludeNetty), - "com.codahale.metrics" % "metrics-core" % "3.0.0", - "com.codahale.metrics" % "metrics-jvm" % "3.0.0", - "com.codahale.metrics" % "metrics-json" % "3.0.0", - "com.codahale.metrics" % "metrics-ganglia" % "3.0.0", - "com.twitter" % "chill_2.9.3" % "0.3.1", - "com.twitter" % "chill-java" % "0.3.1" - ) + "com.google.guava" % "guava" % "14.0.1", + "com.google.code.findbugs" % "jsr305" % "1.3.9", + "log4j" % "log4j" % "1.2.17", + "org.slf4j" % "slf4j-api" % slf4jVersion, + "org.slf4j" % "slf4j-log4j12" % slf4jVersion, + "commons-daemon" % "commons-daemon" % "1.0.10", // workaround for bug HADOOP-9407 + "com.ning" % "compress-lzf" % "0.8.4", + "org.xerial.snappy" % "snappy-java" % "1.0.5", + "org.ow2.asm" % "asm" % "4.0", + "org.spark-project.akka" %% "akka-remote" % "2.2.3-shaded-protobuf" excludeAll(excludeNetty), + "org.spark-project.akka" %% "akka-slf4j" % "2.2.3-shaded-protobuf" excludeAll(excludeNetty), + "net.liftweb" %% "lift-json" % "2.5.1" excludeAll(excludeNetty), + "it.unimi.dsi" % "fastutil" % "6.4.4", + "colt" % "colt" % "1.2.0", + "org.apache.mesos" % "mesos" % "0.13.0", + "net.java.dev.jets3t" % "jets3t" % "0.7.1", + "org.apache.derby" % "derby" % "10.4.2.0" % "test", + "org.apache.hadoop" % "hadoop-client" % hadoopVersion excludeAll(excludeJackson, excludeNetty, excludeAsm, excludeCglib), + "org.apache.avro" % "avro" % "1.7.4", + "org.apache.avro" % "avro-ipc" % "1.7.4" excludeAll(excludeNetty), + "org.apache.zookeeper" % "zookeeper" % "3.4.5" excludeAll(excludeNetty), + "com.codahale.metrics" % "metrics-core" % "3.0.0", + "com.codahale.metrics" % "metrics-jvm" % "3.0.0", + "com.codahale.metrics" % "metrics-json" % "3.0.0", + "com.codahale.metrics" % "metrics-ganglia" % "3.0.0", + "com.codahale.metrics" % "metrics-graphite" % "3.0.0", + "com.twitter" %% "chill" % "0.3.1", + "com.twitter" % "chill-java" % "0.3.1" + ) ) def rootSettings = sharedSettings ++ Seq( publish := {} ) - def replSettings = sharedSettings ++ Seq( + def replSettings = sharedSettings ++ Seq( name := "spark-repl", - libraryDependencies <+= scalaVersion("org.scala-lang" % "scala-compiler" % _) + libraryDependencies <+= scalaVersion(v => "org.scala-lang" % "scala-compiler" % v ), + libraryDependencies <+= scalaVersion(v => "org.scala-lang" % "jline" % v ), + libraryDependencies <+= scalaVersion(v => "org.scala-lang" % "scala-reflect" % v ) ) + def examplesSettings = sharedSettings ++ Seq( name := "spark-examples", libraryDependencies ++= Seq( - "com.twitter" % "algebird-core_2.9.2" % "0.1.11", - + "com.twitter" %% "algebird-core" % "0.1.11", + "org.apache.hbase" % "hbase" % "0.94.6" excludeAll(excludeNetty, excludeAsm), "org.apache.hbase" % "hbase" % HBASE_VERSION excludeAll(excludeNetty, excludeAsm), - - "org.apache.cassandra" % "cassandra-all" % "1.2.5" + "org.apache.cassandra" % "cassandra-all" % "1.2.6" exclude("com.google.guava", "guava") exclude("com.googlecode.concurrentlinkedhashmap", "concurrentlinkedhashmap-lru") exclude("com.ning","compress-lzf") @@ -282,18 +295,21 @@ object SparkBuild extends Build { def streamingSettings = sharedSettings ++ Seq( name := "spark-streaming", resolvers ++= Seq( - "Akka Repository" at "http://repo.akka.io/releases/", + "Eclipse Repository" at "https://repo.eclipse.org/content/repositories/paho-releases/", "Apache repo" at "https://repository.apache.org/content/repositories/releases" ), + libraryDependencies ++= Seq( - "org.eclipse.paho" % "mqtt-client" % "0.4.0", - "org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile" excludeAll(excludeNetty, excludeSnappy), - "org.twitter4j" % "twitter4j-stream" % "3.0.3" excludeAll(excludeNetty), - "com.typesafe.akka" % "akka-zeromq" % "2.0.5" excludeAll(excludeNetty), - "org.apache.kafka" % "kafka_2.9.2" % "0.8.0-beta1" + "org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile" excludeAll(excludeNetty, excludeSnappy), + "com.sksamuel.kafka" %% "kafka" % "0.8.0-beta1" exclude("com.sun.jdmk", "jmxtools") exclude("com.sun.jmx", "jmxri") exclude("net.sf.jopt-simple", "jopt-simple") + excludeAll(excludeNetty), + "org.eclipse.paho" % "mqtt-client" % "0.4.0", + "com.github.sgroschupf" % "zkclient" % "0.1" excludeAll(excludeNetty), + "org.twitter4j" % "twitter4j-stream" % "3.0.3" excludeAll(excludeNetty), + "org.spark-project.akka" %% "akka-zeromq" % "2.2.3-shaded-protobuf" excludeAll(excludeNetty) ) ) diff --git a/pyspark b/pyspark index 4941a36d0d..12cc926dda 100755 --- a/pyspark +++ b/pyspark @@ -23,7 +23,7 @@ FWDIR="$(cd `dirname $0`; pwd)" # Export this as SPARK_HOME export SPARK_HOME="$FWDIR" -SCALA_VERSION=2.9.3 +SCALA_VERSION=2.10 # Exit if the user hasn't compiled Spark if [ ! -f "$FWDIR/RELEASE" ]; then @@ -59,8 +59,12 @@ if [ -n "$IPYTHON_OPTS" ]; then fi if [[ "$IPYTHON" = "1" ]] ; then - IPYTHON_OPTS=${IPYTHON_OPTS:--i} - exec ipython "$IPYTHON_OPTS" -c "%run $PYTHONSTARTUP" + # IPython <1.0.0 doesn't honor PYTHONSTARTUP, while 1.0.0+ does. + # Hence we clear PYTHONSTARTUP and use the -c "%run $IPYTHONSTARTUP" command which works on all versions + # We also force interactive mode with "-i" + IPYTHONSTARTUP=$PYTHONSTARTUP + PYTHONSTARTUP= + exec ipython "$IPYTHON_OPTS" -i -c "%run $IPYTHONSTARTUP" else exec "$PYSPARK_PYTHON" "$@" fi diff --git a/pyspark2.cmd b/pyspark2.cmd index f58e349643..21f9a34388 100644 --- a/pyspark2.cmd +++ b/pyspark2.cmd @@ -17,7 +17,7 @@ rem See the License for the specific language governing permissions and rem limitations under the License. rem -set SCALA_VERSION=2.9.3 +set SCALA_VERSION=2.10 rem Figure out where the Spark framework is installed set FWDIR=%~dp0 diff --git a/python/epydoc.conf b/python/epydoc.conf index 1d0d002d36..0b42e729f8 100644 --- a/python/epydoc.conf +++ b/python/epydoc.conf @@ -32,6 +32,6 @@ target: docs/ private: no -exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers +exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.java_gateway pyspark.examples pyspark.shell pyspark.test pyspark.rddsampler pyspark.daemon diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index da3d96689a..2204e9c9ca 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -90,9 +90,11 @@ import struct import SocketServer import threading from pyspark.cloudpickle import CloudPickler -from pyspark.serializers import read_int, read_with_length, load_pickle +from pyspark.serializers import read_int, PickleSerializer +pickleSer = PickleSerializer() + # Holds accumulators registered on the current machine, keyed by ID. This is then used to send # the local accumulator updates back to the driver program at the end of a task. _accumulatorRegistry = {} @@ -211,7 +213,7 @@ class _UpdateRequestHandler(SocketServer.StreamRequestHandler): from pyspark.accumulators import _accumulatorRegistry num_updates = read_int(self.rfile) for _ in range(num_updates): - (aid, update) = load_pickle(read_with_length(self.rfile)) + (aid, update) = pickleSer._read_with_length(self.rfile) _accumulatorRegistry[aid] += update # Write a byte in acknowledgement self.wfile.write(struct.pack("!b", 1)) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index a7ca8bc888..cbd41e58c4 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -26,7 +26,7 @@ from pyspark.accumulators import Accumulator from pyspark.broadcast import Broadcast from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway -from pyspark.serializers import dump_pickle, write_with_length, batched +from pyspark.serializers import PickleSerializer, BatchedSerializer, MUTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.rdd import RDD @@ -42,7 +42,7 @@ class SparkContext(object): _gateway = None _jvm = None - _writeIteratorToPickleFile = None + _writeToFile = None _takePartition = None _next_accum_id = 0 _active_spark_context = None @@ -51,7 +51,7 @@ class SparkContext(object): def __init__(self, master, jobName, sparkHome=None, pyFiles=None, - environment=None, batchSize=1024): + environment=None, batchSize=1024, serializer=PickleSerializer()): """ Create a new SparkContext. @@ -67,6 +67,7 @@ class SparkContext(object): @param batchSize: The number of Python objects represented as a single Java object. Set 1 to disable batching or -1 to use an unlimited batch size. + @param serializer: The serializer for RDDs. >>> from pyspark.context import SparkContext @@ -83,7 +84,13 @@ class SparkContext(object): self.jobName = jobName self.sparkHome = sparkHome or None # None becomes null in Py4J self.environment = environment or {} - self.batchSize = batchSize # -1 represents a unlimited batch size + self._batchSize = batchSize # -1 represents an unlimited batch size + self._unbatched_serializer = serializer + if batchSize == 1: + self.serializer = self._unbatched_serializer + else: + self.serializer = BatchedSerializer(self._unbatched_serializer, + batchSize) # Create the Java SparkContext through Py4J empty_string_array = self._gateway.new_array(self._jvm.String, 0) @@ -125,8 +132,8 @@ class SparkContext(object): if not SparkContext._gateway: SparkContext._gateway = launch_gateway() SparkContext._jvm = SparkContext._gateway.jvm - SparkContext._writeIteratorToPickleFile = \ - SparkContext._jvm.PythonRDD.writeIteratorToPickleFile + SparkContext._writeToFile = \ + SparkContext._jvm.PythonRDD.writeToFile SparkContext._takePartition = \ SparkContext._jvm.PythonRDD.takePartition @@ -184,15 +191,17 @@ class SparkContext(object): # Make sure we distribute data evenly if it's smaller than self.batchSize if "__len__" not in dir(c): c = list(c) # Make it a list so we can compute its length - batchSize = min(len(c) // numSlices, self.batchSize) + batchSize = min(len(c) // numSlices, self._batchSize) if batchSize > 1: - c = batched(c, batchSize) - for x in c: - write_with_length(dump_pickle(x), tempFile) + serializer = BatchedSerializer(self._unbatched_serializer, + batchSize) + else: + serializer = self._unbatched_serializer + serializer.dump_stream(c, tempFile) tempFile.close() - readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile - jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices) - return RDD(jrdd, self) + readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile + jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices) + return RDD(jrdd, self, serializer) def textFile(self, name, minSplits=None): """ @@ -201,21 +210,39 @@ class SparkContext(object): RDD of Strings. """ minSplits = minSplits or min(self.defaultParallelism, 2) - jrdd = self._jsc.textFile(name, minSplits) - return RDD(jrdd, self) + return RDD(self._jsc.textFile(name, minSplits), self, + MUTF8Deserializer()) - def _checkpointFile(self, name): + def _checkpointFile(self, name, input_deserializer): jrdd = self._jsc.checkpointFile(name) - return RDD(jrdd, self) + return RDD(jrdd, self, input_deserializer) def union(self, rdds): """ Build the union of a list of RDDs. + + This supports unions() of RDDs with different serialized formats, + although this forces them to be reserialized using the default + serializer: + + >>> path = os.path.join(tempdir, "union-text.txt") + >>> with open(path, "w") as testFile: + ... testFile.write("Hello") + >>> textFile = sc.textFile(path) + >>> textFile.collect() + [u'Hello'] + >>> parallelized = sc.parallelize(["World!"]) + >>> sorted(sc.union([textFile, parallelized]).collect()) + [u'Hello', 'World!'] """ + first_jrdd_deserializer = rdds[0]._jrdd_deserializer + if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds): + rdds = [x._reserialize() for x in rdds] first = rdds[0]._jrdd rest = [x._jrdd for x in rdds[1:]] - rest = ListConverter().convert(rest, self.gateway._gateway_client) - return RDD(self._jsc.union(first, rest), self) + rest = ListConverter().convert(rest, self._gateway._gateway_client) + return RDD(self._jsc.union(first, rest), self, + rdds[0]._jrdd_deserializer) def broadcast(self, value): """ @@ -223,7 +250,9 @@ class SparkContext(object): object for reading it in distributed functions. The variable will be sent to each cluster only once. """ - jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value))) + pickleSer = PickleSerializer() + pickled = pickleSer.dumps(value) + jbroadcast = self._jsc.broadcast(bytearray(pickled)) return Broadcast(jbroadcast.id(), value, jbroadcast, self._pickled_broadcast_vars) @@ -235,7 +264,7 @@ class SparkContext(object): and floating-point numbers if you do not provide one. For other types, a custom AccumulatorParam can be used. """ - if accum_param == None: + if accum_param is None: if isinstance(value, int): accum_param = accumulators.INT_ACCUMULATOR_PARAM elif isinstance(value, float): diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 7019fb8bee..61720dcf1a 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -18,7 +18,7 @@ from base64 import standard_b64encode as b64enc import copy from collections import defaultdict -from itertools import chain, ifilter, imap, product +from itertools import chain, ifilter, imap import operator import os import sys @@ -27,9 +27,8 @@ from subprocess import Popen, PIPE from tempfile import NamedTemporaryFile from threading import Thread -from pyspark import cloudpickle -from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \ - read_from_pickle_file, pack_long +from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ + BatchedSerializer, CloudPickleSerializer, pack_long from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup from pyspark.statcounter import StatCounter @@ -48,12 +47,12 @@ class RDD(object): operated on in parallel. """ - def __init__(self, jrdd, ctx): + def __init__(self, jrdd, ctx, jrdd_deserializer): self._jrdd = jrdd self.is_cached = False self.is_checkpointed = False self.ctx = ctx - self._partitionFunc = None + self._jrdd_deserializer = jrdd_deserializer @property def context(self): @@ -247,7 +246,23 @@ class RDD(object): >>> rdd.union(rdd).collect() [1, 1, 2, 3, 1, 1, 2, 3] """ - return RDD(self._jrdd.union(other._jrdd), self.ctx) + if self._jrdd_deserializer == other._jrdd_deserializer: + rdd = RDD(self._jrdd.union(other._jrdd), self.ctx, + self._jrdd_deserializer) + return rdd + else: + # These RDDs contain data in different serialized formats, so we + # must normalize them to the default serializer. + self_copy = self._reserialize() + other_copy = other._reserialize() + return RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx, + self.ctx.serializer) + + def _reserialize(self): + if self._jrdd_deserializer == self.ctx.serializer: + return self + else: + return self.map(lambda x: x, preservesPartitioning=True) def __add__(self, other): """ @@ -334,17 +349,9 @@ class RDD(object): [(1, 1), (1, 2), (2, 1), (2, 2)] """ # Due to batching, we can't use the Java cartesian method. - java_cartesian = RDD(self._jrdd.cartesian(other._jrdd), self.ctx) - def unpack_batches(pair): - (x, y) = pair - if type(x) == Batch or type(y) == Batch: - xs = x.items if type(x) == Batch else [x] - ys = y.items if type(y) == Batch else [y] - for pair in product(xs, ys): - yield pair - else: - yield pair - return java_cartesian.flatMap(unpack_batches) + deserializer = CartesianDeserializer(self._jrdd_deserializer, + other._jrdd_deserializer) + return RDD(self._jrdd.cartesian(other._jrdd), self.ctx, deserializer) def groupBy(self, f, numPartitions=None): """ @@ -391,8 +398,8 @@ class RDD(object): """ Return a list that contains all of the elements in this RDD. """ - picklesInJava = self._jrdd.collect().iterator() - return list(self._collect_iterator_through_file(picklesInJava)) + bytesInJava = self._jrdd.collect().iterator() + return list(self._collect_iterator_through_file(bytesInJava)) def _collect_iterator_through_file(self, iterator): # Transferring lots of data through Py4J can be slow because @@ -400,10 +407,10 @@ class RDD(object): # file and read it back. tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir) tempFile.close() - self.ctx._writeIteratorToPickleFile(iterator, tempFile.name) + self.ctx._writeToFile(iterator, tempFile.name) # Read the data into Python and deserialize it: with open(tempFile.name, 'rb') as tempFile: - for item in read_from_pickle_file(tempFile): + for item in self._jrdd_deserializer.load_stream(tempFile): yield item os.unlink(tempFile.name) @@ -571,7 +578,7 @@ class RDD(object): items = [] for partition in range(mapped._jrdd.splits().size()): iterator = self.ctx._takePartition(mapped._jrdd.rdd(), partition) - items.extend(self._collect_iterator_through_file(iterator)) + items.extend(mapped._collect_iterator_through_file(iterator)) if len(items) >= num: break return items[:num] @@ -598,7 +605,10 @@ class RDD(object): '0\\n1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n' """ def func(split, iterator): - return (str(x).encode("utf-8") for x in iterator) + for x in iterator: + if not isinstance(x, basestring): + x = unicode(x) + yield x.encode("utf-8") keyed = PipelinedRDD(self, func) keyed._bypass_serializer = True keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path) @@ -735,6 +745,7 @@ class RDD(object): # Transferring O(n) objects to Java is too expensive. Instead, we'll # form the hash buckets in Python, transferring O(numPartitions) objects # to Java. Each object is a (splitNumber, [objects]) pair. + outputSerializer = self.ctx._unbatched_serializer def add_shuffle_key(split, iterator): buckets = defaultdict(list) @@ -743,14 +754,14 @@ class RDD(object): buckets[partitionFunc(k) % numPartitions].append((k, v)) for (split, items) in buckets.iteritems(): yield pack_long(split) - yield dump_pickle(Batch(items)) + yield outputSerializer.dumps(items) keyed = PipelinedRDD(self, add_shuffle_key) keyed._bypass_serializer = True pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() partitioner = self.ctx._jvm.PythonPartitioner(numPartitions, id(partitionFunc)) jrdd = pairRDD.partitionBy(partitioner).values() - rdd = RDD(jrdd, self.ctx) + rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer)) # This is required so that id(partitionFunc) remains unique, even if # partitionFunc is a lambda: rdd._partitionFunc = partitionFunc @@ -787,7 +798,8 @@ class RDD(object): numPartitions = self.ctx.defaultParallelism def combineLocally(iterator): combiners = {} - for (k, v) in iterator: + for x in iterator: + (k, v) = x if k not in combiners: combiners[k] = createCombiner(v) else: @@ -929,50 +941,52 @@ class PipelinedRDD(RDD): 20 """ def __init__(self, prev, func, preservesPartitioning=False): - if isinstance(prev, PipelinedRDD) and prev._is_pipelinable(): + if not isinstance(prev, PipelinedRDD) or not prev._is_pipelinable(): + # This transformation is the first in its stage: + self.func = func + self.preservesPartitioning = preservesPartitioning + self._prev_jrdd = prev._jrdd + self._prev_jrdd_deserializer = prev._jrdd_deserializer + else: prev_func = prev.func def pipeline_func(split, iterator): return func(split, prev_func(split, iterator)) self.func = pipeline_func self.preservesPartitioning = \ prev.preservesPartitioning and preservesPartitioning - self._prev_jrdd = prev._prev_jrdd - else: - self.func = func - self.preservesPartitioning = preservesPartitioning - self._prev_jrdd = prev._jrdd + self._prev_jrdd = prev._prev_jrdd # maintain the pipeline + self._prev_jrdd_deserializer = prev._prev_jrdd_deserializer self.is_cached = False self.is_checkpointed = False self.ctx = prev.ctx self.prev = prev self._jrdd_val = None + self._jrdd_deserializer = self.ctx.serializer self._bypass_serializer = False @property def _jrdd(self): if self._jrdd_val: return self._jrdd_val - func = self.func - if not self._bypass_serializer and self.ctx.batchSize != 1: - oldfunc = self.func - batchSize = self.ctx.batchSize - def batched_func(split, iterator): - return batched(oldfunc(split, iterator), batchSize) - func = batched_func - cmds = [func, self._bypass_serializer] - pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds) + if self._bypass_serializer: + serializer = NoOpSerializer() + else: + serializer = self.ctx.serializer + command = (self.func, self._prev_jrdd_deserializer, serializer) + pickled_command = CloudPickleSerializer().dumps(command) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], self.ctx._gateway._gateway_client) self.ctx._pickled_broadcast_vars.clear() - class_manifest = self._prev_jrdd.classManifest() + class_tag = self._prev_jrdd.classTag() env = MapConverter().convert(self.ctx.environment, self.ctx._gateway._gateway_client) includes = ListConverter().convert(self.ctx._python_includes, self.ctx._gateway._gateway_client) python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), - pipe_command, env, includes, self.preservesPartitioning, self.ctx.pythonExec, - broadcast_vars, self.ctx._javaAccumulator, class_manifest) + bytearray(pickled_command), env, includes, self.preservesPartitioning, + self.ctx.pythonExec, broadcast_vars, self.ctx._javaAccumulator, + class_tag) self._jrdd_val = python_rdd.asJavaRDD() return self._jrdd_val diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 54fed1c9c7..811fa6f018 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -15,45 +15,269 @@ # limitations under the License. # -import struct +""" +PySpark supports custom serializers for transferring data; this can improve +performance. + +By default, PySpark uses L{PickleSerializer} to serialize objects using Python's +C{cPickle} serializer, which can serialize nearly any Python object. +Other serializers, like L{MarshalSerializer}, support fewer datatypes but can be +faster. + +The serializer is chosen when creating L{SparkContext}: + +>>> from pyspark.context import SparkContext +>>> from pyspark.serializers import MarshalSerializer +>>> sc = SparkContext('local', 'test', serializer=MarshalSerializer()) +>>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10) +[0, 2, 4, 6, 8, 10, 12, 14, 16, 18] +>>> sc.stop() + +By default, PySpark serialize objects in batches; the batch size can be +controlled through SparkContext's C{batchSize} parameter +(the default size is 1024 objects): + +>>> sc = SparkContext('local', 'test', batchSize=2) +>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) + +Behind the scenes, this creates a JavaRDD with four partitions, each of +which contains two batches of two objects: + +>>> rdd.glom().collect() +[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] +>>> rdd._jrdd.count() +8L +>>> sc.stop() + +A batch size of -1 uses an unlimited batch size, and a size of 1 disables +batching: + +>>> sc = SparkContext('local', 'test', batchSize=1) +>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) +>>> rdd.glom().collect() +[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] +>>> rdd._jrdd.count() +16L +""" + import cPickle +from itertools import chain, izip, product +import marshal +import struct +from pyspark import cloudpickle -class Batch(object): +__all__ = ["PickleSerializer", "MarshalSerializer"] + + +class SpecialLengths(object): + END_OF_DATA_SECTION = -1 + PYTHON_EXCEPTION_THROWN = -2 + TIMING_DATA = -3 + + +class Serializer(object): + + def dump_stream(self, iterator, stream): + """ + Serialize an iterator of objects to the output stream. + """ + raise NotImplementedError + + def load_stream(self, stream): + """ + Return an iterator of deserialized objects from the input stream. + """ + raise NotImplementedError + + + def _load_stream_without_unbatching(self, stream): + return self.load_stream(stream) + + # Note: our notion of "equality" is that output generated by + # equal serializers can be deserialized using the same serializer. + + # This default implementation handles the simple cases; + # subclasses should override __eq__ as appropriate. + + def __eq__(self, other): + return isinstance(other, self.__class__) + + def __ne__(self, other): + return not self.__eq__(other) + + +class FramedSerializer(Serializer): """ - Used to store multiple RDD entries as a single Java object. - - This relieves us from having to explicitly track whether an RDD - is stored as batches of objects and avoids problems when processing - the union() of batched and unbatched RDDs (e.g. the union() of textFile() - with another RDD). + Serializer that writes objects as a stream of (length, data) pairs, + where C{length} is a 32-bit integer and data is C{length} bytes. """ - def __init__(self, items): - self.items = items + + def dump_stream(self, iterator, stream): + for obj in iterator: + self._write_with_length(obj, stream) + + def load_stream(self, stream): + while True: + try: + yield self._read_with_length(stream) + except EOFError: + return + + def _write_with_length(self, obj, stream): + serialized = self.dumps(obj) + write_int(len(serialized), stream) + stream.write(serialized) + + def _read_with_length(self, stream): + length = read_int(stream) + obj = stream.read(length) + if obj == "": + raise EOFError + return self.loads(obj) + + def dumps(self, obj): + """ + Serialize an object into a byte array. + When batching is used, this will be called with an array of objects. + """ + raise NotImplementedError + + def loads(self, obj): + """ + Deserialize an object from a byte array. + """ + raise NotImplementedError -def batched(iterator, batchSize): - if batchSize == -1: # unlimited batch size - yield Batch(list(iterator)) - else: - items = [] - count = 0 - for item in iterator: - items.append(item) - count += 1 - if count == batchSize: - yield Batch(items) - items = [] - count = 0 - if items: - yield Batch(items) +class BatchedSerializer(Serializer): + """ + Serializes a stream of objects in batches by calling its wrapped + Serializer with streams of objects. + """ + + UNLIMITED_BATCH_SIZE = -1 + + def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE): + self.serializer = serializer + self.batchSize = batchSize + + def _batched(self, iterator): + if self.batchSize == self.UNLIMITED_BATCH_SIZE: + yield list(iterator) + else: + items = [] + count = 0 + for item in iterator: + items.append(item) + count += 1 + if count == self.batchSize: + yield items + items = [] + count = 0 + if items: + yield items + + def dump_stream(self, iterator, stream): + self.serializer.dump_stream(self._batched(iterator), stream) + + def load_stream(self, stream): + return chain.from_iterable(self._load_stream_without_unbatching(stream)) + + def _load_stream_without_unbatching(self, stream): + return self.serializer.load_stream(stream) + + def __eq__(self, other): + return isinstance(other, BatchedSerializer) and \ + other.serializer == self.serializer + + def __str__(self): + return "BatchedSerializer<%s>" % str(self.serializer) -def dump_pickle(obj): - return cPickle.dumps(obj, 2) +class CartesianDeserializer(FramedSerializer): + """ + Deserializes the JavaRDD cartesian() of two PythonRDDs. + """ + + def __init__(self, key_ser, val_ser): + self.key_ser = key_ser + self.val_ser = val_ser + + def load_stream(self, stream): + key_stream = self.key_ser._load_stream_without_unbatching(stream) + val_stream = self.val_ser._load_stream_without_unbatching(stream) + key_is_batched = isinstance(self.key_ser, BatchedSerializer) + val_is_batched = isinstance(self.val_ser, BatchedSerializer) + for (keys, vals) in izip(key_stream, val_stream): + keys = keys if key_is_batched else [keys] + vals = vals if val_is_batched else [vals] + for pair in product(keys, vals): + yield pair + + def __eq__(self, other): + return isinstance(other, CartesianDeserializer) and \ + self.key_ser == other.key_ser and self.val_ser == other.val_ser + + def __str__(self): + return "CartesianDeserializer<%s, %s>" % \ + (str(self.key_ser), str(self.val_ser)) -load_pickle = cPickle.loads +class NoOpSerializer(FramedSerializer): + + def loads(self, obj): return obj + def dumps(self, obj): return obj + + +class PickleSerializer(FramedSerializer): + """ + Serializes objects using Python's cPickle serializer: + + http://docs.python.org/2/library/pickle.html + + This serializer supports nearly any Python object, but may + not be as fast as more specialized serializers. + """ + + def dumps(self, obj): return cPickle.dumps(obj, 2) + loads = cPickle.loads + +class CloudPickleSerializer(PickleSerializer): + + def dumps(self, obj): return cloudpickle.dumps(obj, 2) + + +class MarshalSerializer(FramedSerializer): + """ + Serializes objects using Python's Marshal serializer: + + http://docs.python.org/2/library/marshal.html + + This serializer is faster than PickleSerializer but supports fewer datatypes. + """ + + dumps = marshal.dumps + loads = marshal.loads + + +class MUTF8Deserializer(Serializer): + """ + Deserializes streams written by Java's DataOutputStream.writeUTF(). + """ + + def loads(self, stream): + length = struct.unpack('>H', stream.read(2))[0] + return stream.read(length).decode('utf8') + + def load_stream(self, stream): + while True: + try: + yield self.loads(stream) + except struct.error: + return + except EOFError: + return def read_long(stream): @@ -84,25 +308,4 @@ def write_int(value, stream): def write_with_length(obj, stream): write_int(len(obj), stream) - stream.write(obj) - - -def read_with_length(stream): - length = read_int(stream) - obj = stream.read(length) - if obj == "": - raise EOFError - return obj - - -def read_from_pickle_file(stream): - try: - while True: - obj = load_pickle(read_with_length(stream)) - if type(obj) == Batch: # We don't care about inheritance - for item in obj.items: - yield item - else: - yield obj - except EOFError: - return + stream.write(obj) \ No newline at end of file diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 29d6a128f6..3987642bf4 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -19,6 +19,8 @@ Unit tests for PySpark; additional tests are implemented as doctests in individual modules. """ +from fileinput import input +from glob import glob import os import shutil import sys @@ -86,7 +88,8 @@ class TestCheckpoint(PySparkTestCase): time.sleep(1) # 1 second self.assertTrue(flatMappedRDD.getCheckpointFile() is not None) - recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile()) + recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(), + flatMappedRDD._jrdd_deserializer) self.assertEquals([1, 2, 3, 4], recovered.collect()) @@ -137,6 +140,19 @@ class TestAddFile(PySparkTestCase): self.assertEqual("Hello World from inside a package!", UserClass().hello()) +class TestRDDFunctions(PySparkTestCase): + + def test_save_as_textfile_with_unicode(self): + # Regression test for SPARK-970 + x = u"\u00A1Hola, mundo!" + data = self.sc.parallelize([x]) + tempFile = NamedTemporaryFile(delete=True) + tempFile.close() + data.saveAsTextFile(tempFile.name) + raw_contents = ''.join(input(glob(tempFile.name + "/part-0000*"))) + self.assertEqual(x, unicode(raw_contents.strip(), "utf-8")) + + class TestIO(PySparkTestCase): def test_stdout_redirection(self): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index d63c2aaef7..f2b3f3c142 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -23,23 +23,22 @@ import sys import time import socket import traceback -from base64 import standard_b64decode # CloudPickler needs to be imported so that depicklers are registered using the # copy_reg module. from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler from pyspark.files import SparkFiles -from pyspark.serializers import write_with_length, read_with_length, write_int, \ - read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file +from pyspark.serializers import write_with_length, write_int, read_long, \ + write_long, read_int, SpecialLengths, MUTF8Deserializer, PickleSerializer -def load_obj(infile): - return load_pickle(standard_b64decode(infile.readline().strip())) +pickleSer = PickleSerializer() +mutf8_deserializer = MUTF8Deserializer() def report_times(outfile, boot, init, finish): - write_int(-3, outfile) + write_int(SpecialLengths.TIMING_DATA, outfile) write_long(1000 * boot, outfile) write_long(1000 * init, outfile) write_long(1000 * finish, outfile) @@ -52,7 +51,7 @@ def main(infile, outfile): return # fetch name of workdir - spark_files_dir = load_pickle(read_with_length(infile)) + spark_files_dir = mutf8_deserializer.loads(infile) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True @@ -60,38 +59,33 @@ def main(infile, outfile): num_broadcast_variables = read_int(infile) for _ in range(num_broadcast_variables): bid = read_long(infile) - value = read_with_length(infile) - _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value)) + value = pickleSer._read_with_length(infile) + _broadcastRegistry[bid] = Broadcast(bid, value) # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH sys.path.append(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): - sys.path.append(os.path.join(spark_files_dir, load_pickle(read_with_length(infile)))) + filename = mutf8_deserializer.loads(infile) + sys.path.append(os.path.join(spark_files_dir, filename)) - # now load function - func = load_obj(infile) - bypassSerializer = load_obj(infile) - if bypassSerializer: - dumps = lambda x: x - else: - dumps = dump_pickle + command = pickleSer._read_with_length(infile) + (func, deserializer, serializer) = command init_time = time.time() - iterator = read_from_pickle_file(infile) try: - for obj in func(split_index, iterator): - write_with_length(dumps(obj), outfile) + iterator = deserializer.load_stream(infile) + serializer.dump_stream(func(split_index, iterator), outfile) except Exception as e: - write_int(-2, outfile) + write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) write_with_length(traceback.format_exc(), outfile) sys.exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) # Mark the beginning of the accumulators section of the output - write_int(-1, outfile) - for aid, accum in _accumulatorRegistry.items(): - write_with_length(dump_pickle((aid, accum._value)), outfile) - write_int(-1, outfile) + write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) + write_int(len(_accumulatorRegistry), outfile) + for (aid, accum) in _accumulatorRegistry.items(): + pickleSer._write_with_length((aid, accum._value), outfile) if __name__ == '__main__': diff --git a/python/run-tests b/python/run-tests index cbc554ea9d..d4dad672d2 100755 --- a/python/run-tests +++ b/python/run-tests @@ -37,6 +37,7 @@ run_test "pyspark/rdd.py" run_test "pyspark/context.py" run_test "-m doctest pyspark/broadcast.py" run_test "-m doctest pyspark/accumulators.py" +run_test "-m doctest pyspark/serializers.py" run_test "pyspark/tests.py" if [[ $FAILED != 0 ]]; then diff --git a/python/test_support/userlibrary.py b/python/test_support/userlibrary.py index 5bb6f5009f..8e4a6292bc 100755 --- a/python/test_support/userlibrary.py +++ b/python/test_support/userlibrary.py @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + """ Used to test shipping of code depenencies with SparkContext.addPyFile(). """ diff --git a/repl-bin/pom.xml b/repl-bin/pom.xml index f6bf94be6b..869dbdb9b0 100644 --- a/repl-bin/pom.xml +++ b/repl-bin/pom.xml @@ -26,7 +26,7 @@ org.apache.spark - spark-repl-bin_2.9.3 + spark-repl-bin_2.10 pom Spark Project REPL binary packaging http://spark.incubator.apache.org/ @@ -40,18 +40,18 @@ org.apache.spark - spark-core_2.9.3 + spark-core_${scala.binary.version} ${project.version} org.apache.spark - spark-bagel_2.9.3 + spark-bagel_${scala.binary.version} ${project.version} runtime org.apache.spark - spark-repl_2.9.3 + spark-repl_${scala.binary.version} ${project.version} runtime diff --git a/repl-bin/src/deb/bin/run b/repl-bin/src/deb/bin/run index 8b5d8300f2..47bb654baf 100755 --- a/repl-bin/src/deb/bin/run +++ b/repl-bin/src/deb/bin/run @@ -17,7 +17,7 @@ # limitations under the License. # -SCALA_VERSION=2.9.3 +SCALA_VERSION=2.10 # Figure out where the Scala framework is installed FWDIR="$(cd `dirname $0`; pwd)" diff --git a/repl-bin/src/deb/bin/spark-executor b/repl-bin/src/deb/bin/spark-executor index bcfae22677..052d76fb8d 100755 --- a/repl-bin/src/deb/bin/spark-executor +++ b/repl-bin/src/deb/bin/spark-executor @@ -19,4 +19,4 @@ FWDIR="$(cd `dirname $0`; pwd)" echo "Running spark-executor with framework dir = $FWDIR" -exec $FWDIR/run spark.executor.MesosExecutorBackend +exec $FWDIR/run org.apache.spark.executor.MesosExecutorBackend diff --git a/repl-bin/src/deb/bin/spark-shell b/repl-bin/src/deb/bin/spark-shell index ec7e33e1e3..118349d7c3 100755 --- a/repl-bin/src/deb/bin/spark-shell +++ b/repl-bin/src/deb/bin/spark-shell @@ -18,4 +18,4 @@ # FWDIR="$(cd `dirname $0`; pwd)" -exec $FWDIR/run spark.repl.Main "$@" +exec $FWDIR/run org.apache.spark.repl.Main "$@" diff --git a/repl/lib/scala-jline.jar b/repl/lib/scala-jline.jar deleted file mode 100644 index 2f18c95cdd..0000000000 Binary files a/repl/lib/scala-jline.jar and /dev/null differ diff --git a/repl/pom.xml b/repl/pom.xml index 49d86621dd..b0e7877bbb 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -26,7 +26,7 @@ org.apache.spark - spark-repl_2.9.3 + spark-repl_2.10 jar Spark Project REPL http://spark.incubator.apache.org/ @@ -39,18 +39,18 @@ org.apache.spark - spark-core_2.9.3 + spark-core_${scala.binary.version} ${project.version} org.apache.spark - spark-bagel_2.9.3 + spark-bagel_${scala.binary.version} ${project.version} runtime org.apache.spark - spark-mllib_2.9.3 + spark-mllib_${scala.binary.version} ${project.version} runtime @@ -61,10 +61,12 @@ org.scala-lang scala-compiler + ${scala.version} org.scala-lang jline + ${scala.version} org.slf4j @@ -76,18 +78,18 @@ org.scalatest - scalatest_2.9.3 + scalatest_${scala.binary.version} test org.scalacheck - scalacheck_2.9.3 + scalacheck_${scala.binary.version} test - target/scala-${scala.version}/classes - target/scala-${scala.version}/test-classes + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes org.apache.maven.plugins diff --git a/repl/src/main/scala/org/apache/spark/repl/Main.scala b/repl/src/main/scala/org/apache/spark/repl/Main.scala index 17e149f8ab..14b448d076 100644 --- a/repl/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/src/main/scala/org/apache/spark/repl/Main.scala @@ -20,12 +20,12 @@ package org.apache.spark.repl import scala.collection.mutable.Set object Main { - private var _interp: SparkILoop = null - + private var _interp: SparkILoop = _ + def interp = _interp - + def interp_=(i: SparkILoop) { _interp = i } - + def main(args: Array[String]) { _interp = new SparkILoop _interp.process(args) diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala b/repl/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala new file mode 100644 index 0000000000..b2e1df173e --- /dev/null +++ b/repl/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala @@ -0,0 +1,109 @@ +/* NSC -- new Scala compiler + * Copyright 2005-2013 LAMP/EPFL + * @author Paul Phillips + */ + +package org.apache.spark.repl + +import scala.tools.nsc._ +import scala.tools.nsc.interpreter._ + +import scala.reflect.internal.util.BatchSourceFile +import scala.tools.nsc.ast.parser.Tokens.EOF + +import org.apache.spark.Logging + +trait SparkExprTyper extends Logging { + val repl: SparkIMain + + import repl._ + import global.{ reporter => _, Import => _, _ } + import definitions._ + import syntaxAnalyzer.{ UnitParser, UnitScanner, token2name } + import naming.freshInternalVarName + + object codeParser extends { val global: repl.global.type = repl.global } with CodeHandlers[Tree] { + def applyRule[T](code: String, rule: UnitParser => T): T = { + reporter.reset() + val scanner = newUnitParser(code) + val result = rule(scanner) + + if (!reporter.hasErrors) + scanner.accept(EOF) + + result + } + + def defns(code: String) = stmts(code) collect { case x: DefTree => x } + def expr(code: String) = applyRule(code, _.expr()) + def stmts(code: String) = applyRule(code, _.templateStats()) + def stmt(code: String) = stmts(code).last // guaranteed nonempty + } + + /** Parse a line into a sequence of trees. Returns None if the input is incomplete. */ + def parse(line: String): Option[List[Tree]] = debugging(s"""parse("$line")""") { + var isIncomplete = false + reporter.withIncompleteHandler((_, _) => isIncomplete = true) { + val trees = codeParser.stmts(line) + if (reporter.hasErrors) Some(Nil) + else if (isIncomplete) None + else Some(trees) + } + } + // def parsesAsExpr(line: String) = { + // import codeParser._ + // (opt expr line).isDefined + // } + + def symbolOfLine(code: String): Symbol = { + def asExpr(): Symbol = { + val name = freshInternalVarName() + // Typing it with a lazy val would give us the right type, but runs + // into compiler bugs with things like existentials, so we compile it + // behind a def and strip the NullaryMethodType which wraps the expr. + val line = "def " + name + " = {\n" + code + "\n}" + + interpretSynthetic(line) match { + case IR.Success => + val sym0 = symbolOfTerm(name) + // drop NullaryMethodType + val sym = sym0.cloneSymbol setInfo afterTyper(sym0.info.finalResultType) + if (sym.info.typeSymbol eq UnitClass) NoSymbol + else sym + case _ => NoSymbol + } + } + def asDefn(): Symbol = { + val old = repl.definedSymbolList.toSet + + interpretSynthetic(code) match { + case IR.Success => + repl.definedSymbolList filterNot old match { + case Nil => NoSymbol + case sym :: Nil => sym + case syms => NoSymbol.newOverloaded(NoPrefix, syms) + } + case _ => NoSymbol + } + } + beQuietDuring(asExpr()) orElse beQuietDuring(asDefn()) + } + + private var typeOfExpressionDepth = 0 + def typeOfExpression(expr: String, silent: Boolean = true): Type = { + if (typeOfExpressionDepth > 2) { + logDebug("Terminating typeOfExpression recursion for expression: " + expr) + return NoType + } + typeOfExpressionDepth += 1 + // Don't presently have a good way to suppress undesirable success output + // while letting errors through, so it is first trying it silently: if there + // is an error, and errors are desired, then it re-evaluates non-silently + // to induce the error message. + try beSilentDuring(symbolOfLine(expr).tpe) match { + case NoType if !silent => symbolOfLine(expr).tpe // generate error + case tpe => tpe + } + finally typeOfExpressionDepth -= 1 + } +} diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 0ced284da6..523fd1222d 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -1,26 +1,38 @@ /* NSC -- new Scala compiler - * Copyright 2005-2011 LAMP/EPFL + * Copyright 2005-2013 LAMP/EPFL * @author Alexander Spoon */ package org.apache.spark.repl + import scala.tools.nsc._ import scala.tools.nsc.interpreter._ -import Predef.{ println => _, _ } -import java.io.{ BufferedReader, FileReader, PrintWriter } -import scala.sys.process.Process -import session._ import scala.tools.nsc.interpreter.{ Results => IR } -import scala.tools.util.{ SignalManager, Signallable, Javap } +import Predef.{ println => _, _ } +import java.io.{ BufferedReader, FileReader } +import java.util.concurrent.locks.ReentrantLock +import scala.sys.process.Process +import scala.tools.nsc.interpreter.session._ +import scala.util.Properties.{ jdkHome, javaVersion } +import scala.tools.util.{ Javap } import scala.annotation.tailrec -import scala.util.control.Exception.{ ignoring } import scala.collection.mutable.ListBuffer import scala.concurrent.ops -import util.{ ClassPath, Exceptional, stringFromWriter, stringFromStream } -import interpreter._ -import io.{ File, Sources } +import scala.tools.nsc.util.{ ClassPath, Exceptional, stringFromWriter, stringFromStream } +import scala.tools.nsc.interpreter._ +import scala.tools.nsc.io.{ File, Directory } +import scala.reflect.NameTransformer._ +import scala.tools.nsc.util.ScalaClassLoader +import scala.tools.nsc.util.ScalaClassLoader._ +import scala.tools.util._ +import scala.language.{implicitConversions, existentials} +import scala.reflect.{ClassTag, classTag} +import scala.tools.reflect.StdRuntimeTags._ + +import java.lang.{Class => jClass} +import scala.reflect.api.{Mirror, TypeCreator, Universe => ApiUniverse} import org.apache.spark.Logging import org.apache.spark.SparkContext @@ -37,45 +49,86 @@ import org.apache.spark.SparkContext * @author Lex Spoon * @version 1.2 */ -class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: Option[String]) +class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, + val master: Option[String]) extends AnyRef with LoopCommands + with SparkILoopInit with Logging { - def this(in0: BufferedReader, out: PrintWriter, master: String) = this(Some(in0), out, Some(master)) - def this(in0: BufferedReader, out: PrintWriter) = this(Some(in0), out, None) - def this() = this(None, new PrintWriter(Console.out, true), None) - + def this(in0: BufferedReader, out: JPrintWriter, master: String) = this(Some(in0), out, Some(master)) + def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out, None) + def this() = this(None, new JPrintWriter(Console.out, true), None) + var in: InteractiveReader = _ // the input stream from which commands come var settings: Settings = _ var intp: SparkIMain = _ - /* - lazy val power = { - val g = intp.global - Power[g.type](this, g) + @deprecated("Use `intp` instead.", "2.9.0") def interpreter = intp + @deprecated("Use `intp` instead.", "2.9.0") def interpreter_= (i: SparkIMain): Unit = intp = i + + /** Having inherited the difficult "var-ness" of the repl instance, + * I'm trying to work around it by moving operations into a class from + * which it will appear a stable prefix. + */ + private def onIntp[T](f: SparkIMain => T): T = f(intp) + + class IMainOps[T <: SparkIMain](val intp: T) { + import intp._ + import global._ + + def printAfterTyper(msg: => String) = + intp.reporter printMessage afterTyper(msg) + + /** Strip NullaryMethodType artifacts. */ + private def replInfo(sym: Symbol) = { + sym.info match { + case NullaryMethodType(restpe) if sym.isAccessor => restpe + case info => info + } + } + def echoTypeStructure(sym: Symbol) = + printAfterTyper("" + deconstruct.show(replInfo(sym))) + + def echoTypeSignature(sym: Symbol, verbose: Boolean) = { + if (verbose) SparkILoop.this.echo("// Type signature") + printAfterTyper("" + replInfo(sym)) + + if (verbose) { + SparkILoop.this.echo("\n// Internal Type structure") + echoTypeStructure(sym) + } + } } - */ - - // TODO - // object opt extends AestheticSettings - // - @deprecated("Use `intp` instead.", "2.9.0") - def interpreter = intp - - @deprecated("Use `intp` instead.", "2.9.0") - def interpreter_= (i: SparkIMain): Unit = intp = i - + implicit def stabilizeIMain(intp: SparkIMain) = new IMainOps[intp.type](intp) + + /** TODO - + * -n normalize + * -l label with case class parameter names + * -c complete - leave nothing out + */ + private def typeCommandInternal(expr: String, verbose: Boolean): Result = { + onIntp { intp => + val sym = intp.symbolOfLine(expr) + if (sym.exists) intp.echoTypeSignature(sym, verbose) + else "" + } + } + + var sparkContext: SparkContext = _ + + override def echoCommandMessage(msg: String) { + intp.reporter printMessage msg + } + + // def isAsync = !settings.Yreplsync.value + def isAsync = false + // lazy val power = new Power(intp, new StdReplVals(this))(tagOfStdReplVals, classTag[StdReplVals]) def history = in.history /** The context class loader at the time this object was created */ protected val originalClassLoader = Thread.currentThread.getContextClassLoader - // Install a signal handler so we can be prodded. - private val signallable = - /*if (isReplDebug) Signallable("Dump repl state.")(dumpCommand()) - else*/ null - // classpath entries added via :cp var addedClasspath: String = "" @@ -87,74 +140,49 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: /** Record a command for replay should the user request a :replay */ def addReplay(cmd: String) = replayCommandStack ::= cmd - - /** Try to install sigint handler: ignore failure. Signal handler - * will interrupt current line execution if any is in progress. - * - * Attempting to protect the repl from accidental exit, we only honor - * a single ctrl-C if the current buffer is empty: otherwise we look - * for a second one within a short time. - */ - private def installSigIntHandler() { - def onExit() { - Console.println("") // avoiding "shell prompt in middle of line" syndrome - sys.exit(1) - } - ignoring(classOf[Exception]) { - SignalManager("INT") = { - if (intp == null) - onExit() - else if (intp.lineManager.running) - intp.lineManager.cancel() - else if (in.currentLine != "") { - // non-empty buffer, so make them hit ctrl-C a second time - SignalManager("INT") = onExit() - io.timer(5)(installSigIntHandler()) // and restore original handler if they don't - } - else onExit() - } - } + + def savingReplayStack[T](body: => T): T = { + val saved = replayCommandStack + try body + finally replayCommandStack = saved + } + def savingReader[T](body: => T): T = { + val saved = in + try body + finally in = saved } + + def sparkCleanUp(){ + echo("Stopping spark context.") + intp.beQuietDuring { + command("sc.stop()") + } + } /** Close the interpreter and set the var to null. */ def closeInterpreter() { if (intp ne null) { - intp.close + sparkCleanUp() + intp.close() intp = null - Thread.currentThread.setContextClassLoader(originalClassLoader) } } - + class SparkILoopInterpreter extends SparkIMain(settings, out) { + outer => + override lazy val formatting = new Formatting { def prompt = SparkILoop.this.prompt } - override protected def createLineManager() = new Line.Manager { - override def onRunaway(line: Line[_]): Unit = { - val template = """ - |// She's gone rogue, captain! Have to take her out! - |// Calling Thread.stop on runaway %s with offending code: - |// scala> %s""".stripMargin - - echo(template.format(line.thread, line.code)) - // XXX no way to suppress the deprecation warning - line.thread.stop() - in.redrawLine() - } - } - override protected def parentClassLoader = { - SparkHelper.explicitParentLoader(settings).getOrElse( classOf[SparkILoop].getClassLoader ) - } + override protected def parentClassLoader = SparkHelper.explicitParentLoader(settings).getOrElse(classOf[SparkILoop].getClassLoader) } /** Create a new interpreter. */ def createInterpreter() { if (addedClasspath != "") settings.classpath append addedClasspath - + intp = new SparkILoopInterpreter - intp.setContextClassLoader() - installSigIntHandler() } /** print a friendly help message */ @@ -168,10 +196,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: private def helpSummary() = { val usageWidth = commands map (_.usageMsg.length) max val formatStr = "%-" + usageWidth + "s %s %s" - + echo("All commands can be abbreviated, e.g. :he instead of :help.") echo("Those marked with a * have more detailed help, e.g. :help imports.\n") - + commands foreach { cmd => val star = if (cmd.hasLongHelp) "*" else " " echo(formatStr.format(cmd.usageMsg, star, cmd.help)) @@ -182,7 +210,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: case Nil => echo(cmd + ": no such command. Type :help for help.") case xs => echo(cmd + " is ambiguous: did you mean " + xs.map(":" + _.name).mkString(" or ") + "?") } - Result(true, None) + Result(true, None) } private def matchingCommands(cmd: String) = commands filter (_.name startsWith cmd) private def uniqueCommand(cmd: String): Option[LoopCommand] = { @@ -193,31 +221,16 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: case xs => xs find (_.name == cmd) } } - - /** Print a welcome message */ - def printWelcome() { - echo("""Welcome to - ____ __ - / __/__ ___ _____/ /__ - _\ \/ _ \/ _ `/ __/ '_/ - /___/ .__/\_,_/_/ /_/\_\ version 0.9.0-SNAPSHOT - /_/ -""") - import Properties._ - val welcomeMsg = "Using Scala %s (%s, Java %s)".format( - versionString, javaVmName, javaVersion) - echo(welcomeMsg) - } - + /** Show the history */ lazy val historyCommand = new LoopCommand("history", "show the history (optional num is commands to show)") { override def usage = "[num]" def defaultLines = 20 - + def apply(line: String): Result = { if (history eq NoHistory) return "No history available." - + val xs = words(line) val current = history.index val count = try xs.head.toInt catch { case _: Exception => defaultLines } @@ -229,32 +242,38 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: } } - private def echo(msg: String) = { + // When you know you are most likely breaking into the middle + // of a line being typed. This softens the blow. + protected def echoAndRefresh(msg: String) = { + echo("\n" + msg) + in.redrawLine() + } + protected def echo(msg: String) = { out println msg out.flush() } - private def echoNoNL(msg: String) = { + protected def echoNoNL(msg: String) = { out print msg out.flush() } - + /** Search the history */ def searchHistory(_cmdline: String) { val cmdline = _cmdline.toLowerCase val offset = history.index - history.size + 1 - + for ((line, index) <- history.asStrings.zipWithIndex ; if line.toLowerCase contains cmdline) echo("%d %s".format(index + offset, line)) } - + private var currentPrompt = Properties.shellPromptString def setPrompt(prompt: String) = currentPrompt = prompt /** Prompt to print when awaiting input */ def prompt = currentPrompt - + import LoopCommand.{ cmd, nullary } - /** Standard commands **/ + /** Standard commands */ lazy val standardCommands = List( cmd("cp", "", "add a jar or directory to the classpath", addClasspath), cmd("help", "[command]", "print this summary or command-specific help", helpCommand), @@ -263,53 +282,30 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: cmd("imports", "[name name ...]", "show import history, identifying sources of names", importsCommand), cmd("implicits", "[-v]", "show the implicits in scope", implicitsCommand), cmd("javap", "", "disassemble a file or class name", javapCommand), - nullary("keybindings", "show how ctrl-[A-Z] and other keys are bound", keybindingsCommand), cmd("load", "", "load and interpret a Scala file", loadCommand), nullary("paste", "enter paste mode: all input up to ctrl-D compiled together", pasteCommand), - //nullary("power", "enable power user mode", powerCmd), - nullary("quit", "exit the interpreter", () => Result(false, None)), +// nullary("power", "enable power user mode", powerCmd), + nullary("quit", "exit the repl", () => Result(false, None)), nullary("replay", "reset execution and replay all previous commands", replay), + nullary("reset", "reset the repl to its initial state, forgetting all session entries", resetCommand), shCommand, nullary("silent", "disable/enable automatic printing of results", verbosity), - cmd("type", "", "display the type of an expression without evaluating it", typeCommand) + cmd("type", "[-v] ", "display the type of an expression without evaluating it", typeCommand), + nullary("warnings", "show the suppressed warnings from the most recent line which had any", warningsCommand) ) - + /** Power user commands */ lazy val powerCommands: List[LoopCommand] = List( - //nullary("dump", "displays a view of the interpreter's internal state", dumpCommand), - //cmd("phase", "", "set the implicit phase for power commands", phaseCommand), - cmd("wrap", "", "name of method to wrap around each repl line", wrapCommand) withLongHelp (""" - |:wrap - |:wrap clear - |:wrap - | - |Installs a wrapper around each line entered into the repl. - |Currently it must be the simple name of an existing method - |with the specific signature shown in the following example. - | - |def timed[T](body: => T): T = { - | val start = System.nanoTime - | try body - | finally println((System.nanoTime - start) + " nanos elapsed.") - |} - |:wrap timed - | - |If given no argument, :wrap names the wrapper installed. - |An argument of clear will remove the wrapper if any is active. - |Note that wrappers do not compose (a new one replaces the old - |one) and also that the :phase command uses the same machinery, - |so setting :wrap will clear any :phase setting. - """.stripMargin.trim) + // cmd("phase", "", "set the implicit phase for power commands", phaseCommand) ) - - /* - private def dumpCommand(): Result = { - echo("" + power) - history.asStrings takeRight 30 foreach echo - in.redrawLine() - } - */ - + + // private def dumpCommand(): Result = { + // echo("" + power) + // history.asStrings takeRight 30 foreach echo + // in.redrawLine() + // } + // private def valsCommand(): Result = power.valsDescription + private val typeTransforms = List( "scala.collection.immutable." -> "immutable.", "scala.collection.mutable." -> "mutable.", @@ -317,7 +313,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: "java.lang." -> "jl.", "scala.runtime." -> "runtime." ) - + private def importsCommand(line: String): Result = { val tokens = words(line) val handlers = intp.languageWildcardHandlers ++ intp.importHandlers @@ -333,7 +329,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: val implicitMsg = if (imps.isEmpty) "" else imps.size + " are implicit" val foundMsg = if (found.isEmpty) "" else found.mkString(" // imports: ", ", ", "") val statsMsg = List(typeMsg, termMsg, implicitMsg) filterNot (_ == "") mkString ("(", ", ", ")") - + intp.reporter.printMessage("%2d) %-30s %s%s".format( idx + 1, handler.importString, @@ -342,12 +338,11 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: )) } } - - private def implicitsCommand(line: String): Result = { - val intp = SparkILoop.this.intp + + private def implicitsCommand(line: String): Result = onIntp { intp => import intp._ - import global.Symbol - + import global._ + def p(x: Any) = intp.reporter.printMessage("" + x) // If an argument is given, only show a source with that @@ -360,17 +355,17 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: else (args exists (source.name.toString contains _)) } } - + if (filtered.isEmpty) return "No implicits have been imported other than those in Predef." - + filtered foreach { case (source, syms) => p("/* " + syms.size + " implicit members imported from " + source.fullName + " */") - + // This groups the members by where the symbol is defined val byOwner = syms groupBy (_.owner) - val sortedOwners = byOwner.toList sortBy { case (owner, _) => intp.afterTyper(source.info.baseClasses indexOf owner) } + val sortedOwners = byOwner.toList sortBy { case (owner, _) => afterTyper(source.info.baseClasses indexOf owner) } sortedOwners foreach { case (owner, members) => @@ -388,10 +383,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: xss map (xs => xs sortBy (_.name.toString)) } - - val ownerMessage = if (owner == source) " defined in " else " inherited from " + + val ownerMessage = if (owner == source) " defined in " else " inherited from " p(" /* " + members.size + ownerMessage + owner.fullName + " */") - + memberGroups foreach { group => group foreach (s => p(" " + intp.symbolDefString(s))) p("") @@ -400,158 +395,182 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: p("") } } - - protected def newJavap() = new Javap(intp.classLoader, new SparkIMain.ReplStrippingWriter(intp)) { - override def tryClass(path: String): Array[Byte] = { - // Look for Foo first, then Foo$, but if Foo$ is given explicitly, - // we have to drop the $ to find object Foo, then tack it back onto - // the end of the flattened name. - def className = intp flatName path - def moduleName = (intp flatName path.stripSuffix("$")) + "$" - val bytes = super.tryClass(className) - if (bytes.nonEmpty) bytes - else super.tryClass(moduleName) + private def findToolsJar() = { + val jdkPath = Directory(jdkHome) + val jar = jdkPath / "lib" / "tools.jar" toFile; + + if (jar isFile) + Some(jar) + else if (jdkPath.isDirectory) + jdkPath.deepFiles find (_.name == "tools.jar") + else None + } + private def addToolsJarToLoader() = { + val cl = findToolsJar match { + case Some(tools) => ScalaClassLoader.fromURLs(Seq(tools.toURL), intp.classLoader) + case _ => intp.classLoader + } + if (Javap.isAvailable(cl)) { + logDebug(":javap available.") + cl + } + else { + logDebug(":javap unavailable: no tools.jar at " + jdkHome) + intp.classLoader } } + + protected def newJavap() = new JavapClass(addToolsJarToLoader(), new SparkIMain.ReplStrippingWriter(intp)) { + override def tryClass(path: String): Array[Byte] = { + val hd :: rest = path split '.' toList; + // If there are dots in the name, the first segment is the + // key to finding it. + if (rest.nonEmpty) { + intp optFlatName hd match { + case Some(flat) => + val clazz = flat :: rest mkString NAME_JOIN_STRING + val bytes = super.tryClass(clazz) + if (bytes.nonEmpty) bytes + else super.tryClass(clazz + MODULE_SUFFIX_STRING) + case _ => super.tryClass(path) + } + } + else { + // Look for Foo first, then Foo$, but if Foo$ is given explicitly, + // we have to drop the $ to find object Foo, then tack it back onto + // the end of the flattened name. + def className = intp flatName path + def moduleName = (intp flatName path.stripSuffix(MODULE_SUFFIX_STRING)) + MODULE_SUFFIX_STRING + + val bytes = super.tryClass(className) + if (bytes.nonEmpty) bytes + else super.tryClass(moduleName) + } + } + } + // private lazy val javap = substituteAndLog[Javap]("javap", NoJavap)(newJavap()) private lazy val javap = try newJavap() catch { case _: Exception => null } - - private def typeCommand(line: String): Result = { - intp.typeOfExpression(line) match { - case Some(tp) => tp.toString - case _ => "Failed to determine type." + + // Still todo: modules. + private def typeCommand(line0: String): Result = { + line0.trim match { + case "" => ":type [-v] " + case s if s startsWith "-v " => typeCommandInternal(s stripPrefix "-v " trim, true) + case s => typeCommandInternal(s, false) } } - + + private def warningsCommand(): Result = { + if (intp.lastWarnings.isEmpty) + "Can't find any cached warnings." + else + intp.lastWarnings foreach { case (pos, msg) => intp.reporter.warning(pos, msg) } + } + private def javapCommand(line: String): Result = { if (javap == null) - return ":javap unavailable on this platform." - if (line == "") - return ":javap [-lcsvp] [path1 path2 ...]" - - javap(words(line)) foreach { res => - if (res.isError) return "Failed: " + res.value - else res.show() - } - } - private def keybindingsCommand(): Result = { - if (in.keyBindings.isEmpty) "Key bindings unavailable." - else { - echo("Reading jline properties for default key bindings.") - echo("Accuracy not guaranteed: treat this as a guideline only.\n") - in.keyBindings foreach (x => echo ("" + x)) - } + ":javap unavailable, no tools.jar at %s. Set JDK_HOME.".format(jdkHome) + else if (javaVersion startsWith "1.7") + ":javap not yet working with java 1.7" + else if (line == "") + ":javap [-lcsvp] [path1 path2 ...]" + else + javap(words(line)) foreach { res => + if (res.isError) return "Failed: " + res.value + else res.show() + } } + private def wrapCommand(line: String): Result = { def failMsg = "Argument to :wrap must be the name of a method with signature [T](=> T): T" - val intp = SparkILoop.this.intp - val g: intp.global.type = intp.global - import g._ + onIntp { intp => + import intp._ + import global._ - words(line) match { - case Nil => - intp.executionWrapper match { - case "" => "No execution wrapper is set." - case s => "Current execution wrapper: " + s - } - case "clear" :: Nil => - intp.executionWrapper match { - case "" => "No execution wrapper is set." - case s => intp.clearExecutionWrapper() ; "Cleared execution wrapper." - } - case wrapper :: Nil => - intp.typeOfExpression(wrapper) match { - case Some(PolyType(List(targ), MethodType(List(arg), restpe))) => - intp setExecutionWrapper intp.pathToTerm(wrapper) - "Set wrapper to '" + wrapper + "'" - case Some(x) => - failMsg + "\nFound: " + x - case _ => - failMsg + "\nFound: " - } - case _ => failMsg + words(line) match { + case Nil => + intp.executionWrapper match { + case "" => "No execution wrapper is set." + case s => "Current execution wrapper: " + s + } + case "clear" :: Nil => + intp.executionWrapper match { + case "" => "No execution wrapper is set." + case s => intp.clearExecutionWrapper() ; "Cleared execution wrapper." + } + case wrapper :: Nil => + intp.typeOfExpression(wrapper) match { + case PolyType(List(targ), MethodType(List(arg), restpe)) => + intp setExecutionWrapper intp.pathToTerm(wrapper) + "Set wrapper to '" + wrapper + "'" + case tp => + failMsg + "\nFound: " + } + case _ => failMsg + } } } private def pathToPhaseWrapper = intp.pathToTerm("$r") + ".phased.atCurrent" - /* - private def phaseCommand(name: String): Result = { - // This line crashes us in TreeGen: - // - // if (intp.power.phased set name) "..." - // - // Exception in thread "main" java.lang.AssertionError: assertion failed: ._7.type - // at scala.Predef$.assert(Predef.scala:99) - // at scala.tools.nsc.ast.TreeGen.mkAttributedQualifier(TreeGen.scala:69) - // at scala.tools.nsc.ast.TreeGen.mkAttributedQualifier(TreeGen.scala:44) - // at scala.tools.nsc.ast.TreeGen.mkAttributedRef(TreeGen.scala:101) - // at scala.tools.nsc.ast.TreeGen.mkAttributedStableRef(TreeGen.scala:143) - // - // But it works like so, type annotated. - val phased: Phased = power.phased - import phased.NoPhaseName + // private def phaseCommand(name: String): Result = { + // val phased: Phased = power.phased + // import phased.NoPhaseName + + // if (name == "clear") { + // phased.set(NoPhaseName) + // intp.clearExecutionWrapper() + // "Cleared active phase." + // } + // else if (name == "") phased.get match { + // case NoPhaseName => "Usage: :phase (e.g. typer, erasure.next, erasure+3)" + // case ph => "Active phase is '%s'. (To clear, :phase clear)".format(phased.get) + // } + // else { + // val what = phased.parse(name) + // if (what.isEmpty || !phased.set(what)) + // "'" + name + "' does not appear to represent a valid phase." + // else { + // intp.setExecutionWrapper(pathToPhaseWrapper) + // val activeMessage = + // if (what.toString.length == name.length) "" + what + // else "%s (%s)".format(what, name) + + // "Active phase is now: " + activeMessage + // } + // } + // } - if (name == "clear") { - phased.set(NoPhaseName) - intp.clearExecutionWrapper() - "Cleared active phase." - } - else if (name == "") phased.get match { - case NoPhaseName => "Usage: :phase (e.g. typer, erasure.next, erasure+3)" - case ph => "Active phase is '%s'. (To clear, :phase clear)".format(phased.get) - } - else { - val what = phased.parse(name) - if (what.isEmpty || !phased.set(what)) - "'" + name + "' does not appear to represent a valid phase." - else { - intp.setExecutionWrapper(pathToPhaseWrapper) - val activeMessage = - if (what.toString.length == name.length) "" + what - else "%s (%s)".format(what, name) - - "Active phase is now: " + activeMessage - } - } - } - */ - /** Available commands */ - def commands: List[LoopCommand] = standardCommands /* ++ ( + def commands: List[LoopCommand] = standardCommands /*++ ( if (isReplPower) powerCommands else Nil )*/ - + val replayQuestionMessage = - """|The repl compiler has crashed spectacularly. Shall I replay your - |session? I can re-run all lines except the last one. + """|That entry seems to have slain the compiler. Shall I replay + |your session? I can re-run each line except the last one. |[y/n] """.trim.stripMargin - private val crashRecovery: PartialFunction[Throwable, Unit] = { + private val crashRecovery: PartialFunction[Throwable, Boolean] = { case ex: Throwable => - if (settings.YrichExes.value) { - val sources = implicitly[Sources] - echo("\n" + ex.getMessage) - echo( - if (isReplDebug) "[searching " + sources.path + " for exception contexts...]" - else "[searching for exception contexts...]" - ) - echo(Exceptional(ex).force().context()) - } - else { - echo(util.stackTraceString(ex)) - } + echo(intp.global.throwableAsString(ex)) + ex match { case _: NoSuchMethodError | _: NoClassDefFoundError => - echo("Unrecoverable error.") + echo("\nUnrecoverable error.") throw ex case _ => - def fn(): Boolean = in.readYesOrNo(replayQuestionMessage, { echo("\nYou must enter y or n.") ; fn() }) + def fn(): Boolean = + try in.readYesOrNo(replayQuestionMessage, { echo("\nYou must enter y or n.") ; fn() }) + catch { case _: RuntimeException => false } + if (fn()) replay() else echo("\nAbandoning crashed session.") } + true } /** The main read-eval-print loop for the repl. It calls @@ -564,66 +583,89 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: in readLine prompt } // return false if repl should exit - def processLine(line: String): Boolean = + def processLine(line: String): Boolean = { + if (isAsync) { + if (!awaitInitialized()) return false + runThunks() + } if (line eq null) false // assume null means EOF else command(line) match { case Result(false, _) => false case Result(_, Some(finalLine)) => addReplay(finalLine) ; true case _ => true } - - while (true) { - try if (!processLine(readOneLine)) return - catch crashRecovery } + def innerLoop() { + if ( try processLine(readOneLine()) catch crashRecovery ) + innerLoop() + } + innerLoop() } /** interpret all lines from a specified file */ - def interpretAllFrom(file: File) { - val oldIn = in - val oldReplay = replayCommandStack - - try file applyReader { reader => - in = SimpleReader(reader, out, false) - echo("Loading " + file + "...") - loop() - } - finally { - in = oldIn - replayCommandStack = oldReplay + def interpretAllFrom(file: File) { + savingReader { + savingReplayStack { + file applyReader { reader => + in = SimpleReader(reader, out, false) + echo("Loading " + file + "...") + loop() + } + } } } - /** create a new interpreter and replay all commands so far */ + /** create a new interpreter and replay the given commands */ def replay() { - closeInterpreter() - createInterpreter() - for (cmd <- replayCommands) { + reset() + if (replayCommandStack.isEmpty) + echo("Nothing to replay.") + else for (cmd <- replayCommands) { echo("Replaying: " + cmd) // flush because maybe cmd will have its own output command(cmd) echo("") } } - + def resetCommand() { + echo("Resetting repl state.") + if (replayCommandStack.nonEmpty) { + echo("Forgetting this session history:\n") + replayCommands foreach echo + echo("") + replayCommandStack = Nil + } + if (intp.namedDefinedTerms.nonEmpty) + echo("Forgetting all expression results and named terms: " + intp.namedDefinedTerms.mkString(", ")) + if (intp.definedTypes.nonEmpty) + echo("Forgetting defined types: " + intp.definedTypes.mkString(", ")) + + reset() + } + + def reset() { + intp.reset() + // unleashAndSetPhase() + } + /** fork a shell and run a command */ lazy val shCommand = new LoopCommand("sh", "run a shell command (result is implicitly => List[String])") { override def usage = "" def apply(line: String): Result = line match { case "" => showUsage() - case _ => + case _ => val toRun = classOf[ProcessResult].getName + "(" + string2codeQuoted(line) + ")" intp interpret toRun () } } - + def withFile(filename: String)(action: File => Unit) { val f = File(filename) - + if (f.exists) action(f) else echo("That file does not exist") } - + def loadCommand(arg: String) = { var shouldReplay: Option[String] = None withFile(arg)(f => { @@ -657,23 +699,36 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: } else echo("The path '" + f + "' doesn't seem to exist.") } - + def powerCmd(): Result = { if (isReplPower) "Already in power mode." - else enablePowerMode() + else enablePowerMode(false) } - def enablePowerMode() = { - //replProps.power setValue true - //power.unleash() - //echo(power.banner) + + def enablePowerMode(isDuringInit: Boolean) = { + // replProps.power setValue true + // unleashAndSetPhase() + // asyncEcho(isDuringInit, power.banner) } - + // private def unleashAndSetPhase() { +// if (isReplPower) { +// // power.unleash() +// // Set the phase to "typer" +// intp beSilentDuring phaseCommand("typer") +// } +// } + + def asyncEcho(async: Boolean, msg: => String) { + if (async) asyncMessage(msg) + else echo(msg) + } + def verbosity() = { - val old = intp.printResults - intp.printResults = !old - echo("Switched " + (if (old) "off" else "on") + " result printing.") + // val old = intp.printResults + // intp.printResults = !old + // echo("Switched " + (if (old) "off" else "on") + " result printing.") } - + /** Run one command submitted by the user. Two values are returned: * (1) whether to keep running, (2) the line to record for replay, * if any. */ @@ -688,11 +743,11 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: else if (intp.global == null) Result(false, None) // Notice failure to create compiler else Result(true, interpretStartingWith(line)) } - + private def readWhile(cond: String => Boolean) = { Iterator continually in.readLine("") takeWhile (x => x != null && cond(x)) } - + def pasteCommand(): Result = { echo("// Entering paste mode (ctrl-D to finish)\n") val code = readWhile(_ => true) mkString "\n" @@ -700,23 +755,19 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: intp interpret code () } - + private object paste extends Pasted { val ContinueString = " | " val PromptString = "scala> " - + def interpret(line: String): Unit = { echo(line.trim) intp interpret line echo("") } - + def transcript(start: String) = { - // Printing this message doesn't work very well because it's buried in the - // transcript they just pasted. Todo: a short timer goes off when - // lines stop coming which tells them to hit ctrl-D. - // - // echo("// Detected repl transcript paste: ctrl-D to finish.") + echo("\n// Detected repl transcript paste: ctrl-D to finish.\n") apply(Iterator(start) ++ readWhile(_.trim != PromptString.trim)) } } @@ -731,7 +782,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: def interpretStartingWith(code: String): Option[String] = { // signal completion non-completion input has been received in.completion.resetVerbosity() - + def reallyInterpret = { val reallyResult = intp.interpret(code) (reallyResult, reallyResult match { @@ -741,7 +792,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: if (in.interactive && code.endsWith("\n\n")) { echo("You typed two blank lines. Starting a new command.") None - } + } else in.readLine(ContinueString) match { case null => // we know compilation is going to fail since we're at EOF and the @@ -755,10 +806,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: } }) } - + /** Here we place ourselves between the user and the interpreter and examine * the input they are ostensibly submitting. We intervene in several cases: - * + * * 1) If the line starts with "scala> " it is assumed to be an interpreter paste. * 2) If the line starts with "." (but not ".." or "./") it is treated as an invocation * on the previous result. @@ -773,28 +824,12 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: else if (Completion.looksLikeInvocation(code) && intp.mostRecentVar != "") { interpretStartingWith(intp.mostRecentVar + code) } - else { - def runCompletion = in.completion execute code map (intp bindValue _) - /** Due to my accidentally letting file completion execution sneak ahead - * of actual parsing this now operates in such a way that the scala - * interpretation always wins. However to avoid losing useful file - * completion I let it fail and then check the others. So if you - * type /tmp it will echo a failure and then give you a Directory object. - * It's not pretty: maybe I'll implement the silence bits I need to avoid - * echoing the failure. - */ - if (intp isParseable code) { - val (code, result) = reallyInterpret - //if (power != null && code == IR.Error) - // runCompletion - - result - } - else runCompletion match { - case Some(_) => None // completion hit: avoid the latent error - case _ => reallyInterpret._2 // trigger the latent error - } + else if (code.trim startsWith "//") { + // line comment, do nothing + None } + else + reallyInterpret._2 } // runs :load `file` on any files passed via -i @@ -808,7 +843,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: } case _ => } - + /** Tries to create a JLineReader, falling back to SimpleReader: * unless settings or properties are such that it should start * with SimpleReader. @@ -816,7 +851,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: def chooseReader(settings: Settings): InteractiveReader = { if (settings.Xnojline.value || Properties.isEmacsShell) SimpleReader() - else try SparkJLineReader( + else try new SparkJLineReader( if (settings.noCompletion.value) NoCompletion else new SparkJLineCompletion(intp) ) @@ -827,22 +862,71 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: } } - def initializeSpark() { - intp.beQuietDuring { - command(""" - org.apache.spark.repl.Main.interp.out.println("Creating SparkContext..."); - org.apache.spark.repl.Main.interp.out.flush(); - @transient val sc = org.apache.spark.repl.Main.interp.createSparkContext(); - org.apache.spark.repl.Main.interp.out.println("Spark context available as sc."); - org.apache.spark.repl.Main.interp.out.flush(); - """) - command("import org.apache.spark.SparkContext._") - } - echo("Type in expressions to have them evaluated.") - echo("Type :help for more information.") - } + val u: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe + val m = u.runtimeMirror(getClass.getClassLoader) + private def tagOfStaticClass[T: ClassTag]: u.TypeTag[T] = + u.TypeTag[T]( + m, + new TypeCreator { + def apply[U <: ApiUniverse with Singleton](m: Mirror[U]): U # Type = + m.staticClass(classTag[T].runtimeClass.getName).toTypeConstructor.asInstanceOf[U # Type] + }) - var sparkContext: SparkContext = null + def process(settings: Settings): Boolean = savingContextLoader { + this.settings = settings + createInterpreter() + + // sets in to some kind of reader depending on environmental cues + in = in0 match { + case Some(reader) => SimpleReader(reader, out, true) + case None => + // some post-initialization + chooseReader(settings) match { + case x: SparkJLineReader => addThunk(x.consoleReader.postInit) ; x + case x => x + } + } + lazy val tagOfSparkIMain = tagOfStaticClass[org.apache.spark.repl.SparkIMain] + // Bind intp somewhere out of the regular namespace where + // we can get at it in generated code. + addThunk(intp.quietBind(NamedParam[SparkIMain]("$intp", intp)(tagOfSparkIMain, classTag[SparkIMain]))) + addThunk({ + import scala.tools.nsc.io._ + import Properties.userHome + import scala.compat.Platform.EOL + val autorun = replProps.replAutorunCode.option flatMap (f => io.File(f).safeSlurp()) + if (autorun.isDefined) intp.quietRun(autorun.get) + }) + + addThunk(printWelcome()) + addThunk(initializeSpark()) + + loadFiles(settings) + // it is broken on startup; go ahead and exit + if (intp.reporter.hasErrors) + return false + + // This is about the illusion of snappiness. We call initialize() + // which spins off a separate thread, then print the prompt and try + // our best to look ready. The interlocking lazy vals tend to + // inter-deadlock, so we break the cycle with a single asynchronous + // message to an actor. + if (isAsync) { + intp initialize initializedCallback() + createAsyncListener() // listens for signal to run postInitialization + } + else { + intp.initializeSynchronous() + postInitialization() + } + // printWelcome() + + try loop() + catch AbstractOrMissingHandler() + finally closeInterpreter() + + true + } def createSparkContext(): SparkContext = { val uri = System.getenv("SPARK_EXECUTOR_URI") @@ -856,85 +940,26 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: if (prop != null) prop else "local" } } - val jars = Option(System.getenv("ADD_JARS")).map(_.split(',')) - .getOrElse(new Array[String](0)) - .map(new java.io.File(_).getAbsolutePath) - try { - sparkContext = new SparkContext(master, "Spark shell", System.getenv("SPARK_HOME"), jars) - } catch { - case e: Exception => - e.printStackTrace() - echo("Failed to create SparkContext, exiting...") - sys.exit(1) - } + val jars = SparkILoop.getAddedJars.map(new java.io.File(_).getAbsolutePath) + sparkContext = new SparkContext(master, "Spark shell", System.getenv("SPARK_HOME"), jars) + echo("Created spark context..") sparkContext } - def process(settings: Settings): Boolean = { - // Ensure logging is initialized before any Spark threads try to use logs - // (because SLF4J initialization is not thread safe) - initLogging() - - printWelcome() - echo("Initializing interpreter...") - - // Add JARS specified in Spark's ADD_JARS variable to classpath - val jars = Option(System.getenv("ADD_JARS")).map(_.split(',')).getOrElse(new Array[String](0)) - jars.foreach(settings.classpath.append(_)) - - this.settings = settings - createInterpreter() - - // sets in to some kind of reader depending on environmental cues - in = in0 match { - case Some(reader) => SimpleReader(reader, out, true) - case None => chooseReader(settings) - } - - loadFiles(settings) - // it is broken on startup; go ahead and exit - if (intp.reporter.hasErrors) - return false - - try { - // this is about the illusion of snappiness. We call initialize() - // which spins off a separate thread, then print the prompt and try - // our best to look ready. Ideally the user will spend a - // couple seconds saying "wow, it starts so fast!" and by the time - // they type a command the compiler is ready to roll. - intp.initialize() - initializeSpark() - if (isReplPower) { - echo("Starting in power mode, one moment...\n") - enablePowerMode() - } - loop() - } - finally closeInterpreter() - true - } - /** process command-line arguments and do as they request */ def process(args: Array[String]): Boolean = { - val command = new CommandLine(args.toList, msg => echo("scala: " + msg)) + val command = new CommandLine(args.toList, echo) def neededHelp(): String = (if (command.settings.help.value) command.usageMsg + "\n" else "") + (if (command.settings.Xhelp.value) command.xusageMsg + "\n" else "") - + // if they asked for no help and command is valid, we call the real main neededHelp() match { case "" => command.ok && process(command.settings) case help => echoNoNL(help) ; true } } - - @deprecated("Use `process` instead", "2.9.0") - def main(args: Array[String]): Unit = { - if (isReplDebug) - System.out.println(new java.util.Date) - - process(args) - } + @deprecated("Use `process` instead", "2.9.0") def main(settings: Settings): Unit = process(settings) } @@ -943,15 +968,17 @@ object SparkILoop { implicit def loopToInterpreter(repl: SparkILoop): SparkIMain = repl.intp private def echo(msg: String) = Console println msg + def getAddedJars: Array[String] = Option(System.getenv("ADD_JARS")).map(_.split(',')).getOrElse(new Array[String](0)) + // Designed primarily for use by test code: take a String with a // bunch of code, and prints out a transcript of what it would look // like if you'd just typed it into the repl. def runForTranscript(code: String, settings: Settings): String = { import java.io.{ BufferedReader, StringReader, OutputStreamWriter } - + stringFromStream { ostream => Console.withOut(ostream) { - val output = new PrintWriter(new OutputStreamWriter(ostream), true) { + val output = new JPrintWriter(new OutputStreamWriter(ostream), true) { override def write(str: String) = { // completely skip continuation lines if (str forall (ch => ch.isWhitespace || ch == '|')) () @@ -970,26 +997,29 @@ object SparkILoop { } } val repl = new SparkILoop(input, output) + if (settings.classpath.isDefault) settings.classpath.value = sys.props("java.class.path") + getAddedJars.foreach(settings.classpath.append(_)) + repl process settings } } } - + /** Creates an interpreter loop with default settings and feeds * the given code to it as input. */ def run(code: String, sets: Settings = new Settings): String = { import java.io.{ BufferedReader, StringReader, OutputStreamWriter } - + stringFromStream { ostream => Console.withOut(ostream) { val input = new BufferedReader(new StringReader(code)) - val output = new PrintWriter(new OutputStreamWriter(ostream), true) - val repl = new SparkILoop(input, output) - + val output = new JPrintWriter(new OutputStreamWriter(ostream), true) + val repl = new ILoop(input, output) + if (sets.classpath.isDefault) sets.classpath.value = sys.props("java.class.path") @@ -998,32 +1028,4 @@ object SparkILoop { } } def run(lines: List[String]): String = run(lines map (_ + "\n") mkString) - - // provide the enclosing type T - // in order to set up the interpreter's classpath and parent class loader properly - def breakIf[T: Manifest](assertion: => Boolean, args: NamedParam*): Unit = - if (assertion) break[T](args.toList) - - // start a repl, binding supplied args - def break[T: Manifest](args: List[NamedParam]): Unit = { - val msg = if (args.isEmpty) "" else " Binding " + args.size + " value%s.".format( - if (args.size == 1) "" else "s" - ) - echo("Debug repl starting." + msg) - val repl = new SparkILoop { - override def prompt = "\ndebug> " - } - repl.settings = new Settings(echo) - repl.settings.embeddedDefaults[T] - repl.createInterpreter() - repl.in = SparkJLineReader(repl) - - // rebind exit so people don't accidentally call sys.exit by way of predef - repl.quietRun("""def exit = println("Type :quit to resume program execution.")""") - args foreach (p => repl.bind(p.name, p.tpe, p.value)) - repl.loop() - - echo("\nDebug repl exiting.") - repl.closeInterpreter() - } } diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala new file mode 100644 index 0000000000..21b1ba305d --- /dev/null +++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala @@ -0,0 +1,143 @@ +/* NSC -- new Scala compiler + * Copyright 2005-2013 LAMP/EPFL + * @author Paul Phillips + */ + +package org.apache.spark.repl + +import scala.tools.nsc._ +import scala.tools.nsc.interpreter._ + +import scala.reflect.internal.util.Position +import scala.util.control.Exception.ignoring +import scala.tools.nsc.util.stackTraceString + +/** + * Machinery for the asynchronous initialization of the repl. + */ +trait SparkILoopInit { + self: SparkILoop => + + /** Print a welcome message */ + def printWelcome() { + echo("""Welcome to + ____ __ + / __/__ ___ _____/ /__ + _\ \/ _ \/ _ `/ __/ '_/ + /___/ .__/\_,_/_/ /_/\_\ version 0.9.0-SNAPSHOT + /_/ +""") + import Properties._ + val welcomeMsg = "Using Scala %s (%s, Java %s)".format( + versionString, javaVmName, javaVersion) + echo(welcomeMsg) + echo("Type in expressions to have them evaluated.") + echo("Type :help for more information.") + } + + protected def asyncMessage(msg: String) { + if (isReplInfo || isReplPower) + echoAndRefresh(msg) + } + + private val initLock = new java.util.concurrent.locks.ReentrantLock() + private val initCompilerCondition = initLock.newCondition() // signal the compiler is initialized + private val initLoopCondition = initLock.newCondition() // signal the whole repl is initialized + private val initStart = System.nanoTime + + private def withLock[T](body: => T): T = { + initLock.lock() + try body + finally initLock.unlock() + } + // a condition used to ensure serial access to the compiler. + @volatile private var initIsComplete = false + @volatile private var initError: String = null + private def elapsed() = "%.3f".format((System.nanoTime - initStart).toDouble / 1000000000L) + + // the method to be called when the interpreter is initialized. + // Very important this method does nothing synchronous (i.e. do + // not try to use the interpreter) because until it returns, the + // repl's lazy val `global` is still locked. + protected def initializedCallback() = withLock(initCompilerCondition.signal()) + + // Spins off a thread which awaits a single message once the interpreter + // has been initialized. + protected def createAsyncListener() = { + io.spawn { + withLock(initCompilerCondition.await()) + asyncMessage("[info] compiler init time: " + elapsed() + " s.") + postInitialization() + } + } + + // called from main repl loop + protected def awaitInitialized(): Boolean = { + if (!initIsComplete) + withLock { while (!initIsComplete) initLoopCondition.await() } + if (initError != null) { + println(""" + |Failed to initialize the REPL due to an unexpected error. + |This is a bug, please, report it along with the error diagnostics printed below. + |%s.""".stripMargin.format(initError) + ) + false + } else true + } + // private def warningsThunks = List( + // () => intp.bind("lastWarnings", "" + typeTag[List[(Position, String)]], intp.lastWarnings _), + // ) + + protected def postInitThunks = List[Option[() => Unit]]( + Some(intp.setContextClassLoader _), + if (isReplPower) Some(() => enablePowerMode(true)) else None + ).flatten + // ++ ( + // warningsThunks + // ) + // called once after init condition is signalled + protected def postInitialization() { + try { + postInitThunks foreach (f => addThunk(f())) + runThunks() + } catch { + case ex: Throwable => + initError = stackTraceString(ex) + throw ex + } finally { + initIsComplete = true + + if (isAsync) { + asyncMessage("[info] total init time: " + elapsed() + " s.") + withLock(initLoopCondition.signal()) + } + } + } + + def initializeSpark() { + intp.beQuietDuring { + command(""" + @transient val sc = org.apache.spark.repl.Main.interp.createSparkContext(); + """) + command("import org.apache.spark.SparkContext._") + } + echo("Spark context available as sc.") + } + + // code to be executed only after the interpreter is initialized + // and the lazy val `global` can be accessed without risk of deadlock. + private var pendingThunks: List[() => Unit] = Nil + protected def addThunk(body: => Unit) = synchronized { + pendingThunks :+= (() => body) + } + protected def runThunks(): Unit = synchronized { + if (pendingThunks.nonEmpty) + logDebug("Clearing " + pendingThunks.size + " thunks.") + + while (pendingThunks.nonEmpty) { + val thunk = pendingThunks.head + pendingThunks = pendingThunks.tail + thunk() + } + } +} diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 870e12de34..e1455ef8a1 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -1,5 +1,5 @@ /* NSC -- new Scala compiler - * Copyright 2005-2011 LAMP/EPFL + * Copyright 2005-2013 LAMP/EPFL * @author Martin Odersky */ @@ -9,304 +9,333 @@ import scala.tools.nsc._ import scala.tools.nsc.interpreter._ import Predef.{ println => _, _ } -import java.io.{ PrintWriter } -import java.lang.reflect +import util.stringFromWriter +import scala.reflect.internal.util._ import java.net.URL -import util.{ Set => _, _ } -import io.{ AbstractFile, PlainFile, VirtualDirectory } -import reporters.{ ConsoleReporter, Reporter } -import symtab.{ Flags, Names } -import scala.tools.nsc.interpreter.{ Results => IR } +import scala.sys.BooleanProp +import io.{AbstractFile, PlainFile, VirtualDirectory} + +import reporters._ +import symtab.Flags +import scala.reflect.internal.Names import scala.tools.util.PathResolver -import scala.tools.nsc.util.{ ScalaClassLoader, Exceptional } +import scala.tools.nsc.util.ScalaClassLoader import ScalaClassLoader.URLClassLoader -import Exceptional.unwrap +import scala.tools.nsc.util.Exceptional.unwrap import scala.collection.{ mutable, immutable } -import scala.PartialFunction.{ cond, condOpt } import scala.util.control.Exception.{ ultimately } -import scala.reflect.NameTransformer import SparkIMain._ +import java.util.concurrent.Future +import typechecker.Analyzer +import scala.language.implicitConversions +import scala.reflect.runtime.{ universe => ru } +import scala.reflect.{ ClassTag, classTag } +import scala.tools.reflect.StdRuntimeTags._ +import scala.util.control.ControlThrowable +import util.stackTraceString import org.apache.spark.HttpServer import org.apache.spark.util.Utils import org.apache.spark.SparkEnv +import org.apache.spark.Logging -/** An interpreter for Scala code. - * - * The main public entry points are compile(), interpret(), and bind(). - * The compile() method loads a complete Scala file. The interpret() method - * executes one line of Scala code at the request of the user. The bind() - * method binds an object to a variable that can then be used by later - * interpreted code. - * - * The overall approach is based on compiling the requested code and then - * using a Java classloader and Java reflection to run the code - * and access its results. - * - * In more detail, a single compiler instance is used - * to accumulate all successfully compiled or interpreted Scala code. To - * "interpret" a line of code, the compiler generates a fresh object that - * includes the line of code and which has public member(s) to export - * all variables defined by that code. To extract the result of an - * interpreted line to show the user, a second "result object" is created - * which imports the variables exported by the above object and then - * exports a single member named "$export". To accomodate user expressions - * that read from variables or methods defined in previous statements, "import" - * statements are used. - * - * This interpreter shares the strengths and weaknesses of using the - * full compiler-to-Java. The main strength is that interpreted code - * behaves exactly as does compiled code, including running at full speed. - * The main weakness is that redefining classes and methods is not handled - * properly, because rebinding at the Java level is technically difficult. - * - * @author Moez A. Abdel-Gawad - * @author Lex Spoon - */ -class SparkIMain(val settings: Settings, protected val out: PrintWriter) extends SparkImports { - imain => - - /** construct an interpreter that reports to Console */ - def this(settings: Settings) = this(settings, new NewLinePrintWriter(new ConsoleWriter, true)) - def this() = this(new Settings()) +// /** directory to save .class files to */ +// private class ReplVirtualDirectory(out: JPrintWriter) extends VirtualDirectory("((memory))", None) { +// private def pp(root: AbstractFile, indentLevel: Int) { +// val spaces = " " * indentLevel +// out.println(spaces + root.name) +// if (root.isDirectory) +// root.toList sortBy (_.name) foreach (x => pp(x, indentLevel + 1)) +// } +// // print the contents hierarchically +// def show() = pp(this, 0) +// } - /** whether to print out result lines */ - var printResults: Boolean = true - - /** whether to print errors */ - var totalSilence: Boolean = false - - private val RESULT_OBJECT_PREFIX = "RequestResult$" - - lazy val formatting: Formatting = new Formatting { - val prompt = Properties.shellPromptString - } - import formatting._ - - val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1") - - /** Local directory to save .class files too */ - val outputDir = { - val tmp = System.getProperty("java.io.tmpdir") - val rootDir = System.getProperty("spark.repl.classdir", tmp) - Utils.createTempDir(rootDir) - } - if (SPARK_DEBUG_REPL) { - echo("Output directory: " + outputDir) - } - - /** Scala compiler virtual directory for outputDir */ - val virtualDirectory = new PlainFile(outputDir) - - /** Jetty server that will serve our classes to worker nodes */ - val classServer = new HttpServer(outputDir) - - // Start the classServer and store its URI in a spark system property - // (which will be passed to executors so that they can connect to it) - classServer.start() - System.setProperty("spark.repl.class.uri", classServer.uri) - if (SPARK_DEBUG_REPL) { - echo("Class server started, URI = " + classServer.uri) - } - - /* - // directory to save .class files to - val virtualDirectory = new VirtualDirectory("(memory)", None) { - private def pp(root: io.AbstractFile, indentLevel: Int) { - val spaces = " " * indentLevel - out.println(spaces + root.name) - if (root.isDirectory) - root.toList sortBy (_.name) foreach (x => pp(x, indentLevel + 1)) - } - // print the contents hierarchically - def show() = pp(this, 0) - } - */ - - /** reporter */ - lazy val reporter: ConsoleReporter = new SparkIMain.ReplReporter(this) - import reporter.{ printMessage, withoutTruncating } - - // not sure if we have some motivation to print directly to console - private def echo(msg: String) { Console println msg } - - // protected def defaultImports: List[String] = List("_root_.scala.sys.exit") - - /** We're going to go to some trouble to initialize the compiler asynchronously. - * It's critical that nothing call into it until it's been initialized or we will - * run into unrecoverable issues, but the perceived repl startup time goes - * through the roof if we wait for it. So we initialize it with a future and - * use a lazy val to ensure that any attempt to use the compiler object waits - * on the future. + /** An interpreter for Scala code. + * + * The main public entry points are compile(), interpret(), and bind(). + * The compile() method loads a complete Scala file. The interpret() method + * executes one line of Scala code at the request of the user. The bind() + * method binds an object to a variable that can then be used by later + * interpreted code. + * + * The overall approach is based on compiling the requested code and then + * using a Java classloader and Java reflection to run the code + * and access its results. + * + * In more detail, a single compiler instance is used + * to accumulate all successfully compiled or interpreted Scala code. To + * "interpret" a line of code, the compiler generates a fresh object that + * includes the line of code and which has public member(s) to export + * all variables defined by that code. To extract the result of an + * interpreted line to show the user, a second "result object" is created + * which imports the variables exported by the above object and then + * exports members called "$eval" and "$print". To accomodate user expressions + * that read from variables or methods defined in previous statements, "import" + * statements are used. + * + * This interpreter shares the strengths and weaknesses of using the + * full compiler-to-Java. The main strength is that interpreted code + * behaves exactly as does compiled code, including running at full speed. + * The main weakness is that redefining classes and methods is not handled + * properly, because rebinding at the Java level is technically difficult. + * + * @author Moez A. Abdel-Gawad + * @author Lex Spoon */ - private val _compiler: Global = newCompiler(settings, reporter) - private var _initializeComplete = false - def isInitializeComplete = _initializeComplete + class SparkIMain(initialSettings: Settings, val out: JPrintWriter) extends SparkImports with Logging { + imain => - private def _initialize(): Boolean = { - val source = """ - |class $repl_$init { - | List(1) map (_ + 1) - |} - |""".stripMargin - - val result = try { - new _compiler.Run() compileSources List(new BatchSourceFile("", source)) - if (isReplDebug || settings.debug.value) { - // Can't use printMessage here, it deadlocks - Console.println("Repl compiler initialized.") + val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1") + + /** Local directory to save .class files too */ + val outputDir = { + val tmp = System.getProperty("java.io.tmpdir") + val rootDir = System.getProperty("spark.repl.classdir", tmp) + Utils.createTempDir(rootDir) } - // addImports(defaultImports: _*) - true - } - catch { - case x: AbstractMethodError => - printMessage(""" - |Failed to initialize compiler: abstract method error. - |This is most often remedied by a full clean and recompile. - |""".stripMargin - ) - x.printStackTrace() - false - case x: MissingRequirementError => printMessage(""" - |Failed to initialize compiler: %s not found. - |** Note that as of 2.8 scala does not assume use of the java classpath. - |** For the old behavior pass -usejavacp to scala, or if using a Settings - |** object programatically, settings.usejavacp.value = true.""".stripMargin.format(x.req) + if (SPARK_DEBUG_REPL) { + echo("Output directory: " + outputDir) + } + + val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles + val classServer = new HttpServer(outputDir) /** Jetty server that will serve our classes to worker nodes */ + private var currentSettings: Settings = initialSettings + var printResults = true // whether to print result lines + var totalSilence = false // whether to print anything + private var _initializeComplete = false // compiler is initialized + private var _isInitialized: Future[Boolean] = null // set up initialization future + private var bindExceptions = true // whether to bind the lastException variable + private var _executionWrapper = "" // code to be wrapped around all lines + + + // Start the classServer and store its URI in a spark system property + // (which will be passed to executors so that they can connect to it) + classServer.start() + System.setProperty("spark.repl.class.uri", classServer.uri) + if (SPARK_DEBUG_REPL) { + echo("Class server started, URI = " + classServer.uri) + } + + /** We're going to go to some trouble to initialize the compiler asynchronously. + * It's critical that nothing call into it until it's been initialized or we will + * run into unrecoverable issues, but the perceived repl startup time goes + * through the roof if we wait for it. So we initialize it with a future and + * use a lazy val to ensure that any attempt to use the compiler object waits + * on the future. + */ + private var _classLoader: AbstractFileClassLoader = null // active classloader + private val _compiler: Global = newCompiler(settings, reporter) // our private compiler + + private val nextReqId = { + var counter = 0 + () => { counter += 1 ; counter } + } + + def compilerClasspath: Seq[URL] = ( + if (isInitializeComplete) global.classPath.asURLs + else new PathResolver(settings).result.asURLs // the compiler's classpath ) - false + def settings = currentSettings + def mostRecentLine = prevRequestList match { + case Nil => "" + case req :: _ => req.originalLine } - - try result - finally _initializeComplete = result - } - - // set up initialization future - private var _isInitialized: () => Boolean = null - def initialize() = synchronized { - if (_isInitialized == null) - _isInitialized = scala.concurrent.ops future _initialize() - } + // Run the code body with the given boolean settings flipped to true. + def withoutWarnings[T](body: => T): T = beQuietDuring { + val saved = settings.nowarn.value + if (!saved) + settings.nowarn.value = true - /** the public, go through the future compiler */ - lazy val global: Global = { - initialize() - - // blocks until it is ; false means catastrophic failure - if (_isInitialized()) _compiler - else null - } - @deprecated("Use `global` for access to the compiler instance.", "2.9.0") - lazy val compiler: global.type = global - - import global._ - - object naming extends { - val global: imain.global.type = imain.global - } with Naming { - // make sure we don't overwrite their unwisely named res3 etc. - override def freshUserVarName(): String = { - val name = super.freshUserVarName() - if (definedNameMap contains name) freshUserVarName() - else name + try body + finally if (!saved) settings.nowarn.value = false } - } - import naming._ - // object dossiers extends { - // val intp: imain.type = imain - // } with Dossiers { } - // import dossiers._ - - lazy val memberHandlers = new { - val intp: imain.type = imain - } with SparkMemberHandlers - import memberHandlers._ - - def atPickler[T](op: => T): T = atPhase(currentRun.picklerPhase)(op) - def afterTyper[T](op: => T): T = atPhase(currentRun.typerPhase.next)(op) + /** construct an interpreter that reports to Console */ + def this(settings: Settings) = this(settings, new NewLinePrintWriter(new ConsoleWriter, true)) + def this() = this(new Settings()) - /** Temporarily be quiet */ - def beQuietDuring[T](operation: => T): T = { - val wasPrinting = printResults - ultimately(printResults = wasPrinting) { - if (isReplDebug) echo(">> beQuietDuring") - else printResults = false - - operation + lazy val repllog: Logger = new Logger { + val out: JPrintWriter = imain.out + val isInfo: Boolean = BooleanProp keyExists "scala.repl.info" + val isDebug: Boolean = BooleanProp keyExists "scala.repl.debug" + val isTrace: Boolean = BooleanProp keyExists "scala.repl.trace" } - } - def beSilentDuring[T](operation: => T): T = { - val saved = totalSilence - totalSilence = true - try operation - finally totalSilence = saved - } - - def quietRun[T](code: String) = beQuietDuring(interpret(code)) - - /** whether to bind the lastException variable */ - private var bindLastException = true - - /** A string representing code to be wrapped around all lines. */ - private var _executionWrapper: String = "" - def executionWrapper = _executionWrapper - def setExecutionWrapper(code: String) = _executionWrapper = code - def clearExecutionWrapper() = _executionWrapper = "" - - /** Temporarily stop binding lastException */ - def withoutBindingLastException[T](operation: => T): T = { - val wasBinding = bindLastException - ultimately(bindLastException = wasBinding) { - bindLastException = false - operation + lazy val formatting: Formatting = new Formatting { + val prompt = Properties.shellPromptString } - } - - protected def createLineManager(): Line.Manager = new Line.Manager - lazy val lineManager = createLineManager() + lazy val reporter: ConsoleReporter = new SparkIMain.ReplReporter(this) - /** interpreter settings */ - lazy val isettings = new SparkISettings(this) + import formatting._ + import reporter.{ printMessage, withoutTruncating } - /** Instantiate a compiler. Subclasses can override this to - * change the compiler class used by this interpreter. */ - protected def newCompiler(settings: Settings, reporter: Reporter) = { - settings.outputDirs setSingleOutput virtualDirectory - settings.exposeEmptyPackage.value = true - new Global(settings, reporter) - } - - /** the compiler's classpath, as URL's */ - lazy val compilerClasspath: List[URL] = new PathResolver(settings) asURLs + // This exists mostly because using the reporter too early leads to deadlock. + private def echo(msg: String) { Console println msg } + private def _initSources = List(new BatchSourceFile("", "class $repl_$init { }")) + private def _initialize() = { + try { + // todo. if this crashes, REPL will hang + new _compiler.Run() compileSources _initSources + _initializeComplete = true + true + } + catch AbstractOrMissingHandler() + } + private def tquoted(s: String) = "\"\"\"" + s + "\"\"\"" - /* A single class loader is used for all commands interpreted by this Interpreter. + // argument is a thunk to execute after init is done + def initialize(postInitSignal: => Unit) { + synchronized { + if (_isInitialized == null) { + _isInitialized = io.spawn { + try _initialize() + finally postInitSignal + } + } + } + } + def initializeSynchronous(): Unit = { + if (!isInitializeComplete) { + _initialize() + assert(global != null, global) + } + } + def isInitializeComplete = _initializeComplete + + /** the public, go through the future compiler */ + lazy val global: Global = { + if (isInitializeComplete) _compiler + else { + // If init hasn't been called yet you're on your own. + if (_isInitialized == null) { + logWarning("Warning: compiler accessed before init set up. Assuming no postInit code.") + initialize(()) + } + // // blocks until it is ; false means catastrophic failure + if (_isInitialized.get()) _compiler + else null + } + } + @deprecated("Use `global` for access to the compiler instance.", "2.9.0") + lazy val compiler: global.type = global + + import global._ + import definitions.{ScalaPackage, JavaLangPackage, termMember, typeMember} + import rootMirror.{RootClass, getClassIfDefined, getModuleIfDefined, getRequiredModule, getRequiredClass} + + implicit class ReplTypeOps(tp: Type) { + def orElse(other: => Type): Type = if (tp ne NoType) tp else other + def andAlso(fn: Type => Type): Type = if (tp eq NoType) tp else fn(tp) + } + + // TODO: If we try to make naming a lazy val, we run into big time + // scalac unhappiness with what look like cycles. It has not been easy to + // reduce, but name resolution clearly takes different paths. + object naming extends { + val global: imain.global.type = imain.global + } with Naming { + // make sure we don't overwrite their unwisely named res3 etc. + def freshUserTermName(): TermName = { + val name = newTermName(freshUserVarName()) + if (definedNameMap contains name) freshUserTermName() + else name + } + def isUserTermName(name: Name) = isUserVarName("" + name) + def isInternalTermName(name: Name) = isInternalVarName("" + name) + } + import naming._ + + object deconstruct extends { + val global: imain.global.type = imain.global + } with StructuredTypeStrings + + lazy val memberHandlers = new { + val intp: imain.type = imain + } with SparkMemberHandlers + import memberHandlers._ + + /** Temporarily be quiet */ + def beQuietDuring[T](body: => T): T = { + val saved = printResults + printResults = false + try body + finally printResults = saved + } + def beSilentDuring[T](operation: => T): T = { + val saved = totalSilence + totalSilence = true + try operation + finally totalSilence = saved + } + + def quietRun[T](code: String) = beQuietDuring(interpret(code)) + + + private def logAndDiscard[T](label: String, alt: => T): PartialFunction[Throwable, T] = { + case t: ControlThrowable => throw t + case t: Throwable => + logDebug(label + ": " + unwrap(t)) + logDebug(stackTraceString(unwrap(t))) + alt + } + /** takes AnyRef because it may be binding a Throwable or an Exceptional */ + + private def withLastExceptionLock[T](body: => T, alt: => T): T = { + assert(bindExceptions, "withLastExceptionLock called incorrectly.") + bindExceptions = false + + try beQuietDuring(body) + catch logAndDiscard("withLastExceptionLock", alt) + finally bindExceptions = true + } + + def executionWrapper = _executionWrapper + def setExecutionWrapper(code: String) = _executionWrapper = code + def clearExecutionWrapper() = _executionWrapper = "" + + /** interpreter settings */ + lazy val isettings = new SparkISettings(this) + + /** Instantiate a compiler. Overridable. */ + protected def newCompiler(settings: Settings, reporter: Reporter): ReplGlobal = { + settings.outputDirs setSingleOutput virtualDirectory + settings.exposeEmptyPackage.value = true + new Global(settings, reporter) with ReplGlobal { + override def toString: String = "" + } + } + + /** Parent classloader. Overridable. */ + protected def parentClassLoader: ClassLoader = + SparkHelper.explicitParentLoader(settings).getOrElse( this.getClass.getClassLoader() ) + + /* A single class loader is used for all commands interpreted by this Interpreter. It would also be possible to create a new class loader for each command to interpret. The advantages of the current approach are: - - Expressions are only evaluated one time. This is especially - significant for I/O, e.g. "val x = Console.readLine" + - Expressions are only evaluated one time. This is especially + significant for I/O, e.g. "val x = Console.readLine" - The main disadvantage is: + The main disadvantage is: - - Objects, classes, and methods cannot be rebound. Instead, definitions - shadow the old ones, and old code objects refer to the old - definitions. - */ - private var _classLoader: AbstractFileClassLoader = null - def resetClassLoader() = _classLoader = makeClassLoader() - def classLoader: AbstractFileClassLoader = { - if (_classLoader == null) - resetClassLoader() - - _classLoader - } - private def makeClassLoader(): AbstractFileClassLoader = { - val parent = - if (parentClassLoader == null) ScalaClassLoader fromURLs compilerClasspath - else new URLClassLoader(compilerClasspath, parentClassLoader) - - new AbstractFileClassLoader(virtualDirectory, parent) { + - Objects, classes, and methods cannot be rebound. Instead, definitions + shadow the old ones, and old code objects refer to the old + definitions. + */ + def resetClassLoader() = { + logDebug("Setting new classloader: was " + _classLoader) + _classLoader = null + ensureClassLoader() + } + final def ensureClassLoader() { + if (_classLoader == null) + _classLoader = makeClassLoader() + } + def classLoader: AbstractFileClassLoader = { + ensureClassLoader() + _classLoader + } + private class TranslatingClassLoader(parent: ClassLoader) extends AbstractFileClassLoader(virtualDirectory, parent) { /** Overridden here to try translating a simple name to the generated * class name if the original attempt fails. This method is used by * getResourceAsStream as well as findClass. @@ -314,223 +343,300 @@ class SparkIMain(val settings: Settings, protected val out: PrintWriter) extends override protected def findAbstractFile(name: String): AbstractFile = { super.findAbstractFile(name) match { // deadlocks on startup if we try to translate names too early - case null if isInitializeComplete => generatedName(name) map (x => super.findAbstractFile(x)) orNull - case file => file + case null if isInitializeComplete => + generatedName(name) map (x => super.findAbstractFile(x)) orNull + case file => + file } } } - } - private def loadByName(s: String): JClass = - (classLoader tryToInitializeClass s) getOrElse sys.error("Failed to load expected class: '" + s + "'") - - protected def parentClassLoader: ClassLoader = - SparkHelper.explicitParentLoader(settings).getOrElse( this.getClass.getClassLoader() ) + private def makeClassLoader(): AbstractFileClassLoader = + new TranslatingClassLoader(parentClassLoader match { + case null => ScalaClassLoader fromURLs compilerClasspath + case p => new URLClassLoader(compilerClasspath, p) + }) - def getInterpreterClassLoader() = classLoader + def getInterpreterClassLoader() = classLoader - // Set the current Java "context" class loader to this interpreter's class loader - def setContextClassLoader() = classLoader.setAsContext() + // Set the current Java "context" class loader to this interpreter's class loader + def setContextClassLoader() = classLoader.setAsContext() - /** Given a simple repl-defined name, returns the real name of - * the class representing it, e.g. for "Bippy" it may return - * - * $line19.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$Bippy - */ - def generatedName(simpleName: String): Option[String] = { - if (simpleName endsWith "$") optFlatName(simpleName.init) map (_ + "$") - else optFlatName(simpleName) - } - def flatName(id: String) = optFlatName(id) getOrElse id - def optFlatName(id: String) = requestForIdent(id) map (_ fullFlatName id) + /** Given a simple repl-defined name, returns the real name of + * the class representing it, e.g. for "Bippy" it may return + * {{{ + * $line19.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$Bippy + * }}} + */ + def generatedName(simpleName: String): Option[String] = { + if (simpleName endsWith nme.MODULE_SUFFIX_STRING) optFlatName(simpleName.init) map (_ + nme.MODULE_SUFFIX_STRING) + else optFlatName(simpleName) + } + def flatName(id: String) = optFlatName(id) getOrElse id + def optFlatName(id: String) = requestForIdent(id) map (_ fullFlatName id) - def allDefinedNames = definedNameMap.keys.toList sortBy (_.toString) - def pathToType(id: String): String = pathToName(newTypeName(id)) - def pathToTerm(id: String): String = pathToName(newTermName(id)) - def pathToName(name: Name): String = { - if (definedNameMap contains name) - definedNameMap(name) fullPath name - else name.toString - } + def allDefinedNames = definedNameMap.keys.toList.sorted + def pathToType(id: String): String = pathToName(newTypeName(id)) + def pathToTerm(id: String): String = pathToName(newTermName(id)) + def pathToName(name: Name): String = { + if (definedNameMap contains name) + definedNameMap(name) fullPath name + else name.toString + } - /** Most recent tree handled which wasn't wholly synthetic. */ - private def mostRecentlyHandledTree: Option[Tree] = { - prevRequests.reverse foreach { req => - req.handlers.reverse foreach { - case x: MemberDefHandler if x.definesValue && !isInternalVarName(x.name.toString) => return Some(x.member) - case _ => () + /** Most recent tree handled which wasn't wholly synthetic. */ + private def mostRecentlyHandledTree: Option[Tree] = { + prevRequests.reverse foreach { req => + req.handlers.reverse foreach { + case x: MemberDefHandler if x.definesValue && !isInternalTermName(x.name) => return Some(x.member) + case _ => () + } + } + None + } + + /** Stubs for work in progress. */ + def handleTypeRedefinition(name: TypeName, old: Request, req: Request) = { + for (t1 <- old.simpleNameOfType(name) ; t2 <- req.simpleNameOfType(name)) { + logDebug("Redefining type '%s'\n %s -> %s".format(name, t1, t2)) } } - None - } - - /** Stubs for work in progress. */ - def handleTypeRedefinition(name: TypeName, old: Request, req: Request) = { - for (t1 <- old.simpleNameOfType(name) ; t2 <- req.simpleNameOfType(name)) { - DBG("Redefining type '%s'\n %s -> %s".format(name, t1, t2)) - } - } - def handleTermRedefinition(name: TermName, old: Request, req: Request) = { - for (t1 <- old.compilerTypeOf get name ; t2 <- req.compilerTypeOf get name) { - // Printing the types here has a tendency to cause assertion errors, like - // assertion failed: fatal: has owner value x, but a class owner is required - // so DBG is by-name now to keep it in the family. (It also traps the assertion error, - // but we don't want to unnecessarily risk hosing the compiler's internal state.) - DBG("Redefining term '%s'\n %s -> %s".format(name, t1, t2)) + def handleTermRedefinition(name: TermName, old: Request, req: Request) = { + for (t1 <- old.compilerTypeOf get name ; t2 <- req.compilerTypeOf get name) { + // Printing the types here has a tendency to cause assertion errors, like + // assertion failed: fatal: has owner value x, but a class owner is required + // so DBG is by-name now to keep it in the family. (It also traps the assertion error, + // but we don't want to unnecessarily risk hosing the compiler's internal state.) + logDebug("Redefining term '%s'\n %s -> %s".format(name, t1, t2)) + } } - } - def recordRequest(req: Request) { - if (req == null || referencedNameMap == null) - return - prevRequests += req - req.referencedNames foreach (x => referencedNameMap(x) = req) - - // warning about serially defining companions. It'd be easy - // enough to just redefine them together but that may not always - // be what people want so I'm waiting until I can do it better. - if (!settings.nowarnings.value) { + def recordRequest(req: Request) { + if (req == null || referencedNameMap == null) + return + + prevRequests += req + req.referencedNames foreach (x => referencedNameMap(x) = req) + + // warning about serially defining companions. It'd be easy + // enough to just redefine them together but that may not always + // be what people want so I'm waiting until I can do it better. for { name <- req.definedNames filterNot (x => req.definedNames contains x.companionName) oldReq <- definedNameMap get name.companionName newSym <- req.definedSymbols get name oldSym <- oldReq.definedSymbols get name.companionName + if Seq(oldSym, newSym).permutations exists { case Seq(s1, s2) => s1.isClass && s2.isModule } } { - printMessage("warning: previously defined %s is not a companion to %s.".format(oldSym, newSym)) - printMessage("Companions must be defined together; you may wish to use :paste mode for this.") + afterTyper(replwarn(s"warning: previously defined $oldSym is not a companion to $newSym.")) + replwarn("Companions must be defined together; you may wish to use :paste mode for this.") + } + + // Updating the defined name map + req.definedNames foreach { name => + if (definedNameMap contains name) { + if (name.isTypeName) handleTypeRedefinition(name.toTypeName, definedNameMap(name), req) + else handleTermRedefinition(name.toTermName, definedNameMap(name), req) + } + definedNameMap(name) = req } } - - // Updating the defined name map - req.definedNames foreach { name => - if (definedNameMap contains name) { - if (name.isTypeName) handleTypeRedefinition(name.toTypeName, definedNameMap(name), req) - else handleTermRedefinition(name.toTermName, definedNameMap(name), req) - } - definedNameMap(name) = req - } - } - /** Parse a line into a sequence of trees. Returns None if the input is incomplete. */ - def parse(line: String): Option[List[Tree]] = { - var justNeedsMore = false - reporter.withIncompleteHandler((pos,msg) => {justNeedsMore = true}) { - // simple parse: just parse it, nothing else - def simpleParse(code: String): List[Tree] = { - reporter.reset() - val unit = new CompilationUnit(new BatchSourceFile("", code)) - val scanner = new syntaxAnalyzer.UnitParser(unit) - - scanner.templateStatSeq(false)._2 - } - val trees = simpleParse(line) - - if (reporter.hasErrors) Some(Nil) // the result did not parse, so stop - else if (justNeedsMore) None - else Some(trees) + def replwarn(msg: => String) { + if (!settings.nowarnings.value) + printMessage(msg) } - } - - def isParseable(line: String): Boolean = { - beSilentDuring { - parse(line) match { - case Some(xs) => xs.nonEmpty // parses as-is - case None => true // incomplete + + def isParseable(line: String): Boolean = { + beSilentDuring { + try parse(line) match { + case Some(xs) => xs.nonEmpty // parses as-is + case None => true // incomplete + } + catch { case x: Exception => // crashed the compiler + replwarn("Exception in isParseable(\"" + line + "\"): " + x) + false + } } } + + def compileSourcesKeepingRun(sources: SourceFile*) = { + val run = new Run() + reporter.reset() + run compileSources sources.toList + (!reporter.hasErrors, run) + } + + /** Compile an nsc SourceFile. Returns true if there are + * no compilation errors, or false otherwise. + */ + def compileSources(sources: SourceFile*): Boolean = + compileSourcesKeepingRun(sources: _*)._1 + + /** Compile a string. Returns true if there are no + * compilation errors, or false otherwise. + */ + def compileString(code: String): Boolean = + compileSources(new BatchSourceFile("