[SPARK-28985][PYTHON][ML][FOLLOW-UP] Add _AFTSurvivalRegressionParams

### What changes were proposed in this pull request?

Adds

```python
_AFTSurvivalRegressionParams(HasFeaturesCol, HasLabelCol, HasPredictionCol,
                                   HasMaxIter, HasTol, HasFitIntercept,
                                   HasAggregationDepth): ...
```

with related Params and uses it to replace `HasFitIntercept`, `HasMaxIter`, `HasTol` and  `HasAggregationDepth` in `AFTSurvivalRegression` base classes and `JavaPredictionModel,` in `AFTSurvivalRegressionModel` base classes.

### Why are the changes needed?

Previous work (#25776) on [SPARK-28985](https://issues.apache.org/jira/browse/SPARK-28985) replaced `JavaEstimator`, `HasFeaturesCol`, `HasLabelCol`, `HasPredictionCol` in `AFTSurvivalRegression` and  `JavaModel` in `AFTSurvivalRegressionModel` with newly added `JavaPredictor`:

e97b55d322/python/pyspark/ml/wrapper.py (L377)

and `JavaPredictionModel`

e97b55d322/python/pyspark/ml/wrapper.py (L405)

respectively.

This however is inconsistent with Scala counterpart where both classes extend private `AFTSurvivalRegressionBase`

eb037a8180/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala (L48-L50)

This preserves some of the existing inconsistencies (variables as defined in [the official example](https://github.com/apache/spark/blob/master/examples/src/main/python/ml/aft_survival_regression.p))

```
from pyspark.ml.regression import AFTSurvivalRegression, AFTSurvivalRegressionModel
from pyspark.ml.param.shared import HasMaxIter, HasTol, HasFitIntercept, HasAggregationDepth
from pyspark.ml.param import Param

issubclass(AFTSurvivalRegressionModel, HasMaxIter)
# False
hasattr(model, "maxIter")  and isinstance(model.maxIter, Param)
# True

issubclass(AFTSurvivalRegressionModel, HasTol)
# False
hasattr(model, "tol")  and isinstance(model.tol, Param)
# True
```

and can cause problems in the future, if Predictor / PredictionModel API changes (unlike [`IsotonicRegression`](https://github.com/apache/spark/pull/26023), current implementation is technically speaking correct, though incomplete).

### Does this PR introduce any user-facing change?

Yes, it adds a number of base classes to `AFTSurvivalRegressionModel`. These change purely additive and have negligible potential for breaking existing code (and none, compared to changes already made in #25776). Additionally affected API hasn't been released in the current form yet.

### How was this patch tested?

- Existing unit tests.
- Manual testing.

CC huaxingao, zhengruifeng

Closes #26024 from zero323/SPARK-28985-FOLLOW-UP-aftsurival-regression.

Authored-by: zero323 <mszymkiewicz@gmail.com>
Signed-off-by: Sean Owen <sean.owen@databricks.com>
This commit is contained in:
zero323 2019-10-04 18:04:21 -05:00 committed by Sean Owen
parent 228b1ea96c
commit df22535bbd

View file

@ -1480,9 +1480,56 @@ class GBTRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable,
return self._call_java("evaluateEachIteration", dataset, loss)
class _AFTSurvivalRegressionParams(HasFeaturesCol, HasLabelCol, HasPredictionCol,
HasMaxIter, HasTol, HasFitIntercept,
HasAggregationDepth):
"""
Params for :py:class:`AFTSurvivalRegression` and :py:class:`AFTSurvivalRegressionModel`.
.. versionadded:: 3.0.0
"""
censorCol = Param(
Params._dummy(), "censorCol",
"censor column name. The value of this column could be 0 or 1. " +
"If the value is 1, it means the event has occurred i.e. " +
"uncensored; otherwise censored.", typeConverter=TypeConverters.toString)
quantileProbabilities = Param(
Params._dummy(), "quantileProbabilities",
"quantile probabilities array. Values of the quantile probabilities array " +
"should be in the range (0, 1) and the array should be non-empty.",
typeConverter=TypeConverters.toListFloat)
quantilesCol = Param(
Params._dummy(), "quantilesCol",
"quantiles column name. This column will output quantiles of " +
"corresponding quantileProbabilities if it is set.",
typeConverter=TypeConverters.toString)
@since("1.6.0")
def getCensorCol(self):
"""
Gets the value of censorCol or its default value.
"""
return self.getOrDefault(self.censorCol)
@since("1.6.0")
def getQuantileProbabilities(self):
"""
Gets the value of quantileProbabilities or its default value.
"""
return self.getOrDefault(self.quantileProbabilities)
@since("1.6.0")
def getQuantilesCol(self):
"""
Gets the value of quantilesCol or its default value.
"""
return self.getOrDefault(self.quantilesCol)
@inherit_doc
class AFTSurvivalRegression(JavaPredictor, HasFitIntercept, HasMaxIter, HasTol,
HasAggregationDepth, JavaMLWritable, JavaMLReadable):
class AFTSurvivalRegression(JavaEstimator, _AFTSurvivalRegressionParams,
JavaMLWritable, JavaMLReadable):
"""
Accelerated Failure Time (AFT) Model Survival Regression
@ -1529,20 +1576,6 @@ class AFTSurvivalRegression(JavaPredictor, HasFitIntercept, HasMaxIter, HasTol,
.. versionadded:: 1.6.0
"""
censorCol = Param(Params._dummy(), "censorCol",
"censor column name. The value of this column could be 0 or 1. " +
"If the value is 1, it means the event has occurred i.e. " +
"uncensored; otherwise censored.", typeConverter=TypeConverters.toString)
quantileProbabilities = \
Param(Params._dummy(), "quantileProbabilities",
"quantile probabilities array. Values of the quantile probabilities array " +
"should be in the range (0, 1) and the array should be non-empty.",
typeConverter=TypeConverters.toListFloat)
quantilesCol = Param(Params._dummy(), "quantilesCol",
"quantiles column name. This column will output quantiles of " +
"corresponding quantileProbabilities if it is set.",
typeConverter=TypeConverters.toString)
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
@ -1588,13 +1621,6 @@ class AFTSurvivalRegression(JavaPredictor, HasFitIntercept, HasMaxIter, HasTol,
"""
return self._set(censorCol=value)
@since("1.6.0")
def getCensorCol(self):
"""
Gets the value of censorCol or its default value.
"""
return self.getOrDefault(self.censorCol)
@since("1.6.0")
def setQuantileProbabilities(self, value):
"""
@ -1602,13 +1628,6 @@ class AFTSurvivalRegression(JavaPredictor, HasFitIntercept, HasMaxIter, HasTol,
"""
return self._set(quantileProbabilities=value)
@since("1.6.0")
def getQuantileProbabilities(self):
"""
Gets the value of quantileProbabilities or its default value.
"""
return self.getOrDefault(self.quantileProbabilities)
@since("1.6.0")
def setQuantilesCol(self, value):
"""
@ -1616,15 +1635,9 @@ class AFTSurvivalRegression(JavaPredictor, HasFitIntercept, HasMaxIter, HasTol,
"""
return self._set(quantilesCol=value)
@since("1.6.0")
def getQuantilesCol(self):
"""
Gets the value of quantilesCol or its default value.
"""
return self.getOrDefault(self.quantilesCol)
class AFTSurvivalRegressionModel(JavaPredictionModel, JavaMLWritable, JavaMLReadable):
class AFTSurvivalRegressionModel(JavaModel, _AFTSurvivalRegressionParams,
JavaMLWritable, JavaMLReadable):
"""
Model fitted by :class:`AFTSurvivalRegression`.