[SPARK-33111][ML][FOLLOW-UP] aft transform optimization - predictQuantiles
### What changes were proposed in this pull request? 1, optimize `predictQuantiles` by pre-computing an auxiliary var. ### Why are the changes needed? In https://github.com/apache/spark/pull/30000, I optimized the `transform` method. I find that we can also optimize `predictQuantiles` by pre-computing an auxiliary var. It is about 56% faster than existing impl. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? existing testsuites Closes #30034 from zhengruifeng/aft_quantiles_opt. Authored-by: zhengruifeng <ruifengz@foxmail.com> Signed-off-by: Sean Owen <srowen@gmail.com>
This commit is contained in:
parent
dcb0820433
commit
618695b78f
|
@ -383,22 +383,32 @@ class AFTSurvivalRegressionModel private[ml] (
|
|||
|
||||
/** @group setParam */
|
||||
@Since("1.6.0")
|
||||
def setQuantileProbabilities(value: Array[Double]): this.type = set(quantileProbabilities, value)
|
||||
def setQuantileProbabilities(value: Array[Double]): this.type = {
|
||||
set(quantileProbabilities, value)
|
||||
_quantiles(0) = $(quantileProbabilities).map(q => math.exp(math.log(-math.log1p(-q)) * scale))
|
||||
this
|
||||
}
|
||||
|
||||
/** @group setParam */
|
||||
@Since("1.6.0")
|
||||
def setQuantilesCol(value: String): this.type = set(quantilesCol, value)
|
||||
|
||||
private lazy val _quantiles = {
|
||||
Array($(quantileProbabilities).map(q => math.exp(math.log(-math.log1p(-q)) * scale)))
|
||||
}
|
||||
|
||||
private def lambda2Quantiles(lambda: Double): Vector = {
|
||||
val quantiles = _quantiles(0).clone()
|
||||
var i = 0
|
||||
while (i < quantiles.length) { quantiles(i) *= lambda; i += 1 }
|
||||
Vectors.dense(quantiles)
|
||||
}
|
||||
|
||||
@Since("2.0.0")
|
||||
def predictQuantiles(features: Vector): Vector = {
|
||||
// scale parameter for the Weibull distribution of lifetime
|
||||
val lambda = math.exp(BLAS.dot(coefficients, features) + intercept)
|
||||
// shape parameter for the Weibull distribution of lifetime
|
||||
val k = 1 / scale
|
||||
val quantiles = $(quantileProbabilities).map {
|
||||
q => lambda * math.exp(math.log(-math.log1p(-q)) / k)
|
||||
}
|
||||
Vectors.dense(quantiles)
|
||||
val lambda = predict(features)
|
||||
lambda2Quantiles(lambda)
|
||||
}
|
||||
|
||||
@Since("2.0.0")
|
||||
|
@ -414,24 +424,20 @@ class AFTSurvivalRegressionModel private[ml] (
|
|||
var predictionColumns = Seq.empty[Column]
|
||||
|
||||
if ($(predictionCol).nonEmpty) {
|
||||
val predictUDF = udf { features: Vector => predict(features) }
|
||||
val predCol = udf(predict _).apply(col($(featuresCol)))
|
||||
predictionColNames :+= $(predictionCol)
|
||||
predictionColumns :+= predictUDF(col($(featuresCol)))
|
||||
predictionColumns :+= predCol
|
||||
.as($(predictionCol), outputSchema($(predictionCol)).metadata)
|
||||
}
|
||||
|
||||
if (hasQuantilesCol) {
|
||||
val baseQuantiles = $(quantileProbabilities)
|
||||
.map(q => math.exp(math.log(-math.log1p(-q)) * scale))
|
||||
val lambdaCol = if ($(predictionCol).nonEmpty) {
|
||||
predictionColumns.head
|
||||
val quanCol = if ($(predictionCol).nonEmpty) {
|
||||
udf(lambda2Quantiles _).apply(predictionColumns.head)
|
||||
} else {
|
||||
udf { features: Vector => predict(features) }.apply(col($(featuresCol)))
|
||||
udf(predictQuantiles _).apply(col($(featuresCol)))
|
||||
}
|
||||
val predictQuantilesUDF =
|
||||
udf { lambda: Double => Vectors.dense(baseQuantiles.map(q => q * lambda)) }
|
||||
predictionColNames :+= $(quantilesCol)
|
||||
predictionColumns :+= predictQuantilesUDF(lambdaCol)
|
||||
predictionColumns :+= quanCol
|
||||
.as($(quantilesCol), outputSchema($(quantilesCol)).metadata)
|
||||
}
|
||||
|
||||
|
|
|
@ -130,9 +130,9 @@ class AFTSurvivalRegressionSuite extends MLTest with DefaultReadWriteTest {
|
|||
test("aft survival regression with univariate") {
|
||||
val quantileProbabilities = Array(0.1, 0.5, 0.9)
|
||||
val trainer = new AFTSurvivalRegression()
|
||||
.setQuantileProbabilities(quantileProbabilities)
|
||||
.setQuantilesCol("quantiles")
|
||||
val model = trainer.fit(datasetUnivariate)
|
||||
model.setQuantileProbabilities(quantileProbabilities)
|
||||
|
||||
/*
|
||||
Using the following R code to load the data and train the model using survival package.
|
||||
|
|
Loading…
Reference in a new issue