[SPARK-7582] [MLLIB] user guide for StringIndexer

This PR adds a Java unit test and user guide for `StringIndexer`. I put it before `OneHotEncoder` because they are closely related. jkbradley

Author: Xiangrui Meng <meng@databricks.com>

Closes #6561 from mengxr/SPARK-7582 and squashes the following commits:

4bba4f1 [Xiangrui Meng] fix example
ba1cd1b [Xiangrui Meng] fix style
7fa18d1 [Xiangrui Meng] add user guide for StringIndexer
136cb93 [Xiangrui Meng] add a Java unit test for StringIndexer
This commit is contained in:
Xiangrui Meng 2015-06-01 22:03:29 -07:00
parent b53a011647
commit 0221c7f0ef
2 changed files with 193 additions and 0 deletions

View file

@ -456,6 +456,122 @@ for expanded in polyDF.select("polyFeatures").take(3):
</div>
</div>
## StringIndexer
`StringIndexer` encodes a string column of labels to a column of label indices.
The indices are in `[0, numLabels)`, ordered by label frequencies.
So the most frequent label gets index `0`.
If the input column is numeric, we cast it to string and index the string values.
**Examples**
Assume that we have the following DataFrame with columns `id` and `category`:
~~~~
id | category
----|----------
0 | a
1 | b
2 | c
3 | a
4 | a
5 | c
~~~~
`category` is a string column with three labels: "a", "b", and "c".
Applying `StringIndexer` with `category` as the input column and `categoryIndex` as the output
column, we should get the following:
~~~~
id | category | categoryIndex
----|----------|---------------
0 | a | 0.0
1 | b | 2.0
2 | c | 1.0
3 | a | 0.0
4 | a | 0.0
5 | c | 1.0
~~~~
"a" gets index `0` because it is the most frequent, followed by "c" with index `1` and "b" with
index `2`.
<div class="codetabs">
<div data-lang="scala" markdown="1">
[`StringIndexer`](api/scala/index.html#org.apache.spark.ml.feature.StringIndexer) takes an input
column name and an output column name.
{% highlight scala %}
import org.apache.spark.ml.feature.StringIndexer
val df = sqlContext.createDataFrame(
Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c"))
).toDF("id", "category")
val indexer = new StringIndexer()
.setInputCol("category")
.setOutputCol("categoryIndex")
val indexed = indexer.fit(df).transform(df)
indexed.show()
{% endhighlight %}
</div>
<div data-lang="java" markdown="1">
[`StringIndexer`](api/java/org/apache/spark/ml/feature/StringIndexer.html) takes an input column
name and an output column name.
{% highlight java %}
import java.util.Arrays;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import static org.apache.spark.sql.types.DataTypes.*;
JavaRDD<Row> jrdd = jsc.parallelize(Arrays.asList(
RowFactory.create(0, "a"),
RowFactory.create(1, "b"),
RowFactory.create(2, "c"),
RowFactory.create(3, "a"),
RowFactory.create(4, "a"),
RowFactory.create(5, "c")
));
StructType schema = new StructType(new StructField[] {
createStructField("id", DoubleType, false),
createStructField("category", StringType, false)
});
DataFrame df = sqlContext.createDataFrame(jrdd, schema);
StringIndexer indexer = new StringIndexer()
.setInputCol("category")
.setOutputCol("categoryIndex");
DataFrame indexed = indexer.fit(df).transform(df);
indexed.show();
{% endhighlight %}
</div>
<div data-lang="python" markdown="1">
[`StringIndexer`](api/python/pyspark.ml.html#pyspark.ml.feature.StringIndexer) takes an input
column name and an output column name.
{% highlight python %}
from pyspark.ml.feature import StringIndexer
df = sqlContext.createDataFrame(
[(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")],
["id", "category"])
indexer = StringIndexer(inputCol="category", outputCol="categoryIndex")
indexed = indexer.fit(df).transform(df)
indexed.show()
{% endhighlight %}
</div>
</div>
## OneHotEncoder
[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features

View file

@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.feature;
import java.util.Arrays;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import static org.apache.spark.sql.types.DataTypes.*;
public class JavaStringIndexerSuite {
private transient JavaSparkContext jsc;
private transient SQLContext sqlContext;
@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaStringIndexerSuite");
sqlContext = new SQLContext(jsc);
}
@After
public void tearDown() {
jsc.stop();
sqlContext = null;
}
@Test
public void testStringIndexer() {
StructType schema = createStructType(new StructField[] {
createStructField("id", IntegerType, false),
createStructField("label", StringType, false)
});
JavaRDD<Row> rdd = jsc.parallelize(
Arrays.asList(c(0, "a"), c(1, "b"), c(2, "c"), c(3, "a"), c(4, "a"), c(5, "c")));
DataFrame dataset = sqlContext.createDataFrame(rdd, schema);
StringIndexer indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex");
DataFrame output = indexer.fit(dataset).transform(dataset);
Assert.assertArrayEquals(
new Row[] { c(0, 0.0), c(1, 2.0), c(2, 1.0), c(3, 0.0), c(4, 0.0), c(5, 1.0) },
output.orderBy("id").select("id", "labelIndex").collect());
}
/** An alias for RowFactory.create. */
private Row c(Object... values) {
return RowFactory.create(values);
}
}