From df22535bbd1745235ffb712012183d4c7613c280 Mon Sep 17 00:00:00 2001 From: zero323 Date: Fri, 4 Oct 2019 18:04:21 -0500 Subject: [PATCH] [SPARK-28985][PYTHON][ML][FOLLOW-UP] Add _AFTSurvivalRegressionParams ### What changes were proposed in this pull request? Adds ```python _AFTSurvivalRegressionParams(HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasTol, HasFitIntercept, HasAggregationDepth): ... ``` with related Params and uses it to replace `HasFitIntercept`, `HasMaxIter`, `HasTol` and `HasAggregationDepth` in `AFTSurvivalRegression` base classes and `JavaPredictionModel,` in `AFTSurvivalRegressionModel` base classes. ### Why are the changes needed? Previous work (#25776) on [SPARK-28985](https://issues.apache.org/jira/browse/SPARK-28985) replaced `JavaEstimator`, `HasFeaturesCol`, `HasLabelCol`, `HasPredictionCol` in `AFTSurvivalRegression` and `JavaModel` in `AFTSurvivalRegressionModel` with newly added `JavaPredictor`: https://github.com/apache/spark/blob/e97b55d32285052a1f76cca35377c4b21eb2e7d7/python/pyspark/ml/wrapper.py#L377 and `JavaPredictionModel` https://github.com/apache/spark/blob/e97b55d32285052a1f76cca35377c4b21eb2e7d7/python/pyspark/ml/wrapper.py#L405 respectively. This however is inconsistent with Scala counterpart where both classes extend private `AFTSurvivalRegressionBase` https://github.com/apache/spark/blob/eb037a8180be4ab7570eda1fa9cbf3c84b92c3f7/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala#L48-L50 This preserves some of the existing inconsistencies (variables as defined in [the official example](https://github.com/apache/spark/blob/master/examples/src/main/python/ml/aft_survival_regression.p)) ``` from pyspark.ml.regression import AFTSurvivalRegression, AFTSurvivalRegressionModel from pyspark.ml.param.shared import HasMaxIter, HasTol, HasFitIntercept, HasAggregationDepth from pyspark.ml.param import Param issubclass(AFTSurvivalRegressionModel, HasMaxIter) # False hasattr(model, "maxIter") and isinstance(model.maxIter, Param) # True issubclass(AFTSurvivalRegressionModel, HasTol) # False hasattr(model, "tol") and isinstance(model.tol, Param) # True ``` and can cause problems in the future, if Predictor / PredictionModel API changes (unlike [`IsotonicRegression`](https://github.com/apache/spark/pull/26023), current implementation is technically speaking correct, though incomplete). ### Does this PR introduce any user-facing change? Yes, it adds a number of base classes to `AFTSurvivalRegressionModel`. These change purely additive and have negligible potential for breaking existing code (and none, compared to changes already made in #25776). Additionally affected API hasn't been released in the current form yet. ### How was this patch tested? - Existing unit tests. - Manual testing. CC huaxingao, zhengruifeng Closes #26024 from zero323/SPARK-28985-FOLLOW-UP-aftsurival-regression. Authored-by: zero323 Signed-off-by: Sean Owen --- python/pyspark/ml/regression.py | 89 +++++++++++++++++++-------------- 1 file changed, 51 insertions(+), 38 deletions(-) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index f2bcc66203..1a7d39ba89 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -1480,9 +1480,56 @@ class GBTRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, return self._call_java("evaluateEachIteration", dataset, loss) +class _AFTSurvivalRegressionParams(HasFeaturesCol, HasLabelCol, HasPredictionCol, + HasMaxIter, HasTol, HasFitIntercept, + HasAggregationDepth): + """ + Params for :py:class:`AFTSurvivalRegression` and :py:class:`AFTSurvivalRegressionModel`. + + .. versionadded:: 3.0.0 + """ + + censorCol = Param( + Params._dummy(), "censorCol", + "censor column name. The value of this column could be 0 or 1. " + + "If the value is 1, it means the event has occurred i.e. " + + "uncensored; otherwise censored.", typeConverter=TypeConverters.toString) + quantileProbabilities = Param( + Params._dummy(), "quantileProbabilities", + "quantile probabilities array. Values of the quantile probabilities array " + + "should be in the range (0, 1) and the array should be non-empty.", + typeConverter=TypeConverters.toListFloat) + quantilesCol = Param( + Params._dummy(), "quantilesCol", + "quantiles column name. This column will output quantiles of " + + "corresponding quantileProbabilities if it is set.", + typeConverter=TypeConverters.toString) + + @since("1.6.0") + def getCensorCol(self): + """ + Gets the value of censorCol or its default value. + """ + return self.getOrDefault(self.censorCol) + + @since("1.6.0") + def getQuantileProbabilities(self): + """ + Gets the value of quantileProbabilities or its default value. + """ + return self.getOrDefault(self.quantileProbabilities) + + @since("1.6.0") + def getQuantilesCol(self): + """ + Gets the value of quantilesCol or its default value. + """ + return self.getOrDefault(self.quantilesCol) + + @inherit_doc -class AFTSurvivalRegression(JavaPredictor, HasFitIntercept, HasMaxIter, HasTol, - HasAggregationDepth, JavaMLWritable, JavaMLReadable): +class AFTSurvivalRegression(JavaEstimator, _AFTSurvivalRegressionParams, + JavaMLWritable, JavaMLReadable): """ Accelerated Failure Time (AFT) Model Survival Regression @@ -1529,20 +1576,6 @@ class AFTSurvivalRegression(JavaPredictor, HasFitIntercept, HasMaxIter, HasTol, .. versionadded:: 1.6.0 """ - censorCol = Param(Params._dummy(), "censorCol", - "censor column name. The value of this column could be 0 or 1. " + - "If the value is 1, it means the event has occurred i.e. " + - "uncensored; otherwise censored.", typeConverter=TypeConverters.toString) - quantileProbabilities = \ - Param(Params._dummy(), "quantileProbabilities", - "quantile probabilities array. Values of the quantile probabilities array " + - "should be in the range (0, 1) and the array should be non-empty.", - typeConverter=TypeConverters.toListFloat) - quantilesCol = Param(Params._dummy(), "quantilesCol", - "quantiles column name. This column will output quantiles of " + - "corresponding quantileProbabilities if it is set.", - typeConverter=TypeConverters.toString) - @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", @@ -1588,13 +1621,6 @@ class AFTSurvivalRegression(JavaPredictor, HasFitIntercept, HasMaxIter, HasTol, """ return self._set(censorCol=value) - @since("1.6.0") - def getCensorCol(self): - """ - Gets the value of censorCol or its default value. - """ - return self.getOrDefault(self.censorCol) - @since("1.6.0") def setQuantileProbabilities(self, value): """ @@ -1602,13 +1628,6 @@ class AFTSurvivalRegression(JavaPredictor, HasFitIntercept, HasMaxIter, HasTol, """ return self._set(quantileProbabilities=value) - @since("1.6.0") - def getQuantileProbabilities(self): - """ - Gets the value of quantileProbabilities or its default value. - """ - return self.getOrDefault(self.quantileProbabilities) - @since("1.6.0") def setQuantilesCol(self, value): """ @@ -1616,15 +1635,9 @@ class AFTSurvivalRegression(JavaPredictor, HasFitIntercept, HasMaxIter, HasTol, """ return self._set(quantilesCol=value) - @since("1.6.0") - def getQuantilesCol(self): - """ - Gets the value of quantilesCol or its default value. - """ - return self.getOrDefault(self.quantilesCol) - -class AFTSurvivalRegressionModel(JavaPredictionModel, JavaMLWritable, JavaMLReadable): +class AFTSurvivalRegressionModel(JavaModel, _AFTSurvivalRegressionParams, + JavaMLWritable, JavaMLReadable): """ Model fitted by :class:`AFTSurvivalRegression`.