From 886b39de557b4d5f54f5ca11559fca9799534280 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 10 Aug 2012 01:10:02 -0700 Subject: [PATCH 001/291] Add Python API. --- .../scala/spark/api/python/PythonRDD.scala | 147 +++++ pyspark/pyspark/__init__.py | 0 pyspark/pyspark/context.py | 69 +++ pyspark/pyspark/examples/__init__.py | 0 pyspark/pyspark/examples/kmeans.py | 56 ++ pyspark/pyspark/examples/pi.py | 20 + pyspark/pyspark/examples/tc.py | 49 ++ pyspark/pyspark/java_gateway.py | 20 + pyspark/pyspark/join.py | 104 ++++ pyspark/pyspark/rdd.py | 517 ++++++++++++++++++ pyspark/pyspark/serializers.py | 229 ++++++++ pyspark/pyspark/worker.py | 97 ++++ pyspark/requirements.txt | 9 + python/tc.py | 22 + 14 files changed, 1339 insertions(+) create mode 100644 core/src/main/scala/spark/api/python/PythonRDD.scala create mode 100644 pyspark/pyspark/__init__.py create mode 100644 pyspark/pyspark/context.py create mode 100644 pyspark/pyspark/examples/__init__.py create mode 100644 pyspark/pyspark/examples/kmeans.py create mode 100644 pyspark/pyspark/examples/pi.py create mode 100644 pyspark/pyspark/examples/tc.py create mode 100644 pyspark/pyspark/java_gateway.py create mode 100644 pyspark/pyspark/join.py create mode 100644 pyspark/pyspark/rdd.py create mode 100644 pyspark/pyspark/serializers.py create mode 100644 pyspark/pyspark/worker.py create mode 100644 pyspark/requirements.txt create mode 100644 python/tc.py diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala new file mode 100644 index 0000000000..660ad48afe --- /dev/null +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -0,0 +1,147 @@ +package spark.api.python + +import java.io.PrintWriter + +import scala.collection.Map +import scala.collection.JavaConversions._ +import scala.io.Source +import spark._ +import api.java.{JavaPairRDD, JavaRDD} +import scala.Some + +trait PythonRDDBase { + def compute[T](split: Split, envVars: Map[String, String], + command: Seq[String], parent: RDD[T], pythonExec: String): Iterator[String]= { + val currentEnvVars = new ProcessBuilder().environment() + val SPARK_HOME = currentEnvVars.get("SPARK_HOME") + + val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/pyspark/pyspark/worker.py")) + // Add the environmental variables to the process. + envVars.foreach { + case (variable, value) => currentEnvVars.put(variable, value) + } + + val proc = pb.start() + val env = SparkEnv.get + + // Start a thread to print the process's stderr to ours + new Thread("stderr reader for " + command) { + override def run() { + for (line <- Source.fromInputStream(proc.getErrorStream).getLines) { + System.err.println(line) + } + } + }.start() + + // Start a thread to feed the process input from our parent's iterator + new Thread("stdin writer for " + command) { + override def run() { + SparkEnv.set(env) + val out = new PrintWriter(proc.getOutputStream) + for (elem <- command) { + out.println(elem) + } + for (elem <- parent.iterator(split)) { + out.println(PythonRDD.pythonDump(elem)) + } + out.close() + } + }.start() + + // Return an iterator that read lines from the process's stdout + val lines: Iterator[String] = Source.fromInputStream(proc.getInputStream).getLines + wrapIterator(lines, proc) + } + + def wrapIterator[T](iter: Iterator[T], proc: Process): Iterator[T] = { + return new Iterator[T] { + def next() = iter.next() + + def hasNext = { + if (iter.hasNext) { + true + } else { + val exitStatus = proc.waitFor() + if (exitStatus != 0) { + throw new Exception("Subprocess exited with status " + exitStatus) + } + false + } + } + } + } +} + +class PythonRDD[T: ClassManifest]( + parent: RDD[T], command: Seq[String], envVars: Map[String, String], + preservePartitoning: Boolean, pythonExec: String) + extends RDD[String](parent.context) with PythonRDDBase { + + def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, pythonExec: String) = + this(parent, command, Map(), preservePartitoning, pythonExec) + + // Similar to Runtime.exec(), if we are given a single string, split it into words + // using a standard StringTokenizer (i.e. by spaces) + def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String) = + this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec) + + override def splits = parent.splits + + override val dependencies = List(new OneToOneDependency(parent)) + + override val partitioner = if (preservePartitoning) parent.partitioner else None + + override def compute(split: Split): Iterator[String] = + compute(split, envVars, command, parent, pythonExec) + + val asJavaRDD : JavaRDD[String] = JavaRDD.fromRDD(this) +} + +class PythonPairRDD[T: ClassManifest] ( + parent: RDD[T], command: Seq[String], envVars: Map[String, String], + preservePartitoning: Boolean, pythonExec: String) + extends RDD[(String, String)](parent.context) with PythonRDDBase { + + def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, pythonExec: String) = + this(parent, command, Map(), preservePartitoning, pythonExec) + + // Similar to Runtime.exec(), if we are given a single string, split it into words + // using a standard StringTokenizer (i.e. by spaces) + def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String) = + this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec) + + override def splits = parent.splits + + override val dependencies = List(new OneToOneDependency(parent)) + + override val partitioner = if (preservePartitoning) parent.partitioner else None + + override def compute(split: Split): Iterator[(String, String)] = { + compute(split, envVars, command, parent, pythonExec).grouped(2).map { + case Seq(a, b) => (a, b) + case x => throw new Exception("Unexpected value: " + x) + } + } + + val asJavaPairRDD : JavaPairRDD[String, String] = JavaPairRDD.fromRDD(this) +} + +object PythonRDD { + def pythonDump[T](x: T): String = { + if (x.isInstanceOf[scala.Option[_]]) { + val t = x.asInstanceOf[scala.Option[_]] + t match { + case None => "*" + case Some(z) => pythonDump(z) + } + } else if (x.isInstanceOf[scala.Tuple2[_, _]]) { + val t = x.asInstanceOf[scala.Tuple2[_, _]] + "(" + pythonDump(t._1) + "," + pythonDump(t._2) + ")" + } else if (x.isInstanceOf[java.util.List[_]]) { + val objs = asScalaBuffer(x.asInstanceOf[java.util.List[_]]).map(pythonDump) + "[" + objs.mkString("|") + "]" + } else { + x.toString + } + } +} diff --git a/pyspark/pyspark/__init__.py b/pyspark/pyspark/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py new file mode 100644 index 0000000000..587ab12b5f --- /dev/null +++ b/pyspark/pyspark/context.py @@ -0,0 +1,69 @@ +import os +import atexit +from tempfile import NamedTemporaryFile + +from pyspark.java_gateway import launch_gateway +from pyspark.serializers import JSONSerializer, NopSerializer +from pyspark.rdd import RDD, PairRDD + + +class SparkContext(object): + + gateway = launch_gateway() + jvm = gateway.jvm + python_dump = jvm.spark.api.python.PythonRDD.pythonDump + + def __init__(self, master, name, defaultSerializer=JSONSerializer, + defaultParallelism=None, pythonExec='python'): + self.master = master + self.name = name + self._jsc = self.jvm.JavaSparkContext(master, name) + self.defaultSerializer = defaultSerializer + self.defaultParallelism = \ + defaultParallelism or self._jsc.sc().defaultParallelism() + self.pythonExec = pythonExec + + def __del__(self): + if self._jsc: + self._jsc.stop() + + def stop(self): + self._jsc.stop() + self._jsc = None + + def parallelize(self, c, numSlices=None, serializer=None): + serializer = serializer or self.defaultSerializer + numSlices = numSlices or self.defaultParallelism + # Calling the Java parallelize() method with an ArrayList is too slow, + # because it sends O(n) Py4J commands. As an alternative, serialized + # objects are written to a file and loaded through textFile(). + tempFile = NamedTemporaryFile(delete=False) + tempFile.writelines(serializer.dumps(x) + '\n' for x in c) + tempFile.close() + atexit.register(lambda: os.unlink(tempFile.name)) + return self.textFile(tempFile.name, numSlices, serializer) + + def parallelizePairs(self, c, numSlices=None, keySerializer=None, + valSerializer=None): + """ + >>> sc = SparkContext("local", "test") + >>> rdd = sc.parallelizePairs([(1, 2), (3, 4)]) + >>> rdd.collect() + [(1, 2), (3, 4)] + """ + keySerializer = keySerializer or self.defaultSerializer + valSerializer = valSerializer or self.defaultSerializer + numSlices = numSlices or self.defaultParallelism + tempFile = NamedTemporaryFile(delete=False) + for (k, v) in c: + tempFile.write(keySerializer.dumps(k).rstrip('\r\n') + '\n') + tempFile.write(valSerializer.dumps(v).rstrip('\r\n') + '\n') + tempFile.close() + atexit.register(lambda: os.unlink(tempFile.name)) + jrdd = self.textFile(tempFile.name, numSlices)._pipePairs([], "echo") + return PairRDD(jrdd, self, keySerializer, valSerializer) + + def textFile(self, name, numSlices=None, serializer=NopSerializer): + numSlices = numSlices or self.defaultParallelism + jrdd = self._jsc.textFile(name, numSlices) + return RDD(jrdd, self, serializer) diff --git a/pyspark/pyspark/examples/__init__.py b/pyspark/pyspark/examples/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pyspark/pyspark/examples/kmeans.py b/pyspark/pyspark/examples/kmeans.py new file mode 100644 index 0000000000..0761d6e395 --- /dev/null +++ b/pyspark/pyspark/examples/kmeans.py @@ -0,0 +1,56 @@ +import sys + +from pyspark.context import SparkContext + + +def parseVector(line): + return [float(x) for x in line.split(' ')] + + +def addVec(x, y): + return [a + b for (a, b) in zip(x, y)] + + +def squaredDist(x, y): + return sum((a - b) ** 2 for (a, b) in zip(x, y)) + + +def closestPoint(p, centers): + bestIndex = 0 + closest = float("+inf") + for i in range(len(centers)): + tempDist = squaredDist(p, centers[i]) + if tempDist < closest: + closest = tempDist + bestIndex = i + return bestIndex + + +if __name__ == "__main__": + if len(sys.argv) < 5: + print >> sys.stderr, \ + "Usage: PythonKMeans " + exit(-1) + sc = SparkContext(sys.argv[1], "PythonKMeans") + lines = sc.textFile(sys.argv[2]) + data = lines.map(parseVector).cache() + K = int(sys.argv[3]) + convergeDist = float(sys.argv[4]) + + kPoints = data.takeSample(False, K, 34) + tempDist = 1.0 + + while tempDist > convergeDist: + closest = data.mapPairs( + lambda p : (closestPoint(p, kPoints), (p, 1))) + pointStats = closest.reduceByKey( + lambda (x1, y1), (x2, y2): (addVec(x1, x2), y1 + y2)) + newPoints = pointStats.mapPairs( + lambda (x, (y, z)): (x, [a / z for a in y])).collect() + + tempDist = sum(squaredDist(kPoints[x], y) for (x, y) in newPoints) + + for (x, y) in newPoints: + kPoints[x] = y + + print "Final centers: " + str(kPoints) diff --git a/pyspark/pyspark/examples/pi.py b/pyspark/pyspark/examples/pi.py new file mode 100644 index 0000000000..ad77694c41 --- /dev/null +++ b/pyspark/pyspark/examples/pi.py @@ -0,0 +1,20 @@ +import sys +from random import random +from operator import add +from pyspark.context import SparkContext + + +if __name__ == "__main__": + if len(sys.argv) == 1: + print >> sys.stderr, \ + "Usage: PythonPi []" + exit(-1) + sc = SparkContext(sys.argv[1], "PythonKMeans") + slices = sys.argv[2] if len(sys.argv) > 2 else 2 + n = 100000 * slices + def f(_): + x = random() * 2 - 1 + y = random() * 2 - 1 + return 1 if x ** 2 + y ** 2 < 1 else 0 + count = sc.parallelize(xrange(1, n+1), slices).map(f).reduce(add) + print "Pi is roughly %f" % (4.0 * count / n) diff --git a/pyspark/pyspark/examples/tc.py b/pyspark/pyspark/examples/tc.py new file mode 100644 index 0000000000..2796fdc6ad --- /dev/null +++ b/pyspark/pyspark/examples/tc.py @@ -0,0 +1,49 @@ +import sys +from random import Random +from pyspark.context import SparkContext + +numEdges = 200 +numVertices = 100 +rand = Random(42) + + +def generateGraph(): + edges = set() + while len(edges) < numEdges: + src = rand.randrange(0, numEdges) + dst = rand.randrange(0, numEdges) + if src != dst: + edges.add((src, dst)) + return edges + + +if __name__ == "__main__": + if len(sys.argv) == 1: + print >> sys.stderr, \ + "Usage: PythonTC []" + exit(-1) + sc = SparkContext(sys.argv[1], "PythonKMeans") + slices = sys.argv[2] if len(sys.argv) > 2 else 2 + tc = sc.parallelizePairs(generateGraph(), slices).cache() + + # Linear transitive closure: each round grows paths by one edge, + # by joining the graph's edges with the already-discovered paths. + # e.g. join the path (y, z) from the TC with the edge (x, y) from + # the graph to obtain the path (x, z). + + # Because join() joins on keys, the edges are stored in reversed order. + edges = tc.mapPairs(lambda (x, y): (y, x)) + + oldCount = 0L + nextCount = tc.count() + while True: + oldCount = nextCount + # Perform the join, obtaining an RDD of (y, (z, x)) pairs, + # then project the result to obtain the new (x, z) paths. + new_edges = tc.join(edges).mapPairs(lambda (_, (a, b)): (b, a)) + tc = tc.union(new_edges).distinct().cache() + nextCount = tc.count() + if nextCount == oldCount: + break + + print "TC has %i edges" % tc.count() diff --git a/pyspark/pyspark/java_gateway.py b/pyspark/pyspark/java_gateway.py new file mode 100644 index 0000000000..2df80aee85 --- /dev/null +++ b/pyspark/pyspark/java_gateway.py @@ -0,0 +1,20 @@ +import glob +import os +from py4j.java_gateway import java_import, JavaGateway + + +SPARK_HOME = os.environ["SPARK_HOME"] + + +assembly_jar = glob.glob(os.path.join(SPARK_HOME, "core/target") + \ + "/spark-core-assembly-*-SNAPSHOT.jar")[0] + + +def launch_gateway(): + gateway = JavaGateway.launch_gateway(classpath=assembly_jar, + javaopts=["-Xmx256m"], die_on_exit=True) + java_import(gateway.jvm, "spark.api.java.*") + java_import(gateway.jvm, "spark.api.python.*") + java_import(gateway.jvm, "scala.Tuple2") + java_import(gateway.jvm, "spark.api.python.PythonRDD.pythonDump") + return gateway diff --git a/pyspark/pyspark/join.py b/pyspark/pyspark/join.py new file mode 100644 index 0000000000..c67520fce8 --- /dev/null +++ b/pyspark/pyspark/join.py @@ -0,0 +1,104 @@ +""" +Copyright (c) 2011, Douban Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + + * Neither the name of the Douban Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" +from pyspark.serializers import PairSerializer, OptionSerializer, \ + ArraySerializer + + +def _do_python_join(rdd, other, numSplits, dispatch, valSerializer): + vs = rdd.mapPairs(lambda (k, v): (k, (1, v))) + ws = other.mapPairs(lambda (k, v): (k, (2, v))) + return vs.union(ws).groupByKey(numSplits) \ + .flatMapValues(dispatch, valSerializer) + + +def python_join(rdd, other, numSplits): + def dispatch(seq): + vbuf, wbuf = [], [] + for (n, v) in seq: + if n == 1: + vbuf.append(v) + elif n == 2: + wbuf.append(v) + return [(v, w) for v in vbuf for w in wbuf] + valSerializer = PairSerializer(rdd.valSerializer, other.valSerializer) + return _do_python_join(rdd, other, numSplits, dispatch, valSerializer) + + +def python_right_outer_join(rdd, other, numSplits): + def dispatch(seq): + vbuf, wbuf = [], [] + for (n, v) in seq: + if n == 1: + vbuf.append(v) + elif n == 2: + wbuf.append(v) + if not vbuf: + vbuf.append(None) + return [(v, w) for v in vbuf for w in wbuf] + valSerializer = PairSerializer(OptionSerializer(rdd.valSerializer), + other.valSerializer) + return _do_python_join(rdd, other, numSplits, dispatch, valSerializer) + + +def python_left_outer_join(rdd, other, numSplits): + def dispatch(seq): + vbuf, wbuf = [], [] + for (n, v) in seq: + if n == 1: + vbuf.append(v) + elif n == 2: + wbuf.append(v) + if not wbuf: + wbuf.append(None) + return [(v, w) for v in vbuf for w in wbuf] + valSerializer = PairSerializer(rdd.valSerializer, + OptionSerializer(other.valSerializer)) + return _do_python_join(rdd, other, numSplits, dispatch, valSerializer) + + +def python_cogroup(rdd, other, numSplits): + resultValSerializer = PairSerializer( + ArraySerializer(rdd.valSerializer), + ArraySerializer(other.valSerializer)) + vs = rdd.mapPairs(lambda (k, v): (k, (1, v))) + ws = other.mapPairs(lambda (k, v): (k, (2, v))) + def dispatch(seq): + vbuf, wbuf = [], [] + for (n, v) in seq: + if n == 1: + vbuf.append(v) + elif n == 2: + wbuf.append(v) + return (vbuf, wbuf) + return vs.union(ws).groupByKey(numSplits) \ + .mapValues(dispatch, resultValSerializer) diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py new file mode 100644 index 0000000000..c892e86b93 --- /dev/null +++ b/pyspark/pyspark/rdd.py @@ -0,0 +1,517 @@ +from base64 import standard_b64encode as b64enc +from cloud.serialization import cloudpickle +from itertools import chain + +from pyspark.serializers import PairSerializer, NopSerializer, \ + OptionSerializer, ArraySerializer +from pyspark.join import python_join, python_left_outer_join, \ + python_right_outer_join, python_cogroup + + +class RDD(object): + + def __init__(self, jrdd, ctx, serializer=None): + self._jrdd = jrdd + self.is_cached = False + self.ctx = ctx + self.serializer = serializer or ctx.defaultSerializer + + def _builder(self, jrdd, ctx): + return RDD(jrdd, ctx, self.serializer) + + @property + def id(self): + return self._jrdd.id() + + @property + def splits(self): + return self._jrdd.splits() + + @classmethod + def _get_pipe_command(cls, command, functions): + if functions and not isinstance(functions, (list, tuple)): + functions = [functions] + worker_args = [command] + for f in functions: + worker_args.append(b64enc(cloudpickle.dumps(f))) + return " ".join(worker_args) + + def cache(self): + self.is_cached = True + self._jrdd.cache() + return self + + def map(self, f, serializer=None, preservesPartitioning=False): + return MappedRDD(self, f, serializer, preservesPartitioning) + + def mapPairs(self, f, keySerializer=None, valSerializer=None, + preservesPartitioning=False): + return PairMappedRDD(self, f, keySerializer, valSerializer, + preservesPartitioning) + + def flatMap(self, f, serializer=None): + """ + >>> rdd = sc.parallelize([2, 3, 4]) + >>> sorted(rdd.flatMap(lambda x: range(1, x)).collect()) + [1, 1, 1, 2, 2, 3] + """ + serializer = serializer or self.ctx.defaultSerializer + dumps = serializer.dumps + loads = self.serializer.loads + def func(x): + pickled_elems = (dumps(y) for y in f(loads(x))) + return "\n".join(pickled_elems) or None + pipe_command = RDD._get_pipe_command("map", [func]) + class_manifest = self._jrdd.classManifest() + jrdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command, + False, self.ctx.pythonExec, + class_manifest).asJavaRDD() + return RDD(jrdd, self.ctx, serializer) + + def flatMapPairs(self, f, keySerializer=None, valSerializer=None, + preservesPartitioning=False): + """ + >>> rdd = sc.parallelize([2, 3, 4]) + >>> sorted(rdd.flatMapPairs(lambda x: [(x, x), (x, x)]).collect()) + [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] + """ + keySerializer = keySerializer or self.ctx.defaultSerializer + valSerializer = valSerializer or self.ctx.defaultSerializer + dumpk = keySerializer.dumps + dumpv = valSerializer.dumps + loads = self.serializer.loads + def func(x): + pairs = f(loads(x)) + pickled_pairs = ((dumpk(k), dumpv(v)) for (k, v) in pairs) + return "\n".join(chain.from_iterable(pickled_pairs)) or None + pipe_command = RDD._get_pipe_command("map", [func]) + class_manifest = self._jrdd.classManifest() + python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(), pipe_command, + preservesPartitioning, self.ctx.pythonExec, class_manifest) + return PairRDD(python_rdd.asJavaPairRDD(), self.ctx, keySerializer, + valSerializer) + + def filter(self, f): + """ + >>> rdd = sc.parallelize([1, 2, 3, 4, 5]) + >>> rdd.filter(lambda x: x % 2 == 0).collect() + [2, 4] + """ + loads = self.serializer.loads + def filter_func(x): return x if f(loads(x)) else None + return self._builder(self._pipe(filter_func), self.ctx) + + def _pipe(self, functions, command="map"): + class_manifest = self._jrdd.classManifest() + pipe_command = RDD._get_pipe_command(command, functions) + python_rdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command, + False, self.ctx.pythonExec, class_manifest) + return python_rdd.asJavaRDD() + + def _pipePairs(self, functions, command="mapPairs", + preservesPartitioning=False): + class_manifest = self._jrdd.classManifest() + pipe_command = RDD._get_pipe_command(command, functions) + python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(), pipe_command, + preservesPartitioning, self.ctx.pythonExec, class_manifest) + return python_rdd.asJavaPairRDD() + + def distinct(self): + """ + >>> sorted(sc.parallelize([1, 1, 2, 3]).distinct().collect()) + [1, 2, 3] + """ + if self.serializer.is_comparable: + return self._builder(self._jrdd.distinct(), self.ctx) + return self.mapPairs(lambda x: (x, "")) \ + .reduceByKey(lambda x, _: x) \ + .map(lambda (x, _): x) + + def sample(self, withReplacement, fraction, seed): + jrdd = self._jrdd.sample(withReplacement, fraction, seed) + return self._builder(jrdd, self.ctx) + + def takeSample(self, withReplacement, num, seed): + vals = self._jrdd.takeSample(withReplacement, num, seed) + return [self.serializer.loads(self.ctx.python_dump(x)) for x in vals] + + def union(self, other): + """ + >>> rdd = sc.parallelize([1, 1, 2, 3]) + >>> rdd.union(rdd).collect() + [1, 1, 2, 3, 1, 1, 2, 3] + """ + return self._builder(self._jrdd.union(other._jrdd), self.ctx) + + # TODO: sort + + # TODO: Overload __add___? + + # TODO: glom + + def cartesian(self, other): + """ + >>> rdd = sc.parallelize([1, 2]) + >>> sorted(rdd.cartesian(rdd).collect()) + [(1, 1), (1, 2), (2, 1), (2, 2)] + """ + return PairRDD(self._jrdd.cartesian(other._jrdd), self.ctx) + + # numsplits + def groupBy(self, f, numSplits=None): + """ + >>> rdd = sc.parallelize([1, 1, 2, 3, 5, 8]) + >>> sorted(rdd.groupBy(lambda x: x % 2).collect()) + [(0, [2, 8]), (1, [1, 1, 3, 5])] + """ + return self.mapPairs(lambda x: (f(x), x)).groupByKey(numSplits) + + # TODO: pipe + + # TODO: mapPartitions + + def foreach(self, f): + """ + >>> def f(x): print x + >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f) + """ + self.map(f).collect() # Force evaluation + + def collect(self): + vals = self._jrdd.collect() + return [self.serializer.loads(self.ctx.python_dump(x)) for x in vals] + + def reduce(self, f, serializer=None): + """ + >>> import operator + >>> sc.parallelize([1, 2, 3, 4, 5]).reduce(operator.add) + 15 + """ + serializer = serializer or self.ctx.defaultSerializer + loads = self.serializer.loads + dumps = serializer.dumps + def reduceFunction(x, acc): + if acc is None: + return loads(x) + else: + return f(loads(x), acc) + vals = self._pipe([reduceFunction, dumps], command="reduce").collect() + return reduce(f, (serializer.loads(x) for x in vals)) + + # TODO: fold + + # TODO: aggregate + + def count(self): + """ + >>> sc.parallelize([2, 3, 4]).count() + 3L + """ + return self._jrdd.count() + + # TODO: count approx methods + + def take(self, num): + """ + >>> sc.parallelize([2, 3, 4]).take(2) + [2, 3] + """ + vals = self._jrdd.take(num) + return [self.serializer.loads(self.ctx.python_dump(x)) for x in vals] + + def first(self): + """ + >>> sc.parallelize([2, 3, 4]).first() + 2 + """ + return self.serializer.loads(self.ctx.python_dump(self._jrdd.first())) + + # TODO: saveAsTextFile + + # TODO: saveAsObjectFile + + +class PairRDD(RDD): + + def __init__(self, jrdd, ctx, keySerializer=None, valSerializer=None): + RDD.__init__(self, jrdd, ctx) + self.keySerializer = keySerializer or ctx.defaultSerializer + self.valSerializer = valSerializer or ctx.defaultSerializer + self.serializer = \ + PairSerializer(self.keySerializer, self.valSerializer) + + def _builder(self, jrdd, ctx): + return PairRDD(jrdd, ctx, self.keySerializer, self.valSerializer) + + def reduceByKey(self, func, numSplits=None): + """ + >>> x = sc.parallelizePairs([("a", 1), ("b", 1), ("a", 1)]) + >>> sorted(x.reduceByKey(lambda a, b: a + b).collect()) + [('a', 2), ('b', 1)] + """ + return self.combineByKey(lambda x: x, func, func, numSplits) + + # TODO: reduceByKeyLocally() + + # TODO: countByKey() + + # TODO: partitionBy + + def join(self, other, numSplits=None): + """ + >>> x = sc.parallelizePairs([("a", 1), ("b", 4)]) + >>> y = sc.parallelizePairs([("a", 2), ("a", 3)]) + >>> x.join(y).collect() + [('a', (1, 2)), ('a', (1, 3))] + + Check that we get a PairRDD-like object back: + >>> assert x.join(y).join + """ + assert self.keySerializer.name == other.keySerializer.name + if self.keySerializer.is_comparable: + return PairRDD(self._jrdd.join(other._jrdd), + self.ctx, self.keySerializer, + PairSerializer(self.valSerializer, other.valSerializer)) + else: + return python_join(self, other, numSplits) + + def leftOuterJoin(self, other, numSplits=None): + """ + >>> x = sc.parallelizePairs([("a", 1), ("b", 4)]) + >>> y = sc.parallelizePairs([("a", 2)]) + >>> sorted(x.leftOuterJoin(y).collect()) + [('a', (1, 2)), ('b', (4, None))] + """ + assert self.keySerializer.name == other.keySerializer.name + if self.keySerializer.is_comparable: + return PairRDD(self._jrdd.leftOuterJoin(other._jrdd), + self.ctx, self.keySerializer, + PairSerializer(self.valSerializer, + OptionSerializer(other.valSerializer))) + else: + return python_left_outer_join(self, other, numSplits) + + def rightOuterJoin(self, other, numSplits=None): + """ + >>> x = sc.parallelizePairs([("a", 1), ("b", 4)]) + >>> y = sc.parallelizePairs([("a", 2)]) + >>> sorted(y.rightOuterJoin(x).collect()) + [('a', (2, 1)), ('b', (None, 4))] + """ + assert self.keySerializer.name == other.keySerializer.name + if self.keySerializer.is_comparable: + return PairRDD(self._jrdd.rightOuterJoin(other._jrdd), + self.ctx, self.keySerializer, + PairSerializer(OptionSerializer(self.valSerializer), + other.valSerializer)) + else: + return python_right_outer_join(self, other, numSplits) + + def combineByKey(self, createCombiner, mergeValue, mergeCombiners, + numSplits=None, serializer=None): + """ + >>> x = sc.parallelizePairs([("a", 1), ("b", 1), ("a", 1)]) + >>> def f(x): return x + >>> def add(a, b): return a + str(b) + >>> sorted(x.combineByKey(str, add, add).collect()) + [('a', '11'), ('b', '1')] + """ + serializer = serializer or self.ctx.defaultSerializer + if numSplits is None: + numSplits = self.ctx.defaultParallelism + # Use hash() to create keys that are comparable in Java. + loadkv = self.serializer.loads + def pairify(kv): + # TODO: add method to deserialize only the key or value from + # a PairSerializer? + key = loadkv(kv)[0] + return (str(hash(key)), kv) + partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits) + jrdd = self._pipePairs(pairify).partitionBy(partitioner) + pairified = PairRDD(jrdd, self.ctx, NopSerializer, self.serializer) + + loads = PairSerializer(NopSerializer, self.serializer).loads + dumpk = self.keySerializer.dumps + dumpc = serializer.dumps + + functions = [createCombiner, mergeValue, mergeCombiners, loads, dumpk, + dumpc] + jpairs = pairified._pipePairs(functions, "combine_by_key", + preservesPartitioning=True) + return PairRDD(jpairs, self.ctx, self.keySerializer, serializer) + + def groupByKey(self, numSplits=None): + """ + >>> x = sc.parallelizePairs([("a", 1), ("b", 1), ("a", 1)]) + >>> sorted(x.groupByKey().collect()) + [('a', [1, 1]), ('b', [1])] + """ + + def createCombiner(x): + return [x] + + def mergeValue(xs, x): + xs.append(x) + return xs + + def mergeCombiners(a, b): + return a + b + + return self.combineByKey(createCombiner, mergeValue, mergeCombiners, + numSplits) + + def collectAsMap(self): + """ + >>> m = sc.parallelizePairs([(1, 2), (3, 4)]).collectAsMap() + >>> m[1] + 2 + >>> m[3] + 4 + """ + m = self._jrdd.collectAsMap() + def loads(x): + (k, v) = x + return (self.keySerializer.loads(k), self.valSerializer.loads(v)) + return dict(loads(x) for x in m.items()) + + def flatMapValues(self, f, valSerializer=None): + flat_map_fn = lambda (k, v): ((k, x) for x in f(v)) + return self.flatMapPairs(flat_map_fn, self.keySerializer, + valSerializer, True) + + def mapValues(self, f, valSerializer=None): + map_values_fn = lambda (k, v): (k, f(v)) + return self.mapPairs(map_values_fn, self.keySerializer, valSerializer, + True) + + # TODO: support varargs cogroup of several RDDs. + def groupWith(self, other): + return self.cogroup(other) + + def cogroup(self, other, numSplits=None): + """ + >>> x = sc.parallelizePairs([("a", 1), ("b", 4)]) + >>> y = sc.parallelizePairs([("a", 2)]) + >>> x.cogroup(y).collect() + [('a', ([1], [2])), ('b', ([4], []))] + """ + assert self.keySerializer.name == other.keySerializer.name + resultValSerializer = PairSerializer( + ArraySerializer(self.valSerializer), + ArraySerializer(other.valSerializer)) + if self.keySerializer.is_comparable: + return PairRDD(self._jrdd.cogroup(other._jrdd), + self.ctx, self.keySerializer, resultValSerializer) + else: + return python_cogroup(self, other, numSplits) + + # TODO: `lookup` is disabled because we can't make direct comparisons based + # on the key; we need to compare the hash of the key to the hash of the + # keys in the pairs. This could be an expensive operation, since those + # hashes aren't retained. + + # TODO: file saving + + +class MappedRDDBase(object): + def __init__(self, prev, func, serializer, preservesPartitioning=False): + if isinstance(prev, MappedRDDBase) and not prev.is_cached: + prev_func = prev.func + self.func = lambda x: func(prev_func(x)) + self.preservesPartitioning = \ + prev.preservesPartitioning and preservesPartitioning + self._prev_jrdd = prev._prev_jrdd + self._prev_serializer = prev._prev_serializer + else: + self.func = func + self.preservesPartitioning = preservesPartitioning + self._prev_jrdd = prev._jrdd + self._prev_serializer = prev.serializer + self.serializer = serializer or prev.ctx.defaultSerializer + self.is_cached = False + self.ctx = prev.ctx + self.prev = prev + self._jrdd_val = None + + +class MappedRDD(MappedRDDBase, RDD): + """ + >>> rdd = sc.parallelize([1, 2, 3, 4]) + >>> rdd.map(lambda x: 2 * x).cache().map(lambda x: 2 * x).collect() + [4, 8, 12, 16] + >>> rdd.map(lambda x: 2 * x).map(lambda x: 2 * x).collect() + [4, 8, 12, 16] + """ + + @property + def _jrdd(self): + if not self._jrdd_val: + udf = self.func + loads = self._prev_serializer.loads + dumps = self.serializer.dumps + func = lambda x: dumps(udf(loads(x))) + pipe_command = RDD._get_pipe_command("map", [func]) + class_manifest = self._prev_jrdd.classManifest() + python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), + pipe_command, self.preservesPartitioning, self.ctx.pythonExec, + class_manifest) + self._jrdd_val = python_rdd.asJavaRDD() + return self._jrdd_val + + +class PairMappedRDD(MappedRDDBase, PairRDD): + """ + >>> rdd = sc.parallelize([1, 2, 3, 4]) + >>> rdd.mapPairs(lambda x: (x, x)) \\ + ... .mapPairs(lambda (x, y): (2*x, 2*y)) \\ + ... .collect() + [(2, 2), (4, 4), (6, 6), (8, 8)] + >>> rdd.mapPairs(lambda x: (x, x)) \\ + ... .mapPairs(lambda (x, y): (2*x, 2*y)) \\ + ... .map(lambda (x, _): x).collect() + [2, 4, 6, 8] + """ + + def __init__(self, prev, func, keySerializer=None, valSerializer=None, + preservesPartitioning=False): + self.keySerializer = keySerializer or prev.ctx.defaultSerializer + self.valSerializer = valSerializer or prev.ctx.defaultSerializer + serializer = PairSerializer(self.keySerializer, self.valSerializer) + MappedRDDBase.__init__(self, prev, func, serializer, + preservesPartitioning) + + @property + def _jrdd(self): + if not self._jrdd_val: + udf = self.func + loads = self._prev_serializer.loads + dumpk = self.keySerializer.dumps + dumpv = self.valSerializer.dumps + def func(x): + (k, v) = udf(loads(x)) + return (dumpk(k), dumpv(v)) + pipe_command = RDD._get_pipe_command("mapPairs", [func]) + class_manifest = self._prev_jrdd.classManifest() + self._jrdd_val = self.ctx.jvm.PythonPairRDD(self._prev_jrdd.rdd(), + pipe_command, self.preservesPartitioning, self.ctx.pythonExec, + class_manifest).asJavaPairRDD() + return self._jrdd_val + + +def _test(): + import doctest + from pyspark.context import SparkContext + from pyspark.serializers import PickleSerializer, JSONSerializer + globs = globals().copy() + globs['sc'] = SparkContext('local', 'PythonTest', + defaultSerializer=JSONSerializer) + doctest.testmod(globs=globs) + globs['sc'].stop() + globs['sc'] = SparkContext('local', 'PythonTest', + defaultSerializer=PickleSerializer) + doctest.testmod(globs=globs) + globs['sc'].stop() + + +if __name__ == "__main__": + _test() diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py new file mode 100644 index 0000000000..b113f5656b --- /dev/null +++ b/pyspark/pyspark/serializers.py @@ -0,0 +1,229 @@ +""" +Data serialization methods. + +The Spark Python API is built on top of the Spark Java API. RDDs created in +Python are stored in Java as RDDs of Strings. Python objects are automatically +serialized/deserialized, so this representation is transparent to the end-user. + +------------------ +Serializer objects +------------------ + +`Serializer` objects are used to customize how an RDD's values are serialized. + +Each `Serializer` is a named tuple with four fields: + + - A `dumps` function, for serializing a Python object to a string. + + - A `loads` function, for deserializing a Python object from a string. + + - An `is_comparable` field, True if equal Python objects are serialized to + equal strings, and False otherwise. + + - A `name` field, used to identify the Serializer. Serializers are + compared for equality by comparing their names. + +The serializer's output should be base64-encoded. + +------------------------------------------------------------------ +`is_comparable`: comparing serialized representations for equality +------------------------------------------------------------------ + +If `is_comparable` is False, the serializer's representations of equal objects +are not required to be equal: + +>>> import pickle +>>> a = {1: 0, 9: 0} +>>> b = {9: 0, 1: 0} +>>> a == b +True +>>> pickle.dumps(a) == pickle.dumps(b) +False + +RDDs with comparable serializers can use native Java implementations of +operations like join() and distinct(), which may lead to better performance by +eliminating deserialization and Python comparisons. + +The default JSONSerializer produces comparable representations of common Python +data structures. + +-------------------------------------- +Examples of serialized representations +-------------------------------------- + +The RDD transformations that use Python UDFs are implemented in terms of +a modified `PipedRDD.pipe()` function. For each record `x` in the RDD, the +`pipe()` function pipes `x.toString()` to a Python worker process, which +deserializes the string into a Python object, executes user-defined functions, +and outputs serialized Python objects. + +The regular `toString()` method returns an ambiguous representation, due to the +way that Scala `Option` instances are printed: + +>>> from context import SparkContext +>>> sc = SparkContext("local", "SerializerDocs") +>>> x = sc.parallelizePairs([("a", 1), ("b", 4)]) +>>> y = sc.parallelizePairs([("a", 2)]) + +>>> print y.rightOuterJoin(x)._jrdd.first().toString() +(ImEi,(Some(Mg==),MQ==)) + +In Java, preprocessing is performed to handle Option instances, so the Python +process receives unambiguous input: + +>>> print sc.python_dump(y.rightOuterJoin(x)._jrdd.first()) +(ImEi,(Mg==,MQ==)) + +The base64-encoding eliminates the need to escape newlines, parentheses and +other special characters. + +---------------------- +Serializer composition +---------------------- + +In order to handle nested structures, which could contain object serialized +with different serializers, the RDD module composes serializers. For example, +the serializers in the previous example are: + +>>> print x.serializer.name +PairSerializer + +>>> print y.serializer.name +PairSerializer + +>>> print y.rightOuterJoin(x).serializer.name +PairSerializer, JSONSerializer>> +""" +from base64 import standard_b64encode, standard_b64decode +from collections import namedtuple +import cPickle +import simplejson + + +Serializer = namedtuple("Serializer", + ["dumps","loads", "is_comparable", "name"]) + + +NopSerializer = Serializer(str, str, True, "NopSerializer") + + +JSONSerializer = Serializer( + lambda obj: standard_b64encode(simplejson.dumps(obj, sort_keys=True, + separators=(',', ':'))), + lambda s: simplejson.loads(standard_b64decode(s)), + True, + "JSONSerializer" +) + + +PickleSerializer = Serializer( + lambda obj: standard_b64encode(cPickle.dumps(obj)), + lambda s: cPickle.loads(standard_b64decode(s)), + False, + "PickleSerializer" +) + + +def OptionSerializer(serializer): + """ + >>> ser = OptionSerializer(NopSerializer) + >>> ser.loads(ser.dumps("Hello, World!")) + 'Hello, World!' + >>> ser.loads(ser.dumps(None)) is None + True + """ + none_placeholder = '*' + + def dumps(x): + if x is None: + return none_placeholder + else: + return serializer.dumps(x) + + def loads(x): + if x == none_placeholder: + return None + else: + return serializer.loads(x) + + name = "OptionSerializer<%s>" % serializer.name + return Serializer(dumps, loads, serializer.is_comparable, name) + + +def PairSerializer(keySerializer, valSerializer): + """ + Returns a Serializer for a (key, value) pair. + + >>> ser = PairSerializer(JSONSerializer, JSONSerializer) + >>> ser.loads(ser.dumps((1, 2))) + (1, 2) + + >>> ser = PairSerializer(JSONSerializer, ser) + >>> ser.loads(ser.dumps((1, (2, 3)))) + (1, (2, 3)) + """ + def loads(kv): + try: + (key, val) = kv[1:-1].split(',', 1) + key = keySerializer.loads(key) + val = valSerializer.loads(val) + return (key, val) + except: + print "Error in deserializing pair from '%s'" % str(kv) + raise + + def dumps(kv): + (key, val) = kv + return"(%s,%s)" % (keySerializer.dumps(key), valSerializer.dumps(val)) + is_comparable = \ + keySerializer.is_comparable and valSerializer.is_comparable + name = "PairSerializer<%s, %s>" % (keySerializer.name, valSerializer.name) + return Serializer(dumps, loads, is_comparable, name) + + +def ArraySerializer(serializer): + """ + >>> ser = ArraySerializer(JSONSerializer) + >>> ser.loads(ser.dumps([1, 2, 3, 4])) + [1, 2, 3, 4] + >>> ser = ArraySerializer(PairSerializer(JSONSerializer, PickleSerializer)) + >>> ser.loads(ser.dumps([('a', 1), ('b', 2)])) + [('a', 1), ('b', 2)] + >>> ser.loads(ser.dumps([('a', 1)])) + [('a', 1)] + >>> ser.loads(ser.dumps([])) + [] + """ + def dumps(arr): + if arr == []: + return '[]' + else: + return '[' + '|'.join(serializer.dumps(x) for x in arr) + ']' + + def loads(s): + if s == '[]': + return [] + items = s[1:-1] + if '|' in items: + items = items.split('|') + else: + items = [items] + return [serializer.loads(x) for x in items] + + name = "ArraySerializer<%s>" % serializer.name + return Serializer(dumps, loads, serializer.is_comparable, name) + + +# TODO: IntegerSerializer + + +# TODO: DoubleSerializer + + +def _test(): + import doctest + doctest.testmod() + + +if __name__ == "__main__": + _test() diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py new file mode 100644 index 0000000000..4d4cc939c3 --- /dev/null +++ b/pyspark/pyspark/worker.py @@ -0,0 +1,97 @@ +""" +Worker that receives input from Piped RDD. +""" +import sys +from base64 import standard_b64decode +# CloudPickler needs to be imported so that depicklers are registered using the +# copy_reg module. +from cloud.serialization.cloudpickle import CloudPickler +import cPickle + + +# Redirect stdout to stderr so that users must return values from functions. +old_stdout = sys.stdout +sys.stdout = sys.stderr + + +def load_function(): + return cPickle.loads(standard_b64decode(sys.stdin.readline().strip())) + + +def output(x): + for line in x.split("\n"): + old_stdout.write(line.rstrip("\r\n") + "\n") + + +def read_input(): + for line in sys.stdin: + yield line.rstrip("\r\n") + + +def do_combine_by_key(): + create_combiner = load_function() + merge_value = load_function() + merge_combiners = load_function() # TODO: not used. + depickler = load_function() + key_pickler = load_function() + combiner_pickler = load_function() + combiners = {} + for line in read_input(): + # Discard the hashcode added in the Python combineByKey() method. + (key, value) = depickler(line)[1] + if key not in combiners: + combiners[key] = create_combiner(value) + else: + combiners[key] = merge_value(combiners[key], value) + for (key, combiner) in combiners.iteritems(): + output(key_pickler(key)) + output(combiner_pickler(combiner)) + + +def do_map(map_pairs=False): + f = load_function() + for line in read_input(): + try: + out = f(line) + if out is not None: + if map_pairs: + for x in out: + output(x) + else: + output(out) + except: + sys.stderr.write("Error processing line '%s'\n" % line) + raise + + +def do_reduce(): + f = load_function() + dumps = load_function() + acc = None + for line in read_input(): + acc = f(line, acc) + output(dumps(acc)) + + +def do_echo(): + old_stdout.writelines(sys.stdin.readlines()) + + +def main(): + command = sys.stdin.readline().strip() + if command == "map": + do_map(map_pairs=False) + elif command == "mapPairs": + do_map(map_pairs=True) + elif command == "combine_by_key": + do_combine_by_key() + elif command == "reduce": + do_reduce() + elif command == "echo": + do_echo() + else: + raise Exception("Unsupported command %s" % command) + + +if __name__ == '__main__': + main() diff --git a/pyspark/requirements.txt b/pyspark/requirements.txt new file mode 100644 index 0000000000..d9b3fe40bd --- /dev/null +++ b/pyspark/requirements.txt @@ -0,0 +1,9 @@ +# The Python API relies on some new features from the Py4J development branch. +# pip can't install Py4J from git because the setup.py file for the Python +# package is not at the root of the git repository. It may be possible to +# install Py4J from git once https://github.com/pypa/pip/pull/526 is merged. + +# git+git://github.com/bartdag/py4j.git@3dbf380d3d2cdeb9aab394454ea74d80c4aba1ea + +simplejson==2.6.1 +cloud==2.5.5 diff --git a/python/tc.py b/python/tc.py new file mode 100644 index 0000000000..5dcc4317e0 --- /dev/null +++ b/python/tc.py @@ -0,0 +1,22 @@ +from rdd import SparkContext + +sc = SparkContext("local", "PythonWordCount") +e = [(1, 2), (2, 3), (4, 1)] + +tc = sc.parallelizePairs(e) + +edges = tc.mapPairs(lambda (x, y): (y, x)) + +oldCount = 0 +nextCount = tc.count() + +def project(x): + return (x[1][1], x[1][0]) + +while nextCount != oldCount: + oldCount = nextCount + tc = tc.union(tc.join(edges).mapPairs(project)).distinct() + nextCount = tc.count() + +print "TC has %i edges" % tc.count() +print tc.collect() From 13b9514966a423f80f672f23f42ec3f0113936fd Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 19 Aug 2012 17:12:51 -0700 Subject: [PATCH 002/291] Bundle cloudpickle with pyspark. --- pyspark/pyspark/cloudpickle.py | 974 +++++++++++++++++++++++++++++++++ pyspark/pyspark/rdd.py | 2 +- pyspark/pyspark/worker.py | 2 +- pyspark/requirements.txt | 3 - 4 files changed, 976 insertions(+), 5 deletions(-) create mode 100644 pyspark/pyspark/cloudpickle.py diff --git a/pyspark/pyspark/cloudpickle.py b/pyspark/pyspark/cloudpickle.py new file mode 100644 index 0000000000..6a7c23a069 --- /dev/null +++ b/pyspark/pyspark/cloudpickle.py @@ -0,0 +1,974 @@ +""" +This class is defined to override standard pickle functionality + +The goals of it follow: +-Serialize lambdas and nested functions to compiled byte code +-Deal with main module correctly +-Deal with other non-serializable objects + +It does not include an unpickler, as standard python unpickling suffices. + +This module was extracted from the `cloud` package, developed by `PiCloud, Inc. +`_. + +Copyright (c) 2012, Regents of the University of California. +Copyright (c) 2009 `PiCloud, Inc. `_. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the University of California, Berkeley nor the + names of its contributors may be used to endorse or promote + products derived from this software without specific prior written + permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + + +import operator +import os +import pickle +import struct +import sys +import types +from functools import partial +import itertools +from copy_reg import _extension_registry, _inverted_registry, _extension_cache +import new +import dis +import traceback + +#relevant opcodes +STORE_GLOBAL = chr(dis.opname.index('STORE_GLOBAL')) +DELETE_GLOBAL = chr(dis.opname.index('DELETE_GLOBAL')) +LOAD_GLOBAL = chr(dis.opname.index('LOAD_GLOBAL')) +GLOBAL_OPS = [STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL] + +HAVE_ARGUMENT = chr(dis.HAVE_ARGUMENT) +EXTENDED_ARG = chr(dis.EXTENDED_ARG) + +import logging +cloudLog = logging.getLogger("Cloud.Transport") + +try: + import ctypes +except (MemoryError, ImportError): + logging.warning('Exception raised on importing ctypes. Likely python bug.. some functionality will be disabled', exc_info = True) + ctypes = None + PyObject_HEAD = None +else: + + # for reading internal structures + PyObject_HEAD = [ + ('ob_refcnt', ctypes.c_size_t), + ('ob_type', ctypes.c_void_p), + ] + + +try: + from cStringIO import StringIO +except ImportError: + from StringIO import StringIO + +# These helper functions were copied from PiCloud's util module. +def islambda(func): + return getattr(func,'func_name') == '' + +def xrange_params(xrangeobj): + """Returns a 3 element tuple describing the xrange start, step, and len + respectively + + Note: Only guarentees that elements of xrange are the same. parameters may + be different. + e.g. xrange(1,1) is interpretted as xrange(0,0); both behave the same + though w/ iteration + """ + + xrange_len = len(xrangeobj) + if not xrange_len: #empty + return (0,1,0) + start = xrangeobj[0] + if xrange_len == 1: #one element + return start, 1, 1 + return (start, xrangeobj[1] - xrangeobj[0], xrange_len) + +#debug variables intended for developer use: +printSerialization = False +printMemoization = False + +useForcedImports = True #Should I use forced imports for tracking? + + + +class CloudPickler(pickle.Pickler): + + dispatch = pickle.Pickler.dispatch.copy() + savedForceImports = False + savedDjangoEnv = False #hack tro transport django environment + + def __init__(self, file, protocol=None, min_size_to_save= 0): + pickle.Pickler.__init__(self,file,protocol) + self.modules = set() #set of modules needed to depickle + self.globals_ref = {} # map ids to dictionary. used to ensure that functions can share global env + + def dump(self, obj): + # note: not thread safe + # minimal side-effects, so not fixing + recurse_limit = 3000 + base_recurse = sys.getrecursionlimit() + if base_recurse < recurse_limit: + sys.setrecursionlimit(recurse_limit) + self.inject_addons() + try: + return pickle.Pickler.dump(self, obj) + except RuntimeError, e: + if 'recursion' in e.args[0]: + msg = """Could not pickle object as excessively deep recursion required. + Try _fast_serialization=2 or contact PiCloud support""" + raise pickle.PicklingError(msg) + finally: + new_recurse = sys.getrecursionlimit() + if new_recurse == recurse_limit: + sys.setrecursionlimit(base_recurse) + + def save_buffer(self, obj): + """Fallback to save_string""" + pickle.Pickler.save_string(self,str(obj)) + dispatch[buffer] = save_buffer + + #block broken objects + def save_unsupported(self, obj, pack=None): + raise pickle.PicklingError("Cannot pickle objects of type %s" % type(obj)) + dispatch[types.GeneratorType] = save_unsupported + + #python2.6+ supports slice pickling. some py2.5 extensions might as well. We just test it + try: + slice(0,1).__reduce__() + except TypeError: #can't pickle - + dispatch[slice] = save_unsupported + + #itertools objects do not pickle! + for v in itertools.__dict__.values(): + if type(v) is type: + dispatch[v] = save_unsupported + + + def save_dict(self, obj): + """hack fix + If the dict is a global, deal with it in a special way + """ + #print 'saving', obj + if obj is __builtins__: + self.save_reduce(_get_module_builtins, (), obj=obj) + else: + pickle.Pickler.save_dict(self, obj) + dispatch[pickle.DictionaryType] = save_dict + + + def save_module(self, obj, pack=struct.pack): + """ + Save a module as an import + """ + #print 'try save import', obj.__name__ + self.modules.add(obj) + self.save_reduce(subimport,(obj.__name__,), obj=obj) + dispatch[types.ModuleType] = save_module #new type + + def save_codeobject(self, obj, pack=struct.pack): + """ + Save a code object + """ + #print 'try to save codeobj: ', obj + args = ( + obj.co_argcount, obj.co_nlocals, obj.co_stacksize, obj.co_flags, obj.co_code, + obj.co_consts, obj.co_names, obj.co_varnames, obj.co_filename, obj.co_name, + obj.co_firstlineno, obj.co_lnotab, obj.co_freevars, obj.co_cellvars + ) + self.save_reduce(types.CodeType, args, obj=obj) + dispatch[types.CodeType] = save_codeobject #new type + + def save_function(self, obj, name=None, pack=struct.pack): + """ Registered with the dispatch to handle all function types. + + Determines what kind of function obj is (e.g. lambda, defined at + interactive prompt, etc) and handles the pickling appropriately. + """ + write = self.write + + name = obj.__name__ + modname = pickle.whichmodule(obj, name) + #print 'which gives %s %s %s' % (modname, obj, name) + try: + themodule = sys.modules[modname] + except KeyError: # eval'd items such as namedtuple give invalid items for their function __module__ + modname = '__main__' + + if modname == '__main__': + themodule = None + + if themodule: + self.modules.add(themodule) + + if not self.savedDjangoEnv: + #hack for django - if we detect the settings module, we transport it + django_settings = os.environ.get('DJANGO_SETTINGS_MODULE', '') + if django_settings: + django_mod = sys.modules.get(django_settings) + if django_mod: + cloudLog.debug('Transporting django settings %s during save of %s', django_mod, name) + self.savedDjangoEnv = True + self.modules.add(django_mod) + write(pickle.MARK) + self.save_reduce(django_settings_load, (django_mod.__name__,), obj=django_mod) + write(pickle.POP_MARK) + + + # if func is lambda, def'ed at prompt, is in main, or is nested, then + # we'll pickle the actual function object rather than simply saving a + # reference (as is done in default pickler), via save_function_tuple. + if islambda(obj) or obj.func_code.co_filename == '' or themodule == None: + #Force server to import modules that have been imported in main + modList = None + if themodule == None and not self.savedForceImports: + mainmod = sys.modules['__main__'] + if useForcedImports and hasattr(mainmod,'___pyc_forcedImports__'): + modList = list(mainmod.___pyc_forcedImports__) + self.savedForceImports = True + self.save_function_tuple(obj, modList) + return + else: # func is nested + klass = getattr(themodule, name, None) + if klass is None or klass is not obj: + self.save_function_tuple(obj, [themodule]) + return + + if obj.__dict__: + # essentially save_reduce, but workaround needed to avoid recursion + self.save(_restore_attr) + write(pickle.MARK + pickle.GLOBAL + modname + '\n' + name + '\n') + self.memoize(obj) + self.save(obj.__dict__) + write(pickle.TUPLE + pickle.REDUCE) + else: + write(pickle.GLOBAL + modname + '\n' + name + '\n') + self.memoize(obj) + dispatch[types.FunctionType] = save_function + + def save_function_tuple(self, func, forced_imports): + """ Pickles an actual func object. + + A func comprises: code, globals, defaults, closure, and dict. We + extract and save these, injecting reducing functions at certain points + to recreate the func object. Keep in mind that some of these pieces + can contain a ref to the func itself. Thus, a naive save on these + pieces could trigger an infinite loop of save's. To get around that, + we first create a skeleton func object using just the code (this is + safe, since this won't contain a ref to the func), and memoize it as + soon as it's created. The other stuff can then be filled in later. + """ + save = self.save + write = self.write + + # save the modules (if any) + if forced_imports: + write(pickle.MARK) + save(_modules_to_main) + #print 'forced imports are', forced_imports + + forced_names = map(lambda m: m.__name__, forced_imports) + save((forced_names,)) + + #save((forced_imports,)) + write(pickle.REDUCE) + write(pickle.POP_MARK) + + code, f_globals, defaults, closure, dct, base_globals = self.extract_func_data(func) + + save(_fill_function) # skeleton function updater + write(pickle.MARK) # beginning of tuple that _fill_function expects + + # create a skeleton function object and memoize it + save(_make_skel_func) + save((code, len(closure), base_globals)) + write(pickle.REDUCE) + self.memoize(func) + + # save the rest of the func data needed by _fill_function + save(f_globals) + save(defaults) + save(closure) + save(dct) + write(pickle.TUPLE) + write(pickle.REDUCE) # applies _fill_function on the tuple + + @staticmethod + def extract_code_globals(co): + """ + Find all globals names read or written to by codeblock co + """ + code = co.co_code + names = co.co_names + out_names = set() + + n = len(code) + i = 0 + extended_arg = 0 + while i < n: + op = code[i] + + i = i+1 + if op >= HAVE_ARGUMENT: + oparg = ord(code[i]) + ord(code[i+1])*256 + extended_arg + extended_arg = 0 + i = i+2 + if op == EXTENDED_ARG: + extended_arg = oparg*65536L + if op in GLOBAL_OPS: + out_names.add(names[oparg]) + #print 'extracted', out_names, ' from ', names + return out_names + + def extract_func_data(self, func): + """ + Turn the function into a tuple of data necessary to recreate it: + code, globals, defaults, closure, dict + """ + code = func.func_code + + # extract all global ref's + func_global_refs = CloudPickler.extract_code_globals(code) + if code.co_consts: # see if nested function have any global refs + for const in code.co_consts: + if type(const) is types.CodeType and const.co_names: + func_global_refs = func_global_refs.union( CloudPickler.extract_code_globals(const)) + # process all variables referenced by global environment + f_globals = {} + for var in func_global_refs: + #Some names, such as class functions are not global - we don't need them + if func.func_globals.has_key(var): + f_globals[var] = func.func_globals[var] + + # defaults requires no processing + defaults = func.func_defaults + + def get_contents(cell): + try: + return cell.cell_contents + except ValueError, e: #cell is empty error on not yet assigned + raise pickle.PicklingError('Function to be pickled has free variables that are referenced before assignment in enclosing scope') + + + # process closure + if func.func_closure: + closure = map(get_contents, func.func_closure) + else: + closure = [] + + # save the dict + dct = func.func_dict + + if printSerialization: + outvars = ['code: ' + str(code) ] + outvars.append('globals: ' + str(f_globals)) + outvars.append('defaults: ' + str(defaults)) + outvars.append('closure: ' + str(closure)) + print 'function ', func, 'is extracted to: ', ', '.join(outvars) + + base_globals = self.globals_ref.get(id(func.func_globals), {}) + self.globals_ref[id(func.func_globals)] = base_globals + + return (code, f_globals, defaults, closure, dct, base_globals) + + def save_global(self, obj, name=None, pack=struct.pack): + write = self.write + memo = self.memo + + if name is None: + name = obj.__name__ + + modname = getattr(obj, "__module__", None) + if modname is None: + modname = pickle.whichmodule(obj, name) + + try: + __import__(modname) + themodule = sys.modules[modname] + except (ImportError, KeyError, AttributeError): #should never occur + raise pickle.PicklingError( + "Can't pickle %r: Module %s cannot be found" % + (obj, modname)) + + if modname == '__main__': + themodule = None + + if themodule: + self.modules.add(themodule) + + sendRef = True + typ = type(obj) + #print 'saving', obj, typ + try: + try: #Deal with case when getattribute fails with exceptions + klass = getattr(themodule, name) + except (AttributeError): + if modname == '__builtin__': #new.* are misrepeported + modname = 'new' + __import__(modname) + themodule = sys.modules[modname] + try: + klass = getattr(themodule, name) + except AttributeError, a: + #print themodule, name, obj, type(obj) + raise pickle.PicklingError("Can't pickle builtin %s" % obj) + else: + raise + + except (ImportError, KeyError, AttributeError): + if typ == types.TypeType or typ == types.ClassType: + sendRef = False + else: #we can't deal with this + raise + else: + if klass is not obj and (typ == types.TypeType or typ == types.ClassType): + sendRef = False + if not sendRef: + #note: Third party types might crash this - add better checks! + d = dict(obj.__dict__) #copy dict proxy to a dict + if not isinstance(d.get('__dict__', None), property): # don't extract dict that are properties + d.pop('__dict__',None) + d.pop('__weakref__',None) + + # hack as __new__ is stored differently in the __dict__ + new_override = d.get('__new__', None) + if new_override: + d['__new__'] = obj.__new__ + + self.save_reduce(type(obj),(obj.__name__,obj.__bases__, + d),obj=obj) + #print 'internal reduce dask %s %s' % (obj, d) + return + + if self.proto >= 2: + code = _extension_registry.get((modname, name)) + if code: + assert code > 0 + if code <= 0xff: + write(pickle.EXT1 + chr(code)) + elif code <= 0xffff: + write("%c%c%c" % (pickle.EXT2, code&0xff, code>>8)) + else: + write(pickle.EXT4 + pack("= 2 and getattr(func, "__name__", "") == "__newobj__": + #Added fix to allow transient + cls = args[0] + if not hasattr(cls, "__new__"): + raise pickle.PicklingError( + "args[0] from __newobj__ args has no __new__") + if obj is not None and cls is not obj.__class__: + raise pickle.PicklingError( + "args[0] from __newobj__ args has the wrong class") + args = args[1:] + save(cls) + + #Don't pickle transient entries + if hasattr(obj, '__transient__'): + transient = obj.__transient__ + state = state.copy() + + for k in list(state.keys()): + if k in transient: + del state[k] + + save(args) + write(pickle.NEWOBJ) + else: + save(func) + save(args) + write(pickle.REDUCE) + + if obj is not None: + self.memoize(obj) + + # More new special cases (that work with older protocols as + # well): when __reduce__ returns a tuple with 4 or 5 items, + # the 4th and 5th item should be iterators that provide list + # items and dict items (as (key, value) tuples), or None. + + if listitems is not None: + self._batch_appends(listitems) + + if dictitems is not None: + self._batch_setitems(dictitems) + + if state is not None: + #print 'obj %s has state %s' % (obj, state) + save(state) + write(pickle.BUILD) + + + def save_xrange(self, obj): + """Save an xrange object in python 2.5 + Python 2.6 supports this natively + """ + range_params = xrange_params(obj) + self.save_reduce(_build_xrange,range_params) + + #python2.6+ supports xrange pickling. some py2.5 extensions might as well. We just test it + try: + xrange(0).__reduce__() + except TypeError: #can't pickle -- use PiCloud pickler + dispatch[xrange] = save_xrange + + def save_partial(self, obj): + """Partial objects do not serialize correctly in python2.x -- this fixes the bugs""" + self.save_reduce(_genpartial, (obj.func, obj.args, obj.keywords)) + + if sys.version_info < (2,7): #2.7 supports partial pickling + dispatch[partial] = save_partial + + + def save_file(self, obj): + """Save a file""" + import StringIO as pystringIO #we can't use cStringIO as it lacks the name attribute + from ..transport.adapter import SerializingAdapter + + if not hasattr(obj, 'name') or not hasattr(obj, 'mode'): + raise pickle.PicklingError("Cannot pickle files that do not map to an actual file") + if obj.name == '': + return self.save_reduce(getattr, (sys,'stdout'), obj=obj) + if obj.name == '': + return self.save_reduce(getattr, (sys,'stderr'), obj=obj) + if obj.name == '': + raise pickle.PicklingError("Cannot pickle standard input") + if hasattr(obj, 'isatty') and obj.isatty(): + raise pickle.PicklingError("Cannot pickle files that map to tty objects") + if 'r' not in obj.mode: + raise pickle.PicklingError("Cannot pickle files that are not opened for reading") + name = obj.name + try: + fsize = os.stat(name).st_size + except OSError: + raise pickle.PicklingError("Cannot pickle file %s as it cannot be stat" % name) + + if obj.closed: + #create an empty closed string io + retval = pystringIO.StringIO("") + retval.close() + elif not fsize: #empty file + retval = pystringIO.StringIO("") + try: + tmpfile = file(name) + tst = tmpfile.read(1) + except IOError: + raise pickle.PicklingError("Cannot pickle file %s as it cannot be read" % name) + tmpfile.close() + if tst != '': + raise pickle.PicklingError("Cannot pickle file %s as it does not appear to map to a physical, real file" % name) + elif fsize > SerializingAdapter.max_transmit_data: + raise pickle.PicklingError("Cannot pickle file %s as it exceeds cloudconf.py's max_transmit_data of %d" % + (name,SerializingAdapter.max_transmit_data)) + else: + try: + tmpfile = file(name) + contents = tmpfile.read(SerializingAdapter.max_transmit_data) + tmpfile.close() + except IOError: + raise pickle.PicklingError("Cannot pickle file %s as it cannot be read" % name) + retval = pystringIO.StringIO(contents) + curloc = obj.tell() + retval.seek(curloc) + + retval.name = name + self.save(retval) #save stringIO + self.memoize(obj) + + dispatch[file] = save_file + """Special functions for Add-on libraries""" + + def inject_numpy(self): + numpy = sys.modules.get('numpy') + if not numpy or not hasattr(numpy, 'ufunc'): + return + self.dispatch[numpy.ufunc] = self.__class__.save_ufunc + + numpy_tst_mods = ['numpy', 'scipy.special'] + def save_ufunc(self, obj): + """Hack function for saving numpy ufunc objects""" + name = obj.__name__ + for tst_mod_name in self.numpy_tst_mods: + tst_mod = sys.modules.get(tst_mod_name, None) + if tst_mod: + if name in tst_mod.__dict__: + self.save_reduce(_getobject, (tst_mod_name, name)) + return + raise pickle.PicklingError('cannot save %s. Cannot resolve what module it is defined in' % str(obj)) + + def inject_timeseries(self): + """Handle bugs with pickling scikits timeseries""" + tseries = sys.modules.get('scikits.timeseries.tseries') + if not tseries or not hasattr(tseries, 'Timeseries'): + return + self.dispatch[tseries.Timeseries] = self.__class__.save_timeseries + + def save_timeseries(self, obj): + import scikits.timeseries.tseries as ts + + func, reduce_args, state = obj.__reduce__() + if func != ts._tsreconstruct: + raise pickle.PicklingError('timeseries using unexpected reconstruction function %s' % str(func)) + state = (1, + obj.shape, + obj.dtype, + obj.flags.fnc, + obj._data.tostring(), + ts.getmaskarray(obj).tostring(), + obj._fill_value, + obj._dates.shape, + obj._dates.__array__().tostring(), + obj._dates.dtype, #added -- preserve type + obj.freq, + obj._optinfo, + ) + return self.save_reduce(_genTimeSeries, (reduce_args, state)) + + def inject_email(self): + """Block email LazyImporters from being saved""" + email = sys.modules.get('email') + if not email: + return + self.dispatch[email.LazyImporter] = self.__class__.save_unsupported + + def inject_addons(self): + """Plug in system. Register additional pickling functions if modules already loaded""" + self.inject_numpy() + self.inject_timeseries() + self.inject_email() + + """Python Imaging Library""" + def save_image(self, obj): + if not obj.im and obj.fp and 'r' in obj.fp.mode and obj.fp.name \ + and not obj.fp.closed and (not hasattr(obj, 'isatty') or not obj.isatty()): + #if image not loaded yet -- lazy load + self.save_reduce(_lazyloadImage,(obj.fp,), obj=obj) + else: + #image is loaded - just transmit it over + self.save_reduce(_generateImage, (obj.size, obj.mode, obj.tostring()), obj=obj) + + """ + def memoize(self, obj): + pickle.Pickler.memoize(self, obj) + if printMemoization: + print 'memoizing ' + str(obj) + """ + + + +# Shorthands for legacy support + +def dump(obj, file, protocol=2): + CloudPickler(file, protocol).dump(obj) + +def dumps(obj, protocol=2): + file = StringIO() + + cp = CloudPickler(file,protocol) + cp.dump(obj) + + #print 'cloud dumped', str(obj), str(cp.modules) + + return file.getvalue() + + +#hack for __import__ not working as desired +def subimport(name): + __import__(name) + return sys.modules[name] + +#hack to load django settings: +def django_settings_load(name): + modified_env = False + + if 'DJANGO_SETTINGS_MODULE' not in os.environ: + os.environ['DJANGO_SETTINGS_MODULE'] = name # must set name first due to circular deps + modified_env = True + try: + module = subimport(name) + except Exception, i: + print >> sys.stderr, 'Cloud not import django settings %s:' % (name) + print_exec(sys.stderr) + if modified_env: + del os.environ['DJANGO_SETTINGS_MODULE'] + else: + #add project directory to sys,path: + if hasattr(module,'__file__'): + dirname = os.path.split(module.__file__)[0] + '/' + sys.path.append(dirname) + +# restores function attributes +def _restore_attr(obj, attr): + for key, val in attr.items(): + setattr(obj, key, val) + return obj + +def _get_module_builtins(): + return pickle.__builtins__ + +def print_exec(stream): + ei = sys.exc_info() + traceback.print_exception(ei[0], ei[1], ei[2], None, stream) + +def _modules_to_main(modList): + """Force every module in modList to be placed into main""" + if not modList: + return + + main = sys.modules['__main__'] + for modname in modList: + if type(modname) is str: + try: + mod = __import__(modname) + except Exception, i: #catch all... + sys.stderr.write('warning: could not import %s\n. Your function may unexpectedly error due to this import failing; \ +A version mismatch is likely. Specific error was:\n' % modname) + print_exec(sys.stderr) + else: + setattr(main,mod.__name__, mod) + else: + #REVERSE COMPATIBILITY FOR CLOUD CLIENT 1.5 (WITH EPD) + #In old version actual module was sent + setattr(main,modname.__name__, modname) + +#object generators: +def _build_xrange(start, step, len): + """Built xrange explicitly""" + return xrange(start, start + step*len, step) + +def _genpartial(func, args, kwds): + if not args: + args = () + if not kwds: + kwds = {} + return partial(func, *args, **kwds) + + +def _fill_function(func, globals, defaults, closure, dict): + """ Fills in the rest of function data into the skeleton function object + that were created via _make_skel_func(). + """ + func.func_globals.update(globals) + func.func_defaults = defaults + func.func_dict = dict + + if len(closure) != len(func.func_closure): + raise pickle.UnpicklingError("closure lengths don't match up") + for i in range(len(closure)): + _change_cell_value(func.func_closure[i], closure[i]) + + return func + +def _make_skel_func(code, num_closures, base_globals = None): + """ Creates a skeleton function object that contains just the provided + code and the correct number of cells in func_closure. All other + func attributes (e.g. func_globals) are empty. + """ + #build closure (cells): + if not ctypes: + raise Exception('ctypes failed to import; cannot build function') + + cellnew = ctypes.pythonapi.PyCell_New + cellnew.restype = ctypes.py_object + cellnew.argtypes = (ctypes.py_object,) + dummy_closure = tuple(map(lambda i: cellnew(None), range(num_closures))) + + if base_globals is None: + base_globals = {} + base_globals['__builtins__'] = __builtins__ + + return types.FunctionType(code, base_globals, + None, None, dummy_closure) + +# this piece of opaque code is needed below to modify 'cell' contents +cell_changer_code = new.code( + 1, 1, 2, 0, + ''.join([ + chr(dis.opmap['LOAD_FAST']), '\x00\x00', + chr(dis.opmap['DUP_TOP']), + chr(dis.opmap['STORE_DEREF']), '\x00\x00', + chr(dis.opmap['RETURN_VALUE']) + ]), + (), (), ('newval',), '', 'cell_changer', 1, '', ('c',), () +) + +def _change_cell_value(cell, newval): + """ Changes the contents of 'cell' object to newval """ + return new.function(cell_changer_code, {}, None, (), (cell,))(newval) + +"""Constructors for 3rd party libraries +Note: These can never be renamed due to client compatibility issues""" + +def _getobject(modname, attribute): + mod = __import__(modname) + return mod.__dict__[attribute] + +def _generateImage(size, mode, str_rep): + """Generate image from string representation""" + import Image + i = Image.new(mode, size) + i.fromstring(str_rep) + return i + +def _lazyloadImage(fp): + import Image + fp.seek(0) #works in almost any case + return Image.open(fp) + +"""Timeseries""" +def _genTimeSeries(reduce_args, state): + import scikits.timeseries.tseries as ts + from numpy import ndarray + from numpy.ma import MaskedArray + + + time_series = ts._tsreconstruct(*reduce_args) + + #from setstate modified + (ver, shp, typ, isf, raw, msk, flv, dsh, dtm, dtyp, frq, infodict) = state + #print 'regenerating %s' % dtyp + + MaskedArray.__setstate__(time_series, (ver, shp, typ, isf, raw, msk, flv)) + _dates = time_series._dates + #_dates.__setstate__((ver, dsh, typ, isf, dtm, frq)) #use remote typ + ndarray.__setstate__(_dates,(dsh,dtyp, isf, dtm)) + _dates.freq = frq + _dates._cachedinfo.update(dict(full=None, hasdups=None, steps=None, + toobj=None, toord=None, tostr=None)) + # Update the _optinfo dictionary + time_series._optinfo.update(infodict) + return time_series + diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index c892e86b93..5579c56de3 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -1,5 +1,5 @@ from base64 import standard_b64encode as b64enc -from cloud.serialization import cloudpickle +from pyspark import cloudpickle from itertools import chain from pyspark.serializers import PairSerializer, NopSerializer, \ diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py index 4d4cc939c3..4c4b02fce4 100644 --- a/pyspark/pyspark/worker.py +++ b/pyspark/pyspark/worker.py @@ -5,7 +5,7 @@ import sys from base64 import standard_b64decode # CloudPickler needs to be imported so that depicklers are registered using the # copy_reg module. -from cloud.serialization.cloudpickle import CloudPickler +from pyspark.cloudpickle import CloudPickler import cPickle diff --git a/pyspark/requirements.txt b/pyspark/requirements.txt index d9b3fe40bd..71e2bc2b89 100644 --- a/pyspark/requirements.txt +++ b/pyspark/requirements.txt @@ -4,6 +4,3 @@ # install Py4J from git once https://github.com/pypa/pip/pull/526 is merged. # git+git://github.com/bartdag/py4j.git@3dbf380d3d2cdeb9aab394454ea74d80c4aba1ea - -simplejson==2.6.1 -cloud==2.5.5 From fd94e5443c99775bfad1928729f5075c900ad0f9 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 18 Aug 2012 16:07:10 -0700 Subject: [PATCH 003/291] Use only cPickle for serialization in Python API. Objects serialized with JSON can be compared for equality, but JSON can be slow to serialize and only supports a limited range of data types. --- .../scala/spark/api/python/PythonRDD.scala | 192 ++++++-- pyspark/pyspark/context.py | 53 +-- pyspark/pyspark/java_gateway.py | 1 - pyspark/pyspark/join.py | 32 +- pyspark/pyspark/rdd.py | 422 +++++++----------- pyspark/pyspark/serializers.py | 233 +--------- pyspark/pyspark/worker.py | 64 +-- 7 files changed, 387 insertions(+), 610 deletions(-) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 660ad48afe..b9a0168d18 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -1,22 +1,26 @@ package spark.api.python -import java.io.PrintWriter +import java.io._ import scala.collection.Map import scala.collection.JavaConversions._ import scala.io.Source import spark._ -import api.java.{JavaPairRDD, JavaRDD} +import api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} +import scala.{collection, Some} +import collection.parallel.mutable +import scala.collection import scala.Some trait PythonRDDBase { def compute[T](split: Split, envVars: Map[String, String], - command: Seq[String], parent: RDD[T], pythonExec: String): Iterator[String]= { - val currentEnvVars = new ProcessBuilder().environment() - val SPARK_HOME = currentEnvVars.get("SPARK_HOME") + command: Seq[String], parent: RDD[T], pythonExec: String): Iterator[Array[Byte]] = { + val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME") val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/pyspark/pyspark/worker.py")) // Add the environmental variables to the process. + val currentEnvVars = pb.environment() + envVars.foreach { case (variable, value) => currentEnvVars.put(variable, value) } @@ -41,33 +45,70 @@ trait PythonRDDBase { for (elem <- command) { out.println(elem) } + out.flush() + val dOut = new DataOutputStream(proc.getOutputStream) for (elem <- parent.iterator(split)) { - out.println(PythonRDD.pythonDump(elem)) + if (elem.isInstanceOf[Array[Byte]]) { + val arr = elem.asInstanceOf[Array[Byte]] + dOut.writeInt(arr.length) + dOut.write(arr) + } else if (elem.isInstanceOf[scala.Tuple2[_, _]]) { + val t = elem.asInstanceOf[scala.Tuple2[_, _]] + val t1 = t._1.asInstanceOf[Array[Byte]] + val t2 = t._2.asInstanceOf[Array[Byte]] + val length = t1.length + t2.length - 3 - 3 + 4 // stripPickle() removes 3 bytes + dOut.writeInt(length) + dOut.writeByte(Pickle.PROTO) + dOut.writeByte(Pickle.TWO) + dOut.write(PythonRDD.stripPickle(t1)) + dOut.write(PythonRDD.stripPickle(t2)) + dOut.writeByte(Pickle.TUPLE2) + dOut.writeByte(Pickle.STOP) + } else if (elem.isInstanceOf[String]) { + // For uniformity, strings are wrapped into Pickles. + val s = elem.asInstanceOf[String].getBytes("UTF-8") + val length = 2 + 1 + 4 + s.length + 1 + dOut.writeInt(length) + dOut.writeByte(Pickle.PROTO) + dOut.writeByte(Pickle.TWO) + dOut.writeByte(Pickle.BINUNICODE) + dOut.writeInt(Integer.reverseBytes(s.length)) + dOut.write(s) + dOut.writeByte(Pickle.STOP) + } else { + throw new Exception("Unexpected RDD type") + } } - out.close() + dOut.flush() + out.flush() + proc.getOutputStream.close() } }.start() // Return an iterator that read lines from the process's stdout - val lines: Iterator[String] = Source.fromInputStream(proc.getInputStream).getLines - wrapIterator(lines, proc) - } + val stream = new DataInputStream(proc.getInputStream) + return new Iterator[Array[Byte]] { + def next() = { + val obj = _nextObj + _nextObj = read() + obj + } - def wrapIterator[T](iter: Iterator[T], proc: Process): Iterator[T] = { - return new Iterator[T] { - def next() = iter.next() - - def hasNext = { - if (iter.hasNext) { - true - } else { - val exitStatus = proc.waitFor() - if (exitStatus != 0) { - throw new Exception("Subprocess exited with status " + exitStatus) - } - false + private def read() = { + try { + val length = stream.readInt() + val obj = new Array[Byte](length) + stream.readFully(obj) + obj + } catch { + case eof: EOFException => { new Array[Byte](0) } + case e => throw e } } + + var _nextObj = read() + + def hasNext = _nextObj.length != 0 } } } @@ -75,7 +116,7 @@ trait PythonRDDBase { class PythonRDD[T: ClassManifest]( parent: RDD[T], command: Seq[String], envVars: Map[String, String], preservePartitoning: Boolean, pythonExec: String) - extends RDD[String](parent.context) with PythonRDDBase { + extends RDD[Array[Byte]](parent.context) with PythonRDDBase { def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, pythonExec: String) = this(parent, command, Map(), preservePartitoning, pythonExec) @@ -91,16 +132,16 @@ class PythonRDD[T: ClassManifest]( override val partitioner = if (preservePartitoning) parent.partitioner else None - override def compute(split: Split): Iterator[String] = + override def compute(split: Split): Iterator[Array[Byte]] = compute(split, envVars, command, parent, pythonExec) - val asJavaRDD : JavaRDD[String] = JavaRDD.fromRDD(this) + val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) } class PythonPairRDD[T: ClassManifest] ( parent: RDD[T], command: Seq[String], envVars: Map[String, String], preservePartitoning: Boolean, pythonExec: String) - extends RDD[(String, String)](parent.context) with PythonRDDBase { + extends RDD[(Array[Byte], Array[Byte])](parent.context) with PythonRDDBase { def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, pythonExec: String) = this(parent, command, Map(), preservePartitoning, pythonExec) @@ -116,32 +157,95 @@ class PythonPairRDD[T: ClassManifest] ( override val partitioner = if (preservePartitoning) parent.partitioner else None - override def compute(split: Split): Iterator[(String, String)] = { + override def compute(split: Split): Iterator[(Array[Byte], Array[Byte])] = { compute(split, envVars, command, parent, pythonExec).grouped(2).map { case Seq(a, b) => (a, b) - case x => throw new Exception("Unexpected value: " + x) + case x => throw new Exception("PythonPairRDD: unexpected value: " + x) } } - val asJavaPairRDD : JavaPairRDD[String, String] = JavaPairRDD.fromRDD(this) + val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this) } + object PythonRDD { - def pythonDump[T](x: T): String = { - if (x.isInstanceOf[scala.Option[_]]) { - val t = x.asInstanceOf[scala.Option[_]] - t match { - case None => "*" - case Some(z) => pythonDump(z) - } - } else if (x.isInstanceOf[scala.Tuple2[_, _]]) { - val t = x.asInstanceOf[scala.Tuple2[_, _]] - "(" + pythonDump(t._1) + "," + pythonDump(t._2) + ")" - } else if (x.isInstanceOf[java.util.List[_]]) { - val objs = asScalaBuffer(x.asInstanceOf[java.util.List[_]]).map(pythonDump) - "[" + objs.mkString("|") + "]" + + /** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */ + def stripPickle(arr: Array[Byte]) : Array[Byte] = { + arr.slice(2, arr.length - 1) + } + + def asPickle(elem: Any) : Array[Byte] = { + val baos = new ByteArrayOutputStream(); + val dOut = new DataOutputStream(baos); + if (elem.isInstanceOf[Array[Byte]]) { + elem.asInstanceOf[Array[Byte]] + } else if (elem.isInstanceOf[scala.Tuple2[_, _]]) { + val t = elem.asInstanceOf[scala.Tuple2[_, _]] + val t1 = t._1.asInstanceOf[Array[Byte]] + val t2 = t._2.asInstanceOf[Array[Byte]] + dOut.writeByte(Pickle.PROTO) + dOut.writeByte(Pickle.TWO) + dOut.write(PythonRDD.stripPickle(t1)) + dOut.write(PythonRDD.stripPickle(t2)) + dOut.writeByte(Pickle.TUPLE2) + dOut.writeByte(Pickle.STOP) + baos.toByteArray() + } else if (elem.isInstanceOf[String]) { + // For uniformity, strings are wrapped into Pickles. + val s = elem.asInstanceOf[String].getBytes("UTF-8") + dOut.writeByte(Pickle.PROTO) + dOut.writeByte(Pickle.TWO) + dOut.write(Pickle.BINUNICODE) + dOut.writeInt(Integer.reverseBytes(s.length)) + dOut.write(s) + dOut.writeByte(Pickle.STOP) + baos.toByteArray() } else { - x.toString + throw new Exception("Unexpected RDD type") } } + + def pickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) : + JavaRDD[Array[Byte]] = { + val file = new DataInputStream(new FileInputStream(filename)) + val objs = new collection.mutable.ArrayBuffer[Array[Byte]] + try { + while (true) { + val length = file.readInt() + val obj = new Array[Byte](length) + file.readFully(obj) + objs.append(obj) + } + } catch { + case eof: EOFException => {} + case e => throw e + } + JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) + } + + def arrayAsPickle(arr : Any) : Array[Byte] = { + val pickles : Array[Byte] = arr.asInstanceOf[Array[Any]].map(asPickle).map(stripPickle).flatten + + Array[Byte](Pickle.PROTO, Pickle.TWO, Pickle.EMPTY_LIST, Pickle.MARK) ++ pickles ++ + Array[Byte] (Pickle.APPENDS, Pickle.STOP) + } +} + +private object Pickle { + def b(x: Int): Byte = x.asInstanceOf[Byte] + val PROTO: Byte = b(0x80) + val TWO: Byte = b(0x02) + val BINUNICODE : Byte = 'X' + val STOP : Byte = '.' + val TUPLE2 : Byte = b(0x86) + val EMPTY_LIST : Byte = ']' + val MARK : Byte = '(' + val APPENDS : Byte = 'e' +} +class ExtractValue extends spark.api.java.function.Function[(Array[Byte], + Array[Byte]), Array[Byte]] { + + override def call(pair: (Array[Byte], Array[Byte])) : Array[Byte] = pair._2 + } diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index 587ab12b5f..ac7e4057e9 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -3,22 +3,24 @@ import atexit from tempfile import NamedTemporaryFile from pyspark.java_gateway import launch_gateway -from pyspark.serializers import JSONSerializer, NopSerializer -from pyspark.rdd import RDD, PairRDD +from pyspark.serializers import PickleSerializer, dumps +from pyspark.rdd import RDD class SparkContext(object): gateway = launch_gateway() jvm = gateway.jvm - python_dump = jvm.spark.api.python.PythonRDD.pythonDump + pickleFile = jvm.spark.api.python.PythonRDD.pickleFile + asPickle = jvm.spark.api.python.PythonRDD.asPickle + arrayAsPickle = jvm.spark.api.python.PythonRDD.arrayAsPickle - def __init__(self, master, name, defaultSerializer=JSONSerializer, - defaultParallelism=None, pythonExec='python'): + + def __init__(self, master, name, defaultParallelism=None, + pythonExec='python'): self.master = master self.name = name self._jsc = self.jvm.JavaSparkContext(master, name) - self.defaultSerializer = defaultSerializer self.defaultParallelism = \ defaultParallelism or self._jsc.sc().defaultParallelism() self.pythonExec = pythonExec @@ -31,39 +33,26 @@ class SparkContext(object): self._jsc.stop() self._jsc = None - def parallelize(self, c, numSlices=None, serializer=None): - serializer = serializer or self.defaultSerializer + def parallelize(self, c, numSlices=None): + """ + >>> sc = SparkContext("local", "test") + >>> rdd = sc.parallelize([(1, 2), (3, 4)]) + >>> rdd.collect() + [(1, 2), (3, 4)] + """ numSlices = numSlices or self.defaultParallelism # Calling the Java parallelize() method with an ArrayList is too slow, # because it sends O(n) Py4J commands. As an alternative, serialized # objects are written to a file and loaded through textFile(). tempFile = NamedTemporaryFile(delete=False) - tempFile.writelines(serializer.dumps(x) + '\n' for x in c) + for x in c: + dumps(PickleSerializer.dumps(x), tempFile) tempFile.close() atexit.register(lambda: os.unlink(tempFile.name)) - return self.textFile(tempFile.name, numSlices, serializer) + jrdd = self.pickleFile(self._jsc, tempFile.name, numSlices) + return RDD(jrdd, self) - def parallelizePairs(self, c, numSlices=None, keySerializer=None, - valSerializer=None): - """ - >>> sc = SparkContext("local", "test") - >>> rdd = sc.parallelizePairs([(1, 2), (3, 4)]) - >>> rdd.collect() - [(1, 2), (3, 4)] - """ - keySerializer = keySerializer or self.defaultSerializer - valSerializer = valSerializer or self.defaultSerializer - numSlices = numSlices or self.defaultParallelism - tempFile = NamedTemporaryFile(delete=False) - for (k, v) in c: - tempFile.write(keySerializer.dumps(k).rstrip('\r\n') + '\n') - tempFile.write(valSerializer.dumps(v).rstrip('\r\n') + '\n') - tempFile.close() - atexit.register(lambda: os.unlink(tempFile.name)) - jrdd = self.textFile(tempFile.name, numSlices)._pipePairs([], "echo") - return PairRDD(jrdd, self, keySerializer, valSerializer) - - def textFile(self, name, numSlices=None, serializer=NopSerializer): + def textFile(self, name, numSlices=None): numSlices = numSlices or self.defaultParallelism jrdd = self._jsc.textFile(name, numSlices) - return RDD(jrdd, self, serializer) + return RDD(jrdd, self) diff --git a/pyspark/pyspark/java_gateway.py b/pyspark/pyspark/java_gateway.py index 2df80aee85..bcb405ba72 100644 --- a/pyspark/pyspark/java_gateway.py +++ b/pyspark/pyspark/java_gateway.py @@ -16,5 +16,4 @@ def launch_gateway(): java_import(gateway.jvm, "spark.api.java.*") java_import(gateway.jvm, "spark.api.python.*") java_import(gateway.jvm, "scala.Tuple2") - java_import(gateway.jvm, "spark.api.python.PythonRDD.pythonDump") return gateway diff --git a/pyspark/pyspark/join.py b/pyspark/pyspark/join.py index c67520fce8..7036c47980 100644 --- a/pyspark/pyspark/join.py +++ b/pyspark/pyspark/join.py @@ -30,15 +30,12 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """ -from pyspark.serializers import PairSerializer, OptionSerializer, \ - ArraySerializer -def _do_python_join(rdd, other, numSplits, dispatch, valSerializer): - vs = rdd.mapPairs(lambda (k, v): (k, (1, v))) - ws = other.mapPairs(lambda (k, v): (k, (2, v))) - return vs.union(ws).groupByKey(numSplits) \ - .flatMapValues(dispatch, valSerializer) +def _do_python_join(rdd, other, numSplits, dispatch): + vs = rdd.map(lambda (k, v): (k, (1, v))) + ws = other.map(lambda (k, v): (k, (2, v))) + return vs.union(ws).groupByKey(numSplits).flatMapValues(dispatch) def python_join(rdd, other, numSplits): @@ -50,8 +47,7 @@ def python_join(rdd, other, numSplits): elif n == 2: wbuf.append(v) return [(v, w) for v in vbuf for w in wbuf] - valSerializer = PairSerializer(rdd.valSerializer, other.valSerializer) - return _do_python_join(rdd, other, numSplits, dispatch, valSerializer) + return _do_python_join(rdd, other, numSplits, dispatch) def python_right_outer_join(rdd, other, numSplits): @@ -65,9 +61,7 @@ def python_right_outer_join(rdd, other, numSplits): if not vbuf: vbuf.append(None) return [(v, w) for v in vbuf for w in wbuf] - valSerializer = PairSerializer(OptionSerializer(rdd.valSerializer), - other.valSerializer) - return _do_python_join(rdd, other, numSplits, dispatch, valSerializer) + return _do_python_join(rdd, other, numSplits, dispatch) def python_left_outer_join(rdd, other, numSplits): @@ -81,17 +75,12 @@ def python_left_outer_join(rdd, other, numSplits): if not wbuf: wbuf.append(None) return [(v, w) for v in vbuf for w in wbuf] - valSerializer = PairSerializer(rdd.valSerializer, - OptionSerializer(other.valSerializer)) - return _do_python_join(rdd, other, numSplits, dispatch, valSerializer) + return _do_python_join(rdd, other, numSplits, dispatch) def python_cogroup(rdd, other, numSplits): - resultValSerializer = PairSerializer( - ArraySerializer(rdd.valSerializer), - ArraySerializer(other.valSerializer)) - vs = rdd.mapPairs(lambda (k, v): (k, (1, v))) - ws = other.mapPairs(lambda (k, v): (k, (2, v))) + vs = rdd.map(lambda (k, v): (k, (1, v))) + ws = other.map(lambda (k, v): (k, (2, v))) def dispatch(seq): vbuf, wbuf = [], [] for (n, v) in seq: @@ -100,5 +89,4 @@ def python_cogroup(rdd, other, numSplits): elif n == 2: wbuf.append(v) return (vbuf, wbuf) - return vs.union(ws).groupByKey(numSplits) \ - .mapValues(dispatch, resultValSerializer) + return vs.union(ws).groupByKey(numSplits).mapValues(dispatch) diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 5579c56de3..8eccddc0a2 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -1,31 +1,17 @@ from base64 import standard_b64encode as b64enc -from pyspark import cloudpickle -from itertools import chain -from pyspark.serializers import PairSerializer, NopSerializer, \ - OptionSerializer, ArraySerializer +from pyspark import cloudpickle +from pyspark.serializers import PickleSerializer from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup class RDD(object): - def __init__(self, jrdd, ctx, serializer=None): + def __init__(self, jrdd, ctx): self._jrdd = jrdd self.is_cached = False self.ctx = ctx - self.serializer = serializer or ctx.defaultSerializer - - def _builder(self, jrdd, ctx): - return RDD(jrdd, ctx, self.serializer) - - @property - def id(self): - return self._jrdd.id() - - @property - def splits(self): - return self._jrdd.splits() @classmethod def _get_pipe_command(cls, command, functions): @@ -41,55 +27,18 @@ class RDD(object): self._jrdd.cache() return self - def map(self, f, serializer=None, preservesPartitioning=False): - return MappedRDD(self, f, serializer, preservesPartitioning) + def map(self, f, preservesPartitioning=False): + return MappedRDD(self, f, preservesPartitioning) - def mapPairs(self, f, keySerializer=None, valSerializer=None, - preservesPartitioning=False): - return PairMappedRDD(self, f, keySerializer, valSerializer, - preservesPartitioning) - - def flatMap(self, f, serializer=None): + def flatMap(self, f): """ >>> rdd = sc.parallelize([2, 3, 4]) >>> sorted(rdd.flatMap(lambda x: range(1, x)).collect()) [1, 1, 1, 2, 2, 3] - """ - serializer = serializer or self.ctx.defaultSerializer - dumps = serializer.dumps - loads = self.serializer.loads - def func(x): - pickled_elems = (dumps(y) for y in f(loads(x))) - return "\n".join(pickled_elems) or None - pipe_command = RDD._get_pipe_command("map", [func]) - class_manifest = self._jrdd.classManifest() - jrdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command, - False, self.ctx.pythonExec, - class_manifest).asJavaRDD() - return RDD(jrdd, self.ctx, serializer) - - def flatMapPairs(self, f, keySerializer=None, valSerializer=None, - preservesPartitioning=False): - """ - >>> rdd = sc.parallelize([2, 3, 4]) - >>> sorted(rdd.flatMapPairs(lambda x: [(x, x), (x, x)]).collect()) + >>> sorted(rdd.flatMap(lambda x: [(x, x), (x, x)]).collect()) [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] """ - keySerializer = keySerializer or self.ctx.defaultSerializer - valSerializer = valSerializer or self.ctx.defaultSerializer - dumpk = keySerializer.dumps - dumpv = valSerializer.dumps - loads = self.serializer.loads - def func(x): - pairs = f(loads(x)) - pickled_pairs = ((dumpk(k), dumpv(v)) for (k, v) in pairs) - return "\n".join(chain.from_iterable(pickled_pairs)) or None - pipe_command = RDD._get_pipe_command("map", [func]) - class_manifest = self._jrdd.classManifest() - python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(), pipe_command, - preservesPartitioning, self.ctx.pythonExec, class_manifest) - return PairRDD(python_rdd.asJavaPairRDD(), self.ctx, keySerializer, - valSerializer) + return MappedRDD(self, f, preservesPartitioning=False, command='flatmap') def filter(self, f): """ @@ -97,9 +46,8 @@ class RDD(object): >>> rdd.filter(lambda x: x % 2 == 0).collect() [2, 4] """ - loads = self.serializer.loads - def filter_func(x): return x if f(loads(x)) else None - return self._builder(self._pipe(filter_func), self.ctx) + def filter_func(x): return x if f(x) else None + return RDD(self._pipe(filter_func), self.ctx) def _pipe(self, functions, command="map"): class_manifest = self._jrdd.classManifest() @@ -108,32 +56,22 @@ class RDD(object): False, self.ctx.pythonExec, class_manifest) return python_rdd.asJavaRDD() - def _pipePairs(self, functions, command="mapPairs", - preservesPartitioning=False): - class_manifest = self._jrdd.classManifest() - pipe_command = RDD._get_pipe_command(command, functions) - python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(), pipe_command, - preservesPartitioning, self.ctx.pythonExec, class_manifest) - return python_rdd.asJavaPairRDD() - def distinct(self): """ >>> sorted(sc.parallelize([1, 1, 2, 3]).distinct().collect()) [1, 2, 3] """ - if self.serializer.is_comparable: - return self._builder(self._jrdd.distinct(), self.ctx) - return self.mapPairs(lambda x: (x, "")) \ + return self.map(lambda x: (x, "")) \ .reduceByKey(lambda x, _: x) \ .map(lambda (x, _): x) def sample(self, withReplacement, fraction, seed): jrdd = self._jrdd.sample(withReplacement, fraction, seed) - return self._builder(jrdd, self.ctx) + return RDD(jrdd, self.ctx) def takeSample(self, withReplacement, num, seed): vals = self._jrdd.takeSample(withReplacement, num, seed) - return [self.serializer.loads(self.ctx.python_dump(x)) for x in vals] + return [PickleSerializer.loads(x) for x in vals] def union(self, other): """ @@ -141,7 +79,7 @@ class RDD(object): >>> rdd.union(rdd).collect() [1, 1, 2, 3, 1, 1, 2, 3] """ - return self._builder(self._jrdd.union(other._jrdd), self.ctx) + return RDD(self._jrdd.union(other._jrdd), self.ctx) # TODO: sort @@ -155,16 +93,17 @@ class RDD(object): >>> sorted(rdd.cartesian(rdd).collect()) [(1, 1), (1, 2), (2, 1), (2, 2)] """ - return PairRDD(self._jrdd.cartesian(other._jrdd), self.ctx) + return RDD(self._jrdd.cartesian(other._jrdd), self.ctx) # numsplits def groupBy(self, f, numSplits=None): """ >>> rdd = sc.parallelize([1, 1, 2, 3, 5, 8]) - >>> sorted(rdd.groupBy(lambda x: x % 2).collect()) + >>> result = rdd.groupBy(lambda x: x % 2).collect() + >>> sorted([(x, sorted(y)) for (x, y) in result]) [(0, [2, 8]), (1, [1, 1, 3, 5])] """ - return self.mapPairs(lambda x: (f(x), x)).groupByKey(numSplits) + return self.map(lambda x: (f(x), x)).groupByKey(numSplits) # TODO: pipe @@ -178,25 +117,19 @@ class RDD(object): self.map(f).collect() # Force evaluation def collect(self): - vals = self._jrdd.collect() - return [self.serializer.loads(self.ctx.python_dump(x)) for x in vals] + pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().collect()) + return PickleSerializer.loads(bytes(pickle)) - def reduce(self, f, serializer=None): + def reduce(self, f): """ - >>> import operator - >>> sc.parallelize([1, 2, 3, 4, 5]).reduce(operator.add) + >>> from operator import add + >>> sc.parallelize([1, 2, 3, 4, 5]).reduce(add) 15 + >>> sc.parallelize((2 for _ in range(10))).map(lambda x: 1).cache().reduce(add) + 10 """ - serializer = serializer or self.ctx.defaultSerializer - loads = self.serializer.loads - dumps = serializer.dumps - def reduceFunction(x, acc): - if acc is None: - return loads(x) - else: - return f(loads(x), acc) - vals = self._pipe([reduceFunction, dumps], command="reduce").collect() - return reduce(f, (serializer.loads(x) for x in vals)) + vals = MappedRDD(self, f, command="reduce", preservesPartitioning=False).collect() + return reduce(f, vals) # TODO: fold @@ -216,36 +149,35 @@ class RDD(object): >>> sc.parallelize([2, 3, 4]).take(2) [2, 3] """ - vals = self._jrdd.take(num) - return [self.serializer.loads(self.ctx.python_dump(x)) for x in vals] + pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().take(num)) + return PickleSerializer.loads(bytes(pickle)) def first(self): """ >>> sc.parallelize([2, 3, 4]).first() 2 """ - return self.serializer.loads(self.ctx.python_dump(self._jrdd.first())) + return PickleSerializer.loads(bytes(self.ctx.asPickle(self._jrdd.first()))) # TODO: saveAsTextFile # TODO: saveAsObjectFile + # Pair functions -class PairRDD(RDD): - - def __init__(self, jrdd, ctx, keySerializer=None, valSerializer=None): - RDD.__init__(self, jrdd, ctx) - self.keySerializer = keySerializer or ctx.defaultSerializer - self.valSerializer = valSerializer or ctx.defaultSerializer - self.serializer = \ - PairSerializer(self.keySerializer, self.valSerializer) - - def _builder(self, jrdd, ctx): - return PairRDD(jrdd, ctx, self.keySerializer, self.valSerializer) + def collectAsMap(self): + """ + >>> m = sc.parallelize([(1, 2), (3, 4)]).collectAsMap() + >>> m[1] + 2 + >>> m[3] + 4 + """ + return dict(self.collect()) def reduceByKey(self, func, numSplits=None): """ - >>> x = sc.parallelizePairs([("a", 1), ("b", 1), ("a", 1)]) + >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> sorted(x.reduceByKey(lambda a, b: a + b).collect()) [('a', 2), ('b', 1)] """ @@ -259,90 +191,67 @@ class PairRDD(RDD): def join(self, other, numSplits=None): """ - >>> x = sc.parallelizePairs([("a", 1), ("b", 4)]) - >>> y = sc.parallelizePairs([("a", 2), ("a", 3)]) - >>> x.join(y).collect() + >>> x = sc.parallelize([("a", 1), ("b", 4)]) + >>> y = sc.parallelize([("a", 2), ("a", 3)]) + >>> sorted(x.join(y).collect()) [('a', (1, 2)), ('a', (1, 3))] - - Check that we get a PairRDD-like object back: - >>> assert x.join(y).join """ - assert self.keySerializer.name == other.keySerializer.name - if self.keySerializer.is_comparable: - return PairRDD(self._jrdd.join(other._jrdd), - self.ctx, self.keySerializer, - PairSerializer(self.valSerializer, other.valSerializer)) - else: - return python_join(self, other, numSplits) + return python_join(self, other, numSplits) def leftOuterJoin(self, other, numSplits=None): """ - >>> x = sc.parallelizePairs([("a", 1), ("b", 4)]) - >>> y = sc.parallelizePairs([("a", 2)]) + >>> x = sc.parallelize([("a", 1), ("b", 4)]) + >>> y = sc.parallelize([("a", 2)]) >>> sorted(x.leftOuterJoin(y).collect()) [('a', (1, 2)), ('b', (4, None))] """ - assert self.keySerializer.name == other.keySerializer.name - if self.keySerializer.is_comparable: - return PairRDD(self._jrdd.leftOuterJoin(other._jrdd), - self.ctx, self.keySerializer, - PairSerializer(self.valSerializer, - OptionSerializer(other.valSerializer))) - else: - return python_left_outer_join(self, other, numSplits) + return python_left_outer_join(self, other, numSplits) def rightOuterJoin(self, other, numSplits=None): """ - >>> x = sc.parallelizePairs([("a", 1), ("b", 4)]) - >>> y = sc.parallelizePairs([("a", 2)]) + >>> x = sc.parallelize([("a", 1), ("b", 4)]) + >>> y = sc.parallelize([("a", 2)]) >>> sorted(y.rightOuterJoin(x).collect()) [('a', (2, 1)), ('b', (None, 4))] """ - assert self.keySerializer.name == other.keySerializer.name - if self.keySerializer.is_comparable: - return PairRDD(self._jrdd.rightOuterJoin(other._jrdd), - self.ctx, self.keySerializer, - PairSerializer(OptionSerializer(self.valSerializer), - other.valSerializer)) - else: - return python_right_outer_join(self, other, numSplits) + return python_right_outer_join(self, other, numSplits) + + # TODO: pipelining + # TODO: optimizations + def shuffle(self, numSplits): + if numSplits is None: + numSplits = self.ctx.defaultParallelism + pipe_command = RDD._get_pipe_command('shuffle_map_step', []) + class_manifest = self._jrdd.classManifest() + python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(), + pipe_command, False, self.ctx.pythonExec, class_manifest) + partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits) + jrdd = python_rdd.asJavaPairRDD().partitionBy(partitioner) + jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) + # TODO: extract second value. + return RDD(jrdd, self.ctx) + + def combineByKey(self, createCombiner, mergeValue, mergeCombiners, - numSplits=None, serializer=None): + numSplits=None): """ - >>> x = sc.parallelizePairs([("a", 1), ("b", 1), ("a", 1)]) + >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> def f(x): return x >>> def add(a, b): return a + str(b) >>> sorted(x.combineByKey(str, add, add).collect()) [('a', '11'), ('b', '1')] """ - serializer = serializer or self.ctx.defaultSerializer if numSplits is None: numSplits = self.ctx.defaultParallelism - # Use hash() to create keys that are comparable in Java. - loadkv = self.serializer.loads - def pairify(kv): - # TODO: add method to deserialize only the key or value from - # a PairSerializer? - key = loadkv(kv)[0] - return (str(hash(key)), kv) - partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits) - jrdd = self._pipePairs(pairify).partitionBy(partitioner) - pairified = PairRDD(jrdd, self.ctx, NopSerializer, self.serializer) - - loads = PairSerializer(NopSerializer, self.serializer).loads - dumpk = self.keySerializer.dumps - dumpc = serializer.dumps - - functions = [createCombiner, mergeValue, mergeCombiners, loads, dumpk, - dumpc] - jpairs = pairified._pipePairs(functions, "combine_by_key", - preservesPartitioning=True) - return PairRDD(jpairs, self.ctx, self.keySerializer, serializer) + shuffled = self.shuffle(numSplits) + functions = [createCombiner, mergeValue, mergeCombiners] + jpairs = shuffled._pipe(functions, "combine_by_key") + return RDD(jpairs, self.ctx) def groupByKey(self, numSplits=None): """ - >>> x = sc.parallelizePairs([("a", 1), ("b", 1), ("a", 1)]) + >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> sorted(x.groupByKey().collect()) [('a', [1, 1]), ('b', [1])] """ @@ -360,29 +269,15 @@ class PairRDD(RDD): return self.combineByKey(createCombiner, mergeValue, mergeCombiners, numSplits) - def collectAsMap(self): - """ - >>> m = sc.parallelizePairs([(1, 2), (3, 4)]).collectAsMap() - >>> m[1] - 2 - >>> m[3] - 4 - """ - m = self._jrdd.collectAsMap() - def loads(x): - (k, v) = x - return (self.keySerializer.loads(k), self.valSerializer.loads(v)) - return dict(loads(x) for x in m.items()) - - def flatMapValues(self, f, valSerializer=None): + def flatMapValues(self, f): flat_map_fn = lambda (k, v): ((k, x) for x in f(v)) - return self.flatMapPairs(flat_map_fn, self.keySerializer, - valSerializer, True) + return self.flatMap(flat_map_fn) - def mapValues(self, f, valSerializer=None): + def mapValues(self, f): map_values_fn = lambda (k, v): (k, f(v)) - return self.mapPairs(map_values_fn, self.keySerializer, valSerializer, - True) + return self.map(map_values_fn, preservesPartitioning=True) + + # TODO: implement shuffle. # TODO: support varargs cogroup of several RDDs. def groupWith(self, other): @@ -390,20 +285,12 @@ class PairRDD(RDD): def cogroup(self, other, numSplits=None): """ - >>> x = sc.parallelizePairs([("a", 1), ("b", 4)]) - >>> y = sc.parallelizePairs([("a", 2)]) + >>> x = sc.parallelize([("a", 1), ("b", 4)]) + >>> y = sc.parallelize([("a", 2)]) >>> x.cogroup(y).collect() [('a', ([1], [2])), ('b', ([4], []))] """ - assert self.keySerializer.name == other.keySerializer.name - resultValSerializer = PairSerializer( - ArraySerializer(self.valSerializer), - ArraySerializer(other.valSerializer)) - if self.keySerializer.is_comparable: - return PairRDD(self._jrdd.cogroup(other._jrdd), - self.ctx, self.keySerializer, resultValSerializer) - else: - return python_cogroup(self, other, numSplits) + return python_cogroup(self, other, numSplits) # TODO: `lookup` is disabled because we can't make direct comparisons based # on the key; we need to compare the hash of the key to the hash of the @@ -413,44 +300,84 @@ class PairRDD(RDD): # TODO: file saving -class MappedRDDBase(object): - def __init__(self, prev, func, serializer, preservesPartitioning=False): - if isinstance(prev, MappedRDDBase) and not prev.is_cached: - prev_func = prev.func - self.func = lambda x: func(prev_func(x)) - self.preservesPartitioning = \ - prev.preservesPartitioning and preservesPartitioning - self._prev_jrdd = prev._prev_jrdd - self._prev_serializer = prev._prev_serializer - else: - self.func = func - self.preservesPartitioning = preservesPartitioning - self._prev_jrdd = prev._jrdd - self._prev_serializer = prev.serializer - self.serializer = serializer or prev.ctx.defaultSerializer - self.is_cached = False - self.ctx = prev.ctx - self.prev = prev - self._jrdd_val = None - - -class MappedRDD(MappedRDDBase, RDD): +class MappedRDD(RDD): """ + Pipelined maps: >>> rdd = sc.parallelize([1, 2, 3, 4]) >>> rdd.map(lambda x: 2 * x).cache().map(lambda x: 2 * x).collect() [4, 8, 12, 16] >>> rdd.map(lambda x: 2 * x).map(lambda x: 2 * x).collect() [4, 8, 12, 16] + + Pipelined reduces: + >>> from operator import add + >>> rdd.map(lambda x: 2 * x).reduce(add) + 20 + >>> rdd.flatMap(lambda x: [x, x]).reduce(add) + 20 """ + def __init__(self, prev, func, preservesPartitioning=False, command='map'): + if isinstance(prev, MappedRDD) and not prev.is_cached: + prev_func = prev.func + if command == 'reduce': + if prev.command == 'flatmap': + def flatmap_reduce_func(x, acc): + values = prev_func(x) + if values is None: + return acc + if not acc: + if len(values) == 1: + return values[0] + else: + return reduce(func, values[1:], values[0]) + else: + return reduce(func, values, acc) + self.func = flatmap_reduce_func + else: + def reduce_func(x, acc): + val = prev_func(x) + if not val: + return acc + if acc is None: + return val + else: + return func(val, acc) + self.func = reduce_func + else: + if prev.command == 'flatmap': + command = 'flatmap' + self.func = lambda x: (func(y) for y in prev_func(x)) + else: + self.func = lambda x: func(prev_func(x)) + + self.preservesPartitioning = \ + prev.preservesPartitioning and preservesPartitioning + self._prev_jrdd = prev._prev_jrdd + self.is_pipelined = True + else: + if command == 'reduce': + def reduce_func(val, acc): + if acc is None: + return val + else: + return func(val, acc) + self.func = reduce_func + else: + self.func = func + self.preservesPartitioning = preservesPartitioning + self._prev_jrdd = prev._jrdd + self.is_pipelined = False + self.is_cached = False + self.ctx = prev.ctx + self.prev = prev + self._jrdd_val = None + self.command = command @property def _jrdd(self): if not self._jrdd_val: - udf = self.func - loads = self._prev_serializer.loads - dumps = self.serializer.dumps - func = lambda x: dumps(udf(loads(x))) - pipe_command = RDD._get_pipe_command("map", [func]) + funcs = [self.func] + pipe_command = RDD._get_pipe_command(self.command, funcs) class_manifest = self._prev_jrdd.classManifest() python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), pipe_command, self.preservesPartitioning, self.ctx.pythonExec, @@ -459,56 +386,11 @@ class MappedRDD(MappedRDDBase, RDD): return self._jrdd_val -class PairMappedRDD(MappedRDDBase, PairRDD): - """ - >>> rdd = sc.parallelize([1, 2, 3, 4]) - >>> rdd.mapPairs(lambda x: (x, x)) \\ - ... .mapPairs(lambda (x, y): (2*x, 2*y)) \\ - ... .collect() - [(2, 2), (4, 4), (6, 6), (8, 8)] - >>> rdd.mapPairs(lambda x: (x, x)) \\ - ... .mapPairs(lambda (x, y): (2*x, 2*y)) \\ - ... .map(lambda (x, _): x).collect() - [2, 4, 6, 8] - """ - - def __init__(self, prev, func, keySerializer=None, valSerializer=None, - preservesPartitioning=False): - self.keySerializer = keySerializer or prev.ctx.defaultSerializer - self.valSerializer = valSerializer or prev.ctx.defaultSerializer - serializer = PairSerializer(self.keySerializer, self.valSerializer) - MappedRDDBase.__init__(self, prev, func, serializer, - preservesPartitioning) - - @property - def _jrdd(self): - if not self._jrdd_val: - udf = self.func - loads = self._prev_serializer.loads - dumpk = self.keySerializer.dumps - dumpv = self.valSerializer.dumps - def func(x): - (k, v) = udf(loads(x)) - return (dumpk(k), dumpv(v)) - pipe_command = RDD._get_pipe_command("mapPairs", [func]) - class_manifest = self._prev_jrdd.classManifest() - self._jrdd_val = self.ctx.jvm.PythonPairRDD(self._prev_jrdd.rdd(), - pipe_command, self.preservesPartitioning, self.ctx.pythonExec, - class_manifest).asJavaPairRDD() - return self._jrdd_val - - def _test(): import doctest from pyspark.context import SparkContext - from pyspark.serializers import PickleSerializer, JSONSerializer globs = globals().copy() - globs['sc'] = SparkContext('local', 'PythonTest', - defaultSerializer=JSONSerializer) - doctest.testmod(globs=globs) - globs['sc'].stop() - globs['sc'] = SparkContext('local', 'PythonTest', - defaultSerializer=PickleSerializer) + globs['sc'] = SparkContext('local', 'PythonTest') doctest.testmod(globs=globs) globs['sc'].stop() diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py index b113f5656b..7b3e6966e1 100644 --- a/pyspark/pyspark/serializers.py +++ b/pyspark/pyspark/serializers.py @@ -2,228 +2,35 @@ Data serialization methods. The Spark Python API is built on top of the Spark Java API. RDDs created in -Python are stored in Java as RDDs of Strings. Python objects are automatically -serialized/deserialized, so this representation is transparent to the end-user. - ------------------- -Serializer objects ------------------- - -`Serializer` objects are used to customize how an RDD's values are serialized. - -Each `Serializer` is a named tuple with four fields: - - - A `dumps` function, for serializing a Python object to a string. - - - A `loads` function, for deserializing a Python object from a string. - - - An `is_comparable` field, True if equal Python objects are serialized to - equal strings, and False otherwise. - - - A `name` field, used to identify the Serializer. Serializers are - compared for equality by comparing their names. - -The serializer's output should be base64-encoded. - ------------------------------------------------------------------- -`is_comparable`: comparing serialized representations for equality ------------------------------------------------------------------- - -If `is_comparable` is False, the serializer's representations of equal objects -are not required to be equal: - ->>> import pickle ->>> a = {1: 0, 9: 0} ->>> b = {9: 0, 1: 0} ->>> a == b -True ->>> pickle.dumps(a) == pickle.dumps(b) -False - -RDDs with comparable serializers can use native Java implementations of -operations like join() and distinct(), which may lead to better performance by -eliminating deserialization and Python comparisons. - -The default JSONSerializer produces comparable representations of common Python -data structures. - --------------------------------------- -Examples of serialized representations --------------------------------------- - -The RDD transformations that use Python UDFs are implemented in terms of -a modified `PipedRDD.pipe()` function. For each record `x` in the RDD, the -`pipe()` function pipes `x.toString()` to a Python worker process, which -deserializes the string into a Python object, executes user-defined functions, -and outputs serialized Python objects. - -The regular `toString()` method returns an ambiguous representation, due to the -way that Scala `Option` instances are printed: - ->>> from context import SparkContext ->>> sc = SparkContext("local", "SerializerDocs") ->>> x = sc.parallelizePairs([("a", 1), ("b", 4)]) ->>> y = sc.parallelizePairs([("a", 2)]) - ->>> print y.rightOuterJoin(x)._jrdd.first().toString() -(ImEi,(Some(Mg==),MQ==)) - -In Java, preprocessing is performed to handle Option instances, so the Python -process receives unambiguous input: - ->>> print sc.python_dump(y.rightOuterJoin(x)._jrdd.first()) -(ImEi,(Mg==,MQ==)) - -The base64-encoding eliminates the need to escape newlines, parentheses and -other special characters. - ----------------------- -Serializer composition ----------------------- - -In order to handle nested structures, which could contain object serialized -with different serializers, the RDD module composes serializers. For example, -the serializers in the previous example are: - ->>> print x.serializer.name -PairSerializer - ->>> print y.serializer.name -PairSerializer - ->>> print y.rightOuterJoin(x).serializer.name -PairSerializer, JSONSerializer>> +Python are stored in Java as RDD[Array[Byte]]. Python objects are +automatically serialized/deserialized, so this representation is transparent to +the end-user. """ -from base64 import standard_b64encode, standard_b64decode from collections import namedtuple import cPickle -import simplejson +import struct -Serializer = namedtuple("Serializer", - ["dumps","loads", "is_comparable", "name"]) - - -NopSerializer = Serializer(str, str, True, "NopSerializer") - - -JSONSerializer = Serializer( - lambda obj: standard_b64encode(simplejson.dumps(obj, sort_keys=True, - separators=(',', ':'))), - lambda s: simplejson.loads(standard_b64decode(s)), - True, - "JSONSerializer" -) +Serializer = namedtuple("Serializer", ["dumps","loads"]) PickleSerializer = Serializer( - lambda obj: standard_b64encode(cPickle.dumps(obj)), - lambda s: cPickle.loads(standard_b64decode(s)), - False, - "PickleSerializer" -) + lambda obj: cPickle.dumps(obj, -1), + cPickle.loads) -def OptionSerializer(serializer): - """ - >>> ser = OptionSerializer(NopSerializer) - >>> ser.loads(ser.dumps("Hello, World!")) - 'Hello, World!' - >>> ser.loads(ser.dumps(None)) is None - True - """ - none_placeholder = '*' - - def dumps(x): - if x is None: - return none_placeholder - else: - return serializer.dumps(x) - - def loads(x): - if x == none_placeholder: - return None - else: - return serializer.loads(x) - - name = "OptionSerializer<%s>" % serializer.name - return Serializer(dumps, loads, serializer.is_comparable, name) +def dumps(obj, stream): + # TODO: determining the length of non-byte objects. + stream.write(struct.pack("!i", len(obj))) + stream.write(obj) -def PairSerializer(keySerializer, valSerializer): - """ - Returns a Serializer for a (key, value) pair. - - >>> ser = PairSerializer(JSONSerializer, JSONSerializer) - >>> ser.loads(ser.dumps((1, 2))) - (1, 2) - - >>> ser = PairSerializer(JSONSerializer, ser) - >>> ser.loads(ser.dumps((1, (2, 3)))) - (1, (2, 3)) - """ - def loads(kv): - try: - (key, val) = kv[1:-1].split(',', 1) - key = keySerializer.loads(key) - val = valSerializer.loads(val) - return (key, val) - except: - print "Error in deserializing pair from '%s'" % str(kv) - raise - - def dumps(kv): - (key, val) = kv - return"(%s,%s)" % (keySerializer.dumps(key), valSerializer.dumps(val)) - is_comparable = \ - keySerializer.is_comparable and valSerializer.is_comparable - name = "PairSerializer<%s, %s>" % (keySerializer.name, valSerializer.name) - return Serializer(dumps, loads, is_comparable, name) - - -def ArraySerializer(serializer): - """ - >>> ser = ArraySerializer(JSONSerializer) - >>> ser.loads(ser.dumps([1, 2, 3, 4])) - [1, 2, 3, 4] - >>> ser = ArraySerializer(PairSerializer(JSONSerializer, PickleSerializer)) - >>> ser.loads(ser.dumps([('a', 1), ('b', 2)])) - [('a', 1), ('b', 2)] - >>> ser.loads(ser.dumps([('a', 1)])) - [('a', 1)] - >>> ser.loads(ser.dumps([])) - [] - """ - def dumps(arr): - if arr == []: - return '[]' - else: - return '[' + '|'.join(serializer.dumps(x) for x in arr) + ']' - - def loads(s): - if s == '[]': - return [] - items = s[1:-1] - if '|' in items: - items = items.split('|') - else: - items = [items] - return [serializer.loads(x) for x in items] - - name = "ArraySerializer<%s>" % serializer.name - return Serializer(dumps, loads, serializer.is_comparable, name) - - -# TODO: IntegerSerializer - - -# TODO: DoubleSerializer - - -def _test(): - import doctest - doctest.testmod() - - -if __name__ == "__main__": - _test() +def loads(stream): + length = stream.read(4) + if length == "": + raise EOFError + length = struct.unpack("!i", length)[0] + obj = stream.read(length) + if obj == "": + raise EOFError + return obj diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py index 4c4b02fce4..21ff84fb17 100644 --- a/pyspark/pyspark/worker.py +++ b/pyspark/pyspark/worker.py @@ -6,9 +6,9 @@ from base64 import standard_b64decode # CloudPickler needs to be imported so that depicklers are registered using the # copy_reg module. from pyspark.cloudpickle import CloudPickler +from pyspark.serializers import dumps, loads, PickleSerializer import cPickle - # Redirect stdout to stderr so that users must return values from functions. old_stdout = sys.stdout sys.stdout = sys.stderr @@ -19,58 +19,64 @@ def load_function(): def output(x): - for line in x.split("\n"): - old_stdout.write(line.rstrip("\r\n") + "\n") + dumps(x, old_stdout) def read_input(): - for line in sys.stdin: - yield line.rstrip("\r\n") - + try: + while True: + yield loads(sys.stdin) + except EOFError: + return def do_combine_by_key(): create_combiner = load_function() merge_value = load_function() merge_combiners = load_function() # TODO: not used. - depickler = load_function() - key_pickler = load_function() - combiner_pickler = load_function() combiners = {} - for line in read_input(): - # Discard the hashcode added in the Python combineByKey() method. - (key, value) = depickler(line)[1] + for obj in read_input(): + (key, value) = PickleSerializer.loads(obj) if key not in combiners: combiners[key] = create_combiner(value) else: combiners[key] = merge_value(combiners[key], value) for (key, combiner) in combiners.iteritems(): - output(key_pickler(key)) - output(combiner_pickler(combiner)) + output(PickleSerializer.dumps((key, combiner))) -def do_map(map_pairs=False): +def do_map(flat=False): f = load_function() - for line in read_input(): + for obj in read_input(): try: - out = f(line) + #from pickletools import dis + #print repr(obj) + #print dis(obj) + out = f(PickleSerializer.loads(obj)) if out is not None: - if map_pairs: + if flat: for x in out: - output(x) + output(PickleSerializer.dumps(x)) else: - output(out) + output(PickleSerializer.dumps(out)) except: - sys.stderr.write("Error processing line '%s'\n" % line) + sys.stderr.write("Error processing obj %s\n" % repr(obj)) raise +def do_shuffle_map_step(): + for obj in read_input(): + key = PickleSerializer.loads(obj)[1] + output(str(hash(key))) + output(obj) + + def do_reduce(): f = load_function() - dumps = load_function() acc = None - for line in read_input(): - acc = f(line, acc) - output(dumps(acc)) + for obj in read_input(): + acc = f(PickleSerializer.loads(obj), acc) + if acc is not None: + output(PickleSerializer.dumps(acc)) def do_echo(): @@ -80,13 +86,15 @@ def do_echo(): def main(): command = sys.stdin.readline().strip() if command == "map": - do_map(map_pairs=False) - elif command == "mapPairs": - do_map(map_pairs=True) + do_map(flat=False) + elif command == "flatmap": + do_map(flat=True) elif command == "combine_by_key": do_combine_by_key() elif command == "reduce": do_reduce() + elif command == "shuffle_map_step": + do_shuffle_map_step() elif command == "echo": do_echo() else: From 607b53abfca049e7d9139e2d29893a3bb252de19 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 22 Aug 2012 00:43:55 -0700 Subject: [PATCH 004/291] Use numpy in Python k-means example. --- .../scala/spark/api/python/PythonRDD.scala | 8 ++++++- pyspark/pyspark/examples/kmeans.py | 23 +++++++------------ pyspark/pyspark/rdd.py | 9 +++----- pyspark/pyspark/worker.py | 8 +++---- 4 files changed, 21 insertions(+), 27 deletions(-) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index b9a0168d18..93847e2f14 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -101,7 +101,13 @@ trait PythonRDDBase { stream.readFully(obj) obj } catch { - case eof: EOFException => { new Array[Byte](0) } + case eof: EOFException => { + val exitStatus = proc.waitFor() + if (exitStatus != 0) { + throw new Exception("Subprocess exited with status " + exitStatus) + } + new Array[Byte](0) + } case e => throw e } } diff --git a/pyspark/pyspark/examples/kmeans.py b/pyspark/pyspark/examples/kmeans.py index 0761d6e395..9cc366f03c 100644 --- a/pyspark/pyspark/examples/kmeans.py +++ b/pyspark/pyspark/examples/kmeans.py @@ -1,25 +1,18 @@ import sys from pyspark.context import SparkContext +from numpy import array, sum as np_sum def parseVector(line): - return [float(x) for x in line.split(' ')] - - -def addVec(x, y): - return [a + b for (a, b) in zip(x, y)] - - -def squaredDist(x, y): - return sum((a - b) ** 2 for (a, b) in zip(x, y)) + return array([float(x) for x in line.split(' ')]) def closestPoint(p, centers): bestIndex = 0 closest = float("+inf") for i in range(len(centers)): - tempDist = squaredDist(p, centers[i]) + tempDist = np_sum((p - centers[i]) ** 2) if tempDist < closest: closest = tempDist bestIndex = i @@ -41,14 +34,14 @@ if __name__ == "__main__": tempDist = 1.0 while tempDist > convergeDist: - closest = data.mapPairs( + closest = data.map( lambda p : (closestPoint(p, kPoints), (p, 1))) pointStats = closest.reduceByKey( - lambda (x1, y1), (x2, y2): (addVec(x1, x2), y1 + y2)) - newPoints = pointStats.mapPairs( - lambda (x, (y, z)): (x, [a / z for a in y])).collect() + lambda (x1, y1), (x2, y2): (x1 + x2, y1 + y2)) + newPoints = pointStats.map( + lambda (x, (y, z)): (x, y / z)).collect() - tempDist = sum(squaredDist(kPoints[x], y) for (x, y) in newPoints) + tempDist = sum(np_sum((kPoints[x] - y) ** 2) for (x, y) in newPoints) for (x, y) in newPoints: kPoints[x] = y diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 8eccddc0a2..ff9c483032 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -71,7 +71,7 @@ class RDD(object): def takeSample(self, withReplacement, num, seed): vals = self._jrdd.takeSample(withReplacement, num, seed) - return [PickleSerializer.loads(x) for x in vals] + return [PickleSerializer.loads(bytes(x)) for x in vals] def union(self, other): """ @@ -218,17 +218,16 @@ class RDD(object): # TODO: pipelining # TODO: optimizations - def shuffle(self, numSplits): + def shuffle(self, numSplits, hashFunc=hash): if numSplits is None: numSplits = self.ctx.defaultParallelism - pipe_command = RDD._get_pipe_command('shuffle_map_step', []) + pipe_command = RDD._get_pipe_command('shuffle_map_step', [hashFunc]) class_manifest = self._jrdd.classManifest() python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(), pipe_command, False, self.ctx.pythonExec, class_manifest) partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits) jrdd = python_rdd.asJavaPairRDD().partitionBy(partitioner) jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) - # TODO: extract second value. return RDD(jrdd, self.ctx) @@ -277,8 +276,6 @@ class RDD(object): map_values_fn = lambda (k, v): (k, f(v)) return self.map(map_values_fn, preservesPartitioning=True) - # TODO: implement shuffle. - # TODO: support varargs cogroup of several RDDs. def groupWith(self, other): return self.cogroup(other) diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py index 21ff84fb17..b13ed5699a 100644 --- a/pyspark/pyspark/worker.py +++ b/pyspark/pyspark/worker.py @@ -48,9 +48,6 @@ def do_map(flat=False): f = load_function() for obj in read_input(): try: - #from pickletools import dis - #print repr(obj) - #print dis(obj) out = f(PickleSerializer.loads(obj)) if out is not None: if flat: @@ -64,9 +61,10 @@ def do_map(flat=False): def do_shuffle_map_step(): + hashFunc = load_function() for obj in read_input(): - key = PickleSerializer.loads(obj)[1] - output(str(hash(key))) + key = PickleSerializer.loads(obj)[0] + output(str(hashFunc(key))) output(obj) From 4b523004877cf94152225484de7683e9d17cdb56 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 23 Aug 2012 15:54:15 -0700 Subject: [PATCH 005/291] Fix options parsing in Python pi example. --- pyspark/pyspark/examples/pi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyspark/pyspark/examples/pi.py b/pyspark/pyspark/examples/pi.py index ad77694c41..fe63d2c952 100644 --- a/pyspark/pyspark/examples/pi.py +++ b/pyspark/pyspark/examples/pi.py @@ -10,7 +10,7 @@ if __name__ == "__main__": "Usage: PythonPi []" exit(-1) sc = SparkContext(sys.argv[1], "PythonKMeans") - slices = sys.argv[2] if len(sys.argv) > 2 else 2 + slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2 n = 100000 * slices def f(_): x = random() * 2 - 1 From f3b852ce66d193e3421eeecef71ea27bff73a94b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 24 Aug 2012 19:38:50 -0700 Subject: [PATCH 006/291] Refactor Python MappedRDD to use iterator pipelines. --- pyspark/pyspark/rdd.py | 83 ++++++++++++--------------------------- pyspark/pyspark/worker.py | 55 +++++++------------------- 2 files changed, 41 insertions(+), 97 deletions(-) diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index ff9c483032..7d280d8844 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -1,4 +1,5 @@ from base64 import standard_b64encode as b64enc +from itertools import chain, ifilter, imap from pyspark import cloudpickle from pyspark.serializers import PickleSerializer @@ -15,8 +16,6 @@ class RDD(object): @classmethod def _get_pipe_command(cls, command, functions): - if functions and not isinstance(functions, (list, tuple)): - functions = [functions] worker_args = [command] for f in functions: worker_args.append(b64enc(cloudpickle.dumps(f))) @@ -28,7 +27,8 @@ class RDD(object): return self def map(self, f, preservesPartitioning=False): - return MappedRDD(self, f, preservesPartitioning) + def func(iterator): return imap(f, iterator) + return PipelinedRDD(self, func, preservesPartitioning) def flatMap(self, f): """ @@ -38,7 +38,8 @@ class RDD(object): >>> sorted(rdd.flatMap(lambda x: [(x, x), (x, x)]).collect()) [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] """ - return MappedRDD(self, f, preservesPartitioning=False, command='flatmap') + def func(iterator): return chain.from_iterable(imap(f, iterator)) + return PipelinedRDD(self, func) def filter(self, f): """ @@ -46,10 +47,10 @@ class RDD(object): >>> rdd.filter(lambda x: x % 2 == 0).collect() [2, 4] """ - def filter_func(x): return x if f(x) else None - return RDD(self._pipe(filter_func), self.ctx) + def func(iterator): return ifilter(f, iterator) + return PipelinedRDD(self, func) - def _pipe(self, functions, command="map"): + def _pipe(self, functions, command): class_manifest = self._jrdd.classManifest() pipe_command = RDD._get_pipe_command(command, functions) python_rdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command, @@ -128,7 +129,16 @@ class RDD(object): >>> sc.parallelize((2 for _ in range(10))).map(lambda x: 1).cache().reduce(add) 10 """ - vals = MappedRDD(self, f, command="reduce", preservesPartitioning=False).collect() + def func(iterator): + acc = None + for obj in iterator: + if acc is None: + acc = obj + else: + acc = f(obj, acc) + if acc is not None: + yield acc + vals = PipelinedRDD(self, func).collect() return reduce(f, vals) # TODO: fold @@ -230,8 +240,6 @@ class RDD(object): jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) return RDD(jrdd, self.ctx) - - def combineByKey(self, createCombiner, mergeValue, mergeCombiners, numSplits=None): """ @@ -297,7 +305,7 @@ class RDD(object): # TODO: file saving -class MappedRDD(RDD): +class PipelinedRDD(RDD): """ Pipelined maps: >>> rdd = sc.parallelize([1, 2, 3, 4]) @@ -313,68 +321,29 @@ class MappedRDD(RDD): >>> rdd.flatMap(lambda x: [x, x]).reduce(add) 20 """ - def __init__(self, prev, func, preservesPartitioning=False, command='map'): - if isinstance(prev, MappedRDD) and not prev.is_cached: + def __init__(self, prev, func, preservesPartitioning=False): + if isinstance(prev, PipelinedRDD) and not prev.is_cached: prev_func = prev.func - if command == 'reduce': - if prev.command == 'flatmap': - def flatmap_reduce_func(x, acc): - values = prev_func(x) - if values is None: - return acc - if not acc: - if len(values) == 1: - return values[0] - else: - return reduce(func, values[1:], values[0]) - else: - return reduce(func, values, acc) - self.func = flatmap_reduce_func - else: - def reduce_func(x, acc): - val = prev_func(x) - if not val: - return acc - if acc is None: - return val - else: - return func(val, acc) - self.func = reduce_func - else: - if prev.command == 'flatmap': - command = 'flatmap' - self.func = lambda x: (func(y) for y in prev_func(x)) - else: - self.func = lambda x: func(prev_func(x)) - + def pipeline_func(iterator): + return func(prev_func(iterator)) + self.func = pipeline_func self.preservesPartitioning = \ prev.preservesPartitioning and preservesPartitioning self._prev_jrdd = prev._prev_jrdd - self.is_pipelined = True else: - if command == 'reduce': - def reduce_func(val, acc): - if acc is None: - return val - else: - return func(val, acc) - self.func = reduce_func - else: - self.func = func + self.func = func self.preservesPartitioning = preservesPartitioning self._prev_jrdd = prev._jrdd - self.is_pipelined = False self.is_cached = False self.ctx = prev.ctx self.prev = prev self._jrdd_val = None - self.command = command @property def _jrdd(self): if not self._jrdd_val: funcs = [self.func] - pipe_command = RDD._get_pipe_command(self.command, funcs) + pipe_command = RDD._get_pipe_command("pipeline", funcs) class_manifest = self._prev_jrdd.classManifest() python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), pipe_command, self.preservesPartitioning, self.ctx.pythonExec, diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py index b13ed5699a..76b09918e7 100644 --- a/pyspark/pyspark/worker.py +++ b/pyspark/pyspark/worker.py @@ -25,17 +25,17 @@ def output(x): def read_input(): try: while True: - yield loads(sys.stdin) + yield cPickle.loads(loads(sys.stdin)) except EOFError: return + def do_combine_by_key(): create_combiner = load_function() merge_value = load_function() merge_combiners = load_function() # TODO: not used. combiners = {} - for obj in read_input(): - (key, value) = PickleSerializer.loads(obj) + for (key, value) in read_input(): if key not in combiners: combiners[key] = create_combiner(value) else: @@ -44,57 +44,32 @@ def do_combine_by_key(): output(PickleSerializer.dumps((key, combiner))) -def do_map(flat=False): +def do_pipeline(): f = load_function() - for obj in read_input(): - try: - out = f(PickleSerializer.loads(obj)) - if out is not None: - if flat: - for x in out: - output(PickleSerializer.dumps(x)) - else: - output(PickleSerializer.dumps(out)) - except: - sys.stderr.write("Error processing obj %s\n" % repr(obj)) - raise + for obj in f(read_input()): + output(PickleSerializer.dumps(obj)) def do_shuffle_map_step(): hashFunc = load_function() - for obj in read_input(): - key = PickleSerializer.loads(obj)[0] + while True: + try: + pickled = loads(sys.stdin) + except EOFError: + return + key = cPickle.loads(pickled)[0] output(str(hashFunc(key))) - output(obj) - - -def do_reduce(): - f = load_function() - acc = None - for obj in read_input(): - acc = f(PickleSerializer.loads(obj), acc) - if acc is not None: - output(PickleSerializer.dumps(acc)) - - -def do_echo(): - old_stdout.writelines(sys.stdin.readlines()) + output(pickled) def main(): command = sys.stdin.readline().strip() - if command == "map": - do_map(flat=False) - elif command == "flatmap": - do_map(flat=True) + if command == "pipeline": + do_pipeline() elif command == "combine_by_key": do_combine_by_key() - elif command == "reduce": - do_reduce() elif command == "shuffle_map_step": do_shuffle_map_step() - elif command == "echo": - do_echo() else: raise Exception("Unsupported command %s" % command) From 65e8406029a0fe1e1c5c5d033d335b43f6743a04 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 24 Aug 2012 21:07:26 -0700 Subject: [PATCH 007/291] Implement fold() in Python API. --- pyspark/pyspark/rdd.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 7d280d8844..af7703fdfc 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -141,7 +141,25 @@ class RDD(object): vals = PipelinedRDD(self, func).collect() return reduce(f, vals) - # TODO: fold + def fold(self, zeroValue, op): + """ + Aggregate the elements of each partition, and then the results for all + the partitions, using a given associative function and a neutral "zero + value." The function op(t1, t2) is allowed to modify t1 and return it + as its result value to avoid object allocation; however, it should not + modify t2. + + >>> from operator import add + >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add) + 15 + """ + def func(iterator): + acc = zeroValue + for obj in iterator: + acc = op(obj, acc) + yield acc + vals = PipelinedRDD(self, func).collect() + return reduce(op, vals, zeroValue) # TODO: aggregate From f79a1e4d2a8643157136de69b8d7de84f0034712 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 25 Aug 2012 13:59:01 -0700 Subject: [PATCH 008/291] Add broadcast variables to Python API. --- .../scala/spark/api/python/PythonRDD.scala | 43 ++++++++++------- pyspark/pyspark/broadcast.py | 46 +++++++++++++++++++ pyspark/pyspark/context.py | 17 +++++-- pyspark/pyspark/rdd.py | 27 +++++++---- pyspark/pyspark/worker.py | 6 +++ 5 files changed, 110 insertions(+), 29 deletions(-) create mode 100644 pyspark/pyspark/broadcast.py diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 93847e2f14..5163812df4 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -7,14 +7,13 @@ import scala.collection.JavaConversions._ import scala.io.Source import spark._ import api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} -import scala.{collection, Some} -import collection.parallel.mutable +import broadcast.Broadcast import scala.collection -import scala.Some trait PythonRDDBase { def compute[T](split: Split, envVars: Map[String, String], - command: Seq[String], parent: RDD[T], pythonExec: String): Iterator[Array[Byte]] = { + command: Seq[String], parent: RDD[T], pythonExec: String, + broadcastVars: java.util.List[Broadcast[Array[Byte]]]): Iterator[Array[Byte]] = { val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME") val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/pyspark/pyspark/worker.py")) @@ -42,11 +41,18 @@ trait PythonRDDBase { override def run() { SparkEnv.set(env) val out = new PrintWriter(proc.getOutputStream) + val dOut = new DataOutputStream(proc.getOutputStream) + out.println(broadcastVars.length) + for (broadcast <- broadcastVars) { + out.print(broadcast.uuid.toString) + dOut.writeInt(broadcast.value.length) + dOut.write(broadcast.value) + dOut.flush() + } for (elem <- command) { out.println(elem) } out.flush() - val dOut = new DataOutputStream(proc.getOutputStream) for (elem <- parent.iterator(split)) { if (elem.isInstanceOf[Array[Byte]]) { val arr = elem.asInstanceOf[Array[Byte]] @@ -121,16 +127,17 @@ trait PythonRDDBase { class PythonRDD[T: ClassManifest]( parent: RDD[T], command: Seq[String], envVars: Map[String, String], - preservePartitoning: Boolean, pythonExec: String) + preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) extends RDD[Array[Byte]](parent.context) with PythonRDDBase { - def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, pythonExec: String) = - this(parent, command, Map(), preservePartitoning, pythonExec) + def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, + pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = + this(parent, command, Map(), preservePartitoning, pythonExec, broadcastVars) // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) - def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String) = - this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec) + def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = + this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec, broadcastVars) override def splits = parent.splits @@ -139,23 +146,25 @@ class PythonRDD[T: ClassManifest]( override val partitioner = if (preservePartitoning) parent.partitioner else None override def compute(split: Split): Iterator[Array[Byte]] = - compute(split, envVars, command, parent, pythonExec) + compute(split, envVars, command, parent, pythonExec, broadcastVars) val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) } class PythonPairRDD[T: ClassManifest] ( parent: RDD[T], command: Seq[String], envVars: Map[String, String], - preservePartitoning: Boolean, pythonExec: String) + preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) extends RDD[(Array[Byte], Array[Byte])](parent.context) with PythonRDDBase { - def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, pythonExec: String) = - this(parent, command, Map(), preservePartitoning, pythonExec) + def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, + pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = + this(parent, command, Map(), preservePartitoning, pythonExec, broadcastVars) // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) - def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String) = - this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec) + def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String, + broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = + this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec, broadcastVars) override def splits = parent.splits @@ -164,7 +173,7 @@ class PythonPairRDD[T: ClassManifest] ( override val partitioner = if (preservePartitoning) parent.partitioner else None override def compute(split: Split): Iterator[(Array[Byte], Array[Byte])] = { - compute(split, envVars, command, parent, pythonExec).grouped(2).map { + compute(split, envVars, command, parent, pythonExec, broadcastVars).grouped(2).map { case Seq(a, b) => (a, b) case x => throw new Exception("PythonPairRDD: unexpected value: " + x) } diff --git a/pyspark/pyspark/broadcast.py b/pyspark/pyspark/broadcast.py new file mode 100644 index 0000000000..1ea17d59af --- /dev/null +++ b/pyspark/pyspark/broadcast.py @@ -0,0 +1,46 @@ +""" +>>> from pyspark.context import SparkContext +>>> sc = SparkContext('local', 'test') +>>> b = sc.broadcast([1, 2, 3, 4, 5]) +>>> b.value +[1, 2, 3, 4, 5] + +>>> from pyspark.broadcast import _broadcastRegistry +>>> _broadcastRegistry[b.uuid] = b +>>> from cPickle import dumps, loads +>>> loads(dumps(b)).value +[1, 2, 3, 4, 5] + +>>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect() +[1, 2, 3, 4, 5, 1, 2, 3, 4, 5] +""" +# Holds broadcasted data received from Java, keyed by UUID. +_broadcastRegistry = {} + + +def _from_uuid(uuid): + from pyspark.broadcast import _broadcastRegistry + if uuid not in _broadcastRegistry: + raise Exception("Broadcast variable '%s' not loaded!" % uuid) + return _broadcastRegistry[uuid] + + +class Broadcast(object): + def __init__(self, uuid, value, java_broadcast=None, pickle_registry=None): + self.value = value + self.uuid = uuid + self._jbroadcast = java_broadcast + self._pickle_registry = pickle_registry + + def __reduce__(self): + self._pickle_registry.add(self) + return (_from_uuid, (self.uuid, )) + + +def _test(): + import doctest + doctest.testmod() + + +if __name__ == "__main__": + _test() diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index ac7e4057e9..6f87206665 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -2,6 +2,7 @@ import os import atexit from tempfile import NamedTemporaryFile +from pyspark.broadcast import Broadcast from pyspark.java_gateway import launch_gateway from pyspark.serializers import PickleSerializer, dumps from pyspark.rdd import RDD @@ -24,6 +25,11 @@ class SparkContext(object): self.defaultParallelism = \ defaultParallelism or self._jsc.sc().defaultParallelism() self.pythonExec = pythonExec + # Broadcast's __reduce__ method stores Broadcast instances here. + # This allows other code to determine which Broadcast instances have + # been pickled, so it can determine which Java broadcast objects to + # send. + self._pickled_broadcast_vars = set() def __del__(self): if self._jsc: @@ -52,7 +58,12 @@ class SparkContext(object): jrdd = self.pickleFile(self._jsc, tempFile.name, numSlices) return RDD(jrdd, self) - def textFile(self, name, numSlices=None): - numSlices = numSlices or self.defaultParallelism - jrdd = self._jsc.textFile(name, numSlices) + def textFile(self, name, minSplits=None): + minSplits = minSplits or min(self.defaultParallelism, 2) + jrdd = self._jsc.textFile(name, minSplits) return RDD(jrdd, self) + + def broadcast(self, value): + jbroadcast = self._jsc.broadcast(bytearray(PickleSerializer.dumps(value))) + return Broadcast(jbroadcast.uuid().toString(), value, jbroadcast, + self._pickled_broadcast_vars) diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index af7703fdfc..4459095391 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -6,6 +6,8 @@ from pyspark.serializers import PickleSerializer from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup +from py4j.java_collections import ListConverter + class RDD(object): @@ -15,11 +17,15 @@ class RDD(object): self.ctx = ctx @classmethod - def _get_pipe_command(cls, command, functions): + def _get_pipe_command(cls, ctx, command, functions): worker_args = [command] for f in functions: worker_args.append(b64enc(cloudpickle.dumps(f))) - return " ".join(worker_args) + broadcast_vars = [x._jbroadcast for x in ctx._pickled_broadcast_vars] + broadcast_vars = ListConverter().convert(broadcast_vars, + ctx.gateway._gateway_client) + ctx._pickled_broadcast_vars.clear() + return (" ".join(worker_args), broadcast_vars) def cache(self): self.is_cached = True @@ -52,9 +58,10 @@ class RDD(object): def _pipe(self, functions, command): class_manifest = self._jrdd.classManifest() - pipe_command = RDD._get_pipe_command(command, functions) + (pipe_command, broadcast_vars) = \ + RDD._get_pipe_command(self.ctx, command, functions) python_rdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command, - False, self.ctx.pythonExec, class_manifest) + False, self.ctx.pythonExec, broadcast_vars, class_manifest) return python_rdd.asJavaRDD() def distinct(self): @@ -249,10 +256,12 @@ class RDD(object): def shuffle(self, numSplits, hashFunc=hash): if numSplits is None: numSplits = self.ctx.defaultParallelism - pipe_command = RDD._get_pipe_command('shuffle_map_step', [hashFunc]) + (pipe_command, broadcast_vars) = \ + RDD._get_pipe_command(self.ctx, 'shuffle_map_step', [hashFunc]) class_manifest = self._jrdd.classManifest() python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(), - pipe_command, False, self.ctx.pythonExec, class_manifest) + pipe_command, False, self.ctx.pythonExec, broadcast_vars, + class_manifest) partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits) jrdd = python_rdd.asJavaPairRDD().partitionBy(partitioner) jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) @@ -360,12 +369,12 @@ class PipelinedRDD(RDD): @property def _jrdd(self): if not self._jrdd_val: - funcs = [self.func] - pipe_command = RDD._get_pipe_command("pipeline", funcs) + (pipe_command, broadcast_vars) = \ + RDD._get_pipe_command(self.ctx, "pipeline", [self.func]) class_manifest = self._prev_jrdd.classManifest() python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), pipe_command, self.preservesPartitioning, self.ctx.pythonExec, - class_manifest) + broadcast_vars, class_manifest) self._jrdd_val = python_rdd.asJavaRDD() return self._jrdd_val diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py index 76b09918e7..7402897ac8 100644 --- a/pyspark/pyspark/worker.py +++ b/pyspark/pyspark/worker.py @@ -5,6 +5,7 @@ import sys from base64 import standard_b64decode # CloudPickler needs to be imported so that depicklers are registered using the # copy_reg module. +from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler from pyspark.serializers import dumps, loads, PickleSerializer import cPickle @@ -63,6 +64,11 @@ def do_shuffle_map_step(): def main(): + num_broadcast_variables = int(sys.stdin.readline().strip()) + for _ in range(num_broadcast_variables): + uuid = sys.stdin.read(36) + value = loads(sys.stdin) + _broadcastRegistry[uuid] = Broadcast(uuid, cPickle.loads(value)) command = sys.stdin.readline().strip() if command == "pipeline": do_pipeline() From 08b201d810c0dc0933d00d78ec2c1d9135e100c3 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 24 Aug 2012 22:51:45 -0700 Subject: [PATCH 009/291] Add mapPartitions(), glom(), countByValue() to Python API. --- pyspark/pyspark/rdd.py | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 4459095391..f0d665236a 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -1,4 +1,5 @@ from base64 import standard_b64encode as b64enc +from collections import Counter from itertools import chain, ifilter, imap from pyspark import cloudpickle @@ -47,6 +48,15 @@ class RDD(object): def func(iterator): return chain.from_iterable(imap(f, iterator)) return PipelinedRDD(self, func) + def mapPartitions(self, f): + """ + >>> rdd = sc.parallelize([1, 2, 3, 4], 2) + >>> def f(iterator): yield sum(iterator) + >>> rdd.mapPartitions(f).collect() + [3, 7] + """ + return PipelinedRDD(self, f) + def filter(self, f): """ >>> rdd = sc.parallelize([1, 2, 3, 4, 5]) @@ -93,7 +103,14 @@ class RDD(object): # TODO: Overload __add___? - # TODO: glom + def glom(self): + """ + >>> rdd = sc.parallelize([1, 2, 3, 4], 2) + >>> rdd.glom().first() + [1, 2] + """ + def func(iterator): yield list(iterator) + return PipelinedRDD(self, func) def cartesian(self, other): """ @@ -115,8 +132,6 @@ class RDD(object): # TODO: pipe - # TODO: mapPartitions - def foreach(self, f): """ >>> def f(x): print x @@ -177,7 +192,16 @@ class RDD(object): """ return self._jrdd.count() - # TODO: count approx methods + def countByValue(self): + """ + >>> sc.parallelize([1, 2, 1, 2, 2]).countByValue().most_common() + [(2, 3), (1, 2)] + """ + def countPartition(iterator): + yield Counter(iterator) + def mergeMaps(m1, m2): + return m1 + m2 + return self.mapPartitions(countPartition).reduce(mergeMaps) def take(self, num): """ From 8b64b7ecd80c52f2f09a517f1517c0ece7a3d57f Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 24 Aug 2012 23:09:15 -0700 Subject: [PATCH 010/291] Add countByKey(), reduceByKeyLocally() to Python API --- pyspark/pyspark/rdd.py | 52 +++++++++++++++++++++++++++++++----------- 1 file changed, 39 insertions(+), 13 deletions(-) diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index f0d665236a..fd41ea0b17 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -99,9 +99,17 @@ class RDD(object): """ return RDD(self._jrdd.union(other._jrdd), self.ctx) - # TODO: sort + def __add__(self, other): + """ + >>> rdd = sc.parallelize([1, 1, 2, 3]) + >>> (rdd + rdd).collect() + [1, 1, 2, 3, 1, 1, 2, 3] + """ + if not isinstance(other, RDD): + raise TypeError + return self.union(other) - # TODO: Overload __add___? + # TODO: sort def glom(self): """ @@ -120,7 +128,6 @@ class RDD(object): """ return RDD(self._jrdd.cartesian(other._jrdd), self.ctx) - # numsplits def groupBy(self, f, numSplits=None): """ >>> rdd = sc.parallelize([1, 1, 2, 3, 5, 8]) @@ -236,17 +243,38 @@ class RDD(object): def reduceByKey(self, func, numSplits=None): """ - >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) - >>> sorted(x.reduceByKey(lambda a, b: a + b).collect()) + >>> from operator import add + >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) + >>> sorted(rdd.reduceByKey(add).collect()) [('a', 2), ('b', 1)] """ return self.combineByKey(lambda x: x, func, func, numSplits) - # TODO: reduceByKeyLocally() + def reduceByKeyLocally(self, func): + """ + >>> from operator import add + >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) + >>> sorted(rdd.reduceByKeyLocally(add).items()) + [('a', 2), ('b', 1)] + """ + def reducePartition(iterator): + m = {} + for (k, v) in iterator: + m[k] = v if k not in m else func(m[k], v) + yield m + def mergeMaps(m1, m2): + for (k, v) in m2.iteritems(): + m1[k] = v if k not in m1 else func(m1[k], v) + return m1 + return self.mapPartitions(reducePartition).reduce(mergeMaps) - # TODO: countByKey() - - # TODO: partitionBy + def countByKey(self): + """ + >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) + >>> rdd.countByKey().most_common() + [('a', 2), ('b', 1)] + """ + return self.map(lambda x: x[0]).countByValue() def join(self, other, numSplits=None): """ @@ -277,7 +305,7 @@ class RDD(object): # TODO: pipelining # TODO: optimizations - def shuffle(self, numSplits, hashFunc=hash): + def partitionBy(self, numSplits, hashFunc=hash): if numSplits is None: numSplits = self.ctx.defaultParallelism (pipe_command, broadcast_vars) = \ @@ -302,7 +330,7 @@ class RDD(object): """ if numSplits is None: numSplits = self.ctx.defaultParallelism - shuffled = self.shuffle(numSplits) + shuffled = self.partitionBy(numSplits) functions = [createCombiner, mergeValue, mergeCombiners] jpairs = shuffled._pipe(functions, "combine_by_key") return RDD(jpairs, self.ctx) @@ -353,8 +381,6 @@ class RDD(object): # keys in the pairs. This could be an expensive operation, since those # hashes aren't retained. - # TODO: file saving - class PipelinedRDD(RDD): """ From 6904cb77d4306a14891cc71338c8f9f966d009f1 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 25 Aug 2012 14:19:07 -0700 Subject: [PATCH 011/291] Use local combiners in Python API combineByKey(). --- pyspark/pyspark/rdd.py | 33 ++++++++++++++++++++++++--------- pyspark/pyspark/worker.py | 16 ---------------- 2 files changed, 24 insertions(+), 25 deletions(-) diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index fd41ea0b17..3528b8f308 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -46,7 +46,7 @@ class RDD(object): [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] """ def func(iterator): return chain.from_iterable(imap(f, iterator)) - return PipelinedRDD(self, func) + return self.mapPartitions(func) def mapPartitions(self, f): """ @@ -64,7 +64,7 @@ class RDD(object): [2, 4] """ def func(iterator): return ifilter(f, iterator) - return PipelinedRDD(self, func) + return self.mapPartitions(func) def _pipe(self, functions, command): class_manifest = self._jrdd.classManifest() @@ -118,7 +118,7 @@ class RDD(object): [1, 2] """ def func(iterator): yield list(iterator) - return PipelinedRDD(self, func) + return self.mapPartitions(func) def cartesian(self, other): """ @@ -167,7 +167,7 @@ class RDD(object): acc = f(obj, acc) if acc is not None: yield acc - vals = PipelinedRDD(self, func).collect() + vals = self.mapPartitions(func).collect() return reduce(f, vals) def fold(self, zeroValue, op): @@ -187,7 +187,7 @@ class RDD(object): for obj in iterator: acc = op(obj, acc) yield acc - vals = PipelinedRDD(self, func).collect() + vals = self.mapPartitions(func).collect() return reduce(op, vals, zeroValue) # TODO: aggregate @@ -330,10 +330,25 @@ class RDD(object): """ if numSplits is None: numSplits = self.ctx.defaultParallelism - shuffled = self.partitionBy(numSplits) - functions = [createCombiner, mergeValue, mergeCombiners] - jpairs = shuffled._pipe(functions, "combine_by_key") - return RDD(jpairs, self.ctx) + def combineLocally(iterator): + combiners = {} + for (k, v) in iterator: + if k not in combiners: + combiners[k] = createCombiner(v) + else: + combiners[k] = mergeValue(combiners[k], v) + return combiners.iteritems() + locally_combined = self.mapPartitions(combineLocally) + shuffled = locally_combined.partitionBy(numSplits) + def _mergeCombiners(iterator): + combiners = {} + for (k, v) in iterator: + if not k in combiners: + combiners[k] = v + else: + combiners[k] = mergeCombiners(combiners[k], v) + return combiners.iteritems() + return shuffled.mapPartitions(_mergeCombiners) def groupByKey(self, numSplits=None): """ diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py index 7402897ac8..0f90c6ff46 100644 --- a/pyspark/pyspark/worker.py +++ b/pyspark/pyspark/worker.py @@ -31,20 +31,6 @@ def read_input(): return -def do_combine_by_key(): - create_combiner = load_function() - merge_value = load_function() - merge_combiners = load_function() # TODO: not used. - combiners = {} - for (key, value) in read_input(): - if key not in combiners: - combiners[key] = create_combiner(value) - else: - combiners[key] = merge_value(combiners[key], value) - for (key, combiner) in combiners.iteritems(): - output(PickleSerializer.dumps((key, combiner))) - - def do_pipeline(): f = load_function() for obj in f(read_input()): @@ -72,8 +58,6 @@ def main(): command = sys.stdin.readline().strip() if command == "pipeline": do_pipeline() - elif command == "combine_by_key": - do_combine_by_key() elif command == "shuffle_map_step": do_shuffle_map_step() else: From 200d248dcc5903295296bf897211cf543b37f8c1 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 25 Aug 2012 16:46:07 -0700 Subject: [PATCH 012/291] Simplify Python worker; pipeline the map step of partitionBy(). --- .../scala/spark/api/python/PythonRDD.scala | 34 ++------- pyspark/pyspark/context.py | 9 ++- pyspark/pyspark/rdd.py | 70 +++++++------------ pyspark/pyspark/serializers.py | 23 ++---- pyspark/pyspark/worker.py | 50 ++++--------- 5 files changed, 59 insertions(+), 127 deletions(-) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 5163812df4..b9091fd436 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -151,38 +151,18 @@ class PythonRDD[T: ClassManifest]( val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) } -class PythonPairRDD[T: ClassManifest] ( - parent: RDD[T], command: Seq[String], envVars: Map[String, String], - preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) - extends RDD[(Array[Byte], Array[Byte])](parent.context) with PythonRDDBase { - - def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, - pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = - this(parent, command, Map(), preservePartitoning, pythonExec, broadcastVars) - - // Similar to Runtime.exec(), if we are given a single string, split it into words - // using a standard StringTokenizer (i.e. by spaces) - def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String, - broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = - this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec, broadcastVars) - - override def splits = parent.splits - - override val dependencies = List(new OneToOneDependency(parent)) - - override val partitioner = if (preservePartitoning) parent.partitioner else None - - override def compute(split: Split): Iterator[(Array[Byte], Array[Byte])] = { - compute(split, envVars, command, parent, pythonExec, broadcastVars).grouped(2).map { +private class PairwiseRDD(prev: RDD[Array[Byte]]) extends + RDD[(Array[Byte], Array[Byte])](prev.context) { + override def splits = prev.splits + override val dependencies = List(new OneToOneDependency(prev)) + override def compute(split: Split) = + prev.iterator(split).grouped(2).map { case Seq(a, b) => (a, b) - case x => throw new Exception("PythonPairRDD: unexpected value: " + x) + case x => throw new Exception("PairwiseRDD: unexpected value: " + x) } - } - val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this) } - object PythonRDD { /** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */ diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index 6f87206665..b8490019e3 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -4,7 +4,7 @@ from tempfile import NamedTemporaryFile from pyspark.broadcast import Broadcast from pyspark.java_gateway import launch_gateway -from pyspark.serializers import PickleSerializer, dumps +from pyspark.serializers import dump_pickle, write_with_length from pyspark.rdd import RDD @@ -16,9 +16,8 @@ class SparkContext(object): asPickle = jvm.spark.api.python.PythonRDD.asPickle arrayAsPickle = jvm.spark.api.python.PythonRDD.arrayAsPickle - def __init__(self, master, name, defaultParallelism=None, - pythonExec='python'): + pythonExec='python'): self.master = master self.name = name self._jsc = self.jvm.JavaSparkContext(master, name) @@ -52,7 +51,7 @@ class SparkContext(object): # objects are written to a file and loaded through textFile(). tempFile = NamedTemporaryFile(delete=False) for x in c: - dumps(PickleSerializer.dumps(x), tempFile) + write_with_length(dump_pickle(x), tempFile) tempFile.close() atexit.register(lambda: os.unlink(tempFile.name)) jrdd = self.pickleFile(self._jsc, tempFile.name, numSlices) @@ -64,6 +63,6 @@ class SparkContext(object): return RDD(jrdd, self) def broadcast(self, value): - jbroadcast = self._jsc.broadcast(bytearray(PickleSerializer.dumps(value))) + jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value))) return Broadcast(jbroadcast.uuid().toString(), value, jbroadcast, self._pickled_broadcast_vars) diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 3528b8f308..21e822ba9f 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -3,7 +3,7 @@ from collections import Counter from itertools import chain, ifilter, imap from pyspark import cloudpickle -from pyspark.serializers import PickleSerializer +from pyspark.serializers import dump_pickle, load_pickle from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup @@ -17,17 +17,6 @@ class RDD(object): self.is_cached = False self.ctx = ctx - @classmethod - def _get_pipe_command(cls, ctx, command, functions): - worker_args = [command] - for f in functions: - worker_args.append(b64enc(cloudpickle.dumps(f))) - broadcast_vars = [x._jbroadcast for x in ctx._pickled_broadcast_vars] - broadcast_vars = ListConverter().convert(broadcast_vars, - ctx.gateway._gateway_client) - ctx._pickled_broadcast_vars.clear() - return (" ".join(worker_args), broadcast_vars) - def cache(self): self.is_cached = True self._jrdd.cache() @@ -66,14 +55,6 @@ class RDD(object): def func(iterator): return ifilter(f, iterator) return self.mapPartitions(func) - def _pipe(self, functions, command): - class_manifest = self._jrdd.classManifest() - (pipe_command, broadcast_vars) = \ - RDD._get_pipe_command(self.ctx, command, functions) - python_rdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command, - False, self.ctx.pythonExec, broadcast_vars, class_manifest) - return python_rdd.asJavaRDD() - def distinct(self): """ >>> sorted(sc.parallelize([1, 1, 2, 3]).distinct().collect()) @@ -89,7 +70,7 @@ class RDD(object): def takeSample(self, withReplacement, num, seed): vals = self._jrdd.takeSample(withReplacement, num, seed) - return [PickleSerializer.loads(bytes(x)) for x in vals] + return [load_pickle(bytes(x)) for x in vals] def union(self, other): """ @@ -148,7 +129,7 @@ class RDD(object): def collect(self): pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().collect()) - return PickleSerializer.loads(bytes(pickle)) + return load_pickle(bytes(pickle)) def reduce(self, f): """ @@ -216,19 +197,17 @@ class RDD(object): [2, 3] """ pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().take(num)) - return PickleSerializer.loads(bytes(pickle)) + return load_pickle(bytes(pickle)) def first(self): """ >>> sc.parallelize([2, 3, 4]).first() 2 """ - return PickleSerializer.loads(bytes(self.ctx.asPickle(self._jrdd.first()))) + return load_pickle(bytes(self.ctx.asPickle(self._jrdd.first()))) # TODO: saveAsTextFile - # TODO: saveAsObjectFile - # Pair functions def collectAsMap(self): @@ -303,19 +282,18 @@ class RDD(object): """ return python_right_outer_join(self, other, numSplits) - # TODO: pipelining - # TODO: optimizations def partitionBy(self, numSplits, hashFunc=hash): if numSplits is None: numSplits = self.ctx.defaultParallelism - (pipe_command, broadcast_vars) = \ - RDD._get_pipe_command(self.ctx, 'shuffle_map_step', [hashFunc]) - class_manifest = self._jrdd.classManifest() - python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(), - pipe_command, False, self.ctx.pythonExec, broadcast_vars, - class_manifest) + def add_shuffle_key(iterator): + for (k, v) in iterator: + yield str(hashFunc(k)) + yield dump_pickle((k, v)) + keyed = PipelinedRDD(self, add_shuffle_key) + keyed._bypass_serializer = True + pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits) - jrdd = python_rdd.asJavaPairRDD().partitionBy(partitioner) + jrdd = pairRDD.partitionBy(partitioner) jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) return RDD(jrdd, self.ctx) @@ -430,17 +408,23 @@ class PipelinedRDD(RDD): self.ctx = prev.ctx self.prev = prev self._jrdd_val = None + self._bypass_serializer = False @property def _jrdd(self): - if not self._jrdd_val: - (pipe_command, broadcast_vars) = \ - RDD._get_pipe_command(self.ctx, "pipeline", [self.func]) - class_manifest = self._prev_jrdd.classManifest() - python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), - pipe_command, self.preservesPartitioning, self.ctx.pythonExec, - broadcast_vars, class_manifest) - self._jrdd_val = python_rdd.asJavaRDD() + if self._jrdd_val: + return self._jrdd_val + funcs = [self.func, self._bypass_serializer] + pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in funcs) + broadcast_vars = ListConverter().convert( + [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], + self.ctx.gateway._gateway_client) + self.ctx._pickled_broadcast_vars.clear() + class_manifest = self._prev_jrdd.classManifest() + python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), + pipe_command, self.preservesPartitioning, self.ctx.pythonExec, + broadcast_vars, class_manifest) + self._jrdd_val = python_rdd.asJavaRDD() return self._jrdd_val diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py index 7b3e6966e1..faa1e683c7 100644 --- a/pyspark/pyspark/serializers.py +++ b/pyspark/pyspark/serializers.py @@ -1,31 +1,20 @@ -""" -Data serialization methods. - -The Spark Python API is built on top of the Spark Java API. RDDs created in -Python are stored in Java as RDD[Array[Byte]]. Python objects are -automatically serialized/deserialized, so this representation is transparent to -the end-user. -""" -from collections import namedtuple -import cPickle import struct +import cPickle -Serializer = namedtuple("Serializer", ["dumps","loads"]) +def dump_pickle(obj): + return cPickle.dumps(obj, 2) -PickleSerializer = Serializer( - lambda obj: cPickle.dumps(obj, -1), - cPickle.loads) +load_pickle = cPickle.loads -def dumps(obj, stream): - # TODO: determining the length of non-byte objects. +def write_with_length(obj, stream): stream.write(struct.pack("!i", len(obj))) stream.write(obj) -def loads(stream): +def read_with_length(stream): length = stream.read(4) if length == "": raise EOFError diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py index 0f90c6ff46..a9ed71892f 100644 --- a/pyspark/pyspark/worker.py +++ b/pyspark/pyspark/worker.py @@ -7,61 +7,41 @@ from base64 import standard_b64decode # copy_reg module. from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler -from pyspark.serializers import dumps, loads, PickleSerializer -import cPickle +from pyspark.serializers import write_with_length, read_with_length, \ + dump_pickle, load_pickle + # Redirect stdout to stderr so that users must return values from functions. old_stdout = sys.stdout sys.stdout = sys.stderr -def load_function(): - return cPickle.loads(standard_b64decode(sys.stdin.readline().strip())) - - -def output(x): - dumps(x, old_stdout) +def load_obj(): + return load_pickle(standard_b64decode(sys.stdin.readline().strip())) def read_input(): try: while True: - yield cPickle.loads(loads(sys.stdin)) + yield load_pickle(read_with_length(sys.stdin)) except EOFError: return -def do_pipeline(): - f = load_function() - for obj in f(read_input()): - output(PickleSerializer.dumps(obj)) - - -def do_shuffle_map_step(): - hashFunc = load_function() - while True: - try: - pickled = loads(sys.stdin) - except EOFError: - return - key = cPickle.loads(pickled)[0] - output(str(hashFunc(key))) - output(pickled) - - def main(): num_broadcast_variables = int(sys.stdin.readline().strip()) for _ in range(num_broadcast_variables): uuid = sys.stdin.read(36) - value = loads(sys.stdin) - _broadcastRegistry[uuid] = Broadcast(uuid, cPickle.loads(value)) - command = sys.stdin.readline().strip() - if command == "pipeline": - do_pipeline() - elif command == "shuffle_map_step": - do_shuffle_map_step() + value = read_with_length(sys.stdin) + _broadcastRegistry[uuid] = Broadcast(uuid, load_pickle(value)) + func = load_obj() + bypassSerializer = load_obj() + if bypassSerializer: + dumps = lambda x: x else: - raise Exception("Unsupported command %s" % command) + dumps = dump_pickle + for obj in func(read_input()): + write_with_length(dumps(obj), old_stdout) if __name__ == '__main__': From bff6a46359131a8f9bc38b93149b22baa7c711cd Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 25 Aug 2012 18:00:25 -0700 Subject: [PATCH 013/291] Add pipe(), saveAsTextFile(), sc.union() to Python API. --- .../scala/spark/api/python/PythonRDD.scala | 10 +++++--- pyspark/pyspark/context.py | 14 ++++++----- pyspark/pyspark/rdd.py | 25 +++++++++++++++++-- 3 files changed, 38 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index b9091fd436..4d3bdb3963 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -9,6 +9,7 @@ import spark._ import api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import broadcast.Broadcast import scala.collection +import java.nio.charset.Charset trait PythonRDDBase { def compute[T](split: Split, envVars: Map[String, String], @@ -238,9 +239,12 @@ private object Pickle { val MARK : Byte = '(' val APPENDS : Byte = 'e' } -class ExtractValue extends spark.api.java.function.Function[(Array[Byte], + +private class ExtractValue extends spark.api.java.function.Function[(Array[Byte], Array[Byte]), Array[Byte]] { - override def call(pair: (Array[Byte], Array[Byte])) : Array[Byte] = pair._2 - +} + +private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] { + override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8") } diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index b8490019e3..04932c93f2 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -7,6 +7,8 @@ from pyspark.java_gateway import launch_gateway from pyspark.serializers import dump_pickle, write_with_length from pyspark.rdd import RDD +from py4j.java_collections import ListConverter + class SparkContext(object): @@ -39,12 +41,6 @@ class SparkContext(object): self._jsc = None def parallelize(self, c, numSlices=None): - """ - >>> sc = SparkContext("local", "test") - >>> rdd = sc.parallelize([(1, 2), (3, 4)]) - >>> rdd.collect() - [(1, 2), (3, 4)] - """ numSlices = numSlices or self.defaultParallelism # Calling the Java parallelize() method with an ArrayList is too slow, # because it sends O(n) Py4J commands. As an alternative, serialized @@ -62,6 +58,12 @@ class SparkContext(object): jrdd = self._jsc.textFile(name, minSplits) return RDD(jrdd, self) + def union(self, rdds): + first = rdds[0]._jrdd + rest = [x._jrdd for x in rdds[1:]] + rest = ListConverter().convert(rest, self.gateway._gateway_client) + return RDD(self._jsc.union(first, rest), self) + def broadcast(self, value): jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value))) return Broadcast(jbroadcast.uuid().toString(), value, jbroadcast, diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 21e822ba9f..8477f6dd02 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -1,6 +1,9 @@ from base64 import standard_b64encode as b64enc from collections import Counter from itertools import chain, ifilter, imap +import shlex +from subprocess import Popen, PIPE +from threading import Thread from pyspark import cloudpickle from pyspark.serializers import dump_pickle, load_pickle @@ -118,7 +121,20 @@ class RDD(object): """ return self.map(lambda x: (f(x), x)).groupByKey(numSplits) - # TODO: pipe + def pipe(self, command, env={}): + """ + >>> sc.parallelize([1, 2, 3]).pipe('cat').collect() + ['1', '2', '3'] + """ + def func(iterator): + pipe = Popen(shlex.split(command), env=env, stdin=PIPE, stdout=PIPE) + def pipe_objs(out): + for obj in iterator: + out.write(str(obj).rstrip('\n') + '\n') + out.close() + Thread(target=pipe_objs, args=[pipe.stdin]).start() + return (x.rstrip('\n') for x in pipe.stdout) + return self.mapPartitions(func) def foreach(self, f): """ @@ -206,7 +222,12 @@ class RDD(object): """ return load_pickle(bytes(self.ctx.asPickle(self._jrdd.first()))) - # TODO: saveAsTextFile + def saveAsTextFile(self, path): + def func(iterator): + return (str(x).encode("utf-8") for x in iterator) + keyed = PipelinedRDD(self, func) + keyed._bypass_serializer = True + keyed._jrdd.map(self.ctx.jvm.BytesToString()).saveAsTextFile(path) # Pair functions From 414367850982c4f8fc5e63cc94caa422eb736db5 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 27 Aug 2012 00:13:19 -0700 Subject: [PATCH 014/291] Fix minor bugs in Python API examples. --- pyspark/pyspark/examples/pi.py | 2 +- pyspark/pyspark/examples/tc.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pyspark/pyspark/examples/pi.py b/pyspark/pyspark/examples/pi.py index fe63d2c952..348bbc5dce 100644 --- a/pyspark/pyspark/examples/pi.py +++ b/pyspark/pyspark/examples/pi.py @@ -9,7 +9,7 @@ if __name__ == "__main__": print >> sys.stderr, \ "Usage: PythonPi []" exit(-1) - sc = SparkContext(sys.argv[1], "PythonKMeans") + sc = SparkContext(sys.argv[1], "PythonPi") slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2 n = 100000 * slices def f(_): diff --git a/pyspark/pyspark/examples/tc.py b/pyspark/pyspark/examples/tc.py index 2796fdc6ad..9630e72b47 100644 --- a/pyspark/pyspark/examples/tc.py +++ b/pyspark/pyspark/examples/tc.py @@ -22,9 +22,9 @@ if __name__ == "__main__": print >> sys.stderr, \ "Usage: PythonTC []" exit(-1) - sc = SparkContext(sys.argv[1], "PythonKMeans") + sc = SparkContext(sys.argv[1], "PythonTC") slices = sys.argv[2] if len(sys.argv) > 2 else 2 - tc = sc.parallelizePairs(generateGraph(), slices).cache() + tc = sc.parallelize(generateGraph(), slices).cache() # Linear transitive closure: each round grows paths by one edge, # by joining the graph's edges with the already-discovered paths. @@ -32,7 +32,7 @@ if __name__ == "__main__": # the graph to obtain the path (x, z). # Because join() joins on keys, the edges are stored in reversed order. - edges = tc.mapPairs(lambda (x, y): (y, x)) + edges = tc.map(lambda (x, y): (y, x)) oldCount = 0L nextCount = tc.count() @@ -40,7 +40,7 @@ if __name__ == "__main__": oldCount = nextCount # Perform the join, obtaining an RDD of (y, (z, x)) pairs, # then project the result to obtain the new (x, z) paths. - new_edges = tc.join(edges).mapPairs(lambda (_, (a, b)): (b, a)) + new_edges = tc.join(edges).map(lambda (_, (a, b)): (b, a)) tc = tc.union(new_edges).distinct().cache() nextCount = tc.count() if nextCount == oldCount: From 9abdfa663360252d2edb346e6b3df4ff94ce78d7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 17 Sep 2012 00:08:50 -0700 Subject: [PATCH 015/291] Fix Python 2.6 compatibility in Python API. --- pyspark/pyspark/rdd.py | 17 +++++++++++------ python/tc.py | 22 ---------------------- 2 files changed, 11 insertions(+), 28 deletions(-) delete mode 100644 python/tc.py diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 8477f6dd02..e2137fe06c 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -1,5 +1,5 @@ from base64 import standard_b64encode as b64enc -from collections import Counter +from collections import defaultdict from itertools import chain, ifilter, imap import shlex from subprocess import Popen, PIPE @@ -198,13 +198,18 @@ class RDD(object): def countByValue(self): """ - >>> sc.parallelize([1, 2, 1, 2, 2]).countByValue().most_common() - [(2, 3), (1, 2)] + >>> sorted(sc.parallelize([1, 2, 1, 2, 2], 2).countByValue().items()) + [(1, 2), (2, 3)] """ def countPartition(iterator): - yield Counter(iterator) + counts = defaultdict(int) + for obj in iterator: + counts[obj] += 1 + yield counts def mergeMaps(m1, m2): - return m1 + m2 + for (k, v) in m2.iteritems(): + m1[k] += v + return m1 return self.mapPartitions(countPartition).reduce(mergeMaps) def take(self, num): @@ -271,7 +276,7 @@ class RDD(object): def countByKey(self): """ >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) - >>> rdd.countByKey().most_common() + >>> sorted(rdd.countByKey().items()) [('a', 2), ('b', 1)] """ return self.map(lambda x: x[0]).countByValue() diff --git a/python/tc.py b/python/tc.py deleted file mode 100644 index 5dcc4317e0..0000000000 --- a/python/tc.py +++ /dev/null @@ -1,22 +0,0 @@ -from rdd import SparkContext - -sc = SparkContext("local", "PythonWordCount") -e = [(1, 2), (2, 3), (4, 1)] - -tc = sc.parallelizePairs(e) - -edges = tc.mapPairs(lambda (x, y): (y, x)) - -oldCount = 0 -nextCount = tc.count() - -def project(x): - return (x[1][1], x[1][0]) - -while nextCount != oldCount: - oldCount = nextCount - tc = tc.union(tc.join(edges).mapPairs(project)).distinct() - nextCount = tc.count() - -print "TC has %i edges" % tc.count() -print tc.collect() From 52989c8a2c8c10d7f5610c033f6782e58fd3abc2 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 19 Oct 2012 10:24:49 -0700 Subject: [PATCH 016/291] Update Python API for v0.6.0 compatibility. --- .../scala/spark/api/python/PythonRDD.scala | 18 +++++++++++------- .../main/scala/spark/broadcast/Broadcast.scala | 2 +- pyspark/pyspark/broadcast.py | 18 +++++++++--------- pyspark/pyspark/context.py | 2 +- pyspark/pyspark/java_gateway.py | 3 ++- pyspark/pyspark/serializers.py | 18 ++++++++++++++---- pyspark/pyspark/worker.py | 8 ++++---- 7 files changed, 42 insertions(+), 27 deletions(-) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 4d3bdb3963..528885fe5c 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -5,11 +5,15 @@ import java.io._ import scala.collection.Map import scala.collection.JavaConversions._ import scala.io.Source -import spark._ -import api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} -import broadcast.Broadcast -import scala.collection -import java.nio.charset.Charset + +import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} +import spark.broadcast.Broadcast +import spark.SparkEnv +import spark.Split +import spark.RDD +import spark.OneToOneDependency +import spark.rdd.PipedRDD + trait PythonRDDBase { def compute[T](split: Split, envVars: Map[String, String], @@ -43,9 +47,9 @@ trait PythonRDDBase { SparkEnv.set(env) val out = new PrintWriter(proc.getOutputStream) val dOut = new DataOutputStream(proc.getOutputStream) - out.println(broadcastVars.length) + dOut.writeInt(broadcastVars.length) for (broadcast <- broadcastVars) { - out.print(broadcast.uuid.toString) + dOut.writeLong(broadcast.id) dOut.writeInt(broadcast.value.length) dOut.write(broadcast.value) dOut.flush() diff --git a/core/src/main/scala/spark/broadcast/Broadcast.scala b/core/src/main/scala/spark/broadcast/Broadcast.scala index 6055bfd045..2ffe7f741d 100644 --- a/core/src/main/scala/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/spark/broadcast/Broadcast.scala @@ -5,7 +5,7 @@ import java.util.concurrent.atomic.AtomicLong import spark._ -abstract class Broadcast[T](id: Long) extends Serializable { +abstract class Broadcast[T](private[spark] val id: Long) extends Serializable { def value: T // We cannot have an abstract readObject here due to some weird issues with diff --git a/pyspark/pyspark/broadcast.py b/pyspark/pyspark/broadcast.py index 1ea17d59af..4cff02b36d 100644 --- a/pyspark/pyspark/broadcast.py +++ b/pyspark/pyspark/broadcast.py @@ -6,7 +6,7 @@ [1, 2, 3, 4, 5] >>> from pyspark.broadcast import _broadcastRegistry ->>> _broadcastRegistry[b.uuid] = b +>>> _broadcastRegistry[b.bid] = b >>> from cPickle import dumps, loads >>> loads(dumps(b)).value [1, 2, 3, 4, 5] @@ -14,27 +14,27 @@ >>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect() [1, 2, 3, 4, 5, 1, 2, 3, 4, 5] """ -# Holds broadcasted data received from Java, keyed by UUID. +# Holds broadcasted data received from Java, keyed by its id. _broadcastRegistry = {} -def _from_uuid(uuid): +def _from_id(bid): from pyspark.broadcast import _broadcastRegistry - if uuid not in _broadcastRegistry: - raise Exception("Broadcast variable '%s' not loaded!" % uuid) - return _broadcastRegistry[uuid] + if bid not in _broadcastRegistry: + raise Exception("Broadcast variable '%s' not loaded!" % bid) + return _broadcastRegistry[bid] class Broadcast(object): - def __init__(self, uuid, value, java_broadcast=None, pickle_registry=None): + def __init__(self, bid, value, java_broadcast=None, pickle_registry=None): self.value = value - self.uuid = uuid + self.bid = bid self._jbroadcast = java_broadcast self._pickle_registry = pickle_registry def __reduce__(self): self._pickle_registry.add(self) - return (_from_uuid, (self.uuid, )) + return (_from_id, (self.bid, )) def _test(): diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index 04932c93f2..3f4db26644 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -66,5 +66,5 @@ class SparkContext(object): def broadcast(self, value): jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value))) - return Broadcast(jbroadcast.uuid().toString(), value, jbroadcast, + return Broadcast(jbroadcast.id(), value, jbroadcast, self._pickled_broadcast_vars) diff --git a/pyspark/pyspark/java_gateway.py b/pyspark/pyspark/java_gateway.py index bcb405ba72..3726bcbf17 100644 --- a/pyspark/pyspark/java_gateway.py +++ b/pyspark/pyspark/java_gateway.py @@ -7,7 +7,8 @@ SPARK_HOME = os.environ["SPARK_HOME"] assembly_jar = glob.glob(os.path.join(SPARK_HOME, "core/target") + \ - "/spark-core-assembly-*-SNAPSHOT.jar")[0] + "/spark-core-assembly-*.jar")[0] + # TODO: what if multiple assembly jars are found? def launch_gateway(): diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py index faa1e683c7..21ef8b106c 100644 --- a/pyspark/pyspark/serializers.py +++ b/pyspark/pyspark/serializers.py @@ -9,16 +9,26 @@ def dump_pickle(obj): load_pickle = cPickle.loads +def read_long(stream): + length = stream.read(8) + if length == "": + raise EOFError + return struct.unpack("!q", length)[0] + + +def read_int(stream): + length = stream.read(4) + if length == "": + raise EOFError + return struct.unpack("!i", length)[0] + def write_with_length(obj, stream): stream.write(struct.pack("!i", len(obj))) stream.write(obj) def read_with_length(stream): - length = stream.read(4) - if length == "": - raise EOFError - length = struct.unpack("!i", length)[0] + length = read_int(stream) obj = stream.read(length) if obj == "": raise EOFError diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py index a9ed71892f..62824a1c9b 100644 --- a/pyspark/pyspark/worker.py +++ b/pyspark/pyspark/worker.py @@ -8,7 +8,7 @@ from base64 import standard_b64decode from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler from pyspark.serializers import write_with_length, read_with_length, \ - dump_pickle, load_pickle + read_long, read_int, dump_pickle, load_pickle # Redirect stdout to stderr so that users must return values from functions. @@ -29,11 +29,11 @@ def read_input(): def main(): - num_broadcast_variables = int(sys.stdin.readline().strip()) + num_broadcast_variables = read_int(sys.stdin) for _ in range(num_broadcast_variables): - uuid = sys.stdin.read(36) + bid = read_long(sys.stdin) value = read_with_length(sys.stdin) - _broadcastRegistry[uuid] = Broadcast(uuid, load_pickle(value)) + _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value)) func = load_obj() bypassSerializer = load_obj() if bypassSerializer: From c23bf1aff4b9a1faf9d32c7b64acad2213f9515c Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 20 Oct 2012 00:16:41 +0000 Subject: [PATCH 017/291] Add PySpark README and run scripts. --- core/src/main/scala/spark/SparkContext.scala | 2 +- pyspark/README | 58 ++++++++++++++++++++ pyspark/pyspark-shell | 3 + pyspark/pyspark/context.py | 5 +- pyspark/pyspark/examples/wordcount.py | 17 ++++++ pyspark/pyspark/shell.py | 21 +++++++ pyspark/run-pyspark | 23 ++++++++ 7 files changed, 125 insertions(+), 4 deletions(-) create mode 100644 pyspark/README create mode 100755 pyspark/pyspark-shell create mode 100644 pyspark/pyspark/examples/wordcount.py create mode 100644 pyspark/pyspark/shell.py create mode 100755 pyspark/run-pyspark diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index becf737597..acb38ae33d 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -113,7 +113,7 @@ class SparkContext( // Environment variables to pass to our executors private[spark] val executorEnvs = HashMap[String, String]() for (key <- Seq("SPARK_MEM", "SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS", - "SPARK_TESTING")) { + "SPARK_TESTING", "PYTHONPATH")) { val value = System.getenv(key) if (value != null) { executorEnvs(key) = value diff --git a/pyspark/README b/pyspark/README new file mode 100644 index 0000000000..63a1def141 --- /dev/null +++ b/pyspark/README @@ -0,0 +1,58 @@ +# PySpark + +PySpark is a Python API for Spark. + +PySpark jobs are writen in Python and executed using a standard Python +interpreter; this supports modules that use Python C extensions. The +API is based on the Spark Scala API and uses regular Python functions +and lambdas to support user-defined functions. PySpark supports +interactive use through a standard Python interpreter; it can +automatically serialize closures and ship them to worker processes. + +PySpark is built on top of the Spark Java API. Data is uniformly +represented as serialized Python objects and stored in Spark Java +processes, which communicate with PySpark worker processes over pipes. + +## Features + +PySpark supports most of the Spark API, including broadcast variables. +RDDs are dynamically typed and can hold any Python object. + +PySpark does not support: + +- Special functions on RDDs of doubles +- Accumulators + +## Examples and Documentation + +The PySpark source contains docstrings and doctests that document its +API. The public classes are in `context.py` and `rdd.py`. + +The `pyspark/pyspark/examples` directory contains a few complete +examples. + +## Installing PySpark + +PySpark requires a development version of Py4J, a Python library for +interacting with Java processes. It can be installed from +https://github.com/bartdag/py4j; make sure to install a version that +contains at least the commits through 3dbf380d3d. + +PySpark uses the `PYTHONPATH` environment variable to search for Python +classes; Py4J should be on this path, along with any libraries used by +PySpark programs. `PYTHONPATH` will be automatically shipped to worker +machines, but the files that it points to must be present on each +machine. + +PySpark requires the Spark assembly JAR, which can be created by running +`sbt/sbt assembly` in the Spark directory. + +Additionally, `SPARK_HOME` should be set to the location of the Spark +package. + +## Running PySpark + +The easiest way to run PySpark is to use the `run-pyspark` and +`pyspark-shell` scripts, which are included in the `pyspark` directory. +These scripts automatically load the `spark-conf.sh` file, set +`SPARK_HOME`, and add the `pyspark` package to the `PYTHONPATH`. diff --git a/pyspark/pyspark-shell b/pyspark/pyspark-shell new file mode 100755 index 0000000000..4ed3e6010c --- /dev/null +++ b/pyspark/pyspark-shell @@ -0,0 +1,3 @@ +#!/bin/sh +FWDIR="`dirname $0`" +exec $FWDIR/run-pyspark $FWDIR/pyspark/shell.py "$@" diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index 3f4db26644..50d57e5317 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -18,14 +18,13 @@ class SparkContext(object): asPickle = jvm.spark.api.python.PythonRDD.asPickle arrayAsPickle = jvm.spark.api.python.PythonRDD.arrayAsPickle - def __init__(self, master, name, defaultParallelism=None, - pythonExec='python'): + def __init__(self, master, name, defaultParallelism=None): self.master = master self.name = name self._jsc = self.jvm.JavaSparkContext(master, name) self.defaultParallelism = \ defaultParallelism or self._jsc.sc().defaultParallelism() - self.pythonExec = pythonExec + self.pythonExec = os.environ.get("PYSPARK_PYTHON_EXEC", 'python') # Broadcast's __reduce__ method stores Broadcast instances here. # This allows other code to determine which Broadcast instances have # been pickled, so it can determine which Java broadcast objects to diff --git a/pyspark/pyspark/examples/wordcount.py b/pyspark/pyspark/examples/wordcount.py new file mode 100644 index 0000000000..8365c070e8 --- /dev/null +++ b/pyspark/pyspark/examples/wordcount.py @@ -0,0 +1,17 @@ +import sys +from operator import add +from pyspark.context import SparkContext + +if __name__ == "__main__": + if len(sys.argv) < 3: + print >> sys.stderr, \ + "Usage: PythonWordCount " + exit(-1) + sc = SparkContext(sys.argv[1], "PythonWordCount") + lines = sc.textFile(sys.argv[2], 1) + counts = lines.flatMap(lambda x: x.split(' ')) \ + .map(lambda x: (x, 1)) \ + .reduceByKey(add) + output = counts.collect() + for (word, count) in output: + print "%s : %i" % (word, count) diff --git a/pyspark/pyspark/shell.py b/pyspark/pyspark/shell.py new file mode 100644 index 0000000000..7ef30894cb --- /dev/null +++ b/pyspark/pyspark/shell.py @@ -0,0 +1,21 @@ +""" +An interactive shell. +""" +import code +import sys + +from pyspark.context import SparkContext + + +def main(master='local'): + sc = SparkContext(master, 'PySparkShell') + print "Spark context available as sc." + code.interact(local={'sc': sc}) + + +if __name__ == '__main__': + if len(sys.argv) > 1: + master = sys.argv[1] + else: + master = 'local' + main(master) diff --git a/pyspark/run-pyspark b/pyspark/run-pyspark new file mode 100755 index 0000000000..9c5e027962 --- /dev/null +++ b/pyspark/run-pyspark @@ -0,0 +1,23 @@ +#!/bin/bash + +# Figure out where the Scala framework is installed +FWDIR="$(cd `dirname $0`; cd ../; pwd)" + +# Export this as SPARK_HOME +export SPARK_HOME="$FWDIR" + +# Load environment variables from conf/spark-env.sh, if it exists +if [ -e $FWDIR/conf/spark-env.sh ] ; then + . $FWDIR/conf/spark-env.sh +fi + +# Figure out which Python executable to use +if [ -z "$PYSPARK_PYTHON" ] ; then + PYSPARK_PYTHON="python" +fi +export PYSPARK_PYTHON + +# Add the PySpark classes to the Python path: +export PYTHONPATH=$SPARK_HOME/pyspark/:$PYTHONPATH + +exec "$PYSPARK_PYTHON" "$@" From d4f2e5b0ef38db9d42bb0d5fbbbe6103ce047efe Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 22 Oct 2012 10:28:59 -0700 Subject: [PATCH 018/291] Remove PYTHONPATH from SparkContext's executorEnvs. It makes more sense to pass it in the dictionary of environment variables that is used to construct PythonRDD. --- core/src/main/scala/spark/SparkContext.scala | 2 +- .../main/scala/spark/api/python/PythonRDD.scala | 15 +++++++-------- pyspark/pyspark/rdd.py | 8 ++++++-- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index acb38ae33d..becf737597 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -113,7 +113,7 @@ class SparkContext( // Environment variables to pass to our executors private[spark] val executorEnvs = HashMap[String, String]() for (key <- Seq("SPARK_MEM", "SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS", - "SPARK_TESTING", "PYTHONPATH")) { + "SPARK_TESTING")) { val value = System.getenv(key) if (value != null) { executorEnvs(key) = value diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 528885fe5c..a593e53efd 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -131,18 +131,17 @@ trait PythonRDDBase { } class PythonRDD[T: ClassManifest]( - parent: RDD[T], command: Seq[String], envVars: Map[String, String], + parent: RDD[T], command: Seq[String], envVars: java.util.Map[String, String], preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) extends RDD[Array[Byte]](parent.context) with PythonRDDBase { - def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, - pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = - this(parent, command, Map(), preservePartitoning, pythonExec, broadcastVars) - // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) - def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = - this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec, broadcastVars) + def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String], + preservePartitoning: Boolean, pythonExec: String, + broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = + this(parent, PipedRDD.tokenize(command), envVars, preservePartitoning, pythonExec, + broadcastVars) override def splits = parent.splits @@ -151,7 +150,7 @@ class PythonRDD[T: ClassManifest]( override val partitioner = if (preservePartitoning) parent.partitioner else None override def compute(split: Split): Iterator[Array[Byte]] = - compute(split, envVars, command, parent, pythonExec, broadcastVars) + compute(split, envVars.toMap, command, parent, pythonExec, broadcastVars) val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) } diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index e2137fe06c..e4878c08ba 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -1,6 +1,7 @@ from base64 import standard_b64encode as b64enc from collections import defaultdict from itertools import chain, ifilter, imap +import os import shlex from subprocess import Popen, PIPE from threading import Thread @@ -10,7 +11,7 @@ from pyspark.serializers import dump_pickle, load_pickle from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup -from py4j.java_collections import ListConverter +from py4j.java_collections import ListConverter, MapConverter class RDD(object): @@ -447,8 +448,11 @@ class PipelinedRDD(RDD): self.ctx.gateway._gateway_client) self.ctx._pickled_broadcast_vars.clear() class_manifest = self._prev_jrdd.classManifest() + env = MapConverter().convert( + {'PYTHONPATH' : os.environ.get("PYTHONPATH", "")}, + self.ctx.gateway._gateway_client) python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), - pipe_command, self.preservesPartitioning, self.ctx.pythonExec, + pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec, broadcast_vars, class_manifest) self._jrdd_val = python_rdd.asJavaRDD() return self._jrdd_val From 7859879aaa1860ff6b383e32a18fd9a410a97416 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 28 Oct 2012 16:46:31 -0700 Subject: [PATCH 019/291] Bump required Py4J version and add test for large broadcast variables. --- pyspark/README | 2 +- pyspark/pyspark/broadcast.py | 2 ++ pyspark/requirements.txt | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pyspark/README b/pyspark/README index 63a1def141..55490e1a83 100644 --- a/pyspark/README +++ b/pyspark/README @@ -36,7 +36,7 @@ examples. PySpark requires a development version of Py4J, a Python library for interacting with Java processes. It can be installed from https://github.com/bartdag/py4j; make sure to install a version that -contains at least the commits through 3dbf380d3d. +contains at least the commits through b7924aabe9. PySpark uses the `PYTHONPATH` environment variable to search for Python classes; Py4J should be on this path, along with any libraries used by diff --git a/pyspark/pyspark/broadcast.py b/pyspark/pyspark/broadcast.py index 4cff02b36d..93876fa738 100644 --- a/pyspark/pyspark/broadcast.py +++ b/pyspark/pyspark/broadcast.py @@ -13,6 +13,8 @@ >>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect() [1, 2, 3, 4, 5, 1, 2, 3, 4, 5] + +>>> large_broadcast = sc.broadcast(list(range(10000))) """ # Holds broadcasted data received from Java, keyed by its id. _broadcastRegistry = {} diff --git a/pyspark/requirements.txt b/pyspark/requirements.txt index 71e2bc2b89..48fa2ab105 100644 --- a/pyspark/requirements.txt +++ b/pyspark/requirements.txt @@ -3,4 +3,4 @@ # package is not at the root of the git repository. It may be possible to # install Py4J from git once https://github.com/pypa/pip/pull/526 is merged. -# git+git://github.com/bartdag/py4j.git@3dbf380d3d2cdeb9aab394454ea74d80c4aba1ea +# git+git://github.com/bartdag/py4j.git@b7924aabe9c5e63f0a4d8bbd17019534c7ec014e From 2ccf3b665280bf5b0919e3801d028126cb070dbd Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 28 Oct 2012 22:30:28 -0700 Subject: [PATCH 020/291] Fix PySpark hash partitioning bug. A Java array's hashCode is based on its object identify, not its elements, so this was causing serialized keys to be hashed incorrectly. This commit adds a PySpark-specific workaround and adds more tests. --- .../spark/api/python/PythonPartitioner.scala | 41 +++++++++++++++++++ .../scala/spark/api/python/PythonRDD.scala | 10 ++--- pyspark/pyspark/rdd.py | 12 ++++-- 3 files changed, 54 insertions(+), 9 deletions(-) create mode 100644 core/src/main/scala/spark/api/python/PythonPartitioner.scala diff --git a/core/src/main/scala/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/spark/api/python/PythonPartitioner.scala new file mode 100644 index 0000000000..ef9f808fb2 --- /dev/null +++ b/core/src/main/scala/spark/api/python/PythonPartitioner.scala @@ -0,0 +1,41 @@ +package spark.api.python + +import spark.Partitioner + +import java.util.Arrays + +/** + * A [[spark.Partitioner]] that performs handling of byte arrays, for use by the Python API. + */ +class PythonPartitioner(override val numPartitions: Int) extends Partitioner { + + override def getPartition(key: Any): Int = { + if (key == null) { + return 0 + } + else { + val hashCode = { + if (key.isInstanceOf[Array[Byte]]) { + System.err.println("Dumping a byte array!" + Arrays.hashCode(key.asInstanceOf[Array[Byte]]) + ) + Arrays.hashCode(key.asInstanceOf[Array[Byte]]) + } + else + key.hashCode() + } + val mod = hashCode % numPartitions + if (mod < 0) { + mod + numPartitions + } else { + mod // Guard against negative hash codes + } + } + } + + override def equals(other: Any): Boolean = other match { + case h: PythonPartitioner => + h.numPartitions == numPartitions + case _ => + false + } +} diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index a593e53efd..50094d6b0f 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -179,14 +179,12 @@ object PythonRDD { val dOut = new DataOutputStream(baos); if (elem.isInstanceOf[Array[Byte]]) { elem.asInstanceOf[Array[Byte]] - } else if (elem.isInstanceOf[scala.Tuple2[_, _]]) { - val t = elem.asInstanceOf[scala.Tuple2[_, _]] - val t1 = t._1.asInstanceOf[Array[Byte]] - val t2 = t._2.asInstanceOf[Array[Byte]] + } else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) { + val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]] dOut.writeByte(Pickle.PROTO) dOut.writeByte(Pickle.TWO) - dOut.write(PythonRDD.stripPickle(t1)) - dOut.write(PythonRDD.stripPickle(t2)) + dOut.write(PythonRDD.stripPickle(t._1)) + dOut.write(PythonRDD.stripPickle(t._2)) dOut.writeByte(Pickle.TUPLE2) dOut.writeByte(Pickle.STOP) baos.toByteArray() diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index e4878c08ba..85a24c6854 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -310,6 +310,12 @@ class RDD(object): return python_right_outer_join(self, other, numSplits) def partitionBy(self, numSplits, hashFunc=hash): + """ + >>> pairs = sc.parallelize([1, 2, 3, 4, 2, 4, 1]).map(lambda x: (x, x)) + >>> sets = pairs.partitionBy(2).glom().collect() + >>> set(sets[0]).intersection(set(sets[1])) + set([]) + """ if numSplits is None: numSplits = self.ctx.defaultParallelism def add_shuffle_key(iterator): @@ -319,7 +325,7 @@ class RDD(object): keyed = PipelinedRDD(self, add_shuffle_key) keyed._bypass_serializer = True pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() - partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits) + partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits) jrdd = pairRDD.partitionBy(partitioner) jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) return RDD(jrdd, self.ctx) @@ -391,7 +397,7 @@ class RDD(object): """ >>> x = sc.parallelize([("a", 1), ("b", 4)]) >>> y = sc.parallelize([("a", 2)]) - >>> x.cogroup(y).collect() + >>> sorted(x.cogroup(y).collect()) [('a', ([1], [2])), ('b', ([4], []))] """ return python_cogroup(self, other, numSplits) @@ -462,7 +468,7 @@ def _test(): import doctest from pyspark.context import SparkContext globs = globals().copy() - globs['sc'] = SparkContext('local', 'PythonTest') + globs['sc'] = SparkContext('local[4]', 'PythonTest') doctest.testmod(globs=globs) globs['sc'].stop() From 531ac136bf4ed333cb906ac229d986605a8207a6 Mon Sep 17 00:00:00 2001 From: Denny Date: Mon, 29 Oct 2012 14:53:47 -0700 Subject: [PATCH 021/291] BlockManager UI. --- core/src/main/scala/spark/RDD.scala | 8 ++ core/src/main/scala/spark/SparkContext.scala | 10 ++ .../spark/storage/BlockManagerMaster.scala | 33 +++++- .../scala/spark/storage/BlockManagerUI.scala | 102 ++++++++++++++++++ .../src/main/scala/spark/util/AkkaUtils.scala | 5 +- .../{deploy => }/common/layout.scala.html | 0 .../spark/deploy/master/index.scala.html | 2 +- .../deploy/master/job_details.scala.html | 2 +- .../spark/deploy/worker/index.scala.html | 2 +- .../main/twirl/spark/storage/index.scala.html | 28 +++++ .../main/twirl/spark/storage/rdd.scala.html | 65 +++++++++++ .../twirl/spark/storage/rdd_row.scala.html | 18 ++++ .../twirl/spark/storage/rdd_table.scala.html | 18 ++++ 13 files changed, 283 insertions(+), 10 deletions(-) create mode 100644 core/src/main/scala/spark/storage/BlockManagerUI.scala rename core/src/main/twirl/spark/{deploy => }/common/layout.scala.html (100%) create mode 100644 core/src/main/twirl/spark/storage/index.scala.html create mode 100644 core/src/main/twirl/spark/storage/rdd.scala.html create mode 100644 core/src/main/twirl/spark/storage/rdd_row.scala.html create mode 100644 core/src/main/twirl/spark/storage/rdd_table.scala.html diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 338dff4061..dc757dc6aa 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -107,6 +107,12 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial // Variables relating to persistence private var storageLevel: StorageLevel = StorageLevel.NONE + /* Assign a name to this RDD */ + def name(name: String) = { + sc.rddNames(this.id) = name + this + } + /** * Set this RDD's storage level to persist its values across operations after the first time * it is computed. Can only be called once on each RDD. @@ -118,6 +124,8 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial "Cannot change storage level of an RDD after it was already assigned a level") } storageLevel = newLevel + // Register the RDD with the SparkContext + sc.persistentRdds(id) = this this } diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index d26cccbfe1..71c9dcd017 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -1,6 +1,7 @@ package spark import java.io._ +import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicInteger import java.net.{URI, URLClassLoader} @@ -102,10 +103,19 @@ class SparkContext( isLocal) SparkEnv.set(env) + // Start the BlockManager UI + spark.storage.BlockManagerUI.start(SparkEnv.get.actorSystem, + SparkEnv.get.blockManager.master.masterActor, this) + // Used to store a URL for each static file/jar together with the file's local timestamp private[spark] val addedFiles = HashMap[String, Long]() private[spark] val addedJars = HashMap[String, Long]() + // Keeps track of all persisted RDDs + private[spark] val persistentRdds = new ConcurrentHashMap[Int, RDD[_]]() + // A HashMap for friendly RDD Names + private[spark] val rddNames = new ConcurrentHashMap[Int, String]() + // Add each JAR given through the constructor jars.foreach { addJar(_) } diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index ace27e758c..d12a16869a 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -3,7 +3,8 @@ package spark.storage import java.io._ import java.util.{HashMap => JHashMap} -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.collection.JavaConverters._ +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} import scala.util.Random import akka.actor._ @@ -90,6 +91,15 @@ case object StopBlockManagerMaster extends ToBlockManagerMaster private[spark] case object GetMemoryStatus extends ToBlockManagerMaster +private[spark] +case class GetStorageStatus extends ToBlockManagerMaster + +private[spark] +case class BlockStatus(storageLevel: StorageLevel, memSize: Long, diskSize: Long) + +private[spark] +case class StorageStatus(maxMem: Long, remainingMem: Long, blocks: Map[String, BlockStatus]) + private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { @@ -99,7 +109,8 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor val maxMem: Long) { private var _lastSeenMs = timeMs private var _remainingMem = maxMem - private val _blocks = new JHashMap[String, StorageLevel] + + private val _blocks = new JHashMap[String, BlockStatus] logInfo("Registering block manager %s:%d with %s RAM".format( blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(maxMem))) @@ -115,7 +126,7 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor if (_blocks.containsKey(blockId)) { // The block exists on the slave already. - val originalLevel: StorageLevel = _blocks.get(blockId) + val originalLevel: StorageLevel = _blocks.get(blockId).storageLevel if (originalLevel.useMemory) { _remainingMem += memSize @@ -124,7 +135,7 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor if (storageLevel.isValid) { // isValid means it is either stored in-memory or on-disk. - _blocks.put(blockId, storageLevel) + _blocks.put(blockId, BlockStatus(storageLevel, memSize, diskSize)) if (storageLevel.useMemory) { _remainingMem -= memSize logInfo("Added %s in memory on %s:%d (size: %s, free: %s)".format( @@ -137,7 +148,7 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor } } else if (_blocks.containsKey(blockId)) { // If isValid is not true, drop the block. - val originalLevel: StorageLevel = _blocks.get(blockId) + val originalLevel: StorageLevel = _blocks.get(blockId).storageLevel _blocks.remove(blockId) if (originalLevel.useMemory) { _remainingMem += memSize @@ -152,6 +163,8 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor } } + def blocks: JHashMap[String, BlockStatus] = _blocks + def remainingMem: Long = _remainingMem def lastSeenMs: Long = _lastSeenMs @@ -198,6 +211,9 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor case GetMemoryStatus => getMemoryStatus + case GetStorageStatus => + getStorageStatus + case RemoveHost(host) => removeHost(host) sender ! true @@ -219,6 +235,13 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor sender ! res } + private def getStorageStatus() { + val res = blockManagerInfo.map { case(blockManagerId, info) => + StorageStatus(info.maxMem, info.remainingMem, info.blocks.asScala) + } + sender ! res + } + private def register(blockManagerId: BlockManagerId, maxMemSize: Long) { val startTimeMs = System.currentTimeMillis() val tmp = " " + blockManagerId + " " diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala new file mode 100644 index 0000000000..c168f60c35 --- /dev/null +++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala @@ -0,0 +1,102 @@ +package spark.storage + +import akka.actor.{ActorRef, ActorSystem} +import akka.dispatch.Await +import akka.pattern.ask +import akka.util.Timeout +import akka.util.duration._ +import cc.spray.Directives +import cc.spray.directives._ +import cc.spray.typeconversion.TwirlSupport._ +import scala.collection.mutable.ArrayBuffer +import spark.{Logging, SparkContext, SparkEnv} +import spark.util.AkkaUtils + +private[spark] +object BlockManagerUI extends Logging { + + /* Starts the Web interface for the BlockManager */ + def start(actorSystem : ActorSystem, masterActor: ActorRef, sc: SparkContext) { + val webUIDirectives = new BlockManagerUIDirectives(actorSystem, masterActor, sc) + try { + logInfo("Starting BlockManager WebUI.") + val port = Option(System.getenv("BLOCKMANAGER_UI_PORT")).getOrElse("9080").toInt + AkkaUtils.startSprayServer(actorSystem, "0.0.0.0", port, webUIDirectives.handler, "BlockManagerHTTPServer") + } catch { + case e: Exception => + logError("Failed to create BlockManager WebUI", e) + System.exit(1) + } + } + +} + +private[spark] +case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel, numPartitions: Int, memSize: Long, diskSize: Long) + +private[spark] +class BlockManagerUIDirectives(val actorSystem: ActorSystem, master: ActorRef, sc: SparkContext) extends Directives { + + val STATIC_RESOURCE_DIR = "spark/deploy/static" + implicit val timeout = Timeout(1 seconds) + + val handler = { + + get { path("") { completeWith { + // Request the current storage status from the Master + val future = master ? GetStorageStatus + future.map { status => + val storageStati = status.asInstanceOf[ArrayBuffer[StorageStatus]] + + // Calculate macro-level statistics + val maxMem = storageStati.map(_.maxMem).reduce(_+_) + val remainingMem = storageStati.map(_.remainingMem).reduce(_+_) + val diskSpaceUsed = storageStati.flatMap(_.blocks.values.map(_.diskSize)) + .reduceOption(_+_).getOrElse(0L) + + // Filter out everything that's not and rdd. + val rddBlocks = storageStati.flatMap(_.blocks).filter { case(k,v) => k.startsWith("rdd") }.toMap + val rdds = rddInfoFromBlockStati(rddBlocks) + + spark.storage.html.index.render(maxMem, remainingMem, diskSpaceUsed, rdds.toList) + } + }}} ~ + get { path("rdd") { parameter("id") { id => { completeWith { + val future = master ? GetStorageStatus + future.map { status => + val prefix = "rdd_" + id.toString + + val storageStati = status.asInstanceOf[ArrayBuffer[StorageStatus]] + val rddBlocks = storageStati.flatMap(_.blocks).filter { case(k,v) => k.startsWith(prefix) }.toMap + val rddInfo = rddInfoFromBlockStati(rddBlocks).first + + spark.storage.html.rdd.render(rddInfo, rddBlocks) + + } + }}}}} ~ + pathPrefix("static") { + getFromResourceDirectory(STATIC_RESOURCE_DIR) + } + + } + + private def rddInfoFromBlockStati(infos: Map[String, BlockStatus]) : Array[RDDInfo] = { + infos.groupBy { case(k,v) => + // Group by rdd name, ignore the partition name + k.substring(0,k.lastIndexOf('_')) + }.map { case(k,v) => + val blockStati = v.map(_._2).toArray + // Add up memory and disk sizes + val tmp = blockStati.map { x => (x.memSize, x.diskSize)}.reduce { (x,y) => + (x._1 + y._1, x._2 + y._2) + } + // Get the friendly name for the rdd, if available. + // This is pretty hacky, is there a better way? + val rddId = k.split("_").last.toInt + val rddName : String = Option(sc.rddNames.get(rddId)).getOrElse(k) + val rddStorageLevel = sc.persistentRdds.get(rddId).getStorageLevel + RDDInfo(rddId, rddName, rddStorageLevel, blockStati.length, tmp._1, tmp._2) + }.toArray + } + +} diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index b466b5239c..13bc0f8ccc 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -50,12 +50,13 @@ private[spark] object AkkaUtils { * Creates a Spray HTTP server bound to a given IP and port with a given Spray Route object to * handle requests. Throws a SparkException if this fails. */ - def startSprayServer(actorSystem: ActorSystem, ip: String, port: Int, route: Route) { + def startSprayServer(actorSystem: ActorSystem, ip: String, port: Int, route: Route, + name: String = "HttpServer") { val ioWorker = new IoWorker(actorSystem).start() val httpService = actorSystem.actorOf(Props(new HttpService(route))) val rootService = actorSystem.actorOf(Props(new SprayCanRootService(httpService))) val server = actorSystem.actorOf( - Props(new HttpServer(ioWorker, SingletonHandler(rootService))), name = "HttpServer") + Props(new HttpServer(ioWorker, SingletonHandler(rootService))), name = name) actorSystem.registerOnTermination { ioWorker.stop() } val timeout = 3.seconds val future = server.ask(HttpServer.Bind(ip, port))(timeout) diff --git a/core/src/main/twirl/spark/deploy/common/layout.scala.html b/core/src/main/twirl/spark/common/layout.scala.html similarity index 100% rename from core/src/main/twirl/spark/deploy/common/layout.scala.html rename to core/src/main/twirl/spark/common/layout.scala.html diff --git a/core/src/main/twirl/spark/deploy/master/index.scala.html b/core/src/main/twirl/spark/deploy/master/index.scala.html index 7562076b00..2e15fe2200 100644 --- a/core/src/main/twirl/spark/deploy/master/index.scala.html +++ b/core/src/main/twirl/spark/deploy/master/index.scala.html @@ -1,7 +1,7 @@ @(state: spark.deploy.MasterState) @import spark.deploy.master._ -@spark.deploy.common.html.layout(title = "Spark Master on " + state.uri) { +@spark.common.html.layout(title = "Spark Master on " + state.uri) {
diff --git a/core/src/main/twirl/spark/deploy/master/job_details.scala.html b/core/src/main/twirl/spark/deploy/master/job_details.scala.html index dcf41c28f2..d02a51b214 100644 --- a/core/src/main/twirl/spark/deploy/master/job_details.scala.html +++ b/core/src/main/twirl/spark/deploy/master/job_details.scala.html @@ -1,6 +1,6 @@ @(job: spark.deploy.master.JobInfo) -@spark.deploy.common.html.layout(title = "Job Details") { +@spark.common.html.layout(title = "Job Details") {
diff --git a/core/src/main/twirl/spark/deploy/worker/index.scala.html b/core/src/main/twirl/spark/deploy/worker/index.scala.html index 69746ed02c..40c2d81d77 100644 --- a/core/src/main/twirl/spark/deploy/worker/index.scala.html +++ b/core/src/main/twirl/spark/deploy/worker/index.scala.html @@ -1,6 +1,6 @@ @(worker: spark.deploy.WorkerState) -@spark.deploy.common.html.layout(title = "Spark Worker on " + worker.uri) { +@spark.common.html.layout(title = "Spark Worker on " + worker.uri) {
diff --git a/core/src/main/twirl/spark/storage/index.scala.html b/core/src/main/twirl/spark/storage/index.scala.html new file mode 100644 index 0000000000..fa7dad51ee --- /dev/null +++ b/core/src/main/twirl/spark/storage/index.scala.html @@ -0,0 +1,28 @@ +@(maxMem: Long, remainingMem: Long, diskSpaceUsed: Long, rdds: List[spark.storage.RDDInfo]) + +@spark.common.html.layout(title = "Storage Dashboard") { + + +
+
+
    +
  • Memory: + @{spark.Utils.memoryBytesToString(maxMem - remainingMem)} Used + (@{spark.Utils.memoryBytesToString(remainingMem)} Available)
  • +
  • Disk: @{spark.Utils.memoryBytesToString(diskSpaceUsed)} Used
  • +
+
+
+ +
+ + +
+
+

RDD Summary

+
+ @rdd_table(rdds) +
+
+ +} \ No newline at end of file diff --git a/core/src/main/twirl/spark/storage/rdd.scala.html b/core/src/main/twirl/spark/storage/rdd.scala.html new file mode 100644 index 0000000000..3a70326efe --- /dev/null +++ b/core/src/main/twirl/spark/storage/rdd.scala.html @@ -0,0 +1,65 @@ +@(rddInfo: spark.storage.RDDInfo, blocks: Map[String, spark.storage.BlockStatus]) + +@spark.common.html.layout(title = "RDD Info ") { + + +
+
+
    +
  • + Storage Level: + @(if (rddInfo.storageLevel.useDisk) "Disk" else "") + @(if (rddInfo.storageLevel.useMemory) "Memory" else "") + @(if (rddInfo.storageLevel.deserialized) "Deserialized" else "") + @(rddInfo.storageLevel.replication)x Replicated +
  • + Partitions: + @(rddInfo.numPartitions) +
  • +
  • + Memory Size: + @{spark.Utils.memoryBytesToString(rddInfo.memSize)} +
  • +
  • + Disk Size: + @{spark.Utils.memoryBytesToString(rddInfo.diskSize)} +
  • +
+
+
+ +
+ + +
+
+

RDD Summary

+
+ + + + + + + + + + + + + @blocks.map { case (k,v) => + + + + + + + } + +
Block NameStorage LevelSize in MemorySize on Disk
@k@v.storageLevel@{spark.Utils.memoryBytesToString(v.memSize)}@{spark.Utils.memoryBytesToString(v.diskSize)}
+ + +
+
+ +} \ No newline at end of file diff --git a/core/src/main/twirl/spark/storage/rdd_row.scala.html b/core/src/main/twirl/spark/storage/rdd_row.scala.html new file mode 100644 index 0000000000..3dd9944e3b --- /dev/null +++ b/core/src/main/twirl/spark/storage/rdd_row.scala.html @@ -0,0 +1,18 @@ +@(rdd: spark.storage.RDDInfo) + + + + + @rdd.name + + + + @(if (rdd.storageLevel.useDisk) "Disk" else "") + @(if (rdd.storageLevel.useMemory) "Memory" else "") + @(if (rdd.storageLevel.deserialized) "Deserialized" else "") + @(rdd.storageLevel.replication)x Replicated + + @rdd.numPartitions + @{spark.Utils.memoryBytesToString(rdd.memSize)} + @{spark.Utils.memoryBytesToString(rdd.diskSize)} + \ No newline at end of file diff --git a/core/src/main/twirl/spark/storage/rdd_table.scala.html b/core/src/main/twirl/spark/storage/rdd_table.scala.html new file mode 100644 index 0000000000..24f55ccefb --- /dev/null +++ b/core/src/main/twirl/spark/storage/rdd_table.scala.html @@ -0,0 +1,18 @@ +@(rdds: List[spark.storage.RDDInfo]) + + + + + + + + + + + + + @for(rdd <- rdds) { + @rdd_row(rdd) + } + +
RDD NameStorage LevelPartitionsSize in MemorySize on Disk
\ No newline at end of file From eb95212f4d24dbcd734922f39d51e6fdeaeb4c8b Mon Sep 17 00:00:00 2001 From: Denny Date: Mon, 29 Oct 2012 14:57:32 -0700 Subject: [PATCH 022/291] code Formatting --- .../scala/spark/storage/BlockManagerUI.scala | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala index c168f60c35..635c096c87 100644 --- a/core/src/main/scala/spark/storage/BlockManagerUI.scala +++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala @@ -21,7 +21,8 @@ object BlockManagerUI extends Logging { try { logInfo("Starting BlockManager WebUI.") val port = Option(System.getenv("BLOCKMANAGER_UI_PORT")).getOrElse("9080").toInt - AkkaUtils.startSprayServer(actorSystem, "0.0.0.0", port, webUIDirectives.handler, "BlockManagerHTTPServer") + AkkaUtils.startSprayServer(actorSystem, "0.0.0.0", port, + webUIDirectives.handler, "BlockManagerHTTPServer") } catch { case e: Exception => logError("Failed to create BlockManager WebUI", e) @@ -32,10 +33,12 @@ object BlockManagerUI extends Logging { } private[spark] -case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel, numPartitions: Int, memSize: Long, diskSize: Long) +case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel, + numPartitions: Int, memSize: Long, diskSize: Long) private[spark] -class BlockManagerUIDirectives(val actorSystem: ActorSystem, master: ActorRef, sc: SparkContext) extends Directives { +class BlockManagerUIDirectives(val actorSystem: ActorSystem, master: ActorRef, + sc: SparkContext) extends Directives { val STATIC_RESOURCE_DIR = "spark/deploy/static" implicit val timeout = Timeout(1 seconds) @@ -55,7 +58,9 @@ class BlockManagerUIDirectives(val actorSystem: ActorSystem, master: ActorRef, s .reduceOption(_+_).getOrElse(0L) // Filter out everything that's not and rdd. - val rddBlocks = storageStati.flatMap(_.blocks).filter { case(k,v) => k.startsWith("rdd") }.toMap + val rddBlocks = storageStati.flatMap(_.blocks).filter { case(k,v) => + k.startsWith("rdd") + }.toMap val rdds = rddInfoFromBlockStati(rddBlocks) spark.storage.html.index.render(maxMem, remainingMem, diskSpaceUsed, rdds.toList) @@ -67,7 +72,9 @@ class BlockManagerUIDirectives(val actorSystem: ActorSystem, master: ActorRef, s val prefix = "rdd_" + id.toString val storageStati = status.asInstanceOf[ArrayBuffer[StorageStatus]] - val rddBlocks = storageStati.flatMap(_.blocks).filter { case(k,v) => k.startsWith(prefix) }.toMap + val rddBlocks = storageStati.flatMap(_.blocks).filter { case(k,v) => + k.startsWith(prefix) + }.toMap val rddInfo = rddInfoFromBlockStati(rddBlocks).first spark.storage.html.rdd.render(rddInfo, rddBlocks) From ceec1a1a6abb1fd03316e7fcc532d7e121d5bf65 Mon Sep 17 00:00:00 2001 From: Denny Date: Mon, 29 Oct 2012 15:03:01 -0700 Subject: [PATCH 023/291] Nicer storage level format on RDD page --- core/src/main/twirl/spark/storage/rdd.scala.html | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/core/src/main/twirl/spark/storage/rdd.scala.html b/core/src/main/twirl/spark/storage/rdd.scala.html index 3a70326efe..075289c826 100644 --- a/core/src/main/twirl/spark/storage/rdd.scala.html +++ b/core/src/main/twirl/spark/storage/rdd.scala.html @@ -50,7 +50,12 @@ @blocks.map { case (k,v) => @k - @v.storageLevel + + @(if (v.storageLevel.useDisk) "Disk" else "") + @(if (v.storageLevel.useMemory) "Memory" else "") + @(if (v.storageLevel.deserialized) "Deserialized" else "") + @(v.storageLevel.replication)x Replicated + @{spark.Utils.memoryBytesToString(v.memSize)} @{spark.Utils.memoryBytesToString(v.diskSize)} From 4a1be7e0dbf0031d85b91dc1132fe101d87ba097 Mon Sep 17 00:00:00 2001 From: Denny Date: Mon, 12 Nov 2012 10:56:35 -0800 Subject: [PATCH 024/291] Refactor BlockManager UI and adding worker details. --- core/src/main/scala/spark/RDD.scala | 7 +- core/src/main/scala/spark/SparkContext.scala | 2 - .../spark/storage/BlockManagerMaster.scala | 11 +-- .../scala/spark/storage/BlockManagerUI.scala | 51 ++++-------- .../scala/spark/storage/StorageLevel.scala | 9 +++ .../scala/spark/storage/StorageUtils.scala | 78 +++++++++++++++++++ .../main/twirl/spark/storage/index.scala.html | 22 ++++-- .../main/twirl/spark/storage/rdd.scala.html | 35 +++++---- .../twirl/spark/storage/rdd_row.scala.html | 18 ----- .../twirl/spark/storage/rdd_table.scala.html | 16 +++- .../spark/storage/worker_table.scala.html | 24 ++++++ 11 files changed, 186 insertions(+), 87 deletions(-) create mode 100644 core/src/main/scala/spark/storage/StorageUtils.scala delete mode 100644 core/src/main/twirl/spark/storage/rdd_row.scala.html create mode 100644 core/src/main/twirl/spark/storage/worker_table.scala.html diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index dc757dc6aa..3669bda2d2 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -86,6 +86,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial @transient val dependencies: List[Dependency[_]] // Methods available on all RDDs: + + // A friendly name for this RDD + var name: String = null /** Record user function generating this RDD. */ private[spark] val origin = Utils.getSparkCallSite @@ -108,8 +111,8 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial private var storageLevel: StorageLevel = StorageLevel.NONE /* Assign a name to this RDD */ - def name(name: String) = { - sc.rddNames(this.id) = name + def setName(_name: String) = { + name = _name this } diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 71c9dcd017..7ea0f6f9e0 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -113,8 +113,6 @@ class SparkContext( // Keeps track of all persisted RDDs private[spark] val persistentRdds = new ConcurrentHashMap[Int, RDD[_]]() - // A HashMap for friendly RDD Names - private[spark] val rddNames = new ConcurrentHashMap[Int, String]() // Add each JAR given through the constructor jars.foreach { addJar(_) } diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index 3fc9b629c1..beafdda9d1 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -4,7 +4,7 @@ import java.io._ import java.util.{HashMap => JHashMap} import scala.collection.JavaConverters._ -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.util.Random import akka.actor._ @@ -95,10 +95,7 @@ private[spark] case class GetStorageStatus extends ToBlockManagerMaster private[spark] -case class BlockStatus(storageLevel: StorageLevel, memSize: Long, diskSize: Long) - -private[spark] -case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long, remainingMem: Long, blocks: Map[String, BlockStatus]) +case class BlockStatus(blockManagerId: BlockManagerId, storageLevel: StorageLevel, memSize: Long, diskSize: Long) private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { @@ -135,7 +132,7 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor if (storageLevel.isValid) { // isValid means it is either stored in-memory or on-disk. - _blocks.put(blockId, BlockStatus(storageLevel, memSize, diskSize)) + _blocks.put(blockId, BlockStatus(blockManagerId, storageLevel, memSize, diskSize)) if (storageLevel.useMemory) { _remainingMem -= memSize logInfo("Added %s in memory on %s:%d (size: %s, free: %s)".format( @@ -237,7 +234,7 @@ private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor private def getStorageStatus() { val res = blockManagerInfo.map { case(blockManagerId, info) => - StorageStatus(blockManagerId, info.maxMem, info.remainingMem, info.blocks.asScala) + StorageStatus(blockManagerId, info.maxMem, info.blocks.asScala.toMap) } sender ! res } diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala index 635c096c87..35cbd59280 100644 --- a/core/src/main/scala/spark/storage/BlockManagerUI.scala +++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala @@ -12,6 +12,7 @@ import scala.collection.mutable.ArrayBuffer import spark.{Logging, SparkContext, SparkEnv} import spark.util.AkkaUtils + private[spark] object BlockManagerUI extends Logging { @@ -32,9 +33,6 @@ object BlockManagerUI extends Logging { } -private[spark] -case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel, - numPartitions: Int, memSize: Long, diskSize: Long) private[spark] class BlockManagerUIDirectives(val actorSystem: ActorSystem, master: ActorRef, @@ -49,21 +47,17 @@ class BlockManagerUIDirectives(val actorSystem: ActorSystem, master: ActorRef, // Request the current storage status from the Master val future = master ? GetStorageStatus future.map { status => - val storageStati = status.asInstanceOf[ArrayBuffer[StorageStatus]] + val storageStatusList = status.asInstanceOf[ArrayBuffer[StorageStatus]].toArray // Calculate macro-level statistics - val maxMem = storageStati.map(_.maxMem).reduce(_+_) - val remainingMem = storageStati.map(_.remainingMem).reduce(_+_) - val diskSpaceUsed = storageStati.flatMap(_.blocks.values.map(_.diskSize)) + val maxMem = storageStatusList.map(_.maxMem).reduce(_+_) + val remainingMem = storageStatusList.map(_.memRemaining).reduce(_+_) + val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize)) .reduceOption(_+_).getOrElse(0L) - // Filter out everything that's not and rdd. - val rddBlocks = storageStati.flatMap(_.blocks).filter { case(k,v) => - k.startsWith("rdd") - }.toMap - val rdds = rddInfoFromBlockStati(rddBlocks) + val rdds = StorageUtils.rddInfoFromStorageStatus(storageStatusList, sc) - spark.storage.html.index.render(maxMem, remainingMem, diskSpaceUsed, rdds.toList) + spark.storage.html.index.render(maxMem, remainingMem, diskSpaceUsed, rdds, storageStatusList) } }}} ~ get { path("rdd") { parameter("id") { id => { completeWith { @@ -71,13 +65,13 @@ class BlockManagerUIDirectives(val actorSystem: ActorSystem, master: ActorRef, future.map { status => val prefix = "rdd_" + id.toString - val storageStati = status.asInstanceOf[ArrayBuffer[StorageStatus]] - val rddBlocks = storageStati.flatMap(_.blocks).filter { case(k,v) => - k.startsWith(prefix) - }.toMap - val rddInfo = rddInfoFromBlockStati(rddBlocks).first - spark.storage.html.rdd.render(rddInfo, rddBlocks) + val storageStatusList = status.asInstanceOf[ArrayBuffer[StorageStatus]].toArray + val filteredStorageStatusList = StorageUtils.filterStorageStatusByPrefix(storageStatusList, prefix) + + val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).first + + spark.storage.html.rdd.render(rddInfo, filteredStorageStatusList) } }}}}} ~ @@ -87,23 +81,6 @@ class BlockManagerUIDirectives(val actorSystem: ActorSystem, master: ActorRef, } - private def rddInfoFromBlockStati(infos: Map[String, BlockStatus]) : Array[RDDInfo] = { - infos.groupBy { case(k,v) => - // Group by rdd name, ignore the partition name - k.substring(0,k.lastIndexOf('_')) - }.map { case(k,v) => - val blockStati = v.map(_._2).toArray - // Add up memory and disk sizes - val tmp = blockStati.map { x => (x.memSize, x.diskSize)}.reduce { (x,y) => - (x._1 + y._1, x._2 + y._2) - } - // Get the friendly name for the rdd, if available. - // This is pretty hacky, is there a better way? - val rddId = k.split("_").last.toInt - val rddName : String = Option(sc.rddNames.get(rddId)).getOrElse(k) - val rddStorageLevel = sc.persistentRdds.get(rddId).getStorageLevel - RDDInfo(rddId, rddName, rddStorageLevel, blockStati.length, tmp._1, tmp._2) - }.toArray - } + } diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala index c497f03e0c..97d8c7566d 100644 --- a/core/src/main/scala/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/spark/storage/StorageLevel.scala @@ -68,6 +68,15 @@ class StorageLevel( override def toString: String = "StorageLevel(%b, %b, %b, %d)".format(useDisk, useMemory, deserialized, replication) + + def description : String = { + var result = "" + result += (if (useDisk) "Disk " else "") + result += (if (useMemory) "Memory " else "") + result += (if (deserialized) "Deserialized " else "Serialized") + result += "%sx Replicated".format(replication) + result + } } object StorageLevel { diff --git a/core/src/main/scala/spark/storage/StorageUtils.scala b/core/src/main/scala/spark/storage/StorageUtils.scala new file mode 100644 index 0000000000..ebc7390ee5 --- /dev/null +++ b/core/src/main/scala/spark/storage/StorageUtils.scala @@ -0,0 +1,78 @@ +package spark.storage + +import spark.SparkContext + +private[spark] +case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long, + blocks: Map[String, BlockStatus]) { + + def memUsed(blockPrefix: String = "") = { + blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.memSize). + reduceOption(_+_).getOrElse(0l) + } + + def diskUsed(blockPrefix: String = "") = { + blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.diskSize). + reduceOption(_+_).getOrElse(0l) + } + + def memRemaining : Long = maxMem - memUsed() + +} + +case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel, + numPartitions: Int, memSize: Long, diskSize: Long, locations: Array[BlockManagerId]) + + +/* Helper methods for storage-related objects */ +private[spark] +object StorageUtils { + + /* Given the current storage status of the BlockManager, returns information for each RDD */ + def rddInfoFromStorageStatus(storageStatusList: Array[StorageStatus], + sc: SparkContext) : Array[RDDInfo] = { + rddInfoFromBlockStatusList(storageStatusList.flatMap(_.blocks).toMap, sc) + } + + /* Given a list of BlockStatus objets, returns information for each RDD */ + def rddInfoFromBlockStatusList(infos: Map[String, BlockStatus], + sc: SparkContext) : Array[RDDInfo] = { + // Find all RDD Blocks (ignore broadcast variables) + val rddBlocks = infos.filterKeys(_.startsWith("rdd")) + + // Group by rddId, ignore the partition name + val groupedRddBlocks = infos.groupBy { case(k, v) => + k.substring(0,k.lastIndexOf('_')) + }.mapValues(_.values.toArray) + + // For each RDD, generate an RDDInfo object + groupedRddBlocks.map { case(rddKey, rddBlocks) => + + // Add up memory and disk sizes + val memSize = rddBlocks.map(_.memSize).reduce(_ + _) + val diskSize = rddBlocks.map(_.diskSize).reduce(_ + _) + + // Find the id of the RDD, e.g. rdd_1 => 1 + val rddId = rddKey.split("_").last.toInt + // Get the friendly name for the rdd, if available. + val rddName = Option(sc.persistentRdds.get(rddId).name).getOrElse(rddKey) + val rddStorageLevel = sc.persistentRdds.get(rddId).getStorageLevel + + RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, memSize, diskSize, + rddBlocks.map(_.blockManagerId)) + }.toArray + } + + /* Removes all BlockStatus object that are not part of a block prefix */ + def filterStorageStatusByPrefix(storageStatusList: Array[StorageStatus], + prefix: String) : Array[StorageStatus] = { + + storageStatusList.map { status => + val newBlocks = status.blocks.filterKeys(_.startsWith(prefix)) + //val newRemainingMem = status.maxMem - newBlocks.values.map(_.memSize).reduce(_ + _) + StorageStatus(status.blockManagerId, status.maxMem, newBlocks) + } + + } + +} \ No newline at end of file diff --git a/core/src/main/twirl/spark/storage/index.scala.html b/core/src/main/twirl/spark/storage/index.scala.html index fa7dad51ee..2b337f6133 100644 --- a/core/src/main/twirl/spark/storage/index.scala.html +++ b/core/src/main/twirl/spark/storage/index.scala.html @@ -1,4 +1,5 @@ -@(maxMem: Long, remainingMem: Long, diskSpaceUsed: Long, rdds: List[spark.storage.RDDInfo]) +@(maxMem: Long, remainingMem: Long, diskSpaceUsed: Long, rdds: Array[spark.storage.RDDInfo], storageStatusList: Array[spark.storage.StorageStatus]) +@import spark.Utils @spark.common.html.layout(title = "Storage Dashboard") { @@ -7,16 +8,16 @@
  • Memory: - @{spark.Utils.memoryBytesToString(maxMem - remainingMem)} Used - (@{spark.Utils.memoryBytesToString(remainingMem)} Available)
  • -
  • Disk: @{spark.Utils.memoryBytesToString(diskSpaceUsed)} Used
  • + @{Utils.memoryBytesToString(maxMem - remainingMem)} Used + (@{Utils.memoryBytesToString(remainingMem)} Available) +
  • Disk: @{Utils.memoryBytesToString(diskSpaceUsed)} Used

- +

RDD Summary

@@ -25,4 +26,15 @@
+
+ + +
+
+

Worker Summary

+
+ @worker_table(storageStatusList) +
+
+ } \ No newline at end of file diff --git a/core/src/main/twirl/spark/storage/rdd.scala.html b/core/src/main/twirl/spark/storage/rdd.scala.html index 075289c826..ac7f8c981f 100644 --- a/core/src/main/twirl/spark/storage/rdd.scala.html +++ b/core/src/main/twirl/spark/storage/rdd.scala.html @@ -1,4 +1,5 @@ -@(rddInfo: spark.storage.RDDInfo, blocks: Map[String, spark.storage.BlockStatus]) +@(rddInfo: spark.storage.RDDInfo, storageStatusList: Array[spark.storage.StorageStatus]) +@import spark.Utils @spark.common.html.layout(title = "RDD Info ") { @@ -8,21 +9,18 @@
  • Storage Level: - @(if (rddInfo.storageLevel.useDisk) "Disk" else "") - @(if (rddInfo.storageLevel.useMemory) "Memory" else "") - @(if (rddInfo.storageLevel.deserialized) "Deserialized" else "") - @(rddInfo.storageLevel.replication)x Replicated + @(rddInfo.storageLevel.description)
  • Partitions: @(rddInfo.numPartitions)
  • Memory Size: - @{spark.Utils.memoryBytesToString(rddInfo.memSize)} + @{Utils.memoryBytesToString(rddInfo.memSize)}
  • Disk Size: - @{spark.Utils.memoryBytesToString(rddInfo.diskSize)} + @{Utils.memoryBytesToString(rddInfo.diskSize)}
@@ -36,6 +34,7 @@

RDD Summary


+ @@ -47,17 +46,14 @@ - @blocks.map { case (k,v) => + @storageStatusList.flatMap(_.blocks).toArray.sortWith(_._1 < _._1).map { case (k,v) => - - + + } @@ -67,4 +63,15 @@ +
+ + +
+
+

Worker Summary

+
+ @worker_table(storageStatusList, "rdd_" + rddInfo.id ) +
+
+ } \ No newline at end of file diff --git a/core/src/main/twirl/spark/storage/rdd_row.scala.html b/core/src/main/twirl/spark/storage/rdd_row.scala.html deleted file mode 100644 index 3dd9944e3b..0000000000 --- a/core/src/main/twirl/spark/storage/rdd_row.scala.html +++ /dev/null @@ -1,18 +0,0 @@ -@(rdd: spark.storage.RDDInfo) - - - - - - - - \ No newline at end of file diff --git a/core/src/main/twirl/spark/storage/rdd_table.scala.html b/core/src/main/twirl/spark/storage/rdd_table.scala.html index 24f55ccefb..af801cf229 100644 --- a/core/src/main/twirl/spark/storage/rdd_table.scala.html +++ b/core/src/main/twirl/spark/storage/rdd_table.scala.html @@ -1,4 +1,5 @@ -@(rdds: List[spark.storage.RDDInfo]) +@(rdds: Array[spark.storage.RDDInfo]) +@import spark.Utils
@k - @(if (v.storageLevel.useDisk) "Disk" else "") - @(if (v.storageLevel.useMemory) "Memory" else "") - @(if (v.storageLevel.deserialized) "Deserialized" else "") - @(v.storageLevel.replication)x Replicated + @(v.storageLevel.description) @{spark.Utils.memoryBytesToString(v.memSize)}@{spark.Utils.memoryBytesToString(v.diskSize)}@{Utils.memoryBytesToString(v.memSize)}@{Utils.memoryBytesToString(v.diskSize)}
- - @rdd.name - - - @(if (rdd.storageLevel.useDisk) "Disk" else "") - @(if (rdd.storageLevel.useMemory) "Memory" else "") - @(if (rdd.storageLevel.deserialized) "Deserialized" else "") - @(rdd.storageLevel.replication)x Replicated - @rdd.numPartitions@{spark.Utils.memoryBytesToString(rdd.memSize)}@{spark.Utils.memoryBytesToString(rdd.diskSize)}
@@ -12,7 +13,18 @@ @for(rdd <- rdds) { - @rdd_row(rdd) + + + + + + + }
+ + @rdd.name + + @(rdd.storageLevel.description) + @rdd.numPartitions@{Utils.memoryBytesToString(rdd.memSize)}@{Utils.memoryBytesToString(rdd.diskSize)}
\ No newline at end of file diff --git a/core/src/main/twirl/spark/storage/worker_table.scala.html b/core/src/main/twirl/spark/storage/worker_table.scala.html new file mode 100644 index 0000000000..d54b8de4cc --- /dev/null +++ b/core/src/main/twirl/spark/storage/worker_table.scala.html @@ -0,0 +1,24 @@ +@(workersStatusList: Array[spark.storage.StorageStatus], prefix: String = "") +@import spark.Utils + + + + + + + + + + + @for(status <- workersStatusList) { + + + + + + } + +
HostMemory UsageDisk Usage
@(status.blockManagerId.ip + ":" + status.blockManagerId.port) + @(Utils.memoryBytesToString(status.memUsed(prefix))) + (@(Utils.memoryBytesToString(status.memRemaining)) Total Available) + @(Utils.memoryBytesToString(status.diskUsed(prefix)))
\ No newline at end of file From ccd075cf960df6c6c449b709515cdd81499a52be Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 24 Dec 2012 15:01:13 -0800 Subject: [PATCH 025/291] Reduce object overhead in Pyspark shuffle and collect --- pyspark/pyspark/rdd.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 85a24c6854..708ea6eb55 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -145,8 +145,10 @@ class RDD(object): self.map(f).collect() # Force evaluation def collect(self): - pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().collect()) - return load_pickle(bytes(pickle)) + def asList(iterator): + yield list(iterator) + pickles = self.mapPartitions(asList)._jrdd.rdd().collect() + return list(chain.from_iterable(load_pickle(bytes(p)) for p in pickles)) def reduce(self, f): """ @@ -319,16 +321,23 @@ class RDD(object): if numSplits is None: numSplits = self.ctx.defaultParallelism def add_shuffle_key(iterator): + buckets = defaultdict(list) for (k, v) in iterator: - yield str(hashFunc(k)) - yield dump_pickle((k, v)) + buckets[hashFunc(k) % numSplits].append((k, v)) + for (split, items) in buckets.iteritems(): + yield str(split) + yield dump_pickle(items) keyed = PipelinedRDD(self, add_shuffle_key) keyed._bypass_serializer = True pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits) + # Transferring O(n) objects to Java is too expensive. Instead, we'll + # form the hash buckets in Python, transferring O(numSplits) objects + # to Java. Each object is a (splitNumber, [objects]) pair. jrdd = pairRDD.partitionBy(partitioner) jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) - return RDD(jrdd, self.ctx) + # Flatten the resulting RDD: + return RDD(jrdd, self.ctx).flatMap(lambda items: items) def combineByKey(self, createCombiner, mergeValue, mergeCombiners, numSplits=None): From 4608902fb87af64a15b97ab21fe6382cd6e5a644 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 24 Dec 2012 17:20:10 -0800 Subject: [PATCH 026/291] Use filesystem to collect RDDs in PySpark. Passing large volumes of data through Py4J seems to be slow. It appears to be faster to write the data to the local filesystem and read it back from Python. --- .../scala/spark/api/python/PythonRDD.scala | 66 +++++++------------ pyspark/pyspark/context.py | 9 ++- pyspark/pyspark/rdd.py | 34 ++++++++-- pyspark/pyspark/serializers.py | 8 +++ pyspark/pyspark/worker.py | 12 +--- 5 files changed, 66 insertions(+), 63 deletions(-) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 50094d6b0f..4f870e837a 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -1,6 +1,7 @@ package spark.api.python import java.io._ +import java.util.{List => JList} import scala.collection.Map import scala.collection.JavaConversions._ @@ -59,36 +60,7 @@ trait PythonRDDBase { } out.flush() for (elem <- parent.iterator(split)) { - if (elem.isInstanceOf[Array[Byte]]) { - val arr = elem.asInstanceOf[Array[Byte]] - dOut.writeInt(arr.length) - dOut.write(arr) - } else if (elem.isInstanceOf[scala.Tuple2[_, _]]) { - val t = elem.asInstanceOf[scala.Tuple2[_, _]] - val t1 = t._1.asInstanceOf[Array[Byte]] - val t2 = t._2.asInstanceOf[Array[Byte]] - val length = t1.length + t2.length - 3 - 3 + 4 // stripPickle() removes 3 bytes - dOut.writeInt(length) - dOut.writeByte(Pickle.PROTO) - dOut.writeByte(Pickle.TWO) - dOut.write(PythonRDD.stripPickle(t1)) - dOut.write(PythonRDD.stripPickle(t2)) - dOut.writeByte(Pickle.TUPLE2) - dOut.writeByte(Pickle.STOP) - } else if (elem.isInstanceOf[String]) { - // For uniformity, strings are wrapped into Pickles. - val s = elem.asInstanceOf[String].getBytes("UTF-8") - val length = 2 + 1 + 4 + s.length + 1 - dOut.writeInt(length) - dOut.writeByte(Pickle.PROTO) - dOut.writeByte(Pickle.TWO) - dOut.writeByte(Pickle.BINUNICODE) - dOut.writeInt(Integer.reverseBytes(s.length)) - dOut.write(s) - dOut.writeByte(Pickle.STOP) - } else { - throw new Exception("Unexpected RDD type") - } + PythonRDD.writeAsPickle(elem, dOut) } dOut.flush() out.flush() @@ -174,36 +146,45 @@ object PythonRDD { arr.slice(2, arr.length - 1) } - def asPickle(elem: Any) : Array[Byte] = { - val baos = new ByteArrayOutputStream(); - val dOut = new DataOutputStream(baos); + /** + * Write strings, pickled Python objects, or pairs of pickled objects to a data output stream. + * The data format is a 32-bit integer representing the pickled object's length (in bytes), + * followed by the pickled data. + * @param elem the object to write + * @param dOut a data output stream + */ + def writeAsPickle(elem: Any, dOut: DataOutputStream) { if (elem.isInstanceOf[Array[Byte]]) { - elem.asInstanceOf[Array[Byte]] + val arr = elem.asInstanceOf[Array[Byte]] + dOut.writeInt(arr.length) + dOut.write(arr) } else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) { val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]] + val length = t._1.length + t._2.length - 3 - 3 + 4 // stripPickle() removes 3 bytes + dOut.writeInt(length) dOut.writeByte(Pickle.PROTO) dOut.writeByte(Pickle.TWO) dOut.write(PythonRDD.stripPickle(t._1)) dOut.write(PythonRDD.stripPickle(t._2)) dOut.writeByte(Pickle.TUPLE2) dOut.writeByte(Pickle.STOP) - baos.toByteArray() } else if (elem.isInstanceOf[String]) { // For uniformity, strings are wrapped into Pickles. val s = elem.asInstanceOf[String].getBytes("UTF-8") + val length = 2 + 1 + 4 + s.length + 1 + dOut.writeInt(length) dOut.writeByte(Pickle.PROTO) dOut.writeByte(Pickle.TWO) dOut.write(Pickle.BINUNICODE) dOut.writeInt(Integer.reverseBytes(s.length)) dOut.write(s) dOut.writeByte(Pickle.STOP) - baos.toByteArray() } else { throw new Exception("Unexpected RDD type") } } - def pickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) : + def readRDDFromPickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) : JavaRDD[Array[Byte]] = { val file = new DataInputStream(new FileInputStream(filename)) val objs = new collection.mutable.ArrayBuffer[Array[Byte]] @@ -221,11 +202,12 @@ object PythonRDD { JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) } - def arrayAsPickle(arr : Any) : Array[Byte] = { - val pickles : Array[Byte] = arr.asInstanceOf[Array[Any]].map(asPickle).map(stripPickle).flatten - - Array[Byte](Pickle.PROTO, Pickle.TWO, Pickle.EMPTY_LIST, Pickle.MARK) ++ pickles ++ - Array[Byte] (Pickle.APPENDS, Pickle.STOP) + def writeArrayToPickleFile[T](items: Array[T], filename: String) { + val file = new DataOutputStream(new FileOutputStream(filename)) + for (item <- items) { + writeAsPickle(item, file) + } + file.close() } } diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index 50d57e5317..19f9f9e133 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -14,9 +14,8 @@ class SparkContext(object): gateway = launch_gateway() jvm = gateway.jvm - pickleFile = jvm.spark.api.python.PythonRDD.pickleFile - asPickle = jvm.spark.api.python.PythonRDD.asPickle - arrayAsPickle = jvm.spark.api.python.PythonRDD.arrayAsPickle + readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile + writeArrayToPickleFile = jvm.PythonRDD.writeArrayToPickleFile def __init__(self, master, name, defaultParallelism=None): self.master = master @@ -45,11 +44,11 @@ class SparkContext(object): # because it sends O(n) Py4J commands. As an alternative, serialized # objects are written to a file and loaded through textFile(). tempFile = NamedTemporaryFile(delete=False) + atexit.register(lambda: os.unlink(tempFile.name)) for x in c: write_with_length(dump_pickle(x), tempFile) tempFile.close() - atexit.register(lambda: os.unlink(tempFile.name)) - jrdd = self.pickleFile(self._jsc, tempFile.name, numSlices) + jrdd = self.readRDDFromPickleFile(self._jsc, tempFile.name, numSlices) return RDD(jrdd, self) def textFile(self, name, minSplits=None): diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 708ea6eb55..01908cff96 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -1,13 +1,15 @@ +import atexit from base64 import standard_b64encode as b64enc from collections import defaultdict from itertools import chain, ifilter, imap import os import shlex from subprocess import Popen, PIPE +from tempfile import NamedTemporaryFile from threading import Thread from pyspark import cloudpickle -from pyspark.serializers import dump_pickle, load_pickle +from pyspark.serializers import dump_pickle, load_pickle, read_from_pickle_file from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup @@ -145,10 +147,30 @@ class RDD(object): self.map(f).collect() # Force evaluation def collect(self): + # To minimize the number of transfers between Python and Java, we'll + # flatten each partition into a list before collecting it. Due to + # pipelining, this should add minimal overhead. def asList(iterator): yield list(iterator) - pickles = self.mapPartitions(asList)._jrdd.rdd().collect() - return list(chain.from_iterable(load_pickle(bytes(p)) for p in pickles)) + picklesInJava = self.mapPartitions(asList)._jrdd.rdd().collect() + return list(chain.from_iterable(self._collect_array_through_file(picklesInJava))) + + def _collect_array_through_file(self, array): + # Transferring lots of data through Py4J can be slow because + # socket.readline() is inefficient. Instead, we'll dump the data to a + # file and read it back. + tempFile = NamedTemporaryFile(delete=False) + tempFile.close() + def clean_up_file(): + try: os.unlink(tempFile.name) + except: pass + atexit.register(clean_up_file) + self.ctx.writeArrayToPickleFile(array, tempFile.name) + # Read the data into Python and deserialize it: + with open(tempFile.name, 'rb') as tempFile: + for item in read_from_pickle_file(tempFile): + yield item + os.unlink(tempFile.name) def reduce(self, f): """ @@ -220,15 +242,15 @@ class RDD(object): >>> sc.parallelize([2, 3, 4]).take(2) [2, 3] """ - pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().take(num)) - return load_pickle(bytes(pickle)) + picklesInJava = self._jrdd.rdd().take(num) + return list(self._collect_array_through_file(picklesInJava)) def first(self): """ >>> sc.parallelize([2, 3, 4]).first() 2 """ - return load_pickle(bytes(self.ctx.asPickle(self._jrdd.first()))) + return self.take(1)[0] def saveAsTextFile(self, path): def func(iterator): diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py index 21ef8b106c..bfcdda8f12 100644 --- a/pyspark/pyspark/serializers.py +++ b/pyspark/pyspark/serializers.py @@ -33,3 +33,11 @@ def read_with_length(stream): if obj == "": raise EOFError return obj + + +def read_from_pickle_file(stream): + try: + while True: + yield load_pickle(read_with_length(stream)) + except EOFError: + return diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py index 62824a1c9b..9f6b507dbd 100644 --- a/pyspark/pyspark/worker.py +++ b/pyspark/pyspark/worker.py @@ -8,7 +8,7 @@ from base64 import standard_b64decode from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler from pyspark.serializers import write_with_length, read_with_length, \ - read_long, read_int, dump_pickle, load_pickle + read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file # Redirect stdout to stderr so that users must return values from functions. @@ -20,14 +20,6 @@ def load_obj(): return load_pickle(standard_b64decode(sys.stdin.readline().strip())) -def read_input(): - try: - while True: - yield load_pickle(read_with_length(sys.stdin)) - except EOFError: - return - - def main(): num_broadcast_variables = read_int(sys.stdin) for _ in range(num_broadcast_variables): @@ -40,7 +32,7 @@ def main(): dumps = lambda x: x else: dumps = dump_pickle - for obj in func(read_input()): + for obj in func(read_from_pickle_file(sys.stdin)): write_with_length(dumps(obj), old_stdout) From e2dad15621f5dc15275b300df05483afde5025a0 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 26 Dec 2012 17:34:24 -0800 Subject: [PATCH 027/291] Add support for batched serialization of Python objects in PySpark. --- pyspark/pyspark/context.py | 3 +- pyspark/pyspark/rdd.py | 57 +++++++++++++++++++++++----------- pyspark/pyspark/serializers.py | 34 +++++++++++++++++++- 3 files changed, 74 insertions(+), 20 deletions(-) diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index 19f9f9e133..032619693a 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -17,13 +17,14 @@ class SparkContext(object): readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile writeArrayToPickleFile = jvm.PythonRDD.writeArrayToPickleFile - def __init__(self, master, name, defaultParallelism=None): + def __init__(self, master, name, defaultParallelism=None, batchSize=-1): self.master = master self.name = name self._jsc = self.jvm.JavaSparkContext(master, name) self.defaultParallelism = \ defaultParallelism or self._jsc.sc().defaultParallelism() self.pythonExec = os.environ.get("PYSPARK_PYTHON_EXEC", 'python') + self.batchSize = batchSize # -1 represents a unlimited batch size # Broadcast's __reduce__ method stores Broadcast instances here. # This allows other code to determine which Broadcast instances have # been pickled, so it can determine which Java broadcast objects to diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 01908cff96..d7081dffd2 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -2,6 +2,7 @@ import atexit from base64 import standard_b64encode as b64enc from collections import defaultdict from itertools import chain, ifilter, imap +import operator import os import shlex from subprocess import Popen, PIPE @@ -9,7 +10,8 @@ from tempfile import NamedTemporaryFile from threading import Thread from pyspark import cloudpickle -from pyspark.serializers import dump_pickle, load_pickle, read_from_pickle_file +from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \ + read_from_pickle_file from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup @@ -83,6 +85,11 @@ class RDD(object): >>> rdd = sc.parallelize([1, 1, 2, 3]) >>> rdd.union(rdd).collect() [1, 1, 2, 3, 1, 1, 2, 3] + + # Union of batched and unbatched RDDs: + >>> batchedRDD = sc.parallelize([Batch([1, 2, 3, 4, 5])]) + >>> rdd.union(batchedRDD).collect() + [1, 1, 2, 3, 1, 2, 3, 4, 5] """ return RDD(self._jrdd.union(other._jrdd), self.ctx) @@ -147,13 +154,8 @@ class RDD(object): self.map(f).collect() # Force evaluation def collect(self): - # To minimize the number of transfers between Python and Java, we'll - # flatten each partition into a list before collecting it. Due to - # pipelining, this should add minimal overhead. - def asList(iterator): - yield list(iterator) - picklesInJava = self.mapPartitions(asList)._jrdd.rdd().collect() - return list(chain.from_iterable(self._collect_array_through_file(picklesInJava))) + picklesInJava = self._jrdd.rdd().collect() + return list(self._collect_array_through_file(picklesInJava)) def _collect_array_through_file(self, array): # Transferring lots of data through Py4J can be slow because @@ -214,12 +216,21 @@ class RDD(object): # TODO: aggregate + def sum(self): + """ + >>> sc.parallelize([1.0, 2.0, 3.0]).sum() + 6.0 + """ + return self.mapPartitions(lambda x: [sum(x)]).reduce(operator.add) + def count(self): """ >>> sc.parallelize([2, 3, 4]).count() - 3L + 3 + >>> sc.parallelize([Batch([2, 3, 4])]).count() + 3 """ - return self._jrdd.count() + return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum() def countByValue(self): """ @@ -342,24 +353,23 @@ class RDD(object): """ if numSplits is None: numSplits = self.ctx.defaultParallelism + # Transferring O(n) objects to Java is too expensive. Instead, we'll + # form the hash buckets in Python, transferring O(numSplits) objects + # to Java. Each object is a (splitNumber, [objects]) pair. def add_shuffle_key(iterator): buckets = defaultdict(list) for (k, v) in iterator: buckets[hashFunc(k) % numSplits].append((k, v)) for (split, items) in buckets.iteritems(): yield str(split) - yield dump_pickle(items) + yield dump_pickle(Batch(items)) keyed = PipelinedRDD(self, add_shuffle_key) keyed._bypass_serializer = True pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits) - # Transferring O(n) objects to Java is too expensive. Instead, we'll - # form the hash buckets in Python, transferring O(numSplits) objects - # to Java. Each object is a (splitNumber, [objects]) pair. jrdd = pairRDD.partitionBy(partitioner) jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) - # Flatten the resulting RDD: - return RDD(jrdd, self.ctx).flatMap(lambda items: items) + return RDD(jrdd, self.ctx) def combineByKey(self, createCombiner, mergeValue, mergeCombiners, numSplits=None): @@ -478,8 +488,19 @@ class PipelinedRDD(RDD): def _jrdd(self): if self._jrdd_val: return self._jrdd_val - funcs = [self.func, self._bypass_serializer] - pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in funcs) + func = self.func + if not self._bypass_serializer and self.ctx.batchSize != 1: + oldfunc = self.func + batchSize = self.ctx.batchSize + if batchSize == -1: # unlimited batch size + def batched_func(iterator): + yield Batch(list(oldfunc(iterator))) + else: + def batched_func(iterator): + return batched(oldfunc(iterator), batchSize) + func = batched_func + cmds = [func, self._bypass_serializer] + pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], self.ctx.gateway._gateway_client) diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py index bfcdda8f12..4ed925697c 100644 --- a/pyspark/pyspark/serializers.py +++ b/pyspark/pyspark/serializers.py @@ -2,6 +2,33 @@ import struct import cPickle +class Batch(object): + """ + Used to store multiple RDD entries as a single Java object. + + This relieves us from having to explicitly track whether an RDD + is stored as batches of objects and avoids problems when processing + the union() of batched and unbatched RDDs (e.g. the union() of textFile() + with another RDD). + """ + def __init__(self, items): + self.items = items + + +def batched(iterator, batchSize): + items = [] + count = 0 + for item in iterator: + items.append(item) + count += 1 + if count == batchSize: + yield Batch(items) + items = [] + count = [] + if items: + yield Batch(items) + + def dump_pickle(obj): return cPickle.dumps(obj, 2) @@ -38,6 +65,11 @@ def read_with_length(stream): def read_from_pickle_file(stream): try: while True: - yield load_pickle(read_with_length(stream)) + obj = load_pickle(read_with_length(stream)) + if type(obj) == Batch: # We don't care about inheritance + for item in obj.items: + yield item + else: + yield obj except EOFError: return From 1dca0c51804b9c94709ec9cc0544b8dfb7afe59f Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 26 Dec 2012 18:23:06 -0800 Subject: [PATCH 028/291] Remove debug output from PythonPartitioner. --- core/src/main/scala/spark/api/python/PythonPartitioner.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/core/src/main/scala/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/spark/api/python/PythonPartitioner.scala index ef9f808fb2..606a80d1eb 100644 --- a/core/src/main/scala/spark/api/python/PythonPartitioner.scala +++ b/core/src/main/scala/spark/api/python/PythonPartitioner.scala @@ -16,8 +16,6 @@ class PythonPartitioner(override val numPartitions: Int) extends Partitioner { else { val hashCode = { if (key.isInstanceOf[Array[Byte]]) { - System.err.println("Dumping a byte array!" + Arrays.hashCode(key.asInstanceOf[Array[Byte]]) - ) Arrays.hashCode(key.asInstanceOf[Array[Byte]]) } else From 2d98fff0651af4d527f41ba50c01f453fa049464 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 27 Dec 2012 10:13:29 -0800 Subject: [PATCH 029/291] Add IPython support to pyspark-shell. Suggested by / based on code from @MLnick --- pyspark/README | 3 +++ pyspark/pyspark/shell.py | 25 +++++++++++++++++-------- pyspark/requirements.txt | 1 + 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/pyspark/README b/pyspark/README index 55490e1a83..461176de7d 100644 --- a/pyspark/README +++ b/pyspark/README @@ -38,6 +38,9 @@ interacting with Java processes. It can be installed from https://github.com/bartdag/py4j; make sure to install a version that contains at least the commits through b7924aabe9. +PySpark requires the `argparse` module, which is included in Python 2.7 +and is is available for Python 2.6 through `pip` or `easy_install`. + PySpark uses the `PYTHONPATH` environment variable to search for Python classes; Py4J should be on this path, along with any libraries used by PySpark programs. `PYTHONPATH` will be automatically shipped to worker diff --git a/pyspark/pyspark/shell.py b/pyspark/pyspark/shell.py index 7ef30894cb..7012884abc 100644 --- a/pyspark/pyspark/shell.py +++ b/pyspark/pyspark/shell.py @@ -1,21 +1,30 @@ """ An interactive shell. """ +import argparse # argparse is avaiable for Python < 2.7 through easy_install. import code import sys from pyspark.context import SparkContext -def main(master='local'): +def main(master='local', ipython=False): sc = SparkContext(master, 'PySparkShell') - print "Spark context available as sc." - code.interact(local={'sc': sc}) + user_ns = {'sc' : sc} + banner = "Spark context avaiable as sc." + if ipython: + import IPython + IPython.embed(user_ns=user_ns, banner2=banner) + else: + print banner + code.interact(local=user_ns) if __name__ == '__main__': - if len(sys.argv) > 1: - master = sys.argv[1] - else: - master = 'local' - main(master) + parser = argparse.ArgumentParser() + parser.add_argument("master", help="Spark master host (default='local')", + nargs='?', type=str, default="local") + parser.add_argument("-i", "--ipython", help="Run IPython shell", + action="store_true") + args = parser.parse_args() + main(args.master, args.ipython) diff --git a/pyspark/requirements.txt b/pyspark/requirements.txt index 48fa2ab105..2464ca0074 100644 --- a/pyspark/requirements.txt +++ b/pyspark/requirements.txt @@ -4,3 +4,4 @@ # install Py4J from git once https://github.com/pypa/pip/pull/526 is merged. # git+git://github.com/bartdag/py4j.git@b7924aabe9c5e63f0a4d8bbd17019534c7ec014e +argparse From 85b8f2c64f0fc4be5645d8736629fc082cb3587b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 27 Dec 2012 17:55:33 -0800 Subject: [PATCH 030/291] Add epydoc API documentation for PySpark. --- docs/README.md | 8 +- docs/_layouts/global.html | 10 +- docs/_plugins/copy_api_dirs.rb | 17 +++ pyspark/epydoc.conf | 19 ++++ pyspark/pyspark/context.py | 24 ++++ pyspark/pyspark/rdd.py | 195 ++++++++++++++++++++++++++++++--- 6 files changed, 254 insertions(+), 19 deletions(-) create mode 100644 pyspark/epydoc.conf diff --git a/docs/README.md b/docs/README.md index 092153070e..887f407f18 100644 --- a/docs/README.md +++ b/docs/README.md @@ -25,10 +25,12 @@ To mark a block of code in your markdown to be syntax highlighted by jekyll duri // supported languages too. {% endhighlight %} -## Scaladoc +## API Docs (Scaladoc and Epydoc) You can build just the Spark scaladoc by running `sbt/sbt doc` from the SPARK_PROJECT_ROOT directory. -When you run `jekyll` in the docs directory, it will also copy over the scala doc for the various Spark subprojects into the docs directory (and then also into the _site directory). We use a jekyll plugin to run `sbt/sbt doc` before building the site so if you haven't run it (recently) it may take some time as it generates all of the scaladoc. +Similarly, you can build just the PySpark epydoc by running `epydoc --config epydoc.conf` from the SPARK_PROJECT_ROOT/pyspark directory. -NOTE: To skip the step of building and copying over the scaladoc when you build the docs, run `SKIP_SCALADOC=1 jekyll`. +When you run `jekyll` in the docs directory, it will also copy over the scaladoc for the various Spark subprojects into the docs directory (and then also into the _site directory). We use a jekyll plugin to run `sbt/sbt doc` before building the site so if you haven't run it (recently) it may take some time as it generates all of the scaladoc. The jekyll plugin also generates the PySpark docs using [epydoc](http://epydoc.sourceforge.net/). + +NOTE: To skip the step of building and copying over the scaladoc when you build the docs, run `SKIP_SCALADOC=1 jekyll`. Similarly, `SKIP_EPYDOC=1 jekyll` will skip PySpark API doc generation. diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 41ad5242c9..43a5fa3e1c 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -49,8 +49,14 @@
  • Java
  • - -
  • API (Scaladoc)
  • + +
  • Quick Start
  • Scala
  • Java
  • +
  • Python
  • diff --git a/docs/api.md b/docs/api.md index 43548b223c..b9c93ac5e8 100644 --- a/docs/api.md +++ b/docs/api.md @@ -8,3 +8,4 @@ Here you can find links to the Scaladoc generated for the Spark sbt subprojects. - [Core](api/core/index.html) - [Examples](api/examples/index.html) - [Bagel](api/bagel/index.html) +- [PySpark](api/pyspark/index.html) diff --git a/docs/index.md b/docs/index.md index ed9953a590..33ab58a962 100644 --- a/docs/index.md +++ b/docs/index.md @@ -7,11 +7,11 @@ title: Spark Overview TODO(andyk): Rewrite to make the Java API a first class part of the story. {% endcomment %} -Spark is a MapReduce-like cluster computing framework designed for low-latency iterative jobs and interactive use from an -interpreter. It provides clean, language-integrated APIs in Scala and Java, with a rich array of parallel operators. Spark can -run on top of the [Apache Mesos](http://incubator.apache.org/mesos/) cluster manager, +Spark is a MapReduce-like cluster computing framework designed for low-latency iterative jobs and interactive use from an interpreter. +It provides clean, language-integrated APIs in Scala, Java, and Python, with a rich array of parallel operators. +Spark can run on top of the [Apache Mesos](http://incubator.apache.org/mesos/) cluster manager, [Hadoop YARN](http://hadoop.apache.org/docs/r2.0.1-alpha/hadoop-yarn/hadoop-yarn-site/YARN.html), -Amazon EC2, or without an independent resource manager ("standalone mode"). +Amazon EC2, or without an independent resource manager ("standalone mode"). # Downloading @@ -59,6 +59,7 @@ of `project/SparkBuild.scala`, then rebuilding Spark (`sbt/sbt clean compile`). * [Quick Start](quick-start.html): a quick introduction to the Spark API; start here! * [Spark Programming Guide](scala-programming-guide.html): an overview of Spark concepts, and details on the Scala API * [Java Programming Guide](java-programming-guide.html): using Spark from Java +* [Python Programming Guide](python-programming-guide.html): using Spark from Python **Deployment guides:** @@ -72,7 +73,7 @@ of `project/SparkBuild.scala`, then rebuilding Spark (`sbt/sbt clean compile`). * [Configuration](configuration.html): customize Spark via its configuration system * [Tuning Guide](tuning.html): best practices to optimize performance and memory use -* [API Docs (Scaladoc)](api/core/index.html) +* API Docs: [Java/Scala (Scaladoc)](api/core/index.html) and [Python (Epydoc)](api/pyspark/index.html) * [Bagel](bagel-programming-guide.html): an implementation of Google's Pregel on Spark * [Contributing to Spark](contributing-to-spark.html) diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md new file mode 100644 index 0000000000..b7c747f905 --- /dev/null +++ b/docs/python-programming-guide.md @@ -0,0 +1,74 @@ +--- +layout: global +title: Python Programming Guide +--- + + +The Spark Python API (PySpark) exposes most of the Spark features available in the Scala version to Python. +To learn the basics of Spark, we recommend reading through the +[Scala programming guide](scala-programming-guide.html) first; it should be +easy to follow even if you don't know Scala. +This guide will show how to use the Spark features described there in Python. + +# Key Differences in the Python API + +There are a few key differences between the Python and Scala APIs: + +* Python is dynamically typed, so RDDs can hold objects of different types. +* PySpark does not currently support the following Spark features: + - Accumulators + - Special functions on RRDs of doubles, such as `mean` and `stdev` + - Approximate jobs / functions, such as `countApprox` and `sumApprox`. + - `lookup` + - `mapPartitionsWithSplit` + - `persist` at storage levels other than `MEMORY_ONLY` + - `sample` + - `sort` + + +# Installing and Configuring PySpark + +PySpark requires Python 2.6 or higher. +PySpark jobs are executed using a standard cPython interpreter in order to support Python modules that use C extensions. +We have not tested PySpark with Python 3 or with alternative Python interpreters, such as [PyPy](http://pypy.org/) or [Jython](http://www.jython.org/). +By default, PySpark's scripts will run programs using `python`; an alternate Python executable may be specified by setting the `PYSPARK_PYTHON` environment variable in `conf/spark-env.sh`. + +All of PySpark's library dependencies, including [Py4J](http://py4j.sourceforge.net/), are bundled with PySpark and automatically imported. + +Standalone PySpark jobs should be run using the `run-pyspark` script, which automatically configures the Java and Python environmnt using the settings in `conf/spark-env.sh`. +The script automatically adds the `pyspark` package to the `PYTHONPATH`. + + +# Interactive Use + +PySpark's `pyspark-shell` script provides a simple way to learn the API: + +{% highlight python %} +>>> words = sc.textFile("/usr/share/dict/words") +>>> words.filter(lambda w: w.startswith("spar")).take(5) +[u'spar', u'sparable', u'sparada', u'sparadrap', u'sparagrass'] +{% endhighlight %} + +# Standalone Use + +PySpark can also be used from standalone Python scripts by creating a SparkContext in the script and running the script using the `run-pyspark` script in the `pyspark` directory. +The Quick Start guide includes a [complete example](quick-start.html#a-standalone-job-in-python) of a standalone Python job. + +Code dependencies can be deployed by listing them in the `pyFiles` option in the SparkContext constructor: + +{% highlight python %} +from pyspark import SparkContext +sc = SparkContext("local", "Job Name", pyFiles=['MyFile.py', 'lib.zip', 'app.egg']) +{% endhighlight %} + +Files listed here will be added to the `PYTHONPATH` and shipped to remote worker machines. +Code dependencies can be added to an existing SparkContext using its `addPyFile()` method. + +# Where to Go from Here + +PySpark includes several sample programs using the Python API in `pyspark/examples`. +You can run them by passing the files to the `pyspark-run` script included in PySpark -- for example `./pyspark-run examples/wordcount.py`. +Each example program prints usage help when run without any arguments. + +We currently provide [API documentation](api/pyspark/index.html) for the Python API as Epydoc. +Many of the RDD method descriptions contain [doctests](http://docs.python.org/2/library/doctest.html) that provide additional usage examples. diff --git a/docs/quick-start.md b/docs/quick-start.md index defdb34836..c859c31b09 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -6,7 +6,8 @@ title: Quick Start * This will become a table of contents (this text will be scraped). {:toc} -This tutorial provides a quick introduction to using Spark. We will first introduce the API through Spark's interactive Scala shell (don't worry if you don't know Scala -- you will need much for this), then show how to write standalone jobs in Scala and Java. See the [programming guide](scala-programming-guide.html) for a fuller reference. +This tutorial provides a quick introduction to using Spark. We will first introduce the API through Spark's interactive Scala shell (don't worry if you don't know Scala -- you will need much for this), then show how to write standalone jobs in Scala, Java, and Python. +See the [programming guide](scala-programming-guide.html) for a more complete reference. To follow along with this guide, you only need to have successfully built Spark on one machine. Simply go into your Spark directory and run: @@ -230,3 +231,40 @@ Lines with a: 8422, Lines with b: 1836 {% endhighlight %} This example only runs the job locally; for a tutorial on running jobs across several machines, see the [Standalone Mode](spark-standalone.html) documentation, and consider using a distributed input source, such as HDFS. + +# A Standalone Job In Python +Now we will show how to write a standalone job using the Python API (PySpark). + +As an example, we'll create a simple Spark job, `SimpleJob.py`: + +{% highlight python %} +"""SimpleJob.py""" +from pyspark import SparkContext + +logFile = "/var/log/syslog" # Should be some file on your system +sc = SparkContext("local", "Simple job") +logData = sc.textFile(logFile).cache() + +numAs = logData.filter(lambda s: 'a' in s).count() +numBs = logData.filter(lambda s: 'b' in s).count() + +print "Lines with a: %i, lines with b: %i" % (numAs, numBs) +{% endhighlight %} + + +This job simply counts the number of lines containing 'a' and the number containing 'b' in a system log file. +Like in the Scala and Java examples, we use a SparkContext to create RDDs. +We can pass Python functions to Spark, which are automatically serialized along with any variables that they reference. +For jobs that use custom classes or third-party libraries, we can add those code dependencies to SparkContext to ensure that they will be available on remote machines; this is described in more detail in the [Python programming guide](python-programming-guide). +`SimpleJob` is simple enough that we do not need to specify any code dependencies. + +We can run this job using the `run-pyspark` script in `$SPARK_HOME/pyspark`: + +{% highlight python %} +$ cd $SPARK_HOME +$ ./pyspark/run-pyspark SimpleJob.py +... +Lines with a: 8422, Lines with b: 1836 +{% endhighlight python %} + +This example only runs the job locally; for a tutorial on running jobs across several machines, see the [Standalone Mode](spark-standalone.html) documentation, and consider using a distributed input source, such as HDFS. diff --git a/pyspark/README b/pyspark/README deleted file mode 100644 index d8d521c72c..0000000000 --- a/pyspark/README +++ /dev/null @@ -1,42 +0,0 @@ -# PySpark - -PySpark is a Python API for Spark. - -PySpark jobs are writen in Python and executed using a standard Python -interpreter; this supports modules that use Python C extensions. The -API is based on the Spark Scala API and uses regular Python functions -and lambdas to support user-defined functions. PySpark supports -interactive use through a standard Python interpreter; it can -automatically serialize closures and ship them to worker processes. - -PySpark is built on top of the Spark Java API. Data is uniformly -represented as serialized Python objects and stored in Spark Java -processes, which communicate with PySpark worker processes over pipes. - -## Features - -PySpark supports most of the Spark API, including broadcast variables. -RDDs are dynamically typed and can hold any Python object. - -PySpark does not support: - -- Special functions on RDDs of doubles -- Accumulators - -## Examples and Documentation - -The PySpark source contains docstrings and doctests that document its -API. The public classes are in `context.py` and `rdd.py`. - -The `pyspark/pyspark/examples` directory contains a few complete -examples. - -## Installing PySpark -# -To use PySpark, `SPARK_HOME` should be set to the location of the Spark -package. - -## Running PySpark - -The easiest way to run PySpark is to use the `run-pyspark` and -`pyspark-shell` scripts, which are included in the `pyspark` directory. diff --git a/pyspark/pyspark/examples/kmeans.py b/pyspark/examples/kmeans.py similarity index 100% rename from pyspark/pyspark/examples/kmeans.py rename to pyspark/examples/kmeans.py diff --git a/pyspark/pyspark/examples/pi.py b/pyspark/examples/pi.py similarity index 100% rename from pyspark/pyspark/examples/pi.py rename to pyspark/examples/pi.py diff --git a/pyspark/pyspark/examples/tc.py b/pyspark/examples/tc.py similarity index 100% rename from pyspark/pyspark/examples/tc.py rename to pyspark/examples/tc.py diff --git a/pyspark/pyspark/examples/wordcount.py b/pyspark/examples/wordcount.py similarity index 100% rename from pyspark/pyspark/examples/wordcount.py rename to pyspark/examples/wordcount.py diff --git a/pyspark/pyspark/__init__.py b/pyspark/pyspark/__init__.py index 549c2d2711..8f8402b62b 100644 --- a/pyspark/pyspark/__init__.py +++ b/pyspark/pyspark/__init__.py @@ -1,3 +1,9 @@ import sys import os sys.path.insert(0, os.path.join(os.environ["SPARK_HOME"], "pyspark/lib/py4j0.7.egg")) + + +from pyspark.context import SparkContext + + +__all__ = ["SparkContext"] diff --git a/pyspark/pyspark/examples/__init__.py b/pyspark/pyspark/examples/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 From 6ee1ff2663cf1f776dd33e448548a8ddcf974dc6 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 29 Dec 2012 22:22:56 +0000 Subject: [PATCH 036/291] Fix bug in pyspark.serializers.batch; add .gitignore. --- pyspark/.gitignore | 2 ++ pyspark/pyspark/rdd.py | 4 +++- pyspark/pyspark/serializers.py | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) create mode 100644 pyspark/.gitignore diff --git a/pyspark/.gitignore b/pyspark/.gitignore new file mode 100644 index 0000000000..5c56e638f9 --- /dev/null +++ b/pyspark/.gitignore @@ -0,0 +1,2 @@ +*.pyc +docs/ diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 111476d274..20f84b2dd0 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -695,7 +695,9 @@ def _test(): import doctest from pyspark.context import SparkContext globs = globals().copy() - globs['sc'] = SparkContext('local[4]', 'PythonTest') + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) doctest.testmod(globs=globs) globs['sc'].stop() diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py index 4ed925697c..8b08f7ef0f 100644 --- a/pyspark/pyspark/serializers.py +++ b/pyspark/pyspark/serializers.py @@ -24,7 +24,7 @@ def batched(iterator, batchSize): if count == batchSize: yield Batch(items) items = [] - count = [] + count = 0 if items: yield Batch(items) From 26186e2d259f3aa2db9c8594097fd342107ce147 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 29 Dec 2012 15:34:57 -0800 Subject: [PATCH 037/291] Use batching in pyspark parallelize(); fix cartesian() --- pyspark/pyspark/context.py | 4 +++- pyspark/pyspark/rdd.py | 31 +++++++++++++++---------------- pyspark/pyspark/serializers.py | 23 +++++++++++++---------- 3 files changed, 31 insertions(+), 27 deletions(-) diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index b90596ecc2..6172d69dcf 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -4,7 +4,7 @@ from tempfile import NamedTemporaryFile from pyspark.broadcast import Broadcast from pyspark.java_gateway import launch_gateway -from pyspark.serializers import dump_pickle, write_with_length +from pyspark.serializers import dump_pickle, write_with_length, batched from pyspark.rdd import RDD from py4j.java_collections import ListConverter @@ -91,6 +91,8 @@ class SparkContext(object): # objects are written to a file and loaded through textFile(). tempFile = NamedTemporaryFile(delete=False) atexit.register(lambda: os.unlink(tempFile.name)) + if self.batchSize != 1: + c = batched(c, self.batchSize) for x in c: write_with_length(dump_pickle(x), tempFile) tempFile.close() diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 20f84b2dd0..203f7377d2 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -2,7 +2,7 @@ import atexit from base64 import standard_b64encode as b64enc import copy from collections import defaultdict -from itertools import chain, ifilter, imap +from itertools import chain, ifilter, imap, product import operator import os import shlex @@ -123,12 +123,6 @@ class RDD(object): >>> rdd = sc.parallelize([1, 1, 2, 3]) >>> rdd.union(rdd).collect() [1, 1, 2, 3, 1, 1, 2, 3] - - Union of batched and unbatched RDDs (internal test): - - >>> batchedRDD = sc.parallelize([Batch([1, 2, 3, 4, 5])]) - >>> rdd.union(batchedRDD).collect() - [1, 1, 2, 3, 1, 2, 3, 4, 5] """ return RDD(self._jrdd.union(other._jrdd), self.ctx) @@ -168,7 +162,18 @@ class RDD(object): >>> sorted(rdd.cartesian(rdd).collect()) [(1, 1), (1, 2), (2, 1), (2, 2)] """ - return RDD(self._jrdd.cartesian(other._jrdd), self.ctx) + # Due to batching, we can't use the Java cartesian method. + java_cartesian = RDD(self._jrdd.cartesian(other._jrdd), self.ctx) + def unpack_batches(pair): + (x, y) = pair + if type(x) == Batch or type(y) == Batch: + xs = x.items if type(x) == Batch else [x] + ys = y.items if type(y) == Batch else [y] + for pair in product(xs, ys): + yield pair + else: + yield pair + return java_cartesian.flatMap(unpack_batches) def groupBy(self, f, numSplits=None): """ @@ -293,8 +298,6 @@ class RDD(object): >>> sc.parallelize([2, 3, 4]).count() 3 - >>> sc.parallelize([Batch([2, 3, 4])]).count() - 3 """ return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum() @@ -667,12 +670,8 @@ class PipelinedRDD(RDD): if not self._bypass_serializer and self.ctx.batchSize != 1: oldfunc = self.func batchSize = self.ctx.batchSize - if batchSize == -1: # unlimited batch size - def batched_func(iterator): - yield Batch(list(oldfunc(iterator))) - else: - def batched_func(iterator): - return batched(oldfunc(iterator), batchSize) + def batched_func(iterator): + return batched(oldfunc(iterator), batchSize) func = batched_func cmds = [func, self._bypass_serializer] pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds) diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py index 8b08f7ef0f..9a5151ea00 100644 --- a/pyspark/pyspark/serializers.py +++ b/pyspark/pyspark/serializers.py @@ -16,17 +16,20 @@ class Batch(object): def batched(iterator, batchSize): - items = [] - count = 0 - for item in iterator: - items.append(item) - count += 1 - if count == batchSize: + if batchSize == -1: # unlimited batch size + yield Batch(list(iterator)) + else: + items = [] + count = 0 + for item in iterator: + items.append(item) + count += 1 + if count == batchSize: + yield Batch(items) + items = [] + count = 0 + if items: yield Batch(items) - items = [] - count = 0 - if items: - yield Batch(items) def dump_pickle(obj): From 59195c68ec37acf20d527189ed757397b273a207 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 29 Dec 2012 16:01:03 -0800 Subject: [PATCH 038/291] Update PySpark for compatibility with TaskContext. --- .../src/main/scala/spark/api/python/PythonRDD.scala | 13 +++++-------- pyspark/pyspark/rdd.py | 3 ++- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index f76616a4c4..dc48378fdc 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -8,10 +8,7 @@ import scala.io.Source import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import spark.broadcast.Broadcast -import spark.SparkEnv -import spark.Split -import spark.RDD -import spark.OneToOneDependency +import spark._ import spark.rdd.PipedRDD @@ -34,7 +31,7 @@ private[spark] class PythonRDD[T: ClassManifest]( override val partitioner = if (preservePartitoning) parent.partitioner else None - override def compute(split: Split): Iterator[Array[Byte]] = { + override def compute(split: Split, context: TaskContext): Iterator[Array[Byte]] = { val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME") val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/pyspark/pyspark/worker.py")) @@ -74,7 +71,7 @@ private[spark] class PythonRDD[T: ClassManifest]( out.println(elem) } out.flush() - for (elem <- parent.iterator(split)) { + for (elem <- parent.iterator(split, context)) { PythonRDD.writeAsPickle(elem, dOut) } dOut.flush() @@ -123,8 +120,8 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Array[Byte], Array[Byte])](prev.context) { override def splits = prev.splits override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = - prev.iterator(split).grouped(2).map { + override def compute(split: Split, context: TaskContext) = + prev.iterator(split, context).grouped(2).map { case Seq(a, b) => (a, b) case x => throw new Exception("PairwiseRDD: unexpected value: " + x) } diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 203f7377d2..21dda31c4e 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -335,9 +335,10 @@ class RDD(object): """ items = [] splits = self._jrdd.splits() + taskContext = self.ctx.jvm.spark.TaskContext(0, 0, 0) while len(items) < num and splits: split = splits.pop(0) - iterator = self._jrdd.iterator(split) + iterator = self._jrdd.iterator(split, taskContext) items.extend(self._collect_iterator_through_file(iterator)) return items[:num] From 39dd953fd88e9aa7335603ab452d9c1bed4ba67a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 29 Dec 2012 17:06:50 -0800 Subject: [PATCH 039/291] Add test for pyspark.RDD.saveAsTextFile(). --- pyspark/pyspark/rdd.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 21dda31c4e..cbffb6cc1f 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -351,10 +351,17 @@ class RDD(object): """ return self.take(1)[0] - # TODO: add test and fix for use with Batch def saveAsTextFile(self, path): """ Save this RDD as a text file, using string representations of elements. + + >>> tempFile = NamedTemporaryFile(delete=True) + >>> tempFile.close() + >>> sc.parallelize(range(10)).saveAsTextFile(tempFile.name) + >>> from fileinput import input + >>> from glob import glob + >>> ''.join(input(glob(tempFile.name + "/part-0000*"))) + '0\\n1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n' """ def func(iterator): return (str(x).encode("utf-8") for x in iterator) From 099898b43955d99351ec94d4a373de854bf7edf7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 29 Dec 2012 17:52:47 -0800 Subject: [PATCH 040/291] Port LR example to PySpark using numpy. This version of the example crashes after the first iteration with "OverflowError: math range error" because Python's math.exp() behaves differently than Scala's; see SPARK-646. --- pyspark/examples/lr.py | 57 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100755 pyspark/examples/lr.py diff --git a/pyspark/examples/lr.py b/pyspark/examples/lr.py new file mode 100755 index 0000000000..5fca0266b8 --- /dev/null +++ b/pyspark/examples/lr.py @@ -0,0 +1,57 @@ +""" +This example requires numpy (http://www.numpy.org/) +""" +from collections import namedtuple +from math import exp +from os.path import realpath +import sys + +import numpy as np +from pyspark.context import SparkContext + + +N = 100000 # Number of data points +D = 10 # Number of dimensions +R = 0.7 # Scaling factor +ITERATIONS = 5 +np.random.seed(42) + + +DataPoint = namedtuple("DataPoint", ['x', 'y']) +from lr import DataPoint # So that DataPoint is properly serialized + + +def generateData(): + def generatePoint(i): + y = -1 if i % 2 == 0 else 1 + x = np.random.normal(size=D) + (y * R) + return DataPoint(x, y) + return [generatePoint(i) for i in range(N)] + + +if __name__ == "__main__": + if len(sys.argv) == 1: + print >> sys.stderr, \ + "Usage: PythonLR []" + exit(-1) + sc = SparkContext(sys.argv[1], "PythonLR", pyFiles=[realpath(__file__)]) + slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2 + points = sc.parallelize(generateData(), slices).cache() + + # Initialize w to a random value + w = 2 * np.random.ranf(size=D) - 1 + print "Initial w: " + str(w) + + def add(x, y): + x += y + return x + + for i in range(1, ITERATIONS + 1): + print "On iteration %i" % i + + gradient = points.map(lambda p: + (1.0 / (1.0 + exp(-p.y * np.dot(w, p.x)))) * p.y * p.x + ).reduce(add) + w -= gradient + + print "Final w: " + str(w) From 6f6a6b79c4c3f3555f8ff427c91e714d02afe8fa Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 31 Dec 2012 14:56:23 -0800 Subject: [PATCH 041/291] Launch with `scala` by default in run-pyspark --- pyspark/run-pyspark | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pyspark/run-pyspark b/pyspark/run-pyspark index f8039b8038..4d10fbea8b 100755 --- a/pyspark/run-pyspark +++ b/pyspark/run-pyspark @@ -20,4 +20,9 @@ export PYSPARK_PYTHON # Add the PySpark classes to the Python path: export PYTHONPATH=$SPARK_HOME/pyspark/:$PYTHONPATH +# Launch with `scala` by default: +if [[ "$SPARK_LAUNCH_WITH_SCALA" != "0" ]] ; then + export SPARK_LAUNCH_WITH_SCALA=1 +fi + exec "$PYSPARK_PYTHON" "$@" From 170e451fbdd308ae77065bd9c0f2bd278abf0cb7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 1 Jan 2013 13:52:14 -0800 Subject: [PATCH 042/291] Minor documentation and style fixes for PySpark. --- .../spark/api/python/PythonPartitioner.scala | 4 +- .../scala/spark/api/python/PythonRDD.scala | 43 +++++++++++++------ docs/index.md | 8 +++- docs/python-programming-guide.md | 3 +- pyspark/examples/kmeans.py | 13 +++--- .../{lr.py => logistic_regression.py} | 4 +- pyspark/examples/pi.py | 5 ++- .../examples/{tc.py => transitive_closure.py} | 5 ++- pyspark/examples/wordcount.py | 4 +- pyspark/pyspark/__init__.py | 13 +++++- 10 files changed, 70 insertions(+), 32 deletions(-) rename pyspark/examples/{lr.py => logistic_regression.py} (93%) rename pyspark/examples/{tc.py => transitive_closure.py} (94%) diff --git a/core/src/main/scala/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/spark/api/python/PythonPartitioner.scala index 2c829508e5..648d9402b0 100644 --- a/core/src/main/scala/spark/api/python/PythonPartitioner.scala +++ b/core/src/main/scala/spark/api/python/PythonPartitioner.scala @@ -17,9 +17,9 @@ private[spark] class PythonPartitioner(override val numPartitions: Int) extends val hashCode = { if (key.isInstanceOf[Array[Byte]]) { Arrays.hashCode(key.asInstanceOf[Array[Byte]]) - } - else + } else { key.hashCode() + } } val mod = hashCode % numPartitions if (mod < 0) { diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index dc48378fdc..19a039e330 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -13,8 +13,12 @@ import spark.rdd.PipedRDD private[spark] class PythonRDD[T: ClassManifest]( - parent: RDD[T], command: Seq[String], envVars: java.util.Map[String, String], - preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) + parent: RDD[T], + command: Seq[String], + envVars: java.util.Map[String, String], + preservePartitoning: Boolean, + pythonExec: String, + broadcastVars: java.util.List[Broadcast[Array[Byte]]]) extends RDD[Array[Byte]](parent.context) { // Similar to Runtime.exec(), if we are given a single string, split it into words @@ -38,8 +42,8 @@ private[spark] class PythonRDD[T: ClassManifest]( // Add the environmental variables to the process. val currentEnvVars = pb.environment() - envVars.foreach { - case (variable, value) => currentEnvVars.put(variable, value) + for ((variable, value) <- envVars) { + currentEnvVars.put(variable, value) } val proc = pb.start() @@ -116,6 +120,10 @@ private[spark] class PythonRDD[T: ClassManifest]( val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) } +/** + * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python. + * This is used by PySpark's shuffle operations. + */ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Array[Byte], Array[Byte])](prev.context) { override def splits = prev.splits @@ -139,6 +147,16 @@ private[spark] object PythonRDD { * Write strings, pickled Python objects, or pairs of pickled objects to a data output stream. * The data format is a 32-bit integer representing the pickled object's length (in bytes), * followed by the pickled data. + * + * Pickle module: + * + * http://docs.python.org/2/library/pickle.html + * + * The pickle protocol is documented in the source of the `pickle` and `pickletools` modules: + * + * http://hg.python.org/cpython/file/2.6/Lib/pickle.py + * http://hg.python.org/cpython/file/2.6/Lib/pickletools.py + * * @param elem the object to write * @param dOut a data output stream */ @@ -201,15 +219,14 @@ private[spark] object PythonRDD { } private object Pickle { - def b(x: Int): Byte = x.asInstanceOf[Byte] - val PROTO: Byte = b(0x80) - val TWO: Byte = b(0x02) - val BINUNICODE : Byte = 'X' - val STOP : Byte = '.' - val TUPLE2 : Byte = b(0x86) - val EMPTY_LIST : Byte = ']' - val MARK : Byte = '(' - val APPENDS : Byte = 'e' + val PROTO: Byte = 0x80.toByte + val TWO: Byte = 0x02.toByte + val BINUNICODE: Byte = 'X' + val STOP: Byte = '.' + val TUPLE2: Byte = 0x86.toByte + val EMPTY_LIST: Byte = ']' + val MARK: Byte = '(' + val APPENDS: Byte = 'e' } private class ExtractValue extends spark.api.java.function.Function[(Array[Byte], diff --git a/docs/index.md b/docs/index.md index 33ab58a962..848b585333 100644 --- a/docs/index.md +++ b/docs/index.md @@ -8,7 +8,7 @@ TODO(andyk): Rewrite to make the Java API a first class part of the story. {% endcomment %} Spark is a MapReduce-like cluster computing framework designed for low-latency iterative jobs and interactive use from an interpreter. -It provides clean, language-integrated APIs in Scala, Java, and Python, with a rich array of parallel operators. +It provides clean, language-integrated APIs in [Scala](scala-programming-guide.html), [Java](java-programming-guide.html), and [Python](python-programming-guide.html), with a rich array of parallel operators. Spark can run on top of the [Apache Mesos](http://incubator.apache.org/mesos/) cluster manager, [Hadoop YARN](http://hadoop.apache.org/docs/r2.0.1-alpha/hadoop-yarn/hadoop-yarn-site/YARN.html), Amazon EC2, or without an independent resource manager ("standalone mode"). @@ -61,6 +61,11 @@ of `project/SparkBuild.scala`, then rebuilding Spark (`sbt/sbt clean compile`). * [Java Programming Guide](java-programming-guide.html): using Spark from Java * [Python Programming Guide](python-programming-guide.html): using Spark from Python +**API Docs:** + +* [Java/Scala (Scaladoc)](api/core/index.html) +* [Python (Epydoc)](api/pyspark/index.html) + **Deployment guides:** * [Running Spark on Amazon EC2](ec2-scripts.html): scripts that let you launch a cluster on EC2 in about 5 minutes @@ -73,7 +78,6 @@ of `project/SparkBuild.scala`, then rebuilding Spark (`sbt/sbt clean compile`). * [Configuration](configuration.html): customize Spark via its configuration system * [Tuning Guide](tuning.html): best practices to optimize performance and memory use -* API Docs: [Java/Scala (Scaladoc)](api/core/index.html) and [Python (Epydoc)](api/pyspark/index.html) * [Bagel](bagel-programming-guide.html): an implementation of Google's Pregel on Spark * [Contributing to Spark](contributing-to-spark.html) diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md index b7c747f905..d88d4eb42d 100644 --- a/docs/python-programming-guide.md +++ b/docs/python-programming-guide.md @@ -17,8 +17,7 @@ There are a few key differences between the Python and Scala APIs: * Python is dynamically typed, so RDDs can hold objects of different types. * PySpark does not currently support the following Spark features: - Accumulators - - Special functions on RRDs of doubles, such as `mean` and `stdev` - - Approximate jobs / functions, such as `countApprox` and `sumApprox`. + - Special functions on RDDs of doubles, such as `mean` and `stdev` - `lookup` - `mapPartitionsWithSplit` - `persist` at storage levels other than `MEMORY_ONLY` diff --git a/pyspark/examples/kmeans.py b/pyspark/examples/kmeans.py index 9cc366f03c..ad2be21178 100644 --- a/pyspark/examples/kmeans.py +++ b/pyspark/examples/kmeans.py @@ -1,18 +1,21 @@ +""" +This example requires numpy (http://www.numpy.org/) +""" import sys -from pyspark.context import SparkContext -from numpy import array, sum as np_sum +import numpy as np +from pyspark import SparkContext def parseVector(line): - return array([float(x) for x in line.split(' ')]) + return np.array([float(x) for x in line.split(' ')]) def closestPoint(p, centers): bestIndex = 0 closest = float("+inf") for i in range(len(centers)): - tempDist = np_sum((p - centers[i]) ** 2) + tempDist = np.sum((p - centers[i]) ** 2) if tempDist < closest: closest = tempDist bestIndex = i @@ -41,7 +44,7 @@ if __name__ == "__main__": newPoints = pointStats.map( lambda (x, (y, z)): (x, y / z)).collect() - tempDist = sum(np_sum((kPoints[x] - y) ** 2) for (x, y) in newPoints) + tempDist = sum(np.sum((kPoints[x] - y) ** 2) for (x, y) in newPoints) for (x, y) in newPoints: kPoints[x] = y diff --git a/pyspark/examples/lr.py b/pyspark/examples/logistic_regression.py similarity index 93% rename from pyspark/examples/lr.py rename to pyspark/examples/logistic_regression.py index 5fca0266b8..f13698a86f 100755 --- a/pyspark/examples/lr.py +++ b/pyspark/examples/logistic_regression.py @@ -7,7 +7,7 @@ from os.path import realpath import sys import numpy as np -from pyspark.context import SparkContext +from pyspark import SparkContext N = 100000 # Number of data points @@ -32,7 +32,7 @@ def generateData(): if __name__ == "__main__": if len(sys.argv) == 1: print >> sys.stderr, \ - "Usage: PythonLR []" + "Usage: PythonLR []" exit(-1) sc = SparkContext(sys.argv[1], "PythonLR", pyFiles=[realpath(__file__)]) slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2 diff --git a/pyspark/examples/pi.py b/pyspark/examples/pi.py index 348bbc5dce..127cba029b 100644 --- a/pyspark/examples/pi.py +++ b/pyspark/examples/pi.py @@ -1,13 +1,14 @@ import sys from random import random from operator import add -from pyspark.context import SparkContext + +from pyspark import SparkContext if __name__ == "__main__": if len(sys.argv) == 1: print >> sys.stderr, \ - "Usage: PythonPi []" + "Usage: PythonPi []" exit(-1) sc = SparkContext(sys.argv[1], "PythonPi") slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2 diff --git a/pyspark/examples/tc.py b/pyspark/examples/transitive_closure.py similarity index 94% rename from pyspark/examples/tc.py rename to pyspark/examples/transitive_closure.py index 9630e72b47..73f7f8fbaf 100644 --- a/pyspark/examples/tc.py +++ b/pyspark/examples/transitive_closure.py @@ -1,6 +1,7 @@ import sys from random import Random -from pyspark.context import SparkContext + +from pyspark import SparkContext numEdges = 200 numVertices = 100 @@ -20,7 +21,7 @@ def generateGraph(): if __name__ == "__main__": if len(sys.argv) == 1: print >> sys.stderr, \ - "Usage: PythonTC []" + "Usage: PythonTC []" exit(-1) sc = SparkContext(sys.argv[1], "PythonTC") slices = sys.argv[2] if len(sys.argv) > 2 else 2 diff --git a/pyspark/examples/wordcount.py b/pyspark/examples/wordcount.py index 8365c070e8..857160624b 100644 --- a/pyspark/examples/wordcount.py +++ b/pyspark/examples/wordcount.py @@ -1,6 +1,8 @@ import sys from operator import add -from pyspark.context import SparkContext + +from pyspark import SparkContext + if __name__ == "__main__": if len(sys.argv) < 3: diff --git a/pyspark/pyspark/__init__.py b/pyspark/pyspark/__init__.py index 8f8402b62b..1ab360a666 100644 --- a/pyspark/pyspark/__init__.py +++ b/pyspark/pyspark/__init__.py @@ -1,9 +1,20 @@ +""" +PySpark is a Python API for Spark. + +Public classes: + + - L{SparkContext} + Main entry point for Spark functionality. + - L{RDD} + A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. +""" import sys import os sys.path.insert(0, os.path.join(os.environ["SPARK_HOME"], "pyspark/lib/py4j0.7.egg")) from pyspark.context import SparkContext +from pyspark.rdd import RDD -__all__ = ["SparkContext"] +__all__ = ["SparkContext", "RDD"] From b58340dbd9a741331fc4c3829b08c093560056c2 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 1 Jan 2013 14:48:45 -0800 Subject: [PATCH 043/291] Rename top-level 'pyspark' directory to 'python' --- .../src/main/scala/spark/api/python/PythonRDD.scala | 2 +- docs/_plugins/copy_api_dirs.rb | 8 ++++---- pyspark-shell | 3 +++ pyspark/pyspark-shell | 3 --- {pyspark => python}/.gitignore | 0 {pyspark => python}/epydoc.conf | 0 {pyspark => python}/examples/kmeans.py | 0 {pyspark => python}/examples/logistic_regression.py | 0 {pyspark => python}/examples/pi.py | 0 {pyspark => python}/examples/transitive_closure.py | 0 {pyspark => python}/examples/wordcount.py | 0 {pyspark => python}/lib/PY4J_LICENSE.txt | 0 {pyspark => python}/lib/PY4J_VERSION.txt | 0 {pyspark => python}/lib/py4j0.7.egg | Bin {pyspark => python}/lib/py4j0.7.jar | Bin {pyspark => python}/pyspark/__init__.py | 2 +- {pyspark => python}/pyspark/broadcast.py | 0 {pyspark => python}/pyspark/cloudpickle.py | 0 {pyspark => python}/pyspark/context.py | 0 {pyspark => python}/pyspark/java_gateway.py | 0 {pyspark => python}/pyspark/join.py | 0 {pyspark => python}/pyspark/rdd.py | 0 {pyspark => python}/pyspark/serializers.py | 0 {pyspark => python}/pyspark/shell.py | 0 {pyspark => python}/pyspark/worker.py | 0 run | 2 +- pyspark/run-pyspark => run-pyspark | 4 ++-- run2.cmd | 2 +- 28 files changed, 13 insertions(+), 13 deletions(-) create mode 100755 pyspark-shell delete mode 100755 pyspark/pyspark-shell rename {pyspark => python}/.gitignore (100%) rename {pyspark => python}/epydoc.conf (100%) rename {pyspark => python}/examples/kmeans.py (100%) rename {pyspark => python}/examples/logistic_regression.py (100%) rename {pyspark => python}/examples/pi.py (100%) rename {pyspark => python}/examples/transitive_closure.py (100%) rename {pyspark => python}/examples/wordcount.py (100%) rename {pyspark => python}/lib/PY4J_LICENSE.txt (100%) rename {pyspark => python}/lib/PY4J_VERSION.txt (100%) rename {pyspark => python}/lib/py4j0.7.egg (100%) rename {pyspark => python}/lib/py4j0.7.jar (100%) rename {pyspark => python}/pyspark/__init__.py (82%) rename {pyspark => python}/pyspark/broadcast.py (100%) rename {pyspark => python}/pyspark/cloudpickle.py (100%) rename {pyspark => python}/pyspark/context.py (100%) rename {pyspark => python}/pyspark/java_gateway.py (100%) rename {pyspark => python}/pyspark/join.py (100%) rename {pyspark => python}/pyspark/rdd.py (100%) rename {pyspark => python}/pyspark/serializers.py (100%) rename {pyspark => python}/pyspark/shell.py (100%) rename {pyspark => python}/pyspark/worker.py (100%) rename pyspark/run-pyspark => run-pyspark (86%) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 19a039e330..cf60d14f03 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -38,7 +38,7 @@ private[spark] class PythonRDD[T: ClassManifest]( override def compute(split: Split, context: TaskContext): Iterator[Array[Byte]] = { val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME") - val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/pyspark/pyspark/worker.py")) + val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/python/pyspark/worker.py")) // Add the environmental variables to the process. val currentEnvVars = pb.environment() diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 577f3ebe70..c9ce589c1b 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -30,8 +30,8 @@ if ENV['SKIP_SCALADOC'] != '1' end if ENV['SKIP_EPYDOC'] != '1' - puts "Moving to pyspark directory and building epydoc." - cd("../pyspark") + puts "Moving to python directory and building epydoc." + cd("../python") puts `epydoc --config epydoc.conf` puts "Moving back into docs dir." @@ -40,8 +40,8 @@ if ENV['SKIP_EPYDOC'] != '1' puts "echo making directory pyspark" mkdir_p "pyspark" - puts "cp -r ../pyspark/docs/. api/pyspark" - cp_r("../pyspark/docs/.", "api/pyspark") + puts "cp -r ../python/docs/. api/pyspark" + cp_r("../python/docs/.", "api/pyspark") cd("..") end diff --git a/pyspark-shell b/pyspark-shell new file mode 100755 index 0000000000..27aaac3a26 --- /dev/null +++ b/pyspark-shell @@ -0,0 +1,3 @@ +#!/usr/bin/env bash +FWDIR="`dirname $0`" +exec $FWDIR/run-pyspark $FWDIR/python/pyspark/shell.py "$@" diff --git a/pyspark/pyspark-shell b/pyspark/pyspark-shell deleted file mode 100755 index e3736826e8..0000000000 --- a/pyspark/pyspark-shell +++ /dev/null @@ -1,3 +0,0 @@ -#!/usr/bin/env bash -FWDIR="`dirname $0`" -exec $FWDIR/run-pyspark $FWDIR/pyspark/shell.py "$@" diff --git a/pyspark/.gitignore b/python/.gitignore similarity index 100% rename from pyspark/.gitignore rename to python/.gitignore diff --git a/pyspark/epydoc.conf b/python/epydoc.conf similarity index 100% rename from pyspark/epydoc.conf rename to python/epydoc.conf diff --git a/pyspark/examples/kmeans.py b/python/examples/kmeans.py similarity index 100% rename from pyspark/examples/kmeans.py rename to python/examples/kmeans.py diff --git a/pyspark/examples/logistic_regression.py b/python/examples/logistic_regression.py similarity index 100% rename from pyspark/examples/logistic_regression.py rename to python/examples/logistic_regression.py diff --git a/pyspark/examples/pi.py b/python/examples/pi.py similarity index 100% rename from pyspark/examples/pi.py rename to python/examples/pi.py diff --git a/pyspark/examples/transitive_closure.py b/python/examples/transitive_closure.py similarity index 100% rename from pyspark/examples/transitive_closure.py rename to python/examples/transitive_closure.py diff --git a/pyspark/examples/wordcount.py b/python/examples/wordcount.py similarity index 100% rename from pyspark/examples/wordcount.py rename to python/examples/wordcount.py diff --git a/pyspark/lib/PY4J_LICENSE.txt b/python/lib/PY4J_LICENSE.txt similarity index 100% rename from pyspark/lib/PY4J_LICENSE.txt rename to python/lib/PY4J_LICENSE.txt diff --git a/pyspark/lib/PY4J_VERSION.txt b/python/lib/PY4J_VERSION.txt similarity index 100% rename from pyspark/lib/PY4J_VERSION.txt rename to python/lib/PY4J_VERSION.txt diff --git a/pyspark/lib/py4j0.7.egg b/python/lib/py4j0.7.egg similarity index 100% rename from pyspark/lib/py4j0.7.egg rename to python/lib/py4j0.7.egg diff --git a/pyspark/lib/py4j0.7.jar b/python/lib/py4j0.7.jar similarity index 100% rename from pyspark/lib/py4j0.7.jar rename to python/lib/py4j0.7.jar diff --git a/pyspark/pyspark/__init__.py b/python/pyspark/__init__.py similarity index 82% rename from pyspark/pyspark/__init__.py rename to python/pyspark/__init__.py index 1ab360a666..c595ae0842 100644 --- a/pyspark/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -10,7 +10,7 @@ Public classes: """ import sys import os -sys.path.insert(0, os.path.join(os.environ["SPARK_HOME"], "pyspark/lib/py4j0.7.egg")) +sys.path.insert(0, os.path.join(os.environ["SPARK_HOME"], "python/lib/py4j0.7.egg")) from pyspark.context import SparkContext diff --git a/pyspark/pyspark/broadcast.py b/python/pyspark/broadcast.py similarity index 100% rename from pyspark/pyspark/broadcast.py rename to python/pyspark/broadcast.py diff --git a/pyspark/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py similarity index 100% rename from pyspark/pyspark/cloudpickle.py rename to python/pyspark/cloudpickle.py diff --git a/pyspark/pyspark/context.py b/python/pyspark/context.py similarity index 100% rename from pyspark/pyspark/context.py rename to python/pyspark/context.py diff --git a/pyspark/pyspark/java_gateway.py b/python/pyspark/java_gateway.py similarity index 100% rename from pyspark/pyspark/java_gateway.py rename to python/pyspark/java_gateway.py diff --git a/pyspark/pyspark/join.py b/python/pyspark/join.py similarity index 100% rename from pyspark/pyspark/join.py rename to python/pyspark/join.py diff --git a/pyspark/pyspark/rdd.py b/python/pyspark/rdd.py similarity index 100% rename from pyspark/pyspark/rdd.py rename to python/pyspark/rdd.py diff --git a/pyspark/pyspark/serializers.py b/python/pyspark/serializers.py similarity index 100% rename from pyspark/pyspark/serializers.py rename to python/pyspark/serializers.py diff --git a/pyspark/pyspark/shell.py b/python/pyspark/shell.py similarity index 100% rename from pyspark/pyspark/shell.py rename to python/pyspark/shell.py diff --git a/pyspark/pyspark/worker.py b/python/pyspark/worker.py similarity index 100% rename from pyspark/pyspark/worker.py rename to python/pyspark/worker.py diff --git a/run b/run index ed788c4db3..08e2b2434b 100755 --- a/run +++ b/run @@ -63,7 +63,7 @@ CORE_DIR="$FWDIR/core" REPL_DIR="$FWDIR/repl" EXAMPLES_DIR="$FWDIR/examples" BAGEL_DIR="$FWDIR/bagel" -PYSPARK_DIR="$FWDIR/pyspark" +PYSPARK_DIR="$FWDIR/python" # Build up classpath CLASSPATH="$SPARK_CLASSPATH" diff --git a/pyspark/run-pyspark b/run-pyspark similarity index 86% rename from pyspark/run-pyspark rename to run-pyspark index 4d10fbea8b..deb0d708b3 100755 --- a/pyspark/run-pyspark +++ b/run-pyspark @@ -1,7 +1,7 @@ #!/usr/bin/env bash # Figure out where the Scala framework is installed -FWDIR="$(cd `dirname $0`; cd ../; pwd)" +FWDIR="$(cd `dirname $0`; pwd)" # Export this as SPARK_HOME export SPARK_HOME="$FWDIR" @@ -18,7 +18,7 @@ fi export PYSPARK_PYTHON # Add the PySpark classes to the Python path: -export PYTHONPATH=$SPARK_HOME/pyspark/:$PYTHONPATH +export PYTHONPATH=$SPARK_HOME/python/:$PYTHONPATH # Launch with `scala` by default: if [[ "$SPARK_LAUNCH_WITH_SCALA" != "0" ]] ; then diff --git a/run2.cmd b/run2.cmd index 9c50804e69..83464b1166 100644 --- a/run2.cmd +++ b/run2.cmd @@ -34,7 +34,7 @@ set CORE_DIR=%FWDIR%core set REPL_DIR=%FWDIR%repl set EXAMPLES_DIR=%FWDIR%examples set BAGEL_DIR=%FWDIR%bagel -set PYSPARK_DIR=%FWDIR%pyspark +set PYSPARK_DIR=%FWDIR%python rem Build up classpath set CLASSPATH=%SPARK_CLASSPATH%;%MESOS_CLASSPATH%;%FWDIR%conf;%CORE_DIR%\target\scala-%SCALA_VERSION%\classes From ce9f1bbe20eff794cd1d588dc88f109d32588cfe Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 1 Jan 2013 21:25:49 -0800 Subject: [PATCH 044/291] Add `pyspark` script to replace the other scripts. Expand the PySpark programming guide. --- docs/python-programming-guide.md | 49 ++++++++++++++++++++++++++++---- docs/quick-start.md | 4 +-- run-pyspark => pyspark | 4 +++ pyspark-shell | 3 -- python/pyspark/shell.py | 36 +++++++---------------- python/run-tests | 9 ++++++ 6 files changed, 69 insertions(+), 36 deletions(-) rename run-pyspark => pyspark (80%) delete mode 100755 pyspark-shell create mode 100755 python/run-tests diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md index d88d4eb42d..d963551296 100644 --- a/docs/python-programming-guide.md +++ b/docs/python-programming-guide.md @@ -24,6 +24,35 @@ There are a few key differences between the Python and Scala APIs: - `sample` - `sort` +In PySpark, RDDs support the same methods as their Scala counterparts but take Python functions and return Python collection types. +Short functions can be passed to RDD methods using Python's [`lambda`](http://www.diveintopython.net/power_of_introspection/lambda_functions.html) syntax: + +{% highlight python %} +logData = sc.textFile(logFile).cache() +errors = logData.filter(lambda s: 'ERROR' in s.split()) +{% endhighlight %} + +You can also pass functions that are defined using the `def` keyword; this is useful for more complicated functions that cannot be expressed using `lambda`: + +{% highlight python %} +def is_error(line): + return 'ERROR' in line.split() +errors = logData.filter(is_error) +{% endhighlight %} + +Functions can access objects in enclosing scopes, although modifications to those objects within RDD methods will not be propagated to other tasks: + +{% highlight python %} +error_keywords = ["Exception", "Error"] +def is_error(line): + words = line.split() + return any(keyword in words for keyword in error_keywords) +errors = logData.filter(is_error) +{% endhighlight %} + +PySpark will automatically ship these functions to workers, along with any objects that they reference. +Instances of classes will be serialized and shipped to workers by PySpark, but classes themselves cannot be automatically distributed to workers. +The [Standalone Use](#standalone-use) section describes how to ship code dependencies to workers. # Installing and Configuring PySpark @@ -34,13 +63,14 @@ By default, PySpark's scripts will run programs using `python`; an alternate Pyt All of PySpark's library dependencies, including [Py4J](http://py4j.sourceforge.net/), are bundled with PySpark and automatically imported. -Standalone PySpark jobs should be run using the `run-pyspark` script, which automatically configures the Java and Python environmnt using the settings in `conf/spark-env.sh`. +Standalone PySpark jobs should be run using the `pyspark` script, which automatically configures the Java and Python environment using the settings in `conf/spark-env.sh`. The script automatically adds the `pyspark` package to the `PYTHONPATH`. # Interactive Use -PySpark's `pyspark-shell` script provides a simple way to learn the API: +The `pyspark` script launches a Python interpreter that is configured to run PySpark jobs. +When run without any input files, `pyspark` launches a shell that can be used explore data interactively, which is a simple way to learn the API: {% highlight python %} >>> words = sc.textFile("/usr/share/dict/words") @@ -48,9 +78,18 @@ PySpark's `pyspark-shell` script provides a simple way to learn the API: [u'spar', u'sparable', u'sparada', u'sparadrap', u'sparagrass'] {% endhighlight %} +By default, the `pyspark` shell creates SparkContext that runs jobs locally. +To connect to a non-local cluster, set the `MASTER` environment variable. +For example, to use the `pyspark` shell with a [standalone Spark cluster](spark-standalone.html): + +{% highlight shell %} +$ MASTER=spark://IP:PORT ./pyspark +{% endhighlight %} + + # Standalone Use -PySpark can also be used from standalone Python scripts by creating a SparkContext in the script and running the script using the `run-pyspark` script in the `pyspark` directory. +PySpark can also be used from standalone Python scripts by creating a SparkContext in your script and running the script using `pyspark`. The Quick Start guide includes a [complete example](quick-start.html#a-standalone-job-in-python) of a standalone Python job. Code dependencies can be deployed by listing them in the `pyFiles` option in the SparkContext constructor: @@ -65,8 +104,8 @@ Code dependencies can be added to an existing SparkContext using its `addPyFile( # Where to Go from Here -PySpark includes several sample programs using the Python API in `pyspark/examples`. -You can run them by passing the files to the `pyspark-run` script included in PySpark -- for example `./pyspark-run examples/wordcount.py`. +PySpark includes several sample programs using the Python API in `python/examples`. +You can run them by passing the files to the `pyspark` script -- for example `./pyspark python/examples/wordcount.py`. Each example program prints usage help when run without any arguments. We currently provide [API documentation](api/pyspark/index.html) for the Python API as Epydoc. diff --git a/docs/quick-start.md b/docs/quick-start.md index 8c25df5486..2c7cfbed25 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -258,11 +258,11 @@ We can pass Python functions to Spark, which are automatically serialized along For jobs that use custom classes or third-party libraries, we can add those code dependencies to SparkContext to ensure that they will be available on remote machines; this is described in more detail in the [Python programming guide](python-programming-guide). `SimpleJob` is simple enough that we do not need to specify any code dependencies. -We can run this job using the `run-pyspark` script in `$SPARK_HOME/pyspark`: +We can run this job using the `pyspark` script: {% highlight python %} $ cd $SPARK_HOME -$ ./pyspark/run-pyspark SimpleJob.py +$ ./pyspark SimpleJob.py ... Lines with a: 8422, Lines with b: 1836 {% endhighlight python %} diff --git a/run-pyspark b/pyspark similarity index 80% rename from run-pyspark rename to pyspark index deb0d708b3..9e89d51ba2 100755 --- a/run-pyspark +++ b/pyspark @@ -20,6 +20,10 @@ export PYSPARK_PYTHON # Add the PySpark classes to the Python path: export PYTHONPATH=$SPARK_HOME/python/:$PYTHONPATH +# Load the PySpark shell.py script when ./pyspark is used interactively: +export OLD_PYTHONSTARTUP=$PYTHONSTARTUP +export PYTHONSTARTUP=$FWDIR/python/pyspark/shell.py + # Launch with `scala` by default: if [[ "$SPARK_LAUNCH_WITH_SCALA" != "0" ]] ; then export SPARK_LAUNCH_WITH_SCALA=1 diff --git a/pyspark-shell b/pyspark-shell deleted file mode 100755 index 27aaac3a26..0000000000 --- a/pyspark-shell +++ /dev/null @@ -1,3 +0,0 @@ -#!/usr/bin/env bash -FWDIR="`dirname $0`" -exec $FWDIR/run-pyspark $FWDIR/python/pyspark/shell.py "$@" diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index bd39b0283f..7e6ad3aa76 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -1,33 +1,17 @@ """ An interactive shell. -""" -import optparse # I prefer argparse, but it's not included with Python < 2.7 -import code -import sys +This fle is designed to be launched as a PYTHONSTARTUP script. +""" +import os from pyspark.context import SparkContext -def main(master='local', ipython=False): - sc = SparkContext(master, 'PySparkShell') - user_ns = {'sc' : sc} - banner = "Spark context avaiable as sc." - if ipython: - import IPython - IPython.embed(user_ns=user_ns, banner2=banner) - else: - print banner - code.interact(local=user_ns) +sc = SparkContext(os.environ.get("MASTER", "local"), "PySparkShell") +print "Spark context avaiable as sc." - -if __name__ == '__main__': - usage = "usage: %prog [options] master" - parser = optparse.OptionParser(usage=usage) - parser.add_option("-i", "--ipython", help="Run IPython shell", - action="store_true") - (options, args) = parser.parse_args() - if len(sys.argv) > 1: - master = args[0] - else: - master = 'local' - main(master, options.ipython) +# The ./pyspark script stores the old PYTHONSTARTUP value in OLD_PYTHONSTARTUP, +# which allows us to execute the user's PYTHONSTARTUP file: +_pythonstartup = os.environ.get('OLD_PYTHONSTARTUP') +if _pythonstartup and os.path.isfile(_pythonstartup): + execfile(_pythonstartup) diff --git a/python/run-tests b/python/run-tests new file mode 100755 index 0000000000..da9e24cb1f --- /dev/null +++ b/python/run-tests @@ -0,0 +1,9 @@ +#!/usr/bin/env bash + +# Figure out where the Scala framework is installed +FWDIR="$(cd `dirname $0`; cd ../; pwd)" + +$FWDIR/pyspark pyspark/rdd.py +$FWDIR/pyspark -m doctest pyspark/broadcast.py + +# TODO: in the long-run, it would be nice to use a test runner like `nose`. From 33beba39656fc64984db09a82fc69ca4edcc02d4 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 3 Jan 2013 14:52:21 -0800 Subject: [PATCH 045/291] Change PySpark RDD.take() to not call iterator(). --- core/src/main/scala/spark/api/python/PythonRDD.scala | 4 ++++ python/pyspark/context.py | 1 + python/pyspark/rdd.py | 11 +++++------ 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index cf60d14f03..79d824d494 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -10,6 +10,7 @@ import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import spark.broadcast.Broadcast import spark._ import spark.rdd.PipedRDD +import java.util private[spark] class PythonRDD[T: ClassManifest]( @@ -216,6 +217,9 @@ private[spark] object PythonRDD { } file.close() } + + def takePartition[T](rdd: RDD[T], partition: Int): java.util.Iterator[T] = + rdd.context.runJob(rdd, ((x: Iterator[T]) => x), Seq(partition), true).head } private object Pickle { diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 6172d69dcf..4439356c1f 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -21,6 +21,7 @@ class SparkContext(object): jvm = gateway.jvm _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile + _takePartition = jvm.PythonRDD.takePartition def __init__(self, master, jobName, sparkHome=None, pyFiles=None, environment=None, batchSize=1024): diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index cbffb6cc1f..4ba417b2a2 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -328,18 +328,17 @@ class RDD(object): a lot of partitions are required. In that case, use L{collect} to get the whole RDD instead. - >>> sc.parallelize([2, 3, 4, 5, 6]).take(2) + >>> sc.parallelize([2, 3, 4, 5, 6]).cache().take(2) [2, 3] >>> sc.parallelize([2, 3, 4, 5, 6]).take(10) [2, 3, 4, 5, 6] """ items = [] - splits = self._jrdd.splits() - taskContext = self.ctx.jvm.spark.TaskContext(0, 0, 0) - while len(items) < num and splits: - split = splits.pop(0) - iterator = self._jrdd.iterator(split, taskContext) + for partition in range(self._jrdd.splits().size()): + iterator = self.ctx._takePartition(self._jrdd.rdd(), partition) items.extend(self._collect_iterator_through_file(iterator)) + if len(items) >= num: + break return items[:num] def first(self): From 8d57c78c83f74e45ce3c119e2e3915d5eac264e7 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Sat, 5 Jan 2013 10:54:05 -0600 Subject: [PATCH 046/291] Add PairRDDFunctions.keys and values. --- core/src/main/scala/spark/PairRDDFunctions.scala | 10 ++++++++++ core/src/test/scala/spark/ShuffleSuite.scala | 7 +++++++ 2 files changed, 17 insertions(+) diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 413c944a66..ce48cea903 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -615,6 +615,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( writer.cleanup() } + /** + * Return an RDD with the keys of each tuple. + */ + def keys: RDD[K] = self.map(_._1) + + /** + * Return an RDD with the values of each tuple. + */ + def values: RDD[V] = self.map(_._2) + private[spark] def getKeyClass() = implicitly[ClassManifest[K]].erasure private[spark] def getValueClass() = implicitly[ClassManifest[V]].erasure diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index 8170100f1d..5a867016f2 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -216,6 +216,13 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with BeforeAndAfter { // Test that a shuffle on the file works, because this used to be a bug assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil) } + + test("kesy and values") { + sc = new SparkContext("local", "test") + val rdd = sc.parallelize(Array((1, "a"), (2, "b"))) + assert(rdd.keys.collect().toList === List(1, 2)) + assert(rdd.values.collect().toList === List("a", "b")) + } } object ShuffleSuite { From f4e6b9361ffeec1018d5834f09db9fd86f2ba7bd Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Fri, 4 Jan 2013 22:43:22 -0600 Subject: [PATCH 047/291] Add RDD.collect(PartialFunction). --- core/src/main/scala/spark/RDD.scala | 7 +++++++ core/src/test/scala/spark/RDDSuite.scala | 1 + 2 files changed, 8 insertions(+) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 7e38583391..5163c80134 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -329,6 +329,13 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial */ def toArray(): Array[T] = collect() + /** + * Return an RDD that contains all matching values by applying `f`. + */ + def collect[U: ClassManifest](f: PartialFunction[T, U]): RDD[U] = { + filter(f.isDefinedAt).map(f) + } + /** * Reduces the elements of this RDD using the specified associative binary operator. */ diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 45e6c5f840..872b06fd08 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -35,6 +35,7 @@ class RDDSuite extends FunSuite with BeforeAndAfter { assert(nums.flatMap(x => 1 to x).collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4)) assert(nums.union(nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4))) + assert(nums.collect({ case i if i >= 3 => i.toString }).collect().toList === List("3", "4")) val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _))) assert(partitionSums.collect().toList === List(3, 7)) From 6a0db3b449a829f3e5cdf7229f6ee564268be1df Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Sat, 5 Jan 2013 12:56:17 -0600 Subject: [PATCH 048/291] Fix typo. --- core/src/test/scala/spark/ShuffleSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index 5a867016f2..bebb8ebe86 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -217,7 +217,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with BeforeAndAfter { assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil) } - test("kesy and values") { + test("keys and values") { sc = new SparkContext("local", "test") val rdd = sc.parallelize(Array((1, "a"), (2, "b"))) assert(rdd.keys.collect().toList === List(1, 2)) From 1fdb6946b5d076ed0f1b4d2bca2a20b6cd22cbc3 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Sat, 5 Jan 2013 13:07:59 -0600 Subject: [PATCH 049/291] Add RDD.tupleBy. --- core/src/main/scala/spark/RDD.scala | 7 +++++++ core/src/test/scala/spark/RDDSuite.scala | 1 + 2 files changed, 8 insertions(+) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 7e38583391..7aa4b0a173 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -510,6 +510,13 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial .saveAsSequenceFile(path) } + /** + * Tuples the elements of this RDD by applying `f`. + */ + def tupleBy[K](f: T => K): RDD[(K, T)] = { + map(x => (f(x), x)) + } + /** A private method for tests, to look at the contents of each partition */ private[spark] def collectPartitions(): Array[Array[T]] = { sc.runJob(this, (iter: Iterator[T]) => iter.toArray) diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 45e6c5f840..7832884224 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -35,6 +35,7 @@ class RDDSuite extends FunSuite with BeforeAndAfter { assert(nums.flatMap(x => 1 to x).collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4)) assert(nums.union(nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4))) + assert(nums.tupleBy(_.toString).collect().toList === List(("1", 1), ("2", 2), ("3", 3), ("4", 4))) val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _))) assert(partitionSums.collect().toList === List(3, 7)) From 86af64b0a6fde5a6418727a77b43bdfeda1b81cd Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 5 Jan 2013 20:54:08 -0500 Subject: [PATCH 050/291] Fix Accumulators in Java, and add a test for them --- core/src/main/scala/spark/Accumulators.scala | 18 +++++++- core/src/main/scala/spark/SparkContext.scala | 7 +-- .../spark/api/java/JavaSparkContext.scala | 23 ++++++---- core/src/test/scala/spark/JavaAPISuite.java | 44 +++++++++++++++++++ 4 files changed, 79 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala index bacd0ace37..6280f25391 100644 --- a/core/src/main/scala/spark/Accumulators.scala +++ b/core/src/main/scala/spark/Accumulators.scala @@ -38,14 +38,28 @@ class Accumulable[R, T] ( */ def += (term: T) { value_ = param.addAccumulator(value_, term) } + /** + * Add more data to this accumulator / accumulable + * @param term the data to add + */ + def add(term: T) { value_ = param.addAccumulator(value_, term) } + /** * Merge two accumulable objects together - * + * * Normally, a user will not want to use this version, but will instead call `+=`. - * @param term the other Accumulable that will get merged with this + * @param term the other `R` that will get merged with this */ def ++= (term: R) { value_ = param.addInPlace(value_, term)} + /** + * Merge two accumulable objects together + * + * Normally, a user will not want to use this version, but will instead call `add`. + * @param term the other `R` that will get merged with this + */ + def merge(term: R) { value_ = param.addInPlace(value_, term)} + /** * Access the accumulator's current value; only allowed on master. */ diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 4fd81bc63b..bbf8272eb3 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -382,11 +382,12 @@ class SparkContext( new Accumulator(initialValue, param) /** - * Create an [[spark.Accumulable]] shared variable, with a `+=` method + * Create an [[spark.Accumulable]] shared variable, to which tasks can add values with `+=`. + * Only the master can access the accumuable's `value`. * @tparam T accumulator type * @tparam R type that can be added to the accumulator */ - def accumulable[T,R](initialValue: T)(implicit param: AccumulableParam[T,R]) = + def accumulable[T, R](initialValue: T)(implicit param: AccumulableParam[T, R]) = new Accumulable(initialValue, param) /** @@ -404,7 +405,7 @@ class SparkContext( * Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for * reading it in distributed functions. The variable will be sent to each cluster only once. */ - def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T] (value, isLocal) + def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal) /** * Add a file to be downloaded into the working directory of this Spark job on every node. diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala index b7725313c4..bf9ad7a200 100644 --- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala @@ -10,7 +10,7 @@ import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} -import spark.{Accumulator, AccumulatorParam, RDD, SparkContext} +import spark.{Accumulable, AccumulableParam, Accumulator, AccumulatorParam, RDD, SparkContext} import spark.SparkContext.IntAccumulatorParam import spark.SparkContext.DoubleAccumulatorParam import spark.broadcast.Broadcast @@ -265,25 +265,32 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork /** * Create an [[spark.Accumulator]] integer variable, which tasks can "add" values - * to using the `+=` method. Only the master can access the accumulator's `value`. + * to using the `add` method. Only the master can access the accumulator's `value`. */ - def intAccumulator(initialValue: Int): Accumulator[Int] = - sc.accumulator(initialValue)(IntAccumulatorParam) + def intAccumulator(initialValue: Int): Accumulator[java.lang.Integer] = + sc.accumulator(initialValue)(IntAccumulatorParam).asInstanceOf[Accumulator[java.lang.Integer]] /** * Create an [[spark.Accumulator]] double variable, which tasks can "add" values - * to using the `+=` method. Only the master can access the accumulator's `value`. + * to using the `add` method. Only the master can access the accumulator's `value`. */ - def doubleAccumulator(initialValue: Double): Accumulator[Double] = - sc.accumulator(initialValue)(DoubleAccumulatorParam) + def doubleAccumulator(initialValue: Double): Accumulator[java.lang.Double] = + sc.accumulator(initialValue)(DoubleAccumulatorParam).asInstanceOf[Accumulator[java.lang.Double]] /** * Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values - * to using the `+=` method. Only the master can access the accumulator's `value`. + * to using the `add` method. Only the master can access the accumulator's `value`. */ def accumulator[T](initialValue: T, accumulatorParam: AccumulatorParam[T]): Accumulator[T] = sc.accumulator(initialValue)(accumulatorParam) + /** + * Create an [[spark.Accumulable]] shared variable of the given type, to which tasks can + * "add" values with `add`. Only the master can access the accumuable's `value`. + */ + def accumulable[T, R](initialValue: T, param: AccumulableParam[T, R]): Accumulable[T, R] = + sc.accumulable(initialValue)(param) + /** * Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for * reading it in distributed functions. The variable will be sent to each cluster only once. diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 33d5fc2d89..b99e790093 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -581,4 +581,48 @@ public class JavaAPISuite implements Serializable { JavaPairRDD zipped = rdd.zip(doubles); zipped.count(); } + + @Test + public void accumulators() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); + + final Accumulator intAccum = sc.intAccumulator(10); + rdd.foreach(new VoidFunction() { + public void call(Integer x) { + intAccum.add(x); + } + }); + Assert.assertEquals((Integer) 25, intAccum.value()); + + final Accumulator doubleAccum = sc.doubleAccumulator(10.0); + rdd.foreach(new VoidFunction() { + public void call(Integer x) { + doubleAccum.add((double) x); + } + }); + Assert.assertEquals((Double) 25.0, doubleAccum.value()); + + // Try a custom accumulator type + AccumulatorParam floatAccumulatorParam = new AccumulatorParam() { + public Float addInPlace(Float r, Float t) { + return r + t; + } + + public Float addAccumulator(Float r, Float t) { + return r + t; + } + + public Float zero(Float initialValue) { + return 0.0f; + } + }; + + final Accumulator floatAccum = sc.accumulator((Float) 10.0f, floatAccumulatorParam); + rdd.foreach(new VoidFunction() { + public void call(Integer x) { + floatAccum.add((float) x); + } + }); + Assert.assertEquals((Float) 25.0f, floatAccum.value()); + } } From 0982572519655354b10987de4f68e29b8331bd2a Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 5 Jan 2013 22:11:28 -0500 Subject: [PATCH 051/291] Add methods called just 'accumulator' for int/double in Java API --- .../scala/spark/api/java/JavaSparkContext.scala | 13 +++++++++++++ core/src/test/scala/spark/JavaAPISuite.java | 4 ++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala index bf9ad7a200..88ab2846be 100644 --- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala @@ -277,6 +277,19 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def doubleAccumulator(initialValue: Double): Accumulator[java.lang.Double] = sc.accumulator(initialValue)(DoubleAccumulatorParam).asInstanceOf[Accumulator[java.lang.Double]] + /** + * Create an [[spark.Accumulator]] integer variable, which tasks can "add" values + * to using the `add` method. Only the master can access the accumulator's `value`. + */ + def accumulator(initialValue: Int): Accumulator[java.lang.Integer] = intAccumulator(initialValue) + + /** + * Create an [[spark.Accumulator]] double variable, which tasks can "add" values + * to using the `add` method. Only the master can access the accumulator's `value`. + */ + def accumulator(initialValue: Double): Accumulator[java.lang.Double] = + doubleAccumulator(initialValue) + /** * Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values * to using the `add` method. Only the master can access the accumulator's `value`. diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index b99e790093..912f8de05d 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -586,7 +586,7 @@ public class JavaAPISuite implements Serializable { public void accumulators() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - final Accumulator intAccum = sc.intAccumulator(10); + final Accumulator intAccum = sc.accumulator(10); rdd.foreach(new VoidFunction() { public void call(Integer x) { intAccum.add(x); @@ -594,7 +594,7 @@ public class JavaAPISuite implements Serializable { }); Assert.assertEquals((Integer) 25, intAccum.value()); - final Accumulator doubleAccum = sc.doubleAccumulator(10.0); + final Accumulator doubleAccum = sc.accumulator(10.0); rdd.foreach(new VoidFunction() { public void call(Integer x) { doubleAccum.add((double) x); From 8fd3a70c188182105f81f5143ec65e74663582d5 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 5 Jan 2013 22:46:45 -0500 Subject: [PATCH 052/291] Add PairRDD.keys() and values() to Java API --- core/src/main/scala/spark/api/java/JavaPairRDD.scala | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/core/src/main/scala/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/spark/api/java/JavaPairRDD.scala index 5c2be534ff..8ce32e0e2f 100644 --- a/core/src/main/scala/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/spark/api/java/JavaPairRDD.scala @@ -471,6 +471,16 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif implicit def toOrdered(x: K): Ordered[K] = new KeyOrdering(x) fromRDD(new OrderedRDDFunctions(rdd).sortByKey(ascending)) } + + /** + * Return an RDD with the keys of each tuple. + */ + def keys(): JavaRDD[K] = JavaRDD.fromRDD[K](rdd.map(_._1)) + + /** + * Return an RDD with the values of each tuple. + */ + def values(): JavaRDD[V] = JavaRDD.fromRDD[V](rdd.map(_._2)) } object JavaPairRDD { From 8dc06069fe2330c3ee0fcaaeb0ae6e627a5887c3 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Sun, 6 Jan 2013 15:21:45 -0600 Subject: [PATCH 053/291] Rename RDD.tupleBy to keyBy. --- core/src/main/scala/spark/RDD.scala | 4 ++-- core/src/test/scala/spark/RDDSuite.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 7aa4b0a173..5ce524c0e7 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -511,9 +511,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial } /** - * Tuples the elements of this RDD by applying `f`. + * Creates tuples of the elements in this RDD by applying `f`. */ - def tupleBy[K](f: T => K): RDD[(K, T)] = { + def keyBy[K](f: T => K): RDD[(K, T)] = { map(x => (f(x), x)) } diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 7832884224..77bff8aba1 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -35,7 +35,7 @@ class RDDSuite extends FunSuite with BeforeAndAfter { assert(nums.flatMap(x => 1 to x).collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4)) assert(nums.union(nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4))) - assert(nums.tupleBy(_.toString).collect().toList === List(("1", 1), ("2", 2), ("3", 3), ("4", 4))) + assert(nums.keyBy(_.toString).collect().toList === List(("1", 1), ("2", 2), ("3", 3), ("4", 4))) val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _))) assert(partitionSums.collect().toList === List(3, 7)) From 9c32f300fb4151a2b563bf3d2e46469722e016e1 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 7 Jan 2013 16:50:23 -0500 Subject: [PATCH 054/291] Add Accumulable.setValue for easier use in Java --- core/src/main/scala/spark/Accumulators.scala | 20 +++++++++++++++----- core/src/test/scala/spark/JavaAPISuite.java | 4 ++++ 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala index 6280f25391..b644aba5f8 100644 --- a/core/src/main/scala/spark/Accumulators.scala +++ b/core/src/main/scala/spark/Accumulators.scala @@ -63,9 +63,12 @@ class Accumulable[R, T] ( /** * Access the accumulator's current value; only allowed on master. */ - def value = { - if (!deserialized) value_ - else throw new UnsupportedOperationException("Can't read accumulator value in task") + def value: R = { + if (!deserialized) { + value_ + } else { + throw new UnsupportedOperationException("Can't read accumulator value in task") + } } /** @@ -82,10 +85,17 @@ class Accumulable[R, T] ( /** * Set the accumulator's value; only allowed on master. */ - def value_= (r: R) { - if (!deserialized) value_ = r + def value_= (newValue: R) { + if (!deserialized) value_ = newValue else throw new UnsupportedOperationException("Can't assign accumulator value in task") } + + /** + * Set the accumulator's value; only allowed on master + */ + def setValue(newValue: R) { + this.value = newValue + } // Called by Java when deserializing an object private def readObject(in: ObjectInputStream) { diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 912f8de05d..0817d1146c 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -624,5 +624,9 @@ public class JavaAPISuite implements Serializable { } }); Assert.assertEquals((Float) 25.0f, floatAccum.value()); + + // Test the setValue method + floatAccum.setValue(5.0f); + Assert.assertEquals((Float) 5.0f, floatAccum.value()); } } From f8d579a0c05b7d29b59e541b483ded471d14ec17 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Thu, 27 Dec 2012 13:30:07 -0800 Subject: [PATCH 055/291] Remove dependencies on sun jvm classes. Instead use reflection to infer HotSpot options and total physical memory size --- core/src/main/scala/spark/SizeEstimator.scala | 13 ++++++++--- .../spark/deploy/worker/WorkerArguments.scala | 22 ++++++++++++++++--- 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/spark/SizeEstimator.scala b/core/src/main/scala/spark/SizeEstimator.scala index 7c3e8640e9..d4e1157250 100644 --- a/core/src/main/scala/spark/SizeEstimator.scala +++ b/core/src/main/scala/spark/SizeEstimator.scala @@ -9,7 +9,6 @@ import java.util.Random import javax.management.MBeanServer import java.lang.management.ManagementFactory -import com.sun.management.HotSpotDiagnosticMXBean import scala.collection.mutable.ArrayBuffer @@ -76,12 +75,20 @@ private[spark] object SizeEstimator extends Logging { if (System.getProperty("spark.test.useCompressedOops") != null) { return System.getProperty("spark.test.useCompressedOops").toBoolean } + try { val hotSpotMBeanName = "com.sun.management:type=HotSpotDiagnostic" val server = ManagementFactory.getPlatformMBeanServer() + + // NOTE: This should throw an exception in non-Sun JVMs + val hotSpotMBeanClass = Class.forName("com.sun.management.HotSpotDiagnosticMXBean") + val getVMMethod = hotSpotMBeanClass.getDeclaredMethod("getVMOption", + Class.forName("java.lang.String")) + val bean = ManagementFactory.newPlatformMXBeanProxy(server, - hotSpotMBeanName, classOf[HotSpotDiagnosticMXBean]) - return bean.getVMOption("UseCompressedOops").getValue.toBoolean + hotSpotMBeanName, hotSpotMBeanClass) + // TODO: We could use reflection on the VMOption returned ? + return getVMMethod.invoke(bean, "UseCompressedOops").toString.contains("true") } catch { case e: Exception => { // Guess whether they've enabled UseCompressedOops based on whether maxMemory < 32 GB diff --git a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala index 340920025b..37524a7c82 100644 --- a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala @@ -104,9 +104,25 @@ private[spark] class WorkerArguments(args: Array[String]) { } def inferDefaultMemory(): Int = { - val bean = ManagementFactory.getOperatingSystemMXBean - .asInstanceOf[com.sun.management.OperatingSystemMXBean] - val totalMb = (bean.getTotalPhysicalMemorySize / 1024 / 1024).toInt + val ibmVendor = System.getProperty("java.vendor").contains("IBM") + var totalMb = 0 + try { + val bean = ManagementFactory.getOperatingSystemMXBean() + if (ibmVendor) { + val beanClass = Class.forName("com.ibm.lang.management.OperatingSystemMXBean") + val method = beanClass.getDeclaredMethod("getTotalPhysicalMemory") + totalMb = (method.invoke(bean).asInstanceOf[Long] / 1024 / 1024).toInt + } else { + val beanClass = Class.forName("com.sun.management.OperatingSystemMXBean") + val method = beanClass.getDeclaredMethod("getTotalPhysicalMemorySize") + totalMb = (method.invoke(bean).asInstanceOf[Long] / 1024 / 1024).toInt + } + } catch { + case e: Exception => { + totalMb = 2*1024 + System.out.println("Failed to get total physical memory. Using " + totalMb + " MB") + } + } // Leave out 1 GB for the operating system, but don't return a negative memory size math.max(totalMb - 1024, 512) } From aed368a970bbaee4bdf297ba3f6f1b0fa131452c Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Sat, 29 Dec 2012 16:23:43 -0800 Subject: [PATCH 056/291] Update Hadoop dependency to 1.0.3 as 0.20 has Sun specific dependencies. Also fix SequenceFileRDDFunctions to pick the right type conversion across Hadoop versions --- core/src/main/scala/spark/SequenceFileRDDFunctions.scala | 8 +++++++- project/SparkBuild.scala | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala index a34aee69c1..6b4a11d6d3 100644 --- a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala +++ b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala @@ -42,7 +42,13 @@ class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : Cla if (classOf[Writable].isAssignableFrom(classManifest[T].erasure)) { classManifest[T].erasure } else { - implicitly[T => Writable].getClass.getMethods()(0).getReturnType + // We get the type of the Writable class by looking at the apply method which converts + // from T to Writable. Since we have two apply methods we filter out the one which + // is of the form "java.lang.Object apply(java.lang.Object)" + implicitly[T => Writable].getClass.getDeclaredMethods().filter( + m => m.getReturnType().toString != "java.lang.Object" && + m.getName() == "apply")(0).getReturnType + } // TODO: use something like WritableConverter to avoid reflection } diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 842d0fa96b..7c7c33131a 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -10,7 +10,7 @@ import twirl.sbt.TwirlPlugin._ object SparkBuild extends Build { // Hadoop version to build against. For example, "0.20.2", "0.20.205.0", or // "1.0.3" for Apache releases, or "0.20.2-cdh3u5" for Cloudera Hadoop. - val HADOOP_VERSION = "0.20.205.0" + val HADOOP_VERSION = "1.0.3" val HADOOP_MAJOR_VERSION = "1" // For Hadoop 2 versions such as "2.0.0-mr1-cdh4.1.1", set the HADOOP_MAJOR_VERSION to "2" From 77d751731ccd06e161e3ef10540f8165d964282f Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Sat, 29 Dec 2012 18:28:00 -0800 Subject: [PATCH 057/291] Remove unused BoundedMemoryCache file and associated test case. --- .../main/scala/spark/BoundedMemoryCache.scala | 118 ------------------ .../scala/spark/BoundedMemoryCacheSuite.scala | 58 --------- 2 files changed, 176 deletions(-) delete mode 100644 core/src/main/scala/spark/BoundedMemoryCache.scala delete mode 100644 core/src/test/scala/spark/BoundedMemoryCacheSuite.scala diff --git a/core/src/main/scala/spark/BoundedMemoryCache.scala b/core/src/main/scala/spark/BoundedMemoryCache.scala deleted file mode 100644 index e8392a194f..0000000000 --- a/core/src/main/scala/spark/BoundedMemoryCache.scala +++ /dev/null @@ -1,118 +0,0 @@ -package spark - -import java.util.LinkedHashMap - -/** - * An implementation of Cache that estimates the sizes of its entries and attempts to limit its - * total memory usage to a fraction of the JVM heap. Objects' sizes are estimated using - * SizeEstimator, which has limitations; most notably, we will overestimate total memory used if - * some cache entries have pointers to a shared object. Nonetheless, this Cache should work well - * when most of the space is used by arrays of primitives or of simple classes. - */ -private[spark] class BoundedMemoryCache(maxBytes: Long) extends Cache with Logging { - logInfo("BoundedMemoryCache.maxBytes = " + maxBytes) - - def this() { - this(BoundedMemoryCache.getMaxBytes) - } - - private var currentBytes = 0L - private val map = new LinkedHashMap[(Any, Int), Entry](32, 0.75f, true) - - override def get(datasetId: Any, partition: Int): Any = { - synchronized { - val entry = map.get((datasetId, partition)) - if (entry != null) { - entry.value - } else { - null - } - } - } - - override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = { - val key = (datasetId, partition) - logInfo("Asked to add key " + key) - val size = estimateValueSize(key, value) - synchronized { - if (size > getCapacity) { - return CachePutFailure() - } else if (ensureFreeSpace(datasetId, size)) { - logInfo("Adding key " + key) - map.put(key, new Entry(value, size)) - currentBytes += size - logInfo("Number of entries is now " + map.size) - return CachePutSuccess(size) - } else { - logInfo("Didn't add key " + key + " because we would have evicted part of same dataset") - return CachePutFailure() - } - } - } - - override def getCapacity: Long = maxBytes - - /** - * Estimate sizeOf 'value' - */ - private def estimateValueSize(key: (Any, Int), value: Any) = { - val startTime = System.currentTimeMillis - val size = SizeEstimator.estimate(value.asInstanceOf[AnyRef]) - val timeTaken = System.currentTimeMillis - startTime - logInfo("Estimated size for key %s is %d".format(key, size)) - logInfo("Size estimation for key %s took %d ms".format(key, timeTaken)) - size - } - - /** - * Remove least recently used entries from the map until at least space bytes are free, in order - * to make space for a partition from the given dataset ID. If this cannot be done without - * evicting other data from the same dataset, returns false; otherwise, returns true. Assumes - * that a lock is held on the BoundedMemoryCache. - */ - private def ensureFreeSpace(datasetId: Any, space: Long): Boolean = { - logInfo("ensureFreeSpace(%s, %d) called with curBytes=%d, maxBytes=%d".format( - datasetId, space, currentBytes, maxBytes)) - val iter = map.entrySet.iterator // Will give entries in LRU order - while (maxBytes - currentBytes < space && iter.hasNext) { - val mapEntry = iter.next() - val (entryDatasetId, entryPartition) = mapEntry.getKey - if (entryDatasetId == datasetId) { - // Cannot make space without removing part of the same dataset, or a more recently used one - return false - } - reportEntryDropped(entryDatasetId, entryPartition, mapEntry.getValue) - currentBytes -= mapEntry.getValue.size - iter.remove() - } - return true - } - - protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) { - logInfo("Dropping key (%s, %d) of size %d to make space".format(datasetId, partition, entry.size)) - // TODO: remove BoundedMemoryCache - - val (keySpaceId, innerDatasetId) = datasetId.asInstanceOf[(Any, Any)] - innerDatasetId match { - case rddId: Int => - SparkEnv.get.cacheTracker.dropEntry(rddId, partition) - case broadcastUUID: java.util.UUID => - // TODO: Maybe something should be done if the broadcasted variable falls out of cache - case _ => - } - } -} - -// An entry in our map; stores a cached object and its size in bytes -private[spark] case class Entry(value: Any, size: Long) - -private[spark] object BoundedMemoryCache { - /** - * Get maximum cache capacity from system configuration - */ - def getMaxBytes: Long = { - val memoryFractionToUse = System.getProperty("spark.boundedMemoryCache.memoryFraction", "0.66").toDouble - (Runtime.getRuntime.maxMemory * memoryFractionToUse).toLong - } -} - diff --git a/core/src/test/scala/spark/BoundedMemoryCacheSuite.scala b/core/src/test/scala/spark/BoundedMemoryCacheSuite.scala deleted file mode 100644 index 37cafd1e8e..0000000000 --- a/core/src/test/scala/spark/BoundedMemoryCacheSuite.scala +++ /dev/null @@ -1,58 +0,0 @@ -package spark - -import org.scalatest.FunSuite -import org.scalatest.PrivateMethodTester -import org.scalatest.matchers.ShouldMatchers - -// TODO: Replace this with a test of MemoryStore -class BoundedMemoryCacheSuite extends FunSuite with PrivateMethodTester with ShouldMatchers { - test("constructor test") { - val cache = new BoundedMemoryCache(60) - expect(60)(cache.getCapacity) - } - - test("caching") { - // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case - val oldArch = System.setProperty("os.arch", "amd64") - val oldOops = System.setProperty("spark.test.useCompressedOops", "true") - val initialize = PrivateMethod[Unit]('initialize) - SizeEstimator invokePrivate initialize() - - val cache = new BoundedMemoryCache(60) { - //TODO sorry about this, but there is not better way how to skip 'cacheTracker.dropEntry' - override protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) { - logInfo("Dropping key (%s, %d) of size %d to make space".format(datasetId, partition, entry.size)) - } - } - - // NOTE: The String class definition changed in JDK 7 to exclude the int fields count and length - // This means that the size of strings will be lesser by 8 bytes in JDK 7 compared to JDK 6. - // http://mail.openjdk.java.net/pipermail/core-libs-dev/2012-May/010257.html - // Work around to check for either. - - //should be OK - cache.put("1", 0, "Meh") should (equal (CachePutSuccess(56)) or equal (CachePutSuccess(48))) - - //we cannot add this to cache (there is not enough space in cache) & we cannot evict the only value from - //cache because it's from the same dataset - expect(CachePutFailure())(cache.put("1", 1, "Meh")) - - //should be OK, dataset '1' can be evicted from cache - cache.put("2", 0, "Meh") should (equal (CachePutSuccess(56)) or equal (CachePutSuccess(48))) - - //should fail, cache should obey it's capacity - expect(CachePutFailure())(cache.put("3", 0, "Very_long_and_useless_string")) - - if (oldArch != null) { - System.setProperty("os.arch", oldArch) - } else { - System.clearProperty("os.arch") - } - - if (oldOops != null) { - System.setProperty("spark.test.useCompressedOops", oldOops) - } else { - System.clearProperty("spark.test.useCompressedOops") - } - } -} From 55c66d365f76f3e5ecc6b850ba81c84b320f6772 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Mon, 7 Jan 2013 15:19:33 -0800 Subject: [PATCH 058/291] Use a dummy string class in Size Estimator tests to make it resistant to jdk versions --- .../test/scala/spark/SizeEstimatorSuite.scala | 33 ++++++++++++------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/core/src/test/scala/spark/SizeEstimatorSuite.scala b/core/src/test/scala/spark/SizeEstimatorSuite.scala index 17f366212b..bf3b2e1eed 100644 --- a/core/src/test/scala/spark/SizeEstimatorSuite.scala +++ b/core/src/test/scala/spark/SizeEstimatorSuite.scala @@ -20,6 +20,15 @@ class DummyClass4(val d: DummyClass3) { val x: Int = 0 } +object DummyString { + def apply(str: String) : DummyString = new DummyString(str.toArray) +} +class DummyString(val arr: Array[Char]) { + override val hashCode: Int = 0 + // JDK-7 has an extra hash32 field http://hg.openjdk.java.net/jdk7u/jdk7u6/jdk/rev/11987e85555f + @transient val hash32: Int = 0 +} + class SizeEstimatorSuite extends FunSuite with BeforeAndAfterAll with PrivateMethodTester with ShouldMatchers { @@ -50,10 +59,10 @@ class SizeEstimatorSuite // http://mail.openjdk.java.net/pipermail/core-libs-dev/2012-May/010257.html // Work around to check for either. test("strings") { - SizeEstimator.estimate("") should (equal (48) or equal (40)) - SizeEstimator.estimate("a") should (equal (56) or equal (48)) - SizeEstimator.estimate("ab") should (equal (56) or equal (48)) - SizeEstimator.estimate("abcdefgh") should (equal(64) or equal(56)) + SizeEstimator.estimate(DummyString("")) should (equal (48) or equal (40)) + SizeEstimator.estimate(DummyString("a")) should (equal (56) or equal (48)) + SizeEstimator.estimate(DummyString("ab")) should (equal (56) or equal (48)) + SizeEstimator.estimate(DummyString("abcdefgh")) should (equal(64) or equal(56)) } test("primitive arrays") { @@ -105,10 +114,10 @@ class SizeEstimatorSuite val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() - expect(40)(SizeEstimator.estimate("")) - expect(48)(SizeEstimator.estimate("a")) - expect(48)(SizeEstimator.estimate("ab")) - expect(56)(SizeEstimator.estimate("abcdefgh")) + expect(40)(SizeEstimator.estimate(DummyString(""))) + expect(48)(SizeEstimator.estimate(DummyString("a"))) + expect(48)(SizeEstimator.estimate(DummyString("ab"))) + expect(56)(SizeEstimator.estimate(DummyString("abcdefgh"))) resetOrClear("os.arch", arch) } @@ -124,10 +133,10 @@ class SizeEstimatorSuite val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() - SizeEstimator.estimate("") should (equal (64) or equal (56)) - SizeEstimator.estimate("a") should (equal (72) or equal (64)) - SizeEstimator.estimate("ab") should (equal (72) or equal (64)) - SizeEstimator.estimate("abcdefgh") should (equal (80) or equal (72)) + SizeEstimator.estimate(DummyString("")) should (equal (64) or equal (56)) + SizeEstimator.estimate(DummyString("a")) should (equal (72) or equal (64)) + SizeEstimator.estimate(DummyString("ab")) should (equal (72) or equal (64)) + SizeEstimator.estimate(DummyString("abcdefgh")) should (equal (80) or equal (72)) resetOrClear("os.arch", arch) resetOrClear("spark.test.useCompressedOops", oops) From fb3d4d5e85cd4b094411bb08a32ab50cc62dc151 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Mon, 7 Jan 2013 16:46:06 -0800 Subject: [PATCH 059/291] Make default hadoop version 1.0.3 in pom.xml --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index b33cee26b8..fe5b1d0ee4 100644 --- a/pom.xml +++ b/pom.xml @@ -489,7 +489,7 @@ org.apache.hadoop hadoop-core - 0.20.205.0 + 1.0.3 From b1336e2fe458b92dcf60dcd249c41c7bdcc8be6d Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Mon, 7 Jan 2013 17:00:32 -0800 Subject: [PATCH 060/291] Update expected size of strings to match our dummy string class --- .../test/scala/spark/SizeEstimatorSuite.scala | 31 ++++++++----------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/core/src/test/scala/spark/SizeEstimatorSuite.scala b/core/src/test/scala/spark/SizeEstimatorSuite.scala index bf3b2e1eed..e235ef2f67 100644 --- a/core/src/test/scala/spark/SizeEstimatorSuite.scala +++ b/core/src/test/scala/spark/SizeEstimatorSuite.scala @@ -3,7 +3,6 @@ package spark import org.scalatest.FunSuite import org.scalatest.BeforeAndAfterAll import org.scalatest.PrivateMethodTester -import org.scalatest.matchers.ShouldMatchers class DummyClass1 {} @@ -30,7 +29,7 @@ class DummyString(val arr: Array[Char]) { } class SizeEstimatorSuite - extends FunSuite with BeforeAndAfterAll with PrivateMethodTester with ShouldMatchers { + extends FunSuite with BeforeAndAfterAll with PrivateMethodTester { var oldArch: String = _ var oldOops: String = _ @@ -54,15 +53,13 @@ class SizeEstimatorSuite expect(48)(SizeEstimator.estimate(new DummyClass4(new DummyClass3))) } - // NOTE: The String class definition changed in JDK 7 to exclude the int fields count and length. - // This means that the size of strings will be lesser by 8 bytes in JDK 7 compared to JDK 6. - // http://mail.openjdk.java.net/pipermail/core-libs-dev/2012-May/010257.html - // Work around to check for either. + // NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors + // (Sun vs IBM). Use a DummyString class to make tests deterministic. test("strings") { - SizeEstimator.estimate(DummyString("")) should (equal (48) or equal (40)) - SizeEstimator.estimate(DummyString("a")) should (equal (56) or equal (48)) - SizeEstimator.estimate(DummyString("ab")) should (equal (56) or equal (48)) - SizeEstimator.estimate(DummyString("abcdefgh")) should (equal(64) or equal(56)) + expect(40)(SizeEstimator.estimate(DummyString(""))) + expect(48)(SizeEstimator.estimate(DummyString("a"))) + expect(48)(SizeEstimator.estimate(DummyString("ab"))) + expect(56)(SizeEstimator.estimate(DummyString("abcdefgh"))) } test("primitive arrays") { @@ -122,10 +119,8 @@ class SizeEstimatorSuite resetOrClear("os.arch", arch) } - // NOTE: The String class definition changed in JDK 7 to exclude the int fields count and length. - // This means that the size of strings will be lesser by 8 bytes in JDK 7 compared to JDK 6. - // http://mail.openjdk.java.net/pipermail/core-libs-dev/2012-May/010257.html - // Work around to check for either. + // NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors + // (Sun vs IBM). Use a DummyString class to make tests deterministic. test("64-bit arch with no compressed oops") { val arch = System.setProperty("os.arch", "amd64") val oops = System.setProperty("spark.test.useCompressedOops", "false") @@ -133,10 +128,10 @@ class SizeEstimatorSuite val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() - SizeEstimator.estimate(DummyString("")) should (equal (64) or equal (56)) - SizeEstimator.estimate(DummyString("a")) should (equal (72) or equal (64)) - SizeEstimator.estimate(DummyString("ab")) should (equal (72) or equal (64)) - SizeEstimator.estimate(DummyString("abcdefgh")) should (equal (80) or equal (72)) + expect(56)(SizeEstimator.estimate(DummyString(""))) + expect(64)(SizeEstimator.estimate(DummyString("a"))) + expect(64)(SizeEstimator.estimate(DummyString("ab"))) + expect(72)(SizeEstimator.estimate(DummyString("abcdefgh"))) resetOrClear("os.arch", arch) resetOrClear("spark.test.useCompressedOops", oops) From 4bbe07e5ece81fa874d2412bcc165179313a7619 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Mon, 7 Jan 2013 17:46:22 -0800 Subject: [PATCH 061/291] Activate hadoop1 profile by default for maven builds --- bagel/pom.xml | 3 +++ core/pom.xml | 5 ++++- examples/pom.xml | 3 +++ pom.xml | 3 +++ repl-bin/pom.xml | 3 +++ repl/pom.xml | 3 +++ 6 files changed, 19 insertions(+), 1 deletion(-) diff --git a/bagel/pom.xml b/bagel/pom.xml index a8256a6e8b..4ca643bbb7 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -45,6 +45,9 @@ hadoop1 + + true + org.spark-project diff --git a/core/pom.xml b/core/pom.xml index ae52c20657..cd789a7db0 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -159,6 +159,9 @@ hadoop1 + + true + org.apache.hadoop @@ -267,4 +270,4 @@ - \ No newline at end of file + diff --git a/examples/pom.xml b/examples/pom.xml index 782c026d73..9e638c8284 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -45,6 +45,9 @@ hadoop1 + + true + org.spark-project diff --git a/pom.xml b/pom.xml index fe5b1d0ee4..0e2d93c170 100644 --- a/pom.xml +++ b/pom.xml @@ -481,6 +481,9 @@ hadoop1 + + true + 1 diff --git a/repl-bin/pom.xml b/repl-bin/pom.xml index 0667b71cc7..aa9895eda2 100644 --- a/repl-bin/pom.xml +++ b/repl-bin/pom.xml @@ -70,6 +70,9 @@ hadoop1 + + true + hadoop1 diff --git a/repl/pom.xml b/repl/pom.xml index 114e3e9932..ba7a051310 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -72,6 +72,9 @@ hadoop1 + + true + hadoop1 From c41042c816c2d6299aa7d93529b7c39db5d5c03a Mon Sep 17 00:00:00 2001 From: Mikhail Bautin Date: Wed, 26 Dec 2012 15:52:51 -0800 Subject: [PATCH 062/291] Log preferred hosts --- .../main/scala/spark/scheduler/cluster/TaskSetManager.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index cf4aae03a7..dda7a6c64a 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -201,7 +201,9 @@ private[spark] class TaskSetManager( val taskId = sched.newTaskId() // Figure out whether this should count as a preferred launch val preferred = isPreferredLocation(task, host) - val prefStr = if (preferred) "preferred" else "non-preferred" + val prefStr = if (preferred) "preferred" else + "non-preferred, not one of " + + task.preferredLocations.mkString(", ") logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format( taskSet.id, index, taskId, slaveId, host, prefStr)) // Do various bookkeeping From 4725b0f6439337c7a0f5f6fc7034c6f6b9488ae9 Mon Sep 17 00:00:00 2001 From: Mikhail Bautin Date: Mon, 7 Jan 2013 20:07:08 -0800 Subject: [PATCH 063/291] Fixing if/else coding style for preferred hosts logging --- .../main/scala/spark/scheduler/cluster/TaskSetManager.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index dda7a6c64a..a842afcdeb 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -201,9 +201,8 @@ private[spark] class TaskSetManager( val taskId = sched.newTaskId() // Figure out whether this should count as a preferred launch val preferred = isPreferredLocation(task, host) - val prefStr = if (preferred) "preferred" else - "non-preferred, not one of " + - task.preferredLocations.mkString(", ") + val prefStr = if (preferred) "preferred" + else "non-preferred, not one of " + task.preferredLocations.mkString(", ") logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format( taskSet.id, index, taskId, slaveId, host, prefStr)) // Do various bookkeeping From f7adb382ace7f54c5093bf90574b3f9dd0d35534 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Tue, 8 Jan 2013 03:19:43 -0800 Subject: [PATCH 064/291] Activate hadoop1 if property hadoop is missing. hadoop2 can be activated now by using -Dhadoop -Phadoop2. --- bagel/pom.xml | 4 +++- core/pom.xml | 4 +++- examples/pom.xml | 4 +++- pom.xml | 4 +++- repl-bin/pom.xml | 4 +++- repl/pom.xml | 4 +++- 6 files changed, 18 insertions(+), 6 deletions(-) diff --git a/bagel/pom.xml b/bagel/pom.xml index 4ca643bbb7..85b2077026 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -46,7 +46,9 @@ hadoop1 - true + + !hadoop + diff --git a/core/pom.xml b/core/pom.xml index cd789a7db0..005d8fe498 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -160,7 +160,9 @@ hadoop1 - true + + !hadoop + diff --git a/examples/pom.xml b/examples/pom.xml index 9e638c8284..3f738a3f8c 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -46,7 +46,9 @@ hadoop1 - true + + !hadoop + diff --git a/pom.xml b/pom.xml index 0e2d93c170..ea5b9c9d05 100644 --- a/pom.xml +++ b/pom.xml @@ -482,7 +482,9 @@ hadoop1 - true + + !hadoop + 1 diff --git a/repl-bin/pom.xml b/repl-bin/pom.xml index aa9895eda2..fecb01f3cd 100644 --- a/repl-bin/pom.xml +++ b/repl-bin/pom.xml @@ -71,7 +71,9 @@ hadoop1 - true + + !hadoop + hadoop1 diff --git a/repl/pom.xml b/repl/pom.xml index ba7a051310..04b2c35beb 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -73,7 +73,9 @@ hadoop1 - true + + !hadoop + hadoop1 From e4cb72da8a5428c6b9097e92ddbdf4ceee087b85 Mon Sep 17 00:00:00 2001 From: shane-huang Date: Tue, 8 Jan 2013 22:40:58 +0800 Subject: [PATCH 065/291] Fix an issue in ConnectionManager where sendingMessage may create too many unnecessary SendingConnections. --- .../main/scala/spark/network/Connection.scala | 7 +++++-- .../spark/network/ConnectionManager.scala | 17 +++++++++-------- .../spark/network/ConnectionManagerTest.scala | 18 +++++++++--------- 3 files changed, 23 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala index 80262ab7b4..95096fd0ba 100644 --- a/core/src/main/scala/spark/network/Connection.scala +++ b/core/src/main/scala/spark/network/Connection.scala @@ -135,8 +135,11 @@ extends Connection(SocketChannel.open, selector_) { val chunk = message.getChunkForSending(defaultChunkSize) if (chunk.isDefined) { messages += message // this is probably incorrect, it wont work as fifo - if (!message.started) logDebug("Starting to send [" + message + "]") - message.started = true + if (!message.started) { + logDebug("Starting to send [" + message + "]") + message.started = true + message.startTime = System.currentTimeMillis + } return chunk } else { /*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/ diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index 642fa4b525..e7bd2d3bbd 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -43,12 +43,12 @@ private[spark] class ConnectionManager(port: Int) extends Logging { } val selector = SelectorProvider.provider.openSelector() - val handleMessageExecutor = Executors.newFixedThreadPool(4) + val handleMessageExecutor = Executors.newFixedThreadPool(20) val serverChannel = ServerSocketChannel.open() val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] val messageStatuses = new HashMap[Int, MessageStatus] - val connectionRequests = new SynchronizedQueue[SendingConnection] + val connectionRequests = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] val sendMessageRequests = new Queue[(Message, SendingConnection)] @@ -78,11 +78,12 @@ private[spark] class ConnectionManager(port: Int) extends Logging { def run() { try { - while(!selectorThread.isInterrupted) { - while(!connectionRequests.isEmpty) { - val sendingConnection = connectionRequests.dequeue + while(!selectorThread.isInterrupted) { + for( (connectionManagerId, sendingConnection) <- connectionRequests) { + //val sendingConnection = connectionRequests.dequeue sendingConnection.connect() addConnection(sendingConnection) + connectionRequests -= connectionManagerId } sendMessageRequests.synchronized { while(!sendMessageRequests.isEmpty) { @@ -300,8 +301,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging { private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) { def startNewConnection(): SendingConnection = { val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port) - val newConnection = new SendingConnection(inetSocketAddress, selector) - connectionRequests += newConnection + val newConnection = connectionRequests.getOrElseUpdate(connectionManagerId, new SendingConnection(inetSocketAddress, selector)) newConnection } val lookupKey = ConnectionManagerId.fromSocketAddress(connectionManagerId.toSocketAddress) @@ -465,7 +465,7 @@ private[spark] object ConnectionManager { val bufferMessage = Message.createBufferMessage(buffer.duplicate) manager.sendMessageReliably(manager.id, bufferMessage) }).foreach(f => { - val g = Await.result(f, 1 second) + val g = Await.result(f, 10 second) if (!g.isDefined) println("Failed") }) val finishTime = System.currentTimeMillis @@ -473,6 +473,7 @@ private[spark] object ConnectionManager { val mb = size * count / 1024.0 / 1024.0 val ms = finishTime - startTime val tput = mb * 1000.0 / ms + println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)") println("--------------------------") println() } diff --git a/core/src/main/scala/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/spark/network/ConnectionManagerTest.scala index 47ceaf3c07..0e79c518e0 100644 --- a/core/src/main/scala/spark/network/ConnectionManagerTest.scala +++ b/core/src/main/scala/spark/network/ConnectionManagerTest.scala @@ -13,8 +13,8 @@ import akka.util.duration._ private[spark] object ConnectionManagerTest extends Logging{ def main(args: Array[String]) { - if (args.length < 2) { - println("Usage: ConnectionManagerTest ") + if (args.length < 5) { + println("Usage: ConnectionManagerTest ") System.exit(1) } @@ -29,16 +29,16 @@ private[spark] object ConnectionManagerTest extends Logging{ /*println("Slaves")*/ /*slaves.foreach(println)*/ - - val slaveConnManagerIds = sc.parallelize(0 until slaves.length, slaves.length).map( + val tasknum = args(2).toInt + val slaveConnManagerIds = sc.parallelize(0 until tasknum, tasknum).map( i => SparkEnv.get.connectionManager.id).collect() println("\nSlave ConnectionManagerIds") slaveConnManagerIds.foreach(println) println - val count = 10 + val count = args(4).toInt (0 until count).foreach(i => { - val resultStrs = sc.parallelize(0 until slaves.length, slaves.length).map(i => { + val resultStrs = sc.parallelize(0 until tasknum, tasknum).map(i => { val connManager = SparkEnv.get.connectionManager val thisConnManagerId = connManager.id connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { @@ -46,7 +46,7 @@ private[spark] object ConnectionManagerTest extends Logging{ None }) - val size = 100 * 1024 * 1024 + val size = (args(3).toInt) * 1024 * 1024 val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) buffer.flip @@ -56,13 +56,13 @@ private[spark] object ConnectionManagerTest extends Logging{ logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]") connManager.sendMessageReliably(slaveConnManagerId, bufferMessage) }) - val results = futures.map(f => Await.result(f, 1.second)) + val results = futures.map(f => Await.result(f, 999.second)) val finishTime = System.currentTimeMillis Thread.sleep(5000) val mb = size * results.size / 1024.0 / 1024.0 val ms = finishTime - startTime - val resultStr = "Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s" + val resultStr = thisConnManagerId + " Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s" logInfo(resultStr) resultStr }).collect() From 8ac0f35be42765fcd6f02dcf0f070f2ef2377a85 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 8 Jan 2013 09:57:45 -0600 Subject: [PATCH 066/291] Add JavaRDDLike.keyBy. --- core/src/main/scala/spark/api/java/JavaRDDLike.scala | 8 ++++++++ core/src/test/scala/spark/JavaAPISuite.java | 12 ++++++++++++ 2 files changed, 20 insertions(+) diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index 81d3a94466..d15f6dd02f 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -298,4 +298,12 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Save this RDD as a SequenceFile of serialized objects. */ def saveAsObjectFile(path: String) = rdd.saveAsObjectFile(path) + + /** + * Creates tuples of the elements in this RDD by applying `f`. + */ + def keyBy[K](f: JFunction[T, K]): JavaPairRDD[K, T] = { + implicit val kcm: ClassManifest[K] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] + JavaPairRDD.fromRDD(rdd.keyBy(f)) + } } diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 0817d1146c..c61913fc82 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -629,4 +629,16 @@ public class JavaAPISuite implements Serializable { floatAccum.setValue(5.0f); Assert.assertEquals((Float) 5.0f, floatAccum.value()); } + + @Test + public void keyBy() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2)); + List> s = rdd.keyBy(new Function() { + public String call(Integer t) throws Exception { + return t.toString(); + } + }).collect(); + Assert.assertEquals(new Tuple2("1", 1), s.get(0)); + Assert.assertEquals(new Tuple2("2", 2), s.get(1)); + } } From c3f1675f9c4a1be9eebf9512795abc968ac29ba2 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 8 Jan 2013 14:44:33 -0600 Subject: [PATCH 067/291] Retrieve jars to a flat directory so * can be used for the classpath. --- project/SparkBuild.scala | 1 + run | 12 +++--------- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 7c7c33131a..518c4130f0 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -38,6 +38,7 @@ object SparkBuild extends Build { scalacOptions := Seq(/*"-deprecation",*/ "-unchecked", "-optimize"), // -deprecation is too noisy due to usage of old Hadoop API, enable it once that's no longer an issue unmanagedJars in Compile <<= baseDirectory map { base => (base / "lib" ** "*.jar").classpath }, retrieveManaged := true, + retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", transitiveClassifiers in Scope.GlobalScope := Seq("sources"), testListeners <<= target.map(t => Seq(new eu.henkelmann.sbt.JUnitXmlTestsListener(t.getAbsolutePath))), diff --git a/run b/run index 1528f83534..6cfe9631af 100755 --- a/run +++ b/run @@ -75,16 +75,10 @@ CLASSPATH+=":$CORE_DIR/src/main/resources" CLASSPATH+=":$REPL_DIR/target/scala-$SCALA_VERSION/classes" CLASSPATH+=":$EXAMPLES_DIR/target/scala-$SCALA_VERSION/classes" if [ -e "$FWDIR/lib_managed" ]; then - for jar in `find "$FWDIR/lib_managed/jars" -name '*jar'`; do - CLASSPATH+=":$jar" - done - for jar in `find "$FWDIR/lib_managed/bundles" -name '*jar'`; do - CLASSPATH+=":$jar" - done + CLASSPATH+=":$FWDIR/lib_managed/jars/*" + CLASSPATH+=":$FWDIR/lib_managed/bundles/*" fi -for jar in `find "$REPL_DIR/lib" -name '*jar'`; do - CLASSPATH+=":$jar" -done +CLASSPATH+=":$REPL_DIR/lib/*" for jar in `find "$REPL_DIR/target" -name 'spark-repl-*-shaded-hadoop*.jar'`; do CLASSPATH+=":$jar" done From b57dd0f16024a82dfc223e69528b9908b931f068 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 8 Jan 2013 16:04:41 -0800 Subject: [PATCH 068/291] Add mapPartitionsWithSplit() to PySpark. --- .../scala/spark/api/python/PythonRDD.scala | 5 +++ docs/python-programming-guide.md | 1 - python/pyspark/rdd.py | 33 ++++++++++++------- python/pyspark/worker.py | 4 ++- 4 files changed, 30 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 79d824d494..f431ef28d3 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -65,6 +65,9 @@ private[spark] class PythonRDD[T: ClassManifest]( SparkEnv.set(env) val out = new PrintWriter(proc.getOutputStream) val dOut = new DataOutputStream(proc.getOutputStream) + // Split index + dOut.writeInt(split.index) + // Broadcast variables dOut.writeInt(broadcastVars.length) for (broadcast <- broadcastVars) { dOut.writeLong(broadcast.id) @@ -72,10 +75,12 @@ private[spark] class PythonRDD[T: ClassManifest]( dOut.write(broadcast.value) dOut.flush() } + // Serialized user code for (elem <- command) { out.println(elem) } out.flush() + // Data values for (elem <- parent.iterator(split, context)) { PythonRDD.writeAsPickle(elem, dOut) } diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md index d963551296..78ef310a00 100644 --- a/docs/python-programming-guide.md +++ b/docs/python-programming-guide.md @@ -19,7 +19,6 @@ There are a few key differences between the Python and Scala APIs: - Accumulators - Special functions on RDDs of doubles, such as `mean` and `stdev` - `lookup` - - `mapPartitionsWithSplit` - `persist` at storage levels other than `MEMORY_ONLY` - `sample` - `sort` diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 4ba417b2a2..1d36da42b0 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -55,7 +55,7 @@ class RDD(object): """ Return a new RDD containing the distinct elements in this RDD. """ - def func(iterator): return imap(f, iterator) + def func(split, iterator): return imap(f, iterator) return PipelinedRDD(self, func, preservesPartitioning) def flatMap(self, f, preservesPartitioning=False): @@ -69,8 +69,8 @@ class RDD(object): >>> sorted(rdd.flatMap(lambda x: [(x, x), (x, x)]).collect()) [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] """ - def func(iterator): return chain.from_iterable(imap(f, iterator)) - return self.mapPartitions(func, preservesPartitioning) + def func(s, iterator): return chain.from_iterable(imap(f, iterator)) + return self.mapPartitionsWithSplit(func, preservesPartitioning) def mapPartitions(self, f, preservesPartitioning=False): """ @@ -81,9 +81,20 @@ class RDD(object): >>> rdd.mapPartitions(f).collect() [3, 7] """ - return PipelinedRDD(self, f, preservesPartitioning) + def func(s, iterator): return f(iterator) + return self.mapPartitionsWithSplit(func) - # TODO: mapPartitionsWithSplit + def mapPartitionsWithSplit(self, f, preservesPartitioning=False): + """ + Return a new RDD by applying a function to each partition of this RDD, + while tracking the index of the original partition. + + >>> rdd = sc.parallelize([1, 2, 3, 4], 4) + >>> def f(splitIndex, iterator): yield splitIndex + >>> rdd.mapPartitionsWithSplit(f).sum() + 6 + """ + return PipelinedRDD(self, f, preservesPartitioning) def filter(self, f): """ @@ -362,7 +373,7 @@ class RDD(object): >>> ''.join(input(glob(tempFile.name + "/part-0000*"))) '0\\n1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n' """ - def func(iterator): + def func(split, iterator): return (str(x).encode("utf-8") for x in iterator) keyed = PipelinedRDD(self, func) keyed._bypass_serializer = True @@ -500,7 +511,7 @@ class RDD(object): # Transferring O(n) objects to Java is too expensive. Instead, we'll # form the hash buckets in Python, transferring O(numSplits) objects # to Java. Each object is a (splitNumber, [objects]) pair. - def add_shuffle_key(iterator): + def add_shuffle_key(split, iterator): buckets = defaultdict(list) for (k, v) in iterator: buckets[hashFunc(k) % numSplits].append((k, v)) @@ -653,8 +664,8 @@ class PipelinedRDD(RDD): def __init__(self, prev, func, preservesPartitioning=False): if isinstance(prev, PipelinedRDD) and not prev.is_cached: prev_func = prev.func - def pipeline_func(iterator): - return func(prev_func(iterator)) + def pipeline_func(split, iterator): + return func(split, prev_func(split, iterator)) self.func = pipeline_func self.preservesPartitioning = \ prev.preservesPartitioning and preservesPartitioning @@ -677,8 +688,8 @@ class PipelinedRDD(RDD): if not self._bypass_serializer and self.ctx.batchSize != 1: oldfunc = self.func batchSize = self.ctx.batchSize - def batched_func(iterator): - return batched(oldfunc(iterator), batchSize) + def batched_func(split, iterator): + return batched(oldfunc(split, iterator), batchSize) func = batched_func cmds = [func, self._bypass_serializer] pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 9f6b507dbd..3d792bbaa2 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -21,6 +21,7 @@ def load_obj(): def main(): + split_index = read_int(sys.stdin) num_broadcast_variables = read_int(sys.stdin) for _ in range(num_broadcast_variables): bid = read_long(sys.stdin) @@ -32,7 +33,8 @@ def main(): dumps = lambda x: x else: dumps = dump_pickle - for obj in func(read_from_pickle_file(sys.stdin)): + iterator = read_from_pickle_file(sys.stdin) + for obj in func(split_index, iterator): write_with_length(dumps(obj), old_stdout) From 9cc764f52323baa3a218ce9e301d3cc98f1e8b20 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Tue, 8 Jan 2013 22:29:57 -0800 Subject: [PATCH 069/291] Code style --- .../scala/spark/scheduler/cluster/TaskSetManager.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index a842afcdeb..a089b71644 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -201,8 +201,11 @@ private[spark] class TaskSetManager( val taskId = sched.newTaskId() // Figure out whether this should count as a preferred launch val preferred = isPreferredLocation(task, host) - val prefStr = if (preferred) "preferred" - else "non-preferred, not one of " + task.preferredLocations.mkString(", ") + val prefStr = if (preferred) { + "preferred" + } else { + "non-preferred, not one of " + task.preferredLocations.mkString(", ") + } logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format( taskSet.id, index, taskId, slaveId, host, prefStr)) // Do various bookkeeping From 6e8c8f61c478ec5829677a38a624f17ac9609f74 Mon Sep 17 00:00:00 2001 From: Tyson Date: Wed, 9 Jan 2013 10:35:23 -0500 Subject: [PATCH 070/291] Added the spray implicit marshaller library Added the io.spray JSON library --- project/SparkBuild.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 2f67bb9921..f2b79d9ed8 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -133,6 +133,8 @@ object SparkBuild extends Build { "colt" % "colt" % "1.2.0", "cc.spray" % "spray-can" % "1.0-M2.1", "cc.spray" % "spray-server" % "1.0-M2.1", + "cc.spray" %% "spray-json" % "1.1.1", + "io.spray" %% "spray-json" % "1.2.3", "org.apache.mesos" % "mesos" % "0.9.0-incubating" ) ++ (if (HADOOP_MAJOR_VERSION == "2") Some("org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION) else None).toSeq, unmanagedSourceDirectories in Compile <+= baseDirectory{ _ / ("src/hadoop" + HADOOP_MAJOR_VERSION + "/scala") } From 269fe018c73a0d4e12a3c881dbd3bd807e504891 Mon Sep 17 00:00:00 2001 From: Tyson Date: Wed, 9 Jan 2013 10:35:59 -0500 Subject: [PATCH 071/291] JSON object definitions --- .../scala/spark/deploy/JsonProtocol.scala | 59 +++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 core/src/main/scala/spark/deploy/JsonProtocol.scala diff --git a/core/src/main/scala/spark/deploy/JsonProtocol.scala b/core/src/main/scala/spark/deploy/JsonProtocol.scala new file mode 100644 index 0000000000..dc7da85f9c --- /dev/null +++ b/core/src/main/scala/spark/deploy/JsonProtocol.scala @@ -0,0 +1,59 @@ +package spark.deploy + +import master.{JobInfo, WorkerInfo} +import spray.json._ + +/** + * spray-json helper class containing implicit conversion to json for marshalling responses + */ +private[spark] object JsonProtocol extends DefaultJsonProtocol { + import cc.spray.json._ + + implicit object WorkerInfoJsonFormat extends RootJsonWriter[WorkerInfo] { + def write(obj: WorkerInfo) = JsObject( + "id" -> JsString(obj.id), + "host" -> JsString(obj.host), + "webuiaddress" -> JsString(obj.webUiAddress), + "cores" -> JsNumber(obj.cores), + "coresused" -> JsNumber(obj.coresUsed), + "memory" -> JsNumber(obj.memory), + "memoryused" -> JsNumber(obj.memoryUsed) + ) + } + + implicit object JobInfoJsonFormat extends RootJsonWriter[JobInfo] { + def write(obj: JobInfo) = JsObject( + "starttime" -> JsNumber(obj.startTime), + "id" -> JsString(obj.id), + "name" -> JsString(obj.desc.name), + "cores" -> JsNumber(obj.desc.cores), + "user" -> JsString(obj.desc.user), + "memoryperslave" -> JsNumber(obj.desc.memoryPerSlave), + "submitdate" -> JsString(obj.submitDate.toString)) + } + + implicit object MasterStateJsonFormat extends RootJsonWriter[MasterState] { + def write(obj: MasterState) = JsObject( + "url" -> JsString("spark://" + obj.uri), + "workers" -> JsArray(obj.workers.toList.map(_.toJson)), + "cores" -> JsNumber(obj.workers.map(_.cores).sum), + "coresused" -> JsNumber(obj.workers.map(_.coresUsed).sum), + "memory" -> JsNumber(obj.workers.map(_.memory).sum), + "memoryused" -> JsNumber(obj.workers.map(_.memoryUsed).sum), + "activejobs" -> JsArray(obj.activeJobs.toList.map(_.toJson)), + "completedjobs" -> JsArray(obj.completedJobs.toList.map(_.toJson)) + ) + } + + implicit object WorkerStateJsonFormat extends RootJsonWriter[WorkerState] { + def write(obj: WorkerState) = JsObject( + "id" -> JsString(obj.workerId), + "masterurl" -> JsString(obj.masterUrl), + "masterwebuiurl" -> JsString(obj.masterWebUiUrl), + "cores" -> JsNumber(obj.cores), + "coresused" -> JsNumber(obj.coresUsed), + "memory" -> JsNumber(obj.memory), + "memoryused" -> JsNumber(obj.memoryUsed) + ) + } +} From 0da2ff102e1e8ac50059252a153a1b9b3e74b6b8 Mon Sep 17 00:00:00 2001 From: Tyson Date: Wed, 9 Jan 2013 10:36:56 -0500 Subject: [PATCH 072/291] Added url query parameter json and handler --- .../spark/deploy/master/MasterWebUI.scala | 19 +++++++++++++----- .../spark/deploy/worker/WorkerWebUI.scala | 20 ++++++++++++++----- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala index 3cdd3721f5..dfec1d1dc5 100644 --- a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala +++ b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala @@ -9,6 +9,9 @@ import cc.spray.Directives import cc.spray.directives._ import cc.spray.typeconversion.TwirlSupport._ import spark.deploy._ +import cc.spray.http.MediaTypes +import JsonProtocol._ +import cc.spray.typeconversion.SprayJsonSupport._ private[spark] class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Directives { @@ -19,13 +22,19 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct val handler = { get { - path("") { - completeWith { + (path("") & parameters('json ?)) { + case Some(js) => val future = master ? RequestMasterState - future.map { - masterState => spark.deploy.master.html.index.render(masterState.asInstanceOf[MasterState]) + respondWithMediaType(MediaTypes.`application/json`) { ctx => + ctx.complete(future.mapTo[MasterState]) + } + case None => + completeWith { + val future = master ? RequestMasterState + future.map { + masterState => spark.deploy.master.html.index.render(masterState.asInstanceOf[MasterState]) + } } - } } ~ path("job") { parameter("jobId") { jobId => diff --git a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala index d06f4884ee..a168f54ca0 100644 --- a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala +++ b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala @@ -7,7 +7,10 @@ import akka.util.Timeout import akka.util.duration._ import cc.spray.Directives import cc.spray.typeconversion.TwirlSupport._ -import spark.deploy.{WorkerState, RequestWorkerState} +import spark.deploy.{JsonProtocol, WorkerState, RequestWorkerState} +import cc.spray.http.MediaTypes +import JsonProtocol._ +import cc.spray.typeconversion.SprayJsonSupport._ private[spark] class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Directives { @@ -18,13 +21,20 @@ class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Direct val handler = { get { - path("") { - completeWith{ + (path("") & parameters('json ?)) { + case Some(js) => { val future = worker ? RequestWorkerState - future.map { workerState => - spark.deploy.worker.html.index(workerState.asInstanceOf[WorkerState]) + respondWithMediaType(MediaTypes.`application/json`) { ctx => + ctx.complete(future.mapTo[WorkerState]) } } + case None => + completeWith{ + val future = worker ? RequestWorkerState + future.map { workerState => + spark.deploy.worker.html.index(workerState.asInstanceOf[WorkerState]) + } + } } ~ path("log") { parameters("jobId", "executorId", "logType") { (jobId, executorId, logType) => From bf9d9946f97782c9212420123b4a042918d7df5e Mon Sep 17 00:00:00 2001 From: Tyson Date: Wed, 9 Jan 2013 11:29:22 -0500 Subject: [PATCH 073/291] Query parameter reformatted to be more extensible and routing more robust --- core/src/main/scala/spark/deploy/master/MasterWebUI.scala | 6 +++--- core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala index dfec1d1dc5..a96b55d6f3 100644 --- a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala +++ b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala @@ -22,13 +22,13 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct val handler = { get { - (path("") & parameters('json ?)) { - case Some(js) => + (path("") & parameters('format ?)) { + case Some(js) if js.equalsIgnoreCase("json") => val future = master ? RequestMasterState respondWithMediaType(MediaTypes.`application/json`) { ctx => ctx.complete(future.mapTo[MasterState]) } - case None => + case _ => completeWith { val future = master ? RequestMasterState future.map { diff --git a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala index a168f54ca0..84b6c16bd6 100644 --- a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala +++ b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala @@ -21,14 +21,14 @@ class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Direct val handler = { get { - (path("") & parameters('json ?)) { - case Some(js) => { + (path("") & parameters('format ?)) { + case Some(js) if js.equalsIgnoreCase("json") => { val future = worker ? RequestWorkerState respondWithMediaType(MediaTypes.`application/json`) { ctx => ctx.complete(future.mapTo[WorkerState]) } } - case None => + case _ => completeWith{ val future = worker ? RequestWorkerState future.map { workerState => From 549ee388a125ac7014ae3dadfb16c582e250c654 Mon Sep 17 00:00:00 2001 From: Tyson Date: Wed, 9 Jan 2013 15:12:23 -0500 Subject: [PATCH 074/291] Removed io.spray spray-json dependency as it is not needed. --- core/src/main/scala/spark/deploy/JsonProtocol.scala | 4 +--- project/SparkBuild.scala | 1 - 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/core/src/main/scala/spark/deploy/JsonProtocol.scala b/core/src/main/scala/spark/deploy/JsonProtocol.scala index dc7da85f9c..f14f804b3a 100644 --- a/core/src/main/scala/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/spark/deploy/JsonProtocol.scala @@ -1,14 +1,12 @@ package spark.deploy import master.{JobInfo, WorkerInfo} -import spray.json._ +import cc.spray.json._ /** * spray-json helper class containing implicit conversion to json for marshalling responses */ private[spark] object JsonProtocol extends DefaultJsonProtocol { - import cc.spray.json._ - implicit object WorkerInfoJsonFormat extends RootJsonWriter[WorkerInfo] { def write(obj: WorkerInfo) = JsObject( "id" -> JsString(obj.id), diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index f2b79d9ed8..c63efbdd2a 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -134,7 +134,6 @@ object SparkBuild extends Build { "cc.spray" % "spray-can" % "1.0-M2.1", "cc.spray" % "spray-server" % "1.0-M2.1", "cc.spray" %% "spray-json" % "1.1.1", - "io.spray" %% "spray-json" % "1.2.3", "org.apache.mesos" % "mesos" % "0.9.0-incubating" ) ++ (if (HADOOP_MAJOR_VERSION == "2") Some("org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION) else None).toSeq, unmanagedSourceDirectories in Compile <+= baseDirectory{ _ / ("src/hadoop" + HADOOP_MAJOR_VERSION + "/scala") } From e3861ae3953d7cab66160833688c8baf84e835ad Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Wed, 9 Jan 2013 17:03:25 -0600 Subject: [PATCH 075/291] Provide and expose a default Hadoop Configuration. Any "hadoop.*" system properties will be passed along into configuration. --- core/src/main/scala/spark/SparkContext.scala | 18 ++++++++++++++---- .../spark/api/java/JavaSparkContext.scala | 7 +++++++ 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index bbf8272eb3..36e0938854 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -187,6 +187,18 @@ class SparkContext( private var dagScheduler = new DAGScheduler(taskScheduler) + /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ + val hadoopConfiguration = { + val conf = new Configuration() + // Copy any "hadoop.foo=bar" system properties into conf as "foo=bar" + for (key <- System.getProperties.keys.asInstanceOf[Set[String]] if key.startsWith("hadoop.")) { + conf.set(key.substring("hadoop.".length), System.getProperty(key)) + } + val bufferSize = System.getProperty("spark.buffer.size", "65536") + conf.set("io.file.buffer.size", bufferSize) + conf + } + // Methods for creating RDDs /** Distribute a local Scala collection to form an RDD. */ @@ -231,10 +243,8 @@ class SparkContext( valueClass: Class[V], minSplits: Int = defaultMinSplits ) : RDD[(K, V)] = { - val conf = new JobConf() + val conf = new JobConf(hadoopConfiguration) FileInputFormat.setInputPaths(conf, path) - val bufferSize = System.getProperty("spark.buffer.size", "65536") - conf.set("io.file.buffer.size", bufferSize) new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits) } @@ -276,7 +286,7 @@ class SparkContext( fm.erasure.asInstanceOf[Class[F]], km.erasure.asInstanceOf[Class[K]], vm.erasure.asInstanceOf[Class[V]], - new Configuration) + hadoopConfiguration) } /** diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala index 88ab2846be..12e2a0bdac 100644 --- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala @@ -355,6 +355,13 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def clearFiles() { sc.clearFiles() } + + /** + * Returns the Hadoop configuration used for the Hadoop code (e.g. file systems) we reuse. + */ + def hadoopConfiguration() { + sc.hadoopConfiguration + } } object JavaSparkContext { From 1a64432ba50904c3933d8a9539a619fc94b3b30b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 9 Jan 2013 20:30:36 -0800 Subject: [PATCH 076/291] Indicate success/failure in PySpark test script. --- python/run-tests | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/python/run-tests b/python/run-tests index da9e24cb1f..fcdd1e27a7 100755 --- a/python/run-tests +++ b/python/run-tests @@ -3,7 +3,24 @@ # Figure out where the Scala framework is installed FWDIR="$(cd `dirname $0`; cd ../; pwd)" +FAILED=0 + $FWDIR/pyspark pyspark/rdd.py +FAILED=$(($?||$FAILED)) + $FWDIR/pyspark -m doctest pyspark/broadcast.py +FAILED=$(($?||$FAILED)) + +if [[ $FAILED != 0 ]]; then + echo -en "\033[31m" # Red + echo "Had test failures; see logs." + echo -en "\033[0m" # No color + exit -1 +else + echo -en "\033[32m" # Green + echo "Tests passed." + echo -en "\033[0m" # No color +fi # TODO: in the long-run, it would be nice to use a test runner like `nose`. +# The doctest fixtures are the current barrier to doing this. From d55f2b98822faa7d71f5fce2bfa980f8265e0610 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 9 Jan 2013 21:21:23 -0800 Subject: [PATCH 077/291] Use take() instead of takeSample() in PySpark kmeans example. This is a temporary change until we port takeSample(). --- python/examples/kmeans.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/examples/kmeans.py b/python/examples/kmeans.py index ad2be21178..72cf9f88c6 100644 --- a/python/examples/kmeans.py +++ b/python/examples/kmeans.py @@ -33,7 +33,9 @@ if __name__ == "__main__": K = int(sys.argv[3]) convergeDist = float(sys.argv[4]) - kPoints = data.takeSample(False, K, 34) + # TODO: change this after we port takeSample() + #kPoints = data.takeSample(False, K, 34) + kPoints = data.take(K) tempDist = 1.0 while tempDist > convergeDist: From 9930a95d217045c4c22c2575080a03e4b0fd2426 Mon Sep 17 00:00:00 2001 From: shane-huang Date: Thu, 10 Jan 2013 20:09:34 +0800 Subject: [PATCH 078/291] Modified Patch according to comments --- .../main/scala/spark/network/Connection.scala | 8 ++++---- .../spark/network/ConnectionManager.scala | 9 ++++----- .../spark/network/ConnectionManagerTest.scala | 20 +++++++++++++------ 3 files changed, 22 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala index 95096fd0ba..c193bf7c8d 100644 --- a/core/src/main/scala/spark/network/Connection.scala +++ b/core/src/main/scala/spark/network/Connection.scala @@ -136,10 +136,10 @@ extends Connection(SocketChannel.open, selector_) { if (chunk.isDefined) { messages += message // this is probably incorrect, it wont work as fifo if (!message.started) { - logDebug("Starting to send [" + message + "]") - message.started = true - message.startTime = System.currentTimeMillis - } + logDebug("Starting to send [" + message + "]") + message.started = true + message.startTime = System.currentTimeMillis + } return chunk } else { /*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/ diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index e7bd2d3bbd..36c01ad629 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -43,12 +43,12 @@ private[spark] class ConnectionManager(port: Int) extends Logging { } val selector = SelectorProvider.provider.openSelector() - val handleMessageExecutor = Executors.newFixedThreadPool(20) + val handleMessageExecutor = Executors.newFixedThreadPool(System.getProperty("spark.core.connection.handler.threads","20").toInt) val serverChannel = ServerSocketChannel.open() val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] val messageStatuses = new HashMap[Int, MessageStatus] - val connectionRequests = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] + val connectionRequests = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] val sendMessageRequests = new Queue[(Message, SendingConnection)] @@ -78,9 +78,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging { def run() { try { - while(!selectorThread.isInterrupted) { + while(!selectorThread.isInterrupted) { for( (connectionManagerId, sendingConnection) <- connectionRequests) { - //val sendingConnection = connectionRequests.dequeue sendingConnection.connect() addConnection(sendingConnection) connectionRequests -= connectionManagerId @@ -465,7 +464,7 @@ private[spark] object ConnectionManager { val bufferMessage = Message.createBufferMessage(buffer.duplicate) manager.sendMessageReliably(manager.id, bufferMessage) }).foreach(f => { - val g = Await.result(f, 10 second) + val g = Await.result(f, 1 second) if (!g.isDefined) println("Failed") }) val finishTime = System.currentTimeMillis diff --git a/core/src/main/scala/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/spark/network/ConnectionManagerTest.scala index 0e79c518e0..533e4610f3 100644 --- a/core/src/main/scala/spark/network/ConnectionManagerTest.scala +++ b/core/src/main/scala/spark/network/ConnectionManagerTest.scala @@ -13,8 +13,14 @@ import akka.util.duration._ private[spark] object ConnectionManagerTest extends Logging{ def main(args: Array[String]) { - if (args.length < 5) { - println("Usage: ConnectionManagerTest ") + // - the master URL + // - a list slaves to run connectionTest on + //[num of tasks] - the number of parallel tasks to be initiated default is number of slave hosts + //[size of msg in MB (integer)] - the size of messages to be sent in each task, default is 10 + //[count] - how many times to run, default is 3 + //[await time in seconds] : await time (in seconds), default is 600 + if (args.length < 2) { + println("Usage: ConnectionManagerTest [num of tasks] [size of msg in MB (integer)] [count] [await time in seconds)] ") System.exit(1) } @@ -29,14 +35,17 @@ private[spark] object ConnectionManagerTest extends Logging{ /*println("Slaves")*/ /*slaves.foreach(println)*/ - val tasknum = args(2).toInt + val tasknum = if (args.length > 2) args(2).toInt else slaves.length + val size = ( if (args.length > 3) (args(3).toInt) else 10 ) * 1024 * 1024 + val count = if (args.length > 4) args(4).toInt else 3 + val awaitTime = (if (args.length > 5) args(5).toInt else 600 ).second + println("Running "+count+" rounds of test: " + "parallel tasks = " + tasknum + ", msg size = " + size/1024/1024 + " MB, awaitTime = " + awaitTime) val slaveConnManagerIds = sc.parallelize(0 until tasknum, tasknum).map( i => SparkEnv.get.connectionManager.id).collect() println("\nSlave ConnectionManagerIds") slaveConnManagerIds.foreach(println) println - val count = args(4).toInt (0 until count).foreach(i => { val resultStrs = sc.parallelize(0 until tasknum, tasknum).map(i => { val connManager = SparkEnv.get.connectionManager @@ -46,7 +55,6 @@ private[spark] object ConnectionManagerTest extends Logging{ None }) - val size = (args(3).toInt) * 1024 * 1024 val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) buffer.flip @@ -56,7 +64,7 @@ private[spark] object ConnectionManagerTest extends Logging{ logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]") connManager.sendMessageReliably(slaveConnManagerId, bufferMessage) }) - val results = futures.map(f => Await.result(f, 999.second)) + val results = futures.map(f => Await.result(f, awaitTime)) val finishTime = System.currentTimeMillis Thread.sleep(5000) From 49c74ba2af2ab6fe5eda16dbcd35b30b46072a3a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 10 Jan 2013 07:45:12 -0800 Subject: [PATCH 079/291] Change PYSPARK_PYTHON_EXEC to PYSPARK_PYTHON. --- python/pyspark/context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 4439356c1f..e486f206b0 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -52,7 +52,7 @@ class SparkContext(object): self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome, empty_string_array) - self.pythonExec = os.environ.get("PYSPARK_PYTHON_EXEC", 'python') + self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') # Broadcast's __reduce__ method stores Broadcast instances here. # This allows other code to determine which Broadcast instances have # been pickled, so it can determine which Java broadcast objects to From b15e8512793475eaeda7225a259db8aacd600741 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Thu, 10 Jan 2013 10:55:41 -0600 Subject: [PATCH 080/291] Check for AWS_ACCESS_KEY_ID/AWS_SECRET_ACCESS_KEY environment variables. For custom properties, use "spark.hadoop.*" as a prefix instead of just "hadoop.*". --- core/src/main/scala/spark/SparkContext.scala | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 36e0938854..7b11955f1e 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -190,9 +190,16 @@ class SparkContext( /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ val hadoopConfiguration = { val conf = new Configuration() - // Copy any "hadoop.foo=bar" system properties into conf as "foo=bar" - for (key <- System.getProperties.keys.asInstanceOf[Set[String]] if key.startsWith("hadoop.")) { - conf.set(key.substring("hadoop.".length), System.getProperty(key)) + // Explicitly check for S3 environment variables + if (System.getenv("AWS_ACCESS_KEY_ID") != null && System.getenv("AWS_SECRET_ACCESS_KEY") != null) { + conf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID")) + conf.set("fs.s3n.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID")) + conf.set("fs.s3.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY")) + conf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY")) + } + // Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar" + for (key <- System.getProperties.keys.asInstanceOf[Set[String]] if key.startsWith("spark.hadoop.")) { + conf.set(key.substring("spark.hadoop.".length), System.getProperty(key)) } val bufferSize = System.getProperty("spark.buffer.size", "65536") conf.set("io.file.buffer.size", bufferSize) From d1864052c58ff1e58980729f7ccf00e630f815b9 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Thu, 10 Jan 2013 12:16:26 -0600 Subject: [PATCH 081/291] Fix invalid asInstanceOf cast. --- core/src/main/scala/spark/SparkContext.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 7b11955f1e..d2a5b4757a 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -7,6 +7,7 @@ import java.net.{URI, URLClassLoader} import scala.collection.Map import scala.collection.generic.Growable import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.collection.JavaConversions._ import akka.actor.Actor import akka.actor.Actor._ @@ -198,7 +199,7 @@ class SparkContext( conf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY")) } // Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar" - for (key <- System.getProperties.keys.asInstanceOf[Set[String]] if key.startsWith("spark.hadoop.")) { + for (key <- System.getProperties.toMap[String, String].keys if key.startsWith("spark.hadoop.")) { conf.set(key.substring("spark.hadoop.".length), System.getProperty(key)) } val bufferSize = System.getProperty("spark.buffer.size", "65536") From bd336f5f406386c929f2d1f9aecd7d5190a1a087 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 10 Jan 2013 17:13:04 -0800 Subject: [PATCH 082/291] Changed CoGroupRDD's hash map from Scala to Java. --- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index de0d9fad88..2e051c81c8 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -1,7 +1,8 @@ package spark.rdd +import java.util.{HashMap => JHashMap} import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap +import scala.collection.JavaConversions._ import spark.{Aggregator, Logging, Partitioner, RDD, SparkEnv, Split, TaskContext} import spark.{Dependency, OneToOneDependency, ShuffleDependency} @@ -71,7 +72,7 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) override def compute(s: Split, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = { val split = s.asInstanceOf[CoGroupSplit] val numRdds = split.deps.size - val map = new HashMap[K, Seq[ArrayBuffer[Any]]] + val map = new JHashMap[K, Seq[ArrayBuffer[Any]]] def getSeq(k: K): Seq[ArrayBuffer[Any]] = { map.getOrElseUpdate(k, Array.fill(numRdds)(new ArrayBuffer[Any])) } From 2e914d99835487e867cac6add8be1dbd80dc693f Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Thu, 10 Jan 2013 19:13:08 -0800 Subject: [PATCH 083/291] Formatting --- core/src/main/scala/spark/deploy/master/MasterWebUI.scala | 5 +++-- core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala index a96b55d6f3..580014ef3f 100644 --- a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala +++ b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala @@ -8,11 +8,12 @@ import akka.util.duration._ import cc.spray.Directives import cc.spray.directives._ import cc.spray.typeconversion.TwirlSupport._ -import spark.deploy._ import cc.spray.http.MediaTypes -import JsonProtocol._ import cc.spray.typeconversion.SprayJsonSupport._ +import spark.deploy._ +import spark.deploy.JsonProtocol._ + private[spark] class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Directives { val RESOURCE_DIR = "spark/deploy/master/webui" diff --git a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala index 84b6c16bd6..f9489d99fc 100644 --- a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala +++ b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala @@ -7,11 +7,12 @@ import akka.util.Timeout import akka.util.duration._ import cc.spray.Directives import cc.spray.typeconversion.TwirlSupport._ -import spark.deploy.{JsonProtocol, WorkerState, RequestWorkerState} import cc.spray.http.MediaTypes -import JsonProtocol._ import cc.spray.typeconversion.SprayJsonSupport._ +import spark.deploy.{WorkerState, RequestWorkerState} +import spark.deploy.JsonProtocol._ + private[spark] class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Directives { val RESOURCE_DIR = "spark/deploy/worker/webui" From 92625223066a5c28553d7710c6b14af56f64b560 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Thu, 10 Jan 2013 22:07:34 -0800 Subject: [PATCH 084/291] Activate hadoop2 profile in pom.xml with -Dhadoop=2 --- bagel/pom.xml | 6 ++++++ core/pom.xml | 6 ++++++ examples/pom.xml | 6 ++++++ pom.xml | 6 ++++++ repl-bin/pom.xml | 6 ++++++ repl/pom.xml | 6 ++++++ 6 files changed, 36 insertions(+) diff --git a/bagel/pom.xml b/bagel/pom.xml index 85b2077026..c3461fb889 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -77,6 +77,12 @@ hadoop2 + + + hadoop + 2 + + org.spark-project diff --git a/core/pom.xml b/core/pom.xml index 005d8fe498..c8ff625774 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -216,6 +216,12 @@ hadoop2 + + + hadoop + 2 + + org.apache.hadoop diff --git a/examples/pom.xml b/examples/pom.xml index 3f738a3f8c..d0b1e97747 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -77,6 +77,12 @@ hadoop2 + + + hadoop + 2 + + org.spark-project diff --git a/pom.xml b/pom.xml index ea5b9c9d05..ae87813d4e 100644 --- a/pom.xml +++ b/pom.xml @@ -502,6 +502,12 @@ hadoop2 + + + hadoop + 2 + + 2 diff --git a/repl-bin/pom.xml b/repl-bin/pom.xml index fecb01f3cd..54ae20659e 100644 --- a/repl-bin/pom.xml +++ b/repl-bin/pom.xml @@ -115,6 +115,12 @@ hadoop2 + + + hadoop + 2 + + hadoop2 diff --git a/repl/pom.xml b/repl/pom.xml index 04b2c35beb..3e979b93a6 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -121,6 +121,12 @@ hadoop2 + + + hadoop + 2 + + hadoop2 From 3e6519a36e354f3623c5b968efe5217c7fcb242f Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Fri, 11 Jan 2013 11:24:20 -0600 Subject: [PATCH 085/291] Use hadoopConfiguration for default JobConf in PairRDDFunctions. --- core/src/main/scala/spark/PairRDDFunctions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index ce48cea903..51c15837c4 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -557,7 +557,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( keyClass: Class[_], valueClass: Class[_], outputFormatClass: Class[_ <: OutputFormat[_, _]], - conf: JobConf = new JobConf) { + conf: JobConf = new JobConf(self.context.hadoopConfiguration)) { conf.setOutputKeyClass(keyClass) conf.setOutputValueClass(valueClass) // conf.setOutputFormat(outputFormatClass) // Doesn't work in Scala 2.9 due to what may be a generics bug From 5c7a1272198c88a90a843bbda0c1424f92b7c12e Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Fri, 11 Jan 2013 11:25:11 -0600 Subject: [PATCH 086/291] Pass a new Configuration that wraps the default hadoopConfiguration. --- core/src/main/scala/spark/SparkContext.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index d2a5b4757a..f6b98c41bc 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -294,7 +294,7 @@ class SparkContext( fm.erasure.asInstanceOf[Class[F]], km.erasure.asInstanceOf[Class[K]], vm.erasure.asInstanceOf[Class[V]], - hadoopConfiguration) + new Configuration(hadoopConfiguration)) } /** From 480c4139bbd2711e99f3a819c9ef164d8b3dcac0 Mon Sep 17 00:00:00 2001 From: Michael Heuer Date: Fri, 11 Jan 2013 11:24:48 -0600 Subject: [PATCH 087/291] add repositories section to simple job pom.xml --- docs/quick-start.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docs/quick-start.md b/docs/quick-start.md index 177cb14551..d46dc2da3f 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -200,6 +200,16 @@ To build the job, we also write a Maven `pom.xml` file that lists Spark as a dep Simple Project jar 1.0 + + + Spray.cc repository + http://repo.spray.cc + + + Typesafe repository + http://repo.typesafe.com/typesafe/releases + + org.spark-project From c063e8777ebaeb04056889064e9264edc019edbd Mon Sep 17 00:00:00 2001 From: Tyson Date: Fri, 11 Jan 2013 14:57:38 -0500 Subject: [PATCH 088/291] Added implicit json writers for JobDescription and ExecutorRunner --- .../scala/spark/deploy/JsonProtocol.scala | 23 ++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/deploy/JsonProtocol.scala b/core/src/main/scala/spark/deploy/JsonProtocol.scala index f14f804b3a..732fa08064 100644 --- a/core/src/main/scala/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/spark/deploy/JsonProtocol.scala @@ -1,6 +1,7 @@ package spark.deploy import master.{JobInfo, WorkerInfo} +import worker.ExecutorRunner import cc.spray.json._ /** @@ -30,6 +31,24 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol { "submitdate" -> JsString(obj.submitDate.toString)) } + implicit object JobDescriptionJsonFormat extends RootJsonWriter[JobDescription] { + def write(obj: JobDescription) = JsObject( + "name" -> JsString(obj.name), + "cores" -> JsNumber(obj.cores), + "memoryperslave" -> JsNumber(obj.memoryPerSlave), + "user" -> JsString(obj.user) + ) + } + + implicit object ExecutorRunnerJsonFormat extends RootJsonWriter[ExecutorRunner] { + def write(obj: ExecutorRunner) = JsObject( + "id" -> JsNumber(obj.execId), + "memory" -> JsNumber(obj.memory), + "jobid" -> JsString(obj.jobId), + "jobdesc" -> obj.jobDesc.toJson.asJsObject + ) + } + implicit object MasterStateJsonFormat extends RootJsonWriter[MasterState] { def write(obj: MasterState) = JsObject( "url" -> JsString("spark://" + obj.uri), @@ -51,7 +70,9 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol { "cores" -> JsNumber(obj.cores), "coresused" -> JsNumber(obj.coresUsed), "memory" -> JsNumber(obj.memory), - "memoryused" -> JsNumber(obj.memoryUsed) + "memoryused" -> JsNumber(obj.memoryUsed), + "executors" -> JsArray(obj.executors.toList.map(_.toJson)), + "finishedexecutors" -> JsArray(obj.finishedExecutors.toList.map(_.toJson)) ) } } From 1731f1fed4f1369662b1a9fde850a3dcba738a59 Mon Sep 17 00:00:00 2001 From: Tyson Date: Fri, 11 Jan 2013 15:01:43 -0500 Subject: [PATCH 089/291] Added an optional format parameter for individual job queries and optimized the jobId query --- .../spark/deploy/master/MasterWebUI.scala | 38 +++++++++++++------ 1 file changed, 27 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala index 580014ef3f..458ee2d665 100644 --- a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala +++ b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala @@ -38,20 +38,36 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct } } ~ path("job") { - parameter("jobId") { jobId => - completeWith { + parameters("jobId", 'format ?) { + case (jobId, Some(js)) if (js.equalsIgnoreCase("json")) => val future = master ? RequestMasterState - future.map { state => - val masterState = state.asInstanceOf[MasterState] - - // A bit ugly an inefficient, but we won't have a number of jobs - // so large that it will make a significant difference. - (masterState.activeJobs ++ masterState.completedJobs).find(_.id == jobId) match { - case Some(job) => spark.deploy.master.html.job_details.render(job) - case _ => null + val jobInfo = for (masterState <- future.mapTo[MasterState]) yield { + masterState.activeJobs.find(_.id == jobId) match { + case Some(job) => job + case _ => masterState.completedJobs.find(_.id == jobId) match { + case Some(job) => job + case _ => null + } + } + } + respondWithMediaType(MediaTypes.`application/json`) { ctx => + ctx.complete(jobInfo.mapTo[JobInfo]) + } + case (jobId, _) => + completeWith { + val future = master ? RequestMasterState + future.map { state => + val masterState = state.asInstanceOf[MasterState] + + masterState.activeJobs.find(_.id == jobId) match { + case Some(job) => spark.deploy.master.html.job_details.render(job) + case _ => masterState.completedJobs.find(_.id == jobId) match { + case Some(job) => spark.deploy.master.html.job_details.render(job) + case _ => null + } + } } } - } } } ~ pathPrefix("static") { From bbc56d85ed4eb4c3a09b20d5457f704f4b8a70c4 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Sat, 12 Jan 2013 15:24:13 -0800 Subject: [PATCH 090/291] Rename environment variable for hadoop profiles to hadoopVersion --- bagel/pom.xml | 4 ++-- core/pom.xml | 4 ++-- examples/pom.xml | 4 ++-- pom.xml | 5 +++-- repl-bin/pom.xml | 4 ++-- repl/pom.xml | 4 ++-- 6 files changed, 13 insertions(+), 12 deletions(-) diff --git a/bagel/pom.xml b/bagel/pom.xml index c3461fb889..5f58347204 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -47,7 +47,7 @@ hadoop1 - !hadoop + !hadoopVersion @@ -79,7 +79,7 @@ hadoop2 - hadoop + hadoopVersion 2 diff --git a/core/pom.xml b/core/pom.xml index c8ff625774..ad9fdcde2c 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -161,7 +161,7 @@ hadoop1 - !hadoop + !hadoopVersion @@ -218,7 +218,7 @@ hadoop2 - hadoop + hadoopVersion 2 diff --git a/examples/pom.xml b/examples/pom.xml index d0b1e97747..3355deb6b7 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -47,7 +47,7 @@ hadoop1 - !hadoop + !hadoopVersion @@ -79,7 +79,7 @@ hadoop2 - hadoop + hadoopVersion 2 diff --git a/pom.xml b/pom.xml index ae87813d4e..8f1af673a3 100644 --- a/pom.xml +++ b/pom.xml @@ -483,9 +483,10 @@ hadoop1 - !hadoop + !hadoopVersion + 1 @@ -504,7 +505,7 @@ hadoop2 - hadoop + hadoopVersion 2 diff --git a/repl-bin/pom.xml b/repl-bin/pom.xml index 54ae20659e..da91c0f3ab 100644 --- a/repl-bin/pom.xml +++ b/repl-bin/pom.xml @@ -72,7 +72,7 @@ hadoop1 - !hadoop + !hadoopVersion @@ -117,7 +117,7 @@ hadoop2 - hadoop + hadoopVersion 2 diff --git a/repl/pom.xml b/repl/pom.xml index 3e979b93a6..38e883c7f8 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -74,7 +74,7 @@ hadoop1 - !hadoop + !hadoopVersion @@ -123,7 +123,7 @@ hadoop2 - hadoop + hadoopVersion 2 From ba06e9c97cc3f8723ffdc3895182c529d3bb2fb3 Mon Sep 17 00:00:00 2001 From: Eric Zhang Date: Sun, 13 Jan 2013 15:33:11 +0800 Subject: [PATCH 091/291] Update examples/src/main/scala/spark/examples/LocalLR.scala MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix spelling mistake --- examples/src/main/scala/spark/examples/LocalLR.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/src/main/scala/spark/examples/LocalLR.scala b/examples/src/main/scala/spark/examples/LocalLR.scala index f2ac2b3e06..9553162004 100644 --- a/examples/src/main/scala/spark/examples/LocalLR.scala +++ b/examples/src/main/scala/spark/examples/LocalLR.scala @@ -5,7 +5,7 @@ import spark.util.Vector object LocalLR { val N = 10000 // Number of data points - val D = 10 // Numer of dimensions + val D = 10 // Number of dimensions val R = 0.7 // Scaling factor val ITERATIONS = 5 val rand = new Random(42) From 88d8f11365db84d46ff456495c07f664c91d1896 Mon Sep 17 00:00:00 2001 From: Mikhail Bautin Date: Sun, 13 Jan 2013 00:45:52 -0800 Subject: [PATCH 092/291] Add missing dependency spray-json to Maven build --- core/pom.xml | 4 ++++ pom.xml | 6 ++++++ 2 files changed, 10 insertions(+) diff --git a/core/pom.xml b/core/pom.xml index ad9fdcde2c..862d3ec37a 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -71,6 +71,10 @@ cc.spray spray-server + + cc.spray + spray-json_${scala.version} + org.tomdz.twirl twirl-api diff --git a/pom.xml b/pom.xml index 8f1af673a3..751189a9d8 100644 --- a/pom.xml +++ b/pom.xml @@ -54,6 +54,7 @@ 0.9.0-incubating 2.0.3 1.0-M2.1 + 1.1.1 1.6.1 4.1.2 @@ -222,6 +223,11 @@ spray-server ${spray.version} + + cc.spray + spray-json_${scala.version} + ${spray.json.version} + org.tomdz.twirl twirl-api From be7166146bf5692369272b85622d5316eccfd8e6 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 13 Jan 2013 15:27:28 -0800 Subject: [PATCH 093/291] Removed the use of getOrElse to avoid Scala wrapper for every call. --- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index 2e051c81c8..ce5f171911 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -1,8 +1,8 @@ package spark.rdd import java.util.{HashMap => JHashMap} +import scala.collection.JavaConversions import scala.collection.mutable.ArrayBuffer -import scala.collection.JavaConversions._ import spark.{Aggregator, Logging, Partitioner, RDD, SparkEnv, Split, TaskContext} import spark.{Dependency, OneToOneDependency, ShuffleDependency} @@ -74,7 +74,14 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) val numRdds = split.deps.size val map = new JHashMap[K, Seq[ArrayBuffer[Any]]] def getSeq(k: K): Seq[ArrayBuffer[Any]] = { - map.getOrElseUpdate(k, Array.fill(numRdds)(new ArrayBuffer[Any])) + val seq = map.get(k) + if (seq != null) { + seq + } else { + val seq = Array.fill(numRdds)(new ArrayBuffer[Any]) + map.put(k, seq) + seq + } } for ((dep, depNum) <- split.deps.zipWithIndex) dep match { case NarrowCoGroupSplitDep(rdd, itsSplit) => { @@ -94,6 +101,6 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) fetcher.fetch[K, Seq[Any]](shuffleId, split.index).foreach(mergePair) } } - map.iterator + JavaConversions.mapAsScalaMap(map).iterator } } From 72408e8dfacc24652f376d1ee4dd6f04edb54804 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 13 Jan 2013 19:34:07 -0800 Subject: [PATCH 094/291] Make filter preserve partitioner info, since it can --- core/src/main/scala/spark/rdd/FilteredRDD.scala | 3 ++- core/src/test/scala/spark/PartitioningSuite.scala | 5 +++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala index b148da28de..d46549b8b6 100644 --- a/core/src/main/scala/spark/rdd/FilteredRDD.scala +++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala @@ -7,5 +7,6 @@ private[spark] class FilteredRDD[T: ClassManifest](prev: RDD[T], f: T => Boolean) extends RDD[T](prev.context) { override def splits = prev.splits override val dependencies = List(new OneToOneDependency(prev)) + override val partitioner = prev.partitioner // Since filter cannot change a partition's keys override def compute(split: Split, context: TaskContext) = prev.iterator(split, context).filter(f) -} \ No newline at end of file +} diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala index f09b602a7b..eb3c8f238f 100644 --- a/core/src/test/scala/spark/PartitioningSuite.scala +++ b/core/src/test/scala/spark/PartitioningSuite.scala @@ -106,6 +106,11 @@ class PartitioningSuite extends FunSuite with BeforeAndAfter { assert(grouped2.leftOuterJoin(reduced2).partitioner === grouped2.partitioner) assert(grouped2.rightOuterJoin(reduced2).partitioner === grouped2.partitioner) assert(grouped2.cogroup(reduced2).partitioner === grouped2.partitioner) + + assert(grouped2.map(_ => 1).partitioner === None) + assert(grouped2.mapValues(_ => 1).partitioner === grouped2.partitioner) + assert(grouped2.flatMapValues(_ => Seq(1)).partitioner === grouped2.partitioner) + assert(grouped2.filter(_._1 > 4).partitioner === grouped2.partitioner) } test("partitioning Java arrays should fail") { From 273fb5cc109ac0a032f84c1566ae908cd0eb27b6 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Thu, 3 Jan 2013 14:09:56 -0800 Subject: [PATCH 095/291] Throw FetchFailedException for cached missing locs --- .../main/scala/spark/MapOutputTracker.scala | 36 +++++++++++++------ 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index 70eb9f702e..9f2aa76830 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -139,8 +139,8 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea case e: InterruptedException => } } - return mapStatuses.get(shuffleId).map(status => - (status.address, MapOutputTracker.decompressSize(status.compressedSizes(reduceId)))) + return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, + mapStatuses.get(shuffleId)) } else { fetching += shuffleId } @@ -156,21 +156,15 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea fetchedStatuses = deserializeStatuses(fetchedBytes) logInfo("Got the output locations") mapStatuses.put(shuffleId, fetchedStatuses) - if (fetchedStatuses.contains(null)) { - throw new FetchFailedException(null, shuffleId, -1, reduceId, - new Exception("Missing an output location for shuffle " + shuffleId)) - } } finally { fetching.synchronized { fetching -= shuffleId fetching.notifyAll() } } - return fetchedStatuses.map(s => - (s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId)))) + return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) } else { - return statuses.map(s => - (s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId)))) + return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses) } } @@ -258,6 +252,28 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea private[spark] object MapOutputTracker { private val LOG_BASE = 1.1 + // Convert an array of MapStatuses to locations and sizes for a given reduce ID. If + // any of the statuses is null (indicating a missing location due to a failed mapper), + // throw a FetchFailedException. + def convertMapStatuses( + shuffleId: Int, + reduceId: Int, + statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = { + if (statuses == null) { + throw new FetchFailedException(null, shuffleId, -1, reduceId, + new Exception("Missing all output locations for shuffle " + shuffleId)) + } + statuses.map { + status => + if (status == null) { + throw new FetchFailedException(null, shuffleId, -1, reduceId, + new Exception("Missing an output location for shuffle " + shuffleId)) + } else { + (status.address, decompressSize(status.compressedSizes(reduceId))) + } + } + } + /** * Compress a size in bytes to 8 bits for efficient reporting of map output sizes. * We do this by encoding the log base 1.1 of the size as an integer, which can support From 7ba34bc007ec10d12b2a871749f32232cdbc0d9c Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Mon, 14 Jan 2013 15:24:08 -0800 Subject: [PATCH 096/291] Additional tests for MapOutputTracker. --- .../scala/spark/MapOutputTrackerSuite.scala | 82 ++++++++++++++++++- 1 file changed, 80 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala index 5b4b198960..6c6f82e274 100644 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala @@ -1,12 +1,18 @@ package spark import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter import akka.actor._ import spark.scheduler.MapStatus import spark.storage.BlockManagerId +import spark.util.AkkaUtils -class MapOutputTrackerSuite extends FunSuite { +class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { + after { + System.clearProperty("spark.master.port") + } + test("compressSize") { assert(MapOutputTracker.compressSize(0L) === 0) assert(MapOutputTracker.compressSize(1L) === 1) @@ -71,6 +77,78 @@ class MapOutputTrackerSuite extends FunSuite { // The remaining reduce task might try to grab the output dispite the shuffle failure; // this should cause it to fail, and the scheduler will ignore the failure due to the // stage already being aborted. - intercept[Exception] { tracker.getServerStatuses(10, 1) } + intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) } + } + + test("remote fetch") { + val (actorSystem, boundPort) = + AkkaUtils.createActorSystem("test", "localhost", 0) + System.setProperty("spark.master.port", boundPort.toString) + val masterTracker = new MapOutputTracker(actorSystem, true) + val slaveTracker = new MapOutputTracker(actorSystem, false) + masterTracker.registerShuffle(10, 1) + masterTracker.incrementGeneration() + slaveTracker.updateGeneration(masterTracker.getGeneration) + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + + val compressedSize1000 = MapOutputTracker.compressSize(1000L) + val size1000 = MapOutputTracker.decompressSize(compressedSize1000) + masterTracker.registerMapOutput(10, 0, new MapStatus( + new BlockManagerId("hostA", 1000), Array(compressedSize1000))) + masterTracker.incrementGeneration() + slaveTracker.updateGeneration(masterTracker.getGeneration) + assert(slaveTracker.getServerStatuses(10, 0).toSeq === + Seq((new BlockManagerId("hostA", 1000), size1000))) + + masterTracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000)) + masterTracker.incrementGeneration() + slaveTracker.updateGeneration(masterTracker.getGeneration) + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + } + + test("simulatenous fetch fails") { + val dummyActorSystem = ActorSystem("testDummy") + val dummyTracker = new MapOutputTracker(dummyActorSystem, true) + dummyTracker.registerShuffle(10, 1) + // val compressedSize1000 = MapOutputTracker.compressSize(1000L) + // val size100 = MapOutputTracker.decompressSize(compressedSize1000) + // dummyTracker.registerMapOutput(10, 0, new MapStatus( + // new BlockManagerId("hostA", 1000), Array(compressedSize1000))) + val serializedMessage = dummyTracker.getSerializedLocations(10) + + val (actorSystem, boundPort) = + AkkaUtils.createActorSystem("test", "localhost", 0) + System.setProperty("spark.master.port", boundPort.toString) + val delayResponseLock = new java.lang.Object + val delayResponseActor = actorSystem.actorOf(Props(new Actor { + override def receive = { + case GetMapOutputStatuses(shuffleId: Int, requester: String) => + delayResponseLock.synchronized { + sender ! serializedMessage + } + } + }), name = "MapOutputTracker") + val slaveTracker = new MapOutputTracker(actorSystem, false) + var firstFailed = false + var secondFailed = false + val firstFetch = new Thread { + override def run() { + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + firstFailed = true + } + } + val secondFetch = new Thread { + override def run() { + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + secondFailed = true + } + } + delayResponseLock.synchronized { + firstFetch.start + secondFetch.start + } + firstFetch.join + secondFetch.join + assert(firstFailed && secondFailed) } } From b61a4ec77300d6e7fb40f771a9054ae8bc4488de Mon Sep 17 00:00:00 2001 From: seanm Date: Mon, 14 Jan 2013 17:13:10 -0700 Subject: [PATCH 097/291] Removing offset management code that is non-existent in kafka 0.7.0+ --- .../scala/spark/streaming/dstream/KafkaInputDStream.scala | 7 ------- 1 file changed, 7 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala index 2b4740bdf7..9605072382 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala @@ -173,13 +173,6 @@ class KafkaReceiver(host: String, port: Int, groupId: String, stream.takeWhile { msgAndMetadata => blockGenerator += msgAndMetadata.message - // Updating the offet. The key is (broker, topic, group, partition). - val key = KafkaPartitionKey(msgAndMetadata.topicInfo.brokerId, msgAndMetadata.topic, - groupId, msgAndMetadata.topicInfo.partition.partId) - val offset = msgAndMetadata.topicInfo.getConsumeOffset - offsets.put(key, offset) - // logInfo("Handled message: " + (key, offset).toString) - // Keep on handling messages true } From c203a292963a018bd9b84f02bb522fd191a110af Mon Sep 17 00:00:00 2001 From: seanm Date: Mon, 14 Jan 2013 17:22:03 -0700 Subject: [PATCH 098/291] StateDStream changes to give updateStateByKey consistent behavior --- .../scala/spark/streaming/dstream/StateDStream.scala | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala index a1ec2f5454..4e57968eed 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala @@ -48,8 +48,16 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife //logDebug("Generating state RDD for time " + validTime) return Some(stateRDD) } - case None => { // If parent RDD does not exist, then return old state RDD - return Some(prevStateRDD) + case None => { // If parent RDD does not exist + + // Re-apply the update function to the old state RDD + val updateFuncLocal = updateFunc + val finalFunc = (iterator: Iterator[(K, S)]) => { + val i = iterator.map(t => (t._1, Seq[V](), Option(t._2))) + updateFuncLocal(i) + } + val stateRDD = prevStateRDD.mapPartitions(finalFunc, preservePartitioning) + return Some(stateRDD) } } } From b0389997972d383c3aaa87924b725dee70b18d8e Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Mon, 14 Jan 2013 17:04:44 -0800 Subject: [PATCH 099/291] Fix accidental spark.master.host reuse --- core/src/test/scala/spark/MapOutputTrackerSuite.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala index 6c6f82e274..aa1d8ac7e6 100644 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala @@ -81,6 +81,7 @@ class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { } test("remote fetch") { + System.clearProperty("spark.master.host") val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0) System.setProperty("spark.master.port", boundPort.toString) @@ -107,6 +108,7 @@ class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { } test("simulatenous fetch fails") { + System.clearProperty("spark.master.host") val dummyActorSystem = ActorSystem("testDummy") val dummyTracker = new MapOutputTracker(dummyActorSystem, true) dummyTracker.registerShuffle(10, 1) From b77f7390a5a18c2b88fbc0c276c4dbc938560127 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Tue, 15 Jan 2013 09:04:32 +0200 Subject: [PATCH 100/291] Python ALS example --- python/examples/als.py | 71 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100755 python/examples/als.py diff --git a/python/examples/als.py b/python/examples/als.py new file mode 100755 index 0000000000..284cf0d3a2 --- /dev/null +++ b/python/examples/als.py @@ -0,0 +1,71 @@ +""" +This example requires numpy (http://www.numpy.org/) +""" +from os.path import realpath +import sys + +import numpy as np +from numpy.random import rand +from numpy import matrix +from pyspark import SparkContext + +LAMBDA = 0.01 # regularization +np.random.seed(42) + +def rmse(R, ms, us): + diff = R - ms * us.T + return np.sqrt(np.sum(np.power(diff, 2)) / M * U) + +def update(i, vec, mat, ratings): + uu = mat.shape[0] + ff = mat.shape[1] + XtX = matrix(np.zeros((ff, ff))) + Xty = np.zeros((ff, 1)) + + for j in range(uu): + v = mat[j, :] + XtX += v.T * v + Xty += v.T * ratings[i, j] + XtX += np.eye(ff, ff) * LAMBDA * uu + return np.linalg.solve(XtX, Xty) + +if __name__ == "__main__": + if len(sys.argv) < 2: + print >> sys.stderr, \ + "Usage: PythonALS " + exit(-1) + sc = SparkContext(sys.argv[1], "PythonALS", pyFiles=[realpath(__file__)]) + M = int(sys.argv[2]) if len(sys.argv) > 2 else 100 + U = int(sys.argv[3]) if len(sys.argv) > 3 else 500 + F = int(sys.argv[4]) if len(sys.argv) > 4 else 10 + ITERATIONS = int(sys.argv[5]) if len(sys.argv) > 5 else 5 + slices = int(sys.argv[6]) if len(sys.argv) > 6 else 2 + + print "Running ALS with M=%d, U=%d, F=%d, iters=%d, slices=%d\n" % \ + (M, U, F, ITERATIONS, slices) + + R = matrix(rand(M, F)) * matrix(rand(U, F).T) + ms = matrix(rand(M ,F)) + us = matrix(rand(U, F)) + + Rb = sc.broadcast(R) + msb = sc.broadcast(ms) + usb = sc.broadcast(us) + + for i in range(ITERATIONS): + ms = sc.parallelize(range(M), slices) \ + .map(lambda x: update(x, msb.value[x, :], usb.value, Rb.value)) \ + .collect() + ms = matrix(np.array(ms)[:, :, 0]) # collect() returns a list, so array ends up being + # a 3-d array, we take the first 2 dims for the matrix + msb = sc.broadcast(ms) + + us = sc.parallelize(range(U), slices) \ + .map(lambda x: update(x, usb.value[x, :], msb.value, Rb.value.T)) \ + .collect() + us = matrix(np.array(us)[:, :, 0]) + usb = sc.broadcast(us) + + error = rmse(R, ms, us) + print "Iteration %d:" % i + print "\nRMSE: %5.4f\n" % error \ No newline at end of file From dd583b7ebf0e6620ec8e35424b59db451febe3e8 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 15 Jan 2013 10:52:06 -0600 Subject: [PATCH 101/291] Call executeOnCompleteCallbacks in a finally block. --- core/src/main/scala/spark/scheduler/ResultTask.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala index e492279b4e..2aad7956b4 100644 --- a/core/src/main/scala/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/spark/scheduler/ResultTask.scala @@ -15,9 +15,11 @@ private[spark] class ResultTask[T, U]( override def run(attemptId: Long): U = { val context = new TaskContext(stageId, partition, attemptId) - val result = func(context, rdd.iterator(split, context)) - context.executeOnCompleteCallbacks() - result + try { + func(context, rdd.iterator(split, context)) + } finally { + context.executeOnCompleteCallbacks() + } } override def preferredLocations: Seq[String] = locs From d228bff440395e8e6b8d67483467dde65b08ab40 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 15 Jan 2013 11:48:50 -0600 Subject: [PATCH 102/291] Add a test. --- .../spark/scheduler/TaskContextSuite.scala | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 core/src/test/scala/spark/scheduler/TaskContextSuite.scala diff --git a/core/src/test/scala/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/spark/scheduler/TaskContextSuite.scala new file mode 100644 index 0000000000..f937877340 --- /dev/null +++ b/core/src/test/scala/spark/scheduler/TaskContextSuite.scala @@ -0,0 +1,43 @@ +package spark.scheduler + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter +import spark.TaskContext +import spark.RDD +import spark.SparkContext +import spark.Split + +class TaskContextSuite extends FunSuite with BeforeAndAfter { + + var sc: SparkContext = _ + + after { + if (sc != null) { + sc.stop() + sc = null + } + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.master.port") + } + + test("Calls executeOnCompleteCallbacks after failure") { + var completed = false + sc = new SparkContext("local", "test") + val rdd = new RDD[String](sc) { + override val splits = Array[Split](StubSplit(0)) + override val dependencies = List() + override def compute(split: Split, context: TaskContext) = { + context.addOnCompleteCallback(() => completed = true) + sys.error("failed") + } + } + val func = (c: TaskContext, i: Iterator[String]) => i.next + val task = new ResultTask[String, String](0, rdd, func, 0, Seq(), 0) + intercept[RuntimeException] { + task.run(0) + } + assert(completed === true) + } + + case class StubSplit(val index: Int) extends Split +} \ No newline at end of file From c7143e5507f1d5292e678315158d3863c9bb4242 Mon Sep 17 00:00:00 2001 From: Andrew Psaltis Date: Tue, 15 Jan 2013 12:45:42 -0700 Subject: [PATCH 103/291] Changed teh scala version to 2.9.2, so that the classes can be found when the classpath is expanded. --- run2.cmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/run2.cmd b/run2.cmd index 83464b1166..67f1e465e4 100644 --- a/run2.cmd +++ b/run2.cmd @@ -1,6 +1,6 @@ @echo off -set SCALA_VERSION=2.9.1 +set SCALA_VERSION=2.9.2 rem Figure out where the Spark framework is installed set FWDIR=%~dp0 From 74d3b23929758328c2a7879381669d81bf899396 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 15 Jan 2013 14:03:28 -0600 Subject: [PATCH 104/291] Add spark.executor.memory to differentiate executor memory from spark-shell memory. --- core/src/main/scala/spark/SparkContext.scala | 4 ++-- .../scala/spark/deploy/worker/ExecutorRunner.scala | 3 +-- .../cluster/SparkDeploySchedulerBackend.scala | 11 +++++------ 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index bbf8272eb3..a5a1b75944 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -111,8 +111,8 @@ class SparkContext( // Environment variables to pass to our executors private[spark] val executorEnvs = HashMap[String, String]() - for (key <- Seq("SPARK_MEM", "SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS", - "SPARK_TESTING")) { + // Note: SPARK_MEM isn't included because it's set directly in ExecutorRunner + for (key <- Seq("SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS", "SPARK_TESTING")) { val value = System.getenv(key) if (value != null) { executorEnvs(key) = value diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala index beceb55ecd..2f2ea617ff 100644 --- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala @@ -118,8 +118,7 @@ private[spark] class ExecutorRunner( for ((key, value) <- jobDesc.command.environment) { env.put(key, value) } - env.put("SPARK_CORES", cores.toString) - env.put("SPARK_MEMORY", memory.toString) + env.put("SPARK_MEM", memory.toString) // In case we are running this from within the Spark Shell, avoid creating a "scala" // parent process for the executor command env.put("SPARK_LAUNCH_WITH_SCALA", "0") diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index e2301347e5..f2fb244b24 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -23,12 +23,11 @@ private[spark] class SparkDeploySchedulerBackend( // Memory used by each executor (in megabytes) val executorMemory = { - if (System.getenv("SPARK_MEM") != null) { - Utils.memoryStringToMb(System.getenv("SPARK_MEM")) - // TODO: Might need to add some extra memory for the non-heap parts of the JVM - } else { - 512 - } + // TODO: Might need to add some extra memory for the non-heap parts of the JVM + Option(System.getProperty("spark.executor.memory")) + .orElse(Option(System.getenv("SPARK_MEM"))) + .map(Utils.memoryStringToMb) + .getOrElse(512) } override def start() { From 4078623b9f2a338d4992c3dfd3af3a5550615180 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Tue, 15 Jan 2013 12:05:54 -0800 Subject: [PATCH 105/291] Remove broken attempt to test fetching case. --- .../scala/spark/MapOutputTrackerSuite.scala | 48 +------------------ 1 file changed, 2 insertions(+), 46 deletions(-) diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala index aa1d8ac7e6..d3dd3a8fa4 100644 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala @@ -105,52 +105,8 @@ class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { masterTracker.incrementGeneration() slaveTracker.updateGeneration(masterTracker.getGeneration) intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } - } - test("simulatenous fetch fails") { - System.clearProperty("spark.master.host") - val dummyActorSystem = ActorSystem("testDummy") - val dummyTracker = new MapOutputTracker(dummyActorSystem, true) - dummyTracker.registerShuffle(10, 1) - // val compressedSize1000 = MapOutputTracker.compressSize(1000L) - // val size100 = MapOutputTracker.decompressSize(compressedSize1000) - // dummyTracker.registerMapOutput(10, 0, new MapStatus( - // new BlockManagerId("hostA", 1000), Array(compressedSize1000))) - val serializedMessage = dummyTracker.getSerializedLocations(10) - - val (actorSystem, boundPort) = - AkkaUtils.createActorSystem("test", "localhost", 0) - System.setProperty("spark.master.port", boundPort.toString) - val delayResponseLock = new java.lang.Object - val delayResponseActor = actorSystem.actorOf(Props(new Actor { - override def receive = { - case GetMapOutputStatuses(shuffleId: Int, requester: String) => - delayResponseLock.synchronized { - sender ! serializedMessage - } - } - }), name = "MapOutputTracker") - val slaveTracker = new MapOutputTracker(actorSystem, false) - var firstFailed = false - var secondFailed = false - val firstFetch = new Thread { - override def run() { - intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } - firstFailed = true - } - } - val secondFetch = new Thread { - override def run() { - intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } - secondFailed = true - } - } - delayResponseLock.synchronized { - firstFetch.start - secondFetch.start - } - firstFetch.join - secondFetch.join - assert(firstFailed && secondFailed) + // failure should be cached + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } } } From a805ac4a7cdd520b6141dd885c780c526bb54ba6 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 16 Jan 2013 10:55:26 -0800 Subject: [PATCH 106/291] Disabled checkpoint for PairwiseRDD (pySpark). --- core/src/main/scala/spark/api/python/PythonRDD.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 276035a9ad..0138b22d38 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -138,6 +138,7 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends case Seq(a, b) => (a, b) case x => throw new Exception("PairwiseRDD: unexpected value: " + x) } + override def checkpoint() { } val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this) } From eae698f755f41fd8bdff94c498df314ed74aa3c1 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Wed, 16 Jan 2013 12:21:37 -0800 Subject: [PATCH 107/291] remove unused thread pool --- .../main/scala/spark/executor/StandaloneExecutorBackend.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala index 915f71ba9f..a29bf974d2 100644 --- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala @@ -24,9 +24,6 @@ private[spark] class StandaloneExecutorBackend( with ExecutorBackend with Logging { - val threadPool = new ThreadPoolExecutor( - 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable]) - var master: ActorRef = null override def preStart() { From 42fbef3c2a6460bcd389bb86306be3ebc14c998b Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Thu, 17 Jan 2013 15:54:59 +0200 Subject: [PATCH 108/291] Adding default command line args to SparkALS --- .../main/scala/spark/examples/SparkALS.scala | 27 ++++++++++++------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/examples/src/main/scala/spark/examples/SparkALS.scala b/examples/src/main/scala/spark/examples/SparkALS.scala index fb28e2c932..cbd749666d 100644 --- a/examples/src/main/scala/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/spark/examples/SparkALS.scala @@ -7,6 +7,7 @@ import cern.jet.math._ import cern.colt.matrix._ import cern.colt.matrix.linalg._ import spark._ +import scala.Option object SparkALS { // Parameters set through command line arguments @@ -97,21 +98,27 @@ object SparkALS { def main(args: Array[String]) { var host = "" var slices = 0 - args match { - case Array(m, u, f, iters, slices_, host_) => { - M = m.toInt - U = u.toInt - F = f.toInt - ITERATIONS = iters.toInt - slices = slices_.toInt - host = host_ + + (1 to 6).map(i => { + i match { + case a if a < args.length => Option(args(a)) + case _ => Option(null) + } + }).toArray match { + case Array(host_, m, u, f, iters, slices_) => { + host = host_ getOrElse "local" + M = (m getOrElse "100").toInt + U = (u getOrElse "500").toInt + F = (f getOrElse "10").toInt + ITERATIONS = (iters getOrElse "5").toInt + slices = (slices_ getOrElse "2").toInt } case _ => { - System.err.println("Usage: SparkALS ") + System.err.println("Usage: SparkALS [ ]") System.exit(1) } } - printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS); + printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS) val spark = new SparkContext(host, "SparkALS") val R = generateR() From a512df551f85086a6ec363744542e74749c6b560 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Thu, 17 Jan 2013 16:05:27 +0200 Subject: [PATCH 109/291] Fixed index error missing first argument --- examples/src/main/scala/spark/examples/SparkALS.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/src/main/scala/spark/examples/SparkALS.scala b/examples/src/main/scala/spark/examples/SparkALS.scala index cbd749666d..4672812565 100644 --- a/examples/src/main/scala/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/spark/examples/SparkALS.scala @@ -99,7 +99,7 @@ object SparkALS { var host = "" var slices = 0 - (1 to 6).map(i => { + (0 to 5).map(i => { i match { case a if a < args.length => Option(args(a)) case _ => Option(null) From a5ba7a9f322dce763350864bf89d94e6656d9984 Mon Sep 17 00:00:00 2001 From: Nick Pentreath Date: Thu, 17 Jan 2013 16:21:00 +0200 Subject: [PATCH 110/291] Use only one update function and pass in transpose of ratings matrix where appropriate --- .../main/scala/spark/examples/SparkALS.scala | 32 ++----------------- 1 file changed, 3 insertions(+), 29 deletions(-) diff --git a/examples/src/main/scala/spark/examples/SparkALS.scala b/examples/src/main/scala/spark/examples/SparkALS.scala index 4672812565..2766ad1702 100644 --- a/examples/src/main/scala/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/spark/examples/SparkALS.scala @@ -43,7 +43,7 @@ object SparkALS { return sqrt(sumSqs / (M * U)) } - def updateMovie(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D], + def update(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D], R: DoubleMatrix2D) : DoubleMatrix1D = { val U = us.size @@ -69,32 +69,6 @@ object SparkALS { return solved2D.viewColumn(0) } - def updateUser(j: Int, u: DoubleMatrix1D, ms: Array[DoubleMatrix1D], - R: DoubleMatrix2D) : DoubleMatrix1D = - { - val M = ms.size - val F = ms(0).size - val XtX = factory2D.make(F, F) - val Xty = factory1D.make(F) - // For each movie that the user rated - for (i <- 0 until M) { - val m = ms(i) - // Add m * m^t to XtX - blas.dger(1, m, m, XtX) - // Add m * rating to Xty - blas.daxpy(R.get(i, j), m, Xty) - } - // Add regularization coefs to diagonal terms - for (d <- 0 until F) { - XtX.set(d, d, XtX.get(d, d) + LAMBDA * M) - } - // Solve it with Cholesky - val ch = new CholeskyDecomposition(XtX) - val Xty2D = factory2D.make(Xty.toArray, F) - val solved2D = ch.solve(Xty2D) - return solved2D.viewColumn(0) - } - def main(args: Array[String]) { var host = "" var slices = 0 @@ -134,11 +108,11 @@ object SparkALS { for (iter <- 1 to ITERATIONS) { println("Iteration " + iter + ":") ms = spark.parallelize(0 until M, slices) - .map(i => updateMovie(i, msc.value(i), usc.value, Rc.value)) + .map(i => update(i, msc.value(i), usc.value, Rc.value)) .toArray msc = spark.broadcast(ms) // Re-broadcast ms because it was updated us = spark.parallelize(0 until U, slices) - .map(i => updateUser(i, usc.value(i), msc.value, Rc.value)) + .map(i => update(i, usc.value(i), msc.value, algebra.transpose(Rc.value))) .toArray usc = spark.broadcast(us) // Re-broadcast us because it was updated println("RMSE = " + rmse(R, ms, us)) From 892c32a14b89139b7bd89e141fc90b148a67ce68 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Thu, 17 Jan 2013 11:14:47 -0800 Subject: [PATCH 111/291] Warn users if they run pyspark or spark-shell without compiling Spark --- pyspark | 7 +++++++ run | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/pyspark b/pyspark index 9e89d51ba2..ab7f4f50c0 100755 --- a/pyspark +++ b/pyspark @@ -6,6 +6,13 @@ FWDIR="$(cd `dirname $0`; pwd)" # Export this as SPARK_HOME export SPARK_HOME="$FWDIR" +# Exit if the user hasn't compiled Spark +if [ ! -e "$SPARK_HOME/repl/target" ]; then + echo "Failed to find Spark classes in $SPARK_HOME/repl/target" >&2 + echo "You need to compile Spark before running this program" >&2 + exit 1 +fi + # Load environment variables from conf/spark-env.sh, if it exists if [ -e $FWDIR/conf/spark-env.sh ] ; then . $FWDIR/conf/spark-env.sh diff --git a/run b/run index ca23455386..eb93db66db 100755 --- a/run +++ b/run @@ -65,6 +65,13 @@ EXAMPLES_DIR="$FWDIR/examples" BAGEL_DIR="$FWDIR/bagel" PYSPARK_DIR="$FWDIR/python" +# Exit if the user hasn't compiled Spark +if [ ! -e "$REPL_DIR/target" ]; then + echo "Failed to find Spark classes in $REPL_DIR/target" >&2 + echo "You need to compile Spark before running this program" >&2 + exit 1 +fi + # Build up classpath CLASSPATH="$SPARK_CLASSPATH" CLASSPATH+=":$FWDIR/conf" From 742bc841adb2a57b05e7a155681a162ab9dfa2c1 Mon Sep 17 00:00:00 2001 From: Fernand Pajot Date: Thu, 17 Jan 2013 16:56:11 -0800 Subject: [PATCH 112/291] changed HttpBroadcast server cache to be in spark.local.dir instead of java.io.tmpdir --- core/src/main/scala/spark/broadcast/HttpBroadcast.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala index 7eb4ddb74f..96dc28f12a 100644 --- a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala @@ -89,7 +89,7 @@ private object HttpBroadcast extends Logging { } private def createServer() { - broadcastDir = Utils.createTempDir() + broadcastDir = Utils.createTempDir(System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir"))) server = new HttpServer(broadcastDir) server.start() serverUri = server.uri From 54c0f9f185576e9b844fa8f81ca410f188daa51c Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Thu, 17 Jan 2013 17:40:55 -0800 Subject: [PATCH 113/291] Fix code that assumed spark.local.dir is only a single directory --- core/src/main/scala/spark/Utils.scala | 11 ++++++++++- .../main/scala/spark/broadcast/HttpBroadcast.scala | 2 +- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 0e7007459d..aeed5d2f32 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -134,7 +134,7 @@ private object Utils extends Logging { */ def fetchFile(url: String, targetDir: File) { val filename = url.split("/").last - val tempDir = System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")) + val tempDir = getLocalDir val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir)) val targetFile = new File(targetDir, filename) val uri = new URI(url) @@ -204,6 +204,15 @@ private object Utils extends Logging { FileUtil.chmod(filename, "a+x") } + /** + * Get a temporary directory using Spark's spark.local.dir property, if set. This will always + * return a single directory, even though the spark.local.dir property might be a list of + * multiple paths. + */ + def getLocalDir: String = { + System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")).split(',')(0) + } + /** * Shuffle the elements of a collection into a random order, returning the * result in a new collection. Unlike scala.util.Random.shuffle, this method diff --git a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala index 96dc28f12a..856a4683a9 100644 --- a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala @@ -89,7 +89,7 @@ private object HttpBroadcast extends Logging { } private def createServer() { - broadcastDir = Utils.createTempDir(System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir"))) + broadcastDir = Utils.createTempDir(Utils.getLocalDir) server = new HttpServer(broadcastDir) server.start() serverUri = server.uri From 1db119a08f07b8707b901e92b03138b27e887844 Mon Sep 17 00:00:00 2001 From: seanm Date: Fri, 18 Jan 2013 20:22:23 -0700 Subject: [PATCH 114/291] kafka jar wasn't being included by run script --- run | 3 +++ 1 file changed, 3 insertions(+) diff --git a/run b/run index 2f61cb2a87..494f04c3ac 100755 --- a/run +++ b/run @@ -76,6 +76,9 @@ CLASSPATH+=":$CORE_DIR/src/main/resources" CLASSPATH+=":$REPL_DIR/target/scala-$SCALA_VERSION/classes" CLASSPATH+=":$EXAMPLES_DIR/target/scala-$SCALA_VERSION/classes" CLASSPATH+=":$STREAMING_DIR/target/scala-$SCALA_VERSION/classes" +for jar in `find "$STREAMING_DIR/lib" -name '*jar'`; do + CLASSPATH+=":$jar" +done if [ -e "$FWDIR/lib_managed" ]; then for jar in `find "$FWDIR/lib_managed/jars" -name '*jar'`; do CLASSPATH+=":$jar" From 56b7fbafa2b7717896c613e39ecc134f2405b4c6 Mon Sep 17 00:00:00 2001 From: seanm Date: Fri, 18 Jan 2013 21:15:54 -0700 Subject: [PATCH 115/291] further KafkaInputDStream cleanup (removing unused and commented out code relating to offset management) --- .../streaming/dstream/KafkaInputDStream.scala | 72 +------------------ 1 file changed, 3 insertions(+), 69 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala index 9605072382..533c91ee95 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala @@ -19,15 +19,6 @@ import scala.collection.JavaConversions._ // Key for a specific Kafka Partition: (broker, topic, group, part) case class KafkaPartitionKey(brokerId: Int, topic: String, groupId: String, partId: Int) -// NOT USED - Originally intended for fault-tolerance -// Metadata for a Kafka Stream that it sent to the Master -private[streaming] -case class KafkaInputDStreamMetadata(timestamp: Long, data: Map[KafkaPartitionKey, Long]) -// NOT USED - Originally intended for fault-tolerance -// Checkpoint data specific to a KafkaInputDstream -private[streaming] -case class KafkaDStreamCheckpointData(kafkaRdds: HashMap[Time, Any], - savedOffsets: Map[KafkaPartitionKey, Long]) extends DStreamCheckpointData(kafkaRdds) /** * Input stream that pulls messages from a Kafka Broker. @@ -52,49 +43,6 @@ class KafkaInputDStream[T: ClassManifest]( storageLevel: StorageLevel ) extends NetworkInputDStream[T](ssc_ ) with Logging { - // Metadata that keeps track of which messages have already been consumed. - var savedOffsets = HashMap[Long, Map[KafkaPartitionKey, Long]]() - - /* NOT USED - Originally intended for fault-tolerance - - // In case of a failure, the offets for a particular timestamp will be restored. - @transient var restoredOffsets : Map[KafkaPartitionKey, Long] = null - - - override protected[streaming] def addMetadata(metadata: Any) { - metadata match { - case x : KafkaInputDStreamMetadata => - savedOffsets(x.timestamp) = x.data - // TOOD: Remove logging - logInfo("New saved Offsets: " + savedOffsets) - case _ => logInfo("Received unknown metadata: " + metadata.toString) - } - } - - override protected[streaming] def updateCheckpointData(currentTime: Time) { - super.updateCheckpointData(currentTime) - if(savedOffsets.size > 0) { - // Find the offets that were stored before the checkpoint was initiated - val key = savedOffsets.keys.toList.sortWith(_ < _).filter(_ < currentTime.millis).last - val latestOffsets = savedOffsets(key) - logInfo("Updating KafkaDStream checkpoint data: " + latestOffsets.toString) - checkpointData = KafkaDStreamCheckpointData(checkpointData.rdds, latestOffsets) - // TODO: This may throw out offsets that are created after the checkpoint, - // but it's unlikely we'll need them. - savedOffsets.clear() - } - } - - override protected[streaming] def restoreCheckpointData() { - super.restoreCheckpointData() - logInfo("Restoring KafkaDStream checkpoint data.") - checkpointData match { - case x : KafkaDStreamCheckpointData => - restoredOffsets = x.savedOffsets - logInfo("Restored KafkaDStream offsets: " + savedOffsets) - } - } */ - def createReceiver(): NetworkReceiver[T] = { new KafkaReceiver(host, port, groupId, topics, initialOffsets, storageLevel) .asInstanceOf[NetworkReceiver[T]] @@ -111,8 +59,6 @@ class KafkaReceiver(host: String, port: Int, groupId: String, // Handles pushing data into the BlockManager lazy protected val blockGenerator = new BlockGenerator(storageLevel) - // Keeps track of the current offsets. Maps from (broker, topic, group, part) -> Offset - lazy val offsets = HashMap[KafkaPartitionKey, Long]() // Connection to Kafka var consumerConnector : ZookeeperConsumerConnector = null @@ -143,8 +89,8 @@ class KafkaReceiver(host: String, port: Int, groupId: String, consumerConnector = Consumer.create(consumerConfig).asInstanceOf[ZookeeperConsumerConnector] logInfo("Connected to " + zooKeeperEndPoint) - // Reset the Kafka offsets in case we are recovering from a failure - resetOffsets(initialOffsets) + // If specified, set the topic offset + setOffsets(initialOffsets) // Create Threads for each Topic/Message Stream we are listening val topicMessageStreams = consumerConnector.createMessageStreams(topics, new StringDecoder()) @@ -157,7 +103,7 @@ class KafkaReceiver(host: String, port: Int, groupId: String, } // Overwrites the offets in Zookeper. - private def resetOffsets(offsets: Map[KafkaPartitionKey, Long]) { + private def setOffsets(offsets: Map[KafkaPartitionKey, Long]) { offsets.foreach { case(key, offset) => val topicDirs = new ZKGroupTopicDirs(key.groupId, key.topic) val partitionName = key.brokerId + "-" + key.partId @@ -178,16 +124,4 @@ class KafkaReceiver(host: String, port: Int, groupId: String, } } } - - // NOT USED - Originally intended for fault-tolerance - // class KafkaDataHandler(receiver: KafkaReceiver, storageLevel: StorageLevel) - // extends BufferingBlockCreator[Any](receiver, storageLevel) { - - // override def createBlock(blockId: String, iterator: Iterator[Any]) : Block = { - // // Creates a new Block with Kafka-specific Metadata - // new Block(blockId, iterator, KafkaInputDStreamMetadata(System.currentTimeMillis, offsets.toMap)) - // } - - // } - } From d3064fe70762cbfcb7dbd5e1fbd708539c3de5e9 Mon Sep 17 00:00:00 2001 From: seanm Date: Fri, 18 Jan 2013 21:34:29 -0700 Subject: [PATCH 116/291] kafkaStream API cleanup. A quorum of zookeepers can now be specified --- .../streaming/examples/KafkaWordCount.scala | 16 ++++++++-------- .../spark/streaming/StreamingContext.scala | 8 +++----- .../streaming/dstream/KafkaInputDStream.scala | 17 +++++++---------- 3 files changed, 18 insertions(+), 23 deletions(-) diff --git a/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala b/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala index fe55db6e2c..65d5da82fc 100644 --- a/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala +++ b/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala @@ -13,19 +13,19 @@ import spark.streaming.util.RawTextHelper._ object KafkaWordCount { def main(args: Array[String]) { - if (args.length < 6) { - System.err.println("Usage: KafkaWordCount ") + if (args.length < 5) { + System.err.println("Usage: KafkaWordCount ") System.exit(1) } - val Array(master, hostname, port, group, topics, numThreads) = args + val Array(master, zkQuorum, group, topics, numThreads) = args val sc = new SparkContext(master, "KafkaWordCount") val ssc = new StreamingContext(sc, Seconds(2)) ssc.checkpoint("checkpoint") val topicpMap = topics.split(",").map((_,numThreads.toInt)).toMap - val lines = ssc.kafkaStream[String](hostname, port.toInt, group, topicpMap) + val lines = ssc.kafkaStream[String](zkQuorum, group, topicpMap) val words = lines.flatMap(_.split(" ")) val wordCounts = words.map(x => (x, 1l)).reduceByKeyAndWindow(add _, subtract _, Minutes(10), Seconds(2), 2) wordCounts.print() @@ -38,16 +38,16 @@ object KafkaWordCount { object KafkaWordCountProducer { def main(args: Array[String]) { - if (args.length < 3) { - System.err.println("Usage: KafkaWordCountProducer ") + if (args.length < 2) { + System.err.println("Usage: KafkaWordCountProducer ") System.exit(1) } - val Array(hostname, port, topic, messagesPerSec, wordsPerMessage) = args + val Array(zkQuorum, topic, messagesPerSec, wordsPerMessage) = args // Zookeper connection properties val props = new Properties() - props.put("zk.connect", hostname + ":" + port) + props.put("zk.connect", zkQuorum) props.put("serializer.class", "kafka.serializer.StringEncoder") val config = new ProducerConfig(props) diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 14500bdcb1..06cf7a06ed 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -136,8 +136,7 @@ class StreamingContext private ( /** * Create an input stream that pulls messages form a Kafka Broker. - * @param hostname Zookeper hostname. - * @param port Zookeper port. + * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). * @param groupId The group id for this consumer. * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. @@ -146,14 +145,13 @@ class StreamingContext private ( * @param storageLevel RDD storage level. Defaults to memory-only. */ def kafkaStream[T: ClassManifest]( - hostname: String, - port: Int, + zkQuorum: String, groupId: String, topics: Map[String, Int], initialOffsets: Map[KafkaPartitionKey, Long] = Map[KafkaPartitionKey, Long](), storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2 ): DStream[T] = { - val inputStream = new KafkaInputDStream[T](this, hostname, port, groupId, topics, initialOffsets, storageLevel) + val inputStream = new KafkaInputDStream[T](this, zkQuorum, groupId, topics, initialOffsets, storageLevel) registerInputStream(inputStream) inputStream } diff --git a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala index 533c91ee95..4f8c8b9d10 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala @@ -23,8 +23,7 @@ case class KafkaPartitionKey(brokerId: Int, topic: String, groupId: String, part /** * Input stream that pulls messages from a Kafka Broker. * - * @param host Zookeper hostname. - * @param port Zookeper port. + * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). * @param groupId The group id for this consumer. * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. @@ -35,8 +34,7 @@ case class KafkaPartitionKey(brokerId: Int, topic: String, groupId: String, part private[streaming] class KafkaInputDStream[T: ClassManifest]( @transient ssc_ : StreamingContext, - host: String, - port: Int, + zkQuorum: String, groupId: String, topics: Map[String, Int], initialOffsets: Map[KafkaPartitionKey, Long], @@ -44,13 +42,13 @@ class KafkaInputDStream[T: ClassManifest]( ) extends NetworkInputDStream[T](ssc_ ) with Logging { def createReceiver(): NetworkReceiver[T] = { - new KafkaReceiver(host, port, groupId, topics, initialOffsets, storageLevel) + new KafkaReceiver(zkQuorum, groupId, topics, initialOffsets, storageLevel) .asInstanceOf[NetworkReceiver[T]] } } private[streaming] -class KafkaReceiver(host: String, port: Int, groupId: String, +class KafkaReceiver(zkQuorum: String, groupId: String, topics: Map[String, Int], initialOffsets: Map[KafkaPartitionKey, Long], storageLevel: StorageLevel) extends NetworkReceiver[Any] { @@ -73,21 +71,20 @@ class KafkaReceiver(host: String, port: Int, groupId: String, // In case we are using multiple Threads to handle Kafka Messages val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _)) - val zooKeeperEndPoint = host + ":" + port logInfo("Starting Kafka Consumer Stream with group: " + groupId) logInfo("Initial offsets: " + initialOffsets.toString) // Zookeper connection properties val props = new Properties() - props.put("zk.connect", zooKeeperEndPoint) + props.put("zk.connect", zkQuorum) props.put("zk.connectiontimeout.ms", ZK_TIMEOUT.toString) props.put("groupid", groupId) // Create the connection to the cluster - logInfo("Connecting to Zookeper: " + zooKeeperEndPoint) + logInfo("Connecting to Zookeper: " + zkQuorum) val consumerConfig = new ConsumerConfig(props) consumerConnector = Consumer.create(consumerConfig).asInstanceOf[ZookeeperConsumerConnector] - logInfo("Connected to " + zooKeeperEndPoint) + logInfo("Connected to " + zkQuorum) // If specified, set the topic offset setOffsets(initialOffsets) From ecdff861f7993251163b82e737aba6bb1bb814d8 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sat, 19 Jan 2013 22:59:35 -0800 Subject: [PATCH 117/291] Clarifying log directory in EC2 guide --- docs/ec2-scripts.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md index 6e1f7fd3b1..8b069ca9ad 100644 --- a/docs/ec2-scripts.md +++ b/docs/ec2-scripts.md @@ -96,7 +96,8 @@ permissions on your private key file, you can run `launch` with the `spark-ec2` to attach a persistent EBS volume to each node for storing the persistent HDFS. - Finally, if you get errors while running your jobs, look at the slave's logs - for that job using the Mesos web UI (`http://:8080`). + for that job inside of the Mesos work directory (/mnt/mesos-work). Mesos errors + can be found using the Mesos web UI (`http://:8080`). # Configuration From 214345ceace634ec9cc83c4c85b233b699e0d219 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sat, 19 Jan 2013 23:50:17 -0800 Subject: [PATCH 118/291] Fixed issue https://spark-project.atlassian.net/browse/STREAMING-29, along with updates to doc comments in SparkContext.checkpoint(). --- core/src/main/scala/spark/RDD.scala | 17 ++++++++--------- .../main/scala/spark/RDDCheckpointData.scala | 2 +- core/src/main/scala/spark/SparkContext.scala | 13 +++++++------ .../main/scala/spark/streaming/DStream.scala | 8 +++++++- 4 files changed, 23 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index a9f2e86455..e0d2eabb1d 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -549,17 +549,16 @@ abstract class RDD[T: ClassManifest]( } /** - * Mark this RDD for checkpointing. The RDD will be saved to a file inside `checkpointDir` - * (set using setCheckpointDir()) and all references to its parent RDDs will be removed. - * This is used to truncate very long lineages. In the current implementation, Spark will save - * this RDD to a file (using saveAsObjectFile()) after the first job using this RDD is done. - * Hence, it is strongly recommended to use checkpoint() on RDDs when - * (i) checkpoint() is called before the any job has been executed on this RDD. - * (ii) This RDD has been made to persist in memory. Otherwise saving it on a file will - * require recomputation. + * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint + * directory set with SparkContext.setCheckpointDir() and all references to its parent + * RDDs will be removed. This function must be called before any job has been + * executed on this RDD. It is strongly recommended that this RDD is persisted in + * memory, otherwise saving it on a file will require recomputation. */ def checkpoint() { - if (checkpointData.isEmpty) { + if (context.checkpointDir.isEmpty) { + throw new Exception("Checkpoint directory has not been set in the SparkContext") + } else if (checkpointData.isEmpty) { checkpointData = Some(new RDDCheckpointData(this)) checkpointData.get.markForCheckpoint() } diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala index d845a522e4..18df530b7d 100644 --- a/core/src/main/scala/spark/RDDCheckpointData.scala +++ b/core/src/main/scala/spark/RDDCheckpointData.scala @@ -63,7 +63,7 @@ extends Logging with Serializable { } // Save to file, and reload it as an RDD - val path = new Path(rdd.context.checkpointDir, "rdd-" + rdd.id).toString + val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id).toString rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path) _) val newRDD = new CheckpointRDD[T](rdd.context, path) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 88cf357ebf..7f3259d982 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -184,7 +184,7 @@ class SparkContext( private var dagScheduler = new DAGScheduler(taskScheduler) - private[spark] var checkpointDir: String = null + private[spark] var checkpointDir: Option[String] = None // Methods for creating RDDs @@ -595,10 +595,11 @@ class SparkContext( } /** - * Set the directory under which RDDs are going to be checkpointed. This method will - * create this directory and will throw an exception of the path already exists (to avoid - * overwriting existing files may be overwritten). The directory will be deleted on exit - * if indicated. + * Set the directory under which RDDs are going to be checkpointed. The directory must + * be a HDFS path if running on a cluster. If the directory does not exist, it will + * be created. If the directory exists and useExisting is set to true, then the + * exisiting directory will be used. Otherwise an exception will be thrown to + * prevent accidental overriding of checkpoint files in the existing directory. */ def setCheckpointDir(dir: String, useExisting: Boolean = false) { val path = new Path(dir) @@ -610,7 +611,7 @@ class SparkContext( fs.mkdirs(path) } } - checkpointDir = dir + checkpointDir = Some(dir) } /** Default level of parallelism to use when not given by user (e.g. for reduce tasks) */ diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index fbe3cebd6d..c4442b6a0c 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -154,10 +154,16 @@ abstract class DStream[T: ClassManifest] ( assert( !mustCheckpoint || checkpointDuration != null, - "The checkpoint interval for " + this.getClass.getSimpleName + " has not been set. " + + "The checkpoint interval for " + this.getClass.getSimpleName + " has not been set." + " Please use DStream.checkpoint() to set the interval." ) + assert( + checkpointDuration == null || ssc.sc.checkpointDir.isDefined, + "The checkpoint directory has not been set. Please use StreamingContext.checkpoint()" + + " or SparkContext.checkpoint() to set the checkpoint directory." + ) + assert( checkpointDuration == null || checkpointDuration >= slideDuration, "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " + From 8e7f098a2c9e5e85cb9435f28d53a3a5847c14aa Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 20 Jan 2013 01:57:44 -0800 Subject: [PATCH 119/291] Added accumulators to PySpark --- .../scala/spark/api/python/PythonRDD.scala | 83 +++++++-- python/pyspark/__init__.py | 4 + python/pyspark/accumulators.py | 166 ++++++++++++++++++ python/pyspark/context.py | 38 ++++ python/pyspark/rdd.py | 2 +- python/pyspark/serializers.py | 7 +- python/pyspark/shell.py | 4 +- python/pyspark/worker.py | 7 +- 8 files changed, 290 insertions(+), 21 deletions(-) create mode 100644 python/pyspark/accumulators.py diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index f431ef28d3..fb13e84658 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -1,7 +1,8 @@ package spark.api.python import java.io._ -import java.util.{List => JList} +import java.net._ +import java.util.{List => JList, ArrayList => JArrayList, Collections} import scala.collection.JavaConversions._ import scala.io.Source @@ -10,25 +11,26 @@ import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import spark.broadcast.Broadcast import spark._ import spark.rdd.PipedRDD -import java.util private[spark] class PythonRDD[T: ClassManifest]( - parent: RDD[T], - command: Seq[String], - envVars: java.util.Map[String, String], - preservePartitoning: Boolean, - pythonExec: String, - broadcastVars: java.util.List[Broadcast[Array[Byte]]]) + parent: RDD[T], + command: Seq[String], + envVars: java.util.Map[String, String], + preservePartitoning: Boolean, + pythonExec: String, + broadcastVars: JList[Broadcast[Array[Byte]]], + accumulator: Accumulator[JList[Array[Byte]]]) extends RDD[Array[Byte]](parent.context) { // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String], - preservePartitoning: Boolean, pythonExec: String, - broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = + preservePartitoning: Boolean, pythonExec: String, + broadcastVars: JList[Broadcast[Array[Byte]]], + accumulator: Accumulator[JList[Array[Byte]]]) = this(parent, PipedRDD.tokenize(command), envVars, preservePartitoning, pythonExec, - broadcastVars) + broadcastVars, accumulator) override def splits = parent.splits @@ -93,18 +95,30 @@ private[spark] class PythonRDD[T: ClassManifest]( // Return an iterator that read lines from the process's stdout val stream = new DataInputStream(proc.getInputStream) return new Iterator[Array[Byte]] { - def next() = { + def next(): Array[Byte] = { val obj = _nextObj _nextObj = read() obj } - private def read() = { + private def read(): Array[Byte] = { try { val length = stream.readInt() - val obj = new Array[Byte](length) - stream.readFully(obj) - obj + if (length != -1) { + val obj = new Array[Byte](length) + stream.readFully(obj) + obj + } else { + // We've finished the data section of the output, but we can still read some + // accumulator updates; let's do that, breaking when we get EOFException + while (true) { + val len2 = stream.readInt() + val update = new Array[Byte](len2) + stream.readFully(update) + accumulator += Collections.singletonList(update) + } + new Array[Byte](0) + } } catch { case eof: EOFException => { val exitStatus = proc.waitFor() @@ -246,3 +260,40 @@ private class ExtractValue extends spark.api.java.function.Function[(Array[Byte] private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] { override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8") } + +/** + * Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it + * collects a list of pickled strings that we pass to Python through a socket. + */ +class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) + extends AccumulatorParam[JList[Array[Byte]]] { + + override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList + + override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]]) + : JList[Array[Byte]] = { + if (serverHost == null) { + // This happens on the worker node, where we just want to remember all the updates + val1.addAll(val2) + val1 + } else { + // This happens on the master, where we pass the updates to Python through a socket + val socket = new Socket(serverHost, serverPort) + val in = socket.getInputStream + val out = new DataOutputStream(socket.getOutputStream) + out.writeInt(val2.size) + for (array <- val2) { + out.writeInt(array.length) + out.write(array) + } + out.flush() + // Wait for a byte from the Python side as an acknowledgement + val byteRead = in.read() + if (byteRead == -1) { + throw new SparkException("EOF reached before Python server acknowledged") + } + socket.close() + null + } + } +} diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index c595ae0842..00666bc0a3 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -7,6 +7,10 @@ Public classes: Main entry point for Spark functionality. - L{RDD} A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. + - L{Broadcast} + A broadcast variable that gets reused across tasks. + - L{Accumulator} + An "add-only" shared variable that tasks can only add values to. """ import sys import os diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py new file mode 100644 index 0000000000..438af4cfc0 --- /dev/null +++ b/python/pyspark/accumulators.py @@ -0,0 +1,166 @@ +""" +>>> from pyspark.context import SparkContext +>>> sc = SparkContext('local', 'test') +>>> a = sc.accumulator(1) +>>> a.value +1 +>>> a.value = 2 +>>> a.value +2 +>>> a += 5 +>>> a.value +7 + +>>> rdd = sc.parallelize([1,2,3]) +>>> def f(x): +... global a +... a += x +>>> rdd.foreach(f) +>>> a.value +13 + +>>> class VectorAccumulatorParam(object): +... def zero(self, value): +... return [0.0] * len(value) +... def addInPlace(self, val1, val2): +... for i in xrange(len(val1)): +... val1[i] += val2[i] +... return val1 +>>> va = sc.accumulator([1.0, 2.0, 3.0], VectorAccumulatorParam()) +>>> va.value +[1.0, 2.0, 3.0] +>>> def g(x): +... global va +... va += [x] * 3 +>>> rdd.foreach(g) +>>> va.value +[7.0, 8.0, 9.0] + +>>> rdd.map(lambda x: a.value).collect() # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): + ... +Py4JJavaError:... + +>>> def h(x): +... global a +... a.value = 7 +>>> rdd.foreach(h) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): + ... +Py4JJavaError:... + +>>> sc.accumulator([1.0, 2.0, 3.0]) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): + ... +Exception:... +""" + +import struct +import SocketServer +import threading +from pyspark.cloudpickle import CloudPickler +from pyspark.serializers import read_int, read_with_length, load_pickle + + +# Holds accumulators registered on the current machine, keyed by ID. This is then used to send +# the local accumulator updates back to the driver program at the end of a task. +_accumulatorRegistry = {} + + +def _deserialize_accumulator(aid, zero_value, accum_param): + from pyspark.accumulators import _accumulatorRegistry + accum = Accumulator(aid, zero_value, accum_param) + accum._deserialized = True + _accumulatorRegistry[aid] = accum + return accum + + +class Accumulator(object): + def __init__(self, aid, value, accum_param): + """Create a new Accumulator with a given initial value and AccumulatorParam object""" + from pyspark.accumulators import _accumulatorRegistry + self.aid = aid + self.accum_param = accum_param + self._value = value + self._deserialized = False + _accumulatorRegistry[aid] = self + + def __reduce__(self): + """Custom serialization; saves the zero value from our AccumulatorParam""" + param = self.accum_param + return (_deserialize_accumulator, (self.aid, param.zero(self._value), param)) + + @property + def value(self): + """Get the accumulator's value; only usable in driver program""" + if self._deserialized: + raise Exception("Accumulator.value cannot be accessed inside tasks") + return self._value + + @value.setter + def value(self, value): + """Sets the accumulator's value; only usable in driver program""" + if self._deserialized: + raise Exception("Accumulator.value cannot be accessed inside tasks") + self._value = value + + def __iadd__(self, term): + """The += operator; adds a term to this accumulator's value""" + self._value = self.accum_param.addInPlace(self._value, term) + return self + + def __str__(self): + return str(self._value) + + +class AddingAccumulatorParam(object): + """ + An AccumulatorParam that uses the + operators to add values. Designed for simple types + such as integers, floats, and lists. Requires the zero value for the underlying type + as a parameter. + """ + + def __init__(self, zero_value): + self.zero_value = zero_value + + def zero(self, value): + return self.zero_value + + def addInPlace(self, value1, value2): + value1 += value2 + return value1 + + +# Singleton accumulator params for some standard types +INT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0) +DOUBLE_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0) +COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j) + + +class _UpdateRequestHandler(SocketServer.StreamRequestHandler): + def handle(self): + from pyspark.accumulators import _accumulatorRegistry + num_updates = read_int(self.rfile) + for _ in range(num_updates): + (aid, update) = load_pickle(read_with_length(self.rfile)) + _accumulatorRegistry[aid] += update + # Write a byte in acknowledgement + self.wfile.write(struct.pack("!b", 1)) + + +def _start_update_server(): + """Start a TCP server to receive accumulator updates in a daemon thread, and returns it""" + server = SocketServer.TCPServer(("localhost", 0), _UpdateRequestHandler) + thread = threading.Thread(target=server.serve_forever) + thread.daemon = True + thread.start() + return server + + +def _test(): + import doctest + doctest.testmod() + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/context.py b/python/pyspark/context.py index e486f206b0..1e2f845f9c 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -2,6 +2,8 @@ import os import atexit from tempfile import NamedTemporaryFile +from pyspark import accumulators +from pyspark.accumulators import Accumulator from pyspark.broadcast import Broadcast from pyspark.java_gateway import launch_gateway from pyspark.serializers import dump_pickle, write_with_length, batched @@ -22,6 +24,7 @@ class SparkContext(object): _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile _takePartition = jvm.PythonRDD.takePartition + _next_accum_id = 0 def __init__(self, master, jobName, sparkHome=None, pyFiles=None, environment=None, batchSize=1024): @@ -52,6 +55,14 @@ class SparkContext(object): self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome, empty_string_array) + # Create a single Accumulator in Java that we'll send all our updates through; + # they will be passed back to us through a TCP server + self._accumulatorServer = accumulators._start_update_server() + (host, port) = self._accumulatorServer.server_address + self._javaAccumulator = self._jsc.accumulator( + self.jvm.java.util.ArrayList(), + self.jvm.PythonAccumulatorParam(host, port)) + self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') # Broadcast's __reduce__ method stores Broadcast instances here. # This allows other code to determine which Broadcast instances have @@ -74,6 +85,8 @@ class SparkContext(object): def __del__(self): if self._jsc: self._jsc.stop() + if self._accumulatorServer: + self._accumulatorServer.shutdown() def stop(self): """ @@ -129,6 +142,31 @@ class SparkContext(object): return Broadcast(jbroadcast.id(), value, jbroadcast, self._pickled_broadcast_vars) + def accumulator(self, value, accum_param=None): + """ + Create an C{Accumulator} with the given initial value, using a given + AccumulatorParam helper object to define how to add values of the data + type if provided. Default AccumulatorParams are used for integers and + floating-point numbers if you do not provide one. For other types, the + AccumulatorParam must implement two methods: + - C{zero(value)}: provide a "zero value" for the type, compatible in + dimensions with the provided C{value} (e.g., a zero vector). + - C{addInPlace(val1, val2)}: add two values of the accumulator's data + type, returning a new value; for efficiency, can also update C{val1} + in place and return it. + """ + if accum_param == None: + if isinstance(value, int): + accum_param = accumulators.INT_ACCUMULATOR_PARAM + elif isinstance(value, float): + accum_param = accumulators.FLOAT_ACCUMULATOR_PARAM + elif isinstance(value, complex): + accum_param = accumulators.COMPLEX_ACCUMULATOR_PARAM + else: + raise Exception("No default accumulator param for type %s" % type(value)) + SparkContext._next_accum_id += 1 + return Accumulator(SparkContext._next_accum_id - 1, value, accum_param) + def addFile(self, path): """ Add a file to be downloaded into the working directory of this Spark diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 1d36da42b0..d705f0f9e1 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -703,7 +703,7 @@ class PipelinedRDD(RDD): env = MapConverter().convert(env, self.ctx.gateway._gateway_client) python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec, - broadcast_vars, class_manifest) + broadcast_vars, self.ctx._javaAccumulator, class_manifest) self._jrdd_val = python_rdd.asJavaRDD() return self._jrdd_val diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 9a5151ea00..115cf28cc2 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -52,8 +52,13 @@ def read_int(stream): raise EOFError return struct.unpack("!i", length)[0] + +def write_int(value, stream): + stream.write(struct.pack("!i", value)) + + def write_with_length(obj, stream): - stream.write(struct.pack("!i", len(obj))) + write_int(len(obj), stream) stream.write(obj) diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 7e6ad3aa76..f6328c561f 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -1,7 +1,7 @@ """ An interactive shell. -This fle is designed to be launched as a PYTHONSTARTUP script. +This file is designed to be launched as a PYTHONSTARTUP script. """ import os from pyspark.context import SparkContext @@ -14,4 +14,4 @@ print "Spark context avaiable as sc." # which allows us to execute the user's PYTHONSTARTUP file: _pythonstartup = os.environ.get('OLD_PYTHONSTARTUP') if _pythonstartup and os.path.isfile(_pythonstartup): - execfile(_pythonstartup) + execfile(_pythonstartup) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 3d792bbaa2..b2b9288089 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -5,9 +5,10 @@ import sys from base64 import standard_b64decode # CloudPickler needs to be imported so that depicklers are registered using the # copy_reg module. +from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler -from pyspark.serializers import write_with_length, read_with_length, \ +from pyspark.serializers import write_with_length, read_with_length, write_int, \ read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file @@ -36,6 +37,10 @@ def main(): iterator = read_from_pickle_file(sys.stdin) for obj in func(split_index, iterator): write_with_length(dumps(obj), old_stdout) + # Mark the beginning of the accumulators section of the output + write_int(-1, old_stdout) + for aid, accum in _accumulatorRegistry.items(): + write_with_length(dump_pickle((aid, accum._value)), old_stdout) if __name__ == '__main__': From 61b6382a352f3e801643529198b867e13debf470 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 20 Jan 2013 01:59:07 -0800 Subject: [PATCH 120/291] Launch accumulator tests in run-tests --- python/run-tests | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/run-tests b/python/run-tests index fcdd1e27a7..32470911f9 100755 --- a/python/run-tests +++ b/python/run-tests @@ -11,6 +11,9 @@ FAILED=$(($?||$FAILED)) $FWDIR/pyspark -m doctest pyspark/broadcast.py FAILED=$(($?||$FAILED)) +$FWDIR/pyspark -m doctest pyspark/accumulators.py +FAILED=$(($?||$FAILED)) + if [[ $FAILED != 0 ]]; then echo -en "\033[31m" # Red echo "Had test failures; see logs." From a23ed25f3cd6e76784f831d0ab7de7d3e193b59f Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 20 Jan 2013 02:10:25 -0800 Subject: [PATCH 121/291] Add a class comment to Accumulator --- python/pyspark/accumulators.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 438af4cfc0..c00c3a37af 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -76,6 +76,18 @@ def _deserialize_accumulator(aid, zero_value, accum_param): class Accumulator(object): + """ + A shared variable that can be accumulated, i.e., has a commutative and associative "add" + operation. Worker tasks on a Spark cluster can add values to an Accumulator with the C{+=} + operator, but only the driver program is allowed to access its value, using C{value}. + Updates from the workers get propagated automatically to the driver program. + + While C{SparkContext} supports accumulators for primitive data types like C{int} and + C{float}, users can also define accumulators for custom types by providing a custom + C{AccumulatorParam} object with a C{zero} and C{addInPlace} method. Refer to the doctest + of this module for an example. + """ + def __init__(self, aid, value, accum_param): """Create a new Accumulator with a given initial value and AccumulatorParam object""" from pyspark.accumulators import _accumulatorRegistry From ee5a07955c222dce16d0ffb9bde7f61033763c16 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 20 Jan 2013 02:11:58 -0800 Subject: [PATCH 122/291] Fix Python guide to say accumulators are available --- docs/python-programming-guide.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md index 78ef310a00..a840b9b34b 100644 --- a/docs/python-programming-guide.md +++ b/docs/python-programming-guide.md @@ -16,7 +16,6 @@ There are a few key differences between the Python and Scala APIs: * Python is dynamically typed, so RDDs can hold objects of different types. * PySpark does not currently support the following Spark features: - - Accumulators - Special functions on RDDs of doubles, such as `mean` and `stdev` - `lookup` - `persist` at storage levels other than `MEMORY_ONLY` From 33bad85bb9143d41bc5de2068f7e8a8c39928225 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 20 Jan 2013 03:51:11 -0800 Subject: [PATCH 123/291] Fixed streaming testsuite bugs --- streaming/src/test/java/JavaAPISuite.java | 2 ++ .../test/scala/spark/streaming/BasicOperationsSuite.scala | 5 +++++ .../src/test/scala/spark/streaming/CheckpointSuite.scala | 6 +++--- streaming/src/test/scala/spark/streaming/FailureSuite.scala | 3 +++ .../src/test/scala/spark/streaming/InputStreamsSuite.scala | 3 +++ .../src/test/scala/spark/streaming/TestSuiteBase.scala | 6 +++--- .../test/scala/spark/streaming/WindowOperationsSuite.scala | 5 +++++ 7 files changed, 24 insertions(+), 6 deletions(-) diff --git a/streaming/src/test/java/JavaAPISuite.java b/streaming/src/test/java/JavaAPISuite.java index 8c94e13e65..c84e7331c7 100644 --- a/streaming/src/test/java/JavaAPISuite.java +++ b/streaming/src/test/java/JavaAPISuite.java @@ -34,12 +34,14 @@ public class JavaAPISuite implements Serializable { @Before public void setUp() { ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + ssc.checkpoint("checkpoint", new Duration(1000)); } @After public void tearDown() { ssc.stop(); ssc = null; + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.master.port"); } diff --git a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala index f73f9b1823..bfdf32c73e 100644 --- a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala @@ -8,6 +8,11 @@ class BasicOperationsSuite extends TestSuiteBase { override def framework() = "BasicOperationsSuite" + after { + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.master.port") + } + test("map") { val input = Seq(1 to 4, 5 to 8, 9 to 12) testOperation( diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala index 920388bba9..d2f32c189b 100644 --- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -15,9 +15,11 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { } after { - if (ssc != null) ssc.stop() FileUtils.deleteDirectory(new File(checkpointDir)) + + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.master.port") } var ssc: StreamingContext = null @@ -26,8 +28,6 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { override def batchDuration = Milliseconds(500) - override def checkpointDir = "checkpoint" - override def checkpointInterval = batchDuration override def actuallyWait = true diff --git a/streaming/src/test/scala/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/spark/streaming/FailureSuite.scala index 4aa428bf64..7493ac1207 100644 --- a/streaming/src/test/scala/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/spark/streaming/FailureSuite.scala @@ -22,6 +22,9 @@ class FailureSuite extends TestSuiteBase with BeforeAndAfter { after { FailureSuite.reset() FileUtils.deleteDirectory(new File(checkpointDir)) + + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.master.port") } override def framework = "CheckpointSuite" diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala index e71ba6ddc1..d7ba7a5d17 100644 --- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala @@ -40,6 +40,9 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { FileUtils.deleteDirectory(testDir) testDir = null } + + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.master.port") } test("network input stream") { diff --git a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala index a76f61d4ad..49129f3964 100644 --- a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala @@ -10,7 +10,7 @@ import collection.mutable.SynchronizedBuffer import java.io.{ObjectInputStream, IOException} -import org.scalatest.FunSuite +import org.scalatest.{BeforeAndAfter, FunSuite} /** * This is a input stream just for the testsuites. This is equivalent to a checkpointable, @@ -56,7 +56,7 @@ class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBu * This is the base trait for Spark Streaming testsuites. This provides basic functionality * to run user-defined set of input on user-defined stream operations, and verify the output. */ -trait TestSuiteBase extends FunSuite with Logging { +trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { def framework = "TestSuiteBase" @@ -64,7 +64,7 @@ trait TestSuiteBase extends FunSuite with Logging { def batchDuration = Seconds(1) - def checkpointDir = null.asInstanceOf[String] + def checkpointDir = "checkpoint" def checkpointInterval = batchDuration diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala index f9ba1f20f0..0c6e928835 100644 --- a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala @@ -11,6 +11,11 @@ class WindowOperationsSuite extends TestSuiteBase { override def batchDuration = Seconds(1) + after { + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.master.port") + } + val largerSlideInput = Seq( Seq(("a", 1)), Seq(("a", 2)), // 1st window from here From 5f74ead63643df83b04646c08e9bfc6b4b4a9ca9 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sun, 20 Jan 2013 08:59:20 -0800 Subject: [PATCH 124/291] Changes based on Matei's comment --- docs/ec2-scripts.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md index 8b069ca9ad..931b7a66bd 100644 --- a/docs/ec2-scripts.md +++ b/docs/ec2-scripts.md @@ -96,8 +96,9 @@ permissions on your private key file, you can run `launch` with the `spark-ec2` to attach a persistent EBS volume to each node for storing the persistent HDFS. - Finally, if you get errors while running your jobs, look at the slave's logs - for that job inside of the Mesos work directory (/mnt/mesos-work). Mesos errors - can be found using the Mesos web UI (`http://:8080`). + for that job inside of the Mesos work directory (/mnt/mesos-work). You can + also view the status of the cluster using the Mesos web UI + (`http://:8080`). # Configuration From 2a8c2a67909c4878ea24ec94f203287e55dd3782 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 20 Jan 2013 10:24:53 -0800 Subject: [PATCH 125/291] Minor formatting fixes --- examples/src/main/scala/spark/examples/SparkALS.scala | 4 ++-- python/examples/als.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/src/main/scala/spark/examples/SparkALS.scala b/examples/src/main/scala/spark/examples/SparkALS.scala index 2766ad1702..5e01885dbb 100644 --- a/examples/src/main/scala/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/spark/examples/SparkALS.scala @@ -75,8 +75,8 @@ object SparkALS { (0 to 5).map(i => { i match { - case a if a < args.length => Option(args(a)) - case _ => Option(null) + case a if a < args.length => Some(args(a)) + case _ => None } }).toArray match { case Array(host_, m, u, f, iters, slices_) => { diff --git a/python/examples/als.py b/python/examples/als.py index 284cf0d3a2..010f80097f 100755 --- a/python/examples/als.py +++ b/python/examples/als.py @@ -68,4 +68,4 @@ if __name__ == "__main__": error = rmse(R, ms, us) print "Iteration %d:" % i - print "\nRMSE: %5.4f\n" % error \ No newline at end of file + print "\nRMSE: %5.4f\n" % error From ea739251eb763b756a282534268e765b8d4b70f0 Mon Sep 17 00:00:00 2001 From: seanm Date: Sun, 20 Jan 2013 11:29:21 -0700 Subject: [PATCH 126/291] adding updateStateByKey object lifecycle test --- .../streaming/BasicOperationsSuite.scala | 45 +++++++++++++++++++ .../scala/spark/streaming/TestSuiteBase.scala | 5 +++ 2 files changed, 50 insertions(+) diff --git a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala index f73f9b1823..2bc94463b1 100644 --- a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala @@ -160,6 +160,51 @@ class BasicOperationsSuite extends TestSuiteBase { testOperation(inputData, updateStateOperation, outputData, true) } + test("updateStateByKey - object lifecycle") { + val inputData = + Seq( + Seq("a","b"), + null, + Seq("a","c","a"), + Seq("c"), + null, + null + ) + + val outputData = + Seq( + Seq(("a", 1), ("b", 1)), + Seq(("a", 1), ("b", 1)), + Seq(("a", 3), ("c", 1)), + Seq(("a", 3), ("c", 2)), + Seq(("c", 2)), + Seq() + ) + + val updateStateOperation = (s: DStream[String]) => { + class StateObject(var counter: Int = 0, var expireCounter: Int = 0) extends Serializable + + // updateFunc clears a state when a StateObject is seen without new values twice in a row + val updateFunc = (values: Seq[Int], state: Option[StateObject]) => { + val stateObj = state.getOrElse(new StateObject) + values.foldLeft(0)(_ + _) match { + case 0 => stateObj.expireCounter += 1 // no new values + case n => { // has new values, increment and reset expireCounter + stateObj.counter += n + stateObj.expireCounter = 0 + } + } + stateObj.expireCounter match { + case 2 => None // seen twice with no new values, give it the boot + case _ => Option(stateObj) + } + } + s.map(_ -> 1).updateStateByKey[StateObject](updateFunc).mapValues(_.counter) + } + + testOperation(inputData, updateStateOperation, outputData, true) + } + test("forgetting of RDDs - map and window operations") { assert(batchDuration === Seconds(1), "Batch duration has changed from 1 second") diff --git a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala index a76f61d4ad..11cfcba827 100644 --- a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala @@ -28,6 +28,11 @@ class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[ logInfo("Computing RDD for time " + validTime) val index = ((validTime - zeroTime) / slideDuration - 1).toInt val selectedInput = if (index < input.size) input(index) else Seq[T]() + + // lets us test cases where RDDs are not created + if (selectedInput == null) + return None + val rdd = ssc.sc.makeRDD(selectedInput, numPartitions) logInfo("Created RDD " + rdd.id + " with " + selectedInput) Some(rdd) From c0694291c81ad775918421941a80a00ca9593a38 Mon Sep 17 00:00:00 2001 From: seanm Date: Sun, 20 Jan 2013 12:09:45 -0700 Subject: [PATCH 127/291] Splitting StreamingContext.queueStream into two methods --- .../spark/streaming/StreamingContext.scala | 22 +++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 14500bdcb1..3cec35cb37 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -283,17 +283,31 @@ class StreamingContext private ( } /** - * Creates a input stream from an queue of RDDs. In each batch, + * Creates an input stream from a queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * @param queue Queue of RDDs * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval - * @param defaultRDD Default RDD is returned by the DStream when the queue is empty * @tparam T Type of objects in the RDD */ def queueStream[T: ClassManifest]( queue: Queue[RDD[T]], - oneAtATime: Boolean = true, - defaultRDD: RDD[T] = null + oneAtATime: Boolean = true + ): DStream[T] = { + queueStream(queue, oneAtATime, sc.makeRDD(Seq[T](), 1)) + } + + /** + * Creates an input stream from a queue of RDDs. In each batch, + * it will process either one or all of the RDDs returned by the queue. + * @param queue Queue of RDDs + * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval + * @param defaultRDD Default RDD is returned by the DStream when the queue is empty. Set as null if no RDD should be returned when empty + * @tparam T Type of objects in the RDD + */ + def queueStream[T: ClassManifest]( + queue: Queue[RDD[T]], + oneAtATime: Boolean, + defaultRDD: RDD[T] ): DStream[T] = { val inputStream = new QueueInputDStream(this, queue, oneAtATime, defaultRDD) registerInputStream(inputStream) From 17035db159e191a11cd86882c97078581073deb2 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 20 Jan 2013 11:22:38 -0800 Subject: [PATCH 128/291] Add __repr__ to Accumulator; fix bug in sc.accumulator --- python/pyspark/accumulators.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index c00c3a37af..8011779ddc 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -11,6 +11,12 @@ >>> a.value 7 +>>> sc.accumulator(1.0).value +1.0 + +>>> sc.accumulator(1j).value +1j + >>> rdd = sc.parallelize([1,2,3]) >>> def f(x): ... global a @@ -124,6 +130,9 @@ class Accumulator(object): def __str__(self): return str(self._value) + def __repr__(self): + return "Accumulator" % (self.aid, self._value) + class AddingAccumulatorParam(object): """ @@ -145,7 +154,7 @@ class AddingAccumulatorParam(object): # Singleton accumulator params for some standard types INT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0) -DOUBLE_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0) +FLOAT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0) COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j) From 7ed1bf4b485131d58ea6728e7247b79320aca9e6 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 16 Jan 2013 19:15:14 -0800 Subject: [PATCH 129/291] Add RDD checkpointing to Python API. --- .../scala/spark/api/python/PythonRDD.scala | 3 -- python/epydoc.conf | 2 +- python/pyspark/context.py | 9 ++++ python/pyspark/rdd.py | 34 ++++++++++++++ python/pyspark/tests.py | 46 +++++++++++++++++++ python/run-tests | 3 ++ 6 files changed, 93 insertions(+), 4 deletions(-) create mode 100644 python/pyspark/tests.py diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 89f7c316dc..8c38262dd8 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -135,8 +135,6 @@ private[spark] class PythonRDD[T: ClassManifest]( } } - override def checkpoint() { } - val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) } @@ -152,7 +150,6 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends case Seq(a, b) => (a, b) case x => throw new Exception("PairwiseRDD: unexpected value: " + x) } - override def checkpoint() { } val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this) } diff --git a/python/epydoc.conf b/python/epydoc.conf index 91ac984ba2..45102cd9fe 100644 --- a/python/epydoc.conf +++ b/python/epydoc.conf @@ -16,4 +16,4 @@ target: docs/ private: no exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers - pyspark.java_gateway pyspark.examples pyspark.shell + pyspark.java_gateway pyspark.examples pyspark.shell pyspark.test diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 1e2f845f9c..a438b43fdc 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -195,3 +195,12 @@ class SparkContext(object): filename = path.split("/")[-1] os.environ["PYTHONPATH"] = \ "%s:%s" % (filename, os.environ["PYTHONPATH"]) + + def setCheckpointDir(self, dirName, useExisting=False): + """ + Set the directory under which RDDs are going to be checkpointed. This + method will create this directory and will throw an exception of the + path already exists (to avoid overwriting existing files may be + overwritten). The directory will be deleted on exit if indicated. + """ + self._jsc.sc().setCheckpointDir(dirName, useExisting) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index d705f0f9e1..9b676cae4a 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -49,6 +49,40 @@ class RDD(object): self._jrdd.cache() return self + def checkpoint(self): + """ + Mark this RDD for checkpointing. The RDD will be saved to a file inside + `checkpointDir` (set using setCheckpointDir()) and all references to + its parent RDDs will be removed. This is used to truncate very long + lineages. In the current implementation, Spark will save this RDD to + a file (using saveAsObjectFile()) after the first job using this RDD is + done. Hence, it is strongly recommended to use checkpoint() on RDDs + when + + (i) checkpoint() is called before the any job has been executed on this + RDD. + + (ii) This RDD has been made to persist in memory. Otherwise saving it + on a file will require recomputation. + """ + self._jrdd.rdd().checkpoint() + + def isCheckpointed(self): + """ + Return whether this RDD has been checkpointed or not + """ + return self._jrdd.rdd().isCheckpointed() + + def getCheckpointFile(self): + """ + Gets the name of the file to which this RDD was checkpointed + """ + checkpointFile = self._jrdd.rdd().getCheckpointFile() + if checkpointFile.isDefined(): + return checkpointFile.get() + else: + return None + # TODO persist(self, storageLevel) def map(self, f, preservesPartitioning=False): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py new file mode 100644 index 0000000000..c959d5dec7 --- /dev/null +++ b/python/pyspark/tests.py @@ -0,0 +1,46 @@ +""" +Unit tests for PySpark; additional tests are implemented as doctests in +individual modules. +""" +import atexit +import os +import shutil +from tempfile import NamedTemporaryFile +import time +import unittest + +from pyspark.context import SparkContext + + +class TestCheckpoint(unittest.TestCase): + + def setUp(self): + self.sc = SparkContext('local[4]', 'TestPartitioning', batchSize=2) + + def tearDown(self): + self.sc.stop() + + def test_basic_checkpointing(self): + checkpointDir = NamedTemporaryFile(delete=False) + os.unlink(checkpointDir.name) + self.sc.setCheckpointDir(checkpointDir.name) + + parCollection = self.sc.parallelize([1, 2, 3, 4]) + flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1)) + + self.assertFalse(flatMappedRDD.isCheckpointed()) + self.assertIsNone(flatMappedRDD.getCheckpointFile()) + + flatMappedRDD.checkpoint() + result = flatMappedRDD.collect() + time.sleep(1) # 1 second + self.assertTrue(flatMappedRDD.isCheckpointed()) + self.assertEqual(flatMappedRDD.collect(), result) + self.assertEqual(checkpointDir.name, + os.path.dirname(flatMappedRDD.getCheckpointFile())) + + atexit.register(lambda: shutil.rmtree(checkpointDir.name)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/run-tests b/python/run-tests index 32470911f9..ce214e98a8 100755 --- a/python/run-tests +++ b/python/run-tests @@ -14,6 +14,9 @@ FAILED=$(($?||$FAILED)) $FWDIR/pyspark -m doctest pyspark/accumulators.py FAILED=$(($?||$FAILED)) +$FWDIR/pyspark -m unittest pyspark.tests +FAILED=$(($?||$FAILED)) + if [[ $FAILED != 0 ]]; then echo -en "\033[31m" # Red echo "Had test failures; see logs." From d0ba80dc727d00b2b7627dcefd2c77009af55f7d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 20 Jan 2013 13:59:45 -0800 Subject: [PATCH 130/291] Add checkpointFile() and more tests to PySpark. --- python/pyspark/context.py | 6 +++++- python/pyspark/rdd.py | 9 ++++++++- python/pyspark/tests.py | 24 ++++++++++++++++++++++++ 3 files changed, 37 insertions(+), 2 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index a438b43fdc..8beb8e2ae9 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -123,6 +123,10 @@ class SparkContext(object): jrdd = self._jsc.textFile(name, minSplits) return RDD(jrdd, self) + def _checkpointFile(self, name): + jrdd = self._jsc.checkpointFile(name) + return RDD(jrdd, self) + def union(self, rdds): """ Build the union of a list of RDDs. @@ -145,7 +149,7 @@ class SparkContext(object): def accumulator(self, value, accum_param=None): """ Create an C{Accumulator} with the given initial value, using a given - AccumulatorParam helper object to define how to add values of the data + AccumulatorParam helper object to define how to add values of the data type if provided. Default AccumulatorParams are used for integers and floating-point numbers if you do not provide one. For other types, the AccumulatorParam must implement two methods: diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 9b676cae4a..2a2ff9b271 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -32,6 +32,7 @@ class RDD(object): def __init__(self, jrdd, ctx): self._jrdd = jrdd self.is_cached = False + self.is_checkpointed = False self.ctx = ctx @property @@ -65,6 +66,7 @@ class RDD(object): (ii) This RDD has been made to persist in memory. Otherwise saving it on a file will require recomputation. """ + self.is_checkpointed = True self._jrdd.rdd().checkpoint() def isCheckpointed(self): @@ -696,7 +698,7 @@ class PipelinedRDD(RDD): 20 """ def __init__(self, prev, func, preservesPartitioning=False): - if isinstance(prev, PipelinedRDD) and not prev.is_cached: + if isinstance(prev, PipelinedRDD) and prev._is_pipelinable: prev_func = prev.func def pipeline_func(split, iterator): return func(split, prev_func(split, iterator)) @@ -709,6 +711,7 @@ class PipelinedRDD(RDD): self.preservesPartitioning = preservesPartitioning self._prev_jrdd = prev._jrdd self.is_cached = False + self.is_checkpointed = False self.ctx = prev.ctx self.prev = prev self._jrdd_val = None @@ -741,6 +744,10 @@ class PipelinedRDD(RDD): self._jrdd_val = python_rdd.asJavaRDD() return self._jrdd_val + @property + def _is_pipelinable(self): + return not (self.is_cached or self.is_checkpointed) + def _test(): import doctest diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index c959d5dec7..83283fca4f 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -19,6 +19,9 @@ class TestCheckpoint(unittest.TestCase): def tearDown(self): self.sc.stop() + # To avoid Akka rebinding to the same port, since it doesn't unbind + # immediately on shutdown + self.sc.jvm.System.clearProperty("spark.master.port") def test_basic_checkpointing(self): checkpointDir = NamedTemporaryFile(delete=False) @@ -41,6 +44,27 @@ class TestCheckpoint(unittest.TestCase): atexit.register(lambda: shutil.rmtree(checkpointDir.name)) + def test_checkpoint_and_restore(self): + checkpointDir = NamedTemporaryFile(delete=False) + os.unlink(checkpointDir.name) + self.sc.setCheckpointDir(checkpointDir.name) + + parCollection = self.sc.parallelize([1, 2, 3, 4]) + flatMappedRDD = parCollection.flatMap(lambda x: [x]) + + self.assertFalse(flatMappedRDD.isCheckpointed()) + self.assertIsNone(flatMappedRDD.getCheckpointFile()) + + flatMappedRDD.checkpoint() + flatMappedRDD.count() # forces a checkpoint to be computed + time.sleep(1) # 1 second + + self.assertIsNotNone(flatMappedRDD.getCheckpointFile()) + recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile()) + self.assertEquals([1, 2, 3, 4], recovered.collect()) + + atexit.register(lambda: shutil.rmtree(checkpointDir.name)) + if __name__ == "__main__": unittest.main() From 5b6ea9e9a04994553d0319c541ca356e2e3064a7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 20 Jan 2013 15:31:41 -0800 Subject: [PATCH 131/291] Update checkpointing API docs in Python/Java. --- .../main/scala/spark/api/java/JavaRDDLike.scala | 15 ++++++--------- .../scala/spark/api/java/JavaSparkContext.scala | 17 +++++++++-------- python/pyspark/context.py | 11 +++++++---- python/pyspark/rdd.py | 17 +++++------------ 4 files changed, 27 insertions(+), 33 deletions(-) diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index 087270e46d..b3698ffa44 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -307,16 +307,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { implicit val kcm: ClassManifest[K] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] JavaPairRDD.fromRDD(rdd.keyBy(f)) } - + /** - * Mark this RDD for checkpointing. The RDD will be saved to a file inside `checkpointDir` - * (set using setCheckpointDir()) and all references to its parent RDDs will be removed. - * This is used to truncate very long lineages. In the current implementation, Spark will save - * this RDD to a file (using saveAsObjectFile()) after the first job using this RDD is done. - * Hence, it is strongly recommended to use checkpoint() on RDDs when - * (i) checkpoint() is called before the any job has been executed on this RDD. - * (ii) This RDD has been made to persist in memory. Otherwise saving it on a file will - * require recomputation. + * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint + * directory set with SparkContext.setCheckpointDir() and all references to its parent + * RDDs will be removed. This function must be called before any job has been + * executed on this RDD. It is strongly recommended that this RDD is persisted in + * memory, otherwise saving it on a file will require recomputation. */ def checkpoint() = rdd.checkpoint() diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala index fa2f14113d..14699961ad 100644 --- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala @@ -357,20 +357,21 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork } /** - * Set the directory under which RDDs are going to be checkpointed. This method will - * create this directory and will throw an exception of the path already exists (to avoid - * overwriting existing files may be overwritten). The directory will be deleted on exit - * if indicated. + * Set the directory under which RDDs are going to be checkpointed. The directory must + * be a HDFS path if running on a cluster. If the directory does not exist, it will + * be created. If the directory exists and useExisting is set to true, then the + * exisiting directory will be used. Otherwise an exception will be thrown to + * prevent accidental overriding of checkpoint files in the existing directory. */ def setCheckpointDir(dir: String, useExisting: Boolean) { sc.setCheckpointDir(dir, useExisting) } /** - * Set the directory under which RDDs are going to be checkpointed. This method will - * create this directory and will throw an exception of the path already exists (to avoid - * overwriting existing files may be overwritten). The directory will be deleted on exit - * if indicated. + * Set the directory under which RDDs are going to be checkpointed. The directory must + * be a HDFS path if running on a cluster. If the directory does not exist, it will + * be created. If the directory exists, an exception will be thrown to prevent accidental + * overriding of checkpoint files. */ def setCheckpointDir(dir: String) { sc.setCheckpointDir(dir) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 8beb8e2ae9..dcbed37270 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -202,9 +202,12 @@ class SparkContext(object): def setCheckpointDir(self, dirName, useExisting=False): """ - Set the directory under which RDDs are going to be checkpointed. This - method will create this directory and will throw an exception of the - path already exists (to avoid overwriting existing files may be - overwritten). The directory will be deleted on exit if indicated. + Set the directory under which RDDs are going to be checkpointed. The + directory must be a HDFS path if running on a cluster. + + If the directory does not exist, it will be created. If the directory + exists and C{useExisting} is set to true, then the exisiting directory + will be used. Otherwise an exception will be thrown to prevent + accidental overriding of checkpoint files in the existing directory. """ self._jsc.sc().setCheckpointDir(dirName, useExisting) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 2a2ff9b271..7b6ab956ee 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -52,18 +52,11 @@ class RDD(object): def checkpoint(self): """ - Mark this RDD for checkpointing. The RDD will be saved to a file inside - `checkpointDir` (set using setCheckpointDir()) and all references to - its parent RDDs will be removed. This is used to truncate very long - lineages. In the current implementation, Spark will save this RDD to - a file (using saveAsObjectFile()) after the first job using this RDD is - done. Hence, it is strongly recommended to use checkpoint() on RDDs - when - - (i) checkpoint() is called before the any job has been executed on this - RDD. - - (ii) This RDD has been made to persist in memory. Otherwise saving it + Mark this RDD for checkpointing. It will be saved to a file inside the + checkpoint directory set with L{SparkContext.setCheckpointDir()} and + all references to its parent RDDs will be removed. This function must + be called before any job has been executed on this RDD. It is strongly + recommended that this RDD is persisted in memory, otherwise saving it on a file will require recomputation. """ self.is_checkpointed = True From 00d70cd6602d5ff2718e319ec04defbdd486237e Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 20 Jan 2013 15:38:11 -0800 Subject: [PATCH 132/291] Clean up setup code in PySpark checkpointing tests --- python/pyspark/rdd.py | 3 +-- python/pyspark/tests.py | 19 +++++-------------- 2 files changed, 6 insertions(+), 16 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 7b6ab956ee..097cdb13b4 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -691,7 +691,7 @@ class PipelinedRDD(RDD): 20 """ def __init__(self, prev, func, preservesPartitioning=False): - if isinstance(prev, PipelinedRDD) and prev._is_pipelinable: + if isinstance(prev, PipelinedRDD) and prev._is_pipelinable(): prev_func = prev.func def pipeline_func(split, iterator): return func(split, prev_func(split, iterator)) @@ -737,7 +737,6 @@ class PipelinedRDD(RDD): self._jrdd_val = python_rdd.asJavaRDD() return self._jrdd_val - @property def _is_pipelinable(self): return not (self.is_cached or self.is_checkpointed) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 83283fca4f..b0a403b580 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -2,7 +2,6 @@ Unit tests for PySpark; additional tests are implemented as doctests in individual modules. """ -import atexit import os import shutil from tempfile import NamedTemporaryFile @@ -16,18 +15,18 @@ class TestCheckpoint(unittest.TestCase): def setUp(self): self.sc = SparkContext('local[4]', 'TestPartitioning', batchSize=2) + self.checkpointDir = NamedTemporaryFile(delete=False) + os.unlink(self.checkpointDir.name) + self.sc.setCheckpointDir(self.checkpointDir.name) def tearDown(self): self.sc.stop() # To avoid Akka rebinding to the same port, since it doesn't unbind # immediately on shutdown self.sc.jvm.System.clearProperty("spark.master.port") + shutil.rmtree(self.checkpointDir.name) def test_basic_checkpointing(self): - checkpointDir = NamedTemporaryFile(delete=False) - os.unlink(checkpointDir.name) - self.sc.setCheckpointDir(checkpointDir.name) - parCollection = self.sc.parallelize([1, 2, 3, 4]) flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1)) @@ -39,16 +38,10 @@ class TestCheckpoint(unittest.TestCase): time.sleep(1) # 1 second self.assertTrue(flatMappedRDD.isCheckpointed()) self.assertEqual(flatMappedRDD.collect(), result) - self.assertEqual(checkpointDir.name, + self.assertEqual(self.checkpointDir.name, os.path.dirname(flatMappedRDD.getCheckpointFile())) - atexit.register(lambda: shutil.rmtree(checkpointDir.name)) - def test_checkpoint_and_restore(self): - checkpointDir = NamedTemporaryFile(delete=False) - os.unlink(checkpointDir.name) - self.sc.setCheckpointDir(checkpointDir.name) - parCollection = self.sc.parallelize([1, 2, 3, 4]) flatMappedRDD = parCollection.flatMap(lambda x: [x]) @@ -63,8 +56,6 @@ class TestCheckpoint(unittest.TestCase): recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile()) self.assertEquals([1, 2, 3, 4], recovered.collect()) - atexit.register(lambda: shutil.rmtree(checkpointDir.name)) - if __name__ == "__main__": unittest.main() From 9f211dd3f0132daf72fb39883fa4b28e4fd547ca Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 14 Jan 2013 15:30:42 -0800 Subject: [PATCH 133/291] Fix PythonPartitioner equality; see SPARK-654. PythonPartitioner did not take the Python-side partitioning function into account when checking for equality, which might cause problems in the future. --- .../spark/api/python/PythonPartitioner.scala | 13 +++++++++++-- .../main/scala/spark/api/python/PythonRDD.scala | 5 ----- python/pyspark/rdd.py | 17 +++++++++++------ 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/spark/api/python/PythonPartitioner.scala index 648d9402b0..519e310323 100644 --- a/core/src/main/scala/spark/api/python/PythonPartitioner.scala +++ b/core/src/main/scala/spark/api/python/PythonPartitioner.scala @@ -6,8 +6,17 @@ import java.util.Arrays /** * A [[spark.Partitioner]] that performs handling of byte arrays, for use by the Python API. + * + * Stores the unique id() of the Python-side partitioning function so that it is incorporated into + * equality comparisons. Correctness requires that the id is a unique identifier for the + * lifetime of the job (i.e. that it is not re-used as the id of a different partitioning + * function). This can be ensured by using the Python id() function and maintaining a reference + * to the Python partitioning function so that its id() is not reused. */ -private[spark] class PythonPartitioner(override val numPartitions: Int) extends Partitioner { +private[spark] class PythonPartitioner( + override val numPartitions: Int, + val pyPartitionFunctionId: Long) + extends Partitioner { override def getPartition(key: Any): Int = { if (key == null) { @@ -32,7 +41,7 @@ private[spark] class PythonPartitioner(override val numPartitions: Int) extends override def equals(other: Any): Boolean = other match { case h: PythonPartitioner => - h.numPartitions == numPartitions + h.numPartitions == numPartitions && h.pyPartitionFunctionId == pyPartitionFunctionId case _ => false } diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 89f7c316dc..e4c0530241 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -252,11 +252,6 @@ private object Pickle { val APPENDS: Byte = 'e' } -private class ExtractValue extends spark.api.java.function.Function[(Array[Byte], - Array[Byte]), Array[Byte]] { - override def call(pair: (Array[Byte], Array[Byte])) : Array[Byte] = pair._2 -} - private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] { override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8") } diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index d705f0f9e1..b58bf24e3e 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -33,6 +33,7 @@ class RDD(object): self._jrdd = jrdd self.is_cached = False self.ctx = ctx + self._partitionFunc = None @property def context(self): @@ -497,7 +498,7 @@ class RDD(object): return python_right_outer_join(self, other, numSplits) # TODO: add option to control map-side combining - def partitionBy(self, numSplits, hashFunc=hash): + def partitionBy(self, numSplits, partitionFunc=hash): """ Return a copy of the RDD partitioned using the specified partitioner. @@ -514,17 +515,21 @@ class RDD(object): def add_shuffle_key(split, iterator): buckets = defaultdict(list) for (k, v) in iterator: - buckets[hashFunc(k) % numSplits].append((k, v)) + buckets[partitionFunc(k) % numSplits].append((k, v)) for (split, items) in buckets.iteritems(): yield str(split) yield dump_pickle(Batch(items)) keyed = PipelinedRDD(self, add_shuffle_key) keyed._bypass_serializer = True pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() - partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits) - jrdd = pairRDD.partitionBy(partitioner) - jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) - return RDD(jrdd, self.ctx) + partitioner = self.ctx.jvm.PythonPartitioner(numSplits, + id(partitionFunc)) + jrdd = pairRDD.partitionBy(partitioner).values() + rdd = RDD(jrdd, self.ctx) + # This is required so that id(partitionFunc) remains unique, even if + # partitionFunc is a lambda: + rdd._partitionFunc = partitionFunc + return rdd # TODO: add control over map-side aggregation def combineByKey(self, createCombiner, mergeValue, mergeCombiners, From 6e3754bf4759ab3e1e1be978b6b84e6f17742106 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 20 Jan 2013 19:22:24 -0800 Subject: [PATCH 134/291] Add Maven build file for streaming, and fix some issues in SBT file As part of this, changed our Scala 2.9.2 Kafka library to be available as a local Maven repository, following the example in (http://blog.dub.podval.org/2010/01/maven-in-project-repository.html) --- examples/pom.xml | 17 ++ pom.xml | 12 ++ project/SparkBuild.scala | 16 +- repl/pom.xml | 14 ++ .../kafka/0.7.2-spark/kafka-0.7.2-spark.jar} | Bin .../0.7.2-spark/kafka-0.7.2-spark.jar.md5 | 1 + .../0.7.2-spark/kafka-0.7.2-spark.jar.sha1 | 1 + .../kafka/0.7.2-spark/kafka-0.7.2-spark.pom | 9 + .../0.7.2-spark/kafka-0.7.2-spark.pom.md5 | 1 + .../0.7.2-spark/kafka-0.7.2-spark.pom.sha1 | 1 + .../kafka/kafka/maven-metadata-local.xml | 12 ++ .../kafka/kafka/maven-metadata-local.xml.md5 | 1 + .../kafka/kafka/maven-metadata-local.xml.sha1 | 1 + streaming/pom.xml | 155 ++++++++++++++++++ 14 files changed, 234 insertions(+), 7 deletions(-) rename streaming/lib/{kafka-0.7.2.jar => org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar} (100%) create mode 100644 streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.md5 create mode 100644 streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.sha1 create mode 100644 streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom create mode 100644 streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.md5 create mode 100644 streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.sha1 create mode 100644 streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml create mode 100644 streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.md5 create mode 100644 streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.sha1 create mode 100644 streaming/pom.xml diff --git a/examples/pom.xml b/examples/pom.xml index 3355deb6b7..4d43103475 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -19,6 +19,11 @@ org.eclipse.jetty jetty-server + + org.twitter4j + twitter4j-stream + 3.0.3 + org.scalatest @@ -57,6 +62,12 @@ ${project.version} hadoop1 + + org.spark-project + spark-streaming + ${project.version} + hadoop1 + org.apache.hadoop hadoop-core @@ -90,6 +101,12 @@ ${project.version} hadoop2 + + org.spark-project + spark-streaming + ${project.version} + hadoop2 + org.apache.hadoop hadoop-core diff --git a/pom.xml b/pom.xml index 751189a9d8..483b0f9595 100644 --- a/pom.xml +++ b/pom.xml @@ -41,6 +41,7 @@ core bagel examples + streaming repl repl-bin @@ -104,6 +105,17 @@ false + + twitter4j-repo + Twitter4J Repository + http://twitter4j.org/maven2/ + + true + + + false + + diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 3dbb993f9c..03b8094f7d 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -21,7 +21,7 @@ object SparkBuild extends Build { lazy val core = Project("core", file("core"), settings = coreSettings) - lazy val repl = Project("repl", file("repl"), settings = replSettings) dependsOn (core) + lazy val repl = Project("repl", file("repl"), settings = replSettings) dependsOn (core) dependsOn (streaming) lazy val examples = Project("examples", file("examples"), settings = examplesSettings) dependsOn (core) dependsOn (streaming) @@ -92,8 +92,7 @@ object SparkBuild extends Build { "org.eclipse.jetty" % "jetty-server" % "7.5.3.v20111011", "org.scalatest" %% "scalatest" % "1.8" % "test", "org.scalacheck" %% "scalacheck" % "1.9" % "test", - "com.novocode" % "junit-interface" % "0.8" % "test", - "org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile" + "com.novocode" % "junit-interface" % "0.8" % "test" ), parallelExecution := false, /* Workaround for issue #206 (fixed after SBT 0.11.0) */ @@ -136,8 +135,6 @@ object SparkBuild extends Build { "com.typesafe.akka" % "akka-slf4j" % "2.0.3", "it.unimi.dsi" % "fastutil" % "6.4.4", "colt" % "colt" % "1.2.0", - "org.twitter4j" % "twitter4j-core" % "3.0.2", - "org.twitter4j" % "twitter4j-stream" % "3.0.2", "cc.spray" % "spray-can" % "1.0-M2.1", "cc.spray" % "spray-server" % "1.0-M2.1", "cc.spray" %% "spray-json" % "1.1.1", @@ -156,7 +153,10 @@ object SparkBuild extends Build { ) def examplesSettings = sharedSettings ++ Seq( - name := "spark-examples" + name := "spark-examples", + libraryDependencies ++= Seq( + "org.twitter4j" % "twitter4j-stream" % "3.0.3" + ) ) def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel") @@ -164,7 +164,9 @@ object SparkBuild extends Build { def streamingSettings = sharedSettings ++ Seq( name := "spark-streaming", libraryDependencies ++= Seq( - "com.github.sgroschupf" % "zkclient" % "0.1") + "org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile", + "com.github.sgroschupf" % "zkclient" % "0.1" + ) ) ++ assemblySettings ++ extraAssemblySettings def extraAssemblySettings() = Seq(test in assembly := {}) ++ Seq( diff --git a/repl/pom.xml b/repl/pom.xml index 38e883c7f8..2fc9692969 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -101,6 +101,13 @@ hadoop1 runtime + + org.spark-project + spark-streaming + ${project.version} + hadoop1 + runtime + org.apache.hadoop hadoop-core @@ -151,6 +158,13 @@ hadoop2 runtime + + org.spark-project + spark-streaming + ${project.version} + hadoop2 + runtime + org.apache.hadoop hadoop-core diff --git a/streaming/lib/kafka-0.7.2.jar b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar similarity index 100% rename from streaming/lib/kafka-0.7.2.jar rename to streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.md5 b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.md5 new file mode 100644 index 0000000000..29f45f4adb --- /dev/null +++ b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.md5 @@ -0,0 +1 @@ +18876b8bc2e4cef28b6d191aa49d963f \ No newline at end of file diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.sha1 b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.sha1 new file mode 100644 index 0000000000..e3bd62bac0 --- /dev/null +++ b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.sha1 @@ -0,0 +1 @@ +06b27270ffa52250a2c08703b397c99127b72060 \ No newline at end of file diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom new file mode 100644 index 0000000000..082d35726a --- /dev/null +++ b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom @@ -0,0 +1,9 @@ + + + 4.0.0 + org.apache.kafka + kafka + 0.7.2-spark + POM was created from install:install-file + diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.md5 b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.md5 new file mode 100644 index 0000000000..92c4132b5b --- /dev/null +++ b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.md5 @@ -0,0 +1 @@ +7bc4322266e6032bdf9ef6eebdd8097d \ No newline at end of file diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.sha1 b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.sha1 new file mode 100644 index 0000000000..8a1d8a097a --- /dev/null +++ b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.sha1 @@ -0,0 +1 @@ +d0f79e8eff0db43ca7bcf7dce2c8cd2972685c9d \ No newline at end of file diff --git a/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml b/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml new file mode 100644 index 0000000000..720cd51c2f --- /dev/null +++ b/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml @@ -0,0 +1,12 @@ + + + org.apache.kafka + kafka + + 0.7.2-spark + + 0.7.2-spark + + 20130121015225 + + diff --git a/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.md5 b/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.md5 new file mode 100644 index 0000000000..a4ce5dc9e8 --- /dev/null +++ b/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.md5 @@ -0,0 +1 @@ +e2b9c7c5f6370dd1d21a0aae5e8dcd77 \ No newline at end of file diff --git a/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.sha1 b/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.sha1 new file mode 100644 index 0000000000..b869eaf2a6 --- /dev/null +++ b/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.sha1 @@ -0,0 +1 @@ +2a4341da936b6c07a09383d17ffb185ac558ee91 \ No newline at end of file diff --git a/streaming/pom.xml b/streaming/pom.xml new file mode 100644 index 0000000000..3dae815e1a --- /dev/null +++ b/streaming/pom.xml @@ -0,0 +1,155 @@ + + + 4.0.0 + + org.spark-project + parent + 0.7.0-SNAPSHOT + ../pom.xml + + + org.spark-project + spark-streaming + jar + Spark Project Streaming + http://spark-project.org/ + + + + + lib + file://${project.basedir}/lib + + + + + + org.eclipse.jetty + jetty-server + + + org.codehaus.jackson + jackson-mapper-asl + 1.9.11 + + + org.apache.kafka + kafka + 0.7.2-spark + + + org.apache.flume + flume-ng-sdk + 1.2.0 + + + com.github.sgroschupf + zkclient + 0.1 + + + + org.scalatest + scalatest_${scala.version} + test + + + org.scalacheck + scalacheck_${scala.version} + test + + + com.novocode + junit-interface + test + + + org.slf4j + slf4j-log4j12 + test + + + + target/scala-${scala.version}/classes + target/scala-${scala.version}/test-classes + + + org.scalatest + scalatest-maven-plugin + + + + + + + hadoop1 + + + !hadoopVersion + + + + + org.spark-project + spark-core + ${project.version} + hadoop1 + + + org.apache.hadoop + hadoop-core + provided + + + + + + org.apache.maven.plugins + maven-jar-plugin + + hadoop1 + + + + + + + hadoop2 + + + hadoopVersion + 2 + + + + + org.spark-project + spark-core + ${project.version} + hadoop2 + + + org.apache.hadoop + hadoop-core + provided + + + org.apache.hadoop + hadoop-client + provided + + + + + + org.apache.maven.plugins + maven-jar-plugin + + hadoop2 + + + + + + + From 4750907c3dad4c275d0f51937a098ba856098b96 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 20 Jan 2013 21:05:17 -0800 Subject: [PATCH 135/291] Update run script to deal with change to build of REPL shaded JAR --- run | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/run b/run index 060856007f..a094629449 100755 --- a/run +++ b/run @@ -89,9 +89,11 @@ if [ -e "$FWDIR/lib_managed" ]; then CLASSPATH+=":$FWDIR/lib_managed/bundles/*" fi CLASSPATH+=":$REPL_DIR/lib/*" -for jar in `find "$REPL_DIR/target" -name 'spark-repl-*-shaded-hadoop*.jar'`; do - CLASSPATH+=":$jar" -done +if [ -e repl-bin/target ]; then + for jar in `find "repl-bin/target" -name 'spark-repl-*-shaded-hadoop*.jar'`; do + CLASSPATH+=":$jar" + done +fi CLASSPATH+=":$BAGEL_DIR/target/scala-$SCALA_VERSION/classes" for jar in `find $PYSPARK_DIR/lib -name '*jar'`; do CLASSPATH+=":$jar" From c0b9ceb8c3d56c6d6f6f6b5925c87abad06be646 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 21 Jan 2013 00:23:53 -0800 Subject: [PATCH 136/291] Log remote lifecycle events in Akka for easier debugging --- core/src/main/scala/spark/util/AkkaUtils.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index e67cb0336d..fbd0ff46bf 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -32,6 +32,7 @@ private[spark] object AkkaUtils { akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"] akka.actor.provider = "akka.remote.RemoteActorRefProvider" akka.remote.transport = "akka.remote.netty.NettyRemoteTransport" + akka.remote.log-remote-lifecycle-events = on akka.remote.netty.hostname = "%s" akka.remote.netty.port = %d akka.remote.netty.connection-timeout = %ds From 69a417858bf1627de5220d41afba64853d4bf64d Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Mon, 21 Jan 2013 12:42:11 -0600 Subject: [PATCH 137/291] Also use hadoopConfiguration in newAPI methods. --- core/src/main/scala/spark/PairRDDFunctions.scala | 4 ++-- core/src/main/scala/spark/SparkContext.scala | 7 +++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 51c15837c4..1c18736805 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -494,7 +494,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( keyClass: Class[_], valueClass: Class[_], outputFormatClass: Class[_ <: NewOutputFormat[_, _]]) { - saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass, new Configuration) + saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass) } /** @@ -506,7 +506,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( keyClass: Class[_], valueClass: Class[_], outputFormatClass: Class[_ <: NewOutputFormat[_, _]], - conf: Configuration) { + conf: Configuration = self.context.hadoopConfiguration) { val job = new NewAPIHadoopJob(conf) job.setOutputKeyClass(keyClass) job.setOutputValueClass(valueClass) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index f6b98c41bc..303e5081a4 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -293,8 +293,7 @@ class SparkContext( path, fm.erasure.asInstanceOf[Class[F]], km.erasure.asInstanceOf[Class[K]], - vm.erasure.asInstanceOf[Class[V]], - new Configuration(hadoopConfiguration)) + vm.erasure.asInstanceOf[Class[V]]) } /** @@ -306,7 +305,7 @@ class SparkContext( fClass: Class[F], kClass: Class[K], vClass: Class[V], - conf: Configuration): RDD[(K, V)] = { + conf: Configuration = hadoopConfiguration): RDD[(K, V)] = { val job = new NewHadoopJob(conf) NewFileInputFormat.addInputPath(job, new Path(path)) val updatedConf = job.getConfiguration @@ -318,7 +317,7 @@ class SparkContext( * and extra configuration options to pass to the input format. */ def newAPIHadoopRDD[K, V, F <: NewInputFormat[K, V]]( - conf: Configuration, + conf: Configuration = hadoopConfiguration, fClass: Class[F], kClass: Class[K], vClass: Class[V]): RDD[(K, V)] = { From f116d6b5c6029c2f96160bd84829a6fe8b73cccf Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 18 Jan 2013 13:24:37 -0800 Subject: [PATCH 138/291] executor can use a different sparkHome from Worker --- core/src/main/scala/spark/deploy/DeployMessage.scala | 4 +++- core/src/main/scala/spark/deploy/JobDescription.scala | 5 ++++- core/src/main/scala/spark/deploy/client/TestClient.scala | 3 ++- core/src/main/scala/spark/deploy/master/Master.scala | 9 +++++---- core/src/main/scala/spark/deploy/worker/Worker.scala | 4 ++-- .../scheduler/cluster/SparkDeploySchedulerBackend.scala | 3 ++- 6 files changed, 18 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/spark/deploy/DeployMessage.scala b/core/src/main/scala/spark/deploy/DeployMessage.scala index 457122745b..7ee3e63429 100644 --- a/core/src/main/scala/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/spark/deploy/DeployMessage.scala @@ -5,6 +5,7 @@ import spark.deploy.master.{WorkerInfo, JobInfo} import spark.deploy.worker.ExecutorRunner import scala.collection.immutable.List import scala.collection.mutable.HashMap +import java.io.File private[spark] sealed trait DeployMessage extends Serializable @@ -42,7 +43,8 @@ private[spark] case class LaunchExecutor( execId: Int, jobDesc: JobDescription, cores: Int, - memory: Int) + memory: Int, + sparkHome: File) extends DeployMessage diff --git a/core/src/main/scala/spark/deploy/JobDescription.scala b/core/src/main/scala/spark/deploy/JobDescription.scala index 20879c5f11..7f8f9af417 100644 --- a/core/src/main/scala/spark/deploy/JobDescription.scala +++ b/core/src/main/scala/spark/deploy/JobDescription.scala @@ -1,10 +1,13 @@ package spark.deploy +import java.io.File + private[spark] class JobDescription( val name: String, val cores: Int, val memoryPerSlave: Int, - val command: Command) + val command: Command, + val sparkHome: File) extends Serializable { val user = System.getProperty("user.name", "") diff --git a/core/src/main/scala/spark/deploy/client/TestClient.scala b/core/src/main/scala/spark/deploy/client/TestClient.scala index 57a7e123b7..dc743b1fbf 100644 --- a/core/src/main/scala/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/spark/deploy/client/TestClient.scala @@ -3,6 +3,7 @@ package spark.deploy.client import spark.util.AkkaUtils import spark.{Logging, Utils} import spark.deploy.{Command, JobDescription} +import java.io.File private[spark] object TestClient { @@ -25,7 +26,7 @@ private[spark] object TestClient { val url = args(0) val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0) val desc = new JobDescription( - "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map())) + "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), new File("dummy-spark-home")) val listener = new TestListener val client = new Client(actorSystem, url, desc, listener) client.start() diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index 6ecebe626a..f0bee67159 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -6,6 +6,7 @@ import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientDisconnected, Remote import java.text.SimpleDateFormat import java.util.Date +import java.io.File import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} @@ -173,7 +174,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor for (pos <- 0 until numUsable) { if (assigned(pos) > 0) { val exec = job.addExecutor(usableWorkers(pos), assigned(pos)) - launchExecutor(usableWorkers(pos), exec) + launchExecutor(usableWorkers(pos), exec, job.desc.sparkHome) job.state = JobState.RUNNING } } @@ -186,7 +187,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor val coresToUse = math.min(worker.coresFree, job.coresLeft) if (coresToUse > 0) { val exec = job.addExecutor(worker, coresToUse) - launchExecutor(worker, exec) + launchExecutor(worker, exec, job.desc.sparkHome) job.state = JobState.RUNNING } } @@ -195,10 +196,10 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor } } - def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo) { + def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo, sparkHome: File) { logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) - worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory) + worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory, sparkHome) exec.job.actor ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory) } diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala index 7c9e588ea2..078b2d8037 100644 --- a/core/src/main/scala/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/spark/deploy/worker/Worker.scala @@ -119,10 +119,10 @@ private[spark] class Worker( logError("Worker registration failed: " + message) System.exit(1) - case LaunchExecutor(jobId, execId, jobDesc, cores_, memory_) => + case LaunchExecutor(jobId, execId, jobDesc, cores_, memory_, execSparkHome_) => logInfo("Asked to launch executor %s/%d for %s".format(jobId, execId, jobDesc.name)) val manager = new ExecutorRunner( - jobId, execId, jobDesc, cores_, memory_, self, workerId, ip, sparkHome, workDir) + jobId, execId, jobDesc, cores_, memory_, self, workerId, ip, execSparkHome_, workDir) executors(jobId + "/" + execId) = manager manager.start() coresUsed += cores_ diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index e2301347e5..0dcc2efaca 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -4,6 +4,7 @@ import spark.{Utils, Logging, SparkContext} import spark.deploy.client.{Client, ClientListener} import spark.deploy.{Command, JobDescription} import scala.collection.mutable.HashMap +import java.io.File private[spark] class SparkDeploySchedulerBackend( scheduler: ClusterScheduler, @@ -39,7 +40,7 @@ private[spark] class SparkDeploySchedulerBackend( StandaloneSchedulerBackend.ACTOR_NAME) val args = Seq(masterUrl, "{{SLAVEID}}", "{{HOSTNAME}}", "{{CORES}}") val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) - val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command) + val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, new File(sc.sparkHome)) client = new Client(sc.env.actorSystem, master, jobDesc, this) client.start() From aae5a920a4db0c31918a65a03ce7d2087826fd65 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 18 Jan 2013 13:28:50 -0800 Subject: [PATCH 139/291] get sparkHome the correct way --- .../spark/scheduler/cluster/SparkDeploySchedulerBackend.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 0dcc2efaca..08b9d6ff47 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -40,7 +40,7 @@ private[spark] class SparkDeploySchedulerBackend( StandaloneSchedulerBackend.ACTOR_NAME) val args = Seq(masterUrl, "{{SLAVEID}}", "{{HOSTNAME}}", "{{CORES}}") val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) - val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, new File(sc.sparkHome)) + val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, new File(sc.getSparkHome())) client = new Client(sc.env.actorSystem, master, jobDesc, this) client.start() From 5bf73df7f08b17719711a5f05f0b3390b4951272 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Sat, 19 Jan 2013 13:26:15 -0800 Subject: [PATCH 140/291] oops, fix stupid compile error --- .../spark/scheduler/cluster/SparkDeploySchedulerBackend.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 08b9d6ff47..94886d3941 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -40,7 +40,8 @@ private[spark] class SparkDeploySchedulerBackend( StandaloneSchedulerBackend.ACTOR_NAME) val args = Seq(masterUrl, "{{SLAVEID}}", "{{HOSTNAME}}", "{{CORES}}") val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) - val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, new File(sc.getSparkHome())) + val sparkHome = sc.getSparkHome().getOrElse(throw new IllegalArgumentException("must supply spark home for spark standalone")) + val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, new File(sparkHome)) client = new Client(sc.env.actorSystem, master, jobDesc, this) client.start() From c73107500e0a5b6c5f0b4aba8c4504ee4c2adbaf Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Sun, 20 Jan 2013 21:55:50 -0800 Subject: [PATCH 141/291] send sparkHome as String instead of File over network --- core/src/main/scala/spark/deploy/DeployMessage.scala | 2 +- core/src/main/scala/spark/deploy/master/Master.scala | 2 +- core/src/main/scala/spark/deploy/worker/Worker.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/spark/deploy/DeployMessage.scala b/core/src/main/scala/spark/deploy/DeployMessage.scala index 7ee3e63429..a4081ef89c 100644 --- a/core/src/main/scala/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/spark/deploy/DeployMessage.scala @@ -44,7 +44,7 @@ private[spark] case class LaunchExecutor( jobDesc: JobDescription, cores: Int, memory: Int, - sparkHome: File) + sparkHome: String) extends DeployMessage diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index f0bee67159..1b6f808a51 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -199,7 +199,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo, sparkHome: File) { logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) - worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory, sparkHome) + worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory, sparkHome.getAbsolutePath) exec.job.actor ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory) } diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala index 078b2d8037..19bf2be118 100644 --- a/core/src/main/scala/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/spark/deploy/worker/Worker.scala @@ -122,7 +122,7 @@ private[spark] class Worker( case LaunchExecutor(jobId, execId, jobDesc, cores_, memory_, execSparkHome_) => logInfo("Asked to launch executor %s/%d for %s".format(jobId, execId, jobDesc.name)) val manager = new ExecutorRunner( - jobId, execId, jobDesc, cores_, memory_, self, workerId, ip, execSparkHome_, workDir) + jobId, execId, jobDesc, cores_, memory_, self, workerId, ip, new File(execSparkHome_), workDir) executors(jobId + "/" + execId) = manager manager.start() coresUsed += cores_ From fe26acc482f358bf87700f5e80160f7ce558cea7 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Sun, 20 Jan 2013 21:57:44 -0800 Subject: [PATCH 142/291] remove unused imports --- core/src/main/scala/spark/deploy/DeployMessage.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/core/src/main/scala/spark/deploy/DeployMessage.scala b/core/src/main/scala/spark/deploy/DeployMessage.scala index a4081ef89c..35f40c6e91 100644 --- a/core/src/main/scala/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/spark/deploy/DeployMessage.scala @@ -4,8 +4,6 @@ import spark.deploy.ExecutorState.ExecutorState import spark.deploy.master.{WorkerInfo, JobInfo} import spark.deploy.worker.ExecutorRunner import scala.collection.immutable.List -import scala.collection.mutable.HashMap -import java.io.File private[spark] sealed trait DeployMessage extends Serializable From a3f571b539ffd126e9f3bc3e9c7bedfcb6f4d2d4 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Mon, 21 Jan 2013 10:52:17 -0800 Subject: [PATCH 143/291] more File -> String changes --- core/src/main/scala/spark/deploy/JobDescription.scala | 4 +--- core/src/main/scala/spark/deploy/client/TestClient.scala | 3 +-- core/src/main/scala/spark/deploy/master/Master.scala | 5 ++--- .../scheduler/cluster/SparkDeploySchedulerBackend.scala | 4 +--- 4 files changed, 5 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/spark/deploy/JobDescription.scala b/core/src/main/scala/spark/deploy/JobDescription.scala index 7f8f9af417..7160fc05fc 100644 --- a/core/src/main/scala/spark/deploy/JobDescription.scala +++ b/core/src/main/scala/spark/deploy/JobDescription.scala @@ -1,13 +1,11 @@ package spark.deploy -import java.io.File - private[spark] class JobDescription( val name: String, val cores: Int, val memoryPerSlave: Int, val command: Command, - val sparkHome: File) + val sparkHome: String) extends Serializable { val user = System.getProperty("user.name", "") diff --git a/core/src/main/scala/spark/deploy/client/TestClient.scala b/core/src/main/scala/spark/deploy/client/TestClient.scala index dc743b1fbf..8764c400e2 100644 --- a/core/src/main/scala/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/spark/deploy/client/TestClient.scala @@ -3,7 +3,6 @@ package spark.deploy.client import spark.util.AkkaUtils import spark.{Logging, Utils} import spark.deploy.{Command, JobDescription} -import java.io.File private[spark] object TestClient { @@ -26,7 +25,7 @@ private[spark] object TestClient { val url = args(0) val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0) val desc = new JobDescription( - "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), new File("dummy-spark-home")) + "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home") val listener = new TestListener val client = new Client(actorSystem, url, desc, listener) client.start() diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index 1b6f808a51..2c2cd0231b 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -6,7 +6,6 @@ import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientDisconnected, Remote import java.text.SimpleDateFormat import java.util.Date -import java.io.File import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} @@ -196,10 +195,10 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor } } - def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo, sparkHome: File) { + def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo, sparkHome: String) { logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) - worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory, sparkHome.getAbsolutePath) + worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory, sparkHome) exec.job.actor ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory) } diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 94886d3941..a21a5b2f3d 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -3,8 +3,6 @@ package spark.scheduler.cluster import spark.{Utils, Logging, SparkContext} import spark.deploy.client.{Client, ClientListener} import spark.deploy.{Command, JobDescription} -import scala.collection.mutable.HashMap -import java.io.File private[spark] class SparkDeploySchedulerBackend( scheduler: ClusterScheduler, @@ -41,7 +39,7 @@ private[spark] class SparkDeploySchedulerBackend( val args = Seq(masterUrl, "{{SLAVEID}}", "{{HOSTNAME}}", "{{CORES}}") val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) val sparkHome = sc.getSparkHome().getOrElse(throw new IllegalArgumentException("must supply spark home for spark standalone")) - val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, new File(sparkHome)) + val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, sparkHome) client = new Client(sc.env.actorSystem, master, jobDesc, this) client.start() From 4d34c7fc3ecd7a4d035005f84c01e6990c0c345e Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 21 Jan 2013 11:33:48 -0800 Subject: [PATCH 144/291] Fix compile error caused by cherry-pick --- .../spark/scheduler/cluster/SparkDeploySchedulerBackend.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index a21a5b2f3d..4f82cd96dd 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -3,6 +3,7 @@ package spark.scheduler.cluster import spark.{Utils, Logging, SparkContext} import spark.deploy.client.{Client, ClientListener} import spark.deploy.{Command, JobDescription} +import scala.collection.mutable.HashMap private[spark] class SparkDeploySchedulerBackend( scheduler: ClusterScheduler, From a88b44ed3b670633549049e9ccf990ea455e9720 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 21 Jan 2013 11:59:21 -0800 Subject: [PATCH 145/291] Only bind to IPv4 addresses when trying to auto-detect external IP --- core/src/main/scala/spark/Utils.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index b3421df27c..692a3f4050 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -1,7 +1,7 @@ package spark import java.io._ -import java.net.{NetworkInterface, InetAddress, URL, URI} +import java.net.{NetworkInterface, InetAddress, Inet4Address, URL, URI} import java.util.{Locale, Random, UUID} import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor} import org.apache.hadoop.conf.Configuration @@ -251,7 +251,8 @@ private object Utils extends Logging { // Address resolves to something like 127.0.1.1, which happens on Debian; try to find // a better address using the local network interfaces for (ni <- NetworkInterface.getNetworkInterfaces) { - for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress && !addr.isLoopbackAddress) { + for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress && + !addr.isLoopbackAddress && addr.isInstanceOf[Inet4Address]) { // We've found an address that looks reasonable! logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" + " a loopback address: " + address.getHostAddress + "; using " + addr.getHostAddress + From 2173f6c7cac877a3b756d63aabf7bdd06a18e6d9 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 21 Jan 2013 13:02:40 -0800 Subject: [PATCH 146/291] Clarify the documentation on env variables for standalone mode --- docs/spark-standalone.md | 43 ++++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index e0ba7c35cb..bf296221b8 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -51,11 +51,11 @@ Finally, the following configuration options can be passed to the master and wor -c CORES, --cores CORES - Number of CPU cores to use (default: all available); only on worker + Total CPU cores to allow Spark jobs to use on the machine (default: all available); only on worker -m MEM, --memory MEM - Amount of memory to use, in a format like 1000M or 2G (default: your machine's total RAM minus 1 GB); only on worker + Total amount of memory to allow Spark jobs to use on the machine, in a format like 1000M or 2G (default: your machine's total RAM minus 1 GB); only on worker -d DIR, --work-dir DIR @@ -66,9 +66,20 @@ Finally, the following configuration options can be passed to the master and wor # Cluster Launch Scripts -To launch a Spark standalone cluster with the deploy scripts, you need to set up two files, `conf/spark-env.sh` and `conf/slaves`. The `conf/spark-env.sh` file lets you specify global settings for the master and slave instances, such as memory, or port numbers to bind to, while `conf/slaves` is a list of slave nodes. The system requires that all the slave machines have the same configuration files, so *copy these files to each machine*. +To launch a Spark standalone cluster with the deploy scripts, you need to create a file called `conf/slaves` in your Spark directory, which should contain the hostnames of all the machines where you would like to start Spark workers, one per line. The master machine must be able to access each of the slave machines via password-less `ssh` (using a private key). For testing, you can just put `localhost` in this file. -In `conf/spark-env.sh`, you can set the following parameters, in addition to the [standard Spark configuration settings](configuration.html): +Once you've set up this fine, you can launch or stop your cluster with the following shell scripts, based on Hadoop's deploy scripts, and available in `SPARK_HOME/bin`: + +- `bin/start-master.sh` - Starts a master instance on the machine the script is executed on. +- `bin/start-slaves.sh` - Starts a slave instance on each machine specified in the `conf/slaves` file. +- `bin/start-all.sh` - Starts both a master and a number of slaves as described above. +- `bin/stop-master.sh` - Stops the master that was started via the `bin/start-master.sh` script. +- `bin/stop-slaves.sh` - Stops the slave instances that were started via `bin/start-slaves.sh`. +- `bin/stop-all.sh` - Stops both the master and the slaves as described above. + +Note that these scripts must be executed on the machine you want to run the Spark master on, not your local machine. + +You can optionally configure the cluster further by setting environment variables in `conf/spark-env.sh`. Create this file by starting with the `conf/spark-env.sh.template`, and _copy it to all your worker machines_ for the settings to take effect. The following settings are available: @@ -88,36 +99,24 @@ In `conf/spark-env.sh`, you can set the following parameters, in addition to the + + + + - + - + - - - -
    Environment VariableMeaning
    SPARK_WORKER_PORT Start the Spark worker on a specific port (default: random)
    SPARK_WORKER_DIRDirectory to run jobs in, which will include both logs and scratch space (default: SPARK_HOME/work)
    SPARK_WORKER_CORESNumber of cores to use (default: all available cores)Total number of cores to allow Spark jobs to use on the machine (default: all available cores)
    SPARK_WORKER_MEMORYHow much memory to use, e.g. 1000M, 2G (default: total memory minus 1 GB)Total amount of memory to allow Spark jobs to use on the machine, e.g. 1000M, 2G (default: total memory minus 1 GB); note that each job's individual memory is configured using SPARK_MEM
    SPARK_WORKER_WEBUI_PORT Port for the worker web UI (default: 8081)
    SPARK_WORKER_DIRDirectory to run jobs in, which will include both logs and scratch space (default: SPARK_HOME/work)
    -In `conf/slaves`, include a list of all machines where you would like to start a Spark worker, one per line. The master machine must be able to access each of the slave machines via password-less `ssh` (using a private key). For testing purposes, you can have a single `localhost` entry in the slaves file. - -Once you've set up these configuration files, you can launch or stop your cluster with the following shell scripts, based on Hadoop's deploy scripts, and available in `SPARK_HOME/bin`: - -- `bin/start-master.sh` - Starts a master instance on the machine the script is executed on. -- `bin/start-slaves.sh` - Starts a slave instance on each machine specified in the `conf/slaves` file. -- `bin/start-all.sh` - Starts both a master and a number of slaves as described above. -- `bin/stop-master.sh` - Stops the master that was started via the `bin/start-master.sh` script. -- `bin/stop-slaves.sh` - Stops the slave instances that were started via `bin/start-slaves.sh`. -- `bin/stop-all.sh` - Stops both the master and the slaves as described above. - -Note that the scripts must be executed on the machine you want to run the Spark master on, not your local machine. # Connecting a Job to the Cluster From 76d7c0ce2bd9c4d5782fec320279e0a011230625 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 21 Jan 2013 13:10:02 -0800 Subject: [PATCH 147/291] Add more Akka settings to docs --- docs/configuration.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/docs/configuration.md b/docs/configuration.md index 87cb4a6797..036a0df480 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -197,6 +197,15 @@ Apart from these, the following properties are also available, and may be useful poor data locality, but the default generally works well. + + spark.akka.frameSize + 10 + + Maximum message size to allow in "control plane" communication (for serialized tasks and task + results), in MB. Increase this if your tasks need to send back large results to the master + (e.g. using collect() on a large dataset). + + spark.akka.threads 4 @@ -205,6 +214,13 @@ Apart from these, the following properties are also available, and may be useful when the master has a lot of CPU cores. + + spark.akka.timeout + 20 + + Communication timeout between Spark nodes. + + spark.master.host (local hostname) From ffd1623595cdce4080ad1e4e676e65898ebdd6dd Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Mon, 21 Jan 2013 15:55:46 -0600 Subject: [PATCH 148/291] Minor cleanup. --- core/src/main/scala/spark/Accumulators.scala | 3 +-- core/src/main/scala/spark/Logging.scala | 3 +-- .../main/scala/spark/ParallelCollection.scala | 15 +++++---------- core/src/main/scala/spark/TaskContext.scala | 3 +-- core/src/main/scala/spark/rdd/BlockRDD.scala | 6 ++---- core/src/main/scala/spark/rdd/CartesianRDD.scala | 3 +-- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 6 ++---- core/src/main/scala/spark/rdd/NewHadoopRDD.scala | 6 ++---- core/src/main/scala/spark/rdd/SampledRDD.scala | 5 ++--- core/src/main/scala/spark/rdd/ShuffledRDD.scala | 3 +-- core/src/main/scala/spark/rdd/UnionRDD.scala | 3 +-- core/src/main/scala/spark/rdd/ZippedRDD.scala | 3 +-- .../spark/scheduler/local/LocalScheduler.scala | 4 ++-- .../mesos/CoarseMesosSchedulerBackend.scala | 16 ++++++---------- .../scheduler/mesos/MesosSchedulerBackend.scala | 10 +++------- core/src/test/scala/spark/FileServerSuite.scala | 4 ++-- 16 files changed, 33 insertions(+), 60 deletions(-) diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala index b644aba5f8..57c6df35be 100644 --- a/core/src/main/scala/spark/Accumulators.scala +++ b/core/src/main/scala/spark/Accumulators.scala @@ -25,8 +25,7 @@ class Accumulable[R, T] ( extends Serializable { val id = Accumulators.newId - @transient - private var value_ = initialValue // Current value on master + @transient private var value_ = initialValue // Current value on master val zero = param.zero(initialValue) // Zero value to be passed to workers var deserialized = false diff --git a/core/src/main/scala/spark/Logging.scala b/core/src/main/scala/spark/Logging.scala index 90bae26202..7c1c1bb144 100644 --- a/core/src/main/scala/spark/Logging.scala +++ b/core/src/main/scala/spark/Logging.scala @@ -11,8 +11,7 @@ import org.slf4j.LoggerFactory trait Logging { // Make the log field transient so that objects with Logging can // be serialized and used on another machine - @transient - private var log_ : Logger = null + @transient private var log_ : Logger = null // Method to get or create the logger for this object protected def log: Logger = { diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala index ede933c9e9..ad23e5bec8 100644 --- a/core/src/main/scala/spark/ParallelCollection.scala +++ b/core/src/main/scala/spark/ParallelCollection.scala @@ -23,32 +23,28 @@ private[spark] class ParallelCollectionSplit[T: ClassManifest]( } private[spark] class ParallelCollection[T: ClassManifest]( - @transient sc : SparkContext, + @transient sc: SparkContext, @transient data: Seq[T], numSlices: Int, - locationPrefs : Map[Int,Seq[String]]) + locationPrefs: Map[Int,Seq[String]]) extends RDD[T](sc, Nil) { // TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets // cached. It might be worthwhile to write the data to a file in the DFS and read it in the split // instead. // UPDATE: A parallel collection can be checkpointed to HDFS, which achieves this goal. - @transient - var splits_ : Array[Split] = { + @transient var splits_ : Array[Split] = { val slices = ParallelCollection.slice(data, numSlices).toArray slices.indices.map(i => new ParallelCollectionSplit(id, i, slices(i))).toArray } - override def getSplits = splits_.asInstanceOf[Array[Split]] + override def getSplits = splits_ override def compute(s: Split, context: TaskContext) = s.asInstanceOf[ParallelCollectionSplit[T]].iterator override def getPreferredLocations(s: Split): Seq[String] = { - locationPrefs.get(s.index) match { - case Some(s) => s - case _ => Nil - } + locationPrefs.get(s.index) getOrElse Nil } override def clearDependencies() { @@ -56,7 +52,6 @@ private[spark] class ParallelCollection[T: ClassManifest]( } } - private object ParallelCollection { /** * Slice a collection into numSlices sub-collections. One extra thing we do here is to treat Range diff --git a/core/src/main/scala/spark/TaskContext.scala b/core/src/main/scala/spark/TaskContext.scala index d2746b26b3..eab85f85a2 100644 --- a/core/src/main/scala/spark/TaskContext.scala +++ b/core/src/main/scala/spark/TaskContext.scala @@ -5,8 +5,7 @@ import scala.collection.mutable.ArrayBuffer class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Long) extends Serializable { - @transient - val onCompleteCallbacks = new ArrayBuffer[() => Unit] + @transient val onCompleteCallbacks = new ArrayBuffer[() => Unit] // Add a callback function to be executed on task completion. An example use // is for HadoopRDD to register a callback to close the input stream. diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala index b1095a52b4..2c022f88e0 100644 --- a/core/src/main/scala/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/spark/rdd/BlockRDD.scala @@ -11,13 +11,11 @@ private[spark] class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String]) extends RDD[T](sc, Nil) { - @transient - var splits_ : Array[Split] = (0 until blockIds.size).map(i => { + @transient var splits_ : Array[Split] = (0 until blockIds.size).map(i => { new BlockRDDSplit(blockIds(i), i).asInstanceOf[Split] }).toArray - @transient - lazy val locations_ = { + @transient lazy val locations_ = { val blockManager = SparkEnv.get.blockManager /*val locations = blockIds.map(id => blockManager.getLocations(id))*/ val locations = blockManager.getLocations(blockIds) diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala index 79e7c24e7c..453d410ad4 100644 --- a/core/src/main/scala/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala @@ -35,8 +35,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( val numSplitsInRdd2 = rdd2.splits.size - @transient - var splits_ = { + @transient var splits_ = { // create the cross product split val array = new Array[Split](rdd1.splits.size * rdd2.splits.size) for (s1 <- rdd1.splits; s2 <- rdd2.splits) { diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index 1d528be2aa..8fafd27bb6 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -45,8 +45,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) val aggr = new CoGroupAggregator - @transient - var deps_ = { + @transient var deps_ = { val deps = new ArrayBuffer[Dependency[_]] for ((rdd, index) <- rdds.zipWithIndex) { if (rdd.partitioner == Some(part)) { @@ -63,8 +62,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) override def getDependencies = deps_ - @transient - var splits_ : Array[Split] = { + @transient var splits_ : Array[Split] = { val array = new Array[Split](part.numPartitions) for (i <- 0 until array.size) { array(i) = new CoGroupSplit(i, rdds.zipWithIndex.map { case (r, j) => diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala index bb22db073c..c3b155fcbd 100644 --- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala @@ -37,11 +37,9 @@ class NewHadoopRDD[K, V]( formatter.format(new Date()) } - @transient - private val jobId = new JobID(jobtrackerId, id) + @transient private val jobId = new JobID(jobtrackerId, id) - @transient - private val splits_ : Array[Split] = { + @transient private val splits_ : Array[Split] = { val inputFormat = inputFormatClass.newInstance val jobContext = newJobContext(conf, jobId) val rawSplits = inputFormat.getSplits(jobContext).toArray diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala index 1bc9c96112..e24ad23b21 100644 --- a/core/src/main/scala/spark/rdd/SampledRDD.scala +++ b/core/src/main/scala/spark/rdd/SampledRDD.scala @@ -19,13 +19,12 @@ class SampledRDD[T: ClassManifest]( seed: Int) extends RDD[T](prev) { - @transient - var splits_ : Array[Split] = { + @transient var splits_ : Array[Split] = { val rg = new Random(seed) firstParent[T].splits.map(x => new SampledRDDSplit(x, rg.nextInt)) } - override def getSplits = splits_.asInstanceOf[Array[Split]] + override def getSplits = splits_ override def getPreferredLocations(split: Split) = firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDSplit].prev) diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index 1b219473e0..28ff19876d 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -22,8 +22,7 @@ class ShuffledRDD[K, V]( override val partitioner = Some(part) - @transient - var splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i)) + @transient var splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i)) override def getSplits = splits_ diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index 24a085df02..82f0a44ecd 100644 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -28,8 +28,7 @@ class UnionRDD[T: ClassManifest]( @transient var rdds: Seq[RDD[T]]) extends RDD[T](sc, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs - @transient - var splits_ : Array[Split] = { + @transient var splits_ : Array[Split] = { val array = new Array[Split](rdds.map(_.splits.size).sum) var pos = 0 for (rdd <- rdds; split <- rdd.splits) { diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala index 16e6cc0f1b..d950b06c85 100644 --- a/core/src/main/scala/spark/rdd/ZippedRDD.scala +++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala @@ -34,8 +34,7 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest]( // TODO: FIX THIS. - @transient - var splits_ : Array[Split] = { + @transient var splits_ : Array[Split] = { if (rdd1.splits.size != rdd2.splits.size) { throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") } diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index dff550036d..21d255debd 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -19,8 +19,8 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon extends TaskScheduler with Logging { - var attemptId = new AtomicInteger(0) - var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory) + val attemptId = new AtomicInteger(0) + val threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory) val env = SparkEnv.get var listener: TaskSchedulerListener = null diff --git a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala index c45c7df69c..014906b028 100644 --- a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala @@ -64,13 +64,9 @@ private[spark] class CoarseMesosSchedulerBackend( val taskIdToSlaveId = new HashMap[Int, String] val failuresBySlaveId = new HashMap[String, Int] // How many times tasks on each slave failed - val sparkHome = sc.getSparkHome() match { - case Some(path) => - path - case None => - throw new SparkException("Spark home is not set; set it through the spark.home system " + - "property, the SPARK_HOME environment variable or the SparkContext constructor") - } + val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException( + "Spark home is not set; set it through the spark.home system " + + "property, the SPARK_HOME environment variable or the SparkContext constructor")) val extraCoresPerSlave = System.getProperty("spark.mesos.extra.cores", "0").toInt @@ -184,7 +180,7 @@ private[spark] class CoarseMesosSchedulerBackend( } /** Helper function to pull out a resource from a Mesos Resources protobuf */ - def getResource(res: JList[Resource], name: String): Double = { + private def getResource(res: JList[Resource], name: String): Double = { for (r <- res if r.getName == name) { return r.getScalar.getValue } @@ -193,7 +189,7 @@ private[spark] class CoarseMesosSchedulerBackend( } /** Build a Mesos resource protobuf object */ - def createResource(resourceName: String, quantity: Double): Protos.Resource = { + private def createResource(resourceName: String, quantity: Double): Protos.Resource = { Resource.newBuilder() .setName(resourceName) .setType(Value.Type.SCALAR) @@ -202,7 +198,7 @@ private[spark] class CoarseMesosSchedulerBackend( } /** Check whether a Mesos task state represents a finished task */ - def isFinished(state: MesosTaskState) = { + private def isFinished(state: MesosTaskState) = { state == MesosTaskState.TASK_FINISHED || state == MesosTaskState.TASK_FAILED || state == MesosTaskState.TASK_KILLED || diff --git a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala index 8c7a1dfbc0..2989e31f5e 100644 --- a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala @@ -76,13 +76,9 @@ private[spark] class MesosSchedulerBackend( } def createExecutorInfo(): ExecutorInfo = { - val sparkHome = sc.getSparkHome() match { - case Some(path) => - path - case None => - throw new SparkException("Spark home is not set; set it through the spark.home system " + - "property, the SPARK_HOME environment variable or the SparkContext constructor") - } + val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException( + "Spark home is not set; set it through the spark.home system " + + "property, the SPARK_HOME environment variable or the SparkContext constructor")) val execScript = new File(sparkHome, "spark-executor").getCanonicalPath val environment = Environment.newBuilder() sc.executorEnvs.foreach { case (key, value) => diff --git a/core/src/test/scala/spark/FileServerSuite.scala b/core/src/test/scala/spark/FileServerSuite.scala index b4283d9604..fe964bd893 100644 --- a/core/src/test/scala/spark/FileServerSuite.scala +++ b/core/src/test/scala/spark/FileServerSuite.scala @@ -9,8 +9,8 @@ import SparkContext._ class FileServerSuite extends FunSuite with BeforeAndAfter { @transient var sc: SparkContext = _ - @transient var tmpFile : File = _ - @transient var testJarFile : File = _ + @transient var tmpFile: File = _ + @transient var testJarFile: File = _ before { // Create a sample text file From e5ca2413352510297092384eda73049ad601fd8a Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Mon, 21 Jan 2013 16:06:58 -0600 Subject: [PATCH 149/291] Move JavaAPISuite into spark.streaming. --- streaming/src/test/java/{ => spark/streaming}/JavaAPISuite.java | 0 streaming/src/test/java/{ => spark/streaming}/JavaTestUtils.scala | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename streaming/src/test/java/{ => spark/streaming}/JavaAPISuite.java (100%) rename streaming/src/test/java/{ => spark/streaming}/JavaTestUtils.scala (100%) diff --git a/streaming/src/test/java/JavaAPISuite.java b/streaming/src/test/java/spark/streaming/JavaAPISuite.java similarity index 100% rename from streaming/src/test/java/JavaAPISuite.java rename to streaming/src/test/java/spark/streaming/JavaAPISuite.java diff --git a/streaming/src/test/java/JavaTestUtils.scala b/streaming/src/test/java/spark/streaming/JavaTestUtils.scala similarity index 100% rename from streaming/src/test/java/JavaTestUtils.scala rename to streaming/src/test/java/spark/streaming/JavaTestUtils.scala From ef711902c1f42db14c8ddd524195f0a9efb56e65 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 21 Jan 2013 16:42:24 -0800 Subject: [PATCH 150/291] Don't download files to master's working directory. This should avoid exceptions caused by existing files with different contents. I also removed some unused code. --- .../src/main/scala/spark/HttpFileServer.scala | 8 ++-- core/src/main/scala/spark/SparkContext.scala | 7 ++-- core/src/main/scala/spark/SparkEnv.scala | 20 ++++++---- core/src/main/scala/spark/SparkFiles.java | 25 ++++++++++++ core/src/main/scala/spark/Utils.scala | 16 +------- .../spark/api/java/JavaSparkContext.scala | 5 ++- .../scala/spark/api/python/PythonRDD.scala | 2 + .../spark/deploy/worker/ExecutorRunner.scala | 5 --- .../main/scala/spark/executor/Executor.scala | 6 +-- .../scheduler/local/LocalScheduler.scala | 6 +-- .../test/scala/spark/FileServerSuite.scala | 9 +++-- python/pyspark/__init__.py | 5 ++- python/pyspark/context.py | 40 +++++++++++++++++-- python/pyspark/files.py | 24 +++++++++++ python/pyspark/worker.py | 3 ++ python/run-tests | 3 ++ 16 files changed, 133 insertions(+), 51 deletions(-) create mode 100644 core/src/main/scala/spark/SparkFiles.java create mode 100644 python/pyspark/files.py diff --git a/core/src/main/scala/spark/HttpFileServer.scala b/core/src/main/scala/spark/HttpFileServer.scala index 659d17718f..00901d95e2 100644 --- a/core/src/main/scala/spark/HttpFileServer.scala +++ b/core/src/main/scala/spark/HttpFileServer.scala @@ -1,9 +1,7 @@ package spark -import java.io.{File, PrintWriter} -import java.net.URL -import scala.collection.mutable.HashMap -import org.apache.hadoop.fs.FileUtil +import java.io.{File} +import com.google.common.io.Files private[spark] class HttpFileServer extends Logging { @@ -40,7 +38,7 @@ private[spark] class HttpFileServer extends Logging { } def addFileToDir(file: File, dir: File) : String = { - Utils.copyFile(file, new File(dir, file.getName)) + Files.copy(file, new File(dir, file.getName)) return dir + "/" + file.getName } diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 8b6f4b3b7d..2eeca66ed6 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -439,9 +439,10 @@ class SparkContext( def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal) /** - * Add a file to be downloaded into the working directory of this Spark job on every node. + * Add a file to be downloaded with this Spark job on every node. * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported - * filesystems), or an HTTP, HTTPS or FTP URI. + * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, + * use `SparkFiles.get(path)` to find its download location. */ def addFile(path: String) { val uri = new URI(path) @@ -454,7 +455,7 @@ class SparkContext( // Fetch the file locally in case a job is executed locally. // Jobs that run through LocalScheduler will already fetch the required dependencies, // but jobs run in DAGScheduler.runLocally() will not so we must fetch the files here. - Utils.fetchFile(path, new File(".")) + Utils.fetchFile(path, new File(SparkFiles.getRootDirectory)) logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key)) } diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 41441720a7..6b44e29f4c 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -28,14 +28,10 @@ class SparkEnv ( val broadcastManager: BroadcastManager, val blockManager: BlockManager, val connectionManager: ConnectionManager, - val httpFileServer: HttpFileServer + val httpFileServer: HttpFileServer, + val sparkFilesDir: String ) { - /** No-parameter constructor for unit tests. */ - def this() = { - this(null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null, null) - } - def stop() { httpFileServer.stop() mapOutputTracker.stop() @@ -112,6 +108,15 @@ object SparkEnv extends Logging { httpFileServer.initialize() System.setProperty("spark.fileserver.uri", httpFileServer.serverUri) + // Set the sparkFiles directory, used when downloading dependencies. In local mode, + // this is a temporary directory; in distributed mode, this is the executor's current working + // directory. + val sparkFilesDir: String = if (isMaster) { + Utils.createTempDir().getAbsolutePath + } else { + "." + } + // Warn about deprecated spark.cache.class property if (System.getProperty("spark.cache.class") != null) { logWarning("The spark.cache.class property is no longer being used! Specify storage " + @@ -128,6 +133,7 @@ object SparkEnv extends Logging { broadcastManager, blockManager, connectionManager, - httpFileServer) + httpFileServer, + sparkFilesDir) } } diff --git a/core/src/main/scala/spark/SparkFiles.java b/core/src/main/scala/spark/SparkFiles.java new file mode 100644 index 0000000000..b59d8ce93f --- /dev/null +++ b/core/src/main/scala/spark/SparkFiles.java @@ -0,0 +1,25 @@ +package spark; + +import java.io.File; + +/** + * Resolves paths to files added through `addFile(). + */ +public class SparkFiles { + + private SparkFiles() {} + + /** + * Get the absolute path of a file added through `addFile()`. + */ + public static String get(String filename) { + return new File(getRootDirectory(), filename).getAbsolutePath(); + } + + /** + * Get the root directory that contains files added through `addFile()`. + */ + public static String getRootDirectory() { + return SparkEnv.get().sparkFilesDir(); + } +} \ No newline at end of file diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 692a3f4050..827c8bd81e 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -111,20 +111,6 @@ private object Utils extends Logging { } } - /** Copy a file on the local file system */ - def copyFile(source: File, dest: File) { - val in = new FileInputStream(source) - val out = new FileOutputStream(dest) - copyStream(in, out, true) - } - - /** Download a file from a given URL to the local filesystem */ - def downloadFile(url: URL, localPath: String) { - val in = url.openStream() - val out = new FileOutputStream(localPath) - Utils.copyStream(in, out, true) - } - /** * Download a file requested by the executor. Supports fetching the file in a variety of ways, * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. @@ -201,7 +187,7 @@ private object Utils extends Logging { Utils.execute(Seq("tar", "-xf", filename), targetDir) } // Make the file executable - That's necessary for scripts - FileUtil.chmod(filename, "a+x") + FileUtil.chmod(targetFile.getAbsolutePath, "a+x") } /** diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala index 16c122c584..50b8970cd8 100644 --- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala @@ -323,9 +323,10 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def getSparkHome(): Option[String] = sc.getSparkHome() /** - * Add a file to be downloaded into the working directory of this Spark job on every node. + * Add a file to be downloaded with this Spark job on every node. * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported - * filesystems), or an HTTP, HTTPS or FTP URI. + * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, + * use `SparkFiles.get(path)` to find its download location. */ def addFile(path: String) { sc.addFile(path) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 5526406a20..f43a152ca7 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -67,6 +67,8 @@ private[spark] class PythonRDD[T: ClassManifest]( val dOut = new DataOutputStream(proc.getOutputStream) // Split index dOut.writeInt(split.index) + // sparkFilesDir + PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dOut) // Broadcast variables dOut.writeInt(broadcastVars.length) for (broadcast <- broadcastVars) { diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala index beceb55ecd..0d1fe2a6b4 100644 --- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala @@ -106,11 +106,6 @@ private[spark] class ExecutorRunner( throw new IOException("Failed to create directory " + executorDir) } - // Download the files it depends on into it (disabled for now) - //for (url <- jobDesc.fileUrls) { - // fetchFile(url, executorDir) - //} - // Launch the process val command = buildCommandSeq() val builder = new ProcessBuilder(command: _*).directory(executorDir) diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index 2552958d27..70629f6003 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -162,16 +162,16 @@ private[spark] class Executor extends Logging { // Fetch missing dependencies for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(".")) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) currentFiles(name) = timestamp } for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(".")) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) currentJars(name) = timestamp // Add it to our class loader val localName = name.split("/").last - val url = new File(".", localName).toURI.toURL + val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL if (!urlClassLoader.getURLs.contains(url)) { logInfo("Adding " + url + " to class loader") urlClassLoader.addURL(url) diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index dff550036d..4451d314e6 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -116,16 +116,16 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon // Fetch missing dependencies for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(".")) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) currentFiles(name) = timestamp } for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(".")) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) currentJars(name) = timestamp // Add it to our class loader val localName = name.split("/").last - val url = new File(".", localName).toURI.toURL + val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL if (!classLoader.getURLs.contains(url)) { logInfo("Adding " + url + " to class loader") classLoader.addURL(url) diff --git a/core/src/test/scala/spark/FileServerSuite.scala b/core/src/test/scala/spark/FileServerSuite.scala index b4283d9604..528c6b8424 100644 --- a/core/src/test/scala/spark/FileServerSuite.scala +++ b/core/src/test/scala/spark/FileServerSuite.scala @@ -40,7 +40,8 @@ class FileServerSuite extends FunSuite with BeforeAndAfter { sc.addFile(tmpFile.toString) val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) val result = sc.parallelize(testData).reduceByKey { - val in = new BufferedReader(new FileReader("FileServerSuite.txt")) + val path = SparkFiles.get("FileServerSuite.txt") + val in = new BufferedReader(new FileReader(path)) val fileVal = in.readLine().toInt in.close() _ * fileVal + _ * fileVal @@ -54,7 +55,8 @@ class FileServerSuite extends FunSuite with BeforeAndAfter { sc.addFile((new File(tmpFile.toString)).toURL.toString) val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) val result = sc.parallelize(testData).reduceByKey { - val in = new BufferedReader(new FileReader("FileServerSuite.txt")) + val path = SparkFiles.get("FileServerSuite.txt") + val in = new BufferedReader(new FileReader(path)) val fileVal = in.readLine().toInt in.close() _ * fileVal + _ * fileVal @@ -83,7 +85,8 @@ class FileServerSuite extends FunSuite with BeforeAndAfter { sc.addFile(tmpFile.toString) val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) val result = sc.parallelize(testData).reduceByKey { - val in = new BufferedReader(new FileReader("FileServerSuite.txt")) + val path = SparkFiles.get("FileServerSuite.txt") + val in = new BufferedReader(new FileReader(path)) val fileVal = in.readLine().toInt in.close() _ * fileVal + _ * fileVal diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 00666bc0a3..3e8bca62f0 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -11,6 +11,8 @@ Public classes: A broadcast variable that gets reused across tasks. - L{Accumulator} An "add-only" shared variable that tasks can only add values to. + - L{SparkFiles} + Access files shipped with jobs. """ import sys import os @@ -19,6 +21,7 @@ sys.path.insert(0, os.path.join(os.environ["SPARK_HOME"], "python/lib/py4j0.7.eg from pyspark.context import SparkContext from pyspark.rdd import RDD +from pyspark.files import SparkFiles -__all__ = ["SparkContext", "RDD"] +__all__ = ["SparkContext", "RDD", "SparkFiles"] diff --git a/python/pyspark/context.py b/python/pyspark/context.py index dcbed37270..ec0cc7c2f9 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -1,5 +1,7 @@ import os import atexit +import shutil +import tempfile from tempfile import NamedTemporaryFile from pyspark import accumulators @@ -173,10 +175,26 @@ class SparkContext(object): def addFile(self, path): """ - Add a file to be downloaded into the working directory of this Spark - job on every node. The C{path} passed can be either a local file, - a file in HDFS (or other Hadoop-supported filesystems), or an HTTP, - HTTPS or FTP URI. + Add a file to be downloaded with this Spark job on every node. + The C{path} passed can be either a local file, a file in HDFS + (or other Hadoop-supported filesystems), or an HTTP, HTTPS or + FTP URI. + + To access the file in Spark jobs, use + L{SparkFiles.get(path)} to find its + download location. + + >>> from pyspark import SparkFiles + >>> path = os.path.join(tempdir, "test.txt") + >>> with open(path, "w") as testFile: + ... testFile.write("100") + >>> sc.addFile(path) + >>> def func(iterator): + ... with open(SparkFiles.get("test.txt")) as testFile: + ... fileVal = int(testFile.readline()) + ... return [x * 100 for x in iterator] + >>> sc.parallelize([1, 2, 3, 4]).mapPartitions(func).collect() + [100, 200, 300, 400] """ self._jsc.sc().addFile(path) @@ -211,3 +229,17 @@ class SparkContext(object): accidental overriding of checkpoint files in the existing directory. """ self._jsc.sc().setCheckpointDir(dirName, useExisting) + + +def _test(): + import doctest + globs = globals().copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + globs['tempdir'] = tempfile.mkdtemp() + atexit.register(lambda: shutil.rmtree(globs['tempdir'])) + doctest.testmod(globs=globs) + globs['sc'].stop() + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/files.py b/python/pyspark/files.py new file mode 100644 index 0000000000..de1334f046 --- /dev/null +++ b/python/pyspark/files.py @@ -0,0 +1,24 @@ +import os + + +class SparkFiles(object): + """ + Resolves paths to files added through + L{addFile()}. + + SparkFiles contains only classmethods; users should not create SparkFiles + instances. + """ + + _root_directory = None + + def __init__(self): + raise NotImplementedError("Do not construct SparkFiles objects") + + @classmethod + def get(cls, filename): + """ + Get the absolute path of a file added through C{addFile()}. + """ + path = os.path.join(SparkFiles._root_directory, filename) + return os.path.abspath(path) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index b2b9288089..e7bdb7682b 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -8,6 +8,7 @@ from base64 import standard_b64decode from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler +from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, read_with_length, write_int, \ read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file @@ -23,6 +24,8 @@ def load_obj(): def main(): split_index = read_int(sys.stdin) + spark_files_dir = load_pickle(read_with_length(sys.stdin)) + SparkFiles._root_directory = spark_files_dir num_broadcast_variables = read_int(sys.stdin) for _ in range(num_broadcast_variables): bid = read_long(sys.stdin) diff --git a/python/run-tests b/python/run-tests index ce214e98a8..a3a9ff5dcb 100755 --- a/python/run-tests +++ b/python/run-tests @@ -8,6 +8,9 @@ FAILED=0 $FWDIR/pyspark pyspark/rdd.py FAILED=$(($?||$FAILED)) +$FWDIR/pyspark pyspark/context.py +FAILED=$(($?||$FAILED)) + $FWDIR/pyspark -m doctest pyspark/broadcast.py FAILED=$(($?||$FAILED)) From 7b9e96c99206c0679d9925e0161fde738a5c7c3a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 21 Jan 2013 16:45:00 -0800 Subject: [PATCH 151/291] Add synchronization to Executor.updateDependencies() (SPARK-662) --- .../main/scala/spark/executor/Executor.scala | 34 ++++++++++--------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index 70629f6003..28d9d40d43 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -159,22 +159,24 @@ private[spark] class Executor extends Logging { * SparkContext. Also adds any new JARs we fetched to the class loader. */ private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) { - // Fetch missing dependencies - for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) - currentFiles(name) = timestamp - } - for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) - currentJars(name) = timestamp - // Add it to our class loader - val localName = name.split("/").last - val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL - if (!urlClassLoader.getURLs.contains(url)) { - logInfo("Adding " + url + " to class loader") - urlClassLoader.addURL(url) + synchronized { + // Fetch missing dependencies + for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { + logInfo("Fetching " + name + " with timestamp " + timestamp) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) + currentFiles(name) = timestamp + } + for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { + logInfo("Fetching " + name + " with timestamp " + timestamp) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) + currentJars(name) = timestamp + // Add it to our class loader + val localName = name.split("/").last + val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL + if (!urlClassLoader.getURLs.contains(url)) { + logInfo("Adding " + url + " to class loader") + urlClassLoader.addURL(url) + } } } } From 2d8218b8717435a47d7cea399290b30bf5ef010b Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Mon, 21 Jan 2013 20:00:27 -0600 Subject: [PATCH 152/291] Remove unneeded/now-broken saveAsNewAPIHadoopFile overload. --- core/src/main/scala/spark/PairRDDFunctions.scala | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 656b820b8a..53b051f1c5 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -485,18 +485,6 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( saveAsNewAPIHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]]) } - /** - * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat` - * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD. - */ - def saveAsNewAPIHadoopFile( - path: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[_ <: NewOutputFormat[_, _]]) { - saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass) - } - /** * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat` * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD. From a8baeb93272b03a98e44c7bf5c541611aec4a64b Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Mon, 21 Jan 2013 21:30:24 -0600 Subject: [PATCH 153/291] Further simplify getOrElse call. --- core/src/main/scala/spark/ParallelCollection.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala index ad23e5bec8..10adcd53ec 100644 --- a/core/src/main/scala/spark/ParallelCollection.scala +++ b/core/src/main/scala/spark/ParallelCollection.scala @@ -44,7 +44,7 @@ private[spark] class ParallelCollection[T: ClassManifest]( s.asInstanceOf[ParallelCollectionSplit[T]].iterator override def getPreferredLocations(s: Split): Seq[String] = { - locationPrefs.get(s.index) getOrElse Nil + locationPrefs.getOrElse(s.index, Nil) } override def clearDependencies() { From c75ae3622eeed068c44b1f823ef4d87d01a720fd Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 20 Jan 2013 15:12:54 -0800 Subject: [PATCH 154/291] Make AccumulatorParam an abstract base class. --- python/pyspark/accumulators.py | 29 ++++++++++++++++++++++++++--- python/pyspark/context.py | 15 +++++---------- 2 files changed, 31 insertions(+), 13 deletions(-) diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 8011779ddc..5a9269f9bb 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -61,6 +61,7 @@ Traceback (most recent call last): Exception:... """ +from abc import ABCMeta, abstractmethod import struct import SocketServer import threading @@ -90,8 +91,7 @@ class Accumulator(object): While C{SparkContext} supports accumulators for primitive data types like C{int} and C{float}, users can also define accumulators for custom types by providing a custom - C{AccumulatorParam} object with a C{zero} and C{addInPlace} method. Refer to the doctest - of this module for an example. + L{AccumulatorParam} object. Refer to the doctest of this module for an example. """ def __init__(self, aid, value, accum_param): @@ -134,7 +134,30 @@ class Accumulator(object): return "Accumulator" % (self.aid, self._value) -class AddingAccumulatorParam(object): +class AccumulatorParam(object): + """ + Helper object that defines how to accumulate values of a given type. + """ + __metaclass__ = ABCMeta + + @abstractmethod + def zero(self, value): + """ + Provide a "zero value" for the type, compatible in dimensions with the + provided C{value} (e.g., a zero vector) + """ + return + + @abstractmethod + def addInPlace(self, value1, value2): + """ + Add two values of the accumulator's data type, returning a new value; + for efficiency, can also update C{value1} in place and return it. + """ + return + + +class AddingAccumulatorParam(AccumulatorParam): """ An AccumulatorParam that uses the + operators to add values. Designed for simple types such as integers, floats, and lists. Requires the zero value for the underlying type diff --git a/python/pyspark/context.py b/python/pyspark/context.py index dcbed37270..a17e7a4ad1 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -148,16 +148,11 @@ class SparkContext(object): def accumulator(self, value, accum_param=None): """ - Create an C{Accumulator} with the given initial value, using a given - AccumulatorParam helper object to define how to add values of the data - type if provided. Default AccumulatorParams are used for integers and - floating-point numbers if you do not provide one. For other types, the - AccumulatorParam must implement two methods: - - C{zero(value)}: provide a "zero value" for the type, compatible in - dimensions with the provided C{value} (e.g., a zero vector). - - C{addInPlace(val1, val2)}: add two values of the accumulator's data - type, returning a new value; for efficiency, can also update C{val1} - in place and return it. + Create an L{Accumulator} with the given initial value, using a given + L{AccumulatorParam} helper object to define how to add values of the + data type if provided. Default AccumulatorParams are used for integers + and floating-point numbers if you do not provide one. For other types, + a custom AccumulatorParam can be used. """ if accum_param == None: if isinstance(value, int): From 551a47a620c7dc207e3530e54d794a3c3aa8e45e Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 21 Jan 2013 23:31:00 -0800 Subject: [PATCH 155/291] Refactor daemon thread pool creation. --- .../scala/spark/DaemonThreadFactory.scala | 18 ---------- core/src/main/scala/spark/Utils.scala | 33 ++++--------------- .../spark/network/ConnectionManager.scala | 5 ++- .../scheduler/local/LocalScheduler.scala | 2 +- .../streaming/dstream/RawInputDStream.scala | 5 +-- 5 files changed, 13 insertions(+), 50 deletions(-) delete mode 100644 core/src/main/scala/spark/DaemonThreadFactory.scala diff --git a/core/src/main/scala/spark/DaemonThreadFactory.scala b/core/src/main/scala/spark/DaemonThreadFactory.scala deleted file mode 100644 index 56e59adeb7..0000000000 --- a/core/src/main/scala/spark/DaemonThreadFactory.scala +++ /dev/null @@ -1,18 +0,0 @@ -package spark - -import java.util.concurrent.ThreadFactory - -/** - * A ThreadFactory that creates daemon threads - */ -private object DaemonThreadFactory extends ThreadFactory { - override def newThread(r: Runnable): Thread = new DaemonThread(r) -} - -private class DaemonThread(r: Runnable = null) extends Thread { - override def run() { - if (r != null) { - r.run() - } - } -} \ No newline at end of file diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 692a3f4050..9b8636f6c8 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -10,6 +10,7 @@ import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConversions._ import scala.io.Source import com.google.common.io.Files +import com.google.common.util.concurrent.ThreadFactoryBuilder /** * Various utility methods used by Spark. @@ -287,29 +288,14 @@ private object Utils extends Logging { customHostname.getOrElse(InetAddress.getLocalHost.getHostName) } - /** - * Returns a standard ThreadFactory except all threads are daemons. - */ - private def newDaemonThreadFactory: ThreadFactory = { - new ThreadFactory { - def newThread(r: Runnable): Thread = { - var t = Executors.defaultThreadFactory.newThread (r) - t.setDaemon (true) - return t - } - } - } + private[spark] val daemonThreadFactory: ThreadFactory = + new ThreadFactoryBuilder().setDaemon(true).build() /** * Wrapper over newCachedThreadPool. */ - def newDaemonCachedThreadPool(): ThreadPoolExecutor = { - var threadPool = Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor] - - threadPool.setThreadFactory (newDaemonThreadFactory) - - return threadPool - } + def newDaemonCachedThreadPool(): ThreadPoolExecutor = + Executors.newCachedThreadPool(daemonThreadFactory).asInstanceOf[ThreadPoolExecutor] /** * Return the string to tell how long has passed in seconds. The passing parameter should be in @@ -322,13 +308,8 @@ private object Utils extends Logging { /** * Wrapper over newFixedThreadPool. */ - def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = { - var threadPool = Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor] - - threadPool.setThreadFactory(newDaemonThreadFactory) - - return threadPool - } + def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = + Executors.newFixedThreadPool(nThreads, daemonThreadFactory).asInstanceOf[ThreadPoolExecutor] /** * Delete a file or directory and its contents recursively. diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index 36c01ad629..2ecd14f536 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -52,9 +52,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging { val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] val sendMessageRequests = new Queue[(Message, SendingConnection)] - implicit val futureExecContext = ExecutionContext.fromExecutor( - Executors.newCachedThreadPool(DaemonThreadFactory)) - + implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool()) + var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null serverChannel.configureBlocking(false) diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index dff550036d..87f8474ea0 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -20,7 +20,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon with Logging { var attemptId = new AtomicInteger(0) - var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory) + var threadPool = Utils.newDaemonFixedThreadPool(threads) val env = SparkEnv.get var listener: TaskSchedulerListener = null diff --git a/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala index 290fab1ce0..04e6b69b7b 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala @@ -1,6 +1,6 @@ package spark.streaming.dstream -import spark.{DaemonThread, Logging} +import spark.Logging import spark.storage.StorageLevel import spark.streaming.StreamingContext @@ -48,7 +48,8 @@ class RawNetworkReceiver(host: String, port: Int, storageLevel: StorageLevel) val queue = new ArrayBlockingQueue[ByteBuffer](2) - blockPushingThread = new DaemonThread { + blockPushingThread = new Thread { + setDaemon(true) override def run() { var nextBlockNumber = 0 while (true) { From e353886a8ca6179f25b4176d7a62b5d04ce79276 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Tue, 22 Jan 2013 00:23:31 -0800 Subject: [PATCH 156/291] Use generation numbers for fetch failure tracking --- .../scala/spark/scheduler/DAGScheduler.scala | 27 ++++++++++++++----- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 59f2099e91..39a1e6d6c6 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -72,8 +72,12 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val cacheTracker = env.cacheTracker val mapOutputTracker = env.mapOutputTracker - val deadHosts = new HashSet[String] // TODO: The code currently assumes these can't come back; - // that's not going to be a realistic assumption in general + // For tracking failed nodes, we use the MapOutputTracker's generation number, which is + // sent with every task. When we detect a node failing, we note the current generation number + // and failed host, increment it for new tasks, and use this to ignore stray ShuffleMapTask + // results. + // TODO: Garbage collect information about failure generations when new stages start. + val failedGeneration = new HashMap[String, Long] val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done val running = new HashSet[Stage] // Stages we are running right now @@ -429,7 +433,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val status = event.result.asInstanceOf[MapStatus] val host = status.address.ip logInfo("ShuffleMapTask finished with host " + host) - if (!deadHosts.contains(host)) { // TODO: Make sure hostnames are consistent with Mesos + if (failedGeneration.contains(host) && smt.generation <= failedGeneration(host)) { + logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + host) + } else { stage.addOutputLoc(smt.partition, status) } if (running.contains(stage) && pendingTasks(stage).isEmpty) { @@ -495,7 +501,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with lastFetchFailureTime = System.currentTimeMillis() // TODO: Use pluggable clock // TODO: mark the host as failed only if there were lots of fetch failures on it if (bmAddress != null) { - handleHostLost(bmAddress.ip) + handleHostLost(bmAddress.ip, Some(task.generation)) } case other => @@ -507,11 +513,15 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with /** * Responds to a host being lost. This is called inside the event loop so it assumes that it can * modify the scheduler's internal state. Use hostLost() to post a host lost event from outside. + * + * Optionally the generation during which the failure was caught can be passed to avoid allowing + * stray fetch failures from possibly retriggering the detection of a node as lost. */ - def handleHostLost(host: String) { - if (!deadHosts.contains(host)) { + def handleHostLost(host: String, maybeGeneration: Option[Long] = None) { + val currentGeneration = maybeGeneration.getOrElse(mapOutputTracker.getGeneration) + if (!failedGeneration.contains(host) || failedGeneration(host) < currentGeneration) { + failedGeneration(host) = currentGeneration logInfo("Host lost: " + host) - deadHosts += host env.blockManager.master.notifyADeadHost(host) // TODO: This will be really slow if we keep accumulating shuffle map stages for ((shuffleId, stage) <- shuffleToMapStage) { @@ -519,6 +529,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray mapOutputTracker.registerMapOutputs(shuffleId, locs, true) } + if (shuffleToMapStage.isEmpty) { + mapOutputTracker.incrementGeneration() + } cacheTracker.cacheLost(host) updateCacheLocs() } From 364cdb679cf2b0d5e6ed7ab89628f15594d7947f Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 22 Jan 2013 00:43:31 -0800 Subject: [PATCH 157/291] Refactored DStreamCheckpointData. --- .../main/scala/spark/streaming/DStream.scala | 58 ++----------- .../streaming/DStreamCheckpointData.scala | 84 +++++++++++++++++++ .../streaming/dstream/KafkaInputDStream.scala | 9 -- .../spark/streaming/CheckpointSuite.scala | 12 +-- 4 files changed, 99 insertions(+), 64 deletions(-) create mode 100644 streaming/src/main/scala/spark/streaming/DStreamCheckpointData.scala diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index b11ef443dc..3c1861a840 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -12,7 +12,7 @@ import scala.collection.mutable.HashMap import java.io.{ObjectInputStream, IOException, ObjectOutputStream} -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.conf.Configuration /** @@ -75,7 +75,7 @@ abstract class DStream[T: ClassManifest] ( // Checkpoint details protected[streaming] val mustCheckpoint = false protected[streaming] var checkpointDuration: Duration = null - protected[streaming] var checkpointData = new DStreamCheckpointData(HashMap[Time, Any]()) + protected[streaming] val checkpointData = new DStreamCheckpointData(this) // Reference to whole DStream graph protected[streaming] var graph: DStreamGraph = null @@ -85,10 +85,10 @@ abstract class DStream[T: ClassManifest] ( // Duration for which the DStream requires its parent DStream to remember each RDD created protected[streaming] def parentRememberDuration = rememberDuration - /** Returns the StreamingContext associated with this DStream */ + /** Return the StreamingContext associated with this DStream */ def context() = ssc - /** Persists the RDDs of this DStream with the given storage level */ + /** Persist the RDDs of this DStream with the given storage level */ def persist(level: StorageLevel): DStream[T] = { if (this.isInitialized) { throw new UnsupportedOperationException( @@ -342,40 +342,10 @@ abstract class DStream[T: ClassManifest] ( */ protected[streaming] def updateCheckpointData(currentTime: Time) { logInfo("Updating checkpoint data for time " + currentTime) - - // Get the checkpointed RDDs from the generated RDDs - val newRdds = generatedRDDs.filter(_._2.getCheckpointFile.isDefined) - .map(x => (x._1, x._2.getCheckpointFile.get)) - - // Make a copy of the existing checkpoint data (checkpointed RDDs) - val oldRdds = checkpointData.rdds.clone() - - // If the new checkpoint data has checkpoints then replace existing with the new one - if (newRdds.size > 0) { - checkpointData.rdds.clear() - checkpointData.rdds ++= newRdds - } - - // Make parent DStreams update their checkpoint data + checkpointData.update() dependencies.foreach(_.updateCheckpointData(currentTime)) - - // TODO: remove this, this is just for debugging - newRdds.foreach { - case (time, data) => { logInfo("Added checkpointed RDD for time " + time + " to stream checkpoint") } - } - - if (newRdds.size > 0) { - (oldRdds -- newRdds.keySet).foreach { - case (time, data) => { - val path = new Path(data.toString) - val fs = path.getFileSystem(new Configuration()) - fs.delete(path, true) - logInfo("Deleted checkpoint file '" + path + "' for time " + time) - } - } - } - logInfo("Updated checkpoint data for time " + currentTime + ", " + checkpointData.rdds.size + " checkpoints, " - + "[" + checkpointData.rdds.mkString(",") + "]") + checkpointData.cleanup() + logDebug("Updated checkpoint data for time " + currentTime + ": " + checkpointData) } /** @@ -386,14 +356,8 @@ abstract class DStream[T: ClassManifest] ( */ protected[streaming] def restoreCheckpointData() { // Create RDDs from the checkpoint data - logInfo("Restoring checkpoint data from " + checkpointData.rdds.size + " checkpointed RDDs") - checkpointData.rdds.foreach { - case(time, data) => { - logInfo("Restoring checkpointed RDD for time " + time + " from file '" + data.toString + "'") - val rdd = ssc.sc.checkpointFile[T](data.toString) - generatedRDDs += ((time, rdd)) - } - } + logInfo("Restoring checkpoint data from " + checkpointData.checkpointFiles.size + " checkpointed RDDs") + checkpointData.restore() dependencies.foreach(_.restoreCheckpointData()) logInfo("Restored checkpoint data") } @@ -651,7 +615,3 @@ abstract class DStream[T: ClassManifest] ( ssc.registerOutputStream(this) } } - -private[streaming] -case class DStreamCheckpointData(rdds: HashMap[Time, Any]) - diff --git a/streaming/src/main/scala/spark/streaming/DStreamCheckpointData.scala b/streaming/src/main/scala/spark/streaming/DStreamCheckpointData.scala new file mode 100644 index 0000000000..abf903293f --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/DStreamCheckpointData.scala @@ -0,0 +1,84 @@ +package spark.streaming + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.conf.Configuration +import collection.mutable.HashMap +import spark.Logging + + + +private[streaming] +class DStreamCheckpointData[T: ClassManifest] (dstream: DStream[T]) + extends Serializable with Logging { + private[streaming] val checkpointFiles = new HashMap[Time, String]() + @transient private lazy val fileSystem = + new Path(dstream.context.checkpointDir).getFileSystem(new Configuration()) + @transient private var lastCheckpointFiles: HashMap[Time, String] = null + + /** + * Update the checkpoint data of the DStream. Default implementation records the checkpoint files to + * which the generate RDDs of the DStream has been saved. + */ + def update() { + + // Get the checkpointed RDDs from the generated RDDs + val newCheckpointFiles = dstream.generatedRDDs.filter(_._2.getCheckpointFile.isDefined) + .map(x => (x._1, x._2.getCheckpointFile.get)) + + // Make a copy of the existing checkpoint data (checkpointed RDDs) + lastCheckpointFiles = checkpointFiles.clone() + + // If the new checkpoint data has checkpoints then replace existing with the new one + if (newCheckpointFiles.size > 0) { + checkpointFiles.clear() + checkpointFiles ++= newCheckpointFiles + } + + // TODO: remove this, this is just for debugging + newCheckpointFiles.foreach { + case (time, data) => { logInfo("Added checkpointed RDD for time " + time + " to stream checkpoint") } + } + } + + /** + * Cleanup old checkpoint data. Default implementation, cleans up old checkpoint files. + */ + def cleanup() { + // If there is at least on checkpoint file in the current checkpoint files, + // then delete the old checkpoint files. + if (checkpointFiles.size > 0 && lastCheckpointFiles != null) { + (lastCheckpointFiles -- checkpointFiles.keySet).foreach { + case (time, file) => { + try { + val path = new Path(file) + fileSystem.delete(path, true) + logInfo("Deleted checkpoint file '" + file + "' for time " + time) + } catch { + case e: Exception => + logWarning("Error deleting old checkpoint file '" + file + "' for time " + time, e) + } + } + } + } + } + + /** + * Restore the checkpoint data. Default implementation restores the RDDs from their + * checkpoint files. + */ + def restore() { + // Create RDDs from the checkpoint data + checkpointFiles.foreach { + case(time, file) => { + logInfo("Restoring checkpointed RDD for time " + time + " from file '" + file + "'") + dstream.generatedRDDs += ((time, dstream.context.sc.checkpointFile[T](file))) + } + } + } + + override def toString() = { + "[\n" + checkpointFiles.size + "\n" + checkpointFiles.mkString("\n") + "\n]" + } +} + diff --git a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala index 2b4740bdf7..760d9b5cf3 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala @@ -19,15 +19,6 @@ import scala.collection.JavaConversions._ // Key for a specific Kafka Partition: (broker, topic, group, part) case class KafkaPartitionKey(brokerId: Int, topic: String, groupId: String, partId: Int) -// NOT USED - Originally intended for fault-tolerance -// Metadata for a Kafka Stream that it sent to the Master -private[streaming] -case class KafkaInputDStreamMetadata(timestamp: Long, data: Map[KafkaPartitionKey, Long]) -// NOT USED - Originally intended for fault-tolerance -// Checkpoint data specific to a KafkaInputDstream -private[streaming] -case class KafkaDStreamCheckpointData(kafkaRdds: HashMap[Time, Any], - savedOffsets: Map[KafkaPartitionKey, Long]) extends DStreamCheckpointData(kafkaRdds) /** * Input stream that pulls messages from a Kafka Broker. diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala index d2f32c189b..58da4ee539 100644 --- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -63,9 +63,9 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { // then check whether some RDD has been checkpointed or not ssc.start() runStreamsWithRealDelay(ssc, firstNumBatches) - logInfo("Checkpoint data of state stream = \n[" + stateStream.checkpointData.rdds.mkString(",\n") + "]") - assert(!stateStream.checkpointData.rdds.isEmpty, "No checkpointed RDDs in state stream before first failure") - stateStream.checkpointData.rdds.foreach { + logInfo("Checkpoint data of state stream = \n" + stateStream.checkpointData) + assert(!stateStream.checkpointData.checkpointFiles.isEmpty, "No checkpointed RDDs in state stream before first failure") + stateStream.checkpointData.checkpointFiles.foreach { case (time, data) => { val file = new File(data.toString) assert(file.exists(), "Checkpoint file '" + file +"' for time " + time + " for state stream before first failure does not exist") @@ -74,7 +74,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { // Run till a further time such that previous checkpoint files in the stream would be deleted // and check whether the earlier checkpoint files are deleted - val checkpointFiles = stateStream.checkpointData.rdds.map(x => new File(x._2.toString)) + val checkpointFiles = stateStream.checkpointData.checkpointFiles.map(x => new File(x._2)) runStreamsWithRealDelay(ssc, secondNumBatches) checkpointFiles.foreach(file => assert(!file.exists, "Checkpoint file '" + file + "' was not deleted")) ssc.stop() @@ -91,8 +91,8 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { // is present in the checkpoint data or not ssc.start() runStreamsWithRealDelay(ssc, 1) - assert(!stateStream.checkpointData.rdds.isEmpty, "No checkpointed RDDs in state stream before second failure") - stateStream.checkpointData.rdds.foreach { + assert(!stateStream.checkpointData.checkpointFiles.isEmpty, "No checkpointed RDDs in state stream before second failure") + stateStream.checkpointData.checkpointFiles.foreach { case (time, data) => { val file = new File(data.toString) assert(file.exists(), From 7e9ee2e8335f085062d3fdeecd0b49ec63e92117 Mon Sep 17 00:00:00 2001 From: Leemoonsoo Date: Tue, 22 Jan 2013 23:08:34 +0900 Subject: [PATCH 158/291] Fix for hanging spark.HttpFileServer with kind of virtual network --- core/src/main/scala/spark/HttpServer.scala | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/HttpServer.scala b/core/src/main/scala/spark/HttpServer.scala index 0196595ba1..4e0507c080 100644 --- a/core/src/main/scala/spark/HttpServer.scala +++ b/core/src/main/scala/spark/HttpServer.scala @@ -4,6 +4,7 @@ import java.io.File import java.net.InetAddress import org.eclipse.jetty.server.Server +import org.eclipse.jetty.server.bio.SocketConnector import org.eclipse.jetty.server.handler.DefaultHandler import org.eclipse.jetty.server.handler.HandlerList import org.eclipse.jetty.server.handler.ResourceHandler @@ -27,7 +28,13 @@ private[spark] class HttpServer(resourceBase: File) extends Logging { if (server != null) { throw new ServerStateException("Server is already started") } else { - server = new Server(0) + server = new Server() + val connector = new SocketConnector + connector.setMaxIdleTime(60*1000) + connector.setSoLingerTime(-1) + connector.setPort(0) + server.addConnector(connector) + val threadPool = new QueuedThreadPool threadPool.setDaemon(true) server.setThreadPool(threadPool) From 588b24197a85c4b46a38595007293abef9a41f2c Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 22 Jan 2013 10:19:30 -0600 Subject: [PATCH 159/291] Use default arguments instead of constructor overloads. --- core/src/main/scala/spark/SparkContext.scala | 22 +++----------------- 1 file changed, 3 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 8b6f4b3b7d..495d1b6c78 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -58,27 +58,11 @@ import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend class SparkContext( val master: String, val jobName: String, - val sparkHome: String, - val jars: Seq[String], - environment: Map[String, String]) + val sparkHome: String = null, + val jars: Seq[String] = Nil, + environment: Map[String, String] = Map()) extends Logging { - /** - * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). - * @param jobName A name for your job, to display on the cluster web UI - * @param sparkHome Location where Spark is installed on cluster nodes. - * @param jars Collection of JARs to send to the cluster. These can be paths on the local file - * system or HDFS, HTTP, HTTPS, or FTP URLs. - */ - def this(master: String, jobName: String, sparkHome: String, jars: Seq[String]) = - this(master, jobName, sparkHome, jars, Map()) - - /** - * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). - * @param jobName A name for your job, to display on the cluster web UI - */ - def this(master: String, jobName: String) = this(master, jobName, null, Nil, Map()) - // Ensure logging is initialized before we spawn any threads initLogging() From 50e2b23927956c14db40093d31bc80892764006a Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Tue, 22 Jan 2013 09:27:33 -0800 Subject: [PATCH 160/291] Fix up some problems from the merge --- .../scala/spark/storage/BlockManagerMasterActor.scala | 11 +++++++++++ .../scala/spark/storage/BlockManagerMessages.scala | 3 +++ core/src/main/scala/spark/storage/StorageUtils.scala | 8 ++++---- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala index f4d026da33..c945c34c71 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala @@ -68,6 +68,9 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { case GetMemoryStatus => getMemoryStatus + case GetStorageStatus => + getStorageStatus + case RemoveBlock(blockId) => removeBlock(blockId) @@ -177,6 +180,14 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { sender ! res } + private def getStorageStatus() { + val res = blockManagerInfo.map { case(blockManagerId, info) => + import collection.JavaConverters._ + StorageStatus(blockManagerId, info.maxMem, info.blocks.asScala.toMap) + } + sender ! res + } + private def register(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { val startTimeMs = System.currentTimeMillis() val tmp = " " + blockManagerId + " " diff --git a/core/src/main/scala/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/spark/storage/BlockManagerMessages.scala index d73a9b790f..3a381fd385 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMessages.scala @@ -100,3 +100,6 @@ case object GetMemoryStatus extends ToBlockManagerMaster private[spark] case object ExpireDeadHosts extends ToBlockManagerMaster + +private[spark] +case object GetStorageStatus extends ToBlockManagerMaster \ No newline at end of file diff --git a/core/src/main/scala/spark/storage/StorageUtils.scala b/core/src/main/scala/spark/storage/StorageUtils.scala index ebc7390ee5..63ad5c125b 100644 --- a/core/src/main/scala/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/spark/storage/StorageUtils.scala @@ -1,6 +1,7 @@ package spark.storage import spark.SparkContext +import BlockManagerMasterActor.BlockStatus private[spark] case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long, @@ -20,8 +21,8 @@ case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long, } -case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel, - numPartitions: Int, memSize: Long, diskSize: Long, locations: Array[BlockManagerId]) +case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel, + numPartitions: Int, memSize: Long, diskSize: Long) /* Helper methods for storage-related objects */ @@ -58,8 +59,7 @@ object StorageUtils { val rddName = Option(sc.persistentRdds.get(rddId).name).getOrElse(rddKey) val rddStorageLevel = sc.persistentRdds.get(rddId).getStorageLevel - RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, memSize, diskSize, - rddBlocks.map(_.blockManagerId)) + RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, memSize, diskSize) }.toArray } From 27b3f3f0a980f86bac14a14516b5d52a32aa8cbb Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 22 Jan 2013 15:30:42 -0600 Subject: [PATCH 161/291] Handle slaveLost before slaveIdToHost knows about it. --- .../scheduler/cluster/ClusterScheduler.scala | 31 +++++++++++-------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index 20f6e65020..a639b72795 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -252,19 +252,24 @@ private[spark] class ClusterScheduler(val sc: SparkContext) def slaveLost(slaveId: String, reason: ExecutorLossReason) { var failedHost: Option[String] = None synchronized { - val host = slaveIdToHost(slaveId) - if (hostsAlive.contains(host)) { - logError("Lost an executor on " + host + ": " + reason) - slaveIdsWithExecutors -= slaveId - hostsAlive -= host - activeTaskSetsQueue.foreach(_.hostLost(host)) - failedHost = Some(host) - } else { - // We may get multiple slaveLost() calls with different loss reasons. For example, one - // may be triggered by a dropped connection from the slave while another may be a report - // of executor termination from Mesos. We produce log messages for both so we eventually - // report the termination reason. - logError("Lost an executor on " + host + " (already removed): " + reason) + slaveIdToHost.get(slaveId) match { + case Some(host) => + if (hostsAlive.contains(host)) { + logError("Lost an executor on " + host + ": " + reason) + slaveIdsWithExecutors -= slaveId + hostsAlive -= host + activeTaskSetsQueue.foreach(_.hostLost(host)) + failedHost = Some(host) + } else { + // We may get multiple slaveLost() calls with different loss reasons. For example, one + // may be triggered by a dropped connection from the slave while another may be a report + // of executor termination from Mesos. We produce log messages for both so we eventually + // report the termination reason. + logError("Lost an executor on " + host + " (already removed): " + reason) + } + case None => + // We were told about a slave being lost before we could even allocate work to it + logError("Lost slave " + slaveId + " (no work assigned yet)") } } if (failedHost != None) { From 6f2194f7576eb188c23f18125f5101ae0b4e9e4d Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 22 Jan 2013 15:38:58 -0600 Subject: [PATCH 162/291] Call removeJob instead of killing the cluster. --- core/src/main/scala/spark/deploy/master/Master.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index 2c2cd0231b..d1a65204b8 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -103,8 +103,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor val e = new SparkException("Job %s wth ID %s failed %d times.".format( jobInfo.desc.name, jobInfo.id, jobInfo.retryCount)) logError(e.getMessage, e) - throw e - //System.exit(1) + removeJob(jobInfo) } } } From 250fe89679bb59ef0d31f74985f72556dcfe2d06 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 22 Jan 2013 16:29:05 -0600 Subject: [PATCH 163/291] Handle Master telling the Worker to kill an already-dead executor. --- core/src/main/scala/spark/deploy/worker/Worker.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala index 19bf2be118..d040b86908 100644 --- a/core/src/main/scala/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/spark/deploy/worker/Worker.scala @@ -143,9 +143,13 @@ private[spark] class Worker( case KillExecutor(jobId, execId) => val fullId = jobId + "/" + execId - val executor = executors(fullId) - logInfo("Asked to kill executor " + fullId) - executor.kill() + executors.get(fullId) match { + case Some(executor) => + logInfo("Asked to kill executor " + fullId) + executor.kill() + case None => + logInfo("Asked to kill non-existent existent " + fullId) + } case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) => masterDisconnected() From 2437f6741b9c5b0a778d55d324aabdc4642889e5 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 22 Jan 2013 18:01:03 -0600 Subject: [PATCH 164/291] Restore SPARK_MEM in executorEnvs. --- core/src/main/scala/spark/SparkContext.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index a5a1b75944..402355bd52 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -111,8 +111,9 @@ class SparkContext( // Environment variables to pass to our executors private[spark] val executorEnvs = HashMap[String, String]() - // Note: SPARK_MEM isn't included because it's set directly in ExecutorRunner - for (key <- Seq("SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS", "SPARK_TESTING")) { + // Note: SPARK_MEM is included for Mesos, but overwritten for standalone mode in ExecutorRunner + for (key <- Seq("SPARK_MEM", "SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS", + "SPARK_TESTING")) { val value = System.getenv(key) if (value != null) { executorEnvs(key) = value From fdec42385a1a8f10f9dd803525cb3c132a25ba53 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 22 Jan 2013 18:01:12 -0600 Subject: [PATCH 165/291] Fix SPARK_MEM in ExecutorRunner. --- core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala index 2f2ea617ff..e910416235 100644 --- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala @@ -118,7 +118,7 @@ private[spark] class ExecutorRunner( for ((key, value) <- jobDesc.command.environment) { env.put(key, value) } - env.put("SPARK_MEM", memory.toString) + env.put("SPARK_MEM", memory.toString + "m") // In case we are running this from within the Spark Shell, avoid creating a "scala" // parent process for the executor command env.put("SPARK_LAUNCH_WITH_SCALA", "0") From 8c51322cd05f2ae97a08c3af314c7608fcf71b57 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 22 Jan 2013 18:09:10 -0600 Subject: [PATCH 166/291] Don't bother creating an exception. --- core/src/main/scala/spark/deploy/master/Master.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index d1a65204b8..361e5ac627 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -100,9 +100,8 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor if (jobInfo.incrementRetryCount <= JobState.MAX_NUM_RETRY) { schedule() } else { - val e = new SparkException("Job %s wth ID %s failed %d times.".format( + logError("Job %s wth ID %s failed %d times, removing it".format( jobInfo.desc.name, jobInfo.id, jobInfo.retryCount)) - logError(e.getMessage, e) removeJob(jobInfo) } } From 98d0b7747d7539db009a9bbc261f899955871524 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 22 Jan 2013 18:11:51 -0600 Subject: [PATCH 167/291] Fix Worker logInfo about unknown executor. --- core/src/main/scala/spark/deploy/worker/Worker.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala index d040b86908..5a83a42daf 100644 --- a/core/src/main/scala/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/spark/deploy/worker/Worker.scala @@ -148,7 +148,7 @@ private[spark] class Worker( logInfo("Asked to kill executor " + fullId) executor.kill() case None => - logInfo("Asked to kill non-existent existent " + fullId) + logInfo("Asked to kill unknown executor " + fullId) } case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) => From 35168d9c89904f0dc0bb470c1799f5ca3b04221f Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 22 Jan 2013 17:54:11 -0800 Subject: [PATCH 168/291] Fix sys.path bug in PySpark SparkContext.addPyFile --- python/pyspark/context.py | 2 -- python/pyspark/tests.py | 38 ++++++++++++++++++++++++++---- python/pyspark/worker.py | 1 + python/test_support/userlibrary.py | 7 ++++++ 4 files changed, 41 insertions(+), 7 deletions(-) create mode 100755 python/test_support/userlibrary.py diff --git a/python/pyspark/context.py b/python/pyspark/context.py index ec0cc7c2f9..b8d7dc05af 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -215,8 +215,6 @@ class SparkContext(object): """ self.addFile(path) filename = path.split("/")[-1] - os.environ["PYTHONPATH"] = \ - "%s:%s" % (filename, os.environ["PYTHONPATH"]) def setCheckpointDir(self, dirName, useExisting=False): """ diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index b0a403b580..4d70ee4f12 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -9,21 +9,32 @@ import time import unittest from pyspark.context import SparkContext +from pyspark.java_gateway import SPARK_HOME -class TestCheckpoint(unittest.TestCase): +class PySparkTestCase(unittest.TestCase): def setUp(self): - self.sc = SparkContext('local[4]', 'TestPartitioning', batchSize=2) - self.checkpointDir = NamedTemporaryFile(delete=False) - os.unlink(self.checkpointDir.name) - self.sc.setCheckpointDir(self.checkpointDir.name) + class_name = self.__class__.__name__ + self.sc = SparkContext('local[4]', class_name , batchSize=2) def tearDown(self): self.sc.stop() # To avoid Akka rebinding to the same port, since it doesn't unbind # immediately on shutdown self.sc.jvm.System.clearProperty("spark.master.port") + + +class TestCheckpoint(PySparkTestCase): + + def setUp(self): + PySparkTestCase.setUp(self) + self.checkpointDir = NamedTemporaryFile(delete=False) + os.unlink(self.checkpointDir.name) + self.sc.setCheckpointDir(self.checkpointDir.name) + + def tearDown(self): + PySparkTestCase.tearDown(self) shutil.rmtree(self.checkpointDir.name) def test_basic_checkpointing(self): @@ -57,5 +68,22 @@ class TestCheckpoint(unittest.TestCase): self.assertEquals([1, 2, 3, 4], recovered.collect()) +class TestAddFile(PySparkTestCase): + + def test_add_py_file(self): + # To ensure that we're actually testing addPyFile's effects, check that + # this job fails due to `userlibrary` not being on the Python path: + def func(x): + from userlibrary import UserClass + return UserClass().hello() + self.assertRaises(Exception, + self.sc.parallelize(range(2)).map(func).first) + # Add the file, so the job should now succeed: + path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") + self.sc.addPyFile(path) + res = self.sc.parallelize(range(2)).map(func).first() + self.assertEqual("Hello World!", res) + + if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index e7bdb7682b..4bf643da66 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -26,6 +26,7 @@ def main(): split_index = read_int(sys.stdin) spark_files_dir = load_pickle(read_with_length(sys.stdin)) SparkFiles._root_directory = spark_files_dir + sys.path.append(spark_files_dir) num_broadcast_variables = read_int(sys.stdin) for _ in range(num_broadcast_variables): bid = read_long(sys.stdin) diff --git a/python/test_support/userlibrary.py b/python/test_support/userlibrary.py new file mode 100755 index 0000000000..5bb6f5009f --- /dev/null +++ b/python/test_support/userlibrary.py @@ -0,0 +1,7 @@ +""" +Used to test shipping of code depenencies with SparkContext.addPyFile(). +""" + +class UserClass(object): + def hello(self): + return "Hello World!" From fad2b82fc8fb49f2171af10cf7e408d8b8dd7349 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 22 Jan 2013 18:10:00 -0800 Subject: [PATCH 169/291] Added support for saving input files of FileInputDStream to graph checkpoints. Modified 'file input stream with checkpoint' testcase to test recovery of pre-master-failure input files. --- .../main/scala/spark/streaming/DStream.scala | 29 +++--- .../streaming/DStreamCheckpointData.scala | 27 ++++-- .../scala/spark/streaming/DStreamGraph.scala | 2 +- .../spark/streaming/StreamingContext.scala | 7 +- .../streaming/dstream/FileInputDStream.scala | 96 +++++++++++++++---- .../spark/streaming/InputStreamsSuite.scala | 64 +++++++++---- 6 files changed, 159 insertions(+), 66 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 3c1861a840..07ecb018ee 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -86,7 +86,7 @@ abstract class DStream[T: ClassManifest] ( protected[streaming] def parentRememberDuration = rememberDuration /** Return the StreamingContext associated with this DStream */ - def context() = ssc + def context = ssc /** Persist the RDDs of this DStream with the given storage level */ def persist(level: StorageLevel): DStream[T] = { @@ -159,7 +159,7 @@ abstract class DStream[T: ClassManifest] ( ) assert( - checkpointDuration == null || ssc.sc.checkpointDir.isDefined, + checkpointDuration == null || context.sparkContext.checkpointDir.isDefined, "The checkpoint directory has not been set. Please use StreamingContext.checkpoint()" + " or SparkContext.checkpoint() to set the checkpoint directory." ) @@ -298,8 +298,8 @@ abstract class DStream[T: ClassManifest] ( getOrCompute(time) match { case Some(rdd) => { val jobFunc = () => { - val emptyFunc = { (iterator: Iterator[T]) => {} } - ssc.sc.runJob(rdd, emptyFunc) + val emptyFunc = { (iterator: Iterator[T]) => {} } + context.sparkContext.runJob(rdd, emptyFunc) } Some(new Job(time, jobFunc)) } @@ -310,10 +310,9 @@ abstract class DStream[T: ClassManifest] ( /** * Dereference RDDs that are older than rememberDuration. */ - protected[streaming] def forgetOldRDDs(time: Time) { - val keys = generatedRDDs.keys + protected[streaming] def forgetOldMetadata(time: Time) { var numForgotten = 0 - keys.foreach(t => { + generatedRDDs.keys.foreach(t => { if (t <= (time - rememberDuration)) { generatedRDDs.remove(t) numForgotten += 1 @@ -321,7 +320,7 @@ abstract class DStream[T: ClassManifest] ( } }) logInfo("Forgot " + numForgotten + " RDDs from " + this) - dependencies.foreach(_.forgetOldRDDs(time)) + dependencies.foreach(_.forgetOldMetadata(time)) } /* Adds metadata to the Stream while it is running. @@ -356,7 +355,7 @@ abstract class DStream[T: ClassManifest] ( */ protected[streaming] def restoreCheckpointData() { // Create RDDs from the checkpoint data - logInfo("Restoring checkpoint data from " + checkpointData.checkpointFiles.size + " checkpointed RDDs") + logInfo("Restoring checkpoint data") checkpointData.restore() dependencies.foreach(_.restoreCheckpointData()) logInfo("Restored checkpoint data") @@ -397,7 +396,7 @@ abstract class DStream[T: ClassManifest] ( /** Return a new DStream by applying a function to all elements of this DStream. */ def map[U: ClassManifest](mapFunc: T => U): DStream[U] = { - new MappedDStream(this, ssc.sc.clean(mapFunc)) + new MappedDStream(this, context.sparkContext.clean(mapFunc)) } /** @@ -405,7 +404,7 @@ abstract class DStream[T: ClassManifest] ( * and then flattening the results */ def flatMap[U: ClassManifest](flatMapFunc: T => Traversable[U]): DStream[U] = { - new FlatMappedDStream(this, ssc.sc.clean(flatMapFunc)) + new FlatMappedDStream(this, context.sparkContext.clean(flatMapFunc)) } /** Return a new DStream containing only the elements that satisfy a predicate. */ @@ -427,7 +426,7 @@ abstract class DStream[T: ClassManifest] ( mapPartFunc: Iterator[T] => Iterator[U], preservePartitioning: Boolean = false ): DStream[U] = { - new MapPartitionedDStream(this, ssc.sc.clean(mapPartFunc), preservePartitioning) + new MapPartitionedDStream(this, context.sparkContext.clean(mapPartFunc), preservePartitioning) } /** @@ -456,7 +455,7 @@ abstract class DStream[T: ClassManifest] ( * this DStream will be registered as an output stream and therefore materialized. */ def foreach(foreachFunc: (RDD[T], Time) => Unit) { - val newStream = new ForEachDStream(this, ssc.sc.clean(foreachFunc)) + val newStream = new ForEachDStream(this, context.sparkContext.clean(foreachFunc)) ssc.registerOutputStream(newStream) newStream } @@ -474,7 +473,7 @@ abstract class DStream[T: ClassManifest] ( * on each RDD of this DStream. */ def transform[U: ClassManifest](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = { - new TransformedDStream(this, ssc.sc.clean(transformFunc)) + new TransformedDStream(this, context.sparkContext.clean(transformFunc)) } /** @@ -491,7 +490,7 @@ abstract class DStream[T: ClassManifest] ( if (first11.size > 10) println("...") println() } - val newStream = new ForEachDStream(this, ssc.sc.clean(foreachFunc)) + val newStream = new ForEachDStream(this, context.sparkContext.clean(foreachFunc)) ssc.registerOutputStream(newStream) } diff --git a/streaming/src/main/scala/spark/streaming/DStreamCheckpointData.scala b/streaming/src/main/scala/spark/streaming/DStreamCheckpointData.scala index abf903293f..a375980b84 100644 --- a/streaming/src/main/scala/spark/streaming/DStreamCheckpointData.scala +++ b/streaming/src/main/scala/spark/streaming/DStreamCheckpointData.scala @@ -11,14 +11,17 @@ import spark.Logging private[streaming] class DStreamCheckpointData[T: ClassManifest] (dstream: DStream[T]) extends Serializable with Logging { - private[streaming] val checkpointFiles = new HashMap[Time, String]() - @transient private lazy val fileSystem = - new Path(dstream.context.checkpointDir).getFileSystem(new Configuration()) + protected val data = new HashMap[Time, AnyRef]() + + @transient private var fileSystem : FileSystem = null @transient private var lastCheckpointFiles: HashMap[Time, String] = null + protected[streaming] def checkpointFiles = data.asInstanceOf[HashMap[Time, String]] + /** - * Update the checkpoint data of the DStream. Default implementation records the checkpoint files to - * which the generate RDDs of the DStream has been saved. + * Updates the checkpoint data of the DStream. This gets called every time + * the graph checkpoint is initiated. Default implementation records the + * checkpoint files to which the generate RDDs of the DStream has been saved. */ def update() { @@ -42,7 +45,9 @@ class DStreamCheckpointData[T: ClassManifest] (dstream: DStream[T]) } /** - * Cleanup old checkpoint data. Default implementation, cleans up old checkpoint files. + * Cleanup old checkpoint data. This gets called every time the graph + * checkpoint is initiated, but after `update` is called. Default + * implementation, cleans up old checkpoint files. */ def cleanup() { // If there is at least on checkpoint file in the current checkpoint files, @@ -52,6 +57,9 @@ class DStreamCheckpointData[T: ClassManifest] (dstream: DStream[T]) case (time, file) => { try { val path = new Path(file) + if (fileSystem == null) { + fileSystem = path.getFileSystem(new Configuration()) + } fileSystem.delete(path, true) logInfo("Deleted checkpoint file '" + file + "' for time " + time) } catch { @@ -64,15 +72,16 @@ class DStreamCheckpointData[T: ClassManifest] (dstream: DStream[T]) } /** - * Restore the checkpoint data. Default implementation restores the RDDs from their - * checkpoint files. + * Restore the checkpoint data. This gets called once when the DStream graph + * (along with its DStreams) are being restored from a graph checkpoint file. + * Default implementation restores the RDDs from their checkpoint files. */ def restore() { // Create RDDs from the checkpoint data checkpointFiles.foreach { case(time, file) => { logInfo("Restoring checkpointed RDD for time " + time + " from file '" + file + "'") - dstream.generatedRDDs += ((time, dstream.context.sc.checkpointFile[T](file))) + dstream.generatedRDDs += ((time, dstream.context.sparkContext.checkpointFile[T](file))) } } } diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala index bc4a40d7bc..d5a5496839 100644 --- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala @@ -87,7 +87,7 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { private[streaming] def forgetOldRDDs(time: Time) { this.synchronized { - outputStreams.foreach(_.forgetOldRDDs(time)) + outputStreams.foreach(_.forgetOldMetadata(time)) } } diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 14500bdcb1..2cf00e3baa 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -61,7 +61,7 @@ class StreamingContext private ( protected[streaming] val isCheckpointPresent = (cp_ != null) - val sc: SparkContext = { + protected[streaming] val sc: SparkContext = { if (isCheckpointPresent) { new SparkContext(cp_.master, cp_.framework, cp_.sparkHome, cp_.jars) } else { @@ -100,6 +100,11 @@ class StreamingContext private ( protected[streaming] var receiverJobThread: Thread = null protected[streaming] var scheduler: Scheduler = null + /** + * Returns the associated Spark context + */ + def sparkContext = sc + /** * Sets each DStreams in this context to remember RDDs it generated in the last given duration. * DStreams remember RDDs only for a limited duration of time and releases them for garbage diff --git a/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala index 1e6ad84b44..c6ffb252ce 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala @@ -2,13 +2,14 @@ package spark.streaming.dstream import spark.RDD import spark.rdd.UnionRDD -import spark.streaming.{StreamingContext, Time} +import spark.streaming.{DStreamCheckpointData, StreamingContext, Time} import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} -import scala.collection.mutable.HashSet +import scala.collection.mutable.{HashSet, HashMap} +import java.io.{ObjectInputStream, IOException} private[streaming] class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K,V] : ClassManifest]( @@ -18,21 +19,14 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K newFilesOnly: Boolean = true) extends InputDStream[(K, V)](ssc_) { + protected[streaming] override val checkpointData = new FileInputDStreamCheckpointData + + private val lastModTimeFiles = new HashSet[String]() + private var lastModTime = 0L + @transient private var path_ : Path = null @transient private var fs_ : FileSystem = null - - var lastModTime = 0L - val lastModTimeFiles = new HashSet[String]() - - def path(): Path = { - if (path_ == null) path_ = new Path(directory) - path_ - } - - def fs(): FileSystem = { - if (fs_ == null) fs_ = path.getFileSystem(new Configuration()) - fs_ - } + @transient private var files = new HashMap[Time, Array[String]] override def start() { if (newFilesOnly) { @@ -79,8 +73,8 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K } } - val newFiles = fs.listStatus(path, newFilter) - logInfo("New files: " + newFiles.map(_.getPath).mkString(", ")) + val newFiles = fs.listStatus(path, newFilter).map(_.getPath.toString) + logInfo("New files: " + newFiles.mkString(", ")) if (newFiles.length > 0) { // Update the modification time and the files processed for that modification time if (lastModTime != newFilter.latestModTime) { @@ -89,9 +83,70 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K } lastModTimeFiles ++= newFilter.latestModTimeFiles } - val newRDD = new UnionRDD(ssc.sc, newFiles.map( - file => ssc.sc.newAPIHadoopFile[K, V, F](file.getPath.toString))) - Some(newRDD) + files += ((validTime, newFiles)) + Some(filesToRDD(newFiles)) + } + + /** Forget the old time-to-files mappings along with old RDDs */ + protected[streaming] override def forgetOldMetadata(time: Time) { + super.forgetOldMetadata(time) + val filesToBeRemoved = files.filter(_._1 <= (time - rememberDuration)) + files --= filesToBeRemoved.keys + logInfo("Forgot " + filesToBeRemoved.size + " files from " + this) + } + + /** Generate one RDD from an array of files */ + protected[streaming] def filesToRDD(files: Seq[String]): RDD[(K, V)] = { + new UnionRDD( + context.sparkContext, + files.map(file => context.sparkContext.newAPIHadoopFile[K, V, F](file)) + ) + } + + private def path: Path = { + if (path_ == null) path_ = new Path(directory) + path_ + } + + private def fs: FileSystem = { + if (fs_ == null) fs_ = path.getFileSystem(new Configuration()) + fs_ + } + + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + logDebug(this.getClass().getSimpleName + ".readObject used") + ois.defaultReadObject() + generatedRDDs = new HashMap[Time, RDD[(K,V)]] () + files = new HashMap[Time, Array[String]] + } + + /** + * A custom version of the DStreamCheckpointData that stores names of + * Hadoop files as checkpoint data. + */ + private[streaming] + class FileInputDStreamCheckpointData extends DStreamCheckpointData(this) { + + def hadoopFiles = data.asInstanceOf[HashMap[Time, Array[String]]] + + override def update() { + hadoopFiles.clear() + hadoopFiles ++= files + } + + override def cleanup() { } + + override def restore() { + hadoopFiles.foreach { + case (time, files) => { + logInfo("Restoring Hadoop RDD for time " + time + " from files " + + files.mkString("[", ",", "]") ) + files + generatedRDDs += ((time, filesToRDD(files))) + } + } + } } } @@ -100,3 +155,4 @@ object FileInputDStream { def defaultFilter(path: Path): Boolean = !path.getName().startsWith(".") } + diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala index d7ba7a5d17..4f6204f205 100644 --- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala @@ -214,10 +214,6 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { //Thread.sleep(100) } val startTime = System.currentTimeMillis() - /*while (output.size < expectedOutput.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { - logInfo("output.size = " + output.size + ", expectedOutput.size = " + expectedOutput.size) - Thread.sleep(100) - }*/ Thread.sleep(1000) val timeTaken = System.currentTimeMillis() - startTime assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") @@ -226,11 +222,9 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Verify whether data received by Spark Streaming was as expected logInfo("--------------------------------") - logInfo("output.size = " + outputBuffer.size) - logInfo("output") + logInfo("output, size = " + outputBuffer.size) outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("expected output.size = " + expectedOutput.size) - logInfo("expected output") + logInfo("expected output, size = " + expectedOutput.size) expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) logInfo("--------------------------------") @@ -256,8 +250,13 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Set up the streaming context and input streams var ssc = new StreamingContext(master, framework, batchDuration) ssc.checkpoint(checkpointDir, checkpointInterval) - val filestream = ssc.textFileStream(testDir.toString) - var outputStream = new TestOutputStream(filestream, new ArrayBuffer[Seq[String]]) + val fileStream = ssc.textFileStream(testDir.toString) + val outputBuffer = new ArrayBuffer[Seq[Int]] + // Reduced over a large window to ensure that recovery from master failure + // requires reprocessing of all the files seen before the failure + val reducedStream = fileStream.map(_.toInt) + .reduceByWindow(_ + _, batchDuration * 30, batchDuration) + var outputStream = new TestOutputStream(reducedStream, outputBuffer) ssc.registerOutputStream(outputStream) ssc.start() @@ -266,31 +265,56 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { Thread.sleep(1000) for (i <- Seq(1, 2, 3)) { FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n") - Thread.sleep(100) + // wait to make sure that the file is written such that it gets shown in the file listings + Thread.sleep(500) clock.addToTime(batchDuration.milliseconds) + // wait to make sure that FileInputDStream picks up this file only and not any other file + Thread.sleep(500) } - Thread.sleep(500) logInfo("Output = " + outputStream.output.mkString(",")) - assert(outputStream.output.size > 0) + assert(outputStream.output.size > 0, "No files processed before restart") ssc.stop() + for (i <- Seq(4, 5, 6)) { + FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n") + Thread.sleep(1000) + } + // Restart stream computation from checkpoint and create more files to see whether // they are being processed logInfo("*********** RESTARTING ************") ssc = new StreamingContext(checkpointDir) ssc.start() clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - Thread.sleep(500) - for (i <- Seq(4, 5, 6)) { + for (i <- Seq(7, 8, 9)) { FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n") - Thread.sleep(100) + Thread.sleep(500) clock.addToTime(batchDuration.milliseconds) + Thread.sleep(500) } - Thread.sleep(500) - outputStream = ssc.graph.getOutputStreams().head.asInstanceOf[TestOutputStream[String]] - logInfo("Output = " + outputStream.output.mkString(",")) - assert(outputStream.output.size > 0) + Thread.sleep(1000) + assert(outputStream.output.size > 0, "No files processed after restart") ssc.stop() + + // Append the new output to the old buffer + outputStream = ssc.graph.getOutputStreams().head.asInstanceOf[TestOutputStream[Int]] + outputBuffer ++= outputStream.output + + // Verify whether data received by Spark Streaming was as expected + val expectedOutput = Seq(1, 3, 6, 28, 36, 45) + logInfo("--------------------------------") + logInfo("output, size = " + outputBuffer.size) + outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("expected output, size = " + expectedOutput.size) + expectedOutput.foreach(x => logInfo("[" + x + "]")) + logInfo("--------------------------------") + + // Verify whether all the elements received are as expected + assert(outputBuffer.size === expectedOutput.size) + for (i <- 0 until outputBuffer.size) { + assert(outputBuffer(i).size === 1) + assert(outputBuffer(i).head === expectedOutput(i)) + } } } From 325297e5c31418f32deeb2a3cc52755094a11cea Mon Sep 17 00:00:00 2001 From: Mikhail Bautin Date: Tue, 22 Jan 2013 17:31:11 -0800 Subject: [PATCH 170/291] Add an Avro dependency to REPL to make it compile with Hadoop 2 --- pom.xml | 11 +++++++++++ repl/pom.xml | 10 ++++++++++ 2 files changed, 21 insertions(+) diff --git a/pom.xml b/pom.xml index 483b0f9595..3ea989a082 100644 --- a/pom.xml +++ b/pom.xml @@ -542,6 +542,17 @@ hadoop-client 2.0.0-mr1-cdh${cdh.version}
    + + + org.apache.avro + avro + 1.7.1.cloudera.2 + + + org.apache.avro + avro-ipc + 1.7.1.cloudera.2 +
    diff --git a/repl/pom.xml b/repl/pom.xml index 2fc9692969..2dc96beaf5 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -175,6 +175,16 @@ hadoop-client provided
    + + org.apache.avro + avro + provided + + + org.apache.avro + avro-ipc + provided +
    From 284993100022cc4bd43bf84a0be4dd91cf7a4ac0 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Tue, 22 Jan 2013 22:19:30 -0800 Subject: [PATCH 171/291] Eliminate CacheTracker. Replaces DAGScheduler's queries of CacheTracker with BlockManagerMaster queries. Adds CacheManager to locally coordinate computation of cached RDDs. --- core/src/main/scala/spark/CacheTracker.scala | 240 ------------------ core/src/main/scala/spark/RDD.scala | 2 +- core/src/main/scala/spark/SparkEnv.scala | 8 +- .../scala/spark/scheduler/DAGScheduler.scala | 24 +- .../scala/spark/storage/BlockManager.scala | 24 +- .../test/scala/spark/CacheTrackerSuite.scala | 131 ---------- 6 files changed, 18 insertions(+), 411 deletions(-) delete mode 100644 core/src/main/scala/spark/CacheTracker.scala delete mode 100644 core/src/test/scala/spark/CacheTrackerSuite.scala diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala deleted file mode 100644 index 86ad737583..0000000000 --- a/core/src/main/scala/spark/CacheTracker.scala +++ /dev/null @@ -1,240 +0,0 @@ -package spark - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet - -import akka.actor._ -import akka.dispatch._ -import akka.pattern.ask -import akka.remote._ -import akka.util.Duration -import akka.util.Timeout -import akka.util.duration._ - -import spark.storage.BlockManager -import spark.storage.StorageLevel -import util.{TimeStampedHashSet, MetadataCleaner, TimeStampedHashMap} - -private[spark] sealed trait CacheTrackerMessage - -private[spark] case class AddedToCache(rddId: Int, partition: Int, host: String, size: Long = 0L) - extends CacheTrackerMessage -private[spark] case class DroppedFromCache(rddId: Int, partition: Int, host: String, size: Long = 0L) - extends CacheTrackerMessage -private[spark] case class MemoryCacheLost(host: String) extends CacheTrackerMessage -private[spark] case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheTrackerMessage -private[spark] case class SlaveCacheStarted(host: String, size: Long) extends CacheTrackerMessage -private[spark] case object GetCacheStatus extends CacheTrackerMessage -private[spark] case object GetCacheLocations extends CacheTrackerMessage -private[spark] case object StopCacheTracker extends CacheTrackerMessage - -private[spark] class CacheTrackerActor extends Actor with Logging { - // TODO: Should probably store (String, CacheType) tuples - private val locs = new TimeStampedHashMap[Int, Array[List[String]]] - - /** - * A map from the slave's host name to its cache size. - */ - private val slaveCapacity = new HashMap[String, Long] - private val slaveUsage = new HashMap[String, Long] - - private val metadataCleaner = new MetadataCleaner("CacheTrackerActor", locs.clearOldValues) - - private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L) - private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L) - private def getCacheAvailable(host: String): Long = getCacheCapacity(host) - getCacheUsage(host) - - def receive = { - case SlaveCacheStarted(host: String, size: Long) => - slaveCapacity.put(host, size) - slaveUsage.put(host, 0) - sender ! true - - case RegisterRDD(rddId: Int, numPartitions: Int) => - logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions") - locs(rddId) = Array.fill[List[String]](numPartitions)(Nil) - sender ! true - - case AddedToCache(rddId, partition, host, size) => - slaveUsage.put(host, getCacheUsage(host) + size) - locs(rddId)(partition) = host :: locs(rddId)(partition) - sender ! true - - case DroppedFromCache(rddId, partition, host, size) => - slaveUsage.put(host, getCacheUsage(host) - size) - // Do a sanity check to make sure usage is greater than 0. - locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host) - sender ! true - - case MemoryCacheLost(host) => - logInfo("Memory cache lost on " + host) - for ((id, locations) <- locs) { - for (i <- 0 until locations.length) { - locations(i) = locations(i).filterNot(_ == host) - } - } - sender ! true - - case GetCacheLocations => - logInfo("Asked for current cache locations") - sender ! locs.map{case (rrdId, array) => (rrdId -> array.clone())} - - case GetCacheStatus => - val status = slaveCapacity.map { case (host, capacity) => - (host, capacity, getCacheUsage(host)) - }.toSeq - sender ! status - - case StopCacheTracker => - logInfo("Stopping CacheTrackerActor") - sender ! true - metadataCleaner.cancel() - context.stop(self) - } -} - -private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: BlockManager) - extends Logging { - - // Tracker actor on the master, or remote reference to it on workers - val ip: String = System.getProperty("spark.master.host", "localhost") - val port: Int = System.getProperty("spark.master.port", "7077").toInt - val actorName: String = "CacheTracker" - - val timeout = 10.seconds - - var trackerActor: ActorRef = if (isMaster) { - val actor = actorSystem.actorOf(Props[CacheTrackerActor], name = actorName) - logInfo("Registered CacheTrackerActor actor") - actor - } else { - val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName) - actorSystem.actorFor(url) - } - - // TODO: Consider removing this HashSet completely as locs CacheTrackerActor already - // keeps track of registered RDDs - val registeredRddIds = new TimeStampedHashSet[Int] - - // Remembers which splits are currently being loaded (on worker nodes) - val loading = new HashSet[String] - - val metadataCleaner = new MetadataCleaner("CacheTracker", registeredRddIds.clearOldValues) - - // Send a message to the trackerActor and get its result within a default timeout, or - // throw a SparkException if this fails. - def askTracker(message: Any): Any = { - try { - val future = trackerActor.ask(message)(timeout) - return Await.result(future, timeout) - } catch { - case e: Exception => - throw new SparkException("Error communicating with CacheTracker", e) - } - } - - // Send a one-way message to the trackerActor, to which we expect it to reply with true. - def communicate(message: Any) { - if (askTracker(message) != true) { - throw new SparkException("Error reply received from CacheTracker") - } - } - - // Registers an RDD (on master only) - def registerRDD(rddId: Int, numPartitions: Int) { - registeredRddIds.synchronized { - if (!registeredRddIds.contains(rddId)) { - logInfo("Registering RDD ID " + rddId + " with cache") - registeredRddIds += rddId - communicate(RegisterRDD(rddId, numPartitions)) - } - } - } - - // For BlockManager.scala only - def cacheLost(host: String) { - communicate(MemoryCacheLost(host)) - logInfo("CacheTracker successfully removed entries on " + host) - } - - // Get the usage status of slave caches. Each tuple in the returned sequence - // is in the form of (host name, capacity, usage). - def getCacheStatus(): Seq[(String, Long, Long)] = { - askTracker(GetCacheStatus).asInstanceOf[Seq[(String, Long, Long)]] - } - - // For BlockManager.scala only - def notifyFromBlockManager(t: AddedToCache) { - communicate(t) - } - - // Get a snapshot of the currently known locations - def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = { - askTracker(GetCacheLocations).asInstanceOf[HashMap[Int, Array[List[String]]]] - } - - // Gets or computes an RDD split - def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel) - : Iterator[T] = { - val key = "rdd_%d_%d".format(rdd.id, split.index) - logInfo("Cache key is " + key) - blockManager.get(key) match { - case Some(cachedValues) => - // Split is in cache, so just return its values - logInfo("Found partition in cache!") - return cachedValues.asInstanceOf[Iterator[T]] - - case None => - // Mark the split as loading (unless someone else marks it first) - loading.synchronized { - if (loading.contains(key)) { - logInfo("Loading contains " + key + ", waiting...") - while (loading.contains(key)) { - try {loading.wait()} catch {case _ =>} - } - logInfo("Loading no longer contains " + key + ", so returning cached result") - // See whether someone else has successfully loaded it. The main way this would fail - // is for the RDD-level cache eviction policy if someone else has loaded the same RDD - // partition but we didn't want to make space for it. However, that case is unlikely - // because it's unlikely that two threads would work on the same RDD partition. One - // downside of the current code is that threads wait serially if this does happen. - blockManager.get(key) match { - case Some(values) => - return values.asInstanceOf[Iterator[T]] - case None => - logInfo("Whoever was loading " + key + " failed; we'll try it ourselves") - loading.add(key) - } - } else { - loading.add(key) - } - } - try { - // If we got here, we have to load the split - val elements = new ArrayBuffer[Any] - logInfo("Computing partition " + split) - elements ++= rdd.compute(split, context) - // Try to put this block in the blockManager - blockManager.put(key, elements, storageLevel, true) - return elements.iterator.asInstanceOf[Iterator[T]] - } finally { - loading.synchronized { - loading.remove(key) - loading.notifyAll() - } - } - } - } - - // Called by the Cache to report that an entry has been dropped from it - def dropEntry(rddId: Int, partition: Int) { - communicate(DroppedFromCache(rddId, partition, Utils.localHostName())) - } - - def stop() { - communicate(StopCacheTracker) - registeredRddIds.clear() - trackerActor = null - } -} diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index e0d2eabb1d..c79f34342f 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -176,7 +176,7 @@ abstract class RDD[T: ClassManifest]( if (isCheckpointed) { checkpointData.get.iterator(split, context) } else if (storageLevel != StorageLevel.NONE) { - SparkEnv.get.cacheTracker.getOrCompute[T](this, split, context, storageLevel) + SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel) } else { compute(split, context) } diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 41441720a7..a080194980 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -22,7 +22,7 @@ class SparkEnv ( val actorSystem: ActorSystem, val serializer: Serializer, val closureSerializer: Serializer, - val cacheTracker: CacheTracker, + val cacheManager: CacheManager, val mapOutputTracker: MapOutputTracker, val shuffleFetcher: ShuffleFetcher, val broadcastManager: BroadcastManager, @@ -39,7 +39,6 @@ class SparkEnv ( def stop() { httpFileServer.stop() mapOutputTracker.stop() - cacheTracker.stop() shuffleFetcher.stop() broadcastManager.stop() blockManager.stop() @@ -100,8 +99,7 @@ object SparkEnv extends Logging { val closureSerializer = instantiateClass[Serializer]( "spark.closure.serializer", "spark.JavaSerializer") - val cacheTracker = new CacheTracker(actorSystem, isMaster, blockManager) - blockManager.cacheTracker = cacheTracker + val cacheManager = new CacheManager(blockManager) val mapOutputTracker = new MapOutputTracker(actorSystem, isMaster) @@ -122,7 +120,7 @@ object SparkEnv extends Logging { actorSystem, serializer, closureSerializer, - cacheTracker, + cacheManager, mapOutputTracker, shuffleFetcher, broadcastManager, diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 59f2099e91..03d173ac3b 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -69,8 +69,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with var cacheLocs = new HashMap[Int, Array[List[String]]] val env = SparkEnv.get - val cacheTracker = env.cacheTracker val mapOutputTracker = env.mapOutputTracker + val blockManagerMaster = env.blockManager.master val deadHosts = new HashSet[String] // TODO: The code currently assumes these can't come back; // that's not going to be a realistic assumption in general @@ -95,11 +95,17 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with }.start() def getCacheLocs(rdd: RDD[_]): Array[List[String]] = { + if (!cacheLocs.contains(rdd.id)) { + val blockIds = rdd.splits.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray + cacheLocs(rdd.id) = blockManagerMaster.getLocations(blockIds).map { + locations => locations.map(_.ip).toList + }.toArray + } cacheLocs(rdd.id) } - def updateCacheLocs() { - cacheLocs = cacheTracker.getLocationsSnapshot() + def clearCacheLocs() { + cacheLocs.clear } /** @@ -126,7 +132,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with // Kind of ugly: need to register RDDs with the cache and map output tracker here // since we can't do it in the RDD constructor because # of splits is unknown logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")") - cacheTracker.registerRDD(rdd.id, rdd.splits.size) if (shuffleDep != None) { mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size) } @@ -148,8 +153,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with visited += r // Kind of ugly: need to register RDDs with the cache here since // we can't do it in its constructor because # of splits is unknown - logInfo("Registering parent RDD " + r.id + " (" + r.origin + ")") - cacheTracker.registerRDD(r.id, r.splits.size) for (dep <- r.dependencies) { dep match { case shufDep: ShuffleDependency[_,_] => @@ -250,7 +253,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val runId = nextRunId.getAndIncrement() val finalStage = newStage(finalRDD, None, runId) val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener) - updateCacheLocs() + clearCacheLocs() logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length + " output partitions") logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")") @@ -293,7 +296,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with // on the failed node. if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) { logInfo("Resubmitting failed stages") - updateCacheLocs() + clearCacheLocs() val failed2 = failed.toArray failed.clear() for (stage <- failed2.sortBy(_.priority)) { @@ -443,7 +446,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with stage.shuffleDep.get.shuffleId, stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray) } - updateCacheLocs() + clearCacheLocs() if (stage.outputLocs.count(_ == Nil) != 0) { // Some tasks had failed; let's resubmit this stage // TODO: Lower-level scheduler should also deal with this @@ -519,8 +522,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray mapOutputTracker.registerMapOutputs(shuffleId, locs, true) } - cacheTracker.cacheLost(host) - updateCacheLocs() + clearCacheLocs() } } diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 7a8ac10cdd..e049565f48 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -16,7 +16,7 @@ import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream -import spark.{CacheTracker, Logging, SizeEstimator, SparkEnv, SparkException, Utils} +import spark.{Logging, SizeEstimator, SparkEnv, SparkException, Utils} import spark.network._ import spark.serializer.Serializer import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStampedHashMap} @@ -71,9 +71,6 @@ class BlockManager( val connectionManagerId = connectionManager.id val blockManagerId = new BlockManagerId(connectionManagerId.host, connectionManagerId.port) - // TODO: This will be removed after cacheTracker is removed from the code base. - var cacheTracker: CacheTracker = null - // Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory // for receiving shuffle outputs) val maxBytesInFlight = @@ -662,10 +659,6 @@ class BlockManager( BlockManager.dispose(bytesAfterPut) - // TODO: This code will be removed when CacheTracker is gone. - if (blockId.startsWith("rdd")) { - notifyCacheTracker(blockId) - } logDebug("Put block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)) return size @@ -733,11 +726,6 @@ class BlockManager( } } - // TODO: This code will be removed when CacheTracker is gone. - if (blockId.startsWith("rdd")) { - notifyCacheTracker(blockId) - } - // If replication had started, then wait for it to finish if (level.replication > 1) { if (replicationFuture == null) { @@ -780,16 +768,6 @@ class BlockManager( } } - // TODO: This code will be removed when CacheTracker is gone. - private def notifyCacheTracker(key: String) { - if (cacheTracker != null) { - val rddInfo = key.split("_") - val rddId: Int = rddInfo(1).toInt - val partition: Int = rddInfo(2).toInt - cacheTracker.notifyFromBlockManager(spark.AddedToCache(rddId, partition, host)) - } - } - /** * Read a block consisting of a single object. */ diff --git a/core/src/test/scala/spark/CacheTrackerSuite.scala b/core/src/test/scala/spark/CacheTrackerSuite.scala deleted file mode 100644 index 467605981b..0000000000 --- a/core/src/test/scala/spark/CacheTrackerSuite.scala +++ /dev/null @@ -1,131 +0,0 @@ -package spark - -import org.scalatest.FunSuite - -import scala.collection.mutable.HashMap - -import akka.actor._ -import akka.dispatch._ -import akka.pattern.ask -import akka.remote._ -import akka.util.Duration -import akka.util.Timeout -import akka.util.duration._ - -class CacheTrackerSuite extends FunSuite { - // Send a message to an actor and wait for a reply, in a blocking manner - private def ask(actor: ActorRef, message: Any): Any = { - try { - val timeout = 10.seconds - val future = actor.ask(message)(timeout) - return Await.result(future, timeout) - } catch { - case e: Exception => - throw new SparkException("Error communicating with actor", e) - } - } - - test("CacheTrackerActor slave initialization & cache status") { - //System.setProperty("spark.master.port", "1345") - val initialSize = 2L << 20 - - val actorSystem = ActorSystem("test") - val tracker = actorSystem.actorOf(Props[CacheTrackerActor]) - - assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true) - - assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 0L))) - - assert(ask(tracker, StopCacheTracker) === true) - - actorSystem.shutdown() - actorSystem.awaitTermination() - } - - test("RegisterRDD") { - //System.setProperty("spark.master.port", "1345") - val initialSize = 2L << 20 - - val actorSystem = ActorSystem("test") - val tracker = actorSystem.actorOf(Props[CacheTrackerActor]) - - assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true) - - assert(ask(tracker, RegisterRDD(1, 3)) === true) - assert(ask(tracker, RegisterRDD(2, 1)) === true) - - assert(getCacheLocations(tracker) === Map(1 -> List(Nil, Nil, Nil), 2 -> List(Nil))) - - assert(ask(tracker, StopCacheTracker) === true) - - actorSystem.shutdown() - actorSystem.awaitTermination() - } - - test("AddedToCache") { - //System.setProperty("spark.master.port", "1345") - val initialSize = 2L << 20 - - val actorSystem = ActorSystem("test") - val tracker = actorSystem.actorOf(Props[CacheTrackerActor]) - - assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true) - - assert(ask(tracker, RegisterRDD(1, 2)) === true) - assert(ask(tracker, RegisterRDD(2, 1)) === true) - - assert(ask(tracker, AddedToCache(1, 0, "host001", 2L << 15)) === true) - assert(ask(tracker, AddedToCache(1, 1, "host001", 2L << 11)) === true) - assert(ask(tracker, AddedToCache(2, 0, "host001", 3L << 10)) === true) - - assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 72704L))) - - assert(getCacheLocations(tracker) === - Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001")))) - - assert(ask(tracker, StopCacheTracker) === true) - - actorSystem.shutdown() - actorSystem.awaitTermination() - } - - test("DroppedFromCache") { - //System.setProperty("spark.master.port", "1345") - val initialSize = 2L << 20 - - val actorSystem = ActorSystem("test") - val tracker = actorSystem.actorOf(Props[CacheTrackerActor]) - - assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true) - - assert(ask(tracker, RegisterRDD(1, 2)) === true) - assert(ask(tracker, RegisterRDD(2, 1)) === true) - - assert(ask(tracker, AddedToCache(1, 0, "host001", 2L << 15)) === true) - assert(ask(tracker, AddedToCache(1, 1, "host001", 2L << 11)) === true) - assert(ask(tracker, AddedToCache(2, 0, "host001", 3L << 10)) === true) - - assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 72704L))) - assert(getCacheLocations(tracker) === - Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001")))) - - assert(ask(tracker, DroppedFromCache(1, 1, "host001", 2L << 11)) === true) - - assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 68608L))) - assert(getCacheLocations(tracker) === - Map(1 -> List(List("host001"),List()), 2 -> List(List("host001")))) - - assert(ask(tracker, StopCacheTracker) === true) - - actorSystem.shutdown() - actorSystem.awaitTermination() - } - - /** - * Helper function to get cacheLocations from CacheTracker - */ - def getCacheLocations(tracker: ActorRef): HashMap[Int, List[List[String]]] = { - val answer = ask(tracker, GetCacheLocations).asInstanceOf[HashMap[Int, Array[List[String]]]] - answer.map { case (i, arr) => (i, arr.toList) } - } -} From 43e9ff959645e533bcfa0a5c31e62e32c7e9d0a6 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 22 Jan 2013 22:47:26 -0800 Subject: [PATCH 172/291] Add test for driver hanging on exit (SPARK-530). --- core/src/test/scala/spark/DriverSuite.scala | 31 +++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 core/src/test/scala/spark/DriverSuite.scala diff --git a/core/src/test/scala/spark/DriverSuite.scala b/core/src/test/scala/spark/DriverSuite.scala new file mode 100644 index 0000000000..70a7c8bc2f --- /dev/null +++ b/core/src/test/scala/spark/DriverSuite.scala @@ -0,0 +1,31 @@ +package spark + +import java.io.File + +import org.scalatest.FunSuite +import org.scalatest.concurrent.Timeouts +import org.scalatest.prop.TableDrivenPropertyChecks._ +import org.scalatest.time.SpanSugar._ + +class DriverSuite extends FunSuite with Timeouts { + test("driver should exit after finishing") { + // Regression test for SPARK-530: "Spark driver process doesn't exit after finishing" + val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]")) + forAll(masters) { (master: String) => + failAfter(10 seconds) { + Utils.execute(Seq("./run", "spark.DriverWithoutCleanup", master), new File(System.getenv("SPARK_HOME"))) + } + } + } +} + +/** + * Program that creates a Spark driver but doesn't call SparkContext.stop() or + * Sys.exit() after finishing. + */ +object DriverWithoutCleanup { + def main(args: Array[String]) { + val sc = new SparkContext(args(0), "DriverWithoutCleanup") + sc.parallelize(1 to 100, 4).count() + } +} \ No newline at end of file From bacade6caf7527737dc6f02b1c2ca9114e02d8bc Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 22 Jan 2013 22:55:26 -0800 Subject: [PATCH 173/291] Modified BlockManagerId API to ensure zero duplicate objects. Fixed BlockManagerId testcase in BlockManagerTestSuite. --- .../scala/spark/scheduler/MapStatus.scala | 2 +- .../scala/spark/storage/BlockManager.scala | 2 +- .../scala/spark/storage/BlockManagerId.scala | 33 +++++++++++++++---- .../spark/storage/BlockManagerMessages.scala | 3 +- .../scala/spark/MapOutputTrackerSuite.scala | 22 ++++++------- .../spark/storage/BlockManagerSuite.scala | 18 +++++----- 6 files changed, 51 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/MapStatus.scala b/core/src/main/scala/spark/scheduler/MapStatus.scala index 4532d9497f..fae643f3a8 100644 --- a/core/src/main/scala/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/spark/scheduler/MapStatus.scala @@ -20,7 +20,7 @@ private[spark] class MapStatus(var address: BlockManagerId, var compressedSizes: } def readExternal(in: ObjectInput) { - address = new BlockManagerId(in) + address = BlockManagerId(in) compressedSizes = new Array[Byte](in.readInt()) in.readFully(compressedSizes) } diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 7a8ac10cdd..596a69c583 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -69,7 +69,7 @@ class BlockManager( implicit val futureExecContext = connectionManager.futureExecContext val connectionManagerId = connectionManager.id - val blockManagerId = new BlockManagerId(connectionManagerId.host, connectionManagerId.port) + val blockManagerId = BlockManagerId(connectionManagerId.host, connectionManagerId.port) // TODO: This will be removed after cacheTracker is removed from the code base. var cacheTracker: CacheTracker = null diff --git a/core/src/main/scala/spark/storage/BlockManagerId.scala b/core/src/main/scala/spark/storage/BlockManagerId.scala index 488679f049..26c98f2ac8 100644 --- a/core/src/main/scala/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/spark/storage/BlockManagerId.scala @@ -3,20 +3,35 @@ package spark.storage import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} import java.util.concurrent.ConcurrentHashMap +/** + * This class represent an unique identifier for a BlockManager. + * The first 2 constructors of this class is made private to ensure that + * BlockManagerId objects can be created only using the factory method in + * [[spark.storage.BlockManager$]]. This allows de-duplication of id objects. + * Also, constructor parameters are private to ensure that parameters cannot + * be modified from outside this class. + */ +private[spark] class BlockManagerId private ( + private var ip_ : String, + private var port_ : Int + ) extends Externalizable { + + private def this(in: ObjectInput) = this(in.readUTF(), in.readInt()) -private[spark] class BlockManagerId(var ip: String, var port: Int) extends Externalizable { def this() = this(null, 0) // For deserialization only - def this(in: ObjectInput) = this(in.readUTF(), in.readInt()) + def ip = ip_ + + def port = port_ override def writeExternal(out: ObjectOutput) { - out.writeUTF(ip) - out.writeInt(port) + out.writeUTF(ip_) + out.writeInt(port_) } override def readExternal(in: ObjectInput) { - ip = in.readUTF() - port = in.readInt() + ip_ = in.readUTF() + port_ = in.readInt() } @throws(classOf[IOException]) @@ -35,6 +50,12 @@ private[spark] class BlockManagerId(var ip: String, var port: Int) extends Exter private[spark] object BlockManagerId { + def apply(ip: String, port: Int) = + getCachedBlockManagerId(new BlockManagerId(ip, port)) + + def apply(in: ObjectInput) = + getCachedBlockManagerId(new BlockManagerId(in)) + val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]() def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = { diff --git a/core/src/main/scala/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/spark/storage/BlockManagerMessages.scala index d73a9b790f..7437fc63eb 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMessages.scala @@ -54,8 +54,7 @@ class UpdateBlockInfo( } override def readExternal(in: ObjectInput) { - blockManagerId = new BlockManagerId() - blockManagerId.readExternal(in) + blockManagerId = BlockManagerId(in) blockId = in.readUTF() storageLevel = new StorageLevel() storageLevel.readExternal(in) diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala index d3dd3a8fa4..095f415978 100644 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala @@ -47,13 +47,13 @@ class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { val compressedSize10000 = MapOutputTracker.compressSize(10000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) val size10000 = MapOutputTracker.decompressSize(compressedSize10000) - tracker.registerMapOutput(10, 0, new MapStatus(new BlockManagerId("hostA", 1000), + tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("hostA", 1000), Array(compressedSize1000, compressedSize10000))) - tracker.registerMapOutput(10, 1, new MapStatus(new BlockManagerId("hostB", 1000), + tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("hostB", 1000), Array(compressedSize10000, compressedSize1000))) val statuses = tracker.getServerStatuses(10, 0) - assert(statuses.toSeq === Seq((new BlockManagerId("hostA", 1000), size1000), - (new BlockManagerId("hostB", 1000), size10000))) + assert(statuses.toSeq === Seq((BlockManagerId("hostA", 1000), size1000), + (BlockManagerId("hostB", 1000), size10000))) tracker.stop() } @@ -65,14 +65,14 @@ class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { val compressedSize10000 = MapOutputTracker.compressSize(10000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) val size10000 = MapOutputTracker.decompressSize(compressedSize10000) - tracker.registerMapOutput(10, 0, new MapStatus(new BlockManagerId("hostA", 1000), + tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("hostA", 1000), Array(compressedSize1000, compressedSize1000, compressedSize1000))) - tracker.registerMapOutput(10, 1, new MapStatus(new BlockManagerId("hostB", 1000), + tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("hostB", 1000), Array(compressedSize10000, compressedSize1000, compressedSize1000))) // As if we had two simulatenous fetch failures - tracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000)) - tracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000)) + tracker.unregisterMapOutput(10, 0, BlockManagerId("hostA", 1000)) + tracker.unregisterMapOutput(10, 0, BlockManagerId("hostA", 1000)) // The remaining reduce task might try to grab the output dispite the shuffle failure; // this should cause it to fail, and the scheduler will ignore the failure due to the @@ -95,13 +95,13 @@ class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { val compressedSize1000 = MapOutputTracker.compressSize(1000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) masterTracker.registerMapOutput(10, 0, new MapStatus( - new BlockManagerId("hostA", 1000), Array(compressedSize1000))) + BlockManagerId("hostA", 1000), Array(compressedSize1000))) masterTracker.incrementGeneration() slaveTracker.updateGeneration(masterTracker.getGeneration) assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((new BlockManagerId("hostA", 1000), size1000))) + Seq((BlockManagerId("hostA", 1000), size1000))) - masterTracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000)) + masterTracker.unregisterMapOutput(10, 0, BlockManagerId("hostA", 1000)) masterTracker.incrementGeneration() slaveTracker.updateGeneration(masterTracker.getGeneration) intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index 8f86e3170e..a33d3324ba 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -82,16 +82,18 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("BlockManagerId object caching") { - val id1 = new StorageLevel(false, false, false, 3) - val id2 = new StorageLevel(false, false, false, 3) + val id1 = BlockManagerId("XXX", 1) + val id2 = BlockManagerId("XXX", 1) // this should return the same object as id1 + assert(id2 === id1, "id2 is not same as id1") + assert(id2.eq(id1), "id2 is not the same object as id1") val bytes1 = spark.Utils.serialize(id1) - val id1_ = spark.Utils.deserialize[StorageLevel](bytes1) + val id1_ = spark.Utils.deserialize[BlockManagerId](bytes1) val bytes2 = spark.Utils.serialize(id2) - val id2_ = spark.Utils.deserialize[StorageLevel](bytes2) - assert(id1_ === id1, "Deserialized id1 not same as original id1") - assert(id2_ === id2, "Deserialized id2 not same as original id1") - assert(id1_ === id2_, "Deserialized id1 not same as deserialized id2") - assert(id2_.eq(id1_), "Deserialized id2 not the same object as deserialized level1") + val id2_ = spark.Utils.deserialize[BlockManagerId](bytes2) + assert(id1_ === id1, "Deserialized id1 is not same as original id1") + assert(id1_.eq(id1), "Deserialized id1 is not the same object as original id1") + assert(id2_ === id2, "Deserialized id2 is not same as original id2") + assert(id2_.eq(id1), "Deserialized id2 is not the same object as original id1") } test("master + 1 manager interaction") { From 5e11f1e51f17113abb8d3a5bc261af5ba5ffce94 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 22 Jan 2013 23:42:53 -0800 Subject: [PATCH 174/291] Modified StorageLevel API to ensure zero duplicate objects. --- .../scala/spark/storage/BlockManager.scala | 5 +- .../scala/spark/storage/BlockMessage.scala | 2 +- .../scala/spark/storage/StorageLevel.scala | 47 ++++++++++++------- .../spark/storage/BlockManagerSuite.scala | 16 +++++-- 4 files changed, 44 insertions(+), 26 deletions(-) diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 596a69c583..ca7eb13ec8 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -191,7 +191,7 @@ class BlockManager( case level => val inMem = level.useMemory && memoryStore.contains(blockId) val onDisk = level.useDisk && diskStore.contains(blockId) - val storageLevel = new StorageLevel(onDisk, inMem, level.deserialized, level.replication) + val storageLevel = StorageLevel(onDisk, inMem, level.deserialized, level.replication) val memSize = if (inMem) memoryStore.getSize(blockId) else 0L val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L (storageLevel, memSize, diskSize, info.tellMaster) @@ -760,8 +760,7 @@ class BlockManager( */ var cachedPeers: Seq[BlockManagerId] = null private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) { - val tLevel: StorageLevel = - new StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1) + val tLevel = StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1) if (cachedPeers == null) { cachedPeers = master.getPeers(blockManagerId, level.replication - 1) } diff --git a/core/src/main/scala/spark/storage/BlockMessage.scala b/core/src/main/scala/spark/storage/BlockMessage.scala index 3f234df654..30d7500e01 100644 --- a/core/src/main/scala/spark/storage/BlockMessage.scala +++ b/core/src/main/scala/spark/storage/BlockMessage.scala @@ -64,7 +64,7 @@ private[spark] class BlockMessage() { val booleanInt = buffer.getInt() val replication = buffer.getInt() - level = new StorageLevel(booleanInt, replication) + level = StorageLevel(booleanInt, replication) val dataLength = buffer.getInt() data = ByteBuffer.allocate(dataLength) diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala index e3544e5aae..f2535ae5ae 100644 --- a/core/src/main/scala/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/spark/storage/StorageLevel.scala @@ -7,25 +7,30 @@ import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} * whether to drop the RDD to disk if it falls out of memory, whether to keep the data in memory * in a serialized format, and whether to replicate the RDD partitions on multiple nodes. * The [[spark.storage.StorageLevel$]] singleton object contains some static constants for - * commonly useful storage levels. + * commonly useful storage levels. The recommended method to create your own storage level + * object is to use `StorageLevel.apply(...)` from the singleton object. */ class StorageLevel( - var useDisk: Boolean, - var useMemory: Boolean, - var deserialized: Boolean, - var replication: Int = 1) + private var useDisk_ : Boolean, + private var useMemory_ : Boolean, + private var deserialized_ : Boolean, + private var replication_ : Int = 1) extends Externalizable { // TODO: Also add fields for caching priority, dataset ID, and flushing. - - assert(replication < 40, "Replication restricted to be less than 40 for calculating hashcodes") - - def this(flags: Int, replication: Int) { + private def this(flags: Int, replication: Int) { this((flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication) } def this() = this(false, true, false) // For deserialization + def useDisk = useDisk_ + def useMemory = useMemory_ + def deserialized = deserialized_ + def replication = replication_ + + assert(replication < 40, "Replication restricted to be less than 40 for calculating hashcodes") + override def clone(): StorageLevel = new StorageLevel( this.useDisk, this.useMemory, this.deserialized, this.replication) @@ -43,13 +48,13 @@ class StorageLevel( def toInt: Int = { var ret = 0 - if (useDisk) { + if (useDisk_) { ret |= 4 } - if (useMemory) { + if (useMemory_) { ret |= 2 } - if (deserialized) { + if (deserialized_) { ret |= 1 } return ret @@ -57,15 +62,15 @@ class StorageLevel( override def writeExternal(out: ObjectOutput) { out.writeByte(toInt) - out.writeByte(replication) + out.writeByte(replication_) } override def readExternal(in: ObjectInput) { val flags = in.readByte() - useDisk = (flags & 4) != 0 - useMemory = (flags & 2) != 0 - deserialized = (flags & 1) != 0 - replication = in.readByte() + useDisk_ = (flags & 4) != 0 + useMemory_ = (flags & 2) != 0 + deserialized_ = (flags & 1) != 0 + replication_ = in.readByte() } @throws(classOf[IOException]) @@ -91,6 +96,14 @@ object StorageLevel { val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false) val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2) + /** Create a new StorageLevel object */ + def apply(useDisk: Boolean, useMemory: Boolean, deserialized: Boolean, replication: Int = 1) = + getCachedStorageLevel(new StorageLevel(useDisk, useMemory, deserialized, replication)) + + /** Create a new StorageLevel object from its integer representation */ + def apply(flags: Int, replication: Int) = + getCachedStorageLevel(new StorageLevel(flags, replication)) + private[spark] val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]() diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index a33d3324ba..a1aeb12f25 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -69,23 +69,29 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("StorageLevel object caching") { - val level1 = new StorageLevel(false, false, false, 3) - val level2 = new StorageLevel(false, false, false, 3) + val level1 = StorageLevel(false, false, false, 3) + val level2 = StorageLevel(false, false, false, 3) // this should return the same object as level1 + val level3 = StorageLevel(false, false, false, 2) // this should return a different object + assert(level2 === level1, "level2 is not same as level1") + assert(level2.eq(level1), "level2 is not the same object as level1") + assert(level3 != level1, "level3 is same as level1") val bytes1 = spark.Utils.serialize(level1) val level1_ = spark.Utils.deserialize[StorageLevel](bytes1) val bytes2 = spark.Utils.serialize(level2) val level2_ = spark.Utils.deserialize[StorageLevel](bytes2) assert(level1_ === level1, "Deserialized level1 not same as original level1") - assert(level2_ === level2, "Deserialized level2 not same as original level1") - assert(level1_ === level2_, "Deserialized level1 not same as deserialized level2") - assert(level2_.eq(level1_), "Deserialized level2 not the same object as deserialized level1") + assert(level1_.eq(level1), "Deserialized level1 not the same object as original level2") + assert(level2_ === level2, "Deserialized level2 not same as original level2") + assert(level2_.eq(level1), "Deserialized level2 not the same object as original level1") } test("BlockManagerId object caching") { val id1 = BlockManagerId("XXX", 1) val id2 = BlockManagerId("XXX", 1) // this should return the same object as id1 + val id3 = BlockManagerId("XXX", 2) // this should return a different object assert(id2 === id1, "id2 is not same as id1") assert(id2.eq(id1), "id2 is not the same object as id1") + assert(id3 != id1, "id3 is same as id1") val bytes1 = spark.Utils.serialize(id1) val id1_ = spark.Utils.deserialize[BlockManagerId](bytes1) val bytes2 = spark.Utils.serialize(id2) From 155f31398dc83ecb88b4b3e07849a2a8a0a6592f Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 23 Jan 2013 01:10:26 -0800 Subject: [PATCH 175/291] Made StorageLevel constructor private, and added StorageLevels.create() to the Java API. Updates scala and java programming guides. --- core/src/main/scala/spark/api/java/StorageLevels.java | 11 +++++++++++ core/src/main/scala/spark/storage/StorageLevel.scala | 6 +++--- docs/java-programming-guide.md | 3 ++- docs/scala-programming-guide.md | 3 ++- 4 files changed, 18 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/spark/api/java/StorageLevels.java b/core/src/main/scala/spark/api/java/StorageLevels.java index 722af3c06c..5e5845ac3a 100644 --- a/core/src/main/scala/spark/api/java/StorageLevels.java +++ b/core/src/main/scala/spark/api/java/StorageLevels.java @@ -17,4 +17,15 @@ public class StorageLevels { public static final StorageLevel MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2); public static final StorageLevel MEMORY_AND_DISK_SER = new StorageLevel(true, true, false, 1); public static final StorageLevel MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2); + + /** + * Create a new StorageLevel object. + * @param useDisk saved to disk, if true + * @param useMemory saved to memory, if true + * @param deserialized saved as deserialized objects, if true + * @param replication replication factor + */ + public static StorageLevel create(boolean useDisk, boolean useMemory, boolean deserialized, int replication) { + return StorageLevel.apply(useDisk, useMemory, deserialized, replication); + } } diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala index f2535ae5ae..45d6ea2656 100644 --- a/core/src/main/scala/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/spark/storage/StorageLevel.scala @@ -7,10 +7,10 @@ import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} * whether to drop the RDD to disk if it falls out of memory, whether to keep the data in memory * in a serialized format, and whether to replicate the RDD partitions on multiple nodes. * The [[spark.storage.StorageLevel$]] singleton object contains some static constants for - * commonly useful storage levels. The recommended method to create your own storage level - * object is to use `StorageLevel.apply(...)` from the singleton object. + * commonly useful storage levels. To create your own storage level object, use the factor method + * of the singleton object (`StorageLevel(...)`). */ -class StorageLevel( +class StorageLevel private( private var useDisk_ : Boolean, private var useMemory_ : Boolean, private var deserialized_ : Boolean, diff --git a/docs/java-programming-guide.md b/docs/java-programming-guide.md index 188ca4995e..37a906ea1c 100644 --- a/docs/java-programming-guide.md +++ b/docs/java-programming-guide.md @@ -75,7 +75,8 @@ class has a single abstract method, `call()`, that must be implemented. ## Storage Levels RDD [storage level](scala-programming-guide.html#rdd-persistence) constants, such as `MEMORY_AND_DISK`, are -declared in the [spark.api.java.StorageLevels](api/core/index.html#spark.api.java.StorageLevels) class. +declared in the [spark.api.java.StorageLevels](api/core/index.html#spark.api.java.StorageLevels) class. To +define your own storage level, you can use StorageLevels.create(...). # Other Features diff --git a/docs/scala-programming-guide.md b/docs/scala-programming-guide.md index 7350eca837..301b330a79 100644 --- a/docs/scala-programming-guide.md +++ b/docs/scala-programming-guide.md @@ -301,7 +301,8 @@ We recommend going through the following process to select one: * Use the replicated storage levels if you want fast fault recovery (e.g. if using Spark to serve requests from a web application). *All* the storage levels provide full fault tolerance by recomputing lost data, but the replicated ones let you continue running tasks on the RDD without waiting to recompute a lost partition. - + +If you want to define your own storage level (say, with replication factor of 3 instead of 2), then use the function factor method `apply()` of the [`StorageLevel`](api/core/index.html#spark.storage.StorageLevel$) singleton object. # Shared Variables From 9a27062260490336a3bfa97c6efd39b1e7e81573 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 23 Jan 2013 01:34:44 -0800 Subject: [PATCH 176/291] Force generation increment after shuffle map stage --- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 39a1e6d6c6..d8a9049e81 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -445,9 +445,16 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with logInfo("waiting: " + waiting) logInfo("failed: " + failed) if (stage.shuffleDep != None) { + // We supply true to increment the generation number here in case this is a + // recomputation of the map outputs. In that case, some nodes may have cached + // locations with holes (from when we detected the error) and will need the + // generation incremented to refetch them. + // TODO: Only increment the generation number if this is not the first time + // we registered these map outputs. mapOutputTracker.registerMapOutputs( stage.shuffleDep.get.shuffleId, - stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray) + stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray, + true) } updateCacheLocs() if (stage.outputLocs.count(_ == Nil) != 0) { From d209b6b7641059610f734414ea05e0494b5510b0 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 23 Jan 2013 01:35:14 -0800 Subject: [PATCH 177/291] Extra debugging from hostLost() --- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index d8a9049e81..740aec2e61 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -528,7 +528,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val currentGeneration = maybeGeneration.getOrElse(mapOutputTracker.getGeneration) if (!failedGeneration.contains(host) || failedGeneration(host) < currentGeneration) { failedGeneration(host) = currentGeneration - logInfo("Host lost: " + host) + logInfo("Host lost: " + host + " (generation " + currentGeneration + ")") env.blockManager.master.notifyADeadHost(host) // TODO: This will be really slow if we keep accumulating shuffle map stages for ((shuffleId, stage) <- shuffleToMapStage) { @@ -541,6 +541,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with } cacheTracker.cacheLost(host) updateCacheLocs() + } else { + logDebug("Additional host lost message for " + host + + "(generation " + currentGeneration + ")") } } From 0b506dd2ecec909cd514143389d0846db2d194ed Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 23 Jan 2013 01:37:51 -0800 Subject: [PATCH 178/291] Add tests of various node failure scenarios. --- .../test/scala/spark/DistributedSuite.scala | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala index cacc2796b6..0d6b265e54 100644 --- a/core/src/test/scala/spark/DistributedSuite.scala +++ b/core/src/test/scala/spark/DistributedSuite.scala @@ -188,4 +188,76 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter val values = sc.parallelize(1 to 2, 2).map(x => System.getenv("TEST_VAR")).collect() assert(values.toSeq === Seq("TEST_VALUE", "TEST_VALUE")) } + + test("recover from node failures") { + import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity} + DistributedSuite.amMaster = true + sc = new SparkContext(clusterUrl, "test") + val data = sc.parallelize(Seq(true, true), 2) + val singleton = sc.parallelize(Seq(true), 1) + assert(data.count === 2) // force executors to start + val masterId = SparkEnv.get.blockManager.blockManagerId + assert(data.map(markNodeIfIdentity).collect.size === 2) + assert(data.map(failOnMarkedIdentity).collect.size === 2) + } + + test("recover from repeated node failures during shuffle-map") { + import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity} + DistributedSuite.amMaster = true + sc = new SparkContext(clusterUrl, "test") + for (i <- 1 to 3) { + val data = sc.parallelize(Seq(true, false), 2) + val singleton = sc.parallelize(Seq(false), 1) + assert(data.count === 2) + assert(data.map(markNodeIfIdentity).collect.size === 2) + assert(data.map(failOnMarkedIdentity).map(x => x -> x).groupByKey.count === 2) + } + } + + test("recover from repeated node failures during shuffle-reduce") { + import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity} + DistributedSuite.amMaster = true + sc = new SparkContext(clusterUrl, "test") + for (i <- 1 to 3) { + val data = sc.parallelize(Seq(true, true), 2) + val singleton = sc.parallelize(Seq(false), 1) + assert(data.count === 2) + assert(data.map(markNodeIfIdentity).collect.size === 2) + // This relies on mergeCombiners being used to perform the actual reduce for this + // test to actually be testing what it claims. + val grouped = data.map(x => x -> x).combineByKey( + x => x, + (x: Boolean, y: Boolean) => x, + (x: Boolean, y: Boolean) => failOnMarkedIdentity(x) + ) + assert(grouped.collect.size === 1) + } + } +} + +object DistributedSuite { + // Indicates whether this JVM is marked for failure. + var mark = false + + // Set by test to remember if we are in the driver program so we can assert + // that we are not. + var amMaster = false + + // Act like an identity function, but if the argument is true, set mark to true. + def markNodeIfIdentity(item: Boolean): Boolean = { + if (item) { + assert(!amMaster) + mark = true + } + item + } + + // Act like an identity function, but if mark was set to true previously, fail, + // crashing the entire JVM. + def failOnMarkedIdentity(item: Boolean): Boolean = { + if (mark) { + System.exit(42) + } + item + } } From 79d55700ce2559051ac61cc2fb72a67fd7035926 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 23 Jan 2013 01:57:09 -0800 Subject: [PATCH 179/291] One more fix. Made even default constructor of BlockManagerId private to prevent such problems in the future. --- .../src/main/scala/spark/storage/BlockManagerId.scala | 11 ++++++----- .../scala/spark/storage/BlockManagerMessages.scala | 3 +-- core/src/main/scala/spark/storage/StorageLevel.scala | 7 +++++++ 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/spark/storage/BlockManagerId.scala b/core/src/main/scala/spark/storage/BlockManagerId.scala index 26c98f2ac8..abb8b45a1f 100644 --- a/core/src/main/scala/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/spark/storage/BlockManagerId.scala @@ -16,9 +16,7 @@ private[spark] class BlockManagerId private ( private var port_ : Int ) extends Externalizable { - private def this(in: ObjectInput) = this(in.readUTF(), in.readInt()) - - def this() = this(null, 0) // For deserialization only + private def this() = this(null, 0) // For deserialization only def ip = ip_ @@ -53,8 +51,11 @@ private[spark] object BlockManagerId { def apply(ip: String, port: Int) = getCachedBlockManagerId(new BlockManagerId(ip, port)) - def apply(in: ObjectInput) = - getCachedBlockManagerId(new BlockManagerId(in)) + def apply(in: ObjectInput) = { + val obj = new BlockManagerId() + obj.readExternal(in) + getCachedBlockManagerId(obj) + } val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]() diff --git a/core/src/main/scala/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/spark/storage/BlockManagerMessages.scala index 7437fc63eb..30483b0b37 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMessages.scala @@ -56,8 +56,7 @@ class UpdateBlockInfo( override def readExternal(in: ObjectInput) { blockManagerId = BlockManagerId(in) blockId = in.readUTF() - storageLevel = new StorageLevel() - storageLevel.readExternal(in) + storageLevel = StorageLevel(in) memSize = in.readInt() diskSize = in.readInt() } diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala index 45d6ea2656..d1d1c61c1c 100644 --- a/core/src/main/scala/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/spark/storage/StorageLevel.scala @@ -104,6 +104,13 @@ object StorageLevel { def apply(flags: Int, replication: Int) = getCachedStorageLevel(new StorageLevel(flags, replication)) + /** Read StorageLevel object from ObjectInput stream */ + def apply(in: ObjectInput) = { + val obj = new StorageLevel() + obj.readExternal(in) + getCachedStorageLevel(obj) + } + private[spark] val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]() From 666ce431aa03239d580a8c78b3a2f34a851eb413 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 23 Jan 2013 03:15:36 -0800 Subject: [PATCH 180/291] Added support for rescheduling unprocessed batches on master failure. --- .../scala/spark/streaming/Checkpoint.scala | 3 +- .../scala/spark/streaming/JobManager.scala | 30 ++++++++++++++++++- .../scala/spark/streaming/Scheduler.scala | 5 +++- .../spark/streaming/StreamingContext.scala | 4 +-- .../spark/streaming/InputStreamsSuite.scala | 23 +++++++++----- 5 files changed, 53 insertions(+), 12 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala index 2f3adb39c2..b9eb7f8ec4 100644 --- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -17,7 +17,8 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) val jars = ssc.sc.jars val graph = ssc.graph val checkpointDir = ssc.checkpointDir - val checkpointDuration: Duration = ssc.checkpointDuration + val checkpointDuration = ssc.checkpointDuration + val pendingTimes = ssc.scheduler.jobManager.getPendingTimes() def validate() { assert(master != null, "Checkpoint.master is null") diff --git a/streaming/src/main/scala/spark/streaming/JobManager.scala b/streaming/src/main/scala/spark/streaming/JobManager.scala index 3b910538e0..5acdd01e58 100644 --- a/streaming/src/main/scala/spark/streaming/JobManager.scala +++ b/streaming/src/main/scala/spark/streaming/JobManager.scala @@ -3,6 +3,8 @@ package spark.streaming import spark.Logging import spark.SparkEnv import java.util.concurrent.Executors +import collection.mutable.HashMap +import collection.mutable.ArrayBuffer private[streaming] @@ -19,15 +21,41 @@ class JobManager(ssc: StreamingContext, numThreads: Int = 1) extends Logging { case e: Exception => logError("Running " + job + " failed", e) } + clearJob(job) } } initLogging() val jobExecutor = Executors.newFixedThreadPool(numThreads) - + val jobs = new HashMap[Time, ArrayBuffer[Job]] + def runJob(job: Job) { + jobs.synchronized { + jobs.getOrElseUpdate(job.time, new ArrayBuffer[Job]) += job + } jobExecutor.execute(new JobHandler(ssc, job)) logInfo("Added " + job + " to queue") } + + private def clearJob(job: Job) { + jobs.synchronized { + val jobsOfTime = jobs.get(job.time) + if (jobsOfTime.isDefined) { + jobsOfTime.get -= job + if (jobsOfTime.get.isEmpty) { + jobs -= job.time + } + } else { + throw new Exception("Job finished for time " + job.time + + " but time does not exist in jobs") + } + } + } + + def getPendingTimes(): Array[Time] = { + jobs.synchronized { + jobs.keySet.toArray + } + } } diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index c04ed37de8..b77986a3ba 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -35,10 +35,13 @@ class Scheduler(ssc: StreamingContext) extends Logging { // either set the manual clock to the last checkpointed time, // or if the property is defined set it to that time if (clock.isInstanceOf[ManualClock]) { - val lastTime = ssc.getInitialCheckpoint.checkpointTime.milliseconds + val lastTime = ssc.initialCheckpoint.checkpointTime.milliseconds val jumpTime = System.getProperty("spark.streaming.manualClock.jump", "0").toLong clock.asInstanceOf[ManualClock].setTime(lastTime + jumpTime) } + // Reschedule the batches that were received but not processed before failure + ssc.initialCheckpoint.pendingTimes.foreach(time => generateRDDs(time)) + // Restart the timer timer.restart(graph.zeroTime.milliseconds) logInfo("Scheduler's timer restarted") } else { diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 2cf00e3baa..5781b1cc72 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -133,7 +133,7 @@ class StreamingContext private ( } } - protected[streaming] def getInitialCheckpoint(): Checkpoint = { + protected[streaming] def initialCheckpoint: Checkpoint = { if (isCheckpointPresent) cp_ else null } @@ -367,7 +367,7 @@ class StreamingContext private ( } /** - * Sstops the execution of the streams. + * Stops the execution of the streams. */ def stop() { try { diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala index 4f6204f205..34e51e9562 100644 --- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala @@ -44,7 +44,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.master.port") } - + /* test("network input stream") { // Start the server testServer = new TestServer(testPort) @@ -236,8 +236,8 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { assert(output(i).head.toString === expectedOutput(i)) } } - - test("file input stream with checkpoint") { + */ + test("file input stream with master failure") { // Create a temporary directory testDir = { var temp = File.createTempFile(".temp.", Random.nextInt().toString) @@ -251,11 +251,17 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { var ssc = new StreamingContext(master, framework, batchDuration) ssc.checkpoint(checkpointDir, checkpointInterval) val fileStream = ssc.textFileStream(testDir.toString) - val outputBuffer = new ArrayBuffer[Seq[Int]] - // Reduced over a large window to ensure that recovery from master failure + // Making value 3 take large time to process, to ensure that the master + // shuts down in the middle of processing the 3rd batch + val mappedStream = fileStream.map(s => { + val i = s.toInt + if (i == 3) Thread.sleep(1000) + i + }) + // Reducing over a large window to ensure that recovery from master failure // requires reprocessing of all the files seen before the failure - val reducedStream = fileStream.map(_.toInt) - .reduceByWindow(_ + _, batchDuration * 30, batchDuration) + val reducedStream = mappedStream.reduceByWindow(_ + _, batchDuration * 30, batchDuration) + val outputBuffer = new ArrayBuffer[Seq[Int]] var outputStream = new TestOutputStream(reducedStream, outputBuffer) ssc.registerOutputStream(outputStream) ssc.start() @@ -275,6 +281,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { assert(outputStream.output.size > 0, "No files processed before restart") ssc.stop() + // Create files while the master is down for (i <- Seq(4, 5, 6)) { FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n") Thread.sleep(1000) @@ -293,6 +300,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { Thread.sleep(500) } Thread.sleep(1000) + logInfo("Output = " + outputStream.output.mkString(",")) assert(outputStream.output.size > 0, "No files processed after restart") ssc.stop() @@ -316,6 +324,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { assert(outputBuffer(i).head === expectedOutput(i)) } } + } From 9c8ff1e55fb97980e7f0bb7f305c1ed0e59b749e Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 23 Jan 2013 07:31:49 -0800 Subject: [PATCH 181/291] Fixed checkpoint testcases --- streaming/src/test/java/JavaAPISuite.java | 23 +-- .../spark/streaming/CheckpointSuite.scala | 115 +++++++++++- .../spark/streaming/InputStreamsSuite.scala | 163 +----------------- 3 files changed, 129 insertions(+), 172 deletions(-) diff --git a/streaming/src/test/java/JavaAPISuite.java b/streaming/src/test/java/JavaAPISuite.java index c84e7331c7..7a189d85b4 100644 --- a/streaming/src/test/java/JavaAPISuite.java +++ b/streaming/src/test/java/JavaAPISuite.java @@ -45,7 +45,7 @@ public class JavaAPISuite implements Serializable { // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.master.port"); } - + /* @Test public void testCount() { List> inputData = Arrays.asList( @@ -434,7 +434,7 @@ public class JavaAPISuite implements Serializable { assertOrderInvariantEquals(expected, result); } - + */ /* * Performs an order-invariant comparison of lists representing two RDD streams. This allows * us to account for ordering variation within individual RDD's which occurs during windowing. @@ -450,7 +450,7 @@ public class JavaAPISuite implements Serializable { Assert.assertEquals(expected, actual); } - + /* // PairDStream Functions @Test public void testPairFilter() { @@ -897,7 +897,7 @@ public class JavaAPISuite implements Serializable { Assert.assertEquals(expected, result); } - + */ @Test public void testCheckpointMasterRecovery() throws InterruptedException { List> inputData = Arrays.asList( @@ -911,7 +911,6 @@ public class JavaAPISuite implements Serializable { Arrays.asList(1,4), Arrays.asList(8,7)); - File tempDir = Files.createTempDir(); ssc.checkpoint(tempDir.getAbsolutePath(), new Duration(1000)); @@ -927,14 +926,16 @@ public class JavaAPISuite implements Serializable { assertOrderInvariantEquals(expectedInitial, initialResult); Thread.sleep(1000); - ssc.stop(); + ssc = new JavaStreamingContext(tempDir.getAbsolutePath()); - ssc.start(); - List> finalResult = JavaCheckpointTestUtils.runStreams(ssc, 2, 2); - assertOrderInvariantEquals(expectedFinal, finalResult); + // Tweak to take into consideration that the last batch before failure + // will be re-processed after recovery + List> finalResult = JavaCheckpointTestUtils.runStreams(ssc, 2, 3); + assertOrderInvariantEquals(expectedFinal, finalResult.subList(1, 3)); } + /** TEST DISABLED: Pending a discussion about checkpoint() semantics with TD @Test public void testCheckpointofIndividualStream() throws InterruptedException { @@ -963,7 +964,7 @@ public class JavaAPISuite implements Serializable { assertOrderInvariantEquals(expected, result1); } */ - + /* // Input stream tests. These mostly just test that we can instantiate a given InputStream with // Java arguments and assign it to a JavaDStream without producing type errors. Testing of the // InputStream functionality is deferred to the existing Scala tests. @@ -1025,5 +1026,5 @@ public class JavaAPISuite implements Serializable { public void testFileStream() { JavaPairDStream foo = ssc.fileStream("/tmp/foo"); - } + }*/ } diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala index 58da4ee539..04ccca4c01 100644 --- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -7,6 +7,8 @@ import org.scalatest.BeforeAndAfter import org.apache.commons.io.FileUtils import collection.mutable.{SynchronizedBuffer, ArrayBuffer} import util.{Clock, ManualClock} +import scala.util.Random +import com.google.common.io.Files class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { @@ -32,7 +34,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { override def actuallyWait = true - test("basic stream+rdd recovery") { + test("basic rdd checkpoints + dstream graph checkpoint recovery") { assert(batchDuration === Milliseconds(500), "batchDuration for this test must be 1 second") assert(checkpointInterval === batchDuration, "checkpointInterval for this test much be same as batchDuration") @@ -117,7 +119,10 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { ssc = null } - test("map and reduceByKey") { + // This tests whether the systm can recover from a master failure with simple + // non-stateful operations. This assumes as reliable, replayable input + // source - TestInputDStream. + test("recovery with map and reduceByKey operations") { testCheckpointedOperation( Seq( Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq() ), (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _), @@ -126,7 +131,11 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { ) } - test("reduceByKeyAndWindowInv") { + + // This tests whether the ReduceWindowedDStream's RDD checkpoints works correctly such + // that the system can recover from a master failure. This assumes as reliable, + // replayable input source - TestInputDStream. + test("recovery with invertible reduceByKeyAndWindow operation") { val n = 10 val w = 4 val input = (1 to n).map(_ => Seq("a")).toSeq @@ -139,7 +148,11 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { testCheckpointedOperation(input, operation, output, 7) } - test("updateStateByKey") { + + // This tests whether the StateDStream's RDD checkpoints works correctly such + // that the system can recover from a master failure. This assumes as reliable, + // replayable input source - TestInputDStream. + test("recovery with updateStateByKey operation") { val input = (1 to 10).map(_ => Seq("a")).toSeq val output = (1 to 10).map(x => Seq(("a", x))).toSeq val operation = (st: DStream[String]) => { @@ -154,11 +167,99 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { testCheckpointedOperation(input, operation, output, 7) } + // This tests whether file input stream remembers what files were seen before + // the master failure and uses them again to process a large window operatoin. + // It also tests whether batches, whose processing was incomplete due to the + // failure, are re-processed or not. + test("recovery with file input stream") { + // Set up the streaming context and input streams + val testDir = Files.createTempDir() + var ssc = new StreamingContext(master, framework, batchDuration) + ssc.checkpoint(checkpointDir, checkpointInterval) + val fileStream = ssc.textFileStream(testDir.toString) + // Making value 3 take large time to process, to ensure that the master + // shuts down in the middle of processing the 3rd batch + val mappedStream = fileStream.map(s => { + val i = s.toInt + if (i == 3) Thread.sleep(1000) + i + }) + // Reducing over a large window to ensure that recovery from master failure + // requires reprocessing of all the files seen before the failure + val reducedStream = mappedStream.reduceByWindow(_ + _, batchDuration * 30, batchDuration) + val outputBuffer = new ArrayBuffer[Seq[Int]] + var outputStream = new TestOutputStream(reducedStream, outputBuffer) + ssc.registerOutputStream(outputStream) + ssc.start() + + // Create files and advance manual clock to process them + var clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + Thread.sleep(1000) + for (i <- Seq(1, 2, 3)) { + FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n") + // wait to make sure that the file is written such that it gets shown in the file listings + Thread.sleep(500) + clock.addToTime(batchDuration.milliseconds) + // wait to make sure that FileInputDStream picks up this file only and not any other file + Thread.sleep(500) + } + logInfo("Output = " + outputStream.output.mkString(",")) + assert(outputStream.output.size > 0, "No files processed before restart") + ssc.stop() + + // Create files while the master is down + for (i <- Seq(4, 5, 6)) { + FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n") + Thread.sleep(1000) + } + + // Restart stream computation from checkpoint and create more files to see whether + // they are being processed + logInfo("*********** RESTARTING ************") + ssc = new StreamingContext(checkpointDir) + ssc.start() + clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + for (i <- Seq(7, 8, 9)) { + FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n") + Thread.sleep(500) + clock.addToTime(batchDuration.milliseconds) + Thread.sleep(500) + } + Thread.sleep(1000) + logInfo("Output = " + outputStream.output.mkString(",")) + assert(outputStream.output.size > 0, "No files processed after restart") + ssc.stop() + + // Append the new output to the old buffer + outputStream = ssc.graph.getOutputStreams().head.asInstanceOf[TestOutputStream[Int]] + outputBuffer ++= outputStream.output + + // Verify whether data received by Spark Streaming was as expected + val expectedOutput = Seq(1, 3, 6, 28, 36, 45) + logInfo("--------------------------------") + logInfo("output, size = " + outputBuffer.size) + outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("expected output, size = " + expectedOutput.size) + expectedOutput.foreach(x => logInfo("[" + x + "]")) + logInfo("--------------------------------") + + // Verify whether all the elements received are as expected + assert(outputBuffer.size === expectedOutput.size) + for (i <- 0 until outputBuffer.size) { + assert(outputBuffer(i).size === 1) + assert(outputBuffer(i).head === expectedOutput(i)) + } + } + + /** - * Tests a streaming operation under checkpointing, by restart the operation + * Tests a streaming operation under checkpointing, by restarting the operation * from checkpoint file and verifying whether the final output is correct. * The output is assumed to have come from a reliable queue which an replay * data as required. + * + * NOTE: This takes into consideration that the last batch processed before + * master failure will be re-processed after restart/recovery. */ def testCheckpointedOperation[U: ClassManifest, V: ClassManifest]( input: Seq[Seq[U]], @@ -172,7 +273,8 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { val totalNumBatches = input.size val nextNumBatches = totalNumBatches - initialNumBatches val initialNumExpectedOutputs = initialNumBatches - val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs + val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs + 1 + // because the last batch will be processed again // Do the computation for initial number of batches, create checkpoint file and quit ssc = setupStreams[U, V](input, operation) @@ -188,6 +290,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { ) ssc = new StreamingContext(checkpointDir) val outputNew = runStreams[V](ssc, nextNumBatches, nextNumExpectedOutputs) + // the first element will be re-processed data of the last batch before restart verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true) ssc = null } diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala index 34e51e9562..aa08ea1141 100644 --- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala @@ -19,35 +19,24 @@ import org.apache.avro.ipc.specific.SpecificRequestor import java.nio.ByteBuffer import collection.JavaConversions._ import java.nio.charset.Charset +import com.google.common.io.Files class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") - val testPort = 9999 - var testServer: TestServer = null - var testDir: File = null - override def checkpointDir = "checkpoint" after { - FileUtils.deleteDirectory(new File(checkpointDir)) - if (testServer != null) { - testServer.stop() - testServer = null - } - if (testDir != null && testDir.exists()) { - FileUtils.deleteDirectory(testDir) - testDir = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.master.port") } - /* + + test("network input stream") { // Start the server - testServer = new TestServer(testPort) + val testPort = 9999 + val testServer = new TestServer(testPort) testServer.start() // Set up the streaming context and input streams @@ -93,46 +82,6 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } } - test("network input stream with checkpoint") { - // Start the server - testServer = new TestServer(testPort) - testServer.start() - - // Set up the streaming context and input streams - var ssc = new StreamingContext(master, framework, batchDuration) - ssc.checkpoint(checkpointDir, checkpointInterval) - val networkStream = ssc.networkTextStream("localhost", testPort, StorageLevel.MEMORY_AND_DISK) - var outputStream = new TestOutputStream(networkStream, new ArrayBuffer[Seq[String]]) - ssc.registerOutputStream(outputStream) - ssc.start() - - // Feed data to the server to send to the network receiver - var clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - for (i <- Seq(1, 2, 3)) { - testServer.send(i.toString + "\n") - Thread.sleep(100) - clock.addToTime(batchDuration.milliseconds) - } - Thread.sleep(500) - assert(outputStream.output.size > 0) - ssc.stop() - - // Restart stream computation from checkpoint and feed more data to see whether - // they are being received and processed - logInfo("*********** RESTARTING ************") - ssc = new StreamingContext(checkpointDir) - ssc.start() - clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - for (i <- Seq(4, 5, 6)) { - testServer.send(i.toString + "\n") - Thread.sleep(100) - clock.addToTime(batchDuration.milliseconds) - } - Thread.sleep(500) - outputStream = ssc.graph.getOutputStreams().head.asInstanceOf[TestOutputStream[String]] - assert(outputStream.output.size > 0) - ssc.stop() - } test("flume input stream") { // Set up the streaming context and input streams @@ -182,18 +131,10 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } } + test("file input stream") { - - // Create a temporary directory - testDir = { - var temp = File.createTempFile(".temp.", Random.nextInt().toString) - temp.delete() - temp.mkdirs() - logInfo("Created temp dir " + temp) - temp - } - // Set up the streaming context and input streams + val testDir = Files.createTempDir() val ssc = new StreamingContext(master, framework, batchDuration) val filestream = ssc.textFileStream(testDir.toString) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] @@ -235,96 +176,8 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { assert(output(i).size === 1) assert(output(i).head.toString === expectedOutput(i)) } + FileUtils.deleteDirectory(testDir) } - */ - test("file input stream with master failure") { - // Create a temporary directory - testDir = { - var temp = File.createTempFile(".temp.", Random.nextInt().toString) - temp.delete() - temp.mkdirs() - logInfo("Created temp dir " + temp) - temp - } - - // Set up the streaming context and input streams - var ssc = new StreamingContext(master, framework, batchDuration) - ssc.checkpoint(checkpointDir, checkpointInterval) - val fileStream = ssc.textFileStream(testDir.toString) - // Making value 3 take large time to process, to ensure that the master - // shuts down in the middle of processing the 3rd batch - val mappedStream = fileStream.map(s => { - val i = s.toInt - if (i == 3) Thread.sleep(1000) - i - }) - // Reducing over a large window to ensure that recovery from master failure - // requires reprocessing of all the files seen before the failure - val reducedStream = mappedStream.reduceByWindow(_ + _, batchDuration * 30, batchDuration) - val outputBuffer = new ArrayBuffer[Seq[Int]] - var outputStream = new TestOutputStream(reducedStream, outputBuffer) - ssc.registerOutputStream(outputStream) - ssc.start() - - // Create files and advance manual clock to process them - var clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - Thread.sleep(1000) - for (i <- Seq(1, 2, 3)) { - FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n") - // wait to make sure that the file is written such that it gets shown in the file listings - Thread.sleep(500) - clock.addToTime(batchDuration.milliseconds) - // wait to make sure that FileInputDStream picks up this file only and not any other file - Thread.sleep(500) - } - logInfo("Output = " + outputStream.output.mkString(",")) - assert(outputStream.output.size > 0, "No files processed before restart") - ssc.stop() - - // Create files while the master is down - for (i <- Seq(4, 5, 6)) { - FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n") - Thread.sleep(1000) - } - - // Restart stream computation from checkpoint and create more files to see whether - // they are being processed - logInfo("*********** RESTARTING ************") - ssc = new StreamingContext(checkpointDir) - ssc.start() - clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - for (i <- Seq(7, 8, 9)) { - FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n") - Thread.sleep(500) - clock.addToTime(batchDuration.milliseconds) - Thread.sleep(500) - } - Thread.sleep(1000) - logInfo("Output = " + outputStream.output.mkString(",")) - assert(outputStream.output.size > 0, "No files processed after restart") - ssc.stop() - - // Append the new output to the old buffer - outputStream = ssc.graph.getOutputStreams().head.asInstanceOf[TestOutputStream[Int]] - outputBuffer ++= outputStream.output - - // Verify whether data received by Spark Streaming was as expected - val expectedOutput = Seq(1, 3, 6, 28, 36, 45) - logInfo("--------------------------------") - logInfo("output, size = " + outputBuffer.size) - outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("expected output, size = " + expectedOutput.size) - expectedOutput.foreach(x => logInfo("[" + x + "]")) - logInfo("--------------------------------") - - // Verify whether all the elements received are as expected - assert(outputBuffer.size === expectedOutput.size) - for (i <- 0 until outputBuffer.size) { - assert(outputBuffer(i).size === 1) - assert(outputBuffer(i).head === expectedOutput(i)) - } - } - } From ae2ed2947d43860c74a8d40767e289ca78073977 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 23 Jan 2013 10:36:18 -0800 Subject: [PATCH 182/291] Allow PySpark's SparkFiles to be used from driver Fix minor documentation formatting issues. --- core/src/main/scala/spark/SparkFiles.java | 8 +++---- python/pyspark/context.py | 27 ++++++++++++++++++----- python/pyspark/files.py | 20 ++++++++++++++--- python/pyspark/tests.py | 23 +++++++++++++++++++ python/pyspark/worker.py | 1 + python/test_support/hello.txt | 1 + 6 files changed, 67 insertions(+), 13 deletions(-) create mode 100755 python/test_support/hello.txt diff --git a/core/src/main/scala/spark/SparkFiles.java b/core/src/main/scala/spark/SparkFiles.java index b59d8ce93f..566aec622c 100644 --- a/core/src/main/scala/spark/SparkFiles.java +++ b/core/src/main/scala/spark/SparkFiles.java @@ -3,23 +3,23 @@ package spark; import java.io.File; /** - * Resolves paths to files added through `addFile(). + * Resolves paths to files added through `SparkContext.addFile()`. */ public class SparkFiles { private SparkFiles() {} /** - * Get the absolute path of a file added through `addFile()`. + * Get the absolute path of a file added through `SparkContext.addFile()`. */ public static String get(String filename) { return new File(getRootDirectory(), filename).getAbsolutePath(); } /** - * Get the root directory that contains files added through `addFile()`. + * Get the root directory that contains files added through `SparkContext.addFile()`. */ public static String getRootDirectory() { return SparkEnv.get().sparkFilesDir(); } -} \ No newline at end of file +} diff --git a/python/pyspark/context.py b/python/pyspark/context.py index b8d7dc05af..3e33776af0 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -1,12 +1,15 @@ import os import atexit import shutil +import sys import tempfile +from threading import Lock from tempfile import NamedTemporaryFile from pyspark import accumulators from pyspark.accumulators import Accumulator from pyspark.broadcast import Broadcast +from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway from pyspark.serializers import dump_pickle, write_with_length, batched from pyspark.rdd import RDD @@ -27,6 +30,8 @@ class SparkContext(object): _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile _takePartition = jvm.PythonRDD.takePartition _next_accum_id = 0 + _active_spark_context = None + _lock = Lock() def __init__(self, master, jobName, sparkHome=None, pyFiles=None, environment=None, batchSize=1024): @@ -46,6 +51,11 @@ class SparkContext(object): Java object. Set 1 to disable batching or -1 to use an unlimited batch size. """ + with SparkContext._lock: + if SparkContext._active_spark_context: + raise ValueError("Cannot run multiple SparkContexts at once") + else: + SparkContext._active_spark_context = self self.master = master self.jobName = jobName self.sparkHome = sparkHome or None # None becomes null in Py4J @@ -75,6 +85,8 @@ class SparkContext(object): # Deploy any code dependencies specified in the constructor for path in (pyFiles or []): self.addPyFile(path) + SparkFiles._sc = self + sys.path.append(SparkFiles.getRootDirectory()) @property def defaultParallelism(self): @@ -85,17 +97,20 @@ class SparkContext(object): return self._jsc.sc().defaultParallelism() def __del__(self): - if self._jsc: - self._jsc.stop() - if self._accumulatorServer: - self._accumulatorServer.shutdown() + self.stop() def stop(self): """ Shut down the SparkContext. """ - self._jsc.stop() - self._jsc = None + if self._jsc: + self._jsc.stop() + self._jsc = None + if self._accumulatorServer: + self._accumulatorServer.shutdown() + self._accumulatorServer = None + with SparkContext._lock: + SparkContext._active_spark_context = None def parallelize(self, c, numSlices=None): """ diff --git a/python/pyspark/files.py b/python/pyspark/files.py index de1334f046..98f6a399cc 100644 --- a/python/pyspark/files.py +++ b/python/pyspark/files.py @@ -4,13 +4,15 @@ import os class SparkFiles(object): """ Resolves paths to files added through - L{addFile()}. + L{SparkContext.addFile()}. SparkFiles contains only classmethods; users should not create SparkFiles instances. """ _root_directory = None + _is_running_on_worker = False + _sc = None def __init__(self): raise NotImplementedError("Do not construct SparkFiles objects") @@ -18,7 +20,19 @@ class SparkFiles(object): @classmethod def get(cls, filename): """ - Get the absolute path of a file added through C{addFile()}. + Get the absolute path of a file added through C{SparkContext.addFile()}. """ - path = os.path.join(SparkFiles._root_directory, filename) + path = os.path.join(SparkFiles.getRootDirectory(), filename) return os.path.abspath(path) + + @classmethod + def getRootDirectory(cls): + """ + Get the root directory that contains files added through + C{SparkContext.addFile()}. + """ + if cls._is_running_on_worker: + return cls._root_directory + else: + # This will have to change if we support multiple SparkContexts: + return cls._sc.jvm.spark.SparkFiles.getRootDirectory() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 4d70ee4f12..46ab34f063 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -4,22 +4,26 @@ individual modules. """ import os import shutil +import sys from tempfile import NamedTemporaryFile import time import unittest from pyspark.context import SparkContext +from pyspark.files import SparkFiles from pyspark.java_gateway import SPARK_HOME class PySparkTestCase(unittest.TestCase): def setUp(self): + self._old_sys_path = list(sys.path) class_name = self.__class__.__name__ self.sc = SparkContext('local[4]', class_name , batchSize=2) def tearDown(self): self.sc.stop() + sys.path = self._old_sys_path # To avoid Akka rebinding to the same port, since it doesn't unbind # immediately on shutdown self.sc.jvm.System.clearProperty("spark.master.port") @@ -84,6 +88,25 @@ class TestAddFile(PySparkTestCase): res = self.sc.parallelize(range(2)).map(func).first() self.assertEqual("Hello World!", res) + def test_add_file_locally(self): + path = os.path.join(SPARK_HOME, "python/test_support/hello.txt") + self.sc.addFile(path) + download_path = SparkFiles.get("hello.txt") + self.assertNotEqual(path, download_path) + with open(download_path) as test_file: + self.assertEquals("Hello World!\n", test_file.readline()) + + def test_add_py_file_locally(self): + # To ensure that we're actually testing addPyFile's effects, check that + # this fails due to `userlibrary` not being on the Python path: + def func(): + from userlibrary import UserClass + self.assertRaises(ImportError, func) + path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") + self.sc.addFile(path) + from userlibrary import UserClass + self.assertEqual("Hello World!", UserClass().hello()) + if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 4bf643da66..d33d6dd15f 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -26,6 +26,7 @@ def main(): split_index = read_int(sys.stdin) spark_files_dir = load_pickle(read_with_length(sys.stdin)) SparkFiles._root_directory = spark_files_dir + SparkFiles._is_running_on_worker = True sys.path.append(spark_files_dir) num_broadcast_variables = read_int(sys.stdin) for _ in range(num_broadcast_variables): diff --git a/python/test_support/hello.txt b/python/test_support/hello.txt new file mode 100755 index 0000000000..980a0d5f19 --- /dev/null +++ b/python/test_support/hello.txt @@ -0,0 +1 @@ +Hello World! From b47d054cfc5ef45b92a1c970388722ffa0283e66 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 23 Jan 2013 11:18:25 -0800 Subject: [PATCH 183/291] Remove use of abc.ABCMeta due to cloudpickle issue. cloudpickle runs into issues while pickling subclasses of AccumulatorParam, which may be related to this Python issue: http://bugs.python.org/issue7689 This seems hard to fix and the ABCMeta wasn't necessary, so I removed it. --- python/pyspark/accumulators.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 5a9269f9bb..61fcbbd376 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -25,7 +25,8 @@ >>> a.value 13 ->>> class VectorAccumulatorParam(object): +>>> from pyspark.accumulators import AccumulatorParam +>>> class VectorAccumulatorParam(AccumulatorParam): ... def zero(self, value): ... return [0.0] * len(value) ... def addInPlace(self, val1, val2): @@ -61,7 +62,6 @@ Traceback (most recent call last): Exception:... """ -from abc import ABCMeta, abstractmethod import struct import SocketServer import threading @@ -138,23 +138,20 @@ class AccumulatorParam(object): """ Helper object that defines how to accumulate values of a given type. """ - __metaclass__ = ABCMeta - @abstractmethod def zero(self, value): """ Provide a "zero value" for the type, compatible in dimensions with the provided C{value} (e.g., a zero vector) """ - return + raise NotImplementedError - @abstractmethod def addInPlace(self, value1, value2): """ Add two values of the accumulator's data type, returning a new value; for efficiency, can also update C{value1} in place and return it. """ - return + raise NotImplementedError class AddingAccumulatorParam(AccumulatorParam): From e1027ca6398fd5b1a99a2203df840911c4dccb27 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 23 Jan 2013 12:22:11 -0800 Subject: [PATCH 184/291] Actually add CacheManager. --- core/src/main/scala/spark/CacheManager.scala | 65 ++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 core/src/main/scala/spark/CacheManager.scala diff --git a/core/src/main/scala/spark/CacheManager.scala b/core/src/main/scala/spark/CacheManager.scala new file mode 100644 index 0000000000..a0b53fd9d6 --- /dev/null +++ b/core/src/main/scala/spark/CacheManager.scala @@ -0,0 +1,65 @@ +package spark + +import scala.collection.mutable.{ArrayBuffer, HashSet} +import spark.storage.{BlockManager, StorageLevel} + + +/** Spark class responsible for passing RDDs split contents to the BlockManager and making + sure a node doesn't load two copies of an RDD at once. + */ +private[spark] class CacheManager(blockManager: BlockManager) extends Logging { + private val loading = new HashSet[String] + + /** Gets or computes an RDD split. Used by RDD.iterator() when a RDD is cached. */ + def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel) + : Iterator[T] = { + val key = "rdd_%d_%d".format(rdd.id, split.index) + logInfo("Cache key is " + key) + blockManager.get(key) match { + case Some(cachedValues) => + // Split is in cache, so just return its values + logInfo("Found partition in cache!") + return cachedValues.asInstanceOf[Iterator[T]] + + case None => + // Mark the split as loading (unless someone else marks it first) + loading.synchronized { + if (loading.contains(key)) { + logInfo("Loading contains " + key + ", waiting...") + while (loading.contains(key)) { + try {loading.wait()} catch {case _ =>} + } + logInfo("Loading no longer contains " + key + ", so returning cached result") + // See whether someone else has successfully loaded it. The main way this would fail + // is for the RDD-level cache eviction policy if someone else has loaded the same RDD + // partition but we didn't want to make space for it. However, that case is unlikely + // because it's unlikely that two threads would work on the same RDD partition. One + // downside of the current code is that threads wait serially if this does happen. + blockManager.get(key) match { + case Some(values) => + return values.asInstanceOf[Iterator[T]] + case None => + logInfo("Whoever was loading " + key + " failed; we'll try it ourselves") + loading.add(key) + } + } else { + loading.add(key) + } + } + try { + // If we got here, we have to load the split + val elements = new ArrayBuffer[Any] + logInfo("Computing partition " + split) + elements ++= rdd.compute(split, context) + // Try to put this block in the blockManager + blockManager.put(key, elements, storageLevel, true) + return elements.iterator.asInstanceOf[Iterator[T]] + } finally { + loading.synchronized { + loading.remove(key) + loading.notifyAll() + } + } + } + } +} From 88b9d240fda7ca34c08752dfa66797eecb6db872 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 23 Jan 2013 12:40:38 -0800 Subject: [PATCH 185/291] Remove dead code in test. --- core/src/test/scala/spark/DistributedSuite.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala index 0d6b265e54..af66d33aa3 100644 --- a/core/src/test/scala/spark/DistributedSuite.scala +++ b/core/src/test/scala/spark/DistributedSuite.scala @@ -194,7 +194,6 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter DistributedSuite.amMaster = true sc = new SparkContext(clusterUrl, "test") val data = sc.parallelize(Seq(true, true), 2) - val singleton = sc.parallelize(Seq(true), 1) assert(data.count === 2) // force executors to start val masterId = SparkEnv.get.blockManager.blockManagerId assert(data.map(markNodeIfIdentity).collect.size === 2) @@ -207,7 +206,6 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter sc = new SparkContext(clusterUrl, "test") for (i <- 1 to 3) { val data = sc.parallelize(Seq(true, false), 2) - val singleton = sc.parallelize(Seq(false), 1) assert(data.count === 2) assert(data.map(markNodeIfIdentity).collect.size === 2) assert(data.map(failOnMarkedIdentity).map(x => x -> x).groupByKey.count === 2) From be4a115a7ec7fb6ec0d34f1a1a1bb2c9bbe7600e Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 23 Jan 2013 12:48:45 -0800 Subject: [PATCH 186/291] Clarify TODO. --- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 740aec2e61..14a3ef8ad7 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -76,7 +76,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with // sent with every task. When we detect a node failing, we note the current generation number // and failed host, increment it for new tasks, and use this to ignore stray ShuffleMapTask // results. - // TODO: Garbage collect information about failure generations when new stages start. + // TODO: Garbage collect information about failure generations when we know there are no more + // stray messages to detect. val failedGeneration = new HashMap[String, Long] val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done From e1985bfa04ad4583ac1f0f421cbe0182ce7c53df Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Mon, 21 Jan 2013 16:21:14 -0800 Subject: [PATCH 187/291] be sure to set class loader of kryo instances --- core/src/main/scala/spark/KryoSerializer.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala index 93d7327324..56919544e8 100644 --- a/core/src/main/scala/spark/KryoSerializer.scala +++ b/core/src/main/scala/spark/KryoSerializer.scala @@ -206,5 +206,8 @@ class KryoSerializer extends spark.serializer.Serializer with Logging { kryo } - def newInstance(): SerializerInstance = new KryoSerializerInstance(this) + def newInstance(): SerializerInstance = { + this.kryo.setClassLoader(Thread.currentThread().getContextClassLoader) + new KryoSerializerInstance(this) + } } From 5c7422292ecace947f78e5ebe97e83a355531af7 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 23 Jan 2013 12:59:51 -0800 Subject: [PATCH 188/291] Remove more dead code from test. --- core/src/test/scala/spark/DistributedSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala index af66d33aa3..0487e06d12 100644 --- a/core/src/test/scala/spark/DistributedSuite.scala +++ b/core/src/test/scala/spark/DistributedSuite.scala @@ -218,7 +218,6 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter sc = new SparkContext(clusterUrl, "test") for (i <- 1 to 3) { val data = sc.parallelize(Seq(true, true), 2) - val singleton = sc.parallelize(Seq(false), 1) assert(data.count === 2) assert(data.map(markNodeIfIdentity).collect.size === 2) // This relies on mergeCombiners being used to perform the actual reduce for this From 1dd82743e09789f8fdae2f5628545c0cb9f79245 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Wed, 23 Jan 2013 13:07:27 -0800 Subject: [PATCH 189/291] Fix compile error due to cherry-pick --- core/src/main/scala/spark/KryoSerializer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala index 56919544e8..0bd73e936b 100644 --- a/core/src/main/scala/spark/KryoSerializer.scala +++ b/core/src/main/scala/spark/KryoSerializer.scala @@ -207,7 +207,7 @@ class KryoSerializer extends spark.serializer.Serializer with Logging { } def newInstance(): SerializerInstance = { - this.kryo.setClassLoader(Thread.currentThread().getContextClassLoader) + this.kryo.get().setClassLoader(Thread.currentThread().getContextClassLoader) new KryoSerializerInstance(this) } } From eb222b720647c9e92a867c591cc4914b9a6cb5c1 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 23 Jan 2013 15:29:02 -0800 Subject: [PATCH 190/291] Added pruntSplits method to RDD. --- core/src/main/scala/spark/RDD.scala | 10 ++++++++ .../scala/spark/rdd/SplitsPruningRDD.scala | 24 +++++++++++++++++++ core/src/test/scala/spark/RDDSuite.scala | 22 +++++++++++------ 3 files changed, 49 insertions(+), 7 deletions(-) create mode 100644 core/src/main/scala/spark/rdd/SplitsPruningRDD.scala diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index e0d2eabb1d..3d93ff33bb 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -40,6 +40,7 @@ import spark.rdd.MapPartitionsRDD import spark.rdd.MapPartitionsWithSplitRDD import spark.rdd.PipedRDD import spark.rdd.SampledRDD +import spark.rdd.SplitsPruningRDD import spark.rdd.UnionRDD import spark.rdd.ZippedRDD import spark.storage.StorageLevel @@ -543,6 +544,15 @@ abstract class RDD[T: ClassManifest]( map(x => (f(x), x)) } + /** + * Prune splits (partitions) so Spark can avoid launching tasks on + * all splits. An example use case: If we know the RDD is partitioned by range, + * and the execution DAG has a filter on the key, we can avoid launching tasks + * on splits that don't have the range covering the key. + */ + def pruneSplits(splitsFilterFunc: Int => Boolean): RDD[T] = + new SplitsPruningRDD(this, splitsFilterFunc) + /** A private method for tests, to look at the contents of each partition */ private[spark] def collectPartitions(): Array[Array[T]] = { sc.runJob(this, (iter: Iterator[T]) => iter.toArray) diff --git a/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala b/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala new file mode 100644 index 0000000000..74e10265fc --- /dev/null +++ b/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala @@ -0,0 +1,24 @@ +package spark.rdd + +import spark.{OneToOneDependency, RDD, SparkEnv, Split, TaskContext} + +/** + * A RDD used to prune RDD splits so we can avoid launching tasks on + * all splits. An example use case: If we know the RDD is partitioned by range, + * and the execution DAG has a filter on the key, we can avoid launching tasks + * on splits that don't have the range covering the key. + */ +class SplitsPruningRDD[T: ClassManifest]( + prev: RDD[T], + @transient splitsFilterFunc: Int => Boolean) + extends RDD[T](prev) { + + @transient + val _splits: Array[Split] = prev.splits.filter(s => splitsFilterFunc(s.index)) + + override def compute(split: Split, context: TaskContext) = prev.iterator(split, context) + + override protected def getSplits = _splits + + override val partitioner = prev.partitioner +} diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index db217f8482..03aa2845f4 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -1,11 +1,9 @@ package spark import scala.collection.mutable.HashMap -import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter - +import org.scalatest.{BeforeAndAfter, FunSuite} +import spark.SparkContext._ import spark.rdd.CoalescedRDD -import SparkContext._ class RDDSuite extends FunSuite with BeforeAndAfter { @@ -104,7 +102,7 @@ class RDDSuite extends FunSuite with BeforeAndAfter { } test("caching with failures") { - sc = new SparkContext("local", "test") + sc = new SparkContext("local", "test") val onlySplit = new Split { override def index: Int = 0 } var shouldFail = true val rdd = new RDD[Int](sc, Nil) { @@ -136,8 +134,10 @@ class RDDSuite extends FunSuite with BeforeAndAfter { List(List(1, 2, 3, 4, 5), List(6, 7, 8, 9, 10))) // Check that the narrow dependency is also specified correctly - assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(0).toList === List(0, 1, 2, 3, 4)) - assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList === List(5, 6, 7, 8, 9)) + assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(0).toList === + List(0, 1, 2, 3, 4)) + assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList === + List(5, 6, 7, 8, 9)) val coalesced2 = new CoalescedRDD(data, 3) assert(coalesced2.collect().toList === (1 to 10).toList) @@ -168,4 +168,12 @@ class RDDSuite extends FunSuite with BeforeAndAfter { nums.zip(sc.parallelize(1 to 4, 1)).collect() } } + + test("split pruning") { + sc = new SparkContext("local", "test") + val data = sc.parallelize(1 to 10, 10) + // Note that split number starts from 0, so > 8 means only 10th partition left. + val prunedData = data.pruneSplits(splitNum => splitNum > 8).collect + assert(prunedData.size == 1 && prunedData(0) == 10) + } } From c24b3819dd474e13d6098150c174b2e7e4bc6498 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 23 Jan 2013 15:34:59 -0800 Subject: [PATCH 191/291] Added an extra assert for split size check. --- core/src/test/scala/spark/RDDSuite.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 03aa2845f4..ef74c99246 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -173,7 +173,10 @@ class RDDSuite extends FunSuite with BeforeAndAfter { sc = new SparkContext("local", "test") val data = sc.parallelize(1 to 10, 10) // Note that split number starts from 0, so > 8 means only 10th partition left. - val prunedData = data.pruneSplits(splitNum => splitNum > 8).collect - assert(prunedData.size == 1 && prunedData(0) == 10) + val prunedRdd = data.pruneSplits(splitNum => splitNum > 8) + assert(prunedRdd.splits.size == 1) + val prunedData = prunedRdd.collect + assert(prunedData.size == 1) + assert(prunedData(0) == 10) } } From 45cd50d5fe40869cdc237157e073cfb5ac47b27c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 23 Jan 2013 16:06:58 -0800 Subject: [PATCH 192/291] Updated assert == to ===. --- core/src/test/scala/spark/RDDSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index ef74c99246..5a3a12dfff 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -174,9 +174,9 @@ class RDDSuite extends FunSuite with BeforeAndAfter { val data = sc.parallelize(1 to 10, 10) // Note that split number starts from 0, so > 8 means only 10th partition left. val prunedRdd = data.pruneSplits(splitNum => splitNum > 8) - assert(prunedRdd.splits.size == 1) + assert(prunedRdd.splits.size === 1) val prunedData = prunedRdd.collect - assert(prunedData.size == 1) - assert(prunedData(0) == 10) + assert(prunedData.size === 1) + assert(prunedData(0) === 10) } } From 636e912f3289e422be9550752f5279d519062b75 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 23 Jan 2013 21:21:55 -0800 Subject: [PATCH 193/291] Created a PruneDependency to properly assign dependency for SplitsPruningRDD. --- core/src/main/scala/spark/Dependency.scala | 24 ++++++++++++++++--- .../scala/spark/rdd/SplitsPruningRDD.scala | 8 +++---- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/spark/Dependency.scala b/core/src/main/scala/spark/Dependency.scala index b85d2732db..7d5858e88e 100644 --- a/core/src/main/scala/spark/Dependency.scala +++ b/core/src/main/scala/spark/Dependency.scala @@ -5,6 +5,7 @@ package spark */ abstract class Dependency[T](val rdd: RDD[T]) extends Serializable + /** * Base class for dependencies where each partition of the parent RDD is used by at most one * partition of the child RDD. Narrow dependencies allow for pipelined execution. @@ -12,12 +13,13 @@ abstract class Dependency[T](val rdd: RDD[T]) extends Serializable abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) { /** * Get the parent partitions for a child partition. - * @param outputPartition a partition of the child RDD + * @param partitionId a partition of the child RDD * @return the partitions of the parent RDD that the child partition depends upon */ - def getParents(outputPartition: Int): Seq[Int] + def getParents(partitionId: Int): Seq[Int] } + /** * Represents a dependency on the output of a shuffle stage. * @param shuffleId the shuffle id @@ -32,6 +34,7 @@ class ShuffleDependency[K, V]( val shuffleId: Int = rdd.context.newShuffleId() } + /** * Represents a one-to-one dependency between partitions of the parent and child RDDs. */ @@ -39,6 +42,7 @@ class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) { override def getParents(partitionId: Int) = List(partitionId) } + /** * Represents a one-to-one dependency between ranges of partitions in the parent and child RDDs. * @param rdd the parent RDD @@ -48,7 +52,7 @@ class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) { */ class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int) extends NarrowDependency[T](rdd) { - + override def getParents(partitionId: Int) = { if (partitionId >= outStart && partitionId < outStart + length) { List(partitionId - outStart + inStart) @@ -57,3 +61,17 @@ class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int) } } } + + +/** + * Represents a dependency between the SplitsPruningRDD and its parent. In this + * case, the child RDD contains a subset of splits of the parents'. + */ +class PruneDependency[T](rdd: RDD[T], @transient splitsFilterFunc: Int => Boolean) + extends NarrowDependency[T](rdd) { + + @transient + val splits: Array[Split] = rdd.splits.filter(s => splitsFilterFunc(s.index)) + + override def getParents(partitionId: Int) = List(splits(partitionId).index) +} diff --git a/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala b/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala index 74e10265fc..7b44d85bb5 100644 --- a/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala +++ b/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala @@ -1,6 +1,6 @@ package spark.rdd -import spark.{OneToOneDependency, RDD, SparkEnv, Split, TaskContext} +import spark.{PruneDependency, RDD, SparkEnv, Split, TaskContext} /** * A RDD used to prune RDD splits so we can avoid launching tasks on @@ -11,12 +11,12 @@ import spark.{OneToOneDependency, RDD, SparkEnv, Split, TaskContext} class SplitsPruningRDD[T: ClassManifest]( prev: RDD[T], @transient splitsFilterFunc: Int => Boolean) - extends RDD[T](prev) { + extends RDD[T](prev.context, List(new PruneDependency(prev, splitsFilterFunc))) { @transient - val _splits: Array[Split] = prev.splits.filter(s => splitsFilterFunc(s.index)) + val _splits: Array[Split] = dependencies_.head.asInstanceOf[PruneDependency[T]].splits - override def compute(split: Split, context: TaskContext) = prev.iterator(split, context) + override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(split, context) override protected def getSplits = _splits From 81004b967e838fca0790727a3fea5a265ddbc69a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 23 Jan 2013 21:54:27 -0800 Subject: [PATCH 194/291] Marked prev RDD as transient in SplitsPruningRDD. --- core/src/main/scala/spark/rdd/SplitsPruningRDD.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala b/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala index 7b44d85bb5..9b1a210ba3 100644 --- a/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala +++ b/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala @@ -9,7 +9,7 @@ import spark.{PruneDependency, RDD, SparkEnv, Split, TaskContext} * on splits that don't have the range covering the key. */ class SplitsPruningRDD[T: ClassManifest]( - prev: RDD[T], + @transient prev: RDD[T], @transient splitsFilterFunc: Int => Boolean) extends RDD[T](prev.context, List(new PruneDependency(prev, splitsFilterFunc))) { @@ -20,5 +20,5 @@ class SplitsPruningRDD[T: ClassManifest]( override protected def getSplits = _splits - override val partitioner = prev.partitioner + override val partitioner = firstParent[T].partitioner } From eedc542a0276a5248c81446ee84f56d691e5f488 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 23 Jan 2013 22:14:23 -0800 Subject: [PATCH 195/291] Removed pruneSplits method in RDD and renamed SplitsPruningRDD to PartitionPruningRDD. --- core/src/main/scala/spark/RDD.scala | 10 -------- .../scala/spark/rdd/PartitionPruningRDD.scala | 24 +++++++++++++++++++ .../scala/spark/rdd/SplitsPruningRDD.scala | 24 ------------------- core/src/test/scala/spark/RDDSuite.scala | 6 ++--- 4 files changed, 27 insertions(+), 37 deletions(-) create mode 100644 core/src/main/scala/spark/rdd/PartitionPruningRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/SplitsPruningRDD.scala diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 3d93ff33bb..e0d2eabb1d 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -40,7 +40,6 @@ import spark.rdd.MapPartitionsRDD import spark.rdd.MapPartitionsWithSplitRDD import spark.rdd.PipedRDD import spark.rdd.SampledRDD -import spark.rdd.SplitsPruningRDD import spark.rdd.UnionRDD import spark.rdd.ZippedRDD import spark.storage.StorageLevel @@ -544,15 +543,6 @@ abstract class RDD[T: ClassManifest]( map(x => (f(x), x)) } - /** - * Prune splits (partitions) so Spark can avoid launching tasks on - * all splits. An example use case: If we know the RDD is partitioned by range, - * and the execution DAG has a filter on the key, we can avoid launching tasks - * on splits that don't have the range covering the key. - */ - def pruneSplits(splitsFilterFunc: Int => Boolean): RDD[T] = - new SplitsPruningRDD(this, splitsFilterFunc) - /** A private method for tests, to look at the contents of each partition */ private[spark] def collectPartitions(): Array[Array[T]] = { sc.runJob(this, (iter: Iterator[T]) => iter.toArray) diff --git a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala new file mode 100644 index 0000000000..3048949ef2 --- /dev/null +++ b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala @@ -0,0 +1,24 @@ +package spark.rdd + +import spark.{PruneDependency, RDD, SparkEnv, Split, TaskContext} + +/** + * A RDD used to prune RDD partitions/splits so we can avoid launching tasks on + * all partitions. An example use case: If we know the RDD is partitioned by range, + * and the execution DAG has a filter on the key, we can avoid launching tasks + * on partitions that don't have the range covering the key. + */ +class PartitionPruningRDD[T: ClassManifest]( + @transient prev: RDD[T], + @transient partitionFilterFunc: Int => Boolean) + extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) { + + @transient + val partitions_ : Array[Split] = dependencies_.head.asInstanceOf[PruneDependency[T]].splits + + override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(split, context) + + override protected def getSplits = partitions_ + + override val partitioner = firstParent[T].partitioner +} diff --git a/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala b/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala deleted file mode 100644 index 9b1a210ba3..0000000000 --- a/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala +++ /dev/null @@ -1,24 +0,0 @@ -package spark.rdd - -import spark.{PruneDependency, RDD, SparkEnv, Split, TaskContext} - -/** - * A RDD used to prune RDD splits so we can avoid launching tasks on - * all splits. An example use case: If we know the RDD is partitioned by range, - * and the execution DAG has a filter on the key, we can avoid launching tasks - * on splits that don't have the range covering the key. - */ -class SplitsPruningRDD[T: ClassManifest]( - @transient prev: RDD[T], - @transient splitsFilterFunc: Int => Boolean) - extends RDD[T](prev.context, List(new PruneDependency(prev, splitsFilterFunc))) { - - @transient - val _splits: Array[Split] = dependencies_.head.asInstanceOf[PruneDependency[T]].splits - - override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(split, context) - - override protected def getSplits = _splits - - override val partitioner = firstParent[T].partitioner -} diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 5a3a12dfff..73846131a9 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -3,7 +3,7 @@ package spark import scala.collection.mutable.HashMap import org.scalatest.{BeforeAndAfter, FunSuite} import spark.SparkContext._ -import spark.rdd.CoalescedRDD +import spark.rdd.{CoalescedRDD, PartitionPruningRDD} class RDDSuite extends FunSuite with BeforeAndAfter { @@ -169,11 +169,11 @@ class RDDSuite extends FunSuite with BeforeAndAfter { } } - test("split pruning") { + test("partition pruning") { sc = new SparkContext("local", "test") val data = sc.parallelize(1 to 10, 10) // Note that split number starts from 0, so > 8 means only 10th partition left. - val prunedRdd = data.pruneSplits(splitNum => splitNum > 8) + val prunedRdd = new PartitionPruningRDD(data, splitNum => splitNum > 8) assert(prunedRdd.splits.size === 1) val prunedData = prunedRdd.collect assert(prunedData.size === 1) From c109f29c97c9606dee45e6300d01a272dbb560aa Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 23 Jan 2013 22:22:03 -0800 Subject: [PATCH 196/291] Updated PruneDependency to change "split" to "partition". --- core/src/main/scala/spark/Dependency.scala | 10 +++++----- .../src/main/scala/spark/rdd/PartitionPruningRDD.scala | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/spark/Dependency.scala b/core/src/main/scala/spark/Dependency.scala index 7d5858e88e..647aee6eb5 100644 --- a/core/src/main/scala/spark/Dependency.scala +++ b/core/src/main/scala/spark/Dependency.scala @@ -64,14 +64,14 @@ class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int) /** - * Represents a dependency between the SplitsPruningRDD and its parent. In this - * case, the child RDD contains a subset of splits of the parents'. + * Represents a dependency between the PartitionPruningRDD and its parent. In this + * case, the child RDD contains a subset of partitions of the parents'. */ -class PruneDependency[T](rdd: RDD[T], @transient splitsFilterFunc: Int => Boolean) +class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boolean) extends NarrowDependency[T](rdd) { @transient - val splits: Array[Split] = rdd.splits.filter(s => splitsFilterFunc(s.index)) + val partitions: Array[Split] = rdd.splits.filter(s => partitionFilterFunc(s.index)) - override def getParents(partitionId: Int) = List(splits(partitionId).index) + override def getParents(partitionId: Int) = List(partitions(partitionId).index) } diff --git a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala index 3048949ef2..787b59ae8c 100644 --- a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala @@ -14,7 +14,7 @@ class PartitionPruningRDD[T: ClassManifest]( extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) { @transient - val partitions_ : Array[Split] = dependencies_.head.asInstanceOf[PruneDependency[T]].splits + val partitions_ : Array[Split] = dependencies_.head.asInstanceOf[PruneDependency[T]].partitions override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(split, context) From 67a43bc7e622e4dd9d53ccf80b441740d6ff4df5 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 23 Jan 2013 23:06:52 -0800 Subject: [PATCH 197/291] Added a clearDependencies method in PartitionPruningRDD. --- core/src/main/scala/spark/rdd/PartitionPruningRDD.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala index 787b59ae8c..97dd37950e 100644 --- a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala @@ -14,11 +14,16 @@ class PartitionPruningRDD[T: ClassManifest]( extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) { @transient - val partitions_ : Array[Split] = dependencies_.head.asInstanceOf[PruneDependency[T]].partitions + var partitions_ : Array[Split] = dependencies_.head.asInstanceOf[PruneDependency[T]].partitions override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(split, context) override protected def getSplits = partitions_ override val partitioner = firstParent[T].partitioner + + override def clearDependencies() { + super.clearDependencies() + partitions_ = null + } } From 230bda204778e6f3c0f5a20ad341f643146d97cb Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 15 Jan 2013 14:01:19 -0600 Subject: [PATCH 198/291] Add LocalSparkContext to manage common sc variable. --- .../test/scala/spark/AccumulatorSuite.scala | 32 ++------ .../src/test/scala/spark/BroadcastSuite.scala | 14 +--- .../test/scala/spark/CheckpointSuite.scala | 19 ++--- .../scala/spark/ClosureCleanerSuite.scala | 73 +++++++++---------- .../test/scala/spark/DistributedSuite.scala | 23 ++---- core/src/test/scala/spark/FailureSuite.scala | 14 +--- .../test/scala/spark/FileServerSuite.scala | 16 ++-- core/src/test/scala/spark/FileSuite.scala | 16 +--- .../test/scala/spark/LocalSparkContext.scala | 41 +++++++++++ .../scala/spark/MapOutputTrackerSuite.scala | 7 +- .../test/scala/spark/PartitioningSuite.scala | 15 +--- core/src/test/scala/spark/PipedRDDSuite.scala | 16 +--- core/src/test/scala/spark/RDDSuite.scala | 14 +--- core/src/test/scala/spark/ShuffleSuite.scala | 14 +--- core/src/test/scala/spark/SortingSuite.scala | 13 +--- .../src/test/scala/spark/ThreadingSuite.scala | 14 +--- .../spark/scheduler/TaskContextSuite.scala | 14 +--- 17 files changed, 109 insertions(+), 246 deletions(-) create mode 100644 core/src/test/scala/spark/LocalSparkContext.scala diff --git a/core/src/test/scala/spark/AccumulatorSuite.scala b/core/src/test/scala/spark/AccumulatorSuite.scala index d8be99dde7..78d64a44ae 100644 --- a/core/src/test/scala/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/spark/AccumulatorSuite.scala @@ -1,6 +1,5 @@ package spark -import org.scalatest.BeforeAndAfter import org.scalatest.FunSuite import org.scalatest.matchers.ShouldMatchers import collection.mutable @@ -9,18 +8,7 @@ import scala.math.exp import scala.math.signum import spark.SparkContext._ -class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter { - - var sc: SparkContext = null - - after { - if (sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - } +class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkContext { test ("basic accumulation"){ sc = new SparkContext("local", "test") @@ -53,10 +41,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter for (i <- 1 to maxI) { v should contain(i) } - sc.stop() - sc = null - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") + resetSparkContext() } } @@ -86,10 +71,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter x => acc.value += x } } should produce [SparkException] - sc.stop() - sc = null - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") + resetSparkContext() } } @@ -115,10 +97,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter bufferAcc.value should contain(i) mapAcc.value should contain (i -> i.toString) } - sc.stop() - sc = null - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") + resetSparkContext() } } @@ -134,8 +113,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter x => acc.localValue ++= x } acc.value should be ( (0 to maxI).toSet) - sc.stop() - sc = null + resetSparkContext() } } diff --git a/core/src/test/scala/spark/BroadcastSuite.scala b/core/src/test/scala/spark/BroadcastSuite.scala index 2d3302f0aa..362a31fb0d 100644 --- a/core/src/test/scala/spark/BroadcastSuite.scala +++ b/core/src/test/scala/spark/BroadcastSuite.scala @@ -1,20 +1,8 @@ package spark import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter -class BroadcastSuite extends FunSuite with BeforeAndAfter { - - var sc: SparkContext = _ - - after { - if (sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - } +class BroadcastSuite extends FunSuite with LocalSparkContext { test("basic broadcast") { sc = new SparkContext("local", "test") diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 51573254ca..33c317720c 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -1,34 +1,27 @@ package spark -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.FunSuite import java.io.File import spark.rdd._ import spark.SparkContext._ import storage.StorageLevel -class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { +class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { initLogging() - var sc: SparkContext = _ var checkpointDir: File = _ val partitioner = new HashPartitioner(2) - before { + override def beforeEach() { + super.beforeEach() checkpointDir = File.createTempFile("temp", "") checkpointDir.delete() - sc = new SparkContext("local", "test") sc.setCheckpointDir(checkpointDir.toString) } - after { - if (sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - + override def afterEach() { + super.afterEach() if (checkpointDir != null) { checkpointDir.delete() } diff --git a/core/src/test/scala/spark/ClosureCleanerSuite.scala b/core/src/test/scala/spark/ClosureCleanerSuite.scala index dfa2de80e6..b2d0dd4627 100644 --- a/core/src/test/scala/spark/ClosureCleanerSuite.scala +++ b/core/src/test/scala/spark/ClosureCleanerSuite.scala @@ -3,6 +3,7 @@ package spark import java.io.NotSerializableException import org.scalatest.FunSuite +import spark.LocalSparkContext._ import SparkContext._ class ClosureCleanerSuite extends FunSuite { @@ -43,13 +44,10 @@ object TestObject { def run(): Int = { var nonSer = new NonSerializable var x = 5 - val sc = new SparkContext("local", "test") - val nums = sc.parallelize(Array(1, 2, 3, 4)) - val answer = nums.map(_ + x).reduce(_ + _) - sc.stop() - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - return answer + return withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + nums.map(_ + x).reduce(_ + _) + } } } @@ -60,11 +58,10 @@ class TestClass extends Serializable { def run(): Int = { var nonSer = new NonSerializable - val sc = new SparkContext("local", "test") - val nums = sc.parallelize(Array(1, 2, 3, 4)) - val answer = nums.map(_ + getX).reduce(_ + _) - sc.stop() - return answer + return withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + nums.map(_ + getX).reduce(_ + _) + } } } @@ -73,11 +70,10 @@ class TestClassWithoutDefaultConstructor(x: Int) extends Serializable { def run(): Int = { var nonSer = new NonSerializable - val sc = new SparkContext("local", "test") - val nums = sc.parallelize(Array(1, 2, 3, 4)) - val answer = nums.map(_ + getX).reduce(_ + _) - sc.stop() - return answer + return withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + nums.map(_ + getX).reduce(_ + _) + } } } @@ -89,11 +85,10 @@ class TestClassWithoutFieldAccess { def run(): Int = { var nonSer2 = new NonSerializable var x = 5 - val sc = new SparkContext("local", "test") - val nums = sc.parallelize(Array(1, 2, 3, 4)) - val answer = nums.map(_ + x).reduce(_ + _) - sc.stop() - return answer + return withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + nums.map(_ + x).reduce(_ + _) + } } } @@ -102,16 +97,16 @@ object TestObjectWithNesting { def run(): Int = { var nonSer = new NonSerializable var answer = 0 - val sc = new SparkContext("local", "test") - val nums = sc.parallelize(Array(1, 2, 3, 4)) - var y = 1 - for (i <- 1 to 4) { - var nonSer2 = new NonSerializable - var x = i - answer += nums.map(_ + x + y).reduce(_ + _) + return withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + var y = 1 + for (i <- 1 to 4) { + var nonSer2 = new NonSerializable + var x = i + answer += nums.map(_ + x + y).reduce(_ + _) + } + answer } - sc.stop() - return answer } } @@ -121,14 +116,14 @@ class TestClassWithNesting(val y: Int) extends Serializable { def run(): Int = { var nonSer = new NonSerializable var answer = 0 - val sc = new SparkContext("local", "test") - val nums = sc.parallelize(Array(1, 2, 3, 4)) - for (i <- 1 to 4) { - var nonSer2 = new NonSerializable - var x = i - answer += nums.map(_ + x + getY).reduce(_ + _) + return withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + for (i <- 1 to 4) { + var nonSer2 = new NonSerializable + var x = i + answer += nums.map(_ + x + getY).reduce(_ + _) + } + answer } - sc.stop() - return answer } } diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala index cacc2796b6..83a2a549a9 100644 --- a/core/src/test/scala/spark/DistributedSuite.scala +++ b/core/src/test/scala/spark/DistributedSuite.scala @@ -15,41 +15,28 @@ import scala.collection.mutable.ArrayBuffer import SparkContext._ import storage.StorageLevel -class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter { +class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter with LocalSparkContext { val clusterUrl = "local-cluster[2,1,512]" - @transient var sc: SparkContext = _ - after { - if (sc != null) { - sc.stop() - sc = null - } System.clearProperty("spark.reducer.maxMbInFlight") System.clearProperty("spark.storage.memoryFraction") - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") } test("local-cluster format") { sc = new SparkContext("local-cluster[2,1,512]", "test") assert(sc.parallelize(1 to 2, 2).count() == 2) - sc.stop() - System.clearProperty("spark.master.port") + resetSparkContext() sc = new SparkContext("local-cluster[2 , 1 , 512]", "test") assert(sc.parallelize(1 to 2, 2).count() == 2) - sc.stop() - System.clearProperty("spark.master.port") + resetSparkContext() sc = new SparkContext("local-cluster[2, 1, 512]", "test") assert(sc.parallelize(1 to 2, 2).count() == 2) - sc.stop() - System.clearProperty("spark.master.port") + resetSparkContext() sc = new SparkContext("local-cluster[ 2, 1, 512 ]", "test") assert(sc.parallelize(1 to 2, 2).count() == 2) - sc.stop() - System.clearProperty("spark.master.port") - sc = null + resetSparkContext() } test("simple groupByKey") { diff --git a/core/src/test/scala/spark/FailureSuite.scala b/core/src/test/scala/spark/FailureSuite.scala index a3454f25f6..8c1445a465 100644 --- a/core/src/test/scala/spark/FailureSuite.scala +++ b/core/src/test/scala/spark/FailureSuite.scala @@ -1,7 +1,6 @@ package spark import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter import org.scalatest.prop.Checkers import scala.collection.mutable.ArrayBuffer @@ -23,18 +22,7 @@ object FailureSuiteState { } } -class FailureSuite extends FunSuite with BeforeAndAfter { - - var sc: SparkContext = _ - - after { - if (sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - } +class FailureSuite extends FunSuite with LocalSparkContext { // Run a 3-task map job in which task 1 deterministically fails once, and check // whether the job completes successfully and we ran 4 tasks in total. diff --git a/core/src/test/scala/spark/FileServerSuite.scala b/core/src/test/scala/spark/FileServerSuite.scala index b4283d9604..8215cbde02 100644 --- a/core/src/test/scala/spark/FileServerSuite.scala +++ b/core/src/test/scala/spark/FileServerSuite.scala @@ -2,17 +2,16 @@ package spark import com.google.common.io.Files import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter import java.io.{File, PrintWriter, FileReader, BufferedReader} import SparkContext._ -class FileServerSuite extends FunSuite with BeforeAndAfter { +class FileServerSuite extends FunSuite with LocalSparkContext { - @transient var sc: SparkContext = _ @transient var tmpFile : File = _ @transient var testJarFile : File = _ - before { + override def beforeEach() { + super.beforeEach() // Create a sample text file val tmpdir = new File(Files.createTempDir(), "test") tmpdir.mkdir() @@ -22,17 +21,12 @@ class FileServerSuite extends FunSuite with BeforeAndAfter { pw.close() } - after { - if (sc != null) { - sc.stop() - sc = null - } + override def afterEach() { + super.afterEach() // Clean up downloaded file if (tmpFile.exists) { tmpFile.delete() } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") } test("Distributing files locally") { diff --git a/core/src/test/scala/spark/FileSuite.scala b/core/src/test/scala/spark/FileSuite.scala index 554bea53a9..91b48c7456 100644 --- a/core/src/test/scala/spark/FileSuite.scala +++ b/core/src/test/scala/spark/FileSuite.scala @@ -6,24 +6,12 @@ import scala.io.Source import com.google.common.io.Files import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter import org.apache.hadoop.io._ import SparkContext._ -class FileSuite extends FunSuite with BeforeAndAfter { - - var sc: SparkContext = _ - - after { - if (sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - } - +class FileSuite extends FunSuite with LocalSparkContext { + test("text files") { sc = new SparkContext("local", "test") val tempDir = Files.createTempDir() diff --git a/core/src/test/scala/spark/LocalSparkContext.scala b/core/src/test/scala/spark/LocalSparkContext.scala new file mode 100644 index 0000000000..b5e31ddae3 --- /dev/null +++ b/core/src/test/scala/spark/LocalSparkContext.scala @@ -0,0 +1,41 @@ +package spark + +import org.scalatest.Suite +import org.scalatest.BeforeAndAfterEach + +/** Manages a local `sc` {@link SparkContext} variable, correctly stopping it after each test. */ +trait LocalSparkContext extends BeforeAndAfterEach { self: Suite => + + @transient var sc: SparkContext = _ + + override def afterEach() { + resetSparkContext() + super.afterEach() + } + + def resetSparkContext() = { + if (sc != null) { + LocalSparkContext.stop(sc) + sc = null + } + } + +} + +object LocalSparkContext { + def stop(sc: SparkContext) { + sc.stop() + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.master.port") + } + + /** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */ + def withSpark[T](sc: SparkContext)(f: SparkContext => T) = { + try { + f(sc) + } finally { + stop(sc) + } + } + +} \ No newline at end of file diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala index d3dd3a8fa4..774bbd65b1 100644 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala @@ -1,17 +1,13 @@ package spark import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter import akka.actor._ import spark.scheduler.MapStatus import spark.storage.BlockManagerId import spark.util.AkkaUtils -class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { - after { - System.clearProperty("spark.master.port") - } +class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { test("compressSize") { assert(MapOutputTracker.compressSize(0L) === 0) @@ -81,7 +77,6 @@ class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { } test("remote fetch") { - System.clearProperty("spark.master.host") val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0) System.setProperty("spark.master.port", boundPort.toString) diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala index eb3c8f238f..af1107cd19 100644 --- a/core/src/test/scala/spark/PartitioningSuite.scala +++ b/core/src/test/scala/spark/PartitioningSuite.scala @@ -1,25 +1,12 @@ package spark import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter import scala.collection.mutable.ArrayBuffer import SparkContext._ -class PartitioningSuite extends FunSuite with BeforeAndAfter { - - var sc: SparkContext = _ - - after { - if(sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - } - +class PartitioningSuite extends FunSuite with LocalSparkContext { test("HashPartitioner equality") { val p2 = new HashPartitioner(2) diff --git a/core/src/test/scala/spark/PipedRDDSuite.scala b/core/src/test/scala/spark/PipedRDDSuite.scala index 9b84b29227..a6344edf8f 100644 --- a/core/src/test/scala/spark/PipedRDDSuite.scala +++ b/core/src/test/scala/spark/PipedRDDSuite.scala @@ -1,21 +1,9 @@ package spark import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter import SparkContext._ -class PipedRDDSuite extends FunSuite with BeforeAndAfter { - - var sc: SparkContext = _ - - after { - if (sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - } +class PipedRDDSuite extends FunSuite with LocalSparkContext { test("basic pipe") { sc = new SparkContext("local", "test") @@ -51,5 +39,3 @@ class PipedRDDSuite extends FunSuite with BeforeAndAfter { } } - - diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index db217f8482..592427e97a 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -2,23 +2,11 @@ package spark import scala.collection.mutable.HashMap import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter import spark.rdd.CoalescedRDD import SparkContext._ -class RDDSuite extends FunSuite with BeforeAndAfter { - - var sc: SparkContext = _ - - after { - if (sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - } +class RDDSuite extends FunSuite with LocalSparkContext { test("basic operations") { sc = new SparkContext("local", "test") diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index bebb8ebe86..3493b9511f 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -3,7 +3,6 @@ package spark import scala.collection.mutable.ArrayBuffer import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter import org.scalatest.matchers.ShouldMatchers import org.scalatest.prop.Checkers import org.scalacheck.Arbitrary._ @@ -15,18 +14,7 @@ import com.google.common.io.Files import spark.rdd.ShuffledRDD import spark.SparkContext._ -class ShuffleSuite extends FunSuite with ShouldMatchers with BeforeAndAfter { - - var sc: SparkContext = _ - - after { - if (sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - } +class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { test("groupByKey") { sc = new SparkContext("local", "test") diff --git a/core/src/test/scala/spark/SortingSuite.scala b/core/src/test/scala/spark/SortingSuite.scala index 1ad11ff4c3..edb8c839fc 100644 --- a/core/src/test/scala/spark/SortingSuite.scala +++ b/core/src/test/scala/spark/SortingSuite.scala @@ -5,18 +5,7 @@ import org.scalatest.BeforeAndAfter import org.scalatest.matchers.ShouldMatchers import SparkContext._ -class SortingSuite extends FunSuite with BeforeAndAfter with ShouldMatchers with Logging { - - var sc: SparkContext = _ - - after { - if (sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - } +class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers with Logging { test("sortByKey") { sc = new SparkContext("local", "test") diff --git a/core/src/test/scala/spark/ThreadingSuite.scala b/core/src/test/scala/spark/ThreadingSuite.scala index e9b1837d89..ff315b6693 100644 --- a/core/src/test/scala/spark/ThreadingSuite.scala +++ b/core/src/test/scala/spark/ThreadingSuite.scala @@ -22,19 +22,7 @@ object ThreadingSuiteState { } } -class ThreadingSuite extends FunSuite with BeforeAndAfter { - - var sc: SparkContext = _ - - after { - if(sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - } - +class ThreadingSuite extends FunSuite with LocalSparkContext { test("accessing SparkContext form a different thread") { sc = new SparkContext("local", "test") diff --git a/core/src/test/scala/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/spark/scheduler/TaskContextSuite.scala index ba6f8b588f..a5db7103f5 100644 --- a/core/src/test/scala/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/spark/scheduler/TaskContextSuite.scala @@ -6,19 +6,9 @@ import spark.TaskContext import spark.RDD import spark.SparkContext import spark.Split +import spark.LocalSparkContext -class TaskContextSuite extends FunSuite with BeforeAndAfter { - - var sc: SparkContext = _ - - after { - if (sc != null) { - sc.stop() - sc = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") - } +class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { test("Calls executeOnCompleteCallbacks after failure") { var completed = false From b6fc6e67521e8a9a5291693cce3dc766da244395 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 24 Jan 2013 14:28:05 -0800 Subject: [PATCH 199/291] SPARK-541: Adding a warning for invalid Master URL Right now Spark silently parses master URL's which do not match any known regex as a Mesos URL. The Mesos error message when an invalid URL gets passed is really confusing, so this warns the user when the implicit conversion is happening. --- core/src/main/scala/spark/SparkContext.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 66bdbe7cda..bc9fdee8b6 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -112,6 +112,8 @@ class SparkContext( val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r // Regular expression for connecting to Spark deploy clusters val SPARK_REGEX = """(spark://.*)""".r + //Regular expression for connection to Mesos cluster + val MESOS_REGEX = """(mesos://.*)""".r master match { case "local" => @@ -152,6 +154,9 @@ class SparkContext( scheduler case _ => + if (MESOS_REGEX.findFirstIn(master).isEmpty) { + logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master)) + } MesosNativeLibrary.load() val scheduler = new ClusterScheduler(this) val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean From 7dfb82a992d47491174d7929e31351d26cadfcda Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 22 Jan 2013 15:25:41 -0600 Subject: [PATCH 200/291] Replace old 'master' term with 'driver'. --- bagel/src/test/scala/bagel/BagelSuite.scala | 2 +- .../main/scala/spark/MapOutputTracker.scala | 10 +-- core/src/main/scala/spark/SparkContext.scala | 20 ++--- core/src/main/scala/spark/SparkEnv.scala | 22 +++--- .../spark/broadcast/BitTorrentBroadcast.scala | 24 +++--- .../scala/spark/broadcast/Broadcast.scala | 6 +- .../spark/broadcast/BroadcastFactory.scala | 4 +- .../scala/spark/broadcast/HttpBroadcast.scala | 6 +- .../scala/spark/broadcast/MultiTracker.scala | 35 +++++---- .../scala/spark/broadcast/TreeBroadcast.scala | 52 ++++++------- .../spark/deploy/LocalSparkCluster.scala | 34 ++++----- .../spark/deploy/client/ClientListener.scala | 4 +- .../scala/spark/deploy/master/JobInfo.scala | 2 +- .../scala/spark/deploy/master/Master.scala | 18 ++--- .../executor/StandaloneExecutorBackend.scala | 26 +++---- .../cluster/SparkDeploySchedulerBackend.scala | 33 +++++---- .../cluster/StandaloneClusterMessage.scala | 8 +- .../cluster/StandaloneSchedulerBackend.scala | 74 +++++++++---------- .../mesos/CoarseMesosSchedulerBackend.scala | 6 +- .../spark/storage/BlockManagerMaster.scala | 69 +++++++++-------- .../scala/spark/storage/ThreadingTest.scala | 6 +- core/src/test/scala/spark/JavaAPISuite.java | 2 +- .../test/scala/spark/LocalSparkContext.scala | 2 +- .../scala/spark/MapOutputTrackerSuite.scala | 2 +- docs/configuration.md | 12 +-- python/pyspark/tests.py | 2 +- .../src/test/scala/spark/repl/ReplSuite.scala | 2 +- .../dstream/NetworkInputDStream.scala | 4 +- .../java/spark/streaming/JavaAPISuite.java | 2 +- .../streaming/BasicOperationsSuite.scala | 2 +- .../spark/streaming/CheckpointSuite.scala | 2 +- .../scala/spark/streaming/FailureSuite.scala | 2 +- .../spark/streaming/InputStreamsSuite.scala | 2 +- .../streaming/WindowOperationsSuite.scala | 2 +- 34 files changed, 248 insertions(+), 251 deletions(-) diff --git a/bagel/src/test/scala/bagel/BagelSuite.scala b/bagel/src/test/scala/bagel/BagelSuite.scala index ca59f46843..3c2f9c4616 100644 --- a/bagel/src/test/scala/bagel/BagelSuite.scala +++ b/bagel/src/test/scala/bagel/BagelSuite.scala @@ -23,7 +23,7 @@ class BagelSuite extends FunSuite with Assertions with BeforeAndAfter { sc = null } // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") + System.clearProperty("spark.driver.port") } test("halting by voting") { diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index ac02f3363a..d4f5164f7d 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -38,10 +38,7 @@ private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Ac } } -private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logging { - val ip: String = System.getProperty("spark.master.host", "localhost") - val port: Int = System.getProperty("spark.master.port", "7077").toInt - val actorName: String = "MapOutputTracker" +private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolean) extends Logging { val timeout = 10.seconds @@ -56,11 +53,14 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea var cacheGeneration = generation val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]] - var trackerActor: ActorRef = if (isMaster) { + val actorName: String = "MapOutputTracker" + var trackerActor: ActorRef = if (isDriver) { val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(this)), name = actorName) logInfo("Registered MapOutputTrackerActor actor") actor } else { + val ip = System.getProperty("spark.driver.host", "localhost") + val port = System.getProperty("spark.driver.port", "7077").toInt val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName) actorSystem.actorFor(url) } diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index bc9fdee8b6..d4991cb1e0 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -66,20 +66,20 @@ class SparkContext( // Ensure logging is initialized before we spawn any threads initLogging() - // Set Spark master host and port system properties - if (System.getProperty("spark.master.host") == null) { - System.setProperty("spark.master.host", Utils.localIpAddress) + // Set Spark driver host and port system properties + if (System.getProperty("spark.driver.host") == null) { + System.setProperty("spark.driver.host", Utils.localIpAddress) } - if (System.getProperty("spark.master.port") == null) { - System.setProperty("spark.master.port", "0") + if (System.getProperty("spark.driver.port") == null) { + System.setProperty("spark.driver.port", "0") } private val isLocal = (master == "local" || master.startsWith("local[")) // Create the Spark execution environment (cache, map output tracker, etc) private[spark] val env = SparkEnv.createFromSystemProperties( - System.getProperty("spark.master.host"), - System.getProperty("spark.master.port").toInt, + System.getProperty("spark.driver.host"), + System.getProperty("spark.driver.port").toInt, true, isLocal) SparkEnv.set(env) @@ -396,14 +396,14 @@ class SparkContext( /** * Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values - * to using the `+=` method. Only the master can access the accumulator's `value`. + * to using the `+=` method. Only the driver can access the accumulator's `value`. */ def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) = new Accumulator(initialValue, param) /** * Create an [[spark.Accumulable]] shared variable, to which tasks can add values with `+=`. - * Only the master can access the accumuable's `value`. + * Only the driver can access the accumuable's `value`. * @tparam T accumulator type * @tparam R type that can be added to the accumulator */ @@ -530,7 +530,7 @@ class SparkContext( /** * Run a function on a given set of partitions in an RDD and return the results. This is the main * entry point to the scheduler, by which all actions get launched. The allowLocal flag specifies - * whether the scheduler can run the computation on the master rather than shipping it out to the + * whether the scheduler can run the computation on the driver rather than shipping it out to the * cluster, for short actions like first(). */ def runJob[T, U: ClassManifest]( diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 2a7a8af83d..4034af610c 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -60,15 +60,15 @@ object SparkEnv extends Logging { def createFromSystemProperties( hostname: String, port: Int, - isMaster: Boolean, + isDriver: Boolean, isLocal: Boolean ) : SparkEnv = { val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port) - // Bit of a hack: If this is the master and our port was 0 (meaning bind to any free port), - // figure out which port number Akka actually bound to and set spark.master.port to it. - if (isMaster && port == 0) { - System.setProperty("spark.master.port", boundPort.toString) + // Bit of a hack: If this is the driver and our port was 0 (meaning bind to any free port), + // figure out which port number Akka actually bound to and set spark.driver.port to it. + if (isDriver && port == 0) { + System.setProperty("spark.driver.port", boundPort.toString) } val classLoader = Thread.currentThread.getContextClassLoader @@ -82,22 +82,22 @@ object SparkEnv extends Logging { val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer") - val masterIp: String = System.getProperty("spark.master.host", "localhost") - val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt + val driverIp: String = System.getProperty("spark.driver.host", "localhost") + val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt val blockManagerMaster = new BlockManagerMaster( - actorSystem, isMaster, isLocal, masterIp, masterPort) + actorSystem, isDriver, isLocal, driverIp, driverPort) val blockManager = new BlockManager(actorSystem, blockManagerMaster, serializer) val connectionManager = blockManager.connectionManager - val broadcastManager = new BroadcastManager(isMaster) + val broadcastManager = new BroadcastManager(isDriver) val closureSerializer = instantiateClass[Serializer]( "spark.closure.serializer", "spark.JavaSerializer") val cacheManager = new CacheManager(blockManager) - val mapOutputTracker = new MapOutputTracker(actorSystem, isMaster) + val mapOutputTracker = new MapOutputTracker(actorSystem, isDriver) val shuffleFetcher = instantiateClass[ShuffleFetcher]( "spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher") @@ -109,7 +109,7 @@ object SparkEnv extends Logging { // Set the sparkFiles directory, used when downloading dependencies. In local mode, // this is a temporary directory; in distributed mode, this is the executor's current working // directory. - val sparkFilesDir: String = if (isMaster) { + val sparkFilesDir: String = if (isDriver) { Utils.createTempDir().getAbsolutePath } else { "." diff --git a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala index 386f505f2a..adcb2d2415 100644 --- a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala @@ -31,7 +31,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: @transient var totalBlocks = -1 @transient var hasBlocks = new AtomicInteger(0) - // Used ONLY by Master to track how many unique blocks have been sent out + // Used ONLY by driver to track how many unique blocks have been sent out @transient var sentBlocks = new AtomicInteger(0) @transient var listenPortLock = new Object @@ -42,7 +42,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: @transient var serveMR: ServeMultipleRequests = null - // Used only in Master + // Used only in driver @transient var guideMR: GuideMultipleRequests = null // Used only in Workers @@ -99,14 +99,14 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: } // Must always come AFTER listenPort is created - val masterSource = + val driverSource = SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes) hasBlocksBitVector.synchronized { - masterSource.hasBlocksBitVector = hasBlocksBitVector + driverSource.hasBlocksBitVector = hasBlocksBitVector } // In the beginning, this is the only known source to Guide - listOfSources += masterSource + listOfSources += driverSource // Register with the Tracker MultiTracker.registerBroadcast(id, @@ -122,7 +122,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: case None => logInfo("Started reading broadcast variable " + id) - // Initializing everything because Master will only send null/0 values + // Initializing everything because driver will only send null/0 values // Only the 1st worker in a node can be here. Others will get from cache initializeWorkerVariables() @@ -151,7 +151,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: } } - // Initialize variables in the worker node. Master sends everything as 0/null + // Initialize variables in the worker node. Driver sends everything as 0/null private def initializeWorkerVariables() { arrayOfBlocks = null hasBlocksBitVector = null @@ -248,7 +248,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: // Receive source information from Guide var suitableSources = oisGuide.readObject.asInstanceOf[ListBuffer[SourceInfo]] - logDebug("Received suitableSources from Master " + suitableSources) + logDebug("Received suitableSources from Driver " + suitableSources) addToListOfSources(suitableSources) @@ -532,7 +532,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: oosSource.writeObject(blockToAskFor) oosSource.flush() - // CHANGED: Master might send some other block than the one + // CHANGED: Driver might send some other block than the one // requested to ensure fast spreading of all blocks. val recvStartTime = System.currentTimeMillis val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock] @@ -982,9 +982,9 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: // Receive which block to send var blockToSend = ois.readObject.asInstanceOf[Int] - // If it is master AND at least one copy of each block has not been + // If it is driver AND at least one copy of each block has not been // sent out already, MODIFY blockToSend - if (MultiTracker.isMaster && sentBlocks.get < totalBlocks) { + if (MultiTracker.isDriver && sentBlocks.get < totalBlocks) { blockToSend = sentBlocks.getAndIncrement } @@ -1031,7 +1031,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: private[spark] class BitTorrentBroadcastFactory extends BroadcastFactory { - def initialize(isMaster: Boolean) { MultiTracker.initialize(isMaster) } + def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) } def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = new BitTorrentBroadcast[T](value_, isLocal, id) diff --git a/core/src/main/scala/spark/broadcast/Broadcast.scala b/core/src/main/scala/spark/broadcast/Broadcast.scala index 2ffe7f741d..415bde5d67 100644 --- a/core/src/main/scala/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/spark/broadcast/Broadcast.scala @@ -15,7 +15,7 @@ abstract class Broadcast[T](private[spark] val id: Long) extends Serializable { } private[spark] -class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializable { +class BroadcastManager(val _isDriver: Boolean) extends Logging with Serializable { private var initialized = false private var broadcastFactory: BroadcastFactory = null @@ -33,7 +33,7 @@ class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializabl Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] // Initialize appropriate BroadcastFactory and BroadcastObject - broadcastFactory.initialize(isMaster) + broadcastFactory.initialize(isDriver) initialized = true } @@ -49,5 +49,5 @@ class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializabl def newBroadcast[T](value_ : T, isLocal: Boolean) = broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) - def isMaster = isMaster_ + def isDriver = _isDriver } diff --git a/core/src/main/scala/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/spark/broadcast/BroadcastFactory.scala index ab6d302827..5c6184c3c7 100644 --- a/core/src/main/scala/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/spark/broadcast/BroadcastFactory.scala @@ -7,7 +7,7 @@ package spark.broadcast * entire Spark job. */ private[spark] trait BroadcastFactory { - def initialize(isMaster: Boolean): Unit - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] + def initialize(isDriver: Boolean): Unit + def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T] def stop(): Unit } diff --git a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala index 8e490e6bad..7e30b8f7d2 100644 --- a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala @@ -48,7 +48,7 @@ extends Broadcast[T](id) with Logging with Serializable { } private[spark] class HttpBroadcastFactory extends BroadcastFactory { - def initialize(isMaster: Boolean) { HttpBroadcast.initialize(isMaster) } + def initialize(isDriver: Boolean) { HttpBroadcast.initialize(isDriver) } def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = new HttpBroadcast[T](value_, isLocal, id) @@ -69,12 +69,12 @@ private object HttpBroadcast extends Logging { private val cleaner = new MetadataCleaner("HttpBroadcast", cleanup) - def initialize(isMaster: Boolean) { + def initialize(isDriver: Boolean) { synchronized { if (!initialized) { bufferSize = System.getProperty("spark.buffer.size", "65536").toInt compress = System.getProperty("spark.broadcast.compress", "true").toBoolean - if (isMaster) { + if (isDriver) { createServer() } serverUri = System.getProperty("spark.httpBroadcast.uri") diff --git a/core/src/main/scala/spark/broadcast/MultiTracker.scala b/core/src/main/scala/spark/broadcast/MultiTracker.scala index 5e76dedb94..3fd77af73f 100644 --- a/core/src/main/scala/spark/broadcast/MultiTracker.scala +++ b/core/src/main/scala/spark/broadcast/MultiTracker.scala @@ -23,25 +23,24 @@ extends Logging { var ranGen = new Random private var initialized = false - private var isMaster_ = false + private var _isDriver = false private var stopBroadcast = false private var trackMV: TrackMultipleValues = null - def initialize(isMaster__ : Boolean) { + def initialize(__isDriver: Boolean) { synchronized { if (!initialized) { + _isDriver = __isDriver - isMaster_ = isMaster__ - - if (isMaster) { + if (isDriver) { trackMV = new TrackMultipleValues trackMV.setDaemon(true) trackMV.start() - // Set masterHostAddress to the master's IP address for the slaves to read - System.setProperty("spark.MultiTracker.MasterHostAddress", Utils.localIpAddress) + // Set DriverHostAddress to the driver's IP address for the slaves to read + System.setProperty("spark.MultiTracker.DriverHostAddress", Utils.localIpAddress) } initialized = true @@ -54,10 +53,10 @@ extends Logging { } // Load common parameters - private var MasterHostAddress_ = System.getProperty( - "spark.MultiTracker.MasterHostAddress", "") - private var MasterTrackerPort_ = System.getProperty( - "spark.broadcast.masterTrackerPort", "11111").toInt + private var DriverHostAddress_ = System.getProperty( + "spark.MultiTracker.DriverHostAddress", "") + private var DriverTrackerPort_ = System.getProperty( + "spark.broadcast.driverTrackerPort", "11111").toInt private var BlockSize_ = System.getProperty( "spark.broadcast.blockSize", "4096").toInt * 1024 private var MaxRetryCount_ = System.getProperty( @@ -91,11 +90,11 @@ extends Logging { private var EndGameFraction_ = System.getProperty( "spark.broadcast.endGameFraction", "0.95").toDouble - def isMaster = isMaster_ + def isDriver = _isDriver // Common config params - def MasterHostAddress = MasterHostAddress_ - def MasterTrackerPort = MasterTrackerPort_ + def DriverHostAddress = DriverHostAddress_ + def DriverTrackerPort = DriverTrackerPort_ def BlockSize = BlockSize_ def MaxRetryCount = MaxRetryCount_ @@ -123,7 +122,7 @@ extends Logging { var threadPool = Utils.newDaemonCachedThreadPool() var serverSocket: ServerSocket = null - serverSocket = new ServerSocket(MasterTrackerPort) + serverSocket = new ServerSocket(DriverTrackerPort) logInfo("TrackMultipleValues started at " + serverSocket) try { @@ -235,7 +234,7 @@ extends Logging { try { // Connect to the tracker to find out GuideInfo clientSocketToTracker = - new Socket(MultiTracker.MasterHostAddress, MultiTracker.MasterTrackerPort) + new Socket(MultiTracker.DriverHostAddress, MultiTracker.DriverTrackerPort) oosTracker = new ObjectOutputStream(clientSocketToTracker.getOutputStream) oosTracker.flush() @@ -276,7 +275,7 @@ extends Logging { } def registerBroadcast(id: Long, gInfo: SourceInfo) { - val socket = new Socket(MultiTracker.MasterHostAddress, MasterTrackerPort) + val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort) val oosST = new ObjectOutputStream(socket.getOutputStream) oosST.flush() val oisST = new ObjectInputStream(socket.getInputStream) @@ -303,7 +302,7 @@ extends Logging { } def unregisterBroadcast(id: Long) { - val socket = new Socket(MultiTracker.MasterHostAddress, MasterTrackerPort) + val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort) val oosST = new ObjectOutputStream(socket.getOutputStream) oosST.flush() val oisST = new ObjectInputStream(socket.getInputStream) diff --git a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala index f573512835..c55c476117 100644 --- a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala @@ -98,7 +98,7 @@ extends Broadcast[T](id) with Logging with Serializable { case None => logInfo("Started reading broadcast variable " + id) - // Initializing everything because Master will only send null/0 values + // Initializing everything because Driver will only send null/0 values // Only the 1st worker in a node can be here. Others will get from cache initializeWorkerVariables() @@ -157,55 +157,55 @@ extends Broadcast[T](id) with Logging with Serializable { listenPortLock.synchronized { listenPortLock.wait() } } - var clientSocketToMaster: Socket = null - var oosMaster: ObjectOutputStream = null - var oisMaster: ObjectInputStream = null + var clientSocketToDriver: Socket = null + var oosDriver: ObjectOutputStream = null + var oisDriver: ObjectInputStream = null // Connect and receive broadcast from the specified source, retrying the // specified number of times in case of failures var retriesLeft = MultiTracker.MaxRetryCount do { - // Connect to Master and send this worker's Information - clientSocketToMaster = new Socket(MultiTracker.MasterHostAddress, gInfo.listenPort) - oosMaster = new ObjectOutputStream(clientSocketToMaster.getOutputStream) - oosMaster.flush() - oisMaster = new ObjectInputStream(clientSocketToMaster.getInputStream) + // Connect to Driver and send this worker's Information + clientSocketToDriver = new Socket(MultiTracker.DriverHostAddress, gInfo.listenPort) + oosDriver = new ObjectOutputStream(clientSocketToDriver.getOutputStream) + oosDriver.flush() + oisDriver = new ObjectInputStream(clientSocketToDriver.getInputStream) - logDebug("Connected to Master's guiding object") + logDebug("Connected to Driver's guiding object") // Send local source information - oosMaster.writeObject(SourceInfo(hostAddress, listenPort)) - oosMaster.flush() + oosDriver.writeObject(SourceInfo(hostAddress, listenPort)) + oosDriver.flush() - // Receive source information from Master - var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo] + // Receive source information from Driver + var sourceInfo = oisDriver.readObject.asInstanceOf[SourceInfo] totalBlocks = sourceInfo.totalBlocks arrayOfBlocks = new Array[BroadcastBlock](totalBlocks) totalBlocksLock.synchronized { totalBlocksLock.notifyAll() } totalBytes = sourceInfo.totalBytes - logDebug("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort) + logDebug("Received SourceInfo from Driver:" + sourceInfo + " My Port: " + listenPort) val start = System.nanoTime val receptionSucceeded = receiveSingleTransmission(sourceInfo) val time = (System.nanoTime - start) / 1e9 - // Updating some statistics in sourceInfo. Master will be using them later + // Updating some statistics in sourceInfo. Driver will be using them later if (!receptionSucceeded) { sourceInfo.receptionFailed = true } - // Send back statistics to the Master - oosMaster.writeObject(sourceInfo) + // Send back statistics to the Driver + oosDriver.writeObject(sourceInfo) - if (oisMaster != null) { - oisMaster.close() + if (oisDriver != null) { + oisDriver.close() } - if (oosMaster != null) { - oosMaster.close() + if (oosDriver != null) { + oosDriver.close() } - if (clientSocketToMaster != null) { - clientSocketToMaster.close() + if (clientSocketToDriver != null) { + clientSocketToDriver.close() } retriesLeft -= 1 @@ -552,7 +552,7 @@ extends Broadcast[T](id) with Logging with Serializable { } private def sendObject() { - // Wait till receiving the SourceInfo from Master + // Wait till receiving the SourceInfo from Driver while (totalBlocks == -1) { totalBlocksLock.synchronized { totalBlocksLock.wait() } } @@ -576,7 +576,7 @@ extends Broadcast[T](id) with Logging with Serializable { private[spark] class TreeBroadcastFactory extends BroadcastFactory { - def initialize(isMaster: Boolean) { MultiTracker.initialize(isMaster) } + def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) } def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = new TreeBroadcast[T](value_, isLocal, id) diff --git a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala index 4211d80596..ae083efc8d 100644 --- a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala @@ -10,7 +10,7 @@ import spark.{Logging, Utils} import scala.collection.mutable.ArrayBuffer private[spark] -class LocalSparkCluster(numSlaves: Int, coresPerSlave: Int, memoryPerSlave: Int) extends Logging { +class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) extends Logging { val localIpAddress = Utils.localIpAddress @@ -19,33 +19,31 @@ class LocalSparkCluster(numSlaves: Int, coresPerSlave: Int, memoryPerSlave: Int) var masterPort : Int = _ var masterUrl : String = _ - val slaveActorSystems = ArrayBuffer[ActorSystem]() - val slaveActors = ArrayBuffer[ActorRef]() + val workerActorSystems = ArrayBuffer[ActorSystem]() + val workerActors = ArrayBuffer[ActorRef]() def start() : String = { - logInfo("Starting a local Spark cluster with " + numSlaves + " slaves.") + logInfo("Starting a local Spark cluster with " + numWorkers + " workers.") /* Start the Master */ val (actorSystem, masterPort) = AkkaUtils.createActorSystem("sparkMaster", localIpAddress, 0) masterActorSystem = actorSystem masterUrl = "spark://" + localIpAddress + ":" + masterPort - val actor = masterActorSystem.actorOf( + masterActor = masterActorSystem.actorOf( Props(new Master(localIpAddress, masterPort, 0)), name = "Master") - masterActor = actor - /* Start the Slaves */ - for (slaveNum <- 1 to numSlaves) { - /* We can pretend to test distributed stuff by giving the slaves distinct hostnames. + /* Start the Workers */ + for (workerNum <- 1 to numWorkers) { + /* We can pretend to test distributed stuff by giving the workers distinct hostnames. All of 127/8 should be a loopback, we use 127.100.*.* in hopes that it is sufficiently distinctive. */ - val slaveIpAddress = "127.100.0." + (slaveNum % 256) + val workerIpAddress = "127.100.0." + (workerNum % 256) val (actorSystem, boundPort) = - AkkaUtils.createActorSystem("sparkWorker" + slaveNum, slaveIpAddress, 0) - slaveActorSystems += actorSystem - val actor = actorSystem.actorOf( - Props(new Worker(slaveIpAddress, boundPort, 0, coresPerSlave, memoryPerSlave, masterUrl)), + AkkaUtils.createActorSystem("sparkWorker" + workerNum, workerIpAddress, 0) + workerActorSystems += actorSystem + workerActors += actorSystem.actorOf( + Props(new Worker(workerIpAddress, boundPort, 0, coresPerWorker, memoryPerWorker, masterUrl)), name = "Worker") - slaveActors += actor } return masterUrl @@ -53,9 +51,9 @@ class LocalSparkCluster(numSlaves: Int, coresPerSlave: Int, memoryPerSlave: Int) def stop() { logInfo("Shutting down local Spark cluster.") - // Stop the slaves before the master so they don't get upset that it disconnected - slaveActorSystems.foreach(_.shutdown()) - slaveActorSystems.foreach(_.awaitTermination()) + // Stop the workers before the master so they don't get upset that it disconnected + workerActorSystems.foreach(_.shutdown()) + workerActorSystems.foreach(_.awaitTermination()) masterActorSystem.shutdown() masterActorSystem.awaitTermination() } diff --git a/core/src/main/scala/spark/deploy/client/ClientListener.scala b/core/src/main/scala/spark/deploy/client/ClientListener.scala index da6abcc9c2..7035f4b394 100644 --- a/core/src/main/scala/spark/deploy/client/ClientListener.scala +++ b/core/src/main/scala/spark/deploy/client/ClientListener.scala @@ -12,7 +12,7 @@ private[spark] trait ClientListener { def disconnected(): Unit - def executorAdded(id: String, workerId: String, host: String, cores: Int, memory: Int): Unit + def executorAdded(fullId: String, workerId: String, host: String, cores: Int, memory: Int): Unit - def executorRemoved(id: String, message: String, exitStatus: Option[Int]): Unit + def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]): Unit } diff --git a/core/src/main/scala/spark/deploy/master/JobInfo.scala b/core/src/main/scala/spark/deploy/master/JobInfo.scala index 130b031a2a..a274b21c34 100644 --- a/core/src/main/scala/spark/deploy/master/JobInfo.scala +++ b/core/src/main/scala/spark/deploy/master/JobInfo.scala @@ -10,7 +10,7 @@ private[spark] class JobInfo( val id: String, val desc: JobDescription, val submitDate: Date, - val actor: ActorRef) + val driver: ActorRef) { var state = JobState.WAITING var executors = new mutable.HashMap[Int, ExecutorInfo] diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index 2c2cd0231b..3347207c6d 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -88,7 +88,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor execOption match { case Some(exec) => { exec.state = state - exec.job.actor ! ExecutorUpdated(execId, state, message, exitStatus) + exec.job.driver ! ExecutorUpdated(execId, state, message, exitStatus) if (ExecutorState.isFinished(state)) { val jobInfo = idToJob(jobId) // Remove this executor from the worker and job @@ -199,7 +199,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory, sparkHome) - exec.job.actor ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory) + exec.job.driver ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory) } def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int, @@ -221,19 +221,19 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor actorToWorker -= worker.actor addressToWorker -= worker.actor.path.address for (exec <- worker.executors.values) { - exec.job.actor ! ExecutorStateChanged(exec.job.id, exec.id, ExecutorState.LOST, None, None) + exec.job.driver ! ExecutorStateChanged(exec.job.id, exec.id, ExecutorState.LOST, None, None) exec.job.executors -= exec.id } } - def addJob(desc: JobDescription, actor: ActorRef): JobInfo = { + def addJob(desc: JobDescription, driver: ActorRef): JobInfo = { val now = System.currentTimeMillis() val date = new Date(now) - val job = new JobInfo(now, newJobId(date), desc, date, actor) + val job = new JobInfo(now, newJobId(date), desc, date, driver) jobs += job idToJob(job.id) = job - actorToJob(sender) = job - addressToJob(sender.path.address) = job + actorToJob(driver) = job + addressToJob(driver.path.address) = job return job } @@ -242,8 +242,8 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor logInfo("Removing job " + job.id) jobs -= job idToJob -= job.id - actorToJob -= job.actor - addressToWorker -= job.actor.path.address + actorToJob -= job.driver + addressToWorker -= job.driver.path.address completedJobs += job // Remember it in our history waitingJobs -= job for (exec <- job.executors.values) { diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala index a29bf974d2..f80f1b5274 100644 --- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala @@ -16,33 +16,33 @@ import spark.scheduler.cluster.RegisterSlave private[spark] class StandaloneExecutorBackend( executor: Executor, - masterUrl: String, - slaveId: String, + driverUrl: String, + workerId: String, hostname: String, cores: Int) extends Actor with ExecutorBackend with Logging { - var master: ActorRef = null + var driver: ActorRef = null override def preStart() { try { - logInfo("Connecting to master: " + masterUrl) - master = context.actorFor(masterUrl) - master ! RegisterSlave(slaveId, hostname, cores) + logInfo("Connecting to driver: " + driverUrl) + driver = context.actorFor(driverUrl) + driver ! RegisterSlave(workerId, hostname, cores) context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) - context.watch(master) // Doesn't work with remote actors, but useful for testing + context.watch(driver) // Doesn't work with remote actors, but useful for testing } catch { case e: Exception => - logError("Failed to connect to master", e) + logError("Failed to connect to driver", e) System.exit(1) } } override def receive = { case RegisteredSlave(sparkProperties) => - logInfo("Successfully registered with master") + logInfo("Successfully registered with driver") executor.initialize(hostname, sparkProperties) case RegisterSlaveFailed(message) => @@ -55,24 +55,24 @@ private[spark] class StandaloneExecutorBackend( } override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) { - master ! StatusUpdate(slaveId, taskId, state, data) + driver ! StatusUpdate(workerId, taskId, state, data) } } private[spark] object StandaloneExecutorBackend { - def run(masterUrl: String, slaveId: String, hostname: String, cores: Int) { + def run(driverUrl: String, workerId: String, hostname: String, cores: Int) { // Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor // before getting started with all our system properties, etc val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0) val actor = actorSystem.actorOf( - Props(new StandaloneExecutorBackend(new Executor, masterUrl, slaveId, hostname, cores)), + Props(new StandaloneExecutorBackend(new Executor, driverUrl, workerId, hostname, cores)), name = "Executor") actorSystem.awaitTermination() } def main(args: Array[String]) { if (args.length != 4) { - System.err.println("Usage: StandaloneExecutorBackend ") + System.err.println("Usage: StandaloneExecutorBackend ") System.exit(1) } run(args(0), args(1), args(2), args(3).toInt) diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 4f82cd96dd..866beb6d01 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -19,7 +19,7 @@ private[spark] class SparkDeploySchedulerBackend( var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _ val maxCores = System.getProperty("spark.cores.max", Int.MaxValue.toString).toInt - val executorIdToSlaveId = new HashMap[String, String] + val executorIdToWorkerId = new HashMap[String, String] // Memory used by each executor (in megabytes) val executorMemory = { @@ -34,10 +34,11 @@ private[spark] class SparkDeploySchedulerBackend( override def start() { super.start() - val masterUrl = "akka://spark@%s:%s/user/%s".format( - System.getProperty("spark.master.host"), System.getProperty("spark.master.port"), + // The endpoint for executors to talk to us + val driverUrl = "akka://spark@%s:%s/user/%s".format( + System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"), StandaloneSchedulerBackend.ACTOR_NAME) - val args = Seq(masterUrl, "{{SLAVEID}}", "{{HOSTNAME}}", "{{CORES}}") + val args = Seq(driverUrl, "{{SLAVEID}}", "{{HOSTNAME}}", "{{CORES}}") val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) val sparkHome = sc.getSparkHome().getOrElse(throw new IllegalArgumentException("must supply spark home for spark standalone")) val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, sparkHome) @@ -55,35 +56,35 @@ private[spark] class SparkDeploySchedulerBackend( } } - def connected(jobId: String) { + override def connected(jobId: String) { logInfo("Connected to Spark cluster with job ID " + jobId) } - def disconnected() { + override def disconnected() { if (!stopping) { logError("Disconnected from Spark cluster!") scheduler.error("Disconnected from Spark cluster") } } - def executorAdded(id: String, workerId: String, host: String, cores: Int, memory: Int) { - executorIdToSlaveId += id -> workerId + override def executorAdded(fullId: String, workerId: String, host: String, cores: Int, memory: Int) { + executorIdToWorkerId += fullId -> workerId logInfo("Granted executor ID %s on host %s with %d cores, %s RAM".format( - id, host, cores, Utils.memoryMegabytesToString(memory))) + fullId, host, cores, Utils.memoryMegabytesToString(memory))) } - def executorRemoved(id: String, message: String, exitStatus: Option[Int]) { + override def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]) { val reason: ExecutorLossReason = exitStatus match { case Some(code) => ExecutorExited(code) case None => SlaveLost(message) } - logInfo("Executor %s removed: %s".format(id, message)) - executorIdToSlaveId.get(id) match { - case Some(slaveId) => - executorIdToSlaveId.remove(id) - scheduler.slaveLost(slaveId, reason) + logInfo("Executor %s removed: %s".format(fullId, message)) + executorIdToWorkerId.get(fullId) match { + case Some(workerId) => + executorIdToWorkerId.remove(fullId) + scheduler.slaveLost(workerId, reason) case None => - logInfo("No slave ID known for executor %s".format(id)) + logInfo("No worker ID known for executor %s".format(fullId)) } } } diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala index 1386cd9d44..bea9dc4f23 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala @@ -6,7 +6,7 @@ import spark.util.SerializableBuffer private[spark] sealed trait StandaloneClusterMessage extends Serializable -// Master to slaves +// Driver to executors private[spark] case class LaunchTask(task: TaskDescription) extends StandaloneClusterMessage @@ -16,7 +16,7 @@ case class RegisteredSlave(sparkProperties: Seq[(String, String)]) extends Stand private[spark] case class RegisterSlaveFailed(message: String) extends StandaloneClusterMessage -// Slaves to master +// Executors to driver private[spark] case class RegisterSlave(slaveId: String, host: String, cores: Int) extends StandaloneClusterMessage @@ -32,6 +32,6 @@ object StatusUpdate { } } -// Internal messages in master +// Internal messages in driver private[spark] case object ReviveOffers extends StandaloneClusterMessage -private[spark] case object StopMaster extends StandaloneClusterMessage +private[spark] case object StopDriver extends StandaloneClusterMessage diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index eeaae23dc8..d742a7b2bf 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -23,7 +23,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor // Use an atomic variable to track total number of cores in the cluster for simplicity and speed var totalCoreCount = new AtomicInteger(0) - class MasterActor(sparkProperties: Seq[(String, String)]) extends Actor { + class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor { val slaveActor = new HashMap[String, ActorRef] val slaveAddress = new HashMap[String, Address] val slaveHost = new HashMap[String, String] @@ -37,34 +37,34 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } def receive = { - case RegisterSlave(slaveId, host, cores) => - if (slaveActor.contains(slaveId)) { - sender ! RegisterSlaveFailed("Duplicate slave ID: " + slaveId) + case RegisterSlave(workerId, host, cores) => + if (slaveActor.contains(workerId)) { + sender ! RegisterSlaveFailed("Duplicate slave ID: " + workerId) } else { - logInfo("Registered slave: " + sender + " with ID " + slaveId) + logInfo("Registered slave: " + sender + " with ID " + workerId) sender ! RegisteredSlave(sparkProperties) context.watch(sender) - slaveActor(slaveId) = sender - slaveHost(slaveId) = host - freeCores(slaveId) = cores - slaveAddress(slaveId) = sender.path.address - actorToSlaveId(sender) = slaveId - addressToSlaveId(sender.path.address) = slaveId + slaveActor(workerId) = sender + slaveHost(workerId) = host + freeCores(workerId) = cores + slaveAddress(workerId) = sender.path.address + actorToSlaveId(sender) = workerId + addressToSlaveId(sender.path.address) = workerId totalCoreCount.addAndGet(cores) makeOffers() } - case StatusUpdate(slaveId, taskId, state, data) => + case StatusUpdate(workerId, taskId, state, data) => scheduler.statusUpdate(taskId, state, data.value) if (TaskState.isFinished(state)) { - freeCores(slaveId) += 1 - makeOffers(slaveId) + freeCores(workerId) += 1 + makeOffers(workerId) } case ReviveOffers => makeOffers() - case StopMaster => + case StopDriver => sender ! true context.stop(self) @@ -85,9 +85,9 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } // Make fake resource offers on just one slave - def makeOffers(slaveId: String) { + def makeOffers(workerId: String) { launchTasks(scheduler.resourceOffers( - Seq(new WorkerOffer(slaveId, slaveHost(slaveId), freeCores(slaveId))))) + Seq(new WorkerOffer(workerId, slaveHost(workerId), freeCores(workerId))))) } // Launch tasks returned by a set of resource offers @@ -99,24 +99,24 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } // Remove a disconnected slave from the cluster - def removeSlave(slaveId: String, reason: String) { - logInfo("Slave " + slaveId + " disconnected, so removing it") - val numCores = freeCores(slaveId) - actorToSlaveId -= slaveActor(slaveId) - addressToSlaveId -= slaveAddress(slaveId) - slaveActor -= slaveId - slaveHost -= slaveId - freeCores -= slaveId - slaveHost -= slaveId + def removeSlave(workerId: String, reason: String) { + logInfo("Slave " + workerId + " disconnected, so removing it") + val numCores = freeCores(workerId) + actorToSlaveId -= slaveActor(workerId) + addressToSlaveId -= slaveAddress(workerId) + slaveActor -= workerId + slaveHost -= workerId + freeCores -= workerId + slaveHost -= workerId totalCoreCount.addAndGet(-numCores) - scheduler.slaveLost(slaveId, SlaveLost(reason)) + scheduler.slaveLost(workerId, SlaveLost(reason)) } } - var masterActor: ActorRef = null + var driverActor: ActorRef = null val taskIdsOnSlave = new HashMap[String, HashSet[String]] - def start() { + override def start() { val properties = new ArrayBuffer[(String, String)] val iterator = System.getProperties.entrySet.iterator while (iterator.hasNext) { @@ -126,15 +126,15 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor properties += ((key, value)) } } - masterActor = actorSystem.actorOf( - Props(new MasterActor(properties)), name = StandaloneSchedulerBackend.ACTOR_NAME) + driverActor = actorSystem.actorOf( + Props(new DriverActor(properties)), name = StandaloneSchedulerBackend.ACTOR_NAME) } - def stop() { + override def stop() { try { - if (masterActor != null) { + if (driverActor != null) { val timeout = 5.seconds - val future = masterActor.ask(StopMaster)(timeout) + val future = driverActor.ask(StopDriver)(timeout) Await.result(future, timeout) } } catch { @@ -143,11 +143,11 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } } - def reviveOffers() { - masterActor ! ReviveOffers + override def reviveOffers() { + driverActor ! ReviveOffers } - def defaultParallelism(): Int = math.max(totalCoreCount.get(), 2) + override def defaultParallelism(): Int = math.max(totalCoreCount.get(), 2) } private[spark] object StandaloneSchedulerBackend { diff --git a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala index 014906b028..7bf56a05d6 100644 --- a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala @@ -104,11 +104,11 @@ private[spark] class CoarseMesosSchedulerBackend( def createCommand(offer: Offer, numCores: Int): CommandInfo = { val runScript = new File(sparkHome, "run").getCanonicalPath - val masterUrl = "akka://spark@%s:%s/user/%s".format( - System.getProperty("spark.master.host"), System.getProperty("spark.master.port"), + val driverUrl = "akka://spark@%s:%s/user/%s".format( + System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"), StandaloneSchedulerBackend.ACTOR_NAME) val command = "\"%s\" spark.executor.StandaloneExecutorBackend %s %s %s %d".format( - runScript, masterUrl, offer.getSlaveId.getValue, offer.getHostname, numCores) + runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores) val environment = Environment.newBuilder() sc.executorEnvs.foreach { case (key, value) => environment.addVariables(Environment.Variable.newBuilder() diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index a3d8671834..9fd2b454a4 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -11,52 +11,51 @@ import akka.util.duration._ import spark.{Logging, SparkException, Utils} - private[spark] class BlockManagerMaster( val actorSystem: ActorSystem, - isMaster: Boolean, + isDriver: Boolean, isLocal: Boolean, - masterIp: String, - masterPort: Int) + driverIp: String, + driverPort: Int) extends Logging { val AKKA_RETRY_ATTEMPS: Int = System.getProperty("spark.akka.num.retries", "3").toInt val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt - val MASTER_AKKA_ACTOR_NAME = "BlockMasterManager" + val DRIVER_AKKA_ACTOR_NAME = "BlockMasterManager" val SLAVE_AKKA_ACTOR_NAME = "BlockSlaveManager" val DEFAULT_MANAGER_IP: String = Utils.localHostName() val timeout = 10.seconds - var masterActor: ActorRef = { - if (isMaster) { - val masterActor = actorSystem.actorOf(Props(new BlockManagerMasterActor(isLocal)), - name = MASTER_AKKA_ACTOR_NAME) + var driverActor: ActorRef = { + if (isDriver) { + val driverActor = actorSystem.actorOf(Props(new BlockManagerMasterActor(isLocal)), + name = DRIVER_AKKA_ACTOR_NAME) logInfo("Registered BlockManagerMaster Actor") - masterActor + driverActor } else { - val url = "akka://spark@%s:%s/user/%s".format(masterIp, masterPort, MASTER_AKKA_ACTOR_NAME) + val url = "akka://spark@%s:%s/user/%s".format(driverIp, driverPort, DRIVER_AKKA_ACTOR_NAME) logInfo("Connecting to BlockManagerMaster: " + url) actorSystem.actorFor(url) } } - /** Remove a dead host from the master actor. This is only called on the master side. */ + /** Remove a dead host from the driver actor. This is only called on the driver side. */ def notifyADeadHost(host: String) { tell(RemoveHost(host)) logInfo("Removed " + host + " successfully in notifyADeadHost") } /** - * Send the master actor a heart beat from the slave. Returns true if everything works out, - * false if the master does not know about the given block manager, which means the block + * Send the driver actor a heart beat from the slave. Returns true if everything works out, + * false if the driver does not know about the given block manager, which means the block * manager should re-register. */ def sendHeartBeat(blockManagerId: BlockManagerId): Boolean = { - askMasterWithRetry[Boolean](HeartBeat(blockManagerId)) + askDriverWithReply[Boolean](HeartBeat(blockManagerId)) } - /** Register the BlockManager's id with the master. */ + /** Register the BlockManager's id with the driver. */ def registerBlockManager( blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { logInfo("Trying to register BlockManager") @@ -70,25 +69,25 @@ private[spark] class BlockManagerMaster( storageLevel: StorageLevel, memSize: Long, diskSize: Long): Boolean = { - val res = askMasterWithRetry[Boolean]( + val res = askDriverWithReply[Boolean]( UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize)) logInfo("Updated info of block " + blockId) res } - /** Get locations of the blockId from the master */ + /** Get locations of the blockId from the driver */ def getLocations(blockId: String): Seq[BlockManagerId] = { - askMasterWithRetry[Seq[BlockManagerId]](GetLocations(blockId)) + askDriverWithReply[Seq[BlockManagerId]](GetLocations(blockId)) } - /** Get locations of multiple blockIds from the master */ + /** Get locations of multiple blockIds from the driver */ def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = { - askMasterWithRetry[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) + askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) } - /** Get ids of other nodes in the cluster from the master */ + /** Get ids of other nodes in the cluster from the driver */ def getPeers(blockManagerId: BlockManagerId, numPeers: Int): Seq[BlockManagerId] = { - val result = askMasterWithRetry[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers)) + val result = askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers)) if (result.length != numPeers) { throw new SparkException( "Error getting peers, only got " + result.size + " instead of " + numPeers) @@ -98,10 +97,10 @@ private[spark] class BlockManagerMaster( /** * Remove a block from the slaves that have it. This can only be used to remove - * blocks that the master knows about. + * blocks that the driver knows about. */ def removeBlock(blockId: String) { - askMasterWithRetry(RemoveBlock(blockId)) + askDriverWithReply(RemoveBlock(blockId)) } /** @@ -111,33 +110,33 @@ private[spark] class BlockManagerMaster( * amount of remaining memory. */ def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = { - askMasterWithRetry[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) + askDriverWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) } - /** Stop the master actor, called only on the Spark master node */ + /** Stop the driver actor, called only on the Spark driver node */ def stop() { - if (masterActor != null) { + if (driverActor != null) { tell(StopBlockManagerMaster) - masterActor = null + driverActor = null logInfo("BlockManagerMaster stopped") } } /** Send a one-way message to the master actor, to which we expect it to reply with true. */ private def tell(message: Any) { - if (!askMasterWithRetry[Boolean](message)) { + if (!askDriverWithReply[Boolean](message)) { throw new SparkException("BlockManagerMasterActor returned false, expected true.") } } /** - * Send a message to the master actor and get its result within a default timeout, or + * Send a message to the driver actor and get its result within a default timeout, or * throw a SparkException if this fails. */ - private def askMasterWithRetry[T](message: Any): T = { + private def askDriverWithReply[T](message: Any): T = { // TODO: Consider removing multiple attempts - if (masterActor == null) { - throw new SparkException("Error sending message to BlockManager as masterActor is null " + + if (driverActor == null) { + throw new SparkException("Error sending message to BlockManager as driverActor is null " + "[message = " + message + "]") } var attempts = 0 @@ -145,7 +144,7 @@ private[spark] class BlockManagerMaster( while (attempts < AKKA_RETRY_ATTEMPS) { attempts += 1 try { - val future = masterActor.ask(message)(timeout) + val future = driverActor.ask(message)(timeout) val result = Await.result(future, timeout) if (result == null) { throw new Exception("BlockManagerMaster returned null") diff --git a/core/src/main/scala/spark/storage/ThreadingTest.scala b/core/src/main/scala/spark/storage/ThreadingTest.scala index 689f07b969..0b8f6d4303 100644 --- a/core/src/main/scala/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/spark/storage/ThreadingTest.scala @@ -75,9 +75,9 @@ private[spark] object ThreadingTest { System.setProperty("spark.kryoserializer.buffer.mb", "1") val actorSystem = ActorSystem("test") val serializer = new KryoSerializer - val masterIp: String = System.getProperty("spark.master.host", "localhost") - val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt - val blockManagerMaster = new BlockManagerMaster(actorSystem, true, true, masterIp, masterPort) + val driverIp: String = System.getProperty("spark.driver.host", "localhost") + val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt + val blockManagerMaster = new BlockManagerMaster(actorSystem, true, true, driverIp, driverPort) val blockManager = new BlockManager(actorSystem, blockManagerMaster, serializer, 1024 * 1024) val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue)) diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 01351de4ae..42ce6f3c74 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -46,7 +46,7 @@ public class JavaAPISuite implements Serializable { sc.stop(); sc = null; // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port"); + System.clearProperty("spark.driver.port"); } static class ReverseIntComparator implements Comparator, Serializable { diff --git a/core/src/test/scala/spark/LocalSparkContext.scala b/core/src/test/scala/spark/LocalSparkContext.scala index b5e31ddae3..ff00dd05dd 100644 --- a/core/src/test/scala/spark/LocalSparkContext.scala +++ b/core/src/test/scala/spark/LocalSparkContext.scala @@ -26,7 +26,7 @@ object LocalSparkContext { def stop(sc: SparkContext) { sc.stop() // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") + System.clearProperty("spark.driver.port") } /** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */ diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala index 7d5305f1e0..718107d2b5 100644 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala @@ -79,7 +79,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { test("remote fetch") { val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0) - System.setProperty("spark.master.port", boundPort.toString) + System.setProperty("spark.driver.port", boundPort.toString) val masterTracker = new MapOutputTracker(actorSystem, true) val slaveTracker = new MapOutputTracker(actorSystem, false) masterTracker.registerShuffle(10, 1) diff --git a/docs/configuration.md b/docs/configuration.md index 036a0df480..a7054b4321 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -202,7 +202,7 @@ Apart from these, the following properties are also available, and may be useful 10 Maximum message size to allow in "control plane" communication (for serialized tasks and task - results), in MB. Increase this if your tasks need to send back large results to the master + results), in MB. Increase this if your tasks need to send back large results to the driver (e.g. using collect() on a large dataset). @@ -211,7 +211,7 @@ Apart from these, the following properties are also available, and may be useful 4 Number of actor threads to use for communication. Can be useful to increase on large clusters - when the master has a lot of CPU cores. + when the driver has a lot of CPU cores. @@ -222,17 +222,17 @@ Apart from these, the following properties are also available, and may be useful - spark.master.host + spark.driver.host (local hostname) - Hostname or IP address for the master to listen on. + Hostname or IP address for the driver to listen on. - spark.master.port + spark.driver.port (random) - Port for the master to listen on. + Port for the driver to listen on. diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 46ab34f063..df7235756d 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -26,7 +26,7 @@ class PySparkTestCase(unittest.TestCase): sys.path = self._old_sys_path # To avoid Akka rebinding to the same port, since it doesn't unbind # immediately on shutdown - self.sc.jvm.System.clearProperty("spark.master.port") + self.sc.jvm.System.clearProperty("spark.driver.port") class TestCheckpoint(PySparkTestCase): diff --git a/repl/src/test/scala/spark/repl/ReplSuite.scala b/repl/src/test/scala/spark/repl/ReplSuite.scala index db78d06d4f..43559b96d3 100644 --- a/repl/src/test/scala/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/spark/repl/ReplSuite.scala @@ -31,7 +31,7 @@ class ReplSuite extends FunSuite { if (interp.sparkContext != null) interp.sparkContext.stop() // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") + System.clearProperty("spark.driver.port") return out.toString } diff --git a/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala index aa6be95f30..8c322dd698 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala @@ -153,8 +153,8 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log /** A helper actor that communicates with the NetworkInputTracker */ private class NetworkReceiverActor extends Actor { logInfo("Attempting to register with tracker") - val ip = System.getProperty("spark.master.host", "localhost") - val port = System.getProperty("spark.master.port", "7077").toInt + val ip = System.getProperty("spark.driver.host", "localhost") + val port = System.getProperty("spark.driver.port", "7077").toInt val url = "akka://spark@%s:%s/user/NetworkInputTracker".format(ip, port) val tracker = env.actorSystem.actorFor(url) val timeout = 5.seconds diff --git a/streaming/src/test/java/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/spark/streaming/JavaAPISuite.java index c84e7331c7..79d6093429 100644 --- a/streaming/src/test/java/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/spark/streaming/JavaAPISuite.java @@ -43,7 +43,7 @@ public class JavaAPISuite implements Serializable { ssc = null; // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port"); + System.clearProperty("spark.driver.port"); } @Test diff --git a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala index bfdf32c73e..4a036f0710 100644 --- a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala @@ -10,7 +10,7 @@ class BasicOperationsSuite extends TestSuiteBase { after { // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") + System.clearProperty("spark.driver.port") } test("map") { diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala index d2f32c189b..563a7d1458 100644 --- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -19,7 +19,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { FileUtils.deleteDirectory(new File(checkpointDir)) // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") + System.clearProperty("spark.driver.port") } var ssc: StreamingContext = null diff --git a/streaming/src/test/scala/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/spark/streaming/FailureSuite.scala index 7493ac1207..c4cfffbfc1 100644 --- a/streaming/src/test/scala/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/spark/streaming/FailureSuite.scala @@ -24,7 +24,7 @@ class FailureSuite extends TestSuiteBase with BeforeAndAfter { FileUtils.deleteDirectory(new File(checkpointDir)) // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") + System.clearProperty("spark.driver.port") } override def framework = "CheckpointSuite" diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala index d7ba7a5d17..70ae6e3934 100644 --- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala @@ -42,7 +42,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") + System.clearProperty("spark.driver.port") } test("network input stream") { diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala index 0c6e928835..cd9608df53 100644 --- a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala @@ -13,7 +13,7 @@ class WindowOperationsSuite extends TestSuiteBase { after { // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.master.port") + System.clearProperty("spark.driver.port") } val largerSlideInput = Seq( From 539491bbc333834b9ae2721ae6cf3524cefb91ea Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 25 Jan 2013 09:29:59 -0800 Subject: [PATCH 201/291] code reformatting --- core/src/main/scala/spark/RDD.scala | 4 ++-- core/src/main/scala/spark/storage/BlockManagerUI.scala | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 870cc5ca78..4fcab9279a 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -94,7 +94,7 @@ abstract class RDD[T: ClassManifest]( /** How this RDD depends on any parent RDDs. */ protected def getDependencies(): List[Dependency[_]] = dependencies_ - // A friendly name for this RDD + /** A friendly name for this RDD */ var name: String = null /** Optionally overridden by subclasses to specify placement preferences. */ @@ -111,7 +111,7 @@ abstract class RDD[T: ClassManifest]( /** A unique ID for this RDD (within its SparkContext). */ val id = sc.newRddId() - /* Assign a name to this RDD */ + /** Assign a name to this RDD */ def setName(_name: String) = { name = _name this diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala index 35cbd59280..1003cc7a61 100644 --- a/core/src/main/scala/spark/storage/BlockManagerUI.scala +++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala @@ -57,7 +57,8 @@ class BlockManagerUIDirectives(val actorSystem: ActorSystem, master: ActorRef, val rdds = StorageUtils.rddInfoFromStorageStatus(storageStatusList, sc) - spark.storage.html.index.render(maxMem, remainingMem, diskSpaceUsed, rdds, storageStatusList) + spark.storage.html.index. + render(maxMem, remainingMem, diskSpaceUsed, rdds, storageStatusList) } }}} ~ get { path("rdd") { parameter("id") { id => { completeWith { @@ -67,9 +68,10 @@ class BlockManagerUIDirectives(val actorSystem: ActorSystem, master: ActorRef, val storageStatusList = status.asInstanceOf[ArrayBuffer[StorageStatus]].toArray - val filteredStorageStatusList = StorageUtils.filterStorageStatusByPrefix(storageStatusList, prefix) + val filteredStorageStatusList = StorageUtils. + filterStorageStatusByPrefix(storageStatusList, prefix) - val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).first + val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head spark.storage.html.rdd.render(rddInfo, filteredStorageStatusList) From 1cadaa164e9f078e4ca483edb9db7fd5507c9e64 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 25 Jan 2013 09:30:21 -0800 Subject: [PATCH 202/291] switch to TimeStampedHashMap for storing persistent Rdds --- core/src/main/scala/spark/SparkContext.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index d994648899..10ceeb3028 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -44,6 +44,7 @@ import scheduler.{ResultTask, ShuffleMapTask, DAGScheduler, TaskScheduler} import spark.scheduler.local.LocalScheduler import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler} import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} +import util.TimeStampedHashMap /** * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark @@ -110,7 +111,7 @@ class SparkContext( private[spark] val addedJars = HashMap[String, Long]() // Keeps track of all persisted RDDs - private[spark] val persistentRdds = new ConcurrentHashMap[Int, RDD[_]]() + private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]() // Add each JAR given through the constructor jars.foreach { addJar(_) } From a1d9d1767d821c1e25e485e32d9356b12aba6a01 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 25 Jan 2013 10:05:26 -0800 Subject: [PATCH 203/291] fixup 1cadaa1, changed api of map --- core/src/main/scala/spark/storage/StorageUtils.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/storage/StorageUtils.scala b/core/src/main/scala/spark/storage/StorageUtils.scala index 63ad5c125b..a10e3a95c6 100644 --- a/core/src/main/scala/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/spark/storage/StorageUtils.scala @@ -56,8 +56,8 @@ object StorageUtils { // Find the id of the RDD, e.g. rdd_1 => 1 val rddId = rddKey.split("_").last.toInt // Get the friendly name for the rdd, if available. - val rddName = Option(sc.persistentRdds.get(rddId).name).getOrElse(rddKey) - val rddStorageLevel = sc.persistentRdds.get(rddId).getStorageLevel + val rddName = Option(sc.persistentRdds(rddId).name).getOrElse(rddKey) + val rddStorageLevel = sc.persistentRdds(rddId).getStorageLevel RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, memSize, diskSize) }.toArray From 8efbda0b179e3821a1221c6d78681fc74248cdac Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Fri, 25 Jan 2013 14:55:33 -0600 Subject: [PATCH 204/291] Call executeOnCompleteCallbacks in more finally blocks. --- .../scala/spark/scheduler/DAGScheduler.scala | 13 ++--- .../spark/scheduler/ShuffleMapTask.scala | 50 +++++++++---------- 2 files changed, 32 insertions(+), 31 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index b320be8863..f599eb00bd 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -40,7 +40,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with eventQueue.put(HostLost(host)) } - // Called by TaskScheduler to cancel an entier TaskSet due to repeated failures. + // Called by TaskScheduler to cancel an entire TaskSet due to repeated failures. override def taskSetFailed(taskSet: TaskSet, reason: String) { eventQueue.put(TaskSetFailed(taskSet, reason)) } @@ -54,8 +54,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with // resubmit failed stages val POLL_TIMEOUT = 10L - private val lock = new Object // Used for access to the entire DAGScheduler - private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent] val nextRunId = new AtomicInteger(0) @@ -337,9 +335,12 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val rdd = job.finalStage.rdd val split = rdd.splits(job.partitions(0)) val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0) - val result = job.func(taskContext, rdd.iterator(split, taskContext)) - taskContext.executeOnCompleteCallbacks() - job.listener.taskSucceeded(0, result) + try { + val result = job.func(taskContext, rdd.iterator(split, taskContext)) + job.listener.taskSucceeded(0, result) + } finally { + taskContext.executeOnCompleteCallbacks() + } } catch { case e: Exception => job.listener.jobFailed(e) diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 19f5328eee..83641a2a84 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -81,7 +81,7 @@ private[spark] class ShuffleMapTask( with Externalizable with Logging { - def this() = this(0, null, null, 0, null) + protected def this() = this(0, null, null, 0, null) var split = if (rdd == null) { null @@ -117,34 +117,34 @@ private[spark] class ShuffleMapTask( override def run(attemptId: Long): MapStatus = { val numOutputSplits = dep.partitioner.numPartitions - val partitioner = dep.partitioner val taskContext = new TaskContext(stageId, partition, attemptId) + try { + // Partition the map output. + val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)]) + for (elem <- rdd.iterator(split, taskContext)) { + val pair = elem.asInstanceOf[(Any, Any)] + val bucketId = dep.partitioner.getPartition(pair._1) + buckets(bucketId) += pair + } + val bucketIterators = buckets.map(_.iterator) - // Partition the map output. - val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)]) - for (elem <- rdd.iterator(split, taskContext)) { - val pair = elem.asInstanceOf[(Any, Any)] - val bucketId = partitioner.getPartition(pair._1) - buckets(bucketId) += pair + val compressedSizes = new Array[Byte](numOutputSplits) + + val blockManager = SparkEnv.get.blockManager + for (i <- 0 until numOutputSplits) { + val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i + // Get a Scala iterator from Java map + val iter: Iterator[(Any, Any)] = bucketIterators(i) + val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false) + compressedSizes(i) = MapOutputTracker.compressSize(size) + } + + return new MapStatus(blockManager.blockManagerId, compressedSizes) + } finally { + // Execute the callbacks on task completion. + taskContext.executeOnCompleteCallbacks() } - val bucketIterators = buckets.map(_.iterator) - - val compressedSizes = new Array[Byte](numOutputSplits) - - val blockManager = SparkEnv.get.blockManager - for (i <- 0 until numOutputSplits) { - val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i - // Get a Scala iterator from Java map - val iter: Iterator[(Any, Any)] = bucketIterators(i) - val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false) - compressedSizes(i) = MapOutputTracker.compressSize(size) - } - - // Execute the callbacks on task completion. - taskContext.executeOnCompleteCallbacks() - - return new MapStatus(blockManager.blockManagerId, compressedSizes) } override def preferredLocations: Seq[String] = locs From 49c05608f5f27354da120e2367b6d4a63ec38948 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 25 Jan 2013 17:04:16 -0800 Subject: [PATCH 205/291] add metadatacleaner for persisentRdd map --- core/src/main/scala/spark/SparkContext.scala | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 10ceeb3028..bff54dbdd1 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -44,7 +44,7 @@ import scheduler.{ResultTask, ShuffleMapTask, DAGScheduler, TaskScheduler} import spark.scheduler.local.LocalScheduler import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler} import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} -import util.TimeStampedHashMap +import util.{MetadataCleaner, TimeStampedHashMap} /** * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark @@ -113,6 +113,9 @@ class SparkContext( // Keeps track of all persisted RDDs private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]() + private[spark] val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup) + + // Add each JAR given through the constructor jars.foreach { addJar(_) } @@ -512,6 +515,7 @@ class SparkContext( /** Shut down the SparkContext. */ def stop() { if (dagScheduler != null) { + metadataCleaner.cancel() dagScheduler.stop() dagScheduler = null taskScheduler = null @@ -654,6 +658,12 @@ class SparkContext( /** Register a new RDD, returning its RDD ID */ private[spark] def newRddId(): Int = nextRddId.getAndIncrement() + + private[spark] def cleanup(cleanupTime: Long) { + var sizeBefore = persistentRdds.size + persistentRdds.clearOldValues(cleanupTime) + logInfo("idToStage " + sizeBefore + " --> " + persistentRdds.size) + } } /** From d49cf0e587b7cbbd31917d9bb69f98466feb0f9f Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 26 Jan 2013 15:57:01 -0800 Subject: [PATCH 206/291] Fix JavaRDDLike.flatMap(PairFlatMapFunction) (SPARK-668). This workaround is easier than rewriting JavaRDDLike in Java. --- .../scala/spark/api/java/JavaRDDLike.scala | 7 ++--- .../spark/api/java/PairFlatMapWorkaround.java | 20 +++++++++++++ core/src/test/scala/spark/JavaAPISuite.java | 28 +++++++++++++++++++ 3 files changed, 51 insertions(+), 4 deletions(-) create mode 100644 core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index b3698ffa44..4c95c989b5 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -12,7 +12,7 @@ import spark.storage.StorageLevel import com.google.common.base.Optional -trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { +trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround[T] { def wrapRDD(rdd: RDD[T]): This implicit val classManifest: ClassManifest[T] @@ -82,10 +82,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { } /** - * Return a new RDD by first applying a function to all elements of this - * RDD, and then flattening the results. + * Part of the workaround for SPARK-668; called in PairFlatMapWorkaround.java. */ - def flatMap[K, V](f: PairFlatMapFunction[T, K, V]): JavaPairRDD[K, V] = { + private[spark] def doFlatMap[K, V](f: PairFlatMapFunction[T, K, V]): JavaPairRDD[K, V] = { import scala.collection.JavaConverters._ def fn = (x: T) => f.apply(x).asScala def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]] diff --git a/core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java b/core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java new file mode 100644 index 0000000000..68b6fd6622 --- /dev/null +++ b/core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java @@ -0,0 +1,20 @@ +package spark.api.java; + +import spark.api.java.JavaPairRDD; +import spark.api.java.JavaRDDLike; +import spark.api.java.function.PairFlatMapFunction; + +import java.io.Serializable; + +/** + * Workaround for SPARK-668. + */ +class PairFlatMapWorkaround implements Serializable { + /** + * Return a new RDD by first applying a function to all elements of this + * RDD, and then flattening the results. + */ + public JavaPairRDD flatMap(PairFlatMapFunction f) { + return ((JavaRDDLike ) this).doFlatMap(f); + } +} diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 01351de4ae..f50ba093e9 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -355,6 +355,34 @@ public class JavaAPISuite implements Serializable { Assert.assertEquals(11, pairs.count()); } + @Test + public void mapsFromPairsToPairs() { + List> pairs = Arrays.asList( + new Tuple2(1, "a"), + new Tuple2(2, "aa"), + new Tuple2(3, "aaa") + ); + JavaPairRDD pairRDD = sc.parallelizePairs(pairs); + + // Regression test for SPARK-668: + JavaPairRDD swapped = pairRDD.flatMap( + new PairFlatMapFunction, String, Integer>() { + @Override + public Iterable> call(Tuple2 item) throws Exception { + return Collections.singletonList(item.swap()); + } + }); + swapped.collect(); + + // There was never a bug here, but it's worth testing: + pairRDD.map(new PairFunction, String, Integer>() { + @Override + public Tuple2 call(Tuple2 item) throws Exception { + return item.swap(); + } + }).collect(); + } + @Test public void mapPartitions() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2); From ad4232b4dadc6290d3c4696d3cc007d3f01cb236 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Sat, 26 Jan 2013 18:07:14 -0800 Subject: [PATCH 207/291] Fix deadlock in BlockManager reregistration triggered by failed updates. --- .../scala/spark/storage/BlockManager.scala | 35 +++++++++++++++- .../spark/storage/BlockManagerSuite.scala | 40 ++++++++++++++++++- 2 files changed, 72 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 19cdaaa984..19d35b8667 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -90,7 +90,10 @@ class BlockManager( val slaveActor = master.actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)), name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next) - @volatile private var shuttingDown = false + // Pending reregistration action being executed asynchronously or null if none + // is pending. Accesses should synchronize on asyncReregisterLock. + var asyncReregisterTask: Future[Unit] = null + val asyncReregisterLock = new Object private def heartBeat() { if (!master.sendHeartBeat(blockManagerId)) { @@ -147,6 +150,8 @@ class BlockManager( /** * Reregister with the master and report all blocks to it. This will be called by the heart beat * thread if our heartbeat to the block amnager indicates that we were not registered. + * + * Note that this method must be called without any BlockInfo locks held. */ def reregister() { // TODO: We might need to rate limit reregistering. @@ -155,6 +160,32 @@ class BlockManager( reportAllBlocks() } + /** + * Reregister with the master sometime soon. + */ + def asyncReregister() { + asyncReregisterLock.synchronized { + if (asyncReregisterTask == null) { + asyncReregisterTask = Future[Unit] { + reregister() + asyncReregisterLock.synchronized { + asyncReregisterTask = null + } + } + } + } + } + + /** + * For testing. Wait for any pending asynchronous reregistration; otherwise, do nothing. + */ + def waitForAsyncReregister() { + val task = asyncReregisterTask + if (task != null) { + Await.ready(task, Duration.Inf) + } + } + /** * Get storage level of local block. If no info exists for the block, then returns null. */ @@ -170,7 +201,7 @@ class BlockManager( if (needReregister) { logInfo("Got told to reregister updating block " + blockId) // Reregistering will report our new block for free. - reregister() + asyncReregister() } logDebug("Told master about block " + blockId) } diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index a1aeb12f25..2165744689 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -219,18 +219,56 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT val a2 = new Array[Byte](400) store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) - assert(master.getLocations("a1").size > 0, "master was not told about a1") master.notifyADeadHost(store.blockManagerId.ip) assert(master.getLocations("a1").size == 0, "a1 was not removed from master") store.putSingle("a2", a1, StorageLevel.MEMORY_ONLY) + store.waitForAsyncReregister() assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master") assert(master.getLocations("a2").size > 0, "master was not told about a2") } + test("reregistration doesn't dead lock") { + val heartBeat = PrivateMethod[Unit]('heartBeat) + store = new BlockManager(actorSystem, master, serializer, 2000) + val a1 = new Array[Byte](400) + val a2 = List(new Array[Byte](400)) + + // try many times to trigger any deadlocks + for (i <- 1 to 100) { + master.notifyADeadHost(store.blockManagerId.ip) + val t1 = new Thread { + override def run = { + store.put("a2", a2.iterator, StorageLevel.MEMORY_ONLY, true) + } + } + val t2 = new Thread { + override def run = { + store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) + } + } + val t3 = new Thread { + override def run = { + store invokePrivate heartBeat() + } + } + + t1.start + t2.start + t3.start + t1.join + t2.join + t3.join + + store.dropFromMemory("a1", null) + store.dropFromMemory("a2", null) + store.waitForAsyncReregister() + } + } + test("in-memory LRU storage") { store = new BlockManager(actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) From 58fc6b2bed9f660fbf134aab188827b7d8975a62 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Sat, 26 Jan 2013 18:07:53 -0800 Subject: [PATCH 208/291] Handle duplicate registrations better. --- core/src/main/scala/spark/storage/BlockManagerMasterActor.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala index f4d026da33..2216c33b76 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala @@ -183,7 +183,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { if (blockManagerId.ip == Utils.localHostName() && !isLocal) { logInfo("Got Register Msg from master node, don't register it") - } else { + } else if (!blockManagerInfo.contains(blockManagerId)) { blockManagerIdByHost.get(blockManagerId.ip) match { case Some(managers) => // A block manager of the same host name already exists. From 717b221cca79cb8a1603e9dcf7f0bb50e215ac41 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Sat, 26 Jan 2013 22:59:22 -0800 Subject: [PATCH 209/291] Detect whether we run on EC2 using ec2-metadata as well --- bin/start-master.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bin/start-master.sh b/bin/start-master.sh index a901b1c260..87feb261fe 100755 --- a/bin/start-master.sh +++ b/bin/start-master.sh @@ -26,7 +26,8 @@ fi # Set SPARK_PUBLIC_DNS so the master report the correct webUI address to the slaves if [ "$SPARK_PUBLIC_DNS" = "" ]; then # If we appear to be running on EC2, use the public address by default: - if [[ `hostname` == *ec2.internal ]]; then + # NOTE: ec2-metadata is installed on Amazon Linux AMI. Check based on that and hostname + if command -v ec2-metadata > /dev/null || [[ `hostname` == *ec2.internal ]]; then export SPARK_PUBLIC_DNS=`wget -q -O - http://instance-data.ec2.internal/latest/meta-data/public-hostname` fi fi From 44b4a0f88fcb31727347b755ae8ec14d69571b52 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 27 Jan 2013 19:23:49 -0800 Subject: [PATCH 210/291] Track workers by executor ID instead of hostname to allow multiple executors per machine and remove the need for multiple IP addresses in unit tests. --- .../main/scala/spark/MapOutputTracker.scala | 4 +- core/src/main/scala/spark/SparkContext.scala | 6 +- core/src/main/scala/spark/SparkEnv.scala | 9 +- .../spark/deploy/LocalSparkCluster.scala | 16 +-- .../scala/spark/deploy/master/Master.scala | 4 +- .../spark/deploy/worker/ExecutorRunner.scala | 2 +- .../main/scala/spark/executor/Executor.scala | 4 +- .../spark/executor/MesosExecutorBackend.scala | 3 +- .../executor/StandaloneExecutorBackend.scala | 14 +-- .../scala/spark/scheduler/DAGScheduler.scala | 44 +++---- .../spark/scheduler/DAGSchedulerEvent.scala | 2 +- .../scala/spark/scheduler/MapStatus.scala | 6 +- .../main/scala/spark/scheduler/Stage.scala | 11 +- .../scheduler/TaskSchedulerListener.scala | 2 +- .../scheduler/cluster/ClusterScheduler.scala | 110 ++++++++++-------- .../cluster/SparkDeploySchedulerBackend.scala | 4 +- .../cluster/StandaloneSchedulerBackend.scala | 64 +++++----- .../scheduler/cluster/TaskDescription.scala | 2 +- .../spark/scheduler/cluster/TaskInfo.scala | 7 +- .../scheduler/cluster/TaskSetManager.scala | 38 +++--- .../spark/scheduler/cluster/WorkerOffer.scala | 4 +- .../mesos/MesosSchedulerBackend.scala | 2 +- .../scala/spark/storage/BlockManager.scala | 10 +- .../scala/spark/storage/BlockManagerId.scala | 27 +++-- .../spark/storage/BlockManagerMaster.scala | 12 +- .../storage/BlockManagerMasterActor.scala | 66 +++++------ .../spark/storage/BlockManagerMessages.scala | 2 +- .../scala/spark/storage/BlockManagerUI.scala | 7 +- .../scala/spark/storage/ThreadingTest.scala | 3 +- .../src/main/scala/spark/util/AkkaUtils.scala | 6 +- .../scala/spark/util/TimeStampedHashMap.scala | 4 +- core/src/test/scala/spark/DriverSuite.scala | 5 +- .../scala/spark/MapOutputTrackerSuite.scala | 69 ++++++----- .../spark/storage/BlockManagerSuite.scala | 86 +++++++------- sbt/sbt | 2 +- 35 files changed, 343 insertions(+), 314 deletions(-) diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index ac02f3363a..c1f012b419 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -114,7 +114,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea var array = mapStatuses(shuffleId) if (array != null) { array.synchronized { - if (array(mapId) != null && array(mapId).address == bmAddress) { + if (array(mapId) != null && array(mapId).location == bmAddress) { array(mapId) = null } } @@ -277,7 +277,7 @@ private[spark] object MapOutputTracker { throw new FetchFailedException(null, shuffleId, -1, reduceId, new Exception("Missing an output location for shuffle " + shuffleId)) } else { - (status.address, decompressSize(status.compressedSizes(reduceId))) + (status.location, decompressSize(status.compressedSizes(reduceId))) } } } diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 4581c0adcf..39721b47ae 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -80,6 +80,7 @@ class SparkContext( // Create the Spark execution environment (cache, map output tracker, etc) private[spark] val env = SparkEnv.createFromSystemProperties( + "", System.getProperty("spark.master.host"), System.getProperty("spark.master.port").toInt, true, @@ -97,7 +98,7 @@ class SparkContext( // Keeps track of all persisted RDDs private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]() - private[spark] val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup) + private[spark] val metadataCleaner = new MetadataCleaner("SparkContext", this.cleanup) // Add each JAR given through the constructor @@ -649,10 +650,9 @@ class SparkContext( /** Register a new RDD, returning its RDD ID */ private[spark] def newRddId(): Int = nextRddId.getAndIncrement() + /** Called by MetadataCleaner to clean up the persistentRdds map periodically */ private[spark] def cleanup(cleanupTime: Long) { - var sizeBefore = persistentRdds.size persistentRdds.clearOldValues(cleanupTime) - logInfo("idToStage " + sizeBefore + " --> " + persistentRdds.size) } } diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 2a7a8af83d..0c094edcf3 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -19,6 +19,7 @@ import spark.util.AkkaUtils * SparkEnv.get (e.g. after creating a SparkContext) and set it with SparkEnv.set. */ class SparkEnv ( + val executorId: String, val actorSystem: ActorSystem, val serializer: Serializer, val closureSerializer: Serializer, @@ -58,11 +59,12 @@ object SparkEnv extends Logging { } def createFromSystemProperties( + executorId: String, hostname: String, port: Int, isMaster: Boolean, - isLocal: Boolean - ) : SparkEnv = { + isLocal: Boolean): SparkEnv = { + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port) // Bit of a hack: If this is the master and our port was 0 (meaning bind to any free port), @@ -86,7 +88,7 @@ object SparkEnv extends Logging { val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt val blockManagerMaster = new BlockManagerMaster( actorSystem, isMaster, isLocal, masterIp, masterPort) - val blockManager = new BlockManager(actorSystem, blockManagerMaster, serializer) + val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, serializer) val connectionManager = blockManager.connectionManager @@ -122,6 +124,7 @@ object SparkEnv extends Logging { } new SparkEnv( + executorId, actorSystem, serializer, closureSerializer, diff --git a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala index 4211d80596..8f51051e39 100644 --- a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala @@ -9,6 +9,12 @@ import spark.{Logging, Utils} import scala.collection.mutable.ArrayBuffer +/** + * Testing class that creates a Spark standalone process in-cluster (that is, running the + * spark.deploy.master.Master and spark.deploy.worker.Workers in the same JVMs). Executors launched + * by the Workers still run in separate JVMs. This can be used to test distributed operation and + * fault recovery without spinning up a lot of processes. + */ private[spark] class LocalSparkCluster(numSlaves: Int, coresPerSlave: Int, memoryPerSlave: Int) extends Logging { @@ -35,16 +41,12 @@ class LocalSparkCluster(numSlaves: Int, coresPerSlave: Int, memoryPerSlave: Int) /* Start the Slaves */ for (slaveNum <- 1 to numSlaves) { - /* We can pretend to test distributed stuff by giving the slaves distinct hostnames. - All of 127/8 should be a loopback, we use 127.100.*.* in hopes that it is - sufficiently distinctive. */ - val slaveIpAddress = "127.100.0." + (slaveNum % 256) val (actorSystem, boundPort) = - AkkaUtils.createActorSystem("sparkWorker" + slaveNum, slaveIpAddress, 0) + AkkaUtils.createActorSystem("sparkWorker" + slaveNum, localIpAddress, 0) slaveActorSystems += actorSystem val actor = actorSystem.actorOf( - Props(new Worker(slaveIpAddress, boundPort, 0, coresPerSlave, memoryPerSlave, masterUrl)), - name = "Worker") + Props(new Worker(localIpAddress, boundPort, 0, coresPerSlave, memoryPerSlave, masterUrl)), + name = "Worker") slaveActors += actor } diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index 2c2cd0231b..2e7e868579 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -97,10 +97,10 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor exec.worker.removeExecutor(exec) // Only retry certain number of times so we don't go into an infinite loop. - if (jobInfo.incrementRetryCount <= JobState.MAX_NUM_RETRY) { + if (jobInfo.incrementRetryCount < JobState.MAX_NUM_RETRY) { schedule() } else { - val e = new SparkException("Job %s wth ID %s failed %d times.".format( + val e = new SparkException("Job %s with ID %s failed %d times.".format( jobInfo.desc.name, jobInfo.id, jobInfo.retryCount)) logError(e.getMessage, e) throw e diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala index 0d1fe2a6b4..af3acfecb6 100644 --- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala @@ -67,7 +67,7 @@ private[spark] class ExecutorRunner( /** Replace variables such as {{SLAVEID}} and {{CORES}} in a command argument passed to us */ def substituteVariables(argument: String): String = argument match { - case "{{SLAVEID}}" => workerId + case "{{EXECUTOR_ID}}" => execId.toString case "{{HOSTNAME}}" => hostname case "{{CORES}}" => cores.toString case other => other diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index 28d9d40d43..bd21ba719a 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -30,7 +30,7 @@ private[spark] class Executor extends Logging { initLogging() - def initialize(slaveHostname: String, properties: Seq[(String, String)]) { + def initialize(executorId: String, slaveHostname: String, properties: Seq[(String, String)]) { // Make sure the local hostname we report matches the cluster scheduler's name for this host Utils.setCustomHostname(slaveHostname) @@ -64,7 +64,7 @@ private[spark] class Executor extends Logging { ) // Initialize Spark environment (using system properties read above) - env = SparkEnv.createFromSystemProperties(slaveHostname, 0, false, false) + env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, false, false) SparkEnv.set(env) // Start worker thread pool diff --git a/core/src/main/scala/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/spark/executor/MesosExecutorBackend.scala index eeab3959c6..1ef88075ad 100644 --- a/core/src/main/scala/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/MesosExecutorBackend.scala @@ -29,9 +29,10 @@ private[spark] class MesosExecutorBackend(executor: Executor) executorInfo: ExecutorInfo, frameworkInfo: FrameworkInfo, slaveInfo: SlaveInfo) { + logInfo("Registered with Mesos as executor ID " + executorInfo.getExecutorId.getValue) this.driver = driver val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray) - executor.initialize(slaveInfo.getHostname, properties) + executor.initialize(executorInfo.getExecutorId.getValue, slaveInfo.getHostname, properties) } override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) { diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala index a29bf974d2..435ee5743e 100644 --- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala @@ -17,7 +17,7 @@ import spark.scheduler.cluster.RegisterSlave private[spark] class StandaloneExecutorBackend( executor: Executor, masterUrl: String, - slaveId: String, + executorId: String, hostname: String, cores: Int) extends Actor @@ -30,7 +30,7 @@ private[spark] class StandaloneExecutorBackend( try { logInfo("Connecting to master: " + masterUrl) master = context.actorFor(masterUrl) - master ! RegisterSlave(slaveId, hostname, cores) + master ! RegisterSlave(executorId, hostname, cores) context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) context.watch(master) // Doesn't work with remote actors, but useful for testing } catch { @@ -43,7 +43,7 @@ private[spark] class StandaloneExecutorBackend( override def receive = { case RegisteredSlave(sparkProperties) => logInfo("Successfully registered with master") - executor.initialize(hostname, sparkProperties) + executor.initialize(executorId, hostname, sparkProperties) case RegisterSlaveFailed(message) => logError("Slave registration failed: " + message) @@ -55,24 +55,24 @@ private[spark] class StandaloneExecutorBackend( } override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) { - master ! StatusUpdate(slaveId, taskId, state, data) + master ! StatusUpdate(executorId, taskId, state, data) } } private[spark] object StandaloneExecutorBackend { - def run(masterUrl: String, slaveId: String, hostname: String, cores: Int) { + def run(masterUrl: String, executorId: String, hostname: String, cores: Int) { // Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor // before getting started with all our system properties, etc val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0) val actor = actorSystem.actorOf( - Props(new StandaloneExecutorBackend(new Executor, masterUrl, slaveId, hostname, cores)), + Props(new StandaloneExecutorBackend(new Executor, masterUrl, executorId, hostname, cores)), name = "Executor") actorSystem.awaitTermination() } def main(args: Array[String]) { if (args.length != 4) { - System.err.println("Usage: StandaloneExecutorBackend ") + System.err.println("Usage: StandaloneExecutorBackend ") System.exit(1) } run(args(0), args(1), args(2), args(3).toInt) diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index f599eb00bd..bd541d4207 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -35,9 +35,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with eventQueue.put(CompletionEvent(task, reason, result, accumUpdates)) } - // Called by TaskScheduler when a host fails. - override def hostLost(host: String) { - eventQueue.put(HostLost(host)) + // Called by TaskScheduler when an executor fails. + override def executorLost(execId: String) { + eventQueue.put(ExecutorLost(execId)) } // Called by TaskScheduler to cancel an entire TaskSet due to repeated failures. @@ -72,7 +72,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with // For tracking failed nodes, we use the MapOutputTracker's generation number, which is // sent with every task. When we detect a node failing, we note the current generation number - // and failed host, increment it for new tasks, and use this to ignore stray ShuffleMapTask + // and failed executor, increment it for new tasks, and use this to ignore stray ShuffleMapTask // results. // TODO: Garbage collect information about failure generations when we know there are no more // stray messages to detect. @@ -108,7 +108,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with } def clearCacheLocs() { - cacheLocs.clear + cacheLocs.clear() } /** @@ -271,8 +271,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with submitStage(finalStage) } - case HostLost(host) => - handleHostLost(host) + case ExecutorLost(execId) => + handleExecutorLost(execId) case completion: CompletionEvent => handleTaskCompletion(completion) @@ -436,10 +436,10 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with case smt: ShuffleMapTask => val stage = idToStage(smt.stageId) val status = event.result.asInstanceOf[MapStatus] - val host = status.address.ip - logInfo("ShuffleMapTask finished with host " + host) - if (failedGeneration.contains(host) && smt.generation <= failedGeneration(host)) { - logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + host) + val execId = status.location.executorId + logDebug("ShuffleMapTask finished on " + execId) + if (failedGeneration.contains(execId) && smt.generation <= failedGeneration(execId)) { + logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId) } else { stage.addOutputLoc(smt.partition, status) } @@ -511,9 +511,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with // Remember that a fetch failed now; this is used to resubmit the broken // stages later, after a small wait (to give other tasks the chance to fail) lastFetchFailureTime = System.currentTimeMillis() // TODO: Use pluggable clock - // TODO: mark the host as failed only if there were lots of fetch failures on it + // TODO: mark the executor as failed only if there were lots of fetch failures on it if (bmAddress != null) { - handleHostLost(bmAddress.ip, Some(task.generation)) + handleExecutorLost(bmAddress.executorId, Some(task.generation)) } case other => @@ -523,21 +523,21 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with } /** - * Responds to a host being lost. This is called inside the event loop so it assumes that it can - * modify the scheduler's internal state. Use hostLost() to post a host lost event from outside. + * Responds to an executor being lost. This is called inside the event loop, so it assumes it can + * modify the scheduler's internal state. Use executorLost() to post a loss event from outside. * * Optionally the generation during which the failure was caught can be passed to avoid allowing * stray fetch failures from possibly retriggering the detection of a node as lost. */ - def handleHostLost(host: String, maybeGeneration: Option[Long] = None) { + def handleExecutorLost(execId: String, maybeGeneration: Option[Long] = None) { val currentGeneration = maybeGeneration.getOrElse(mapOutputTracker.getGeneration) - if (!failedGeneration.contains(host) || failedGeneration(host) < currentGeneration) { - failedGeneration(host) = currentGeneration - logInfo("Host lost: " + host + " (generation " + currentGeneration + ")") - env.blockManager.master.notifyADeadHost(host) + if (!failedGeneration.contains(execId) || failedGeneration(execId) < currentGeneration) { + failedGeneration(execId) = currentGeneration + logInfo("Executor lost: %s (generation %d)".format(execId, currentGeneration)) + env.blockManager.master.removeExecutor(execId) // TODO: This will be really slow if we keep accumulating shuffle map stages for ((shuffleId, stage) <- shuffleToMapStage) { - stage.removeOutputsOnHost(host) + stage.removeOutputsOnExecutor(execId) val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray mapOutputTracker.registerMapOutputs(shuffleId, locs, true) } @@ -546,7 +546,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with } clearCacheLocs() } else { - logDebug("Additional host lost message for " + host + + logDebug("Additional executor lost message for " + execId + "(generation " + currentGeneration + ")") } } diff --git a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala index 3422a21d9d..b34fa78c07 100644 --- a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala @@ -28,7 +28,7 @@ private[spark] case class CompletionEvent( accumUpdates: Map[Long, Any]) extends DAGSchedulerEvent -private[spark] case class HostLost(host: String) extends DAGSchedulerEvent +private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent diff --git a/core/src/main/scala/spark/scheduler/MapStatus.scala b/core/src/main/scala/spark/scheduler/MapStatus.scala index fae643f3a8..203abb917b 100644 --- a/core/src/main/scala/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/spark/scheduler/MapStatus.scala @@ -8,19 +8,19 @@ import java.io.{ObjectOutput, ObjectInput, Externalizable} * task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks. * The map output sizes are compressed using MapOutputTracker.compressSize. */ -private[spark] class MapStatus(var address: BlockManagerId, var compressedSizes: Array[Byte]) +private[spark] class MapStatus(var location: BlockManagerId, var compressedSizes: Array[Byte]) extends Externalizable { def this() = this(null, null) // For deserialization only def writeExternal(out: ObjectOutput) { - address.writeExternal(out) + location.writeExternal(out) out.writeInt(compressedSizes.length) out.write(compressedSizes) } def readExternal(in: ObjectInput) { - address = BlockManagerId(in) + location = BlockManagerId(in) compressedSizes = new Array[Byte](in.readInt()) in.readFully(compressedSizes) } diff --git a/core/src/main/scala/spark/scheduler/Stage.scala b/core/src/main/scala/spark/scheduler/Stage.scala index 4846b66729..e9419728e3 100644 --- a/core/src/main/scala/spark/scheduler/Stage.scala +++ b/core/src/main/scala/spark/scheduler/Stage.scala @@ -51,18 +51,18 @@ private[spark] class Stage( def removeOutputLoc(partition: Int, bmAddress: BlockManagerId) { val prevList = outputLocs(partition) - val newList = prevList.filterNot(_.address == bmAddress) + val newList = prevList.filterNot(_.location == bmAddress) outputLocs(partition) = newList if (prevList != Nil && newList == Nil) { numAvailableOutputs -= 1 } } - def removeOutputsOnHost(host: String) { + def removeOutputsOnExecutor(execId: String) { var becameUnavailable = false for (partition <- 0 until numPartitions) { val prevList = outputLocs(partition) - val newList = prevList.filterNot(_.address.ip == host) + val newList = prevList.filterNot(_.location.executorId == execId) outputLocs(partition) = newList if (prevList != Nil && newList == Nil) { becameUnavailable = true @@ -70,7 +70,8 @@ private[spark] class Stage( } } if (becameUnavailable) { - logInfo("%s is now unavailable on %s (%d/%d, %s)".format(this, host, numAvailableOutputs, numPartitions, isAvailable)) + logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format( + this, execId, numAvailableOutputs, numPartitions, isAvailable)) } } @@ -82,7 +83,7 @@ private[spark] class Stage( def origin: String = rdd.origin - override def toString = "Stage " + id // + ": [RDD = " + rdd.id + ", isShuffle = " + isShuffleMap + "]" + override def toString = "Stage " + id override def hashCode(): Int = id } diff --git a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala index fa4de15d0d..9fcef86e46 100644 --- a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala +++ b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala @@ -12,7 +12,7 @@ private[spark] trait TaskSchedulerListener { def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any]): Unit // A node was lost from the cluster. - def hostLost(host: String): Unit + def executorLost(execId: String): Unit // The TaskScheduler wants to abort an entire task set. def taskSetFailed(taskSet: TaskSet, reason: String): Unit diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index a639b72795..0b4177805b 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -27,19 +27,20 @@ private[spark] class ClusterScheduler(val sc: SparkContext) var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager] val taskIdToTaskSetId = new HashMap[Long, String] - val taskIdToSlaveId = new HashMap[Long, String] + val taskIdToExecutorId = new HashMap[Long, String] val taskSetTaskIds = new HashMap[String, HashSet[Long]] // Incrementing Mesos task IDs val nextTaskId = new AtomicLong(0) - // Which hosts in the cluster are alive (contains hostnames) - val hostsAlive = new HashSet[String] + // Which executor IDs we have executors on + val activeExecutorIds = new HashSet[String] - // Which slave IDs we have executors on - val slaveIdsWithExecutors = new HashSet[String] + // The set of executors we have on each host; this is used to compute hostsAlive, which + // in turn is used to decide when we can attain data locality on a given host + val executorsByHost = new HashMap[String, HashSet[String]] - val slaveIdToHost = new HashMap[String, String] + val executorIdToHost = new HashMap[String, String] // JAR server, if any JARs were added by the user to the SparkContext var jarServer: HttpServer = null @@ -102,7 +103,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) activeTaskSets -= manager.taskSet.id activeTaskSetsQueue -= manager taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) - taskIdToSlaveId --= taskSetTaskIds(manager.taskSet.id) + taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id) taskSetTaskIds.remove(manager.taskSet.id) } } @@ -117,8 +118,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) SparkEnv.set(sc.env) // Mark each slave as alive and remember its hostname for (o <- offers) { - slaveIdToHost(o.slaveId) = o.hostname - hostsAlive += o.hostname + executorIdToHost(o.executorId) = o.hostname } // Build a list of tasks to assign to each slave val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores)) @@ -128,16 +128,20 @@ private[spark] class ClusterScheduler(val sc: SparkContext) do { launchedTask = false for (i <- 0 until offers.size) { - val sid = offers(i).slaveId + val execId = offers(i).executorId val host = offers(i).hostname - manager.slaveOffer(sid, host, availableCpus(i)) match { + manager.slaveOffer(execId, host, availableCpus(i)) match { case Some(task) => tasks(i) += task val tid = task.taskId taskIdToTaskSetId(tid) = manager.taskSet.id taskSetTaskIds(manager.taskSet.id) += tid - taskIdToSlaveId(tid) = sid - slaveIdsWithExecutors += sid + taskIdToExecutorId(tid) = execId + activeExecutorIds += execId + if (!executorsByHost.contains(host)) { + executorsByHost(host) = new HashSet() + } + executorsByHost(host) += execId availableCpus(i) -= 1 launchedTask = true @@ -152,25 +156,21 @@ private[spark] class ClusterScheduler(val sc: SparkContext) def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { var taskSetToUpdate: Option[TaskSetManager] = None - var failedHost: Option[String] = None + var failedExecutor: Option[String] = None var taskFailed = false synchronized { try { - if (state == TaskState.LOST && taskIdToSlaveId.contains(tid)) { - // We lost the executor on this slave, so remember that it's gone - val slaveId = taskIdToSlaveId(tid) - val host = slaveIdToHost(slaveId) - if (hostsAlive.contains(host)) { - slaveIdsWithExecutors -= slaveId - hostsAlive -= host - activeTaskSetsQueue.foreach(_.hostLost(host)) - failedHost = Some(host) + if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) { + // We lost this entire executor, so remember that it's gone + val execId = taskIdToExecutorId(tid) + if (activeExecutorIds.contains(execId)) { + removeExecutor(execId) + failedExecutor = Some(execId) } } taskIdToTaskSetId.get(tid) match { case Some(taskSetId) => if (activeTaskSets.contains(taskSetId)) { - //activeTaskSets(taskSetId).statusUpdate(status) taskSetToUpdate = Some(activeTaskSets(taskSetId)) } if (TaskState.isFinished(state)) { @@ -178,7 +178,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) if (taskSetTaskIds.contains(taskSetId)) { taskSetTaskIds(taskSetId) -= tid } - taskIdToSlaveId.remove(tid) + taskIdToExecutorId.remove(tid) } if (state == TaskState.FAILED) { taskFailed = true @@ -190,12 +190,12 @@ private[spark] class ClusterScheduler(val sc: SparkContext) case e: Exception => logError("Exception in statusUpdate", e) } } - // Update the task set and DAGScheduler without holding a lock on this, because that can deadlock + // Update the task set and DAGScheduler without holding a lock on this, since that can deadlock if (taskSetToUpdate != None) { taskSetToUpdate.get.statusUpdate(tid, state, serializedData) } - if (failedHost != None) { - listener.hostLost(failedHost.get) + if (failedExecutor != None) { + listener.executorLost(failedExecutor.get) backend.reviveOffers() } if (taskFailed) { @@ -249,32 +249,42 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } } - def slaveLost(slaveId: String, reason: ExecutorLossReason) { - var failedHost: Option[String] = None + def executorLost(executorId: String, reason: ExecutorLossReason) { + var failedExecutor: Option[String] = None synchronized { - slaveIdToHost.get(slaveId) match { - case Some(host) => - if (hostsAlive.contains(host)) { - logError("Lost an executor on " + host + ": " + reason) - slaveIdsWithExecutors -= slaveId - hostsAlive -= host - activeTaskSetsQueue.foreach(_.hostLost(host)) - failedHost = Some(host) - } else { - // We may get multiple slaveLost() calls with different loss reasons. For example, one - // may be triggered by a dropped connection from the slave while another may be a report - // of executor termination from Mesos. We produce log messages for both so we eventually - // report the termination reason. - logError("Lost an executor on " + host + " (already removed): " + reason) - } - case None => - // We were told about a slave being lost before we could even allocate work to it - logError("Lost slave " + slaveId + " (no work assigned yet)") + if (activeExecutorIds.contains(executorId)) { + val host = executorIdToHost(executorId) + logError("Lost executor %s on %s: %s".format(executorId, host, reason)) + removeExecutor(executorId) + failedExecutor = Some(executorId) + } else { + // We may get multiple executorLost() calls with different loss reasons. For example, one + // may be triggered by a dropped connection from the slave while another may be a report + // of executor termination from Mesos. We produce log messages for both so we eventually + // report the termination reason. + logError("Lost an executor " + executorId + " (already removed): " + reason) } } - if (failedHost != None) { - listener.hostLost(failedHost.get) + // Call listener.executorLost without holding the lock on this to prevent deadlock + if (failedExecutor != None) { + listener.executorLost(failedExecutor.get) backend.reviveOffers() } } + + /** Get a list of hosts that currently have executors */ + def hostsAlive: scala.collection.Set[String] = executorsByHost.keySet + + /** Remove an executor from all our data structures and mark it as lost */ + private def removeExecutor(executorId: String) { + activeExecutorIds -= executorId + val host = executorIdToHost(executorId) + val execs = executorsByHost.getOrElse(host, new HashSet) + execs -= executorId + if (execs.isEmpty) { + executorsByHost -= host + } + executorIdToHost -= executorId + activeTaskSetsQueue.foreach(_.executorLost(executorId, host)) + } } diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 4f82cd96dd..f0792c1b76 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -37,7 +37,7 @@ private[spark] class SparkDeploySchedulerBackend( val masterUrl = "akka://spark@%s:%s/user/%s".format( System.getProperty("spark.master.host"), System.getProperty("spark.master.port"), StandaloneSchedulerBackend.ACTOR_NAME) - val args = Seq(masterUrl, "{{SLAVEID}}", "{{HOSTNAME}}", "{{CORES}}") + val args = Seq(masterUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}") val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) val sparkHome = sc.getSparkHome().getOrElse(throw new IllegalArgumentException("must supply spark home for spark standalone")) val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, sparkHome) @@ -81,7 +81,7 @@ private[spark] class SparkDeploySchedulerBackend( executorIdToSlaveId.get(id) match { case Some(slaveId) => executorIdToSlaveId.remove(id) - scheduler.slaveLost(slaveId, reason) + scheduler.executorLost(slaveId, reason) case None => logInfo("No slave ID known for executor %s".format(id)) } diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index eeaae23dc8..32be1e7a26 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -28,8 +28,8 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor val slaveAddress = new HashMap[String, Address] val slaveHost = new HashMap[String, String] val freeCores = new HashMap[String, Int] - val actorToSlaveId = new HashMap[ActorRef, String] - val addressToSlaveId = new HashMap[Address, String] + val actorToExecutorId = new HashMap[ActorRef, String] + val addressToExecutorId = new HashMap[Address, String] override def preStart() { // Listen for remote client disconnection events, since they don't go through Akka's watch() @@ -37,28 +37,28 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } def receive = { - case RegisterSlave(slaveId, host, cores) => - if (slaveActor.contains(slaveId)) { - sender ! RegisterSlaveFailed("Duplicate slave ID: " + slaveId) + case RegisterSlave(executorId, host, cores) => + if (slaveActor.contains(executorId)) { + sender ! RegisterSlaveFailed("Duplicate executor ID: " + executorId) } else { - logInfo("Registered slave: " + sender + " with ID " + slaveId) + logInfo("Registered executor: " + sender + " with ID " + executorId) sender ! RegisteredSlave(sparkProperties) context.watch(sender) - slaveActor(slaveId) = sender - slaveHost(slaveId) = host - freeCores(slaveId) = cores - slaveAddress(slaveId) = sender.path.address - actorToSlaveId(sender) = slaveId - addressToSlaveId(sender.path.address) = slaveId + slaveActor(executorId) = sender + slaveHost(executorId) = host + freeCores(executorId) = cores + slaveAddress(executorId) = sender.path.address + actorToExecutorId(sender) = executorId + addressToExecutorId(sender.path.address) = executorId totalCoreCount.addAndGet(cores) makeOffers() } - case StatusUpdate(slaveId, taskId, state, data) => + case StatusUpdate(executorId, taskId, state, data) => scheduler.statusUpdate(taskId, state, data.value) if (TaskState.isFinished(state)) { - freeCores(slaveId) += 1 - makeOffers(slaveId) + freeCores(executorId) += 1 + makeOffers(executorId) } case ReviveOffers => @@ -69,13 +69,13 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor context.stop(self) case Terminated(actor) => - actorToSlaveId.get(actor).foreach(removeSlave(_, "Akka actor terminated")) + actorToExecutorId.get(actor).foreach(removeSlave(_, "Akka actor terminated")) case RemoteClientDisconnected(transport, address) => - addressToSlaveId.get(address).foreach(removeSlave(_, "remote Akka client disconnected")) + addressToExecutorId.get(address).foreach(removeSlave(_, "remote Akka client disconnected")) case RemoteClientShutdown(transport, address) => - addressToSlaveId.get(address).foreach(removeSlave(_, "remote Akka client shutdown")) + addressToExecutorId.get(address).foreach(removeSlave(_, "remote Akka client shutdown")) } // Make fake resource offers on all slaves @@ -85,31 +85,31 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } // Make fake resource offers on just one slave - def makeOffers(slaveId: String) { + def makeOffers(executorId: String) { launchTasks(scheduler.resourceOffers( - Seq(new WorkerOffer(slaveId, slaveHost(slaveId), freeCores(slaveId))))) + Seq(new WorkerOffer(executorId, slaveHost(executorId), freeCores(executorId))))) } // Launch tasks returned by a set of resource offers def launchTasks(tasks: Seq[Seq[TaskDescription]]) { for (task <- tasks.flatten) { - freeCores(task.slaveId) -= 1 - slaveActor(task.slaveId) ! LaunchTask(task) + freeCores(task.executorId) -= 1 + slaveActor(task.executorId) ! LaunchTask(task) } } // Remove a disconnected slave from the cluster - def removeSlave(slaveId: String, reason: String) { - logInfo("Slave " + slaveId + " disconnected, so removing it") - val numCores = freeCores(slaveId) - actorToSlaveId -= slaveActor(slaveId) - addressToSlaveId -= slaveAddress(slaveId) - slaveActor -= slaveId - slaveHost -= slaveId - freeCores -= slaveId - slaveHost -= slaveId + def removeSlave(executorId: String, reason: String) { + logInfo("Slave " + executorId + " disconnected, so removing it") + val numCores = freeCores(executorId) + actorToExecutorId -= slaveActor(executorId) + addressToExecutorId -= slaveAddress(executorId) + slaveActor -= executorId + slaveHost -= executorId + freeCores -= executorId + slaveHost -= executorId totalCoreCount.addAndGet(-numCores) - scheduler.slaveLost(slaveId, SlaveLost(reason)) + scheduler.executorLost(executorId, SlaveLost(reason)) } } diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala b/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala index aa097fd3a2..b41e951be9 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala @@ -5,7 +5,7 @@ import spark.util.SerializableBuffer private[spark] class TaskDescription( val taskId: Long, - val slaveId: String, + val executorId: String, val name: String, _serializedTask: ByteBuffer) extends Serializable { diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala index ca84503780..0f975ce1eb 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala @@ -4,7 +4,12 @@ package spark.scheduler.cluster * Information about a running task attempt inside a TaskSet. */ private[spark] -class TaskInfo(val taskId: Long, val index: Int, val launchTime: Long, val host: String) { +class TaskInfo( + val taskId: Long, + val index: Int, + val launchTime: Long, + val executorId: String, + val host: String) { var finishTime: Long = 0 var failed = false diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index a089b71644..26201ad0dd 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -138,10 +138,11 @@ private[spark] class TaskSetManager( // attempt running on this host, in case the host is slow. In addition, if localOnly is set, the // task must have a preference for this host (or no preferred locations at all). def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = { + val hostsAlive = sched.hostsAlive speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set val localTask = speculatableTasks.find { index => - val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive + val locations = tasks(index).preferredLocations.toSet & hostsAlive val attemptLocs = taskAttempts(index).map(_.host) (locations.size == 0 || locations.contains(host)) && !attemptLocs.contains(host) } @@ -189,7 +190,7 @@ private[spark] class TaskSetManager( } // Respond to an offer of a single slave from the scheduler by finding a task - def slaveOffer(slaveId: String, host: String, availableCpus: Double): Option[TaskDescription] = { + def slaveOffer(execId: String, host: String, availableCpus: Double): Option[TaskDescription] = { if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) { val time = System.currentTimeMillis val localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT) @@ -206,11 +207,11 @@ private[spark] class TaskSetManager( } else { "non-preferred, not one of " + task.preferredLocations.mkString(", ") } - logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format( - taskSet.id, index, taskId, slaveId, host, prefStr)) + logInfo("Starting task %s:%d as TID %s on executor %s: %s (%s)".format( + taskSet.id, index, taskId, execId, host, prefStr)) // Do various bookkeeping copiesRunning(index) += 1 - val info = new TaskInfo(taskId, index, time, host) + val info = new TaskInfo(taskId, index, time, execId, host) taskInfos(taskId) = info taskAttempts(index) = info :: taskAttempts(index) if (preferred) { @@ -224,7 +225,7 @@ private[spark] class TaskSetManager( logInfo("Serialized task %s:%d as %d bytes in %d ms".format( taskSet.id, index, serializedTask.limit, timeTaken)) val taskName = "task %s:%d".format(taskSet.id, index) - return Some(new TaskDescription(taskId, slaveId, taskName, serializedTask)) + return Some(new TaskDescription(taskId, execId, taskName, serializedTask)) } case _ => } @@ -356,19 +357,22 @@ private[spark] class TaskSetManager( sched.taskSetFinished(this) } - def hostLost(hostname: String) { - logInfo("Re-queueing tasks for " + hostname + " from TaskSet " + taskSet.id) - // If some task has preferred locations only on hostname, put it in the no-prefs list - // to avoid the wait from delay scheduling - for (index <- getPendingTasksForHost(hostname)) { - val newLocs = tasks(index).preferredLocations.toSet & sched.hostsAlive - if (newLocs.isEmpty) { - pendingTasksWithNoPrefs += index + def executorLost(execId: String, hostname: String) { + logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id) + val newHostsAlive = sched.hostsAlive + // If some task has preferred locations only on hostname, and there are no more executors there, + // put it in the no-prefs list to avoid the wait from delay scheduling + if (!newHostsAlive.contains(hostname)) { + for (index <- getPendingTasksForHost(hostname)) { + val newLocs = tasks(index).preferredLocations.toSet & newHostsAlive + if (newLocs.isEmpty) { + pendingTasksWithNoPrefs += index + } } } - // Re-enqueue any tasks that ran on the failed host if this is a shuffle map stage + // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage if (tasks(0).isInstanceOf[ShuffleMapTask]) { - for ((tid, info) <- taskInfos if info.host == hostname) { + for ((tid, info) <- taskInfos if info.executorId == execId) { val index = taskInfos(tid).index if (finished(index)) { finished(index) = false @@ -382,7 +386,7 @@ private[spark] class TaskSetManager( } } // Also re-enqueue any tasks that were running on the node - for ((tid, info) <- taskInfos if info.running && info.host == hostname) { + for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { taskLost(tid, TaskState.KILLED, null) } } diff --git a/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala b/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala index 6b919d68b2..3c3afcbb14 100644 --- a/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala +++ b/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala @@ -1,8 +1,8 @@ package spark.scheduler.cluster /** - * Represents free resources available on a worker node. + * Represents free resources available on an executor. */ private[spark] -class WorkerOffer(val slaveId: String, val hostname: String, val cores: Int) { +class WorkerOffer(val executorId: String, val hostname: String, val cores: Int) { } diff --git a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala index 2989e31f5e..f3467db86b 100644 --- a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala @@ -268,7 +268,7 @@ private[spark] class MesosSchedulerBackend( synchronized { slaveIdsWithExecutors -= slaveId.getValue } - scheduler.slaveLost(slaveId.getValue, reason) + scheduler.executorLost(slaveId.getValue, reason) } override def slaveLost(d: SchedulerDriver, slaveId: SlaveID) { diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 19d35b8667..1215d5f5c8 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -30,6 +30,7 @@ extends Exception(message) private[spark] class BlockManager( + executorId: String, actorSystem: ActorSystem, val master: BlockManagerMaster, val serializer: Serializer, @@ -68,8 +69,8 @@ class BlockManager( val connectionManager = new ConnectionManager(0) implicit val futureExecContext = connectionManager.futureExecContext - val connectionManagerId = connectionManager.id - val blockManagerId = BlockManagerId(connectionManagerId.host, connectionManagerId.port) + val blockManagerId = BlockManagerId( + executorId, connectionManager.id.host, connectionManager.id.port) // Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory // for receiving shuffle outputs) @@ -109,8 +110,9 @@ class BlockManager( /** * Construct a BlockManager with a memory limit set based on system properties. */ - def this(actorSystem: ActorSystem, master: BlockManagerMaster, serializer: Serializer) = { - this(actorSystem, master, serializer, BlockManager.getMaxMemoryFromSystemProperties) + def this(execId: String, actorSystem: ActorSystem, master: BlockManagerMaster, + serializer: Serializer) = { + this(execId, actorSystem, master, serializer, BlockManager.getMaxMemoryFromSystemProperties) } /** diff --git a/core/src/main/scala/spark/storage/BlockManagerId.scala b/core/src/main/scala/spark/storage/BlockManagerId.scala index abb8b45a1f..f2f1e77d41 100644 --- a/core/src/main/scala/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/spark/storage/BlockManagerId.scala @@ -7,27 +7,32 @@ import java.util.concurrent.ConcurrentHashMap * This class represent an unique identifier for a BlockManager. * The first 2 constructors of this class is made private to ensure that * BlockManagerId objects can be created only using the factory method in - * [[spark.storage.BlockManager$]]. This allows de-duplication of id objects. + * [[spark.storage.BlockManager$]]. This allows de-duplication of ID objects. * Also, constructor parameters are private to ensure that parameters cannot * be modified from outside this class. */ private[spark] class BlockManagerId private ( + private var executorId_ : String, private var ip_ : String, private var port_ : Int ) extends Externalizable { - private def this() = this(null, 0) // For deserialization only + private def this() = this(null, null, 0) // For deserialization only - def ip = ip_ + def executorId: String = executorId_ - def port = port_ + def ip: String = ip_ + + def port: Int = port_ override def writeExternal(out: ObjectOutput) { + out.writeUTF(executorId_) out.writeUTF(ip_) out.writeInt(port_) } override def readExternal(in: ObjectInput) { + executorId_ = in.readUTF() ip_ = in.readUTF() port_ = in.readInt() } @@ -35,21 +40,23 @@ private[spark] class BlockManagerId private ( @throws(classOf[IOException]) private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this) - override def toString = "BlockManagerId(" + ip + ", " + port + ")" + override def toString = "BlockManagerId(%s, %s, %d)".format(executorId, ip, port) - override def hashCode = ip.hashCode * 41 + port + override def hashCode: Int = (executorId.hashCode * 41 + ip.hashCode) * 41 + port override def equals(that: Any) = that match { - case id: BlockManagerId => port == id.port && ip == id.ip - case _ => false + case id: BlockManagerId => + executorId == id.executorId && port == id.port && ip == id.ip + case _ => + false } } private[spark] object BlockManagerId { - def apply(ip: String, port: Int) = - getCachedBlockManagerId(new BlockManagerId(ip, port)) + def apply(execId: String, ip: String, port: Int) = + getCachedBlockManagerId(new BlockManagerId(execId, ip, port)) def apply(in: ObjectInput) = { val obj = new BlockManagerId() diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index 937115e92c..55ff1dde9c 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -24,7 +24,7 @@ private[spark] class BlockManagerMaster( masterPort: Int) extends Logging { - val AKKA_RETRY_ATTEMPS: Int = System.getProperty("spark.akka.num.retries", "3").toInt + val AKKA_RETRY_ATTEMPTS: Int = System.getProperty("spark.akka.num.retries", "3").toInt val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt val MASTER_AKKA_ACTOR_NAME = "BlockMasterManager" @@ -45,10 +45,10 @@ private[spark] class BlockManagerMaster( } } - /** Remove a dead host from the master actor. This is only called on the master side. */ - def notifyADeadHost(host: String) { - tell(RemoveHost(host)) - logInfo("Removed " + host + " successfully in notifyADeadHost") + /** Remove a dead executor from the master actor. This is only called on the master side. */ + def removeExecutor(execId: String) { + tell(RemoveExecutor(execId)) + logInfo("Removed " + execId + " successfully in removeExecutor") } /** @@ -146,7 +146,7 @@ private[spark] class BlockManagerMaster( } var attempts = 0 var lastException: Exception = null - while (attempts < AKKA_RETRY_ATTEMPS) { + while (attempts < AKKA_RETRY_ATTEMPTS) { attempts += 1 try { val future = masterActor.ask(message)(timeout) diff --git a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala index b31b6286d3..f88517f1a3 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala @@ -23,9 +23,8 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { private val blockManagerInfo = new HashMap[BlockManagerId, BlockManagerMasterActor.BlockManagerInfo] - // Mapping from host name to block manager id. We allow multiple block managers - // on the same host name (ip). - private val blockManagerIdByHost = new HashMap[String, ArrayBuffer[BlockManagerId]] + // Mapping from executor ID to block manager ID. + private val blockManagerIdByExecutor = new HashMap[String, BlockManagerId] // Mapping from block id to the set of block managers that have the block. private val blockLocations = new JHashMap[String, Pair[Int, HashSet[BlockManagerId]]] @@ -74,8 +73,8 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { case RemoveBlock(blockId) => removeBlock(blockId) - case RemoveHost(host) => - removeHost(host) + case RemoveExecutor(execId) => + removeExecutor(execId) sender ! true case StopBlockManagerMaster => @@ -99,16 +98,12 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { def removeBlockManager(blockManagerId: BlockManagerId) { val info = blockManagerInfo(blockManagerId) - // Remove the block manager from blockManagerIdByHost. If the list of block - // managers belonging to the IP is empty, remove the entry from the hash map. - blockManagerIdByHost.get(blockManagerId.ip).foreach { managers: ArrayBuffer[BlockManagerId] => - managers -= blockManagerId - if (managers.size == 0) blockManagerIdByHost.remove(blockManagerId.ip) - } + // Remove the block manager from blockManagerIdByExecutor. + blockManagerIdByExecutor -= blockManagerId.executorId // Remove it from blockManagerInfo and remove all the blocks. blockManagerInfo.remove(blockManagerId) - var iterator = info.blocks.keySet.iterator + val iterator = info.blocks.keySet.iterator while (iterator.hasNext) { val blockId = iterator.next val locations = blockLocations.get(blockId)._2 @@ -133,17 +128,15 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { toRemove.foreach(removeBlockManager) } - def removeHost(host: String) { - logInfo("Trying to remove the host: " + host + " from BlockManagerMaster.") - logInfo("Previous hosts: " + blockManagerInfo.keySet.toSeq) - blockManagerIdByHost.get(host).foreach(_.foreach(removeBlockManager)) - logInfo("Current hosts: " + blockManagerInfo.keySet.toSeq) + def removeExecutor(execId: String) { + logInfo("Trying to remove executor " + execId + " from BlockManagerMaster.") + blockManagerIdByExecutor.get(execId).foreach(removeBlockManager) sender ! true } def heartBeat(blockManagerId: BlockManagerId) { if (!blockManagerInfo.contains(blockManagerId)) { - if (blockManagerId.ip == Utils.localHostName() && !isLocal) { + if (blockManagerId.executorId == "" && !isLocal) { sender ! true } else { sender ! false @@ -188,24 +181,20 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { sender ! res } - private def register(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { - val startTimeMs = System.currentTimeMillis() - val tmp = " " + blockManagerId + " " - - if (blockManagerId.ip == Utils.localHostName() && !isLocal) { - logInfo("Got Register Msg from master node, don't register it") - } else if (!blockManagerInfo.contains(blockManagerId)) { - blockManagerIdByHost.get(blockManagerId.ip) match { - case Some(managers) => - // A block manager of the same host name already exists. - logInfo("Got another registration for host " + blockManagerId) - managers += blockManagerId + private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { + if (id.executorId == "" && !isLocal) { + // Got a register message from the master node; don't register it + } else if (!blockManagerInfo.contains(id)) { + blockManagerIdByExecutor.get(id.executorId) match { + case Some(manager) => + // A block manager of the same host name already exists + logError("Got two different block manager registrations on " + id.executorId) + System.exit(1) case None => - blockManagerIdByHost += (blockManagerId.ip -> ArrayBuffer(blockManagerId)) + blockManagerIdByExecutor(id.executorId) = id } - - blockManagerInfo += (blockManagerId -> new BlockManagerMasterActor.BlockManagerInfo( - blockManagerId, System.currentTimeMillis(), maxMemSize, slaveActor)) + blockManagerInfo(id) = new BlockManagerMasterActor.BlockManagerInfo( + id, System.currentTimeMillis(), maxMemSize, slaveActor) } sender ! true } @@ -217,11 +206,8 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { memSize: Long, diskSize: Long) { - val startTimeMs = System.currentTimeMillis() - val tmp = " " + blockManagerId + " " + blockId + " " - if (!blockManagerInfo.contains(blockManagerId)) { - if (blockManagerId.ip == Utils.localHostName() && !isLocal) { + if (blockManagerId.executorId == "" && !isLocal) { // We intentionally do not register the master (except in local mode), // so we should not indicate failure. sender ! true @@ -353,8 +339,8 @@ object BlockManagerMasterActor { _lastSeenMs = System.currentTimeMillis() } - def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long, diskSize: Long) - : Unit = synchronized { + def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long, + diskSize: Long) { updateLastSeenMs() diff --git a/core/src/main/scala/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/spark/storage/BlockManagerMessages.scala index 3d03ff3a93..1494f90103 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMessages.scala @@ -88,7 +88,7 @@ private[spark] case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster private[spark] -case class RemoveHost(host: String) extends ToBlockManagerMaster +case class RemoveExecutor(execId: String) extends ToBlockManagerMaster private[spark] case object StopBlockManagerMaster extends ToBlockManagerMaster diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala index 1003cc7a61..b7423c7234 100644 --- a/core/src/main/scala/spark/storage/BlockManagerUI.scala +++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala @@ -11,6 +11,7 @@ import cc.spray.typeconversion.TwirlSupport._ import scala.collection.mutable.ArrayBuffer import spark.{Logging, SparkContext, SparkEnv} import spark.util.AkkaUtils +import spark.Utils private[spark] @@ -20,10 +21,10 @@ object BlockManagerUI extends Logging { def start(actorSystem : ActorSystem, masterActor: ActorRef, sc: SparkContext) { val webUIDirectives = new BlockManagerUIDirectives(actorSystem, masterActor, sc) try { - logInfo("Starting BlockManager WebUI.") - val port = Option(System.getenv("BLOCKMANAGER_UI_PORT")).getOrElse("9080").toInt - AkkaUtils.startSprayServer(actorSystem, "0.0.0.0", port, + val boundPort = AkkaUtils.startSprayServer(actorSystem, "0.0.0.0", + Option(System.getenv("BLOCKMANAGER_UI_PORT")).getOrElse("9080").toInt, webUIDirectives.handler, "BlockManagerHTTPServer") + logInfo("Started BlockManager web UI at %s:%d".format(Utils.localHostName(), boundPort)) } catch { case e: Exception => logError("Failed to create BlockManager WebUI", e) diff --git a/core/src/main/scala/spark/storage/ThreadingTest.scala b/core/src/main/scala/spark/storage/ThreadingTest.scala index 689f07b969..f04c046c31 100644 --- a/core/src/main/scala/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/spark/storage/ThreadingTest.scala @@ -78,7 +78,8 @@ private[spark] object ThreadingTest { val masterIp: String = System.getProperty("spark.master.host", "localhost") val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt val blockManagerMaster = new BlockManagerMaster(actorSystem, true, true, masterIp, masterPort) - val blockManager = new BlockManager(actorSystem, blockManagerMaster, serializer, 1024 * 1024) + val blockManager = new BlockManager( + "", actorSystem, blockManagerMaster, serializer, 1024 * 1024) val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue)) producers.foreach(_.start) diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index ff2c3079be..775ff8f1aa 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -52,10 +52,10 @@ private[spark] object AkkaUtils { /** * Creates a Spray HTTP server bound to a given IP and port with a given Spray Route object to - * handle requests. Throws a SparkException if this fails. + * handle requests. Returns the bound port or throws a SparkException on failure. */ def startSprayServer(actorSystem: ActorSystem, ip: String, port: Int, route: Route, - name: String = "HttpServer") { + name: String = "HttpServer"): Int = { val ioWorker = new IoWorker(actorSystem).start() val httpService = actorSystem.actorOf(Props(new HttpService(route))) val rootService = actorSystem.actorOf(Props(new SprayCanRootService(httpService))) @@ -67,7 +67,7 @@ private[spark] object AkkaUtils { try { Await.result(future, timeout) match { case bound: HttpServer.Bound => - return + return bound.endpoint.getPort case other: Any => throw new SparkException("Failed to bind web UI to port " + port + ": " + other) } diff --git a/core/src/main/scala/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/spark/util/TimeStampedHashMap.scala index bb7c5c01c8..188f8910da 100644 --- a/core/src/main/scala/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/spark/util/TimeStampedHashMap.scala @@ -63,9 +63,9 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() with spark.Logging { override def empty: Map[A, B] = new TimeStampedHashMap[A, B]() - override def size(): Int = internalMap.size() + override def size: Int = internalMap.size - override def foreach[U](f: ((A, B)) => U): Unit = { + override def foreach[U](f: ((A, B)) => U) { val iterator = internalMap.entrySet().iterator() while(iterator.hasNext) { val entry = iterator.next() diff --git a/core/src/test/scala/spark/DriverSuite.scala b/core/src/test/scala/spark/DriverSuite.scala index 70a7c8bc2f..342610e1dd 100644 --- a/core/src/test/scala/spark/DriverSuite.scala +++ b/core/src/test/scala/spark/DriverSuite.scala @@ -13,7 +13,8 @@ class DriverSuite extends FunSuite with Timeouts { val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]")) forAll(masters) { (master: String) => failAfter(10 seconds) { - Utils.execute(Seq("./run", "spark.DriverWithoutCleanup", master), new File(System.getenv("SPARK_HOME"))) + Utils.execute(Seq("./run", "spark.DriverWithoutCleanup", master), + new File(System.getenv("SPARK_HOME"))) } } } @@ -28,4 +29,4 @@ object DriverWithoutCleanup { val sc = new SparkContext(args(0), "DriverWithoutCleanup") sc.parallelize(1 to 100, 4).count() } -} \ No newline at end of file +} diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala index 7d5305f1e0..e8fe7ecabc 100644 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala @@ -43,13 +43,13 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val compressedSize10000 = MapOutputTracker.compressSize(10000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) val size10000 = MapOutputTracker.decompressSize(compressedSize10000) - tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("hostA", 1000), + tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000), Array(compressedSize1000, compressedSize10000))) - tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("hostB", 1000), + tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000), Array(compressedSize10000, compressedSize1000))) val statuses = tracker.getServerStatuses(10, 0) - assert(statuses.toSeq === Seq((BlockManagerId("hostA", 1000), size1000), - (BlockManagerId("hostB", 1000), size10000))) + assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000), + (BlockManagerId("b", "hostB", 1000), size10000))) tracker.stop() } @@ -61,47 +61,52 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val compressedSize10000 = MapOutputTracker.compressSize(10000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) val size10000 = MapOutputTracker.decompressSize(compressedSize10000) - tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("hostA", 1000), + tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000), Array(compressedSize1000, compressedSize1000, compressedSize1000))) - tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("hostB", 1000), + tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000), Array(compressedSize10000, compressedSize1000, compressedSize1000))) // As if we had two simulatenous fetch failures - tracker.unregisterMapOutput(10, 0, BlockManagerId("hostA", 1000)) - tracker.unregisterMapOutput(10, 0, BlockManagerId("hostA", 1000)) + tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) + tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) - // The remaining reduce task might try to grab the output dispite the shuffle failure; + // The remaining reduce task might try to grab the output despite the shuffle failure; // this should cause it to fail, and the scheduler will ignore the failure due to the // stage already being aborted. intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) } } test("remote fetch") { - val (actorSystem, boundPort) = - AkkaUtils.createActorSystem("test", "localhost", 0) - System.setProperty("spark.master.port", boundPort.toString) - val masterTracker = new MapOutputTracker(actorSystem, true) - val slaveTracker = new MapOutputTracker(actorSystem, false) - masterTracker.registerShuffle(10, 1) - masterTracker.incrementGeneration() - slaveTracker.updateGeneration(masterTracker.getGeneration) - intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + try { + System.clearProperty("spark.master.host") // In case some previous test had set it + val (actorSystem, boundPort) = + AkkaUtils.createActorSystem("test", "localhost", 0) + System.setProperty("spark.master.port", boundPort.toString) + val masterTracker = new MapOutputTracker(actorSystem, true) + val slaveTracker = new MapOutputTracker(actorSystem, false) + masterTracker.registerShuffle(10, 1) + masterTracker.incrementGeneration() + slaveTracker.updateGeneration(masterTracker.getGeneration) + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } - val compressedSize1000 = MapOutputTracker.compressSize(1000L) - val size1000 = MapOutputTracker.decompressSize(compressedSize1000) - masterTracker.registerMapOutput(10, 0, new MapStatus( - BlockManagerId("hostA", 1000), Array(compressedSize1000))) - masterTracker.incrementGeneration() - slaveTracker.updateGeneration(masterTracker.getGeneration) - assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((BlockManagerId("hostA", 1000), size1000))) + val compressedSize1000 = MapOutputTracker.compressSize(1000L) + val size1000 = MapOutputTracker.decompressSize(compressedSize1000) + masterTracker.registerMapOutput(10, 0, new MapStatus( + BlockManagerId("a", "hostA", 1000), Array(compressedSize1000))) + masterTracker.incrementGeneration() + slaveTracker.updateGeneration(masterTracker.getGeneration) + assert(slaveTracker.getServerStatuses(10, 0).toSeq === + Seq((BlockManagerId("a", "hostA", 1000), size1000))) - masterTracker.unregisterMapOutput(10, 0, BlockManagerId("hostA", 1000)) - masterTracker.incrementGeneration() - slaveTracker.updateGeneration(masterTracker.getGeneration) - intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) + masterTracker.incrementGeneration() + slaveTracker.updateGeneration(masterTracker.getGeneration) + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } - // failure should be cached - intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + // failure should be cached + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + } finally { + System.clearProperty("spark.master.port") + } } } diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index 2165744689..2d177bbf67 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -86,9 +86,9 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("BlockManagerId object caching") { - val id1 = BlockManagerId("XXX", 1) - val id2 = BlockManagerId("XXX", 1) // this should return the same object as id1 - val id3 = BlockManagerId("XXX", 2) // this should return a different object + val id1 = BlockManagerId("e1", "XXX", 1) + val id2 = BlockManagerId("e1", "XXX", 1) // this should return the same object as id1 + val id3 = BlockManagerId("e1", "XXX", 2) // this should return a different object assert(id2 === id1, "id2 is not same as id1") assert(id2.eq(id1), "id2 is not the same object as id1") assert(id3 != id1, "id3 is same as id1") @@ -103,7 +103,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("master + 1 manager interaction") { - store = new BlockManager(actorSystem, master, serializer, 2000) + store = new BlockManager("", actorSystem, master, serializer, 2000) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -133,8 +133,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("master + 2 managers interaction") { - store = new BlockManager(actorSystem, master, serializer, 2000) - store2 = new BlockManager(actorSystem, master, new KryoSerializer, 2000) + store = new BlockManager("exec1", actorSystem, master, serializer, 2000) + store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer, 2000) val peers = master.getPeers(store.blockManagerId, 1) assert(peers.size === 1, "master did not return the other manager as a peer") @@ -149,7 +149,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("removing block") { - store = new BlockManager(actorSystem, master, serializer, 2000) + store = new BlockManager("", actorSystem, master, serializer, 2000) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -198,7 +198,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("reregistration on heart beat") { val heartBeat = PrivateMethod[Unit]('heartBeat) - store = new BlockManager(actorSystem, master, serializer, 2000) + store = new BlockManager("", actorSystem, master, serializer, 2000) val a1 = new Array[Byte](400) store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) @@ -206,7 +206,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT assert(store.getSingle("a1") != None, "a1 was not in store") assert(master.getLocations("a1").size > 0, "master was not told about a1") - master.notifyADeadHost(store.blockManagerId.ip) + master.removeExecutor(store.blockManagerId.executorId) assert(master.getLocations("a1").size == 0, "a1 was not removed from master") store invokePrivate heartBeat() @@ -214,14 +214,14 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("reregistration on block update") { - store = new BlockManager(actorSystem, master, serializer, 2000) + store = new BlockManager("", actorSystem, master, serializer, 2000) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) assert(master.getLocations("a1").size > 0, "master was not told about a1") - master.notifyADeadHost(store.blockManagerId.ip) + master.removeExecutor(store.blockManagerId.executorId) assert(master.getLocations("a1").size == 0, "a1 was not removed from master") store.putSingle("a2", a1, StorageLevel.MEMORY_ONLY) @@ -233,35 +233,35 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("reregistration doesn't dead lock") { val heartBeat = PrivateMethod[Unit]('heartBeat) - store = new BlockManager(actorSystem, master, serializer, 2000) + store = new BlockManager("", actorSystem, master, serializer, 2000) val a1 = new Array[Byte](400) val a2 = List(new Array[Byte](400)) // try many times to trigger any deadlocks for (i <- 1 to 100) { - master.notifyADeadHost(store.blockManagerId.ip) + master.removeExecutor(store.blockManagerId.executorId) val t1 = new Thread { - override def run = { + override def run() { store.put("a2", a2.iterator, StorageLevel.MEMORY_ONLY, true) } } val t2 = new Thread { - override def run = { + override def run() { store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) } } val t3 = new Thread { - override def run = { + override def run() { store invokePrivate heartBeat() } } - t1.start - t2.start - t3.start - t1.join - t2.join - t3.join + t1.start() + t2.start() + t3.start() + t1.join() + t2.join() + t3.join() store.dropFromMemory("a1", null) store.dropFromMemory("a2", null) @@ -270,7 +270,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU storage") { - store = new BlockManager(actorSystem, master, serializer, 1200) + store = new BlockManager("", actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -289,7 +289,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU storage with serialization") { - store = new BlockManager(actorSystem, master, serializer, 1200) + store = new BlockManager("", actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -308,14 +308,14 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU for partitions of same RDD") { - store = new BlockManager(actorSystem, master, serializer, 1200) + store = new BlockManager("", actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) store.putSingle("rdd_0_1", a1, StorageLevel.MEMORY_ONLY) store.putSingle("rdd_0_2", a2, StorageLevel.MEMORY_ONLY) store.putSingle("rdd_0_3", a3, StorageLevel.MEMORY_ONLY) - // Even though we accessed rdd_0_3 last, it should not have replaced partitiosn 1 and 2 + // Even though we accessed rdd_0_3 last, it should not have replaced partitions 1 and 2 // from the same RDD assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store") assert(store.getSingle("rdd_0_2") != None, "rdd_0_2 was not in store") @@ -327,7 +327,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU for partitions of multiple RDDs") { - store = new BlockManager(actorSystem, master, serializer, 1200) + store = new BlockManager("", actorSystem, master, serializer, 1200) store.putSingle("rdd_0_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY) store.putSingle("rdd_0_2", new Array[Byte](400), StorageLevel.MEMORY_ONLY) store.putSingle("rdd_1_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY) @@ -350,7 +350,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("on-disk storage") { - store = new BlockManager(actorSystem, master, serializer, 1200) + store = new BlockManager("", actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -363,7 +363,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage") { - store = new BlockManager(actorSystem, master, serializer, 1200) + store = new BlockManager("", actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -378,7 +378,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage with getLocalBytes") { - store = new BlockManager(actorSystem, master, serializer, 1200) + store = new BlockManager("", actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -393,7 +393,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage with serialization") { - store = new BlockManager(actorSystem, master, serializer, 1200) + store = new BlockManager("", actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -408,7 +408,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage with serialization and getLocalBytes") { - store = new BlockManager(actorSystem, master, serializer, 1200) + store = new BlockManager("", actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -423,7 +423,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("LRU with mixed storage levels") { - store = new BlockManager(actorSystem, master, serializer, 1200) + store = new BlockManager("", actorSystem, master, serializer, 1200) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -448,7 +448,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU with streams") { - store = new BlockManager(actorSystem, master, serializer, 1200) + store = new BlockManager("", actorSystem, master, serializer, 1200) val list1 = List(new Array[Byte](200), new Array[Byte](200)) val list2 = List(new Array[Byte](200), new Array[Byte](200)) val list3 = List(new Array[Byte](200), new Array[Byte](200)) @@ -472,7 +472,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("LRU with mixed storage levels and streams") { - store = new BlockManager(actorSystem, master, serializer, 1200) + store = new BlockManager("", actorSystem, master, serializer, 1200) val list1 = List(new Array[Byte](200), new Array[Byte](200)) val list2 = List(new Array[Byte](200), new Array[Byte](200)) val list3 = List(new Array[Byte](200), new Array[Byte](200)) @@ -518,7 +518,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("overly large block") { - store = new BlockManager(actorSystem, master, serializer, 500) + store = new BlockManager("", actorSystem, master, serializer, 500) store.putSingle("a1", new Array[Byte](1000), StorageLevel.MEMORY_ONLY) assert(store.getSingle("a1") === None, "a1 was in store") store.putSingle("a2", new Array[Byte](1000), StorageLevel.MEMORY_AND_DISK) @@ -529,49 +529,49 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("block compression") { try { System.setProperty("spark.shuffle.compress", "true") - store = new BlockManager(actorSystem, master, serializer, 2000) + store = new BlockManager("exec1", actorSystem, master, serializer, 2000) store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize("shuffle_0_0_0") <= 100, "shuffle_0_0_0 was not compressed") store.stop() store = null System.setProperty("spark.shuffle.compress", "false") - store = new BlockManager(actorSystem, master, serializer, 2000) + store = new BlockManager("exec2", actorSystem, master, serializer, 2000) store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize("shuffle_0_0_0") >= 1000, "shuffle_0_0_0 was compressed") store.stop() store = null System.setProperty("spark.broadcast.compress", "true") - store = new BlockManager(actorSystem, master, serializer, 2000) + store = new BlockManager("exec3", actorSystem, master, serializer, 2000) store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize("broadcast_0") <= 100, "broadcast_0 was not compressed") store.stop() store = null System.setProperty("spark.broadcast.compress", "false") - store = new BlockManager(actorSystem, master, serializer, 2000) + store = new BlockManager("exec4", actorSystem, master, serializer, 2000) store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize("broadcast_0") >= 1000, "broadcast_0 was compressed") store.stop() store = null System.setProperty("spark.rdd.compress", "true") - store = new BlockManager(actorSystem, master, serializer, 2000) + store = new BlockManager("exec5", actorSystem, master, serializer, 2000) store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize("rdd_0_0") <= 100, "rdd_0_0 was not compressed") store.stop() store = null System.setProperty("spark.rdd.compress", "false") - store = new BlockManager(actorSystem, master, serializer, 2000) + store = new BlockManager("exec6", actorSystem, master, serializer, 2000) store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize("rdd_0_0") >= 1000, "rdd_0_0 was compressed") store.stop() store = null // Check that any other block types are also kept uncompressed - store = new BlockManager(actorSystem, master, serializer, 2000) + store = new BlockManager("exec7", actorSystem, master, serializer, 2000) store.putSingle("other_block", new Array[Byte](1000), StorageLevel.MEMORY_ONLY) assert(store.memoryStore.getSize("other_block") >= 1000, "other_block was compressed") store.stop() diff --git a/sbt/sbt b/sbt/sbt index a3055c13c1..8f426d18e8 100755 --- a/sbt/sbt +++ b/sbt/sbt @@ -5,4 +5,4 @@ if [ "$MESOS_HOME" != "" ]; then fi export SPARK_HOME=$(cd "$(dirname $0)/.."; pwd) export SPARK_TESTING=1 # To put test classes on classpath -java -Xmx1200M -XX:MaxPermSize=200m $EXTRA_ARGS -jar $SPARK_HOME/sbt/sbt-launch-*.jar "$@" +java -Xmx1200M -XX:MaxPermSize=250m $EXTRA_ARGS -jar $SPARK_HOME/sbt/sbt-launch-*.jar "$@" From 909850729ec59b788645575fdc03df7cc51fe42b Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 27 Jan 2013 23:17:20 -0800 Subject: [PATCH 211/291] Rename more things from slave to executor --- .../spark/deploy/worker/ExecutorRunner.scala | 2 +- .../executor/StandaloneExecutorBackend.scala | 12 ++--- .../scheduler/cluster/SlaveResources.scala | 4 -- .../cluster/SparkDeploySchedulerBackend.scala | 16 ++----- .../cluster/StandaloneClusterMessage.scala | 16 ++++--- .../cluster/StandaloneSchedulerBackend.scala | 48 +++++++++---------- .../scala/spark/storage/BlockManagerUI.scala | 2 + .../scala/spark/util/MetadataCleaner.scala | 10 ++-- 8 files changed, 50 insertions(+), 60 deletions(-) delete mode 100644 core/src/main/scala/spark/scheduler/cluster/SlaveResources.scala diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala index af3acfecb6..f5ff267d44 100644 --- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala @@ -65,7 +65,7 @@ private[spark] class ExecutorRunner( } } - /** Replace variables such as {{SLAVEID}} and {{CORES}} in a command argument passed to us */ + /** Replace variables such as {{EXECUTOR_ID}} and {{CORES}} in a command argument passed to us */ def substituteVariables(argument: String): String = argument match { case "{{EXECUTOR_ID}}" => execId.toString case "{{HOSTNAME}}" => hostname diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala index 435ee5743e..50871802ea 100644 --- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala @@ -8,10 +8,10 @@ import akka.actor.{ActorRef, Actor, Props} import java.util.concurrent.{TimeUnit, ThreadPoolExecutor, SynchronousQueue} import akka.remote.RemoteClientLifeCycleEvent import spark.scheduler.cluster._ -import spark.scheduler.cluster.RegisteredSlave +import spark.scheduler.cluster.RegisteredExecutor import spark.scheduler.cluster.LaunchTask -import spark.scheduler.cluster.RegisterSlaveFailed -import spark.scheduler.cluster.RegisterSlave +import spark.scheduler.cluster.RegisterExecutorFailed +import spark.scheduler.cluster.RegisterExecutor private[spark] class StandaloneExecutorBackend( @@ -30,7 +30,7 @@ private[spark] class StandaloneExecutorBackend( try { logInfo("Connecting to master: " + masterUrl) master = context.actorFor(masterUrl) - master ! RegisterSlave(executorId, hostname, cores) + master ! RegisterExecutor(executorId, hostname, cores) context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) context.watch(master) // Doesn't work with remote actors, but useful for testing } catch { @@ -41,11 +41,11 @@ private[spark] class StandaloneExecutorBackend( } override def receive = { - case RegisteredSlave(sparkProperties) => + case RegisteredExecutor(sparkProperties) => logInfo("Successfully registered with master") executor.initialize(executorId, hostname, sparkProperties) - case RegisterSlaveFailed(message) => + case RegisterExecutorFailed(message) => logError("Slave registration failed: " + message) System.exit(1) diff --git a/core/src/main/scala/spark/scheduler/cluster/SlaveResources.scala b/core/src/main/scala/spark/scheduler/cluster/SlaveResources.scala deleted file mode 100644 index 96ebaa4601..0000000000 --- a/core/src/main/scala/spark/scheduler/cluster/SlaveResources.scala +++ /dev/null @@ -1,4 +0,0 @@ -package spark.scheduler.cluster - -private[spark] -class SlaveResources(val slaveId: String, val hostname: String, val coresFree: Int) {} diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index f0792c1b76..6dd3ae003d 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -19,7 +19,6 @@ private[spark] class SparkDeploySchedulerBackend( var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _ val maxCores = System.getProperty("spark.cores.max", Int.MaxValue.toString).toInt - val executorIdToSlaveId = new HashMap[String, String] // Memory used by each executor (in megabytes) val executorMemory = { @@ -47,7 +46,7 @@ private[spark] class SparkDeploySchedulerBackend( } override def stop() { - stopping = true; + stopping = true super.stop() client.stop() if (shutdownCallback != null) { @@ -67,23 +66,16 @@ private[spark] class SparkDeploySchedulerBackend( } def executorAdded(id: String, workerId: String, host: String, cores: Int, memory: Int) { - executorIdToSlaveId += id -> workerId logInfo("Granted executor ID %s on host %s with %d cores, %s RAM".format( id, host, cores, Utils.memoryMegabytesToString(memory))) } - def executorRemoved(id: String, message: String, exitStatus: Option[Int]) { + def executorRemoved(executorId: String, message: String, exitStatus: Option[Int]) { val reason: ExecutorLossReason = exitStatus match { case Some(code) => ExecutorExited(code) case None => SlaveLost(message) } - logInfo("Executor %s removed: %s".format(id, message)) - executorIdToSlaveId.get(id) match { - case Some(slaveId) => - executorIdToSlaveId.remove(id) - scheduler.executorLost(slaveId, reason) - case None => - logInfo("No slave ID known for executor %s".format(id)) - } + logInfo("Executor %s removed: %s".format(executorId, message)) + scheduler.executorLost(executorId, reason) } } diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala index 1386cd9d44..c68f15bdfa 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala @@ -11,24 +11,26 @@ private[spark] case class LaunchTask(task: TaskDescription) extends StandaloneClusterMessage private[spark] -case class RegisteredSlave(sparkProperties: Seq[(String, String)]) extends StandaloneClusterMessage +case class RegisteredExecutor(sparkProperties: Seq[(String, String)]) + extends StandaloneClusterMessage private[spark] -case class RegisterSlaveFailed(message: String) extends StandaloneClusterMessage +case class RegisterExecutorFailed(message: String) extends StandaloneClusterMessage -// Slaves to master +// Executors to master private[spark] -case class RegisterSlave(slaveId: String, host: String, cores: Int) extends StandaloneClusterMessage +case class RegisterExecutor(executorId: String, host: String, cores: Int) + extends StandaloneClusterMessage private[spark] -case class StatusUpdate(slaveId: String, taskId: Long, state: TaskState, data: SerializableBuffer) +case class StatusUpdate(executorId: String, taskId: Long, state: TaskState, data: SerializableBuffer) extends StandaloneClusterMessage private[spark] object StatusUpdate { /** Alternate factory method that takes a ByteBuffer directly for the data field */ - def apply(slaveId: String, taskId: Long, state: TaskState, data: ByteBuffer): StatusUpdate = { - StatusUpdate(slaveId, taskId, state, new SerializableBuffer(data)) + def apply(executorId: String, taskId: Long, state: TaskState, data: ByteBuffer): StatusUpdate = { + StatusUpdate(executorId, taskId, state, new SerializableBuffer(data)) } } diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 32be1e7a26..69822f568c 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -24,9 +24,9 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor var totalCoreCount = new AtomicInteger(0) class MasterActor(sparkProperties: Seq[(String, String)]) extends Actor { - val slaveActor = new HashMap[String, ActorRef] - val slaveAddress = new HashMap[String, Address] - val slaveHost = new HashMap[String, String] + val executorActor = new HashMap[String, ActorRef] + val executorAddress = new HashMap[String, Address] + val executorHost = new HashMap[String, String] val freeCores = new HashMap[String, Int] val actorToExecutorId = new HashMap[ActorRef, String] val addressToExecutorId = new HashMap[Address, String] @@ -37,17 +37,17 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } def receive = { - case RegisterSlave(executorId, host, cores) => - if (slaveActor.contains(executorId)) { - sender ! RegisterSlaveFailed("Duplicate executor ID: " + executorId) + case RegisterExecutor(executorId, host, cores) => + if (executorActor.contains(executorId)) { + sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId) } else { logInfo("Registered executor: " + sender + " with ID " + executorId) - sender ! RegisteredSlave(sparkProperties) + sender ! RegisteredExecutor(sparkProperties) context.watch(sender) - slaveActor(executorId) = sender - slaveHost(executorId) = host + executorActor(executorId) = sender + executorHost(executorId) = host freeCores(executorId) = cores - slaveAddress(executorId) = sender.path.address + executorAddress(executorId) = sender.path.address actorToExecutorId(sender) = executorId addressToExecutorId(sender.path.address) = executorId totalCoreCount.addAndGet(cores) @@ -69,45 +69,45 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor context.stop(self) case Terminated(actor) => - actorToExecutorId.get(actor).foreach(removeSlave(_, "Akka actor terminated")) + actorToExecutorId.get(actor).foreach(removeExecutor(_, "Akka actor terminated")) case RemoteClientDisconnected(transport, address) => - addressToExecutorId.get(address).foreach(removeSlave(_, "remote Akka client disconnected")) + addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client disconnected")) case RemoteClientShutdown(transport, address) => - addressToExecutorId.get(address).foreach(removeSlave(_, "remote Akka client shutdown")) + addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client shutdown")) } - // Make fake resource offers on all slaves + // Make fake resource offers on all executors def makeOffers() { launchTasks(scheduler.resourceOffers( - slaveHost.toArray.map {case (id, host) => new WorkerOffer(id, host, freeCores(id))})) + executorHost.toArray.map {case (id, host) => new WorkerOffer(id, host, freeCores(id))})) } - // Make fake resource offers on just one slave + // Make fake resource offers on just one executor def makeOffers(executorId: String) { launchTasks(scheduler.resourceOffers( - Seq(new WorkerOffer(executorId, slaveHost(executorId), freeCores(executorId))))) + Seq(new WorkerOffer(executorId, executorHost(executorId), freeCores(executorId))))) } // Launch tasks returned by a set of resource offers def launchTasks(tasks: Seq[Seq[TaskDescription]]) { for (task <- tasks.flatten) { freeCores(task.executorId) -= 1 - slaveActor(task.executorId) ! LaunchTask(task) + executorActor(task.executorId) ! LaunchTask(task) } } // Remove a disconnected slave from the cluster - def removeSlave(executorId: String, reason: String) { + def removeExecutor(executorId: String, reason: String) { logInfo("Slave " + executorId + " disconnected, so removing it") val numCores = freeCores(executorId) - actorToExecutorId -= slaveActor(executorId) - addressToExecutorId -= slaveAddress(executorId) - slaveActor -= executorId - slaveHost -= executorId + actorToExecutorId -= executorActor(executorId) + addressToExecutorId -= executorAddress(executorId) + executorActor -= executorId + executorHost -= executorId freeCores -= executorId - slaveHost -= executorId + executorHost -= executorId totalCoreCount.addAndGet(-numCores) scheduler.executorLost(executorId, SlaveLost(reason)) } diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala index b7423c7234..956ede201e 100644 --- a/core/src/main/scala/spark/storage/BlockManagerUI.scala +++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala @@ -21,6 +21,8 @@ object BlockManagerUI extends Logging { def start(actorSystem : ActorSystem, masterActor: ActorRef, sc: SparkContext) { val webUIDirectives = new BlockManagerUIDirectives(actorSystem, masterActor, sc) try { + // TODO: This needs to find a random free port to bind to. Unfortunately, there's no way + // in spray to do that, so we'll have to rely on something like new ServerSocket() val boundPort = AkkaUtils.startSprayServer(actorSystem, "0.0.0.0", Option(System.getenv("BLOCKMANAGER_UI_PORT")).getOrElse("9080").toInt, webUIDirectives.handler, "BlockManagerHTTPServer") diff --git a/core/src/main/scala/spark/util/MetadataCleaner.scala b/core/src/main/scala/spark/util/MetadataCleaner.scala index 139e21d09e..721c4c6029 100644 --- a/core/src/main/scala/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/spark/util/MetadataCleaner.scala @@ -14,18 +14,16 @@ class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging val task = new TimerTask { def run() { try { - if (delaySeconds > 0) { - cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000)) - logInfo("Ran metadata cleaner for " + name) - } + cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000)) + logInfo("Ran metadata cleaner for " + name) } catch { case e: Exception => logError("Error running cleanup task for " + name, e) } } } - if (periodSeconds > 0) { - logInfo( + if (delaySeconds > 0) { + logDebug( "Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds and " + "period of " + periodSeconds + " secs") timer.schedule(task, periodSeconds * 1000, periodSeconds * 1000) From f03d9760fd8ac67fd0865cb355ba75d2eff507fe Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 27 Jan 2013 23:56:14 -0800 Subject: [PATCH 212/291] Clean up BlockManagerUI a little (make it not be an object, merge with Directives, and bind to a random port) --- core/src/main/scala/spark/SparkContext.scala | 7 +- core/src/main/scala/spark/Utils.scala | 17 ++- .../spark/deploy/master/MasterWebUI.scala | 6 +- .../spark/deploy/worker/WorkerWebUI.scala | 6 +- .../scala/spark/storage/BlockManagerUI.scala | 120 +++++++++--------- .../src/main/scala/spark/util/AkkaUtils.scala | 6 +- .../scala/spark/util/MetadataCleaner.scala | 3 + 7 files changed, 91 insertions(+), 74 deletions(-) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 39721b47ae..77036c1275 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -44,6 +44,7 @@ import scheduler.{ResultTask, ShuffleMapTask, DAGScheduler, TaskScheduler} import spark.scheduler.local.LocalScheduler import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler} import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} +import storage.BlockManagerUI import util.{MetadataCleaner, TimeStampedHashMap} /** @@ -88,8 +89,9 @@ class SparkContext( SparkEnv.set(env) // Start the BlockManager UI - spark.storage.BlockManagerUI.start(SparkEnv.get.actorSystem, - SparkEnv.get.blockManager.master.masterActor, this) + private[spark] val ui = new BlockManagerUI( + env.actorSystem, env.blockManager.master.masterActor, this) + ui.start() // Used to store a URL for each static file/jar together with the file's local timestamp private[spark] val addedFiles = HashMap[String, Long]() @@ -97,7 +99,6 @@ class SparkContext( // Keeps track of all persisted RDDs private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]() - private[spark] val metadataCleaner = new MetadataCleaner("SparkContext", this.cleanup) diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index ae77264372..1e58d01273 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -1,7 +1,7 @@ package spark import java.io._ -import java.net.{NetworkInterface, InetAddress, Inet4Address, URL, URI} +import java.net._ import java.util.{Locale, Random, UUID} import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor} import org.apache.hadoop.conf.Configuration @@ -11,6 +11,7 @@ import scala.collection.JavaConversions._ import scala.io.Source import com.google.common.io.Files import com.google.common.util.concurrent.ThreadFactoryBuilder +import scala.Some /** * Various utility methods used by Spark. @@ -431,4 +432,18 @@ private object Utils extends Logging { } "%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine) } + + /** + * Try to find a free port to bind to on the local host. This should ideally never be needed, + * except that, unfortunately, some of the networking libraries we currently rely on (e.g. Spray) + * don't let users bind to port 0 and then figure out which free port they actually bound to. + * We work around this by binding a ServerSocket and immediately unbinding it. This is *not* + * necessarily guaranteed to work, but it's the best we can do. + */ + def findFreePort(): Int = { + val socket = new ServerSocket(0) + val portBound = socket.getLocalPort + socket.close() + portBound + } } diff --git a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala index 458ee2d665..a01774f511 100644 --- a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala +++ b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala @@ -14,12 +14,15 @@ import cc.spray.typeconversion.SprayJsonSupport._ import spark.deploy._ import spark.deploy.JsonProtocol._ +/** + * Web UI server for the standalone master. + */ private[spark] class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Directives { val RESOURCE_DIR = "spark/deploy/master/webui" val STATIC_RESOURCE_DIR = "spark/deploy/static" - implicit val timeout = Timeout(1 seconds) + implicit val timeout = Timeout(10 seconds) val handler = { get { @@ -76,5 +79,4 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct getFromResourceDirectory(RESOURCE_DIR) } } - } diff --git a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala index f9489d99fc..ef81f072a3 100644 --- a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala +++ b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala @@ -13,12 +13,15 @@ import cc.spray.typeconversion.SprayJsonSupport._ import spark.deploy.{WorkerState, RequestWorkerState} import spark.deploy.JsonProtocol._ +/** + * Web UI server for the standalone worker. + */ private[spark] class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Directives { val RESOURCE_DIR = "spark/deploy/worker/webui" val STATIC_RESOURCE_DIR = "spark/deploy/static" - implicit val timeout = Timeout(1 seconds) + implicit val timeout = Timeout(10 seconds) val handler = { get { @@ -50,5 +53,4 @@ class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Direct getFromResourceDirectory(RESOURCE_DIR) } } - } diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala index 956ede201e..eda320fa47 100644 --- a/core/src/main/scala/spark/storage/BlockManagerUI.scala +++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala @@ -1,32 +1,41 @@ package spark.storage import akka.actor.{ActorRef, ActorSystem} -import akka.dispatch.Await import akka.pattern.ask import akka.util.Timeout import akka.util.duration._ -import cc.spray.Directives import cc.spray.directives._ import cc.spray.typeconversion.TwirlSupport._ +import cc.spray.Directives import scala.collection.mutable.ArrayBuffer -import spark.{Logging, SparkContext, SparkEnv} +import spark.{Logging, SparkContext} import spark.util.AkkaUtils import spark.Utils +/** + * Web UI server for the BlockManager inside each SparkContext. + */ private[spark] -object BlockManagerUI extends Logging { +class BlockManagerUI(val actorSystem: ActorSystem, blockManagerMaster: ActorRef, sc: SparkContext) + extends Directives with Logging { - /* Starts the Web interface for the BlockManager */ - def start(actorSystem : ActorSystem, masterActor: ActorRef, sc: SparkContext) { - val webUIDirectives = new BlockManagerUIDirectives(actorSystem, masterActor, sc) + val STATIC_RESOURCE_DIR = "spark/deploy/static" + + implicit val timeout = Timeout(10 seconds) + + /** Start a HTTP server to run the Web interface */ + def start() { try { - // TODO: This needs to find a random free port to bind to. Unfortunately, there's no way - // in spray to do that, so we'll have to rely on something like new ServerSocket() - val boundPort = AkkaUtils.startSprayServer(actorSystem, "0.0.0.0", - Option(System.getenv("BLOCKMANAGER_UI_PORT")).getOrElse("9080").toInt, - webUIDirectives.handler, "BlockManagerHTTPServer") - logInfo("Started BlockManager web UI at %s:%d".format(Utils.localHostName(), boundPort)) + val port = if (System.getProperty("spark.ui.port") != null) { + System.getProperty("spark.ui.port").toInt + } else { + // TODO: Unfortunately, it's not possible to pass port 0 to spray and figure out which + // random port it bound to, so we have to try to find a local one by creating a socket. + Utils.findFreePort() + } + AkkaUtils.startSprayServer(actorSystem, "0.0.0.0", port, handler, "BlockManagerHTTPServer") + logInfo("Started BlockManager web UI at http://%s:%d".format(Utils.localHostName(), port)) } catch { case e: Exception => logError("Failed to create BlockManager WebUI", e) @@ -34,58 +43,43 @@ object BlockManagerUI extends Logging { } } -} - - -private[spark] -class BlockManagerUIDirectives(val actorSystem: ActorSystem, master: ActorRef, - sc: SparkContext) extends Directives { - - val STATIC_RESOURCE_DIR = "spark/deploy/static" - implicit val timeout = Timeout(1 seconds) - val handler = { - - get { path("") { completeWith { - // Request the current storage status from the Master - val future = master ? GetStorageStatus - future.map { status => - val storageStatusList = status.asInstanceOf[ArrayBuffer[StorageStatus]].toArray - - // Calculate macro-level statistics - val maxMem = storageStatusList.map(_.maxMem).reduce(_+_) - val remainingMem = storageStatusList.map(_.memRemaining).reduce(_+_) - val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize)) - .reduceOption(_+_).getOrElse(0L) - - val rdds = StorageUtils.rddInfoFromStorageStatus(storageStatusList, sc) - - spark.storage.html.index. - render(maxMem, remainingMem, diskSpaceUsed, rdds, storageStatusList) + get { + path("") { + completeWith { + // Request the current storage status from the Master + val future = blockManagerMaster ? GetStorageStatus + future.map { status => + // Calculate macro-level statistics + val storageStatusList = status.asInstanceOf[ArrayBuffer[StorageStatus]].toArray + val maxMem = storageStatusList.map(_.maxMem).reduce(_+_) + val remainingMem = storageStatusList.map(_.memRemaining).reduce(_+_) + val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize)) + .reduceOption(_+_).getOrElse(0L) + val rdds = StorageUtils.rddInfoFromStorageStatus(storageStatusList, sc) + spark.storage.html.index. + render(maxMem, remainingMem, diskSpaceUsed, rdds, storageStatusList) + } + } + } ~ + path("rdd") { + parameter("id") { id => + completeWith { + val future = blockManagerMaster ? GetStorageStatus + future.map { status => + val prefix = "rdd_" + id.toString + val storageStatusList = status.asInstanceOf[ArrayBuffer[StorageStatus]].toArray + val filteredStorageStatusList = StorageUtils. + filterStorageStatusByPrefix(storageStatusList, prefix) + val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head + spark.storage.html.rdd.render(rddInfo, filteredStorageStatusList) + } + } + } + } ~ + pathPrefix("static") { + getFromResourceDirectory(STATIC_RESOURCE_DIR) } - }}} ~ - get { path("rdd") { parameter("id") { id => { completeWith { - val future = master ? GetStorageStatus - future.map { status => - val prefix = "rdd_" + id.toString - - - val storageStatusList = status.asInstanceOf[ArrayBuffer[StorageStatus]].toArray - val filteredStorageStatusList = StorageUtils. - filterStorageStatusByPrefix(storageStatusList, prefix) - - val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head - - spark.storage.html.rdd.render(rddInfo, filteredStorageStatusList) - - } - }}}}} ~ - pathPrefix("static") { - getFromResourceDirectory(STATIC_RESOURCE_DIR) } - } - - - } diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index 775ff8f1aa..e0fdeffbc4 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -1,6 +1,6 @@ package spark.util -import akka.actor.{Props, ActorSystemImpl, ActorSystem} +import akka.actor.{ActorRef, Props, ActorSystemImpl, ActorSystem} import com.typesafe.config.ConfigFactory import akka.util.duration._ import akka.pattern.ask @@ -55,7 +55,7 @@ private[spark] object AkkaUtils { * handle requests. Returns the bound port or throws a SparkException on failure. */ def startSprayServer(actorSystem: ActorSystem, ip: String, port: Int, route: Route, - name: String = "HttpServer"): Int = { + name: String = "HttpServer"): ActorRef = { val ioWorker = new IoWorker(actorSystem).start() val httpService = actorSystem.actorOf(Props(new HttpService(route))) val rootService = actorSystem.actorOf(Props(new SprayCanRootService(httpService))) @@ -67,7 +67,7 @@ private[spark] object AkkaUtils { try { Await.result(future, timeout) match { case bound: HttpServer.Bound => - return bound.endpoint.getPort + return server case other: Any => throw new SparkException("Failed to bind web UI to port " + port + ": " + other) } diff --git a/core/src/main/scala/spark/util/MetadataCleaner.scala b/core/src/main/scala/spark/util/MetadataCleaner.scala index 721c4c6029..51fb440108 100644 --- a/core/src/main/scala/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/spark/util/MetadataCleaner.scala @@ -5,6 +5,9 @@ import java.util.{TimerTask, Timer} import spark.Logging +/** + * Runs a timer task to periodically clean up metadata (e.g. old files or hashtable entries) + */ class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging { val delaySeconds = MetadataCleaner.getDelaySeconds From 286f8f876ff495df33a7966e77ca90d69f338450 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 28 Jan 2013 01:29:27 -0800 Subject: [PATCH 213/291] Change time unit in MetadataCleaner to seconds --- core/src/main/scala/spark/util/MetadataCleaner.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/spark/util/MetadataCleaner.scala b/core/src/main/scala/spark/util/MetadataCleaner.scala index 51fb440108..6cf93a9b17 100644 --- a/core/src/main/scala/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/spark/util/MetadataCleaner.scala @@ -9,7 +9,6 @@ import spark.Logging * Runs a timer task to periodically clean up metadata (e.g. old files or hashtable entries) */ class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging { - val delaySeconds = MetadataCleaner.getDelaySeconds val periodSeconds = math.max(10, delaySeconds / 10) val timer = new Timer(name + " cleanup timer", true) @@ -39,7 +38,7 @@ class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging object MetadataCleaner { - def getDelaySeconds = (System.getProperty("spark.cleaner.delay", "-100").toDouble * 60).toInt - def setDelaySeconds(delay: Long) { System.setProperty("spark.cleaner.delay", delay.toString) } + def getDelaySeconds = System.getProperty("spark.cleaner.delay", "-1").toInt + def setDelaySeconds(delay: Int) { System.setProperty("spark.cleaner.delay", delay.toString) } } From 07f568e1bfc67eead88e2c5dbfb9cac23e1ac8bc Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 24 Jan 2013 15:27:29 -0800 Subject: [PATCH 214/291] SPARK-658: Adding logging of stage duration --- .../scala/spark/scheduler/DAGScheduler.scala | 21 +++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index bd541d4207..8aad667182 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -86,6 +86,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val activeJobs = new HashSet[ActiveJob] val resultStageToJob = new HashMap[Stage, ActiveJob] + val stageSubmissionTimes = new HashMap[Stage, Long] val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup) @@ -393,6 +394,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with logDebug("New pending tasks: " + myPending) taskSched.submitTasks( new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.priority)) + if (!stageSubmissionTimes.contains(stage)) { + stageSubmissionTimes.put(stage, System.currentTimeMillis()) + } } else { logDebug("Stage " + stage + " is actually done; %b %d %d".format( stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions)) @@ -407,6 +411,15 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with def handleTaskCompletion(event: CompletionEvent) { val task = event.task val stage = idToStage(task.stageId) + + def stageFinished(stage: Stage) = { + val serviceTime = stageSubmissionTimes.remove(stage) match { + case Some(t) => (System.currentTimeMillis() - t).toString + case _ => "Unkown" + } + logInfo("%s (%s) finished in %s ms".format(stage, stage.origin, serviceTime)) + running -= stage + } event.reason match { case Success => logInfo("Completed " + task) @@ -421,13 +434,13 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with if (!job.finished(rt.outputId)) { job.finished(rt.outputId) = true job.numFinished += 1 - job.listener.taskSucceeded(rt.outputId, event.result) // If the whole job has finished, remove it if (job.numFinished == job.numPartitions) { activeJobs -= job resultStageToJob -= stage - running -= stage + stageFinished(stage) } + job.listener.taskSucceeded(rt.outputId, event.result) } case None => logInfo("Ignoring result from " + rt + " because its job has finished") @@ -444,8 +457,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with stage.addOutputLoc(smt.partition, status) } if (running.contains(stage) && pendingTasks(stage).isEmpty) { - logInfo(stage + " (" + stage.origin + ") finished; looking for newly runnable stages") - running -= stage + stageFinished(stage) + logInfo("looking for newly runnable stages") logInfo("running: " + running) logInfo("waiting: " + waiting) logInfo("failed: " + failed) From c423be7d8e1349fc00431328b76b52f4eee8a975 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 24 Jan 2013 18:25:57 -0800 Subject: [PATCH 215/291] Renaming stage finished function --- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 8aad667182..bce7418e87 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -412,7 +412,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val task = event.task val stage = idToStage(task.stageId) - def stageFinished(stage: Stage) = { + def markStageAsFinished(stage: Stage) = { val serviceTime = stageSubmissionTimes.remove(stage) match { case Some(t) => (System.currentTimeMillis() - t).toString case _ => "Unkown" @@ -438,7 +438,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with if (job.numFinished == job.numPartitions) { activeJobs -= job resultStageToJob -= stage - stageFinished(stage) + markStageAsFinished(stage) } job.listener.taskSucceeded(rt.outputId, event.result) } @@ -457,7 +457,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with stage.addOutputLoc(smt.partition, status) } if (running.contains(stage) && pendingTasks(stage).isEmpty) { - stageFinished(stage) + markStageAsFinished(stage) logInfo("looking for newly runnable stages") logInfo("running: " + running) logInfo("waiting: " + waiting) From 501433f1d59b1b326c0a7169fa1fd6136f7628e3 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 28 Jan 2013 10:17:35 -0800 Subject: [PATCH 216/291] Making submission time a field --- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 7 +++---- core/src/main/scala/spark/scheduler/Stage.scala | 3 +++ 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index bce7418e87..7ba1f3430a 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -86,7 +86,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val activeJobs = new HashSet[ActiveJob] val resultStageToJob = new HashMap[Stage, ActiveJob] - val stageSubmissionTimes = new HashMap[Stage, Long] val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup) @@ -394,8 +393,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with logDebug("New pending tasks: " + myPending) taskSched.submitTasks( new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.priority)) - if (!stageSubmissionTimes.contains(stage)) { - stageSubmissionTimes.put(stage, System.currentTimeMillis()) + if (!stage.submissionTime.isDefined) { + stage.submissionTime = Some(System.currentTimeMillis()) } } else { logDebug("Stage " + stage + " is actually done; %b %d %d".format( @@ -413,7 +412,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val stage = idToStage(task.stageId) def markStageAsFinished(stage: Stage) = { - val serviceTime = stageSubmissionTimes.remove(stage) match { + val serviceTime = stage.submissionTime match { case Some(t) => (System.currentTimeMillis() - t).toString case _ => "Unkown" } diff --git a/core/src/main/scala/spark/scheduler/Stage.scala b/core/src/main/scala/spark/scheduler/Stage.scala index e9419728e3..374114d870 100644 --- a/core/src/main/scala/spark/scheduler/Stage.scala +++ b/core/src/main/scala/spark/scheduler/Stage.scala @@ -32,6 +32,9 @@ private[spark] class Stage( val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) var numAvailableOutputs = 0 + /** When first task was submitted to scheduler. */ + var submissionTime: Option[Long] = None + private var nextAttemptId = 0 def isAvailable: Boolean = { From a423ee546c389b5ce0d2117299456712370d7ad1 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Tue, 22 Jan 2013 18:48:43 -0800 Subject: [PATCH 217/291] expose RDD & storage info directly via SparkContext --- core/src/main/scala/spark/SparkContext.scala | 16 ++++++++ .../spark/storage/BlockManagerMaster.scala | 4 ++ .../scala/spark/storage/BlockManagerUI.scala | 39 +++++++------------ .../scala/spark/storage/StorageUtils.scala | 10 +++-- 4 files changed, 41 insertions(+), 28 deletions(-) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 77036c1275..be992250a9 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -46,6 +46,7 @@ import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, C import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import storage.BlockManagerUI import util.{MetadataCleaner, TimeStampedHashMap} +import storage.{StorageStatus, StorageUtils, RDDInfo} /** * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark @@ -473,6 +474,21 @@ class SparkContext( } } + /** + * Return information about what RDDs are cached, if they are in mem or on disk, how much space + * they take, etc. + */ + def getRDDStorageInfo : Array[RDDInfo] = { + StorageUtils.rddInfoFromStorageStatus(getSlavesStorageStatus, this) + } + + /** + * Return information about blocks stored in all of the slaves + */ + def getSlavesStorageStatus : Array[StorageStatus] = { + env.blockManager.master.getStorageStatus + } + /** * Clear the job's list of files added by `addFile` so that they do not get downloaded to * any new nodes. diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index 55ff1dde9c..c7ee76f0b7 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -118,6 +118,10 @@ private[spark] class BlockManagerMaster( askMasterWithRetry[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) } + def getStorageStatus: Array[StorageStatus] = { + askMasterWithRetry[ArrayBuffer[StorageStatus]](GetStorageStatus).toArray + } + /** Stop the master actor, called only on the Spark master node */ def stop() { if (masterActor != null) { diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala index eda320fa47..52f6d1b657 100644 --- a/core/src/main/scala/spark/storage/BlockManagerUI.scala +++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala @@ -1,13 +1,10 @@ package spark.storage import akka.actor.{ActorRef, ActorSystem} -import akka.pattern.ask import akka.util.Timeout import akka.util.duration._ -import cc.spray.directives._ import cc.spray.typeconversion.TwirlSupport._ import cc.spray.Directives -import scala.collection.mutable.ArrayBuffer import spark.{Logging, SparkContext} import spark.util.AkkaUtils import spark.Utils @@ -48,32 +45,26 @@ class BlockManagerUI(val actorSystem: ActorSystem, blockManagerMaster: ActorRef, path("") { completeWith { // Request the current storage status from the Master - val future = blockManagerMaster ? GetStorageStatus - future.map { status => - // Calculate macro-level statistics - val storageStatusList = status.asInstanceOf[ArrayBuffer[StorageStatus]].toArray - val maxMem = storageStatusList.map(_.maxMem).reduce(_+_) - val remainingMem = storageStatusList.map(_.memRemaining).reduce(_+_) - val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize)) - .reduceOption(_+_).getOrElse(0L) - val rdds = StorageUtils.rddInfoFromStorageStatus(storageStatusList, sc) - spark.storage.html.index. - render(maxMem, remainingMem, diskSpaceUsed, rdds, storageStatusList) - } + val storageStatusList = sc.getSlavesStorageStatus + // Calculate macro-level statistics + val maxMem = storageStatusList.map(_.maxMem).reduce(_+_) + val remainingMem = storageStatusList.map(_.memRemaining).reduce(_+_) + val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize)) + .reduceOption(_+_).getOrElse(0L) + val rdds = StorageUtils.rddInfoFromStorageStatus(storageStatusList, sc) + spark.storage.html.index. + render(maxMem, remainingMem, diskSpaceUsed, rdds, storageStatusList) } } ~ path("rdd") { parameter("id") { id => completeWith { - val future = blockManagerMaster ? GetStorageStatus - future.map { status => - val prefix = "rdd_" + id.toString - val storageStatusList = status.asInstanceOf[ArrayBuffer[StorageStatus]].toArray - val filteredStorageStatusList = StorageUtils. - filterStorageStatusByPrefix(storageStatusList, prefix) - val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head - spark.storage.html.rdd.render(rddInfo, filteredStorageStatusList) - } + val prefix = "rdd_" + id.toString + val storageStatusList = sc.getSlavesStorageStatus + val filteredStorageStatusList = StorageUtils. + filterStorageStatusByPrefix(storageStatusList, prefix) + val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head + spark.storage.html.rdd.render(rddInfo, filteredStorageStatusList) } } } ~ diff --git a/core/src/main/scala/spark/storage/StorageUtils.scala b/core/src/main/scala/spark/storage/StorageUtils.scala index a10e3a95c6..d6e33c8619 100644 --- a/core/src/main/scala/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/spark/storage/StorageUtils.scala @@ -56,9 +56,11 @@ object StorageUtils { // Find the id of the RDD, e.g. rdd_1 => 1 val rddId = rddKey.split("_").last.toInt // Get the friendly name for the rdd, if available. - val rddName = Option(sc.persistentRdds(rddId).name).getOrElse(rddKey) - val rddStorageLevel = sc.persistentRdds(rddId).getStorageLevel - + val rdd = sc.persistentRdds(rddId) + val rddName = Option(rdd.name).getOrElse(rddKey) + val rddStorageLevel = rdd.getStorageLevel + //TODO get total number of partitions in rdd + RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, memSize, diskSize) }.toArray } @@ -75,4 +77,4 @@ object StorageUtils { } -} \ No newline at end of file +} From 0f22c4207f27bc8d1675af82f873141dda754f5c Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Mon, 28 Jan 2013 10:08:59 -0800 Subject: [PATCH 218/291] better formatting for RDDInfo --- core/src/main/scala/spark/storage/StorageUtils.scala | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/spark/storage/StorageUtils.scala b/core/src/main/scala/spark/storage/StorageUtils.scala index d6e33c8619..ce7c067eea 100644 --- a/core/src/main/scala/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/spark/storage/StorageUtils.scala @@ -1,6 +1,6 @@ package spark.storage -import spark.SparkContext +import spark.{Utils, SparkContext} import BlockManagerMasterActor.BlockStatus private[spark] @@ -22,8 +22,14 @@ case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long, } case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel, - numPartitions: Int, memSize: Long, diskSize: Long) - + numPartitions: Int, memSize: Long, diskSize: Long) { + override def toString = { + import Utils.memoryBytesToString + import java.lang.{Integer => JInt} + String.format("RDD \"%s\" (%d) Storage: %s; Partitions: %d; MemorySize: %s; DiskSize: %s", name, id.asInstanceOf[JInt], + storageLevel.toString, numPartitions.asInstanceOf[JInt], memoryBytesToString(memSize), memoryBytesToString(diskSize)) + } +} /* Helper methods for storage-related objects */ private[spark] From efff7bfb3382f4e07f9fad0e6e647c0ec629355e Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Mon, 28 Jan 2013 20:23:11 -0800 Subject: [PATCH 219/291] add long and float accumulatorparams --- core/src/main/scala/spark/SparkContext.scala | 10 ++++++++++ core/src/test/scala/spark/AccumulatorSuite.scala | 6 ++++++ 2 files changed, 16 insertions(+) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 77036c1275..dc9b8688b3 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -673,6 +673,16 @@ object SparkContext { def zero(initialValue: Int) = 0 } + implicit object LongAccumulatorParam extends AccumulatorParam[Long] { + def addInPlace(t1: Long, t2: Long) = t1 + t2 + def zero(initialValue: Long) = 0l + } + + implicit object FloatAccumulatorParam extends AccumulatorParam[Float] { + def addInPlace(t1: Float, t2: Float) = t1 + t2 + def zero(initialValue: Float) = 0f + } + // TODO: Add AccumulatorParams for other types, e.g. lists and strings implicit def rddToPairRDDFunctions[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) = diff --git a/core/src/test/scala/spark/AccumulatorSuite.scala b/core/src/test/scala/spark/AccumulatorSuite.scala index 78d64a44ae..ac8ae7d308 100644 --- a/core/src/test/scala/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/spark/AccumulatorSuite.scala @@ -17,6 +17,12 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkConte val d = sc.parallelize(1 to 20) d.foreach{x => acc += x} acc.value should be (210) + + + val longAcc = sc.accumulator(0l) + val maxInt = Integer.MAX_VALUE.toLong + d.foreach{x => longAcc += maxInt + x} + longAcc.value should be (210l + maxInt * 20) } test ("value not assignable from tasks") { From 1f9b486a8be49ef547ac1532cafd63c4c9d4ddda Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 28 Jan 2013 20:24:54 -0800 Subject: [PATCH 220/291] Some DEBUG-level log cleanup. A few changes to make the DEBUG-level logs less noisy and more readable. - Moved a few very frequent messages to Trace - Changed some BlockManger log messages to make them more understandable SPARK-666 #resolve --- .../main/scala/spark/scheduler/DAGScheduler.scala | 8 ++++---- .../main/scala/spark/storage/BlockManager.scala | 14 +++++++------- .../spark/storage/BlockManagerMasterActor.scala | 2 +- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index bd541d4207..f10d7cc84e 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -308,10 +308,10 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with } else { // TODO: We might want to run this less often, when we are sure that something has become // runnable that wasn't before. - logDebug("Checking for newly runnable parent stages") - logDebug("running: " + running) - logDebug("waiting: " + waiting) - logDebug("failed: " + failed) + logTrace("Checking for newly runnable parent stages") + logTrace("running: " + running) + logTrace("waiting: " + waiting) + logTrace("failed: " + failed) val waiting2 = waiting.toArray waiting.clear() for (stage <- waiting2.sortBy(_.priority)) { diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 1215d5f5c8..c61fd75c2b 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -243,7 +243,7 @@ class BlockManager( val startTimeMs = System.currentTimeMillis var managers = master.getLocations(blockId) val locations = managers.map(_.ip) - logDebug("Get block locations in " + Utils.getUsedTimeMs(startTimeMs)) + logDebug("Got block locations in " + Utils.getUsedTimeMs(startTimeMs)) return locations } @@ -253,7 +253,7 @@ class BlockManager( def getLocations(blockIds: Array[String]): Array[Seq[String]] = { val startTimeMs = System.currentTimeMillis val locations = master.getLocations(blockIds).map(_.map(_.ip).toSeq).toArray - logDebug("Get multiple block location in " + Utils.getUsedTimeMs(startTimeMs)) + logDebug("Got multiple block location in " + Utils.getUsedTimeMs(startTimeMs)) return locations } @@ -645,7 +645,7 @@ class BlockManager( var size = 0L myInfo.synchronized { - logDebug("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) + logTrace("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) + " to get into synchronized block") if (level.useMemory) { @@ -677,8 +677,10 @@ class BlockManager( } logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs)) + // Replicate block if required if (level.replication > 1) { + val remoteStartTime = System.currentTimeMillis // Serialize the block if not already done if (bytesAfterPut == null) { if (valuesAfterPut == null) { @@ -688,12 +690,10 @@ class BlockManager( bytesAfterPut = dataSerialize(blockId, valuesAfterPut) } replicate(blockId, bytesAfterPut, level) + logDebug("Put block " + blockId + " remotely took " + Utils.getUsedTimeMs(remoteStartTime)) } - BlockManager.dispose(bytesAfterPut) - logDebug("Put block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)) - return size } @@ -978,7 +978,7 @@ object BlockManager extends Logging { */ def dispose(buffer: ByteBuffer) { if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) { - logDebug("Unmapping " + buffer) + logTrace("Unmapping " + buffer) if (buffer.asInstanceOf[DirectBuffer].cleaner() != null) { buffer.asInstanceOf[DirectBuffer].cleaner().clean() } diff --git a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala index f88517f1a3..2830bc6297 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala @@ -115,7 +115,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { } def expireDeadHosts() { - logDebug("Checking for hosts with no recent heart beats in BlockManagerMaster.") + logTrace("Checking for hosts with no recent heart beats in BlockManagerMaster.") val now = System.currentTimeMillis() val minSeenTime = now - slaveTimeout val toRemove = new HashSet[BlockManagerId] From 7ee824e42ebaa1fc0b0248e0a35021108625ed14 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 28 Jan 2013 21:48:32 -0800 Subject: [PATCH 221/291] Units from ms -> s --- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 7ba1f3430a..b8336d9d06 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -413,10 +413,10 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with def markStageAsFinished(stage: Stage) = { val serviceTime = stage.submissionTime match { - case Some(t) => (System.currentTimeMillis() - t).toString + case Some(t) => "%.03f".format((System.currentTimeMillis() - t) / 1000.0) case _ => "Unkown" } - logInfo("%s (%s) finished in %s ms".format(stage, stage.origin, serviceTime)) + logInfo("%s (%s) finished in %s s".format(stage, stage.origin, serviceTime)) running -= stage } event.reason match { From b45857c965219e2d26f35adb2ea3a2b831fdb77f Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Mon, 28 Jan 2013 23:56:56 -0600 Subject: [PATCH 222/291] Add RDD.toDebugString. Original idea by Nathan Kronenfeld. --- core/src/main/scala/spark/RDD.scala | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 0d3857f9dd..172431c31a 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -638,4 +638,14 @@ abstract class RDD[T: ClassManifest]( protected[spark] def clearDependencies() { dependencies_ = null } + + /** A description of this RDD and its recursive dependencies for debugging. */ + def toDebugString(): String = { + def debugString(rdd: RDD[_], prefix: String = ""): Seq[String] = { + Seq(prefix + rdd) ++ rdd.dependencies.flatMap(d => debugString(d.rdd, prefix + " ")) + } + debugString(this).mkString("\n") + } + + override def toString() = "%s[%d] at %s".format(getClass.getSimpleName, id, origin) } From 951cfd9ba2a9239a777f156f10af820e9df49606 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 29 Jan 2013 00:02:17 -0600 Subject: [PATCH 223/291] Add JavaRDDLike.toDebugString(). --- core/src/main/scala/spark/api/java/JavaRDDLike.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index 4c95c989b5..44f778e5c2 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -330,4 +330,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround case _ => Optional.absent() } } + + /** A description of this RDD and its recursive dependencies for debugging. */ + def toDebugString(): String = { + rdd.toDebugString() + } } From 3cda14af3fea97c2372c7335505e9dad7e0dd117 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 29 Jan 2013 00:12:31 -0600 Subject: [PATCH 224/291] Add number of splits. --- core/src/main/scala/spark/RDD.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 172431c31a..39bacd2afb 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -642,7 +642,8 @@ abstract class RDD[T: ClassManifest]( /** A description of this RDD and its recursive dependencies for debugging. */ def toDebugString(): String = { def debugString(rdd: RDD[_], prefix: String = ""): Seq[String] = { - Seq(prefix + rdd) ++ rdd.dependencies.flatMap(d => debugString(d.rdd, prefix + " ")) + Seq(prefix + rdd + " (" + rdd.splits.size + " splits)") ++ + rdd.dependencies.flatMap(d => debugString(d.rdd, prefix + " ")) } debugString(this).mkString("\n") } From cbf72bffa5874319c7ee7117a073e9d01fa51585 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 29 Jan 2013 00:20:36 -0600 Subject: [PATCH 225/291] Include name, if set, in RDD.toString(). --- core/src/main/scala/spark/RDD.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 39bacd2afb..a23441483e 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -648,5 +648,10 @@ abstract class RDD[T: ClassManifest]( debugString(this).mkString("\n") } - override def toString() = "%s[%d] at %s".format(getClass.getSimpleName, id, origin) + override def toString(): String = "%s%s[%d] at %s".format( + Option(name).map(_ + " ").getOrElse(""), + getClass.getSimpleName, + id, + origin) + } From b29599e5cf0272f0d0e3ceceebb473a8163eab8c Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 28 Jan 2013 22:24:47 -0800 Subject: [PATCH 226/291] Fix code that depended on metadata cleaner interval being in minutes --- streaming/src/main/scala/spark/streaming/DStream.scala | 8 ++++---- .../src/main/scala/spark/streaming/StreamingContext.scala | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index b11ef443dc..352f83fe0c 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -198,10 +198,10 @@ abstract class DStream[T: ClassManifest] ( metadataCleanerDelay < 0 || rememberDuration.milliseconds < metadataCleanerDelay * 1000, "It seems you are doing some DStream window operation or setting a checkpoint interval " + "which requires " + this.getClass.getSimpleName + " to remember generated RDDs for more " + - "than " + rememberDuration.milliseconds + " milliseconds. But the Spark's metadata cleanup" + - "delay is set to " + (metadataCleanerDelay / 60.0) + " minutes, which is not sufficient. Please set " + - "the Java property 'spark.cleaner.delay' to more than " + - math.ceil(rememberDuration.milliseconds.toDouble / 60000.0).toInt + " minutes." + "than " + rememberDuration.milliseconds / 1000 + " seconds. But Spark's metadata cleanup" + + "delay is set to " + metadataCleanerDelay + " seconds, which is not sufficient. Please " + + "set the Java property 'spark.cleaner.delay' to more than " + + math.ceil(rememberDuration.milliseconds / 1000.0).toInt + " seconds." ) dependencies.foreach(_.validate()) diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 14500bdcb1..37ba524b48 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -389,7 +389,7 @@ object StreamingContext { // Set the default cleaner delay to an hour if not already set. // This should be sufficient for even 1 second interval. if (MetadataCleaner.getDelaySeconds < 0) { - MetadataCleaner.setDelaySeconds(60) + MetadataCleaner.setDelaySeconds(3600) } new SparkContext(master, frameworkName) } From 64ba6a8c2c5f46e6de6deb6a6fd576a55cb3b198 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 28 Jan 2013 22:30:12 -0800 Subject: [PATCH 227/291] Simplify checkpointing code and RDD class a little: - RDD's getDependencies and getSplits methods are now guaranteed to be called only once, so subclasses can safely do computation in there without worrying about caching the results. - The management of a "splits_" variable that is cleared out when we checkpoint an RDD is now done in the RDD class. - A few of the RDD subclasses are simpler. - CheckpointRDD's compute() method no longer assumes that it is given a CheckpointRDDSplit -- it can work just as well on a split from the original RDD, because it only looks at its index. This is important because things like UnionRDD and ZippedRDD remember the parent's splits as part of their own and wouldn't work on checkpointed parents. - RDD.iterator can now reuse cached data if an RDD is computed before it is checkpointed. It seems like it wouldn't do this before (it always called iterator() on the CheckpointRDD, which read from HDFS). --- core/src/main/scala/spark/CacheManager.scala | 6 +- .../main/scala/spark/PairRDDFunctions.scala | 4 +- core/src/main/scala/spark/RDD.scala | 130 ++++++++++-------- .../main/scala/spark/RDDCheckpointData.scala | 19 +-- .../scala/spark/api/java/JavaRDDLike.scala | 2 +- .../main/scala/spark/rdd/CartesianRDD.scala | 12 +- .../main/scala/spark/rdd/CheckpointRDD.scala | 61 ++++---- .../main/scala/spark/rdd/CoalescedRDD.scala | 14 +- core/src/main/scala/spark/rdd/MappedRDD.scala | 6 +- .../scala/spark/rdd/PartitionPruningRDD.scala | 13 +- .../main/scala/spark/rdd/ShuffledRDD.scala | 8 +- core/src/main/scala/spark/rdd/UnionRDD.scala | 14 +- core/src/main/scala/spark/rdd/ZippedRDD.scala | 7 +- .../scala/spark/util/MetadataCleaner.scala | 4 +- .../test/scala/spark/CheckpointSuite.scala | 21 +-- 15 files changed, 153 insertions(+), 168 deletions(-) diff --git a/core/src/main/scala/spark/CacheManager.scala b/core/src/main/scala/spark/CacheManager.scala index a0b53fd9d6..711435c333 100644 --- a/core/src/main/scala/spark/CacheManager.scala +++ b/core/src/main/scala/spark/CacheManager.scala @@ -10,9 +10,9 @@ import spark.storage.{BlockManager, StorageLevel} private[spark] class CacheManager(blockManager: BlockManager) extends Logging { private val loading = new HashSet[String] - /** Gets or computes an RDD split. Used by RDD.iterator() when a RDD is cached. */ + /** Gets or computes an RDD split. Used by RDD.iterator() when an RDD is cached. */ def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel) - : Iterator[T] = { + : Iterator[T] = { val key = "rdd_%d_%d".format(rdd.id, split.index) logInfo("Cache key is " + key) blockManager.get(key) match { @@ -50,7 +50,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { // If we got here, we have to load the split val elements = new ArrayBuffer[Any] logInfo("Computing partition " + split) - elements ++= rdd.compute(split, context) + elements ++= rdd.computeOrReadCheckpoint(split, context) // Try to put this block in the blockManager blockManager.put(key, elements, storageLevel, true) return elements.iterator.asInstanceOf[Iterator[T]] diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 53b051f1c5..231e23a7de 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -649,9 +649,7 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest]( } private[spark] -class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) - extends RDD[(K, U)](prev) { - +class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) extends RDD[(K, U)](prev) { override def getSplits = firstParent[(K, V)].splits override val partitioner = firstParent[(K, V)].partitioner override def compute(split: Split, context: TaskContext) = diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 0d3857f9dd..dbad6d4c83 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -1,27 +1,17 @@ package spark -import java.io.{ObjectOutputStream, IOException, EOFException, ObjectInputStream} import java.net.URL import java.util.{Date, Random} import java.util.{HashMap => JHashMap} -import java.util.concurrent.atomic.AtomicLong import scala.collection.Map import scala.collection.JavaConversions.mapAsScalaMap import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap -import org.apache.hadoop.fs.Path import org.apache.hadoop.io.BytesWritable import org.apache.hadoop.io.NullWritable import org.apache.hadoop.io.Text -import org.apache.hadoop.io.Writable -import org.apache.hadoop.mapred.FileOutputCommitter -import org.apache.hadoop.mapred.HadoopWriter -import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapred.OutputCommitter -import org.apache.hadoop.mapred.OutputFormat -import org.apache.hadoop.mapred.SequenceFileOutputFormat import org.apache.hadoop.mapred.TextOutputFormat import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} @@ -30,7 +20,6 @@ import spark.partial.BoundedDouble import spark.partial.CountEvaluator import spark.partial.GroupedCountEvaluator import spark.partial.PartialResult -import spark.rdd.BlockRDD import spark.rdd.CartesianRDD import spark.rdd.FilteredRDD import spark.rdd.FlatMappedRDD @@ -73,11 +62,11 @@ import SparkContext._ * on RDD internals. */ abstract class RDD[T: ClassManifest]( - @transient var sc: SparkContext, - var dependencies_ : List[Dependency[_]] + @transient private var sc: SparkContext, + @transient private var deps: Seq[Dependency[_]] ) extends Serializable with Logging { - + /** Construct an RDD with just a one-to-one dependency on one parent */ def this(@transient oneParent: RDD[_]) = this(oneParent.context , List(new OneToOneDependency(oneParent))) @@ -85,25 +74,27 @@ abstract class RDD[T: ClassManifest]( // Methods that should be implemented by subclasses of RDD // ======================================================================= - /** Function for computing a given partition. */ + /** Implemented by subclasses to compute a given partition. */ def compute(split: Split, context: TaskContext): Iterator[T] - /** Set of partitions in this RDD. */ - protected def getSplits(): Array[Split] + /** + * Implemented by subclasses to return the set of partitions in this RDD. This method will only + * be called once, so it is safe to implement a time-consuming computation in it. + */ + protected def getSplits: Array[Split] - /** How this RDD depends on any parent RDDs. */ - protected def getDependencies(): List[Dependency[_]] = dependencies_ + /** + * Implemented by subclasses to return how this RDD depends on parent RDDs. This method will only + * be called once, so it is safe to implement a time-consuming computation in it. + */ + protected def getDependencies: Seq[Dependency[_]] = deps - /** A friendly name for this RDD */ - var name: String = null - /** Optionally overridden by subclasses to specify placement preferences. */ protected def getPreferredLocations(split: Split): Seq[String] = Nil /** Optionally overridden by subclasses to specify how they are partitioned. */ val partitioner: Option[Partitioner] = None - // ======================================================================= // Methods and fields available on all RDDs // ======================================================================= @@ -111,13 +102,16 @@ abstract class RDD[T: ClassManifest]( /** A unique ID for this RDD (within its SparkContext). */ val id = sc.newRddId() + /** A friendly name for this RDD */ + var name: String = null + /** Assign a name to this RDD */ def setName(_name: String) = { name = _name this } - /** + /** * Set this RDD's storage level to persist its values across operations after the first time * it is computed. Can only be called once on each RDD. */ @@ -142,15 +136,24 @@ abstract class RDD[T: ClassManifest]( /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ def getStorageLevel = storageLevel + // Our dependencies and splits will be gotten by calling subclass's methods below, and will + // be overwritten when we're checkpointed + private var dependencies_ : Seq[Dependency[_]] = null + @transient private var splits_ : Array[Split] = null + + /** An Option holding our checkpoint RDD, if we are checkpointed */ + private def checkpointRDD: Option[RDD[T]] = checkpointData.flatMap(_.checkpointRDD) + /** - * Get the preferred location of a split, taking into account whether the + * Get the list of dependencies of this RDD, taking into account whether the * RDD is checkpointed or not. */ - final def preferredLocations(split: Split): Seq[String] = { - if (isCheckpointed) { - checkpointData.get.getPreferredLocations(split) - } else { - getPreferredLocations(split) + final def dependencies: Seq[Dependency[_]] = { + checkpointRDD.map(r => List(new OneToOneDependency(r))).getOrElse { + if (dependencies_ == null) { + dependencies_ = getDependencies + } + dependencies_ } } @@ -159,22 +162,21 @@ abstract class RDD[T: ClassManifest]( * RDD is checkpointed or not. */ final def splits: Array[Split] = { - if (isCheckpointed) { - checkpointData.get.getSplits - } else { - getSplits + checkpointRDD.map(_.splits).getOrElse { + if (splits_ == null) { + splits_ = getSplits + } + splits_ } } /** - * Get the list of dependencies of this RDD, taking into account whether the + * Get the preferred location of a split, taking into account whether the * RDD is checkpointed or not. */ - final def dependencies: List[Dependency[_]] = { - if (isCheckpointed) { - dependencies_ - } else { - getDependencies + final def preferredLocations(split: Split): Seq[String] = { + checkpointRDD.map(_.getPreferredLocations(split)).getOrElse { + getPreferredLocations(split) } } @@ -184,10 +186,19 @@ abstract class RDD[T: ClassManifest]( * subclasses of RDD. */ final def iterator(split: Split, context: TaskContext): Iterator[T] = { - if (isCheckpointed) { - checkpointData.get.iterator(split, context) - } else if (storageLevel != StorageLevel.NONE) { + if (storageLevel != StorageLevel.NONE) { SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel) + } else { + computeOrReadCheckpoint(split, context) + } + } + + /** + * Compute an RDD partition or read it from a checkpoint if the RDD is checkpointing. + */ + private[spark] def computeOrReadCheckpoint(split: Split, context: TaskContext): Iterator[T] = { + if (isCheckpointed) { + firstParent[T].iterator(split, context) } else { compute(split, context) } @@ -578,15 +589,15 @@ abstract class RDD[T: ClassManifest]( /** * Return whether this RDD has been checkpointed or not */ - def isCheckpointed(): Boolean = { - if (checkpointData.isDefined) checkpointData.get.isCheckpointed() else false + def isCheckpointed: Boolean = { + checkpointData.map(_.isCheckpointed).getOrElse(false) } /** * Gets the name of the file to which this RDD was checkpointed */ - def getCheckpointFile(): Option[String] = { - if (checkpointData.isDefined) checkpointData.get.getCheckpointFile() else None + def getCheckpointFile: Option[String] = { + checkpointData.flatMap(_.getCheckpointFile) } // ======================================================================= @@ -611,31 +622,36 @@ abstract class RDD[T: ClassManifest]( def context = sc /** - * Performs the checkpointing of this RDD by saving this . It is called by the DAGScheduler + * Performs the checkpointing of this RDD by saving this. It is called by the DAGScheduler * after a job using this RDD has completed (therefore the RDD has been materialized and * potentially stored in memory). doCheckpoint() is called recursively on the parent RDDs. */ - protected[spark] def doCheckpoint() { - if (checkpointData.isDefined) checkpointData.get.doCheckpoint() - dependencies.foreach(_.rdd.doCheckpoint()) + private[spark] def doCheckpoint() { + if (checkpointData.isDefined) { + checkpointData.get.doCheckpoint() + } else { + dependencies.foreach(_.rdd.doCheckpoint()) + } } /** - * Changes the dependencies of this RDD from its original parents to the new RDD - * (`newRDD`) created from the checkpoint file. + * Changes the dependencies of this RDD from its original parents to a new RDD (`newRDD`) + * created from the checkpoint file, and forget its old dependencies and splits. */ - protected[spark] def changeDependencies(newRDD: RDD[_]) { + private[spark] def markCheckpointed(checkpointRDD: RDD[_]) { clearDependencies() - dependencies_ = List(new OneToOneDependency(newRDD)) + dependencies_ = null + splits_ = null + deps = null // Forget the constructor argument for dependencies too } /** * Clears the dependencies of this RDD. This method must ensure that all references * to the original parent RDDs is removed to enable the parent RDDs to be garbage * collected. Subclasses of RDD may override this method for implementing their own cleaning - * logic. See [[spark.rdd.UnionRDD]] and [[spark.rdd.ShuffledRDD]] to get a better idea. + * logic. See [[spark.rdd.UnionRDD]] for an example. */ - protected[spark] def clearDependencies() { + protected def clearDependencies() { dependencies_ = null } } diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala index 18df530b7d..a4a4ebaf53 100644 --- a/core/src/main/scala/spark/RDDCheckpointData.scala +++ b/core/src/main/scala/spark/RDDCheckpointData.scala @@ -20,7 +20,7 @@ private[spark] object CheckpointState extends Enumeration { * of the checkpointed RDD. */ private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T]) -extends Logging with Serializable { + extends Logging with Serializable { import CheckpointState._ @@ -31,7 +31,7 @@ extends Logging with Serializable { @transient var cpFile: Option[String] = None // The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD. - @transient var cpRDD: Option[RDD[T]] = None + var cpRDD: Option[RDD[T]] = None // Mark the RDD for checkpointing def markForCheckpoint() { @@ -41,12 +41,12 @@ extends Logging with Serializable { } // Is the RDD already checkpointed - def isCheckpointed(): Boolean = { + def isCheckpointed: Boolean = { RDDCheckpointData.synchronized { cpState == Checkpointed } } // Get the file to which this RDD was checkpointed to as an Option - def getCheckpointFile(): Option[String] = { + def getCheckpointFile: Option[String] = { RDDCheckpointData.synchronized { cpFile } } @@ -71,7 +71,7 @@ extends Logging with Serializable { RDDCheckpointData.synchronized { cpFile = Some(path) cpRDD = Some(newRDD) - rdd.changeDependencies(newRDD) + rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and splits cpState = Checkpointed RDDCheckpointData.clearTaskCaches() logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id) @@ -79,7 +79,7 @@ extends Logging with Serializable { } // Get preferred location of a split after checkpointing - def getPreferredLocations(split: Split) = { + def getPreferredLocations(split: Split): Seq[String] = { RDDCheckpointData.synchronized { cpRDD.get.preferredLocations(split) } @@ -91,9 +91,10 @@ extends Logging with Serializable { } } - // Get iterator. This is called at the worker nodes. - def iterator(split: Split, context: TaskContext): Iterator[T] = { - rdd.firstParent[T].iterator(split, context) + def checkpointRDD: Option[RDD[T]] = { + RDDCheckpointData.synchronized { + cpRDD + } } } diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index 4c95c989b5..46fd8fe85e 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -319,7 +319,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround /** * Return whether this RDD has been checkpointed or not */ - def isCheckpointed(): Boolean = rdd.isCheckpointed() + def isCheckpointed: Boolean = rdd.isCheckpointed /** * Gets the name of the file to which this RDD was checkpointed diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala index 453d410ad4..0f9ca06531 100644 --- a/core/src/main/scala/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala @@ -1,7 +1,7 @@ package spark.rdd import java.io.{ObjectOutputStream, IOException} -import spark.{OneToOneDependency, NarrowDependency, RDD, SparkContext, Split, TaskContext} +import spark._ private[spark] @@ -35,7 +35,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( val numSplitsInRdd2 = rdd2.splits.size - @transient var splits_ = { + override def getSplits: Array[Split] = { // create the cross product split val array = new Array[Split](rdd1.splits.size * rdd2.splits.size) for (s1 <- rdd1.splits; s2 <- rdd2.splits) { @@ -45,8 +45,6 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( array } - override def getSplits = splits_ - override def getPreferredLocations(split: Split) = { val currSplit = split.asInstanceOf[CartesianSplit] rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2) @@ -58,7 +56,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( y <- rdd2.iterator(currSplit.s2, context)) yield (x, y) } - var deps_ = List( + override def getDependencies: Seq[Dependency[_]] = List( new NarrowDependency(rdd1) { def getParents(id: Int): Seq[Int] = List(id / numSplitsInRdd2) }, @@ -67,11 +65,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( } ) - override def getDependencies = deps_ - override def clearDependencies() { - deps_ = Nil - splits_ = null rdd1 = null rdd2 = null } diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala index 6f00f6ac73..96b593ba7c 100644 --- a/core/src/main/scala/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala @@ -9,23 +9,26 @@ import org.apache.hadoop.fs.Path import java.io.{File, IOException, EOFException} import java.text.NumberFormat -private[spark] class CheckpointRDDSplit(idx: Int, val splitFile: String) extends Split { - override val index: Int = idx -} +private[spark] class CheckpointRDDSplit(val index: Int) extends Split {} /** * This RDD represents a RDD checkpoint file (similar to HadoopRDD). */ private[spark] -class CheckpointRDD[T: ClassManifest](sc: SparkContext, checkpointPath: String) +class CheckpointRDD[T: ClassManifest](sc: SparkContext, val checkpointPath: String) extends RDD[T](sc, Nil) { - @transient val path = new Path(checkpointPath) - @transient val fs = path.getFileSystem(new Configuration()) + @transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration) @transient val splits_ : Array[Split] = { - val splitFiles = fs.listStatus(path).map(_.getPath.toString).filter(_.contains("part-")).sorted - splitFiles.zipWithIndex.map(x => new CheckpointRDDSplit(x._2, x._1)).toArray + val dirContents = fs.listStatus(new Path(checkpointPath)) + val splitFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted + val numSplits = splitFiles.size + if (!splitFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) || + !splitFiles(numSplits-1).endsWith(CheckpointRDD.splitIdToFile(numSplits-1))) { + throw new SparkException("Invalid checkpoint directory: " + checkpointPath) + } + Array.tabulate(numSplits)(i => new CheckpointRDDSplit(i)) } checkpointData = Some(new RDDCheckpointData[T](this)) @@ -34,36 +37,34 @@ class CheckpointRDD[T: ClassManifest](sc: SparkContext, checkpointPath: String) override def getSplits = splits_ override def getPreferredLocations(split: Split): Seq[String] = { - val status = fs.getFileStatus(path) + val status = fs.getFileStatus(new Path(checkpointPath)) val locations = fs.getFileBlockLocations(status, 0, status.getLen) - locations.firstOption.toList.flatMap(_.getHosts).filter(_ != "localhost") + locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost") } override def compute(split: Split, context: TaskContext): Iterator[T] = { - CheckpointRDD.readFromFile(split.asInstanceOf[CheckpointRDDSplit].splitFile, context) + val file = new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index)) + CheckpointRDD.readFromFile(file, context) } override def checkpoint() { - // Do nothing. Hadoop RDD should not be checkpointed. + // Do nothing. CheckpointRDD should not be checkpointed. } } private[spark] object CheckpointRDD extends Logging { - def splitIdToFileName(splitId: Int): String = { - val numfmt = NumberFormat.getInstance() - numfmt.setMinimumIntegerDigits(5) - numfmt.setGroupingUsed(false) - "part-" + numfmt.format(splitId) + def splitIdToFile(splitId: Int): String = { + "part-%05d".format(splitId) } - def writeToFile[T](path: String, blockSize: Int = -1)(context: TaskContext, iterator: Iterator[T]) { + def writeToFile[T](path: String, blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) { val outputDir = new Path(path) val fs = outputDir.getFileSystem(new Configuration()) - val finalOutputName = splitIdToFileName(context.splitId) + val finalOutputName = splitIdToFile(ctx.splitId) val finalOutputPath = new Path(outputDir, finalOutputName) - val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + context.attemptId) + val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptId) if (fs.exists(tempOutputPath)) { throw new IOException("Checkpoint failed: temporary path " + @@ -83,22 +84,22 @@ private[spark] object CheckpointRDD extends Logging { serializeStream.close() if (!fs.rename(tempOutputPath, finalOutputPath)) { - if (!fs.delete(finalOutputPath, true)) { - throw new IOException("Checkpoint failed: failed to delete earlier output of task " - + context.attemptId) - } - if (!fs.rename(tempOutputPath, finalOutputPath)) { + if (!fs.exists(finalOutputPath)) { + fs.delete(tempOutputPath, false) throw new IOException("Checkpoint failed: failed to save output of task: " - + context.attemptId) + + ctx.attemptId + " and final output path does not exist") + } else { + // Some other copy of this task must've finished before us and renamed it + logInfo("Final output path " + finalOutputPath + " already exists; not overwriting it") + fs.delete(tempOutputPath, false) } } } - def readFromFile[T](path: String, context: TaskContext): Iterator[T] = { - val inputPath = new Path(path) - val fs = inputPath.getFileSystem(new Configuration()) + def readFromFile[T](path: Path, context: TaskContext): Iterator[T] = { + val fs = path.getFileSystem(new Configuration()) val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt - val fileInputStream = fs.open(inputPath, bufferSize) + val fileInputStream = fs.open(path, bufferSize) val serializer = SparkEnv.get.serializer.newInstance() val deserializeStream = serializer.deserializeStream(fileInputStream) diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala index 167755bbba..4c57434b65 100644 --- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala @@ -27,11 +27,11 @@ private[spark] case class CoalescedRDDSplit( * or to avoid having a large number of small tasks when processing a directory with many files. */ class CoalescedRDD[T: ClassManifest]( - var prev: RDD[T], + @transient var prev: RDD[T], maxPartitions: Int) - extends RDD[T](prev.context, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs + extends RDD[T](prev.context, Nil) { // Nil since we implement getDependencies - @transient var splits_ : Array[Split] = { + override def getSplits: Array[Split] = { val prevSplits = prev.splits if (prevSplits.length < maxPartitions) { prevSplits.map(_.index).map{idx => new CoalescedRDDSplit(idx, prev, Array(idx)) } @@ -44,26 +44,20 @@ class CoalescedRDD[T: ClassManifest]( } } - override def getSplits = splits_ - override def compute(split: Split, context: TaskContext): Iterator[T] = { split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap { parentSplit => firstParent[T].iterator(parentSplit, context) } } - var deps_ : List[Dependency[_]] = List( + override def getDependencies: Seq[Dependency[_]] = List( new NarrowDependency(prev) { def getParents(id: Int): Seq[Int] = splits(id).asInstanceOf[CoalescedRDDSplit].parentsIndices } ) - override def getDependencies() = deps_ - override def clearDependencies() { - deps_ = Nil - splits_ = null prev = null } } diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala index c6ceb272cd..5466c9c657 100644 --- a/core/src/main/scala/spark/rdd/MappedRDD.scala +++ b/core/src/main/scala/spark/rdd/MappedRDD.scala @@ -3,13 +3,11 @@ package spark.rdd import spark.{RDD, Split, TaskContext} private[spark] -class MappedRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], - f: T => U) +class MappedRDD[U: ClassManifest, T: ClassManifest](prev: RDD[T], f: T => U) extends RDD[U](prev) { override def getSplits = firstParent[T].splits override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(split, context).map(f) -} \ No newline at end of file +} diff --git a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala index 97dd37950e..b8482338c6 100644 --- a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala @@ -7,23 +7,18 @@ import spark.{PruneDependency, RDD, SparkEnv, Split, TaskContext} * all partitions. An example use case: If we know the RDD is partitioned by range, * and the execution DAG has a filter on the key, we can avoid launching tasks * on partitions that don't have the range covering the key. + * + * TODO: This currently doesn't give partition IDs properly! */ class PartitionPruningRDD[T: ClassManifest]( @transient prev: RDD[T], @transient partitionFilterFunc: Int => Boolean) extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) { - @transient - var partitions_ : Array[Split] = dependencies_.head.asInstanceOf[PruneDependency[T]].partitions - override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(split, context) - override protected def getSplits = partitions_ + override protected def getSplits = + getDependencies.head.asInstanceOf[PruneDependency[T]].partitions override val partitioner = firstParent[T].partitioner - - override def clearDependencies() { - super.clearDependencies() - partitions_ = null - } } diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index 28ff19876d..d396478673 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -22,16 +22,10 @@ class ShuffledRDD[K, V]( override val partitioner = Some(part) - @transient var splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i)) - - override def getSplits = splits_ + override def getSplits = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i)) override def compute(split: Split, context: TaskContext): Iterator[(K, V)] = { val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index) } - - override def clearDependencies() { - splits_ = null - } } diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index 82f0a44ecd..26a2d511f2 100644 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -26,9 +26,9 @@ private[spark] class UnionSplit[T: ClassManifest](idx: Int, rdd: RDD[T], splitIn class UnionRDD[T: ClassManifest]( sc: SparkContext, @transient var rdds: Seq[RDD[T]]) - extends RDD[T](sc, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs + extends RDD[T](sc, Nil) { // Nil since we implement getDependencies - @transient var splits_ : Array[Split] = { + override def getSplits: Array[Split] = { val array = new Array[Split](rdds.map(_.splits.size).sum) var pos = 0 for (rdd <- rdds; split <- rdd.splits) { @@ -38,20 +38,16 @@ class UnionRDD[T: ClassManifest]( array } - override def getSplits = splits_ - - @transient var deps_ = { + override def getDependencies: Seq[Dependency[_]] = { val deps = new ArrayBuffer[Dependency[_]] var pos = 0 for (rdd <- rdds) { deps += new RangeDependency(rdd, 0, pos, rdd.splits.size) pos += rdd.splits.size } - deps.toList + deps } - override def getDependencies = deps_ - override def compute(s: Split, context: TaskContext): Iterator[T] = s.asInstanceOf[UnionSplit[T]].iterator(context) @@ -59,8 +55,6 @@ class UnionRDD[T: ClassManifest]( s.asInstanceOf[UnionSplit[T]].preferredLocations() override def clearDependencies() { - deps_ = null - splits_ = null rdds = null } } diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala index d950b06c85..e5df6d8c72 100644 --- a/core/src/main/scala/spark/rdd/ZippedRDD.scala +++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala @@ -32,9 +32,7 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest]( extends RDD[(T, U)](sc, List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2))) with Serializable { - // TODO: FIX THIS. - - @transient var splits_ : Array[Split] = { + override def getSplits: Array[Split] = { if (rdd1.splits.size != rdd2.splits.size) { throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") } @@ -45,8 +43,6 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest]( array } - override def getSplits = splits_ - override def compute(s: Split, context: TaskContext): Iterator[(T, U)] = { val (split1, split2) = s.asInstanceOf[ZippedSplit[T, U]].splits rdd1.iterator(split1, context).zip(rdd2.iterator(split2, context)) @@ -58,7 +54,6 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest]( } override def clearDependencies() { - splits_ = null rdd1 = null rdd2 = null } diff --git a/core/src/main/scala/spark/util/MetadataCleaner.scala b/core/src/main/scala/spark/util/MetadataCleaner.scala index 6cf93a9b17..eaff7ae581 100644 --- a/core/src/main/scala/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/spark/util/MetadataCleaner.scala @@ -26,8 +26,8 @@ class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging if (delaySeconds > 0) { logDebug( - "Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds and " - + "period of " + periodSeconds + " secs") + "Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds " + + "and period of " + periodSeconds + " secs") timer.schedule(task, periodSeconds * 1000, periodSeconds * 1000) } diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 33c317720c..0b74607fb8 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -99,7 +99,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { // the parent RDD has been checkpointed and parent splits have been changed to HadoopSplits. // Note that this test is very specific to the current implementation of CartesianRDD. val ones = sc.makeRDD(1 to 100, 10).map(x => x) - ones.checkpoint // checkpoint that MappedRDD + ones.checkpoint() // checkpoint that MappedRDD val cartesian = new CartesianRDD(sc, ones, ones) val splitBeforeCheckpoint = serializeDeserialize(cartesian.splits.head.asInstanceOf[CartesianSplit]) @@ -125,7 +125,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { // the parent RDD has been checkpointed and parent splits have been changed to HadoopSplits. // Note that this test is very specific to the current implementation of CoalescedRDDSplits val ones = sc.makeRDD(1 to 100, 10).map(x => x) - ones.checkpoint // checkpoint that MappedRDD + ones.checkpoint() // checkpoint that MappedRDD val coalesced = new CoalescedRDD(ones, 2) val splitBeforeCheckpoint = serializeDeserialize(coalesced.splits.head.asInstanceOf[CoalescedRDDSplit]) @@ -160,7 +160,6 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { // so only the RDD will reduce in serialized size, not the splits. testParentCheckpointing( rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false) - } /** @@ -176,7 +175,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { testRDDSplitSize: Boolean = false ) { // Generate the final RDD using given RDD operation - val baseRDD = generateLongLineageRDD + val baseRDD = generateLongLineageRDD() val operatedRDD = op(baseRDD) val parentRDD = operatedRDD.dependencies.headOption.orNull val rddType = operatedRDD.getClass.getSimpleName @@ -245,12 +244,16 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { testRDDSplitSize: Boolean ) { // Generate the final RDD using given RDD operation - val baseRDD = generateLongLineageRDD + val baseRDD = generateLongLineageRDD() val operatedRDD = op(baseRDD) val parentRDD = operatedRDD.dependencies.head.rdd val rddType = operatedRDD.getClass.getSimpleName val parentRDDType = parentRDD.getClass.getSimpleName + // Get the splits and dependencies of the parent in case they're lazily computed + parentRDD.dependencies + parentRDD.splits + // Find serialized sizes before and after the checkpoint val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) parentRDD.checkpoint() // checkpoint the parent RDD, not the generated one @@ -267,7 +270,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { if (testRDDSize) { assert( rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint, - "Size of " + rddType + " did not reduce after parent checkpointing parent " + parentRDDType + + "Size of " + rddType + " did not reduce after checkpointing parent " + parentRDDType + "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" ) } @@ -318,10 +321,12 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { } /** - * Get serialized sizes of the RDD and its splits + * Get serialized sizes of the RDD and its splits, in order to test whether the size shrinks + * upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint. */ def getSerializedSizes(rdd: RDD[_]): (Int, Int) = { - (Utils.serialize(rdd).size, Utils.serialize(rdd.splits).size) + (Utils.serialize(rdd).length - Utils.serialize(rdd.checkpointData).length, + Utils.serialize(rdd.splits).length) } /** From a34096a76de9d07518ce33111ad43b88049c1ac2 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Mon, 28 Jan 2013 22:40:16 -0800 Subject: [PATCH 228/291] Add easymock to POMs --- core/pom.xml | 5 +++++ pom.xml | 6 ++++++ 2 files changed, 11 insertions(+) diff --git a/core/pom.xml b/core/pom.xml index 862d3ec37a..a2b9b726a6 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -98,6 +98,11 @@ scalacheck_${scala.version} test
    + + org.easymock + easymock + test + com.novocode junit-interface diff --git a/pom.xml b/pom.xml index 3ea989a082..4a4ff560e7 100644 --- a/pom.xml +++ b/pom.xml @@ -273,6 +273,12 @@ 1.8 test + + org.easymock + easymock + 3.1 + test + org.scalacheck scalacheck_${scala.version} From 16a0789e10d2ac714e7c623b026c4a58ca9678d6 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Tue, 29 Jan 2013 17:09:53 -0800 Subject: [PATCH 229/291] Remember ConnectionManagerId used to initiate SendingConnections. This prevents ConnectionManager from getting confused if a machine has multiple host names and the one getHostName() finds happens not to be the one that was passed from, e.g., the BlockManagerMaster. --- .../src/main/scala/spark/network/Connection.scala | 15 +++++++++++---- .../scala/spark/network/ConnectionManager.scala | 3 ++- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala index c193bf7c8d..cd5b7d57f3 100644 --- a/core/src/main/scala/spark/network/Connection.scala +++ b/core/src/main/scala/spark/network/Connection.scala @@ -12,7 +12,14 @@ import java.net._ private[spark] -abstract class Connection(val channel: SocketChannel, val selector: Selector) extends Logging { +abstract class Connection(val channel: SocketChannel, val selector: Selector, + val remoteConnectionManagerId: ConnectionManagerId) extends Logging { + def this(channel_ : SocketChannel, selector_ : Selector) = { + this(channel_, selector_, + ConnectionManagerId.fromSocketAddress( + channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress] + )) + } channel.configureBlocking(false) channel.socket.setTcpNoDelay(true) @@ -25,7 +32,6 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector) ex var onKeyInterestChangeCallback: (Connection, Int) => Unit = null val remoteAddress = getRemoteAddress() - val remoteConnectionManagerId = ConnectionManagerId.fromSocketAddress(remoteAddress) def key() = channel.keyFor(selector) @@ -103,8 +109,9 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector) ex } -private[spark] class SendingConnection(val address: InetSocketAddress, selector_ : Selector) -extends Connection(SocketChannel.open, selector_) { +private[spark] class SendingConnection(val address: InetSocketAddress, selector_ : Selector, + remoteId_ : ConnectionManagerId) +extends Connection(SocketChannel.open, selector_, remoteId_) { class Outbox(fair: Int = 0) { val messages = new Queue[Message]() diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index 2ecd14f536..c7f226044d 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -299,7 +299,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging { private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) { def startNewConnection(): SendingConnection = { val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port) - val newConnection = connectionRequests.getOrElseUpdate(connectionManagerId, new SendingConnection(inetSocketAddress, selector)) + val newConnection = connectionRequests.getOrElseUpdate(connectionManagerId, + new SendingConnection(inetSocketAddress, selector, connectionManagerId)) newConnection } val lookupKey = ConnectionManagerId.fromSocketAddress(connectionManagerId.toSocketAddress) From 0f81025ecadbfd21edb64602658ae8ba26e5bf66 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Tue, 29 Jan 2013 18:54:58 -0800 Subject: [PATCH 230/291] Add easymock to SBT configuration. --- project/SparkBuild.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 03b8094f7d..af8b5ba017 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -92,7 +92,8 @@ object SparkBuild extends Build { "org.eclipse.jetty" % "jetty-server" % "7.5.3.v20111011", "org.scalatest" %% "scalatest" % "1.8" % "test", "org.scalacheck" %% "scalacheck" % "1.9" % "test", - "com.novocode" % "junit-interface" % "0.8" % "test" + "com.novocode" % "junit-interface" % "0.8" % "test", + "org.easymock" % "easymock" % "3.1" % "test" ), parallelExecution := false, /* Workaround for issue #206 (fixed after SBT 0.11.0) */ From a3d14c0404d6b28433784f84086a29ecc0045a12 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Mon, 28 Jan 2013 22:41:08 -0800 Subject: [PATCH 231/291] Refactoring to DAGScheduler to aid testing --- core/src/main/scala/spark/SparkContext.scala | 1 + .../scala/spark/scheduler/DAGScheduler.scala | 29 +++++++++++-------- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index dc9b8688b3..6ae04f4a44 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -187,6 +187,7 @@ class SparkContext( taskScheduler.start() private var dagScheduler = new DAGScheduler(taskScheduler) + dagScheduler.start() /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ val hadoopConfiguration = { diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index b130be6a38..9655961162 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -23,7 +23,14 @@ import util.{MetadataCleaner, TimeStampedHashMap} * and to report fetch failures (the submitTasks method, and code to add CompletionEvents). */ private[spark] -class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with Logging { +class DAGScheduler(taskSched: TaskScheduler, + mapOutputTracker: MapOutputTracker, + blockManagerMaster: BlockManagerMaster, + env: SparkEnv) + extends TaskSchedulerListener with Logging { + def this(taskSched: TaskScheduler) { + this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get) + } taskSched.setListener(this) // Called by TaskScheduler to report task completions or failures. @@ -66,10 +73,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with var cacheLocs = new HashMap[Int, Array[List[String]]] - val env = SparkEnv.get - val mapOutputTracker = env.mapOutputTracker - val blockManagerMaster = env.blockManager.master - // For tracking failed nodes, we use the MapOutputTracker's generation number, which is // sent with every task. When we detect a node failing, we note the current generation number // and failed executor, increment it for new tasks, and use this to ignore stray ShuffleMapTask @@ -90,12 +93,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup) // Start a thread to run the DAGScheduler event loop - new Thread("DAGScheduler") { - setDaemon(true) - override def run() { - DAGScheduler.this.run() - } - }.start() + def start() { + new Thread("DAGScheduler") { + setDaemon(true) + override def run() { + DAGScheduler.this.run() + } + }.start() + } def getCacheLocs(rdd: RDD[_]): Array[List[String]] = { if (!cacheLocs.contains(rdd.id)) { @@ -546,7 +551,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with if (!failedGeneration.contains(execId) || failedGeneration(execId) < currentGeneration) { failedGeneration(execId) = currentGeneration logInfo("Executor lost: %s (generation %d)".format(execId, currentGeneration)) - env.blockManager.master.removeExecutor(execId) + blockManagerMaster.removeExecutor(execId) // TODO: This will be really slow if we keep accumulating shuffle map stages for ((shuffleId, stage) <- shuffleToMapStage) { stage.removeOutputsOnExecutor(execId) From 9eac7d01f0880d1d3d51e922ef2566c4ee92989f Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Mon, 28 Jan 2013 22:42:35 -0800 Subject: [PATCH 232/291] Add DAGScheduler tests. --- .../spark/scheduler/DAGSchedulerSuite.scala | 540 ++++++++++++++++++ 1 file changed, 540 insertions(+) create mode 100644 core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala new file mode 100644 index 0000000000..53f5214d7a --- /dev/null +++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala @@ -0,0 +1,540 @@ +package spark.scheduler + +import scala.collection.mutable.{Map, HashMap} + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter +import org.scalatest.concurrent.AsyncAssertions +import org.scalatest.concurrent.TimeLimitedTests +import org.scalatest.mock.EasyMockSugar +import org.scalatest.time.{Span, Seconds} + +import org.easymock.EasyMock._ +import org.easymock.EasyMock +import org.easymock.{IAnswer, IArgumentMatcher} + +import akka.actor.ActorSystem + +import spark.storage.BlockManager +import spark.storage.BlockManagerId +import spark.storage.BlockManagerMaster +import spark.{Dependency, ShuffleDependency, OneToOneDependency} +import spark.FetchFailedException +import spark.MapOutputTracker +import spark.RDD +import spark.SparkContext +import spark.SparkException +import spark.Split +import spark.TaskContext +import spark.TaskEndReason + +import spark.{FetchFailed, Success} + +class DAGSchedulerSuite extends FunSuite + with BeforeAndAfter with EasyMockSugar with TimeLimitedTests + with AsyncAssertions with spark.Logging { + + // If we crash the DAGScheduler thread, our test will probably hang. + override val timeLimit = Span(5, Seconds) + + val sc: SparkContext = new SparkContext("local", "DAGSchedulerSuite") + var scheduler: DAGScheduler = null + var w: Waiter = null + val taskScheduler = mock[TaskScheduler] + val blockManagerMaster = mock[BlockManagerMaster] + var mapOutputTracker: MapOutputTracker = null + var schedulerThread: Thread = null + var schedulerException: Throwable = null + val taskSetMatchers = new HashMap[MyRDD, IArgumentMatcher] + val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]] + + implicit val mocks = MockObjects(taskScheduler, blockManagerMaster) + + def makeBlockManagerId(host: String): BlockManagerId = + BlockManagerId("exec-" + host, host, 12345) + + def resetExpecting(f: => Unit) { + reset(taskScheduler) + reset(blockManagerMaster) + expecting(f) + } + + before { + taskSetMatchers.clear() + cacheLocations.clear() + val actorSystem = ActorSystem("test") + mapOutputTracker = new MapOutputTracker(actorSystem, true) + resetExpecting { + taskScheduler.setListener(anyObject()) + } + whenExecuting { + scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null) + } + w = new Waiter + schedulerException = null + schedulerThread = new Thread("DAGScheduler under test") { + override def run() { + try { + scheduler.run() + } catch { + case t: Throwable => + logError("Got exception in DAGScheduler: ", t) + schedulerException = t + } finally { + w.dismiss() + } + } + } + schedulerThread.start + logInfo("finished before") + } + + after { + logInfo("started after") + resetExpecting { + taskScheduler.stop() + } + whenExecuting { + scheduler.stop + schedulerThread.join + } + w.await() + if (schedulerException != null) { + throw new Exception("Exception caught from scheduler thread", schedulerException) + } + } + + // Type of RDD we use for testing. Note that we should never call the real RDD compute methods. + // This is a pair RDD type so it can always be used in ShuffleDependencies. + type MyRDD = RDD[(Int, Int)] + + def makeRdd( + numSplits: Int, + dependencies: List[Dependency[_]], + locations: Seq[Seq[String]] = Nil + ): MyRDD = { + val maxSplit = numSplits - 1 + return new MyRDD(sc, dependencies) { + override def compute(split: Split, context: TaskContext): Iterator[(Int, Int)] = + throw new RuntimeException("should not be reached") + override def getSplits() = (0 to maxSplit).map(i => new Split { + override def index = i + }).toArray + override def getPreferredLocations(split: Split): Seq[String] = + if (locations.isDefinedAt(split.index)) + locations(split.index) + else + Nil + override def toString: String = "DAGSchedulerSuiteRDD " + id + } + } + + def taskSetForRdd(rdd: MyRDD): TaskSet = { + val matcher = taskSetMatchers.getOrElseUpdate(rdd, + new IArgumentMatcher { + override def matches(actual: Any): Boolean = { + val taskSet = actual.asInstanceOf[TaskSet] + taskSet.tasks(0) match { + case rt: ResultTask[_, _] => rt.rdd.id == rdd.id + case smt: ShuffleMapTask => smt.rdd.id == rdd.id + case _ => false + } + } + override def appendTo(buf: StringBuffer) { + buf.append("taskSetForRdd(" + rdd + ")") + } + }) + EasyMock.reportMatcher(matcher) + return null + } + + def expectGetLocations(): Unit = { + EasyMock.expect(blockManagerMaster.getLocations(anyObject().asInstanceOf[Array[String]])). + andAnswer(new IAnswer[Seq[Seq[BlockManagerId]]] { + override def answer(): Seq[Seq[BlockManagerId]] = { + val blocks = getCurrentArguments()(0).asInstanceOf[Array[String]] + return blocks.map { name => + val pieces = name.split("_") + if (pieces(0) == "rdd") { + val key = pieces(1).toInt -> pieces(2).toInt + if (cacheLocations.contains(key)) { + cacheLocations(key) + } else { + Seq[BlockManagerId]() + } + } else { + Seq[BlockManagerId]() + } + }.toSeq + } + }).anyTimes() + } + + def expectStageAnd(rdd: MyRDD, results: Seq[(TaskEndReason, Any)], + preferredLocations: Option[Seq[Seq[String]]] = None)(afterSubmit: TaskSet => Unit) { + // TODO: Remember which submission + EasyMock.expect(taskScheduler.submitTasks(taskSetForRdd(rdd))).andAnswer(new IAnswer[Unit] { + override def answer(): Unit = { + val taskSet = getCurrentArguments()(0).asInstanceOf[TaskSet] + for (task <- taskSet.tasks) { + task.generation = mapOutputTracker.getGeneration + } + afterSubmit(taskSet) + preferredLocations match { + case None => + for (taskLocs <- taskSet.tasks.map(_.preferredLocations)) { + w { assert(taskLocs.size === 0) } + } + case Some(locations) => + w { assert(locations.size === taskSet.tasks.size) } + for ((expectLocs, taskLocs) <- + taskSet.tasks.map(_.preferredLocations).zip(locations)) { + w { assert(expectLocs === taskLocs) } + } + } + w { assert(taskSet.tasks.size >= results.size)} + for ((result, i) <- results.zipWithIndex) { + if (i < taskSet.tasks.size) { + scheduler.taskEnded(taskSet.tasks(i), result._1, result._2, Map[Long, Any]()) + } + } + } + }) + } + + def expectStage(rdd: MyRDD, results: Seq[(TaskEndReason, Any)], + preferredLocations: Option[Seq[Seq[String]]] = None) { + expectStageAnd(rdd, results, preferredLocations) { _ => } + } + + def submitRdd(rdd: MyRDD, allowLocal: Boolean = false): Array[Int] = { + return scheduler.runJob[(Int, Int), Int]( + rdd, + (context: TaskContext, it: Iterator[(Int, Int)]) => it.next._1.asInstanceOf[Int], + (0 to (rdd.splits.size - 1)), + "test-site", + allowLocal + ) + } + + def makeMapStatus(host: String, reduces: Int): MapStatus = + new MapStatus(makeBlockManagerId(host), Array.fill[Byte](reduces)(2)) + + test("zero split job") { + val rdd = makeRdd(0, Nil) + resetExpecting { + expectGetLocations() + // deliberately expect no stages to be submitted + } + whenExecuting { + assert(submitRdd(rdd) === Array[Int]()) + } + } + + test("run trivial job") { + val rdd = makeRdd(1, Nil) + resetExpecting { + expectGetLocations() + expectStage(rdd, List( (Success, 42) )) + } + whenExecuting { + assert(submitRdd(rdd) === Array(42)) + } + } + + test("local job") { + val rdd = new MyRDD(sc, Nil) { + override def compute(split: Split, context: TaskContext): Iterator[(Int, Int)] = + Array(42 -> 0).iterator + override def getSplits() = Array( new Split { override def index = 0 } ) + override def getPreferredLocations(split: Split) = Nil + override def toString = "DAGSchedulerSuite Local RDD" + } + resetExpecting { + expectGetLocations() + // deliberately expect no stages to be submitted + } + whenExecuting { + assert(submitRdd(rdd, true) === Array(42)) + } + } + + test("run trivial job w/ dependency") { + val baseRdd = makeRdd(1, Nil) + val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd))) + resetExpecting { + expectGetLocations() + expectStage(finalRdd, List( (Success, 42) )) + } + whenExecuting { + assert(submitRdd(finalRdd) === Array(42)) + } + } + + test("location preferences w/ dependency") { + val baseRdd = makeRdd(1, Nil) + val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd))) + resetExpecting { + expectGetLocations() + cacheLocations(baseRdd.id -> 0) = + Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")) + expectStage(finalRdd, List( (Success, 42) ), + Some(List(Seq("hostA", "hostB")))) + } + whenExecuting { + assert(submitRdd(finalRdd) === Array(42)) + } + } + + test("trivial job failure") { + val rdd = makeRdd(1, Nil) + resetExpecting { + expectGetLocations() + expectStageAnd(rdd, List()) { taskSet => scheduler.taskSetFailed(taskSet, "test failure") } + } + whenExecuting(taskScheduler, blockManagerMaster) { + intercept[SparkException] { submitRdd(rdd) } + } + } + + test("run trivial shuffle") { + val shuffleMapRdd = makeRdd(2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = makeRdd(1, List(shuffleDep)) + + resetExpecting { + expectGetLocations() + expectStage(shuffleMapRdd, List( + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostB", 1)) + )) + expectStageAnd(reduceRdd, List( (Success, 42) )) { _ => + w { assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === + Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) } + } + } + whenExecuting { + assert(submitRdd(reduceRdd) === Array(42)) + } + } + + test("run trivial shuffle with fetch failure") { + val shuffleMapRdd = makeRdd(2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = makeRdd(2, List(shuffleDep)) + + resetExpecting { + expectGetLocations() + expectStage(shuffleMapRdd, List( + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostB", 1)) + )) + blockManagerMaster.removeExecutor("exec-hostA") + expectStage(reduceRdd, List( + (Success, 42), + (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), null) + )) + // partial recompute + expectStage(shuffleMapRdd, List( (Success, makeMapStatus("hostA", 1)) )) + expectStageAnd(reduceRdd, List( (Success, 43) )) { _ => + w { assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === + Array(makeBlockManagerId("hostA"), + makeBlockManagerId("hostB"))) } + } + } + whenExecuting { + assert(submitRdd(reduceRdd) === Array(42, 43)) + } + } + + test("ignore late map task completions") { + val shuffleMapRdd = makeRdd(2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = makeRdd(2, List(shuffleDep)) + + resetExpecting { + expectGetLocations() + expectStageAnd(shuffleMapRdd, List( + (Success, makeMapStatus("hostA", 1)) + )) { taskSet => + val newGeneration = mapOutputTracker.getGeneration + 1 + scheduler.executorLost("exec-hostA") + val noAccum = Map[Long, Any]() + // We rely on the event queue being ordered and increasing the generation number by 1 + // should be ignored for being too old + scheduler.taskEnded(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum) + // should work because it's a non-failed host + scheduler.taskEnded(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum) + // should be ignored for being too old + scheduler.taskEnded(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum) + // should be ignored (not end the stage) because it's too old + scheduler.taskEnded(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum) + taskSet.tasks(1).generation = newGeneration + scheduler.taskEnded(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum) + } + blockManagerMaster.removeExecutor("exec-hostA") + expectStageAnd(reduceRdd, List( + (Success, 42), (Success, 43) + )) { _ => + w { assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === + Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) } + } + } + whenExecuting { + assert(submitRdd(reduceRdd) === Array(42, 43)) + } + } + + test("run trivial shuffle with out-of-band failure") { + val shuffleMapRdd = makeRdd(2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = makeRdd(1, List(shuffleDep)) + resetExpecting { + expectGetLocations() + blockManagerMaster.removeExecutor("exec-hostA") + expectStageAnd(shuffleMapRdd, List( + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostB", 1)) + )) { _ => scheduler.executorLost("exec-hostA") } + expectStage(shuffleMapRdd, List( + (Success, makeMapStatus("hostC", 1)) + )) + expectStageAnd(reduceRdd, List( (Success, 42) )) { _ => + w { assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === + Array(makeBlockManagerId("hostC"), + makeBlockManagerId("hostB"))) } + } + } + whenExecuting { + assert(submitRdd(reduceRdd) === Array(42)) + } + } + + test("recursive shuffle failures") { + val shuffleOneRdd = makeRdd(2, Nil) + val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null) + val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne)) + val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) + val finalRdd = makeRdd(1, List(shuffleDepTwo)) + + resetExpecting { + expectGetLocations() + expectStage(shuffleOneRdd, List( + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostB", 1)) + )) + expectStage(shuffleTwoRdd, List( + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostC", 1)) + )) + blockManagerMaster.removeExecutor("exec-hostA") + expectStage(finalRdd, List( + (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null) + )) + // triggers a partial recompute of the first stage, then the second + expectStage(shuffleOneRdd, List( + (Success, makeMapStatus("hostA", 1)) + )) + expectStage(shuffleTwoRdd, List( + (Success, makeMapStatus("hostA", 1)) + )) + expectStage(finalRdd, List( + (Success, 42) + )) + } + whenExecuting { + assert(submitRdd(finalRdd) === Array(42)) + } + } + + test("cached post-shuffle") { + val shuffleOneRdd = makeRdd(2, Nil) + val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null) + val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne)) + val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) + val finalRdd = makeRdd(1, List(shuffleDepTwo)) + + resetExpecting { + expectGetLocations() + expectStage(shuffleOneRdd, List( + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostB", 1)) + )) + expectStageAnd(shuffleTwoRdd, List( + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostC", 1)) + )){ _ => + cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD")) + cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC")) + } + blockManagerMaster.removeExecutor("exec-hostA") + expectStage(finalRdd, List( + (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null) + )) + // since we have a cached copy of the missing split of shuffleTwoRdd, we shouldn't + // immediately try to rerun shuffleOneRdd: + expectStage(shuffleTwoRdd, List( + (Success, makeMapStatus("hostD", 1)) + ), Some(Seq(List("hostD")))) + expectStage(finalRdd, List( + (Success, 42) + )) + } + whenExecuting { + assert(submitRdd(finalRdd) === Array(42)) + } + } + + test("cached post-shuffle but fails") { + val shuffleOneRdd = makeRdd(2, Nil) + val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null) + val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne)) + val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) + val finalRdd = makeRdd(1, List(shuffleDepTwo)) + + resetExpecting { + expectGetLocations() + expectStage(shuffleOneRdd, List( + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostB", 1)) + )) + expectStageAnd(shuffleTwoRdd, List( + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostC", 1)) + )){ _ => + cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD")) + cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC")) + } + blockManagerMaster.removeExecutor("exec-hostA") + expectStage(finalRdd, List( + (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null) + )) + // since we have a cached copy of the missing split of shuffleTwoRdd, we shouldn't + // immediately try to rerun shuffleOneRdd: + expectStageAnd(shuffleTwoRdd, List( + (FetchFailed(null, shuffleDepOne.shuffleId, 0, 0), null) + ), Some(Seq(List("hostD")))) { _ => + w { + intercept[FetchFailedException]{ + mapOutputTracker.getServerStatuses(shuffleDepOne.shuffleId, 0) + } + } + cacheLocations.remove(shuffleTwoRdd.id -> 0) + } + // after that fetch failure, we should refetch the cache locations and try to recompute + // the whole chain. Note that we will ignore that a fetch failure previously occured on + // this host. + expectStage(shuffleOneRdd, List( (Success, makeMapStatus("hostA", 1)) )) + expectStage(shuffleTwoRdd, List( (Success, makeMapStatus("hostA", 1)) )) + expectStage(finalRdd, List( (Success, 42) )) + } + whenExecuting { + assert(submitRdd(finalRdd) === Array(42)) + } + } +} + From 4bf3d7ea1252454ca584a3dabf26bdeab4069409 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Tue, 29 Jan 2013 19:05:45 -0800 Subject: [PATCH 233/291] Clear spark.master.port to cleanup for other tests --- core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala index 53f5214d7a..6c577c2685 100644 --- a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala @@ -102,6 +102,7 @@ class DAGSchedulerSuite extends FunSuite if (schedulerException != null) { throw new Exception("Exception caught from scheduler thread", schedulerException) } + System.clearProperty("spark.master.port") } // Type of RDD we use for testing. Note that we should never call the real RDD compute methods. From 178b89204c9dbee36886e757ddaafbd079672f4a Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 30 Jan 2013 09:19:55 -0800 Subject: [PATCH 234/291] Refactor DAGScheduler more to allow testing without a separate thread. --- .../scala/spark/scheduler/DAGScheduler.scala | 176 +++++++++++------- 1 file changed, 111 insertions(+), 65 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 9655961162..6892509ed1 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -23,11 +23,13 @@ import util.{MetadataCleaner, TimeStampedHashMap} * and to report fetch failures (the submitTasks method, and code to add CompletionEvents). */ private[spark] -class DAGScheduler(taskSched: TaskScheduler, - mapOutputTracker: MapOutputTracker, - blockManagerMaster: BlockManagerMaster, - env: SparkEnv) - extends TaskSchedulerListener with Logging { +class DAGScheduler( + taskSched: TaskScheduler, + mapOutputTracker: MapOutputTracker, + blockManagerMaster: BlockManagerMaster, + env: SparkEnv) + extends TaskSchedulerListener with Logging { + def this(taskSched: TaskScheduler) { this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get) } @@ -203,6 +205,27 @@ class DAGScheduler(taskSched: TaskScheduler, missing.toList } + /** Returns (and does not) submit a JobSubmitted event suitable to run a given job, and + * a JobWaiter whose getResult() method will return the result of the job when it is complete. + * + * The job is assumed to have at least one partition; zero partition jobs should be handled + * without a JobSubmitted event. + */ + private[scheduler] def prepareJob[T, U: ClassManifest]( + finalRdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + partitions: Seq[Int], + callSite: String, + allowLocal: Boolean) + : (JobSubmitted, JobWaiter) = + { + assert(partitions.size > 0) + val waiter = new JobWaiter(partitions.size) + val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] + val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter) + return (toSubmit, waiter) + } + def runJob[T, U: ClassManifest]( finalRdd: RDD[T], func: (TaskContext, Iterator[T]) => U, @@ -214,9 +237,8 @@ class DAGScheduler(taskSched: TaskScheduler, if (partitions.size == 0) { return new Array[U](0) } - val waiter = new JobWaiter(partitions.size) - val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] - eventQueue.put(JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter)) + val (toSubmit, waiter) = prepareJob(finalRdd, func, partitions, callSite, allowLocal) + eventQueue.put(toSubmit) waiter.getResult() match { case JobSucceeded(results: Seq[_]) => return results.asInstanceOf[Seq[U]].toArray @@ -241,6 +263,81 @@ class DAGScheduler(taskSched: TaskScheduler, return listener.getResult() // Will throw an exception if the job fails } + /** Process one event retrieved from the event queue. + * Returns true if we should stop the event loop. + */ + private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = { + event match { + case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener) => + val runId = nextRunId.getAndIncrement() + val finalStage = newStage(finalRDD, None, runId) + val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener) + clearCacheLocs() + logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length + + " output partitions (allowLocal=" + allowLocal + ")") + logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")") + logInfo("Parents of final stage: " + finalStage.parents) + logInfo("Missing parents: " + getMissingParentStages(finalStage)) + if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) { + // Compute very short actions like first() or take() with no parent stages locally. + runLocally(job) + } else { + activeJobs += job + resultStageToJob(finalStage) = job + submitStage(finalStage) + } + + case ExecutorLost(execId) => + handleExecutorLost(execId) + + case completion: CompletionEvent => + handleTaskCompletion(completion) + + case TaskSetFailed(taskSet, reason) => + abortStage(idToStage(taskSet.stageId), reason) + + case StopDAGScheduler => + // Cancel any active jobs + for (job <- activeJobs) { + val error = new SparkException("Job cancelled because SparkContext was shut down") + job.listener.jobFailed(error) + } + return true + } + return false + } + + /** Resubmit any failed stages. Ordinarily called after a small amount of time has passed since + * the last fetch failure. + */ + private[scheduler] def resubmitFailedStages() { + logInfo("Resubmitting failed stages") + clearCacheLocs() + val failed2 = failed.toArray + failed.clear() + for (stage <- failed2.sortBy(_.priority)) { + submitStage(stage) + } + } + + /** Check for waiting or failed stages which are now eligible for resubmission. + * Ordinarily run on every iteration of the event loop. + */ + private[scheduler] def submitWaitingStages() { + // TODO: We might want to run this less often, when we are sure that something has become + // runnable that wasn't before. + logTrace("Checking for newly runnable parent stages") + logTrace("running: " + running) + logTrace("waiting: " + waiting) + logTrace("failed: " + failed) + val waiting2 = waiting.toArray + waiting.clear() + for (stage <- waiting2.sortBy(_.priority)) { + submitStage(stage) + } + } + + /** * The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure * events and responds by launching tasks. This runs in a dedicated thread and receives events @@ -251,77 +348,26 @@ class DAGScheduler(taskSched: TaskScheduler, while (true) { val event = eventQueue.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS) - val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability if (event != null) { logDebug("Got event of type " + event.getClass.getName) } - event match { - case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener) => - val runId = nextRunId.getAndIncrement() - val finalStage = newStage(finalRDD, None, runId) - val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener) - clearCacheLocs() - logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length + - " output partitions") - logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")") - logInfo("Parents of final stage: " + finalStage.parents) - logInfo("Missing parents: " + getMissingParentStages(finalStage)) - if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) { - // Compute very short actions like first() or take() with no parent stages locally. - runLocally(job) - } else { - activeJobs += job - resultStageToJob(finalStage) = job - submitStage(finalStage) - } - - case ExecutorLost(execId) => - handleExecutorLost(execId) - - case completion: CompletionEvent => - handleTaskCompletion(completion) - - case TaskSetFailed(taskSet, reason) => - abortStage(idToStage(taskSet.stageId), reason) - - case StopDAGScheduler => - // Cancel any active jobs - for (job <- activeJobs) { - val error = new SparkException("Job cancelled because SparkContext was shut down") - job.listener.jobFailed(error) - } + if (event != null) { + if (processEvent(event)) { return - - case null => - // queue.poll() timed out, ignore it + } } + val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability // Periodically resubmit failed stages if some map output fetches have failed and we have // waited at least RESUBMIT_TIMEOUT. We wait for this short time because when a node fails, // tasks on many other nodes are bound to get a fetch failure, and they won't all get it at // the same time, so we want to make sure we've identified all the reduce tasks that depend // on the failed node. if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) { - logInfo("Resubmitting failed stages") - clearCacheLocs() - val failed2 = failed.toArray - failed.clear() - for (stage <- failed2.sortBy(_.priority)) { - submitStage(stage) - } + resubmitFailedStages } else { - // TODO: We might want to run this less often, when we are sure that something has become - // runnable that wasn't before. - logTrace("Checking for newly runnable parent stages") - logTrace("running: " + running) - logTrace("waiting: " + waiting) - logTrace("failed: " + failed) - val waiting2 = waiting.toArray - waiting.clear() - for (stage <- waiting2.sortBy(_.priority)) { - submitStage(stage) - } + submitWaitingStages } } } From 9c0bae75ade9e5b5a69077a5719adf4ee96e2c2e Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 30 Jan 2013 09:22:07 -0800 Subject: [PATCH 235/291] Change DAGSchedulerSuite to run DAGScheduler in the same Thread. --- .../spark/scheduler/DAGSchedulerSuite.scala | 582 ++++++++++-------- 1 file changed, 326 insertions(+), 256 deletions(-) diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala index 6c577c2685..89173540d4 100644 --- a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala @@ -4,12 +4,12 @@ import scala.collection.mutable.{Map, HashMap} import org.scalatest.FunSuite import org.scalatest.BeforeAndAfter -import org.scalatest.concurrent.AsyncAssertions import org.scalatest.concurrent.TimeLimitedTests import org.scalatest.mock.EasyMockSugar import org.scalatest.time.{Span, Seconds} import org.easymock.EasyMock._ +import org.easymock.Capture import org.easymock.EasyMock import org.easymock.{IAnswer, IArgumentMatcher} @@ -30,33 +30,55 @@ import spark.TaskEndReason import spark.{FetchFailed, Success} -class DAGSchedulerSuite extends FunSuite - with BeforeAndAfter with EasyMockSugar with TimeLimitedTests - with AsyncAssertions with spark.Logging { +class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar with TimeLimitedTests { - // If we crash the DAGScheduler thread, our test will probably hang. + // impose a time limit on this test in case we don't let the job finish. override val timeLimit = Span(5, Seconds) val sc: SparkContext = new SparkContext("local", "DAGSchedulerSuite") var scheduler: DAGScheduler = null - var w: Waiter = null val taskScheduler = mock[TaskScheduler] val blockManagerMaster = mock[BlockManagerMaster] var mapOutputTracker: MapOutputTracker = null var schedulerThread: Thread = null var schedulerException: Throwable = null + + /** Set of EasyMock argument matchers that match a TaskSet for a given RDD. + * We cache these so we do not create duplicate matchers for the same RDD. + * This allows us to easily setup a sequence of expectations for task sets for + * that RDD. + */ val taskSetMatchers = new HashMap[MyRDD, IArgumentMatcher] + + /** Set of cache locations to return from our mock BlockManagerMaster. + * Keys are (rdd ID, partition ID). Anything not present will return an empty + * list of cache locations silently. + */ val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]] + /** JobWaiter for the last JobSubmitted event we pushed. To keep tests (most of which + * will only submit one job) from needing to explicitly track it. + */ + var lastJobWaiter: JobWaiter = null + + /** Tell EasyMockSugar what mock objects we want to be configured by expecting {...} + * and whenExecuting {...} */ implicit val mocks = MockObjects(taskScheduler, blockManagerMaster) - def makeBlockManagerId(host: String): BlockManagerId = - BlockManagerId("exec-" + host, host, 12345) - + /** Utility function to reset mocks and set expectations on them. EasyMock wants mock objects + * to be reset after each time their expectations are set, and we tend to check mock object + * calls over a single call to DAGScheduler. + * + * We also set a default expectation here that blockManagerMaster.getLocations can be called + * and will return values from cacheLocations. + */ def resetExpecting(f: => Unit) { reset(taskScheduler) reset(blockManagerMaster) - expecting(f) + expecting { + expectGetLocations() + f + } } before { @@ -70,45 +92,30 @@ class DAGSchedulerSuite extends FunSuite whenExecuting { scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null) } - w = new Waiter - schedulerException = null - schedulerThread = new Thread("DAGScheduler under test") { - override def run() { - try { - scheduler.run() - } catch { - case t: Throwable => - logError("Got exception in DAGScheduler: ", t) - schedulerException = t - } finally { - w.dismiss() - } - } - } - schedulerThread.start - logInfo("finished before") } after { - logInfo("started after") + assert(scheduler.processEvent(StopDAGScheduler)) resetExpecting { taskScheduler.stop() } whenExecuting { - scheduler.stop - schedulerThread.join - } - w.await() - if (schedulerException != null) { - throw new Exception("Exception caught from scheduler thread", schedulerException) + scheduler.stop() } System.clearProperty("spark.master.port") } - // Type of RDD we use for testing. Note that we should never call the real RDD compute methods. - // This is a pair RDD type so it can always be used in ShuffleDependencies. + def makeBlockManagerId(host: String): BlockManagerId = + BlockManagerId("exec-" + host, host, 12345) + + /** Type of RDD we use for testing. Note that we should never call the real RDD compute methods. + * This is a pair RDD type so it can always be used in ShuffleDependencies. */ type MyRDD = RDD[(Int, Int)] + /** Create an RDD for passing to DAGScheduler. These RDDs will use the dependencies and + * preferredLocations (if any) that are passed to them. They are deliberately not executable + * so we can test that DAGScheduler does not try to execute RDDs locally. + */ def makeRdd( numSplits: Int, dependencies: List[Dependency[_]], @@ -130,6 +137,9 @@ class DAGSchedulerSuite extends FunSuite } } + /** EasyMock matcher method. For use as an argument matcher for a TaskSet whose first task + * is from a particular RDD. + */ def taskSetForRdd(rdd: MyRDD): TaskSet = { val matcher = taskSetMatchers.getOrElseUpdate(rdd, new IArgumentMatcher { @@ -149,6 +159,9 @@ class DAGSchedulerSuite extends FunSuite return null } + /** Setup an EasyMock expectation to repsond to blockManagerMaster.getLocations() called from + * cacheLocations. + */ def expectGetLocations(): Unit = { EasyMock.expect(blockManagerMaster.getLocations(anyObject().asInstanceOf[Array[String]])). andAnswer(new IAnswer[Seq[Seq[BlockManagerId]]] { @@ -171,51 +184,106 @@ class DAGSchedulerSuite extends FunSuite }).anyTimes() } - def expectStageAnd(rdd: MyRDD, results: Seq[(TaskEndReason, Any)], - preferredLocations: Option[Seq[Seq[String]]] = None)(afterSubmit: TaskSet => Unit) { - // TODO: Remember which submission - EasyMock.expect(taskScheduler.submitTasks(taskSetForRdd(rdd))).andAnswer(new IAnswer[Unit] { - override def answer(): Unit = { - val taskSet = getCurrentArguments()(0).asInstanceOf[TaskSet] - for (task <- taskSet.tasks) { - task.generation = mapOutputTracker.getGeneration - } - afterSubmit(taskSet) - preferredLocations match { - case None => - for (taskLocs <- taskSet.tasks.map(_.preferredLocations)) { - w { assert(taskLocs.size === 0) } - } - case Some(locations) => - w { assert(locations.size === taskSet.tasks.size) } - for ((expectLocs, taskLocs) <- - taskSet.tasks.map(_.preferredLocations).zip(locations)) { - w { assert(expectLocs === taskLocs) } - } - } - w { assert(taskSet.tasks.size >= results.size)} - for ((result, i) <- results.zipWithIndex) { - if (i < taskSet.tasks.size) { - scheduler.taskEnded(taskSet.tasks(i), result._1, result._2, Map[Long, Any]()) - } - } + /** Process the supplied event as if it were the top of the DAGScheduler event queue, expecting + * the scheduler not to exit. + * + * After processing the event, submit waiting stages as is done on most iterations of the + * DAGScheduler event loop. + */ + def runEvent(event: DAGSchedulerEvent) { + assert(!scheduler.processEvent(event)) + scheduler.submitWaitingStages() + } + + /** Expect a TaskSet for the specified RDD to be submitted to the TaskScheduler. Should be + * called from a resetExpecting { ... } block. + * + * Returns a easymock Capture that will contain the task set after the stage is submitted. + * Most tests should use interceptStage() instead of this directly. + */ + def expectStage(rdd: MyRDD): Capture[TaskSet] = { + val taskSetCapture = new Capture[TaskSet] + taskScheduler.submitTasks(and(capture(taskSetCapture), taskSetForRdd(rdd))) + return taskSetCapture + } + + /** Expect the supplied code snippet to submit a stage for the specified RDD. + * Return the resulting TaskSet. First marks all the tasks are belonging to the + * current MapOutputTracker generation. + */ + def interceptStage(rdd: MyRDD)(f: => Unit): TaskSet = { + var capture: Capture[TaskSet] = null + resetExpecting { + capture = expectStage(rdd) + } + whenExecuting { + f + } + val taskSet = capture.getValue + for (task <- taskSet.tasks) { + task.generation = mapOutputTracker.getGeneration + } + return taskSet + } + + /** Send the given CompletionEvent messages for the tasks in the TaskSet. */ + def respondToTaskSet(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]) { + assert(taskSet.tasks.size >= results.size) + for ((result, i) <- results.zipWithIndex) { + if (i < taskSet.tasks.size) { + runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, Map[Long, Any]())) } - }) + } } - def expectStage(rdd: MyRDD, results: Seq[(TaskEndReason, Any)], - preferredLocations: Option[Seq[Seq[String]]] = None) { - expectStageAnd(rdd, results, preferredLocations) { _ => } + /** Assert that the supplied TaskSet has exactly the given preferredLocations. */ + def expectTaskSetLocations(taskSet: TaskSet, locations: Seq[Seq[String]]) { + assert(locations.size === taskSet.tasks.size) + for ((expectLocs, taskLocs) <- + taskSet.tasks.map(_.preferredLocations).zip(locations)) { + assert(expectLocs === taskLocs) + } } - def submitRdd(rdd: MyRDD, allowLocal: Boolean = false): Array[Int] = { - return scheduler.runJob[(Int, Int), Int]( + /** When we submit dummy Jobs, this is the compute function we supply. Except in a local test + * below, we do not expect this function to ever be executed; instead, we will return results + * directly through CompletionEvents. + */ + def jobComputeFunc(context: TaskContext, it: Iterator[(Int, Int)]): Int = + it.next._1.asInstanceOf[Int] + + + /** Start a job to compute the given RDD. Returns the JobWaiter that will + * collect the result of the job via callbacks from DAGScheduler. */ + def submitRdd(rdd: MyRDD, allowLocal: Boolean = false): JobWaiter = { + val (toSubmit, waiter) = scheduler.prepareJob[(Int, Int), Int]( rdd, - (context: TaskContext, it: Iterator[(Int, Int)]) => it.next._1.asInstanceOf[Int], + jobComputeFunc, (0 to (rdd.splits.size - 1)), "test-site", allowLocal ) + lastJobWaiter = waiter + runEvent(toSubmit) + return waiter + } + + /** Assert that a job we started has failed. */ + def expectJobException(waiter: JobWaiter = lastJobWaiter) { + waiter.getResult match { + case JobSucceeded(_) => fail() + case JobFailed(_) => return + } + } + + /** Assert that a job we started has succeeded and has the given result. */ + def expectJobResult(expected: Array[Int], waiter: JobWaiter = lastJobWaiter) { + waiter.getResult match { + case JobSucceeded(answer) => + assert(expected === answer.asInstanceOf[Seq[Int]].toArray ) + case JobFailed(_) => + fail() + } } def makeMapStatus(host: String, reduces: Int): MapStatus = @@ -223,24 +291,14 @@ class DAGSchedulerSuite extends FunSuite test("zero split job") { val rdd = makeRdd(0, Nil) - resetExpecting { - expectGetLocations() - // deliberately expect no stages to be submitted - } - whenExecuting { - assert(submitRdd(rdd) === Array[Int]()) - } + assert(scheduler.runJob(rdd, jobComputeFunc, Seq(), "test-site", false) === Array[Int]()) } test("run trivial job") { val rdd = makeRdd(1, Nil) - resetExpecting { - expectGetLocations() - expectStage(rdd, List( (Success, 42) )) - } - whenExecuting { - assert(submitRdd(rdd) === Array(42)) - } + val taskSet = interceptStage(rdd) { submitRdd(rdd) } + respondToTaskSet(taskSet, List( (Success, 42) )) + expectJobResult(Array(42)) } test("local job") { @@ -251,51 +309,34 @@ class DAGSchedulerSuite extends FunSuite override def getPreferredLocations(split: Split) = Nil override def toString = "DAGSchedulerSuite Local RDD" } - resetExpecting { - expectGetLocations() - // deliberately expect no stages to be submitted - } - whenExecuting { - assert(submitRdd(rdd, true) === Array(42)) - } + submitRdd(rdd, true) + expectJobResult(Array(42)) } test("run trivial job w/ dependency") { val baseRdd = makeRdd(1, Nil) val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd))) - resetExpecting { - expectGetLocations() - expectStage(finalRdd, List( (Success, 42) )) - } - whenExecuting { - assert(submitRdd(finalRdd) === Array(42)) - } + val taskSet = interceptStage(finalRdd) { submitRdd(finalRdd) } + respondToTaskSet(taskSet, List( (Success, 42) )) + expectJobResult(Array(42)) } - test("location preferences w/ dependency") { + test("cache location preferences w/ dependency") { val baseRdd = makeRdd(1, Nil) val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd))) - resetExpecting { - expectGetLocations() - cacheLocations(baseRdd.id -> 0) = - Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")) - expectStage(finalRdd, List( (Success, 42) ), - Some(List(Seq("hostA", "hostB")))) - } - whenExecuting { - assert(submitRdd(finalRdd) === Array(42)) - } + cacheLocations(baseRdd.id -> 0) = + Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")) + val taskSet = interceptStage(finalRdd) { submitRdd(finalRdd) } + expectTaskSetLocations(taskSet, List(Seq("hostA", "hostB"))) + respondToTaskSet(taskSet, List( (Success, 42) )) + expectJobResult(Array(42)) } test("trivial job failure") { val rdd = makeRdd(1, Nil) - resetExpecting { - expectGetLocations() - expectStageAnd(rdd, List()) { taskSet => scheduler.taskSetFailed(taskSet, "test failure") } - } - whenExecuting(taskScheduler, blockManagerMaster) { - intercept[SparkException] { submitRdd(rdd) } - } + val taskSet = interceptStage(rdd) { submitRdd(rdd) } + runEvent(TaskSetFailed(taskSet, "test failure")) + expectJobException() } test("run trivial shuffle") { @@ -304,20 +345,17 @@ class DAGSchedulerSuite extends FunSuite val shuffleId = shuffleDep.shuffleId val reduceRdd = makeRdd(1, List(shuffleDep)) - resetExpecting { - expectGetLocations() - expectStage(shuffleMapRdd, List( + val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) } + val secondStage = interceptStage(reduceRdd) { + respondToTaskSet(firstStage, List( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)) )) - expectStageAnd(reduceRdd, List( (Success, 42) )) { _ => - w { assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === - Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) } - } - } - whenExecuting { - assert(submitRdd(reduceRdd) === Array(42)) } + assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === + Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) + respondToTaskSet(secondStage, List( (Success, 42) )) + expectJobResult(Array(42)) } test("run trivial shuffle with fetch failure") { @@ -326,28 +364,32 @@ class DAGSchedulerSuite extends FunSuite val shuffleId = shuffleDep.shuffleId val reduceRdd = makeRdd(2, List(shuffleDep)) - resetExpecting { - expectGetLocations() - expectStage(shuffleMapRdd, List( + val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) } + val secondStage = interceptStage(reduceRdd) { + respondToTaskSet(firstStage, List( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)) )) + } + resetExpecting { blockManagerMaster.removeExecutor("exec-hostA") - expectStage(reduceRdd, List( + } + whenExecuting { + respondToTaskSet(secondStage, List( (Success, 42), (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), null) )) - // partial recompute - expectStage(shuffleMapRdd, List( (Success, makeMapStatus("hostA", 1)) )) - expectStageAnd(reduceRdd, List( (Success, 43) )) { _ => - w { assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === - Array(makeBlockManagerId("hostA"), - makeBlockManagerId("hostB"))) } - } } - whenExecuting { - assert(submitRdd(reduceRdd) === Array(42, 43)) + val thirdStage = interceptStage(shuffleMapRdd) { + scheduler.resubmitFailedStages() } + val fourthStage = interceptStage(reduceRdd) { + respondToTaskSet(thirdStage, List( (Success, makeMapStatus("hostA", 1)) )) + } + assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === + Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) + respondToTaskSet(fourthStage, List( (Success, 43) )) + expectJobResult(Array(42, 43)) } test("ignore late map task completions") { @@ -356,63 +398,64 @@ class DAGSchedulerSuite extends FunSuite val shuffleId = shuffleDep.shuffleId val reduceRdd = makeRdd(2, List(shuffleDep)) + val taskSet = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) } + val oldGeneration = mapOutputTracker.getGeneration resetExpecting { - expectGetLocations() - expectStageAnd(shuffleMapRdd, List( - (Success, makeMapStatus("hostA", 1)) - )) { taskSet => - val newGeneration = mapOutputTracker.getGeneration + 1 - scheduler.executorLost("exec-hostA") - val noAccum = Map[Long, Any]() - // We rely on the event queue being ordered and increasing the generation number by 1 - // should be ignored for being too old - scheduler.taskEnded(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum) - // should work because it's a non-failed host - scheduler.taskEnded(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum) - // should be ignored for being too old - scheduler.taskEnded(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum) - // should be ignored (not end the stage) because it's too old - scheduler.taskEnded(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum) - taskSet.tasks(1).generation = newGeneration - scheduler.taskEnded(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum) - } blockManagerMaster.removeExecutor("exec-hostA") - expectStageAnd(reduceRdd, List( - (Success, 42), (Success, 43) - )) { _ => - w { assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === - Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) } - } } whenExecuting { - assert(submitRdd(reduceRdd) === Array(42, 43)) + runEvent(ExecutorLost("exec-hostA")) } + val newGeneration = mapOutputTracker.getGeneration + assert(newGeneration > oldGeneration) + val noAccum = Map[Long, Any]() + // We rely on the event queue being ordered and increasing the generation number by 1 + // should be ignored for being too old + runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum)) + // should work because it's a non-failed host + runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum)) + // should be ignored for being too old + runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum)) + taskSet.tasks(1).generation = newGeneration + val secondStage = interceptStage(reduceRdd) { + runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum)) + } + assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === + Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) + respondToTaskSet(secondStage, List( (Success, 42), (Success, 43) )) + expectJobResult(Array(42, 43)) } - test("run trivial shuffle with out-of-band failure") { + test("run trivial shuffle with out-of-band failure and retry") { val shuffleMapRdd = makeRdd(2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) val shuffleId = shuffleDep.shuffleId val reduceRdd = makeRdd(1, List(shuffleDep)) + + val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) } resetExpecting { - expectGetLocations() blockManagerMaster.removeExecutor("exec-hostA") - expectStageAnd(shuffleMapRdd, List( - (Success, makeMapStatus("hostA", 1)), - (Success, makeMapStatus("hostB", 1)) - )) { _ => scheduler.executorLost("exec-hostA") } - expectStage(shuffleMapRdd, List( - (Success, makeMapStatus("hostC", 1)) - )) - expectStageAnd(reduceRdd, List( (Success, 42) )) { _ => - w { assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === - Array(makeBlockManagerId("hostC"), - makeBlockManagerId("hostB"))) } - } } whenExecuting { - assert(submitRdd(reduceRdd) === Array(42)) + runEvent(ExecutorLost("exec-hostA")) } + // DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks + // rather than marking it is as failed and waiting. + val secondStage = interceptStage(shuffleMapRdd) { + respondToTaskSet(firstStage, List( + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostB", 1)) + )) + } + val thirdStage = interceptStage(reduceRdd) { + respondToTaskSet(secondStage, List( + (Success, makeMapStatus("hostC", 1)) + )) + } + assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === + Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + respondToTaskSet(thirdStage, List( (Success, 42) )) + expectJobResult(Array(42)) } test("recursive shuffle failures") { @@ -422,34 +465,42 @@ class DAGSchedulerSuite extends FunSuite val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) val finalRdd = makeRdd(1, List(shuffleDepTwo)) - resetExpecting { - expectGetLocations() - expectStage(shuffleOneRdd, List( - (Success, makeMapStatus("hostA", 1)), - (Success, makeMapStatus("hostB", 1)) + val firstStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) } + val secondStage = interceptStage(shuffleTwoRdd) { + respondToTaskSet(firstStage, List( + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)) )) - expectStage(shuffleTwoRdd, List( + } + val thirdStage = interceptStage(finalRdd) { + respondToTaskSet(secondStage, List( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostC", 1)) )) + } + resetExpecting { blockManagerMaster.removeExecutor("exec-hostA") - expectStage(finalRdd, List( - (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null) - )) - // triggers a partial recompute of the first stage, then the second - expectStage(shuffleOneRdd, List( - (Success, makeMapStatus("hostA", 1)) - )) - expectStage(shuffleTwoRdd, List( - (Success, makeMapStatus("hostA", 1)) - )) - expectStage(finalRdd, List( - (Success, 42) - )) } whenExecuting { - assert(submitRdd(finalRdd) === Array(42)) + respondToTaskSet(thirdStage, List( + (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null) + )) } + val recomputeOne = interceptStage(shuffleOneRdd) { + scheduler.resubmitFailedStages + } + val recomputeTwo = interceptStage(shuffleTwoRdd) { + respondToTaskSet(recomputeOne, List( + (Success, makeMapStatus("hostA", 2)) + )) + } + val finalStage = interceptStage(finalRdd) { + respondToTaskSet(recomputeTwo, List( + (Success, makeMapStatus("hostA", 1)) + )) + } + respondToTaskSet(finalStage, List( (Success, 42) )) + expectJobResult(Array(42)) } test("cached post-shuffle") { @@ -459,35 +510,41 @@ class DAGSchedulerSuite extends FunSuite val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) val finalRdd = makeRdd(1, List(shuffleDepTwo)) - resetExpecting { - expectGetLocations() - expectStage(shuffleOneRdd, List( + val firstShuffleStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) } + cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD")) + cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC")) + val secondShuffleStage = interceptStage(shuffleTwoRdd) { + respondToTaskSet(firstShuffleStage, List( + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)) + )) + } + val reduceStage = interceptStage(finalRdd) { + respondToTaskSet(secondShuffleStage, List( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)) )) - expectStageAnd(shuffleTwoRdd, List( - (Success, makeMapStatus("hostA", 1)), - (Success, makeMapStatus("hostC", 1)) - )){ _ => - cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD")) - cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC")) - } + } + resetExpecting { blockManagerMaster.removeExecutor("exec-hostA") - expectStage(finalRdd, List( - (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null) - )) - // since we have a cached copy of the missing split of shuffleTwoRdd, we shouldn't - // immediately try to rerun shuffleOneRdd: - expectStage(shuffleTwoRdd, List( - (Success, makeMapStatus("hostD", 1)) - ), Some(Seq(List("hostD")))) - expectStage(finalRdd, List( - (Success, 42) - )) } whenExecuting { - assert(submitRdd(finalRdd) === Array(42)) + respondToTaskSet(reduceStage, List( + (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null) + )) } + // DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun. + val recomputeTwo = interceptStage(shuffleTwoRdd) { + scheduler.resubmitFailedStages() + } + expectTaskSetLocations(recomputeTwo, Seq(Seq("hostD"))) + val finalRetry = interceptStage(finalRdd) { + respondToTaskSet(recomputeTwo, List( + (Success, makeMapStatus("hostD", 1)) + )) + } + respondToTaskSet(finalRetry, List( (Success, 42) )) + expectJobResult(Array(42)) } test("cached post-shuffle but fails") { @@ -497,45 +554,58 @@ class DAGSchedulerSuite extends FunSuite val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null) val finalRdd = makeRdd(1, List(shuffleDepTwo)) - resetExpecting { - expectGetLocations() - expectStage(shuffleOneRdd, List( + val firstShuffleStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) } + cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD")) + cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC")) + val secondShuffleStage = interceptStage(shuffleTwoRdd) { + respondToTaskSet(firstShuffleStage, List( + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)) + )) + } + val reduceStage = interceptStage(finalRdd) { + respondToTaskSet(secondShuffleStage, List( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)) )) - expectStageAnd(shuffleTwoRdd, List( - (Success, makeMapStatus("hostA", 1)), - (Success, makeMapStatus("hostC", 1)) - )){ _ => - cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD")) - cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC")) - } + } + resetExpecting { blockManagerMaster.removeExecutor("exec-hostA") - expectStage(finalRdd, List( - (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null) - )) - // since we have a cached copy of the missing split of shuffleTwoRdd, we shouldn't - // immediately try to rerun shuffleOneRdd: - expectStageAnd(shuffleTwoRdd, List( - (FetchFailed(null, shuffleDepOne.shuffleId, 0, 0), null) - ), Some(Seq(List("hostD")))) { _ => - w { - intercept[FetchFailedException]{ - mapOutputTracker.getServerStatuses(shuffleDepOne.shuffleId, 0) - } - } - cacheLocations.remove(shuffleTwoRdd.id -> 0) - } - // after that fetch failure, we should refetch the cache locations and try to recompute - // the whole chain. Note that we will ignore that a fetch failure previously occured on - // this host. - expectStage(shuffleOneRdd, List( (Success, makeMapStatus("hostA", 1)) )) - expectStage(shuffleTwoRdd, List( (Success, makeMapStatus("hostA", 1)) )) - expectStage(finalRdd, List( (Success, 42) )) } whenExecuting { - assert(submitRdd(finalRdd) === Array(42)) + respondToTaskSet(reduceStage, List( + (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null) + )) } + val recomputeTwoCached = interceptStage(shuffleTwoRdd) { + scheduler.resubmitFailedStages() + } + expectTaskSetLocations(recomputeTwoCached, Seq(Seq("hostD"))) + intercept[FetchFailedException]{ + mapOutputTracker.getServerStatuses(shuffleDepOne.shuffleId, 0) + } + + // Simulate the shuffle input data failing to be cached. + cacheLocations.remove(shuffleTwoRdd.id -> 0) + respondToTaskSet(recomputeTwoCached, List( + (FetchFailed(null, shuffleDepOne.shuffleId, 0, 0), null) + )) + + // After the fetch failure, DAGScheduler should recheck the cache and decide to resubmit + // everything. + val recomputeOne = interceptStage(shuffleOneRdd) { + scheduler.resubmitFailedStages() + } + // We use hostA here to make sure DAGScheduler doesn't think it's still dead. + val recomputeTwoUncached = interceptStage(shuffleTwoRdd) { + respondToTaskSet(recomputeOne, List( (Success, makeMapStatus("hostA", 1)) )) + } + expectTaskSetLocations(recomputeTwoUncached, Seq(Seq[String]())) + val finalRetry = interceptStage(finalRdd) { + respondToTaskSet(recomputeTwoUncached, List( (Success, makeMapStatus("hostA", 1)) )) + + } + respondToTaskSet(finalRetry, List( (Success, 42) )) + expectJobResult(Array(42)) } } - From 7f51458774ce4561f1df3ba9b68704c3f63852f3 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 30 Jan 2013 09:34:53 -0800 Subject: [PATCH 236/291] Comment at top of DAGSchedulerSuite --- .../scala/spark/scheduler/DAGSchedulerSuite.scala | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala index 89173540d4..c31e2e7064 100644 --- a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala @@ -30,9 +30,22 @@ import spark.TaskEndReason import spark.{FetchFailed, Success} +/** + * Tests for DAGScheduler. These tests directly call the event processing functinos in DAGScheduler + * rather than spawning an event loop thread as happens in the real code. They use EasyMock + * to mock out two classes that DAGScheduler interacts with: TaskScheduler (to which TaskSets are + * submitted) and BlockManagerMaster (from which cache locations are retrieved and to which dead + * host notifications are sent). In addition, tests may check for side effects on a non-mocked + * MapOutputTracker instance. + * + * Tests primarily consist of running DAGScheduler#processEvent and + * DAGScheduler#submitWaitingStages (via test utility functions like runEvent or respondToTaskSet) + * and capturing the resulting TaskSets from the mock TaskScheduler. + */ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar with TimeLimitedTests { - // impose a time limit on this test in case we don't let the job finish. + // impose a time limit on this test in case we don't let the job finish, in which case + // JobWaiter#getResult will hang. override val timeLimit = Span(5, Seconds) val sc: SparkContext = new SparkContext("local", "DAGSchedulerSuite") From f7de6978c14a331683e4a341fccd6e4c5e9fa523 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Tue, 29 Jan 2013 14:03:05 -0800 Subject: [PATCH 237/291] Use Mesos ExecutorIDs to hold SlaveIDs. Then we can safely use the Mesos ExecutorID as a Spark ExecutorID. --- .../spark/executor/MesosExecutorBackend.scala | 6 +++- .../mesos/MesosSchedulerBackend.scala | 30 ++++++++++--------- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/spark/executor/MesosExecutorBackend.scala index 1ef88075ad..b981b26916 100644 --- a/core/src/main/scala/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/MesosExecutorBackend.scala @@ -32,7 +32,11 @@ private[spark] class MesosExecutorBackend(executor: Executor) logInfo("Registered with Mesos as executor ID " + executorInfo.getExecutorId.getValue) this.driver = driver val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray) - executor.initialize(executorInfo.getExecutorId.getValue, slaveInfo.getHostname, properties) + executor.initialize( + slaveInfo.getId.getValue + "-" + executorInfo.getExecutorId.getValue, + slaveInfo.getHostname, + properties + ) } override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) { diff --git a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala index f3467db86b..eab1c60e0b 100644 --- a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala @@ -51,7 +51,7 @@ private[spark] class MesosSchedulerBackend( val taskIdToSlaveId = new HashMap[Long, String] // An ExecutorInfo for our tasks - var executorInfo: ExecutorInfo = null + var execArgs: Array[Byte] = null override def start() { synchronized { @@ -70,12 +70,11 @@ private[spark] class MesosSchedulerBackend( } }.start() - executorInfo = createExecutorInfo() waitForRegister() } } - def createExecutorInfo(): ExecutorInfo = { + def createExecutorInfo(execId: String): ExecutorInfo = { val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException( "Spark home is not set; set it through the spark.home system " + "property, the SPARK_HOME environment variable or the SparkContext constructor")) @@ -97,7 +96,7 @@ private[spark] class MesosSchedulerBackend( .setEnvironment(environment) .build() ExecutorInfo.newBuilder() - .setExecutorId(ExecutorID.newBuilder().setValue("default").build()) + .setExecutorId(ExecutorID.newBuilder().setValue(execId).build()) .setCommand(command) .setData(ByteString.copyFrom(createExecArg())) .addResources(memory) @@ -109,17 +108,20 @@ private[spark] class MesosSchedulerBackend( * containing all the spark.* system properties in the form of (String, String) pairs. */ private def createExecArg(): Array[Byte] = { - val props = new HashMap[String, String] - val iterator = System.getProperties.entrySet.iterator - while (iterator.hasNext) { - val entry = iterator.next - val (key, value) = (entry.getKey.toString, entry.getValue.toString) - if (key.startsWith("spark.")) { - props(key) = value + if (execArgs == null) { + val props = new HashMap[String, String] + val iterator = System.getProperties.entrySet.iterator + while (iterator.hasNext) { + val entry = iterator.next + val (key, value) = (entry.getKey.toString, entry.getValue.toString) + if (key.startsWith("spark.")) { + props(key) = value + } } + // Serialize the map as an array of (String, String) pairs + execArgs = Utils.serialize(props.toArray) } - // Serialize the map as an array of (String, String) pairs - return Utils.serialize(props.toArray) + return execArgs } override def offerRescinded(d: SchedulerDriver, o: OfferID) {} @@ -216,7 +218,7 @@ private[spark] class MesosSchedulerBackend( return MesosTaskInfo.newBuilder() .setTaskId(taskId) .setSlaveId(SlaveID.newBuilder().setValue(slaveId).build()) - .setExecutor(executorInfo) + .setExecutor(createExecutorInfo(slaveId)) .setName(task.name) .addResources(cpuResource) .setData(ByteString.copyFrom(task.serializedTask)) From 252845d3046034d6e779bd7245d2f876debba8fd Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Wed, 30 Jan 2013 10:38:06 -0800 Subject: [PATCH 238/291] Remove remants of attempt to use slaveId-executorId in MesosExecutorBackend --- core/src/main/scala/spark/executor/MesosExecutorBackend.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/spark/executor/MesosExecutorBackend.scala index b981b26916..818d6d1dda 100644 --- a/core/src/main/scala/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/MesosExecutorBackend.scala @@ -33,7 +33,7 @@ private[spark] class MesosExecutorBackend(executor: Executor) this.driver = driver val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray) executor.initialize( - slaveInfo.getId.getValue + "-" + executorInfo.getExecutorId.getValue, + executorInfo.getExecutorId.getValue, slaveInfo.getHostname, properties ) From 871476d506a2d543482defb923a42a2a01f206ab Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Wed, 30 Jan 2013 16:56:46 -0600 Subject: [PATCH 239/291] Include message and exitStatus if availalbe. --- core/src/main/scala/spark/deploy/worker/Worker.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala index 5a83a42daf..8b41620d98 100644 --- a/core/src/main/scala/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/spark/deploy/worker/Worker.scala @@ -134,7 +134,9 @@ private[spark] class Worker( val fullId = jobId + "/" + execId if (ExecutorState.isFinished(state)) { val executor = executors(fullId) - logInfo("Executor " + fullId + " finished with state " + state) + logInfo("Executor " + fullId + " finished with state " + state + + message.map(" message " + _).getOrElse("") + + exitStatus.map(" exitStatus " + _).getOrElse("")) finishedExecutors(fullId) = executor executors -= fullId coresUsed -= executor.cores From 58a7d320d7287f3773976f0efdf2bc2c1474f7f9 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 30 Jan 2013 14:49:18 -0800 Subject: [PATCH 240/291] Inclue packaging and launching pyspark in guide. It's nicer if all the commands you need are made explicit. --- docs/python-programming-guide.md | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md index a840b9b34b..94311bd102 100644 --- a/docs/python-programming-guide.md +++ b/docs/python-programming-guide.md @@ -67,8 +67,14 @@ The script automatically adds the `pyspark` package to the `PYTHONPATH`. # Interactive Use -The `pyspark` script launches a Python interpreter that is configured to run PySpark jobs. -When run without any input files, `pyspark` launches a shell that can be used explore data interactively, which is a simple way to learn the API: +The `pyspark` script launches a Python interpreter that is configured to run PySpark jobs. To use `pyspark` interactively, first build Spark, then launch it directly from the command line without any options: + +{% highlight bash %} +$ sbt/sbt package +$ ./pyspark +{% endhighlight %} + +The Python shell can be used explore data interactively and is a simple way to learn the API: {% highlight python %} >>> words = sc.textFile("/usr/share/dict/words") From 3f945e3b830c5a7d50acd61c5aabf964f40f7f4b Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 30 Jan 2013 15:04:06 -0800 Subject: [PATCH 241/291] Make module help available in python shell. Also, adds a line in doc explaining how to use. --- docs/python-programming-guide.md | 1 + python/pyspark/shell.py | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md index 94311bd102..4e84d23edf 100644 --- a/docs/python-programming-guide.md +++ b/docs/python-programming-guide.md @@ -80,6 +80,7 @@ The Python shell can be used explore data interactively and is a simple way to l >>> words = sc.textFile("/usr/share/dict/words") >>> words.filter(lambda w: w.startswith("spar")).take(5) [u'spar', u'sparable', u'sparada', u'sparadrap', u'sparagrass'] +>>> help(pyspark) # Show all pyspark functions {% endhighlight %} By default, the `pyspark` shell creates SparkContext that runs jobs locally. diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index f6328c561f..54ff1bf8e7 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -4,6 +4,7 @@ An interactive shell. This file is designed to be launched as a PYTHONSTARTUP script. """ import os +import pyspark from pyspark.context import SparkContext From c1df24d0850b0ac89f35f1a47ce6b2fb5b95df0a Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Wed, 30 Jan 2013 18:51:14 -0800 Subject: [PATCH 242/291] rename Slaves --> Executor --- core/src/main/scala/spark/SparkContext.scala | 6 +++--- core/src/main/scala/spark/storage/BlockManagerUI.scala | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index a09eca1dd0..39e3555de8 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -468,7 +468,7 @@ class SparkContext( * Return a map from the slave to the max memory available for caching and the remaining * memory available for caching. */ - def getSlavesMemoryStatus: Map[String, (Long, Long)] = { + def getExecutorMemoryStatus: Map[String, (Long, Long)] = { env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) => (blockManagerId.ip + ":" + blockManagerId.port, mem) } @@ -479,13 +479,13 @@ class SparkContext( * they take, etc. */ def getRDDStorageInfo : Array[RDDInfo] = { - StorageUtils.rddInfoFromStorageStatus(getSlavesStorageStatus, this) + StorageUtils.rddInfoFromStorageStatus(getExecutorStorageStatus, this) } /** * Return information about blocks stored in all of the slaves */ - def getSlavesStorageStatus : Array[StorageStatus] = { + def getExecutorStorageStatus : Array[StorageStatus] = { env.blockManager.master.getStorageStatus } diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala index 52f6d1b657..9e6721ec17 100644 --- a/core/src/main/scala/spark/storage/BlockManagerUI.scala +++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala @@ -45,7 +45,7 @@ class BlockManagerUI(val actorSystem: ActorSystem, blockManagerMaster: ActorRef, path("") { completeWith { // Request the current storage status from the Master - val storageStatusList = sc.getSlavesStorageStatus + val storageStatusList = sc.getExecutorStorageStatus // Calculate macro-level statistics val maxMem = storageStatusList.map(_.maxMem).reduce(_+_) val remainingMem = storageStatusList.map(_.memRemaining).reduce(_+_) @@ -60,7 +60,7 @@ class BlockManagerUI(val actorSystem: ActorSystem, blockManagerMaster: ActorRef, parameter("id") { id => completeWith { val prefix = "rdd_" + id.toString - val storageStatusList = sc.getSlavesStorageStatus + val storageStatusList = sc.getExecutorStorageStatus val filteredStorageStatusList = StorageUtils. filterStorageStatusByPrefix(storageStatusList, prefix) val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head From fe3eceab5724bec0103471eb905bb9701120b04a Mon Sep 17 00:00:00 2001 From: Mikhail Bautin Date: Thu, 31 Jan 2013 13:30:41 -0800 Subject: [PATCH 243/291] Remove activation of profiles by default See the discussion at https://github.com/mesos/spark/pull/355 for why default profile activation is a problem. --- bagel/pom.xml | 11 ----------- core/pom.xml | 11 ----------- examples/pom.xml | 11 ----------- pom.xml | 11 ----------- repl-bin/pom.xml | 11 ----------- repl/pom.xml | 11 ----------- streaming/pom.xml | 11 ----------- 7 files changed, 77 deletions(-) diff --git a/bagel/pom.xml b/bagel/pom.xml index 5f58347204..a8256a6e8b 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -45,11 +45,6 @@ hadoop1 - - - !hadoopVersion - - org.spark-project @@ -77,12 +72,6 @@ hadoop2 - - - hadoopVersion - 2 - - org.spark-project diff --git a/core/pom.xml b/core/pom.xml index 862d3ec37a..873e8a1d0f 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -163,11 +163,6 @@ hadoop1 - - - !hadoopVersion - - org.apache.hadoop @@ -220,12 +215,6 @@ hadoop2 - - - hadoopVersion - 2 - - org.apache.hadoop diff --git a/examples/pom.xml b/examples/pom.xml index 4d43103475..f43af670c6 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -50,11 +50,6 @@ hadoop1 - - - !hadoopVersion - - org.spark-project @@ -88,12 +83,6 @@ hadoop2 - - - hadoopVersion - 2 - - org.spark-project diff --git a/pom.xml b/pom.xml index 3ea989a082..c6b9012dc6 100644 --- a/pom.xml +++ b/pom.xml @@ -499,11 +499,6 @@ hadoop1 - - - !hadoopVersion - - 1 @@ -521,12 +516,6 @@ hadoop2 - - - hadoopVersion - 2 - - 2 diff --git a/repl-bin/pom.xml b/repl-bin/pom.xml index da91c0f3ab..0667b71cc7 100644 --- a/repl-bin/pom.xml +++ b/repl-bin/pom.xml @@ -70,11 +70,6 @@ hadoop1 - - - !hadoopVersion - - hadoop1 @@ -115,12 +110,6 @@ hadoop2 - - - hadoopVersion - 2 - - hadoop2 diff --git a/repl/pom.xml b/repl/pom.xml index 2dc96beaf5..4a296fa630 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -72,11 +72,6 @@ hadoop1 - - - !hadoopVersion - - hadoop1 @@ -128,12 +123,6 @@ hadoop2 - - - hadoopVersion - 2 - - hadoop2 diff --git a/streaming/pom.xml b/streaming/pom.xml index 3dae815e1a..6ee7e59df3 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -83,11 +83,6 @@ hadoop1 - - - !hadoopVersion - - org.spark-project @@ -115,12 +110,6 @@ hadoop2 - - - hadoopVersion - 2 - - org.spark-project From 418e36caa8fcd9a70026ab762ec709732fdebd6b Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Thu, 31 Jan 2013 17:18:33 -0600 Subject: [PATCH 244/291] Add more private declarations. --- .../main/scala/spark/MapOutputTracker.scala | 2 +- .../spark/deploy/master/MasterWebUI.scala | 22 +++------ .../scala/spark/scheduler/DAGScheduler.scala | 46 ++++++++++--------- .../spark/scheduler/ShuffleMapTask.scala | 3 +- .../scheduler/cluster/ClusterScheduler.scala | 2 +- .../scheduler/cluster/TaskSetManager.scala | 19 ++++---- .../scheduler/local/LocalScheduler.scala | 4 +- .../scala/spark/util/MetadataCleaner.scala | 10 ++-- 8 files changed, 49 insertions(+), 59 deletions(-) diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index aaf433b324..4735207585 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -170,7 +170,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolea } } - def cleanup(cleanupTime: Long) { + private def cleanup(cleanupTime: Long) { mapStatuses.clearOldValues(cleanupTime) cachedSerializedStatuses.clearOldValues(cleanupTime) } diff --git a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala index a01774f511..529f72e9da 100644 --- a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala +++ b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala @@ -45,13 +45,9 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct case (jobId, Some(js)) if (js.equalsIgnoreCase("json")) => val future = master ? RequestMasterState val jobInfo = for (masterState <- future.mapTo[MasterState]) yield { - masterState.activeJobs.find(_.id == jobId) match { - case Some(job) => job - case _ => masterState.completedJobs.find(_.id == jobId) match { - case Some(job) => job - case _ => null - } - } + masterState.activeJobs.find(_.id == jobId).getOrElse({ + masterState.completedJobs.find(_.id == jobId).getOrElse(null) + }) } respondWithMediaType(MediaTypes.`application/json`) { ctx => ctx.complete(jobInfo.mapTo[JobInfo]) @@ -61,14 +57,10 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct val future = master ? RequestMasterState future.map { state => val masterState = state.asInstanceOf[MasterState] - - masterState.activeJobs.find(_.id == jobId) match { - case Some(job) => spark.deploy.master.html.job_details.render(job) - case _ => masterState.completedJobs.find(_.id == jobId) match { - case Some(job) => spark.deploy.master.html.job_details.render(job) - case _ => null - } - } + val job = masterState.activeJobs.find(_.id == jobId).getOrElse({ + masterState.completedJobs.find(_.id == jobId).getOrElse(null) + }) + spark.deploy.master.html.job_details.render(job) } } } diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index b130be6a38..14f61f7e87 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -97,7 +97,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with } }.start() - def getCacheLocs(rdd: RDD[_]): Array[List[String]] = { + private def getCacheLocs(rdd: RDD[_]): Array[List[String]] = { if (!cacheLocs.contains(rdd.id)) { val blockIds = rdd.splits.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray cacheLocs(rdd.id) = blockManagerMaster.getLocations(blockIds).map { @@ -107,7 +107,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with cacheLocs(rdd.id) } - def clearCacheLocs() { + private def clearCacheLocs() { cacheLocs.clear() } @@ -116,7 +116,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with * The priority value passed in will be used if the stage doesn't already exist with * a lower priority (we assume that priorities always increase across jobs for now). */ - def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], priority: Int): Stage = { + private def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], priority: Int): Stage = { shuffleToMapStage.get(shuffleDep.shuffleId) match { case Some(stage) => stage case None => @@ -131,11 +131,11 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with * as a result stage for the final RDD used directly in an action. The stage will also be given * the provided priority. */ - def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_]], priority: Int): Stage = { - // Kind of ugly: need to register RDDs with the cache and map output tracker here - // since we can't do it in the RDD constructor because # of splits is unknown - logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")") + private def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_]], priority: Int): Stage = { if (shuffleDep != None) { + // Kind of ugly: need to register RDDs with the cache and map output tracker here + // since we can't do it in the RDD constructor because # of splits is unknown + logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")") mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size) } val id = nextStageId.getAndIncrement() @@ -148,7 +148,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with * Get or create the list of parent stages for a given RDD. The stages will be assigned the * provided priority if they haven't already been created with a lower priority. */ - def getParentStages(rdd: RDD[_], priority: Int): List[Stage] = { + private def getParentStages(rdd: RDD[_], priority: Int): List[Stage] = { val parents = new HashSet[Stage] val visited = new HashSet[RDD[_]] def visit(r: RDD[_]) { @@ -170,7 +170,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with parents.toList } - def getMissingParentStages(stage: Stage): List[Stage] = { + private def getMissingParentStages(stage: Stage): List[Stage] = { val missing = new HashSet[Stage] val visited = new HashSet[RDD[_]] def visit(rdd: RDD[_]) { @@ -241,7 +241,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with * events and responds by launching tasks. This runs in a dedicated thread and receives events * via the eventQueue. */ - def run() { + private def run() { SparkEnv.set(env) while (true) { @@ -326,7 +326,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with * We run the operation in a separate thread just in case it takes a bunch of time, so that we * don't block the DAGScheduler event loop or other concurrent jobs. */ - def runLocally(job: ActiveJob) { + private def runLocally(job: ActiveJob) { logInfo("Computing the requested partition locally") new Thread("Local computation of job " + job.runId) { override def run() { @@ -349,13 +349,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with }.start() } - def submitStage(stage: Stage) { + /** Submits stage, but first recursively submits any missing parents. */ + private def submitStage(stage: Stage) { logDebug("submitStage(" + stage + ")") if (!waiting(stage) && !running(stage) && !failed(stage)) { val missing = getMissingParentStages(stage).sortBy(_.id) logDebug("missing: " + missing) if (missing == Nil) { - logInfo("Submitting " + stage + " (" + stage.origin + "), which has no missing parents") + logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents") submitMissingTasks(stage) running += stage } else { @@ -367,7 +368,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with } } - def submitMissingTasks(stage: Stage) { + /** Called when stage's parents are available and we can now do its task. */ + private def submitMissingTasks(stage: Stage) { logDebug("submitMissingTasks(" + stage + ")") // Get our pending tasks and remember them in our pendingTasks entry val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet) @@ -388,7 +390,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with } } if (tasks.size > 0) { - logInfo("Submitting " + tasks.size + " missing tasks from " + stage) + logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")") myPending ++= tasks logDebug("New pending tasks: " + myPending) taskSched.submitTasks( @@ -407,7 +409,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with * Responds to a task finishing. This is called inside the event loop so it assumes that it can * modify the scheduler's internal state. Use taskEnded() to post a task end event from outside. */ - def handleTaskCompletion(event: CompletionEvent) { + private def handleTaskCompletion(event: CompletionEvent) { val task = event.task val stage = idToStage(task.stageId) @@ -492,7 +494,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with waiting --= newlyRunnable running ++= newlyRunnable for (stage <- newlyRunnable.sortBy(_.id)) { - logInfo("Submitting " + stage + " (" + stage.origin + "), which is now runnable") + logInfo("Submitting " + stage + " (" + stage.rdd + "), which is now runnable") submitMissingTasks(stage) } } @@ -541,7 +543,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with * Optionally the generation during which the failure was caught can be passed to avoid allowing * stray fetch failures from possibly retriggering the detection of a node as lost. */ - def handleExecutorLost(execId: String, maybeGeneration: Option[Long] = None) { + private def handleExecutorLost(execId: String, maybeGeneration: Option[Long] = None) { val currentGeneration = maybeGeneration.getOrElse(mapOutputTracker.getGeneration) if (!failedGeneration.contains(execId) || failedGeneration(execId) < currentGeneration) { failedGeneration(execId) = currentGeneration @@ -567,7 +569,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with * Aborts all jobs depending on a particular Stage. This is called in response to a task set * being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. */ - def abortStage(failedStage: Stage, reason: String) { + private def abortStage(failedStage: Stage, reason: String) { val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq for (resultStage <- dependentStages) { val job = resultStageToJob(resultStage) @@ -583,7 +585,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with /** * Return true if one of stage's ancestors is target. */ - def stageDependsOn(stage: Stage, target: Stage): Boolean = { + private def stageDependsOn(stage: Stage, target: Stage): Boolean = { if (stage == target) { return true } @@ -610,7 +612,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with visitedRdds.contains(target.rdd) } - def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = { + private def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = { // If the partition is cached, return the cache locations val cached = getCacheLocs(rdd)(partition) if (cached != Nil) { @@ -636,7 +638,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with return Nil } - def cleanup(cleanupTime: Long) { + private def cleanup(cleanupTime: Long) { var sizeBefore = idToStage.size idToStage.clearOldValues(cleanupTime) logInfo("idToStage " + sizeBefore + " --> " + idToStage.size) diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 83641a2a84..b701b67c89 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -127,7 +127,6 @@ private[spark] class ShuffleMapTask( val bucketId = dep.partitioner.getPartition(pair._1) buckets(bucketId) += pair } - val bucketIterators = buckets.map(_.iterator) val compressedSizes = new Array[Byte](numOutputSplits) @@ -135,7 +134,7 @@ private[spark] class ShuffleMapTask( for (i <- 0 until numOutputSplits) { val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i // Get a Scala iterator from Java map - val iter: Iterator[(Any, Any)] = bucketIterators(i) + val iter: Iterator[(Any, Any)] = buckets(i).iterator val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false) compressedSizes(i) = MapOutputTracker.compressSize(size) } diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index 0b4177805b..1e4fbdb874 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -86,7 +86,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } } - def submitTasks(taskSet: TaskSet) { + override def submitTasks(taskSet: TaskSet) { val tasks = taskSet.tasks logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") this.synchronized { diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index 26201ad0dd..3dabdd76b1 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -17,10 +17,7 @@ import java.nio.ByteBuffer /** * Schedules the tasks within a single TaskSet in the ClusterScheduler. */ -private[spark] class TaskSetManager( - sched: ClusterScheduler, - val taskSet: TaskSet) - extends Logging { +private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSet) extends Logging { // Maximum time to wait to run a task in a preferred location (in ms) val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong @@ -100,7 +97,7 @@ private[spark] class TaskSetManager( } // Add a task to all the pending-task lists that it should be on. - def addPendingTask(index: Int) { + private def addPendingTask(index: Int) { val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive if (locations.size == 0) { pendingTasksWithNoPrefs += index @@ -115,7 +112,7 @@ private[spark] class TaskSetManager( // Return the pending tasks list for a given host, or an empty list if // there is no map entry for that host - def getPendingTasksForHost(host: String): ArrayBuffer[Int] = { + private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = { pendingTasksForHost.getOrElse(host, ArrayBuffer()) } @@ -123,7 +120,7 @@ private[spark] class TaskSetManager( // Return None if the list is empty. // This method also cleans up any tasks in the list that have already // been launched, since we want that to happen lazily. - def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = { + private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = { while (!list.isEmpty) { val index = list.last list.trimEnd(1) @@ -137,7 +134,7 @@ private[spark] class TaskSetManager( // Return a speculative task for a given host if any are available. The task should not have an // attempt running on this host, in case the host is slow. In addition, if localOnly is set, the // task must have a preference for this host (or no preferred locations at all). - def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = { + private def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = { val hostsAlive = sched.hostsAlive speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set val localTask = speculatableTasks.find { @@ -162,7 +159,7 @@ private[spark] class TaskSetManager( // Dequeue a pending task for a given node and return its index. // If localOnly is set to false, allow non-local tasks as well. - def findTask(host: String, localOnly: Boolean): Option[Int] = { + private def findTask(host: String, localOnly: Boolean): Option[Int] = { val localTask = findTaskFromList(getPendingTasksForHost(host)) if (localTask != None) { return localTask @@ -184,7 +181,7 @@ private[spark] class TaskSetManager( // Does a host count as a preferred location for a task? This is true if // either the task has preferred locations and this host is one, or it has // no preferred locations (in which we still count the launch as preferred). - def isPreferredLocation(task: Task[_], host: String): Boolean = { + private def isPreferredLocation(task: Task[_], host: String): Boolean = { val locs = task.preferredLocations return (locs.contains(host) || locs.isEmpty) } @@ -335,7 +332,7 @@ private[spark] class TaskSetManager( if (numFailures(index) > MAX_TASK_FAILURES) { logError("Task %s:%d failed more than %d times; aborting job".format( taskSet.id, index, MAX_TASK_FAILURES)) - abort("Task %d failed more than %d times".format(index, MAX_TASK_FAILURES)) + abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES)) } } } else { diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index 9ff7c02097..482d1cc853 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -53,7 +53,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon } def runTask(task: Task[_], idInJob: Int, attemptId: Int) { - logInfo("Running task " + idInJob) + logInfo("Running " + task) // Set the Spark execution environment for the worker thread SparkEnv.set(env) try { @@ -80,7 +80,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon val resultToReturn = ser.deserialize[Any](ser.serialize(result)) val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]]( ser.serialize(Accumulators.values)) - logInfo("Finished task " + idInJob) + logInfo("Finished " + task) // If the threadpool has not already been shutdown, notify DAGScheduler if (!Thread.currentThread().isInterrupted) diff --git a/core/src/main/scala/spark/util/MetadataCleaner.scala b/core/src/main/scala/spark/util/MetadataCleaner.scala index eaff7ae581..a342d378ff 100644 --- a/core/src/main/scala/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/spark/util/MetadataCleaner.scala @@ -9,12 +9,12 @@ import spark.Logging * Runs a timer task to periodically clean up metadata (e.g. old files or hashtable entries) */ class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging { - val delaySeconds = MetadataCleaner.getDelaySeconds - val periodSeconds = math.max(10, delaySeconds / 10) - val timer = new Timer(name + " cleanup timer", true) + private val delaySeconds = MetadataCleaner.getDelaySeconds + private val periodSeconds = math.max(10, delaySeconds / 10) + private val timer = new Timer(name + " cleanup timer", true) - val task = new TimerTask { - def run() { + private val task = new TimerTask { + override def run() { try { cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000)) logInfo("Ran metadata cleaner for " + name) From 782187c21047ee31728bdb173a2b7ee708cef77b Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Thu, 31 Jan 2013 18:27:25 -0600 Subject: [PATCH 245/291] Once we find a split with no block, we don't have to look for more. --- .../scala/spark/scheduler/DAGScheduler.scala | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index b130be6a38..b62b25f688 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -177,18 +177,17 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with if (!visited(rdd)) { visited += rdd val locs = getCacheLocs(rdd) - for (p <- 0 until rdd.splits.size) { - if (locs(p) == Nil) { - for (dep <- rdd.dependencies) { - dep match { - case shufDep: ShuffleDependency[_,_] => - val mapStage = getShuffleMapStage(shufDep, stage.priority) - if (!mapStage.isAvailable) { - missing += mapStage - } - case narrowDep: NarrowDependency[_] => - visit(narrowDep.rdd) - } + val atLeastOneMissing = (0 until rdd.splits.size).exists(locs(_) == Nil) + if (atLeastOneMissing) { + for (dep <- rdd.dependencies) { + dep match { + case shufDep: ShuffleDependency[_,_] => + val mapStage = getShuffleMapStage(shufDep, stage.priority) + if (!mapStage.isAvailable) { + missing += mapStage + } + case narrowDep: NarrowDependency[_] => + visit(narrowDep.rdd) } } } From 5b0fc265c2f2ce461d61904c2a4e6e47b24d2bbe Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 31 Jan 2013 17:48:39 -0800 Subject: [PATCH 246/291] Changed PartitionPruningRDD's split to make sure it returns the correct split index. --- core/src/main/scala/spark/Dependency.scala | 8 ++++++++ core/src/main/scala/spark/rdd/PartitionPruningRDD.scala | 4 +++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/Dependency.scala b/core/src/main/scala/spark/Dependency.scala index 647aee6eb5..827eac850a 100644 --- a/core/src/main/scala/spark/Dependency.scala +++ b/core/src/main/scala/spark/Dependency.scala @@ -72,6 +72,14 @@ class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boo @transient val partitions: Array[Split] = rdd.splits.filter(s => partitionFilterFunc(s.index)) + .zipWithIndex + .map { case(split, idx) => new PruneDependency.PartitionPruningRDDSplit(idx, split) : Split } override def getParents(partitionId: Int) = List(partitions(partitionId).index) } + +object PruneDependency { + class PartitionPruningRDDSplit(idx: Int, val parentSplit: Split) extends Split { + override val index = idx + } +} diff --git a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala index b8482338c6..0989b149e1 100644 --- a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala @@ -2,6 +2,7 @@ package spark.rdd import spark.{PruneDependency, RDD, SparkEnv, Split, TaskContext} + /** * A RDD used to prune RDD partitions/splits so we can avoid launching tasks on * all partitions. An example use case: If we know the RDD is partitioned by range, @@ -15,7 +16,8 @@ class PartitionPruningRDD[T: ClassManifest]( @transient partitionFilterFunc: Int => Boolean) extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) { - override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(split, context) + override def compute(split: Split, context: TaskContext) = firstParent[T].iterator( + split.asInstanceOf[PruneDependency.PartitionPruningRDDSplit].parentSplit, context) override protected def getSplits = getDependencies.head.asInstanceOf[PruneDependency[T]].partitions From 6289d9654e32fc92418d41cc6e32fee30f85c833 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 31 Jan 2013 17:49:36 -0800 Subject: [PATCH 247/291] Removed the TODO comment from PartitionPruningRDD. --- core/src/main/scala/spark/rdd/PartitionPruningRDD.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala index 0989b149e1..3756870fac 100644 --- a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala @@ -8,8 +8,6 @@ import spark.{PruneDependency, RDD, SparkEnv, Split, TaskContext} * all partitions. An example use case: If we know the RDD is partitioned by range, * and the execution DAG has a filter on the key, we can avoid launching tasks * on partitions that don't have the range covering the key. - * - * TODO: This currently doesn't give partition IDs properly! */ class PartitionPruningRDD[T: ClassManifest]( @transient prev: RDD[T], From 3446d5c8d6b385106ac85e46320d92faa8efb4e6 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 31 Jan 2013 18:02:28 -0800 Subject: [PATCH 248/291] SPARK-673: Capture and re-throw Python exceptions This patch alters the Python <-> executor protocol to pass on exception data when they occur in user Python code. --- .../scala/spark/api/python/PythonRDD.scala | 40 ++++++++++++------- python/pyspark/worker.py | 10 ++++- 2 files changed, 34 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index f43a152ca7..6b9ef62529 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -103,21 +103,30 @@ private[spark] class PythonRDD[T: ClassManifest]( private def read(): Array[Byte] = { try { - val length = stream.readInt() - if (length != -1) { - val obj = new Array[Byte](length) - stream.readFully(obj) - obj - } else { - // We've finished the data section of the output, but we can still read some - // accumulator updates; let's do that, breaking when we get EOFException - while (true) { - val len2 = stream.readInt() - val update = new Array[Byte](len2) - stream.readFully(update) - accumulator += Collections.singletonList(update) + stream.readInt() match { + case length if length > 0 => { + val obj = new Array[Byte](length) + stream.readFully(obj) + obj } - new Array[Byte](0) + case -2 => { + // Signals that an exception has been thrown in python + val exLength = stream.readInt() + val obj = new Array[Byte](exLength) + stream.readFully(obj) + throw new PythonException(new String(obj)) + } + case -1 => { + // We've finished the data section of the output, but we can still read some + // accumulator updates; let's do that, breaking when we get EOFException + while (true) { + val len2 = stream.readInt() + val update = new Array[Byte](len2) + stream.readFully(update) + accumulator += Collections.singletonList(update) + } + new Array[Byte](0) + } } } catch { case eof: EOFException => { @@ -140,6 +149,9 @@ private[spark] class PythonRDD[T: ClassManifest]( val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) } +/** Thrown for exceptions in user Python code. */ +private class PythonException(msg: String) extends Exception(msg) + /** * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python. * This is used by PySpark's shuffle operations. diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index d33d6dd15f..9622e0cfe4 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -2,6 +2,7 @@ Worker that receives input from Piped RDD. """ import sys +import traceback from base64 import standard_b64decode # CloudPickler needs to be imported so that depicklers are registered using the # copy_reg module. @@ -40,8 +41,13 @@ def main(): else: dumps = dump_pickle iterator = read_from_pickle_file(sys.stdin) - for obj in func(split_index, iterator): - write_with_length(dumps(obj), old_stdout) + try: + for obj in func(split_index, iterator): + write_with_length(dumps(obj), old_stdout) + except Exception as e: + write_int(-2, old_stdout) + write_with_length(traceback.format_exc(), old_stdout) + sys.exit(-1) # Mark the beginning of the accumulators section of the output write_int(-1, old_stdout) for aid, accum in _accumulatorRegistry.items(): From c33f0ef41a1865de2bae01b52b860650d3734da4 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 31 Jan 2013 21:50:02 -0800 Subject: [PATCH 249/291] Some style cleanup --- core/src/main/scala/spark/api/python/PythonRDD.scala | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 6b9ef62529..23e3149248 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -104,19 +104,17 @@ private[spark] class PythonRDD[T: ClassManifest]( private def read(): Array[Byte] = { try { stream.readInt() match { - case length if length > 0 => { + case length if length > 0 => val obj = new Array[Byte](length) stream.readFully(obj) obj - } - case -2 => { + case -2 => // Signals that an exception has been thrown in python val exLength = stream.readInt() val obj = new Array[Byte](exLength) stream.readFully(obj) throw new PythonException(new String(obj)) - } - case -1 => { + case -1 => // We've finished the data section of the output, but we can still read some // accumulator updates; let's do that, breaking when we get EOFException while (true) { @@ -124,9 +122,8 @@ private[spark] class PythonRDD[T: ClassManifest]( val update = new Array[Byte](len2) stream.readFully(update) accumulator += Collections.singletonList(update) + new Array[Byte](0) } - new Array[Byte](0) - } } } catch { case eof: EOFException => { From 39ab83e9577a5449fb0d6ef944dffc0d7cd00b4a Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 31 Jan 2013 21:52:52 -0800 Subject: [PATCH 250/291] Small fix from last commit --- core/src/main/scala/spark/api/python/PythonRDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 23e3149248..39758e94f4 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -122,8 +122,8 @@ private[spark] class PythonRDD[T: ClassManifest]( val update = new Array[Byte](len2) stream.readFully(update) accumulator += Collections.singletonList(update) - new Array[Byte](0) } + new Array[Byte](0) } } catch { case eof: EOFException => { From f9af9cee6fed9c6af896fb92556ad4f48c7f8e64 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 1 Feb 2013 00:02:46 -0800 Subject: [PATCH 251/291] Moved PruneDependency into PartitionPruningRDD.scala. --- core/src/main/scala/spark/Dependency.scala | 22 ---------------- .../scala/spark/rdd/PartitionPruningRDD.scala | 26 ++++++++++++++++--- 2 files changed, 22 insertions(+), 26 deletions(-) diff --git a/core/src/main/scala/spark/Dependency.scala b/core/src/main/scala/spark/Dependency.scala index 827eac850a..5eea907322 100644 --- a/core/src/main/scala/spark/Dependency.scala +++ b/core/src/main/scala/spark/Dependency.scala @@ -61,25 +61,3 @@ class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int) } } } - - -/** - * Represents a dependency between the PartitionPruningRDD and its parent. In this - * case, the child RDD contains a subset of partitions of the parents'. - */ -class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boolean) - extends NarrowDependency[T](rdd) { - - @transient - val partitions: Array[Split] = rdd.splits.filter(s => partitionFilterFunc(s.index)) - .zipWithIndex - .map { case(split, idx) => new PruneDependency.PartitionPruningRDDSplit(idx, split) : Split } - - override def getParents(partitionId: Int) = List(partitions(partitionId).index) -} - -object PruneDependency { - class PartitionPruningRDDSplit(idx: Int, val parentSplit: Split) extends Split { - override val index = idx - } -} diff --git a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala index 3756870fac..a50ce75171 100644 --- a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala @@ -1,6 +1,26 @@ package spark.rdd -import spark.{PruneDependency, RDD, SparkEnv, Split, TaskContext} +import spark.{NarrowDependency, RDD, SparkEnv, Split, TaskContext} + + +class PartitionPruningRDDSplit(idx: Int, val parentSplit: Split) extends Split { + override val index = idx +} + + +/** + * Represents a dependency between the PartitionPruningRDD and its parent. In this + * case, the child RDD contains a subset of partitions of the parents'. + */ +class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boolean) + extends NarrowDependency[T](rdd) { + + @transient + val partitions: Array[Split] = rdd.splits.filter(s => partitionFilterFunc(s.index)) + .zipWithIndex.map { case(split, idx) => new PartitionPruningRDDSplit(idx, split) : Split } + + override def getParents(partitionId: Int) = List(partitions(partitionId).index) +} /** @@ -15,10 +35,8 @@ class PartitionPruningRDD[T: ClassManifest]( extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) { override def compute(split: Split, context: TaskContext) = firstParent[T].iterator( - split.asInstanceOf[PruneDependency.PartitionPruningRDDSplit].parentSplit, context) + split.asInstanceOf[PartitionPruningRDDSplit].parentSplit, context) override protected def getSplits = getDependencies.head.asInstanceOf[PruneDependency[T]].partitions - - override val partitioner = firstParent[T].partitioner } From f127f2ae76692b189d86b5a47293579d5657c6d5 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 1 Feb 2013 00:20:49 -0800 Subject: [PATCH 252/291] fixup merge (master -> driver renaming) --- core/src/main/scala/spark/storage/BlockManagerMaster.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index 99324445ca..0372cb080a 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -118,7 +118,7 @@ private[spark] class BlockManagerMaster( } def getStorageStatus: Array[StorageStatus] = { - askMasterWithRetry[ArrayBuffer[StorageStatus]](GetStorageStatus).toArray + askDriverWithReply[ArrayBuffer[StorageStatus]](GetStorageStatus).toArray } /** Stop the driver actor, called only on the Spark driver node */ From 8a0a5ed53353ad6aa5656eb729d55ca7af2ab096 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 1 Feb 2013 00:23:38 -0800 Subject: [PATCH 253/291] track total partitions, in addition to cached partitions; use scala string formatting --- core/src/main/scala/spark/storage/StorageUtils.scala | 10 ++++------ core/src/main/twirl/spark/storage/rdd.scala.html | 6 +++++- core/src/main/twirl/spark/storage/rdd_table.scala.html | 6 ++++-- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/spark/storage/StorageUtils.scala b/core/src/main/scala/spark/storage/StorageUtils.scala index ce7c067eea..5367b74bb6 100644 --- a/core/src/main/scala/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/spark/storage/StorageUtils.scala @@ -22,12 +22,11 @@ case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long, } case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel, - numPartitions: Int, memSize: Long, diskSize: Long) { + numCachedPartitions: Int, numPartitions: Int, memSize: Long, diskSize: Long) { override def toString = { import Utils.memoryBytesToString - import java.lang.{Integer => JInt} - String.format("RDD \"%s\" (%d) Storage: %s; Partitions: %d; MemorySize: %s; DiskSize: %s", name, id.asInstanceOf[JInt], - storageLevel.toString, numPartitions.asInstanceOf[JInt], memoryBytesToString(memSize), memoryBytesToString(diskSize)) + "RDD \"%s\" (%d) Storage: %s; CachedPartitions: %d; TotalPartitions: %d; MemorySize: %s; DiskSize: %s".format(name, id, + storageLevel.toString, numCachedPartitions, numPartitions, memoryBytesToString(memSize), memoryBytesToString(diskSize)) } } @@ -65,9 +64,8 @@ object StorageUtils { val rdd = sc.persistentRdds(rddId) val rddName = Option(rdd.name).getOrElse(rddKey) val rddStorageLevel = rdd.getStorageLevel - //TODO get total number of partitions in rdd - RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, memSize, diskSize) + RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, rdd.splits.size, memSize, diskSize) }.toArray } diff --git a/core/src/main/twirl/spark/storage/rdd.scala.html b/core/src/main/twirl/spark/storage/rdd.scala.html index ac7f8c981f..d85addeb17 100644 --- a/core/src/main/twirl/spark/storage/rdd.scala.html +++ b/core/src/main/twirl/spark/storage/rdd.scala.html @@ -11,7 +11,11 @@ Storage Level: @(rddInfo.storageLevel.description)
  • - Partitions: + Cached Partitions: + @(rddInfo.numCachedPartitions) +
  • +
  • + Total Partitions: @(rddInfo.numPartitions)
  • diff --git a/core/src/main/twirl/spark/storage/rdd_table.scala.html b/core/src/main/twirl/spark/storage/rdd_table.scala.html index af801cf229..a51e64aed0 100644 --- a/core/src/main/twirl/spark/storage/rdd_table.scala.html +++ b/core/src/main/twirl/spark/storage/rdd_table.scala.html @@ -6,7 +6,8 @@ RDD Name Storage Level - Partitions + Cached Partitions + Fraction Partitions Cached Size in Memory Size on Disk @@ -21,7 +22,8 @@ @(rdd.storageLevel.description) - @rdd.numPartitions + @rdd.numCachedPartitions + @(rdd.numCachedPartitions / rdd.numPartitions.toDouble) @{Utils.memoryBytesToString(rdd.memSize)} @{Utils.memoryBytesToString(rdd.diskSize)} From 57b64d0d1902eb51bf79f595626c2b9f80a9d1e2 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 1 Feb 2013 00:25:19 -0800 Subject: [PATCH 254/291] Fix stdout redirection in PySpark. --- python/pyspark/tests.py | 9 +++++++++ python/pyspark/worker.py | 5 +++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index df7235756d..52297d44e6 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -108,5 +108,14 @@ class TestAddFile(PySparkTestCase): self.assertEqual("Hello World!", UserClass().hello()) +class TestIO(PySparkTestCase): + + def test_stdout_redirection(self): + import subprocess + def func(x): + subprocess.check_call('ls', shell=True) + self.sc.parallelize([1]).foreach(func) + + if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 9622e0cfe4..812e7a9da5 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -1,6 +1,7 @@ """ Worker that receives input from Piped RDD. """ +import os import sys import traceback from base64 import standard_b64decode @@ -15,8 +16,8 @@ from pyspark.serializers import write_with_length, read_with_length, write_int, # Redirect stdout to stderr so that users must return values from functions. -old_stdout = sys.stdout -sys.stdout = sys.stderr +old_stdout = os.fdopen(os.dup(1), 'w') +os.dup2(2, 1) def load_obj(): From 59c57e48dfb362923610785b230d5b3b56c620c3 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Fri, 1 Feb 2013 10:34:02 -0600 Subject: [PATCH 255/291] Stop BlockManagers metadataCleaner. --- core/src/main/scala/spark/storage/BlockManager.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index c61fd75c2b..9893e9625d 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -950,6 +950,7 @@ class BlockManager( blockInfo.clear() memoryStore.clear() diskStore.clear() + metadataCleaner.cancel() logInfo("BlockManager stopped") } } From c6190067ae40cf457b7f2e58619904b6fd2b1cb6 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 1 Feb 2013 09:55:25 -0800 Subject: [PATCH 256/291] remove unneeded (and unused) filter on block info --- core/src/main/scala/spark/storage/StorageUtils.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/core/src/main/scala/spark/storage/StorageUtils.scala b/core/src/main/scala/spark/storage/StorageUtils.scala index 5367b74bb6..5f72b67b2c 100644 --- a/core/src/main/scala/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/spark/storage/StorageUtils.scala @@ -43,8 +43,6 @@ object StorageUtils { /* Given a list of BlockStatus objets, returns information for each RDD */ def rddInfoFromBlockStatusList(infos: Map[String, BlockStatus], sc: SparkContext) : Array[RDDInfo] = { - // Find all RDD Blocks (ignore broadcast variables) - val rddBlocks = infos.filterKeys(_.startsWith("rdd")) // Group by rddId, ignore the partition name val groupedRddBlocks = infos.groupBy { case(k, v) => From 9cc6ff9c4e7eec2d62261fc166ad2ebade148752 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 1 Feb 2013 11:09:56 -0800 Subject: [PATCH 257/291] Do not launch JavaGateways on workers (SPARK-674). The problem was that the gateway was being initialized whenever the pyspark.context module was loaded. The fix uses lazy initialization that occurs only when SparkContext instances are actually constructed. I also made the gateway and jvm variables private. This change results in ~3-4x performance improvement when running the PySpark unit tests. --- python/pyspark/context.py | 27 +++++++++++++++++---------- python/pyspark/files.py | 2 +- python/pyspark/rdd.py | 12 ++++++------ python/pyspark/tests.py | 2 +- 4 files changed, 25 insertions(+), 18 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 783e3dc148..ba6896dda3 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -24,11 +24,10 @@ class SparkContext(object): broadcast variables on that cluster. """ - gateway = launch_gateway() - jvm = gateway.jvm - _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile - _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile - _takePartition = jvm.PythonRDD.takePartition + _gateway = None + _jvm = None + _writeIteratorToPickleFile = None + _takePartition = None _next_accum_id = 0 _active_spark_context = None _lock = Lock() @@ -56,6 +55,13 @@ class SparkContext(object): raise ValueError("Cannot run multiple SparkContexts at once") else: SparkContext._active_spark_context = self + if not SparkContext._gateway: + SparkContext._gateway = launch_gateway() + SparkContext._jvm = SparkContext._gateway.jvm + SparkContext._writeIteratorToPickleFile = \ + SparkContext._jvm.PythonRDD.writeIteratorToPickleFile + SparkContext._takePartition = \ + SparkContext._jvm.PythonRDD.takePartition self.master = master self.jobName = jobName self.sparkHome = sparkHome or None # None becomes null in Py4J @@ -63,8 +69,8 @@ class SparkContext(object): self.batchSize = batchSize # -1 represents a unlimited batch size # Create the Java SparkContext through Py4J - empty_string_array = self.gateway.new_array(self.jvm.String, 0) - self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome, + empty_string_array = self._gateway.new_array(self._jvm.String, 0) + self._jsc = self._jvm.JavaSparkContext(master, jobName, sparkHome, empty_string_array) # Create a single Accumulator in Java that we'll send all our updates through; @@ -72,8 +78,8 @@ class SparkContext(object): self._accumulatorServer = accumulators._start_update_server() (host, port) = self._accumulatorServer.server_address self._javaAccumulator = self._jsc.accumulator( - self.jvm.java.util.ArrayList(), - self.jvm.PythonAccumulatorParam(host, port)) + self._jvm.java.util.ArrayList(), + self._jvm.PythonAccumulatorParam(host, port)) self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') # Broadcast's __reduce__ method stores Broadcast instances here. @@ -127,7 +133,8 @@ class SparkContext(object): for x in c: write_with_length(dump_pickle(x), tempFile) tempFile.close() - jrdd = self._readRDDFromPickleFile(self._jsc, tempFile.name, numSlices) + readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile + jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices) return RDD(jrdd, self) def textFile(self, name, minSplits=None): diff --git a/python/pyspark/files.py b/python/pyspark/files.py index 98f6a399cc..001b7a28b6 100644 --- a/python/pyspark/files.py +++ b/python/pyspark/files.py @@ -35,4 +35,4 @@ class SparkFiles(object): return cls._root_directory else: # This will have to change if we support multiple SparkContexts: - return cls._sc.jvm.spark.SparkFiles.getRootDirectory() + return cls._sc._jvm.spark.SparkFiles.getRootDirectory() diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index d53355a8f1..d7cad2f372 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -407,7 +407,7 @@ class RDD(object): return (str(x).encode("utf-8") for x in iterator) keyed = PipelinedRDD(self, func) keyed._bypass_serializer = True - keyed._jrdd.map(self.ctx.jvm.BytesToString()).saveAsTextFile(path) + keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path) # Pair functions @@ -550,8 +550,8 @@ class RDD(object): yield dump_pickle(Batch(items)) keyed = PipelinedRDD(self, add_shuffle_key) keyed._bypass_serializer = True - pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() - partitioner = self.ctx.jvm.PythonPartitioner(numSplits, + pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() + partitioner = self.ctx._jvm.PythonPartitioner(numSplits, id(partitionFunc)) jrdd = pairRDD.partitionBy(partitioner).values() rdd = RDD(jrdd, self.ctx) @@ -730,13 +730,13 @@ class PipelinedRDD(RDD): pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], - self.ctx.gateway._gateway_client) + self.ctx._gateway._gateway_client) self.ctx._pickled_broadcast_vars.clear() class_manifest = self._prev_jrdd.classManifest() env = copy.copy(self.ctx.environment) env['PYTHONPATH'] = os.environ.get("PYTHONPATH", "") - env = MapConverter().convert(env, self.ctx.gateway._gateway_client) - python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), + env = MapConverter().convert(env, self.ctx._gateway._gateway_client) + python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec, broadcast_vars, self.ctx._javaAccumulator, class_manifest) self._jrdd_val = python_rdd.asJavaRDD() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 52297d44e6..6a1962d267 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -26,7 +26,7 @@ class PySparkTestCase(unittest.TestCase): sys.path = self._old_sys_path # To avoid Akka rebinding to the same port, since it doesn't unbind # immediately on shutdown - self.sc.jvm.System.clearProperty("spark.driver.port") + self.sc._jvm.System.clearProperty("spark.driver.port") class TestCheckpoint(PySparkTestCase): From e211f405bcb3cf02c3ae589cf81d9c9dfc70bc03 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 1 Feb 2013 11:48:11 -0800 Subject: [PATCH 258/291] Use spark.local.dir for PySpark temp files (SPARK-580). --- python/pyspark/context.py | 12 ++++++++---- python/pyspark/rdd.py | 7 +------ 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index ba6896dda3..6831f9b7f8 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -1,8 +1,6 @@ import os -import atexit import shutil import sys -import tempfile from threading import Lock from tempfile import NamedTemporaryFile @@ -94,6 +92,11 @@ class SparkContext(object): SparkFiles._sc = self sys.path.append(SparkFiles.getRootDirectory()) + # Create a temporary directory inside spark.local.dir: + local_dir = self._jvm.spark.Utils.getLocalDir() + self._temp_dir = \ + self._jvm.spark.Utils.createTempDir(local_dir).getAbsolutePath() + @property def defaultParallelism(self): """ @@ -126,8 +129,7 @@ class SparkContext(object): # Calling the Java parallelize() method with an ArrayList is too slow, # because it sends O(n) Py4J commands. As an alternative, serialized # objects are written to a file and loaded through textFile(). - tempFile = NamedTemporaryFile(delete=False) - atexit.register(lambda: os.unlink(tempFile.name)) + tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) if self.batchSize != 1: c = batched(c, self.batchSize) for x in c: @@ -247,7 +249,9 @@ class SparkContext(object): def _test(): + import atexit import doctest + import tempfile globs = globals().copy() globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) globs['tempdir'] = tempfile.mkdtemp() diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index d7cad2f372..41ea6e6e14 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1,4 +1,3 @@ -import atexit from base64 import standard_b64encode as b64enc import copy from collections import defaultdict @@ -264,12 +263,8 @@ class RDD(object): # Transferring lots of data through Py4J can be slow because # socket.readline() is inefficient. Instead, we'll dump the data to a # file and read it back. - tempFile = NamedTemporaryFile(delete=False) + tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir) tempFile.close() - def clean_up_file(): - try: os.unlink(tempFile.name) - except: pass - atexit.register(clean_up_file) self.ctx._writeIteratorToPickleFile(iterator, tempFile.name) # Read the data into Python and deserialize it: with open(tempFile.name, 'rb') as tempFile: From 9970926ede0d5a719b8f22e97977804d3c811e97 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Fri, 1 Feb 2013 14:07:34 -0800 Subject: [PATCH 259/291] formatting --- core/src/main/scala/spark/RDD.scala | 2 +- core/src/main/scala/spark/scheduler/ShuffleMapTask.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 210404d540..010e61dfdc 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -385,7 +385,7 @@ abstract class RDD[T: ClassManifest]( val reducePartition: Iterator[T] => Option[T] = iter => { if (iter.hasNext) { Some(iter.reduceLeft(cleanF)) - }else { + } else { None } } diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 83641a2a84..20f2c9e489 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -32,7 +32,7 @@ private[spark] object ShuffleMapTask { return old } else { val out = new ByteArrayOutputStream - val ser = SparkEnv.get.closureSerializer.newInstance + val ser = SparkEnv.get.closureSerializer.newInstance() val objOut = ser.serializeStream(new GZIPOutputStream(out)) objOut.writeObject(rdd) objOut.writeObject(dep) @@ -48,7 +48,7 @@ private[spark] object ShuffleMapTask { synchronized { val loader = Thread.currentThread.getContextClassLoader val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) - val ser = SparkEnv.get.closureSerializer.newInstance + val ser = SparkEnv.get.closureSerializer.newInstance() val objIn = ser.deserializeStream(in) val rdd = objIn.readObject().asInstanceOf[RDD[_]] val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]] From 8b3041c7233011c4a96fab045a86df91eae7b6f3 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Fri, 1 Feb 2013 15:38:42 -0800 Subject: [PATCH 260/291] Reduced the memory usage of reduce and similar operations These operations used to wait for all the results to be available in an array on the driver program before merging them. They now merge values incrementally as they arrive. --- .../main/scala/spark/PairRDDFunctions.scala | 4 +- core/src/main/scala/spark/RDD.scala | 41 ++++++++----- core/src/main/scala/spark/SparkContext.scala | 61 +++++++++++++++---- core/src/main/scala/spark/Utils.scala | 8 +++ .../partial/ApproximateActionListener.scala | 4 +- .../scala/spark/scheduler/DAGScheduler.scala | 15 +++-- .../scala/spark/scheduler/JobResult.scala | 2 +- .../scala/spark/scheduler/JobWaiter.scala | 14 +++-- core/src/test/scala/spark/RDDSuite.scala | 12 ++-- 9 files changed, 111 insertions(+), 50 deletions(-) diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 231e23a7de..cc3cca2571 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -465,7 +465,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( val res = self.context.runJob(self, process _, Array(index), false) res(0) case None => - self.filter(_._1 == key).map(_._2).collect + self.filter(_._1 == key).map(_._2).collect() } } @@ -590,7 +590,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( var count = 0 while(iter.hasNext) { - val record = iter.next + val record = iter.next() count += 1 writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) } diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 010e61dfdc..9d6ea782bd 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -389,16 +389,18 @@ abstract class RDD[T: ClassManifest]( None } } - val options = sc.runJob(this, reducePartition) - val results = new ArrayBuffer[T] - for (opt <- options; elem <- opt) { - results += elem - } - if (results.size == 0) { - throw new UnsupportedOperationException("empty collection") - } else { - return results.reduceLeft(cleanF) + var jobResult: Option[T] = None + val mergeResult = (index: Int, taskResult: Option[T]) => { + if (taskResult != None) { + jobResult = jobResult match { + case Some(value) => Some(f(value, taskResult.get)) + case None => taskResult + } + } } + sc.runJob(this, reducePartition, mergeResult) + // Get the final result out of our Option, or throw an exception if the RDD was empty + jobResult.getOrElse(throw new UnsupportedOperationException("empty collection")) } /** @@ -408,9 +410,13 @@ abstract class RDD[T: ClassManifest]( * modify t2. */ def fold(zeroValue: T)(op: (T, T) => T): T = { + // Clone the zero value since we will also be serializing it as part of tasks + var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance()) val cleanOp = sc.clean(op) - val results = sc.runJob(this, (iter: Iterator[T]) => iter.fold(zeroValue)(cleanOp)) - return results.fold(zeroValue)(cleanOp) + val foldPartition = (iter: Iterator[T]) => iter.fold(zeroValue)(cleanOp) + val mergeResult = (index: Int, taskResult: T) => jobResult = op(jobResult, taskResult) + sc.runJob(this, foldPartition, mergeResult) + jobResult } /** @@ -422,11 +428,14 @@ abstract class RDD[T: ClassManifest]( * allocation. */ def aggregate[U: ClassManifest](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = { + // Clone the zero value since we will also be serializing it as part of tasks + var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance()) val cleanSeqOp = sc.clean(seqOp) val cleanCombOp = sc.clean(combOp) - val results = sc.runJob(this, - (iter: Iterator[T]) => iter.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)) - return results.fold(zeroValue)(cleanCombOp) + val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) + val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult) + sc.runJob(this, aggregatePartition, mergeResult) + jobResult } /** @@ -437,7 +446,7 @@ abstract class RDD[T: ClassManifest]( var result = 0L while (iter.hasNext) { result += 1L - iter.next + iter.next() } result }).sum @@ -452,7 +461,7 @@ abstract class RDD[T: ClassManifest]( var result = 0L while (iter.hasNext) { result += 1L - iter.next + iter.next() } result } diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index b0d4b58240..ddbf8f95d9 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -543,10 +543,30 @@ class SparkContext( } /** - * Run a function on a given set of partitions in an RDD and return the results. This is the main - * entry point to the scheduler, by which all actions get launched. The allowLocal flag specifies - * whether the scheduler can run the computation on the driver rather than shipping it out to the - * cluster, for short actions like first(). + * Run a function on a given set of partitions in an RDD and pass the results to the given + * handler function. This is the main entry point for all actions in Spark. The allowLocal + * flag specifies whether the scheduler can run the computation on the driver rather than + * shipping it out to the cluster, for short actions like first(). + */ + def runJob[T, U: ClassManifest]( + rdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + partitions: Seq[Int], + allowLocal: Boolean, + resultHandler: (Int, U) => Unit) { + val callSite = Utils.getSparkCallSite + logInfo("Starting job: " + callSite) + val start = System.nanoTime + val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler) + logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s") + rdd.doCheckpoint() + result + } + + /** + * Run a function on a given set of partitions in an RDD and return the results as an array. The + * allowLocal flag specifies whether the scheduler can run the computation on the driver rather + * than shipping it out to the cluster, for short actions like first(). */ def runJob[T, U: ClassManifest]( rdd: RDD[T], @@ -554,13 +574,9 @@ class SparkContext( partitions: Seq[Int], allowLocal: Boolean ): Array[U] = { - val callSite = Utils.getSparkCallSite - logInfo("Starting job: " + callSite) - val start = System.nanoTime - val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal) - logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s") - rdd.doCheckpoint() - result + val results = new Array[U](partitions.size) + runJob[T, U](rdd, func, partitions, allowLocal, (index, res) => results(index) = res) + results } /** @@ -590,6 +606,29 @@ class SparkContext( runJob(rdd, func, 0 until rdd.splits.size, false) } + /** + * Run a job on all partitions in an RDD and pass the results to a handler function. + */ + def runJob[T, U: ClassManifest]( + rdd: RDD[T], + processPartition: (TaskContext, Iterator[T]) => U, + resultHandler: (Int, U) => Unit) + { + runJob[T, U](rdd, processPartition, 0 until rdd.splits.size, false, resultHandler) + } + + /** + * Run a job on all partitions in an RDD and pass the results to a handler function. + */ + def runJob[T, U: ClassManifest]( + rdd: RDD[T], + processPartition: Iterator[T] => U, + resultHandler: (Int, U) => Unit) + { + val processFunc = (context: TaskContext, iter: Iterator[T]) => processPartition(iter) + runJob[T, U](rdd, processFunc, 0 until rdd.splits.size, false, resultHandler) + } + /** * Run a job that can return approximate results. */ diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 1e58d01273..28d643abca 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -12,6 +12,7 @@ import scala.io.Source import com.google.common.io.Files import com.google.common.util.concurrent.ThreadFactoryBuilder import scala.Some +import spark.serializer.SerializerInstance /** * Various utility methods used by Spark. @@ -446,4 +447,11 @@ private object Utils extends Logging { socket.close() portBound } + + /** + * Clone an object using a Spark serializer. + */ + def clone[T](value: T, serializer: SerializerInstance): T = { + serializer.deserialize[T](serializer.serialize(value)) + } } diff --git a/core/src/main/scala/spark/partial/ApproximateActionListener.scala b/core/src/main/scala/spark/partial/ApproximateActionListener.scala index 42f46e06ed..24b4909380 100644 --- a/core/src/main/scala/spark/partial/ApproximateActionListener.scala +++ b/core/src/main/scala/spark/partial/ApproximateActionListener.scala @@ -32,7 +32,7 @@ private[spark] class ApproximateActionListener[T, U, R]( if (finishedTasks == totalTasks) { // If we had already returned a PartialResult, set its final value resultObject.foreach(r => r.setFinalValue(evaluator.currentResult())) - // Notify any waiting thread that may have called getResult + // Notify any waiting thread that may have called awaitResult this.notifyAll() } } @@ -49,7 +49,7 @@ private[spark] class ApproximateActionListener[T, U, R]( * Waits for up to timeout milliseconds since the listener was created and then returns a * PartialResult with the result so far. This may be complete if the whole job is done. */ - def getResult(): PartialResult[R] = synchronized { + def awaitResult(): PartialResult[R] = synchronized { val finishTime = startTime + timeout while (true) { val time = System.currentTimeMillis() diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 14f61f7e87..908a22b2df 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -203,18 +203,17 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], callSite: String, - allowLocal: Boolean) - : Array[U] = + allowLocal: Boolean, + resultHandler: (Int, U) => Unit) { if (partitions.size == 0) { - return new Array[U](0) + return } - val waiter = new JobWaiter(partitions.size) + val waiter = new JobWaiter(partitions.size, resultHandler) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] eventQueue.put(JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter)) - waiter.getResult() match { - case JobSucceeded(results: Seq[_]) => - return results.asInstanceOf[Seq[U]].toArray + waiter.awaitResult() match { + case JobSucceeded => {} case JobFailed(exception: Exception) => logInfo("Failed to run " + callSite) throw exception @@ -233,7 +232,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val partitions = (0 until rdd.splits.size).toArray eventQueue.put(JobSubmitted(rdd, func2, partitions, false, callSite, listener)) - return listener.getResult() // Will throw an exception if the job fails + return listener.awaitResult() // Will throw an exception if the job fails } /** diff --git a/core/src/main/scala/spark/scheduler/JobResult.scala b/core/src/main/scala/spark/scheduler/JobResult.scala index c4a74e526f..654131ee84 100644 --- a/core/src/main/scala/spark/scheduler/JobResult.scala +++ b/core/src/main/scala/spark/scheduler/JobResult.scala @@ -5,5 +5,5 @@ package spark.scheduler */ private[spark] sealed trait JobResult -private[spark] case class JobSucceeded(results: Seq[_]) extends JobResult +private[spark] case object JobSucceeded extends JobResult private[spark] case class JobFailed(exception: Exception) extends JobResult diff --git a/core/src/main/scala/spark/scheduler/JobWaiter.scala b/core/src/main/scala/spark/scheduler/JobWaiter.scala index b3d4feebe5..3cc6a86345 100644 --- a/core/src/main/scala/spark/scheduler/JobWaiter.scala +++ b/core/src/main/scala/spark/scheduler/JobWaiter.scala @@ -3,10 +3,12 @@ package spark.scheduler import scala.collection.mutable.ArrayBuffer /** - * An object that waits for a DAGScheduler job to complete. + * An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their + * results to the given handler function. */ -private[spark] class JobWaiter(totalTasks: Int) extends JobListener { - private val taskResults = ArrayBuffer.fill[Any](totalTasks)(null) +private[spark] class JobWaiter[T](totalTasks: Int, resultHandler: (Int, T) => Unit) + extends JobListener { + private var finishedTasks = 0 private var jobFinished = false // Is the job as a whole finished (succeeded or failed)? @@ -17,11 +19,11 @@ private[spark] class JobWaiter(totalTasks: Int) extends JobListener { if (jobFinished) { throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter") } - taskResults(index) = result + resultHandler(index, result.asInstanceOf[T]) finishedTasks += 1 if (finishedTasks == totalTasks) { jobFinished = true - jobResult = JobSucceeded(taskResults) + jobResult = JobSucceeded this.notifyAll() } } @@ -38,7 +40,7 @@ private[spark] class JobWaiter(totalTasks: Int) extends JobListener { } } - def getResult(): JobResult = synchronized { + def awaitResult(): JobResult = synchronized { while (!jobFinished) { this.wait() } diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index ed03e65153..95d2e62730 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -12,9 +12,9 @@ class RDDSuite extends FunSuite with LocalSparkContext { val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) assert(nums.collect().toList === List(1, 2, 3, 4)) val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2) - assert(dups.distinct.count === 4) - assert(dups.distinct().collect === dups.distinct.collect) - assert(dups.distinct(2).collect === dups.distinct.collect) + assert(dups.distinct().count === 4) + assert(dups.distinct().collect === dups.distinct().collect) + assert(dups.distinct(2).collect === dups.distinct().collect) assert(nums.reduce(_ + _) === 10) assert(nums.fold(0)(_ + _) === 10) assert(nums.map(_.toString).collect().toList === List("1", "2", "3", "4")) @@ -31,6 +31,10 @@ class RDDSuite extends FunSuite with LocalSparkContext { case(split, iter) => Iterator((split, iter.reduceLeft(_ + _))) } assert(partitionSumsWithSplit.collect().toList === List((0, 3), (1, 7))) + + intercept[UnsupportedOperationException] { + nums.filter(_ > 5).reduce(_ + _) + } } test("SparkContext.union") { @@ -164,7 +168,7 @@ class RDDSuite extends FunSuite with LocalSparkContext { // Note that split number starts from 0, so > 8 means only 10th partition left. val prunedRdd = new PartitionPruningRDD(data, splitNum => splitNum > 8) assert(prunedRdd.splits.size === 1) - val prunedData = prunedRdd.collect + val prunedData = prunedRdd.collect() assert(prunedData.size === 1) assert(prunedData(0) === 10) } From 12c1eb47568060efac57d6df7df7e5704a8d3fab Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Fri, 1 Feb 2013 21:21:44 -0600 Subject: [PATCH 261/291] Reduce the amount of duplicate logging Akka does to stdout. Given we have Akka logging go through SLF4j to log4j, we don't need all the extra noise of Akka's stdout logger that is supposedly only used during Akka init time but seems to continue logging lots of noisy network events that we either don't care about or are in the log4j logs anyway. See: http://doc.akka.io/docs/akka/2.0/general/configuration.html # Log level for the very basic logger activated during AkkaApplication startup # Options: ERROR, WARNING, INFO, DEBUG # stdout-loglevel = "WARNING" --- core/src/main/scala/spark/util/AkkaUtils.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index e0fdeffbc4..e43fbd6b1c 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -30,6 +30,7 @@ private[spark] object AkkaUtils { val akkaConf = ConfigFactory.parseString(""" akka.daemonic = on akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"] + akka.stdout-loglevel = "ERROR" akka.actor.provider = "akka.remote.RemoteActorRefProvider" akka.remote.transport = "akka.remote.netty.NettyRemoteTransport" akka.remote.log-remote-lifecycle-events = on From ae26911ec0d768dcdae8b7d706ca4544e36535e6 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Fri, 1 Feb 2013 21:07:24 -0800 Subject: [PATCH 262/291] Add back test for distinct without parens --- core/src/test/scala/spark/RDDSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 95d2e62730..89a3687386 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -12,7 +12,8 @@ class RDDSuite extends FunSuite with LocalSparkContext { val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) assert(nums.collect().toList === List(1, 2, 3, 4)) val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2) - assert(dups.distinct().count === 4) + assert(dups.distinct().count() === 4) + assert(dups.distinct.count === 4) // Can distinct and count be called without parentheses? assert(dups.distinct().collect === dups.distinct().collect) assert(dups.distinct(2).collect === dups.distinct().collect) assert(nums.reduce(_ + _) === 10) From 1fd5ee323d127499bb3f173d4142c37532ec29b2 Mon Sep 17 00:00:00 2001 From: Charles Reiss Date: Fri, 1 Feb 2013 22:33:38 -0800 Subject: [PATCH 263/291] Code review changes: add sc.stop; style of multiline comments; parens on procedure calls. --- .../spark/scheduler/DAGSchedulerSuite.scala | 69 +++++++++++++------ 1 file changed, 47 insertions(+), 22 deletions(-) diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala index c31e2e7064..adce1f38bb 100644 --- a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala @@ -31,7 +31,7 @@ import spark.TaskEndReason import spark.{FetchFailed, Success} /** - * Tests for DAGScheduler. These tests directly call the event processing functinos in DAGScheduler + * Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler * rather than spawning an event loop thread as happens in the real code. They use EasyMock * to mock out two classes that DAGScheduler interacts with: TaskScheduler (to which TaskSets are * submitted) and BlockManagerMaster (from which cache locations are retrieved and to which dead @@ -56,29 +56,34 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar var schedulerThread: Thread = null var schedulerException: Throwable = null - /** Set of EasyMock argument matchers that match a TaskSet for a given RDD. + /** + * Set of EasyMock argument matchers that match a TaskSet for a given RDD. * We cache these so we do not create duplicate matchers for the same RDD. * This allows us to easily setup a sequence of expectations for task sets for * that RDD. */ val taskSetMatchers = new HashMap[MyRDD, IArgumentMatcher] - /** Set of cache locations to return from our mock BlockManagerMaster. + /** + * Set of cache locations to return from our mock BlockManagerMaster. * Keys are (rdd ID, partition ID). Anything not present will return an empty * list of cache locations silently. */ val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]] - /** JobWaiter for the last JobSubmitted event we pushed. To keep tests (most of which + /** + * JobWaiter for the last JobSubmitted event we pushed. To keep tests (most of which * will only submit one job) from needing to explicitly track it. */ var lastJobWaiter: JobWaiter = null - /** Tell EasyMockSugar what mock objects we want to be configured by expecting {...} + /** + * Tell EasyMockSugar what mock objects we want to be configured by expecting {...} * and whenExecuting {...} */ implicit val mocks = MockObjects(taskScheduler, blockManagerMaster) - /** Utility function to reset mocks and set expectations on them. EasyMock wants mock objects + /** + * Utility function to reset mocks and set expectations on them. EasyMock wants mock objects * to be reset after each time their expectations are set, and we tend to check mock object * calls over a single call to DAGScheduler. * @@ -115,17 +120,21 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar whenExecuting { scheduler.stop() } + sc.stop() System.clearProperty("spark.master.port") } def makeBlockManagerId(host: String): BlockManagerId = BlockManagerId("exec-" + host, host, 12345) - /** Type of RDD we use for testing. Note that we should never call the real RDD compute methods. - * This is a pair RDD type so it can always be used in ShuffleDependencies. */ + /** + * Type of RDD we use for testing. Note that we should never call the real RDD compute methods. + * This is a pair RDD type so it can always be used in ShuffleDependencies. + */ type MyRDD = RDD[(Int, Int)] - /** Create an RDD for passing to DAGScheduler. These RDDs will use the dependencies and + /** + * Create an RDD for passing to DAGScheduler. These RDDs will use the dependencies and * preferredLocations (if any) that are passed to them. They are deliberately not executable * so we can test that DAGScheduler does not try to execute RDDs locally. */ @@ -150,7 +159,8 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } } - /** EasyMock matcher method. For use as an argument matcher for a TaskSet whose first task + /** + * EasyMock matcher method. For use as an argument matcher for a TaskSet whose first task * is from a particular RDD. */ def taskSetForRdd(rdd: MyRDD): TaskSet = { @@ -172,7 +182,8 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar return null } - /** Setup an EasyMock expectation to repsond to blockManagerMaster.getLocations() called from + /** + * Setup an EasyMock expectation to repsond to blockManagerMaster.getLocations() called from * cacheLocations. */ def expectGetLocations(): Unit = { @@ -197,7 +208,8 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar }).anyTimes() } - /** Process the supplied event as if it were the top of the DAGScheduler event queue, expecting + /** + * Process the supplied event as if it were the top of the DAGScheduler event queue, expecting * the scheduler not to exit. * * After processing the event, submit waiting stages as is done on most iterations of the @@ -208,7 +220,8 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar scheduler.submitWaitingStages() } - /** Expect a TaskSet for the specified RDD to be submitted to the TaskScheduler. Should be + /** + * Expect a TaskSet for the specified RDD to be submitted to the TaskScheduler. Should be * called from a resetExpecting { ... } block. * * Returns a easymock Capture that will contain the task set after the stage is submitted. @@ -220,7 +233,8 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar return taskSetCapture } - /** Expect the supplied code snippet to submit a stage for the specified RDD. + /** + * Expect the supplied code snippet to submit a stage for the specified RDD. * Return the resulting TaskSet. First marks all the tasks are belonging to the * current MapOutputTracker generation. */ @@ -239,7 +253,9 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar return taskSet } - /** Send the given CompletionEvent messages for the tasks in the TaskSet. */ + /** + * Send the given CompletionEvent messages for the tasks in the TaskSet. + */ def respondToTaskSet(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]) { assert(taskSet.tasks.size >= results.size) for ((result, i) <- results.zipWithIndex) { @@ -249,7 +265,9 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } } - /** Assert that the supplied TaskSet has exactly the given preferredLocations. */ + /** + * Assert that the supplied TaskSet has exactly the given preferredLocations. + */ def expectTaskSetLocations(taskSet: TaskSet, locations: Seq[Seq[String]]) { assert(locations.size === taskSet.tasks.size) for ((expectLocs, taskLocs) <- @@ -258,7 +276,8 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } } - /** When we submit dummy Jobs, this is the compute function we supply. Except in a local test + /** + * When we submit dummy Jobs, this is the compute function we supply. Except in a local test * below, we do not expect this function to ever be executed; instead, we will return results * directly through CompletionEvents. */ @@ -266,8 +285,10 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar it.next._1.asInstanceOf[Int] - /** Start a job to compute the given RDD. Returns the JobWaiter that will - * collect the result of the job via callbacks from DAGScheduler. */ + /** + * Start a job to compute the given RDD. Returns the JobWaiter that will + * collect the result of the job via callbacks from DAGScheduler. + */ def submitRdd(rdd: MyRDD, allowLocal: Boolean = false): JobWaiter = { val (toSubmit, waiter) = scheduler.prepareJob[(Int, Int), Int]( rdd, @@ -281,7 +302,9 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar return waiter } - /** Assert that a job we started has failed. */ + /** + * Assert that a job we started has failed. + */ def expectJobException(waiter: JobWaiter = lastJobWaiter) { waiter.getResult match { case JobSucceeded(_) => fail() @@ -289,7 +312,9 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } } - /** Assert that a job we started has succeeded and has the given result. */ + /** + * Assert that a job we started has succeeded and has the given result. + */ def expectJobResult(expected: Array[Int], waiter: JobWaiter = lastJobWaiter) { waiter.getResult match { case JobSucceeded(answer) => @@ -500,7 +525,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar )) } val recomputeOne = interceptStage(shuffleOneRdd) { - scheduler.resubmitFailedStages + scheduler.resubmitFailedStages() } val recomputeTwo = interceptStage(shuffleTwoRdd) { respondToTaskSet(recomputeOne, List( From 28e0cb9f312b7fb1b0236fd15ba0dd2f423e826d Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Sat, 2 Feb 2013 01:11:37 -0600 Subject: [PATCH 264/291] Fix createActorSystem not actually using the systemName parameter. This meant all system names were "spark", which worked, but didn't lead to the most intuitive log output. This fixes createActorSystem to use the passed system name, and refactors Master/Worker to encapsulate their system/actor names instead of having the clients guess at them. Note that the driver system name, "spark", is left as is, and is still repeated a few times, but that seems like a separate issue. --- .../spark/deploy/LocalSparkCluster.scala | 38 +++++------- .../scala/spark/deploy/client/Client.scala | 13 +---- .../scala/spark/deploy/master/Master.scala | 24 +++++++- .../scala/spark/deploy/worker/Worker.scala | 58 +++++++++---------- .../spark/storage/BlockManagerMaster.scala | 2 - .../src/main/scala/spark/util/AkkaUtils.scala | 6 +- 6 files changed, 68 insertions(+), 73 deletions(-) diff --git a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala index 2836574ecb..22319a96ca 100644 --- a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala @@ -18,35 +18,23 @@ import scala.collection.mutable.ArrayBuffer private[spark] class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) extends Logging { - val localIpAddress = Utils.localIpAddress + private val localIpAddress = Utils.localIpAddress + private val masterActorSystems = ArrayBuffer[ActorSystem]() + private val workerActorSystems = ArrayBuffer[ActorSystem]() - var masterActor : ActorRef = _ - var masterActorSystem : ActorSystem = _ - var masterPort : Int = _ - var masterUrl : String = _ - - val workerActorSystems = ArrayBuffer[ActorSystem]() - val workerActors = ArrayBuffer[ActorRef]() - - def start() : String = { + def start(): String = { logInfo("Starting a local Spark cluster with " + numWorkers + " workers.") /* Start the Master */ - val (actorSystem, masterPort) = AkkaUtils.createActorSystem("sparkMaster", localIpAddress, 0) - masterActorSystem = actorSystem - masterUrl = "spark://" + localIpAddress + ":" + masterPort - masterActor = masterActorSystem.actorOf( - Props(new Master(localIpAddress, masterPort, 0)), name = "Master") + val (masterSystem, masterPort) = Master.startSystemAndActor(localIpAddress, 0, 0) + masterActorSystems += masterSystem + val masterUrl = "spark://" + localIpAddress + ":" + masterPort - /* Start the Slaves */ + /* Start the Workers */ for (workerNum <- 1 to numWorkers) { - val (actorSystem, boundPort) = - AkkaUtils.createActorSystem("sparkWorker" + workerNum, localIpAddress, 0) - workerActorSystems += actorSystem - val actor = actorSystem.actorOf( - Props(new Worker(localIpAddress, boundPort, 0, coresPerWorker, memoryPerWorker, masterUrl)), - name = "Worker") - workerActors += actor + val (workerSystem, _) = Worker.startSystemAndActor(localIpAddress, 0, 0, coresPerWorker, + memoryPerWorker, masterUrl, null, Some(workerNum)) + workerActorSystems += workerSystem } return masterUrl @@ -57,7 +45,7 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I // Stop the workers before the master so they don't get upset that it disconnected workerActorSystems.foreach(_.shutdown()) workerActorSystems.foreach(_.awaitTermination()) - masterActorSystem.shutdown() - masterActorSystem.awaitTermination() + masterActorSystems.foreach(_.shutdown()) + masterActorSystems.foreach(_.awaitTermination()) } } diff --git a/core/src/main/scala/spark/deploy/client/Client.scala b/core/src/main/scala/spark/deploy/client/Client.scala index 90fe9508cd..a63eee1233 100644 --- a/core/src/main/scala/spark/deploy/client/Client.scala +++ b/core/src/main/scala/spark/deploy/client/Client.scala @@ -9,6 +9,7 @@ import spark.{SparkException, Logging} import akka.remote.RemoteClientLifeCycleEvent import akka.remote.RemoteClientShutdown import spark.deploy.RegisterJob +import spark.deploy.master.Master import akka.remote.RemoteClientDisconnected import akka.actor.Terminated import akka.dispatch.Await @@ -24,26 +25,18 @@ private[spark] class Client( listener: ClientListener) extends Logging { - val MASTER_REGEX = "spark://([^:]+):([0-9]+)".r - var actor: ActorRef = null var jobId: String = null - if (MASTER_REGEX.unapplySeq(masterUrl) == None) { - throw new SparkException("Invalid master URL: " + masterUrl) - } - class ClientActor extends Actor with Logging { var master: ActorRef = null var masterAddress: Address = null var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times override def preStart() { - val Seq(masterHost, masterPort) = MASTER_REGEX.unapplySeq(masterUrl).get - logInfo("Connecting to master spark://" + masterHost + ":" + masterPort) - val akkaUrl = "akka://spark@%s:%s/user/Master".format(masterHost, masterPort) + logInfo("Connecting to master " + masterUrl) try { - master = context.actorFor(akkaUrl) + master = context.actorFor(Master.toAkkaUrl(masterUrl)) masterAddress = master.path.address master ! RegisterJob(jobDescription) context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index c618e87cdd..92e7914b1b 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -262,11 +262,29 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor } private[spark] object Master { + private val systemName = "sparkMaster" + private val actorName = "Master" + private val sparkUrlRegex = "spark://([^:]+):([0-9]+)".r + def main(argStrings: Array[String]) { val args = new MasterArguments(argStrings) - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", args.ip, args.port) - val actor = actorSystem.actorOf( - Props(new Master(args.ip, boundPort, args.webUiPort)), name = "Master") + val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort) actorSystem.awaitTermination() } + + /** Returns an `akka://...` URL for the Master actor given a sparkUrl `spark://host:ip`. */ + def toAkkaUrl(sparkUrl: String): String = { + sparkUrl match { + case sparkUrlRegex(host, port) => + "akka://%s@%s:%s/user/%s".format(systemName, host, port, actorName) + case _ => + throw new SparkException("Invalid master URL: " + sparkUrl) + } + } + + def startSystemAndActor(host: String, port: Int, webUiPort: Int): (ActorSystem, Int) = { + val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port) + val actor = actorSystem.actorOf(Props(new Master(host, boundPort, webUiPort)), name = actorName) + (actorSystem, boundPort) + } } diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala index 8b41620d98..2219dd6262 100644 --- a/core/src/main/scala/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/spark/deploy/worker/Worker.scala @@ -1,7 +1,7 @@ package spark.deploy.worker import scala.collection.mutable.{ArrayBuffer, HashMap} -import akka.actor.{ActorRef, Props, Actor} +import akka.actor.{ActorRef, Props, Actor, ActorSystem} import spark.{Logging, Utils} import spark.util.AkkaUtils import spark.deploy._ @@ -13,6 +13,7 @@ import akka.remote.RemoteClientDisconnected import spark.deploy.RegisterWorker import spark.deploy.LaunchExecutor import spark.deploy.RegisterWorkerFailed +import spark.deploy.master.Master import akka.actor.Terminated import java.io.File @@ -27,7 +28,6 @@ private[spark] class Worker( extends Actor with Logging { val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs - val MASTER_REGEX = "spark://([^:]+):([0-9]+)".r var master: ActorRef = null var masterWebUiUrl : String = "" @@ -48,11 +48,7 @@ private[spark] class Worker( def memoryFree: Int = memory - memoryUsed def createWorkDir() { - workDir = if (workDirPath != null) { - new File(workDirPath) - } else { - new File(sparkHome, "work") - } + workDir = Option(workDirPath).map(new File(_)).getOrElse(new File(sparkHome, "work")) try { if (!workDir.exists() && !workDir.mkdirs()) { logError("Failed to create work directory " + workDir) @@ -68,8 +64,7 @@ private[spark] class Worker( override def preStart() { logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format( ip, port, cores, Utils.memoryMegabytesToString(memory))) - val envVar = System.getenv("SPARK_HOME") - sparkHome = new File(if (envVar == null) "." else envVar) + sparkHome = new File(Option(System.getenv("SPARK_HOME")).getOrElse(".")) logInfo("Spark home: " + sparkHome) createWorkDir() connectToMaster() @@ -77,24 +72,15 @@ private[spark] class Worker( } def connectToMaster() { - masterUrl match { - case MASTER_REGEX(masterHost, masterPort) => { - logInfo("Connecting to master spark://" + masterHost + ":" + masterPort) - val akkaUrl = "akka://spark@%s:%s/user/Master".format(masterHost, masterPort) - try { - master = context.actorFor(akkaUrl) - master ! RegisterWorker(workerId, ip, port, cores, memory, webUiPort, publicAddress) - context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) - context.watch(master) // Doesn't work with remote actors, but useful for testing - } catch { - case e: Exception => - logError("Failed to connect to master", e) - System.exit(1) - } - } - - case _ => - logError("Invalid master URL: " + masterUrl) + logInfo("Connecting to master " + masterUrl) + try { + master = context.actorFor(Master.toAkkaUrl(masterUrl)) + master ! RegisterWorker(workerId, ip, port, cores, memory, webUiPort, publicAddress) + context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) + context.watch(master) // Doesn't work with remote actors, but useful for testing + } catch { + case e: Exception => + logError("Failed to connect to master", e) System.exit(1) } } @@ -183,11 +169,19 @@ private[spark] class Worker( private[spark] object Worker { def main(argStrings: Array[String]) { val args = new WorkerArguments(argStrings) - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", args.ip, args.port) - val actor = actorSystem.actorOf( - Props(new Worker(args.ip, boundPort, args.webUiPort, args.cores, args.memory, - args.master, args.workDir)), - name = "Worker") + val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort, args.cores, + args.memory, args.master, args.workDir) actorSystem.awaitTermination() } + + def startSystemAndActor(host: String, port: Int, webUiPort: Int, cores: Int, memory: Int, + masterUrl: String, workDir: String, workerNumber: Option[Int] = None): (ActorSystem, Int) = { + // The LocalSparkCluster runs multiple local sparkWorkerX actor systems + val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("") + val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port) + val actor = actorSystem.actorOf(Props(new Worker(host, boundPort, webUiPort, cores, memory, + masterUrl, workDir)), name = "Worker") + (actorSystem, boundPort) + } + } diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index 36398095a2..7be6b9fa87 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -27,8 +27,6 @@ private[spark] class BlockManagerMaster( val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt val DRIVER_AKKA_ACTOR_NAME = "BlockMasterManager" - val SLAVE_AKKA_ACTOR_NAME = "BlockSlaveManager" - val DEFAULT_MANAGER_IP: String = Utils.localHostName() val timeout = 10.seconds var driverActor: ActorRef = { diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index e0fdeffbc4..3a3626e8a0 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -18,9 +18,13 @@ import java.util.concurrent.TimeoutException * Various utility classes for working with Akka. */ private[spark] object AkkaUtils { + /** * Creates an ActorSystem ready for remoting, with various Spark features. Returns both the * ActorSystem itself and its port (which is hard to get from Akka). + * + * Note: the `name` parameter is important, as even if a client sends a message to right + * host + port, if the system name is incorrect, Akka will drop the message. */ def createActorSystem(name: String, host: String, port: Int): (ActorSystem, Int) = { val akkaThreads = System.getProperty("spark.akka.threads", "4").toInt @@ -41,7 +45,7 @@ private[spark] object AkkaUtils { akka.actor.default-dispatcher.throughput = %d """.format(host, port, akkaTimeout, akkaFrameSize, akkaThreads, akkaBatchSize)) - val actorSystem = ActorSystem("spark", akkaConf, getClass.getClassLoader) + val actorSystem = ActorSystem(name, akkaConf, getClass.getClassLoader) // Figure out the port number we bound to, in case port was passed as 0. This is a bit of a // hack because Akka doesn't let you figure out the port through the public API yet. From 696eec32c982ca516c506de33f383a173bcbd131 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Sat, 2 Feb 2013 02:03:26 -0600 Subject: [PATCH 265/291] Move executorMemory up into SchedulerBackend. --- .../spark/scheduler/cluster/SchedulerBackend.scala | 12 ++++++++++++ .../cluster/SparkDeploySchedulerBackend.scala | 9 --------- .../mesos/CoarseMesosSchedulerBackend.scala | 10 ---------- .../scheduler/mesos/MesosSchedulerBackend.scala | 10 ---------- 4 files changed, 12 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala index ddcd64d7c6..9ac875de3a 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala @@ -1,5 +1,7 @@ package spark.scheduler.cluster +import spark.Utils + /** * A backend interface for cluster scheduling systems that allows plugging in different ones under * ClusterScheduler. We assume a Mesos-like model where the application gets resource offers as @@ -11,5 +13,15 @@ private[spark] trait SchedulerBackend { def reviveOffers(): Unit def defaultParallelism(): Int + // Memory used by each executor (in megabytes) + protected val executorMemory = { + // TODO: Might need to add some extra memory for the non-heap parts of the JVM + Option(System.getProperty("spark.executor.memory")) + .orElse(Option(System.getenv("SPARK_MEM"))) + .map(Utils.memoryStringToMb) + .getOrElse(512) + } + + // TODO: Probably want to add a killTask too } diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 2f7099c5b9..59ff8bcb90 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -20,15 +20,6 @@ private[spark] class SparkDeploySchedulerBackend( val maxCores = System.getProperty("spark.cores.max", Int.MaxValue.toString).toInt - // Memory used by each executor (in megabytes) - val executorMemory = { - // TODO: Might need to add some extra memory for the non-heap parts of the JVM - Option(System.getProperty("spark.executor.memory")) - .orElse(Option(System.getenv("SPARK_MEM"))) - .map(Utils.memoryStringToMb) - .getOrElse(512) - } - override def start() { super.start() diff --git a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala index 7bf56a05d6..b481ec0a72 100644 --- a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala @@ -35,16 +35,6 @@ private[spark] class CoarseMesosSchedulerBackend( val MAX_SLAVE_FAILURES = 2 // Blacklist a slave after this many failures - // Memory used by each executor (in megabytes) - val executorMemory = { - if (System.getenv("SPARK_MEM") != null) { - Utils.memoryStringToMb(System.getenv("SPARK_MEM")) - // TODO: Might need to add some extra memory for the non-heap parts of the JVM - } else { - 512 - } - } - // Lock used to wait for scheduler to be registered var isRegistered = false val registeredLock = new Object() diff --git a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala index eab1c60e0b..5c8b531de3 100644 --- a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala @@ -29,16 +29,6 @@ private[spark] class MesosSchedulerBackend( with MScheduler with Logging { - // Memory used by each executor (in megabytes) - val EXECUTOR_MEMORY = { - if (System.getenv("SPARK_MEM") != null) { - Utils.memoryStringToMb(System.getenv("SPARK_MEM")) - // TODO: Might need to add some extra memory for the non-heap parts of the JVM - } else { - 512 - } - } - // Lock used to wait for scheduler to be registered var isRegistered = false val registeredLock = new Object() From cae8a6795c7f454b74c8d3c4425a6ced151d6d9b Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Sat, 2 Feb 2013 02:15:39 -0600 Subject: [PATCH 266/291] Fix dangling old variable names. --- .../scala/spark/scheduler/mesos/MesosSchedulerBackend.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala index 5c8b531de3..300766d0f5 100644 --- a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala @@ -79,7 +79,7 @@ private[spark] class MesosSchedulerBackend( val memory = Resource.newBuilder() .setName("mem") .setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder().setValue(EXECUTOR_MEMORY).build()) + .setScalar(Value.Scalar.newBuilder().setValue(executorMemory).build()) .build() val command = CommandInfo.newBuilder() .setValue(execScript) @@ -151,7 +151,7 @@ private[spark] class MesosSchedulerBackend( def enoughMemory(o: Offer) = { val mem = getResource(o.getResourcesList, "mem") val slaveId = o.getSlaveId.getValue - mem >= EXECUTOR_MEMORY || slaveIdsWithExecutors.contains(slaveId) + mem >= executorMemory || slaveIdsWithExecutors.contains(slaveId) } for ((offer, index) <- offers.zipWithIndex if enoughMemory(offer)) { From 7aba123f0c0fd024105462b3a0b203cd357c67e9 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Sat, 2 Feb 2013 13:53:28 -0600 Subject: [PATCH 267/291] Further simplify checking for Nil. --- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index b62b25f688..2a646dd0f5 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -176,9 +176,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with def visit(rdd: RDD[_]) { if (!visited(rdd)) { visited += rdd - val locs = getCacheLocs(rdd) - val atLeastOneMissing = (0 until rdd.splits.size).exists(locs(_) == Nil) - if (atLeastOneMissing) { + if (getCacheLocs(rdd).contains(Nil)) { for (dep <- rdd.dependencies) { dep match { case shufDep: ShuffleDependency[_,_] => From 34a7bcdb3a19deed18b25225daf47ff22ee20869 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 2 Feb 2013 19:40:30 -0800 Subject: [PATCH 268/291] Formatting --- .../main/scala/spark/scheduler/DAGScheduler.scala | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 8cfc08e5ac..2a35915560 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -205,8 +205,9 @@ class DAGScheduler( missing.toList } - /** Returns (and does not) submit a JobSubmitted event suitable to run a given job, and - * a JobWaiter whose getResult() method will return the result of the job when it is complete. + /** + * Returns (and does not submit) a JobSubmitted event suitable to run a given job, and a + * JobWaiter whose getResult() method will return the result of the job when it is complete. * * The job is assumed to have at least one partition; zero partition jobs should be handled * without a JobSubmitted event. @@ -308,7 +309,8 @@ class DAGScheduler( return false } - /** Resubmit any failed stages. Ordinarily called after a small amount of time has passed since + /** + * Resubmit any failed stages. Ordinarily called after a small amount of time has passed since * the last fetch failure. */ private[scheduler] def resubmitFailedStages() { @@ -321,7 +323,8 @@ class DAGScheduler( } } - /** Check for waiting or failed stages which are now eligible for resubmission. + /** + * Check for waiting or failed stages which are now eligible for resubmission. * Ordinarily run on every iteration of the event loop. */ private[scheduler] def submitWaitingStages() { @@ -366,9 +369,9 @@ class DAGScheduler( // the same time, so we want to make sure we've identified all the reduce tasks that depend // on the failed node. if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) { - resubmitFailedStages + resubmitFailedStages() } else { - submitWaitingStages + submitWaitingStages() } } } From 2415c18f48fc28d88f29b88c312f98054f530f20 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 3 Feb 2013 06:44:11 +0000 Subject: [PATCH 269/291] Fix reporting of PySpark doctest failures. --- python/pyspark/context.py | 4 +++- python/pyspark/rdd.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 6831f9b7f8..657fe6f989 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -256,8 +256,10 @@ def _test(): globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) globs['tempdir'] = tempfile.mkdtemp() atexit.register(lambda: shutil.rmtree(globs['tempdir'])) - doctest.testmod(globs=globs) + (failure_count, test_count) = doctest.testmod(globs=globs) globs['sc'].stop() + if failure_count: + exit(-1) if __name__ == "__main__": diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 41ea6e6e14..fb144bc45d 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -748,8 +748,10 @@ def _test(): # The small batch size here ensures that we see multiple batches, # even in these small test examples: globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) - doctest.testmod(globs=globs) + (failure_count, test_count) = doctest.testmod(globs=globs) globs['sc'].stop() + if failure_count: + exit(-1) if __name__ == "__main__": From 8fbd5380b7f36842297f624bad3a2513f7eca47b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 3 Feb 2013 06:44:49 +0000 Subject: [PATCH 270/291] Fetch fewer objects in PySpark's take() method. --- core/src/main/scala/spark/api/python/PythonRDD.scala | 11 +++++++++-- python/pyspark/rdd.py | 4 ++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 39758e94f4..ab8351e55e 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -238,6 +238,11 @@ private[spark] object PythonRDD { } def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) { + import scala.collection.JavaConverters._ + writeIteratorToPickleFile(items.asScala, filename) + } + + def writeIteratorToPickleFile[T](items: Iterator[T], filename: String) { val file = new DataOutputStream(new FileOutputStream(filename)) for (item <- items) { writeAsPickle(item, file) @@ -245,8 +250,10 @@ private[spark] object PythonRDD { file.close() } - def takePartition[T](rdd: RDD[T], partition: Int): java.util.Iterator[T] = - rdd.context.runJob(rdd, ((x: Iterator[T]) => x), Seq(partition), true).head + def takePartition[T](rdd: RDD[T], partition: Int): Iterator[T] = { + implicit val cm : ClassManifest[T] = rdd.elementClassManifest + rdd.context.runJob(rdd, ((x: Iterator[T]) => x.toArray), Seq(partition), true).head.iterator + } } private object Pickle { diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index fb144bc45d..4cda6cf661 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -372,6 +372,10 @@ class RDD(object): items = [] for partition in range(self._jrdd.splits().size()): iterator = self.ctx._takePartition(self._jrdd.rdd(), partition) + # Each item in the iterator is a string, Python object, batch of + # Python objects. Regardless, it is sufficient to take `num` + # of these objects in order to collect `num` Python objects: + iterator = iterator.take(num) items.extend(self._collect_iterator_through_file(iterator)) if len(items) >= num: break From 9163c3705d98ca19c09fe5618e347b9d20f88f63 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 2 Feb 2013 23:34:47 -0800 Subject: [PATCH 271/291] Formatting --- core/src/main/scala/spark/scheduler/DAGScheduler.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 2a35915560..edbfd1c45f 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -265,7 +265,8 @@ class DAGScheduler( return listener.awaitResult() // Will throw an exception if the job fails } - /** Process one event retrieved from the event queue. + /** + * Process one event retrieved from the event queue. * Returns true if we should stop the event loop. */ private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = { From e61729113d3bf165d1ab9bd83ea55d52fd0bb72e Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 3 Feb 2013 21:29:40 -0800 Subject: [PATCH 272/291] Remove unnecessary doctest __main__ methods. --- python/pyspark/accumulators.py | 9 --------- python/pyspark/broadcast.py | 9 --------- 2 files changed, 18 deletions(-) diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 61fcbbd376..3e9d7d36da 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -196,12 +196,3 @@ def _start_update_server(): thread.daemon = True thread.start() return server - - -def _test(): - import doctest - doctest.testmod() - - -if __name__ == "__main__": - _test() diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 93876fa738..def810dd46 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -37,12 +37,3 @@ class Broadcast(object): def __reduce__(self): self._pickle_registry.add(self) return (_from_id, (self.bid, )) - - -def _test(): - import doctest - doctest.testmod() - - -if __name__ == "__main__": - _test() From aa4ee1e9e5485c1b96474e704c76225a2b8a7da9 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 4 Feb 2013 11:06:31 -0800 Subject: [PATCH 273/291] Fix failing test --- core/src/test/scala/spark/MapOutputTrackerSuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala index f4e7ec39fe..dd19442dcb 100644 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala @@ -79,8 +79,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { test("remote fetch") { try { System.clearProperty("spark.driver.host") // In case some previous test had set it - val (actorSystem, boundPort) = - AkkaUtils.createActorSystem("test", "localhost", 0) + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", "localhost", 0) System.setProperty("spark.driver.port", boundPort.toString) val masterTracker = new MapOutputTracker(actorSystem, true) val slaveTracker = new MapOutputTracker(actorSystem, false) From f6ec547ea7b56ee607a4c2a69206f8952318eaf1 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 4 Feb 2013 13:14:54 -0800 Subject: [PATCH 274/291] Small fix to test for distinct --- core/src/test/scala/spark/RDDSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 89a3687386..fe7deb10d6 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -14,7 +14,7 @@ class RDDSuite extends FunSuite with LocalSparkContext { val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2) assert(dups.distinct().count() === 4) assert(dups.distinct.count === 4) // Can distinct and count be called without parentheses? - assert(dups.distinct().collect === dups.distinct().collect) + assert(dups.distinct.collect === dups.distinct().collect) assert(dups.distinct(2).collect === dups.distinct().collect) assert(nums.reduce(_ + _) === 10) assert(nums.fold(0)(_ + _) === 10) From 7eea64aa4c0d6a51406e0d1b039906ee9559cd58 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Tue, 5 Feb 2013 11:41:31 -0800 Subject: [PATCH 275/291] Streaming constructor which takes JavaSparkContext It's sometimes helpful to directly pass a JavaSparkContext, and take advantage of the various constructors available for that. --- .../spark/streaming/api/java/JavaStreamingContext.scala | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala index f82e6a37cc..e7f446a49b 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -33,6 +33,14 @@ class JavaStreamingContext(val ssc: StreamingContext) { def this(master: String, frameworkName: String, batchDuration: Duration) = this(new StreamingContext(master, frameworkName, batchDuration)) + /** + * Creates a StreamingContext. + * @param sparkContext The underlying JavaSparkContext to use + * @param batchDuration The time interval at which streaming data will be divided into batches + */ + def this(sparkContext: JavaSparkContext, batchDuration: Duration) = + this(new StreamingContext(sparkContext.sc, batchDuration)) + /** * Re-creates a StreamingContext from a checkpoint file. * @param path Path either to the directory that was specified as the checkpoint directory, or From 8bd0e888f377f13ac239df4ffd49fc666095e764 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 5 Feb 2013 17:50:25 -0600 Subject: [PATCH 276/291] Inline mergePair to look more like the narrow dep branch. No functionality changes, I think this is just more consistent given mergePair isn't called multiple times/recursive. Also added a comment to explain the usual case of having two parent RDDs. --- core/src/main/scala/spark/rdd/CoGroupedRDD.scala | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index 8fafd27bb6..4893fe8d78 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -84,6 +84,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) override def compute(s: Split, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = { val split = s.asInstanceOf[CoGroupSplit] val numRdds = split.deps.size + // e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs) val map = new JHashMap[K, Seq[ArrayBuffer[Any]]] def getSeq(k: K): Seq[ArrayBuffer[Any]] = { val seq = map.get(k) @@ -104,13 +105,10 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) } case ShuffleCoGroupSplitDep(shuffleId) => { // Read map outputs of shuffle - def mergePair(pair: (K, Seq[Any])) { - val mySeq = getSeq(pair._1) - for (v <- pair._2) - mySeq(depNum) += v - } val fetcher = SparkEnv.get.shuffleFetcher - fetcher.fetch[K, Seq[Any]](shuffleId, split.index).foreach(mergePair) + for ((k, vs) <- fetcher.fetch[K, Seq[Any]](shuffleId, split.index)) { + getSeq(k)(depNum) ++= vs + } } } JavaConversions.mapAsScalaMap(map).iterator From 1ba3393ceb5709620a28b8bc01826153993fc444 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 5 Feb 2013 17:56:50 -0600 Subject: [PATCH 277/291] Increase DriverSuite timeout. --- core/src/test/scala/spark/DriverSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/test/scala/spark/DriverSuite.scala b/core/src/test/scala/spark/DriverSuite.scala index 342610e1dd..5e84b3a66a 100644 --- a/core/src/test/scala/spark/DriverSuite.scala +++ b/core/src/test/scala/spark/DriverSuite.scala @@ -9,10 +9,11 @@ import org.scalatest.time.SpanSugar._ class DriverSuite extends FunSuite with Timeouts { test("driver should exit after finishing") { + assert(System.getenv("SPARK_HOME") != null) // Regression test for SPARK-530: "Spark driver process doesn't exit after finishing" val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]")) forAll(masters) { (master: String) => - failAfter(10 seconds) { + failAfter(30 seconds) { Utils.execute(Seq("./run", "spark.DriverWithoutCleanup", master), new File(System.getenv("SPARK_HOME"))) } From 0e19093fd89ec9740f98cdcffd1ec09f4faf2490 Mon Sep 17 00:00:00 2001 From: Stephen Haberman Date: Tue, 5 Feb 2013 18:58:00 -0600 Subject: [PATCH 278/291] Handle Terminated to avoid endless DeathPactExceptions. Credit to Roland Kuhn, Akka's tech lead, for pointing out this various obvious fix, but StandaloneExecutorBackend.preStart's catch block would never (ever) get hit, because all of the operation's in preStart are async. So, the System.exit in the catch block was skipped, and instead Akka was sending Terminated messages which, since we didn't handle, it turned into DeathPactException, which started a postRestart/preStart infinite loop. --- .../scala/spark/deploy/worker/Worker.scala | 7 ++---- .../executor/StandaloneExecutorBackend.scala | 25 ++++++++----------- 2 files changed, 13 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala index 8b41620d98..48177a638a 100644 --- a/core/src/main/scala/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/spark/deploy/worker/Worker.scala @@ -1,19 +1,16 @@ package spark.deploy.worker import scala.collection.mutable.{ArrayBuffer, HashMap} -import akka.actor.{ActorRef, Props, Actor} +import akka.actor.{ActorRef, Props, Actor, Terminated} import spark.{Logging, Utils} import spark.util.AkkaUtils import spark.deploy._ -import akka.remote.RemoteClientLifeCycleEvent +import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected} import java.text.SimpleDateFormat import java.util.Date -import akka.remote.RemoteClientShutdown -import akka.remote.RemoteClientDisconnected import spark.deploy.RegisterWorker import spark.deploy.LaunchExecutor import spark.deploy.RegisterWorkerFailed -import akka.actor.Terminated import java.io.File private[spark] class Worker( diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala index e45288ff53..224c126fdd 100644 --- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala @@ -4,16 +4,15 @@ import java.nio.ByteBuffer import spark.Logging import spark.TaskState.TaskState import spark.util.AkkaUtils -import akka.actor.{ActorRef, Actor, Props} +import akka.actor.{ActorRef, Actor, Props, Terminated} +import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected} import java.util.concurrent.{TimeUnit, ThreadPoolExecutor, SynchronousQueue} -import akka.remote.RemoteClientLifeCycleEvent import spark.scheduler.cluster._ import spark.scheduler.cluster.RegisteredExecutor import spark.scheduler.cluster.LaunchTask import spark.scheduler.cluster.RegisterExecutorFailed import spark.scheduler.cluster.RegisterExecutor - private[spark] class StandaloneExecutorBackend( executor: Executor, driverUrl: String, @@ -27,17 +26,11 @@ private[spark] class StandaloneExecutorBackend( var driver: ActorRef = null override def preStart() { - try { - logInfo("Connecting to driver: " + driverUrl) - driver = context.actorFor(driverUrl) - driver ! RegisterExecutor(executorId, hostname, cores) - context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) - context.watch(driver) // Doesn't work with remote actors, but useful for testing - } catch { - case e: Exception => - logError("Failed to connect to driver", e) - System.exit(1) - } + logInfo("Connecting to driver: " + driverUrl) + driver = context.actorFor(driverUrl) + driver ! RegisterExecutor(executorId, hostname, cores) + context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) + context.watch(driver) // Doesn't work with remote actors, but useful for testing } override def receive = { @@ -52,6 +45,10 @@ private[spark] class StandaloneExecutorBackend( case LaunchTask(taskDesc) => logInfo("Got assigned task " + taskDesc.taskId) executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask) + + case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) => + logError("Driver terminated or disconnected! Shutting down.") + System.exit(1) } override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) { From d55e3aa467ab7d406739255bd8dc3dfc60f3cb16 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 7 Feb 2013 13:59:18 -0800 Subject: [PATCH 279/291] Updated JavaStreamingContext with updated kafkaStream API. --- .../api/java/JavaStreamingContext.scala | 26 +++++++------------ 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala index f82e6a37cc..70d6bd2b1b 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -45,27 +45,24 @@ class JavaStreamingContext(val ssc: StreamingContext) { /** * Create an input stream that pulls messages form a Kafka Broker. - * @param hostname Zookeper hostname. - * @param port Zookeper port. + * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). * @param groupId The group id for this consumer. * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. */ def kafkaStream[T]( - hostname: String, - port: Int, + zkQuorum: String, groupId: String, topics: JMap[String, JInt]) : JavaDStream[T] = { implicit val cmt: ClassManifest[T] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - ssc.kafkaStream[T](hostname, port, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*)) + ssc.kafkaStream[T](zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*)) } /** * Create an input stream that pulls messages form a Kafka Broker. - * @param hostname Zookeper hostname. - * @param port Zookeper port. + * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). * @param groupId The group id for this consumer. * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. @@ -73,8 +70,7 @@ class JavaStreamingContext(val ssc: StreamingContext) { * By default the value is pulled from zookeper. */ def kafkaStream[T]( - hostname: String, - port: Int, + zkQuorum: String, groupId: String, topics: JMap[String, JInt], initialOffsets: JMap[KafkaPartitionKey, JLong]) @@ -82,8 +78,7 @@ class JavaStreamingContext(val ssc: StreamingContext) { implicit val cmt: ClassManifest[T] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] ssc.kafkaStream[T]( - hostname, - port, + zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), Map(initialOffsets.mapValues(_.longValue()).toSeq: _*)) @@ -91,8 +86,7 @@ class JavaStreamingContext(val ssc: StreamingContext) { /** * Create an input stream that pulls messages form a Kafka Broker. - * @param hostname Zookeper hostname. - * @param port Zookeper port. + * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). * @param groupId The group id for this consumer. * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. @@ -101,8 +95,7 @@ class JavaStreamingContext(val ssc: StreamingContext) { * @param storageLevel RDD storage level. Defaults to memory-only */ def kafkaStream[T]( - hostname: String, - port: Int, + zkQuorum: String, groupId: String, topics: JMap[String, JInt], initialOffsets: JMap[KafkaPartitionKey, JLong], @@ -111,8 +104,7 @@ class JavaStreamingContext(val ssc: StreamingContext) { implicit val cmt: ClassManifest[T] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] ssc.kafkaStream[T]( - hostname, - port, + zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), Map(initialOffsets.mapValues(_.longValue()).toSeq: _*), From 99a5fc498acf3de14d754f8dda0df6bb81dd9595 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sat, 9 Feb 2013 15:18:05 -0800 Subject: [PATCH 280/291] Added an initial spark job to ensure worker nodes are initialized. --- .../main/scala/spark/streaming/NetworkInputTracker.scala | 7 ++++++- .../src/test/scala/spark/streaming/InputStreamsSuite.scala | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala index e4152f3a61..b54f53b203 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala @@ -4,6 +4,7 @@ import spark.streaming.dstream.{NetworkInputDStream, NetworkReceiver} import spark.streaming.dstream.{StopReceiver, ReportBlock, ReportError} import spark.Logging import spark.SparkEnv +import spark.SparkContext._ import scala.collection.mutable.HashMap import scala.collection.mutable.Queue @@ -138,8 +139,12 @@ class NetworkInputTracker( } iterator.next().start() } + // Run the dummy Spark job to ensure that all slaves have registered. + // This avoids all the receivers to be scheduled on the same node. + //ssc.sparkContext.makeRDD(1 to 100, 100).map(x => (x, 1)).reduceByKey(_ + _, 20).collect() + // Distribute the receivers and start them - ssc.sc.runJob(tempRDD, startReceiver) + ssc.sparkContext.runJob(tempRDD, startReceiver) } /** Stops the receivers. */ diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala index c442210004..0eb9c7b81e 100644 --- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala @@ -95,7 +95,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val input = Seq(1, 2, 3, 4, 5) - + Thread.sleep(1000) val transceiver = new NettyTransceiver(new InetSocketAddress("localhost", 33333)); val client = SpecificRequestor.getClient( classOf[AvroSourceProtocol], transceiver); From 16baea62bce62987158acce0595a0916c25b32b2 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 10 Feb 2013 19:14:49 -0800 Subject: [PATCH 281/291] Fixed bug in CheckpointRDD to prevent exception when the original RDD had zero splits. --- core/src/main/scala/spark/rdd/CheckpointRDD.scala | 4 ++-- core/src/test/scala/spark/CheckpointSuite.scala | 10 ++++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala index 96b593ba7c..a21338f85f 100644 --- a/core/src/main/scala/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala @@ -24,8 +24,8 @@ class CheckpointRDD[T: ClassManifest](sc: SparkContext, val checkpointPath: Stri val dirContents = fs.listStatus(new Path(checkpointPath)) val splitFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted val numSplits = splitFiles.size - if (!splitFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) || - !splitFiles(numSplits-1).endsWith(CheckpointRDD.splitIdToFile(numSplits-1))) { + if (numSplits > 0 && (!splitFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) || + !splitFiles(numSplits-1).endsWith(CheckpointRDD.splitIdToFile(numSplits-1)))) { throw new SparkException("Invalid checkpoint directory: " + checkpointPath) } Array.tabulate(numSplits)(i => new CheckpointRDDSplit(i)) diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 0b74607fb8..4425949f46 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -162,6 +162,16 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false) } + test("CheckpointRDD with zero partitions") { + val rdd = new BlockRDD[Int](sc, Array[String]()) + assert(rdd.splits.size === 0) + assert(rdd.isCheckpointed === false) + rdd.checkpoint() + assert(rdd.count() === 0) + assert(rdd.isCheckpointed === true) + assert(rdd.splits.size === 0) + } + /** * Test checkpointing of the final RDD generated by the given operation. By default, * this method tests whether the size of serialized RDD has reduced after checkpointing or not. From fd90daf850a922fe33c3638b18304d827953e2cb Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 10 Feb 2013 19:48:42 -0800 Subject: [PATCH 282/291] Fixed bugs in FileInputDStream and Scheduler that occasionally failed to reprocess old files after recovering from master failure. Completely modified spark.streaming.FailureTest to test multiple master failures using file input stream. --- .../scala/spark/streaming/DStreamGraph.scala | 2 + .../scala/spark/streaming/JobManager.scala | 4 +- .../scala/spark/streaming/Scheduler.scala | 8 +- .../src/main/scala/spark/streaming/Time.scala | 4 + .../streaming/dstream/FileInputDStream.scala | 13 +- .../scala/spark/streaming/FailureSuite.scala | 283 +++++++++++++----- 6 files changed, 222 insertions(+), 92 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala index d5a5496839..7aa9d20004 100644 --- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala @@ -81,12 +81,14 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { private[streaming] def generateRDDs(time: Time): Seq[Job] = { this.synchronized { + logInfo("Generating RDDs for time " + time) outputStreams.flatMap(outputStream => outputStream.generateJob(time)) } } private[streaming] def forgetOldRDDs(time: Time) { this.synchronized { + logInfo("Forgetting old RDDs for time " + time) outputStreams.foreach(_.forgetOldMetadata(time)) } } diff --git a/streaming/src/main/scala/spark/streaming/JobManager.scala b/streaming/src/main/scala/spark/streaming/JobManager.scala index 5acdd01e58..8b18c7bc6a 100644 --- a/streaming/src/main/scala/spark/streaming/JobManager.scala +++ b/streaming/src/main/scala/spark/streaming/JobManager.scala @@ -15,8 +15,8 @@ class JobManager(ssc: StreamingContext, numThreads: Int = 1) extends Logging { SparkEnv.set(ssc.env) try { val timeTaken = job.run() - logInfo("Total delay: %.5f s for job %s (execution: %.5f s)".format( - (System.currentTimeMillis() - job.time.milliseconds) / 1000.0, job.id, timeTaken / 1000.0)) + logInfo("Total delay: %.5f s for job %s of time %s (execution: %.5f s)".format( + (System.currentTimeMillis() - job.time.milliseconds) / 1000.0, job.id, job.time.milliseconds, timeTaken / 1000.0)) } catch { case e: Exception => logError("Running " + job + " failed", e) diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index b77986a3ba..23a0f0974d 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -40,7 +40,11 @@ class Scheduler(ssc: StreamingContext) extends Logging { clock.asInstanceOf[ManualClock].setTime(lastTime + jumpTime) } // Reschedule the batches that were received but not processed before failure - ssc.initialCheckpoint.pendingTimes.foreach(time => generateRDDs(time)) + //ssc.initialCheckpoint.pendingTimes.foreach(time => generateRDDs(time)) + val pendingTimes = ssc.initialCheckpoint.pendingTimes.sorted(Time.ordering) + println(pendingTimes.mkString(", ")) + pendingTimes.foreach(time => + graph.generateRDDs(time).foreach(jobManager.runJob)) // Restart the timer timer.restart(graph.zeroTime.milliseconds) logInfo("Scheduler's timer restarted") @@ -64,11 +68,11 @@ class Scheduler(ssc: StreamingContext) extends Logging { graph.generateRDDs(time).foreach(jobManager.runJob) graph.forgetOldRDDs(time) doCheckpoint(time) - logInfo("Generated RDDs for time " + time) } private def doCheckpoint(time: Time) { if (ssc.checkpointDuration != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) { + logInfo("Checkpointing graph for time " + time) val startTime = System.currentTimeMillis() ssc.graph.updateCheckpointData(time) checkpointWriter.write(new Checkpoint(ssc, time)) diff --git a/streaming/src/main/scala/spark/streaming/Time.scala b/streaming/src/main/scala/spark/streaming/Time.scala index 5daeb761dd..8a6c9a5cb5 100644 --- a/streaming/src/main/scala/spark/streaming/Time.scala +++ b/streaming/src/main/scala/spark/streaming/Time.scala @@ -39,4 +39,8 @@ case class Time(private val millis: Long) { override def toString: String = (millis.toString + " ms") +} + +object Time { + val ordering = Ordering.by((time: Time) => time.millis) } \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala index c6ffb252ce..10ccb4318d 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala @@ -128,7 +128,7 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K private[streaming] class FileInputDStreamCheckpointData extends DStreamCheckpointData(this) { - def hadoopFiles = data.asInstanceOf[HashMap[Time, Array[String]]] + def hadoopFiles = data.asInstanceOf[HashMap[Time, Array[String]]] override def update() { hadoopFiles.clear() @@ -139,11 +139,12 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K override def restore() { hadoopFiles.foreach { - case (time, files) => { - logInfo("Restoring Hadoop RDD for time " + time + " from files " + - files.mkString("[", ",", "]") ) - files - generatedRDDs += ((time, filesToRDD(files))) + case (t, f) => { + // Restore the metadata in both files and generatedRDDs + logInfo("Restoring files for time " + t + " - " + + f.mkString("[", ", ", "]") ) + files += ((t, f)) + generatedRDDs += ((t, filesToRDD(f))) } } } diff --git a/streaming/src/test/scala/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/spark/streaming/FailureSuite.scala index c4cfffbfc1..efaa098d2e 100644 --- a/streaming/src/test/scala/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/spark/streaming/FailureSuite.scala @@ -1,58 +1,58 @@ package spark.streaming -import org.scalatest.BeforeAndAfter +import org.scalatest.{FunSuite, BeforeAndAfter} import org.apache.commons.io.FileUtils import java.io.File import scala.runtime.RichInt import scala.util.Random import spark.streaming.StreamingContext._ -import collection.mutable.ArrayBuffer +import collection.mutable.{SynchronizedBuffer, ArrayBuffer} import spark.Logging +import com.google.common.io.Files /** * This testsuite tests master failures at random times while the stream is running using * the real clock. */ -class FailureSuite extends TestSuiteBase with BeforeAndAfter { +class FailureSuite extends FunSuite with BeforeAndAfter with Logging { + + var testDir: File = null + var checkpointDir: File = null + val batchDuration = Milliseconds(500) before { - FileUtils.deleteDirectory(new File(checkpointDir)) + testDir = Files.createTempDir() + checkpointDir = Files.createTempDir() } after { FailureSuite.reset() - FileUtils.deleteDirectory(new File(checkpointDir)) + FileUtils.deleteDirectory(checkpointDir) + FileUtils.deleteDirectory(testDir) // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.driver.port") } - override def framework = "CheckpointSuite" - - override def batchDuration = Milliseconds(500) - - override def checkpointDir = "checkpoint" - - override def checkpointInterval = batchDuration - test("multiple failures with updateStateByKey") { val n = 30 // Input: time=1 ==> [ a ] , time=2 ==> [ a, a ] , time=3 ==> [ a, a, a ] , ... - val input = (1 to n).map(i => (1 to i).map(_ =>"a").toSeq).toSeq - // Last output: [ (a, 465) ] for n=30 - val lastOutput = Seq( ("a", (1 to n).reduce(_ + _)) ) + val input = (1 to n).map(i => (1 to i).map(_ => "a").mkString(" ")).toSeq + // Expected output: time=1 ==> [ (a, 1) ] , time=2 ==> [ (a, 3) ] , time=3 ==> [ (a,6) ] , ... + val expectedOutput = (1 to n).map(i => (1 to i).reduce(_ + _)).map(j => ("a", j)) val operation = (st: DStream[String]) => { val updateFunc = (values: Seq[Int], state: Option[RichInt]) => { Some(new RichInt(values.foldLeft(0)(_ + _) + state.map(_.self).getOrElse(0))) } - st.map(x => (x, 1)) - .updateStateByKey[RichInt](updateFunc) - .checkpoint(Seconds(2)) - .map(t => (t._1, t._2.self)) + st.flatMap(_.split(" ")) + .map(x => (x, 1)) + .updateStateByKey[RichInt](updateFunc) + .checkpoint(Seconds(2)) + .map(t => (t._1, t._2.self)) } - testOperationWithMultipleFailures(input, operation, lastOutput, n, n) + testOperationWithMultipleFailures(input, operation, expectedOutput) } test("multiple failures with reduceByKeyAndWindow") { @@ -60,17 +60,18 @@ class FailureSuite extends TestSuiteBase with BeforeAndAfter { val w = 100 assert(w > n, "Window should be much larger than the number of input sets in this test") // Input: time=1 ==> [ a ] , time=2 ==> [ a, a ] , time=3 ==> [ a, a, a ] , ... - val input = (1 to n).map(i => (1 to i).map(_ =>"a").toSeq).toSeq - // Last output: [ (a, 465) ] - val lastOutput = Seq( ("a", (1 to n).reduce(_ + _)) ) + val input = (1 to n).map(i => (1 to i).map(_ => "a").mkString(" ")).toSeq + // Expected output: time=1 ==> [ (a, 1) ] , time=2 ==> [ (a, 3) ] , time=3 ==> [ (a,6) ] , ... + val expectedOutput = (1 to n).map(i => (1 to i).reduce(_ + _)).map(j => ("a", j)) val operation = (st: DStream[String]) => { - st.map(x => (x, 1)) + st.flatMap(_.split(" ")) + .map(x => (x, 1)) .reduceByKeyAndWindow(_ + _, _ - _, batchDuration * w, batchDuration) .checkpoint(Seconds(2)) } - testOperationWithMultipleFailures(input, operation, lastOutput, n, n) + testOperationWithMultipleFailures(input, operation, expectedOutput) } @@ -79,113 +80,231 @@ class FailureSuite extends TestSuiteBase with BeforeAndAfter { * final set of output values is as expected or not. Checking the final value is * proof that no intermediate data was lost due to master failures. */ - def testOperationWithMultipleFailures[U: ClassManifest, V: ClassManifest]( - input: Seq[Seq[U]], - operation: DStream[U] => DStream[V], - lastExpectedOutput: Seq[V], - numBatches: Int, - numExpectedOutput: Int + def testOperationWithMultipleFailures( + input: Seq[String], + operation: DStream[String] => DStream[(String, Int)], + expectedOutput: Seq[(String, Int)] ) { - var ssc = setupStreams[U, V](input, operation) - val mergedOutput = new ArrayBuffer[Seq[V]]() + var ssc = setupStreamsWithFileStream(operation) + val mergedOutput = new ArrayBuffer[(String, Int)]() + val lastExpectedOutput = expectedOutput.last + + val maxTimeToRun = expectedOutput.size * batchDuration.milliseconds * 2 var totalTimeRan = 0L - while(totalTimeRan <= numBatches * batchDuration.milliseconds * 2) { - new KillingThread(ssc, numBatches * batchDuration.milliseconds.toInt / 4).start() - val (output, timeRan) = runStreamsWithRealClock[V](ssc, numBatches, numExpectedOutput) + // Start generating files in the a different thread + val fileGeneratingThread = new FileGeneratingThread(input, testDir.getPath, batchDuration.milliseconds) + fileGeneratingThread.start() + + // Repeatedly start and kill the streaming context until timed out or + // all expected output is generated + while(!FailureSuite.outputGenerated && !FailureSuite.timedOut) { + + // Start the thread to kill the streaming after some time + FailureSuite.failed = false + val killingThread = new KillingThread(ssc, batchDuration.milliseconds * 10) + killingThread.start() + + // Run the streams with real clock until last expected output is seen or timed out + val (output, timeRan) = runStreamsWithRealClock(ssc, lastExpectedOutput, maxTimeToRun - totalTimeRan) + if (killingThread.isAlive) killingThread.interrupt() + + // Merge output and time ran and see whether already timed out or not mergedOutput ++= output totalTimeRan += timeRan logInfo("New output = " + output) logInfo("Merged output = " + mergedOutput) logInfo("Total time spent = " + totalTimeRan) - val sleepTime = Random.nextInt(numBatches * batchDuration.milliseconds.toInt / 8) - logInfo( - "\n-------------------------------------------\n" + - " Restarting stream computation in " + sleepTime + " ms " + - "\n-------------------------------------------\n" - ) - Thread.sleep(sleepTime) - FailureSuite.failed = false - ssc = new StreamingContext(checkpointDir) + if (totalTimeRan > maxTimeToRun) { + FailureSuite.timedOut = true + } + + if (!FailureSuite.outputGenerated && !FailureSuite.timedOut) { + val sleepTime = Random.nextInt(batchDuration.milliseconds.toInt * 2) + logInfo( + "\n-------------------------------------------\n" + + " Restarting stream computation in " + sleepTime + " ms " + + "\n-------------------------------------------\n" + ) + Thread.sleep(sleepTime) + } + + // Recreate the streaming context from checkpoint + ssc = new StreamingContext(checkpointDir.getPath) } ssc.stop() ssc = null + logInfo("Finished test after " + FailureSuite.failureCount + " failures") - // Verify whether the last output is the expected one - val lastOutput = mergedOutput(mergedOutput.lastIndexWhere(!_.isEmpty)) - assert(lastOutput.toSet === lastExpectedOutput.toSet) - logInfo("Finished computation after " + FailureSuite.failureCount + " failures") + if (FailureSuite.timedOut) { + logWarning("Timed out with run time of "+ maxTimeToRun + " ms for " + + expectedOutput.size + " batches of " + batchDuration) + } + + // Verify whether the output is as expected + verifyOutput(mergedOutput, expectedOutput) + if (fileGeneratingThread.isAlive) fileGeneratingThread.interrupt() + } + + /** Sets up the stream operations with file input stream */ + def setupStreamsWithFileStream( + operation: DStream[String] => DStream[(String, Int)] + ): StreamingContext = { + val ssc = new StreamingContext("local[4]", "FailureSuite", batchDuration) + ssc.checkpoint(checkpointDir.getPath) + val inputStream = ssc.textFileStream(testDir.getPath) + val operatedStream = operation(inputStream) + val outputBuffer = new ArrayBuffer[Seq[(String, Int)]] with SynchronizedBuffer[Seq[(String, Int)]] + val outputStream = new TestOutputStream(operatedStream, outputBuffer) + ssc.registerOutputStream(outputStream) + ssc } /** - * Runs the streams set up in `ssc` on real clock until the expected max number of + * Runs the streams set up in `ssc` on real clock. */ - def runStreamsWithRealClock[V: ClassManifest]( - ssc: StreamingContext, - numBatches: Int, - maxExpectedOutput: Int - ): (Seq[Seq[V]], Long) = { + def runStreamsWithRealClock( + ssc: StreamingContext, + lastExpectedOutput: (String, Int), + timeout: Long + ): (Seq[(String, Int)], Long) = { System.clearProperty("spark.streaming.clock") - assert(numBatches > 0, "Number of batches to run stream computation is zero") - assert(maxExpectedOutput > 0, "Max expected outputs after " + numBatches + " is zero") - logInfo("numBatches = " + numBatches + ", maxExpectedOutput = " + maxExpectedOutput) - // Get the output buffer - val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]] + val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[(String, Int)]] val output = outputStream.output - val waitTime = (batchDuration.milliseconds * (numBatches.toDouble + 0.5)).toLong val startTime = System.currentTimeMillis() - try { - // Start computation - ssc.start() + // Functions to detect various conditions + def hasFailed = FailureSuite.failed + def isLastOutputGenerated = !output.flatMap(x => x).isEmpty && output(output.lastIndexWhere(!_.isEmpty)).head == lastExpectedOutput + def isTimedOut = System.currentTimeMillis() - startTime > timeout - // Wait until expected number of output items have been generated - while (output.size < maxExpectedOutput && System.currentTimeMillis() - startTime < waitTime && !FailureSuite.failed) { - logInfo("output.size = " + output.size + ", maxExpectedOutput = " + maxExpectedOutput) + // Start the streaming computation and let it run while ... + // (i) StreamingContext has not been shut down yet + // (ii) The last expected output has not been generated yet + // (iii) Its not timed out yet + try { + ssc.start() + while (!hasFailed && !isLastOutputGenerated && !isTimedOut) { Thread.sleep(100) } + logInfo("Has failed = " + hasFailed) + logInfo("Is last output generated = " + isLastOutputGenerated) + logInfo("Is timed out = " + isTimedOut) } catch { case e: Exception => logInfo("Exception while running streams: " + e) } finally { ssc.stop() } + + // Verify whether the output of each batch has only one element + assert(output.forall(_.size <= 1), "output of each batch should have only one element") + + // Set appropriate flags is timed out or output has been generated + if (isTimedOut) FailureSuite.timedOut = true + if (isLastOutputGenerated) FailureSuite.outputGenerated = true + val timeTaken = System.currentTimeMillis() - startTime logInfo("" + output.size + " sets of output generated in " + timeTaken + " ms") - (output, timeTaken) + (output.flatMap(_.headOption), timeTaken) } + /** + * Verifies the output value are the same as expected. Since failures can lead to + * a batch being processed twice, a batches output may appear more than once + * consecutively. To avoid getting confused with those, we eliminate consecutive + * duplicate batch outputs of values from the `output`. As a result, the + * expected output should not have consecutive batches with the same values as output. + */ + def verifyOutput(output: Seq[(String, Int)], expectedOutput: Seq[(String, Int)]) { + // Verify whether expected outputs do not consecutive batches with same output + for (i <- 0 until expectedOutput.size - 1) { + assert(expectedOutput(i) != expectedOutput(i+1), + "Expected output has consecutive duplicate sequence of values") + } + // Match the output with the expected output + logInfo( + "\n-------------------------------------------\n" + + " Verifying output " + + "\n-------------------------------------------\n" + ) + logInfo("Expected output, size = " + expectedOutput.size) + logInfo(expectedOutput.mkString("[", ",", "]")) + logInfo("Output, size = " + output.size) + logInfo(output.mkString("[", ",", "]")) + output.foreach(o => + assert(expectedOutput.contains(o), "Expected value " + o + " not found") + ) + } } object FailureSuite { var failed = false + var outputGenerated = false + var timedOut = false var failureCount = 0 def reset() { failed = false + outputGenerated = false + timedOut = false failureCount = 0 } } -class KillingThread(ssc: StreamingContext, maxKillWaitTime: Int) extends Thread with Logging { +/** + * Thread to kill streaming context after some time. + */ +class KillingThread(ssc: StreamingContext, maxKillWaitTime: Long) extends Thread with Logging { initLogging() override def run() { - var minKillWaitTime = if (FailureSuite.failureCount == 0) 3000 else 1000 // to allow the first checkpoint - val killWaitTime = minKillWaitTime + Random.nextInt(maxKillWaitTime) - logInfo("Kill wait time = " + killWaitTime) - Thread.sleep(killWaitTime.toLong) - logInfo( - "\n---------------------------------------\n" + - "Killing streaming context after " + killWaitTime + " ms" + - "\n---------------------------------------\n" - ) - if (ssc != null) ssc.stop() - FailureSuite.failed = true - FailureSuite.failureCount += 1 + try { + var minKillWaitTime = if (FailureSuite.failureCount == 0) 5000 else 1000 // to allow the first checkpoint + val killWaitTime = minKillWaitTime + math.abs(Random.nextLong % maxKillWaitTime) + logInfo("Kill wait time = " + killWaitTime) + Thread.sleep(killWaitTime) + logInfo( + "\n---------------------------------------\n" + + "Killing streaming context after " + killWaitTime + " ms" + + "\n---------------------------------------\n" + ) + if (ssc != null) { + ssc.stop() + FailureSuite.failed = true + FailureSuite.failureCount += 1 + } + logInfo("Killing thread exited") + } catch { + case ie: InterruptedException => logInfo("Killing thread interrupted") + case e: Exception => logWarning("Exception in killing thread", e) + } } } + +/** + * Thread to generate input files periodically with the desired text + */ +class FileGeneratingThread(input: Seq[String], testDir: String, interval: Long) + extends Thread with Logging { + initLogging() + + override def run() { + try { + Thread.sleep(5000) // To make sure that all the streaming context has been set up + for (i <- 0 until input.size) { + FileUtils.writeStringToFile(new File(testDir, i.toString), input(i).toString + "\n") + Thread.sleep(interval) + } + logInfo("File generating thread exited") + } catch { + case ie: InterruptedException => logInfo("File generating thread interrupted") + case e: Exception => logWarning("File generating in killing thread", e) + } + } +} + From 39addd380363c0371e935fae50983fe87158c1ac Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 13 Feb 2013 12:17:45 -0800 Subject: [PATCH 283/291] Changed scheduler and file input stream to fix bugs in the driver fault tolerance. Added MasterFailureTest to rigorously test master fault tolerance with file input stream. --- .../main/scala/spark/streaming/DStream.scala | 23 +- .../streaming/DStreamCheckpointData.scala | 2 +- .../scala/spark/streaming/DStreamGraph.scala | 49 ++- .../scala/spark/streaming/JobManager.scala | 10 +- .../scala/spark/streaming/Scheduler.scala | 92 +++-- .../src/main/scala/spark/streaming/Time.scala | 10 + .../streaming/dstream/FileInputDStream.scala | 59 ++- .../dstream/NetworkInputDStream.scala | 11 +- .../streaming/util/MasterFailureTest.scala | 375 ++++++++++++++++++ .../spark/streaming/util/RecurringTimer.scala | 30 +- .../java/spark/streaming/JavaAPISuite.java | 21 +- streaming/src/test/resources/log4j.properties | 7 +- .../streaming/BasicOperationsSuite.scala | 2 + .../spark/streaming/CheckpointSuite.scala | 107 +++-- .../scala/spark/streaming/FailureSuite.scala | 304 +------------- .../spark/streaming/InputStreamsSuite.scala | 29 +- .../scala/spark/streaming/TestSuiteBase.scala | 12 +- .../streaming/WindowOperationsSuite.scala | 2 + 18 files changed, 693 insertions(+), 452 deletions(-) create mode 100644 streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 0eb6aad187..0c1b667c0a 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -292,7 +292,7 @@ abstract class DStream[T: ClassManifest] ( * Generate a SparkStreaming job for the given time. This is an internal method that * should not be called directly. This default implementation creates a job * that materializes the corresponding RDD. Subclasses of DStream may override this - * (eg. ForEachDStream). + * to generate their own jobs. */ protected[streaming] def generateJob(time: Time): Option[Job] = { getOrCompute(time) match { @@ -308,19 +308,18 @@ abstract class DStream[T: ClassManifest] ( } /** - * Dereference RDDs that are older than rememberDuration. + * Clear metadata that are older than `rememberDuration` of this DStream. + * This is an internal method that should not be called directly. This default + * implementation clears the old generated RDDs. Subclasses of DStream may override + * this to clear their own metadata along with the generated RDDs. */ - protected[streaming] def forgetOldMetadata(time: Time) { + protected[streaming] def clearOldMetadata(time: Time) { var numForgotten = 0 - generatedRDDs.keys.foreach(t => { - if (t <= (time - rememberDuration)) { - generatedRDDs.remove(t) - numForgotten += 1 - logInfo("Forgot RDD of time " + t + " from " + this) - } - }) - logInfo("Forgot " + numForgotten + " RDDs from " + this) - dependencies.foreach(_.forgetOldMetadata(time)) + val oldRDDs = generatedRDDs.filter(_._1 <= (time - rememberDuration)) + generatedRDDs --= oldRDDs.keys + logInfo("Cleared " + oldRDDs.size + " RDDs that were older than " + + (time - rememberDuration) + ": " + oldRDDs.keys.mkString(", ")) + dependencies.foreach(_.clearOldMetadata(time)) } /* Adds metadata to the Stream while it is running. diff --git a/streaming/src/main/scala/spark/streaming/DStreamCheckpointData.scala b/streaming/src/main/scala/spark/streaming/DStreamCheckpointData.scala index a375980b84..6b0fade7c6 100644 --- a/streaming/src/main/scala/spark/streaming/DStreamCheckpointData.scala +++ b/streaming/src/main/scala/spark/streaming/DStreamCheckpointData.scala @@ -87,7 +87,7 @@ class DStreamCheckpointData[T: ClassManifest] (dstream: DStream[T]) } override def toString() = { - "[\n" + checkpointFiles.size + "\n" + checkpointFiles.mkString("\n") + "\n]" + "[\n" + checkpointFiles.size + " checkpoint files \n" + checkpointFiles.mkString("\n") + "\n]" } } diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala index 7aa9d20004..22d9e24f05 100644 --- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala @@ -11,17 +11,20 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { private val inputStreams = new ArrayBuffer[InputDStream[_]]() private val outputStreams = new ArrayBuffer[DStream[_]]() - private[streaming] var zeroTime: Time = null - private[streaming] var batchDuration: Duration = null - private[streaming] var rememberDuration: Duration = null - private[streaming] var checkpointInProgress = false + var rememberDuration: Duration = null + var checkpointInProgress = false - private[streaming] def start(time: Time) { + var zeroTime: Time = null + var startTime: Time = null + var batchDuration: Duration = null + + def start(time: Time) { this.synchronized { if (zeroTime != null) { throw new Exception("DStream graph computation already started") } zeroTime = time + startTime = time outputStreams.foreach(_.initialize(zeroTime)) outputStreams.foreach(_.remember(rememberDuration)) outputStreams.foreach(_.validate) @@ -29,19 +32,23 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { } } - private[streaming] def stop() { + def restart(time: Time) { + this.synchronized { startTime = time } + } + + def stop() { this.synchronized { inputStreams.par.foreach(_.stop()) } } - private[streaming] def setContext(ssc: StreamingContext) { + def setContext(ssc: StreamingContext) { this.synchronized { outputStreams.foreach(_.setContext(ssc)) } } - private[streaming] def setBatchDuration(duration: Duration) { + def setBatchDuration(duration: Duration) { this.synchronized { if (batchDuration != null) { throw new Exception("Batch duration already set as " + batchDuration + @@ -51,61 +58,61 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { batchDuration = duration } - private[streaming] def remember(duration: Duration) { + def remember(duration: Duration) { this.synchronized { if (rememberDuration != null) { throw new Exception("Batch duration already set as " + batchDuration + ". cannot set it again.") } + rememberDuration = duration } - rememberDuration = duration } - private[streaming] def addInputStream(inputStream: InputDStream[_]) { + def addInputStream(inputStream: InputDStream[_]) { this.synchronized { inputStream.setGraph(this) inputStreams += inputStream } } - private[streaming] def addOutputStream(outputStream: DStream[_]) { + def addOutputStream(outputStream: DStream[_]) { this.synchronized { outputStream.setGraph(this) outputStreams += outputStream } } - private[streaming] def getInputStreams() = this.synchronized { inputStreams.toArray } + def getInputStreams() = this.synchronized { inputStreams.toArray } - private[streaming] def getOutputStreams() = this.synchronized { outputStreams.toArray } + def getOutputStreams() = this.synchronized { outputStreams.toArray } - private[streaming] def generateRDDs(time: Time): Seq[Job] = { + def generateRDDs(time: Time): Seq[Job] = { this.synchronized { logInfo("Generating RDDs for time " + time) outputStreams.flatMap(outputStream => outputStream.generateJob(time)) } } - private[streaming] def forgetOldRDDs(time: Time) { + def clearOldMetadata(time: Time) { this.synchronized { - logInfo("Forgetting old RDDs for time " + time) - outputStreams.foreach(_.forgetOldMetadata(time)) + logInfo("Clearing old metadata for time " + time) + outputStreams.foreach(_.clearOldMetadata(time)) } } - private[streaming] def updateCheckpointData(time: Time) { + def updateCheckpointData(time: Time) { this.synchronized { outputStreams.foreach(_.updateCheckpointData(time)) } } - private[streaming] def restoreCheckpointData() { + def restoreCheckpointData() { this.synchronized { outputStreams.foreach(_.restoreCheckpointData()) } } - private[streaming] def validate() { + def validate() { this.synchronized { assert(batchDuration != null, "Batch duration has not been set") //assert(batchDuration >= Milliseconds(100), "Batch duration of " + batchDuration + " is very low") diff --git a/streaming/src/main/scala/spark/streaming/JobManager.scala b/streaming/src/main/scala/spark/streaming/JobManager.scala index 8b18c7bc6a..649494ff4a 100644 --- a/streaming/src/main/scala/spark/streaming/JobManager.scala +++ b/streaming/src/main/scala/spark/streaming/JobManager.scala @@ -38,13 +38,19 @@ class JobManager(ssc: StreamingContext, numThreads: Int = 1) extends Logging { logInfo("Added " + job + " to queue") } + def stop() { + jobExecutor.shutdown() + } + private def clearJob(job: Job) { jobs.synchronized { - val jobsOfTime = jobs.get(job.time) + val time = job.time + val jobsOfTime = jobs.get(time) if (jobsOfTime.isDefined) { jobsOfTime.get -= job if (jobsOfTime.get.isEmpty) { - jobs -= job.time + ssc.scheduler.clearOldMetadata(time) + jobs -= time } } else { throw new Exception("Job finished for time " + job.time + diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index 23a0f0974d..57d494da83 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -9,11 +9,8 @@ class Scheduler(ssc: StreamingContext) extends Logging { initLogging() - val graph = ssc.graph - val concurrentJobs = System.getProperty("spark.streaming.concurrentJobs", "1").toInt val jobManager = new JobManager(ssc, concurrentJobs) - val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) { new CheckpointWriter(ssc.checkpointDir) } else { @@ -24,53 +21,80 @@ class Scheduler(ssc: StreamingContext) extends Logging { val clock = Class.forName(clockClass).newInstance().asInstanceOf[Clock] val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds, longTime => generateRDDs(new Time(longTime))) + val graph = ssc.graph - def start() { - // If context was started from checkpoint, then restart timer such that - // this timer's triggers occur at the same time as the original timer. - // Otherwise just start the timer from scratch, and initialize graph based - // on this first trigger time of the timer. + def start() = synchronized { if (ssc.isCheckpointPresent) { - // If manual clock is being used for testing, then - // either set the manual clock to the last checkpointed time, - // or if the property is defined set it to that time - if (clock.isInstanceOf[ManualClock]) { - val lastTime = ssc.initialCheckpoint.checkpointTime.milliseconds - val jumpTime = System.getProperty("spark.streaming.manualClock.jump", "0").toLong - clock.asInstanceOf[ManualClock].setTime(lastTime + jumpTime) - } - // Reschedule the batches that were received but not processed before failure - //ssc.initialCheckpoint.pendingTimes.foreach(time => generateRDDs(time)) - val pendingTimes = ssc.initialCheckpoint.pendingTimes.sorted(Time.ordering) - println(pendingTimes.mkString(", ")) - pendingTimes.foreach(time => - graph.generateRDDs(time).foreach(jobManager.runJob)) - // Restart the timer - timer.restart(graph.zeroTime.milliseconds) - logInfo("Scheduler's timer restarted") + restart() } else { - val firstTime = new Time(timer.start()) - graph.start(firstTime - ssc.graph.batchDuration) - logInfo("Scheduler's timer started") + startFirstTime() } logInfo("Scheduler started") } - def stop() { + def stop() = synchronized { timer.stop() - graph.stop() + jobManager.stop() + ssc.graph.stop() logInfo("Scheduler stopped") } - - private def generateRDDs(time: Time) { + + private def startFirstTime() { + val startTime = new Time(timer.getStartTime()) + graph.start(startTime - graph.batchDuration) + timer.start(startTime.milliseconds) + logInfo("Scheduler's timer started at " + startTime) + } + + private def restart() { + + // If manual clock is being used for testing, then + // either set the manual clock to the last checkpointed time, + // or if the property is defined set it to that time + if (clock.isInstanceOf[ManualClock]) { + val lastTime = ssc.initialCheckpoint.checkpointTime.milliseconds + val jumpTime = System.getProperty("spark.streaming.manualClock.jump", "0").toLong + clock.asInstanceOf[ManualClock].setTime(lastTime + jumpTime) + } + + val batchDuration = ssc.graph.batchDuration + + // Batches when the master was down, that is, + // between the checkpoint and current restart time + val checkpointTime = ssc.initialCheckpoint.checkpointTime + val restartTime = new Time(timer.getRestartTime(graph.zeroTime.milliseconds)) + val downTimes = checkpointTime.until(restartTime, batchDuration) + logInfo("Batches during down time: " + downTimes.mkString(", ")) + + // Batches that were unprocessed before failure + val pendingTimes = ssc.initialCheckpoint.pendingTimes + logInfo("Batches pending processing: " + pendingTimes.mkString(", ")) + // Reschedule jobs for these times + val timesToReschedule = (pendingTimes ++ downTimes).distinct.sorted(Time.ordering) + logInfo("Batches to reschedule: " + timesToReschedule.mkString(", ")) + timesToReschedule.foreach(time => + graph.generateRDDs(time).foreach(jobManager.runJob) + ) + + // Restart the timer + timer.start(restartTime.milliseconds) + logInfo("Scheduler's timer restarted") + } + + /** Generates the RDDs, clears old metadata and does checkpoint for the given time */ + def generateRDDs(time: Time) { SparkEnv.set(ssc.env) logInfo("\n-----------------------------------------------------\n") graph.generateRDDs(time).foreach(jobManager.runJob) - graph.forgetOldRDDs(time) doCheckpoint(time) } - private def doCheckpoint(time: Time) { + + def clearOldMetadata(time: Time) { + ssc.graph.clearOldMetadata(time) + } + + def doCheckpoint(time: Time) { if (ssc.checkpointDuration != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) { logInfo("Checkpointing graph for time " + time) val startTime = System.currentTimeMillis() diff --git a/streaming/src/main/scala/spark/streaming/Time.scala b/streaming/src/main/scala/spark/streaming/Time.scala index 8a6c9a5cb5..8201e84a20 100644 --- a/streaming/src/main/scala/spark/streaming/Time.scala +++ b/streaming/src/main/scala/spark/streaming/Time.scala @@ -37,6 +37,16 @@ case class Time(private val millis: Long) { def max(that: Time): Time = if (this > that) this else that + def until(that: Time, interval: Duration): Seq[Time] = { + assert(that > this, "Cannot create sequence as " + that + " not more than " + this) + assert( + (that - this).isMultipleOf(interval), + "Cannot create sequence as gap between " + that + " and " + + this + " is not multiple of " + interval + ) + (this.milliseconds) until (that.milliseconds) by (interval.milliseconds) map (new Time(_)) + } + override def toString: String = (millis.toString + " ms") } diff --git a/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala index 10ccb4318d..41b9bd9461 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala @@ -21,19 +21,21 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K protected[streaming] override val checkpointData = new FileInputDStreamCheckpointData + // Latest file mod time seen till any point of time private val lastModTimeFiles = new HashSet[String]() private var lastModTime = 0L @transient private var path_ : Path = null @transient private var fs_ : FileSystem = null - @transient private var files = new HashMap[Time, Array[String]] + @transient private[streaming] var files = new HashMap[Time, Array[String]] override def start() { if (newFilesOnly) { - lastModTime = System.currentTimeMillis() + lastModTime = graph.zeroTime.milliseconds } else { lastModTime = 0 } + logDebug("LastModTime initialized to " + lastModTime + ", new files only = " + newFilesOnly) } override def stop() { } @@ -43,38 +45,50 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K * a union RDD out of them. Note that this maintains the list of files that were processed * in the latest modification time in the previous call to this method. This is because the * modification time returned by the FileStatus API seems to return times only at the - * granularity of seconds. Hence, new files may have the same modification time as the - * latest modification time in the previous call to this method and the list of files - * maintained is used to filter the one that have been processed. + * granularity of seconds. And new files may have the same modification time as the + * latest modification time in the previous call to this method yet was not reported in + * the previous call. */ override def compute(validTime: Time): Option[RDD[(K, V)]] = { + assert(validTime.milliseconds >= lastModTime, "Trying to get new files for really old time [" + validTime + " < " + lastModTime) + // Create the filter for selecting new files val newFilter = new PathFilter() { + // Latest file mod time seen in this round of fetching files and its corresponding files var latestModTime = 0L val latestModTimeFiles = new HashSet[String]() def accept(path: Path): Boolean = { - if (!filter(path)) { + if (!filter(path)) { // Reject file if it does not satisfy filter + logDebug("Rejected by filter " + path) return false - } else { + } else { // Accept file only if val modTime = fs.getFileStatus(path).getModificationTime() - if (modTime < lastModTime){ - return false + logDebug("Mod time for " + path + " is " + modTime) + if (modTime < lastModTime) { + logDebug("Mod time less than last mod time") + return false // If the file was created before the last time it was called } else if (modTime == lastModTime && lastModTimeFiles.contains(path.toString)) { - return false + logDebug("Mod time equal to last mod time, but file considered already") + return false // If the file was created exactly as lastModTime but not reported yet + } else if (modTime > validTime.milliseconds) { + logDebug("Mod time more than valid time") + return false // If the file was created after the time this function call requires } if (modTime > latestModTime) { latestModTime = modTime latestModTimeFiles.clear() + logDebug("Latest mod time updated to " + latestModTime) } latestModTimeFiles += path.toString + logDebug("Accepted " + path) return true } } } - + logDebug("Finding new files at time " + validTime + " for last mod time = " + lastModTime) val newFiles = fs.listStatus(path, newFilter).map(_.getPath.toString) - logInfo("New files: " + newFiles.mkString(", ")) + logInfo("New files at time " + validTime + ":\n" + newFiles.mkString("\n")) if (newFiles.length > 0) { // Update the modification time and the files processed for that modification time if (lastModTime != newFilter.latestModTime) { @@ -82,17 +96,21 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K lastModTimeFiles.clear() } lastModTimeFiles ++= newFilter.latestModTimeFiles + logDebug("Last mod time updated to " + lastModTime) } files += ((validTime, newFiles)) Some(filesToRDD(newFiles)) } - /** Forget the old time-to-files mappings along with old RDDs */ - protected[streaming] override def forgetOldMetadata(time: Time) { - super.forgetOldMetadata(time) - val filesToBeRemoved = files.filter(_._1 <= (time - rememberDuration)) - files --= filesToBeRemoved.keys - logInfo("Forgot " + filesToBeRemoved.size + " files from " + this) + /** Clear the old time-to-files mappings along with old RDDs */ + protected[streaming] override def clearOldMetadata(time: Time) { + super.clearOldMetadata(time) + val oldFiles = files.filter(_._1 <= (time - rememberDuration)) + files --= oldFiles.keys + logInfo("Cleared " + oldFiles.size + " old files that were older than " + + (time - rememberDuration) + ": " + oldFiles.keys.mkString(", ")) + logDebug("Cleared files are:\n" + + oldFiles.map(p => (p._1, p._2.mkString(", "))).mkString("\n")) } /** Generate one RDD from an array of files */ @@ -148,6 +166,11 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K } } } + + override def toString() = { + "[\n" + hadoopFiles.size + " file sets\n" + + hadoopFiles.map(p => (p._1, p._2.mkString(", "))).mkString("\n") + "\n]" + } } } diff --git a/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala index 8c322dd698..ecc75ec913 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala @@ -46,8 +46,15 @@ abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : Streaming def stop() {} override def compute(validTime: Time): Option[RDD[T]] = { - val blockIds = ssc.networkInputTracker.getBlockIds(id, validTime) - Some(new BlockRDD[T](ssc.sc, blockIds)) + // If this is called for any time before the start time of the context, + // then this returns an empty RDD. This may happen when recovering from a + // master failure forces + if (validTime >= graph.startTime) { + val blockIds = ssc.networkInputTracker.getBlockIds(id, validTime) + Some(new BlockRDD[T](ssc.sc, blockIds)) + } else { + Some(new BlockRDD[T](ssc.sc, Array[String]())) + } } } diff --git a/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala b/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala new file mode 100644 index 0000000000..3ffe4b64d0 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala @@ -0,0 +1,375 @@ +package spark.streaming.util + +import spark.{Logging, RDD} +import spark.streaming._ +import spark.streaming.dstream.ForEachDStream +import StreamingContext._ + +import scala.util.Random +import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} + +import java.io.{File, ObjectInputStream, IOException} +import java.util.UUID + +import com.google.common.io.Files + +import org.apache.commons.io.FileUtils +import org.apache.hadoop.fs.{FileUtil, FileSystem, Path} +import org.apache.hadoop.conf.Configuration + + +private[streaming] +object MasterFailureTest extends Logging { + initLogging() + + @volatile var killed = false + @volatile var killCount = 0 + + def main(args: Array[String]) { + if (args.size < 2) { + println( + "Usage: MasterFailureTest <# batches> []") + System.exit(1) + } + val directory = args(0) + val numBatches = args(1).toInt + val batchDuration = if (args.size > 2) Milliseconds(args(2).toInt) else Seconds(1) + + println("\n\n========================= MAP TEST =========================\n\n") + testMap(directory, numBatches, batchDuration) + + println("\n\n================= UPDATE-STATE-BY-KEY TEST =================\n\n") + testUpdateStateByKey(directory, numBatches, batchDuration) + } + + def testMap(directory: String, numBatches: Int, batchDuration: Duration) { + // Input: time=1 ==> [ 1 ] , time=2 ==> [ 2 ] , time=3 ==> [ 3 ] , ... + val input = (1 to numBatches).map(_.toString).toSeq + // Expected output: time=1 ==> [ 1 ] , time=2 ==> [ 2 ] , time=3 ==> [ 3 ] , ... + val expectedOutput = (1 to numBatches) + + val operation = (st: DStream[String]) => st.map(_.toInt) + + // Run streaming operation with multiple master failures + val output = testOperation(directory, batchDuration, input, operation, expectedOutput) + + logInfo("Expected output, size = " + expectedOutput.size) + logInfo(expectedOutput.mkString("[", ",", "]")) + logInfo("Output, size = " + output.size) + logInfo(output.mkString("[", ",", "]")) + + // Verify whether all the values of the expected output is present + // in the output + assert(output.distinct.toSet == expectedOutput.toSet) + } + + + def testUpdateStateByKey(directory: String, numBatches: Int, batchDuration: Duration) { + // Input: time=1 ==> [ a ] , time=2 ==> [ a, a ] , time=3 ==> [ a, a, a ] , ... + val input = (1 to numBatches).map(i => (1 to i).map(_ => "a").mkString(" ")).toSeq + // Expected output: time=1 ==> [ (a, 1) ] , time=2 ==> [ (a, 3) ] , time=3 ==> [ (a,6) ] , ... + val expectedOutput = (1L to numBatches).map(i => (1L to i).reduce(_ + _)).map(j => ("a", j)) + + val operation = (st: DStream[String]) => { + val updateFunc = (values: Seq[Long], state: Option[Long]) => { + Some(values.foldLeft(0L)(_ + _) + state.getOrElse(0L)) + } + st.flatMap(_.split(" ")) + .map(x => (x, 1L)) + .updateStateByKey[Long](updateFunc) + .checkpoint(batchDuration * 5) + } + + // Run streaming operation with multiple master failures + val output = testOperation(directory, batchDuration, input, operation, expectedOutput) + + logInfo("Expected output, size = " + expectedOutput.size + "\n" + expectedOutput) + logInfo("Output, size = " + output.size + "\n" + output) + + // Verify whether all the values in the output are among the expected output values + output.foreach(o => + assert(expectedOutput.contains(o), "Expected value " + o + " not found") + ) + + // Verify whether the last expected output value has been generated, there by + // confirming that none of the inputs have been missed + assert(output.last == expectedOutput.last) + } + + /** + * Tests stream operation with multiple master failures, and verifies whether the + * final set of output values is as expected or not. + */ + def testOperation[T: ClassManifest]( + directory: String, + batchDuration: Duration, + input: Seq[String], + operation: DStream[String] => DStream[T], + expectedOutput: Seq[T] + ): Seq[T] = { + + // Just making sure that the expected output does not have duplicates + assert(expectedOutput.distinct.toSet == expectedOutput.toSet) + + // Setup the stream computation with the given operation + val (ssc, checkpointDir, testDir) = setupStreams(directory, batchDuration, operation) + + // Start generating files in the a different thread + val fileGeneratingThread = new FileGeneratingThread(input, testDir, batchDuration.milliseconds) + fileGeneratingThread.start() + + // Run the streams and repeatedly kill it until the last expected output + // has been generated, or until it has run for twice the expected time + val lastExpectedOutput = expectedOutput.last + val maxTimeToRun = expectedOutput.size * batchDuration.milliseconds * 2 + val mergedOutput = runStreams(ssc, lastExpectedOutput, maxTimeToRun) + + // Delete directories + fileGeneratingThread.join() + val fs = checkpointDir.getFileSystem(new Configuration()) + fs.delete(checkpointDir, true) + fs.delete(testDir, true) + logInfo("Finished test after " + killCount + " failures") + mergedOutput + } + + /** + * Sets up the stream computation with the given operation, directory (local or HDFS), + * and batch duration. Returns the streaming context and the directory to which + * files should be written for testing. + */ + private def setupStreams[T: ClassManifest]( + directory: String, + batchDuration: Duration, + operation: DStream[String] => DStream[T] + ): (StreamingContext, Path, Path) = { + // Reset all state + reset() + + // Create the directories for this test + val uuid = UUID.randomUUID().toString + val rootDir = new Path(directory, uuid) + val fs = rootDir.getFileSystem(new Configuration()) + val checkpointDir = new Path(rootDir, "checkpoint") + val testDir = new Path(rootDir, "test") + fs.mkdirs(checkpointDir) + fs.mkdirs(testDir) + + // Setup the streaming computation with the given operation + System.clearProperty("spark.driver.port") + var ssc = new StreamingContext("local[4]", "MasterFailureTest", batchDuration) + ssc.checkpoint(checkpointDir.toString) + val inputStream = ssc.textFileStream(testDir.toString) + val operatedStream = operation(inputStream) + val outputStream = new TestOutputStream(operatedStream) + ssc.registerOutputStream(outputStream) + (ssc, checkpointDir, testDir) + } + + + /** + * Repeatedly starts and kills the streaming context until timed out or + * the last expected output is generated. Finally, return + */ + private def runStreams[T: ClassManifest]( + ssc_ : StreamingContext, + lastExpectedOutput: T, + maxTimeToRun: Long + ): Seq[T] = { + + var ssc = ssc_ + var totalTimeRan = 0L + var isLastOutputGenerated = false + var isTimedOut = false + val mergedOutput = new ArrayBuffer[T]() + val checkpointDir = ssc.checkpointDir + var batchDuration = ssc.graph.batchDuration + + while(!isLastOutputGenerated && !isTimedOut) { + // Get the output buffer + val outputBuffer = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[T]].output + def output = outputBuffer.flatMap(x => x) + + // Start the thread to kill the streaming after some time + killed = false + val killingThread = new KillingThread(ssc, batchDuration.milliseconds * 10) + killingThread.start() + + var timeRan = 0L + try { + // Start the streaming computation and let it run while ... + // (i) StreamingContext has not been shut down yet + // (ii) The last expected output has not been generated yet + // (iii) Its not timed out yet + System.clearProperty("spark.streaming.clock") + System.clearProperty("spark.driver.port") + ssc.start() + val startTime = System.currentTimeMillis() + while (!killed && !isLastOutputGenerated && !isTimedOut) { + Thread.sleep(100) + timeRan = System.currentTimeMillis() - startTime + isLastOutputGenerated = (!output.isEmpty && output.last == lastExpectedOutput) + isTimedOut = (timeRan + totalTimeRan > maxTimeToRun) + } + } catch { + case e: Exception => logError("Error running streaming context", e) + } + if (killingThread.isAlive) killingThread.interrupt() + ssc.stop() + + logInfo("Has been killed = " + killed) + logInfo("Is last output generated = " + isLastOutputGenerated) + logInfo("Is timed out = " + isTimedOut) + + // Verify whether the output of each batch has only one element or no element + // and then merge the new output with all the earlier output + mergedOutput ++= output + totalTimeRan += timeRan + logInfo("New output = " + output) + logInfo("Merged output = " + mergedOutput) + logInfo("Time ran = " + timeRan) + logInfo("Total time ran = " + totalTimeRan) + + if (!isLastOutputGenerated && !isTimedOut) { + val sleepTime = Random.nextInt(batchDuration.milliseconds.toInt * 10) + logInfo( + "\n-------------------------------------------\n" + + " Restarting stream computation in " + sleepTime + " ms " + + "\n-------------------------------------------\n" + ) + Thread.sleep(sleepTime) + // Recreate the streaming context from checkpoint + ssc = new StreamingContext(checkpointDir) + } + } + mergedOutput + } + + /** + * Verifies the output value are the same as expected. Since failures can lead to + * a batch being processed twice, a batches output may appear more than once + * consecutively. To avoid getting confused with those, we eliminate consecutive + * duplicate batch outputs of values from the `output`. As a result, the + * expected output should not have consecutive batches with the same values as output. + */ + private def verifyOutput[T: ClassManifest](output: Seq[T], expectedOutput: Seq[T]) { + // Verify whether expected outputs do not consecutive batches with same output + for (i <- 0 until expectedOutput.size - 1) { + assert(expectedOutput(i) != expectedOutput(i+1), + "Expected output has consecutive duplicate sequence of values") + } + + // Log the output + println("Expected output, size = " + expectedOutput.size) + println(expectedOutput.mkString("[", ",", "]")) + println("Output, size = " + output.size) + println(output.mkString("[", ",", "]")) + + // Match the output with the expected output + output.foreach(o => + assert(expectedOutput.contains(o), "Expected value " + o + " not found") + ) + } + + /** Resets counter to prepare for the test */ + private def reset() { + killed = false + killCount = 0 + } +} + +/** + * This is a output stream just for testing. All the output is collected into a + * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint. + */ +private[streaming] +class TestOutputStream[T: ClassManifest]( + parent: DStream[T], + val output: ArrayBuffer[Seq[T]] = new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]] + ) extends ForEachDStream[T]( + parent, + (rdd: RDD[T], t: Time) => { + val collected = rdd.collect() + output += collected + println(t + ": " + collected.mkString("[", ",", "]")) + } + ) { + + // This is to clear the output buffer every it is read from a checkpoint + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + ois.defaultReadObject() + output.clear() + } +} + + +/** + * Thread to kill streaming context after a random period of time. + */ +private[streaming] +class KillingThread(ssc: StreamingContext, maxKillWaitTime: Long) extends Thread with Logging { + initLogging() + + override def run() { + try { + // If it is the first killing, then allow the first checkpoint to be created + var minKillWaitTime = if (MasterFailureTest.killCount == 0) 5000 else 1000 + val killWaitTime = minKillWaitTime + math.abs(Random.nextLong % maxKillWaitTime) + logInfo("Kill wait time = " + killWaitTime) + Thread.sleep(killWaitTime) + logInfo( + "\n---------------------------------------\n" + + "Killing streaming context after " + killWaitTime + " ms" + + "\n---------------------------------------\n" + ) + if (ssc != null) { + ssc.stop() + MasterFailureTest.killed = true + MasterFailureTest.killCount += 1 + } + logInfo("Killing thread finished normally") + } catch { + case ie: InterruptedException => logInfo("Killing thread interrupted") + case e: Exception => logWarning("Exception in killing thread", e) + } + + } +} + + +/** + * Thread to generate input files periodically with the desired text. + */ +private[streaming] +class FileGeneratingThread(input: Seq[String], testDir: Path, interval: Long) + extends Thread with Logging { + initLogging() + + override def run() { + val localTestDir = Files.createTempDir() + val fs = testDir.getFileSystem(new Configuration()) + try { + Thread.sleep(5000) // To make sure that all the streaming context has been set up + for (i <- 0 until input.size) { + // Write the data to a local file and then move it to the target test directory + val localFile = new File(localTestDir, (i+1).toString) + val hadoopFile = new Path(testDir, (i+1).toString) + FileUtils.writeStringToFile(localFile, input(i).toString + "\n") + //fs.moveFromLocalFile(new Path(localFile.toString), new Path(testDir, i.toString)) + fs.copyFromLocalFile(new Path(localFile.toString), hadoopFile) + logInfo("Generated file " + hadoopFile + " at " + System.currentTimeMillis) + Thread.sleep(interval) + localFile.delete() + } + logInfo("File generating thread finished normally") + } catch { + case ie: InterruptedException => logInfo("File generating thread interrupted") + case e: Exception => logWarning("File generating in killing thread", e) + } finally { + fs.close() + } + } +} + + diff --git a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala index db715cc295..8e10276deb 100644 --- a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala +++ b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala @@ -3,9 +3,9 @@ package spark.streaming.util private[streaming] class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => Unit) { - val minPollTime = 25L + private val minPollTime = 25L - val pollTime = { + private val pollTime = { if (period / 10.0 > minPollTime) { (period / 10.0).toLong } else { @@ -13,11 +13,20 @@ class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => } } - val thread = new Thread() { + private val thread = new Thread() { override def run() { loop } } - var nextTime = 0L + private var nextTime = 0L + + def getStartTime(): Long = { + (math.floor(clock.currentTime.toDouble / period) + 1).toLong * period + } + + def getRestartTime(originalStartTime: Long): Long = { + val gap = clock.currentTime - originalStartTime + (math.floor(gap.toDouble / period).toLong + 1) * period + originalStartTime + } def start(startTime: Long): Long = { nextTime = startTime @@ -26,21 +35,14 @@ class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => } def start(): Long = { - val startTime = (math.floor(clock.currentTime.toDouble / period) + 1).toLong * period - start(startTime) + start(getStartTime()) } - def restart(originalStartTime: Long): Long = { - val gap = clock.currentTime - originalStartTime - val newStartTime = (math.floor(gap.toDouble / period).toLong + 1) * period + originalStartTime - start(newStartTime) - } - - def stop() { + def stop() { thread.interrupt() } - def loop() { + private def loop() { try { while (true) { clock.waitTillTime(nextTime) diff --git a/streaming/src/test/java/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/spark/streaming/JavaAPISuite.java index fbe4af4597..783a393a8f 100644 --- a/streaming/src/test/java/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/spark/streaming/JavaAPISuite.java @@ -33,7 +33,8 @@ public class JavaAPISuite implements Serializable { @Before public void setUp() { - ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock"); + ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); ssc.checkpoint("checkpoint", new Duration(1000)); } @@ -45,7 +46,7 @@ public class JavaAPISuite implements Serializable { // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.driver.port"); } - /* + @Test public void testCount() { List> inputData = Arrays.asList( @@ -434,7 +435,7 @@ public class JavaAPISuite implements Serializable { assertOrderInvariantEquals(expected, result); } - */ + /* * Performs an order-invariant comparison of lists representing two RDD streams. This allows * us to account for ordering variation within individual RDD's which occurs during windowing. @@ -450,7 +451,7 @@ public class JavaAPISuite implements Serializable { Assert.assertEquals(expected, actual); } - /* + // PairDStream Functions @Test public void testPairFilter() { @@ -897,7 +898,7 @@ public class JavaAPISuite implements Serializable { Assert.assertEquals(expected, result); } - */ + @Test public void testCheckpointMasterRecovery() throws InterruptedException { List> inputData = Arrays.asList( @@ -964,7 +965,7 @@ public class JavaAPISuite implements Serializable { assertOrderInvariantEquals(expected, result1); } */ - /* + // Input stream tests. These mostly just test that we can instantiate a given InputStream with // Java arguments and assign it to a JavaDStream without producing type errors. Testing of the // InputStream functionality is deferred to the existing Scala tests. @@ -972,9 +973,9 @@ public class JavaAPISuite implements Serializable { public void testKafkaStream() { HashMap topics = Maps.newHashMap(); HashMap offsets = Maps.newHashMap(); - JavaDStream test1 = ssc.kafkaStream("localhost", 12345, "group", topics); - JavaDStream test2 = ssc.kafkaStream("localhost", 12345, "group", topics, offsets); - JavaDStream test3 = ssc.kafkaStream("localhost", 12345, "group", topics, offsets, + JavaDStream test1 = ssc.kafkaStream("localhost:12345", "group", topics); + JavaDStream test2 = ssc.kafkaStream("localhost:12345", "group", topics, offsets); + JavaDStream test3 = ssc.kafkaStream("localhost:12345", "group", topics, offsets, StorageLevel.MEMORY_AND_DISK()); } @@ -1026,5 +1027,5 @@ public class JavaAPISuite implements Serializable { public void testFileStream() { JavaPairDStream foo = ssc.fileStream("/tmp/foo"); - }*/ + } } diff --git a/streaming/src/test/resources/log4j.properties b/streaming/src/test/resources/log4j.properties index edfa1243fa..5652596e1e 100644 --- a/streaming/src/test/resources/log4j.properties +++ b/streaming/src/test/resources/log4j.properties @@ -1,6 +1,7 @@ # Set everything to be logged to the file streaming/target/unit-tests.log -log4j.rootCategory=INFO, file -log4j.appender.file=org.apache.log4j.FileAppender +log4j.rootCategory=WARN, file +# log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file=org.apache.log4j.ConsoleAppender log4j.appender.file.append=false log4j.appender.file.file=streaming/target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout @@ -8,4 +9,6 @@ log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: # Ignore messages below warning level from Jetty, because it's a bit verbose log4j.logger.org.eclipse.jetty=WARN +log4j.logger.spark.streaming=INFO +log4j.logger.spark.streaming.dstream.FileInputDStream=DEBUG diff --git a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala index c031949dd1..12388b8887 100644 --- a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala @@ -6,6 +6,8 @@ import util.ManualClock class BasicOperationsSuite extends TestSuiteBase { + System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") + override def framework() = "BasicOperationsSuite" after { diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala index 7126af62d9..c89c4a8d43 100644 --- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -1,5 +1,6 @@ package spark.streaming +import dstream.FileInputDStream import spark.streaming.StreamingContext._ import java.io.File import runtime.RichInt @@ -10,8 +11,16 @@ import util.{Clock, ManualClock} import scala.util.Random import com.google.common.io.Files + +/** + * This test suites tests the checkpointing functionality of DStreams - + * the checkpointing of a DStream's RDDs as well as the checkpointing of + * the whole DStream graph. + */ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { + System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") + before { FileUtils.deleteDirectory(new File(checkpointDir)) } @@ -64,7 +73,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { // Run till a time such that at least one RDD in the stream should have been checkpointed, // then check whether some RDD has been checkpointed or not ssc.start() - runStreamsWithRealDelay(ssc, firstNumBatches) + advanceTimeWithRealDelay(ssc, firstNumBatches) logInfo("Checkpoint data of state stream = \n" + stateStream.checkpointData) assert(!stateStream.checkpointData.checkpointFiles.isEmpty, "No checkpointed RDDs in state stream before first failure") stateStream.checkpointData.checkpointFiles.foreach { @@ -77,7 +86,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { // Run till a further time such that previous checkpoint files in the stream would be deleted // and check whether the earlier checkpoint files are deleted val checkpointFiles = stateStream.checkpointData.checkpointFiles.map(x => new File(x._2)) - runStreamsWithRealDelay(ssc, secondNumBatches) + advanceTimeWithRealDelay(ssc, secondNumBatches) checkpointFiles.foreach(file => assert(!file.exists, "Checkpoint file '" + file + "' was not deleted")) ssc.stop() @@ -92,7 +101,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { // Run one batch to generate a new checkpoint file and check whether some RDD // is present in the checkpoint data or not ssc.start() - runStreamsWithRealDelay(ssc, 1) + advanceTimeWithRealDelay(ssc, 1) assert(!stateStream.checkpointData.checkpointFiles.isEmpty, "No checkpointed RDDs in state stream before second failure") stateStream.checkpointData.checkpointFiles.foreach { case (time, data) => { @@ -113,7 +122,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { // Adjust manual clock time as if it is being restarted after a delay System.setProperty("spark.streaming.manualClock.jump", (batchDuration.milliseconds * 7).toString) ssc.start() - runStreamsWithRealDelay(ssc, 4) + advanceTimeWithRealDelay(ssc, 4) ssc.stop() System.clearProperty("spark.streaming.manualClock.jump") ssc = null @@ -168,74 +177,95 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { } // This tests whether file input stream remembers what files were seen before - // the master failure and uses them again to process a large window operatoin. + // the master failure and uses them again to process a large window operation. // It also tests whether batches, whose processing was incomplete due to the // failure, are re-processed or not. test("recovery with file input stream") { + // Disable manual clock as FileInputDStream does not work with manual clock + val clockProperty = System.getProperty("spark.streaming.clock") + System.clearProperty("spark.streaming.clock") + // Set up the streaming context and input streams val testDir = Files.createTempDir() - var ssc = new StreamingContext(master, framework, batchDuration) + var ssc = new StreamingContext(master, framework, Seconds(1)) ssc.checkpoint(checkpointDir, checkpointInterval) val fileStream = ssc.textFileStream(testDir.toString) // Making value 3 take large time to process, to ensure that the master // shuts down in the middle of processing the 3rd batch val mappedStream = fileStream.map(s => { val i = s.toInt - if (i == 3) Thread.sleep(1000) + if (i == 3) Thread.sleep(2000) i }) + // Reducing over a large window to ensure that recovery from master failure // requires reprocessing of all the files seen before the failure - val reducedStream = mappedStream.reduceByWindow(_ + _, batchDuration * 30, batchDuration) + val reducedStream = mappedStream.reduceByWindow(_ + _, Seconds(30), Seconds(1)) val outputBuffer = new ArrayBuffer[Seq[Int]] var outputStream = new TestOutputStream(reducedStream, outputBuffer) ssc.registerOutputStream(outputStream) ssc.start() // Create files and advance manual clock to process them - var clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + //var clock = ssc.scheduler.clock.asInstanceOf[ManualClock] Thread.sleep(1000) for (i <- Seq(1, 2, 3)) { FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n") // wait to make sure that the file is written such that it gets shown in the file listings - Thread.sleep(500) - clock.addToTime(batchDuration.milliseconds) - // wait to make sure that FileInputDStream picks up this file only and not any other file - Thread.sleep(500) + Thread.sleep(1000) } logInfo("Output = " + outputStream.output.mkString(",")) assert(outputStream.output.size > 0, "No files processed before restart") ssc.stop() + // Verify whether files created have been recorded correctly or not + var fileInputDStream = ssc.graph.getInputStreams().head.asInstanceOf[FileInputDStream[_, _, _]] + def recordedFiles = fileInputDStream.files.values.flatMap(x => x) + assert(!recordedFiles.filter(_.endsWith("1")).isEmpty) + assert(!recordedFiles.filter(_.endsWith("2")).isEmpty) + assert(!recordedFiles.filter(_.endsWith("3")).isEmpty) + // Create files while the master is down for (i <- Seq(4, 5, 6)) { FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n") Thread.sleep(1000) } - // Restart stream computation from checkpoint and create more files to see whether - // they are being processed + // Recover context from checkpoint file and verify whether the files that were + // recorded before failure were saved and successfully recovered logInfo("*********** RESTARTING ************") ssc = new StreamingContext(checkpointDir) + fileInputDStream = ssc.graph.getInputStreams().head.asInstanceOf[FileInputDStream[_, _, _]] + assert(!recordedFiles.filter(_.endsWith("1")).isEmpty) + assert(!recordedFiles.filter(_.endsWith("2")).isEmpty) + assert(!recordedFiles.filter(_.endsWith("3")).isEmpty) + + // Restart stream computation ssc.start() - clock = ssc.scheduler.clock.asInstanceOf[ManualClock] for (i <- Seq(7, 8, 9)) { FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n") - Thread.sleep(500) - clock.addToTime(batchDuration.milliseconds) - Thread.sleep(500) + Thread.sleep(1000) } Thread.sleep(1000) - logInfo("Output = " + outputStream.output.mkString(",")) + logInfo("Output = " + outputStream.output.mkString("[", ", ", "]")) assert(outputStream.output.size > 0, "No files processed after restart") ssc.stop() + // Verify whether files created while the driver was down have been recorded or not + assert(!recordedFiles.filter(_.endsWith("4")).isEmpty) + assert(!recordedFiles.filter(_.endsWith("5")).isEmpty) + assert(!recordedFiles.filter(_.endsWith("6")).isEmpty) + + // Verify whether new files created after recover have been recorded or not + assert(!recordedFiles.filter(_.endsWith("7")).isEmpty) + assert(!recordedFiles.filter(_.endsWith("8")).isEmpty) + assert(!recordedFiles.filter(_.endsWith("9")).isEmpty) + // Append the new output to the old buffer outputStream = ssc.graph.getOutputStreams().head.asInstanceOf[TestOutputStream[Int]] outputBuffer ++= outputStream.output - // Verify whether data received by Spark Streaming was as expected - val expectedOutput = Seq(1, 3, 6, 28, 36, 45) + val expectedOutput = Seq(1, 3, 6, 10, 15, 21, 28, 36, 45) logInfo("--------------------------------") logInfo("output, size = " + outputBuffer.size) outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) @@ -244,11 +274,17 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { logInfo("--------------------------------") // Verify whether all the elements received are as expected - assert(outputBuffer.size === expectedOutput.size) - for (i <- 0 until outputBuffer.size) { - assert(outputBuffer(i).size === 1) - assert(outputBuffer(i).head === expectedOutput(i)) - } + val output = outputBuffer.flatMap(x => x) + assert(output.contains(6)) // To ensure that the 3rd input (i.e., 3) was processed + output.foreach(o => // To ensure all the inputs are correctly added cumulatively + assert(expectedOutput.contains(o), "Expected value " + o + " not found") + ) + // To ensure that all the inputs were received correctly + assert(expectedOutput.last === output.last) + + // Enable manual clock back again for other tests + if (clockProperty != null) + System.setProperty("spark.streaming.clock", clockProperty) } @@ -278,7 +314,9 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { // Do the computation for initial number of batches, create checkpoint file and quit ssc = setupStreams[U, V](input, operation) - val output = runStreams[V](ssc, initialNumBatches, initialNumExpectedOutputs) + ssc.start() + val output = advanceTimeWithRealDelay[V](ssc, initialNumBatches) + ssc.stop() verifyOutput[V](output, expectedOutput.take(initialNumBatches), true) Thread.sleep(1000) @@ -289,17 +327,20 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { "\n-------------------------------------------\n" ) ssc = new StreamingContext(checkpointDir) - val outputNew = runStreams[V](ssc, nextNumBatches, nextNumExpectedOutputs) + System.clearProperty("spark.driver.port") + ssc.start() + val outputNew = advanceTimeWithRealDelay[V](ssc, nextNumBatches) // the first element will be re-processed data of the last batch before restart verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true) + ssc.stop() ssc = null } /** * Advances the manual clock on the streaming scheduler by given number of batches. - * It also wait for the expected amount of time for each batch. + * It also waits for the expected amount of time for each batch. */ - def runStreamsWithRealDelay(ssc: StreamingContext, numBatches: Long) { + def advanceTimeWithRealDelay[V: ClassManifest](ssc: StreamingContext, numBatches: Long): Seq[Seq[V]] = { val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] logInfo("Manual clock before advancing = " + clock.time) for (i <- 1 to numBatches.toInt) { @@ -308,6 +349,8 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { } logInfo("Manual clock after advancing = " + clock.time) Thread.sleep(batchDuration.milliseconds) - } + val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]] + outputStream.output + } } \ No newline at end of file diff --git a/streaming/src/test/scala/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/spark/streaming/FailureSuite.scala index efaa098d2e..a5fa7ab92d 100644 --- a/streaming/src/test/scala/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/spark/streaming/FailureSuite.scala @@ -1,14 +1,15 @@ package spark.streaming -import org.scalatest.{FunSuite, BeforeAndAfter} -import org.apache.commons.io.FileUtils -import java.io.File -import scala.runtime.RichInt -import scala.util.Random -import spark.streaming.StreamingContext._ -import collection.mutable.{SynchronizedBuffer, ArrayBuffer} import spark.Logging +import spark.streaming.util.MasterFailureTest +import StreamingContext._ + +import org.scalatest.{FunSuite, BeforeAndAfter} import com.google.common.io.Files +import java.io.File +import org.apache.commons.io.FileUtils +import collection.mutable.ArrayBuffer + /** * This testsuite tests master failures at random times while the stream is running using @@ -16,295 +17,24 @@ import com.google.common.io.Files */ class FailureSuite extends FunSuite with BeforeAndAfter with Logging { - var testDir: File = null - var checkpointDir: File = null - val batchDuration = Milliseconds(500) + var directory = "FailureSuite" + val numBatches = 30 + val batchDuration = Milliseconds(1000) before { - testDir = Files.createTempDir() - checkpointDir = Files.createTempDir() + FileUtils.deleteDirectory(new File(directory)) } after { - FailureSuite.reset() - FileUtils.deleteDirectory(checkpointDir) - FileUtils.deleteDirectory(testDir) + FileUtils.deleteDirectory(new File(directory)) + } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.driver.port") + test("multiple failures with map") { + MasterFailureTest.testMap(directory, numBatches, batchDuration) } test("multiple failures with updateStateByKey") { - val n = 30 - // Input: time=1 ==> [ a ] , time=2 ==> [ a, a ] , time=3 ==> [ a, a, a ] , ... - val input = (1 to n).map(i => (1 to i).map(_ => "a").mkString(" ")).toSeq - // Expected output: time=1 ==> [ (a, 1) ] , time=2 ==> [ (a, 3) ] , time=3 ==> [ (a,6) ] , ... - val expectedOutput = (1 to n).map(i => (1 to i).reduce(_ + _)).map(j => ("a", j)) - - val operation = (st: DStream[String]) => { - val updateFunc = (values: Seq[Int], state: Option[RichInt]) => { - Some(new RichInt(values.foldLeft(0)(_ + _) + state.map(_.self).getOrElse(0))) - } - st.flatMap(_.split(" ")) - .map(x => (x, 1)) - .updateStateByKey[RichInt](updateFunc) - .checkpoint(Seconds(2)) - .map(t => (t._1, t._2.self)) - } - - testOperationWithMultipleFailures(input, operation, expectedOutput) - } - - test("multiple failures with reduceByKeyAndWindow") { - val n = 30 - val w = 100 - assert(w > n, "Window should be much larger than the number of input sets in this test") - // Input: time=1 ==> [ a ] , time=2 ==> [ a, a ] , time=3 ==> [ a, a, a ] , ... - val input = (1 to n).map(i => (1 to i).map(_ => "a").mkString(" ")).toSeq - // Expected output: time=1 ==> [ (a, 1) ] , time=2 ==> [ (a, 3) ] , time=3 ==> [ (a,6) ] , ... - val expectedOutput = (1 to n).map(i => (1 to i).reduce(_ + _)).map(j => ("a", j)) - - val operation = (st: DStream[String]) => { - st.flatMap(_.split(" ")) - .map(x => (x, 1)) - .reduceByKeyAndWindow(_ + _, _ - _, batchDuration * w, batchDuration) - .checkpoint(Seconds(2)) - } - - testOperationWithMultipleFailures(input, operation, expectedOutput) - } - - - /** - * Tests stream operation with multiple master failures, and verifies whether the - * final set of output values is as expected or not. Checking the final value is - * proof that no intermediate data was lost due to master failures. - */ - def testOperationWithMultipleFailures( - input: Seq[String], - operation: DStream[String] => DStream[(String, Int)], - expectedOutput: Seq[(String, Int)] - ) { - var ssc = setupStreamsWithFileStream(operation) - - val mergedOutput = new ArrayBuffer[(String, Int)]() - val lastExpectedOutput = expectedOutput.last - - val maxTimeToRun = expectedOutput.size * batchDuration.milliseconds * 2 - var totalTimeRan = 0L - - // Start generating files in the a different thread - val fileGeneratingThread = new FileGeneratingThread(input, testDir.getPath, batchDuration.milliseconds) - fileGeneratingThread.start() - - // Repeatedly start and kill the streaming context until timed out or - // all expected output is generated - while(!FailureSuite.outputGenerated && !FailureSuite.timedOut) { - - // Start the thread to kill the streaming after some time - FailureSuite.failed = false - val killingThread = new KillingThread(ssc, batchDuration.milliseconds * 10) - killingThread.start() - - // Run the streams with real clock until last expected output is seen or timed out - val (output, timeRan) = runStreamsWithRealClock(ssc, lastExpectedOutput, maxTimeToRun - totalTimeRan) - if (killingThread.isAlive) killingThread.interrupt() - - // Merge output and time ran and see whether already timed out or not - mergedOutput ++= output - totalTimeRan += timeRan - logInfo("New output = " + output) - logInfo("Merged output = " + mergedOutput) - logInfo("Total time spent = " + totalTimeRan) - if (totalTimeRan > maxTimeToRun) { - FailureSuite.timedOut = true - } - - if (!FailureSuite.outputGenerated && !FailureSuite.timedOut) { - val sleepTime = Random.nextInt(batchDuration.milliseconds.toInt * 2) - logInfo( - "\n-------------------------------------------\n" + - " Restarting stream computation in " + sleepTime + " ms " + - "\n-------------------------------------------\n" - ) - Thread.sleep(sleepTime) - } - - // Recreate the streaming context from checkpoint - ssc = new StreamingContext(checkpointDir.getPath) - } - ssc.stop() - ssc = null - logInfo("Finished test after " + FailureSuite.failureCount + " failures") - - if (FailureSuite.timedOut) { - logWarning("Timed out with run time of "+ maxTimeToRun + " ms for " + - expectedOutput.size + " batches of " + batchDuration) - } - - // Verify whether the output is as expected - verifyOutput(mergedOutput, expectedOutput) - if (fileGeneratingThread.isAlive) fileGeneratingThread.interrupt() - } - - /** Sets up the stream operations with file input stream */ - def setupStreamsWithFileStream( - operation: DStream[String] => DStream[(String, Int)] - ): StreamingContext = { - val ssc = new StreamingContext("local[4]", "FailureSuite", batchDuration) - ssc.checkpoint(checkpointDir.getPath) - val inputStream = ssc.textFileStream(testDir.getPath) - val operatedStream = operation(inputStream) - val outputBuffer = new ArrayBuffer[Seq[(String, Int)]] with SynchronizedBuffer[Seq[(String, Int)]] - val outputStream = new TestOutputStream(operatedStream, outputBuffer) - ssc.registerOutputStream(outputStream) - ssc - } - - /** - * Runs the streams set up in `ssc` on real clock. - */ - def runStreamsWithRealClock( - ssc: StreamingContext, - lastExpectedOutput: (String, Int), - timeout: Long - ): (Seq[(String, Int)], Long) = { - - System.clearProperty("spark.streaming.clock") - - // Get the output buffer - val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[(String, Int)]] - val output = outputStream.output - val startTime = System.currentTimeMillis() - - // Functions to detect various conditions - def hasFailed = FailureSuite.failed - def isLastOutputGenerated = !output.flatMap(x => x).isEmpty && output(output.lastIndexWhere(!_.isEmpty)).head == lastExpectedOutput - def isTimedOut = System.currentTimeMillis() - startTime > timeout - - // Start the streaming computation and let it run while ... - // (i) StreamingContext has not been shut down yet - // (ii) The last expected output has not been generated yet - // (iii) Its not timed out yet - try { - ssc.start() - while (!hasFailed && !isLastOutputGenerated && !isTimedOut) { - Thread.sleep(100) - } - logInfo("Has failed = " + hasFailed) - logInfo("Is last output generated = " + isLastOutputGenerated) - logInfo("Is timed out = " + isTimedOut) - } catch { - case e: Exception => logInfo("Exception while running streams: " + e) - } finally { - ssc.stop() - } - - // Verify whether the output of each batch has only one element - assert(output.forall(_.size <= 1), "output of each batch should have only one element") - - // Set appropriate flags is timed out or output has been generated - if (isTimedOut) FailureSuite.timedOut = true - if (isLastOutputGenerated) FailureSuite.outputGenerated = true - - val timeTaken = System.currentTimeMillis() - startTime - logInfo("" + output.size + " sets of output generated in " + timeTaken + " ms") - (output.flatMap(_.headOption), timeTaken) - } - - /** - * Verifies the output value are the same as expected. Since failures can lead to - * a batch being processed twice, a batches output may appear more than once - * consecutively. To avoid getting confused with those, we eliminate consecutive - * duplicate batch outputs of values from the `output`. As a result, the - * expected output should not have consecutive batches with the same values as output. - */ - def verifyOutput(output: Seq[(String, Int)], expectedOutput: Seq[(String, Int)]) { - // Verify whether expected outputs do not consecutive batches with same output - for (i <- 0 until expectedOutput.size - 1) { - assert(expectedOutput(i) != expectedOutput(i+1), - "Expected output has consecutive duplicate sequence of values") - } - - // Match the output with the expected output - logInfo( - "\n-------------------------------------------\n" + - " Verifying output " + - "\n-------------------------------------------\n" - ) - logInfo("Expected output, size = " + expectedOutput.size) - logInfo(expectedOutput.mkString("[", ",", "]")) - logInfo("Output, size = " + output.size) - logInfo(output.mkString("[", ",", "]")) - output.foreach(o => - assert(expectedOutput.contains(o), "Expected value " + o + " not found") - ) - } -} - -object FailureSuite { - var failed = false - var outputGenerated = false - var timedOut = false - var failureCount = 0 - - def reset() { - failed = false - outputGenerated = false - timedOut = false - failureCount = 0 - } -} - -/** - * Thread to kill streaming context after some time. - */ -class KillingThread(ssc: StreamingContext, maxKillWaitTime: Long) extends Thread with Logging { - initLogging() - - override def run() { - try { - var minKillWaitTime = if (FailureSuite.failureCount == 0) 5000 else 1000 // to allow the first checkpoint - val killWaitTime = minKillWaitTime + math.abs(Random.nextLong % maxKillWaitTime) - logInfo("Kill wait time = " + killWaitTime) - Thread.sleep(killWaitTime) - logInfo( - "\n---------------------------------------\n" + - "Killing streaming context after " + killWaitTime + " ms" + - "\n---------------------------------------\n" - ) - if (ssc != null) { - ssc.stop() - FailureSuite.failed = true - FailureSuite.failureCount += 1 - } - logInfo("Killing thread exited") - } catch { - case ie: InterruptedException => logInfo("Killing thread interrupted") - case e: Exception => logWarning("Exception in killing thread", e) - } - } -} - -/** - * Thread to generate input files periodically with the desired text - */ -class FileGeneratingThread(input: Seq[String], testDir: String, interval: Long) - extends Thread with Logging { - initLogging() - - override def run() { - try { - Thread.sleep(5000) // To make sure that all the streaming context has been set up - for (i <- 0 until input.size) { - FileUtils.writeStringToFile(new File(testDir, i.toString), input(i).toString + "\n") - Thread.sleep(interval) - } - logInfo("File generating thread exited") - } catch { - case ie: InterruptedException => logInfo("File generating thread interrupted") - case e: Exception => logWarning("File generating in killing thread", e) - } + MasterFailureTest.testUpdateStateByKey(directory, numBatches, batchDuration) } } diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala index 0eb9c7b81e..7c1c2e1040 100644 --- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala @@ -133,26 +133,29 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { test("file input stream") { + // Disable manual clock as FileInputDStream does not work with manual clock + System.clearProperty("spark.streaming.clock") + // Set up the streaming context and input streams val testDir = Files.createTempDir() val ssc = new StreamingContext(master, framework, batchDuration) - val filestream = ssc.textFileStream(testDir.toString) + val fileStream = ssc.textFileStream(testDir.toString) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] def output = outputBuffer.flatMap(x => x) - val outputStream = new TestOutputStream(filestream, outputBuffer) + val outputStream = new TestOutputStream(fileStream, outputBuffer) ssc.registerOutputStream(outputStream) ssc.start() // Create files in the temporary directory so that Spark Streaming can read data from it - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val input = Seq(1, 2, 3, 4, 5) val expectedOutput = input.map(_.toString) Thread.sleep(1000) for (i <- 0 until input.size) { - FileUtils.writeStringToFile(new File(testDir, i.toString), input(i).toString + "\n") - Thread.sleep(500) - clock.addToTime(batchDuration.milliseconds) - //Thread.sleep(100) + val file = new File(testDir, i.toString) + FileUtils.writeStringToFile(file, input(i).toString + "\n") + logInfo("Created file " + file) + Thread.sleep(batchDuration.milliseconds) + Thread.sleep(1000) } val startTime = System.currentTimeMillis() Thread.sleep(1000) @@ -171,16 +174,16 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Verify whether all the elements received are as expected // (whether the elements were received one in each interval is not verified) - assert(output.size === expectedOutput.size) - for (i <- 0 until output.size) { - assert(output(i).size === 1) - assert(output(i).head.toString === expectedOutput(i)) - } + assert(output.toList === expectedOutput.toList) + FileUtils.deleteDirectory(testDir) + + // Enable manual clock back again for other tests + System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") } } - +/** This is server to test the network input stream */ class TestServer(port: Int) extends Logging { val queue = new ArrayBlockingQueue[String](100) diff --git a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala index c2733831b2..2cc31d6137 100644 --- a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala @@ -63,20 +63,28 @@ class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBu */ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { + // Name of the framework for Spark context def framework = "TestSuiteBase" + // Master for Spark context def master = "local[2]" + // Batch duration def batchDuration = Seconds(1) + // Directory where the checkpoint data will be saved def checkpointDir = "checkpoint" + // Duration after which the graph is checkpointed def checkpointInterval = batchDuration + // Number of partitions of the input parallel collections created for testing def numInputPartitions = 2 + // Maximum time to wait before the test times out def maxWaitTimeMillis = 10000 + // Whether to actually wait in real time before changing manual clock def actuallyWait = false /** @@ -140,9 +148,6 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { numBatches: Int, numExpectedOutput: Int ): Seq[Seq[V]] = { - - System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") - assert(numBatches > 0, "Number of batches to run stream computation is zero") assert(numExpectedOutput > 0, "Number of expected outputs after " + numBatches + " is zero") logInfo("numBatches = " + numBatches + ", numExpectedOutput = " + numExpectedOutput) @@ -186,7 +191,6 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { } finally { ssc.stop() } - output } diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala index cd9608df53..1080790147 100644 --- a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala @@ -5,6 +5,8 @@ import collection.mutable.ArrayBuffer class WindowOperationsSuite extends TestSuiteBase { + System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") + override def framework = "WindowOperationsSuite" override def maxWaitTimeMillis = 20000 From 12b020b6689b8db94df904d9b897a43bce18c971 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 13 Feb 2013 20:53:50 -0800 Subject: [PATCH 284/291] Added filter functionality to reduceByKeyAndWindow with inverse. Consolidated reduceByKeyAndWindow's many functions into smaller number of functions with optional parameters. --- .../main/scala/spark/streaming/DStream.scala | 2 +- .../streaming/PairDStreamFunctions.scala | 71 ++++++++----------- .../streaming/api/java/JavaPairDStream.scala | 28 +++++--- .../dstream/ReducedWindowedDStream.scala | 30 +++++--- .../streaming/util/MasterFailureTest.scala | 1 - streaming/src/test/resources/log4j.properties | 2 +- .../streaming/WindowOperationsSuite.scala | 49 ++++++++----- 7 files changed, 102 insertions(+), 81 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 0c1b667c0a..6abec9e6be 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -531,7 +531,7 @@ abstract class DStream[T: ClassManifest] ( windowDuration: Duration, slideDuration: Duration ): DStream[T] = { - this.window(windowDuration, slideDuration).reduce(reduceFunc) + this.reduce(reduceFunc).window(windowDuration, slideDuration).reduce(reduceFunc) } def reduceByWindow( diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala index fbcf061126..021ff83b36 100644 --- a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala @@ -137,7 +137,8 @@ extends Serializable { * @param slideDuration sliding interval of the window (i.e., the interval after which * the new DStream will generate RDDs); must be a multiple of this * DStream's batching interval - * @param numPartitions Number of partitions of each RDD in the new DStream. + * @param numPartitions number of partitions of each RDD in the new DStream; if not specified + * then Spark's default number of partitions will be used */ def groupByKeyAndWindow( windowDuration: Duration, @@ -155,7 +156,7 @@ extends Serializable { * @param slideDuration sliding interval of the window (i.e., the interval after which * the new DStream will generate RDDs); must be a multiple of this * DStream's batching interval - * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. + * @param partitioner partitioner for controlling the partitioning of each RDD in the new DStream. */ def groupByKeyAndWindow( windowDuration: Duration, @@ -213,7 +214,7 @@ extends Serializable { * @param numPartitions Number of partitions of each RDD in the new DStream. */ def reduceByKeyAndWindow( - reduceFunc: (V, V) => V, + reduceFunc: (V, V) => V, windowDuration: Duration, slideDuration: Duration, numPartitions: Int @@ -230,7 +231,8 @@ extends Serializable { * @param slideDuration sliding interval of the window (i.e., the interval after which * the new DStream will generate RDDs); must be a multiple of this * DStream's batching interval - * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. + * @param partitioner partitioner for controlling the partitioning of each RDD + * in the new DStream. */ def reduceByKeyAndWindow( reduceFunc: (V, V) => V, @@ -245,7 +247,7 @@ extends Serializable { } /** - * Create a new DStream by reducing over a using incremental computation. + * Create a new DStream by applying incremental `reduceByKey` over a sliding window. * The reduced value of over a new window is calculated using the old window's reduce value : * 1. reduce the new values that entered the window (e.g., adding new counts) * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) @@ -253,81 +255,64 @@ extends Serializable { * However, it is applicable to only "invertible reduce functions". * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. * @param reduceFunc associative reduce function - * @param invReduceFunc inverse function + * @param invReduceFunc inverse reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval * @param slideDuration sliding interval of the window (i.e., the interval after which * the new DStream will generate RDDs); must be a multiple of this * DStream's batching interval + * @param filterFunc Optional function to filter expired key-value pairs; + * only pairs that satisfy the function are retained */ def reduceByKeyAndWindow( reduceFunc: (V, V) => V, invReduceFunc: (V, V) => V, windowDuration: Duration, - slideDuration: Duration + slideDuration: Duration = self.slideDuration, + numPartitions: Int = ssc.sc.defaultParallelism, + filterFunc: ((K, V)) => Boolean = null ): DStream[(K, V)] = { reduceByKeyAndWindow( - reduceFunc, invReduceFunc, windowDuration, slideDuration, defaultPartitioner()) + reduceFunc, invReduceFunc, windowDuration, + slideDuration, defaultPartitioner(numPartitions), filterFunc + ) } /** - * Create a new DStream by reducing over a using incremental computation. + * Create a new DStream by applying incremental `reduceByKey` over a sliding window. * The reduced value of over a new window is calculated using the old window's reduce value : * 1. reduce the new values that entered the window (e.g., adding new counts) * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. * However, it is applicable to only "invertible reduce functions". - * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. - * @param reduceFunc associative reduce function - * @param invReduceFunc inverse function + * @param reduceFunc associative reduce function + * @param invReduceFunc inverse reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval * @param slideDuration sliding interval of the window (i.e., the interval after which * the new DStream will generate RDDs); must be a multiple of this * DStream's batching interval - * @param numPartitions Number of partitions of each RDD in the new DStream. + * @param partitioner partitioner for controlling the partitioning of each RDD in the new DStream. + * @param filterFunc Optional function to filter expired key-value pairs; + * only pairs that satisfy the function are retained */ def reduceByKeyAndWindow( reduceFunc: (V, V) => V, invReduceFunc: (V, V) => V, windowDuration: Duration, slideDuration: Duration, - numPartitions: Int - ): DStream[(K, V)] = { - - reduceByKeyAndWindow( - reduceFunc, invReduceFunc, windowDuration, slideDuration, defaultPartitioner(numPartitions)) - } - - /** - * Create a new DStream by reducing over a using incremental computation. - * The reduced value of over a new window is calculated using the old window's reduce value : - * 1. reduce the new values that entered the window (e.g., adding new counts) - * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) - * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. - * However, it is applicable to only "invertible reduce functions". - * @param reduceFunc associative reduce function - * @param invReduceFunc inverse function - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. - */ - def reduceByKeyAndWindow( - reduceFunc: (V, V) => V, - invReduceFunc: (V, V) => V, - windowDuration: Duration, - slideDuration: Duration, - partitioner: Partitioner + partitioner: Partitioner, + filterFunc: ((K, V)) => Boolean ): DStream[(K, V)] = { val cleanedReduceFunc = ssc.sc.clean(reduceFunc) val cleanedInvReduceFunc = ssc.sc.clean(invReduceFunc) + val cleanedFilterFunc = if (filterFunc != null) Some(ssc.sc.clean(filterFunc)) else None new ReducedWindowedDStream[K, V]( - self, cleanedReduceFunc, cleanedInvReduceFunc, windowDuration, slideDuration, partitioner) + self, cleanedReduceFunc, cleanedInvReduceFunc, cleanedFilterFunc, + windowDuration, slideDuration, partitioner + ) } /** diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala index ef10c091ca..4d3e0d0304 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala @@ -328,7 +328,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Create a new DStream by reducing over a using incremental computation. + * Create a new DStream by applying incremental `reduceByKey` over a sliding window. * The reduced value of over a new window is calculated using the old window's reduce value : * 1. reduce the new values that entered the window (e.g., adding new counts) * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) @@ -342,25 +342,31 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * @param slideDuration sliding interval of the window (i.e., the interval after which * the new DStream will generate RDDs); must be a multiple of this * DStream's batching interval - * @param numPartitions Number of partitions of each RDD in the new DStream. + * @param numPartitions number of partitions of each RDD in the new DStream. + * @param filterFunc function to filter expired key-value pairs; + * only pairs that satisfy the function are retained + * set this to null if you do not want to filter */ def reduceByKeyAndWindow( reduceFunc: Function2[V, V, V], invReduceFunc: Function2[V, V, V], windowDuration: Duration, slideDuration: Duration, - numPartitions: Int + numPartitions: Int, + filterFunc: JFunction[(K, V), java.lang.Boolean] ): JavaPairDStream[K, V] = { dstream.reduceByKeyAndWindow( reduceFunc, invReduceFunc, windowDuration, slideDuration, - numPartitions) + numPartitions, + (p: (K, V)) => filterFunc(p).booleanValue() + ) } /** - * Create a new DStream by reducing over a using incremental computation. + * Create a new DStream by applying incremental `reduceByKey` over a sliding window. * The reduced value of over a new window is calculated using the old window's reduce value : * 1. reduce the new values that entered the window (e.g., adding new counts) * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) @@ -374,20 +380,26 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * the new DStream will generate RDDs); must be a multiple of this * DStream's batching interval * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. + * @param filterFunc function to filter expired key-value pairs; + * only pairs that satisfy the function are retained + * set this to null if you do not want to filter */ def reduceByKeyAndWindow( reduceFunc: Function2[V, V, V], invReduceFunc: Function2[V, V, V], windowDuration: Duration, slideDuration: Duration, - partitioner: Partitioner - ): JavaPairDStream[K, V] = { + partitioner: Partitioner, + filterFunc: JFunction[(K, V), java.lang.Boolean] + ): JavaPairDStream[K, V] = { dstream.reduceByKeyAndWindow( reduceFunc, invReduceFunc, windowDuration, slideDuration, - partitioner) + partitioner, + (p: (K, V)) => filterFunc(p).booleanValue() + ) } /** diff --git a/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala index 733d5c4a25..aa5a71e1ed 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala @@ -3,7 +3,7 @@ package spark.streaming.dstream import spark.streaming.StreamingContext._ import spark.RDD -import spark.rdd.CoGroupedRDD +import spark.rdd.{CoGroupedRDD, MapPartitionsRDD} import spark.Partitioner import spark.SparkContext._ import spark.storage.StorageLevel @@ -15,7 +15,8 @@ private[streaming] class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( parent: DStream[(K, V)], reduceFunc: (V, V) => V, - invReduceFunc: (V, V) => V, + invReduceFunc: (V, V) => V, + filterFunc: Option[((K, V)) => Boolean], _windowDuration: Duration, _slideDuration: Duration, partitioner: Partitioner @@ -87,22 +88,25 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( // // Get the RDDs of the reduced values in "old time steps" - val oldRDDs = reducedStream.slice(previousWindow.beginTime, currentWindow.beginTime - parent.slideDuration) + val oldRDDs = + reducedStream.slice(previousWindow.beginTime, currentWindow.beginTime - parent.slideDuration) logDebug("# old RDDs = " + oldRDDs.size) // Get the RDDs of the reduced values in "new time steps" - val newRDDs = reducedStream.slice(previousWindow.endTime + parent.slideDuration, currentWindow.endTime) + val newRDDs = + reducedStream.slice(previousWindow.endTime + parent.slideDuration, currentWindow.endTime) logDebug("# new RDDs = " + newRDDs.size) // Get the RDD of the reduced value of the previous window - val previousWindowRDD = getOrCompute(previousWindow.endTime).getOrElse(ssc.sc.makeRDD(Seq[(K,V)]())) + val previousWindowRDD = + getOrCompute(previousWindow.endTime).getOrElse(ssc.sc.makeRDD(Seq[(K,V)]())) // Make the list of RDDs that needs to cogrouped together for reducing their reduced values val allRDDs = new ArrayBuffer[RDD[(K, V)]]() += previousWindowRDD ++= oldRDDs ++= newRDDs // Cogroup the reduced RDDs and merge the reduced values - val cogroupedRDD = new CoGroupedRDD[K](allRDDs.toSeq.asInstanceOf[Seq[RDD[(_, _)]]], partitioner) - //val mergeValuesFunc = mergeValues(oldRDDs.size, newRDDs.size) _ + val cogroupedRDD = + new CoGroupedRDD[K](allRDDs.toSeq.asInstanceOf[Seq[RDD[(_, _)]]], partitioner) val numOldValues = oldRDDs.size val numNewValues = newRDDs.size @@ -114,7 +118,9 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( // Getting reduced values "old time steps" that will be removed from current window val oldValues = (1 to numOldValues).map(i => seqOfValues(i)).filter(!_.isEmpty).map(_.head) // Getting reduced values "new time steps" - val newValues = (1 to numNewValues).map(i => seqOfValues(numOldValues + i)).filter(!_.isEmpty).map(_.head) + val newValues = + (1 to numNewValues).map(i => seqOfValues(numOldValues + i)).filter(!_.isEmpty).map(_.head) + if (seqOfValues(0).isEmpty) { // If previous window's reduce value does not exist, then at least new values should exist if (newValues.isEmpty) { @@ -140,10 +146,12 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( val mergedValuesRDD = cogroupedRDD.asInstanceOf[RDD[(K,Seq[Seq[V]])]].mapValues(mergeValues) - Some(mergedValuesRDD) + if (filterFunc.isDefined) { + Some(mergedValuesRDD.filter(filterFunc.get)) + } else { + Some(mergedValuesRDD) + } } - - } diff --git a/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala b/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala index 3ffe4b64d0..83d8591a3a 100644 --- a/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala +++ b/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala @@ -291,7 +291,6 @@ class TestOutputStream[T: ClassManifest]( (rdd: RDD[T], t: Time) => { val collected = rdd.collect() output += collected - println(t + ": " + collected.mkString("[", ",", "]")) } ) { diff --git a/streaming/src/test/resources/log4j.properties b/streaming/src/test/resources/log4j.properties index 5652596e1e..f0638e0e02 100644 --- a/streaming/src/test/resources/log4j.properties +++ b/streaming/src/test/resources/log4j.properties @@ -1,7 +1,7 @@ # Set everything to be logged to the file streaming/target/unit-tests.log log4j.rootCategory=WARN, file # log4j.appender.file=org.apache.log4j.FileAppender -log4j.appender.file=org.apache.log4j.ConsoleAppender +log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=false log4j.appender.file.file=streaming/target/unit-tests.log log4j.appender.file.layout=org.apache.log4j.PatternLayout diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala index 1080790147..e6ac7b35aa 100644 --- a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala @@ -84,12 +84,9 @@ class WindowOperationsSuite extends TestSuiteBase { ) /* - The output of the reduceByKeyAndWindow with inverse reduce function is - different from the naive reduceByKeyAndWindow. Even if the count of a - particular key is 0, the key does not get eliminated from the RDDs of - ReducedWindowedDStream. This causes the number of keys in these RDDs to - increase forever. A more generalized version that allows elimination of - keys should be considered. + The output of the reduceByKeyAndWindow with inverse function but without a filter + function will be different from the naive reduceByKeyAndWindow, as no keys get + eliminated from the ReducedWindowedDStream even if the value of a key becomes 0. */ val bigReduceInvOutput = Seq( @@ -177,31 +174,31 @@ class WindowOperationsSuite extends TestSuiteBase { // Testing reduceByKeyAndWindow (with invertible reduce function) - testReduceByKeyAndWindowInv( + testReduceByKeyAndWindowWithInverse( "basic reduction", Seq(Seq(("a", 1), ("a", 3)) ), Seq(Seq(("a", 4)) ) ) - testReduceByKeyAndWindowInv( + testReduceByKeyAndWindowWithInverse( "key already in window and new value added into window", Seq( Seq(("a", 1)), Seq(("a", 1)) ), Seq( Seq(("a", 1)), Seq(("a", 2)) ) ) - testReduceByKeyAndWindowInv( + testReduceByKeyAndWindowWithInverse( "new key added into window", Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 1)) ), Seq( Seq(("a", 1)), Seq(("a", 2), ("b", 1)) ) ) - testReduceByKeyAndWindowInv( + testReduceByKeyAndWindowWithInverse( "key removed from window", Seq( Seq(("a", 1)), Seq(("a", 1)), Seq(), Seq() ), Seq( Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 1)), Seq(("a", 0)) ) ) - testReduceByKeyAndWindowInv( + testReduceByKeyAndWindowWithInverse( "larger slide time", largerSlideInput, largerSlideReduceOutput, @@ -209,7 +206,9 @@ class WindowOperationsSuite extends TestSuiteBase { Seconds(2) ) - testReduceByKeyAndWindowInv("big test", bigInput, bigReduceInvOutput) + testReduceByKeyAndWindowWithInverse("big test", bigInput, bigReduceInvOutput) + + testReduceByKeyAndWindowWithFilteredInverse("big test", bigInput, bigReduceOutput) test("groupByKeyAndWindow") { val input = bigInput @@ -276,27 +275,45 @@ class WindowOperationsSuite extends TestSuiteBase { test("reduceByKeyAndWindow - " + name) { val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt val operation = (s: DStream[(String, Int)]) => { - s.reduceByKeyAndWindow(_ + _, windowDuration, slideDuration).persist() + s.reduceByKeyAndWindow((x: Int, y: Int) => x + y, windowDuration, slideDuration) } testOperation(input, operation, expectedOutput, numBatches, true) } } - def testReduceByKeyAndWindowInv( + def testReduceByKeyAndWindowWithInverse( name: String, input: Seq[Seq[(String, Int)]], expectedOutput: Seq[Seq[(String, Int)]], windowDuration: Duration = Seconds(2), slideDuration: Duration = Seconds(1) ) { - test("reduceByKeyAndWindowInv - " + name) { + test("ReduceByKeyAndWindow with inverse function - " + name) { val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt val operation = (s: DStream[(String, Int)]) => { s.reduceByKeyAndWindow(_ + _, _ - _, windowDuration, slideDuration) - .persist() .checkpoint(Seconds(100)) // Large value to avoid effect of RDD checkpointing } testOperation(input, operation, expectedOutput, numBatches, true) } } + + def testReduceByKeyAndWindowWithFilteredInverse( + name: String, + input: Seq[Seq[(String, Int)]], + expectedOutput: Seq[Seq[(String, Int)]], + windowDuration: Duration = Seconds(2), + slideDuration: Duration = Seconds(1) + ) { + test("reduceByKeyAndWindow with inverse and filter functions - " + name) { + val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt + val filterFunc = (p: (String, Int)) => p._2 != 0 + val operation = (s: DStream[(String, Int)]) => { + s.reduceByKeyAndWindow(_ + _, _ - _, windowDuration, slideDuration, filterFunc = filterFunc) + .persist() + .checkpoint(Seconds(100)) // Large value to avoid effect of RDD checkpointing + } + testOperation(input, operation, expectedOutput, numBatches, true) + } + } } From 03e8dc6861936a0862fba1ca9f830d5ff507718f Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 13 Feb 2013 20:59:29 -0800 Subject: [PATCH 285/291] Changes functions comments to make them more consistent. --- .../streaming/PairDStreamFunctions.scala | 42 ++++++++-------- .../streaming/api/java/JavaPairDStream.scala | 48 +++++++++---------- 2 files changed, 45 insertions(+), 45 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala index 021ff83b36..835b20ae08 100644 --- a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala @@ -26,7 +26,7 @@ extends Serializable { } /** - * Create a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to + * Return a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to * generate the RDDs with Spark's default number of partitions. */ def groupByKey(): DStream[(K, Seq[V])] = { @@ -34,7 +34,7 @@ extends Serializable { } /** - * Create a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to + * Return a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to * generate the RDDs with `numPartitions` partitions. */ def groupByKey(numPartitions: Int): DStream[(K, Seq[V])] = { @@ -42,7 +42,7 @@ extends Serializable { } /** - * Create a new DStream by applying `groupByKey` on each RDD. The supplied [[spark.Partitioner]] + * Return a new DStream by applying `groupByKey` on each RDD. The supplied [[spark.Partitioner]] * is used to control the partitioning of each RDD. */ def groupByKey(partitioner: Partitioner): DStream[(K, Seq[V])] = { @@ -54,7 +54,7 @@ extends Serializable { } /** - * Create a new DStream by applying `reduceByKey` to each RDD. The values for each key are + * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are * merged using the associative reduce function. Hash partitioning is used to generate the RDDs * with Spark's default number of partitions. */ @@ -63,7 +63,7 @@ extends Serializable { } /** - * Create a new DStream by applying `reduceByKey` to each RDD. The values for each key are + * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are * merged using the supplied reduce function. Hash partitioning is used to generate the RDDs * with `numPartitions` partitions. */ @@ -72,7 +72,7 @@ extends Serializable { } /** - * Create a new DStream by applying `reduceByKey` to each RDD. The values for each key are + * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are * merged using the supplied reduce function. [[spark.Partitioner]] is used to control the * partitioning of each RDD. */ @@ -82,7 +82,7 @@ extends Serializable { } /** - * Combine elements of each key in DStream's RDDs using custom function. This is similar to the + * Combine elements of each key in DStream's RDDs using custom functions. This is similar to the * combineByKey for RDDs. Please refer to combineByKey in [[spark.PairRDDFunctions]] for more * information. */ @@ -95,7 +95,7 @@ extends Serializable { } /** - * Create a new DStream by counting the number of values of each key in each RDD. Hash + * Return a new DStream by counting the number of values of each key in each RDD. Hash * partitioning is used to generate the RDDs with Spark's `numPartitions` partitions. */ def countByKey(numPartitions: Int = self.ssc.sc.defaultParallelism): DStream[(K, Long)] = { @@ -103,7 +103,7 @@ extends Serializable { } /** - * Creates a new DStream by applying `groupByKey` over a sliding window. This is similar to + * Return a new DStream by applying `groupByKey` over a sliding window. This is similar to * `DStream.groupByKey()` but applies it over a sliding window. The new DStream generates RDDs * with the same interval as this DStream. Hash partitioning is used to generate the RDDs with * Spark's default number of partitions. @@ -115,7 +115,7 @@ extends Serializable { } /** - * Create a new DStream by applying `groupByKey` over a sliding window. Similar to + * Return a new DStream by applying `groupByKey` over a sliding window. Similar to * `DStream.groupByKey()`, but applies it over a sliding window. Hash partitioning is used to * generate the RDDs with Spark's default number of partitions. * @param windowDuration width of the window; must be a multiple of this DStream's @@ -129,7 +129,7 @@ extends Serializable { } /** - * Create a new DStream by applying `groupByKey` over a sliding window on `this` DStream. + * Return a new DStream by applying `groupByKey` over a sliding window on `this` DStream. * Similar to `DStream.groupByKey()`, but applies it over a sliding window. * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. * @param windowDuration width of the window; must be a multiple of this DStream's @@ -167,7 +167,7 @@ extends Serializable { } /** - * Create a new DStream by applying `reduceByKey` over a sliding window on `this` DStream. + * Return a new DStream by applying `reduceByKey` over a sliding window on `this` DStream. * Similar to `DStream.reduceByKey()`, but applies it over a sliding window. The new DStream * generates RDDs with the same interval as this DStream. Hash partitioning is used to generate * the RDDs with Spark's default number of partitions. @@ -183,7 +183,7 @@ extends Serializable { } /** - * Create a new DStream by applying `reduceByKey` over a sliding window. This is similar to + * Return a new DStream by applying `reduceByKey` over a sliding window. This is similar to * `DStream.reduceByKey()` but applies it over a sliding window. Hash partitioning is used to * generate the RDDs with Spark's default number of partitions. * @param reduceFunc associative reduce function @@ -202,7 +202,7 @@ extends Serializable { } /** - * Create a new DStream by applying `reduceByKey` over a sliding window. This is similar to + * Return a new DStream by applying `reduceByKey` over a sliding window. This is similar to * `DStream.reduceByKey()` but applies it over a sliding window. Hash partitioning is used to * generate the RDDs with `numPartitions` partitions. * @param reduceFunc associative reduce function @@ -223,7 +223,7 @@ extends Serializable { } /** - * Create a new DStream by applying `reduceByKey` over a sliding window. Similar to + * Return a new DStream by applying `reduceByKey` over a sliding window. Similar to * `DStream.reduceByKey()`, but applies it over a sliding window. * @param reduceFunc associative reduce function * @param windowDuration width of the window; must be a multiple of this DStream's @@ -247,7 +247,7 @@ extends Serializable { } /** - * Create a new DStream by applying incremental `reduceByKey` over a sliding window. + * Return a new DStream by applying incremental `reduceByKey` over a sliding window. * The reduced value of over a new window is calculated using the old window's reduce value : * 1. reduce the new values that entered the window (e.g., adding new counts) * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) @@ -280,7 +280,7 @@ extends Serializable { } /** - * Create a new DStream by applying incremental `reduceByKey` over a sliding window. + * Return a new DStream by applying incremental `reduceByKey` over a sliding window. * The reduced value of over a new window is calculated using the old window's reduce value : * 1. reduce the new values that entered the window (e.g., adding new counts) * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) @@ -316,7 +316,7 @@ extends Serializable { } /** - * Create a new DStream by counting the number of values for each key over a window. + * Return a new DStream by counting the number of values for each key over a window. * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval @@ -341,7 +341,7 @@ extends Serializable { } /** - * Create a new "state" DStream where the state for each key is updated by applying + * Return a new "state" DStream where the state for each key is updated by applying * the given function on the previous state of the key and the new values of each key. * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. * @param updateFunc State update function. If `this` function returns None, then @@ -355,7 +355,7 @@ extends Serializable { } /** - * Create a new "state" DStream where the state for each key is updated by applying + * Return a new "state" DStream where the state for each key is updated by applying * the given function on the previous state of the key and the new values of each key. * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. * @param updateFunc State update function. If `this` function returns None, then @@ -390,7 +390,7 @@ extends Serializable { } /** - * Create a new "state" DStream where the state for each key is updated by applying + * Return a new "state" DStream where the state for each key is updated by applying * the given function on the previous state of the key and the new values of each key. * [[spark.Paxrtitioner]] is used to control the partitioning of each RDD. * @param updateFunc State update function. If `this` function returns None, then diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala index 4d3e0d0304..048e10b69c 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala @@ -25,17 +25,17 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( // Methods common to all DStream's // ======================================================================= - /** Returns a new DStream containing only the elements that satisfy a predicate. */ + /** Return a new DStream containing only the elements that satisfy a predicate. */ def filter(f: JFunction[(K, V), java.lang.Boolean]): JavaPairDStream[K, V] = dstream.filter((x => f(x).booleanValue())) - /** Persists RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ + /** Persist RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ def cache(): JavaPairDStream[K, V] = dstream.cache() - /** Persists RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ + /** Persist RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ def persist(): JavaPairDStream[K, V] = dstream.cache() - /** Persists the RDDs of this DStream with the given storage level */ + /** Persist the RDDs of this DStream with the given storage level */ def persist(storageLevel: StorageLevel): JavaPairDStream[K, V] = dstream.persist(storageLevel) /** Method that generates a RDD for the given Duration */ @@ -67,7 +67,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( dstream.window(windowDuration, slideDuration) /** - * Returns a new DStream which computed based on tumbling window on this DStream. + * Return a new DStream which computed based on tumbling window on this DStream. * This is equivalent to window(batchDuration, batchDuration). * @param batchDuration tumbling window duration; must be a multiple of this DStream's interval */ @@ -75,7 +75,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( dstream.tumble(batchDuration) /** - * Returns a new DStream by unifying data of another DStream with this DStream. + * Return a new DStream by unifying data of another DStream with this DStream. * @param that Another DStream having the same interval (i.e., slideDuration) as this DStream. */ def union(that: JavaPairDStream[K, V]): JavaPairDStream[K, V] = @@ -86,21 +86,21 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( // ======================================================================= /** - * Create a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to + * Return a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to * generate the RDDs with Spark's default number of partitions. */ def groupByKey(): JavaPairDStream[K, JList[V]] = dstream.groupByKey().mapValues(seqAsJavaList _) /** - * Create a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to + * Return a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to * generate the RDDs with `numPartitions` partitions. */ def groupByKey(numPartitions: Int): JavaPairDStream[K, JList[V]] = dstream.groupByKey(numPartitions).mapValues(seqAsJavaList _) /** - * Creates a new DStream by applying `groupByKey` on each RDD of `this` DStream. + * Return a new DStream by applying `groupByKey` on each RDD of `this` DStream. * Therefore, the values for each key in `this` DStream's RDDs are grouped into a * single sequence to generate the RDDs of the new DStream. [[spark.Partitioner]] * is used to control the partitioning of each RDD. @@ -109,7 +109,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( dstream.groupByKey(partitioner).mapValues(seqAsJavaList _) /** - * Create a new DStream by applying `reduceByKey` to each RDD. The values for each key are + * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are * merged using the associative reduce function. Hash partitioning is used to generate the RDDs * with Spark's default number of partitions. */ @@ -117,7 +117,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( dstream.reduceByKey(func) /** - * Create a new DStream by applying `reduceByKey` to each RDD. The values for each key are + * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are * merged using the supplied reduce function. Hash partitioning is used to generate the RDDs * with `numPartitions` partitions. */ @@ -125,7 +125,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( dstream.reduceByKey(func, numPartitions) /** - * Create a new DStream by applying `reduceByKey` to each RDD. The values for each key are + * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are * merged using the supplied reduce function. [[spark.Partitioner]] is used to control the * partitioning of each RDD. */ @@ -149,7 +149,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Create a new DStream by counting the number of values of each key in each RDD. Hash + * Return a new DStream by counting the number of values of each key in each RDD. Hash * partitioning is used to generate the RDDs with Spark's `numPartitions` partitions. */ def countByKey(numPartitions: Int): JavaPairDStream[K, JLong] = { @@ -158,7 +158,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( /** - * Create a new DStream by counting the number of values of each key in each RDD. Hash + * Return a new DStream by counting the number of values of each key in each RDD. Hash * partitioning is used to generate the RDDs with the default number of partitions. */ def countByKey(): JavaPairDStream[K, JLong] = { @@ -166,7 +166,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Creates a new DStream by applying `groupByKey` over a sliding window. This is similar to + * Return a new DStream by applying `groupByKey` over a sliding window. This is similar to * `DStream.groupByKey()` but applies it over a sliding window. The new DStream generates RDDs * with the same interval as this DStream. Hash partitioning is used to generate the RDDs with * Spark's default number of partitions. @@ -178,7 +178,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Create a new DStream by applying `groupByKey` over a sliding window. Similar to + * Return a new DStream by applying `groupByKey` over a sliding window. Similar to * `DStream.groupByKey()`, but applies it over a sliding window. Hash partitioning is used to * generate the RDDs with Spark's default number of partitions. * @param windowDuration width of the window; must be a multiple of this DStream's @@ -193,7 +193,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Create a new DStream by applying `groupByKey` over a sliding window on `this` DStream. + * Return a new DStream by applying `groupByKey` over a sliding window on `this` DStream. * Similar to `DStream.groupByKey()`, but applies it over a sliding window. * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. * @param windowDuration width of the window; must be a multiple of this DStream's @@ -210,7 +210,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Create a new DStream by applying `groupByKey` over a sliding window on `this` DStream. + * Return a new DStream by applying `groupByKey` over a sliding window on `this` DStream. * Similar to `DStream.groupByKey()`, but applies it over a sliding window. * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval @@ -243,7 +243,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Create a new DStream by applying `reduceByKey` over a sliding window. This is similar to + * Return a new DStream by applying `reduceByKey` over a sliding window. This is similar to * `DStream.reduceByKey()` but applies it over a sliding window. Hash partitioning is used to * generate the RDDs with Spark's default number of partitions. * @param reduceFunc associative reduce function @@ -262,7 +262,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Create a new DStream by applying `reduceByKey` over a sliding window. This is similar to + * Return a new DStream by applying `reduceByKey` over a sliding window. This is similar to * `DStream.reduceByKey()` but applies it over a sliding window. Hash partitioning is used to * generate the RDDs with `numPartitions` partitions. * @param reduceFunc associative reduce function @@ -283,7 +283,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Create a new DStream by applying `reduceByKey` over a sliding window. Similar to + * Return a new DStream by applying `reduceByKey` over a sliding window. Similar to * `DStream.reduceByKey()`, but applies it over a sliding window. * @param reduceFunc associative reduce function * @param windowDuration width of the window; must be a multiple of this DStream's @@ -303,7 +303,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Create a new DStream by reducing over a using incremental computation. + * Return a new DStream by reducing over a using incremental computation. * The reduced value of over a new window is calculated using the old window's reduce value : * 1. reduce the new values that entered the window (e.g., adding new counts) * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) @@ -328,7 +328,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Create a new DStream by applying incremental `reduceByKey` over a sliding window. + * Return a new DStream by applying incremental `reduceByKey` over a sliding window. * The reduced value of over a new window is calculated using the old window's reduce value : * 1. reduce the new values that entered the window (e.g., adding new counts) * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) @@ -366,7 +366,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Create a new DStream by applying incremental `reduceByKey` over a sliding window. + * Return a new DStream by applying incremental `reduceByKey` over a sliding window. * The reduced value of over a new window is calculated using the old window's reduce value : * 1. reduce the new values that entered the window (e.g., adding new counts) * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) From 2eacf22401f75b956036fb0c32eb38baa16b224e Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 14 Feb 2013 12:21:47 -0800 Subject: [PATCH 286/291] Removed countByKeyAndWindow on paired DStreams, and added countByValueAndWindow for all DStreams. Updated both scala and java API and testsuites. --- .../examples/clickstream/PageViewStream.scala | 11 ++- .../main/scala/spark/streaming/DStream.scala | 88 +++++++++++++++---- .../streaming/PairDStreamFunctions.scala | 43 ++------- .../streaming/api/java/JavaDStream.scala | 27 +++--- .../streaming/api/java/JavaDStreamLike.scala | 87 +++++++++++++++++- .../streaming/api/java/JavaPairDStream.scala | 56 +----------- .../java/spark/streaming/JavaAPISuite.java | 79 +++++++---------- .../streaming/BasicOperationsSuite.scala | 21 ++++- .../streaming/WindowOperationsSuite.scala | 8 +- 9 files changed, 231 insertions(+), 189 deletions(-) diff --git a/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala b/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala index a191321d91..60f228b8ad 100644 --- a/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala +++ b/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala @@ -28,16 +28,15 @@ object PageViewStream { // Create a NetworkInputDStream on target host:port and convert each line to a PageView val pageViews = ssc.networkTextStream(host, port) - .flatMap(_.split("\n")) - .map(PageView.fromString(_)) + .flatMap(_.split("\n")) + .map(PageView.fromString(_)) // Return a count of views per URL seen in each batch - val pageCounts = pageViews.map(view => ((view.url, 1))).countByKey() + val pageCounts = pageViews.map(view => view.url).countByValue() // Return a sliding window of page views per URL in the last ten seconds - val slidingPageCounts = pageViews.map(view => ((view.url, 1))) - .window(Seconds(10), Seconds(2)) - .countByKey() + val slidingPageCounts = pageViews.map(view => view.url) + .countByValueAndWindow(Seconds(10), Seconds(2)) // Return the rate of error pages (a non 200 status) in each zip code over the last 30 seconds diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 6abec9e6be..ce42b742d7 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -441,6 +441,15 @@ abstract class DStream[T: ClassManifest] ( */ def count(): DStream[Long] = this.map(_ => 1L).reduce(_ + _) + /** + * Return a new DStream in which each RDD contains the counts of each distinct value in + * each RDD of this DStream. Hash partitioning is used to generate + * the RDDs with `numPartitions` partitions (Spark's default number of partitions if + * `numPartitions` not specified). + */ + def countByValue(numPartitions: Int = ssc.sc.defaultParallelism): DStream[(T, Long)] = + this.map(x => (x, 1L)).reduceByKey((x: Long, y: Long) => x + y, numPartitions) + /** * Apply a function to each RDD in this DStream. This is an output operator, so * this DStream will be registered as an output stream and therefore materialized. @@ -494,14 +503,16 @@ abstract class DStream[T: ClassManifest] ( } /** - * Return a new DStream which is computed based on windowed batches of this DStream. - * The new DStream generates RDDs with the same interval as this DStream. + * Return a new DStream in which each RDD contains all the elements in seen in a + * sliding window of time over this DStream. The new DStream generates RDDs with + * the same interval as this DStream. * @param windowDuration width of the window; must be a multiple of this DStream's interval. */ def window(windowDuration: Duration): DStream[T] = window(windowDuration, this.slideDuration) /** - * Return a new DStream which is computed based on windowed batches of this DStream. + * Return a new DStream in which each RDD contains all the elements in seen in a + * sliding window of time over this DStream. * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval * @param slideDuration sliding interval of the window (i.e., the interval after which @@ -512,19 +523,15 @@ abstract class DStream[T: ClassManifest] ( new WindowedDStream(this, windowDuration, slideDuration) } - /** - * Return a new DStream which computed based on tumbling window on this DStream. - * This is equivalent to window(batchTime, batchTime). - * @param batchDuration tumbling window duration; must be a multiple of this DStream's - * batching interval - */ - def tumble(batchDuration: Duration): DStream[T] = window(batchDuration, batchDuration) - /** * Return a new DStream in which each RDD has a single element generated by reducing all - * elements in a window over this DStream. windowDuration and slideDuration are as defined - * in the window() operation. This is equivalent to - * window(windowDuration, slideDuration).reduce(reduceFunc) + * elements in a sliding window over this DStream. + * @param reduceFunc associative reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval */ def reduceByWindow( reduceFunc: (T, T) => T, @@ -534,6 +541,22 @@ abstract class DStream[T: ClassManifest] ( this.reduce(reduceFunc).window(windowDuration, slideDuration).reduce(reduceFunc) } + /** + * Return a new DStream in which each RDD has a single element generated by reducing all + * elements in a sliding window over this DStream. However, the reduction is done incrementally + * using the old window's reduced value : + * 1. reduce the new values that entered the window (e.g., adding new counts) + * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + * This is more efficient than reduceByWindow without "inverse reduce" function. + * However, it is applicable to only "invertible reduce functions". + * @param reduceFunc associative reduce function + * @param invReduceFunc inverse reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + */ def reduceByWindow( reduceFunc: (T, T) => T, invReduceFunc: (T, T) => T, @@ -547,13 +570,46 @@ abstract class DStream[T: ClassManifest] ( /** * Return a new DStream in which each RDD has a single element generated by counting the number - * of elements in a window over this DStream. windowDuration and slideDuration are as defined in the - * window() operation. This is equivalent to window(windowDuration, slideDuration).count() + * of elements in a sliding window over this DStream. Hash partitioning is used to generate the RDDs with + * Spark's default number of partitions. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval */ def countByWindow(windowDuration: Duration, slideDuration: Duration): DStream[Long] = { this.map(_ => 1L).reduceByWindow(_ + _, _ - _, windowDuration, slideDuration) } + /** + * Return a new DStream in which each RDD contains the count of distinct elements in + * RDDs in a sliding window over this DStream. Hash partitioning is used to generate + * the RDDs with `numPartitions` partitions (Spark's default number of partitions if + * `numPartitions` not specified). + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param numPartitions number of partitions of each RDD in the new DStream. + */ + def countByValueAndWindow( + windowDuration: Duration, + slideDuration: Duration, + numPartitions: Int = ssc.sc.defaultParallelism + ): DStream[(T, Long)] = { + + this.map(x => (x, 1L)).reduceByKeyAndWindow( + (x: Long, y: Long) => x + y, + (x: Long, y: Long) => x - y, + windowDuration, + slideDuration, + numPartitions, + (x: (T, Long)) => x._2 != 0L + ) + } + /** * Return a new DStream by unifying data of another DStream with this DStream. * @param that Another DStream having the same slideDuration as this DStream. diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala index 835b20ae08..5127db3bbc 100644 --- a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala @@ -94,14 +94,6 @@ extends Serializable { new ShuffledDStream[K, V, C](self, createCombiner, mergeValue, mergeCombiner, partitioner) } - /** - * Return a new DStream by counting the number of values of each key in each RDD. Hash - * partitioning is used to generate the RDDs with Spark's `numPartitions` partitions. - */ - def countByKey(numPartitions: Int = self.ssc.sc.defaultParallelism): DStream[(K, Long)] = { - self.map(x => (x._1, 1L)).reduceByKey((x: Long, y: Long) => x + y, numPartitions) - } - /** * Return a new DStream by applying `groupByKey` over a sliding window. This is similar to * `DStream.groupByKey()` but applies it over a sliding window. The new DStream generates RDDs @@ -211,7 +203,7 @@ extends Serializable { * @param slideDuration sliding interval of the window (i.e., the interval after which * the new DStream will generate RDDs); must be a multiple of this * DStream's batching interval - * @param numPartitions Number of partitions of each RDD in the new DStream. + * @param numPartitions number of partitions of each RDD in the new DStream. */ def reduceByKeyAndWindow( reduceFunc: (V, V) => V, @@ -248,10 +240,10 @@ extends Serializable { /** * Return a new DStream by applying incremental `reduceByKey` over a sliding window. - * The reduced value of over a new window is calculated using the old window's reduce value : + * The reduced value of over a new window is calculated using the old window's reduced value : * 1. reduce the new values that entered the window (e.g., adding new counts) * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) - * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. + * This is more efficient than reduceByKeyAndWindow without "inverse reduce" function. * However, it is applicable to only "invertible reduce functions". * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. * @param reduceFunc associative reduce function @@ -281,10 +273,10 @@ extends Serializable { /** * Return a new DStream by applying incremental `reduceByKey` over a sliding window. - * The reduced value of over a new window is calculated using the old window's reduce value : + * The reduced value of over a new window is calculated using the old window's reduced value : * 1. reduce the new values that entered the window (e.g., adding new counts) * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) - * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. + * This is more efficient than reduceByKeyAndWindow without "inverse reduce" function. * However, it is applicable to only "invertible reduce functions". * @param reduceFunc associative reduce function * @param invReduceFunc inverse reduce function @@ -315,31 +307,6 @@ extends Serializable { ) } - /** - * Return a new DStream by counting the number of values for each key over a window. - * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - * @param numPartitions Number of partitions of each RDD in the new DStream. - */ - def countByKeyAndWindow( - windowDuration: Duration, - slideDuration: Duration, - numPartitions: Int = self.ssc.sc.defaultParallelism - ): DStream[(K, Long)] = { - - self.map(x => (x._1, 1L)).reduceByKeyAndWindow( - (x: Long, y: Long) => x + y, - (x: Long, y: Long) => x - y, - windowDuration, - slideDuration, - numPartitions - ) - } - /** * Return a new "state" DStream where the state for each key is updated by applying * the given function on the previous state of the key and the new values of each key. diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala index 2e7466b16c..30985b4ebc 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala @@ -36,7 +36,7 @@ class JavaDStream[T](val dstream: DStream[T])(implicit val classManifest: ClassM def cache(): JavaDStream[T] = dstream.cache() /** Persist RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ - def persist(): JavaDStream[T] = dstream.cache() + def persist(): JavaDStream[T] = dstream.persist() /** Persist the RDDs of this DStream with the given storage level */ def persist(storageLevel: StorageLevel): JavaDStream[T] = dstream.persist(storageLevel) @@ -50,33 +50,26 @@ class JavaDStream[T](val dstream: DStream[T])(implicit val classManifest: ClassM } /** - * Return a new DStream which is computed based on windowed batches of this DStream. - * The new DStream generates RDDs with the same interval as this DStream. + * Return a new DStream in which each RDD contains all the elements in seen in a + * sliding window of time over this DStream. The new DStream generates RDDs with + * the same interval as this DStream. * @param windowDuration width of the window; must be a multiple of this DStream's interval. - * @return */ def window(windowDuration: Duration): JavaDStream[T] = dstream.window(windowDuration) /** - * Return a new DStream which is computed based on windowed batches of this DStream. - * @param windowDuration duration (i.e., width) of the window; - * must be a multiple of this DStream's interval + * Return a new DStream in which each RDD contains all the elements in seen in a + * sliding window of time over this DStream. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's interval + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval */ def window(windowDuration: Duration, slideDuration: Duration): JavaDStream[T] = dstream.window(windowDuration, slideDuration) - /** - * Return a new DStream which computed based on tumbling window on this DStream. - * This is equivalent to window(batchDuration, batchDuration). - * @param batchDuration tumbling window duration; must be a multiple of this DStream's interval - */ - def tumble(batchDuration: Duration): JavaDStream[T] = - dstream.tumble(batchDuration) - /** * Return a new DStream by unifying data of another DStream with this DStream. * @param that Another DStream having the same interval (i.e., slideDuration) as this DStream. diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala index b93cb7865a..1c1ba05ff9 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala @@ -33,6 +33,26 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable */ def count(): JavaDStream[JLong] = dstream.count() + /** + * Return a new DStream in which each RDD contains the counts of each distinct value in + * each RDD of this DStream. Hash partitioning is used to generate the RDDs with + * Spark's default number of partitions. + */ + def countByValue(): JavaPairDStream[T, JLong] = { + JavaPairDStream.scalaToJavaLong(dstream.countByValue()) + } + + /** + * Return a new DStream in which each RDD contains the counts of each distinct value in + * each RDD of this DStream. Hash partitioning is used to generate the RDDs with `numPartitions` + * partitions. + * @param numPartitions number of partitions of each RDD in the new DStream. + */ + def countByValue(numPartitions: Int): JavaPairDStream[T, JLong] = { + JavaPairDStream.scalaToJavaLong(dstream.countByValue(numPartitions)) + } + + /** * Return a new DStream in which each RDD has a single element generated by counting the number * of elements in a window over this DStream. windowDuration and slideDuration are as defined in the @@ -42,6 +62,39 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable dstream.countByWindow(windowDuration, slideDuration) } + /** + * Return a new DStream in which each RDD contains the count of distinct elements in + * RDDs in a sliding window over this DStream. Hash partitioning is used to generate the RDDs with + * Spark's default number of partitions. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + */ + def countByValueAndWindow(windowDuration: Duration, slideDuration: Duration) + : JavaPairDStream[T, JLong] = { + JavaPairDStream.scalaToJavaLong( + dstream.countByValueAndWindow(windowDuration, slideDuration)) + } + + /** + * Return a new DStream in which each RDD contains the count of distinct elements in + * RDDs in a sliding window over this DStream. Hash partitioning is used to generate the RDDs with `numPartitions` + * partitions. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param numPartitions number of partitions of each RDD in the new DStream. + */ + def countByValueAndWindow(windowDuration: Duration, slideDuration: Duration, numPartitions: Int) + : JavaPairDStream[T, JLong] = { + JavaPairDStream.scalaToJavaLong( + dstream.countByValueAndWindow(windowDuration, slideDuration, numPartitions)) + } + /** * Return a new DStream in which each RDD is generated by applying glom() to each RDD of * this DStream. Applying glom() to an RDD coalesces all elements within each partition into @@ -114,8 +167,38 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable /** * Return a new DStream in which each RDD has a single element generated by reducing all - * elements in a window over this DStream. windowDuration and slideDuration are as defined in the - * window() operation. This is equivalent to window(windowDuration, slideDuration).reduce(reduceFunc) + * elements in a sliding window over this DStream. + * @param reduceFunc associative reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + */ + def reduceByWindow( + reduceFunc: (T, T) => T, + windowDuration: Duration, + slideDuration: Duration + ): DStream[T] = { + dstream.reduceByWindow(reduceFunc, windowDuration, slideDuration) + } + + + /** + * Return a new DStream in which each RDD has a single element generated by reducing all + * elements in a sliding window over this DStream. However, the reduction is done incrementally + * using the old window's reduced value : + * 1. reduce the new values that entered the window (e.g., adding new counts) + * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + * This is more efficient than reduceByWindow without "inverse reduce" function. + * However, it is applicable to only "invertible reduce functions". + * @param reduceFunc associative reduce function + * @param invReduceFunc inverse reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval */ def reduceByWindow( reduceFunc: JFunction2[T, T, T], diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala index 048e10b69c..952ca657bf 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala @@ -33,7 +33,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( def cache(): JavaPairDStream[K, V] = dstream.cache() /** Persist RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ - def persist(): JavaPairDStream[K, V] = dstream.cache() + def persist(): JavaPairDStream[K, V] = dstream.persist() /** Persist the RDDs of this DStream with the given storage level */ def persist(storageLevel: StorageLevel): JavaPairDStream[K, V] = dstream.persist(storageLevel) @@ -66,14 +66,6 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( def window(windowDuration: Duration, slideDuration: Duration): JavaPairDStream[K, V] = dstream.window(windowDuration, slideDuration) - /** - * Return a new DStream which computed based on tumbling window on this DStream. - * This is equivalent to window(batchDuration, batchDuration). - * @param batchDuration tumbling window duration; must be a multiple of this DStream's interval - */ - def tumble(batchDuration: Duration): JavaPairDStream[K, V] = - dstream.tumble(batchDuration) - /** * Return a new DStream by unifying data of another DStream with this DStream. * @param that Another DStream having the same interval (i.e., slideDuration) as this DStream. @@ -148,23 +140,6 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( dstream.combineByKey(createCombiner, mergeValue, mergeCombiners, partitioner) } - /** - * Return a new DStream by counting the number of values of each key in each RDD. Hash - * partitioning is used to generate the RDDs with Spark's `numPartitions` partitions. - */ - def countByKey(numPartitions: Int): JavaPairDStream[K, JLong] = { - JavaPairDStream.scalaToJavaLong(dstream.countByKey(numPartitions)); - } - - - /** - * Return a new DStream by counting the number of values of each key in each RDD. Hash - * partitioning is used to generate the RDDs with the default number of partitions. - */ - def countByKey(): JavaPairDStream[K, JLong] = { - JavaPairDStream.scalaToJavaLong(dstream.countByKey()); - } - /** * Return a new DStream by applying `groupByKey` over a sliding window. This is similar to * `DStream.groupByKey()` but applies it over a sliding window. The new DStream generates RDDs @@ -402,35 +377,6 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( ) } - /** - * Create a new DStream by counting the number of values for each key over a window. - * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - */ - def countByKeyAndWindow(windowDuration: Duration, slideDuration: Duration) - : JavaPairDStream[K, JLong] = { - JavaPairDStream.scalaToJavaLong(dstream.countByKeyAndWindow(windowDuration, slideDuration)) - } - - /** - * Create a new DStream by counting the number of values for each key over a window. - * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - * @param numPartitions Number of partitions of each RDD in the new DStream. - */ - def countByKeyAndWindow(windowDuration: Duration, slideDuration: Duration, numPartitions: Int) - : JavaPairDStream[K, Long] = { - dstream.countByKeyAndWindow(windowDuration, slideDuration, numPartitions) - } - private def convertUpdateStateFunction[S](in: JFunction2[JList[V], Optional[S], Optional[S]]): (Seq[V], Option[S]) => Option[S] = { val scalaFunc: (Seq[V], Option[S]) => Option[S] = (values, state) => { diff --git a/streaming/src/test/java/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/spark/streaming/JavaAPISuite.java index 783a393a8f..7bea0b1fc4 100644 --- a/streaming/src/test/java/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/spark/streaming/JavaAPISuite.java @@ -134,29 +134,6 @@ public class JavaAPISuite implements Serializable { assertOrderInvariantEquals(expected, result); } - @Test - public void testTumble() { - List> inputData = Arrays.asList( - Arrays.asList(1,2,3), - Arrays.asList(4,5,6), - Arrays.asList(7,8,9), - Arrays.asList(10,11,12), - Arrays.asList(13,14,15), - Arrays.asList(16,17,18)); - - List> expected = Arrays.asList( - Arrays.asList(1,2,3,4,5,6), - Arrays.asList(7,8,9,10,11,12), - Arrays.asList(13,14,15,16,17,18)); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream windowed = stream.tumble(new Duration(2000)); - JavaTestUtils.attachTestOutputStream(windowed); - List> result = JavaTestUtils.runStreams(ssc, 6, 3); - - assertOrderInvariantEquals(expected, result); - } - @Test public void testFilter() { List> inputData = Arrays.asList( @@ -584,24 +561,26 @@ public class JavaAPISuite implements Serializable { } @Test - public void testCountByKey() { - List>> inputData = stringStringKVStream; + public void testCountByValue() { + List> inputData = Arrays.asList( + Arrays.asList("hello", "world"), + Arrays.asList("hello", "moon"), + Arrays.asList("hello")); List>> expected = Arrays.asList( - Arrays.asList( - new Tuple2("california", 2L), - new Tuple2("new york", 2L)), - Arrays.asList( - new Tuple2("california", 2L), - new Tuple2("new york", 2L))); + Arrays.asList( + new Tuple2("hello", 1L), + new Tuple2("world", 1L)), + Arrays.asList( + new Tuple2("hello", 1L), + new Tuple2("moon", 1L)), + Arrays.asList( + new Tuple2("hello", 1L))); - JavaDStream> stream = JavaTestUtils.attachTestInputStream( - ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream counted = pairStream.countByKey(); + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream counted = stream.countByValue(); JavaTestUtils.attachTestOutputStream(counted); - List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + List>> result = JavaTestUtils.runStreams(ssc, 3, 3); Assert.assertEquals(expected, result); } @@ -712,26 +691,28 @@ public class JavaAPISuite implements Serializable { } @Test - public void testCountByKeyAndWindow() { - List>> inputData = stringStringKVStream; + public void testCountByValueAndWindow() { + List> inputData = Arrays.asList( + Arrays.asList("hello", "world"), + Arrays.asList("hello", "moon"), + Arrays.asList("hello")); List>> expected = Arrays.asList( Arrays.asList( - new Tuple2("california", 2L), - new Tuple2("new york", 2L)), + new Tuple2("hello", 1L), + new Tuple2("world", 1L)), Arrays.asList( - new Tuple2("california", 4L), - new Tuple2("new york", 4L)), + new Tuple2("hello", 2L), + new Tuple2("world", 1L), + new Tuple2("moon", 1L)), Arrays.asList( - new Tuple2("california", 2L), - new Tuple2("new york", 2L))); + new Tuple2("hello", 2L), + new Tuple2("moon", 1L))); - JavaDStream> stream = JavaTestUtils.attachTestInputStream( + JavaDStream stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream counted = - pairStream.countByKeyAndWindow(new Duration(2000), new Duration(1000)); + stream.countByValueAndWindow(new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(counted); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); diff --git a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala index 12388b8887..1e86cf49bb 100644 --- a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala @@ -24,7 +24,7 @@ class BasicOperationsSuite extends TestSuiteBase { ) } - test("flatmap") { + test("flatMap") { val input = Seq(1 to 4, 5 to 8, 9 to 12) testOperation( input, @@ -88,6 +88,23 @@ class BasicOperationsSuite extends TestSuiteBase { ) } + test("count") { + testOperation( + Seq(1 to 1, 1 to 2, 1 to 3, 1 to 4), + (s: DStream[Int]) => s.count(), + Seq(Seq(1L), Seq(2L), Seq(3L), Seq(4L)) + ) + } + + test("countByValue") { + testOperation( + Seq(1 to 1, Seq(1, 1, 1), 1 to 2, Seq(1, 1, 2, 2)), + (s: DStream[Int]) => s.countByValue(), + Seq(Seq((1, 1L)), Seq((1, 3L)), Seq((1, 1L), (2, 1L)), Seq((2, 2L), (1, 2L))), + true + ) + } + test("mapValues") { testOperation( Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), @@ -206,7 +223,7 @@ class BasicOperationsSuite extends TestSuiteBase { case _ => Option(stateObj) } } - s.map(_ -> 1).updateStateByKey[StateObject](updateFunc).mapValues(_.counter) + s.map(x => (x, 1)).updateStateByKey[StateObject](updateFunc).mapValues(_.counter) } testOperation(inputData, updateStateOperation, outputData, true) diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala index e6ac7b35aa..f8380af331 100644 --- a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala @@ -236,14 +236,14 @@ class WindowOperationsSuite extends TestSuiteBase { testOperation(input, operation, expectedOutput, numBatches, true) } - test("countByKeyAndWindow") { - val input = Seq(Seq(("a", 1)), Seq(("b", 1), ("b", 2)), Seq(("a", 10), ("b", 20))) + test("countByValueAndWindow") { + val input = Seq(Seq("a"), Seq("b", "b"), Seq("a", "b")) val expectedOutput = Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 2)), Seq(("a", 1), ("b", 3))) val windowDuration = Seconds(2) val slideDuration = Seconds(1) val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt - val operation = (s: DStream[(String, Int)]) => { - s.countByKeyAndWindow(windowDuration, slideDuration).map(x => (x._1, x._2.toInt)) + val operation = (s: DStream[String]) => { + s.countByValueAndWindow(windowDuration, slideDuration).map(x => (x._1, x._2.toInt)) } testOperation(input, operation, expectedOutput, numBatches, true) } From def8126d7788a8bd991ac6f9f9403de701a39dc5 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 14 Feb 2013 17:49:43 -0800 Subject: [PATCH 287/291] Added TwitterInputDStream from example to StreamingContext. Renamed example TwitterBasic to TwitterPopularTags. --- ...erBasic.scala => TwitterPopularTags.scala} | 33 +++++------- project/SparkBuild.scala | 8 ++- .../spark/streaming/StreamingContext.scala | 52 +++++++++++++------ .../dstream}/TwitterInputDStream.scala | 5 +- 4 files changed, 53 insertions(+), 45 deletions(-) rename examples/src/main/scala/spark/streaming/examples/{twitter/TwitterBasic.scala => TwitterPopularTags.scala} (55%) rename {examples/src/main/scala/spark/streaming/examples/twitter => streaming/src/main/scala/spark/streaming/dstream}/TwitterInputDStream.scala (94%) diff --git a/examples/src/main/scala/spark/streaming/examples/twitter/TwitterBasic.scala b/examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala similarity index 55% rename from examples/src/main/scala/spark/streaming/examples/twitter/TwitterBasic.scala rename to examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala index 377bc0c98e..fdb3a4c73c 100644 --- a/examples/src/main/scala/spark/streaming/examples/twitter/TwitterBasic.scala +++ b/examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala @@ -1,19 +1,19 @@ -package spark.streaming.examples.twitter +package spark.streaming.examples -import spark.streaming.StreamingContext._ import spark.streaming.{Seconds, StreamingContext} +import StreamingContext._ import spark.SparkContext._ -import spark.storage.StorageLevel /** * Calculates popular hashtags (topics) over sliding 10 and 60 second windows from a Twitter * stream. The stream is instantiated with credentials and optionally filters supplied by the * command line arguments. + * */ -object TwitterBasic { +object TwitterPopularTags { def main(args: Array[String]) { if (args.length < 3) { - System.err.println("Usage: TwitterBasic " + + System.err.println("Usage: TwitterPopularTags " + " [filter1] [filter2] ... [filter n]") System.exit(1) } @@ -21,10 +21,8 @@ object TwitterBasic { val Array(master, username, password) = args.slice(0, 3) val filters = args.slice(3, args.length) - val ssc = new StreamingContext(master, "TwitterBasic", Seconds(2)) - val stream = new TwitterInputDStream(ssc, username, password, filters, - StorageLevel.MEMORY_ONLY_SER) - ssc.registerInputStream(stream) + val ssc = new StreamingContext(master, "TwitterPopularTags", Seconds(2)) + val stream = ssc.twitterStream(username, password, filters) val hashTags = stream.flatMap(status => status.getText.split(" ").filter(_.startsWith("#"))) @@ -39,22 +37,17 @@ object TwitterBasic { // Print popular hashtags topCounts60.foreach(rdd => { - if (rdd.count() != 0) { - val topList = rdd.take(5) - println("\nPopular topics in last 60 seconds (%s total):".format(rdd.count())) - topList.foreach{case (count, tag) => println("%s (%s tweets)".format(tag, count))} - } + val topList = rdd.take(5) + println("\nPopular topics in last 60 seconds (%s total):".format(rdd.count())) + topList.foreach{case (count, tag) => println("%s (%s tweets)".format(tag, count))} }) topCounts10.foreach(rdd => { - if (rdd.count() != 0) { - val topList = rdd.take(5) - println("\nPopular topics in last 10 seconds (%s total):".format(rdd.count())) - topList.foreach{case (count, tag) => println("%s (%s tweets)".format(tag, count))} - } + val topList = rdd.take(5) + println("\nPopular topics in last 10 seconds (%s total):".format(rdd.count())) + topList.foreach{case (count, tag) => println("%s (%s tweets)".format(tag, count))} }) ssc.start() } - } diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index af8b5ba017..c6d3cc8b15 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -154,10 +154,7 @@ object SparkBuild extends Build { ) def examplesSettings = sharedSettings ++ Seq( - name := "spark-examples", - libraryDependencies ++= Seq( - "org.twitter4j" % "twitter4j-stream" % "3.0.3" - ) + name := "spark-examples" ) def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel") @@ -166,7 +163,8 @@ object SparkBuild extends Build { name := "spark-streaming", libraryDependencies ++= Seq( "org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile", - "com.github.sgroschupf" % "zkclient" % "0.1" + "com.github.sgroschupf" % "zkclient" % "0.1", + "org.twitter4j" % "twitter4j-stream" % "3.0.3" ) ) ++ assemblySettings ++ extraAssemblySettings diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 8cfbec51d2..9be9d884be 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -17,6 +17,7 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.hadoop.mapreduce.lib.input.TextInputFormat import org.apache.hadoop.fs.Path import java.util.UUID +import twitter4j.Status /** * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic @@ -30,14 +31,14 @@ class StreamingContext private ( ) extends Logging { /** - * Creates a StreamingContext using an existing SparkContext. + * Create a StreamingContext using an existing SparkContext. * @param sparkContext Existing SparkContext * @param batchDuration The time interval at which streaming data will be divided into batches */ def this(sparkContext: SparkContext, batchDuration: Duration) = this(sparkContext, null, batchDuration) /** - * Creates a StreamingContext by providing the details necessary for creating a new SparkContext. + * Create a StreamingContext by providing the details necessary for creating a new SparkContext. * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). * @param frameworkName A name for your job, to display on the cluster web UI * @param batchDuration The time interval at which streaming data will be divided into batches @@ -46,7 +47,7 @@ class StreamingContext private ( this(StreamingContext.createNewSparkContext(master, frameworkName), null, batchDuration) /** - * Re-creates a StreamingContext from a checkpoint file. + * Re-create a StreamingContext from a checkpoint file. * @param path Path either to the directory that was specified as the checkpoint directory, or * to the checkpoint file 'graph' or 'graph.bk'. */ @@ -101,12 +102,12 @@ class StreamingContext private ( protected[streaming] var scheduler: Scheduler = null /** - * Returns the associated Spark context + * Return the associated Spark context */ def sparkContext = sc /** - * Sets each DStreams in this context to remember RDDs it generated in the last given duration. + * Set each DStreams in this context to remember RDDs it generated in the last given duration. * DStreams remember RDDs only for a limited duration of time and releases them for garbage * collection. This method allows the developer to specify how to long to remember the RDDs ( * if the developer wishes to query old data outside the DStream computation). @@ -117,7 +118,7 @@ class StreamingContext private ( } /** - * Sets the context to periodically checkpoint the DStream operations for master + * Set the context to periodically checkpoint the DStream operations for master * fault-tolerance. By default, the graph will be checkpointed every batch interval. * @param directory HDFS-compatible directory where the checkpoint data will be reliably stored * @param interval checkpoint interval @@ -200,7 +201,7 @@ class StreamingContext private ( } /** - * Creates a input stream from a Flume source. + * Create a input stream from a Flume source. * @param hostname Hostname of the slave machine to which the flume data will be sent * @param port Port of the slave machine to which the flume data will be sent * @param storageLevel Storage level to use for storing the received objects @@ -236,7 +237,7 @@ class StreamingContext private ( } /** - * Creates a input stream that monitors a Hadoop-compatible filesystem + * Create a input stream that monitors a Hadoop-compatible filesystem * for new files and reads them using the given key-value types and input format. * File names starting with . are ignored. * @param directory HDFS directory to monitor for new file @@ -255,7 +256,7 @@ class StreamingContext private ( } /** - * Creates a input stream that monitors a Hadoop-compatible filesystem + * Create a input stream that monitors a Hadoop-compatible filesystem * for new files and reads them using the given key-value types and input format. * @param directory HDFS directory to monitor for new file * @param filter Function to filter paths to process @@ -274,9 +275,8 @@ class StreamingContext private ( inputStream } - /** - * Creates a input stream that monitors a Hadoop-compatible filesystem + * Create a input stream that monitors a Hadoop-compatible filesystem * for new files and reads them as text files (using key as LongWritable, value * as Text and input format as TextInputFormat). File names starting with . are ignored. * @param directory HDFS directory to monitor for new file @@ -286,7 +286,25 @@ class StreamingContext private ( } /** - * Creates an input stream from a queue of RDDs. In each batch, + * Create a input stream that returns tweets received from Twitter. + * @param username Twitter username + * @param password Twitter password + * @param filters Set of filter strings to get only those tweets that match them + * @param storageLevel Storage level to use for storing the received objects + */ + def twitterStream( + username: String, + password: String, + filters: Seq[String], + storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 + ): DStream[Status] = { + val inputStream = new TwitterInputDStream(this, username, password, filters, storageLevel) + registerInputStream(inputStream) + inputStream + } + + /** + * Create an input stream from a queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * @param queue Queue of RDDs * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval @@ -300,7 +318,7 @@ class StreamingContext private ( } /** - * Creates an input stream from a queue of RDDs. In each batch, + * Create an input stream from a queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * @param queue Queue of RDDs * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval @@ -325,7 +343,7 @@ class StreamingContext private ( } /** - * Registers an input stream that will be started (InputDStream.start() called) to get the + * Register an input stream that will be started (InputDStream.start() called) to get the * input data. */ def registerInputStream(inputStream: InputDStream[_]) { @@ -333,7 +351,7 @@ class StreamingContext private ( } /** - * Registers an output stream that will be computed every interval + * Register an output stream that will be computed every interval */ def registerOutputStream(outputStream: DStream[_]) { graph.addOutputStream(outputStream) @@ -351,7 +369,7 @@ class StreamingContext private ( } /** - * Starts the execution of the streams. + * Start the execution of the streams. */ def start() { if (checkpointDir != null && checkpointDuration == null && graph != null) { @@ -379,7 +397,7 @@ class StreamingContext private ( } /** - * Stops the execution of the streams. + * Stop the execution of the streams. */ def stop() { try { diff --git a/examples/src/main/scala/spark/streaming/examples/twitter/TwitterInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala similarity index 94% rename from examples/src/main/scala/spark/streaming/examples/twitter/TwitterInputDStream.scala rename to streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala index 99ed4cdc1c..d733254ddb 100644 --- a/examples/src/main/scala/spark/streaming/examples/twitter/TwitterInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala @@ -1,4 +1,4 @@ -package spark.streaming.examples.twitter +package spark.streaming.dstream import spark._ import spark.streaming._ @@ -6,7 +6,6 @@ import dstream.{NetworkReceiver, NetworkInputDStream} import storage.StorageLevel import twitter4j._ import twitter4j.auth.BasicAuthorization -import collection.JavaConversions._ /* A stream of Twitter statuses, potentially filtered by one or more keywords. * @@ -50,7 +49,7 @@ class TwitterReceiver( def onTrackLimitationNotice(i: Int) {} def onScrubGeo(l: Long, l1: Long) {} def onStallWarning(stallWarning: StallWarning) {} - def onException(e: Exception) {} + def onException(e: Exception) { stopOnError(e) } }) val query: FilterQuery = new FilterQuery From 4b8402e900c803e64b8a4e2094fd845ccfc9df36 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 14 Feb 2013 18:10:37 -0800 Subject: [PATCH 288/291] Moved Java streaming examples to examples/src/main/java/spark/streaming/... and fixed logging in NetworkInputTracker to highlight errors when receiver deregisters/shuts down. --- .../spark/streaming/examples/JavaFlumeEventCount.java | 0 .../spark/streaming/examples/JavaNetworkWordCount.java | 0 .../spark/streaming/examples/JavaQueueStream.java | 0 .../src/main/scala/spark/streaming/NetworkInputTracker.scala | 2 +- 4 files changed, 1 insertion(+), 1 deletion(-) rename examples/src/main/{scala => java}/spark/streaming/examples/JavaFlumeEventCount.java (100%) rename examples/src/main/{scala => java}/spark/streaming/examples/JavaNetworkWordCount.java (100%) rename examples/src/main/{scala => java}/spark/streaming/examples/JavaQueueStream.java (100%) diff --git a/examples/src/main/scala/spark/streaming/examples/JavaFlumeEventCount.java b/examples/src/main/java/spark/streaming/examples/JavaFlumeEventCount.java similarity index 100% rename from examples/src/main/scala/spark/streaming/examples/JavaFlumeEventCount.java rename to examples/src/main/java/spark/streaming/examples/JavaFlumeEventCount.java diff --git a/examples/src/main/scala/spark/streaming/examples/JavaNetworkWordCount.java b/examples/src/main/java/spark/streaming/examples/JavaNetworkWordCount.java similarity index 100% rename from examples/src/main/scala/spark/streaming/examples/JavaNetworkWordCount.java rename to examples/src/main/java/spark/streaming/examples/JavaNetworkWordCount.java diff --git a/examples/src/main/scala/spark/streaming/examples/JavaQueueStream.java b/examples/src/main/java/spark/streaming/examples/JavaQueueStream.java similarity index 100% rename from examples/src/main/scala/spark/streaming/examples/JavaQueueStream.java rename to examples/src/main/java/spark/streaming/examples/JavaQueueStream.java diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala index b54f53b203..ca5f11fdba 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala @@ -86,7 +86,7 @@ class NetworkInputTracker( } case DeregisterReceiver(streamId, msg) => { receiverInfo -= streamId - logInfo("De-registered receiver for network stream " + streamId + logError("De-registered receiver for network stream " + streamId + " with message " + msg) //TODO: Do something about the corresponding NetworkInputDStream } From ddcb976b0d7ce4a76168da33c0e947a5a6b5a255 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 15 Feb 2013 06:54:47 +0000 Subject: [PATCH 289/291] Made MasterFailureTest more robust. --- .../streaming/util/MasterFailureTest.scala | 26 ++++++++++++++++--- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala b/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala index 83d8591a3a..776e676063 100644 --- a/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala +++ b/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala @@ -40,6 +40,8 @@ object MasterFailureTest extends Logging { println("\n\n================= UPDATE-STATE-BY-KEY TEST =================\n\n") testUpdateStateByKey(directory, numBatches, batchDuration) + + println("\n\nSUCCESS\n\n") } def testMap(directory: String, numBatches: Int, batchDuration: Duration) { @@ -347,7 +349,8 @@ class FileGeneratingThread(input: Seq[String], testDir: Path, interval: Long) override def run() { val localTestDir = Files.createTempDir() - val fs = testDir.getFileSystem(new Configuration()) + var fs = testDir.getFileSystem(new Configuration()) + val maxTries = 3 try { Thread.sleep(5000) // To make sure that all the streaming context has been set up for (i <- 0 until input.size) { @@ -355,9 +358,24 @@ class FileGeneratingThread(input: Seq[String], testDir: Path, interval: Long) val localFile = new File(localTestDir, (i+1).toString) val hadoopFile = new Path(testDir, (i+1).toString) FileUtils.writeStringToFile(localFile, input(i).toString + "\n") - //fs.moveFromLocalFile(new Path(localFile.toString), new Path(testDir, i.toString)) - fs.copyFromLocalFile(new Path(localFile.toString), hadoopFile) - logInfo("Generated file " + hadoopFile + " at " + System.currentTimeMillis) + var tries = 0 + var done = false + while (!done && tries < maxTries) { + tries += 1 + try { + fs.copyFromLocalFile(new Path(localFile.toString), hadoopFile) + done = true + } catch { + case ioe: IOException => { + fs = testDir.getFileSystem(new Configuration()) + logWarning("Attempt " + tries + " at generating file " + hadoopFile + " failed.", ioe) + } + } + } + if (!done) + logError("Could not generate file " + hadoopFile) + else + logInfo("Generated file " + hadoopFile + " at " + System.currentTimeMillis) Thread.sleep(interval) localFile.delete() } From f98c7da23ef66812b8b4888230ee98c07f09af23 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 17 Feb 2013 15:06:41 -0800 Subject: [PATCH 290/291] Many changes to ensure better 2nd recovery if 2nd failure happens while recovering from 1st failure - Made the scheduler to checkpoint after clearing old metadata which ensures that a new checkpoint is written as soon as at least one batch gets computed while recovering from a failure. This ensures that if there is a 2nd failure while recovering from 1st failure, the system start 2nd recovery from a newer checkpoint. - Modified Checkpoint writer to write checkpoint in a different thread. - Added a check to make sure that compute for InputDStreams gets called only for strictly increasing times. - Changed implementation of slice to call getOrCompute on parent DStream in time-increasing order. - Added testcase to test slice. - Fixed testGroupByKeyAndWindow testcase in JavaAPISuite to verify results with expected output in an order-independent manner. --- .../scala/spark/streaming/Checkpoint.scala | 71 ++++++++++++------- .../main/scala/spark/streaming/DStream.scala | 27 ++++--- .../scala/spark/streaming/DStreamGraph.scala | 13 +++- .../scala/spark/streaming/JobManager.scala | 8 ++- .../scala/spark/streaming/Scheduler.scala | 27 ++++--- .../spark/streaming/StreamingContext.scala | 7 +- .../src/main/scala/spark/streaming/Time.scala | 11 ++- .../api/java/JavaStreamingContext.scala | 7 +- .../streaming/dstream/InputDStream.scala | 36 +++++++++- .../dstream/TwitterInputDStream.scala | 4 +- .../streaming/util/MasterFailureTest.scala | 2 +- .../java/spark/streaming/JavaAPISuite.java | 54 +++++++++----- .../java/spark/streaming/JavaTestUtils.scala | 1 + streaming/src/test/resources/log4j.properties | 4 +- .../streaming/BasicOperationsSuite.scala | 20 ++++++ .../spark/streaming/CheckpointSuite.scala | 5 +- .../scala/spark/streaming/TestSuiteBase.scala | 7 +- .../streaming/WindowOperationsSuite.scala | 5 +- 18 files changed, 210 insertions(+), 99 deletions(-) diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala index b9eb7f8ec4..7405c8b22e 100644 --- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -6,6 +6,8 @@ import org.apache.hadoop.fs.{FileUtil, Path} import org.apache.hadoop.conf.Configuration import java.io._ +import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} +import java.util.concurrent.Executors private[streaming] @@ -38,32 +40,50 @@ class CheckpointWriter(checkpointDir: String) extends Logging { val conf = new Configuration() var fs = file.getFileSystem(conf) val maxAttempts = 3 + val executor = Executors.newFixedThreadPool(1) + + class CheckpointWriteHandler(checkpointTime: Time, bytes: Array[Byte]) extends Runnable { + def run() { + var attempts = 0 + val startTime = System.currentTimeMillis() + while (attempts < maxAttempts) { + attempts += 1 + try { + logDebug("Saving checkpoint for time " + checkpointTime + " to file '" + file + "'") + if (fs.exists(file)) { + val bkFile = new Path(file.getParent, file.getName + ".bk") + FileUtil.copy(fs, file, fs, bkFile, true, true, conf) + logDebug("Moved existing checkpoint file to " + bkFile) + } + val fos = fs.create(file) + fos.write(bytes) + fos.close() + fos.close() + val finishTime = System.currentTimeMillis(); + logInfo("Checkpoint for time " + checkpointTime + " saved to file '" + file + + "', took " + bytes.length + " bytes and " + (finishTime - startTime) + " milliseconds") + return + } catch { + case ioe: IOException => + logWarning("Error writing checkpoint to file in " + attempts + " attempts", ioe) + } + } + logError("Could not write checkpoint for time " + checkpointTime + " to file '" + file + "'") + } + } def write(checkpoint: Checkpoint) { - // TODO: maybe do this in a different thread from the main stream execution thread - var attempts = 0 - while (attempts < maxAttempts) { - attempts += 1 - try { - logDebug("Saving checkpoint for time " + checkpoint.checkpointTime + " to file '" + file + "'") - if (fs.exists(file)) { - val bkFile = new Path(file.getParent, file.getName + ".bk") - FileUtil.copy(fs, file, fs, bkFile, true, true, conf) - logDebug("Moved existing checkpoint file to " + bkFile) - } - val fos = fs.create(file) - val oos = new ObjectOutputStream(fos) - oos.writeObject(checkpoint) - oos.close() - logInfo("Checkpoint for time " + checkpoint.checkpointTime + " saved to file '" + file + "'") - fos.close() - return - } catch { - case ioe: IOException => - logWarning("Error writing checkpoint to file in " + attempts + " attempts", ioe) - } - } - logError("Could not write checkpoint for time " + checkpoint.checkpointTime + " to file '" + file + "'") + val bos = new ByteArrayOutputStream() + val zos = new LZFOutputStream(bos) + val oos = new ObjectOutputStream(zos) + oos.writeObject(checkpoint) + oos.close() + bos.close() + executor.execute(new CheckpointWriteHandler(checkpoint.checkpointTime, bos.toByteArray)) + } + + def stop() { + executor.shutdown() } } @@ -85,7 +105,8 @@ object CheckpointReader extends Logging { // of ObjectInputStream is used to explicitly use the current thread's default class // loader to find and load classes. This is a well know Java issue and has popped up // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) - val ois = new ObjectInputStreamWithLoader(fis, Thread.currentThread().getContextClassLoader) + val zis = new LZFInputStream(fis) + val ois = new ObjectInputStreamWithLoader(zis, Thread.currentThread().getContextClassLoader) val cp = ois.readObject.asInstanceOf[Checkpoint] ois.close() fs.close() diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index ce42b742d7..84e4b5bedb 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -238,13 +238,15 @@ abstract class DStream[T: ClassManifest] ( dependencies.foreach(_.remember(parentRememberDuration)) } - /** This method checks whether the 'time' is valid wrt slideDuration for generating RDD */ + /** Checks whether the 'time' is valid wrt slideDuration for generating RDD */ protected def isTimeValid(time: Time): Boolean = { if (!isInitialized) { throw new Exception (this + " has not been initialized") } else if (time <= zeroTime || ! (time - zeroTime).isMultipleOf(slideDuration)) { + logInfo("Time " + time + " is invalid as zeroTime is " + zeroTime + " and slideDuration is " + slideDuration + " and difference is " + (time - zeroTime)) false } else { + logInfo("Time " + time + " is valid") true } } @@ -627,16 +629,21 @@ abstract class DStream[T: ClassManifest] ( * Return all the RDDs between 'fromTime' to 'toTime' (both included) */ def slice(fromTime: Time, toTime: Time): Seq[RDD[T]] = { - val rdds = new ArrayBuffer[RDD[T]]() - var time = toTime.floor(slideDuration) - while (time >= zeroTime && time >= fromTime) { - getOrCompute(time) match { - case Some(rdd) => rdds += rdd - case None => //throw new Exception("Could not get RDD for time " + time) - } - time -= slideDuration + if (!(fromTime - zeroTime).isMultipleOf(slideDuration)) { + logWarning("fromTime (" + fromTime + ") is not a multiple of slideDuration (" + slideDuration + ")") } - rdds.toSeq + if (!(toTime - zeroTime).isMultipleOf(slideDuration)) { + logWarning("toTime (" + fromTime + ") is not a multiple of slideDuration (" + slideDuration + ")") + } + val alignedToTime = toTime.floor(slideDuration) + val alignedFromTime = fromTime.floor(slideDuration) + + logInfo("Slicing from " + fromTime + " to " + toTime + + " (aligned to " + alignedFromTime + " and " + alignedToTime + ")") + + alignedFromTime.to(alignedToTime, slideDuration).flatMap(time => { + if (time >= zeroTime) getOrCompute(time) else None + }) } /** diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala index 22d9e24f05..adb7f3a24d 100644 --- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala @@ -86,10 +86,12 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { def getOutputStreams() = this.synchronized { outputStreams.toArray } - def generateRDDs(time: Time): Seq[Job] = { + def generateJobs(time: Time): Seq[Job] = { this.synchronized { - logInfo("Generating RDDs for time " + time) - outputStreams.flatMap(outputStream => outputStream.generateJob(time)) + logInfo("Generating jobs for time " + time) + val jobs = outputStreams.flatMap(outputStream => outputStream.generateJob(time)) + logInfo("Generated " + jobs.length + " jobs for time " + time) + jobs } } @@ -97,18 +99,23 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { this.synchronized { logInfo("Clearing old metadata for time " + time) outputStreams.foreach(_.clearOldMetadata(time)) + logInfo("Cleared old metadata for time " + time) } } def updateCheckpointData(time: Time) { this.synchronized { + logInfo("Updating checkpoint data for time " + time) outputStreams.foreach(_.updateCheckpointData(time)) + logInfo("Updated checkpoint data for time " + time) } } def restoreCheckpointData() { this.synchronized { + logInfo("Restoring checkpoint data") outputStreams.foreach(_.restoreCheckpointData()) + logInfo("Restored checkpoint data") } } diff --git a/streaming/src/main/scala/spark/streaming/JobManager.scala b/streaming/src/main/scala/spark/streaming/JobManager.scala index 649494ff4a..7696c4a592 100644 --- a/streaming/src/main/scala/spark/streaming/JobManager.scala +++ b/streaming/src/main/scala/spark/streaming/JobManager.scala @@ -43,20 +43,24 @@ class JobManager(ssc: StreamingContext, numThreads: Int = 1) extends Logging { } private def clearJob(job: Job) { + var timeCleared = false + val time = job.time jobs.synchronized { - val time = job.time val jobsOfTime = jobs.get(time) if (jobsOfTime.isDefined) { jobsOfTime.get -= job if (jobsOfTime.get.isEmpty) { - ssc.scheduler.clearOldMetadata(time) jobs -= time + timeCleared = true } } else { throw new Exception("Job finished for time " + job.time + " but time does not exist in jobs") } } + if (timeCleared) { + ssc.scheduler.clearOldMetadata(time) + } } def getPendingTimes(): Array[Time] = { diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index 57d494da83..1c4b22a898 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -20,8 +20,9 @@ class Scheduler(ssc: StreamingContext) extends Logging { val clockClass = System.getProperty("spark.streaming.clock", "spark.streaming.util.SystemClock") val clock = Class.forName(clockClass).newInstance().asInstanceOf[Clock] val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds, - longTime => generateRDDs(new Time(longTime))) + longTime => generateJobs(new Time(longTime))) val graph = ssc.graph + var latestTime: Time = null def start() = synchronized { if (ssc.isCheckpointPresent) { @@ -35,6 +36,7 @@ class Scheduler(ssc: StreamingContext) extends Logging { def stop() = synchronized { timer.stop() jobManager.stop() + if (checkpointWriter != null) checkpointWriter.stop() ssc.graph.stop() logInfo("Scheduler stopped") } @@ -73,35 +75,38 @@ class Scheduler(ssc: StreamingContext) extends Logging { val timesToReschedule = (pendingTimes ++ downTimes).distinct.sorted(Time.ordering) logInfo("Batches to reschedule: " + timesToReschedule.mkString(", ")) timesToReschedule.foreach(time => - graph.generateRDDs(time).foreach(jobManager.runJob) + graph.generateJobs(time).foreach(jobManager.runJob) ) // Restart the timer timer.start(restartTime.milliseconds) - logInfo("Scheduler's timer restarted") + logInfo("Scheduler's timer restarted at " + restartTime) } - /** Generates the RDDs, clears old metadata and does checkpoint for the given time */ - def generateRDDs(time: Time) { + /** Generate jobs and perform checkpoint for the given `time`. */ + def generateJobs(time: Time) { SparkEnv.set(ssc.env) logInfo("\n-----------------------------------------------------\n") - graph.generateRDDs(time).foreach(jobManager.runJob) + graph.generateJobs(time).foreach(jobManager.runJob) + latestTime = time doCheckpoint(time) } - + /** + * Clear old metadata assuming jobs of `time` have finished processing. + * And also perform checkpoint. + */ def clearOldMetadata(time: Time) { ssc.graph.clearOldMetadata(time) + doCheckpoint(time) } - def doCheckpoint(time: Time) { + /** Perform checkpoint for the give `time`. */ + def doCheckpoint(time: Time) = synchronized { if (ssc.checkpointDuration != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) { logInfo("Checkpointing graph for time " + time) - val startTime = System.currentTimeMillis() ssc.graph.updateCheckpointData(time) checkpointWriter.write(new Checkpoint(ssc, time)) - val stopTime = System.currentTimeMillis() - logInfo("Checkpointing the graph took " + (stopTime - startTime) + " ms") } } } diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 9be9d884be..d1407b7869 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -119,18 +119,15 @@ class StreamingContext private ( /** * Set the context to periodically checkpoint the DStream operations for master - * fault-tolerance. By default, the graph will be checkpointed every batch interval. + * fault-tolerance. The graph will be checkpointed every batch interval. * @param directory HDFS-compatible directory where the checkpoint data will be reliably stored - * @param interval checkpoint interval */ - def checkpoint(directory: String, interval: Duration = null) { + def checkpoint(directory: String) { if (directory != null) { sc.setCheckpointDir(StreamingContext.getSparkCheckpointDir(directory)) checkpointDir = directory - checkpointDuration = interval } else { checkpointDir = null - checkpointDuration = null } } diff --git a/streaming/src/main/scala/spark/streaming/Time.scala b/streaming/src/main/scala/spark/streaming/Time.scala index 8201e84a20..f14decf08b 100644 --- a/streaming/src/main/scala/spark/streaming/Time.scala +++ b/streaming/src/main/scala/spark/streaming/Time.scala @@ -38,15 +38,14 @@ case class Time(private val millis: Long) { def max(that: Time): Time = if (this > that) this else that def until(that: Time, interval: Duration): Seq[Time] = { - assert(that > this, "Cannot create sequence as " + that + " not more than " + this) - assert( - (that - this).isMultipleOf(interval), - "Cannot create sequence as gap between " + that + " and " + - this + " is not multiple of " + interval - ) (this.milliseconds) until (that.milliseconds) by (interval.milliseconds) map (new Time(_)) } + def to(that: Time, interval: Duration): Seq[Time] = { + (this.milliseconds) to (that.milliseconds) by (interval.milliseconds) map (new Time(_)) + } + + override def toString: String = (millis.toString + " ms") } diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala index 5bbf2b084f..03933aae93 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -314,12 +314,11 @@ class JavaStreamingContext(val ssc: StreamingContext) { /** * Sets the context to periodically checkpoint the DStream operations for master - * fault-tolerance. By default, the graph will be checkpointed every batch interval. + * fault-tolerance. The graph will be checkpointed every batch interval. * @param directory HDFS-compatible directory where the checkpoint data will be reliably stored - * @param interval checkpoint interval */ - def checkpoint(directory: String, interval: Duration = null) { - ssc.checkpoint(directory, interval) + def checkpoint(directory: String) { + ssc.checkpoint(directory) } /** diff --git a/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala index 980ca5177e..a4db44a608 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala @@ -1,10 +1,42 @@ package spark.streaming.dstream -import spark.streaming.{Duration, StreamingContext, DStream} +import spark.streaming.{Time, Duration, StreamingContext, DStream} +/** + * This is the abstract base class for all input streams. This class provides to methods + * start() and stop() which called by the scheduler to start and stop receiving data/ + * Input streams that can generated RDDs from new data just by running a service on + * the driver node (that is, without running a receiver onworker nodes) can be + * implemented by directly subclassing this InputDStream. For example, + * FileInputDStream, a subclass of InputDStream, monitors a HDFS directory for + * new files and generates RDDs on the new files. For implementing input streams + * that requires running a receiver on the worker nodes, use NetworkInputDStream + * as the parent class. + */ abstract class InputDStream[T: ClassManifest] (@transient ssc_ : StreamingContext) extends DStream[T](ssc_) { + var lastValidTime: Time = null + + /** + * Checks whether the 'time' is valid wrt slideDuration for generating RDD. + * Additionally it also ensures valid times are in strictly increasing order. + * This ensures that InputDStream.compute() is called strictly on increasing + * times. + */ + override protected def isTimeValid(time: Time): Boolean = { + if (!super.isTimeValid(time)) { + false // Time not valid + } else { + // Time is valid, but check it it is more than lastValidTime + if (lastValidTime == null || lastValidTime <= time) { + logWarning("isTimeValid called with " + time + " where as last valid time is " + lastValidTime) + } + lastValidTime = time + true + } + } + override def dependencies = List() override def slideDuration: Duration = { @@ -13,7 +45,9 @@ abstract class InputDStream[T: ClassManifest] (@transient ssc_ : StreamingContex ssc.graph.batchDuration } + /** Method called to start receiving data. Subclasses must implement this method. */ def start() + /** Method called to stop receiving data. Subclasses must implement this method. */ def stop() } diff --git a/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala index d733254ddb..e70822e5c3 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala @@ -2,8 +2,8 @@ package spark.streaming.dstream import spark._ import spark.streaming._ -import dstream.{NetworkReceiver, NetworkInputDStream} import storage.StorageLevel + import twitter4j._ import twitter4j.auth.BasicAuthorization @@ -19,7 +19,7 @@ class TwitterInputDStream( password: String, filters: Seq[String], storageLevel: StorageLevel - ) extends NetworkInputDStream[Status](ssc_) { + ) extends NetworkInputDStream[Status](ssc_) { override def createReceiver(): NetworkReceiver[Status] = { new TwitterReceiver(username, password, filters, storageLevel) diff --git a/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala b/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala index 776e676063..bdd9f4d753 100644 --- a/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala +++ b/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala @@ -315,7 +315,7 @@ class KillingThread(ssc: StreamingContext, maxKillWaitTime: Long) extends Thread override def run() { try { // If it is the first killing, then allow the first checkpoint to be created - var minKillWaitTime = if (MasterFailureTest.killCount == 0) 5000 else 1000 + var minKillWaitTime = if (MasterFailureTest.killCount == 0) 5000 else 2000 val killWaitTime = minKillWaitTime + math.abs(Random.nextLong % maxKillWaitTime) logInfo("Kill wait time = " + killWaitTime) Thread.sleep(killWaitTime) diff --git a/streaming/src/test/java/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/spark/streaming/JavaAPISuite.java index 7bea0b1fc4..16bacffb92 100644 --- a/streaming/src/test/java/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/spark/streaming/JavaAPISuite.java @@ -23,6 +23,7 @@ import spark.streaming.JavaCheckpointTestUtils; import spark.streaming.dstream.KafkaPartitionKey; import java.io.*; +import java.text.Collator; import java.util.*; // The test suite itself is Serializable so that anonymous Function implementations can be @@ -35,7 +36,7 @@ public class JavaAPISuite implements Serializable { public void setUp() { System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock"); ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); - ssc.checkpoint("checkpoint", new Duration(1000)); + ssc.checkpoint("checkpoint"); } @After @@ -587,26 +588,47 @@ public class JavaAPISuite implements Serializable { @Test public void testGroupByKeyAndWindow() { - List>> inputData = stringStringKVStream; + List>> inputData = stringIntKVStream; - List>>> expected = Arrays.asList( - Arrays.asList(new Tuple2>("california", Arrays.asList("dodgers", "giants")), - new Tuple2>("new york", Arrays.asList("yankees", "mets"))), - Arrays.asList(new Tuple2>("california", - Arrays.asList("sharks", "ducks", "dodgers", "giants")), - new Tuple2>("new york", Arrays.asList("rangers", "islanders", "yankees", "mets"))), - Arrays.asList(new Tuple2>("california", Arrays.asList("sharks", "ducks")), - new Tuple2>("new york", Arrays.asList("rangers", "islanders")))); + List>>> expected = Arrays.asList( + Arrays.asList( + new Tuple2>("california", Arrays.asList(1, 3)), + new Tuple2>("new york", Arrays.asList(1, 4)) + ), + Arrays.asList( + new Tuple2>("california", Arrays.asList(1, 3, 5, 5)), + new Tuple2>("new york", Arrays.asList(1, 1, 3, 4)) + ), + Arrays.asList( + new Tuple2>("california", Arrays.asList(5, 5)), + new Tuple2>("new york", Arrays.asList(1, 3)) + ) + ); - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream> groupWindowed = + JavaPairDStream> groupWindowed = pairStream.groupByKeyAndWindow(new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(groupWindowed); - List>>> result = JavaTestUtils.runStreams(ssc, 3, 3); + List>>> result = JavaTestUtils.runStreams(ssc, 3, 3); - Assert.assertEquals(expected, result); + assert(result.size() == expected.size()); + for (int i = 0; i < result.size(); i++) { + assert(convert(result.get(i)).equals(convert(expected.get(i)))); + } + } + + private HashSet>> convert(List>> listOfTuples) { + List>> newListOfTuples = new ArrayList>>(); + for (Tuple2> tuple: listOfTuples) { + newListOfTuples.add(convert(tuple)); + } + return new HashSet>>(newListOfTuples); + } + + private Tuple2> convert(Tuple2> tuple) { + return new Tuple2>(tuple._1(), new HashSet(tuple._2())); } @Test @@ -894,7 +916,7 @@ public class JavaAPISuite implements Serializable { Arrays.asList(8,7)); File tempDir = Files.createTempDir(); - ssc.checkpoint(tempDir.getAbsolutePath(), new Duration(1000)); + ssc.checkpoint(tempDir.getAbsolutePath()); JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream letterCount = stream.map(new Function() { diff --git a/streaming/src/test/java/spark/streaming/JavaTestUtils.scala b/streaming/src/test/java/spark/streaming/JavaTestUtils.scala index 56349837e5..52ea28732a 100644 --- a/streaming/src/test/java/spark/streaming/JavaTestUtils.scala +++ b/streaming/src/test/java/spark/streaming/JavaTestUtils.scala @@ -57,6 +57,7 @@ trait JavaTestBase extends TestSuiteBase { } object JavaTestUtils extends JavaTestBase { + override def maxWaitTimeMillis = 20000 } diff --git a/streaming/src/test/resources/log4j.properties b/streaming/src/test/resources/log4j.properties index f0638e0e02..59c445e63f 100644 --- a/streaming/src/test/resources/log4j.properties +++ b/streaming/src/test/resources/log4j.properties @@ -1,5 +1,5 @@ # Set everything to be logged to the file streaming/target/unit-tests.log -log4j.rootCategory=WARN, file +log4j.rootCategory=INFO, file # log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=false @@ -9,6 +9,4 @@ log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: # Ignore messages below warning level from Jetty, because it's a bit verbose log4j.logger.org.eclipse.jetty=WARN -log4j.logger.spark.streaming=INFO -log4j.logger.spark.streaming.dstream.FileInputDStream=DEBUG diff --git a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala index 1e86cf49bb..8fce91853c 100644 --- a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala @@ -229,6 +229,26 @@ class BasicOperationsSuite extends TestSuiteBase { testOperation(inputData, updateStateOperation, outputData, true) } + test("slice") { + val ssc = new StreamingContext("local[2]", "BasicOperationSuite", Seconds(1)) + val input = Seq(Seq(1), Seq(2), Seq(3), Seq(4)) + val stream = new TestInputStream[Int](ssc, input, 2) + ssc.registerInputStream(stream) + stream.foreach(_ => {}) // Dummy output stream + ssc.start() + Thread.sleep(2000) + def getInputFromSlice(fromMillis: Long, toMillis: Long) = { + stream.slice(new Time(fromMillis), new Time(toMillis)).flatMap(_.collect()).toSet + } + + assert(getInputFromSlice(0, 1000) == Set(1)) + assert(getInputFromSlice(0, 2000) == Set(1, 2)) + assert(getInputFromSlice(1000, 2000) == Set(1, 2)) + assert(getInputFromSlice(2000, 4000) == Set(2, 3, 4)) + ssc.stop() + Thread.sleep(1000) + } + test("forgetting of RDDs - map and window operations") { assert(batchDuration === Seconds(1), "Batch duration has changed from 1 second") diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala index c89c4a8d43..5250667bcb 100644 --- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -39,14 +39,11 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { override def batchDuration = Milliseconds(500) - override def checkpointInterval = batchDuration - override def actuallyWait = true test("basic rdd checkpoints + dstream graph checkpoint recovery") { assert(batchDuration === Milliseconds(500), "batchDuration for this test must be 1 second") - assert(checkpointInterval === batchDuration, "checkpointInterval for this test much be same as batchDuration") System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") @@ -188,7 +185,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { // Set up the streaming context and input streams val testDir = Files.createTempDir() var ssc = new StreamingContext(master, framework, Seconds(1)) - ssc.checkpoint(checkpointDir, checkpointInterval) + ssc.checkpoint(checkpointDir) val fileStream = ssc.textFileStream(testDir.toString) // Making value 3 take large time to process, to ensure that the master // shuts down in the middle of processing the 3rd batch diff --git a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala index 2cc31d6137..ad6aa79d10 100644 --- a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala @@ -75,9 +75,6 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { // Directory where the checkpoint data will be saved def checkpointDir = "checkpoint" - // Duration after which the graph is checkpointed - def checkpointInterval = batchDuration - // Number of partitions of the input parallel collections created for testing def numInputPartitions = 2 @@ -99,7 +96,7 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { // Create StreamingContext val ssc = new StreamingContext(master, framework, batchDuration) if (checkpointDir != null) { - ssc.checkpoint(checkpointDir, checkpointInterval) + ssc.checkpoint(checkpointDir) } // Setup the stream computation @@ -124,7 +121,7 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { // Create StreamingContext val ssc = new StreamingContext(master, framework, batchDuration) if (checkpointDir != null) { - ssc.checkpoint(checkpointDir, checkpointInterval) + ssc.checkpoint(checkpointDir) } // Setup the stream computation diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala index f8380af331..1b66f3bda2 100644 --- a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala @@ -273,6 +273,7 @@ class WindowOperationsSuite extends TestSuiteBase { slideDuration: Duration = Seconds(1) ) { test("reduceByKeyAndWindow - " + name) { + logInfo("reduceByKeyAndWindow - " + name) val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt val operation = (s: DStream[(String, Int)]) => { s.reduceByKeyAndWindow((x: Int, y: Int) => x + y, windowDuration, slideDuration) @@ -288,7 +289,8 @@ class WindowOperationsSuite extends TestSuiteBase { windowDuration: Duration = Seconds(2), slideDuration: Duration = Seconds(1) ) { - test("ReduceByKeyAndWindow with inverse function - " + name) { + test("reduceByKeyAndWindow with inverse function - " + name) { + logInfo("reduceByKeyAndWindow with inverse function - " + name) val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt val operation = (s: DStream[(String, Int)]) => { s.reduceByKeyAndWindow(_ + _, _ - _, windowDuration, slideDuration) @@ -306,6 +308,7 @@ class WindowOperationsSuite extends TestSuiteBase { slideDuration: Duration = Seconds(1) ) { test("reduceByKeyAndWindow with inverse and filter functions - " + name) { + logInfo("reduceByKeyAndWindow with inverse and filter functions - " + name) val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt val filterFunc = (p: (String, Int)) => p._2 != 0 val operation = (s: DStream[(String, Int)]) => { From 8ad561dc7d6475d7b217ec3f57bac3b584fed31a Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 18 Feb 2013 02:12:41 -0800 Subject: [PATCH 291/291] Added checkpointing and fault-tolerance semantics to the programming guide. Fixed default checkpoint interval to being a multiple of slide duration. Fixed visibility of some classes and objects to clean up docs. --- docs/python-programming-guide.md | 2 +- docs/streaming-programming-guide.md | 252 ++++++++++++++---- .../main/scala/spark/streaming/DStream.scala | 2 +- .../main/scala/spark/streaming/Duration.scala | 2 +- .../main/scala/spark/streaming/Interval.scala | 1 + .../streaming/PairDStreamFunctions.scala | 8 +- .../dstream/TwitterInputDStream.scala | 2 + .../spark/streaming/CheckpointSuite.scala | 2 +- 8 files changed, 209 insertions(+), 62 deletions(-) diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md index 4e84d23edf..2012241a6a 100644 --- a/docs/python-programming-guide.md +++ b/docs/python-programming-guide.md @@ -87,7 +87,7 @@ By default, the `pyspark` shell creates SparkContext that runs jobs locally. To connect to a non-local cluster, set the `MASTER` environment variable. For example, to use the `pyspark` shell with a [standalone Spark cluster](spark-standalone.html): -{% highlight shell %} +{% highlight bash %} $ MASTER=spark://IP:PORT ./pyspark {% endhighlight %} diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index b6da7af654..d408e80359 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -34,8 +34,8 @@ The StreamingContext is used to creating InputDStreams from input sources: {% highlight scala %} // Assuming ssc is the StreamingContext -ssc.networkStream(hostname, port) // Creates a stream that uses a TCP socket to read data from hostname:port -ssc.textFileStream(directory) // Creates a stream by monitoring and processing new files in a HDFS directory +ssc.networkStream(hostname, port) // Creates a stream that uses a TCP socket to read data from hostname:port +ssc.textFileStream(directory) // Creates a stream by monitoring and processing new files in a HDFS directory {% endhighlight %} A complete list of input sources is available in the [StreamingContext API documentation](api/streaming/index.html#spark.streaming.StreamingContext). Data received from these sources can be processed using DStream operations, which are explained next. @@ -50,18 +50,18 @@ Once an input DStream has been created, you can transform it using _DStream oper DStreams support many of the transformations available on normal Spark RDD's: - + - + - + - + @@ -70,73 +70,92 @@ DStreams support many of the transformations available on normal Spark RDD's: - + + + + + + + + + + + + + - - + - + - - - - - + + + + + +
    TransformationMeaning
    TransformationMeaning
    map(func) Returns a new DStream formed by passing each element of the source through a function func. Returns a new DStream formed by passing each element of the source DStream through a function func.
    filter(func) Returns a new stream formed by selecting those elements of the source on which func returns true. Returns a new DStream formed by selecting those elements of the source DStream on which func returns true.
    flatMap(func) Similar to map, but each input item can be mapped to 0 or more output items (so func should return a Seq rather than a single item). Similar to map, but each input item can be mapped to 0 or more output items (so func should return a Seq rather than a single item).
    mapPartitions(func)
    union(otherStream) Return a new stream that contains the union of the elements in the source stream and the argument. Return a new DStream that contains the union of the elements in the source DStream and the argument DStream.
    count() Returns a new DStream of single-element RDDs by counting the number of elements in each RDD of the source DStream.
    reduce(func) Returns a new DStream of single-element RDDs by aggregating the elements in each RDD of the source DStream using a function func (which takes two arguments and returns one). The function should be associative so that it can be computed in parallel.
    countByValue() When called on a DStream of elements of type K, returns a new DStream of (K, Long) pairs where the value of each key is its frequency in each RDD of the source DStream.
    groupByKey([numTasks]) When called on a stream of (K, V) pairs, returns a stream of (K, Seq[V]) pairs.
    -Note: By default, this uses only 8 parallel tasks to do the grouping. You can pass an optional numTasks argument to set a different number of tasks. +
    When called on a DStream of (K, V) pairs, returns a new DStream of (K, Seq[V]) pairs by grouping together all the values of each key in the RDDs of the source DStream.
    + Note: By default, this uses Spark's default number of parallel tasks (2 for local machine, 8 for a cluser) to do the grouping. You can pass an optional numTasks argument to set a different number of tasks.
    reduceByKey(func, [numTasks]) When called on a stream of (K, V) pairs, returns a stream of (K, V) pairs where the values for each key are aggregated using the given reduce function. Like in groupByKey, the number of reduce tasks is configurable through an optional second argument. When called on a DStream of (K, V) pairs, returns a new DStream of (K, V) pairs where the values for each key are aggregated using the given reduce function. Like in groupByKey, the number of reduce tasks is configurable through an optional second argument.
    join(otherStream, [numTasks]) When called on streams of type (K, V) and (K, W), returns a stream of (K, (V, W)) pairs with all pairs of elements for each key. When called on two DStreams of (K, V) and (K, W) pairs, returns a new DStream of (K, (V, W)) pairs with all pairs of elements for each key.
    cogroup(otherStream, [numTasks]) When called on DStream of type (K, V) and (K, W), returns a DStream of (K, Seq[V], Seq[W]) tuples.
    reduce(func) Returns a new DStream of single-element RDDs by aggregating the elements of the stream using a function func (which takes two arguments and returns one). The function should be associative so that it can be computed correctly in parallel. When called on DStream of (K, V) and (K, W) pairs, returns a new DStream of (K, Seq[V], Seq[W]) tuples.
    transform(func) Returns a new DStream by applying func (a RDD-to-RDD function) to every RDD of the stream. This can be used to do arbitrary RDD operations on the DStream.
    updateStateByKey(func) Return a new "state" DStream where the state for each key is updated by applying the given function on the previous state of the key and the new values of each key. This can be used to track session state by using the session-id as the key and updating the session state as new data is received.
    -Spark Streaming features windowed computations, which allow you to report statistics over a sliding window of data. All window functions take a windowDuration, which represents the width of the window and a slideTime, which represents the frequency during which the window is calculated. +Spark Streaming features windowed computations, which allow you to apply transformations over a sliding window of data. All window functions take a windowDuration, which represents the width of the window and a slideTime, which represents the frequency during which the window is calculated. - + - - + - + - + - - + + + + + + + + - - + - - - - +
    TransformationMeaning
    TransformationMeaning
    window(windowDuration, slideTime) Return a new stream which is computed based on windowed batches of the source stream. windowDuration is the width of the window and slideTime is the frequency during which the window is calculated. Both times must be multiples of the batch interval. + window(windowDuration, slideDuration) Return a new DStream which is computed based on windowed batches of the source DStream. windowDuration is the width of the window and slideTime is the frequency during which the window is calculated. Both times must be multiples of the batch interval.
    countByWindow(windowDuration, slideTime) countByWindow(windowDuration, slideDuration) Return a sliding count of elements in the stream. windowDuration and slideDuration are exactly as defined in window().
    reduceByWindow(func, windowDuration, slideDuration) reduceByWindow(func, windowDuration, slideDuration) Return a new single-element stream, created by aggregating elements in the stream over a sliding interval using func. The function should be associative so that it can be computed correctly in parallel. windowDuration and slideDuration are exactly as defined in window().
    groupByKeyAndWindow(windowDuration, slideDuration, [numTasks]) + groupByKeyAndWindow(windowDuration, slideDuration, [numTasks]) When called on a stream of (K, V) pairs, returns a stream of (K, Seq[V]) pairs over a sliding window.
    -Note: By default, this uses only 8 parallel tasks to do the grouping. You can pass an optional numTasks argument to set a different number of tasks. windowDuration and slideDuration are exactly as defined in window(). +
    When called on a DStream of (K, V) pairs, returns a new DStream of (K, Seq[V]) pairs by grouping together values of each key over batches in a sliding window.
    +Note: By default, this uses Spark's default number of parallel tasks (2 for local machine, 8 for a cluser) to do the grouping. You can pass an optional numTasks argument to set a different number of tasks.
    reduceByKeyAndWindow(func, windowDuration, slideDuration, [numTasks]) When called on a DStream of (K, V) pairs, returns a new DStream of (K, V) pairs where the values for each key are aggregated using the given reduce function func over batches in a sliding window. Like in groupByKeyAndWindow, the number of reduce tasks is configurable through an optional second argument. + windowDuration and slideDuration are exactly as defined in window(). +
    reduceByKeyAndWindow(func, invFunc, windowDuration, slideDuration, [numTasks]) A more efficient version of the above reduceByKeyAndWindow() where the reduce value of each window is calculated + incrementally using the reduce values of the previous window. This is done by reducing the new data that enter the sliding window, and "inverse reducing" the old data that leave the window. An example would be that of "adding" and "subtracting" counts of keys as the window slides. However, it is applicable to only "invertible reduce functions", that is, those reduce functions which have a corresponding "inverse reduce" function (taken as parameter invFunc. Like in groupByKeyAndWindow, the number of reduce tasks is configurable through an optional second argument. + windowDuration and slideDuration are exactly as defined in window().
    reduceByKeyAndWindow(func, [numTasks]) When called on a stream of (K, V) pairs, returns a stream of (K, V) pairs where the values for each key are aggregated using the given reduce function over batches within a sliding window. Like in groupByKeyAndWindow, the number of reduce tasks is configurable through an optional second argument. + countByValueAndWindow(windowDuration, slideDuration, [numTasks]) When called on a DStream of (K, V) pairs, returns a new DStream of (K, Long) pairs where the value of each key is its frequency within a sliding window. Like in groupByKeyAndWindow, the number of reduce tasks is configurable through an optional second argument. windowDuration and slideDuration are exactly as defined in window(). -
    countByKeyAndWindow([numTasks]) When called on a stream of (K, V) pairs, returns a stream of (K, Int) pairs where the values for each key are the count within a sliding window. Like in countByKeyAndWindow, the number of reduce tasks is configurable through an optional second argument. - windowDuration and slideDuration are exactly as defined in window(). -
    @@ -147,7 +166,7 @@ A complete list of DStream operations is available in the API documentation of [ When an output operator is called, it triggers the computation of a stream. Currently the following output operators are defined: - + @@ -176,11 +195,6 @@ When an output operator is called, it triggers the computation of a stream. Curr
    OperatorMeaning
    OperatorMeaning
    foreach(func) The fundamental output operator. Applies a function, func, to each RDD generated from the stream. This function should have side effects, such as printing output, saving the RDD to external files, or writing it over the network to an external system.
    -## DStream Persistence -Similar to RDDs, DStreams also allow developers to persist the stream's data in memory. That is, using `persist()` method on a DStream would automatically persist every RDD of that DStream in memory. This is useful if the data in the DStream will be computed multiple times (e.g., multiple DStream operations on the same data). For window-based operations like `reduceByWindow` and `reduceByKeyAndWindow` and state-based operations like `updateStateByKey`, this is implicitly true. Hence, DStreams generated by window-based operations are automatically persisted in memory, without the developer calling `persist()`. - -Note that, unlike RDDs, the default persistence level of DStreams keeps the data serialized in memory. This is further discussed in the [Performance Tuning](#memory-tuning) section. More information on different persistence levels can be found in [Spark Programming Guide](scala-programming-guide.html#rdd-persistence). - # Starting the Streaming computation All the above DStream operations are completely lazy, that is, the operations will start executing only after the context is started by using {% highlight scala %} @@ -192,8 +206,8 @@ Conversely, the computation can be stopped by using ssc.stop() {% endhighlight %} -# Example - NetworkWordCount.scala -A good example to start off is the spark.streaming.examples.NetworkWordCount. This example counts the words received from a network server every second. Given below is the relevant sections of the source code. You can find the full source code in /streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala. +# Example +A simple example to start off is the [NetworkWordCount](https://github.com/mesos/spark/tree/master/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala). This example counts the words received from a network server every second. Given below is the relevant sections of the source code. You can find the full source code in `/streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala` . {% highlight scala %} import spark.streaming.{Seconds, StreamingContext} @@ -260,6 +274,31 @@ Time: 1357008430000 ms +You can find more examples in `/streaming/src/main/scala/spark/streaming/examples/`. They can be run in the similar manner using `./run spark.streaming.examples....` . Executing without any parameter would give the required parameter list. Further explanation to run them can be found in comments in the files. + +# DStream Persistence +Similar to RDDs, DStreams also allow developers to persist the stream's data in memory. That is, using `persist()` method on a DStream would automatically persist every RDD of that DStream in memory. This is useful if the data in the DStream will be computed multiple times (e.g., multiple operations on the same data). For window-based operations like `reduceByWindow` and `reduceByKeyAndWindow` and state-based operations like `updateStateByKey`, this is implicitly true. Hence, DStreams generated by window-based operations are automatically persisted in memory, without the developer calling `persist()`. + +For input streams that receive data from the network (that is, subclasses of NetworkInputDStream like FlumeInputDStream and KafkaInputDStream), the default persistence level is set to replicate the data to two nodes for fault-tolerance. + +Note that, unlike RDDs, the default persistence level of DStreams keeps the data serialized in memory. This is further discussed in the [Performance Tuning](#memory-tuning) section. More information on different persistence levels can be found in [Spark Programming Guide](scala-programming-guide.html#rdd-persistence). + +# RDD Checkpointing within DStreams +DStreams created by stateful operations like `updateStateByKey` require the RDDs in the DStream to be periodically saved to HDFS files for checkpointing. This is because, unless checkpointed, the lineage of operations of the state RDDs can increase indefinitely (since each RDD in the DStream depends on the previous RDD). This leads to two problems - (i) the size of Spark tasks increase proportionally with the RDD lineage leading higher task launch times, (ii) no limit on the amount of recomputation required on failure. Checkpointing RDDs at some interval by writing them to HDFS allows the lineage to be truncated. Note that checkpointing also incurs the cost of saving to HDFS which may cause the corresponding batch to take longer to process. Hence, the interval of checkpointing needs to be set carefully. At small batch sizes (say 1 second), checkpointing every batch may significantly reduce operation throughput. Conversely, checkpointing too slowly causes the lineage and task sizes to grow which may have detrimental effects. Typically, a checkpoint interval of 5 - 10 times of sliding interval of a DStream is good setting to try. + +To enable checkpointing, the developer has to provide the HDFS path to which RDD will be saved. This is done by using + +{% highlight scala %} +ssc.checkpoint(hdfsPath) // assuming ssc is the StreamingContext +{% endhighlight %} + +The interval of checkpointing of a DStream can be set by using + +{% highlight scala %} +dstream.checkpoint(checkpointInterval) // checkpointInterval must be a multiple of slide duration of dstream +{% endhighlight %} + +For DStreams that must be checkpointed (that is, DStreams created by `updateStateByKey` and `reduceByKeyAndWindow` with inverse function), the checkpoint interval of the DStream is by default set to a multiple of the DStream's sliding interval such that its at least 10 seconds. # Performance Tuning @@ -273,17 +312,21 @@ Getting the best performance of a Spark Streaming application on a cluster requi There are a number of optimizations that can be done in Spark to minimize the processing time of each batch. These have been discussed in detail in [Tuning Guide](tuning.html). This section highlights some of the most important ones. ### Level of Parallelism -Cluster resources maybe underutilized if the number of parallel tasks used in any stage of the computation is not high enough. For example, for distributed reduce operations like `reduceByKey` and `reduceByKeyAndWindow`, the default number of parallel tasks is 8. You can pass the level of parallelism as an argument (see the [`spark.PairDStreamFunctions`](api/streaming/index.html#spark.PairDStreamFunctions) documentation), or set the system property `spark.default.parallelism` to change the default. +Cluster resources maybe under-utilized if the number of parallel tasks used in any stage of the computation is not high enough. For example, for distributed reduce operations like `reduceByKey` and `reduceByKeyAndWindow`, the default number of parallel tasks is 8. You can pass the level of parallelism as an argument (see the [`spark.PairDStreamFunctions`](api/streaming/index.html#spark.PairDStreamFunctions) documentation), or set the system property `spark.default.parallelism` to change the default. ### Data Serialization The overhead of data serialization can be significant, especially when sub-second batch sizes are to be achieved. There are two aspects to it. -* Serialization of RDD data in Spark: Please refer to the detailed discussion on data serialization in the [Tuning Guide](tuning.html). However, note that unlike Spark, by default RDDs are persisted as serialized byte arrays to minimize pauses related to GC. -* Serialization of input data: To ingest external data into Spark, data received as bytes (say, from the network) needs to deserialized from bytes and re-serialized into Spark's serialization format. Hence, the deserialization overhead of input data may be a bottleneck. + +* **Serialization of RDD data in Spark**: Please refer to the detailed discussion on data serialization in the [Tuning Guide](tuning.html). However, note that unlike Spark, by default RDDs are persisted as serialized byte arrays to minimize pauses related to GC. + +* **Serialization of input data**: To ingest external data into Spark, data received as bytes (say, from the network) needs to deserialized from bytes and re-serialized into Spark's serialization format. Hence, the deserialization overhead of input data may be a bottleneck. ### Task Launching Overheads If the number of tasks launched per second is high (say, 50 or more per second), then the overhead of sending out tasks to the slaves maybe significant and will make it hard to achieve sub-second latencies. The overhead can be reduced by the following changes: -* Task Serialization: Using Kryo serialization for serializing tasks can reduced the task sizes, and therefore reduce the time taken to send them to the slaves. -* Execution mode: Running Spark in Standalone mode or coarse-grained Mesos mode leads to better task launch times than the fine-grained Mesos mode. Please refer to the [Running on Mesos guide](running-on-mesos.html) for more details. + +* **Task Serialization**: Using Kryo serialization for serializing tasks can reduced the task sizes, and therefore reduce the time taken to send them to the slaves. + +* **Execution mode**: Running Spark in Standalone mode or coarse-grained Mesos mode leads to better task launch times than the fine-grained Mesos mode. Please refer to the [Running on Mesos guide](running-on-mesos.html) for more details. These changes may reduce batch processing time by 100s of milliseconds, thus allowing sub-second batch size to be viable. ## Setting the Right Batch Size @@ -292,22 +335,121 @@ For a Spark Streaming application running on a cluster to be stable, the process A good approach to figure out the right batch size for your application is to test it with a conservative batch size (say, 5-10 seconds) and a low data rate. To verify whether the system is able to keep up with data rate, you can check the value of the end-to-end delay experienced by each processed batch (in the Spark master logs, find the line having the phrase "Total delay"). If the delay is maintained to be less than the batch size, then system is stable. Otherwise, if the delay is continuously increasing, it means that the system is unable to keep up and it therefore unstable. Once you have an idea of a stable configuration, you can try increasing the data rate and/or reducing the batch size. Note that momentary increase in the delay due to temporary data rate increases maybe fine as long as the delay reduces back to a low value (i.e., less than batch size). ## 24/7 Operation -By default, Spark does not forget any of the metadata (RDDs generated, stages processed, etc.). But for a Spark Streaming application to operate 24/7, it is necessary for Spark to do periodic cleanup of it metadata. This can be enabled by setting the Java system property `spark.cleaner.delay` to the number of minutes you want any metadata to persist. For example, setting `spark.cleaner.delay` to 10 would cause Spark periodically cleanup all metadata and persisted RDDs that are older than 10 minutes. Note, that this property needs to be set before the SparkContext is created. +By default, Spark does not forget any of the metadata (RDDs generated, stages processed, etc.). But for a Spark Streaming application to operate 24/7, it is necessary for Spark to do periodic cleanup of it metadata. This can be enabled by setting the Java system property `spark.cleaner.delay` to the number of seconds you want any metadata to persist. For example, setting `spark.cleaner.delay` to 600 would cause Spark periodically cleanup all metadata and persisted RDDs that are older than 10 minutes. Note, that this property needs to be set before the SparkContext is created. This value is closely tied with any window operation that is being used. Any window operation would require the input data to be persisted in memory for at least the duration of the window. Hence it is necessary to set the delay to at least the value of the largest window operation used in the Spark Streaming application. If this delay is set too low, the application will throw an exception saying so. ## Memory Tuning Tuning the memory usage and GC behavior of Spark applications have been discussed in great detail in the [Tuning Guide](tuning.html). It is recommended that you read that. In this section, we highlight a few customizations that are strongly recommended to minimize GC related pauses in Spark Streaming applications and achieving more consistent batch processing times. -* Default persistence level of DStreams: Unlike RDDs, the default persistence level of DStreams serializes the data in memory (that is, [StorageLevel.MEMORY_ONLY_SER](api/core/index.html#spark.storage.StorageLevel$) for DStream compared to [StorageLevel.MEMORY_ONLY](api/core/index.html#spark.storage.StorageLevel$) for RDDs). Even though keeping the data serialized incurs a higher serialization overheads, it significantly reduces GC pauses. +* **Default persistence level of DStreams**: Unlike RDDs, the default persistence level of DStreams serializes the data in memory (that is, [StorageLevel.MEMORY_ONLY_SER](api/core/index.html#spark.storage.StorageLevel$) for DStream compared to [StorageLevel.MEMORY_ONLY](api/core/index.html#spark.storage.StorageLevel$) for RDDs). Even though keeping the data serialized incurs a higher serialization overheads, it significantly reduces GC pauses. -* Concurrent garbage collector: Using the concurrent mark-and-sweep GC further minimizes the variability of GC pauses. Even though concurrent GC is known to reduce the overall processing throughput of the system, its use is still recommended to achieve more consistent batch processing times. +* **Concurrent garbage collector**: Using the concurrent mark-and-sweep GC further minimizes the variability of GC pauses. Even though concurrent GC is known to reduce the overall processing throughput of the system, its use is still recommended to achieve more consistent batch processing times. -# Master Fault-tolerance (Alpha) -TODO +# Fault-tolerance Properties +There are two aspects to fault-tolerance - failure of a worker node and that of a driver node. In this section, we are going to discuss the fault-tolerance behavior and the semantics of the processed data. -* Checkpointing of DStream graph +## Failure of a Worker Node +In case of the worker node failure, none of the processed data will be lost because -* Recovery from master faults +1. All the input data is fault-tolerant (either the data is on HDFS, or it replicated Spark Streaming if received from the network) +1. All intermediate data is expressed as RDDs with their lineage to the input data, which allows Spark to recompute any part of the intermediate data is lost to worker node failure. -* Current state and future directions \ No newline at end of file +If the worker node where a network data receiver is running fails, then the receiver will be restarted on a different node and it will continue to receive data. However, data that was accepted by the receiver but not yet replicated to other Spark nodes may be lost, which is a fraction of a second of data. + +Since all data is modeled as RDDs with their lineage of deterministic operations, any recomputation always leads to the same result. As a result, all DStream transformations are guaranteed to have _exactly-once_ semantics. That is, the final transformed result will be same even if there were was a worker node failure. However, output operations (like `foreach`) have _at-least once_ semantics, that is, the transformed data may get written to an external entity more than once in the event of a worker failure. While this is acceptable for saving to HDFS using the `saveAs*Files` operations (as the file will simply get over-written by the same data), additional transactions-like mechanisms may be necessary to achieve exactly-once semantics for output operations. + +## Failure of a Driver Node +A system that is required to operate 24/7 needs to be able tolerate the failure of the drive node as well. Spark Streaming does this by saving the state of the DStream computation periodically to a HDFS file, that can be used to restart the streaming computation in the event of a failure of the driver node. To elaborate, the following state is periodically saved to a file. + +1. The DStream operator graph (input streams, output streams, etc.) +1. The configuration of each DStream (checkpoint interval, etc.) +1. The RDD checkpoint files of each DStream + +All this is periodically saved in the file `/graph` where `` is the HDFS path set using `ssc.checkpoint(...)` as described earlier. To recover, a new Streaming Context can be created with this directory by using + +{% highlight scala %} +val ssc = new StreamingContext(checkpointDirectory) +{% endhighlight %} + +Calling `ssc.start()` on this new context will restart the receivers and the stream computations. + +In case of stateful operations (that is, `updateStateByKey` and `reduceByKeyAndWindow` with inverse function), the intermediate data at the time of failure also needs to be recomputed.This requires two things - (i) the RDD checkpoints and (ii) the data received since the checkpoints. In the current _alpha_ release, the input data received from the network is not saved durably across driver failures (the data is only replicated in memory of the worker processes and gets lost when the driver fails). Only with file input streams (where the data is already durably stored) is the recovery from driver failure complete and all intermediate data is recomputed. In a future release, this will be true for all input streams. Note that for non-stateful operations, with _all_ input streams, the system will recover and continue receiving and processing new data. + +To understand the behavior of the system under driver failure, lets consider what will happen with a file input stream Specifically, in the case of the file input stream, it will correctly identify new files that were created while the driver was down and process them in the same way as it would have if the driver had not failed. To explain further in the case of file input stream, we shall use an example. Lets say, files are being generated every second, and a Spark Streaming program reads every new file and output the number of lines in the file. This is what the sequence of outputs would be with and without a driver failure. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Time Number of lines in input file Output without driver failure Output with driver failure
    1101010
    2202020
    3303030
    44040[DRIVER FAILS]
    no output
    55050no output
    66060no output
    77070[DRIVER RECOVERS]
    40, 50, 60, 70
    8808080
    9909090
    10100100100
    + +If the driver had crashed in the middle of the processing of time 3, then it will process time 3 and output 30 after recovery. + +# Where to Go from Here +* Documentation - [Scala and Java](api/streaming/index.html) +* More examples - [Scala](https://github.com/mesos/spark/tree/master/examples/src/main/scala/spark/streaming/examples) and [Java](https://github.com/mesos/spark/tree/master/examples/src/main/java/spark/streaming/examples) \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 84e4b5bedb..e1be5ef51c 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -132,7 +132,7 @@ abstract class DStream[T: ClassManifest] ( // Set the checkpoint interval to be slideDuration or 10 seconds, which ever is larger if (mustCheckpoint && checkpointDuration == null) { - checkpointDuration = slideDuration.max(Seconds(10)) + checkpointDuration = slideDuration * math.ceil(Seconds(10) / slideDuration).toInt logInfo("Checkpoint interval automatically set to " + checkpointDuration) } diff --git a/streaming/src/main/scala/spark/streaming/Duration.scala b/streaming/src/main/scala/spark/streaming/Duration.scala index e4dc579a17..ee26206e24 100644 --- a/streaming/src/main/scala/spark/streaming/Duration.scala +++ b/streaming/src/main/scala/spark/streaming/Duration.scala @@ -16,7 +16,7 @@ case class Duration (private val millis: Long) { def * (times: Int): Duration = new Duration(millis * times) - def / (that: Duration): Long = millis / that.millis + def / (that: Duration): Double = millis.toDouble / that.millis.toDouble def isMultipleOf(that: Duration): Boolean = (this.millis % that.millis == 0) diff --git a/streaming/src/main/scala/spark/streaming/Interval.scala b/streaming/src/main/scala/spark/streaming/Interval.scala index dc21dfb722..6a8b81760e 100644 --- a/streaming/src/main/scala/spark/streaming/Interval.scala +++ b/streaming/src/main/scala/spark/streaming/Interval.scala @@ -30,6 +30,7 @@ class Interval(val beginTime: Time, val endTime: Time) { override def toString = "[" + beginTime + ", " + endTime + "]" } +private[streaming] object Interval { def currentInterval(duration: Duration): Interval = { val time = new Time(System.currentTimeMillis) diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala index 5127db3bbc..5a2dd46fa0 100644 --- a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala @@ -18,8 +18,8 @@ import org.apache.hadoop.conf.Configuration class PairDStreamFunctions[K: ClassManifest, V: ClassManifest](self: DStream[(K,V)]) extends Serializable { - - def ssc = self.ssc + + private[streaming] def ssc = self.ssc private[streaming] def defaultPartitioner(numPartitions: Int = self.ssc.sc.defaultParallelism) = { new HashPartitioner(numPartitions) @@ -242,7 +242,9 @@ extends Serializable { * Return a new DStream by applying incremental `reduceByKey` over a sliding window. * The reduced value of over a new window is calculated using the old window's reduced value : * 1. reduce the new values that entered the window (e.g., adding new counts) + * * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + * * This is more efficient than reduceByKeyAndWindow without "inverse reduce" function. * However, it is applicable to only "invertible reduce functions". * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. @@ -399,7 +401,7 @@ extends Serializable { } /** - * Cogroup `this` DStream with `other` DStream. For each key k in corresponding RDDs of `this` + * Cogroup `this` DStream with `other` DStream using a partitioner. For each key k in corresponding RDDs of `this` * or `other` DStreams, the generated RDD will contains a tuple with the list of values for that * key in both RDDs. Partitioner is used to partition each generated RDD. */ diff --git a/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala index e70822e5c3..0e21b7480c 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala @@ -13,6 +13,7 @@ import twitter4j.auth.BasicAuthorization * An optional set of string filters can be used to restrict the set of tweets. The Twitter API is * such that this may return a sampled subset of all tweets during each interval. */ +private[streaming] class TwitterInputDStream( @transient ssc_ : StreamingContext, username: String, @@ -26,6 +27,7 @@ class TwitterInputDStream( } } +private[streaming] class TwitterReceiver( username: String, password: String, diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala index 5250667bcb..cac86deeaf 100644 --- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -50,7 +50,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { val stateStreamCheckpointInterval = Seconds(1) // this ensure checkpointing occurs at least once - val firstNumBatches = (stateStreamCheckpointInterval / batchDuration) * 2 + val firstNumBatches = (stateStreamCheckpointInterval / batchDuration).toLong * 2 val secondNumBatches = firstNumBatches // Setup the streams