[SPARK-19591][ML][MLLIB] Add sample weights to decision trees

This is updated PR https://github.com/apache/spark/pull/16722 to latest master

## What changes were proposed in this pull request?

This patch adds support for sample weights to DecisionTreeRegressor and DecisionTreeClassifier.

Note: This patch does not add support for sample weights to RandomForest. As discussed in the JIRA, we would like to add sample weights into the bagging process. This patch is large enough as is, and there are some additional considerations to be made for random forests. Since the machinery introduced here needs to be present regardless, I have opted to leave random forests for a follow up pr.
## How was this patch tested?

The algorithms are tested to ensure that:
    1. Arbitrary scaling of constant weights has no effect
    2. Outliers with small weights do not affect the learned model
    3. Oversampling and weighting are equivalent

Unit tests are also added to test other smaller components.
## Summary of changes

   - Impurity aggregators now store weighted sufficient statistics. They also store a raw count, however, since this is needed to use minInstancesPerNode.

   - Impurity aggregators now also hold the raw count.

   - This patch maintains the meaning of minInstancesPerNode, in that the parameter still corresponds to raw, unweighted counts. It also adds a new parameter minWeightFractionPerNode which requires that nodes must contain at least minWeightFractionPerNode * weightedNumExamples total weight.

   - This patch modifies findSplitsForContinuousFeatures to use weighted sums. Unit tests are added.

   - TreePoint is modified to hold a sample weight

   - BaggedPoint is modified from:
``` Scala
private[spark] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double]) extends Serializable
```
to
``` Scala
private[spark] class BaggedPoint[Datum](
    val datum: Datum,
    val subsampleCounts: Array[Int],
    val sampleWeight: Double) extends Serializable
```
We do not simply multiply the counts by the weight and store that because we need the raw counts and the weight in order to use both minInstancesPerNode and minWeightPerNode

**Note**: many of the changed files are due simply to using Instance instead of LabeledPoint

Closes #21632 from imatiach-msft/ilmat/sample-weights.

Authored-by: Ilya Matiach <ilmat@microsoft.com>
Signed-off-by: Sean Owen <sean.owen@databricks.com>
This commit is contained in:
Ilya Matiach 2019-01-24 18:20:28 -07:00 committed by Sean Owen
parent 3699763fda
commit b2d36f65db
31 changed files with 743 additions and 280 deletions

View file

@ -52,7 +52,7 @@ object TestingUtils {
/** /**
* Private helper function for comparing two values using absolute tolerance. * Private helper function for comparing two values using absolute tolerance.
*/ */
private def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = { private[ml] def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
// Special case for NaNs // Special case for NaNs
if (x.isNaN && y.isNaN) { if (x.isNaN && y.isNaN) {
return true return true

View file

@ -77,17 +77,37 @@ abstract class Classifier[
* @note Throws `SparkException` if any label is a non-integer or is negative * @note Throws `SparkException` if any label is a non-integer or is negative
*/ */
protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = { protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = {
require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" + validateNumClasses(numClasses)
s" $numClasses, but requires numClasses > 0.")
dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
case Row(label: Double, features: Vector) => case Row(label: Double, features: Vector) =>
require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" + validateLabel(label, numClasses)
s" dataset with invalid label $label. Labels must be integers in range" +
s" [0, $numClasses).")
LabeledPoint(label, features) LabeledPoint(label, features)
} }
} }
/**
* Validates that number of classes is greater than zero.
*
* @param numClasses Number of classes label can take.
*/
protected def validateNumClasses(numClasses: Int): Unit = {
require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" +
s" $numClasses, but requires numClasses > 0.")
}
/**
* Validates the label on the classifier is a valid integer in the range [0, numClasses).
*
* @param label The label to validate.
* @param numClasses Number of classes label can take. Labels must be integers in the range
* [0, numClasses).
*/
protected def validateLabel(label: Double, numClasses: Int): Unit = {
require(label.toLong == label && label >= 0 && label < numClasses, s"Classifier was given" +
s" dataset with invalid label $label. Labels must be integers in range" +
s" [0, $numClasses).")
}
/** /**
* Get the number of classes. This looks in column metadata first, and if that is missing, * Get the number of classes. This looks in column metadata first, and if that is missing,
* then this assumes classes are indexed 0,1,...,numClasses-1 and computes numClasses * then this assumes classes are indexed 0,1,...,numClasses-1 and computes numClasses

View file

@ -22,10 +22,12 @@ import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._ import org.json4s.JsonDSL._
import org.apache.spark.annotation.Since import org.apache.spark.annotation.Since
import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.{DecisionTreeModel, Node, TreeClassifierParams}
import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._ import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._
import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.util._ import org.apache.spark.ml.util._
@ -33,8 +35,9 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.types.DoubleType
/** /**
* Decision tree learning algorithm (http://en.wikipedia.org/wiki/Decision_tree_learning) * Decision tree learning algorithm (http://en.wikipedia.org/wiki/Decision_tree_learning)
@ -66,6 +69,9 @@ class DecisionTreeClassifier @Since("1.4.0") (
def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
/** @group setParam */ /** @group setParam */
@Since("3.0.0")
def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value)
@Since("1.4.0") @Since("1.4.0")
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
@ -97,6 +103,16 @@ class DecisionTreeClassifier @Since("1.4.0") (
@Since("1.6.0") @Since("1.6.0")
def setSeed(value: Long): this.type = set(seed, value) def setSeed(value: Long): this.type = set(seed, value)
/**
* Sets the value of param [[weightCol]].
* If this is not set or empty, we treat all instance weights as 1.0.
* Default is not set, so all instances have weight one.
*
* @group setParam
*/
@Since("3.0.0")
def setWeightCol(value: String): this.type = set(weightCol, value)
override protected def train( override protected def train(
dataset: Dataset[_]): DecisionTreeClassificationModel = instrumented { instr => dataset: Dataset[_]): DecisionTreeClassificationModel = instrumented { instr =>
instr.logPipelineStage(this) instr.logPipelineStage(this)
@ -104,22 +120,27 @@ class DecisionTreeClassifier @Since("1.4.0") (
val categoricalFeatures: Map[Int, Int] = val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses: Int = getNumClasses(dataset) val numClasses: Int = getNumClasses(dataset)
instr.logNumClasses(numClasses)
if (isDefined(thresholds)) { if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName + require($(thresholds).length == numClasses, this.getClass.getSimpleName +
".train() called with non-matching numClasses and thresholds.length." + ".train() called with non-matching numClasses and thresholds.length." +
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
} }
validateNumClasses(numClasses)
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses) val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances =
dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
validateLabel(label, numClasses)
Instance(label, weight, features)
}
val strategy = getOldStrategy(categoricalFeatures, numClasses) val strategy = getOldStrategy(categoricalFeatures, numClasses)
instr.logNumClasses(numClasses)
instr.logParams(this, labelCol, featuresCol, predictionCol, rawPredictionCol, instr.logParams(this, labelCol, featuresCol, predictionCol, rawPredictionCol,
probabilityCol, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB, probabilityCol, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB,
cacheNodeIds, checkpointInterval, impurity, seed) cacheNodeIds, checkpointInterval, impurity, seed)
val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", val trees = RandomForest.run(instances, strategy, numTrees = 1, featureSubsetStrategy = "all",
seed = $(seed), instr = Some(instr), parentUID = Some(uid)) seed = $(seed), instr = Some(instr), parentUID = Some(uid))
trees.head.asInstanceOf[DecisionTreeClassificationModel] trees.head.asInstanceOf[DecisionTreeClassificationModel]
@ -128,13 +149,13 @@ class DecisionTreeClassifier @Since("1.4.0") (
/** (private[ml]) Train a decision tree on an RDD */ /** (private[ml]) Train a decision tree on an RDD */
private[ml] def train(data: RDD[LabeledPoint], private[ml] def train(data: RDD[LabeledPoint],
oldStrategy: OldStrategy): DecisionTreeClassificationModel = instrumented { instr => oldStrategy: OldStrategy): DecisionTreeClassificationModel = instrumented { instr =>
val instances = data.map(_.toInstance)
instr.logPipelineStage(this) instr.logPipelineStage(this)
instr.logDataset(data) instr.logDataset(instances)
instr.logParams(this, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB, instr.logParams(this, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB,
cacheNodeIds, checkpointInterval, impurity, seed) cacheNodeIds, checkpointInterval, impurity, seed)
val trees = RandomForest.run(instances, oldStrategy, numTrees = 1,
val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all", featureSubsetStrategy = "all", seed = 0L, instr = Some(instr), parentUID = Some(uid))
seed = 0L, instr = Some(instr), parentUID = Some(uid))
trees.head.asInstanceOf[DecisionTreeClassificationModel] trees.head.asInstanceOf[DecisionTreeClassificationModel]
} }
@ -180,6 +201,7 @@ class DecisionTreeClassificationModel private[ml] (
/** /**
* Construct a decision tree classification model. * Construct a decision tree classification model.
*
* @param rootNode Root node of tree, with other nodes attached. * @param rootNode Root node of tree, with other nodes attached.
*/ */
private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) = private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) =

View file

@ -21,20 +21,21 @@ import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._ import org.json4s.JsonDSL._
import org.apache.spark.annotation.Since import org.apache.spark.annotation.Since
import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.{RandomForestParams, TreeClassifierParams, TreeEnsembleModel}
import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.util._ import org.apache.spark.ml.util._
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions.{col, udf}
/** /**
* <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forest</a> learning algorithm for * <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forest</a> learning algorithm for
@ -130,7 +131,7 @@ class RandomForestClassifier @Since("1.4.0") (
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
} }
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses) val instances: RDD[Instance] = extractLabeledPoints(dataset, numClasses).map(_.toInstance)
val strategy = val strategy =
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity) super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
@ -139,7 +140,7 @@ class RandomForestClassifier @Since("1.4.0") (
minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, checkpointInterval) minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, checkpointInterval)
val trees = RandomForest val trees = RandomForest
.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr)) .run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
.map(_.asInstanceOf[DecisionTreeClassificationModel]) .map(_.asInstanceOf[DecisionTreeClassificationModel])
val numFeatures = trees.head.numFeatures val numFeatures = trees.head.numFeatures

View file

@ -37,4 +37,13 @@ case class LabeledPoint(@Since("2.0.0") label: Double, @Since("2.0.0") features:
override def toString: String = { override def toString: String = {
s"($label,$features)" s"($label,$features)"
} }
private[spark] def toInstance(weight: Double): Instance = {
Instance(label, weight, features)
}
private[spark] def toInstance: Instance = {
Instance(label, 1.0, features)
}
} }

View file

@ -23,9 +23,10 @@ import org.json4s.JsonDSL._
import org.apache.spark.annotation.Since import org.apache.spark.annotation.Since
import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._ import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._
import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.tree.impl.RandomForest
@ -34,8 +35,9 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType
/** /**
@ -65,6 +67,9 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
/** @group setParam */ /** @group setParam */
@Since("3.0.0")
def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value)
@Since("1.4.0") @Since("1.4.0")
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
@ -100,18 +105,33 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
@Since("2.0.0") @Since("2.0.0")
def setVarianceCol(value: String): this.type = set(varianceCol, value) def setVarianceCol(value: String): this.type = set(varianceCol, value)
/**
* Sets the value of param [[weightCol]].
* If this is not set or empty, we treat all instance weights as 1.0.
* Default is not set, so all instances have weight one.
*
* @group setParam
*/
@Since("3.0.0")
def setWeightCol(value: String): this.type = set(weightCol, value)
override protected def train( override protected def train(
dataset: Dataset[_]): DecisionTreeRegressionModel = instrumented { instr => dataset: Dataset[_]): DecisionTreeRegressionModel = instrumented { instr =>
val categoricalFeatures: Map[Int, Int] = val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances =
dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
val strategy = getOldStrategy(categoricalFeatures) val strategy = getOldStrategy(categoricalFeatures)
instr.logPipelineStage(this) instr.logPipelineStage(this)
instr.logDataset(oldDataset) instr.logDataset(instances)
instr.logParams(this, params: _*) instr.logParams(this, params: _*)
val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", val trees = RandomForest.run(instances, strategy, numTrees = 1, featureSubsetStrategy = "all",
seed = $(seed), instr = Some(instr), parentUID = Some(uid)) seed = $(seed), instr = Some(instr), parentUID = Some(uid))
trees.head.asInstanceOf[DecisionTreeRegressionModel] trees.head.asInstanceOf[DecisionTreeRegressionModel]
@ -126,8 +146,9 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
instr.logDataset(data) instr.logDataset(data)
instr.logParams(this, params: _*) instr.logParams(this, params: _*)
val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy, val instances = data.map(_.toInstance)
seed = $(seed), instr = Some(instr), parentUID = Some(uid)) val trees = RandomForest.run(instances, oldStrategy, numTrees = 1,
featureSubsetStrategy, seed = $(seed), instr = Some(instr), parentUID = Some(uid))
trees.head.asInstanceOf[DecisionTreeRegressionModel] trees.head.asInstanceOf[DecisionTreeRegressionModel]
} }
@ -155,6 +176,7 @@ object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor
* <a href="http://en.wikipedia.org/wiki/Decision_tree_learning"> * <a href="http://en.wikipedia.org/wiki/Decision_tree_learning">
* Decision tree (Wikipedia)</a> model for regression. * Decision tree (Wikipedia)</a> model for regression.
* It supports both continuous and categorical features. * It supports both continuous and categorical features.
*
* @param rootNode Root of the decision tree * @param rootNode Root of the decision tree
*/ */
@Since("1.4.0") @Since("1.4.0")
@ -173,6 +195,7 @@ class DecisionTreeRegressionModel private[ml] (
/** /**
* Construct a decision tree regression model. * Construct a decision tree regression model.
*
* @param rootNode Root node of tree, with other nodes attached. * @param rootNode Root node of tree, with other nodes attached.
*/ */
private[ml] def this(rootNode: Node, numFeatures: Int) = private[ml] def this(rootNode: Node, numFeatures: Int) =

View file

@ -22,7 +22,6 @@ import org.json4s.JsonDSL._
import org.apache.spark.annotation.Since import org.apache.spark.annotation.Since
import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree._
@ -32,10 +31,8 @@ import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions.{col, udf}
/** /**
* <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forest</a> * <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forest</a>
@ -119,18 +116,19 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
dataset: Dataset[_]): RandomForestRegressionModel = instrumented { instr => dataset: Dataset[_]): RandomForestRegressionModel = instrumented { instr =>
val categoricalFeatures: Map[Int, Int] = val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val instances = extractLabeledPoints(dataset).map(_.toInstance)
val strategy = val strategy =
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity) super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity)
instr.logPipelineStage(this) instr.logPipelineStage(this)
instr.logDataset(dataset) instr.logDataset(instances)
instr.logParams(this, labelCol, featuresCol, predictionCol, impurity, numTrees, instr.logParams(this, labelCol, featuresCol, predictionCol, impurity, numTrees,
featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain, featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain,
minInstancesPerNode, seed, subsamplingRate, cacheNodeIds, checkpointInterval) minInstancesPerNode, seed, subsamplingRate, cacheNodeIds, checkpointInterval)
val trees = RandomForest val trees = RandomForest
.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr)) .run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
.map(_.asInstanceOf[DecisionTreeRegressionModel]) .map(_.asInstanceOf[DecisionTreeRegressionModel])
val numFeatures = trees.head.numFeatures val numFeatures = trees.head.numFeatures

View file

@ -33,13 +33,13 @@ import org.apache.spark.util.random.XORShiftRandom
* this datum has 1 copy, 0 copies, and 4 copies in the 3 subsamples, respectively. * this datum has 1 copy, 0 copies, and 4 copies in the 3 subsamples, respectively.
* *
* @param datum Data instance * @param datum Data instance
* @param subsampleWeights Weight of this instance in each subsampled dataset. * @param subsampleCounts Number of samples of this instance in each subsampled dataset.
* * @param sampleWeight The weight of this instance.
* TODO: This does not currently support (Double) weighted instances. Once MLlib has weighted
* dataset support, update. (We store subsampleWeights as Double for this future extension.)
*/ */
private[spark] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double]) private[spark] class BaggedPoint[Datum](
extends Serializable val datum: Datum,
val subsampleCounts: Array[Int],
val sampleWeight: Double = 1.0) extends Serializable
private[spark] object BaggedPoint { private[spark] object BaggedPoint {
@ -52,6 +52,7 @@ private[spark] object BaggedPoint {
* @param subsamplingRate Fraction of the training data used for learning decision tree. * @param subsamplingRate Fraction of the training data used for learning decision tree.
* @param numSubsamples Number of subsamples of this RDD to take. * @param numSubsamples Number of subsamples of this RDD to take.
* @param withReplacement Sampling with/without replacement. * @param withReplacement Sampling with/without replacement.
* @param extractSampleWeight A function to get the sample weight of each datum.
* @param seed Random seed. * @param seed Random seed.
* @return BaggedPoint dataset representation. * @return BaggedPoint dataset representation.
*/ */
@ -60,12 +61,14 @@ private[spark] object BaggedPoint {
subsamplingRate: Double, subsamplingRate: Double,
numSubsamples: Int, numSubsamples: Int,
withReplacement: Boolean, withReplacement: Boolean,
extractSampleWeight: (Datum => Double) = (_: Datum) => 1.0,
seed: Long = Utils.random.nextLong()): RDD[BaggedPoint[Datum]] = { seed: Long = Utils.random.nextLong()): RDD[BaggedPoint[Datum]] = {
// TODO: implement weighted bootstrapping
if (withReplacement) { if (withReplacement) {
convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed) convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed)
} else { } else {
if (numSubsamples == 1 && subsamplingRate == 1.0) { if (numSubsamples == 1 && subsamplingRate == 1.0) {
convertToBaggedRDDWithoutSampling(input) convertToBaggedRDDWithoutSampling(input, extractSampleWeight)
} else { } else {
convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed) convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed)
} }
@ -82,16 +85,15 @@ private[spark] object BaggedPoint {
val rng = new XORShiftRandom val rng = new XORShiftRandom
rng.setSeed(seed + partitionIndex + 1) rng.setSeed(seed + partitionIndex + 1)
instances.map { instance => instances.map { instance =>
val subsampleWeights = new Array[Double](numSubsamples) val subsampleCounts = new Array[Int](numSubsamples)
var subsampleIndex = 0 var subsampleIndex = 0
while (subsampleIndex < numSubsamples) { while (subsampleIndex < numSubsamples) {
val x = rng.nextDouble() if (rng.nextDouble() < subsamplingRate) {
subsampleWeights(subsampleIndex) = { subsampleCounts(subsampleIndex) = 1
if (x < subsamplingRate) 1.0 else 0.0
} }
subsampleIndex += 1 subsampleIndex += 1
} }
new BaggedPoint(instance, subsampleWeights) new BaggedPoint(instance, subsampleCounts)
} }
} }
} }
@ -106,20 +108,20 @@ private[spark] object BaggedPoint {
val poisson = new PoissonDistribution(subsample) val poisson = new PoissonDistribution(subsample)
poisson.reseedRandomGenerator(seed + partitionIndex + 1) poisson.reseedRandomGenerator(seed + partitionIndex + 1)
instances.map { instance => instances.map { instance =>
val subsampleWeights = new Array[Double](numSubsamples) val subsampleCounts = new Array[Int](numSubsamples)
var subsampleIndex = 0 var subsampleIndex = 0
while (subsampleIndex < numSubsamples) { while (subsampleIndex < numSubsamples) {
subsampleWeights(subsampleIndex) = poisson.sample() subsampleCounts(subsampleIndex) = poisson.sample()
subsampleIndex += 1 subsampleIndex += 1
} }
new BaggedPoint(instance, subsampleWeights) new BaggedPoint(instance, subsampleCounts)
} }
} }
} }
private def convertToBaggedRDDWithoutSampling[Datum] ( private def convertToBaggedRDDWithoutSampling[Datum] (
input: RDD[Datum]): RDD[BaggedPoint[Datum]] = { input: RDD[Datum],
input.map(datum => new BaggedPoint(datum, Array(1.0))) extractSampleWeight: (Datum => Double)): RDD[BaggedPoint[Datum]] = {
input.map(datum => new BaggedPoint(datum, Array(1), extractSampleWeight(datum)))
} }
} }

