[SPARK-18088][ML] Various ChiSqSelector cleanups
## What changes were proposed in this pull request? - Renamed kbest to numTopFeatures - Renamed alpha to fpr - Added missing Since annotations - Doc cleanups ## How was this patch tested? Added new standardized unit tests for spark.ml. Improved existing unit test coverage a bit. Author: Joseph K. Bradley <joseph@databricks.com> Closes #15647 from jkbradley/chisqselector-follow-ups.
This commit is contained in:
parent
b929537b6e
commit
91c33a0ca5
|
@ -1338,14 +1338,14 @@ for more details on the API.
|
|||
`ChiSqSelector` stands for Chi-Squared feature selection. It operates on labeled data with
|
||||
categorical features. ChiSqSelector uses the
|
||||
[Chi-Squared test of independence](https://en.wikipedia.org/wiki/Chi-squared_test) to decide which
|
||||
features to choose. It supports three selection methods: `KBest`, `Percentile` and `FPR`:
|
||||
features to choose. It supports three selection methods: `numTopFeatures`, `percentile`, `fpr`:
|
||||
|
||||
* `KBest` chooses the `k` top features according to a chi-squared test. This is akin to yielding the features with the most predictive power.
|
||||
* `Percentile` is similar to `KBest` but chooses a fraction of all features instead of a fixed number.
|
||||
* `FPR` chooses all features whose false positive rate meets some threshold.
|
||||
* `numTopFeatures` chooses a fixed number of top features according to a chi-squared test. This is akin to yielding the features with the most predictive power.
|
||||
* `percentile` is similar to `numTopFeatures` but chooses a fraction of all features instead of a fixed number.
|
||||
* `fpr` chooses all features whose p-value is below a threshold, thus controlling the false positive rate of selection.
|
||||
|
||||
By default, the selection method is `KBest`, the default number of top features is 50. User can use
|
||||
`setNumTopFeatures`, `setPercentile` and `setAlpha` to set different selection methods.
|
||||
By default, the selection method is `numTopFeatures`, with the default number of top features set to 50.
|
||||
The user can choose a selection method using `setSelectorType`.
|
||||
|
||||
**Examples**
|
||||
|
||||
|
|
|
@ -227,22 +227,19 @@ both speed and statistical learning behavior.
|
|||
[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) implements
|
||||
Chi-Squared feature selection. It operates on labeled data with categorical features. ChiSqSelector uses the
|
||||
[Chi-Squared test of independence](https://en.wikipedia.org/wiki/Chi-squared_test) to decide which
|
||||
features to choose. It supports three selection methods: `KBest`, `Percentile` and `FPR`:
|
||||
features to choose. It supports three selection methods: `numTopFeatures`, `percentile`, `fpr`:
|
||||
|
||||
* `KBest` chooses the `k` top features according to a chi-squared test. This is akin to yielding the features with the most predictive power.
|
||||
* `Percentile` is similar to `KBest` but chooses a fraction of all features instead of a fixed number.
|
||||
* `FPR` chooses all features whose false positive rate meets some threshold.
|
||||
* `numTopFeatures` chooses a fixed number of top features according to a chi-squared test. This is akin to yielding the features with the most predictive power.
|
||||
* `percentile` is similar to `numTopFeatures` but chooses a fraction of all features instead of a fixed number.
|
||||
* `fpr` chooses all features whose p-value is below a threshold, thus controlling the false positive rate of selection.
|
||||
|
||||
By default, the selection method is `KBest`, the default number of top features is 50. User can use
|
||||
`setNumTopFeatures`, `setPercentile` and `setAlpha` to set different selection methods.
|
||||
By default, the selection method is `numTopFeatures`, with the default number of top features set to 50.
|
||||
The user can choose a selection method using `setSelectorType`.
|
||||
|
||||
The number of features to select can be tuned using a held-out validation set.
|
||||
|
||||
### Model Fitting
|
||||
|
||||
`ChiSqSelector` takes a `numTopFeatures` parameter specifying the number of top features that
|
||||
the selector will select.
|
||||
|
||||
The [`fit`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) method takes
|
||||
an input of `RDD[LabeledPoint]` with categorical features, learns the summary statistics, and then
|
||||
returns a `ChiSqSelectorModel` which can transform an input dataset into the reduced feature space.
|
||||
|
|
|
@ -42,69 +42,80 @@ private[feature] trait ChiSqSelectorParams extends Params
|
|||
with HasFeaturesCol with HasOutputCol with HasLabelCol {
|
||||
|
||||
/**
|
||||
* Number of features that selector will select (ordered by statistic value descending). If the
|
||||
* Number of features that selector will select, ordered by ascending p-value. If the
|
||||
* number of features is less than numTopFeatures, then this will select all features.
|
||||
* Only applicable when selectorType = "kbest".
|
||||
* Only applicable when selectorType = "numTopFeatures".
|
||||
* The default value of numTopFeatures is 50.
|
||||
*
|
||||
* @group param
|
||||
*/
|
||||
@Since("1.6.0")
|
||||
final val numTopFeatures = new IntParam(this, "numTopFeatures",
|
||||
"Number of features that selector will select, ordered by statistics value descending. If the" +
|
||||
"Number of features that selector will select, ordered by ascending p-value. If the" +
|
||||
" number of features is < numTopFeatures, then this will select all features.",
|
||||
ParamValidators.gtEq(1))
|
||||
setDefault(numTopFeatures -> 50)
|
||||
|
||||
/** @group getParam */
|
||||
@Since("1.6.0")
|
||||
def getNumTopFeatures: Int = $(numTopFeatures)
|
||||
|
||||
/**
|
||||
* Percentile of features that selector will select, ordered by statistics value descending.
|
||||
* Only applicable when selectorType = "percentile".
|
||||
* Default value is 0.1.
|
||||
* @group param
|
||||
*/
|
||||
@Since("2.1.0")
|
||||
final val percentile = new DoubleParam(this, "percentile",
|
||||
"Percentile of features that selector will select, ordered by statistics value descending.",
|
||||
"Percentile of features that selector will select, ordered by ascending p-value.",
|
||||
ParamValidators.inRange(0, 1))
|
||||
setDefault(percentile -> 0.1)
|
||||
|
||||
/** @group getParam */
|
||||
@Since("2.1.0")
|
||||
def getPercentile: Double = $(percentile)
|
||||
|
||||
/**
|
||||
* The highest p-value for features to be kept.
|
||||
* Only applicable when selectorType = "fpr".
|
||||
* Default value is 0.05.
|
||||
* @group param
|
||||
*/
|
||||
final val alpha = new DoubleParam(this, "alpha", "The highest p-value for features to be kept.",
|
||||
final val fpr = new DoubleParam(this, "fpr", "The highest p-value for features to be kept.",
|
||||
ParamValidators.inRange(0, 1))
|
||||
setDefault(alpha -> 0.05)
|
||||
setDefault(fpr -> 0.05)
|
||||
|
||||
/** @group getParam */
|
||||
def getAlpha: Double = $(alpha)
|
||||
def getFpr: Double = $(fpr)
|
||||
|
||||
/**
|
||||
* The selector type of the ChisqSelector.
|
||||
* Supported options: "kbest" (default), "percentile" and "fpr".
|
||||
* Supported options: "numTopFeatures" (default), "percentile", "fpr".
|
||||
* @group param
|
||||
*/
|
||||
@Since("2.1.0")
|
||||
final val selectorType = new Param[String](this, "selectorType",
|
||||
"The selector type of the ChisqSelector. " +
|
||||
"Supported options: kbest (default), percentile and fpr.",
|
||||
ParamValidators.inArray[String](OldChiSqSelector.supportedSelectorTypes.toArray))
|
||||
setDefault(selectorType -> OldChiSqSelector.KBest)
|
||||
"Supported options: " + OldChiSqSelector.supportedSelectorTypes.mkString(", "),
|
||||
ParamValidators.inArray[String](OldChiSqSelector.supportedSelectorTypes))
|
||||
setDefault(selectorType -> OldChiSqSelector.NumTopFeatures)
|
||||
|
||||
/** @group getParam */
|
||||
@Since("2.1.0")
|
||||
def getSelectorType: String = $(selectorType)
|
||||
}
|
||||
|
||||
/**
|
||||
* Chi-Squared feature selection, which selects categorical features to use for predicting a
|
||||
* categorical label.
|
||||
* The selector supports three selection methods: `kbest`, `percentile` and `fpr`.
|
||||
* `kbest` chooses the `k` top features according to a chi-squared test.
|
||||
* `percentile` is similar but chooses a fraction of all features instead of a fixed number.
|
||||
* `fpr` chooses all features whose false positive rate meets some threshold.
|
||||
* By default, the selection method is `kbest`, the default number of top features is 50.
|
||||
* The selector supports different selection methods: `numTopFeatures`, `percentile`, `fpr`.
|
||||
* - `numTopFeatures` chooses a fixed number of top features according to a chi-squared test.
|
||||
* - `percentile` is similar but chooses a fraction of all features instead of a fixed number.
|
||||
* - `fpr` chooses all features whose p-value is below a threshold, thus controlling the false
|
||||
* positive rate of selection.
|
||||
* By default, the selection method is `numTopFeatures`, with the default number of top features
|
||||
* set to 50.
|
||||
*/
|
||||
@Since("1.6.0")
|
||||
final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: String)
|
||||
|
@ -113,10 +124,6 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str
|
|||
@Since("1.6.0")
|
||||
def this() = this(Identifiable.randomUID("chiSqSelector"))
|
||||
|
||||
/** @group setParam */
|
||||
@Since("2.1.0")
|
||||
def setSelectorType(value: String): this.type = set(selectorType, value)
|
||||
|
||||
/** @group setParam */
|
||||
@Since("1.6.0")
|
||||
def setNumTopFeatures(value: Int): this.type = set(numTopFeatures, value)
|
||||
|
@ -127,7 +134,11 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str
|
|||
|
||||
/** @group setParam */
|
||||
@Since("2.1.0")
|
||||
def setAlpha(value: Double): this.type = set(alpha, value)
|
||||
def setFpr(value: Double): this.type = set(fpr, value)
|
||||
|
||||
/** @group setParam */
|
||||
@Since("2.1.0")
|
||||
def setSelectorType(value: String): this.type = set(selectorType, value)
|
||||
|
||||
/** @group setParam */
|
||||
@Since("1.6.0")
|
||||
|
@ -153,15 +164,15 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str
|
|||
.setSelectorType($(selectorType))
|
||||
.setNumTopFeatures($(numTopFeatures))
|
||||
.setPercentile($(percentile))
|
||||
.setAlpha($(alpha))
|
||||
.setFpr($(fpr))
|
||||
val model = selector.fit(input)
|
||||
copyValues(new ChiSqSelectorModel(uid, model).setParent(this))
|
||||
}
|
||||
|
||||
@Since("1.6.0")
|
||||
override def transformSchema(schema: StructType): StructType = {
|
||||
val otherPairs = OldChiSqSelector.supportedTypeAndParamPairs.filter(_._1 != $(selectorType))
|
||||
otherPairs.foreach { case (_, paramName: String) =>
|
||||
val otherPairs = OldChiSqSelector.supportedSelectorTypes.filter(_ != $(selectorType))
|
||||
otherPairs.foreach { paramName: String =>
|
||||
if (isSet(getParam(paramName))) {
|
||||
logWarning(s"Param $paramName will take no effect when selector type = ${$(selectorType)}.")
|
||||
}
|
||||
|
|
|
@ -638,13 +638,13 @@ private[python] class PythonMLLibAPI extends Serializable {
|
|||
selectorType: String,
|
||||
numTopFeatures: Int,
|
||||
percentile: Double,
|
||||
alpha: Double,
|
||||
fpr: Double,
|
||||
data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = {
|
||||
new ChiSqSelector()
|
||||
.setSelectorType(selectorType)
|
||||
.setNumTopFeatures(numTopFeatures)
|
||||
.setPercentile(percentile)
|
||||
.setAlpha(alpha)
|
||||
.setFpr(fpr)
|
||||
.fit(data.rdd)
|
||||
}
|
||||
|
||||
|
|
|
@ -161,7 +161,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] {
|
|||
Loader.checkSchema[Data](dataFrame.schema)
|
||||
|
||||
val features = dataArray.rdd.map {
|
||||
case Row(feature: Int) => (feature)
|
||||
case Row(feature: Int) => feature
|
||||
}.collect()
|
||||
|
||||
new ChiSqSelectorModel(features)
|
||||
|
@ -171,18 +171,20 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] {
|
|||
|
||||
/**
|
||||
* Creates a ChiSquared feature selector.
|
||||
* The selector supports three selection methods: `kbest`, `percentile` and `fpr`.
|
||||
* `kbest` chooses the `k` top features according to a chi-squared test.
|
||||
* `percentile` is similar but chooses a fraction of all features instead of a fixed number.
|
||||
* `fpr` chooses all features whose false positive rate meets some threshold.
|
||||
* By default, the selection method is `kbest`, the default number of top features is 50.
|
||||
* The selector supports different selection methods: `numTopFeatures`, `percentile`, `fpr`.
|
||||
* - `numTopFeatures` chooses a fixed number of top features according to a chi-squared test.
|
||||
* - `percentile` is similar but chooses a fraction of all features instead of a fixed number.
|
||||
* - `fpr` chooses all features whose p-value is below a threshold, thus controlling the false
|
||||
* positive rate of selection.
|
||||
* By default, the selection method is `numTopFeatures`, with the default number of top features
|
||||
* set to 50.
|
||||
*/
|
||||
@Since("1.3.0")
|
||||
class ChiSqSelector @Since("2.1.0") () extends Serializable {
|
||||
var numTopFeatures: Int = 50
|
||||
var percentile: Double = 0.1
|
||||
var alpha: Double = 0.05
|
||||
var selectorType = ChiSqSelector.KBest
|
||||
var fpr: Double = 0.05
|
||||
var selectorType = ChiSqSelector.NumTopFeatures
|
||||
|
||||
/**
|
||||
* The is the same to call this() and setNumTopFeatures(numTopFeatures)
|
||||
|
@ -207,15 +209,15 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
|
|||
}
|
||||
|
||||
@Since("2.1.0")
|
||||
def setAlpha(value: Double): this.type = {
|
||||
require(0.0 <= value && value <= 1.0, "Alpha must be in [0,1]")
|
||||
alpha = value
|
||||
def setFpr(value: Double): this.type = {
|
||||
require(0.0 <= value && value <= 1.0, "FPR must be in [0,1]")
|
||||
fpr = value
|
||||
this
|
||||
}
|
||||
|
||||
@Since("2.1.0")
|
||||
def setSelectorType(value: String): this.type = {
|
||||
require(ChiSqSelector.supportedSelectorTypes.toSeq.contains(value),
|
||||
require(ChiSqSelector.supportedSelectorTypes.contains(value),
|
||||
s"ChiSqSelector Type: $value was not supported.")
|
||||
selectorType = value
|
||||
this
|
||||
|
@ -232,7 +234,7 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
|
|||
def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = {
|
||||
val chiSqTestResult = Statistics.chiSqTest(data).zipWithIndex
|
||||
val features = selectorType match {
|
||||
case ChiSqSelector.KBest =>
|
||||
case ChiSqSelector.NumTopFeatures =>
|
||||
chiSqTestResult
|
||||
.sortBy { case (res, _) => res.pValue }
|
||||
.take(numTopFeatures)
|
||||
|
@ -242,7 +244,7 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
|
|||
.take((chiSqTestResult.length * percentile).toInt)
|
||||
case ChiSqSelector.FPR =>
|
||||
chiSqTestResult
|
||||
.filter { case (res, _) => res.pValue < alpha }
|
||||
.filter { case (res, _) => res.pValue < fpr }
|
||||
case errorType =>
|
||||
throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType")
|
||||
}
|
||||
|
@ -251,22 +253,17 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
|
|||
}
|
||||
}
|
||||
|
||||
@Since("2.1.0")
|
||||
object ChiSqSelector {
|
||||
private[spark] object ChiSqSelector {
|
||||
|
||||
/** String name for `kbest` selector type. */
|
||||
private[spark] val KBest: String = "kbest"
|
||||
/** String name for `numTopFeatures` selector type. */
|
||||
val NumTopFeatures: String = "numTopFeatures"
|
||||
|
||||
/** String name for `percentile` selector type. */
|
||||
private[spark] val Percentile: String = "percentile"
|
||||
val Percentile: String = "percentile"
|
||||
|
||||
/** String name for `fpr` selector type. */
|
||||
private[spark] val FPR: String = "fpr"
|
||||
|
||||
/** Set of selector type and param pairs that ChiSqSelector supports. */
|
||||
private[spark] val supportedTypeAndParamPairs = Set(KBest -> "numTopFeatures",
|
||||
Percentile -> "percentile", FPR -> "alpha")
|
||||
|
||||
/** Set of selector types that ChiSqSelector supports. */
|
||||
private[spark] val supportedSelectorTypes = supportedTypeAndParamPairs.map(_._1)
|
||||
val supportedSelectorTypes: Array[String] = Array(NumTopFeatures, Percentile, FPR)
|
||||
}
|
||||
|
|
|
@ -19,85 +19,72 @@ package org.apache.spark.ml.feature
|
|||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.ml.linalg.{Vector, Vectors}
|
||||
import org.apache.spark.ml.param.ParamsSuite
|
||||
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
|
||||
import org.apache.spark.ml.util.TestingUtils._
|
||||
import org.apache.spark.mllib.feature
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
import org.apache.spark.sql.Row
|
||||
import org.apache.spark.sql.{Dataset, Row}
|
||||
|
||||
class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
|
||||
with DefaultReadWriteTest {
|
||||
|
||||
test("Test Chi-Square selector") {
|
||||
import testImplicits._
|
||||
val data = Seq(
|
||||
LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))),
|
||||
LabeledPoint(1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0)))),
|
||||
LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))),
|
||||
LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0)))
|
||||
)
|
||||
@transient var dataset: Dataset[_] = _
|
||||
|
||||
val preFilteredData = Seq(
|
||||
Vectors.dense(8.0),
|
||||
Vectors.dense(0.0),
|
||||
Vectors.dense(0.0),
|
||||
Vectors.dense(8.0)
|
||||
)
|
||||
override def beforeAll(): Unit = {
|
||||
super.beforeAll()
|
||||
|
||||
val df = sc.parallelize(data.zip(preFilteredData))
|
||||
.map(x => (x._1.label, x._1.features, x._2))
|
||||
.toDF("label", "data", "preFilteredData")
|
||||
// Toy dataset, including the top feature for a chi-squared test.
|
||||
// These data are chosen such that each feature's test has a distinct p-value.
|
||||
/* To verify the results with R, run:
|
||||
library(stats)
|
||||
x1 <- c(8.0, 0.0, 0.0, 7.0, 8.0)
|
||||
x2 <- c(7.0, 9.0, 9.0, 9.0, 7.0)
|
||||
x3 <- c(0.0, 6.0, 8.0, 5.0, 3.0)
|
||||
y <- c(0.0, 1.0, 1.0, 2.0, 2.0)
|
||||
chisq.test(x1,y)
|
||||
chisq.test(x2,y)
|
||||
chisq.test(x3,y)
|
||||
*/
|
||||
dataset = spark.createDataFrame(Seq(
|
||||
(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0))), Vectors.dense(8.0)),
|
||||
(1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0))), Vectors.dense(0.0)),
|
||||
(1.0, Vectors.dense(Array(0.0, 9.0, 8.0)), Vectors.dense(0.0)),
|
||||
(2.0, Vectors.dense(Array(7.0, 9.0, 5.0)), Vectors.dense(7.0)),
|
||||
(2.0, Vectors.dense(Array(8.0, 7.0, 3.0)), Vectors.dense(8.0))
|
||||
)).toDF("label", "features", "topFeature")
|
||||
}
|
||||
|
||||
test("params") {
|
||||
ParamsSuite.checkParams(new ChiSqSelector)
|
||||
val model = new ChiSqSelectorModel("myModel",
|
||||
new org.apache.spark.mllib.feature.ChiSqSelectorModel(Array(1, 3, 4)))
|
||||
ParamsSuite.checkParams(model)
|
||||
}
|
||||
|
||||
test("Test Chi-Square selector: numTopFeatures") {
|
||||
val selector = new ChiSqSelector()
|
||||
.setSelectorType("kbest")
|
||||
.setNumTopFeatures(1)
|
||||
.setFeaturesCol("data")
|
||||
.setLabelCol("label")
|
||||
.setOutputCol("filtered")
|
||||
.setOutputCol("filtered").setSelectorType("numTopFeatures").setNumTopFeatures(1)
|
||||
ChiSqSelectorSuite.testSelector(selector, dataset)
|
||||
}
|
||||
|
||||
selector.fit(df).transform(df).select("filtered", "preFilteredData").collect().foreach {
|
||||
case Row(vec1: Vector, vec2: Vector) =>
|
||||
assert(vec1 ~== vec2 absTol 1e-1)
|
||||
test("Test Chi-Square selector: percentile") {
|
||||
val selector = new ChiSqSelector()
|
||||
.setOutputCol("filtered").setSelectorType("percentile").setPercentile(0.34)
|
||||
ChiSqSelectorSuite.testSelector(selector, dataset)
|
||||
}
|
||||
|
||||
test("Test Chi-Square selector: fpr") {
|
||||
val selector = new ChiSqSelector()
|
||||
.setOutputCol("filtered").setSelectorType("fpr").setFpr(0.2)
|
||||
ChiSqSelectorSuite.testSelector(selector, dataset)
|
||||
}
|
||||
|
||||
test("read/write") {
|
||||
def checkModelData(model: ChiSqSelectorModel, model2: ChiSqSelectorModel): Unit = {
|
||||
assert(model.selectedFeatures === model2.selectedFeatures)
|
||||
}
|
||||
|
||||
selector.setSelectorType("percentile").setPercentile(0.34).fit(df).transform(df)
|
||||
.select("filtered", "preFilteredData").collect().foreach {
|
||||
case Row(vec1: Vector, vec2: Vector) =>
|
||||
assert(vec1 ~== vec2 absTol 1e-1)
|
||||
}
|
||||
|
||||
val preFilteredData2 = Seq(
|
||||
Vectors.dense(8.0, 7.0),
|
||||
Vectors.dense(0.0, 9.0),
|
||||
Vectors.dense(0.0, 9.0),
|
||||
Vectors.dense(8.0, 9.0)
|
||||
)
|
||||
|
||||
val df2 = sc.parallelize(data.zip(preFilteredData2))
|
||||
.map(x => (x._1.label, x._1.features, x._2))
|
||||
.toDF("label", "data", "preFilteredData")
|
||||
|
||||
selector.setSelectorType("fpr").setAlpha(0.2).fit(df2).transform(df2)
|
||||
.select("filtered", "preFilteredData").collect().foreach {
|
||||
case Row(vec1: Vector, vec2: Vector) =>
|
||||
assert(vec1 ~== vec2 absTol 1e-1)
|
||||
}
|
||||
}
|
||||
|
||||
test("ChiSqSelector read/write") {
|
||||
val t = new ChiSqSelector()
|
||||
.setFeaturesCol("myFeaturesCol")
|
||||
.setLabelCol("myLabelCol")
|
||||
.setOutputCol("myOutputCol")
|
||||
.setNumTopFeatures(2)
|
||||
testDefaultReadWrite(t)
|
||||
}
|
||||
|
||||
test("ChiSqSelectorModel read/write") {
|
||||
val oldModel = new feature.ChiSqSelectorModel(Array(1, 3))
|
||||
val instance = new ChiSqSelectorModel("myChiSqSelectorModel", oldModel)
|
||||
val newInstance = testDefaultReadWrite(instance)
|
||||
assert(newInstance.selectedFeatures === instance.selectedFeatures)
|
||||
val nb = new ChiSqSelector
|
||||
testEstimatorAndModelReadWrite(nb, dataset, ChiSqSelectorSuite.allParamSettings, checkModelData)
|
||||
}
|
||||
|
||||
test("should support all NumericType labels and not support other types") {
|
||||
|
@ -108,3 +95,25 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
object ChiSqSelectorSuite {
|
||||
|
||||
private def testSelector(selector: ChiSqSelector, dataset: Dataset[_]): Unit = {
|
||||
selector.fit(dataset).transform(dataset).select("filtered", "topFeature").collect()
|
||||
.foreach { case Row(vec1: Vector, vec2: Vector) =>
|
||||
assert(vec1 ~== vec2 absTol 1e-1)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Mapping from all Params to valid settings which differ from the defaults.
|
||||
* This is useful for tests which need to exercise all Params, such as save/load.
|
||||
* This excludes input columns to simplify some tests.
|
||||
*/
|
||||
val allParamSettings: Map[String, Any] = Map(
|
||||
"selectorType" -> "percentile",
|
||||
"numTopFeatures" -> 1,
|
||||
"percentile" -> 0.12,
|
||||
"outputCol" -> "myOutput"
|
||||
)
|
||||
}
|
||||
|
|
|
@ -54,33 +54,34 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))),
|
||||
LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0)))), 2)
|
||||
val preFilteredData =
|
||||
Set(LabeledPoint(0.0, Vectors.dense(Array(8.0))),
|
||||
Seq(LabeledPoint(0.0, Vectors.dense(Array(8.0))),
|
||||
LabeledPoint(1.0, Vectors.dense(Array(0.0))),
|
||||
LabeledPoint(1.0, Vectors.dense(Array(0.0))),
|
||||
LabeledPoint(2.0, Vectors.dense(Array(8.0))))
|
||||
val model = new ChiSqSelector(1).fit(labeledDiscreteData)
|
||||
val filteredData = labeledDiscreteData.map { lp =>
|
||||
LabeledPoint(lp.label, model.transform(lp.features))
|
||||
}.collect().toSet
|
||||
assert(filteredData == preFilteredData)
|
||||
}.collect().toSeq
|
||||
assert(filteredData === preFilteredData)
|
||||
}
|
||||
|
||||
test("ChiSqSelector by FPR transform test (sparse & dense vector)") {
|
||||
test("ChiSqSelector by fpr transform test (sparse & dense vector)") {
|
||||
val labeledDiscreteData = sc.parallelize(
|
||||
Seq(LabeledPoint(0.0, Vectors.sparse(4, Array((0, 8.0), (1, 7.0)))),
|
||||
LabeledPoint(1.0, Vectors.sparse(4, Array((1, 9.0), (2, 6.0), (3, 4.0)))),
|
||||
LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0, 4.0))),
|
||||
LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0, 9.0)))), 2)
|
||||
val preFilteredData =
|
||||
Set(LabeledPoint(0.0, Vectors.dense(Array(0.0))),
|
||||
Seq(LabeledPoint(0.0, Vectors.dense(Array(0.0))),
|
||||
LabeledPoint(1.0, Vectors.dense(Array(4.0))),
|
||||
LabeledPoint(1.0, Vectors.dense(Array(4.0))),
|
||||
LabeledPoint(2.0, Vectors.dense(Array(9.0))))
|
||||
val model = new ChiSqSelector().setSelectorType("fpr").setAlpha(0.1).fit(labeledDiscreteData)
|
||||
val model: ChiSqSelectorModel = new ChiSqSelector().setSelectorType("fpr")
|
||||
.setFpr(0.1).fit(labeledDiscreteData)
|
||||
val filteredData = labeledDiscreteData.map { lp =>
|
||||
LabeledPoint(lp.label, model.transform(lp.features))
|
||||
}.collect().toSet
|
||||
assert(filteredData == preFilteredData)
|
||||
}.collect().toSeq
|
||||
assert(filteredData === preFilteredData)
|
||||
}
|
||||
|
||||
test("model load / save") {
|
||||
|
|
|
@ -2606,42 +2606,43 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja
|
|||
|
||||
selectorType = Param(Params._dummy(), "selectorType",
|
||||
"The selector type of the ChisqSelector. " +
|
||||
"Supported options: kbest (default), percentile and fpr.",
|
||||
"Supported options: numTopFeatures (default), percentile and fpr.",
|
||||
typeConverter=TypeConverters.toString)
|
||||
|
||||
numTopFeatures = \
|
||||
Param(Params._dummy(), "numTopFeatures",
|
||||
"Number of features that selector will select, ordered by statistics value " +
|
||||
"descending. If the number of features is < numTopFeatures, then this will select " +
|
||||
"Number of features that selector will select, ordered by ascending p-value. " +
|
||||
"If the number of features is < numTopFeatures, then this will select " +
|
||||
"all features.", typeConverter=TypeConverters.toInt)
|
||||
|
||||
percentile = Param(Params._dummy(), "percentile", "Percentile of features that selector " +
|
||||
"will select, ordered by statistics value descending.",
|
||||
"will select, ordered by ascending p-value.",
|
||||
typeConverter=TypeConverters.toFloat)
|
||||
|
||||
alpha = Param(Params._dummy(), "alpha", "The highest p-value for features to be kept.",
|
||||
typeConverter=TypeConverters.toFloat)
|
||||
fpr = Param(Params._dummy(), "fpr", "The highest p-value for features to be kept.",
|
||||
typeConverter=TypeConverters.toFloat)
|
||||
|
||||
@keyword_only
|
||||
def __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None,
|
||||
labelCol="label", selectorType="kbest", percentile=0.1, alpha=0.05):
|
||||
labelCol="label", selectorType="numTopFeatures", percentile=0.1, fpr=0.05):
|
||||
"""
|
||||
__init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, \
|
||||
labelCol="label", selectorType="kbest", percentile=0.1, alpha=0.05)
|
||||
labelCol="label", selectorType="numTopFeatures", percentile=0.1, fpr=0.05)
|
||||
"""
|
||||
super(ChiSqSelector, self).__init__()
|
||||
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ChiSqSelector", self.uid)
|
||||
self._setDefault(numTopFeatures=50, selectorType="kbest", percentile=0.1, alpha=0.05)
|
||||
self._setDefault(numTopFeatures=50, selectorType="numTopFeatures", percentile=0.1,
|
||||
fpr=0.05)
|
||||
kwargs = self.__init__._input_kwargs
|
||||
self.setParams(**kwargs)
|
||||
|
||||
@keyword_only
|
||||
@since("2.0.0")
|
||||
def setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None,
|
||||
labelCol="labels", selectorType="kbest", percentile=0.1, alpha=0.05):
|
||||
labelCol="labels", selectorType="numTopFeatures", percentile=0.1, fpr=0.05):
|
||||
"""
|
||||
setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None, \
|
||||
labelCol="labels", selectorType="kbest", percentile=0.1, alpha=0.05)
|
||||
labelCol="labels", selectorType="numTopFeatures", percentile=0.1, fpr=0.05)
|
||||
Sets params for this ChiSqSelector.
|
||||
"""
|
||||
kwargs = self.setParams._input_kwargs
|
||||
|
@ -2665,7 +2666,7 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja
|
|||
def setNumTopFeatures(self, value):
|
||||
"""
|
||||
Sets the value of :py:attr:`numTopFeatures`.
|
||||
Only applicable when selectorType = "kbest".
|
||||
Only applicable when selectorType = "numTopFeatures".
|
||||
"""
|
||||
return self._set(numTopFeatures=value)
|
||||
|
||||
|
@ -2692,19 +2693,19 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja
|
|||
return self.getOrDefault(self.percentile)
|
||||
|
||||
@since("2.1.0")
|
||||
def setAlpha(self, value):
|
||||
def setFpr(self, value):
|
||||
"""
|
||||
Sets the value of :py:attr:`alpha`.
|
||||
Sets the value of :py:attr:`fpr`.
|
||||
Only applicable when selectorType = "fpr".
|
||||
"""
|
||||
return self._set(alpha=value)
|
||||
return self._set(fpr=value)
|
||||
|
||||
@since("2.1.0")
|
||||
def getAlpha(self):
|
||||
def getFpr(self):
|
||||
"""
|
||||
Gets the value of alpha or its default value.
|
||||
Gets the value of fpr or its default value.
|
||||
"""
|
||||
return self.getOrDefault(self.alpha)
|
||||
return self.getOrDefault(self.fpr)
|
||||
|
||||
def _create_model(self, java_model):
|
||||
return ChiSqSelectorModel(java_model)
|
||||
|
|
|
@ -274,52 +274,48 @@ class ChiSqSelectorModel(JavaVectorTransformer):
|
|||
class ChiSqSelector(object):
|
||||
"""
|
||||
Creates a ChiSquared feature selector.
|
||||
The selector supports three selection methods: `KBest`, `Percentile` and `FPR`.
|
||||
`kbest` chooses the `k` top features according to a chi-squared test.
|
||||
The selector supports different selection methods: `numTopFeatures`, `percentile`, `fpr`.
|
||||
`numTopFeatures` chooses a fixed number of top features according to a chi-squared test.
|
||||
`percentile` is similar but chooses a fraction of all features instead of a fixed number.
|
||||
`fpr` chooses all features whose false positive rate meets some threshold.
|
||||
By default, the selection method is `kbest`, the default number of top features is 50.
|
||||
`fpr` chooses all features whose p-value is below a threshold, thus controlling the false
|
||||
positive rate of selection.
|
||||
By default, the selection method is `numTopFeatures`, with the default number of top features
|
||||
set to 50.
|
||||
|
||||
>>> data = [
|
||||
>>> data = sc.parallelize([
|
||||
... LabeledPoint(0.0, SparseVector(3, {0: 8.0, 1: 7.0})),
|
||||
... LabeledPoint(1.0, SparseVector(3, {1: 9.0, 2: 6.0})),
|
||||
... LabeledPoint(1.0, [0.0, 9.0, 8.0]),
|
||||
... LabeledPoint(2.0, [8.0, 9.0, 5.0])
|
||||
... ]
|
||||
>>> model = ChiSqSelector().setNumTopFeatures(1).fit(sc.parallelize(data))
|
||||
... LabeledPoint(2.0, [7.0, 9.0, 5.0]),
|
||||
... LabeledPoint(2.0, [8.0, 7.0, 3.0])
|
||||
... ])
|
||||
>>> model = ChiSqSelector(numTopFeatures=1).fit(data)
|
||||
>>> model.transform(SparseVector(3, {1: 9.0, 2: 6.0}))
|
||||
SparseVector(1, {})
|
||||
>>> model.transform(DenseVector([8.0, 9.0, 5.0]))
|
||||
DenseVector([8.0])
|
||||
>>> model = ChiSqSelector().setSelectorType("percentile").setPercentile(0.34).fit(
|
||||
... sc.parallelize(data))
|
||||
>>> model.transform(DenseVector([7.0, 9.0, 5.0]))
|
||||
DenseVector([7.0])
|
||||
>>> model = ChiSqSelector(selectorType="fpr", fpr=0.2).fit(data)
|
||||
>>> model.transform(SparseVector(3, {1: 9.0, 2: 6.0}))
|
||||
SparseVector(1, {})
|
||||
>>> model.transform(DenseVector([8.0, 9.0, 5.0]))
|
||||
DenseVector([8.0])
|
||||
>>> data = [
|
||||
... LabeledPoint(0.0, SparseVector(4, {0: 8.0, 1: 7.0})),
|
||||
... LabeledPoint(1.0, SparseVector(4, {1: 9.0, 2: 6.0, 3: 4.0})),
|
||||
... LabeledPoint(1.0, [0.0, 9.0, 8.0, 4.0]),
|
||||
... LabeledPoint(2.0, [8.0, 9.0, 5.0, 9.0])
|
||||
... ]
|
||||
>>> model = ChiSqSelector().setSelectorType("fpr").setAlpha(0.1).fit(sc.parallelize(data))
|
||||
>>> model.transform(DenseVector([1.0,2.0,3.0,4.0]))
|
||||
DenseVector([4.0])
|
||||
>>> model.transform(DenseVector([7.0, 9.0, 5.0]))
|
||||
DenseVector([7.0])
|
||||
>>> model = ChiSqSelector(selectorType="percentile", percentile=0.34).fit(data)
|
||||
>>> model.transform(DenseVector([7.0, 9.0, 5.0]))
|
||||
DenseVector([7.0])
|
||||
|
||||
.. versionadded:: 1.4.0
|
||||
"""
|
||||
def __init__(self, numTopFeatures=50, selectorType="kbest", percentile=0.1, alpha=0.05):
|
||||
def __init__(self, numTopFeatures=50, selectorType="numTopFeatures", percentile=0.1, fpr=0.05):
|
||||
self.numTopFeatures = numTopFeatures
|
||||
self.selectorType = selectorType
|
||||
self.percentile = percentile
|
||||
self.alpha = alpha
|
||||
self.fpr = fpr
|
||||
|
||||
@since('2.1.0')
|
||||
def setNumTopFeatures(self, numTopFeatures):
|
||||
"""
|
||||
set numTopFeature for feature selection by number of top features.
|
||||
Only applicable when selectorType = "kbest".
|
||||
Only applicable when selectorType = "numTopFeatures".
|
||||
"""
|
||||
self.numTopFeatures = int(numTopFeatures)
|
||||
return self
|
||||
|
@ -334,19 +330,19 @@ class ChiSqSelector(object):
|
|||
return self
|
||||
|
||||
@since('2.1.0')
|
||||
def setAlpha(self, alpha):
|
||||
def setFpr(self, fpr):
|
||||
"""
|
||||
set alpha [0.0, 1.0] for feature selection by FPR.
|
||||
set FPR [0.0, 1.0] for feature selection by FPR.
|
||||
Only applicable when selectorType = "fpr".
|
||||
"""
|
||||
self.alpha = float(alpha)
|
||||
self.fpr = float(fpr)
|
||||
return self
|
||||
|
||||
@since('2.1.0')
|
||||
def setSelectorType(self, selectorType):
|
||||
"""
|
||||
set the selector type of the ChisqSelector.
|
||||
Supported options: "kbest" (default), "percentile" and "fpr".
|
||||
Supported options: "numTopFeatures" (default), "percentile", "fpr".
|
||||
"""
|
||||
self.selectorType = str(selectorType)
|
||||
return self
|
||||
|
@ -362,7 +358,7 @@ class ChiSqSelector(object):
|
|||
Apply feature discretizer before using this function.
|
||||
"""
|
||||
jmodel = callMLlibFunc("fitChiSqSelector", self.selectorType, self.numTopFeatures,
|
||||
self.percentile, self.alpha, data)
|
||||
self.percentile, self.fpr, data)
|
||||
return ChiSqSelectorModel(jmodel)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue