From fa28f25619d6712e5f920f498ec03085ea208b4d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 11 Dec 2012 13:59:43 -0800 Subject: [PATCH] Fixed bug in UnionRDD and CoGroupedRDD --- .../main/scala/spark/rdd/CoGroupedRDD.scala | 9 +- core/src/main/scala/spark/rdd/UnionRDD.scala | 12 +- .../test/scala/spark/CheckpointSuite.scala | 104 ------------------ 3 files changed, 10 insertions(+), 115 deletions(-) diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index e4e70b13ba..bc6d16ee8b 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -15,8 +15,11 @@ import spark.Split import java.io.{ObjectOutputStream, IOException} private[spark] sealed trait CoGroupSplitDep extends Serializable -private[spark] case class NarrowCoGroupSplitDep(rdd: RDD[_], splitIndex: Int, var split: Split = null) - extends CoGroupSplitDep { +private[spark] case class NarrowCoGroupSplitDep( + rdd: RDD[_], + splitIndex: Int, + var split: Split + ) extends CoGroupSplitDep { @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { @@ -75,7 +78,7 @@ CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) case s: ShuffleDependency[_, _] => new ShuffleCoGroupSplitDep(s.shuffleId): CoGroupSplitDep case _ => - new NarrowCoGroupSplitDep(r, i): CoGroupSplitDep + new NarrowCoGroupSplitDep(r, i, r.splits(i)): CoGroupSplitDep } }.toList) } diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index 808729f18d..a84867492b 100644 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -5,14 +5,10 @@ import scala.collection.mutable.ArrayBuffer import spark._ import java.io.{ObjectOutputStream, IOException} -private[spark] class UnionSplit[T: ClassManifest]( - idx: Int, - rdd: RDD[T], - splitIndex: Int, - var split: Split = null) - extends Split - with Serializable { - +private[spark] class UnionSplit[T: ClassManifest](idx: Int, rdd: RDD[T], splitIndex: Int) + extends Split { + var split: Split = rdd.splits(splitIndex) + def iterator() = rdd.iterator(split) def preferredLocations() = rdd.preferredLocations(split) override val index: Int = idx diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 7b323e089c..909c55c91c 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -329,114 +329,10 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { val bytes = Utils.serialize(obj) Utils.deserialize[T](bytes) } - /* - test("Consistency check for ResultTask") { - // Time -----------------------> - // Core 1: |<- count in thread 1, task 1 ->| |<-- checkpoint, task 1 ---->| |<- count in thread 2, task 2 ->| - // Core 2: |<- count in thread 1, task 2 ->| |<--- checkpoint, task 2 ---------->| |<- count in thread 2, task 1 ->| - // | - // checkpoint completed - sc.stop(); sc = null - System.clearProperty("spark.master.port") - - val dir = File.createTempFile("temp_", "") - dir.delete() - val ctxt = new SparkContext("local[2]", "ResultTask") - ctxt.setCheckpointDir(dir.toString) - - try { - val rdd = ctxt.makeRDD(1 to 2, 2).map(x => { - val state = CheckpointSuite.incrementState() - println("State = " + state) - if (state <= 3) { - // If executing the two tasks for the job comouting rdd.count - // of thread 1, or the first task for the recomputation due - // to checkpointing (saveing to HDFS), then do nothing - } else if (state == 4) { - // If executing the second task for the recomputation due to - // checkpointing. then prolong this task, to allow rdd.count - // of thread 2 to start before checkpoint of this RDD is completed - - Thread.sleep(1000) - println("State = " + state + " wake up") - } else { - // Else executing the tasks from thread 2 - Thread.sleep(1000) - println("State = " + state + " wake up") - } - - (x, 1) - }) - rdd.checkpoint() - val env = SparkEnv.get - - val thread1 = new Thread() { - override def run() { - try { - SparkEnv.set(env) - rdd.count() - } catch { - case e: Exception => CheckpointSuite.failed("Exception in thread 1", e) - } - } - } - thread1.start() - - val thread2 = new Thread() { - override def run() { - try { - SparkEnv.set(env) - CheckpointSuite.waitTillState(3) - println("\n\n\n\n") - rdd.count() - } catch { - case e: Exception => CheckpointSuite.failed("Exception in thread 2", e) - } - } - } - thread2.start() - - thread1.join() - thread2.join() - } finally { - dir.delete() - } - - assert(!CheckpointSuite.failed, CheckpointSuite.failureMessage) - - ctxt.stop() - - } - */ } object CheckpointSuite { - /* - var state = 0 - var failed = false - var failureMessage = "" - - def incrementState(): Int = { - this.synchronized { state += 1; this.notifyAll(); state } - } - - def getState(): Int = { - this.synchronized( state ) - } - - def waitTillState(s: Int) { - while(state < s) { - this.synchronized { this.wait() } - } - } - - def failed(msg: String, ex: Exception) { - failed = true - failureMessage += msg + "\n" + ex + "\n\n" - } - */ - // This is a custom cogroup function that does not use mapValues like // the PairRDDFunctions.cogroup() def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) = {