[SPARK-10243] [MLLIB] update since versions in mllib.tree
Same as #8421 but for `mllib.tree`. cc jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #8442 from mengxr/SPARK-10236.
This commit is contained in:
parent
d703372f86
commit
fb7e12fe2e
|
@ -46,7 +46,8 @@ import org.apache.spark.util.random.XORShiftRandom
|
||||||
*/
|
*/
|
||||||
@Since("1.0.0")
|
@Since("1.0.0")
|
||||||
@Experimental
|
@Experimental
|
||||||
class DecisionTree (private val strategy: Strategy) extends Serializable with Logging {
|
class DecisionTree @Since("1.0.0") (private val strategy: Strategy)
|
||||||
|
extends Serializable with Logging {
|
||||||
|
|
||||||
strategy.assertValid()
|
strategy.assertValid()
|
||||||
|
|
||||||
|
|
|
@ -51,7 +51,7 @@ import org.apache.spark.storage.StorageLevel
|
||||||
*/
|
*/
|
||||||
@Since("1.2.0")
|
@Since("1.2.0")
|
||||||
@Experimental
|
@Experimental
|
||||||
class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
|
class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: BoostingStrategy)
|
||||||
extends Serializable with Logging {
|
extends Serializable with Logging {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -26,7 +26,9 @@ import org.apache.spark.annotation.{Experimental, Since}
|
||||||
@Since("1.0.0")
|
@Since("1.0.0")
|
||||||
@Experimental
|
@Experimental
|
||||||
object Algo extends Enumeration {
|
object Algo extends Enumeration {
|
||||||
|
@Since("1.0.0")
|
||||||
type Algo = Value
|
type Algo = Value
|
||||||
|
@Since("1.0.0")
|
||||||
val Classification, Regression = Value
|
val Classification, Regression = Value
|
||||||
|
|
||||||
private[mllib] def fromString(name: String): Algo = name match {
|
private[mllib] def fromString(name: String): Algo = name match {
|
||||||
|
|
|
@ -41,14 +41,14 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
|
||||||
*/
|
*/
|
||||||
@Since("1.2.0")
|
@Since("1.2.0")
|
||||||
@Experimental
|
@Experimental
|
||||||
case class BoostingStrategy(
|
case class BoostingStrategy @Since("1.4.0") (
|
||||||
// Required boosting parameters
|
// Required boosting parameters
|
||||||
@BeanProperty var treeStrategy: Strategy,
|
@Since("1.2.0") @BeanProperty var treeStrategy: Strategy,
|
||||||
@BeanProperty var loss: Loss,
|
@Since("1.2.0") @BeanProperty var loss: Loss,
|
||||||
// Optional boosting parameters
|
// Optional boosting parameters
|
||||||
@BeanProperty var numIterations: Int = 100,
|
@Since("1.2.0") @BeanProperty var numIterations: Int = 100,
|
||||||
@BeanProperty var learningRate: Double = 0.1,
|
@Since("1.2.0") @BeanProperty var learningRate: Double = 0.1,
|
||||||
@BeanProperty var validationTol: Double = 1e-5) extends Serializable {
|
@Since("1.4.0") @BeanProperty var validationTol: Double = 1e-5) extends Serializable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Check validity of parameters.
|
* Check validity of parameters.
|
||||||
|
|
|
@ -26,6 +26,8 @@ import org.apache.spark.annotation.{Experimental, Since}
|
||||||
@Since("1.0.0")
|
@Since("1.0.0")
|
||||||
@Experimental
|
@Experimental
|
||||||
object FeatureType extends Enumeration {
|
object FeatureType extends Enumeration {
|
||||||
|
@Since("1.0.0")
|
||||||
type FeatureType = Value
|
type FeatureType = Value
|
||||||
|
@Since("1.0.0")
|
||||||
val Continuous, Categorical = Value
|
val Continuous, Categorical = Value
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,6 +26,8 @@ import org.apache.spark.annotation.{Experimental, Since}
|
||||||
@Since("1.0.0")
|
@Since("1.0.0")
|
||||||
@Experimental
|
@Experimental
|
||||||
object QuantileStrategy extends Enumeration {
|
object QuantileStrategy extends Enumeration {
|
||||||
|
@Since("1.0.0")
|
||||||
type QuantileStrategy = Value
|
type QuantileStrategy = Value
|
||||||
|
@Since("1.0.0")
|
||||||
val Sort, MinMax, ApproxHist = Value
|
val Sort, MinMax, ApproxHist = Value
|
||||||
}
|
}
|
||||||
|
|
|
@ -69,20 +69,20 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
|
||||||
*/
|
*/
|
||||||
@Since("1.0.0")
|
@Since("1.0.0")
|
||||||
@Experimental
|
@Experimental
|
||||||
class Strategy (
|
class Strategy @Since("1.3.0") (
|
||||||
@BeanProperty var algo: Algo,
|
@Since("1.0.0") @BeanProperty var algo: Algo,
|
||||||
@BeanProperty var impurity: Impurity,
|
@Since("1.0.0") @BeanProperty var impurity: Impurity,
|
||||||
@BeanProperty var maxDepth: Int,
|
@Since("1.0.0") @BeanProperty var maxDepth: Int,
|
||||||
@BeanProperty var numClasses: Int = 2,
|
@Since("1.2.0") @BeanProperty var numClasses: Int = 2,
|
||||||
@BeanProperty var maxBins: Int = 32,
|
@Since("1.0.0") @BeanProperty var maxBins: Int = 32,
|
||||||
@BeanProperty var quantileCalculationStrategy: QuantileStrategy = Sort,
|
@Since("1.0.0") @BeanProperty var quantileCalculationStrategy: QuantileStrategy = Sort,
|
||||||
@BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
|
@Since("1.0.0") @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
|
||||||
@BeanProperty var minInstancesPerNode: Int = 1,
|
@Since("1.2.0") @BeanProperty var minInstancesPerNode: Int = 1,
|
||||||
@BeanProperty var minInfoGain: Double = 0.0,
|
@Since("1.2.0") @BeanProperty var minInfoGain: Double = 0.0,
|
||||||
@BeanProperty var maxMemoryInMB: Int = 256,
|
@Since("1.0.0") @BeanProperty var maxMemoryInMB: Int = 256,
|
||||||
@BeanProperty var subsamplingRate: Double = 1,
|
@Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1,
|
||||||
@BeanProperty var useNodeIdCache: Boolean = false,
|
@Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false,
|
||||||
@BeanProperty var checkpointInterval: Int = 10) extends Serializable {
|
@Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10) extends Serializable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*/
|
*/
|
||||||
|
@ -206,6 +206,7 @@ object Strategy {
|
||||||
}
|
}
|
||||||
|
|
||||||
@deprecated("Use Strategy.defaultStrategy instead.", "1.5.0")
|
@deprecated("Use Strategy.defaultStrategy instead.", "1.5.0")
|
||||||
|
@Since("1.2.0")
|
||||||
def defaultStategy(algo: Algo): Strategy = defaultStrategy(algo)
|
def defaultStategy(algo: Algo): Strategy = defaultStrategy(algo)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,7 +43,9 @@ import org.apache.spark.util.Utils
|
||||||
*/
|
*/
|
||||||
@Since("1.0.0")
|
@Since("1.0.0")
|
||||||
@Experimental
|
@Experimental
|
||||||
class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable with Saveable {
|
class DecisionTreeModel @Since("1.0.0") (
|
||||||
|
@Since("1.0.0") val topNode: Node,
|
||||||
|
@Since("1.0.0") val algo: Algo) extends Serializable with Saveable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Predict values for a single data point using the model trained.
|
* Predict values for a single data point using the model trained.
|
||||||
|
@ -110,6 +112,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
|
||||||
/**
|
/**
|
||||||
* Print the full model to a string.
|
* Print the full model to a string.
|
||||||
*/
|
*/
|
||||||
|
@Since("1.2.0")
|
||||||
def toDebugString: String = {
|
def toDebugString: String = {
|
||||||
val header = toString + "\n"
|
val header = toString + "\n"
|
||||||
header + topNode.subtreeToString(2)
|
header + topNode.subtreeToString(2)
|
||||||
|
|
|
@ -41,15 +41,15 @@ import org.apache.spark.mllib.linalg.Vector
|
||||||
*/
|
*/
|
||||||
@Since("1.0.0")
|
@Since("1.0.0")
|
||||||
@DeveloperApi
|
@DeveloperApi
|
||||||
class Node (
|
class Node @Since("1.2.0") (
|
||||||
val id: Int,
|
@Since("1.0.0") val id: Int,
|
||||||
var predict: Predict,
|
@Since("1.0.0") var predict: Predict,
|
||||||
var impurity: Double,
|
@Since("1.2.0") var impurity: Double,
|
||||||
var isLeaf: Boolean,
|
@Since("1.0.0") var isLeaf: Boolean,
|
||||||
var split: Option[Split],
|
@Since("1.0.0") var split: Option[Split],
|
||||||
var leftNode: Option[Node],
|
@Since("1.0.0") var leftNode: Option[Node],
|
||||||
var rightNode: Option[Node],
|
@Since("1.0.0") var rightNode: Option[Node],
|
||||||
var stats: Option[InformationGainStats]) extends Serializable with Logging {
|
@Since("1.0.0") var stats: Option[InformationGainStats]) extends Serializable with Logging {
|
||||||
|
|
||||||
override def toString: String = {
|
override def toString: String = {
|
||||||
s"id = $id, isLeaf = $isLeaf, predict = $predict, impurity = $impurity, " +
|
s"id = $id, isLeaf = $isLeaf, predict = $predict, impurity = $impurity, " +
|
||||||
|
|
|
@ -26,9 +26,9 @@ import org.apache.spark.annotation.{DeveloperApi, Since}
|
||||||
*/
|
*/
|
||||||
@Since("1.2.0")
|
@Since("1.2.0")
|
||||||
@DeveloperApi
|
@DeveloperApi
|
||||||
class Predict(
|
class Predict @Since("1.2.0") (
|
||||||
val predict: Double,
|
@Since("1.2.0") val predict: Double,
|
||||||
val prob: Double = 0.0) extends Serializable {
|
@Since("1.2.0") val prob: Double = 0.0) extends Serializable {
|
||||||
|
|
||||||
override def toString: String = s"$predict (prob = $prob)"
|
override def toString: String = s"$predict (prob = $prob)"
|
||||||
|
|
||||||
|
|
|
@ -34,10 +34,10 @@ import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
|
||||||
@Since("1.0.0")
|
@Since("1.0.0")
|
||||||
@DeveloperApi
|
@DeveloperApi
|
||||||
case class Split(
|
case class Split(
|
||||||
feature: Int,
|
@Since("1.0.0") feature: Int,
|
||||||
threshold: Double,
|
@Since("1.0.0") threshold: Double,
|
||||||
featureType: FeatureType,
|
@Since("1.0.0") featureType: FeatureType,
|
||||||
categories: List[Double]) {
|
@Since("1.0.0") categories: List[Double]) {
|
||||||
|
|
||||||
override def toString: String = {
|
override def toString: String = {
|
||||||
s"Feature = $feature, threshold = $threshold, featureType = $featureType, " +
|
s"Feature = $feature, threshold = $threshold, featureType = $featureType, " +
|
||||||
|
|
|
@ -48,7 +48,9 @@ import org.apache.spark.util.Utils
|
||||||
*/
|
*/
|
||||||
@Since("1.2.0")
|
@Since("1.2.0")
|
||||||
@Experimental
|
@Experimental
|
||||||
class RandomForestModel(override val algo: Algo, override val trees: Array[DecisionTreeModel])
|
class RandomForestModel @Since("1.2.0") (
|
||||||
|
@Since("1.2.0") override val algo: Algo,
|
||||||
|
@Since("1.2.0") override val trees: Array[DecisionTreeModel])
|
||||||
extends TreeEnsembleModel(algo, trees, Array.fill(trees.length)(1.0),
|
extends TreeEnsembleModel(algo, trees, Array.fill(trees.length)(1.0),
|
||||||
combiningStrategy = if (algo == Classification) Vote else Average)
|
combiningStrategy = if (algo == Classification) Vote else Average)
|
||||||
with Saveable {
|
with Saveable {
|
||||||
|
@ -115,10 +117,10 @@ object RandomForestModel extends Loader[RandomForestModel] {
|
||||||
*/
|
*/
|
||||||
@Since("1.2.0")
|
@Since("1.2.0")
|
||||||
@Experimental
|
@Experimental
|
||||||
class GradientBoostedTreesModel(
|
class GradientBoostedTreesModel @Since("1.2.0") (
|
||||||
override val algo: Algo,
|
@Since("1.2.0") override val algo: Algo,
|
||||||
override val trees: Array[DecisionTreeModel],
|
@Since("1.2.0") override val trees: Array[DecisionTreeModel],
|
||||||
override val treeWeights: Array[Double])
|
@Since("1.2.0") override val treeWeights: Array[Double])
|
||||||
extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum)
|
extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum)
|
||||||
with Saveable {
|
with Saveable {
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue