[SPARK-16348][ML][MLLIB][PYTHON] Use full classpaths for pyspark ML JVM calls
## What changes were proposed in this pull request? Issue: Omitting the full classpath can cause problems when calling JVM methods or classes from pyspark. This PR: Changed all uses of jvm.X in pyspark.ml and pyspark.mllib to use full classpath for X ## How was this patch tested? Existing unit tests. Manual testing in an environment where this was an issue. Author: Joseph K. Bradley <joseph@databricks.com> Closes #14023 from jkbradley/SPARK-16348.
This commit is contained in:
parent
59f9c1bd1a
commit
fdde7d0aa0
|
@ -63,7 +63,7 @@ def _to_java_object_rdd(rdd):
|
||||||
RDD is serialized in batch or not.
|
RDD is serialized in batch or not.
|
||||||
"""
|
"""
|
||||||
rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer()))
|
rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer()))
|
||||||
return rdd.ctx._jvm.MLSerDe.pythonToJava(rdd._jrdd, True)
|
return rdd.ctx._jvm.org.apache.spark.ml.python.MLSerDe.pythonToJava(rdd._jrdd, True)
|
||||||
|
|
||||||
|
|
||||||
def _py2java(sc, obj):
|
def _py2java(sc, obj):
|
||||||
|
@ -82,7 +82,7 @@ def _py2java(sc, obj):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
data = bytearray(PickleSerializer().dumps(obj))
|
data = bytearray(PickleSerializer().dumps(obj))
|
||||||
obj = sc._jvm.MLSerDe.loads(data)
|
obj = sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(data)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
@ -95,17 +95,17 @@ def _java2py(sc, r, encoding="bytes"):
|
||||||
clsName = 'JavaRDD'
|
clsName = 'JavaRDD'
|
||||||
|
|
||||||
if clsName == 'JavaRDD':
|
if clsName == 'JavaRDD':
|
||||||
jrdd = sc._jvm.MLSerDe.javaToPython(r)
|
jrdd = sc._jvm.org.apache.spark.ml.python.MLSerDe.javaToPython(r)
|
||||||
return RDD(jrdd, sc)
|
return RDD(jrdd, sc)
|
||||||
|
|
||||||
if clsName == 'Dataset':
|
if clsName == 'Dataset':
|
||||||
return DataFrame(r, SQLContext.getOrCreate(sc))
|
return DataFrame(r, SQLContext.getOrCreate(sc))
|
||||||
|
|
||||||
if clsName in _picklable_classes:
|
if clsName in _picklable_classes:
|
||||||
r = sc._jvm.MLSerDe.dumps(r)
|
r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(r)
|
||||||
elif isinstance(r, (JavaArray, JavaList)):
|
elif isinstance(r, (JavaArray, JavaList)):
|
||||||
try:
|
try:
|
||||||
r = sc._jvm.MLSerDe.dumps(r)
|
r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(r)
|
||||||
except Py4JJavaError:
|
except Py4JJavaError:
|
||||||
pass # not pickable
|
pass # not pickable
|
||||||
|
|
||||||
|
|
|
@ -1195,12 +1195,12 @@ class VectorTests(MLlibTestCase):
|
||||||
|
|
||||||
def _test_serialize(self, v):
|
def _test_serialize(self, v):
|
||||||
self.assertEqual(v, ser.loads(ser.dumps(v)))
|
self.assertEqual(v, ser.loads(ser.dumps(v)))
|
||||||
jvec = self.sc._jvm.MLSerDe.loads(bytearray(ser.dumps(v)))
|
jvec = self.sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(bytearray(ser.dumps(v)))
|
||||||
nv = ser.loads(bytes(self.sc._jvm.MLSerDe.dumps(jvec)))
|
nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(jvec)))
|
||||||
self.assertEqual(v, nv)
|
self.assertEqual(v, nv)
|
||||||
vs = [v] * 100
|
vs = [v] * 100
|
||||||
jvecs = self.sc._jvm.MLSerDe.loads(bytearray(ser.dumps(vs)))
|
jvecs = self.sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(bytearray(ser.dumps(vs)))
|
||||||
nvs = ser.loads(bytes(self.sc._jvm.MLSerDe.dumps(jvecs)))
|
nvs = ser.loads(bytes(self.sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(jvecs)))
|
||||||
self.assertEqual(vs, nvs)
|
self.assertEqual(vs, nvs)
|
||||||
|
|
||||||
def test_serialize(self):
|
def test_serialize(self):
|
||||||
|
|
|
@ -507,7 +507,7 @@ class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader):
|
||||||
Path to where the model is stored.
|
Path to where the model is stored.
|
||||||
"""
|
"""
|
||||||
model = cls._load_java(sc, path)
|
model = cls._load_java(sc, path)
|
||||||
wrapper = sc._jvm.GaussianMixtureModelWrapper(model)
|
wrapper = sc._jvm.org.apache.spark.mllib.api.python.GaussianMixtureModelWrapper(model)
|
||||||
return cls(wrapper)
|
return cls(wrapper)
|
||||||
|
|
||||||
|
|
||||||
|
@ -638,7 +638,8 @@ class PowerIterationClusteringModel(JavaModelWrapper, JavaSaveable, JavaLoader):
|
||||||
Load a model from the given path.
|
Load a model from the given path.
|
||||||
"""
|
"""
|
||||||
model = cls._load_java(sc, path)
|
model = cls._load_java(sc, path)
|
||||||
wrapper = sc._jvm.PowerIterationClusteringModelWrapper(model)
|
wrapper =\
|
||||||
|
sc._jvm.org.apache.spark.mllib.api.python.PowerIterationClusteringModelWrapper(model)
|
||||||
return PowerIterationClusteringModel(wrapper)
|
return PowerIterationClusteringModel(wrapper)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -66,7 +66,7 @@ def _to_java_object_rdd(rdd):
|
||||||
RDD is serialized in batch or not.
|
RDD is serialized in batch or not.
|
||||||
"""
|
"""
|
||||||
rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer()))
|
rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer()))
|
||||||
return rdd.ctx._jvm.SerDe.pythonToJava(rdd._jrdd, True)
|
return rdd.ctx._jvm.org.apache.spark.mllib.api.python.SerDe.pythonToJava(rdd._jrdd, True)
|
||||||
|
|
||||||
|
|
||||||
def _py2java(sc, obj):
|
def _py2java(sc, obj):
|
||||||
|
@ -85,7 +85,7 @@ def _py2java(sc, obj):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
data = bytearray(PickleSerializer().dumps(obj))
|
data = bytearray(PickleSerializer().dumps(obj))
|
||||||
obj = sc._jvm.SerDe.loads(data)
|
obj = sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(data)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
@ -98,17 +98,17 @@ def _java2py(sc, r, encoding="bytes"):
|
||||||
clsName = 'JavaRDD'
|
clsName = 'JavaRDD'
|
||||||
|
|
||||||
if clsName == 'JavaRDD':
|
if clsName == 'JavaRDD':
|
||||||
jrdd = sc._jvm.SerDe.javaToPython(r)
|
jrdd = sc._jvm.org.apache.spark.mllib.api.python.SerDe.javaToPython(r)
|
||||||
return RDD(jrdd, sc)
|
return RDD(jrdd, sc)
|
||||||
|
|
||||||
if clsName == 'Dataset':
|
if clsName == 'Dataset':
|
||||||
return DataFrame(r, SQLContext.getOrCreate(sc))
|
return DataFrame(r, SQLContext.getOrCreate(sc))
|
||||||
|
|
||||||
if clsName in _picklable_classes:
|
if clsName in _picklable_classes:
|
||||||
r = sc._jvm.SerDe.dumps(r)
|
r = sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(r)
|
||||||
elif isinstance(r, (JavaArray, JavaList)):
|
elif isinstance(r, (JavaArray, JavaList)):
|
||||||
try:
|
try:
|
||||||
r = sc._jvm.SerDe.dumps(r)
|
r = sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(r)
|
||||||
except Py4JJavaError:
|
except Py4JJavaError:
|
||||||
pass # not pickable
|
pass # not pickable
|
||||||
|
|
||||||
|
|
|
@ -553,7 +553,7 @@ class Word2VecModel(JavaVectorTransformer, JavaSaveable, JavaLoader):
|
||||||
"""
|
"""
|
||||||
jmodel = sc._jvm.org.apache.spark.mllib.feature \
|
jmodel = sc._jvm.org.apache.spark.mllib.feature \
|
||||||
.Word2VecModel.load(sc._jsc.sc(), path)
|
.Word2VecModel.load(sc._jsc.sc(), path)
|
||||||
model = sc._jvm.Word2VecModelWrapper(jmodel)
|
model = sc._jvm.org.apache.spark.mllib.api.python.Word2VecModelWrapper(jmodel)
|
||||||
return Word2VecModel(model)
|
return Word2VecModel(model)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -64,7 +64,7 @@ class FPGrowthModel(JavaModelWrapper, JavaSaveable, JavaLoader):
|
||||||
Load a model from the given path.
|
Load a model from the given path.
|
||||||
"""
|
"""
|
||||||
model = cls._load_java(sc, path)
|
model = cls._load_java(sc, path)
|
||||||
wrapper = sc._jvm.FPGrowthModelWrapper(model)
|
wrapper = sc._jvm.org.apache.spark.mllib.api.python.FPGrowthModelWrapper(model)
|
||||||
return FPGrowthModel(wrapper)
|
return FPGrowthModel(wrapper)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -207,7 +207,7 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader):
|
||||||
def load(cls, sc, path):
|
def load(cls, sc, path):
|
||||||
"""Load a model from the given path"""
|
"""Load a model from the given path"""
|
||||||
model = cls._load_java(sc, path)
|
model = cls._load_java(sc, path)
|
||||||
wrapper = sc._jvm.MatrixFactorizationModelWrapper(model)
|
wrapper = sc._jvm.org.apache.spark.mllib.api.python.MatrixFactorizationModelWrapper(model)
|
||||||
return MatrixFactorizationModel(wrapper)
|
return MatrixFactorizationModel(wrapper)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -150,12 +150,12 @@ class VectorTests(MLlibTestCase):
|
||||||
|
|
||||||
def _test_serialize(self, v):
|
def _test_serialize(self, v):
|
||||||
self.assertEqual(v, ser.loads(ser.dumps(v)))
|
self.assertEqual(v, ser.loads(ser.dumps(v)))
|
||||||
jvec = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(v)))
|
jvec = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(v)))
|
||||||
nv = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jvec)))
|
nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvec)))
|
||||||
self.assertEqual(v, nv)
|
self.assertEqual(v, nv)
|
||||||
vs = [v] * 100
|
vs = [v] * 100
|
||||||
jvecs = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(vs)))
|
jvecs = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(vs)))
|
||||||
nvs = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jvecs)))
|
nvs = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvecs)))
|
||||||
self.assertEqual(vs, nvs)
|
self.assertEqual(vs, nvs)
|
||||||
|
|
||||||
def test_serialize(self):
|
def test_serialize(self):
|
||||||
|
@ -1650,8 +1650,8 @@ class ALSTests(MLlibTestCase):
|
||||||
|
|
||||||
def test_als_ratings_serialize(self):
|
def test_als_ratings_serialize(self):
|
||||||
r = Rating(7, 1123, 3.14)
|
r = Rating(7, 1123, 3.14)
|
||||||
jr = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(r)))
|
jr = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(r)))
|
||||||
nr = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jr)))
|
nr = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jr)))
|
||||||
self.assertEqual(r.user, nr.user)
|
self.assertEqual(r.user, nr.user)
|
||||||
self.assertEqual(r.product, nr.product)
|
self.assertEqual(r.product, nr.product)
|
||||||
self.assertAlmostEqual(r.rating, nr.rating, 2)
|
self.assertAlmostEqual(r.rating, nr.rating, 2)
|
||||||
|
@ -1659,7 +1659,8 @@ class ALSTests(MLlibTestCase):
|
||||||
def test_als_ratings_id_long_error(self):
|
def test_als_ratings_id_long_error(self):
|
||||||
r = Rating(1205640308657491975, 50233468418, 1.0)
|
r = Rating(1205640308657491975, 50233468418, 1.0)
|
||||||
# rating user id exceeds max int value, should fail when pickled
|
# rating user id exceeds max int value, should fail when pickled
|
||||||
self.assertRaises(Py4JJavaError, self.sc._jvm.SerDe.loads, bytearray(ser.dumps(r)))
|
self.assertRaises(Py4JJavaError, self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads,
|
||||||
|
bytearray(ser.dumps(r)))
|
||||||
|
|
||||||
|
|
||||||
class HashingTFTest(MLlibTestCase):
|
class HashingTFTest(MLlibTestCase):
|
||||||
|
|
Loading…
Reference in a new issue