[SPARK-17704][ML][MLLIB] ChiSqSelector performance improvement.
## What changes were proposed in this pull request? Partial revert of #15277 to instead sort and store input to model rather than require sorted input ## How was this patch tested? Existing tests. Author: Sean Owen <sowen@cloudera.com> Closes #15299 from srowen/SPARK-17704.2.
This commit is contained in:
parent
af6ece33d3
commit
b88cb63da3
|
@ -193,7 +193,7 @@ final class ChiSqSelectorModel private[ml] (
|
|||
|
||||
import ChiSqSelectorModel._
|
||||
|
||||
/** list of indices to select (filter). Must be ordered asc */
|
||||
/** list of indices to select (filter). */
|
||||
@Since("1.6.0")
|
||||
val selectedFeatures: Array[Int] = chiSqSelector.selectedFeatures
|
||||
|
||||
|
|
|
@ -35,14 +35,15 @@ import org.apache.spark.sql.{Row, SparkSession}
|
|||
/**
|
||||
* Chi Squared selector model.
|
||||
*
|
||||
* @param selectedFeatures list of indices to select (filter). Must be ordered asc
|
||||
* @param selectedFeatures list of indices to select (filter).
|
||||
*/
|
||||
@Since("1.3.0")
|
||||
class ChiSqSelectorModel @Since("1.3.0") (
|
||||
@Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable {
|
||||
|
||||
require(isSorted(selectedFeatures), "Array has to be sorted asc")
|
||||
private val filterIndices = selectedFeatures.sorted
|
||||
|
||||
@deprecated("not intended for subclasses to use", "2.1.0")
|
||||
protected def isSorted(array: Array[Int]): Boolean = {
|
||||
var i = 1
|
||||
val len = array.length
|
||||
|
@ -61,7 +62,7 @@ class ChiSqSelectorModel @Since("1.3.0") (
|
|||
*/
|
||||
@Since("1.3.0")
|
||||
override def transform(vector: Vector): Vector = {
|
||||
compress(vector, selectedFeatures)
|
||||
compress(vector)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -69,9 +70,8 @@ class ChiSqSelectorModel @Since("1.3.0") (
|
|||
* Preserves the order of filtered features the same as their indices are stored.
|
||||
* Might be moved to Vector as .slice
|
||||
* @param features vector
|
||||
* @param filterIndices indices of features to filter, must be ordered asc
|
||||
*/
|
||||
private def compress(features: Vector, filterIndices: Array[Int]): Vector = {
|
||||
private def compress(features: Vector): Vector = {
|
||||
features match {
|
||||
case SparseVector(size, indices, values) =>
|
||||
val newSize = filterIndices.length
|
||||
|
@ -230,23 +230,23 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
|
|||
*/
|
||||
@Since("1.3.0")
|
||||
def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = {
|
||||
val chiSqTestResult = Statistics.chiSqTest(data)
|
||||
val chiSqTestResult = Statistics.chiSqTest(data).zipWithIndex
|
||||
val features = selectorType match {
|
||||
case ChiSqSelector.KBest =>
|
||||
chiSqTestResult.zipWithIndex
|
||||
chiSqTestResult
|
||||
.sortBy { case (res, _) => -res.statistic }
|
||||
.take(numTopFeatures)
|
||||
case ChiSqSelector.Percentile =>
|
||||
chiSqTestResult.zipWithIndex
|
||||
chiSqTestResult
|
||||
.sortBy { case (res, _) => -res.statistic }
|
||||
.take((chiSqTestResult.length * percentile).toInt)
|
||||
case ChiSqSelector.FPR =>
|
||||
chiSqTestResult.zipWithIndex
|
||||
.filter{ case (res, _) => res.pValue < alpha }
|
||||
chiSqTestResult
|
||||
.filter { case (res, _) => res.pValue < alpha }
|
||||
case errorType =>
|
||||
throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType")
|
||||
}
|
||||
val indices = features.map { case (_, indices) => indices }.sorted
|
||||
val indices = features.map { case (_, index) => index }
|
||||
new ChiSqSelectorModel(indices)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2705,7 +2705,7 @@ class ChiSqSelectorModel(JavaModel, JavaMLReadable, JavaMLWritable):
|
|||
@since("2.0.0")
|
||||
def selectedFeatures(self):
|
||||
"""
|
||||
List of indices to select (filter). Must be ordered asc.
|
||||
List of indices to select (filter).
|
||||
"""
|
||||
return self._call_java("selectedFeatures")
|
||||
|
||||
|
|
Loading…
Reference in a new issue