merged and debugged

This commit is contained in:
Joseph E. Gonzalez 2013-11-07 20:19:49 -08:00
commit e523f0d2fb
6 changed files with 514 additions and 150 deletions

View file

@ -22,7 +22,7 @@ import org.apache.spark.SparkContext._
import org.apache.spark.rdd._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.collection.{BitSet, OpenHashSet, PrimitiveKeyOpenHashMap}
import org.apache.spark.graph.impl.AggregationMsg
/**
* The `VertexSetIndex` maintains the per-partition mapping from
@ -659,6 +659,43 @@ object VertexSetRDD {
apply(rdd,index, (v:V) => v, reduceFunc, reduceFunc)
def aggregate[V: ClassManifest](
rdd: RDD[AggregationMsg[V]], index: VertexSetIndex,
reduceFunc: (V, V) => V): VertexSetRDD[V] = {
val cReduceFunc = index.rdd.context.clean(reduceFunc)
assert(rdd.partitioner == index.rdd.partitioner)
// Use the index to build the new values table
val values: RDD[ (Array[V], BitSet) ] = index.rdd.zipPartitions(rdd)( (indexIter, tblIter) => {
// There is only one map
val index = indexIter.next()
assert(!indexIter.hasNext)
val values = new Array[V](index.capacity)
val bs = new BitSet(index.capacity)
for (msg <- tblIter) {
// Get the location of the key in the index
val pos = index.getPos(msg.vid)
if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) {
throw new SparkException("Error: Trying to bind an external index " +
"to an RDD which contains keys that are not in the index.")
} else {
// Get the actual index
val ind = pos & OpenHashSet.POSITION_MASK
// If this value has already been seen then merge
if (bs.get(ind)) {
values(ind) = cReduceFunc(values(ind), msg.data)
} else { // otherwise just store the new value
bs.set(ind)
values(ind) = msg.data
}
}
}
Iterator((values, bs))
})
new VertexSetRDD(index, values)
}
/**
* Construct a vertex set from an RDD using an existing index and a
* user defined `combiner` to merge duplicate vertices.

View file

@ -5,14 +5,13 @@ import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.SparkContext._
import org.apache.spark.HashPartitioner
import org.apache.spark.util.ClosureCleaner
import org.apache.spark.graph._
import org.apache.spark.graph.impl.GraphImpl._
import org.apache.spark.graph.impl.MessageToPartitionRDDFunctions._
import org.apache.spark.graph.impl.MsgRDDFunctions._
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.collection.{BitSet, OpenHashSet, PrimitiveKeyOpenHashMap}
@ -73,8 +72,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
def this() = this(null, null, null, null)
/**
* (localVidMap: VertexSetRDD[Pid, VertexIdToIndexMap]) is a version of the
* vertex data after it is replicated. Within each partition, it holds a map
@ -87,22 +84,18 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
@transient val vTableReplicatedValues: RDD[(Pid, Array[VD])] =
createVTableReplicated(vTable, vid2pid, localVidMap)
/** Return a RDD of vertices. */
@transient override val vertices = vTable
/** Return a RDD of edges. */
@transient override val edges: RDD[Edge[ED]] = {
eTable.mapPartitions( iter => iter.next()._2.iterator , true )
}
/** Return a RDD that brings edges with its source and destination vertices together. */
@transient override val triplets: RDD[EdgeTriplet[VD, ED]] =
makeTriplets(localVidMap, vTableReplicatedValues, eTable)
override def persist(newLevel: StorageLevel): Graph[VD, ED] = {
eTable.persist(newLevel)
vid2pid.persist(newLevel)
@ -129,7 +122,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
"Min Load" -> minLoad, "Max Load" -> maxLoad)
}
/**
* Display the lineage information for this graph.
*/
@ -187,14 +179,12 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
println(visited)
} // end of print lineage
override def reverse: Graph[VD, ED] = {
val newEtable = eTable.mapPartitions( _.map{ case (pid, epart) => (pid, epart.reverse) },
preservesPartitioning = true)
new GraphImpl(vTable, vid2pid, localVidMap, newEtable)
}
override def mapVertices[VD2: ClassManifest](f: (Vid, VD) => VD2): Graph[VD2, ED] = {
val newVTable = vTable.mapValuesWithKeys((vid, data) => f(vid, data))
new GraphImpl(newVTable, vid2pid, localVidMap, eTable)
@ -206,11 +196,9 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
new GraphImpl(vTable, vid2pid, localVidMap, newETable)
}
override def mapTriplets[ED2: ClassManifest](f: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] =
GraphImpl.mapTriplets(this, f)
override def subgraph(epred: EdgeTriplet[VD,ED] => Boolean = (x => true),
vpred: (Vid, VD) => Boolean = ((a,b) => true) ): Graph[VD, ED] = {
@ -250,7 +238,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
new GraphImpl(newVTable, newVid2Pid, localVidMap, newETable)
}
override def groupEdgeTriplets[ED2: ClassManifest](
f: Iterator[EdgeTriplet[VD,ED]] => ED2 ): Graph[VD,ED2] = {
val newEdges: RDD[Edge[ED2]] = triplets.mapPartitions { partIter =>
@ -275,7 +262,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
new GraphImpl(vTable, vid2pid, localVidMap, newETable)
}
override def groupEdges[ED2: ClassManifest](f: Iterator[Edge[ED]] => ED2 ):
Graph[VD,ED2] = {
@ -293,8 +279,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
new GraphImpl(vTable, vid2pid, localVidMap, newETable)
}
//////////////////////////////////////////////////////////////////////////////////////////////////
// Lower level transformation methods
//////////////////////////////////////////////////////////////////////////////////////////////////
@ -305,7 +289,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
: VertexSetRDD[A] =
GraphImpl.mapReduceTriplets(this, mapFunc, reduceFunc)
override def outerJoinVertices[U: ClassManifest, VD2: ClassManifest]
(updates: RDD[(Vid, U)])(updateF: (Vid, VD, Option[U]) => VD2)
: Graph[VD2, ED] = {
@ -313,15 +296,9 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
val newVTable = vTable.leftJoin(updates)(updateF)
new GraphImpl(newVTable, vid2pid, localVidMap, eTable)
}
} // end of class GraphImpl
object GraphImpl {
def apply[VD: ClassManifest, ED: ClassManifest](
@ -331,7 +308,6 @@ object GraphImpl {
apply(vertices, edges, defaultVertexAttr, (a:VD, b:VD) => a)
}
def apply[VD: ClassManifest, ED: ClassManifest](
vertices: RDD[(Vid, VD)],
edges: RDD[Edge[ED]],
@ -357,7 +333,6 @@ object GraphImpl {
new GraphImpl(vtable, vid2pid, localVidMap, etable)
}
/**
* Create the edge table RDD, which is much more efficient for Java heap storage than the
* normal edges data structure (RDD[(Vid, Vid, ED)]).
@ -379,7 +354,7 @@ object GraphImpl {
//val part: Pid = canonicalEdgePartitionFunction2D(e.srcId, e.dstId, numPartitions, ceilSqrt)
// Should we be using 3-tuple or an optimized class
MessageToPartition(part, (e.srcId, e.dstId, e.attr))
new MessageToPartition(part, (e.srcId, e.dstId, e.attr))
}
.partitionBy(new HashPartitioner(numPartitions))
.mapPartitionsWithIndex( (pid, iter) => {
@ -393,7 +368,6 @@ object GraphImpl {
}, preservesPartitioning = true).cache()
}
protected def createVid2Pid[ED: ClassManifest](
eTable: RDD[(Pid, EdgePartition[ED])],
vTableIndex: VertexSetIndex): VertexSetRDD[Array[Pid]] = {
@ -410,7 +384,6 @@ object GraphImpl {
.mapValues(a => a.toArray).cache()
}
protected def createLocalVidMap[ED: ClassManifest](eTable: RDD[(Pid, EdgePartition[ED])]):
RDD[(Pid, VertexIdToIndexMap)] = {
eTable.mapPartitions( _.map{ case (pid, epart) =>
@ -423,7 +396,6 @@ object GraphImpl {
}, preservesPartitioning = true).cache()
}
protected def createVTableReplicated[VD: ClassManifest](
vTable: VertexSetRDD[VD],
vid2pid: VertexSetRDD[Array[Pid]],
@ -432,7 +404,10 @@ object GraphImpl {
// Join vid2pid and vTable, generate a shuffle dependency on the joined
// result, and get the shuffle id so we can use it on the slave.
val msgsByPartition = vTable.zipJoinFlatMap(vid2pid) { (vid, vdata, pids) =>
pids.iterator.map { pid => MessageToPartition(pid, (vid, vdata)) }
// TODO(rxin): reuse VertexBroadcastMessage
pids.iterator.map { pid =>
new VertexBroadcastMsg[VD](pid, vid, vdata)
}
}.partitionBy(replicationMap.partitioner.get).cache()
replicationMap.zipPartitions(msgsByPartition){
@ -442,8 +417,8 @@ object GraphImpl {
// Populate the vertex array using the vidToIndex map
val vertexArray = new Array[VD](vidToIndex.capacity)
for (msg <- msgsIter) {
val ind = vidToIndex.getPos(msg.data._1) & OpenHashSet.POSITION_MASK
vertexArray(ind) = msg.data._2
val ind = vidToIndex.getPos(msg.vid) & OpenHashSet.POSITION_MASK
vertexArray(ind) = msg.data
}
Iterator((pid, vertexArray))
}.cache()
@ -451,7 +426,6 @@ object GraphImpl {
// @todo assert edge table has partitioner
}
def makeTriplets[VD: ClassManifest, ED: ClassManifest](
localVidMap: RDD[(Pid, VertexIdToIndexMap)],
vTableReplicatedValues: RDD[(Pid, Array[VD]) ],
@ -465,7 +439,6 @@ object GraphImpl {
}
}
def mapTriplets[VD: ClassManifest, ED: ClassManifest, ED2: ClassManifest](
g: GraphImpl[VD, ED],
f: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = {
@ -487,7 +460,6 @@ object GraphImpl {
new GraphImpl(g.vTable, g.vid2pid, g.localVidMap, newETable)
}
def mapReduceTriplets[VD: ClassManifest, ED: ClassManifest, A: ClassManifest](
g: GraphImpl[VD, ED],
mapFunc: EdgeTriplet[VD, ED] => Array[(Vid, A)],
@ -499,33 +471,35 @@ object GraphImpl {
// Map and preaggregate
val preAgg = g.eTable.zipPartitions(g.localVidMap, g.vTableReplicatedValues){
(edgePartitionIter, vidToIndexIter, vertexArrayIter) =>
val (pid, edgePartition) = edgePartitionIter.next()
val (_, edgePartition) = edgePartitionIter.next()
val (_, vidToIndex) = vidToIndexIter.next()
val (_, vertexArray) = vertexArrayIter.next()
assert(!edgePartitionIter.hasNext)
assert(!vidToIndexIter.hasNext)
assert(!vertexArrayIter.hasNext)
assert(vidToIndex.capacity == vertexArray.size)
// Reuse the vidToIndex map to run aggregation.
val vmap = new PrimitiveKeyOpenHashMap[Vid, VD](vidToIndex, vertexArray)
// We can reuse the vidToIndex map for aggregation here as well.
/** @todo Since this has the downside of not allowing "messages" to arbitrary
* vertices we should consider just using a fresh map.
*/
// TODO(jegonzal): This doesn't allow users to send messages to arbitrary vertices.
val msgArray = new Array[A](vertexArray.size)
val msgBS = new BitSet(vertexArray.size)
// Iterate over the partition
val et = new EdgeTriplet[VD, ED]
edgePartition.foreach{e =>
edgePartition.foreach { e =>
et.set(e)
et.srcAttr = vmap(e.srcId)
et.dstAttr = vmap(e.dstId)
mapFunc(et).foreach{ case (vid, msg) =>
// TODO(rxin): rewrite the foreach using a simple while loop to speed things up.
// Also given we are only allowing zero, one, or two messages, we can completely unroll
// the for loop.
mapFunc(et).foreach { case (vid, msg) =>
// verify that the vid is valid
assert(vid == et.srcId || vid == et.dstId)
// Get the index of the key
val ind = vidToIndex.getPos(vid) & OpenHashSet.POSITION_MASK
// Populate the aggregator map
if(msgBS.get(ind)) {
if (msgBS.get(ind)) {
msgArray(ind) = reduceFunc(msgArray(ind), msg)
} else {
msgArray(ind) = msg
@ -534,20 +508,19 @@ object GraphImpl {
}
}
// construct an iterator of tuples Iterator[(Vid, A)]
msgBS.iterator.map( ind => (vidToIndex.getValue(ind), msgArray(ind)) )
msgBS.iterator.map { ind =>
new AggregationMsg[A](vidToIndex.getValue(ind), msgArray(ind))
}
}.partitionBy(g.vTable.index.rdd.partitioner.get)
// do the final reduction reusing the index map
VertexSetRDD(preAgg, g.vTable.index, reduceFunc)
VertexSetRDD.aggregate(preAgg, g.vTable.index, reduceFunc)
}
protected def edgePartitionFunction1D(src: Vid, dst: Vid, numParts: Pid): Pid = {
val mixingPrime: Vid = 1125899906842597L
(math.abs(src) * mixingPrime).toInt % numParts
}
/**
* This function implements a classic 2D-Partitioning of a sparse matrix.
* Suppose we have a graph with 11 vertices that we want to partition
@ -600,7 +573,6 @@ object GraphImpl {
(col * ceilSqrtNumParts + row) % numParts
}
/**
* Assign edges to an aribtrary machine corresponding to a
* random vertex cut.
@ -609,7 +581,6 @@ object GraphImpl {
math.abs((src, dst).hashCode()) % numParts
}
/**
* @todo This will only partition edges to the upper diagonal
* of the 2D processor space.
@ -626,4 +597,3 @@ object GraphImpl {
}
} // end of object GraphImpl

View file

@ -1,10 +1,35 @@
package org.apache.spark.graph.impl
import org.apache.spark.Partitioner
import org.apache.spark.graph.Pid
import org.apache.spark.graph.{Pid, Vid}
import org.apache.spark.rdd.{ShuffledRDD, RDD}
class VertexBroadcastMsg[@specialized(Int, Long, Double, Boolean) T](
@transient var partition: Pid,
var vid: Vid,
var data: T)
extends Product2[Pid, (Vid, T)] {
override def _1 = partition
override def _2 = (vid, data)
override def canEqual(that: Any): Boolean = that.isInstanceOf[VertexBroadcastMsg[_]]
}
class AggregationMsg[@specialized(Int, Long, Double, Boolean) T](var vid: Vid, var data: T)
extends Product2[Vid, T] {
override def _1 = vid
override def _2 = data
override def canEqual(that: Any): Boolean = that.isInstanceOf[AggregationMsg[_]]
}
/**
* A message used to send a specific value to a partition.
* @param partition index of the target partition.
@ -22,15 +47,38 @@ class MessageToPartition[@specialized(Int, Long, Double, Char, Boolean/*, AnyRef
override def canEqual(that: Any): Boolean = that.isInstanceOf[MessageToPartition[_]]
}
/**
* Companion object for MessageToPartition.
*/
object MessageToPartition {
def apply[T](partition: Pid, value: T) = new MessageToPartition(partition, value)
class VertexBroadcastMsgRDDFunctions[T: ClassManifest](self: RDD[VertexBroadcastMsg[T]]) {
def partitionBy(partitioner: Partitioner): RDD[VertexBroadcastMsg[T]] = {
val rdd = new ShuffledRDD[Pid, (Vid, T), VertexBroadcastMsg[T]](self, partitioner)
// Set a custom serializer if the data is of int or double type.
if (classManifest[T] == ClassManifest.Int) {
rdd.setSerializer(classOf[IntVertexBroadcastMsgSerializer].getName)
} else if (classManifest[T] == ClassManifest.Double) {
rdd.setSerializer(classOf[DoubleVertexBroadcastMsgSerializer].getName)
}
rdd
}
}
class MessageToPartitionRDDFunctions[T: ClassManifest](self: RDD[MessageToPartition[T]]) {
class AggregationMessageRDDFunctions[T: ClassManifest](self: RDD[AggregationMsg[T]]) {
def partitionBy(partitioner: Partitioner): RDD[AggregationMsg[T]] = {
val rdd = new ShuffledRDD[Vid, T, AggregationMsg[T]](self, partitioner)
// Set a custom serializer if the data is of int or double type.
if (classManifest[T] == ClassManifest.Int) {
rdd.setSerializer(classOf[IntAggMsgSerializer].getName)
} else if (classManifest[T] == ClassManifest.Double) {
rdd.setSerializer(classOf[DoubleAggMsgSerializer].getName)
}
rdd
}
}
class MsgRDDFunctions[T: ClassManifest](self: RDD[MessageToPartition[T]]) {
/**
* Return a copy of the RDD partitioned using the specified partitioner.
@ -42,8 +90,16 @@ class MessageToPartitionRDDFunctions[T: ClassManifest](self: RDD[MessageToPartit
}
object MessageToPartitionRDDFunctions {
object MsgRDDFunctions {
implicit def rdd2PartitionRDDFunctions[T: ClassManifest](rdd: RDD[MessageToPartition[T]]) = {
new MessageToPartitionRDDFunctions(rdd)
new MsgRDDFunctions(rdd)
}
implicit def rdd2vertexMessageRDDFunctions[T: ClassManifest](rdd: RDD[VertexBroadcastMsg[T]]) = {
new VertexBroadcastMsgRDDFunctions(rdd)
}
implicit def rdd2aggMessageRDDFunctions[T: ClassManifest](rdd: RDD[AggregationMsg[T]]) = {
new AggregationMessageRDDFunctions(rdd)
}
}

View file

@ -0,0 +1,169 @@
package org.apache.spark.graph.impl
import java.io.{InputStream, OutputStream}
import java.nio.ByteBuffer
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance, Serializer}
/** A special shuffle serializer for VertexBroadcastMessage[Int]. */
class IntVertexBroadcastMsgSerializer extends Serializer {
override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
def writeObject[T](t: T) = {
val msg = t.asInstanceOf[VertexBroadcastMsg[Int]]
writeLong(msg.vid)
writeInt(msg.data)
this
}
}
override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
override def readObject[T](): T = {
new VertexBroadcastMsg[Int](0, readLong(), readInt()).asInstanceOf[T]
}
}
}
}
/** A special shuffle serializer for VertexBroadcastMessage[Double]. */
class DoubleVertexBroadcastMsgSerializer extends Serializer {
override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
def writeObject[T](t: T) = {
val msg = t.asInstanceOf[VertexBroadcastMsg[Double]]
writeLong(msg.vid)
writeDouble(msg.data)
this
}
}
override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
def readObject[T](): T = {
new VertexBroadcastMsg[Double](0, readLong(), readDouble()).asInstanceOf[T]
}
}
}
}
/** A special shuffle serializer for AggregationMessage[Int]. */
class IntAggMsgSerializer extends Serializer {
override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
def writeObject[T](t: T) = {
val msg = t.asInstanceOf[AggregationMsg[Int]]
writeLong(msg.vid)
writeInt(msg.data)
this
}
}
override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
override def readObject[T](): T = {
new AggregationMsg[Int](readLong(), readInt()).asInstanceOf[T]
}
}
}
}
/** A special shuffle serializer for AggregationMessage[Double]. */
class DoubleAggMsgSerializer extends Serializer {
override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
def writeObject[T](t: T) = {
val msg = t.asInstanceOf[AggregationMsg[Double]]
writeLong(msg.vid)
writeDouble(msg.data)
this
}
}
override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
def readObject[T](): T = {
new AggregationMsg[Double](readLong(), readDouble()).asInstanceOf[T]
}
}
}
}
////////////////////////////////////////////////////////////////////////////////
// Helper classes to shorten the implementation of those special serializers.
////////////////////////////////////////////////////////////////////////////////
sealed abstract class ShuffleSerializationStream(s: OutputStream) extends SerializationStream {
// The implementation should override this one.
def writeObject[T](t: T): SerializationStream
def writeInt(v: Int) {
s.write(v >> 24)
s.write(v >> 16)
s.write(v >> 8)
s.write(v)
}
def writeLong(v: Long) {
s.write((v >>> 56).toInt)
s.write((v >>> 48).toInt)
s.write((v >>> 40).toInt)
s.write((v >>> 32).toInt)
s.write((v >>> 24).toInt)
s.write((v >>> 16).toInt)
s.write((v >>> 8).toInt)
s.write(v.toInt)
}
def writeDouble(v: Double) {
writeLong(java.lang.Double.doubleToLongBits(v))
}
override def flush(): Unit = s.flush()
override def close(): Unit = s.close()
}
sealed abstract class ShuffleDeserializationStream(s: InputStream) extends DeserializationStream {
// The implementation should override this one.
def readObject[T](): T
def readInt(): Int = {
(s.read() & 0xFF) << 24 | (s.read() & 0xFF) << 16 | (s.read() & 0xFF) << 8 | (s.read() & 0xFF)
}
def readLong(): Long = {
(s.read().toLong << 56) |
(s.read() & 0xFF).toLong << 48 |
(s.read() & 0xFF).toLong << 40 |
(s.read() & 0xFF).toLong << 32 |
(s.read() & 0xFF).toLong << 24 |
(s.read() & 0xFF) << 16 |
(s.read() & 0xFF) << 8 |
(s.read() & 0xFF)
}
def readDouble(): Double = java.lang.Double.longBitsToDouble(readLong())
override def close(): Unit = s.close()
}
sealed trait ShuffleSerializerInstance extends SerializerInstance {
override def serialize[T](t: T): ByteBuffer = throw new UnsupportedOperationException
override def deserialize[T](bytes: ByteBuffer): T = throw new UnsupportedOperationException
override def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T =
throw new UnsupportedOperationException
// The implementation should override the following two.
override def serializeStream(s: OutputStream): SerializationStream
override def deserializeStream(s: InputStream): DeserializationStream
}

