[SPARK-29754][ML] LoR/AFT/LiR/SVC use Summarizer instead of MultivariateOnlineSummarizer
### What changes were proposed in this pull request? 1, change the scope of `ml.SummarizerBuffer` and add a method `createSummarizerBuffer` for it, so it can be used as an aggregator like `MultivariateOnlineSummarizer`; 2, In LoR/AFT/LiR/SVC, use Summarizer instead of MultivariateOnlineSummarizer ### Why are the changes needed? The computation of summary before learning iterations is a bottleneck in high-dimension cases, since `MultivariateOnlineSummarizer` compute much more than needed. In the [ticket](https://issues.apache.org/jira/browse/SPARK-29754) is an example, with `--driver-memory=4G` LoR will always fail on KDDA dataset. If we swith to `ml.Summarizer`, then `--driver-memory=3G` is enough to train a model. ### Does this PR introduce any user-facing change? No ### How was this patch tested? existing testsuites & manual test in REPL Closes #26396 from zhengruifeng/using_SummarizerBuffer. Authored-by: zhengruifeng <ruifengz@foxmail.com> Signed-off-by: zhengruifeng <ruifengz@foxmail.com>
This commit is contained in:
parent
0dcd739534
commit
5853e8b330
|
@ -32,10 +32,9 @@ import org.apache.spark.ml.optim.aggregator.HingeAggregator
|
|||
import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction}
|
||||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.ml.param.shared._
|
||||
import org.apache.spark.ml.stat.SummaryBuilderImpl._
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.ml.util.Instrumentation.instrumented
|
||||
import org.apache.spark.mllib.linalg.VectorImplicits._
|
||||
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
|
||||
import org.apache.spark.sql.{Dataset, Row}
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
|
||||
|
@ -170,19 +169,15 @@ class LinearSVC @Since("2.2.0") (
|
|||
instr.logParams(this, labelCol, weightCol, featuresCol, predictionCol, rawPredictionCol,
|
||||
regParam, maxIter, fitIntercept, tol, standardization, threshold, aggregationDepth)
|
||||
|
||||
val (summarizer, labelSummarizer) = {
|
||||
val seqOp = (c: (MultivariateOnlineSummarizer, MultiClassSummarizer),
|
||||
instance: Instance) =>
|
||||
(c._1.add(instance.features, instance.weight), c._2.add(instance.label, instance.weight))
|
||||
|
||||
val combOp = (c1: (MultivariateOnlineSummarizer, MultiClassSummarizer),
|
||||
c2: (MultivariateOnlineSummarizer, MultiClassSummarizer)) =>
|
||||
(c1._1.merge(c2._1), c1._2.merge(c2._2))
|
||||
|
||||
instances.treeAggregate(
|
||||
(new MultivariateOnlineSummarizer, new MultiClassSummarizer)
|
||||
)(seqOp, combOp, $(aggregationDepth))
|
||||
}
|
||||
val (summarizer, labelSummarizer) = instances.treeAggregate(
|
||||
(createSummarizerBuffer("mean", "variance", "count"), new MultiClassSummarizer))(
|
||||
seqOp = (c: (SummarizerBuffer, MultiClassSummarizer), instance: Instance) =>
|
||||
(c._1.add(instance.features, instance.weight), c._2.add(instance.label, instance.weight)),
|
||||
combOp = (c1: (SummarizerBuffer, MultiClassSummarizer),
|
||||
c2: (SummarizerBuffer, MultiClassSummarizer)) =>
|
||||
(c1._1.merge(c2._1), c1._2.merge(c2._2)),
|
||||
depth = $(aggregationDepth)
|
||||
)
|
||||
instr.logNumExamples(summarizer.count)
|
||||
instr.logNamedValue("lowestLabelWeight", labelSummarizer.histogram.min.toString)
|
||||
instr.logNamedValue("highestLabelWeight", labelSummarizer.histogram.max.toString)
|
||||
|
|
|
@ -34,11 +34,10 @@ import org.apache.spark.ml.optim.aggregator.LogisticAggregator
|
|||
import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction}
|
||||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.ml.param.shared._
|
||||
import org.apache.spark.ml.stat.SummaryBuilderImpl._
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.ml.util.Instrumentation.instrumented
|
||||
import org.apache.spark.mllib.evaluation.{BinaryClassificationMetrics, MulticlassMetrics}
|
||||
import org.apache.spark.mllib.linalg.VectorImplicits._
|
||||
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
|
||||
import org.apache.spark.mllib.util.MLUtils
|
||||
import org.apache.spark.sql.{DataFrame, Dataset, Row}
|
||||
import org.apache.spark.sql.functions.col
|
||||
|
@ -501,19 +500,16 @@ class LogisticRegression @Since("1.2.0") (
|
|||
probabilityCol, regParam, elasticNetParam, standardization, threshold, maxIter, tol,
|
||||
fitIntercept)
|
||||
|
||||
val (summarizer, labelSummarizer) = {
|
||||
val seqOp = (c: (MultivariateOnlineSummarizer, MultiClassSummarizer),
|
||||
instance: Instance) =>
|
||||
(c._1.add(instance.features, instance.weight), c._2.add(instance.label, instance.weight))
|
||||
val (summarizer, labelSummarizer) = instances.treeAggregate(
|
||||
(createSummarizerBuffer("mean", "variance", "count"), new MultiClassSummarizer))(
|
||||
seqOp = (c: (SummarizerBuffer, MultiClassSummarizer), instance: Instance) =>
|
||||
(c._1.add(instance.features, instance.weight), c._2.add(instance.label, instance.weight)),
|
||||
combOp = (c1: (SummarizerBuffer, MultiClassSummarizer),
|
||||
c2: (SummarizerBuffer, MultiClassSummarizer)) =>
|
||||
(c1._1.merge(c2._1), c1._2.merge(c2._2)),
|
||||
depth = $(aggregationDepth)
|
||||
)
|
||||
|
||||
val combOp = (c1: (MultivariateOnlineSummarizer, MultiClassSummarizer),
|
||||
c2: (MultivariateOnlineSummarizer, MultiClassSummarizer)) =>
|
||||
(c1._1.merge(c2._1), c1._2.merge(c2._2))
|
||||
|
||||
instances.treeAggregate(
|
||||
(new MultivariateOnlineSummarizer, new MultiClassSummarizer)
|
||||
)(seqOp, combOp, $(aggregationDepth))
|
||||
}
|
||||
instr.logNumExamples(summarizer.count)
|
||||
instr.logNamedValue("lowestLabelWeight", labelSummarizer.histogram.min.toString)
|
||||
instr.logNamedValue("highestLabelWeight", labelSummarizer.histogram.max.toString)
|
||||
|
|
|
@ -31,10 +31,9 @@ import org.apache.spark.ml.{Estimator, Model}
|
|||
import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors, VectorUDT}
|
||||
import org.apache.spark.ml.param._
|
||||
import org.apache.spark.ml.param.shared._
|
||||
import org.apache.spark.ml.stat.SummaryBuilderImpl._
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.ml.util.Instrumentation.instrumented
|
||||
import org.apache.spark.mllib.linalg.VectorImplicits._
|
||||
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
|
||||
import org.apache.spark.mllib.util.MLUtils
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
|
||||
|
@ -215,15 +214,12 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
|
|||
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
|
||||
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
|
||||
|
||||
val featuresSummarizer = {
|
||||
val seqOp = (c: MultivariateOnlineSummarizer, v: AFTPoint) => c.add(v.features)
|
||||
val combOp = (c1: MultivariateOnlineSummarizer, c2: MultivariateOnlineSummarizer) => {
|
||||
c1.merge(c2)
|
||||
}
|
||||
instances.treeAggregate(
|
||||
new MultivariateOnlineSummarizer
|
||||
)(seqOp, combOp, $(aggregationDepth))
|
||||
}
|
||||
val featuresSummarizer = instances.treeAggregate(
|
||||
createSummarizerBuffer("mean", "variance", "count"))(
|
||||
seqOp = (c: SummarizerBuffer, v: AFTPoint) => c.add(v.features),
|
||||
combOp = (c1: SummarizerBuffer, c2: SummarizerBuffer) => c1.merge(c2),
|
||||
depth = $(aggregationDepth)
|
||||
)
|
||||
|
||||
val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt)
|
||||
val numFeatures = featuresStd.size
|
||||
|
|
|
@ -36,12 +36,12 @@ import org.apache.spark.ml.optim.aggregator.{HuberAggregator, LeastSquaresAggreg
|
|||
import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction}
|
||||
import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators}
|
||||
import org.apache.spark.ml.param.shared._
|
||||
import org.apache.spark.ml.stat.SummaryBuilderImpl._
|
||||
import org.apache.spark.ml.util._
|
||||
import org.apache.spark.ml.util.Instrumentation.instrumented
|
||||
import org.apache.spark.mllib.evaluation.RegressionMetrics
|
||||
import org.apache.spark.mllib.linalg.VectorImplicits._
|
||||
import org.apache.spark.mllib.regression.{LinearRegressionModel => OldLinearRegressionModel}
|
||||
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
|
||||
import org.apache.spark.mllib.util.MLUtils
|
||||
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
|
||||
import org.apache.spark.sql.functions._
|
||||
|
@ -357,20 +357,17 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
|
|||
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
|
||||
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
|
||||
|
||||
val (featuresSummarizer, ySummarizer) = {
|
||||
val seqOp = (c: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer),
|
||||
instance: Instance) =>
|
||||
val (featuresSummarizer, ySummarizer) = instances.treeAggregate(
|
||||
(createSummarizerBuffer("mean", "variance"),
|
||||
createSummarizerBuffer("mean", "variance", "count")))(
|
||||
seqOp = (c: (SummarizerBuffer, SummarizerBuffer), instance: Instance) =>
|
||||
(c._1.add(instance.features, instance.weight),
|
||||
c._2.add(Vectors.dense(instance.label), instance.weight))
|
||||
|
||||
val combOp = (c1: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer),
|
||||
c2: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer)) =>
|
||||
(c1._1.merge(c2._1), c1._2.merge(c2._2))
|
||||
|
||||
instances.treeAggregate(
|
||||
(new MultivariateOnlineSummarizer, new MultivariateOnlineSummarizer)
|
||||
)(seqOp, combOp, $(aggregationDepth))
|
||||
}
|
||||
c._2.add(Vectors.dense(instance.label), instance.weight)),
|
||||
combOp = (c1: (SummarizerBuffer, SummarizerBuffer),
|
||||
c2: (SummarizerBuffer, SummarizerBuffer)) =>
|
||||
(c1._1.merge(c2._1), c1._2.merge(c2._2)),
|
||||
depth = $(aggregationDepth)
|
||||
)
|
||||
|
||||
val yMean = ySummarizer.mean(0)
|
||||
val rawYStd = math.sqrt(ySummarizer.variance(0))
|
||||
|
|
|
@ -230,6 +230,11 @@ private[ml] object SummaryBuilderImpl extends Logging {
|
|||
StructType(fields)
|
||||
}
|
||||
|
||||
private[ml] def createSummarizerBuffer(requested: String*): SummarizerBuffer = {
|
||||
val (metrics, computeMetrics) = getRelevantMetrics(requested)
|
||||
new SummarizerBuffer(metrics, computeMetrics)
|
||||
}
|
||||
|
||||
private val vectorUDT = new VectorUDT
|
||||
|
||||
/**
|
||||
|
@ -277,7 +282,7 @@ private[ml] object SummaryBuilderImpl extends Logging {
|
|||
private[stat] case object ComputeMax extends ComputeMetric
|
||||
private[stat] case object ComputeMin extends ComputeMetric
|
||||
|
||||
private[stat] class SummarizerBuffer(
|
||||
private[ml] class SummarizerBuffer(
|
||||
requestedMetrics: Seq[Metric],
|
||||
requestedCompMetrics: Seq[ComputeMetric]
|
||||
) extends Serializable {
|
||||
|
|
Loading…
Reference in a new issue