[SPARK-29116][PYTHON][ML] Refactor py classes related to DecisionTree
### What changes were proposed in this pull request? - Move tree related classes to a separate file ```tree.py``` - add method ```predictLeaf``` in ```DecisionTreeModel```& ```TreeEnsembleModel``` ### Why are the changes needed? - keep parity between scala and python - easy code maintenance ### Does this PR introduce any user-facing change? Yes add method ```predictLeaf``` in ```DecisionTreeModel```& ```TreeEnsembleModel``` add ```setMinWeightFractionPerNode``` in ```DecisionTreeClassifier``` and ```DecisionTreeRegressor``` ### How was this patch tested? add some doc tests Closes #25929 from huaxingao/spark_29116. Authored-by: Huaxin Gao <huaxing@us.ibm.com> Signed-off-by: zhengruifeng <ruifengz@foxmail.com>
This commit is contained in:
parent
beb8d2f8ad
commit
81362956a7
|
@ -22,9 +22,10 @@ from multiprocessing.pool import ThreadPool
|
||||||
from pyspark import since, keyword_only
|
from pyspark import since, keyword_only
|
||||||
from pyspark.ml import Estimator, Model
|
from pyspark.ml import Estimator, Model
|
||||||
from pyspark.ml.param.shared import *
|
from pyspark.ml.param.shared import *
|
||||||
from pyspark.ml.regression import DecisionTreeModel, DecisionTreeParams, \
|
from pyspark.ml.tree import _DecisionTreeModel, _DecisionTreeParams, \
|
||||||
DecisionTreeRegressionModel, GBTParams, HasVarianceImpurity, RandomForestParams, \
|
_TreeEnsembleModel, _RandomForestParams, _GBTParams, \
|
||||||
TreeEnsembleModel
|
_HasVarianceImpurity, _TreeClassifierParams, _TreeEnsembleParams
|
||||||
|
from pyspark.ml.regression import DecisionTreeRegressionModel
|
||||||
from pyspark.ml.util import *
|
from pyspark.ml.util import *
|
||||||
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, \
|
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, \
|
||||||
JavaPredictor, JavaPredictorParams, JavaPredictionModel, JavaWrapper
|
JavaPredictor, JavaPredictorParams, JavaPredictionModel, JavaWrapper
|
||||||
|
@ -939,34 +940,17 @@ class BinaryLogisticRegressionTrainingSummary(BinaryLogisticRegressionSummary,
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TreeClassifierParams(object):
|
@inherit_doc
|
||||||
|
class _DecisionTreeClassifierParams(_DecisionTreeParams, _TreeClassifierParams):
|
||||||
"""
|
"""
|
||||||
Private class to track supported impurity measures.
|
Params for :py:class:`DecisionTreeClassifier` and :py:class:`DecisionTreeClassificationModel`.
|
||||||
|
|
||||||
.. versionadded:: 1.4.0
|
|
||||||
"""
|
"""
|
||||||
supportedImpurities = ["entropy", "gini"]
|
pass
|
||||||
|
|
||||||
impurity = Param(Params._dummy(), "impurity",
|
|
||||||
"Criterion used for information gain calculation (case-insensitive). " +
|
|
||||||
"Supported options: " +
|
|
||||||
", ".join(supportedImpurities), typeConverter=TypeConverters.toString)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super(TreeClassifierParams, self).__init__()
|
|
||||||
|
|
||||||
@since("1.6.0")
|
|
||||||
def getImpurity(self):
|
|
||||||
"""
|
|
||||||
Gets the value of impurity or its default value.
|
|
||||||
"""
|
|
||||||
return self.getOrDefault(self.impurity)
|
|
||||||
|
|
||||||
|
|
||||||
@inherit_doc
|
@inherit_doc
|
||||||
class DecisionTreeClassifier(JavaProbabilisticClassifier, HasWeightCol,
|
class DecisionTreeClassifier(JavaProbabilisticClassifier, _DecisionTreeClassifierParams,
|
||||||
DecisionTreeParams, TreeClassifierParams, HasCheckpointInterval,
|
JavaMLWritable, JavaMLReadable):
|
||||||
HasSeed, JavaMLWritable, JavaMLReadable):
|
|
||||||
"""
|
"""
|
||||||
`Decision tree <http://en.wikipedia.org/wiki/Decision_tree_learning>`_
|
`Decision tree <http://en.wikipedia.org/wiki/Decision_tree_learning>`_
|
||||||
learning algorithm for classification.
|
learning algorithm for classification.
|
||||||
|
@ -1045,20 +1029,20 @@ class DecisionTreeClassifier(JavaProbabilisticClassifier, HasWeightCol,
|
||||||
probabilityCol="probability", rawPredictionCol="rawPrediction",
|
probabilityCol="probability", rawPredictionCol="rawPrediction",
|
||||||
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
||||||
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini",
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini",
|
||||||
seed=None, weightCol=None, leafCol=""):
|
seed=None, weightCol=None, leafCol="", minWeightFractionPerNode=0.0):
|
||||||
"""
|
"""
|
||||||
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
||||||
probabilityCol="probability", rawPredictionCol="rawPrediction", \
|
probabilityCol="probability", rawPredictionCol="rawPrediction", \
|
||||||
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
||||||
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
|
||||||
seed=None, weightCol=None, leafCol="")
|
seed=None, weightCol=None, leafCol="", minWeightFractionPerNode=0.0)
|
||||||
"""
|
"""
|
||||||
super(DecisionTreeClassifier, self).__init__()
|
super(DecisionTreeClassifier, self).__init__()
|
||||||
self._java_obj = self._new_java_obj(
|
self._java_obj = self._new_java_obj(
|
||||||
"org.apache.spark.ml.classification.DecisionTreeClassifier", self.uid)
|
"org.apache.spark.ml.classification.DecisionTreeClassifier", self.uid)
|
||||||
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
||||||
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
|
||||||
impurity="gini", leafCol="")
|
impurity="gini", leafCol="", minWeightFractionPerNode=0.0)
|
||||||
kwargs = self._input_kwargs
|
kwargs = self._input_kwargs
|
||||||
self.setParams(**kwargs)
|
self.setParams(**kwargs)
|
||||||
|
|
||||||
|
@ -1068,13 +1052,14 @@ class DecisionTreeClassifier(JavaProbabilisticClassifier, HasWeightCol,
|
||||||
probabilityCol="probability", rawPredictionCol="rawPrediction",
|
probabilityCol="probability", rawPredictionCol="rawPrediction",
|
||||||
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
||||||
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
|
||||||
impurity="gini", seed=None, weightCol=None, leafCol=""):
|
impurity="gini", seed=None, weightCol=None, leafCol="",
|
||||||
|
minWeightFractionPerNode=0.0):
|
||||||
"""
|
"""
|
||||||
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
||||||
probabilityCol="probability", rawPredictionCol="rawPrediction", \
|
probabilityCol="probability", rawPredictionCol="rawPrediction", \
|
||||||
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
||||||
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
|
||||||
seed=None, weightCol=None, leafCol="")
|
seed=None, weightCol=None, leafCol="", minWeightFractionPerNode=0.0)
|
||||||
Sets params for the DecisionTreeClassifier.
|
Sets params for the DecisionTreeClassifier.
|
||||||
"""
|
"""
|
||||||
kwargs = self._input_kwargs
|
kwargs = self._input_kwargs
|
||||||
|
@ -1101,6 +1086,13 @@ class DecisionTreeClassifier(JavaProbabilisticClassifier, HasWeightCol,
|
||||||
"""
|
"""
|
||||||
return self._set(minInstancesPerNode=value)
|
return self._set(minInstancesPerNode=value)
|
||||||
|
|
||||||
|
@since("3.0.0")
|
||||||
|
def setMinWeightFractionPerNode(self, value):
|
||||||
|
"""
|
||||||
|
Sets the value of :py:attr:`minWeightFractionPerNode`.
|
||||||
|
"""
|
||||||
|
return self._set(minWeightFractionPerNode=value)
|
||||||
|
|
||||||
def setMinInfoGain(self, value):
|
def setMinInfoGain(self, value):
|
||||||
"""
|
"""
|
||||||
Sets the value of :py:attr:`minInfoGain`.
|
Sets the value of :py:attr:`minInfoGain`.
|
||||||
|
@ -1128,8 +1120,9 @@ class DecisionTreeClassifier(JavaProbabilisticClassifier, HasWeightCol,
|
||||||
|
|
||||||
|
|
||||||
@inherit_doc
|
@inherit_doc
|
||||||
class DecisionTreeClassificationModel(DecisionTreeModel, JavaProbabilisticClassificationModel,
|
class DecisionTreeClassificationModel(_DecisionTreeModel, JavaProbabilisticClassificationModel,
|
||||||
JavaMLWritable, JavaMLReadable):
|
_DecisionTreeClassifierParams, JavaMLWritable,
|
||||||
|
JavaMLReadable):
|
||||||
"""
|
"""
|
||||||
Model fitted by DecisionTreeClassifier.
|
Model fitted by DecisionTreeClassifier.
|
||||||
|
|
||||||
|
@ -1159,8 +1152,15 @@ class DecisionTreeClassificationModel(DecisionTreeModel, JavaProbabilisticClassi
|
||||||
|
|
||||||
|
|
||||||
@inherit_doc
|
@inherit_doc
|
||||||
class RandomForestClassifier(JavaProbabilisticClassifier, HasSeed, RandomForestParams,
|
class _RandomForestClassifierParams(_RandomForestParams, _TreeClassifierParams):
|
||||||
TreeClassifierParams, HasCheckpointInterval,
|
"""
|
||||||
|
Params for :py:class:`RandomForestClassifier` and :py:class:`RandomForestClassificationModel`.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@inherit_doc
|
||||||
|
class RandomForestClassifier(JavaProbabilisticClassifier, _RandomForestClassifierParams,
|
||||||
JavaMLWritable, JavaMLReadable):
|
JavaMLWritable, JavaMLReadable):
|
||||||
"""
|
"""
|
||||||
`Random Forest <http://en.wikipedia.org/wiki/Random_forest>`_
|
`Random Forest <http://en.wikipedia.org/wiki/Random_forest>`_
|
||||||
|
@ -1230,14 +1230,14 @@ class RandomForestClassifier(JavaProbabilisticClassifier, HasSeed, RandomForestP
|
||||||
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
||||||
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini",
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini",
|
||||||
numTrees=20, featureSubsetStrategy="auto", seed=None, subsamplingRate=1.0,
|
numTrees=20, featureSubsetStrategy="auto", seed=None, subsamplingRate=1.0,
|
||||||
leafCol=""):
|
leafCol="", minWeightFractionPerNode=0.0):
|
||||||
"""
|
"""
|
||||||
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
||||||
probabilityCol="probability", rawPredictionCol="rawPrediction", \
|
probabilityCol="probability", rawPredictionCol="rawPrediction", \
|
||||||
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
||||||
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
|
||||||
numTrees=20, featureSubsetStrategy="auto", seed=None, subsamplingRate=1.0, \
|
numTrees=20, featureSubsetStrategy="auto", seed=None, subsamplingRate=1.0, \
|
||||||
leafCol="")
|
leafCol="", minWeightFractionPerNode=0.0)
|
||||||
"""
|
"""
|
||||||
super(RandomForestClassifier, self).__init__()
|
super(RandomForestClassifier, self).__init__()
|
||||||
self._java_obj = self._new_java_obj(
|
self._java_obj = self._new_java_obj(
|
||||||
|
@ -1245,7 +1245,7 @@ class RandomForestClassifier(JavaProbabilisticClassifier, HasSeed, RandomForestP
|
||||||
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
||||||
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
|
||||||
impurity="gini", numTrees=20, featureSubsetStrategy="auto",
|
impurity="gini", numTrees=20, featureSubsetStrategy="auto",
|
||||||
subsamplingRate=1.0, leafCol="")
|
subsamplingRate=1.0, leafCol="", minWeightFractionPerNode=0.0)
|
||||||
kwargs = self._input_kwargs
|
kwargs = self._input_kwargs
|
||||||
self.setParams(**kwargs)
|
self.setParams(**kwargs)
|
||||||
|
|
||||||
|
@ -1256,14 +1256,14 @@ class RandomForestClassifier(JavaProbabilisticClassifier, HasSeed, RandomForestP
|
||||||
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
||||||
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None,
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None,
|
||||||
impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0,
|
impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0,
|
||||||
leafCol=""):
|
leafCol="", minWeightFractionPerNode=0.0):
|
||||||
"""
|
"""
|
||||||
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
||||||
probabilityCol="probability", rawPredictionCol="rawPrediction", \
|
probabilityCol="probability", rawPredictionCol="rawPrediction", \
|
||||||
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
||||||
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \
|
||||||
impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0, \
|
impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0, \
|
||||||
leafCol="")
|
leafCol="", minWeightFractionPerNode=0.0)
|
||||||
Sets params for linear classification.
|
Sets params for linear classification.
|
||||||
"""
|
"""
|
||||||
kwargs = self._input_kwargs
|
kwargs = self._input_kwargs
|
||||||
|
@ -1337,8 +1337,9 @@ class RandomForestClassifier(JavaProbabilisticClassifier, HasSeed, RandomForestP
|
||||||
return self._set(featureSubsetStrategy=value)
|
return self._set(featureSubsetStrategy=value)
|
||||||
|
|
||||||
|
|
||||||
class RandomForestClassificationModel(TreeEnsembleModel, JavaProbabilisticClassificationModel,
|
class RandomForestClassificationModel(_TreeEnsembleModel, JavaProbabilisticClassificationModel,
|
||||||
JavaMLWritable, JavaMLReadable):
|
_RandomForestClassifierParams, JavaMLWritable,
|
||||||
|
JavaMLReadable):
|
||||||
"""
|
"""
|
||||||
Model fitted by RandomForestClassifier.
|
Model fitted by RandomForestClassifier.
|
||||||
|
|
||||||
|
@ -1367,7 +1368,7 @@ class RandomForestClassificationModel(TreeEnsembleModel, JavaProbabilisticClassi
|
||||||
return [DecisionTreeClassificationModel(m) for m in list(self._call_java("trees"))]
|
return [DecisionTreeClassificationModel(m) for m in list(self._call_java("trees"))]
|
||||||
|
|
||||||
|
|
||||||
class GBTClassifierParams(GBTParams, HasVarianceImpurity):
|
class GBTClassifierParams(_GBTParams, _HasVarianceImpurity):
|
||||||
"""
|
"""
|
||||||
Private class to track supported GBTClassifier params.
|
Private class to track supported GBTClassifier params.
|
||||||
|
|
||||||
|
@ -1390,8 +1391,8 @@ class GBTClassifierParams(GBTParams, HasVarianceImpurity):
|
||||||
|
|
||||||
|
|
||||||
@inherit_doc
|
@inherit_doc
|
||||||
class GBTClassifier(JavaProbabilisticClassifier, GBTClassifierParams, HasCheckpointInterval,
|
class GBTClassifier(JavaProbabilisticClassifier, GBTClassifierParams,
|
||||||
HasSeed, JavaMLWritable, JavaMLReadable):
|
JavaMLWritable, JavaMLReadable):
|
||||||
"""
|
"""
|
||||||
`Gradient-Boosted Trees (GBTs) <http://en.wikipedia.org/wiki/Gradient_boosting>`_
|
`Gradient-Boosted Trees (GBTs) <http://en.wikipedia.org/wiki/Gradient_boosting>`_
|
||||||
learning algorithm for classification.
|
learning algorithm for classification.
|
||||||
|
@ -1485,14 +1486,14 @@ class GBTClassifier(JavaProbabilisticClassifier, GBTClassifierParams, HasCheckpo
|
||||||
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic",
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic",
|
||||||
maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, impurity="variance",
|
maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, impurity="variance",
|
||||||
featureSubsetStrategy="all", validationTol=0.01, validationIndicatorCol=None,
|
featureSubsetStrategy="all", validationTol=0.01, validationIndicatorCol=None,
|
||||||
leafCol=""):
|
leafCol="", minWeightFractionPerNode=0.0):
|
||||||
"""
|
"""
|
||||||
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
||||||
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
||||||
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
|
||||||
lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \
|
lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \
|
||||||
impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \
|
impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \
|
||||||
validationIndicatorCol=None, leafCol="")
|
validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0)
|
||||||
"""
|
"""
|
||||||
super(GBTClassifier, self).__init__()
|
super(GBTClassifier, self).__init__()
|
||||||
self._java_obj = self._new_java_obj(
|
self._java_obj = self._new_java_obj(
|
||||||
|
@ -1501,7 +1502,7 @@ class GBTClassifier(JavaProbabilisticClassifier, GBTClassifierParams, HasCheckpo
|
||||||
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
|
||||||
lossType="logistic", maxIter=20, stepSize=0.1, subsamplingRate=1.0,
|
lossType="logistic", maxIter=20, stepSize=0.1, subsamplingRate=1.0,
|
||||||
impurity="variance", featureSubsetStrategy="all", validationTol=0.01,
|
impurity="variance", featureSubsetStrategy="all", validationTol=0.01,
|
||||||
leafCol="")
|
leafCol="", minWeightFractionPerNode=0.0)
|
||||||
kwargs = self._input_kwargs
|
kwargs = self._input_kwargs
|
||||||
self.setParams(**kwargs)
|
self.setParams(**kwargs)
|
||||||
|
|
||||||
|
@ -1512,14 +1513,14 @@ class GBTClassifier(JavaProbabilisticClassifier, GBTClassifierParams, HasCheckpo
|
||||||
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
|
||||||
lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0,
|
lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0,
|
||||||
impurity="variance", featureSubsetStrategy="all", validationTol=0.01,
|
impurity="variance", featureSubsetStrategy="all", validationTol=0.01,
|
||||||
validationIndicatorCol=None, leafCol=""):
|
validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0):
|
||||||
"""
|
"""
|
||||||
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
||||||
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
||||||
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
|
||||||
lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \
|
lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \
|
||||||
impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \
|
impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \
|
||||||
validationIndicatorCol=None, leafCol="")
|
validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0)
|
||||||
Sets params for Gradient Boosted Tree Classification.
|
Sets params for Gradient Boosted Tree Classification.
|
||||||
"""
|
"""
|
||||||
kwargs = self._input_kwargs
|
kwargs = self._input_kwargs
|
||||||
|
@ -1600,8 +1601,8 @@ class GBTClassifier(JavaProbabilisticClassifier, GBTClassifierParams, HasCheckpo
|
||||||
return self._set(validationIndicatorCol=value)
|
return self._set(validationIndicatorCol=value)
|
||||||
|
|
||||||
|
|
||||||
class GBTClassificationModel(TreeEnsembleModel, JavaProbabilisticClassificationModel,
|
class GBTClassificationModel(_TreeEnsembleModel, JavaProbabilisticClassificationModel,
|
||||||
JavaMLWritable, JavaMLReadable):
|
GBTClassifierParams, JavaMLWritable, JavaMLReadable):
|
||||||
"""
|
"""
|
||||||
Model fitted by GBTClassifier.
|
Model fitted by GBTClassifier.
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,9 @@ import sys
|
||||||
|
|
||||||
from pyspark import since, keyword_only
|
from pyspark import since, keyword_only
|
||||||
from pyspark.ml.param.shared import *
|
from pyspark.ml.param.shared import *
|
||||||
|
from pyspark.ml.tree import _DecisionTreeModel, _DecisionTreeParams, \
|
||||||
|
_TreeEnsembleModel, _TreeEnsembleParams, _RandomForestParams, _GBTParams, \
|
||||||
|
_HasVarianceImpurity, _TreeRegressorParams
|
||||||
from pyspark.ml.util import *
|
from pyspark.ml.util import *
|
||||||
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, \
|
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, \
|
||||||
JavaPredictor, JavaPredictionModel, JavaWrapper
|
JavaPredictor, JavaPredictionModel, JavaWrapper
|
||||||
|
@ -600,233 +603,19 @@ class IsotonicRegressionModel(JavaModel, _IsotonicRegressionBase,
|
||||||
return self._call_java("predictions")
|
return self._call_java("predictions")
|
||||||
|
|
||||||
|
|
||||||
class DecisionTreeParams(Params):
|
class _DecisionTreeRegressorParams(_DecisionTreeParams, _TreeRegressorParams, HasVarianceCol):
|
||||||
"""
|
"""
|
||||||
Mixin for Decision Tree parameters.
|
Params for :py:class:`DecisionTreeRegressor` and :py:class:`DecisionTreeRegressionModel`.
|
||||||
"""
|
|
||||||
|
|
||||||
leafCol = Param(Params._dummy(), "leafCol", "Leaf indices column name. Predicted leaf " +
|
|
||||||
"index of each instance in each tree by preorder.",
|
|
||||||
typeConverter=TypeConverters.toString)
|
|
||||||
|
|
||||||
maxDepth = Param(Params._dummy(), "maxDepth", "Maximum depth of the tree. (>= 0) E.g., " +
|
|
||||||
"depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.",
|
|
||||||
typeConverter=TypeConverters.toInt)
|
|
||||||
|
|
||||||
maxBins = Param(Params._dummy(), "maxBins", "Max number of bins for discretizing continuous " +
|
|
||||||
"features. Must be >=2 and >= number of categories for any categorical " +
|
|
||||||
"feature.", typeConverter=TypeConverters.toInt)
|
|
||||||
|
|
||||||
minInstancesPerNode = Param(Params._dummy(), "minInstancesPerNode", "Minimum number of " +
|
|
||||||
"instances each child must have after split. If a split causes " +
|
|
||||||
"the left or right child to have fewer than " +
|
|
||||||
"minInstancesPerNode, the split will be discarded as invalid. " +
|
|
||||||
"Should be >= 1.", typeConverter=TypeConverters.toInt)
|
|
||||||
|
|
||||||
minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information gain for a split " +
|
|
||||||
"to be considered at a tree node.", typeConverter=TypeConverters.toFloat)
|
|
||||||
|
|
||||||
maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to " +
|
|
||||||
"histogram aggregation. If too small, then 1 node will be split per " +
|
|
||||||
"iteration, and its aggregates may exceed this size.",
|
|
||||||
typeConverter=TypeConverters.toInt)
|
|
||||||
|
|
||||||
cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass " +
|
|
||||||
"trees to executors to match instances with nodes. If true, the " +
|
|
||||||
"algorithm will cache node IDs for each instance. Caching can speed " +
|
|
||||||
"up training of deeper trees. Users can set how often should the cache " +
|
|
||||||
"be checkpointed or disable it by setting checkpointInterval.",
|
|
||||||
typeConverter=TypeConverters.toBoolean)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super(DecisionTreeParams, self).__init__()
|
|
||||||
|
|
||||||
def setLeafCol(self, value):
|
|
||||||
"""
|
|
||||||
Sets the value of :py:attr:`leafCol`.
|
|
||||||
"""
|
|
||||||
return self._set(leafCol=value)
|
|
||||||
|
|
||||||
def getLeafCol(self):
|
|
||||||
"""
|
|
||||||
Gets the value of leafCol or its default value.
|
|
||||||
"""
|
|
||||||
return self.getOrDefault(self.leafCol)
|
|
||||||
|
|
||||||
def getMaxDepth(self):
|
|
||||||
"""
|
|
||||||
Gets the value of maxDepth or its default value.
|
|
||||||
"""
|
|
||||||
return self.getOrDefault(self.maxDepth)
|
|
||||||
|
|
||||||
def getMaxBins(self):
|
|
||||||
"""
|
|
||||||
Gets the value of maxBins or its default value.
|
|
||||||
"""
|
|
||||||
return self.getOrDefault(self.maxBins)
|
|
||||||
|
|
||||||
def getMinInstancesPerNode(self):
|
|
||||||
"""
|
|
||||||
Gets the value of minInstancesPerNode or its default value.
|
|
||||||
"""
|
|
||||||
return self.getOrDefault(self.minInstancesPerNode)
|
|
||||||
|
|
||||||
def getMinInfoGain(self):
|
|
||||||
"""
|
|
||||||
Gets the value of minInfoGain or its default value.
|
|
||||||
"""
|
|
||||||
return self.getOrDefault(self.minInfoGain)
|
|
||||||
|
|
||||||
def getMaxMemoryInMB(self):
|
|
||||||
"""
|
|
||||||
Gets the value of maxMemoryInMB or its default value.
|
|
||||||
"""
|
|
||||||
return self.getOrDefault(self.maxMemoryInMB)
|
|
||||||
|
|
||||||
def getCacheNodeIds(self):
|
|
||||||
"""
|
|
||||||
Gets the value of cacheNodeIds or its default value.
|
|
||||||
"""
|
|
||||||
return self.getOrDefault(self.cacheNodeIds)
|
|
||||||
|
|
||||||
|
|
||||||
class TreeEnsembleParams(DecisionTreeParams):
|
|
||||||
"""
|
|
||||||
Mixin for Decision Tree-based ensemble algorithms parameters.
|
|
||||||
"""
|
|
||||||
|
|
||||||
subsamplingRate = Param(Params._dummy(), "subsamplingRate", "Fraction of the training data " +
|
|
||||||
"used for learning each decision tree, in range (0, 1].",
|
|
||||||
typeConverter=TypeConverters.toFloat)
|
|
||||||
|
|
||||||
supportedFeatureSubsetStrategies = ["auto", "all", "onethird", "sqrt", "log2"]
|
|
||||||
|
|
||||||
featureSubsetStrategy = \
|
|
||||||
Param(Params._dummy(), "featureSubsetStrategy",
|
|
||||||
"The number of features to consider for splits at each tree node. Supported " +
|
|
||||||
"options: 'auto' (choose automatically for task: If numTrees == 1, set to " +
|
|
||||||
"'all'. If numTrees > 1 (forest), set to 'sqrt' for classification and to " +
|
|
||||||
"'onethird' for regression), 'all' (use all features), 'onethird' (use " +
|
|
||||||
"1/3 of the features), 'sqrt' (use sqrt(number of features)), 'log2' (use " +
|
|
||||||
"log2(number of features)), 'n' (when n is in the range (0, 1.0], use " +
|
|
||||||
"n * number of features. When n is in the range (1, number of features), use" +
|
|
||||||
" n features). default = 'auto'", typeConverter=TypeConverters.toString)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super(TreeEnsembleParams, self).__init__()
|
|
||||||
|
|
||||||
@since("1.4.0")
|
|
||||||
def getSubsamplingRate(self):
|
|
||||||
"""
|
|
||||||
Gets the value of subsamplingRate or its default value.
|
|
||||||
"""
|
|
||||||
return self.getOrDefault(self.subsamplingRate)
|
|
||||||
|
|
||||||
@since("1.4.0")
|
|
||||||
def getFeatureSubsetStrategy(self):
|
|
||||||
"""
|
|
||||||
Gets the value of featureSubsetStrategy or its default value.
|
|
||||||
"""
|
|
||||||
return self.getOrDefault(self.featureSubsetStrategy)
|
|
||||||
|
|
||||||
|
|
||||||
class HasVarianceImpurity(Params):
|
|
||||||
"""
|
|
||||||
Private class to track supported impurity measures.
|
|
||||||
"""
|
|
||||||
|
|
||||||
supportedImpurities = ["variance"]
|
|
||||||
|
|
||||||
impurity = Param(Params._dummy(), "impurity",
|
|
||||||
"Criterion used for information gain calculation (case-insensitive). " +
|
|
||||||
"Supported options: " +
|
|
||||||
", ".join(supportedImpurities), typeConverter=TypeConverters.toString)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super(HasVarianceImpurity, self).__init__()
|
|
||||||
|
|
||||||
@since("1.4.0")
|
|
||||||
def getImpurity(self):
|
|
||||||
"""
|
|
||||||
Gets the value of impurity or its default value.
|
|
||||||
"""
|
|
||||||
return self.getOrDefault(self.impurity)
|
|
||||||
|
|
||||||
|
|
||||||
class TreeRegressorParams(HasVarianceImpurity):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class RandomForestParams(TreeEnsembleParams):
|
|
||||||
"""
|
|
||||||
Private class to track supported random forest parameters.
|
|
||||||
"""
|
|
||||||
|
|
||||||
numTrees = Param(Params._dummy(), "numTrees", "Number of trees to train (>= 1).",
|
|
||||||
typeConverter=TypeConverters.toInt)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super(RandomForestParams, self).__init__()
|
|
||||||
|
|
||||||
@since("1.4.0")
|
|
||||||
def getNumTrees(self):
|
|
||||||
"""
|
|
||||||
Gets the value of numTrees or its default value.
|
|
||||||
"""
|
|
||||||
return self.getOrDefault(self.numTrees)
|
|
||||||
|
|
||||||
|
|
||||||
class GBTParams(TreeEnsembleParams, HasMaxIter, HasStepSize, HasValidationIndicatorCol):
|
|
||||||
"""
|
|
||||||
Private class to track supported GBT params.
|
|
||||||
"""
|
|
||||||
|
|
||||||
stepSize = Param(Params._dummy(), "stepSize",
|
|
||||||
"Step size (a.k.a. learning rate) in interval (0, 1] for shrinking " +
|
|
||||||
"the contribution of each estimator.",
|
|
||||||
typeConverter=TypeConverters.toFloat)
|
|
||||||
|
|
||||||
validationTol = Param(Params._dummy(), "validationTol",
|
|
||||||
"Threshold for stopping early when fit with validation is used. " +
|
|
||||||
"If the error rate on the validation input changes by less than the " +
|
|
||||||
"validationTol, then learning will stop early (before `maxIter`). " +
|
|
||||||
"This parameter is ignored when fit without validation is used.",
|
|
||||||
typeConverter=TypeConverters.toFloat)
|
|
||||||
|
|
||||||
@since("3.0.0")
|
|
||||||
def getValidationTol(self):
|
|
||||||
"""
|
|
||||||
Gets the value of validationTol or its default value.
|
|
||||||
"""
|
|
||||||
return self.getOrDefault(self.validationTol)
|
|
||||||
|
|
||||||
|
|
||||||
class GBTRegressorParams(GBTParams, TreeRegressorParams):
|
|
||||||
"""
|
|
||||||
Private class to track supported GBTRegressor params.
|
|
||||||
|
|
||||||
.. versionadded:: 3.0.0
|
.. versionadded:: 3.0.0
|
||||||
"""
|
"""
|
||||||
|
|
||||||
supportedLossTypes = ["squared", "absolute"]
|
pass
|
||||||
|
|
||||||
lossType = Param(Params._dummy(), "lossType",
|
|
||||||
"Loss function which GBT tries to minimize (case-insensitive). " +
|
|
||||||
"Supported options: " + ", ".join(supportedLossTypes),
|
|
||||||
typeConverter=TypeConverters.toString)
|
|
||||||
|
|
||||||
@since("1.4.0")
|
|
||||||
def getLossType(self):
|
|
||||||
"""
|
|
||||||
Gets the value of lossType or its default value.
|
|
||||||
"""
|
|
||||||
return self.getOrDefault(self.lossType)
|
|
||||||
|
|
||||||
|
|
||||||
@inherit_doc
|
@inherit_doc
|
||||||
class DecisionTreeRegressor(JavaPredictor, HasWeightCol, DecisionTreeParams, TreeRegressorParams,
|
class DecisionTreeRegressor(JavaPredictor, _DecisionTreeRegressorParams, JavaMLWritable,
|
||||||
HasCheckpointInterval, HasSeed, JavaMLWritable, JavaMLReadable,
|
JavaMLReadable):
|
||||||
HasVarianceCol):
|
|
||||||
"""
|
"""
|
||||||
`Decision tree <http://en.wikipedia.org/wiki/Decision_tree_learning>`_
|
`Decision tree <http://en.wikipedia.org/wiki/Decision_tree_learning>`_
|
||||||
learning algorithm for regression.
|
learning algorithm for regression.
|
||||||
|
@ -836,8 +625,12 @@ class DecisionTreeRegressor(JavaPredictor, HasWeightCol, DecisionTreeParams, Tre
|
||||||
>>> df = spark.createDataFrame([
|
>>> df = spark.createDataFrame([
|
||||||
... (1.0, Vectors.dense(1.0)),
|
... (1.0, Vectors.dense(1.0)),
|
||||||
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
|
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
|
||||||
>>> dt = DecisionTreeRegressor(maxDepth=2, varianceCol="variance", leafCol="leafId")
|
>>> dt = DecisionTreeRegressor(maxDepth=2, varianceCol="variance")
|
||||||
>>> model = dt.fit(df)
|
>>> model = dt.fit(df)
|
||||||
|
>>> model.getVarianceCol()
|
||||||
|
'variance'
|
||||||
|
>>> model.setLeafCol("leafId")
|
||||||
|
DecisionTreeRegressionModel...
|
||||||
>>> model.depth
|
>>> model.depth
|
||||||
1
|
1
|
||||||
>>> model.numNodes
|
>>> model.numNodes
|
||||||
|
@ -852,6 +645,8 @@ class DecisionTreeRegressor(JavaPredictor, HasWeightCol, DecisionTreeParams, Tre
|
||||||
>>> result = model.transform(test0).head()
|
>>> result = model.transform(test0).head()
|
||||||
>>> result.prediction
|
>>> result.prediction
|
||||||
0.0
|
0.0
|
||||||
|
>>> model.predictLeaf(test0.head().features)
|
||||||
|
0.0
|
||||||
>>> result.leafId
|
>>> result.leafId
|
||||||
0.0
|
0.0
|
||||||
>>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
|
>>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
|
||||||
|
@ -888,20 +683,21 @@ class DecisionTreeRegressor(JavaPredictor, HasWeightCol, DecisionTreeParams, Tre
|
||||||
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
||||||
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
||||||
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance",
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance",
|
||||||
seed=None, varianceCol=None, weightCol=None, leafCol=""):
|
seed=None, varianceCol=None, weightCol=None, leafCol="",
|
||||||
|
minWeightFractionPerNode=0.0):
|
||||||
"""
|
"""
|
||||||
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
||||||
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
||||||
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
|
||||||
impurity="variance", seed=None, varianceCol=None, weightCol=None, \
|
impurity="variance", seed=None, varianceCol=None, weightCol=None, \
|
||||||
leafCol="")
|
leafCol="", minWeightFractionPerNode=0.0)
|
||||||
"""
|
"""
|
||||||
super(DecisionTreeRegressor, self).__init__()
|
super(DecisionTreeRegressor, self).__init__()
|
||||||
self._java_obj = self._new_java_obj(
|
self._java_obj = self._new_java_obj(
|
||||||
"org.apache.spark.ml.regression.DecisionTreeRegressor", self.uid)
|
"org.apache.spark.ml.regression.DecisionTreeRegressor", self.uid)
|
||||||
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
||||||
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
|
||||||
impurity="variance", leafCol="")
|
impurity="variance", leafCol="", minWeightFractionPerNode=0.0)
|
||||||
kwargs = self._input_kwargs
|
kwargs = self._input_kwargs
|
||||||
self.setParams(**kwargs)
|
self.setParams(**kwargs)
|
||||||
|
|
||||||
|
@ -911,13 +707,13 @@ class DecisionTreeRegressor(JavaPredictor, HasWeightCol, DecisionTreeParams, Tre
|
||||||
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
||||||
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
|
||||||
impurity="variance", seed=None, varianceCol=None, weightCol=None,
|
impurity="variance", seed=None, varianceCol=None, weightCol=None,
|
||||||
leafCol=""):
|
leafCol="", minWeightFractionPerNode=0.0):
|
||||||
"""
|
"""
|
||||||
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
||||||
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
||||||
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
|
||||||
impurity="variance", seed=None, varianceCol=None, weightCol=None, \
|
impurity="variance", seed=None, varianceCol=None, weightCol=None, \
|
||||||
leafCol="")
|
leafCol="", minWeightFractionPerNode=0.0)
|
||||||
Sets params for the DecisionTreeRegressor.
|
Sets params for the DecisionTreeRegressor.
|
||||||
"""
|
"""
|
||||||
kwargs = self._input_kwargs
|
kwargs = self._input_kwargs
|
||||||
|
@ -944,6 +740,13 @@ class DecisionTreeRegressor(JavaPredictor, HasWeightCol, DecisionTreeParams, Tre
|
||||||
"""
|
"""
|
||||||
return self._set(minInstancesPerNode=value)
|
return self._set(minInstancesPerNode=value)
|
||||||
|
|
||||||
|
@since("3.0.0")
|
||||||
|
def setMinWeightFractionPerNode(self, value):
|
||||||
|
"""
|
||||||
|
Sets the value of :py:attr:`minWeightFractionPerNode`.
|
||||||
|
"""
|
||||||
|
return self._set(minWeightFractionPerNode=value)
|
||||||
|
|
||||||
def setMinInfoGain(self, value):
|
def setMinInfoGain(self, value):
|
||||||
"""
|
"""
|
||||||
Sets the value of :py:attr:`minInfoGain`.
|
Sets the value of :py:attr:`minInfoGain`.
|
||||||
|
@ -971,79 +774,8 @@ class DecisionTreeRegressor(JavaPredictor, HasWeightCol, DecisionTreeParams, Tre
|
||||||
|
|
||||||
|
|
||||||
@inherit_doc
|
@inherit_doc
|
||||||
class DecisionTreeModel(JavaPredictionModel):
|
class DecisionTreeRegressionModel(_DecisionTreeModel, _DecisionTreeRegressorParams,
|
||||||
"""
|
JavaMLWritable, JavaMLReadable):
|
||||||
Abstraction for Decision Tree models.
|
|
||||||
|
|
||||||
.. versionadded:: 1.5.0
|
|
||||||
"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
@since("1.5.0")
|
|
||||||
def numNodes(self):
|
|
||||||
"""Return number of nodes of the decision tree."""
|
|
||||||
return self._call_java("numNodes")
|
|
||||||
|
|
||||||
@property
|
|
||||||
@since("1.5.0")
|
|
||||||
def depth(self):
|
|
||||||
"""Return depth of the decision tree."""
|
|
||||||
return self._call_java("depth")
|
|
||||||
|
|
||||||
@property
|
|
||||||
@since("2.0.0")
|
|
||||||
def toDebugString(self):
|
|
||||||
"""Full description of model."""
|
|
||||||
return self._call_java("toDebugString")
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return self._call_java("toString")
|
|
||||||
|
|
||||||
|
|
||||||
@inherit_doc
|
|
||||||
class TreeEnsembleModel(JavaModel):
|
|
||||||
"""
|
|
||||||
(private abstraction)
|
|
||||||
|
|
||||||
Represents a tree ensemble model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
@since("2.0.0")
|
|
||||||
def trees(self):
|
|
||||||
"""Trees in this ensemble. Warning: These have null parent Estimators."""
|
|
||||||
return [DecisionTreeModel(m) for m in list(self._call_java("trees"))]
|
|
||||||
|
|
||||||
@property
|
|
||||||
@since("2.0.0")
|
|
||||||
def getNumTrees(self):
|
|
||||||
"""Number of trees in ensemble."""
|
|
||||||
return self._call_java("getNumTrees")
|
|
||||||
|
|
||||||
@property
|
|
||||||
@since("1.5.0")
|
|
||||||
def treeWeights(self):
|
|
||||||
"""Return the weights for each tree"""
|
|
||||||
return list(self._call_java("javaTreeWeights"))
|
|
||||||
|
|
||||||
@property
|
|
||||||
@since("2.0.0")
|
|
||||||
def totalNumNodes(self):
|
|
||||||
"""Total number of nodes, summed over all trees in the ensemble."""
|
|
||||||
return self._call_java("totalNumNodes")
|
|
||||||
|
|
||||||
@property
|
|
||||||
@since("2.0.0")
|
|
||||||
def toDebugString(self):
|
|
||||||
"""Full description of model."""
|
|
||||||
return self._call_java("toDebugString")
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return self._call_java("toString")
|
|
||||||
|
|
||||||
|
|
||||||
@inherit_doc
|
|
||||||
class DecisionTreeRegressionModel(DecisionTreeModel, JavaMLWritable, JavaMLReadable):
|
|
||||||
"""
|
"""
|
||||||
Model fitted by :class:`DecisionTreeRegressor`.
|
Model fitted by :class:`DecisionTreeRegressor`.
|
||||||
|
|
||||||
|
@ -1072,9 +804,18 @@ class DecisionTreeRegressionModel(DecisionTreeModel, JavaMLWritable, JavaMLReada
|
||||||
return self._call_java("featureImportances")
|
return self._call_java("featureImportances")
|
||||||
|
|
||||||
|
|
||||||
|
class _RandomForestRegressorParams(_RandomForestParams, _TreeRegressorParams):
|
||||||
|
"""
|
||||||
|
Params for :py:class:`RandomForestRegressor` and :py:class:`RandomForestRegressionModel`.
|
||||||
|
|
||||||
|
.. versionadded:: 3.0.0
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@inherit_doc
|
@inherit_doc
|
||||||
class RandomForestRegressor(JavaPredictor, HasSeed, RandomForestParams, TreeRegressorParams,
|
class RandomForestRegressor(JavaPredictor, _RandomForestRegressorParams, JavaMLWritable,
|
||||||
HasCheckpointInterval, JavaMLWritable, JavaMLReadable):
|
JavaMLReadable):
|
||||||
"""
|
"""
|
||||||
`Random Forest <http://en.wikipedia.org/wiki/Random_forest>`_
|
`Random Forest <http://en.wikipedia.org/wiki/Random_forest>`_
|
||||||
learning algorithm for regression.
|
learning algorithm for regression.
|
||||||
|
@ -1085,8 +826,12 @@ class RandomForestRegressor(JavaPredictor, HasSeed, RandomForestParams, TreeRegr
|
||||||
>>> df = spark.createDataFrame([
|
>>> df = spark.createDataFrame([
|
||||||
... (1.0, Vectors.dense(1.0)),
|
... (1.0, Vectors.dense(1.0)),
|
||||||
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
|
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
|
||||||
>>> rf = RandomForestRegressor(numTrees=2, maxDepth=2, seed=42, leafCol="leafId")
|
>>> rf = RandomForestRegressor(numTrees=2, maxDepth=2, seed=42)
|
||||||
>>> model = rf.fit(df)
|
>>> model = rf.fit(df)
|
||||||
|
>>> model.getSeed()
|
||||||
|
42
|
||||||
|
>>> model.setLeafCol("leafId")
|
||||||
|
RandomForestRegressionModel...
|
||||||
>>> model.featureImportances
|
>>> model.featureImportances
|
||||||
SparseVector(1, {0: 1.0})
|
SparseVector(1, {0: 1.0})
|
||||||
>>> allclose(model.treeWeights, [1.0, 1.0])
|
>>> allclose(model.treeWeights, [1.0, 1.0])
|
||||||
|
@ -1094,6 +839,8 @@ class RandomForestRegressor(JavaPredictor, HasSeed, RandomForestParams, TreeRegr
|
||||||
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
|
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
|
||||||
>>> model.predict(test0.head().features)
|
>>> model.predict(test0.head().features)
|
||||||
0.0
|
0.0
|
||||||
|
>>> model.predictLeaf(test0.head().features)
|
||||||
|
DenseVector([0.0, 0.0])
|
||||||
>>> result = model.transform(test0).head()
|
>>> result = model.transform(test0).head()
|
||||||
>>> result.prediction
|
>>> result.prediction
|
||||||
0.0
|
0.0
|
||||||
|
@ -1127,13 +874,13 @@ class RandomForestRegressor(JavaPredictor, HasSeed, RandomForestParams, TreeRegr
|
||||||
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
||||||
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
|
||||||
impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20,
|
impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20,
|
||||||
featureSubsetStrategy="auto", leafCol=""):
|
featureSubsetStrategy="auto", leafCol="", minWeightFractionPerNode=0.0):
|
||||||
"""
|
"""
|
||||||
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
||||||
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
||||||
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
|
||||||
impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, \
|
impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, \
|
||||||
featureSubsetStrategy="auto", leafCol="")
|
featureSubsetStrategy="auto", leafCol=", minWeightFractionPerNode=0.0")
|
||||||
"""
|
"""
|
||||||
super(RandomForestRegressor, self).__init__()
|
super(RandomForestRegressor, self).__init__()
|
||||||
self._java_obj = self._new_java_obj(
|
self._java_obj = self._new_java_obj(
|
||||||
|
@ -1141,7 +888,7 @@ class RandomForestRegressor(JavaPredictor, HasSeed, RandomForestParams, TreeRegr
|
||||||
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
||||||
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
|
||||||
impurity="variance", subsamplingRate=1.0, numTrees=20,
|
impurity="variance", subsamplingRate=1.0, numTrees=20,
|
||||||
featureSubsetStrategy="auto", leafCol="")
|
featureSubsetStrategy="auto", leafCol="", minWeightFractionPerNode=0.0)
|
||||||
kwargs = self._input_kwargs
|
kwargs = self._input_kwargs
|
||||||
self.setParams(**kwargs)
|
self.setParams(**kwargs)
|
||||||
|
|
||||||
|
@ -1151,13 +898,13 @@ class RandomForestRegressor(JavaPredictor, HasSeed, RandomForestParams, TreeRegr
|
||||||
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
|
||||||
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
|
||||||
impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20,
|
impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20,
|
||||||
featureSubsetStrategy="auto", leafCol=""):
|
featureSubsetStrategy="auto", leafCol="", minWeightFractionPerNode=0.0):
|
||||||
"""
|
"""
|
||||||
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
||||||
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
||||||
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
|
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
|
||||||
impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, \
|
impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, \
|
||||||
featureSubsetStrategy="auto", leafCol="")
|
featureSubsetStrategy="auto", leafCol="", minWeightFractionPerNode=0.0)
|
||||||
Sets params for linear regression.
|
Sets params for linear regression.
|
||||||
"""
|
"""
|
||||||
kwargs = self._input_kwargs
|
kwargs = self._input_kwargs
|
||||||
|
@ -1231,8 +978,8 @@ class RandomForestRegressor(JavaPredictor, HasSeed, RandomForestParams, TreeRegr
|
||||||
return self._set(featureSubsetStrategy=value)
|
return self._set(featureSubsetStrategy=value)
|
||||||
|
|
||||||
|
|
||||||
class RandomForestRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable,
|
class RandomForestRegressionModel(_TreeEnsembleModel, _RandomForestRegressorParams,
|
||||||
JavaMLReadable):
|
JavaMLWritable, JavaMLReadable):
|
||||||
"""
|
"""
|
||||||
Model fitted by :class:`RandomForestRegressor`.
|
Model fitted by :class:`RandomForestRegressor`.
|
||||||
|
|
||||||
|
@ -1261,9 +1008,30 @@ class RandomForestRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaML
|
||||||
return self._call_java("featureImportances")
|
return self._call_java("featureImportances")
|
||||||
|
|
||||||
|
|
||||||
|
class _GBTRegressorParams(_GBTParams, _TreeRegressorParams):
|
||||||
|
"""
|
||||||
|
Params for :py:class:`GBTRegressor` and :py:class:`GBTRegressorModel`.
|
||||||
|
|
||||||
|
.. versionadded:: 3.0.0
|
||||||
|
"""
|
||||||
|
|
||||||
|
supportedLossTypes = ["squared", "absolute"]
|
||||||
|
|
||||||
|
lossType = Param(Params._dummy(), "lossType",
|
||||||
|
"Loss function which GBT tries to minimize (case-insensitive). " +
|
||||||
|
"Supported options: " + ", ".join(supportedLossTypes),
|
||||||
|
typeConverter=TypeConverters.toString)
|
||||||
|
|
||||||
|
@since("1.4.0")
|
||||||
|
def getLossType(self):
|
||||||
|
"""
|
||||||
|
Gets the value of lossType or its default value.
|
||||||
|
"""
|
||||||
|
return self.getOrDefault(self.lossType)
|
||||||
|
|
||||||
|
|
||||||
@inherit_doc
|
@inherit_doc
|
||||||
class GBTRegressor(JavaPredictor, GBTRegressorParams, HasCheckpointInterval, HasSeed,
|
class GBTRegressor(JavaPredictor, _GBTRegressorParams, JavaMLWritable, JavaMLReadable):
|
||||||
JavaMLWritable, JavaMLReadable):
|
|
||||||
"""
|
"""
|
||||||
`Gradient-Boosted Trees (GBTs) <http://en.wikipedia.org/wiki/Gradient_boosting>`_
|
`Gradient-Boosted Trees (GBTs) <http://en.wikipedia.org/wiki/Gradient_boosting>`_
|
||||||
learning algorithm for regression.
|
learning algorithm for regression.
|
||||||
|
@ -1280,8 +1048,6 @@ class GBTRegressor(JavaPredictor, GBTRegressorParams, HasCheckpointInterval, Has
|
||||||
>>> print(gbt.getFeatureSubsetStrategy())
|
>>> print(gbt.getFeatureSubsetStrategy())
|
||||||
all
|
all
|
||||||
>>> model = gbt.fit(df)
|
>>> model = gbt.fit(df)
|
||||||
>>> model.setFeaturesCol("features")
|
|
||||||
GBTRegressionModel...
|
|
||||||
>>> model.featureImportances
|
>>> model.featureImportances
|
||||||
SparseVector(1, {0: 1.0})
|
SparseVector(1, {0: 1.0})
|
||||||
>>> model.numFeatures
|
>>> model.numFeatures
|
||||||
|
@ -1291,6 +1057,8 @@ class GBTRegressor(JavaPredictor, GBTRegressorParams, HasCheckpointInterval, Has
|
||||||
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
|
>>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
|
||||||
>>> model.predict(test0.head().features)
|
>>> model.predict(test0.head().features)
|
||||||
0.0
|
0.0
|
||||||
|
>>> model.predictLeaf(test0.head().features)
|
||||||
|
DenseVector([0.0, 0.0, 0.0, 0.0, 0.0])
|
||||||
>>> result = model.transform(test0).head()
|
>>> result = model.transform(test0).head()
|
||||||
>>> result.prediction
|
>>> result.prediction
|
||||||
0.0
|
0.0
|
||||||
|
@ -1332,14 +1100,14 @@ class GBTRegressor(JavaPredictor, GBTRegressorParams, HasCheckpointInterval, Has
|
||||||
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
|
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
|
||||||
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None,
|
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None,
|
||||||
impurity="variance", featureSubsetStrategy="all", validationTol=0.01,
|
impurity="variance", featureSubsetStrategy="all", validationTol=0.01,
|
||||||
validationIndicatorCol=None, leafCol=""):
|
validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0):
|
||||||
"""
|
"""
|
||||||
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
||||||
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
||||||
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \
|
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \
|
||||||
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \
|
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \
|
||||||
impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \
|
impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \
|
||||||
validationIndicatorCol=None, leafCol="")
|
validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0)
|
||||||
"""
|
"""
|
||||||
super(GBTRegressor, self).__init__()
|
super(GBTRegressor, self).__init__()
|
||||||
self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.GBTRegressor", self.uid)
|
self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.GBTRegressor", self.uid)
|
||||||
|
@ -1347,7 +1115,7 @@ class GBTRegressor(JavaPredictor, GBTRegressorParams, HasCheckpointInterval, Has
|
||||||
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
|
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
|
||||||
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1,
|
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1,
|
||||||
impurity="variance", featureSubsetStrategy="all", validationTol=0.01,
|
impurity="variance", featureSubsetStrategy="all", validationTol=0.01,
|
||||||
leafCol="")
|
leafCol="", minWeightFractionPerNode=0.0)
|
||||||
kwargs = self._input_kwargs
|
kwargs = self._input_kwargs
|
||||||
self.setParams(**kwargs)
|
self.setParams(**kwargs)
|
||||||
|
|
||||||
|
@ -1358,14 +1126,14 @@ class GBTRegressor(JavaPredictor, GBTRegressorParams, HasCheckpointInterval, Has
|
||||||
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
|
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
|
||||||
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None,
|
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None,
|
||||||
impuriy="variance", featureSubsetStrategy="all", validationTol=0.01,
|
impuriy="variance", featureSubsetStrategy="all", validationTol=0.01,
|
||||||
validationIndicatorCol=None, leafCol=""):
|
validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0):
|
||||||
"""
|
"""
|
||||||
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
||||||
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
|
||||||
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \
|
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \
|
||||||
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \
|
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \
|
||||||
impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \
|
impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \
|
||||||
validationIndicatorCol=None, leafCol="")
|
validationIndicatorCol=None, leafCol="", minWeightFractionPerNode=0.0)
|
||||||
Sets params for Gradient Boosted Tree Regression.
|
Sets params for Gradient Boosted Tree Regression.
|
||||||
"""
|
"""
|
||||||
kwargs = self._input_kwargs
|
kwargs = self._input_kwargs
|
||||||
|
@ -1446,7 +1214,7 @@ class GBTRegressor(JavaPredictor, GBTRegressorParams, HasCheckpointInterval, Has
|
||||||
return self._set(validationIndicatorCol=value)
|
return self._set(validationIndicatorCol=value)
|
||||||
|
|
||||||
|
|
||||||
class GBTRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable):
|
class GBTRegressionModel(_TreeEnsembleModel, _GBTRegressorParams, JavaMLWritable, JavaMLReadable):
|
||||||
"""
|
"""
|
||||||
Model fitted by :class:`GBTRegressor`.
|
Model fitted by :class:`GBTRegressor`.
|
||||||
|
|
||||||
|
|
351
python/pyspark/ml/tree.py
Normal file
351
python/pyspark/ml/tree.py
Normal file
|
@ -0,0 +1,351 @@
|
||||||
|
#
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
# contributor license agreements. See the NOTICE file distributed with
|
||||||
|
# this work for additional information regarding copyright ownership.
|
||||||
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
# (the "License"); you may not use this file except in compliance with
|
||||||
|
# the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
from pyspark import since, keyword_only
|
||||||
|
from pyspark.ml.param.shared import *
|
||||||
|
from pyspark.ml.util import *
|
||||||
|
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, \
|
||||||
|
JavaPredictor, JavaPredictionModel
|
||||||
|
from pyspark.ml.common import inherit_doc, _java2py, _py2java
|
||||||
|
|
||||||
|
|
||||||
|
@inherit_doc
|
||||||
|
class _DecisionTreeModel(JavaPredictionModel):
|
||||||
|
"""
|
||||||
|
Abstraction for Decision Tree models.
|
||||||
|
|
||||||
|
.. versionadded:: 1.5.0
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@since("1.5.0")
|
||||||
|
def numNodes(self):
|
||||||
|
"""Return number of nodes of the decision tree."""
|
||||||
|
return self._call_java("numNodes")
|
||||||
|
|
||||||
|
@property
|
||||||
|
@since("1.5.0")
|
||||||
|
def depth(self):
|
||||||
|
"""Return depth of the decision tree."""
|
||||||
|
return self._call_java("depth")
|
||||||
|
|
||||||
|
@property
|
||||||
|
@since("2.0.0")
|
||||||
|
def toDebugString(self):
|
||||||
|
"""Full description of model."""
|
||||||
|
return self._call_java("toDebugString")
|
||||||
|
|
||||||
|
@since("3.0.0")
|
||||||
|
def predictLeaf(self, value):
|
||||||
|
"""
|
||||||
|
Predict the indices of the leaves corresponding to the feature vector.
|
||||||
|
"""
|
||||||
|
return self._call_java("predictLeaf", value)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self._call_java("toString")
|
||||||
|
|
||||||
|
|
||||||
|
class _DecisionTreeParams(HasCheckpointInterval, HasSeed, HasWeightCol):
|
||||||
|
"""
|
||||||
|
Mixin for Decision Tree parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
leafCol = Param(Params._dummy(), "leafCol", "Leaf indices column name. Predicted leaf " +
|
||||||
|
"index of each instance in each tree by preorder.",
|
||||||
|
typeConverter=TypeConverters.toString)
|
||||||
|
|
||||||
|
maxDepth = Param(Params._dummy(), "maxDepth", "Maximum depth of the tree. (>= 0) E.g., " +
|
||||||
|
"depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.",
|
||||||
|
typeConverter=TypeConverters.toInt)
|
||||||
|
|
||||||
|
maxBins = Param(Params._dummy(), "maxBins", "Max number of bins for discretizing continuous " +
|
||||||
|
"features. Must be >=2 and >= number of categories for any categorical " +
|
||||||
|
"feature.", typeConverter=TypeConverters.toInt)
|
||||||
|
|
||||||
|
minInstancesPerNode = Param(Params._dummy(), "minInstancesPerNode", "Minimum number of " +
|
||||||
|
"instances each child must have after split. If a split causes " +
|
||||||
|
"the left or right child to have fewer than " +
|
||||||
|
"minInstancesPerNode, the split will be discarded as invalid. " +
|
||||||
|
"Should be >= 1.", typeConverter=TypeConverters.toInt)
|
||||||
|
|
||||||
|
minWeightFractionPerNode = Param(Params._dummy(), "minWeightFractionPerNode", "Minimum "
|
||||||
|
"fraction of the weighted sample count that each child "
|
||||||
|
"must have after split. If a split causes the fraction "
|
||||||
|
"of the total weight in the left or right child to be "
|
||||||
|
"less than minWeightFractionPerNode, the split will be "
|
||||||
|
"discarded as invalid. Should be in interval [0.0, 0.5).",
|
||||||
|
typeConverter=TypeConverters.toFloat)
|
||||||
|
|
||||||
|
minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information gain for a split " +
|
||||||
|
"to be considered at a tree node.", typeConverter=TypeConverters.toFloat)
|
||||||
|
|
||||||
|
maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to " +
|
||||||
|
"histogram aggregation. If too small, then 1 node will be split per " +
|
||||||
|
"iteration, and its aggregates may exceed this size.",
|
||||||
|
typeConverter=TypeConverters.toInt)
|
||||||
|
|
||||||
|
cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass " +
|
||||||
|
"trees to executors to match instances with nodes. If true, the " +
|
||||||
|
"algorithm will cache node IDs for each instance. Caching can speed " +
|
||||||
|
"up training of deeper trees. Users can set how often should the cache " +
|
||||||
|
"be checkpointed or disable it by setting checkpointInterval.",
|
||||||
|
typeConverter=TypeConverters.toBoolean)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(_DecisionTreeParams, self).__init__()
|
||||||
|
|
||||||
|
def setLeafCol(self, value):
|
||||||
|
"""
|
||||||
|
Sets the value of :py:attr:`leafCol`.
|
||||||
|
"""
|
||||||
|
return self._set(leafCol=value)
|
||||||
|
|
||||||
|
def getLeafCol(self):
|
||||||
|
"""
|
||||||
|
Gets the value of leafCol or its default value.
|
||||||
|
"""
|
||||||
|
return self.getOrDefault(self.leafCol)
|
||||||
|
|
||||||
|
def getMaxDepth(self):
|
||||||
|
"""
|
||||||
|
Gets the value of maxDepth or its default value.
|
||||||
|
"""
|
||||||
|
return self.getOrDefault(self.maxDepth)
|
||||||
|
|
||||||
|
def getMaxBins(self):
|
||||||
|
"""
|
||||||
|
Gets the value of maxBins or its default value.
|
||||||
|
"""
|
||||||
|
return self.getOrDefault(self.maxBins)
|
||||||
|
|
||||||
|
def getMinInstancesPerNode(self):
|
||||||
|
"""
|
||||||
|
Gets the value of minInstancesPerNode or its default value.
|
||||||
|
"""
|
||||||
|
return self.getOrDefault(self.minInstancesPerNode)
|
||||||
|
|
||||||
|
def getMinWeightFractionPerNode(self):
|
||||||
|
"""
|
||||||
|
Gets the value of minWeightFractionPerNode or its default value.
|
||||||
|
"""
|
||||||
|
return self.getOrDefault(self.minWeightFractionPerNode)
|
||||||
|
|
||||||
|
def getMinInfoGain(self):
|
||||||
|
"""
|
||||||
|
Gets the value of minInfoGain or its default value.
|
||||||
|
"""
|
||||||
|
return self.getOrDefault(self.minInfoGain)
|
||||||
|
|
||||||
|
def getMaxMemoryInMB(self):
|
||||||
|
"""
|
||||||
|
Gets the value of maxMemoryInMB or its default value.
|
||||||
|
"""
|
||||||
|
return self.getOrDefault(self.maxMemoryInMB)
|
||||||
|
|
||||||
|
def getCacheNodeIds(self):
|
||||||
|
"""
|
||||||
|
Gets the value of cacheNodeIds or its default value.
|
||||||
|
"""
|
||||||
|
return self.getOrDefault(self.cacheNodeIds)
|
||||||
|
|
||||||
|
|
||||||
|
@inherit_doc
|
||||||
|
class _TreeEnsembleModel(JavaPredictionModel):
|
||||||
|
"""
|
||||||
|
(private abstraction)
|
||||||
|
Represents a tree ensemble model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@since("2.0.0")
|
||||||
|
def trees(self):
|
||||||
|
"""Trees in this ensemble. Warning: These have null parent Estimators."""
|
||||||
|
return [_DecisionTreeModel(m) for m in list(self._call_java("trees"))]
|
||||||
|
|
||||||
|
@property
|
||||||
|
@since("2.0.0")
|
||||||
|
def getNumTrees(self):
|
||||||
|
"""Number of trees in ensemble."""
|
||||||
|
return self._call_java("getNumTrees")
|
||||||
|
|
||||||
|
@property
|
||||||
|
@since("1.5.0")
|
||||||
|
def treeWeights(self):
|
||||||
|
"""Return the weights for each tree"""
|
||||||
|
return list(self._call_java("javaTreeWeights"))
|
||||||
|
|
||||||
|
@property
|
||||||
|
@since("2.0.0")
|
||||||
|
def totalNumNodes(self):
|
||||||
|
"""Total number of nodes, summed over all trees in the ensemble."""
|
||||||
|
return self._call_java("totalNumNodes")
|
||||||
|
|
||||||
|
@property
|
||||||
|
@since("2.0.0")
|
||||||
|
def toDebugString(self):
|
||||||
|
"""Full description of model."""
|
||||||
|
return self._call_java("toDebugString")
|
||||||
|
|
||||||
|
@since("3.0.0")
|
||||||
|
def predictLeaf(self, value):
|
||||||
|
"""
|
||||||
|
Predict the indices of the leaves corresponding to the feature vector.
|
||||||
|
"""
|
||||||
|
return self._call_java("predictLeaf", value)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self._call_java("toString")
|
||||||
|
|
||||||
|
|
||||||
|
class _TreeEnsembleParams(_DecisionTreeParams):
|
||||||
|
"""
|
||||||
|
Mixin for Decision Tree-based ensemble algorithms parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
subsamplingRate = Param(Params._dummy(), "subsamplingRate", "Fraction of the training data " +
|
||||||
|
"used for learning each decision tree, in range (0, 1].",
|
||||||
|
typeConverter=TypeConverters.toFloat)
|
||||||
|
|
||||||
|
supportedFeatureSubsetStrategies = ["auto", "all", "onethird", "sqrt", "log2"]
|
||||||
|
|
||||||
|
featureSubsetStrategy = \
|
||||||
|
Param(Params._dummy(), "featureSubsetStrategy",
|
||||||
|
"The number of features to consider for splits at each tree node. Supported " +
|
||||||
|
"options: 'auto' (choose automatically for task: If numTrees == 1, set to " +
|
||||||
|
"'all'. If numTrees > 1 (forest), set to 'sqrt' for classification and to " +
|
||||||
|
"'onethird' for regression), 'all' (use all features), 'onethird' (use " +
|
||||||
|
"1/3 of the features), 'sqrt' (use sqrt(number of features)), 'log2' (use " +
|
||||||
|
"log2(number of features)), 'n' (when n is in the range (0, 1.0], use " +
|
||||||
|
"n * number of features. When n is in the range (1, number of features), use" +
|
||||||
|
" n features). default = 'auto'", typeConverter=TypeConverters.toString)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(_TreeEnsembleParams, self).__init__()
|
||||||
|
|
||||||
|
@since("1.4.0")
|
||||||
|
def getSubsamplingRate(self):
|
||||||
|
"""
|
||||||
|
Gets the value of subsamplingRate or its default value.
|
||||||
|
"""
|
||||||
|
return self.getOrDefault(self.subsamplingRate)
|
||||||
|
|
||||||
|
@since("1.4.0")
|
||||||
|
def getFeatureSubsetStrategy(self):
|
||||||
|
"""
|
||||||
|
Gets the value of featureSubsetStrategy or its default value.
|
||||||
|
"""
|
||||||
|
return self.getOrDefault(self.featureSubsetStrategy)
|
||||||
|
|
||||||
|
|
||||||
|
class _RandomForestParams(_TreeEnsembleParams):
|
||||||
|
"""
|
||||||
|
Private class to track supported random forest parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
numTrees = Param(Params._dummy(), "numTrees", "Number of trees to train (>= 1).",
|
||||||
|
typeConverter=TypeConverters.toInt)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(_RandomForestParams, self).__init__()
|
||||||
|
|
||||||
|
@since("1.4.0")
|
||||||
|
def getNumTrees(self):
|
||||||
|
"""
|
||||||
|
Gets the value of numTrees or its default value.
|
||||||
|
"""
|
||||||
|
return self.getOrDefault(self.numTrees)
|
||||||
|
|
||||||
|
|
||||||
|
class _GBTParams(_TreeEnsembleParams, HasMaxIter, HasStepSize, HasValidationIndicatorCol):
|
||||||
|
"""
|
||||||
|
Private class to track supported GBT params.
|
||||||
|
"""
|
||||||
|
|
||||||
|
stepSize = Param(Params._dummy(), "stepSize",
|
||||||
|
"Step size (a.k.a. learning rate) in interval (0, 1] for shrinking " +
|
||||||
|
"the contribution of each estimator.",
|
||||||
|
typeConverter=TypeConverters.toFloat)
|
||||||
|
|
||||||
|
validationTol = Param(Params._dummy(), "validationTol",
|
||||||
|
"Threshold for stopping early when fit with validation is used. " +
|
||||||
|
"If the error rate on the validation input changes by less than the " +
|
||||||
|
"validationTol, then learning will stop early (before `maxIter`). " +
|
||||||
|
"This parameter is ignored when fit without validation is used.",
|
||||||
|
typeConverter=TypeConverters.toFloat)
|
||||||
|
|
||||||
|
@since("3.0.0")
|
||||||
|
def getValidationTol(self):
|
||||||
|
"""
|
||||||
|
Gets the value of validationTol or its default value.
|
||||||
|
"""
|
||||||
|
return self.getOrDefault(self.validationTol)
|
||||||
|
|
||||||
|
|
||||||
|
class _HasVarianceImpurity(Params):
|
||||||
|
"""
|
||||||
|
Private class to track supported impurity measures.
|
||||||
|
"""
|
||||||
|
|
||||||
|
supportedImpurities = ["variance"]
|
||||||
|
|
||||||
|
impurity = Param(Params._dummy(), "impurity",
|
||||||
|
"Criterion used for information gain calculation (case-insensitive). " +
|
||||||
|
"Supported options: " +
|
||||||
|
", ".join(supportedImpurities), typeConverter=TypeConverters.toString)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(_HasVarianceImpurity, self).__init__()
|
||||||
|
|
||||||
|
@since("1.4.0")
|
||||||
|
def getImpurity(self):
|
||||||
|
"""
|
||||||
|
Gets the value of impurity or its default value.
|
||||||
|
"""
|
||||||
|
return self.getOrDefault(self.impurity)
|
||||||
|
|
||||||
|
|
||||||
|
class _TreeClassifierParams(object):
|
||||||
|
"""
|
||||||
|
Private class to track supported impurity measures.
|
||||||
|
|
||||||
|
.. versionadded:: 1.4.0
|
||||||
|
"""
|
||||||
|
|
||||||
|
supportedImpurities = ["entropy", "gini"]
|
||||||
|
|
||||||
|
impurity = Param(Params._dummy(), "impurity",
|
||||||
|
"Criterion used for information gain calculation (case-insensitive). " +
|
||||||
|
"Supported options: " +
|
||||||
|
", ".join(supportedImpurities), typeConverter=TypeConverters.toString)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(_TreeClassifierParams, self).__init__()
|
||||||
|
|
||||||
|
@since("1.6.0")
|
||||||
|
def getImpurity(self):
|
||||||
|
"""
|
||||||
|
Gets the value of impurity or its default value.
|
||||||
|
"""
|
||||||
|
return self.getOrDefault(self.impurity)
|
||||||
|
|
||||||
|
|
||||||
|
class _TreeRegressorParams(_HasVarianceImpurity):
|
||||||
|
"""
|
||||||
|
Private class to track supported impurity measures.
|
||||||
|
"""
|
||||||
|
pass
|
Loading…
Reference in a new issue