From 32508e20d468fcb72fb89e6ae23c9fdd6475f0c8 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Fri, 20 Dec 2013 12:59:07 -0800 Subject: [PATCH] Test VertexPartition and fix bugs --- .../spark/graph/impl/VertexPartition.scala | 24 ++-- .../graph/impl/VertexPartitionSuite.scala | 113 ++++++++++++++++++ 2 files changed, 128 insertions(+), 9 deletions(-) create mode 100644 graph/src/test/scala/org/apache/spark/graph/impl/VertexPartitionSuite.scala diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/VertexPartition.scala b/graph/src/main/scala/org/apache/spark/graph/impl/VertexPartition.scala index 7710d6eada..9b2d66999c 100644 --- a/graph/src/main/scala/org/apache/spark/graph/impl/VertexPartition.scala +++ b/graph/src/main/scala/org/apache/spark/graph/impl/VertexPartition.scala @@ -188,8 +188,10 @@ class VertexPartition[@specialized(Long, Int, Double) VD: ClassManifest]( val newValues = new Array[VD2](capacity) iter.foreach { case (vid, vdata) => val pos = index.getPos(vid) - newMask.set(pos) - newValues(pos) = vdata + if (pos >= 0) { + newMask.set(pos) + newValues(pos) = vdata + } } new VertexPartition[VD2](index, newValues, newMask) } @@ -204,8 +206,10 @@ class VertexPartition[@specialized(Long, Int, Double) VD: ClassManifest]( System.arraycopy(values, 0, newValues, 0, newValues.length) iter.foreach { case (vid, vdata) => val pos = index.getPos(vid) - newMask.set(pos) - newValues(pos) = vdata + if (pos >= 0) { + newMask.set(pos) + newValues(pos) = vdata + } } new VertexPartition(index, newValues, newMask) } @@ -219,11 +223,13 @@ class VertexPartition[@specialized(Long, Int, Double) VD: ClassManifest]( val vid = product._1 val vdata = product._2 val pos = index.getPos(vid) - if (newMask.get(pos)) { - newValues(pos) = reduceFunc(newValues(pos), vdata) - } else { // otherwise just store the new value - newMask.set(pos) - newValues(pos) = vdata + if (pos >= 0) { + if (newMask.get(pos)) { + newValues(pos) = reduceFunc(newValues(pos), vdata) + } else { // otherwise just store the new value + newMask.set(pos) + newValues(pos) = vdata + } } } new VertexPartition[VD2](index, newValues, newMask) diff --git a/graph/src/test/scala/org/apache/spark/graph/impl/VertexPartitionSuite.scala b/graph/src/test/scala/org/apache/spark/graph/impl/VertexPartitionSuite.scala new file mode 100644 index 0000000000..72579a48c2 --- /dev/null +++ b/graph/src/test/scala/org/apache/spark/graph/impl/VertexPartitionSuite.scala @@ -0,0 +1,113 @@ +package org.apache.spark.graph.impl + +import org.apache.spark.graph._ +import org.scalatest.FunSuite + +class VertexPartitionSuite extends FunSuite { + + test("isDefined, filter") { + val vp = VertexPartition(Iterator((0L, 1), (1L, 1))).filter { (vid, attr) => vid == 0 } + assert(vp.isDefined(0)) + assert(!vp.isDefined(1)) + assert(!vp.isDefined(2)) + assert(!vp.isDefined(-1)) + } + + test("isActive, numActives, replaceActives") { + val vp = VertexPartition(Iterator((0L, 1), (1L, 1))) + .filter { (vid, attr) => vid == 0 } + .replaceActives(Iterator(0, 2, 0)) + assert(vp.isActive(0)) + assert(!vp.isActive(1)) + assert(vp.isActive(2)) + assert(!vp.isActive(-1)) + assert(vp.numActives == Some(2)) + } + + test("map") { + val vp = VertexPartition(Iterator((0L, 1), (1L, 1))).map { (vid, attr) => 2 } + assert(vp(0) === 2) + } + + test("diff") { + val vp = VertexPartition(Iterator((0L, 1), (1L, 1), (2L, 1))) + val vp2 = vp.filter { (vid, attr) => vid <= 1 } + val vp3a = vp.map { (vid, attr) => 2 } + val vp3b = VertexPartition(vp3a.iterator) + // diff with same index + val diff1 = vp2.diff(vp3a) + assert(diff1(0) === 2) + assert(diff1(1) === 2) + assert(diff1(2) === 2) + assert(!diff1.isDefined(2)) + // diff with different indexes + val diff2 = vp2.diff(vp3b) + assert(diff2(0) === 2) + assert(diff2(1) === 2) + assert(diff2(2) === 2) + assert(!diff2.isDefined(2)) + } + + test("leftJoin") { + val vp = VertexPartition(Iterator((0L, 1), (1L, 1), (2L, 1))) + val vp2a = vp.filter { (vid, attr) => vid <= 1 }.map { (vid, attr) => 2 } + val vp2b = VertexPartition(vp2a.iterator) + // leftJoin with same index + val join1 = vp.leftJoin(vp2a) { (vid, a, bOpt) => bOpt.getOrElse(a) } + assert(join1.iterator.toSet === Set((0L, 2), (1L, 2), (2L, 1))) + // leftJoin with different indexes + val join2 = vp.leftJoin(vp2b) { (vid, a, bOpt) => bOpt.getOrElse(a) } + assert(join2.iterator.toSet === Set((0L, 2), (1L, 2), (2L, 1))) + // leftJoin an iterator + val join3 = vp.leftJoin(vp2a.iterator) { (vid, a, bOpt) => bOpt.getOrElse(a) } + assert(join3.iterator.toSet === Set((0L, 2), (1L, 2), (2L, 1))) + } + + test("innerJoin") { + val vp = VertexPartition(Iterator((0L, 1), (1L, 1), (2L, 1))) + val vp2a = vp.filter { (vid, attr) => vid <= 1 }.map { (vid, attr) => 2 } + val vp2b = VertexPartition(vp2a.iterator) + // innerJoin with same index + val join1 = vp.innerJoin(vp2a) { (vid, a, b) => b } + assert(join1.iterator.toSet === Set((0L, 2), (1L, 2))) + // innerJoin with different indexes + val join2 = vp.innerJoin(vp2b) { (vid, a, b) => b } + assert(join2.iterator.toSet === Set((0L, 2), (1L, 2))) + // innerJoin an iterator + val join3 = vp.innerJoin(vp2a.iterator) { (vid, a, b) => b } + assert(join3.iterator.toSet === Set((0L, 2), (1L, 2))) + } + + test("createUsingIndex") { + val vp = VertexPartition(Iterator((0L, 1), (1L, 1), (2L, 1))) + val elems = List((0L, 2), (2L, 2), (3L, 2)) + val vp2 = vp.createUsingIndex(elems.iterator) + assert(vp2.iterator.toSet === Set((0L, 2), (2L, 2))) + assert(vp.index === vp2.index) + } + + test("innerJoinKeepLeft") { + val vp = VertexPartition(Iterator((0L, 1), (1L, 1), (2L, 1))) + val elems = List((0L, 2), (2L, 2), (3L, 2)) + val vp2 = vp.innerJoinKeepLeft(elems.iterator) + assert(vp2.iterator.toSet === Set((0L, 2), (2L, 2))) + assert(vp2(1) === 1) + } + + test("aggregateUsingIndex") { + val vp = VertexPartition(Iterator((0L, 1), (1L, 1), (2L, 1))) + val messages = List((0L, "a"), (2L, "b"), (0L, "c"), (3L, "d")) + val vp2 = vp.aggregateUsingIndex[String](messages.iterator, _ + _) + assert(vp2.iterator.toSet === Set((0L, "ac"), (2L, "b"))) + } + + test("reindex") { + val vp = VertexPartition(Iterator((0L, 1), (1L, 1), (2L, 1))) + val vp2 = vp.filter { (vid, attr) => vid <= 1 } + val vp3 = vp2.reindex() + assert(vp2.iterator.toSet === vp3.iterator.toSet) + assert(vp2(2) === 1) + assert(vp3.index.getPos(2) === -1) + } + +}