[SPARK-17197][ML][PYSPARK] PySpark LiR/LoR supports tree aggregation level configurable.
## What changes were proposed in this pull request? [SPARK-17090](https://issues.apache.org/jira/browse/SPARK-17090) makes tree aggregation level in LiR/LoR configurable, this PR makes PySpark support this function. ## How was this patch tested? Since ```aggregationDepth``` is an expert param, I'm not prefer to test it in doctest which is also used for example. Here is the offline test result: ![image](https://cloud.githubusercontent.com/assets/1962026/17879457/f83d7760-68a6-11e6-9936-d0a884d5d6ec.png) Author: Yanbo Liang <ybliang8@gmail.com> Closes #14766 from yanboliang/spark-17197.
This commit is contained in:
parent
e0b20f9f24
commit
6b8cb1fe52
|
@ -64,7 +64,7 @@ class JavaClassificationModel(JavaPredictionModel):
|
|||
class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
|
||||
HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol,
|
||||
HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds,
|
||||
HasWeightCol, JavaMLWritable, JavaMLReadable):
|
||||
HasWeightCol, HasAggregationDepth, JavaMLWritable, JavaMLReadable):
|
||||
"""
|
||||
Logistic regression.
|
||||
Currently, this class only supports binary classification.
|
||||
|
@ -121,12 +121,14 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
|
|||
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
||||
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
|
||||
threshold=0.5, thresholds=None, probabilityCol="probability",
|
||||
rawPredictionCol="rawPrediction", standardization=True, weightCol=None):
|
||||
rawPredictionCol="rawPrediction", standardization=True, weightCol=None,
|
||||
aggregationDepth=2):
|
||||
"""
|
||||
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
||||
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
|
||||
threshold=0.5, thresholds=None, probabilityCol="probability", \
|
||||
rawPredictionCol="rawPrediction", standardization=True, weightCol=None)
|
||||
rawPredictionCol="rawPrediction", standardization=True, weightCol=None, \
|
||||
aggregationDepth=2)
|
||||
If the threshold and thresholds Params are both set, they must be equivalent.
|
||||
"""
|
||||
super(LogisticRegression, self).__init__()
|
||||
|
@ -142,12 +144,14 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
|
|||
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
||||
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
|
||||
threshold=0.5, thresholds=None, probabilityCol="probability",
|
||||
rawPredictionCol="rawPrediction", standardization=True, weightCol=None):
|
||||
rawPredictionCol="rawPrediction", standardization=True, weightCol=None,
|
||||
aggregationDepth=2):
|
||||
"""
|
||||
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
||||
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
|
||||
threshold=0.5, thresholds=None, probabilityCol="probability", \
|
||||
rawPredictionCol="rawPrediction", standardization=True, weightCol=None)
|
||||
rawPredictionCol="rawPrediction", standardization=True, weightCol=None, \
|
||||
aggregationDepth=2)
|
||||
Sets params for logistic regression.
|
||||
If the threshold and thresholds Params are both set, they must be equivalent.
|
||||
"""
|
||||
|
|
|
@ -147,7 +147,9 @@ if __name__ == "__main__":
|
|||
("solver", "the solver algorithm for optimization. If this is not set or empty, " +
|
||||
"default value is 'auto'.", "'auto'", "TypeConverters.toString"),
|
||||
("varianceCol", "column name for the biased sample variance of prediction.",
|
||||
None, "TypeConverters.toString")]
|
||||
None, "TypeConverters.toString"),
|
||||
("aggregationDepth", "suggested depth for treeAggregate (>= 2).", "2",
|
||||
"TypeConverters.toInt")]
|
||||
|
||||
code = []
|
||||
for name, doc, defaultValueStr, typeConverter in shared:
|
||||
|
|
|
@ -560,6 +560,30 @@ class HasVarianceCol(Params):
|
|||
return self.getOrDefault(self.varianceCol)
|
||||
|
||||
|
||||
class HasAggregationDepth(Params):
|
||||
"""
|
||||
Mixin for param aggregationDepth: suggested depth for treeAggregate (>= 2).
|
||||
"""
|
||||
|
||||
aggregationDepth = Param(Params._dummy(), "aggregationDepth", "suggested depth for treeAggregate (>= 2).", typeConverter=TypeConverters.toInt)
|
||||
|
||||
def __init__(self):
|
||||
super(HasAggregationDepth, self).__init__()
|
||||
self._setDefault(aggregationDepth=2)
|
||||
|
||||
def setAggregationDepth(self, value):
|
||||
"""
|
||||
Sets the value of :py:attr:`aggregationDepth`.
|
||||
"""
|
||||
return self._set(aggregationDepth=value)
|
||||
|
||||
def getAggregationDepth(self):
|
||||
"""
|
||||
Gets the value of aggregationDepth or its default value.
|
||||
"""
|
||||
return self.getOrDefault(self.aggregationDepth)
|
||||
|
||||
|
||||
class DecisionTreeParams(Params):
|
||||
"""
|
||||
Mixin for Decision Tree parameters.
|
||||
|
|
|
@ -39,7 +39,8 @@ __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel',
|
|||
@inherit_doc
|
||||
class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
|
||||
HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept,
|
||||
HasStandardization, HasSolver, HasWeightCol, JavaMLWritable, JavaMLReadable):
|
||||
HasStandardization, HasSolver, HasWeightCol, HasAggregationDepth,
|
||||
JavaMLWritable, JavaMLReadable):
|
||||
"""
|
||||
Linear regression.
|
||||
|
||||
|
@ -97,11 +98,11 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
|
|||
@keyword_only
|
||||
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
||||
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
|
||||
standardization=True, solver="auto", weightCol=None):
|
||||
standardization=True, solver="auto", weightCol=None, aggregationDepth=2):
|
||||
"""
|
||||
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
||||
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
|
||||
standardization=True, solver="auto", weightCol=None)
|
||||
standardization=True, solver="auto", weightCol=None, aggregationDepth=2)
|
||||
"""
|
||||
super(LinearRegression, self).__init__()
|
||||
self._java_obj = self._new_java_obj(
|
||||
|
@ -114,11 +115,11 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
|
|||
@since("1.4.0")
|
||||
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
||||
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
|
||||
standardization=True, solver="auto", weightCol=None):
|
||||
standardization=True, solver="auto", weightCol=None, aggregationDepth=2):
|
||||
"""
|
||||
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
||||
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
|
||||
standardization=True, solver="auto", weightCol=None)
|
||||
standardization=True, solver="auto", weightCol=None, aggregationDepth=2)
|
||||
Sets params for linear regression.
|
||||
"""
|
||||
kwargs = self.setParams._input_kwargs
|
||||
|
|
Loading…
Reference in a new issue