From 098be27ad53c485ee2fc7f5871c47f899020e87b Mon Sep 17 00:00:00 2001 From: sethah Date: Wed, 23 Sep 2015 15:00:52 -0700 Subject: [PATCH] [SPARK-9715] [ML] Store numFeatures in all ML PredictionModel types All prediction models should store `numFeatures` indicating the number of features the model was trained on. Default value of -1 added for backwards compatibility. Author: sethah Closes #8675 from sethah/SPARK-9715. --- .../examples/ml/JavaDeveloperApiExample.java | 5 ++++ .../examples/ml/DeveloperApiExample.scala | 3 +++ .../scala/org/apache/spark/ml/Predictor.scala | 6 ++++- .../DecisionTreeClassifier.scala | 13 ++++++---- .../ml/classification/GBTClassifier.scala | 26 ++++++++++++++----- .../classification/LogisticRegression.scala | 2 ++ .../MultilayerPerceptronClassifier.scala | 2 ++ .../spark/ml/classification/NaiveBayes.scala | 2 ++ .../RandomForestClassifier.scala | 8 +++--- .../ml/regression/DecisionTreeRegressor.scala | 13 ++++++---- .../spark/ml/regression/GBTRegressor.scala | 24 ++++++++++++----- .../ml/regression/LinearRegression.scala | 2 ++ .../ml/regression/RandomForestRegressor.scala | 7 ++--- .../spark/ml/tree/impl/RandomForest.scala | 14 +++++++--- .../DecisionTreeClassifierSuite.scala | 4 ++- .../classification/GBTClassifierSuite.scala | 11 +++++--- .../LogisticRegressionSuite.scala | 2 ++ .../MultilayerPerceptronClassifierSuite.scala | 4 ++- .../ProbabilisticClassifierSuite.scala | 6 +++-- .../RandomForestClassifierSuite.scala | 8 +++--- .../DecisionTreeRegressorSuite.scala | 2 ++ .../ml/regression/GBTRegressorSuite.scala | 7 +++-- .../ml/regression/LinearRegressionSuite.scala | 4 ++- .../RandomForestRegressorSuite.scala | 2 ++ .../ml/tree/impl/RandomForestSuite.scala | 3 ++- 25 files changed, 130 insertions(+), 50 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index a377694507..0b4c0d9ba9 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -219,6 +219,11 @@ class MyJavaLogisticRegressionModel */ public int numClasses() { return 2; } + /** + * Number of features the model was trained on. + */ + public int numFeatures() { return weights_.size(); } + /** * Create a copy of the model. * The copy is shallow, except for the embedded paramMap, which gets a deep copy. diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index 340c3559b1..3758edc561 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -172,6 +172,9 @@ private class MyLogisticRegressionModel( /** Number of classes the label can take. 2 indicates binary classification. */ override val numClasses: Int = 2 + /** Number of features the model was trained on. */ + override val numFeatures: Int = weights.size + /** * Create a copy of the model. * The copy is shallow, except for the embedded paramMap, which gets a deep copy. diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index 19fe039b8f..e0dcd427fa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils @@ -145,6 +145,10 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, /** @group setParam */ def setPredictionCol(value: String): M = set(predictionCol, value).asInstanceOf[M] + /** Returns the number of features the model was trained on. If unknown, returns -1 */ + @Since("1.6.0") + def numFeatures: Int = -1 + /** * Returns the SQL DataType corresponding to the FeaturesType type parameter. * diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index b8eb49f9bd..a6f6d463bf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -107,6 +107,7 @@ object DecisionTreeClassifier { final class DecisionTreeClassificationModel private[ml] ( override val uid: String, override val rootNode: Node, + override val numFeatures: Int, override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel] with DecisionTreeModel with Serializable { @@ -118,8 +119,8 @@ final class DecisionTreeClassificationModel private[ml] ( * Construct a decision tree classification model. * @param rootNode Root node of tree, with other nodes attached. */ - private[ml] def this(rootNode: Node, numClasses: Int) = - this(Identifiable.randomUID("dtc"), rootNode, numClasses) + private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) = + this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses) override protected def predict(features: Vector): Double = { rootNode.predictImpl(features).prediction @@ -141,7 +142,7 @@ final class DecisionTreeClassificationModel private[ml] ( } override def copy(extra: ParamMap): DecisionTreeClassificationModel = { - copyValues(new DecisionTreeClassificationModel(uid, rootNode, numClasses), extra) + copyValues(new DecisionTreeClassificationModel(uid, rootNode, numFeatures, numClasses), extra) .setParent(parent) } @@ -161,12 +162,14 @@ private[ml] object DecisionTreeClassificationModel { def fromOld( oldModel: OldDecisionTreeModel, parent: DecisionTreeClassifier, - categoricalFeatures: Map[Int, Int]): DecisionTreeClassificationModel = { + categoricalFeatures: Map[Int, Int], + numFeatures: Int = -1): DecisionTreeClassificationModel = { require(oldModel.algo == OldAlgo.Classification, s"Cannot convert non-classification DecisionTreeModel (old API) to" + s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}") val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc") - new DecisionTreeClassificationModel(uid, rootNode, -1) + // Can't infer number of features from old model, so default to -1 + new DecisionTreeClassificationModel(uid, rootNode, numFeatures, -1) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index ad8683648b..74aef94bf7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -33,7 +33,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{Row, DataFrame} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.DoubleType @@ -138,10 +138,11 @@ final class GBTClassifier(override val uid: String) require(numClasses == 2, s"GBTClassifier only supports binary classification but was given numClasses = $numClasses") val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) + val numFeatures = oldDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) val oldGBT = new OldGBT(boostingStrategy) val oldModel = oldGBT.run(oldDataset) - GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures) + GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures, numFeatures) } override def copy(extra: ParamMap): GBTClassifier = defaultCopy(extra) @@ -164,10 +165,11 @@ object GBTClassifier { * @param _treeWeights Weights for the decision trees in the ensemble. */ @Experimental -final class GBTClassificationModel( +final class GBTClassificationModel private[ml]( override val uid: String, private val _trees: Array[DecisionTreeRegressionModel], - private val _treeWeights: Array[Double]) + private val _treeWeights: Array[Double], + override val numFeatures: Int) extends PredictionModel[Vector, GBTClassificationModel] with TreeEnsembleModel with Serializable { @@ -175,6 +177,14 @@ final class GBTClassificationModel( require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" + s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).") + /** + * Construct a GBTClassificationModel + * @param _trees Decision trees in the ensemble. + * @param _treeWeights Weights for the decision trees in the ensemble. + */ + def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) = + this(uid, _trees, _treeWeights, -1) + override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] override def treeWeights: Array[Double] = _treeWeights @@ -196,7 +206,8 @@ final class GBTClassificationModel( } override def copy(extra: ParamMap): GBTClassificationModel = { - copyValues(new GBTClassificationModel(uid, _trees, _treeWeights), extra).setParent(parent) + copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures), + extra).setParent(parent) } override def toString: String = { @@ -215,7 +226,8 @@ private[ml] object GBTClassificationModel { def fromOld( oldModel: OldGBTModel, parent: GBTClassifier, - categoricalFeatures: Map[Int, Int]): GBTClassificationModel = { + categoricalFeatures: Map[Int, Int], + numFeatures: Int = -1): GBTClassificationModel = { require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" + s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).") val newTrees = oldModel.trees.map { tree => @@ -223,6 +235,6 @@ private[ml] object GBTClassificationModel { DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc") - new GBTClassificationModel(parent.uid, newTrees, oldModel.treeWeights) + new GBTClassificationModel(parent.uid, newTrees, oldModel.treeWeights, numFeatures) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index bd96e8d000..c17a7b0c36 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -426,6 +426,8 @@ class LogisticRegressionModel private[ml] ( 1.0 / (1.0 + math.exp(-m)) } + override val numFeatures: Int = weights.size + override val numClasses: Int = 2 private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 5f60dea91f..cd7462596d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -181,6 +181,8 @@ class MultilayerPerceptronClassificationModel private[ml] ( extends PredictionModel[Vector, MultilayerPerceptronClassificationModel] with Serializable { + override val numFeatures: Int = layers.head + private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights) /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 082ea1ffad..a14dcecbaf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -137,6 +137,8 @@ class NaiveBayesModel private[ml] ( throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") } + override val numFeatures: Int = theta.numCols + override val numClasses: Int = pi.size private def multinomialCalculation(features: Vector) = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index a6ebee1bb1..bae329692a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -119,13 +119,12 @@ object RandomForestClassifier { * features. * @param _trees Decision trees in the ensemble. * Warning: These have null parents. - * @param numFeatures Number of features used by this model */ @Experimental final class RandomForestClassificationModel private[ml] ( override val uid: String, private val _trees: Array[DecisionTreeClassificationModel], - val numFeatures: Int, + override val numFeatures: Int, override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel] with TreeEnsembleModel with Serializable { @@ -226,7 +225,8 @@ private[ml] object RandomForestClassificationModel { oldModel: OldRandomForestModel, parent: RandomForestClassifier, categoricalFeatures: Map[Int, Int], - numClasses: Int): RandomForestClassificationModel = { + numClasses: Int, + numFeatures: Int = -1): RandomForestClassificationModel = { require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" + s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).") val newTrees = oldModel.trees.map { tree => @@ -234,6 +234,6 @@ private[ml] object RandomForestClassificationModel { DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc") - new RandomForestClassificationModel(uid, newTrees, -1, numClasses) + new RandomForestClassificationModel(uid, newTrees, numFeatures, numClasses) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index d9a244bea2..88b79a4eb8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -96,7 +96,8 @@ object DecisionTreeRegressor { @Experimental final class DecisionTreeRegressionModel private[ml] ( override val uid: String, - override val rootNode: Node) + override val rootNode: Node, + override val numFeatures: Int) extends PredictionModel[Vector, DecisionTreeRegressionModel] with DecisionTreeModel with Serializable { @@ -107,14 +108,15 @@ final class DecisionTreeRegressionModel private[ml] ( * Construct a decision tree regression model. * @param rootNode Root node of tree, with other nodes attached. */ - private[ml] def this(rootNode: Node) = this(Identifiable.randomUID("dtr"), rootNode) + private[ml] def this(rootNode: Node, numFeatures: Int) = + this(Identifiable.randomUID("dtr"), rootNode, numFeatures) override protected def predict(features: Vector): Double = { rootNode.predictImpl(features).prediction } override def copy(extra: ParamMap): DecisionTreeRegressionModel = { - copyValues(new DecisionTreeRegressionModel(uid, rootNode), extra).setParent(parent) + copyValues(new DecisionTreeRegressionModel(uid, rootNode, numFeatures), extra).setParent(parent) } override def toString: String = { @@ -133,12 +135,13 @@ private[ml] object DecisionTreeRegressionModel { def fromOld( oldModel: OldDecisionTreeModel, parent: DecisionTreeRegressor, - categoricalFeatures: Map[Int, Int]): DecisionTreeRegressionModel = { + categoricalFeatures: Map[Int, Int], + numFeatures: Int = -1): DecisionTreeRegressionModel = { require(oldModel.algo == OldAlgo.Regression, s"Cannot convert non-regression DecisionTreeModel (old API) to" + s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}") val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtr") - new DecisionTreeRegressionModel(uid, rootNode) + new DecisionTreeRegressionModel(uid, rootNode, numFeatures) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index d841ecb9e5..65b5b3e072 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -128,10 +128,11 @@ final class GBTRegressor(override val uid: String) val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) + val numFeatures = oldDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) val oldGBT = new OldGBT(boostingStrategy) val oldModel = oldGBT.run(oldDataset) - GBTRegressionModel.fromOld(oldModel, this, categoricalFeatures) + GBTRegressionModel.fromOld(oldModel, this, categoricalFeatures, numFeatures) } override def copy(extra: ParamMap): GBTRegressor = defaultCopy(extra) @@ -154,10 +155,11 @@ object GBTRegressor { * @param _treeWeights Weights for the decision trees in the ensemble. */ @Experimental -final class GBTRegressionModel( +final class GBTRegressionModel private[ml]( override val uid: String, private val _trees: Array[DecisionTreeRegressionModel], - private val _treeWeights: Array[Double]) + private val _treeWeights: Array[Double], + override val numFeatures: Int) extends PredictionModel[Vector, GBTRegressionModel] with TreeEnsembleModel with Serializable { @@ -165,6 +167,14 @@ final class GBTRegressionModel( require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" + s" non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).") + /** + * Construct a GBTRegressionModel + * @param _trees Decision trees in the ensemble. + * @param _treeWeights Weights for the decision trees in the ensemble. + */ + def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) = + this(uid, _trees, _treeWeights, -1) + override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] override def treeWeights: Array[Double] = _treeWeights @@ -185,7 +195,8 @@ final class GBTRegressionModel( } override def copy(extra: ParamMap): GBTRegressionModel = { - copyValues(new GBTRegressionModel(uid, _trees, _treeWeights), extra).setParent(parent) + copyValues(new GBTRegressionModel(uid, _trees, _treeWeights, numFeatures), + extra).setParent(parent) } override def toString: String = { @@ -204,7 +215,8 @@ private[ml] object GBTRegressionModel { def fromOld( oldModel: OldGBTModel, parent: GBTRegressor, - categoricalFeatures: Map[Int, Int]): GBTRegressionModel = { + categoricalFeatures: Map[Int, Int], + numFeatures: Int = -1): GBTRegressionModel = { require(oldModel.algo == OldAlgo.Regression, "Cannot convert GradientBoostedTreesModel" + s" with algo=${oldModel.algo} (old API) to GBTRegressionModel (new API).") val newTrees = oldModel.trees.map { tree => @@ -212,6 +224,6 @@ private[ml] object GBTRegressionModel { DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtr") - new GBTRegressionModel(parent.uid, newTrees, oldModel.treeWeights) + new GBTRegressionModel(parent.uid, newTrees, oldModel.treeWeights, numFeatures) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 78a67c5fda..a77e702141 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -293,6 +293,8 @@ class LinearRegressionModel private[ml] ( private var trainingSummary: Option[LinearRegressionTrainingSummary] = None + override val numFeatures: Int = weights.size + /** * Gets summary (e.g. residuals, mse, r-squared ) of model on training set. An exception is * thrown if `trainingSummary == None`. diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index ddb7214416..64fc17247c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -115,7 +115,7 @@ object RandomForestRegressor { final class RandomForestRegressionModel private[ml] ( override val uid: String, private val _trees: Array[DecisionTreeRegressionModel], - val numFeatures: Int) + override val numFeatures: Int) extends PredictionModel[Vector, RandomForestRegressionModel] with TreeEnsembleModel with Serializable { @@ -187,13 +187,14 @@ private[ml] object RandomForestRegressionModel { def fromOld( oldModel: OldRandomForestModel, parent: RandomForestRegressor, - categoricalFeatures: Map[Int, Int]): RandomForestRegressionModel = { + categoricalFeatures: Map[Int, Int], + numFeatures: Int = -1): RandomForestRegressionModel = { require(oldModel.algo == OldAlgo.Regression, "Cannot convert RandomForestModel" + s" with algo=${oldModel.algo} (old API) to RandomForestRegressionModel (new API).") val newTrees = oldModel.trees.map { tree => // parent for each tree is null since there is no good way to set this. DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } - new RandomForestRegressionModel(parent.uid, newTrees, -1) + new RandomForestRegressionModel(parent.uid, newTrees, numFeatures) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 4ac51a4754..c494556085 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -179,22 +179,28 @@ private[ml] object RandomForest extends Logging { } } + val numFeatures = metadata.numFeatures + parentUID match { case Some(uid) => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(uid, rootNode.toNode, strategy.getNumClasses) + new DecisionTreeClassificationModel(uid, rootNode.toNode, numFeatures, + strategy.getNumClasses) } } else { - topNodes.map(rootNode => new DecisionTreeRegressionModel(uid, rootNode.toNode)) + topNodes.map { rootNode => + new DecisionTreeRegressionModel(uid, rootNode.toNode, numFeatures) + } } case None => if (strategy.algo == OldAlgo.Classification) { topNodes.map { rootNode => - new DecisionTreeClassificationModel(rootNode.toNode, strategy.getNumClasses) + new DecisionTreeClassificationModel(rootNode.toNode, numFeatures, + strategy.getNumClasses) } } else { - topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode)) + topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode, numFeatures)) } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index f680d8d3c4..815f6fd997 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -59,7 +59,7 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte test("params") { ParamsSuite.checkParams(new DecisionTreeClassifier) - val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2) + val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 1, 2) ParamsSuite.checkParams(model) } @@ -310,6 +310,7 @@ private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite { dt: DecisionTreeClassifier, categoricalFeatures: Map[Int, Int], numClasses: Int): Unit = { + val numFeatures = data.first().features.size val oldStrategy = dt.getOldStrategy(categoricalFeatures, numClasses) val oldTree = OldDecisionTree.train(data, oldStrategy) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) @@ -318,5 +319,6 @@ private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite { val oldTreeAsNew = DecisionTreeClassificationModel.fromOld( oldTree, newTree.parent.asInstanceOf[DecisionTreeClassifier], categoricalFeatures) TreeTests.checkEqual(oldTreeAsNew, newTree) + assert(newTree.numFeatures === numFeatures) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index e3909bccaa..039141aeb6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -59,8 +59,8 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { test("params") { ParamsSuite.checkParams(new GBTClassifier) val model = new GBTClassificationModel("gbtc", - Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null))), - Array(1.0)) + Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null), 1)), + Array(1.0), 1) ParamsSuite.checkParams(model) } @@ -145,7 +145,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { */ } -private object GBTClassifierSuite { +private object GBTClassifierSuite extends SparkFunSuite { /** * Train 2 models on the given dataset, one using the old API and one using the new API. @@ -156,6 +156,7 @@ private object GBTClassifierSuite { validationData: Option[RDD[LabeledPoint]], gbt: GBTClassifier, categoricalFeatures: Map[Int, Int]): Unit = { + val numFeatures = data.first().features.size val oldBoostingStrategy = gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) val oldGBT = new OldGBT(oldBoostingStrategy) @@ -164,7 +165,9 @@ private object GBTClassifierSuite { val newModel = gbt.fit(newData) // Use parent from newTree since this is not checked anyways. val oldModelAsNew = GBTClassificationModel.fromOld( - oldModel, newModel.parent.asInstanceOf[GBTClassifier], categoricalFeatures) + oldModel, newModel.parent.asInstanceOf[GBTClassifier], categoricalFeatures, numFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) + assert(newModel.numFeatures === numFeatures) + assert(oldModelAsNew.numFeatures === numFeatures) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index f5219f9f57..ec01998601 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -194,6 +194,8 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val model = lr.fit(dataset) assert(model.numClasses === 2) + val numFeatures = dataset.select("features").first().getAs[Vector](0).size + assert(model.numFeatures === numFeatures) val threshold = model.getThreshold val results = model.transform(dataset) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index ddc948f65d..2d1df9b2b8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.Row @@ -73,6 +73,8 @@ class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSp .setSeed(11L) .setMaxIter(numIterations) val model = trainer.fit(dataFrame) + val numFeatures = dataFrame.select("features").first().getAs[Vector](0).size + assert(model.numFeatures === numFeatures) val mlpPredictionAndLabels = model.transform(dataFrame).select("prediction", "label") .map { case Row(p: Double, l: Double) => (p, l) } // train multinomial logistic regression diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala index 8f50cb924e..fb5f00e064 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} final class TestProbabilisticClassificationModel( override val uid: String, + override val numFeatures: Int, override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, TestProbabilisticClassificationModel] { @@ -45,13 +46,14 @@ class ProbabilisticClassifierSuite extends SparkFunSuite { test("test thresholding") { val thresholds = Array(0.5, 0.2) - val testModel = new TestProbabilisticClassificationModel("myuid", 2).setThresholds(thresholds) + val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2) + .setThresholds(thresholds) assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 1.0))) === 1.0) assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 0.2))) === 0.0) } test("test thresholding not required") { - val testModel = new TestProbabilisticClassificationModel("myuid", 2) + val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2) assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 2.0))) === 1.0) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index b4403ec300..deb8ec771c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -68,7 +68,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte test("params") { ParamsSuite.checkParams(new RandomForestClassifier) val model = new RandomForestClassificationModel("rfc", - Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)), 2, 2) + Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 1, 2)), 2, 2) ParamsSuite.checkParams(model) } @@ -209,7 +209,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte */ } -private object RandomForestClassifierSuite { +private object RandomForestClassifierSuite extends SparkFunSuite { /** * Train 2 models on the given dataset, one using the old API and one using the new API. @@ -220,6 +220,7 @@ private object RandomForestClassifierSuite { rf: RandomForestClassifier, categoricalFeatures: Map[Int, Int], numClasses: Int): Unit = { + val numFeatures = data.first().features.size val oldStrategy = rf.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, rf.getOldImpurity) val oldModel = OldRandomForest.trainClassifier( @@ -233,6 +234,7 @@ private object RandomForestClassifierSuite { TreeTests.checkEqual(oldModelAsNew, newModel) assert(newModel.hasParent) assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent) - assert(newModel.numClasses == numClasses) + assert(newModel.numClasses === numClasses) + assert(newModel.numFeatures === numFeatures) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index b092bcd6a7..868fb8eecb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -89,6 +89,7 @@ private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite { data: RDD[LabeledPoint], dt: DecisionTreeRegressor, categoricalFeatures: Map[Int, Int]): Unit = { + val numFeatures = data.first().features.size val oldStrategy = dt.getOldStrategy(categoricalFeatures) val oldTree = OldDecisionTree.train(data, oldStrategy) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) @@ -97,5 +98,6 @@ private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite { val oldTreeAsNew = DecisionTreeRegressionModel.fromOld( oldTree, newTree.parent.asInstanceOf[DecisionTreeRegressor], categoricalFeatures) TreeTests.checkEqual(oldTreeAsNew, newTree) + assert(newTree.numFeatures === numFeatures) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index a68197b591..09326600e6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -156,7 +156,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { */ } -private object GBTRegressorSuite { +private object GBTRegressorSuite extends SparkFunSuite { /** * Train 2 models on the given dataset, one using the old API and one using the new API. @@ -167,6 +167,7 @@ private object GBTRegressorSuite { validationData: Option[RDD[LabeledPoint]], gbt: GBTRegressor, categoricalFeatures: Map[Int, Int]): Unit = { + val numFeatures = data.first().features.size val oldBoostingStrategy = gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) val oldGBT = new OldGBT(oldBoostingStrategy) val oldModel = oldGBT.run(data) @@ -174,7 +175,9 @@ private object GBTRegressorSuite { val newModel = gbt.fit(newData) // Use parent from newTree since this is not checked anyways. val oldModelAsNew = GBTRegressionModel.fromOld( - oldModel, newModel.parent.asInstanceOf[GBTRegressor], categoricalFeatures) + oldModel, newModel.parent.asInstanceOf[GBTRegressor], categoricalFeatures, numFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) + assert(newModel.numFeatures === numFeatures) + assert(oldModelAsNew.numFeatures === numFeatures) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 8428f4f00b..7cb9471e69 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -22,8 +22,8 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.MLTestingUtils -import org.apache.spark.mllib.linalg.{DenseVector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.linalg.{Vector, DenseVector, Vectors} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} @@ -87,6 +87,8 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(model.getPredictionCol === "prediction") assert(model.intercept !== 0.0) assert(model.hasParent) + val numFeatures = dataset.select("features").first().getAs[Vector](0).size + assert(model.numFeatures === numFeatures) } test("linear regression with intercept without regularization") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index 7b1b3f1148..7e751e4b55 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -137,6 +137,7 @@ private object RandomForestRegressorSuite extends SparkFunSuite { data: RDD[LabeledPoint], rf: RandomForestRegressor, categoricalFeatures: Map[Int, Int]): Unit = { + val numFeatures = data.first().features.size val oldStrategy = rf.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, rf.getOldImpurity) val oldModel = OldRandomForest.trainRegressor( @@ -147,5 +148,6 @@ private object RandomForestRegressorSuite extends SparkFunSuite { val oldModelAsNew = RandomForestRegressionModel.fromOld( oldModel, newModel.parent.asInstanceOf[RandomForestRegressor], categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) + assert(newModel.numFeatures === numFeatures) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index dc852795c7..d5c238e9ae 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -77,7 +77,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // Forest consisting of (full tree) + (internal node with 2 leafs) val trees = Array(parent, grandParent).map { root => - new DecisionTreeClassificationModel(root, numClasses = 3).asInstanceOf[DecisionTreeModel] + new DecisionTreeClassificationModel(root, numFeatures = 2, numClasses = 3) + .asInstanceOf[DecisionTreeModel] } val importances: Vector = RandomForest.featureImportances(trees, 2) val tree2norm = feature0importance + feature1importance