From 4815bc2128c7f6d4d21da730b8c72da087233b34 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 1 Apr 2015 18:17:07 -0700 Subject: [PATCH] [SPARK-6660][MLLIB] pythonToJava doesn't recognize object arrays davies Author: Xiangrui Meng Closes #5318 from mengxr/SPARK-6660 and squashes the following commits: 0f66ec2 [Xiangrui Meng] recognize object arrays ad8c42f [Xiangrui Meng] add a test for SPARK-6660 --- .../apache/spark/mllib/api/python/PythonMLLibAPI.scala | 5 ++++- python/pyspark/mllib/tests.py | 8 ++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 662ec5fbed..5995d6df97 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -1113,7 +1113,10 @@ private[spark] object SerDe extends Serializable { iter.flatMap { row => val obj = unpickle.loads(row) if (batched) { - obj.asInstanceOf[JArrayList[_]].asScala + obj match { + case list: JArrayList[_] => list.asScala + case arr: Array[_] => arr + } } else { Seq(obj) } diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 893fc6f491..6e9c68ec8a 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -36,6 +36,7 @@ if sys.version_info[:2] <= (2, 6): else: import unittest +from pyspark.mllib.common import _to_java_object_rdd from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ DenseMatrix, Vectors, Matrices from pyspark.mllib.regression import LabeledPoint @@ -641,6 +642,13 @@ class FeatureTest(PySparkTestCase): idf = model.idf() self.assertEqual(len(idf), 11) + +class SerDeTest(PySparkTestCase): + def test_to_java_object_rdd(self): # SPARK-6660 + data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0L) + self.assertEqual(_to_java_object_rdd(data).count(), 10) + + if __name__ == "__main__": if not _have_scipy: print "NOTE: Skipping SciPy tests as it does not seem to be installed"