[SPARK-17704][ML][MLLIB] ChiSqSelector performance improvement.

## What changes were proposed in this pull request?

Partial revert of #15277 to instead sort and store input to model rather than require sorted input

## How was this patch tested?

Existing tests.

Author: Sean Owen <sowen@cloudera.com>

Closes #15299 from srowen/SPARK-17704.2.
This commit is contained in:
Sean Owen 2016-10-01 16:10:39 -04:00
parent af6ece33d3
commit b88cb63da3
No known key found for this signature in database
GPG key ID: BEB3956D6717BDDC
3 changed files with 13 additions and 13 deletions

View file

@ -193,7 +193,7 @@ final class ChiSqSelectorModel private[ml] (
import ChiSqSelectorModel._
/** list of indices to select (filter). Must be ordered asc */
/** list of indices to select (filter). */
@Since("1.6.0")
val selectedFeatures: Array[Int] = chiSqSelector.selectedFeatures

View file

@ -35,14 +35,15 @@ import org.apache.spark.sql.{Row, SparkSession}
/**
* Chi Squared selector model.
*
* @param selectedFeatures list of indices to select (filter). Must be ordered asc
* @param selectedFeatures list of indices to select (filter).
*/
@Since("1.3.0")
class ChiSqSelectorModel @Since("1.3.0") (
@Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable {
require(isSorted(selectedFeatures), "Array has to be sorted asc")
private val filterIndices = selectedFeatures.sorted
@deprecated("not intended for subclasses to use", "2.1.0")
protected def isSorted(array: Array[Int]): Boolean = {
var i = 1
val len = array.length
@ -61,7 +62,7 @@ class ChiSqSelectorModel @Since("1.3.0") (
*/
@Since("1.3.0")
override def transform(vector: Vector): Vector = {
compress(vector, selectedFeatures)
compress(vector)
}
/**
@ -69,9 +70,8 @@ class ChiSqSelectorModel @Since("1.3.0") (
* Preserves the order of filtered features the same as their indices are stored.
* Might be moved to Vector as .slice
* @param features vector
* @param filterIndices indices of features to filter, must be ordered asc
*/
private def compress(features: Vector, filterIndices: Array[Int]): Vector = {
private def compress(features: Vector): Vector = {
features match {
case SparseVector(size, indices, values) =>
val newSize = filterIndices.length
@ -230,23 +230,23 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
*/
@Since("1.3.0")
def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = {
val chiSqTestResult = Statistics.chiSqTest(data)
val chiSqTestResult = Statistics.chiSqTest(data).zipWithIndex
val features = selectorType match {
case ChiSqSelector.KBest =>
chiSqTestResult.zipWithIndex
chiSqTestResult
.sortBy { case (res, _) => -res.statistic }
.take(numTopFeatures)
case ChiSqSelector.Percentile =>
chiSqTestResult.zipWithIndex
chiSqTestResult
.sortBy { case (res, _) => -res.statistic }
.take((chiSqTestResult.length * percentile).toInt)
case ChiSqSelector.FPR =>
chiSqTestResult.zipWithIndex
.filter{ case (res, _) => res.pValue < alpha }
chiSqTestResult
.filter { case (res, _) => res.pValue < alpha }
case errorType =>
throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType")
}
val indices = features.map { case (_, indices) => indices }.sorted
val indices = features.map { case (_, index) => index }
new ChiSqSelectorModel(indices)
}
}

View file

@ -2705,7 +2705,7 @@ class ChiSqSelectorModel(JavaModel, JavaMLReadable, JavaMLWritable):
@since("2.0.0")
def selectedFeatures(self):
"""
List of indices to select (filter). Must be ordered asc.
List of indices to select (filter).
"""
return self._call_java("selectedFeatures")