View file

@ -104,16 +104,21 @@ private[spark] class DTStatsAggregator(
/** /**
* Update the stats for a given (feature, bin) for ordered features, using the given label. * Update the stats for a given (feature, bin) for ordered features, using the given label.
*/ */
def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = { def update(
featureIndex: Int,
binIndex: Int,
label: Double,
numSamples: Int,
sampleWeight: Double): Unit = {
val i = featureOffsets(featureIndex) + binIndex * statsSize val i = featureOffsets(featureIndex) + binIndex * statsSize
impurityAggregator.update(allStats, i, label, instanceWeight) impurityAggregator.update(allStats, i, label, numSamples, sampleWeight)
} }
/** /**
* Update the parent node stats using the given label. * Update the parent node stats using the given label.
*/ */
def updateParent(label: Double, instanceWeight: Double): Unit = { def updateParent(label: Double, numSamples: Int, sampleWeight: Double): Unit = {
impurityAggregator.update(parentStats, 0, label, instanceWeight) impurityAggregator.update(parentStats, 0, label, numSamples, sampleWeight)
} }
/** /**
@ -127,9 +132,10 @@ private[spark] class DTStatsAggregator(
featureOffset: Int, featureOffset: Int,
binIndex: Int, binIndex: Int,
label: Double, label: Double,
instanceWeight: Double): Unit = { numSamples: Int,
sampleWeight: Double): Unit = {
impurityAggregator.update(allStats, featureOffset + binIndex * statsSize, impurityAggregator.update(allStats, featureOffset + binIndex * statsSize,
label, instanceWeight) label, numSamples, sampleWeight)
} }
/** /**

View file

@ -21,7 +21,7 @@ import scala.collection.mutable
import scala.util.Try import scala.util.Try
import org.apache.spark.internal.Logging import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.tree.TreeEnsembleParams import org.apache.spark.ml.tree.TreeEnsembleParams
import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
@ -32,16 +32,20 @@ import org.apache.spark.rdd.RDD
/** /**
* Learning and dataset metadata for DecisionTree. * Learning and dataset metadata for DecisionTree.
* *
* @param weightedNumExamples Weighted count of samples in the tree.
* @param numClasses For classification: labels can take values {0, ..., numClasses - 1}. * @param numClasses For classification: labels can take values {0, ..., numClasses - 1}.
* For regression: fixed at 0 (no meaning). * For regression: fixed at 0 (no meaning).
* @param maxBins Maximum number of bins, for all features. * @param maxBins Maximum number of bins, for all features.
* @param featureArity Map: categorical feature index to arity. * @param featureArity Map: categorical feature index to arity.
* I.e., the feature takes values in {0, ..., arity - 1}. * I.e., the feature takes values in {0, ..., arity - 1}.
* @param numBins Number of bins for each feature. * @param numBins Number of bins for each feature.
* @param minWeightFractionPerNode The minimum fraction of the total sample weight that must be
* present in a leaf node in order to be considered a valid split.
*/ */
private[spark] class DecisionTreeMetadata( private[spark] class DecisionTreeMetadata(
val numFeatures: Int, val numFeatures: Int,
val numExamples: Long, val numExamples: Long,
val weightedNumExamples: Double,
val numClasses: Int, val numClasses: Int,
val maxBins: Int, val maxBins: Int,
val featureArity: Map[Int, Int], val featureArity: Map[Int, Int],
@ -51,6 +55,7 @@ private[spark] class DecisionTreeMetadata(
val quantileStrategy: QuantileStrategy, val quantileStrategy: QuantileStrategy,
val maxDepth: Int, val maxDepth: Int,
val minInstancesPerNode: Int, val minInstancesPerNode: Int,
val minWeightFractionPerNode: Double,
val minInfoGain: Double, val minInfoGain: Double,
val numTrees: Int, val numTrees: Int,
val numFeaturesPerNode: Int) extends Serializable { val numFeaturesPerNode: Int) extends Serializable {
@ -67,6 +72,8 @@ private[spark] class DecisionTreeMetadata(
def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex) def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex)
def minWeightPerNode: Double = minWeightFractionPerNode * weightedNumExamples
/** /**
* Number of splits for the given feature. * Number of splits for the given feature.
* For unordered features, there is 1 bin per split. * For unordered features, there is 1 bin per split.
@ -104,7 +111,7 @@ private[spark] object DecisionTreeMetadata extends Logging {
* as well as the number of splits and bins for each feature. * as well as the number of splits and bins for each feature.
*/ */
def buildMetadata( def buildMetadata(
input: RDD[LabeledPoint], input: RDD[Instance],
strategy: Strategy, strategy: Strategy,
numTrees: Int, numTrees: Int,
featureSubsetStrategy: String): DecisionTreeMetadata = { featureSubsetStrategy: String): DecisionTreeMetadata = {
@ -115,7 +122,11 @@ private[spark] object DecisionTreeMetadata extends Logging {
} }
require(numFeatures > 0, s"DecisionTree requires number of features > 0, " + require(numFeatures > 0, s"DecisionTree requires number of features > 0, " +
s"but was given an empty features vector") s"but was given an empty features vector")
val numExamples = input.count() val (numExamples, weightSum) = input.aggregate((0L, 0.0))(
seqOp = (cw, instance) => (cw._1 + 1L, cw._2 + instance.weight),
combOp = (cw1, cw2) => (cw1._1 + cw2._1, cw1._2 + cw2._2)
)
val numClasses = strategy.algo match { val numClasses = strategy.algo match {
case Classification => strategy.numClasses case Classification => strategy.numClasses
case Regression => 0 case Regression => 0
@ -206,17 +217,18 @@ private[spark] object DecisionTreeMetadata extends Logging {
} }
} }
new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max, new DecisionTreeMetadata(numFeatures, numExamples, weightSum, numClasses,
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins, numBins.max, strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth, strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth,
strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode) strategy.minInstancesPerNode, strategy.minWeightFractionPerNode, strategy.minInfoGain,
numTrees, numFeaturesPerNode)
} }
/** /**
* Version of [[DecisionTreeMetadata#buildMetadata]] for DecisionTree. * Version of [[DecisionTreeMetadata#buildMetadata]] for DecisionTree.
*/ */
def buildMetadata( def buildMetadata(
input: RDD[LabeledPoint], input: RDD[Instance],
strategy: Strategy): DecisionTreeMetadata = { strategy: Strategy): DecisionTreeMetadata = {
buildMetadata(input, strategy, numTrees = 1, featureSubsetStrategy = "all") buildMetadata(input, strategy, numTrees = 1, featureSubsetStrategy = "all")
} }

View file

@ -24,10 +24,12 @@ import scala.util.Random
import org.apache.spark.internal.Logging import org.apache.spark.internal.Logging
import org.apache.spark.ml.classification.DecisionTreeClassificationModel import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.impl.Utils
import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree._
import org.apache.spark.ml.util.Instrumentation import org.apache.spark.ml.util.Instrumentation
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
import org.apache.spark.mllib.tree.model.ImpurityStats import org.apache.spark.mllib.tree.model.ImpurityStats
@ -90,6 +92,24 @@ private[spark] object RandomForest extends Logging with Serializable {
strategy: OldStrategy, strategy: OldStrategy,
numTrees: Int, numTrees: Int,
featureSubsetStrategy: String, featureSubsetStrategy: String,
seed: Long): Array[DecisionTreeModel] = {
val instances = input.map { case LabeledPoint(label, features) =>
Instance(label, 1.0, features.asML)
}
run(instances, strategy, numTrees, featureSubsetStrategy, seed, None)
}
/**
* Train a random forest.
*
* @param input Training data: RDD of `Instance`
* @return an unweighted set of trees
*/
def run(
input: RDD[Instance],
strategy: OldStrategy,
numTrees: Int,
featureSubsetStrategy: String,
seed: Long, seed: Long,
instr: Option[Instrumentation], instr: Option[Instrumentation],
prune: Boolean = true, // exposed for testing only, real trees are always pruned prune: Boolean = true, // exposed for testing only, real trees are always pruned
@ -101,9 +121,10 @@ private[spark] object RandomForest extends Logging with Serializable {
timer.start("init") timer.start("init")
val retaggedInput = input.retag(classOf[LabeledPoint]) val retaggedInput = input.retag(classOf[Instance])
val metadata = val metadata =
DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy) DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy)
instr match { instr match {
case Some(instrumentation) => case Some(instrumentation) =>
instrumentation.logNumFeatures(metadata.numFeatures) instrumentation.logNumFeatures(metadata.numFeatures)
@ -132,7 +153,8 @@ private[spark] object RandomForest extends Logging with Serializable {
val withReplacement = numTrees > 1 val withReplacement = numTrees > 1
val baggedInput = BaggedPoint val baggedInput = BaggedPoint
.convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, withReplacement, seed) .convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, withReplacement,
(tp: TreePoint) => tp.weight, seed = seed)
.persist(StorageLevel.MEMORY_AND_DISK) .persist(StorageLevel.MEMORY_AND_DISK)
// depth of the decision tree // depth of the decision tree
@ -254,19 +276,21 @@ private[spark] object RandomForest extends Logging with Serializable {
* For unordered features, bins correspond to subsets of categories; either the left or right bin * For unordered features, bins correspond to subsets of categories; either the left or right bin
* for each subset is updated. * for each subset is updated.
* *
* @param agg Array storing aggregate calculation, with a set of sufficient statistics for * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
* each (feature, bin). * each (feature, bin).
* @param treePoint Data point being aggregated. * @param treePoint Data point being aggregated.
* @param splits possible splits indexed (numFeatures)(numSplits) * @param splits Possible splits indexed (numFeatures)(numSplits)
* @param unorderedFeatures Set of indices of unordered features. * @param unorderedFeatures Set of indices of unordered features.
* @param instanceWeight Weight (importance) of instance in dataset. * @param numSamples Number of times this instance occurs in the sample.
* @param sampleWeight Weight (importance) of instance in dataset.
*/ */
private def mixedBinSeqOp( private def mixedBinSeqOp(
agg: DTStatsAggregator, agg: DTStatsAggregator,
treePoint: TreePoint, treePoint: TreePoint,
splits: Array[Array[Split]], splits: Array[Array[Split]],
unorderedFeatures: Set[Int], unorderedFeatures: Set[Int],
instanceWeight: Double, numSamples: Int,
sampleWeight: Double,
featuresForNode: Option[Array[Int]]): Unit = { featuresForNode: Option[Array[Int]]): Unit = {
val numFeaturesPerNode = if (featuresForNode.nonEmpty) { val numFeaturesPerNode = if (featuresForNode.nonEmpty) {
// Use subsampled features // Use subsampled features
@ -293,14 +317,15 @@ private[spark] object RandomForest extends Logging with Serializable {
var splitIndex = 0 var splitIndex = 0
while (splitIndex < numSplits) { while (splitIndex < numSplits) {
if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) { if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) {
agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, numSamples,
sampleWeight)
} }
splitIndex += 1 splitIndex += 1
} }
} else { } else {
// Ordered feature // Ordered feature
val binIndex = treePoint.binnedFeatures(featureIndex) val binIndex = treePoint.binnedFeatures(featureIndex)
agg.update(featureIndexIdx, binIndex, treePoint.label, instanceWeight) agg.update(featureIndexIdx, binIndex, treePoint.label, numSamples, sampleWeight)
} }
featureIndexIdx += 1 featureIndexIdx += 1
} }
@ -314,12 +339,14 @@ private[spark] object RandomForest extends Logging with Serializable {
* @param agg Array storing aggregate calculation, with a set of sufficient statistics for * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
* each (feature, bin). * each (feature, bin).
* @param treePoint Data point being aggregated. * @param treePoint Data point being aggregated.
* @param instanceWeight Weight (importance) of instance in dataset. * @param numSamples Number of times this instance occurs in the sample.
* @param sampleWeight Weight (importance) of instance in dataset.
*/ */
private def orderedBinSeqOp( private def orderedBinSeqOp(
agg: DTStatsAggregator, agg: DTStatsAggregator,
treePoint: TreePoint, treePoint: TreePoint,
instanceWeight: Double, numSamples: Int,
sampleWeight: Double,
featuresForNode: Option[Array[Int]]): Unit = { featuresForNode: Option[Array[Int]]): Unit = {
val label = treePoint.label val label = treePoint.label
@ -329,7 +356,7 @@ private[spark] object RandomForest extends Logging with Serializable {
var featureIndexIdx = 0 var featureIndexIdx = 0
while (featureIndexIdx < featuresForNode.get.length) { while (featureIndexIdx < featuresForNode.get.length) {
val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx)) val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx))
agg.update(featureIndexIdx, binIndex, label, instanceWeight) agg.update(featureIndexIdx, binIndex, label, numSamples, sampleWeight)
featureIndexIdx += 1 featureIndexIdx += 1
} }
} else { } else {
@ -338,7 +365,7 @@ private[spark] object RandomForest extends Logging with Serializable {
var featureIndex = 0 var featureIndex = 0
while (featureIndex < numFeatures) { while (featureIndex < numFeatures) {
val binIndex = treePoint.binnedFeatures(featureIndex) val binIndex = treePoint.binnedFeatures(featureIndex)
agg.update(featureIndex, binIndex, label, instanceWeight) agg.update(featureIndex, binIndex, label, numSamples, sampleWeight)
featureIndex += 1 featureIndex += 1
} }
} }
@ -427,14 +454,16 @@ private[spark] object RandomForest extends Logging with Serializable {
if (nodeInfo != null) { if (nodeInfo != null) {
val aggNodeIndex = nodeInfo.nodeIndexInGroup val aggNodeIndex = nodeInfo.nodeIndexInGroup
val featuresForNode = nodeInfo.featureSubset val featuresForNode = nodeInfo.featureSubset
val instanceWeight = baggedPoint.subsampleWeights(treeIndex) val numSamples = baggedPoint.subsampleCounts(treeIndex)
val sampleWeight = baggedPoint.sampleWeight
if (metadata.unorderedFeatures.isEmpty) { if (metadata.unorderedFeatures.isEmpty) {
orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode) orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, numSamples, sampleWeight,
featuresForNode)
} else { } else {
mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits, mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
metadata.unorderedFeatures, instanceWeight, featuresForNode) metadata.unorderedFeatures, numSamples, sampleWeight, featuresForNode)
} }
agg(aggNodeIndex).updateParent(baggedPoint.datum.label, instanceWeight) agg(aggNodeIndex).updateParent(baggedPoint.datum.label, numSamples, sampleWeight)
} }
} }
@ -594,8 +623,8 @@ private[spark] object RandomForest extends Logging with Serializable {
if (!isLeaf) { if (!isLeaf) {
node.split = Some(split) node.split = Some(split)
val childIsLeaf = (LearningNode.indexToLevel(nodeIndex) + 1) == metadata.maxDepth val childIsLeaf = (LearningNode.indexToLevel(nodeIndex) + 1) == metadata.maxDepth
val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0) val leftChildIsLeaf = childIsLeaf || (math.abs(stats.leftImpurity) < Utils.EPSILON)
val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0) val rightChildIsLeaf = childIsLeaf || (math.abs(stats.rightImpurity) < Utils.EPSILON)
node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex), node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex),
leftChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator))) leftChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator)))
node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex), node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex),
@ -659,15 +688,20 @@ private[spark] object RandomForest extends Logging with Serializable {
stats.impurity stats.impurity
} }
val leftRawCount = leftImpurityCalculator.rawCount
val rightRawCount = rightImpurityCalculator.rawCount
val leftCount = leftImpurityCalculator.count val leftCount = leftImpurityCalculator.count
val rightCount = rightImpurityCalculator.count val rightCount = rightImpurityCalculator.count
val totalCount = leftCount + rightCount val totalCount = leftCount + rightCount
// If left child or right child doesn't satisfy minimum instances per node, val violatesMinInstancesPerNode = (leftRawCount < metadata.minInstancesPerNode) ||
// then this split is invalid, return invalid information gain stats. (rightRawCount < metadata.minInstancesPerNode)
if ((leftCount < metadata.minInstancesPerNode) || val violatesMinWeightPerNode = (leftCount < metadata.minWeightPerNode) ||
(rightCount < metadata.minInstancesPerNode)) { (rightCount < metadata.minWeightPerNode)
// If left child or right child doesn't satisfy minimum weight per node or minimum
// instances per node, then this split is invalid, return invalid information gain stats.
if (violatesMinInstancesPerNode || violatesMinWeightPerNode) {
return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
} }
@ -734,7 +768,8 @@ private[spark] object RandomForest extends Logging with Serializable {
// Find best split. // Find best split.
val (bestFeatureSplitIndex, bestFeatureGainStats) = val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { case splitIdx => Range(0, numSplits).map { case splitIdx =>
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) val leftChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
val rightChildStats = val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
rightChildStats.subtract(leftChildStats) rightChildStats.subtract(leftChildStats)
@ -876,14 +911,14 @@ private[spark] object RandomForest extends Logging with Serializable {
* and for multiclass classification with a high-arity feature, * and for multiclass classification with a high-arity feature,
* there is one bin per category. * there is one bin per category.
* *
* @param input Training data: RDD of [[LabeledPoint]] * @param input Training data: RDD of [[Instance]]
* @param metadata Learning and dataset metadata * @param metadata Learning and dataset metadata
* @param seed random seed * @param seed random seed
* @return Splits, an Array of [[Split]] * @return Splits, an Array of [[Split]]
* of size (numFeatures, numSplits) * of size (numFeatures, numSplits)
*/ */
protected[tree] def findSplits( protected[tree] def findSplits(
input: RDD[LabeledPoint], input: RDD[Instance],
metadata: DecisionTreeMetadata, metadata: DecisionTreeMetadata,
seed: Long): Array[Array[Split]] = { seed: Long): Array[Array[Split]] = {
@ -898,14 +933,14 @@ private[spark] object RandomForest extends Logging with Serializable {
logDebug("fraction of data used for calculating quantiles = " + fraction) logDebug("fraction of data used for calculating quantiles = " + fraction)
input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()) input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt())
} else { } else {
input.sparkContext.emptyRDD[LabeledPoint] input.sparkContext.emptyRDD[Instance]
} }
findSplitsBySorting(sampledInput, metadata, continuousFeatures) findSplitsBySorting(sampledInput, metadata, continuousFeatures)
} }
private def findSplitsBySorting( private def findSplitsBySorting(
input: RDD[LabeledPoint], input: RDD[Instance],
metadata: DecisionTreeMetadata, metadata: DecisionTreeMetadata,
continuousFeatures: IndexedSeq[Int]): Array[Array[Split]] = { continuousFeatures: IndexedSeq[Int]): Array[Array[Split]] = {
@ -917,7 +952,8 @@ private[spark] object RandomForest extends Logging with Serializable {
input input
.flatMap { point => .flatMap { point =>
continuousFeatures.map(idx => (idx, point.features(idx))).filter(_._2 != 0.0) continuousFeatures.map(idx => (idx, (point.weight, point.features(idx))))
.filter(_._2._2 != 0.0)
}.groupByKey(numPartitions) }.groupByKey(numPartitions)
.map { case (idx, samples) => .map { case (idx, samples) =>
val thresholds = findSplitsForContinuousFeature(samples, metadata, idx) val thresholds = findSplitsForContinuousFeature(samples, metadata, idx)
@ -982,7 +1018,7 @@ private[spark] object RandomForest extends Logging with Serializable {
* could be different from the specified `numSplits`. * could be different from the specified `numSplits`.
* The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly. * The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly.
* *
* @param featureSamples feature values of each sample * @param featureSamples feature values and sample weights of each sample
* @param metadata decision tree metadata * @param metadata decision tree metadata
* NOTE: `metadata.numbins` will be changed accordingly * NOTE: `metadata.numbins` will be changed accordingly
* if there are not enough splits to be found * if there are not enough splits to be found
@ -990,7 +1026,7 @@ private[spark] object RandomForest extends Logging with Serializable {
* @return array of split thresholds * @return array of split thresholds
*/ */
private[tree] def findSplitsForContinuousFeature( private[tree] def findSplitsForContinuousFeature(
featureSamples: Iterable[Double], featureSamples: Iterable[(Double, Double)],
metadata: DecisionTreeMetadata, metadata: DecisionTreeMetadata,
featureIndex: Int): Array[Double] = { featureIndex: Int): Array[Double] = {
require(metadata.isContinuous(featureIndex), require(metadata.isContinuous(featureIndex),
@ -1002,19 +1038,24 @@ private[spark] object RandomForest extends Logging with Serializable {
val numSplits = metadata.numSplits(featureIndex) val numSplits = metadata.numSplits(featureIndex)
// get count for each distinct value except zero value // get count for each distinct value except zero value
val partNumSamples = featureSamples.size val partValueCountMap = mutable.Map[Double, Double]()
val partValueCountMap = scala.collection.mutable.Map[Double, Int]() var partNumSamples = 0.0
featureSamples.foreach { x => var unweightedNumSamples = 0.0
partValueCountMap(x) = partValueCountMap.getOrElse(x, 0) + 1 featureSamples.foreach { case (sampleWeight, feature) =>
partValueCountMap(feature) = partValueCountMap.getOrElse(feature, 0.0) + sampleWeight;
partNumSamples += sampleWeight;
unweightedNumSamples += 1.0
} }
// Calculate the expected number of samples for finding splits // Calculate the expected number of samples for finding splits
val numSamples = (samplesFractionForFindSplits(metadata) * metadata.numExamples).toInt val weightedNumSamples = samplesFractionForFindSplits(metadata) *
metadata.weightedNumExamples
// add expected zero value count and get complete statistics // add expected zero value count and get complete statistics
val valueCountMap: Map[Double, Int] = if (numSamples - partNumSamples > 0) { val tolerance = Utils.EPSILON * unweightedNumSamples * unweightedNumSamples
partValueCountMap.toMap + (0.0 -> (numSamples - partNumSamples)) val valueCountMap = if (weightedNumSamples - partNumSamples > tolerance) {
partValueCountMap + (0.0 -> (weightedNumSamples - partNumSamples))
} else { } else {
partValueCountMap.toMap partValueCountMap
} }
// sort distinct values // sort distinct values
@ -1031,7 +1072,7 @@ private[spark] object RandomForest extends Logging with Serializable {
.toArray .toArray
} else { } else {
// stride between splits // stride between splits
val stride: Double = numSamples.toDouble / (numSplits + 1) val stride: Double = weightedNumSamples / (numSplits + 1)
logDebug("stride = " + stride) logDebug("stride = " + stride)
// iterate `valueCount` to find splits // iterate `valueCount` to find splits

View file

@ -17,7 +17,7 @@
package org.apache.spark.ml.tree.impl package org.apache.spark.ml.tree.impl
import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.tree.{ContinuousSplit, Split} import org.apache.spark.ml.tree.{ContinuousSplit, Split}
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
@ -36,10 +36,12 @@ import org.apache.spark.rdd.RDD
* @param label Label from LabeledPoint * @param label Label from LabeledPoint
* @param binnedFeatures Binned feature values. * @param binnedFeatures Binned feature values.
* Same length as LabeledPoint.features, but values are bin indices. * Same length as LabeledPoint.features, but values are bin indices.
* @param weight Sample weight for this TreePoint.
*/ */
private[spark] class TreePoint(val label: Double, val binnedFeatures: Array[Int]) private[spark] class TreePoint(
extends Serializable { val label: Double,
} val binnedFeatures: Array[Int],
val weight: Double) extends Serializable
private[spark] object TreePoint { private[spark] object TreePoint {
@ -52,7 +54,7 @@ private[spark] object TreePoint {
* @return TreePoint dataset representation * @return TreePoint dataset representation
*/ */
def convertToTreeRDD( def convertToTreeRDD(
input: RDD[LabeledPoint], input: RDD[Instance],
splits: Array[Array[Split]], splits: Array[Array[Split]],
metadata: DecisionTreeMetadata): RDD[TreePoint] = { metadata: DecisionTreeMetadata): RDD[TreePoint] = {
// Construct arrays for featureArity for efficiency in the inner loop. // Construct arrays for featureArity for efficiency in the inner loop.
@ -82,18 +84,18 @@ private[spark] object TreePoint {
* for categorical features. * for categorical features.
*/ */
private def labeledPointToTreePoint( private def labeledPointToTreePoint(
labeledPoint: LabeledPoint, instance: Instance,
thresholds: Array[Array[Double]], thresholds: Array[Array[Double]],
featureArity: Array[Int]): TreePoint = { featureArity: Array[Int]): TreePoint = {
val numFeatures = labeledPoint.features.size val numFeatures = instance.features.size
val arr = new Array[Int](numFeatures) val arr = new Array[Int](numFeatures)
var featureIndex = 0 var featureIndex = 0
while (featureIndex < numFeatures) { while (featureIndex < numFeatures) {
arr(featureIndex) = arr(featureIndex) =
findBin(featureIndex, labeledPoint, featureArity(featureIndex), thresholds(featureIndex)) findBin(featureIndex, instance, featureArity(featureIndex), thresholds(featureIndex))
featureIndex += 1 featureIndex += 1
} }
new TreePoint(labeledPoint.label, arr) new TreePoint(instance.label, arr, instance.weight)
} }
/** /**
@ -106,10 +108,10 @@ private[spark] object TreePoint {
*/ */
private def findBin( private def findBin(
featureIndex: Int, featureIndex: Int,
labeledPoint: LabeledPoint, instance: Instance,
featureArity: Int, featureArity: Int,
thresholds: Array[Double]): Int = { thresholds: Array[Double]): Int = {
val featureValue = labeledPoint.features(featureIndex) val featureValue = instance.features(featureIndex)
if (featureArity == 0) { if (featureArity == 0) {
val idx = java.util.Arrays.binarySearch(thresholds, featureValue) val idx = java.util.Arrays.binarySearch(thresholds, featureValue)
@ -125,7 +127,7 @@ private[spark] object TreePoint {
s"DecisionTree given invalid data:" + s"DecisionTree given invalid data:" +
s" Feature $featureIndex is categorical with values in {0,...,${featureArity - 1}," + s" Feature $featureIndex is categorical with values in {0,...,${featureArity - 1}," +
s" but a data point gives it value $featureValue.\n" + s" but a data point gives it value $featureValue.\n" +
" Bad data point: " + labeledPoint.toString) s" Bad data point: $instance")
} }
featureValue.toInt featureValue.toInt
} }

View file

@ -282,6 +282,7 @@ private[ml] object DecisionTreeModelReadWrite {
* *
* @param id Index used for tree reconstruction. Indices follow a pre-order traversal. * @param id Index used for tree reconstruction. Indices follow a pre-order traversal.
* @param impurityStats Stats array. Impurity type is stored in metadata. * @param impurityStats Stats array. Impurity type is stored in metadata.
* @param rawCount The unweighted number of samples falling in this node.
* @param gain Gain, or arbitrary value if leaf node. * @param gain Gain, or arbitrary value if leaf node.
* @param leftChild Left child index, or arbitrary value if leaf node. * @param leftChild Left child index, or arbitrary value if leaf node.
* @param rightChild Right child index, or arbitrary value if leaf node. * @param rightChild Right child index, or arbitrary value if leaf node.
@ -292,6 +293,7 @@ private[ml] object DecisionTreeModelReadWrite {
prediction: Double, prediction: Double,
impurity: Double, impurity: Double,
impurityStats: Array[Double], impurityStats: Array[Double],
rawCount: Long,
gain: Double, gain: Double,
leftChild: Int, leftChild: Int,
rightChild: Int, rightChild: Int,
@ -311,11 +313,12 @@ private[ml] object DecisionTreeModelReadWrite {
val (leftNodeData, leftIdx) = build(n.leftChild, id + 1) val (leftNodeData, leftIdx) = build(n.leftChild, id + 1)
val (rightNodeData, rightIdx) = build(n.rightChild, leftIdx + 1) val (rightNodeData, rightIdx) = build(n.rightChild, leftIdx + 1)
val thisNodeData = NodeData(id, n.prediction, n.impurity, n.impurityStats.stats, val thisNodeData = NodeData(id, n.prediction, n.impurity, n.impurityStats.stats,
n.gain, leftNodeData.head.id, rightNodeData.head.id, SplitData(n.split)) n.impurityStats.rawCount, n.gain, leftNodeData.head.id, rightNodeData.head.id,
SplitData(n.split))
(thisNodeData +: (leftNodeData ++ rightNodeData), rightIdx) (thisNodeData +: (leftNodeData ++ rightNodeData), rightIdx)
case _: LeafNode => case _: LeafNode =>
(Seq(NodeData(id, node.prediction, node.impurity, node.impurityStats.stats, (Seq(NodeData(id, node.prediction, node.impurity, node.impurityStats.stats,
-1.0, -1, -1, SplitData(-1, Array.empty[Double], -1))), node.impurityStats.rawCount, -1.0, -1, -1, SplitData(-1, Array.empty[Double], -1))),
id) id)
} }
} }
@ -360,7 +363,8 @@ private[ml] object DecisionTreeModelReadWrite {
// traversal, this guarantees that child nodes will be built before parent nodes. // traversal, this guarantees that child nodes will be built before parent nodes.
val finalNodes = new Array[Node](nodes.length) val finalNodes = new Array[Node](nodes.length)
nodes.reverseIterator.foreach { case n: NodeData => nodes.reverseIterator.foreach { case n: NodeData =>
val impurityStats = ImpurityCalculator.getCalculator(impurityType, n.impurityStats) val impurityStats =
ImpurityCalculator.getCalculator(impurityType, n.impurityStats, n.rawCount)
val node = if (n.leftChild != -1) { val node = if (n.leftChild != -1) {
val leftChild = finalNodes(n.leftChild) val leftChild = finalNodes(n.leftChild)
val rightChild = finalNodes(n.rightChild) val rightChild = finalNodes(n.rightChild)

View file

@ -37,7 +37,7 @@ import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
* Note: Marked as private and DeveloperApi since this may be made public in the future. * Note: Marked as private and DeveloperApi since this may be made public in the future.
*/ */
private[ml] trait DecisionTreeParams extends PredictorParams private[ml] trait DecisionTreeParams extends PredictorParams
with HasCheckpointInterval with HasSeed { with HasCheckpointInterval with HasSeed with HasWeightCol {
/** /**
* Maximum depth of the tree (>= 0). * Maximum depth of the tree (>= 0).
@ -74,6 +74,21 @@ private[ml] trait DecisionTreeParams extends PredictorParams
" child to have fewer than minInstancesPerNode, the split will be discarded as invalid." + " child to have fewer than minInstancesPerNode, the split will be discarded as invalid." +
" Should be >= 1.", ParamValidators.gtEq(1)) " Should be >= 1.", ParamValidators.gtEq(1))
/**
* Minimum fraction of the weighted sample count that each child must have after split.
* If a split causes the fraction of the total weight in the left or right child to be less than
* minWeightFractionPerNode, the split will be discarded as invalid.
* Should be in the interval [0.0, 0.5).
* (default = 0.0)
* @group param
*/
final val minWeightFractionPerNode: DoubleParam = new DoubleParam(this,
"minWeightFractionPerNode", "Minimum fraction of the weighted sample count that each child " +
"must have after split. If a split causes the fraction of the total weight in the left or " +
"right child to be less than minWeightFractionPerNode, the split will be discarded as " +
"invalid. Should be in interval [0.0, 0.5)",
ParamValidators.inRange(0.0, 0.5, lowerInclusive = true, upperInclusive = false))
/** /**
* Minimum information gain for a split to be considered at a tree node. * Minimum information gain for a split to be considered at a tree node.
* Should be >= 0.0. * Should be >= 0.0.
@ -107,8 +122,9 @@ private[ml] trait DecisionTreeParams extends PredictorParams
" algorithm will cache node IDs for each instance. Caching can speed up training of deeper" + " algorithm will cache node IDs for each instance. Caching can speed up training of deeper" +
" trees.") " trees.")
setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0, setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1,
maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10) minWeightFractionPerNode -> 0.0, minInfoGain -> 0.0, maxMemoryInMB -> 256,
cacheNodeIds -> false, checkpointInterval -> 10)
/** @group getParam */ /** @group getParam */
final def getMaxDepth: Int = $(maxDepth) final def getMaxDepth: Int = $(maxDepth)
@ -119,6 +135,9 @@ private[ml] trait DecisionTreeParams extends PredictorParams
/** @group getParam */ /** @group getParam */
final def getMinInstancesPerNode: Int = $(minInstancesPerNode) final def getMinInstancesPerNode: Int = $(minInstancesPerNode)
/** @group getParam */
final def getMinWeightFractionPerNode: Double = $(minWeightFractionPerNode)
/** @group getParam */ /** @group getParam */
final def getMinInfoGain: Double = $(minInfoGain) final def getMinInfoGain: Double = $(minInfoGain)
@ -143,6 +162,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams
strategy.maxMemoryInMB = getMaxMemoryInMB strategy.maxMemoryInMB = getMaxMemoryInMB
strategy.minInfoGain = getMinInfoGain strategy.minInfoGain = getMinInfoGain
strategy.minInstancesPerNode = getMinInstancesPerNode strategy.minInstancesPerNode = getMinInstancesPerNode
strategy.minWeightFractionPerNode = getMinWeightFractionPerNode
strategy.useNodeIdCache = getCacheNodeIds strategy.useNodeIdCache = getCacheNodeIds
strategy.numClasses = numClasses strategy.numClasses = numClasses
strategy.categoricalFeaturesInfo = categoricalFeatures strategy.categoricalFeaturesInfo = categoricalFeatures

View file

@ -23,6 +23,7 @@ import scala.util.Try
import org.apache.spark.annotation.Since import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.tree.{DecisionTreeModel => NewDTModel, TreeEnsembleParams => NewRFParams} import org.apache.spark.ml.tree.{DecisionTreeModel => NewDTModel, TreeEnsembleParams => NewRFParams}
import org.apache.spark.ml.tree.impl.{RandomForest => NewRandomForest} import org.apache.spark.ml.tree.impl.{RandomForest => NewRandomForest}
import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.regression.LabeledPoint
@ -91,8 +92,8 @@ private class RandomForest (
* @return RandomForestModel that can be used for prediction. * @return RandomForestModel that can be used for prediction.
*/ */
def run(input: RDD[LabeledPoint]): RandomForestModel = { def run(input: RDD[LabeledPoint]): RandomForestModel = {
val trees: Array[NewDTModel] = NewRandomForest.run(input.map(_.asML), strategy, numTrees, val trees: Array[NewDTModel] =
featureSubsetStrategy, seed.toLong, None) NewRandomForest.run(input, strategy, numTrees, featureSubsetStrategy, seed.toLong)
new RandomForestModel(strategy.algo, trees.map(_.toOld)) new RandomForestModel(strategy.algo, trees.map(_.toOld))
} }

View file

@ -80,7 +80,8 @@ class Strategy @Since("1.3.0") (
@Since("1.0.0") @BeanProperty var maxMemoryInMB: Int = 256, @Since("1.0.0") @BeanProperty var maxMemoryInMB: Int = 256,
@Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1, @Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1,
@Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false, @Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false,
@Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10) extends Serializable { @Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10,
@Since("3.0.0") @BeanProperty var minWeightFractionPerNode: Double = 0.0) extends Serializable {
/** /**
*/ */
@ -96,6 +97,31 @@ class Strategy @Since("1.3.0") (
isMulticlassClassification && (categoricalFeaturesInfo.size > 0) isMulticlassClassification && (categoricalFeaturesInfo.size > 0)
} }
// scalastyle:off argcount
/**
* Backwards compatible constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]]
*/
@Since("1.0.0")
def this(
algo: Algo,
impurity: Impurity,
maxDepth: Int,
numClasses: Int,
maxBins: Int,
quantileCalculationStrategy: QuantileStrategy,
categoricalFeaturesInfo: Map[Int, Int],
minInstancesPerNode: Int,
minInfoGain: Double,
maxMemoryInMB: Int,
subsamplingRate: Double,
useNodeIdCache: Boolean,
checkpointInterval: Int) {
this(algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy,
categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, maxMemoryInMB,
subsamplingRate, useNodeIdCache, checkpointInterval, 0.0)
}
// scalastyle:on argcount
/** /**
* Java-friendly constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]] * Java-friendly constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]]
*/ */
@ -108,7 +134,8 @@ class Strategy @Since("1.3.0") (
maxBins: Int, maxBins: Int,
categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]) { categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]) {
this(algo, impurity, maxDepth, numClasses, maxBins, Sort, this(algo, impurity, maxDepth, numClasses, maxBins, Sort,
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap) categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
minWeightFractionPerNode = 0.0)
} }
/** /**
@ -171,8 +198,9 @@ class Strategy @Since("1.3.0") (
@Since("1.2.0") @Since("1.2.0")
def copy: Strategy = { def copy: Strategy = {
new Strategy(algo, impurity, maxDepth, numClasses, maxBins, new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode,
maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval) minInfoGain, maxMemoryInMB, subsamplingRate, useNodeIdCache,
checkpointInterval, minWeightFractionPerNode)
} }
} }

View file

@ -83,23 +83,29 @@ object Entropy extends Impurity {
* @param numClasses Number of classes for label. * @param numClasses Number of classes for label.
*/ */
private[spark] class EntropyAggregator(numClasses: Int) private[spark] class EntropyAggregator(numClasses: Int)
extends ImpurityAggregator(numClasses) with Serializable { extends ImpurityAggregator(numClasses + 1) with Serializable {
/** /**
* Update stats for one (node, feature, bin) with the given label. * Update stats for one (node, feature, bin) with the given label.
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
* @param offset Start index of stats for this (node, feature, bin). * @param offset Start index of stats for this (node, feature, bin).
*/ */
def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = { def update(
if (label >= statsSize) { allStats: Array[Double],
offset: Int,
label: Double,
numSamples: Int,
sampleWeight: Double): Unit = {
if (label >= numClasses) {
throw new IllegalArgumentException(s"EntropyAggregator given label $label" + throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
s" but requires label < numClasses (= $statsSize).") s" but requires label < numClasses (= ${numClasses}).")
} }
if (label < 0) { if (label < 0) {
throw new IllegalArgumentException(s"EntropyAggregator given label $label" + throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
s"but requires label is non-negative.") s"but requires label is non-negative.")
} }
allStats(offset + label.toInt) += instanceWeight allStats(offset + label.toInt) += numSamples * sampleWeight
allStats(offset + statsSize - 1) += numSamples
} }
/** /**
@ -108,7 +114,8 @@ private[spark] class EntropyAggregator(numClasses: Int)
* @param offset Start index of stats for this (node, feature, bin). * @param offset Start index of stats for this (node, feature, bin).
*/ */
def getCalculator(allStats: Array[Double], offset: Int): EntropyCalculator = { def getCalculator(allStats: Array[Double], offset: Int): EntropyCalculator = {
new EntropyCalculator(allStats.view(offset, offset + statsSize).toArray) new EntropyCalculator(allStats.view(offset, offset + statsSize - 1).toArray,
allStats(offset + statsSize - 1).toLong)
} }
} }
@ -118,12 +125,13 @@ private[spark] class EntropyAggregator(numClasses: Int)
* (node, feature, bin). * (node, feature, bin).
* @param stats Array of sufficient statistics for a (node, feature, bin). * @param stats Array of sufficient statistics for a (node, feature, bin).
*/ */
private[spark] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { private[spark] class EntropyCalculator(stats: Array[Double], var rawCount: Long)
extends ImpurityCalculator(stats) {
/** /**
* Make a deep copy of this [[ImpurityCalculator]]. * Make a deep copy of this [[ImpurityCalculator]].
*/ */
def copy: EntropyCalculator = new EntropyCalculator(stats.clone()) def copy: EntropyCalculator = new EntropyCalculator(stats.clone(), rawCount)
/** /**
* Calculate the impurity from the stored sufficient statistics. * Calculate the impurity from the stored sufficient statistics.
@ -131,9 +139,9 @@ private[spark] class EntropyCalculator(stats: Array[Double]) extends ImpurityCal
def calculate(): Double = Entropy.calculate(stats, stats.sum) def calculate(): Double = Entropy.calculate(stats, stats.sum)
/** /**
* Number of data points accounted for in the sufficient statistics. * Weighted number of data points accounted for in the sufficient statistics.
*/ */
def count: Long = stats.sum.toLong def count: Double = stats.sum
/** /**
* Prediction which should be made based on the sufficient statistics. * Prediction which should be made based on the sufficient statistics.

View file

@ -80,23 +80,29 @@ object Gini extends Impurity {
* @param numClasses Number of classes for label. * @param numClasses Number of classes for label.
*/ */
private[spark] class GiniAggregator(numClasses: Int) private[spark] class GiniAggregator(numClasses: Int)
extends ImpurityAggregator(numClasses) with Serializable { extends ImpurityAggregator(numClasses + 1) with Serializable {
/** /**
* Update stats for one (node, feature, bin) with the given label. * Update stats for one (node, feature, bin) with the given label.
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
* @param offset Start index of stats for this (node, feature, bin). * @param offset Start index of stats for this (node, feature, bin).
*/ */
def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = { def update(
if (label >= statsSize) { allStats: Array[Double],
offset: Int,
label: Double,
numSamples: Int,
sampleWeight: Double): Unit = {
if (label >= numClasses) {
throw new IllegalArgumentException(s"GiniAggregator given label $label" + throw new IllegalArgumentException(s"GiniAggregator given label $label" +
s" but requires label < numClasses (= $statsSize).") s" but requires label < numClasses (= ${numClasses}).")
} }
if (label < 0) { if (label < 0) {
throw new IllegalArgumentException(s"GiniAggregator given label $label" + throw new IllegalArgumentException(s"GiniAggregator given label $label" +
s"but requires label is non-negative.") s"but requires label to be non-negative.")
} }
allStats(offset + label.toInt) += instanceWeight allStats(offset + label.toInt) += numSamples * sampleWeight
allStats(offset + statsSize - 1) += numSamples
} }
/** /**
@ -105,7 +111,8 @@ private[spark] class GiniAggregator(numClasses: Int)
* @param offset Start index of stats for this (node, feature, bin). * @param offset Start index of stats for this (node, feature, bin).
*/ */
def getCalculator(allStats: Array[Double], offset: Int): GiniCalculator = { def getCalculator(allStats: Array[Double], offset: Int): GiniCalculator = {
new GiniCalculator(allStats.view(offset, offset + statsSize).toArray) new GiniCalculator(allStats.view(offset, offset + statsSize - 1).toArray,
allStats(offset + statsSize - 1).toLong)
} }
} }
@ -115,12 +122,13 @@ private[spark] class GiniAggregator(numClasses: Int)
* (node, feature, bin). * (node, feature, bin).
* @param stats Array of sufficient statistics for a (node, feature, bin). * @param stats Array of sufficient statistics for a (node, feature, bin).
*/ */
private[spark] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { private[spark] class GiniCalculator(stats: Array[Double], var rawCount: Long)
extends ImpurityCalculator(stats) {
/** /**
* Make a deep copy of this [[ImpurityCalculator]]. * Make a deep copy of this [[ImpurityCalculator]].
*/ */
def copy: GiniCalculator = new GiniCalculator(stats.clone()) def copy: GiniCalculator = new GiniCalculator(stats.clone(), rawCount)
/** /**
* Calculate the impurity from the stored sufficient statistics. * Calculate the impurity from the stored sufficient statistics.
@ -128,9 +136,9 @@ private[spark] class GiniCalculator(stats: Array[Double]) extends ImpurityCalcul
def calculate(): Double = Gini.calculate(stats, stats.sum) def calculate(): Double = Gini.calculate(stats, stats.sum)
/** /**
* Number of data points accounted for in the sufficient statistics. * Weighted number of data points accounted for in the sufficient statistics.
*/ */
def count: Long = stats.sum.toLong def count: Double = stats.sum
/** /**
* Prediction which should be made based on the sufficient statistics. * Prediction which should be made based on the sufficient statistics.

View file

@ -81,7 +81,12 @@ private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Ser
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
* @param offset Start index of stats for this (node, feature, bin). * @param offset Start index of stats for this (node, feature, bin).
*/ */
def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit def update(
allStats: Array[Double],
offset: Int,
label: Double,
numSamples: Int,
sampleWeight: Double): Unit
/** /**
* Get an [[ImpurityCalculator]] for a (node, feature, bin). * Get an [[ImpurityCalculator]] for a (node, feature, bin).
@ -122,6 +127,7 @@ private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) exten
stats(i) += other.stats(i) stats(i) += other.stats(i)
i += 1 i += 1
} }
rawCount += other.rawCount
this this
} }
@ -139,13 +145,19 @@ private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) exten
stats(i) -= other.stats(i) stats(i) -= other.stats(i)
i += 1 i += 1
} }
rawCount -= other.rawCount
this this
} }
/** /**
* Number of data points accounted for in the sufficient statistics. * Weighted number of data points accounted for in the sufficient statistics.
*/ */
def count: Long def count: Double
/**
* Raw number of data points accounted for in the sufficient statistics.
*/
var rawCount: Long
/** /**
* Prediction which should be made based on the sufficient statistics. * Prediction which should be made based on the sufficient statistics.
@ -185,11 +197,14 @@ private[spark] object ImpurityCalculator {
* Create an [[ImpurityCalculator]] instance of the given impurity type and with * Create an [[ImpurityCalculator]] instance of the given impurity type and with
* the given stats. * the given stats.
*/ */
def getCalculator(impurity: String, stats: Array[Double]): ImpurityCalculator = { def getCalculator(
impurity: String,
stats: Array[Double],
rawCount: Long): ImpurityCalculator = {
impurity.toLowerCase(Locale.ROOT) match { impurity.toLowerCase(Locale.ROOT) match {
case "gini" => new GiniCalculator(stats) case "gini" => new GiniCalculator(stats, rawCount)
case "entropy" => new EntropyCalculator(stats) case "entropy" => new EntropyCalculator(stats, rawCount)
case "variance" => new VarianceCalculator(stats) case "variance" => new VarianceCalculator(stats, rawCount)
case _ => case _ =>
throw new IllegalArgumentException( throw new IllegalArgumentException(
s"ImpurityCalculator builder did not recognize impurity type: $impurity") s"ImpurityCalculator builder did not recognize impurity type: $impurity")

View file

@ -66,21 +66,32 @@ object Variance extends Impurity {
/** /**
* Class for updating views of a vector of sufficient statistics, * Class for updating views of a vector of sufficient statistics,
* in order to compute impurity from a sample. * in order to compute impurity from a sample. For variance, we track:
* - sum(w_i)
* - sum(w_i * y_i)
* - sum(w_i * y_i * y_i)
* - count(y_i)
* Note: Instances of this class do not hold the data; they operate on views of the data. * Note: Instances of this class do not hold the data; they operate on views of the data.
*/ */
private[spark] class VarianceAggregator() private[spark] class VarianceAggregator()
extends ImpurityAggregator(statsSize = 3) with Serializable { extends ImpurityAggregator(statsSize = 4) with Serializable {
/** /**
* Update stats for one (node, feature, bin) with the given label. * Update stats for one (node, feature, bin) with the given label.
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
* @param offset Start index of stats for this (node, feature, bin). * @param offset Start index of stats for this (node, feature, bin).
*/ */
def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = { def update(
allStats: Array[Double],
offset: Int,
label: Double,
numSamples: Int,
sampleWeight: Double): Unit = {
val instanceWeight = numSamples * sampleWeight
allStats(offset) += instanceWeight allStats(offset) += instanceWeight
allStats(offset + 1) += instanceWeight * label allStats(offset + 1) += instanceWeight * label
allStats(offset + 2) += instanceWeight * label * label allStats(offset + 2) += instanceWeight * label * label
allStats(offset + 3) += numSamples
} }
/** /**
@ -89,7 +100,8 @@ private[spark] class VarianceAggregator()
* @param offset Start index of stats for this (node, feature, bin). * @param offset Start index of stats for this (node, feature, bin).
*/ */
def getCalculator(allStats: Array[Double], offset: Int): VarianceCalculator = { def getCalculator(allStats: Array[Double], offset: Int): VarianceCalculator = {
new VarianceCalculator(allStats.view(offset, offset + statsSize).toArray) new VarianceCalculator(allStats.view(offset, offset + statsSize - 1).toArray,
allStats(offset + statsSize - 1).toLong)
} }
} }
@ -99,7 +111,8 @@ private[spark] class VarianceAggregator()
* (node, feature, bin). * (node, feature, bin).
* @param stats Array of sufficient statistics for a (node, feature, bin). * @param stats Array of sufficient statistics for a (node, feature, bin).
*/ */
private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { private[spark] class VarianceCalculator(stats: Array[Double], var rawCount: Long)
extends ImpurityCalculator(stats) {
require(stats.length == 3, require(stats.length == 3,
s"VarianceCalculator requires sufficient statistics array stats to be of length 3," + s"VarianceCalculator requires sufficient statistics array stats to be of length 3," +
@ -108,7 +121,7 @@ private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCa
/** /**
* Make a deep copy of this [[ImpurityCalculator]]. * Make a deep copy of this [[ImpurityCalculator]].
*/ */
def copy: VarianceCalculator = new VarianceCalculator(stats.clone()) def copy: VarianceCalculator = new VarianceCalculator(stats.clone(), rawCount)
/** /**
* Calculate the impurity from the stored sufficient statistics. * Calculate the impurity from the stored sufficient statistics.
@ -116,9 +129,9 @@ private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCa
def calculate(): Double = Variance.calculate(stats(0), stats(1), stats(2)) def calculate(): Double = Variance.calculate(stats(0), stats(1), stats(2))
/** /**
* Number of data points accounted for in the sufficient statistics. * Weighted number of data points accounted for in the sufficient statistics.
*/ */
def count: Long = stats(0).toLong def count: Double = stats(0)
/** /**
* Prediction which should be made based on the sufficient statistics. * Prediction which should be made based on the sufficient statistics.

View file

@ -42,6 +42,8 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest {
private var continuousDataPointsForMulticlassRDD: RDD[LabeledPoint] = _ private var continuousDataPointsForMulticlassRDD: RDD[LabeledPoint] = _
private var categoricalDataPointsForMulticlassForOrderedFeaturesRDD: RDD[LabeledPoint] = _ private var categoricalDataPointsForMulticlassForOrderedFeaturesRDD: RDD[LabeledPoint] = _
private val seed = 42
override def beforeAll() { override def beforeAll() {
super.beforeAll() super.beforeAll()
categoricalDataPointsRDD = categoricalDataPointsRDD =
@ -250,7 +252,7 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest {
MLTestingUtils.checkCopyAndUids(dt, newTree) MLTestingUtils.checkCopyAndUids(dt, newTree)
testTransformer[(Vector, Double)](newData, newTree, testTransformer[(Vector, Double, Double)](newData, newTree,
"prediction", "rawPrediction", "probability") { "prediction", "rawPrediction", "probability") {
case Row(pred: Double, rawPred: Vector, probPred: Vector) => case Row(pred: Double, rawPred: Vector, probPred: Vector) =>
assert(pred === rawPred.argmax, assert(pred === rawPred.argmax,
@ -327,6 +329,49 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest {
dt.fit(df) dt.fit(df)
} }
test("training with sample weights") {
val df = {
val nPoints = 100
val coefficients = Array(
-0.57997, 0.912083, -0.371077,
-0.16624, -0.84355, -0.048509)
val xMean = Array(5.843, 3.057)
val xVariance = Array(0.6856, 0.1899)
val testData = LogisticRegressionSuite.generateMultinomialLogisticInput(
coefficients, xMean, xVariance, addIntercept = true, nPoints, seed)
sc.parallelize(testData, 4).toDF()
}
val numClasses = 3
val predEquals = (x: Double, y: Double) => x == y
// (impurity, maxDepth)
val testParams = Seq(
("gini", 10),
("entropy", 10),
("gini", 5)
)
for ((impurity, maxDepth) <- testParams) {
val estimator = new DecisionTreeClassifier()
.setMaxDepth(maxDepth)
.setSeed(seed)
.setMinWeightFractionPerNode(0.049)
.setImpurity(impurity)
MLTestingUtils.testArbitrarilyScaledWeights[DecisionTreeClassificationModel,
DecisionTreeClassifier](df.as[LabeledPoint], estimator,
MLTestingUtils.modelPredictionEquals(df, predEquals, 0.7))
MLTestingUtils.testOutliersWithSmallWeights[DecisionTreeClassificationModel,
DecisionTreeClassifier](df.as[LabeledPoint], estimator,
numClasses, MLTestingUtils.modelPredictionEquals(df, predEquals, 0.8),
outlierRatio = 2)
MLTestingUtils.testOversamplingVsWeighting[DecisionTreeClassificationModel,
DecisionTreeClassifier](df.as[LabeledPoint], estimator,
MLTestingUtils.modelPredictionEquals(df, predEquals, 0.7), seed)
}
}
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Tests of model save/load // Tests of model save/load
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
@ -350,7 +395,6 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest {
TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 2) TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 2)
testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings, testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings,
allParamSettings, checkModelData) allParamSettings, checkModelData)
// Continuous splits with tree depth 2 // Continuous splits with tree depth 2
val continuousData: DataFrame = val continuousData: DataFrame =
TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)

View file

@ -18,7 +18,7 @@
package org.apache.spark.ml.classification package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.tree.LeafNode import org.apache.spark.ml.tree.LeafNode
@ -141,7 +141,7 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest {
MLTestingUtils.checkCopyAndUids(rf, model) MLTestingUtils.checkCopyAndUids(rf, model)
testTransformer[(Vector, Double)](df, model, "prediction", "rawPrediction", testTransformer[(Vector, Double, Double)](df, model, "prediction", "rawPrediction",
"probability") { case Row(pred: Double, rawPred: Vector, probPred: Vector) => "probability") { case Row(pred: Double, rawPred: Vector, probPred: Vector) =>
assert(pred === rawPred.argmax, assert(pred === rawPred.argmax,
s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.") s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.")
@ -180,7 +180,6 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest {
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Tests of feature importance // Tests of feature importance
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
test("Feature importance with toy data") { test("Feature importance with toy data") {
val numClasses = 2 val numClasses = 2
val rf = new RandomForestClassifier() val rf = new RandomForestClassifier()

View file

@ -18,7 +18,7 @@
package org.apache.spark.ml.regression package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
@ -26,6 +26,7 @@ import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
DecisionTreeSuite => OldDecisionTreeSuite} DecisionTreeSuite => OldDecisionTreeSuite}
import org.apache.spark.mllib.util.LinearDataGenerator
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.{DataFrame, Row}
@ -35,11 +36,17 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest {
import testImplicits._ import testImplicits._
private var categoricalDataPointsRDD: RDD[LabeledPoint] = _ private var categoricalDataPointsRDD: RDD[LabeledPoint] = _
private var linearRegressionData: DataFrame = _
private val seed = 42
override def beforeAll() { override def beforeAll() {
super.beforeAll() super.beforeAll()
categoricalDataPointsRDD = categoricalDataPointsRDD =
sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints().map(_.asML)) sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints().map(_.asML))
linearRegressionData = sc.parallelize(LinearDataGenerator.generateLinearInput(
intercept = 6.3, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3),
xVariance = Array(0.7, 1.2), nPoints = 1000, seed, eps = 0.5), 2).map(_.asML).toDF()
} }
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
@ -88,7 +95,7 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest {
val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0) val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0)
val model = dt.fit(df) val model = dt.fit(df)
testTransformer[(Vector, Double)](df, model, "features", "variance") { testTransformer[(Vector, Double, Double)](df, model, "features", "variance") {
case Row(features: Vector, variance: Double) => case Row(features: Vector, variance: Double) =>
val expectedVariance = model.rootNode.predictImpl(features).impurityStats.calculate() val expectedVariance = model.rootNode.predictImpl(features).impurityStats.calculate()
assert(variance === expectedVariance, assert(variance === expectedVariance,
@ -101,7 +108,7 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest {
.setMaxBins(6) .setMaxBins(6)
.setSeed(0) .setSeed(0)
testTransformerByGlobalCheckFunc[(Vector, Double)](varianceDF, dt.fit(varianceDF), testTransformerByGlobalCheckFunc[(Vector, Double, Double)](varianceDF, dt.fit(varianceDF),
"variance") { case rows: Seq[Row] => "variance") { case rows: Seq[Row] =>
val calculatedVariances = rows.map(_.getDouble(0)) val calculatedVariances = rows.map(_.getDouble(0))
@ -159,6 +166,28 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest {
} }
} }
test("training with sample weights") {
val df = linearRegressionData
val numClasses = 0
val testParams = Seq(5, 10)
for (maxDepth <- testParams) {
val estimator = new DecisionTreeRegressor()
.setMaxDepth(maxDepth)
.setMinWeightFractionPerNode(0.05)
.setSeed(123)
MLTestingUtils.testArbitrarilyScaledWeights[DecisionTreeRegressionModel,
DecisionTreeRegressor](df.as[LabeledPoint], estimator,
MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.1, 0.99))
MLTestingUtils.testOutliersWithSmallWeights[DecisionTreeRegressionModel,
DecisionTreeRegressor](df.as[LabeledPoint], estimator, numClasses,
MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.1, 0.99),
outlierRatio = 2)
MLTestingUtils.testOversamplingVsWeighting[DecisionTreeRegressionModel,
DecisionTreeRegressor](df.as[LabeledPoint], estimator,
MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.01, 1.0), seed)
}
}
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Tests of model save/load // Tests of model save/load
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////

View file

@ -891,6 +891,7 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLRe
.setStandardization(standardization) .setStandardization(standardization)
.setRegParam(regParam) .setRegParam(regParam)
.setElasticNetParam(elasticNetParam) .setElasticNetParam(elasticNetParam)
.setSolver(solver)
MLTestingUtils.testArbitrarilyScaledWeights[LinearRegressionModel, LinearRegression]( MLTestingUtils.testArbitrarilyScaledWeights[LinearRegressionModel, LinearRegression](
datasetWithStrongNoise.as[LabeledPoint], estimator, modelEquals) datasetWithStrongNoise.as[LabeledPoint], estimator, modelEquals)
MLTestingUtils.testOutliersWithSmallWeights[LinearRegressionModel, LinearRegression]( MLTestingUtils.testOutliersWithSmallWeights[LinearRegressionModel, LinearRegression](

View file

@ -18,6 +18,7 @@
package org.apache.spark.ml.tree.impl package org.apache.spark.ml.tree.impl
import org.apache.spark.SparkFunSuite import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.mllib.tree.EnsembleTestHelper import org.apache.spark.mllib.tree.EnsembleTestHelper
import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.MLlibTestSparkContext
@ -26,12 +27,16 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
*/ */
class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext { class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext {
test("BaggedPoint RDD: without subsampling") { test("BaggedPoint RDD: without subsampling with weights") {
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000).map { lp =>
Instance(lp.label, 0.5, lp.features.asML)
}
val rdd = sc.parallelize(arr) val rdd = sc.parallelize(arr)
val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false, 42) val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false,
(instance: Instance) => instance.weight * 4.0, seed = 42)
baggedRDD.collect().foreach { baggedPoint => baggedRDD.collect().foreach { baggedPoint =>
assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1) assert(baggedPoint.subsampleCounts.size === 1 && baggedPoint.subsampleCounts(0) === 1)
assert(baggedPoint.sampleWeight === 2.0)
} }
} }
@ -40,13 +45,17 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext {
val (expectedMean, expectedStddev) = (1.0, 1.0) val (expectedMean, expectedStddev) = (1.0, 1.0)
val seeds = Array(123, 5354, 230, 349867, 23987) val seeds = Array(123, 5354, 230, 349867, 23987)
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000).map(_.asML)
val rdd = sc.parallelize(arr) val rdd = sc.parallelize(arr)
seeds.foreach { seed => seeds.foreach { seed =>
val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true, seed) val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true,
val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() (_: LabeledPoint) => 2.0, seed)
val subsampleCounts: Array[Array[Double]] =
baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect()
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
expectedStddev, epsilon = 0.01) expectedStddev, epsilon = 0.01)
// should ignore weight function for now
assert(baggedRDD.collect().forall(_.sampleWeight === 1.0))
} }
} }
@ -59,8 +68,10 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext {
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
val rdd = sc.parallelize(arr) val rdd = sc.parallelize(arr)
seeds.foreach { seed => seeds.foreach { seed =>
val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true, seed) val baggedRDD =
val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true, seed = seed)
val subsampleCounts: Array[Array[Double]] =
baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect()
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
expectedStddev, epsilon = 0.01) expectedStddev, epsilon = 0.01)
} }
@ -71,13 +82,17 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext {
val (expectedMean, expectedStddev) = (1.0, 0) val (expectedMean, expectedStddev) = (1.0, 0)
val seeds = Array(123, 5354, 230, 349867, 23987) val seeds = Array(123, 5354, 230, 349867, 23987)
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000).map(_.asML)
val rdd = sc.parallelize(arr) val rdd = sc.parallelize(arr)
seeds.foreach { seed => seeds.foreach { seed =>
val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false, seed) val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false,
val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() (_: LabeledPoint) => 2.0, seed)
val subsampleCounts: Array[Array[Double]] =
baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect()
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
expectedStddev, epsilon = 0.01) expectedStddev, epsilon = 0.01)
// should ignore weight function for now
assert(baggedRDD.collect().forall(_.sampleWeight === 1.0))
} }
} }
@ -90,8 +105,10 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext {
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
val rdd = sc.parallelize(arr) val rdd = sc.parallelize(arr)
seeds.foreach { seed => seeds.foreach { seed =>
val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false, seed) val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false,
val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() seed = seed)
val subsampleCounts: Array[Array[Double]] =
baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect()
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
expectedStddev, epsilon = 0.01) expectedStddev, epsilon = 0.01)
} }

View file

@ -19,10 +19,11 @@ package org.apache.spark.ml.tree.impl
import scala.annotation.tailrec import scala.annotation.tailrec
import scala.collection.mutable import scala.collection.mutable
import scala.language.implicitConversions
import org.apache.spark.SparkFunSuite import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.classification.DecisionTreeClassificationModel import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree._
import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.ml.util.TestingUtils._
@ -46,7 +47,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
test("Binary classification with continuous features: split calculation") { test("Binary classification with continuous features: split calculation") {
val arr = OldDTSuite.generateOrderedLabeledPointsWithLabel1().map(_.asML) val arr = OldDTSuite.generateOrderedLabeledPointsWithLabel1().map(_.asML.toInstance)
assert(arr.length === 1000) assert(arr.length === 1000)
val rdd = sc.parallelize(arr) val rdd = sc.parallelize(arr)
val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2, 100) val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2, 100)
@ -58,7 +59,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
} }
test("Binary classification with binary (ordered) categorical features: split calculation") { test("Binary classification with binary (ordered) categorical features: split calculation") {
val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML) val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML.toInstance)
assert(arr.length === 1000) assert(arr.length === 1000)
val rdd = sc.parallelize(arr) val rdd = sc.parallelize(arr)
val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 2, val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 2,
@ -75,7 +76,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
test("Binary classification with 3-ary (ordered) categorical features," + test("Binary classification with 3-ary (ordered) categorical features," +
" with no samples for one category: split calculation") { " with no samples for one category: split calculation") {
val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML) val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML.toInstance)
assert(arr.length === 1000) assert(arr.length === 1000)
val rdd = sc.parallelize(arr) val rdd = sc.parallelize(arr)
val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 2, val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 2,
@ -93,12 +94,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
test("find splits for a continuous feature") { test("find splits for a continuous feature") {
// find splits for normal case // find splits for normal case
{ {
val fakeMetadata = new DecisionTreeMetadata(1, 200000, 0, 0, val fakeMetadata = new DecisionTreeMetadata(1, 200000, 200000.0, 0, 0,
Map(), Set(), Map(), Set(),
Array(6), Gini, QuantileStrategy.Sort, Array(6), Gini, QuantileStrategy.Sort,
0, 0, 0.0, 0, 0 0, 0, 0.0, 0.0, 0, 0
) )
val featureSamples = Array.fill(10000)(math.random).filter(_ != 0.0) val featureSamples = Array.fill(10000)((1.0, math.random)).filter(_._2 != 0.0)
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits.length === 5) assert(splits.length === 5)
assert(fakeMetadata.numSplits(0) === 5) assert(fakeMetadata.numSplits(0) === 5)
@ -109,15 +110,16 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
// SPARK-16957: Use midpoints for split values. // SPARK-16957: Use midpoints for split values.
{ {
val fakeMetadata = new DecisionTreeMetadata(1, 8, 0, 0, val fakeMetadata = new DecisionTreeMetadata(1, 8, 8.0, 0, 0,
Map(), Set(), Map(), Set(),
Array(3), Gini, QuantileStrategy.Sort, Array(3), Gini, QuantileStrategy.Sort,
0, 0, 0.0, 0, 0 0, 0, 0.0, 0.0, 0, 0
) )
// possibleSplits <= numSplits // possibleSplits <= numSplits
{ {
val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1).map(_.toDouble).filter(_ != 0.0) val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1)
.map(x => (1.0, x.toDouble)).filter(_._2 != 0.0)
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
val expectedSplits = Array((0.0 + 1.0) / 2) val expectedSplits = Array((0.0 + 1.0) / 2)
assert(splits === expectedSplits) assert(splits === expectedSplits)
@ -125,7 +127,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
// possibleSplits > numSplits // possibleSplits > numSplits
{ {
val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3).map(_.toDouble).filter(_ != 0.0) val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3)
.map(x => (1.0, x.toDouble)).filter(_._2 != 0.0)
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
val expectedSplits = Array((0.0 + 1.0) / 2, (2.0 + 3.0) / 2) val expectedSplits = Array((0.0 + 1.0) / 2, (2.0 + 3.0) / 2)
assert(splits === expectedSplits) assert(splits === expectedSplits)
@ -135,12 +138,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
// find splits should not return identical splits // find splits should not return identical splits
// when there are not enough split candidates, reduce the number of splits in metadata // when there are not enough split candidates, reduce the number of splits in metadata
{ {
val fakeMetadata = new DecisionTreeMetadata(1, 12, 0, 0, val fakeMetadata = new DecisionTreeMetadata(1, 12, 12.0, 0, 0,
Map(), Set(), Map(), Set(),
Array(5), Gini, QuantileStrategy.Sort, Array(5), Gini, QuantileStrategy.Sort,
0, 0, 0.0, 0, 0 0, 0, 0.0, 0.0, 0, 0
) )
val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3).map(_.toDouble) val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3).map(x => (1.0, x.toDouble))
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
val expectedSplits = Array((1.0 + 2.0) / 2, (2.0 + 3.0) / 2) val expectedSplits = Array((1.0 + 2.0) / 2, (2.0 + 3.0) / 2)
assert(splits === expectedSplits) assert(splits === expectedSplits)
@ -150,13 +153,13 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
// find splits when most samples close to the minimum // find splits when most samples close to the minimum
{ {
val fakeMetadata = new DecisionTreeMetadata(1, 18, 0, 0, val fakeMetadata = new DecisionTreeMetadata(1, 18, 18.0, 0, 0,
Map(), Set(), Map(), Set(),
Array(3), Gini, QuantileStrategy.Sort, Array(3), Gini, QuantileStrategy.Sort,
0, 0, 0.0, 0, 0 0, 0, 0.0, 0.0, 0, 0
) )
val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5) val featureSamples =
.map(_.toDouble) Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(x => (1.0, x.toDouble))
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
val expectedSplits = Array((2.0 + 3.0) / 2, (3.0 + 4.0) / 2) val expectedSplits = Array((2.0 + 3.0) / 2, (3.0 + 4.0) / 2)
assert(splits === expectedSplits) assert(splits === expectedSplits)
@ -164,37 +167,55 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
// find splits when most samples close to the maximum // find splits when most samples close to the maximum
{ {
val fakeMetadata = new DecisionTreeMetadata(1, 17, 0, 0, val fakeMetadata = new DecisionTreeMetadata(1, 17, 17.0, 0, 0,
Map(), Set(), Map(), Set(),
Array(2), Gini, QuantileStrategy.Sort, Array(2), Gini, QuantileStrategy.Sort,
0, 0, 0.0, 0, 0 0, 0, 0.0, 0.0, 0, 0
) )
val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2) val featureSamples =
.map(_.toDouble).filter(_ != 0.0) Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(x => (1.0, x.toDouble))
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
val expectedSplits = Array((1.0 + 2.0) / 2) val expectedSplits = Array((1.0 + 2.0) / 2)
assert(splits === expectedSplits) assert(splits === expectedSplits)
} }
// find splits for constant feature // find splits for arbitrarily scaled data
{ {
val fakeMetadata = new DecisionTreeMetadata(1, 3, 0, 0, val fakeMetadata = new DecisionTreeMetadata(1, 0, 0.0, 0, 0,
Map(), Set(),
Array(6), Gini, QuantileStrategy.Sort,
0, 0, 0.0, 0.0, 0, 0
)
val featureSamplesUnitWeight = Array.fill(10)((1.0, math.random))
val featureSamplesSmallWeight = featureSamplesUnitWeight.map { case (w, x) => (w * 0.001, x)}
val featureSamplesLargeWeight = featureSamplesUnitWeight.map { case (w, x) => (w * 1000, x)}
val splitsUnitWeight = RandomForest
.findSplitsForContinuousFeature(featureSamplesUnitWeight, fakeMetadata, 0)
val splitsSmallWeight = RandomForest
.findSplitsForContinuousFeature(featureSamplesSmallWeight, fakeMetadata, 0)
val splitsLargeWeight = RandomForest
.findSplitsForContinuousFeature(featureSamplesLargeWeight, fakeMetadata, 0)
assert(splitsUnitWeight === splitsSmallWeight)
assert(splitsUnitWeight === splitsLargeWeight)
}
// find splits when most weight is close to the minimum
{
val fakeMetadata = new DecisionTreeMetadata(1, 0, 0.0, 0, 0,
Map(), Set(), Map(), Set(),
Array(3), Gini, QuantileStrategy.Sort, Array(3), Gini, QuantileStrategy.Sort,
0, 0, 0.0, 0, 0 0, 0, 0.0, 0.0, 0, 0
) )
val featureSamples = Array(0, 0, 0).map(_.toDouble).filter(_ != 0.0) val featureSamples = Array((10, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6)).map {
val featureSamplesEmpty = Array.empty[Double] case (w, x) => (w.toDouble, x.toDouble)
}
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits === Array.empty[Double]) assert(splits === Array(1.5, 2.5, 3.5, 4.5, 5.5))
val splitsEmpty =
RandomForest.findSplitsForContinuousFeature(featureSamplesEmpty, fakeMetadata, 0)
assert(splitsEmpty === Array.empty[Double])
} }
} }
test("train with empty arrays") { test("train with empty arrays") {
val lp = LabeledPoint(1.0, Vectors.dense(Array.empty[Double])) val lp = LabeledPoint(1.0, Vectors.dense(Array.empty[Double])).toInstance
val data = Array.fill(5)(lp) val data = Array.fill(5)(lp)
val rdd = sc.parallelize(data) val rdd = sc.parallelize(data)
@ -209,8 +230,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
} }
test("train with constant features") { test("train with constant features") {
val lp = LabeledPoint(1.0, Vectors.dense(0.0, 0.0, 0.0)) val instance = LabeledPoint(1.0, Vectors.dense(0.0, 0.0, 0.0)).toInstance
val data = Array.fill(5)(lp) val data = Array.fill(5)(instance)
val rdd = sc.parallelize(data) val rdd = sc.parallelize(data)
val strategy = new OldStrategy( val strategy = new OldStrategy(
OldAlgo.Classification, OldAlgo.Classification,
@ -222,7 +243,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val Array(tree) = RandomForest.run(rdd, strategy, 1, "all", 42L, instr = None) val Array(tree) = RandomForest.run(rdd, strategy, 1, "all", 42L, instr = None)
assert(tree.rootNode.impurity === -1.0) assert(tree.rootNode.impurity === -1.0)
assert(tree.depth === 0) assert(tree.depth === 0)
assert(tree.rootNode.prediction === lp.label) assert(tree.rootNode.prediction === instance.label)
// Test with no categorical features // Test with no categorical features
val strategy2 = new OldStrategy( val strategy2 = new OldStrategy(
@ -233,11 +254,11 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val Array(tree2) = RandomForest.run(rdd, strategy2, 1, "all", 42L, instr = None) val Array(tree2) = RandomForest.run(rdd, strategy2, 1, "all", 42L, instr = None)
assert(tree2.rootNode.impurity === -1.0) assert(tree2.rootNode.impurity === -1.0)
assert(tree2.depth === 0) assert(tree2.depth === 0)
assert(tree2.rootNode.prediction === lp.label) assert(tree2.rootNode.prediction === instance.label)
} }
test("Multiclass classification with unordered categorical features: split calculations") { test("Multiclass classification with unordered categorical features: split calculations") {
val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML) val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML.toInstance)
assert(arr.length === 1000) assert(arr.length === 1000)
val rdd = sc.parallelize(arr) val rdd = sc.parallelize(arr)
val strategy = new OldStrategy( val strategy = new OldStrategy(
@ -278,7 +299,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
} }
test("Multiclass classification with ordered categorical features: split calculations") { test("Multiclass classification with ordered categorical features: split calculations") {
val arr = OldDTSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures().map(_.asML) val arr = OldDTSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
.map(_.asML.toInstance)
assert(arr.length === 3000) assert(arr.length === 3000)
val rdd = sc.parallelize(arr) val rdd = sc.parallelize(arr)
val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 100, val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 100,
@ -310,7 +332,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)), LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)),
LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)), LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)),
LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))) LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
val input = sc.parallelize(arr) val input = sc.parallelize(arr.map(_.toInstance))
val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 1, val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 1,
numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
@ -352,7 +374,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)), LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)),
LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)), LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)),
LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))) LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
val input = sc.parallelize(arr) val input = sc.parallelize(arr.map(_.toInstance))
val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 5, val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 5,
numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
@ -404,7 +426,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
LabeledPoint(0.0, Vectors.dense(2.0)), LabeledPoint(0.0, Vectors.dense(2.0)),
LabeledPoint(0.0, Vectors.dense(2.0)), LabeledPoint(0.0, Vectors.dense(2.0)),
LabeledPoint(1.0, Vectors.dense(2.0))) LabeledPoint(1.0, Vectors.dense(2.0)))
val input = sc.parallelize(arr) val input = sc.parallelize(arr.map(_.toInstance))
// Must set maxBins s.t. the feature will be treated as an ordered categorical feature. // Must set maxBins s.t. the feature will be treated as an ordered categorical feature.
val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 1, val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 1,
@ -424,7 +446,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
} }
test("Second level node building with vs. without groups") { test("Second level node building with vs. without groups") {
val arr = OldDTSuite.generateOrderedLabeledPoints().map(_.asML) val arr = OldDTSuite.generateOrderedLabeledPoints().map(_.asML.toInstance)
assert(arr.length === 1000) assert(arr.length === 1000)
val rdd = sc.parallelize(arr) val rdd = sc.parallelize(arr)
// For tree with 1 group // For tree with 1 group
@ -468,7 +490,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
def binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy: OldStrategy) { def binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy: OldStrategy) {
val numFeatures = 50 val numFeatures = 50
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures, 1000) val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures, 1000)
val rdd = sc.parallelize(arr).map(_.asML) val rdd = sc.parallelize(arr).map(_.asML.toInstance)
// Select feature subset for top nodes. Return true if OK. // Select feature subset for top nodes. Return true if OK.
def checkFeatureSubsetStrategy( def checkFeatureSubsetStrategy(
@ -581,16 +603,16 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
left2 parent left2 parent
left right left right
*/ */
val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0)) val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0), 6L)
val left = new LeafNode(0.0, leftImp.calculate(), leftImp) val left = new LeafNode(0.0, leftImp.calculate(), leftImp)
val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0)) val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0), 8L)
val right = new LeafNode(2.0, rightImp.calculate(), rightImp) val right = new LeafNode(2.0, rightImp.calculate(), rightImp)
val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5)) val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5))
val parentImp = parent.impurityStats val parentImp = parent.impurityStats
val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0)) val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0), 8L)
val left2 = new LeafNode(0.0, left2Imp.calculate(), left2Imp) val left2 = new LeafNode(0.0, left2Imp.calculate(), left2Imp)
val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0)) val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0))
@ -647,12 +669,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
// feature_0 = 0 improves the impurity measure, despite the prediction will always be 0 // feature_0 = 0 improves the impurity measure, despite the prediction will always be 0
// in both branches. // in both branches.
val arr = Array( val arr = Array(
LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), Instance(0.0, 1.0, Vectors.dense(0.0, 1.0)),
LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), Instance(1.0, 1.0, Vectors.dense(0.0, 1.0)),
LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), Instance(0.0, 1.0, Vectors.dense(0.0, 0.0)),
LabeledPoint(1.0, Vectors.dense(1.0, 0.0)), Instance(1.0, 1.0, Vectors.dense(1.0, 0.0)),
LabeledPoint(0.0, Vectors.dense(1.0, 0.0)), Instance(0.0, 1.0, Vectors.dense(1.0, 0.0)),
LabeledPoint(1.0, Vectors.dense(1.0, 1.0)) Instance(1.0, 1.0, Vectors.dense(1.0, 1.0))
) )
val rdd = sc.parallelize(arr) val rdd = sc.parallelize(arr)
@ -677,13 +699,13 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
// feature_1 = 1 improves the impurity measure, despite the prediction will always be 0.5 // feature_1 = 1 improves the impurity measure, despite the prediction will always be 0.5
// in both branches. // in both branches.
val arr = Array( val arr = Array(
LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), Instance(0.0, 1.0, Vectors.dense(0.0, 1.0)),
LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), Instance(1.0, 1.0, Vectors.dense(0.0, 1.0)),
LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), Instance(0.0, 1.0, Vectors.dense(0.0, 0.0)),
LabeledPoint(0.0, Vectors.dense(1.0, 0.0)), Instance(0.0, 1.0, Vectors.dense(1.0, 0.0)),
LabeledPoint(1.0, Vectors.dense(1.0, 1.0)), Instance(1.0, 1.0, Vectors.dense(1.0, 1.0)),
LabeledPoint(0.0, Vectors.dense(1.0, 1.0)), Instance(0.0, 1.0, Vectors.dense(1.0, 1.0)),
LabeledPoint(0.5, Vectors.dense(1.0, 1.0)) Instance(0.5, 1.0, Vectors.dense(1.0, 1.0))
) )
val rdd = sc.parallelize(arr) val rdd = sc.parallelize(arr)
@ -700,6 +722,56 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(unprunedTree.numNodes === 5) assert(unprunedTree.numNodes === 5)
assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.size) assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.size)
} }
test("weights at arbitrary scale") {
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(3, 10)
val rddWithUnitWeights = sc.parallelize(arr.map(_.asML.toInstance))
val rddWithSmallWeights = rddWithUnitWeights.map { inst =>
Instance(inst.label, 0.001, inst.features)
}
val rddWithBigWeights = rddWithUnitWeights.map { inst =>
Instance(inst.label, 1000, inst.features)
}
val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2)
val unitWeightTrees = RandomForest.run(rddWithUnitWeights, strategy, 3, "all", 42L, None)
val smallWeightTrees = RandomForest.run(rddWithSmallWeights, strategy, 3, "all", 42L, None)
unitWeightTrees.zip(smallWeightTrees).foreach { case (unitTree, smallWeightTree) =>
TreeTests.checkEqual(unitTree, smallWeightTree)
}
val bigWeightTrees = RandomForest.run(rddWithBigWeights, strategy, 3, "all", 42L, None)
unitWeightTrees.zip(bigWeightTrees).foreach { case (unitTree, bigWeightTree) =>
TreeTests.checkEqual(unitTree, bigWeightTree)
}
}
test("minWeightFraction and minInstancesPerNode") {
val data = Array(
Instance(0.0, 1.0, Vectors.dense(0.0)),
Instance(0.0, 1.0, Vectors.dense(0.0)),
Instance(0.0, 1.0, Vectors.dense(0.0)),
Instance(0.0, 1.0, Vectors.dense(0.0)),
Instance(1.0, 0.1, Vectors.dense(1.0))
)
val rdd = sc.parallelize(data)
val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2,
minWeightFractionPerNode = 0.5)
val Array(tree1) = RandomForest.run(rdd, strategy, 1, "all", 42L, None)
assert(tree1.depth === 0)
strategy.minWeightFractionPerNode = 0.0
val Array(tree2) = RandomForest.run(rdd, strategy, 1, "all", 42L, None)
assert(tree2.depth === 1)
strategy.minInstancesPerNode = 2
val Array(tree3) = RandomForest.run(rdd, strategy, 1, "all", 42L, None)
assert(tree3.depth === 0)
strategy.minInstancesPerNode = 1
val Array(tree4) = RandomForest.run(rdd, strategy, 1, "all", 42L, None)
assert(tree4.depth === 1)
}
} }
private object RandomForestSuite { private object RandomForestSuite {
@ -717,7 +789,7 @@ private object RandomForestSuite {
else { else {
nodes.head match { nodes.head match {
case i: InternalNode => getSumLeafCounters(i.leftChild :: i.rightChild :: nodes.tail, acc) case i: InternalNode => getSumLeafCounters(i.leftChild :: i.rightChild :: nodes.tail, acc)
case l: LeafNode => getSumLeafCounters(nodes.tail, acc + l.impurityStats.count) case l: LeafNode => getSumLeafCounters(nodes.tail, acc + l.impurityStats.rawCount)
} }
} }
} }

View file

@ -28,7 +28,7 @@ class TreePointSuite extends SparkFunSuite {
val ser = new KryoSerializer(conf).newInstance() val ser = new KryoSerializer(conf).newInstance()
val point = new TreePoint(1.0, Array(1, 2, 3)) val point = new TreePoint(1.0, Array(1, 2, 3), 1.0)
val point2 = ser.deserialize[TreePoint](ser.serialize(point)) val point2 = ser.deserialize[TreePoint](ser.serialize(point))
assert(point.label === point2.label) assert(point.label === point2.label)
assert(point.binnedFeatures === point2.binnedFeatures) assert(point.binnedFeatures === point2.binnedFeatures)

View file

@ -18,13 +18,15 @@
package org.apache.spark.ml.tree.impl package org.apache.spark.ml.tree.impl
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import scala.util.Random
import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaRDD
import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree._
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.{DataFrame, SparkSession}
@ -32,6 +34,7 @@ private[ml] object TreeTests extends SparkFunSuite {
/** /**
* Convert the given data to a DataFrame, and set the features and label metadata. * Convert the given data to a DataFrame, and set the features and label metadata.
*
* @param data Dataset. Categorical features and labels must already have 0-based indices. * @param data Dataset. Categorical features and labels must already have 0-based indices.
* This must be non-empty. * This must be non-empty.
* @param categoricalFeatures Map: categorical feature index to number of distinct values * @param categoricalFeatures Map: categorical feature index to number of distinct values
@ -39,16 +42,22 @@ private[ml] object TreeTests extends SparkFunSuite {
* @return DataFrame with metadata * @return DataFrame with metadata
*/ */
def setMetadata( def setMetadata(
data: RDD[LabeledPoint], data: RDD[_],
categoricalFeatures: Map[Int, Int], categoricalFeatures: Map[Int, Int],
numClasses: Int): DataFrame = { numClasses: Int): DataFrame = {
val dataOfInstance: RDD[Instance] = data.map {
_ match {
case instance: Instance => instance
case labeledPoint: LabeledPoint => labeledPoint.toInstance
}
}
val spark = SparkSession.builder() val spark = SparkSession.builder()
.sparkContext(data.sparkContext) .sparkContext(data.sparkContext)
.getOrCreate() .getOrCreate()
import spark.implicits._ import spark.implicits._
val df = data.toDF() val df = dataOfInstance.toDF()
val numFeatures = data.first().features.size val numFeatures = dataOfInstance.first().features.size
val featuresAttributes = Range(0, numFeatures).map { feature => val featuresAttributes = Range(0, numFeatures).map { feature =>
if (categoricalFeatures.contains(feature)) { if (categoricalFeatures.contains(feature)) {
NominalAttribute.defaultAttr.withIndex(feature).withNumValues(categoricalFeatures(feature)) NominalAttribute.defaultAttr.withIndex(feature).withNumValues(categoricalFeatures(feature))
@ -64,7 +73,7 @@ private[ml] object TreeTests extends SparkFunSuite {
} }
val labelMetadata = labelAttribute.toMetadata() val labelMetadata = labelAttribute.toMetadata()
df.select(df("features").as("features", featuresMetadata), df.select(df("features").as("features", featuresMetadata),
df("label").as("label", labelMetadata)) df("label").as("label", labelMetadata), df("weight"))
} }
/** /**
@ -80,6 +89,7 @@ private[ml] object TreeTests extends SparkFunSuite {
/** /**
* Set label metadata (particularly the number of classes) on a DataFrame. * Set label metadata (particularly the number of classes) on a DataFrame.
*
* @param data Dataset. Categorical features and labels must already have 0-based indices. * @param data Dataset. Categorical features and labels must already have 0-based indices.
* This must be non-empty. * This must be non-empty.
* @param numClasses Number of classes label can take. If 0, mark as continuous. * @param numClasses Number of classes label can take. If 0, mark as continuous.
@ -124,8 +134,8 @@ private[ml] object TreeTests extends SparkFunSuite {
* make mistakes such as creating loops of Nodes. * make mistakes such as creating loops of Nodes.
*/ */
private def checkEqual(a: Node, b: Node): Unit = { private def checkEqual(a: Node, b: Node): Unit = {
assert(a.prediction === b.prediction) assert(a.prediction ~== b.prediction absTol 1e-8)
assert(a.impurity === b.impurity) assert(a.impurity ~== b.impurity absTol 1e-8)
(a, b) match { (a, b) match {
case (aye: InternalNode, bee: InternalNode) => case (aye: InternalNode, bee: InternalNode) =>
assert(aye.split === bee.split) assert(aye.split === bee.split)
@ -156,6 +166,7 @@ private[ml] object TreeTests extends SparkFunSuite {
/** /**
* Helper method for constructing a tree for testing. * Helper method for constructing a tree for testing.
* Given left, right children, construct a parent node. * Given left, right children, construct a parent node.
*
* @param split Split for parent node * @param split Split for parent node
* @return Parent node with children attached * @return Parent node with children attached
*/ */
@ -163,8 +174,8 @@ private[ml] object TreeTests extends SparkFunSuite {
val leftImp = left.impurityStats val leftImp = left.impurityStats
val rightImp = right.impurityStats val rightImp = right.impurityStats
val parentImp = leftImp.copy.add(rightImp) val parentImp = leftImp.copy.add(rightImp)
val leftWeight = leftImp.count / parentImp.count.toDouble val leftWeight = leftImp.count / parentImp.count
val rightWeight = rightImp.count / parentImp.count.toDouble val rightWeight = rightImp.count / parentImp.count
val gain = parentImp.calculate() - val gain = parentImp.calculate() -
(leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate()) (leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate())
val pred = parentImp.predict val pred = parentImp.predict

View file

@ -23,7 +23,7 @@ import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasWeightCol} import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.recommendation.{ALS, ALSModel} import org.apache.spark.ml.recommendation.{ALS, ALSModel}
import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
@ -205,8 +205,8 @@ object MLTestingUtils extends SparkFunSuite {
seed: Long): Unit = { seed: Long): Unit = {
val (overSampledData, weightedData) = genEquivalentOversampledAndWeightedInstances( val (overSampledData, weightedData) = genEquivalentOversampledAndWeightedInstances(
data, seed) data, seed)
val weightedModel = estimator.set(estimator.weightCol, "weight").fit(weightedData)
val overSampledModel = estimator.set(estimator.weightCol, "").fit(overSampledData) val overSampledModel = estimator.set(estimator.weightCol, "").fit(overSampledData)
val weightedModel = estimator.set(estimator.weightCol, "weight").fit(weightedData)
modelEquals(weightedModel, overSampledModel) modelEquals(weightedModel, overSampledModel)
} }
@ -228,7 +228,8 @@ object MLTestingUtils extends SparkFunSuite {
List.fill(outlierRatio)(Instance(outlierLabel, 0.0001, f)) ++ List(Instance(l, w, f)) List.fill(outlierRatio)(Instance(outlierLabel, 0.0001, f)) ++ List(Instance(l, w, f))
} }
val trueModel = estimator.set(estimator.weightCol, "").fit(data) val trueModel = estimator.set(estimator.weightCol, "").fit(data)
val outlierModel = estimator.set(estimator.weightCol, "weight").fit(outlierDS) val outlierModel = estimator.set(estimator.weightCol, "weight")
.fit(outlierDS)
modelEquals(trueModel, outlierModel) modelEquals(trueModel, outlierModel)
} }
@ -241,7 +242,7 @@ object MLTestingUtils extends SparkFunSuite {
estimator: E with HasWeightCol, estimator: E with HasWeightCol,
modelEquals: (M, M) => Unit): Unit = { modelEquals: (M, M) => Unit): Unit = {
estimator.set(estimator.weightCol, "weight") estimator.set(estimator.weightCol, "weight")
val models = Seq(0.001, 1.0, 1000.0).map { w => val models = Seq(0.01, 1.0, 1000.0).map { w =>
val df = data.withColumn("weight", lit(w)) val df = data.withColumn("weight", lit(w))
estimator.fit(df) estimator.fit(df)
} }
@ -268,4 +269,20 @@ object MLTestingUtils extends SparkFunSuite {
assert(newDatasetF.schema(featuresColName).dataType.equals(new ArrayType(FloatType, false))) assert(newDatasetF.schema(featuresColName).dataType.equals(new ArrayType(FloatType, false)))
(newDataset, newDatasetD, newDatasetF) (newDataset, newDatasetD, newDatasetF)
} }
def modelPredictionEquals[M <: PredictionModel[_, M]](
data: DataFrame,
compareFunc: (Double, Double) => Boolean,
fractionInTol: Double)(
model1: M,
model2: M): Unit = {
val pred1 = model1.transform(data).select(model1.getPredictionCol).collect()
val pred2 = model2.transform(data).select(model2.getPredictionCol).collect()
val inTol = pred1.zip(pred2).count { case (p1, p2) =>
val x = p1.getDouble(0)
val y = p2.getDouble(0)
compareFunc(x, y)
}
assert(inTol / pred1.length.toDouble >= fractionInTol)
}
} }

View file

@ -73,7 +73,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
maxBins = 100, maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy) val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1)) assert(!metadata.isUnordered(featureIndex = 1))
@ -100,7 +100,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
maxDepth = 2, maxDepth = 2,
maxBins = 100, maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2)) categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2))
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy) val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1)) assert(!metadata.isUnordered(featureIndex = 1))
@ -116,7 +116,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val rdd = sc.parallelize(arr) val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Gini, maxDepth = 3, val strategy = new Strategy(Classification, Gini, maxDepth = 3,
numClasses = 2, maxBins = 100) numClasses = 2, maxBins = 100)
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy) val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1)) assert(!metadata.isUnordered(featureIndex = 1))
@ -133,7 +133,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val rdd = sc.parallelize(arr) val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Gini, maxDepth = 3, val strategy = new Strategy(Classification, Gini, maxDepth = 3,
numClasses = 2, maxBins = 100) numClasses = 2, maxBins = 100)
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy) val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1)) assert(!metadata.isUnordered(featureIndex = 1))
@ -150,7 +150,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val rdd = sc.parallelize(arr) val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Entropy, maxDepth = 3, val strategy = new Strategy(Classification, Entropy, maxDepth = 3,
numClasses = 2, maxBins = 100) numClasses = 2, maxBins = 100)
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy) val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1)) assert(!metadata.isUnordered(featureIndex = 1))
@ -167,7 +167,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val rdd = sc.parallelize(arr) val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Entropy, maxDepth = 3, val strategy = new Strategy(Classification, Entropy, maxDepth = 3,
numClasses = 2, maxBins = 100) numClasses = 2, maxBins = 100)
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy) val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1)) assert(!metadata.isUnordered(featureIndex = 1))
@ -183,7 +183,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val rdd = sc.parallelize(arr) val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClasses = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) numClasses = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy) val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
assert(strategy.isMulticlassClassification) assert(strategy.isMulticlassClassification)
assert(metadata.isUnordered(featureIndex = 0)) assert(metadata.isUnordered(featureIndex = 0))
assert(metadata.isUnordered(featureIndex = 1)) assert(metadata.isUnordered(featureIndex = 1))
@ -240,7 +240,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
numClasses = 3, maxBins = maxBins, numClasses = 3, maxBins = maxBins,
categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
assert(strategy.isMulticlassClassification) assert(strategy.isMulticlassClassification)
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy) val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
assert(metadata.isUnordered(featureIndex = 0)) assert(metadata.isUnordered(featureIndex = 0))
assert(metadata.isUnordered(featureIndex = 1)) assert(metadata.isUnordered(featureIndex = 1))
@ -288,7 +288,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClasses = 3, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3)) numClasses = 3, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3))
assert(strategy.isMulticlassClassification) assert(strategy.isMulticlassClassification)
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy) val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
assert(metadata.isUnordered(featureIndex = 0)) assert(metadata.isUnordered(featureIndex = 0))
val model = DecisionTree.train(rdd, strategy) val model = DecisionTree.train(rdd, strategy)
@ -310,7 +310,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
numClasses = 3, maxBins = 100, numClasses = 3, maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
assert(strategy.isMulticlassClassification) assert(strategy.isMulticlassClassification)
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy) val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1)) assert(!metadata.isUnordered(featureIndex = 1))

View file

@ -18,23 +18,63 @@
package org.apache.spark.mllib.tree package org.apache.spark.mllib.tree
import org.apache.spark.SparkFunSuite import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator} import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.tree.impurity._
/** /**
* Test suites for `GiniAggregator` and `EntropyAggregator`. * Test suites for `GiniAggregator` and `EntropyAggregator`.
*/ */
class ImpuritySuite extends SparkFunSuite { class ImpuritySuite extends SparkFunSuite {
private val seed = 42
test("Gini impurity does not support negative labels") { test("Gini impurity does not support negative labels") {
val gini = new GiniAggregator(2) val gini = new GiniAggregator(2)
intercept[IllegalArgumentException] { intercept[IllegalArgumentException] {
gini.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0) gini.update(Array(0.0, 1.0, 2.0), 0, -1, 3, 0.0)
} }
} }
test("Entropy does not support negative labels") { test("Entropy does not support negative labels") {
val entropy = new EntropyAggregator(2) val entropy = new EntropyAggregator(2)
intercept[IllegalArgumentException] { intercept[IllegalArgumentException] {
entropy.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0) entropy.update(Array(0.0, 1.0, 2.0), 0, -1, 3, 0.0)
} }
} }
test("Classification impurities are insensitive to scaling") {
val rng = new scala.util.Random(seed)
val weightedCounts = Array.fill(5)(rng.nextDouble())
val smallWeightedCounts = weightedCounts.map(_ * 0.0001)
val largeWeightedCounts = weightedCounts.map(_ * 10000)
Seq(Gini, Entropy).foreach { impurity =>
val impurity1 = impurity.calculate(weightedCounts, weightedCounts.sum)
assert(impurity.calculate(smallWeightedCounts, smallWeightedCounts.sum)
~== impurity1 relTol 0.005)
assert(impurity.calculate(largeWeightedCounts, largeWeightedCounts.sum)
~== impurity1 relTol 0.005)
}
}
test("Regression impurities are insensitive to scaling") {
def computeStats(samples: Seq[Double], weights: Seq[Double]): (Double, Double, Double) = {
samples.zip(weights).foldLeft((0.0, 0.0, 0.0)) { case ((wn, wy, wyy), (y, w)) =>
(wn + w, wy + w * y, wyy + w * y * y)
}
}
val rng = new scala.util.Random(seed)
val samples = Array.fill(10)(rng.nextDouble())
val _weights = Array.fill(10)(rng.nextDouble())
val smallWeights = _weights.map(_ * 0.0001)
val largeWeights = _weights.map(_ * 10000)
val (count, sum, sumSquared) = computeStats(samples, _weights)
Seq(Variance).foreach { impurity =>
val impurity1 = impurity.calculate(count, sum, sumSquared)
val (smallCount, smallSum, smallSumSquared) = computeStats(samples, smallWeights)
val (largeCount, largeSum, largeSumSquared) = computeStats(samples, largeWeights)
assert(impurity.calculate(smallCount, smallSum, smallSumSquared) ~== impurity1 relTol 0.005)
assert(impurity.calculate(largeCount, largeSum, largeSumSquared) ~== impurity1 relTol 0.005)
}
}
} }