[SPARK-28985][PYTHON][ML][FOLLOW-UP] Add _AFTSurvivalRegressionParams
### What changes were proposed in this pull request? Adds ```python _AFTSurvivalRegressionParams(HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasTol, HasFitIntercept, HasAggregationDepth): ... ``` with related Params and uses it to replace `HasFitIntercept`, `HasMaxIter`, `HasTol` and `HasAggregationDepth` in `AFTSurvivalRegression` base classes and `JavaPredictionModel,` in `AFTSurvivalRegressionModel` base classes. ### Why are the changes needed? Previous work (#25776) on [SPARK-28985](https://issues.apache.org/jira/browse/SPARK-28985) replaced `JavaEstimator`, `HasFeaturesCol`, `HasLabelCol`, `HasPredictionCol` in `AFTSurvivalRegression` and `JavaModel` in `AFTSurvivalRegressionModel` with newly added `JavaPredictor`:e97b55d322/python/pyspark/ml/wrapper.py (L377)
and `JavaPredictionModel`e97b55d322/python/pyspark/ml/wrapper.py (L405)
respectively. This however is inconsistent with Scala counterpart where both classes extend private `AFTSurvivalRegressionBase`eb037a8180/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala (L48-L50)
This preserves some of the existing inconsistencies (variables as defined in [the official example](https://github.com/apache/spark/blob/master/examples/src/main/python/ml/aft_survival_regression.p)) ``` from pyspark.ml.regression import AFTSurvivalRegression, AFTSurvivalRegressionModel from pyspark.ml.param.shared import HasMaxIter, HasTol, HasFitIntercept, HasAggregationDepth from pyspark.ml.param import Param issubclass(AFTSurvivalRegressionModel, HasMaxIter) # False hasattr(model, "maxIter") and isinstance(model.maxIter, Param) # True issubclass(AFTSurvivalRegressionModel, HasTol) # False hasattr(model, "tol") and isinstance(model.tol, Param) # True ``` and can cause problems in the future, if Predictor / PredictionModel API changes (unlike [`IsotonicRegression`](https://github.com/apache/spark/pull/26023), current implementation is technically speaking correct, though incomplete). ### Does this PR introduce any user-facing change? Yes, it adds a number of base classes to `AFTSurvivalRegressionModel`. These change purely additive and have negligible potential for breaking existing code (and none, compared to changes already made in #25776). Additionally affected API hasn't been released in the current form yet. ### How was this patch tested? - Existing unit tests. - Manual testing. CC huaxingao, zhengruifeng Closes #26024 from zero323/SPARK-28985-FOLLOW-UP-aftsurival-regression. Authored-by: zero323 <mszymkiewicz@gmail.com> Signed-off-by: Sean Owen <sean.owen@databricks.com>
This commit is contained in:
parent
228b1ea96c
commit
df22535bbd
|
@ -1480,9 +1480,56 @@ class GBTRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable,
|
|||
return self._call_java("evaluateEachIteration", dataset, loss)
|
||||
|
||||
|
||||
class _AFTSurvivalRegressionParams(HasFeaturesCol, HasLabelCol, HasPredictionCol,
|
||||
HasMaxIter, HasTol, HasFitIntercept,
|
||||
HasAggregationDepth):
|
||||
"""
|
||||
Params for :py:class:`AFTSurvivalRegression` and :py:class:`AFTSurvivalRegressionModel`.
|
||||
|
||||
.. versionadded:: 3.0.0
|
||||
"""
|
||||
|
||||
censorCol = Param(
|
||||
Params._dummy(), "censorCol",
|
||||
"censor column name. The value of this column could be 0 or 1. " +
|
||||
"If the value is 1, it means the event has occurred i.e. " +
|
||||
"uncensored; otherwise censored.", typeConverter=TypeConverters.toString)
|
||||
quantileProbabilities = Param(
|
||||
Params._dummy(), "quantileProbabilities",
|
||||
"quantile probabilities array. Values of the quantile probabilities array " +
|
||||
"should be in the range (0, 1) and the array should be non-empty.",
|
||||
typeConverter=TypeConverters.toListFloat)
|
||||
quantilesCol = Param(
|
||||
Params._dummy(), "quantilesCol",
|
||||
"quantiles column name. This column will output quantiles of " +
|
||||
"corresponding quantileProbabilities if it is set.",
|
||||
typeConverter=TypeConverters.toString)
|
||||
|
||||
@since("1.6.0")
|
||||
def getCensorCol(self):
|
||||
"""
|
||||
Gets the value of censorCol or its default value.
|
||||
"""
|
||||
return self.getOrDefault(self.censorCol)
|
||||
|
||||
@since("1.6.0")
|
||||
def getQuantileProbabilities(self):
|
||||
"""
|
||||
Gets the value of quantileProbabilities or its default value.
|
||||
"""
|
||||
return self.getOrDefault(self.quantileProbabilities)
|
||||
|
||||
@since("1.6.0")
|
||||
def getQuantilesCol(self):
|
||||
"""
|
||||
Gets the value of quantilesCol or its default value.
|
||||
"""
|
||||
return self.getOrDefault(self.quantilesCol)
|
||||
|
||||
|
||||
@inherit_doc
|
||||
class AFTSurvivalRegression(JavaPredictor, HasFitIntercept, HasMaxIter, HasTol,
|
||||
HasAggregationDepth, JavaMLWritable, JavaMLReadable):
|
||||
class AFTSurvivalRegression(JavaEstimator, _AFTSurvivalRegressionParams,
|
||||
JavaMLWritable, JavaMLReadable):
|
||||
"""
|
||||
Accelerated Failure Time (AFT) Model Survival Regression
|
||||
|
||||
|
@ -1529,20 +1576,6 @@ class AFTSurvivalRegression(JavaPredictor, HasFitIntercept, HasMaxIter, HasTol,
|
|||
.. versionadded:: 1.6.0
|
||||
"""
|
||||
|
||||
censorCol = Param(Params._dummy(), "censorCol",
|
||||
"censor column name. The value of this column could be 0 or 1. " +
|
||||
"If the value is 1, it means the event has occurred i.e. " +
|
||||
"uncensored; otherwise censored.", typeConverter=TypeConverters.toString)
|
||||
quantileProbabilities = \
|
||||
Param(Params._dummy(), "quantileProbabilities",
|
||||
"quantile probabilities array. Values of the quantile probabilities array " +
|
||||
"should be in the range (0, 1) and the array should be non-empty.",
|
||||
typeConverter=TypeConverters.toListFloat)
|
||||
quantilesCol = Param(Params._dummy(), "quantilesCol",
|
||||
"quantiles column name. This column will output quantiles of " +
|
||||
"corresponding quantileProbabilities if it is set.",
|
||||
typeConverter=TypeConverters.toString)
|
||||
|
||||
@keyword_only
|
||||
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
||||
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
|
||||
|
@ -1588,13 +1621,6 @@ class AFTSurvivalRegression(JavaPredictor, HasFitIntercept, HasMaxIter, HasTol,
|
|||
"""
|
||||
return self._set(censorCol=value)
|
||||
|
||||
@since("1.6.0")
|
||||
def getCensorCol(self):
|
||||
"""
|
||||
Gets the value of censorCol or its default value.
|
||||
"""
|
||||
return self.getOrDefault(self.censorCol)
|
||||
|
||||
@since("1.6.0")
|
||||
def setQuantileProbabilities(self, value):
|
||||
"""
|
||||
|
@ -1602,13 +1628,6 @@ class AFTSurvivalRegression(JavaPredictor, HasFitIntercept, HasMaxIter, HasTol,
|
|||
"""
|
||||
return self._set(quantileProbabilities=value)
|
||||
|
||||
@since("1.6.0")
|
||||
def getQuantileProbabilities(self):
|
||||
"""
|
||||
Gets the value of quantileProbabilities or its default value.
|
||||
"""
|
||||
return self.getOrDefault(self.quantileProbabilities)
|
||||
|
||||
@since("1.6.0")
|
||||
def setQuantilesCol(self, value):
|
||||
"""
|
||||
|
@ -1616,15 +1635,9 @@ class AFTSurvivalRegression(JavaPredictor, HasFitIntercept, HasMaxIter, HasTol,
|
|||
"""
|
||||
return self._set(quantilesCol=value)
|
||||
|
||||
@since("1.6.0")
|
||||
def getQuantilesCol(self):
|
||||
"""
|
||||
Gets the value of quantilesCol or its default value.
|
||||
"""
|
||||
return self.getOrDefault(self.quantilesCol)
|
||||
|
||||
|
||||
class AFTSurvivalRegressionModel(JavaPredictionModel, JavaMLWritable, JavaMLReadable):
|
||||
class AFTSurvivalRegressionModel(JavaModel, _AFTSurvivalRegressionParams,
|
||||
JavaMLWritable, JavaMLReadable):
|
||||
"""
|
||||
Model fitted by :class:`AFTSurvivalRegression`.
|
||||
|
||||
|
|
Loading…
Reference in a new issue