[SPARK-28985][PYTHON][ML][FOLLOW-UP] Add _AFTSurvivalRegressionParams
### What changes were proposed in this pull request? Adds ```python _AFTSurvivalRegressionParams(HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasTol, HasFitIntercept, HasAggregationDepth): ... ``` with related Params and uses it to replace `HasFitIntercept`, `HasMaxIter`, `HasTol` and `HasAggregationDepth` in `AFTSurvivalRegression` base classes and `JavaPredictionModel,` in `AFTSurvivalRegressionModel` base classes. ### Why are the changes needed? Previous work (#25776) on [SPARK-28985](https://issues.apache.org/jira/browse/SPARK-28985) replaced `JavaEstimator`, `HasFeaturesCol`, `HasLabelCol`, `HasPredictionCol` in `AFTSurvivalRegression` and `JavaModel` in `AFTSurvivalRegressionModel` with newly added `JavaPredictor`:e97b55d322/python/pyspark/ml/wrapper.py (L377)
and `JavaPredictionModel`e97b55d322/python/pyspark/ml/wrapper.py (L405)
respectively. This however is inconsistent with Scala counterpart where both classes extend private `AFTSurvivalRegressionBase`eb037a8180/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala (L48-L50)
This preserves some of the existing inconsistencies (variables as defined in [the official example](https://github.com/apache/spark/blob/master/examples/src/main/python/ml/aft_survival_regression.p)) ``` from pyspark.ml.regression import AFTSurvivalRegression, AFTSurvivalRegressionModel from pyspark.ml.param.shared import HasMaxIter, HasTol, HasFitIntercept, HasAggregationDepth from pyspark.ml.param import Param issubclass(AFTSurvivalRegressionModel, HasMaxIter) # False hasattr(model, "maxIter") and isinstance(model.maxIter, Param) # True issubclass(AFTSurvivalRegressionModel, HasTol) # False hasattr(model, "tol") and isinstance(model.tol, Param) # True ``` and can cause problems in the future, if Predictor / PredictionModel API changes (unlike [`IsotonicRegression`](https://github.com/apache/spark/pull/26023), current implementation is technically speaking correct, though incomplete). ### Does this PR introduce any user-facing change? Yes, it adds a number of base classes to `AFTSurvivalRegressionModel`. These change purely additive and have negligible potential for breaking existing code (and none, compared to changes already made in #25776). Additionally affected API hasn't been released in the current form yet. ### How was this patch tested? - Existing unit tests. - Manual testing. CC huaxingao, zhengruifeng Closes #26024 from zero323/SPARK-28985-FOLLOW-UP-aftsurival-regression. Authored-by: zero323 <mszymkiewicz@gmail.com> Signed-off-by: Sean Owen <sean.owen@databricks.com>
This commit is contained in:
parent
228b1ea96c
commit
df22535bbd
|
@ -1480,9 +1480,56 @@ class GBTRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable,
|
||||||
return self._call_java("evaluateEachIteration", dataset, loss)
|
return self._call_java("evaluateEachIteration", dataset, loss)
|
||||||
|
|
||||||
|
|
||||||
|
class _AFTSurvivalRegressionParams(HasFeaturesCol, HasLabelCol, HasPredictionCol,
|
||||||
|
HasMaxIter, HasTol, HasFitIntercept,
|
||||||
|
HasAggregationDepth):
|
||||||
|
"""
|
||||||
|
Params for :py:class:`AFTSurvivalRegression` and :py:class:`AFTSurvivalRegressionModel`.
|
||||||
|
|
||||||
|
.. versionadded:: 3.0.0
|
||||||
|
"""
|
||||||
|
|
||||||
|
censorCol = Param(
|
||||||
|
Params._dummy(), "censorCol",
|
||||||
|
"censor column name. The value of this column could be 0 or 1. " +
|
||||||
|
"If the value is 1, it means the event has occurred i.e. " +
|
||||||
|
"uncensored; otherwise censored.", typeConverter=TypeConverters.toString)
|
||||||
|
quantileProbabilities = Param(
|
||||||
|
Params._dummy(), "quantileProbabilities",
|
||||||
|
"quantile probabilities array. Values of the quantile probabilities array " +
|
||||||
|
"should be in the range (0, 1) and the array should be non-empty.",
|
||||||
|
typeConverter=TypeConverters.toListFloat)
|
||||||
|
quantilesCol = Param(
|
||||||
|
Params._dummy(), "quantilesCol",
|
||||||
|
"quantiles column name. This column will output quantiles of " +
|
||||||
|
"corresponding quantileProbabilities if it is set.",
|
||||||
|
typeConverter=TypeConverters.toString)
|
||||||
|
|
||||||
|
@since("1.6.0")
|
||||||
|
def getCensorCol(self):
|
||||||
|
"""
|
||||||
|
Gets the value of censorCol or its default value.
|
||||||
|
"""
|
||||||
|
return self.getOrDefault(self.censorCol)
|
||||||
|
|
||||||
|
@since("1.6.0")
|
||||||
|
def getQuantileProbabilities(self):
|
||||||
|
"""
|
||||||
|
Gets the value of quantileProbabilities or its default value.
|
||||||
|
"""
|
||||||
|
return self.getOrDefault(self.quantileProbabilities)
|
||||||
|
|
||||||
|
@since("1.6.0")
|
||||||
|
def getQuantilesCol(self):
|
||||||
|
"""
|
||||||
|
Gets the value of quantilesCol or its default value.
|
||||||
|
"""
|
||||||
|
return self.getOrDefault(self.quantilesCol)
|
||||||
|
|
||||||
|
|
||||||
@inherit_doc
|
@inherit_doc
|
||||||
class AFTSurvivalRegression(JavaPredictor, HasFitIntercept, HasMaxIter, HasTol,
|
class AFTSurvivalRegression(JavaEstimator, _AFTSurvivalRegressionParams,
|
||||||
HasAggregationDepth, JavaMLWritable, JavaMLReadable):
|
JavaMLWritable, JavaMLReadable):
|
||||||
"""
|
"""
|
||||||
Accelerated Failure Time (AFT) Model Survival Regression
|
Accelerated Failure Time (AFT) Model Survival Regression
|
||||||
|
|
||||||
|
@ -1529,20 +1576,6 @@ class AFTSurvivalRegression(JavaPredictor, HasFitIntercept, HasMaxIter, HasTol,
|
||||||
.. versionadded:: 1.6.0
|
.. versionadded:: 1.6.0
|
||||||
"""
|
"""
|
||||||
|
|
||||||
censorCol = Param(Params._dummy(), "censorCol",
|
|
||||||
"censor column name. The value of this column could be 0 or 1. " +
|
|
||||||
"If the value is 1, it means the event has occurred i.e. " +
|
|
||||||
"uncensored; otherwise censored.", typeConverter=TypeConverters.toString)
|
|
||||||
quantileProbabilities = \
|
|
||||||
Param(Params._dummy(), "quantileProbabilities",
|
|
||||||
"quantile probabilities array. Values of the quantile probabilities array " +
|
|
||||||
"should be in the range (0, 1) and the array should be non-empty.",
|
|
||||||
typeConverter=TypeConverters.toListFloat)
|
|
||||||
quantilesCol = Param(Params._dummy(), "quantilesCol",
|
|
||||||
"quantiles column name. This column will output quantiles of " +
|
|
||||||
"corresponding quantileProbabilities if it is set.",
|
|
||||||
typeConverter=TypeConverters.toString)
|
|
||||||
|
|
||||||
@keyword_only
|
@keyword_only
|
||||||
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
||||||
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
|
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
|
||||||
|
@ -1588,13 +1621,6 @@ class AFTSurvivalRegression(JavaPredictor, HasFitIntercept, HasMaxIter, HasTol,
|
||||||
"""
|
"""
|
||||||
return self._set(censorCol=value)
|
return self._set(censorCol=value)
|
||||||
|
|
||||||
@since("1.6.0")
|
|
||||||
def getCensorCol(self):
|
|
||||||
"""
|
|
||||||
Gets the value of censorCol or its default value.
|
|
||||||
"""
|
|
||||||
return self.getOrDefault(self.censorCol)
|
|
||||||
|
|
||||||
@since("1.6.0")
|
@since("1.6.0")
|
||||||
def setQuantileProbabilities(self, value):
|
def setQuantileProbabilities(self, value):
|
||||||
"""
|
"""
|
||||||
|
@ -1602,13 +1628,6 @@ class AFTSurvivalRegression(JavaPredictor, HasFitIntercept, HasMaxIter, HasTol,
|
||||||
"""
|
"""
|
||||||
return self._set(quantileProbabilities=value)
|
return self._set(quantileProbabilities=value)
|
||||||
|
|
||||||
@since("1.6.0")
|
|
||||||
def getQuantileProbabilities(self):
|
|
||||||
"""
|
|
||||||
Gets the value of quantileProbabilities or its default value.
|
|
||||||
"""
|
|
||||||
return self.getOrDefault(self.quantileProbabilities)
|
|
||||||
|
|
||||||
@since("1.6.0")
|
@since("1.6.0")
|
||||||
def setQuantilesCol(self, value):
|
def setQuantilesCol(self, value):
|
||||||
"""
|
"""
|
||||||
|
@ -1616,15 +1635,9 @@ class AFTSurvivalRegression(JavaPredictor, HasFitIntercept, HasMaxIter, HasTol,
|
||||||
"""
|
"""
|
||||||
return self._set(quantilesCol=value)
|
return self._set(quantilesCol=value)
|
||||||
|
|
||||||
@since("1.6.0")
|
|
||||||
def getQuantilesCol(self):
|
|
||||||
"""
|
|
||||||
Gets the value of quantilesCol or its default value.
|
|
||||||
"""
|
|
||||||
return self.getOrDefault(self.quantilesCol)
|
|
||||||
|
|
||||||
|
class AFTSurvivalRegressionModel(JavaModel, _AFTSurvivalRegressionParams,
|
||||||
class AFTSurvivalRegressionModel(JavaPredictionModel, JavaMLWritable, JavaMLReadable):
|
JavaMLWritable, JavaMLReadable):
|
||||||
"""
|
"""
|
||||||
Model fitted by :class:`AFTSurvivalRegression`.
|
Model fitted by :class:`AFTSurvivalRegression`.
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue