[SPARK-17197][ML][PYSPARK] PySpark LiR/LoR supports tree aggregation level configurable.

## What changes were proposed in this pull request?
[SPARK-17090](https://issues.apache.org/jira/browse/SPARK-17090) makes tree aggregation level in LiR/LoR configurable, this PR makes PySpark support this function.

## How was this patch tested?
Since ```aggregationDepth``` is an expert param, I'm not prefer to test it in doctest which is also used for example. Here is the offline test result:
![image](https://cloud.githubusercontent.com/assets/1962026/17879457/f83d7760-68a6-11e6-9936-d0a884d5d6ec.png)

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #14766 from yanboliang/spark-17197.
This commit is contained in:
Yanbo Liang 2016-08-25 02:26:33 -07:00
parent e0b20f9f24
commit 6b8cb1fe52
4 changed files with 42 additions and 11 deletions

View file

@ -64,7 +64,7 @@ class JavaClassificationModel(JavaPredictionModel):
class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol,
HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds,
HasWeightCol, JavaMLWritable, JavaMLReadable):
HasWeightCol, HasAggregationDepth, JavaMLWritable, JavaMLReadable):
"""
Logistic regression.
Currently, this class only supports binary classification.
@ -121,12 +121,14 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
threshold=0.5, thresholds=None, probabilityCol="probability",
rawPredictionCol="rawPrediction", standardization=True, weightCol=None):
rawPredictionCol="rawPrediction", standardization=True, weightCol=None,
aggregationDepth=2):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
threshold=0.5, thresholds=None, probabilityCol="probability", \
rawPredictionCol="rawPrediction", standardization=True, weightCol=None)
rawPredictionCol="rawPrediction", standardization=True, weightCol=None, \
aggregationDepth=2)
If the threshold and thresholds Params are both set, they must be equivalent.
"""
super(LogisticRegression, self).__init__()
@ -142,12 +144,14 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
threshold=0.5, thresholds=None, probabilityCol="probability",
rawPredictionCol="rawPrediction", standardization=True, weightCol=None):
rawPredictionCol="rawPrediction", standardization=True, weightCol=None,
aggregationDepth=2):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
threshold=0.5, thresholds=None, probabilityCol="probability", \
rawPredictionCol="rawPrediction", standardization=True, weightCol=None)
rawPredictionCol="rawPrediction", standardization=True, weightCol=None, \
aggregationDepth=2)
Sets params for logistic regression.
If the threshold and thresholds Params are both set, they must be equivalent.
"""

View file

@ -147,7 +147,9 @@ if __name__ == "__main__":
("solver", "the solver algorithm for optimization. If this is not set or empty, " +
"default value is 'auto'.", "'auto'", "TypeConverters.toString"),
("varianceCol", "column name for the biased sample variance of prediction.",
None, "TypeConverters.toString")]
None, "TypeConverters.toString"),
("aggregationDepth", "suggested depth for treeAggregate (>= 2).", "2",
"TypeConverters.toInt")]
code = []
for name, doc, defaultValueStr, typeConverter in shared:

View file

@ -560,6 +560,30 @@ class HasVarianceCol(Params):
return self.getOrDefault(self.varianceCol)
class HasAggregationDepth(Params):
"""
Mixin for param aggregationDepth: suggested depth for treeAggregate (>= 2).
"""
aggregationDepth = Param(Params._dummy(), "aggregationDepth", "suggested depth for treeAggregate (>= 2).", typeConverter=TypeConverters.toInt)
def __init__(self):
super(HasAggregationDepth, self).__init__()
self._setDefault(aggregationDepth=2)
def setAggregationDepth(self, value):
"""
Sets the value of :py:attr:`aggregationDepth`.
"""
return self._set(aggregationDepth=value)
def getAggregationDepth(self):
"""
Gets the value of aggregationDepth or its default value.
"""
return self.getOrDefault(self.aggregationDepth)
class DecisionTreeParams(Params):
"""
Mixin for Decision Tree parameters.

View file

@ -39,7 +39,8 @@ __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel',
@inherit_doc
class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept,
HasStandardization, HasSolver, HasWeightCol, JavaMLWritable, JavaMLReadable):
HasStandardization, HasSolver, HasWeightCol, HasAggregationDepth,
JavaMLWritable, JavaMLReadable):
"""
Linear regression.
@ -97,11 +98,11 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
standardization=True, solver="auto", weightCol=None):
standardization=True, solver="auto", weightCol=None, aggregationDepth=2):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
standardization=True, solver="auto", weightCol=None)
standardization=True, solver="auto", weightCol=None, aggregationDepth=2)
"""
super(LinearRegression, self).__init__()
self._java_obj = self._new_java_obj(
@ -114,11 +115,11 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
@since("1.4.0")
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
standardization=True, solver="auto", weightCol=None):
standardization=True, solver="auto", weightCol=None, aggregationDepth=2):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
standardization=True, solver="auto", weightCol=None)
standardization=True, solver="auto", weightCol=None, aggregationDepth=2)
Sets params for linear regression.
"""
kwargs = self.setParams._input_kwargs