Fix SPARK-978: ClassCastException in PySpark cartesian.
This commit is contained in:
parent
0035dbbc81
commit
61569906cc
|
@ -78,9 +78,7 @@ private[spark] class PythonRDD[T: ClassTag](
|
|||
dataOut.writeInt(command.length)
|
||||
dataOut.write(command)
|
||||
// Data values
|
||||
for (elem <- parent.iterator(split, context)) {
|
||||
PythonRDD.writeToStream(elem, dataOut)
|
||||
}
|
||||
PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
|
||||
dataOut.flush()
|
||||
worker.shutdownOutput()
|
||||
} catch {
|
||||
|
@ -206,20 +204,43 @@ private[spark] object PythonRDD {
|
|||
JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
|
||||
}
|
||||
|
||||
def writeToStream(elem: Any, dataOut: DataOutputStream) {
|
||||
elem match {
|
||||
case bytes: Array[Byte] =>
|
||||
dataOut.writeInt(bytes.length)
|
||||
dataOut.write(bytes)
|
||||
case pair: (Array[Byte], Array[Byte]) =>
|
||||
dataOut.writeInt(pair._1.length)
|
||||
dataOut.write(pair._1)
|
||||
dataOut.writeInt(pair._2.length)
|
||||
dataOut.write(pair._2)
|
||||
case str: String =>
|
||||
dataOut.writeUTF(str)
|
||||
case other =>
|
||||
throw new SparkException("Unexpected element type " + other.getClass)
|
||||
def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) {
|
||||
// The right way to implement this would be to use TypeTags to get the full
|
||||
// type of T. Since I don't want to introduce breaking changes throughout the
|
||||
// entire Spark API, I have to use this hacky approach:
|
||||
if (iter.hasNext) {
|
||||
val first = iter.next()
|
||||
val newIter = Seq(first).iterator ++ iter
|
||||
first match {
|
||||
case arr: Array[Byte] =>
|
||||
newIter.asInstanceOf[Iterator[Array[Byte]]].foreach { bytes =>
|
||||
dataOut.writeInt(bytes.length)
|
||||
dataOut.write(bytes)
|
||||
}
|
||||
case string: String =>
|
||||
newIter.asInstanceOf[Iterator[String]].foreach { str =>
|
||||
dataOut.writeUTF(str)
|
||||
}
|
||||
case pair: Tuple2[_, _] =>
|
||||
pair._1 match {
|
||||
case bytePair: Array[Byte] =>
|
||||
newIter.asInstanceOf[Iterator[Tuple2[Array[Byte], Array[Byte]]]].foreach { pair =>
|
||||
dataOut.writeInt(pair._1.length)
|
||||
dataOut.write(pair._1)
|
||||
dataOut.writeInt(pair._2.length)
|
||||
dataOut.write(pair._2)
|
||||
}
|
||||
case stringPair: String =>
|
||||
newIter.asInstanceOf[Iterator[Tuple2[String, String]]].foreach { pair =>
|
||||
dataOut.writeUTF(pair._1)
|
||||
dataOut.writeUTF(pair._2)
|
||||
}
|
||||
case other =>
|
||||
throw new SparkException("Unexpected Tuple2 element type " + pair._1.getClass)
|
||||
}
|
||||
case other =>
|
||||
throw new SparkException("Unexpected element type " + first.getClass)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -230,9 +251,7 @@ private[spark] object PythonRDD {
|
|||
|
||||
def writeToFile[T](items: Iterator[T], filename: String) {
|
||||
val file = new DataOutputStream(new FileOutputStream(filename))
|
||||
for (item <- items) {
|
||||
writeToStream(item, file)
|
||||
}
|
||||
writeIteratorToStream(items, file)
|
||||
file.close()
|
||||
}
|
||||
|
||||
|
|
|
@ -159,6 +159,15 @@ class TestRDDFunctions(PySparkTestCase):
|
|||
cart = rdd1.cartesian(rdd2)
|
||||
result = cart.map(lambda (x, y): x + y).collect()
|
||||
|
||||
def test_cartesian_on_textfile(self):
|
||||
# Regression test for
|
||||
path = os.path.join(SPARK_HOME, "python/test_support/hello.txt")
|
||||
a = self.sc.textFile(path)
|
||||
result = a.cartesian(a).collect()
|
||||
(x, y) = result[0]
|
||||
self.assertEqual("Hello World!", x.strip())
|
||||
self.assertEqual("Hello World!", y.strip())
|
||||
|
||||
|
||||
class TestIO(PySparkTestCase):
|
||||
|
||||
|
|
Loading…
Reference in a new issue