26d35f3fd9
Preview: http://54.82.240.23:4000/mllib-guide.html Table of contents: * Basics * Data types * Summary statistics * Classification and regression * linear support vector machine (SVM) * logistic regression * linear linear squares, Lasso, and ridge regression * decision tree * naive Bayes * Collaborative Filtering * alternating least squares (ALS) * Clustering * k-means * Dimensionality reduction * singular value decomposition (SVD) * principal component analysis (PCA) * Optimization * stochastic gradient descent * limited-memory BFGS (L-BFGS) Author: Xiangrui Meng <meng@databricks.com> Closes #422 from mengxr/mllib-doc and squashes the following commits: 944e3a9 [Xiangrui Meng] merge master f9fda28 [Xiangrui Meng] minor 9474065 [Xiangrui Meng] add alpha to ALS examples 928e630 [Xiangrui Meng] initialization_mode -> initializationMode 5bbff49 [Xiangrui Meng] add imports to labeled point examples c17440d [Xiangrui Meng] fix python nb example 28f40dc [Xiangrui Meng] remove localhost:4000 369a4d3 [Xiangrui Meng] Merge branch 'master' into mllib-doc 7dc95cc [Xiangrui Meng] update linear methods 053ad8a [Xiangrui Meng] add links to go back to the main page abbbf7e [Xiangrui Meng] update ALS argument names 648283e [Xiangrui Meng] level down statistics 14e2287 [Xiangrui Meng] add sample libsvm data and use it in guide 8cd2441 [Xiangrui Meng] minor updates 186ab07 [Xiangrui Meng] update section names 6568d65 [Xiangrui Meng] update toc, level up lr and svm 162ee12 [Xiangrui Meng] rename section names 5c1e1b1 [Xiangrui Meng] minor 8aeaba1 [Xiangrui Meng] wrap long lines 6ce6a6f [Xiangrui Meng] add summary statistics to toc 5760045 [Xiangrui Meng] claim beta cc604bf [Xiangrui Meng] remove classification and regression 92747b3 [Xiangrui Meng] make section titles consistent e605dd6 [Xiangrui Meng] add LIBSVM loader f639674 [Xiangrui Meng] add python section to migration guide c82ffb4 [Xiangrui Meng] clean optimization 31660eb [Xiangrui Meng] update linear algebra and stat 0a40837 [Xiangrui Meng] first pass over linear methods 1fc8271 [Xiangrui Meng] update toc 906ed0a [Xiangrui Meng] add a python example to naive bayes 5f0a700 [Xiangrui Meng] update collaborative filtering 656d416 [Xiangrui Meng] update mllib-clustering 86e143a [Xiangrui Meng] remove data types section from main page 8d1a128 [Xiangrui Meng] move part of linear algebra to data types and add Java/Python examples d1b5cbf [Xiangrui Meng] merge master 72e4804 [Xiangrui Meng] one pass over tree guide 64f8995 [Xiangrui Meng] move decision tree guide to a separate file 9fca001 [Xiangrui Meng] add first version of linear algebra guide 53c9552 [Xiangrui Meng] update dependencies f316ec2 [Xiangrui Meng] add migration guide f399f6c [Xiangrui Meng] move linear-algebra to dimensionality-reduction 182460f [Xiangrui Meng] add guide for naive Bayes 137fd1d [Xiangrui Meng] re-organize toc a61e434 [Xiangrui Meng] update mllib's toc
186 lines
7.8 KiB
Markdown
186 lines
7.8 KiB
Markdown
---
|
|
layout: global
|
|
title: <a href="mllib-guide.html">MLlib</a> - Decision Tree
|
|
---
|
|
|
|
* Table of contents
|
|
{:toc}
|
|
|
|
Decision trees and their ensembles are popular methods for the machine learning tasks of
|
|
classification and regression. Decision trees are widely used since they are easy to interpret,
|
|
handle categorical variables, extend to the multiclass classification setting, do not require
|
|
feature scaling and are able to capture nonlinearities and feature interactions. Tree ensemble
|
|
algorithms such as decision forest and boosting are among the top performers for classification and
|
|
regression tasks.
|
|
|
|
## Basic algorithm
|
|
|
|
The decision tree is a greedy algorithm that performs a recursive binary partitioning of the feature
|
|
space by choosing a single element from the *best split set* where each element of the set maximizes
|
|
the information gain at a tree node. In other words, the split chosen at each tree node is chosen
|
|
from the set `$\underset{s}{\operatorname{argmax}} IG(D,s)$` where `$IG(D,s)$` is the information
|
|
gain when a split `$s$` is applied to a dataset `$D$`.
|
|
|
|
### Node impurity and information gain
|
|
|
|
The *node impurity* is a measure of the homogeneity of the labels at the node. The current
|
|
implementation provides two impurity measures for classification (Gini impurity and entropy) and one
|
|
impurity measure for regression (variance).
|
|
|
|
<table class="table">
|
|
<thead>
|
|
<tr><th>Impurity</th><th>Task</th><th>Formula</th><th>Description</th></tr>
|
|
</thead>
|
|
<tbody>
|
|
<tr>
|
|
<td>Gini impurity</td>
|
|
<td>Classification</td>
|
|
<td>$\sum_{i=1}^{M} f_i(1-f_i)$</td><td>$f_i$ is the frequency of label $i$ at a node and $M$ is the number of unique labels.</td>
|
|
</tr>
|
|
<tr>
|
|
<td>Entropy</td>
|
|
<td>Classification</td>
|
|
<td>$\sum_{i=1}^{M} -f_ilog(f_i)$</td><td>$f_i$ is the frequency of label $i$ at a node and $M$ is the number of unique labels.</td>
|
|
</tr>
|
|
<tr>
|
|
<td>Variance</td>
|
|
<td>Regression</td>
|
|
<td>$\frac{1}{n} \sum_{i=1}^{N} (x_i - \mu)^2$</td><td>$y_i$ is label for an instance,
|
|
$N$ is the number of instances and $\mu$ is the mean given by $\frac{1}{N} \sum_{i=1}^n x_i$.</td>
|
|
</tr>
|
|
</tbody>
|
|
</table>
|
|
|
|
The *information gain* is the difference in the parent node impurity and the weighted sum of the two
|
|
child node impurities. Assuming that a split $s$ partitions the dataset `$D$` of size `$N$` into two
|
|
datasets `$D_{left}$` and `$D_{right}$` of sizes `$N_{left}$` and `$N_{right}$`, respectively:
|
|
|
|
`$IG(D,s) = Impurity(D) - \frac{N_{left}}{N} Impurity(D_{left}) - \frac{N_{right}}{N} Impurity(D_{right})$`
|
|
|
|
### Split candidates
|
|
|
|
**Continuous features**
|
|
|
|
For small datasets in single machine implementations, the split candidates for each continuous
|
|
feature are typically the unique values for the feature. Some implementations sort the feature
|
|
values and then use the ordered unique values as split candidates for faster tree calculations.
|
|
|
|
Finding ordered unique feature values is computationally intensive for large distributed
|
|
datasets. One can get an approximate set of split candidates by performing a quantile calculation
|
|
over a sampled fraction of the data. The ordered splits create "bins" and the maximum number of such
|
|
bins can be specified using the `maxBins` parameters.
|
|
|
|
Note that the number of bins cannot be greater than the number of instances `$N$` (a rare scenario
|
|
since the default `maxBins` value is 100). The tree algorithm automatically reduces the number of
|
|
bins if the condition is not satisfied.
|
|
|
|
**Categorical features**
|
|
|
|
For `$M$` categorical features, one could come up with `$2^M-1$` split candidates. However, for
|
|
binary classification, the number of split candidates can be reduced to `$M-1$` by ordering the
|
|
categorical feature values by the proportion of labels falling in one of the two classes (see
|
|
Section 9.2.4 in
|
|
[Elements of Statistical Machine Learning](http://statweb.stanford.edu/~tibs/ElemStatLearn/) for
|
|
details). For example, for a binary classification problem with one categorical feature with three
|
|
categories A, B and C with corresponding proportion of label 1 as 0.2, 0.6 and 0.4, the categorical
|
|
features are orded as A followed by C followed B or A, B, C. The two split candidates are A \| C, B
|
|
and A , B \| C where \| denotes the split.
|
|
|
|
### Stopping rule
|
|
|
|
The recursive tree construction is stopped at a node when one of the two conditions is met:
|
|
|
|
1. The node depth is equal to the `maxDepth` training parammeter
|
|
2. No split candidate leads to an information gain at the node.
|
|
|
|
### Practical limitations
|
|
|
|
1. The tree implementation stores an Array[Double] of size *O(#features \* #splits \* 2^maxDepth)*
|
|
in memory for aggregating histograms over partitions. The current implementation might not scale
|
|
to very deep trees since the memory requirement grows exponentially with tree depth.
|
|
2. The implemented algorithm reads both sparse and dense data. However, it is not optimized for
|
|
sparse input.
|
|
3. Python is not supported in this release.
|
|
|
|
We are planning to solve these problems in the near future. Please drop us a line if you encounter
|
|
any issues.
|
|
|
|
## Examples
|
|
|
|
### Classification
|
|
|
|
The example below demonstrates how to load a CSV file, parse it as an RDD of `LabeledPoint` and then
|
|
perform classification using a decision tree using Gini impurity as an impurity measure and a
|
|
maximum tree depth of 5. The training error is calculated to measure the algorithm accuracy.
|
|
|
|
<div class="codetabs">
|
|
<div data-lang="scala">
|
|
{% highlight scala %}
|
|
import org.apache.spark.SparkContext
|
|
import org.apache.spark.mllib.tree.DecisionTree
|
|
import org.apache.spark.mllib.regression.LabeledPoint
|
|
import org.apache.spark.mllib.linalg.Vectors
|
|
import org.apache.spark.mllib.tree.configuration.Algo._
|
|
import org.apache.spark.mllib.tree.impurity.Gini
|
|
|
|
// Load and parse the data file
|
|
val data = sc.textFile("mllib/data/sample_tree_data.csv")
|
|
val parsedData = data.map { line =>
|
|
val parts = line.split(',').map(_.toDouble)
|
|
LabeledPoint(parts(0), Vectors.dense(parts.tail))
|
|
}
|
|
|
|
// Run training algorithm to build the model
|
|
val maxDepth = 5
|
|
val model = DecisionTree.train(parsedData, Classification, Gini, maxDepth)
|
|
|
|
// Evaluate model on training examples and compute training error
|
|
val labelAndPreds = parsedData.map { point =>
|
|
val prediction = model.predict(point.features)
|
|
(point.label, prediction)
|
|
}
|
|
val trainErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / parsedData.count
|
|
println("Training Error = " + trainErr)
|
|
{% endhighlight %}
|
|
</div>
|
|
</div>
|
|
|
|
### Regression
|
|
|
|
The example below demonstrates how to load a CSV file, parse it as an RDD of `LabeledPoint` and then
|
|
perform regression using a decision tree using variance as an impurity measure and a maximum tree
|
|
depth of 5. The Mean Squared Error (MSE) is computed at the end to evaluate
|
|
[goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit).
|
|
|
|
<div class="codetabs">
|
|
<div data-lang="scala">
|
|
{% highlight scala %}
|
|
import org.apache.spark.SparkContext
|
|
import org.apache.spark.mllib.tree.DecisionTree
|
|
import org.apache.spark.mllib.regression.LabeledPoint
|
|
import org.apache.spark.mllib.linalg.Vectors
|
|
import org.apache.spark.mllib.tree.configuration.Algo._
|
|
import org.apache.spark.mllib.tree.impurity.Variance
|
|
|
|
// Load and parse the data file
|
|
val data = sc.textFile("mllib/data/sample_tree_data.csv")
|
|
val parsedData = data.map { line =>
|
|
val parts = line.split(',').map(_.toDouble)
|
|
LabeledPoint(parts(0), Vectors.dense(parts.tail))
|
|
}
|
|
|
|
// Run training algorithm to build the model
|
|
val maxDepth = 5
|
|
val model = DecisionTree.train(parsedData, Regression, Variance, maxDepth)
|
|
|
|
// Evaluate model on training examples and compute training error
|
|
val valuesAndPreds = parsedData.map { point =>
|
|
val prediction = model.predict(point.features)
|
|
(point.label, prediction)
|
|
}
|
|
val MSE = valuesAndPreds.map{ case(v, p) => math.pow((v - p), 2)}.reduce(_ + _)/valuesAndPreds.count
|
|
println("training Mean Squared Error = " + MSE)
|
|
{% endhighlight %}
|
|
</div>
|
|
</div>
|