[SPARK-10243] [MLLIB] update since versions in mllib.tree
Same as #8421 but for `mllib.tree`. cc jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #8442 from mengxr/SPARK-10236.
This commit is contained in:
parent
d703372f86
commit
fb7e12fe2e
|
@ -46,7 +46,8 @@ import org.apache.spark.util.random.XORShiftRandom
|
|||
*/
|
||||
@Since("1.0.0")
|
||||
@Experimental
|
||||
class DecisionTree (private val strategy: Strategy) extends Serializable with Logging {
|
||||
class DecisionTree @Since("1.0.0") (private val strategy: Strategy)
|
||||
extends Serializable with Logging {
|
||||
|
||||
strategy.assertValid()
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ import org.apache.spark.storage.StorageLevel
|
|||
*/
|
||||
@Since("1.2.0")
|
||||
@Experimental
|
||||
class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
|
||||
class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: BoostingStrategy)
|
||||
extends Serializable with Logging {
|
||||
|
||||
/**
|
||||
|
|
|
@ -26,7 +26,9 @@ import org.apache.spark.annotation.{Experimental, Since}
|
|||
@Since("1.0.0")
|
||||
@Experimental
|
||||
object Algo extends Enumeration {
|
||||
@Since("1.0.0")
|
||||
type Algo = Value
|
||||
@Since("1.0.0")
|
||||
val Classification, Regression = Value
|
||||
|
||||
private[mllib] def fromString(name: String): Algo = name match {
|
||||
|
|
|
@ -41,14 +41,14 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
|
|||
*/
|
||||
@Since("1.2.0")
|
||||
@Experimental
|
||||
case class BoostingStrategy(
|
||||
case class BoostingStrategy @Since("1.4.0") (
|
||||
// Required boosting parameters
|
||||
@BeanProperty var treeStrategy: Strategy,
|
||||
@BeanProperty var loss: Loss,
|
||||
@Since("1.2.0") @BeanProperty var treeStrategy: Strategy,
|
||||
@Since("1.2.0") @BeanProperty var loss: Loss,
|
||||
// Optional boosting parameters
|
||||
@BeanProperty var numIterations: Int = 100,
|
||||
@BeanProperty var learningRate: Double = 0.1,
|
||||
@BeanProperty var validationTol: Double = 1e-5) extends Serializable {
|
||||
@Since("1.2.0") @BeanProperty var numIterations: Int = 100,
|
||||
@Since("1.2.0") @BeanProperty var learningRate: Double = 0.1,
|
||||
@Since("1.4.0") @BeanProperty var validationTol: Double = 1e-5) extends Serializable {
|
||||
|
||||
/**
|
||||
* Check validity of parameters.
|
||||
|
|
|
@ -26,6 +26,8 @@ import org.apache.spark.annotation.{Experimental, Since}
|
|||
@Since("1.0.0")
|
||||
@Experimental
|
||||
object FeatureType extends Enumeration {
|
||||
@Since("1.0.0")
|
||||
type FeatureType = Value
|
||||
@Since("1.0.0")
|
||||
val Continuous, Categorical = Value
|
||||
}
|
||||
|
|
|
@ -26,6 +26,8 @@ import org.apache.spark.annotation.{Experimental, Since}
|
|||
@Since("1.0.0")
|
||||
@Experimental
|
||||
object QuantileStrategy extends Enumeration {
|
||||
@Since("1.0.0")
|
||||
type QuantileStrategy = Value
|
||||
@Since("1.0.0")
|
||||
val Sort, MinMax, ApproxHist = Value
|
||||
}
|
||||
|
|
|
@ -69,20 +69,20 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
|
|||
*/
|
||||
@Since("1.0.0")
|
||||
@Experimental
|
||||
class Strategy (
|
||||
@BeanProperty var algo: Algo,
|
||||
@BeanProperty var impurity: Impurity,
|
||||
@BeanProperty var maxDepth: Int,
|
||||
@BeanProperty var numClasses: Int = 2,
|
||||
@BeanProperty var maxBins: Int = 32,
|
||||
@BeanProperty var quantileCalculationStrategy: QuantileStrategy = Sort,
|
||||
@BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
|
||||
@BeanProperty var minInstancesPerNode: Int = 1,
|
||||
@BeanProperty var minInfoGain: Double = 0.0,
|
||||
@BeanProperty var maxMemoryInMB: Int = 256,
|
||||
@BeanProperty var subsamplingRate: Double = 1,
|
||||
@BeanProperty var useNodeIdCache: Boolean = false,
|
||||
@BeanProperty var checkpointInterval: Int = 10) extends Serializable {
|
||||
class Strategy @Since("1.3.0") (
|
||||
@Since("1.0.0") @BeanProperty var algo: Algo,
|
||||
@Since("1.0.0") @BeanProperty var impurity: Impurity,
|
||||
@Since("1.0.0") @BeanProperty var maxDepth: Int,
|
||||
@Since("1.2.0") @BeanProperty var numClasses: Int = 2,
|
||||
@Since("1.0.0") @BeanProperty var maxBins: Int = 32,
|
||||
@Since("1.0.0") @BeanProperty var quantileCalculationStrategy: QuantileStrategy = Sort,
|
||||
@Since("1.0.0") @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
|
||||
@Since("1.2.0") @BeanProperty var minInstancesPerNode: Int = 1,
|
||||
@Since("1.2.0") @BeanProperty var minInfoGain: Double = 0.0,
|
||||
@Since("1.0.0") @BeanProperty var maxMemoryInMB: Int = 256,
|
||||
@Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1,
|
||||
@Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false,
|
||||
@Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10) extends Serializable {
|
||||
|
||||
/**
|
||||
*/
|
||||
|
@ -206,6 +206,7 @@ object Strategy {
|
|||
}
|
||||
|
||||
@deprecated("Use Strategy.defaultStrategy instead.", "1.5.0")
|
||||
@Since("1.2.0")
|
||||
def defaultStategy(algo: Algo): Strategy = defaultStrategy(algo)
|
||||
|
||||
}
|
||||
|
|
|
@ -43,7 +43,9 @@ import org.apache.spark.util.Utils
|
|||
*/
|
||||
@Since("1.0.0")
|
||||
@Experimental
|
||||
class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable with Saveable {
|
||||
class DecisionTreeModel @Since("1.0.0") (
|
||||
@Since("1.0.0") val topNode: Node,
|
||||
@Since("1.0.0") val algo: Algo) extends Serializable with Saveable {
|
||||
|
||||
/**
|
||||
* Predict values for a single data point using the model trained.
|
||||
|
@ -110,6 +112,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
|
|||
/**
|
||||
* Print the full model to a string.
|
||||
*/
|
||||
@Since("1.2.0")
|
||||
def toDebugString: String = {
|
||||
val header = toString + "\n"
|
||||
header + topNode.subtreeToString(2)
|
||||
|
|
|
@ -41,15 +41,15 @@ import org.apache.spark.mllib.linalg.Vector
|
|||
*/
|
||||
@Since("1.0.0")
|
||||
@DeveloperApi
|
||||
class Node (
|
||||
val id: Int,
|
||||
var predict: Predict,
|
||||
var impurity: Double,
|
||||
var isLeaf: Boolean,
|
||||
var split: Option[Split],
|
||||
var leftNode: Option[Node],
|
||||
var rightNode: Option[Node],
|
||||
var stats: Option[InformationGainStats]) extends Serializable with Logging {
|
||||
class Node @Since("1.2.0") (
|
||||
@Since("1.0.0") val id: Int,
|
||||
@Since("1.0.0") var predict: Predict,
|
||||
@Since("1.2.0") var impurity: Double,
|
||||
@Since("1.0.0") var isLeaf: Boolean,
|
||||
@Since("1.0.0") var split: Option[Split],
|
||||
@Since("1.0.0") var leftNode: Option[Node],
|
||||
@Since("1.0.0") var rightNode: Option[Node],
|
||||
@Since("1.0.0") var stats: Option[InformationGainStats]) extends Serializable with Logging {
|
||||
|
||||
override def toString: String = {
|
||||
s"id = $id, isLeaf = $isLeaf, predict = $predict, impurity = $impurity, " +
|
||||
|
|
|
@ -26,9 +26,9 @@ import org.apache.spark.annotation.{DeveloperApi, Since}
|
|||
*/
|
||||
@Since("1.2.0")
|
||||
@DeveloperApi
|
||||
class Predict(
|
||||
val predict: Double,
|
||||
val prob: Double = 0.0) extends Serializable {
|
||||
class Predict @Since("1.2.0") (
|
||||
@Since("1.2.0") val predict: Double,
|
||||
@Since("1.2.0") val prob: Double = 0.0) extends Serializable {
|
||||
|
||||
override def toString: String = s"$predict (prob = $prob)"
|
||||
|
||||
|
|
|
@ -34,10 +34,10 @@ import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
|
|||
@Since("1.0.0")
|
||||
@DeveloperApi
|
||||
case class Split(
|
||||
feature: Int,
|
||||
threshold: Double,
|
||||
featureType: FeatureType,
|
||||
categories: List[Double]) {
|
||||
@Since("1.0.0") feature: Int,
|
||||
@Since("1.0.0") threshold: Double,
|
||||
@Since("1.0.0") featureType: FeatureType,
|
||||
@Since("1.0.0") categories: List[Double]) {
|
||||
|
||||
override def toString: String = {
|
||||
s"Feature = $feature, threshold = $threshold, featureType = $featureType, " +
|
||||
|
|
|
@ -48,7 +48,9 @@ import org.apache.spark.util.Utils
|
|||
*/
|
||||
@Since("1.2.0")
|
||||
@Experimental
|
||||
class RandomForestModel(override val algo: Algo, override val trees: Array[DecisionTreeModel])
|
||||
class RandomForestModel @Since("1.2.0") (
|
||||
@Since("1.2.0") override val algo: Algo,
|
||||
@Since("1.2.0") override val trees: Array[DecisionTreeModel])
|
||||
extends TreeEnsembleModel(algo, trees, Array.fill(trees.length)(1.0),
|
||||
combiningStrategy = if (algo == Classification) Vote else Average)
|
||||
with Saveable {
|
||||
|
@ -115,10 +117,10 @@ object RandomForestModel extends Loader[RandomForestModel] {
|
|||
*/
|
||||
@Since("1.2.0")
|
||||
@Experimental
|
||||
class GradientBoostedTreesModel(
|
||||
override val algo: Algo,
|
||||
override val trees: Array[DecisionTreeModel],
|
||||
override val treeWeights: Array[Double])
|
||||
class GradientBoostedTreesModel @Since("1.2.0") (
|
||||
@Since("1.2.0") override val algo: Algo,
|
||||
@Since("1.2.0") override val trees: Array[DecisionTreeModel],
|
||||
@Since("1.2.0") override val treeWeights: Array[Double])
|
||||
extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum)
|
||||
with Saveable {
|
||||
|
||||
|
|
Loading…
Reference in a new issue