[SPARK-7584] [MLLIB] User guide for VectorAssembler
This PR adds a section in the user guide for `VectorAssembler` with code examples in Python/Java/Scala. It also adds a unit test in Java. jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #6556 from mengxr/SPARK-7584 and squashes the following commits: 11313f6 [Xiangrui Meng] simplify Java example 0cd47f3 [Xiangrui Meng] update user guide fd36292 [Xiangrui Meng] update Java unit test ce61ca0 [Xiangrui Meng] add Java unit test for VectorAssembler e399942 [Xiangrui Meng] scala/python example code
This commit is contained in:
parent
b7ab0299b0
commit
90c606925e
|
@ -964,5 +964,119 @@ DataFrame transformedData = transformer.transform(dataFrame);
|
|||
</div>
|
||||
</div>
|
||||
|
||||
## VectorAssembler
|
||||
|
||||
`VectorAssembler` is a transformer that combines a given list of columns into a single vector
|
||||
column.
|
||||
It is useful for combining raw features and features generated by different feature transformers
|
||||
into a single feature vector, in order to train ML models like logistic regression and decision
|
||||
trees.
|
||||
`VectorAssembler` accepts the following input column types: all numeric types, boolean type,
|
||||
and vector type.
|
||||
In each row, the values of the input columns will be concatenated into a vector in the specified
|
||||
order.
|
||||
|
||||
**Examples**
|
||||
|
||||
Assume that we have a DataFrame with the columns `id`, `hour`, `mobile`, `userFeatures`,
|
||||
and `clicked`:
|
||||
|
||||
~~~
|
||||
id | hour | mobile | userFeatures | clicked
|
||||
----|------|--------|------------------|---------
|
||||
0 | 18 | 1.0 | [0.0, 10.0, 0.5] | 1.0
|
||||
~~~
|
||||
|
||||
`userFeatures` is a vector column that contains three user features.
|
||||
We want to combine `hour`, `mobile`, and `userFeatures` into a single feature vector
|
||||
called `features` and use it to predict `clicked` or not.
|
||||
If we set `VectorAssembler`'s input columns to `hour`, `mobile`, and `userFeatures` and
|
||||
output column to `features`, after transformation we should get the following DataFrame:
|
||||
|
||||
~~~
|
||||
id | hour | mobile | userFeatures | clicked | features
|
||||
----|------|--------|------------------|---------|-----------------------------
|
||||
0 | 18 | 1.0 | [0.0, 10.0, 0.5] | 1.0 | [18.0, 1.0, 0.0, 10.0, 0.5]
|
||||
~~~
|
||||
|
||||
<div class="codetabs">
|
||||
<div data-lang="scala" markdown="1">
|
||||
|
||||
[`VectorAssembler`](api/scala/index.html#org.apache.spark.ml.feature.VectorAssembler) takes an array
|
||||
of input column names and an output column name.
|
||||
|
||||
{% highlight scala %}
|
||||
import org.apache.spark.mllib.linalg.Vectors
|
||||
import org.apache.spark.ml.feature.VectorAssembler
|
||||
|
||||
val dataset = sqlContext.createDataFrame(
|
||||
Seq((0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0))
|
||||
).toDF("id", "hour", "mobile", "userFeatures", "clicked")
|
||||
val assembler = new VectorAssembler()
|
||||
.setInputCols(Array("hour", "mobile", "userFeatures"))
|
||||
.setOutputCol("features")
|
||||
val output = assembler.transform(dataset)
|
||||
println(output.select("features", "clicked").first())
|
||||
{% endhighlight %}
|
||||
</div>
|
||||
|
||||
<div data-lang="java" markdown="1">
|
||||
|
||||
[`VectorAssembler`](api/java/org/apache/spark/ml/feature/VectorAssembler.html) takes an array
|
||||
of input column names and an output column name.
|
||||
|
||||
{% highlight java %}
|
||||
import java.util.Arrays;
|
||||
|
||||
import org.apache.spark.api.java.JavaRDD;
|
||||
import org.apache.spark.mllib.linalg.VectorUDT;
|
||||
import org.apache.spark.mllib.linalg.Vectors;
|
||||
import org.apache.spark.sql.DataFrame;
|
||||
import org.apache.spark.sql.Row;
|
||||
import org.apache.spark.sql.RowFactory;
|
||||
import org.apache.spark.sql.types.*;
|
||||
import static org.apache.spark.sql.types.DataTypes.*;
|
||||
|
||||
StructType schema = createStructType(new StructField[] {
|
||||
createStructField("id", IntegerType, false),
|
||||
createStructField("hour", IntegerType, false),
|
||||
createStructField("mobile", DoubleType, false),
|
||||
createStructField("userFeatures", new VectorUDT(), false),
|
||||
createStructField("clicked", DoubleType, false)
|
||||
});
|
||||
Row row = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0);
|
||||
JavaRDD<Row> rdd = jsc.parallelize(Arrays.asList(row));
|
||||
DataFrame dataset = sqlContext.createDataFrame(rdd, schema);
|
||||
|
||||
VectorAssembler assembler = new VectorAssembler()
|
||||
.setInputCols(new String[] {"hour", "mobile", "userFeatures"})
|
||||
.setOutputCol("features");
|
||||
|
||||
DataFrame output = assembler.transform(dataset);
|
||||
System.out.println(output.select("features", "clicked").first());
|
||||
{% endhighlight %}
|
||||
</div>
|
||||
|
||||
<div data-lang="python" markdown="1">
|
||||
|
||||
[`VectorAssembler`](api/python/pyspark.ml.html#pyspark.ml.feature.VectorAssembler) takes a list
|
||||
of input column names and an output column name.
|
||||
|
||||
{% highlight python %}
|
||||
from pyspark.mllib.linalg import Vectors
|
||||
from pyspark.ml.feature import VectorAssembler
|
||||
|
||||
dataset = sqlContext.createDataFrame(
|
||||
[(0, 18, 1.0, Vectors.dense([0.0, 10.0, 0.5]), 1.0)],
|
||||
["id", "hour", "mobile", "userFeatures", "clicked"])
|
||||
assembler = VectorAssembler(
|
||||
inputCols=["hour", "mobile", "userFeatures"],
|
||||
outputCol="features")
|
||||
output = assembler.transform(dataset)
|
||||
print(output.select("features", "clicked").first())
|
||||
{% endhighlight %}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
# Feature Selectors
|
||||
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml.feature;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
import org.junit.After;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
import org.apache.spark.api.java.JavaRDD;
|
||||
import org.apache.spark.api.java.JavaSparkContext;
|
||||
import org.apache.spark.mllib.linalg.Vector;
|
||||
import org.apache.spark.mllib.linalg.VectorUDT;
|
||||
import org.apache.spark.mllib.linalg.Vectors;
|
||||
import org.apache.spark.sql.DataFrame;
|
||||
import org.apache.spark.sql.Row;
|
||||
import org.apache.spark.sql.RowFactory;
|
||||
import org.apache.spark.sql.SQLContext;
|
||||
import org.apache.spark.sql.types.*;
|
||||
import static org.apache.spark.sql.types.DataTypes.*;
|
||||
|
||||
public class JavaVectorAssemblerSuite {
|
||||
private transient JavaSparkContext jsc;
|
||||
private transient SQLContext sqlContext;
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
jsc = new JavaSparkContext("local", "JavaVectorAssemblerSuite");
|
||||
sqlContext = new SQLContext(jsc);
|
||||
}
|
||||
|
||||
@After
|
||||
public void tearDown() {
|
||||
jsc.stop();
|
||||
jsc = null;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVectorAssembler() {
|
||||
StructType schema = createStructType(new StructField[] {
|
||||
createStructField("id", IntegerType, false),
|
||||
createStructField("x", DoubleType, false),
|
||||
createStructField("y", new VectorUDT(), false),
|
||||
createStructField("name", StringType, false),
|
||||
createStructField("z", new VectorUDT(), false),
|
||||
createStructField("n", LongType, false)
|
||||
});
|
||||
Row row = RowFactory.create(
|
||||
0, 0.0, Vectors.dense(1.0, 2.0), "a",
|
||||
Vectors.sparse(2, new int[] {1}, new double[] {3.0}), 10L);
|
||||
JavaRDD<Row> rdd = jsc.parallelize(Arrays.asList(row));
|
||||
DataFrame dataset = sqlContext.createDataFrame(rdd, schema);
|
||||
VectorAssembler assembler = new VectorAssembler()
|
||||
.setInputCols(new String[] {"x", "y", "z", "n"})
|
||||
.setOutputCol("features");
|
||||
DataFrame output = assembler.transform(dataset);
|
||||
Assert.assertEquals(
|
||||
Vectors.sparse(6, new int[] {1, 2, 4, 5}, new double[] {1.0, 2.0, 3.0, 10.0}),
|
||||
output.select("features").first().<Vector>getAs(0));
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue