[SPARK-5094][MLlib] Add Python API for Gradient Boosted Trees

This PR is implementing the Gradient Boosted Trees for Python API.

Author: Kazuki Taniguchi <kazuki.t.1018@gmail.com>

Closes #3951 from kazk1018/gbt_for_py and squashes the following commits:

620d247 [Kazuki Taniguchi] [SPARK-5094][MLlib] Add Python API for Gradient Boosted Trees
This commit is contained in:
Kazuki Taniguchi 2015-01-30 00:39:44 -08:00 committed by Xiangrui Meng
parent dd4d84cf80
commit bc1fc9b60d
4 changed files with 318 additions and 56 deletions

View file

@ -0,0 +1,76 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Gradient boosted Trees classification and regression using MLlib.
"""
import sys
from pyspark.context import SparkContext
from pyspark.mllib.tree import GradientBoostedTrees
from pyspark.mllib.util import MLUtils
def testClassification(trainingData, testData):
# Train a GradientBoostedTrees model.
# Empty categoricalFeaturesInfo indicates all features are continuous.
model = GradientBoostedTrees.trainClassifier(trainingData, categoricalFeaturesInfo={},
numIterations=30, maxDepth=4)
# Evaluate model on test instances and compute test error
predictions = model.predict(testData.map(lambda x: x.features))
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() \
/ float(testData.count())
print('Test Error = ' + str(testErr))
print('Learned classification ensemble model:')
print(model.toDebugString())
def testRegression(trainingData, testData):
# Train a GradientBoostedTrees model.
# Empty categoricalFeaturesInfo indicates all features are continuous.
model = GradientBoostedTrees.trainRegressor(trainingData, categoricalFeaturesInfo={},
numIterations=30, maxDepth=4)
# Evaluate model on test instances and compute test error
predictions = model.predict(testData.map(lambda x: x.features))
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() \
/ float(testData.count())
print('Test Mean Squared Error = ' + str(testMSE))
print('Learned regression ensemble model:')
print(model.toDebugString())
if __name__ == "__main__":
if len(sys.argv) > 1:
print >> sys.stderr, "Usage: gradient_boosted_trees"
exit(1)
sc = SparkContext(appName="PythonGradientBoostedTrees")
# Load and parse the data file into an RDD of LabeledPoint.
data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
# Split the data into training and test sets (30% held out for testing)
(trainingData, testData) = data.randomSplit([0.7, 0.3])
print('\nRunning example of classification using GradientBoostedTrees\n')
testClassification(trainingData, testData)
print('\nRunning example of regression using GradientBoostedTrees\n')
testRegression(trainingData, testData)
sc.stop()

View file

@ -41,10 +41,11 @@ import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics}
import org.apache.spark.mllib.stat.correlation.CorrelationNames
import org.apache.spark.mllib.stat.test.ChiSqTestResult
import org.apache.spark.mllib.tree.{RandomForest, DecisionTree}
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
import org.apache.spark.mllib.tree.{GradientBoostedTrees, RandomForest, DecisionTree}
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo, Strategy}
import org.apache.spark.mllib.tree.impurity._
import org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel}
import org.apache.spark.mllib.tree.loss.Losses
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel, RandomForestModel, DecisionTreeModel}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
@ -532,6 +533,35 @@ class PythonMLLibAPI extends Serializable {
}
}
/**
* Java stub for Python mllib GradientBoostedTrees.train().
* This stub returns a handle to the Java object instead of the content of the Java object.
* Extra care needs to be taken in the Python code to ensure it gets freed on exit;
* see the Py4J documentation.
*/
def trainGradientBoostedTreesModel(
data: JavaRDD[LabeledPoint],
algoStr: String,
categoricalFeaturesInfo: JMap[Int, Int],
lossStr: String,
numIterations: Int,
learningRate: Double,
maxDepth: Int): GradientBoostedTreesModel = {
val boostingStrategy = BoostingStrategy.defaultParams(algoStr)
boostingStrategy.setLoss(Losses.fromString(lossStr))
boostingStrategy.setNumIterations(numIterations)
boostingStrategy.setLearningRate(learningRate)
boostingStrategy.treeStrategy.setMaxDepth(maxDepth)
boostingStrategy.treeStrategy.categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap
val cached = data.rdd.persist(StorageLevel.MEMORY_AND_DISK)
try {
GradientBoostedTrees.train(cached, boostingStrategy)
} finally {
cached.unpersist(blocking = false)
}
}
/**
* Java stub for mllib Statistics.colStats(X: RDD[Vector]).
* TODO figure out return type.

View file

@ -169,7 +169,7 @@ class ListTests(PySparkTestCase):
def test_classification(self):
from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes
from pyspark.mllib.tree import DecisionTree
from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees
data = [
LabeledPoint(0.0, [1, 0, 0]),
LabeledPoint(1.0, [0, 1, 1]),
@ -198,18 +198,31 @@ class ListTests(PySparkTestCase):
self.assertTrue(nb_model.predict(features[3]) > 0)
categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories
dt_model = \
DecisionTree.trainClassifier(rdd, numClasses=2,
categoricalFeaturesInfo=categoricalFeaturesInfo)
dt_model = DecisionTree.trainClassifier(
rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo)
self.assertTrue(dt_model.predict(features[0]) <= 0)
self.assertTrue(dt_model.predict(features[1]) > 0)
self.assertTrue(dt_model.predict(features[2]) <= 0)
self.assertTrue(dt_model.predict(features[3]) > 0)
rf_model = RandomForest.trainClassifier(
rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=100)
self.assertTrue(rf_model.predict(features[0]) <= 0)
self.assertTrue(rf_model.predict(features[1]) > 0)
self.assertTrue(rf_model.predict(features[2]) <= 0)
self.assertTrue(rf_model.predict(features[3]) > 0)
gbt_model = GradientBoostedTrees.trainClassifier(
rdd, categoricalFeaturesInfo=categoricalFeaturesInfo)
self.assertTrue(gbt_model.predict(features[0]) <= 0)
self.assertTrue(gbt_model.predict(features[1]) > 0)
self.assertTrue(gbt_model.predict(features[2]) <= 0)
self.assertTrue(gbt_model.predict(features[3]) > 0)
def test_regression(self):
from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \
RidgeRegressionWithSGD
from pyspark.mllib.tree import DecisionTree
from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees
data = [
LabeledPoint(-1.0, [0, -1]),
LabeledPoint(1.0, [0, 1]),
@ -238,13 +251,27 @@ class ListTests(PySparkTestCase):
self.assertTrue(rr_model.predict(features[3]) > 0)
categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories
dt_model = \
DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo)
dt_model = DecisionTree.trainRegressor(
rdd, categoricalFeaturesInfo=categoricalFeaturesInfo)
self.assertTrue(dt_model.predict(features[0]) <= 0)
self.assertTrue(dt_model.predict(features[1]) > 0)
self.assertTrue(dt_model.predict(features[2]) <= 0)
self.assertTrue(dt_model.predict(features[3]) > 0)
rf_model = RandomForest.trainRegressor(
rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=100)
self.assertTrue(rf_model.predict(features[0]) <= 0)
self.assertTrue(rf_model.predict(features[1]) > 0)
self.assertTrue(rf_model.predict(features[2]) <= 0)
self.assertTrue(rf_model.predict(features[3]) > 0)
gbt_model = GradientBoostedTrees.trainRegressor(
rdd, categoricalFeaturesInfo=categoricalFeaturesInfo)
self.assertTrue(gbt_model.predict(features[0]) <= 0)
self.assertTrue(gbt_model.predict(features[1]) > 0)
self.assertTrue(gbt_model.predict(features[2]) <= 0)
self.assertTrue(gbt_model.predict(features[3]) > 0)
class StatTests(PySparkTestCase):
# SPARK-4023

View file

@ -24,16 +24,48 @@ from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
from pyspark.mllib.linalg import _convert_to_vector
from pyspark.mllib.regression import LabeledPoint
__all__ = ['DecisionTreeModel', 'DecisionTree', 'RandomForestModel', 'RandomForest']
__all__ = ['DecisionTreeModel', 'DecisionTree', 'RandomForestModel',
'RandomForest', 'GradientBoostedTrees']
class TreeEnsembleModel(JavaModelWrapper):
def predict(self, x):
"""
Predict values for a single data point or an RDD of points using
the model trained.
"""
if isinstance(x, RDD):
return self.call("predict", x.map(_convert_to_vector))
else:
return self.call("predict", _convert_to_vector(x))
def numTrees(self):
"""
Get number of trees in ensemble.
"""
return self.call("numTrees")
def totalNumNodes(self):
"""
Get total number of nodes, summed over all trees in the ensemble.
"""
return self.call("totalNumNodes")
def __repr__(self):
""" Summary of model """
return self._java_model.toString()
def toDebugString(self):
""" Full model """
return self._java_model.toDebugString()
class DecisionTreeModel(JavaModelWrapper):
"""
A decision tree model for classification or regression.
.. note:: Experimental
EXPERIMENTAL: This is an experimental API.
It will probably be modified in future.
A decision tree model for classification or regression.
"""
def predict(self, x):
"""
@ -64,12 +96,10 @@ class DecisionTreeModel(JavaModelWrapper):
class DecisionTree(object):
"""
Learning algorithm for a decision tree model for classification or regression.
.. note:: Experimental
EXPERIMENTAL: This is an experimental API.
It will probably be modified in future.
Learning algorithm for a decision tree model for classification or regression.
"""
@classmethod
@ -186,51 +216,19 @@ class DecisionTree(object):
impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
class RandomForestModel(JavaModelWrapper):
class RandomForestModel(TreeEnsembleModel):
"""
.. note:: Experimental
Represents a random forest model.
EXPERIMENTAL: This is an experimental API.
It will probably be modified in future.
"""
def predict(self, x):
"""
Predict values for a single data point or an RDD of points using
the model trained.
"""
if isinstance(x, RDD):
return self.call("predict", x.map(_convert_to_vector))
else:
return self.call("predict", _convert_to_vector(x))
def numTrees(self):
"""
Get number of trees in forest.
"""
return self.call("numTrees")
def totalNumNodes(self):
"""
Get total number of nodes, summed over all trees in the forest.
"""
return self.call("totalNumNodes")
def __repr__(self):
""" Summary of model """
return self._java_model.toString()
def toDebugString(self):
""" Full model """
return self._java_model.toDebugString()
class RandomForest(object):
"""
Learning algorithm for a random forest model for classification or regression.
.. note:: Experimental
EXPERIMENTAL: This is an experimental API.
It will probably be modified in future.
Learning algorithm for a random forest model for classification or regression.
"""
supportedFeatureSubsetStrategies = ("auto", "all", "sqrt", "log2", "onethird")
@ -383,6 +381,137 @@ class RandomForest(object):
featureSubsetStrategy, impurity, maxDepth, maxBins, seed)
class GradientBoostedTreesModel(TreeEnsembleModel):
"""
.. note:: Experimental
Represents a gradient-boosted tree model.
"""
class GradientBoostedTrees(object):
"""
.. note:: Experimental
Learning algorithm for a gradient boosted trees model for classification or regression.
"""
@classmethod
def _train(cls, data, algo, categoricalFeaturesInfo,
loss, numIterations, learningRate, maxDepth):
first = data.first()
assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint"
model = callMLlibFunc("trainGradientBoostedTreesModel", data, algo, categoricalFeaturesInfo,
loss, numIterations, learningRate, maxDepth)
return GradientBoostedTreesModel(model)
@classmethod
def trainClassifier(cls, data, categoricalFeaturesInfo,
loss="logLoss", numIterations=100, learningRate=0.1, maxDepth=3):
"""
Method to train a gradient-boosted trees model for classification.
:param data: Training dataset: RDD of LabeledPoint. Labels should take values {0, 1}.
:param categoricalFeaturesInfo: Map storing arity of categorical
features. E.g., an entry (n -> k) indicates that feature
n is categorical with k categories indexed from 0:
{0, 1, ..., k-1}.
:param loss: Loss function used for minimization during gradient boosting.
Supported: {"logLoss" (default), "leastSquaresError", "leastAbsoluteError"}.
:param numIterations: Number of iterations of boosting.
(default: 100)
:param learningRate: Learning rate for shrinking the contribution of each estimator.
The learning rate should be between in the interval (0, 1]
(default: 0.1)
:param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1
leaf node; depth 1 means 1 internal node + 2 leaf nodes.
(default: 3)
:return: GradientBoostedTreesModel that can be used for prediction
Example usage:
>>> from pyspark.mllib.regression import LabeledPoint
>>> from pyspark.mllib.tree import GradientBoostedTrees
>>>
>>> data = [
... LabeledPoint(0.0, [0.0]),
... LabeledPoint(0.0, [1.0]),
... LabeledPoint(1.0, [2.0]),
... LabeledPoint(1.0, [3.0])
... ]
>>>
>>> model = GradientBoostedTrees.trainClassifier(sc.parallelize(data), {})
>>> model.numTrees()
100
>>> model.totalNumNodes()
300
>>> print model, # it already has newline
TreeEnsembleModel classifier with 100 trees
>>> model.predict([2.0])
1.0
>>> model.predict([0.0])
0.0
>>> rdd = sc.parallelize([[2.0], [0.0]])
>>> model.predict(rdd).collect()
[1.0, 0.0]
"""
return cls._train(data, "classification", categoricalFeaturesInfo,
loss, numIterations, learningRate, maxDepth)
@classmethod
def trainRegressor(cls, data, categoricalFeaturesInfo,
loss="leastSquaresError", numIterations=100, learningRate=0.1, maxDepth=3):
"""
Method to train a gradient-boosted trees model for regression.
:param data: Training dataset: RDD of LabeledPoint. Labels are
real numbers.
:param categoricalFeaturesInfo: Map storing arity of categorical
features. E.g., an entry (n -> k) indicates that feature
n is categorical with k categories indexed from 0:
{0, 1, ..., k-1}.
:param loss: Loss function used for minimization during gradient boosting.
Supported: {"logLoss" (default), "leastSquaresError", "leastAbsoluteError"}.
:param numIterations: Number of iterations of boosting.
(default: 100)
:param learningRate: Learning rate for shrinking the contribution of each estimator.
The learning rate should be between in the interval (0, 1]
(default: 0.1)
:param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1
leaf node; depth 1 means 1 internal node + 2 leaf nodes.
(default: 3)
:return: GradientBoostedTreesModel that can be used for prediction
Example usage:
>>> from pyspark.mllib.regression import LabeledPoint
>>> from pyspark.mllib.tree import GradientBoostedTrees
>>> from pyspark.mllib.linalg import SparseVector
>>>
>>> sparse_data = [
... LabeledPoint(0.0, SparseVector(2, {0: 1.0})),
... LabeledPoint(1.0, SparseVector(2, {1: 1.0})),
... LabeledPoint(0.0, SparseVector(2, {0: 1.0})),
... LabeledPoint(1.0, SparseVector(2, {1: 2.0}))
... ]
>>>
>>> model = GradientBoostedTrees.trainRegressor(sc.parallelize(sparse_data), {})
>>> model.numTrees()
100
>>> model.totalNumNodes()
102
>>> model.predict(SparseVector(2, {1: 1.0}))
1.0
>>> model.predict(SparseVector(2, {0: 1.0}))
0.0
>>> rdd = sc.parallelize([[0.0, 1.0], [1.0, 0.0]])
>>> model.predict(rdd).collect()
[1.0, 0.0]
"""
return cls._train(data, "regression", categoricalFeaturesInfo,
loss, numIterations, learningRate, maxDepth)
def _test():
import doctest
globs = globals().copy()