[SPARK-6097][MLLIB] Support tree model save/load in PySpark/MLlib
Similar to `MatrixFactorizaionModel`, we only need wrappers to support save/load for tree models in Python. jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #4854 from mengxr/SPARK-6097 and squashes the following commits: 4586a4d [Xiangrui Meng] fix more typos 8ebcac2 [Xiangrui Meng] fix python style 91172d8 [Xiangrui Meng] fix typos 201b3b9 [Xiangrui Meng] update user guide b5158e2 [Xiangrui Meng] support tree model save/load in PySpark/MLlib
This commit is contained in:
parent
54d19689ff
commit
7e53a79c30
|
@ -293,11 +293,9 @@ DecisionTreeModel sameModel = DecisionTreeModel.load(sc.sc(), "myModelPath");
|
|||
|
||||
<div data-lang="python">
|
||||
|
||||
Note that the Python API does not yet support model save/load but will in the future.
|
||||
|
||||
{% highlight python %}
|
||||
from pyspark.mllib.regression import LabeledPoint
|
||||
from pyspark.mllib.tree import DecisionTree
|
||||
from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
|
||||
from pyspark.mllib.util import MLUtils
|
||||
|
||||
# Load and parse the data file into an RDD of LabeledPoint.
|
||||
|
@ -317,6 +315,10 @@ testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(tes
|
|||
print('Test Error = ' + str(testErr))
|
||||
print('Learned classification tree model:')
|
||||
print(model.toDebugString())
|
||||
|
||||
# Save and load model
|
||||
model.save(sc, "myModelPath")
|
||||
sameModel = DecisionTreeModel.load(sc, "myModelPath")
|
||||
{% endhighlight %}
|
||||
</div>
|
||||
|
||||
|
@ -440,11 +442,9 @@ DecisionTreeModel sameModel = DecisionTreeModel.load(sc.sc(), "myModelPath");
|
|||
|
||||
<div data-lang="python">
|
||||
|
||||
Note that the Python API does not yet support model save/load but will in the future.
|
||||
|
||||
{% highlight python %}
|
||||
from pyspark.mllib.regression import LabeledPoint
|
||||
from pyspark.mllib.tree import DecisionTree
|
||||
from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
|
||||
from pyspark.mllib.util import MLUtils
|
||||
|
||||
# Load and parse the data file into an RDD of LabeledPoint.
|
||||
|
@ -464,6 +464,10 @@ testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / flo
|
|||
print('Test Mean Squared Error = ' + str(testMSE))
|
||||
print('Learned regression tree model:')
|
||||
print(model.toDebugString())
|
||||
|
||||
# Save and load model
|
||||
model.save(sc, "myModelPath")
|
||||
sameModel = DecisionTreeModel.load(sc, "myModelPath")
|
||||
{% endhighlight %}
|
||||
</div>
|
||||
|
||||
|
|
|
@ -202,10 +202,8 @@ RandomForestModel sameModel = RandomForestModel.load(sc.sc(), "myModelPath");
|
|||
|
||||
<div data-lang="python">
|
||||
|
||||
Note that the Python API does not yet support model save/load but will in the future.
|
||||
|
||||
{% highlight python %}
|
||||
from pyspark.mllib.tree import RandomForest
|
||||
from pyspark.mllib.tree import RandomForest, RandomForestModel
|
||||
from pyspark.mllib.util import MLUtils
|
||||
|
||||
# Load and parse the data file into an RDD of LabeledPoint.
|
||||
|
@ -228,6 +226,10 @@ testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(tes
|
|||
print('Test Error = ' + str(testErr))
|
||||
print('Learned classification forest model:')
|
||||
print(model.toDebugString())
|
||||
|
||||
# Save and load model
|
||||
model.save(sc, "myModelPath")
|
||||
sameModel = RandomForestModel.load(sc, "myModelPath")
|
||||
{% endhighlight %}
|
||||
</div>
|
||||
|
||||
|
@ -354,10 +356,8 @@ RandomForestModel sameModel = RandomForestModel.load(sc.sc(), "myModelPath");
|
|||
|
||||
<div data-lang="python">
|
||||
|
||||
Note that the Python API does not yet support model save/load but will in the future.
|
||||
|
||||
{% highlight python %}
|
||||
from pyspark.mllib.tree import RandomForest
|
||||
from pyspark.mllib.tree import RandomForest, RandomForestModel
|
||||
from pyspark.mllib.util import MLUtils
|
||||
|
||||
# Load and parse the data file into an RDD of LabeledPoint.
|
||||
|
@ -380,6 +380,10 @@ testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / flo
|
|||
print('Test Mean Squared Error = ' + str(testMSE))
|
||||
print('Learned regression forest model:')
|
||||
print(model.toDebugString())
|
||||
|
||||
# Save and load model
|
||||
model.save(sc, "myModelPath")
|
||||
sameModel = RandomForestModel.load(sc, "myModelPath")
|
||||
{% endhighlight %}
|
||||
</div>
|
||||
|
||||
|
@ -581,10 +585,8 @@ GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(sc.sc(), "m
|
|||
|
||||
<div data-lang="python">
|
||||
|
||||
Note that the Python API does not yet support model save/load but will in the future.
|
||||
|
||||
{% highlight python %}
|
||||
from pyspark.mllib.tree import GradientBoostedTrees
|
||||
from pyspark.mllib.tree import GradientBoostedTrees, GradientBoostedTreesModel
|
||||
from pyspark.mllib.util import MLUtils
|
||||
|
||||
# Load and parse the data file.
|
||||
|
@ -605,6 +607,10 @@ testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(tes
|
|||
print('Test Error = ' + str(testErr))
|
||||
print('Learned classification GBT model:')
|
||||
print(model.toDebugString())
|
||||
|
||||
# Save and load model
|
||||
model.save(sc, "myModelPath")
|
||||
sameModel = GradientBoostedTreesModel.load(sc, "myModelPath")
|
||||
{% endhighlight %}
|
||||
</div>
|
||||
|
||||
|
@ -732,10 +738,8 @@ GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(sc.sc(), "m
|
|||
|
||||
<div data-lang="python">
|
||||
|
||||
Note that the Python API does not yet support model save/load but will in the future.
|
||||
|
||||
{% highlight python %}
|
||||
from pyspark.mllib.tree import GradientBoostedTrees
|
||||
from pyspark.mllib.tree import GradientBoostedTrees, GradientBoostedTreesModel
|
||||
from pyspark.mllib.util import MLUtils
|
||||
|
||||
# Load and parse the data file.
|
||||
|
@ -756,6 +760,10 @@ testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / flo
|
|||
print('Test Mean Squared Error = ' + str(testMSE))
|
||||
print('Learned regression GBT model:')
|
||||
print(model.toDebugString())
|
||||
|
||||
# Save and load model
|
||||
model.save(sc, "myModelPath")
|
||||
sameModel = GradientBoostedTreesModel.load(sc, "myModelPath")
|
||||
{% endhighlight %}
|
||||
</div>
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ from collections import namedtuple
|
|||
from pyspark import SparkContext
|
||||
from pyspark.rdd import RDD
|
||||
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc
|
||||
from pyspark.mllib.util import Saveable, JavaLoader
|
||||
from pyspark.mllib.util import JavaLoader, JavaSaveable
|
||||
|
||||
__all__ = ['MatrixFactorizationModel', 'ALS', 'Rating']
|
||||
|
||||
|
@ -41,7 +41,7 @@ class Rating(namedtuple("Rating", ["user", "product", "rating"])):
|
|||
|
||||
|
||||
@inherit_doc
|
||||
class MatrixFactorizationModel(JavaModelWrapper, Saveable, JavaLoader):
|
||||
class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader):
|
||||
|
||||
"""A matrix factorisation model trained by regularized alternating
|
||||
least-squares.
|
||||
|
@ -92,7 +92,7 @@ class MatrixFactorizationModel(JavaModelWrapper, Saveable, JavaLoader):
|
|||
0.43...
|
||||
>>> try:
|
||||
... os.removedirs(path)
|
||||
... except:
|
||||
... except OSError:
|
||||
... pass
|
||||
"""
|
||||
def predict(self, user, product):
|
||||
|
@ -111,9 +111,6 @@ class MatrixFactorizationModel(JavaModelWrapper, Saveable, JavaLoader):
|
|||
def productFeatures(self):
|
||||
return self.call("getProductFeatures")
|
||||
|
||||
def save(self, sc, path):
|
||||
self.call("save", sc._jsc.sc(), path)
|
||||
|
||||
|
||||
class ALS(object):
|
||||
|
||||
|
|
|
@ -19,7 +19,9 @@
|
|||
Fuller unit tests for Python MLlib.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import array as pyarray
|
||||
|
||||
from numpy import array, array_equal
|
||||
|
@ -195,7 +197,8 @@ class ListTests(PySparkTestCase):
|
|||
|
||||
def test_classification(self):
|
||||
from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes
|
||||
from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees
|
||||
from pyspark.mllib.tree import DecisionTree, DecisionTreeModel, RandomForest,\
|
||||
RandomForestModel, GradientBoostedTrees, GradientBoostedTreesModel
|
||||
data = [
|
||||
LabeledPoint(0.0, [1, 0, 0]),
|
||||
LabeledPoint(1.0, [0, 1, 1]),
|
||||
|
@ -205,6 +208,8 @@ class ListTests(PySparkTestCase):
|
|||
rdd = self.sc.parallelize(data)
|
||||
features = [p.features.tolist() for p in data]
|
||||
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
|
||||
lr_model = LogisticRegressionWithSGD.train(rdd)
|
||||
self.assertTrue(lr_model.predict(features[0]) <= 0)
|
||||
self.assertTrue(lr_model.predict(features[1]) > 0)
|
||||
|
@ -231,6 +236,11 @@ class ListTests(PySparkTestCase):
|
|||
self.assertTrue(dt_model.predict(features[2]) <= 0)
|
||||
self.assertTrue(dt_model.predict(features[3]) > 0)
|
||||
|
||||
dt_model_dir = os.path.join(temp_dir, "dt")
|
||||
dt_model.save(self.sc, dt_model_dir)
|
||||
same_dt_model = DecisionTreeModel.load(self.sc, dt_model_dir)
|
||||
self.assertEqual(same_dt_model.toDebugString(), dt_model.toDebugString())
|
||||
|
||||
rf_model = RandomForest.trainClassifier(
|
||||
rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=100)
|
||||
self.assertTrue(rf_model.predict(features[0]) <= 0)
|
||||
|
@ -238,6 +248,11 @@ class ListTests(PySparkTestCase):
|
|||
self.assertTrue(rf_model.predict(features[2]) <= 0)
|
||||
self.assertTrue(rf_model.predict(features[3]) > 0)
|
||||
|
||||
rf_model_dir = os.path.join(temp_dir, "rf")
|
||||
rf_model.save(self.sc, rf_model_dir)
|
||||
same_rf_model = RandomForestModel.load(self.sc, rf_model_dir)
|
||||
self.assertEqual(same_rf_model.toDebugString(), rf_model.toDebugString())
|
||||
|
||||
gbt_model = GradientBoostedTrees.trainClassifier(
|
||||
rdd, categoricalFeaturesInfo=categoricalFeaturesInfo)
|
||||
self.assertTrue(gbt_model.predict(features[0]) <= 0)
|
||||
|
@ -245,6 +260,16 @@ class ListTests(PySparkTestCase):
|
|||
self.assertTrue(gbt_model.predict(features[2]) <= 0)
|
||||
self.assertTrue(gbt_model.predict(features[3]) > 0)
|
||||
|
||||
gbt_model_dir = os.path.join(temp_dir, "gbt")
|
||||
gbt_model.save(self.sc, gbt_model_dir)
|
||||
same_gbt_model = GradientBoostedTreesModel.load(self.sc, gbt_model_dir)
|
||||
self.assertEqual(same_gbt_model.toDebugString(), gbt_model.toDebugString())
|
||||
|
||||
try:
|
||||
os.removedirs(temp_dir)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def test_regression(self):
|
||||
from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \
|
||||
RidgeRegressionWithSGD
|
||||
|
|
|
@ -23,12 +23,13 @@ from pyspark import SparkContext, RDD
|
|||
from pyspark.mllib.common import callMLlibFunc, inherit_doc, JavaModelWrapper
|
||||
from pyspark.mllib.linalg import _convert_to_vector
|
||||
from pyspark.mllib.regression import LabeledPoint
|
||||
from pyspark.mllib.util import JavaLoader, JavaSaveable
|
||||
|
||||
__all__ = ['DecisionTreeModel', 'DecisionTree', 'RandomForestModel',
|
||||
'RandomForest', 'GradientBoostedTreesModel', 'GradientBoostedTrees']
|
||||
|
||||
|
||||
class TreeEnsembleModel(JavaModelWrapper):
|
||||
class TreeEnsembleModel(JavaModelWrapper, JavaSaveable):
|
||||
def predict(self, x):
|
||||
"""
|
||||
Predict values for a single data point or an RDD of points using
|
||||
|
@ -66,7 +67,7 @@ class TreeEnsembleModel(JavaModelWrapper):
|
|||
return self._java_model.toDebugString()
|
||||
|
||||
|
||||
class DecisionTreeModel(JavaModelWrapper):
|
||||
class DecisionTreeModel(JavaModelWrapper, JavaSaveable, JavaLoader):
|
||||
"""
|
||||
.. note:: Experimental
|
||||
|
||||
|
@ -103,6 +104,10 @@ class DecisionTreeModel(JavaModelWrapper):
|
|||
""" full model. """
|
||||
return self._java_model.toDebugString()
|
||||
|
||||
@classmethod
|
||||
def _java_loader_class(cls):
|
||||
return "org.apache.spark.mllib.tree.model.DecisionTreeModel"
|
||||
|
||||
|
||||
class DecisionTree(object):
|
||||
"""
|
||||
|
@ -227,13 +232,17 @@ class DecisionTree(object):
|
|||
|
||||
|
||||
@inherit_doc
|
||||
class RandomForestModel(TreeEnsembleModel):
|
||||
class RandomForestModel(TreeEnsembleModel, JavaLoader):
|
||||
"""
|
||||
.. note:: Experimental
|
||||
|
||||
Represents a random forest model.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _java_loader_class(cls):
|
||||
return "org.apache.spark.mllib.tree.model.RandomForestModel"
|
||||
|
||||
|
||||
class RandomForest(object):
|
||||
"""
|
||||
|
@ -406,13 +415,17 @@ class RandomForest(object):
|
|||
|
||||
|
||||
@inherit_doc
|
||||
class GradientBoostedTreesModel(TreeEnsembleModel):
|
||||
class GradientBoostedTreesModel(TreeEnsembleModel, JavaLoader):
|
||||
"""
|
||||
.. note:: Experimental
|
||||
|
||||
Represents a gradient-boosted tree model.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _java_loader_class(cls):
|
||||
return "org.apache.spark.mllib.tree.model.GradientBoostedTreesModel"
|
||||
|
||||
|
||||
class GradientBoostedTrees(object):
|
||||
"""
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
import numpy as np
|
||||
import warnings
|
||||
|
||||
from pyspark.mllib.common import callMLlibFunc
|
||||
from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper, inherit_doc
|
||||
from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector
|
||||
from pyspark.mllib.regression import LabeledPoint
|
||||
|
||||
|
@ -191,6 +191,17 @@ class Saveable(object):
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
@inherit_doc
|
||||
class JavaSaveable(Saveable):
|
||||
"""
|
||||
Mixin for models that provide save() through their Scala
|
||||
implementation.
|
||||
"""
|
||||
|
||||
def save(self, sc, path):
|
||||
self._java_model.save(sc._jsc.sc(), path)
|
||||
|
||||
|
||||
class Loader(object):
|
||||
"""
|
||||
Mixin for classes which can load saved models from files.
|
||||
|
@ -210,6 +221,7 @@ class Loader(object):
|
|||
raise NotImplemented
|
||||
|
||||
|
||||
@inherit_doc
|
||||
class JavaLoader(Loader):
|
||||
"""
|
||||
Mixin for classes which can load saved models using its Scala
|
||||
|
@ -217,13 +229,30 @@ class JavaLoader(Loader):
|
|||
"""
|
||||
|
||||
@classmethod
|
||||
def load(cls, sc, path):
|
||||
def _java_loader_class(cls):
|
||||
"""
|
||||
Returns the full class name of the Java loader. The default
|
||||
implementation replaces "pyspark" by "org.apache.spark" in
|
||||
the Python full class name.
|
||||
"""
|
||||
java_package = cls.__module__.replace("pyspark", "org.apache.spark")
|
||||
java_class = ".".join([java_package, cls.__name__])
|
||||
return ".".join([java_package, cls.__name__])
|
||||
|
||||
@classmethod
|
||||
def _load_java(cls, sc, path):
|
||||
"""
|
||||
Load a Java model from the given path.
|
||||
"""
|
||||
java_class = cls._java_loader_class()
|
||||
java_obj = sc._jvm
|
||||
for name in java_class.split("."):
|
||||
java_obj = getattr(java_obj, name)
|
||||
return cls(java_obj.load(sc._jsc.sc(), path))
|
||||
return java_obj.load(sc._jsc.sc(), path)
|
||||
|
||||
@classmethod
|
||||
def load(cls, sc, path):
|
||||
java_model = cls._load_java(sc, path)
|
||||
return cls(java_model)
|
||||
|
||||
|
||||
def _test():
|
||||
|
|
Loading…
Reference in a new issue