[SPARK-7577] [ML] [DOC] add bucketizer doc
CC jkbradley
Author: Xusen Yin <yinxusen@gmail.com>
Closes #6451 from yinxusen/SPARK-7577 and squashes the following commits:
e2dc32e [Xusen Yin] rename colums
e350e49 [Xusen Yin] add all demos
006ddf1 [Xusen Yin] add java test
3238481 [Xusen Yin] add bucketizer
(cherry picked from commit 1bd63e82fd
)
Signed-off-by: Joseph K. Bradley <joseph@databricks.com>
This commit is contained in:
parent
8f4a86eaa1
commit
7bb445a38c
|
@ -789,6 +789,92 @@ scaledData = scalerModel.transform(dataFrame)
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
## Bucketizer
|
||||||
|
|
||||||
|
`Bucketizer` transforms a column of continuous features to a column of feature buckets, where the buckets are specified by users. It takes a parameter:
|
||||||
|
|
||||||
|
* `splits`: Parameter for mapping continuous features into buckets. With n+1 splits, there are n buckets. A bucket defined by splits x,y holds values in the range [x,y) except the last bucket, which also includes y. Splits should be strictly increasing. Values at -inf, inf must be explicitly provided to cover all Double values; Otherwise, values outside the splits specified will be treated as errors. Two examples of `splits` are `Array(Double.NegativeInfinity, 0.0, 1.0, Double.PositiveInfinity)` and `Array(0.0, 1.0, 2.0)`.
|
||||||
|
|
||||||
|
Note that if you have no idea of the upper bound and lower bound of the targeted column, you would better add the `Double.NegativeInfinity` and `Double.PositiveInfinity` as the bounds of your splits to prevent a potenial out of Bucketizer bounds exception.
|
||||||
|
|
||||||
|
Note also that the splits that you provided have to be in strictly increasing order, i.e. `s0 < s1 < s2 < ... < sn`.
|
||||||
|
|
||||||
|
More details can be found in the API docs for [Bucketizer](api/scala/index.html#org.apache.spark.ml.feature.Bucketizer).
|
||||||
|
|
||||||
|
The following example demonstrates how to bucketize a column of `Double`s into another index-wised column.
|
||||||
|
|
||||||
|
<div class="codetabs">
|
||||||
|
<div data-lang="scala">
|
||||||
|
{% highlight scala %}
|
||||||
|
import org.apache.spark.ml.feature.Bucketizer
|
||||||
|
import org.apache.spark.sql.DataFrame
|
||||||
|
|
||||||
|
val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity)
|
||||||
|
|
||||||
|
val data = Array(-0.5, -0.3, 0.0, 0.2)
|
||||||
|
val dataFrame = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features")
|
||||||
|
|
||||||
|
val bucketizer = new Bucketizer()
|
||||||
|
.setInputCol("features")
|
||||||
|
.setOutputCol("bucketedFeatures")
|
||||||
|
.setSplits(splits)
|
||||||
|
|
||||||
|
// Transform original data into its bucket index.
|
||||||
|
val bucketedData = bucketizer.transform(dataFrame)
|
||||||
|
{% endhighlight %}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div data-lang="java">
|
||||||
|
{% highlight java %}
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
|
||||||
|
import org.apache.spark.sql.DataFrame;
|
||||||
|
import org.apache.spark.sql.Row;
|
||||||
|
import org.apache.spark.sql.RowFactory;
|
||||||
|
import org.apache.spark.sql.types.DataTypes;
|
||||||
|
import org.apache.spark.sql.types.Metadata;
|
||||||
|
import org.apache.spark.sql.types.StructField;
|
||||||
|
import org.apache.spark.sql.types.StructType;
|
||||||
|
|
||||||
|
double[] splits = {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY};
|
||||||
|
|
||||||
|
JavaRDD<Row> data = jsc.parallelize(Lists.newArrayList(
|
||||||
|
RowFactory.create(-0.5),
|
||||||
|
RowFactory.create(-0.3),
|
||||||
|
RowFactory.create(0.0),
|
||||||
|
RowFactory.create(0.2)
|
||||||
|
));
|
||||||
|
StructType schema = new StructType(new StructField[] {
|
||||||
|
new StructField("features", DataTypes.DoubleType, false, Metadata.empty())
|
||||||
|
});
|
||||||
|
DataFrame dataFrame = jsql.createDataFrame(data, schema);
|
||||||
|
|
||||||
|
Bucketizer bucketizer = new Bucketizer()
|
||||||
|
.setInputCol("features")
|
||||||
|
.setOutputCol("bucketedFeatures")
|
||||||
|
.setSplits(splits);
|
||||||
|
|
||||||
|
// Transform original data into its bucket index.
|
||||||
|
DataFrame bucketedData = bucketizer.transform(dataFrame);
|
||||||
|
{% endhighlight %}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div data-lang="python">
|
||||||
|
{% highlight python %}
|
||||||
|
from pyspark.ml.feature import Bucketizer
|
||||||
|
|
||||||
|
splits = [-float("inf"), -0.5, 0.0, 0.5, float("inf")]
|
||||||
|
|
||||||
|
data = [(-0.5,), (-0.3,), (0.0,), (0.2,)]
|
||||||
|
dataFrame = sqlContext.createDataFrame(data, ["features"])
|
||||||
|
|
||||||
|
bucketizer = Bucketizer(splits=splits, inputCol="features", outputCol="bucketedFeatures")
|
||||||
|
|
||||||
|
# Transform original data into its bucket index.
|
||||||
|
bucketedData = bucketizer.transform(dataFrame)
|
||||||
|
{% endhighlight %}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
# Feature Selectors
|
# Feature Selectors
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,80 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.ml.feature;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
import org.junit.After;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Before;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
|
import org.apache.spark.sql.DataFrame;
|
||||||
|
import org.apache.spark.sql.Row;
|
||||||
|
import org.apache.spark.sql.RowFactory;
|
||||||
|
import org.apache.spark.sql.SQLContext;
|
||||||
|
import org.apache.spark.sql.types.DataTypes;
|
||||||
|
import org.apache.spark.sql.types.Metadata;
|
||||||
|
import org.apache.spark.sql.types.StructField;
|
||||||
|
import org.apache.spark.sql.types.StructType;
|
||||||
|
|
||||||
|
public class JavaBucketizerSuite {
|
||||||
|
private transient JavaSparkContext jsc;
|
||||||
|
private transient SQLContext jsql;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setUp() {
|
||||||
|
jsc = new JavaSparkContext("local", "JavaBucketizerSuite");
|
||||||
|
jsql = new SQLContext(jsc);
|
||||||
|
}
|
||||||
|
|
||||||
|
@After
|
||||||
|
public void tearDown() {
|
||||||
|
jsc.stop();
|
||||||
|
jsc = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void bucketizerTest() {
|
||||||
|
double[] splits = {-0.5, 0.0, 0.5};
|
||||||
|
|
||||||
|
JavaRDD<Row> data = jsc.parallelize(Lists.newArrayList(
|
||||||
|
RowFactory.create(-0.5),
|
||||||
|
RowFactory.create(-0.3),
|
||||||
|
RowFactory.create(0.0),
|
||||||
|
RowFactory.create(0.2)
|
||||||
|
));
|
||||||
|
StructType schema = new StructType(new StructField[] {
|
||||||
|
new StructField("feature", DataTypes.DoubleType, false, Metadata.empty())
|
||||||
|
});
|
||||||
|
DataFrame dataset = jsql.createDataFrame(data, schema);
|
||||||
|
|
||||||
|
Bucketizer bucketizer = new Bucketizer()
|
||||||
|
.setInputCol("feature")
|
||||||
|
.setOutputCol("result")
|
||||||
|
.setSplits(splits);
|
||||||
|
|
||||||
|
Row[] result = bucketizer.transform(dataset).select("result").collect();
|
||||||
|
|
||||||
|
for (Row r : result) {
|
||||||
|
double index = r.getDouble(0);
|
||||||
|
Assert.assertTrue((index >= 0) && (index <= 1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue