diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 82527fe663..57bde8d85f 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -78,9 +78,7 @@ private[spark] class PythonRDD[T: ClassTag]( dataOut.writeInt(command.length) dataOut.write(command) // Data values - for (elem <- parent.iterator(split, context)) { - PythonRDD.writeToStream(elem, dataOut) - } + PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut) dataOut.flush() worker.shutdownOutput() } catch { @@ -206,20 +204,43 @@ private[spark] object PythonRDD { JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) } - def writeToStream(elem: Any, dataOut: DataOutputStream) { - elem match { - case bytes: Array[Byte] => - dataOut.writeInt(bytes.length) - dataOut.write(bytes) - case pair: (Array[Byte], Array[Byte]) => - dataOut.writeInt(pair._1.length) - dataOut.write(pair._1) - dataOut.writeInt(pair._2.length) - dataOut.write(pair._2) - case str: String => - dataOut.writeUTF(str) - case other => - throw new SparkException("Unexpected element type " + other.getClass) + def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) { + // The right way to implement this would be to use TypeTags to get the full + // type of T. Since I don't want to introduce breaking changes throughout the + // entire Spark API, I have to use this hacky approach: + if (iter.hasNext) { + val first = iter.next() + val newIter = Seq(first).iterator ++ iter + first match { + case arr: Array[Byte] => + newIter.asInstanceOf[Iterator[Array[Byte]]].foreach { bytes => + dataOut.writeInt(bytes.length) + dataOut.write(bytes) + } + case string: String => + newIter.asInstanceOf[Iterator[String]].foreach { str => + dataOut.writeUTF(str) + } + case pair: Tuple2[_, _] => + pair._1 match { + case bytePair: Array[Byte] => + newIter.asInstanceOf[Iterator[Tuple2[Array[Byte], Array[Byte]]]].foreach { pair => + dataOut.writeInt(pair._1.length) + dataOut.write(pair._1) + dataOut.writeInt(pair._2.length) + dataOut.write(pair._2) + } + case stringPair: String => + newIter.asInstanceOf[Iterator[Tuple2[String, String]]].foreach { pair => + dataOut.writeUTF(pair._1) + dataOut.writeUTF(pair._2) + } + case other => + throw new SparkException("Unexpected Tuple2 element type " + pair._1.getClass) + } + case other => + throw new SparkException("Unexpected element type " + first.getClass) + } } } @@ -230,9 +251,7 @@ private[spark] object PythonRDD { def writeToFile[T](items: Iterator[T], filename: String) { val file = new DataOutputStream(new FileOutputStream(filename)) - for (item <- items) { - writeToStream(item, file) - } + writeIteratorToStream(items, file) file.close() } diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 05a9f7f0d1..acd1ca5676 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -159,6 +159,15 @@ class TestRDDFunctions(PySparkTestCase): cart = rdd1.cartesian(rdd2) result = cart.map(lambda (x, y): x + y).collect() + def test_cartesian_on_textfile(self): + # Regression test for + path = os.path.join(SPARK_HOME, "python/test_support/hello.txt") + a = self.sc.textFile(path) + result = a.cartesian(a).collect() + (x, y) = result[0] + self.assertEqual("Hello World!", x.strip()) + self.assertEqual("Hello World!", y.strip()) + class TestIO(PySparkTestCase):