Yanbo Liang f92b7b98e9 [SPARK-11367][ML][PYSPARK] Python LinearRegression should support setting solver
[SPARK-10668]( has provided ```WeightedLeastSquares``` solver("normal") in ```LinearRegression``` with L2 regularization in Scala and R, Python ML ```LinearRegression``` should also support setting solver("auto", "normal", "l-bfgs")

Author: Yanbo Liang <>

Closes #9328 from yanboliang/spark-11367.
2015-10-28 08:54:20 -07:00

843 lines
32 KiB

# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
from pyspark import since
from import keyword_only
from import JavaEstimator, JavaModel
from import *
from pyspark.mllib.common import inherit_doc
__all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel',
'DecisionTreeRegressor', 'DecisionTreeRegressionModel',
'GBTRegressor', 'GBTRegressionModel',
'IsotonicRegression', 'IsotonicRegressionModel',
'LinearRegression', 'LinearRegressionModel',
'RandomForestRegressor', 'RandomForestRegressionModel']
class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept,
HasStandardization, HasSolver):
Linear regression.
The learning objective is to minimize the squared error, with regularization.
The specific squared error loss function used is: L = 1/2n ||A weights - y||^2^
This support multiple types of regularization:
- none (a.k.a. ordinary least squares)
- L2 (ridge regression)
- L1 (Lasso)
- L2 + L1 (elastic net)
>>> from pyspark.mllib.linalg import Vectors
>>> df = sqlContext.createDataFrame([
... (1.0, Vectors.dense(1.0)),
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
>>> lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal")
>>> model =
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
>>> model.weights
>>> model.intercept
>>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction
>>> lr.setParams("vector")
Traceback (most recent call last):
TypeError: Method setParams forces keyword arguments.
.. versionadded:: 1.4.0
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
standardization=True, solver="auto"):
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
standardization=True, solver="auto")
super(LinearRegression, self).__init__()
self._java_obj = self._new_java_obj(
"", self.uid)
self._setDefault(maxIter=100, regParam=0.0, tol=1e-6)
kwargs = self.__init__._input_kwargs
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
standardization=True, solver="auto"):
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
standardization=True, solver="auto")
Sets params for linear regression.
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)
def _create_model(self, java_model):
return LinearRegressionModel(java_model)
class LinearRegressionModel(JavaModel):
Model fitted by LinearRegression.
.. versionadded:: 1.4.0
def weights(self):
Model weights.
return self._call_java("weights")
def intercept(self):
Model intercept.
return self._call_java("intercept")
class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
.. note:: Experimental
Currently implemented using parallelized pool adjacent violators algorithm.
Only univariate (single feature) algorithm supported.
>>> from pyspark.mllib.linalg import Vectors
>>> df = sqlContext.createDataFrame([
... (1.0, Vectors.dense(1.0)),
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
>>> ir = IsotonicRegression()
>>> model =
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
>>> model.boundaries
DenseVector([0.0, 1.0])
# a placeholder to make it appear in the generated doc
isotonic = \
Param(Params._dummy(), "isotonic",
"whether the output sequence should be isotonic/increasing (true) or" +
"antitonic/decreasing (false).")
featureIndex = \
Param(Params._dummy(), "featureIndex",
"The index of the feature if featuresCol is a vector column, no effect otherwise.")
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
weightCol=None, isotonic=True, featureIndex=0):
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
weightCol=None, isotonic=True, featureIndex=0):
super(IsotonicRegression, self).__init__()
self._java_obj = self._new_java_obj(
"", self.uid)
self.isotonic = \
Param(self, "isotonic",
"whether the output sequence should be isotonic/increasing (true) or" +
"antitonic/decreasing (false).")
self.featureIndex = \
Param(self, "featureIndex",
"The index of the feature if featuresCol is a vector column, no effect " +
self._setDefault(isotonic=True, featureIndex=0)
kwargs = self.__init__._input_kwargs
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
weightCol=None, isotonic=True, featureIndex=0):
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
weightCol=None, isotonic=True, featureIndex=0):
Set the params for IsotonicRegression.
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)
def _create_model(self, java_model):
return IsotonicRegressionModel(java_model)
def setIsotonic(self, value):
Sets the value of :py:attr:`isotonic`.
self._paramMap[self.isotonic] = value
return self
def getIsotonic(self):
Gets the value of isotonic or its default value.
return self.getOrDefault(self.isotonic)
def setFeatureIndex(self, value):
Sets the value of :py:attr:`featureIndex`.
self._paramMap[self.featureIndex] = value
return self
def getFeatureIndex(self):
Gets the value of featureIndex or its default value.
return self.getOrDefault(self.featureIndex)
class IsotonicRegressionModel(JavaModel):
.. note:: Experimental
Model fitted by IsotonicRegression.
def boundaries(self):
Model boundaries.
return self._call_java("boundaries")
def predictions(self):
Predictions associated with the boundaries at the same index, monotone because of isotonic
return self._call_java("predictions")
class TreeEnsembleParams(DecisionTreeParams):
Mixin for Decision Tree-based ensemble algorithms parameters.
# a placeholder to make it appear in the generated doc
subsamplingRate = Param(Params._dummy(), "subsamplingRate", "Fraction of the training data " +
"used for learning each decision tree, in range (0, 1].")
def __init__(self):
super(TreeEnsembleParams, self).__init__()
#: param for Fraction of the training data, in range (0, 1].
self.subsamplingRate = Param(self, "subsamplingRate", "Fraction of the training data " +
"used for learning each decision tree, in range (0, 1].")
def setSubsamplingRate(self, value):
Sets the value of :py:attr:`subsamplingRate`.
self._paramMap[self.subsamplingRate] = value
return self
def getSubsamplingRate(self):
Gets the value of subsamplingRate or its default value.
return self.getOrDefault(self.subsamplingRate)
class TreeRegressorParams(Params):
Private class to track supported impurity measures.
supportedImpurities = ["variance"]
# a placeholder to make it appear in the generated doc
impurity = Param(Params._dummy(), "impurity",
"Criterion used for information gain calculation (case-insensitive). " +
"Supported options: " +
", ".join(supportedImpurities))
def __init__(self):
super(TreeRegressorParams, self).__init__()
#: param for Criterion used for information gain calculation (case-insensitive).
self.impurity = Param(self, "impurity", "Criterion used for information " +
"gain calculation (case-insensitive). Supported options: " +
", ".join(self.supportedImpurities))
def setImpurity(self, value):
Sets the value of :py:attr:`impurity`.
self._paramMap[self.impurity] = value
return self
def getImpurity(self):
Gets the value of impurity or its default value.
return self.getOrDefault(self.impurity)
class RandomForestParams(TreeEnsembleParams):
Private class to track supported random forest parameters.
supportedFeatureSubsetStrategies = ["auto", "all", "onethird", "sqrt", "log2"]
# a placeholder to make it appear in the generated doc
numTrees = Param(Params._dummy(), "numTrees", "Number of trees to train (>= 1).")
featureSubsetStrategy = \
Param(Params._dummy(), "featureSubsetStrategy",
"The number of features to consider for splits at each tree node. Supported " +
"options: " + ", ".join(supportedFeatureSubsetStrategies))
def __init__(self):
super(RandomForestParams, self).__init__()
#: param for Number of trees to train (>= 1).
self.numTrees = Param(self, "numTrees", "Number of trees to train (>= 1).")
#: param for The number of features to consider for splits at each tree node.
self.featureSubsetStrategy = \
Param(self, "featureSubsetStrategy",
"The number of features to consider for splits at each tree node. Supported " +
"options: " + ", ".join(self.supportedFeatureSubsetStrategies))
def setNumTrees(self, value):
Sets the value of :py:attr:`numTrees`.
self._paramMap[self.numTrees] = value
return self
def getNumTrees(self):
Gets the value of numTrees or its default value.
return self.getOrDefault(self.numTrees)
def setFeatureSubsetStrategy(self, value):
Sets the value of :py:attr:`featureSubsetStrategy`.
self._paramMap[self.featureSubsetStrategy] = value
return self
def getFeatureSubsetStrategy(self):
Gets the value of featureSubsetStrategy or its default value.
return self.getOrDefault(self.featureSubsetStrategy)
class GBTParams(TreeEnsembleParams):
Private class to track supported GBT params.
supportedLossTypes = ["squared", "absolute"]
class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
DecisionTreeParams, TreeRegressorParams, HasCheckpointInterval):
` Decision tree`
learning algorithm for regression.
It supports both continuous and categorical features.
>>> from pyspark.mllib.linalg import Vectors
>>> df = sqlContext.createDataFrame([
... (1.0, Vectors.dense(1.0)),
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
>>> dt = DecisionTreeRegressor(maxDepth=2)
>>> model =
>>> model.depth
>>> model.numNodes
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
>>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction
.. versionadded:: 1.4.0
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance"):
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance")
super(DecisionTreeRegressor, self).__init__()
self._java_obj = self._new_java_obj(
"", self.uid)
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
kwargs = self.__init__._input_kwargs
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance")
Sets params for the DecisionTreeRegressor.
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)
def _create_model(self, java_model):
return DecisionTreeRegressionModel(java_model)
class DecisionTreeModel(JavaModel):
"""Abstraction for Decision Tree models.
.. versionadded:: 1.5.0
def numNodes(self):
"""Return number of nodes of the decision tree."""
return self._call_java("numNodes")
def depth(self):
"""Return depth of the decision tree."""
return self._call_java("depth")
def __repr__(self):
return self._call_java("toString")
class TreeEnsembleModels(JavaModel):
"""Represents a tree ensemble model.
.. versionadded:: 1.5.0
def treeWeights(self):
"""Return the weights for each tree"""
return list(self._call_java("javaTreeWeights"))
def __repr__(self):
return self._call_java("toString")
class DecisionTreeRegressionModel(DecisionTreeModel):
Model fitted by DecisionTreeRegressor.
.. versionadded:: 1.4.0
class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed,
RandomForestParams, TreeRegressorParams, HasCheckpointInterval):
` Random Forest`
learning algorithm for regression.
It supports both continuous and categorical features.
>>> from numpy import allclose
>>> from pyspark.mllib.linalg import Vectors
>>> df = sqlContext.createDataFrame([
... (1.0, Vectors.dense(1.0)),
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
>>> rf = RandomForestRegressor(numTrees=2, maxDepth=2, seed=42)
>>> model =
>>> allclose(model.treeWeights, [1.0, 1.0])
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
>>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction
.. versionadded:: 1.4.0
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20,
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, \
super(RandomForestRegressor, self).__init__()
self._java_obj = self._new_java_obj(
"", self.uid)
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20,
kwargs = self.__init__._input_kwargs
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20,
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, \
Sets params for linear regression.
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)
def _create_model(self, java_model):
return RandomForestRegressionModel(java_model)
class RandomForestRegressionModel(TreeEnsembleModels):
Model fitted by RandomForestRegressor.
.. versionadded:: 1.4.0
class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
GBTParams, HasCheckpointInterval, HasStepSize, HasSeed):
` Gradient-Boosted Trees (GBTs)`
learning algorithm for regression.
It supports both continuous and categorical features.
>>> from numpy import allclose
>>> from pyspark.mllib.linalg import Vectors
>>> df = sqlContext.createDataFrame([
... (1.0, Vectors.dense(1.0)),
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
>>> gbt = GBTRegressor(maxIter=5, maxDepth=2)
>>> model =
>>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
>>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction
.. versionadded:: 1.4.0
# a placeholder to make it appear in the generated doc
lossType = Param(Params._dummy(), "lossType",
"Loss function which GBT tries to minimize (case-insensitive). " +
"Supported options: " + ", ".join(GBTParams.supportedLossTypes))
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1):
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1)
super(GBTRegressor, self).__init__()
self._java_obj = self._new_java_obj("", self.uid)
#: param for Loss function which GBT tries to minimize (case-insensitive).
self.lossType = Param(self, "lossType",
"Loss function which GBT tries to minimize (case-insensitive). " +
"Supported options: " + ", ".join(GBTParams.supportedLossTypes))
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1)
kwargs = self.__init__._input_kwargs
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1):
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1)
Sets params for Gradient Boosted Tree Regression.
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)
def _create_model(self, java_model):
return GBTRegressionModel(java_model)
def setLossType(self, value):
Sets the value of :py:attr:`lossType`.
self._paramMap[self.lossType] = value
return self
def getLossType(self):
Gets the value of lossType or its default value.
return self.getOrDefault(self.lossType)
class GBTRegressionModel(TreeEnsembleModels):
Model fitted by GBTRegressor.
.. versionadded:: 1.4.0
class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
HasFitIntercept, HasMaxIter, HasTol):
Accelerated Failure Time (AFT) Model Survival Regression
Fit a parametric AFT survival regression model based on the Weibull distribution
of the survival time.
.. seealso:: `AFT Model <>`_
>>> from pyspark.mllib.linalg import Vectors
>>> df = sqlContext.createDataFrame([
... (1.0, Vectors.dense(1.0), 1.0),
... (0.0, Vectors.sparse(1, [], []), 0.0)], ["label", "features", "censor"])
>>> aftsr = AFTSurvivalRegression()
>>> model =
>>> model.predict(Vectors.dense(6.3))
>>> model.predictQuantiles(Vectors.dense(6.3))
DenseVector([0.0101, 0.0513, 0.1054, 0.2877, 0.6931, 1.3863, 2.3026, 2.9957, 4.6052])
>>> model.transform(df).show()
|label| features|censor|prediction|
| 1.0| [1.0]| 1.0| 1.0|
| 0.0|(1,[],[])| 0.0| 1.0|
.. versionadded:: 1.6.0
# a placeholder to make it appear in the generated doc
censorCol = Param(Params._dummy(), "censorCol",
"censor column name. The value of this column could be 0 or 1. " +
"If the value is 1, it means the event has occurred i.e. " +
"uncensored; otherwise censored.")
quantileProbabilities = \
Param(Params._dummy(), "quantileProbabilities",
"quantile probabilities array. Values of the quantile probabilities array " +
"should be in the range (0, 1) and the array should be non-empty.")
quantilesCol = Param(Params._dummy(), "quantilesCol",
"quantiles column name. This column will output quantiles of " +
"corresponding quantileProbabilities if it is set.")
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
quantileProbabilities=None, quantilesCol=None):
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \
quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \
super(AFTSurvivalRegression, self).__init__()
self._java_obj = self._new_java_obj(
"", self.uid)
#: Param for censor column name
self.censorCol = Param(self, "censorCol",
"censor column name. The value of this column could be 0 or 1. " +
"If the value is 1, it means the event has occurred i.e. " +
"uncensored; otherwise censored.")
#: Param for quantile probabilities array
self.quantileProbabilities = \
Param(self, "quantileProbabilities",
"quantile probabilities array. Values of the quantile probabilities array " +
"should be in the range (0, 1) and the array should be non-empty.")
#: Param for quantiles column name
self.quantilesCol = Param(self, "quantilesCol",
"quantiles column name. This column will output quantiles of " +
"corresponding quantileProbabilities if it is set.")
quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99])
kwargs = self.__init__._input_kwargs
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
quantileProbabilities=None, quantilesCol=None):
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \
quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)
def _create_model(self, java_model):
return AFTSurvivalRegressionModel(java_model)
def setCensorCol(self, value):
Sets the value of :py:attr:`censorCol`.
self._paramMap[self.censorCol] = value
return self
def getCensorCol(self):
Gets the value of censorCol or its default value.
return self.getOrDefault(self.censorCol)
def setQuantileProbabilities(self, value):
Sets the value of :py:attr:`quantileProbabilities`.
self._paramMap[self.quantileProbabilities] = value
return self
def getQuantileProbabilities(self):
Gets the value of quantileProbabilities or its default value.
return self.getOrDefault(self.quantileProbabilities)
def setQuantilesCol(self, value):
Sets the value of :py:attr:`quantilesCol`.
self._paramMap[self.quantilesCol] = value
return self
def getQuantilesCol(self):
Gets the value of quantilesCol or its default value.
return self.getOrDefault(self.quantilesCol)
class AFTSurvivalRegressionModel(JavaModel):
Model fitted by AFTSurvivalRegression.
.. versionadded:: 1.6.0
def predictQuantiles(self, features):
Predicted Quantiles
return self._call_java("predictQuantiles", features)
def predict(self, features):
Predicted value
return self._call_java("predict", features)
if __name__ == "__main__":
import doctest
from pyspark.context import SparkContext
from pyspark.sql import SQLContext
globs = globals().copy()
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
sc = SparkContext("local[2]", "ml.regression tests")
sqlContext = SQLContext(sc)
globs['sc'] = sc
globs['sqlContext'] = sqlContext
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
if failure_count: