[SPARK-29116][PYTHON][ML] Refactor py classes related to DecisionTree

### What changes were proposed in this pull request?

- Move tree related classes to a separate file ```tree.py```
- add method ```predictLeaf``` in ```DecisionTreeModel```& ```TreeEnsembleModel```

### Why are the changes needed?

- keep parity between scala and python
- easy code maintenance

### Does this PR introduce any user-facing change?
Yes
add method ```predictLeaf``` in ```DecisionTreeModel```& ```TreeEnsembleModel```
add ```setMinWeightFractionPerNode```  in ```DecisionTreeClassifier``` and ```DecisionTreeRegressor```

### How was this patch tested?
add some doc tests

Closes #25929 from huaxingao/spark_29116.

Authored-by: Huaxin Gao <huaxing@us.ibm.com>
Signed-off-by: zhengruifeng <ruifengz@foxmail.com>
This commit is contained in:
Huaxin Gao 2019-10-12 22:13:50 +08:00 committed by zhengruifeng
parent beb8d2f8ad
commit 81362956a7
3 changed files with 490 additions and 370 deletions

View file

@ -22,9 +22,10 @@ from multiprocessing.pool import ThreadPool
from pyspark import since, keyword_only
from pyspark.ml import Estimator, Model
from pyspark.ml.param.shared import *
from pyspark.ml.regression import DecisionTreeModel, DecisionTreeParams, \
DecisionTreeRegressionModel, GBTParams, HasVarianceImpurity, RandomForestParams, \
TreeEnsembleModel
from pyspark.ml.tree import _DecisionTreeModel, _DecisionTreeParams, \
_TreeEnsembleModel, _RandomForestParams, _GBTParams, \
_HasVarianceImpurity, _TreeClassifierParams, _TreeEnsembleParams
from pyspark.ml.regression import DecisionTreeRegressionModel
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, \
JavaPredictor, JavaPredictorParams, JavaPredictionModel, JavaWrapper
@ -939,34 +940,17 @@ class BinaryLogisticRegressionTrainingSummary(BinaryLogisticRegressionSummary,
pass
class TreeClassifierParams(object):
@inherit_doc
class _DecisionTreeClassifierParams(_DecisionTreeParams, _TreeClassifierParams):
"""
Private class to track supported impurity measures.
.. versionadded:: 1.4.0
Params for :py:class:`DecisionTreeClassifier` and :py:class:`DecisionTreeClassificationModel`.
"""
supportedImpurities = ["entropy", "gini"]
impurity = Param(Params._dummy(), "impurity",
"Criterion used for information gain calculation (case-insensitive). " +
"Supported options: " +
", ".join(supportedImpurities), typeConverter=TypeConverters.toString)
def __init__(self):
super(TreeClassifierParams, self).__init__()
@since("1.6.0")
def getImpurity(self):
"""
Gets the value of impurity or its default value.
"""
return self.getOrDefault(self.impurity)
pass
@inherit_doc
class DecisionTreeClassifier(JavaProbabilisticClassifier, HasWeightCol,
DecisionTreeParams, TreeClassifierParams, HasCheckpointInterval,
HasSeed, JavaMLWritable, JavaMLReadable):
class DecisionTreeClassifier(JavaProbabilisticClassifier, _DecisionTreeClassifierParams,
JavaMLWritable, JavaMLReadable):
"""
`Decision tree <http://en.wikipedia.org/wiki/Decision_tree_learning>`_
learning algorithm for classification.
@ -1045,20 +1029,20 @@ class DecisionTreeClassifier(JavaProbabilisticClassifier, HasWeightCol,
probabilityCol="probability", rawPredictionCol="rawPrediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini",
seed=None, weightCol=None, leafCol=""):
seed=None, weightCol=None, leafCol="", minWeightFractionPerNode=0.0):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
probabilityCol="probability", rawPredictionCol="rawPrediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
seed=None, weightCol=None, leafCol="")
seed=None, weightCol=None, leafCol="", minWeightFractionPerNode=0.0)
"""
super(DecisionTreeClassifier, self).__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.classification.DecisionTreeClassifier", self.uid)
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
impurity="gini", leafCol="")
impurity="gini", leafCol="", minWeightFractionPerNode=0.0)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@ -1068,13 +1052,14 @@ class DecisionTreeClassifier(JavaProbabilisticClassifier, HasWeightCol,
probabilityCol="probability", rawPredictionCol="rawPrediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
impurity="gini", seed=None, weightCol=None, leafCol=""):
impurity="gini", seed=None, weightCol=None, leafCol="",
minWeightFractionPerNode=0.0):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
probabilityCol="probability", rawPredictionCol="rawPrediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
seed=None, weightCol=None, leafCol="")
seed=None, weightCol=None, leafCol="", minWeightFractionPerNode=0.0)
Sets params for the DecisionTreeClassifier.
"""
kwargs = self._input_kwargs
@ -1101,6 +1086,13 @@ class DecisionTreeClassifier(JavaProbabilisticClassifier, HasWeightCol,
"""
return self._set(minInstancesPerNode=value)
@since("3.0.0")
def setMinWeightFractionPerNode(self, value):
"""
Sets the value of :py:attr:`minWeightFractionPerNode`.
"""
return self._set(minWeightFractionPerNode=value)
def setMinInfoGain(self, value):
"""
Sets the value of :py:attr:`minInfoGain`.
@ -1128,8 +1120,9 @@ class DecisionTreeClassifier(JavaProbabilisticClassifier, HasWeightCol,
@inherit_doc
class DecisionTreeClassificationModel(DecisionTreeModel, JavaProbabilisticClassificationModel,
JavaMLWritable, JavaMLReadable):
class DecisionTreeClassificationModel(_DecisionTreeModel, JavaProbabilisticClassificationModel,
_DecisionTreeClassifierParams, JavaMLWritable,
JavaMLReadable):
"""
Model fitted by DecisionTreeClassifier.
@ -1159,8 +1152,15 @@ class DecisionTreeClassificationModel(DecisionTreeModel, JavaProbabilisticClassi
@inherit_doc
class RandomForestClassifier(JavaProbabilisticClassifier, HasSeed, RandomForestParams,
TreeClassifierParams, HasCheckpointInterval,
class _RandomForestClassifierParams(_RandomForestParams, _TreeClassifierParams):
"""
Params for :py:class:`RandomForestClassifier` and :py:class:`RandomForestClassificationModel`.
"""
pass
@inherit_doc
class RandomForestClassifier(JavaProbabilisticClassifier, _RandomForestClassifierParams,
JavaMLWritable, JavaMLReadable):
"""
`Random Forest <http://en.wikipedia.org/wiki/Random_forest>`_
@ -1230,14 +1230,14 @@ class RandomForestClassifier(JavaProbabilisticClassifier, HasSeed, RandomForestP
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini",
numTrees=20, featureSubsetStrategy="auto", seed=None, subsamplingRate=1.0,
leafCol=""):
leafCol="", minWeightFractionPerNode=0.0):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
probabilityCol="probability", rawPredictionCol="rawPrediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
numTrees=20, featureSubsetStrategy="auto", seed=None, subsamplingRate=1.0, \
leafCol="")
leafCol="", minWeightFractionPerNode=0.0)
"""
super(RandomForestClassifier, self).__init__()
self._java_obj = self._new_java_obj(
@ -1245,7 +1245,7 @@ class RandomForestClassifier(JavaProbabilisticClassifier, HasSeed, RandomForestP
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
impurity="gini", numTrees=20, featureSubsetStrategy="auto",
subsamplingRate=1.0, leafCol="")
subsamplingRate=1.0, leafCol="", minWeightFractionPerNode=0.0)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@ -1256,14 +1256,14 @@ class RandomForestClassifier(JavaProbabilisticClassifier, HasSeed, RandomForestP
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None,
impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0,
leafCol=""):
leafCol="", minWeightFractionPerNode=0.0):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
probabilityCol="probability", rawPredictionCol="rawPrediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \
impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0, \
leafCol="")
leafCol="", minWeightFractionPerNode=0.0)
Sets params for linear classification.
"""
kwargs = self._input_kwargs
@ -1337,8 +1337,9 @@ class RandomForestClassifier(JavaProbabilisticClassifier, HasSeed, RandomForestP
return self._set(featureSubsetStrategy=value)
class RandomForestClassificationModel(TreeEnsembleModel, JavaProbabilisticClassificationModel,
JavaMLWritable, JavaMLReadable):
class RandomForestClassificationModel(_TreeEnsembleModel, JavaProbabilisticClassificationModel,
_RandomForestClassifierParams, JavaMLWritable,
JavaMLReadable):
"""
Model fitted by RandomForestClassifier.
@ -1367,7 +1368,7 @@ class RandomForestClassificationModel(TreeEnsembleModel, JavaProbabilisticClassi
return [DecisionTreeClassificationModel(m) for m in list(self._call_java("trees"))]
class GBTClassifierParams(GBTParams, HasVarianceImpurity):
class GBTClassifierParams(_GBTParams, _HasVarianceImpurity):
"""
Private class to track supported GBTClassifier params.
@ -1390,8 +1391,8 @@ class GBTClassifierParams(GBTParams, HasVarianceImpurity):
@inherit_doc
class GBTClassifier(JavaProbabilisticClassifier, GBTClassifierParams, HasCheckpointInterval,
HasSeed, JavaMLWritable, JavaMLReadable):
class GBTClassifier(JavaProbabilisticClassifier, GBTClassifierParams,
JavaMLWritable, JavaMLReadable):
"""
`Gradient-Boosted Trees (GBTs) <http://en.wikipedia.org/wiki/Gradient_boosting>`_
learning algorithm for classification.
@ -1485,14 +1486,14 @@ class GBTClassifier(JavaProbabilisticClassifier, GBTClassifierParams, HasCheckpo
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic",
maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, impurity="variance",
featureSubsetStrategy="all", validationTol=0.01, validationIndicatorCol=None,
leafCol=""):
leafCol="", minWeightFractionPerNode=0.0):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \
impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \
validationIndicatorCol=None, leafCol="")
validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0)
"""
super(GBTClassifier, self).__init__()
self._java_obj = self._new_java_obj(
@ -1501,7 +1502,7 @@ class GBTClassifier(JavaProbabilisticClassifier, GBTClassifierParams, HasCheckpo
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
lossType="logistic", maxIter=20, stepSize=0.1, subsamplingRate=1.0,
impurity="variance", featureSubsetStrategy="all", validationTol=0.01,
leafCol="")
leafCol="", minWeightFractionPerNode=0.0)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@ -1512,14 +1513,14 @@ class GBTClassifier(JavaProbabilisticClassifier, GBTClassifierParams, HasCheckpo
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0,
impurity="variance", featureSubsetStrategy="all", validationTol=0.01,
validationIndicatorCol=None, leafCol=""):
validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \
impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \
validationIndicatorCol=None, leafCol="")
validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0)
Sets params for Gradient Boosted Tree Classification.
"""
kwargs = self._input_kwargs
@ -1600,8 +1601,8 @@ class GBTClassifier(JavaProbabilisticClassifier, GBTClassifierParams, HasCheckpo
return self._set(validationIndicatorCol=value)
class GBTClassificationModel(TreeEnsembleModel, JavaProbabilisticClassificationModel,
JavaMLWritable, JavaMLReadable):
class GBTClassificationModel(_TreeEnsembleModel, JavaProbabilisticClassificationModel,
GBTClassifierParams, JavaMLWritable, JavaMLReadable):
"""
Model fitted by GBTClassifier.

View file

@ -19,6 +19,9 @@ import sys
from pyspark import since, keyword_only
from pyspark.ml.param.shared import *
from pyspark.ml.tree import _DecisionTreeModel, _DecisionTreeParams, \
_TreeEnsembleModel, _TreeEnsembleParams, _RandomForestParams, _GBTParams, \
_HasVarianceImpurity, _TreeRegressorParams
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, \
JavaPredictor, JavaPredictionModel, JavaWrapper
@ -600,233 +603,19 @@ class IsotonicRegressionModel(JavaModel, _IsotonicRegressionBase,
return self._call_java("predictions")
class DecisionTreeParams(Params):
class _DecisionTreeRegressorParams(_DecisionTreeParams, _TreeRegressorParams, HasVarianceCol):
"""
Mixin for Decision Tree parameters.
"""
leafCol = Param(Params._dummy(), "leafCol", "Leaf indices column name. Predicted leaf " +
"index of each instance in each tree by preorder.",
typeConverter=TypeConverters.toString)
maxDepth = Param(Params._dummy(), "maxDepth", "Maximum depth of the tree. (>= 0) E.g., " +
"depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.",
typeConverter=TypeConverters.toInt)
maxBins = Param(Params._dummy(), "maxBins", "Max number of bins for discretizing continuous " +
"features. Must be >=2 and >= number of categories for any categorical " +
"feature.", typeConverter=TypeConverters.toInt)
minInstancesPerNode = Param(Params._dummy(), "minInstancesPerNode", "Minimum number of " +
"instances each child must have after split. If a split causes " +
"the left or right child to have fewer than " +
"minInstancesPerNode, the split will be discarded as invalid. " +
"Should be >= 1.", typeConverter=TypeConverters.toInt)
minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information gain for a split " +
"to be considered at a tree node.", typeConverter=TypeConverters.toFloat)
maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to " +
"histogram aggregation. If too small, then 1 node will be split per " +
"iteration, and its aggregates may exceed this size.",
typeConverter=TypeConverters.toInt)
cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass " +
"trees to executors to match instances with nodes. If true, the " +
"algorithm will cache node IDs for each instance. Caching can speed " +
"up training of deeper trees. Users can set how often should the cache " +
"be checkpointed or disable it by setting checkpointInterval.",
typeConverter=TypeConverters.toBoolean)
def __init__(self):
super(DecisionTreeParams, self).__init__()
def setLeafCol(self, value):
"""
Sets the value of :py:attr:`leafCol`.
"""
return self._set(leafCol=value)
def getLeafCol(self):
"""
Gets the value of leafCol or its default value.
"""
return self.getOrDefault(self.leafCol)
def getMaxDepth(self):
"""
Gets the value of maxDepth or its default value.
"""
return self.getOrDefault(self.maxDepth)
def getMaxBins(self):
"""
Gets the value of maxBins or its default value.
"""
return self.getOrDefault(self.maxBins)
def getMinInstancesPerNode(self):
"""
Gets the value of minInstancesPerNode or its default value.
"""
return self.getOrDefault(self.minInstancesPerNode)
def getMinInfoGain(self):
"""
Gets the value of minInfoGain or its default value.
"""
return self.getOrDefault(self.minInfoGain)
def getMaxMemoryInMB(self):
"""
Gets the value of maxMemoryInMB or its default value.
"""
return self.getOrDefault(self.maxMemoryInMB)
def getCacheNodeIds(self):
"""
Gets the value of cacheNodeIds or its default value.
"""
return self.getOrDefault(self.cacheNodeIds)
class TreeEnsembleParams(DecisionTreeParams):
"""
Mixin for Decision Tree-based ensemble algorithms parameters.
"""
subsamplingRate = Param(Params._dummy(), "subsamplingRate", "Fraction of the training data " +
"used for learning each decision tree, in range (0, 1].",
typeConverter=TypeConverters.toFloat)
supportedFeatureSubsetStrategies = ["auto", "all", "onethird", "sqrt", "log2"]
featureSubsetStrategy = \
Param(Params._dummy(), "featureSubsetStrategy",
"The number of features to consider for splits at each tree node. Supported " +
"options: 'auto' (choose automatically for task: If numTrees == 1, set to " +
"'all'. If numTrees > 1 (forest), set to 'sqrt' for classification and to " +
"'onethird' for regression), 'all' (use all features), 'onethird' (use " +
"1/3 of the features), 'sqrt' (use sqrt(number of features)), 'log2' (use " +
"log2(number of features)), 'n' (when n is in the range (0, 1.0], use " +
"n * number of features. When n is in the range (1, number of features), use" +
" n features). default = 'auto'", typeConverter=TypeConverters.toString)
def __init__(self):
super(TreeEnsembleParams, self).__init__()
@since("1.4.0")
def getSubsamplingRate(self):
"""
Gets the value of subsamplingRate or its default value.
"""
return self.getOrDefault(self.subsamplingRate)
@since("1.4.0")
def getFeatureSubsetStrategy(self):
"""
Gets the value of featureSubsetStrategy or its default value.
"""
return self.getOrDefault(self.featureSubsetStrategy)
class HasVarianceImpurity(Params):
"""
Private class to track supported impurity measures.
"""
supportedImpurities = ["variance"]
impurity = Param(Params._dummy(), "impurity",
"Criterion used for information gain calculation (case-insensitive). " +
"Supported options: " +
", ".join(supportedImpurities), typeConverter=TypeConverters.toString)
def __init__(self):
super(HasVarianceImpurity, self).__init__()
@since("1.4.0")
def getImpurity(self):
"""
Gets the value of impurity or its default value.
"""
return self.getOrDefault(self.impurity)
class TreeRegressorParams(HasVarianceImpurity):
pass
class RandomForestParams(TreeEnsembleParams):
"""
Private class to track supported random forest parameters.
"""
numTrees = Param(Params._dummy(), "numTrees", "Number of trees to train (>= 1).",
typeConverter=TypeConverters.toInt)
def __init__(self):
super(RandomForestParams, self).__init__()
@since("1.4.0")
def getNumTrees(self):
"""
Gets the value of numTrees or its default value.
"""
return self.getOrDefault(self.numTrees)
class GBTParams(TreeEnsembleParams, HasMaxIter, HasStepSize, HasValidationIndicatorCol):
"""
Private class to track supported GBT params.
"""
stepSize = Param(Params._dummy(), "stepSize",
"Step size (a.k.a. learning rate) in interval (0, 1] for shrinking " +
"the contribution of each estimator.",
typeConverter=TypeConverters.toFloat)
validationTol = Param(Params._dummy(), "validationTol",
"Threshold for stopping early when fit with validation is used. " +
"If the error rate on the validation input changes by less than the " +
"validationTol, then learning will stop early (before `maxIter`). " +
"This parameter is ignored when fit without validation is used.",
typeConverter=TypeConverters.toFloat)
@since("3.0.0")
def getValidationTol(self):
"""
Gets the value of validationTol or its default value.
"""
return self.getOrDefault(self.validationTol)
class GBTRegressorParams(GBTParams, TreeRegressorParams):
"""
Private class to track supported GBTRegressor params.
Params for :py:class:`DecisionTreeRegressor` and :py:class:`DecisionTreeRegressionModel`.
.. versionadded:: 3.0.0
"""
supportedLossTypes = ["squared", "absolute"]
lossType = Param(Params._dummy(), "lossType",
"Loss function which GBT tries to minimize (case-insensitive). " +
"Supported options: " + ", ".join(supportedLossTypes),
typeConverter=TypeConverters.toString)
@since("1.4.0")
def getLossType(self):
"""
Gets the value of lossType or its default value.
"""
return self.getOrDefault(self.lossType)
pass
@inherit_doc
class DecisionTreeRegressor(JavaPredictor, HasWeightCol, DecisionTreeParams, TreeRegressorParams,
HasCheckpointInterval, HasSeed, JavaMLWritable, JavaMLReadable,
HasVarianceCol):
class DecisionTreeRegressor(JavaPredictor, _DecisionTreeRegressorParams, JavaMLWritable,
JavaMLReadable):
"""
`Decision tree <http://en.wikipedia.org/wiki/Decision_tree_learning>`_
learning algorithm for regression.
@ -836,8 +625,12 @@ class DecisionTreeRegressor(JavaPredictor, HasWeightCol, DecisionTreeParams, Tre
>>> df = spark.createDataFrame([
... (1.0, Vectors.dense(1.0)),
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
>>> dt = DecisionTreeRegressor(maxDepth=2, varianceCol="variance", leafCol="leafId")
>>> dt = DecisionTreeRegressor(maxDepth=2, varianceCol="variance")
>>> model = dt.fit(df)
>>> model.getVarianceCol()
'variance'
>>> model.setLeafCol("leafId")
DecisionTreeRegressionModel...
>>> model.depth
1
>>> model.numNodes
@ -852,6 +645,8 @@ class DecisionTreeRegressor(JavaPredictor, HasWeightCol, DecisionTreeParams, Tre
>>> result = model.transform(test0).head()
>>> result.prediction
0.0
>>> model.predictLeaf(test0.head().features)
0.0
>>> result.leafId
0.0
>>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
@ -888,20 +683,21 @@ class DecisionTreeRegressor(JavaPredictor, HasWeightCol, DecisionTreeParams, Tre
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance",
seed=None, varianceCol=None, weightCol=None, leafCol=""):
seed=None, varianceCol=None, weightCol=None, leafCol="",
minWeightFractionPerNode=0.0):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
impurity="variance", seed=None, varianceCol=None, weightCol=None, \
leafCol="")
leafCol="", minWeightFractionPerNode=0.0)
"""
super(DecisionTreeRegressor, self).__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.regression.DecisionTreeRegressor", self.uid)
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
impurity="variance", leafCol="")
impurity="variance", leafCol="", minWeightFractionPerNode=0.0)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@ -911,13 +707,13 @@ class DecisionTreeRegressor(JavaPredictor, HasWeightCol, DecisionTreeParams, Tre
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
impurity="variance", seed=None, varianceCol=None, weightCol=None,
leafCol=""):
leafCol="", minWeightFractionPerNode=0.0):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
impurity="variance", seed=None, varianceCol=None, weightCol=None, \
leafCol="")
leafCol="", minWeightFractionPerNode=0.0)
Sets params for the DecisionTreeRegressor.
"""
kwargs = self._input_kwargs
@ -944,6 +740,13 @@ class DecisionTreeRegressor(JavaPredictor, HasWeightCol, DecisionTreeParams, Tre
"""
return self._set(minInstancesPerNode=value)
@since("3.0.0")
def setMinWeightFractionPerNode(self, value):
"""
Sets the value of :py:attr:`minWeightFractionPerNode`.
"""
return self._set(minWeightFractionPerNode=value)
def setMinInfoGain(self, value):
"""
Sets the value of :py:attr:`minInfoGain`.
@ -971,79 +774,8 @@ class DecisionTreeRegressor(JavaPredictor, HasWeightCol, DecisionTreeParams, Tre
@inherit_doc
class DecisionTreeModel(JavaPredictionModel):
"""
Abstraction for Decision Tree models.
.. versionadded:: 1.5.0
"""
@property
@since("1.5.0")
def numNodes(self):
"""Return number of nodes of the decision tree."""
return self._call_java("numNodes")
@property
@since("1.5.0")
def depth(self):
"""Return depth of the decision tree."""
return self._call_java("depth")
@property
@since("2.0.0")
def toDebugString(self):
"""Full description of model."""
return self._call_java("toDebugString")
def __repr__(self):
return self._call_java("toString")
@inherit_doc
class TreeEnsembleModel(JavaModel):
"""
(private abstraction)
Represents a tree ensemble model.
"""
@property
@since("2.0.0")
def trees(self):
"""Trees in this ensemble. Warning: These have null parent Estimators."""
return [DecisionTreeModel(m) for m in list(self._call_java("trees"))]
@property
@since("2.0.0")
def getNumTrees(self):
"""Number of trees in ensemble."""
return self._call_java("getNumTrees")
@property
@since("1.5.0")
def treeWeights(self):
"""Return the weights for each tree"""
return list(self._call_java("javaTreeWeights"))
@property
@since("2.0.0")
def totalNumNodes(self):
"""Total number of nodes, summed over all trees in the ensemble."""
return self._call_java("totalNumNodes")
@property
@since("2.0.0")
def toDebugString(self):
"""Full description of model."""
return self._call_java("toDebugString")
def __repr__(self):
return self._call_java("toString")
@inherit_doc
class DecisionTreeRegressionModel(DecisionTreeModel, JavaMLWritable, JavaMLReadable):
class DecisionTreeRegressionModel(_DecisionTreeModel, _DecisionTreeRegressorParams,
JavaMLWritable, JavaMLReadable):
"""
Model fitted by :class:`DecisionTreeRegressor`.
@ -1072,9 +804,18 @@ class DecisionTreeRegressionModel(DecisionTreeModel, JavaMLWritable, JavaMLReada
return self._call_java("featureImportances")
class _RandomForestRegressorParams(_RandomForestParams, _TreeRegressorParams):
"""
Params for :py:class:`RandomForestRegressor` and :py:class:`RandomForestRegressionModel`.
.. versionadded:: 3.0.0
"""
pass
@inherit_doc
class RandomForestRegressor(JavaPredictor, HasSeed, RandomForestParams, TreeRegressorParams,
HasCheckpointInterval, JavaMLWritable, JavaMLReadable):
class RandomForestRegressor(JavaPredictor, _RandomForestRegressorParams, JavaMLWritable,
JavaMLReadable):
"""
`Random Forest <http://en.wikipedia.org/wiki/Random_forest>`_
learning algorithm for regression.
@ -1085,8 +826,12 @@ class RandomForestRegressor(JavaPredictor, HasSeed, RandomForestParams, TreeRegr
>>> df = spark.createDataFrame([
... (1.0, Vectors.dense(1.0)),
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
>>> rf = RandomForestRegressor(numTrees=2, maxDepth=2, seed=42, leafCol="leafId")
>>> rf = RandomForestRegressor(numTrees=2, maxDepth=2, seed=42)
>>> model = rf.fit(df)
>>> model.getSeed()
42
>>> model.setLeafCol("leafId")
RandomForestRegressionModel...
>>> model.featureImportances
SparseVector(1, {0: 1.0})
>>> allclose(model.treeWeights, [1.0, 1.0])
@ -1094,6 +839,8 @@ class RandomForestRegressor(JavaPredictor, HasSeed, RandomForestParams, TreeRegr
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.predict(test0.head().features)
0.0
>>> model.predictLeaf(test0.head().features)
DenseVector([0.0, 0.0])
>>> result = model.transform(test0).head()
>>> result.prediction
0.0
@ -1127,13 +874,13 @@ class RandomForestRegressor(JavaPredictor, HasSeed, RandomForestParams, TreeRegr
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20,
featureSubsetStrategy="auto", leafCol=""):
featureSubsetStrategy="auto", leafCol="", minWeightFractionPerNode=0.0):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, \
featureSubsetStrategy="auto", leafCol="")
featureSubsetStrategy="auto", leafCol=", minWeightFractionPerNode=0.0")
"""
super(RandomForestRegressor, self).__init__()
self._java_obj = self._new_java_obj(
@ -1141,7 +888,7 @@ class RandomForestRegressor(JavaPredictor, HasSeed, RandomForestParams, TreeRegr
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
impurity="variance", subsamplingRate=1.0, numTrees=20,
featureSubsetStrategy="auto", leafCol="")
featureSubsetStrategy="auto", leafCol="", minWeightFractionPerNode=0.0)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@ -1151,13 +898,13 @@ class RandomForestRegressor(JavaPredictor, HasSeed, RandomForestParams, TreeRegr
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20,
featureSubsetStrategy="auto", leafCol=""):
featureSubsetStrategy="auto", leafCol="", minWeightFractionPerNode=0.0):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, \
featureSubsetStrategy="auto", leafCol="")
featureSubsetStrategy="auto", leafCol="", minWeightFractionPerNode=0.0)
Sets params for linear regression.
"""
kwargs = self._input_kwargs
@ -1231,8 +978,8 @@ class RandomForestRegressor(JavaPredictor, HasSeed, RandomForestParams, TreeRegr
return self._set(featureSubsetStrategy=value)
class RandomForestRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable,
JavaMLReadable):
class RandomForestRegressionModel(_TreeEnsembleModel, _RandomForestRegressorParams,
JavaMLWritable, JavaMLReadable):
"""
Model fitted by :class:`RandomForestRegressor`.
@ -1261,9 +1008,30 @@ class RandomForestRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaML
return self._call_java("featureImportances")
class _GBTRegressorParams(_GBTParams, _TreeRegressorParams):
"""
Params for :py:class:`GBTRegressor` and :py:class:`GBTRegressorModel`.
.. versionadded:: 3.0.0
"""
supportedLossTypes = ["squared", "absolute"]
lossType = Param(Params._dummy(), "lossType",
"Loss function which GBT tries to minimize (case-insensitive). " +
"Supported options: " + ", ".join(supportedLossTypes),
typeConverter=TypeConverters.toString)
@since("1.4.0")
def getLossType(self):
"""
Gets the value of lossType or its default value.
"""
return self.getOrDefault(self.lossType)
@inherit_doc
class GBTRegressor(JavaPredictor, GBTRegressorParams, HasCheckpointInterval, HasSeed,
JavaMLWritable, JavaMLReadable):
class GBTRegressor(JavaPredictor, _GBTRegressorParams, JavaMLWritable, JavaMLReadable):
"""
`Gradient-Boosted Trees (GBTs) <http://en.wikipedia.org/wiki/Gradient_boosting>`_
learning algorithm for regression.
@ -1280,8 +1048,6 @@ class GBTRegressor(JavaPredictor, GBTRegressorParams, HasCheckpointInterval, Has
>>> print(gbt.getFeatureSubsetStrategy())
all
>>> model = gbt.fit(df)
>>> model.setFeaturesCol("features")
GBTRegressionModel...
>>> model.featureImportances
SparseVector(1, {0: 1.0})
>>> model.numFeatures
@ -1291,6 +1057,8 @@ class GBTRegressor(JavaPredictor, GBTRegressorParams, HasCheckpointInterval, Has
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.predict(test0.head().features)
0.0
>>> model.predictLeaf(test0.head().features)
DenseVector([0.0, 0.0, 0.0, 0.0, 0.0])
>>> result = model.transform(test0).head()
>>> result.prediction
0.0
@ -1332,14 +1100,14 @@ class GBTRegressor(JavaPredictor, GBTRegressorParams, HasCheckpointInterval, Has
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None,
impurity="variance", featureSubsetStrategy="all", validationTol=0.01,
validationIndicatorCol=None, leafCol=""):
validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \
impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \
validationIndicatorCol=None, leafCol="")
validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0)
"""
super(GBTRegressor, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.GBTRegressor", self.uid)
@ -1347,7 +1115,7 @@ class GBTRegressor(JavaPredictor, GBTRegressorParams, HasCheckpointInterval, Has
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1,
impurity="variance", featureSubsetStrategy="all", validationTol=0.01,
leafCol="")
leafCol="", minWeightFractionPerNode=0.0)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@ -1358,14 +1126,14 @@ class GBTRegressor(JavaPredictor, GBTRegressorParams, HasCheckpointInterval, Has
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None,
impuriy="variance", featureSubsetStrategy="all", validationTol=0.01,
validationIndicatorCol=None, leafCol=""):
validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \
impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \
validationIndicatorCol=None, leafCol="")
validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0)
Sets params for Gradient Boosted Tree Regression.
"""
kwargs = self._input_kwargs
@ -1446,7 +1214,7 @@ class GBTRegressor(JavaPredictor, GBTRegressorParams, HasCheckpointInterval, Has
return self._set(validationIndicatorCol=value)
class GBTRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable):
class GBTRegressionModel(_TreeEnsembleModel, _GBTRegressorParams, JavaMLWritable, JavaMLReadable):
"""
Model fitted by :class:`GBTRegressor`.

351
python/pyspark/ml/tree.py Normal file
View file

@ -0,0 +1,351 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pyspark import since, keyword_only
from pyspark.ml.param.shared import *
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, \
JavaPredictor, JavaPredictionModel
from pyspark.ml.common import inherit_doc, _java2py, _py2java
@inherit_doc
class _DecisionTreeModel(JavaPredictionModel):
"""
Abstraction for Decision Tree models.
.. versionadded:: 1.5.0
"""
@property
@since("1.5.0")
def numNodes(self):
"""Return number of nodes of the decision tree."""
return self._call_java("numNodes")
@property
@since("1.5.0")
def depth(self):
"""Return depth of the decision tree."""
return self._call_java("depth")
@property
@since("2.0.0")
def toDebugString(self):
"""Full description of model."""
return self._call_java("toDebugString")
@since("3.0.0")
def predictLeaf(self, value):
"""
Predict the indices of the leaves corresponding to the feature vector.
"""
return self._call_java("predictLeaf", value)
def __repr__(self):
return self._call_java("toString")
class _DecisionTreeParams(HasCheckpointInterval, HasSeed, HasWeightCol):
"""
Mixin for Decision Tree parameters.
"""
leafCol = Param(Params._dummy(), "leafCol", "Leaf indices column name. Predicted leaf " +
"index of each instance in each tree by preorder.",
typeConverter=TypeConverters.toString)
maxDepth = Param(Params._dummy(), "maxDepth", "Maximum depth of the tree. (>= 0) E.g., " +
"depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.",
typeConverter=TypeConverters.toInt)
maxBins = Param(Params._dummy(), "maxBins", "Max number of bins for discretizing continuous " +
"features. Must be >=2 and >= number of categories for any categorical " +
"feature.", typeConverter=TypeConverters.toInt)
minInstancesPerNode = Param(Params._dummy(), "minInstancesPerNode", "Minimum number of " +
"instances each child must have after split. If a split causes " +
"the left or right child to have fewer than " +
"minInstancesPerNode, the split will be discarded as invalid. " +
"Should be >= 1.", typeConverter=TypeConverters.toInt)
minWeightFractionPerNode = Param(Params._dummy(), "minWeightFractionPerNode", "Minimum "
"fraction of the weighted sample count that each child "
"must have after split. If a split causes the fraction "
"of the total weight in the left or right child to be "
"less than minWeightFractionPerNode, the split will be "
"discarded as invalid. Should be in interval [0.0, 0.5).",
typeConverter=TypeConverters.toFloat)
minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information gain for a split " +
"to be considered at a tree node.", typeConverter=TypeConverters.toFloat)
maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to " +
"histogram aggregation. If too small, then 1 node will be split per " +
"iteration, and its aggregates may exceed this size.",
typeConverter=TypeConverters.toInt)
cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass " +
"trees to executors to match instances with nodes. If true, the " +
"algorithm will cache node IDs for each instance. Caching can speed " +
"up training of deeper trees. Users can set how often should the cache " +
"be checkpointed or disable it by setting checkpointInterval.",
typeConverter=TypeConverters.toBoolean)
def __init__(self):
super(_DecisionTreeParams, self).__init__()
def setLeafCol(self, value):
"""
Sets the value of :py:attr:`leafCol`.
"""
return self._set(leafCol=value)
def getLeafCol(self):
"""
Gets the value of leafCol or its default value.
"""
return self.getOrDefault(self.leafCol)
def getMaxDepth(self):
"""
Gets the value of maxDepth or its default value.
"""
return self.getOrDefault(self.maxDepth)
def getMaxBins(self):
"""
Gets the value of maxBins or its default value.
"""
return self.getOrDefault(self.maxBins)
def getMinInstancesPerNode(self):
"""
Gets the value of minInstancesPerNode or its default value.
"""
return self.getOrDefault(self.minInstancesPerNode)
def getMinWeightFractionPerNode(self):
"""
Gets the value of minWeightFractionPerNode or its default value.
"""
return self.getOrDefault(self.minWeightFractionPerNode)
def getMinInfoGain(self):
"""
Gets the value of minInfoGain or its default value.
"""
return self.getOrDefault(self.minInfoGain)
def getMaxMemoryInMB(self):
"""
Gets the value of maxMemoryInMB or its default value.
"""
return self.getOrDefault(self.maxMemoryInMB)
def getCacheNodeIds(self):
"""
Gets the value of cacheNodeIds or its default value.
"""
return self.getOrDefault(self.cacheNodeIds)
@inherit_doc
class _TreeEnsembleModel(JavaPredictionModel):
"""
(private abstraction)
Represents a tree ensemble model.
"""
@property
@since("2.0.0")
def trees(self):
"""Trees in this ensemble. Warning: These have null parent Estimators."""
return [_DecisionTreeModel(m) for m in list(self._call_java("trees"))]
@property
@since("2.0.0")
def getNumTrees(self):
"""Number of trees in ensemble."""
return self._call_java("getNumTrees")
@property
@since("1.5.0")
def treeWeights(self):
"""Return the weights for each tree"""
return list(self._call_java("javaTreeWeights"))
@property
@since("2.0.0")
def totalNumNodes(self):
"""Total number of nodes, summed over all trees in the ensemble."""
return self._call_java("totalNumNodes")
@property
@since("2.0.0")
def toDebugString(self):
"""Full description of model."""
return self._call_java("toDebugString")
@since("3.0.0")
def predictLeaf(self, value):
"""
Predict the indices of the leaves corresponding to the feature vector.
"""
return self._call_java("predictLeaf", value)
def __repr__(self):
return self._call_java("toString")
class _TreeEnsembleParams(_DecisionTreeParams):
"""
Mixin for Decision Tree-based ensemble algorithms parameters.
"""
subsamplingRate = Param(Params._dummy(), "subsamplingRate", "Fraction of the training data " +
"used for learning each decision tree, in range (0, 1].",
typeConverter=TypeConverters.toFloat)
supportedFeatureSubsetStrategies = ["auto", "all", "onethird", "sqrt", "log2"]
featureSubsetStrategy = \
Param(Params._dummy(), "featureSubsetStrategy",
"The number of features to consider for splits at each tree node. Supported " +
"options: 'auto' (choose automatically for task: If numTrees == 1, set to " +
"'all'. If numTrees > 1 (forest), set to 'sqrt' for classification and to " +
"'onethird' for regression), 'all' (use all features), 'onethird' (use " +
"1/3 of the features), 'sqrt' (use sqrt(number of features)), 'log2' (use " +
"log2(number of features)), 'n' (when n is in the range (0, 1.0], use " +
"n * number of features. When n is in the range (1, number of features), use" +
" n features). default = 'auto'", typeConverter=TypeConverters.toString)
def __init__(self):
super(_TreeEnsembleParams, self).__init__()
@since("1.4.0")
def getSubsamplingRate(self):
"""
Gets the value of subsamplingRate or its default value.
"""
return self.getOrDefault(self.subsamplingRate)
@since("1.4.0")
def getFeatureSubsetStrategy(self):
"""
Gets the value of featureSubsetStrategy or its default value.
"""
return self.getOrDefault(self.featureSubsetStrategy)
class _RandomForestParams(_TreeEnsembleParams):
"""
Private class to track supported random forest parameters.
"""
numTrees = Param(Params._dummy(), "numTrees", "Number of trees to train (>= 1).",
typeConverter=TypeConverters.toInt)
def __init__(self):
super(_RandomForestParams, self).__init__()
@since("1.4.0")
def getNumTrees(self):
"""
Gets the value of numTrees or its default value.
"""
return self.getOrDefault(self.numTrees)
class _GBTParams(_TreeEnsembleParams, HasMaxIter, HasStepSize, HasValidationIndicatorCol):
"""
Private class to track supported GBT params.
"""
stepSize = Param(Params._dummy(), "stepSize",
"Step size (a.k.a. learning rate) in interval (0, 1] for shrinking " +
"the contribution of each estimator.",
typeConverter=TypeConverters.toFloat)
validationTol = Param(Params._dummy(), "validationTol",
"Threshold for stopping early when fit with validation is used. " +
"If the error rate on the validation input changes by less than the " +
"validationTol, then learning will stop early (before `maxIter`). " +
"This parameter is ignored when fit without validation is used.",
typeConverter=TypeConverters.toFloat)
@since("3.0.0")
def getValidationTol(self):
"""
Gets the value of validationTol or its default value.
"""
return self.getOrDefault(self.validationTol)
class _HasVarianceImpurity(Params):
"""
Private class to track supported impurity measures.
"""
supportedImpurities = ["variance"]
impurity = Param(Params._dummy(), "impurity",
"Criterion used for information gain calculation (case-insensitive). " +
"Supported options: " +
", ".join(supportedImpurities), typeConverter=TypeConverters.toString)
def __init__(self):
super(_HasVarianceImpurity, self).__init__()
@since("1.4.0")
def getImpurity(self):
"""
Gets the value of impurity or its default value.
"""
return self.getOrDefault(self.impurity)
class _TreeClassifierParams(object):
"""
Private class to track supported impurity measures.
.. versionadded:: 1.4.0
"""
supportedImpurities = ["entropy", "gini"]
impurity = Param(Params._dummy(), "impurity",
"Criterion used for information gain calculation (case-insensitive). " +
"Supported options: " +
", ".join(supportedImpurities), typeConverter=TypeConverters.toString)
def __init__(self):
super(_TreeClassifierParams, self).__init__()
@since("1.6.0")
def getImpurity(self):
"""
Gets the value of impurity or its default value.
"""
return self.getOrDefault(self.impurity)
class _TreeRegressorParams(_HasVarianceImpurity):
"""
Private class to track supported impurity measures.
"""
pass