[SPARK-5890] [ML] Add feature discretizer
JIRA issue [here](https://issues.apache.org/jira/browse/SPARK-5890). I borrow the code of `findSplits` from `RandomForest`. I don't think it's good to call it from `RandomForest` directly. Author: Xusen Yin <yinxusen@gmail.com> Closes #5779 from yinxusen/SPARK-5890.
This commit is contained in:
parent
2a717821bb
commit
23a9448c04
|
@ -0,0 +1,176 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.ml.feature
|
||||||
|
|
||||||
|
import scala.collection.mutable
|
||||||
|
|
||||||
|
import org.apache.spark.Logging
|
||||||
|
import org.apache.spark.annotation.Experimental
|
||||||
|
import org.apache.spark.ml._
|
||||||
|
import org.apache.spark.ml.attribute.NominalAttribute
|
||||||
|
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
|
||||||
|
import org.apache.spark.ml.param.{IntParam, _}
|
||||||
|
import org.apache.spark.ml.util._
|
||||||
|
import org.apache.spark.sql.types.{DoubleType, StructType}
|
||||||
|
import org.apache.spark.sql.{DataFrame, Row}
|
||||||
|
import org.apache.spark.util.random.XORShiftRandom
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Params for [[QuantileDiscretizer]].
|
||||||
|
*/
|
||||||
|
private[feature] trait QuantileDiscretizerBase extends Params with HasInputCol with HasOutputCol {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Maximum number of buckets (quantiles, or categories) into which data points are grouped. Must
|
||||||
|
* be >= 2.
|
||||||
|
* default: 2
|
||||||
|
* @group param
|
||||||
|
*/
|
||||||
|
val numBuckets = new IntParam(this, "numBuckets", "Maximum number of buckets (quantiles, or " +
|
||||||
|
"categories) into which data points are grouped. Must be >= 2.",
|
||||||
|
ParamValidators.gtEq(2))
|
||||||
|
setDefault(numBuckets -> 2)
|
||||||
|
|
||||||
|
/** @group getParam */
|
||||||
|
def getNumBuckets: Int = getOrDefault(numBuckets)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* :: Experimental ::
|
||||||
|
* `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned
|
||||||
|
* categorical features. The bin ranges are chosen by taking a sample of the data and dividing it
|
||||||
|
* into roughly equal parts. The lower and upper bin bounds will be -Infinity and +Infinity,
|
||||||
|
* covering all real values. This attempts to find numBuckets partitions based on a sample of data,
|
||||||
|
* but it may find fewer depending on the data sample values.
|
||||||
|
*/
|
||||||
|
@Experimental
|
||||||
|
final class QuantileDiscretizer(override val uid: String)
|
||||||
|
extends Estimator[Bucketizer] with QuantileDiscretizerBase {
|
||||||
|
|
||||||
|
def this() = this(Identifiable.randomUID("quantileDiscretizer"))
|
||||||
|
|
||||||
|
/** @group setParam */
|
||||||
|
def setNumBuckets(value: Int): this.type = set(numBuckets, value)
|
||||||
|
|
||||||
|
/** @group setParam */
|
||||||
|
def setInputCol(value: String): this.type = set(inputCol, value)
|
||||||
|
|
||||||
|
/** @group setParam */
|
||||||
|
def setOutputCol(value: String): this.type = set(outputCol, value)
|
||||||
|
|
||||||
|
override def transformSchema(schema: StructType): StructType = {
|
||||||
|
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
|
||||||
|
val inputFields = schema.fields
|
||||||
|
require(inputFields.forall(_.name != $(outputCol)),
|
||||||
|
s"Output column ${$(outputCol)} already exists.")
|
||||||
|
val attr = NominalAttribute.defaultAttr.withName($(outputCol))
|
||||||
|
val outputFields = inputFields :+ attr.toStructField()
|
||||||
|
StructType(outputFields)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def fit(dataset: DataFrame): Bucketizer = {
|
||||||
|
val samples = QuantileDiscretizer.getSampledInput(dataset.select($(inputCol)), $(numBuckets))
|
||||||
|
.map { case Row(feature: Double) => feature }
|
||||||
|
val candidates = QuantileDiscretizer.findSplitCandidates(samples, $(numBuckets) - 1)
|
||||||
|
val splits = QuantileDiscretizer.getSplits(candidates)
|
||||||
|
val bucketizer = new Bucketizer(uid).setSplits(splits)
|
||||||
|
copyValues(bucketizer)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def copy(extra: ParamMap): QuantileDiscretizer = defaultCopy(extra)
|
||||||
|
}
|
||||||
|
|
||||||
|
private[feature] object QuantileDiscretizer extends Logging {
|
||||||
|
/**
|
||||||
|
* Sampling from the given dataset to collect quantile statistics.
|
||||||
|
*/
|
||||||
|
def getSampledInput(dataset: DataFrame, numBins: Int): Array[Row] = {
|
||||||
|
val totalSamples = dataset.count()
|
||||||
|
require(totalSamples > 0,
|
||||||
|
"QuantileDiscretizer requires non-empty input dataset but was given an empty input.")
|
||||||
|
val requiredSamples = math.max(numBins * numBins, 10000)
|
||||||
|
val fraction = math.min(requiredSamples / dataset.count(), 1.0)
|
||||||
|
dataset.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compute split points with respect to the sample distribution.
|
||||||
|
*/
|
||||||
|
def findSplitCandidates(samples: Array[Double], numSplits: Int): Array[Double] = {
|
||||||
|
val valueCountMap = samples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
|
||||||
|
m + ((x, m.getOrElse(x, 0) + 1))
|
||||||
|
}
|
||||||
|
val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray ++ Array((Double.MaxValue, 1))
|
||||||
|
val possibleSplits = valueCounts.length - 1
|
||||||
|
if (possibleSplits <= numSplits) {
|
||||||
|
valueCounts.dropRight(1).map(_._1)
|
||||||
|
} else {
|
||||||
|
val stride: Double = math.ceil(samples.length.toDouble / (numSplits + 1))
|
||||||
|
val splitsBuilder = mutable.ArrayBuilder.make[Double]
|
||||||
|
var index = 1
|
||||||
|
// currentCount: sum of counts of values that have been visited
|
||||||
|
var currentCount = valueCounts(0)._2
|
||||||
|
// targetCount: target value for `currentCount`. If `currentCount` is closest value to
|
||||||
|
// `targetCount`, then current value is a split threshold. After finding a split threshold,
|
||||||
|
// `targetCount` is added by stride.
|
||||||
|
var targetCount = stride
|
||||||
|
while (index < valueCounts.length) {
|
||||||
|
val previousCount = currentCount
|
||||||
|
currentCount += valueCounts(index)._2
|
||||||
|
val previousGap = math.abs(previousCount - targetCount)
|
||||||
|
val currentGap = math.abs(currentCount - targetCount)
|
||||||
|
// If adding count of current value to currentCount makes the gap between currentCount and
|
||||||
|
// targetCount smaller, previous value is a split threshold.
|
||||||
|
if (previousGap < currentGap) {
|
||||||
|
splitsBuilder += valueCounts(index - 1)._1
|
||||||
|
targetCount += stride
|
||||||
|
}
|
||||||
|
index += 1
|
||||||
|
}
|
||||||
|
splitsBuilder.result()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adjust split candidates to proper splits by: adding positive/negative infinity to both sides as
|
||||||
|
* needed, and adding a default split value of 0 if no good candidates are found.
|
||||||
|
*/
|
||||||
|
def getSplits(candidates: Array[Double]): Array[Double] = {
|
||||||
|
val effectiveValues = if (candidates.size != 0) {
|
||||||
|
if (candidates.head == Double.NegativeInfinity
|
||||||
|
&& candidates.last == Double.PositiveInfinity) {
|
||||||
|
candidates.drop(1).dropRight(1)
|
||||||
|
} else if (candidates.head == Double.NegativeInfinity) {
|
||||||
|
candidates.drop(1)
|
||||||
|
} else if (candidates.last == Double.PositiveInfinity) {
|
||||||
|
candidates.dropRight(1)
|
||||||
|
} else {
|
||||||
|
candidates
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
candidates
|
||||||
|
}
|
||||||
|
|
||||||
|
if (effectiveValues.size == 0) {
|
||||||
|
Array(Double.NegativeInfinity, 0, Double.PositiveInfinity)
|
||||||
|
} else {
|
||||||
|
Array(Double.NegativeInfinity) ++ effectiveValues ++ Array(Double.PositiveInfinity)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,98 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.ml.feature
|
||||||
|
|
||||||
|
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
|
||||||
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
|
import org.apache.spark.sql.{Row, SQLContext}
|
||||||
|
import org.apache.spark.{SparkContext, SparkFunSuite}
|
||||||
|
|
||||||
|
class QuantileDiscretizerSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
|
import org.apache.spark.ml.feature.QuantileDiscretizerSuite._
|
||||||
|
|
||||||
|
test("Test quantile discretizer") {
|
||||||
|
checkDiscretizedData(sc,
|
||||||
|
Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
|
||||||
|
10,
|
||||||
|
Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
|
||||||
|
Array("-Infinity, 1.0", "1.0, 2.0", "2.0, 3.0", "3.0, Infinity"))
|
||||||
|
|
||||||
|
checkDiscretizedData(sc,
|
||||||
|
Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
|
||||||
|
4,
|
||||||
|
Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
|
||||||
|
Array("-Infinity, 1.0", "1.0, 2.0", "2.0, 3.0", "3.0, Infinity"))
|
||||||
|
|
||||||
|
checkDiscretizedData(sc,
|
||||||
|
Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
|
||||||
|
3,
|
||||||
|
Array[Double](0, 1, 2, 2, 2, 2, 2, 2, 2),
|
||||||
|
Array("-Infinity, 2.0", "2.0, 3.0", "3.0, Infinity"))
|
||||||
|
|
||||||
|
checkDiscretizedData(sc,
|
||||||
|
Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
|
||||||
|
2,
|
||||||
|
Array[Double](0, 1, 1, 1, 1, 1, 1, 1, 1),
|
||||||
|
Array("-Infinity, 2.0", "2.0, Infinity"))
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
test("Test getting splits") {
|
||||||
|
val splitTestPoints = Array(
|
||||||
|
Array[Double]() -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
|
||||||
|
Array(Double.NegativeInfinity) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
|
||||||
|
Array(Double.PositiveInfinity) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
|
||||||
|
Array(Double.NegativeInfinity, Double.PositiveInfinity)
|
||||||
|
-> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
|
||||||
|
Array(0.0) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
|
||||||
|
Array(1.0) -> Array(Double.NegativeInfinity, 1, Double.PositiveInfinity),
|
||||||
|
Array(0.0, 1.0) -> Array(Double.NegativeInfinity, 0, 1, Double.PositiveInfinity)
|
||||||
|
)
|
||||||
|
for ((ori, res) <- splitTestPoints) {
|
||||||
|
assert(QuantileDiscretizer.getSplits(ori) === res, "Returned splits are invalid.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private object QuantileDiscretizerSuite extends SparkFunSuite {
|
||||||
|
|
||||||
|
def checkDiscretizedData(
|
||||||
|
sc: SparkContext,
|
||||||
|
data: Array[Double],
|
||||||
|
numBucket: Int,
|
||||||
|
expectedResult: Array[Double],
|
||||||
|
expectedAttrs: Array[String]): Unit = {
|
||||||
|
val sqlCtx = SQLContext.getOrCreate(sc)
|
||||||
|
import sqlCtx.implicits._
|
||||||
|
|
||||||
|
val df = sc.parallelize(data.map(Tuple1.apply)).toDF("input")
|
||||||
|
val discretizer = new QuantileDiscretizer().setInputCol("input").setOutputCol("result")
|
||||||
|
.setNumBuckets(numBucket)
|
||||||
|
val result = discretizer.fit(df).transform(df)
|
||||||
|
|
||||||
|
val transformedFeatures = result.select("result").collect()
|
||||||
|
.map { case Row(transformedFeature: Double) => transformedFeature }
|
||||||
|
val transformedAttrs = Attribute.fromStructField(result.schema("result"))
|
||||||
|
.asInstanceOf[NominalAttribute].values.get
|
||||||
|
|
||||||
|
assert(transformedFeatures === expectedResult,
|
||||||
|
"Transformed features do not equal expected features.")
|
||||||
|
assert(transformedAttrs === expectedAttrs,
|
||||||
|
"Transformed attributes do not equal expected attributes.")
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue