[SPARK-15149][EXAMPLE][DOC] update kmeans example
## What changes were proposed in this pull request? Python example for ml.kmeans already exists, but not included in user guide. 1,small changes like: `example_on` `example_off` 2,add it to user guide 3,update examples to directly read datafile ## How was this patch tested? manual tests `./bin/spark-submit examples/src/main/python/ml/kmeans_example.py Author: Zheng RuiFeng <ruifengz@foxmail.com> Closes #12925 from zhengruifeng/km_pe.
This commit is contained in:
parent
cef73b5638
commit
8beae59144
|
@ -79,6 +79,11 @@ Refer to the [Java API docs](api/java/org/apache/spark/ml/clustering/KMeans.html
|
|||
{% include_example java/org/apache/spark/examples/ml/JavaKMeansExample.java %}
|
||||
</div>
|
||||
|
||||
<div data-lang="python" markdown="1">
|
||||
Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.clustering.KMeans) for more details.
|
||||
|
||||
{% include_example python/ml/kmeans_example.py %}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
|
|
|
@ -17,77 +17,45 @@
|
|||
|
||||
package org.apache.spark.examples.ml;
|
||||
|
||||
import java.util.regex.Pattern;
|
||||
|
||||
import org.apache.spark.api.java.JavaRDD;
|
||||
import org.apache.spark.api.java.function.Function;
|
||||
import org.apache.spark.sql.Dataset;
|
||||
import org.apache.spark.sql.SparkSession;
|
||||
import org.apache.spark.sql.catalyst.expressions.GenericRow;
|
||||
// $example on$
|
||||
import org.apache.spark.ml.clustering.KMeansModel;
|
||||
import org.apache.spark.ml.clustering.KMeans;
|
||||
import org.apache.spark.mllib.linalg.Vector;
|
||||
import org.apache.spark.mllib.linalg.VectorUDT;
|
||||
import org.apache.spark.mllib.linalg.Vectors;
|
||||
import org.apache.spark.sql.Dataset;
|
||||
import org.apache.spark.sql.Row;
|
||||
import org.apache.spark.sql.types.Metadata;
|
||||
import org.apache.spark.sql.types.StructField;
|
||||
import org.apache.spark.sql.types.StructType;
|
||||
// $example off$
|
||||
import org.apache.spark.sql.SparkSession;
|
||||
|
||||
|
||||
/**
|
||||
* An example demonstrating a k-means clustering.
|
||||
* An example demonstrating k-means clustering.
|
||||
* Run with
|
||||
* <pre>
|
||||
* bin/run-example ml.JavaKMeansExample <file> <k>
|
||||
* bin/run-example ml.JavaKMeansExample
|
||||
* </pre>
|
||||
*/
|
||||
public class JavaKMeansExample {
|
||||
|
||||
private static class ParsePoint implements Function<String, Row> {
|
||||
private static final Pattern separator = Pattern.compile(" ");
|
||||
|
||||
@Override
|
||||
public Row call(String line) {
|
||||
String[] tok = separator.split(line);
|
||||
double[] point = new double[tok.length];
|
||||
for (int i = 0; i < tok.length; ++i) {
|
||||
point[i] = Double.parseDouble(tok[i]);
|
||||
}
|
||||
Vector[] points = {Vectors.dense(point)};
|
||||
return new GenericRow(points);
|
||||
}
|
||||
}
|
||||
|
||||
public static void main(String[] args) {
|
||||
if (args.length != 2) {
|
||||
System.err.println("Usage: ml.JavaKMeansExample <file> <k>");
|
||||
System.exit(1);
|
||||
}
|
||||
String inputFile = args[0];
|
||||
int k = Integer.parseInt(args[1]);
|
||||
|
||||
// Parses the arguments
|
||||
// Create a SparkSession.
|
||||
SparkSession spark = SparkSession
|
||||
.builder()
|
||||
.appName("JavaKMeansExample")
|
||||
.getOrCreate();
|
||||
|
||||
// $example on$
|
||||
// Loads data
|
||||
JavaRDD<Row> points = spark.read().text(inputFile).javaRDD().map(new ParsePoint());
|
||||
StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())};
|
||||
StructType schema = new StructType(fields);
|
||||
Dataset<Row> dataset = spark.createDataFrame(points, schema);
|
||||
// Loads data.
|
||||
Dataset<Row> dataset = spark.read().format("libsvm").load("data/mllib/sample_kmeans_data.txt");
|
||||
|
||||
// Trains a k-means model
|
||||
KMeans kmeans = new KMeans()
|
||||
.setK(k);
|
||||
// Trains a k-means model.
|
||||
KMeans kmeans = new KMeans().setK(2).setSeed(1L);
|
||||
KMeansModel model = kmeans.fit(dataset);
|
||||
|
||||
// Shows the result
|
||||
// Evaluate clustering by computing Within Set Sum of Squared Errors.
|
||||
double WSSSE = model.computeCost(dataset);
|
||||
System.out.println("Within Set Sum of Squared Errors = " + WSSSE);
|
||||
|
||||
// Shows the result.
|
||||
Vector[] centers = model.clusterCenters();
|
||||
System.out.println("Cluster Centers: ");
|
||||
for (Vector center: centers) {
|
||||
|
|
|
@ -17,55 +17,45 @@
|
|||
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
# $example on$
|
||||
from pyspark.ml.clustering import KMeans
|
||||
# $example off$
|
||||
|
||||
import numpy as np
|
||||
from pyspark.ml.clustering import KMeans, KMeansModel
|
||||
from pyspark.mllib.linalg import VectorUDT, _convert_to_vector
|
||||
from pyspark.sql import SparkSession
|
||||
from pyspark.sql.types import Row, StructField, StructType
|
||||
|
||||
"""
|
||||
A simple example demonstrating a k-means clustering.
|
||||
An example demonstrating k-means clustering.
|
||||
Run with:
|
||||
bin/spark-submit examples/src/main/python/ml/kmeans_example.py <input> <k>
|
||||
bin/spark-submit examples/src/main/python/ml/kmeans_example.py
|
||||
|
||||
This example requires NumPy (http://www.numpy.org/).
|
||||
"""
|
||||
|
||||
|
||||
def parseVector(row):
|
||||
array = np.array([float(x) for x in row.value.split(' ')])
|
||||
return _convert_to_vector(array)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
FEATURES_COL = "features"
|
||||
|
||||
if len(sys.argv) != 3:
|
||||
print("Usage: kmeans_example.py <file> <k>", file=sys.stderr)
|
||||
exit(-1)
|
||||
path = sys.argv[1]
|
||||
k = sys.argv[2]
|
||||
|
||||
spark = SparkSession\
|
||||
.builder\
|
||||
.appName("PythonKMeansExample")\
|
||||
.getOrCreate()
|
||||
|
||||
lines = spark.read.text(path).rdd
|
||||
data = lines.map(parseVector)
|
||||
row_rdd = data.map(lambda x: Row(x))
|
||||
schema = StructType([StructField(FEATURES_COL, VectorUDT(), False)])
|
||||
df = spark.createDataFrame(row_rdd, schema)
|
||||
# $example on$
|
||||
# Loads data.
|
||||
dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt")
|
||||
|
||||
kmeans = KMeans().setK(2).setSeed(1).setFeaturesCol(FEATURES_COL)
|
||||
model = kmeans.fit(df)
|
||||
# Trains a k-means model.
|
||||
kmeans = KMeans().setK(2).setSeed(1)
|
||||
model = kmeans.fit(dataset)
|
||||
|
||||
# Evaluate clustering by computing Within Set Sum of Squared Errors.
|
||||
wssse = model.computeCost(dataset)
|
||||
print("Within Set Sum of Squared Errors = " + str(wssse))
|
||||
|
||||
# Shows the result.
|
||||
centers = model.clusterCenters()
|
||||
|
||||
print("Cluster Centers: ")
|
||||
for center in centers:
|
||||
print(center)
|
||||
# $example off$
|
||||
|
||||
spark.stop()
|
||||
|
|
|
@ -21,12 +21,11 @@ package org.apache.spark.examples.ml
|
|||
|
||||
// $example on$
|
||||
import org.apache.spark.ml.clustering.KMeans
|
||||
import org.apache.spark.mllib.linalg.Vectors
|
||||
import org.apache.spark.sql.{DataFrame, SparkSession}
|
||||
// $example off$
|
||||
import org.apache.spark.sql.SparkSession
|
||||
|
||||
/**
|
||||
* An example demonstrating a k-means clustering.
|
||||
* An example demonstrating k-means clustering.
|
||||
* Run with
|
||||
* {{{
|
||||
* bin/run-example ml.KMeansExample
|
||||
|
@ -35,32 +34,26 @@ import org.apache.spark.sql.{DataFrame, SparkSession}
|
|||
object KMeansExample {
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
// Creates a Spark context and a SQL context
|
||||
// Creates a SparkSession.
|
||||
val spark = SparkSession
|
||||
.builder
|
||||
.appName(s"${this.getClass.getSimpleName}")
|
||||
.getOrCreate()
|
||||
|
||||
// $example on$
|
||||
// Crates a DataFrame
|
||||
val dataset: DataFrame = spark.createDataFrame(Seq(
|
||||
(1, Vectors.dense(0.0, 0.0, 0.0)),
|
||||
(2, Vectors.dense(0.1, 0.1, 0.1)),
|
||||
(3, Vectors.dense(0.2, 0.2, 0.2)),
|
||||
(4, Vectors.dense(9.0, 9.0, 9.0)),
|
||||
(5, Vectors.dense(9.1, 9.1, 9.1)),
|
||||
(6, Vectors.dense(9.2, 9.2, 9.2))
|
||||
)).toDF("id", "features")
|
||||
// Loads data.
|
||||
val dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt")
|
||||
|
||||
// Trains a k-means model
|
||||
val kmeans = new KMeans()
|
||||
.setK(2)
|
||||
.setFeaturesCol("features")
|
||||
.setPredictionCol("prediction")
|
||||
// Trains a k-means model.
|
||||
val kmeans = new KMeans().setK(2).setSeed(1L)
|
||||
val model = kmeans.fit(dataset)
|
||||
|
||||
// Shows the result
|
||||
println("Final Centers: ")
|
||||
// Evaluate clustering by computing Within Set Sum of Squared Errors.
|
||||
val WSSSE = model.computeCost(dataset)
|
||||
println(s"Within Set Sum of Squared Errors = $WSSSE")
|
||||
|
||||
// Shows the result.
|
||||
println("Cluster Centers: ")
|
||||
model.clusterCenters.foreach(println)
|
||||
// $example off$
|
||||
|
||||
|
|
Loading…
Reference in a new issue