[SPARK-9715] [ML] Store numFeatures in all ML PredictionModel types
All prediction models should store `numFeatures` indicating the number of features the model was trained on. Default value of -1 added for backwards compatibility. Author: sethah <seth.hendrickson16@gmail.com> Closes #8675 from sethah/SPARK-9715.
This commit is contained in:
parent
a18208047f
commit
098be27ad5
|
@ -219,6 +219,11 @@ class MyJavaLogisticRegressionModel
|
||||||
*/
|
*/
|
||||||
public int numClasses() { return 2; }
|
public int numClasses() { return 2; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Number of features the model was trained on.
|
||||||
|
*/
|
||||||
|
public int numFeatures() { return weights_.size(); }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a copy of the model.
|
* Create a copy of the model.
|
||||||
* The copy is shallow, except for the embedded paramMap, which gets a deep copy.
|
* The copy is shallow, except for the embedded paramMap, which gets a deep copy.
|
||||||
|
|
|
@ -172,6 +172,9 @@ private class MyLogisticRegressionModel(
|
||||||
/** Number of classes the label can take. 2 indicates binary classification. */
|
/** Number of classes the label can take. 2 indicates binary classification. */
|
||||||
override val numClasses: Int = 2
|
override val numClasses: Int = 2
|
||||||
|
|
||||||
|
/** Number of features the model was trained on. */
|
||||||
|
override val numFeatures: Int = weights.size
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a copy of the model.
|
* Create a copy of the model.
|
||||||
* The copy is shallow, except for the embedded paramMap, which gets a deep copy.
|
* The copy is shallow, except for the embedded paramMap, which gets a deep copy.
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
|
|
||||||
package org.apache.spark.ml
|
package org.apache.spark.ml
|
||||||
|
|
||||||
import org.apache.spark.annotation.DeveloperApi
|
import org.apache.spark.annotation.{DeveloperApi, Since}
|
||||||
import org.apache.spark.ml.param._
|
import org.apache.spark.ml.param._
|
||||||
import org.apache.spark.ml.param.shared._
|
import org.apache.spark.ml.param.shared._
|
||||||
import org.apache.spark.ml.util.SchemaUtils
|
import org.apache.spark.ml.util.SchemaUtils
|
||||||
|
@ -145,6 +145,10 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
|
||||||
/** @group setParam */
|
/** @group setParam */
|
||||||
def setPredictionCol(value: String): M = set(predictionCol, value).asInstanceOf[M]
|
def setPredictionCol(value: String): M = set(predictionCol, value).asInstanceOf[M]
|
||||||
|
|
||||||
|
/** Returns the number of features the model was trained on. If unknown, returns -1 */
|
||||||
|
@Since("1.6.0")
|
||||||
|
def numFeatures: Int = -1
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the SQL DataType corresponding to the FeaturesType type parameter.
|
* Returns the SQL DataType corresponding to the FeaturesType type parameter.
|
||||||
*
|
*
|
||||||
|
|
|
@ -107,6 +107,7 @@ object DecisionTreeClassifier {
|
||||||
final class DecisionTreeClassificationModel private[ml] (
|
final class DecisionTreeClassificationModel private[ml] (
|
||||||
override val uid: String,
|
override val uid: String,
|
||||||
override val rootNode: Node,
|
override val rootNode: Node,
|
||||||
|
override val numFeatures: Int,
|
||||||
override val numClasses: Int)
|
override val numClasses: Int)
|
||||||
extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel]
|
extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel]
|
||||||
with DecisionTreeModel with Serializable {
|
with DecisionTreeModel with Serializable {
|
||||||
|
@ -118,8 +119,8 @@ final class DecisionTreeClassificationModel private[ml] (
|
||||||
* Construct a decision tree classification model.
|
* Construct a decision tree classification model.
|
||||||
* @param rootNode Root node of tree, with other nodes attached.
|
* @param rootNode Root node of tree, with other nodes attached.
|
||||||
*/
|
*/
|
||||||
private[ml] def this(rootNode: Node, numClasses: Int) =
|
private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) =
|
||||||
this(Identifiable.randomUID("dtc"), rootNode, numClasses)
|
this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses)
|
||||||
|
|
||||||
override protected def predict(features: Vector): Double = {
|
override protected def predict(features: Vector): Double = {
|
||||||
rootNode.predictImpl(features).prediction
|
rootNode.predictImpl(features).prediction
|
||||||
|
@ -141,7 +142,7 @@ final class DecisionTreeClassificationModel private[ml] (
|
||||||
}
|
}
|
||||||
|
|
||||||
override def copy(extra: ParamMap): DecisionTreeClassificationModel = {
|
override def copy(extra: ParamMap): DecisionTreeClassificationModel = {
|
||||||
copyValues(new DecisionTreeClassificationModel(uid, rootNode, numClasses), extra)
|
copyValues(new DecisionTreeClassificationModel(uid, rootNode, numFeatures, numClasses), extra)
|
||||||
.setParent(parent)
|
.setParent(parent)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -161,12 +162,14 @@ private[ml] object DecisionTreeClassificationModel {
|
||||||
def fromOld(
|
def fromOld(
|
||||||
oldModel: OldDecisionTreeModel,
|
oldModel: OldDecisionTreeModel,
|
||||||
parent: DecisionTreeClassifier,
|
parent: DecisionTreeClassifier,
|
||||||
categoricalFeatures: Map[Int, Int]): DecisionTreeClassificationModel = {
|
categoricalFeatures: Map[Int, Int],
|
||||||
|
numFeatures: Int = -1): DecisionTreeClassificationModel = {
|
||||||
require(oldModel.algo == OldAlgo.Classification,
|
require(oldModel.algo == OldAlgo.Classification,
|
||||||
s"Cannot convert non-classification DecisionTreeModel (old API) to" +
|
s"Cannot convert non-classification DecisionTreeModel (old API) to" +
|
||||||
s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}")
|
s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}")
|
||||||
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
|
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
|
||||||
val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc")
|
val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc")
|
||||||
new DecisionTreeClassificationModel(uid, rootNode, -1)
|
// Can't infer number of features from old model, so default to -1
|
||||||
|
new DecisionTreeClassificationModel(uid, rootNode, numFeatures, -1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,7 +33,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
|
||||||
import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss}
|
import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss}
|
||||||
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
|
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.DataFrame
|
import org.apache.spark.sql.{Row, DataFrame}
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
import org.apache.spark.sql.types.DoubleType
|
import org.apache.spark.sql.types.DoubleType
|
||||||
|
|
||||||
|
@ -138,10 +138,11 @@ final class GBTClassifier(override val uid: String)
|
||||||
require(numClasses == 2,
|
require(numClasses == 2,
|
||||||
s"GBTClassifier only supports binary classification but was given numClasses = $numClasses")
|
s"GBTClassifier only supports binary classification but was given numClasses = $numClasses")
|
||||||
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
|
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
|
||||||
|
val numFeatures = oldDataset.first().features.size
|
||||||
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
|
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
|
||||||
val oldGBT = new OldGBT(boostingStrategy)
|
val oldGBT = new OldGBT(boostingStrategy)
|
||||||
val oldModel = oldGBT.run(oldDataset)
|
val oldModel = oldGBT.run(oldDataset)
|
||||||
GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures)
|
GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures, numFeatures)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def copy(extra: ParamMap): GBTClassifier = defaultCopy(extra)
|
override def copy(extra: ParamMap): GBTClassifier = defaultCopy(extra)
|
||||||
|
@ -164,10 +165,11 @@ object GBTClassifier {
|
||||||
* @param _treeWeights Weights for the decision trees in the ensemble.
|
* @param _treeWeights Weights for the decision trees in the ensemble.
|
||||||
*/
|
*/
|
||||||
@Experimental
|
@Experimental
|
||||||
final class GBTClassificationModel(
|
final class GBTClassificationModel private[ml](
|
||||||
override val uid: String,
|
override val uid: String,
|
||||||
private val _trees: Array[DecisionTreeRegressionModel],
|
private val _trees: Array[DecisionTreeRegressionModel],
|
||||||
private val _treeWeights: Array[Double])
|
private val _treeWeights: Array[Double],
|
||||||
|
override val numFeatures: Int)
|
||||||
extends PredictionModel[Vector, GBTClassificationModel]
|
extends PredictionModel[Vector, GBTClassificationModel]
|
||||||
with TreeEnsembleModel with Serializable {
|
with TreeEnsembleModel with Serializable {
|
||||||
|
|
||||||
|
@ -175,6 +177,14 @@ final class GBTClassificationModel(
|
||||||
require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" +
|
require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" +
|
||||||
s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")
|
s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Construct a GBTClassificationModel
|
||||||
|
* @param _trees Decision trees in the ensemble.
|
||||||
|
* @param _treeWeights Weights for the decision trees in the ensemble.
|
||||||
|
*/
|
||||||
|
def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) =
|
||||||
|
this(uid, _trees, _treeWeights, -1)
|
||||||
|
|
||||||
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
|
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
|
||||||
|
|
||||||
override def treeWeights: Array[Double] = _treeWeights
|
override def treeWeights: Array[Double] = _treeWeights
|
||||||
|
@ -196,7 +206,8 @@ final class GBTClassificationModel(
|
||||||
}
|
}
|
||||||
|
|
||||||
override def copy(extra: ParamMap): GBTClassificationModel = {
|
override def copy(extra: ParamMap): GBTClassificationModel = {
|
||||||
copyValues(new GBTClassificationModel(uid, _trees, _treeWeights), extra).setParent(parent)
|
copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures),
|
||||||
|
extra).setParent(parent)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def toString: String = {
|
override def toString: String = {
|
||||||
|
@ -215,7 +226,8 @@ private[ml] object GBTClassificationModel {
|
||||||
def fromOld(
|
def fromOld(
|
||||||
oldModel: OldGBTModel,
|
oldModel: OldGBTModel,
|
||||||
parent: GBTClassifier,
|
parent: GBTClassifier,
|
||||||
categoricalFeatures: Map[Int, Int]): GBTClassificationModel = {
|
categoricalFeatures: Map[Int, Int],
|
||||||
|
numFeatures: Int = -1): GBTClassificationModel = {
|
||||||
require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" +
|
require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" +
|
||||||
s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).")
|
s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).")
|
||||||
val newTrees = oldModel.trees.map { tree =>
|
val newTrees = oldModel.trees.map { tree =>
|
||||||
|
@ -223,6 +235,6 @@ private[ml] object GBTClassificationModel {
|
||||||
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
|
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
|
||||||
}
|
}
|
||||||
val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc")
|
val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc")
|
||||||
new GBTClassificationModel(parent.uid, newTrees, oldModel.treeWeights)
|
new GBTClassificationModel(parent.uid, newTrees, oldModel.treeWeights, numFeatures)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -426,6 +426,8 @@ class LogisticRegressionModel private[ml] (
|
||||||
1.0 / (1.0 + math.exp(-m))
|
1.0 / (1.0 + math.exp(-m))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override val numFeatures: Int = weights.size
|
||||||
|
|
||||||
override val numClasses: Int = 2
|
override val numClasses: Int = 2
|
||||||
|
|
||||||
private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None
|
private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None
|
||||||
|
|
|
@ -181,6 +181,8 @@ class MultilayerPerceptronClassificationModel private[ml] (
|
||||||
extends PredictionModel[Vector, MultilayerPerceptronClassificationModel]
|
extends PredictionModel[Vector, MultilayerPerceptronClassificationModel]
|
||||||
with Serializable {
|
with Serializable {
|
||||||
|
|
||||||
|
override val numFeatures: Int = layers.head
|
||||||
|
|
||||||
private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights)
|
private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -137,6 +137,8 @@ class NaiveBayesModel private[ml] (
|
||||||
throw new UnknownError(s"Invalid modelType: ${$(modelType)}.")
|
throw new UnknownError(s"Invalid modelType: ${$(modelType)}.")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override val numFeatures: Int = theta.numCols
|
||||||
|
|
||||||
override val numClasses: Int = pi.size
|
override val numClasses: Int = pi.size
|
||||||
|
|
||||||
private def multinomialCalculation(features: Vector) = {
|
private def multinomialCalculation(features: Vector) = {
|
||||||
|
|
|
@ -119,13 +119,12 @@ object RandomForestClassifier {
|
||||||
* features.
|
* features.
|
||||||
* @param _trees Decision trees in the ensemble.
|
* @param _trees Decision trees in the ensemble.
|
||||||
* Warning: These have null parents.
|
* Warning: These have null parents.
|
||||||
* @param numFeatures Number of features used by this model
|
|
||||||
*/
|
*/
|
||||||
@Experimental
|
@Experimental
|
||||||
final class RandomForestClassificationModel private[ml] (
|
final class RandomForestClassificationModel private[ml] (
|
||||||
override val uid: String,
|
override val uid: String,
|
||||||
private val _trees: Array[DecisionTreeClassificationModel],
|
private val _trees: Array[DecisionTreeClassificationModel],
|
||||||
val numFeatures: Int,
|
override val numFeatures: Int,
|
||||||
override val numClasses: Int)
|
override val numClasses: Int)
|
||||||
extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
|
extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
|
||||||
with TreeEnsembleModel with Serializable {
|
with TreeEnsembleModel with Serializable {
|
||||||
|
@ -226,7 +225,8 @@ private[ml] object RandomForestClassificationModel {
|
||||||
oldModel: OldRandomForestModel,
|
oldModel: OldRandomForestModel,
|
||||||
parent: RandomForestClassifier,
|
parent: RandomForestClassifier,
|
||||||
categoricalFeatures: Map[Int, Int],
|
categoricalFeatures: Map[Int, Int],
|
||||||
numClasses: Int): RandomForestClassificationModel = {
|
numClasses: Int,
|
||||||
|
numFeatures: Int = -1): RandomForestClassificationModel = {
|
||||||
require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" +
|
require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" +
|
||||||
s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).")
|
s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).")
|
||||||
val newTrees = oldModel.trees.map { tree =>
|
val newTrees = oldModel.trees.map { tree =>
|
||||||
|
@ -234,6 +234,6 @@ private[ml] object RandomForestClassificationModel {
|
||||||
DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures)
|
DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures)
|
||||||
}
|
}
|
||||||
val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc")
|
val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc")
|
||||||
new RandomForestClassificationModel(uid, newTrees, -1, numClasses)
|
new RandomForestClassificationModel(uid, newTrees, numFeatures, numClasses)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -96,7 +96,8 @@ object DecisionTreeRegressor {
|
||||||
@Experimental
|
@Experimental
|
||||||
final class DecisionTreeRegressionModel private[ml] (
|
final class DecisionTreeRegressionModel private[ml] (
|
||||||
override val uid: String,
|
override val uid: String,
|
||||||
override val rootNode: Node)
|
override val rootNode: Node,
|
||||||
|
override val numFeatures: Int)
|
||||||
extends PredictionModel[Vector, DecisionTreeRegressionModel]
|
extends PredictionModel[Vector, DecisionTreeRegressionModel]
|
||||||
with DecisionTreeModel with Serializable {
|
with DecisionTreeModel with Serializable {
|
||||||
|
|
||||||
|
@ -107,14 +108,15 @@ final class DecisionTreeRegressionModel private[ml] (
|
||||||
* Construct a decision tree regression model.
|
* Construct a decision tree regression model.
|
||||||
* @param rootNode Root node of tree, with other nodes attached.
|
* @param rootNode Root node of tree, with other nodes attached.
|
||||||
*/
|
*/
|
||||||
private[ml] def this(rootNode: Node) = this(Identifiable.randomUID("dtr"), rootNode)
|
private[ml] def this(rootNode: Node, numFeatures: Int) =
|
||||||
|
this(Identifiable.randomUID("dtr"), rootNode, numFeatures)
|
||||||
|
|
||||||
override protected def predict(features: Vector): Double = {
|
override protected def predict(features: Vector): Double = {
|
||||||
rootNode.predictImpl(features).prediction
|
rootNode.predictImpl(features).prediction
|
||||||
}
|
}
|
||||||
|
|
||||||
override def copy(extra: ParamMap): DecisionTreeRegressionModel = {
|
override def copy(extra: ParamMap): DecisionTreeRegressionModel = {
|
||||||
copyValues(new DecisionTreeRegressionModel(uid, rootNode), extra).setParent(parent)
|
copyValues(new DecisionTreeRegressionModel(uid, rootNode, numFeatures), extra).setParent(parent)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def toString: String = {
|
override def toString: String = {
|
||||||
|
@ -133,12 +135,13 @@ private[ml] object DecisionTreeRegressionModel {
|
||||||
def fromOld(
|
def fromOld(
|
||||||
oldModel: OldDecisionTreeModel,
|
oldModel: OldDecisionTreeModel,
|
||||||
parent: DecisionTreeRegressor,
|
parent: DecisionTreeRegressor,
|
||||||
categoricalFeatures: Map[Int, Int]): DecisionTreeRegressionModel = {
|
categoricalFeatures: Map[Int, Int],
|
||||||
|
numFeatures: Int = -1): DecisionTreeRegressionModel = {
|
||||||
require(oldModel.algo == OldAlgo.Regression,
|
require(oldModel.algo == OldAlgo.Regression,
|
||||||
s"Cannot convert non-regression DecisionTreeModel (old API) to" +
|
s"Cannot convert non-regression DecisionTreeModel (old API) to" +
|
||||||
s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}")
|
s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}")
|
||||||
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
|
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
|
||||||
val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtr")
|
val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtr")
|
||||||
new DecisionTreeRegressionModel(uid, rootNode)
|
new DecisionTreeRegressionModel(uid, rootNode, numFeatures)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -128,10 +128,11 @@ final class GBTRegressor(override val uid: String)
|
||||||
val categoricalFeatures: Map[Int, Int] =
|
val categoricalFeatures: Map[Int, Int] =
|
||||||
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
|
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
|
||||||
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
|
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
|
||||||
|
val numFeatures = oldDataset.first().features.size
|
||||||
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
|
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
|
||||||
val oldGBT = new OldGBT(boostingStrategy)
|
val oldGBT = new OldGBT(boostingStrategy)
|
||||||
val oldModel = oldGBT.run(oldDataset)
|
val oldModel = oldGBT.run(oldDataset)
|
||||||
GBTRegressionModel.fromOld(oldModel, this, categoricalFeatures)
|
GBTRegressionModel.fromOld(oldModel, this, categoricalFeatures, numFeatures)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def copy(extra: ParamMap): GBTRegressor = defaultCopy(extra)
|
override def copy(extra: ParamMap): GBTRegressor = defaultCopy(extra)
|
||||||
|
@ -154,10 +155,11 @@ object GBTRegressor {
|
||||||
* @param _treeWeights Weights for the decision trees in the ensemble.
|
* @param _treeWeights Weights for the decision trees in the ensemble.
|
||||||
*/
|
*/
|
||||||
@Experimental
|
@Experimental
|
||||||
final class GBTRegressionModel(
|
final class GBTRegressionModel private[ml](
|
||||||
override val uid: String,
|
override val uid: String,
|
||||||
private val _trees: Array[DecisionTreeRegressionModel],
|
private val _trees: Array[DecisionTreeRegressionModel],
|
||||||
private val _treeWeights: Array[Double])
|
private val _treeWeights: Array[Double],
|
||||||
|
override val numFeatures: Int)
|
||||||
extends PredictionModel[Vector, GBTRegressionModel]
|
extends PredictionModel[Vector, GBTRegressionModel]
|
||||||
with TreeEnsembleModel with Serializable {
|
with TreeEnsembleModel with Serializable {
|
||||||
|
|
||||||
|
@ -165,6 +167,14 @@ final class GBTRegressionModel(
|
||||||
require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" +
|
require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" +
|
||||||
s" non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")
|
s" non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Construct a GBTRegressionModel
|
||||||
|
* @param _trees Decision trees in the ensemble.
|
||||||
|
* @param _treeWeights Weights for the decision trees in the ensemble.
|
||||||
|
*/
|
||||||
|
def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) =
|
||||||
|
this(uid, _trees, _treeWeights, -1)
|
||||||
|
|
||||||
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
|
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
|
||||||
|
|
||||||
override def treeWeights: Array[Double] = _treeWeights
|
override def treeWeights: Array[Double] = _treeWeights
|
||||||
|
@ -185,7 +195,8 @@ final class GBTRegressionModel(
|
||||||
}
|
}
|
||||||
|
|
||||||
override def copy(extra: ParamMap): GBTRegressionModel = {
|
override def copy(extra: ParamMap): GBTRegressionModel = {
|
||||||
copyValues(new GBTRegressionModel(uid, _trees, _treeWeights), extra).setParent(parent)
|
copyValues(new GBTRegressionModel(uid, _trees, _treeWeights, numFeatures),
|
||||||
|
extra).setParent(parent)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def toString: String = {
|
override def toString: String = {
|
||||||
|
@ -204,7 +215,8 @@ private[ml] object GBTRegressionModel {
|
||||||
def fromOld(
|
def fromOld(
|
||||||
oldModel: OldGBTModel,
|
oldModel: OldGBTModel,
|
||||||
parent: GBTRegressor,
|
parent: GBTRegressor,
|
||||||
categoricalFeatures: Map[Int, Int]): GBTRegressionModel = {
|
categoricalFeatures: Map[Int, Int],
|
||||||
|
numFeatures: Int = -1): GBTRegressionModel = {
|
||||||
require(oldModel.algo == OldAlgo.Regression, "Cannot convert GradientBoostedTreesModel" +
|
require(oldModel.algo == OldAlgo.Regression, "Cannot convert GradientBoostedTreesModel" +
|
||||||
s" with algo=${oldModel.algo} (old API) to GBTRegressionModel (new API).")
|
s" with algo=${oldModel.algo} (old API) to GBTRegressionModel (new API).")
|
||||||
val newTrees = oldModel.trees.map { tree =>
|
val newTrees = oldModel.trees.map { tree =>
|
||||||
|
@ -212,6 +224,6 @@ private[ml] object GBTRegressionModel {
|
||||||
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
|
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
|
||||||
}
|
}
|
||||||
val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtr")
|
val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtr")
|
||||||
new GBTRegressionModel(parent.uid, newTrees, oldModel.treeWeights)
|
new GBTRegressionModel(parent.uid, newTrees, oldModel.treeWeights, numFeatures)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -293,6 +293,8 @@ class LinearRegressionModel private[ml] (
|
||||||
|
|
||||||
private var trainingSummary: Option[LinearRegressionTrainingSummary] = None
|
private var trainingSummary: Option[LinearRegressionTrainingSummary] = None
|
||||||
|
|
||||||
|
override val numFeatures: Int = weights.size
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gets summary (e.g. residuals, mse, r-squared ) of model on training set. An exception is
|
* Gets summary (e.g. residuals, mse, r-squared ) of model on training set. An exception is
|
||||||
* thrown if `trainingSummary == None`.
|
* thrown if `trainingSummary == None`.
|
||||||
|
|
|
@ -115,7 +115,7 @@ object RandomForestRegressor {
|
||||||
final class RandomForestRegressionModel private[ml] (
|
final class RandomForestRegressionModel private[ml] (
|
||||||
override val uid: String,
|
override val uid: String,
|
||||||
private val _trees: Array[DecisionTreeRegressionModel],
|
private val _trees: Array[DecisionTreeRegressionModel],
|
||||||
val numFeatures: Int)
|
override val numFeatures: Int)
|
||||||
extends PredictionModel[Vector, RandomForestRegressionModel]
|
extends PredictionModel[Vector, RandomForestRegressionModel]
|
||||||
with TreeEnsembleModel with Serializable {
|
with TreeEnsembleModel with Serializable {
|
||||||
|
|
||||||
|
@ -187,13 +187,14 @@ private[ml] object RandomForestRegressionModel {
|
||||||
def fromOld(
|
def fromOld(
|
||||||
oldModel: OldRandomForestModel,
|
oldModel: OldRandomForestModel,
|
||||||
parent: RandomForestRegressor,
|
parent: RandomForestRegressor,
|
||||||
categoricalFeatures: Map[Int, Int]): RandomForestRegressionModel = {
|
categoricalFeatures: Map[Int, Int],
|
||||||
|
numFeatures: Int = -1): RandomForestRegressionModel = {
|
||||||
require(oldModel.algo == OldAlgo.Regression, "Cannot convert RandomForestModel" +
|
require(oldModel.algo == OldAlgo.Regression, "Cannot convert RandomForestModel" +
|
||||||
s" with algo=${oldModel.algo} (old API) to RandomForestRegressionModel (new API).")
|
s" with algo=${oldModel.algo} (old API) to RandomForestRegressionModel (new API).")
|
||||||
val newTrees = oldModel.trees.map { tree =>
|
val newTrees = oldModel.trees.map { tree =>
|
||||||
// parent for each tree is null since there is no good way to set this.
|
// parent for each tree is null since there is no good way to set this.
|
||||||
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
|
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
|
||||||
}
|
}
|
||||||
new RandomForestRegressionModel(parent.uid, newTrees, -1)
|
new RandomForestRegressionModel(parent.uid, newTrees, numFeatures)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -179,22 +179,28 @@ private[ml] object RandomForest extends Logging {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
val numFeatures = metadata.numFeatures
|
||||||
|
|
||||||
parentUID match {
|
parentUID match {
|
||||||
case Some(uid) =>
|
case Some(uid) =>
|
||||||
if (strategy.algo == OldAlgo.Classification) {
|
if (strategy.algo == OldAlgo.Classification) {
|
||||||
topNodes.map { rootNode =>
|
topNodes.map { rootNode =>
|
||||||
new DecisionTreeClassificationModel(uid, rootNode.toNode, strategy.getNumClasses)
|
new DecisionTreeClassificationModel(uid, rootNode.toNode, numFeatures,
|
||||||
|
strategy.getNumClasses)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
topNodes.map(rootNode => new DecisionTreeRegressionModel(uid, rootNode.toNode))
|
topNodes.map { rootNode =>
|
||||||
|
new DecisionTreeRegressionModel(uid, rootNode.toNode, numFeatures)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case None =>
|
case None =>
|
||||||
if (strategy.algo == OldAlgo.Classification) {
|
if (strategy.algo == OldAlgo.Classification) {
|
||||||
topNodes.map { rootNode =>
|
topNodes.map { rootNode =>
|
||||||
new DecisionTreeClassificationModel(rootNode.toNode, strategy.getNumClasses)
|
new DecisionTreeClassificationModel(rootNode.toNode, numFeatures,
|
||||||
|
strategy.getNumClasses)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode))
|
topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode, numFeatures))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -59,7 +59,7 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
|
||||||
|
|
||||||
test("params") {
|
test("params") {
|
||||||
ParamsSuite.checkParams(new DecisionTreeClassifier)
|
ParamsSuite.checkParams(new DecisionTreeClassifier)
|
||||||
val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)
|
val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 1, 2)
|
||||||
ParamsSuite.checkParams(model)
|
ParamsSuite.checkParams(model)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -310,6 +310,7 @@ private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite {
|
||||||
dt: DecisionTreeClassifier,
|
dt: DecisionTreeClassifier,
|
||||||
categoricalFeatures: Map[Int, Int],
|
categoricalFeatures: Map[Int, Int],
|
||||||
numClasses: Int): Unit = {
|
numClasses: Int): Unit = {
|
||||||
|
val numFeatures = data.first().features.size
|
||||||
val oldStrategy = dt.getOldStrategy(categoricalFeatures, numClasses)
|
val oldStrategy = dt.getOldStrategy(categoricalFeatures, numClasses)
|
||||||
val oldTree = OldDecisionTree.train(data, oldStrategy)
|
val oldTree = OldDecisionTree.train(data, oldStrategy)
|
||||||
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
|
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
|
||||||
|
@ -318,5 +319,6 @@ private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite {
|
||||||
val oldTreeAsNew = DecisionTreeClassificationModel.fromOld(
|
val oldTreeAsNew = DecisionTreeClassificationModel.fromOld(
|
||||||
oldTree, newTree.parent.asInstanceOf[DecisionTreeClassifier], categoricalFeatures)
|
oldTree, newTree.parent.asInstanceOf[DecisionTreeClassifier], categoricalFeatures)
|
||||||
TreeTests.checkEqual(oldTreeAsNew, newTree)
|
TreeTests.checkEqual(oldTreeAsNew, newTree)
|
||||||
|
assert(newTree.numFeatures === numFeatures)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -59,8 +59,8 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
test("params") {
|
test("params") {
|
||||||
ParamsSuite.checkParams(new GBTClassifier)
|
ParamsSuite.checkParams(new GBTClassifier)
|
||||||
val model = new GBTClassificationModel("gbtc",
|
val model = new GBTClassificationModel("gbtc",
|
||||||
Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null))),
|
Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null), 1)),
|
||||||
Array(1.0))
|
Array(1.0), 1)
|
||||||
ParamsSuite.checkParams(model)
|
ParamsSuite.checkParams(model)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -145,7 +145,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
*/
|
*/
|
||||||
}
|
}
|
||||||
|
|
||||||
private object GBTClassifierSuite {
|
private object GBTClassifierSuite extends SparkFunSuite {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Train 2 models on the given dataset, one using the old API and one using the new API.
|
* Train 2 models on the given dataset, one using the old API and one using the new API.
|
||||||
|
@ -156,6 +156,7 @@ private object GBTClassifierSuite {
|
||||||
validationData: Option[RDD[LabeledPoint]],
|
validationData: Option[RDD[LabeledPoint]],
|
||||||
gbt: GBTClassifier,
|
gbt: GBTClassifier,
|
||||||
categoricalFeatures: Map[Int, Int]): Unit = {
|
categoricalFeatures: Map[Int, Int]): Unit = {
|
||||||
|
val numFeatures = data.first().features.size
|
||||||
val oldBoostingStrategy =
|
val oldBoostingStrategy =
|
||||||
gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
|
gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
|
||||||
val oldGBT = new OldGBT(oldBoostingStrategy)
|
val oldGBT = new OldGBT(oldBoostingStrategy)
|
||||||
|
@ -164,7 +165,9 @@ private object GBTClassifierSuite {
|
||||||
val newModel = gbt.fit(newData)
|
val newModel = gbt.fit(newData)
|
||||||
// Use parent from newTree since this is not checked anyways.
|
// Use parent from newTree since this is not checked anyways.
|
||||||
val oldModelAsNew = GBTClassificationModel.fromOld(
|
val oldModelAsNew = GBTClassificationModel.fromOld(
|
||||||
oldModel, newModel.parent.asInstanceOf[GBTClassifier], categoricalFeatures)
|
oldModel, newModel.parent.asInstanceOf[GBTClassifier], categoricalFeatures, numFeatures)
|
||||||
TreeTests.checkEqual(oldModelAsNew, newModel)
|
TreeTests.checkEqual(oldModelAsNew, newModel)
|
||||||
|
assert(newModel.numFeatures === numFeatures)
|
||||||
|
assert(oldModelAsNew.numFeatures === numFeatures)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -194,6 +194,8 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
|
|
||||||
val model = lr.fit(dataset)
|
val model = lr.fit(dataset)
|
||||||
assert(model.numClasses === 2)
|
assert(model.numClasses === 2)
|
||||||
|
val numFeatures = dataset.select("features").first().getAs[Vector](0).size
|
||||||
|
assert(model.numFeatures === numFeatures)
|
||||||
|
|
||||||
val threshold = model.getThreshold
|
val threshold = model.getThreshold
|
||||||
val results = model.transform(dataset)
|
val results = model.transform(dataset)
|
||||||
|
|
|
@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite
|
||||||
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
|
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
|
||||||
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
|
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
|
||||||
import org.apache.spark.mllib.evaluation.MulticlassMetrics
|
import org.apache.spark.mllib.evaluation.MulticlassMetrics
|
||||||
import org.apache.spark.mllib.linalg.Vectors
|
import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
||||||
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
import org.apache.spark.mllib.util.MLlibTestSparkContext
|
||||||
import org.apache.spark.mllib.util.TestingUtils._
|
import org.apache.spark.mllib.util.TestingUtils._
|
||||||
import org.apache.spark.sql.Row
|
import org.apache.spark.sql.Row
|
||||||
|
@ -73,6 +73,8 @@ class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSp
|
||||||
.setSeed(11L)
|
.setSeed(11L)
|
||||||
.setMaxIter(numIterations)
|
.setMaxIter(numIterations)
|
||||||
val model = trainer.fit(dataFrame)
|
val model = trainer.fit(dataFrame)
|
||||||
|
val numFeatures = dataFrame.select("features").first().getAs[Vector](0).size
|
||||||
|
assert(model.numFeatures === numFeatures)
|
||||||
val mlpPredictionAndLabels = model.transform(dataFrame).select("prediction", "label")
|
val mlpPredictionAndLabels = model.transform(dataFrame).select("prediction", "label")
|
||||||
.map { case Row(p: Double, l: Double) => (p, l) }
|
.map { case Row(p: Double, l: Double) => (p, l) }
|
||||||
// train multinomial logistic regression
|
// train multinomial logistic regression
|
||||||
|
|
|
@ -22,6 +22,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
||||||
|
|
||||||
final class TestProbabilisticClassificationModel(
|
final class TestProbabilisticClassificationModel(
|
||||||
override val uid: String,
|
override val uid: String,
|
||||||
|
override val numFeatures: Int,
|
||||||
override val numClasses: Int)
|
override val numClasses: Int)
|
||||||
extends ProbabilisticClassificationModel[Vector, TestProbabilisticClassificationModel] {
|
extends ProbabilisticClassificationModel[Vector, TestProbabilisticClassificationModel] {
|
||||||
|
|
||||||
|
@ -45,13 +46,14 @@ class ProbabilisticClassifierSuite extends SparkFunSuite {
|
||||||
|
|
||||||
test("test thresholding") {
|
test("test thresholding") {
|
||||||
val thresholds = Array(0.5, 0.2)
|
val thresholds = Array(0.5, 0.2)
|
||||||
val testModel = new TestProbabilisticClassificationModel("myuid", 2).setThresholds(thresholds)
|
val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2)
|
||||||
|
.setThresholds(thresholds)
|
||||||
assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 1.0))) === 1.0)
|
assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 1.0))) === 1.0)
|
||||||
assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 0.2))) === 0.0)
|
assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 0.2))) === 0.0)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("test thresholding not required") {
|
test("test thresholding not required") {
|
||||||
val testModel = new TestProbabilisticClassificationModel("myuid", 2)
|
val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2)
|
||||||
assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 2.0))) === 1.0)
|
assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 2.0))) === 1.0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -68,7 +68,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
|
||||||
test("params") {
|
test("params") {
|
||||||
ParamsSuite.checkParams(new RandomForestClassifier)
|
ParamsSuite.checkParams(new RandomForestClassifier)
|
||||||
val model = new RandomForestClassificationModel("rfc",
|
val model = new RandomForestClassificationModel("rfc",
|
||||||
Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)), 2, 2)
|
Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 1, 2)), 2, 2)
|
||||||
ParamsSuite.checkParams(model)
|
ParamsSuite.checkParams(model)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -209,7 +209,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
|
||||||
*/
|
*/
|
||||||
}
|
}
|
||||||
|
|
||||||
private object RandomForestClassifierSuite {
|
private object RandomForestClassifierSuite extends SparkFunSuite {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Train 2 models on the given dataset, one using the old API and one using the new API.
|
* Train 2 models on the given dataset, one using the old API and one using the new API.
|
||||||
|
@ -220,6 +220,7 @@ private object RandomForestClassifierSuite {
|
||||||
rf: RandomForestClassifier,
|
rf: RandomForestClassifier,
|
||||||
categoricalFeatures: Map[Int, Int],
|
categoricalFeatures: Map[Int, Int],
|
||||||
numClasses: Int): Unit = {
|
numClasses: Int): Unit = {
|
||||||
|
val numFeatures = data.first().features.size
|
||||||
val oldStrategy =
|
val oldStrategy =
|
||||||
rf.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, rf.getOldImpurity)
|
rf.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, rf.getOldImpurity)
|
||||||
val oldModel = OldRandomForest.trainClassifier(
|
val oldModel = OldRandomForest.trainClassifier(
|
||||||
|
@ -233,6 +234,7 @@ private object RandomForestClassifierSuite {
|
||||||
TreeTests.checkEqual(oldModelAsNew, newModel)
|
TreeTests.checkEqual(oldModelAsNew, newModel)
|
||||||
assert(newModel.hasParent)
|
assert(newModel.hasParent)
|
||||||
assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent)
|
assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent)
|
||||||
assert(newModel.numClasses == numClasses)
|
assert(newModel.numClasses === numClasses)
|
||||||
|
assert(newModel.numFeatures === numFeatures)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -89,6 +89,7 @@ private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite {
|
||||||
data: RDD[LabeledPoint],
|
data: RDD[LabeledPoint],
|
||||||
dt: DecisionTreeRegressor,
|
dt: DecisionTreeRegressor,
|
||||||
categoricalFeatures: Map[Int, Int]): Unit = {
|
categoricalFeatures: Map[Int, Int]): Unit = {
|
||||||
|
val numFeatures = data.first().features.size
|
||||||
val oldStrategy = dt.getOldStrategy(categoricalFeatures)
|
val oldStrategy = dt.getOldStrategy(categoricalFeatures)
|
||||||
val oldTree = OldDecisionTree.train(data, oldStrategy)
|
val oldTree = OldDecisionTree.train(data, oldStrategy)
|
||||||
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
|
val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
|
||||||
|
@ -97,5 +98,6 @@ private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite {
|
||||||
val oldTreeAsNew = DecisionTreeRegressionModel.fromOld(
|
val oldTreeAsNew = DecisionTreeRegressionModel.fromOld(
|
||||||
oldTree, newTree.parent.asInstanceOf[DecisionTreeRegressor], categoricalFeatures)
|
oldTree, newTree.parent.asInstanceOf[DecisionTreeRegressor], categoricalFeatures)
|
||||||
TreeTests.checkEqual(oldTreeAsNew, newTree)
|
TreeTests.checkEqual(oldTreeAsNew, newTree)
|
||||||
|
assert(newTree.numFeatures === numFeatures)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -156,7 +156,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
*/
|
*/
|
||||||
}
|
}
|
||||||
|
|
||||||
private object GBTRegressorSuite {
|
private object GBTRegressorSuite extends SparkFunSuite {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Train 2 models on the given dataset, one using the old API and one using the new API.
|
* Train 2 models on the given dataset, one using the old API and one using the new API.
|
||||||
|
@ -167,6 +167,7 @@ private object GBTRegressorSuite {
|
||||||
validationData: Option[RDD[LabeledPoint]],
|
validationData: Option[RDD[LabeledPoint]],
|
||||||
gbt: GBTRegressor,
|
gbt: GBTRegressor,
|
||||||
categoricalFeatures: Map[Int, Int]): Unit = {
|
categoricalFeatures: Map[Int, Int]): Unit = {
|
||||||
|
val numFeatures = data.first().features.size
|
||||||
val oldBoostingStrategy = gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
|
val oldBoostingStrategy = gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
|
||||||
val oldGBT = new OldGBT(oldBoostingStrategy)
|
val oldGBT = new OldGBT(oldBoostingStrategy)
|
||||||
val oldModel = oldGBT.run(data)
|
val oldModel = oldGBT.run(data)
|
||||||
|
@ -174,7 +175,9 @@ private object GBTRegressorSuite {
|
||||||
val newModel = gbt.fit(newData)
|
val newModel = gbt.fit(newData)
|
||||||
// Use parent from newTree since this is not checked anyways.
|
// Use parent from newTree since this is not checked anyways.
|
||||||
val oldModelAsNew = GBTRegressionModel.fromOld(
|
val oldModelAsNew = GBTRegressionModel.fromOld(
|
||||||
oldModel, newModel.parent.asInstanceOf[GBTRegressor], categoricalFeatures)
|
oldModel, newModel.parent.asInstanceOf[GBTRegressor], categoricalFeatures, numFeatures)
|
||||||
TreeTests.checkEqual(oldModelAsNew, newModel)
|
TreeTests.checkEqual(oldModelAsNew, newModel)
|
||||||
|
assert(newModel.numFeatures === numFeatures)
|
||||||
|
assert(oldModelAsNew.numFeatures === numFeatures)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,8 +22,8 @@ import scala.util.Random
|
||||||
import org.apache.spark.SparkFunSuite
|
import org.apache.spark.SparkFunSuite
|
||||||
import org.apache.spark.ml.param.ParamsSuite
|
import org.apache.spark.ml.param.ParamsSuite
|
||||||
import org.apache.spark.ml.util.MLTestingUtils
|
import org.apache.spark.ml.util.MLTestingUtils
|
||||||
import org.apache.spark.mllib.linalg.{DenseVector, Vectors}
|
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
|
import org.apache.spark.mllib.linalg.{Vector, DenseVector, Vectors}
|
||||||
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
|
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
|
||||||
import org.apache.spark.mllib.util.TestingUtils._
|
import org.apache.spark.mllib.util.TestingUtils._
|
||||||
import org.apache.spark.sql.{DataFrame, Row}
|
import org.apache.spark.sql.{DataFrame, Row}
|
||||||
|
@ -87,6 +87,8 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
assert(model.getPredictionCol === "prediction")
|
assert(model.getPredictionCol === "prediction")
|
||||||
assert(model.intercept !== 0.0)
|
assert(model.intercept !== 0.0)
|
||||||
assert(model.hasParent)
|
assert(model.hasParent)
|
||||||
|
val numFeatures = dataset.select("features").first().getAs[Vector](0).size
|
||||||
|
assert(model.numFeatures === numFeatures)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("linear regression with intercept without regularization") {
|
test("linear regression with intercept without regularization") {
|
||||||
|
|
|
@ -137,6 +137,7 @@ private object RandomForestRegressorSuite extends SparkFunSuite {
|
||||||
data: RDD[LabeledPoint],
|
data: RDD[LabeledPoint],
|
||||||
rf: RandomForestRegressor,
|
rf: RandomForestRegressor,
|
||||||
categoricalFeatures: Map[Int, Int]): Unit = {
|
categoricalFeatures: Map[Int, Int]): Unit = {
|
||||||
|
val numFeatures = data.first().features.size
|
||||||
val oldStrategy =
|
val oldStrategy =
|
||||||
rf.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, rf.getOldImpurity)
|
rf.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, rf.getOldImpurity)
|
||||||
val oldModel = OldRandomForest.trainRegressor(
|
val oldModel = OldRandomForest.trainRegressor(
|
||||||
|
@ -147,5 +148,6 @@ private object RandomForestRegressorSuite extends SparkFunSuite {
|
||||||
val oldModelAsNew = RandomForestRegressionModel.fromOld(
|
val oldModelAsNew = RandomForestRegressionModel.fromOld(
|
||||||
oldModel, newModel.parent.asInstanceOf[RandomForestRegressor], categoricalFeatures)
|
oldModel, newModel.parent.asInstanceOf[RandomForestRegressor], categoricalFeatures)
|
||||||
TreeTests.checkEqual(oldModelAsNew, newModel)
|
TreeTests.checkEqual(oldModelAsNew, newModel)
|
||||||
|
assert(newModel.numFeatures === numFeatures)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -77,7 +77,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
|
||||||
|
|
||||||
// Forest consisting of (full tree) + (internal node with 2 leafs)
|
// Forest consisting of (full tree) + (internal node with 2 leafs)
|
||||||
val trees = Array(parent, grandParent).map { root =>
|
val trees = Array(parent, grandParent).map { root =>
|
||||||
new DecisionTreeClassificationModel(root, numClasses = 3).asInstanceOf[DecisionTreeModel]
|
new DecisionTreeClassificationModel(root, numFeatures = 2, numClasses = 3)
|
||||||
|
.asInstanceOf[DecisionTreeModel]
|
||||||
}
|
}
|
||||||
val importances: Vector = RandomForest.featureImportances(trees, 2)
|
val importances: Vector = RandomForest.featureImportances(trees, 2)
|
||||||
val tree2norm = feature0importance + feature1importance
|
val tree2norm = feature0importance + feature1importance
|
||||||
|
|
Loading…
Reference in a new issue