[SPARK-7387] [ML] [DOC] CrossValidator example code in Python
Author: Ram Sriharsha <rsriharsha@hw11853.local> Closes #6358 from harsha2010/SPARK-7387 and squashes the following commits: 63efda2 [Ram Sriharsha] more examples for classifier to distinguish mapreduce from spark properly aeb6bb6 [Ram Sriharsha] Python Style Fix 54a500c [Ram Sriharsha] Merge branch 'master' into SPARK-7387 615e91c [Ram Sriharsha] cleanup 204c4e3 [Ram Sriharsha] Merge branch 'master' into SPARK-7387 7246d35 [Ram Sriharsha] [SPARK-7387][ml][doc] CrossValidator example code in Python
This commit is contained in:
parent
5cd6a63d96
commit
c3f4c32571
96
examples/src/main/python/ml/cross_validator.py
Normal file
96
examples/src/main/python/ml/cross_validator.py
Normal file
|
@ -0,0 +1,96 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright ownership.
|
||||
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
# (the "License"); you may not use this file except in compliance with
|
||||
# the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
from pyspark import SparkContext
|
||||
from pyspark.ml import Pipeline
|
||||
from pyspark.ml.classification import LogisticRegression
|
||||
from pyspark.ml.evaluation import BinaryClassificationEvaluator
|
||||
from pyspark.ml.feature import HashingTF, Tokenizer
|
||||
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
|
||||
from pyspark.sql import Row, SQLContext
|
||||
|
||||
"""
|
||||
A simple example demonstrating model selection using CrossValidator.
|
||||
This example also demonstrates how Pipelines are Estimators.
|
||||
Run with:
|
||||
|
||||
bin/spark-submit examples/src/main/python/ml/cross_validator.py
|
||||
"""
|
||||
|
||||
if __name__ == "__main__":
|
||||
sc = SparkContext(appName="CrossValidatorExample")
|
||||
sqlContext = SQLContext(sc)
|
||||
|
||||
# Prepare training documents, which are labeled.
|
||||
LabeledDocument = Row("id", "text", "label")
|
||||
training = sc.parallelize([(0, "a b c d e spark", 1.0),
|
||||
(1, "b d", 0.0),
|
||||
(2, "spark f g h", 1.0),
|
||||
(3, "hadoop mapreduce", 0.0),
|
||||
(4, "b spark who", 1.0),
|
||||
(5, "g d a y", 0.0),
|
||||
(6, "spark fly", 1.0),
|
||||
(7, "was mapreduce", 0.0),
|
||||
(8, "e spark program", 1.0),
|
||||
(9, "a e c l", 0.0),
|
||||
(10, "spark compile", 1.0),
|
||||
(11, "hadoop software", 0.0)
|
||||
]) \
|
||||
.map(lambda x: LabeledDocument(*x)).toDF()
|
||||
|
||||
# Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr.
|
||||
tokenizer = Tokenizer(inputCol="text", outputCol="words")
|
||||
hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")
|
||||
lr = LogisticRegression(maxIter=10)
|
||||
pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])
|
||||
|
||||
# We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.
|
||||
# This will allow us to jointly choose parameters for all Pipeline stages.
|
||||
# A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
|
||||
# We use a ParamGridBuilder to construct a grid of parameters to search over.
|
||||
# With 3 values for hashingTF.numFeatures and 2 values for lr.regParam,
|
||||
# this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from.
|
||||
paramGrid = ParamGridBuilder() \
|
||||
.addGrid(hashingTF.numFeatures, [10, 100, 1000]) \
|
||||
.addGrid(lr.regParam, [0.1, 0.01]) \
|
||||
.build()
|
||||
|
||||
crossval = CrossValidator(estimator=pipeline,
|
||||
estimatorParamMaps=paramGrid,
|
||||
evaluator=BinaryClassificationEvaluator(),
|
||||
numFolds=2) # use 3+ folds in practice
|
||||
|
||||
# Run cross-validation, and choose the best set of parameters.
|
||||
cvModel = crossval.fit(training)
|
||||
|
||||
# Prepare test documents, which are unlabeled.
|
||||
Document = Row("id", "text")
|
||||
test = sc.parallelize([(4L, "spark i j k"),
|
||||
(5L, "l m n"),
|
||||
(6L, "mapreduce spark"),
|
||||
(7L, "apache hadoop")]) \
|
||||
.map(lambda x: Document(*x)).toDF()
|
||||
|
||||
# Make predictions on test documents. cvModel uses the best model found (lrModel).
|
||||
prediction = cvModel.transform(test)
|
||||
selected = prediction.select("id", "text", "probability", "prediction")
|
||||
for row in selected.collect():
|
||||
print(row)
|
||||
|
||||
sc.stop()
|
|
@ -41,8 +41,8 @@ if __name__ == "__main__":
|
|||
|
||||
# prepare training data.
|
||||
# We create an RDD of LabeledPoints and convert them into a DataFrame.
|
||||
# Spark DataFrames can automatically infer the schema from named tuples
|
||||
# and LabeledPoint implements __reduce__ to behave like a named tuple.
|
||||
# A LabeledPoint is an Object with two fields named label and features
|
||||
# and Spark SQL identifies these fields and creates the schema appropriately.
|
||||
training = sc.parallelize([
|
||||
LabeledPoint(1.0, DenseVector([0.0, 1.1, 0.1])),
|
||||
LabeledPoint(0.0, DenseVector([2.0, 1.0, -1.0])),
|
||||
|
|
Loading…
Reference in a new issue