[SPARK-33111][ML][FOLLOW-UP] aft transform optimization - predictQuantiles

### What changes were proposed in this pull request?
1, optimize `predictQuantiles` by pre-computing an auxiliary var.

### Why are the changes needed?
In https://github.com/apache/spark/pull/30000, I optimized the `transform` method. I find that we can also optimize `predictQuantiles` by pre-computing an auxiliary var.

It is about 56% faster than existing impl.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
existing testsuites

Closes #30034 from zhengruifeng/aft_quantiles_opt.

Authored-by: zhengruifeng <ruifengz@foxmail.com>
Signed-off-by: Sean Owen <srowen@gmail.com>
This commit is contained in:
zhengruifeng 2020-10-21 08:49:25 -05:00 committed by Sean Owen
parent dcb0820433
commit 618695b78f
2 changed files with 25 additions and 19 deletions

View file

@ -383,22 +383,32 @@ class AFTSurvivalRegressionModel private[ml] (
/** @group setParam */
@Since("1.6.0")
def setQuantileProbabilities(value: Array[Double]): this.type = set(quantileProbabilities, value)
def setQuantileProbabilities(value: Array[Double]): this.type = {
set(quantileProbabilities, value)
_quantiles(0) = $(quantileProbabilities).map(q => math.exp(math.log(-math.log1p(-q)) * scale))
this
}
/** @group setParam */
@Since("1.6.0")
def setQuantilesCol(value: String): this.type = set(quantilesCol, value)
private lazy val _quantiles = {
Array($(quantileProbabilities).map(q => math.exp(math.log(-math.log1p(-q)) * scale)))
}
private def lambda2Quantiles(lambda: Double): Vector = {
val quantiles = _quantiles(0).clone()
var i = 0
while (i < quantiles.length) { quantiles(i) *= lambda; i += 1 }
Vectors.dense(quantiles)
}
@Since("2.0.0")
def predictQuantiles(features: Vector): Vector = {
// scale parameter for the Weibull distribution of lifetime
val lambda = math.exp(BLAS.dot(coefficients, features) + intercept)
// shape parameter for the Weibull distribution of lifetime
val k = 1 / scale
val quantiles = $(quantileProbabilities).map {
q => lambda * math.exp(math.log(-math.log1p(-q)) / k)
}
Vectors.dense(quantiles)
val lambda = predict(features)
lambda2Quantiles(lambda)
}
@Since("2.0.0")
@ -414,24 +424,20 @@ class AFTSurvivalRegressionModel private[ml] (
var predictionColumns = Seq.empty[Column]
if ($(predictionCol).nonEmpty) {
val predictUDF = udf { features: Vector => predict(features) }
val predCol = udf(predict _).apply(col($(featuresCol)))
predictionColNames :+= $(predictionCol)
predictionColumns :+= predictUDF(col($(featuresCol)))
predictionColumns :+= predCol
.as($(predictionCol), outputSchema($(predictionCol)).metadata)
}
if (hasQuantilesCol) {
val baseQuantiles = $(quantileProbabilities)
.map(q => math.exp(math.log(-math.log1p(-q)) * scale))
val lambdaCol = if ($(predictionCol).nonEmpty) {
predictionColumns.head
val quanCol = if ($(predictionCol).nonEmpty) {
udf(lambda2Quantiles _).apply(predictionColumns.head)
} else {
udf { features: Vector => predict(features) }.apply(col($(featuresCol)))
udf(predictQuantiles _).apply(col($(featuresCol)))
}
val predictQuantilesUDF =
udf { lambda: Double => Vectors.dense(baseQuantiles.map(q => q * lambda)) }
predictionColNames :+= $(quantilesCol)
predictionColumns :+= predictQuantilesUDF(lambdaCol)
predictionColumns :+= quanCol
.as($(quantilesCol), outputSchema($(quantilesCol)).metadata)
}

View file

@ -130,9 +130,9 @@ class AFTSurvivalRegressionSuite extends MLTest with DefaultReadWriteTest {
test("aft survival regression with univariate") {
val quantileProbabilities = Array(0.1, 0.5, 0.9)
val trainer = new AFTSurvivalRegression()
.setQuantileProbabilities(quantileProbabilities)
.setQuantilesCol("quantiles")
val model = trainer.fit(datasetUnivariate)
model.setQuantileProbabilities(quantileProbabilities)
/*
Using the following R code to load the data and train the model using survival package.