[SPARK-2334] fix AttributeError when call PipelineRDD.id()
The underline JavaRDD for PipelineRDD is created lazily, it's delayed until call _jrdd. The id of JavaRDD is cached as `_id`, it saves a RPC call in py4j for later calls. closes #1276 Author: Davies Liu <davies.liu@gmail.com> Closes #2296 from davies/id and squashes the following commits: e197958 [Davies Liu] fix style 9721716 [Davies Liu] fix id of PipelineRDD
This commit is contained in:
parent
21a1e1bb89
commit
110fb8b24d
|
@ -2075,6 +2075,7 @@ class PipelinedRDD(RDD):
|
|||
self.ctx = prev.ctx
|
||||
self.prev = prev
|
||||
self._jrdd_val = None
|
||||
self._id = None
|
||||
self._jrdd_deserializer = self.ctx.serializer
|
||||
self._bypass_serializer = False
|
||||
self._partitionFunc = prev._partitionFunc if self.preservesPartitioning else None
|
||||
|
@ -2105,6 +2106,11 @@ class PipelinedRDD(RDD):
|
|||
self._jrdd_val = python_rdd.asJavaRDD()
|
||||
return self._jrdd_val
|
||||
|
||||
def id(self):
|
||||
if self._id is None:
|
||||
self._id = self._jrdd.id()
|
||||
return self._id
|
||||
|
||||
def _is_pipelinable(self):
|
||||
return not (self.is_cached or self.is_checkpointed)
|
||||
|
||||
|
|
|
@ -1525,7 +1525,7 @@ class SchemaRDD(RDD):
|
|||
self.sql_ctx = sql_ctx
|
||||
self._sc = sql_ctx._sc
|
||||
self._jschema_rdd = jschema_rdd
|
||||
|
||||
self._id = None
|
||||
self.is_cached = False
|
||||
self.is_checkpointed = False
|
||||
self.ctx = self.sql_ctx._sc
|
||||
|
@ -1543,9 +1543,10 @@ class SchemaRDD(RDD):
|
|||
self._lazy_jrdd = self._jschema_rdd.javaToPython()
|
||||
return self._lazy_jrdd
|
||||
|
||||
@property
|
||||
def _id(self):
|
||||
return self._jrdd.id()
|
||||
def id(self):
|
||||
if self._id is None:
|
||||
self._id = self._jrdd.id()
|
||||
return self._id
|
||||
|
||||
def saveAsParquetFile(self, path):
|
||||
"""Save the contents as a Parquet file, preserving the schema.
|
||||
|
|
|
@ -281,6 +281,15 @@ class TestAddFile(PySparkTestCase):
|
|||
|
||||
class TestRDDFunctions(PySparkTestCase):
|
||||
|
||||
def test_id(self):
|
||||
rdd = self.sc.parallelize(range(10))
|
||||
id = rdd.id()
|
||||
self.assertEqual(id, rdd.id())
|
||||
rdd2 = rdd.map(str).filter(bool)
|
||||
id2 = rdd2.id()
|
||||
self.assertEqual(id + 1, id2)
|
||||
self.assertEqual(id2, rdd2.id())
|
||||
|
||||
def test_failed_sparkcontext_creation(self):
|
||||
# Regression test for SPARK-1550
|
||||
self.sc.stop()
|
||||
|
|
Loading…
Reference in a new issue