[SPARK-8373] [PYSPARK] Add emptyRDD to pyspark and fix the issue when calling sum on an empty RDD

This PR fixes the sum issue and also adds `emptyRDD` so that it's easy to create a test case.

Author: zsxwing <zsxwing@gmail.com>

Closes #6826 from zsxwing/python-emptyRDD and squashes the following commits:

b36993f [zsxwing] Update the return type to JavaRDD[T]
71df047 [zsxwing] Add emptyRDD to pyspark and fix the issue when calling sum on an empty RDD

(cherry picked from commit 0fc4b96f3e)
Signed-off-by: Andrew Or <andrew@databricks.com>
This commit is contained in:
zsxwing 2015-06-17 13:59:39 -07:00 committed by Andrew Or
parent f0513733d4
commit 5e7973df0e
4 changed files with 20 additions and 1 deletions

View file

@ -425,6 +425,11 @@ private[spark] object PythonRDD extends Logging {
iter.foreach(write) iter.foreach(write)
} }
/** Create an RDD that has no partitions or elements. */
def emptyRDD[T](sc: JavaSparkContext): JavaRDD[T] = {
sc.emptyRDD[T]
}
/** /**
* Create an RDD from a path using [[org.apache.hadoop.mapred.SequenceFileInputFormat]], * Create an RDD from a path using [[org.apache.hadoop.mapred.SequenceFileInputFormat]],
* key and value class. * key and value class.

View file

@ -324,6 +324,12 @@ class SparkContext(object):
with SparkContext._lock: with SparkContext._lock:
SparkContext._active_spark_context = None SparkContext._active_spark_context = None
def emptyRDD(self):
"""
Create an RDD that has no partitions or elements.
"""
return RDD(self._jsc.emptyRDD(), self, NoOpSerializer())
def range(self, start, end=None, step=1, numSlices=None): def range(self, start, end=None, step=1, numSlices=None):
""" """
Create a new RDD of int containing elements from `start` to `end` Create a new RDD of int containing elements from `start` to `end`

View file

@ -960,7 +960,7 @@ class RDD(object):
>>> sc.parallelize([1.0, 2.0, 3.0]).sum() >>> sc.parallelize([1.0, 2.0, 3.0]).sum()
6.0 6.0
""" """
return self.mapPartitions(lambda x: [sum(x)]).reduce(operator.add) return self.mapPartitions(lambda x: [sum(x)]).fold(0, operator.add)
def count(self): def count(self):
""" """

View file

@ -458,6 +458,14 @@ class RDDTests(ReusedPySparkTestCase):
self.assertEqual(id + 1, id2) self.assertEqual(id + 1, id2)
self.assertEqual(id2, rdd2.id()) self.assertEqual(id2, rdd2.id())
def test_empty_rdd(self):
rdd = self.sc.emptyRDD()
self.assertTrue(rdd.isEmpty())
def test_sum(self):
self.assertEqual(0, self.sc.emptyRDD().sum())
self.assertEqual(6, self.sc.parallelize([1, 2, 3]).sum())
def test_save_as_textfile_with_unicode(self): def test_save_as_textfile_with_unicode(self):
# Regression test for SPARK-970 # Regression test for SPARK-970
x = u"\u00A1Hola, mundo!" x = u"\u00A1Hola, mundo!"