[SPARK-6097][MLLIB] Support tree model save/load in PySpark/MLlib

Similar to `MatrixFactorizaionModel`, we only need wrappers to support save/load for tree models in Python.

jkbradley

Author: Xiangrui Meng <meng@databricks.com>

Closes #4854 from mengxr/SPARK-6097 and squashes the following commits:

4586a4d [Xiangrui Meng] fix more typos
8ebcac2 [Xiangrui Meng] fix python style
91172d8 [Xiangrui Meng] fix typos
201b3b9 [Xiangrui Meng] update user guide
b5158e2 [Xiangrui Meng] support tree model save/load in PySpark/MLlib
This commit is contained in:
Xiangrui Meng 2015-03-02 22:27:01 -08:00
parent 54d19689ff
commit 7e53a79c30
6 changed files with 109 additions and 33 deletions

View file

@ -293,11 +293,9 @@ DecisionTreeModel sameModel = DecisionTreeModel.load(sc.sc(), "myModelPath");
<div data-lang="python">
Note that the Python API does not yet support model save/load but will in the future.
{% highlight python %}
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.tree import DecisionTree
from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
from pyspark.mllib.util import MLUtils
# Load and parse the data file into an RDD of LabeledPoint.
@ -317,6 +315,10 @@ testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(tes
print('Test Error = ' + str(testErr))
print('Learned classification tree model:')
print(model.toDebugString())
# Save and load model
model.save(sc, "myModelPath")
sameModel = DecisionTreeModel.load(sc, "myModelPath")
{% endhighlight %}
</div>
@ -440,11 +442,9 @@ DecisionTreeModel sameModel = DecisionTreeModel.load(sc.sc(), "myModelPath");
<div data-lang="python">
Note that the Python API does not yet support model save/load but will in the future.
{% highlight python %}
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.tree import DecisionTree
from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
from pyspark.mllib.util import MLUtils
# Load and parse the data file into an RDD of LabeledPoint.
@ -464,6 +464,10 @@ testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / flo
print('Test Mean Squared Error = ' + str(testMSE))
print('Learned regression tree model:')
print(model.toDebugString())
# Save and load model
model.save(sc, "myModelPath")
sameModel = DecisionTreeModel.load(sc, "myModelPath")
{% endhighlight %}
</div>

View file

@ -202,10 +202,8 @@ RandomForestModel sameModel = RandomForestModel.load(sc.sc(), "myModelPath");
<div data-lang="python">
Note that the Python API does not yet support model save/load but will in the future.
{% highlight python %}
from pyspark.mllib.tree import RandomForest
from pyspark.mllib.tree import RandomForest, RandomForestModel
from pyspark.mllib.util import MLUtils
# Load and parse the data file into an RDD of LabeledPoint.
@ -228,6 +226,10 @@ testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(tes
print('Test Error = ' + str(testErr))
print('Learned classification forest model:')
print(model.toDebugString())
# Save and load model
model.save(sc, "myModelPath")
sameModel = RandomForestModel.load(sc, "myModelPath")
{% endhighlight %}
</div>
@ -354,10 +356,8 @@ RandomForestModel sameModel = RandomForestModel.load(sc.sc(), "myModelPath");
<div data-lang="python">
Note that the Python API does not yet support model save/load but will in the future.
{% highlight python %}
from pyspark.mllib.tree import RandomForest
from pyspark.mllib.tree import RandomForest, RandomForestModel
from pyspark.mllib.util import MLUtils
# Load and parse the data file into an RDD of LabeledPoint.
@ -380,6 +380,10 @@ testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / flo
print('Test Mean Squared Error = ' + str(testMSE))
print('Learned regression forest model:')
print(model.toDebugString())
# Save and load model
model.save(sc, "myModelPath")
sameModel = RandomForestModel.load(sc, "myModelPath")
{% endhighlight %}
</div>
@ -581,10 +585,8 @@ GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(sc.sc(), "m
<div data-lang="python">
Note that the Python API does not yet support model save/load but will in the future.
{% highlight python %}
from pyspark.mllib.tree import GradientBoostedTrees
from pyspark.mllib.tree import GradientBoostedTrees, GradientBoostedTreesModel
from pyspark.mllib.util import MLUtils
# Load and parse the data file.
@ -605,6 +607,10 @@ testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(tes
print('Test Error = ' + str(testErr))
print('Learned classification GBT model:')
print(model.toDebugString())
# Save and load model
model.save(sc, "myModelPath")
sameModel = GradientBoostedTreesModel.load(sc, "myModelPath")
{% endhighlight %}
</div>
@ -732,10 +738,8 @@ GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(sc.sc(), "m
<div data-lang="python">
Note that the Python API does not yet support model save/load but will in the future.
{% highlight python %}
from pyspark.mllib.tree import GradientBoostedTrees
from pyspark.mllib.tree import GradientBoostedTrees, GradientBoostedTreesModel
from pyspark.mllib.util import MLUtils
# Load and parse the data file.
@ -756,6 +760,10 @@ testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / flo
print('Test Mean Squared Error = ' + str(testMSE))
print('Learned regression GBT model:')
print(model.toDebugString())
# Save and load model
model.save(sc, "myModelPath")
sameModel = GradientBoostedTreesModel.load(sc, "myModelPath")
{% endhighlight %}
</div>

View file

@ -20,7 +20,7 @@ from collections import namedtuple
from pyspark import SparkContext
from pyspark.rdd import RDD
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc
from pyspark.mllib.util import Saveable, JavaLoader
from pyspark.mllib.util import JavaLoader, JavaSaveable
__all__ = ['MatrixFactorizationModel', 'ALS', 'Rating']
@ -41,7 +41,7 @@ class Rating(namedtuple("Rating", ["user", "product", "rating"])):
@inherit_doc
class MatrixFactorizationModel(JavaModelWrapper, Saveable, JavaLoader):
class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader):
"""A matrix factorisation model trained by regularized alternating
least-squares.
@ -92,7 +92,7 @@ class MatrixFactorizationModel(JavaModelWrapper, Saveable, JavaLoader):
0.43...
>>> try:
... os.removedirs(path)
... except:
... except OSError:
... pass
"""
def predict(self, user, product):
@ -111,9 +111,6 @@ class MatrixFactorizationModel(JavaModelWrapper, Saveable, JavaLoader):
def productFeatures(self):
return self.call("getProductFeatures")
def save(self, sc, path):
self.call("save", sc._jsc.sc(), path)
class ALS(object):

View file

@ -19,7 +19,9 @@
Fuller unit tests for Python MLlib.
"""
import os
import sys
import tempfile
import array as pyarray
from numpy import array, array_equal
@ -195,7 +197,8 @@ class ListTests(PySparkTestCase):
def test_classification(self):
from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes
from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees
from pyspark.mllib.tree import DecisionTree, DecisionTreeModel, RandomForest,\
RandomForestModel, GradientBoostedTrees, GradientBoostedTreesModel
data = [
LabeledPoint(0.0, [1, 0, 0]),
LabeledPoint(1.0, [0, 1, 1]),
@ -205,6 +208,8 @@ class ListTests(PySparkTestCase):
rdd = self.sc.parallelize(data)
features = [p.features.tolist() for p in data]
temp_dir = tempfile.mkdtemp()
lr_model = LogisticRegressionWithSGD.train(rdd)
self.assertTrue(lr_model.predict(features[0]) <= 0)
self.assertTrue(lr_model.predict(features[1]) > 0)
@ -231,6 +236,11 @@ class ListTests(PySparkTestCase):
self.assertTrue(dt_model.predict(features[2]) <= 0)
self.assertTrue(dt_model.predict(features[3]) > 0)
dt_model_dir = os.path.join(temp_dir, "dt")
dt_model.save(self.sc, dt_model_dir)
same_dt_model = DecisionTreeModel.load(self.sc, dt_model_dir)
self.assertEqual(same_dt_model.toDebugString(), dt_model.toDebugString())
rf_model = RandomForest.trainClassifier(
rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=100)
self.assertTrue(rf_model.predict(features[0]) <= 0)
@ -238,6 +248,11 @@ class ListTests(PySparkTestCase):
self.assertTrue(rf_model.predict(features[2]) <= 0)
self.assertTrue(rf_model.predict(features[3]) > 0)
rf_model_dir = os.path.join(temp_dir, "rf")
rf_model.save(self.sc, rf_model_dir)
same_rf_model = RandomForestModel.load(self.sc, rf_model_dir)
self.assertEqual(same_rf_model.toDebugString(), rf_model.toDebugString())
gbt_model = GradientBoostedTrees.trainClassifier(
rdd, categoricalFeaturesInfo=categoricalFeaturesInfo)
self.assertTrue(gbt_model.predict(features[0]) <= 0)
@ -245,6 +260,16 @@ class ListTests(PySparkTestCase):
self.assertTrue(gbt_model.predict(features[2]) <= 0)
self.assertTrue(gbt_model.predict(features[3]) > 0)
gbt_model_dir = os.path.join(temp_dir, "gbt")
gbt_model.save(self.sc, gbt_model_dir)
same_gbt_model = GradientBoostedTreesModel.load(self.sc, gbt_model_dir)
self.assertEqual(same_gbt_model.toDebugString(), gbt_model.toDebugString())
try:
os.removedirs(temp_dir)
except OSError:
pass
def test_regression(self):
from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \
RidgeRegressionWithSGD

View file

@ -23,12 +23,13 @@ from pyspark import SparkContext, RDD
from pyspark.mllib.common import callMLlibFunc, inherit_doc, JavaModelWrapper
from pyspark.mllib.linalg import _convert_to_vector
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.util import JavaLoader, JavaSaveable
__all__ = ['DecisionTreeModel', 'DecisionTree', 'RandomForestModel',
'RandomForest', 'GradientBoostedTreesModel', 'GradientBoostedTrees']
class TreeEnsembleModel(JavaModelWrapper):
class TreeEnsembleModel(JavaModelWrapper, JavaSaveable):
def predict(self, x):
"""
Predict values for a single data point or an RDD of points using
@ -66,7 +67,7 @@ class TreeEnsembleModel(JavaModelWrapper):
return self._java_model.toDebugString()
class DecisionTreeModel(JavaModelWrapper):
class DecisionTreeModel(JavaModelWrapper, JavaSaveable, JavaLoader):
"""
.. note:: Experimental
@ -103,6 +104,10 @@ class DecisionTreeModel(JavaModelWrapper):
""" full model. """
return self._java_model.toDebugString()
@classmethod
def _java_loader_class(cls):
return "org.apache.spark.mllib.tree.model.DecisionTreeModel"
class DecisionTree(object):
"""
@ -227,13 +232,17 @@ class DecisionTree(object):
@inherit_doc
class RandomForestModel(TreeEnsembleModel):
class RandomForestModel(TreeEnsembleModel, JavaLoader):
"""
.. note:: Experimental
Represents a random forest model.
"""
@classmethod
def _java_loader_class(cls):
return "org.apache.spark.mllib.tree.model.RandomForestModel"
class RandomForest(object):
"""
@ -406,13 +415,17 @@ class RandomForest(object):
@inherit_doc
class GradientBoostedTreesModel(TreeEnsembleModel):
class GradientBoostedTreesModel(TreeEnsembleModel, JavaLoader):
"""
.. note:: Experimental
Represents a gradient-boosted tree model.
"""
@classmethod
def _java_loader_class(cls):
return "org.apache.spark.mllib.tree.model.GradientBoostedTreesModel"
class GradientBoostedTrees(object):
"""

View file

@ -18,7 +18,7 @@
import numpy as np
import warnings
from pyspark.mllib.common import callMLlibFunc
from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper, inherit_doc
from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector
from pyspark.mllib.regression import LabeledPoint
@ -191,6 +191,17 @@ class Saveable(object):
raise NotImplementedError
@inherit_doc
class JavaSaveable(Saveable):
"""
Mixin for models that provide save() through their Scala
implementation.
"""
def save(self, sc, path):
self._java_model.save(sc._jsc.sc(), path)
class Loader(object):
"""
Mixin for classes which can load saved models from files.
@ -210,6 +221,7 @@ class Loader(object):
raise NotImplemented
@inherit_doc
class JavaLoader(Loader):
"""
Mixin for classes which can load saved models using its Scala
@ -217,13 +229,30 @@ class JavaLoader(Loader):
"""
@classmethod
def load(cls, sc, path):
def _java_loader_class(cls):
"""
Returns the full class name of the Java loader. The default
implementation replaces "pyspark" by "org.apache.spark" in
the Python full class name.
"""
java_package = cls.__module__.replace("pyspark", "org.apache.spark")
java_class = ".".join([java_package, cls.__name__])
return ".".join([java_package, cls.__name__])
@classmethod
def _load_java(cls, sc, path):
"""
Load a Java model from the given path.
"""
java_class = cls._java_loader_class()
java_obj = sc._jvm
for name in java_class.split("."):
java_obj = getattr(java_obj, name)
return cls(java_obj.load(sc._jsc.sc(), path))
return java_obj.load(sc._jsc.sc(), path)
@classmethod
def load(cls, sc, path):
java_model = cls._load_java(sc, path)
return cls(java_model)
def _test():