[SPARK-6660][MLLIB] pythonToJava doesn't recognize object arrays

davies

Author: Xiangrui Meng <meng@databricks.com>

Closes #5318 from mengxr/SPARK-6660 and squashes the following commits:

0f66ec2 [Xiangrui Meng] recognize object arrays
ad8c42f [Xiangrui Meng] add a test for SPARK-6660
This commit is contained in:
Xiangrui Meng 2015-04-01 18:17:07 -07:00
parent 757b2e9175
commit 4815bc2128
2 changed files with 12 additions and 1 deletions

View file

@ -1113,7 +1113,10 @@ private[spark] object SerDe extends Serializable {
iter.flatMap { row =>
val obj = unpickle.loads(row)
if (batched) {
obj.asInstanceOf[JArrayList[_]].asScala
obj match {
case list: JArrayList[_] => list.asScala
case arr: Array[_] => arr
}
} else {
Seq(obj)
}

View file

@ -36,6 +36,7 @@ if sys.version_info[:2] <= (2, 6):
else:
import unittest
from pyspark.mllib.common import _to_java_object_rdd
from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\
DenseMatrix, Vectors, Matrices
from pyspark.mllib.regression import LabeledPoint
@ -641,6 +642,13 @@ class FeatureTest(PySparkTestCase):
idf = model.idf()
self.assertEqual(len(idf), 11)
class SerDeTest(PySparkTestCase):
def test_to_java_object_rdd(self): # SPARK-6660
data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0L)
self.assertEqual(_to_java_object_rdd(data).count(), 10)
if __name__ == "__main__":
if not _have_scipy:
print "NOTE: Skipping SciPy tests as it does not seem to be installed"