[SPARK-19591][ML][MLLIB] Add sample weights to decision trees
This is updated PR https://github.com/apache/spark/pull/16722 to latest master ## What changes were proposed in this pull request? This patch adds support for sample weights to DecisionTreeRegressor and DecisionTreeClassifier. Note: This patch does not add support for sample weights to RandomForest. As discussed in the JIRA, we would like to add sample weights into the bagging process. This patch is large enough as is, and there are some additional considerations to be made for random forests. Since the machinery introduced here needs to be present regardless, I have opted to leave random forests for a follow up pr. ## How was this patch tested? The algorithms are tested to ensure that: 1. Arbitrary scaling of constant weights has no effect 2. Outliers with small weights do not affect the learned model 3. Oversampling and weighting are equivalent Unit tests are also added to test other smaller components. ## Summary of changes - Impurity aggregators now store weighted sufficient statistics. They also store a raw count, however, since this is needed to use minInstancesPerNode. - Impurity aggregators now also hold the raw count. - This patch maintains the meaning of minInstancesPerNode, in that the parameter still corresponds to raw, unweighted counts. It also adds a new parameter minWeightFractionPerNode which requires that nodes must contain at least minWeightFractionPerNode * weightedNumExamples total weight. - This patch modifies findSplitsForContinuousFeatures to use weighted sums. Unit tests are added. - TreePoint is modified to hold a sample weight - BaggedPoint is modified from: ``` Scala private[spark] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double]) extends Serializable ``` to ``` Scala private[spark] class BaggedPoint[Datum]( val datum: Datum, val subsampleCounts: Array[Int], val sampleWeight: Double) extends Serializable ``` We do not simply multiply the counts by the weight and store that because we need the raw counts and the weight in order to use both minInstancesPerNode and minWeightPerNode **Note**: many of the changed files are due simply to using Instance instead of LabeledPoint Closes #21632 from imatiach-msft/ilmat/sample-weights. Authored-by: Ilya Matiach <ilmat@microsoft.com> Signed-off-by: Sean Owen <sean.owen@databricks.com>
This commit is contained in:
parent
3699763fda
commit
b2d36f65db
|
@ -52,7 +52,7 @@ object TestingUtils {
|
|||
/**
|
||||
* Private helper function for comparing two values using absolute tolerance.
|
||||
*/
|
||||
private def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
|
||||
private[ml] def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
|
||||
// Special case for NaNs
|
||||
if (x.isNaN && y.isNaN) {
|
||||
return true
|
||||
|
|
|
@ -77,17 +77,37 @@ abstract class Classifier[
|
|||
* @note Throws `SparkException` if any label is a non-integer or is negative
|
||||
*/
|
||||
protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = {
|
||||
require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" +
|
||||
s" $numClasses, but requires numClasses > 0.")
|
||||
validateNumClasses(numClasses)
|
||||
dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
|
||||
case Row(label: Double, features: Vector) =>
|
||||
require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" +
|
||||
s" dataset with invalid label $label. Labels must be integers in range" +
|
||||
s" [0, $numClasses).")
|
||||
validateLabel(label, numClasses)
|
||||
LabeledPoint(label, features)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates that number of classes is greater than zero.
|
||||
*
|
||||
* @param numClasses Number of classes label can take.
|
||||
*/
|
||||
protected def validateNumClasses(numClasses: Int): Unit = {
|
||||
require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" +
|
||||
s" $numClasses, but requires numClasses > 0.")
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates the label on the classifier is a valid integer in the range [0, numClasses).
|
||||
*
|
||||
* @param label The label to validate.
|
||||
* @param numClasses Number of classes label can take. Labels must be integers in the range
|
||||
* [0, numClasses).
|
||||
*/
|
||||
protected def validateLabel(label: Double, numClasses: Int): Unit = {
|
||||
require(label.toLong == label && label >= 0 && label < numClasses, s"Classifier was given" +
|
||||
s" dataset with invalid label $label. Labels must be integers in range" +
|
||||
s" [0, $numClasses).")
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the number of classes. This looks in column metadata first, and if that is missing,
|
||||
* then this assumes classes are indexed 0,1,...,numClasses-1 and computes numClasses
|
||||
|
|
|
@ -22,10 +22,12 @@ import org.json4s.{DefaultFormats, JObject}
|
|||
import org.json4s.JsonDSL._
|
||||
|
||||
import org.apache.spark.annotation.Since
|
||||
import org.apache.spark.ml.feature.LabeledPoint
|
||||
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
|
||||
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.ml.param.shared.HasWeightCol
|
||||
import org.apache.spark.ml.tree._
|
||||
import org.apache.spark.ml.tree.{DecisionTreeModel, Node, TreeClassifierParams}
|
||||
import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._
|
||||
import org.apache.spark.ml.tree.impl.RandomForest
|
||||
import org.apache.spark.ml.util._
|
||||
|
@ -33,8 +35,9 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
|
|||
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
|
||||
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.Dataset
|
||||
|
||||
import org.apache.spark.sql.{Dataset, Row}
|
||||
import org.apache.spark.sql.functions.{col, lit}
|
||||
import org.apache.spark.sql.types.DoubleType
|
||||
|
||||
/**
|
||||
* Decision tree learning algorithm (http://en.wikipedia.org/wiki/Decision_tree_learning)
|
||||
|
@ -66,6 +69,9 @@ class DecisionTreeClassifier @Since("1.4.0") (
|
|||
def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
|
||||
|
||||
/** @group setParam */
|
||||
@Since("3.0.0")
|
||||
def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value)
|
||||
|
||||
@Since("1.4.0")
|
||||
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
|
||||
|
||||
|
@ -97,6 +103,16 @@ class DecisionTreeClassifier @Since("1.4.0") (
|
|||
@Since("1.6.0")
|
||||
def setSeed(value: Long): this.type = set(seed, value)
|
||||
|
||||
/**
|
||||
* Sets the value of param [[weightCol]].
|
||||
* If this is not set or empty, we treat all instance weights as 1.0.
|
||||
* Default is not set, so all instances have weight one.
|
||||
*
|
||||
* @group setParam
|
||||
*/
|
||||
@Since("3.0.0")
|
||||
def setWeightCol(value: String): this.type = set(weightCol, value)
|
||||
|
||||
override protected def train(
|
||||
dataset: Dataset[_]): DecisionTreeClassificationModel = instrumented { instr =>
|
||||
instr.logPipelineStage(this)
|
||||
|
@ -104,22 +120,27 @@ class DecisionTreeClassifier @Since("1.4.0") (
|
|||
val categoricalFeatures: Map[Int, Int] =
|
||||
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
|
||||
val numClasses: Int = getNumClasses(dataset)
|
||||
instr.logNumClasses(numClasses)
|
||||
|
||||
if (isDefined(thresholds)) {
|
||||
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
|
||||
".train() called with non-matching numClasses and thresholds.length." +
|
||||
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
|
||||
}
|
||||
|
||||
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
|
||||
validateNumClasses(numClasses)
|
||||
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
|
||||
val instances =
|
||||
dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
|
||||
case Row(label: Double, weight: Double, features: Vector) =>
|
||||
validateLabel(label, numClasses)
|
||||
Instance(label, weight, features)
|
||||
}
|
||||
val strategy = getOldStrategy(categoricalFeatures, numClasses)
|
||||
|
||||
instr.logNumClasses(numClasses)
|
||||
instr.logParams(this, labelCol, featuresCol, predictionCol, rawPredictionCol,
|
||||
probabilityCol, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB,
|
||||
cacheNodeIds, checkpointInterval, impurity, seed)
|
||||
|
||||
val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
|
||||
val trees = RandomForest.run(instances, strategy, numTrees = 1, featureSubsetStrategy = "all",
|
||||
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
|
||||
|
||||
trees.head.asInstanceOf[DecisionTreeClassificationModel]
|
||||
|
@ -128,13 +149,13 @@ class DecisionTreeClassifier @Since("1.4.0") (
|
|||
/** (private[ml]) Train a decision tree on an RDD */
|
||||
private[ml] def train(data: RDD[LabeledPoint],
|
||||
oldStrategy: OldStrategy): DecisionTreeClassificationModel = instrumented { instr =>
|
||||
val instances = data.map(_.toInstance)
|
||||
instr.logPipelineStage(this)
|
||||
instr.logDataset(data)
|
||||
instr.logDataset(instances)
|
||||
instr.logParams(this, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB,
|
||||
cacheNodeIds, checkpointInterval, impurity, seed)
|
||||
|
||||
val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all",
|
||||
seed = 0L, instr = Some(instr), parentUID = Some(uid))
|
||||
val trees = RandomForest.run(instances, oldStrategy, numTrees = 1,
|
||||
featureSubsetStrategy = "all", seed = 0L, instr = Some(instr), parentUID = Some(uid))
|
||||
|
||||
trees.head.asInstanceOf[DecisionTreeClassificationModel]
|
||||
}
|
||||
|
@ -180,6 +201,7 @@ class DecisionTreeClassificationModel private[ml] (
|
|||
|
||||
/**
|
||||
* Construct a decision tree classification model.
|
||||
*
|
||||
* @param rootNode Root node of tree, with other nodes attached.
|
||||
*/
|
||||
private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) =
|
||||
|
|
|
@ -21,20 +21,21 @@ import org.json4s.{DefaultFormats, JObject}
|
|||
import org.json4s.JsonDSL._
|
||||
|
||||
import org.apache.spark.annotation.Since
|
||||
import org.apache.spark.ml.feature.LabeledPoint
|
||||
import org.apache.spark.ml.feature.Instance
|
||||
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.ml.tree._
|
||||
import org.apache.spark.ml.tree.{RandomForestParams, TreeClassifierParams, TreeEnsembleModel}
|
||||
import org.apache.spark.ml.tree.impl.RandomForest
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
|
||||
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
|
||||
import org.apache.spark.ml.util.Instrumentation.instrumented
|
||||
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
|
||||
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||
import org.apache.spark.sql.functions._
|
||||
|
||||
import org.apache.spark.sql.functions.{col, udf}
|
||||
|
||||
/**
|
||||
* <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forest</a> learning algorithm for
|
||||
|
@ -130,7 +131,7 @@ class RandomForestClassifier @Since("1.4.0") (
|
|||
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
|
||||
}
|
||||
|
||||
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
|
||||
val instances: RDD[Instance] = extractLabeledPoints(dataset, numClasses).map(_.toInstance)
|
||||
val strategy =
|
||||
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
|
||||
|
||||
|
@ -139,7 +140,7 @@ class RandomForestClassifier @Since("1.4.0") (
|
|||
minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, checkpointInterval)
|
||||
|
||||
val trees = RandomForest
|
||||
.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
|
||||
.run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
|
||||
.map(_.asInstanceOf[DecisionTreeClassificationModel])
|
||||
|
||||
val numFeatures = trees.head.numFeatures
|
||||
|
|
|
@ -37,4 +37,13 @@ case class LabeledPoint(@Since("2.0.0") label: Double, @Since("2.0.0") features:
|
|||
override def toString: String = {
|
||||
s"($label,$features)"
|
||||
}
|
||||
|
||||
private[spark] def toInstance(weight: Double): Instance = {
|
||||
Instance(label, weight, features)
|
||||
}
|
||||
|
||||
private[spark] def toInstance: Instance = {
|
||||
Instance(label, 1.0, features)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -23,9 +23,10 @@ import org.json4s.JsonDSL._
|
|||
|
||||
import org.apache.spark.annotation.Since
|
||||
import org.apache.spark.ml.{PredictionModel, Predictor}
|
||||
import org.apache.spark.ml.feature.LabeledPoint
|
||||
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
|
||||
import org.apache.spark.ml.linalg.Vector
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.ml.param.shared.HasWeightCol
|
||||
import org.apache.spark.ml.tree._
|
||||
import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._
|
||||
import org.apache.spark.ml.tree.impl.RandomForest
|
||||
|
@ -34,8 +35,9 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
|
|||
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
|
||||
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types.DoubleType
|
||||
|
||||
|
||||
/**
|
||||
|
@ -65,6 +67,9 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
|
|||
def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
|
||||
|
||||
/** @group setParam */
|
||||
@Since("3.0.0")
|
||||
def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value)
|
||||
|
||||
@Since("1.4.0")
|
||||
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
|
||||
|
||||
|
@ -100,18 +105,33 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
|
|||
@Since("2.0.0")
|
||||
def setVarianceCol(value: String): this.type = set(varianceCol, value)
|
||||
|
||||
/**
|
||||
* Sets the value of param [[weightCol]].
|
||||
* If this is not set or empty, we treat all instance weights as 1.0.
|
||||
* Default is not set, so all instances have weight one.
|
||||
*
|
||||
* @group setParam
|
||||
*/
|
||||
@Since("3.0.0")
|
||||
def setWeightCol(value: String): this.type = set(weightCol, value)
|
||||
|
||||
override protected def train(
|
||||
dataset: Dataset[_]): DecisionTreeRegressionModel = instrumented { instr =>
|
||||
val categoricalFeatures: Map[Int, Int] =
|
||||
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
|
||||
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
|
||||
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
|
||||
val instances =
|
||||
dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
|
||||
case Row(label: Double, weight: Double, features: Vector) =>
|
||||
Instance(label, weight, features)
|
||||
}
|
||||
val strategy = getOldStrategy(categoricalFeatures)
|
||||
|
||||
instr.logPipelineStage(this)
|
||||
instr.logDataset(oldDataset)
|
||||
instr.logDataset(instances)
|
||||
instr.logParams(this, params: _*)
|
||||
|
||||
val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
|
||||
val trees = RandomForest.run(instances, strategy, numTrees = 1, featureSubsetStrategy = "all",
|
||||
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
|
||||
|
||||
trees.head.asInstanceOf[DecisionTreeRegressionModel]
|
||||
|
@ -126,8 +146,9 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
|
|||
instr.logDataset(data)
|
||||
instr.logParams(this, params: _*)
|
||||
|
||||
val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy,
|
||||
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
|
||||
val instances = data.map(_.toInstance)
|
||||
val trees = RandomForest.run(instances, oldStrategy, numTrees = 1,
|
||||
featureSubsetStrategy, seed = $(seed), instr = Some(instr), parentUID = Some(uid))
|
||||
|
||||
trees.head.asInstanceOf[DecisionTreeRegressionModel]
|
||||
}
|
||||
|
@ -155,6 +176,7 @@ object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor
|
|||
* <a href="http://en.wikipedia.org/wiki/Decision_tree_learning">
|
||||
* Decision tree (Wikipedia)</a> model for regression.
|
||||
* It supports both continuous and categorical features.
|
||||
*
|
||||
* @param rootNode Root of the decision tree
|
||||
*/
|
||||
@Since("1.4.0")
|
||||
|
@ -173,6 +195,7 @@ class DecisionTreeRegressionModel private[ml] (
|
|||
|
||||
/**
|
||||
* Construct a decision tree regression model.
|
||||
*
|
||||
* @param rootNode Root node of tree, with other nodes attached.
|
||||
*/
|
||||
private[ml] def this(rootNode: Node, numFeatures: Int) =
|
||||
|
|
|
@ -22,7 +22,6 @@ import org.json4s.JsonDSL._
|
|||
|
||||
import org.apache.spark.annotation.Since
|
||||
import org.apache.spark.ml.{PredictionModel, Predictor}
|
||||
import org.apache.spark.ml.feature.LabeledPoint
|
||||
import org.apache.spark.ml.linalg.Vector
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.ml.tree._
|
||||
|
@ -32,10 +31,8 @@ import org.apache.spark.ml.util.DefaultParamsReader.Metadata
|
|||
import org.apache.spark.ml.util.Instrumentation.instrumented
|
||||
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
|
||||
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||
import org.apache.spark.sql.functions._
|
||||
|
||||
import org.apache.spark.sql.functions.{col, udf}
|
||||
|
||||
/**
|
||||
* <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forest</a>
|
||||
|
@ -119,18 +116,19 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
|
|||
dataset: Dataset[_]): RandomForestRegressionModel = instrumented { instr =>
|
||||
val categoricalFeatures: Map[Int, Int] =
|
||||
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
|
||||
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
|
||||
|
||||
val instances = extractLabeledPoints(dataset).map(_.toInstance)
|
||||
val strategy =
|
||||
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity)
|
||||
|
||||
instr.logPipelineStage(this)
|
||||
instr.logDataset(dataset)
|
||||
instr.logDataset(instances)
|
||||
instr.logParams(this, labelCol, featuresCol, predictionCol, impurity, numTrees,
|
||||
featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain,
|
||||
minInstancesPerNode, seed, subsamplingRate, cacheNodeIds, checkpointInterval)
|
||||
|
||||
val trees = RandomForest
|
||||
.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
|
||||
.run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
|
||||
.map(_.asInstanceOf[DecisionTreeRegressionModel])
|
||||
|
||||
val numFeatures = trees.head.numFeatures
|
||||
|
|
|
@ -33,13 +33,13 @@ import org.apache.spark.util.random.XORShiftRandom
|
|||
* this datum has 1 copy, 0 copies, and 4 copies in the 3 subsamples, respectively.
|
||||
*
|
||||
* @param datum Data instance
|
||||
* @param subsampleWeights Weight of this instance in each subsampled dataset.
|
||||
*
|
||||
* TODO: This does not currently support (Double) weighted instances. Once MLlib has weighted
|
||||
* dataset support, update. (We store subsampleWeights as Double for this future extension.)
|
||||
* @param subsampleCounts Number of samples of this instance in each subsampled dataset.
|
||||
* @param sampleWeight The weight of this instance.
|
||||
*/
|
||||
private[spark] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double])
|
||||
extends Serializable
|
||||
private[spark] class BaggedPoint[Datum](
|
||||
val datum: Datum,
|
||||
val subsampleCounts: Array[Int],
|
||||
val sampleWeight: Double = 1.0) extends Serializable
|
||||
|
||||
private[spark] object BaggedPoint {
|
||||
|
||||
|
@ -52,6 +52,7 @@ private[spark] object BaggedPoint {
|
|||
* @param subsamplingRate Fraction of the training data used for learning decision tree.
|
||||
* @param numSubsamples Number of subsamples of this RDD to take.
|
||||
* @param withReplacement Sampling with/without replacement.
|
||||
* @param extractSampleWeight A function to get the sample weight of each datum.
|
||||
* @param seed Random seed.
|
||||
* @return BaggedPoint dataset representation.
|
||||
*/
|
||||
|
@ -60,12 +61,14 @@ private[spark] object BaggedPoint {
|
|||
subsamplingRate: Double,
|
||||
numSubsamples: Int,
|
||||
withReplacement: Boolean,
|
||||
extractSampleWeight: (Datum => Double) = (_: Datum) => 1.0,
|
||||
seed: Long = Utils.random.nextLong()): RDD[BaggedPoint[Datum]] = {
|
||||
// TODO: implement weighted bootstrapping
|
||||
if (withReplacement) {
|
||||
convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed)
|
||||
} else {
|
||||
if (numSubsamples == 1 && subsamplingRate == 1.0) {
|
||||
convertToBaggedRDDWithoutSampling(input)
|
||||
convertToBaggedRDDWithoutSampling(input, extractSampleWeight)
|
||||
} else {
|
||||
convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed)
|
||||
}
|
||||
|
@ -82,16 +85,15 @@ private[spark] object BaggedPoint {
|
|||
val rng = new XORShiftRandom
|
||||
rng.setSeed(seed + partitionIndex + 1)
|
||||
instances.map { instance =>
|
||||
val subsampleWeights = new Array[Double](numSubsamples)
|
||||
val subsampleCounts = new Array[Int](numSubsamples)
|
||||
var subsampleIndex = 0
|
||||
while (subsampleIndex < numSubsamples) {
|
||||
val x = rng.nextDouble()
|
||||
subsampleWeights(subsampleIndex) = {
|
||||
if (x < subsamplingRate) 1.0 else 0.0
|
||||
if (rng.nextDouble() < subsamplingRate) {
|
||||
subsampleCounts(subsampleIndex) = 1
|
||||
}
|
||||
subsampleIndex += 1
|
||||
}
|
||||
new BaggedPoint(instance, subsampleWeights)
|
||||
new BaggedPoint(instance, subsampleCounts)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -106,20 +108,20 @@ private[spark] object BaggedPoint {
|
|||
val poisson = new PoissonDistribution(subsample)
|
||||
poisson.reseedRandomGenerator(seed + partitionIndex + 1)
|
||||
instances.map { instance =>
|
||||
val subsampleWeights = new Array[Double](numSubsamples)
|
||||
val subsampleCounts = new Array[Int](numSubsamples)
|
||||
var subsampleIndex = 0
|
||||
while (subsampleIndex < numSubsamples) {
|
||||
subsampleWeights(subsampleIndex) = poisson.sample()
|
||||
subsampleCounts(subsampleIndex) = poisson.sample()
|
||||
subsampleIndex += 1
|
||||
}
|
||||
new BaggedPoint(instance, subsampleWeights)
|
||||
new BaggedPoint(instance, subsampleCounts)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def convertToBaggedRDDWithoutSampling[Datum] (
|
||||
input: RDD[Datum]): RDD[BaggedPoint[Datum]] = {
|
||||
input.map(datum => new BaggedPoint(datum, Array(1.0)))
|
||||
input: RDD[Datum],
|
||||
extractSampleWeight: (Datum => Double)): RDD[BaggedPoint[Datum]] = {
|
||||
input.map(datum => new BaggedPoint(datum, Array(1), extractSampleWeight(datum)))
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -104,16 +104,21 @@ private[spark] class DTStatsAggregator(
|
|||
/**
|
||||
* Update the stats for a given (feature, bin) for ordered features, using the given label.
|
||||
*/
|
||||
def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = {
|
||||
def update(
|
||||
featureIndex: Int,
|
||||
binIndex: Int,
|
||||
label: Double,
|
||||
numSamples: Int,
|
||||
sampleWeight: Double): Unit = {
|
||||
val i = featureOffsets(featureIndex) + binIndex * statsSize
|
||||
impurityAggregator.update(allStats, i, label, instanceWeight)
|
||||
impurityAggregator.update(allStats, i, label, numSamples, sampleWeight)
|
||||
}
|
||||
|
||||
/**
|
||||
* Update the parent node stats using the given label.
|
||||
*/
|
||||
def updateParent(label: Double, instanceWeight: Double): Unit = {
|
||||
impurityAggregator.update(parentStats, 0, label, instanceWeight)
|
||||
def updateParent(label: Double, numSamples: Int, sampleWeight: Double): Unit = {
|
||||
impurityAggregator.update(parentStats, 0, label, numSamples, sampleWeight)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -127,9 +132,10 @@ private[spark] class DTStatsAggregator(
|
|||
featureOffset: Int,
|
||||
binIndex: Int,
|
||||
label: Double,
|
||||
instanceWeight: Double): Unit = {
|
||||
numSamples: Int,
|
||||
sampleWeight: Double): Unit = {
|
||||
impurityAggregator.update(allStats, featureOffset + binIndex * statsSize,
|
||||
label, instanceWeight)
|
||||
label, numSamples, sampleWeight)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -21,7 +21,7 @@ import scala.collection.mutable
|
|||
import scala.util.Try
|
||||
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.ml.feature.LabeledPoint
|
||||
import org.apache.spark.ml.feature.Instance
|
||||
import org.apache.spark.ml.tree.TreeEnsembleParams
|
||||
import org.apache.spark.mllib.tree.configuration.Algo._
|
||||
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
|
||||
|
@ -32,16 +32,20 @@ import org.apache.spark.rdd.RDD
|
|||
/**
|
||||
* Learning and dataset metadata for DecisionTree.
|
||||
*
|
||||
* @param weightedNumExamples Weighted count of samples in the tree.
|
||||
* @param numClasses For classification: labels can take values {0, ..., numClasses - 1}.
|
||||
* For regression: fixed at 0 (no meaning).
|
||||
* @param maxBins Maximum number of bins, for all features.
|
||||
* @param featureArity Map: categorical feature index to arity.
|
||||
* I.e., the feature takes values in {0, ..., arity - 1}.
|
||||
* @param numBins Number of bins for each feature.
|
||||
* @param minWeightFractionPerNode The minimum fraction of the total sample weight that must be
|
||||
* present in a leaf node in order to be considered a valid split.
|
||||
*/
|
||||
private[spark] class DecisionTreeMetadata(
|
||||
val numFeatures: Int,
|
||||
val numExamples: Long,
|
||||
val weightedNumExamples: Double,
|
||||
val numClasses: Int,
|
||||
val maxBins: Int,
|
||||
val featureArity: Map[Int, Int],
|
||||
|
@ -51,6 +55,7 @@ private[spark] class DecisionTreeMetadata(
|
|||
val quantileStrategy: QuantileStrategy,
|
||||
val maxDepth: Int,
|
||||
val minInstancesPerNode: Int,
|
||||
val minWeightFractionPerNode: Double,
|
||||
val minInfoGain: Double,
|
||||
val numTrees: Int,
|
||||
val numFeaturesPerNode: Int) extends Serializable {
|
||||
|
@ -67,6 +72,8 @@ private[spark] class DecisionTreeMetadata(
|
|||
|
||||
def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex)
|
||||
|
||||
def minWeightPerNode: Double = minWeightFractionPerNode * weightedNumExamples
|
||||
|
||||
/**
|
||||
* Number of splits for the given feature.
|
||||
* For unordered features, there is 1 bin per split.
|
||||
|
@ -104,7 +111,7 @@ private[spark] object DecisionTreeMetadata extends Logging {
|
|||
* as well as the number of splits and bins for each feature.
|
||||
*/
|
||||
def buildMetadata(
|
||||
input: RDD[LabeledPoint],
|
||||
input: RDD[Instance],
|
||||
strategy: Strategy,
|
||||
numTrees: Int,
|
||||
featureSubsetStrategy: String): DecisionTreeMetadata = {
|
||||
|
@ -115,7 +122,11 @@ private[spark] object DecisionTreeMetadata extends Logging {
|
|||
}
|
||||
require(numFeatures > 0, s"DecisionTree requires number of features > 0, " +
|
||||
s"but was given an empty features vector")
|
||||
val numExamples = input.count()
|
||||
val (numExamples, weightSum) = input.aggregate((0L, 0.0))(
|
||||
seqOp = (cw, instance) => (cw._1 + 1L, cw._2 + instance.weight),
|
||||
combOp = (cw1, cw2) => (cw1._1 + cw2._1, cw1._2 + cw2._2)
|
||||
)
|
||||
|
||||
val numClasses = strategy.algo match {
|
||||
case Classification => strategy.numClasses
|
||||
case Regression => 0
|
||||
|
@ -206,17 +217,18 @@ private[spark] object DecisionTreeMetadata extends Logging {
|
|||
}
|
||||
}
|
||||
|
||||
new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
|
||||
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
|
||||
new DecisionTreeMetadata(numFeatures, numExamples, weightSum, numClasses,
|
||||
numBins.max, strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
|
||||
strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth,
|
||||
strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode)
|
||||
strategy.minInstancesPerNode, strategy.minWeightFractionPerNode, strategy.minInfoGain,
|
||||
numTrees, numFeaturesPerNode)
|
||||
}
|
||||
|
||||
/**
|
||||
* Version of [[DecisionTreeMetadata#buildMetadata]] for DecisionTree.
|
||||
*/
|
||||
def buildMetadata(
|
||||
input: RDD[LabeledPoint],
|
||||
input: RDD[Instance],
|
||||
strategy: Strategy): DecisionTreeMetadata = {
|
||||
buildMetadata(input, strategy, numTrees = 1, featureSubsetStrategy = "all")
|
||||
}
|
||||
|
|
|
@ -24,10 +24,12 @@ import scala.util.Random
|
|||
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
|
||||
import org.apache.spark.ml.feature.LabeledPoint
|
||||
import org.apache.spark.ml.feature.Instance
|
||||
import org.apache.spark.ml.impl.Utils
|
||||
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
|
||||
import org.apache.spark.ml.tree._
|
||||
import org.apache.spark.ml.util.Instrumentation
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
|
||||
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
|
||||
import org.apache.spark.mllib.tree.model.ImpurityStats
|
||||
|
@ -90,6 +92,24 @@ private[spark] object RandomForest extends Logging with Serializable {
|
|||
strategy: OldStrategy,
|
||||
numTrees: Int,
|
||||
featureSubsetStrategy: String,
|
||||
seed: Long): Array[DecisionTreeModel] = {
|
||||
val instances = input.map { case LabeledPoint(label, features) =>
|
||||
Instance(label, 1.0, features.asML)
|
||||
}
|
||||
run(instances, strategy, numTrees, featureSubsetStrategy, seed, None)
|
||||
}
|
||||
|
||||
/**
|
||||
* Train a random forest.
|
||||
*
|
||||
* @param input Training data: RDD of `Instance`
|
||||
* @return an unweighted set of trees
|
||||
*/
|
||||
def run(
|
||||
input: RDD[Instance],
|
||||
strategy: OldStrategy,
|
||||
numTrees: Int,
|
||||
featureSubsetStrategy: String,
|
||||
seed: Long,
|
||||
instr: Option[Instrumentation],
|
||||
prune: Boolean = true, // exposed for testing only, real trees are always pruned
|
||||
|
@ -101,9 +121,10 @@ private[spark] object RandomForest extends Logging with Serializable {
|
|||
|
||||
timer.start("init")
|
||||
|
||||
val retaggedInput = input.retag(classOf[LabeledPoint])
|
||||
val retaggedInput = input.retag(classOf[Instance])
|
||||
val metadata =
|
||||
DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy)
|
||||
|
||||
instr match {
|
||||
case Some(instrumentation) =>
|
||||
instrumentation.logNumFeatures(metadata.numFeatures)
|
||||
|
@ -132,7 +153,8 @@ private[spark] object RandomForest extends Logging with Serializable {
|
|||
val withReplacement = numTrees > 1
|
||||
|
||||
val baggedInput = BaggedPoint
|
||||
.convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, withReplacement, seed)
|
||||
.convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, withReplacement,
|
||||
(tp: TreePoint) => tp.weight, seed = seed)
|
||||
.persist(StorageLevel.MEMORY_AND_DISK)
|
||||
|
||||
// depth of the decision tree
|
||||
|
@ -254,19 +276,21 @@ private[spark] object RandomForest extends Logging with Serializable {
|
|||
* For unordered features, bins correspond to subsets of categories; either the left or right bin
|
||||
* for each subset is updated.
|
||||
*
|
||||
* @param agg Array storing aggregate calculation, with a set of sufficient statistics for
|
||||
* each (feature, bin).
|
||||
* @param treePoint Data point being aggregated.
|
||||
* @param splits possible splits indexed (numFeatures)(numSplits)
|
||||
* @param unorderedFeatures Set of indices of unordered features.
|
||||
* @param instanceWeight Weight (importance) of instance in dataset.
|
||||
* @param agg Array storing aggregate calculation, with a set of sufficient statistics for
|
||||
* each (feature, bin).
|
||||
* @param treePoint Data point being aggregated.
|
||||
* @param splits Possible splits indexed (numFeatures)(numSplits)
|
||||
* @param unorderedFeatures Set of indices of unordered features.
|
||||
* @param numSamples Number of times this instance occurs in the sample.
|
||||
* @param sampleWeight Weight (importance) of instance in dataset.
|
||||
*/
|
||||
private def mixedBinSeqOp(
|
||||
agg: DTStatsAggregator,
|
||||
treePoint: TreePoint,
|
||||
splits: Array[Array[Split]],
|
||||
unorderedFeatures: Set[Int],
|
||||
instanceWeight: Double,
|
||||
numSamples: Int,
|
||||
sampleWeight: Double,
|
||||
featuresForNode: Option[Array[Int]]): Unit = {
|
||||
val numFeaturesPerNode = if (featuresForNode.nonEmpty) {
|
||||
// Use subsampled features
|
||||
|
@ -293,14 +317,15 @@ private[spark] object RandomForest extends Logging with Serializable {
|
|||
var splitIndex = 0
|
||||
while (splitIndex < numSplits) {
|
||||
if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) {
|
||||
agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight)
|
||||
agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, numSamples,
|
||||
sampleWeight)
|
||||
}
|
||||
splitIndex += 1
|
||||
}
|
||||
} else {
|
||||
// Ordered feature
|
||||
val binIndex = treePoint.binnedFeatures(featureIndex)
|
||||
agg.update(featureIndexIdx, binIndex, treePoint.label, instanceWeight)
|
||||
agg.update(featureIndexIdx, binIndex, treePoint.label, numSamples, sampleWeight)
|
||||
}
|
||||
featureIndexIdx += 1
|
||||
}
|
||||
|
@ -314,12 +339,14 @@ private[spark] object RandomForest extends Logging with Serializable {
|
|||
* @param agg Array storing aggregate calculation, with a set of sufficient statistics for
|
||||
* each (feature, bin).
|
||||
* @param treePoint Data point being aggregated.
|
||||
* @param instanceWeight Weight (importance) of instance in dataset.
|
||||
* @param numSamples Number of times this instance occurs in the sample.
|
||||
* @param sampleWeight Weight (importance) of instance in dataset.
|
||||
*/
|
||||
private def orderedBinSeqOp(
|
||||
agg: DTStatsAggregator,
|
||||
treePoint: TreePoint,
|
||||
instanceWeight: Double,
|
||||
numSamples: Int,
|
||||
sampleWeight: Double,
|
||||
featuresForNode: Option[Array[Int]]): Unit = {
|
||||
val label = treePoint.label
|
||||
|
||||
|
@ -329,7 +356,7 @@ private[spark] object RandomForest extends Logging with Serializable {
|
|||
var featureIndexIdx = 0
|
||||
while (featureIndexIdx < featuresForNode.get.length) {
|
||||
val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx))
|
||||
agg.update(featureIndexIdx, binIndex, label, instanceWeight)
|
||||
agg.update(featureIndexIdx, binIndex, label, numSamples, sampleWeight)
|
||||
featureIndexIdx += 1
|
||||
}
|
||||
} else {
|
||||
|
@ -338,7 +365,7 @@ private[spark] object RandomForest extends Logging with Serializable {
|
|||
var featureIndex = 0
|
||||
while (featureIndex < numFeatures) {
|
||||
val binIndex = treePoint.binnedFeatures(featureIndex)
|
||||
agg.update(featureIndex, binIndex, label, instanceWeight)
|
||||
agg.update(featureIndex, binIndex, label, numSamples, sampleWeight)
|
||||
featureIndex += 1
|
||||
}
|
||||
}
|
||||
|
@ -427,14 +454,16 @@ private[spark] object RandomForest extends Logging with Serializable {
|
|||
if (nodeInfo != null) {
|
||||
val aggNodeIndex = nodeInfo.nodeIndexInGroup
|
||||
val featuresForNode = nodeInfo.featureSubset
|
||||
val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
|
||||
val numSamples = baggedPoint.subsampleCounts(treeIndex)
|
||||
val sampleWeight = baggedPoint.sampleWeight
|
||||
if (metadata.unorderedFeatures.isEmpty) {
|
||||
orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
|
||||
orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, numSamples, sampleWeight,
|
||||
featuresForNode)
|
||||
} else {
|
||||
mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
|
||||
metadata.unorderedFeatures, instanceWeight, featuresForNode)
|
||||
metadata.unorderedFeatures, numSamples, sampleWeight, featuresForNode)
|
||||
}
|
||||
agg(aggNodeIndex).updateParent(baggedPoint.datum.label, instanceWeight)
|
||||
agg(aggNodeIndex).updateParent(baggedPoint.datum.label, numSamples, sampleWeight)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -594,8 +623,8 @@ private[spark] object RandomForest extends Logging with Serializable {
|
|||
if (!isLeaf) {
|
||||
node.split = Some(split)
|
||||
val childIsLeaf = (LearningNode.indexToLevel(nodeIndex) + 1) == metadata.maxDepth
|
||||
val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0)
|
||||
val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0)
|
||||
val leftChildIsLeaf = childIsLeaf || (math.abs(stats.leftImpurity) < Utils.EPSILON)
|
||||
val rightChildIsLeaf = childIsLeaf || (math.abs(stats.rightImpurity) < Utils.EPSILON)
|
||||
node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex),
|
||||
leftChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator)))
|
||||
node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex),
|
||||
|
@ -659,15 +688,20 @@ private[spark] object RandomForest extends Logging with Serializable {
|
|||
stats.impurity
|
||||
}
|
||||
|
||||
val leftRawCount = leftImpurityCalculator.rawCount
|
||||
val rightRawCount = rightImpurityCalculator.rawCount
|
||||
val leftCount = leftImpurityCalculator.count
|
||||
val rightCount = rightImpurityCalculator.count
|
||||
|
||||
val totalCount = leftCount + rightCount
|
||||
|
||||
// If left child or right child doesn't satisfy minimum instances per node,
|
||||
// then this split is invalid, return invalid information gain stats.
|
||||
if ((leftCount < metadata.minInstancesPerNode) ||
|
||||
(rightCount < metadata.minInstancesPerNode)) {
|
||||
val violatesMinInstancesPerNode = (leftRawCount < metadata.minInstancesPerNode) ||
|
||||
(rightRawCount < metadata.minInstancesPerNode)
|
||||
val violatesMinWeightPerNode = (leftCount < metadata.minWeightPerNode) ||
|
||||
(rightCount < metadata.minWeightPerNode)
|
||||
// If left child or right child doesn't satisfy minimum weight per node or minimum
|
||||
// instances per node, then this split is invalid, return invalid information gain stats.
|
||||
if (violatesMinInstancesPerNode || violatesMinWeightPerNode) {
|
||||
return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
|
||||
}
|
||||
|
||||
|
@ -734,7 +768,8 @@ private[spark] object RandomForest extends Logging with Serializable {
|
|||
// Find best split.
|
||||
val (bestFeatureSplitIndex, bestFeatureGainStats) =
|
||||
Range(0, numSplits).map { case splitIdx =>
|
||||
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
|
||||
val leftChildStats =
|
||||
binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
|
||||
val rightChildStats =
|
||||
binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
|
||||
rightChildStats.subtract(leftChildStats)
|
||||
|
@ -876,14 +911,14 @@ private[spark] object RandomForest extends Logging with Serializable {
|
|||
* and for multiclass classification with a high-arity feature,
|
||||
* there is one bin per category.
|
||||
*
|
||||
* @param input Training data: RDD of [[LabeledPoint]]
|
||||
* @param input Training data: RDD of [[Instance]]
|
||||
* @param metadata Learning and dataset metadata
|
||||
* @param seed random seed
|
||||
* @return Splits, an Array of [[Split]]
|
||||
* of size (numFeatures, numSplits)
|
||||
*/
|
||||
protected[tree] def findSplits(
|
||||
input: RDD[LabeledPoint],
|
||||
input: RDD[Instance],
|
||||
metadata: DecisionTreeMetadata,
|
||||
seed: Long): Array[Array[Split]] = {
|
||||
|
||||
|
@ -898,14 +933,14 @@ private[spark] object RandomForest extends Logging with Serializable {
|
|||
logDebug("fraction of data used for calculating quantiles = " + fraction)
|
||||
input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt())
|
||||
} else {
|
||||
input.sparkContext.emptyRDD[LabeledPoint]
|
||||
input.sparkContext.emptyRDD[Instance]
|
||||
}
|
||||
|
||||
findSplitsBySorting(sampledInput, metadata, continuousFeatures)
|
||||
}
|
||||
|
||||
private def findSplitsBySorting(
|
||||
input: RDD[LabeledPoint],
|
||||
input: RDD[Instance],
|
||||
metadata: DecisionTreeMetadata,
|
||||
continuousFeatures: IndexedSeq[Int]): Array[Array[Split]] = {
|
||||
|
||||
|
@ -917,7 +952,8 @@ private[spark] object RandomForest extends Logging with Serializable {
|
|||
|
||||
input
|
||||
.flatMap { point =>
|
||||
continuousFeatures.map(idx => (idx, point.features(idx))).filter(_._2 != 0.0)
|
||||
continuousFeatures.map(idx => (idx, (point.weight, point.features(idx))))
|
||||
.filter(_._2._2 != 0.0)
|
||||
}.groupByKey(numPartitions)
|
||||
.map { case (idx, samples) =>
|
||||
val thresholds = findSplitsForContinuousFeature(samples, metadata, idx)
|
||||
|
@ -982,7 +1018,7 @@ private[spark] object RandomForest extends Logging with Serializable {
|
|||
* could be different from the specified `numSplits`.
|
||||
* The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly.
|
||||
*
|
||||
* @param featureSamples feature values of each sample
|
||||
* @param featureSamples feature values and sample weights of each sample
|
||||
* @param metadata decision tree metadata
|
||||
* NOTE: `metadata.numbins` will be changed accordingly
|
||||
* if there are not enough splits to be found
|
||||
|
@ -990,7 +1026,7 @@ private[spark] object RandomForest extends Logging with Serializable {
|
|||
* @return array of split thresholds
|
||||
*/
|
||||
private[tree] def findSplitsForContinuousFeature(
|
||||
featureSamples: Iterable[Double],
|
||||
featureSamples: Iterable[(Double, Double)],
|
||||
metadata: DecisionTreeMetadata,
|
||||
featureIndex: Int): Array[Double] = {
|
||||
require(metadata.isContinuous(featureIndex),
|
||||
|
@ -1002,19 +1038,24 @@ private[spark] object RandomForest extends Logging with Serializable {
|
|||
val numSplits = metadata.numSplits(featureIndex)
|
||||
|
||||
// get count for each distinct value except zero value
|
||||
val partNumSamples = featureSamples.size
|
||||
val partValueCountMap = scala.collection.mutable.Map[Double, Int]()
|
||||
featureSamples.foreach { x =>
|
||||
partValueCountMap(x) = partValueCountMap.getOrElse(x, 0) + 1
|
||||
val partValueCountMap = mutable.Map[Double, Double]()
|
||||
var partNumSamples = 0.0
|
||||
var unweightedNumSamples = 0.0
|
||||
featureSamples.foreach { case (sampleWeight, feature) =>
|
||||
partValueCountMap(feature) = partValueCountMap.getOrElse(feature, 0.0) + sampleWeight;
|
||||
partNumSamples += sampleWeight;
|
||||
unweightedNumSamples += 1.0
|
||||
}
|
||||
|
||||
// Calculate the expected number of samples for finding splits
|
||||
val numSamples = (samplesFractionForFindSplits(metadata) * metadata.numExamples).toInt
|
||||
val weightedNumSamples = samplesFractionForFindSplits(metadata) *
|
||||
metadata.weightedNumExamples
|
||||
// add expected zero value count and get complete statistics
|
||||
val valueCountMap: Map[Double, Int] = if (numSamples - partNumSamples > 0) {
|
||||
partValueCountMap.toMap + (0.0 -> (numSamples - partNumSamples))
|
||||
val tolerance = Utils.EPSILON * unweightedNumSamples * unweightedNumSamples
|
||||
val valueCountMap = if (weightedNumSamples - partNumSamples > tolerance) {
|
||||
partValueCountMap + (0.0 -> (weightedNumSamples - partNumSamples))
|
||||
} else {
|
||||
partValueCountMap.toMap
|
||||
partValueCountMap
|
||||
}
|
||||
|
||||
// sort distinct values
|
||||
|
@ -1031,7 +1072,7 @@ private[spark] object RandomForest extends Logging with Serializable {
|
|||
.toArray
|
||||
} else {
|
||||
// stride between splits
|
||||
val stride: Double = numSamples.toDouble / (numSplits + 1)
|
||||
val stride: Double = weightedNumSamples / (numSplits + 1)
|
||||
logDebug("stride = " + stride)
|
||||
|
||||
// iterate `valueCount` to find splits
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
package org.apache.spark.ml.tree.impl
|
||||
|
||||
import org.apache.spark.ml.feature.LabeledPoint
|
||||
import org.apache.spark.ml.feature.Instance
|
||||
import org.apache.spark.ml.tree.{ContinuousSplit, Split}
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
|
@ -36,10 +36,12 @@ import org.apache.spark.rdd.RDD
|
|||
* @param label Label from LabeledPoint
|
||||
* @param binnedFeatures Binned feature values.
|
||||
* Same length as LabeledPoint.features, but values are bin indices.
|
||||
* @param weight Sample weight for this TreePoint.
|
||||
*/
|
||||
private[spark] class TreePoint(val label: Double, val binnedFeatures: Array[Int])
|
||||
extends Serializable {
|
||||
}
|
||||
private[spark] class TreePoint(
|
||||
val label: Double,
|
||||
val binnedFeatures: Array[Int],
|
||||
val weight: Double) extends Serializable
|
||||
|
||||
private[spark] object TreePoint {
|
||||
|
||||
|
@ -52,7 +54,7 @@ private[spark] object TreePoint {
|
|||
* @return TreePoint dataset representation
|
||||
*/
|
||||
def convertToTreeRDD(
|
||||
input: RDD[LabeledPoint],
|
||||
input: RDD[Instance],
|
||||
splits: Array[Array[Split]],
|
||||
metadata: DecisionTreeMetadata): RDD[TreePoint] = {
|
||||
// Construct arrays for featureArity for efficiency in the inner loop.
|
||||
|
@ -82,18 +84,18 @@ private[spark] object TreePoint {
|
|||
* for categorical features.
|
||||
*/
|
||||
private def labeledPointToTreePoint(
|
||||
labeledPoint: LabeledPoint,
|
||||
instance: Instance,
|
||||
thresholds: Array[Array[Double]],
|
||||
featureArity: Array[Int]): TreePoint = {
|
||||
val numFeatures = labeledPoint.features.size
|
||||
val numFeatures = instance.features.size
|
||||
val arr = new Array[Int](numFeatures)
|
||||
var featureIndex = 0
|
||||
while (featureIndex < numFeatures) {
|
||||
arr(featureIndex) =
|
||||
findBin(featureIndex, labeledPoint, featureArity(featureIndex), thresholds(featureIndex))
|
||||
findBin(featureIndex, instance, featureArity(featureIndex), thresholds(featureIndex))
|
||||
featureIndex += 1
|
||||
}
|
||||
new TreePoint(labeledPoint.label, arr)
|
||||
new TreePoint(instance.label, arr, instance.weight)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -106,10 +108,10 @@ private[spark] object TreePoint {
|
|||
*/
|
||||
private def findBin(
|
||||
featureIndex: Int,
|
||||
labeledPoint: LabeledPoint,
|
||||
instance: Instance,
|
||||
featureArity: Int,
|
||||
thresholds: Array[Double]): Int = {
|
||||
val featureValue = labeledPoint.features(featureIndex)
|
||||
val featureValue = instance.features(featureIndex)
|
||||
|
||||
if (featureArity == 0) {
|
||||
val idx = java.util.Arrays.binarySearch(thresholds, featureValue)
|
||||
|
@ -125,7 +127,7 @@ private[spark] object TreePoint {
|
|||
s"DecisionTree given invalid data:" +
|
||||
s" Feature $featureIndex is categorical with values in {0,...,${featureArity - 1}," +
|
||||
s" but a data point gives it value $featureValue.\n" +
|
||||
" Bad data point: " + labeledPoint.toString)
|
||||
s" Bad data point: $instance")
|
||||
}
|
||||
featureValue.toInt
|
||||
}
|
||||
|
|
|
@ -282,6 +282,7 @@ private[ml] object DecisionTreeModelReadWrite {
|
|||
*
|
||||
* @param id Index used for tree reconstruction. Indices follow a pre-order traversal.
|
||||
* @param impurityStats Stats array. Impurity type is stored in metadata.
|
||||
* @param rawCount The unweighted number of samples falling in this node.
|
||||
* @param gain Gain, or arbitrary value if leaf node.
|
||||
* @param leftChild Left child index, or arbitrary value if leaf node.
|
||||
* @param rightChild Right child index, or arbitrary value if leaf node.
|
||||
|
@ -292,6 +293,7 @@ private[ml] object DecisionTreeModelReadWrite {
|
|||
prediction: Double,
|
||||
impurity: Double,
|
||||
impurityStats: Array[Double],
|
||||
rawCount: Long,
|
||||
gain: Double,
|
||||
leftChild: Int,
|
||||
rightChild: Int,
|
||||
|
@ -311,11 +313,12 @@ private[ml] object DecisionTreeModelReadWrite {
|
|||
val (leftNodeData, leftIdx) = build(n.leftChild, id + 1)
|
||||
val (rightNodeData, rightIdx) = build(n.rightChild, leftIdx + 1)
|
||||
val thisNodeData = NodeData(id, n.prediction, n.impurity, n.impurityStats.stats,
|
||||
n.gain, leftNodeData.head.id, rightNodeData.head.id, SplitData(n.split))
|
||||
n.impurityStats.rawCount, n.gain, leftNodeData.head.id, rightNodeData.head.id,
|
||||
SplitData(n.split))
|
||||
(thisNodeData +: (leftNodeData ++ rightNodeData), rightIdx)
|
||||
case _: LeafNode =>
|
||||
(Seq(NodeData(id, node.prediction, node.impurity, node.impurityStats.stats,
|
||||
-1.0, -1, -1, SplitData(-1, Array.empty[Double], -1))),
|
||||
node.impurityStats.rawCount, -1.0, -1, -1, SplitData(-1, Array.empty[Double], -1))),
|
||||
id)
|
||||
}
|
||||
}
|
||||
|
@ -360,7 +363,8 @@ private[ml] object DecisionTreeModelReadWrite {
|
|||
// traversal, this guarantees that child nodes will be built before parent nodes.
|
||||
val finalNodes = new Array[Node](nodes.length)
|
||||
nodes.reverseIterator.foreach { case n: NodeData =>
|
||||
val impurityStats = ImpurityCalculator.getCalculator(impurityType, n.impurityStats)
|
||||
val impurityStats =
|
||||
ImpurityCalculator.getCalculator(impurityType, n.impurityStats, n.rawCount)
|
||||
val node = if (n.leftChild != -1) {
|
||||
val leftChild = finalNodes(n.leftChild)
|
||||
val rightChild = finalNodes(n.rightChild)
|
||||
|
|
|
@ -37,7 +37,7 @@ import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
|
|||
* Note: Marked as private and DeveloperApi since this may be made public in the future.
|
||||
*/
|
||||
private[ml] trait DecisionTreeParams extends PredictorParams
|
||||
with HasCheckpointInterval with HasSeed {
|
||||
with HasCheckpointInterval with HasSeed with HasWeightCol {
|
||||
|
||||
/**
|
||||
* Maximum depth of the tree (>= 0).
|
||||
|
@ -74,6 +74,21 @@ private[ml] trait DecisionTreeParams extends PredictorParams
|
|||
" child to have fewer than minInstancesPerNode, the split will be discarded as invalid." +
|
||||
" Should be >= 1.", ParamValidators.gtEq(1))
|
||||
|
||||
/**
|
||||
* Minimum fraction of the weighted sample count that each child must have after split.
|
||||
* If a split causes the fraction of the total weight in the left or right child to be less than
|
||||
* minWeightFractionPerNode, the split will be discarded as invalid.
|
||||
* Should be in the interval [0.0, 0.5).
|
||||
* (default = 0.0)
|
||||
* @group param
|
||||
*/
|
||||
final val minWeightFractionPerNode: DoubleParam = new DoubleParam(this,
|
||||
"minWeightFractionPerNode", "Minimum fraction of the weighted sample count that each child " +
|
||||
"must have after split. If a split causes the fraction of the total weight in the left or " +
|
||||
"right child to be less than minWeightFractionPerNode, the split will be discarded as " +
|
||||
"invalid. Should be in interval [0.0, 0.5)",
|
||||
ParamValidators.inRange(0.0, 0.5, lowerInclusive = true, upperInclusive = false))
|
||||
|
||||
/**
|
||||
* Minimum information gain for a split to be considered at a tree node.
|
||||
* Should be >= 0.0.
|
||||
|
@ -107,8 +122,9 @@ private[ml] trait DecisionTreeParams extends PredictorParams
|
|||
" algorithm will cache node IDs for each instance. Caching can speed up training of deeper" +
|
||||
" trees.")
|
||||
|
||||
setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0,
|
||||
maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10)
|
||||
setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1,
|
||||
minWeightFractionPerNode -> 0.0, minInfoGain -> 0.0, maxMemoryInMB -> 256,
|
||||
cacheNodeIds -> false, checkpointInterval -> 10)
|
||||
|
||||
/** @group getParam */
|
||||
final def getMaxDepth: Int = $(maxDepth)
|
||||
|
@ -119,6 +135,9 @@ private[ml] trait DecisionTreeParams extends PredictorParams
|
|||
/** @group getParam */
|
||||
final def getMinInstancesPerNode: Int = $(minInstancesPerNode)
|
||||
|
||||
/** @group getParam */
|
||||
final def getMinWeightFractionPerNode: Double = $(minWeightFractionPerNode)
|
||||
|
||||
/** @group getParam */
|
||||
final def getMinInfoGain: Double = $(minInfoGain)
|
||||
|
||||
|
@ -143,6 +162,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams
|
|||
strategy.maxMemoryInMB = getMaxMemoryInMB
|
||||
strategy.minInfoGain = getMinInfoGain
|
||||
strategy.minInstancesPerNode = getMinInstancesPerNode
|
||||
strategy.minWeightFractionPerNode = getMinWeightFractionPerNode
|
||||
strategy.useNodeIdCache = getCacheNodeIds
|
||||
strategy.numClasses = numClasses
|
||||
strategy.categoricalFeaturesInfo = categoricalFeatures
|
||||
|
|
|
@ -23,6 +23,7 @@ import scala.util.Try
|
|||
import org.apache.spark.annotation.Since
|
||||
import org.apache.spark.api.java.JavaRDD
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.ml.feature.Instance
|
||||
import org.apache.spark.ml.tree.{DecisionTreeModel => NewDTModel, TreeEnsembleParams => NewRFParams}
|
||||
import org.apache.spark.ml.tree.impl.{RandomForest => NewRandomForest}
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
|
@ -91,8 +92,8 @@ private class RandomForest (
|
|||
* @return RandomForestModel that can be used for prediction.
|
||||
*/
|
||||
def run(input: RDD[LabeledPoint]): RandomForestModel = {
|
||||
val trees: Array[NewDTModel] = NewRandomForest.run(input.map(_.asML), strategy, numTrees,
|
||||
featureSubsetStrategy, seed.toLong, None)
|
||||
val trees: Array[NewDTModel] =
|
||||
NewRandomForest.run(input, strategy, numTrees, featureSubsetStrategy, seed.toLong)
|
||||
new RandomForestModel(strategy.algo, trees.map(_.toOld))
|
||||
}
|
||||
|
||||
|
|
|
@ -80,7 +80,8 @@ class Strategy @Since("1.3.0") (
|
|||
@Since("1.0.0") @BeanProperty var maxMemoryInMB: Int = 256,
|
||||
@Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1,
|
||||
@Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false,
|
||||
@Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10) extends Serializable {
|
||||
@Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10,
|
||||
@Since("3.0.0") @BeanProperty var minWeightFractionPerNode: Double = 0.0) extends Serializable {
|
||||
|
||||
/**
|
||||
*/
|
||||
|
@ -96,6 +97,31 @@ class Strategy @Since("1.3.0") (
|
|||
isMulticlassClassification && (categoricalFeaturesInfo.size > 0)
|
||||
}
|
||||
|
||||
// scalastyle:off argcount
|
||||
/**
|
||||
* Backwards compatible constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]]
|
||||
*/
|
||||
@Since("1.0.0")
|
||||
def this(
|
||||
algo: Algo,
|
||||
impurity: Impurity,
|
||||
maxDepth: Int,
|
||||
numClasses: Int,
|
||||
maxBins: Int,
|
||||
quantileCalculationStrategy: QuantileStrategy,
|
||||
categoricalFeaturesInfo: Map[Int, Int],
|
||||
minInstancesPerNode: Int,
|
||||
minInfoGain: Double,
|
||||
maxMemoryInMB: Int,
|
||||
subsamplingRate: Double,
|
||||
useNodeIdCache: Boolean,
|
||||
checkpointInterval: Int) {
|
||||
this(algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy,
|
||||
categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, maxMemoryInMB,
|
||||
subsamplingRate, useNodeIdCache, checkpointInterval, 0.0)
|
||||
}
|
||||
// scalastyle:on argcount
|
||||
|
||||
/**
|
||||
* Java-friendly constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]]
|
||||
*/
|
||||
|
@ -108,7 +134,8 @@ class Strategy @Since("1.3.0") (
|
|||
maxBins: Int,
|
||||
categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]) {
|
||||
this(algo, impurity, maxDepth, numClasses, maxBins, Sort,
|
||||
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap)
|
||||
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
|
||||
minWeightFractionPerNode = 0.0)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -171,8 +198,9 @@ class Strategy @Since("1.3.0") (
|
|||
@Since("1.2.0")
|
||||
def copy: Strategy = {
|
||||
new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
|
||||
quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain,
|
||||
maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval)
|
||||
quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode,
|
||||
minInfoGain, maxMemoryInMB, subsamplingRate, useNodeIdCache,
|
||||
checkpointInterval, minWeightFractionPerNode)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -83,23 +83,29 @@ object Entropy extends Impurity {
|
|||
* @param numClasses Number of classes for label.
|
||||
*/
|
||||
private[spark] class EntropyAggregator(numClasses: Int)
|
||||
extends ImpurityAggregator(numClasses) with Serializable {
|
||||
extends ImpurityAggregator(numClasses + 1) with Serializable {
|
||||
|
||||
/**
|
||||
* Update stats for one (node, feature, bin) with the given label.
|
||||
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
|
||||
* @param offset Start index of stats for this (node, feature, bin).
|
||||
*/
|
||||
def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = {
|
||||
if (label >= statsSize) {
|
||||
def update(
|
||||
allStats: Array[Double],
|
||||
offset: Int,
|
||||
label: Double,
|
||||
numSamples: Int,
|
||||
sampleWeight: Double): Unit = {
|
||||
if (label >= numClasses) {
|
||||
throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
|
||||
s" but requires label < numClasses (= $statsSize).")
|
||||
s" but requires label < numClasses (= ${numClasses}).")
|
||||
}
|
||||
if (label < 0) {
|
||||
throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
|
||||
s"but requires label is non-negative.")
|
||||
}
|
||||
allStats(offset + label.toInt) += instanceWeight
|
||||
allStats(offset + label.toInt) += numSamples * sampleWeight
|
||||
allStats(offset + statsSize - 1) += numSamples
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -108,7 +114,8 @@ private[spark] class EntropyAggregator(numClasses: Int)
|
|||
* @param offset Start index of stats for this (node, feature, bin).
|
||||
*/
|
||||
def getCalculator(allStats: Array[Double], offset: Int): EntropyCalculator = {
|
||||
new EntropyCalculator(allStats.view(offset, offset + statsSize).toArray)
|
||||
new EntropyCalculator(allStats.view(offset, offset + statsSize - 1).toArray,
|
||||
allStats(offset + statsSize - 1).toLong)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -118,12 +125,13 @@ private[spark] class EntropyAggregator(numClasses: Int)
|
|||
* (node, feature, bin).
|
||||
* @param stats Array of sufficient statistics for a (node, feature, bin).
|
||||
*/
|
||||
private[spark] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
|
||||
private[spark] class EntropyCalculator(stats: Array[Double], var rawCount: Long)
|
||||
extends ImpurityCalculator(stats) {
|
||||
|
||||
/**
|
||||
* Make a deep copy of this [[ImpurityCalculator]].
|
||||
*/
|
||||
def copy: EntropyCalculator = new EntropyCalculator(stats.clone())
|
||||
def copy: EntropyCalculator = new EntropyCalculator(stats.clone(), rawCount)
|
||||
|
||||
/**
|
||||
* Calculate the impurity from the stored sufficient statistics.
|
||||
|
@ -131,9 +139,9 @@ private[spark] class EntropyCalculator(stats: Array[Double]) extends ImpurityCal
|
|||
def calculate(): Double = Entropy.calculate(stats, stats.sum)
|
||||
|
||||
/**
|
||||
* Number of data points accounted for in the sufficient statistics.
|
||||
* Weighted number of data points accounted for in the sufficient statistics.
|
||||
*/
|
||||
def count: Long = stats.sum.toLong
|
||||
def count: Double = stats.sum
|
||||
|
||||
/**
|
||||
* Prediction which should be made based on the sufficient statistics.
|
||||
|
|
|
@ -80,23 +80,29 @@ object Gini extends Impurity {
|
|||
* @param numClasses Number of classes for label.
|
||||
*/
|
||||
private[spark] class GiniAggregator(numClasses: Int)
|
||||
extends ImpurityAggregator(numClasses) with Serializable {
|
||||
extends ImpurityAggregator(numClasses + 1) with Serializable {
|
||||
|
||||
/**
|
||||
* Update stats for one (node, feature, bin) with the given label.
|
||||
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
|
||||
* @param offset Start index of stats for this (node, feature, bin).
|
||||
*/
|
||||
def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = {
|
||||
if (label >= statsSize) {
|
||||
def update(
|
||||
allStats: Array[Double],
|
||||
offset: Int,
|
||||
label: Double,
|
||||
numSamples: Int,
|
||||
sampleWeight: Double): Unit = {
|
||||
if (label >= numClasses) {
|
||||
throw new IllegalArgumentException(s"GiniAggregator given label $label" +
|
||||
s" but requires label < numClasses (= $statsSize).")
|
||||
s" but requires label < numClasses (= ${numClasses}).")
|
||||
}
|
||||
if (label < 0) {
|
||||
throw new IllegalArgumentException(s"GiniAggregator given label $label" +
|
||||
s"but requires label is non-negative.")
|
||||
s"but requires label to be non-negative.")
|
||||
}
|
||||
allStats(offset + label.toInt) += instanceWeight
|
||||
allStats(offset + label.toInt) += numSamples * sampleWeight
|
||||
allStats(offset + statsSize - 1) += numSamples
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -105,7 +111,8 @@ private[spark] class GiniAggregator(numClasses: Int)
|
|||
* @param offset Start index of stats for this (node, feature, bin).
|
||||
*/
|
||||
def getCalculator(allStats: Array[Double], offset: Int): GiniCalculator = {
|
||||
new GiniCalculator(allStats.view(offset, offset + statsSize).toArray)
|
||||
new GiniCalculator(allStats.view(offset, offset + statsSize - 1).toArray,
|
||||
allStats(offset + statsSize - 1).toLong)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -115,12 +122,13 @@ private[spark] class GiniAggregator(numClasses: Int)
|
|||
* (node, feature, bin).
|
||||
* @param stats Array of sufficient statistics for a (node, feature, bin).
|
||||
*/
|
||||
private[spark] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
|
||||
private[spark] class GiniCalculator(stats: Array[Double], var rawCount: Long)
|
||||
extends ImpurityCalculator(stats) {
|
||||
|
||||
/**
|
||||
* Make a deep copy of this [[ImpurityCalculator]].
|
||||
*/
|
||||
def copy: GiniCalculator = new GiniCalculator(stats.clone())
|
||||
def copy: GiniCalculator = new GiniCalculator(stats.clone(), rawCount)
|
||||
|
||||
/**
|
||||
* Calculate the impurity from the stored sufficient statistics.
|
||||
|
@ -128,9 +136,9 @@ private[spark] class GiniCalculator(stats: Array[Double]) extends ImpurityCalcul
|
|||
def calculate(): Double = Gini.calculate(stats, stats.sum)
|
||||
|
||||
/**
|
||||
* Number of data points accounted for in the sufficient statistics.
|
||||
* Weighted number of data points accounted for in the sufficient statistics.
|
||||
*/
|
||||
def count: Long = stats.sum.toLong
|
||||
def count: Double = stats.sum
|
||||
|
||||
/**
|
||||
* Prediction which should be made based on the sufficient statistics.
|
||||
|
|
|
@ -81,7 +81,12 @@ private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Ser
|
|||
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
|
||||
* @param offset Start index of stats for this (node, feature, bin).
|
||||
*/
|
||||
def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit
|
||||
def update(
|
||||
allStats: Array[Double],
|
||||
offset: Int,
|
||||
label: Double,
|
||||
numSamples: Int,
|
||||
sampleWeight: Double): Unit
|
||||
|
||||
/**
|
||||
* Get an [[ImpurityCalculator]] for a (node, feature, bin).
|
||||
|
@ -122,6 +127,7 @@ private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) exten
|
|||
stats(i) += other.stats(i)
|
||||
i += 1
|
||||
}
|
||||
rawCount += other.rawCount
|
||||
this
|
||||
}
|
||||
|
||||
|
@ -139,13 +145,19 @@ private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) exten
|
|||
stats(i) -= other.stats(i)
|
||||
i += 1
|
||||
}
|
||||
rawCount -= other.rawCount
|
||||
this
|
||||
}
|
||||
|
||||
/**
|
||||
* Number of data points accounted for in the sufficient statistics.
|
||||
* Weighted number of data points accounted for in the sufficient statistics.
|
||||
*/
|
||||
def count: Long
|
||||
def count: Double
|
||||
|
||||
/**
|
||||
* Raw number of data points accounted for in the sufficient statistics.
|
||||
*/
|
||||
var rawCount: Long
|
||||
|
||||
/**
|
||||
* Prediction which should be made based on the sufficient statistics.
|
||||
|
@ -185,11 +197,14 @@ private[spark] object ImpurityCalculator {
|
|||
* Create an [[ImpurityCalculator]] instance of the given impurity type and with
|
||||
* the given stats.
|
||||
*/
|
||||
def getCalculator(impurity: String, stats: Array[Double]): ImpurityCalculator = {
|
||||
def getCalculator(
|
||||
impurity: String,
|
||||
stats: Array[Double],
|
||||
rawCount: Long): ImpurityCalculator = {
|
||||
impurity.toLowerCase(Locale.ROOT) match {
|
||||
case "gini" => new GiniCalculator(stats)
|
||||
case "entropy" => new EntropyCalculator(stats)
|
||||
case "variance" => new VarianceCalculator(stats)
|
||||
case "gini" => new GiniCalculator(stats, rawCount)
|
||||
case "entropy" => new EntropyCalculator(stats, rawCount)
|
||||
case "variance" => new VarianceCalculator(stats, rawCount)
|
||||
case _ =>
|
||||
throw new IllegalArgumentException(
|
||||
s"ImpurityCalculator builder did not recognize impurity type: $impurity")
|
||||
|
|
|
@ -66,21 +66,32 @@ object Variance extends Impurity {
|
|||
|
||||
/**
|
||||
* Class for updating views of a vector of sufficient statistics,
|
||||
* in order to compute impurity from a sample.
|
||||
* in order to compute impurity from a sample. For variance, we track:
|
||||
* - sum(w_i)
|
||||
* - sum(w_i * y_i)
|
||||
* - sum(w_i * y_i * y_i)
|
||||
* - count(y_i)
|
||||
* Note: Instances of this class do not hold the data; they operate on views of the data.
|
||||
*/
|
||||
private[spark] class VarianceAggregator()
|
||||
extends ImpurityAggregator(statsSize = 3) with Serializable {
|
||||
extends ImpurityAggregator(statsSize = 4) with Serializable {
|
||||
|
||||
/**
|
||||
* Update stats for one (node, feature, bin) with the given label.
|
||||
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
|
||||
* @param offset Start index of stats for this (node, feature, bin).
|
||||
*/
|
||||
def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = {
|
||||
def update(
|
||||
allStats: Array[Double],
|
||||
offset: Int,
|
||||
label: Double,
|
||||
numSamples: Int,
|
||||
sampleWeight: Double): Unit = {
|
||||
val instanceWeight = numSamples * sampleWeight
|
||||
allStats(offset) += instanceWeight
|
||||
allStats(offset + 1) += instanceWeight * label
|
||||
allStats(offset + 2) += instanceWeight * label * label
|
||||
allStats(offset + 3) += numSamples
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -89,7 +100,8 @@ private[spark] class VarianceAggregator()
|
|||
* @param offset Start index of stats for this (node, feature, bin).
|
||||
*/
|
||||
def getCalculator(allStats: Array[Double], offset: Int): VarianceCalculator = {
|
||||
new VarianceCalculator(allStats.view(offset, offset + statsSize).toArray)
|
||||
new VarianceCalculator(allStats.view(offset, offset + statsSize - 1).toArray,
|
||||
allStats(offset + statsSize - 1).toLong)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -99,7 +111,8 @@ private[spark] class VarianceAggregator()
|
|||
* (node, feature, bin).
|
||||
* @param stats Array of sufficient statistics for a (node, feature, bin).
|
||||
*/
|
||||
private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
|
||||
private[spark] class VarianceCalculator(stats: Array[Double], var rawCount: Long)
|
||||
extends ImpurityCalculator(stats) {
|
||||
|
||||
require(stats.length == 3,
|
||||
s"VarianceCalculator requires sufficient statistics array stats to be of length 3," +
|
||||
|
@ -108,7 +121,7 @@ private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCa
|
|||
/**
|
||||
* Make a deep copy of this [[ImpurityCalculator]].
|
||||
*/
|
||||
def copy: VarianceCalculator = new VarianceCalculator(stats.clone())
|
||||
def copy: VarianceCalculator = new VarianceCalculator(stats.clone(), rawCount)
|
||||
|
||||
/**
|
||||
* Calculate the impurity from the stored sufficient statistics.
|
||||
|
@ -116,9 +129,9 @@ private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCa
|
|||
def calculate(): Double = Variance.calculate(stats(0), stats(1), stats(2))
|
||||
|
||||
/**
|
||||
* Number of data points accounted for in the sufficient statistics.
|
||||
* Weighted number of data points accounted for in the sufficient statistics.
|
||||
*/
|
||||
def count: Long = stats(0).toLong
|
||||
def count: Double = stats(0)
|
||||
|
||||
/**
|
||||
* Prediction which should be made based on the sufficient statistics.
|
||||
|
|
|
@ -42,6 +42,8 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest {
|
|||
private var continuousDataPointsForMulticlassRDD: RDD[LabeledPoint] = _
|
||||
private var categoricalDataPointsForMulticlassForOrderedFeaturesRDD: RDD[LabeledPoint] = _
|
||||
|
||||
private val seed = 42
|
||||
|
||||
override def beforeAll() {
|
||||
super.beforeAll()
|
||||
categoricalDataPointsRDD =
|
||||
|
@ -250,7 +252,7 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest {
|
|||
|
||||
MLTestingUtils.checkCopyAndUids(dt, newTree)
|
||||
|
||||
testTransformer[(Vector, Double)](newData, newTree,
|
||||
testTransformer[(Vector, Double, Double)](newData, newTree,
|
||||
"prediction", "rawPrediction", "probability") {
|
||||
case Row(pred: Double, rawPred: Vector, probPred: Vector) =>
|
||||
assert(pred === rawPred.argmax,
|
||||
|
@ -327,6 +329,49 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest {
|
|||
dt.fit(df)
|
||||
}
|
||||
|
||||
test("training with sample weights") {
|
||||
val df = {
|
||||
val nPoints = 100
|
||||
val coefficients = Array(
|
||||
-0.57997, 0.912083, -0.371077,
|
||||
-0.16624, -0.84355, -0.048509)
|
||||
|
||||
val xMean = Array(5.843, 3.057)
|
||||
val xVariance = Array(0.6856, 0.1899)
|
||||
|
||||
val testData = LogisticRegressionSuite.generateMultinomialLogisticInput(
|
||||
coefficients, xMean, xVariance, addIntercept = true, nPoints, seed)
|
||||
|
||||
sc.parallelize(testData, 4).toDF()
|
||||
}
|
||||
val numClasses = 3
|
||||
val predEquals = (x: Double, y: Double) => x == y
|
||||
// (impurity, maxDepth)
|
||||
val testParams = Seq(
|
||||
("gini", 10),
|
||||
("entropy", 10),
|
||||
("gini", 5)
|
||||
)
|
||||
for ((impurity, maxDepth) <- testParams) {
|
||||
val estimator = new DecisionTreeClassifier()
|
||||
.setMaxDepth(maxDepth)
|
||||
.setSeed(seed)
|
||||
.setMinWeightFractionPerNode(0.049)
|
||||
.setImpurity(impurity)
|
||||
|
||||
MLTestingUtils.testArbitrarilyScaledWeights[DecisionTreeClassificationModel,
|
||||
DecisionTreeClassifier](df.as[LabeledPoint], estimator,
|
||||
MLTestingUtils.modelPredictionEquals(df, predEquals, 0.7))
|
||||
MLTestingUtils.testOutliersWithSmallWeights[DecisionTreeClassificationModel,
|
||||
DecisionTreeClassifier](df.as[LabeledPoint], estimator,
|
||||
numClasses, MLTestingUtils.modelPredictionEquals(df, predEquals, 0.8),
|
||||
outlierRatio = 2)
|
||||
MLTestingUtils.testOversamplingVsWeighting[DecisionTreeClassificationModel,
|
||||
DecisionTreeClassifier](df.as[LabeledPoint], estimator,
|
||||
MLTestingUtils.modelPredictionEquals(df, predEquals, 0.7), seed)
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Tests of model save/load
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -350,7 +395,6 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest {
|
|||
TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 2)
|
||||
testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings,
|
||||
allParamSettings, checkModelData)
|
||||
|
||||
// Continuous splits with tree depth 2
|
||||
val continuousData: DataFrame =
|
||||
TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
package org.apache.spark.ml.classification
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.ml.feature.LabeledPoint
|
||||
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
|
||||
import org.apache.spark.ml.linalg.{Vector, Vectors}
|
||||
import org.apache.spark.ml.param.ParamsSuite
|
||||
import org.apache.spark.ml.tree.LeafNode
|
||||
|
@ -141,7 +141,7 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest {
|
|||
|
||||
MLTestingUtils.checkCopyAndUids(rf, model)
|
||||
|
||||
testTransformer[(Vector, Double)](df, model, "prediction", "rawPrediction",
|
||||
testTransformer[(Vector, Double, Double)](df, model, "prediction", "rawPrediction",
|
||||
"probability") { case Row(pred: Double, rawPred: Vector, probPred: Vector) =>
|
||||
assert(pred === rawPred.argmax,
|
||||
s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.")
|
||||
|
@ -180,7 +180,6 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest {
|
|||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Tests of feature importance
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
test("Feature importance with toy data") {
|
||||
val numClasses = 2
|
||||
val rf = new RandomForestClassifier()
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
package org.apache.spark.ml.regression
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.ml.feature.LabeledPoint
|
||||
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
|
||||
import org.apache.spark.ml.linalg.Vector
|
||||
import org.apache.spark.ml.tree.impl.TreeTests
|
||||
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
|
||||
|
@ -26,6 +26,7 @@ import org.apache.spark.ml.util.TestingUtils._
|
|||
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
|
||||
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
|
||||
DecisionTreeSuite => OldDecisionTreeSuite}
|
||||
import org.apache.spark.mllib.util.LinearDataGenerator
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.{DataFrame, Row}
|
||||
|
||||
|
@ -35,11 +36,17 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest {
|
|||
import testImplicits._
|
||||
|
||||
private var categoricalDataPointsRDD: RDD[LabeledPoint] = _
|
||||
private var linearRegressionData: DataFrame = _
|
||||
|
||||
private val seed = 42
|
||||
|
||||
override def beforeAll() {
|
||||
super.beforeAll()
|
||||
categoricalDataPointsRDD =
|
||||
sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints().map(_.asML))
|
||||
linearRegressionData = sc.parallelize(LinearDataGenerator.generateLinearInput(
|
||||
intercept = 6.3, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3),
|
||||
xVariance = Array(0.7, 1.2), nPoints = 1000, seed, eps = 0.5), 2).map(_.asML).toDF()
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -88,7 +95,7 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest {
|
|||
val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0)
|
||||
val model = dt.fit(df)
|
||||
|
||||
testTransformer[(Vector, Double)](df, model, "features", "variance") {
|
||||
testTransformer[(Vector, Double, Double)](df, model, "features", "variance") {
|
||||
case Row(features: Vector, variance: Double) =>
|
||||
val expectedVariance = model.rootNode.predictImpl(features).impurityStats.calculate()
|
||||
assert(variance === expectedVariance,
|
||||
|
@ -101,7 +108,7 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest {
|
|||
.setMaxBins(6)
|
||||
.setSeed(0)
|
||||
|
||||
testTransformerByGlobalCheckFunc[(Vector, Double)](varianceDF, dt.fit(varianceDF),
|
||||
testTransformerByGlobalCheckFunc[(Vector, Double, Double)](varianceDF, dt.fit(varianceDF),
|
||||
"variance") { case rows: Seq[Row] =>
|
||||
val calculatedVariances = rows.map(_.getDouble(0))
|
||||
|
||||
|
@ -159,6 +166,28 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest {
|
|||
}
|
||||
}
|
||||
|
||||
test("training with sample weights") {
|
||||
val df = linearRegressionData
|
||||
val numClasses = 0
|
||||
val testParams = Seq(5, 10)
|
||||
for (maxDepth <- testParams) {
|
||||
val estimator = new DecisionTreeRegressor()
|
||||
.setMaxDepth(maxDepth)
|
||||
.setMinWeightFractionPerNode(0.05)
|
||||
.setSeed(123)
|
||||
MLTestingUtils.testArbitrarilyScaledWeights[DecisionTreeRegressionModel,
|
||||
DecisionTreeRegressor](df.as[LabeledPoint], estimator,
|
||||
MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.1, 0.99))
|
||||
MLTestingUtils.testOutliersWithSmallWeights[DecisionTreeRegressionModel,
|
||||
DecisionTreeRegressor](df.as[LabeledPoint], estimator, numClasses,
|
||||
MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.1, 0.99),
|
||||
outlierRatio = 2)
|
||||
MLTestingUtils.testOversamplingVsWeighting[DecisionTreeRegressionModel,
|
||||
DecisionTreeRegressor](df.as[LabeledPoint], estimator,
|
||||
MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.01, 1.0), seed)
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Tests of model save/load
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -891,6 +891,7 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLRe
|
|||
.setStandardization(standardization)
|
||||
.setRegParam(regParam)
|
||||
.setElasticNetParam(elasticNetParam)
|
||||
.setSolver(solver)
|
||||
MLTestingUtils.testArbitrarilyScaledWeights[LinearRegressionModel, LinearRegression](
|
||||
datasetWithStrongNoise.as[LabeledPoint], estimator, modelEquals)
|
||||
MLTestingUtils.testOutliersWithSmallWeights[LinearRegressionModel, LinearRegression](
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.spark.ml.tree.impl
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
|
||||
import org.apache.spark.mllib.tree.EnsembleTestHelper
|
||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||
|
||||
|
@ -26,12 +27,16 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
|
|||
*/
|
||||
class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||
|
||||
test("BaggedPoint RDD: without subsampling") {
|
||||
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
|
||||
test("BaggedPoint RDD: without subsampling with weights") {
|
||||
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000).map { lp =>
|
||||
Instance(lp.label, 0.5, lp.features.asML)
|
||||
}
|
||||
val rdd = sc.parallelize(arr)
|
||||
val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false, 42)
|
||||
val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false,
|
||||
(instance: Instance) => instance.weight * 4.0, seed = 42)
|
||||
baggedRDD.collect().foreach { baggedPoint =>
|
||||
assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1)
|
||||
assert(baggedPoint.subsampleCounts.size === 1 && baggedPoint.subsampleCounts(0) === 1)
|
||||
assert(baggedPoint.sampleWeight === 2.0)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -40,13 +45,17 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
val (expectedMean, expectedStddev) = (1.0, 1.0)
|
||||
|
||||
val seeds = Array(123, 5354, 230, 349867, 23987)
|
||||
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
|
||||
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000).map(_.asML)
|
||||
val rdd = sc.parallelize(arr)
|
||||
seeds.foreach { seed =>
|
||||
val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true, seed)
|
||||
val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
|
||||
val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true,
|
||||
(_: LabeledPoint) => 2.0, seed)
|
||||
val subsampleCounts: Array[Array[Double]] =
|
||||
baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect()
|
||||
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
|
||||
expectedStddev, epsilon = 0.01)
|
||||
// should ignore weight function for now
|
||||
assert(baggedRDD.collect().forall(_.sampleWeight === 1.0))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -59,8 +68,10 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
|
||||
val rdd = sc.parallelize(arr)
|
||||
seeds.foreach { seed =>
|
||||
val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true, seed)
|
||||
val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
|
||||
val baggedRDD =
|
||||
BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true, seed = seed)
|
||||
val subsampleCounts: Array[Array[Double]] =
|
||||
baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect()
|
||||
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
|
||||
expectedStddev, epsilon = 0.01)
|
||||
}
|
||||
|
@ -71,13 +82,17 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
val (expectedMean, expectedStddev) = (1.0, 0)
|
||||
|
||||
val seeds = Array(123, 5354, 230, 349867, 23987)
|
||||
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
|
||||
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000).map(_.asML)
|
||||
val rdd = sc.parallelize(arr)
|
||||
seeds.foreach { seed =>
|
||||
val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false, seed)
|
||||
val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
|
||||
val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false,
|
||||
(_: LabeledPoint) => 2.0, seed)
|
||||
val subsampleCounts: Array[Array[Double]] =
|
||||
baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect()
|
||||
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
|
||||
expectedStddev, epsilon = 0.01)
|
||||
// should ignore weight function for now
|
||||
assert(baggedRDD.collect().forall(_.sampleWeight === 1.0))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -90,8 +105,10 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
|
||||
val rdd = sc.parallelize(arr)
|
||||
seeds.foreach { seed =>
|
||||
val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false, seed)
|
||||
val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
|
||||
val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false,
|
||||
seed = seed)
|
||||
val subsampleCounts: Array[Array[Double]] =
|
||||
baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect()
|
||||
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
|
||||
expectedStddev, epsilon = 0.01)
|
||||
}
|
||||
|
|
|
@ -19,10 +19,11 @@ package org.apache.spark.ml.tree.impl
|
|||
|
||||
import scala.annotation.tailrec
|
||||
import scala.collection.mutable
|
||||
import scala.language.implicitConversions
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
|
||||
import org.apache.spark.ml.feature.LabeledPoint
|
||||
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
|
||||
import org.apache.spark.ml.linalg.{Vector, Vectors}
|
||||
import org.apache.spark.ml.tree._
|
||||
import org.apache.spark.ml.util.TestingUtils._
|
||||
|
@ -46,7 +47,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
/////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
test("Binary classification with continuous features: split calculation") {
|
||||
val arr = OldDTSuite.generateOrderedLabeledPointsWithLabel1().map(_.asML)
|
||||
val arr = OldDTSuite.generateOrderedLabeledPointsWithLabel1().map(_.asML.toInstance)
|
||||
assert(arr.length === 1000)
|
||||
val rdd = sc.parallelize(arr)
|
||||
val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2, 100)
|
||||
|
@ -58,7 +59,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
}
|
||||
|
||||
test("Binary classification with binary (ordered) categorical features: split calculation") {
|
||||
val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML)
|
||||
val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML.toInstance)
|
||||
assert(arr.length === 1000)
|
||||
val rdd = sc.parallelize(arr)
|
||||
val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 2,
|
||||
|
@ -75,7 +76,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
|
||||
test("Binary classification with 3-ary (ordered) categorical features," +
|
||||
" with no samples for one category: split calculation") {
|
||||
val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML)
|
||||
val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML.toInstance)
|
||||
assert(arr.length === 1000)
|
||||
val rdd = sc.parallelize(arr)
|
||||
val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 2,
|
||||
|
@ -93,12 +94,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
test("find splits for a continuous feature") {
|
||||
// find splits for normal case
|
||||
{
|
||||
val fakeMetadata = new DecisionTreeMetadata(1, 200000, 0, 0,
|
||||
val fakeMetadata = new DecisionTreeMetadata(1, 200000, 200000.0, 0, 0,
|
||||
Map(), Set(),
|
||||
Array(6), Gini, QuantileStrategy.Sort,
|
||||
0, 0, 0.0, 0, 0
|
||||
0, 0, 0.0, 0.0, 0, 0
|
||||
)
|
||||
val featureSamples = Array.fill(10000)(math.random).filter(_ != 0.0)
|
||||
val featureSamples = Array.fill(10000)((1.0, math.random)).filter(_._2 != 0.0)
|
||||
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
|
||||
assert(splits.length === 5)
|
||||
assert(fakeMetadata.numSplits(0) === 5)
|
||||
|
@ -109,15 +110,16 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
|
||||
// SPARK-16957: Use midpoints for split values.
|
||||
{
|
||||
val fakeMetadata = new DecisionTreeMetadata(1, 8, 0, 0,
|
||||
val fakeMetadata = new DecisionTreeMetadata(1, 8, 8.0, 0, 0,
|
||||
Map(), Set(),
|
||||
Array(3), Gini, QuantileStrategy.Sort,
|
||||
0, 0, 0.0, 0, 0
|
||||
0, 0, 0.0, 0.0, 0, 0
|
||||
)
|
||||
|
||||
// possibleSplits <= numSplits
|
||||
{
|
||||
val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1).map(_.toDouble).filter(_ != 0.0)
|
||||
val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1)
|
||||
.map(x => (1.0, x.toDouble)).filter(_._2 != 0.0)
|
||||
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
|
||||
val expectedSplits = Array((0.0 + 1.0) / 2)
|
||||
assert(splits === expectedSplits)
|
||||
|
@ -125,7 +127,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
|
||||
// possibleSplits > numSplits
|
||||
{
|
||||
val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3).map(_.toDouble).filter(_ != 0.0)
|
||||
val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3)
|
||||
.map(x => (1.0, x.toDouble)).filter(_._2 != 0.0)
|
||||
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
|
||||
val expectedSplits = Array((0.0 + 1.0) / 2, (2.0 + 3.0) / 2)
|
||||
assert(splits === expectedSplits)
|
||||
|
@ -135,12 +138,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
// find splits should not return identical splits
|
||||
// when there are not enough split candidates, reduce the number of splits in metadata
|
||||
{
|
||||
val fakeMetadata = new DecisionTreeMetadata(1, 12, 0, 0,
|
||||
val fakeMetadata = new DecisionTreeMetadata(1, 12, 12.0, 0, 0,
|
||||
Map(), Set(),
|
||||
Array(5), Gini, QuantileStrategy.Sort,
|
||||
0, 0, 0.0, 0, 0
|
||||
0, 0, 0.0, 0.0, 0, 0
|
||||
)
|
||||
val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3).map(_.toDouble)
|
||||
val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3).map(x => (1.0, x.toDouble))
|
||||
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
|
||||
val expectedSplits = Array((1.0 + 2.0) / 2, (2.0 + 3.0) / 2)
|
||||
assert(splits === expectedSplits)
|
||||
|
@ -150,13 +153,13 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
|
||||
// find splits when most samples close to the minimum
|
||||
{
|
||||
val fakeMetadata = new DecisionTreeMetadata(1, 18, 0, 0,
|
||||
val fakeMetadata = new DecisionTreeMetadata(1, 18, 18.0, 0, 0,
|
||||
Map(), Set(),
|
||||
Array(3), Gini, QuantileStrategy.Sort,
|
||||
0, 0, 0.0, 0, 0
|
||||
0, 0, 0.0, 0.0, 0, 0
|
||||
)
|
||||
val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5)
|
||||
.map(_.toDouble)
|
||||
val featureSamples =
|
||||
Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(x => (1.0, x.toDouble))
|
||||
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
|
||||
val expectedSplits = Array((2.0 + 3.0) / 2, (3.0 + 4.0) / 2)
|
||||
assert(splits === expectedSplits)
|
||||
|
@ -164,37 +167,55 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
|
||||
// find splits when most samples close to the maximum
|
||||
{
|
||||
val fakeMetadata = new DecisionTreeMetadata(1, 17, 0, 0,
|
||||
val fakeMetadata = new DecisionTreeMetadata(1, 17, 17.0, 0, 0,
|
||||
Map(), Set(),
|
||||
Array(2), Gini, QuantileStrategy.Sort,
|
||||
0, 0, 0.0, 0, 0
|
||||
0, 0, 0.0, 0.0, 0, 0
|
||||
)
|
||||
val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2)
|
||||
.map(_.toDouble).filter(_ != 0.0)
|
||||
val featureSamples =
|
||||
Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(x => (1.0, x.toDouble))
|
||||
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
|
||||
val expectedSplits = Array((1.0 + 2.0) / 2)
|
||||
assert(splits === expectedSplits)
|
||||
}
|
||||
|
||||
// find splits for constant feature
|
||||
// find splits for arbitrarily scaled data
|
||||
{
|
||||
val fakeMetadata = new DecisionTreeMetadata(1, 3, 0, 0,
|
||||
val fakeMetadata = new DecisionTreeMetadata(1, 0, 0.0, 0, 0,
|
||||
Map(), Set(),
|
||||
Array(6), Gini, QuantileStrategy.Sort,
|
||||
0, 0, 0.0, 0.0, 0, 0
|
||||
)
|
||||
val featureSamplesUnitWeight = Array.fill(10)((1.0, math.random))
|
||||
val featureSamplesSmallWeight = featureSamplesUnitWeight.map { case (w, x) => (w * 0.001, x)}
|
||||
val featureSamplesLargeWeight = featureSamplesUnitWeight.map { case (w, x) => (w * 1000, x)}
|
||||
val splitsUnitWeight = RandomForest
|
||||
.findSplitsForContinuousFeature(featureSamplesUnitWeight, fakeMetadata, 0)
|
||||
val splitsSmallWeight = RandomForest
|
||||
.findSplitsForContinuousFeature(featureSamplesSmallWeight, fakeMetadata, 0)
|
||||
val splitsLargeWeight = RandomForest
|
||||
.findSplitsForContinuousFeature(featureSamplesLargeWeight, fakeMetadata, 0)
|
||||
assert(splitsUnitWeight === splitsSmallWeight)
|
||||
assert(splitsUnitWeight === splitsLargeWeight)
|
||||
}
|
||||
|
||||
// find splits when most weight is close to the minimum
|
||||
{
|
||||
val fakeMetadata = new DecisionTreeMetadata(1, 0, 0.0, 0, 0,
|
||||
Map(), Set(),
|
||||
Array(3), Gini, QuantileStrategy.Sort,
|
||||
0, 0, 0.0, 0, 0
|
||||
0, 0, 0.0, 0.0, 0, 0
|
||||
)
|
||||
val featureSamples = Array(0, 0, 0).map(_.toDouble).filter(_ != 0.0)
|
||||
val featureSamplesEmpty = Array.empty[Double]
|
||||
val featureSamples = Array((10, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6)).map {
|
||||
case (w, x) => (w.toDouble, x.toDouble)
|
||||
}
|
||||
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
|
||||
assert(splits === Array.empty[Double])
|
||||
val splitsEmpty =
|
||||
RandomForest.findSplitsForContinuousFeature(featureSamplesEmpty, fakeMetadata, 0)
|
||||
assert(splitsEmpty === Array.empty[Double])
|
||||
assert(splits === Array(1.5, 2.5, 3.5, 4.5, 5.5))
|
||||
}
|
||||
}
|
||||
|
||||
test("train with empty arrays") {
|
||||
val lp = LabeledPoint(1.0, Vectors.dense(Array.empty[Double]))
|
||||
val lp = LabeledPoint(1.0, Vectors.dense(Array.empty[Double])).toInstance
|
||||
val data = Array.fill(5)(lp)
|
||||
val rdd = sc.parallelize(data)
|
||||
|
||||
|
@ -209,8 +230,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
}
|
||||
|
||||
test("train with constant features") {
|
||||
val lp = LabeledPoint(1.0, Vectors.dense(0.0, 0.0, 0.0))
|
||||
val data = Array.fill(5)(lp)
|
||||
val instance = LabeledPoint(1.0, Vectors.dense(0.0, 0.0, 0.0)).toInstance
|
||||
val data = Array.fill(5)(instance)
|
||||
val rdd = sc.parallelize(data)
|
||||
val strategy = new OldStrategy(
|
||||
OldAlgo.Classification,
|
||||
|
@ -222,7 +243,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
val Array(tree) = RandomForest.run(rdd, strategy, 1, "all", 42L, instr = None)
|
||||
assert(tree.rootNode.impurity === -1.0)
|
||||
assert(tree.depth === 0)
|
||||
assert(tree.rootNode.prediction === lp.label)
|
||||
assert(tree.rootNode.prediction === instance.label)
|
||||
|
||||
// Test with no categorical features
|
||||
val strategy2 = new OldStrategy(
|
||||
|
@ -233,11 +254,11 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
val Array(tree2) = RandomForest.run(rdd, strategy2, 1, "all", 42L, instr = None)
|
||||
assert(tree2.rootNode.impurity === -1.0)
|
||||
assert(tree2.depth === 0)
|
||||
assert(tree2.rootNode.prediction === lp.label)
|
||||
assert(tree2.rootNode.prediction === instance.label)
|
||||
}
|
||||
|
||||
test("Multiclass classification with unordered categorical features: split calculations") {
|
||||
val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML)
|
||||
val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML.toInstance)
|
||||
assert(arr.length === 1000)
|
||||
val rdd = sc.parallelize(arr)
|
||||
val strategy = new OldStrategy(
|
||||
|
@ -278,7 +299,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
}
|
||||
|
||||
test("Multiclass classification with ordered categorical features: split calculations") {
|
||||
val arr = OldDTSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures().map(_.asML)
|
||||
val arr = OldDTSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
|
||||
.map(_.asML.toInstance)
|
||||
assert(arr.length === 3000)
|
||||
val rdd = sc.parallelize(arr)
|
||||
val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 100,
|
||||
|
@ -310,7 +332,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)),
|
||||
LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
|
||||
val input = sc.parallelize(arr)
|
||||
val input = sc.parallelize(arr.map(_.toInstance))
|
||||
|
||||
val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 1,
|
||||
numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
|
||||
|
@ -352,7 +374,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)),
|
||||
LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
|
||||
val input = sc.parallelize(arr)
|
||||
val input = sc.parallelize(arr.map(_.toInstance))
|
||||
|
||||
val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 5,
|
||||
numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
|
||||
|
@ -404,7 +426,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
LabeledPoint(0.0, Vectors.dense(2.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(2.0)),
|
||||
LabeledPoint(1.0, Vectors.dense(2.0)))
|
||||
val input = sc.parallelize(arr)
|
||||
val input = sc.parallelize(arr.map(_.toInstance))
|
||||
|
||||
// Must set maxBins s.t. the feature will be treated as an ordered categorical feature.
|
||||
val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 1,
|
||||
|
@ -424,7 +446,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
}
|
||||
|
||||
test("Second level node building with vs. without groups") {
|
||||
val arr = OldDTSuite.generateOrderedLabeledPoints().map(_.asML)
|
||||
val arr = OldDTSuite.generateOrderedLabeledPoints().map(_.asML.toInstance)
|
||||
assert(arr.length === 1000)
|
||||
val rdd = sc.parallelize(arr)
|
||||
// For tree with 1 group
|
||||
|
@ -468,7 +490,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
def binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy: OldStrategy) {
|
||||
val numFeatures = 50
|
||||
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures, 1000)
|
||||
val rdd = sc.parallelize(arr).map(_.asML)
|
||||
val rdd = sc.parallelize(arr).map(_.asML.toInstance)
|
||||
|
||||
// Select feature subset for top nodes. Return true if OK.
|
||||
def checkFeatureSubsetStrategy(
|
||||
|
@ -581,16 +603,16 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
left2 parent
|
||||
left right
|
||||
*/
|
||||
val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0))
|
||||
val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0), 6L)
|
||||
val left = new LeafNode(0.0, leftImp.calculate(), leftImp)
|
||||
|
||||
val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0))
|
||||
val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0), 8L)
|
||||
val right = new LeafNode(2.0, rightImp.calculate(), rightImp)
|
||||
|
||||
val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5))
|
||||
val parentImp = parent.impurityStats
|
||||
|
||||
val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0))
|
||||
val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0), 8L)
|
||||
val left2 = new LeafNode(0.0, left2Imp.calculate(), left2Imp)
|
||||
|
||||
val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0))
|
||||
|
@ -647,12 +669,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
// feature_0 = 0 improves the impurity measure, despite the prediction will always be 0
|
||||
// in both branches.
|
||||
val arr = Array(
|
||||
LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
|
||||
LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
|
||||
LabeledPoint(1.0, Vectors.dense(1.0, 0.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(1.0, 0.0)),
|
||||
LabeledPoint(1.0, Vectors.dense(1.0, 1.0))
|
||||
Instance(0.0, 1.0, Vectors.dense(0.0, 1.0)),
|
||||
Instance(1.0, 1.0, Vectors.dense(0.0, 1.0)),
|
||||
Instance(0.0, 1.0, Vectors.dense(0.0, 0.0)),
|
||||
Instance(1.0, 1.0, Vectors.dense(1.0, 0.0)),
|
||||
Instance(0.0, 1.0, Vectors.dense(1.0, 0.0)),
|
||||
Instance(1.0, 1.0, Vectors.dense(1.0, 1.0))
|
||||
)
|
||||
val rdd = sc.parallelize(arr)
|
||||
|
||||
|
@ -677,13 +699,13 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
// feature_1 = 1 improves the impurity measure, despite the prediction will always be 0.5
|
||||
// in both branches.
|
||||
val arr = Array(
|
||||
LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
|
||||
LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(1.0, 0.0)),
|
||||
LabeledPoint(1.0, Vectors.dense(1.0, 1.0)),
|
||||
LabeledPoint(0.0, Vectors.dense(1.0, 1.0)),
|
||||
LabeledPoint(0.5, Vectors.dense(1.0, 1.0))
|
||||
Instance(0.0, 1.0, Vectors.dense(0.0, 1.0)),
|
||||
Instance(1.0, 1.0, Vectors.dense(0.0, 1.0)),
|
||||
Instance(0.0, 1.0, Vectors.dense(0.0, 0.0)),
|
||||
Instance(0.0, 1.0, Vectors.dense(1.0, 0.0)),
|
||||
Instance(1.0, 1.0, Vectors.dense(1.0, 1.0)),
|
||||
Instance(0.0, 1.0, Vectors.dense(1.0, 1.0)),
|
||||
Instance(0.5, 1.0, Vectors.dense(1.0, 1.0))
|
||||
)
|
||||
val rdd = sc.parallelize(arr)
|
||||
|
||||
|
@ -700,6 +722,56 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
assert(unprunedTree.numNodes === 5)
|
||||
assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.size)
|
||||
}
|
||||
|
||||
test("weights at arbitrary scale") {
|
||||
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(3, 10)
|
||||
val rddWithUnitWeights = sc.parallelize(arr.map(_.asML.toInstance))
|
||||
val rddWithSmallWeights = rddWithUnitWeights.map { inst =>
|
||||
Instance(inst.label, 0.001, inst.features)
|
||||
}
|
||||
val rddWithBigWeights = rddWithUnitWeights.map { inst =>
|
||||
Instance(inst.label, 1000, inst.features)
|
||||
}
|
||||
val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2)
|
||||
val unitWeightTrees = RandomForest.run(rddWithUnitWeights, strategy, 3, "all", 42L, None)
|
||||
|
||||
val smallWeightTrees = RandomForest.run(rddWithSmallWeights, strategy, 3, "all", 42L, None)
|
||||
unitWeightTrees.zip(smallWeightTrees).foreach { case (unitTree, smallWeightTree) =>
|
||||
TreeTests.checkEqual(unitTree, smallWeightTree)
|
||||
}
|
||||
|
||||
val bigWeightTrees = RandomForest.run(rddWithBigWeights, strategy, 3, "all", 42L, None)
|
||||
unitWeightTrees.zip(bigWeightTrees).foreach { case (unitTree, bigWeightTree) =>
|
||||
TreeTests.checkEqual(unitTree, bigWeightTree)
|
||||
}
|
||||
}
|
||||
|
||||
test("minWeightFraction and minInstancesPerNode") {
|
||||
val data = Array(
|
||||
Instance(0.0, 1.0, Vectors.dense(0.0)),
|
||||
Instance(0.0, 1.0, Vectors.dense(0.0)),
|
||||
Instance(0.0, 1.0, Vectors.dense(0.0)),
|
||||
Instance(0.0, 1.0, Vectors.dense(0.0)),
|
||||
Instance(1.0, 0.1, Vectors.dense(1.0))
|
||||
)
|
||||
val rdd = sc.parallelize(data)
|
||||
val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2,
|
||||
minWeightFractionPerNode = 0.5)
|
||||
val Array(tree1) = RandomForest.run(rdd, strategy, 1, "all", 42L, None)
|
||||
assert(tree1.depth === 0)
|
||||
|
||||
strategy.minWeightFractionPerNode = 0.0
|
||||
val Array(tree2) = RandomForest.run(rdd, strategy, 1, "all", 42L, None)
|
||||
assert(tree2.depth === 1)
|
||||
|
||||
strategy.minInstancesPerNode = 2
|
||||
val Array(tree3) = RandomForest.run(rdd, strategy, 1, "all", 42L, None)
|
||||
assert(tree3.depth === 0)
|
||||
|
||||
strategy.minInstancesPerNode = 1
|
||||
val Array(tree4) = RandomForest.run(rdd, strategy, 1, "all", 42L, None)
|
||||
assert(tree4.depth === 1)
|
||||
}
|
||||
}
|
||||
|
||||
private object RandomForestSuite {
|
||||
|
@ -717,7 +789,7 @@ private object RandomForestSuite {
|
|||
else {
|
||||
nodes.head match {
|
||||
case i: InternalNode => getSumLeafCounters(i.leftChild :: i.rightChild :: nodes.tail, acc)
|
||||
case l: LeafNode => getSumLeafCounters(nodes.tail, acc + l.impurityStats.count)
|
||||
case l: LeafNode => getSumLeafCounters(nodes.tail, acc + l.impurityStats.rawCount)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -28,7 +28,7 @@ class TreePointSuite extends SparkFunSuite {
|
|||
|
||||
val ser = new KryoSerializer(conf).newInstance()
|
||||
|
||||
val point = new TreePoint(1.0, Array(1, 2, 3))
|
||||
val point = new TreePoint(1.0, Array(1, 2, 3), 1.0)
|
||||
val point2 = ser.deserialize[TreePoint](ser.serialize(point))
|
||||
assert(point.label === point2.label)
|
||||
assert(point.binnedFeatures === point2.binnedFeatures)
|
||||
|
|
|
@ -18,13 +18,15 @@
|
|||
package org.apache.spark.ml.tree.impl
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.util.Random
|
||||
|
||||
import org.apache.spark.{SparkContext, SparkFunSuite}
|
||||
import org.apache.spark.api.java.JavaRDD
|
||||
import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
|
||||
import org.apache.spark.ml.feature.LabeledPoint
|
||||
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
|
||||
import org.apache.spark.ml.linalg.Vectors
|
||||
import org.apache.spark.ml.tree._
|
||||
import org.apache.spark.mllib.util.TestingUtils._
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.{DataFrame, SparkSession}
|
||||
|
||||
|
@ -32,6 +34,7 @@ private[ml] object TreeTests extends SparkFunSuite {
|
|||
|
||||
/**
|
||||
* Convert the given data to a DataFrame, and set the features and label metadata.
|
||||
*
|
||||
* @param data Dataset. Categorical features and labels must already have 0-based indices.
|
||||
* This must be non-empty.
|
||||
* @param categoricalFeatures Map: categorical feature index to number of distinct values
|
||||
|
@ -39,16 +42,22 @@ private[ml] object TreeTests extends SparkFunSuite {
|
|||
* @return DataFrame with metadata
|
||||
*/
|
||||
def setMetadata(
|
||||
data: RDD[LabeledPoint],
|
||||
data: RDD[_],
|
||||
categoricalFeatures: Map[Int, Int],
|
||||
numClasses: Int): DataFrame = {
|
||||
val dataOfInstance: RDD[Instance] = data.map {
|
||||
_ match {
|
||||
case instance: Instance => instance
|
||||
case labeledPoint: LabeledPoint => labeledPoint.toInstance
|
||||
}
|
||||
}
|
||||
val spark = SparkSession.builder()
|
||||
.sparkContext(data.sparkContext)
|
||||
.getOrCreate()
|
||||
import spark.implicits._
|
||||
|
||||
val df = data.toDF()
|
||||
val numFeatures = data.first().features.size
|
||||
val df = dataOfInstance.toDF()
|
||||
val numFeatures = dataOfInstance.first().features.size
|
||||
val featuresAttributes = Range(0, numFeatures).map { feature =>
|
||||
if (categoricalFeatures.contains(feature)) {
|
||||
NominalAttribute.defaultAttr.withIndex(feature).withNumValues(categoricalFeatures(feature))
|
||||
|
@ -64,7 +73,7 @@ private[ml] object TreeTests extends SparkFunSuite {
|
|||
}
|
||||
val labelMetadata = labelAttribute.toMetadata()
|
||||
df.select(df("features").as("features", featuresMetadata),
|
||||
df("label").as("label", labelMetadata))
|
||||
df("label").as("label", labelMetadata), df("weight"))
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -80,6 +89,7 @@ private[ml] object TreeTests extends SparkFunSuite {
|
|||
|
||||
/**
|
||||
* Set label metadata (particularly the number of classes) on a DataFrame.
|
||||
*
|
||||
* @param data Dataset. Categorical features and labels must already have 0-based indices.
|
||||
* This must be non-empty.
|
||||
* @param numClasses Number of classes label can take. If 0, mark as continuous.
|
||||
|
@ -124,8 +134,8 @@ private[ml] object TreeTests extends SparkFunSuite {
|
|||
* make mistakes such as creating loops of Nodes.
|
||||
*/
|
||||
private def checkEqual(a: Node, b: Node): Unit = {
|
||||
assert(a.prediction === b.prediction)
|
||||
assert(a.impurity === b.impurity)
|
||||
assert(a.prediction ~== b.prediction absTol 1e-8)
|
||||
assert(a.impurity ~== b.impurity absTol 1e-8)
|
||||
(a, b) match {
|
||||
case (aye: InternalNode, bee: InternalNode) =>
|
||||
assert(aye.split === bee.split)
|
||||
|
@ -156,6 +166,7 @@ private[ml] object TreeTests extends SparkFunSuite {
|
|||
/**
|
||||
* Helper method for constructing a tree for testing.
|
||||
* Given left, right children, construct a parent node.
|
||||
*
|
||||
* @param split Split for parent node
|
||||
* @return Parent node with children attached
|
||||
*/
|
||||
|
@ -163,8 +174,8 @@ private[ml] object TreeTests extends SparkFunSuite {
|
|||
val leftImp = left.impurityStats
|
||||
val rightImp = right.impurityStats
|
||||
val parentImp = leftImp.copy.add(rightImp)
|
||||
val leftWeight = leftImp.count / parentImp.count.toDouble
|
||||
val rightWeight = rightImp.count / parentImp.count.toDouble
|
||||
val leftWeight = leftImp.count / parentImp.count
|
||||
val rightWeight = rightImp.count / parentImp.count
|
||||
val gain = parentImp.calculate() -
|
||||
(leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate())
|
||||
val pred = parentImp.predict
|
||||
|
|
|
@ -23,7 +23,7 @@ import org.apache.spark.ml.evaluation.Evaluator
|
|||
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
|
||||
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
|
||||
import org.apache.spark.ml.param.ParamMap
|
||||
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasWeightCol}
|
||||
import org.apache.spark.ml.param.shared.HasWeightCol
|
||||
import org.apache.spark.ml.recommendation.{ALS, ALSModel}
|
||||
import org.apache.spark.ml.tree.impl.TreeTests
|
||||
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
|
||||
|
@ -205,8 +205,8 @@ object MLTestingUtils extends SparkFunSuite {
|
|||
seed: Long): Unit = {
|
||||
val (overSampledData, weightedData) = genEquivalentOversampledAndWeightedInstances(
|
||||
data, seed)
|
||||
val weightedModel = estimator.set(estimator.weightCol, "weight").fit(weightedData)
|
||||
val overSampledModel = estimator.set(estimator.weightCol, "").fit(overSampledData)
|
||||
val weightedModel = estimator.set(estimator.weightCol, "weight").fit(weightedData)
|
||||
modelEquals(weightedModel, overSampledModel)
|
||||
}
|
||||
|
||||
|
@ -228,7 +228,8 @@ object MLTestingUtils extends SparkFunSuite {
|
|||
List.fill(outlierRatio)(Instance(outlierLabel, 0.0001, f)) ++ List(Instance(l, w, f))
|
||||
}
|
||||
val trueModel = estimator.set(estimator.weightCol, "").fit(data)
|
||||
val outlierModel = estimator.set(estimator.weightCol, "weight").fit(outlierDS)
|
||||
val outlierModel = estimator.set(estimator.weightCol, "weight")
|
||||
.fit(outlierDS)
|
||||
modelEquals(trueModel, outlierModel)
|
||||
}
|
||||
|
||||
|
@ -241,7 +242,7 @@ object MLTestingUtils extends SparkFunSuite {
|
|||
estimator: E with HasWeightCol,
|
||||
modelEquals: (M, M) => Unit): Unit = {
|
||||
estimator.set(estimator.weightCol, "weight")
|
||||
val models = Seq(0.001, 1.0, 1000.0).map { w =>
|
||||
val models = Seq(0.01, 1.0, 1000.0).map { w =>
|
||||
val df = data.withColumn("weight", lit(w))
|
||||
estimator.fit(df)
|
||||
}
|
||||
|
@ -268,4 +269,20 @@ object MLTestingUtils extends SparkFunSuite {
|
|||
assert(newDatasetF.schema(featuresColName).dataType.equals(new ArrayType(FloatType, false)))
|
||||
(newDataset, newDatasetD, newDatasetF)
|
||||
}
|
||||
|
||||
def modelPredictionEquals[M <: PredictionModel[_, M]](
|
||||
data: DataFrame,
|
||||
compareFunc: (Double, Double) => Boolean,
|
||||
fractionInTol: Double)(
|
||||
model1: M,
|
||||
model2: M): Unit = {
|
||||
val pred1 = model1.transform(data).select(model1.getPredictionCol).collect()
|
||||
val pred2 = model2.transform(data).select(model2.getPredictionCol).collect()
|
||||
val inTol = pred1.zip(pred2).count { case (p1, p2) =>
|
||||
val x = p1.getDouble(0)
|
||||
val y = p2.getDouble(0)
|
||||
compareFunc(x, y)
|
||||
}
|
||||
assert(inTol / pred1.length.toDouble >= fractionInTol)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -73,7 +73,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
maxBins = 100,
|
||||
categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
|
||||
|
||||
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
|
||||
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
|
||||
assert(!metadata.isUnordered(featureIndex = 0))
|
||||
assert(!metadata.isUnordered(featureIndex = 1))
|
||||
|
||||
|
@ -100,7 +100,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
maxDepth = 2,
|
||||
maxBins = 100,
|
||||
categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2))
|
||||
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
|
||||
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
|
||||
assert(!metadata.isUnordered(featureIndex = 0))
|
||||
assert(!metadata.isUnordered(featureIndex = 1))
|
||||
|
||||
|
@ -116,7 +116,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
val rdd = sc.parallelize(arr)
|
||||
val strategy = new Strategy(Classification, Gini, maxDepth = 3,
|
||||
numClasses = 2, maxBins = 100)
|
||||
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
|
||||
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
|
||||
assert(!metadata.isUnordered(featureIndex = 0))
|
||||
assert(!metadata.isUnordered(featureIndex = 1))
|
||||
|
||||
|
@ -133,7 +133,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
val rdd = sc.parallelize(arr)
|
||||
val strategy = new Strategy(Classification, Gini, maxDepth = 3,
|
||||
numClasses = 2, maxBins = 100)
|
||||
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
|
||||
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
|
||||
assert(!metadata.isUnordered(featureIndex = 0))
|
||||
assert(!metadata.isUnordered(featureIndex = 1))
|
||||
|
||||
|
@ -150,7 +150,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
val rdd = sc.parallelize(arr)
|
||||
val strategy = new Strategy(Classification, Entropy, maxDepth = 3,
|
||||
numClasses = 2, maxBins = 100)
|
||||
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
|
||||
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
|
||||
assert(!metadata.isUnordered(featureIndex = 0))
|
||||
assert(!metadata.isUnordered(featureIndex = 1))
|
||||
|
||||
|
@ -167,7 +167,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
val rdd = sc.parallelize(arr)
|
||||
val strategy = new Strategy(Classification, Entropy, maxDepth = 3,
|
||||
numClasses = 2, maxBins = 100)
|
||||
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
|
||||
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
|
||||
assert(!metadata.isUnordered(featureIndex = 0))
|
||||
assert(!metadata.isUnordered(featureIndex = 1))
|
||||
|
||||
|
@ -183,7 +183,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
val rdd = sc.parallelize(arr)
|
||||
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
|
||||
numClasses = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
|
||||
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
|
||||
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
|
||||
assert(strategy.isMulticlassClassification)
|
||||
assert(metadata.isUnordered(featureIndex = 0))
|
||||
assert(metadata.isUnordered(featureIndex = 1))
|
||||
|
@ -240,7 +240,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
numClasses = 3, maxBins = maxBins,
|
||||
categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
|
||||
assert(strategy.isMulticlassClassification)
|
||||
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
|
||||
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
|
||||
assert(metadata.isUnordered(featureIndex = 0))
|
||||
assert(metadata.isUnordered(featureIndex = 1))
|
||||
|
||||
|
@ -288,7 +288,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
|
||||
numClasses = 3, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3))
|
||||
assert(strategy.isMulticlassClassification)
|
||||
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
|
||||
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
|
||||
assert(metadata.isUnordered(featureIndex = 0))
|
||||
|
||||
val model = DecisionTree.train(rdd, strategy)
|
||||
|
@ -310,7 +310,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
|
|||
numClasses = 3, maxBins = 100,
|
||||
categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
|
||||
assert(strategy.isMulticlassClassification)
|
||||
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
|
||||
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
|
||||
assert(!metadata.isUnordered(featureIndex = 0))
|
||||
assert(!metadata.isUnordered(featureIndex = 1))
|
||||
|
||||
|
|
|
@ -18,23 +18,63 @@
|
|||
package org.apache.spark.mllib.tree
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator}
|
||||
import org.apache.spark.ml.util.TestingUtils._
|
||||
import org.apache.spark.mllib.tree.impurity._
|
||||
|
||||
/**
|
||||
* Test suites for `GiniAggregator` and `EntropyAggregator`.
|
||||
*/
|
||||
class ImpuritySuite extends SparkFunSuite {
|
||||
|
||||
private val seed = 42
|
||||
|
||||
test("Gini impurity does not support negative labels") {
|
||||
val gini = new GiniAggregator(2)
|
||||
intercept[IllegalArgumentException] {
|
||||
gini.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0)
|
||||
gini.update(Array(0.0, 1.0, 2.0), 0, -1, 3, 0.0)
|
||||
}
|
||||
}
|
||||
|
||||
test("Entropy does not support negative labels") {
|
||||
val entropy = new EntropyAggregator(2)
|
||||
intercept[IllegalArgumentException] {
|
||||
entropy.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0)
|
||||
entropy.update(Array(0.0, 1.0, 2.0), 0, -1, 3, 0.0)
|
||||
}
|
||||
}
|
||||
|
||||
test("Classification impurities are insensitive to scaling") {
|
||||
val rng = new scala.util.Random(seed)
|
||||
val weightedCounts = Array.fill(5)(rng.nextDouble())
|
||||
val smallWeightedCounts = weightedCounts.map(_ * 0.0001)
|
||||
val largeWeightedCounts = weightedCounts.map(_ * 10000)
|
||||
Seq(Gini, Entropy).foreach { impurity =>
|
||||
val impurity1 = impurity.calculate(weightedCounts, weightedCounts.sum)
|
||||
assert(impurity.calculate(smallWeightedCounts, smallWeightedCounts.sum)
|
||||
~== impurity1 relTol 0.005)
|
||||
assert(impurity.calculate(largeWeightedCounts, largeWeightedCounts.sum)
|
||||
~== impurity1 relTol 0.005)
|
||||
}
|
||||
}
|
||||
|
||||
test("Regression impurities are insensitive to scaling") {
|
||||
def computeStats(samples: Seq[Double], weights: Seq[Double]): (Double, Double, Double) = {
|
||||
samples.zip(weights).foldLeft((0.0, 0.0, 0.0)) { case ((wn, wy, wyy), (y, w)) =>
|
||||
(wn + w, wy + w * y, wyy + w * y * y)
|
||||
}
|
||||
}
|
||||
val rng = new scala.util.Random(seed)
|
||||
val samples = Array.fill(10)(rng.nextDouble())
|
||||
val _weights = Array.fill(10)(rng.nextDouble())
|
||||
val smallWeights = _weights.map(_ * 0.0001)
|
||||
val largeWeights = _weights.map(_ * 10000)
|
||||
val (count, sum, sumSquared) = computeStats(samples, _weights)
|
||||
Seq(Variance).foreach { impurity =>
|
||||
val impurity1 = impurity.calculate(count, sum, sumSquared)
|
||||
val (smallCount, smallSum, smallSumSquared) = computeStats(samples, smallWeights)
|
||||
val (largeCount, largeSum, largeSumSquared) = computeStats(samples, largeWeights)
|
||||
assert(impurity.calculate(smallCount, smallSum, smallSumSquared) ~== impurity1 relTol 0.005)
|
||||
assert(impurity.calculate(largeCount, largeSum, largeSumSquared) ~== impurity1 relTol 0.005)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue