[SPARK-19591][ML][MLLIB] Add sample weights to decision trees
This is updated PR https://github.com/apache/spark/pull/16722 to latest master ## What changes were proposed in this pull request? This patch adds support for sample weights to DecisionTreeRegressor and DecisionTreeClassifier. Note: This patch does not add support for sample weights to RandomForest. As discussed in the JIRA, we would like to add sample weights into the bagging process. This patch is large enough as is, and there are some additional considerations to be made for random forests. Since the machinery introduced here needs to be present regardless, I have opted to leave random forests for a follow up pr. ## How was this patch tested? The algorithms are tested to ensure that: 1. Arbitrary scaling of constant weights has no effect 2. Outliers with small weights do not affect the learned model 3. Oversampling and weighting are equivalent Unit tests are also added to test other smaller components. ## Summary of changes - Impurity aggregators now store weighted sufficient statistics. They also store a raw count, however, since this is needed to use minInstancesPerNode. - Impurity aggregators now also hold the raw count. - This patch maintains the meaning of minInstancesPerNode, in that the parameter still corresponds to raw, unweighted counts. It also adds a new parameter minWeightFractionPerNode which requires that nodes must contain at least minWeightFractionPerNode * weightedNumExamples total weight. - This patch modifies findSplitsForContinuousFeatures to use weighted sums. Unit tests are added. - TreePoint is modified to hold a sample weight - BaggedPoint is modified from: ``` Scala private[spark] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double]) extends Serializable ``` to ``` Scala private[spark] class BaggedPoint[Datum]( val datum: Datum, val subsampleCounts: Array[Int], val sampleWeight: Double) extends Serializable ``` We do not simply multiply the counts by the weight and store that because we need the raw counts and the weight in order to use both minInstancesPerNode and minWeightPerNode **Note**: many of the changed files are due simply to using Instance instead of LabeledPoint Closes #21632 from imatiach-msft/ilmat/sample-weights. Authored-by: Ilya Matiach <ilmat@microsoft.com> Signed-off-by: Sean Owen <sean.owen@databricks.com>
This commit is contained in:
parent
3699763fda
commit
b2d36f65db
|
@ -52,7 +52,7 @@ object TestingUtils {
|
||||||
/**
|
/**
|
||||||
* Private helper function for comparing two values using absolute tolerance.
|
* Private helper function for comparing two values using absolute tolerance.
|
||||||
*/
|
*/
|
||||||
private def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
|
private[ml] def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
|
||||||
// Special case for NaNs
|
// Special case for NaNs
|
||||||
if (x.isNaN && y.isNaN) {
|
if (x.isNaN && y.isNaN) {
|
||||||
return true
|
return true
|
||||||
|
|
|
@ -77,17 +77,37 @@ abstract class Classifier[
|
||||||
* @note Throws `SparkException` if any label is a non-integer or is negative
|
* @note Throws `SparkException` if any label is a non-integer or is negative
|
||||||
*/
|
*/
|
||||||
protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = {
|
protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = {
|
||||||
require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" +
|
validateNumClasses(numClasses)
|
||||||
s" $numClasses, but requires numClasses > 0.")
|
|
||||||
dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
|
dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
|
||||||
case Row(label: Double, features: Vector) =>
|
case Row(label: Double, features: Vector) =>
|
||||||
require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" +
|
validateLabel(label, numClasses)
|
||||||
s" dataset with invalid label $label. Labels must be integers in range" +
|
|
||||||
s" [0, $numClasses).")
|
|
||||||
LabeledPoint(label, features)
|
LabeledPoint(label, features)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validates that number of classes is greater than zero.
|
||||||
|
*
|
||||||
|
* @param numClasses Number of classes label can take.
|
||||||
|
*/
|
||||||
|
protected def validateNumClasses(numClasses: Int): Unit = {
|
||||||
|
require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" +
|
||||||
|
s" $numClasses, but requires numClasses > 0.")
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validates the label on the classifier is a valid integer in the range [0, numClasses).
|
||||||
|
*
|
||||||
|
* @param label The label to validate.
|
||||||
|
* @param numClasses Number of classes label can take. Labels must be integers in the range
|
||||||
|
* [0, numClasses).
|
||||||
|
*/
|
||||||
|
protected def validateLabel(label: Double, numClasses: Int): Unit = {
|
||||||
|
require(label.toLong == label && label >= 0 && label < numClasses, s"Classifier was given" +
|
||||||
|
s" dataset with invalid label $label. Labels must be integers in range" +
|
||||||
|
s" [0, $numClasses).")
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the number of classes. This looks in column metadata first, and if that is missing,
|
* Get the number of classes. This looks in column metadata first, and if that is missing,
|
||||||
* then this assumes classes are indexed 0,1,...,numClasses-1 and computes numClasses
|
* then this assumes classes are indexed 0,1,...,numClasses-1 and computes numClasses
|
||||||
|
|
|
@ -22,10 +22,12 @@ import org.json4s.{DefaultFormats, JObject}
|
||||||
import org.json4s.JsonDSL._
|
import org.json4s.JsonDSL._
|
||||||
|
|
||||||
import org.apache.spark.annotation.Since
|
import org.apache.spark.annotation.Since
|
||||||
import org.apache.spark.ml.feature.LabeledPoint
|
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
|
||||||
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
|
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
|
||||||
import org.apache.spark.ml.param.ParamMap
|
import org.apache.spark.ml.param.ParamMap
|
||||||
|
import org.apache.spark.ml.param.shared.HasWeightCol
|
||||||
import org.apache.spark.ml.tree._
|
import org.apache.spark.ml.tree._
|
||||||
|
import org.apache.spark.ml.tree.{DecisionTreeModel, Node, TreeClassifierParams}
|
||||||
import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._
|
import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._
|
||||||
import org.apache.spark.ml.tree.impl.RandomForest
|
import org.apache.spark.ml.tree.impl.RandomForest
|
||||||
import org.apache.spark.ml.util._
|
import org.apache.spark.ml.util._
|
||||||
|
@ -33,8 +35,9 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
|
||||||
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
|
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
|
||||||
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
|
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.Dataset
|
import org.apache.spark.sql.{Dataset, Row}
|
||||||
|
import org.apache.spark.sql.functions.{col, lit}
|
||||||
|
import org.apache.spark.sql.types.DoubleType
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Decision tree learning algorithm (http://en.wikipedia.org/wiki/Decision_tree_learning)
|
* Decision tree learning algorithm (http://en.wikipedia.org/wiki/Decision_tree_learning)
|
||||||
|
@ -66,6 +69,9 @@ class DecisionTreeClassifier @Since("1.4.0") (
|
||||||
def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
|
def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
|
||||||
|
|
||||||
/** @group setParam */
|
/** @group setParam */
|
||||||
|
@Since("3.0.0")
|
||||||
|
def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value)
|
||||||
|
|
||||||
@Since("1.4.0")
|
@Since("1.4.0")
|
||||||
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
|
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
|
||||||
|
|
||||||
|
@ -97,6 +103,16 @@ class DecisionTreeClassifier @Since("1.4.0") (
|
||||||
@Since("1.6.0")
|
@Since("1.6.0")
|
||||||
def setSeed(value: Long): this.type = set(seed, value)
|
def setSeed(value: Long): this.type = set(seed, value)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the value of param [[weightCol]].
|
||||||
|
* If this is not set or empty, we treat all instance weights as 1.0.
|
||||||
|
* Default is not set, so all instances have weight one.
|
||||||
|
*
|
||||||
|
* @group setParam
|
||||||
|
*/
|
||||||
|
@Since("3.0.0")
|
||||||
|
def setWeightCol(value: String): this.type = set(weightCol, value)
|
||||||
|
|
||||||
override protected def train(
|
override protected def train(
|
||||||
dataset: Dataset[_]): DecisionTreeClassificationModel = instrumented { instr =>
|
dataset: Dataset[_]): DecisionTreeClassificationModel = instrumented { instr =>
|
||||||
instr.logPipelineStage(this)
|
instr.logPipelineStage(this)
|
||||||
|
@ -104,22 +120,27 @@ class DecisionTreeClassifier @Since("1.4.0") (
|
||||||
val categoricalFeatures: Map[Int, Int] =
|
val categoricalFeatures: Map[Int, Int] =
|
||||||
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
|
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
|
||||||
val numClasses: Int = getNumClasses(dataset)
|
val numClasses: Int = getNumClasses(dataset)
|
||||||
instr.logNumClasses(numClasses)
|
|
||||||
|
|
||||||
if (isDefined(thresholds)) {
|
if (isDefined(thresholds)) {
|
||||||
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
|
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
|
||||||
".train() called with non-matching numClasses and thresholds.length." +
|
".train() called with non-matching numClasses and thresholds.length." +
|
||||||
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
|
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
|
||||||
}
|
}
|
||||||
|
validateNumClasses(numClasses)
|
||||||
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
|
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
|
||||||
|
val instances =
|
||||||
|
dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
|
||||||
|
case Row(label: Double, weight: Double, features: Vector) =>
|
||||||
|
validateLabel(label, numClasses)
|
||||||
|
Instance(label, weight, features)
|
||||||
|
}
|
||||||
val strategy = getOldStrategy(categoricalFeatures, numClasses)
|
val strategy = getOldStrategy(categoricalFeatures, numClasses)
|
||||||
|
instr.logNumClasses(numClasses)
|
||||||
instr.logParams(this, labelCol, featuresCol, predictionCol, rawPredictionCol,
|
instr.logParams(this, labelCol, featuresCol, predictionCol, rawPredictionCol,
|
||||||
probabilityCol, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB,
|
probabilityCol, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB,
|
||||||
cacheNodeIds, checkpointInterval, impurity, seed)
|
cacheNodeIds, checkpointInterval, impurity, seed)
|
||||||
|
|
||||||
val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
|
val trees = RandomForest.run(instances, strategy, numTrees = 1, featureSubsetStrategy = "all",
|
||||||
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
|
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
|
||||||
|
|
||||||
trees.head.asInstanceOf[DecisionTreeClassificationModel]
|
trees.head.asInstanceOf[DecisionTreeClassificationModel]
|
||||||
|
@ -128,13 +149,13 @@ class DecisionTreeClassifier @Since("1.4.0") (
|
||||||
/** (private[ml]) Train a decision tree on an RDD */
|
/** (private[ml]) Train a decision tree on an RDD */
|
||||||
private[ml] def train(data: RDD[LabeledPoint],
|
private[ml] def train(data: RDD[LabeledPoint],
|
||||||
oldStrategy: OldStrategy): DecisionTreeClassificationModel = instrumented { instr =>
|
oldStrategy: OldStrategy): DecisionTreeClassificationModel = instrumented { instr =>
|
||||||
|
val instances = data.map(_.toInstance)
|
||||||
instr.logPipelineStage(this)
|
instr.logPipelineStage(this)
|
||||||
instr.logDataset(data)
|
instr.logDataset(instances)
|
||||||
instr.logParams(this, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB,
|
instr.logParams(this, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB,
|
||||||
cacheNodeIds, checkpointInterval, impurity, seed)
|
cacheNodeIds, checkpointInterval, impurity, seed)
|
||||||
|
val trees = RandomForest.run(instances, oldStrategy, numTrees = 1,
|
||||||
val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all",
|
featureSubsetStrategy = "all", seed = 0L, instr = Some(instr), parentUID = Some(uid))
|
||||||
seed = 0L, instr = Some(instr), parentUID = Some(uid))
|
|
||||||
|
|
||||||
trees.head.asInstanceOf[DecisionTreeClassificationModel]
|
trees.head.asInstanceOf[DecisionTreeClassificationModel]
|
||||||
}
|
}
|
||||||
|
@ -180,6 +201,7 @@ class DecisionTreeClassificationModel private[ml] (
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Construct a decision tree classification model.
|
* Construct a decision tree classification model.
|
||||||
|
*
|
||||||
* @param rootNode Root node of tree, with other nodes attached.
|
* @param rootNode Root node of tree, with other nodes attached.
|
||||||
*/
|
*/
|
||||||
private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) =
|
private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) =
|
||||||
|
|
|
@ -21,20 +21,21 @@ import org.json4s.{DefaultFormats, JObject}
|
||||||
import org.json4s.JsonDSL._
|
import org.json4s.JsonDSL._
|
||||||
|
|
||||||
import org.apache.spark.annotation.Since
|
import org.apache.spark.annotation.Since
|
||||||
import org.apache.spark.ml.feature.LabeledPoint
|
import org.apache.spark.ml.feature.Instance
|
||||||
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
|
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
|
||||||
import org.apache.spark.ml.param.ParamMap
|
import org.apache.spark.ml.param.ParamMap
|
||||||
import org.apache.spark.ml.tree._
|
import org.apache.spark.ml.tree._
|
||||||
|
import org.apache.spark.ml.tree.{RandomForestParams, TreeClassifierParams, TreeEnsembleModel}
|
||||||
import org.apache.spark.ml.tree.impl.RandomForest
|
import org.apache.spark.ml.tree.impl.RandomForest
|
||||||
import org.apache.spark.ml.util._
|
import org.apache.spark.ml.util._
|
||||||
|
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
|
||||||
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
|
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
|
||||||
import org.apache.spark.ml.util.Instrumentation.instrumented
|
import org.apache.spark.ml.util.Instrumentation.instrumented
|
||||||
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
|
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
|
||||||
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
|
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.{DataFrame, Dataset}
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions.{col, udf}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forest</a> learning algorithm for
|
* <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forest</a> learning algorithm for
|
||||||
|
@ -130,7 +131,7 @@ class RandomForestClassifier @Since("1.4.0") (
|
||||||
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
|
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
|
||||||
}
|
}
|
||||||
|
|
||||||
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
|
val instances: RDD[Instance] = extractLabeledPoints(dataset, numClasses).map(_.toInstance)
|
||||||
val strategy =
|
val strategy =
|
||||||
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
|
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
|
||||||
|
|
||||||
|
@ -139,7 +140,7 @@ class RandomForestClassifier @Since("1.4.0") (
|
||||||
minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, checkpointInterval)
|
minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, checkpointInterval)
|
||||||
|
|
||||||
val trees = RandomForest
|
val trees = RandomForest
|
||||||
.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
|
.run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
|
||||||
.map(_.asInstanceOf[DecisionTreeClassificationModel])
|
.map(_.asInstanceOf[DecisionTreeClassificationModel])
|
||||||
|
|
||||||
val numFeatures = trees.head.numFeatures
|
val numFeatures = trees.head.numFeatures
|
||||||
|
|
|
@ -37,4 +37,13 @@ case class LabeledPoint(@Since("2.0.0") label: Double, @Since("2.0.0") features:
|
||||||
override def toString: String = {
|
override def toString: String = {
|
||||||
s"($label,$features)"
|
s"($label,$features)"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private[spark] def toInstance(weight: Double): Instance = {
|
||||||
|
Instance(label, weight, features)
|
||||||
|
}
|
||||||
|
|
||||||
|
private[spark] def toInstance: Instance = {
|
||||||
|
Instance(label, 1.0, features)
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,9 +23,10 @@ import org.json4s.JsonDSL._
|
||||||
|
|
||||||
import org.apache.spark.annotation.Since
|
import org.apache.spark.annotation.Since
|
||||||
import org.apache.spark.ml.{PredictionModel, Predictor}
|
import org.apache.spark.ml.{PredictionModel, Predictor}
|
||||||
import org.apache.spark.ml.feature.LabeledPoint
|
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
|
||||||
import org.apache.spark.ml.linalg.Vector
|
import org.apache.spark.ml.linalg.Vector
|
||||||
import org.apache.spark.ml.param.ParamMap
|
import org.apache.spark.ml.param.ParamMap
|
||||||
|
import org.apache.spark.ml.param.shared.HasWeightCol
|
||||||
import org.apache.spark.ml.tree._
|
import org.apache.spark.ml.tree._
|
||||||
import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._
|
import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._
|
||||||
import org.apache.spark.ml.tree.impl.RandomForest
|
import org.apache.spark.ml.tree.impl.RandomForest
|
||||||
|
@ -34,8 +35,9 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
|
||||||
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
|
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
|
||||||
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
|
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.{DataFrame, Dataset}
|
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
|
import org.apache.spark.sql.types.DoubleType
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -65,6 +67,9 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
|
||||||
def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
|
def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
|
||||||
|
|
||||||
/** @group setParam */
|
/** @group setParam */
|
||||||
|
@Since("3.0.0")
|
||||||
|
def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value)
|
||||||
|
|
||||||
@Since("1.4.0")
|
@Since("1.4.0")
|
||||||
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
|
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
|
||||||
|
|
||||||
|
@ -100,18 +105,33 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
|
||||||
@Since("2.0.0")
|
@Since("2.0.0")
|
||||||
def setVarianceCol(value: String): this.type = set(varianceCol, value)
|
def setVarianceCol(value: String): this.type = set(varianceCol, value)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the value of param [[weightCol]].
|
||||||
|
* If this is not set or empty, we treat all instance weights as 1.0.
|
||||||
|
* Default is not set, so all instances have weight one.
|
||||||
|
*
|
||||||
|
* @group setParam
|
||||||
|
*/
|
||||||
|
@Since("3.0.0")
|
||||||
|
def setWeightCol(value: String): this.type = set(weightCol, value)
|
||||||
|
|
||||||
override protected def train(
|
override protected def train(
|
||||||
dataset: Dataset[_]): DecisionTreeRegressionModel = instrumented { instr =>
|
dataset: Dataset[_]): DecisionTreeRegressionModel = instrumented { instr =>
|
||||||
val categoricalFeatures: Map[Int, Int] =
|
val categoricalFeatures: Map[Int, Int] =
|
||||||
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
|
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
|
||||||
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
|
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
|
||||||
|
val instances =
|
||||||
|
dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
|
||||||
|
case Row(label: Double, weight: Double, features: Vector) =>
|
||||||
|
Instance(label, weight, features)
|
||||||
|
}
|
||||||
val strategy = getOldStrategy(categoricalFeatures)
|
val strategy = getOldStrategy(categoricalFeatures)
|
||||||
|
|
||||||
instr.logPipelineStage(this)
|
instr.logPipelineStage(this)
|
||||||
instr.logDataset(oldDataset)
|
instr.logDataset(instances)
|
||||||
instr.logParams(this, params: _*)
|
instr.logParams(this, params: _*)
|
||||||
|
|
||||||
val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
|
val trees = RandomForest.run(instances, strategy, numTrees = 1, featureSubsetStrategy = "all",
|
||||||
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
|
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
|
||||||
|
|
||||||
trees.head.asInstanceOf[DecisionTreeRegressionModel]
|
trees.head.asInstanceOf[DecisionTreeRegressionModel]
|
||||||
|
@ -126,8 +146,9 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
|
||||||
instr.logDataset(data)
|
instr.logDataset(data)
|
||||||
instr.logParams(this, params: _*)
|
instr.logParams(this, params: _*)
|
||||||
|
|
||||||
val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy,
|
val instances = data.map(_.toInstance)
|
||||||
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
|
val trees = RandomForest.run(instances, oldStrategy, numTrees = 1,
|
||||||
|
featureSubsetStrategy, seed = $(seed), instr = Some(instr), parentUID = Some(uid))
|
||||||
|
|
||||||
trees.head.asInstanceOf[DecisionTreeRegressionModel]
|
trees.head.asInstanceOf[DecisionTreeRegressionModel]
|
||||||
}
|
}
|
||||||
|
@ -155,6 +176,7 @@ object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor
|
||||||
* <a href="http://en.wikipedia.org/wiki/Decision_tree_learning">
|
* <a href="http://en.wikipedia.org/wiki/Decision_tree_learning">
|
||||||
* Decision tree (Wikipedia)</a> model for regression.
|
* Decision tree (Wikipedia)</a> model for regression.
|
||||||
* It supports both continuous and categorical features.
|
* It supports both continuous and categorical features.
|
||||||
|
*
|
||||||
* @param rootNode Root of the decision tree
|
* @param rootNode Root of the decision tree
|
||||||
*/
|
*/
|
||||||
@Since("1.4.0")
|
@Since("1.4.0")
|
||||||
|
@ -173,6 +195,7 @@ class DecisionTreeRegressionModel private[ml] (
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Construct a decision tree regression model.
|
* Construct a decision tree regression model.
|
||||||
|
*
|
||||||
* @param rootNode Root node of tree, with other nodes attached.
|
* @param rootNode Root node of tree, with other nodes attached.
|
||||||
*/
|
*/
|
||||||
private[ml] def this(rootNode: Node, numFeatures: Int) =
|
private[ml] def this(rootNode: Node, numFeatures: Int) =
|
||||||
|
|
|
@ -22,7 +22,6 @@ import org.json4s.JsonDSL._
|
||||||
|
|
||||||
import org.apache.spark.annotation.Since
|
import org.apache.spark.annotation.Since
|
||||||
import org.apache.spark.ml.{PredictionModel, Predictor}
|
import org.apache.spark.ml.{PredictionModel, Predictor}
|
||||||
import org.apache.spark.ml.feature.LabeledPoint
|
|
||||||
import org.apache.spark.ml.linalg.Vector
|
import org.apache.spark.ml.linalg.Vector
|
||||||
import org.apache.spark.ml.param.ParamMap
|
import org.apache.spark.ml.param.ParamMap
|
||||||
import org.apache.spark.ml.tree._
|
import org.apache.spark.ml.tree._
|
||||||
|
@ -32,10 +31,8 @@ import org.apache.spark.ml.util.DefaultParamsReader.Metadata
|
||||||
import org.apache.spark.ml.util.Instrumentation.instrumented
|
import org.apache.spark.ml.util.Instrumentation.instrumented
|
||||||
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
|
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
|
||||||
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
|
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
|
||||||
import org.apache.spark.rdd.RDD
|
|
||||||
import org.apache.spark.sql.{DataFrame, Dataset}
|
import org.apache.spark.sql.{DataFrame, Dataset}
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions.{col, udf}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forest</a>
|
* <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forest</a>
|
||||||
|
@ -119,18 +116,19 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
|
||||||
dataset: Dataset[_]): RandomForestRegressionModel = instrumented { instr =>
|
dataset: Dataset[_]): RandomForestRegressionModel = instrumented { instr =>
|
||||||
val categoricalFeatures: Map[Int, Int] =
|
val categoricalFeatures: Map[Int, Int] =
|
||||||
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
|
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
|
||||||
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
|
|
||||||
|
val instances = extractLabeledPoints(dataset).map(_.toInstance)
|
||||||
val strategy =
|
val strategy =
|
||||||
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity)
|
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity)
|
||||||
|
|
||||||
instr.logPipelineStage(this)
|
instr.logPipelineStage(this)
|
||||||
instr.logDataset(dataset)
|
instr.logDataset(instances)
|
||||||
instr.logParams(this, labelCol, featuresCol, predictionCol, impurity, numTrees,
|
instr.logParams(this, labelCol, featuresCol, predictionCol, impurity, numTrees,
|
||||||
featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain,
|
featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain,
|
||||||
minInstancesPerNode, seed, subsamplingRate, cacheNodeIds, checkpointInterval)
|
minInstancesPerNode, seed, subsamplingRate, cacheNodeIds, checkpointInterval)
|
||||||
|
|
||||||
val trees = RandomForest
|
val trees = RandomForest
|
||||||
.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
|
.run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
|
||||||
.map(_.asInstanceOf[DecisionTreeRegressionModel])
|
.map(_.asInstanceOf[DecisionTreeRegressionModel])
|
||||||
|
|
||||||
val numFeatures = trees.head.numFeatures
|
val numFeatures = trees.head.numFeatures
|
||||||
|
|
|
@ -33,13 +33,13 @@ import org.apache.spark.util.random.XORShiftRandom
|
||||||
* this datum has 1 copy, 0 copies, and 4 copies in the 3 subsamples, respectively.
|
* this datum has 1 copy, 0 copies, and 4 copies in the 3 subsamples, respectively.
|
||||||
*
|
*
|
||||||
* @param datum Data instance
|
* @param datum Data instance
|
||||||
* @param subsampleWeights Weight of this instance in each subsampled dataset.
|
* @param subsampleCounts Number of samples of this instance in each subsampled dataset.
|
||||||
*
|
* @param sampleWeight The weight of this instance.
|
||||||
* TODO: This does not currently support (Double) weighted instances. Once MLlib has weighted
|
|
||||||
* dataset support, update. (We store subsampleWeights as Double for this future extension.)
|
|
||||||
*/
|
*/
|
||||||
private[spark] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double])
|
private[spark] class BaggedPoint[Datum](
|
||||||
extends Serializable
|
val datum: Datum,
|
||||||
|
val subsampleCounts: Array[Int],
|
||||||
|
val sampleWeight: Double = 1.0) extends Serializable
|
||||||
|
|
||||||
private[spark] object BaggedPoint {
|
private[spark] object BaggedPoint {
|
||||||
|
|
||||||
|
@ -52,6 +52,7 @@ private[spark] object BaggedPoint {
|
||||||
* @param subsamplingRate Fraction of the training data used for learning decision tree.
|
* @param subsamplingRate Fraction of the training data used for learning decision tree.
|
||||||
* @param numSubsamples Number of subsamples of this RDD to take.
|
* @param numSubsamples Number of subsamples of this RDD to take.
|
||||||
* @param withReplacement Sampling with/without replacement.
|
* @param withReplacement Sampling with/without replacement.
|
||||||
|
* @param extractSampleWeight A function to get the sample weight of each datum.
|
||||||
* @param seed Random seed.
|
* @param seed Random seed.
|
||||||
* @return BaggedPoint dataset representation.
|
* @return BaggedPoint dataset representation.
|
||||||
*/
|
*/
|
||||||
|
@ -60,12 +61,14 @@ private[spark] object BaggedPoint {
|
||||||
subsamplingRate: Double,
|
subsamplingRate: Double,
|
||||||
numSubsamples: Int,
|
numSubsamples: Int,
|
||||||
withReplacement: Boolean,
|
withReplacement: Boolean,
|
||||||
|
extractSampleWeight: (Datum => Double) = (_: Datum) => 1.0,
|
||||||
seed: Long = Utils.random.nextLong()): RDD[BaggedPoint[Datum]] = {
|
seed: Long = Utils.random.nextLong()): RDD[BaggedPoint[Datum]] = {
|
||||||
|
// TODO: implement weighted bootstrapping
|
||||||
if (withReplacement) {
|
if (withReplacement) {
|
||||||
convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed)
|
convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed)
|
||||||
} else {
|
} else {
|
||||||
if (numSubsamples == 1 && subsamplingRate == 1.0) {
|
if (numSubsamples == 1 && subsamplingRate == 1.0) {
|
||||||
convertToBaggedRDDWithoutSampling(input)
|
convertToBaggedRDDWithoutSampling(input, extractSampleWeight)
|
||||||
} else {
|
} else {
|
||||||
convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed)
|
convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed)
|
||||||
}
|
}
|
||||||
|
@ -82,16 +85,15 @@ private[spark] object BaggedPoint {
|
||||||
val rng = new XORShiftRandom
|
val rng = new XORShiftRandom
|
||||||
rng.setSeed(seed + partitionIndex + 1)
|
rng.setSeed(seed + partitionIndex + 1)
|
||||||
instances.map { instance =>
|
instances.map { instance =>
|
||||||
val subsampleWeights = new Array[Double](numSubsamples)
|
val subsampleCounts = new Array[Int](numSubsamples)
|
||||||
var subsampleIndex = 0
|
var subsampleIndex = 0
|
||||||
while (subsampleIndex < numSubsamples) {
|
while (subsampleIndex < numSubsamples) {
|
||||||
val x = rng.nextDouble()
|
if (rng.nextDouble() < subsamplingRate) {
|
||||||
subsampleWeights(subsampleIndex) = {
|
subsampleCounts(subsampleIndex) = 1
|
||||||
if (x < subsamplingRate) 1.0 else 0.0
|
|
||||||
}
|
}
|
||||||
subsampleIndex += 1
|
subsampleIndex += 1
|
||||||
}
|
}
|
||||||
new BaggedPoint(instance, subsampleWeights)
|
new BaggedPoint(instance, subsampleCounts)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -106,20 +108,20 @@ private[spark] object BaggedPoint {
|
||||||
val poisson = new PoissonDistribution(subsample)
|
val poisson = new PoissonDistribution(subsample)
|
||||||
poisson.reseedRandomGenerator(seed + partitionIndex + 1)
|
poisson.reseedRandomGenerator(seed + partitionIndex + 1)
|
||||||
instances.map { instance =>
|
instances.map { instance =>
|
||||||
val subsampleWeights = new Array[Double](numSubsamples)
|
val subsampleCounts = new Array[Int](numSubsamples)
|
||||||
var subsampleIndex = 0
|
var subsampleIndex = 0
|
||||||
while (subsampleIndex < numSubsamples) {
|
while (subsampleIndex < numSubsamples) {
|
||||||
subsampleWeights(subsampleIndex) = poisson.sample()
|
subsampleCounts(subsampleIndex) = poisson.sample()
|
||||||
subsampleIndex += 1
|
subsampleIndex += 1
|
||||||
}
|
}
|
||||||
new BaggedPoint(instance, subsampleWeights)
|
new BaggedPoint(instance, subsampleCounts)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private def convertToBaggedRDDWithoutSampling[Datum] (
|
private def convertToBaggedRDDWithoutSampling[Datum] (
|
||||||
input: RDD[Datum]): RDD[BaggedPoint[Datum]] = {
|
input: RDD[Datum],
|
||||||
input.map(datum => new BaggedPoint(datum, Array(1.0)))
|
extractSampleWeight: (Datum => Double)): RDD[BaggedPoint[Datum]] = {
|
||||||
|
input.map(datum => new BaggedPoint(datum, Array(1), extractSampleWeight(datum)))
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -104,16 +104,21 @@ private[spark] class DTStatsAggregator(
|
||||||
/**
|
/**
|
||||||
* Update the stats for a given (feature, bin) for ordered features, using the given label.
|
* Update the stats for a given (feature, bin) for ordered features, using the given label.
|
||||||
*/
|
*/
|
||||||
def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = {
|
def update(
|
||||||
|
featureIndex: Int,
|
||||||
|
binIndex: Int,
|
||||||
|
label: Double,
|
||||||
|
numSamples: Int,
|
||||||
|
sampleWeight: Double): Unit = {
|
||||||
val i = featureOffsets(featureIndex) + binIndex * statsSize
|
val i = featureOffsets(featureIndex) + binIndex * statsSize
|
||||||
impurityAggregator.update(allStats, i, label, instanceWeight)
|
impurityAggregator.update(allStats, i, label, numSamples, sampleWeight)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Update the parent node stats using the given label.
|
* Update the parent node stats using the given label.
|
||||||
*/
|
*/
|
||||||
def updateParent(label: Double, instanceWeight: Double): Unit = {
|
def updateParent(label: Double, numSamples: Int, sampleWeight: Double): Unit = {
|
||||||
impurityAggregator.update(parentStats, 0, label, instanceWeight)
|
impurityAggregator.update(parentStats, 0, label, numSamples, sampleWeight)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -127,9 +132,10 @@ private[spark] class DTStatsAggregator(
|
||||||
featureOffset: Int,
|
featureOffset: Int,
|
||||||
binIndex: Int,
|
binIndex: Int,
|
||||||
label: Double,
|
label: Double,
|
||||||
instanceWeight: Double): Unit = {
|
numSamples: Int,
|
||||||
|
sampleWeight: Double): Unit = {
|
||||||
impurityAggregator.update(allStats, featureOffset + binIndex * statsSize,
|
impurityAggregator.update(allStats, featureOffset + binIndex * statsSize,
|
||||||
label, instanceWeight)
|
label, numSamples, sampleWeight)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -21,7 +21,7 @@ import scala.collection.mutable
|
||||||
import scala.util.Try
|
import scala.util.Try
|
||||||
|
|
||||||
import org.apache.spark.internal.Logging
|
import org.apache.spark.internal.Logging
|
||||||
import org.apache.spark.ml.feature.LabeledPoint
|
import org.apache.spark.ml.feature.Instance
|
||||||
import org.apache.spark.ml.tree.TreeEnsembleParams
|
import org.apache.spark.ml.tree.TreeEnsembleParams
|
||||||
import org.apache.spark.mllib.tree.configuration.Algo._
|
import org.apache.spark.mllib.tree.configuration.Algo._
|
||||||
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
|
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
|
||||||
|
@ -32,16 +32,20 @@ import org.apache.spark.rdd.RDD
|
||||||
/**
|
/**
|
||||||
* Learning and dataset metadata for DecisionTree.
|
* Learning and dataset metadata for DecisionTree.
|
||||||
*
|
*
|
||||||
|
* @param weightedNumExamples Weighted count of samples in the tree.
|
||||||
* @param numClasses For classification: labels can take values {0, ..., numClasses - 1}.
|
* @param numClasses For classification: labels can take values {0, ..., numClasses - 1}.
|
||||||
* For regression: fixed at 0 (no meaning).
|
* For regression: fixed at 0 (no meaning).
|
||||||
* @param maxBins Maximum number of bins, for all features.
|
* @param maxBins Maximum number of bins, for all features.
|
||||||
* @param featureArity Map: categorical feature index to arity.
|
* @param featureArity Map: categorical feature index to arity.
|
||||||
* I.e., the feature takes values in {0, ..., arity - 1}.
|
* I.e., the feature takes values in {0, ..., arity - 1}.
|
||||||
* @param numBins Number of bins for each feature.
|
* @param numBins Number of bins for each feature.
|
||||||
|
* @param minWeightFractionPerNode The minimum fraction of the total sample weight that must be
|
||||||
|
* present in a leaf node in order to be considered a valid split.
|
||||||
*/
|
*/
|
||||||
private[spark] class DecisionTreeMetadata(
|
private[spark] class DecisionTreeMetadata(
|
||||||
val numFeatures: Int,
|
val numFeatures: Int,
|
||||||
val numExamples: Long,
|
val numExamples: Long,
|
||||||
|
val weightedNumExamples: Double,
|
||||||
val numClasses: Int,
|
val numClasses: Int,
|
||||||
val maxBins: Int,
|
val maxBins: Int,
|
||||||
val featureArity: Map[Int, Int],
|
val featureArity: Map[Int, Int],
|
||||||
|
@ -51,6 +55,7 @@ private[spark] class DecisionTreeMetadata(
|
||||||
val quantileStrategy: QuantileStrategy,
|
val quantileStrategy: QuantileStrategy,
|
||||||
val maxDepth: Int,
|
val maxDepth: Int,
|
||||||
val minInstancesPerNode: Int,
|
val minInstancesPerNode: Int,
|
||||||
|
val minWeightFractionPerNode: Double,
|
||||||
val minInfoGain: Double,
|
val minInfoGain: Double,
|
||||||
val numTrees: Int,
|
val numTrees: Int,
|
||||||
val numFeaturesPerNode: Int) extends Serializable {
|
val numFeaturesPerNode: Int) extends Serializable {
|
||||||
|
@ -67,6 +72,8 @@ private[spark] class DecisionTreeMetadata(
|
||||||
|
|
||||||
def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex)
|
def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex)
|
||||||
|
|
||||||
|
def minWeightPerNode: Double = minWeightFractionPerNode * weightedNumExamples
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Number of splits for the given feature.
|
* Number of splits for the given feature.
|
||||||
* For unordered features, there is 1 bin per split.
|
* For unordered features, there is 1 bin per split.
|
||||||
|
@ -104,7 +111,7 @@ private[spark] object DecisionTreeMetadata extends Logging {
|
||||||
* as well as the number of splits and bins for each feature.
|
* as well as the number of splits and bins for each feature.
|
||||||
*/
|
*/
|
||||||
def buildMetadata(
|
def buildMetadata(
|
||||||
input: RDD[LabeledPoint],
|
input: RDD[Instance],
|
||||||
strategy: Strategy,
|
strategy: Strategy,
|
||||||
numTrees: Int,
|
numTrees: Int,
|
||||||
featureSubsetStrategy: String): DecisionTreeMetadata = {
|
featureSubsetStrategy: String): DecisionTreeMetadata = {
|
||||||
|
@ -115,7 +122,11 @@ private[spark] object DecisionTreeMetadata extends Logging {
|
||||||
}
|
}
|
||||||
require(numFeatures > 0, s"DecisionTree requires number of features > 0, " +
|
require(numFeatures > 0, s"DecisionTree requires number of features > 0, " +
|
||||||
s"but was given an empty features vector")
|
s"but was given an empty features vector")
|
||||||
val numExamples = input.count()
|
val (numExamples, weightSum) = input.aggregate((0L, 0.0))(
|
||||||
|
seqOp = (cw, instance) => (cw._1 + 1L, cw._2 + instance.weight),
|
||||||
|
combOp = (cw1, cw2) => (cw1._1 + cw2._1, cw1._2 + cw2._2)
|
||||||
|
)
|
||||||
|
|
||||||
val numClasses = strategy.algo match {
|
val numClasses = strategy.algo match {
|
||||||
case Classification => strategy.numClasses
|
case Classification => strategy.numClasses
|
||||||
case Regression => 0
|
case Regression => 0
|
||||||
|
@ -206,17 +217,18 @@ private[spark] object DecisionTreeMetadata extends Logging {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
|
new DecisionTreeMetadata(numFeatures, numExamples, weightSum, numClasses,
|
||||||
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
|
numBins.max, strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
|
||||||
strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth,
|
strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth,
|
||||||
strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode)
|
strategy.minInstancesPerNode, strategy.minWeightFractionPerNode, strategy.minInfoGain,
|
||||||
|
numTrees, numFeaturesPerNode)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Version of [[DecisionTreeMetadata#buildMetadata]] for DecisionTree.
|
* Version of [[DecisionTreeMetadata#buildMetadata]] for DecisionTree.
|
||||||
*/
|
*/
|
||||||
def buildMetadata(
|
def buildMetadata(
|
||||||
input: RDD[LabeledPoint],
|
input: RDD[Instance],
|
||||||
strategy: Strategy): DecisionTreeMetadata = {
|
strategy: Strategy): DecisionTreeMetadata = {
|
||||||
buildMetadata(input, strategy, numTrees = 1, featureSubsetStrategy = "all")
|
buildMetadata(input, strategy, numTrees = 1, featureSubsetStrategy = "all")
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,10 +24,12 @@ import scala.util.Random
|
||||||
|
|
||||||
import org.apache.spark.internal.Logging
|
import org.apache.spark.internal.Logging
|
||||||
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
|
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
|
||||||
import org.apache.spark.ml.feature.LabeledPoint
|
import org.apache.spark.ml.feature.Instance
|
||||||
|
import org.apache.spark.ml.impl.Utils
|
||||||
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
|
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
|
||||||
import org.apache.spark.ml.tree._
|
import org.apache.spark.ml.tree._
|
||||||
import org.apache.spark.ml.util.Instrumentation
|
import org.apache.spark.ml.util.Instrumentation
|
||||||
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
|
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
|
||||||
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
|
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
|
||||||
import org.apache.spark.mllib.tree.model.ImpurityStats
|
import org.apache.spark.mllib.tree.model.ImpurityStats
|
||||||
|
@ -90,6 +92,24 @@ private[spark] object RandomForest extends Logging with Serializable {
|
||||||
strategy: OldStrategy,
|
strategy: OldStrategy,
|
||||||
numTrees: Int,
|
numTrees: Int,
|
||||||
featureSubsetStrategy: String,
|
featureSubsetStrategy: String,
|
||||||
|
seed: Long): Array[DecisionTreeModel] = {
|
||||||
|
val instances = input.map { case LabeledPoint(label, features) =>
|
||||||
|
Instance(label, 1.0, features.asML)
|
||||||
|
}
|
||||||
|
run(instances, strategy, numTrees, featureSubsetStrategy, seed, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Train a random forest.
|
||||||
|
*
|
||||||
|
* @param input Training data: RDD of `Instance`
|
||||||
|
* @return an unweighted set of trees
|
||||||
|
*/
|
||||||
|
def run(
|
||||||
|
input: RDD[Instance],
|
||||||
|
strategy: OldStrategy,
|
||||||
|
numTrees: Int,
|
||||||
|
featureSubsetStrategy: String,
|
||||||
seed: Long,
|
seed: Long,
|
||||||
instr: Option[Instrumentation],
|
instr: Option[Instrumentation],
|
||||||
prune: Boolean = true, // exposed for testing only, real trees are always pruned
|
prune: Boolean = true, // exposed for testing only, real trees are always pruned
|
||||||
|
@ -101,9 +121,10 @@ private[spark] object RandomForest extends Logging with Serializable {
|
||||||
|
|
||||||
timer.start("init")
|
timer.start("init")
|
||||||
|
|
||||||
val retaggedInput = input.retag(classOf[LabeledPoint])
|
val retaggedInput = input.retag(classOf[Instance])
|
||||||
val metadata =
|
val metadata =
|
||||||
DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy)
|
DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy)
|
||||||
|
|
||||||
instr match {
|
instr match {
|
||||||
case Some(instrumentation) =>
|
case Some(instrumentation) =>
|
||||||
instrumentation.logNumFeatures(metadata.numFeatures)
|
instrumentation.logNumFeatures(metadata.numFeatures)
|
||||||
|
@ -132,7 +153,8 @@ private[spark] object RandomForest extends Logging with Serializable {
|
||||||
val withReplacement = numTrees > 1
|
val withReplacement = numTrees > 1
|
||||||
|
|
||||||
val baggedInput = BaggedPoint
|
val baggedInput = BaggedPoint
|
||||||
.convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, withReplacement, seed)
|
.convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, withReplacement,
|
||||||
|
(tp: TreePoint) => tp.weight, seed = seed)
|
||||||
.persist(StorageLevel.MEMORY_AND_DISK)
|
.persist(StorageLevel.MEMORY_AND_DISK)
|
||||||
|
|
||||||
// depth of the decision tree
|
// depth of the decision tree
|
||||||
|
@ -254,19 +276,21 @@ private[spark] object RandomForest extends Logging with Serializable {
|
||||||
* For unordered features, bins correspond to subsets of categories; either the left or right bin
|
* For unordered features, bins correspond to subsets of categories; either the left or right bin
|
||||||
* for each subset is updated.
|
* for each subset is updated.
|
||||||
*
|
*
|
||||||
* @param agg Array storing aggregate calculation, with a set of sufficient statistics for
|
* @param agg Array storing aggregate calculation, with a set of sufficient statistics for
|
||||||
* each (feature, bin).
|
* each (feature, bin).
|
||||||
* @param treePoint Data point being aggregated.
|
* @param treePoint Data point being aggregated.
|
||||||
* @param splits possible splits indexed (numFeatures)(numSplits)
|
* @param splits Possible splits indexed (numFeatures)(numSplits)
|
||||||
* @param unorderedFeatures Set of indices of unordered features.
|
* @param unorderedFeatures Set of indices of unordered features.
|
||||||
* @param instanceWeight Weight (importance) of instance in dataset.
|
* @param numSamples Number of times this instance occurs in the sample.
|
||||||
|
* @param sampleWeight Weight (importance) of instance in dataset.
|
||||||
*/
|
*/
|
||||||
private def mixedBinSeqOp(
|
private def mixedBinSeqOp(
|
||||||
agg: DTStatsAggregator,
|
agg: DTStatsAggregator,
|
||||||
treePoint: TreePoint,
|
treePoint: TreePoint,
|
||||||
splits: Array[Array[Split]],
|
splits: Array[Array[Split]],
|
||||||
unorderedFeatures: Set[Int],
|
unorderedFeatures: Set[Int],
|
||||||
instanceWeight: Double,
|
numSamples: Int,
|
||||||
|
sampleWeight: Double,
|
||||||
featuresForNode: Option[Array[Int]]): Unit = {
|
featuresForNode: Option[Array[Int]]): Unit = {
|
||||||
val numFeaturesPerNode = if (featuresForNode.nonEmpty) {
|
val numFeaturesPerNode = if (featuresForNode.nonEmpty) {
|
||||||
// Use subsampled features
|
// Use subsampled features
|
||||||
|
@ -293,14 +317,15 @@ private[spark] object RandomForest extends Logging with Serializable {
|
||||||
var splitIndex = 0
|
var splitIndex = 0
|
||||||
while (splitIndex < numSplits) {
|
while (splitIndex < numSplits) {
|
||||||
if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) {
|
if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) {
|
||||||
agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight)
|
agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, numSamples,
|
||||||
|
sampleWeight)
|
||||||
}
|
}
|
||||||
splitIndex += 1
|
splitIndex += 1
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Ordered feature
|
// Ordered feature
|
||||||
val binIndex = treePoint.binnedFeatures(featureIndex)
|
val binIndex = treePoint.binnedFeatures(featureIndex)
|
||||||
agg.update(featureIndexIdx, binIndex, treePoint.label, instanceWeight)
|
agg.update(featureIndexIdx, binIndex, treePoint.label, numSamples, sampleWeight)
|
||||||
}
|
}
|
||||||
featureIndexIdx += 1
|
featureIndexIdx += 1
|
||||||
}
|
}
|
||||||
|
@ -314,12 +339,14 @@ private[spark] object RandomForest extends Logging with Serializable {
|
||||||
* @param agg Array storing aggregate calculation, with a set of sufficient statistics for
|
* @param agg Array storing aggregate calculation, with a set of sufficient statistics for
|
||||||
* each (feature, bin).
|
* each (feature, bin).
|
||||||
* @param treePoint Data point being aggregated.
|
* @param treePoint Data point being aggregated.
|
||||||
* @param instanceWeight Weight (importance) of instance in dataset.
|
* @param numSamples Number of times this instance occurs in the sample.
|
||||||
|
* @param sampleWeight Weight (importance) of instance in dataset.
|
||||||
*/
|
*/
|
||||||
private def orderedBinSeqOp(
|
private def orderedBinSeqOp(
|
||||||
agg: DTStatsAggregator,
|
agg: DTStatsAggregator,
|
||||||
treePoint: TreePoint,
|
treePoint: TreePoint,
|
||||||
instanceWeight: Double,
|
numSamples: Int,
|
||||||
|
sampleWeight: Double,
|
||||||
featuresForNode: Option[Array[Int]]): Unit = {
|
featuresForNode: Option[Array[Int]]): Unit = {
|
||||||
val label = treePoint.label
|
val label = treePoint.label
|
||||||
|
|
||||||
|
@ -329,7 +356,7 @@ private[spark] object RandomForest extends Logging with Serializable {
|
||||||
var featureIndexIdx = 0
|
var featureIndexIdx = 0
|
||||||
while (featureIndexIdx < featuresForNode.get.length) {
|
while (featureIndexIdx < featuresForNode.get.length) {
|
||||||
val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx))
|
val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx))
|
||||||
agg.update(featureIndexIdx, binIndex, label, instanceWeight)
|
agg.update(featureIndexIdx, binIndex, label, numSamples, sampleWeight)
|
||||||
featureIndexIdx += 1
|
featureIndexIdx += 1
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -338,7 +365,7 @@ private[spark] object RandomForest extends Logging with Serializable {
|
||||||
var featureIndex = 0
|
var featureIndex = 0
|
||||||
while (featureIndex < numFeatures) {
|
while (featureIndex < numFeatures) {
|
||||||
val binIndex = treePoint.binnedFeatures(featureIndex)
|
val binIndex = treePoint.binnedFeatures(featureIndex)
|
||||||
agg.update(featureIndex, binIndex, label, instanceWeight)
|
agg.update(featureIndex, binIndex, label, numSamples, sampleWeight)
|
||||||
featureIndex += 1
|
featureIndex += 1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -427,14 +454,16 @@ private[spark] object RandomForest extends Logging with Serializable {
|
||||||
if (nodeInfo != null) {
|
if (nodeInfo != null) {
|
||||||
val aggNodeIndex = nodeInfo.nodeIndexInGroup
|
val aggNodeIndex = nodeInfo.nodeIndexInGroup
|
||||||
val featuresForNode = nodeInfo.featureSubset
|
val featuresForNode = nodeInfo.featureSubset
|
||||||
val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
|
val numSamples = baggedPoint.subsampleCounts(treeIndex)
|
||||||
|
val sampleWeight = baggedPoint.sampleWeight
|
||||||
if (metadata.unorderedFeatures.isEmpty) {
|
if (metadata.unorderedFeatures.isEmpty) {
|
||||||
orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
|
orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, numSamples, sampleWeight,
|
||||||
|
featuresForNode)
|
||||||
} else {
|
} else {
|
||||||
mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
|
mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
|
||||||
metadata.unorderedFeatures, instanceWeight, featuresForNode)
|
metadata.unorderedFeatures, numSamples, sampleWeight, featuresForNode)
|
||||||
}
|
}
|
||||||
agg(aggNodeIndex).updateParent(baggedPoint.datum.label, instanceWeight)
|
agg(aggNodeIndex).updateParent(baggedPoint.datum.label, numSamples, sampleWeight)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -594,8 +623,8 @@ private[spark] object RandomForest extends Logging with Serializable {
|
||||||
if (!isLeaf) {
|
if (!isLeaf) {
|
||||||
node.split = Some(split)
|
node.split = Some(split)
|
||||||
val childIsLeaf = (LearningNode.indexToLevel(nodeIndex) + 1) == metadata.maxDepth
|
val childIsLeaf = (LearningNode.indexToLevel(nodeIndex) + 1) == metadata.maxDepth
|
||||||
val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0)
|
val leftChildIsLeaf = childIsLeaf || (math.abs(stats.leftImpurity) < Utils.EPSILON)
|
||||||
val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0)
|
val rightChildIsLeaf = childIsLeaf || (math.abs(stats.rightImpurity) < Utils.EPSILON)
|
||||||
node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex),
|
node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex),
|
||||||
leftChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator)))
|
leftChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator)))
|
||||||
node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex),
|
node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex),
|
||||||
|
@ -659,15 +688,20 @@ private[spark] object RandomForest extends Logging with Serializable {
|
||||||
stats.impurity
|
stats.impurity
|
||||||
}
|
}
|
||||||
|
|
||||||
|
val leftRawCount = leftImpurityCalculator.rawCount
|
||||||
|
val rightRawCount = rightImpurityCalculator.rawCount
|
||||||
val leftCount = leftImpurityCalculator.count
|
val leftCount = leftImpurityCalculator.count
|
||||||
val rightCount = rightImpurityCalculator.count
|
val rightCount = rightImpurityCalculator.count
|
||||||
|
|
||||||
val totalCount = leftCount + rightCount
|
val totalCount = leftCount + rightCount
|
||||||
|
|
||||||
// If left child or right child doesn't satisfy minimum instances per node,
|
val violatesMinInstancesPerNode = (leftRawCount < metadata.minInstancesPerNode) ||
|
||||||
// then this split is invalid, return invalid information gain stats.
|
(rightRawCount < metadata.minInstancesPerNode)
|
||||||
if ((leftCount < metadata.minInstancesPerNode) ||
|
val violatesMinWeightPerNode = (leftCount < metadata.minWeightPerNode) ||
|
||||||
(rightCount < metadata.minInstancesPerNode)) {
|
(rightCount < metadata.minWeightPerNode)
|
||||||
|
// If left child or right child doesn't satisfy minimum weight per node or minimum
|
||||||
|
// instances per node, then this split is invalid, return invalid information gain stats.
|
||||||
|
if (violatesMinInstancesPerNode || violatesMinWeightPerNode) {
|
||||||
return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
|
return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -734,7 +768,8 @@ private[spark] object RandomForest extends Logging with Serializable {
|
||||||
// Find best split.
|
// Find best split.
|
||||||
val (bestFeatureSplitIndex, bestFeatureGainStats) =
|
val (bestFeatureSplitIndex, bestFeatureGainStats) =
|
||||||
Range(0, numSplits).map { case splitIdx =>
|
Range(0, numSplits).map { case splitIdx =>
|
||||||
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
|
val leftChildStats =
|
||||||
|
binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
|
||||||
val rightChildStats =
|
val rightChildStats =
|
||||||
binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
|
binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
|
||||||
rightChildStats.subtract(leftChildStats)
|
rightChildStats.subtract(leftChildStats)
|
||||||
|
@ -876,14 +911,14 @@ private[spark] object RandomForest extends Logging with Serializable {
|
||||||
* and for multiclass classification with a high-arity feature,
|
* and for multiclass classification with a high-arity feature,
|
||||||
* there is one bin per category.
|
* there is one bin per category.
|
||||||
*
|
*
|
||||||
* @param input Training data: RDD of [[LabeledPoint]]
|
* @param input Training data: RDD of [[Instance]]
|
||||||
* @param metadata Learning and dataset metadata
|
* @param metadata Learning and dataset metadata
|
||||||
* @param seed random seed
|
* @param seed random seed
|
||||||
* @return Splits, an Array of [[Split]]
|
* @return Splits, an Array of [[Split]]
|
||||||
* of size (numFeatures, numSplits)
|
* of size (numFeatures, numSplits)
|
||||||
*/
|
*/
|
||||||
protected[tree] def findSplits(
|
protected[tree] def findSplits(
|
||||||
input: RDD[LabeledPoint],
|
input: RDD[Instance],
|
||||||
metadata: DecisionTreeMetadata,
|
metadata: DecisionTreeMetadata,
|
||||||
seed: Long): Array[Array[Split]] = {
|
seed: Long): Array[Array[Split]] = {
|
||||||
|
|
||||||
|
@ -898,14 +933,14 @@ private[spark] object RandomForest extends Logging with Serializable {
|
||||||
logDebug("fraction of data used for calculating quantiles = " + fraction)
|
logDebug("fraction of data used for calculating quantiles = " + fraction)
|
||||||
input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt())
|
input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt())
|
||||||
} else {
|
} else {
|
||||||
input.sparkContext.emptyRDD[LabeledPoint]
|
input.sparkContext.emptyRDD[Instance]
|
||||||
}
|
}
|
||||||
|
|
||||||
findSplitsBySorting(sampledInput, metadata, continuousFeatures)
|
findSplitsBySorting(sampledInput, metadata, continuousFeatures)
|
||||||
}
|
}
|
||||||
|
|
||||||
private def findSplitsBySorting(
|
private def findSplitsBySorting(
|
||||||
input: RDD[LabeledPoint],
|
input: RDD[Instance],
|
||||||
metadata: DecisionTreeMetadata,
|
metadata: DecisionTreeMetadata,
|
||||||
continuousFeatures: IndexedSeq[Int]): Array[Array[Split]] = {
|
continuousFeatures: IndexedSeq[Int]): Array[Array[Split]] = {
|
||||||
|
|
||||||
|
@ -917,7 +952,8 @@ private[spark] object RandomForest extends Logging with Serializable {
|
||||||
|
|
||||||
input
|
input
|
||||||
.flatMap { point =>
|
.flatMap { point =>
|
||||||
continuousFeatures.map(idx => (idx, point.features(idx))).filter(_._2 != 0.0)
|
continuousFeatures.map(idx => (idx, (point.weight, point.features(idx))))
|
||||||
|
.filter(_._2._2 != 0.0)
|
||||||
}.groupByKey(numPartitions)
|
}.groupByKey(numPartitions)
|
||||||
.map { case (idx, samples) =>
|
.map { case (idx, samples) =>
|
||||||
val thresholds = findSplitsForContinuousFeature(samples, metadata, idx)
|
val thresholds = findSplitsForContinuousFeature(samples, metadata, idx)
|
||||||
|
@ -982,7 +1018,7 @@ private[spark] object RandomForest extends Logging with Serializable {
|
||||||
* could be different from the specified `numSplits`.
|
* could be different from the specified `numSplits`.
|
||||||
* The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly.
|
* The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly.
|
||||||
*
|
*
|
||||||
* @param featureSamples feature values of each sample
|
* @param featureSamples feature values and sample weights of each sample
|
||||||
* @param metadata decision tree metadata
|
* @param metadata decision tree metadata
|
||||||
* NOTE: `metadata.numbins` will be changed accordingly
|
* NOTE: `metadata.numbins` will be changed accordingly
|
||||||
* if there are not enough splits to be found
|
* if there are not enough splits to be found
|
||||||
|
@ -990,7 +1026,7 @@ private[spark] object RandomForest extends Logging with Serializable {
|
||||||
* @return array of split thresholds
|
* @return array of split thresholds
|
||||||
*/
|
*/
|
||||||
private[tree] def findSplitsForContinuousFeature(
|
private[tree] def findSplitsForContinuousFeature(
|
||||||
featureSamples: Iterable[Double],
|
featureSamples: Iterable[(Double, Double)],
|
||||||
metadata: DecisionTreeMetadata,
|
metadata: DecisionTreeMetadata,
|
||||||
featureIndex: Int): Array[Double] = {
|
featureIndex: Int): Array[Double] = {
|
||||||
require(metadata.isContinuous(featureIndex),
|
require(metadata.isContinuous(featureIndex),
|
||||||
|
@ -1002,19 +1038,24 @@ private[spark] object RandomForest extends Logging with Serializable {
|
||||||
val numSplits = metadata.numSplits(featureIndex)
|
val numSplits = metadata.numSplits(featureIndex)
|
||||||
|
|
||||||
// get count for each distinct value except zero value
|
// get count for each distinct value except zero value
|
||||||
val partNumSamples = featureSamples.size
|
val partValueCountMap = mutable.Map[Double, Double]()
|
||||||
val partValueCountMap = scala.collection.mutable.Map[Double, Int]()
|
var partNumSamples = 0.0
|
||||||
featureSamples.foreach { x =>
|
var unweightedNumSamples = 0.0
|
||||||
partValueCountMap(x) = partValueCountMap.getOrElse(x, 0) + 1
|
featureSamples.foreach { case (sampleWeight, feature) =>
|
||||||
|
partValueCountMap(feature) = partValueCountMap.getOrElse(feature, 0.0) + sampleWeight;
|
||||||
|
partNumSamples += sampleWeight;
|
||||||
|
unweightedNumSamples += 1.0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate the expected number of samples for finding splits
|
// Calculate the expected number of samples for finding splits
|
||||||
val numSamples = (samplesFractionForFindSplits(metadata) * metadata.numExamples).toInt
|
val weightedNumSamples = samplesFractionForFindSplits(metadata) *
|
||||||
|
metadata.weightedNumExamples
|
||||||
// add expected zero value count and get complete statistics
|
// add expected zero value count and get complete statistics
|
||||||
val valueCountMap: Map[Double, Int] = if (numSamples - partNumSamples > 0) {
|
val tolerance = Utils.EPSILON * unweightedNumSamples * unweightedNumSamples
|
||||||
partValueCountMap.toMap + (0.0 -> (numSamples - partNumSamples))
|
val valueCountMap = if (weightedNumSamples - partNumSamples > tolerance) {
|
||||||
|
partValueCountMap + (0.0 -> (weightedNumSamples - partNumSamples))
|
||||||
} else {
|
} else {
|
||||||
partValueCountMap.toMap
|
partValueCountMap
|
||||||
}
|
}
|
||||||
|
|
||||||
// sort distinct values
|
// sort distinct values
|
||||||
|
@ -1031,7 +1072,7 @@ private[spark] object RandomForest extends Logging with Serializable {
|
||||||
.toArray
|
.toArray
|
||||||
} else {
|
} else {
|
||||||
// stride between splits
|
// stride between splits
|
||||||
val stride: Double = numSamples.toDouble / (numSplits + 1)
|
val stride: Double = weightedNumSamples / (numSplits + 1)
|
||||||
logDebug("stride = " + stride)
|
logDebug("stride = " + stride)
|
||||||
|
|
||||||
// iterate `valueCount` to find splits
|
// iterate `valueCount` to find splits
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
|
|
||||||
package org.apache.spark.ml.tree.impl
|
package org.apache.spark.ml.tree.impl
|
||||||
|
|
||||||
import org.apache.spark.ml.feature.LabeledPoint
|
import org.apache.spark.ml.feature.Instance
|
||||||
import org.apache.spark.ml.tree.{ContinuousSplit, Split}
|
import org.apache.spark.ml.tree.{ContinuousSplit, Split}
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
|
|
||||||
|
@ -36,10 +36,12 @@ import org.apache.spark.rdd.RDD
|
||||||
* @param label Label from LabeledPoint
|
* @param label Label from LabeledPoint
|
||||||
* @param binnedFeatures Binned feature values.
|
* @param binnedFeatures Binned feature values.
|
||||||
* Same length as LabeledPoint.features, but values are bin indices.
|
* Same length as LabeledPoint.features, but values are bin indices.
|
||||||
|
* @param weight Sample weight for this TreePoint.
|
||||||
*/
|
*/
|
||||||
private[spark] class TreePoint(val label: Double, val binnedFeatures: Array[Int])
|
private[spark] class TreePoint(
|
||||||
extends Serializable {
|
val label: Double,
|
||||||
}
|
val binnedFeatures: Array[Int],
|
||||||
|
val weight: Double) extends Serializable
|
||||||
|
|
||||||
private[spark] object TreePoint {
|
private[spark] object TreePoint {
|
||||||
|
|
||||||
|
@ -52,7 +54,7 @@ private[spark] object TreePoint {
|
||||||
* @return TreePoint dataset representation
|
* @return TreePoint dataset representation
|
||||||
*/
|
*/
|
||||||
def convertToTreeRDD(
|
def convertToTreeRDD(
|
||||||
input: RDD[LabeledPoint],
|
input: RDD[Instance],
|
||||||
splits: Array[Array[Split]],
|
splits: Array[Array[Split]],
|
||||||
metadata: DecisionTreeMetadata): RDD[TreePoint] = {
|
metadata: DecisionTreeMetadata): RDD[TreePoint] = {
|
||||||
// Construct arrays for featureArity for efficiency in the inner loop.
|
// Construct arrays for featureArity for efficiency in the inner loop.
|
||||||
|
@ -82,18 +84,18 @@ private[spark] object TreePoint {
|
||||||
* for categorical features.
|
* for categorical features.
|
||||||
*/
|
*/
|
||||||
private def labeledPointToTreePoint(
|
private def labeledPointToTreePoint(
|
||||||
labeledPoint: LabeledPoint,
|
instance: Instance,
|
||||||
thresholds: Array[Array[Double]],
|
thresholds: Array[Array[Double]],
|
||||||
featureArity: Array[Int]): TreePoint = {
|
featureArity: Array[Int]): TreePoint = {
|
||||||
val numFeatures = labeledPoint.features.size
|
val numFeatures = instance.features.size
|
||||||
val arr = new Array[Int](numFeatures)
|
val arr = new Array[Int](numFeatures)
|
||||||
var featureIndex = 0
|
var featureIndex = 0
|
||||||
while (featureIndex < numFeatures) {
|
while (featureIndex < numFeatures) {
|
||||||
arr(featureIndex) =
|
arr(featureIndex) =
|
||||||
findBin(featureIndex, labeledPoint, featureArity(featureIndex), thresholds(featureIndex))
|
findBin(featureIndex, instance, featureArity(featureIndex), thresholds(featureIndex))
|
||||||
featureIndex += 1
|
featureIndex += 1
|
||||||
}
|
}
|
||||||
new TreePoint(labeledPoint.label, arr)
|
new TreePoint(instance.label, arr, instance.weight)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -106,10 +108,10 @@ private[spark] object TreePoint {
|
||||||
*/
|
*/
|
||||||
private def findBin(
|
private def findBin(
|
||||||
featureIndex: Int,
|
featureIndex: Int,
|
||||||
labeledPoint: LabeledPoint,
|
instance: Instance,
|
||||||
featureArity: Int,
|
featureArity: Int,
|
||||||
thresholds: Array[Double]): Int = {
|
thresholds: Array[Double]): Int = {
|
||||||
val featureValue = labeledPoint.features(featureIndex)
|
val featureValue = instance.features(featureIndex)
|
||||||
|
|
||||||
if (featureArity == 0) {
|
if (featureArity == 0) {
|
||||||
val idx = java.util.Arrays.binarySearch(thresholds, featureValue)
|
val idx = java.util.Arrays.binarySearch(thresholds, featureValue)
|
||||||
|
@ -125,7 +127,7 @@ private[spark] object TreePoint {
|
||||||
s"DecisionTree given invalid data:" +
|
s"DecisionTree given invalid data:" +
|
||||||
s" Feature $featureIndex is categorical with values in {0,...,${featureArity - 1}," +
|
s" Feature $featureIndex is categorical with values in {0,...,${featureArity - 1}," +
|
||||||
s" but a data point gives it value $featureValue.\n" +
|
s" but a data point gives it value $featureValue.\n" +
|
||||||
" Bad data point: " + labeledPoint.toString)
|
s" Bad data point: $instance")
|
||||||
}
|
}
|
||||||
featureValue.toInt
|
featureValue.toInt
|
||||||
}
|
}
|
||||||
|
|
|
@ -282,6 +282,7 @@ private[ml] object DecisionTreeModelReadWrite {
|
||||||
*
|
*
|
||||||
* @param id Index used for tree reconstruction. Indices follow a pre-order traversal.
|
* @param id Index used for tree reconstruction. Indices follow a pre-order traversal.
|
||||||
* @param impurityStats Stats array. Impurity type is stored in metadata.
|
* @param impurityStats Stats array. Impurity type is stored in metadata.
|
||||||
|
* @param rawCount The unweighted number of samples falling in this node.
|
||||||
* @param gain Gain, or arbitrary value if leaf node.
|
* @param gain Gain, or arbitrary value if leaf node.
|
||||||
* @param leftChild Left child index, or arbitrary value if leaf node.
|
* @param leftChild Left child index, or arbitrary value if leaf node.
|
||||||
* @param rightChild Right child index, or arbitrary value if leaf node.
|
* @param rightChild Right child index, or arbitrary value if leaf node.
|
||||||
|
@ -292,6 +293,7 @@ private[ml] object DecisionTreeModelReadWrite {
|
||||||
prediction: Double,
|
prediction: Double,
|
||||||
impurity: Double,
|
impurity: Double,
|
||||||
impurityStats: Array[Double],
|
impurityStats: Array[Double],
|
||||||
|
rawCount: Long,
|
||||||
gain: Double,
|
gain: Double,
|
||||||
leftChild: Int,
|
leftChild: Int,
|
||||||
rightChild: Int,
|
rightChild: Int,
|
||||||
|
@ -311,11 +313,12 @@ private[ml] object DecisionTreeModelReadWrite {
|
||||||
val (leftNodeData, leftIdx) = build(n.leftChild, id + 1)
|
val (leftNodeData, leftIdx) = build(n.leftChild, id + 1)
|
||||||
val (rightNodeData, rightIdx) = build(n.rightChild, leftIdx + 1)
|
val (rightNodeData, rightIdx) = build(n.rightChild, leftIdx + 1)
|
||||||
val thisNodeData = NodeData(id, n.prediction, n.impurity, n.impurityStats.stats,
|
val thisNodeData = NodeData(id, n.prediction, n.impurity, n.impurityStats.stats,
|
||||||
n.gain, leftNodeData.head.id, rightNodeData.head.id, SplitData(n.split))
|
n.impurityStats.rawCount, n.gain, leftNodeData.head.id, rightNodeData.head.id,
|
||||||
|
SplitData(n.split))
|
||||||
(thisNodeData +: (leftNodeData ++ rightNodeData), rightIdx)
|
(thisNodeData +: (leftNodeData ++ rightNodeData), rightIdx)
|
||||||
case _: LeafNode =>
|
case _: LeafNode =>
|
||||||
(Seq(NodeData(id, node.prediction, node.impurity, node.impurityStats.stats,
|
(Seq(NodeData(id, node.prediction, node.impurity, node.impurityStats.stats,
|
||||||
-1.0, -1, -1, SplitData(-1, Array.empty[Double], -1))),
|
node.impurityStats.rawCount, -1.0, -1, -1, SplitData(-1, Array.empty[Double], -1))),
|
||||||
id)
|
id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -360,7 +363,8 @@ private[ml] object DecisionTreeModelReadWrite {
|
||||||
// traversal, this guarantees that child nodes will be built before parent nodes.
|
// traversal, this guarantees that child nodes will be built before parent nodes.
|
||||||
val finalNodes = new Array[Node](nodes.length)
|
val finalNodes = new Array[Node](nodes.length)
|
||||||
nodes.reverseIterator.foreach { case n: NodeData =>
|
nodes.reverseIterator.foreach { case n: NodeData =>
|
||||||
val impurityStats = ImpurityCalculator.getCalculator(impurityType, n.impurityStats)
|
val impurityStats =
|
||||||
|
ImpurityCalculator.getCalculator(impurityType, n.impurityStats, n.rawCount)
|
||||||
val node = if (n.leftChild != -1) {
|
val node = if (n.leftChild != -1) {
|
||||||
val leftChild = finalNodes(n.leftChild)
|
val leftChild = finalNodes(n.leftChild)
|
||||||
val rightChild = finalNodes(n.rightChild)
|
val rightChild = finalNodes(n.rightChild)
|
||||||
|
|
|
@ -37,7 +37,7 @@ import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
|
||||||
* Note: Marked as private and DeveloperApi since this may be made public in the future.
|
* Note: Marked as private and DeveloperApi since this may be made public in the future.
|
||||||
*/
|
*/
|
||||||
private[ml] trait DecisionTreeParams extends PredictorParams
|
private[ml] trait DecisionTreeParams extends PredictorParams
|
||||||
with HasCheckpointInterval with HasSeed {
|
with HasCheckpointInterval with HasSeed with HasWeightCol {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Maximum depth of the tree (>= 0).
|
* Maximum depth of the tree (>= 0).
|
||||||
|
@ -74,6 +74,21 @@ private[ml] trait DecisionTreeParams extends PredictorParams
|
||||||
" child to have fewer than minInstancesPerNode, the split will be discarded as invalid." +
|
" child to have fewer than minInstancesPerNode, the split will be discarded as invalid." +
|
||||||
" Should be >= 1.", ParamValidators.gtEq(1))
|
" Should be >= 1.", ParamValidators.gtEq(1))
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Minimum fraction of the weighted sample count that each child must have after split.
|
||||||
|
* If a split causes the fraction of the total weight in the left or right child to be less than
|
||||||
|
* minWeightFractionPerNode, the split will be discarded as invalid.
|
||||||
|
* Should be in the interval [0.0, 0.5).
|
||||||
|
* (default = 0.0)
|
||||||
|
* @group param
|
||||||
|
*/
|
||||||
|
final val minWeightFractionPerNode: DoubleParam = new DoubleParam(this,
|
||||||
|
"minWeightFractionPerNode", "Minimum fraction of the weighted sample count that each child " +
|
||||||
|
"must have after split. If a split causes the fraction of the total weight in the left or " +
|
||||||
|
"right child to be less than minWeightFractionPerNode, the split will be discarded as " +
|
||||||
|
"invalid. Should be in interval [0.0, 0.5)",
|
||||||
|
ParamValidators.inRange(0.0, 0.5, lowerInclusive = true, upperInclusive = false))
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Minimum information gain for a split to be considered at a tree node.
|
* Minimum information gain for a split to be considered at a tree node.
|
||||||
* Should be >= 0.0.
|
* Should be >= 0.0.
|
||||||
|
@ -107,8 +122,9 @@ private[ml] trait DecisionTreeParams extends PredictorParams
|
||||||
" algorithm will cache node IDs for each instance. Caching can speed up training of deeper" +
|
" algorithm will cache node IDs for each instance. Caching can speed up training of deeper" +
|
||||||
" trees.")
|
" trees.")
|
||||||
|
|
||||||
setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0,
|
setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1,
|
||||||
maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10)
|
minWeightFractionPerNode -> 0.0, minInfoGain -> 0.0, maxMemoryInMB -> 256,
|
||||||
|
cacheNodeIds -> false, checkpointInterval -> 10)
|
||||||
|
|
||||||
/** @group getParam */
|
/** @group getParam */
|
||||||
final def getMaxDepth: Int = $(maxDepth)
|
final def getMaxDepth: Int = $(maxDepth)
|
||||||
|
@ -119,6 +135,9 @@ private[ml] trait DecisionTreeParams extends PredictorParams
|
||||||
/** @group getParam */
|
/** @group getParam */
|
||||||
final def getMinInstancesPerNode: Int = $(minInstancesPerNode)
|
final def getMinInstancesPerNode: Int = $(minInstancesPerNode)
|
||||||
|
|
||||||
|
/** @group getParam */
|
||||||
|
final def getMinWeightFractionPerNode: Double = $(minWeightFractionPerNode)
|
||||||
|
|
||||||
/** @group getParam */
|
/** @group getParam */
|
||||||
final def getMinInfoGain: Double = $(minInfoGain)
|
final def getMinInfoGain: Double = $(minInfoGain)
|
||||||
|
|
||||||
|
@ -143,6 +162,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams
|
||||||
strategy.maxMemoryInMB = getMaxMemoryInMB
|
strategy.maxMemoryInMB = getMaxMemoryInMB
|
||||||
strategy.minInfoGain = getMinInfoGain
|
strategy.minInfoGain = getMinInfoGain
|
||||||
strategy.minInstancesPerNode = getMinInstancesPerNode
|
strategy.minInstancesPerNode = getMinInstancesPerNode
|
||||||
|
strategy.minWeightFractionPerNode = getMinWeightFractionPerNode
|
||||||
strategy.useNodeIdCache = getCacheNodeIds
|
strategy.useNodeIdCache = getCacheNodeIds
|
||||||
strategy.numClasses = numClasses
|
strategy.numClasses = numClasses
|
||||||
strategy.categoricalFeaturesInfo = categoricalFeatures
|
strategy.categoricalFeaturesInfo = categoricalFeatures
|
||||||
|
|
|
@ -23,6 +23,7 @@ import scala.util.Try
|
||||||
import org.apache.spark.annotation.Since
|
import org.apache.spark.annotation.Since
|
||||||
import org.apache.spark.api.java.JavaRDD
|
import org.apache.spark.api.java.JavaRDD
|
||||||
import org.apache.spark.internal.Logging
|
import org.apache.spark.internal.Logging
|
||||||
|
import org.apache.spark.ml.feature.Instance
|
||||||
import org.apache.spark.ml.tree.{DecisionTreeModel => NewDTModel, TreeEnsembleParams => NewRFParams}
|
import org.apache.spark.ml.tree.{DecisionTreeModel => NewDTModel, TreeEnsembleParams => NewRFParams}
|
||||||
import org.apache.spark.ml.tree.impl.{RandomForest => NewRandomForest}
|
import org.apache.spark.ml.tree.impl.{RandomForest => NewRandomForest}
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
|
@ -91,8 +92,8 @@ private class RandomForest (
|
||||||
* @return RandomForestModel that can be used for prediction.
|
* @return RandomForestModel that can be used for prediction.
|
||||||
*/
|
*/
|
||||||
def run(input: RDD[LabeledPoint]): RandomForestModel = {
|
def run(input: RDD[LabeledPoint]): RandomForestModel = {
|
||||||
val trees: Array[NewDTModel] = NewRandomForest.run(input.map(_.asML), strategy, numTrees,
|
val trees: Array[NewDTModel] =
|
||||||
featureSubsetStrategy, seed.toLong, None)
|
NewRandomForest.run(input, strategy, numTrees, featureSubsetStrategy, seed.toLong)
|
||||||
new RandomForestModel(strategy.algo, trees.map(_.toOld))
|
new RandomForestModel(strategy.algo, trees.map(_.toOld))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -80,7 +80,8 @@ class Strategy @Since("1.3.0") (
|
||||||
@Since("1.0.0") @BeanProperty var maxMemoryInMB: Int = 256,
|
@Since("1.0.0") @BeanProperty var maxMemoryInMB: Int = 256,
|
||||||
@Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1,
|
@Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1,
|
||||||
@Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false,
|
@Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false,
|
||||||
@Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10) extends Serializable {
|
@Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10,
|
||||||
|
@Since("3.0.0") @BeanProperty var minWeightFractionPerNode: Double = 0.0) extends Serializable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*/
|
*/
|
||||||
|
@ -96,6 +97,31 @@ class Strategy @Since("1.3.0") (
|
||||||
isMulticlassClassification && (categoricalFeaturesInfo.size > 0)
|
isMulticlassClassification && (categoricalFeaturesInfo.size > 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// scalastyle:off argcount
|
||||||
|
/**
|
||||||
|
* Backwards compatible constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]]
|
||||||
|
*/
|
||||||
|
@Since("1.0.0")
|
||||||
|
def this(
|
||||||
|
algo: Algo,
|
||||||
|
impurity: Impurity,
|
||||||
|
maxDepth: Int,
|
||||||
|
numClasses: Int,
|
||||||
|
maxBins: Int,
|
||||||
|
quantileCalculationStrategy: QuantileStrategy,
|
||||||
|
categoricalFeaturesInfo: Map[Int, Int],
|
||||||
|
minInstancesPerNode: Int,
|
||||||
|
minInfoGain: Double,
|
||||||
|
maxMemoryInMB: Int,
|
||||||
|
subsamplingRate: Double,
|
||||||
|
useNodeIdCache: Boolean,
|
||||||
|
checkpointInterval: Int) {
|
||||||
|
this(algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy,
|
||||||
|
categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, maxMemoryInMB,
|
||||||
|
subsamplingRate, useNodeIdCache, checkpointInterval, 0.0)
|
||||||
|
}
|
||||||
|
// scalastyle:on argcount
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Java-friendly constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]]
|
* Java-friendly constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]]
|
||||||
*/
|
*/
|
||||||
|
@ -108,7 +134,8 @@ class Strategy @Since("1.3.0") (
|
||||||
maxBins: Int,
|
maxBins: Int,
|
||||||
categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]) {
|
categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]) {
|
||||||
this(algo, impurity, maxDepth, numClasses, maxBins, Sort,
|
this(algo, impurity, maxDepth, numClasses, maxBins, Sort,
|
||||||
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap)
|
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
|
||||||
|
minWeightFractionPerNode = 0.0)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -171,8 +198,9 @@ class Strategy @Since("1.3.0") (
|
||||||
@Since("1.2.0")
|
@Since("1.2.0")
|
||||||
def copy: Strategy = {
|
def copy: Strategy = {
|
||||||
new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
|
new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
|
||||||
quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain,
|
quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode,
|
||||||
maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval)
|
minInfoGain, maxMemoryInMB, subsamplingRate, useNodeIdCache,
|
||||||
|
checkpointInterval, minWeightFractionPerNode)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -83,23 +83,29 @@ object Entropy extends Impurity {
|
||||||
* @param numClasses Number of classes for label.
|
* @param numClasses Number of classes for label.
|
||||||
*/
|
*/
|
||||||
private[spark] class EntropyAggregator(numClasses: Int)
|
private[spark] class EntropyAggregator(numClasses: Int)
|
||||||
extends ImpurityAggregator(numClasses) with Serializable {
|
extends ImpurityAggregator(numClasses + 1) with Serializable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Update stats for one (node, feature, bin) with the given label.
|
* Update stats for one (node, feature, bin) with the given label.
|
||||||
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
|
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
|
||||||
* @param offset Start index of stats for this (node, feature, bin).
|
* @param offset Start index of stats for this (node, feature, bin).
|
||||||
*/
|
*/
|
||||||
def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = {
|
def update(
|
||||||
if (label >= statsSize) {
|
allStats: Array[Double],
|
||||||
|
offset: Int,
|
||||||
|
label: Double,
|
||||||
|
numSamples: Int,
|
||||||
|
sampleWeight: Double): Unit = {
|
||||||
|
if (label >= numClasses) {
|
||||||
throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
|
throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
|
||||||
s" but requires label < numClasses (= $statsSize).")
|
s" but requires label < numClasses (= ${numClasses}).")
|
||||||
}
|
}
|
||||||
if (label < 0) {
|
if (label < 0) {
|
||||||
throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
|
throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
|
||||||
s"but requires label is non-negative.")
|
s"but requires label is non-negative.")
|
||||||
}
|
}
|
||||||
allStats(offset + label.toInt) += instanceWeight
|
allStats(offset + label.toInt) += numSamples * sampleWeight
|
||||||
|
allStats(offset + statsSize - 1) += numSamples
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -108,7 +114,8 @@ private[spark] class EntropyAggregator(numClasses: Int)
|
||||||
* @param offset Start index of stats for this (node, feature, bin).
|
* @param offset Start index of stats for this (node, feature, bin).
|
||||||
*/
|
*/
|
||||||
def getCalculator(allStats: Array[Double], offset: Int): EntropyCalculator = {
|
def getCalculator(allStats: Array[Double], offset: Int): EntropyCalculator = {
|
||||||
new EntropyCalculator(allStats.view(offset, offset + statsSize).toArray)
|
new EntropyCalculator(allStats.view(offset, offset + statsSize - 1).toArray,
|
||||||
|
allStats(offset + statsSize - 1).toLong)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -118,12 +125,13 @@ private[spark] class EntropyAggregator(numClasses: Int)
|
||||||
* (node, feature, bin).
|
* (node, feature, bin).
|
||||||
* @param stats Array of sufficient statistics for a (node, feature, bin).
|
* @param stats Array of sufficient statistics for a (node, feature, bin).
|
||||||
*/
|
*/
|
||||||
private[spark] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
|
private[spark] class EntropyCalculator(stats: Array[Double], var rawCount: Long)
|
||||||
|
extends ImpurityCalculator(stats) {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Make a deep copy of this [[ImpurityCalculator]].
|
* Make a deep copy of this [[ImpurityCalculator]].
|
||||||
*/
|
*/
|
||||||
def copy: EntropyCalculator = new EntropyCalculator(stats.clone())
|
def copy: EntropyCalculator = new EntropyCalculator(stats.clone(), rawCount)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Calculate the impurity from the stored sufficient statistics.
|
* Calculate the impurity from the stored sufficient statistics.
|
||||||
|
@ -131,9 +139,9 @@ private[spark] class EntropyCalculator(stats: Array[Double]) extends ImpurityCal
|
||||||
def calculate(): Double = Entropy.calculate(stats, stats.sum)
|
def calculate(): Double = Entropy.calculate(stats, stats.sum)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Number of data points accounted for in the sufficient statistics.
|
* Weighted number of data points accounted for in the sufficient statistics.
|
||||||
*/
|
*/
|
||||||
def count: Long = stats.sum.toLong
|
def count: Double = stats.sum
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Prediction which should be made based on the sufficient statistics.
|
* Prediction which should be made based on the sufficient statistics.
|
||||||
|
|
|
@ -80,23 +80,29 @@ object Gini extends Impurity {
|
||||||
* @param numClasses Number of classes for label.
|
* @param numClasses Number of classes for label.
|
||||||
*/
|
*/
|
||||||
private[spark] class GiniAggregator(numClasses: Int)
|
private[spark] class GiniAggregator(numClasses: Int)
|
||||||
extends ImpurityAggregator(numClasses) with Serializable {
|
extends ImpurityAggregator(numClasses + 1) with Serializable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Update stats for one (node, feature, bin) with the given label.
|
* Update stats for one (node, feature, bin) with the given label.
|
||||||
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
|
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
|
||||||
* @param offset Start index of stats for this (node, feature, bin).
|
* @param offset Start index of stats for this (node, feature, bin).
|
||||||
*/
|
*/
|
||||||
def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = {
|
def update(
|
||||||
if (label >= statsSize) {
|
allStats: Array[Double],
|
||||||
|
offset: Int,
|
||||||
|
label: Double,
|
||||||
|
numSamples: Int,
|
||||||
|
sampleWeight: Double): Unit = {
|
||||||
|
if (label >= numClasses) {
|
||||||
throw new IllegalArgumentException(s"GiniAggregator given label $label" +
|
throw new IllegalArgumentException(s"GiniAggregator given label $label" +
|
||||||
s" but requires label < numClasses (= $statsSize).")
|
s" but requires label < numClasses (= ${numClasses}).")
|
||||||
}
|
}
|
||||||
if (label < 0) {
|
if (label < 0) {
|
||||||
throw new IllegalArgumentException(s"GiniAggregator given label $label" +
|
throw new IllegalArgumentException(s"GiniAggregator given label $label" +
|
||||||
s"but requires label is non-negative.")
|
s"but requires label to be non-negative.")
|
||||||
}
|
}
|
||||||
allStats(offset + label.toInt) += instanceWeight
|
allStats(offset + label.toInt) += numSamples * sampleWeight
|
||||||
|
allStats(offset + statsSize - 1) += numSamples
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -105,7 +111,8 @@ private[spark] class GiniAggregator(numClasses: Int)
|
||||||
* @param offset Start index of stats for this (node, feature, bin).
|
* @param offset Start index of stats for this (node, feature, bin).
|
||||||
*/
|
*/
|
||||||
def getCalculator(allStats: Array[Double], offset: Int): GiniCalculator = {
|
def getCalculator(allStats: Array[Double], offset: Int): GiniCalculator = {
|
||||||
new GiniCalculator(allStats.view(offset, offset + statsSize).toArray)
|
new GiniCalculator(allStats.view(offset, offset + statsSize - 1).toArray,
|
||||||
|
allStats(offset + statsSize - 1).toLong)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -115,12 +122,13 @@ private[spark] class GiniAggregator(numClasses: Int)
|
||||||
* (node, feature, bin).
|
* (node, feature, bin).
|
||||||
* @param stats Array of sufficient statistics for a (node, feature, bin).
|
* @param stats Array of sufficient statistics for a (node, feature, bin).
|
||||||
*/
|
*/
|
||||||
private[spark] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
|
private[spark] class GiniCalculator(stats: Array[Double], var rawCount: Long)
|
||||||
|
extends ImpurityCalculator(stats) {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Make a deep copy of this [[ImpurityCalculator]].
|
* Make a deep copy of this [[ImpurityCalculator]].
|
||||||
*/
|
*/
|
||||||
def copy: GiniCalculator = new GiniCalculator(stats.clone())
|
def copy: GiniCalculator = new GiniCalculator(stats.clone(), rawCount)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Calculate the impurity from the stored sufficient statistics.
|
* Calculate the impurity from the stored sufficient statistics.
|
||||||
|
@ -128,9 +136,9 @@ private[spark] class GiniCalculator(stats: Array[Double]) extends ImpurityCalcul
|
||||||
def calculate(): Double = Gini.calculate(stats, stats.sum)
|
def calculate(): Double = Gini.calculate(stats, stats.sum)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Number of data points accounted for in the sufficient statistics.
|
* Weighted number of data points accounted for in the sufficient statistics.
|
||||||
*/
|
*/
|
||||||
def count: Long = stats.sum.toLong
|
def count: Double = stats.sum
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Prediction which should be made based on the sufficient statistics.
|
* Prediction which should be made based on the sufficient statistics.
|
||||||
|
|
|
@ -81,7 +81,12 @@ private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Ser
|
||||||
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
|
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
|
||||||
* @param offset Start index of stats for this (node, feature, bin).
|
* @param offset Start index of stats for this (node, feature, bin).
|
||||||
*/
|
*/
|
||||||
def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit
|
def update(
|
||||||
|
allStats: Array[Double],
|
||||||
|
offset: Int,
|
||||||
|
label: Double,
|
||||||
|
numSamples: Int,
|
||||||
|
sampleWeight: Double): Unit
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get an [[ImpurityCalculator]] for a (node, feature, bin).
|
* Get an [[ImpurityCalculator]] for a (node, feature, bin).
|
||||||
|
@ -122,6 +127,7 @@ private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) exten
|
||||||
stats(i) += other.stats(i)
|
stats(i) += other.stats(i)
|
||||||
i += 1
|
i += 1
|
||||||
}
|
}
|
||||||
|
rawCount += other.rawCount
|
||||||
this
|
this
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -139,13 +145,19 @@ private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) exten
|
||||||
stats(i) -= other.stats(i)
|
stats(i) -= other.stats(i)
|
||||||
i += 1
|
i += 1
|
||||||
}
|
}
|
||||||
|
rawCount -= other.rawCount
|
||||||
this
|
this
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Number of data points accounted for in the sufficient statistics.
|
* Weighted number of data points accounted for in the sufficient statistics.
|
||||||
*/
|
*/
|
||||||
def count: Long
|
def count: Double
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Raw number of data points accounted for in the sufficient statistics.
|
||||||
|
*/
|
||||||
|
var rawCount: Long
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Prediction which should be made based on the sufficient statistics.
|
* Prediction which should be made based on the sufficient statistics.
|
||||||
|
@ -185,11 +197,14 @@ private[spark] object ImpurityCalculator {
|
||||||
* Create an [[ImpurityCalculator]] instance of the given impurity type and with
|
* Create an [[ImpurityCalculator]] instance of the given impurity type and with
|
||||||
* the given stats.
|
* the given stats.
|
||||||
*/
|
*/
|
||||||
def getCalculator(impurity: String, stats: Array[Double]): ImpurityCalculator = {
|
def getCalculator(
|
||||||
|
impurity: String,
|
||||||
|
stats: Array[Double],
|
||||||
|
rawCount: Long): ImpurityCalculator = {
|
||||||
impurity.toLowerCase(Locale.ROOT) match {
|
impurity.toLowerCase(Locale.ROOT) match {
|
||||||
case "gini" => new GiniCalculator(stats)
|
case "gini" => new GiniCalculator(stats, rawCount)
|
||||||
case "entropy" => new EntropyCalculator(stats)
|
case "entropy" => new EntropyCalculator(stats, rawCount)
|
||||||
case "variance" => new VarianceCalculator(stats)
|
case "variance" => new VarianceCalculator(stats, rawCount)
|
||||||
case _ =>
|
case _ =>
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
s"ImpurityCalculator builder did not recognize impurity type: $impurity")
|
s"ImpurityCalculator builder did not recognize impurity type: $impurity")
|
||||||
|
|
|
@ -66,21 +66,32 @@ object Variance extends Impurity {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Class for updating views of a vector of sufficient statistics,
|
* Class for updating views of a vector of sufficient statistics,
|
||||||
* in order to compute impurity from a sample.
|
* in order to compute impurity from a sample. For variance, we track:
|
||||||
|
* - sum(w_i)
|
||||||
|
* - sum(w_i * y_i)
|
||||||
|
* - sum(w_i * y_i * y_i)
|
||||||
|
* - count(y_i)
|
||||||
* Note: Instances of this class do not hold the data; they operate on views of the data.
|
* Note: Instances of this class do not hold the data; they operate on views of the data.
|
||||||
*/
|
*/
|
||||||
private[spark] class VarianceAggregator()
|
private[spark] class VarianceAggregator()
|
||||||
extends ImpurityAggregator(statsSize = 3) with Serializable {
|
extends ImpurityAggregator(statsSize = 4) with Serializable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Update stats for one (node, feature, bin) with the given label.
|
* Update stats for one (node, feature, bin) with the given label.
|
||||||
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
|
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
|
||||||
* @param offset Start index of stats for this (node, feature, bin).
|
* @param offset Start index of stats for this (node, feature, bin).
|
||||||
*/
|
*/
|
||||||
def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = {
|
def update(
|
||||||
|
allStats: Array[Double],
|
||||||
|
offset: Int,
|
||||||
|
label: Double,
|
||||||
|
numSamples: Int,
|
||||||
|
sampleWeight: Double): Unit = {
|
||||||
|
val instanceWeight = numSamples * sampleWeight
|
||||||
allStats(offset) += instanceWeight
|
allStats(offset) += instanceWeight
|
||||||
allStats(offset + 1) += instanceWeight * label
|
allStats(offset + 1) += instanceWeight * label
|
||||||
allStats(offset + 2) += instanceWeight * label * label
|
allStats(offset + 2) += instanceWeight * label * label
|
||||||
|
allStats(offset + 3) += numSamples
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -89,7 +100,8 @@ private[spark] class VarianceAggregator()
|
||||||
* @param offset Start index of stats for this (node, feature, bin).
|
* @param offset Start index of stats for this (node, feature, bin).
|
||||||
*/
|
*/
|
||||||
def getCalculator(allStats: Array[Double], offset: Int): VarianceCalculator = {
|
def getCalculator(allStats: Array[Double], offset: Int): VarianceCalculator = {
|
||||||
new VarianceCalculator(allStats.view(offset, offset + statsSize).toArray)
|
new VarianceCalculator(allStats.view(offset, offset + statsSize - 1).toArray,
|
||||||
|
allStats(offset + statsSize - 1).toLong)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -99,7 +111,8 @@ private[spark] class VarianceAggregator()
|
||||||
* (node, feature, bin).
|
* (node, feature, bin).
|
||||||
* @param stats Array of sufficient statistics for a (node, feature, bin).
|
* @param stats Array of sufficient statistics for a (node, feature, bin).
|
||||||
*/
|
*/
|
||||||
private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
|
private[spark] class VarianceCalculator(stats: Array[Double], var rawCount: Long)
|
||||||
|
extends ImpurityCalculator(stats) {
|
||||||
|
|
||||||
require(stats.length == 3,
|
require(stats.length == 3,
|
||||||
s"VarianceCalculator requires sufficient statistics array stats to be of length 3," +
|
s"VarianceCalculator requires sufficient statistics array stats to be of length 3," +
|
||||||
|
@ -108,7 +121,7 @@ private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCa
|
||||||
/**
|
/**
|
||||||
* Make a deep copy of this [[ImpurityCalculator]].
|
* Make a deep copy of this [[ImpurityCalculator]].
|
||||||
*/
|
*/
|
||||||
def copy: VarianceCalculator = new VarianceCalculator(stats.clone())
|
def copy: VarianceCalculator = new VarianceCalculator(stats.clone(), rawCount)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Calculate the impurity from the stored sufficient statistics.
|
* Calculate the impurity from the stored sufficient statistics.
|
||||||
|
@ -116,9 +129,9 @@ private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCa
|
||||||
def calculate(): Double = Variance.calculate(stats(0), stats(1), stats(2))
|
def calculate(): Double = Variance.calculate(stats(0), stats(1), stats(2))
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Number of data points accounted for in the sufficient statistics.
|
* Weighted number of data points accounted for in the sufficient statistics.
|
||||||
*/
|
*/
|
||||||
def count: Long = stats(0).toLong
|
def count: Double = stats(0)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Prediction which should be made based on the sufficient statistics.
|
* Prediction which should be made based on the sufficient statistics.
|
||||||
|
|
|
@ -42,6 +42,8 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest {
|
||||||
private var continuousDataPointsForMulticlassRDD: RDD[LabeledPoint] = _
|
private var continuousDataPointsForMulticlassRDD: RDD[LabeledPoint] = _
|
||||||
private var categoricalDataPointsForMulticlassForOrderedFeaturesRDD: RDD[LabeledPoint] = _
|
private var categoricalDataPointsForMulticlassForOrderedFeaturesRDD: RDD[LabeledPoint] = _
|
||||||
|
|
||||||
|
private val seed = 42
|
||||||
|
|
||||||
override def beforeAll() {
|
override def beforeAll() {
|
||||||
super.beforeAll()
|
super.beforeAll()
|
||||||
categoricalDataPointsRDD =
|
categoricalDataPointsRDD =
|
||||||
|
@ -250,7 +252,7 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest {
|
||||||
|
|
||||||
MLTestingUtils.checkCopyAndUids(dt, newTree)
|
MLTestingUtils.checkCopyAndUids(dt, newTree)
|
||||||
|
|
||||||
testTransformer[(Vector, Double)](newData, newTree,
|
testTransformer[(Vector, Double, Double)](newData, newTree,
|
||||||
"prediction", "rawPrediction", "probability") {
|
"prediction", "rawPrediction", "probability") {
|
||||||
case Row(pred: Double, rawPred: Vector, probPred: Vector) =>
|
case Row(pred: Double, rawPred: Vector, probPred: Vector) =>
|
||||||
assert(pred === rawPred.argmax,
|
assert(pred === rawPred.argmax,
|
||||||
|
@ -327,6 +329,49 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest {
|
||||||
dt.fit(df)
|
dt.fit(df)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("training with sample weights") {
|
||||||
|
val df = {
|
||||||
|
val nPoints = 100
|
||||||
|
val coefficients = Array(
|
||||||
|
-0.57997, 0.912083, -0.371077,
|
||||||
|
-0.16624, -0.84355, -0.048509)
|
||||||
|
|
||||||
|
val xMean = Array(5.843, 3.057)
|
||||||
|
val xVariance = Array(0.6856, 0.1899)
|
||||||
|
|
||||||
|
val testData = LogisticRegressionSuite.generateMultinomialLogisticInput(
|
||||||
|
coefficients, xMean, xVariance, addIntercept = true, nPoints, seed)
|
||||||
|
|
||||||
|
sc.parallelize(testData, 4).toDF()
|
||||||
|
}
|
||||||
|
val numClasses = 3
|
||||||
|
val predEquals = (x: Double, y: Double) => x == y
|
||||||
|
// (impurity, maxDepth)
|
||||||
|
val testParams = Seq(
|
||||||
|
("gini", 10),
|
||||||
|
("entropy", 10),
|
||||||
|
("gini", 5)
|
||||||
|
)
|
||||||
|
for ((impurity, maxDepth) <- testParams) {
|
||||||
|
val estimator = new DecisionTreeClassifier()
|
||||||
|
.setMaxDepth(maxDepth)
|
||||||
|
.setSeed(seed)
|
||||||
|
.setMinWeightFractionPerNode(0.049)
|
||||||
|
.setImpurity(impurity)
|
||||||
|
|
||||||
|
MLTestingUtils.testArbitrarilyScaledWeights[DecisionTreeClassificationModel,
|
||||||
|
DecisionTreeClassifier](df.as[LabeledPoint], estimator,
|
||||||
|
MLTestingUtils.modelPredictionEquals(df, predEquals, 0.7))
|
||||||
|
MLTestingUtils.testOutliersWithSmallWeights[DecisionTreeClassificationModel,
|
||||||
|
DecisionTreeClassifier](df.as[LabeledPoint], estimator,
|
||||||
|
numClasses, MLTestingUtils.modelPredictionEquals(df, predEquals, 0.8),
|
||||||
|
outlierRatio = 2)
|
||||||
|
MLTestingUtils.testOversamplingVsWeighting[DecisionTreeClassificationModel,
|
||||||
|
DecisionTreeClassifier](df.as[LabeledPoint], estimator,
|
||||||
|
MLTestingUtils.modelPredictionEquals(df, predEquals, 0.7), seed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Tests of model save/load
|
// Tests of model save/load
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -350,7 +395,6 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest {
|
||||||
TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 2)
|
TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 2)
|
||||||
testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings,
|
testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings,
|
||||||
allParamSettings, checkModelData)
|
allParamSettings, checkModelData)
|
||||||
|
|
||||||
// Continuous splits with tree depth 2
|
// Continuous splits with tree depth 2
|
||||||
val continuousData: DataFrame =
|
val continuousData: DataFrame =
|
||||||
TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
|
TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
package org.apache.spark.ml.classification
|
package org.apache.spark.ml.classification
|
||||||
|
|
||||||
import org.apache.spark.SparkFunSuite
|
import org.apache.spark.SparkFunSuite
|
||||||
import org.apache.spark.ml.feature.LabeledPoint
|
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
|
||||||
import org.apache.spark.ml.linalg.{Vector, Vectors}
|
import org.apache.spark.ml.linalg.{Vector, Vectors}
|
||||||
import org.apache.spark.ml.param.ParamsSuite
|
import org.apache.spark.ml.param.ParamsSuite
|
||||||
import org.apache.spark.ml.tree.LeafNode
|
import org.apache.spark.ml.tree.LeafNode
|
||||||
|
@ -141,7 +141,7 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest {
|
||||||
|
|
||||||
MLTestingUtils.checkCopyAndUids(rf, model)
|
MLTestingUtils.checkCopyAndUids(rf, model)
|
||||||
|
|
||||||
testTransformer[(Vector, Double)](df, model, "prediction", "rawPrediction",
|
testTransformer[(Vector, Double, Double)](df, model, "prediction", "rawPrediction",
|
||||||
"probability") { case Row(pred: Double, rawPred: Vector, probPred: Vector) =>
|
"probability") { case Row(pred: Double, rawPred: Vector, probPred: Vector) =>
|
||||||
assert(pred === rawPred.argmax,
|
assert(pred === rawPred.argmax,
|
||||||
s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.")
|
s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.")
|
||||||
|
@ -180,7 +180,6 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest {
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Tests of feature importance
|
// Tests of feature importance
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
test("Feature importance with toy data") {
|
test("Feature importance with toy data") {
|
||||||
val numClasses = 2
|
val numClasses = 2
|
||||||
val rf = new RandomForestClassifier()
|
val rf = new RandomForestClassifier()
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
package org.apache.spark.ml.regression
|
package org.apache.spark.ml.regression
|
||||||
|
|
||||||
import org.apache.spark.SparkFunSuite
|
import org.apache.spark.SparkFunSuite
|
||||||
import org.apache.spark.ml.feature.LabeledPoint
|
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
|
||||||
import org.apache.spark.ml.linalg.Vector
|
import org.apache.spark.ml.linalg.Vector
|
||||||
import org.apache.spark.ml.tree.impl.TreeTests
|
import org.apache.spark.ml.tree.impl.TreeTests
|
||||||
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
|
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
|
||||||
|
@ -26,6 +26,7 @@ import org.apache.spark.ml.util.TestingUtils._
|
||||||
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
|
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
|
||||||
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
|
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
|
||||||
DecisionTreeSuite => OldDecisionTreeSuite}
|
DecisionTreeSuite => OldDecisionTreeSuite}
|
||||||
|
import org.apache.spark.mllib.util.LinearDataGenerator
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.{DataFrame, Row}
|
import org.apache.spark.sql.{DataFrame, Row}
|
||||||
|
|
||||||
|
@ -35,11 +36,17 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest {
|
||||||
import testImplicits._
|
import testImplicits._
|
||||||
|
|
||||||
private var categoricalDataPointsRDD: RDD[LabeledPoint] = _
|
private var categoricalDataPointsRDD: RDD[LabeledPoint] = _
|
||||||
|
private var linearRegressionData: DataFrame = _
|
||||||
|
|
||||||
|
private val seed = 42
|
||||||
|
|
||||||
override def beforeAll() {
|
override def beforeAll() {
|
||||||
super.beforeAll()
|
super.beforeAll()
|
||||||
categoricalDataPointsRDD =
|
categoricalDataPointsRDD =
|
||||||
sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints().map(_.asML))
|
sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints().map(_.asML))
|
||||||
|
linearRegressionData = sc.parallelize(LinearDataGenerator.generateLinearInput(
|
||||||
|
intercept = 6.3, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3),
|
||||||
|
xVariance = Array(0.7, 1.2), nPoints = 1000, seed, eps = 0.5), 2).map(_.asML).toDF()
|
||||||
}
|
}
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -88,7 +95,7 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest {
|
||||||
val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0)
|
val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0)
|
||||||
val model = dt.fit(df)
|
val model = dt.fit(df)
|
||||||
|
|
||||||
testTransformer[(Vector, Double)](df, model, "features", "variance") {
|
testTransformer[(Vector, Double, Double)](df, model, "features", "variance") {
|
||||||
case Row(features: Vector, variance: Double) =>
|
case Row(features: Vector, variance: Double) =>
|
||||||
val expectedVariance = model.rootNode.predictImpl(features).impurityStats.calculate()
|
val expectedVariance = model.rootNode.predictImpl(features).impurityStats.calculate()
|
||||||
assert(variance === expectedVariance,
|
assert(variance === expectedVariance,
|
||||||
|
@ -101,7 +108,7 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest {
|
||||||
.setMaxBins(6)
|
.setMaxBins(6)
|
||||||
.setSeed(0)
|
.setSeed(0)
|
||||||
|
|
||||||
testTransformerByGlobalCheckFunc[(Vector, Double)](varianceDF, dt.fit(varianceDF),
|
testTransformerByGlobalCheckFunc[(Vector, Double, Double)](varianceDF, dt.fit(varianceDF),
|
||||||
"variance") { case rows: Seq[Row] =>
|
"variance") { case rows: Seq[Row] =>
|
||||||
val calculatedVariances = rows.map(_.getDouble(0))
|
val calculatedVariances = rows.map(_.getDouble(0))
|
||||||
|
|
||||||
|
@ -159,6 +166,28 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("training with sample weights") {
|
||||||
|
val df = linearRegressionData
|
||||||
|
val numClasses = 0
|
||||||
|
val testParams = Seq(5, 10)
|
||||||
|
for (maxDepth <- testParams) {
|
||||||
|
val estimator = new DecisionTreeRegressor()
|
||||||
|
.setMaxDepth(maxDepth)
|
||||||
|
.setMinWeightFractionPerNode(0.05)
|
||||||
|
.setSeed(123)
|
||||||
|
MLTestingUtils.testArbitrarilyScaledWeights[DecisionTreeRegressionModel,
|
||||||
|
DecisionTreeRegressor](df.as[LabeledPoint], estimator,
|
||||||
|
MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.1, 0.99))
|
||||||
|
MLTestingUtils.testOutliersWithSmallWeights[DecisionTreeRegressionModel,
|
||||||
|
DecisionTreeRegressor](df.as[LabeledPoint], estimator, numClasses,
|
||||||
|
MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.1, 0.99),
|
||||||
|
outlierRatio = 2)
|
||||||
|
MLTestingUtils.testOversamplingVsWeighting[DecisionTreeRegressionModel,
|
||||||
|
DecisionTreeRegressor](df.as[LabeledPoint], estimator,
|
||||||
|
MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.01, 1.0), seed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Tests of model save/load
|
// Tests of model save/load
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -891,6 +891,7 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLRe
|
||||||
.setStandardization(standardization)
|
.setStandardization(standardization)
|
||||||
.setRegParam(regParam)
|
.setRegParam(regParam)
|
||||||
.setElasticNetParam(elasticNetParam)
|
.setElasticNetParam(elasticNetParam)
|
||||||
|
.setSolver(solver)
|
||||||
MLTestingUtils.testArbitrarilyScaledWeights[LinearRegressionModel, LinearRegression](
|
MLTestingUtils.testArbitrarilyScaledWeights[LinearRegressionModel, LinearRegression](
|
||||||
datasetWithStrongNoise.as[LabeledPoint], estimator, modelEquals)
|
datasetWithStrongNoise.as[LabeledPoint], estimator, modelEquals)
|
||||||
MLTestingUtils.testOutliersWithSmallWeights[LinearRegressionModel, LinearRegression](
|
MLTestingUtils.testOutliersWithSmallWeights[LinearRegressionModel, LinearRegression](
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
package org.apache.spark.ml.tree.impl
|
package org.apache.spark.ml.tree.impl
|
||||||
|
|
||||||
import org.apache.spark.SparkFunSuite
|
import org.apache.spark.SparkFunSuite
|
||||||
|
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
|
||||||
import org.apache.spark.mllib.tree.EnsembleTestHelper
|
import org.apache.spark.mllib.tree.EnsembleTestHelper
|
||||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
|
|
||||||
|
@ -26,12 +27,16 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
*/
|
*/
|
||||||
class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext {
|
class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
|
|
||||||
test("BaggedPoint RDD: without subsampling") {
|
test("BaggedPoint RDD: without subsampling with weights") {
|
||||||
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
|
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000).map { lp =>
|
||||||
|
Instance(lp.label, 0.5, lp.features.asML)
|
||||||
|
}
|
||||||
val rdd = sc.parallelize(arr)
|
val rdd = sc.parallelize(arr)
|
||||||
val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false, 42)
|
val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false,
|
||||||
|
(instance: Instance) => instance.weight * 4.0, seed = 42)
|
||||||
baggedRDD.collect().foreach { baggedPoint =>
|
baggedRDD.collect().foreach { baggedPoint =>
|
||||||
assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1)
|
assert(baggedPoint.subsampleCounts.size === 1 && baggedPoint.subsampleCounts(0) === 1)
|
||||||
|
assert(baggedPoint.sampleWeight === 2.0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -40,13 +45,17 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
val (expectedMean, expectedStddev) = (1.0, 1.0)
|
val (expectedMean, expectedStddev) = (1.0, 1.0)
|
||||||
|
|
||||||
val seeds = Array(123, 5354, 230, 349867, 23987)
|
val seeds = Array(123, 5354, 230, 349867, 23987)
|
||||||
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
|
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000).map(_.asML)
|
||||||
val rdd = sc.parallelize(arr)
|
val rdd = sc.parallelize(arr)
|
||||||
seeds.foreach { seed =>
|
seeds.foreach { seed =>
|
||||||
val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true, seed)
|
val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true,
|
||||||
val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
|
(_: LabeledPoint) => 2.0, seed)
|
||||||
|
val subsampleCounts: Array[Array[Double]] =
|
||||||
|
baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect()
|
||||||
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
|
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
|
||||||
expectedStddev, epsilon = 0.01)
|
expectedStddev, epsilon = 0.01)
|
||||||
|
// should ignore weight function for now
|
||||||
|
assert(baggedRDD.collect().forall(_.sampleWeight === 1.0))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -59,8 +68,10 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
|
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
|
||||||
val rdd = sc.parallelize(arr)
|
val rdd = sc.parallelize(arr)
|
||||||
seeds.foreach { seed =>
|
seeds.foreach { seed =>
|
||||||
val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true, seed)
|
val baggedRDD =
|
||||||
val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
|
BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true, seed = seed)
|
||||||
|
val subsampleCounts: Array[Array[Double]] =
|
||||||
|
baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect()
|
||||||
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
|
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
|
||||||
expectedStddev, epsilon = 0.01)
|
expectedStddev, epsilon = 0.01)
|
||||||
}
|
}
|
||||||
|
@ -71,13 +82,17 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
val (expectedMean, expectedStddev) = (1.0, 0)
|
val (expectedMean, expectedStddev) = (1.0, 0)
|
||||||
|
|
||||||
val seeds = Array(123, 5354, 230, 349867, 23987)
|
val seeds = Array(123, 5354, 230, 349867, 23987)
|
||||||
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
|
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000).map(_.asML)
|
||||||
val rdd = sc.parallelize(arr)
|
val rdd = sc.parallelize(arr)
|
||||||
seeds.foreach { seed =>
|
seeds.foreach { seed =>
|
||||||
val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false, seed)
|
val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false,
|
||||||
val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
|
(_: LabeledPoint) => 2.0, seed)
|
||||||
|
val subsampleCounts: Array[Array[Double]] =
|
||||||
|
baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect()
|
||||||
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
|
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
|
||||||
expectedStddev, epsilon = 0.01)
|
expectedStddev, epsilon = 0.01)
|
||||||
|
// should ignore weight function for now
|
||||||
|
assert(baggedRDD.collect().forall(_.sampleWeight === 1.0))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -90,8 +105,10 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
|
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
|
||||||
val rdd = sc.parallelize(arr)
|
val rdd = sc.parallelize(arr)
|
||||||
seeds.foreach { seed =>
|
seeds.foreach { seed =>
|
||||||
val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false, seed)
|
val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false,
|
||||||
val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
|
seed = seed)
|
||||||
|
val subsampleCounts: Array[Array[Double]] =
|
||||||
|
baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect()
|
||||||
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
|
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
|
||||||
expectedStddev, epsilon = 0.01)
|
expectedStddev, epsilon = 0.01)
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,10 +19,11 @@ package org.apache.spark.ml.tree.impl
|
||||||
|
|
||||||
import scala.annotation.tailrec
|
import scala.annotation.tailrec
|
||||||
import scala.collection.mutable
|
import scala.collection.mutable
|
||||||
|
import scala.language.implicitConversions
|
||||||
|
|
||||||
import org.apache.spark.SparkFunSuite
|
import org.apache.spark.SparkFunSuite
|
||||||
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
|
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
|
||||||
import org.apache.spark.ml.feature.LabeledPoint
|
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
|
||||||
import org.apache.spark.ml.linalg.{Vector, Vectors}
|
import org.apache.spark.ml.linalg.{Vector, Vectors}
|
||||||
import org.apache.spark.ml.tree._
|
import org.apache.spark.ml.tree._
|
||||||
import org.apache.spark.ml.util.TestingUtils._
|
import org.apache.spark.ml.util.TestingUtils._
|
||||||
|
@ -46,7 +47,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
test("Binary classification with continuous features: split calculation") {
|
test("Binary classification with continuous features: split calculation") {
|
||||||
val arr = OldDTSuite.generateOrderedLabeledPointsWithLabel1().map(_.asML)
|
val arr = OldDTSuite.generateOrderedLabeledPointsWithLabel1().map(_.asML.toInstance)
|
||||||
assert(arr.length === 1000)
|
assert(arr.length === 1000)
|
||||||
val rdd = sc.parallelize(arr)
|
val rdd = sc.parallelize(arr)
|
||||||
val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2, 100)
|
val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2, 100)
|
||||||
|
@ -58,7 +59,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
}
|
}
|
||||||
|
|
||||||
test("Binary classification with binary (ordered) categorical features: split calculation") {
|
test("Binary classification with binary (ordered) categorical features: split calculation") {
|
||||||
val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML)
|
val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML.toInstance)
|
||||||
assert(arr.length === 1000)
|
assert(arr.length === 1000)
|
||||||
val rdd = sc.parallelize(arr)
|
val rdd = sc.parallelize(arr)
|
||||||
val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 2,
|
val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 2,
|
||||||
|
@ -75,7 +76,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
|
|
||||||
test("Binary classification with 3-ary (ordered) categorical features," +
|
test("Binary classification with 3-ary (ordered) categorical features," +
|
||||||
" with no samples for one category: split calculation") {
|
" with no samples for one category: split calculation") {
|
||||||
val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML)
|
val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML.toInstance)
|
||||||
assert(arr.length === 1000)
|
assert(arr.length === 1000)
|
||||||
val rdd = sc.parallelize(arr)
|
val rdd = sc.parallelize(arr)
|
||||||
val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 2,
|
val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 2,
|
||||||
|
@ -93,12 +94,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
test("find splits for a continuous feature") {
|
test("find splits for a continuous feature") {
|
||||||
// find splits for normal case
|
// find splits for normal case
|
||||||
{
|
{
|
||||||
val fakeMetadata = new DecisionTreeMetadata(1, 200000, 0, 0,
|
val fakeMetadata = new DecisionTreeMetadata(1, 200000, 200000.0, 0, 0,
|
||||||
Map(), Set(),
|
Map(), Set(),
|
||||||
Array(6), Gini, QuantileStrategy.Sort,
|
Array(6), Gini, QuantileStrategy.Sort,
|
||||||
0, 0, 0.0, 0, 0
|
0, 0, 0.0, 0.0, 0, 0
|
||||||
)
|
)
|
||||||
val featureSamples = Array.fill(10000)(math.random).filter(_ != 0.0)
|
val featureSamples = Array.fill(10000)((1.0, math.random)).filter(_._2 != 0.0)
|
||||||
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
|
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
|
||||||
assert(splits.length === 5)
|
assert(splits.length === 5)
|
||||||
assert(fakeMetadata.numSplits(0) === 5)
|
assert(fakeMetadata.numSplits(0) === 5)
|
||||||
|
@ -109,15 +110,16 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
|
|
||||||
// SPARK-16957: Use midpoints for split values.
|
// SPARK-16957: Use midpoints for split values.
|
||||||
{
|
{
|
||||||
val fakeMetadata = new DecisionTreeMetadata(1, 8, 0, 0,
|
val fakeMetadata = new DecisionTreeMetadata(1, 8, 8.0, 0, 0,
|
||||||
Map(), Set(),
|
Map(), Set(),
|
||||||
Array(3), Gini, QuantileStrategy.Sort,
|
Array(3), Gini, QuantileStrategy.Sort,
|
||||||
0, 0, 0.0, 0, 0
|
0, 0, 0.0, 0.0, 0, 0
|
||||||
)
|
)
|
||||||
|
|
||||||
// possibleSplits <= numSplits
|
// possibleSplits <= numSplits
|
||||||
{
|
{
|
||||||
val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1).map(_.toDouble).filter(_ != 0.0)
|
val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1)
|
||||||
|
.map(x => (1.0, x.toDouble)).filter(_._2 != 0.0)
|
||||||
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
|
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
|
||||||
val expectedSplits = Array((0.0 + 1.0) / 2)
|
val expectedSplits = Array((0.0 + 1.0) / 2)
|
||||||
assert(splits === expectedSplits)
|
assert(splits === expectedSplits)
|
||||||
|
@ -125,7 +127,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
|
|
||||||
// possibleSplits > numSplits
|
// possibleSplits > numSplits
|
||||||
{
|
{
|
||||||
val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3).map(_.toDouble).filter(_ != 0.0)
|
val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3)
|
||||||
|
.map(x => (1.0, x.toDouble)).filter(_._2 != 0.0)
|
||||||
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
|
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
|
||||||
val expectedSplits = Array((0.0 + 1.0) / 2, (2.0 + 3.0) / 2)
|
val expectedSplits = Array((0.0 + 1.0) / 2, (2.0 + 3.0) / 2)
|
||||||
assert(splits === expectedSplits)
|
assert(splits === expectedSplits)
|
||||||
|
@ -135,12 +138,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
// find splits should not return identical splits
|
// find splits should not return identical splits
|
||||||
// when there are not enough split candidates, reduce the number of splits in metadata
|
// when there are not enough split candidates, reduce the number of splits in metadata
|
||||||
{
|
{
|
||||||
val fakeMetadata = new DecisionTreeMetadata(1, 12, 0, 0,
|
val fakeMetadata = new DecisionTreeMetadata(1, 12, 12.0, 0, 0,
|
||||||
Map(), Set(),
|
Map(), Set(),
|
||||||
Array(5), Gini, QuantileStrategy.Sort,
|
Array(5), Gini, QuantileStrategy.Sort,
|
||||||
0, 0, 0.0, 0, 0
|
0, 0, 0.0, 0.0, 0, 0
|
||||||
)
|
)
|
||||||
val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3).map(_.toDouble)
|
val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3).map(x => (1.0, x.toDouble))
|
||||||
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
|
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
|
||||||
val expectedSplits = Array((1.0 + 2.0) / 2, (2.0 + 3.0) / 2)
|
val expectedSplits = Array((1.0 + 2.0) / 2, (2.0 + 3.0) / 2)
|
||||||
assert(splits === expectedSplits)
|
assert(splits === expectedSplits)
|
||||||
|
@ -150,13 +153,13 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
|
|
||||||
// find splits when most samples close to the minimum
|
// find splits when most samples close to the minimum
|
||||||
{
|
{
|
||||||
val fakeMetadata = new DecisionTreeMetadata(1, 18, 0, 0,
|
val fakeMetadata = new DecisionTreeMetadata(1, 18, 18.0, 0, 0,
|
||||||
Map(), Set(),
|
Map(), Set(),
|
||||||
Array(3), Gini, QuantileStrategy.Sort,
|
Array(3), Gini, QuantileStrategy.Sort,
|
||||||
0, 0, 0.0, 0, 0
|
0, 0, 0.0, 0.0, 0, 0
|
||||||
)
|
)
|
||||||
val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5)
|
val featureSamples =
|
||||||
.map(_.toDouble)
|
Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(x => (1.0, x.toDouble))
|
||||||
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
|
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
|
||||||
val expectedSplits = Array((2.0 + 3.0) / 2, (3.0 + 4.0) / 2)
|
val expectedSplits = Array((2.0 + 3.0) / 2, (3.0 + 4.0) / 2)
|
||||||
assert(splits === expectedSplits)
|
assert(splits === expectedSplits)
|
||||||
|
@ -164,37 +167,55 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
|
|
||||||
// find splits when most samples close to the maximum
|
// find splits when most samples close to the maximum
|
||||||
{
|
{
|
||||||
val fakeMetadata = new DecisionTreeMetadata(1, 17, 0, 0,
|
val fakeMetadata = new DecisionTreeMetadata(1, 17, 17.0, 0, 0,
|
||||||
Map(), Set(),
|
Map(), Set(),
|
||||||
Array(2), Gini, QuantileStrategy.Sort,
|
Array(2), Gini, QuantileStrategy.Sort,
|
||||||
0, 0, 0.0, 0, 0
|
0, 0, 0.0, 0.0, 0, 0
|
||||||
)
|
)
|
||||||
val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2)
|
val featureSamples =
|
||||||
.map(_.toDouble).filter(_ != 0.0)
|
Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(x => (1.0, x.toDouble))
|
||||||
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
|
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
|
||||||
val expectedSplits = Array((1.0 + 2.0) / 2)
|
val expectedSplits = Array((1.0 + 2.0) / 2)
|
||||||
assert(splits === expectedSplits)
|
assert(splits === expectedSplits)
|
||||||
}
|
}
|
||||||
|
|
||||||
// find splits for constant feature
|
// find splits for arbitrarily scaled data
|
||||||
{
|
{
|
||||||
val fakeMetadata = new DecisionTreeMetadata(1, 3, 0, 0,
|
val fakeMetadata = new DecisionTreeMetadata(1, 0, 0.0, 0, 0,
|
||||||
|
Map(), Set(),
|
||||||
|
Array(6), Gini, QuantileStrategy.Sort,
|
||||||
|
0, 0, 0.0, 0.0, 0, 0
|
||||||
|
)
|
||||||
|
val featureSamplesUnitWeight = Array.fill(10)((1.0, math.random))
|
||||||
|
val featureSamplesSmallWeight = featureSamplesUnitWeight.map { case (w, x) => (w * 0.001, x)}
|
||||||
|
val featureSamplesLargeWeight = featureSamplesUnitWeight.map { case (w, x) => (w * 1000, x)}
|
||||||
|
val splitsUnitWeight = RandomForest
|
||||||
|
.findSplitsForContinuousFeature(featureSamplesUnitWeight, fakeMetadata, 0)
|
||||||
|
val splitsSmallWeight = RandomForest
|
||||||
|
.findSplitsForContinuousFeature(featureSamplesSmallWeight, fakeMetadata, 0)
|
||||||
|
val splitsLargeWeight = RandomForest
|
||||||
|
.findSplitsForContinuousFeature(featureSamplesLargeWeight, fakeMetadata, 0)
|
||||||
|
assert(splitsUnitWeight === splitsSmallWeight)
|
||||||
|
assert(splitsUnitWeight === splitsLargeWeight)
|
||||||
|
}
|
||||||
|
|
||||||
|
// find splits when most weight is close to the minimum
|
||||||
|
{
|
||||||
|
val fakeMetadata = new DecisionTreeMetadata(1, 0, 0.0, 0, 0,
|
||||||
Map(), Set(),
|
Map(), Set(),
|
||||||
Array(3), Gini, QuantileStrategy.Sort,
|
Array(3), Gini, QuantileStrategy.Sort,
|
||||||
0, 0, 0.0, 0, 0
|
0, 0, 0.0, 0.0, 0, 0
|
||||||
)
|
)
|
||||||
val featureSamples = Array(0, 0, 0).map(_.toDouble).filter(_ != 0.0)
|
val featureSamples = Array((10, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6)).map {
|
||||||
val featureSamplesEmpty = Array.empty[Double]
|
case (w, x) => (w.toDouble, x.toDouble)
|
||||||
|
}
|
||||||
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
|
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
|
||||||
assert(splits === Array.empty[Double])
|
assert(splits === Array(1.5, 2.5, 3.5, 4.5, 5.5))
|
||||||
val splitsEmpty =
|
|
||||||
RandomForest.findSplitsForContinuousFeature(featureSamplesEmpty, fakeMetadata, 0)
|
|
||||||
assert(splitsEmpty === Array.empty[Double])
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
test("train with empty arrays") {
|
test("train with empty arrays") {
|
||||||
val lp = LabeledPoint(1.0, Vectors.dense(Array.empty[Double]))
|
val lp = LabeledPoint(1.0, Vectors.dense(Array.empty[Double])).toInstance
|
||||||
val data = Array.fill(5)(lp)
|
val data = Array.fill(5)(lp)
|
||||||
val rdd = sc.parallelize(data)
|
val rdd = sc.parallelize(data)
|
||||||
|
|
||||||
|
@ -209,8 +230,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
}
|
}
|
||||||
|
|
||||||
test("train with constant features") {
|
test("train with constant features") {
|
||||||
val lp = LabeledPoint(1.0, Vectors.dense(0.0, 0.0, 0.0))
|
val instance = LabeledPoint(1.0, Vectors.dense(0.0, 0.0, 0.0)).toInstance
|
||||||
val data = Array.fill(5)(lp)
|
val data = Array.fill(5)(instance)
|
||||||
val rdd = sc.parallelize(data)
|
val rdd = sc.parallelize(data)
|
||||||
val strategy = new OldStrategy(
|
val strategy = new OldStrategy(
|
||||||
OldAlgo.Classification,
|
OldAlgo.Classification,
|
||||||
|
@ -222,7 +243,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
val Array(tree) = RandomForest.run(rdd, strategy, 1, "all", 42L, instr = None)
|
val Array(tree) = RandomForest.run(rdd, strategy, 1, "all", 42L, instr = None)
|
||||||
assert(tree.rootNode.impurity === -1.0)
|
assert(tree.rootNode.impurity === -1.0)
|
||||||
assert(tree.depth === 0)
|
assert(tree.depth === 0)
|
||||||
assert(tree.rootNode.prediction === lp.label)
|
assert(tree.rootNode.prediction === instance.label)
|
||||||
|
|
||||||
// Test with no categorical features
|
// Test with no categorical features
|
||||||
val strategy2 = new OldStrategy(
|
val strategy2 = new OldStrategy(
|
||||||
|
@ -233,11 +254,11 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
val Array(tree2) = RandomForest.run(rdd, strategy2, 1, "all", 42L, instr = None)
|
val Array(tree2) = RandomForest.run(rdd, strategy2, 1, "all", 42L, instr = None)
|
||||||
assert(tree2.rootNode.impurity === -1.0)
|
assert(tree2.rootNode.impurity === -1.0)
|
||||||
assert(tree2.depth === 0)
|
assert(tree2.depth === 0)
|
||||||
assert(tree2.rootNode.prediction === lp.label)
|
assert(tree2.rootNode.prediction === instance.label)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("Multiclass classification with unordered categorical features: split calculations") {
|
test("Multiclass classification with unordered categorical features: split calculations") {
|
||||||
val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML)
|
val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML.toInstance)
|
||||||
assert(arr.length === 1000)
|
assert(arr.length === 1000)
|
||||||
val rdd = sc.parallelize(arr)
|
val rdd = sc.parallelize(arr)
|
||||||
val strategy = new OldStrategy(
|
val strategy = new OldStrategy(
|
||||||
|
@ -278,7 +299,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
}
|
}
|
||||||
|
|
||||||
test("Multiclass classification with ordered categorical features: split calculations") {
|
test("Multiclass classification with ordered categorical features: split calculations") {
|
||||||
val arr = OldDTSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures().map(_.asML)
|
val arr = OldDTSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
|
||||||
|
.map(_.asML.toInstance)
|
||||||
assert(arr.length === 3000)
|
assert(arr.length === 3000)
|
||||||
val rdd = sc.parallelize(arr)
|
val rdd = sc.parallelize(arr)
|
||||||
val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 100,
|
val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 100,
|
||||||
|
@ -310,7 +332,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)),
|
LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)),
|
||||||
LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)),
|
LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)),
|
||||||
LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
|
LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
|
||||||
val input = sc.parallelize(arr)
|
val input = sc.parallelize(arr.map(_.toInstance))
|
||||||
|
|
||||||
val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 1,
|
val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 1,
|
||||||
numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
|
numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
|
||||||
|
@ -352,7 +374,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)),
|
LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)),
|
||||||
LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)),
|
LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)),
|
||||||
LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
|
LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
|
||||||
val input = sc.parallelize(arr)
|
val input = sc.parallelize(arr.map(_.toInstance))
|
||||||
|
|
||||||
val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 5,
|
val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 5,
|
||||||
numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
|
numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
|
||||||
|
@ -404,7 +426,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
LabeledPoint(0.0, Vectors.dense(2.0)),
|
LabeledPoint(0.0, Vectors.dense(2.0)),
|
||||||
LabeledPoint(0.0, Vectors.dense(2.0)),
|
LabeledPoint(0.0, Vectors.dense(2.0)),
|
||||||
LabeledPoint(1.0, Vectors.dense(2.0)))
|
LabeledPoint(1.0, Vectors.dense(2.0)))
|
||||||
val input = sc.parallelize(arr)
|
val input = sc.parallelize(arr.map(_.toInstance))
|
||||||
|
|
||||||
// Must set maxBins s.t. the feature will be treated as an ordered categorical feature.
|
// Must set maxBins s.t. the feature will be treated as an ordered categorical feature.
|
||||||
val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 1,
|
val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 1,
|
||||||
|
@ -424,7 +446,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
}
|
}
|
||||||
|
|
||||||
test("Second level node building with vs. without groups") {
|
test("Second level node building with vs. without groups") {
|
||||||
val arr = OldDTSuite.generateOrderedLabeledPoints().map(_.asML)
|
val arr = OldDTSuite.generateOrderedLabeledPoints().map(_.asML.toInstance)
|
||||||
assert(arr.length === 1000)
|
assert(arr.length === 1000)
|
||||||
val rdd = sc.parallelize(arr)
|
val rdd = sc.parallelize(arr)
|
||||||
// For tree with 1 group
|
// For tree with 1 group
|
||||||
|
@ -468,7 +490,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
def binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy: OldStrategy) {
|
def binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy: OldStrategy) {
|
||||||
val numFeatures = 50
|
val numFeatures = 50
|
||||||
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures, 1000)
|
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures, 1000)
|
||||||
val rdd = sc.parallelize(arr).map(_.asML)
|
val rdd = sc.parallelize(arr).map(_.asML.toInstance)
|
||||||
|
|
||||||
// Select feature subset for top nodes. Return true if OK.
|
// Select feature subset for top nodes. Return true if OK.
|
||||||
def checkFeatureSubsetStrategy(
|
def checkFeatureSubsetStrategy(
|
||||||
|
@ -581,16 +603,16 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
left2 parent
|
left2 parent
|
||||||
left right
|
left right
|
||||||
*/
|
*/
|
||||||
val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0))
|
val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0), 6L)
|
||||||
val left = new LeafNode(0.0, leftImp.calculate(), leftImp)
|
val left = new LeafNode(0.0, leftImp.calculate(), leftImp)
|
||||||
|
|
||||||
val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0))
|
val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0), 8L)
|
||||||
val right = new LeafNode(2.0, rightImp.calculate(), rightImp)
|
val right = new LeafNode(2.0, rightImp.calculate(), rightImp)
|
||||||
|
|
||||||
val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5))
|
val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5))
|
||||||
val parentImp = parent.impurityStats
|
val parentImp = parent.impurityStats
|
||||||
|
|
||||||
val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0))
|
val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0), 8L)
|
||||||
val left2 = new LeafNode(0.0, left2Imp.calculate(), left2Imp)
|
val left2 = new LeafNode(0.0, left2Imp.calculate(), left2Imp)
|
||||||
|
|
||||||
val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0))
|
val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0))
|
||||||
|
@ -647,12 +669,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
// feature_0 = 0 improves the impurity measure, despite the prediction will always be 0
|
// feature_0 = 0 improves the impurity measure, despite the prediction will always be 0
|
||||||
// in both branches.
|
// in both branches.
|
||||||
val arr = Array(
|
val arr = Array(
|
||||||
LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
|
Instance(0.0, 1.0, Vectors.dense(0.0, 1.0)),
|
||||||
LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
|
Instance(1.0, 1.0, Vectors.dense(0.0, 1.0)),
|
||||||
LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
|
Instance(0.0, 1.0, Vectors.dense(0.0, 0.0)),
|
||||||
LabeledPoint(1.0, Vectors.dense(1.0, 0.0)),
|
Instance(1.0, 1.0, Vectors.dense(1.0, 0.0)),
|
||||||
LabeledPoint(0.0, Vectors.dense(1.0, 0.0)),
|
Instance(0.0, 1.0, Vectors.dense(1.0, 0.0)),
|
||||||
LabeledPoint(1.0, Vectors.dense(1.0, 1.0))
|
Instance(1.0, 1.0, Vectors.dense(1.0, 1.0))
|
||||||
)
|
)
|
||||||
val rdd = sc.parallelize(arr)
|
val rdd = sc.parallelize(arr)
|
||||||
|
|
||||||
|
@ -677,13 +699,13 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
// feature_1 = 1 improves the impurity measure, despite the prediction will always be 0.5
|
// feature_1 = 1 improves the impurity measure, despite the prediction will always be 0.5
|
||||||
// in both branches.
|
// in both branches.
|
||||||
val arr = Array(
|
val arr = Array(
|
||||||
LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
|
Instance(0.0, 1.0, Vectors.dense(0.0, 1.0)),
|
||||||
LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
|
Instance(1.0, 1.0, Vectors.dense(0.0, 1.0)),
|
||||||
LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
|
Instance(0.0, 1.0, Vectors.dense(0.0, 0.0)),
|
||||||
LabeledPoint(0.0, Vectors.dense(1.0, 0.0)),
|
Instance(0.0, 1.0, Vectors.dense(1.0, 0.0)),
|
||||||
LabeledPoint(1.0, Vectors.dense(1.0, 1.0)),
|
Instance(1.0, 1.0, Vectors.dense(1.0, 1.0)),
|
||||||
LabeledPoint(0.0, Vectors.dense(1.0, 1.0)),
|
Instance(0.0, 1.0, Vectors.dense(1.0, 1.0)),
|
||||||
LabeledPoint(0.5, Vectors.dense(1.0, 1.0))
|
Instance(0.5, 1.0, Vectors.dense(1.0, 1.0))
|
||||||
)
|
)
|
||||||
val rdd = sc.parallelize(arr)
|
val rdd = sc.parallelize(arr)
|
||||||
|
|
||||||
|
@ -700,6 +722,56 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
assert(unprunedTree.numNodes === 5)
|
assert(unprunedTree.numNodes === 5)
|
||||||
assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.size)
|
assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.size)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("weights at arbitrary scale") {
|
||||||
|
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(3, 10)
|
||||||
|
val rddWithUnitWeights = sc.parallelize(arr.map(_.asML.toInstance))
|
||||||
|
val rddWithSmallWeights = rddWithUnitWeights.map { inst =>
|
||||||
|
Instance(inst.label, 0.001, inst.features)
|
||||||
|
}
|
||||||
|
val rddWithBigWeights = rddWithUnitWeights.map { inst =>
|
||||||
|
Instance(inst.label, 1000, inst.features)
|
||||||
|
}
|
||||||
|
val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2)
|
||||||
|
val unitWeightTrees = RandomForest.run(rddWithUnitWeights, strategy, 3, "all", 42L, None)
|
||||||
|
|
||||||
|
val smallWeightTrees = RandomForest.run(rddWithSmallWeights, strategy, 3, "all", 42L, None)
|
||||||
|
unitWeightTrees.zip(smallWeightTrees).foreach { case (unitTree, smallWeightTree) =>
|
||||||
|
TreeTests.checkEqual(unitTree, smallWeightTree)
|
||||||
|
}
|
||||||
|
|
||||||
|
val bigWeightTrees = RandomForest.run(rddWithBigWeights, strategy, 3, "all", 42L, None)
|
||||||
|
unitWeightTrees.zip(bigWeightTrees).foreach { case (unitTree, bigWeightTree) =>
|
||||||
|
TreeTests.checkEqual(unitTree, bigWeightTree)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("minWeightFraction and minInstancesPerNode") {
|
||||||
|
val data = Array(
|
||||||
|
Instance(0.0, 1.0, Vectors.dense(0.0)),
|
||||||
|
Instance(0.0, 1.0, Vectors.dense(0.0)),
|
||||||
|
Instance(0.0, 1.0, Vectors.dense(0.0)),
|
||||||
|
Instance(0.0, 1.0, Vectors.dense(0.0)),
|
||||||
|
Instance(1.0, 0.1, Vectors.dense(1.0))
|
||||||
|
)
|
||||||
|
val rdd = sc.parallelize(data)
|
||||||
|
val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2,
|
||||||
|
minWeightFractionPerNode = 0.5)
|
||||||
|
val Array(tree1) = RandomForest.run(rdd, strategy, 1, "all", 42L, None)
|
||||||
|
assert(tree1.depth === 0)
|
||||||
|
|
||||||
|
strategy.minWeightFractionPerNode = 0.0
|
||||||
|
val Array(tree2) = RandomForest.run(rdd, strategy, 1, "all", 42L, None)
|
||||||
|
assert(tree2.depth === 1)
|
||||||
|
|
||||||
|
strategy.minInstancesPerNode = 2
|
||||||
|
val Array(tree3) = RandomForest.run(rdd, strategy, 1, "all", 42L, None)
|
||||||
|
assert(tree3.depth === 0)
|
||||||
|
|
||||||
|
strategy.minInstancesPerNode = 1
|
||||||
|
val Array(tree4) = RandomForest.run(rdd, strategy, 1, "all", 42L, None)
|
||||||
|
assert(tree4.depth === 1)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private object RandomForestSuite {
|
private object RandomForestSuite {
|
||||||
|
@ -717,7 +789,7 @@ private object RandomForestSuite {
|
||||||
else {
|
else {
|
||||||
nodes.head match {
|
nodes.head match {
|
||||||
case i: InternalNode => getSumLeafCounters(i.leftChild :: i.rightChild :: nodes.tail, acc)
|
case i: InternalNode => getSumLeafCounters(i.leftChild :: i.rightChild :: nodes.tail, acc)
|
||||||
case l: LeafNode => getSumLeafCounters(nodes.tail, acc + l.impurityStats.count)
|
case l: LeafNode => getSumLeafCounters(nodes.tail, acc + l.impurityStats.rawCount)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,7 +28,7 @@ class TreePointSuite extends SparkFunSuite {
|
||||||
|
|
||||||
val ser = new KryoSerializer(conf).newInstance()
|
val ser = new KryoSerializer(conf).newInstance()
|
||||||
|
|
||||||
val point = new TreePoint(1.0, Array(1, 2, 3))
|
val point = new TreePoint(1.0, Array(1, 2, 3), 1.0)
|
||||||
val point2 = ser.deserialize[TreePoint](ser.serialize(point))
|
val point2 = ser.deserialize[TreePoint](ser.serialize(point))
|
||||||
assert(point.label === point2.label)
|
assert(point.label === point2.label)
|
||||||
assert(point.binnedFeatures === point2.binnedFeatures)
|
assert(point.binnedFeatures === point2.binnedFeatures)
|
||||||
|
|
|
@ -18,13 +18,15 @@
|
||||||
package org.apache.spark.ml.tree.impl
|
package org.apache.spark.ml.tree.impl
|
||||||
|
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
|
import scala.util.Random
|
||||||
|
|
||||||
import org.apache.spark.{SparkContext, SparkFunSuite}
|
import org.apache.spark.{SparkContext, SparkFunSuite}
|
||||||
import org.apache.spark.api.java.JavaRDD
|
import org.apache.spark.api.java.JavaRDD
|
||||||
import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
|
import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
|
||||||
import org.apache.spark.ml.feature.LabeledPoint
|
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
|
||||||
import org.apache.spark.ml.linalg.Vectors
|
import org.apache.spark.ml.linalg.Vectors
|
||||||
import org.apache.spark.ml.tree._
|
import org.apache.spark.ml.tree._
|
||||||
|
import org.apache.spark.mllib.util.TestingUtils._
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.{DataFrame, SparkSession}
|
import org.apache.spark.sql.{DataFrame, SparkSession}
|
||||||
|
|
||||||
|
@ -32,6 +34,7 @@ private[ml] object TreeTests extends SparkFunSuite {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert the given data to a DataFrame, and set the features and label metadata.
|
* Convert the given data to a DataFrame, and set the features and label metadata.
|
||||||
|
*
|
||||||
* @param data Dataset. Categorical features and labels must already have 0-based indices.
|
* @param data Dataset. Categorical features and labels must already have 0-based indices.
|
||||||
* This must be non-empty.
|
* This must be non-empty.
|
||||||
* @param categoricalFeatures Map: categorical feature index to number of distinct values
|
* @param categoricalFeatures Map: categorical feature index to number of distinct values
|
||||||
|
@ -39,16 +42,22 @@ private[ml] object TreeTests extends SparkFunSuite {
|
||||||
* @return DataFrame with metadata
|
* @return DataFrame with metadata
|
||||||
*/
|
*/
|
||||||
def setMetadata(
|
def setMetadata(
|
||||||
data: RDD[LabeledPoint],
|
data: RDD[_],
|
||||||
categoricalFeatures: Map[Int, Int],
|
categoricalFeatures: Map[Int, Int],
|
||||||
numClasses: Int): DataFrame = {
|
numClasses: Int): DataFrame = {
|
||||||
|
val dataOfInstance: RDD[Instance] = data.map {
|
||||||
|
_ match {
|
||||||
|
case instance: Instance => instance
|
||||||
|
case labeledPoint: LabeledPoint => labeledPoint.toInstance
|
||||||
|
}
|
||||||
|
}
|
||||||
val spark = SparkSession.builder()
|
val spark = SparkSession.builder()
|
||||||
.sparkContext(data.sparkContext)
|
.sparkContext(data.sparkContext)
|
||||||
.getOrCreate()
|
.getOrCreate()
|
||||||
import spark.implicits._
|
import spark.implicits._
|
||||||
|
|
||||||
val df = data.toDF()
|
val df = dataOfInstance.toDF()
|
||||||
val numFeatures = data.first().features.size
|
val numFeatures = dataOfInstance.first().features.size
|
||||||
val featuresAttributes = Range(0, numFeatures).map { feature =>
|
val featuresAttributes = Range(0, numFeatures).map { feature =>
|
||||||
if (categoricalFeatures.contains(feature)) {
|
if (categoricalFeatures.contains(feature)) {
|
||||||
NominalAttribute.defaultAttr.withIndex(feature).withNumValues(categoricalFeatures(feature))
|
NominalAttribute.defaultAttr.withIndex(feature).withNumValues(categoricalFeatures(feature))
|
||||||
|
@ -64,7 +73,7 @@ private[ml] object TreeTests extends SparkFunSuite {
|
||||||
}
|
}
|
||||||
val labelMetadata = labelAttribute.toMetadata()
|
val labelMetadata = labelAttribute.toMetadata()
|
||||||
df.select(df("features").as("features", featuresMetadata),
|
df.select(df("features").as("features", featuresMetadata),
|
||||||
df("label").as("label", labelMetadata))
|
df("label").as("label", labelMetadata), df("weight"))
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -80,6 +89,7 @@ private[ml] object TreeTests extends SparkFunSuite {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set label metadata (particularly the number of classes) on a DataFrame.
|
* Set label metadata (particularly the number of classes) on a DataFrame.
|
||||||
|
*
|
||||||
* @param data Dataset. Categorical features and labels must already have 0-based indices.
|
* @param data Dataset. Categorical features and labels must already have 0-based indices.
|
||||||
* This must be non-empty.
|
* This must be non-empty.
|
||||||
* @param numClasses Number of classes label can take. If 0, mark as continuous.
|
* @param numClasses Number of classes label can take. If 0, mark as continuous.
|
||||||
|
@ -124,8 +134,8 @@ private[ml] object TreeTests extends SparkFunSuite {
|
||||||
* make mistakes such as creating loops of Nodes.
|
* make mistakes such as creating loops of Nodes.
|
||||||
*/
|
*/
|
||||||
private def checkEqual(a: Node, b: Node): Unit = {
|
private def checkEqual(a: Node, b: Node): Unit = {
|
||||||
assert(a.prediction === b.prediction)
|
assert(a.prediction ~== b.prediction absTol 1e-8)
|
||||||
assert(a.impurity === b.impurity)
|
assert(a.impurity ~== b.impurity absTol 1e-8)
|
||||||
(a, b) match {
|
(a, b) match {
|
||||||
case (aye: InternalNode, bee: InternalNode) =>
|
case (aye: InternalNode, bee: InternalNode) =>
|
||||||
assert(aye.split === bee.split)
|
assert(aye.split === bee.split)
|
||||||
|
@ -156,6 +166,7 @@ private[ml] object TreeTests extends SparkFunSuite {
|
||||||
/**
|
/**
|
||||||
* Helper method for constructing a tree for testing.
|
* Helper method for constructing a tree for testing.
|
||||||
* Given left, right children, construct a parent node.
|
* Given left, right children, construct a parent node.
|
||||||
|
*
|
||||||
* @param split Split for parent node
|
* @param split Split for parent node
|
||||||
* @return Parent node with children attached
|
* @return Parent node with children attached
|
||||||
*/
|
*/
|
||||||
|
@ -163,8 +174,8 @@ private[ml] object TreeTests extends SparkFunSuite {
|
||||||
val leftImp = left.impurityStats
|
val leftImp = left.impurityStats
|
||||||
val rightImp = right.impurityStats
|
val rightImp = right.impurityStats
|
||||||
val parentImp = leftImp.copy.add(rightImp)
|
val parentImp = leftImp.copy.add(rightImp)
|
||||||
val leftWeight = leftImp.count / parentImp.count.toDouble
|
val leftWeight = leftImp.count / parentImp.count
|
||||||
val rightWeight = rightImp.count / parentImp.count.toDouble
|
val rightWeight = rightImp.count / parentImp.count
|
||||||
val gain = parentImp.calculate() -
|
val gain = parentImp.calculate() -
|
||||||
(leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate())
|
(leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate())
|
||||||
val pred = parentImp.predict
|
val pred = parentImp.predict
|
||||||
|
|
|
@ -23,7 +23,7 @@ import org.apache.spark.ml.evaluation.Evaluator
|
||||||
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
|
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
|
||||||
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
|
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
|
||||||
import org.apache.spark.ml.param.ParamMap
|
import org.apache.spark.ml.param.ParamMap
|
||||||
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasWeightCol}
|
import org.apache.spark.ml.param.shared.HasWeightCol
|
||||||
import org.apache.spark.ml.recommendation.{ALS, ALSModel}
|
import org.apache.spark.ml.recommendation.{ALS, ALSModel}
|
||||||
import org.apache.spark.ml.tree.impl.TreeTests
|
import org.apache.spark.ml.tree.impl.TreeTests
|
||||||
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
|
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
|
||||||
|
@ -205,8 +205,8 @@ object MLTestingUtils extends SparkFunSuite {
|
||||||
seed: Long): Unit = {
|
seed: Long): Unit = {
|
||||||
val (overSampledData, weightedData) = genEquivalentOversampledAndWeightedInstances(
|
val (overSampledData, weightedData) = genEquivalentOversampledAndWeightedInstances(
|
||||||
data, seed)
|
data, seed)
|
||||||
val weightedModel = estimator.set(estimator.weightCol, "weight").fit(weightedData)
|
|
||||||
val overSampledModel = estimator.set(estimator.weightCol, "").fit(overSampledData)
|
val overSampledModel = estimator.set(estimator.weightCol, "").fit(overSampledData)
|
||||||
|
val weightedModel = estimator.set(estimator.weightCol, "weight").fit(weightedData)
|
||||||
modelEquals(weightedModel, overSampledModel)
|
modelEquals(weightedModel, overSampledModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -228,7 +228,8 @@ object MLTestingUtils extends SparkFunSuite {
|
||||||
List.fill(outlierRatio)(Instance(outlierLabel, 0.0001, f)) ++ List(Instance(l, w, f))
|
List.fill(outlierRatio)(Instance(outlierLabel, 0.0001, f)) ++ List(Instance(l, w, f))
|
||||||
}
|
}
|
||||||
val trueModel = estimator.set(estimator.weightCol, "").fit(data)
|
val trueModel = estimator.set(estimator.weightCol, "").fit(data)
|
||||||
val outlierModel = estimator.set(estimator.weightCol, "weight").fit(outlierDS)
|
val outlierModel = estimator.set(estimator.weightCol, "weight")
|
||||||
|
.fit(outlierDS)
|
||||||
modelEquals(trueModel, outlierModel)
|
modelEquals(trueModel, outlierModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -241,7 +242,7 @@ object MLTestingUtils extends SparkFunSuite {
|
||||||
estimator: E with HasWeightCol,
|
estimator: E with HasWeightCol,
|
||||||
modelEquals: (M, M) => Unit): Unit = {
|
modelEquals: (M, M) => Unit): Unit = {
|
||||||
estimator.set(estimator.weightCol, "weight")
|
estimator.set(estimator.weightCol, "weight")
|
||||||
val models = Seq(0.001, 1.0, 1000.0).map { w =>
|
val models = Seq(0.01, 1.0, 1000.0).map { w =>
|
||||||
val df = data.withColumn("weight", lit(w))
|
val df = data.withColumn("weight", lit(w))
|
||||||
estimator.fit(df)
|
estimator.fit(df)
|
||||||
}
|
}
|
||||||
|
@ -268,4 +269,20 @@ object MLTestingUtils extends SparkFunSuite {
|
||||||
assert(newDatasetF.schema(featuresColName).dataType.equals(new ArrayType(FloatType, false)))
|
assert(newDatasetF.schema(featuresColName).dataType.equals(new ArrayType(FloatType, false)))
|
||||||
(newDataset, newDatasetD, newDatasetF)
|
(newDataset, newDatasetD, newDatasetF)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def modelPredictionEquals[M <: PredictionModel[_, M]](
|
||||||
|
data: DataFrame,
|
||||||
|
compareFunc: (Double, Double) => Boolean,
|
||||||
|
fractionInTol: Double)(
|
||||||
|
model1: M,
|
||||||
|
model2: M): Unit = {
|
||||||
|
val pred1 = model1.transform(data).select(model1.getPredictionCol).collect()
|
||||||
|
val pred2 = model2.transform(data).select(model2.getPredictionCol).collect()
|
||||||
|
val inTol = pred1.zip(pred2).count { case (p1, p2) =>
|
||||||
|
val x = p1.getDouble(0)
|
||||||
|
val y = p2.getDouble(0)
|
||||||
|
compareFunc(x, y)
|
||||||
|
}
|
||||||
|
assert(inTol / pred1.length.toDouble >= fractionInTol)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -73,7 +73,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
maxBins = 100,
|
maxBins = 100,
|
||||||
categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
|
categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
|
||||||
|
|
||||||
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
|
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
|
||||||
assert(!metadata.isUnordered(featureIndex = 0))
|
assert(!metadata.isUnordered(featureIndex = 0))
|
||||||
assert(!metadata.isUnordered(featureIndex = 1))
|
assert(!metadata.isUnordered(featureIndex = 1))
|
||||||
|
|
||||||
|
@ -100,7 +100,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
maxDepth = 2,
|
maxDepth = 2,
|
||||||
maxBins = 100,
|
maxBins = 100,
|
||||||
categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2))
|
categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2))
|
||||||
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
|
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
|
||||||
assert(!metadata.isUnordered(featureIndex = 0))
|
assert(!metadata.isUnordered(featureIndex = 0))
|
||||||
assert(!metadata.isUnordered(featureIndex = 1))
|
assert(!metadata.isUnordered(featureIndex = 1))
|
||||||
|
|
||||||
|
@ -116,7 +116,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
val rdd = sc.parallelize(arr)
|
val rdd = sc.parallelize(arr)
|
||||||
val strategy = new Strategy(Classification, Gini, maxDepth = 3,
|
val strategy = new Strategy(Classification, Gini, maxDepth = 3,
|
||||||
numClasses = 2, maxBins = 100)
|
numClasses = 2, maxBins = 100)
|
||||||
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
|
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
|
||||||
assert(!metadata.isUnordered(featureIndex = 0))
|
assert(!metadata.isUnordered(featureIndex = 0))
|
||||||
assert(!metadata.isUnordered(featureIndex = 1))
|
assert(!metadata.isUnordered(featureIndex = 1))
|
||||||
|
|
||||||
|
@ -133,7 +133,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
val rdd = sc.parallelize(arr)
|
val rdd = sc.parallelize(arr)
|
||||||
val strategy = new Strategy(Classification, Gini, maxDepth = 3,
|
val strategy = new Strategy(Classification, Gini, maxDepth = 3,
|
||||||
numClasses = 2, maxBins = 100)
|
numClasses = 2, maxBins = 100)
|
||||||
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
|
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
|
||||||
assert(!metadata.isUnordered(featureIndex = 0))
|
assert(!metadata.isUnordered(featureIndex = 0))
|
||||||
assert(!metadata.isUnordered(featureIndex = 1))
|
assert(!metadata.isUnordered(featureIndex = 1))
|
||||||
|
|
||||||
|
@ -150,7 +150,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
val rdd = sc.parallelize(arr)
|
val rdd = sc.parallelize(arr)
|
||||||
val strategy = new Strategy(Classification, Entropy, maxDepth = 3,
|
val strategy = new Strategy(Classification, Entropy, maxDepth = 3,
|
||||||
numClasses = 2, maxBins = 100)
|
numClasses = 2, maxBins = 100)
|
||||||
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
|
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
|
||||||
assert(!metadata.isUnordered(featureIndex = 0))
|
assert(!metadata.isUnordered(featureIndex = 0))
|
||||||
assert(!metadata.isUnordered(featureIndex = 1))
|
assert(!metadata.isUnordered(featureIndex = 1))
|
||||||
|
|
||||||
|
@ -167,7 +167,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
val rdd = sc.parallelize(arr)
|
val rdd = sc.parallelize(arr)
|
||||||
val strategy = new Strategy(Classification, Entropy, maxDepth = 3,
|
val strategy = new Strategy(Classification, Entropy, maxDepth = 3,
|
||||||
numClasses = 2, maxBins = 100)
|
numClasses = 2, maxBins = 100)
|
||||||
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
|
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
|
||||||
assert(!metadata.isUnordered(featureIndex = 0))
|
assert(!metadata.isUnordered(featureIndex = 0))
|
||||||
assert(!metadata.isUnordered(featureIndex = 1))
|
assert(!metadata.isUnordered(featureIndex = 1))
|
||||||
|
|
||||||
|
@ -183,7 +183,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
val rdd = sc.parallelize(arr)
|
val rdd = sc.parallelize(arr)
|
||||||
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
|
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
|
||||||
numClasses = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
|
numClasses = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
|
||||||
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
|
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
|
||||||
assert(strategy.isMulticlassClassification)
|
assert(strategy.isMulticlassClassification)
|
||||||
assert(metadata.isUnordered(featureIndex = 0))
|
assert(metadata.isUnordered(featureIndex = 0))
|
||||||
assert(metadata.isUnordered(featureIndex = 1))
|
assert(metadata.isUnordered(featureIndex = 1))
|
||||||
|
@ -240,7 +240,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
numClasses = 3, maxBins = maxBins,
|
numClasses = 3, maxBins = maxBins,
|
||||||
categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
|
categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
|
||||||
assert(strategy.isMulticlassClassification)
|
assert(strategy.isMulticlassClassification)
|
||||||
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
|
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
|
||||||
assert(metadata.isUnordered(featureIndex = 0))
|
assert(metadata.isUnordered(featureIndex = 0))
|
||||||
assert(metadata.isUnordered(featureIndex = 1))
|
assert(metadata.isUnordered(featureIndex = 1))
|
||||||
|
|
||||||
|
@ -288,7 +288,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
|
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
|
||||||
numClasses = 3, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3))
|
numClasses = 3, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3))
|
||||||
assert(strategy.isMulticlassClassification)
|
assert(strategy.isMulticlassClassification)
|
||||||
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
|
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
|
||||||
assert(metadata.isUnordered(featureIndex = 0))
|
assert(metadata.isUnordered(featureIndex = 0))
|
||||||
|
|
||||||
val model = DecisionTree.train(rdd, strategy)
|
val model = DecisionTree.train(rdd, strategy)
|
||||||
|
@ -310,7 +310,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
numClasses = 3, maxBins = 100,
|
numClasses = 3, maxBins = 100,
|
||||||
categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
|
categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
|
||||||
assert(strategy.isMulticlassClassification)
|
assert(strategy.isMulticlassClassification)
|
||||||
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
|
val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
|
||||||
assert(!metadata.isUnordered(featureIndex = 0))
|
assert(!metadata.isUnordered(featureIndex = 0))
|
||||||
assert(!metadata.isUnordered(featureIndex = 1))
|
assert(!metadata.isUnordered(featureIndex = 1))
|
||||||
|
|
||||||
|
|
|
@ -18,23 +18,63 @@
|
||||||
package org.apache.spark.mllib.tree
|
package org.apache.spark.mllib.tree
|
||||||
|
|
||||||
import org.apache.spark.SparkFunSuite
|
import org.apache.spark.SparkFunSuite
|
||||||
import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator}
|
import org.apache.spark.ml.util.TestingUtils._
|
||||||
|
import org.apache.spark.mllib.tree.impurity._
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Test suites for `GiniAggregator` and `EntropyAggregator`.
|
* Test suites for `GiniAggregator` and `EntropyAggregator`.
|
||||||
*/
|
*/
|
||||||
class ImpuritySuite extends SparkFunSuite {
|
class ImpuritySuite extends SparkFunSuite {
|
||||||
|
|
||||||
|
private val seed = 42
|
||||||
|
|
||||||
test("Gini impurity does not support negative labels") {
|
test("Gini impurity does not support negative labels") {
|
||||||
val gini = new GiniAggregator(2)
|
val gini = new GiniAggregator(2)
|
||||||
intercept[IllegalArgumentException] {
|
intercept[IllegalArgumentException] {
|
||||||
gini.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0)
|
gini.update(Array(0.0, 1.0, 2.0), 0, -1, 3, 0.0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
test("Entropy does not support negative labels") {
|
test("Entropy does not support negative labels") {
|
||||||
val entropy = new EntropyAggregator(2)
|
val entropy = new EntropyAggregator(2)
|
||||||
intercept[IllegalArgumentException] {
|
intercept[IllegalArgumentException] {
|
||||||
entropy.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0)
|
entropy.update(Array(0.0, 1.0, 2.0), 0, -1, 3, 0.0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("Classification impurities are insensitive to scaling") {
|
||||||
|
val rng = new scala.util.Random(seed)
|
||||||
|
val weightedCounts = Array.fill(5)(rng.nextDouble())
|
||||||
|
val smallWeightedCounts = weightedCounts.map(_ * 0.0001)
|
||||||
|
val largeWeightedCounts = weightedCounts.map(_ * 10000)
|
||||||
|
Seq(Gini, Entropy).foreach { impurity =>
|
||||||
|
val impurity1 = impurity.calculate(weightedCounts, weightedCounts.sum)
|
||||||
|
assert(impurity.calculate(smallWeightedCounts, smallWeightedCounts.sum)
|
||||||
|
~== impurity1 relTol 0.005)
|
||||||
|
assert(impurity.calculate(largeWeightedCounts, largeWeightedCounts.sum)
|
||||||
|
~== impurity1 relTol 0.005)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("Regression impurities are insensitive to scaling") {
|
||||||
|
def computeStats(samples: Seq[Double], weights: Seq[Double]): (Double, Double, Double) = {
|
||||||
|
samples.zip(weights).foldLeft((0.0, 0.0, 0.0)) { case ((wn, wy, wyy), (y, w)) =>
|
||||||
|
(wn + w, wy + w * y, wyy + w * y * y)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
val rng = new scala.util.Random(seed)
|
||||||
|
val samples = Array.fill(10)(rng.nextDouble())
|
||||||
|
val _weights = Array.fill(10)(rng.nextDouble())
|
||||||
|
val smallWeights = _weights.map(_ * 0.0001)
|
||||||
|
val largeWeights = _weights.map(_ * 10000)
|
||||||
|
val (count, sum, sumSquared) = computeStats(samples, _weights)
|
||||||
|
Seq(Variance).foreach { impurity =>
|
||||||
|
val impurity1 = impurity.calculate(count, sum, sumSquared)
|
||||||
|
val (smallCount, smallSum, smallSumSquared) = computeStats(samples, smallWeights)
|
||||||
|
val (largeCount, largeSum, largeSumSquared) = computeStats(samples, largeWeights)
|
||||||
|
assert(impurity.calculate(smallCount, smallSum, smallSumSquared) ~== impurity1 relTol 0.005)
|
||||||
|
assert(impurity.calculate(largeCount, largeSum, largeSumSquared) ~== impurity1 relTol 0.005)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue