From bb9b50c2172bc1aa7d261626aa74f9be6d1d5e79 Mon Sep 17 00:00:00 2001 From: zhengruifeng Date: Fri, 8 May 2020 14:06:36 +0800 Subject: [PATCH] [SPARK-31656][ML][PYSPARK] AFT blockify input vectors ### What changes were proposed in this pull request? 1, add new param blockSize; 2, add a new class InstanceBlock; 3, if blockSize==1, keep original behavior; if blockSize>1, stack input vectors to blocks (like ALS/MLP); 4, if blockSize>1, standardize the input outside of optimization procedure; ### Why are the changes needed? it will obtain performance gain on dense datasets, such as epsilon 1, reduce RAM to persist traing dataset; (save about 40% RAM) 2, use Level-2 BLAS routines; (~10X speedup) ### Does this PR introduce _any_ user-facing change? Yes, a new param is added ### How was this patch tested? existing and added testsuites Closes #28473 from zhengruifeng/blockify_aft. Authored-by: zhengruifeng Signed-off-by: zhengruifeng --- .../spark/ml/classification/LinearSVC.scala | 18 +- .../classification/LogisticRegression.scala | 23 +-- .../ml/optim/aggregator/AFTAggregator.scala | 96 +++++++++- .../ml/regression/AFTSurvivalRegression.scala | 168 +++++++++++++----- .../ml/regression/LinearRegression.scala | 18 +- .../AFTSurvivalRegressionSuite.scala | 16 ++ python/pyspark/ml/regression.py | 21 ++- 7 files changed, 279 insertions(+), 81 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index 69c35a8a80..217398c51b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -187,17 +187,15 @@ class LinearSVC @Since("2.2.0") ( val instances = extractInstances(dataset) .setName("training instances") - val (summarizer, labelSummarizer) = if ($(blockSize) == 1) { - if (dataset.storageLevel == StorageLevel.NONE) { - instances.persist(StorageLevel.MEMORY_AND_DISK) - } - Summarizer.getClassificationSummarizers(instances, $(aggregationDepth)) - } else { - // instances will be standardized and converted to blocks, so no need to cache instances. - Summarizer.getClassificationSummarizers(instances, $(aggregationDepth), - Seq("mean", "std", "count", "numNonZeros")) + if (dataset.storageLevel == StorageLevel.NONE && $(blockSize) == 1) { + instances.persist(StorageLevel.MEMORY_AND_DISK) } + var requestedMetrics = Seq("mean", "std", "count") + if ($(blockSize) != 1) requestedMetrics +:= "numNonZeros" + val (summarizer, labelSummarizer) = Summarizer + .getClassificationSummarizers(instances, $(aggregationDepth), requestedMetrics) + val histogram = labelSummarizer.histogram val numInvalid = labelSummarizer.countInvalid val numFeatures = summarizer.mean.size @@ -316,7 +314,7 @@ class LinearSVC @Since("2.2.0") ( } val blocks = InstanceBlock.blokify(standardized, $(blockSize)) .persist(StorageLevel.MEMORY_AND_DISK) - .setName(s"training dataset (blockSize=${$(blockSize)})") + .setName(s"training blocks (blockSize=${$(blockSize)})") val getAggregatorFunc = new BlockHingeAggregator($(fitIntercept))(_) val costFun = new RDDLossFunction(blocks, getAggregatorFunc, diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 10cf961800..c1dd677f08 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -517,17 +517,18 @@ class LogisticRegression @Since("1.2.0") ( probabilityCol, regParam, elasticNetParam, standardization, threshold, maxIter, tol, fitIntercept, blockSize) - val instances = extractInstances(dataset).setName("training instances") - if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) + val instances = extractInstances(dataset) + .setName("training instances") - val (summarizer, labelSummarizer) = if ($(blockSize) == 1) { - Summarizer.getClassificationSummarizers(instances, $(aggregationDepth)) - } else { - // instances will be standardized and converted to blocks, so no need to cache instances. - Summarizer.getClassificationSummarizers(instances, $(aggregationDepth), - Seq("mean", "std", "count", "numNonZeros")) + if (handlePersistence && $(blockSize) == 1) { + instances.persist(StorageLevel.MEMORY_AND_DISK) } + var requestedMetrics = Seq("mean", "std", "count") + if ($(blockSize) != 1) requestedMetrics +:= "numNonZeros" + val (summarizer, labelSummarizer) = Summarizer + .getClassificationSummarizers(instances, $(aggregationDepth), requestedMetrics) + val numFeatures = summarizer.mean.size val histogram = labelSummarizer.histogram val numInvalid = labelSummarizer.countInvalid @@ -591,7 +592,7 @@ class LogisticRegression @Since("1.2.0") ( } else { Vectors.dense(if (numClasses == 2) Double.PositiveInfinity else Double.NegativeInfinity) } - if (handlePersistence) instances.unpersist() + if (instances.getStorageLevel != StorageLevel.NONE) instances.unpersist() return createModel(dataset, numClasses, coefMatrix, interceptVec, Array.empty) } @@ -650,7 +651,7 @@ class LogisticRegression @Since("1.2.0") ( trainOnBlocks(instances, featuresStd, numClasses, initialCoefWithInterceptMatrix, regularization, optimizer) } - if (handlePersistence) instances.unpersist() + if (instances.getStorageLevel != StorageLevel.NONE) instances.unpersist() if (allCoefficients == null) { val msg = s"${optimizer.getClass.getName} failed." @@ -1002,7 +1003,7 @@ class LogisticRegression @Since("1.2.0") ( } val blocks = InstanceBlock.blokify(standardized, $(blockSize)) .persist(StorageLevel.MEMORY_AND_DISK) - .setName(s"training dataset (blockSize=${$(blockSize)})") + .setName(s"training blocks (blockSize=${$(blockSize)})") val getAggregatorFunc = new BlockLogisticAggregator(numFeatures, numClasses, $(fitIntercept), checkMultinomial(numClasses))(_) diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/AFTAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/AFTAggregator.scala index 6482c619e6..8a5d7fe34e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/AFTAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/AFTAggregator.scala @@ -155,8 +155,102 @@ private[ml] class AFTAggregator( } gradientSumArray(dim - 2) += { if (fitIntercept) multiplier else 0.0 } gradientSumArray(dim - 1) += delta + multiplier * sigma * epsilon - weightSum += 1.0 + + this + } +} + + +/** + * BlockAFTAggregator computes the gradient and loss as used in AFT survival regression + * for blocks in sparse or dense matrix in an online fashion. + * + * Two BlockAFTAggregators can be merged together to have a summary of loss and gradient of + * the corresponding joint dataset. + * + * NOTE: The feature values are expected to be standardized before computation. + * + * @param bcCoefficients The coefficients corresponding to the features. + * @param fitIntercept Whether to fit an intercept term. + */ +private[ml] class BlockAFTAggregator( + fitIntercept: Boolean)(bcCoefficients: Broadcast[Vector]) + extends DifferentiableLossAggregator[(Matrix, Array[Double], Array[Double]), + BlockAFTAggregator] { + + protected override val dim: Int = bcCoefficients.value.size + private val numFeatures = dim - 2 + + @transient private lazy val coefficientsArray = bcCoefficients.value match { + case DenseVector(values) => values + case _ => throw new IllegalArgumentException(s"coefficients only supports dense vector" + + s" but got type ${bcCoefficients.value.getClass}.") + } + + @transient private lazy val linear = Vectors.dense(coefficientsArray.take(numFeatures)) + + /** + * Add a new training instance block to this BlockAFTAggregator, and update the loss and + * gradient of the objective function. + * + * @return This BlockAFTAggregator object. + */ + def add(block: (Matrix, Array[Double], Array[Double])): this.type = { + val (matrix, labels, censors) = block + require(matrix.isTransposed) + require(numFeatures == matrix.numCols, s"Dimensions mismatch when adding new " + + s"instance. Expecting $numFeatures but got ${matrix.numCols}.") + require(labels.forall(_ > 0.0), "The lifetime or label should be greater than 0.") + + val size = matrix.numRows + require(labels.length == size && censors.length == size) + + val intercept = coefficientsArray(dim - 2) + // sigma is the scale parameter of the AFT model + val sigma = math.exp(coefficientsArray(dim - 1)) + + // vec here represents margins + val vec = if (fitIntercept) { + Vectors.dense(Array.fill(size)(intercept)).toDense + } else { + Vectors.zeros(size).toDense + } + BLAS.gemv(1.0, matrix, linear, 1.0, vec) + + // in-place convert margins to gradient scales + // then, vec represents gradient scales + var i = 0 + var sigmaGradSum = 0.0 + while (i < size) { + val ti = labels(i) + val delta = censors(i) + val margin = vec(i) + val epsilon = (math.log(ti) - margin) / sigma + val expEpsilon = math.exp(epsilon) + lossSum += delta * math.log(sigma) - delta * epsilon + expEpsilon + val multiplier = (delta - expEpsilon) / sigma + vec.values(i) = multiplier + sigmaGradSum += delta + multiplier * sigma * epsilon + i += 1 + } + + matrix match { + case dm: DenseMatrix => + BLAS.nativeBLAS.dgemv("N", dm.numCols, dm.numRows, 1.0, dm.values, dm.numCols, + vec.values, 1, 1.0, gradientSumArray, 1) + + case sm: SparseMatrix => + val linearGradSumVec = Vectors.zeros(numFeatures).toDense + BLAS.gemv(1.0, sm.transpose, vec, 0.0, linearGradSumVec) + BLAS.getBLAS(numFeatures).daxpy(numFeatures, 1.0, linearGradSumVec.values, 1, + gradientSumArray, 1) + } + + if (fitIntercept) gradientSumArray(dim - 2) += vec.values.sum + gradientSumArray(dim - 1) += sigmaGradSum + weightSum += size + this } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 8cc5f864de..2c30e44b93 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -27,8 +27,9 @@ import org.apache.spark.SparkException import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.ml.PredictorParams -import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors, VectorUDT} -import org.apache.spark.ml.optim.aggregator.AFTAggregator +import org.apache.spark.ml.feature.StandardScalerModel +import org.apache.spark.ml.linalg._ +import org.apache.spark.ml.optim.aggregator._ import org.apache.spark.ml.optim.loss.RDDLossFunction import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -46,7 +47,8 @@ import org.apache.spark.storage.StorageLevel * Params for accelerated failure time (AFT) regression. */ private[regression] trait AFTSurvivalRegressionParams extends PredictorParams - with HasMaxIter with HasTol with HasFitIntercept with HasAggregationDepth with Logging { + with HasMaxIter with HasTol with HasFitIntercept with HasAggregationDepth with HasBlockSize + with Logging { /** * Param for censor column name. @@ -183,6 +185,25 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value) setDefault(aggregationDepth -> 2) + /** + * Set block size for stacking input data in matrices. + * If blockSize == 1, then stacking will be skipped, and each vector is treated individually; + * If blockSize > 1, then vectors will be stacked to blocks, and high-level BLAS routines + * will be used if possible (for example, GEMV instead of DOT, GEMM instead of GEMV). + * Recommended size is between 10 and 1000. An appropriate choice of the block size depends + * on the sparsity and dim of input datasets, the underlying BLAS implementation (for example, + * f2jBLAS, OpenBLAS, intel MKL) and its configuration (for example, number of threads). + * Note that existing BLAS implementations are mainly optimized for dense matrices, if the + * input dataset is sparse, stacking may bring no performance gain, the worse is possible + * performance regression. + * Default is 1. + * + * @group expertSetParam + */ + @Since("3.1.0") + def setBlockSize(value: Int): this.type = set(blockSize, value) + setDefault(blockSize -> 1) + /** * Extract [[featuresCol]], [[labelCol]] and [[censorCol]] from input dataset, * and put it in an RDD with strong types. @@ -197,39 +218,50 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S override protected def train( dataset: Dataset[_]): AFTSurvivalRegressionModel = instrumented { instr => - val instances = extractAFTPoints(dataset) - val handlePersistence = dataset.storageLevel == StorageLevel.NONE - if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, labelCol, featuresCol, censorCol, predictionCol, quantilesCol, + fitIntercept, maxIter, tol, aggregationDepth, blockSize) + instr.logNamedValue("quantileProbabilities.size", $(quantileProbabilities).length) - val featuresSummarizer = instances.treeAggregate( - Summarizer.createSummarizerBuffer("mean", "std", "count"))( + val instances = extractAFTPoints(dataset) + .setName("training instances") + + if ($(blockSize) == 1 && dataset.storageLevel == StorageLevel.NONE) { + instances.persist(StorageLevel.MEMORY_AND_DISK) + } + + var requestedMetrics = Seq("mean", "std", "count") + if ($(blockSize) != 1) requestedMetrics +:= "numNonZeros" + val summarizer = instances.treeAggregate( + Summarizer.createSummarizerBuffer(requestedMetrics: _*))( seqOp = (c: SummarizerBuffer, v: AFTPoint) => c.add(v.features), combOp = (c1: SummarizerBuffer, c2: SummarizerBuffer) => c1.merge(c2), depth = $(aggregationDepth) ) - val featuresStd = featuresSummarizer.std.toArray + val featuresStd = summarizer.std.toArray val numFeatures = featuresStd.length - - instr.logPipelineStage(this) - instr.logDataset(dataset) - instr.logParams(this, labelCol, featuresCol, censorCol, predictionCol, quantilesCol, - fitIntercept, maxIter, tol, aggregationDepth) - instr.logNamedValue("quantileProbabilities.size", $(quantileProbabilities).length) instr.logNumFeatures(numFeatures) - instr.logNumExamples(featuresSummarizer.count) + instr.logNumExamples(summarizer.count) + if ($(blockSize) > 1) { + val scale = 1.0 / summarizer.count / numFeatures + val sparsity = 1 - summarizer.numNonzeros.toArray.map(_ * scale).sum + instr.logNamedValue("sparsity", sparsity.toString) + if (sparsity > 0.5) { + instr.logWarning(s"sparsity of input dataset is $sparsity, " + + s"which may hurt performance in high-level BLAS.") + } + } if (!$(fitIntercept) && (0 until numFeatures).exists { i => - featuresStd(i) == 0.0 && featuresSummarizer.mean(i) != 0.0 }) { + featuresStd(i) == 0.0 && summarizer.mean(i) != 0.0 }) { instr.logWarning("Fitting AFTSurvivalRegressionModel without intercept on dataset with " + "constant nonzero column, Spark MLlib outputs zero coefficients for constant nonzero " + "columns. This behavior is different from R survival::survreg.") } - val bcFeaturesStd = instances.context.broadcast(featuresStd) - val getAggregatorFunc = new AFTAggregator(bcFeaturesStd, $(fitIntercept))(_) val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) - val costFun = new RDDLossFunction(instances, getAggregatorFunc, None, $(aggregationDepth)) /* The parameters vector has three parts: @@ -239,36 +271,86 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S */ val initialParameters = Vectors.zeros(numFeatures + 2) + val (rawCoefficients, objectiveHistory) = if ($(blockSize) == 1) { + trainOnRows(instances, featuresStd, optimizer, initialParameters) + } else { + trainOnBlocks(instances, featuresStd, optimizer, initialParameters) + } + if (instances.getStorageLevel != StorageLevel.NONE) instances.unpersist() + + if (rawCoefficients == null) { + val msg = s"${optimizer.getClass.getName} failed." + instr.logError(msg) + throw new SparkException(msg) + } + + val coefficientArray = Array.tabulate(numFeatures) { i => + if (featuresStd(i) != 0) rawCoefficients(i) / featuresStd(i) else 0.0 + } + val coefficients = Vectors.dense(coefficientArray) + val intercept = rawCoefficients(numFeatures) + val scale = math.exp(rawCoefficients(numFeatures + 1)) + new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale) + } + + private def trainOnRows( + instances: RDD[AFTPoint], + featuresStd: Array[Double], + optimizer: BreezeLBFGS[BDV[Double]], + initialParameters: Vector): (Array[Double], Array[Double]) = { + val bcFeaturesStd = instances.context.broadcast(featuresStd) + val getAggregatorFunc = new AFTAggregator(bcFeaturesStd, $(fitIntercept))(_) + val costFun = new RDDLossFunction(instances, getAggregatorFunc, None, $(aggregationDepth)) + val states = optimizer.iterations(new CachedDiffFunction(costFun), initialParameters.asBreeze.toDenseVector) - val parameters = { - val arrayBuilder = mutable.ArrayBuilder.make[Double] - var state: optimizer.State = null - while (states.hasNext) { - state = states.next() - arrayBuilder += state.adjustedValue - } - if (state == null) { - val msg = s"${optimizer.getClass.getName} failed." - throw new SparkException(msg) - } - state.x.toArray.clone() + val arrayBuilder = mutable.ArrayBuilder.make[Double] + var state: optimizer.State = null + while (states.hasNext) { + state = states.next() + arrayBuilder += state.adjustedValue } - bcFeaturesStd.destroy() - if (handlePersistence) instances.unpersist() - val rawCoefficients = parameters.take(numFeatures) - var i = 0 - while (i < numFeatures) { - rawCoefficients(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 } - i += 1 + (if (state != null) state.x.toArray else null, arrayBuilder.result) + } + + private def trainOnBlocks( + instances: RDD[AFTPoint], + featuresStd: Array[Double], + optimizer: BreezeLBFGS[BDV[Double]], + initialParameters: Vector): (Array[Double], Array[Double]) = { + val bcFeaturesStd = instances.context.broadcast(featuresStd) + val blocks = instances.mapPartitions { iter => + val inverseStd = bcFeaturesStd.value.map { std => if (std != 0) 1.0 / std else 0.0 } + val func = StandardScalerModel.getTransformFunc(Array.empty, inverseStd, false, true) + iter.grouped($(blockSize)).map { seq => + val matrix = Matrices.fromVectors(seq.map(point => func(point.features))) + val labels = seq.map(_.label).toArray + val censors = seq.map(_.censor).toArray + (matrix, labels, censors) + } } - val coefficients = Vectors.dense(rawCoefficients) - val intercept = parameters(numFeatures) - val scale = math.exp(parameters(numFeatures + 1)) - new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale) + blocks.persist(StorageLevel.MEMORY_AND_DISK) + .setName(s"training blocks (blockSize=${$(blockSize)})") + + val getAggregatorFunc = new BlockAFTAggregator($(fitIntercept))(_) + val costFun = new RDDLossFunction(blocks, getAggregatorFunc, None, $(aggregationDepth)) + + val states = optimizer.iterations(new CachedDiffFunction(costFun), + initialParameters.asBreeze.toDenseVector) + + val arrayBuilder = mutable.ArrayBuilder.make[Double] + var state: optimizer.State = null + while (states.hasNext) { + state = states.next() + arrayBuilder += state.adjustedValue + } + blocks.unpersist() + bcFeaturesStd.destroy() + + (if (state != null) state.x.toArray else null, arrayBuilder.result) } @Since("1.6.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 9f18c84931..bcf9b7c042 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -355,17 +355,15 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val instances = extractInstances(dataset) .setName("training instances") - val (featuresSummarizer, ySummarizer) = if ($(blockSize) == 1) { - if (dataset.storageLevel == StorageLevel.NONE) { - instances.persist(StorageLevel.MEMORY_AND_DISK) - } - Summarizer.getRegressionSummarizers(instances, $(aggregationDepth)) - } else { - // instances will be standardized and converted to blocks, so no need to cache instances. - Summarizer.getRegressionSummarizers(instances, $(aggregationDepth), - Seq("mean", "std", "count", "numNonZeros")) + if (dataset.storageLevel == StorageLevel.NONE && $(blockSize) == 1) { + instances.persist(StorageLevel.MEMORY_AND_DISK) } + var requestedMetrics = Seq("mean", "std", "count") + if ($(blockSize) != 1) requestedMetrics +:= "numNonZeros" + val (featuresSummarizer, ySummarizer) = Summarizer + .getRegressionSummarizers(instances, $(aggregationDepth), requestedMetrics) + val yMean = ySummarizer.mean(0) val rawYStd = ySummarizer.std(0) @@ -617,7 +615,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String } val blocks = InstanceBlock.blokify(standardized, $(blockSize)) .persist(StorageLevel.MEMORY_AND_DISK) - .setName(s"training dataset (blockSize=${$(blockSize)})") + .setName(s"training blocks (blockSize=${$(blockSize)})") val costFun = $(loss) match { case SquaredError => diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 6cc73e040e..a66143ab12 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -428,6 +428,22 @@ class AFTSurvivalRegressionSuite extends MLTest with DefaultReadWriteTest { val trainer = new AFTSurvivalRegression() trainer.fit(dataset) } + + test("AFTSurvivalRegression on blocks") { + val quantileProbabilities = Array(0.1, 0.5, 0.9) + for (dataset <- Seq(datasetUnivariate, datasetUnivariateScaled, datasetMultivariate)) { + val aft = new AFTSurvivalRegression() + .setQuantileProbabilities(quantileProbabilities) + .setQuantilesCol("quantiles") + val model = aft.fit(dataset) + Seq(4, 16, 64).foreach { blockSize => + val model2 = aft.setBlockSize(blockSize).fit(dataset) + assert(model.coefficients ~== model2.coefficients relTol 1e-9) + assert(model.intercept ~== model2.intercept relTol 1e-9) + assert(model.scale ~== model2.scale relTol 1e-9) + } + } + } } object AFTSurvivalRegressionSuite { diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 77a50bfe4a..2ce467308e 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -1607,7 +1607,7 @@ class GBTRegressionModel( class _AFTSurvivalRegressionParams(_PredictorParams, HasMaxIter, HasTol, HasFitIntercept, - HasAggregationDepth): + HasAggregationDepth, HasBlockSize): """ Params for :py:class:`AFTSurvivalRegression` and :py:class:`AFTSurvivalRegressionModel`. @@ -1674,6 +1674,8 @@ class AFTSurvivalRegression(_JavaRegressor, _AFTSurvivalRegressionParams, 10 >>> aftsr.clear(aftsr.maxIter) >>> model = aftsr.fit(df) + >>> model.getBlockSize() + 1 >>> model.setFeaturesCol("features") AFTSurvivalRegressionModel... >>> model.predict(Vectors.dense(6.3)) @@ -1710,19 +1712,19 @@ class AFTSurvivalRegression(_JavaRegressor, _AFTSurvivalRegressionParams, def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]), - quantilesCol=None, aggregationDepth=2): + quantilesCol=None, aggregationDepth=2, blockSize=1): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \ quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \ - quantilesCol=None, aggregationDepth=2) + quantilesCol=None, aggregationDepth=2, blockSize=1) """ super(AFTSurvivalRegression, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.AFTSurvivalRegression", self.uid) self._setDefault(censorCol="censor", quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], - maxIter=100, tol=1E-6) + maxIter=100, tol=1E-6, blockSize=1) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -1731,12 +1733,12 @@ class AFTSurvivalRegression(_JavaRegressor, _AFTSurvivalRegressionParams, def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]), - quantilesCol=None, aggregationDepth=2): + quantilesCol=None, aggregationDepth=2, blockSize=1): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \ quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \ - quantilesCol=None, aggregationDepth=2): + quantilesCol=None, aggregationDepth=2, blockSize=1): """ kwargs = self._input_kwargs return self._set(**kwargs) @@ -1793,6 +1795,13 @@ class AFTSurvivalRegression(_JavaRegressor, _AFTSurvivalRegressionParams, """ return self._set(aggregationDepth=value) + @since("3.1.0") + def setBlockSize(self, value): + """ + Sets the value of :py:attr:`blockSize`. + """ + return self._set(blockSize=value) + class AFTSurvivalRegressionModel(_JavaRegressionModel, _AFTSurvivalRegressionParams, JavaMLWritable, JavaMLReadable):