[SPARK-31788][CORE][PYTHON] Fix UnionRDD of PairRDDs

### What changes were proposed in this pull request?
UnionRDD of PairRDDs causing a bug. The fix is to check for instance type before proceeding

### Why are the changes needed?
Changes are needed to avoid users running into issues with union rdd operation with any other type other than JavaRDD.

### Does this PR introduce _any_ user-facing change?
Yes

Before:
SparkSession available as 'spark'.
>>> rdd1 = sc.parallelize([1,2,3,4,5])
>>> rdd2 = sc.parallelize([6,7,8,9,10])
>>> pairRDD1 = rdd1.zip(rdd2)
>>> unionRDD1 = sc.union([pairRDD1, pairRDD1])
Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/home/gs/spark/latest/python/pyspark/context.py", line 870,
in union jrdds[i] = rdds[i]._jrdd
File "/home/gs/spark/latest/python/lib/py4j-0.10.9-src.zip/py4j/java_collections.py", line 238, in setitem File "/home/gs/spark/latest/python/lib/py4j-0.10.9-src.zip/py4j/java_collections.py", line 221,
in __set_item File "/home/gs/spark/latest/python/lib/py4j-0.10.9-src.zip/py4j/protocol.py", line 332, in get_return_value py4j.protocol.Py4JError: An error occurred while calling None.None. Trace: py4j.Py4JException: Cannot convert org.apache.spark.api.java.JavaPairRDD to org.apache.spark.api.java.JavaRDD at py4j.commands.ArrayCommand.convertArgument(ArrayCommand.java:166) at py4j.commands.ArrayCommand.setArray(ArrayCommand.java:144) at py4j.commands.ArrayCommand.execute(ArrayCommand.java:97) at py4j.GatewayConnection.run(GatewayConnection.java:238) at java.lang.Thread.run(Thread.java:748)

After:
>>> rdd2 = sc.parallelize([6,7,8,9,10])
>>> pairRDD1 = rdd1.zip(rdd2)
>>> unionRDD1 = sc.union([pairRDD1, pairRDD1])
>>> unionRDD1.collect()
[(1, 6), (2, 7), (3, 8), (4, 9), (5, 10), (1, 6), (2, 7), (3, 8), (4, 9), (5, 10)]

### How was this patch tested?
Tested with the reproduced piece of code above manually

Closes #28603 from redsanket/SPARK-31788.

Authored-by: schintap <schintap@verizonmedia.com>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
schintap 2020-05-25 10:29:08 +09:00 committed by HyukjinKwon
parent 753636e86b
commit a61911c50c
2 changed files with 19 additions and 2 deletions

View file

@ -25,6 +25,7 @@ from threading import RLock
from tempfile import NamedTemporaryFile
from py4j.protocol import Py4JError
from py4j.java_gateway import is_instance_of
from pyspark import accumulators
from pyspark.accumulators import Accumulator
@ -864,10 +865,17 @@ class SparkContext(object):
first_jrdd_deserializer = rdds[0]._jrdd_deserializer
if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds):
rdds = [x._reserialize() for x in rdds]
gw = SparkContext._gateway
cls = SparkContext._jvm.org.apache.spark.api.java.JavaRDD
jrdds = SparkContext._gateway.new_array(cls, len(rdds))
is_jrdd = is_instance_of(gw, rdds[0]._jrdd, cls)
jrdds = gw.new_array(cls, len(rdds))
for i in range(0, len(rdds)):
jrdds[i] = rdds[i]._jrdd
if is_jrdd:
jrdds[i] = rdds[i]._jrdd
else:
# zip could return JavaPairRDD hence we ensure `_jrdd`
# to be `JavaRDD` by wrapping it in a `map`
jrdds[i] = rdds[i].map(lambda x: x)._jrdd
return RDD(self._jsc.union(jrdds), self, rdds[0]._jrdd_deserializer)
def broadcast(self, value):

View file

@ -168,6 +168,15 @@ class RDDTests(ReusedPySparkTestCase):
set([(x, (x, x)) for x in 'abc'])
)
def test_union_pair_rdd(self):
# Regression test for SPARK-31788
rdd = self.sc.parallelize([1, 2])
pair_rdd = rdd.zip(rdd)
self.assertEqual(
self.sc.union([pair_rdd, pair_rdd]).collect(),
[((1, 1), (2, 2)), ((1, 1), (2, 2))]
)
def test_deleting_input_files(self):
# Regression test for SPARK-1025
tempFile = tempfile.NamedTemporaryFile(delete=False)