[SPARK-6530] [ML] Add chi-square selector for ml package
See JIRA [here](https://issues.apache.org/jira/browse/SPARK-6530). Author: Xusen Yin <yinxusen@gmail.com> Closes #5742 from yinxusen/SPARK-6530.
This commit is contained in:
parent
23a9448c04
commit
633aaae0a1
|
@ -0,0 +1,150 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml.feature
|
||||
|
||||
import org.apache.spark.annotation.Experimental
|
||||
import org.apache.spark.ml._
|
||||
import org.apache.spark.ml.attribute.{AttributeGroup, _}
|
||||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.ml.param.shared._
|
||||
import org.apache.spark.ml.util.Identifiable
|
||||
import org.apache.spark.ml.util.SchemaUtils
|
||||
import org.apache.spark.mllib.feature
|
||||
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
|
||||
|
||||
/**
|
||||
* Params for [[ChiSqSelector]] and [[ChiSqSelectorModel]].
|
||||
*/
|
||||
private[feature] trait ChiSqSelectorParams extends Params
|
||||
with HasFeaturesCol with HasOutputCol with HasLabelCol {
|
||||
|
||||
/**
|
||||
* Number of features that selector will select (ordered by statistic value descending). If the
|
||||
* number of features is < numTopFeatures, then this will select all features. The default value
|
||||
* of numTopFeatures is 50.
|
||||
* @group param
|
||||
*/
|
||||
final val numTopFeatures = new IntParam(this, "numTopFeatures",
|
||||
"Number of features that selector will select, ordered by statistics value descending. If the" +
|
||||
" number of features is < numTopFeatures, then this will select all features.",
|
||||
ParamValidators.gtEq(1))
|
||||
setDefault(numTopFeatures -> 50)
|
||||
|
||||
/** @group getParam */
|
||||
def getNumTopFeatures: Int = $(numTopFeatures)
|
||||
}
|
||||
|
||||
/**
|
||||
* :: Experimental ::
|
||||
* Chi-Squared feature selection, which selects categorical features to use for predicting a
|
||||
* categorical label.
|
||||
*/
|
||||
@Experimental
|
||||
final class ChiSqSelector(override val uid: String)
|
||||
extends Estimator[ChiSqSelectorModel] with ChiSqSelectorParams {
|
||||
|
||||
def this() = this(Identifiable.randomUID("chiSqSelector"))
|
||||
|
||||
/** @group setParam */
|
||||
def setNumTopFeatures(value: Int): this.type = set(numTopFeatures, value)
|
||||
|
||||
/** @group setParam */
|
||||
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
|
||||
|
||||
/** @group setParam */
|
||||
def setOutputCol(value: String): this.type = set(outputCol, value)
|
||||
|
||||
/** @group setParam */
|
||||
def setLabelCol(value: String): this.type = set(labelCol, value)
|
||||
|
||||
override def fit(dataset: DataFrame): ChiSqSelectorModel = {
|
||||
transformSchema(dataset.schema, logging = true)
|
||||
val input = dataset.select($(labelCol), $(featuresCol)).map {
|
||||
case Row(label: Double, features: Vector) =>
|
||||
LabeledPoint(label, features)
|
||||
}
|
||||
val chiSqSelector = new feature.ChiSqSelector($(numTopFeatures)).fit(input)
|
||||
copyValues(new ChiSqSelectorModel(uid, chiSqSelector).setParent(this))
|
||||
}
|
||||
|
||||
override def transformSchema(schema: StructType): StructType = {
|
||||
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
|
||||
SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
|
||||
SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
|
||||
}
|
||||
|
||||
override def copy(extra: ParamMap): ChiSqSelector = defaultCopy(extra)
|
||||
}
|
||||
|
||||
/**
|
||||
* :: Experimental ::
|
||||
* Model fitted by [[ChiSqSelector]].
|
||||
*/
|
||||
@Experimental
|
||||
final class ChiSqSelectorModel private[ml] (
|
||||
override val uid: String,
|
||||
private val chiSqSelector: feature.ChiSqSelectorModel)
|
||||
extends Model[ChiSqSelectorModel] with ChiSqSelectorParams {
|
||||
|
||||
/** @group setParam */
|
||||
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
|
||||
|
||||
/** @group setParam */
|
||||
def setOutputCol(value: String): this.type = set(outputCol, value)
|
||||
|
||||
/** @group setParam */
|
||||
def setLabelCol(value: String): this.type = set(labelCol, value)
|
||||
|
||||
override def transform(dataset: DataFrame): DataFrame = {
|
||||
val transformedSchema = transformSchema(dataset.schema, logging = true)
|
||||
val newField = transformedSchema.last
|
||||
val selector = udf { chiSqSelector.transform _ }
|
||||
dataset.withColumn($(outputCol), selector(col($(featuresCol))), newField.metadata)
|
||||
}
|
||||
|
||||
override def transformSchema(schema: StructType): StructType = {
|
||||
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
|
||||
val newField = prepOutputField(schema)
|
||||
val outputFields = schema.fields :+ newField
|
||||
StructType(outputFields)
|
||||
}
|
||||
|
||||
/**
|
||||
* Prepare the output column field, including per-feature metadata.
|
||||
*/
|
||||
private def prepOutputField(schema: StructType): StructField = {
|
||||
val selector = chiSqSelector.selectedFeatures.toSet
|
||||
val origAttrGroup = AttributeGroup.fromStructField(schema($(featuresCol)))
|
||||
val featureAttributes: Array[Attribute] = if (origAttrGroup.attributes.nonEmpty) {
|
||||
origAttrGroup.attributes.get.zipWithIndex.filter(x => selector.contains(x._2)).map(_._1)
|
||||
} else {
|
||||
Array.fill[Attribute](selector.size)(NominalAttribute.defaultAttr)
|
||||
}
|
||||
val newAttributeGroup = new AttributeGroup($(outputCol), featureAttributes)
|
||||
newAttributeGroup.toStructField()
|
||||
}
|
||||
|
||||
override def copy(extra: ParamMap): ChiSqSelectorModel = {
|
||||
val copied = new ChiSqSelectorModel(uid, chiSqSelector)
|
||||
copyValues(copied, extra).setParent(parent)
|
||||
}
|
||||
}
|
|
@ -109,6 +109,8 @@ class ChiSqSelectorModel @Since("1.3.0") (
|
|||
* Creates a ChiSquared feature selector.
|
||||
* @param numTopFeatures number of features that selector will select
|
||||
* (ordered by statistic value descending)
|
||||
* Note that if the number of features is < numTopFeatures, then this will
|
||||
* select all features.
|
||||
*/
|
||||
@Since("1.3.0")
|
||||
@Experimental
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.ml.feature
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
import org.apache.spark.mllib.util.TestingUtils._
|
||||
import org.apache.spark.sql.{Row, SQLContext}
|
||||
|
||||
class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||
test("Test Chi-Square selector") {
|
||||
val sqlContext = SQLContext.getOrCreate(sc)
|
||||
import sqlContext.implicits._
|
||||
|
||||
val data = Seq(
|
||||
LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))),
|
||||
LabeledPoint(1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0)))),
|
||||
LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))),
|
||||
LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0)))
|
||||
)
|
||||
|
||||
val preFilteredData = Seq(
|
||||
Vectors.dense(0.0),
|
||||
Vectors.dense(6.0),
|
||||
Vectors.dense(8.0),
|
||||
Vectors.dense(5.0)
|
||||
)
|
||||
|
||||
val df = sc.parallelize(data.zip(preFilteredData))
|
||||
.map(x => (x._1.label, x._1.features, x._2))
|
||||
.toDF("label", "data", "preFilteredData")
|
||||
|
||||
val model = new ChiSqSelector()
|
||||
.setNumTopFeatures(1)
|
||||
.setFeaturesCol("data")
|
||||
.setLabelCol("label")
|
||||
.setOutputCol("filtered")
|
||||
|
||||
model.fit(df).transform(df).select("filtered", "preFilteredData").collect().foreach {
|
||||
case Row(vec1: Vector, vec2: Vector) =>
|
||||
assert(vec1 ~== vec2 absTol 1e-1)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue