[SPARK-31788][CORE][PYTHON] Fix UnionRDD of PairRDDs
### What changes were proposed in this pull request? UnionRDD of PairRDDs causing a bug. The fix is to check for instance type before proceeding ### Why are the changes needed? Changes are needed to avoid users running into issues with union rdd operation with any other type other than JavaRDD. ### Does this PR introduce _any_ user-facing change? Yes Before: SparkSession available as 'spark'. >>> rdd1 = sc.parallelize([1,2,3,4,5]) >>> rdd2 = sc.parallelize([6,7,8,9,10]) >>> pairRDD1 = rdd1.zip(rdd2) >>> unionRDD1 = sc.union([pairRDD1, pairRDD1]) Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/home/gs/spark/latest/python/pyspark/context.py", line 870, in union jrdds[i] = rdds[i]._jrdd File "/home/gs/spark/latest/python/lib/py4j-0.10.9-src.zip/py4j/java_collections.py", line 238, in setitem File "/home/gs/spark/latest/python/lib/py4j-0.10.9-src.zip/py4j/java_collections.py", line 221, in __set_item File "/home/gs/spark/latest/python/lib/py4j-0.10.9-src.zip/py4j/protocol.py", line 332, in get_return_value py4j.protocol.Py4JError: An error occurred while calling None.None. Trace: py4j.Py4JException: Cannot convert org.apache.spark.api.java.JavaPairRDD to org.apache.spark.api.java.JavaRDD at py4j.commands.ArrayCommand.convertArgument(ArrayCommand.java:166) at py4j.commands.ArrayCommand.setArray(ArrayCommand.java:144) at py4j.commands.ArrayCommand.execute(ArrayCommand.java:97) at py4j.GatewayConnection.run(GatewayConnection.java:238) at java.lang.Thread.run(Thread.java:748) After: >>> rdd2 = sc.parallelize([6,7,8,9,10]) >>> pairRDD1 = rdd1.zip(rdd2) >>> unionRDD1 = sc.union([pairRDD1, pairRDD1]) >>> unionRDD1.collect() [(1, 6), (2, 7), (3, 8), (4, 9), (5, 10), (1, 6), (2, 7), (3, 8), (4, 9), (5, 10)] ### How was this patch tested? Tested with the reproduced piece of code above manually Closes #28603 from redsanket/SPARK-31788. Authored-by: schintap <schintap@verizonmedia.com> Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
parent
753636e86b
commit
a61911c50c
|
@ -25,6 +25,7 @@ from threading import RLock
|
|||
from tempfile import NamedTemporaryFile
|
||||
|
||||
from py4j.protocol import Py4JError
|
||||
from py4j.java_gateway import is_instance_of
|
||||
|
||||
from pyspark import accumulators
|
||||
from pyspark.accumulators import Accumulator
|
||||
|
@ -864,10 +865,17 @@ class SparkContext(object):
|
|||
first_jrdd_deserializer = rdds[0]._jrdd_deserializer
|
||||
if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds):
|
||||
rdds = [x._reserialize() for x in rdds]
|
||||
gw = SparkContext._gateway
|
||||
cls = SparkContext._jvm.org.apache.spark.api.java.JavaRDD
|
||||
jrdds = SparkContext._gateway.new_array(cls, len(rdds))
|
||||
is_jrdd = is_instance_of(gw, rdds[0]._jrdd, cls)
|
||||
jrdds = gw.new_array(cls, len(rdds))
|
||||
for i in range(0, len(rdds)):
|
||||
jrdds[i] = rdds[i]._jrdd
|
||||
if is_jrdd:
|
||||
jrdds[i] = rdds[i]._jrdd
|
||||
else:
|
||||
# zip could return JavaPairRDD hence we ensure `_jrdd`
|
||||
# to be `JavaRDD` by wrapping it in a `map`
|
||||
jrdds[i] = rdds[i].map(lambda x: x)._jrdd
|
||||
return RDD(self._jsc.union(jrdds), self, rdds[0]._jrdd_deserializer)
|
||||
|
||||
def broadcast(self, value):
|
||||
|
|
|
@ -168,6 +168,15 @@ class RDDTests(ReusedPySparkTestCase):
|
|||
set([(x, (x, x)) for x in 'abc'])
|
||||
)
|
||||
|
||||
def test_union_pair_rdd(self):
|
||||
# Regression test for SPARK-31788
|
||||
rdd = self.sc.parallelize([1, 2])
|
||||
pair_rdd = rdd.zip(rdd)
|
||||
self.assertEqual(
|
||||
self.sc.union([pair_rdd, pair_rdd]).collect(),
|
||||
[((1, 1), (2, 2)), ((1, 1), (2, 2))]
|
||||
)
|
||||
|
||||
def test_deleting_input_files(self):
|
||||
# Regression test for SPARK-1025
|
||||
tempFile = tempfile.NamedTemporaryFile(delete=False)
|
||||
|
|
Loading…
Reference in a new issue