124
graphx-shell Executable file
View file

@ -0,0 +1,124 @@
#!/usr/bin/env bash
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# Shell script for starting the Spark Shell REPL
# Note that it will set MASTER to spark://${SPARK_MASTER_IP}:${SPARK_MASTER_PORT}
# if those two env vars are set in spark-env.sh but MASTER is not.
# Options:
# -c <cores> Set the number of cores for REPL to use
#
# Enter posix mode for bash
set -o posix
# Update the the banner logo
export SPARK_BANNER_TEXT="Welcome to
______ __ _ __
/ ____/________ _____ / /_ | |/ /
/ / __/ ___/ __ \`/ __ \/ __ \| /
/ /_/ / / / /_/ / /_/ / / / / |
\____/_/ \__,_/ .___/_/ /_/_/|_|
/_/ Alpha Release
Powered by:
____ __
/ __/__ ___ _____/ /__
_\ \/ _ \/ _ \`/ __/ '_/
/___/ .__/\_,_/_/ /_/\_\
/_/ version 0.9.0
Example:
scala> val graph = GraphLoader.textFile(sc, \"hdfs://links\")
scala> graph.numVertices
scala> graph.numEdges
scala> val pageRankGraph = Analytics.pagerank(graph, 10) // 10 iterations
scala> val maxPr = pageRankGraph.vertices.map{ case (vid, pr) => pr }.max
scala> println(maxPr)
"
export SPARK_SHELL_INIT_BLOCK="import org.apache.spark.graph._;"
# Set the serializer to use Kryo for graphx objects
SPARK_JAVA_OPTS+=" -Dspark.serializer=org.apache.spark.serializer.KryoSerializer "
SPARK_JAVA_OPTS+="-Dspark.kryo.registrator=org.apache.spark.graph.GraphKryoRegistrator "
SPARK_JAVA_OPTS+="-Dspark.kryoserializer.buffer.mb=10 "
FWDIR="`dirname $0`"
for o in "$@"; do
if [ "$1" = "-c" -o "$1" = "--cores" ]; then
shift
if [ -n "$1" ]; then
OPTIONS="-Dspark.cores.max=$1"
shift
fi
fi
done
# Set MASTER from spark-env if possible
if [ -z "$MASTER" ]; then
if [ -e "$FWDIR/conf/spark-env.sh" ]; then
. "$FWDIR/conf/spark-env.sh"
fi
if [[ "x" != "x$SPARK_MASTER_IP" && "y" != "y$SPARK_MASTER_PORT" ]]; then
MASTER="spark://${SPARK_MASTER_IP}:${SPARK_MASTER_PORT}"
export MASTER
fi
fi
# Copy restore-TTY-on-exit functions from Scala script so spark-shell exits properly even in
# binary distribution of Spark where Scala is not installed
exit_status=127
saved_stty=""
# restore stty settings (echo in particular)
function restoreSttySettings() {
stty $saved_stty
saved_stty=""
}
function onExit() {
if [[ "$saved_stty" != "" ]]; then
restoreSttySettings
fi
exit $exit_status
}
# to reenable echo if we are interrupted before completing.
trap onExit INT
# save terminal settings
saved_stty=$(stty -g 2>/dev/null)
# clear on error so we don't later try to restore them
if [[ ! $? ]]; then
saved_stty=""
fi
$FWDIR/spark-class $OPTIONS org.apache.spark.repl.Main "$@"
# record the exit status lest it be overwritten:
# then reenable echo and propagate the code.
exit_status=$?
onExit

View file

@ -196,13 +196,17 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
/** Print a welcome message */
def printWelcome() {
echo("""Welcome to
val prop = System.getenv("SPARK_BANNER_TEXT")
val bannerText =
if (prop != null) prop else
"""Welcome to
____ __
/ __/__ ___ _____/ /__
_\ \/ _ \/ _ `/ __/ '_/
/___/ .__/\_,_/_/ /_/\_\ version 0.9.0-SNAPSHOT
/_/
""")
"""
echo(bannerText)
import Properties._
val welcomeMsg = "Using Scala %s (%s, Java %s)".format(
versionString, javaVmName, javaVersion)
@ -837,6 +841,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
org.apache.spark.repl.Main.interp.out.flush();
""")
command("import org.apache.spark.SparkContext._")
val prop = System.getenv("SPARK_SHELL_INIT_BLOCK")
if (prop != null) {
command(prop)
}
}
echo("Type in expressions to have them evaluated.")
echo("Type :help for more information.")