From 3c37928fab0801fa1e2662d873dac4b4f93c547d Mon Sep 17 00:00:00 2001 From: "Joseph E. Gonzalez" Date: Mon, 4 Nov 2013 20:10:15 -0800 Subject: [PATCH 1/3] This commit adds a new graphx-shell which is essentially the same as the spark shell but with GraphX packages automatically imported and with Kryo serialization enabled for GraphX types. In addition the graphx-shell has a nifty new logo. To make these changes minimally invasive in the SparkILoop.scala I added some additional environment variables: SPARK_BANNER_TEXT: If set this string is displayed instead of the spark logo SPARK_SHELL_INIT_BLOCK: if set this expression is evaluated in the spark shell after the spark context is created. --- graphx-shell | 124 +++++++++++++ .../org/apache/spark/repl/SparkILoop.scala | 174 +++++++++--------- 2 files changed, 215 insertions(+), 83 deletions(-) create mode 100755 graphx-shell diff --git a/graphx-shell b/graphx-shell new file mode 100755 index 0000000000..4dd6c68ace --- /dev/null +++ b/graphx-shell @@ -0,0 +1,124 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# +# Shell script for starting the Spark Shell REPL +# Note that it will set MASTER to spark://${SPARK_MASTER_IP}:${SPARK_MASTER_PORT} +# if those two env vars are set in spark-env.sh but MASTER is not. +# Options: +# -c Set the number of cores for REPL to use +# + +# Enter posix mode for bash +set -o posix + + +# Update the the banner logo +export SPARK_BANNER_TEXT="Welcome to + ______ __ _ __ + / ____/________ _____ / /_ | |/ / + / / __/ ___/ __ \`/ __ \/ __ \| / + / /_/ / / / /_/ / /_/ / / / / | + \____/_/ \__,_/ .___/_/ /_/_/|_| + /_/ Alpha Release + +Powered by: + ____ __ + / __/__ ___ _____/ /__ + _\ \/ _ \/ _ \`/ __/ '_/ + /___/ .__/\_,_/_/ /_/\_\ + /_/ version 0.9.0 + +Example: + + scala> val graph = GraphLoader.textFile(sc, \"hdfs://links\") + scala> graph.numVertices + scala> graph.numEdges + scala> val pageRankGraph = Analytics.pagerank(graph, 10) // 10 iterations + scala> val maxPr = pageRankGraph.vertices.map{ case (vid, pr) => pr }.max + scala> println(maxPr) + +" + +export SPARK_SHELL_INIT_BLOCK="import org.apache.spark.graph._;" + +# Set the serializer to use Kryo for graphx objects +SPARK_JAVA_OPTS+=" -Dspark.serializer=org.apache.spark.serializer.KryoSerializer " +SPARK_JAVA_OPTS+="-Dspark.kryo.registrator=org.apache.spark.graph.GraphKryoRegistrator " +SPARK_JAVA_OPTS+="-Dspark.kryoserializer.buffer.mb=10 " + + + +FWDIR="`dirname $0`" + +for o in "$@"; do + if [ "$1" = "-c" -o "$1" = "--cores" ]; then + shift + if [ -n "$1" ]; then + OPTIONS="-Dspark.cores.max=$1" + shift + fi + fi +done + +# Set MASTER from spark-env if possible +if [ -z "$MASTER" ]; then + if [ -e "$FWDIR/conf/spark-env.sh" ]; then + . "$FWDIR/conf/spark-env.sh" + fi + if [[ "x" != "x$SPARK_MASTER_IP" && "y" != "y$SPARK_MASTER_PORT" ]]; then + MASTER="spark://${SPARK_MASTER_IP}:${SPARK_MASTER_PORT}" + export MASTER + fi +fi + +# Copy restore-TTY-on-exit functions from Scala script so spark-shell exits properly even in +# binary distribution of Spark where Scala is not installed +exit_status=127 +saved_stty="" + +# restore stty settings (echo in particular) +function restoreSttySettings() { + stty $saved_stty + saved_stty="" +} + +function onExit() { + if [[ "$saved_stty" != "" ]]; then + restoreSttySettings + fi + exit $exit_status +} + +# to reenable echo if we are interrupted before completing. +trap onExit INT + +# save terminal settings +saved_stty=$(stty -g 2>/dev/null) +# clear on error so we don't later try to restore them +if [[ ! $? ]]; then + saved_stty="" +fi + +$FWDIR/spark-class $OPTIONS org.apache.spark.repl.Main "$@" + +# record the exit status lest it be overwritten: +# then reenable echo and propagate the code. +exit_status=$? +onExit diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 0ced284da6..efdd90c47f 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -45,7 +45,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: def this(in0: BufferedReader, out: PrintWriter, master: String) = this(Some(in0), out, Some(master)) def this(in0: BufferedReader, out: PrintWriter) = this(Some(in0), out, None) def this() = this(None, new PrintWriter(Console.out, true), None) - + var in: InteractiveReader = _ // the input stream from which commands come var settings: Settings = _ var intp: SparkIMain = _ @@ -56,16 +56,16 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: Power[g.type](this, g) } */ - + // TODO // object opt extends AestheticSettings - // + // @deprecated("Use `intp` instead.", "2.9.0") def interpreter = intp - + @deprecated("Use `intp` instead.", "2.9.0") def interpreter_= (i: SparkIMain): Unit = intp = i - + def history = in.history /** The context class loader at the time this object was created */ @@ -75,7 +75,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: private val signallable = /*if (isReplDebug) Signallable("Dump repl state.")(dumpCommand()) else*/ null - + // classpath entries added via :cp var addedClasspath: String = "" @@ -87,10 +87,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: /** Record a command for replay should the user request a :replay */ def addReplay(cmd: String) = replayCommandStack ::= cmd - + /** Try to install sigint handler: ignore failure. Signal handler * will interrupt current line execution if any is in progress. - * + * * Attempting to protect the repl from accidental exit, we only honor * a single ctrl-C if the current buffer is empty: otherwise we look * for a second one within a short time. @@ -124,7 +124,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: Thread.currentThread.setContextClassLoader(originalClassLoader) } } - + class SparkILoopInterpreter extends SparkIMain(settings, out) { override lazy val formatting = new Formatting { def prompt = SparkILoop.this.prompt @@ -135,7 +135,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: |// She's gone rogue, captain! Have to take her out! |// Calling Thread.stop on runaway %s with offending code: |// scala> %s""".stripMargin - + echo(template.format(line.thread, line.code)) // XXX no way to suppress the deprecation warning line.thread.stop() @@ -151,7 +151,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: def createInterpreter() { if (addedClasspath != "") settings.classpath append addedClasspath - + intp = new SparkILoopInterpreter intp.setContextClassLoader() installSigIntHandler() @@ -168,10 +168,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: private def helpSummary() = { val usageWidth = commands map (_.usageMsg.length) max val formatStr = "%-" + usageWidth + "s %s %s" - + echo("All commands can be abbreviated, e.g. :he instead of :help.") echo("Those marked with a * have more detailed help, e.g. :help imports.\n") - + commands foreach { cmd => val star = if (cmd.hasLongHelp) "*" else " " echo(formatStr.format(cmd.usageMsg, star, cmd.help)) @@ -182,7 +182,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: case Nil => echo(cmd + ": no such command. Type :help for help.") case xs => echo(cmd + " is ambiguous: did you mean " + xs.map(":" + _.name).mkString(" or ") + "?") } - Result(true, None) + Result(true, None) } private def matchingCommands(cmd: String) = commands filter (_.name startsWith cmd) private def uniqueCommand(cmd: String): Option[LoopCommand] = { @@ -193,31 +193,35 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: case xs => xs find (_.name == cmd) } } - + /** Print a welcome message */ def printWelcome() { - echo("""Welcome to - ____ __ + val prop = System.getenv("SPARK_BANNER_TEXT") + val bannerText = + if (prop != null) prop else + """Welcome to + ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ /___/ .__/\_,_/_/ /_/\_\ version 0.9.0-SNAPSHOT - /_/ -""") + /_/ + """ + echo(bannerText) import Properties._ val welcomeMsg = "Using Scala %s (%s, Java %s)".format( - versionString, javaVmName, javaVersion) + versionString, javaVmName, javaVersion) echo(welcomeMsg) } - + /** Show the history */ lazy val historyCommand = new LoopCommand("history", "show the history (optional num is commands to show)") { override def usage = "[num]" def defaultLines = 20 - + def apply(line: String): Result = { if (history eq NoHistory) return "No history available." - + val xs = words(line) val current = history.index val count = try xs.head.toInt catch { case _: Exception => defaultLines } @@ -237,21 +241,21 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: out print msg out.flush() } - + /** Search the history */ def searchHistory(_cmdline: String) { val cmdline = _cmdline.toLowerCase val offset = history.index - history.size + 1 - + for ((line, index) <- history.asStrings.zipWithIndex ; if line.toLowerCase contains cmdline) echo("%d %s".format(index + offset, line)) } - + private var currentPrompt = Properties.shellPromptString def setPrompt(prompt: String) = currentPrompt = prompt /** Prompt to print when awaiting input */ def prompt = currentPrompt - + import LoopCommand.{ cmd, nullary } /** Standard commands **/ @@ -273,7 +277,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: nullary("silent", "disable/enable automatic printing of results", verbosity), cmd("type", "", "display the type of an expression without evaluating it", typeCommand) ) - + /** Power user commands */ lazy val powerCommands: List[LoopCommand] = List( //nullary("dump", "displays a view of the interpreter's internal state", dumpCommand), @@ -298,10 +302,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: |An argument of clear will remove the wrapper if any is active. |Note that wrappers do not compose (a new one replaces the old |one) and also that the :phase command uses the same machinery, - |so setting :wrap will clear any :phase setting. + |so setting :wrap will clear any :phase setting. """.stripMargin.trim) ) - + /* private def dumpCommand(): Result = { echo("" + power) @@ -309,7 +313,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: in.redrawLine() } */ - + private val typeTransforms = List( "scala.collection.immutable." -> "immutable.", "scala.collection.mutable." -> "mutable.", @@ -317,7 +321,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: "java.lang." -> "jl.", "scala.runtime." -> "runtime." ) - + private def importsCommand(line: String): Result = { val tokens = words(line) val handlers = intp.languageWildcardHandlers ++ intp.importHandlers @@ -333,7 +337,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: val implicitMsg = if (imps.isEmpty) "" else imps.size + " are implicit" val foundMsg = if (found.isEmpty) "" else found.mkString(" // imports: ", ", ", "") val statsMsg = List(typeMsg, termMsg, implicitMsg) filterNot (_ == "") mkString ("(", ", ", ")") - + intp.reporter.printMessage("%2d) %-30s %s%s".format( idx + 1, handler.importString, @@ -342,12 +346,12 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: )) } } - + private def implicitsCommand(line: String): Result = { val intp = SparkILoop.this.intp import intp._ import global.Symbol - + def p(x: Any) = intp.reporter.printMessage("" + x) // If an argument is given, only show a source with that @@ -360,14 +364,14 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: else (args exists (source.name.toString contains _)) } } - + if (filtered.isEmpty) return "No implicits have been imported other than those in Predef." - + filtered foreach { case (source, syms) => p("/* " + syms.size + " implicit members imported from " + source.fullName + " */") - + // This groups the members by where the symbol is defined val byOwner = syms groupBy (_.owner) val sortedOwners = byOwner.toList sortBy { case (owner, _) => intp.afterTyper(source.info.baseClasses indexOf owner) } @@ -388,10 +392,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: xss map (xs => xs sortBy (_.name.toString)) } - - val ownerMessage = if (owner == source) " defined in " else " inherited from " + + val ownerMessage = if (owner == source) " defined in " else " inherited from " p(" /* " + members.size + ownerMessage + owner.fullName + " */") - + memberGroups foreach { group => group foreach (s => p(" " + intp.symbolDefString(s))) p("") @@ -400,7 +404,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: p("") } } - + protected def newJavap() = new Javap(intp.classLoader, new SparkIMain.ReplStrippingWriter(intp)) { override def tryClass(path: String): Array[Byte] = { // Look for Foo first, then Foo$, but if Foo$ is given explicitly, @@ -417,20 +421,20 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: private lazy val javap = try newJavap() catch { case _: Exception => null } - + private def typeCommand(line: String): Result = { intp.typeOfExpression(line) match { case Some(tp) => tp.toString case _ => "Failed to determine type." } } - + private def javapCommand(line: String): Result = { if (javap == null) return ":javap unavailable on this platform." if (line == "") return ":javap [-lcsvp] [path1 path2 ...]" - + javap(words(line)) foreach { res => if (res.isError) return "Failed: " + res.value else res.show() @@ -504,25 +508,25 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: } else { val what = phased.parse(name) - if (what.isEmpty || !phased.set(what)) + if (what.isEmpty || !phased.set(what)) "'" + name + "' does not appear to represent a valid phase." else { intp.setExecutionWrapper(pathToPhaseWrapper) val activeMessage = if (what.toString.length == name.length) "" + what else "%s (%s)".format(what, name) - + "Active phase is now: " + activeMessage } } } */ - + /** Available commands */ def commands: List[LoopCommand] = standardCommands /* ++ ( if (isReplPower) powerCommands else Nil )*/ - + val replayQuestionMessage = """|The repl compiler has crashed spectacularly. Shall I replay your |session? I can re-run all lines except the last one. @@ -579,10 +583,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: } /** interpret all lines from a specified file */ - def interpretAllFrom(file: File) { + def interpretAllFrom(file: File) { val oldIn = in val oldReplay = replayCommandStack - + try file applyReader { reader => in = SimpleReader(reader, out, false) echo("Loading " + file + "...") @@ -604,26 +608,26 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: echo("") } } - + /** fork a shell and run a command */ lazy val shCommand = new LoopCommand("sh", "run a shell command (result is implicitly => List[String])") { override def usage = "" def apply(line: String): Result = line match { case "" => showUsage() - case _ => + case _ => val toRun = classOf[ProcessResult].getName + "(" + string2codeQuoted(line) + ")" intp interpret toRun () } } - + def withFile(filename: String)(action: File => Unit) { val f = File(filename) - + if (f.exists) action(f) else echo("That file does not exist") } - + def loadCommand(arg: String) = { var shouldReplay: Option[String] = None withFile(arg)(f => { @@ -657,7 +661,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: } else echo("The path '" + f + "' doesn't seem to exist.") } - + def powerCmd(): Result = { if (isReplPower) "Already in power mode." else enablePowerMode() @@ -667,13 +671,13 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: //power.unleash() //echo(power.banner) } - + def verbosity() = { val old = intp.printResults intp.printResults = !old echo("Switched " + (if (old) "off" else "on") + " result printing.") } - + /** Run one command submitted by the user. Two values are returned: * (1) whether to keep running, (2) the line to record for replay, * if any. */ @@ -688,11 +692,11 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: else if (intp.global == null) Result(false, None) // Notice failure to create compiler else Result(true, interpretStartingWith(line)) } - + private def readWhile(cond: String => Boolean) = { Iterator continually in.readLine("") takeWhile (x => x != null && cond(x)) } - + def pasteCommand(): Result = { echo("// Entering paste mode (ctrl-D to finish)\n") val code = readWhile(_ => true) mkString "\n" @@ -700,17 +704,17 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: intp interpret code () } - + private object paste extends Pasted { val ContinueString = " | " val PromptString = "scala> " - + def interpret(line: String): Unit = { echo(line.trim) intp interpret line echo("") } - + def transcript(start: String) = { // Printing this message doesn't work very well because it's buried in the // transcript they just pasted. Todo: a short timer goes off when @@ -731,7 +735,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: def interpretStartingWith(code: String): Option[String] = { // signal completion non-completion input has been received in.completion.resetVerbosity() - + def reallyInterpret = { val reallyResult = intp.interpret(code) (reallyResult, reallyResult match { @@ -741,7 +745,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: if (in.interactive && code.endsWith("\n\n")) { echo("You typed two blank lines. Starting a new command.") None - } + } else in.readLine(ContinueString) match { case null => // we know compilation is going to fail since we're at EOF and the @@ -755,10 +759,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: } }) } - + /** Here we place ourselves between the user and the interpreter and examine * the input they are ostensibly submitting. We intervene in several cases: - * + * * 1) If the line starts with "scala> " it is assumed to be an interpreter paste. * 2) If the line starts with "." (but not ".." or "./") it is treated as an invocation * on the previous result. @@ -787,7 +791,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: val (code, result) = reallyInterpret //if (power != null && code == IR.Error) // runCompletion - + result } else runCompletion match { @@ -808,7 +812,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: } case _ => } - + /** Tries to create a JLineReader, falling back to SimpleReader: * unless settings or properties are such that it should start * with SimpleReader. @@ -837,6 +841,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: org.apache.spark.repl.Main.interp.out.flush(); """) command("import org.apache.spark.SparkContext._") + val prop = System.getenv("SPARK_SHELL_INIT_BLOCK") + if (prop != null) { + command(prop) + } } echo("Type in expressions to have them evaluated.") echo("Type :help for more information.") @@ -884,7 +892,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: this.settings = settings createInterpreter() - + // sets in to some kind of reader depending on environmental cues in = in0 match { case Some(reader) => SimpleReader(reader, out, true) @@ -895,10 +903,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: // it is broken on startup; go ahead and exit if (intp.reporter.hasErrors) return false - - try { + + try { // this is about the illusion of snappiness. We call initialize() - // which spins off a separate thread, then print the prompt and try + // which spins off a separate thread, then print the prompt and try // our best to look ready. Ideally the user will spend a // couple seconds saying "wow, it starts so fast!" and by the time // they type a command the compiler is ready to roll. @@ -920,19 +928,19 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: def neededHelp(): String = (if (command.settings.help.value) command.usageMsg + "\n" else "") + (if (command.settings.Xhelp.value) command.xusageMsg + "\n" else "") - + // if they asked for no help and command is valid, we call the real main neededHelp() match { case "" => command.ok && process(command.settings) case help => echoNoNL(help) ; true } } - + @deprecated("Use `process` instead", "2.9.0") def main(args: Array[String]): Unit = { if (isReplDebug) System.out.println(new java.util.Date) - + process(args) } @deprecated("Use `process` instead", "2.9.0") @@ -948,7 +956,7 @@ object SparkILoop { // like if you'd just typed it into the repl. def runForTranscript(code: String, settings: Settings): String = { import java.io.{ BufferedReader, StringReader, OutputStreamWriter } - + stringFromStream { ostream => Console.withOut(ostream) { val output = new PrintWriter(new OutputStreamWriter(ostream), true) { @@ -977,19 +985,19 @@ object SparkILoop { } } } - + /** Creates an interpreter loop with default settings and feeds * the given code to it as input. */ def run(code: String, sets: Settings = new Settings): String = { import java.io.{ BufferedReader, StringReader, OutputStreamWriter } - + stringFromStream { ostream => Console.withOut(ostream) { val input = new BufferedReader(new StringReader(code)) val output = new PrintWriter(new OutputStreamWriter(ostream), true) val repl = new SparkILoop(input, output) - + if (sets.classpath.isDefault) sets.classpath.value = sys.props("java.class.path") @@ -1017,7 +1025,7 @@ object SparkILoop { repl.settings.embeddedDefaults[T] repl.createInterpreter() repl.in = SparkJLineReader(repl) - + // rebind exit so people don't accidentally call sys.exit by way of predef repl.quietRun("""def exit = println("Type :quit to resume program execution.")""") args foreach (p => repl.bind(p.name, p.tpe, p.value)) @@ -1025,5 +1033,5 @@ object SparkILoop { echo("\nDebug repl exiting.") repl.closeInterpreter() - } + } } From 2406bf33e4381dc172d28311646954b08b614a6c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 7 Nov 2013 19:18:58 -0800 Subject: [PATCH 2/3] Use custom serializer for aggregation messages when the data type is int/double. --- .../apache/spark/graph/impl/GraphImpl.scala | 61 ++------- .../spark/graph/impl/MessageToPartition.scala | 35 ++++- .../apache/spark/graph/impl/Serializers.scala | 125 ++++++++++++++++++ 3 files changed, 171 insertions(+), 50 deletions(-) create mode 100644 graph/src/main/scala/org/apache/spark/graph/impl/Serializers.scala diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala index b88c952feb..d0df35d422 100644 --- a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala +++ b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala @@ -5,7 +5,6 @@ import scala.collection.JavaConversions._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer - import org.apache.spark.SparkContext._ import org.apache.spark.HashPartitioner import org.apache.spark.util.ClosureCleaner @@ -72,8 +71,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( def this() = this(null, null, null, null) - - /** * (localVidMap: VertexSetRDD[Pid, VertexIdToIndexMap]) is a version of the * vertex data after it is replicated. Within each partition, it holds a map @@ -86,22 +83,18 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( @transient val vTableReplicatedValues: RDD[(Pid, Array[VD])] = createVTableReplicated(vTable, vid2pid, localVidMap) - /** Return a RDD of vertices. */ @transient override val vertices = vTable - /** Return a RDD of edges. */ @transient override val edges: RDD[Edge[ED]] = { eTable.mapPartitions( iter => iter.next()._2.iterator , true ) } - /** Return a RDD that brings edges with its source and destination vertices together. */ @transient override val triplets: RDD[EdgeTriplet[VD, ED]] = makeTriplets(localVidMap, vTableReplicatedValues, eTable) - override def cache(): Graph[VD, ED] = { eTable.cache() vid2pid.cache() @@ -109,7 +102,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( this } - override def statistics: Map[String, Any] = { val numVertices = this.numVertices val numEdges = this.numEdges @@ -125,7 +117,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( "Min Load" -> minLoad, "Max Load" -> maxLoad) } - /** * Display the lineage information for this graph. */ @@ -183,14 +174,12 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( println(visited) } // end of print lineage - override def reverse: Graph[VD, ED] = { val newEtable = eTable.mapPartitions( _.map{ case (pid, epart) => (pid, epart.reverse) }, preservesPartitioning = true) new GraphImpl(vTable, vid2pid, localVidMap, newEtable) } - override def mapVertices[VD2: ClassManifest](f: (Vid, VD) => VD2): Graph[VD2, ED] = { val newVTable = vTable.mapValuesWithKeys((vid, data) => f(vid, data)) new GraphImpl(newVTable, vid2pid, localVidMap, eTable) @@ -202,11 +191,9 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( new GraphImpl(vTable, vid2pid, localVidMap, newETable) } - override def mapTriplets[ED2: ClassManifest](f: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = GraphImpl.mapTriplets(this, f) - override def subgraph(epred: EdgeTriplet[VD,ED] => Boolean = (x => true), vpred: (Vid, VD) => Boolean = ((a,b) => true) ): Graph[VD, ED] = { @@ -246,7 +233,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( new GraphImpl(newVTable, newVid2Pid, localVidMap, newETable) } - override def groupEdgeTriplets[ED2: ClassManifest]( f: Iterator[EdgeTriplet[VD,ED]] => ED2 ): Graph[VD,ED2] = { val newEdges: RDD[Edge[ED2]] = triplets.mapPartitions { partIter => @@ -271,7 +257,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( new GraphImpl(vTable, vid2pid, localVidMap, newETable) } - override def groupEdges[ED2: ClassManifest](f: Iterator[Edge[ED]] => ED2 ): Graph[VD,ED2] = { @@ -289,8 +274,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( new GraphImpl(vTable, vid2pid, localVidMap, newETable) } - - ////////////////////////////////////////////////////////////////////////////////////////////////// // Lower level transformation methods ////////////////////////////////////////////////////////////////////////////////////////////////// @@ -301,7 +284,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( : VertexSetRDD[A] = GraphImpl.mapReduceTriplets(this, mapFunc, reduceFunc) - override def outerJoinVertices[U: ClassManifest, VD2: ClassManifest] (updates: RDD[(Vid, U)])(updateF: (Vid, VD, Option[U]) => VD2) : Graph[VD2, ED] = { @@ -309,15 +291,9 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( val newVTable = vTable.leftJoin(updates)(updateF) new GraphImpl(newVTable, vid2pid, localVidMap, eTable) } - - } // end of class GraphImpl - - - - object GraphImpl { def apply[VD: ClassManifest, ED: ClassManifest]( @@ -327,7 +303,6 @@ object GraphImpl { apply(vertices, edges, defaultVertexAttr, (a:VD, b:VD) => a) } - def apply[VD: ClassManifest, ED: ClassManifest]( vertices: RDD[(Vid, VD)], edges: RDD[Edge[ED]], @@ -353,7 +328,6 @@ object GraphImpl { new GraphImpl(vtable, vid2pid, localVidMap, etable) } - /** * Create the edge table RDD, which is much more efficient for Java heap storage than the * normal edges data structure (RDD[(Vid, Vid, ED)]). @@ -389,7 +363,6 @@ object GraphImpl { }, preservesPartitioning = true).cache() } - protected def createVid2Pid[ED: ClassManifest]( eTable: RDD[(Pid, EdgePartition[ED])], vTableIndex: VertexSetIndex): VertexSetRDD[Array[Pid]] = { @@ -406,7 +379,6 @@ object GraphImpl { .mapValues(a => a.toArray).cache() } - protected def createLocalVidMap[ED: ClassManifest](eTable: RDD[(Pid, EdgePartition[ED])]): RDD[(Pid, VertexIdToIndexMap)] = { eTable.mapPartitions( _.map{ case (pid, epart) => @@ -419,7 +391,6 @@ object GraphImpl { }, preservesPartitioning = true).cache() } - protected def createVTableReplicated[VD: ClassManifest]( vTable: VertexSetRDD[VD], vid2pid: VertexSetRDD[Array[Pid]], @@ -428,9 +399,9 @@ object GraphImpl { // Join vid2pid and vTable, generate a shuffle dependency on the joined // result, and get the shuffle id so we can use it on the slave. val msgsByPartition = vTable.zipJoinFlatMap(vid2pid) { (vid, vdata, pids) => - pids.iterator.map { pid => MessageToPartition(pid, (vid, vdata)) } + pids.iterator.map { pid => new VertexMessage[VD](pid, vid, vdata) } }.partitionBy(replicationMap.partitioner.get).cache() - + replicationMap.zipPartitions(msgsByPartition){ (mapIter, msgsIter) => val (pid, vidToIndex) = mapIter.next() @@ -438,8 +409,8 @@ object GraphImpl { // Populate the vertex array using the vidToIndex map val vertexArray = new Array[VD](vidToIndex.capacity) for (msg <- msgsIter) { - val ind = vidToIndex.getPos(msg.data._1) & OpenHashSet.POSITION_MASK - vertexArray(ind) = msg.data._2 + val ind = vidToIndex.getPos(msg.vid) & OpenHashSet.POSITION_MASK + vertexArray(ind) = msg.data } Iterator((pid, vertexArray)) }.cache() @@ -447,7 +418,6 @@ object GraphImpl { // @todo assert edge table has partitioner } - def makeTriplets[VD: ClassManifest, ED: ClassManifest]( localVidMap: RDD[(Pid, VertexIdToIndexMap)], vTableReplicatedValues: RDD[(Pid, Array[VD]) ], @@ -461,7 +431,6 @@ object GraphImpl { } } - def mapTriplets[VD: ClassManifest, ED: ClassManifest, ED2: ClassManifest]( g: GraphImpl[VD, ED], f: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = { @@ -483,7 +452,6 @@ object GraphImpl { new GraphImpl(g.vTable, g.vid2pid, g.localVidMap, newETable) } - def mapReduceTriplets[VD: ClassManifest, ED: ClassManifest, A: ClassManifest]( g: GraphImpl[VD, ED], mapFunc: EdgeTriplet[VD, ED] => Array[(Vid, A)], @@ -495,33 +463,34 @@ object GraphImpl { // Map and preaggregate val preAgg = g.eTable.zipPartitions(g.localVidMap, g.vTableReplicatedValues){ (edgePartitionIter, vidToIndexIter, vertexArrayIter) => - val (pid, edgePartition) = edgePartitionIter.next() + val (_, edgePartition) = edgePartitionIter.next() val (_, vidToIndex) = vidToIndexIter.next() val (_, vertexArray) = vertexArrayIter.next() assert(!edgePartitionIter.hasNext) assert(!vidToIndexIter.hasNext) assert(!vertexArrayIter.hasNext) assert(vidToIndex.capacity == vertexArray.size) + // Reuse the vidToIndex map to run aggregation. val vmap = new PrimitiveKeyOpenHashMap[Vid, VD](vidToIndex, vertexArray) - // We can reuse the vidToIndex map for aggregation here as well. - /** @todo Since this has the downside of not allowing "messages" to arbitrary - * vertices we should consider just using a fresh map. - */ + // TODO(jegonzal): This doesn't allow users to send messages to arbitrary vertices. val msgArray = new Array[A](vertexArray.size) val msgBS = new BitSet(vertexArray.size) // Iterate over the partition val et = new EdgeTriplet[VD, ED] - edgePartition.foreach{e => + edgePartition.foreach { e => et.set(e) et.srcAttr = vmap(e.srcId) et.dstAttr = vmap(e.dstId) + // TODO(rxin): rewrite the foreach using a simple while loop to speed things up. + // Also given we are only allowing zero, one, or two messages, we can completely unroll + // the for loop. mapFunc(et).foreach{ case (vid, msg) => // verify that the vid is valid assert(vid == et.srcId || vid == et.dstId) // Get the index of the key val ind = vidToIndex.getPos(vid) & OpenHashSet.POSITION_MASK // Populate the aggregator map - if(msgBS.get(ind)) { + if (msgBS.get(ind)) { msgArray(ind) = reduceFunc(msgArray(ind), msg) } else { msgArray(ind) = msg @@ -536,14 +505,11 @@ object GraphImpl { VertexSetRDD(preAgg, g.vTable.index, reduceFunc) } - protected def edgePartitionFunction1D(src: Vid, dst: Vid, numParts: Pid): Pid = { val mixingPrime: Vid = 1125899906842597L (math.abs(src) * mixingPrime).toInt % numParts } - - /** * This function implements a classic 2D-Partitioning of a sparse matrix. * Suppose we have a graph with 11 vertices that we want to partition @@ -596,7 +562,6 @@ object GraphImpl { (col * ceilSqrtNumParts + row) % numParts } - /** * Assign edges to an aribtrary machine corresponding to a * random vertex cut. @@ -605,7 +570,6 @@ object GraphImpl { math.abs((src, dst).hashCode()) % numParts } - /** * @todo This will only partition edges to the upper diagonal * of the 2D processor space. @@ -622,4 +586,3 @@ object GraphImpl { } } // end of object GraphImpl - diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/MessageToPartition.scala b/graph/src/main/scala/org/apache/spark/graph/impl/MessageToPartition.scala index b7bbf257a4..9ac2c59bf8 100644 --- a/graph/src/main/scala/org/apache/spark/graph/impl/MessageToPartition.scala +++ b/graph/src/main/scala/org/apache/spark/graph/impl/MessageToPartition.scala @@ -1,10 +1,24 @@ package org.apache.spark.graph.impl import org.apache.spark.Partitioner -import org.apache.spark.graph.Pid +import org.apache.spark.graph.{Pid, Vid} import org.apache.spark.rdd.{ShuffledRDD, RDD} +class VertexMessage[@specialized(Int, Long, Double, Boolean/*, AnyRef*/) T]( + @transient var partition: Pid, + var vid: Vid, + var data: T) + extends Product2[Pid, (Vid, T)] { + + override def _1 = partition + + override def _2 = (vid, data) + + override def canEqual(that: Any): Boolean = that.isInstanceOf[VertexMessage[_]] +} + + /** * A message used to send a specific value to a partition. * @param partition index of the target partition. @@ -30,6 +44,21 @@ object MessageToPartition { } +class VertexMessageRDDFunctions[T: ClassManifest](self: RDD[VertexMessage[T]]) { + def partitionBy(partitioner: Partitioner): RDD[VertexMessage[T]] = { + val rdd = new ShuffledRDD[Pid, (Vid, T), VertexMessage[T]](self, partitioner) + + // Set a custom serializer if the data is of int or double type. + if (classManifest[T] == ClassManifest.Int) { + rdd.setSerializer(classOf[IntVertexMessageSerializer].getName) + } else if (classManifest[T] == ClassManifest.Double) { + rdd.setSerializer(classOf[DoubleVertexMessageSerializer].getName) + } + rdd + } +} + + class MessageToPartitionRDDFunctions[T: ClassManifest](self: RDD[MessageToPartition[T]]) { /** @@ -46,4 +75,8 @@ object MessageToPartitionRDDFunctions { implicit def rdd2PartitionRDDFunctions[T: ClassManifest](rdd: RDD[MessageToPartition[T]]) = { new MessageToPartitionRDDFunctions(rdd) } + + implicit def rdd2vertexMessageRDDFunctions[T: ClassManifest](rdd: RDD[VertexMessage[T]]) = { + new VertexMessageRDDFunctions(rdd) + } } diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/Serializers.scala b/graph/src/main/scala/org/apache/spark/graph/impl/Serializers.scala new file mode 100644 index 0000000000..0092aa7c6b --- /dev/null +++ b/graph/src/main/scala/org/apache/spark/graph/impl/Serializers.scala @@ -0,0 +1,125 @@ +package org.apache.spark.graph.impl + +import java.io.{InputStream, OutputStream} +import java.nio.ByteBuffer + +import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance, Serializer} + + +/** A special shuffle serializer for VertexMessage[Int]. */ +class IntVertexMessageSerializer extends Serializer { + override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { + + override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { + def writeObject[T](t: T) = { + val msg = t.asInstanceOf[VertexMessage[Int]] + writeLong(msg.vid) + writeInt(msg.data) + this + } + } + + override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { + override def readObject[T](): T = { + new VertexMessage[Int](0, readLong(), readInt()).asInstanceOf[T] + } + } + } +} + + +/** A special shuffle serializer for VertexMessage[Double]. */ +class DoubleVertexMessageSerializer extends Serializer { + override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { + + override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { + def writeObject[T](t: T) = { + val msg = t.asInstanceOf[VertexMessage[Double]] + writeLong(msg.vid) + writeDouble(msg.data) + this + } + } + + override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { + def readObject[T](): T = { + new VertexMessage[Double](0, readLong(), readDouble()).asInstanceOf[T] + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////// +// Helper classes to shorten the implementation of those special serializers. +//////////////////////////////////////////////////////////////////////////////// + +sealed abstract class ShuffleSerializationStream(s: OutputStream) extends SerializationStream { + // The implementation should override this one. + def writeObject[T](t: T): SerializationStream + + def writeInt(v: Int) { + s.write(v >> 24) + s.write(v >> 16) + s.write(v >> 8) + s.write(v) + } + + def writeLong(v: Long) { + s.write((v >>> 56).toInt) + s.write((v >>> 48).toInt) + s.write((v >>> 40).toInt) + s.write((v >>> 32).toInt) + s.write((v >>> 24).toInt) + s.write((v >>> 16).toInt) + s.write((v >>> 8).toInt) + s.write(v.toInt) + } + + def writeDouble(v: Double) { + writeLong(java.lang.Double.doubleToLongBits(v)) + } + + override def flush(): Unit = s.flush() + + override def close(): Unit = s.close() +} + + +sealed abstract class ShuffleDeserializationStream(s: InputStream) extends DeserializationStream { + // The implementation should override this one. + def readObject[T](): T + + def readInt(): Int = { + (s.read() & 0xFF) << 24 | (s.read() & 0xFF) << 16 | (s.read() & 0xFF) << 8 | (s.read() & 0xFF) + } + + def readLong(): Long = { + (s.read().toLong << 56) | + (s.read() & 0xFF).toLong << 48 | + (s.read() & 0xFF).toLong << 40 | + (s.read() & 0xFF).toLong << 32 | + (s.read() & 0xFF).toLong << 24 | + (s.read() & 0xFF) << 16 | + (s.read() & 0xFF) << 8 | + (s.read() & 0xFF) + } + + def readDouble(): Double = java.lang.Double.longBitsToDouble(readLong()) + + override def close(): Unit = s.close() +} + + +sealed trait ShuffleSerializerInstance extends SerializerInstance { + + override def serialize[T](t: T): ByteBuffer = throw new UnsupportedOperationException + + override def deserialize[T](bytes: ByteBuffer): T = throw new UnsupportedOperationException + + override def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = + throw new UnsupportedOperationException + + // The implementation should override the following two. + override def serializeStream(s: OutputStream): SerializationStream + override def deserializeStream(s: InputStream): DeserializationStream +} From bac7be30cd9d58ee4bbc86fa78ba0cc90a84892e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 7 Nov 2013 19:39:48 -0800 Subject: [PATCH 3/3] Made more specialized messages. --- .../apache/spark/graph/impl/GraphImpl.scala | 13 ++-- .../spark/graph/impl/MessageToPartition.scala | 61 +++++++++++++------ .../apache/spark/graph/impl/Serializers.scala | 60 +++++++++++++++--- 3 files changed, 103 insertions(+), 31 deletions(-) diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala index a6c4cc4b66..c38780a265 100644 --- a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala +++ b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala @@ -11,7 +11,7 @@ import org.apache.spark.util.ClosureCleaner import org.apache.spark.graph._ import org.apache.spark.graph.impl.GraphImpl._ -import org.apache.spark.graph.impl.MessageToPartitionRDDFunctions._ +import org.apache.spark.graph.impl.MsgRDDFunctions._ import org.apache.spark.rdd.RDD import org.apache.spark.util.collection.{BitSet, OpenHashSet, PrimitiveKeyOpenHashMap} @@ -349,7 +349,7 @@ object GraphImpl { //val part: Pid = canonicalEdgePartitionFunction2D(e.srcId, e.dstId, numPartitions, ceilSqrt) // Should we be using 3-tuple or an optimized class - MessageToPartition(part, (e.srcId, e.dstId, e.attr)) + new MessageToPartition(part, (e.srcId, e.dstId, e.attr)) } .partitionBy(new HashPartitioner(numPartitions)) .mapPartitionsWithIndex( (pid, iter) => { @@ -399,7 +399,10 @@ object GraphImpl { // Join vid2pid and vTable, generate a shuffle dependency on the joined // result, and get the shuffle id so we can use it on the slave. val msgsByPartition = vTable.zipJoinFlatMap(vid2pid) { (vid, vdata, pids) => - pids.iterator.map { pid => new VertexMessage[VD](pid, vid, vdata) } + // TODO(rxin): reuse VertexBroadcastMessage + pids.iterator.map { pid => + new VertexBroadcastMsg[VD](pid, vid, vdata) + } }.partitionBy(replicationMap.partitioner.get).cache() replicationMap.zipPartitions(msgsByPartition){ @@ -500,7 +503,9 @@ object GraphImpl { } } // construct an iterator of tuples Iterator[(Vid, A)] - msgBS.iterator.map( ind => (vidToIndex.getValue(ind), msgArray(ind)) ) + msgBS.iterator.map { ind => + new AggregationMsg[A](vidToIndex.getValue(ind), msgArray(ind)) + } }.partitionBy(g.vTable.index.rdd.partitioner.get) // do the final reduction reusing the index map VertexSetRDD(preAgg, g.vTable.index, reduceFunc) diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/MessageToPartition.scala b/graph/src/main/scala/org/apache/spark/graph/impl/MessageToPartition.scala index 9ac2c59bf8..3fc0b7c0f7 100644 --- a/graph/src/main/scala/org/apache/spark/graph/impl/MessageToPartition.scala +++ b/graph/src/main/scala/org/apache/spark/graph/impl/MessageToPartition.scala @@ -5,7 +5,7 @@ import org.apache.spark.graph.{Pid, Vid} import org.apache.spark.rdd.{ShuffledRDD, RDD} -class VertexMessage[@specialized(Int, Long, Double, Boolean/*, AnyRef*/) T]( +class VertexBroadcastMsg[@specialized(Int, Long, Double, Boolean) T]( @transient var partition: Pid, var vid: Vid, var data: T) @@ -15,7 +15,18 @@ class VertexMessage[@specialized(Int, Long, Double, Boolean/*, AnyRef*/) T]( override def _2 = (vid, data) - override def canEqual(that: Any): Boolean = that.isInstanceOf[VertexMessage[_]] + override def canEqual(that: Any): Boolean = that.isInstanceOf[VertexBroadcastMsg[_]] +} + + +class AggregationMsg[@specialized(Int, Long, Double, Boolean) T](var vid: Vid, var data: T) + extends Product2[Vid, T] { + + override def _1 = vid + + override def _2 = data + + override def canEqual(that: Any): Boolean = that.isInstanceOf[AggregationMsg[_]] } @@ -36,30 +47,38 @@ class MessageToPartition[@specialized(Int, Long, Double, Char, Boolean/*, AnyRef override def canEqual(that: Any): Boolean = that.isInstanceOf[MessageToPartition[_]] } -/** - * Companion object for MessageToPartition. - */ -object MessageToPartition { - def apply[T](partition: Pid, value: T) = new MessageToPartition(partition, value) -} - -class VertexMessageRDDFunctions[T: ClassManifest](self: RDD[VertexMessage[T]]) { - def partitionBy(partitioner: Partitioner): RDD[VertexMessage[T]] = { - val rdd = new ShuffledRDD[Pid, (Vid, T), VertexMessage[T]](self, partitioner) +class VertexBroadcastMsgRDDFunctions[T: ClassManifest](self: RDD[VertexBroadcastMsg[T]]) { + def partitionBy(partitioner: Partitioner): RDD[VertexBroadcastMsg[T]] = { + val rdd = new ShuffledRDD[Pid, (Vid, T), VertexBroadcastMsg[T]](self, partitioner) // Set a custom serializer if the data is of int or double type. if (classManifest[T] == ClassManifest.Int) { - rdd.setSerializer(classOf[IntVertexMessageSerializer].getName) + rdd.setSerializer(classOf[IntVertexBroadcastMsgSerializer].getName) } else if (classManifest[T] == ClassManifest.Double) { - rdd.setSerializer(classOf[DoubleVertexMessageSerializer].getName) + rdd.setSerializer(classOf[DoubleVertexBroadcastMsgSerializer].getName) } rdd } } -class MessageToPartitionRDDFunctions[T: ClassManifest](self: RDD[MessageToPartition[T]]) { +class AggregationMessageRDDFunctions[T: ClassManifest](self: RDD[AggregationMsg[T]]) { + def partitionBy(partitioner: Partitioner): RDD[AggregationMsg[T]] = { + val rdd = new ShuffledRDD[Vid, T, AggregationMsg[T]](self, partitioner) + + // Set a custom serializer if the data is of int or double type. + if (classManifest[T] == ClassManifest.Int) { + rdd.setSerializer(classOf[IntAggMsgSerializer].getName) + } else if (classManifest[T] == ClassManifest.Double) { + rdd.setSerializer(classOf[DoubleAggMsgSerializer].getName) + } + rdd + } +} + + +class MsgRDDFunctions[T: ClassManifest](self: RDD[MessageToPartition[T]]) { /** * Return a copy of the RDD partitioned using the specified partitioner. @@ -71,12 +90,16 @@ class MessageToPartitionRDDFunctions[T: ClassManifest](self: RDD[MessageToPartit } -object MessageToPartitionRDDFunctions { +object MsgRDDFunctions { implicit def rdd2PartitionRDDFunctions[T: ClassManifest](rdd: RDD[MessageToPartition[T]]) = { - new MessageToPartitionRDDFunctions(rdd) + new MsgRDDFunctions(rdd) } - implicit def rdd2vertexMessageRDDFunctions[T: ClassManifest](rdd: RDD[VertexMessage[T]]) = { - new VertexMessageRDDFunctions(rdd) + implicit def rdd2vertexMessageRDDFunctions[T: ClassManifest](rdd: RDD[VertexBroadcastMsg[T]]) = { + new VertexBroadcastMsgRDDFunctions(rdd) + } + + implicit def rdd2aggMessageRDDFunctions[T: ClassManifest](rdd: RDD[AggregationMsg[T]]) = { + new AggregationMessageRDDFunctions(rdd) } } diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/Serializers.scala b/graph/src/main/scala/org/apache/spark/graph/impl/Serializers.scala index 0092aa7c6b..8b4c0868b1 100644 --- a/graph/src/main/scala/org/apache/spark/graph/impl/Serializers.scala +++ b/graph/src/main/scala/org/apache/spark/graph/impl/Serializers.scala @@ -6,13 +6,13 @@ import java.nio.ByteBuffer import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance, Serializer} -/** A special shuffle serializer for VertexMessage[Int]. */ -class IntVertexMessageSerializer extends Serializer { +/** A special shuffle serializer for VertexBroadcastMessage[Int]. */ +class IntVertexBroadcastMsgSerializer extends Serializer { override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { def writeObject[T](t: T) = { - val msg = t.asInstanceOf[VertexMessage[Int]] + val msg = t.asInstanceOf[VertexBroadcastMsg[Int]] writeLong(msg.vid) writeInt(msg.data) this @@ -21,20 +21,20 @@ class IntVertexMessageSerializer extends Serializer { override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { override def readObject[T](): T = { - new VertexMessage[Int](0, readLong(), readInt()).asInstanceOf[T] + new VertexBroadcastMsg[Int](0, readLong(), readInt()).asInstanceOf[T] } } } } -/** A special shuffle serializer for VertexMessage[Double]. */ -class DoubleVertexMessageSerializer extends Serializer { +/** A special shuffle serializer for VertexBroadcastMessage[Double]. */ +class DoubleVertexBroadcastMsgSerializer extends Serializer { override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { def writeObject[T](t: T) = { - val msg = t.asInstanceOf[VertexMessage[Double]] + val msg = t.asInstanceOf[VertexBroadcastMsg[Double]] writeLong(msg.vid) writeDouble(msg.data) this @@ -43,7 +43,51 @@ class DoubleVertexMessageSerializer extends Serializer { override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { def readObject[T](): T = { - new VertexMessage[Double](0, readLong(), readDouble()).asInstanceOf[T] + new VertexBroadcastMsg[Double](0, readLong(), readDouble()).asInstanceOf[T] + } + } + } +} + + +/** A special shuffle serializer for AggregationMessage[Int]. */ +class IntAggMsgSerializer extends Serializer { + override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { + + override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { + def writeObject[T](t: T) = { + val msg = t.asInstanceOf[AggregationMsg[Int]] + writeLong(msg.vid) + writeInt(msg.data) + this + } + } + + override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { + override def readObject[T](): T = { + new AggregationMsg[Int](readLong(), readInt()).asInstanceOf[T] + } + } + } +} + + +/** A special shuffle serializer for AggregationMessage[Double]. */ +class DoubleAggMsgSerializer extends Serializer { + override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { + + override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { + def writeObject[T](t: T) = { + val msg = t.asInstanceOf[AggregationMsg[Double]] + writeLong(msg.vid) + writeDouble(msg.data) + this + } + } + + override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { + def readObject[T](): T = { + new AggregationMsg[Double](readLong(), readDouble()).asInstanceOf[T] } } }