From 0221c7f0efe2512f3ae3839b83aa8abb0806d516 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 1 Jun 2015 22:03:29 -0700 Subject: [PATCH] [SPARK-7582] [MLLIB] user guide for StringIndexer This PR adds a Java unit test and user guide for `StringIndexer`. I put it before `OneHotEncoder` because they are closely related. jkbradley Author: Xiangrui Meng Closes #6561 from mengxr/SPARK-7582 and squashes the following commits: 4bba4f1 [Xiangrui Meng] fix example ba1cd1b [Xiangrui Meng] fix style 7fa18d1 [Xiangrui Meng] add user guide for StringIndexer 136cb93 [Xiangrui Meng] add a Java unit test for StringIndexer --- docs/ml-features.md | 116 ++++++++++++++++++ .../ml/feature/JavaStringIndexerSuite.java | 77 ++++++++++++ 2 files changed, 193 insertions(+) create mode 100644 mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java diff --git a/docs/ml-features.md b/docs/ml-features.md index 9ee5696122..f88c0248c1 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -456,6 +456,122 @@ for expanded in polyDF.select("polyFeatures").take(3): +## StringIndexer + +`StringIndexer` encodes a string column of labels to a column of label indices. +The indices are in `[0, numLabels)`, ordered by label frequencies. +So the most frequent label gets index `0`. +If the input column is numeric, we cast it to string and index the string values. + +**Examples** + +Assume that we have the following DataFrame with columns `id` and `category`: + +~~~~ + id | category +----|---------- + 0 | a + 1 | b + 2 | c + 3 | a + 4 | a + 5 | c +~~~~ + +`category` is a string column with three labels: "a", "b", and "c". +Applying `StringIndexer` with `category` as the input column and `categoryIndex` as the output +column, we should get the following: + +~~~~ + id | category | categoryIndex +----|----------|--------------- + 0 | a | 0.0 + 1 | b | 2.0 + 2 | c | 1.0 + 3 | a | 0.0 + 4 | a | 0.0 + 5 | c | 1.0 +~~~~ + +"a" gets index `0` because it is the most frequent, followed by "c" with index `1` and "b" with +index `2`. + +
+ +
+ +[`StringIndexer`](api/scala/index.html#org.apache.spark.ml.feature.StringIndexer) takes an input +column name and an output column name. + +{% highlight scala %} +import org.apache.spark.ml.feature.StringIndexer + +val df = sqlContext.createDataFrame( + Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) +).toDF("id", "category") +val indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex") +val indexed = indexer.fit(df).transform(df) +indexed.show() +{% endhighlight %} +
+ +
+[`StringIndexer`](api/java/org/apache/spark/ml/feature/StringIndexer.html) takes an input column +name and an output column name. + +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.StringIndexer; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import static org.apache.spark.sql.types.DataTypes.*; + +JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(0, "a"), + RowFactory.create(1, "b"), + RowFactory.create(2, "c"), + RowFactory.create(3, "a"), + RowFactory.create(4, "a"), + RowFactory.create(5, "c") +)); +StructType schema = new StructType(new StructField[] { + createStructField("id", DoubleType, false), + createStructField("category", StringType, false) +}); +DataFrame df = sqlContext.createDataFrame(jrdd, schema); +StringIndexer indexer = new StringIndexer() + .setInputCol("category") + .setOutputCol("categoryIndex"); +DataFrame indexed = indexer.fit(df).transform(df); +indexed.show(); +{% endhighlight %} +
+ +
+ +[`StringIndexer`](api/python/pyspark.ml.html#pyspark.ml.feature.StringIndexer) takes an input +column name and an output column name. + +{% highlight python %} +from pyspark.ml.feature import StringIndexer + +df = sqlContext.createDataFrame( + [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")], + ["id", "category"]) +indexer = StringIndexer(inputCol="category", outputCol="categoryIndex") +indexed = indexer.fit(df).transform(df) +indexed.show() +{% endhighlight %} +
+
+ ## OneHotEncoder [One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java new file mode 100644 index 0000000000..35b18c5308 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import java.util.Arrays; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import static org.apache.spark.sql.types.DataTypes.*; + +public class JavaStringIndexerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext sqlContext; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaStringIndexerSuite"); + sqlContext = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + sqlContext = null; + } + + @Test + public void testStringIndexer() { + StructType schema = createStructType(new StructField[] { + createStructField("id", IntegerType, false), + createStructField("label", StringType, false) + }); + JavaRDD rdd = jsc.parallelize( + Arrays.asList(c(0, "a"), c(1, "b"), c(2, "c"), c(3, "a"), c(4, "a"), c(5, "c"))); + DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + + StringIndexer indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex"); + DataFrame output = indexer.fit(dataset).transform(dataset); + + Assert.assertArrayEquals( + new Row[] { c(0, 0.0), c(1, 2.0), c(2, 1.0), c(3, 0.0), c(4, 0.0), c(5, 1.0) }, + output.orderBy("id").select("id", "labelIndex").collect()); + } + + /** An alias for RowFactory.create. */ + private Row c(Object... values) { + return RowFactory.create(values); + } +}