[SPARK-31656][ML][PYSPARK] AFT blockify input vectors
### What changes were proposed in this pull request? 1, add new param blockSize; 2, add a new class InstanceBlock; 3, if blockSize==1, keep original behavior; if blockSize>1, stack input vectors to blocks (like ALS/MLP); 4, if blockSize>1, standardize the input outside of optimization procedure; ### Why are the changes needed? it will obtain performance gain on dense datasets, such as epsilon 1, reduce RAM to persist traing dataset; (save about 40% RAM) 2, use Level-2 BLAS routines; (~10X speedup) ### Does this PR introduce _any_ user-facing change? Yes, a new param is added ### How was this patch tested? existing and added testsuites Closes #28473 from zhengruifeng/blockify_aft. Authored-by: zhengruifeng <ruifengz@foxmail.com> Signed-off-by: zhengruifeng <ruifengz@foxmail.com>
This commit is contained in:
parent
18d2ba53e4
commit
bb9b50c217
|
@ -187,17 +187,15 @@ class LinearSVC @Since("2.2.0") (
|
|||
val instances = extractInstances(dataset)
|
||||
.setName("training instances")
|
||||
|
||||
val (summarizer, labelSummarizer) = if ($(blockSize) == 1) {
|
||||
if (dataset.storageLevel == StorageLevel.NONE) {
|
||||
instances.persist(StorageLevel.MEMORY_AND_DISK)
|
||||
}
|
||||
Summarizer.getClassificationSummarizers(instances, $(aggregationDepth))
|
||||
} else {
|
||||
// instances will be standardized and converted to blocks, so no need to cache instances.
|
||||
Summarizer.getClassificationSummarizers(instances, $(aggregationDepth),
|
||||
Seq("mean", "std", "count", "numNonZeros"))
|
||||
if (dataset.storageLevel == StorageLevel.NONE && $(blockSize) == 1) {
|
||||
instances.persist(StorageLevel.MEMORY_AND_DISK)
|
||||
}
|
||||
|
||||
var requestedMetrics = Seq("mean", "std", "count")
|
||||
if ($(blockSize) != 1) requestedMetrics +:= "numNonZeros"
|
||||
val (summarizer, labelSummarizer) = Summarizer
|
||||
.getClassificationSummarizers(instances, $(aggregationDepth), requestedMetrics)
|
||||
|
||||
val histogram = labelSummarizer.histogram
|
||||
val numInvalid = labelSummarizer.countInvalid
|
||||
val numFeatures = summarizer.mean.size
|
||||
|
@ -316,7 +314,7 @@ class LinearSVC @Since("2.2.0") (
|
|||
}
|
||||
val blocks = InstanceBlock.blokify(standardized, $(blockSize))
|
||||
.persist(StorageLevel.MEMORY_AND_DISK)
|
||||
.setName(s"training dataset (blockSize=${$(blockSize)})")
|
||||
.setName(s"training blocks (blockSize=${$(blockSize)})")
|
||||
|
||||
val getAggregatorFunc = new BlockHingeAggregator($(fitIntercept))(_)
|
||||
val costFun = new RDDLossFunction(blocks, getAggregatorFunc,
|
||||
|
|
|
@ -517,17 +517,18 @@ class LogisticRegression @Since("1.2.0") (
|
|||
probabilityCol, regParam, elasticNetParam, standardization, threshold, maxIter, tol,
|
||||
fitIntercept, blockSize)
|
||||
|
||||
val instances = extractInstances(dataset).setName("training instances")
|
||||
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
|
||||
val instances = extractInstances(dataset)
|
||||
.setName("training instances")
|
||||
|
||||
val (summarizer, labelSummarizer) = if ($(blockSize) == 1) {
|
||||
Summarizer.getClassificationSummarizers(instances, $(aggregationDepth))
|
||||
} else {
|
||||
// instances will be standardized and converted to blocks, so no need to cache instances.
|
||||
Summarizer.getClassificationSummarizers(instances, $(aggregationDepth),
|
||||
Seq("mean", "std", "count", "numNonZeros"))
|
||||
if (handlePersistence && $(blockSize) == 1) {
|
||||
instances.persist(StorageLevel.MEMORY_AND_DISK)
|
||||
}
|
||||
|
||||
var requestedMetrics = Seq("mean", "std", "count")
|
||||
if ($(blockSize) != 1) requestedMetrics +:= "numNonZeros"
|
||||
val (summarizer, labelSummarizer) = Summarizer
|
||||
.getClassificationSummarizers(instances, $(aggregationDepth), requestedMetrics)
|
||||
|
||||
val numFeatures = summarizer.mean.size
|
||||
val histogram = labelSummarizer.histogram
|
||||
val numInvalid = labelSummarizer.countInvalid
|
||||
|
@ -591,7 +592,7 @@ class LogisticRegression @Since("1.2.0") (
|
|||
} else {
|
||||
Vectors.dense(if (numClasses == 2) Double.PositiveInfinity else Double.NegativeInfinity)
|
||||
}
|
||||
if (handlePersistence) instances.unpersist()
|
||||
if (instances.getStorageLevel != StorageLevel.NONE) instances.unpersist()
|
||||
return createModel(dataset, numClasses, coefMatrix, interceptVec, Array.empty)
|
||||
}
|
||||
|
||||
|
@ -650,7 +651,7 @@ class LogisticRegression @Since("1.2.0") (
|
|||
trainOnBlocks(instances, featuresStd, numClasses, initialCoefWithInterceptMatrix,
|
||||
regularization, optimizer)
|
||||
}
|
||||
if (handlePersistence) instances.unpersist()
|
||||
if (instances.getStorageLevel != StorageLevel.NONE) instances.unpersist()
|
||||
|
||||
if (allCoefficients == null) {
|
||||
val msg = s"${optimizer.getClass.getName} failed."
|
||||
|
@ -1002,7 +1003,7 @@ class LogisticRegression @Since("1.2.0") (
|
|||
}
|
||||
val blocks = InstanceBlock.blokify(standardized, $(blockSize))
|
||||
.persist(StorageLevel.MEMORY_AND_DISK)
|
||||
.setName(s"training dataset (blockSize=${$(blockSize)})")
|
||||
.setName(s"training blocks (blockSize=${$(blockSize)})")
|
||||
|
||||
val getAggregatorFunc = new BlockLogisticAggregator(numFeatures, numClasses, $(fitIntercept),
|
||||
checkMultinomial(numClasses))(_)
|
||||
|
|
|
@ -155,8 +155,102 @@ private[ml] class AFTAggregator(
|
|||
}
|
||||
gradientSumArray(dim - 2) += { if (fitIntercept) multiplier else 0.0 }
|
||||
gradientSumArray(dim - 1) += delta + multiplier * sigma * epsilon
|
||||
|
||||
weightSum += 1.0
|
||||
|
||||
this
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* BlockAFTAggregator computes the gradient and loss as used in AFT survival regression
|
||||
* for blocks in sparse or dense matrix in an online fashion.
|
||||
*
|
||||
* Two BlockAFTAggregators can be merged together to have a summary of loss and gradient of
|
||||
* the corresponding joint dataset.
|
||||
*
|
||||
* NOTE: The feature values are expected to be standardized before computation.
|
||||
*
|
||||
* @param bcCoefficients The coefficients corresponding to the features.
|
||||
* @param fitIntercept Whether to fit an intercept term.
|
||||
*/
|
||||
private[ml] class BlockAFTAggregator(
|
||||
fitIntercept: Boolean)(bcCoefficients: Broadcast[Vector])
|
||||
extends DifferentiableLossAggregator[(Matrix, Array[Double], Array[Double]),
|
||||
BlockAFTAggregator] {
|
||||
|
||||
protected override val dim: Int = bcCoefficients.value.size
|
||||
private val numFeatures = dim - 2
|
||||
|
||||
@transient private lazy val coefficientsArray = bcCoefficients.value match {
|
||||
case DenseVector(values) => values
|
||||
case _ => throw new IllegalArgumentException(s"coefficients only supports dense vector" +
|
||||
s" but got type ${bcCoefficients.value.getClass}.")
|
||||
}
|
||||
|
||||
@transient private lazy val linear = Vectors.dense(coefficientsArray.take(numFeatures))
|
||||
|
||||
/**
|
||||
* Add a new training instance block to this BlockAFTAggregator, and update the loss and
|
||||
* gradient of the objective function.
|
||||
*
|
||||
* @return This BlockAFTAggregator object.
|
||||
*/
|
||||
def add(block: (Matrix, Array[Double], Array[Double])): this.type = {
|
||||
val (matrix, labels, censors) = block
|
||||
require(matrix.isTransposed)
|
||||
require(numFeatures == matrix.numCols, s"Dimensions mismatch when adding new " +
|
||||
s"instance. Expecting $numFeatures but got ${matrix.numCols}.")
|
||||
require(labels.forall(_ > 0.0), "The lifetime or label should be greater than 0.")
|
||||
|
||||
val size = matrix.numRows
|
||||
require(labels.length == size && censors.length == size)
|
||||
|
||||
val intercept = coefficientsArray(dim - 2)
|
||||
// sigma is the scale parameter of the AFT model
|
||||
val sigma = math.exp(coefficientsArray(dim - 1))
|
||||
|
||||
// vec here represents margins
|
||||
val vec = if (fitIntercept) {
|
||||
Vectors.dense(Array.fill(size)(intercept)).toDense
|
||||
} else {
|
||||
Vectors.zeros(size).toDense
|
||||
}
|
||||
BLAS.gemv(1.0, matrix, linear, 1.0, vec)
|
||||
|
||||
// in-place convert margins to gradient scales
|
||||
// then, vec represents gradient scales
|
||||
var i = 0
|
||||
var sigmaGradSum = 0.0
|
||||
while (i < size) {
|
||||
val ti = labels(i)
|
||||
val delta = censors(i)
|
||||
val margin = vec(i)
|
||||
val epsilon = (math.log(ti) - margin) / sigma
|
||||
val expEpsilon = math.exp(epsilon)
|
||||
lossSum += delta * math.log(sigma) - delta * epsilon + expEpsilon
|
||||
val multiplier = (delta - expEpsilon) / sigma
|
||||
vec.values(i) = multiplier
|
||||
sigmaGradSum += delta + multiplier * sigma * epsilon
|
||||
i += 1
|
||||
}
|
||||
|
||||
matrix match {
|
||||
case dm: DenseMatrix =>
|
||||
BLAS.nativeBLAS.dgemv("N", dm.numCols, dm.numRows, 1.0, dm.values, dm.numCols,
|
||||
vec.values, 1, 1.0, gradientSumArray, 1)
|
||||
|
||||
case sm: SparseMatrix =>
|
||||
val linearGradSumVec = Vectors.zeros(numFeatures).toDense
|
||||
BLAS.gemv(1.0, sm.transpose, vec, 0.0, linearGradSumVec)
|
||||
BLAS.getBLAS(numFeatures).daxpy(numFeatures, 1.0, linearGradSumVec.values, 1,
|
||||
gradientSumArray, 1)
|
||||
}
|
||||
|
||||
if (fitIntercept) gradientSumArray(dim - 2) += vec.values.sum
|
||||
gradientSumArray(dim - 1) += sigmaGradSum
|
||||
weightSum += size
|
||||
|
||||
this
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,8 +27,9 @@ import org.apache.spark.SparkException
|
|||
import org.apache.spark.annotation.Since
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.ml.PredictorParams
|
||||
import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors, VectorUDT}
|
||||
import org.apache.spark.ml.optim.aggregator.AFTAggregator
|
||||
import org.apache.spark.ml.feature.StandardScalerModel
|
||||
import org.apache.spark.ml.linalg._
|
||||
import org.apache.spark.ml.optim.aggregator._
|
||||
import org.apache.spark.ml.optim.loss.RDDLossFunction
|
||||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.ml.param.shared._
|
||||
|
@ -46,7 +47,8 @@ import org.apache.spark.storage.StorageLevel
|
|||
* Params for accelerated failure time (AFT) regression.
|
||||
*/
|
||||
private[regression] trait AFTSurvivalRegressionParams extends PredictorParams
|
||||
with HasMaxIter with HasTol with HasFitIntercept with HasAggregationDepth with Logging {
|
||||
with HasMaxIter with HasTol with HasFitIntercept with HasAggregationDepth with HasBlockSize
|
||||
with Logging {
|
||||
|
||||
/**
|
||||
* Param for censor column name.
|
||||
|
@ -183,6 +185,25 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
|
|||
def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value)
|
||||
setDefault(aggregationDepth -> 2)
|
||||
|
||||
/**
|
||||
* Set block size for stacking input data in matrices.
|
||||
* If blockSize == 1, then stacking will be skipped, and each vector is treated individually;
|
||||
* If blockSize > 1, then vectors will be stacked to blocks, and high-level BLAS routines
|
||||
* will be used if possible (for example, GEMV instead of DOT, GEMM instead of GEMV).
|
||||
* Recommended size is between 10 and 1000. An appropriate choice of the block size depends
|
||||
* on the sparsity and dim of input datasets, the underlying BLAS implementation (for example,
|
||||
* f2jBLAS, OpenBLAS, intel MKL) and its configuration (for example, number of threads).
|
||||
* Note that existing BLAS implementations are mainly optimized for dense matrices, if the
|
||||
* input dataset is sparse, stacking may bring no performance gain, the worse is possible
|
||||
* performance regression.
|
||||
* Default is 1.
|
||||
*
|
||||
* @group expertSetParam
|
||||
*/
|
||||
@Since("3.1.0")
|
||||
def setBlockSize(value: Int): this.type = set(blockSize, value)
|
||||
setDefault(blockSize -> 1)
|
||||
|
||||
/**
|
||||
* Extract [[featuresCol]], [[labelCol]] and [[censorCol]] from input dataset,
|
||||
* and put it in an RDD with strong types.
|
||||
|
@ -197,39 +218,50 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
|
|||
|
||||
override protected def train(
|
||||
dataset: Dataset[_]): AFTSurvivalRegressionModel = instrumented { instr =>
|
||||
val instances = extractAFTPoints(dataset)
|
||||
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
|
||||
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
|
||||
instr.logPipelineStage(this)
|
||||
instr.logDataset(dataset)
|
||||
instr.logParams(this, labelCol, featuresCol, censorCol, predictionCol, quantilesCol,
|
||||
fitIntercept, maxIter, tol, aggregationDepth, blockSize)
|
||||
instr.logNamedValue("quantileProbabilities.size", $(quantileProbabilities).length)
|
||||
|
||||
val featuresSummarizer = instances.treeAggregate(
|
||||
Summarizer.createSummarizerBuffer("mean", "std", "count"))(
|
||||
val instances = extractAFTPoints(dataset)
|
||||
.setName("training instances")
|
||||
|
||||
if ($(blockSize) == 1 && dataset.storageLevel == StorageLevel.NONE) {
|
||||
instances.persist(StorageLevel.MEMORY_AND_DISK)
|
||||
}
|
||||
|
||||
var requestedMetrics = Seq("mean", "std", "count")
|
||||
if ($(blockSize) != 1) requestedMetrics +:= "numNonZeros"
|
||||
val summarizer = instances.treeAggregate(
|
||||
Summarizer.createSummarizerBuffer(requestedMetrics: _*))(
|
||||
seqOp = (c: SummarizerBuffer, v: AFTPoint) => c.add(v.features),
|
||||
combOp = (c1: SummarizerBuffer, c2: SummarizerBuffer) => c1.merge(c2),
|
||||
depth = $(aggregationDepth)
|
||||
)
|
||||
|
||||
val featuresStd = featuresSummarizer.std.toArray
|
||||
val featuresStd = summarizer.std.toArray
|
||||
val numFeatures = featuresStd.length
|
||||
|
||||
instr.logPipelineStage(this)
|
||||
instr.logDataset(dataset)
|
||||
instr.logParams(this, labelCol, featuresCol, censorCol, predictionCol, quantilesCol,
|
||||
fitIntercept, maxIter, tol, aggregationDepth)
|
||||
instr.logNamedValue("quantileProbabilities.size", $(quantileProbabilities).length)
|
||||
instr.logNumFeatures(numFeatures)
|
||||
instr.logNumExamples(featuresSummarizer.count)
|
||||
instr.logNumExamples(summarizer.count)
|
||||
if ($(blockSize) > 1) {
|
||||
val scale = 1.0 / summarizer.count / numFeatures
|
||||
val sparsity = 1 - summarizer.numNonzeros.toArray.map(_ * scale).sum
|
||||
instr.logNamedValue("sparsity", sparsity.toString)
|
||||
if (sparsity > 0.5) {
|
||||
instr.logWarning(s"sparsity of input dataset is $sparsity, " +
|
||||
s"which may hurt performance in high-level BLAS.")
|
||||
}
|
||||
}
|
||||
|
||||
if (!$(fitIntercept) && (0 until numFeatures).exists { i =>
|
||||
featuresStd(i) == 0.0 && featuresSummarizer.mean(i) != 0.0 }) {
|
||||
featuresStd(i) == 0.0 && summarizer.mean(i) != 0.0 }) {
|
||||
instr.logWarning("Fitting AFTSurvivalRegressionModel without intercept on dataset with " +
|
||||
"constant nonzero column, Spark MLlib outputs zero coefficients for constant nonzero " +
|
||||
"columns. This behavior is different from R survival::survreg.")
|
||||
}
|
||||
|
||||
val bcFeaturesStd = instances.context.broadcast(featuresStd)
|
||||
val getAggregatorFunc = new AFTAggregator(bcFeaturesStd, $(fitIntercept))(_)
|
||||
val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
|
||||
val costFun = new RDDLossFunction(instances, getAggregatorFunc, None, $(aggregationDepth))
|
||||
|
||||
/*
|
||||
The parameters vector has three parts:
|
||||
|
@ -239,36 +271,86 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
|
|||
*/
|
||||
val initialParameters = Vectors.zeros(numFeatures + 2)
|
||||
|
||||
val (rawCoefficients, objectiveHistory) = if ($(blockSize) == 1) {
|
||||
trainOnRows(instances, featuresStd, optimizer, initialParameters)
|
||||
} else {
|
||||
trainOnBlocks(instances, featuresStd, optimizer, initialParameters)
|
||||
}
|
||||
if (instances.getStorageLevel != StorageLevel.NONE) instances.unpersist()
|
||||
|
||||
if (rawCoefficients == null) {
|
||||
val msg = s"${optimizer.getClass.getName} failed."
|
||||
instr.logError(msg)
|
||||
throw new SparkException(msg)
|
||||
}
|
||||
|
||||
val coefficientArray = Array.tabulate(numFeatures) { i =>
|
||||
if (featuresStd(i) != 0) rawCoefficients(i) / featuresStd(i) else 0.0
|
||||
}
|
||||
val coefficients = Vectors.dense(coefficientArray)
|
||||
val intercept = rawCoefficients(numFeatures)
|
||||
val scale = math.exp(rawCoefficients(numFeatures + 1))
|
||||
new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale)
|
||||
}
|
||||
|
||||
private def trainOnRows(
|
||||
instances: RDD[AFTPoint],
|
||||
featuresStd: Array[Double],
|
||||
optimizer: BreezeLBFGS[BDV[Double]],
|
||||
initialParameters: Vector): (Array[Double], Array[Double]) = {
|
||||
val bcFeaturesStd = instances.context.broadcast(featuresStd)
|
||||
val getAggregatorFunc = new AFTAggregator(bcFeaturesStd, $(fitIntercept))(_)
|
||||
val costFun = new RDDLossFunction(instances, getAggregatorFunc, None, $(aggregationDepth))
|
||||
|
||||
val states = optimizer.iterations(new CachedDiffFunction(costFun),
|
||||
initialParameters.asBreeze.toDenseVector)
|
||||
|
||||
val parameters = {
|
||||
val arrayBuilder = mutable.ArrayBuilder.make[Double]
|
||||
var state: optimizer.State = null
|
||||
while (states.hasNext) {
|
||||
state = states.next()
|
||||
arrayBuilder += state.adjustedValue
|
||||
}
|
||||
if (state == null) {
|
||||
val msg = s"${optimizer.getClass.getName} failed."
|
||||
throw new SparkException(msg)
|
||||
}
|
||||
state.x.toArray.clone()
|
||||
val arrayBuilder = mutable.ArrayBuilder.make[Double]
|
||||
var state: optimizer.State = null
|
||||
while (states.hasNext) {
|
||||
state = states.next()
|
||||
arrayBuilder += state.adjustedValue
|
||||
}
|
||||
|
||||
bcFeaturesStd.destroy()
|
||||
if (handlePersistence) instances.unpersist()
|
||||
|
||||
val rawCoefficients = parameters.take(numFeatures)
|
||||
var i = 0
|
||||
while (i < numFeatures) {
|
||||
rawCoefficients(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 }
|
||||
i += 1
|
||||
(if (state != null) state.x.toArray else null, arrayBuilder.result)
|
||||
}
|
||||
|
||||
private def trainOnBlocks(
|
||||
instances: RDD[AFTPoint],
|
||||
featuresStd: Array[Double],
|
||||
optimizer: BreezeLBFGS[BDV[Double]],
|
||||
initialParameters: Vector): (Array[Double], Array[Double]) = {
|
||||
val bcFeaturesStd = instances.context.broadcast(featuresStd)
|
||||
val blocks = instances.mapPartitions { iter =>
|
||||
val inverseStd = bcFeaturesStd.value.map { std => if (std != 0) 1.0 / std else 0.0 }
|
||||
val func = StandardScalerModel.getTransformFunc(Array.empty, inverseStd, false, true)
|
||||
iter.grouped($(blockSize)).map { seq =>
|
||||
val matrix = Matrices.fromVectors(seq.map(point => func(point.features)))
|
||||
val labels = seq.map(_.label).toArray
|
||||
val censors = seq.map(_.censor).toArray
|
||||
(matrix, labels, censors)
|
||||
}
|
||||
}
|
||||
val coefficients = Vectors.dense(rawCoefficients)
|
||||
val intercept = parameters(numFeatures)
|
||||
val scale = math.exp(parameters(numFeatures + 1))
|
||||
new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale)
|
||||
blocks.persist(StorageLevel.MEMORY_AND_DISK)
|
||||
.setName(s"training blocks (blockSize=${$(blockSize)})")
|
||||
|
||||
val getAggregatorFunc = new BlockAFTAggregator($(fitIntercept))(_)
|
||||
val costFun = new RDDLossFunction(blocks, getAggregatorFunc, None, $(aggregationDepth))
|
||||
|
||||
val states = optimizer.iterations(new CachedDiffFunction(costFun),
|
||||
initialParameters.asBreeze.toDenseVector)
|
||||
|
||||
val arrayBuilder = mutable.ArrayBuilder.make[Double]
|
||||
var state: optimizer.State = null
|
||||
while (states.hasNext) {
|
||||
state = states.next()
|
||||
arrayBuilder += state.adjustedValue
|
||||
}
|
||||
blocks.unpersist()
|
||||
bcFeaturesStd.destroy()
|
||||
|
||||
(if (state != null) state.x.toArray else null, arrayBuilder.result)
|
||||
}
|
||||
|
||||
@Since("1.6.0")
|
||||
|
|
|
@ -355,17 +355,15 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
|
|||
val instances = extractInstances(dataset)
|
||||
.setName("training instances")
|
||||
|
||||
val (featuresSummarizer, ySummarizer) = if ($(blockSize) == 1) {
|
||||
if (dataset.storageLevel == StorageLevel.NONE) {
|
||||
instances.persist(StorageLevel.MEMORY_AND_DISK)
|
||||
}
|
||||
Summarizer.getRegressionSummarizers(instances, $(aggregationDepth))
|
||||
} else {
|
||||
// instances will be standardized and converted to blocks, so no need to cache instances.
|
||||
Summarizer.getRegressionSummarizers(instances, $(aggregationDepth),
|
||||
Seq("mean", "std", "count", "numNonZeros"))
|
||||
if (dataset.storageLevel == StorageLevel.NONE && $(blockSize) == 1) {
|
||||
instances.persist(StorageLevel.MEMORY_AND_DISK)
|
||||
}
|
||||
|
||||
var requestedMetrics = Seq("mean", "std", "count")
|
||||
if ($(blockSize) != 1) requestedMetrics +:= "numNonZeros"
|
||||
val (featuresSummarizer, ySummarizer) = Summarizer
|
||||
.getRegressionSummarizers(instances, $(aggregationDepth), requestedMetrics)
|
||||
|
||||
val yMean = ySummarizer.mean(0)
|
||||
val rawYStd = ySummarizer.std(0)
|
||||
|
||||
|
@ -617,7 +615,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
|
|||
}
|
||||
val blocks = InstanceBlock.blokify(standardized, $(blockSize))
|
||||
.persist(StorageLevel.MEMORY_AND_DISK)
|
||||
.setName(s"training dataset (blockSize=${$(blockSize)})")
|
||||
.setName(s"training blocks (blockSize=${$(blockSize)})")
|
||||
|
||||
val costFun = $(loss) match {
|
||||
case SquaredError =>
|
||||
|
|
|
@ -428,6 +428,22 @@ class AFTSurvivalRegressionSuite extends MLTest with DefaultReadWriteTest {
|
|||
val trainer = new AFTSurvivalRegression()
|
||||
trainer.fit(dataset)
|
||||
}
|
||||
|
||||
test("AFTSurvivalRegression on blocks") {
|
||||
val quantileProbabilities = Array(0.1, 0.5, 0.9)
|
||||
for (dataset <- Seq(datasetUnivariate, datasetUnivariateScaled, datasetMultivariate)) {
|
||||
val aft = new AFTSurvivalRegression()
|
||||
.setQuantileProbabilities(quantileProbabilities)
|
||||
.setQuantilesCol("quantiles")
|
||||
val model = aft.fit(dataset)
|
||||
Seq(4, 16, 64).foreach { blockSize =>
|
||||
val model2 = aft.setBlockSize(blockSize).fit(dataset)
|
||||
assert(model.coefficients ~== model2.coefficients relTol 1e-9)
|
||||
assert(model.intercept ~== model2.intercept relTol 1e-9)
|
||||
assert(model.scale ~== model2.scale relTol 1e-9)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
object AFTSurvivalRegressionSuite {
|
||||
|
|
|
@ -1607,7 +1607,7 @@ class GBTRegressionModel(
|
|||
|
||||
|
||||
class _AFTSurvivalRegressionParams(_PredictorParams, HasMaxIter, HasTol, HasFitIntercept,
|
||||
HasAggregationDepth):
|
||||
HasAggregationDepth, HasBlockSize):
|
||||
"""
|
||||
Params for :py:class:`AFTSurvivalRegression` and :py:class:`AFTSurvivalRegressionModel`.
|
||||
|
||||
|
@ -1674,6 +1674,8 @@ class AFTSurvivalRegression(_JavaRegressor, _AFTSurvivalRegressionParams,
|
|||
10
|
||||
>>> aftsr.clear(aftsr.maxIter)
|
||||
>>> model = aftsr.fit(df)
|
||||
>>> model.getBlockSize()
|
||||
1
|
||||
>>> model.setFeaturesCol("features")
|
||||
AFTSurvivalRegressionModel...
|
||||
>>> model.predict(Vectors.dense(6.3))
|
||||
|
@ -1710,19 +1712,19 @@ class AFTSurvivalRegression(_JavaRegressor, _AFTSurvivalRegressionParams,
|
|||
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
||||
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
|
||||
quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]),
|
||||
quantilesCol=None, aggregationDepth=2):
|
||||
quantilesCol=None, aggregationDepth=2, blockSize=1):
|
||||
"""
|
||||
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
||||
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \
|
||||
quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \
|
||||
quantilesCol=None, aggregationDepth=2)
|
||||
quantilesCol=None, aggregationDepth=2, blockSize=1)
|
||||
"""
|
||||
super(AFTSurvivalRegression, self).__init__()
|
||||
self._java_obj = self._new_java_obj(
|
||||
"org.apache.spark.ml.regression.AFTSurvivalRegression", self.uid)
|
||||
self._setDefault(censorCol="censor",
|
||||
quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99],
|
||||
maxIter=100, tol=1E-6)
|
||||
maxIter=100, tol=1E-6, blockSize=1)
|
||||
kwargs = self._input_kwargs
|
||||
self.setParams(**kwargs)
|
||||
|
||||
|
@ -1731,12 +1733,12 @@ class AFTSurvivalRegression(_JavaRegressor, _AFTSurvivalRegressionParams,
|
|||
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
|
||||
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
|
||||
quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]),
|
||||
quantilesCol=None, aggregationDepth=2):
|
||||
quantilesCol=None, aggregationDepth=2, blockSize=1):
|
||||
"""
|
||||
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
|
||||
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \
|
||||
quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \
|
||||
quantilesCol=None, aggregationDepth=2):
|
||||
quantilesCol=None, aggregationDepth=2, blockSize=1):
|
||||
"""
|
||||
kwargs = self._input_kwargs
|
||||
return self._set(**kwargs)
|
||||
|
@ -1793,6 +1795,13 @@ class AFTSurvivalRegression(_JavaRegressor, _AFTSurvivalRegressionParams,
|
|||
"""
|
||||
return self._set(aggregationDepth=value)
|
||||
|
||||
@since("3.1.0")
|
||||
def setBlockSize(self, value):
|
||||
"""
|
||||
Sets the value of :py:attr:`blockSize`.
|
||||
"""
|
||||
return self._set(blockSize=value)
|
||||
|
||||
|
||||
class AFTSurvivalRegressionModel(_JavaRegressionModel, _AFTSurvivalRegressionParams,
|
||||
JavaMLWritable, JavaMLReadable):
|
||||
|
|
Loading…
Reference in a new issue