[SPARK-11552][DOCS][Replaced example code in ml-decision-tree.md using include_example]
I have tested it on my local, it is working fine, please review Author: sachin aggarwal <different.sachin@gmail.com> Closes #9539 from agsachin/SPARK-11552-real.
This commit is contained in:
parent
5039a49b63
commit
51d41e4b1a
|
@ -118,196 +118,24 @@ We use two feature transformers to prepare the data; these help index categories
|
||||||
|
|
||||||
More details on parameters can be found in the [Scala API documentation](api/scala/index.html#org.apache.spark.ml.classification.DecisionTreeClassifier).
|
More details on parameters can be found in the [Scala API documentation](api/scala/index.html#org.apache.spark.ml.classification.DecisionTreeClassifier).
|
||||||
|
|
||||||
{% highlight scala %}
|
{% include_example scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala %}
|
||||||
import org.apache.spark.ml.Pipeline
|
|
||||||
import org.apache.spark.ml.classification.DecisionTreeClassifier
|
|
||||||
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
|
|
||||||
import org.apache.spark.ml.feature.{StringIndexer, IndexToString, VectorIndexer}
|
|
||||||
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
|
|
||||||
import org.apache.spark.mllib.util.MLUtils
|
|
||||||
|
|
||||||
// Load and parse the data file, converting it to a DataFrame.
|
|
||||||
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF()
|
|
||||||
|
|
||||||
// Index labels, adding metadata to the label column.
|
|
||||||
// Fit on whole dataset to include all labels in index.
|
|
||||||
val labelIndexer = new StringIndexer()
|
|
||||||
.setInputCol("label")
|
|
||||||
.setOutputCol("indexedLabel")
|
|
||||||
.fit(data)
|
|
||||||
// Automatically identify categorical features, and index them.
|
|
||||||
val featureIndexer = new VectorIndexer()
|
|
||||||
.setInputCol("features")
|
|
||||||
.setOutputCol("indexedFeatures")
|
|
||||||
.setMaxCategories(4) // features with > 4 distinct values are treated as continuous
|
|
||||||
.fit(data)
|
|
||||||
|
|
||||||
// Split the data into training and test sets (30% held out for testing)
|
|
||||||
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
|
|
||||||
|
|
||||||
// Train a DecisionTree model.
|
|
||||||
val dt = new DecisionTreeClassifier()
|
|
||||||
.setLabelCol("indexedLabel")
|
|
||||||
.setFeaturesCol("indexedFeatures")
|
|
||||||
|
|
||||||
// Convert indexed labels back to original labels.
|
|
||||||
val labelConverter = new IndexToString()
|
|
||||||
.setInputCol("prediction")
|
|
||||||
.setOutputCol("predictedLabel")
|
|
||||||
.setLabels(labelIndexer.labels)
|
|
||||||
|
|
||||||
// Chain indexers and tree in a Pipeline
|
|
||||||
val pipeline = new Pipeline()
|
|
||||||
.setStages(Array(labelIndexer, featureIndexer, dt, labelConverter))
|
|
||||||
|
|
||||||
// Train model. This also runs the indexers.
|
|
||||||
val model = pipeline.fit(trainingData)
|
|
||||||
|
|
||||||
// Make predictions.
|
|
||||||
val predictions = model.transform(testData)
|
|
||||||
|
|
||||||
// Select example rows to display.
|
|
||||||
predictions.select("predictedLabel", "label", "features").show(5)
|
|
||||||
|
|
||||||
// Select (prediction, true label) and compute test error
|
|
||||||
val evaluator = new MulticlassClassificationEvaluator()
|
|
||||||
.setLabelCol("indexedLabel")
|
|
||||||
.setPredictionCol("prediction")
|
|
||||||
.setMetricName("precision")
|
|
||||||
val accuracy = evaluator.evaluate(predictions)
|
|
||||||
println("Test Error = " + (1.0 - accuracy))
|
|
||||||
|
|
||||||
val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel]
|
|
||||||
println("Learned classification tree model:\n" + treeModel.toDebugString)
|
|
||||||
{% endhighlight %}
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div data-lang="java" markdown="1">
|
<div data-lang="java" markdown="1">
|
||||||
|
|
||||||
More details on parameters can be found in the [Java API documentation](api/java/org/apache/spark/ml/classification/DecisionTreeClassifier.html).
|
More details on parameters can be found in the [Java API documentation](api/java/org/apache/spark/ml/classification/DecisionTreeClassifier.html).
|
||||||
|
|
||||||
{% highlight java %}
|
{% include_example java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java %}
|
||||||
import org.apache.spark.ml.Pipeline;
|
|
||||||
import org.apache.spark.ml.PipelineModel;
|
|
||||||
import org.apache.spark.ml.PipelineStage;
|
|
||||||
import org.apache.spark.ml.classification.DecisionTreeClassifier;
|
|
||||||
import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
|
|
||||||
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
|
|
||||||
import org.apache.spark.ml.feature.*;
|
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint;
|
|
||||||
import org.apache.spark.mllib.util.MLUtils;
|
|
||||||
import org.apache.spark.rdd.RDD;
|
|
||||||
import org.apache.spark.sql.DataFrame;
|
|
||||||
|
|
||||||
// Load and parse the data file, converting it to a DataFrame.
|
|
||||||
RDD<LabeledPoint> rdd = MLUtils.loadLibSVMFile(sc.sc(), "data/mllib/sample_libsvm_data.txt");
|
|
||||||
DataFrame data = jsql.createDataFrame(rdd, LabeledPoint.class);
|
|
||||||
|
|
||||||
// Index labels, adding metadata to the label column.
|
|
||||||
// Fit on whole dataset to include all labels in index.
|
|
||||||
StringIndexerModel labelIndexer = new StringIndexer()
|
|
||||||
.setInputCol("label")
|
|
||||||
.setOutputCol("indexedLabel")
|
|
||||||
.fit(data);
|
|
||||||
// Automatically identify categorical features, and index them.
|
|
||||||
VectorIndexerModel featureIndexer = new VectorIndexer()
|
|
||||||
.setInputCol("features")
|
|
||||||
.setOutputCol("indexedFeatures")
|
|
||||||
.setMaxCategories(4) // features with > 4 distinct values are treated as continuous
|
|
||||||
.fit(data);
|
|
||||||
|
|
||||||
// Split the data into training and test sets (30% held out for testing)
|
|
||||||
DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3});
|
|
||||||
DataFrame trainingData = splits[0];
|
|
||||||
DataFrame testData = splits[1];
|
|
||||||
|
|
||||||
// Train a DecisionTree model.
|
|
||||||
DecisionTreeClassifier dt = new DecisionTreeClassifier()
|
|
||||||
.setLabelCol("indexedLabel")
|
|
||||||
.setFeaturesCol("indexedFeatures");
|
|
||||||
|
|
||||||
// Convert indexed labels back to original labels.
|
|
||||||
IndexToString labelConverter = new IndexToString()
|
|
||||||
.setInputCol("prediction")
|
|
||||||
.setOutputCol("predictedLabel")
|
|
||||||
.setLabels(labelIndexer.labels());
|
|
||||||
|
|
||||||
// Chain indexers and tree in a Pipeline
|
|
||||||
Pipeline pipeline = new Pipeline()
|
|
||||||
.setStages(new PipelineStage[] {labelIndexer, featureIndexer, dt, labelConverter});
|
|
||||||
|
|
||||||
// Train model. This also runs the indexers.
|
|
||||||
PipelineModel model = pipeline.fit(trainingData);
|
|
||||||
|
|
||||||
// Make predictions.
|
|
||||||
DataFrame predictions = model.transform(testData);
|
|
||||||
|
|
||||||
// Select example rows to display.
|
|
||||||
predictions.select("predictedLabel", "label", "features").show(5);
|
|
||||||
|
|
||||||
// Select (prediction, true label) and compute test error
|
|
||||||
MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
|
|
||||||
.setLabelCol("indexedLabel")
|
|
||||||
.setPredictionCol("prediction")
|
|
||||||
.setMetricName("precision");
|
|
||||||
double accuracy = evaluator.evaluate(predictions);
|
|
||||||
System.out.println("Test Error = " + (1.0 - accuracy));
|
|
||||||
|
|
||||||
DecisionTreeClassificationModel treeModel =
|
|
||||||
(DecisionTreeClassificationModel)(model.stages()[2]);
|
|
||||||
System.out.println("Learned classification tree model:\n" + treeModel.toDebugString());
|
|
||||||
{% endhighlight %}
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div data-lang="python" markdown="1">
|
<div data-lang="python" markdown="1">
|
||||||
|
|
||||||
More details on parameters can be found in the [Python API documentation](api/python/pyspark.ml.html#pyspark.ml.classification.DecisionTreeClassifier).
|
More details on parameters can be found in the [Python API documentation](api/python/pyspark.ml.html#pyspark.ml.classification.DecisionTreeClassifier).
|
||||||
|
|
||||||
{% highlight python %}
|
{% include_example python/ml/decision_tree_classification_example.py %}
|
||||||
from pyspark.ml import Pipeline
|
|
||||||
from pyspark.ml.classification import DecisionTreeClassifier
|
|
||||||
from pyspark.ml.feature import StringIndexer, VectorIndexer
|
|
||||||
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
|
|
||||||
from pyspark.mllib.util import MLUtils
|
|
||||||
|
|
||||||
# Load and parse the data file, converting it to a DataFrame.
|
|
||||||
data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF()
|
|
||||||
|
|
||||||
# Index labels, adding metadata to the label column.
|
|
||||||
# Fit on whole dataset to include all labels in index.
|
|
||||||
labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data)
|
|
||||||
# Automatically identify categorical features, and index them.
|
|
||||||
# We specify maxCategories so features with > 4 distinct values are treated as continuous.
|
|
||||||
featureIndexer =\
|
|
||||||
VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data)
|
|
||||||
|
|
||||||
# Split the data into training and test sets (30% held out for testing)
|
|
||||||
(trainingData, testData) = data.randomSplit([0.7, 0.3])
|
|
||||||
|
|
||||||
# Train a DecisionTree model.
|
|
||||||
dt = DecisionTreeClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures")
|
|
||||||
|
|
||||||
# Chain indexers and tree in a Pipeline
|
|
||||||
pipeline = Pipeline(stages=[labelIndexer, featureIndexer, dt])
|
|
||||||
|
|
||||||
# Train model. This also runs the indexers.
|
|
||||||
model = pipeline.fit(trainingData)
|
|
||||||
|
|
||||||
# Make predictions.
|
|
||||||
predictions = model.transform(testData)
|
|
||||||
|
|
||||||
# Select example rows to display.
|
|
||||||
predictions.select("prediction", "indexedLabel", "features").show(5)
|
|
||||||
|
|
||||||
# Select (prediction, true label) and compute test error
|
|
||||||
evaluator = MulticlassClassificationEvaluator(
|
|
||||||
labelCol="indexedLabel", predictionCol="prediction", metricName="precision")
|
|
||||||
accuracy = evaluator.evaluate(predictions)
|
|
||||||
print "Test Error = %g" % (1.0 - accuracy)
|
|
||||||
|
|
||||||
treeModel = model.stages[2]
|
|
||||||
print treeModel # summary only
|
|
||||||
{% endhighlight %}
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
@ -323,171 +151,21 @@ We use a feature transformer to index categorical features, adding metadata to t
|
||||||
|
|
||||||
More details on parameters can be found in the [Scala API documentation](api/scala/index.html#org.apache.spark.ml.regression.DecisionTreeRegressor).
|
More details on parameters can be found in the [Scala API documentation](api/scala/index.html#org.apache.spark.ml.regression.DecisionTreeRegressor).
|
||||||
|
|
||||||
{% highlight scala %}
|
{% include_example scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala %}
|
||||||
import org.apache.spark.ml.Pipeline
|
|
||||||
import org.apache.spark.ml.regression.DecisionTreeRegressor
|
|
||||||
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
|
|
||||||
import org.apache.spark.ml.feature.VectorIndexer
|
|
||||||
import org.apache.spark.ml.evaluation.RegressionEvaluator
|
|
||||||
import org.apache.spark.mllib.util.MLUtils
|
|
||||||
|
|
||||||
// Load and parse the data file, converting it to a DataFrame.
|
|
||||||
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF()
|
|
||||||
|
|
||||||
// Automatically identify categorical features, and index them.
|
|
||||||
// Here, we treat features with > 4 distinct values as continuous.
|
|
||||||
val featureIndexer = new VectorIndexer()
|
|
||||||
.setInputCol("features")
|
|
||||||
.setOutputCol("indexedFeatures")
|
|
||||||
.setMaxCategories(4)
|
|
||||||
.fit(data)
|
|
||||||
|
|
||||||
// Split the data into training and test sets (30% held out for testing)
|
|
||||||
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
|
|
||||||
|
|
||||||
// Train a DecisionTree model.
|
|
||||||
val dt = new DecisionTreeRegressor()
|
|
||||||
.setLabelCol("label")
|
|
||||||
.setFeaturesCol("indexedFeatures")
|
|
||||||
|
|
||||||
// Chain indexer and tree in a Pipeline
|
|
||||||
val pipeline = new Pipeline()
|
|
||||||
.setStages(Array(featureIndexer, dt))
|
|
||||||
|
|
||||||
// Train model. This also runs the indexer.
|
|
||||||
val model = pipeline.fit(trainingData)
|
|
||||||
|
|
||||||
// Make predictions.
|
|
||||||
val predictions = model.transform(testData)
|
|
||||||
|
|
||||||
// Select example rows to display.
|
|
||||||
predictions.select("prediction", "label", "features").show(5)
|
|
||||||
|
|
||||||
// Select (prediction, true label) and compute test error
|
|
||||||
val evaluator = new RegressionEvaluator()
|
|
||||||
.setLabelCol("label")
|
|
||||||
.setPredictionCol("prediction")
|
|
||||||
.setMetricName("rmse")
|
|
||||||
val rmse = evaluator.evaluate(predictions)
|
|
||||||
println("Root Mean Squared Error (RMSE) on test data = " + rmse)
|
|
||||||
|
|
||||||
val treeModel = model.stages(1).asInstanceOf[DecisionTreeRegressionModel]
|
|
||||||
println("Learned regression tree model:\n" + treeModel.toDebugString)
|
|
||||||
{% endhighlight %}
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div data-lang="java" markdown="1">
|
<div data-lang="java" markdown="1">
|
||||||
|
|
||||||
More details on parameters can be found in the [Java API documentation](api/java/org/apache/spark/ml/regression/DecisionTreeRegressor.html).
|
More details on parameters can be found in the [Java API documentation](api/java/org/apache/spark/ml/regression/DecisionTreeRegressor.html).
|
||||||
|
|
||||||
{% highlight java %}
|
{% include_example java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java %}
|
||||||
import org.apache.spark.ml.Pipeline;
|
|
||||||
import org.apache.spark.ml.PipelineModel;
|
|
||||||
import org.apache.spark.ml.PipelineStage;
|
|
||||||
import org.apache.spark.ml.evaluation.RegressionEvaluator;
|
|
||||||
import org.apache.spark.ml.feature.VectorIndexer;
|
|
||||||
import org.apache.spark.ml.feature.VectorIndexerModel;
|
|
||||||
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
|
|
||||||
import org.apache.spark.ml.regression.DecisionTreeRegressor;
|
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint;
|
|
||||||
import org.apache.spark.mllib.util.MLUtils;
|
|
||||||
import org.apache.spark.rdd.RDD;
|
|
||||||
import org.apache.spark.sql.DataFrame;
|
|
||||||
|
|
||||||
// Load and parse the data file, converting it to a DataFrame.
|
|
||||||
RDD<LabeledPoint> rdd = MLUtils.loadLibSVMFile(sc.sc(), "data/mllib/sample_libsvm_data.txt");
|
|
||||||
DataFrame data = jsql.createDataFrame(rdd, LabeledPoint.class);
|
|
||||||
|
|
||||||
// Automatically identify categorical features, and index them.
|
|
||||||
// Set maxCategories so features with > 4 distinct values are treated as continuous.
|
|
||||||
VectorIndexerModel featureIndexer = new VectorIndexer()
|
|
||||||
.setInputCol("features")
|
|
||||||
.setOutputCol("indexedFeatures")
|
|
||||||
.setMaxCategories(4)
|
|
||||||
.fit(data);
|
|
||||||
|
|
||||||
// Split the data into training and test sets (30% held out for testing)
|
|
||||||
DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3});
|
|
||||||
DataFrame trainingData = splits[0];
|
|
||||||
DataFrame testData = splits[1];
|
|
||||||
|
|
||||||
// Train a DecisionTree model.
|
|
||||||
DecisionTreeRegressor dt = new DecisionTreeRegressor()
|
|
||||||
.setFeaturesCol("indexedFeatures");
|
|
||||||
|
|
||||||
// Chain indexer and tree in a Pipeline
|
|
||||||
Pipeline pipeline = new Pipeline()
|
|
||||||
.setStages(new PipelineStage[] {featureIndexer, dt});
|
|
||||||
|
|
||||||
// Train model. This also runs the indexer.
|
|
||||||
PipelineModel model = pipeline.fit(trainingData);
|
|
||||||
|
|
||||||
// Make predictions.
|
|
||||||
DataFrame predictions = model.transform(testData);
|
|
||||||
|
|
||||||
// Select example rows to display.
|
|
||||||
predictions.select("label", "features").show(5);
|
|
||||||
|
|
||||||
// Select (prediction, true label) and compute test error
|
|
||||||
RegressionEvaluator evaluator = new RegressionEvaluator()
|
|
||||||
.setLabelCol("label")
|
|
||||||
.setPredictionCol("prediction")
|
|
||||||
.setMetricName("rmse");
|
|
||||||
double rmse = evaluator.evaluate(predictions);
|
|
||||||
System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse);
|
|
||||||
|
|
||||||
DecisionTreeRegressionModel treeModel =
|
|
||||||
(DecisionTreeRegressionModel)(model.stages()[1]);
|
|
||||||
System.out.println("Learned regression tree model:\n" + treeModel.toDebugString());
|
|
||||||
{% endhighlight %}
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div data-lang="python" markdown="1">
|
<div data-lang="python" markdown="1">
|
||||||
|
|
||||||
More details on parameters can be found in the [Python API documentation](api/python/pyspark.ml.html#pyspark.ml.regression.DecisionTreeRegressor).
|
More details on parameters can be found in the [Python API documentation](api/python/pyspark.ml.html#pyspark.ml.regression.DecisionTreeRegressor).
|
||||||
|
|
||||||
{% highlight python %}
|
{% include_example python/ml/decision_tree_regression_example.py %}
|
||||||
from pyspark.ml import Pipeline
|
|
||||||
from pyspark.ml.regression import DecisionTreeRegressor
|
|
||||||
from pyspark.ml.feature import VectorIndexer
|
|
||||||
from pyspark.ml.evaluation import RegressionEvaluator
|
|
||||||
from pyspark.mllib.util import MLUtils
|
|
||||||
|
|
||||||
# Load and parse the data file, converting it to a DataFrame.
|
|
||||||
data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF()
|
|
||||||
|
|
||||||
# Automatically identify categorical features, and index them.
|
|
||||||
# We specify maxCategories so features with > 4 distinct values are treated as continuous.
|
|
||||||
featureIndexer =\
|
|
||||||
VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data)
|
|
||||||
|
|
||||||
# Split the data into training and test sets (30% held out for testing)
|
|
||||||
(trainingData, testData) = data.randomSplit([0.7, 0.3])
|
|
||||||
|
|
||||||
# Train a DecisionTree model.
|
|
||||||
dt = DecisionTreeRegressor(featuresCol="indexedFeatures")
|
|
||||||
|
|
||||||
# Chain indexer and tree in a Pipeline
|
|
||||||
pipeline = Pipeline(stages=[featureIndexer, dt])
|
|
||||||
|
|
||||||
# Train model. This also runs the indexer.
|
|
||||||
model = pipeline.fit(trainingData)
|
|
||||||
|
|
||||||
# Make predictions.
|
|
||||||
predictions = model.transform(testData)
|
|
||||||
|
|
||||||
# Select example rows to display.
|
|
||||||
predictions.select("prediction", "label", "features").show(5)
|
|
||||||
|
|
||||||
# Select (prediction, true label) and compute test error
|
|
||||||
evaluator = RegressionEvaluator(
|
|
||||||
labelCol="label", predictionCol="prediction", metricName="rmse")
|
|
||||||
rmse = evaluator.evaluate(predictions)
|
|
||||||
print "Root Mean Squared Error (RMSE) on test data = %g" % rmse
|
|
||||||
|
|
||||||
treeModel = model.stages[1]
|
|
||||||
print treeModel # summary only
|
|
||||||
{% endhighlight %}
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
|
@ -0,0 +1,103 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
// scalastyle:off println
|
||||||
|
package org.apache.spark.examples.ml;
|
||||||
|
// $example on$
|
||||||
|
import org.apache.spark.SparkConf;
|
||||||
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
|
import org.apache.spark.ml.Pipeline;
|
||||||
|
import org.apache.spark.ml.PipelineModel;
|
||||||
|
import org.apache.spark.ml.PipelineStage;
|
||||||
|
import org.apache.spark.ml.classification.DecisionTreeClassifier;
|
||||||
|
import org.apache.spark.ml.classification.DecisionTreeClassificationModel;
|
||||||
|
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
|
||||||
|
import org.apache.spark.ml.feature.*;
|
||||||
|
import org.apache.spark.mllib.regression.LabeledPoint;
|
||||||
|
import org.apache.spark.mllib.util.MLUtils;
|
||||||
|
import org.apache.spark.rdd.RDD;
|
||||||
|
import org.apache.spark.sql.DataFrame;
|
||||||
|
import org.apache.spark.sql.SQLContext;
|
||||||
|
// $example off$
|
||||||
|
|
||||||
|
public class JavaDecisionTreeClassificationExample {
|
||||||
|
public static void main(String[] args) {
|
||||||
|
SparkConf conf = new SparkConf().setAppName("JavaDecisionTreeClassificationExample");
|
||||||
|
JavaSparkContext jsc = new JavaSparkContext(conf);
|
||||||
|
SQLContext sqlContext = new SQLContext(jsc);
|
||||||
|
|
||||||
|
// $example on$
|
||||||
|
// Load and parse the data file, converting it to a DataFrame.
|
||||||
|
RDD<LabeledPoint> rdd = MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt");
|
||||||
|
DataFrame data = sqlContext.createDataFrame(rdd, LabeledPoint.class);
|
||||||
|
|
||||||
|
// Index labels, adding metadata to the label column.
|
||||||
|
// Fit on whole dataset to include all labels in index.
|
||||||
|
StringIndexerModel labelIndexer = new StringIndexer()
|
||||||
|
.setInputCol("label")
|
||||||
|
.setOutputCol("indexedLabel")
|
||||||
|
.fit(data);
|
||||||
|
|
||||||
|
// Automatically identify categorical features, and index them.
|
||||||
|
VectorIndexerModel featureIndexer = new VectorIndexer()
|
||||||
|
.setInputCol("features")
|
||||||
|
.setOutputCol("indexedFeatures")
|
||||||
|
.setMaxCategories(4) // features with > 4 distinct values are treated as continuous
|
||||||
|
.fit(data);
|
||||||
|
|
||||||
|
// Split the data into training and test sets (30% held out for testing)
|
||||||
|
DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3});
|
||||||
|
DataFrame trainingData = splits[0];
|
||||||
|
DataFrame testData = splits[1];
|
||||||
|
|
||||||
|
// Train a DecisionTree model.
|
||||||
|
DecisionTreeClassifier dt = new DecisionTreeClassifier()
|
||||||
|
.setLabelCol("indexedLabel")
|
||||||
|
.setFeaturesCol("indexedFeatures");
|
||||||
|
|
||||||
|
// Convert indexed labels back to original labels.
|
||||||
|
IndexToString labelConverter = new IndexToString()
|
||||||
|
.setInputCol("prediction")
|
||||||
|
.setOutputCol("predictedLabel")
|
||||||
|
.setLabels(labelIndexer.labels());
|
||||||
|
|
||||||
|
// Chain indexers and tree in a Pipeline
|
||||||
|
Pipeline pipeline = new Pipeline()
|
||||||
|
.setStages(new PipelineStage[]{labelIndexer, featureIndexer, dt, labelConverter});
|
||||||
|
|
||||||
|
// Train model. This also runs the indexers.
|
||||||
|
PipelineModel model = pipeline.fit(trainingData);
|
||||||
|
|
||||||
|
// Make predictions.
|
||||||
|
DataFrame predictions = model.transform(testData);
|
||||||
|
|
||||||
|
// Select example rows to display.
|
||||||
|
predictions.select("predictedLabel", "label", "features").show(5);
|
||||||
|
|
||||||
|
// Select (prediction, true label) and compute test error
|
||||||
|
MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
|
||||||
|
.setLabelCol("indexedLabel")
|
||||||
|
.setPredictionCol("prediction")
|
||||||
|
.setMetricName("precision");
|
||||||
|
double accuracy = evaluator.evaluate(predictions);
|
||||||
|
System.out.println("Test Error = " + (1.0 - accuracy));
|
||||||
|
|
||||||
|
DecisionTreeClassificationModel treeModel =
|
||||||
|
(DecisionTreeClassificationModel) (model.stages()[2]);
|
||||||
|
System.out.println("Learned classification tree model:\n" + treeModel.toDebugString());
|
||||||
|
// $example off$
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,90 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
// scalastyle:off println
|
||||||
|
package org.apache.spark.examples.ml;
|
||||||
|
// $example on$
|
||||||
|
import org.apache.spark.SparkConf;
|
||||||
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
|
import org.apache.spark.ml.Pipeline;
|
||||||
|
import org.apache.spark.ml.PipelineModel;
|
||||||
|
import org.apache.spark.ml.PipelineStage;
|
||||||
|
import org.apache.spark.ml.evaluation.RegressionEvaluator;
|
||||||
|
import org.apache.spark.ml.feature.VectorIndexer;
|
||||||
|
import org.apache.spark.ml.feature.VectorIndexerModel;
|
||||||
|
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
|
||||||
|
import org.apache.spark.ml.regression.DecisionTreeRegressor;
|
||||||
|
import org.apache.spark.mllib.regression.LabeledPoint;
|
||||||
|
import org.apache.spark.mllib.util.MLUtils;
|
||||||
|
import org.apache.spark.rdd.RDD;
|
||||||
|
import org.apache.spark.sql.DataFrame;
|
||||||
|
import org.apache.spark.sql.SQLContext;
|
||||||
|
// $example off$
|
||||||
|
|
||||||
|
public class JavaDecisionTreeRegressionExample {
|
||||||
|
public static void main(String[] args) {
|
||||||
|
SparkConf conf = new SparkConf().setAppName("JavaDecisionTreeRegressionExample");
|
||||||
|
JavaSparkContext jsc = new JavaSparkContext(conf);
|
||||||
|
SQLContext sqlContext = new SQLContext(jsc);
|
||||||
|
// $example on$
|
||||||
|
// Load and parse the data file, converting it to a DataFrame.
|
||||||
|
RDD<LabeledPoint> rdd = MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt");
|
||||||
|
DataFrame data = sqlContext.createDataFrame(rdd, LabeledPoint.class);
|
||||||
|
|
||||||
|
// Automatically identify categorical features, and index them.
|
||||||
|
// Set maxCategories so features with > 4 distinct values are treated as continuous.
|
||||||
|
VectorIndexerModel featureIndexer = new VectorIndexer()
|
||||||
|
.setInputCol("features")
|
||||||
|
.setOutputCol("indexedFeatures")
|
||||||
|
.setMaxCategories(4)
|
||||||
|
.fit(data);
|
||||||
|
|
||||||
|
// Split the data into training and test sets (30% held out for testing)
|
||||||
|
DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3});
|
||||||
|
DataFrame trainingData = splits[0];
|
||||||
|
DataFrame testData = splits[1];
|
||||||
|
|
||||||
|
// Train a DecisionTree model.
|
||||||
|
DecisionTreeRegressor dt = new DecisionTreeRegressor()
|
||||||
|
.setFeaturesCol("indexedFeatures");
|
||||||
|
|
||||||
|
// Chain indexer and tree in a Pipeline
|
||||||
|
Pipeline pipeline = new Pipeline()
|
||||||
|
.setStages(new PipelineStage[]{featureIndexer, dt});
|
||||||
|
|
||||||
|
// Train model. This also runs the indexer.
|
||||||
|
PipelineModel model = pipeline.fit(trainingData);
|
||||||
|
|
||||||
|
// Make predictions.
|
||||||
|
DataFrame predictions = model.transform(testData);
|
||||||
|
|
||||||
|
// Select example rows to display.
|
||||||
|
predictions.select("label", "features").show(5);
|
||||||
|
|
||||||
|
// Select (prediction, true label) and compute test error
|
||||||
|
RegressionEvaluator evaluator = new RegressionEvaluator()
|
||||||
|
.setLabelCol("label")
|
||||||
|
.setPredictionCol("prediction")
|
||||||
|
.setMetricName("rmse");
|
||||||
|
double rmse = evaluator.evaluate(predictions);
|
||||||
|
System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse);
|
||||||
|
|
||||||
|
DecisionTreeRegressionModel treeModel =
|
||||||
|
(DecisionTreeRegressionModel) (model.stages()[1]);
|
||||||
|
System.out.println("Learned regression tree model:\n" + treeModel.toDebugString());
|
||||||
|
// $example off$
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,77 @@
|
||||||
|
#
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
# contributor license agreements. See the NOTICE file distributed with
|
||||||
|
# this work for additional information regarding copyright ownership.
|
||||||
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
# (the "License"); you may not use this file except in compliance with
|
||||||
|
# the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
"""
|
||||||
|
Decision Tree Classification Example.
|
||||||
|
"""
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# $example on$
|
||||||
|
from pyspark import SparkContext, SQLContext
|
||||||
|
from pyspark.ml import Pipeline
|
||||||
|
from pyspark.ml.classification import DecisionTreeClassifier
|
||||||
|
from pyspark.ml.feature import StringIndexer, VectorIndexer
|
||||||
|
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
|
||||||
|
from pyspark.mllib.util import MLUtils
|
||||||
|
# $example off$
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sc = SparkContext(appName="decision_tree_classification_example")
|
||||||
|
sqlContext = SQLContext(sc)
|
||||||
|
|
||||||
|
# $example on$
|
||||||
|
# Load and parse the data file, converting it to a DataFrame.
|
||||||
|
data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF()
|
||||||
|
|
||||||
|
# Index labels, adding metadata to the label column.
|
||||||
|
# Fit on whole dataset to include all labels in index.
|
||||||
|
labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data)
|
||||||
|
# Automatically identify categorical features, and index them.
|
||||||
|
# We specify maxCategories so features with > 4 distinct values are treated as continuous.
|
||||||
|
featureIndexer =\
|
||||||
|
VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data)
|
||||||
|
|
||||||
|
# Split the data into training and test sets (30% held out for testing)
|
||||||
|
(trainingData, testData) = data.randomSplit([0.7, 0.3])
|
||||||
|
|
||||||
|
# Train a DecisionTree model.
|
||||||
|
dt = DecisionTreeClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures")
|
||||||
|
|
||||||
|
# Chain indexers and tree in a Pipeline
|
||||||
|
pipeline = Pipeline(stages=[labelIndexer, featureIndexer, dt])
|
||||||
|
|
||||||
|
# Train model. This also runs the indexers.
|
||||||
|
model = pipeline.fit(trainingData)
|
||||||
|
|
||||||
|
# Make predictions.
|
||||||
|
predictions = model.transform(testData)
|
||||||
|
|
||||||
|
# Select example rows to display.
|
||||||
|
predictions.select("prediction", "indexedLabel", "features").show(5)
|
||||||
|
|
||||||
|
# Select (prediction, true label) and compute test error
|
||||||
|
evaluator = MulticlassClassificationEvaluator(
|
||||||
|
labelCol="indexedLabel", predictionCol="prediction", metricName="precision")
|
||||||
|
accuracy = evaluator.evaluate(predictions)
|
||||||
|
print("Test Error = %g " % (1.0 - accuracy))
|
||||||
|
|
||||||
|
treeModel = model.stages[2]
|
||||||
|
# summary only
|
||||||
|
print(treeModel)
|
||||||
|
# $example off$
|
|
@ -0,0 +1,74 @@
|
||||||
|
#
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
# contributor license agreements. See the NOTICE file distributed with
|
||||||
|
# this work for additional information regarding copyright ownership.
|
||||||
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
# (the "License"); you may not use this file except in compliance with
|
||||||
|
# the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
"""
|
||||||
|
Decision Tree Regression Example.
|
||||||
|
"""
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from pyspark import SparkContext, SQLContext
|
||||||
|
# $example on$
|
||||||
|
from pyspark.ml import Pipeline
|
||||||
|
from pyspark.ml.regression import DecisionTreeRegressor
|
||||||
|
from pyspark.ml.feature import VectorIndexer
|
||||||
|
from pyspark.ml.evaluation import RegressionEvaluator
|
||||||
|
from pyspark.mllib.util import MLUtils
|
||||||
|
# $example off$
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sc = SparkContext(appName="decision_tree_classification_example")
|
||||||
|
sqlContext = SQLContext(sc)
|
||||||
|
|
||||||
|
# $example on$
|
||||||
|
# Load and parse the data file, converting it to a DataFrame.
|
||||||
|
data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF()
|
||||||
|
|
||||||
|
# Automatically identify categorical features, and index them.
|
||||||
|
# We specify maxCategories so features with > 4 distinct values are treated as continuous.
|
||||||
|
featureIndexer =\
|
||||||
|
VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data)
|
||||||
|
|
||||||
|
# Split the data into training and test sets (30% held out for testing)
|
||||||
|
(trainingData, testData) = data.randomSplit([0.7, 0.3])
|
||||||
|
|
||||||
|
# Train a DecisionTree model.
|
||||||
|
dt = DecisionTreeRegressor(featuresCol="indexedFeatures")
|
||||||
|
|
||||||
|
# Chain indexer and tree in a Pipeline
|
||||||
|
pipeline = Pipeline(stages=[featureIndexer, dt])
|
||||||
|
|
||||||
|
# Train model. This also runs the indexer.
|
||||||
|
model = pipeline.fit(trainingData)
|
||||||
|
|
||||||
|
# Make predictions.
|
||||||
|
predictions = model.transform(testData)
|
||||||
|
|
||||||
|
# Select example rows to display.
|
||||||
|
predictions.select("prediction", "label", "features").show(5)
|
||||||
|
|
||||||
|
# Select (prediction, true label) and compute test error
|
||||||
|
evaluator = RegressionEvaluator(
|
||||||
|
labelCol="label", predictionCol="prediction", metricName="rmse")
|
||||||
|
rmse = evaluator.evaluate(predictions)
|
||||||
|
print("Root Mean Squared Error (RMSE) on test data = %g" % rmse)
|
||||||
|
|
||||||
|
treeModel = model.stages[1]
|
||||||
|
# summary only
|
||||||
|
print(treeModel)
|
||||||
|
# $example off$
|
|
@ -0,0 +1,94 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// scalastyle:off println
|
||||||
|
package org.apache.spark.examples.ml
|
||||||
|
|
||||||
|
import org.apache.spark.sql.SQLContext
|
||||||
|
import org.apache.spark.{SparkContext, SparkConf}
|
||||||
|
// $example on$
|
||||||
|
import org.apache.spark.ml.Pipeline
|
||||||
|
import org.apache.spark.ml.classification.DecisionTreeClassifier
|
||||||
|
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
|
||||||
|
import org.apache.spark.ml.feature.{StringIndexer, IndexToString, VectorIndexer}
|
||||||
|
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
|
||||||
|
import org.apache.spark.mllib.util.MLUtils
|
||||||
|
// $example off$
|
||||||
|
|
||||||
|
object DecisionTreeClassificationExample {
|
||||||
|
def main(args: Array[String]): Unit = {
|
||||||
|
val conf = new SparkConf().setAppName("DecisionTreeClassificationExample")
|
||||||
|
val sc = new SparkContext(conf)
|
||||||
|
val sqlContext = new SQLContext(sc)
|
||||||
|
import sqlContext.implicits._
|
||||||
|
// $example on$
|
||||||
|
// Load and parse the data file, converting it to a DataFrame.
|
||||||
|
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF()
|
||||||
|
|
||||||
|
// Index labels, adding metadata to the label column.
|
||||||
|
// Fit on whole dataset to include all labels in index.
|
||||||
|
val labelIndexer = new StringIndexer()
|
||||||
|
.setInputCol("label")
|
||||||
|
.setOutputCol("indexedLabel")
|
||||||
|
.fit(data)
|
||||||
|
// Automatically identify categorical features, and index them.
|
||||||
|
val featureIndexer = new VectorIndexer()
|
||||||
|
.setInputCol("features")
|
||||||
|
.setOutputCol("indexedFeatures")
|
||||||
|
.setMaxCategories(4) // features with > 4 distinct values are treated as continuous
|
||||||
|
.fit(data)
|
||||||
|
|
||||||
|
// Split the data into training and test sets (30% held out for testing)
|
||||||
|
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
|
||||||
|
|
||||||
|
// Train a DecisionTree model.
|
||||||
|
val dt = new DecisionTreeClassifier()
|
||||||
|
.setLabelCol("indexedLabel")
|
||||||
|
.setFeaturesCol("indexedFeatures")
|
||||||
|
|
||||||
|
// Convert indexed labels back to original labels.
|
||||||
|
val labelConverter = new IndexToString()
|
||||||
|
.setInputCol("prediction")
|
||||||
|
.setOutputCol("predictedLabel")
|
||||||
|
.setLabels(labelIndexer.labels)
|
||||||
|
|
||||||
|
// Chain indexers and tree in a Pipeline
|
||||||
|
val pipeline = new Pipeline()
|
||||||
|
.setStages(Array(labelIndexer, featureIndexer, dt, labelConverter))
|
||||||
|
|
||||||
|
// Train model. This also runs the indexers.
|
||||||
|
val model = pipeline.fit(trainingData)
|
||||||
|
|
||||||
|
// Make predictions.
|
||||||
|
val predictions = model.transform(testData)
|
||||||
|
|
||||||
|
// Select example rows to display.
|
||||||
|
predictions.select("predictedLabel", "label", "features").show(5)
|
||||||
|
|
||||||
|
// Select (prediction, true label) and compute test error
|
||||||
|
val evaluator = new MulticlassClassificationEvaluator()
|
||||||
|
.setLabelCol("indexedLabel")
|
||||||
|
.setPredictionCol("prediction")
|
||||||
|
.setMetricName("precision")
|
||||||
|
val accuracy = evaluator.evaluate(predictions)
|
||||||
|
println("Test Error = " + (1.0 - accuracy))
|
||||||
|
|
||||||
|
val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel]
|
||||||
|
println("Learned classification tree model:\n" + treeModel.toDebugString)
|
||||||
|
// $example off$
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,81 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// scalastyle:off println
|
||||||
|
package org.apache.spark.examples.ml
|
||||||
|
import org.apache.spark.sql.SQLContext
|
||||||
|
import org.apache.spark.{SparkContext, SparkConf}
|
||||||
|
// $example on$
|
||||||
|
import org.apache.spark.ml.Pipeline
|
||||||
|
import org.apache.spark.ml.regression.DecisionTreeRegressor
|
||||||
|
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
|
||||||
|
import org.apache.spark.ml.feature.VectorIndexer
|
||||||
|
import org.apache.spark.ml.evaluation.RegressionEvaluator
|
||||||
|
import org.apache.spark.mllib.util.MLUtils
|
||||||
|
// $example off$
|
||||||
|
object DecisionTreeRegressionExample {
|
||||||
|
def main(args: Array[String]): Unit = {
|
||||||
|
val conf = new SparkConf().setAppName("DecisionTreeRegressionExample")
|
||||||
|
val sc = new SparkContext(conf)
|
||||||
|
val sqlContext = new SQLContext(sc)
|
||||||
|
import sqlContext.implicits._
|
||||||
|
// $example on$
|
||||||
|
// Load and parse the data file, converting it to a DataFrame.
|
||||||
|
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF()
|
||||||
|
|
||||||
|
// Automatically identify categorical features, and index them.
|
||||||
|
// Here, we treat features with > 4 distinct values as continuous.
|
||||||
|
val featureIndexer = new VectorIndexer()
|
||||||
|
.setInputCol("features")
|
||||||
|
.setOutputCol("indexedFeatures")
|
||||||
|
.setMaxCategories(4)
|
||||||
|
.fit(data)
|
||||||
|
|
||||||
|
// Split the data into training and test sets (30% held out for testing)
|
||||||
|
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
|
||||||
|
|
||||||
|
// Train a DecisionTree model.
|
||||||
|
val dt = new DecisionTreeRegressor()
|
||||||
|
.setLabelCol("label")
|
||||||
|
.setFeaturesCol("indexedFeatures")
|
||||||
|
|
||||||
|
// Chain indexer and tree in a Pipeline
|
||||||
|
val pipeline = new Pipeline()
|
||||||
|
.setStages(Array(featureIndexer, dt))
|
||||||
|
|
||||||
|
// Train model. This also runs the indexer.
|
||||||
|
val model = pipeline.fit(trainingData)
|
||||||
|
|
||||||
|
// Make predictions.
|
||||||
|
val predictions = model.transform(testData)
|
||||||
|
|
||||||
|
// Select example rows to display.
|
||||||
|
predictions.select("prediction", "label", "features").show(5)
|
||||||
|
|
||||||
|
// Select (prediction, true label) and compute test error
|
||||||
|
val evaluator = new RegressionEvaluator()
|
||||||
|
.setLabelCol("label")
|
||||||
|
.setPredictionCol("prediction")
|
||||||
|
.setMetricName("rmse")
|
||||||
|
val rmse = evaluator.evaluate(predictions)
|
||||||
|
println("Root Mean Squared Error (RMSE) on test data = " + rmse)
|
||||||
|
|
||||||
|
val treeModel = model.stages(1).asInstanceOf[DecisionTreeRegressionModel]
|
||||||
|
println("Learned regression tree model:\n" + treeModel.toDebugString)
|
||||||
|
// $example off$
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue