Enable the Broadcast examples to work in a cluster setting

Since they rely on println to display results, we need to first collect
those results to the driver to have them actually display locally.
This commit is contained in:
Aaron Davidson 2013-11-18 22:51:35 -08:00
parent e2ebc3a9d8
commit 50fd8d98c0
2 changed files with 14 additions and 11 deletions

View file

@ -32,13 +32,13 @@ object BroadcastTest {
System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast." + bcName + "BroadcastFactory")
System.setProperty("spark.broadcast.blockSize", blockSize)
val sc = new SparkContext(args(0), "Broadcast Test 2",
val sc = new SparkContext(args(0), "Broadcast Test",
System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")))
val slices = if (args.length > 1) args(1).toInt else 2
val num = if (args.length > 2) args(2).toInt else 1000000
var arr1 = new Array[Int](num)
val arr1 = new Array[Int](num)
for (i <- 0 until arr1.length) {
arr1(i) = i
}
@ -48,9 +48,9 @@ object BroadcastTest {
println("===========")
val startTime = System.nanoTime
val barr1 = sc.broadcast(arr1)
sc.parallelize(1 to 10, slices).foreach {
i => println(barr1.value.size)
}
val observedSizes = sc.parallelize(1 to 10, slices).map(_ => barr1.value.size)
// Collect the small RDD so we can print the observed sizes locally.
observedSizes.collect().foreach(i => println(i))
println("Iteration %d took %.0f milliseconds".format(i, (System.nanoTime - startTime) / 1E6))
}

View file

@ -18,35 +18,38 @@
package org.apache.spark.examples
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
object MultiBroadcastTest {
def main(args: Array[String]) {
if (args.length == 0) {
System.err.println("Usage: BroadcastTest <master> [<slices>] [numElem]")
System.err.println("Usage: MultiBroadcastTest <master> [<slices>] [numElem]")
System.exit(1)
}
val sc = new SparkContext(args(0), "Broadcast Test",
val sc = new SparkContext(args(0), "Multi-Broadcast Test",
System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")))
val slices = if (args.length > 1) args(1).toInt else 2
val num = if (args.length > 2) args(2).toInt else 1000000
var arr1 = new Array[Int](num)
val arr1 = new Array[Int](num)
for (i <- 0 until arr1.length) {
arr1(i) = i
}
var arr2 = new Array[Int](num)
val arr2 = new Array[Int](num)
for (i <- 0 until arr2.length) {
arr2(i) = i
}
val barr1 = sc.broadcast(arr1)
val barr2 = sc.broadcast(arr2)
sc.parallelize(1 to 10, slices).foreach {
i => println(barr1.value.size + barr2.value.size)
val observedSizes: RDD[(Int, Int)] = sc.parallelize(1 to 10, slices).map { _ =>
(barr1.value.size, barr2.value.size)
}
// Collect the small RDD so we can print the observed sizes locally.
observedSizes.collect().foreach(i => println(i))
System.exit(0)
}